示例#1
0
def test_tf_batch_to_canvas():
    import tensorflow as tf

    tf.enable_eager_execution()
    x = np.ones((9, 100, 100, 3))
    x = tf.convert_to_tensor(x)
    canvas = batches.tf_batch_to_canvas(x)
    assert canvas.shape == (1, 300, 300, 3)

    canvas = batches.tf_batch_to_canvas(x, cols=5)
    assert canvas.shape == (1, 200, 500, 3)

    canvas = batches.tf_batch_to_canvas(x, cols=1)
    assert canvas.shape == (1, 900, 100, 3)

    canvas = batches.tf_batch_to_canvas(x, cols=0)
    assert canvas.shape == (1, 900, 100, 3)

    canvas = batches.tf_batch_to_canvas(x, cols=None)
    assert canvas.shape == (1, 300, 300, 3)

    x = np.ones((9, 100, 100, 1))
    x = tf.convert_to_tensor(x)
    canvas = batches.tf_batch_to_canvas(x)
    assert canvas.shape == (1, 300, 300, 1)

    x = np.ones((9, 100, 100, 1, 1))
    x = tf.convert_to_tensor(x)
    with pytest.raises(ValueError,
                       match="input tensor has more than 4 dimensions."):
        canvas = batches.tf_batch_to_canvas(x)
    def make_loss_ops(self):
        # perceptual network
        losses = dict()
        with tf.variable_scope("VGG19__noinit__"):
            gram_weight = self.config.get("gram_weight", 0.0)
            print("GRAM: {}".format(gram_weight))
            self.vgg19 = vgg19 = deeploss.VGG19Features(
                self.session, default_gram=gram_weight, original_scale=True)
        dim = np.prod(self.model.augmented_views[0].shape.as_list()[1:])
        auto_rec_loss = (1e-3 * 0.5 * dim * vgg19.make_loss_op(
            self.model.augmented_views[2], self.model._generated))

        with tf.variable_scope("prior_gmrf_weight"):
            prior_gmrf_weight = make_var(
                step=self.global_step,
                var_type=self.config["prior_gmrf_weight"]["var_type"],
                options=self.config["prior_gmrf_weight"]["options"],
            )
        with tf.variable_scope("prior_mumford_sha_weight"):
            prior_mumford_sha_weight = make_var(
                step=self.global_step,
                var_type=self.config["prior_mumford_sha_weight"]["var_type"],
                options=self.config["prior_mumford_sha_weight"]["options"],
            )
        with tf.variable_scope("kl_weight"):
            kl_weight = make_linear_var(step=self.global_step,
                                        **self.config["kl_weight"])

        mumford_sha_alpha = make_var(
            step=self.global_step,
            var_type=self.config["mumford_sha_alpha"]["var_type"],
            options=self.config["mumford_sha_alpha"]["options"],
        )
        mumford_sha_lambda = make_var(
            step=self.global_step,
            var_type=self.config["mumford_sha_lambda"]["var_type"],
            options=self.config["mumford_sha_lambda"]["options"],
        )
        self.log_ops["mumford_sha_lambda"] = mumford_sha_lambda
        self.log_ops["mumford_sha_alpha"] = mumford_sha_alpha

        # smoothness prior (phase energy)
        prior_gmrf = self.smoothness_prior_simple_gradient()
        prior_gmrf_weighted = prior_gmrf_weight * prior_gmrf

        self.log_ops["prior_gmrf"] = prior_gmrf
        self.log_ops["prior_gmrf_weight"] = prior_gmrf_weight
        self.log_ops["prior_gmrf_weighted"] = prior_gmrf_weighted

        mask0_kl = tf.reduce_sum([
            categorical_kl(m)
            for m in [self.model.m0_sample, self.model.m1_sample]
        ])
        mask0_kl_weighted = kl_weight * mask0_kl  # TODO: categorical KL
        self.log_ops["mask0_kl_weight"] = kl_weight
        self.log_ops["mask0_kl"] = mask0_kl
        self.log_ops["mask0_kl_weighted"] = mask0_kl_weighted

        log_probs = self.model.m0_logits
        p_labels = nn.softmax(log_probs, spatial=False)
        labels = nn.hard_max(p_labels, 3)
        labels = nn.straight_through_estimator(labels, p_labels)
        if self.config.get("entropy_func", "cross_entropy") == "cross_entropy":
            weakly_superv_loss_p = tf.nn.softmax_cross_entropy_with_logits_v2(
                labels, log_probs, dim=3)
        elif self.config.get("entropy_func", "cross_entropy") == "entropy":
            weakly_superv_loss_p = tf.nn.softmax_cross_entropy_with_logits_v2(
                p_labels, log_probs, dim=3)
        else:
            raise ValueError("unkown entropy_func")
        weakly_superv_loss_p = tf.reduce_mean(weakly_superv_loss_p)

        gamma = self.config.get("gamma", 3.0)
        corrected_logits_1 = nn.softmax(self.model.m1_logits,
                                        spatial=False) * gamma
        corrected_logits_1 = nn.softmax(corrected_logits_1, spatial=True)
        corrected_logits_1 *= tf.stop_gradient(
            (1 - self.model.m1_sample_hard_mask))

        N, h, w, P = self.model.m1_sample.shape.as_list()
        _, sigma = nn.probs_to_mu_sigma(corrected_logits_1, tf.ones((N, P)))
        # for i in range(sigma.shape.as_list()[1]):
        #     self.log_ops["sigma1_{:02d}".format(i)] = sigma[0, i, 0, 0]
        # for i in range(sigma.shape.as_list()[1]):
        #     self.log_ops["sigma2_{:02d}".format(i)] = sigma[0, i, 1, 1]

        # sigma_filter = tf.reduce_sum(
        #     self.model.m1_sample_hard, axis=(1, 2), keep_dims=True
        # )
        # sigma_filter = tf.to_float(sigma_filter > 0.0)
        # sigma_filter = tf.transpose(sigma_filter, perm=[0, 3, 1, 2])
        # for i in range(sigma.shape.as_list()[1]):
        #     self.log_ops["sigma_active_{:02d}".format(i)] = sigma_filter[0, i, 0, 0]
        # sigma *= tf.stop_gradient(sigma_filter)
        # for i in range(sigma.shape.as_list()[1]):
        #     self.log_ops["sigma2_filtered_{:02d}".format(i)] = sigma[0, i, 1, 1]
        variances = tf.reduce_mean(
            tf.reduce_sum(sigma[:, :, 0, 0] + sigma[:, :, 1, 1], axis=1))

        with tf.variable_scope("variance_weight"):
            variance_weight = make_var(
                step=self.global_step,
                var_type=self.config["variance_weight"]["var_type"],
                options=self.config["variance_weight"]["options"],
            )

        variance_weighted = variance_weight * variances
        self.log_ops["variance_loss_weighted"] = variance_weighted
        self.log_ops["variance_loss"] = variances
        self.log_ops["variance_weight"] = variance_weight

        with tf.variable_scope("weakly_superv_loss_weight_p"):
            weakly_superv_loss_weight_p = make_var(
                step=self.global_step,
                var_type=self.config["weakly_superv_loss_weight_p"]
                ["var_type"],
                options=self.config["weakly_superv_loss_weight_p"]["options"],
            )

        weakly_superv_loss_p_weighted = (weakly_superv_loss_p *
                                         weakly_superv_loss_weight_p)

        self.log_ops[
            "weakly_superv_loss_weight_p"] = weakly_superv_loss_weight_p
        self.log_ops["weakly_superv_loss_p"] = weakly_superv_loss_p
        self.log_ops[
            "weakly_superv_loss_p_weighted"] = weakly_superv_loss_p_weighted

        # per submodule loss

        # delta
        losses["encoder_0"] = auto_rec_loss
        losses["encoder_1"] = auto_rec_loss
        # decoder
        losses["decoder_delta"] = auto_rec_loss

        r, smoothness_cost, contour_cost = nn.mumford_shah(
            self.model.m0_sample, 1.0, 1.0e-2)
        smoothness_cost = tf.reduce_mean(
            tf.reduce_sum(
                (tf.reduce_sum(smoothness_cost, axis=(1, 2),
                               keep_dims=True))**2,
                axis=(3, ),
            ))
        contour_cost = tf.reduce_mean(
            tf.reduce_sum(
                (tf.reduce_sum(contour_cost, axis=(1, 2), keep_dims=True))**2,
                axis=(3, ),
            ))
        p_mumford_sha = prior_mumford_sha_weight * tf.reduce_mean(
            tf.reduce_sum((tf.reduce_sum(r, axis=(1, 2), keep_dims=True))**2,
                          axis=(3, )))
        area_cost = 1.0e-12 * tf.reduce_mean(
            tf.reduce_sum(
                (tf.reduce_sum(
                    self.model.m0_sample, axis=(1, 2), keep_dims=True))**2,
                axis=(3, ),
            ))

        patch_loss = self.model.m0_sample_hard * tf.stop_gradient(
            (1 - self.model.m0_sample_hard_mask))
        patch_loss = tf.reduce_sum(patch_loss, axis=[1, 2, 3])
        patch_loss = tf.reduce_mean(patch_loss, axis=[0])
        with tf.variable_scope("patch_loss_weight"):
            patch_loss_weight = make_var(
                step=self.global_step,
                var_type=self.config["patch_loss_weight"]["var_type"],
                options=self.config["patch_loss_weight"]["options"],
            )
        patch_loss_weighted = patch_loss * patch_loss_weight
        self.log_ops["patch_loss"] = patch_loss
        self.log_ops["patch_loss_weight"] = patch_loss_weight
        self.log_ops["patch_loss_weighted"] = patch_loss_weighted

        if not self.config.get("pretrain", False):
            losses["decoder_visualize"] = (auto_rec_loss +
                                           prior_gmrf_weighted +
                                           mask0_kl_weighted +
                                           weakly_superv_loss_p_weighted +
                                           variance_weighted + p_mumford_sha +
                                           area_cost + patch_loss_weighted)
        else:
            losses["decoder_visualize"] = auto_rec_loss

        # mi estimators
        loss_dis0 = 0.5 * (logit_loss(self.model.logit_joint0, real=True) +
                           logit_loss(self.model.logit_marginal0, real=False))
        losses["mi0_discriminator"] = loss_dis0
        loss_dis1 = 0.5 * (logit_loss(self.model.logit_joint1, real=True) +
                           logit_loss(self.model.logit_marginal1, real=False))
        losses["mi1_discriminator"] = loss_dis1
        # estimator
        losses["mi_estimator"] = 0.5 * (
            logit_loss(self.model.mi_logit_joint, real=True) +
            logit_loss(self.model.mi_logit_marginal, real=False))

        # accuracies of discriminators
        def acc(ljoint, lmarg):
            correct_joint = tf.reduce_sum(tf.cast(ljoint > 0.0, tf.int32))
            correct_marginal = tf.reduce_sum(tf.cast(lmarg < 0.0, tf.int32))
            accuracy = (correct_joint +
                        correct_marginal) / (2 * tf.shape(ljoint)[0])
            return accuracy

        dis0_accuracy = acc(self.model.logit_joint0,
                            self.model.logit_marginal0)
        dis1_accuracy = acc(self.model.logit_joint1,
                            self.model.logit_marginal1)
        est_accuracy = acc(self.model.mi_logit_joint,
                           self.model.mi_logit_marginal)

        # averages
        avg_acc0 = make_ema(0.5, dis0_accuracy, self.update_ops)
        avg_acc1 = make_ema(0.5, dis1_accuracy, self.update_ops)
        avg_acc_error = make_ema(0.0, dis1_accuracy - dis0_accuracy,
                                 self.update_ops)
        self.log_ops["avg_acc_error"] = avg_acc_error
        avg_loss_dis0 = make_ema(1.0, loss_dis0, self.update_ops)
        avg_loss_dis1 = make_ema(1.0, loss_dis1, self.update_ops)

        # Parameters
        MI_TARGET = self.config["MI"].get("mi_target", 0.125)
        MI_SLACK = self.config["MI"].get("mi_slack", 0.05)

        LOO_TOL = 0.025
        LON_LR = 0.05
        LON_ADAPTIVE = False

        LOA_INIT = self.config["MI"].get("loa_init", 0.0)
        LOA_LR = self.config["MI"].get("loa_lr", 4.0)
        LOA_ADAPTIVE = self.config["MI"].get("loa_adaptive", True)

        LOR_INIT = self.config["MI"].get("lor_init", 7.5)
        LOR_LR = self.config["MI"].get("lor_lr", 0.05)
        LOR_MIN = self.config["MI"].get("lor_min", 1.0)
        LOR_MAX = self.config["MI"].get("lor_max", 7.5)
        LOR_ADAPTIVE = self.config["MI"].get("lor_adaptive", True)

        # delta mi minimization
        mim_constraint = logit_constraint(self.model.logit_joint0, real=False)
        independent_mim_constraint = logit_constraint(self.model.logit_joint1,
                                                      real=False)

        # level of overpowering
        avg_mim_constraint = tf.maximum(
            0.0, make_ema(0.0, mim_constraint, self.update_ops))
        avg_independent_mim_constraint = tf.maximum(
            0.0, make_ema(0.0, independent_mim_constraint, self.update_ops))
        self.log_ops["avg_mim"] = avg_mim_constraint
        self.log_ops["avg_independent_mim"] = avg_independent_mim_constraint
        loo = (avg_independent_mim_constraint -
               avg_mim_constraint) / (avg_independent_mim_constraint + 1e-6)
        loo = tf.clip_by_value(loo, 0.0, 1.0)
        self.log_ops["loo"] = loo

        # level of noise
        lon_gain = -loo + LOO_TOL
        self.log_ops["lon_gain"] = lon_gain
        lon_lr = LON_LR
        new_lon = tf.clip_by_value(self.model.lon + lon_lr * lon_gain, 0.0,
                                   1.0)
        if LON_ADAPTIVE:
            update_lon = tf.assign(self.model.lon, new_lon)
            self.update_ops.append(update_lon)
        self.log_ops["model_lon"] = self.model.lon

        # OPTION
        adversarial_regularization = self.config.get(
            "adversarial_regularization", True)
        if adversarial_regularization:
            # level of attack - estimate of lagrange multiplier for mi constraint
            initial_loa = LOA_INIT
            loa = tf.Variable(initial_loa, dtype=tf.float32, trainable=False)
            loa_gain = mim_constraint - (1.0 - MI_SLACK) * MI_TARGET
            loa_lr = tf.constant(LOA_LR)
            new_loa = loa + loa_lr * loa_gain
            new_loa = tf.maximum(0.0, new_loa)

            if LOA_ADAPTIVE:
                update_loa = tf.assign(loa, new_loa)
                self.update_ops.append(update_loa)

                adversarial_active = tf.stop_gradient(
                    tf.to_float(loa_lr * loa_gain >= -loa))
                adversarial_weighted_loss = adversarial_active * (
                    loa * loa_gain + loa_lr / 2.0 * tf.square(loa_gain))
            else:
                adversarial_weighted_loss = loa * loa_gain

            losses["encoder_0"] += adversarial_weighted_loss

        # use lor
        # OPTION
        variational_regularization = self.config.get(
            "variational_regularization", True)
        if variational_regularization:
            assert self.model.stochastic_encoder_0

            bottleneck_loss = self.model.z_00_distribution.kl()

            # (log) level of regularization
            initial_lor = LOR_INIT
            lor = tf.Variable(initial_lor, dtype=tf.float32, trainable=False)
            lor_gain = independent_mim_constraint - MI_TARGET
            lor_lr = LOR_LR
            new_lor = tf.clip_by_value(lor + lor_lr * lor_gain, LOR_MIN,
                                       LOR_MAX)
            if LOR_ADAPTIVE:
                update_lor = tf.assign(lor, new_lor)
                self.update_ops.append(update_lor)
            beta_0 = self.config.get("beta_0", 1.0)
            bottleneck_weighted_loss = beta_0 * tf.exp(
                lor) * bottleneck_loss  # TODO: hardcoded
            losses["encoder_0"] += bottleneck_weighted_loss
        else:
            assert not self.model.stochastic_encoder_0

        # logging
        for k in losses:
            self.log_ops["loss_{}".format(k)] = losses[k]
        self.log_ops["dis0_accuracy"] = dis0_accuracy
        self.log_ops["dis1_accuracy"] = dis1_accuracy
        self.log_ops["avg_dis0_accuracy"] = avg_acc0
        self.log_ops["avg_dis1_accuracy"] = avg_acc1
        self.log_ops["avg_loss_dis0"] = avg_loss_dis0
        self.log_ops["avg_loss_dis1"] = avg_loss_dis1
        self.log_ops["est_accuracy"] = est_accuracy

        self.log_ops["mi_constraint"] = mim_constraint
        self.log_ops["independent_mi_constraint"] = independent_mim_constraint

        if adversarial_regularization:
            self.log_ops["adversarial_weight"] = loa
            self.log_ops["adversarial_constraint"] = mim_constraint
            self.log_ops[
                "adversarial_weighted_loss"] = adversarial_weighted_loss

            self.log_ops["loa"] = loa
            self.log_ops["loa_gain"] = loa_gain

        if variational_regularization:
            self.log_ops["bottleneck_weight"] = lor
            self.log_ops["bottleneck_loss"] = bottleneck_loss
            self.log_ops["bottleneck_weighted_loss"] = bottleneck_weighted_loss

            self.log_ops["lor"] = lor
            beta_0 = self.config.get("beta_0", 1.0)
            self.log_ops["explor"] = beta_0 * tf.exp(lor)
            self.log_ops["lor_gain"] = lor_gain

        visualize_mask("out_parts_soft",
                       tf.to_float(self.model.out_parts_soft), self.img_ops,
                       True)
        visualize_mask("m0_sample", self.model.m0_sample, self.img_ops, True)

        cols = self.model.encoding_mask.shape.as_list()[0]
        encoding_masks = tf.concat([self.model.encoding_mask], 0)
        encoding_masks = tf_batches.tf_batch_to_canvas(encoding_masks, cols)
        visualize_mask("encoding_masks", tf.to_float(encoding_masks),
                       self.img_ops, False)

        decoding_masks = tf.concat([self.model.decoding_mask], 0)
        decoding_masks = tf_batches.tf_batch_to_canvas(decoding_masks, cols)
        visualize_mask("decoding_masks", tf.to_float(decoding_masks),
                       self.img_ops, False)

        masks = nn.take_only_first_item_in_the_batch(self.model.decoding_mask)
        masks = tf.transpose(masks, perm=[3, 1, 2, 0])
        self.img_ops["masks"] = masks

        correspondence0 = tf.expand_dims(
            self.model.decoding_mask, axis=-1) * tf.expand_dims(
                self.model.augmented_views[0], axis=3)
        correspondence0 = tf.transpose(correspondence0, [3, 0, 1, 2, 4])
        correspondence1 = tf.expand_dims(
            self.model.encoding_mask, axis=-1) * tf.expand_dims(
                self.model.augmented_views[1], axis=3)
        correspondence1 = tf.transpose(correspondence1, [3, 0, 1, 2, 4])
        correspondence = tf.concat([correspondence0, correspondence1], axis=1)
        N_PARTS, _, H, W, _ = correspondence.shape.as_list()

        def make_grid(X):
            X = tf.squeeze(X)
            return tf_batches.tf_batch_to_canvas(X)

        correspondence = list(
            map(make_grid, tf.split(correspondence, N_PARTS, 0)))
        correspondence = tf_batches.tf_batch_to_canvas(
            tf.concat(correspondence, 0), 5)
        self.img_ops.update({"assigned_parts": correspondence})

        p_levels = [0.01, 0.05, 0.1, 0.25, 0.5, 0.75, 0.9]
        m = nn.take_only_first_item_in_the_batch(self.model.m0_sample)

        def levels(m, i):
            return tf.concat([tf.to_float(m[..., i] > p) for p in p_levels],
                             axis=0)

        level_sets = tf.concat([levels(m, i) for i in range(m.shape[3])], 0)
        level_sets = tf.expand_dims(level_sets, -1)
        level_sets = tf_batches.tf_batch_to_canvas(level_sets, len(p_levels))
        img_title = "m0_sample_levels" + "-{}" * len(p_levels)
        img_title = img_title.format(*p_levels).replace(".", "_")

        ratios = [1.0e-3, 5 * 1.0e-3, 1.0e-2, 5 * 1.0e-2]
        mumford_sha_sets = tf.concat([nn.edge_set(m, 1, r) for r in ratios],
                                     axis=0)  # 5, H, W, 25
        mumford_sha_sets = tf.transpose(mumford_sha_sets, perm=[3, 0, 1, 2])
        a, b, h, w = mumford_sha_sets.shape.as_list()
        mumford_sha_sets = tf.reshape(mumford_sha_sets, (a * b, h, w))
        mumford_sha_sets = tf.expand_dims(mumford_sha_sets, -1)
        mumford_sha_sets = tf_batches.tf_batch_to_canvas(
            mumford_sha_sets, len(ratios))

        p_heatmap = tf.squeeze(nn.colorize(m))  # [H, W, 25, 3]
        p_heatmap = tf.transpose(p_heatmap, perm=[2, 0, 1, 3])

        self.img_ops.update({
            "mumford_sha_edges": mumford_sha_sets,
            "p_heatmap": p_heatmap,
            img_title: level_sets,
            "view0": self.model.inputs["view0"],
            "view1": self.model.inputs["view1"],
            "view0_target": self.model.inputs["view0_target"],
            "cross": self.model._cross_generated,
            "generated": self.model._generated,
        })
        if self.model.use_tps:
            self.img_ops.update({
                "tps_view0":
                self.model.outputs["tps_view0"],
                "tps_view1":
                self.model.outputs["tps_view1"],
                "tps_view0_target":
                self.model.outputs["tps_view0_target"],
            })

        self.log_ops["zr_mumford_sha"] = p_mumford_sha
        self.log_ops["z_mumford_sha_smoothness_cost"] = smoothness_cost
        self.log_ops["z_mumford_sha_contour_cost"] = contour_cost
        self.log_ops["z_area_cost"] = area_cost

        self.log_ops["prior_mumford_sha_weight"] = prior_mumford_sha_weight

        for k in self.config.get("fix_weights", []):
            if k in losses.keys():
                # losses[k] = tf.no_op()
                del losses[k]
            else:
                pass
        return losses
 def make_grid(X):
     X = tf.squeeze(X)
     return tf_batches.tf_batch_to_canvas(X)
    def make_loss_ops(self):
        # perceptual network
        with tf.variable_scope("VGG19__noinit__"):
            gram_weight = self.config.get("gram_weight", 0.0)
            print("GRAM: {}".format(gram_weight))
            self.vgg19 = vgg19 = deeploss.VGG19Features(
                self.session, default_gram=gram_weight, original_scale=True)
        dim = np.prod(self.model.inputs["view0"].shape.as_list()[1:])
        auto_rec_loss = (1e-3 * 0.5 * dim * vgg19.make_loss_op(
            self.model.inputs["view0"], self.model._generated))

        x0 = nn.take_only_first_item_in_the_batch(self.model.inputs["view0"])
        x1 = nn.take_only_first_item_in_the_batch(self.model.inputs["view1"])
        global_auto_rec_loss = (
            1e-3 * 0.5 * dim *
            vgg19.make_loss_op(x0, self.model.global_generated))

        alpha_auto_rec_loss = (
            1e-3 * 0.5 * dim *
            vgg19.make_loss_op(x1, self.model.alpha_generated))

        pi_auto_rec_loss = (1e-3 * 0.5 * dim *
                            vgg19.make_loss_op(x0, self.model.pi_generated))

        with tf.variable_scope("prior_gmrf_weight"):
            prior_gmrf_weight = make_var(
                step=self.global_step,
                var_type=self.config["prior_gmrf_weight"]["var_type"],
                options=self.config["prior_gmrf_weight"]["options"],
            )
        with tf.variable_scope("prior_mumford_sha_weight"):
            prior_mumford_sha_weight = make_var(
                step=self.global_step,
                var_type=self.config["prior_mumford_sha_weight"]["var_type"],
                options=self.config["prior_mumford_sha_weight"]["options"],
            )
        with tf.variable_scope("kl_weight"):
            kl_weight = make_linear_var(step=self.global_step,
                                        **self.config["kl_weight"])

        mumford_sha_alpha = make_var(
            step=self.global_step,
            var_type=self.config["mumford_sha_alpha"]["var_type"],
            options=self.config["mumford_sha_alpha"]["options"],
        )
        mumford_sha_lambda = make_var(
            step=self.global_step,
            var_type=self.config["mumford_sha_lambda"]["var_type"],
            options=self.config["mumford_sha_lambda"]["options"],
        )
        self.log_ops["mumford_sha_lambda"] = mumford_sha_lambda
        self.log_ops["mumford_sha_alpha"] = mumford_sha_alpha

        # smoothness prior (phase energy)
        prior_gmrf = self.smoothness_prior_simple_gradient()
        prior_gmrf_weighted = prior_gmrf_weight * prior_gmrf
        mumford_sha_prior = self.smoothness_prior_mumfordsha(
            mumford_sha_alpha, mumford_sha_lambda)
        prior_mumford_sha_weighted = mumford_sha_prior * prior_mumford_sha_weight

        self.log_ops["prior_gmrf"] = prior_gmrf
        self.log_ops["prior_gmrf_weight"] = prior_gmrf_weight
        self.log_ops["prior_gmrf_weighted"] = prior_gmrf_weighted

        self.log_ops["prior_mumford_sha"] = mumford_sha_prior
        self.log_ops["prior_mumford_sha_weight"] = prior_mumford_sha_weight
        self.log_ops["prior_mumford_sha_weighted"] = prior_mumford_sha_weighted

        mask0_kl = tf.reduce_sum([
            categorical_kl(m)
            for m in [self.model.m0_sample, self.model.m1_sample]
        ])
        mask0_kl_weighted = kl_weight * mask0_kl  # TODO: categorical KL
        self.log_ops["mask0_kl_weight"] = kl_weight
        self.log_ops["mask0_kl"] = mask0_kl
        self.log_ops["mask0_kl_weighted"] = mask0_kl_weighted

        log_probs = self.model.m0_logits
        p_labels = nn.softmax(log_probs, spatial=False)
        labels = nn.hard_max(p_labels, 3)
        labels = nn.straight_through_estimator(labels, p_labels)
        weakly_superv_loss_p = tf.nn.softmax_cross_entropy_with_logits_v2(
            labels, log_probs, dim=3)
        weakly_superv_loss_p = tf.reduce_mean(weakly_superv_loss_p)

        gamma = self.config.get("gamma", 3.0)
        # corrected_logits_1 = self.model.m11_sample ** gamma
        corrected_logits_1 = nn.softmax(self.model.m1_logits, spatial=False)
        corrected_logits_1 = nn.softmax(corrected_logits_1, spatial=True)

        corrected_logits_1 /= tf.reduce_sum(corrected_logits_1,
                                            axis=(1, 2),
                                            keep_dims=True)
        N, h, w, P = self.model.m1_sample.shape.as_list()
        _, sigma = nn.probs_to_mu_sigma(corrected_logits_1, tf.ones((N, P)))
        for i in range(sigma.shape.as_list()[1]):
            self.log_ops["sigma1_{:02d}".format(i)] = sigma[0, i, 0, 0]
        for i in range(sigma.shape.as_list()[1]):
            self.log_ops["sigma2_{:02d}".format(i)] = sigma[0, i, 1, 1]

        sigma_filter = tf.reduce_sum(self.model.m1_sample_hard,
                                     axis=(1, 2),
                                     keep_dims=True)
        sigma_filter = tf.to_float(sigma_filter > 0.0)
        sigma_filter = tf.transpose(sigma_filter, perm=[0, 3, 1, 2])
        for i in range(sigma.shape.as_list()[1]):
            self.log_ops["sigma_active_{:02d}".format(i)] = sigma_filter[0, i,
                                                                         0, 0]
        # sigma *= tf.stop_gradient(sigma_filter)
        # for i in range(sigma.shape.as_list()[1]):
        #     self.log_ops["sigma2_filtered_{:02d}".format(i)] = sigma[0, i, 1, 1]
        variances = tf.reduce_mean(
            tf.reduce_sum(sigma[:, :, 0, 0]**2 + sigma[:, :, 1, 1]**2, axis=1))

        with tf.variable_scope("variance_weight"):
            variance_weight = make_var(
                step=self.global_step,
                var_type=self.config["variance_weight"]["var_type"],
                options=self.config["variance_weight"]["options"],
            )

        variance_weighted = variance_weight * variances
        self.log_ops["variance_loss_weighted"] = variance_weighted
        self.log_ops["variance_loss"] = variances
        self.log_ops["variance_weight"] = variance_weight

        with tf.variable_scope("weakly_superv_loss_weight_p"):
            weakly_superv_loss_weight_p = make_var(
                step=self.global_step,
                var_type=self.config["weakly_superv_loss_weight_p"]
                ["var_type"],
                options=self.config["weakly_superv_loss_weight_p"]["options"],
            )

        weakly_superv_loss_p_weighted = (weakly_superv_loss_p *
                                         weakly_superv_loss_weight_p)

        self.log_ops[
            "weakly_superv_loss_weight_p"] = weakly_superv_loss_weight_p
        self.log_ops["weakly_superv_loss_p"] = weakly_superv_loss_p
        self.log_ops[
            "weakly_superv_loss_p_weighted"] = weakly_superv_loss_p_weighted

        # per submodule loss
        losses = dict()

        # delta
        losses["encoder_0"] = auto_rec_loss + global_auto_rec_loss
        losses["encoder_1"] = auto_rec_loss + global_auto_rec_loss

        losses["d_single"] = global_auto_rec_loss
        losses["d_alpha"] = alpha_auto_rec_loss
        losses["d_pi"] = pi_auto_rec_loss

        # decoder
        losses["decoder_delta"] = auto_rec_loss

        # z_rec_loss = tf.reduce_sum(
        #     tf.square(self.model.reconstructed_pi1 - self.model.pi_sample_v1),
        #     axis=(1, 2, 3),
        # )
        # z_rec_loss = tf.reduce_mean(
        #     z_rec_loss
        # ) + self.model.z_distribution.kl_to_shifted_standard_normal(
        #     self.model.z_prior_mean
        # )
        # losses["z_decoder"] = z_rec_loss
        # losses["z_encoder"] = z_rec_loss

        if not self.config.get("pretrain", False):
            losses["decoder_visualize"] = (auto_rec_loss +
                                           prior_gmrf_weighted +
                                           prior_mumford_sha_weighted +
                                           mask0_kl_weighted +
                                           weakly_superv_loss_p_weighted +
                                           variance_weighted)
        else:
            losses["decoder_visualize"] = auto_rec_loss

        # mi estimators
        loss_dis0 = 0.5 * (logit_loss(self.model.logit_joint0, real=True) +
                           logit_loss(self.model.logit_marginal0, real=False))
        losses["mi0_discriminator"] = loss_dis0
        loss_dis1 = 0.5 * (logit_loss(self.model.logit_joint1, real=True) +
                           logit_loss(self.model.logit_marginal1, real=False))
        losses["mi1_discriminator"] = loss_dis1
        # estimator
        losses["mi_estimator"] = 0.5 * (
            logit_loss(self.model.mi_logit_joint, real=True) +
            logit_loss(self.model.mi_logit_marginal, real=False))

        # accuracies of discriminators
        def acc(ljoint, lmarg):
            correct_joint = tf.reduce_sum(tf.cast(ljoint > 0.0, tf.int32))
            correct_marginal = tf.reduce_sum(tf.cast(lmarg < 0.0, tf.int32))
            accuracy = (correct_joint +
                        correct_marginal) / (2 * tf.shape(ljoint)[0])
            return accuracy

        dis0_accuracy = acc(self.model.logit_joint0,
                            self.model.logit_marginal0)
        dis1_accuracy = acc(self.model.logit_joint1,
                            self.model.logit_marginal1)
        est_accuracy = acc(self.model.mi_logit_joint,
                           self.model.mi_logit_marginal)

        # averages
        avg_acc0 = make_ema(0.5, dis0_accuracy, self.update_ops)
        avg_acc1 = make_ema(0.5, dis1_accuracy, self.update_ops)
        avg_acc_error = make_ema(0.0, dis1_accuracy - dis0_accuracy,
                                 self.update_ops)
        self.log_ops["avg_acc_error"] = avg_acc_error
        avg_loss_dis0 = make_ema(1.0, loss_dis0, self.update_ops)
        avg_loss_dis1 = make_ema(1.0, loss_dis1, self.update_ops)

        # Parameters
        MI_TARGET = self.config["MI"].get("mi_target", 0.125)
        MI_SLACK = self.config["MI"].get("mi_slack", 0.05)

        LOO_TOL = 0.025
        LON_LR = 0.05
        LON_ADAPTIVE = False

        LOA_INIT = self.config["MI"].get("loa_init", 0.0)
        LOA_LR = self.config["MI"].get("loa_lr", 4.0)
        LOA_ADAPTIVE = self.config["MI"].get("loa_adaptive", True)

        LOR_INIT = self.config["MI"].get("lor_init", 7.5)
        LOR_LR = self.config["MI"].get("lor_lr", 0.05)
        LOR_MIN = self.config["MI"].get("lor_min", 1.0)
        LOR_MAX = self.config["MI"].get("lor_max", 7.5)
        LOR_ADAPTIVE = self.config["MI"].get("lor_adaptive", True)

        # delta mi minimization
        mim_constraint = logit_constraint(self.model.logit_joint0, real=False)
        independent_mim_constraint = logit_constraint(self.model.logit_joint1,
                                                      real=False)

        # level of overpowering
        avg_mim_constraint = tf.maximum(
            0.0, make_ema(0.0, mim_constraint, self.update_ops))
        avg_independent_mim_constraint = tf.maximum(
            0.0, make_ema(0.0, independent_mim_constraint, self.update_ops))
        self.log_ops["avg_mim"] = avg_mim_constraint
        self.log_ops["avg_independent_mim"] = avg_independent_mim_constraint
        loo = (avg_independent_mim_constraint -
               avg_mim_constraint) / (avg_independent_mim_constraint + 1e-6)
        loo = tf.clip_by_value(loo, 0.0, 1.0)
        self.log_ops["loo"] = loo

        # level of noise
        lon_gain = -loo + LOO_TOL
        self.log_ops["lon_gain"] = lon_gain
        lon_lr = LON_LR
        new_lon = tf.clip_by_value(self.model.lon + lon_lr * lon_gain, 0.0,
                                   1.0)
        if LON_ADAPTIVE:
            update_lon = tf.assign(self.model.lon, new_lon)
            self.update_ops.append(update_lon)
        self.log_ops["model_lon"] = self.model.lon

        # OPTION
        adversarial_regularization = True
        if adversarial_regularization:
            # level of attack - estimate of lagrange multiplier for mi constraint
            initial_loa = LOA_INIT
            loa = tf.Variable(initial_loa, dtype=tf.float32, trainable=False)
            loa_gain = mim_constraint - (1.0 - MI_SLACK) * MI_TARGET
            loa_lr = tf.constant(LOA_LR)
            new_loa = loa + loa_lr * loa_gain
            new_loa = tf.maximum(0.0, new_loa)

            if LOA_ADAPTIVE:
                update_loa = tf.assign(loa, new_loa)
                self.update_ops.append(update_loa)

                adversarial_active = tf.stop_gradient(
                    tf.to_float(loa_lr * loa_gain >= -loa))
                adversarial_weighted_loss = adversarial_active * (
                    loa * loa_gain + loa_lr / 2.0 * tf.square(loa_gain))
            else:
                adversarial_weighted_loss = loa * loa_gain

            losses["encoder_0"] += adversarial_weighted_loss

        # use lor
        # OPTION
        variational_regularization = True
        if variational_regularization:
            assert self.model.stochastic_encoder_0

            bottleneck_loss = self.model.z_00_distribution.kl()

            # (log) level of regularization
            initial_lor = LOR_INIT
            lor = tf.Variable(initial_lor, dtype=tf.float32, trainable=False)
            lor_gain = independent_mim_constraint - MI_TARGET
            lor_lr = LOR_LR
            new_lor = tf.clip_by_value(lor + lor_lr * lor_gain, LOR_MIN,
                                       LOR_MAX)
            if LOR_ADAPTIVE:
                update_lor = tf.assign(lor, new_lor)
                self.update_ops.append(update_lor)

            bottleneck_weighted_loss = tf.exp(lor) * bottleneck_loss
            losses["encoder_0"] += bottleneck_weighted_loss
        else:
            assert not self.model.stochastic_encoder_0

        if self.model.pretty:
            weight_pretty = self.config.get("weight_pretty", 4e-3)
            pw = weight_pretty * 0.5 * dim
            # pretty
            loss_dis_pretty = 0.5 * (
                logit_loss(self.model.logit_pretty_orig, real=True) +
                logit_loss(self.model.logit_pretty_fake, real=False))
            losses["pretty_discriminator"] = loss_dis_pretty
            self.log_ops["pretty_dis_logit"] = loss_dis_pretty
            # gradient penalty on real data
            gp_pretty = gradient_penalty(losses["pretty_discriminator"],
                                         [self.model.views[0]])
            self.log_ops["gp_pretty"] = gp_pretty
            gp_weight = 1e2
            losses["pretty_discriminator"] += gp_weight * gp_pretty

            # fakor
            loss_dec_pretty = logit_loss(self.model.logit_pretty_fake,
                                         real=True)
            self.log_ops["pretty_dec_logit"] = loss_dec_pretty

            loss_dec_pretty = loss_dec_pretty

            pretty_objective = (2.0 * pw * loss_dec_pretty
                                )  # backward comp for comparison

            losses["decoder_delta"] += pretty_objective

        # logging
        for k in losses:
            self.log_ops["loss_{}".format(k)] = losses[k]
        self.log_ops["dis0_accuracy"] = dis0_accuracy
        self.log_ops["dis1_accuracy"] = dis1_accuracy
        self.log_ops["avg_dis0_accuracy"] = avg_acc0
        self.log_ops["avg_dis1_accuracy"] = avg_acc1
        self.log_ops["avg_loss_dis0"] = avg_loss_dis0
        self.log_ops["avg_loss_dis1"] = avg_loss_dis1
        self.log_ops["est_accuracy"] = est_accuracy

        self.log_ops["mi_constraint"] = mim_constraint
        self.log_ops["independent_mi_constraint"] = independent_mim_constraint

        if adversarial_regularization:
            self.log_ops["adversarial_weight"] = loa
            self.log_ops["adversarial_constraint"] = mim_constraint
            self.log_ops[
                "adversarial_weighted_loss"] = adversarial_weighted_loss

            self.log_ops["loa"] = loa
            self.log_ops["loa_gain"] = loa_gain

        if variational_regularization:
            self.log_ops["bottleneck_weight"] = lor
            self.log_ops["bottleneck_loss"] = bottleneck_loss
            self.log_ops["bottleneck_weighted_loss"] = bottleneck_weighted_loss

            self.log_ops["lor"] = lor
            self.log_ops["explor"] = tf.exp(lor)
            self.log_ops["lor_gain"] = lor_gain

        visualize_mask("out_parts_soft",
                       tf.to_float(self.model.out_parts_soft), self.img_ops,
                       True)
        visualize_mask("m0_sample", self.model.m0_sample, self.img_ops, True)

        cols = self.model.encoding_mask.shape.as_list()[0]
        encoding_masks = tf.concat([self.model.encoding_mask], 0)
        encoding_masks = tf_batches.tf_batch_to_canvas(encoding_masks, cols)
        visualize_mask("encoding_masks", tf.to_float(encoding_masks),
                       self.img_ops, False)

        decoding_masks = tf.concat([self.model.decoding_mask], 0)
        decoding_masks = tf_batches.tf_batch_to_canvas(decoding_masks, cols)
        visualize_mask("decoding_masks", tf.to_float(decoding_masks),
                       self.img_ops, False)

        # self.img_ops["part_generated"] = self.model._part_generated
        self.img_ops["global_generated"] = self.model.global_generated
        self.img_ops["alpha_generated"] = self.model.alpha_generated
        self.img_ops["pi_generated"] = self.model.pi_generated

        masks = nn.take_only_first_item_in_the_batch(self.model.decoding_mask)
        masks = tf.transpose(masks, perm=[3, 1, 2, 0])
        self.img_ops["masks"] = masks

        correspondence0 = tf.expand_dims(self.model.decoding_mask,
                                         axis=-1) * tf.expand_dims(
                                             self.model.views[0], axis=3)
        correspondence0 = tf.transpose(correspondence0, [3, 0, 1, 2, 4])
        correspondence1 = tf.expand_dims(self.model.encoding_mask,
                                         axis=-1) * tf.expand_dims(
                                             self.model.views[1], axis=3)
        correspondence1 = tf.transpose(correspondence1, [3, 0, 1, 2, 4])
        correspondence = tf.concat([correspondence0, correspondence1], axis=1)
        N_PARTS, _, H, W, _ = correspondence.shape.as_list()

        def make_grid(X):
            X = tf.squeeze(X)
            return tf_batches.tf_batch_to_canvas(X)

        correspondence = list(
            map(make_grid, tf.split(correspondence, N_PARTS, 0)))
        correspondence = tf_batches.tf_batch_to_canvas(
            tf.concat(correspondence, 0), 5)
        self.img_ops.update({"assigned_parts": correspondence})
        # self.img_ops.update(
        # {"part_samples_generated": self.model.part_samples_generated}
        # )

        self.img_ops.update({
            "view0": self.model.inputs["view0"],
            "view1": self.model.inputs["view1"],
            "cross": self.model._cross_generated,
            "generated": self.model._generated,
        })

        for k in self.config.get("fix_weights", []):
            if k in losses.keys():
                # losses[k] = tf.no_op()
                del losses[k]
            else:
                pass
        return losses