Example #1
0
def visualize_nearest_neighbours(model, data, global_step, batch_size,
                                 num_steps, num_frames_per_step, split):
    """Visualize nearest neighbours in embedding space."""
    # Set learning_phase to False to use models in inference mode.
    tf.keras.backend.set_learning_phase(0)

    cnn = model['cnn']
    emb = model['emb']

    cnn_feats = get_cnn_feats(cnn, data, training=False)
    emb_feats = emb(cnn_feats, num_steps)
    emb_feats = tf.stack(tf.split(emb_feats, num_steps, axis=0), axis=1)

    query_feats = emb_feats[0]

    frames = data['frames']
    image_list = tf.unstack(frames, num=batch_size, axis=0)
    im_list = [image_list[0][num_frames_per_step - 1::num_frames_per_step]]
    sim_matrix = np.zeros((batch_size - 1, num_steps, num_steps),
                          dtype=np.float32)

    for i in range(1, batch_size):
        candidate_feats = emb_feats[i]

        img_list = tf.unstack(image_list[i],
                              num=num_steps * num_frames_per_step,
                              axis=0)[num_frames_per_step -
                                      1::num_frames_per_step]
        nn_img_list = []

        for j in range(num_steps):
            curr_query_feats = tf.tile(query_feats[j:j + 1], [num_steps, 1])
            mean_squared_distance = tf.reduce_mean(tf.squared_difference(
                curr_query_feats, candidate_feats),
                                                   axis=1)
            sim_matrix[i - 1, j] = softmax(-1.0 * mean_squared_distance)
            nn_img_list.append(img_list[tf.argmin(mean_squared_distance)])

        nn_img = tf.stack(nn_img_list, axis=0)
        im_list.append(nn_img)

    def vstack(im):
        return tf.concat(tf.unstack(im, num=num_steps), axis=1)

    summary_im = tf.expand_dims(tf.concat([vstack(im) for im in im_list],
                                          axis=0),
                                axis=0)
    tf.summary.image('%s/nn' % split, summary_im, step=global_step)
    # Convert sim_matrix to float32 as summary_image doesn't take float64
    sim_matrix = sim_matrix.astype(np.float32)
    tf.summary.image('%s/similarity_matrix' % split,
                     np.expand_dims(sim_matrix, axis=3),
                     step=global_step)
Example #2
0
def _align_single_cycle(cycle, embs, cycle_length, num_steps, similarity_type,
                        temperature):
    """Takes a single cycle and returns logits (simialrity scores) and labels."""
    # Choose random frame.
    n_idx = tf.random_uniform((), minval=0, maxval=num_steps, dtype=tf.int32)
    # Create labels
    onehot_labels = tf.one_hot(n_idx, num_steps)

    # Choose query feats for first frame.
    query_feats = embs[cycle[0], n_idx:n_idx + 1]

    num_channels = tf.shape(query_feats)[-1]
    for c in range(1, cycle_length + 1):
        candidate_feats = embs[cycle[c]]

        if similarity_type == 'l2':
            # Find L2 distance.
            mean_squared_distance = tf.reduce_sum(tf.squared_difference(
                tf.tile(query_feats, [num_steps, 1]), candidate_feats),
                                                  axis=1)
            # Convert L2 distance to similarity.
            similarity = -mean_squared_distance

        elif similarity_type == 'cosine':
            # Dot product of embeddings.
            similarity = tf.squeeze(
                tf.matmul(candidate_feats, query_feats, transpose_b=True))
        else:
            raise ValueError('similarity_type can either be l2 or cosine.')

        # Scale the distance  by number of channels. This normalization helps with
        # optimization.
        similarity /= tf.cast(num_channels, tf.float32)
        # Scale the distance by a temperature that helps with how soft/hard the
        # alignment should be.
        similarity /= temperature

        beta = tf.nn.softmax(similarity)
        beta = tf.expand_dims(beta, axis=1)
        beta = tf.tile(beta, [1, num_channels])

        # Find weighted nearest neighbour.
        query_feats = tf.reduce_sum(beta * candidate_feats,
                                    axis=0,
                                    keepdims=True)

    return similarity, onehot_labels