예제 #1
0
def speech_rnn_latent_values(
    lib,
    feats_fn,
    latent_dict,
):

    x, labels, lengths, keys = (
        data_library.load_speech_data_from_npz(feats_fn))

    data_library.truncate_data_dim(x, lengths, lib["input_dim"],
                                   lib["max_frames"])
    iterator = batching_library.speech_iterator(
        x, 1, shuffle_batches_every_epoch=False)
    indices = iterator.indices

    tf.reset_default_graph()
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    key_counter = 0

    X = tf.placeholder(tf.float32, [None, None, lib["input_dim"]])
    target = tf.placeholder(tf.int32, [None, lib["num_classes"]])
    X_lengths = tf.placeholder(tf.int32, [None])
    train_flag = tf.placeholder_with_default(False, shape=())
    training_placeholders = [X, X_lengths, target]

    model = model_legos_library.rnn_classifier_architecture(
        [X, X_lengths], train_flag, model_setup_library.activation_lib(), lib)

    output = model["output"]
    latent = model["latent"]

    model_fn = model_setup_library.get_model_fn(lib)

    saver = tf.train.Saver()
    with tf.Session(config=config) as sesh:
        saver.restore(sesh, model_fn)
        for feats, lengths in tqdm(iterator,
                                   desc="Extracting latents",
                                   ncols=COL_LENGTH):
            lat = sesh.run(latent,
                           feed_dict={
                               X: feats,
                               X_lengths: lengths,
                               train_flag: False
                           })

            latent_dict[keys[indices[key_counter]]] = lat
            key_counter += 1

    print("Total number of keys: {}".format(key_counter))
예제 #2
0
def main():

    lib = library_setup()

    num_episodes = 400
    K = lib["K"]
    M = lib["M"]
    Q = lib["Q"]

    images, image_labels, image_keys = (data_library.load_image_data_from_npz(
        lib["image_data_dir"]))
    im_x, im_labels, im_keys = images, image_labels, image_keys

    # Speech data
    sp_x, sp_labels, sp_lengths, sp_keys = (
        data_library.load_speech_data_from_npz(lib["speech_data_dir"]))
    max_frames = 100
    d_frame = 13
    print("\nLimiting dimensionality: {}".format(d_frame))
    print("Limiting number of frames: {}\n".format(max_frames))
    data_library.truncate_data_dim(sp_x, sp_lengths, d_frame, max_frames)

    speech_x, speech_labels, speech_lengths, speech_keys = sp_x, sp_labels, sp_lengths, sp_keys

    episodes_fn = lib["data_fn"] + "_episodes.txt"
    print("Generating epsiodes...\n")

    file = open(episodes_fn, "w")

    for episode_counter in tqdm(range(1, num_episodes + 1)):

        support_set = few_shot_learning_library.construct_few_shot_support_set_with_keys(
            speech_x, speech_labels, speech_keys, speech_lengths, images,
            image_labels, image_keys, M, K)

        im_keys_to_not_include = []
        sp_keys_to_not_include = []
        for key in support_set:
            for im_key in support_set[key]["image_keys"]:
                im_keys_to_not_include.append(im_key)
            for sp_key in support_set[key]["speech_keys"]:
                sp_keys_to_not_include.append(sp_key)

        query_dict = few_shot_learning_library.sample_multiple_keys(
            speech_x,
            speech_labels,
            speech_keys,
            speech_lengths,
            Q,
            exclude_key_list=sp_keys_to_not_include)

        matching_dict = few_shot_learning_library.sample_multiple_keys(
            images,
            image_labels,
            image_keys,
            num_to_sample=M - 1,
            num_of_each_sample=1,
            exclude_key_list=im_keys_to_not_include)

        file.write("Episode {}\n".format(episode_counter))

        file.write("{}\n".format("Support set:"))
        file.write("{}\n".format("Labels: "))
        for key in support_set:
            file.write("{}\n".format(key))
        file.write("{}\n".format("Image keys: "))
        for key in support_set:
            for i in range(len(support_set[key]["image_keys"])):
                file.write("{}".format(support_set[key]["image_keys"][i]))
                if i == len(support_set[key]["image_keys"]) - 1:
                    file.write("\n")
                else:
                    file.write(" ")
        file.write("{}\n".format("Speech keys: "))
        for key in support_set:
            for i in range(len(support_set[key]["speech_keys"])):
                file.write("{}".format(support_set[key]["speech_keys"][i]))
                if i == len(support_set[key]["speech_keys"]) - 1:
                    file.write("\n")
                else:
                    file.write(" ")

        file.write("{}\n".format("Query:"))
        file.write("{}\n".format("Labels: "))
        for key in query_dict:
            file.write("{}\n".format(key))
        file.write("{}\n".format("Keys: "))
        for key in query_dict:
            for i in range(len(query_dict[key]["keys"])):
                file.write("{}".format(query_dict[key]["keys"][i]))
                if i == len(query_dict[key]["keys"]) - 1: file.write("\n")
                else: file.write(" ")

        file.write("{}\n".format("Matching set:"))
        file.write("{}\n".format("Labels: "))
        for key in matching_dict:
            file.write("{}\n".format(key))
        file.write("{}\n".format("Keys: "))
        for key in matching_dict:
            for i in range(len(matching_dict[key]["keys"])):
                file.write("{}".format(matching_dict[key]["keys"][i]))
                if i == len(matching_dict[key]["keys"]) - 1: file.write("\n")
                else: file.write(" ")

        file.write("\n")

        if episode_check(support_set, query_dict, matching_dict):
            break

    file.close()

    print("Wrote epsiodes to {}".format(episodes_fn))
def main():

    args = check_argv()

    if args.metric == "cosine":
        dist_func = "cosine"
    elif args.metric == "euclidean":
        dist_func = "euclidean"
    elif args.metric == "euclidean_squared":
        dist_func == "sqeuclidean"

    print("Start time: {}".format(datetime.datetime.now()))
    if args.speech_feats_fn.split("_")[-2] == args.image_feats_fn.split(
            "/")[-1].split('.')[0]:
        subset = args.image_feats_fn.split("/")[-1].split('.')[0]
    elif args.speech_feats_fn.split(
            "_")[-2] == "val" and args.image_feats_fn.split("/")[-1].split(
                '.')[0] == "validation":
        subset = args.image_feats_fn.split("/")[-1].split('.')[0]
    else:
        sys.exit(0)

    speech_not_image_pairs = True if args.speech_feats_fn.split(
        "/")[2] in SPEECH_DATASETS and args.image_feats_fn.split(
            "/")[2] in IMAGE_DATASETS else "INVALID"
    if speech_not_image_pairs == "INVALID":
        print("Specified dataset to get pairs for, not valid.")
        sys.exit(0)

    key_pair_file = path.join(
        pair_path,
        args.speech_feats_fn.split("/")[2] + "_speech_" +
        args.image_feats_fn.split("/")[2] + "_image_pairs")
    model = "classifier"
    util_library.check_dir(key_pair_file)

    if os.path.isfile(key_pair_file) is False:

        speech_latent_npz = path.join(
            pair_path,
            "/".join(args.speech_feats_fn.split(".")[-2].split("/")[2:]),
            model + "_latents", model + "_feats.npz")
        image_latent_npz = path.join(
            pair_path,
            "/".join(args.image_feats_fn.split(".")[-2].split("/")[2:]),
            model + "_latents", model + "_feats.npz")

        image_latents, image_keys = data_library.load_latent_data_from_npz(
            image_latent_npz)
        image_latents = np.asarray(image_latents)
        image_latents = np.squeeze(image_latents)

        speech_latents, speech_keys = data_library.load_latent_data_from_npz(
            speech_latent_npz)
        speech_latents = np.asarray(speech_latents)
        speech_latents = np.squeeze(speech_latents)

        speech_latents = (speech_latents - speech_latents.mean(axis=0)
                          ) / speech_latents.std(axis=0)
        image_latents = (image_latents - image_latents.mean(axis=0)
                         ) / image_latents.std(axis=0)

        im_x, im_labels, im_keys = (data_library.load_image_data_from_npz(
            args.image_feats_fn))

        sp_x, sp_labels, sp_lengths, sp_keys = (
            data_library.load_speech_data_from_npz(args.speech_feats_fn))
        max_frames = 100
        d_frame = 13
        print("\nLimiting dimensionality: {}".format(d_frame))
        print("Limiting number of frames: {}\n".format(max_frames))
        data_library.truncate_data_dim(sp_x, sp_lengths, d_frame, max_frames)

        support_set = few_shot_learning_library.construct_few_shot_support_set_with_keys(
            sp_x, sp_labels, sp_keys, sp_lengths, im_x, im_labels, im_keys, 11,
            5)

        support_set_speech_keys = []
        support_set_image_keys = []
        support_set_speech_latents = []
        support_set_image_latents = []

        for key in support_set:
            support_set_speech_keys.extend(support_set[key]["speech_keys"])
            for sp_key in support_set[key]["speech_keys"]:
                ind = np.where(np.asarray(speech_keys) == sp_key)[0][0]
                support_set_speech_latents.append(speech_latents[ind, :])

            support_set_image_keys.extend(support_set[key]["image_keys"])
            for im_key in support_set[key]["image_keys"]:
                ind = np.where(np.asarray(image_keys) == im_key)[0][0]
                support_set_image_latents.append(image_latents[ind, :])

        support_set_speech_latents = np.asarray(support_set_speech_latents)
        support_set_image_latents = np.asarray(support_set_image_latents)

        support_dict = {}
        already_used = []

        for key in support_set_speech_keys:
            label = key.split("_")[0]

            for im_key in support_set_image_keys:
                if few_shot_learning_library.label_test(
                        im_key.split("_")[0],
                        label) and im_key not in already_used:
                    support_dict[key] = im_key
                    already_used.append(im_key)
                    break

        speech_dict = {}
        speech_distances = cdist(speech_latents, support_set_speech_latents,
                                 dist_func)
        speech_indexes = np.argsort(speech_distances, axis=1)
        for i, sp_key in enumerate(speech_keys):
            if sp_key not in support_set_speech_keys:
                for count in range(speech_indexes.shape[-1]):
                    ind = speech_indexes[i, count]
                    speech_dict[sp_key] = support_dict[
                        support_set_speech_keys[ind]]
                    break

        image_dict = {}
        image_distances = cdist(image_latents, support_set_image_latents,
                                dist_func)
        image_indexes = np.argsort(image_distances, axis=1)

        for i, im_key in enumerate(image_keys):

            if im_key not in support_set_image_keys:

                for count in range(image_indexes.shape[-1]):
                    ind = image_indexes[i, count]

                    if support_set_image_keys[ind] not in image_dict:
                        image_dict[support_set_image_keys[ind]] = []

                    image_dict[support_set_image_keys[ind]].append(im_key)
                    break

        already_used_im_keys = []
        key_pair_file = open(path.join(key_pair_file, subset + "_pairs.txt"),
                             'w')
        for sp_key in tqdm(speech_dict,
                           desc="Generating speech-image pairs",
                           ncols=COL_LENGTH):
            possible_im_keys = image_dict[speech_dict[sp_key]]

            for i in range(len(possible_im_keys)):
                possible_key = possible_im_keys[i]
                if possible_key not in already_used_im_keys:
                    key_pair_file.write(f'{sp_key}\t{possible_key}\n')
                    already_used_im_keys.append(possible_key)
                    image_dict[speech_dict[sp_key]].remove(possible_key)
                    break

        key_pair_file.close()
    print("End time: {}".format(datetime.datetime.now()))
예제 #4
0
def main():

    lib = library_setup()

    num_episodes = 400
    K = lib["K"]
    M = lib["M"]
    Q = lib["Q"]

    # Speech data
    sp_x, sp_labels, sp_lengths, sp_keys = (
        data_library.load_speech_data_from_npz(lib["data_dir"]))
    max_frames = 100
    d_frame = 13
    print("\nLimiting dimensionality: {}".format(d_frame))
    print("Limiting number of frames: {}\n".format(max_frames))
    data_library.truncate_data_dim(sp_x, sp_lengths, d_frame, max_frames)

    if lib["data_type"] == "buckeye":
        digit_list = [
            "one", "two", "three", "four", "five", "six", "seven", "eight",
            "nine", "zero", "oh"
        ]
        speech_x = []
        speech_labels = []
        speech_lengths = []
        speech_keys = []
        for i, label in enumerate(sp_labels):
            if label not in digit_list:
                speech_x.append(sp_x[i])
                speech_labels.append(sp_labels[i])
                speech_lengths.append(sp_lengths[i])
                speech_keys.append(sp_keys[i])

    else:
        speech_x, speech_labels, speech_lengths, speech_keys = sp_x, sp_labels, sp_lengths, sp_keys

    if lib["data_type"] == "buckeye":
        speech_x, speech_keys, speech_labels, speech_lengths = filter_buckeye_set(
            speech_x, speech_keys, speech_labels, M, K, Q)
    episodes_fn = lib["data_fn"] + "_episodes.txt"
    print("Generating epsiodes...\n")

    file = open(episodes_fn, "w")

    for episode_counter in tqdm(range(1, num_episodes + 1)):

        support_set = few_shot_learning_library.construct_few_shot_support_set_with_keys(
            sp_x=speech_x,
            sp_labels=speech_labels,
            sp_keys=speech_keys,
            sp_lengths=speech_lengths,
            im_x=None,
            im_labels=None,
            im_keys=None,
            num_to_sample=M,
            num_of_each_sample=K)

        sp_keys_to_not_include = []
        labels_wanted = []
        for key in support_set:
            labels_wanted.append(key)
            for sp_key in support_set[key]["speech_keys"]:
                sp_keys_to_not_include.append(sp_key)

        query_dict = few_shot_learning_library.sample_multiple_keys(
            speech_x,
            speech_labels,
            speech_keys,
            speech_lengths,
            Q,
            exclude_key_list=sp_keys_to_not_include,
            labels_wanted=labels_wanted)

        for key in query_dict:
            for sp_key in query_dict[key]["keys"]:
                sp_keys_to_not_include.append(sp_key)

        matching_dict = few_shot_learning_library.sample_multiple_keys(
            speech_x,
            speech_labels,
            speech_keys,
            num_to_sample=len(labels_wanted),
            num_of_each_sample=1,
            exclude_key_list=sp_keys_to_not_include,
            labels_wanted=labels_wanted)

        file.write("Episode {}\n".format(episode_counter))

        file.write("{}\n".format("Support set:"))
        file.write("{}\n".format("Labels: "))
        for key in support_set:
            file.write("{}\n".format(key))

        file.write("{}\n".format("Keys: "))
        for key in support_set:
            for i in range(len(support_set[key]["speech_keys"])):
                file.write("{}".format(support_set[key]["speech_keys"][i]))
                if i == len(support_set[key]["speech_keys"]) - 1:
                    file.write("\n")
                else:
                    file.write(" ")

        file.write("{}\n".format("Query:"))
        file.write("{}\n".format("Labels: "))
        for key in query_dict:
            file.write("{}\n".format(key))
        file.write("{}\n".format("Keys: "))
        for key in query_dict:
            for i in range(len(query_dict[key]["keys"])):
                file.write("{}".format(query_dict[key]["keys"][i]))
                if i == len(query_dict[key]["keys"]) - 1: file.write("\n")
                else: file.write(" ")

        file.write("{}\n".format("Matching set:"))
        file.write("{}\n".format("Labels: "))
        for key in matching_dict:
            file.write("{}\n".format(key))
        file.write("{}\n".format("Keys: "))
        for key in matching_dict:
            for i in range(len(matching_dict[key]["keys"])):
                file.write("{}".format(matching_dict[key]["keys"][i]))
                if i == len(matching_dict[key]["keys"]) - 1: file.write("\n")
                else: file.write(" ")

        file.write("\n")

        if episode_check(support_set, query_dict, matching_dict):
            break

    file.close()

    print("Wrote epsiodes to {}".format(episodes_fn))
예제 #5
0
#_____________________________________________________________________________________________________________________________________
#
# Data procesing
#
#_____________________________________________________________________________________________________________________________________

print("\n" + "-" * PRINT_LENGTH)
print("Processing training data")
print("-" * PRINT_LENGTH)
image_train_x, image_train_labels, image_train_keys = (
    data_library.load_image_data_from_npz(base_lib["image_train_fn"]))
speech_train_x, speech_train_labels, speech_train_lengths, speech_train_keys = (
    data_library.load_speech_data_from_npz(base_lib["speech_train_fn"]))
data_library.truncate_data_dim(speech_train_x, speech_train_lengths,
                               base_lib["speech_input_dim"],
                               base_lib["max_frames"])
train_mask = data_library.get_mask(speech_train_x, speech_train_lengths, None)
print("\n" + "-" * PRINT_LENGTH)
print("Processing validation data")
print("-" * PRINT_LENGTH)
image_val_x, image_val_labels, image_val_keys = (
    data_library.load_image_data_from_npz(base_lib["image_val_fn"]))
speech_val_x, speech_val_labels, speech_val_lengths, speech_val_keys = (
    data_library.load_speech_data_from_npz(base_lib["speech_val_fn"]))
data_library.truncate_data_dim(speech_val_x, speech_val_lengths,
                               base_lib["speech_input_dim"],
                               base_lib["max_frames"])
val_mask = data_library.get_mask(speech_val_x, speech_val_lengths)
print("\n" + "-" * PRINT_LENGTH)
print("Processing test data")