def trainer(net, env, args):
    # logger
    exp_name = args.exp_name + '_' + args.RL_name + '_' + args.env_name
    logger_kwargs = setup_logger_kwargs(exp_name=exp_name,
                                        seed=args.seed,
                                        output_dir=args.output_dir + "/")
    logger = EpochLogger(**logger_kwargs)
    if proc_id() == 0:
        sys.stdout = Logger(logger_kwargs["output_dir"] + "/print.log",
                            sys.stdout)
        logger.save_config(locals(), __file__)
    # start running
    start_time = time.time()
    for i in range(args.n_epochs):
        test_ep_reward, logger = net.test_agent(
            args=args,
            env=env,
            n=10,
            logger=logger,
            obs2state=obs2state,
        )
        logger.store(TestEpRet=test_ep_reward)

        logger.log_tabular('Epoch', i)
        logger.log_tabular('TestEpRet', average_only=True)

        logger.log_tabular('TestSuccess', average_only=True)
        logger.dump_tabular()

    print(
        colorize("the experience %s is end" % logger.output_file.name,
                 'green',
                 bold=True))
    net.save_simple_network(logger_kwargs["output_dir"])
    net.save_norm(logger_kwargs["output_dir"])
    net.save_replay_buffer(logger_kwargs["output_dir"])
Ejemplo n.º 2
0
def trainer(net, env, args):
    # logger
    exp_name = args.exp_name + '_' + args.RL_name + '_' + args.env_name
    logger_kwargs = setup_logger_kwargs(exp_name=exp_name,
                                        seed=args.seed,
                                        output_dir=args.output_dir + "/")
    logger = EpochLogger(**logger_kwargs)
    if proc_id() == 0:
        sys.stdout = Logger(logger_kwargs["output_dir"] + "/print.log",
                            sys.stdout)
        logger.save_config(locals(), __file__)
    # start running
    start_time = time.time()
    for i in range(args.n_epochs):
        for c in range(args.n_cycles):
            obs = env.reset()
            episode_trans = []
            s = obs2state(obs)
            ep_reward = 0
            real_ep_reward = 0
            episode_time = time.time()

            success = []
            for j in range(args.n_steps):
                a = net.get_action(s, noise_scale=args.noise_ps)
                # a = net.get_action(s)
                # a = noise.add_noise(a)

                if np.random.random() < args.random_eps:
                    a = np.random.uniform(low=-net.a_bound,
                                          high=net.a_bound,
                                          size=net.act_dim)
                a = np.clip(a, -net.a_bound, net.a_bound)
                # ensure the gripper close!
                try:
                    obs_next, r, done, info = env.step(a)
                    success.append(info["is_success"])
                except Exception as e:
                    success.append(int(done))
                s_ = obs2state(obs_next)

                # visualization
                if args.render and i % 3 == 0 and c % 20 == 0:
                    env.render()

                # 防止gym中的最大step会返回done=True
                done = False if j == args.n_steps - 1 else done

                if not args.her:
                    net.store_transition((s, a, r, s_, done))

                episode_trans.append([obs, a, r, obs_next, done, info])
                s = s_
                obs = obs_next
                ep_reward += r
                real_ep_reward += r
            if args.her:
                net.save_episode(episode_trans=episode_trans,
                                 reward_func=env.compute_reward,
                                 obs2state=obs2state)
            logger.store(EpRet=ep_reward)
            logger.store(EpRealRet=real_ep_reward)

            for _ in range(40):
                outs = net.learn(
                    args.batch_size,
                    args.base_lr,
                    args.base_lr * 2,
                )
                if outs[1] is not None:
                    logger.store(Q1=outs[1])
                    logger.store(Q2=outs[2])
            if 0.0 < sum(success) < args.n_steps:
                print("epoch:", i, "\tep:", c, "\tep_rew:",
                      ep_reward, "\ttime:",
                      np.round(time.time() - episode_time, 3), '\tdone:',
                      sum(success))

        test_ep_reward, logger = net.test_agent(
            args=args,
            env=env,
            n=10,
            logger=logger,
            obs2state=obs2state,
        )
        logger.store(TestEpRet=test_ep_reward)

        logger.log_tabular('Epoch', i)
        logger.log_tabular('EpRet', average_only=True)
        logger.log_tabular('EpRealRet', average_only=True)
        logger.log_tabular('TestEpRet', average_only=True)

        logger.log_tabular('Q1', with_min_and_max=True)
        logger.log_tabular('Q2', average_only=True)

        logger.log_tabular('TestSuccess', average_only=True)

        logger.log_tabular(
            'TotalEnvInteracts',
            i * args.n_cycles * args.n_steps + c * args.n_steps + j + 1)
        logger.log_tabular('TotalTime', time.time() - start_time)
        logger.dump_tabular()

    print(
        colorize("the experience %s is end" % logger.output_file.name,
                 'green',
                 bold=True))
    net.save_simple_network(logger_kwargs["output_dir"])