From c03fe2f2840647ece78c5fed02744d5af9ea35e2 Mon Sep 17 00:00:00 2001 From: Pavel Lutskov Date: Sat, 27 Mar 2021 15:42:20 +0100 Subject: [PATCH] Refactor the parser and remove caveman's testing --- argclass.py | 45 +++++++++++++++++---------------------------- 1 file changed, 17 insertions(+), 28 deletions(-) diff --git a/argclass.py b/argclass.py index 61b3b16..55a8198 100644 --- a/argclass.py +++ b/argclass.py @@ -1,6 +1,6 @@ import typing from argparse import ArgumentParser -from dataclasses import dataclass, field, MISSING +from dataclasses import dataclass, MISSING def make_gnu_option(name): @@ -55,7 +55,12 @@ def _prepare_list_cfg(name, field_): **decide_default(field_), **get_choices(field_), } - arg_cfg['type'] = typing.get_args(field_.type)[0] + subtype = typing.get_args(field_.type) + if not subtype: + arg_cfg['type'] = str + else: + arg_cfg['type'] = subtype[0] + if field_.metadata.get('allow_empty', False): arg_cfg['nargs'] = '*' else: @@ -87,7 +92,7 @@ def _prepare_trivial(ap: ArgumentParser, name, field_): def prepare_field(ap, name, field_): if field_.type is bool: _prepare_bool(ap, name, field_) - elif typing.get_origin(field_.type) is list: + elif field_.type is list or typing.get_origin(field_.type) is list: _prepare_list(ap, name, field_) else: _prepare_trivial(ap, name, field_) @@ -95,29 +100,13 @@ def prepare_field(ap, name, field_): def argclass(cls): - class ArgClass(dataclass(cls)): + @classmethod + def parse_args(cls, argv): + ap = ArgumentParser() + for name, field_ in cls.__dataclass_fields__.items(): + prepare_field(ap, name, field_) + return cls(**vars(ap.parse_args(argv))) - @classmethod - def parse_args(cls): - ap = ArgumentParser() - for name, field_ in cls.__dataclass_fields__.items(): - prepare_field(ap, name, field_) - return ap.parse_args() - - return ArgClass - - -@argclass -class A: - required_int_arg: int - required_bool_arg: bool - required_int_arg_with_choices: int = field(metadata={'choices': [3, 4, 5]}) - arg_with_default: str = 'abc' - list_arg_with_default: list[int] = field(default_factory=lambda: [1, 2]) - list_arg_with_default_mb_empty: list[int] = field( - default_factory=lambda: [], metadata={'allow_empty': True} - ) - - -if __name__ == '__main__': - print(A.parse_args()) + cls = dataclass(cls) + cls.parse_args = parse_args + return cls