def visualize_nearest_neighbours(model, data, global_step, batch_size, num_steps, num_frames_per_step, split): """Visualize nearest neighbours in embedding space.""" # Set learning_phase to False to use models in inference mode. tf.keras.backend.set_learning_phase(0) cnn = model['cnn'] emb = model['emb'] cnn_feats = get_cnn_feats(cnn, data, training=False) emb_feats = emb(cnn_feats, num_steps) emb_feats = tf.stack(tf.split(emb_feats, num_steps, axis=0), axis=1) query_feats = emb_feats[0] frames = data['frames'] image_list = tf.unstack(frames, num=batch_size, axis=0) im_list = [image_list[0][num_frames_per_step - 1::num_frames_per_step]] sim_matrix = np.zeros((batch_size - 1, num_steps, num_steps), dtype=np.float32) for i in range(1, batch_size): candidate_feats = emb_feats[i] img_list = tf.unstack(image_list[i], num=num_steps * num_frames_per_step, axis=0)[num_frames_per_step - 1::num_frames_per_step] nn_img_list = [] for j in range(num_steps): curr_query_feats = tf.tile(query_feats[j:j + 1], [num_steps, 1]) mean_squared_distance = tf.reduce_mean(tf.squared_difference( curr_query_feats, candidate_feats), axis=1) sim_matrix[i - 1, j] = softmax(-1.0 * mean_squared_distance) nn_img_list.append(img_list[tf.argmin(mean_squared_distance)]) nn_img = tf.stack(nn_img_list, axis=0) im_list.append(nn_img) def vstack(im): return tf.concat(tf.unstack(im, num=num_steps), axis=1) summary_im = tf.expand_dims(tf.concat([vstack(im) for im in im_list], axis=0), axis=0) tf.summary.image('%s/nn' % split, summary_im, step=global_step) # Convert sim_matrix to float32 as summary_image doesn't take float64 sim_matrix = sim_matrix.astype(np.float32) tf.summary.image('%s/similarity_matrix' % split, np.expand_dims(sim_matrix, axis=3), step=global_step)
def _align_single_cycle(cycle, embs, cycle_length, num_steps, similarity_type, temperature): """Takes a single cycle and returns logits (simialrity scores) and labels.""" # Choose random frame. n_idx = tf.random_uniform((), minval=0, maxval=num_steps, dtype=tf.int32) # Create labels onehot_labels = tf.one_hot(n_idx, num_steps) # Choose query feats for first frame. query_feats = embs[cycle[0], n_idx:n_idx + 1] num_channels = tf.shape(query_feats)[-1] for c in range(1, cycle_length + 1): candidate_feats = embs[cycle[c]] if similarity_type == 'l2': # Find L2 distance. mean_squared_distance = tf.reduce_sum(tf.squared_difference( tf.tile(query_feats, [num_steps, 1]), candidate_feats), axis=1) # Convert L2 distance to similarity. similarity = -mean_squared_distance elif similarity_type == 'cosine': # Dot product of embeddings. similarity = tf.squeeze( tf.matmul(candidate_feats, query_feats, transpose_b=True)) else: raise ValueError('similarity_type can either be l2 or cosine.') # Scale the distance by number of channels. This normalization helps with # optimization. similarity /= tf.cast(num_channels, tf.float32) # Scale the distance by a temperature that helps with how soft/hard the # alignment should be. similarity /= temperature beta = tf.nn.softmax(similarity) beta = tf.expand_dims(beta, axis=1) beta = tf.tile(beta, [1, num_channels]) # Find weighted nearest neighbour. query_feats = tf.reduce_sum(beta * candidate_feats, axis=0, keepdims=True) return similarity, onehot_labels