Files
argclass/argclass.py

113 lines
2.9 KiB
Python

import typing
from argparse import ArgumentParser
from dataclasses import dataclass, MISSING
def make_gnu_option(name):
return f'--{name.replace("_", "-")}'
def decide_default(field_):
arg_cfg = {}
if field_.default != MISSING:
arg_cfg['default'] = field_.default
elif field_.default_factory != MISSING:
arg_cfg['default'] = field_.default_factory()
else:
arg_cfg['required'] = True
return arg_cfg
def get_choices(field_):
arg_cfg = {}
try:
arg_cfg['choices'] = field_.metadata['choices']
except KeyError:
pass
return arg_cfg
def compute_arg_names(name, field_):
names = [make_gnu_option(name)] # long option
try:
names.append(f'-{field_.metadata["shortopt"]}')
except KeyError:
pass
return names
def _prepare_bool(ap: ArgumentParser, name, field_):
arg_cfg = decide_default(field_)
required = 'required' in arg_cfg
bool_parser = ap.add_mutually_exclusive_group(required=required)
bool_parser.add_argument(make_gnu_option(name),
action='store_true',
dest=name)
bool_parser.add_argument(make_gnu_option(f'no_{name}'),
action='store_false',
dest=name)
if not required:
ap.set_defaults(**{name: arg_cfg['default']})
def _prepare_list_cfg(name, field_):
arg_cfg = {
**decide_default(field_),
**get_choices(field_),
}
subtype = typing.get_args(field_.type)
if not subtype:
arg_cfg['type'] = str
else:
arg_cfg['type'] = subtype[0]
if field_.metadata.get('allow_empty', False):
arg_cfg['nargs'] = '*'
else:
arg_cfg['nargs'] = '+'
return arg_cfg
def _prepare_trivial_cfg(name, field_):
arg_cfg = {
**decide_default(field_),
**get_choices(field_),
}
arg_cfg['type'] = field_.type
return arg_cfg
def _prepare_list(ap: ArgumentParser, name, field_):
arg_cfg = _prepare_list_cfg(name, field_)
arg_names = compute_arg_names(name, field_)
ap.add_argument(*arg_names, **arg_cfg)
def _prepare_trivial(ap: ArgumentParser, name, field_):
arg_cfg = _prepare_trivial_cfg(name, field_)
arg_names = compute_arg_names(name, field_)
ap.add_argument(*arg_names, **arg_cfg)
def prepare_field(ap, name, field_):
if field_.type is bool:
_prepare_bool(ap, name, field_)
elif field_.type is list or typing.get_origin(field_.type) is list:
_prepare_list(ap, name, field_)
else:
_prepare_trivial(ap, name, field_)
def argclass(cls):
@classmethod
def parse_args(cls, argv):
ap = ArgumentParser()
for name, field_ in cls.__dataclass_fields__.items():
prepare_field(ap, name, field_)
return cls(**vars(ap.parse_args(argv)))
cls = dataclass(cls)
cls.parse_args = parse_args
return cls