fc1 = tf.layers.dense(concat_reps, n * n) fc2 = tf.layers.dense(fc1, n * n) P_hat_raw = tf.layers.dense(fc2, n * n) P_hat_raw_square = tf.reshape(P_hat_raw, [M, n, n]) P_hat = tf.nn.softmax(P_hat_raw_square, dim=-1) # row-stochastic! losses = tf.nn.softmax_cross_entropy_with_logits_v2( labels=P_true, logits=P_hat_raw_square, dim=2) losses = tf.reduce_mean(losses, axis=-1) loss = tf.reduce_mean(losses) if method == 'sinkhorn': representations = multi_mnist_cnn.deepnn(l, X, n) pre_sinkhorn = tf.reshape(representations, [M, n, n]) P_hat = sinkhorn_operator(pre_sinkhorn, temp=temperature) P_hat_logit = tf.log(P_hat) losses = tf.nn.softmax_cross_entropy_with_logits_v2( labels=P_true, logits=P_hat_logit, dim=2) losses = tf.reduce_mean(losses, axis=-1) loss = tf.reduce_mean(losses) if method == 'gumbel_sinkhorn': representations = multi_mnist_cnn.deepnn(l, X, n) pre_sinkhorn = tf.reshape(representations, [M, n, n]) P_hat = sinkhorn_operator(pre_sinkhorn, temp=temperature) P_hat_sample, _ = gumbel_sinkhorn( pre_sinkhorn, temp=temperature, n_samples=n_s) P_hat_sample_logit = tf.log(P_hat_sample)
fc3 = tf.layers.dense(fc2, 10, tf.nn.relu) y_hat = tf.layers.dense(fc3, 1) y_hat = tf.squeeze(y_hat) loss_phi = tf.reduce_sum(tf.squared_difference(y_hat, y)) loss_theta = loss_phi prob_median_eval = 0 elif method == 'sinkhorn': with tf.variable_scope('phi'): representations = multi_mnist_cnn.deepnn(l, X, n) pre_sinkhorn = tf.reshape(representations, [M, n, n]) with tf.variable_scope('theta'): regression_candidates = multi_mnist_cnn.deepnn(l, X, 1) regression_candidates = tf.reshape(regression_candidates, [M, n]) P_hat = sinkhorn_operator(pre_sinkhorn, temp=temp) prob_median = get_median_probs(P_hat) point_estimates = tf.reduce_sum(prob_median * regression_candidates, axis=1) exp_loss = tf.squared_difference(y, point_estimates) loss_phi = tf.reduce_mean(exp_loss) loss_theta = loss_phi P_hat_eval = sinkhorn_operator(pre_sinkhorn, temp=1e-20) prob_median_eval = get_median_probs(P_hat_eval) elif method == 'gumbel_sinkhorn': with tf.variable_scope('phi'): representations = multi_mnist_cnn.deepnn(l, X, n)