Beispiel #1
0
                          1e-10, dtype=tf.float32)  # simulate hard sort
                      )

experiment_id = 'sort-%s-M%d-n%d-l%d-t%d' % (method, M, n, l, tau * 10)
checkpoint_path = 'checkpoints/%s/' % experiment_id

handle = tf.placeholder(tf.string, ())
X_iterator = tf.data.Iterator.from_string_handle(
    handle,
    (tf.float32, tf.float32, tf.float32, tf.float32),
    ((M, n, l * 28, 28), (M,), (M, n), (M, n))
)

X, y, median_scores, true_scores = X_iterator.get_next()
true_scores = tf.expand_dims(true_scores, 2)
P_true = util.neuralsort(true_scores, 1e-10)

if method == 'vanilla':
    representations = multi_mnist_cnn.deepnn(l, X, n)
    concat_reps = tf.reshape(representations, [M, n * n])
    fc1 = tf.layers.dense(concat_reps, n * n)
    fc2 = tf.layers.dense(fc1, n * n)
    P_hat_raw = tf.layers.dense(fc2, n * n)
    P_hat_raw_square = tf.reshape(P_hat_raw, [M, n, n])

    P_hat = tf.nn.softmax(P_hat_raw_square, dim=-1)  # row-stochastic!

    losses = tf.nn.softmax_cross_entropy_with_logits_v2(
        labels=P_true, logits=P_hat_raw_square, dim=2)
    losses = tf.reduce_mean(losses, axis=-1)
    loss = tf.reduce_mean(losses)
Beispiel #2
0
temp = tf.cond(evaluation,
               false_fn=lambda: tf.convert_to_tensor(tau, dtype=tf.float32),
               true_fn=lambda: tf.convert_to_tensor(1e-10, dtype=tf.float32))

experiment_id = 'median-%s-M%d-n%d-l%d-t%d' % (method, M, n, l, tau * 10)
checkpoint_path = 'checkpoints/%s/' % experiment_id

handle = tf.placeholder(tf.string, ())
X_iterator = tf.data.Iterator.from_string_handle(
    handle, (tf.float32, tf.float32, tf.float32, tf.float32),
    ((M, n, l * 28, 28), (M, ), (M, n), (M, n)))

X, y, median_scores, true_scores = X_iterator.get_next()

true_scores = tf.expand_dims(true_scores, 2)
P_true = util.neuralsort(true_scores, 1e-10)
n_prime = n


def get_median_probs(P):
    median_strip = P[:, n // 2, :]
    median_total = tf.reduce_sum(median_strip, axis=1, keepdims=True)
    probs = median_strip / median_total
    # print(probs)
    return probs


if method == 'vanilla':
    with tf.variable_scope("phi"):
        representations = multi_mnist_cnn.deepnn(l, X, 10)
    representations = tf.reshape(representations, [M, n * 10])