def speech_rnn_latent_values( lib, feats_fn, latent_dict, ): x, labels, lengths, keys = ( data_library.load_speech_data_from_npz(feats_fn)) data_library.truncate_data_dim(x, lengths, lib["input_dim"], lib["max_frames"]) iterator = batching_library.speech_iterator( x, 1, shuffle_batches_every_epoch=False) indices = iterator.indices tf.reset_default_graph() config = tf.ConfigProto() config.gpu_options.allow_growth = True key_counter = 0 X = tf.placeholder(tf.float32, [None, None, lib["input_dim"]]) target = tf.placeholder(tf.int32, [None, lib["num_classes"]]) X_lengths = tf.placeholder(tf.int32, [None]) train_flag = tf.placeholder_with_default(False, shape=()) training_placeholders = [X, X_lengths, target] model = model_legos_library.rnn_classifier_architecture( [X, X_lengths], train_flag, model_setup_library.activation_lib(), lib) output = model["output"] latent = model["latent"] model_fn = model_setup_library.get_model_fn(lib) saver = tf.train.Saver() with tf.Session(config=config) as sesh: saver.restore(sesh, model_fn) for feats, lengths in tqdm(iterator, desc="Extracting latents", ncols=COL_LENGTH): lat = sesh.run(latent, feed_dict={ X: feats, X_lengths: lengths, train_flag: False }) latent_dict[keys[indices[key_counter]]] = lat key_counter += 1 print("Total number of keys: {}".format(key_counter))
def main(): lib = library_setup() num_episodes = 400 K = lib["K"] M = lib["M"] Q = lib["Q"] images, image_labels, image_keys = (data_library.load_image_data_from_npz( lib["image_data_dir"])) im_x, im_labels, im_keys = images, image_labels, image_keys # Speech data sp_x, sp_labels, sp_lengths, sp_keys = ( data_library.load_speech_data_from_npz(lib["speech_data_dir"])) max_frames = 100 d_frame = 13 print("\nLimiting dimensionality: {}".format(d_frame)) print("Limiting number of frames: {}\n".format(max_frames)) data_library.truncate_data_dim(sp_x, sp_lengths, d_frame, max_frames) speech_x, speech_labels, speech_lengths, speech_keys = sp_x, sp_labels, sp_lengths, sp_keys episodes_fn = lib["data_fn"] + "_episodes.txt" print("Generating epsiodes...\n") file = open(episodes_fn, "w") for episode_counter in tqdm(range(1, num_episodes + 1)): support_set = few_shot_learning_library.construct_few_shot_support_set_with_keys( speech_x, speech_labels, speech_keys, speech_lengths, images, image_labels, image_keys, M, K) im_keys_to_not_include = [] sp_keys_to_not_include = [] for key in support_set: for im_key in support_set[key]["image_keys"]: im_keys_to_not_include.append(im_key) for sp_key in support_set[key]["speech_keys"]: sp_keys_to_not_include.append(sp_key) query_dict = few_shot_learning_library.sample_multiple_keys( speech_x, speech_labels, speech_keys, speech_lengths, Q, exclude_key_list=sp_keys_to_not_include) matching_dict = few_shot_learning_library.sample_multiple_keys( images, image_labels, image_keys, num_to_sample=M - 1, num_of_each_sample=1, exclude_key_list=im_keys_to_not_include) file.write("Episode {}\n".format(episode_counter)) file.write("{}\n".format("Support set:")) file.write("{}\n".format("Labels: ")) for key in support_set: file.write("{}\n".format(key)) file.write("{}\n".format("Image keys: ")) for key in support_set: for i in range(len(support_set[key]["image_keys"])): file.write("{}".format(support_set[key]["image_keys"][i])) if i == len(support_set[key]["image_keys"]) - 1: file.write("\n") else: file.write(" ") file.write("{}\n".format("Speech keys: ")) for key in support_set: for i in range(len(support_set[key]["speech_keys"])): file.write("{}".format(support_set[key]["speech_keys"][i])) if i == len(support_set[key]["speech_keys"]) - 1: file.write("\n") else: file.write(" ") file.write("{}\n".format("Query:")) file.write("{}\n".format("Labels: ")) for key in query_dict: file.write("{}\n".format(key)) file.write("{}\n".format("Keys: ")) for key in query_dict: for i in range(len(query_dict[key]["keys"])): file.write("{}".format(query_dict[key]["keys"][i])) if i == len(query_dict[key]["keys"]) - 1: file.write("\n") else: file.write(" ") file.write("{}\n".format("Matching set:")) file.write("{}\n".format("Labels: ")) for key in matching_dict: file.write("{}\n".format(key)) file.write("{}\n".format("Keys: ")) for key in matching_dict: for i in range(len(matching_dict[key]["keys"])): file.write("{}".format(matching_dict[key]["keys"][i])) if i == len(matching_dict[key]["keys"]) - 1: file.write("\n") else: file.write(" ") file.write("\n") if episode_check(support_set, query_dict, matching_dict): break file.close() print("Wrote epsiodes to {}".format(episodes_fn))
def main(): args = check_argv() if args.metric == "cosine": dist_func = "cosine" elif args.metric == "euclidean": dist_func = "euclidean" elif args.metric == "euclidean_squared": dist_func == "sqeuclidean" print("Start time: {}".format(datetime.datetime.now())) if args.speech_feats_fn.split("_")[-2] == args.image_feats_fn.split( "/")[-1].split('.')[0]: subset = args.image_feats_fn.split("/")[-1].split('.')[0] elif args.speech_feats_fn.split( "_")[-2] == "val" and args.image_feats_fn.split("/")[-1].split( '.')[0] == "validation": subset = args.image_feats_fn.split("/")[-1].split('.')[0] else: sys.exit(0) speech_not_image_pairs = True if args.speech_feats_fn.split( "/")[2] in SPEECH_DATASETS and args.image_feats_fn.split( "/")[2] in IMAGE_DATASETS else "INVALID" if speech_not_image_pairs == "INVALID": print("Specified dataset to get pairs for, not valid.") sys.exit(0) key_pair_file = path.join( pair_path, args.speech_feats_fn.split("/")[2] + "_speech_" + args.image_feats_fn.split("/")[2] + "_image_pairs") model = "classifier" util_library.check_dir(key_pair_file) if os.path.isfile(key_pair_file) is False: speech_latent_npz = path.join( pair_path, "/".join(args.speech_feats_fn.split(".")[-2].split("/")[2:]), model + "_latents", model + "_feats.npz") image_latent_npz = path.join( pair_path, "/".join(args.image_feats_fn.split(".")[-2].split("/")[2:]), model + "_latents", model + "_feats.npz") image_latents, image_keys = data_library.load_latent_data_from_npz( image_latent_npz) image_latents = np.asarray(image_latents) image_latents = np.squeeze(image_latents) speech_latents, speech_keys = data_library.load_latent_data_from_npz( speech_latent_npz) speech_latents = np.asarray(speech_latents) speech_latents = np.squeeze(speech_latents) speech_latents = (speech_latents - speech_latents.mean(axis=0) ) / speech_latents.std(axis=0) image_latents = (image_latents - image_latents.mean(axis=0) ) / image_latents.std(axis=0) im_x, im_labels, im_keys = (data_library.load_image_data_from_npz( args.image_feats_fn)) sp_x, sp_labels, sp_lengths, sp_keys = ( data_library.load_speech_data_from_npz(args.speech_feats_fn)) max_frames = 100 d_frame = 13 print("\nLimiting dimensionality: {}".format(d_frame)) print("Limiting number of frames: {}\n".format(max_frames)) data_library.truncate_data_dim(sp_x, sp_lengths, d_frame, max_frames) support_set = few_shot_learning_library.construct_few_shot_support_set_with_keys( sp_x, sp_labels, sp_keys, sp_lengths, im_x, im_labels, im_keys, 11, 5) support_set_speech_keys = [] support_set_image_keys = [] support_set_speech_latents = [] support_set_image_latents = [] for key in support_set: support_set_speech_keys.extend(support_set[key]["speech_keys"]) for sp_key in support_set[key]["speech_keys"]: ind = np.where(np.asarray(speech_keys) == sp_key)[0][0] support_set_speech_latents.append(speech_latents[ind, :]) support_set_image_keys.extend(support_set[key]["image_keys"]) for im_key in support_set[key]["image_keys"]: ind = np.where(np.asarray(image_keys) == im_key)[0][0] support_set_image_latents.append(image_latents[ind, :]) support_set_speech_latents = np.asarray(support_set_speech_latents) support_set_image_latents = np.asarray(support_set_image_latents) support_dict = {} already_used = [] for key in support_set_speech_keys: label = key.split("_")[0] for im_key in support_set_image_keys: if few_shot_learning_library.label_test( im_key.split("_")[0], label) and im_key not in already_used: support_dict[key] = im_key already_used.append(im_key) break speech_dict = {} speech_distances = cdist(speech_latents, support_set_speech_latents, dist_func) speech_indexes = np.argsort(speech_distances, axis=1) for i, sp_key in enumerate(speech_keys): if sp_key not in support_set_speech_keys: for count in range(speech_indexes.shape[-1]): ind = speech_indexes[i, count] speech_dict[sp_key] = support_dict[ support_set_speech_keys[ind]] break image_dict = {} image_distances = cdist(image_latents, support_set_image_latents, dist_func) image_indexes = np.argsort(image_distances, axis=1) for i, im_key in enumerate(image_keys): if im_key not in support_set_image_keys: for count in range(image_indexes.shape[-1]): ind = image_indexes[i, count] if support_set_image_keys[ind] not in image_dict: image_dict[support_set_image_keys[ind]] = [] image_dict[support_set_image_keys[ind]].append(im_key) break already_used_im_keys = [] key_pair_file = open(path.join(key_pair_file, subset + "_pairs.txt"), 'w') for sp_key in tqdm(speech_dict, desc="Generating speech-image pairs", ncols=COL_LENGTH): possible_im_keys = image_dict[speech_dict[sp_key]] for i in range(len(possible_im_keys)): possible_key = possible_im_keys[i] if possible_key not in already_used_im_keys: key_pair_file.write(f'{sp_key}\t{possible_key}\n') already_used_im_keys.append(possible_key) image_dict[speech_dict[sp_key]].remove(possible_key) break key_pair_file.close() print("End time: {}".format(datetime.datetime.now()))
def main(): lib = library_setup() num_episodes = 400 K = lib["K"] M = lib["M"] Q = lib["Q"] # Speech data sp_x, sp_labels, sp_lengths, sp_keys = ( data_library.load_speech_data_from_npz(lib["data_dir"])) max_frames = 100 d_frame = 13 print("\nLimiting dimensionality: {}".format(d_frame)) print("Limiting number of frames: {}\n".format(max_frames)) data_library.truncate_data_dim(sp_x, sp_lengths, d_frame, max_frames) if lib["data_type"] == "buckeye": digit_list = [ "one", "two", "three", "four", "five", "six", "seven", "eight", "nine", "zero", "oh" ] speech_x = [] speech_labels = [] speech_lengths = [] speech_keys = [] for i, label in enumerate(sp_labels): if label not in digit_list: speech_x.append(sp_x[i]) speech_labels.append(sp_labels[i]) speech_lengths.append(sp_lengths[i]) speech_keys.append(sp_keys[i]) else: speech_x, speech_labels, speech_lengths, speech_keys = sp_x, sp_labels, sp_lengths, sp_keys if lib["data_type"] == "buckeye": speech_x, speech_keys, speech_labels, speech_lengths = filter_buckeye_set( speech_x, speech_keys, speech_labels, M, K, Q) episodes_fn = lib["data_fn"] + "_episodes.txt" print("Generating epsiodes...\n") file = open(episodes_fn, "w") for episode_counter in tqdm(range(1, num_episodes + 1)): support_set = few_shot_learning_library.construct_few_shot_support_set_with_keys( sp_x=speech_x, sp_labels=speech_labels, sp_keys=speech_keys, sp_lengths=speech_lengths, im_x=None, im_labels=None, im_keys=None, num_to_sample=M, num_of_each_sample=K) sp_keys_to_not_include = [] labels_wanted = [] for key in support_set: labels_wanted.append(key) for sp_key in support_set[key]["speech_keys"]: sp_keys_to_not_include.append(sp_key) query_dict = few_shot_learning_library.sample_multiple_keys( speech_x, speech_labels, speech_keys, speech_lengths, Q, exclude_key_list=sp_keys_to_not_include, labels_wanted=labels_wanted) for key in query_dict: for sp_key in query_dict[key]["keys"]: sp_keys_to_not_include.append(sp_key) matching_dict = few_shot_learning_library.sample_multiple_keys( speech_x, speech_labels, speech_keys, num_to_sample=len(labels_wanted), num_of_each_sample=1, exclude_key_list=sp_keys_to_not_include, labels_wanted=labels_wanted) file.write("Episode {}\n".format(episode_counter)) file.write("{}\n".format("Support set:")) file.write("{}\n".format("Labels: ")) for key in support_set: file.write("{}\n".format(key)) file.write("{}\n".format("Keys: ")) for key in support_set: for i in range(len(support_set[key]["speech_keys"])): file.write("{}".format(support_set[key]["speech_keys"][i])) if i == len(support_set[key]["speech_keys"]) - 1: file.write("\n") else: file.write(" ") file.write("{}\n".format("Query:")) file.write("{}\n".format("Labels: ")) for key in query_dict: file.write("{}\n".format(key)) file.write("{}\n".format("Keys: ")) for key in query_dict: for i in range(len(query_dict[key]["keys"])): file.write("{}".format(query_dict[key]["keys"][i])) if i == len(query_dict[key]["keys"]) - 1: file.write("\n") else: file.write(" ") file.write("{}\n".format("Matching set:")) file.write("{}\n".format("Labels: ")) for key in matching_dict: file.write("{}\n".format(key)) file.write("{}\n".format("Keys: ")) for key in matching_dict: for i in range(len(matching_dict[key]["keys"])): file.write("{}".format(matching_dict[key]["keys"][i])) if i == len(matching_dict[key]["keys"]) - 1: file.write("\n") else: file.write(" ") file.write("\n") if episode_check(support_set, query_dict, matching_dict): break file.close() print("Wrote epsiodes to {}".format(episodes_fn))
#_____________________________________________________________________________________________________________________________________ # # Data procesing # #_____________________________________________________________________________________________________________________________________ print("\n" + "-" * PRINT_LENGTH) print("Processing training data") print("-" * PRINT_LENGTH) image_train_x, image_train_labels, image_train_keys = ( data_library.load_image_data_from_npz(base_lib["image_train_fn"])) speech_train_x, speech_train_labels, speech_train_lengths, speech_train_keys = ( data_library.load_speech_data_from_npz(base_lib["speech_train_fn"])) data_library.truncate_data_dim(speech_train_x, speech_train_lengths, base_lib["speech_input_dim"], base_lib["max_frames"]) train_mask = data_library.get_mask(speech_train_x, speech_train_lengths, None) print("\n" + "-" * PRINT_LENGTH) print("Processing validation data") print("-" * PRINT_LENGTH) image_val_x, image_val_labels, image_val_keys = ( data_library.load_image_data_from_npz(base_lib["image_val_fn"])) speech_val_x, speech_val_labels, speech_val_lengths, speech_val_keys = ( data_library.load_speech_data_from_npz(base_lib["speech_val_fn"])) data_library.truncate_data_dim(speech_val_x, speech_val_lengths, base_lib["speech_input_dim"], base_lib["max_frames"]) val_mask = data_library.get_mask(speech_val_x, speech_val_lengths) print("\n" + "-" * PRINT_LENGTH) print("Processing test data")