コード例 #1
0
def test_collector_with_dict_state():
    env = MyTestEnv(size=5, sleep=0, dict_state=True)
    policy = MyPolicy(dict_state=True)
    c0 = Collector(policy, env, ReplayBuffer(size=100), preprocess_fn)
    c0.collect(n_step=3)
    c0.collect(n_episode=3)
    env_fns = [
        lambda: MyTestEnv(size=2, sleep=0, dict_state=True),
        lambda: MyTestEnv(size=3, sleep=0, dict_state=True),
        lambda: MyTestEnv(size=4, sleep=0, dict_state=True),
        lambda: MyTestEnv(size=5, sleep=0, dict_state=True),
    ]
    envs = VectorEnv(env_fns)
    c1 = Collector(policy, envs, ReplayBuffer(size=100), preprocess_fn)
    c1.collect(n_step=10)
    c1.collect(n_episode=[2, 1, 1, 2])
    batch = c1.sample(10)
    print(batch)
    c0.buffer.update(c1.buffer)
    assert equal(c0.buffer[:len(c0.buffer)].obs.index, [
        0., 1., 2., 3., 4., 0., 1., 2., 3., 4., 0., 1., 2., 3., 4., 0., 1., 0.,
        1., 2., 0., 1., 0., 1., 2., 3., 0., 1., 2., 3., 4., 0., 1., 0., 1., 2.,
        0., 1., 0., 1., 2., 3., 0., 1., 2., 3., 4.
    ])
    c2 = Collector(policy, envs, ReplayBuffer(size=100, stack_num=4),
                   preprocess_fn)
    c2.collect(n_episode=[0, 0, 0, 10])
    batch = c2.sample(10)
    print(batch['obs_next']['index'])
コード例 #2
0
def test_collector_with_ma():
    def reward_metric(x):
        return x.sum()
    env = MyTestEnv(size=5, sleep=0, ma_rew=4)
    policy = MyPolicy()
    c0 = Collector(policy, env, ReplayBuffer(size=100),
                   preprocess_fn, reward_metric=reward_metric)
    r = c0.collect(n_step=3)['rew']
    assert np.asanyarray(r).size == 1 and r == 0.
    r = c0.collect(n_episode=3)['rew']
    assert np.asanyarray(r).size == 1 and r == 4.
    env_fns = [lambda x=i: MyTestEnv(size=x, sleep=0, ma_rew=4)
               for i in [2, 3, 4, 5]]
    envs = VectorEnv(env_fns)
    c1 = Collector(policy, envs, ReplayBuffer(size=100),
                   preprocess_fn, reward_metric=reward_metric)
    r = c1.collect(n_step=10)['rew']
    assert np.asanyarray(r).size == 1 and r == 4.
    r = c1.collect(n_episode=[2, 1, 1, 2])['rew']
    assert np.asanyarray(r).size == 1 and r == 4.
    batch = c1.sample(10)
    print(batch)
    c0.buffer.update(c1.buffer)
    obs = [
        0., 1., 2., 3., 4., 0., 1., 2., 3., 4., 0., 1., 2., 3., 4., 0., 1.,
        0., 1., 2., 0., 1., 0., 1., 2., 3., 0., 1., 2., 3., 4., 0., 1., 0.,
        1., 2., 0., 1., 0., 1., 2., 3., 0., 1., 2., 3., 4.]
    assert np.allclose(c0.buffer[:len(c0.buffer)].obs, obs)
    rew = [0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 1,
           0, 0, 1, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 1, 0,
           0, 1, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 1]
    assert np.allclose(c0.buffer[:len(c0.buffer)].rew,
                       [[x] * 4 for x in rew])
    c2 = Collector(policy, envs, ReplayBuffer(size=100, stack_num=4),
                   preprocess_fn, reward_metric=reward_metric)
    r = c2.collect(n_episode=[0, 0, 0, 10])['rew']
    assert np.asanyarray(r).size == 1 and r == 4.
    batch = c2.sample(10)
    print(batch['obs_next'])
コード例 #3
0
ファイル: test_collector.py プロジェクト: tao9/tianshou
def test_collector_with_dict_state():
    env = MyTestEnv(size=5, sleep=0, dict_state=True)
    policy = MyPolicy(dict_state=True)
    c0 = Collector(policy, env, ReplayBuffer(size=100))
    c0.collect(n_step=3)
    c0.collect(n_episode=3)
    env_fns = [
        lambda: MyTestEnv(size=2, sleep=0, dict_state=True),
        lambda: MyTestEnv(size=3, sleep=0, dict_state=True),
        lambda: MyTestEnv(size=4, sleep=0, dict_state=True),
        lambda: MyTestEnv(size=5, sleep=0, dict_state=True),
    ]
    envs = VectorEnv(env_fns)
    c1 = Collector(policy, envs, ReplayBuffer(size=100))
    c1.collect(n_step=10)
    c2 = Collector(policy, envs, ReplayBuffer(size=100, stack_num=4))
    c2.collect(n_episode=10)
    batch = c2.sample(10)
    print(batch['obs_next']['index'])
コード例 #4
0
def test_dqn(args=get_args()):
    env = make_atari_env(args)
    args.state_shape = env.observation_space.shape or env.observation_space.n
    args.action_shape = env.env.action_space.shape or env.env.action_space.n

    # should be N_FRAMES x H x W

    print("Observations shape: ", args.state_shape)
    print("Actions shape: ", args.action_shape)
    # make environments
    train_envs = SubprocVectorEnv(
        [lambda: make_atari_env(args) for _ in range(args.training_num)])
    test_envs = SubprocVectorEnv(
        [lambda: make_atari_env_watch(args) for _ in range(args.test_num)])
    # seed
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    train_envs.seed(args.seed)
    test_envs.seed(args.seed)
    # log
    log_path = os.path.join(args.logdir, args.task, 'embedding')
    embedding_writer = SummaryWriter(log_path + '/with_init')

    embedding_net = embedding_prediction.Prediction(
        *args.state_shape, args.action_shape,
        args.device).to(device=args.device)
    embedding_net.apply(embedding_prediction.weights_init)

    if args.embedding_path:
        embedding_net.load_state_dict(torch.load(log_path + '/embedding.pth'))
        print("Loaded agent from: ", log_path + '/embedding.pth')
    # numel_list = [p.numel() for p in embedding_net.parameters()]
    # print(sum(numel_list), numel_list)

    pre_buffer = ReplayBuffer(args.buffer_size,
                              save_only_last_obs=True,
                              stack_num=args.frames_stack)
    pre_test_buffer = ReplayBuffer(args.buffer_size // 100,
                                   save_only_last_obs=True,
                                   stack_num=args.frames_stack)

    train_collector = Collector(None, train_envs, pre_buffer)
    test_collector = Collector(None, test_envs, pre_test_buffer)
    if args.embedding_data_path:
        pre_buffer = pickle.load(open(log_path + '/train_data.pkl', 'rb'))
        pre_test_buffer = pickle.load(open(log_path + '/test_data.pkl', 'rb'))
        train_collector.buffer = pre_buffer
        test_collector.buffer = pre_test_buffer
        print('load success')
    else:
        print('collect start')
        train_collector = Collector(None, train_envs, pre_buffer)
        test_collector = Collector(None, test_envs, pre_test_buffer)
        train_collector.collect(n_step=args.buffer_size, random=True)
        test_collector.collect(n_step=args.buffer_size // 100, random=True)
        print(len(train_collector.buffer))
        print(len(test_collector.buffer))
        if not os.path.exists(log_path):
            os.makedirs(log_path)
        pickle.dump(pre_buffer, open(log_path + '/train_data.pkl', 'wb'))
        pickle.dump(pre_test_buffer, open(log_path + '/test_data.pkl', 'wb'))
        print('collect finish')

    #使用得到的数据训练编码网络
    def part_loss(x, device='cpu'):
        if not isinstance(x, torch.Tensor):
            x = torch.tensor(x, device=device, dtype=torch.float32)
        x = x.view(128, -1)
        temp = torch.cat(
            ((1 - x).pow(2.0).unsqueeze_(0), x.pow(2.0).unsqueeze_(0)), dim=0)
        temp_2 = torch.min(temp, dim=0)[0]
        return torch.sum(temp_2)

    pre_optim = torch.optim.Adam(embedding_net.parameters(), lr=1e-5)
    scheduler = torch.optim.lr_scheduler.StepLR(pre_optim,
                                                step_size=50000,
                                                gamma=0.1,
                                                last_epoch=-1)
    test_batch_data = test_collector.sample(batch_size=0)

    loss_fn = torch.nn.NLLLoss()
    # train_loss = []
    for epoch in range(1, 100001):
        embedding_net.train()
        batch_data = train_collector.sample(batch_size=128)
        # print(batch_data)
        # print(batch_data['obs'][0] == batch_data['obs'][1])
        pred = embedding_net(batch_data['obs'], batch_data['obs_next'])
        x1 = pred[1]
        x2 = pred[2]
        # print(torch.argmax(pred[0], dim=1))
        if not isinstance(batch_data['act'], torch.Tensor):
            act = torch.tensor(batch_data['act'],
                               device=args.device,
                               dtype=torch.int64)
        # print(pred[0].dtype)
        # print(act.dtype)
        # l2_norm = sum(p.pow(2.0).sum() for p in embedding_net.net.parameters())
        # loss = loss_fn(pred[0], act) + 0.001 * (part_loss(x1) + part_loss(x2)) / 64
        loss_1 = loss_fn(pred[0], act)
        loss_2 = 0.01 * (part_loss(x1, args.device) +
                         part_loss(x2, args.device)) / 128
        loss = loss_1 + loss_2
        # print(loss_1)
        # print(loss_2)
        embedding_writer.add_scalar('training loss1', loss_1.item(), epoch)
        embedding_writer.add_scalar('training loss2', loss_2, epoch)
        embedding_writer.add_scalar('training loss', loss.item(), epoch)
        # train_loss.append(loss.detach().item())
        pre_optim.zero_grad()
        loss.backward()
        pre_optim.step()
        scheduler.step()

        if epoch % 10000 == 0 or epoch == 1:
            print(pre_optim.state_dict()['param_groups'][0]['lr'])
            # print("Epoch: %d,Train: Loss: %f" % (epoch, float(loss.item())))
            correct = 0
            numel_list = [p for p in embedding_net.parameters()][-2]
            print(numel_list)
            embedding_net.eval()
            with torch.no_grad():
                test_pred, x1, x2, _ = embedding_net(
                    test_batch_data['obs'], test_batch_data['obs_next'])
                if not isinstance(test_batch_data['act'], torch.Tensor):
                    act = torch.tensor(test_batch_data['act'],
                                       device=args.device,
                                       dtype=torch.int64)
                loss_1 = loss_fn(test_pred, act)
                loss_2 = 0.01 * (part_loss(x1, args.device) +
                                 part_loss(x2, args.device)) / 128
                loss = loss_1 + loss_2
                embedding_writer.add_scalar('test loss', loss.item(), epoch)
                # print("Test Loss: %f" % (float(loss)))
                print(torch.argmax(test_pred, dim=1))
                print(act)
                correct += int((torch.argmax(test_pred, dim=1) == act).sum())
                print('Acc:', correct / len(test_batch_data))

    torch.save(embedding_net.state_dict(),
               os.path.join(log_path, 'embedding.pth'))
    embedding_writer.close()
    # plt.figure()
    # plt.plot(np.arange(100000),train_loss)
    # plt.show()
    exit()
    #构建hash表

    # log
    log_path = os.path.join(args.logdir, args.task, 'dqn')
    writer = SummaryWriter(log_path)

    # define model
    net = DQN(*args.state_shape, args.action_shape,
              args.device).to(device=args.device)

    optim = torch.optim.Adam(net.parameters(), lr=args.lr)
    # define policy
    policy = DQNPolicy(net,
                       optim,
                       args.gamma,
                       args.n_step,
                       target_update_freq=args.target_update_freq)
    # load a previous policy
    if args.resume_path:
        policy.load_state_dict(torch.load(args.resume_path))
        print("Loaded agent from: ", args.resume_path)

    # replay buffer: `save_last_obs` and `stack_num` can be removed together
    # when you have enough RAM
    pre_buffer.reset()
    buffer = ReplayBuffer(args.buffer_size,
                          ignore_obs_next=True,
                          save_only_last_obs=True,
                          stack_num=args.frames_stack)
    # collector
    # train_collector中传入preprocess_fn对奖励进行重构
    train_collector = Collector(policy, train_envs, buffer)
    test_collector = Collector(policy, test_envs)

    def save_fn(policy):
        torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))

    def stop_fn(x):
        if env.env.spec.reward_threshold:
            return x >= env.spec.reward_threshold
        elif 'Pong' in args.task:
            return x >= 20

    def train_fn(x):
        # nature DQN setting, linear decay in the first 1M steps
        now = x * args.collect_per_step * args.step_per_epoch
        if now <= 1e6:
            eps = args.eps_train - now / 1e6 * \
                (args.eps_train - args.eps_train_final)
            policy.set_eps(eps)
        else:
            policy.set_eps(args.eps_train_final)
        print("set eps =", policy.eps)

    def test_fn(x):
        policy.set_eps(args.eps_test)

    # watch agent's performance
    def watch():
        print("Testing agent ...")
        policy.eval()
        policy.set_eps(args.eps_test)
        test_envs.seed(args.seed)
        test_collector.reset()
        result = test_collector.collect(n_episode=[1] * args.test_num,
                                        render=1 / 30)
        pprint.pprint(result)

    if args.watch:
        watch()
        exit(0)

    # test train_collector and start filling replay buffer
    train_collector.collect(n_step=args.batch_size * 4)
    # trainer
    result = offpolicy_trainer(policy,
                               train_collector,
                               test_collector,
                               args.epoch,
                               args.step_per_epoch,
                               args.collect_per_step,
                               args.test_num,
                               args.batch_size,
                               train_fn=train_fn,
                               test_fn=test_fn,
                               stop_fn=stop_fn,
                               save_fn=save_fn,
                               writer=writer,
                               test_in_train=False)

    pprint.pprint(result)
    watch()
コード例 #5
0
ファイル: offpolicy.py プロジェクト: xxyqsy/tianshou
def offpolicy_trainer(policy: BasePolicy,
                      train_collector: Collector,
                      test_collector: Collector,
                      max_epoch: int,
                      step_per_epoch: int,
                      collect_per_step: int,
                      episode_per_test: Union[int, List[int]],
                      batch_size: int,
                      update_per_step: int = 1,
                      train_fn: Optional[Callable[[int], None]] = None,
                      test_fn: Optional[Callable[[int], None]] = None,
                      stop_fn: Optional[Callable[[float], bool]] = None,
                      save_fn: Optional[Callable[[BasePolicy], None]] = None,
                      log_fn: Optional[Callable[[dict], None]] = None,
                      writer: Optional[SummaryWriter] = None,
                      log_interval: int = 1,
                      verbose: bool = True,
                      **kwargs) -> Dict[str, Union[float, str]]:
    """A wrapper for off-policy trainer procedure.

    :param policy: an instance of the :class:`~tianshou.policy.BasePolicy`
        class.
    :param train_collector: the collector used for training.
    :type train_collector: :class:`~tianshou.data.Collector`
    :param test_collector: the collector used for testing.
    :type test_collector: :class:`~tianshou.data.Collector`
    :param int max_epoch: the maximum of epochs for training. The training
        process might be finished before reaching the ``max_epoch``.
    :param int step_per_epoch: the number of step for updating policy network
        in one epoch.
    :param int collect_per_step: the number of frames the collector would
        collect before the network update. In other words, collect some frames
        and do some policy network update.
    :param episode_per_test: the number of episodes for one policy evaluation.
    :param int batch_size: the batch size of sample data, which is going to
        feed in the policy network.
    :param int update_per_step: the number of times the policy network would
        be updated after frames be collected. In other words, collect some
        frames and do some policy network update.
    :param function train_fn: a function receives the current number of epoch
        index and performs some operations at the beginning of training in this
        epoch.
    :param function test_fn: a function receives the current number of epoch
        index and performs some operations at the beginning of testing in this
        epoch.
    :param function save_fn: a function for saving policy when the undiscounted
        average mean reward in evaluation phase gets better.
    :param function stop_fn: a function receives the average undiscounted
        returns of the testing result, return a boolean which indicates whether
        reaching the goal.
    :param function log_fn: a function receives env info for logging.
    :param torch.utils.tensorboard.SummaryWriter writer: a TensorBoard
        SummaryWriter.
    :param int log_interval: the log interval of the writer.
    :param bool verbose: whether to print the information.

    :return: See :func:`~tianshou.trainer.gather_info`.
    """
    global_step = 0
    best_epoch, best_reward = -1, -1
    stat = {}
    start_time = time.time()
    test_in_train = train_collector.policy == policy
    for epoch in range(1, 1 + max_epoch):
        # train
        policy.train()
        if train_fn:
            train_fn(epoch)
        with tqdm.tqdm(total=step_per_epoch,
                       desc=f'Epoch #{epoch}',
                       **tqdm_config) as t:
            while t.n < t.total:
                result = train_collector.collect(n_step=collect_per_step,
                                                 log_fn=log_fn)
                data = {}
                if test_in_train and stop_fn and stop_fn(result['rew']):
                    test_result = test_episode(policy, test_collector, test_fn,
                                               epoch, episode_per_test)
                    if stop_fn and stop_fn(test_result['rew']):
                        if save_fn:
                            save_fn(policy)
                        for k in result.keys():
                            data[k] = f'{result[k]:.2f}'
                        t.set_postfix(**data)
                        return gather_info(start_time, train_collector,
                                           test_collector, test_result['rew'])
                    else:
                        policy.train()
                        if train_fn:
                            train_fn(epoch)
                for i in range(update_per_step * min(
                        result['n/st'] // collect_per_step, t.total - t.n)):
                    global_step += 1
                    losses = policy.learn(train_collector.sample(batch_size))
                    for k in result.keys():
                        data[k] = f'{result[k]:.2f}'
                        if writer and global_step % log_interval == 0:
                            writer.add_scalar(k,
                                              result[k],
                                              global_step=global_step)
                    for k in losses.keys():
                        if stat.get(k) is None:
                            stat[k] = MovAvg()
                        stat[k].add(losses[k])
                        data[k] = f'{stat[k].get():.6f}'
                        if writer and global_step % log_interval == 0:
                            writer.add_scalar(k,
                                              stat[k].get(),
                                              global_step=global_step)
                    t.update(1)
                    t.set_postfix(**data)
            if t.n <= t.total:
                t.update()
        # test
        result = test_episode(policy, test_collector, test_fn, epoch,
                              episode_per_test)
        if best_epoch == -1 or best_reward < result['rew']:
            best_reward = result['rew']
            best_epoch = epoch
            if save_fn:
                save_fn(policy)
        if verbose:
            print(f'Epoch #{epoch}: test_reward: {result["rew"]:.6f}, '
                  f'best_reward: {best_reward:.6f} in #{best_epoch}')
        if stop_fn and stop_fn(best_reward):
            break
    return gather_info(start_time, train_collector, test_collector,
                       best_reward)
コード例 #6
0
ファイル: cartpole_ppo.py プロジェクト: zeta1999/jiminy
def onpolicy_trainer(
        policy: BasePolicy,
        train_collector: Collector,
        test_collector: Collector,
        max_epoch: int,
        frame_per_epoch: int,
        collect_per_step: int,
        repeat_per_collect: int,
        episode_per_test: Union[int, List[int]],
        batch_size: int,
        train_fn: Optional[Callable[[int], None]] = None,
        test_fn: Optional[Callable[[int], None]] = None,
        stop_fn: Optional[Callable[[float], bool]] = None,
        save_fn: Optional[Callable[[BasePolicy], None]] = None,
        log_fn: Optional[Callable[[dict], None]] = None,
        writer: Optional[SummaryWriter] = None,
        log_interval: int = 1,
        verbose: bool = True,
        **kwargs
) -> Dict[str, Union[float, str]]:
    global_step = 0
    best_epoch, best_reward = -1, -1
    stat = {}
    start_time = time.time()
    test_in_train = train_collector.policy == policy
    for epoch in range(1, 1 + max_epoch):
        # train
        policy.train()
        if train_fn:
            train_fn(epoch)
        with tqdm.tqdm(total=frame_per_epoch, desc=f'Epoch #{epoch}',
                       **tqdm_config) as t:
            while t.n < t.total:
                result = train_collector.collect(n_step=collect_per_step,
                                                 log_fn=log_fn)
                data = {}
                if test_in_train and stop_fn and stop_fn(result['rew']):
                    test_result = test_episode(
                        policy, test_collector, test_fn,
                        epoch, episode_per_test)
                    if stop_fn and stop_fn(test_result['rew']):
                        if save_fn:
                            save_fn(policy)
                        for k in result.keys():
                            data[k] = f'{result[k]:.2f}'
                        t.set_postfix(**data)
                        return gather_info(
                            start_time, train_collector, test_collector,
                            test_result['rew'])
                    else:
                        policy.train()
                        if train_fn:
                            train_fn(epoch)
                losses = policy.learn(
                    train_collector.sample(0), batch_size, repeat_per_collect)
                train_collector.reset_buffer()
                global_step += collect_per_step
                for k in result.keys():
                    data[k] = f'{result[k]:.2f}'
                    if writer and global_step % log_interval == 0:
                        writer.add_scalar(
                            k, result[k], global_step=global_step)
                for k in losses.keys():
                    if stat.get(k) is None:
                        stat[k] = MovAvg()
                    stat[k].add(losses[k])
                    data[k] = f'{stat[k].get():.6f}'
                    if writer and global_step % log_interval == 0:
                        writer.add_scalar(
                            k, stat[k].get(), global_step=global_step)
                t.update(collect_per_step)
                t.set_postfix(**data)
            if t.n <= t.total:
                t.update()
        # test
        result = test_episode(
            policy, test_collector, test_fn, epoch, episode_per_test)
        if best_epoch == -1 or best_reward < result['rew']:
            best_reward = result['rew']
            best_epoch = epoch
            if save_fn:
                save_fn(policy)
        if verbose:
            print(f'Epoch #{epoch}: test_reward: {result["rew"]:.6f}, '
                  f'best_reward: {best_reward:.6f} in #{best_epoch}')
        if stop_fn and stop_fn(best_reward):
            break
    return gather_info(
        start_time, train_collector, test_collector, best_reward)