import typing from argparse import ArgumentParser from dataclasses import dataclass, MISSING def make_gnu_option(name): return f'--{name.replace("_", "-")}' def decide_default(field_): arg_cfg = {} if field_.default != MISSING: arg_cfg["default"] = field_.default elif field_.default_factory != MISSING: arg_cfg["default"] = field_.default_factory() else: arg_cfg["required"] = True return arg_cfg def get_choices(field_): arg_cfg = {} try: arg_cfg["choices"] = field_.metadata["choices"] except KeyError: pass return arg_cfg def compute_arg_names(name, field_): names = [make_gnu_option(name)] # long option try: names.append(f'-{field_.metadata["shortopt"]}') except KeyError: pass return names def _prepare_bool(ap: ArgumentParser, name, field_): arg_cfg = decide_default(field_) required = "required" in arg_cfg bool_parser = ap.add_mutually_exclusive_group(required=required) bool_parser.add_argument( make_gnu_option(name), action="store_true", dest=name ) bool_parser.add_argument( make_gnu_option(f"no_{name}"), action="store_false", dest=name ) if not required: ap.set_defaults(**{name: arg_cfg["default"]}) def _prepare_list_cfg(name, field_): arg_cfg = { **decide_default(field_), **get_choices(field_), } subtype = typing.get_args(field_.type) if not subtype: arg_cfg["type"] = str else: arg_cfg["type"] = subtype[0] if field_.metadata.get("allow_empty", False): arg_cfg["nargs"] = "*" else: arg_cfg["nargs"] = "+" return arg_cfg def _prepare_trivial_cfg(name, field_): arg_cfg = { **decide_default(field_), **get_choices(field_), } arg_cfg["type"] = field_.type return arg_cfg def _prepare_list(ap: ArgumentParser, name, field_): arg_cfg = _prepare_list_cfg(name, field_) arg_names = compute_arg_names(name, field_) ap.add_argument(*arg_names, **arg_cfg) def _prepare_trivial(ap: ArgumentParser, name, field_): arg_cfg = _prepare_trivial_cfg(name, field_) arg_names = compute_arg_names(name, field_) ap.add_argument(*arg_names, **arg_cfg) def prepare_field(ap, name, field_): if field_.type is bool: _prepare_bool(ap, name, field_) elif field_.type is list or typing.get_origin(field_.type) is list: _prepare_list(ap, name, field_) else: _prepare_trivial(ap, name, field_) def argclass(cls): @classmethod def parse_args(cls, argv): ap = ArgumentParser() for name, field_ in cls.__dataclass_fields__.items(): prepare_field(ap, name, field_) return cls(**vars(ap.parse_args(argv))) cls = dataclass(cls) cls.parse_args = parse_args return cls