Пример #1
0
def analogies(eval_folder, embeddings, embeddings_file, dictionary, reverse_dictionary):
    """
    Evaluate embeddings with respect to analogies
    :param eval_folder: folder in which to write analogy results
    :param embeddings: embedding matrix to evaluate
    :param embeddings_file: file in which the embedding matrix is stored
    :param dictionary: [keys=statement, values==statement index]
    :param reverse_dictionary: [keys=statement index, values=statement]
    """
    # Create folder in which to write analogy results
    folder_analogies = os.path.join(eval_folder, "analogy")
    if not os.path.exists(folder_analogies):
        os.makedirs(folder_analogies)

    # Generate analogy "questions" and write them to a file
    analogy_questions_file = os.path.join(folder_analogies, "questions.txt")
    if not os.path.exists(analogy_questions_file):
        print('\n--- Generate analogy questions and write them to a file')
        analogygen.generate_analogy_questions(analogy_questions_file)

    # Load analogies
    analogy_questions_file_dump = os.path.join(folder_analogies, "questions")
    if not os.path.exists(analogy_questions_file_dump):

        # Read analogies from external file
        print('\n--- Read analogies from file ', analogy_questions_file)
        analogies, analogy_types, n_questions_total, n_questions_relevant = \
            load_analogy_questions(analogy_questions_file, dictionary)

        # Dump analogies into a file to be reused
        print('\n--- Writing analogies into file ', analogy_questions_file_dump)
        i2v_utils.safe_pickle([analogies, analogy_types, n_questions_total, n_questions_relevant],
                              analogy_questions_file_dump)

    else:

        # Load analogies from binary file
        print('\n--- Loading analogies from file ', analogy_questions_file_dump)
        with open(analogy_questions_file_dump, 'rb') as f:
            analogies, analogy_types, n_questions_total, n_questions_relevant = pickle.load(f)

    # Print info
    print('\tFound    {:>10,d} analogy-questions in total, '.format(n_questions_total))
    print('\tof which {:>10,d} are compatible with this vocabulary'.format(n_questions_relevant))

    # Evaluate
    summary = ''
    score_list = list()

    # Evaluate analogies in the embedding space
    analogy_eval_file = os.path.join(folder_analogies, 'res_' + embeddings_file[:-2].replace('/', '_') + '.txt')
    print('\n--- Starting analogy evaluation')

    # List of pairs (number of correctly answered questions in category, number of questions in category)
    scores = evaluate_analogies(embeddings, reverse_dictionary, analogies, analogy_types, analogy_eval_file)
    score_list.append(scores)
    summary += write_score_summary(scores, analogy_types, embeddings_file)

    # Print summary
    print(summary)
Пример #2
0
def load_analogies(data_folder):

    ####################################################################################################################
    # Generate analogy "questions" and write them to a file
    eval_folder = os.path.join(FLAGS.embeddings_folder, "eval")
    folder_analogies = os.path.join(eval_folder, "analogy")

    if not os.path.exists(folder_analogies):
        os.makedirs(folder_analogies)
    analogy_questions_file = os.path.join(folder_analogies, "questions.txt")
    if not os.path.exists(analogy_questions_file):
        print('\n--- Generating analogy questions and write them to a file')
        analogygen.generate_analogy_questions(analogy_questions_file)

    ####################################################################################################################
    # Read analogy "questions" from file
    folder_vocabulary = os.path.join(data_folder, "vocabulary")
    dictionary_pickle = os.path.join(folder_vocabulary, 'dic_pickle')
    print('\tLoading dictionary from file', dictionary_pickle)
    with open(dictionary_pickle, 'rb') as f:
        dictionary = pickle.load(f)

    analogy_questions_file_dump = os.path.join(folder_analogies, "questions")
    if not os.path.exists(analogy_questions_file_dump):

        # Read analogies from external file
        print('\n--- Read analogies from file ', analogy_questions_file)
        analogies, analogy_types, n_questions_total, n_questions_relevant = \
            load_analogy_questions(analogy_questions_file, dictionary)

        # Dump analogies into a file to be reused
        print('\n--- Writing analogies into file ',
              analogy_questions_file_dump)
        i2v_utils.safe_pickle([
            analogies, analogy_types, n_questions_total, n_questions_relevant
        ], analogy_questions_file_dump)

    else:

        # Load analogies from binary file
        print('\n--- Loading analogies from file ',
              analogy_questions_file_dump)
        with open(analogy_questions_file_dump, 'rb') as f:
            analogies, analogy_types, n_questions_total, n_questions_relevant = pickle.load(
                f)

    # Print info
    print('\tFound    {:>10,d} analogy-questions, '.format(n_questions_total))
    print('\tof which {:>10,d} are compatible with this vocabulary'.format(
        n_questions_relevant))

    return analogies, analogy_types, n_questions_total, n_questions_relevant
Пример #3
0
def construct_vocabulary(data_folder, folders):
    """
    Construct vocabulary from XFGs and indexify the data set
    :param data_folder: string containing the path to the parent directory of data sub-folders
    :param folders: list of sub-folders containing pre-processed LLVM IR code

    Files produced for vocabulary:
        data_folder/vocabulary/cutoff_stmts_pickle
        data_folder/vocabulary/cutoff_stmts.csv
        data_folder/vocabulary/dic_pickle
        data_folder/vocabulary/dic.csv
        data_folder/vocabulary/vocabulary_metadata_for_tboard
        data_folder/vocabulary/vocabulary_statistics_class.txt
        data_folder/vocabulary/vocabulary_statistics_freq.txt
    Files produced for pair-building:
        data_folder/*_datasetprep_adjmat/
        data_folder/*_datasetprep_cw_X/file_H_dic_cw_X.p
    Files produced for indexification:
        data_folder/*_dataset_cw_X/data_pairs_cw_3.rec
    """

    # Get options and flags
    context_width = FLAGS.context_width
    cutoff_unknown = FLAGS.cutoff_unknown
    subsample_threshold = FLAGS.subsampling

    # Vocabulary folder
    folder_vocabulary = os.path.join(data_folder, 'vocabulary')
    if not os.path.exists(folder_vocabulary):
        os.makedirs(folder_vocabulary)

    ####################################################################################################################
    # Build vocabulary
    dictionary_csv = os.path.join(folder_vocabulary, 'dic.csv')
    dictionary_pickle = os.path.join(folder_vocabulary, 'dic_pickle')
    cutoff_stmts_pickle = os.path.join(folder_vocabulary, 'cutoff_stmts_pickle')
    if not os.path.exists(dictionary_csv):

        # Combine the source data lists
        print('\n--- Combining', len(folders), 'folders into one data set from which we build a vocabulary')
        source_data_list_combined = dict()  # keys: statements as strings, values: number of occurences
        num_statements_total = 0

        for folder in folders:

            folder_preprocessed = folder + '_preprocessed'
            transformed_folder = os.path.join(folder_preprocessed, 'data_transformed')
            file_names_dict = get_file_names(folder)
            file_names = file_names_dict.values()
            num_files = len(file_names)
            count = 0

            for file_name in file_names:

                source = os.path.join(transformed_folder, file_name[:-3] + '.p')

                if os.path.exists(source):
                    with open(source, 'rb') as f:

                        # Load lists of statements
                        print('Fetching statements from file {:<60} ({:>2} / {:>2})'.format(
                            source, count, num_files))
                        source_data_list_ = pickle.load(f)

                        # Add to cummulated list
                        source_data_list_combined = add_to_vocabulary(source_data_list_combined, source_data_list_)

                        # Get numbers
                        num_statements_in_file = len(source_data_list_)
                        num_statements_total += num_statements_in_file
                        print('\tRead        {:>10,d} statements in this file'.format(num_statements_in_file))
                        print('\tAccumulated {:>10,d} statements so far'.format(num_statements_total))
                        del source_data_list_
                        count += 1

        # Get statistics of the combined list before pruning
        print('\n--- Compute some statistics on the combined data')
        vocabulary_statistics(source_data_list_combined, descr="combining data folders")

        # Prune data
        source_data_list_combined, stmts_cut_off = prune_vocabulary(source_data_list_combined, cutoff_unknown)

        # Get statistics of the combined list after pruning
        print('\n--- Compute some statistics on the combined data')
        vocabulary_statistics(source_data_list_combined, descr="pruning combined data")

        # Build the vocabulary
        print('\n--- Building the vocabulary and indices')

        # Set the vocabulary size
        vocabulary_size = len(source_data_list_combined)

        # Build data set: use ordering from original files, here statement-strings are being translated to indices
        number_statements = sum(list(source_data_list_combined.values()))
        dictionary = build_dictionary(source_data_list_combined)

        # Print information about the vocabulary to console
        out = '\tAfter building indexed vocabulary:\n' \
              + '\t--- {:<26}: {:>8,d}\n'.format('Number of stmts', number_statements) \
              + '\t--- {:<26}: {:>8,d}\n'.format('Vocabulary size', vocabulary_size)
        print(out)

        # Print information about the vocabulary to file
        vocab_info_file = os.path.join(folder_vocabulary, 'vocabulary_statistics')
        print_vocabulary(source_data_list_combined, vocab_info_file)

        # Print dictionary
        print('Writing dictionary to file', dictionary_pickle)
        i2v_utils.safe_pickle(dictionary, dictionary_pickle)
        print('Writing dictionary to file', dictionary_csv)
        with open(dictionary_csv, 'w', newline='') as f:
            fieldnames = ['#statement', 'index']
            writer = csv.DictWriter(f, fieldnames=fieldnames)
            writer.writeheader()
            data = [dict(zip(fieldnames, [k.replace('\n ', '\\n '), v])) for k, v in dictionary.items()]
            writer.writerows(data)

        # Print cut off statements
        print('Writing cut off statements to file', cutoff_stmts_pickle)
        i2v_utils.safe_pickle(stmts_cut_off, cutoff_stmts_pickle)
        cutoff_stmts_csv = os.path.join(folder_vocabulary, 'cutoff_stmts.csv')
        print('Writing cut off statements to file', cutoff_stmts_csv)
        with open(cutoff_stmts_csv, 'w', newline='\n') as f:
            for c in stmts_cut_off:
                f.write(c + '\n')
        del cutoff_stmts_csv

        # Print metadata file used by TensorBoard
        print('Building reverse dictionary...')
        reverse_dictionary = dict(zip(dictionary.values(), dictionary.keys()))
        vocab_metada_file = os.path.join(folder_vocabulary, 'vocabulary_metadata_for_tboard')
        print_vocabulary_metadata(reverse_dictionary, source_data_list_combined, vocab_metada_file)

        # Let go of variables that aren't needed anymore so as to reduce memory usage
        del source_data_list_combined

    ####################################################################################################################
    # Generate data-pair dictionaries

    # Load dictionary and cutoff statements
    print('\n--- Loading dictionary from file', dictionary_pickle)
    with open(dictionary_pickle, 'rb') as f:
        dictionary = pickle.load(f)
    print('Loading cut off statements from file', cutoff_stmts_pickle)
    with open(cutoff_stmts_pickle, 'rb') as f:
        stmts_cut_off = pickle.load(f)
    stmts_cut_off = set(stmts_cut_off)

    # Generate
    print('\n--- Generating data pair dictionary from dual graphs and dump to files')

    for folder in folders:

        folder_preprocessed = folder + '_preprocessed'
        folder_Dfiles = os.path.join(folder_preprocessed, 'xfg_dual')
        D_files_ = os.listdir(folder_Dfiles + '/')
        D_files = [Df for Df in D_files_ if Df[-2:] == '.p']
        num_D_files = len(D_files)
        folder_H = folder + '_datasetprep_cw_' + str(context_width)
        folder_mat = folder + '_datasetprep_adjmat'
        if not os.path.exists(folder_H):
            os.makedirs(folder_H)
        if not os.path.exists(folder_mat):
            os.makedirs(folder_mat)

        for i, D_file in enumerate(D_files):

            # "In-context" dictionary
            base_filename = D_file[:-2]
            D_file_open = os.path.join(folder_Dfiles, D_file)
            to_dump = os.path.join(folder_H, base_filename + "_H_dic_cw_" + str(context_width) + '.p')
            if not os.path.exists(to_dump):

                # Load dual graph
                print('Build H_dic from:', D_file_open, '(', i, '/', num_D_files, ')')
                with open(D_file_open, 'rb') as f:
                    D = pickle.load(f)

                # Build H-dictionary
                H_dic = build_H_dictionary(D, context_width, folder_mat, base_filename, dictionary, stmts_cut_off)
                print('Print to', to_dump)
                i2v_utils.safe_pickle(H_dic, to_dump)

            else:
                print('Found context-dictionary dump:', to_dump, '(', i, '/', num_D_files, ')')

    ####################################################################################################################
    # Generate data_pairs.rec from data pair dictionary dumps

    # Generate
    print('\n--- Writing .rec files')

    for folder in folders:

        # H dic dump files
        folder_H = folder + '_datasetprep_cw_' + str(context_width)
        H_files_ = os.listdir(folder_H + '/')
        H_files = [Hf for Hf in H_files_ if "_H_dic_cw_" + str(context_width) in Hf and Hf[-2:] == '.p']
        num_H_files = len(H_files)

        # Record files
        folder_REC = folder + '_dataset_cw_' + str(context_width)
        file_rec = os.path.join(folder_REC, 'data_pairs_cw_' + str(context_width) + '.rec')
        if not os.path.exists(folder_REC):
            os.makedirs(folder_REC)

        if not os.path.exists(file_rec):

            # Clear contents
            f = open(file_rec, 'wb')
            f.close()

            data_pairs_in_folder = 0
            for i, H_file in enumerate(H_files):

                dic_dump = os.path.join(folder_H, H_file)

                print('Building data pairs from file', dic_dump, '(', i, '/', num_H_files, ')')
                with open(dic_dump, 'rb') as f:
                    H_dic = pickle.load(f)

                # Get pairs [target, context] from graph and write them to file
                data_pairs = generate_data_pairs_from_H_dictionary(H_dic, subsample_threshold)
                data_pairs_in_folder += len(data_pairs)

                print('writing to fixed-length file: ', file_rec)

                # Start read and write
                counter = 0
                with open(file_rec, 'ab') as rec:

                    # Loop over pairs
                    num_pairs = len(data_pairs)
                    for p in data_pairs:

                        # Print progress ever so often
                        if counter % 10e5 == 0 and counter != 0:
                            print('wrote pairs: {:>10,d} / {:>10,d} ...'.format(counter, num_pairs))

                        # Write and increment counter
                        assert int(p[0]) < 184, "Found index " + str(int(p[0]))
                        assert int(p[1]) < 184, "Found index " + str(int(p[1]))
                        rec.write(struct.pack('II', int(p[0]), int(p[1])))
                        counter += 1

            print('Pairs in folder', folder, ':', data_pairs_in_folder)

        else:

            filesize_bytes = os.path.getsize(file_rec)
            # Number of pairs is filesize_bytes / 2 (pairs) / 4 (32-bit integers)
            file_pairs = int(filesize_bytes / 8)
            print('Found', file_rec, 'with #pairs:', file_pairs)
Пример #4
0
def train_embeddings(data_folder, data_folders):
    """
    Main function for embedding training workflow
    :param data_folder: string containing the path to the parent directory of raw data sub-folders
    :param data_folders: list of sub-folders containing pre-processed LLVM IR code
    :return embedding matrix

    Folders produced:
        data_folder/FLAGS.embeddings_folder/emb_cw_X_embeddings
        data_folder/FLAGS.embeddings_folder/emb_cw_X_train
    """

    # Get flag values
    restore_tf_variables_from_ckpt = FLAGS.restore
    context_width = FLAGS.context_width
    outfolder = FLAGS.embeddings_folder
    param = {k: FLAGS[k].value for k in FLAGS}

    # Set file signature
    file_signature = i2v_utils.set_file_signature(param, data_folder)

    # Print model parameters
    out_ = '\n--- Data files: '
    print(out_)
    out = out_ + '\n'
    num_data_pairs = 0
    data_pair_files = get_data_pair_files(data_folders, context_width)
    for data_pair_file in data_pair_files:
        filesize_bytes = os.path.getsize(
            data_pair_file
        )  # num pairs = filesize_bytes / 2 (pairs) / 4 (32-bit integers)
        file_pairs = int(filesize_bytes / 8)
        num_data_pairs += file_pairs
        out_ = '\t{:<60}: {:>12,d} pairs'.format(data_pair_file, file_pairs)
        print(out_)
        out += out_ + '\n'

    out_ = '\t{:<60}: {:>12,d} pairs'.format('total', num_data_pairs)
    print(out_)
    out += out_ + '\n'

    # Get dictionary and vocabulary
    print('\n\tGetting dictionary ...')
    folder_vocabulary = os.path.join(data_folder, 'vocabulary')
    dictionary_pickle = os.path.join(folder_vocabulary, 'dic_pickle')
    with open(dictionary_pickle, 'rb') as f:
        dictionary = pickle.load(f)
    reverse_dictionary = dict(zip(dictionary.values(), dictionary.keys()))
    del dictionary
    vocabulary_size = len(reverse_dictionary.keys())

    # Print Skip-Gram model parameters
    out_ = '\n--- Skip Gram model parameters'
    print(out_)
    out += out_ + '\n'
    out_ = '\tData folder             : {:<}'.format(data_folder)
    print(out_)
    out += out_ + '\n'
    out_ = '\tNumber of data pairs    : {:>15,d}'.format(num_data_pairs)
    print(out_)
    out += out_ + '\n'
    out_ = '\tVocabulary size         : {:>15,d}'.format(vocabulary_size)
    print(out_)
    out += out_ + '\n'
    out_ = '\tEmbedding size          : {:>15,d}'.format(
        param['embedding_size'])
    print(out_)
    out += out_ + '\n'
    out_ = '\tContext width           : {:>15,d}'.format(
        param['context_width'])
    print(out_)
    out += out_ + '\n'
    out_ = '\tMini-batch size         : {:>15,d}'.format(
        param['mini_batch_size'])
    print(out_)
    out += out_ + '\n'
    out_ = '\tNegative samples in NCE : {:>15,d}'.format(param['num_sampled'])
    print(out_)
    out += out_ + '\n'
    out_ = '\tL2 regularization scale : {:>15,e}'.format(param['beta'])
    print(out_)
    out += out_ + '\n'
    out_ = '\tNumber of epochs        : {:>15,d}'.format(param['num_epochs'])
    print(out_)
    out += out_ + '\n'
    out_ = '\tRestoring a prev. train : {}'.format(
        restore_tf_variables_from_ckpt)
    print(out_)
    out += out_ + '\n'

    # Print training information to file
    log_dir_ = os.path.join(outfolder,
                            'emb_cw_' + str(context_width) + '_train/')
    log_dir = os.path.join(log_dir_, file_signature[1:])
    if not os.path.exists(log_dir):
        os.makedirs(log_dir)
    train_info_file = os.path.join(log_dir, 'training_info.txt')
    with open(train_info_file, 'w') as f:
        f.write(out)

    # Validation set used to sample nearest neighbors
    # Limit to the words that have a low numeric ID,
    # which by construction are also the most frequent.
    valid_size = 30  # Random set of words to evaluate similarity on.
    valid_window = 50  # Only pick dev samples in the head of the distribution.
    valid_examples = np.random.choice(valid_window, valid_size, replace=False)

    # Copy metadata file into TensorBoard folder
    vocab_metada_file_ = os.path.join(folder_vocabulary,
                                      'vocabulary_metadata_for_tboard')
    v_metadata_file_name = 'vocab_metada_' + file_signature
    vocab_metada_file = os.path.join(log_dir, v_metadata_file_name)
    ckpt_saver_file = os.path.join(log_dir, "inst2vec.ckpt")
    ckpt_saver_file_init = os.path.join(log_dir, "inst2vec-init.ckpt")
    ckpt_saver_file_final = os.path.join(log_dir, "inst2vec-final.ckpt")
    os.makedirs(os.path.dirname(vocab_metada_file), exist_ok=True)
    subprocess.call('cp ' + vocab_metada_file_ + ' ' + vocab_metada_file,
                    shell=True)

    # Train the embeddings (Skip-Gram model)
    print('\n--- Setup completed, starting to train the embeddings')
    folder_embeddings = os.path.join(
        outfolder, 'emb_cw_' + str(context_width) + '_embeddings')
    if not os.path.exists(folder_embeddings):
        os.makedirs(folder_embeddings)
    embeddings_pickle = os.path.join(folder_embeddings,
                                     "emb_" + file_signature + ".p")
    embeddings = train_skip_gram(vocabulary_size, data_folder, data_folders,
                                 num_data_pairs, reverse_dictionary, param,
                                 valid_examples, log_dir, v_metadata_file_name,
                                 embeddings_pickle, ckpt_saver_file,
                                 ckpt_saver_file_init, ckpt_saver_file_final,
                                 restore_tf_variables_from_ckpt)

    # Save the embeddings and dictionaries in an external file to be reused later
    print('\n\tWriting embeddings to file', embeddings_pickle)
    i2v_utils.safe_pickle(embeddings, embeddings_pickle)

    # Write the embeddings to CSV file
    embeddings_csv = os.path.join(folder_embeddings,
                                  "emb_" + file_signature + ".csv")
    print('\t Writing embeddings to file ', embeddings_csv)
    np.savetxt(
        embeddings_csv,
        embeddings,
        delimiter=',',
        header=
        'Embeddings matrix, rows correspond to the embedding vector of statements'
    )

    return embeddings, embeddings_pickle
Пример #5
0
def train_skip_gram(V, data_folder, data_folders, dataset_size,
                    reverse_dictionary, param, valid_examples, log_dir,
                    vocab_metada_file, embeddings_pickle, ckpt_saver_file,
                    ckpt_saver_file_init, ckpt_saver_file_final,
                    restore_variables):
    """
    Train embeddings (Skip-Gram model)
    :param V: vocabulary size
    :param data_folder: string containing the path to the parent directory of raw data sub-folders
    :param data_folders: list of sub-folders containing pre-processed LLVM IR code
    :param dataset_size: number of data pairs in total in the training data set
    :param reverse_dictionary: [keys=statement index, values=statement]
    :param param: parameters of the inst2vec training
    :param valid_examples: statements to be used as validation examples (list of indices)
    :param log_dir: logging directory for Tensorboard output
    :param vocab_metada_file: vocabulary metadata file for Tensorboard
    :param embeddings_pickle: file in which to pickle embeddings
    :param ckpt_saver_file: checkpoint saver file (intermediate states of training)
    :param ckpt_saver_file_init: checkpoint saver file (initial state of training)
    :param ckpt_saver_file_final: checkpoint saver file (final state of training)
    :param restore_variables: boolean: whether to restore variables from a previous training
    :return: embeddings matrix
    """
    ####################################################################################################################
    # Extract parameters from dictionary "param"
    N = param['embedding_size']
    mini_batch_size = param['mini_batch_size']
    num_sampled = param['num_sampled']
    num_epochs = param['num_epochs']
    learning_rate = param['learning_rate']
    l2_reg_scale = param['beta']
    freq_print_loss = param['freq_print_loss']
    step_print_neighbors = param['step_print_neighbors']
    context_width = param['context_width']

    ####################################################################################################################
    # Set up for analogies
    analogies, analogy_types, n_questions_total, n_questions_relevant = i2v_eval.load_analogies(
        data_folder)
    folder_evaluation = embeddings_pickle.replace('.p', '') + 'eval'
    if not os.path.exists(folder_evaluation):
        os.makedirs(folder_evaluation)
    analogy_evaluation_file = os.path.join(folder_evaluation,
                                           "analogy_results")

    config = None
    options = None
    metadata = None
    if FLAGS.profile:
        options = tf.compat.v1.RunOptions(
            trace_level=tf.compat.v1.RunOptions.FULL_TRACE)
        metadata = tf.compat.v1.RunMetadata()
    if FLAGS.xla:
        config = tf.compat.v1.ConfigProto()
        config.graph_options.optimizer_options.global_jit_level = tf.compat.v1.OptimizerOptions.ON_1

    ####################################################################################################################
    # Read data using Tensorflow's data API
    data_files = get_data_pair_files(data_folders, context_width)
    print('\ttraining with data from files:', data_files)
    with tf.compat.v1.name_scope("Reader") as scope:

        random.shuffle(data_files)
        dataset_raw = tf.data.FixedLengthRecordDataset(
            filenames=data_files,
            record_bytes=8)  # <TFRecordDataset shapes: (), types: tf.string>
        dataset = dataset_raw.map(record_parser)
        dataset = dataset.shuffle(int(1e5))
        dataset_batched = dataset.batch(
            mini_batch_size, drop_remainder=True
        )  #apply(tf.contrib.data.batch_and_drop_remainder(mini_batch_size))
        dataset_batched = dataset_batched.prefetch(int(100000000))
        iterator = tf.compat.v1.data.make_initializable_iterator(
            dataset_batched)
        saveable_iterator = tf.data.experimental.make_saveable_from_iterator(
            iterator)
        next_batch = iterator.get_next(
        )  # Tensor("Shape:0", shape=(2,), dtype=int32)

    ####################################################################################################################
    # Tensorflow computational graph
    # Placeholders for inputs
    with tf.compat.v1.name_scope("Input_Data") as scope:
        train_inputs = next_batch[:, 0]
        train_labels = tf.reshape(next_batch[:, 1],
                                  shape=[mini_batch_size, 1],
                                  name="training_labels")

    # (input) Embedding matrix
    with tf.compat.v1.name_scope("Input_Layer") as scope:
        W_in = tf.Variable(tf.random.uniform([V, N], -1.0, 1.0),
                           name="input-embeddings")

        # Look up the vector representing each source word in the batch (fetches rows of the embedding matrix)
        h = tf.compat.v1.nn.embedding_lookup(params=W_in,
                                             ids=train_inputs,
                                             name="input_embedding_vectors")

    # Normalized embedding matrix
    with tf.compat.v1.name_scope("Embeddings_Normalized") as scope:
        normalized_embeddings = tf.nn.l2_normalize(
            W_in, name="embeddings_normalized")

    # (output) Embedding matrix ("output weights")
    with tf.compat.v1.name_scope("Output_Layer") as scope:
        if FLAGS.softmax:
            W_out = tf.Variable(tf.random.truncated_normal([N, V],
                                                           stddev=1.0 /
                                                           math.sqrt(N)),
                                name="output_embeddings")
        else:
            W_out = tf.Variable(tf.random.truncated_normal([V, N],
                                                           stddev=1.0 /
                                                           math.sqrt(N)),
                                name="output_embeddings")

        # Biases between hidden layer and output layer
        b_out = tf.Variable(tf.zeros([V]), name="nce_bias")

    # Optimization
    with tf.compat.v1.name_scope("Optimization_Block") as scope:
        # Loss function
        if FLAGS.softmax:
            logits = tf.compat.v1.layers.dense(inputs=h, units=V)
            onehot = tf.one_hot(train_labels, V)
            loss_tensor = tf.nn.softmax_cross_entropy_with_logits(
                labels=onehot, logits=logits)
        else:
            loss_tensor = tf.nn.nce_loss(weights=W_out,
                                         biases=b_out,
                                         labels=train_labels,
                                         inputs=h,
                                         num_sampled=num_sampled,
                                         num_classes=V)
        train_loss = tf.reduce_mean(input_tensor=loss_tensor, name="nce_loss")

        # Regularization (optional)
        if l2_reg_scale > 0:
            tf.compat.v1.add_to_collection(
                tf.compat.v1.GraphKeys.REGULARIZATION_LOSSES, W_in)
            tf.compat.v1.add_to_collection(
                tf.compat.v1.GraphKeys.REGULARIZATION_LOSSES, W_out)
            regularizer = tf.keras.regularizers.l2(0.5 * (l2_reg_scale))
            reg_variables = tf.compat.v1.get_collection(
                tf.compat.v1.GraphKeys.REGULARIZATION_LOSSES)
            reg_term = regularizer(
                reg_variables
            )  #tf.contrib.layers.apply_regularization(regularizer, reg_variables)
            loss = train_loss + reg_term
        else:
            loss = train_loss

        # Optimizer
        if FLAGS.optimizer == 'adam':
            optimizer = tf.compat.v1.train.AdamOptimizer(
                learning_rate=learning_rate).minimize(loss)
        elif FLAGS.optimizer == 'nadam':
            optimizer = tf.keras.optimizers.Nadam(
                learning_rate=learning_rate).minimize(loss)
        elif FLAGS.optimizer == 'momentum':
            global_train_step = tf.Variable(0,
                                            trainable=False,
                                            dtype=tf.int32,
                                            name="global_step")
            # Passing global_step to minimize() will increment it at each step.
            optimizer = (tf.compat.v1.train.MomentumOptimizer(
                learning_rate, 0.9).minimize(loss,
                                             global_step=global_train_step))
        else:
            raise ValueError('Unrecognized optimizer ' + FLAGS.optimizer)

    if FLAGS.optimizer != 'momentum':
        global_train_step = tf.Variable(0,
                                        trainable=False,
                                        dtype=tf.int32,
                                        name="global_step")

    ####################################################################################################################
    # Validation block
    with tf.compat.v1.name_scope("Validation_Block") as scope:
        valid_dataset = tf.constant(valid_examples,
                                    dtype=tf.int32,
                                    name="validation_data_size")
        valid_embeddings = tf.compat.v1.nn.embedding_lookup(
            params=normalized_embeddings, ids=valid_dataset)
        cosine_similarity = tf.matmul(valid_embeddings,
                                      normalized_embeddings,
                                      transpose_b=True)

    ####################################################################################################################
    # Summaries
    with tf.compat.v1.name_scope("Summaries") as scope:
        tf.compat.v1.summary.histogram("input_embeddings", W_in)
        tf.compat.v1.summary.histogram("input_embeddings_normalized",
                                       normalized_embeddings)
        tf.compat.v1.summary.histogram("output_embeddings", W_out)
        tf.compat.v1.summary.scalar("nce_loss", loss)

        analogy_score_tensor = tf.Variable(0,
                                           trainable=False,
                                           dtype=tf.int32,
                                           name="analogy_score")
        tf.compat.v1.summary.scalar("analogy_score", analogy_score_tensor)

    ####################################################################################################################
    # Misc.
    restore_completed = False
    init = tf.compat.v1.global_variables_initializer()  # variables initializer
    summary_op = tf.compat.v1.summary.merge_all(
    )  # merge summaries into one operation

    ####################################################################################################################
    # Training
    with tf.compat.v1.Session(config=config) as sess:

        # Add TensorBoard components
        writer = tf.compat.v1.summary.FileWriter(
            log_dir)  # create summary writer
        writer.add_graph(sess.graph)
        gvars = [
            gvar for gvar in tf.compat.v1.global_variables()
            if 'analogy_score' not in gvar.name
        ]
        saver = tf.compat.v1.train.Saver(
            gvars, max_to_keep=5)  # create checkpoint saver
        config = projector.ProjectorConfig()  # create projector config
        embedding = config.embeddings.add()  # add embeddings visualizer
        embedding.tensor_name = W_in.name
        embedding.metadata_path = vocab_metada_file  # link metadata
        projector.visualize_embeddings(
            writer, config)  # add writer and config to projector

        # Set up variables
        if restore_variables:  # restore variables from disk
            restore_file = tf.train.latest_checkpoint(log_dir)
            assert restore_file is not None, "No restore file found in folder " + log_dir
            assert os.path.exists(restore_file + ".index"), \
                "Trying to restore Tensorflow session from non-existing file: " + restore_file + ".index"
            init.run()
            saver.restore(sess, restore_file)
            print("\tVariables restored from file", ckpt_saver_file,
                  "in TensorFlow ")

        else:  # save the computational graph to file and initialize variables

            graph_saver = tf.compat.v1.train.Saver(allow_empty=True)
            init.run()
            graph_saver.save(sess,
                             ckpt_saver_file_init,
                             global_step=0,
                             write_meta_graph=True)
            tf.compat.v1.add_to_collection(
                tf.compat.v1.GraphKeys.SAVEABLE_OBJECTS, saveable_iterator)
            print("\tVariables initialized in TensorFlow")

        # Compute the necessary number of steps for this epoch as well as how often to print the avg loss
        num_steps = int(math.ceil(dataset_size / mini_batch_size))
        step_print_loss = int(math.ceil(num_steps / freq_print_loss))
        print('\tPrinting loss every ', step_print_loss, 'steps, i.e.',
              freq_print_loss, 'times per epoch')

        ################################################################################################################
        # Epoch loop
        epoch = 0
        global_step = 0
        while epoch < int(num_epochs):
            print('\n\tStarting epoch ', epoch)
            sess.run(iterator.initializer)  # initialize iterator

            # If restoring a previous training session, set the right training epoch
            if restore_variables and not restore_completed:
                epoch = int(
                    math.floor(global_train_step.eval() /
                               (dataset_size / mini_batch_size)))
                global_step = global_train_step.eval()
                print('Starting from epoch', epoch)

            ############################################################################################################
            # Loop over steps (mini batches) inside of epoch
            step = 0
            avg_loss = 0
            while True:

                try:

                    # Print average loss every x steps
                    if step_print_loss > 0 and step % int(
                            step_print_loss) == 0:  # update step with logging

                        # If restoring a previous training session, set the right training epoch
                        if restore_variables and not restore_completed:
                            restore_completed = True

                        # Write global step
                        if FLAGS.optimizer != 'momentum':
                            global_train_step.assign(global_step).eval()

                        # Perform an update
                        # print('\tStarting local step {:>6}'.format(step))  # un-comment for debugging
                        [_, loss_val, train_loss_val, global_step] = sess.run(
                            [optimizer, loss, train_loss, global_train_step],
                            options=options,
                            run_metadata=metadata)
                        assert not np.isnan(
                            loss_val), "Loss at step " + str(step) + " is nan"
                        assert not np.isinf(
                            loss_val), "Loss at step " + str(step) + " is inf"
                        avg_loss += loss_val

                        if step > 0:
                            avg_loss /= step_print_loss

                        analogy_score = i2v_eval.evaluate_analogies(
                            W_in.eval(),
                            reverse_dictionary,
                            analogies,
                            analogy_types,
                            analogy_evaluation_file,
                            session=sess,
                            print=i2v_eval.nop)
                        total_analogy_score = sum(
                            [a[0] for a in analogy_score])
                        analogy_score_tensor.assign(
                            total_analogy_score).eval()  # for tf.summary

                        [summary, W_in_val] = sess.run([summary_op, W_in])

                        if FLAGS.savebest is not None:
                            filelist = [f for f in os.listdir(FLAGS.savebest)]
                            scorelist = [
                                int(s.split('-')[1]) for s in filelist
                            ]
                            if len(scorelist
                                   ) == 0 or total_analogy_score > sorted(
                                       scorelist)[-1]:
                                i2v_utils.safe_pickle(
                                    W_in_val, FLAGS.savebest + '/' + 'score-' +
                                    str(total_analogy_score) + '-w.p')

                        # Display average loss
                        print(
                            '{} Avg. loss at epoch {:>6,d}, step {:>12,d} of {:>12,d}, global step {:>15} : {:>12.3f}, analogies: {})'
                            .format(str(datetime.now()), epoch, step,
                                    num_steps, global_step, avg_loss,
                                    str(analogy_score)))
                        avg_loss = 0

                        # Pickle intermediate embeddings
                        i2v_utils.safe_pickle(W_in_val, embeddings_pickle)

                        # Write to TensorBoard
                        saver.save(sess,
                                   ckpt_saver_file,
                                   global_step=global_step,
                                   write_meta_graph=False)
                        writer.add_summary(summary, global_step=global_step)

                        if FLAGS.profile:
                            fetched_timeline = timeline.Timeline(
                                metadata.step_stats)
                            chrome_trace = fetched_timeline.generate_chrome_trace_format(
                            )
                            with open('timeline_step_%d.json' % step,
                                      'w') as f:
                                f.write(chrome_trace)

                        if step > 0 and FLAGS.extreme:
                            sys.exit(22)

                    else:  # ordinary update step
                        [_, loss_val] = sess.run([optimizer, loss])
                        avg_loss += loss_val

                    # Compute and print nearest neighbors every x steps
                    if step_print_neighbors > 0 and step % int(
                            step_print_neighbors) == 0:
                        print_neighbors(op=cosine_similarity,
                                        examples=valid_examples,
                                        top_k=6,
                                        reverse_dictionary=reverse_dictionary)

                    # Update loop index (steps in epoch)
                    step += 1
                    global_step += 1

                except tf.errors.OutOfRangeError:

                    # We reached the end of the epoch
                    print('\n\t Writing embeddings to file ',
                          embeddings_pickle)
                    i2v_utils.safe_pickle([W_in.eval()],
                                          embeddings_pickle)  # WEIRD!
                    epoch += 1  # update loop index (epochs)
                    break  # from this inner loop

        ################################################################################################################
        # End of training:
        # Print the nearest neighbors at the end of the run
        if step_print_neighbors == -1:
            print_neighbors(op=cosine_similarity,
                            examples=valid_examples,
                            top_k=6,
                            reverse_dictionary=reverse_dictionary)

        # Save state of training and close the TensorBoard summary writer
        save_path = saver.save(sess, ckpt_saver_file_final, global_step)
        writer.add_summary(summary, global_step)
        writer.close()

        return W_in.eval()