Beispiel #1
0
    def train(self, data_loader):
        print('Training...')
        with torch.autograd.set_detect_anomaly(True):
            self.epoch += 1
            self.set_train()
            record_trace = utils.Record()
            record_image = utils.Record()
            record_inter = utils.Record()
            record_kld = utils.Record()
            start_time = time.time()
            progress = progressbar.ProgressBar(maxval=len(data_loader)).start()
            for i, (trace, image) in enumerate(data_loader):
                progress.update(i + 1)
                trace = trace.cuda()
                image = image.cuda()
                self.zero_grad()
                trace_embed = self.TraceEncoder(trace)
                image_embed = self.ImageEncoder(image)
                trace_mu, trace_logvar = trace_embed, trace_embed
                image_mu, image_logvar = image_embed, image_embed
                trace_z = utils.reparameterize(trace_mu, trace_logvar)
                image_z = utils.reparameterize(image_mu, image_logvar)
                trace2image, trace_inter = self.Decoder(trace_z)
                image2image, image_inter = self.Decoder(image_z)

                err_trace = self.l1(trace2image, image)
                err_image = self.l1(image2image, image)
                #err_inter = self.l2(trace_inter, image_inter)
                err_kld = self.kld(image_mu, image_logvar, trace_mu,
                                   trace_logvar)

                #(err_trace + err_image + err_inter + self.args['beta'] * err_kld).backward()
                (err_trace + err_image +
                 self.args['beta'] * err_kld).backward()

                self.optimizer.step()

                record_trace.add(err_trace)
                record_image.add(err_image)
                #record_inter.add(err_inter)
                record_kld.add(err_kld)
            progress.finish()
            utils.clear_progressbar()
            print('----------------------------------------')
            print('Epoch: %d' % self.epoch)
            print('Costs time: %.2fs' % (time.time() - start_time))
            print('Loss of Trace to Image: %f' % (record_trace.mean()))
            print('Loss of Image to Image: %f' % (record_image.mean()))
            print('Loss of KL-Divergence: %f' % (record_kld.mean()))
            print('----------------------------------------')
            utils.save_image(image.data, ('%s/image/train/target_%03d.jpg' %
                                          (self.args['vae_dir'], self.epoch)))
            utils.save_image(trace2image.data,
                             ('%s/image/train/tr2im_%03d.jpg' %
                              (self.args['vae_dir'], self.epoch)))
            utils.save_image(image2image.data,
                             ('%s/image/train/im2im_%03d.jpg' %
                              (self.args['vae_dir'], self.epoch)))
Beispiel #2
0
    def forward(self,
                goals,
                goals_length,
                posts,
                posts_length,
                origin_responses=None):
        goal_output, _ = self.goal_encoder(goals)  # [B, G, H]
        goal_h = batch_gather_3_1(goal_output, goals_length)  # [B, H]

        batchsz, max_sen, max_word = posts.shape
        post_flat = posts.view(batchsz * max_sen, max_word)
        post_output_flat, _ = self.sys_encoder(post_flat)
        post_output = post_output_flat.view(batchsz, max_sen, max_word,
                                            -1)  # [B, S, P, H]
        post_h = batch_gather_4_2(post_output, posts_length)  # [B, S, H]

        context_output, _ = self.context_encoder(
            post_h, goal_h.unsqueeze(0))  # [B, S, H]
        posts_sen_length = posts_length.gt(0).sum(1)  # [B]

        context = batch_gather_3_1(context_output, posts_sen_length)  # [B, H]
        mu, logvar = self.mu_net(context), self.logvar_net(context)
        last_context = batch_gather_3_1(context_output, posts_sen_length - 1)
        mu_last, logvar_last = self.mu_net_last(
            last_context), self.logvar_net_last(last_context)
        z = reparameterize(mu_last, logvar_last)
        hidden = self.concat_net(torch.cat([context, z], dim=1))

        teacher = 1 if origin_responses is not None else 0
        a_weights, _, _ = self.usr_decoder(inputs=origin_responses, encoder_hidden=hidden.unsqueeze(0), \
                                           teacher_forcing_ratio=teacher)
        t_weights = self.terminal_net(context).squeeze(1)

        return a_weights, t_weights, (mu_last, logvar_last, mu, logvar)
def do_epoch_bnn(model, dataloader, criterion, optim=None, T=1):
    total_loss = 0
    total_accuracy = 0
    for x, y_true in tqdm(dataloader, leave=False):
        x, y_true = x.to(device), y_true.to(device)

        # y_prob = 0.
        # for t in range(T):
        #     y_pred, s_pred = model(x)
        #     y_prob += 1./T * F.softmax(reparameterize(y_pred, s_pred), dim=1)
        # loss = F.nll_loss(torch.log(y_prob), y_true)
        # # True Bayesian network should average over probabilities (T times),
        # # however, logit with CrossEntropyLoss is more stable than log(softmax) with NLLLoss
        y_logit = 0.
        for t in range(T):
            y_pred, s_pred = model(x)
            y_logit += 1. / T * reparameterize(y_pred, s_pred)
        loss = criterion(y_logit, y_true)

        if optim is not None:
            optim.zero_grad()
            loss.backward()
            optim.step()

        total_loss += loss.item()
        total_accuracy += (y_logit.max(1)[1] == y_true).float().mean().item()
    mean_loss = total_loss / len(dataloader)
    mean_accuracy = total_accuracy / len(dataloader)

    return mean_loss, mean_accuracy
Beispiel #4
0
  def get_embeddings(self, loader):

      device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
      ret = []
      y = []
      with torch.no_grad():
          for data in loader:
              data.to(device)
              x, edge_index, batch = data.x, data.edge_index, data.batch
              if x is None:
                  x = torch.ones((batch.shape[0],1)).to(device)
              node_mu, node_logvar, class_mu, class_logvar = self.encoder(x, edge_index, batch)

              '''grouped_mu, grouped_logvar = accumulate_group_evidence(
                  class_mu.data, class_logvar.data, batch, True
              )

              accumulated_class_latent_embeddings = group_wise_reparameterize(
                  training=False, mu=grouped_mu, logvar=grouped_logvar, labels_batch=batch, cuda=True
              )

              class_emb = global_mean_pool(accumulated_class_latent_embeddings, batch)'''

              node_latent_embeddings = reparameterize(training=True, mu=node_mu, logvar=node_logvar)

              ret.append(node_latent_embeddings.cpu().numpy())
              y.append(data.y.cpu().numpy())
      ret = np.concatenate(ret, 0)
      y = np.concatenate(y, 0)
      return ret, y
Beispiel #5
0
    def forward(self, x):
        mu, logsigma, classcode = self.encoder(x)
        contentcode = reparameterize(mu, logsigma)
        latentcode = torch.cat([contentcode, classcode], dim=1)

        recon_x = self.decoder(latentcode)

        return mu, logsigma, classcode, recon_x
def backward_loss(x, model, device):
    mu, logsigma, classcode = model.encoder(x)
    shuffled_classcode = classcode[torch.randperm(classcode.shape[0])]
    randcontent = torch.randn_like(mu).to(device)

    latentcode1 = torch.cat([randcontent, classcode], dim=1)
    latentcode2 = torch.cat([randcontent, shuffled_classcode], dim=1)

    recon_imgs1 = model.decoder(latentcode1).detach()
    recon_imgs2 = model.decoder(latentcode2).detach()

    cycle_mu1, cycle_logsigma1, cycle_classcode1 = model.encoder(recon_imgs1)
    cycle_mu2, cycle_logsigma2, cycle_classcode2 = model.encoder(recon_imgs2)

    cycle_contentcode1 = reparameterize(cycle_mu1, cycle_logsigma1)
    cycle_contentcode2 = reparameterize(cycle_mu2, cycle_logsigma2)

    bloss = F.l1_loss(cycle_contentcode1, cycle_contentcode2)
    return bloss
Beispiel #7
0
def train_bayes(train_loader, model, optimizer, scheduler, epoch, device):
    """Train for one epoch on the training set"""
    scheduler.step()
    batch_time = AverageMeter()
    losses = AverageMeter()
    top1 = AverageMeter()

    # switch to train mode
    model.train()

    end = time.time()
    for i, (input, label) in enumerate(train_loader):
        label = label.to(device)
        input = input.to(device)

        # compute output
        y_softmax = 0.
        for t in range(args.T_train):
            y_pred, s_pred = model(input)
            y_softmax += 1. / args.T_train * F.softmax(
                reparameterize(y_pred, s_pred), dim=1)
        loss = RegCrossEntropyLoss_bayes(y_softmax, label, args.mr_weight_l2,
                                         args.mr_weight_negent,
                                         args.mr_weight_kld, args.num_class)

        # measure accuracy and record loss
        prec1 = accuracy(y_softmax.data, label, topk=(1, ))
        losses.update(loss.item(), input.size(0))
        top1.update(prec1[0], input.size(0))

        # compute gradient and do SGD step
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        if i % args.print_freq == 0:
            print('Epoch: [{0}][{1}/{2}], '
                  'Time {batch_time.val:.3f} ({batch_time.avg:.3f}), '
                  'Loss {loss.val:.4f} ({loss.avg:.4f}), '
                  'Prec@1 {top1.val:.3f} ({top1.avg:.3f})'.format(
                      epoch,
                      i,
                      len(train_loader),
                      batch_time=batch_time,
                      loss=losses,
                      top1=top1))
    # log to TensorBoard
    if args.tensorboard:
        log_value('train_loss', losses.avg, epoch)
        log_value('train_acc', top1.avg, epoch)
Beispiel #8
0
def validate_bayes(val_loader, model, epoch, device):
    """Perform validation on the validation set"""
    batch_time = AverageMeter()
    losses = AverageMeter()
    top1 = AverageMeter()

    # switch to evaluate mode
    model.eval()

    end = time.time()
    with torch.no_grad():
        for i, (input, label) in enumerate(val_loader):
            label = label.to(device)
            input = input.to(device)

            # compute output
            y_softmax = 0.
            for t in range(args.T):
                y_pred, s_pred = model(input)
                y_softmax += 1. / args.T * F.softmax(
                    reparameterize(y_pred, s_pred), dim=1)
            loss = RegCrossEntropyLoss_bayes(y_softmax, label,
                                             args.mr_weight_l2,
                                             args.mr_weight_negent,
                                             args.mr_weight_kld,
                                             args.num_class)

            # measure accuracy and record loss
            prec1 = accuracy(y_softmax.data, label, topk=(1, ))
            losses.update(loss.item(), input.size(0))
            top1.update(prec1[0], input.size(0))

            # measure elapsed time
            batch_time.update(time.time() - end)
            end = time.time()

            if i % args.print_freq == 0:
                print('Test: [{0}/{1}], '
                      'Time {batch_time.val:.3f} ({batch_time.avg:.3f}), '
                      'Loss {loss.val:.4f} ({loss.avg:.4f}), '
                      'Prec@1 {top1.val:.3f} ({top1.avg:.3f})'.format(
                          i,
                          len(val_loader),
                          batch_time=batch_time,
                          loss=losses,
                          top1=top1))

    print(' * Prec@1 {top1.avg:.3f}'.format(top1=top1))
    # log to TensorBoard
    if args.tensorboard:
        log_value('val_loss', losses.avg, epoch)
        log_value('val_acc', top1.avg, epoch)
    return top1.avg
def sparseTopicTransformer_loss(encoder, decoder, doc_batch):
    mus, log_sigmas = encoder(doc_batch, device)
    samples = utils.reparameterize(mus, log_sigmas)
    # print('\t\tmus.shape = ', mus.shape,' log_sigma.shape = ', log_sigmas.shape)
    thetas = samples.softmax(dim=1)
    hidden_states, labels = decoder(doc_batch, device)
    print('\t\thidden_states.shape = ', hidden_states.shape)
    print('\t\tlabels.shape = ', labels.shape)
    print('***************************************')
    loss = utils.kld(mus, log_sigmas) + utils.recovery_loss(
        hidden_states, samples) + utils.sparsity_regularization(thetas)
    return loss
def forward_loss(x, model):
    mu, logsigma, classcode = model.encoder(x)
    contentcode = reparameterize(mu, logsigma)
    shuffled_classcode = classcode[torch.randperm(classcode.shape[0])]

    latentcode1 = torch.cat([contentcode, shuffled_classcode], dim=1)
    latentcode2 = torch.cat([contentcode, classcode], dim=1)

    recon_x1 = model.decoder(latentcode1)
    recon_x2 = model.decoder(latentcode2)

    return vae_loss(x, mu, logsigma, recon_x1) + vae_loss(
        x, mu, logsigma, recon_x2)
Beispiel #11
0
    def get_dataset_sample_for_latent_fc(self, dataset):
        s = dataset if isinstance(dataset, dict) else next(dataset)
        sample = {}
        sample['phases'] = torch.Tensor(s['phases'][:,  0, 0, 3, :, 0].numpy()).cuda()
        sample['action'] = torch.Tensor(s['corrected_action'][:, :1].numpy()).cuda()
        sample['reward'] = s['reward'].numpy()

        mu = (torch.Tensor(s['mu'].numpy())).cuda()
        logvar = (torch.Tensor(s['logvar'].numpy())).cuda()
        latent = reparameterize(mu, logvar)

        sample['x'] = latent[:, 0]
        sample['y'] = latent[:, 1]
        return sample
Beispiel #12
0
  def forward(self, x, edge_index, batch, num_graphs):

    # batch_size = data.num_graphs
    if x is None:
        x = torch.ones(batch.shape[0]).to(device)

    node_mu, node_logvar, class_mu, class_logvar = self.encoder(x, edge_index, batch)
    grouped_mu, grouped_logvar = accumulate_group_evidence(
        class_mu.data, class_logvar.data, batch, True
    )

    # kl-divergence error for style latent space
    node_kl_divergence_loss = torch.mean(
        - 0.5 * torch.sum(1 + node_logvar - node_mu.pow(2) - node_logvar.exp())
    )
    node_kl_divergence_loss = 0.0000001 * node_kl_divergence_loss *num_graphs
    node_kl_divergence_loss.backward(retain_graph=True)

    # kl-divergence error for class latent space
    class_kl_divergence_loss = torch.mean(
        - 0.5 * torch.sum(1 + grouped_logvar - grouped_mu.pow(2) - grouped_logvar.exp())
    )
    class_kl_divergence_loss = 0.0000001 * class_kl_divergence_loss * num_graphs
    class_kl_divergence_loss.backward(retain_graph=True)

    # reconstruct samples
    """
    sampling from group mu and logvar for each graph in mini-batch differently makes
    the decoder consider class latent embeddings as random noise and ignore them 
    """
    node_latent_embeddings = reparameterize(training=True, mu=node_mu, logvar=node_logvar)
    class_latent_embeddings = group_wise_reparameterize(
        training=True, mu=grouped_mu, logvar=grouped_logvar, labels_batch=batch, cuda=True
    )

    #need to reduce ml between node and class latents
    '''measure='JSD'
    mi_loss = local_global_loss_disen(node_latent_embeddings, class_latent_embeddings, edge_index, batch, measure)
    mi_loss.backward(retain_graph=True)'''

    reconstructed_node = self.decoder(node_latent_embeddings, class_latent_embeddings, edge_index)
    #check input feat first
    #print('recon ', x[0],reconstructed_node[0])
    reconstruction_error =  0.1*mse_loss(reconstructed_node, x) * num_graphs
    reconstruction_error.backward()

    
    return reconstruction_error.item() , class_kl_divergence_loss.item() , node_kl_divergence_loss.item()
def run_through_network(X, labels_batch):
    style_mu, style_logvar, class_mu, class_logvar = encoder(Variable(X))
    grouped_mu, grouped_logvar = accumulate_group_evidence(
        class_mu.data, class_logvar.data, labels_batch, FLAGS.cuda
    )
    importance_sampling(X[0], style_mu[0], style_logvar[0], class_mu[0], class_logvar[0])

    # reconstruct samples
    style_latent_embeddings = reparameterize(training=True, mu=style_mu, logvar=style_logvar)
    class_latent_embeddings = group_wise_reparameterize(
        training=True, mu=grouped_mu, logvar=grouped_logvar, labels_batch=labels_batch, cuda=FLAGS.cuda
    )

    reconstructed_images = decoder(style_latent_embeddings, class_latent_embeddings)

    return reconstructed_images
def topic_loss(encoder, decoder, doc_batch, tau, V, B):
    # print('***************************************')
    mus, log_sigmas = encoder(doc_batch, device)
    samples = utils.reparameterize(mus, log_sigmas)
    # print('\t\tmus.shape = ', mus.shape,' log_sigma.shape = ', log_sigmas.shape)
    thetas = samples.softmax(dim=1)
    hidden_states, labels = decoder(doc_batch, device)
    print('\t\tthetas.shape = ', thetas.shape)
    print('\t\thidden_states.shape = ', hidden_states.shape)
    # print('\t\tlabels.shape = ', labels.shape)

    b = torch.matmul(hidden_states, tau)
    b = b.sigmoid()
    b = b.view(hidden_states.shape[0], hidden_states.shape[1], 1)
    logits1 = torch.matmul(hidden_states, V)

    bias = torch.matmul(thetas, B)
    bias = bias.view(decoder.batch_size, 1, decoder.num_words)
    print('\t\tlogits1.shape = ', logits1.shape)
    print('\t\tbias.shape = ', bias.shape)

    logits0 = logits1 + bias

    # print('\t\tb.shape = ', b.shape)
    print('\t\tlogits0.shape = ', logits0.shape)
    # print('\t\tlogits1.shape = ', logits1.shape)

    probs = logits1.softmax(dim=2) * b + (1 - b) * logits0.softmax(dim=2)
    # assert (probs.cpu() < 0).sum() == 0, 'there must be something wrong'
    # probs += (1-b)*logits.softmax(dim=2)

    # print('\t\tprobs.shape = ', probs.shape)
    # print('\t\tprobs.sum = ', probs.sum().item())

    probs = probs.view(-1, decoder.num_words)
    labels = labels.view(-1)
    masks = (labels > 0).float()

    loss = -1 * torch.sum(
        torch.log(probs[range(probs.shape[0]), labels]) *
        masks) / torch.sum(masks).item()
    # print('after flatten: probs.shape = ', probs.shape)
    # print('labels.shape = ', labels.shape)
    # print('labels = ', labels)

    loss += utils.kld(mus, log_sigmas)
    return loss
Beispiel #15
0
    def sample_from_conditional(self, X, z=None, full_cov=False):
        """
        Calculates self.conditional and also draws a sample, adding input propagation if necessary

        If z=None then the tensorflow random_normal function is used to generate the
        N(0, 1) samples, otherwise z are used for the whitened sample points

        :param X: Input locations (S,N,D_in)
        :param full_cov: Whether to compute correlations between outputs
        :param z: None, or the sampled points in whitened representation
        :return: mean (S,N,D), var (S,N,N,D or S,N,D), samples (S,N,D)
        """
        mean, var = self.conditional_SND(X, full_cov=full_cov)

        # set shapes
        S = tf.shape(X)[0]
        N = tf.shape(X)[1]
        D = mean.shape[-1]
        # D = self.num_outputs

        mean = tf.reshape(mean, (S, N, D))
        if full_cov:
            var = tf.reshape(var, (S, N, N, D))
        else:
            var = tf.reshape(var, (S, N, D))

        if z is None:
            z = tf.random.normal(tf.shape(mean), dtype=gpflow.default_float())
        samples = reparameterize(mean, var, z, full_cov=full_cov)

        if self.input_prop_dim:
            shape = [tf.shape(X)[0], tf.shape(X)[1], self.input_prop_dim]
            X_prop = tf.reshape(X[:, :, :self.input_prop_dim], shape)

            samples = tf.concat([X_prop, samples], 2)
            mean = tf.concat([X_prop, mean], 2)

            if full_cov:
                shape = (tf.shape(X)[0], tf.shape(X)[1], tf.shape(X)[1],
                         tf.shape(var)[3])
                zeros = tf.zeros(shape, dtype=gpflow.default_float())
                var = tf.concat([zeros, var], 3)
            else:
                var = tf.concat([tf.zeros_like(X_prop), var], 2)

        return samples, mean, var
    def sample_from_conditional(self, X, z=None, full_cov=False):
        """
        Calculates self.conditional and also draws a sample

        If z=None then the tensorflow random_normal function is used to generate the
        N(0, 1) samples, otherwise z are used for the whitened sample points

        :param X: Input locations (S,N,D_in)
        :param full_cov: Whether to compute correlations between outputs
        :param z: None, or the sampled points in whitened representation
        :return: mean (S,N,D), var (S,N,N,D or S,N,D), samples (S,N,D)
        """
        mean, var = self.conditional(X, full_cov=full_cov)
        if z is None:
            z = tf.random_normal(tf.shape(mean), dtype=settings.float_type)
        samples = reparameterize(mean, var, z, full_cov=full_cov)
        return samples, mean, var
Beispiel #17
0
def process(FLAGS, X, labels_batch, encoder, decoder):

    style_mu, style_logvar, class_mu, class_logvar = encoder(X.cuda())

    content_mu, content_logvar, list_g, sizes_group = \
            accumulate_group_evidence(FLAGS,class_mu, class_logvar,
                                    labels_batch, FLAGS.cuda)
    style_latent_embeddings = reparameterize(training=True,
                                             mu=style_mu,
                                             logvar=style_logvar)

    class_latent_embeddings, indexes, sizes = group_wise_reparameterize_each(
        training=True,
        mu=content_mu,
        logvar=content_logvar,
        labels_batch=labels_batch,
        list_groups_labels=list_g,
        sizes_group=sizes_group,
        cuda=FLAGS.cuda)

    # kl-divergence error for style latent space
    style_kl_divergence_loss = 0.5 * (-1 - style_logvar[indexes, :] +
                                      style_mu[indexes, :].pow(2) +
                                      style_logvar[indexes, :].exp()).sum()
    # kl-divergence error for class latent space
    class_kl_divergence_loss = 0.5 * (-1 - content_logvar + content_mu.pow(2) +
                                      content_logvar.exp()).sum()

    # reconstruct samples
    #reorder by the same order as class_latent_embeddings
    mu_x, logvar_x = decoder(style_latent_embeddings[indexes, :],
                             class_latent_embeddings)
    scale_x = (torch.exp(logvar_x) + 1e-12)**0.5
    scale_x = scale_x.view(X.size(0), 784)
    # create normal distribution on output pixel
    mu_x = mu_x.view(X.size(0), 784)
    prob_x = Normal(mu_x, scale_x)
    logp_batch = prob_x.log_prob(X[indexes, :].view(X.size(0), 784)).sum(1)

    reconstruction_proba = logp_batch.sum(0)
    n_groups = content_mu.size(0)
    elbo = (reconstruction_proba - style_kl_divergence_loss -
            class_kl_divergence_loss) / n_groups

    return elbo, reconstruction_proba / n_groups, style_kl_divergence_loss / n_groups, class_kl_divergence_loss / n_groups
Beispiel #18
0
        )

    def forward(self, img):
        img_flat = img.view(img.size(0), -1)
        validity = self.dis_model(img_flat)

        return validity

if __name__ == '__main__':
    """
    test for network outputs
    """
    encoder = Encoder(16, 16)
    decoder = Decoder(16, 16)

    classifier = Classifier(z_dim=16, num_classes=10)

    mnist = datasets.MNIST(root='mnist', download=True, train=True, transform=transform_config)
    loader = cycle(DataLoader(mnist, batch_size=64, shuffle=True, num_workers=0, drop_last=True))

    image_batch, labels_batch = next(loader)

    mu, logvar, class_latent_space = encoder(Variable(image_batch))
    style_latent_space = reparameterize(training=True, mu=mu, logvar=logvar)

    reconstructed_image = decoder(style_latent_space, class_latent_space)
    classifier_pred = classifier(style_latent_space)

    print(reconstructed_image.size())
    print(classifier_pred.size())
        image_batch, _, labels_batch = next(loader)
        _, __, class_mu, class_logvar = encoder(image_batch)
        content_mu, content_logvar, list_g, sizes_group = accumulate_group_evidence(
            FLAGS, class_mu.data, class_logvar.data, labels_batch, FLAGS.cuda)

    # generate all possible combinations using the encoder and decoder architecture in the grid
    for row in range(1, 11):
        style_image = image_array[row][0]
        style_image = np.transpose(style_image, (2, 0, 1))
        style_image = torch.FloatTensor(style_image)
        style_image = style_image.contiguous()
        style_image = style_image[0, :, :]
        style_image = style_image.view(1, 1, 28, 28)
        style_mu, style_logvar, _, _ = encoder(Variable(style_image))
        style_latent_embeddings = reparameterize(training=True,
                                                 mu=style_mu,
                                                 logvar=style_logvar)

        for col in range(1, 11):
            if not FLAGS.accumulate_evidence:
                class_image = image_array[col][0]
                class_image = np.transpose(class_image, (2, 0, 1))
                class_image = torch.FloatTensor(class_image)
                class_image = class_image.contiguous()
                class_image = class_image[0, :, :]
                class_image = class_image.view(1, 1, 28, 28)
                _, _, class_mu, class_logvar = encoder(Variable(class_image))
                specified_factor_temp = reparameterize(training=True,
                                                       mu=class_mu,
                                                       logvar=class_logvar)
            else:
Beispiel #20
0
def training_procedure(FLAGS):
    """
    model definition
    """
    encoder = Encoder(style_dim=FLAGS.style_dim, class_dim=FLAGS.class_dim)
    encoder.apply(weights_init)

    decoder = Decoder(style_dim=FLAGS.style_dim, class_dim=FLAGS.class_dim)
    decoder.apply(weights_init)

    # load saved models if load_saved flag is true
    if FLAGS.load_saved:
        encoder.load_state_dict(
            torch.load(os.path.join('checkpoints', FLAGS.encoder_save)))
        decoder.load_state_dict(
            torch.load(os.path.join('checkpoints', FLAGS.decoder_save)))
    """
    variable definition
    """
    X = torch.FloatTensor(FLAGS.batch_size, 1, FLAGS.image_size,
                          FLAGS.image_size)
    '''
    add option to run on GPU
    '''
    if FLAGS.cuda:
        encoder.cuda()
        decoder.cuda()

        X = X.cuda()
    """
    optimizer definition
    """
    auto_encoder_optimizer = optim.Adam(list(encoder.parameters()) +
                                        list(decoder.parameters()),
                                        lr=FLAGS.initial_learning_rate,
                                        betas=(FLAGS.beta_1, FLAGS.beta_2))
    """
    training
    """
    if torch.cuda.is_available() and not FLAGS.cuda:
        print(
            "WARNING: You have a CUDA device, so you should probably run with --cuda"
        )

    if not os.path.exists('checkpoints'):
        os.makedirs('checkpoints')

    # load_saved is false when training is started from 0th iteration
    if not FLAGS.load_saved:
        with open(FLAGS.log_file, 'w') as log:
            log.write(
                'Epoch\tIteration\tReconstruction_loss\tStyle_KL_divergence_loss\tClass_KL_divergence_loss\n'
            )

    # load data set and create data loader instance
    print('Loading MNIST dataset...')
    mnist = datasets.MNIST(root='mnist',
                           download=True,
                           train=True,
                           transform=transform_config)
    loader = cycle(
        DataLoader(mnist,
                   batch_size=FLAGS.batch_size,
                   shuffle=True,
                   num_workers=0,
                   drop_last=True))

    # initialize summary writer
    writer = SummaryWriter()

    for epoch in range(FLAGS.start_epoch, FLAGS.end_epoch):
        print('')
        print(
            'Epoch #' + str(epoch) +
            '..........................................................................'
        )

        for iteration in range(int(len(mnist) / FLAGS.batch_size)):
            # load a mini-batch
            image_batch, labels_batch = next(loader)

            # set zero_grad for the optimizer
            auto_encoder_optimizer.zero_grad()

            X.copy_(image_batch)

            style_mu, style_logvar, class_mu, class_logvar = encoder(
                Variable(X))
            grouped_mu, grouped_logvar = accumulate_group_evidence(
                class_mu.data, class_logvar.data, labels_batch, FLAGS.cuda)

            # kl-divergence error for style latent space
            style_kl_divergence_loss = FLAGS.kl_divergence_coef * (
                -0.5 * torch.sum(1 + style_logvar - style_mu.pow(2) -
                                 style_logvar.exp()))
            style_kl_divergence_loss /= (FLAGS.batch_size *
                                         FLAGS.num_channels *
                                         FLAGS.image_size * FLAGS.image_size)
            style_kl_divergence_loss.backward(retain_graph=True)

            # kl-divergence error for class latent space
            class_kl_divergence_loss = FLAGS.kl_divergence_coef * (
                -0.5 * torch.sum(1 + grouped_logvar - grouped_mu.pow(2) -
                                 grouped_logvar.exp()))
            class_kl_divergence_loss /= (FLAGS.batch_size *
                                         FLAGS.num_channels *
                                         FLAGS.image_size * FLAGS.image_size)
            class_kl_divergence_loss.backward(retain_graph=True)

            # reconstruct samples
            """
            sampling from group mu and logvar for each image in mini-batch differently makes
            the decoder consider class latent embeddings as random noise and ignore them 
            """
            style_latent_embeddings = reparameterize(training=True,
                                                     mu=style_mu,
                                                     logvar=style_logvar)
            class_latent_embeddings = group_wise_reparameterize(
                training=True,
                mu=grouped_mu,
                logvar=grouped_logvar,
                labels_batch=labels_batch,
                cuda=FLAGS.cuda)

            reconstructed_images = decoder(style_latent_embeddings,
                                           class_latent_embeddings)

            reconstruction_error = FLAGS.reconstruction_coef * mse_loss(
                reconstructed_images, Variable(X))
            reconstruction_error.backward()

            auto_encoder_optimizer.step()

            if (iteration + 1) % 50 == 0:
                print('')
                print('Epoch #' + str(epoch))
                print('Iteration #' + str(iteration))

                print('')
                print('Reconstruction loss: ' +
                      str(reconstruction_error.data.storage().tolist()[0]))
                print('Style KL-Divergence loss: ' +
                      str(style_kl_divergence_loss.data.storage().tolist()[0]))
                print('Class KL-Divergence loss: ' +
                      str(class_kl_divergence_loss.data.storage().tolist()[0]))

            # write to log
            with open(FLAGS.log_file, 'a') as log:
                log.write('{0}\t{1}\t{2}\t{3}\t{4}\n'.format(
                    epoch, iteration,
                    reconstruction_error.data.storage().tolist()[0],
                    style_kl_divergence_loss.data.storage().tolist()[0],
                    class_kl_divergence_loss.data.storage().tolist()[0]))

            # write to tensorboard
            writer.add_scalar(
                'Reconstruction loss',
                reconstruction_error.data.storage().tolist()[0],
                epoch * (int(len(mnist) / FLAGS.batch_size) + 1) + iteration)
            writer.add_scalar(
                'Style KL-Divergence loss',
                style_kl_divergence_loss.data.storage().tolist()[0],
                epoch * (int(len(mnist) / FLAGS.batch_size) + 1) + iteration)
            writer.add_scalar(
                'Class KL-Divergence loss',
                class_kl_divergence_loss.data.storage().tolist()[0],
                epoch * (int(len(mnist) / FLAGS.batch_size) + 1) + iteration)

        # save checkpoints after every 5 epochs
        if (epoch + 1) % 5 == 0 or (epoch + 1) == FLAGS.end_epoch:
            torch.save(encoder.state_dict(),
                       os.path.join('checkpoints', FLAGS.encoder_save))
            torch.save(decoder.state_dict(),
                       os.path.join('checkpoints', FLAGS.decoder_save))
def generate_sequence(x):
    if opt.svg_name == 'GT':
        return x
    elif opt.svg_name == 'SVG':
        gen_seq = []
        frame_predictor.hidden = frame_predictor.init_hidden()
        posterior.hidden = posterior.init_hidden()
        prior.hidden = prior.init_hidden()
        x_in = x[0]
        gen_seq.append(x_in)
        for i in range(1, opt.n_eval):
            h = encoder(x_in)
            if i < opt.n_past:   
                h, skip = h
            else:
                h, _ = h
            h = h.detach()
            if i < opt.n_past:
                h_target = encoder(x[i])[0].detach()
                z_t, _, _ = posterior(h_target)
                prior(h)
                frame_predictor(torch.cat([h, z_t], 1))
                x_in = x[i]
                gen_seq.append(x_in)
            else:
                z_t, _, _ = prior(h)
                h = frame_predictor(torch.cat([h, z_t], 1)).detach()
                x_in = decoder([h, skip]).detach()
                gen_seq.append(x_in)

        return gen_seq

    elif opt.svg_name == 'DSVG':
        gen_seq = []
        x_in = x[0]
        xs = []
        for i in range(0, opt.n_past):
            xs.append(x[i])
    
        random.shuffle(xs)
    
        mu_c, logvar_c, skip = cont_encoder(torch.cat(xs, 1))
        mu_c = mu_c.detach()
        
        gen_seq.append(x_in)
        h = pose_encoder(x[0]).detach()
        for i in range(1, opt.n_eval):
            h_target = pose_encoder(x[i]).detach()
            if i < opt.n_past:
                mu_t_p, logvar_t_p = posterior_pose(torch.cat([h_target, mu_c],1), time_step = i-1)
                z_t_p = utils.reparameterize(mu_t_p, logvar_t_p)
                prior(torch.cat([h, mu_c],1), time_step = i-1)
                frame_predictor(torch.cat([z_t_p, mu_c], 1), time_step = i-1)
                x_in = x[i]
                gen_seq.append(x_in)
                h = h_target
            else:
                mu_t_pp, logvar_t_pp = prior(torch.cat([h, mu_c],1),time_step = i-1)
                z_t = utils.reparameterize(mu_t_pp, logvar_t_pp)
                h_pred = frame_predictor(torch.cat([z_t, mu_c], 1), time_step = i-1).detach()
                x_in = decoder([h_pred, skip]).detach()
                gen_seq.append(x_in)
                h = pose_encoder(x_in).detach()

        return gen_seq
    else:
        raise ValueError('Unknown svg model: %s' % opt.svg_name)
Beispiel #22
0
def training_procedure(FLAGS):
    """
    model definition
    """
    encoder = Encoder(style_dim=FLAGS.style_dim, class_dim=FLAGS.class_dim)
    encoder.apply(weights_init)

    decoder = Decoder(style_dim=FLAGS.style_dim, class_dim=FLAGS.class_dim)
    decoder.apply(weights_init)

    # load saved models if load_saved flag is true
    if FLAGS.load_saved:
        encoder.load_state_dict(
            torch.load(os.path.join('checkpoints', FLAGS.encoder_save)))
        decoder.load_state_dict(
            torch.load(os.path.join('checkpoints', FLAGS.decoder_save)))
    """
    variable definition
    """

    X_1 = torch.FloatTensor(FLAGS.batch_size, FLAGS.num_channels,
                            FLAGS.image_size, FLAGS.image_size)
    X_2 = torch.FloatTensor(FLAGS.batch_size, FLAGS.num_channels,
                            FLAGS.image_size, FLAGS.image_size)
    X_3 = torch.FloatTensor(FLAGS.batch_size, FLAGS.num_channels,
                            FLAGS.image_size, FLAGS.image_size)

    style_latent_space = torch.FloatTensor(FLAGS.batch_size, FLAGS.style_dim)
    """
    loss definitions
    """
    cross_entropy_loss = nn.CrossEntropyLoss()
    '''
    add option to run on GPU
    '''
    if FLAGS.cuda:
        encoder.cuda()
        decoder.cuda()

        cross_entropy_loss.cuda()

        X_1 = X_1.cuda()
        X_2 = X_2.cuda()
        X_3 = X_3.cuda()

        style_latent_space = style_latent_space.cuda()
    """
    optimizer and scheduler definition
    """
    auto_encoder_optimizer = optim.Adam(list(encoder.parameters()) +
                                        list(decoder.parameters()),
                                        lr=FLAGS.initial_learning_rate,
                                        betas=(FLAGS.beta_1, FLAGS.beta_2))

    reverse_cycle_optimizer = optim.Adam(list(encoder.parameters()),
                                         lr=FLAGS.initial_learning_rate,
                                         betas=(FLAGS.beta_1, FLAGS.beta_2))

    # divide the learning rate by a factor of 10 after 80 epochs
    auto_encoder_scheduler = optim.lr_scheduler.StepLR(auto_encoder_optimizer,
                                                       step_size=80,
                                                       gamma=0.1)
    reverse_cycle_scheduler = optim.lr_scheduler.StepLR(
        reverse_cycle_optimizer, step_size=80, gamma=0.1)
    """
    training
    """
    if torch.cuda.is_available() and not FLAGS.cuda:
        print(
            "WARNING: You have a CUDA device, so you should probably run with --cuda"
        )

    if not os.path.exists('checkpoints'):
        os.makedirs('checkpoints')

    if not os.path.exists('reconstructed_images'):
        os.makedirs('reconstructed_images')

    # load_saved is false when training is started from 0th iteration
    if not FLAGS.load_saved:
        with open(FLAGS.log_file, 'w') as log:
            log.write(
                'Epoch\tIteration\tReconstruction_loss\tKL_divergence_loss\tReverse_cycle_loss\n'
            )

    # load data set and create data loader instance
    print('Loading MNIST paired dataset...')
    paired_mnist = MNIST_Paired(root='mnist',
                                download=True,
                                train=True,
                                transform=transform_config)
    loader = cycle(
        DataLoader(paired_mnist,
                   batch_size=FLAGS.batch_size,
                   shuffle=True,
                   num_workers=0,
                   drop_last=True))

    # initialize summary writer
    writer = SummaryWriter()

    for epoch in range(FLAGS.start_epoch, FLAGS.end_epoch):
        print('')
        print(
            'Epoch #' + str(epoch) +
            '..........................................................................'
        )

        # update the learning rate scheduler
        auto_encoder_scheduler.step()
        reverse_cycle_scheduler.step()

        for iteration in range(int(len(paired_mnist) / FLAGS.batch_size)):
            # A. run the auto-encoder reconstruction
            image_batch_1, image_batch_2, _ = next(loader)

            auto_encoder_optimizer.zero_grad()

            X_1.copy_(image_batch_1)
            X_2.copy_(image_batch_2)

            style_mu_1, style_logvar_1, class_latent_space_1 = encoder(
                Variable(X_1))
            style_latent_space_1 = reparameterize(training=True,
                                                  mu=style_mu_1,
                                                  logvar=style_logvar_1)

            kl_divergence_loss_1 = FLAGS.kl_divergence_coef * (
                -0.5 * torch.sum(1 + style_logvar_1 - style_mu_1.pow(2) -
                                 style_logvar_1.exp()))
            kl_divergence_loss_1 /= (FLAGS.batch_size * FLAGS.num_channels *
                                     FLAGS.image_size * FLAGS.image_size)
            kl_divergence_loss_1.backward(retain_graph=True)

            style_mu_2, style_logvar_2, class_latent_space_2 = encoder(
                Variable(X_2))
            style_latent_space_2 = reparameterize(training=True,
                                                  mu=style_mu_2,
                                                  logvar=style_logvar_2)

            kl_divergence_loss_2 = FLAGS.kl_divergence_coef * (
                -0.5 * torch.sum(1 + style_logvar_2 - style_mu_2.pow(2) -
                                 style_logvar_2.exp()))
            kl_divergence_loss_2 /= (FLAGS.batch_size * FLAGS.num_channels *
                                     FLAGS.image_size * FLAGS.image_size)
            kl_divergence_loss_2.backward(retain_graph=True)

            reconstructed_X_1 = decoder(style_latent_space_1,
                                        class_latent_space_2)
            reconstructed_X_2 = decoder(style_latent_space_2,
                                        class_latent_space_1)

            reconstruction_error_1 = FLAGS.reconstruction_coef * mse_loss(
                reconstructed_X_1, Variable(X_1))
            reconstruction_error_1.backward(retain_graph=True)

            reconstruction_error_2 = FLAGS.reconstruction_coef * mse_loss(
                reconstructed_X_2, Variable(X_2))
            reconstruction_error_2.backward()

            reconstruction_error = (
                reconstruction_error_1 +
                reconstruction_error_2) / FLAGS.reconstruction_coef
            kl_divergence_error = (kl_divergence_loss_1 + kl_divergence_loss_2
                                   ) / FLAGS.kl_divergence_coef

            auto_encoder_optimizer.step()

            # B. reverse cycle
            image_batch_1, _, __ = next(loader)
            image_batch_2, _, __ = next(loader)

            reverse_cycle_optimizer.zero_grad()

            X_1.copy_(image_batch_1)
            X_2.copy_(image_batch_2)

            style_latent_space.normal_(0., 1.)

            _, __, class_latent_space_1 = encoder(Variable(X_1))
            _, __, class_latent_space_2 = encoder(Variable(X_2))

            reconstructed_X_1 = decoder(Variable(style_latent_space),
                                        class_latent_space_1.detach())
            reconstructed_X_2 = decoder(Variable(style_latent_space),
                                        class_latent_space_2.detach())

            style_mu_1, style_logvar_1, _ = encoder(reconstructed_X_1)
            style_latent_space_1 = reparameterize(training=False,
                                                  mu=style_mu_1,
                                                  logvar=style_logvar_1)

            style_mu_2, style_logvar_2, _ = encoder(reconstructed_X_2)
            style_latent_space_2 = reparameterize(training=False,
                                                  mu=style_mu_2,
                                                  logvar=style_logvar_2)

            reverse_cycle_loss = FLAGS.reverse_cycle_coef * l1_loss(
                style_latent_space_1, style_latent_space_2)
            reverse_cycle_loss.backward()
            reverse_cycle_loss /= FLAGS.reverse_cycle_coef

            reverse_cycle_optimizer.step()

            if (iteration + 1) % 10 == 0:
                print('')
                print('Epoch #' + str(epoch))
                print('Iteration #' + str(iteration))

                print('')
                print('Reconstruction loss: ' +
                      str(reconstruction_error.data.storage().tolist()[0]))
                print('KL-Divergence loss: ' +
                      str(kl_divergence_error.data.storage().tolist()[0]))
                print('Reverse cycle loss: ' +
                      str(reverse_cycle_loss.data.storage().tolist()[0]))

            # write to log
            with open(FLAGS.log_file, 'a') as log:
                log.write('{0}\t{1}\t{2}\t{3}\t{4}\n'.format(
                    epoch, iteration,
                    reconstruction_error.data.storage().tolist()[0],
                    kl_divergence_error.data.storage().tolist()[0],
                    reverse_cycle_loss.data.storage().tolist()[0]))

            # write to tensorboard
            writer.add_scalar(
                'Reconstruction loss',
                reconstruction_error.data.storage().tolist()[0],
                epoch * (int(len(paired_mnist) / FLAGS.batch_size) + 1) +
                iteration)
            writer.add_scalar(
                'KL-Divergence loss',
                kl_divergence_error.data.storage().tolist()[0],
                epoch * (int(len(paired_mnist) / FLAGS.batch_size) + 1) +
                iteration)
            writer.add_scalar(
                'Reverse cycle loss',
                reverse_cycle_loss.data.storage().tolist()[0],
                epoch * (int(len(paired_mnist) / FLAGS.batch_size) + 1) +
                iteration)

        # save model after every 5 epochs
        if (epoch + 1) % 5 == 0 or (epoch + 1) == FLAGS.end_epoch:
            torch.save(encoder.state_dict(),
                       os.path.join('checkpoints', FLAGS.encoder_save))
            torch.save(decoder.state_dict(),
                       os.path.join('checkpoints', FLAGS.decoder_save))
            """
            save reconstructed images and style swapped image generations to check progress
            """
            image_batch_1, image_batch_2, _ = next(loader)
            image_batch_3, _, __ = next(loader)

            X_1.copy_(image_batch_1)
            X_2.copy_(image_batch_2)
            X_3.copy_(image_batch_3)

            style_mu_1, style_logvar_1, _ = encoder(Variable(X_1))
            _, __, class_latent_space_2 = encoder(Variable(X_2))
            style_mu_3, style_logvar_3, _ = encoder(Variable(X_3))

            style_latent_space_1 = reparameterize(training=False,
                                                  mu=style_mu_1,
                                                  logvar=style_logvar_1)
            style_latent_space_3 = reparameterize(training=False,
                                                  mu=style_mu_3,
                                                  logvar=style_logvar_3)

            reconstructed_X_1_2 = decoder(style_latent_space_1,
                                          class_latent_space_2)
            reconstructed_X_3_2 = decoder(style_latent_space_3,
                                          class_latent_space_2)

            # save input image batch
            image_batch = np.transpose(X_1.cpu().numpy(), (0, 2, 3, 1))
            image_batch = np.concatenate(
                (image_batch, image_batch, image_batch), axis=3)
            imshow_grid(image_batch, name=str(epoch) + '_original', save=True)

            # save reconstructed batch
            reconstructed_x = np.transpose(
                reconstructed_X_1_2.cpu().data.numpy(), (0, 2, 3, 1))
            reconstructed_x = np.concatenate(
                (reconstructed_x, reconstructed_x, reconstructed_x), axis=3)
            imshow_grid(reconstructed_x,
                        name=str(epoch) + '_target',
                        save=True)

            style_batch = np.transpose(X_3.cpu().numpy(), (0, 2, 3, 1))
            style_batch = np.concatenate(
                (style_batch, style_batch, style_batch), axis=3)
            imshow_grid(style_batch, name=str(epoch) + '_style', save=True)

            # save style swapped reconstructed batch
            reconstructed_style = np.transpose(
                reconstructed_X_3_2.cpu().data.numpy(), (0, 2, 3, 1))
            reconstructed_style = np.concatenate(
                (reconstructed_style, reconstructed_style,
                 reconstructed_style),
                axis=3)
            imshow_grid(reconstructed_style,
                        name=str(epoch) + '_style_target',
                        save=True)
                                download=True,
                                train=False,
                                transform=transform_config)
    loader = cycle(
        DataLoader(paired_mnist,
                   batch_size=FLAGS.num_test_samples,
                   shuffle=True,
                   num_workers=0,
                   drop_last=True))

    image_batch, _, labels_batch = next(loader)

    style_mu, style_logvar, class_mu, class_logvar = encoder(
        Variable(image_batch))
    style_latent_embeddings = reparameterize(training=True,
                                             mu=style_mu,
                                             logvar=style_logvar)

    if FLAGS.accumulate_evidence:
        grouped_mu, grouped_logvar = accumulate_group_evidence(
            class_mu.data, class_logvar.data, labels_batch, FLAGS.cuda)

        class_latent_embeddings = group_wise_reparameterize(
            training=True,
            mu=grouped_mu,
            logvar=grouped_logvar,
            labels_batch=labels_batch,
            cuda=FLAGS.cuda)
    else:
        class_latent_embeddings = reparameterize(training=True,
                                                 mu=class_mu,
    def forward(self, x):
        mu, logsigma = self.encoder(x)
        z = reparameterize(mu, logsigma)
        recon_x = self.decoder(z)

        return mu, logsigma, recon_x
Beispiel #25
0
                           download=True,
                           train=True,
                           transform=transform_config)
    loader = cycle(
        DataLoader(mnist,
                   batch_size=64,
                   shuffle=True,
                   num_workers=0,
                   drop_last=True))

    image_batch, labels_batch = next(loader)

    style_mu, style_logvar, class_mu, class_logvar = encoder(
        Variable(image_batch))

    style_reparam = reparameterize(training=True,
                                   mu=style_mu,
                                   logvar=style_logvar)
    class_reparam = reparameterize(training=True,
                                   mu=class_mu,
                                   logvar=class_logvar)

    reconstructed_image = decoder(style_reparam, class_reparam)

    style_classifier_pred = classifier(style_reparam)
    class_classifier_pred = classifier(class_reparam)

    print(reconstructed_image.size())
    print(style_classifier_pred.size())
    print(class_classifier_pred.size())
Beispiel #26
0
 def sample(self, x):
     x = F.relu(self.fc1(x), inplace=True)
     x = F.relu(self.fc2(x), inplace=True)
     means, log_stds = self.fc3(x).chunk(2, dim=-1)
     return utils.reparameterize(means, log_stds.clamp_(-20, 2))
Beispiel #27
0
def training_procedure(FLAGS):
    """
    model definition
    """
    encoder = Encoder(style_dim=FLAGS.style_dim, class_dim=FLAGS.class_dim)
    encoder.apply(weights_init)

    decoder = Decoder(style_dim=FLAGS.style_dim, class_dim=FLAGS.class_dim)
    decoder.apply(weights_init)

    discriminator = Discriminator()
    discriminator.apply(weights_init)

    # load saved models if load_saved flag is true
    if FLAGS.load_saved:
        raise Exception('This is not implemented')
        encoder.load_state_dict(torch.load(os.path.join('checkpoints', FLAGS.encoder_save)))
        decoder.load_state_dict(torch.load(os.path.join('checkpoints', FLAGS.decoder_save)))

    """
    variable definition
    """

    X_1 = torch.FloatTensor(FLAGS.batch_size, FLAGS.num_channels, FLAGS.image_size, FLAGS.image_size)
    X_2 = torch.FloatTensor(FLAGS.batch_size, FLAGS.num_channels, FLAGS.image_size, FLAGS.image_size)
    X_3 = torch.FloatTensor(FLAGS.batch_size, FLAGS.num_channels, FLAGS.image_size, FLAGS.image_size)

    style_latent_space = torch.FloatTensor(FLAGS.batch_size, FLAGS.style_dim)

    """
    loss definitions
    """
    cross_entropy_loss = nn.CrossEntropyLoss()
    adversarial_loss = nn.BCELoss()

    '''
    add option to run on GPU
    '''
    if FLAGS.cuda:
        encoder.cuda()
        decoder.cuda()
        discriminator.cuda()

        cross_entropy_loss.cuda()
        adversarial_loss.cuda()

        X_1 = X_1.cuda()
        X_2 = X_2.cuda()
        X_3 = X_3.cuda()

        style_latent_space = style_latent_space.cuda()

    """
    optimizer and scheduler definition
    """
    auto_encoder_optimizer = optim.Adam(
        list(encoder.parameters()) + list(decoder.parameters()),
        lr=FLAGS.initial_learning_rate,
        betas=(FLAGS.beta_1, FLAGS.beta_2)
    )

    reverse_cycle_optimizer = optim.Adam(
        list(encoder.parameters()),
        lr=FLAGS.initial_learning_rate,
        betas=(FLAGS.beta_1, FLAGS.beta_2)
    )

    generator_optimizer = optim.Adam(
        list(decoder.parameters()),
        lr=FLAGS.initial_learning_rate,
        betas=(FLAGS.beta_1, FLAGS.beta_2)
    )

    discriminator_optimizer = optim.Adam(
        list(discriminator.parameters()),
        lr=FLAGS.initial_learning_rate,
        betas=(FLAGS.beta_1, FLAGS.beta_2)
    )

    # divide the learning rate by a factor of 10 after 80 epochs
    auto_encoder_scheduler = optim.lr_scheduler.StepLR(auto_encoder_optimizer, step_size=80, gamma=0.1)
    reverse_cycle_scheduler = optim.lr_scheduler.StepLR(reverse_cycle_optimizer, step_size=80, gamma=0.1)
    generator_scheduler = optim.lr_scheduler.StepLR(generator_optimizer, step_size=80, gamma=0.1)
    discriminator_scheduler = optim.lr_scheduler.StepLR(discriminator_optimizer, step_size=80, gamma=0.1)

    # Used later to define discriminator ground truths
    Tensor = torch.cuda.FloatTensor if FLAGS.cuda else torch.FloatTensor

    """
    training
    """
    if torch.cuda.is_available() and not FLAGS.cuda:
        print("WARNING: You have a CUDA device, so you should probably run with --cuda")

    if not os.path.exists('checkpoints'):
        os.makedirs('checkpoints')

    if not os.path.exists('reconstructed_images'):
        os.makedirs('reconstructed_images')

    # load_saved is false when training is started from 0th iteration
    if not FLAGS.load_saved:
        with open(FLAGS.log_file, 'w') as log:
            headers = ['Epoch', 'Iteration', 'Reconstruction_loss', 'KL_divergence_loss', 'Reverse_cycle_loss']

            if FLAGS.forward_gan:
              headers.extend(['Generator_forward_loss', 'Discriminator_forward_loss'])

            if FLAGS.reverse_gan:
              headers.extend(['Generator_reverse_loss', 'Discriminator_reverse_loss'])

            log.write('\t'.join(headers) + '\n')

    # load data set and create data loader instance
    print('Loading CIFAR paired dataset...')
    paired_cifar = CIFAR_Paired(root='cifar', download=True, train=True, transform=transform_config)
    loader = cycle(DataLoader(paired_cifar, batch_size=FLAGS.batch_size, shuffle=True, num_workers=0, drop_last=True))

    # Save a batch of images to use for visualization
    image_sample_1, image_sample_2, _ = next(loader)
    image_sample_3, _, _ = next(loader)

    # initialize summary writer
    writer = SummaryWriter()

    for epoch in range(FLAGS.start_epoch, FLAGS.end_epoch):
        print('')
        print('Epoch #' + str(epoch) + '..........................................................................')

        # update the learning rate scheduler
        auto_encoder_scheduler.step()
        reverse_cycle_scheduler.step()
        generator_scheduler.step()
        discriminator_scheduler.step()

        for iteration in range(int(len(paired_cifar) / FLAGS.batch_size)):
            # Adversarial ground truths
            valid = Variable(Tensor(FLAGS.batch_size, 1).fill_(1.0), requires_grad=False)
            fake = Variable(Tensor(FLAGS.batch_size, 1).fill_(0.0), requires_grad=False)

            # A. run the auto-encoder reconstruction
            image_batch_1, image_batch_2, _ = next(loader)

            auto_encoder_optimizer.zero_grad()

            X_1.copy_(image_batch_1)
            X_2.copy_(image_batch_2)

            style_mu_1, style_logvar_1, class_latent_space_1 = encoder(Variable(X_1))
            style_latent_space_1 = reparameterize(training=True, mu=style_mu_1, logvar=style_logvar_1)

            kl_divergence_loss_1 = FLAGS.kl_divergence_coef * (
                - 0.5 * torch.sum(1 + style_logvar_1 - style_mu_1.pow(2) - style_logvar_1.exp())
            )
            kl_divergence_loss_1 /= (FLAGS.batch_size * FLAGS.num_channels * FLAGS.image_size * FLAGS.image_size)
            kl_divergence_loss_1.backward(retain_graph=True)

            style_mu_2, style_logvar_2, class_latent_space_2 = encoder(Variable(X_2))
            style_latent_space_2 = reparameterize(training=True, mu=style_mu_2, logvar=style_logvar_2)

            kl_divergence_loss_2 = FLAGS.kl_divergence_coef * (
                - 0.5 * torch.sum(1 + style_logvar_2 - style_mu_2.pow(2) - style_logvar_2.exp())
            )
            kl_divergence_loss_2 /= (FLAGS.batch_size * FLAGS.num_channels * FLAGS.image_size * FLAGS.image_size)
            kl_divergence_loss_2.backward(retain_graph=True)

            reconstructed_X_1 = decoder(style_latent_space_1, class_latent_space_2)
            reconstructed_X_2 = decoder(style_latent_space_2, class_latent_space_1)

            reconstruction_error_1 = FLAGS.reconstruction_coef * mse_loss(reconstructed_X_1, Variable(X_1))
            reconstruction_error_1.backward(retain_graph=True)

            reconstruction_error_2 = FLAGS.reconstruction_coef * mse_loss(reconstructed_X_2, Variable(X_2))
            reconstruction_error_2.backward()

            reconstruction_error = (reconstruction_error_1 + reconstruction_error_2) / FLAGS.reconstruction_coef
            kl_divergence_error = (kl_divergence_loss_1 + kl_divergence_loss_2) / FLAGS.kl_divergence_coef

            auto_encoder_optimizer.step()

            # A-1. Discriminator training during forward cycle
            if FLAGS.forward_gan:
              # Training generator
              generator_optimizer.zero_grad()

              g_loss_1 = adversarial_loss(discriminator(Variable(reconstructed_X_1)), valid)
              g_loss_2 = adversarial_loss(discriminator(Variable(reconstructed_X_2)), valid)

              gen_f_loss = (g_loss_1 + g_loss_2) / 2.0
              gen_f_loss.backward()

              generator_optimizer.step()

              # Training discriminator
              discriminator_optimizer.zero_grad()

              real_loss_1 = adversarial_loss(discriminator(Variable(X_1)), valid)
              real_loss_2 = adversarial_loss(discriminator(Variable(X_2)), valid)
              fake_loss_1 = adversarial_loss(discriminator(Variable(reconstructed_X_1)), fake)
              fake_loss_2 = adversarial_loss(discriminator(Variable(reconstructed_X_2)), fake)

              dis_f_loss = (real_loss_1 + real_loss_2 + fake_loss_1 + fake_loss_2) / 4.0
              dis_f_loss.backward()

              discriminator_optimizer.step()

            # B. reverse cycle
            image_batch_1, _, __ = next(loader)
            image_batch_2, _, __ = next(loader)

            reverse_cycle_optimizer.zero_grad()

            X_1.copy_(image_batch_1)
            X_2.copy_(image_batch_2)

            style_latent_space.normal_(0., 1.)

            _, __, class_latent_space_1 = encoder(Variable(X_1))
            _, __, class_latent_space_2 = encoder(Variable(X_2))

            reconstructed_X_1 = decoder(Variable(style_latent_space), class_latent_space_1.detach())
            reconstructed_X_2 = decoder(Variable(style_latent_space), class_latent_space_2.detach())

            style_mu_1, style_logvar_1, _ = encoder(reconstructed_X_1)
            style_latent_space_1 = reparameterize(training=False, mu=style_mu_1, logvar=style_logvar_1)

            style_mu_2, style_logvar_2, _ = encoder(reconstructed_X_2)
            style_latent_space_2 = reparameterize(training=False, mu=style_mu_2, logvar=style_logvar_2)

            reverse_cycle_loss = FLAGS.reverse_cycle_coef * l1_loss(style_latent_space_1, style_latent_space_2)
            reverse_cycle_loss.backward()
            reverse_cycle_loss /= FLAGS.reverse_cycle_coef

            reverse_cycle_optimizer.step()

            # B-1. Discriminator training during reverse cycle
            if FLAGS.reverse_gan:
              # Training generator
              generator_optimizer.zero_grad()

              g_loss_1 = adversarial_loss(discriminator(Variable(reconstructed_X_1)), valid)
              g_loss_2 = adversarial_loss(discriminator(Variable(reconstructed_X_2)), valid)

              gen_r_loss = (g_loss_1 + g_loss_2) / 2.0
              gen_r_loss.backward()

              generator_optimizer.step()

              # Training discriminator
              discriminator_optimizer.zero_grad()

              real_loss_1 = adversarial_loss(discriminator(Variable(X_1)), valid)
              real_loss_2 = adversarial_loss(discriminator(Variable(X_2)), valid)
              fake_loss_1 = adversarial_loss(discriminator(Variable(reconstructed_X_1)), fake)
              fake_loss_2 = adversarial_loss(discriminator(Variable(reconstructed_X_2)), fake)

              dis_r_loss = (real_loss_1 + real_loss_2 + fake_loss_1 + fake_loss_2) / 4.0
              dis_r_loss.backward()

              discriminator_optimizer.step()

            if (iteration + 1) % 10 == 0:
                print('')
                print('Epoch #' + str(epoch))
                print('Iteration #' + str(iteration))

                print('')
                print('Reconstruction loss: ' + str(reconstruction_error.data.storage().tolist()[0]))
                print('KL-Divergence loss: ' + str(kl_divergence_error.data.storage().tolist()[0]))
                print('Reverse cycle loss: ' + str(reverse_cycle_loss.data.storage().tolist()[0]))

                if FLAGS.forward_gan:
                  print('Generator F loss: ' + str(gen_f_loss.data.storage().tolist()[0]))
                  print('Discriminator F loss: ' + str(dis_f_loss.data.storage().tolist()[0]))

                if FLAGS.reverse_gan:
                  print('Generator R loss: ' + str(gen_r_loss.data.storage().tolist()[0]))
                  print('Discriminator R loss: ' + str(dis_r_loss.data.storage().tolist()[0]))

            # write to log
            with open(FLAGS.log_file, 'a') as log:
                row = []

                row.append(epoch)
                row.append(iteration)
                row.append(reconstruction_error.data.storage().tolist()[0])
                row.append(kl_divergence_error.data.storage().tolist()[0])
                row.append(reverse_cycle_loss.data.storage().tolist()[0])

                if FLAGS.forward_gan:
                  row.append(gen_f_loss.data.storage().tolist()[0])
                  row.append(dis_f_loss.data.storage().tolist()[0])

                if FLAGS.reverse_gan:
                  row.append(gen_r_loss.data.storage().tolist()[0])
                  row.append(dis_r_loss.data.storage().tolist()[0])

                row = [str(x) for x in row]
                log.write('\t'.join(row) + '\n')

            # write to tensorboard
            writer.add_scalar('Reconstruction loss', reconstruction_error.data.storage().tolist()[0],
                              epoch * (int(len(paired_cifar) / FLAGS.batch_size) + 1) + iteration)
            writer.add_scalar('KL-Divergence loss', kl_divergence_error.data.storage().tolist()[0],
                              epoch * (int(len(paired_cifar) / FLAGS.batch_size) + 1) + iteration)
            writer.add_scalar('Reverse cycle loss', reverse_cycle_loss.data.storage().tolist()[0],
                              epoch * (int(len(paired_cifar) / FLAGS.batch_size) + 1) + iteration)

            if FLAGS.forward_gan:
              writer.add_scalar('Generator F loss', gen_f_loss.data.storage().tolist()[0],
                                epoch * (int(len(paired_cifar) / FLAGS.batch_size) + 1) + iteration)
              writer.add_scalar('Discriminator F loss', dis_f_loss.data.storage().tolist()[0],
                                epoch * (int(len(paired_cifar) / FLAGS.batch_size) + 1) + iteration)

            if FLAGS.reverse_gan:
              writer.add_scalar('Generator R loss', gen_r_loss.data.storage().tolist()[0],
                                epoch * (int(len(paired_cifar) / FLAGS.batch_size) + 1) + iteration)
              writer.add_scalar('Discriminator R loss', dis_r_loss.data.storage().tolist()[0],
                                epoch * (int(len(paired_cifar) / FLAGS.batch_size) + 1) + iteration)

        # save model after every 5 epochs
        if (epoch + 1) % 5 == 0 or (epoch + 1) == FLAGS.end_epoch:
            torch.save(encoder.state_dict(), os.path.join('checkpoints', FLAGS.encoder_save))
            torch.save(decoder.state_dict(), os.path.join('checkpoints', FLAGS.decoder_save))

            """
            save reconstructed images and style swapped image generations to check progress
            """

            X_1.copy_(image_sample_1)
            X_2.copy_(image_sample_2)
            X_3.copy_(image_sample_3)

            style_mu_1, style_logvar_1, _ = encoder(Variable(X_1))
            _, __, class_latent_space_2 = encoder(Variable(X_2))
            style_mu_3, style_logvar_3, _ = encoder(Variable(X_3))

            style_latent_space_1 = reparameterize(training=False, mu=style_mu_1, logvar=style_logvar_1)
            style_latent_space_3 = reparameterize(training=False, mu=style_mu_3, logvar=style_logvar_3)

            reconstructed_X_1_2 = decoder(style_latent_space_1, class_latent_space_2)
            reconstructed_X_3_2 = decoder(style_latent_space_3, class_latent_space_2)

            # save input image batch
            image_batch = np.transpose(X_1.cpu().numpy(), (0, 2, 3, 1))
            if FLAGS.num_channels == 1:
              image_batch = np.concatenate((image_batch, image_batch, image_batch), axis=3)
            imshow_grid(image_batch, name=str(epoch) + '_original', save=True)

            # save reconstructed batch
            reconstructed_x = np.transpose(reconstructed_X_1_2.cpu().data.numpy(), (0, 2, 3, 1))
            if FLAGS.num_channels == 1:
              reconstructed_x = np.concatenate((reconstructed_x, reconstructed_x, reconstructed_x), axis=3)
            imshow_grid(reconstructed_x, name=str(epoch) + '_target', save=True)

            style_batch = np.transpose(X_3.cpu().numpy(), (0, 2, 3, 1))
            if FLAGS.num_channels == 1:
              style_batch = np.concatenate((style_batch, style_batch, style_batch), axis=3)
            imshow_grid(style_batch, name=str(epoch) + '_style', save=True)

            # save style swapped reconstructed batch
            reconstructed_style = np.transpose(reconstructed_X_3_2.cpu().data.numpy(), (0, 2, 3, 1))
            if FLAGS.num_channels == 1:
              reconstructed_style = np.concatenate((reconstructed_style, reconstructed_style, reconstructed_style), axis=3)
            imshow_grid(reconstructed_style, name=str(epoch) + '_style_target', save=True)
                Z_logvar = Z_logvar.detach()
                # Real
                # label = torch.full((b_size,), cfg["real_label"], device=device)
                _, Dis_X = netD(X)
                errD_real = Dis_X.view(-1)
                errD_real.backward(one)

                # Fake
                X_tilde = netG(Z_mu)
                # label.fill_(cfg["fake_label"]).to(device)
                _, Dis_X_tilde = netD(X_tilde)
                errD_fake = Dis_X_tilde.view(-1)
                errD_fake.backward(minus_one)

                # Sampled
                Zp = reparameterize(Z_mu, Z_logvar)
                Xp = netG(Zp).detach()
                # label.fill_(cfg["fake_label"]).to(device)
                _, Dis_Xp = netD(Xp)
                # Dis_Xp = Dis_Xp.view(-1)
                # errD_resamp = calc_BCE_loss(Dis_Xp, label)
                errD_resamp = Dis_Xp.view(-1)
                errD_resamp.backward(minus_one)
                #
                errD = errD_real - errD_fake - errD_resamp
                optimizerD.step()

            # for i, /data in enumerate(dataloader, 0):

            # Update Generator Network
            for p in netD.parameters():
Beispiel #29
0
def training_procedure(FLAGS):
    """
    model definition
    """
    encoder = Encoder(style_dim=FLAGS.style_dim, class_dim=FLAGS.class_dim)
    encoder.apply(weights_init)

    decoder = Decoder(style_dim=FLAGS.style_dim, class_dim=FLAGS.class_dim)
    decoder.apply(weights_init)

    # load saved models if load_saved flag is true
    if FLAGS.load_saved:
        encoder.load_state_dict(
            torch.load(os.path.join('checkpoints', FLAGS.encoder_save)))
        decoder.load_state_dict(
            torch.load(os.path.join('checkpoints', FLAGS.decoder_save)))
    """
    variable definition
    """
    X = torch.FloatTensor(FLAGS.batch_size, 784)
    '''
    run on GPU if GPU is available
    '''
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    encoder.to(device=device)
    decoder.to(device=device)
    X = X.to(device=device)
    """
    optimizer definition
    """
    auto_encoder_optimizer = optim.Adam(list(encoder.parameters()) +
                                        list(decoder.parameters()),
                                        lr=FLAGS.initial_learning_rate,
                                        betas=(FLAGS.beta_1, FLAGS.beta_2))
    """
    
    """
    if torch.cuda.is_available() and not FLAGS.cuda:
        print(
            "WARNING: You have a CUDA device, so you should probably run with --cuda"
        )

    if not os.path.exists('checkpoints'):
        os.makedirs('checkpoints')

    # load_saved is false when training is started from 0th iteration
    if not FLAGS.load_saved:
        with open(FLAGS.log_file, 'w') as log:
            log.write(
                'Epoch\tIteration\tReconstruction_loss\tStyle_KL_divergence_loss\tClass_KL_divergence_loss\n'
            )

    # load data set and create data loader instance
    dirs = os.listdir(os.path.join(os.getcwd(), 'data'))
    print('Loading double multivariate normal time series data...')
    for dsname in dirs:
        params = dsname.split('_')
        if params[2] in ('theta=-1'):
            print('Running dataset ', dsname)
            ds = DoubleMulNormal(dsname)
            # ds = experiment3(1000, 50, 3)
            loader = cycle(
                DataLoader(ds,
                           batch_size=FLAGS.batch_size,
                           shuffle=True,
                           drop_last=True))

            # initialize summary writer
            writer = SummaryWriter()

            for epoch in range(FLAGS.start_epoch, FLAGS.end_epoch):
                print()
                print(
                    'Epoch #' + str(epoch) +
                    '........................................................')

                # the total loss at each epoch after running iterations of batches
                total_loss = 0

                for iteration in range(int(len(ds) / FLAGS.batch_size)):
                    # load a mini-batch
                    image_batch, labels_batch = next(loader)

                    # set zero_grad for the optimizer
                    auto_encoder_optimizer.zero_grad()

                    X.copy_(image_batch)

                    style_mu, style_logvar, class_mu, class_logvar = encoder(
                        Variable(X))
                    grouped_mu, grouped_logvar = accumulate_group_evidence(
                        class_mu.data, class_logvar.data, labels_batch,
                        FLAGS.cuda)

                    # kl-divergence error for style latent space
                    style_kl_divergence_loss = FLAGS.kl_divergence_coef * (
                        -0.5 * torch.sum(1 + style_logvar - style_mu.pow(2) -
                                         style_logvar.exp()))
                    style_kl_divergence_loss /= (FLAGS.batch_size *
                                                 FLAGS.num_channels *
                                                 FLAGS.image_size *
                                                 FLAGS.image_size)
                    style_kl_divergence_loss.backward(retain_graph=True)

                    # kl-divergence error for class latent space
                    class_kl_divergence_loss = FLAGS.kl_divergence_coef * (
                        -0.5 *
                        torch.sum(1 + grouped_logvar - grouped_mu.pow(2) -
                                  grouped_logvar.exp()))
                    class_kl_divergence_loss /= (FLAGS.batch_size *
                                                 FLAGS.num_channels *
                                                 FLAGS.image_size *
                                                 FLAGS.image_size)
                    class_kl_divergence_loss.backward(retain_graph=True)

                    # reconstruct samples
                    """
                    sampling from group mu and logvar for each image in mini-batch differently makes
                    the decoder consider class latent embeddings as random noise and ignore them 
                    """
                    style_latent_embeddings = reparameterize(
                        training=True, mu=style_mu, logvar=style_logvar)
                    class_latent_embeddings = group_wise_reparameterize(
                        training=True,
                        mu=grouped_mu,
                        logvar=grouped_logvar,
                        labels_batch=labels_batch,
                        cuda=FLAGS.cuda)

                    reconstructed_images = decoder(style_latent_embeddings,
                                                   class_latent_embeddings)

                    reconstruction_error = FLAGS.reconstruction_coef * mse_loss(
                        reconstructed_images, Variable(X))
                    reconstruction_error.backward()

                    total_loss += style_kl_divergence_loss + class_kl_divergence_loss + reconstruction_error

                    auto_encoder_optimizer.step()

                    if (iteration + 1) % 50 == 0:
                        print('\tIteration #' + str(iteration))
                        print('Reconstruction loss: ' + str(
                            reconstruction_error.data.storage().tolist()[0]))
                        print('Style KL loss: ' +
                              str(style_kl_divergence_loss.data.storage().
                                  tolist()[0]))
                        print('Class KL loss: ' +
                              str(class_kl_divergence_loss.data.storage().
                                  tolist()[0]))

                    # write to log
                    with open(FLAGS.log_file, 'a') as log:
                        log.write('{0}\t{1}\t{2}\t{3}\t{4}\n'.format(
                            epoch, iteration,
                            reconstruction_error.data.storage().tolist()[0],
                            style_kl_divergence_loss.data.storage().tolist()
                            [0],
                            class_kl_divergence_loss.data.storage().tolist()
                            [0]))

                    # write to tensorboard
                    writer.add_scalar(
                        'Reconstruction loss',
                        reconstruction_error.data.storage().tolist()[0],
                        epoch * (int(len(ds) / FLAGS.batch_size) + 1) +
                        iteration)
                    writer.add_scalar(
                        'Style KL-Divergence loss',
                        style_kl_divergence_loss.data.storage().tolist()[0],
                        epoch * (int(len(ds) / FLAGS.batch_size) + 1) +
                        iteration)
                    writer.add_scalar(
                        'Class KL-Divergence loss',
                        class_kl_divergence_loss.data.storage().tolist()[0],
                        epoch * (int(len(ds) / FLAGS.batch_size) + 1) +
                        iteration)

                    if epoch == 0 and (iteration + 1) % 50 == 0:
                        torch.save(
                            encoder.state_dict(),
                            os.path.join('checkpoints', 'encoder_' + dsname))
                        torch.save(
                            decoder.state_dict(),
                            os.path.join('checkpoints', 'decoder_' + dsname))

                # save checkpoints after every 10 epochs
                if (epoch + 1) % 10 == 0 or (epoch + 1) == FLAGS.end_epoch:
                    torch.save(
                        encoder.state_dict(),
                        os.path.join('checkpoints', 'encoder_' + dsname))
                    torch.save(
                        decoder.state_dict(),
                        os.path.join('checkpoints', 'decoder_' + dsname))

                print('Total loss at current epoch: ', total_loss.item())
Beispiel #30
0
def make_gifs(x, idx, name):
    # get approx posterior sample
    posterior_gen = []
    posterior_gen.append(x[0])
    x_in = x[0]
    # ------------ calculate the content posterior
    xs = []
    for i in range(0, opt.n_past):
        xs.append(x[i])
    #if True:
    random.shuffle(xs)
    #xc = torch.cat(xs, 1)
    mu_c, logvar_c, skip = cont_encoder(torch.cat(xs, 1))
    mu_c = mu_c.detach()

    for i in range(1, opt.n_eval):
        h_target = pose_encoder(x[i]).detach()
        mu_t_p, logvar_t_p = posterior_pose(torch.cat([h_target, mu_c], 1),
                                            time_step=i - 1)
        z_t_p = utils.reparameterize(mu_t_p, logvar_t_p)
        if i < opt.n_past:
            frame_predictor(torch.cat([z_t_p, mu_c], 1), time_step=i - 1)
            posterior_gen.append(x[i])
            x_in = x[i]
        else:
            h_pred = frame_predictor(torch.cat([z_t_p, mu_c], 1),
                                     time_step=i - 1).detach()
            x_in = decoder([h_pred, skip]).detach()
            posterior_gen.append(x_in)

    nsample = opt.nsample
    ssim = np.zeros((opt.batch_size, nsample, opt.n_future))
    psnr = np.zeros((opt.batch_size, nsample, opt.n_future))

    #ccm_pred = np.zeros((opt.batch_size, nsample, opt.n_eval-1))
    #ccm_gt = np.zeros((opt.batch_size, opt.n_eval-1))

    progress = progressbar.ProgressBar(maxval=nsample).start()
    all_gen = []
    '''for i in range(1, opt.n_eval):
        out_gt = discriminator(torch.cat([x[0], x[i]],dim=1))
        ccm_i_gt = out_gt.mean().data.cpu().numpy()
        print('time step %d, mean out gt: %.4f'%(i,ccm_i_gt))
        ccm_gt[:,i-1] = out_gt.squeeze().data.cpu().numpy()'''

    hs = []
    for i in range(0, opt.n_past):
        hs.append(pose_encoder(x[i]).detach())

    for s in range(nsample):
        progress.update(s + 1)
        gen_seq = []
        gt_seq = []
        x_in = x[0]
        all_gen.append([])
        all_gen[s].append(x_in)

        h = pose_encoder(x[0]).detach()
        for i in range(1, opt.n_eval):
            h_target = pose_encoder(x[i]).detach()

            if i < opt.n_past:
                mu_t_p, logvar_t_p = posterior_pose(torch.cat([h_target, mu_c],
                                                              1),
                                                    time_step=i - 1)
                z_t_p = utils.reparameterize(mu_t_p, logvar_t_p)
                prior(torch.cat([h, mu_c], 1), time_step=i - 1)
                frame_predictor(torch.cat([z_t_p, mu_c], 1), time_step=i - 1)
                x_in = x[i]
                all_gen[s].append(x_in)
                h = h_target
            else:
                mu_t_pp, logvar_t_pp = prior(torch.cat([h, mu_c], 1),
                                             time_step=i - 1)
                z_t = utils.reparameterize(mu_t_pp, logvar_t_pp)
                h_pred = frame_predictor(torch.cat([z_t, mu_c], 1),
                                         time_step=i - 1).detach()
                x_in = decoder([h_pred, skip]).detach()
                gen_seq.append(x_in.data.cpu().numpy())
                gt_seq.append(x[i].data.cpu().numpy())
                all_gen[s].append(x_in)
                h = pose_encoder(x_in).detach()

            #out_pred = discriminator(torch.cat([x[0],x_in],dim=1))
            #ccm_i_pred = out_pred.mean().data.cpu().numpy()
            #print('time step %d, mean out pred: %.4f'%(i,ccm_i_pred))
            #ccm_pred[:, s, i-1] = out_pred.squeeze().data.cpu().numpy()

        _, ssim[:, s, :], psnr[:, s, :] = utils.eval_seq(gt_seq, gen_seq)

    progress.finish()
    utils.clear_progressbar()

    best_ssim = np.zeros((opt.batch_size, opt.n_future))
    best_psnr = np.zeros((opt.batch_size, opt.n_future))
    ###### ssim ######
    for i in range(opt.batch_size):
        gifs = [[] for t in range(opt.n_eval)]
        text = [[] for t in range(opt.n_eval)]

        mean_ssim = np.mean(ssim[i], 1)
        ordered = np.argsort(mean_ssim)
        best_ssim[i, :] = ssim[i, ordered[-1], :]

        mean_psnr = np.mean(psnr[i], 1)
        ordered_p = np.argsort(mean_psnr)
        best_psnr[i, :] = psnr[i, ordered_p[-1], :]

        rand_sidx = [np.random.randint(nsample) for s in range(3)]

        # -- generate gifs
        for t in range(opt.n_eval):
            # gt
            gifs[t].append(add_border(x[t][i], 'green'))
            text[t].append('Ground\ntruth')
            #posterior
            if t < opt.n_past:
                color = 'green'
            else:
                color = 'red'
            gifs[t].append(add_border(posterior_gen[t][i], color))
            text[t].append('Approx.\nposterior')
            # best
            if t < opt.n_past:
                color = 'green'
            else:
                color = 'red'
            sidx = ordered[-1]
            gifs[t].append(add_border(all_gen[sidx][t][i], color))
            text[t].append('Best SSIM')
            # random 3
            for s in range(len(rand_sidx)):
                gifs[t].append(add_border(all_gen[rand_sidx[s]][t][i], color))
                text[t].append('Random\nsample %d' % (s + 1))

        fname = '%s/samples/%s_%d.gif' % (opt.log_dir, name, idx + i)
        utils.save_gif_with_text(fname, gifs, text)

        # -- generate samples
        to_plot = []
        gts = []
        best_p = []
        rand_samples = [[] for s in range(len(rand_sidx))]
        for t in range(opt.n_eval):
            # gt
            gts.append(x[t][i])
            best_p.append(all_gen[ordered_p[-1]][t][i])

            # sample
            for s in range(len(rand_sidx)):
                rand_samples[s].append(all_gen[rand_sidx[s]][t][i])

        to_plot.append(gts)
        to_plot.append(best_p)
        for s in range(len(rand_sidx)):
            to_plot.append(rand_samples[s])
        fname = '%s/samples/%s_%d.png' % (opt.log_dir, name, idx + i)
        utils.save_tensors_image(fname, to_plot)

    return best_ssim, best_psnr  #, ccm_pred, ccm_gt