diff --git a/src/argclass/__init__.py b/src/argclass/__init__.py index ddf86f8..5213372 100644 --- a/src/argclass/__init__.py +++ b/src/argclass/__init__.py @@ -2,6 +2,7 @@ import typing from argparse import ArgumentParser from dataclasses import MISSING, Field, dataclass, field +from enum import Enum SHORTOPT = "__argclass__shortopt" CHOICES = "__argclass__choices" @@ -19,26 +20,25 @@ def _make_gnu_option(name: str): def _decide_default(field_: Field): - arg_cfg = {} + defaults_cfg = {} if field_.default != MISSING: - arg_cfg["default"] = field_.default + defaults_cfg["default"] = field_.default elif field_.default_factory != MISSING: - arg_cfg["default"] = field_.default_factory() + defaults_cfg["default"] = field_.default_factory() else: - arg_cfg["required"] = True - return arg_cfg + defaults_cfg["required"] = True + return defaults_cfg -def _get_choices(field_: Field): - arg_cfg = {} - try: - arg_cfg["choices"] = field_.metadata[CHOICES] - except KeyError: - pass - return arg_cfg +def _get_choices(field_: Field, base_type): + if issubclass(base_type, Enum): + return {"choices": [e.value for e in base_type]} + elif CHOICES in field_.metadata: + return {"choices": field_.metadata[CHOICES]} + return {} -def _compute_arg_names(name, field_): +def _compute_arg_names(name, field_) -> list[str]: names = [_make_gnu_option(name)] # long option if SHORTOPT in field_.metadata: if len(field_.metadata[SHORTOPT]) != 1: @@ -50,8 +50,8 @@ def _compute_arg_names(name, field_): def _prepare_bool(ap: ArgumentParser, name, field_): - arg_cfg = _decide_default(field_) - required = arg_cfg.get("required", False) + defaults_cfg = _decide_default(field_) + required = defaults_cfg.get("required", False) bool_parser = ap.add_mutually_exclusive_group(required=required) bool_parser.add_argument( _make_gnu_option(name), action="store_true", dest=name @@ -60,34 +60,39 @@ def _prepare_bool(ap: ArgumentParser, name, field_): _make_gnu_option(f"no_{name}"), action="store_false", dest=name ) if not required: - ap.set_defaults(**{name: arg_cfg["default"]}) + ap.set_defaults(**{name: defaults_cfg["default"]}) def _prepare_list_cfg(name: str, field_: Field): - arg_cfg = { - **_decide_default(field_), - **_get_choices(field_), - } + defaults_cfg = _decide_default(field_) + subtype = typing.get_args(field_.type) if not subtype: raise ArgclassError( - f"List field {name} must have a subtype. Did you mean {name}: list[str]?" + f"List field {name} must have a subtype. " + "Did you mean {name}: list[str]?" ) - arg_cfg["type"] = subtype[0] + choices_cfg = _get_choices(field_, subtype[0]) if field_.metadata.get(ALLOW_EMPTY, False): - arg_cfg["nargs"] = "*" + nargs = "*" else: - arg_cfg["nargs"] = "+" - return arg_cfg + nargs = "+" + + return { + **defaults_cfg, + **choices_cfg, + "nargs": nargs, + "type": subtype[0], + } -def _prepare_trivial_cfg(name, field_): +def _prepare_trivial_cfg(_, field_: Field): arg_cfg = { **_decide_default(field_), - **_get_choices(field_), + **_get_choices(field_, field_.type), + "type": field_.type, } - arg_cfg["type"] = field_.type return arg_cfg