Пример #1
0
    def __init__(self, params):

        self.params = params
        self.loss_function = nn.MSELoss().cuda()
        # choose device
        self.cuda = params["cuda"] and torch.cuda.is_available()
        torch.manual_seed(params["seed"])
        # Fix numeric divergence due to bug in Cudnn
        torch.backends.cudnn.benchmark = True
        self.device = torch.device("cuda" if self.cuda else "cpu")
        # Initialize model
        if params["noreload"]:
            self.frame_predictor = lstm_models.lstm(params["g_dim"] + params["z_dim"]+params["action_size"], params["g_dim"], params["rnn_size"], params["predictor_rnn_layers"],
                                               params["batch_size"]).cuda()
            self.posterior = lstm_models.gaussian_lstm(params["g_dim"], params["z_dim"], params["rnn_size"], params["posterior_rnn_layers"],
                                                  params["batch_size"]).cuda()

            self.encoder = model.encoder(params["g_dim"], params["n_channels"]).cuda()
            self.decoder = model.decoder(params["g_dim"], params["n_channels"]).cuda()
        else:
            self.load_checkpoint()
        self.frame_predictor.apply(svp_utils.init_weights)
        self.posterior.apply(svp_utils.init_weights)
        self.encoder.apply(svp_utils.init_weights)
        self.decoder.apply(svp_utils.init_weights)

        # Init optimizers
        self.frame_predictor_optimizer = optim.Adam(self.frame_predictor.parameters(), lr=params["learning_rate"], betas=(params["beta1"], 0.999))
        self.posterior_optimizer =  optim.Adam(self.posterior.parameters(), lr=params["learning_rate"], betas=(params["beta1"], 0.999))
        self.encoder_optimizer = optim.Adam(self.encoder.parameters(), lr=params["learning_rate"], betas=(params["beta1"], 0.999))
        self.decoder_optimizer = optim.Adam(self.decoder.parameters(), lr=params["learning_rate"], betas=(params["beta1"], 0.999))
        if params["plot_visdom"]:
            self.plotter = VisdomLinePlotter(env_name=params['env'])
            self.img_plotter = VisdomImagePlotter(env_name=params['env'])


        # Select transformations
        transform = transforms.Lambda(
            lambda x: np.transpose(x, (0, 3, 1, 2)) / 255)
        self.train_loader = DataLoader(
            RolloutSequenceDataset(params["path_data"], params["seq_len"], transform, buffer_size=params["train_buffer_size"]),
            batch_size=params['batch_size'], num_workers=2, shuffle=True, drop_last=True)
        self.test_loader = DataLoader(
            RolloutSequenceDataset(params["path_data"],  params["seq_len"], transform, train=False, buffer_size=params["test_buffer_size"]),
            batch_size=params['batch_size'], num_workers=2, shuffle=False, drop_last=True)
Пример #2
0
class SVG_FP_TRAINER():

    def __init__(self, params):

        self.params = params
        self.loss_function = nn.MSELoss().cuda()
        # choose device
        self.cuda = params["cuda"] and torch.cuda.is_available()
        torch.manual_seed(params["seed"])
        # Fix numeric divergence due to bug in Cudnn
        torch.backends.cudnn.benchmark = True
        self.device = torch.device("cuda" if self.cuda else "cpu")
        # Initialize model
        if params["noreload"]:
            self.frame_predictor = lstm_models.lstm(params["g_dim"] + params["z_dim"]+params["action_size"], params["g_dim"], params["rnn_size"], params["predictor_rnn_layers"],
                                               params["batch_size"]).cuda()
            self.posterior = lstm_models.gaussian_lstm(params["g_dim"], params["z_dim"], params["rnn_size"], params["posterior_rnn_layers"],
                                                  params["batch_size"]).cuda()

            self.encoder = model.encoder(params["g_dim"], params["n_channels"]).cuda()
            self.decoder = model.decoder(params["g_dim"], params["n_channels"]).cuda()
        else:
            self.load_checkpoint()
        self.frame_predictor.apply(svp_utils.init_weights)
        self.posterior.apply(svp_utils.init_weights)
        self.encoder.apply(svp_utils.init_weights)
        self.decoder.apply(svp_utils.init_weights)

        # Init optimizers
        self.frame_predictor_optimizer = optim.Adam(self.frame_predictor.parameters(), lr=params["learning_rate"], betas=(params["beta1"], 0.999))
        self.posterior_optimizer =  optim.Adam(self.posterior.parameters(), lr=params["learning_rate"], betas=(params["beta1"], 0.999))
        self.encoder_optimizer = optim.Adam(self.encoder.parameters(), lr=params["learning_rate"], betas=(params["beta1"], 0.999))
        self.decoder_optimizer = optim.Adam(self.decoder.parameters(), lr=params["learning_rate"], betas=(params["beta1"], 0.999))
        if params["plot_visdom"]:
            self.plotter = VisdomLinePlotter(env_name=params['env'])
            self.img_plotter = VisdomImagePlotter(env_name=params['env'])


        # Select transformations
        transform = transforms.Lambda(
            lambda x: np.transpose(x, (0, 3, 1, 2)) / 255)
        self.train_loader = DataLoader(
            RolloutSequenceDataset(params["path_data"], params["seq_len"], transform, buffer_size=params["train_buffer_size"]),
            batch_size=params['batch_size'], num_workers=2, shuffle=True, drop_last=True)
        self.test_loader = DataLoader(
            RolloutSequenceDataset(params["path_data"],  params["seq_len"], transform, train=False, buffer_size=params["test_buffer_size"]),
            batch_size=params['batch_size'], num_workers=2, shuffle=False, drop_last=True)


    def load_checkpoint(self):
        tmp = torch.load(self.params["model_path"])
        print("LOADING CHECKPOINT.............")
        self.frame_predictor = tmp['frame_predictor']
        self.posterior = tmp['posterior']
        self.encoder = tmp['encoder']
        self.decoder = tmp['decoder']
        self.frame_predictor.batch_size = self.params["batch_size"]
        self.posterior.batch_size = self.params["batch_size"]


    def plot_samples(self, x, actions, epoch):
        # Create a gif of sequences
        nsample = 5
        gen_seq = [[] for i in range(nsample)]

        gt_seq = [x[:,i] for i in range(x.shape[1])]

        #h_seq = [self.encoder(x[:,i]) for i in range(params["n_past"])]
        for s in range(nsample):
            self.frame_predictor.hidden = self.frame_predictor.init_hidden()
            self.posterior.hidden = self.posterior.init_hidden()

            gen_seq[s].append(x[:,0])
            x_in = x[:,0]
            for i in range(1, self.params["n_eval"]):
                h = self.encoder(x_in)
                if self.params["last_frame_skip"] or i < self.params["n_past"]:
                    h, skip = h
                    h = h.detach()
                else:
                    h, _ = h
                    h = h.detach()
                if i < self.params["n_past"]:
                    h_target = self.encoder(x[:, i])[0].detach()
                    z_t, _, _ = self.posterior(h_target)
                    self.frame_predictor(torch.cat([h, z_t, actions[:,i-1]], 1))
                    x_in = x[:,i]

                    gen_seq[s].append(x_in)
                else:
                    z_t = torch.cuda.FloatTensor(self.params["batch_size"], self.params["z_dim"]).normal_()
                    h = self.frame_predictor(torch.cat([h, z_t, actions[:,i-1]], 1)).detach()
                    x_in = self.decoder([h, skip]).detach()
                    gen_seq[s].append(x_in)

        to_plot = []
        gifs = [[] for t in range(self.params["n_eval"])]
        nrow = min(self.params["batch_size"], 10)
        for i in range(nrow):
            # ground truth sequence
            row = []
            for t in range(self.params["n_eval"]):
                row.append(gt_seq[t][i])
            to_plot.append(row)

            for s in range(nsample):
                row = []
                for t in range(self.params["n_eval"]):
                    row.append(gen_seq[s][t][i])
                to_plot.append(row)
            for t in range(self.params["n_eval"]):
                row = []
                row.append(gt_seq[t][i])
                for s in range(nsample):
                    row.append(gen_seq[s][t][i])
                gifs[t].append(row)

        fname = '%s/gen/sample_%d.png' % (self.params["logdir"], epoch)
        svp_utils.save_tensors_image(fname, to_plot)

        fname = '%s/gen/sample_%d.gif' % (self.params["logdir"], epoch)
        svp_utils.save_gif(fname, gifs)


    def plot_rec(self, x, actions, epoch):
        self.frame_predictor.hidden = self.frame_predictor.init_hidden()
        self.posterior.hidden = self.posterior.init_hidden()
        gen_seq = []
        gen_seq.append(x[:,0])
        x_in = x[:,0]
        h_seq = [self.encoder(x[:,i]) for i in range(params["seq_len"])]
        for i in range(1, self.params["seq_len"]):
            h_target = h_seq[i][0].detach()
            if self.params["last_frame_skip"] or i < self.params["n_past"]:
                h, skip = h_seq[i - 1]
            else:
                h, _ = h_seq[i - 1]
            h = h.detach()
            z_t, mu, logvar = self.posterior(h_target)
            if i < self.params["n_past"]:
                self.frame_predictor(torch.cat([h, z_t, actions[:,i-1]], 1))
                gen_seq.append(x[:,i])
            else:
                h = self.frame_predictor(torch.cat([h, z_t, actions[:,i-1]], 1)).detach()
                x_pred = self.decoder([h, skip]).detach()
                gen_seq.append(x_pred)

        to_plot = []
        nrow = min(self.params["batch_size"], 10)
        for i in range(nrow):
            row = []
            for t in range(self.params["seq_len"]):
                row.append(gen_seq[t][i])
            to_plot.append(row)
        check_dir(params["logdir"], "gen")
        fname = '%s/gen/rec_%d.png' % (self.params["logdir"], epoch)
        svp_utils.save_tensors_image(fname, to_plot)

    def data_pass(self, epoch, train):
        if train:
            self.frame_predictor.train()
            self.posterior.train()
            self.encoder.train()
            self.decoder.train()
            loader = self.train_loader
            mode = "train"
        else:
            self.frame_predictor.eval()
            self.posterior.eval()
            self.encoder.eval()
            self.decoder.eval()
            loader = self.test_loader
            mode = "test"
        num_of_files = len(loader.dataset._files)
        buffer_size = loader.dataset._buffer_size
        iteration = 0
        final_loss = 0
        all_files = 0
        loader.dataset._buffer_index = 0
        break_cond = False
        plot_epoch = True
        while True:
            loader.dataset.load_next_buffer()
            cum_loss = 0
            cum_mse = 0.0
            cum_dl = 0.0
            pbar = tqdm(total=len(loader.dataset), desc="Epoch {} - {}".format(epoch, mode))

            for i, data in enumerate(loader):
                obs, action, next_obs = [arr.to(self.device) for arr in data]
                obs = torch.cat([obs, next_obs[:,-1].unsqueeze(1)], 1)

                # initialize the hidden state.
                self.frame_predictor.hidden = self.frame_predictor.init_hidden()
                self.posterior.hidden = self.posterior.init_hidden()
                seq_len = obs.shape[1]


                if not train:
                    if plot_epoch:
                        print(">>>>>>>>>>>PLOT<<<<<<<<<<<<<")
                        self.plot_rec(obs, action, epoch)
                        self.plot_samples(obs, action, epoch)
                        plot_epoch = False
                    h_seq = [self.encoder(obs[:, j]) for j in range(params["n_past"])]
                    mse = 0
                    kld = 0
                    x_in = obs[:, 0]

                    for t in range(1, seq_len):
                        h = self.encoder(x_in)
                        if t < params["n_past"] and params["last_frame_skip"]:
                            h, skip = h
                        else:
                            h, _ = h

                        if t < self.params["n_past"]:
                            h_target = self.encoder(obs[:, t])[0]
                            _, z_t, _ = self.posterior(h_target)
                        else:
                            z_t = torch.cuda.FloatTensor(self.params["batch_size"], self.params["z_dim"]).normal_()

                        if t < self.params["n_past"]:
                            self.frame_predictor(torch.cat([h, z_t, action[:,t-1]], 1))
                            x_in = obs[:, t]
                        else:
                            h = self.frame_predictor(torch.cat([h, z_t, action[:,t-1]], 1))
                            x_in = self.decoder([h, skip])
                        x_pred = x_in
                        with torch.no_grad():
                            mse += self.loss_function(x_pred, obs[:, t])
                            kld = 0
                    mse /= params["seq_len"]
                    kld /= params["seq_len"]
                    loss = mse + kld*params["beta"]
                else:
                    self.frame_predictor.zero_grad()
                    self.posterior.zero_grad()
                    self.encoder.zero_grad()
                    self.decoder.zero_grad()

                    h_seq = [self.encoder(obs[:,j]) for j in range(seq_len)]
                    mse = 0
                    kld = 0

                    for t in range(1, seq_len):
                        h_target = h_seq[t][0]  
                        if t < params["n_past"] or params["last_frame_skip"]:
                            h, skip = h_seq[t-1]
                        else:
                            h = h_seq[t-1][0]
                        z_t, mu, logvar = self.posterior(h_target)
                        h_pred = self.frame_predictor(torch.cat([h, z_t, action[:,t-1]], 1))
                        x_pred = self.decoder([h_pred, skip])

                        mse += self.loss_function(x_pred, obs[:, t])
                        kld += kl_criterion(mu, logvar, params["batch_size"])

                    mse /= params["seq_len"]
                    kld /= params["seq_len"]
                    loss = mse + kld*params["beta"]
                    loss.backward()
                    self.frame_predictor_optimizer.step()
                    self.posterior_optimizer.step()
                    self.encoder_optimizer.step()
                    self.decoder_optimizer.step()

                cum_loss += loss
                cum_mse += mse
                cum_dl += kld*params["beta"]
                pbar.set_postfix_str("loss={loss:5.4f} , MSE={mse_loss:5.4f}, DL_loss={dl_loss:5.4f} ".format(
                    loss=cum_loss / (i + 1), mse_loss=cum_mse/ (i + 1), dl_loss=cum_dl/ (i + 1)))
                pbar.update(params['batch_size'])
            pbar.close()
            final_loss += cum_loss
            all_files += len(loader.dataset)
            iteration += 1
            print("Iteration: " + str(iteration))
            print("Buffer index: " + str(loader.dataset._buffer_index))
            if buffer_size < num_of_files:
                if params["shorten_epoch"]==iteration:
                    break_cond = True
                if loader.dataset._buffer_index == 0 or break_cond:
                    final_loss = final_loss * params['batch_size'] / all_files
                    break
                if num_of_files - loader.dataset._buffer_index < buffer_size:
                    break_cond = True
            else:
                final_loss = final_loss * params['batch_size'] / all_files
                break

        print("Average loss {}".format(final_loss))
        if self.params["plot_visdom"]:
            if train:
                self.plotter.plot('loss', 'train', 'SVG_FP Train Loss', epoch, final_loss.item())
            else:
                self.plotter.plot('loss', 'test', 'SVG_FP Test Loss', epoch, final_loss.item())
        return final_loss.item()

    def init_svg_model(self):
        self.svg_dir = os.path.join(self.params["logdir"], 'svg')
        check_dir(self.svg_dir, 'samples')


    def checkpoint(self, cur_best, test_loss):
        best_filename = os.path.join(self.svg_dir, 'best.tar')
        filename = os.path.join(self.svg_dir, 'checkpoint.tar')
        is_best = not cur_best or test_loss < cur_best
        if is_best:
            cur_best = test_loss
        save_checkpoint({
        'encoder': self.encoder,
        'decoder': self.decoder,
        'frame_predictor': self.frame_predictor,
        'posterior': self.posterior,
        'test_loss': test_loss,
        'params': self.params
        }, is_best, filename, best_filename)
        return cur_best


    def plot(self, train, test, epochs):
        plt.plot(epochs, train, label="train loss")
        plt.plot(epochs, test, label="test loss")
        plt.legend()
        plt.grid()
        plt.savefig(self.params["logdir"] + "/svg_fp_training_curve.png")
        plt.close()
Пример #3
0
class VAE_TRAINER():
    def __init__(self, params):

        self.params = params
        self.loss_function = {
            'ms-ssim': ms_ssim_loss,
            'mse': mse_loss,
            'mix': mix_loss
        }[params["loss"]]

        # Choose device
        self.cuda = params["cuda"] and torch.cuda.is_available()
        torch.manual_seed(params["seed"])
        # Fix numeric divergence due to bug in Cudnn
        torch.backends.cudnn.benchmark = True
        self.device = torch.device("cuda" if self.cuda else "cpu")

        # Prepare data transformations
        red_size = params["img_size"]
        transform_train = transforms.Compose([
            transforms.ToPILImage(),
            transforms.Resize((red_size, red_size)),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
        ])

        transform_val = transforms.Compose([
            transforms.ToPILImage(),
            transforms.Resize((red_size, red_size)),
            transforms.ToTensor(),
        ])

        # Initialize Data loaders
        op_dataset = RolloutObservationDataset(params["path_data"],
                                               transform_train,
                                               train=True)
        val_dataset = RolloutObservationDataset(params["path_data"],
                                                transform_val,
                                                train=False)

        self.train_loader = torch.utils.data.DataLoader(
            op_dataset,
            batch_size=params["batch_size"],
            shuffle=True,
            num_workers=0)
        self.eval_loader = torch.utils.data.DataLoader(
            val_dataset,
            batch_size=params["batch_size"],
            shuffle=False,
            num_workers=0)

        # Initialize model and hyperparams
        self.model = VAE(nc=3,
                         ngf=64,
                         ndf=64,
                         latent_variable_size=params["latent_size"],
                         cuda=self.cuda).to(self.device)
        self.optimizer = optim.Adam(self.model.parameters())
        self.init_vae_model()
        self.visualize = params["visualize"]
        if self.visualize:
            self.plotter = VisdomLinePlotter(env_name=params['env'])
            self.img_plotter = VisdomImagePlotter(env_name=params['env'])
        self.alpha = params["alpha"] if params["alpha"] else 1.0

    def train(self, epoch):
        self.model.train()
        # dataset_train.load_next_buffer()
        mse_loss = 0
        ssim_loss = 0
        train_loss = 0
        # Train step
        for batch_idx, data in enumerate(self.train_loader):
            data = data.to(self.device)
            self.optimizer.zero_grad()
            recon_batch, mu, logvar = self.model(data)
            loss, mse, ssim = self.loss_function(recon_batch, data, mu, logvar,
                                                 self.alpha)
            loss.backward()

            train_loss += loss.item()
            ssim_loss += ssim
            mse_loss += mse
            self.optimizer.step()

            if batch_idx % params["log_interval"] == 0:
                print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                    epoch, batch_idx * len(data),
                    len(self.train_loader.dataset),
                    100. * batch_idx / len(self.train_loader), loss.item()))
                print('MSE: {} , SSIM: {:.4f}'.format(mse, ssim))

        step = len(self.train_loader.dataset) / float(
            self.params["batch_size"])
        mean_train_loss = train_loss / step
        mean_ssim_loss = ssim_loss / step
        mean_mse_loss = mse_loss / step
        print('-- Epoch: {} Average loss: {:.4f}'.format(
            epoch, mean_train_loss))
        print('-- Average MSE: {:.5f} Average SSIM: {:.4f}'.format(
            mean_mse_loss, mean_ssim_loss))
        if self.visualize:
            self.plotter.plot('loss', 'train', 'VAE Train Loss', epoch,
                              mean_train_loss)
        return

    def eval(self):
        self.model.eval()
        # dataset_test.load_next_buffer()
        eval_loss = 0
        mse_loss = 0
        ssim_loss = 0
        vis = True
        with torch.no_grad():
            # Eval step
            for data in self.eval_loader:
                data = data.to(self.device)
                recon_batch, mu, logvar = self.model(data)

                loss, mse, ssim = self.loss_function(recon_batch, data, mu,
                                                     logvar, self.alpha)
                eval_loss += loss.item()
                ssim_loss += ssim
                mse_loss += mse
                if vis:
                    org_title = "Epoch: " + str(epoch)
                    comparison1 = torch.cat([
                        data[:4],
                        recon_batch.view(params["batch_size"], 3,
                                         params["img_size"],
                                         params["img_size"])[:4]
                    ])
                    if self.visualize:
                        self.img_plotter.plot(comparison1, org_title)
                    vis = False

        step = len(self.eval_loader.dataset) / float(params["batch_size"])
        mean_eval_loss = eval_loss / step
        mean_ssim_loss = ssim_loss / step
        mean_mse_loss = mse_loss / step
        print('-- Eval set loss: {:.4f}'.format(mean_eval_loss))
        print('-- Eval MSE: {:.5f} Eval SSIM: {:.4f}'.format(
            mean_mse_loss, mean_ssim_loss))
        if self.visualize:
            self.plotter.plot('loss', 'eval', 'VAE Eval Loss', epoch,
                              mean_eval_loss)
            self.plotter.plot('loss', 'mse train', 'VAE MSE Loss', epoch,
                              mean_mse_loss)
            self.plotter.plot('loss', 'ssim train', 'VAE MSE Loss', epoch,
                              mean_ssim_loss)

        return mean_eval_loss

    def init_vae_model(self):
        self.vae_dir = os.path.join(self.params["logdir"], 'vae')
        check_dir(self.vae_dir, 'samples')
        if not self.params["noreload"]:  # and os.path.exists(reload_file):
            reload_file = os.path.join(self.params["vae_location"], 'best.tar')
            state = torch.load(reload_file)
            print("Reloading model at epoch {}"
                  ", with eval error {}".format(state['epoch'],
                                                state['precision']))
            self.model.load_state_dict(state['state_dict'])
            self.optimizer.load_state_dict(state['optimizer'])

    def checkpoint(self, cur_best, eval_loss):
        # Save the best and last checkpoint
        best_filename = os.path.join(self.vae_dir, 'best.tar')
        filename = os.path.join(self.vae_dir, 'checkpoint.tar')
        is_best = not cur_best or eval_loss < cur_best
        if is_best:
            cur_best = eval_loss

        save_checkpoint(
            {
                'epoch': epoch,
                'state_dict': self.model.state_dict(),
                'precision': eval_loss,
                'optimizer': self.optimizer.state_dict()
            }, is_best, filename, best_filename)
        return cur_best

    def plot(self, train, eval, epochs):
        plt.plot(epochs, train, label="train loss")
        plt.plot(epochs, eval, label="eval loss")
        plt.legend()
        plt.grid()
        plt.savefig(self.params["logdir"] + "/vae_training_curve.png")
        plt.close()
Пример #4
0
    def __init__(self, params):

        self.params = params
        self.loss_function = {
            'ms-ssim': ms_ssim_loss,
            'mse': mse_loss,
            'mix': mix_loss
        }[params["loss"]]

        # Choose device
        self.cuda = params["cuda"] and torch.cuda.is_available()
        torch.manual_seed(params["seed"])
        # Fix numeric divergence due to bug in Cudnn
        torch.backends.cudnn.benchmark = True
        self.device = torch.device("cuda" if self.cuda else "cpu")

        # Prepare data transformations
        red_size = params["img_size"]
        transform_train = transforms.Compose([
            transforms.ToPILImage(),
            transforms.Resize((red_size, red_size)),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
        ])

        transform_val = transforms.Compose([
            transforms.ToPILImage(),
            transforms.Resize((red_size, red_size)),
            transforms.ToTensor(),
        ])

        # Initialize Data loaders
        op_dataset = RolloutObservationDataset(params["path_data"],
                                               transform_train,
                                               train=True)
        val_dataset = RolloutObservationDataset(params["path_data"],
                                                transform_val,
                                                train=False)

        self.train_loader = torch.utils.data.DataLoader(
            op_dataset,
            batch_size=params["batch_size"],
            shuffle=True,
            num_workers=0)
        self.eval_loader = torch.utils.data.DataLoader(
            val_dataset,
            batch_size=params["batch_size"],
            shuffle=False,
            num_workers=0)

        # Initialize model and hyperparams
        self.model = VAE(nc=3,
                         ngf=64,
                         ndf=64,
                         latent_variable_size=params["latent_size"],
                         cuda=self.cuda).to(self.device)
        self.optimizer = optim.Adam(self.model.parameters())
        self.init_vae_model()
        self.visualize = params["visualize"]
        if self.visualize:
            self.plotter = VisdomLinePlotter(env_name=params['env'])
            self.img_plotter = VisdomImagePlotter(env_name=params['env'])
        self.alpha = params["alpha"] if params["alpha"] else 1.0
Пример #5
0
            final_loss = final_loss * params['batch_size'] / all_files
            break

    print("Average loss {}".format(final_loss))
    if train:
        mdn_plotter.plot('loss', 'train', 'MDRNN Train Loss', epoch, final_loss)
    else:
        mdn_plotter.plot('loss', 'test', 'MDRNN Test Loss', epoch, final_loss)
    return final_loss

train = partial(data_pass, train=True)
test = partial(data_pass, train=False)

cur_best = None
global mdn_plotter
mdn_plotter = VisdomLinePlotter(env_name=params['env'])
cum_train_loss = []
cum_test_loss = []
epochs_list = []
for e in range(params['epochs']):
    train_loss=train(e)
    test_loss = test(e)
    cum_test_loss.append(test_loss)
    cum_train_loss.append(train_loss)
    epochs_list.append(e)
    plot_curve(cum_train_loss, cum_test_loss, epochs_list)

    scheduler.step(test_loss)
    is_best = not cur_best or test_loss < cur_best
    if is_best:
        cur_best = test_loss