Implement choices help for Enum args

This commit is contained in:
2026-04-01 21:17:40 +02:00
parent 02ee0dab3d
commit fc8a192a07

View File

@@ -2,6 +2,7 @@
import typing import typing
from argparse import ArgumentParser from argparse import ArgumentParser
from dataclasses import MISSING, Field, dataclass, field from dataclasses import MISSING, Field, dataclass, field
from enum import Enum
SHORTOPT = "__argclass__shortopt" SHORTOPT = "__argclass__shortopt"
CHOICES = "__argclass__choices" CHOICES = "__argclass__choices"
@@ -19,26 +20,25 @@ def _make_gnu_option(name: str):
def _decide_default(field_: Field): def _decide_default(field_: Field):
arg_cfg = {} defaults_cfg = {}
if field_.default != MISSING: if field_.default != MISSING:
arg_cfg["default"] = field_.default defaults_cfg["default"] = field_.default
elif field_.default_factory != MISSING: elif field_.default_factory != MISSING:
arg_cfg["default"] = field_.default_factory() defaults_cfg["default"] = field_.default_factory()
else: else:
arg_cfg["required"] = True defaults_cfg["required"] = True
return arg_cfg return defaults_cfg
def _get_choices(field_: Field): def _get_choices(field_: Field, base_type):
arg_cfg = {} if issubclass(base_type, Enum):
try: return {"choices": [e.value for e in base_type]}
arg_cfg["choices"] = field_.metadata[CHOICES] elif CHOICES in field_.metadata:
except KeyError: return {"choices": field_.metadata[CHOICES]}
pass return {}
return arg_cfg
def _compute_arg_names(name, field_): def _compute_arg_names(name, field_) -> list[str]:
names = [_make_gnu_option(name)] # long option names = [_make_gnu_option(name)] # long option
if SHORTOPT in field_.metadata: if SHORTOPT in field_.metadata:
if len(field_.metadata[SHORTOPT]) != 1: if len(field_.metadata[SHORTOPT]) != 1:
@@ -50,8 +50,8 @@ def _compute_arg_names(name, field_):
def _prepare_bool(ap: ArgumentParser, name, field_): def _prepare_bool(ap: ArgumentParser, name, field_):
arg_cfg = _decide_default(field_) defaults_cfg = _decide_default(field_)
required = arg_cfg.get("required", False) required = defaults_cfg.get("required", False)
bool_parser = ap.add_mutually_exclusive_group(required=required) bool_parser = ap.add_mutually_exclusive_group(required=required)
bool_parser.add_argument( bool_parser.add_argument(
_make_gnu_option(name), action="store_true", dest=name _make_gnu_option(name), action="store_true", dest=name
@@ -60,34 +60,39 @@ def _prepare_bool(ap: ArgumentParser, name, field_):
_make_gnu_option(f"no_{name}"), action="store_false", dest=name _make_gnu_option(f"no_{name}"), action="store_false", dest=name
) )
if not required: if not required:
ap.set_defaults(**{name: arg_cfg["default"]}) ap.set_defaults(**{name: defaults_cfg["default"]})
def _prepare_list_cfg(name: str, field_: Field): def _prepare_list_cfg(name: str, field_: Field):
arg_cfg = { defaults_cfg = _decide_default(field_)
**_decide_default(field_),
**_get_choices(field_),
}
subtype = typing.get_args(field_.type) subtype = typing.get_args(field_.type)
if not subtype: if not subtype:
raise ArgclassError( raise ArgclassError(
f"List field {name} must have a subtype. Did you mean {name}: list[str]?" f"List field {name} must have a subtype. "
"Did you mean {name}: list[str]?"
) )
arg_cfg["type"] = subtype[0] choices_cfg = _get_choices(field_, subtype[0])
if field_.metadata.get(ALLOW_EMPTY, False): if field_.metadata.get(ALLOW_EMPTY, False):
arg_cfg["nargs"] = "*" nargs = "*"
else: else:
arg_cfg["nargs"] = "+" nargs = "+"
return arg_cfg
return {
**defaults_cfg,
**choices_cfg,
"nargs": nargs,
"type": subtype[0],
}
def _prepare_trivial_cfg(name, field_): def _prepare_trivial_cfg(_, field_: Field):
arg_cfg = { arg_cfg = {
**_decide_default(field_), **_decide_default(field_),
**_get_choices(field_), **_get_choices(field_, field_.type),
"type": field_.type,
} }
arg_cfg["type"] = field_.type
return arg_cfg return arg_cfg