def extract_reconstructions(encoder_input, style_mu, class_mu, class_logvar):
    grouped_mu, _ = accumulate_group_evidence(class_mu.data, class_logvar.data,
                                              torch.zeros(style_mu.size(0), 1))
    decoder_style_input = style_mu.clone().detach().requires_grad_(True)
    decoder_content_input = grouped_mu[0].clone().detach().requires_grad_(True)

    # optimize wrt the above variables
    # decoder_style_input.cuda()
    # decoder_content_input.cuda()

    content = decoder_content_input.expand(
        style_mu.size(0), decoder_content_input.size(0)).double()

    optimizer = optim.Adam([decoder_style_input, decoder_content_input])

    for iterations in range(50):
        optimizer.zero_grad()

        reconstructed = decoder(decoder_style_input, content)
        reconstruction_error = torch.sum(
            (reconstructed - encoder_input).pow(2))
        reconstruction_error.backward()

        optimizer.step()

    return reconstructed, reconstruction_error
Example #2
0
  def forward(self, x, edge_index, batch, num_graphs):

    # batch_size = data.num_graphs
    if x is None:
        x = torch.ones(batch.shape[0]).to(device)

    node_mu, node_logvar, class_mu, class_logvar = self.encoder(x, edge_index, batch)
    grouped_mu, grouped_logvar = accumulate_group_evidence(
        class_mu.data, class_logvar.data, batch, True
    )

    # kl-divergence error for style latent space
    node_kl_divergence_loss = torch.mean(
        - 0.5 * torch.sum(1 + node_logvar - node_mu.pow(2) - node_logvar.exp())
    )
    node_kl_divergence_loss = 0.0000001 * node_kl_divergence_loss *num_graphs
    node_kl_divergence_loss.backward(retain_graph=True)

    # kl-divergence error for class latent space
    class_kl_divergence_loss = torch.mean(
        - 0.5 * torch.sum(1 + grouped_logvar - grouped_mu.pow(2) - grouped_logvar.exp())
    )
    class_kl_divergence_loss = 0.0000001 * class_kl_divergence_loss * num_graphs
    class_kl_divergence_loss.backward(retain_graph=True)

    # reconstruct samples
    """
    sampling from group mu and logvar for each graph in mini-batch differently makes
    the decoder consider class latent embeddings as random noise and ignore them 
    """
    node_latent_embeddings = reparameterize(training=True, mu=node_mu, logvar=node_logvar)
    class_latent_embeddings = group_wise_reparameterize(
        training=True, mu=grouped_mu, logvar=grouped_logvar, labels_batch=batch, cuda=True
    )

    #need to reduce ml between node and class latents
    '''measure='JSD'
    mi_loss = local_global_loss_disen(node_latent_embeddings, class_latent_embeddings, edge_index, batch, measure)
    mi_loss.backward(retain_graph=True)'''

    reconstructed_node = self.decoder(node_latent_embeddings, class_latent_embeddings, edge_index)
    #check input feat first
    #print('recon ', x[0],reconstructed_node[0])
    reconstruction_error =  0.1*mse_loss(reconstructed_node, x) * num_graphs
    reconstruction_error.backward()

    
    return reconstruction_error.item() , class_kl_divergence_loss.item() , node_kl_divergence_loss.item()
def run_through_network(X, labels_batch):
    style_mu, style_logvar, class_mu, class_logvar = encoder(Variable(X))
    grouped_mu, grouped_logvar = accumulate_group_evidence(
        class_mu.data, class_logvar.data, labels_batch, FLAGS.cuda
    )
    importance_sampling(X[0], style_mu[0], style_logvar[0], class_mu[0], class_logvar[0])

    # reconstruct samples
    style_latent_embeddings = reparameterize(training=True, mu=style_mu, logvar=style_logvar)
    class_latent_embeddings = group_wise_reparameterize(
        training=True, mu=grouped_mu, logvar=grouped_logvar, labels_batch=labels_batch, cuda=FLAGS.cuda
    )

    reconstructed_images = decoder(style_latent_embeddings, class_latent_embeddings)

    return reconstructed_images
Example #4
0
def process(FLAGS, X, labels_batch, encoder, decoder):

    style_mu, style_logvar, class_mu, class_logvar = encoder(X.cuda())

    content_mu, content_logvar, list_g, sizes_group = \
            accumulate_group_evidence(FLAGS,class_mu, class_logvar,
                                    labels_batch, FLAGS.cuda)
    style_latent_embeddings = reparameterize(training=True,
                                             mu=style_mu,
                                             logvar=style_logvar)

    class_latent_embeddings, indexes, sizes = group_wise_reparameterize_each(
        training=True,
        mu=content_mu,
        logvar=content_logvar,
        labels_batch=labels_batch,
        list_groups_labels=list_g,
        sizes_group=sizes_group,
        cuda=FLAGS.cuda)

    # kl-divergence error for style latent space
    style_kl_divergence_loss = 0.5 * (-1 - style_logvar[indexes, :] +
                                      style_mu[indexes, :].pow(2) +
                                      style_logvar[indexes, :].exp()).sum()
    # kl-divergence error for class latent space
    class_kl_divergence_loss = 0.5 * (-1 - content_logvar + content_mu.pow(2) +
                                      content_logvar.exp()).sum()

    # reconstruct samples
    #reorder by the same order as class_latent_embeddings
    mu_x, logvar_x = decoder(style_latent_embeddings[indexes, :],
                             class_latent_embeddings)
    scale_x = (torch.exp(logvar_x) + 1e-12)**0.5
    scale_x = scale_x.view(X.size(0), 784)
    # create normal distribution on output pixel
    mu_x = mu_x.view(X.size(0), 784)
    prob_x = Normal(mu_x, scale_x)
    logp_batch = prob_x.log_prob(X[indexes, :].view(X.size(0), 784)).sum(1)

    reconstruction_proba = logp_batch.sum(0)
    n_groups = content_mu.size(0)
    elbo = (reconstruction_proba - style_kl_divergence_loss -
            class_kl_divergence_loss) / n_groups

    return elbo, reconstruction_proba / n_groups, style_kl_divergence_loss / n_groups, class_kl_divergence_loss / n_groups
Example #5
0
def training_procedure(FLAGS):
    """
    model definition
    """
    encoder = Encoder(style_dim=FLAGS.style_dim, class_dim=FLAGS.class_dim)
    encoder.apply(weights_init)

    decoder = Decoder(style_dim=FLAGS.style_dim, class_dim=FLAGS.class_dim)
    decoder.apply(weights_init)

    # load saved models if load_saved flag is true
    if FLAGS.load_saved:
        encoder.load_state_dict(
            torch.load(os.path.join('checkpoints', FLAGS.encoder_save)))
        decoder.load_state_dict(
            torch.load(os.path.join('checkpoints', FLAGS.decoder_save)))
    """
    variable definition
    """
    X = torch.FloatTensor(FLAGS.batch_size, 1, FLAGS.image_size,
                          FLAGS.image_size)
    '''
    add option to run on GPU
    '''
    if FLAGS.cuda:
        encoder.cuda()
        decoder.cuda()

        X = X.cuda()
    """
    optimizer definition
    """
    auto_encoder_optimizer = optim.Adam(list(encoder.parameters()) +
                                        list(decoder.parameters()),
                                        lr=FLAGS.initial_learning_rate,
                                        betas=(FLAGS.beta_1, FLAGS.beta_2))
    """
    training
    """
    if torch.cuda.is_available() and not FLAGS.cuda:
        print(
            "WARNING: You have a CUDA device, so you should probably run with --cuda"
        )

    if not os.path.exists('checkpoints'):
        os.makedirs('checkpoints')

    # load_saved is false when training is started from 0th iteration
    if not FLAGS.load_saved:
        with open(FLAGS.log_file, 'w') as log:
            log.write(
                'Epoch\tIteration\tReconstruction_loss\tStyle_KL_divergence_loss\tClass_KL_divergence_loss\n'
            )

    # load data set and create data loader instance
    print('Loading MNIST dataset...')
    mnist = datasets.MNIST(root='mnist',
                           download=True,
                           train=True,
                           transform=transform_config)
    loader = cycle(
        DataLoader(mnist,
                   batch_size=FLAGS.batch_size,
                   shuffle=True,
                   num_workers=0,
                   drop_last=True))

    # initialize summary writer
    writer = SummaryWriter()

    for epoch in range(FLAGS.start_epoch, FLAGS.end_epoch):
        print('')
        print(
            'Epoch #' + str(epoch) +
            '..........................................................................'
        )

        for iteration in range(int(len(mnist) / FLAGS.batch_size)):
            # load a mini-batch
            image_batch, labels_batch = next(loader)

            # set zero_grad for the optimizer
            auto_encoder_optimizer.zero_grad()

            X.copy_(image_batch)

            style_mu, style_logvar, class_mu, class_logvar = encoder(
                Variable(X))
            grouped_mu, grouped_logvar = accumulate_group_evidence(
                class_mu.data, class_logvar.data, labels_batch, FLAGS.cuda)

            # kl-divergence error for style latent space
            style_kl_divergence_loss = FLAGS.kl_divergence_coef * (
                -0.5 * torch.sum(1 + style_logvar - style_mu.pow(2) -
                                 style_logvar.exp()))
            style_kl_divergence_loss /= (FLAGS.batch_size *
                                         FLAGS.num_channels *
                                         FLAGS.image_size * FLAGS.image_size)
            style_kl_divergence_loss.backward(retain_graph=True)

            # kl-divergence error for class latent space
            class_kl_divergence_loss = FLAGS.kl_divergence_coef * (
                -0.5 * torch.sum(1 + grouped_logvar - grouped_mu.pow(2) -
                                 grouped_logvar.exp()))
            class_kl_divergence_loss /= (FLAGS.batch_size *
                                         FLAGS.num_channels *
                                         FLAGS.image_size * FLAGS.image_size)
            class_kl_divergence_loss.backward(retain_graph=True)

            # reconstruct samples
            """
            sampling from group mu and logvar for each image in mini-batch differently makes
            the decoder consider class latent embeddings as random noise and ignore them 
            """
            style_latent_embeddings = reparameterize(training=True,
                                                     mu=style_mu,
                                                     logvar=style_logvar)
            class_latent_embeddings = group_wise_reparameterize(
                training=True,
                mu=grouped_mu,
                logvar=grouped_logvar,
                labels_batch=labels_batch,
                cuda=FLAGS.cuda)

            reconstructed_images = decoder(style_latent_embeddings,
                                           class_latent_embeddings)

            reconstruction_error = FLAGS.reconstruction_coef * mse_loss(
                reconstructed_images, Variable(X))
            reconstruction_error.backward()

            auto_encoder_optimizer.step()

            if (iteration + 1) % 50 == 0:
                print('')
                print('Epoch #' + str(epoch))
                print('Iteration #' + str(iteration))

                print('')
                print('Reconstruction loss: ' +
                      str(reconstruction_error.data.storage().tolist()[0]))
                print('Style KL-Divergence loss: ' +
                      str(style_kl_divergence_loss.data.storage().tolist()[0]))
                print('Class KL-Divergence loss: ' +
                      str(class_kl_divergence_loss.data.storage().tolist()[0]))

            # write to log
            with open(FLAGS.log_file, 'a') as log:
                log.write('{0}\t{1}\t{2}\t{3}\t{4}\n'.format(
                    epoch, iteration,
                    reconstruction_error.data.storage().tolist()[0],
                    style_kl_divergence_loss.data.storage().tolist()[0],
                    class_kl_divergence_loss.data.storage().tolist()[0]))

            # write to tensorboard
            writer.add_scalar(
                'Reconstruction loss',
                reconstruction_error.data.storage().tolist()[0],
                epoch * (int(len(mnist) / FLAGS.batch_size) + 1) + iteration)
            writer.add_scalar(
                'Style KL-Divergence loss',
                style_kl_divergence_loss.data.storage().tolist()[0],
                epoch * (int(len(mnist) / FLAGS.batch_size) + 1) + iteration)
            writer.add_scalar(
                'Class KL-Divergence loss',
                class_kl_divergence_loss.data.storage().tolist()[0],
                epoch * (int(len(mnist) / FLAGS.batch_size) + 1) + iteration)

        # save checkpoints after every 5 epochs
        if (epoch + 1) % 5 == 0 or (epoch + 1) == FLAGS.end_epoch:
            torch.save(encoder.state_dict(),
                       os.path.join('checkpoints', FLAGS.encoder_save))
            torch.save(decoder.state_dict(),
                       os.path.join('checkpoints', FLAGS.decoder_save))
            print()
            X_i = test_data[i] # 1-d vector
            l = X_i.size(0)

            errors = []

            for eta in range(10, 90):
                print('\tRunning eta =', eta)
                g1 = X_i[0:eta].view(eta, -1).cuda()
                g2 = X_i[eta:l].view(l-eta, -1).cuda()

                total_error = 0
                for g in [g1, g2]:
                    style_mu, _, class_mu, class_logvar = encoder(g)
                    grouped_mu, _ = accumulate_group_evidence(
                        class_mu.data, class_logvar.data, torch.zeros(g.size(0), 1)
                    )

                    decoder_style_input = torch.tensor(style_mu, requires_grad = True, device='cuda')
                    decoder_content_input = torch.tensor(grouped_mu[0], requires_grad = True, device='cuda')

                    content = decoder_content_input.expand(g.size(0), 1)

                    optimizer = optim.Adam(
                        [decoder_style_input, decoder_content_input],
                        lr = 0.01 # this may be an important parameter
                    )

                    for iterations in range(500):
                        optimizer.zero_grad()
Example #7
0
        DataLoader(paired_mnist,
                   batch_size=FLAGS.num_test_samples,
                   shuffle=True,
                   num_workers=0,
                   drop_last=True))

    image_batch, _, labels_batch = next(loader)

    style_mu, style_logvar, class_mu, class_logvar = encoder(
        Variable(image_batch))
    style_latent_embeddings = reparameterize(training=True,
                                             mu=style_mu,
                                             logvar=style_logvar)

    if FLAGS.accumulate_evidence:
        grouped_mu, grouped_logvar = accumulate_group_evidence(
            class_mu.data, class_logvar.data, labels_batch, FLAGS.cuda)

        class_latent_embeddings = group_wise_reparameterize(
            training=True,
            mu=grouped_mu,
            logvar=grouped_logvar,
            labels_batch=labels_batch,
            cuda=FLAGS.cuda)
    else:
        class_latent_embeddings = reparameterize(training=True,
                                                 mu=class_mu,
                                                 logvar=class_logvar)

    # perform t-SNE embedding
    vis_data = TSNE(n_components=2, verbose=1, perplexity=30.0,
                    n_iter=1000).fit_transform(
Example #8
0
def training_procedure(FLAGS):
    """
    model definition
    """
    encoder = Encoder(style_dim=FLAGS.style_dim, class_dim=FLAGS.class_dim)
    encoder.apply(weights_init)

    decoder = Decoder(style_dim=FLAGS.style_dim, class_dim=FLAGS.class_dim)
    decoder.apply(weights_init)

    # load saved models if load_saved flag is true
    if FLAGS.load_saved:
        encoder.load_state_dict(
            torch.load(os.path.join('checkpoints', FLAGS.encoder_save)))
        decoder.load_state_dict(
            torch.load(os.path.join('checkpoints', FLAGS.decoder_save)))
    """
    variable definition
    """
    X = torch.FloatTensor(FLAGS.batch_size, 784)
    '''
    run on GPU if GPU is available
    '''
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    encoder.to(device=device)
    decoder.to(device=device)
    X = X.to(device=device)
    """
    optimizer definition
    """
    auto_encoder_optimizer = optim.Adam(list(encoder.parameters()) +
                                        list(decoder.parameters()),
                                        lr=FLAGS.initial_learning_rate,
                                        betas=(FLAGS.beta_1, FLAGS.beta_2))
    """
    
    """
    if torch.cuda.is_available() and not FLAGS.cuda:
        print(
            "WARNING: You have a CUDA device, so you should probably run with --cuda"
        )

    if not os.path.exists('checkpoints'):
        os.makedirs('checkpoints')

    # load_saved is false when training is started from 0th iteration
    if not FLAGS.load_saved:
        with open(FLAGS.log_file, 'w') as log:
            log.write(
                'Epoch\tIteration\tReconstruction_loss\tStyle_KL_divergence_loss\tClass_KL_divergence_loss\n'
            )

    # load data set and create data loader instance
    dirs = os.listdir(os.path.join(os.getcwd(), 'data'))
    print('Loading double multivariate normal time series data...')
    for dsname in dirs:
        params = dsname.split('_')
        if params[2] in ('theta=-1'):
            print('Running dataset ', dsname)
            ds = DoubleMulNormal(dsname)
            # ds = experiment3(1000, 50, 3)
            loader = cycle(
                DataLoader(ds,
                           batch_size=FLAGS.batch_size,
                           shuffle=True,
                           drop_last=True))

            # initialize summary writer
            writer = SummaryWriter()

            for epoch in range(FLAGS.start_epoch, FLAGS.end_epoch):
                print()
                print(
                    'Epoch #' + str(epoch) +
                    '........................................................')

                # the total loss at each epoch after running iterations of batches
                total_loss = 0

                for iteration in range(int(len(ds) / FLAGS.batch_size)):
                    # load a mini-batch
                    image_batch, labels_batch = next(loader)

                    # set zero_grad for the optimizer
                    auto_encoder_optimizer.zero_grad()

                    X.copy_(image_batch)

                    style_mu, style_logvar, class_mu, class_logvar = encoder(
                        Variable(X))
                    grouped_mu, grouped_logvar = accumulate_group_evidence(
                        class_mu.data, class_logvar.data, labels_batch,
                        FLAGS.cuda)

                    # kl-divergence error for style latent space
                    style_kl_divergence_loss = FLAGS.kl_divergence_coef * (
                        -0.5 * torch.sum(1 + style_logvar - style_mu.pow(2) -
                                         style_logvar.exp()))
                    style_kl_divergence_loss /= (FLAGS.batch_size *
                                                 FLAGS.num_channels *
                                                 FLAGS.image_size *
                                                 FLAGS.image_size)
                    style_kl_divergence_loss.backward(retain_graph=True)

                    # kl-divergence error for class latent space
                    class_kl_divergence_loss = FLAGS.kl_divergence_coef * (
                        -0.5 *
                        torch.sum(1 + grouped_logvar - grouped_mu.pow(2) -
                                  grouped_logvar.exp()))
                    class_kl_divergence_loss /= (FLAGS.batch_size *
                                                 FLAGS.num_channels *
                                                 FLAGS.image_size *
                                                 FLAGS.image_size)
                    class_kl_divergence_loss.backward(retain_graph=True)

                    # reconstruct samples
                    """
                    sampling from group mu and logvar for each image in mini-batch differently makes
                    the decoder consider class latent embeddings as random noise and ignore them 
                    """
                    style_latent_embeddings = reparameterize(
                        training=True, mu=style_mu, logvar=style_logvar)
                    class_latent_embeddings = group_wise_reparameterize(
                        training=True,
                        mu=grouped_mu,
                        logvar=grouped_logvar,
                        labels_batch=labels_batch,
                        cuda=FLAGS.cuda)

                    reconstructed_images = decoder(style_latent_embeddings,
                                                   class_latent_embeddings)

                    reconstruction_error = FLAGS.reconstruction_coef * mse_loss(
                        reconstructed_images, Variable(X))
                    reconstruction_error.backward()

                    total_loss += style_kl_divergence_loss + class_kl_divergence_loss + reconstruction_error

                    auto_encoder_optimizer.step()

                    if (iteration + 1) % 50 == 0:
                        print('\tIteration #' + str(iteration))
                        print('Reconstruction loss: ' + str(
                            reconstruction_error.data.storage().tolist()[0]))
                        print('Style KL loss: ' +
                              str(style_kl_divergence_loss.data.storage().
                                  tolist()[0]))
                        print('Class KL loss: ' +
                              str(class_kl_divergence_loss.data.storage().
                                  tolist()[0]))

                    # write to log
                    with open(FLAGS.log_file, 'a') as log:
                        log.write('{0}\t{1}\t{2}\t{3}\t{4}\n'.format(
                            epoch, iteration,
                            reconstruction_error.data.storage().tolist()[0],
                            style_kl_divergence_loss.data.storage().tolist()
                            [0],
                            class_kl_divergence_loss.data.storage().tolist()
                            [0]))

                    # write to tensorboard
                    writer.add_scalar(
                        'Reconstruction loss',
                        reconstruction_error.data.storage().tolist()[0],
                        epoch * (int(len(ds) / FLAGS.batch_size) + 1) +
                        iteration)
                    writer.add_scalar(
                        'Style KL-Divergence loss',
                        style_kl_divergence_loss.data.storage().tolist()[0],
                        epoch * (int(len(ds) / FLAGS.batch_size) + 1) +
                        iteration)
                    writer.add_scalar(
                        'Class KL-Divergence loss',
                        class_kl_divergence_loss.data.storage().tolist()[0],
                        epoch * (int(len(ds) / FLAGS.batch_size) + 1) +
                        iteration)

                    if epoch == 0 and (iteration + 1) % 50 == 0:
                        torch.save(
                            encoder.state_dict(),
                            os.path.join('checkpoints', 'encoder_' + dsname))
                        torch.save(
                            decoder.state_dict(),
                            os.path.join('checkpoints', 'decoder_' + dsname))

                # save checkpoints after every 10 epochs
                if (epoch + 1) % 10 == 0 or (epoch + 1) == FLAGS.end_epoch:
                    torch.save(
                        encoder.state_dict(),
                        os.path.join('checkpoints', 'encoder_' + dsname))
                    torch.save(
                        decoder.state_dict(),
                        os.path.join('checkpoints', 'decoder_' + dsname))

                print('Total loss at current epoch: ', total_loss.item())
    specified_factor_images = []
    for i in range(1, 11):
        specified_factor_images.append(image_array[0][i][:, :, 0])

    specified_factor_images = np.asarray(specified_factor_images)
    specified_factor_images = np.expand_dims(specified_factor_images, axis=3)
    specified_factor_images = np.transpose(specified_factor_images,
                                           (0, 3, 1, 2))
    specified_factor_images = torch.FloatTensor(specified_factor_images)
    specified_factor_images = specified_factor_images.contiguous()

    if FLAGS.accumulate_evidence:
        # sample a big batch, accumulate evidence and use that for class embeddings
        image_batch, _, labels_batch = next(loader)
        _, __, class_mu, class_logvar = encoder(image_batch)
        content_mu, content_logvar, list_g, sizes_group = accumulate_group_evidence(
            FLAGS, class_mu.data, class_logvar.data, labels_batch, FLAGS.cuda)

    # generate all possible combinations using the encoder and decoder architecture in the grid
    for row in range(1, 11):
        style_image = image_array[row][0]
        style_image = np.transpose(style_image, (2, 0, 1))
        style_image = torch.FloatTensor(style_image)
        style_image = style_image.contiguous()
        style_image = style_image[0, :, :]
        style_image = style_image.view(1, 1, 28, 28)
        style_mu, style_logvar, _, _ = encoder(Variable(style_image))
        style_latent_embeddings = reparameterize(training=True,
                                                 mu=style_mu,
                                                 logvar=style_logvar)

        for col in range(1, 11):