initial commit

This commit is contained in:
2021-02-02 20:34:54 +01:00
commit 1c5e34d32c

123
argclass.py Normal file
View File

@@ -0,0 +1,123 @@
import typing
from argparse import ArgumentParser
from dataclasses import dataclass, field, MISSING
def make_gnu_option(name):
return f'--{name.replace("_", "-")}'
def decide_default(field_):
arg_cfg = {}
if field_.default != MISSING:
arg_cfg['default'] = field_.default
elif field_.default_factory != MISSING:
arg_cfg['default'] = field_.default_factory()
else:
arg_cfg['required'] = True
return arg_cfg
def get_choices(field_):
arg_cfg = {}
try:
arg_cfg['choices'] = field_.metadata['choices']
except KeyError:
pass
return arg_cfg
def compute_arg_names(name, field_):
names = [make_gnu_option(name)] # long option
try:
names.append(f'-{field_.metadata["shortopt"]}')
except KeyError:
pass
return names
def _prepare_bool(ap: ArgumentParser, name, field_):
arg_cfg = decide_default(field_)
required = 'required' in arg_cfg
bool_parser = ap.add_mutually_exclusive_group(required=required)
bool_parser.add_argument(make_gnu_option(name),
action='store_true',
dest=name)
bool_parser.add_argument(make_gnu_option(f'no_{name}'),
action='store_false',
dest=name)
if not required:
ap.set_defaults(**{name: arg_cfg['default']})
def _prepare_list_cfg(name, field_):
arg_cfg = {
**decide_default(field_),
**get_choices(field_),
}
arg_cfg['type'] = typing.get_args(field_.type)[0]
if field_.metadata.get('allow_empty', False):
arg_cfg['nargs'] = '*'
else:
arg_cfg['nargs'] = '+'
return arg_cfg
def _prepare_trivial_cfg(name, field_):
arg_cfg = {
**decide_default(field_),
**get_choices(field_),
}
arg_cfg['type'] = field_.type
return arg_cfg
def _prepare_list(ap: ArgumentParser, name, field_):
arg_cfg = _prepare_list_cfg(name, field_)
arg_names = compute_arg_names(name, field_)
ap.add_argument(*arg_names, **arg_cfg)
def _prepare_trivial(ap: ArgumentParser, name, field_):
arg_cfg = _prepare_trivial_cfg(name, field_)
arg_names = compute_arg_names(name, field_)
ap.add_argument(*arg_names, **arg_cfg)
def prepare_field(ap, name, field_):
if field_.type is bool:
_prepare_bool(ap, name, field_)
elif typing.get_origin(field_.type) is list:
_prepare_list(ap, name, field_)
else:
_prepare_trivial(ap, name, field_)
def argclass(cls):
class ArgClass(dataclass(cls)):
@classmethod
def parse_args(cls):
ap = ArgumentParser()
for name, field_ in cls.__dataclass_fields__.items():
prepare_field(ap, name, field_)
return ap.parse_args()
return ArgClass
@argclass
class A:
required_int_arg: int
required_bool_arg: bool
required_int_arg_with_choices: int = field(metadata={'choices': [3, 4, 5]})
arg_with_default: str = 'abc'
list_arg_with_default: list[int] = field(default_factory=lambda: [1, 2])
list_arg_with_default_mb_empty: list[int] = field(
default_factory=lambda: [], metadata={'allow_empty': True}
)
if __name__ == '__main__':
print(A.parse_args())