Ejemplo n.º 1
0
def main(args):
    # configure logger, disable logging in child MPI processes (with rank > 0)

    arg_parser = common_arg_parser()
    args, unknown_args = arg_parser.parse_known_args(args)
    extra_args = parse_cmdline_kwargs(unknown_args)

    if args.extra_import is not None:
        import_module(args.extra_import)

    if MPI is None or MPI.COMM_WORLD.Get_rank() == 0:
        rank = 0
        # logger.configure()
        # logger.configure(dir=log_path, format_strs=['stdout', 'log', 'csv', 'tensorboard'])
        logger.configure(dir=log_path, format_strs=['stdout', 'csv'])
    else:
        logger.configure(format_strs=[])
        rank = MPI.COMM_WORLD.Get_rank()

    model, env = train(args, extra_args)

    if args.save_path is not None and rank == 0:
        save_path = osp.expanduser(args.save_path)
        model.save(save_path)

    if args.play:
        logger.log("Running trained model")
        obs = env.reset()

        state = model.initial_state if hasattr(model,
                                               'initial_state') else None
        dones = np.zeros((1, ))

        episode_rew = 0
        i = 0
        while True:
            if state is not None:
                actions, _, state, _ = model.step(obs, S=state, M=dones)
            else:
                actions, _, _, _, _ = model.step(obs)
                # actions, _, _, _ = model.step(obs)
            obs, rew, done, _ = env.step(actions)
            episode_rew += rew[0] if isinstance(env, VecEnv) else rew
            env.render()
            done = done.any() if isinstance(done, np.ndarray) else done
            i += 1
            if done:
                print(f'episode_rew={episode_rew}')
                print(i)
                episode_rew = 0
                obs = env.reset()

    env.close()
    model.sess.close()
    tf.reset_default_graph()
    return model
Ejemplo n.º 2
0
def main(args):
    # configure logger, disable logging in child MPI processes (with rank > 0)

    arg_parser = common_arg_parser()
    args, unknown_args = arg_parser.parse_known_args(args)
    extra_args = parse_cmdline_kwargs(unknown_args)

    if MPI is None or MPI.COMM_WORLD.Get_rank() == 0:
        rank = 0
        logger.configure()
    else:
        logger.configure(format_strs=[])
        rank = MPI.COMM_WORLD.Get_rank()

    model, env = train(args, extra_args)
    env.close()

    if args.save_path is not None and rank == 0:
        save_path = osp.expanduser(args.save_path)
        model.save(save_path)

    if args.play:
        logger.log("Running trained model")
        env = build_env(args)
        # Start at start state
        if args.demo:
            env.starting_positions = get_all_states(args.env)

        obs = env.reset()

        def initialize_placeholders(nlstm=128, **kwargs):
            return np.zeros((args.num_env or 1, 2 * nlstm)), np.zeros((1))

        state, dones = initialize_placeholders(**extra_args)
        while True:
            actions, _, state, _ = model.step(obs, S=state, M=dones)
            obs, _, done, _ = env.step(actions)
            env.render()
            done = done.any() if isinstance(done, np.ndarray) else done

            if done:
                obs = env.reset()

        env.close()

    return model
Ejemplo n.º 3
0
def testbaselines(args):
    # 用baseline对环境的包装方法,尝试运行最简代码
    arg_parser = common_arg_parser()
    args, unknown_args = arg_parser.parse_known_args(args)
    #extra_args = parse_cmdline_kwargs(unknown_args)

    env_type, env_id = get_env_type(args)
    print('env_type: {}'.format(env_type))
    env = make_atari(env_id)
    #env = build_env(args)
    print("env builded ",env)
    obs = env.reset()
    reset = True

    print("env reseted")
    for t in range(10000):
        env.render()
        action = env.action_space.sample()
        new_obs, rew, done, _ = env.step(action)
        obs = new_obs
        if done:
            print(done)
            obs = env.reset()
            reset = True
        rank = 0
        logger.configure()
    else:
        logger.configure(format_strs=[])
        rank = MPI.COMM_WORLD.Get_rank()

    model, env = retrain(args, extra_args, save_path)
    model.save(save_path)

    # env.close()


if __name__ == '__main__':

    # parser = argparse.ArgumentParser(description='Train or test neural net motor controller.')
    # parser.add_argument('--train', dest='train', action='store_true', default=False)
    # parser.add_argument('--test', dest='test', action='store_true', default=True)
    save_path = './models/ddpg'
    # args = parser.parse_args()
    arg_parser = common_arg_parser()
    args, unknown_args = arg_parser.parse_known_args()
    extra_args = parse_cmdline_kwargs(unknown_args)

    # need to registrate in common/cmd_util.py
    if args.train:
        main(args, extra_args, save_path)
    if args.test:
        main_test(args, extra_args, save_path)
    if args.retrain:  # train based on the train checkpoint stored in save_path
        main_retrain(args, extra_args, save_path)