Exemplo n.º 1
0
def batch_generator(descs, batch_size, train_vocab, word_ngrams, sort_ngrams, shuffle=False,
                    show_progress=True):
    global cache
    inds = np.arange(len(descs))
    rem_inds, batch_inds = next_batch(inds, batch_size, shuffle)

    if show_progress:
        progress_bar = tqdm(total=int(np.ceil(len(descs) / batch_size)))
    while len(batch_inds) > 0:
        batch_descs = [descs[i] for i in batch_inds]
        desc_hashes = [hash(str(desc)) for desc in batch_descs]
        batch = [[0] + [train_vocab[phrase]["id"] for phrase in get_all(desc, word_ngrams, sort_ngrams) if
                        phrase in train_vocab] if h not in cache else cache[h] for
                 desc, h in zip(batch_descs, desc_hashes)]

        for h, inds in zip(desc_hashes, batch):
            if h not in cache:
                cache[h] = inds
        batch_weights = [[1 / len(i) for _ in range(len(i))] for i in batch]

        cur_lens = np.array([len(i) for i in batch])
        mx_len = max(cur_lens)
        to_pad = mx_len - cur_lens

        batch = [i + [0 for _ in range(pad)] for i, pad in zip(batch, to_pad)]
        batch_weights = [i + [0 for _ in range(pad)] for i, pad in zip(batch_weights, to_pad)]

        rem_inds, batch_inds = next_batch(rem_inds, batch_size, shuffle)
        if show_progress:
            progress_bar.update()
        yield batch, np.expand_dims(batch_weights, axis=2)

    if show_progress:
        progress_bar.close()
Exemplo n.º 2
0
 def get_subwords(self, word):
     word = word.replace("_", " ")
     word_splitted = word.split()
     if len(word_splitted) > self.info["word_ngrams"]:
         return [], []
     else:
         subwords = [
             phrase
             for phrase in get_all(word_splitted, self.info["word_ngrams"],
                                   self.info["sort_ngrams"])
             if phrase in self.word_dict
         ]
         return subwords, [
             self.get_word_id(subword) for subword in subwords
         ]
Exemplo n.º 3
0
def batch_generator(descs,
                    labels,
                    batch_size,
                    train_vocab,
                    labels_lookup,
                    word_ngrams,
                    shuffle=False):
    global cache
    inds = np.arange(len(descs))
    rem_inds, batch_inds = next_batch(inds, batch_size, shuffle)

    while len(batch_inds) > 0:
        batch_descs = [descs[i] for i in batch_inds]
        desc_hashes = [hash(str(desc)) for desc in batch_descs]
        batch = [[0] + [
            train_vocab[phrase]["id"]
            for phrase in get_all(desc, word_ngrams) if phrase in train_vocab
        ] if h not in cache else cache[h]
                 for desc, h in zip(batch_descs, desc_hashes)]

        for h, inds in zip(desc_hashes, batch):
            if h not in cache:
                cache[h] = inds
        batch_weights = [[1 / len(i) for _ in range(len(i))] for i in batch]
        batch_labels = [labels[i] for i in batch_inds]
        batch_labels = [labels_lookup[label] for label in batch_labels]

        cur_lens = np.array([len(i) for i in batch])
        mx_len = max(cur_lens)
        to_pad = mx_len - cur_lens

        batch = [i + [0 for _ in range(pad)] for i, pad in zip(batch, to_pad)]
        batch_weights = [
            i + [0 for _ in range(pad)]
            for i, pad in zip(batch_weights, to_pad)
        ]

        rem_inds, batch_inds = next_batch(rem_inds, batch_size, shuffle)
        yield batch, np.expand_dims(batch_weights, axis=2), batch_labels
Exemplo n.º 4
0
    def _batch_generator(self, list_of_texts, batch_size):
        """
        Generate batch from list of texts
        :param list_of_texts: list/array
        :param batch_size: int
        :return: batch word indices, batch word weights
        """
        if self.preprocessing_function:
            list_of_texts = [self.preprocessing_function(str(t)) for t in list_of_texts]
        else:
            list_of_texts = [str(t) for t in list_of_texts]
        inds = np.arange(len(list_of_texts))
        rem_inds, batch_inds = next_batch(inds, batch_size)

        while len(batch_inds) > 0:
            batch, batch_weights = [], []

            descs_words = [list(get_all(list_of_texts[ind].split(), self.info["word_ngrams"], self.info["sort_ngrams"]))
                           for ind in batch_inds]
            num_max_words = max([len(desc_split) for desc_split in descs_words]) + 1

            for desc_words in descs_words:
                init_test_inds = [0] + [self.train_vocab[phrase]["id"] for phrase in desc_words
                                        if phrase in self.train_vocab]

                test_desc_inds = init_test_inds + [0 for _ in range(num_max_words - len(init_test_inds))]
                test_desc_weights = np.zeros_like(test_desc_inds, dtype=float)
                test_desc_weights[:len(init_test_inds)] = 1. / len(init_test_inds)

                batch.append(test_desc_inds)
                batch_weights.append(test_desc_weights)
            rem_inds, batch_inds = next_batch(rem_inds, batch_size)
            batch_weights = np.expand_dims(batch_weights, 2)
            batch = np.array(batch)

            yield batch, batch_weights
Exemplo n.º 5
0
def main():
    main_start = time.time()
    parser = argparse.ArgumentParser()

    parser.add_argument("--train_path",
                        type=str,
                        help="path to train file",
                        default="./train.txt")
    parser.add_argument("--validation_path",
                        type=str,
                        help="path to validation file",
                        default="")
    parser.add_argument("--label_prefix",
                        type=str,
                        help="label prefix",
                        default="__label__")
    parser.add_argument(
        "--min_word_count",
        type=int,
        default=1,
        help="discard words which appear less than this number")
    parser.add_argument(
        "--min_label_count",
        type=int,
        default=1,
        help="discard labels which appear less than this number")
    parser.add_argument("--dim",
                        type=int,
                        default=100,
                        help="length of embedding vector")
    parser.add_argument("--n_epochs",
                        type=int,
                        default=5,
                        help="number of epochs")
    parser.add_argument("--word_ngrams",
                        type=int,
                        default=1,
                        help="word ngrams")
    parser.add_argument("--batch_size",
                        type=int,
                        default=1024,
                        help="batch size for train")
    parser.add_argument(
        "--batch_size_inference",
        type=int,
        default=1024,
        help=
        "batch size for inference, ignored if validation_path is not provided")
    parser.add_argument("--seed", type=int, default=17)
    parser.add_argument("--learning_rate",
                        type=float,
                        default=0.1,
                        help="learning rate")
    parser.add_argument("--learning_rate_multiplier",
                        type=float,
                        default=0.8,
                        help="learning rate multiplier after each epoch")
    parser.add_argument(
        "--data_fraction",
        type=float,
        default=1,
        help=
        "data fraction, if < 1, train (and validation) data will be randomly sampled"
    )
    parser.add_argument("--use_validation",
                        type=int,
                        default=0,
                        help="evaluate on validation data")
    parser.add_argument("--use_gpu",
                        type=int,
                        default=0,
                        help="use gpu for training")
    parser.add_argument("--gpu_fraction",
                        type=float,
                        default=0.5,
                        help="what fraction of gpu to allocate")
    parser.add_argument("--result_dir",
                        type=str,
                        help="result dir",
                        default="./results/")

    args = parser.parse_args()
    for bool_param in [args.use_validation, args.use_gpu]:
        assert bool_param in [0, 1]
    train_path = args.train_path
    validation_path = args.validation_path
    label_prefix = args.label_prefix
    min_word_count = args.min_word_count
    min_label_count = args.min_label_count
    emb_dim = args.dim
    n_epochs = args.n_epochs
    word_ngrams = args.word_ngrams
    batch_size = args.batch_size
    batch_size_inference = args.batch_size_inference
    initial_learning_rate = args.learning_rate
    learning_rate_multiplier = args.learning_rate_multiplier
    use_validation = bool(args.use_validation)
    seed = args.seed
    data_fraction = args.data_fraction
    use_gpu = bool(args.use_gpu)
    gpu_fraction = args.gpu_fraction
    result_dir = validate(args.result_dir)

    print('training with arguments:')
    print(args)
    print('\n')

    np.random.seed(seed)

    train_descs, train_labels, max_words = parse_txt(train_path,
                                                     return_max_len=True,
                                                     debug_till_row=-1,
                                                     fraction=data_fraction,
                                                     seed=seed,
                                                     label_prefix=label_prefix)

    model_params = {
        "word_ngrams":
        word_ngrams,
        "word_id_path":
        os.path.abspath(os.path.join(result_dir, "word_id.json")),
        "label_dict_path":
        os.path.abspath(os.path.join(result_dir, "label_dict.json"))
    }

    for child_dir in os.listdir(result_dir):
        dir_tmp = os.path.join(result_dir, child_dir)
        if os.path.isdir(dir_tmp):
            shutil.rmtree(dir_tmp)
        if dir_tmp.endswith(".pb"):
            os.remove(dir_tmp)

    max_words_with_ng = 1
    for ng in range(word_ngrams):
        max_words_with_ng += max_words - ng

    print("preparing dataset")
    print("total number of datapoints: {}".format(len(train_descs)))
    print("max number of words in description: {}".format(max_words))
    print("max number of words with n-grams in description: {}".format(
        max_words_with_ng))

    label_dict_path = os.path.join(result_dir, "label_dict.json")
    word_id_path = os.path.join(result_dir, "word_id.json")

    train_vocab = make_train_vocab(train_descs, word_ngrams)
    label_vocab = make_label_vocab(train_labels)

    if min_word_count > 1:
        tmp_cnt = 1
        train_vocab_thresholded = {}
        for k, v in sorted(train_vocab.items(), key=lambda t: t[0]):
            if v["cnt"] >= min_word_count:
                v["id"] = tmp_cnt
                train_vocab_thresholded[k] = v
                tmp_cnt += 1

        train_vocab = train_vocab_thresholded.copy()
        del train_vocab_thresholded

        print(
            "number of unique words and phrases after thresholding: {}".format(
                len(train_vocab)))

    print("\nnumber of labels in train: {}".format(len(set(
        label_vocab.keys()))))
    if min_label_count > 1:
        label_vocab_thresholded = {}
        tmp_cnt = 0
        for k, v in sorted(label_vocab.items(), key=lambda t: t[0]):
            if v["cnt"] >= min_label_count:
                v["id"] = tmp_cnt
                label_vocab_thresholded[k] = v
                tmp_cnt += 1

        label_vocab = label_vocab_thresholded.copy()
        del label_vocab_thresholded

        print("number of unique labels after thresholding: {}".format(
            len(label_vocab)))

    final_train_labels = set(label_vocab.keys())

    with open(label_dict_path, "w+") as outfile:
        json.dump(label_vocab, outfile)
    with open(word_id_path, "w+") as outfile:
        json.dump(train_vocab, outfile)
    with open(os.path.join(result_dir, "model_params.json"), "w+") as outfile:
        json.dump(model_params, outfile)

    num_words_in_train = len(train_vocab)
    num_labels = len(label_vocab)

    train_descs2, train_labels2 = [], []
    labels_lookup = {}

    labels_thrown, descs_thrown = 0, 0
    for train_desc, train_label in zip(tqdm(train_descs), train_labels):
        final_train_inds = [0] + [
            train_vocab[phrase]["id"]
            for phrase in get_all(train_desc, word_ngrams)
            if phrase in train_vocab
        ]
        if len(final_train_inds) == 1:
            descs_thrown += 1
            continue

        if train_label not in labels_lookup:
            if train_label in final_train_labels:
                labels_lookup[train_label] = construct_label(
                    label_vocab[train_label]["id"], num_labels)
            else:
                labels_thrown += 1
                continue

        train_labels2.append(train_label)
        train_descs2.append(train_desc)
    del train_descs, train_labels

    print("\n{} datapoints thrown because of empty description".format(
        descs_thrown))
    if min_label_count > 1:
        print("{} datapoints thrown because of label".format(labels_thrown))

    if use_validation:
        val_descs, val_labels, max_words_val = parse_txt(
            validation_path,
            return_max_len=True,
            label_prefix=label_prefix,
            seed=seed,
            fraction=data_fraction)
        max_words_with_ng_val = 1
        for ng in range(word_ngrams):
            max_words_with_ng_val += max_words_val - ng

        print("\ntotal number of val datapoints: {}".format(len(val_descs)))
        val_descs2, val_labels2 = [], []
        num_thrown_for_label = 0

        for val_desc, val_label in zip(val_descs, val_labels):
            if val_label not in labels_lookup:
                num_thrown_for_label += 1
                continue

            val_descs2.append(val_desc)
            val_labels2.append(val_label)

        val_labels_set = set(val_labels2)

        print("{} datapoints thrown because of label".format(
            num_thrown_for_label))
        print("number of val datapoints after cleaning: {}".format(
            len(val_descs2)))
        print("number of unique labels in val after cleaning: {}".format(
            len(val_labels_set)))
        initial_val_len = len(val_descs)
        del val_descs, val_labels

    if use_gpu:
        device = "/gpu:0"
        config = tf.ConfigProto(
            allow_soft_placement=True,
            gpu_options=tf.GPUOptions(
                per_process_gpu_memory_fraction=gpu_fraction,
                allow_growth=True))
    else:
        device = "/cpu:0"
        config = tf.ConfigProto(allow_soft_placement=True)

    with tf.device(device):
        with tf.Session(config=config) as sess:
            input_ph = tf.placeholder(tf.int32,
                                      shape=[None, None],
                                      name="input")
            weights_ph = tf.placeholder(tf.float32,
                                        shape=[None, None, 1],
                                        name="input_weights")
            labels_ph = tf.placeholder(tf.float32,
                                       shape=[None, num_labels],
                                       name="label")
            learning_rate_ph = tf.placeholder_with_default(
                initial_learning_rate, shape=[], name="learning_rate")

            tf.set_random_seed(seed)

            with tf.name_scope("embeddings"):
                look_up_table = tf.Variable(tf.random_uniform(
                    [num_words_in_train + 1, emb_dim]),
                                            name="embedding_matrix")

            with tf.name_scope("mean_sentece_vector"):
                gath_vecs = tf.gather(look_up_table, input_ph)
                weights_broadcasted = tf.tile(weights_ph,
                                              tf.stack([1, 1, emb_dim]))
                mean_emb = tf.reduce_sum(tf.multiply(weights_broadcasted,
                                                     gath_vecs),
                                         axis=1,
                                         name="sentence_embedding")

            logits = tf.layers.dense(
                mean_emb,
                num_labels,
                use_bias=False,
                kernel_initializer=tf.truncated_normal_initializer(),
                name="logits")
            output = tf.nn.softmax(logits, name="prediction")
            # this is not used in the training, but will be used for inference

            correctly_predicted = tf.nn.in_top_k(logits,
                                                 tf.argmax(labels_ph, axis=1),
                                                 1,
                                                 name="top_1")

            ce_loss = tf.reduce_mean(
                tf.nn.softmax_cross_entropy_with_logits_v2(labels=labels_ph,
                                                           logits=logits),
                name="ce_loss")

            train_op = tf.train.AdamOptimizer(learning_rate_ph).minimize(
                ce_loss)
            sess.run(tf.global_variables_initializer())

            train_start = time.time()

            for epoch in range(n_epochs):
                print("\nepoch {} started".format(epoch + 1))

                end_epoch_accuracy, end_epoch_accuracy_k = [], []
                moving_loss = []
                for batch, batch_weights, batch_labels in \
                        batch_generator(train_descs2, train_labels2, batch_size, train_vocab, labels_lookup,
                                        word_ngrams, shuffle=True):
                    _, correct, batch_loss = sess.run(
                        [train_op, correctly_predicted, ce_loss],
                        feed_dict={
                            input_ph: batch,
                            weights_ph: batch_weights,
                            labels_ph: batch_labels
                        })

                    end_epoch_accuracy.extend(correct)
                    moving_loss.append(batch_loss)

                print('\ncurrent learning rate: {}'.format(
                    round(initial_learning_rate, 7)))
                print("epoch {} ended".format(epoch + 1))
                print("epoch moving mean loss: {}".format(
                    round(np.mean(moving_loss), 3)))
                print("train moving average accuracy: {}".format(
                    percent(end_epoch_accuracy)))

                initial_learning_rate *= learning_rate_multiplier

                if use_validation:
                    end_epoch_accuracy, end_epoch_accuracy_k = [], []
                    end_epoch_loss = []

                    for batch, batch_weights, batch_labels in \
                            batch_generator(val_descs2, val_labels2, batch_size_inference, train_vocab, labels_lookup,
                                            word_ngrams):
                        correct, batch_loss = sess.run(
                            [correctly_predicted, ce_loss],
                            feed_dict={
                                input_ph: batch,
                                weights_ph: batch_weights,
                                labels_ph: batch_labels
                            })

                        end_epoch_accuracy.extend(correct)
                        end_epoch_loss.append(batch_loss)

                    mean_acc = np.round(
                        100 * np.sum(end_epoch_accuracy) / initial_val_len, 2)

                    print("end epoch mean val accuracy: {}".format(mean_acc))

            freeze_save_graph(sess, result_dir,
                              "model_ep{}.pb".format(epoch + 1), "prediction")
            print("the model is stored at {}".format(result_dir))
            print("the training took {} seconds".format(
                round(time.time() - train_start, 0)))
    print("all process took {} seconds".format(
        round(time.time() - main_start, 0)))
Exemplo n.º 6
0
    def _batch_generator(self, list_of_texts, batch_size, show_progress=False):
        """
        Generate batch from list of texts
        :param list_of_texts: list/array
        :param batch_size: int
        :param show_progress: bool, show progress bar
        :return: batch word indices, batch word weights
        """
        if self.preprocessing_function:
            list_of_texts = [
                self.preprocessing_function(str(text))
                for text in list_of_texts
            ]
        else:
            list_of_texts = [str(text) for text in list_of_texts]
        indices = np.arange(len(list_of_texts))
        remaining_indices, batch_indices = next_batch(indices, batch_size)

        if len(list_of_texts) <= batch_size:
            show_progress = False

        disable_progress_bar = not show_progress
        progress_bar = tqdm(total=int(np.ceil(len(list_of_texts) /
                                              batch_size)),
                            disable=disable_progress_bar)

        while len(batch_indices) > 0:
            batch, batch_weights = [], []

            batch_descriptions = [
                list(
                    get_all(list_of_texts[index].split(),
                            self.info["word_ngrams"],
                            self.info["sort_ngrams"]))
                for index in batch_indices
            ]
            num_max_words = max([
                len(batch_description)
                for batch_description in batch_descriptions
            ]) + 1

            for batch_description in batch_descriptions:
                initial_indices = [0] + [
                    self.word_dict[phrase]["id"]
                    for phrase in batch_description if phrase in self.word_dict
                ]

                description_indices = np.array(
                    initial_indices +
                    [0 for _ in range(num_max_words - len(initial_indices))])
                description_weights = np.zeros_like(description_indices,
                                                    dtype=np.float32)
                description_weights[:len(initial_indices
                                         )] = 1. / len(initial_indices)

                batch.append(description_indices)
                batch_weights.append(description_weights)
            remaining_indices, batch_indices = next_batch(
                remaining_indices, batch_size)

            progress_bar.update()
            yield batch, batch_weights

        progress_bar.close()
Exemplo n.º 7
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("-md", "--model_dir", type=str, help="path where model.pb and model_params.json are")
    parser.add_argument("-tp", "--test_path", type=str, help="path to test file")
    parser.add_argument("-lp", "--label_prefix", type=str, help="label prefix", default="__label__")
    parser.add_argument("-bs", "--batch_size", type=int, default=1024, help="batch size for inference")
    parser.add_argument("-k", "--top_k", type=int, default=1, help="calculate accuracy on top k predictions")
    parser.add_argument("-hc", "--hand_check", type=bool, default=False, help="test on manually inputted data")
    parser.add_argument("-gpu", "--use_gpu", type=bool, default=True, help="use gpu for inference")
    parser.add_argument("-gpu_fr", "--gpu_fraction", type=float, default=0.4, help="what fraction of gpu to allocate")
    args = parser.parse_args()

    model_dir = args.model_dir
    model_params_path = os.path.join(model_dir, "model_params.json")
    model_path = os.path.join(model_dir, "model_best.pb")
    label_prefix = args.label_prefix

    if args.use_gpu:
        device = "/gpu:0"
        config = tf.ConfigProto(allow_soft_placement=True,
                                gpu_options=tf.GPUOptions(per_process_gpu_memory_fraction=args.gpu_fraction,
                                                          allow_growth=True))
    else:
        device = "/cpu:0"
        os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
        config = tf.ConfigProto(allow_soft_placement=True)

    num_thrown_for_label = 0
    with open(model_params_path, "r") as infile:
        model_params = json.load(infile)
    if os.path.isfile(model_params["label_dict_path"]):
        with open(model_params["label_dict_path"], "r") as infile:
            label_dict = json.load(infile)
    else:
        with open(os.path.join(model_dir, "label_dict.json"), "r") as infile:
            label_dict = json.load(infile)
    if os.path.isfile(model_params["word_dict_path"]):
        with open(model_params["word_dict_path"], "r") as infile:
            word_dict = json.load(infile)
    else:
        with open(os.path.join(model_dir, "word_dict.json"), "r") as infile:
            word_dict = json.load(infile)
    word_ngrams = model_params["word_ngrams"]
    sort_ngrams = model_params["sort_ngrams"]

    labels_dict_inverse = {}

    for label, label_id in label_dict.items():
        labels_dict_inverse[label_dict[label]["id"]] = label

    with tf.device(device):
        with tf.Session(config=config) as sess:
            run_arg = load_graph(model_path, ["input:0", "input_weights:0", "prediction:0"])
            if args.hand_check:
                while True:
                    query_description = input("Enter the description: ")
                    label = [token for token in query_description.split() if token.startswith(label_prefix)][0]
                    label = label.split(label_prefix)[-1]
                    query_description = query_description[20:]
                    test_description_indices = \
                        np.expand_dims([0] + [word_dict[phrase]["id"] for phrase in
                                              get_all(query_description.split(), word_ngrams, sort_ngrams)
                                              if phrase in word_dict], axis=0)

                    test_desc_weights = np.zeros_like(test_description_indices, dtype=np.float32)
                    test_desc_weights[0][:len(test_description_indices[0])] = 1. / len(test_description_indices[0])

                    if label not in label_dict:
                        print("New label")
                        continue

                    probabilities = np.squeeze(sess.run(run_arg[-1], feed_dict={run_arg[0]: test_description_indices,
                                                                                run_arg[1]: test_desc_weights}))

                    max_index = np.argmax(probabilities)
                    max_prob = probabilities[max_index]
                    predicted_label = labels_dict_inverse[max_index]
                    print(predicted_label == label, predicted_label, max_prob)
            else:
                test_descriptions, test_labels = parse_txt(args.test_path)
                test_indices = np.arange(len(test_descriptions))
                print("The total number of test datapoints: {}".format(len(test_descriptions)))

                progress_bar = tqdm(total=int(np.ceil(len(test_descriptions) / args.batch_size)))
                remaining_indices, batch_indices = next_batch(test_indices, args.batch_size)
                accuracy_top_1, accuracy_top_k = 0, 0
                cnt = 0

                while len(batch_indices) > 0:
                    batch_descriptions = [test_descriptions[i] for i in batch_indices]
                    batch_labels = [test_labels[i] for i in batch_indices]

                    batch, batch_weights, batch_labels2 = [], [], []

                    max_words = -1
                    for test_description in batch_descriptions:
                        max_words = max(max_words, len(test_description.split()))

                    num_max_words = 1
                    for ng in range(word_ngrams):
                        num_max_words += max_words - ng

                    for test_description, test_label in zip(batch_descriptions, batch_labels):
                        if test_label not in label_dict:
                            num_thrown_for_label += 1
                            continue
                        initial_test_indices = [0] + [word_dict[phrase]["id"] for phrase in
                                                      get_all(test_description.split(), word_ngrams, sort_ngrams)
                                                      if phrase in word_dict]

                        cnt += 1
                        test_description_indices = \
                            np.array(initial_test_indices +
                                     [0 for _ in range(num_max_words - len(initial_test_indices))])
                        test_description_weights = np.zeros_like(test_description_indices, dtype=np.float32)
                        test_description_weights[:len(initial_test_indices)] = 1. / len(initial_test_indices)

                        batch.append(test_description_indices)
                        batch_weights.append(test_description_weights)
                        batch_labels2.append(label_dict[test_label]["id"])

                    probabilities = sess.run(run_arg[-1], feed_dict={run_arg[0]: batch, run_arg[1]: batch_weights})
                    top_k = [np.argsort(i)[-args.top_k:] for i in probabilities]

                    accuracy_top_k += sum([True if i in j else False for i, j in zip(batch_labels2, top_k)])
                    accuracy_top_1 += sum([True if i == j[-1] else False for i, j in zip(batch_labels2, top_k)])
                    remaining_indices, batch_indices = next_batch(remaining_indices, args.batch_size)
                    progress_bar.update()
                progress_bar.close()

                print("{} datapoint thrown because of label".format(num_thrown_for_label))
                print("Number of test datapoints after cleaning: {}".format(len(test_descriptions) -
                                                                            num_thrown_for_label))
                print("Number of unique labels in test after cleaning: {}".format(len(set(test_labels))))
                print("Accuracy: {}".format(round(100 * accuracy_top_1 / len(test_descriptions), 2)))
                print("Accuracy top {}: {}".format(args.top_k, round(100 * accuracy_top_k / len(test_descriptions), 2)))