All checks were successful
continuous-integration/drone/push Build is passing
Reviewed-on: https://git.deguo.duckdns.org/pavel/argclass/pulls/1 Co-authored-by: Pavel Lutskov <pavel.lutskov@gmail.com> Co-committed-by: Pavel Lutskov <pavel.lutskov@gmail.com>
112 lines
2.8 KiB
Python
112 lines
2.8 KiB
Python
import typing
|
|
from argparse import ArgumentParser
|
|
from dataclasses import dataclass, MISSING
|
|
|
|
|
|
def make_gnu_option(name):
|
|
return f'--{name.replace("_", "-")}'
|
|
|
|
|
|
def decide_default(field_):
|
|
arg_cfg = {}
|
|
if field_.default != MISSING:
|
|
arg_cfg["default"] = field_.default
|
|
elif field_.default_factory != MISSING:
|
|
arg_cfg["default"] = field_.default_factory()
|
|
else:
|
|
arg_cfg["required"] = True
|
|
return arg_cfg
|
|
|
|
|
|
def get_choices(field_):
|
|
arg_cfg = {}
|
|
try:
|
|
arg_cfg["choices"] = field_.metadata["choices"]
|
|
except KeyError:
|
|
pass
|
|
return arg_cfg
|
|
|
|
|
|
def compute_arg_names(name, field_):
|
|
names = [make_gnu_option(name)] # long option
|
|
try:
|
|
names.append(f'-{field_.metadata["shortopt"]}')
|
|
except KeyError:
|
|
pass
|
|
return names
|
|
|
|
|
|
def _prepare_bool(ap: ArgumentParser, name, field_):
|
|
arg_cfg = decide_default(field_)
|
|
required = "required" in arg_cfg
|
|
bool_parser = ap.add_mutually_exclusive_group(required=required)
|
|
bool_parser.add_argument(
|
|
make_gnu_option(name), action="store_true", dest=name
|
|
)
|
|
bool_parser.add_argument(
|
|
make_gnu_option(f"no_{name}"), action="store_false", dest=name
|
|
)
|
|
if not required:
|
|
ap.set_defaults(**{name: arg_cfg["default"]})
|
|
|
|
|
|
def _prepare_list_cfg(name, field_):
|
|
arg_cfg = {
|
|
**decide_default(field_),
|
|
**get_choices(field_),
|
|
}
|
|
subtype = typing.get_args(field_.type)
|
|
if not subtype:
|
|
arg_cfg["type"] = str
|
|
else:
|
|
arg_cfg["type"] = subtype[0]
|
|
|
|
if field_.metadata.get("allow_empty", False):
|
|
arg_cfg["nargs"] = "*"
|
|
else:
|
|
arg_cfg["nargs"] = "+"
|
|
return arg_cfg
|
|
|
|
|
|
def _prepare_trivial_cfg(name, field_):
|
|
arg_cfg = {
|
|
**decide_default(field_),
|
|
**get_choices(field_),
|
|
}
|
|
arg_cfg["type"] = field_.type
|
|
return arg_cfg
|
|
|
|
|
|
def _prepare_list(ap: ArgumentParser, name, field_):
|
|
arg_cfg = _prepare_list_cfg(name, field_)
|
|
arg_names = compute_arg_names(name, field_)
|
|
ap.add_argument(*arg_names, **arg_cfg)
|
|
|
|
|
|
def _prepare_trivial(ap: ArgumentParser, name, field_):
|
|
arg_cfg = _prepare_trivial_cfg(name, field_)
|
|
arg_names = compute_arg_names(name, field_)
|
|
ap.add_argument(*arg_names, **arg_cfg)
|
|
|
|
|
|
def prepare_field(ap, name, field_):
|
|
if field_.type is bool:
|
|
_prepare_bool(ap, name, field_)
|
|
elif field_.type is list or typing.get_origin(field_.type) is list:
|
|
_prepare_list(ap, name, field_)
|
|
else:
|
|
_prepare_trivial(ap, name, field_)
|
|
|
|
|
|
def argclass(cls):
|
|
@classmethod
|
|
def parse_args(cls, argv):
|
|
ap = ArgumentParser()
|
|
for name, field_ in cls.__dataclass_fields__.items():
|
|
prepare_field(ap, name, field_)
|
|
return cls(**vars(ap.parse_args(argv)))
|
|
|
|
cls = dataclass(cls)
|
|
cls.parse_args = parse_args
|
|
return cls
|