"loss_valid": [],
    "acc_train": [],
    "acc_valid": [],
}

# Create callbacks
checkpoint = CheckPoint(
    model=model,
    optimizer=optimizer,
    loss_fn=loss_fn,
    savedir=CHECKPOINT_DIR,
    improved_delta=IMPROVED_DELTA,
    last_best_loss=np.inf,
)
earlystop = EarlyStopping(
    not_improved_thres=N_NOT_IMPROVED,
    improved_delta=IMPROVED_DELTA,
)

#------------------------------------------------------------------------------
#   Train the model
#------------------------------------------------------------------------------
for epoch in range(1, args.n_epochs + 1):
    print("------------------------------------------------------------------")
    # Train model
    loss_train, acc_train, time_train = train_on_epoch(model=model,
                                                       device=DEVICE,
                                                       dataloader=train_loader,
                                                       loss_fn=loss_fn,
                                                       optimizer=optimizer,
                                                       epoch=epoch)
示例#2
0
vae = VAE(3, LSIZE).to(device)
vae.load_state_dict(state['state_dict'])

# Loading model
rnn_dir = join(args.logdir, 'mdrnn')
rnn_file = join(rnn_dir, 'best.tar')

if not exists(rnn_dir):
    mkdir(rnn_dir)

mdrnn = MDRNN(LSIZE, ASIZE, RSIZE, 5)
mdrnn.to(device)
optimizer = torch.optim.RMSprop(mdrnn.parameters(), lr=1e-3, alpha=.9)
scheduler = ReduceLROnPlateau(optimizer, 'min', factor=0.5, patience=5)
earlystopping = EarlyStopping('min', patience=30)


if exists(rnn_file) and not args.noreload:
    rnn_state = torch.load(rnn_file)
    print("Loading MDRNN at epoch {} "
          "with test error {}".format(
              rnn_state["epoch"], rnn_state["precision"]))
    mdrnn.load_state_dict(rnn_state["state_dict"])
    optimizer.load_state_dict(rnn_state["optimizer"])
    scheduler.load_state_dict(state['scheduler'])
    earlystopping.load_state_dict(state['earlystopping'])


# Data Loading
transform = transforms.Lambda(
示例#3
0
    # create the network the network currently lacks a method to pass in inputs but it will need to know:
    # sizes of input and outputs for each of the embeddings,
    # size of outputs and number of layers for encoders combining real and int features for rnn and mlp
    # size of output for LSTM
    # number of layers for decoder predicting the target

    net = ISD_RNN_MLP(in_seq_list, out_seq_list, in_stat_list, out_stat_list, in_seq_real, in_stat_real, args.in_rnn, args.out_rnn, args.out_mlp).to(device)

    print(next(net.parameters()).is_cuda)

    # Set up the optimizer to work on the network parameters
    optimizer = torch.optim.RMSprop(net.parameters(), lr=args.lr, alpha=.9)
    # Scheduler and early stopper, scheduler specifies how do reduce lr once the accuracy saturates
    scheduler = ReduceLROnPlateau(optimizer, 'min', factor=0.5, patience=5)
    # Early stopper stops the job after there is no increase for patience epochs
    early_stopper = EarlyStopping('min', patience=10)
    # Actual training loop
    cur_best = None

    #rnn_epoch(0, train=True)

    for e in range(args.epochs):
        #for each epoch runs one loop on train and one validation
        rnn_epoch(e, train=True)
        with torch.no_grad():
            test_loss = rnn_epoch(e, train=False)
        #check if new test_loss is the current best
        is_best = not cur_best or test_loss < cur_best  # TODO: finish this
        if is_best:
            cur_best = test_loss
        #saves the last model and in case it is the best one it calls it best.tar
示例#4
0
vae = VAE(3, LSIZE).to(device)
vae.load_state_dict(state['state_dict'])

# Loading model
rnn_dir = join(args.logdir, 'mdrnn')
rnn_file = join(rnn_dir, 'best.tar')

if not exists(rnn_dir):
    mkdir(rnn_dir)

mdrnn = MDRNN(LSIZE, ASIZE, RSIZE, 5)
mdrnn.to(device)
optimizer = torch.optim.RMSprop(mdrnn.parameters(), lr=1e-3, alpha=.9)
scheduler = ReduceLROnPlateau(optimizer, 'min', factor=0.5, patience=5)
earlystopping = EarlyStopping('min', patience=30)


if exists(rnn_file) and not args.noreload:
    rnn_state = torch.load(rnn_file)
    print("Loading MDRNN at epoch {} "
          "with test error {}".format(
              rnn_state["epoch"], rnn_state["precision"]))
    mdrnn.load_state_dict(rnn_state["state_dict"])
    optimizer.load_state_dict(rnn_state["optimizer"])
    scheduler.load_state_dict(state['scheduler'])
    earlystopping.load_state_dict(state['earlystopping'])


# Data Loading
transform = transforms.Lambda(
示例#5
0
# training
prev_rnn_dir = rnn_dir
if args.iteration_num is not None:
    rnn_dir = join(args.logdir, "mdrnn", "iter_{}".format(args.iteration_num))
    prev_rnn_dir = join(args.logdir, "mdrnn",
                        "iter_{}".format(args.iteration_num - 1))
if not exists(rnn_dir):
    makedirs(rnn_dir)
rnn_file = join(rnn_dir, "best.tar")
prev_rnn_file = join(prev_rnn_dir, "best.tar")

mdrnn = MDRNN(LSIZE, ASIZE, RSIZE, 5)
mdrnn.to(device)
optimizer = torch.optim.RMSprop(mdrnn.parameters(), lr=1e-3, alpha=0.9)
scheduler = ReduceLROnPlateau(optimizer, "min", factor=0.5, patience=5)
earlystopping = EarlyStopping("min", patience=30)

if exists(prev_rnn_file) and not args.noreload:
    rnn_state = torch.load(prev_rnn_file)
    print("Loading MDRNN at epoch {} "
          "with test error {}".format(rnn_state["epoch"],
                                      rnn_state["precision"]))
    print("MDRNN loaded from {}".format(prev_rnn_file))
    mdrnn.load_state_dict(rnn_state["state_dict"])
    optimizer.load_state_dict(rnn_state["optimizer"])
    scheduler.load_state_dict(state["scheduler"])
    earlystopping.load_state_dict(state["earlystopping"])

# Data Loading
dataset_dir = args.dataset_dir
if args.iteration_num is not None:
示例#6
0
def train_vae(logdir, traindir, epochs=100, testdir=None):
    print('Training VAE using traindir', traindir)
    batch_size = 100  # maybe should change this back to their initial one of 32
    noreload = False  #Best model is not reloaded if specified
    nosamples = False  #'Does not save samples during training if specified'

    testdir = testdir if testdir else traindir

    torch.manual_seed(123)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    transform_train = transforms.Compose([
        transforms.ToPILImage(),
        transforms.Resize((RED_SIZE, RED_SIZE)),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
    ])

    transform_test = transforms.Compose([
        transforms.ToPILImage(),
        transforms.Resize((RED_SIZE, RED_SIZE)),
        transforms.ToTensor(),
    ])

    dataset_train = RolloutObservationDataset(traindir,
                                              transform_train,
                                              train=True)
    dataset_test = RolloutObservationDataset(testdir,
                                             transform_test,
                                             train=False)
    train_loader = torch.utils.data.DataLoader(dataset_train,
                                               batch_size=batch_size,
                                               shuffle=True,
                                               num_workers=2)
    test_loader = torch.utils.data.DataLoader(dataset_test,
                                              batch_size=batch_size,
                                              shuffle=True,
                                              num_workers=2)

    model = VAE(3, LSIZE).to(device)
    optimizer = optim.Adam(model.parameters())
    scheduler = ReduceLROnPlateau(optimizer, 'min', factor=0.5, patience=5)
    earlystopping = EarlyStopping('min', patience=30)

    # Reconstruction + KL divergence losses summed over all elements and batch
    def loss_function(recon_x, x, mu, logsigma):
        """ VAE loss function """
        BCE = F.mse_loss(recon_x, x, size_average=False)

        # see Appendix B from VAE paper:
        # Kingma and Welling. Auto-Encoding Variational Bayes. ICLR, 2014
        # https://arxiv.org/abs/1312.6114
        # 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)
        KLD = -0.5 * torch.sum(1 + 2 * logsigma - mu.pow(2) -
                               (2 * logsigma).exp())
        return BCE + KLD

    def train(epoch):
        """ One training epoch """
        model.train()
        dataset_train.load_next_buffer()
        train_loss = 0
        for batch_idx, data in enumerate(train_loader):
            data = data.to(device)
            optimizer.zero_grad()
            recon_batch, mu, logvar = model(data)
            loss = loss_function(recon_batch, data, mu, logvar)
            loss.backward()
            train_loss += loss.item()
            optimizer.step()
            if batch_idx % 20 == 0:
                print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                    epoch, batch_idx * len(data), len(train_loader.dataset),
                    100. * batch_idx / len(train_loader),
                    loss.item() / len(data)))

        print('====> Epoch: {} Average loss: {:.4f}'.format(
            epoch, train_loss / len(train_loader.dataset)))

    def test():
        """ One test epoch """
        model.eval()
        dataset_test.load_next_buffer()
        test_loss = 0
        with torch.no_grad():
            for data in test_loader:
                data = data.to(device)
                recon_batch, mu, logvar = model(data)
                test_loss += loss_function(recon_batch, data, mu,
                                           logvar).item()

        test_loss /= len(test_loader.dataset)
        print('====> Test set loss: {:.4f}'.format(test_loss))
        return test_loss

    # check vae dir exists, if not, create it
    vae_dir = join(logdir, 'vae')
    if not exists(vae_dir):
        mkdir(vae_dir)
        mkdir(join(vae_dir, 'samples'))

    reload_file = join(vae_dir, 'best.tar')
    if not noreload and exists(reload_file):
        state = torch.load(reload_file)
        print("Reloading model at epoch {}"
              ", with test error {}".format(state['epoch'],
                                            state['precision']))
        model.load_state_dict(state['state_dict'])
        optimizer.load_state_dict(state['optimizer'])
        scheduler.load_state_dict(state['scheduler'])
        earlystopping.load_state_dict(state['earlystopping'])

    cur_best = None

    for epoch in range(1, epochs + 1):
        train(epoch)
        test_loss = test()
        scheduler.step(test_loss)
        earlystopping.step(test_loss)

        # checkpointing
        best_filename = join(vae_dir, 'best.tar')
        filename = join(vae_dir, 'checkpoint.tar')
        is_best = not cur_best or test_loss < cur_best
        if is_best:
            cur_best = test_loss

        save_checkpoint(
            {
                'epoch': epoch,
                'state_dict': model.state_dict(),
                'precision': test_loss,
                'optimizer': optimizer.state_dict(),
                'scheduler': scheduler.state_dict(),
                'earlystopping': earlystopping.state_dict()
            }, is_best, filename, best_filename)

        if not nosamples:
            with torch.no_grad():
                sample = torch.randn(RED_SIZE, LSIZE).to(device)
                sample = model.decoder(sample).cpu()
                save_image(
                    sample.view(64, 3, RED_SIZE, RED_SIZE),
                    join(vae_dir, 'samples/sample_' + str(epoch) + '.png'))

        if earlystopping.stop:
            print(
                "End of Training because of early stopping at epoch {}".format(
                    epoch))
            break
示例#7
0
                                 rollout_time_limit)
        r_gen.rollout(flatten_parameters(controller.parameters()), render=True)


# 1. Random Rollout 수행을 통한 experience data 확보
#generate_random_rollout_data(random_rollout_dir, random_rollout_num)

# 2-1. VAE를 train할 dataset 생성
v_dataset_train, v_dataset_test = make_vae_dataset(rollout_root_dir)

# 2-2. VAE 모델(V) 생성
v_model = VAE(3, LSIZE).to(device)
v_optimizer = optim.Adam(v_model.parameters())
v_scheduler = ReduceLROnPlateau(v_optimizer, 'min', factor=0.5,
                                patience=5)  # originally 5
v_earlystopping = EarlyStopping('min', patience=30)  # patience 30 -> 10

# 2-3. VAE 모델(V) 훈련
v_model_train_proc(vae_dir,
                   v_model,
                   v_dataset_train,
                   v_dataset_test,
                   v_optimizer,
                   v_scheduler,
                   v_earlystopping,
                   skip_train=True,
                   max_train_epochs=1000)

# hardmaru 는 rollout 10k, epoch=10 으로 vae 트레이닝을 끝냈다.
# ctellec 은 rollout 1k, max epoch=1000 이나 줬는데, 100이면 충분한게 아니었나 싶다.
示例#8
0
def train_mdrnn(logdir, traindir, epochs=10, testdir=None):
    BSIZE = 80 # maybe should change this back to their initial one of 16
    noreload = False #Best model is not reloaded if specified
    SEQ_LEN = 32
    epochs = int(epochs)

    testdir = testdir if testdir else traindir
    cuda = torch.cuda.is_available()

    torch.manual_seed(123)
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


    # Loading VAE
    vae_file = join(logdir, 'vae', 'best.tar')
    assert exists(vae_file), "No trained VAE in the logdir..."
    state = torch.load(vae_file)
    print("Loading VAE at epoch {} "
          "with test error {}".format(
              state['epoch'], state['precision']))

    vae = VAE(3, LSIZE).to(device)
    vae.load_state_dict(state['state_dict'])

    # Loading model
    rnn_dir = join(logdir, 'mdrnn')
    rnn_file = join(rnn_dir, 'best.tar')

    if not exists(rnn_dir):
        mkdir(rnn_dir)

    mdrnn = MDRNN(LSIZE, ASIZE, RSIZE, 5)
    mdrnn.to(device)
    optimizer = torch.optim.RMSprop(mdrnn.parameters(), lr=1e-3, alpha=.9)
    scheduler = ReduceLROnPlateau(optimizer, 'min', factor=0.5, patience=5)
    earlystopping = EarlyStopping('min', patience=30)


    if exists(rnn_file) and not noreload:
        rnn_state = torch.load(rnn_file)
        print("Loading MDRNN at epoch {} "
              "with test error {}".format(
                  rnn_state["epoch"], rnn_state["precision"]))
        mdrnn.load_state_dict(rnn_state["state_dict"])
        optimizer.load_state_dict(rnn_state["optimizer"])
        scheduler.load_state_dict(state['scheduler'])
        earlystopping.load_state_dict(state['earlystopping'])


    # Data Loading
    transform = transforms.Lambda(
        lambda x: np.transpose(x, (0, 3, 1, 2)) / 255)
    train_loader = DataLoader(
        RolloutSequenceDataset(traindir, SEQ_LEN, transform, buffer_size=30),
        batch_size=BSIZE, num_workers=8, shuffle=True)
    test_loader = DataLoader(
        RolloutSequenceDataset(testdir, SEQ_LEN, transform, train=False, buffer_size=10),
        batch_size=BSIZE, num_workers=8)

    def to_latent(obs, next_obs):
        """ Transform observations to latent space.

        :args obs: 5D torch tensor (BSIZE, SEQ_LEN, ASIZE, SIZE, SIZE)
        :args next_obs: 5D torch tensor (BSIZE, SEQ_LEN, ASIZE, SIZE, SIZE)

        :returns: (latent_obs, latent_next_obs)
            - latent_obs: 4D torch tensor (BSIZE, SEQ_LEN, LSIZE)
            - next_latent_obs: 4D torch tensor (BSIZE, SEQ_LEN, LSIZE)
        """
        with torch.no_grad():
            obs, next_obs = [
                f.upsample(x.view(-1, 3, SIZE, SIZE), size=RED_SIZE,
                           mode='bilinear', align_corners=True)
                for x in (obs, next_obs)]

            (obs_mu, obs_logsigma), (next_obs_mu, next_obs_logsigma) = [
                vae(x)[1:] for x in (obs, next_obs)]

            latent_obs, latent_next_obs = [
                (x_mu + x_logsigma.exp() * torch.randn_like(x_mu)).view(BSIZE, SEQ_LEN, LSIZE)
                for x_mu, x_logsigma in
                [(obs_mu, obs_logsigma), (next_obs_mu, next_obs_logsigma)]]
        return latent_obs, latent_next_obs

    def get_loss(latent_obs, action, reward, terminal, latent_next_obs):
        """ Compute losses.

        The loss that is computed is:
        (GMMLoss(latent_next_obs, GMMPredicted) + MSE(reward, predicted_reward) +
             BCE(terminal, logit_terminal)) / (LSIZE + 2)
        The LSIZE + 2 factor is here to counteract the fact that the GMMLoss scales
        approximately linearily with LSIZE. All losses are averaged both on the
        batch and the sequence dimensions (the two first dimensions).

        :args latent_obs: (BSIZE, SEQ_LEN, LSIZE) torch tensor
        :args action: (BSIZE, SEQ_LEN, ASIZE) torch tensor
        :args reward: (BSIZE, SEQ_LEN) torch tensor
        :args latent_next_obs: (BSIZE, SEQ_LEN, LSIZE) torch tensor

        :returns: dictionary of losses, containing the gmm, the mse, the bce and
            the averaged loss.
        """
        latent_obs, action,\
            reward, terminal,\
            latent_next_obs = [arr.transpose(1, 0)
                               for arr in [latent_obs, action,
                                           reward, terminal,
                                           latent_next_obs]]
        mus, sigmas, logpi, rs, ds = mdrnn(action, latent_obs)
        gmm = gmm_loss(latent_next_obs, mus, sigmas, logpi)
        bce = f.binary_cross_entropy_with_logits(ds, terminal)
        mse = f.mse_loss(rs, reward)
        loss = (gmm + bce + mse) / (LSIZE + 2)
        return dict(gmm=gmm, bce=bce, mse=mse, loss=loss)


    def data_pass(epoch, train): # pylint: disable=too-many-locals
        """ One pass through the data """
        if train:
            mdrnn.train()
            loader = train_loader
        else:
            mdrnn.eval()
            loader = test_loader

        loader.dataset.load_next_buffer()

        cum_loss = 0
        cum_gmm = 0
        cum_bce = 0
        cum_mse = 0

        pbar = tqdm(total=len(loader.dataset), desc="Epoch {}".format(epoch))
        for i, data in enumerate(loader):
            obs, action, reward, terminal, next_obs = [arr.to(device) for arr in data]

            # transform obs
            latent_obs, latent_next_obs = to_latent(obs, next_obs)

            if train:
                losses = get_loss(latent_obs, action, reward,
                                  terminal, latent_next_obs)

                optimizer.zero_grad()
                losses['loss'].backward()
                optimizer.step()
            else:
                with torch.no_grad():
                    losses = get_loss(latent_obs, action, reward,
                                      terminal, latent_next_obs)

            cum_loss += losses['loss'].item()
            cum_gmm += losses['gmm'].item()
            cum_bce += losses['bce'].item()
            cum_mse += losses['mse'].item()

            pbar.set_postfix_str("loss={loss:10.6f} bce={bce:10.6f} "
                                 "gmm={gmm:10.6f} mse={mse:10.6f}".format(
                                     loss=cum_loss / (i + 1), bce=cum_bce / (i + 1),
                                     gmm=cum_gmm / LSIZE / (i + 1), mse=cum_mse / (i + 1)))
            pbar.update(BSIZE)
        pbar.close()
        return cum_loss * BSIZE / len(loader.dataset)

    train = partial(data_pass, train=True)
    test = partial(data_pass, train=False)

    for e in range(epochs):
        cur_best = None
        train(e)
        test_loss = test(e)
        scheduler.step(test_loss)
        earlystopping.step(test_loss)

        is_best = not cur_best or test_loss < cur_best
        if is_best:
            cur_best = test_loss
        checkpoint_fname = join(rnn_dir, 'checkpoint.tar')
        save_checkpoint({
            "state_dict": mdrnn.state_dict(),
            "optimizer": optimizer.state_dict(),
            'scheduler': scheduler.state_dict(),
            'earlystopping': earlystopping.state_dict(),
            "precision": test_loss,
            "epoch": e}, is_best, checkpoint_fname,
                        rnn_file)

        if earlystopping.stop:
            print("End of Training because of early stopping at epoch {}".format(e))
            break
示例#9
0
    def test_mdrnn_learning(self):
        num_epochs = 300
        num_episodes = 400
        batch_size = 200
        action_dim = 2
        seq_len = 5
        state_dim = 2
        simulated_num_gaussian = 2
        mdrnn_num_gaussian = 2
        simulated_hidden_size = 3
        mdrnn_hidden_size = 10
        mdrnn_hidden_layer = 1
        adam_lr = 0.01
        cur_state_mem = numpy.zeros((num_episodes, seq_len, state_dim))
        next_state_mem = numpy.zeros((num_episodes, seq_len, state_dim))
        action_mem = numpy.zeros((num_episodes, seq_len, action_dim))
        reward_mem = numpy.zeros((num_episodes, seq_len))
        terminal_mem = numpy.zeros((num_episodes, seq_len))
        next_mus_mem = numpy.zeros(
            (num_episodes, seq_len, simulated_num_gaussian, state_dim))

        swm = SimulatedWorldModel(
            action_dim=action_dim,
            state_dim=state_dim,
            num_gaussian=simulated_num_gaussian,
            lstm_num_layer=1,
            lstm_hidden_dim=simulated_hidden_size,
        )

        actions = torch.eye(action_dim)
        for e in range(num_episodes):
            swm.init_hidden(batch_size=1)
            next_state = torch.randn((1, 1, state_dim))
            for s in range(seq_len):
                cur_state = next_state

                action = torch.tensor(
                    actions[numpy.random.randint(action_dim)]).view(
                        1, 1, action_dim)
                next_mus, reward = swm(action, cur_state)
                terminal = 0
                if s == seq_len - 1:
                    terminal = 1

                next_pi = torch.ones(
                    simulated_num_gaussian) / simulated_num_gaussian
                index = Categorical(next_pi).sample((1, )).long().item()
                next_state = next_mus[0, 0, index].view(1, 1, state_dim)

                print(
                    "{} cur_state: {}, action: {}, next_state: {}, reward: {}, terminal: {}"
                    .format(e, cur_state, action, next_state, reward,
                            terminal))
                print("next_pi: {}, sampled index: {}".format(next_pi, index))
                print("next_mus:", next_mus, "\n")

                cur_state_mem[e, s, :] = cur_state.detach().numpy()
                action_mem[e, s, :] = action.numpy()
                reward_mem[e, s] = reward.detach().numpy()
                terminal_mem[e, s] = terminal
                next_state_mem[e, s, :] = next_state.detach().numpy()
                next_mus_mem[e, s, :, :] = next_mus.detach().numpy()

        mdrnn = MDRNN(
            latents=state_dim,
            actions=action_dim,
            gaussians=mdrnn_num_gaussian,
            hiddens=mdrnn_hidden_size,
            layers=mdrnn_hidden_layer,
        )
        mdrnn.train()
        optimizer = torch.optim.Adam(mdrnn.parameters(), lr=adam_lr)
        num_batch = num_episodes // batch_size
        earlystopping = EarlyStopping('min', patience=30)

        cum_loss = []
        cum_gmm = []
        cum_bce = []
        cum_mse = []
        for e in range(num_epochs):
            for i in range(0, num_batch):
                mdrnn.init_hidden(batch_size=batch_size)
                optimizer.zero_grad()
                sample_indices = numpy.random.randint(num_episodes,
                                                      size=batch_size)

                obs, action, reward, terminal, next_obs = \
                    cur_state_mem[sample_indices], \
                    action_mem[sample_indices], \
                    reward_mem[sample_indices], \
                    terminal_mem[sample_indices], \
                    next_state_mem[sample_indices]
                obs, action, reward, terminal, next_obs = \
                    torch.tensor(obs, dtype=torch.float), \
                    torch.tensor(action, dtype=torch.float), \
                    torch.tensor(reward, dtype=torch.float), \
                    torch.tensor(terminal, dtype=torch.float), \
                    torch.tensor(next_obs, dtype=torch.float)

                print("learning at epoch {} step {} best score {} counter {}".
                      format(e, i, earlystopping.best,
                             earlystopping.num_bad_epochs))
                losses = self.get_loss(obs, action, reward, terminal, next_obs,
                                       state_dim, mdrnn)
                losses['loss'].backward()
                optimizer.step()

                cum_loss += [losses['loss'].item()]
                cum_gmm += [losses['gmm'].item()]
                cum_bce += [losses['bce'].item()]
                cum_mse += [losses['mse'].item()]
                print(
                    "loss={loss:10.6f} bce={bce:10.6f} gmm={gmm:10.6f} mse={mse:10.6f}"
                    .format(
                        loss=losses['loss'],
                        bce=losses['bce'],
                        gmm=losses['gmm'],
                        mse=losses['mse'],
                    ))
                print(
                    "cum loss={loss:10.6f} cum bce={bce:10.6f} cum gmm={gmm:10.6f} cum mse={mse:10.6f}"
                    .format(
                        loss=numpy.mean(cum_loss),
                        bce=numpy.mean(cum_bce),
                        gmm=numpy.mean(cum_gmm),
                        mse=numpy.mean(cum_mse),
                    ))

                print()

            earlystopping.step(numpy.mean(cum_loss[-num_batch:]))
            if numpy.mean(cum_loss[-num_batch:]) < -3. and earlystopping.stop:
                break

        assert numpy.mean(cum_loss[-num_batch:]) < -3.

        sample_indices = [0]
        mdrnn.init_hidden(batch_size=len(sample_indices))
        mdrnn.eval()
        obs, action, reward, terminal, next_obs = \
            cur_state_mem[sample_indices], \
            action_mem[sample_indices], \
            reward_mem[sample_indices], \
            terminal_mem[sample_indices], \
            next_state_mem[sample_indices]
        obs, action, reward, terminal, next_obs = \
            torch.tensor(obs, dtype=torch.float), \
            torch.tensor(action, dtype=torch.float), \
            torch.tensor(reward, dtype=torch.float), \
            torch.tensor(terminal, dtype=torch.float), \
            torch.tensor(next_obs, dtype=torch.float)
        transpose_obs, transpose_action, transpose_reward, transpose_terminal, transpose_next_obs = \
            self.transpose(obs, action, reward, terminal, next_obs)
        mus, sigmas, logpi, rs, ds = mdrnn(transpose_action, transpose_obs)
        pi = torch.exp(logpi)
        gl = gmm_loss(transpose_next_obs, mus, sigmas, logpi)
        print(gl)

        print()