根据baselines库修改的运行输入参数的解析代码

网友投稿 335 2022-09-03


根据baselines库修改的运行输入参数的解析代码

如题:

def arg_parser(): """ Create an empty argparse.ArgumentParser. """ import argparse parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) parser.add_argument('--env', help='environment ID', type=str, default='Reacher-v2') parser.add_argument('--env_type', help='type of environment, used when the environment type cannot be automatically determined', type=str) parser.add_argument('--seed', help='RNG seed', type=int, default=None) parser.add_argument('--alg', help='Algorithm', type=str, default='ppo2') parser.add_argument('--num_timesteps', type=float, default=1e6), parser.add_argument('--network', help='network type (mlp, cnn, lstm, cnn_lstm, conv_only)', default=None) parser.add_argument('--gamestate', help='game state to load (so far only used in retro games)', default=None) parser.add_argument('--num_env', help='Number of environment copies being run in parallel. When not specified, set to number of cpus for Atari, and to 1 for Mujoco', default=None, type=int) parser.add_argument('--reward_scale', help='Reward scale factor. Default: 1.0', default=1.0, type=float) parser.add_argument('--save_path', help='Path to save trained model to', default=None, type=str) parser.add_argument('--save_video_interval', help='Save video every x steps (0 = disabled)', default=0, type=int) parser.add_argument('--save_video_length', help='Length of recorded video. Default: 200', default=200, type=int) parser.add_argument('--log_path', help='Directory to save learning curve data.', default=None, type=str) parser.add_argument('--play', default=False, action='store_true') return parser.parse_known_args()def parse_unknown_args(args): """ Parse arguments not consumed by arg parser into a dictionary """ retval = {} preceded_by_key = False for arg in args: if arg.startswith('--'): if '=' in arg: key = arg.split('=')[0][2:] value = arg.split('=')[1] retval[key] = value else: key = arg[2:] preceded_by_key = True elif preceded_by_key: retval[key] = arg preceded_by_key = False return retvaldef parse_cmdline_kwargs(args, unknown_args): ''' convert a list of '='-spaced command-line arguments to a dictionary, evaluating python objects when possible ''' def parse(v): assert isinstance(v, str) try: return eval(v) except (NameError, SyntaxError): return v args.__dict__.update({k: parse(v) for k,v in parse_unknown_args(unknown_args).items()}) return argsargs, unknown_args = arg_parser()print(args)args = parse_cmdline_kwargs(args, unknown_args)print(args)

运行:

python test.py --aaa=me --xxx=11.11  --abc=True   --cde=1+99

解析结果:

Namespace(alg='ppo2', env='Reacher-v2', env_type=None, gamestate=None, log_path=None, network=None, num_env=None, num_timesteps=1000000.0, play=False, reward_scale=1.0, save_path=None, save_video_interval=0, save_video_length=200, seed=None)

Namespace(aaa='me', abc=True, alg='ppo2', cde=100, env='Reacher-v2', env_type=None, gamestate=None, log_path=None, network=None, num_env=None, num_timesteps=1000000.0, play=False, reward_scale=1.0, save_path=None, save_video_interval=0, save_video_length=200, seed=None, xxx=11.11)

=======================================

比较规范的运行参数解析的代码,方便后续代码中对参数的调用。


版权声明:本文内容由网络用户投稿,版权归原作者所有,本站不拥有其著作权,亦不承担相应法律责任。如果您发现本站中有涉嫌抄袭或描述失实的内容,请联系我们jiasou666@gmail.com 处理,核实后本网站将在24小时内删除侵权内容。

上一篇:JAVA实现对阿里云DNS的解析管理
下一篇:一文弄懂Python中的 if __name__ == __main__(简书 python)
相关文章

 发表评论

暂时没有评论,来抢沙发吧~