Refactor the parser and remove caveman's testing
This commit is contained in:
37
argclass.py
37
argclass.py
@@ -1,6 +1,6 @@
|
||||
import typing
|
||||
from argparse import ArgumentParser
|
||||
from dataclasses import dataclass, field, MISSING
|
||||
from dataclasses import dataclass, MISSING
|
||||
|
||||
|
||||
def make_gnu_option(name):
|
||||
@@ -55,7 +55,12 @@ def _prepare_list_cfg(name, field_):
|
||||
**decide_default(field_),
|
||||
**get_choices(field_),
|
||||
}
|
||||
arg_cfg['type'] = typing.get_args(field_.type)[0]
|
||||
subtype = typing.get_args(field_.type)
|
||||
if not subtype:
|
||||
arg_cfg['type'] = str
|
||||
else:
|
||||
arg_cfg['type'] = subtype[0]
|
||||
|
||||
if field_.metadata.get('allow_empty', False):
|
||||
arg_cfg['nargs'] = '*'
|
||||
else:
|
||||
@@ -87,7 +92,7 @@ def _prepare_trivial(ap: ArgumentParser, name, field_):
|
||||
def prepare_field(ap, name, field_):
|
||||
if field_.type is bool:
|
||||
_prepare_bool(ap, name, field_)
|
||||
elif typing.get_origin(field_.type) is list:
|
||||
elif field_.type is list or typing.get_origin(field_.type) is list:
|
||||
_prepare_list(ap, name, field_)
|
||||
else:
|
||||
_prepare_trivial(ap, name, field_)
|
||||
@@ -95,29 +100,13 @@ def prepare_field(ap, name, field_):
|
||||
|
||||
def argclass(cls):
|
||||
|
||||
class ArgClass(dataclass(cls)):
|
||||
|
||||
@classmethod
|
||||
def parse_args(cls):
|
||||
def parse_args(cls, argv):
|
||||
ap = ArgumentParser()
|
||||
for name, field_ in cls.__dataclass_fields__.items():
|
||||
prepare_field(ap, name, field_)
|
||||
return ap.parse_args()
|
||||
return cls(**vars(ap.parse_args(argv)))
|
||||
|
||||
return ArgClass
|
||||
|
||||
|
||||
@argclass
|
||||
class A:
|
||||
required_int_arg: int
|
||||
required_bool_arg: bool
|
||||
required_int_arg_with_choices: int = field(metadata={'choices': [3, 4, 5]})
|
||||
arg_with_default: str = 'abc'
|
||||
list_arg_with_default: list[int] = field(default_factory=lambda: [1, 2])
|
||||
list_arg_with_default_mb_empty: list[int] = field(
|
||||
default_factory=lambda: [], metadata={'allow_empty': True}
|
||||
)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
print(A.parse_args())
|
||||
cls = dataclass(cls)
|
||||
cls.parse_args = parse_args
|
||||
return cls
|
||||
|
||||
Reference in New Issue
Block a user