Пример #1
0
 def test_euclidean_distance(self):
     with self.test_session():
         x = tf.constant([[0, 1], [0, 1], [1, 0]])
         y = tf.constant([[0, 1], [0, 2], [-1, 0]])
         dists_t = ops.euclidean_distance(x, y)
         dists = dists_t.eval()
         np.testing.assert_array_almost_equal(dists, np.array([0, 1, 2]))
Пример #2
0
def compute_distances(xe, ye, I, train=True):
    r"""
    Computes pairwise distances for all pairs of query items and
    potential neighbors.

    :param xe: BxNxE tensor of database item embeddings
    :param ye: BxMxE tensor of query item embeddings
    :param I: BxMxO index tensor that selects O potential neighbors for each item in ye
    :param train: whether to use tensor comprehensions for inference (forward only)

    :return: a BxMxO tensor of distances
    """

    # xe -> b n e
    # ye -> b m e
    # I  -> b m o
    b,n, e = xe.shape
    m = ye.shape[1]
    o = I.shape[2]

    if not train:
        # xe_ind -> b m o e
        If = I.view(b, m*o,1).expand(b,m*o,e)

        # D -> b m o
        ye = ye.unsqueeze(3)

        if ops.has_tensor_comprehensions():
            D = -2*ops.indexed_matmul_1_tc(xe, ye.squeeze(3), I).unsqueeze(3)
        else:
            # This is slower than inference with tensor comprehensions :(
            D = -2*torch.cat([ops.indexed_matmul_1(xe, ye[:,i:i+10,:,:].squeeze(3), I[:,i:i+10,:]).unsqueeze(3) for i in range(0,m,10)], dim=1)

        xe_sqs = (xe**2).sum(dim=-1, keepdim=True)
        xe_sqs_ind = xe_sqs.gather(dim=1, index=If[:,:,0:1]).view(b,m,o,1)
        D += xe_sqs_ind
        D += (ye**2).sum(dim=-2, keepdim=True)

        D = D.squeeze(3)
    else:
        # D_full -> b m n
        D_full = ops.euclidean_distance(ye, xe.permute(0,2,1))

        # D -> b m o
        D = D_full.gather(dim=2, index=I)

    return -D
Пример #3
0
def run_experiment(targets_fn, run_name):
    dataset_fn = datasets.mnist_digits
    model_fn = models.mlp_model
    batch_size = 100
    nat_loss_coefficient = 20.0
    optimizer = tf.train.AdamOptimizer()
    eval_steps = 200

    (train_x, _), (validation_x, _), _ = dataset_fn()
    targets = targets_fn(len(train_x))

    input_t = tf.placeholder(dtype=tf.float32,
                             name='input_t',
                             shape=(None, *train_x.shape[1:]))
    target_t = tf.placeholder_with_default(input=np.zeros(
        (batch_size, targets.shape[1])),
                                           name='input_t',
                                           shape=(None, targets.shape[1]))

    reconstructed_t, z_t = model_fn(input_t, targets.shape[1])

    mean_reconstruction_loss_t = tf.nn.l2_loss(reconstructed_t -
                                               input_t) / batch_size

    cost_matrix_t = ops.cost_matrix(z_t,
                                    target_t,
                                    loss_func=ops.euclidean_distance)
    new_assignment_indices = ops.hungarian_method(cost_matrix_t)
    new_targets_t = tf.gather(target_t, new_assignment_indices)

    nat_loss_t = ops.euclidean_distance(new_targets_t, z_t)
    mean_nat_loss_t = tf.reduce_mean(nat_loss_t)

    total_loss = (mean_nat_loss_t *
                  nat_loss_coefficient) + mean_reconstruction_loss_t

    global_step_t = tf.train.get_or_create_global_step()
    train_op = optimizer.minimize(total_loss, global_step_t)

    tf.summary.scalar('mean_nat_loss', mean_nat_loss_t)
    tf.summary.scalar('mean_reconstruction_loss/train',
                      mean_reconstruction_loss_t)

    logdir = os.path.join('model_logs', run_name)
    sess = tf.train.MonitoredTrainingSession(checkpoint_dir=logdir,
                                             save_checkpoint_secs=200,
                                             save_summaries_steps=200)
    metric_logger = metric_logging.TFEventsLogger(log_dir=logdir)

    logging.info('Training')
    while True:
        batch_indices = np.random.choice(len(train_x), size=batch_size)
        batch_images = train_x[batch_indices]
        batch_target = targets[batch_indices]

        current_step, new_targets, _ = sess.run(
            [global_step_t, new_targets_t, train_op],
            feed_dict={
                input_t: batch_images,
                target_t: batch_target
            })

        targets[batch_indices] = new_targets

        if current_step % eval_steps == 2:
            # validation_session_results = [
            #     sess.run(
            #         [mean_reconstruction_loss_t, reconstructed_t],
            #         feed_dict={
            #             input_t: validation_x[i:i + batch_size],
            #         }
            #     ) for i in range(0, len(validation_x), batch_size)
            # ]
            validation_session_results = sess.run(
                [mean_reconstruction_loss_t, reconstructed_t],
                feed_dict={
                    input_t: validation_x,
                })

            validation_reconstruction_loss = np.mean(
                [v[0] for v in validation_session_results])
            # just get the first 100 of validation images
            validation_images = validation_session_results[0][1]

            metric_logger.log_scalar('mean_reconstruction_loss/validation',
                                     validation_reconstruction_loss,
                                     current_step)
            metric_logger.log_images(
                'validation_reconstructed',
                # workaround needed to save monochrome images
                validation_images[:10][:, :, :, 0],
                current_step)
Пример #4
0
def run_experiment(dataset_fn,
                   model_fn,
                   batching_fn,
                   batch_size,
                   run_name,
                   eval_steps,
                   train_steps,
                   config_path=None):
    optimizer = tf.train.AdamOptimizer(1e-3)

    train_x, targets = dataset_fn()
    target_assignments = np.arange(len(targets))

    assert len(train_x) == len(targets)

    assert len(train_x) % batch_size == 0, \
        'batch_size must be a multiple for validation size due to laziness'

    input_t = tf.placeholder(dtype=tf.float32,
                             name='input_t',
                             shape=(batch_size, *train_x.shape[1:]))
    target_t = tf.placeholder_with_default(input=np.zeros(
        (batch_size, targets.shape[1])),
                                           name='input_t',
                                           shape=(batch_size,
                                                  targets.shape[1]))

    z_t = model_fn(input_t, targets.shape[1])

    cost_matrix_t = ops.cost_matrix(z_t,
                                    target_t,
                                    loss_func=ops.euclidean_distance)
    new_assignment_indices_t = ops.hungarian_method(
        tf.expand_dims(cost_matrix_t, 0))[0]
    new_targets_t = tf.gather(target_t, new_assignment_indices_t)

    nat_loss_t = ops.euclidean_distance(new_targets_t, z_t)
    mean_nat_loss_t = tf.reduce_mean(nat_loss_t)

    global_step_t = tf.train.get_or_create_global_step()

    train_op = optimizer.minimize(mean_nat_loss_t, global_step_t)

    tf.summary.scalar('nat_loss/train', mean_nat_loss_t)

    logdir = os.path.join('model_logs', run_name)
    sess = tf.train.MonitoredTrainingSession(checkpoint_dir=logdir,
                                             save_checkpoint_secs=200,
                                             save_summaries_steps=eval_steps)
    metric_logger = metric_logging.TensorboardLogger(
        writer=tf.summary.FileWriterCache.get(logdir))
    shutil.copyfile(config_path, f'{logdir}/config.py')
    assignments_path = f'{logdir}/assignments.p'
    if os.path.isfile(assignments_path):
        with open(assignments_path, 'rb') as f:
            target_assignments = pickle.load(f)

    train_noise_image, *_ = np.histogram2d(targets[:, 0],
                                           targets[:, 1],
                                           bins=(256, 256))
    metric_logger.log_images('train_noise_image',
                             [train_noise_image.T / train_noise_image.max()],
                             0)

    logging.info('Training')
    current_step = 0
    moving_nat_loss = 0.5
    moving_reassignment_fraction = 1.0

    while current_step < train_steps:
        batch_indices = batching_fn(batch_size=batch_size,
                                    targets=targets,
                                    context={
                                        'current_step':
                                        current_step,
                                        'average_l2_loss':
                                        moving_nat_loss,
                                        'moving_reassignment_fraction':
                                        moving_reassignment_fraction
                                    })

        batch_input = train_x[batch_indices]
        target_indices = target_assignments[batch_indices]
        batch_target = targets[target_indices]

        current_step, mean_loss_for_batch, new_assignment_indices, _ = sess.run(
            [
                global_step_t, mean_nat_loss_t, new_assignment_indices_t,
                train_op
            ],
            feed_dict={
                input_t: batch_input,
                target_t: batch_target
            })
        moving_nat_loss = (0.99 * moving_nat_loss) + (0.01 *
                                                      mean_loss_for_batch)
        fraction_changed = (new_assignment_indices != np.arange(
            len(new_assignment_indices))).mean()
        moving_reassignment_fraction = (
            0.99 * moving_reassignment_fraction) + (0.01 * fraction_changed)

        # Update target assignment
        target_assignments[batch_indices] = target_assignments[batch_indices][
            new_assignment_indices].copy()

        if current_step % eval_steps == 1:
            with open(assignments_path, 'wb') as f:
                pickle.dump(target_assignments, f)
            metric_logger.log_scalar('fraction_of_targets_changing',
                                     moving_reassignment_fraction,
                                     current_step)

            validation_results = [
                sess.run([z_t],
                         feed_dict={
                             input_t: train_x[i:i + batch_size],
                         }) for i in range(0, len(train_x), batch_size)
            ]
            validation_z = np.concatenate([x[0] for x in validation_results])

            validation_noise_image, *_ = np.histogram2d(validation_z[:, 0],
                                                        validation_z[:, 1],
                                                        bins=(256, 256))
            metric_logger.log_images(
                'validation_noise_image',
                [validation_noise_image.T / validation_noise_image.max()],
                current_step)