Exemplo n.º 1
0
def train_mdnrnn(
    env: EnvWrapper,
    trainer: MDNRNNTrainer,
    trainer_preprocessor,
    num_train_transitions: int,
    seq_len: int,
    batch_size: int,
    num_train_epochs: int,
    # for optional validation
    test_replay_buffer=None,
):
    train_replay_buffer = ReplayBuffer(
        replay_capacity=num_train_transitions,
        batch_size=batch_size,
        stack_size=seq_len,
        return_everything_as_stack=True,
    )
    random_policy = make_random_policy_for_env(env)
    agent = Agent.create_for_env(env, policy=random_policy)
    fill_replay_buffer(env, train_replay_buffer, num_train_transitions, agent)
    num_batch_per_epoch = train_replay_buffer.size // batch_size

    logger.info("Made RBs, starting to train now!")
    optimizer = trainer.configure_optimizers()[0]
    for _ in range(num_train_epochs):
        for i in range(num_batch_per_epoch):
            batch = train_replay_buffer.sample_transition_batch(batch_size=batch_size)
            preprocessed_batch = trainer_preprocessor(batch)
            loss = next(trainer.train_step_gen(preprocessed_batch, i))
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        # validation
        if test_replay_buffer is not None:
            with torch.no_grad():
                trainer.memory_network.mdnrnn.eval()
                test_batch = test_replay_buffer.sample_transition_batch(
                    batch_size=batch_size
                )
                preprocessed_test_batch = trainer_preprocessor(test_batch)
                valid_losses = trainer.get_loss(preprocessed_test_batch)
                trainer.memory_network.mdnrnn.train()
    return trainer
Exemplo n.º 2
0
def train_mdnrnn(
    env: gym.Env,
    trainer: MDNRNNTrainer,
    trainer_preprocessor,
    num_train_transitions: int,
    seq_len: int,
    batch_size: int,
    num_train_epochs: int,
    # for optional validation
    test_replay_buffer=None,
):
    train_replay_buffer = ReplayBuffer.create_from_env(
        env=env,
        replay_memory_size=num_train_transitions,
        batch_size=batch_size,
        stack_size=seq_len,
        return_everything_as_stack=True,
    )
    fill_replay_buffer(env, train_replay_buffer, num_train_transitions)
    num_batch_per_epoch = train_replay_buffer.size // batch_size
    logger.info("Made RBs, starting to train now!")
    for epoch in range(num_train_epochs):
        for i in range(num_batch_per_epoch):
            batch = train_replay_buffer.sample_transition_batch_tensor(
                batch_size=batch_size
            )
            preprocessed_batch = trainer_preprocessor(batch)
            losses = trainer.train(preprocessed_batch)
            print_mdnrnn_losses(epoch, i, losses)

        # validation
        if test_replay_buffer is not None:
            with torch.no_grad():
                trainer.memory_network.mdnrnn.eval()
                test_batch = test_replay_buffer.sample_transition_batch_tensor(
                    batch_size=batch_size
                )
                preprocessed_test_batch = trainer_preprocessor(test_batch)
                valid_losses = trainer.get_loss(preprocessed_test_batch)
                print_mdnrnn_losses(epoch, "validation", valid_losses)
                trainer.memory_network.mdnrnn.train()
    return trainer
Exemplo n.º 3
0
def train_sgd(
    gym_env: OpenAIGymEnvironment,
    trainer: MDNRNNTrainer,
    use_gpu: bool,
    test_run_name: str,
    minibatch_size: int,
    run_details: OpenAiRunDetails,
    test_batch: rlt.PreprocessedTrainingBatch,
):
    assert run_details.max_steps is not None
    train_replay_buffer = get_replay_buffer(
        run_details.num_train_episodes,
        run_details.seq_len,
        run_details.max_steps,
        gym_env,
    )
    valid_replay_buffer = get_replay_buffer(
        run_details.num_test_episodes,
        run_details.seq_len,
        run_details.max_steps,
        gym_env,
    )
    valid_batch = valid_replay_buffer.sample_memories(
        valid_replay_buffer.memory_size, use_gpu=use_gpu, batch_first=True)
    valid_loss_history = []

    num_batch_per_epoch = train_replay_buffer.memory_size // minibatch_size
    logger.info(
        "Collected data {} transitions.\n"
        "Training will take {} epochs, with each epoch having {} mini-batches"
        " and each mini-batch having {} samples".format(
            train_replay_buffer.memory_size,
            run_details.train_epochs,
            num_batch_per_epoch,
            minibatch_size,
        ))

    for i_epoch in range(run_details.train_epochs):
        for i_batch in range(num_batch_per_epoch):
            training_batch = train_replay_buffer.sample_memories(
                minibatch_size, use_gpu=use_gpu, batch_first=True)
            losses = trainer.train(training_batch, batch_first=True)
            logger.info(
                "{}-th epoch, {}-th minibatch: \n"
                "loss={}, bce={}, gmm={}, mse={} \n"
                "cum loss={}, cum bce={}, cum gmm={}, cum mse={}\n".format(
                    i_epoch,
                    i_batch,
                    losses["loss"],
                    losses["bce"],
                    losses["gmm"],
                    losses["mse"],
                    np.mean(trainer.cum_loss),
                    np.mean(trainer.cum_bce),
                    np.mean(trainer.cum_gmm),
                    np.mean(trainer.cum_mse),
                ))

        trainer.mdnrnn.mdnrnn.eval()
        valid_losses = trainer.get_loss(valid_batch,
                                        state_dim=gym_env.state_dim,
                                        batch_first=True)
        valid_losses = loss_to_num(valid_losses)
        valid_loss_history.append(valid_losses)
        trainer.mdnrnn.mdnrnn.train()
        logger.info(
            "{}-th epoch, validate loss={}, bce={}, gmm={}, mse={}".format(
                i_epoch,
                valid_losses["loss"],
                valid_losses["bce"],
                valid_losses["gmm"],
                valid_losses["mse"],
            ))
        latest_loss = valid_loss_history[-1]["loss"]
        recent_valid_loss_hist = valid_loss_history[-1 - run_details.
                                                    early_stopping_patience:-1]
        # earlystopping
        if len(valid_loss_history
               ) > run_details.early_stopping_patience and all(
                   (latest_loss >= v["loss"] for v in recent_valid_loss_hist)):
            break

    trainer.mdnrnn.mdnrnn.eval()
    test_losses = trainer.get_loss(test_batch,
                                   state_dim=gym_env.state_dim,
                                   batch_first=True)
    test_losses = loss_to_num(test_losses)
    logger.info("Test loss: {}, bce={}, gmm={}, mse={}".format(
        test_losses["loss"],
        test_losses["bce"],
        test_losses["gmm"],
        test_losses["mse"],
    ))
    logger.info("Valid loss history: {}".format(valid_loss_history))
    return test_losses, valid_loss_history, trainer