Implement choices help for Enum args
This commit is contained in:
@@ -2,6 +2,7 @@
|
||||
import typing
|
||||
from argparse import ArgumentParser
|
||||
from dataclasses import MISSING, Field, dataclass, field
|
||||
from enum import Enum
|
||||
|
||||
SHORTOPT = "__argclass__shortopt"
|
||||
CHOICES = "__argclass__choices"
|
||||
@@ -19,26 +20,25 @@ def _make_gnu_option(name: str):
|
||||
|
||||
|
||||
def _decide_default(field_: Field):
|
||||
arg_cfg = {}
|
||||
defaults_cfg = {}
|
||||
if field_.default != MISSING:
|
||||
arg_cfg["default"] = field_.default
|
||||
defaults_cfg["default"] = field_.default
|
||||
elif field_.default_factory != MISSING:
|
||||
arg_cfg["default"] = field_.default_factory()
|
||||
defaults_cfg["default"] = field_.default_factory()
|
||||
else:
|
||||
arg_cfg["required"] = True
|
||||
return arg_cfg
|
||||
defaults_cfg["required"] = True
|
||||
return defaults_cfg
|
||||
|
||||
|
||||
def _get_choices(field_: Field):
|
||||
arg_cfg = {}
|
||||
try:
|
||||
arg_cfg["choices"] = field_.metadata[CHOICES]
|
||||
except KeyError:
|
||||
pass
|
||||
return arg_cfg
|
||||
def _get_choices(field_: Field, base_type):
|
||||
if issubclass(base_type, Enum):
|
||||
return {"choices": [e.value for e in base_type]}
|
||||
elif CHOICES in field_.metadata:
|
||||
return {"choices": field_.metadata[CHOICES]}
|
||||
return {}
|
||||
|
||||
|
||||
def _compute_arg_names(name, field_):
|
||||
def _compute_arg_names(name, field_) -> list[str]:
|
||||
names = [_make_gnu_option(name)] # long option
|
||||
if SHORTOPT in field_.metadata:
|
||||
if len(field_.metadata[SHORTOPT]) != 1:
|
||||
@@ -50,8 +50,8 @@ def _compute_arg_names(name, field_):
|
||||
|
||||
|
||||
def _prepare_bool(ap: ArgumentParser, name, field_):
|
||||
arg_cfg = _decide_default(field_)
|
||||
required = arg_cfg.get("required", False)
|
||||
defaults_cfg = _decide_default(field_)
|
||||
required = defaults_cfg.get("required", False)
|
||||
bool_parser = ap.add_mutually_exclusive_group(required=required)
|
||||
bool_parser.add_argument(
|
||||
_make_gnu_option(name), action="store_true", dest=name
|
||||
@@ -60,34 +60,39 @@ def _prepare_bool(ap: ArgumentParser, name, field_):
|
||||
_make_gnu_option(f"no_{name}"), action="store_false", dest=name
|
||||
)
|
||||
if not required:
|
||||
ap.set_defaults(**{name: arg_cfg["default"]})
|
||||
ap.set_defaults(**{name: defaults_cfg["default"]})
|
||||
|
||||
|
||||
def _prepare_list_cfg(name: str, field_: Field):
|
||||
arg_cfg = {
|
||||
**_decide_default(field_),
|
||||
**_get_choices(field_),
|
||||
}
|
||||
defaults_cfg = _decide_default(field_)
|
||||
|
||||
subtype = typing.get_args(field_.type)
|
||||
if not subtype:
|
||||
raise ArgclassError(
|
||||
f"List field {name} must have a subtype. Did you mean {name}: list[str]?"
|
||||
f"List field {name} must have a subtype. "
|
||||
"Did you mean {name}: list[str]?"
|
||||
)
|
||||
arg_cfg["type"] = subtype[0]
|
||||
choices_cfg = _get_choices(field_, subtype[0])
|
||||
|
||||
if field_.metadata.get(ALLOW_EMPTY, False):
|
||||
arg_cfg["nargs"] = "*"
|
||||
nargs = "*"
|
||||
else:
|
||||
arg_cfg["nargs"] = "+"
|
||||
return arg_cfg
|
||||
nargs = "+"
|
||||
|
||||
return {
|
||||
**defaults_cfg,
|
||||
**choices_cfg,
|
||||
"nargs": nargs,
|
||||
"type": subtype[0],
|
||||
}
|
||||
|
||||
|
||||
def _prepare_trivial_cfg(name, field_):
|
||||
def _prepare_trivial_cfg(_, field_: Field):
|
||||
arg_cfg = {
|
||||
**_decide_default(field_),
|
||||
**_get_choices(field_),
|
||||
**_get_choices(field_, field_.type),
|
||||
"type": field_.type,
|
||||
}
|
||||
arg_cfg["type"] = field_.type
|
||||
return arg_cfg
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user