def test_tf_batch_to_canvas(): import tensorflow as tf tf.enable_eager_execution() x = np.ones((9, 100, 100, 3)) x = tf.convert_to_tensor(x) canvas = batches.tf_batch_to_canvas(x) assert canvas.shape == (1, 300, 300, 3) canvas = batches.tf_batch_to_canvas(x, cols=5) assert canvas.shape == (1, 200, 500, 3) canvas = batches.tf_batch_to_canvas(x, cols=1) assert canvas.shape == (1, 900, 100, 3) canvas = batches.tf_batch_to_canvas(x, cols=0) assert canvas.shape == (1, 900, 100, 3) canvas = batches.tf_batch_to_canvas(x, cols=None) assert canvas.shape == (1, 300, 300, 3) x = np.ones((9, 100, 100, 1)) x = tf.convert_to_tensor(x) canvas = batches.tf_batch_to_canvas(x) assert canvas.shape == (1, 300, 300, 1) x = np.ones((9, 100, 100, 1, 1)) x = tf.convert_to_tensor(x) with pytest.raises(ValueError, match="input tensor has more than 4 dimensions."): canvas = batches.tf_batch_to_canvas(x)
def make_loss_ops(self): # perceptual network losses = dict() with tf.variable_scope("VGG19__noinit__"): gram_weight = self.config.get("gram_weight", 0.0) print("GRAM: {}".format(gram_weight)) self.vgg19 = vgg19 = deeploss.VGG19Features( self.session, default_gram=gram_weight, original_scale=True) dim = np.prod(self.model.augmented_views[0].shape.as_list()[1:]) auto_rec_loss = (1e-3 * 0.5 * dim * vgg19.make_loss_op( self.model.augmented_views[2], self.model._generated)) with tf.variable_scope("prior_gmrf_weight"): prior_gmrf_weight = make_var( step=self.global_step, var_type=self.config["prior_gmrf_weight"]["var_type"], options=self.config["prior_gmrf_weight"]["options"], ) with tf.variable_scope("prior_mumford_sha_weight"): prior_mumford_sha_weight = make_var( step=self.global_step, var_type=self.config["prior_mumford_sha_weight"]["var_type"], options=self.config["prior_mumford_sha_weight"]["options"], ) with tf.variable_scope("kl_weight"): kl_weight = make_linear_var(step=self.global_step, **self.config["kl_weight"]) mumford_sha_alpha = make_var( step=self.global_step, var_type=self.config["mumford_sha_alpha"]["var_type"], options=self.config["mumford_sha_alpha"]["options"], ) mumford_sha_lambda = make_var( step=self.global_step, var_type=self.config["mumford_sha_lambda"]["var_type"], options=self.config["mumford_sha_lambda"]["options"], ) self.log_ops["mumford_sha_lambda"] = mumford_sha_lambda self.log_ops["mumford_sha_alpha"] = mumford_sha_alpha # smoothness prior (phase energy) prior_gmrf = self.smoothness_prior_simple_gradient() prior_gmrf_weighted = prior_gmrf_weight * prior_gmrf self.log_ops["prior_gmrf"] = prior_gmrf self.log_ops["prior_gmrf_weight"] = prior_gmrf_weight self.log_ops["prior_gmrf_weighted"] = prior_gmrf_weighted mask0_kl = tf.reduce_sum([ categorical_kl(m) for m in [self.model.m0_sample, self.model.m1_sample] ]) mask0_kl_weighted = kl_weight * mask0_kl # TODO: categorical KL self.log_ops["mask0_kl_weight"] = kl_weight self.log_ops["mask0_kl"] = mask0_kl self.log_ops["mask0_kl_weighted"] = mask0_kl_weighted log_probs = self.model.m0_logits p_labels = nn.softmax(log_probs, spatial=False) labels = nn.hard_max(p_labels, 3) labels = nn.straight_through_estimator(labels, p_labels) if self.config.get("entropy_func", "cross_entropy") == "cross_entropy": weakly_superv_loss_p = tf.nn.softmax_cross_entropy_with_logits_v2( labels, log_probs, dim=3) elif self.config.get("entropy_func", "cross_entropy") == "entropy": weakly_superv_loss_p = tf.nn.softmax_cross_entropy_with_logits_v2( p_labels, log_probs, dim=3) else: raise ValueError("unkown entropy_func") weakly_superv_loss_p = tf.reduce_mean(weakly_superv_loss_p) gamma = self.config.get("gamma", 3.0) corrected_logits_1 = nn.softmax(self.model.m1_logits, spatial=False) * gamma corrected_logits_1 = nn.softmax(corrected_logits_1, spatial=True) corrected_logits_1 *= tf.stop_gradient( (1 - self.model.m1_sample_hard_mask)) N, h, w, P = self.model.m1_sample.shape.as_list() _, sigma = nn.probs_to_mu_sigma(corrected_logits_1, tf.ones((N, P))) # for i in range(sigma.shape.as_list()[1]): # self.log_ops["sigma1_{:02d}".format(i)] = sigma[0, i, 0, 0] # for i in range(sigma.shape.as_list()[1]): # self.log_ops["sigma2_{:02d}".format(i)] = sigma[0, i, 1, 1] # sigma_filter = tf.reduce_sum( # self.model.m1_sample_hard, axis=(1, 2), keep_dims=True # ) # sigma_filter = tf.to_float(sigma_filter > 0.0) # sigma_filter = tf.transpose(sigma_filter, perm=[0, 3, 1, 2]) # for i in range(sigma.shape.as_list()[1]): # self.log_ops["sigma_active_{:02d}".format(i)] = sigma_filter[0, i, 0, 0] # sigma *= tf.stop_gradient(sigma_filter) # for i in range(sigma.shape.as_list()[1]): # self.log_ops["sigma2_filtered_{:02d}".format(i)] = sigma[0, i, 1, 1] variances = tf.reduce_mean( tf.reduce_sum(sigma[:, :, 0, 0] + sigma[:, :, 1, 1], axis=1)) with tf.variable_scope("variance_weight"): variance_weight = make_var( step=self.global_step, var_type=self.config["variance_weight"]["var_type"], options=self.config["variance_weight"]["options"], ) variance_weighted = variance_weight * variances self.log_ops["variance_loss_weighted"] = variance_weighted self.log_ops["variance_loss"] = variances self.log_ops["variance_weight"] = variance_weight with tf.variable_scope("weakly_superv_loss_weight_p"): weakly_superv_loss_weight_p = make_var( step=self.global_step, var_type=self.config["weakly_superv_loss_weight_p"] ["var_type"], options=self.config["weakly_superv_loss_weight_p"]["options"], ) weakly_superv_loss_p_weighted = (weakly_superv_loss_p * weakly_superv_loss_weight_p) self.log_ops[ "weakly_superv_loss_weight_p"] = weakly_superv_loss_weight_p self.log_ops["weakly_superv_loss_p"] = weakly_superv_loss_p self.log_ops[ "weakly_superv_loss_p_weighted"] = weakly_superv_loss_p_weighted # per submodule loss # delta losses["encoder_0"] = auto_rec_loss losses["encoder_1"] = auto_rec_loss # decoder losses["decoder_delta"] = auto_rec_loss r, smoothness_cost, contour_cost = nn.mumford_shah( self.model.m0_sample, 1.0, 1.0e-2) smoothness_cost = tf.reduce_mean( tf.reduce_sum( (tf.reduce_sum(smoothness_cost, axis=(1, 2), keep_dims=True))**2, axis=(3, ), )) contour_cost = tf.reduce_mean( tf.reduce_sum( (tf.reduce_sum(contour_cost, axis=(1, 2), keep_dims=True))**2, axis=(3, ), )) p_mumford_sha = prior_mumford_sha_weight * tf.reduce_mean( tf.reduce_sum((tf.reduce_sum(r, axis=(1, 2), keep_dims=True))**2, axis=(3, ))) area_cost = 1.0e-12 * tf.reduce_mean( tf.reduce_sum( (tf.reduce_sum( self.model.m0_sample, axis=(1, 2), keep_dims=True))**2, axis=(3, ), )) patch_loss = self.model.m0_sample_hard * tf.stop_gradient( (1 - self.model.m0_sample_hard_mask)) patch_loss = tf.reduce_sum(patch_loss, axis=[1, 2, 3]) patch_loss = tf.reduce_mean(patch_loss, axis=[0]) with tf.variable_scope("patch_loss_weight"): patch_loss_weight = make_var( step=self.global_step, var_type=self.config["patch_loss_weight"]["var_type"], options=self.config["patch_loss_weight"]["options"], ) patch_loss_weighted = patch_loss * patch_loss_weight self.log_ops["patch_loss"] = patch_loss self.log_ops["patch_loss_weight"] = patch_loss_weight self.log_ops["patch_loss_weighted"] = patch_loss_weighted if not self.config.get("pretrain", False): losses["decoder_visualize"] = (auto_rec_loss + prior_gmrf_weighted + mask0_kl_weighted + weakly_superv_loss_p_weighted + variance_weighted + p_mumford_sha + area_cost + patch_loss_weighted) else: losses["decoder_visualize"] = auto_rec_loss # mi estimators loss_dis0 = 0.5 * (logit_loss(self.model.logit_joint0, real=True) + logit_loss(self.model.logit_marginal0, real=False)) losses["mi0_discriminator"] = loss_dis0 loss_dis1 = 0.5 * (logit_loss(self.model.logit_joint1, real=True) + logit_loss(self.model.logit_marginal1, real=False)) losses["mi1_discriminator"] = loss_dis1 # estimator losses["mi_estimator"] = 0.5 * ( logit_loss(self.model.mi_logit_joint, real=True) + logit_loss(self.model.mi_logit_marginal, real=False)) # accuracies of discriminators def acc(ljoint, lmarg): correct_joint = tf.reduce_sum(tf.cast(ljoint > 0.0, tf.int32)) correct_marginal = tf.reduce_sum(tf.cast(lmarg < 0.0, tf.int32)) accuracy = (correct_joint + correct_marginal) / (2 * tf.shape(ljoint)[0]) return accuracy dis0_accuracy = acc(self.model.logit_joint0, self.model.logit_marginal0) dis1_accuracy = acc(self.model.logit_joint1, self.model.logit_marginal1) est_accuracy = acc(self.model.mi_logit_joint, self.model.mi_logit_marginal) # averages avg_acc0 = make_ema(0.5, dis0_accuracy, self.update_ops) avg_acc1 = make_ema(0.5, dis1_accuracy, self.update_ops) avg_acc_error = make_ema(0.0, dis1_accuracy - dis0_accuracy, self.update_ops) self.log_ops["avg_acc_error"] = avg_acc_error avg_loss_dis0 = make_ema(1.0, loss_dis0, self.update_ops) avg_loss_dis1 = make_ema(1.0, loss_dis1, self.update_ops) # Parameters MI_TARGET = self.config["MI"].get("mi_target", 0.125) MI_SLACK = self.config["MI"].get("mi_slack", 0.05) LOO_TOL = 0.025 LON_LR = 0.05 LON_ADAPTIVE = False LOA_INIT = self.config["MI"].get("loa_init", 0.0) LOA_LR = self.config["MI"].get("loa_lr", 4.0) LOA_ADAPTIVE = self.config["MI"].get("loa_adaptive", True) LOR_INIT = self.config["MI"].get("lor_init", 7.5) LOR_LR = self.config["MI"].get("lor_lr", 0.05) LOR_MIN = self.config["MI"].get("lor_min", 1.0) LOR_MAX = self.config["MI"].get("lor_max", 7.5) LOR_ADAPTIVE = self.config["MI"].get("lor_adaptive", True) # delta mi minimization mim_constraint = logit_constraint(self.model.logit_joint0, real=False) independent_mim_constraint = logit_constraint(self.model.logit_joint1, real=False) # level of overpowering avg_mim_constraint = tf.maximum( 0.0, make_ema(0.0, mim_constraint, self.update_ops)) avg_independent_mim_constraint = tf.maximum( 0.0, make_ema(0.0, independent_mim_constraint, self.update_ops)) self.log_ops["avg_mim"] = avg_mim_constraint self.log_ops["avg_independent_mim"] = avg_independent_mim_constraint loo = (avg_independent_mim_constraint - avg_mim_constraint) / (avg_independent_mim_constraint + 1e-6) loo = tf.clip_by_value(loo, 0.0, 1.0) self.log_ops["loo"] = loo # level of noise lon_gain = -loo + LOO_TOL self.log_ops["lon_gain"] = lon_gain lon_lr = LON_LR new_lon = tf.clip_by_value(self.model.lon + lon_lr * lon_gain, 0.0, 1.0) if LON_ADAPTIVE: update_lon = tf.assign(self.model.lon, new_lon) self.update_ops.append(update_lon) self.log_ops["model_lon"] = self.model.lon # OPTION adversarial_regularization = self.config.get( "adversarial_regularization", True) if adversarial_regularization: # level of attack - estimate of lagrange multiplier for mi constraint initial_loa = LOA_INIT loa = tf.Variable(initial_loa, dtype=tf.float32, trainable=False) loa_gain = mim_constraint - (1.0 - MI_SLACK) * MI_TARGET loa_lr = tf.constant(LOA_LR) new_loa = loa + loa_lr * loa_gain new_loa = tf.maximum(0.0, new_loa) if LOA_ADAPTIVE: update_loa = tf.assign(loa, new_loa) self.update_ops.append(update_loa) adversarial_active = tf.stop_gradient( tf.to_float(loa_lr * loa_gain >= -loa)) adversarial_weighted_loss = adversarial_active * ( loa * loa_gain + loa_lr / 2.0 * tf.square(loa_gain)) else: adversarial_weighted_loss = loa * loa_gain losses["encoder_0"] += adversarial_weighted_loss # use lor # OPTION variational_regularization = self.config.get( "variational_regularization", True) if variational_regularization: assert self.model.stochastic_encoder_0 bottleneck_loss = self.model.z_00_distribution.kl() # (log) level of regularization initial_lor = LOR_INIT lor = tf.Variable(initial_lor, dtype=tf.float32, trainable=False) lor_gain = independent_mim_constraint - MI_TARGET lor_lr = LOR_LR new_lor = tf.clip_by_value(lor + lor_lr * lor_gain, LOR_MIN, LOR_MAX) if LOR_ADAPTIVE: update_lor = tf.assign(lor, new_lor) self.update_ops.append(update_lor) beta_0 = self.config.get("beta_0", 1.0) bottleneck_weighted_loss = beta_0 * tf.exp( lor) * bottleneck_loss # TODO: hardcoded losses["encoder_0"] += bottleneck_weighted_loss else: assert not self.model.stochastic_encoder_0 # logging for k in losses: self.log_ops["loss_{}".format(k)] = losses[k] self.log_ops["dis0_accuracy"] = dis0_accuracy self.log_ops["dis1_accuracy"] = dis1_accuracy self.log_ops["avg_dis0_accuracy"] = avg_acc0 self.log_ops["avg_dis1_accuracy"] = avg_acc1 self.log_ops["avg_loss_dis0"] = avg_loss_dis0 self.log_ops["avg_loss_dis1"] = avg_loss_dis1 self.log_ops["est_accuracy"] = est_accuracy self.log_ops["mi_constraint"] = mim_constraint self.log_ops["independent_mi_constraint"] = independent_mim_constraint if adversarial_regularization: self.log_ops["adversarial_weight"] = loa self.log_ops["adversarial_constraint"] = mim_constraint self.log_ops[ "adversarial_weighted_loss"] = adversarial_weighted_loss self.log_ops["loa"] = loa self.log_ops["loa_gain"] = loa_gain if variational_regularization: self.log_ops["bottleneck_weight"] = lor self.log_ops["bottleneck_loss"] = bottleneck_loss self.log_ops["bottleneck_weighted_loss"] = bottleneck_weighted_loss self.log_ops["lor"] = lor beta_0 = self.config.get("beta_0", 1.0) self.log_ops["explor"] = beta_0 * tf.exp(lor) self.log_ops["lor_gain"] = lor_gain visualize_mask("out_parts_soft", tf.to_float(self.model.out_parts_soft), self.img_ops, True) visualize_mask("m0_sample", self.model.m0_sample, self.img_ops, True) cols = self.model.encoding_mask.shape.as_list()[0] encoding_masks = tf.concat([self.model.encoding_mask], 0) encoding_masks = tf_batches.tf_batch_to_canvas(encoding_masks, cols) visualize_mask("encoding_masks", tf.to_float(encoding_masks), self.img_ops, False) decoding_masks = tf.concat([self.model.decoding_mask], 0) decoding_masks = tf_batches.tf_batch_to_canvas(decoding_masks, cols) visualize_mask("decoding_masks", tf.to_float(decoding_masks), self.img_ops, False) masks = nn.take_only_first_item_in_the_batch(self.model.decoding_mask) masks = tf.transpose(masks, perm=[3, 1, 2, 0]) self.img_ops["masks"] = masks correspondence0 = tf.expand_dims( self.model.decoding_mask, axis=-1) * tf.expand_dims( self.model.augmented_views[0], axis=3) correspondence0 = tf.transpose(correspondence0, [3, 0, 1, 2, 4]) correspondence1 = tf.expand_dims( self.model.encoding_mask, axis=-1) * tf.expand_dims( self.model.augmented_views[1], axis=3) correspondence1 = tf.transpose(correspondence1, [3, 0, 1, 2, 4]) correspondence = tf.concat([correspondence0, correspondence1], axis=1) N_PARTS, _, H, W, _ = correspondence.shape.as_list() def make_grid(X): X = tf.squeeze(X) return tf_batches.tf_batch_to_canvas(X) correspondence = list( map(make_grid, tf.split(correspondence, N_PARTS, 0))) correspondence = tf_batches.tf_batch_to_canvas( tf.concat(correspondence, 0), 5) self.img_ops.update({"assigned_parts": correspondence}) p_levels = [0.01, 0.05, 0.1, 0.25, 0.5, 0.75, 0.9] m = nn.take_only_first_item_in_the_batch(self.model.m0_sample) def levels(m, i): return tf.concat([tf.to_float(m[..., i] > p) for p in p_levels], axis=0) level_sets = tf.concat([levels(m, i) for i in range(m.shape[3])], 0) level_sets = tf.expand_dims(level_sets, -1) level_sets = tf_batches.tf_batch_to_canvas(level_sets, len(p_levels)) img_title = "m0_sample_levels" + "-{}" * len(p_levels) img_title = img_title.format(*p_levels).replace(".", "_") ratios = [1.0e-3, 5 * 1.0e-3, 1.0e-2, 5 * 1.0e-2] mumford_sha_sets = tf.concat([nn.edge_set(m, 1, r) for r in ratios], axis=0) # 5, H, W, 25 mumford_sha_sets = tf.transpose(mumford_sha_sets, perm=[3, 0, 1, 2]) a, b, h, w = mumford_sha_sets.shape.as_list() mumford_sha_sets = tf.reshape(mumford_sha_sets, (a * b, h, w)) mumford_sha_sets = tf.expand_dims(mumford_sha_sets, -1) mumford_sha_sets = tf_batches.tf_batch_to_canvas( mumford_sha_sets, len(ratios)) p_heatmap = tf.squeeze(nn.colorize(m)) # [H, W, 25, 3] p_heatmap = tf.transpose(p_heatmap, perm=[2, 0, 1, 3]) self.img_ops.update({ "mumford_sha_edges": mumford_sha_sets, "p_heatmap": p_heatmap, img_title: level_sets, "view0": self.model.inputs["view0"], "view1": self.model.inputs["view1"], "view0_target": self.model.inputs["view0_target"], "cross": self.model._cross_generated, "generated": self.model._generated, }) if self.model.use_tps: self.img_ops.update({ "tps_view0": self.model.outputs["tps_view0"], "tps_view1": self.model.outputs["tps_view1"], "tps_view0_target": self.model.outputs["tps_view0_target"], }) self.log_ops["zr_mumford_sha"] = p_mumford_sha self.log_ops["z_mumford_sha_smoothness_cost"] = smoothness_cost self.log_ops["z_mumford_sha_contour_cost"] = contour_cost self.log_ops["z_area_cost"] = area_cost self.log_ops["prior_mumford_sha_weight"] = prior_mumford_sha_weight for k in self.config.get("fix_weights", []): if k in losses.keys(): # losses[k] = tf.no_op() del losses[k] else: pass return losses
def make_grid(X): X = tf.squeeze(X) return tf_batches.tf_batch_to_canvas(X)
def make_loss_ops(self): # perceptual network with tf.variable_scope("VGG19__noinit__"): gram_weight = self.config.get("gram_weight", 0.0) print("GRAM: {}".format(gram_weight)) self.vgg19 = vgg19 = deeploss.VGG19Features( self.session, default_gram=gram_weight, original_scale=True) dim = np.prod(self.model.inputs["view0"].shape.as_list()[1:]) auto_rec_loss = (1e-3 * 0.5 * dim * vgg19.make_loss_op( self.model.inputs["view0"], self.model._generated)) x0 = nn.take_only_first_item_in_the_batch(self.model.inputs["view0"]) x1 = nn.take_only_first_item_in_the_batch(self.model.inputs["view1"]) global_auto_rec_loss = ( 1e-3 * 0.5 * dim * vgg19.make_loss_op(x0, self.model.global_generated)) alpha_auto_rec_loss = ( 1e-3 * 0.5 * dim * vgg19.make_loss_op(x1, self.model.alpha_generated)) pi_auto_rec_loss = (1e-3 * 0.5 * dim * vgg19.make_loss_op(x0, self.model.pi_generated)) with tf.variable_scope("prior_gmrf_weight"): prior_gmrf_weight = make_var( step=self.global_step, var_type=self.config["prior_gmrf_weight"]["var_type"], options=self.config["prior_gmrf_weight"]["options"], ) with tf.variable_scope("prior_mumford_sha_weight"): prior_mumford_sha_weight = make_var( step=self.global_step, var_type=self.config["prior_mumford_sha_weight"]["var_type"], options=self.config["prior_mumford_sha_weight"]["options"], ) with tf.variable_scope("kl_weight"): kl_weight = make_linear_var(step=self.global_step, **self.config["kl_weight"]) mumford_sha_alpha = make_var( step=self.global_step, var_type=self.config["mumford_sha_alpha"]["var_type"], options=self.config["mumford_sha_alpha"]["options"], ) mumford_sha_lambda = make_var( step=self.global_step, var_type=self.config["mumford_sha_lambda"]["var_type"], options=self.config["mumford_sha_lambda"]["options"], ) self.log_ops["mumford_sha_lambda"] = mumford_sha_lambda self.log_ops["mumford_sha_alpha"] = mumford_sha_alpha # smoothness prior (phase energy) prior_gmrf = self.smoothness_prior_simple_gradient() prior_gmrf_weighted = prior_gmrf_weight * prior_gmrf mumford_sha_prior = self.smoothness_prior_mumfordsha( mumford_sha_alpha, mumford_sha_lambda) prior_mumford_sha_weighted = mumford_sha_prior * prior_mumford_sha_weight self.log_ops["prior_gmrf"] = prior_gmrf self.log_ops["prior_gmrf_weight"] = prior_gmrf_weight self.log_ops["prior_gmrf_weighted"] = prior_gmrf_weighted self.log_ops["prior_mumford_sha"] = mumford_sha_prior self.log_ops["prior_mumford_sha_weight"] = prior_mumford_sha_weight self.log_ops["prior_mumford_sha_weighted"] = prior_mumford_sha_weighted mask0_kl = tf.reduce_sum([ categorical_kl(m) for m in [self.model.m0_sample, self.model.m1_sample] ]) mask0_kl_weighted = kl_weight * mask0_kl # TODO: categorical KL self.log_ops["mask0_kl_weight"] = kl_weight self.log_ops["mask0_kl"] = mask0_kl self.log_ops["mask0_kl_weighted"] = mask0_kl_weighted log_probs = self.model.m0_logits p_labels = nn.softmax(log_probs, spatial=False) labels = nn.hard_max(p_labels, 3) labels = nn.straight_through_estimator(labels, p_labels) weakly_superv_loss_p = tf.nn.softmax_cross_entropy_with_logits_v2( labels, log_probs, dim=3) weakly_superv_loss_p = tf.reduce_mean(weakly_superv_loss_p) gamma = self.config.get("gamma", 3.0) # corrected_logits_1 = self.model.m11_sample ** gamma corrected_logits_1 = nn.softmax(self.model.m1_logits, spatial=False) corrected_logits_1 = nn.softmax(corrected_logits_1, spatial=True) corrected_logits_1 /= tf.reduce_sum(corrected_logits_1, axis=(1, 2), keep_dims=True) N, h, w, P = self.model.m1_sample.shape.as_list() _, sigma = nn.probs_to_mu_sigma(corrected_logits_1, tf.ones((N, P))) for i in range(sigma.shape.as_list()[1]): self.log_ops["sigma1_{:02d}".format(i)] = sigma[0, i, 0, 0] for i in range(sigma.shape.as_list()[1]): self.log_ops["sigma2_{:02d}".format(i)] = sigma[0, i, 1, 1] sigma_filter = tf.reduce_sum(self.model.m1_sample_hard, axis=(1, 2), keep_dims=True) sigma_filter = tf.to_float(sigma_filter > 0.0) sigma_filter = tf.transpose(sigma_filter, perm=[0, 3, 1, 2]) for i in range(sigma.shape.as_list()[1]): self.log_ops["sigma_active_{:02d}".format(i)] = sigma_filter[0, i, 0, 0] # sigma *= tf.stop_gradient(sigma_filter) # for i in range(sigma.shape.as_list()[1]): # self.log_ops["sigma2_filtered_{:02d}".format(i)] = sigma[0, i, 1, 1] variances = tf.reduce_mean( tf.reduce_sum(sigma[:, :, 0, 0]**2 + sigma[:, :, 1, 1]**2, axis=1)) with tf.variable_scope("variance_weight"): variance_weight = make_var( step=self.global_step, var_type=self.config["variance_weight"]["var_type"], options=self.config["variance_weight"]["options"], ) variance_weighted = variance_weight * variances self.log_ops["variance_loss_weighted"] = variance_weighted self.log_ops["variance_loss"] = variances self.log_ops["variance_weight"] = variance_weight with tf.variable_scope("weakly_superv_loss_weight_p"): weakly_superv_loss_weight_p = make_var( step=self.global_step, var_type=self.config["weakly_superv_loss_weight_p"] ["var_type"], options=self.config["weakly_superv_loss_weight_p"]["options"], ) weakly_superv_loss_p_weighted = (weakly_superv_loss_p * weakly_superv_loss_weight_p) self.log_ops[ "weakly_superv_loss_weight_p"] = weakly_superv_loss_weight_p self.log_ops["weakly_superv_loss_p"] = weakly_superv_loss_p self.log_ops[ "weakly_superv_loss_p_weighted"] = weakly_superv_loss_p_weighted # per submodule loss losses = dict() # delta losses["encoder_0"] = auto_rec_loss + global_auto_rec_loss losses["encoder_1"] = auto_rec_loss + global_auto_rec_loss losses["d_single"] = global_auto_rec_loss losses["d_alpha"] = alpha_auto_rec_loss losses["d_pi"] = pi_auto_rec_loss # decoder losses["decoder_delta"] = auto_rec_loss # z_rec_loss = tf.reduce_sum( # tf.square(self.model.reconstructed_pi1 - self.model.pi_sample_v1), # axis=(1, 2, 3), # ) # z_rec_loss = tf.reduce_mean( # z_rec_loss # ) + self.model.z_distribution.kl_to_shifted_standard_normal( # self.model.z_prior_mean # ) # losses["z_decoder"] = z_rec_loss # losses["z_encoder"] = z_rec_loss if not self.config.get("pretrain", False): losses["decoder_visualize"] = (auto_rec_loss + prior_gmrf_weighted + prior_mumford_sha_weighted + mask0_kl_weighted + weakly_superv_loss_p_weighted + variance_weighted) else: losses["decoder_visualize"] = auto_rec_loss # mi estimators loss_dis0 = 0.5 * (logit_loss(self.model.logit_joint0, real=True) + logit_loss(self.model.logit_marginal0, real=False)) losses["mi0_discriminator"] = loss_dis0 loss_dis1 = 0.5 * (logit_loss(self.model.logit_joint1, real=True) + logit_loss(self.model.logit_marginal1, real=False)) losses["mi1_discriminator"] = loss_dis1 # estimator losses["mi_estimator"] = 0.5 * ( logit_loss(self.model.mi_logit_joint, real=True) + logit_loss(self.model.mi_logit_marginal, real=False)) # accuracies of discriminators def acc(ljoint, lmarg): correct_joint = tf.reduce_sum(tf.cast(ljoint > 0.0, tf.int32)) correct_marginal = tf.reduce_sum(tf.cast(lmarg < 0.0, tf.int32)) accuracy = (correct_joint + correct_marginal) / (2 * tf.shape(ljoint)[0]) return accuracy dis0_accuracy = acc(self.model.logit_joint0, self.model.logit_marginal0) dis1_accuracy = acc(self.model.logit_joint1, self.model.logit_marginal1) est_accuracy = acc(self.model.mi_logit_joint, self.model.mi_logit_marginal) # averages avg_acc0 = make_ema(0.5, dis0_accuracy, self.update_ops) avg_acc1 = make_ema(0.5, dis1_accuracy, self.update_ops) avg_acc_error = make_ema(0.0, dis1_accuracy - dis0_accuracy, self.update_ops) self.log_ops["avg_acc_error"] = avg_acc_error avg_loss_dis0 = make_ema(1.0, loss_dis0, self.update_ops) avg_loss_dis1 = make_ema(1.0, loss_dis1, self.update_ops) # Parameters MI_TARGET = self.config["MI"].get("mi_target", 0.125) MI_SLACK = self.config["MI"].get("mi_slack", 0.05) LOO_TOL = 0.025 LON_LR = 0.05 LON_ADAPTIVE = False LOA_INIT = self.config["MI"].get("loa_init", 0.0) LOA_LR = self.config["MI"].get("loa_lr", 4.0) LOA_ADAPTIVE = self.config["MI"].get("loa_adaptive", True) LOR_INIT = self.config["MI"].get("lor_init", 7.5) LOR_LR = self.config["MI"].get("lor_lr", 0.05) LOR_MIN = self.config["MI"].get("lor_min", 1.0) LOR_MAX = self.config["MI"].get("lor_max", 7.5) LOR_ADAPTIVE = self.config["MI"].get("lor_adaptive", True) # delta mi minimization mim_constraint = logit_constraint(self.model.logit_joint0, real=False) independent_mim_constraint = logit_constraint(self.model.logit_joint1, real=False) # level of overpowering avg_mim_constraint = tf.maximum( 0.0, make_ema(0.0, mim_constraint, self.update_ops)) avg_independent_mim_constraint = tf.maximum( 0.0, make_ema(0.0, independent_mim_constraint, self.update_ops)) self.log_ops["avg_mim"] = avg_mim_constraint self.log_ops["avg_independent_mim"] = avg_independent_mim_constraint loo = (avg_independent_mim_constraint - avg_mim_constraint) / (avg_independent_mim_constraint + 1e-6) loo = tf.clip_by_value(loo, 0.0, 1.0) self.log_ops["loo"] = loo # level of noise lon_gain = -loo + LOO_TOL self.log_ops["lon_gain"] = lon_gain lon_lr = LON_LR new_lon = tf.clip_by_value(self.model.lon + lon_lr * lon_gain, 0.0, 1.0) if LON_ADAPTIVE: update_lon = tf.assign(self.model.lon, new_lon) self.update_ops.append(update_lon) self.log_ops["model_lon"] = self.model.lon # OPTION adversarial_regularization = True if adversarial_regularization: # level of attack - estimate of lagrange multiplier for mi constraint initial_loa = LOA_INIT loa = tf.Variable(initial_loa, dtype=tf.float32, trainable=False) loa_gain = mim_constraint - (1.0 - MI_SLACK) * MI_TARGET loa_lr = tf.constant(LOA_LR) new_loa = loa + loa_lr * loa_gain new_loa = tf.maximum(0.0, new_loa) if LOA_ADAPTIVE: update_loa = tf.assign(loa, new_loa) self.update_ops.append(update_loa) adversarial_active = tf.stop_gradient( tf.to_float(loa_lr * loa_gain >= -loa)) adversarial_weighted_loss = adversarial_active * ( loa * loa_gain + loa_lr / 2.0 * tf.square(loa_gain)) else: adversarial_weighted_loss = loa * loa_gain losses["encoder_0"] += adversarial_weighted_loss # use lor # OPTION variational_regularization = True if variational_regularization: assert self.model.stochastic_encoder_0 bottleneck_loss = self.model.z_00_distribution.kl() # (log) level of regularization initial_lor = LOR_INIT lor = tf.Variable(initial_lor, dtype=tf.float32, trainable=False) lor_gain = independent_mim_constraint - MI_TARGET lor_lr = LOR_LR new_lor = tf.clip_by_value(lor + lor_lr * lor_gain, LOR_MIN, LOR_MAX) if LOR_ADAPTIVE: update_lor = tf.assign(lor, new_lor) self.update_ops.append(update_lor) bottleneck_weighted_loss = tf.exp(lor) * bottleneck_loss losses["encoder_0"] += bottleneck_weighted_loss else: assert not self.model.stochastic_encoder_0 if self.model.pretty: weight_pretty = self.config.get("weight_pretty", 4e-3) pw = weight_pretty * 0.5 * dim # pretty loss_dis_pretty = 0.5 * ( logit_loss(self.model.logit_pretty_orig, real=True) + logit_loss(self.model.logit_pretty_fake, real=False)) losses["pretty_discriminator"] = loss_dis_pretty self.log_ops["pretty_dis_logit"] = loss_dis_pretty # gradient penalty on real data gp_pretty = gradient_penalty(losses["pretty_discriminator"], [self.model.views[0]]) self.log_ops["gp_pretty"] = gp_pretty gp_weight = 1e2 losses["pretty_discriminator"] += gp_weight * gp_pretty # fakor loss_dec_pretty = logit_loss(self.model.logit_pretty_fake, real=True) self.log_ops["pretty_dec_logit"] = loss_dec_pretty loss_dec_pretty = loss_dec_pretty pretty_objective = (2.0 * pw * loss_dec_pretty ) # backward comp for comparison losses["decoder_delta"] += pretty_objective # logging for k in losses: self.log_ops["loss_{}".format(k)] = losses[k] self.log_ops["dis0_accuracy"] = dis0_accuracy self.log_ops["dis1_accuracy"] = dis1_accuracy self.log_ops["avg_dis0_accuracy"] = avg_acc0 self.log_ops["avg_dis1_accuracy"] = avg_acc1 self.log_ops["avg_loss_dis0"] = avg_loss_dis0 self.log_ops["avg_loss_dis1"] = avg_loss_dis1 self.log_ops["est_accuracy"] = est_accuracy self.log_ops["mi_constraint"] = mim_constraint self.log_ops["independent_mi_constraint"] = independent_mim_constraint if adversarial_regularization: self.log_ops["adversarial_weight"] = loa self.log_ops["adversarial_constraint"] = mim_constraint self.log_ops[ "adversarial_weighted_loss"] = adversarial_weighted_loss self.log_ops["loa"] = loa self.log_ops["loa_gain"] = loa_gain if variational_regularization: self.log_ops["bottleneck_weight"] = lor self.log_ops["bottleneck_loss"] = bottleneck_loss self.log_ops["bottleneck_weighted_loss"] = bottleneck_weighted_loss self.log_ops["lor"] = lor self.log_ops["explor"] = tf.exp(lor) self.log_ops["lor_gain"] = lor_gain visualize_mask("out_parts_soft", tf.to_float(self.model.out_parts_soft), self.img_ops, True) visualize_mask("m0_sample", self.model.m0_sample, self.img_ops, True) cols = self.model.encoding_mask.shape.as_list()[0] encoding_masks = tf.concat([self.model.encoding_mask], 0) encoding_masks = tf_batches.tf_batch_to_canvas(encoding_masks, cols) visualize_mask("encoding_masks", tf.to_float(encoding_masks), self.img_ops, False) decoding_masks = tf.concat([self.model.decoding_mask], 0) decoding_masks = tf_batches.tf_batch_to_canvas(decoding_masks, cols) visualize_mask("decoding_masks", tf.to_float(decoding_masks), self.img_ops, False) # self.img_ops["part_generated"] = self.model._part_generated self.img_ops["global_generated"] = self.model.global_generated self.img_ops["alpha_generated"] = self.model.alpha_generated self.img_ops["pi_generated"] = self.model.pi_generated masks = nn.take_only_first_item_in_the_batch(self.model.decoding_mask) masks = tf.transpose(masks, perm=[3, 1, 2, 0]) self.img_ops["masks"] = masks correspondence0 = tf.expand_dims(self.model.decoding_mask, axis=-1) * tf.expand_dims( self.model.views[0], axis=3) correspondence0 = tf.transpose(correspondence0, [3, 0, 1, 2, 4]) correspondence1 = tf.expand_dims(self.model.encoding_mask, axis=-1) * tf.expand_dims( self.model.views[1], axis=3) correspondence1 = tf.transpose(correspondence1, [3, 0, 1, 2, 4]) correspondence = tf.concat([correspondence0, correspondence1], axis=1) N_PARTS, _, H, W, _ = correspondence.shape.as_list() def make_grid(X): X = tf.squeeze(X) return tf_batches.tf_batch_to_canvas(X) correspondence = list( map(make_grid, tf.split(correspondence, N_PARTS, 0))) correspondence = tf_batches.tf_batch_to_canvas( tf.concat(correspondence, 0), 5) self.img_ops.update({"assigned_parts": correspondence}) # self.img_ops.update( # {"part_samples_generated": self.model.part_samples_generated} # ) self.img_ops.update({ "view0": self.model.inputs["view0"], "view1": self.model.inputs["view1"], "cross": self.model._cross_generated, "generated": self.model._generated, }) for k in self.config.get("fix_weights", []): if k in losses.keys(): # losses[k] = tf.no_op() del losses[k] else: pass return losses