commit 1c5e34d32cbe774d50664fd6cd6d1ef764a3db55 Author: Pavel Lutskov Date: Tue Feb 2 20:34:54 2021 +0100 initial commit diff --git a/argclass.py b/argclass.py new file mode 100644 index 0000000..61b3b16 --- /dev/null +++ b/argclass.py @@ -0,0 +1,123 @@ +import typing +from argparse import ArgumentParser +from dataclasses import dataclass, field, MISSING + + +def make_gnu_option(name): + return f'--{name.replace("_", "-")}' + + +def decide_default(field_): + arg_cfg = {} + if field_.default != MISSING: + arg_cfg['default'] = field_.default + elif field_.default_factory != MISSING: + arg_cfg['default'] = field_.default_factory() + else: + arg_cfg['required'] = True + return arg_cfg + + +def get_choices(field_): + arg_cfg = {} + try: + arg_cfg['choices'] = field_.metadata['choices'] + except KeyError: + pass + return arg_cfg + + +def compute_arg_names(name, field_): + names = [make_gnu_option(name)] # long option + try: + names.append(f'-{field_.metadata["shortopt"]}') + except KeyError: + pass + return names + + +def _prepare_bool(ap: ArgumentParser, name, field_): + arg_cfg = decide_default(field_) + required = 'required' in arg_cfg + bool_parser = ap.add_mutually_exclusive_group(required=required) + bool_parser.add_argument(make_gnu_option(name), + action='store_true', + dest=name) + bool_parser.add_argument(make_gnu_option(f'no_{name}'), + action='store_false', + dest=name) + if not required: + ap.set_defaults(**{name: arg_cfg['default']}) + + +def _prepare_list_cfg(name, field_): + arg_cfg = { + **decide_default(field_), + **get_choices(field_), + } + arg_cfg['type'] = typing.get_args(field_.type)[0] + if field_.metadata.get('allow_empty', False): + arg_cfg['nargs'] = '*' + else: + arg_cfg['nargs'] = '+' + return arg_cfg + + +def _prepare_trivial_cfg(name, field_): + arg_cfg = { + **decide_default(field_), + **get_choices(field_), + } + arg_cfg['type'] = field_.type + return arg_cfg + + +def _prepare_list(ap: ArgumentParser, name, field_): + arg_cfg = _prepare_list_cfg(name, field_) + arg_names = compute_arg_names(name, field_) + ap.add_argument(*arg_names, **arg_cfg) + + +def _prepare_trivial(ap: ArgumentParser, name, field_): + arg_cfg = _prepare_trivial_cfg(name, field_) + arg_names = compute_arg_names(name, field_) + ap.add_argument(*arg_names, **arg_cfg) + + +def prepare_field(ap, name, field_): + if field_.type is bool: + _prepare_bool(ap, name, field_) + elif typing.get_origin(field_.type) is list: + _prepare_list(ap, name, field_) + else: + _prepare_trivial(ap, name, field_) + + +def argclass(cls): + + class ArgClass(dataclass(cls)): + + @classmethod + def parse_args(cls): + ap = ArgumentParser() + for name, field_ in cls.__dataclass_fields__.items(): + prepare_field(ap, name, field_) + return ap.parse_args() + + return ArgClass + + +@argclass +class A: + required_int_arg: int + required_bool_arg: bool + required_int_arg_with_choices: int = field(metadata={'choices': [3, 4, 5]}) + arg_with_default: str = 'abc' + list_arg_with_default: list[int] = field(default_factory=lambda: [1, 2]) + list_arg_with_default_mb_empty: list[int] = field( + default_factory=lambda: [], metadata={'allow_empty': True} + ) + + +if __name__ == '__main__': + print(A.parse_args())