Beispiel #1
0
def main(args):
    # Check if the output folder is exist
    if not os.path.exists(args.folder):
        os.mkdir(args.folder)

    # Load data
    torch.manual_seed(args.seed)
    kwargs = {'num_workers': 1, 'pin_memory': True} if args.cuda else {}
    train_loader = torch.utils.data.DataLoader(datasets.MNIST(
        './data', train=True, download=True, transform=transforms.ToTensor()),
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               **kwargs)

    # Load model
    model = CVAE().cuda() if torch.cuda.is_available() else CVAE()
    optimizer = optim.Adam(model.parameters(), lr=1e-3)

    # Train and generate sample every epoch
    loss_list = []
    for epoch in range(1, args.epochs + 1):
        model.train()
        _loss = train(epoch, model, train_loader, optimizer)
        loss_list.append(_loss)
        model.eval()
        sample = torch.randn(100, 20)
        label = torch.from_numpy(np.asarray(list(range(10)) * 10))
        sample = Variable(
            sample).cuda() if torch.cuda.is_available() else Variable(sample)
        sample = model.decode(sample, label).cpu()
        save_image(sample.view(100, 1, 28, 28).data,
                   os.path.join(args.folder, 'sample_' + str(epoch) + '.png'),
                   nrow=10)
    plt.plot(range(len(loss_list)), loss_list, '-o')
    plt.savefig(os.path.join(args.folder, 'cvae_loss_curve.png'))
    torch.save(model.state_dict(), os.path.join(args.folder, 'cvae.pth'))
Beispiel #2
0
def main(**kwargs):
    """
    Main function that trains the model
    1. Retrieve arguments from kwargs
    2. Prepare data
    3. Train
    4. Display and save first batch of training set (truth and reconstructed) after every epoch
    5. If latent dimension is 2, display and save latent variable of first batch of training set after every epoch
    
    Args:
        dataset: Which dataset to use
        decoder_type: How to model the output pixels, Gaussian or Bernoulli
        model_sigma: In case of Gaussian decoder, whether to model the sigmas too
        epochs: How many epochs to train model
        batch_size: Size of training / testing batch
        lr: Learning rate
        latent_dim: Dimension of latent variable
        print_every: How often to print training progress
        resume_path: The path of saved model with which to resume training
        resume_epoch: In case of resuming, the number of epochs already done 

    Notes:
        - Saves model to folder 'saved_model/' every 20 epochs and when done
        - Capable of training from scratch and resuming (provide saved model location to argument resume_path)
        - Schedules learning rate with optim.lr_scheduler.ReduceLROnPlateau
            : Decays learning rate by 1/10 when mean loss of all training data does not decrease for 10 epochs
    """
    # Retrieve arguments
    dataset = kwargs.get('dataset', defaults['dataset'])
    decoder_type = kwargs.get('decoder_type', defaults['decoder_type'])
    if decoder_type == 'Gaussian':
        model_sigma = kwargs.get('model_sigma', defaults['model_sigma'])
    epochs = kwargs.get('epochs', defaults['epochs'])
    batch_size = kwargs.get('batch_size', defaults['batch_size'])
    lr = kwargs.get('learning_rate', defaults['learning_rate'])
    latent_dim = kwargs.get('latent_dim', defaults['latent_dim'])
    print_every = kwargs.get('print_every', defaults['print_every'])
    resume_path = kwargs.get('resume_path', defaults['resume_path'])
    resume_epoch = kwargs.get('resume_epoch', defaults['resume_epoch'])

    # Specify dataset transform on load
    if decoder_type == 'Bernoulli':
        trsf = transforms.Compose([
            transforms.ToTensor(),
            transforms.Lambda(lambda x: (x >= 0.5).float())
        ])
    elif decoder_type == 'Gaussian':
        trsf = transforms.ToTensor()

    # Load dataset with transform
    if dataset == 'MNIST':
        train_data = datasets.MNIST(root='MNIST',
                                    train=True,
                                    transform=trsf,
                                    download=True)
        test_data = datasets.MNIST(root='MNIST',
                                   train=False,
                                   transform=trsf,
                                   download=True)
    elif dataset == 'CIFAR10':
        train_data = datasets.CIFAR10(root='CIFAR10',
                                      train=True,
                                      transform=trsf,
                                      download=True)
        test_data = datasets.CIFAR10(root='CIFAR10',
                                     train=False,
                                     transform=trsf,
                                     download=True)

    # Instantiate dataloader
    train_loader = torch.utils.data.DataLoader(train_data,
                                               batch_size=batch_size,
                                               shuffle=True)
    test_loader = torch.utils.data.DataLoader(test_data,
                                              batch_size=batch_size,
                                              shuffle=False)

    # Instantiate/Load model and optimizer
    if resume_path:
        autoencoder = torch.load(resume_path, map_location=device)
        optimizer = optim.Adam(autoencoder.parameters(), lr=lr)
        print('Loaded saved model at ' + resume_path)
    else:
        if decoder_type == 'Bernoulli':
            autoencoder = CVAE(latent_dim, dataset, decoder_type).to(device)
        else:
            autoencoder = CVAE(latent_dim, dataset, decoder_type,
                               model_sigma).to(device)
        optimizer = optim.Adam(autoencoder.parameters(), lr=lr)

    # Instantiate learning rate scheduler
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer,
                                                     'min',
                                                     verbose=True,
                                                     patience=5)

    # Announce current mode
    print(
        f'Start training CVAE with Gaussian encoder and {decoder_type} decoder on {dataset} dataset from epoch {resume_epoch+1}'
    )

    # Prepare batch to display with plt
    first_test_batch, first_test_batch_label = iter(test_loader).next()
    first_test_batch, first_test_batch_label = first_test_batch.to(
        device), first_test_batch_label.to(device)

    # Display latent variable distribution before any training
    if latent_dim == 2 and resume_epoch == 0:
        autoencoder(first_test_batch, first_test_batch_label)
        display_and_save_latent(autoencoder.z, first_test_batch_label,
                                f'-{decoder_type}-z{latent_dim}-e000')

    # Train
    autoencoder.train()
    for epoch in range(resume_epoch, epochs + resume_epoch):
        loss_hist = []
        for batch_ind, (input_data, input_label) in enumerate(train_loader):
            input_data, input_label = input_data.to(device), input_label.to(
                device)

            # Forward propagation
            if decoder_type == 'Bernoulli':
                z_mu, z_sigma, p = autoencoder(input_data, input_label)
            elif model_sigma:
                z_mu, z_sigma, out_mu, out_sigma = autoencoder(
                    input_data, input_label)
            else:
                z_mu, z_sigma, out_mu = autoencoder(input_data, input_label)

            # Calculate loss
            KL_divergence_i = 0.5 * torch.sum(
                z_mu**2 + z_sigma**2 - torch.log(1e-8 + z_sigma**2) - 1.,
                dim=1)
            if decoder_type == 'Bernoulli':
                reconstruction_loss_i = -torch.sum(F.binary_cross_entropy(
                    p, input_data, reduction='none'),
                                                   dim=(1, 2, 3))
            elif model_sigma:
                reconstruction_loss_i = -0.5 * torch.sum(
                    torch.log(1e-8 + 6.28 * out_sigma**2) +
                    ((input_data - out_mu)**2) / (out_sigma**2),
                    dim=(1, 2, 3))
            else:
                reconstruction_loss_i = -0.5 * torch.sum(
                    (input_data - out_mu)**2, dim=(1, 2, 3))
            ELBO_i = reconstruction_loss_i - KL_divergence_i
            loss = -torch.mean(ELBO_i)

            loss_hist.append(loss)

            # Backward propagation
            optimizer.zero_grad()
            loss.backward()

            # Update parameters
            optimizer.step()

            # Print progress
            if batch_ind % print_every == 0:
                train_log = 'Epoch {:03d}/{:03d}\tLoss: {:.6f}\t\tTrain: [{}/{} ({:.0f}%)]           '.format(
                    epoch + 1, epochs + resume_epoch,
                    loss.cpu().item(), batch_ind + 1, len(train_loader),
                    100. * batch_ind / len(train_loader))
                print(train_log, end='\r')
                sys.stdout.flush()

        # Learning rate decay
        scheduler.step(sum(loss_hist) / len(loss_hist))

        # Save model every 20 epochs
        if (epoch + 1) % 20 == 0 and epoch + 1 != epochs:
            PATH = f'saved_model/{dataset}-{decoder_type}-e{epoch+1}-z{latent_dim}' + datetime.datetime.now(
            ).strftime("-%b-%d-%H-%M-%p")
            torch.save(autoencoder, PATH)
            print('\vTemporarily saved model to ' + PATH)

        # Display training result with test set
        data = f'-{decoder_type}-z{latent_dim}-e{epoch+1:03d}'
        with torch.no_grad():
            autoencoder.eval()
            if decoder_type == 'Bernoulli':
                z_mu, z_sigma, p = autoencoder(first_test_batch,
                                               first_test_batch_label)
                output = torch.bernoulli(p)

                if latent_dim == 2:
                    display_and_save_latent(autoencoder.z,
                                            first_test_batch_label, data)

                display_and_save_batch("Binarized-truth",
                                       first_test_batch,
                                       data,
                                       save=(epoch == 0))
                display_and_save_batch("Mean-reconstruction",
                                       p,
                                       data,
                                       save=True)
                display_and_save_batch("Sampled-reconstruction",
                                       output,
                                       data,
                                       save=True)

            elif model_sigma:
                z_mu, z_sigma, out_mu, out_sigma = autoencoder(
                    first_test_batch, first_test_batch_label)
                output = torch.normal(out_mu, out_sigma).clamp(0., 1.)

                if latent_dim == 2:
                    display_and_save_latent(autoencoder.z,
                                            first_test_batch_label, data)

                display_and_save_batch("Truth",
                                       first_test_batch,
                                       data,
                                       save=(epoch == 0))
                display_and_save_batch("Mean-reconstruction",
                                       out_mu,
                                       data,
                                       save=True)
                # display_and_save_batch("Sampled reconstruction", output, data, save=True)

            else:
                z_mu, z_sigma, out_mu = autoencoder(first_test_batch,
                                                    first_test_batch_label)
                output = torch.normal(out_mu,
                                      torch.ones_like(out_mu)).clamp(0., 1.)

                if latent_dim == 2:
                    display_and_save_latent(autoencoder.z,
                                            first_test_batch_label, data)

                display_and_save_batch("Truth",
                                       first_test_batch,
                                       data,
                                       save=(epoch == 0))
                display_and_save_batch("Mean-reconstruction",
                                       out_mu,
                                       data,
                                       save=True)
                # display_and_save_batch("Sampled reconstruction", output, data, save=True)
            autoencoder.train()

    # Save final model
    PATH = f'saved_model/{dataset}-{decoder_type}-e{epochs+resume_epoch}-z{latent_dim}' + datetime.datetime.now(
    ).strftime("-%b-%d-%H-%M-%p")
    torch.save(autoencoder, PATH)
    print('\vSaved model to ' + PATH)
    # fetch data
    data = locate('data.get_%s' % args.dataset)(args)

    # make dataloaders
    train_loader, val_loader, test_loader = [
        CLDataLoader(elem, args, train=t)
        for elem, t in zip(data, [True, False, False])
    ]

    model = ResNet18(args.n_classes, nf=20,
                     input_size=args.input_size).to(args.device)
    opt = torch.optim.SGD(model.parameters(), lr=0.1)

    gen = CVAE(20, args).cuda()  # this is actually an autoencoder
    opt_gen = torch.optim.Adam(gen.parameters())

    # build buffer
    if args.store_latents:
        buffer = Buffer(args, input_size=(20 * 4 * 4, ))
    else:
        buffer = Buffer(args)

    buffer.min_per_class = 0
    print('multiple heads ', args.multiple_heads)

    if run == 0:
        print("number of classifier parameters:",
              sum([np.prod(p.size()) for p in model.parameters()]))
        print("number of generator parameters: ",
              sum([np.prod(p.size()) for p in gen.parameters()]))
Beispiel #4
0
def train(train_A_dir,
          train_B_dir,
          model_dir,
          model_name,
          random_seed,
          val_A_dir,
          val_B_dir,
          output_dir,
          tensorboard_dir,
          load_path,
          gen_eval=True):
    np.random.seed(random_seed)

    # For now, copy hyperparams used in the CycleGAN
    num_epochs = 100000
    mini_batch_size = 1  # mini_batch_size = 1 is better
    learning_rate = 0.0002
    learning_rate_decay = learning_rate / 200000
    sampling_rate = 16000
    num_mcep = 24
    frame_period = 5.0
    n_frames = 128
    lambda_cycle = 10
    lambda_identity = 5
    device = 'cuda'

    # Use the same pre-processing as the CycleGAN
    print("Begin Preprocessing")

    wavs_A = load_wavs(wav_dir=train_A_dir, sr=sampling_rate)
    wavs_B = load_wavs(wav_dir=train_B_dir, sr=sampling_rate)
    print("Finished Loading")

    f0s_A, timeaxes_A, sps_A, aps_A, coded_sps_A = world_encode_data(
        wavs=wavs_A,
        fs=sampling_rate,
        frame_period=frame_period,
        coded_dim=num_mcep)
    f0s_B, timeaxes_B, sps_B, aps_B, coded_sps_B = world_encode_data(
        wavs=wavs_B,
        fs=sampling_rate,
        frame_period=frame_period,
        coded_dim=num_mcep)
    print("Finished Encoding")

    log_f0s_mean_A, log_f0s_std_A = logf0_statistics(f0s_A)
    log_f0s_mean_B, log_f0s_std_B = logf0_statistics(f0s_B)

    print('Log Pitch A')
    print('Mean: %f, Std: %f' % (log_f0s_mean_A, log_f0s_std_A))
    print('Log Pitch B')
    print('Mean: %f, Std: %f' % (log_f0s_mean_B, log_f0s_std_B))

    coded_sps_A_transposed = transpose_in_list(lst=coded_sps_A)
    coded_sps_B_transposed = transpose_in_list(lst=coded_sps_B)

    coded_sps_A_norm, coded_sps_A_mean, coded_sps_A_std = coded_sps_normalization_fit_transoform(
        coded_sps=coded_sps_A_transposed)
    print("Input data fixed.")
    coded_sps_B_norm, coded_sps_B_mean, coded_sps_B_std = coded_sps_normalization_fit_transoform(
        coded_sps=coded_sps_B_transposed)

    if not os.path.exists(model_dir):
        os.makedirs(model_dir)
    np.savez(os.path.join(model_dir, 'logf0s_normalization.npz'),
             mean_A=log_f0s_mean_A,
             std_A=log_f0s_std_A,
             mean_B=log_f0s_mean_B,
             std_B=log_f0s_std_B)
    np.savez(os.path.join(model_dir, 'mcep_normalization.npz'),
             mean_A=coded_sps_A_mean,
             std_A=coded_sps_A_std,
             mean_B=coded_sps_B_mean,
             std_B=coded_sps_B_std)

    if val_A_dir is not None:
        validation_A_output_dir = os.path.join(output_dir, 'converted_A')
        if not os.path.exists(validation_A_output_dir):
            os.makedirs(validation_A_output_dir)

    if val_B_dir is not None:
        validation_B_output_dir = os.path.join(output_dir, 'converted_B')
        if not os.path.exists(validation_B_output_dir):
            os.makedirs(validation_B_output_dir)

    print("End Preprocessing")

    if load_path is not None:
        model = CVAE(num_mcep, 128, num_mcep, 2)
        model.load_state_dict(torch.load(load_path))
        model.eval()
        if device == 'cuda':
            model.cuda()
        print("Loaded Model from path %s" % load_path)
        if val_A_dir is not None and gen_eval:
            print("Generating Evaluation Data")
            for file in os.listdir(val_A_dir):
                filepath = os.path.join(val_A_dir, file)
                print(
                    "Converting {0} from Class 0 to Class 1".format(filepath))
                wav, _ = librosa.load(filepath, sr=sampling_rate, mono=True)
                wav = wav_padding(wav=wav,
                                  sr=sampling_rate,
                                  frame_period=frame_period,
                                  multiple=4)
                f0, timeaxis, sp, ap = world_decompose(
                    wav=wav, fs=sampling_rate, frame_period=frame_period)
                f0_converted = pitch_conversion(f0=f0,
                                                mean_log_src=log_f0s_mean_A,
                                                std_log_src=log_f0s_std_A,
                                                mean_log_target=log_f0s_mean_B,
                                                std_log_target=log_f0s_std_B)
                coded_sp = world_encode_spectral_envelop(sp=sp,
                                                         fs=sampling_rate,
                                                         dim=num_mcep)
                coded_sp_transposed = coded_sp.T
                coded_sp_norm = (coded_sp_transposed -
                                 coded_sps_A_mean) / coded_sps_A_std
                coded_sp_converted_norm, _, _ = model.convert(
                    np.array([coded_sp_norm]), 0, 1, device)
                coded_sp_converted_norm = coded_sp_converted_norm.cpu().numpy()
                coded_sp_converted_norm = np.squeeze(coded_sp_converted_norm)
                coded_sp_converted = coded_sp_converted_norm * coded_sps_B_std + coded_sps_B_mean
                coded_sp_converted = coded_sp_converted.T
                coded_sp_converted = np.ascontiguousarray(coded_sp_converted)
                decoded_sp_converted = world_decode_spectral_envelop(
                    coded_sp=coded_sp_converted, fs=sampling_rate)
                wav_transformed = world_speech_synthesis(
                    f0=f0_converted,
                    decoded_sp=decoded_sp_converted,
                    ap=ap,
                    fs=sampling_rate,
                    frame_period=frame_period)
                librosa.output.write_wav(
                    os.path.join(validation_A_output_dir,
                                 'eval_' + os.path.basename(file)),
                    wav_transformed, sampling_rate)
            exit(0)

    print("Begin Training")

    model = CVAE(num_mcep, 128, num_mcep, 2)
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)
    writer = SummaryWriter(tensorboard_dir)

    if device == 'cuda':
        model.cuda()

    for epoch in tqdm(range(num_epochs)):
        dataset_A, dataset_B = sample_train_data(dataset_A=coded_sps_A_norm,
                                                 dataset_B=coded_sps_B_norm,
                                                 n_frames=n_frames)
        dataset_A = torch.tensor(dataset_A).to(torch.float)
        dataset_B = torch.tensor(dataset_B).to(torch.float)

        n_samples, input_dim, depth = dataset_A.shape
        y_A = F.one_hot(torch.zeros(depth).to(torch.int64),
                        num_classes=2).to(torch.float).T
        y_B = F.one_hot(torch.ones(depth).to(torch.int64),
                        num_classes=2).to(torch.float).T
        (y_A, y_B) = (y_A.reshape((1, 2, depth)), y_B.reshape((1, 2, depth)))
        y_A = torch.cat([y_A] * n_samples)
        y_B = torch.cat([y_B] * n_samples)

        # dataset_A = torch.cat((dataset_A, y_A), axis=1)
        # dataset_B = torch.cat((dataset_B, y_B), axis=1)

        X = torch.cat((dataset_A, dataset_B)).to(device)
        Y = torch.cat((y_A, y_B)).to(device)

        # out, z_mu, z_var = model(dataset_A, y_A)
        # rec_loss = F.binary_cross_entropy(out, dataset_A, size_average=False)
        # kl_diver = -0.5 * torch.sum(1 + z_var - z_mu.pow(2) - z_var.exp())
        out, z_mu, z_var = model(X, Y)

        rec_loss = F.binary_cross_entropy(out, X, size_average=False)
        kl_diver = -0.5 * torch.sum(1 + z_var - z_mu.pow(2) - z_var.exp())

        loss = rec_loss + kl_diver

        writer.add_scalar('Reconstruction Loss', rec_loss, epoch)
        writer.add_scalar('KL-Divergence', kl_diver, epoch)
        writer.add_scalar('Total Loss', loss, epoch)

        # print("loss = {0} || rec = {1} || kl = {2}".format(loss, rec_loss, kl_diver))

        loss.backward()
        optimizer.step()

        if val_A_dir is not None:
            if epoch % 1000 == 0:
                print('Generating Validation Data...')
                for file in os.listdir(val_A_dir):
                    filepath = os.path.join(val_A_dir, file)
                    print("Converting {0} from Class 0 to Class 1".format(
                        filepath))
                    wav, _ = librosa.load(filepath,
                                          sr=sampling_rate,
                                          mono=True)
                    wav = wav_padding(wav=wav,
                                      sr=sampling_rate,
                                      frame_period=frame_period,
                                      multiple=4)
                    f0, timeaxis, sp, ap = world_decompose(
                        wav=wav, fs=sampling_rate, frame_period=frame_period)
                    f0_converted = pitch_conversion(
                        f0=f0,
                        mean_log_src=log_f0s_mean_A,
                        std_log_src=log_f0s_std_A,
                        mean_log_target=log_f0s_mean_B,
                        std_log_target=log_f0s_std_B)
                    coded_sp = world_encode_spectral_envelop(sp=sp,
                                                             fs=sampling_rate,
                                                             dim=num_mcep)
                    coded_sp_transposed = coded_sp.T
                    coded_sp_norm = (coded_sp_transposed -
                                     coded_sps_A_mean) / coded_sps_A_std
                    coded_sp_converted_norm, _, _ = model.convert(
                        np.array([coded_sp_norm]), 0, 1, device)
                    coded_sp_converted_norm = coded_sp_converted_norm.cpu(
                    ).numpy()
                    coded_sp_converted_norm = np.squeeze(
                        coded_sp_converted_norm)
                    coded_sp_converted = coded_sp_converted_norm * coded_sps_B_std + coded_sps_B_mean
                    coded_sp_converted = coded_sp_converted.T
                    coded_sp_converted = np.ascontiguousarray(
                        coded_sp_converted)
                    decoded_sp_converted = world_decode_spectral_envelop(
                        coded_sp=coded_sp_converted, fs=sampling_rate)
                    wav_transformed = world_speech_synthesis(
                        f0=f0_converted,
                        decoded_sp=decoded_sp_converted,
                        ap=ap,
                        fs=sampling_rate,
                        frame_period=frame_period)
                    librosa.output.write_wav(
                        os.path.join(validation_A_output_dir,
                                     str(epoch) + '_' +
                                     os.path.basename(file)), wav_transformed,
                        sampling_rate)
                    break
        if epoch % 1000 == 0:
            print('Saving Checkpoint')
            filepath = os.path.join(model_dir, model_name)
            if not os.path.exists(filepath):
                os.makedirs(filepath)
            torch.save(model.state_dict(),
                       os.path.join(filepath, '{0}.ckpt'.format(epoch)))
    train[train.columns[train.columns != "class"]]), np.array(
        pd.get_dummies(train["class"]))
testx, testy = np.array(
    test[train.columns[train.columns != "class"]]), np.array(
        pd.get_dummies(test["class"]))
batch_size = 512
max_epoch = 100
train_N = len(train)
test_N = len(test)
gpu = False
device = "cuda" if gpu else "cpu"

model = CVAE()
if gpu:
    model = model.cuda()
opt = optim.Adadelta(model.parameters(), lr=1e-3)


def Loss_function(x_hat, x, mu, logsimga):
    reconstraction_loss = F.binary_cross_entropy(x_hat, x, size_average=False)
    KL_div = -0.5 * th.sum(1 + logsimga - mu.pow(2) - logsimga.exp())

    return reconstraction_loss + KL_div


def create_batch(x, y):
    a = list(range(len(x)))
    np.random.shuffle(a)
    x = x[a]
    y = y[a]
    batch_x = [
Beispiel #6
0
    torch.manual_seed(run)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

    # fetch data
    data = locate('data.get_%s' % args.dataset)(args)

    # make dataloaders
    train_loader, val_loader, test_loader  = [CLDataLoader(elem, args, train=t) for elem, t in zip(data, [True, False, False])]


    model = ResNet18(args.n_classes, nf=20, input_size=args.input_size).to(args.device)
    opt   = torch.optim.SGD(model.parameters(), lr=0.1)

    gen     = CVAE(20, args).cuda() # this is actually an autoencoder
    opt_gen = torch.optim.Adam(gen.parameters())

    # build buffer
    if args.store_latents:
        buffer = Buffer(args, input_size = (20*4*4,))
    else:
        buffer = Buffer(args)

    buffer.min_per_class = 0
    print('multiple heads ', args.multiple_heads)

    if run == 0:
        print("number of classifier parameters:", sum([np.prod(p.size()) for p in model.parameters()]))
        print("number of generator parameters: ", sum([np.prod(p.size()) for p in gen.parameters()]))
        print("buffer parameters:              ", np.prod(buffer.bx.size()))
Beispiel #7
0
class CVAEInterface():
    def __init__(self, run_id=1, output_path="", env_path_root=""):
        super().__init__()
        self.cvae = CVAE(run_id=run_id)
        self.device = torch.device('cuda' if CUDA_AVAILABLE else 'cpu')
        self.output_path = output_path
        self.env_path_root = env_path_root

        if self.output_path is not None:
            if os.path.exists(self.output_path):
                shutil.rmtree(self.output_path)
            os.mkdir(self.output_path)

    def load_dataset(self, dataset_root, data_type="arm", mode="train"):
        assert (data_type == "both" or data_type == "arm"
                or data_type == "base")
        assert (mode == "train" or mode == "test")
        # Should show different count and path for different modes
        print("Loading {} dataset for mode : {}, path : {}".format(
            data_type, mode, dataset_root))
        self.data_type = data_type

        paths_dataset = PathsDataset(type="FULL_STATE")
        c_test_dataset = PathsDataset(type="CONDITION_ONLY")
        env_dir_paths = os.listdir(dataset_root)
        # Get all C vars to test sample generation on each
        all_condition_vars = []
        for env_dir_index in filter(lambda f: f[0].isdigit(), env_dir_paths):
            env_paths_file = os.path.join(dataset_root, env_dir_index,
                                          "data_{}.txt".format(data_type))
            env_paths = np.loadtxt(env_paths_file)
            # 4 to 16
            if IGNORE_START:
                start = env_paths[:, X_DIM:2 * X_DIM]
                samples = env_paths[:, :X_DIM]
                euc_dist = np.linalg.norm(start - samples, axis=1)
                far_from_start = np.where(euc_dist > 5.0)
                print(far_from_start)
                env_paths = env_paths[far_from_start[0], :]
                condition_vars = env_paths[:, 2 * X_DIM:2 * X_DIM + C_DIM]
            else:
                if mode == "train":
                    # Testing, less points near start to reduce them in sampled output
                    start = env_paths[:, X_DIM:X_DIM + POINT_DIM]
                    samples = env_paths[:, :X_DIM]
                    euc_dist = np.linalg.norm(start - samples, axis=1)
                    far_from_start = np.where(euc_dist > 2.0)
                    # print(far_from_start)
                    env_paths = env_paths[far_from_start[0], :]

                condition_vars = env_paths[:, X_DIM:X_DIM + C_DIM]
            # print(env_paths.shape)
            # Stuff for train dataloader
            # Take only required elements
            # env_paths = env_paths[:, :X_DIM + C_DIM]
            env_paths = np.hstack((env_paths[:, :X_DIM], condition_vars))
            # Uniquify to remove duplicates
            env_paths = np.unique(env_paths, axis=0)
            env_index = np.empty((env_paths.shape[0], 1))
            env_index.fill(env_dir_index)
            data = np.hstack((env_index, env_paths))
            paths_dataset.add_env_paths(data.tolist())

            # Stuff for test dataloader
            env_index = np.empty((condition_vars.shape[0], 1))
            env_index.fill(env_dir_index)
            data = np.hstack((env_index, condition_vars))
            all_condition_vars += data.tolist()
            print("Added {} states from {} environment".format(
                env_paths.shape[0], env_dir_index))

        dataloader = DataLoader(paths_dataset,
                                batch_size=TRAIN_BATCH_SIZE,
                                shuffle=True)

        if data_type != "both":

            # Depending on which dataset is being loaded, set the right variables
            if mode == "train":
                self.train_dataloader = dataloader
                self.train_paths_dataset = paths_dataset
            elif mode == "test":
                self.test_condition_vars = np.unique(all_condition_vars,
                                                     axis=0)
                print("Unique test conditions count : {}".format(
                    self.test_condition_vars.shape[0]))
                # Tile condition variables to predict given number of samples for x
                all_condition_vars_tile = np.repeat(self.test_condition_vars,
                                                    TEST_SAMPLES, 0)
                c_test_dataset.add_env_paths(all_condition_vars_tile.tolist())
                c_test_dataloader = DataLoader(c_test_dataset,
                                               batch_size=TEST_BATCH_SIZE,
                                               shuffle=False)
                self.test_dataloader = c_test_dataloader
        else:
            arm_test_dataset = PathsDataset(type="CONDITION_ONLY")
            base_test_dataset = PathsDataset(type="CONDITION_ONLY")

            all_condition_vars = np.array(all_condition_vars)
            self.test_condition_vars = np.delete(all_condition_vars, [4, 5],
                                                 axis=1)
            self.test_condition_vars = np.unique(self.test_condition_vars,
                                                 axis=0)
            print("Unique test conditions count : {}".format(
                self.test_condition_vars.shape[0]))
            # print(self.test_condition_vars)
            arm_condition_vars = np.insert(self.test_condition_vars,
                                           2 * POINT_DIM,
                                           1,
                                           axis=1)
            arm_condition_vars = np.insert(arm_condition_vars,
                                           2 * POINT_DIM,
                                           0,
                                           axis=1)

            arm_condition_vars = np.repeat(arm_condition_vars, TEST_SAMPLES, 0)
            arm_test_dataset.add_env_paths(arm_condition_vars.tolist())
            arm_test_dataloader = DataLoader(arm_test_dataset,
                                             batch_size=TEST_BATCH_SIZE,
                                             shuffle=False)

            base_condition_vars = np.insert(self.test_condition_vars,
                                            2 * POINT_DIM,
                                            0,
                                            axis=1)
            base_condition_vars = np.insert(base_condition_vars,
                                            2 * POINT_DIM,
                                            1,
                                            axis=1)

            base_condition_vars = np.repeat(base_condition_vars, TEST_SAMPLES,
                                            0)
            base_test_dataset.add_env_paths(base_condition_vars.tolist())
            base_test_dataloader = DataLoader(base_test_dataset,
                                              batch_size=TEST_BATCH_SIZE,
                                              shuffle=False)

            if mode == "train":
                self.train_dataloader = dataloader
            elif mode == "test":
                self.arm_test_dataloader = arm_test_dataloader
                self.base_test_dataloader = base_test_dataloader

    def visualize_train_data(self, num_conditions=1):
        # Pick a random condition
        # Find all states for that condition
        # Plot them
        print("Plotting input data for {} random conditions".format(
            num_conditions))
        all_input_paths = np.array(self.train_paths_dataset.paths)[:, 1:]
        env_ids = np.array(self.train_paths_dataset.paths)[:, :1]
        # print(all_input_paths[0,:])
        for c_i in range(num_conditions):
            rand_index = np.random.randint(0, all_input_paths.shape[0])
            condition = all_input_paths[rand_index, 2:]
            env_id = env_ids[rand_index, 0]
            # print(condition)
            # condition_samples = np.argwhere(all_input_paths[:,2:] == condition)
            # indices = np.where(all_input_paths[:,2:] == condition)
            # Find all samples corresponding to this condition
            indices = np.where(
                np.isin(all_input_paths[:, 2:], condition).all(axis=1))[0]
            # print(indices)
            x = all_input_paths[indices, :2]
            fig = self.plot(x, condition, env_id=env_id)
            self.cvae.tboard.add_figure('train_data/condition_{}'.format(c_i),
                                        fig, 0)
            # print(all_input_paths[indices,:])
        self.cvae.tboard.flush()

    def visualize_map(self, env_id):
        path = "{}/{}.txt".format(self.env_path_root, int(env_id))
        plt.title('Environment - {}'.format(env_id))
        with open(path, "r") as f:
            line = f.readline()
            while line:
                line = line.split(" ")
                # print(line)
                if "wall" in line[0] or "table" in line[0]:
                    x = float(line[1])
                    y = float(line[2])
                    l = float(line[4])
                    b = float(line[5])
                    rect = Rectangle((x - l / 2, y - b / 2), l, b)
                    plt.gca().add_patch(rect)

                line = f.readline()
        plt.draw()

    def plot(self, x, c, env_id=None, suffix=0, write_file=False, show=False):
        '''
            Plot samples and environment - from train input or predicted output
        '''
        # print(c)
        if IGNORE_START:
            goal = c[0:2]
        else:
            start = c[0:2]
            goal = c[2:4]
        # For given conditional, plot the samples
        fig1 = plt.figure(figsize=(10, 6), dpi=80)
        # ax1 = fig1.add_subplot(111, aspect='equal')
        plt.scatter(x[:, 0], x[:, 1], color="green", s=70, alpha=0.1)
        if IGNORE_START == False:
            plt.scatter(start[0], start[1], color="blue", s=70, alpha=0.6)
        plt.scatter(goal[0], goal[1], color="red", s=70, alpha=0.6)
        if env_id is not None:
            self.visualize_map(env_id)
            # wall_locs = c[4:]
            # i = 0
            # while i < wall_locs.shape[0]:
            #     plt.scatter(wall_locs[i], wall_locs[i+1], color="green", s=70, alpha=0.6)
            #     i = i + 2

        plt.xlabel('x')
        plt.ylabel('y')
        plt.xlim(0, X_MAX)
        plt.ylim(0, Y_MAX)
        if write_file:
            plt.savefig('{}/gen_points_fig_{}.png'.format(
                self.output_path, suffix))
            np.savetxt('{}/gen_points_{}.txt'.format(self.output_path, suffix),
                       x,
                       fmt="%.2f",
                       delimiter=',')
            np.savetxt('{}/start_goal_{}.txt'.format(self.output_path, suffix),
                       np.vstack((start, goal)),
                       fmt="%.2f",
                       delimiter=',')
        if show:
            plt.show()
        # plt.close(fig1)
        return fig1

    def load_saved_cvae(self, decoder_path):
        print("Loading saved CVAE")
        self.cvae.load_decoder(decoder_path)

        # base_cvae = CVAE(run_id=run_id)
        # base_decoder_path = 'experiments/cvae/base/decoder-final.pkl'
        # base_cvae.load_decoder(base_decoder_path)

        # for iteration, batch in enumerate(dataloader):
    def test_single(self,
                    env_id,
                    sample_size=1000,
                    c_test=None,
                    visualize=True):
        self.cvae.eval()
        c_test_gpu = torch.from_numpy(c_test).float().to(self.device)
        c_test_gpu = torch.unsqueeze(c_test_gpu, dim=0)
        x_test = self.cvae.inference(sample_size=sample_size, c=c_test_gpu)
        x_test = x_test.detach().cpu().numpy()

        if visualize:
            self.plot(x_test,
                      c_test,
                      env_id=env_id,
                      show=False,
                      write_file=True,
                      suffix=0)
        return x_test

    def test(self, epoch, dataloader, write_file=False, suffix=""):

        x_test_predicted = []
        self.cvae.eval()
        for iteration, batch in enumerate(dataloader):
            # print(batch)
            c_test_data = batch['condition'].float().to(self.device)
            # print(c_test_data[0,:])
            x_test = self.cvae.batch_inference(c=c_test_data)
            x_test_predicted += x_test.detach().cpu().numpy().tolist()
            # print(x_test.shape)
            if iteration % LOG_INTERVAL == 0 or iteration == len(
                    dataloader) - 1:
                print(
                    "Test Epoch {:02d}/{:02d} Batch {:04d}/{:d}, Iteration {}".
                    format(epoch, num_epochs, iteration,
                           len(dataloader) - 1, iteration))

        x_test_predicted = np.array(x_test_predicted)
        # print(x_test_predicted.shape)
        # Draw plot for each unique condition
        for c_i in range(self.test_condition_vars.shape[0]):
            x_test = x_test_predicted[c_i * TEST_SAMPLES:(c_i + 1) *
                                      TEST_SAMPLES]
            # Fine because c_test is used only for plotting, we dont need arm/base label here
            c_test = self.test_condition_vars[c_i, 1:]
            env_id = self.test_condition_vars[c_i, 0]
            # print(self.test_condition_vars[c_i,:])
            fig = self.plot(x_test,
                            c_test,
                            env_id=env_id,
                            suffix=c_i,
                            write_file=write_file)
            self.cvae.tboard.add_figure(
                'test_epoch_{}/condition_{}_{}'.format(epoch, c_i, suffix),
                fig, 0)
            if c_i % LOG_INTERVAL == 0:
                print("Plotting condition : {}".format(c_i))
        self.cvae.tboard.flush()

        # for c_i in range(c_test_data.shape[0]):
        #     c_test = c_test_data[c_i,:]
        #     c_test_gpu = torch.from_numpy(c_test).float().to(device)

        #     x_test = cvae_model.inference(n=TEST_SAMPLES, c=c_test_gpu)
        #     x_test = x_test.detach().cpu().numpy()
        #     fig = plot(x_test, c_test)
        #     cvae_model.tboard.add_figure('test_epoch_{}/condition_{}'.format(epoch, c_i), fig, 0)

        #     if c_i % 50 == 0:
        #         print("Epoch : {}, Testing condition count : {} ".format(epoch, c_i))

    def train(self,
              run_id=1,
              num_epochs=1,
              initial_learning_rate=0.001,
              weight_decay=0.0001):

        optimizer = torch.optim.Adam(self.cvae.parameters(),
                                     lr=initial_learning_rate,
                                     weight_decay=weight_decay)
        for epoch in range(num_epochs):
            for iteration, batch in enumerate(self.train_dataloader):
                # print(batch['condition'][0,:])
                self.cvae.train()
                x = batch['state'].float().to(self.device)
                c = batch['condition'].float().to(self.device)
                recon_x, mean, log_var, z = self.cvae(x, c)
                # print(recon_x.shape)

                loss = self.cvae.loss_fn(recon_x, x, mean, log_var)

                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

                counter = epoch * len(self.train_dataloader) + iteration
                if iteration % LOG_INTERVAL == 0 or iteration == len(
                        self.train_dataloader) - 1:
                    print(
                        "Train Epoch {:02d}/{:02d} Batch {:04d}/{:d}, Iteration {}, Loss {:9.4f}"
                        .format(epoch, num_epochs, iteration,
                                len(self.train_dataloader) - 1, counter,
                                loss.item()))
                    self.cvae.tboard.add_scalar('train/loss', loss.item(),
                                                counter)

                    # cvae.eval()
                    # c_test = c[0,:]
                    # x_test = cvae.inference(n=TEST_SAMPLES, c=c_test)
                    # x_test = x_test.detach().cpu().numpy()
                    # fig = plot(x_test, c_test)
                    # cvae.tboard.add_figure('test/samples', fig, counter)

            if epoch % TEST_INTERVAL == 0 or epoch == num_epochs - 1:
                # Test CVAE for all c by drawing samples
                if self.data_type != "both":
                    self.test(epoch, self.test_dataloader)
                else:
                    self.test(epoch, self.arm_test_dataloader, suffix="arm")
                    self.test(epoch, self.base_test_dataloader, suffix="base")

            if epoch % SAVE_INTERVAL == 0 and epoch > 0:
                self.cvae.save_model_weights(counter)

        self.cvae.save_model_weights('final')
Beispiel #8
0
def main(args):
    # Create model directory
    if not os.path.exists(args.model_path):
        os.makedirs(args.model_path)

    # Load vocabulary wrapper
    with open(args.vocab_path, 'rb') as f:
        vocab = pickle.load(f)

    pad_idx = vocab.word2idx['<pad>']
    sos_idx = vocab.word2idx['<start>']
    eos_idx = vocab.word2idx['<end>']
    unk_idx = vocab.word2idx['<unk>']

    # Build data loader
    train_data_loader, valid_data_loader = get_loader(
        args.train_image_dir,
        args.val_image_dir,
        args.train_caption_path,
        args.val_caption_path,
        vocab,
        args.batch_size,
        shuffle=True,
        num_workers=args.num_workers)

    def kl_anneal_function(anneal_function, step, k, x0):
        if anneal_function == 'logistic':
            # return float(1 / (1 + np.exp(-k * (step - x0))))
            return float(expit(k * (step - x0)))
        elif anneal_function == 'linear':
            return min(1, step / x0)

    nll = torch.nn.NLLLoss(ignore_index=pad_idx)

    def loss_fn(logp, target, length, mean, logv, anneal_function, step, k,
                x0):
        # cut-off unnecessary padding from target, and flatten
        target = target[:, :torch.max(length).data[0]].contiguous().view(-1)
        logp = logp.view(-1, logp.size(2))

        # Negative Log Likelihood
        nll_loss = nll(logp, target)

        # KL Divergence
        KL_loss = -0.5 * torch.sum(1 + logv - mean.pow(2) - logv.exp())
        KL_weight = kl_anneal_function(anneal_function, step, k, x0)

        return nll_loss, KL_loss, KL_weight

    # Build the models
    model = CVAE(vocab_size=len(vocab),
                 embedding_size=args.embedding_size,
                 rnn_type=args.rnn_type,
                 hidden_size=args.hidden_size,
                 word_dropout=args.word_dropout,
                 embedding_dropout=args.embedding_dropout,
                 latent_size=args.latent_size,
                 max_sequence_length=args.max_sequence_length,
                 num_layers=args.num_layers,
                 bidirectional=args.bidirectional,
                 pad_idx=pad_idx,
                 sos_idx=sos_idx,
                 eos_idx=eos_idx,
                 unk_idx=unk_idx)
    model.to(device)
    # Loss and optimizer
    optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate)

    # Train the models
    total_step = len(train_data_loader)
    step_for_kl_annealing = 0
    best_valid_loss = float("inf")
    patience = 0

    for epoch in range(args.num_epochs):
        for i, (images, captions, lengths) in enumerate(train_data_loader):

            # Set mini-batch dataset
            images = images.to(device)
            captions_src = captions[:, :captions.size()[1] - 1]
            captions_tgt = captions[:, 1:]
            captions_src = captions_src.to(device)
            captions_tgt = captions_tgt.to(device)
            lengths = lengths - 1
            lengths = lengths.to(device)

            # Forward, backward and optimize
            logp, mean, logv, z = model(images, captions_src, lengths)

            #loss calculation
            NLL_loss, KL_loss, KL_weight = loss_fn(logp, captions_tgt, lengths,
                                                   mean, logv,
                                                   args.anneal_function,
                                                   step_for_kl_annealing,
                                                   args.k, args.x0)

            loss = (NLL_loss + KL_weight * KL_loss) / args.batch_size

            # backward + optimization
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            step_for_kl_annealing += 1

            # Print log info
            if i % args.log_step == 0:
                print(
                    'Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}, Perplexity: {:5.4f}'
                    .format(epoch, args.num_epochs, i, total_step, loss.item(),
                            np.exp(loss.item())))
                outputs = model._sample(logp)
                outputs = outputs.cpu().numpy()

                # Convert word_ids to words
                sampled_caption = []
                ground_truth_caption = []
                for word_id in outputs[-1]:
                    word = vocab.idx2word[word_id]
                    sampled_caption.append(word)
                    if word == '<end>':
                        break

                captions_tgt = captions_tgt.cpu().numpy()
                for word_id in captions_tgt[-1]:
                    word = vocab.idx2word[word_id]
                    ground_truth_caption.append(word)
                    if word == '<end>':
                        break
                reconstructed = ' '.join(sampled_caption)
                ground_truth = ' '.join(ground_truth_caption)
                print("ground_truth: {0} \n reconstructed: {1}\n".format(
                    ground_truth, reconstructed))

            # Save the model checkpoints
            if (i + 1) % args.save_step == 0:
                torch.save(
                    model.state_dict(),
                    os.path.join(args.model_path,
                                 'model-{}-{}.ckpt'.format(epoch + 1, i + 1)))

        torch.save(
            model.state_dict(),
            os.path.join(args.model_path,
                         'model-{}-epoch.ckpt'.format(epoch + 1)))

        valid_loss = 0

        #check against validation set and early stop if the validation score is not improving within patience period
        for j, (images, captions, lengths) in enumerate(valid_data_loader):
            # Set mini-batch dataset
            images = images.to(device)
            captions_src = captions[:, :captions.size()[1] - 1]
            captions_tgt = captions[:, 1:]
            captions_src = captions_src.to(device)
            captions_tgt = captions_tgt.to(device)
            lengths = lengths - 1
            lengths = lengths.to(device)

            # Forward, backward and optimize
            logp, mean, logv, z = model(images, captions_src, lengths)

            # loss calculation
            NLL_loss, KL_loss, KL_weight = loss_fn(logp, captions_tgt, lengths,
                                                   mean, logv,
                                                   args.anneal_function,
                                                   step_for_kl_annealing,
                                                   args.k, args.x0)

            valid_loss += (NLL_loss + KL_weight * KL_loss) / args.batch_size

            if j == 2:
                break
        print("validation loss for epoch {}: {}".format(epoch + 1, valid_loss))
        print("patience is at {}".format(patience))
        if valid_loss < best_valid_loss:
            best_valid_loss = valid_loss
            patience = 0
        else:
            patience += 1

        if patience == 5:
            print("early stopping at epoch {}".format(epoch + 1))
            break
Beispiel #9
0
def main():
    # config_RNN = Settings(model='RNN', model_name='9.23_dropout0.1_RNN', resume='best.pth')
    # config_GPT = Settings(model='GPT', model_name='10.3_GPT_dim8_layer6', resume='best.pth')
    config_CVAE = Settings(model='CVAE',
                           model_name='10.10_CVAE_dim16',
                           resume='best.pth')

    dataset_path = os.path.join(os.getcwd(), config_CVAE.path)
    dataset_filename = config_CVAE.test_file
    device = config_CVAE.device
    # model_dir_RNN = os.path.join(config_RNN.exp_dir, config_RNN.model_name)
    model_dir_GPT = os.path.join(config_CVAE.exp_dir, config_CVAE.model_name)
    model_dir_CVAE = os.path.join(config_CVAE.exp_dir, config_CVAE.model_name)

    with open(os.path.join(dataset_path, dataset_filename), 'rb') as f:
        dataset = pickle.load(f)
    dataset = LightDarkDataset(config_CVAE, dataset)
    data, targets = collect_data(config_CVAE, dataset)

    # with open(os.path.join(dataset_path, 'light_dark_sample_len15.pickle'), 'rb') as f:
    #     sample = pickle.load(f)
    # data, targets = sample['data'], sample['targets']

    # model_RNN = RNN(config_RNN).to(device)
    model_GPT = GPT2(config_CVAE).to(device)
    model_CVAE = CVAE(config_CVAE).to(device)

    # optimizer_RNN = th.optim.AdamW(model_RNN.parameters(),
    #                            lr=config_RNN.learning_rate,
    #                            weight_decay=config_RNN.weight_decay)
    # optimizer_GPT = th.optim.AdamW(model_GPT.parameters(),
    #                            lr=config_GPT.learning_rate,
    #                            weight_decay=config_GPT.weight_decay)
    optimizer_CVAE = th.optim.AdamW(model_CVAE.parameters(),
                                    lr=config_CVAE.learning_rate,
                                    weight_decay=config_CVAE.weight_decay)

    if config_CVAE.optimizer == 'AdamW':
        # scheduler_RNN = th.optim.lr_scheduler.LambdaLR(optimizer_RNN, lambda step: min((step+1)/config_RNN.warmup_step, 1))
        # scheduler_GPT = th.optim.lr_scheduler.LambdaLR(optimizer_GPT, lambda step: min((step+1)/config_GPT.warmup_step, 1))
        scheduler_CVAE = th.optim.lr_scheduler.LambdaLR(
            optimizer_CVAE, lambda step: min(
                (step + 1) / config_CVAE.warmup_step, 1))
    elif config_CVAE.optimizer == 'AdamWR':
        # scheduler_RNN = CosineAnnealingWarmUpRestarts(
        #     optimizer=optimizer_RNN,
        #     T_0=config_RNN.T_0,
        #     T_mult=config_RNN.T_mult,
        #     eta_max=config_RNN.lr_max,
        #     T_up=config_RNN.warmup_step,
        #     gamma=config_RNN.lr_mult
        # )
        # scheduler_GPT = CosineAnnealingWarmUpRestarts(
        #     optimizer=optimizer_GPT,
        #     T_0=config_GPT.T_0,
        #     T_mult=config_GPT.T_mult,
        #     eta_max=config_GPT.lr_max,
        #     T_up=config_GPT.warmup_step,
        #     gamma=config_GPT.lr_mult
        # )
        scheduler_CVAE = CosineAnnealingWarmUpRestarts(
            optimizer=optimizer_CVAE,
            T_0=config_CVAE.T_0,
            T_mult=config_CVAE.T_mult,
            eta_max=config_CVAE.lr_max,
            T_up=config_CVAE.warmup_step,
            gamma=config_CVAE.lr_mult)
    else:
        # |FIXME| using error?exception?logging?
        print(
            f'"{config_CVAE.optimizer}" is not support!! You should select "AdamW" or "AdamWR".'
        )
        return

    # load checkpoint for resuming
    if config_CVAE.resume is not None:
        # filename_RNN = os.path.join(model_dir_RNN, config_RNN.resume)
        # filename_GPT = os.path.join(model_dir_GPT, config_GPT.resume)
        filename_CVAE = os.path.join(model_dir_CVAE, config_CVAE.resume)

        # if os.path.isfile(filename_RNN):
        #     start_epoch_RNN, best_error_RNN, model_RNN, optimizer_RNN, scheduler_RNN = load_checkpoint(config_RNN, filename_RNN, model_RNN, optimizer_RNN, scheduler_RNN)
        #     start_epoch_RNN += 1
        #     print("[RNN]Loaded checkpoint '{}' (epoch {})".format(config_RNN.resume, start_epoch_RNN))
        # else:
        #     # |FIXME| using error?exception?logging?
        #     print("No checkpoint found at '{}'".format(config_RNN.resume))
        #     return

        # if os.path.isfile(filename_GPT):
        #     start_epoch_GPT, best_error_GPT, model_GPT, optimizer_GPT, scheduler_GPT = load_checkpoint(config_GPT, filename_GPT, model_GPT, optimizer_GPT, scheduler_GPT)
        #     start_epoch_GPT += 1
        #     print("[GPT]Loaded checkpoint '{}' (epoch {})".format(config_GPT.resume, start_epoch_GPT))
        # else:
        #     # |FIXME| using error?exception?logging?
        #     print("No checkpoint found at '{}'".format(config_GPT.resume))
        #     return

        if os.path.isfile(filename_CVAE):
            start_epoch_CVAE, best_error_CVAE, model_CVAE, optimizer_CVAE, scheduler_CVAE = load_checkpoint(
                config_CVAE, filename_CVAE, model_CVAE, optimizer_CVAE,
                scheduler_CVAE)
            start_epoch_CVAE += 1
            print("[CVAE]Loaded checkpoint '{}' (epoch {})".format(
                config_CVAE.resume, start_epoch_CVAE))
        else:
            # |FIXME| using error?exception?logging?
            print("No checkpoint found at '{}'".format(config_CVAE.resume))
            return

    # pred_RNN = []
    pred_GPT = []
    pred_CVAE = []
    # total_time_RNN = 0.
    total_time_GPT = 0.
    total_time_CVAE = 0.
    for d in data:
        for i in range(config_CVAE.num_output):
            # tmp_pred_RNN, time_RNN = predict_action(config_RNN, model_RNN, d)
            # tmp_pred_GPT, time_GPT = predict_action(config_GPT, model_GPT, d)
            tmp_pred_CVAE, time_CVAE = predict_action(config_CVAE, model_CVAE,
                                                      d)

            # pred_RNN.append(tmp_pred_RNN)
            # pred_GPT.append(tmp_pred_GPT)
            pred_CVAE.append(tmp_pred_CVAE)
            # total_time_RNN += time_RNN
            # total_time_GPT += time_GPT
            total_time_CVAE += time_CVAE

    targets = np.asarray(targets).reshape(-1, 2)
    # pred_RNN = np.asarray(pred_RNN).reshape(-1, 2)
    pred_GPT = np.asarray(pred_GPT).reshape(-1, 2)
    pred_CVAE = np.asarray(pred_CVAE).reshape(-1, 2)

    # print(f'Inference time for RNN: {total_time_RNN / (config_RNN.num_input * config_RNN.num_output)}')
    # print(f'Inference time for GPT: {total_time_GPT / (config_GPT.num_input * config_GPT.num_output)}')
    print(
        f'Inference time for CVAE: {total_time_CVAE / (config_CVAE.num_input * config_CVAE.num_output)}'
    )

    plt.xlim(-7, 7)
    plt.ylim(-7, 7)
    plt.scatter(targets[:, 0], targets[:, 1], c='red')
    # plt.scatter(pred_RNN[:,0], pred_RNN[:,1], c='green')
    # plt.scatter(pred_GPT[:,0], pred_GPT[:,1], c='blue')
    plt.scatter(pred_CVAE[:, 0], pred_CVAE[:, 1], c='black')
    plt.show()