Files
argclass/argclass.py
Pavel Lutskov f93588f4bf
All checks were successful
continuous-integration/drone/push Build is passing
Add black to CI/CD (#1)
Reviewed-on: https://git.deguo.duckdns.org/pavel/argclass/pulls/1
Co-authored-by: Pavel Lutskov <pavel.lutskov@gmail.com>
Co-committed-by: Pavel Lutskov <pavel.lutskov@gmail.com>
2021-09-12 12:09:57 +02:00

112 lines
2.8 KiB
Python

import typing
from argparse import ArgumentParser
from dataclasses import dataclass, MISSING
def make_gnu_option(name):
return f'--{name.replace("_", "-")}'
def decide_default(field_):
arg_cfg = {}
if field_.default != MISSING:
arg_cfg["default"] = field_.default
elif field_.default_factory != MISSING:
arg_cfg["default"] = field_.default_factory()
else:
arg_cfg["required"] = True
return arg_cfg
def get_choices(field_):
arg_cfg = {}
try:
arg_cfg["choices"] = field_.metadata["choices"]
except KeyError:
pass
return arg_cfg
def compute_arg_names(name, field_):
names = [make_gnu_option(name)] # long option
try:
names.append(f'-{field_.metadata["shortopt"]}')
except KeyError:
pass
return names
def _prepare_bool(ap: ArgumentParser, name, field_):
arg_cfg = decide_default(field_)
required = "required" in arg_cfg
bool_parser = ap.add_mutually_exclusive_group(required=required)
bool_parser.add_argument(
make_gnu_option(name), action="store_true", dest=name
)
bool_parser.add_argument(
make_gnu_option(f"no_{name}"), action="store_false", dest=name
)
if not required:
ap.set_defaults(**{name: arg_cfg["default"]})
def _prepare_list_cfg(name, field_):
arg_cfg = {
**decide_default(field_),
**get_choices(field_),
}
subtype = typing.get_args(field_.type)
if not subtype:
arg_cfg["type"] = str
else:
arg_cfg["type"] = subtype[0]
if field_.metadata.get("allow_empty", False):
arg_cfg["nargs"] = "*"
else:
arg_cfg["nargs"] = "+"
return arg_cfg
def _prepare_trivial_cfg(name, field_):
arg_cfg = {
**decide_default(field_),
**get_choices(field_),
}
arg_cfg["type"] = field_.type
return arg_cfg
def _prepare_list(ap: ArgumentParser, name, field_):
arg_cfg = _prepare_list_cfg(name, field_)
arg_names = compute_arg_names(name, field_)
ap.add_argument(*arg_names, **arg_cfg)
def _prepare_trivial(ap: ArgumentParser, name, field_):
arg_cfg = _prepare_trivial_cfg(name, field_)
arg_names = compute_arg_names(name, field_)
ap.add_argument(*arg_names, **arg_cfg)
def prepare_field(ap, name, field_):
if field_.type is bool:
_prepare_bool(ap, name, field_)
elif field_.type is list or typing.get_origin(field_.type) is list:
_prepare_list(ap, name, field_)
else:
_prepare_trivial(ap, name, field_)
def argclass(cls):
@classmethod
def parse_args(cls, argv):
ap = ArgumentParser()
for name, field_ in cls.__dataclass_fields__.items():
prepare_field(ap, name, field_)
return cls(**vars(ap.parse_args(argv)))
cls = dataclass(cls)
cls.parse_args = parse_args
return cls