import typing from argparse import ArgumentParser from dataclasses import dataclass, MISSING def make_gnu_option(name): return f'--{name.replace("_", "-")}' def decide_default(field_): arg_cfg = {} if field_.default != MISSING: arg_cfg['default'] = field_.default elif field_.default_factory != MISSING: arg_cfg['default'] = field_.default_factory() else: arg_cfg['required'] = True return arg_cfg def get_choices(field_): arg_cfg = {} try: arg_cfg['choices'] = field_.metadata['choices'] except KeyError: pass return arg_cfg def compute_arg_names(name, field_): names = [make_gnu_option(name)] # long option try: names.append(f'-{field_.metadata["shortopt"]}') except KeyError: pass return names def _prepare_bool(ap: ArgumentParser, name, field_): arg_cfg = decide_default(field_) required = 'required' in arg_cfg bool_parser = ap.add_mutually_exclusive_group(required=required) bool_parser.add_argument(make_gnu_option(name), action='store_true', dest=name) bool_parser.add_argument(make_gnu_option(f'no_{name}'), action='store_false', dest=name) if not required: ap.set_defaults(**{name: arg_cfg['default']}) def _prepare_list_cfg(name, field_): arg_cfg = { **decide_default(field_), **get_choices(field_), } subtype = typing.get_args(field_.type) if not subtype: arg_cfg['type'] = str else: arg_cfg['type'] = subtype[0] if field_.metadata.get('allow_empty', False): arg_cfg['nargs'] = '*' else: arg_cfg['nargs'] = '+' return arg_cfg def _prepare_trivial_cfg(name, field_): arg_cfg = { **decide_default(field_), **get_choices(field_), } arg_cfg['type'] = field_.type return arg_cfg def _prepare_list(ap: ArgumentParser, name, field_): arg_cfg = _prepare_list_cfg(name, field_) arg_names = compute_arg_names(name, field_) ap.add_argument(*arg_names, **arg_cfg) def _prepare_trivial(ap: ArgumentParser, name, field_): arg_cfg = _prepare_trivial_cfg(name, field_) arg_names = compute_arg_names(name, field_) ap.add_argument(*arg_names, **arg_cfg) def prepare_field(ap, name, field_): if field_.type is bool: _prepare_bool(ap, name, field_) elif field_.type is list or typing.get_origin(field_.type) is list: _prepare_list(ap, name, field_) else: _prepare_trivial(ap, name, field_) def argclass(cls): @classmethod def parse_args(cls, argv): ap = ArgumentParser() for name, field_ in cls.__dataclass_fields__.items(): prepare_field(ap, name, field_) return cls(**vars(ap.parse_args(argv))) cls = dataclass(cls) cls.parse_args = parse_args return cls