コード例 #1
0
def train_seq2reward_and_compute_reward_mse(
    env_name: str,
    model: ModelManager__Union,
    num_train_transitions: int,
    num_test_transitions: int,
    seq_len: int,
    batch_size: int,
    num_train_epochs: int,
    use_gpu: bool,
    saved_seq2reward_path: Optional[str] = None,
):
    """ Train Seq2Reward Network and compute reward mse. """
    env = Gym(env_name=env_name)
    env.seed(SEED)

    manager = model.value
    trainer = manager.initialize_trainer(
        use_gpu=use_gpu,
        reward_options=RewardOptions(),
        normalization_data_map=build_normalizer(env),
    )

    device = "cuda" if use_gpu else "cpu"
    # pyre-fixme[6]: Expected `device` for 2nd param but got `str`.
    trainer_preprocessor = make_replay_buffer_trainer_preprocessor(
        trainer, device, env)
    test_replay_buffer = ReplayBuffer(
        replay_capacity=num_test_transitions,
        batch_size=batch_size,
        stack_size=seq_len,
        return_everything_as_stack=True,
    )
    fill_replay_buffer(env, test_replay_buffer, num_test_transitions)

    if saved_seq2reward_path is None:
        # train from scratch
        trainer = train_seq2reward(
            env=env,
            trainer=trainer,
            trainer_preprocessor=trainer_preprocessor,
            num_train_transitions=num_train_transitions,
            seq_len=seq_len,
            batch_size=batch_size,
            num_train_epochs=num_train_epochs,
            test_replay_buffer=test_replay_buffer,
        )
    else:
        # load a pretrained model, and just evaluate it
        trainer.seq2reward_network.load_state_dict(
            torch.load(saved_seq2reward_path))
    state_dim = env.observation_space.shape[0]
    with torch.no_grad():
        trainer.seq2reward_network.eval()
        test_batch = test_replay_buffer.sample_transition_batch(
            batch_size=test_replay_buffer.size)
        preprocessed_test_batch = trainer_preprocessor(test_batch)
        adhoc_padding(preprocessed_test_batch, state_dim=state_dim)
        losses = trainer.get_loss(preprocessed_test_batch)
        detached_losses = [loss.cpu().detach().item() for loss in losses]
        trainer.seq2reward_network.train()
    return detached_losses
コード例 #2
0
    def test_sparse_input(self):
        replay_capacity = 100
        num_transitions = replay_capacity // 2
        memory = ReplayBuffer(stack_size=1,
                              replay_capacity=replay_capacity,
                              update_horizon=1)

        def trans(i):
            sparse_feat1 = list(range(0, i % 4))
            sparse_feat2 = list(range(i % 4, 4))
            id_list = {
                "sparse_feat1": sparse_feat1,
                "sparse_feat2": sparse_feat2
            }
            sparse_feat3 = (list(range(0, i % 7)),
                            [k + 0.5 for k in range(0, i % 7)])
            sparse_feat4 = (list(range(i % 7,
                                       7)), [k + 0.5 for k in range(i % 7, 7)])
            id_score_list = {
                "sparse_feat3": sparse_feat3,
                "sparse_feat4": sparse_feat4
            }
            return {
                "observation": np.ones(OBS_SHAPE, dtype=OBS_TYPE),
                "action": int(2 * i),
                "reward": float(3 * i),
                "terminal": i % 4,
                "id_list": id_list,
                "id_score_list": id_score_list,
            }

        for i in range(num_transitions):
            memory.add(**trans(i))

        indices = list(range(num_transitions - 1))
        batch = memory.sample_transition_batch(len(indices),
                                               torch.tensor(indices))

        # calculate expected
        res = {
            "id_list": {
                "sparse_feat1": ([], []),
                "sparse_feat2": ([], [])
            },
            "id_score_list": {
                "sparse_feat3": ([], [], []),
                "sparse_feat4": ([], [], []),
            },
            "next_id_list": {
                "sparse_feat1": ([], []),
                "sparse_feat2": ([], [])
            },
            "next_id_score_list": {
                "sparse_feat3": ([], [], []),
                "sparse_feat4": ([], [], []),
            },
        }
        for i in range(num_transitions - 1):
            feats_i = trans(i)
            feats_next = trans(i + 1)
            for k in ["id_list", "id_score_list"]:
                for feat_id in res[k]:
                    res[k][feat_id][0].append(len(res[k][feat_id][1]))
                    if k == "id_list":
                        res[k][feat_id][1].extend(feats_i[k][feat_id])
                    else:
                        res[k][feat_id][1].extend(feats_i[k][feat_id][0])
                        res[k][feat_id][2].extend(feats_i[k][feat_id][1])

            for k in ["next_id_list", "next_id_score_list"]:
                for feat_id in res[k]:
                    res[k][feat_id][0].append(len(res[k][feat_id][1]))
                    orig_k = k[len("next_"):]
                    if k == "next_id_list":
                        res[k][feat_id][1].extend(feats_next[orig_k][feat_id])
                    else:
                        res[k][feat_id][1].extend(
                            feats_next[orig_k][feat_id][0])
                        res[k][feat_id][2].extend(
                            feats_next[orig_k][feat_id][1])

        for k in [
                "id_list", "id_score_list", "next_id_list",
                "next_id_score_list"
        ]:
            for feat_id in res[k]:
                if k in ["id_list", "next_id_list"]:
                    npt.assert_array_equal(res[k][feat_id][0],
                                           getattr(batch, k)[feat_id][0])
                    npt.assert_array_equal(res[k][feat_id][1],
                                           getattr(batch, k)[feat_id][1])
                else:
                    npt.assert_array_equal(res[k][feat_id][0],
                                           getattr(batch, k)[feat_id][0])
                    npt.assert_array_equal(res[k][feat_id][1],
                                           getattr(batch, k)[feat_id][1])
                    npt.assert_array_equal(res[k][feat_id][2],
                                           getattr(batch, k)[feat_id][2])

        # sample random
        _ = memory.sample_transition_batch(10)
コード例 #3
0
ファイル: test_world_model.py プロジェクト: zachkeer/ReAgent
def train_mdnrnn_and_compute_feature_stats(
    env_name: str,
    model: ModelManager__Union,
    num_train_transitions: int,
    num_test_transitions: int,
    seq_len: int,
    batch_size: int,
    num_train_epochs: int,
    use_gpu: bool,
    saved_mdnrnn_path: Optional[str] = None,
):
    """ Train MDNRNN Memory Network and compute feature importance/sensitivity. """
    env: gym.Env = Gym(env_name=env_name)
    env.seed(SEED)

    manager = model.value
    trainer = manager.initialize_trainer(
        use_gpu=use_gpu,
        reward_options=RewardOptions(),
        normalization_data_map=build_normalizer(env),
    )

    device = "cuda" if use_gpu else "cpu"
    # pyre-fixme[6]: Expected `device` for 2nd param but got `str`.
    trainer_preprocessor = make_replay_buffer_trainer_preprocessor(
        trainer, device, env)
    test_replay_buffer = ReplayBuffer(
        replay_capacity=num_test_transitions,
        batch_size=batch_size,
        stack_size=seq_len,
        return_everything_as_stack=True,
    )
    fill_replay_buffer(env, test_replay_buffer, num_test_transitions)

    if saved_mdnrnn_path is None:
        # train from scratch
        trainer = train_mdnrnn(
            env=env,
            trainer=trainer,
            trainer_preprocessor=trainer_preprocessor,
            num_train_transitions=num_train_transitions,
            seq_len=seq_len,
            batch_size=batch_size,
            num_train_epochs=num_train_epochs,
            test_replay_buffer=test_replay_buffer,
        )
    else:
        # load a pretrained model, and just evaluate it
        trainer.memory_network.mdnrnn.load_state_dict(
            torch.load(saved_mdnrnn_path))

    with torch.no_grad():
        trainer.memory_network.mdnrnn.eval()
        test_batch = test_replay_buffer.sample_transition_batch(
            batch_size=test_replay_buffer.size)
        preprocessed_test_batch = trainer_preprocessor(test_batch)
        feature_importance = calculate_feature_importance(
            env=env,
            trainer=trainer,
            use_gpu=use_gpu,
            test_batch=preprocessed_test_batch,
        )

        feature_sensitivity = calculate_feature_sensitivity(
            env=env,
            trainer=trainer,
            use_gpu=use_gpu,
            test_batch=preprocessed_test_batch,
        )

        trainer.memory_network.mdnrnn.train()
    return feature_importance, feature_sensitivity