Пример #1
0
def test_reconstruction_loss(weights, text):
    reconstructer = ReconstructText(64, text.shape[1])
    wv = reconstructer(weights)
    wv.requies_grad = False
    text.requies_grad = False
    r_loss = reconstruction_loss(text, wv)
    print("Reconstruction Loss:", r_loss)
Пример #2
0
 def _call(self, x):
     mu, sigma = self.encoder(x)
     p = torch.diag(sigma[0]**2)
     sampler = Normal(loc=mu, scale=sigma**2)
     epsilon = sampler.sample((self.n_samples, ))
     z = sigma * epsilon + mu
     x_tilde = self.decoder(z)
     x_flatten_tilde = x_tilde.flatten()
     bar_eta = mu.reshape((mu.shape[0], -1, mu.shape[1]))
     r = torch.diag(x_flatten_tilde * (1 - x_flatten_tilde))
     gamma = 0
     alpha = 1
     eta_1 = torch.zeros(self.n_samples, 1, z.shape[-1])
     eta_1 = z[0]
     lambda_1 = 0
     eta_1.squeeze(1)
     for j in range(self.intervals.shape[0]):
         lambda_1 = lambda_1 + self.intervals[j]
         print(j, 'th interval', lambda_1, self.intervals[j])
         h_eta = self.decoder(bar_eta)
         h_eta = h_eta.flatten()
         H = self.get_jacobian(bar_eta, h_eta).squeeze()
         bar_eta = bar_eta.squeeze(1)
         square_mat = lambda_1 * H @ p @ (H.transpose(1, 0)) + r
         A_j_lambda = -1 / 2 * p @ H.transpose(1,
                                               0) @ square_mat.inverse() @ H
         temp = (torch.eye(bar_eta.shape[-1]) + lambda_1 * A_j_lambda
                 ) @ p @ H.transpose(1, 0) @ r.inverse() @ x.reshape(
                     (784, 1)) + A_j_lambda @ mu.transpose(1, 0)
         b_j_lambda = (torch.eye(bar_eta.shape[-1]) +
                       2 * lambda_1 * A_j_lambda) @ temp
         bar_eta = (
             bar_eta.transpose(1, 0) + self.intervals[j] *
             (A_j_lambda @ bar_eta.transpose(1, 0) + b_j_lambda)).transpose(
                 1, 0)
         eta_1 = (
             eta_1.transpose(1, 0) + self.intervals[j] *
             (A_j_lambda @ eta_1.transpose(1, 0) + b_j_lambda)).transpose(
                 1, 0)
         alpha_mat = torch.eye(
             A_j_lambda.shape[0]) + self.intervals[j] * A_j_lambda
         alpha = alpha * torch.abs(torch.det(alpha_mat))
     gamma = gamma + torch.log(alpha)
     x_reconstruction = self.decoder(eta_1)
     x_reconstruction = x_reconstruction.flatten()
     x_reconstruction = torch.sigmoid(x_reconstruction)
     log_x_g_z = -reconstruction_loss(
         x_reconstruction.reshape(
             (self.n_samples, 1, x.shape[-2], x.shape[-1])), x)  ## negative
     log_z = torch.sum(-0.5 * eta_1**2, -1)
     log_z_g_x = torch.sum(-torch.log(sigma) - 0.5 * (z - mu) * (z - mu) *
                           (sigma.reciprocal()**2))
     elbo = log_x_g_z + log_z - log_z_g_x + gamma
     print(-log_x_g_z, '\n', log_z, '\n', log_z_g_x, '\n', gamma)
     return x_reconstruction, eta_1, gamma, elbo
Пример #3
0
def test(valid_queue, model, num_samples, args, logging):
    if args.distributed:
        dist.barrier()
    nelbo_avg = utils.AvgrageMeter()
    neg_log_p_avg = utils.AvgrageMeter()
    model.eval()
    for step, x in enumerate(valid_queue):
        x = x[0] if len(x) > 1 else x
        x = x.float().cuda()

        # change bit length
        x = utils.pre_process(x, args.num_x_bits)

        with torch.no_grad():
            nelbo, log_iw = [], []
            for k in range(num_samples):
                logits, log_q, log_p, kl_all, _ = model(x)
                output = model.decoder_output(logits)
                recon_loss = utils.reconstruction_loss(output,
                                                       x,
                                                       crop=model.crop_output)
                balanced_kl, _, _ = utils.kl_balancer(kl_all, kl_balance=False)
                nelbo_batch = recon_loss + balanced_kl
                nelbo.append(nelbo_batch)
                log_iw.append(
                    utils.log_iw(output,
                                 x,
                                 log_q,
                                 log_p,
                                 crop=model.crop_output))

            nelbo = torch.mean(torch.stack(nelbo, dim=1))
            log_p = torch.mean(
                torch.logsumexp(torch.stack(log_iw, dim=1), dim=1) -
                np.log(num_samples))

        nelbo_avg.update(nelbo.data, x.size(0))
        neg_log_p_avg.update(-log_p.data, x.size(0))

    utils.average_tensor(nelbo_avg.avg, args.distributed)
    utils.average_tensor(neg_log_p_avg.avg, args.distributed)
    if args.distributed:
        # block to sync
        dist.barrier()
    logging.info('val, step: %d, NELBO: %f, neg Log p %f', step, nelbo_avg.avg,
                 neg_log_p_avg.avg)
    return neg_log_p_avg.avg, nelbo_avg.avg
Пример #4
0
    def compute_output(self,
                       X,
                       Y,
                       keep_prob=cfg.keep_prob,
                       regularization_scale=cfg.regularization_scale):

        print("Size of input:")
        print(X.get_shape())

        # 1. Convolve the input image up to the digit capsules.
        digit_caps = self._image_to_digitcaps(X)

        # 2. Get the margin loss
        margin_loss = u.margin_loss(digit_caps, Y)

        # 3. Reconstruct the images
        reconstructed_image, reconstruction_1, reconstruction_2 = self._digitcaps_to_image(
            digit_caps, Y)

        # 4. Get the reconstruction loss
        reconstruction_loss = u.reconstruction_loss(reconstructed_image, X)

        # 5. Get the total loss
        total_loss = margin_loss + regularization_scale * reconstruction_loss

        # 6. Get the batch accuracy
        batch_accuracy = u.acc(digit_caps, Y)

        # 7. Reconstruct all possible images
        memo = self._digitcaps_to_memo(X, digit_caps)

        # 8. Get the memo capsules
        memo_caps = self._memo_to_digitcaps(memo, keep_prob=keep_prob)

        # 9. Get the memo margin loss
        memo_margin_loss = u.margin_loss(memo_caps, Y)

        # 10. Get the memo accuracy
        memo_accuracy = u.acc(memo_caps, Y)

        # 11. Return all of the losses and reconstructions
        return (total_loss, margin_loss, reconstruction_loss,
                reconstructed_image, reconstruction_1, reconstruction_2,
                batch_accuracy, memo, memo_margin_loss, memo_accuracy)
Пример #5
0
def train_main_network(xi_t, xi_tk, xj_tk):
    content_encoder.zero_grad()
    pose_encoder.zero_grad()
    decoder.zero_grad()

    # Compute content vectors of video i
    ci_t = content_encoder(xi_t)
    ci_tk = content_encoder(xi_tk).detach()

    # Compute pose vectors of video i
    pi_t = pose_encoder(xi_t)
    pi_tk = pose_encoder(xi_tk).detach()

    # Compute pose vector of video j
    pj_tk = pose_encoder(xj_tk).detach()

    # Compuse scene discrimination vector
    discr_same = discriminator(pi_t, pi_tk)
    discr_diff = discriminator(pi_t, pj_tk)

    # Compute reconsctruct image
    pred_xitk = decoder(ci_t, pi_tk)

    # Similarity loss
    sim_loss = nutils.similarity_loss(ci_t, ci_tk, device=device)

    # Reconstruction loss
    rec_loss = nutils.reconstruction_loss(pred_xitk, xi_tk, device=device)

    # Adversarial loss
    adv_loss = nutils.adversarial_loss(discr_same, discr_diff, device=device)

    # Total loss
    loss = rec_loss + alpha * sim_loss + beta * adv_loss

    loss.backward()

    pose_encoder_optim.step()
    content_encoder_optim.step()
    decoder_optim.step()

    return get_value(rec_loss), get_value(sim_loss), get_value(adv_loss)
Пример #6
0
def train(train_queue, model, cnn_optimizer, grad_scalar, global_step,
          warmup_iters, writer, logging):
    alpha_i = utils.kl_balancer_coeff(num_scales=model.num_latent_scales,
                                      groups_per_scale=model.groups_per_scale,
                                      fun='square')
    nelbo = utils.AvgrageMeter()
    model.train()
    for step, x in enumerate(train_queue):
        x = x[0] if len(x) > 1 else x
        x = x.half().cuda()

        # change bit length
        x = utils.pre_process(x, args.num_x_bits)

        # warm-up lr
        if global_step < warmup_iters:
            lr = args.learning_rate * float(global_step) / warmup_iters
            for param_group in cnn_optimizer.param_groups:
                param_group['lr'] = lr

        # sync parameters, it may not be necessary
        if step % 100 == 0:
            utils.average_params(model.parameters(), args.distributed)

        cnn_optimizer.zero_grad()
        with autocast():
            logits, log_q, log_p, kl_all, kl_diag = model(x)

            output = model.decoder_output(logits)
            kl_coeff = utils.kl_coeff(
                global_step, args.kl_anneal_portion * args.num_total_iter,
                args.kl_const_portion * args.num_total_iter,
                args.kl_const_coeff)

            recon_loss = utils.reconstruction_loss(output,
                                                   x,
                                                   crop=model.crop_output)
            balanced_kl, kl_coeffs, kl_vals = utils.kl_balancer(
                kl_all, kl_coeff, kl_balance=True, alpha_i=alpha_i)

            nelbo_batch = recon_loss + balanced_kl
            loss = torch.mean(nelbo_batch)
            norm_loss = model.spectral_norm_parallel()
            bn_loss = model.batchnorm_loss()
            # get spectral regularization coefficient (lambda)
            if args.weight_decay_norm_anneal:
                assert args.weight_decay_norm_init > 0 and args.weight_decay_norm > 0, 'init and final wdn should be positive.'
                wdn_coeff = (1. - kl_coeff) * np.log(
                    args.weight_decay_norm_init) + kl_coeff * np.log(
                        args.weight_decay_norm)
                wdn_coeff = np.exp(wdn_coeff)
            else:
                wdn_coeff = args.weight_decay_norm

            loss += norm_loss * wdn_coeff + bn_loss * wdn_coeff

        grad_scalar.scale(loss).backward()
        utils.average_gradients(model.parameters(), args.distributed)
        grad_scalar.step(cnn_optimizer)
        grad_scalar.update()
        nelbo.update(loss.data, 1)

        if (global_step + 1) % 100 == 0:
            if (global_step + 1) % 1000 == 0:  # reduced frequency
                n = int(np.floor(np.sqrt(x.size(0))))
                x_img = x[:n * n]
                output_img = output.mean if isinstance(
                    output, torch.distributions.bernoulli.Bernoulli
                ) else output.sample()
                output_img = output_img[:n * n]
                x_tiled = utils.tile_image(x_img, n)
                output_tiled = utils.tile_image(output_img, n)
                in_out_tiled = torch.cat((x_tiled, output_tiled), dim=2)
                writer.add_image('reconstruction', in_out_tiled, global_step)

            # norm
            writer.add_scalar('train/norm_loss', norm_loss, global_step)
            writer.add_scalar('train/bn_loss', bn_loss, global_step)
            writer.add_scalar('train/norm_coeff', wdn_coeff, global_step)

            utils.average_tensor(nelbo.avg, args.distributed)
            logging.info('train %d %f', global_step, nelbo.avg)
            writer.add_scalar('train/nelbo_avg', nelbo.avg, global_step)
            writer.add_scalar(
                'train/lr',
                cnn_optimizer.state_dict()['param_groups'][0]['lr'],
                global_step)
            writer.add_scalar('train/nelbo_iter', loss, global_step)
            writer.add_scalar('train/kl_iter', torch.mean(sum(kl_all)),
                              global_step)
            writer.add_scalar(
                'train/recon_iter',
                torch.mean(
                    utils.reconstruction_loss(output,
                                              x,
                                              crop=model.crop_output)),
                global_step)
            writer.add_scalar('kl_coeff/coeff', kl_coeff, global_step)
            total_active = 0
            for i, kl_diag_i in enumerate(kl_diag):
                utils.average_tensor(kl_diag_i, args.distributed)
                num_active = torch.sum(kl_diag_i > 0.1).detach()
                total_active += num_active

                # kl_ceoff
                writer.add_scalar('kl/active_%d' % i, num_active, global_step)
                writer.add_scalar('kl_coeff/layer_%d' % i, kl_coeffs[i],
                                  global_step)
                writer.add_scalar('kl_vals/layer_%d' % i, kl_vals[i],
                                  global_step)
            writer.add_scalar('kl/total_active', total_active, global_step)

        global_step += 1

    utils.average_tensor(nelbo.avg, args.distributed)
    return nelbo.avg, global_step
Пример #7
0
def caps_model_fn(features, labels, mode):
    hooks = []
    train_log_dict = {}
    """Model function for CNN."""
    # Input Layer
    # Reshape X to 4-D tensor: [batch_size, width, height, channels]
    # Fashion MNIST images are 28x28 pixels, and have one color channel
    input_layer = tf.reshape(features["x"], [-1, 28, 28, 1])

    # A little bit cheaper version of the capsule network in: Dynamic Routing Between Capsules
    # Std. convolutional layer
    conv1 = tf.layers.conv2d(inputs=input_layer,
                             filters=256,
                             kernel_size=[9, 9],
                             padding="valid",
                             activation=tf.nn.relu,
                             name="ReLU_Conv1")
    conv1 = tf.expand_dims(conv1, axis=-2)
    # Convolutional capsules, no routing as the dimension of the units of previous layer is one
    primarycaps = caps.conv2d(conv1,
                              32,
                              8, [9, 9],
                              strides=(2, 2),
                              name="PrimaryCaps")
    primarycaps = tf.reshape(
        primarycaps,
        [-1, primarycaps.shape[1].value * primarycaps.shape[2].value * 32, 8])
    # Fully connected capsules with routing by agreement
    digitcaps = caps.dense(primarycaps,
                           10,
                           16,
                           iter_routing=iter_routing,
                           learn_coupling=learn_coupling,
                           mapfn_parallel_iterations=mapfn_parallel_iterations,
                           name="DigitCaps")
    # The length of the capsule activation vectors encodes the probability of an entity being present
    lengths = tf.sqrt(tf.reduce_sum(tf.square(digitcaps), axis=2) + epsilon,
                      name="Lengths")

    # Predictions for (PREDICTION mode)
    predictions = {
        # Generate predictions (for PREDICT and EVAL mode)
        "classes": tf.argmax(lengths, axis=1),
        "probabilities": tf.nn.softmax(lengths, name="Softmax")
    }

    if regularization:
        masked_digitcaps_pred = mask_one(digitcaps,
                                         lengths,
                                         is_predicting=True)
        with tf.variable_scope(tf.get_variable_scope()):
            reconstruction_pred = decoder_nn(masked_digitcaps_pred)
        predictions["reconstruction"] = reconstruction_pred

    if mode == tf.estimator.ModeKeys.PREDICT:
        return tf.estimator.EstimatorSpec(mode=mode, predictions=predictions)

    # Calculate Loss (for both TRAIN and EVAL modes)
    onehot_labels = tf.one_hot(indices=tf.cast(labels, tf.int32), depth=10)
    m_loss = margin_loss(onehot_labels, lengths)
    train_log_dict["margin loss"] = m_loss
    tf.summary.scalar("margin_loss", m_loss)
    if regularization:
        masked_digitcaps = mask_one(digitcaps, onehot_labels)
        with tf.variable_scope(tf.get_variable_scope(), reuse=True):
            reconstruction = decoder_nn(masked_digitcaps)
        rec_loss = reconstruction_loss(input_layer, reconstruction)
        train_log_dict["reconstruction loss"] = rec_loss
        tf.summary.scalar("reconstruction_loss", rec_loss)
        loss = m_loss + lambda_reg * rec_loss
    else:
        loss = m_loss

    # Configure the Training Op (for TRAIN mode)
    if mode == tf.estimator.ModeKeys.TRAIN:
        # Logging hook
        train_log_dict["accuracy"] = tf.metrics.accuracy(
            labels=labels, predictions=predictions["classes"])[1]
        logging_hook = tf.train.LoggingTensorHook(
            train_log_dict, every_n_iter=config.save_summary_steps)
        # Summary hook
        summary_hook = tf.train.SummarySaverHook(
            save_steps=config.save_summary_steps,
            output_dir=model_dir,
            summary_op=tf.summary.merge_all())
        hooks += [logging_hook, summary_hook]
        global_step = tf.train.get_or_create_global_step()
        learning_rate = tf.train.exponential_decay(start_lr, global_step,
                                                   decay_steps, decay_rate)
        optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate)
        train_op = optimizer.minimize(loss=loss,
                                      global_step=tf.train.get_global_step())
        return tf.estimator.EstimatorSpec(mode=mode,
                                          loss=loss,
                                          train_op=train_op,
                                          training_hooks=hooks)

    # Add evaluation metrics (for EVAL mode)
    eval_metric_ops = {
        "accuracy":
        tf.metrics.accuracy(labels=labels, predictions=predictions["classes"])
    }
    return tf.estimator.EstimatorSpec(mode=mode,
                                      loss=loss,
                                      eval_metric_ops=eval_metric_ops)
Пример #8
0
    def forward(self, x, global_step, args):

        if args.fp16:
            x = x.half()

        metrics = {}

        alpha_i = utils.kl_balancer_coeff(
            num_scales=self.num_latent_scales,
            groups_per_scale=self.groups_per_scale,
            fun='square')

        x_in = self.preprocess(x)
        if args.fp16:
            x_in = x_in.half()
        s = self.stem(x_in)

        # perform pre-processing
        for cell in self.pre_process:
            s = cell(s)

        # run the main encoder tower
        combiner_cells_enc = []
        combiner_cells_s = []
        for cell in self.enc_tower:
            if cell.cell_type == 'combiner_enc':
                combiner_cells_enc.append(cell)
                combiner_cells_s.append(s)
            else:
                s = cell(s)

        # reverse combiner cells and their input for decoder
        combiner_cells_enc.reverse()
        combiner_cells_s.reverse()

        idx_dec = 0
        ftr = self.enc0(s)  # this reduces the channel dimension
        param0 = self.enc_sampler[idx_dec](ftr)
        mu_q, log_sig_q = torch.chunk(param0, 2, dim=1)
        dist = Normal(mu_q, log_sig_q)  # for the first approx. posterior
        z, _ = dist.sample()
        log_q_conv = dist.log_p(z)

        # apply normalizing flows
        nf_offset = 0
        for n in range(self.num_flows):
            z, log_det = self.nf_cells[n](z, ftr)
            log_q_conv -= log_det
        nf_offset += self.num_flows
        all_q = [dist]
        all_log_q = [log_q_conv]

        # To make sure we do not pass any deterministic features from x to decoder.
        s = 0

        # prior for z0
        dist = Normal(mu=torch.zeros_like(z), log_sigma=torch.zeros_like(z))
        log_p_conv = dist.log_p(z)
        all_p = [dist]
        all_log_p = [log_p_conv]

        idx_dec = 0
        s = self.prior_ftr0.unsqueeze(0)
        batch_size = z.size(0)
        s = s.expand(batch_size, -1, -1)
        for cell in self.dec_tower:
            if cell.cell_type == 'combiner_dec':
                if idx_dec > 0:
                    # form prior
                    param = self.dec_sampler[idx_dec - 1](s)
                    mu_p, log_sig_p = torch.chunk(param, 2, dim=1)

                    # form encoder
                    ftr = combiner_cells_enc[idx_dec - 1](
                        combiner_cells_s[idx_dec - 1], s)
                    param = self.enc_sampler[idx_dec](ftr)
                    mu_q, log_sig_q = torch.chunk(param, 2, dim=1)
                    dist = Normal(mu_p + mu_q, log_sig_p +
                                  log_sig_q) if self.res_dist else Normal(
                                      mu_q, log_sig_q)
                    z, _ = dist.sample()
                    log_q_conv = dist.log_p(z)
                    # apply NF
                    for n in range(self.num_flows):
                        z, log_det = self.nf_cells[nf_offset + n](z, ftr)
                        log_q_conv -= log_det
                    nf_offset += self.num_flows
                    all_log_q.append(log_q_conv)
                    all_q.append(dist)

                    # evaluate log_p(z)
                    dist = Normal(mu_p, log_sig_p)
                    log_p_conv = dist.log_p(z)
                    all_p.append(dist)
                    all_log_p.append(log_p_conv)

                # 'combiner_dec'
                s = cell(s, z)
                idx_dec += 1
            else:
                s = cell(s)

        if self.vanilla_vae:
            s = self.stem_decoder(z)

        for cell in self.post_process:
            s = cell(s)

        logits = self.image_conditional(s)

        # compute kl
        kl_all = []
        kl_diag = []
        log_p, log_q = 0., 0.
        for q, p, log_q_conv, log_p_conv in zip(all_q, all_p, all_log_q,
                                                all_log_p):
            if self.with_nf:
                kl_per_var = log_q_conv - log_p_conv
            else:
                kl_per_var = q.kl(p)

            kl_diag.append(torch.mean(torch.sum(kl_per_var, dim=2), dim=0))
            kl_all.append(torch.sum(kl_per_var, dim=[1, 2]))
            log_q += torch.sum(log_q_conv, dim=[1, 2])
            log_p += torch.sum(log_p_conv, dim=[1, 2])

        output = self.decoder_output(logits)
        """
        def _spectral_loss(x_target, x_out, args):
            if hps.use_nonrelative_specloss:
                sl = spectral_loss(x_target, x_out, args) / args.bandwidth['spec']
            else:
                sl = spectral_convergence(x_target, x_out, args)
            sl = t.mean(sl)
            return sl

        def _multispectral_loss(x_target, x_out, args):
            sl = multispectral_loss(x_target, x_out, args) / args.bandwidth['spec']
            sl = t.mean(sl)
            return sl
        """

        kl_coeff = utils.kl_coeff(global_step,
                                  args.kl_anneal_portion * args.num_total_iter,
                                  args.kl_const_portion * args.num_total_iter,
                                  args.kl_const_coeff)
        recon_loss = utils.reconstruction_loss(output,
                                               x,
                                               crop=self.crop_output)
        balanced_kl, kl_coeffs, kl_vals = utils.kl_balancer(kl_all,
                                                            kl_coeff,
                                                            kl_balance=True,
                                                            alpha_i=alpha_i)

        nelbo_batch = recon_loss + balanced_kl
        loss = torch.mean(nelbo_batch)

        bn_loss = self.batchnorm_loss()
        norm_loss = self.spectral_norm_parallel()

        #x_target = audio_postprocess(x.float(), args)
        #x_out = audio_postprocess(output.sample(), args)

        #spec_loss = _spectral_loss(x_target, x_out, args)
        #multispec_loss = _multispectral_loss(x_target, x_out, args)

        if args.weight_decay_norm_anneal:
            assert args.weight_decay_norm_init > 0 and args.weight_decay_norm > 0, 'init and final wdn should be positive.'
            wdn_coeff = (1. - kl_coeff) * np.log(
                args.weight_decay_norm_init) + kl_coeff * np.log(
                    args.weight_decay_norm)
            wdn_coeff = np.exp(wdn_coeff)
        else:
            wdn_coeff = args.weight_decay_norm

        loss += bn_loss * wdn_coeff + norm_loss * wdn_coeff

        metrics.update(
            dict(recon_loss=recon_loss,
                 bn_loss=bn_loss,
                 norm_loss=norm_loss,
                 wdn_coeff=wdn_coeff,
                 kl_all=torch.mean(sum(kl_all)),
                 kl_coeff=kl_coeff))

        for key, val in metrics.items():
            metrics[key] = val.detach()

        return output, loss, metrics