package input import ( "fmt" "strings" "tasksquire/common" "github.com/charmbracelet/bubbles/key" "github.com/charmbracelet/bubbles/textinput" "github.com/charmbracelet/bubbles/viewport" tea "github.com/charmbracelet/bubbletea" "github.com/charmbracelet/huh" "github.com/charmbracelet/huh/accessibility" "github.com/charmbracelet/lipgloss" ) // MultiSelect is a form multi-select field. type MultiSelect struct { common *common.Common value *[]string key string // customization title string description string options []Option[string] filterable bool filteredOptions []Option[string] limit int height int // error handling validate func([]string) error err error // state cursor int focused bool filtering bool filter textinput.Model viewport viewport.Model // options width int accessible bool theme *huh.Theme keymap huh.MultiSelectKeyMap // new hasNewOption bool newInput textinput.Model newInputActive bool } // NewMultiSelect returns a new multi-select field. func NewMultiSelect(common *common.Common) *MultiSelect { filter := textinput.New() filter.Prompt = "/" newInput := textinput.New() newInput.Prompt = "New: " return &MultiSelect{ common: common, options: []Option[string]{}, value: new([]string), validate: func([]string) error { return nil }, filtering: false, filter: filter, newInput: newInput, newInputActive: false, } } // Value sets the value of the multi-select field. func (m *MultiSelect) Value(value *[]string) *MultiSelect { m.value = value for i, o := range m.options { for _, v := range *value { if o.Value == v { m.options[i].selected = true break } } } return m } // Key sets the key of the select field which can be used to retrieve the value // after submission. func (m *MultiSelect) Key(key string) *MultiSelect { m.key = key return m } // Title sets the title of the multi-select field. func (m *MultiSelect) Title(title string) *MultiSelect { m.title = title return m } // Description sets the description of the multi-select field. func (m *MultiSelect) Description(description string) *MultiSelect { m.description = description return m } // Options sets the options of the multi-select field. func (m *MultiSelect) Options(hasNewOption bool, options ...Option[string]) *MultiSelect { m.hasNewOption = hasNewOption if m.hasNewOption { newOption := []Option[string]{ {Key: "(new)", Value: ""}, } options = append(newOption, options...) } if len(options) <= 0 { return m } for i, o := range options { for _, v := range *m.value { if o.Value == v { options[i].selected = true break } } } m.options = options m.filteredOptions = options m.updateViewportHeight() return m } // Filterable sets the multi-select field as filterable. func (m *MultiSelect) Filterable(filterable bool) *MultiSelect { m.filterable = filterable return m } // Limit sets the limit of the multi-select field. func (m *MultiSelect) Limit(limit int) *MultiSelect { m.limit = limit return m } // Height sets the height of the multi-select field. func (m *MultiSelect) Height(height int) *MultiSelect { // What we really want to do is set the height of the viewport, but we // need a theme applied before we can calcualate its height. m.height = height m.updateViewportHeight() return m } // Validate sets the validation function of the multi-select field. func (m *MultiSelect) Validate(validate func([]string) error) *MultiSelect { m.validate = validate return m } // Error returns the error of the multi-select field. func (m *MultiSelect) Error() error { return m.err } // Skip returns whether the multiselect should be skipped or should be blocking. func (*MultiSelect) Skip() bool { return false } // Zoom returns whether the multiselect should be zoomed. func (*MultiSelect) Zoom() bool { return false } // Focus focuses the multi-select field. func (m *MultiSelect) Focus() tea.Cmd { m.focused = true return nil } // Blur blurs the multi-select field. func (m *MultiSelect) Blur() tea.Cmd { m.focused = false return nil } // KeyBinds returns the help message for the multi-select field. func (m *MultiSelect) KeyBinds() []key.Binding { return []key.Binding{ m.keymap.Toggle, m.keymap.Up, m.keymap.Down, m.keymap.Filter, m.keymap.SetFilter, m.keymap.ClearFilter, m.keymap.Prev, m.keymap.Submit, m.keymap.Next, } } // Init initializes the multi-select field. func (m *MultiSelect) Init() tea.Cmd { return nil } // Update updates the multi-select field. func (m *MultiSelect) Update(msg tea.Msg) (tea.Model, tea.Cmd) { // Enforce height on the viewport during update as we need themes to // be applied before we can calculate the height. m.updateViewportHeight() var cmd tea.Cmd if m.filtering { m.filter, cmd = m.filter.Update(msg) } if m.newInputActive { m.newInput, cmd = m.newInput.Update(msg) switch msg := msg.(type) { case tea.KeyMsg: switch { case key.Matches(msg, m.common.Keymap.Ok): newOptions := []Option[string]{} for _, item := range strings.Split(m.newInput.Value(), " ") { newOptions = append(newOptions, Option[string]{ Key: item, Value: item, selected: true, }) } m.options = append(m.options, newOptions...) filteredNewOptions := []Option[string]{} for _, item := range newOptions { if m.filterFunc(item.Key) { filteredNewOptions = append(filteredNewOptions, item) } } m.filteredOptions = append(m.filteredOptions, filteredNewOptions...) m.newInputActive = false m.newInput.SetValue("") m.newInput.Blur() case key.Matches(msg, m.common.Keymap.Back): m.newInputActive = false m.newInput.Blur() return m, SuppressBack() } } } switch msg := msg.(type) { case tea.KeyMsg: m.err = nil switch { case key.Matches(msg, m.keymap.Filter): m.setFilter(true) return m, m.filter.Focus() case key.Matches(msg, m.keymap.SetFilter) && m.filtering: if len(m.filteredOptions) <= 0 { m.filter.SetValue("") m.filteredOptions = m.options } m.setFilter(false) case key.Matches(msg, m.common.Keymap.Back) && m.filtering: m.filter.SetValue("") m.filteredOptions = m.options m.setFilter(false) case key.Matches(msg, m.keymap.ClearFilter): m.filter.SetValue("") m.filteredOptions = m.options m.setFilter(false) case key.Matches(msg, m.keymap.Up): if m.filtering && msg.String() == "k" { break } m.cursor = max(m.cursor-1, 0) if m.cursor < m.viewport.YOffset { m.viewport.SetYOffset(m.cursor) } case key.Matches(msg, m.keymap.Down): if m.filtering && msg.String() == "j" { break } m.cursor = min(m.cursor+1, len(m.filteredOptions)-1) if m.cursor >= m.viewport.YOffset+m.viewport.Height { m.viewport.LineDown(1) } case key.Matches(msg, m.keymap.GotoTop): if m.filtering { break } m.cursor = 0 m.viewport.GotoTop() case key.Matches(msg, m.keymap.GotoBottom): if m.filtering { break } m.cursor = len(m.filteredOptions) - 1 m.viewport.GotoBottom() case key.Matches(msg, m.keymap.HalfPageUp): m.cursor = max(m.cursor-m.viewport.Height/2, 0) m.viewport.HalfViewUp() case key.Matches(msg, m.keymap.HalfPageDown): m.cursor = min(m.cursor+m.viewport.Height/2, len(m.filteredOptions)-1) m.viewport.HalfViewDown() case key.Matches(msg, m.keymap.Toggle) && !m.filtering: if m.hasNewOption && m.cursor == 0 { m.newInputActive = true m.newInput.Focus() } else { for i, option := range m.options { if option.Key == m.filteredOptions[m.cursor].Key { if !m.options[m.cursor].selected && m.limit > 0 && m.numSelected() >= m.limit { break } selected := m.options[i].selected m.options[i].selected = !selected m.filteredOptions[m.cursor].selected = !selected m.finalize() } } } case key.Matches(msg, m.keymap.Prev): m.finalize() if m.err != nil { return m, nil } return m, huh.PrevField case key.Matches(msg, m.keymap.Next, m.keymap.Submit): m.finalize() if m.err != nil { return m, nil } return m, huh.NextField } if m.filtering { m.filteredOptions = m.options if m.filter.Value() != "" { m.filteredOptions = nil for _, option := range m.options { if m.filterFunc(option.Key) { m.filteredOptions = append(m.filteredOptions, option) } } } if len(m.filteredOptions) > 0 { m.cursor = min(m.cursor, len(m.filteredOptions)-1) m.viewport.SetYOffset(clamp(m.cursor, 0, len(m.filteredOptions)-m.viewport.Height)) } } } return m, cmd } // updateViewportHeight updates the viewport size according to the Height setting // on this multi-select field. func (m *MultiSelect) updateViewportHeight() { // If no height is set size the viewport to the number of options. if m.height <= 0 { m.viewport.Height = len(m.options) return } const minHeight = 1 m.viewport.Height = max(minHeight, m.height- lipgloss.Height(m.titleView())- lipgloss.Height(m.descriptionView())) } func (m *MultiSelect) numSelected() int { var count int for _, o := range m.options { if o.selected { count++ } } return count } func (m *MultiSelect) finalize() { *m.value = make([]string, 0) for _, option := range m.options { if option.selected { *m.value = append(*m.value, option.Value) } } m.err = m.validate(*m.value) } func (m *MultiSelect) activeStyles() *huh.FieldStyles { theme := m.theme if theme == nil { theme = huh.ThemeCharm() } if m.focused { return &theme.Focused } return &theme.Blurred } func (m *MultiSelect) titleView() string { if m.title == "" { return "" } var ( styles = m.activeStyles() sb = strings.Builder{} ) if m.filtering { sb.WriteString(m.filter.View()) } else if m.filter.Value() != "" { sb.WriteString(styles.Title.Render(m.title) + styles.Description.Render("/"+m.filter.Value())) } else { sb.WriteString(styles.Title.Render(m.title)) } if m.err != nil { sb.WriteString(styles.ErrorIndicator.String()) } return sb.String() } func (m *MultiSelect) descriptionView() string { return m.activeStyles().Description.Render(m.description) } func (m *MultiSelect) choicesView() string { var ( styles = m.activeStyles() c = styles.MultiSelectSelector.String() sb strings.Builder ) for i, option := range m.filteredOptions { if m.newInputActive && i == 0 { sb.WriteString(c) sb.WriteString(m.newInput.View()) sb.WriteString("\n") continue } else if m.cursor == i { sb.WriteString(c) } else { sb.WriteString(strings.Repeat(" ", lipgloss.Width(c))) } if m.filteredOptions[i].selected { sb.WriteString(styles.SelectedPrefix.String()) sb.WriteString(styles.SelectedOption.Render(option.Key)) } else { sb.WriteString(styles.UnselectedPrefix.String()) sb.WriteString(styles.UnselectedOption.Render(option.Key)) } if i < len(m.options)-1 { sb.WriteString("\n") } } for i := len(m.filteredOptions); i < len(m.options)-1; i++ { sb.WriteString("\n") } return sb.String() } // View renders the multi-select field. func (m *MultiSelect) View() string { styles := m.activeStyles() m.viewport.SetContent(m.choicesView()) var sb strings.Builder if m.title != "" { sb.WriteString(m.titleView()) sb.WriteString("\n") } if m.description != "" { sb.WriteString(m.descriptionView() + "\n") } sb.WriteString(m.viewport.View()) return styles.Base.Render(sb.String()) } func (m *MultiSelect) printOptions() { styles := m.activeStyles() var sb strings.Builder sb.WriteString(styles.Title.Render(m.title)) sb.WriteString("\n") for i, option := range m.options { if option.selected { sb.WriteString(styles.SelectedOption.Render(fmt.Sprintf("%d. %s %s", i+1, "✓", option.Key))) } else { sb.WriteString(fmt.Sprintf("%d. %s %s", i+1, " ", option.Key)) } sb.WriteString("\n") } fmt.Println(sb.String()) } // setFilter sets the filter of the select field. func (m *MultiSelect) setFilter(filter bool) { m.filtering = filter m.keymap.SetFilter.SetEnabled(filter) m.keymap.Filter.SetEnabled(!filter) m.keymap.Next.SetEnabled(!filter) m.keymap.Submit.SetEnabled(!filter) m.keymap.Prev.SetEnabled(!filter) m.keymap.ClearFilter.SetEnabled(!filter && m.filter.Value() != "") } // filterFunc returns true if the option matches the filter. func (m *MultiSelect) filterFunc(option string) bool { // XXX: remove diacritics or allow customization of filter function. return strings.Contains(strings.ToLower(option), strings.ToLower(m.filter.Value())) } // Run runs the multi-select field. func (m *MultiSelect) Run() error { if m.accessible { return m.runAccessible() } return huh.Run(m) } // runAccessible() runs the multi-select field in accessible mode. func (m *MultiSelect) runAccessible() error { m.printOptions() styles := m.activeStyles() var choice int for { fmt.Printf("Select up to %d options. 0 to continue.\n", m.limit) choice = accessibility.PromptInt("Select: ", 0, len(m.options)) if choice == 0 { m.finalize() err := m.validate(*m.value) if err != nil { fmt.Println(err) continue } break } if !m.options[choice-1].selected && m.limit > 0 && m.numSelected() >= m.limit { fmt.Printf("You can't select more than %d options.\n", m.limit) continue } m.options[choice-1].selected = !m.options[choice-1].selected if m.options[choice-1].selected { fmt.Printf("Selected: %s\n\n", m.options[choice-1].Key) } else { fmt.Printf("Deselected: %s\n\n", m.options[choice-1].Key) } m.printOptions() } var values []string for _, option := range m.options { if option.selected { *m.value = append(*m.value, option.Value) values = append(values, option.Key) } } fmt.Println(styles.SelectedOption.Render("Selected:", strings.Join(values, ", ")+"\n")) return nil } // WithTheme sets the theme of the multi-select field. func (m *MultiSelect) WithTheme(theme *huh.Theme) huh.Field { if m.theme != nil { return m } m.theme = theme m.filter.Cursor.Style = m.theme.Focused.TextInput.Cursor m.filter.PromptStyle = m.theme.Focused.TextInput.Prompt m.updateViewportHeight() return m } // WithKeyMap sets the keymap of the multi-select field. func (m *MultiSelect) WithKeyMap(k *huh.KeyMap) huh.Field { m.keymap = k.MultiSelect return m } // WithAccessible sets the accessible mode of the multi-select field. func (m *MultiSelect) WithAccessible(accessible bool) huh.Field { m.accessible = accessible return m } // WithWidth sets the width of the multi-select field. func (m *MultiSelect) WithWidth(width int) huh.Field { m.width = width return m } // WithHeight sets the height of the multi-select field. func (m *MultiSelect) WithHeight(height int) huh.Field { m.height = height return m } // WithPosition sets the position of the multi-select field. func (m *MultiSelect) WithPosition(p huh.FieldPosition) huh.Field { if m.filtering { return m } m.keymap.Prev.SetEnabled(!p.IsFirst()) m.keymap.Next.SetEnabled(!p.IsLast()) m.keymap.Submit.SetEnabled(p.IsLast()) return m } // GetKey returns the multi-select's key. func (m *MultiSelect) GetKey() string { return m.key } // GetValue returns the multi-select's value. func (m *MultiSelect) GetValue() any { return *m.value } func min(a, b int) int { if a < b { return a } return b } func max(a, b int) int { if a > b { return a } return b } func clamp(n, low, high int) int { if low > high { low, high = high, low } return min(high, max(low, n)) } type SuppressBackMsg struct{} func SuppressBack() tea.Cmd { return func() tea.Msg { return SuppressBackMsg{} } }