Ejemplo n.º 1
0
def train_seq2reward_and_compute_reward_mse(
    env_name: str,
    model: ModelManager__Union,
    num_train_transitions: int,
    num_test_transitions: int,
    seq_len: int,
    batch_size: int,
    num_train_epochs: int,
    use_gpu: bool,
    saved_seq2reward_path: Optional[str] = None,
):
    """ Train Seq2Reward Network and compute reward mse. """
    env = Gym(env_name=env_name)
    env.seed(SEED)

    manager = model.value
    trainer = manager.initialize_trainer(
        use_gpu=use_gpu,
        reward_options=RewardOptions(),
        normalization_data_map=build_normalizer(env),
    )

    device = "cuda" if use_gpu else "cpu"
    # pyre-fixme[6]: Expected `device` for 2nd param but got `str`.
    trainer_preprocessor = make_replay_buffer_trainer_preprocessor(trainer, device, env)
    test_replay_buffer = ReplayBuffer(
        replay_capacity=num_test_transitions,
        batch_size=batch_size,
        stack_size=seq_len,
        return_everything_as_stack=True,
    )
    fill_replay_buffer(env, test_replay_buffer, num_test_transitions)

    if saved_seq2reward_path is None:
        # train from scratch
        trainer = train_seq2reward(
            env=env,
            trainer=trainer,
            trainer_preprocessor=trainer_preprocessor,
            num_train_transitions=num_train_transitions,
            seq_len=seq_len,
            batch_size=batch_size,
            num_train_epochs=num_train_epochs,
            test_replay_buffer=test_replay_buffer,
        )
    else:
        # load a pretrained model, and just evaluate it
        trainer.seq2reward_network.load_state_dict(torch.load(saved_seq2reward_path))
    state_dim = env.observation_space.shape[0]
    with torch.no_grad():
        trainer.seq2reward_network.eval()
        test_batch = test_replay_buffer.sample_transition_batch(
            batch_size=test_replay_buffer.size
        )
        preprocessed_test_batch = trainer_preprocessor(test_batch)
        adhoc_action_padding(preprocessed_test_batch, state_dim=state_dim)
        losses = trainer.get_loss(preprocessed_test_batch)
        detached_losses = losses.cpu().detach().item()
        trainer.seq2reward_network.train()
    return detached_losses
Ejemplo n.º 2
0
def run_test_offline(
    env_name: str,
    model: ModelManager__Union,
    replay_memory_size: int,
    num_batches_per_epoch: int,
    num_train_epochs: int,
    passing_score_bar: float,
    num_eval_episodes: int,
    minibatch_size: int,
    use_gpu: bool,
):
    env = Gym(env_name=env_name)
    env.seed(SEED)
    env.action_space.seed(SEED)
    normalization = build_normalizer(env)
    logger.info(f"Normalization is: \n{pprint.pformat(normalization)}")

    manager = model.value
    trainer = manager.initialize_trainer(
        use_gpu=use_gpu,
        reward_options=RewardOptions(),
        normalization_data_map=normalization,
    )

    # first fill the replay buffer to burn_in
    replay_buffer = ReplayBuffer(replay_capacity=replay_memory_size,
                                 batch_size=minibatch_size)
    # always fill full RB
    random_policy = make_random_policy_for_env(env)
    agent = Agent.create_for_env(env, policy=random_policy)
    fill_replay_buffer(
        env=env,
        replay_buffer=replay_buffer,
        desired_size=replay_memory_size,
        agent=agent,
    )

    device = torch.device("cuda") if use_gpu else None
    # pyre-fixme[6]: Expected `device` for 2nd param but got `Optional[torch.device]`.
    trainer_preprocessor = make_replay_buffer_trainer_preprocessor(
        trainer, device, env)

    writer = SummaryWriter()
    with summary_writer_context(writer):
        for epoch in range(num_train_epochs):
            logger.info(f"Evaluating before epoch {epoch}: ")
            eval_rewards = evaluate_cem(env, manager, 1)
            for _ in tqdm(range(num_batches_per_epoch)):
                train_batch = replay_buffer.sample_transition_batch()
                preprocessed_batch = trainer_preprocessor(train_batch)
                trainer.train(preprocessed_batch)

    logger.info(f"Evaluating after training for {num_train_epochs} epochs: ")
    eval_rewards = evaluate_cem(env, manager, num_eval_episodes)
    mean_rewards = np.mean(eval_rewards)
    assert (mean_rewards >= passing_score_bar
            ), f"{mean_rewards} doesn't pass the bar {passing_score_bar}."
 def test_box(self):
     env = Gym(env_name="CartPole-v0")
     obs_preprocessor = env.get_obs_preprocessor()
     obs = env.reset()
     state = obs_preprocessor(obs)
     self.assertTrue(state.has_float_features_only)
     self.assertEqual(state.float_features.shape, (1, obs.shape[0]))
     self.assertEqual(state.float_features.dtype, torch.float32)
     self.assertEqual(state.float_features.device, torch.device("cpu"))
     npt.assert_array_almost_equal(obs, state.float_features.squeeze(0))
 def test_box_cuda(self):
     env = Gym(env_name="CartPole-v0")
     device = torch.device("cuda")
     obs_preprocessor = env.get_obs_preprocessor(device=device)
     obs = env.reset()
     state = obs_preprocessor(obs)
     self.assertTrue(state.has_float_features_only)
     self.assertEqual(state.float_features.shape, (1, obs.shape[0]))
     self.assertEqual(state.float_features.dtype, torch.float32)
     # `device` doesn't have index. So we need this.
     x = torch.zeros(1, device=device)
     self.assertEqual(state.float_features.device, x.device)
     npt.assert_array_almost_equal(obs, state.float_features.cpu().squeeze(0))
Ejemplo n.º 5
0
    def test_create_df_from_replay_buffer(self):
        env_name = "MiniGrid-Empty-5x5-v0"
        env = Gym(env_name=env_name)
        state_dim = env.observation_space.shape[0]
        # Wrap env in TestEnv
        env = TestEnv(env)
        problem_domain = ProblemDomain.DISCRETE_ACTION
        DATASET_SIZE = 1000
        multi_steps = None
        DS = "2021-09-16"

        # Generate data
        df = create_df_from_replay_buffer(
            env=env,
            problem_domain=problem_domain,
            desired_size=DATASET_SIZE,
            multi_steps=multi_steps,
            ds=DS,
            shuffle_df=False,
        )
        self.assertEqual(len(df), DATASET_SIZE)

        # Check data
        preprocessor = PythonSparseToDenseProcessor(list(range(state_dim)))
        for idx, row in df.iterrows():
            df_mdp_id = row["mdp_id"]
            env_mdp_id = str(env.sart[idx][0])
            self.assertEqual(df_mdp_id, env_mdp_id)

            df_seq_num = row["sequence_number"]
            env_seq_num = env.sart[idx][1]
            self.assertEqual(df_seq_num, env_seq_num)

            df_state = preprocessor.process([row["state_features"]
                                             ])[0][0].numpy()
            env_state = env.sart[idx][2]
            npt.assert_array_equal(df_state, env_state)

            df_action = row["action"]
            env_action = str(env.sart[idx][3])
            self.assertEqual(df_action, env_action)

            df_terminal = row["next_action"] == ""
            env_terminal = env.sart[idx][5]
            self.assertEqual(df_terminal, env_terminal)
            if not df_terminal:
                df_reward = float(row["reward"])
                env_reward = float(env.sart[idx][4])
                npt.assert_allclose(df_reward, env_reward)

                df_next_state = preprocessor.process(
                    [row["next_state_features"]])[0][0].numpy()
                env_next_state = env.sart[idx + 1][2]
                npt.assert_array_equal(df_next_state, env_next_state)

                df_next_action = row["next_action"]
                env_next_action = str(env.sart[idx + 1][3])
                self.assertEqual(df_next_action, env_next_action)
            else:
                del env.sart[idx + 1]
Ejemplo n.º 6
0
def offline_gym(
    env_name: str,
    pkl_path: str,
    num_train_transitions: int,
    max_steps: Optional[int],
    seed: Optional[int] = None,
):
    """
    Generate samples from a DiscreteRandomPolicy on the Gym environment and
    saves results in a pandas df parquet.
    """
    initialize_seed(seed)
    env = Gym(env_name=env_name)

    replay_buffer = ReplayBuffer(replay_capacity=num_train_transitions,
                                 batch_size=1)
    fill_replay_buffer(env, replay_buffer, num_train_transitions)
    if isinstance(env.action_space, gym.spaces.Discrete):
        is_discrete_action = True
    else:
        assert isinstance(env.action_space, gym.spaces.Box)
        is_discrete_action = False
    df = replay_buffer_to_pre_timeline_df(is_discrete_action, replay_buffer)
    logger.info(f"Saving dataset with {len(df)} samples to {pkl_path}")
    df.to_pickle(pkl_path)
Ejemplo n.º 7
0
    def test_random_vs_lqr(self):
        """
        Test random actions vs. a LQR controller. LQR controller should perform
        much better than random actions in the linear dynamics environment.
        """
        env = Gym(env_name="LinearDynamics-v0")
        num_test_episodes = 500

        def random_policy(env, state):
            return np.random.uniform(
                env.action_space.low, env.action_space.high, env.action_dim
            )

        def lqr_policy(env, state):
            # Four matrices that characterize the environment
            A, B, Q, R = env.A, env.B, env.Q, env.R
            # Solve discrete algebraic Riccati equation:
            M = linalg.solve_discrete_are(A, B, Q, R)
            K = np.dot(
                linalg.inv(np.dot(np.dot(B.T, M), B) + R), (np.dot(np.dot(B.T, M), A))
            )
            state = state.reshape((-1, 1))
            action = -K.dot(state).squeeze()
            return action

        mean_acc_rws_random = self.run_n_episodes(env, num_test_episodes, random_policy)
        mean_acc_rws_lqr = self.run_n_episodes(env, num_test_episodes, lqr_policy)
        logger.info(f"Mean acc. reward of random policy: {mean_acc_rws_random}")
        logger.info(f"Mean acc. reward of LQR policy: {mean_acc_rws_lqr}")
        assert mean_acc_rws_lqr > mean_acc_rws_random
Ejemplo n.º 8
0
def evaluate_gym(
    env_name: str,
    model: ModelManager__Union,
    publisher: ModelPublisher__Union,
    num_eval_episodes: int,
    passing_score_bar: float,
    max_steps: Optional[int] = None,
):
    publisher_manager = publisher.value
    assert isinstance(
        publisher_manager, FileSystemPublisher
    ), f"publishing manager is type {type(publisher_manager)}, not FileSystemPublisher"
    env = Gym(env_name=env_name)
    torchscript_path = publisher_manager.get_latest_published_model(
        model.value)
    jit_model = torch.jit.load(torchscript_path)
    policy = create_predictor_policy_from_model(jit_model)
    agent = Agent.create_for_env_with_serving_policy(env, policy)
    rewards = evaluate_for_n_episodes(n=num_eval_episodes,
                                      env=env,
                                      agent=agent,
                                      max_steps=max_steps)
    avg_reward = np.mean(rewards)
    logger.info(f"Average reward over {num_eval_episodes} is {avg_reward}.\n"
                f"List of rewards: {rewards}")
    assert (avg_reward >= passing_score_bar
            ), f"{avg_reward} fails to pass the bar of {passing_score_bar}!"
    return
Ejemplo n.º 9
0
def offline_gym_random(
    env_name: str,
    pkl_path: str,
    num_train_transitions: int,
    max_steps: Optional[int],
    seed: int = 1,
):
    """
    Generate samples from a random Policy on the Gym environment and
    saves results in a pandas df parquet.
    """
    env = Gym(env_name=env_name)
    random_policy = make_random_policy_for_env(env)
    agent = Agent.create_for_env(env, policy=random_policy)
    return _offline_gym(env, agent, pkl_path, num_train_transitions, max_steps,
                        seed)
Ejemplo n.º 10
0
def offline_gym_predictor(
    env_name: str,
    model: ModelManager__Union,
    publisher: ModelPublisher__Union,
    pkl_path: str,
    num_train_transitions: int,
    max_steps: Optional[int],
    module_name: str = "default_model",
    seed: int = 1,
):
    """
    Generate samples from a trained Policy on the Gym environment and
    saves results in a pandas df parquet.
    """
    env = Gym(env_name=env_name)
    agent = make_agent_from_model(env, model, publisher, module_name)
    return _offline_gym(env, agent, pkl_path, num_train_transitions, max_steps,
                        seed)
Ejemplo n.º 11
0
 def setUp(self):
     logging.getLogger().setLevel(logging.DEBUG)
     env = Gym("CartPole-v0")
     norm = build_normalizer(env)
     net_builder = FullyConnected(sizes=[8], activations=["linear"])
     cartpole_scorer = net_builder.build_q_network(
         state_feature_config=None,
         state_normalization_data=norm["state"],
         output_dim=len(norm["action"].dense_normalization_parameters),
     )
     policy = Policy(scorer=cartpole_scorer, sampler=SoftmaxActionSampler())
     agent = Agent.create_for_env(env, policy)
     self.max_steps = 3
     self.num_episodes = 6
     self.dataset = EpisodicDataset(
         env=env,
         agent=agent,
         num_episodes=self.num_episodes,
         seed=0,
         max_steps=self.max_steps,
     )
Ejemplo n.º 12
0
def evaluate_gym(
    env_name: str,
    model: ModelManager__Union,
    publisher: ModelPublisher__Union,
    num_eval_episodes: int,
    passing_score_bar: float,
    module_name: str = "default_model",
    max_steps: Optional[int] = None,
):
    initialize_seed(1)
    env = Gym(env_name=env_name)
    agent = make_agent_from_model(env, model, publisher, module_name)

    rewards = evaluate_for_n_episodes(n=num_eval_episodes,
                                      env=env,
                                      agent=agent,
                                      max_steps=max_steps)
    avg_reward = np.mean(rewards)
    logger.info(f"Average reward over {num_eval_episodes} is {avg_reward}.\n"
                f"List of rewards: {rewards}\n"
                f"Passing score bar: {passing_score_bar}")
    assert (avg_reward >= passing_score_bar
            ), f"{avg_reward} fails to pass the bar of {passing_score_bar}!"
    return
Ejemplo n.º 13
0
def run_test_offline(
    env_name: str,
    model: ModelManager__Union,
    replay_memory_size: int,
    num_batches_per_epoch: int,
    num_train_epochs: int,
    passing_score_bar: float,
    num_eval_episodes: int,
    minibatch_size: int,
    use_gpu: bool,
):
    env = Gym(env_name=env_name)
    env.seed(SEED)
    env.action_space.seed(SEED)
    normalization = build_normalizer(env)
    logger.info(f"Normalization is: \n{pprint.pformat(normalization)}")

    manager = model.value
    trainer = manager.build_trainer(
        use_gpu=use_gpu,
        reward_options=RewardOptions(),
        normalization_data_map=normalization,
    )

    # first fill the replay buffer to burn_in
    replay_buffer = ReplayBuffer(
        replay_capacity=replay_memory_size, batch_size=minibatch_size
    )
    # always fill full RB
    random_policy = make_random_policy_for_env(env)
    agent = Agent.create_for_env(env, policy=random_policy)
    fill_replay_buffer(
        env=env,
        replay_buffer=replay_buffer,
        desired_size=replay_memory_size,
        agent=agent,
    )

    device = torch.device("cuda") if use_gpu else None
    dataset = OfflineReplayBufferDataset.create_for_trainer(
        trainer,
        env,
        replay_buffer,
        batch_size=minibatch_size,
        num_batches=num_batches_per_epoch,
        device=device,
    )
    data_loader = torch.utils.data.DataLoader(dataset, collate_fn=identity_collate)
    pl_trainer = pl.Trainer(
        max_epochs=num_train_epochs,
        gpus=int(use_gpu),
        deterministic=True,
        default_root_dir=f"lightning_log_{str(uuid.uuid4())}",
    )
    pl_trainer.fit(trainer, data_loader)

    logger.info(f"Evaluating after training for {num_train_epochs} epochs: ")
    eval_rewards = evaluate_cem(env, manager, trainer, num_eval_episodes)
    mean_rewards = np.mean(eval_rewards)
    assert (
        mean_rewards >= passing_score_bar
    ), f"{mean_rewards} doesn't pass the bar {passing_score_bar}."
Ejemplo n.º 14
0
def create_string_game_data(dataset_size=10000,
                            training_data_ratio=0.9,
                            filter_short_sequence=False):
    SEQ_LEN = 6
    NUM_ACTION = 2
    NUM_MDP_PER_BATCH = 5

    env = Gym(env_name="StringGame-v0", set_max_steps=SEQ_LEN)
    df = create_df_from_replay_buffer(
        env=env,
        problem_domain=ProblemDomain.DISCRETE_ACTION,
        desired_size=dataset_size,
        multi_steps=None,
        ds="2020-10-10",
    )

    if filter_short_sequence:
        batch_size = NUM_MDP_PER_BATCH
        time_diff = torch.ones(SEQ_LEN, batch_size)
        valid_step = SEQ_LEN * torch.ones(batch_size, dtype=torch.int64)[:,
                                                                         None]
        not_terminal = torch.Tensor(
            [0 if i == SEQ_LEN - 1 else 1 for i in range(SEQ_LEN)])
        not_terminal = torch.transpose(not_terminal.tile(NUM_MDP_PER_BATCH, 1),
                                       0, 1)
    else:
        batch_size = NUM_MDP_PER_BATCH * SEQ_LEN
        time_diff = torch.ones(SEQ_LEN, batch_size)
        valid_step = torch.arange(SEQ_LEN, 0, -1).tile(NUM_MDP_PER_BATCH)[:,
                                                                          None]
        not_terminal = torch.transpose(
            torch.tril(torch.ones(SEQ_LEN, SEQ_LEN),
                       diagonal=-1).tile(NUM_MDP_PER_BATCH, 1),
            0,
            1,
        )

    num_batches = int(dataset_size / SEQ_LEN / NUM_MDP_PER_BATCH)
    batches = [None for _ in range(num_batches)]
    batch_count, batch_seq_count = 0, 0
    batch_reward = torch.zeros(SEQ_LEN, batch_size)
    batch_action = torch.zeros(SEQ_LEN, batch_size, NUM_ACTION)
    batch_state = torch.zeros(SEQ_LEN, batch_size, NUM_ACTION)
    for mdp_id in sorted(set(df.mdp_id)):
        mdp = df[df["mdp_id"] == mdp_id].sort_values("sequence_number",
                                                     ascending=True)
        if len(mdp) != SEQ_LEN:
            continue

        all_step_reward = torch.Tensor(list(mdp["reward"]))
        all_step_state = torch.Tensor(
            [list(s.values()) for s in mdp["state_features"]])
        all_step_action = torch.zeros_like(all_step_state)
        all_step_action[torch.arange(SEQ_LEN),
                        [int(a) for a in mdp["action"]]] = 1.0

        for j in range(SEQ_LEN):
            if filter_short_sequence and j > 0:
                break

            reward = torch.zeros_like(all_step_reward)
            reward[:SEQ_LEN - j] = all_step_reward[-(SEQ_LEN - j):]
            batch_reward[:, batch_seq_count] = reward

            state = torch.zeros_like(all_step_state)
            state[:SEQ_LEN - j] = all_step_state[-(SEQ_LEN - j):]
            batch_state[:, batch_seq_count] = state

            action = torch.zeros_like(all_step_action)
            action[:SEQ_LEN - j] = all_step_action[-(SEQ_LEN - j):]
            batch_action[:, batch_seq_count] = action

            batch_seq_count += 1

        if batch_seq_count == batch_size:
            batches[batch_count] = rlt.MemoryNetworkInput(
                reward=batch_reward,
                action=batch_action,
                state=rlt.FeatureData(float_features=batch_state),
                next_state=rlt.FeatureData(float_features=torch.zeros_like(
                    batch_state)),  # fake, not used anyway
                not_terminal=not_terminal,
                time_diff=time_diff,
                valid_step=valid_step,
                step=None,
            )
            batch_count += 1
            batch_seq_count = 0
            batch_reward = torch.zeros_like(batch_reward)
            batch_action = torch.zeros_like(batch_action)
            batch_state = torch.zeros_like(batch_state)
    assert batch_count == num_batches

    num_training_batches = int(training_data_ratio * num_batches)
    training_data = batches[:num_training_batches]
    eval_data = batches[num_training_batches:]
    return training_data, eval_data
Ejemplo n.º 15
0
 def test_string_game(self):
     env = Gym(env_name="StringGame-v0")
     env.seed(313)
     mean_acc_reward = self._test_env(env)
     assert 0.1 >= mean_acc_reward
Ejemplo n.º 16
0
def train_mdnrnn_and_train_on_embedded_env(
    env_name: str,
    embedding_model: ModelManager__Union,
    num_embedding_train_transitions: int,
    seq_len: int,
    batch_size: int,
    num_embedding_train_epochs: int,
    train_model: ModelManager__Union,
    num_state_embed_transitions: int,
    num_agent_train_epochs: int,
    num_agent_eval_epochs: int,
    use_gpu: bool,
    passing_score_bar: float,
    # pyre-fixme[9]: saved_mdnrnn_path has type `str`; used as `None`.
    saved_mdnrnn_path: str = None,
):
    """ Train an agent on embedded states by the MDNRNN. """
    env = Gym(env_name=env_name)
    env.seed(SEED)

    embedding_manager = embedding_model.value
    embedding_trainer = embedding_manager.initialize_trainer(
        use_gpu=use_gpu,
        reward_options=RewardOptions(),
        normalization_data_map=build_normalizer(env),
    )

    device = "cuda" if use_gpu else "cpu"
    embedding_trainer_preprocessor = make_replay_buffer_trainer_preprocessor(
        embedding_trainer,
        # pyre-fixme[6]: Expected `device` for 2nd param but got `str`.
        device,
        env,
    )
    if saved_mdnrnn_path is None:
        # train from scratch
        embedding_trainer = train_mdnrnn(
            env=env,
            trainer=embedding_trainer,
            trainer_preprocessor=embedding_trainer_preprocessor,
            num_train_transitions=num_embedding_train_transitions,
            seq_len=seq_len,
            batch_size=batch_size,
            num_train_epochs=num_embedding_train_epochs,
        )
    else:
        # load a pretrained model, and just evaluate it
        embedding_trainer.memory_network.mdnrnn.load_state_dict(
            torch.load(saved_mdnrnn_path))

    # create embedding dataset
    embed_rb, state_min, state_max = create_embed_rl_dataset(
        env=env,
        memory_network=embedding_trainer.memory_network,
        num_state_embed_transitions=num_state_embed_transitions,
        batch_size=batch_size,
        seq_len=seq_len,
        hidden_dim=embedding_trainer.params.hidden_size,
        use_gpu=use_gpu,
    )
    embed_env = StateEmbedEnvironment(
        gym_env=env,
        mdnrnn=embedding_trainer.memory_network,
        max_embed_seq_len=seq_len,
        state_min_value=state_min,
        state_max_value=state_max,
    )
    agent_manager = train_model.value
    agent_trainer = agent_manager.initialize_trainer(
        use_gpu=use_gpu,
        reward_options=RewardOptions(),
        # pyre-fixme[6]: Expected `EnvWrapper` for 1st param but got
        #  `StateEmbedEnvironment`.
        normalization_data_map=build_normalizer(embed_env),
    )
    device = "cuda" if use_gpu else "cpu"
    agent_trainer_preprocessor = make_replay_buffer_trainer_preprocessor(
        agent_trainer,
        # pyre-fixme[6]: Expected `device` for 2nd param but got `str`.
        device,
        env,
    )
    num_batch_per_epoch = embed_rb.size // batch_size
    # FIXME: This has to be wrapped in dataloader
    for epoch in range(num_agent_train_epochs):
        for _ in tqdm(range(num_batch_per_epoch), desc=f"epoch {epoch}"):
            batch = embed_rb.sample_transition_batch(batch_size=batch_size)
            preprocessed_batch = agent_trainer_preprocessor(batch)
            # FIXME: This should be fitted with Lightning's trainer
            agent_trainer.train(preprocessed_batch)

    # evaluate model
    rewards = []
    policy = agent_manager.create_policy(serving=False)
    # pyre-fixme[6]: Expected `EnvWrapper` for 1st param but got
    #  `StateEmbedEnvironment`.
    agent = Agent.create_for_env(embed_env, policy=policy, device=device)
    # num_processes=1 needed to avoid workers from dying on CircleCI tests
    rewards = evaluate_for_n_episodes(
        n=num_agent_eval_epochs,
        # pyre-fixme[6]: Expected `EnvWrapper` for 2nd param but got
        #  `StateEmbedEnvironment`.
        env=embed_env,
        agent=agent,
        num_processes=1,
    )
    assert (np.mean(rewards) >= passing_score_bar
            ), f"average reward doesn't pass our bar {passing_score_bar}"
    return rewards
Ejemplo n.º 17
0
def train_mdnrnn_and_compute_feature_stats(
    env_name: str,
    model: ModelManager__Union,
    num_train_transitions: int,
    num_test_transitions: int,
    seq_len: int,
    batch_size: int,
    num_train_epochs: int,
    use_gpu: bool,
    saved_mdnrnn_path: Optional[str] = None,
):
    """ Train MDNRNN Memory Network and compute feature importance/sensitivity. """
    env: gym.Env = Gym(env_name=env_name)
    env.seed(SEED)

    manager = model.value
    trainer = manager.initialize_trainer(
        use_gpu=use_gpu,
        reward_options=RewardOptions(),
        normalization_data_map=build_normalizer(env),
    )

    device = "cuda" if use_gpu else "cpu"
    # pyre-fixme[6]: Expected `device` for 2nd param but got `str`.
    trainer_preprocessor = make_replay_buffer_trainer_preprocessor(
        trainer, device, env)
    test_replay_buffer = ReplayBuffer(
        replay_capacity=num_test_transitions,
        batch_size=batch_size,
        stack_size=seq_len,
        return_everything_as_stack=True,
    )
    fill_replay_buffer(env, test_replay_buffer, num_test_transitions)

    if saved_mdnrnn_path is None:
        # train from scratch
        trainer = train_mdnrnn(
            env=env,
            trainer=trainer,
            trainer_preprocessor=trainer_preprocessor,
            num_train_transitions=num_train_transitions,
            seq_len=seq_len,
            batch_size=batch_size,
            num_train_epochs=num_train_epochs,
            test_replay_buffer=test_replay_buffer,
        )
    else:
        # load a pretrained model, and just evaluate it
        trainer.memory_network.mdnrnn.load_state_dict(
            torch.load(saved_mdnrnn_path))

    with torch.no_grad():
        trainer.memory_network.mdnrnn.eval()
        test_batch = test_replay_buffer.sample_transition_batch(
            batch_size=test_replay_buffer.size)
        preprocessed_test_batch = trainer_preprocessor(test_batch)
        feature_importance = calculate_feature_importance(
            env=env,
            trainer=trainer,
            use_gpu=use_gpu,
            test_batch=preprocessed_test_batch,
        )

        feature_sensitivity = calculate_feature_sensitivity(
            env=env,
            trainer=trainer,
            use_gpu=use_gpu,
            test_batch=preprocessed_test_batch,
        )

        trainer.memory_network.mdnrnn.train()
    return feature_importance, feature_sensitivity
Ejemplo n.º 18
0
 def test_pocman(self):
     env = Gym(env_name="Pocman-v0")
     env.seed(313)
     mean_acc_reward = self._test_env(env)
     assert -80 <= mean_acc_reward <= -70