Implement choices help for Enum args

This commit is contained in:
2026-04-01 21:17:40 +02:00
parent 02ee0dab3d
commit fc8a192a07

View File

@@ -2,6 +2,7 @@
import typing
from argparse import ArgumentParser
from dataclasses import MISSING, Field, dataclass, field
from enum import Enum
SHORTOPT = "__argclass__shortopt"
CHOICES = "__argclass__choices"
@@ -19,26 +20,25 @@ def _make_gnu_option(name: str):
def _decide_default(field_: Field):
arg_cfg = {}
defaults_cfg = {}
if field_.default != MISSING:
arg_cfg["default"] = field_.default
defaults_cfg["default"] = field_.default
elif field_.default_factory != MISSING:
arg_cfg["default"] = field_.default_factory()
defaults_cfg["default"] = field_.default_factory()
else:
arg_cfg["required"] = True
return arg_cfg
defaults_cfg["required"] = True
return defaults_cfg
def _get_choices(field_: Field):
arg_cfg = {}
try:
arg_cfg["choices"] = field_.metadata[CHOICES]
except KeyError:
pass
return arg_cfg
def _get_choices(field_: Field, base_type):
if issubclass(base_type, Enum):
return {"choices": [e.value for e in base_type]}
elif CHOICES in field_.metadata:
return {"choices": field_.metadata[CHOICES]}
return {}
def _compute_arg_names(name, field_):
def _compute_arg_names(name, field_) -> list[str]:
names = [_make_gnu_option(name)] # long option
if SHORTOPT in field_.metadata:
if len(field_.metadata[SHORTOPT]) != 1:
@@ -50,8 +50,8 @@ def _compute_arg_names(name, field_):
def _prepare_bool(ap: ArgumentParser, name, field_):
arg_cfg = _decide_default(field_)
required = arg_cfg.get("required", False)
defaults_cfg = _decide_default(field_)
required = defaults_cfg.get("required", False)
bool_parser = ap.add_mutually_exclusive_group(required=required)
bool_parser.add_argument(
_make_gnu_option(name), action="store_true", dest=name
@@ -60,34 +60,39 @@ def _prepare_bool(ap: ArgumentParser, name, field_):
_make_gnu_option(f"no_{name}"), action="store_false", dest=name
)
if not required:
ap.set_defaults(**{name: arg_cfg["default"]})
ap.set_defaults(**{name: defaults_cfg["default"]})
def _prepare_list_cfg(name: str, field_: Field):
arg_cfg = {
**_decide_default(field_),
**_get_choices(field_),
}
defaults_cfg = _decide_default(field_)
subtype = typing.get_args(field_.type)
if not subtype:
raise ArgclassError(
f"List field {name} must have a subtype. Did you mean {name}: list[str]?"
f"List field {name} must have a subtype. "
"Did you mean {name}: list[str]?"
)
arg_cfg["type"] = subtype[0]
choices_cfg = _get_choices(field_, subtype[0])
if field_.metadata.get(ALLOW_EMPTY, False):
arg_cfg["nargs"] = "*"
nargs = "*"
else:
arg_cfg["nargs"] = "+"
return arg_cfg
nargs = "+"
return {
**defaults_cfg,
**choices_cfg,
"nargs": nargs,
"type": subtype[0],
}
def _prepare_trivial_cfg(name, field_):
def _prepare_trivial_cfg(_, field_: Field):
arg_cfg = {
**_decide_default(field_),
**_get_choices(field_),
**_get_choices(field_, field_.type),
"type": field_.type,
}
arg_cfg["type"] = field_.type
return arg_cfg