initial commit
This commit is contained in:
123
argclass.py
Normal file
123
argclass.py
Normal file
@@ -0,0 +1,123 @@
|
|||||||
|
import typing
|
||||||
|
from argparse import ArgumentParser
|
||||||
|
from dataclasses import dataclass, field, MISSING
|
||||||
|
|
||||||
|
|
||||||
|
def make_gnu_option(name):
|
||||||
|
return f'--{name.replace("_", "-")}'
|
||||||
|
|
||||||
|
|
||||||
|
def decide_default(field_):
|
||||||
|
arg_cfg = {}
|
||||||
|
if field_.default != MISSING:
|
||||||
|
arg_cfg['default'] = field_.default
|
||||||
|
elif field_.default_factory != MISSING:
|
||||||
|
arg_cfg['default'] = field_.default_factory()
|
||||||
|
else:
|
||||||
|
arg_cfg['required'] = True
|
||||||
|
return arg_cfg
|
||||||
|
|
||||||
|
|
||||||
|
def get_choices(field_):
|
||||||
|
arg_cfg = {}
|
||||||
|
try:
|
||||||
|
arg_cfg['choices'] = field_.metadata['choices']
|
||||||
|
except KeyError:
|
||||||
|
pass
|
||||||
|
return arg_cfg
|
||||||
|
|
||||||
|
|
||||||
|
def compute_arg_names(name, field_):
|
||||||
|
names = [make_gnu_option(name)] # long option
|
||||||
|
try:
|
||||||
|
names.append(f'-{field_.metadata["shortopt"]}')
|
||||||
|
except KeyError:
|
||||||
|
pass
|
||||||
|
return names
|
||||||
|
|
||||||
|
|
||||||
|
def _prepare_bool(ap: ArgumentParser, name, field_):
|
||||||
|
arg_cfg = decide_default(field_)
|
||||||
|
required = 'required' in arg_cfg
|
||||||
|
bool_parser = ap.add_mutually_exclusive_group(required=required)
|
||||||
|
bool_parser.add_argument(make_gnu_option(name),
|
||||||
|
action='store_true',
|
||||||
|
dest=name)
|
||||||
|
bool_parser.add_argument(make_gnu_option(f'no_{name}'),
|
||||||
|
action='store_false',
|
||||||
|
dest=name)
|
||||||
|
if not required:
|
||||||
|
ap.set_defaults(**{name: arg_cfg['default']})
|
||||||
|
|
||||||
|
|
||||||
|
def _prepare_list_cfg(name, field_):
|
||||||
|
arg_cfg = {
|
||||||
|
**decide_default(field_),
|
||||||
|
**get_choices(field_),
|
||||||
|
}
|
||||||
|
arg_cfg['type'] = typing.get_args(field_.type)[0]
|
||||||
|
if field_.metadata.get('allow_empty', False):
|
||||||
|
arg_cfg['nargs'] = '*'
|
||||||
|
else:
|
||||||
|
arg_cfg['nargs'] = '+'
|
||||||
|
return arg_cfg
|
||||||
|
|
||||||
|
|
||||||
|
def _prepare_trivial_cfg(name, field_):
|
||||||
|
arg_cfg = {
|
||||||
|
**decide_default(field_),
|
||||||
|
**get_choices(field_),
|
||||||
|
}
|
||||||
|
arg_cfg['type'] = field_.type
|
||||||
|
return arg_cfg
|
||||||
|
|
||||||
|
|
||||||
|
def _prepare_list(ap: ArgumentParser, name, field_):
|
||||||
|
arg_cfg = _prepare_list_cfg(name, field_)
|
||||||
|
arg_names = compute_arg_names(name, field_)
|
||||||
|
ap.add_argument(*arg_names, **arg_cfg)
|
||||||
|
|
||||||
|
|
||||||
|
def _prepare_trivial(ap: ArgumentParser, name, field_):
|
||||||
|
arg_cfg = _prepare_trivial_cfg(name, field_)
|
||||||
|
arg_names = compute_arg_names(name, field_)
|
||||||
|
ap.add_argument(*arg_names, **arg_cfg)
|
||||||
|
|
||||||
|
|
||||||
|
def prepare_field(ap, name, field_):
|
||||||
|
if field_.type is bool:
|
||||||
|
_prepare_bool(ap, name, field_)
|
||||||
|
elif typing.get_origin(field_.type) is list:
|
||||||
|
_prepare_list(ap, name, field_)
|
||||||
|
else:
|
||||||
|
_prepare_trivial(ap, name, field_)
|
||||||
|
|
||||||
|
|
||||||
|
def argclass(cls):
|
||||||
|
|
||||||
|
class ArgClass(dataclass(cls)):
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def parse_args(cls):
|
||||||
|
ap = ArgumentParser()
|
||||||
|
for name, field_ in cls.__dataclass_fields__.items():
|
||||||
|
prepare_field(ap, name, field_)
|
||||||
|
return ap.parse_args()
|
||||||
|
|
||||||
|
return ArgClass
|
||||||
|
|
||||||
|
|
||||||
|
@argclass
|
||||||
|
class A:
|
||||||
|
required_int_arg: int
|
||||||
|
required_bool_arg: bool
|
||||||
|
required_int_arg_with_choices: int = field(metadata={'choices': [3, 4, 5]})
|
||||||
|
arg_with_default: str = 'abc'
|
||||||
|
list_arg_with_default: list[int] = field(default_factory=lambda: [1, 2])
|
||||||
|
list_arg_with_default_mb_empty: list[int] = field(
|
||||||
|
default_factory=lambda: [], metadata={'allow_empty': True}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
print(A.parse_args())
|
||||||
Reference in New Issue
Block a user