예제 #1
0
def test_ReplayBuffer():
    """
    tianshou.data.ReplayBuffer
    buf.add()
    buf.get()
    buf.update()
    buf.sample()
    buf.reset()
    len(buf)
    :return:
    """
    buf1 = ReplayBuffer(size=15)
    for i in range(3):
        buf1.add(obs=i,
                 act=i,
                 rew=i,
                 done=i,
                 obs_next=i + 1,
                 info={},
                 weight=None)
    print(len(buf1))
    print(buf1.obs)
    buf2 = ReplayBuffer(size=10)
    for i in range(15):
        buf2.add(obs=i,
                 act=i,
                 rew=i,
                 done=i,
                 obs_next=i + 1,
                 info={},
                 weight=None)
    print(buf2.obs)
    buf1.update(buf2)
    print(buf1.obs)
    index = [1, 3, 5]
    # key is an obligatory args
    print(buf2.get(index, key='obs'))
    print('--------------------')
    sample_data, indice = buf2.sample(batch_size=4)
    print(sample_data, indice)
    print(sample_data.obs == buf2[indice].obs)
    print('--------------------')
    # buf.reset() only resets the index, not the content.
    print(len(buf2))
    buf2.reset()
    print(len(buf2))
    print(buf2)
    print('--------------------')
예제 #2
0
def test_episodic_returns(size=2560):
    fn = BasePolicy.compute_episodic_return
    buf = ReplayBuffer(20)
    batch = Batch(
        done=np.array([1, 0, 0, 1, 0, 1, 0, 1.]),
        rew=np.array([0, 1, 2, 3, 4, 5, 6, 7.]),
        info=Batch({
            'TimeLimit.truncated':
            np.array([False, False, False, False, False, True, False, False])
        }))
    for b in batch:
        b.obs = b.act = 1
        buf.add(b)
    returns, _ = fn(batch, buf, buf.sample_indices(0), gamma=.1, gae_lambda=1)
    ans = np.array([0, 1.23, 2.3, 3, 4.5, 5, 6.7, 7])
    assert np.allclose(returns, ans)
    buf.reset()
    batch = Batch(
        done=np.array([0, 1, 0, 1, 0, 1, 0.]),
        rew=np.array([7, 6, 1, 2, 3, 4, 5.]),
    )
    for b in batch:
        b.obs = b.act = 1
        buf.add(b)
    returns, _ = fn(batch, buf, buf.sample_indices(0), gamma=.1, gae_lambda=1)
    ans = np.array([7.6, 6, 1.2, 2, 3.4, 4, 5])
    assert np.allclose(returns, ans)
    buf.reset()
    batch = Batch(
        done=np.array([0, 1, 0, 1, 0, 0, 1.]),
        rew=np.array([7, 6, 1, 2, 3, 4, 5.]),
    )
    for b in batch:
        b.obs = b.act = 1
        buf.add(b)
    returns, _ = fn(batch, buf, buf.sample_indices(0), gamma=.1, gae_lambda=1)
    ans = np.array([7.6, 6, 1.2, 2, 3.45, 4.5, 5])
    assert np.allclose(returns, ans)
    buf.reset()
    batch = Batch(
        done=np.array([0, 0, 0, 1., 0, 0, 0, 1, 0, 0, 0, 1]),
        rew=np.array(
            [101, 102, 103., 200, 104, 105, 106, 201, 107, 108, 109, 202]),
    )
    for b in batch:
        b.obs = b.act = 1
        buf.add(b)
    v = np.array([2., 3., 4, -1, 5., 6., 7, -2, 8., 9., 10, -3])
    returns, _ = fn(batch,
                    buf,
                    buf.sample_indices(0),
                    v,
                    gamma=0.99,
                    gae_lambda=0.95)
    ground_truth = np.array([
        454.8344, 376.1143, 291.298, 200., 464.5610, 383.1085, 295.387, 201.,
        474.2876, 390.1027, 299.476, 202.
    ])
    assert np.allclose(returns, ground_truth)
    buf.reset()
    batch = Batch(done=np.array([0, 0, 0, 1., 0, 0, 0, 1, 0, 0, 0, 1]),
                  rew=np.array([
                      101, 102, 103., 200, 104, 105, 106, 201, 107, 108, 109,
                      202
                  ]),
                  info=Batch({
                      'TimeLimit.truncated':
                      np.array([
                          False, False, False, True, False, False, False, True,
                          False, False, False, False
                      ])
                  }))
    for b in batch:
        b.obs = b.act = 1
        buf.add(b)
    v = np.array([2., 3., 4, -1, 5., 6., 7, -2, 8., 9., 10, -3])
    returns, _ = fn(batch,
                    buf,
                    buf.sample_indices(0),
                    v,
                    gamma=0.99,
                    gae_lambda=0.95)
    ground_truth = np.array([
        454.0109, 375.2386, 290.3669, 199.01, 462.9138, 381.3571, 293.5248,
        199.02, 474.2876, 390.1027, 299.476, 202.
    ])
    assert np.allclose(returns, ground_truth)

    if __name__ == '__main__':
        buf = ReplayBuffer(size)
        batch = Batch(
            done=np.random.randint(100, size=size) == 0,
            rew=np.random.random(size),
        )
        for b in batch:
            b.obs = b.act = 1
            buf.add(b)
        indices = buf.sample_indices(0)

        def vanilla():
            return compute_episodic_return_base(batch, gamma=.1)

        def optimized():
            return fn(batch, buf, indices, gamma=.1, gae_lambda=1.0)

        cnt = 3000
        print('GAE vanilla', timeit(vanilla, setup=vanilla, number=cnt))
        print('GAE optim  ', timeit(optimized, setup=optimized, number=cnt))
예제 #3
0
파일: MAPPO.py 프로젝트: Luckych454/MAFRL
def test_Fedppo(args=get_args()):
    torch.set_num_threads(1)  # for poor CPU
    env = gym.make(args.task)
    args.state_shape = env.observation_space.shape or env.observation_space.n
    args.action_shape = env.action_space.shape or env.action_space.n
    # train_envs = gym.make(args.task)
    # you can also use tianshou.env.SubprocVectorEnv
    # train_envs = DummyVectorEnv(
    #     [lambda: gym.make(args.task) for _ in range(args.training_num)])
    # # test_envs = gym.make(args.task)
    # test_envs = DummyVectorEnv(
    #     [lambda: gym.make(args.task) for _ in range(args.test_num)])
    if args.data_quantity != 0:
        env.set_data_quantity(args.data_quantity)
    if args.data_quality != 0:
        env.set_data_quality(args.data_quality)
    if args.psi != 0:
        env.set_psi(args.psi)
    if args.nu != 0:
        env.set_nu(args.nu)
    # seed
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    # train_envs.seed(args.seed)
    # test_envs.seed(args.seed)
    # model

    # server policy
    server_policy = build_policy(0, args)

    # client policy
    ND_policy = build_policy(1, args)
    RD_policy = build_policy(2, args)
    FD_policy = build_policy(3, args)
    # 不用collector,用replaybuffer
    server_buffer = ReplayBuffer(args.buffer_size)
    ND_buffer = ReplayBuffer(args.buffer_size)
    RD_buffer = ReplayBuffer(args.buffer_size)
    FD_buffer = ReplayBuffer(args.buffer_size)

    # log
    log_path = os.path.join(args.logdir, args.task, 'ppo')
    writer = SummaryWriter(log_path)

    def save_fn(policy):
        torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))

    # 这里开始我自己写,自己写trainer和testor
    # 为了查看server额收敛情况,我们首先不训练client网络。。。
    start_time = time.time()
    _server_obs, _ND_obs, _RD_obs, _FD_obs = env.reset()
    _server_act = _server_rew = _done = _info = None
    server_buffer.reset()
    _ND_act = _ND_rew = _RD_act = _RD_rew = _FD_act = _FD_rew = [None]
    ND_buffer.reset()
    RD_buffer.reset()
    FD_buffer.reset()
    all_server_costs = []
    all_ND_utility = []
    all_RD_utility = []
    all_FD_utility = []
    all_leak_probability = []
    for epoch in range(1, 1 + args.epoch):
        # 每个epoch收集N*T数据,然后用B训练M次
        server_costs = []
        ND_utility = []
        FD_utility = []
        RD_utility = []
        leak_probability = []
        payment = []
        expected_time = []
        training_time = []
        with tqdm.tqdm(total=args.step_per_epoch,
                       desc=f'Epoch #{epoch}',
                       **tqdm_config) as t:
            while t.n < t.total:
                # 收集数据,不用梯度
                # server
                _server_obs, _ND_obs, _RD_obs, _FD_obs = env.reset()
                server_batch = Batch(obs=_server_obs,
                                     act=_server_act,
                                     rew=_server_rew,
                                     done=_done,
                                     obs_next=None,
                                     info=_info,
                                     policy=None)
                with torch.no_grad():
                    server_result = server_policy(server_batch, None)
                _server_policy = [{}]
                _server_act = to_numpy(server_result.act)
                # ND
                ND_batch = Batch(obs=_ND_obs,
                                 act=_ND_act,
                                 rew=_ND_rew,
                                 done=_done,
                                 obs_next=None,
                                 info=_info,
                                 policy=None)
                with torch.no_grad():
                    ND_result = ND_policy(ND_batch, None)
                _ND_policy = [{}]
                _ND_act = to_numpy(ND_result.act)
                # RD
                RD_batch = Batch(obs=_RD_obs,
                                 act=_RD_act,
                                 rew=_RD_rew,
                                 done=_done,
                                 obs_next=None,
                                 info=_info,
                                 policy=None)
                with torch.no_grad():
                    RD_result = RD_policy(RD_batch, None)
                _RD_policy = [{}]
                _RD_act = to_numpy(RD_result.act)
                # FD
                FD_batch = Batch(obs=_FD_obs,
                                 act=_FD_act,
                                 rew=_FD_rew,
                                 done=_done,
                                 obs_next=None,
                                 info=_info,
                                 policy=None)
                with torch.no_grad():
                    FD_result = FD_policy(FD_batch, None)
                _FD_policy = [{}]
                _FD_act = to_numpy(FD_result.act)
                # print(_ND_act.shape)
                server_obs_next, ND_obs_next, RD_obs_next, FD_obs_next, _server_rew, _client_rew, _done, _info = env.step(
                    _server_act[0], _ND_act[0], _RD_act[0], _FD_act[0])

                server_costs.append(_server_rew)
                ND_utility.append(_client_rew[0])
                RD_utility.append(_client_rew[1])
                FD_utility.append(_client_rew[2])
                leak_probability.append(_info[0]["leak"])
                payment.append(env.payment)
                expected_time.append(env.expected_time)
                training_time.append(env.global_time * env.time_lambda)
                # 加入replay buffer
                server_buffer.add(
                    Batch(obs=_server_obs[0],
                          act=_server_act[0],
                          rew=_server_rew[0],
                          done=_done[0],
                          obs_next=server_obs_next[0],
                          info=_info[0],
                          policy=_server_policy[0]))
                ND_buffer.add(
                    Batch(obs=_ND_obs[0],
                          act=_ND_act[0],
                          rew=_client_rew[0],
                          done=_done[0],
                          obs_next=ND_obs_next[0],
                          info=_info[0],
                          policy=_ND_policy[0]))
                RD_buffer.add(
                    Batch(obs=_RD_obs[0],
                          act=_RD_act[0],
                          rew=_client_rew[1],
                          done=_done[0],
                          obs_next=RD_obs_next[0],
                          info=_info[0],
                          policy=_RD_policy[0]))
                FD_buffer.add(
                    Batch(obs=_FD_obs[0],
                          act=_FD_act[0],
                          rew=_client_rew[2],
                          done=_done[0],
                          obs_next=FD_obs_next[0],
                          info=_info[0],
                          policy=_FD_policy[0]))
                t.update(1)
                _server_obs = server_obs_next
                _ND_obs = ND_obs_next
                _RD_obs = RD_obs_next
                _FD_obs = FD_obs_next
        all_server_costs.append(np.array(server_costs).mean())
        all_ND_utility.append(np.array(ND_utility).mean())
        all_RD_utility.append(np.array(RD_utility).mean())
        all_FD_utility.append(np.array(FD_utility).mean())
        all_leak_probability.append(np.array(leak_probability).mean())
        print("current bandwidth:", env.bandwidth)
        print("leak signal:", env.leak_NU, env.leak_FU)
        print("current server cost:", np.array(server_costs).mean())
        print("current device utility:", all_ND_utility[-1],
              all_RD_utility[-1], all_FD_utility[-1])
        print("leak probability:", all_leak_probability[-1])
        print("server_act:", _server_act[0])
        print("device_acts:", _ND_act[0], _RD_act[0], _FD_act[0])
        print("payment cost:", np.array(payment).mean())
        print("Expected time cost:", np.array(expected_time).mean())
        print("Training time cost:", np.array(training_time).mean())
        # print("server_act:",_server_act)
        # print("client_act:",_client_act)
        print("info:", env.communication_time, env.computation_time,
              env.K_theta)
        server_batch_data, server_indice = server_buffer.sample(0)
        server_batch_data = server_policy.process_fn(server_batch_data,
                                                     server_buffer,
                                                     server_indice)
        server_policy.learn(server_batch_data, args.batch_size,
                            args.repeat_per_collect)
        server_buffer.reset()

        ND_batch_data, ND_indice = ND_buffer.sample(0)
        ND_batch_data = ND_policy.process_fn(ND_batch_data, ND_buffer,
                                             ND_indice)
        ND_policy.learn(ND_batch_data, args.batch_size,
                        args.repeat_per_collect)
        ND_buffer.reset()

        RD_batch_data, RD_indice = RD_buffer.sample(0)
        RD_batch_data = RD_policy.process_fn(RD_batch_data, RD_buffer,
                                             RD_indice)
        RD_policy.learn(RD_batch_data, args.batch_size,
                        args.repeat_per_collect)
        RD_buffer.reset()

        FD_batch_data, FD_indice = FD_buffer.sample(0)
        FD_batch_data = FD_policy.process_fn(FD_batch_data, FD_buffer,
                                             FD_indice)
        FD_policy.learn(FD_batch_data, args.batch_size,
                        args.repeat_per_collect)
        FD_buffer.reset()
    print("all_server_cost:", all_server_costs)
    print("all_ND_utility:", all_ND_utility)
    print("all_RD_utility:", all_RD_utility)
    print("all_FD_utility:", all_FD_utility)
    print("all_leak_probability:", all_leak_probability)
    plt.plot(all_server_costs)
    plt.show()
예제 #4
0
def test_dqn(args=get_args()):
    env = make_atari_env(args)
    args.state_shape = env.observation_space.shape or env.observation_space.n
    args.action_shape = env.env.action_space.shape or env.env.action_space.n

    # should be N_FRAMES x H x W

    print("Observations shape: ", args.state_shape)
    print("Actions shape: ", args.action_shape)
    # make environments
    train_envs = SubprocVectorEnv(
        [lambda: make_atari_env(args) for _ in range(args.training_num)])
    test_envs = SubprocVectorEnv(
        [lambda: make_atari_env_watch(args) for _ in range(args.test_num)])
    # seed
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    train_envs.seed(args.seed)
    test_envs.seed(args.seed)
    # log
    log_path = os.path.join(args.logdir, args.task, 'embedding')
    embedding_writer = SummaryWriter(log_path + '/with_init')

    embedding_net = embedding_prediction.Prediction(
        *args.state_shape, args.action_shape,
        args.device).to(device=args.device)
    embedding_net.apply(embedding_prediction.weights_init)

    if args.embedding_path:
        embedding_net.load_state_dict(torch.load(log_path + '/embedding.pth'))
        print("Loaded agent from: ", log_path + '/embedding.pth')
    # numel_list = [p.numel() for p in embedding_net.parameters()]
    # print(sum(numel_list), numel_list)

    pre_buffer = ReplayBuffer(args.buffer_size,
                              save_only_last_obs=True,
                              stack_num=args.frames_stack)
    pre_test_buffer = ReplayBuffer(args.buffer_size // 100,
                                   save_only_last_obs=True,
                                   stack_num=args.frames_stack)

    train_collector = Collector(None, train_envs, pre_buffer)
    test_collector = Collector(None, test_envs, pre_test_buffer)
    if args.embedding_data_path:
        pre_buffer = pickle.load(open(log_path + '/train_data.pkl', 'rb'))
        pre_test_buffer = pickle.load(open(log_path + '/test_data.pkl', 'rb'))
        train_collector.buffer = pre_buffer
        test_collector.buffer = pre_test_buffer
        print('load success')
    else:
        print('collect start')
        train_collector = Collector(None, train_envs, pre_buffer)
        test_collector = Collector(None, test_envs, pre_test_buffer)
        train_collector.collect(n_step=args.buffer_size, random=True)
        test_collector.collect(n_step=args.buffer_size // 100, random=True)
        print(len(train_collector.buffer))
        print(len(test_collector.buffer))
        if not os.path.exists(log_path):
            os.makedirs(log_path)
        pickle.dump(pre_buffer, open(log_path + '/train_data.pkl', 'wb'))
        pickle.dump(pre_test_buffer, open(log_path + '/test_data.pkl', 'wb'))
        print('collect finish')

    #使用得到的数据训练编码网络
    def part_loss(x, device='cpu'):
        if not isinstance(x, torch.Tensor):
            x = torch.tensor(x, device=device, dtype=torch.float32)
        x = x.view(128, -1)
        temp = torch.cat(
            ((1 - x).pow(2.0).unsqueeze_(0), x.pow(2.0).unsqueeze_(0)), dim=0)
        temp_2 = torch.min(temp, dim=0)[0]
        return torch.sum(temp_2)

    pre_optim = torch.optim.Adam(embedding_net.parameters(), lr=1e-5)
    scheduler = torch.optim.lr_scheduler.StepLR(pre_optim,
                                                step_size=50000,
                                                gamma=0.1,
                                                last_epoch=-1)
    test_batch_data = test_collector.sample(batch_size=0)

    loss_fn = torch.nn.NLLLoss()
    # train_loss = []
    for epoch in range(1, 100001):
        embedding_net.train()
        batch_data = train_collector.sample(batch_size=128)
        # print(batch_data)
        # print(batch_data['obs'][0] == batch_data['obs'][1])
        pred = embedding_net(batch_data['obs'], batch_data['obs_next'])
        x1 = pred[1]
        x2 = pred[2]
        # print(torch.argmax(pred[0], dim=1))
        if not isinstance(batch_data['act'], torch.Tensor):
            act = torch.tensor(batch_data['act'],
                               device=args.device,
                               dtype=torch.int64)
        # print(pred[0].dtype)
        # print(act.dtype)
        # l2_norm = sum(p.pow(2.0).sum() for p in embedding_net.net.parameters())
        # loss = loss_fn(pred[0], act) + 0.001 * (part_loss(x1) + part_loss(x2)) / 64
        loss_1 = loss_fn(pred[0], act)
        loss_2 = 0.01 * (part_loss(x1, args.device) +
                         part_loss(x2, args.device)) / 128
        loss = loss_1 + loss_2
        # print(loss_1)
        # print(loss_2)
        embedding_writer.add_scalar('training loss1', loss_1.item(), epoch)
        embedding_writer.add_scalar('training loss2', loss_2, epoch)
        embedding_writer.add_scalar('training loss', loss.item(), epoch)
        # train_loss.append(loss.detach().item())
        pre_optim.zero_grad()
        loss.backward()
        pre_optim.step()
        scheduler.step()

        if epoch % 10000 == 0 or epoch == 1:
            print(pre_optim.state_dict()['param_groups'][0]['lr'])
            # print("Epoch: %d,Train: Loss: %f" % (epoch, float(loss.item())))
            correct = 0
            numel_list = [p for p in embedding_net.parameters()][-2]
            print(numel_list)
            embedding_net.eval()
            with torch.no_grad():
                test_pred, x1, x2, _ = embedding_net(
                    test_batch_data['obs'], test_batch_data['obs_next'])
                if not isinstance(test_batch_data['act'], torch.Tensor):
                    act = torch.tensor(test_batch_data['act'],
                                       device=args.device,
                                       dtype=torch.int64)
                loss_1 = loss_fn(test_pred, act)
                loss_2 = 0.01 * (part_loss(x1, args.device) +
                                 part_loss(x2, args.device)) / 128
                loss = loss_1 + loss_2
                embedding_writer.add_scalar('test loss', loss.item(), epoch)
                # print("Test Loss: %f" % (float(loss)))
                print(torch.argmax(test_pred, dim=1))
                print(act)
                correct += int((torch.argmax(test_pred, dim=1) == act).sum())
                print('Acc:', correct / len(test_batch_data))

    torch.save(embedding_net.state_dict(),
               os.path.join(log_path, 'embedding.pth'))
    embedding_writer.close()
    # plt.figure()
    # plt.plot(np.arange(100000),train_loss)
    # plt.show()
    exit()
    #构建hash表

    # log
    log_path = os.path.join(args.logdir, args.task, 'dqn')
    writer = SummaryWriter(log_path)

    # define model
    net = DQN(*args.state_shape, args.action_shape,
              args.device).to(device=args.device)

    optim = torch.optim.Adam(net.parameters(), lr=args.lr)
    # define policy
    policy = DQNPolicy(net,
                       optim,
                       args.gamma,
                       args.n_step,
                       target_update_freq=args.target_update_freq)
    # load a previous policy
    if args.resume_path:
        policy.load_state_dict(torch.load(args.resume_path))
        print("Loaded agent from: ", args.resume_path)

    # replay buffer: `save_last_obs` and `stack_num` can be removed together
    # when you have enough RAM
    pre_buffer.reset()
    buffer = ReplayBuffer(args.buffer_size,
                          ignore_obs_next=True,
                          save_only_last_obs=True,
                          stack_num=args.frames_stack)
    # collector
    # train_collector中传入preprocess_fn对奖励进行重构
    train_collector = Collector(policy, train_envs, buffer)
    test_collector = Collector(policy, test_envs)

    def save_fn(policy):
        torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))

    def stop_fn(x):
        if env.env.spec.reward_threshold:
            return x >= env.spec.reward_threshold
        elif 'Pong' in args.task:
            return x >= 20

    def train_fn(x):
        # nature DQN setting, linear decay in the first 1M steps
        now = x * args.collect_per_step * args.step_per_epoch
        if now <= 1e6:
            eps = args.eps_train - now / 1e6 * \
                (args.eps_train - args.eps_train_final)
            policy.set_eps(eps)
        else:
            policy.set_eps(args.eps_train_final)
        print("set eps =", policy.eps)

    def test_fn(x):
        policy.set_eps(args.eps_test)

    # watch agent's performance
    def watch():
        print("Testing agent ...")
        policy.eval()
        policy.set_eps(args.eps_test)
        test_envs.seed(args.seed)
        test_collector.reset()
        result = test_collector.collect(n_episode=[1] * args.test_num,
                                        render=1 / 30)
        pprint.pprint(result)

    if args.watch:
        watch()
        exit(0)

    # test train_collector and start filling replay buffer
    train_collector.collect(n_step=args.batch_size * 4)
    # trainer
    result = offpolicy_trainer(policy,
                               train_collector,
                               test_collector,
                               args.epoch,
                               args.step_per_epoch,
                               args.collect_per_step,
                               args.test_num,
                               args.batch_size,
                               train_fn=train_fn,
                               test_fn=test_fn,
                               stop_fn=stop_fn,
                               save_fn=save_fn,
                               writer=writer,
                               test_in_train=False)

    pprint.pprint(result)
    watch()
예제 #5
0
class Collector(object):
    """The :class:`~tianshou.data.Collector` enables the policy to interact
    with different types of environments conveniently.

    :param policy: an instance of the :class:`~tianshou.policy.BasePolicy`
        class.
    :param env: an environment or an instance of the
        :class:`~tianshou.env.BaseVectorEnv` class.
    :param buffer: an instance of the :class:`~tianshou.data.ReplayBuffer`
        class, or a list of :class:`~tianshou.data.ReplayBuffer`. If set to
        ``None``, it will automatically assign a small-size
        :class:`~tianshou.data.ReplayBuffer`.
    :param int stat_size: for the moving average of recording speed, defaults
        to 100.

    Example:
    ::

        policy = PGPolicy(...)  # or other policies if you wish
        env = gym.make('CartPole-v0')
        replay_buffer = ReplayBuffer(size=10000)
        # here we set up a collector with a single environment
        collector = Collector(policy, env, buffer=replay_buffer)

        # the collector supports vectorized environments as well
        envs = VectorEnv([lambda: gym.make('CartPole-v0') for _ in range(3)])
        buffers = [ReplayBuffer(size=5000) for _ in range(3)]
        # you can also pass a list of replay buffer to collector, for multi-env
        # collector = Collector(policy, envs, buffer=buffers)
        collector = Collector(policy, envs, buffer=replay_buffer)

        # collect at least 3 episodes
        collector.collect(n_episode=3)
        # collect 1 episode for the first env, 3 for the third env
        collector.collect(n_episode=[1, 0, 3])
        # collect at least 2 steps
        collector.collect(n_step=2)
        # collect episodes with visual rendering (the render argument is the
        #   sleep time between rendering consecutive frames)
        collector.collect(n_episode=1, render=0.03)

        # sample data with a given number of batch-size:
        batch_data = collector.sample(batch_size=64)
        # policy.learn(batch_data)  # btw, vanilla policy gradient only
        #   supports on-policy training, so here we pick all data in the buffer
        batch_data = collector.sample(batch_size=0)
        policy.learn(batch_data)
        # on-policy algorithms use the collected data only once, so here we
        #   clear the buffer
        collector.reset_buffer()

    For the scenario of collecting data from multiple environments to a single
    buffer, the cache buffers will turn on automatically. It may return the
    data more than the given limitation.

    .. note::

        Please make sure the given environment has a time limitation.
    """

    def __init__(self, policy, env, buffer=None, stat_size=100, **kwargs):
        super().__init__()
        self.env = env
        self.env_num = 1
        self.collect_step = 0
        self.collect_episode = 0
        self.collect_time = 0
        if buffer is None:
            self.buffer = ReplayBuffer(100)
        else:
            self.buffer = buffer
        self.policy = policy
        self.process_fn = policy.process_fn
        self._multi_env = isinstance(env, BaseVectorEnv)
        self._multi_buf = False  # True if buf is a list
        # need multiple cache buffers only if storing in one buffer
        self._cached_buf = []
        if self._multi_env:
            self.env_num = len(env)
            if isinstance(self.buffer, list):
                assert len(self.buffer) == self.env_num, \
                    'The number of data buffer does not match the number of ' \
                    'input env.'
                self._multi_buf = True
            elif isinstance(self.buffer, ReplayBuffer):
                self._cached_buf = [
                    ListReplayBuffer() for _ in range(self.env_num)]
            else:
                raise TypeError('The buffer in data collector is invalid!')
        self.reset_env()
        self.reset_buffer()
        # state over batch is either a list, an np.ndarray, or a torch.Tensor
        self.state = None
        self.step_speed = MovAvg(stat_size)
        self.episode_speed = MovAvg(stat_size)

    def reset_buffer(self):
        """Reset the main data buffer."""
        if self._multi_buf:
            for b in self.buffer:
                b.reset()
        else:
            self.buffer.reset()

    def get_env_num(self):
        """Return the number of environments the collector has."""
        return self.env_num

    def reset_env(self):
        """Reset all of the environment(s)' states and reset all of the cache
        buffers (if need).
        """
        self._obs = self.env.reset()
        self._act = self._rew = self._done = self._info = None
        if self._multi_env:
            self.reward = np.zeros(self.env_num)
            self.length = np.zeros(self.env_num)
        else:
            self.reward, self.length = 0, 0
        for b in self._cached_buf:
            b.reset()

    def seed(self, seed=None):
        """Reset all the seed(s) of the given environment(s)."""
        if hasattr(self.env, 'seed'):
            return self.env.seed(seed)

    def render(self, **kwargs):
        """Render all the environment(s)."""
        if hasattr(self.env, 'render'):
            return self.env.render(**kwargs)

    def close(self):
        """Close the environment(s)."""
        if hasattr(self.env, 'close'):
            self.env.close()

    def _make_batch(self, data):
        """Return [data]."""
        if isinstance(data, np.ndarray):
            return data[None]
        else:
            return np.array([data])

    def _reset_state(self, id):
        """Reset self.state[id]."""
        if self.state is None:
            return
        if isinstance(self.state, list):
            self.state[id] = None
        elif isinstance(self.state, dict):
            for k in self.state:
                if isinstance(self.state[k], list):
                    self.state[k][id] = None
                elif isinstance(self.state[k], torch.Tensor) or \
                        isinstance(self.state[k], np.ndarray):
                    self.state[k][id] = 0
        elif isinstance(self.state, torch.Tensor) or \
                isinstance(self.state, np.ndarray):
            self.state[id] = 0

    def collect(self, n_step=0, n_episode=0, render=None):
        """Collect a specified number of step or episode.

        :param int n_step: how many steps you want to collect.
        :param n_episode: how many episodes you want to collect (in each
            environment).
        :type n_episode: int or list
        :param float render: the sleep time between rendering consecutive
            frames, defaults to ``None`` (no rendering).

        .. note::

            One and only one collection number specification is permitted,
            either ``n_step`` or ``n_episode``.

        :return: A dict including the following keys

            * ``n/ep`` the collected number of episodes.
            * ``n/st`` the collected number of steps.
            * ``v/st`` the speed of steps per second.
            * ``v/ep`` the speed of episode per second.
            * ``rew`` the mean reward over collected episodes.
            * ``len`` the mean length over collected episodes.
        """
        warning_count = 0
        if not self._multi_env:
            n_episode = np.sum(n_episode)
        start_time = time.time()
        assert sum([(n_step != 0), (n_episode != 0)]) == 1, \
            "One and only one collection number specification is permitted!"
        cur_step = 0
        cur_episode = np.zeros(self.env_num) if self._multi_env else 0
        reward_sum = 0
        length_sum = 0
        while True:
            if warning_count >= 100000:
                warnings.warn(
                    'There are already many steps in an episode. '
                    'You should add a time limitation to your environment!',
                    Warning)
            if self._multi_env:
                batch_data = Batch(
                    obs=self._obs, act=self._act, rew=self._rew,
                    done=self._done, obs_next=None, info=self._info)
            else:
                batch_data = Batch(
                    obs=self._make_batch(self._obs),
                    act=self._make_batch(self._act),
                    rew=self._make_batch(self._rew),
                    done=self._make_batch(self._done),
                    obs_next=None,
                    info=self._make_batch(self._info))
            with torch.no_grad():
                result = self.policy(batch_data, self.state)
            self.state = result.state if hasattr(result, 'state') else None
            if isinstance(result.act, torch.Tensor):
                self._act = result.act.detach().cpu().numpy()
            elif not isinstance(self._act, np.ndarray):
                self._act = np.array(result.act)
            else:
                self._act = result.act
            obs_next, self._rew, self._done, self._info = self.env.step(
                self._act if self._multi_env else self._act[0])
            if render is not None:
                self.env.render()
                if render > 0:
                    time.sleep(render)
            self.length += 1
            self.reward += self._rew
            if self._multi_env:
                for i in range(self.env_num):
                    data = {
                        'obs': self._obs[i], 'act': self._act[i],
                        'rew': self._rew[i], 'done': self._done[i],
                        'obs_next': obs_next[i], 'info': self._info[i]}
                    if self._cached_buf:
                        warning_count += 1
                        self._cached_buf[i].add(**data)
                    elif self._multi_buf:
                        warning_count += 1
                        self.buffer[i].add(**data)
                        cur_step += 1
                    else:
                        warning_count += 1
                        self.buffer.add(**data)
                        cur_step += 1
                    if self._done[i]:
                        if n_step != 0 or np.isscalar(n_episode) or \
                                cur_episode[i] < n_episode[i]:
                            cur_episode[i] += 1
                            reward_sum += self.reward[i]
                            length_sum += self.length[i]
                            if self._cached_buf:
                                cur_step += len(self._cached_buf[i])
                                self.buffer.update(self._cached_buf[i])
                        self.reward[i], self.length[i] = 0, 0
                        if self._cached_buf:
                            self._cached_buf[i].reset()
                        self._reset_state(i)
                if sum(self._done):
                    obs_next = self.env.reset(np.where(self._done)[0])
                if n_episode != 0:
                    if isinstance(n_episode, list) and \
                            (cur_episode >= np.array(n_episode)).all() or \
                            np.isscalar(n_episode) and \
                            cur_episode.sum() >= n_episode:
                        break
            else:
                self.buffer.add(
                    self._obs, self._act[0], self._rew,
                    self._done, obs_next, self._info)
                cur_step += 1
                if self._done:
                    cur_episode += 1
                    reward_sum += self.reward
                    length_sum += self.length
                    self.reward, self.length = 0, 0
                    self.state = None
                    obs_next = self.env.reset()
                if n_episode != 0 and cur_episode >= n_episode:
                    break
            if n_step != 0 and cur_step >= n_step:
                break
            self._obs = obs_next
        self._obs = obs_next
        if self._multi_env:
            cur_episode = sum(cur_episode)
        duration = max(time.time() - start_time, 1e-9)
        self.step_speed.add(cur_step / duration)
        self.episode_speed.add(cur_episode / duration)
        self.collect_step += cur_step
        self.collect_episode += cur_episode
        self.collect_time += duration
        if isinstance(n_episode, list):
            n_episode = np.sum(n_episode)
        else:
            n_episode = max(cur_episode, 1)
        return {
            'n/ep': cur_episode,
            'n/st': cur_step,
            'v/st': self.step_speed.get(),
            'v/ep': self.episode_speed.get(),
            'rew': reward_sum / n_episode,
            'len': length_sum / n_episode,
        }

    def sample(self, batch_size):
        """Sample a data batch from the internal replay buffer. It will call
        :meth:`~tianshou.policy.BasePolicy.process_fn` before returning
        the final batch data.

        :param int batch_size: ``0`` means it will extract all the data from
            the buffer, otherwise it will extract the data with the given
            batch_size.
        """
        if self._multi_buf:
            if batch_size > 0:
                lens = [len(b) for b in self.buffer]
                total = sum(lens)
                batch_index = np.random.choice(
                    total, batch_size, p=np.array(lens) / total)
            else:
                batch_index = np.array([])
            batch_data = Batch()
            for i, b in enumerate(self.buffer):
                cur_batch = (batch_index == i).sum()
                if batch_size and cur_batch or batch_size <= 0:
                    batch, indice = b.sample(cur_batch)
                    batch = self.process_fn(batch, b, indice)
                    batch_data.append(batch)
        else:
            batch_data, indice = self.buffer.sample(batch_size)
            batch_data = self.process_fn(batch_data, self.buffer, indice)
        return batch_data
예제 #6
0
class Collector(object):
    """docstring for Collector"""
    def __init__(self, policy, env, buffer=None, stat_size=100):
        super().__init__()
        self.env = env
        self.env_num = 1
        self.collect_step = 0
        self.collect_episode = 0
        self.collect_time = 0
        if buffer is None:
            self.buffer = ReplayBuffer(100)
        else:
            self.buffer = buffer
        self.policy = policy
        self.process_fn = policy.process_fn
        self._multi_env = isinstance(env, BaseVectorEnv)
        self._multi_buf = False  # True if buf is a list
        # need multiple cache buffers only if storing in one buffer
        self._cached_buf = []
        if self._multi_env:
            self.env_num = len(env)
            if isinstance(self.buffer, list):
                assert len(self.buffer) == self.env_num, \
                    'The number of data buffer does not match the number of ' \
                    'input env.'
                self._multi_buf = True
            elif isinstance(self.buffer, ReplayBuffer):
                self._cached_buf = [
                    ListReplayBuffer() for _ in range(self.env_num)
                ]
            else:
                raise TypeError('The buffer in data collector is invalid!')
        self.reset_env()
        self.reset_buffer()
        # state over batch is either a list, an np.ndarray, or a torch.Tensor
        self.state = None
        self.step_speed = MovAvg(stat_size)
        self.episode_speed = MovAvg(stat_size)

    def reset_buffer(self):
        if self._multi_buf:
            for b in self.buffer:
                b.reset()
        else:
            self.buffer.reset()

    def get_env_num(self):
        return self.env_num

    def reset_env(self):
        self._obs = self.env.reset()
        self._act = self._rew = self._done = self._info = None
        if self._multi_env:
            self.reward = np.zeros(self.env_num)
            self.length = np.zeros(self.env_num)
        else:
            self.reward, self.length = 0, 0
        for b in self._cached_buf:
            b.reset()

    def seed(self, seed=None):
        if hasattr(self.env, 'seed'):
            return self.env.seed(seed)

    def render(self, **kwargs):
        if hasattr(self.env, 'render'):
            return self.env.render(**kwargs)

    def close(self):
        if hasattr(self.env, 'close'):
            self.env.close()

    def _make_batch(self, data):
        if isinstance(data, np.ndarray):
            return data[None]
        else:
            return np.array([data])

    def collect(self, n_step=0, n_episode=0, render=0):
        warning_count = 0
        if not self._multi_env:
            n_episode = np.sum(n_episode)
        start_time = time.time()
        assert sum([(n_step != 0), (n_episode != 0)]) == 1, \
            "One and only one collection number specification permitted!"
        cur_step = 0
        cur_episode = np.zeros(self.env_num) if self._multi_env else 0
        reward_sum = 0
        length_sum = 0
        while True:
            if warning_count >= 100000:
                warnings.warn(
                    'There are already many steps in an episode. '
                    'You should add a time limitation to your environment!',
                    Warning)
            if self._multi_env:
                batch_data = Batch(obs=self._obs,
                                   act=self._act,
                                   rew=self._rew,
                                   done=self._done,
                                   obs_next=None,
                                   info=self._info)
            else:
                batch_data = Batch(obs=self._make_batch(self._obs),
                                   act=self._make_batch(self._act),
                                   rew=self._make_batch(self._rew),
                                   done=self._make_batch(self._done),
                                   obs_next=None,
                                   info=self._make_batch(self._info))
            result = self.policy(batch_data, self.state)
            self.state = result.state if hasattr(result, 'state') else None
            if isinstance(result.act, torch.Tensor):
                self._act = result.act.detach().cpu().numpy()
            elif not isinstance(self._act, np.ndarray):
                self._act = np.array(result.act)
            else:
                self._act = result.act
            obs_next, self._rew, self._done, self._info = self.env.step(
                self._act if self._multi_env else self._act[0])
            if render > 0:
                self.env.render()
                time.sleep(render)
            self.length += 1
            self.reward += self._rew
            if self._multi_env:
                for i in range(self.env_num):
                    data = {
                        'obs': self._obs[i],
                        'act': self._act[i],
                        'rew': self._rew[i],
                        'done': self._done[i],
                        'obs_next': obs_next[i],
                        'info': self._info[i]
                    }
                    if self._cached_buf:
                        warning_count += 1
                        self._cached_buf[i].add(**data)
                    elif self._multi_buf:
                        warning_count += 1
                        self.buffer[i].add(**data)
                        cur_step += 1
                    else:
                        warning_count += 1
                        self.buffer.add(**data)
                        cur_step += 1
                    if self._done[i]:
                        if n_step != 0 or np.isscalar(n_episode) or \
                                cur_episode[i] < n_episode[i]:
                            cur_episode[i] += 1
                            reward_sum += self.reward[i]
                            length_sum += self.length[i]
                            if self._cached_buf:
                                cur_step += len(self._cached_buf[i])
                                self.buffer.update(self._cached_buf[i])
                        self.reward[i], self.length[i] = 0, 0
                        if self._cached_buf:
                            self._cached_buf[i].reset()
                        if isinstance(self.state, list):
                            self.state[i] = None
                        elif self.state is not None:
                            if isinstance(self.state[i], dict):
                                self.state[i] = {}
                            else:
                                self.state[i] = self.state[i] * 0
                            if isinstance(self.state, torch.Tensor):
                                # remove ref count in pytorch (?)
                                self.state = self.state.detach()
                if sum(self._done):
                    obs_next = self.env.reset(np.where(self._done)[0])
                if n_episode != 0:
                    if isinstance(n_episode, list) and \
                            (cur_episode >= np.array(n_episode)).all() or \
                            np.isscalar(n_episode) and \
                            cur_episode.sum() >= n_episode:
                        break
            else:
                self.buffer.add(self._obs, self._act[0], self._rew, self._done,
                                obs_next, self._info)
                cur_step += 1
                if self._done:
                    cur_episode += 1
                    reward_sum += self.reward
                    length_sum += self.length
                    self.reward, self.length = 0, 0
                    self.state = None
                    obs_next = self.env.reset()
                if n_episode != 0 and cur_episode >= n_episode:
                    break
            if n_step != 0 and cur_step >= n_step:
                break
            self._obs = obs_next
        self._obs = obs_next
        if self._multi_env:
            cur_episode = sum(cur_episode)
        duration = time.time() - start_time
        self.step_speed.add(cur_step / duration)
        self.episode_speed.add(cur_episode / duration)
        self.collect_step += cur_step
        self.collect_episode += cur_episode
        self.collect_time += duration
        if isinstance(n_episode, list):
            n_episode = np.sum(n_episode)
        else:
            n_episode = max(cur_episode, 1)
        return {
            'n/ep': cur_episode,
            'n/st': cur_step,
            'v/st': self.step_speed.get(),
            'v/ep': self.episode_speed.get(),
            'rew': reward_sum / n_episode,
            'len': length_sum / n_episode,
        }

    def sample(self, batch_size):
        if self._multi_buf:
            if batch_size > 0:
                lens = [len(b) for b in self.buffer]
                total = sum(lens)
                batch_index = np.random.choice(total,
                                               batch_size,
                                               p=np.array(lens) / total)
            else:
                batch_index = np.array([])
            batch_data = Batch()
            for i, b in enumerate(self.buffer):
                cur_batch = (batch_index == i).sum()
                if batch_size and cur_batch or batch_size <= 0:
                    batch, indice = b.sample(cur_batch)
                    batch = self.process_fn(batch, b, indice)
                    batch_data.append(batch)
        else:
            batch_data, indice = self.buffer.sample(batch_size)
            batch_data = self.process_fn(batch_data, self.buffer, indice)
        return batch_data