示例#1
0
def main(args):
    # Check if the output folder is exist
    if not os.path.exists(args.folder):
        os.mkdir(args.folder)

    # Load data
    torch.manual_seed(args.seed)
    kwargs = {'num_workers': 1, 'pin_memory': True} if args.cuda else {}
    train_loader = torch.utils.data.DataLoader(datasets.MNIST(
        './data', train=True, download=True, transform=transforms.ToTensor()),
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               **kwargs)

    # Load model
    model = CVAE().cuda() if torch.cuda.is_available() else CVAE()
    optimizer = optim.Adam(model.parameters(), lr=1e-3)

    # Train and generate sample every epoch
    loss_list = []
    for epoch in range(1, args.epochs + 1):
        model.train()
        _loss = train(epoch, model, train_loader, optimizer)
        loss_list.append(_loss)
        model.eval()
        sample = torch.randn(100, 20)
        label = torch.from_numpy(np.asarray(list(range(10)) * 10))
        sample = Variable(
            sample).cuda() if torch.cuda.is_available() else Variable(sample)
        sample = model.decode(sample, label).cpu()
        save_image(sample.view(100, 1, 28, 28).data,
                   os.path.join(args.folder, 'sample_' + str(epoch) + '.png'),
                   nrow=10)
    plt.plot(range(len(loss_list)), loss_list, '-o')
    plt.savefig(os.path.join(args.folder, 'cvae_loss_curve.png'))
    torch.save(model.state_dict(), os.path.join(args.folder, 'cvae.pth'))
示例#2
0
def main(**kwargs):
    """
    Main function that trains the model
    1. Retrieve arguments from kwargs
    2. Prepare data
    3. Train
    4. Display and save first batch of training set (truth and reconstructed) after every epoch
    5. If latent dimension is 2, display and save latent variable of first batch of training set after every epoch
    
    Args:
        dataset: Which dataset to use
        decoder_type: How to model the output pixels, Gaussian or Bernoulli
        model_sigma: In case of Gaussian decoder, whether to model the sigmas too
        epochs: How many epochs to train model
        batch_size: Size of training / testing batch
        lr: Learning rate
        latent_dim: Dimension of latent variable
        print_every: How often to print training progress
        resume_path: The path of saved model with which to resume training
        resume_epoch: In case of resuming, the number of epochs already done 

    Notes:
        - Saves model to folder 'saved_model/' every 20 epochs and when done
        - Capable of training from scratch and resuming (provide saved model location to argument resume_path)
        - Schedules learning rate with optim.lr_scheduler.ReduceLROnPlateau
            : Decays learning rate by 1/10 when mean loss of all training data does not decrease for 10 epochs
    """
    # Retrieve arguments
    dataset = kwargs.get('dataset', defaults['dataset'])
    decoder_type = kwargs.get('decoder_type', defaults['decoder_type'])
    if decoder_type == 'Gaussian':
        model_sigma = kwargs.get('model_sigma', defaults['model_sigma'])
    epochs = kwargs.get('epochs', defaults['epochs'])
    batch_size = kwargs.get('batch_size', defaults['batch_size'])
    lr = kwargs.get('learning_rate', defaults['learning_rate'])
    latent_dim = kwargs.get('latent_dim', defaults['latent_dim'])
    print_every = kwargs.get('print_every', defaults['print_every'])
    resume_path = kwargs.get('resume_path', defaults['resume_path'])
    resume_epoch = kwargs.get('resume_epoch', defaults['resume_epoch'])

    # Specify dataset transform on load
    if decoder_type == 'Bernoulli':
        trsf = transforms.Compose([
            transforms.ToTensor(),
            transforms.Lambda(lambda x: (x >= 0.5).float())
        ])
    elif decoder_type == 'Gaussian':
        trsf = transforms.ToTensor()

    # Load dataset with transform
    if dataset == 'MNIST':
        train_data = datasets.MNIST(root='MNIST',
                                    train=True,
                                    transform=trsf,
                                    download=True)
        test_data = datasets.MNIST(root='MNIST',
                                   train=False,
                                   transform=trsf,
                                   download=True)
    elif dataset == 'CIFAR10':
        train_data = datasets.CIFAR10(root='CIFAR10',
                                      train=True,
                                      transform=trsf,
                                      download=True)
        test_data = datasets.CIFAR10(root='CIFAR10',
                                     train=False,
                                     transform=trsf,
                                     download=True)

    # Instantiate dataloader
    train_loader = torch.utils.data.DataLoader(train_data,
                                               batch_size=batch_size,
                                               shuffle=True)
    test_loader = torch.utils.data.DataLoader(test_data,
                                              batch_size=batch_size,
                                              shuffle=False)

    # Instantiate/Load model and optimizer
    if resume_path:
        autoencoder = torch.load(resume_path, map_location=device)
        optimizer = optim.Adam(autoencoder.parameters(), lr=lr)
        print('Loaded saved model at ' + resume_path)
    else:
        if decoder_type == 'Bernoulli':
            autoencoder = CVAE(latent_dim, dataset, decoder_type).to(device)
        else:
            autoencoder = CVAE(latent_dim, dataset, decoder_type,
                               model_sigma).to(device)
        optimizer = optim.Adam(autoencoder.parameters(), lr=lr)

    # Instantiate learning rate scheduler
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer,
                                                     'min',
                                                     verbose=True,
                                                     patience=5)

    # Announce current mode
    print(
        f'Start training CVAE with Gaussian encoder and {decoder_type} decoder on {dataset} dataset from epoch {resume_epoch+1}'
    )

    # Prepare batch to display with plt
    first_test_batch, first_test_batch_label = iter(test_loader).next()
    first_test_batch, first_test_batch_label = first_test_batch.to(
        device), first_test_batch_label.to(device)

    # Display latent variable distribution before any training
    if latent_dim == 2 and resume_epoch == 0:
        autoencoder(first_test_batch, first_test_batch_label)
        display_and_save_latent(autoencoder.z, first_test_batch_label,
                                f'-{decoder_type}-z{latent_dim}-e000')

    # Train
    autoencoder.train()
    for epoch in range(resume_epoch, epochs + resume_epoch):
        loss_hist = []
        for batch_ind, (input_data, input_label) in enumerate(train_loader):
            input_data, input_label = input_data.to(device), input_label.to(
                device)

            # Forward propagation
            if decoder_type == 'Bernoulli':
                z_mu, z_sigma, p = autoencoder(input_data, input_label)
            elif model_sigma:
                z_mu, z_sigma, out_mu, out_sigma = autoencoder(
                    input_data, input_label)
            else:
                z_mu, z_sigma, out_mu = autoencoder(input_data, input_label)

            # Calculate loss
            KL_divergence_i = 0.5 * torch.sum(
                z_mu**2 + z_sigma**2 - torch.log(1e-8 + z_sigma**2) - 1.,
                dim=1)
            if decoder_type == 'Bernoulli':
                reconstruction_loss_i = -torch.sum(F.binary_cross_entropy(
                    p, input_data, reduction='none'),
                                                   dim=(1, 2, 3))
            elif model_sigma:
                reconstruction_loss_i = -0.5 * torch.sum(
                    torch.log(1e-8 + 6.28 * out_sigma**2) +
                    ((input_data - out_mu)**2) / (out_sigma**2),
                    dim=(1, 2, 3))
            else:
                reconstruction_loss_i = -0.5 * torch.sum(
                    (input_data - out_mu)**2, dim=(1, 2, 3))
            ELBO_i = reconstruction_loss_i - KL_divergence_i
            loss = -torch.mean(ELBO_i)

            loss_hist.append(loss)

            # Backward propagation
            optimizer.zero_grad()
            loss.backward()

            # Update parameters
            optimizer.step()

            # Print progress
            if batch_ind % print_every == 0:
                train_log = 'Epoch {:03d}/{:03d}\tLoss: {:.6f}\t\tTrain: [{}/{} ({:.0f}%)]           '.format(
                    epoch + 1, epochs + resume_epoch,
                    loss.cpu().item(), batch_ind + 1, len(train_loader),
                    100. * batch_ind / len(train_loader))
                print(train_log, end='\r')
                sys.stdout.flush()

        # Learning rate decay
        scheduler.step(sum(loss_hist) / len(loss_hist))

        # Save model every 20 epochs
        if (epoch + 1) % 20 == 0 and epoch + 1 != epochs:
            PATH = f'saved_model/{dataset}-{decoder_type}-e{epoch+1}-z{latent_dim}' + datetime.datetime.now(
            ).strftime("-%b-%d-%H-%M-%p")
            torch.save(autoencoder, PATH)
            print('\vTemporarily saved model to ' + PATH)

        # Display training result with test set
        data = f'-{decoder_type}-z{latent_dim}-e{epoch+1:03d}'
        with torch.no_grad():
            autoencoder.eval()
            if decoder_type == 'Bernoulli':
                z_mu, z_sigma, p = autoencoder(first_test_batch,
                                               first_test_batch_label)
                output = torch.bernoulli(p)

                if latent_dim == 2:
                    display_and_save_latent(autoencoder.z,
                                            first_test_batch_label, data)

                display_and_save_batch("Binarized-truth",
                                       first_test_batch,
                                       data,
                                       save=(epoch == 0))
                display_and_save_batch("Mean-reconstruction",
                                       p,
                                       data,
                                       save=True)
                display_and_save_batch("Sampled-reconstruction",
                                       output,
                                       data,
                                       save=True)

            elif model_sigma:
                z_mu, z_sigma, out_mu, out_sigma = autoencoder(
                    first_test_batch, first_test_batch_label)
                output = torch.normal(out_mu, out_sigma).clamp(0., 1.)

                if latent_dim == 2:
                    display_and_save_latent(autoencoder.z,
                                            first_test_batch_label, data)

                display_and_save_batch("Truth",
                                       first_test_batch,
                                       data,
                                       save=(epoch == 0))
                display_and_save_batch("Mean-reconstruction",
                                       out_mu,
                                       data,
                                       save=True)
                # display_and_save_batch("Sampled reconstruction", output, data, save=True)

            else:
                z_mu, z_sigma, out_mu = autoencoder(first_test_batch,
                                                    first_test_batch_label)
                output = torch.normal(out_mu,
                                      torch.ones_like(out_mu)).clamp(0., 1.)

                if latent_dim == 2:
                    display_and_save_latent(autoencoder.z,
                                            first_test_batch_label, data)

                display_and_save_batch("Truth",
                                       first_test_batch,
                                       data,
                                       save=(epoch == 0))
                display_and_save_batch("Mean-reconstruction",
                                       out_mu,
                                       data,
                                       save=True)
                # display_and_save_batch("Sampled reconstruction", output, data, save=True)
            autoencoder.train()

    # Save final model
    PATH = f'saved_model/{dataset}-{decoder_type}-e{epochs+resume_epoch}-z{latent_dim}' + datetime.datetime.now(
    ).strftime("-%b-%d-%H-%M-%p")
    torch.save(autoencoder, PATH)
    print('\vSaved model to ' + PATH)
示例#3
0
文件: train.py 项目: tbrx/CVAE
min_temperature = 0.5
decay_rate    = 0.95
for epoch in range(num_epochs):
    # Learning rate scheduling 
    st = time.time()
    model.assign_lr(learning_rate * (decay_rate ** epoch))
    train_loss = []
    test_loss = []
    st = time.time()
    
    for iteration in range(int(len(train_molecules)/batch_size)):
        n = np.random.randint(len(train_molecules), size = batch_size)
        x = np.array([train_molecules[i] for i in n])
        l = np.array([train_length[i] for i in n])
        c = np.array([train_labels[i] for i in n])
        cost = model.train(x, l, c)
        train_loss.append(cost)
    
    for iteration in range(int(len(test_molecules)/batch_size)):
        n = np.random.randint(len(test_molecules), size = batch_size)
        x = np.array([test_molecules[i] for i in n])
        l = np.array([test_length[i] for i in n])
        c = np.array([test_labels[i] for i in n])
        cost = model.test(x, l, c)
        test_loss.append(cost)
    
    train_loss = np.mean(np.array(train_loss))        
    test_loss = np.mean(np.array(test_loss))    
    end = time.time()    
    if epoch==0:
        print ('epoch\ttrain_loss\ttest_loss\ttime (s)')
示例#4
0
class CVAEInterface():
    def __init__(self, run_id=1, output_path="", env_path_root=""):
        super().__init__()
        self.cvae = CVAE(run_id=run_id)
        self.device = torch.device('cuda' if CUDA_AVAILABLE else 'cpu')
        self.output_path = output_path
        self.env_path_root = env_path_root

        if self.output_path is not None:
            if os.path.exists(self.output_path):
                shutil.rmtree(self.output_path)
            os.mkdir(self.output_path)

    def load_dataset(self, dataset_root, data_type="arm", mode="train"):
        assert (data_type == "both" or data_type == "arm"
                or data_type == "base")
        assert (mode == "train" or mode == "test")
        # Should show different count and path for different modes
        print("Loading {} dataset for mode : {}, path : {}".format(
            data_type, mode, dataset_root))
        self.data_type = data_type

        paths_dataset = PathsDataset(type="FULL_STATE")
        c_test_dataset = PathsDataset(type="CONDITION_ONLY")
        env_dir_paths = os.listdir(dataset_root)
        # Get all C vars to test sample generation on each
        all_condition_vars = []
        for env_dir_index in filter(lambda f: f[0].isdigit(), env_dir_paths):
            env_paths_file = os.path.join(dataset_root, env_dir_index,
                                          "data_{}.txt".format(data_type))
            env_paths = np.loadtxt(env_paths_file)
            # 4 to 16
            if IGNORE_START:
                start = env_paths[:, X_DIM:2 * X_DIM]
                samples = env_paths[:, :X_DIM]
                euc_dist = np.linalg.norm(start - samples, axis=1)
                far_from_start = np.where(euc_dist > 5.0)
                print(far_from_start)
                env_paths = env_paths[far_from_start[0], :]
                condition_vars = env_paths[:, 2 * X_DIM:2 * X_DIM + C_DIM]
            else:
                if mode == "train":
                    # Testing, less points near start to reduce them in sampled output
                    start = env_paths[:, X_DIM:X_DIM + POINT_DIM]
                    samples = env_paths[:, :X_DIM]
                    euc_dist = np.linalg.norm(start - samples, axis=1)
                    far_from_start = np.where(euc_dist > 2.0)
                    # print(far_from_start)
                    env_paths = env_paths[far_from_start[0], :]

                condition_vars = env_paths[:, X_DIM:X_DIM + C_DIM]
            # print(env_paths.shape)
            # Stuff for train dataloader
            # Take only required elements
            # env_paths = env_paths[:, :X_DIM + C_DIM]
            env_paths = np.hstack((env_paths[:, :X_DIM], condition_vars))
            # Uniquify to remove duplicates
            env_paths = np.unique(env_paths, axis=0)
            env_index = np.empty((env_paths.shape[0], 1))
            env_index.fill(env_dir_index)
            data = np.hstack((env_index, env_paths))
            paths_dataset.add_env_paths(data.tolist())

            # Stuff for test dataloader
            env_index = np.empty((condition_vars.shape[0], 1))
            env_index.fill(env_dir_index)
            data = np.hstack((env_index, condition_vars))
            all_condition_vars += data.tolist()
            print("Added {} states from {} environment".format(
                env_paths.shape[0], env_dir_index))

        dataloader = DataLoader(paths_dataset,
                                batch_size=TRAIN_BATCH_SIZE,
                                shuffle=True)

        if data_type != "both":

            # Depending on which dataset is being loaded, set the right variables
            if mode == "train":
                self.train_dataloader = dataloader
                self.train_paths_dataset = paths_dataset
            elif mode == "test":
                self.test_condition_vars = np.unique(all_condition_vars,
                                                     axis=0)
                print("Unique test conditions count : {}".format(
                    self.test_condition_vars.shape[0]))
                # Tile condition variables to predict given number of samples for x
                all_condition_vars_tile = np.repeat(self.test_condition_vars,
                                                    TEST_SAMPLES, 0)
                c_test_dataset.add_env_paths(all_condition_vars_tile.tolist())
                c_test_dataloader = DataLoader(c_test_dataset,
                                               batch_size=TEST_BATCH_SIZE,
                                               shuffle=False)
                self.test_dataloader = c_test_dataloader
        else:
            arm_test_dataset = PathsDataset(type="CONDITION_ONLY")
            base_test_dataset = PathsDataset(type="CONDITION_ONLY")

            all_condition_vars = np.array(all_condition_vars)
            self.test_condition_vars = np.delete(all_condition_vars, [4, 5],
                                                 axis=1)
            self.test_condition_vars = np.unique(self.test_condition_vars,
                                                 axis=0)
            print("Unique test conditions count : {}".format(
                self.test_condition_vars.shape[0]))
            # print(self.test_condition_vars)
            arm_condition_vars = np.insert(self.test_condition_vars,
                                           2 * POINT_DIM,
                                           1,
                                           axis=1)
            arm_condition_vars = np.insert(arm_condition_vars,
                                           2 * POINT_DIM,
                                           0,
                                           axis=1)

            arm_condition_vars = np.repeat(arm_condition_vars, TEST_SAMPLES, 0)
            arm_test_dataset.add_env_paths(arm_condition_vars.tolist())
            arm_test_dataloader = DataLoader(arm_test_dataset,
                                             batch_size=TEST_BATCH_SIZE,
                                             shuffle=False)

            base_condition_vars = np.insert(self.test_condition_vars,
                                            2 * POINT_DIM,
                                            0,
                                            axis=1)
            base_condition_vars = np.insert(base_condition_vars,
                                            2 * POINT_DIM,
                                            1,
                                            axis=1)

            base_condition_vars = np.repeat(base_condition_vars, TEST_SAMPLES,
                                            0)
            base_test_dataset.add_env_paths(base_condition_vars.tolist())
            base_test_dataloader = DataLoader(base_test_dataset,
                                              batch_size=TEST_BATCH_SIZE,
                                              shuffle=False)

            if mode == "train":
                self.train_dataloader = dataloader
            elif mode == "test":
                self.arm_test_dataloader = arm_test_dataloader
                self.base_test_dataloader = base_test_dataloader

    def visualize_train_data(self, num_conditions=1):
        # Pick a random condition
        # Find all states for that condition
        # Plot them
        print("Plotting input data for {} random conditions".format(
            num_conditions))
        all_input_paths = np.array(self.train_paths_dataset.paths)[:, 1:]
        env_ids = np.array(self.train_paths_dataset.paths)[:, :1]
        # print(all_input_paths[0,:])
        for c_i in range(num_conditions):
            rand_index = np.random.randint(0, all_input_paths.shape[0])
            condition = all_input_paths[rand_index, 2:]
            env_id = env_ids[rand_index, 0]
            # print(condition)
            # condition_samples = np.argwhere(all_input_paths[:,2:] == condition)
            # indices = np.where(all_input_paths[:,2:] == condition)
            # Find all samples corresponding to this condition
            indices = np.where(
                np.isin(all_input_paths[:, 2:], condition).all(axis=1))[0]
            # print(indices)
            x = all_input_paths[indices, :2]
            fig = self.plot(x, condition, env_id=env_id)
            self.cvae.tboard.add_figure('train_data/condition_{}'.format(c_i),
                                        fig, 0)
            # print(all_input_paths[indices,:])
        self.cvae.tboard.flush()

    def visualize_map(self, env_id):
        path = "{}/{}.txt".format(self.env_path_root, int(env_id))
        plt.title('Environment - {}'.format(env_id))
        with open(path, "r") as f:
            line = f.readline()
            while line:
                line = line.split(" ")
                # print(line)
                if "wall" in line[0] or "table" in line[0]:
                    x = float(line[1])
                    y = float(line[2])
                    l = float(line[4])
                    b = float(line[5])
                    rect = Rectangle((x - l / 2, y - b / 2), l, b)
                    plt.gca().add_patch(rect)

                line = f.readline()
        plt.draw()

    def plot(self, x, c, env_id=None, suffix=0, write_file=False, show=False):
        '''
            Plot samples and environment - from train input or predicted output
        '''
        # print(c)
        if IGNORE_START:
            goal = c[0:2]
        else:
            start = c[0:2]
            goal = c[2:4]
        # For given conditional, plot the samples
        fig1 = plt.figure(figsize=(10, 6), dpi=80)
        # ax1 = fig1.add_subplot(111, aspect='equal')
        plt.scatter(x[:, 0], x[:, 1], color="green", s=70, alpha=0.1)
        if IGNORE_START == False:
            plt.scatter(start[0], start[1], color="blue", s=70, alpha=0.6)
        plt.scatter(goal[0], goal[1], color="red", s=70, alpha=0.6)
        if env_id is not None:
            self.visualize_map(env_id)
            # wall_locs = c[4:]
            # i = 0
            # while i < wall_locs.shape[0]:
            #     plt.scatter(wall_locs[i], wall_locs[i+1], color="green", s=70, alpha=0.6)
            #     i = i + 2

        plt.xlabel('x')
        plt.ylabel('y')
        plt.xlim(0, X_MAX)
        plt.ylim(0, Y_MAX)
        if write_file:
            plt.savefig('{}/gen_points_fig_{}.png'.format(
                self.output_path, suffix))
            np.savetxt('{}/gen_points_{}.txt'.format(self.output_path, suffix),
                       x,
                       fmt="%.2f",
                       delimiter=',')
            np.savetxt('{}/start_goal_{}.txt'.format(self.output_path, suffix),
                       np.vstack((start, goal)),
                       fmt="%.2f",
                       delimiter=',')
        if show:
            plt.show()
        # plt.close(fig1)
        return fig1

    def load_saved_cvae(self, decoder_path):
        print("Loading saved CVAE")
        self.cvae.load_decoder(decoder_path)

        # base_cvae = CVAE(run_id=run_id)
        # base_decoder_path = 'experiments/cvae/base/decoder-final.pkl'
        # base_cvae.load_decoder(base_decoder_path)

        # for iteration, batch in enumerate(dataloader):
    def test_single(self,
                    env_id,
                    sample_size=1000,
                    c_test=None,
                    visualize=True):
        self.cvae.eval()
        c_test_gpu = torch.from_numpy(c_test).float().to(self.device)
        c_test_gpu = torch.unsqueeze(c_test_gpu, dim=0)
        x_test = self.cvae.inference(sample_size=sample_size, c=c_test_gpu)
        x_test = x_test.detach().cpu().numpy()

        if visualize:
            self.plot(x_test,
                      c_test,
                      env_id=env_id,
                      show=False,
                      write_file=True,
                      suffix=0)
        return x_test

    def test(self, epoch, dataloader, write_file=False, suffix=""):

        x_test_predicted = []
        self.cvae.eval()
        for iteration, batch in enumerate(dataloader):
            # print(batch)
            c_test_data = batch['condition'].float().to(self.device)
            # print(c_test_data[0,:])
            x_test = self.cvae.batch_inference(c=c_test_data)
            x_test_predicted += x_test.detach().cpu().numpy().tolist()
            # print(x_test.shape)
            if iteration % LOG_INTERVAL == 0 or iteration == len(
                    dataloader) - 1:
                print(
                    "Test Epoch {:02d}/{:02d} Batch {:04d}/{:d}, Iteration {}".
                    format(epoch, num_epochs, iteration,
                           len(dataloader) - 1, iteration))

        x_test_predicted = np.array(x_test_predicted)
        # print(x_test_predicted.shape)
        # Draw plot for each unique condition
        for c_i in range(self.test_condition_vars.shape[0]):
            x_test = x_test_predicted[c_i * TEST_SAMPLES:(c_i + 1) *
                                      TEST_SAMPLES]
            # Fine because c_test is used only for plotting, we dont need arm/base label here
            c_test = self.test_condition_vars[c_i, 1:]
            env_id = self.test_condition_vars[c_i, 0]
            # print(self.test_condition_vars[c_i,:])
            fig = self.plot(x_test,
                            c_test,
                            env_id=env_id,
                            suffix=c_i,
                            write_file=write_file)
            self.cvae.tboard.add_figure(
                'test_epoch_{}/condition_{}_{}'.format(epoch, c_i, suffix),
                fig, 0)
            if c_i % LOG_INTERVAL == 0:
                print("Plotting condition : {}".format(c_i))
        self.cvae.tboard.flush()

        # for c_i in range(c_test_data.shape[0]):
        #     c_test = c_test_data[c_i,:]
        #     c_test_gpu = torch.from_numpy(c_test).float().to(device)

        #     x_test = cvae_model.inference(n=TEST_SAMPLES, c=c_test_gpu)
        #     x_test = x_test.detach().cpu().numpy()
        #     fig = plot(x_test, c_test)
        #     cvae_model.tboard.add_figure('test_epoch_{}/condition_{}'.format(epoch, c_i), fig, 0)

        #     if c_i % 50 == 0:
        #         print("Epoch : {}, Testing condition count : {} ".format(epoch, c_i))

    def train(self,
              run_id=1,
              num_epochs=1,
              initial_learning_rate=0.001,
              weight_decay=0.0001):

        optimizer = torch.optim.Adam(self.cvae.parameters(),
                                     lr=initial_learning_rate,
                                     weight_decay=weight_decay)
        for epoch in range(num_epochs):
            for iteration, batch in enumerate(self.train_dataloader):
                # print(batch['condition'][0,:])
                self.cvae.train()
                x = batch['state'].float().to(self.device)
                c = batch['condition'].float().to(self.device)
                recon_x, mean, log_var, z = self.cvae(x, c)
                # print(recon_x.shape)

                loss = self.cvae.loss_fn(recon_x, x, mean, log_var)

                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

                counter = epoch * len(self.train_dataloader) + iteration
                if iteration % LOG_INTERVAL == 0 or iteration == len(
                        self.train_dataloader) - 1:
                    print(
                        "Train Epoch {:02d}/{:02d} Batch {:04d}/{:d}, Iteration {}, Loss {:9.4f}"
                        .format(epoch, num_epochs, iteration,
                                len(self.train_dataloader) - 1, counter,
                                loss.item()))
                    self.cvae.tboard.add_scalar('train/loss', loss.item(),
                                                counter)

                    # cvae.eval()
                    # c_test = c[0,:]
                    # x_test = cvae.inference(n=TEST_SAMPLES, c=c_test)
                    # x_test = x_test.detach().cpu().numpy()
                    # fig = plot(x_test, c_test)
                    # cvae.tboard.add_figure('test/samples', fig, counter)

            if epoch % TEST_INTERVAL == 0 or epoch == num_epochs - 1:
                # Test CVAE for all c by drawing samples
                if self.data_type != "both":
                    self.test(epoch, self.test_dataloader)
                else:
                    self.test(epoch, self.arm_test_dataloader, suffix="arm")
                    self.test(epoch, self.base_test_dataloader, suffix="base")

            if epoch % SAVE_INTERVAL == 0 and epoch > 0:
                self.cvae.save_model_weights(counter)

        self.cvae.save_model_weights('final')
for epoch in range(args.num_epochs):

    st = time.time()
    # Learning rate scheduling
    #model.assign_lr(learning_rate * (decay_rate ** epoch))
    train_loss = []
    test_loss = []
    st = time.time()

    for iteration in range(len(train_molecules_input) // args.batch_size):
        n = np.random.randint(len(train_molecules_input), size=args.batch_size)
        x = np.array([train_molecules_input[i] for i in n])
        y = np.array([train_molecules_output[i] for i in n])
        l = np.array([train_length[i] for i in n])
        cost = model.train(x, y, l)
        train_loss.append(cost)

    for iteration in range(len(test_molecules_input) // args.batch_size):
        n = np.random.randint(len(test_molecules_input), size=args.batch_size)
        x = np.array([test_molecules_input[i] for i in n])
        y = np.array([test_molecules_output[i] for i in n])
        l = np.array([test_length[i] for i in n])
        cost = model.test(x, y, l)
        test_loss.append(cost)

    train_loss = np.mean(np.array(train_loss))
    test_loss = np.mean(np.array(test_loss))
    end = time.time()
    if epoch == 0:
        print('epoch\ttrain_loss\ttest_loss\ttime (s)')