Beispiel #1
0
def plot_to_projector(
    x,
    feature_vector,
    y,
    class_names,
    log_dir="default_log_dir",
    meta_file="metadata.tsv",
):
    assert x.ndim == 4  # (BATCH, H, W, C)

    if os.path.isdir(log_dir):
        shutil.rmtree(log_dir)

    # Create a new clean fresh folder :)
    os.mkdir(log_dir)

    SPRITES_FILE = os.path.join(log_dir, "sprites.png")
    sprite = create_sprite(x)
    cv2.imwrite(SPRITES_FILE, sprite)

    # Generate label names
    labels = [class_names[y[i]] for i in range(int(y.shape[0]))]

    with open(os.path.join(log_dir, meta_file), "w") as f:
        for label in labels:
            f.write("{}\n".format(label))

    if feature_vector.ndim != 2:
        print("NOTE: Feature vector is not of form (BATCH, FEATURES)"
              " reshaping to try and get it to this form!")
        feature_vector = tf.reshape(feature_vector,
                                    [feature_vector.shape[0], -1])

        feature_vector = tf.Variable(feature_vector)
        checkpoint = tf.train.Checkpoint(embedding=feature_vector)
        checkpoint.save(os.path.join(log_dir, "embeddings.ckpt"))

        # Set up config
        config = projector.ProjectorConfig()
        embedding = config.embeddings.add()
        embedding.tensor_name = "embedding/.ATTRIBUTES/VARIABLE_VALUE"
        embedding.metadata_path = meta_file
        embedding.sprite.image_path = "sprites.png"
        embedding.sprite.single_image_dim.extend((x.shape[1], x.shape[2]))
        projector.visualize_embeddings(log_dir, config)
def plot_with_tensorboard(
    logdir: Path,
    sentences: List[str],
    clusters: Union[List, np.ndarray],
    embeddings: Dict[str, np.ndarray],
):
    """
    Saves Tensorboard embeddings to projector.
    
    Args:
        logdir (Path): Directory where to save.
        sentences (List[str]): Sentences for metadata.
        embeddings (Dict[str, np.ndarray]): Embeddings to plot.
    """
    # Set up a logs directory, so Tensorboard knows where to look for files
    if not logdir.exists():
        logdir.mkdir(exist_ok=True, parents=True)

    # Save Labels separately on a line-by-line manner.
    with open(str(logdir / 'metadata.tsv'), "w") as f:
        f.write('Index\tSentence\tCluster\n')
        for i, (sentence, cluster) in enumerate(zip(sentences, clusters)):
            f.write(f"{i}\t{sentence}\t{cluster}\n")

    # Set up config
    config = projector.ProjectorConfig()
    variables = {}
    for name, embedding in embeddings.items():
        # Save the weights we want to analyse as a variable. Note that the first
        # value represents any unknown word, which is not in the metadata, so
        # we will remove that value.
        variable = tf.Variable(embedding, name=name)
        # Create a checkpoint from embedding, the filename and key are
        # name of the tensor.
        variables = {**variables, name: variable}

        embedding = config.embeddings.add()
        # The name of the tensor will be suffixed by `/.ATTRIBUTES/VARIABLE_VALUE`
        embedding.tensor_name = f"{name}/.ATTRIBUTES/VARIABLE_VALUE"
        embedding.metadata_path = "metadata.tsv"

    checkpoint = tf.train.Checkpoint(**variables)
    checkpoint.save(str(logdir / f"embeddings.ckpt"))
    projector.visualize_embeddings(str(logdir), config)
Beispiel #3
0
 def export_movies_embeddings(self, path: pathlib, movie_to_title: dict):
     # Creates folders if don't exist
     if not path.exists():
         path.mkdir(parents=True)
     # Creates metadata for the embedding projector
     with open(path / 'metadata.tsv', "w") as f:
         f.writelines([f'{title}\n' for title in movie_to_title.values()])
     # Creates checkpoint storing variable with embeddings matching the metadata info
     indexes = self.movie_to_index[tf.constant(list(movie_to_title.keys()))]
     embeddings = tf.gather(self.movies_embeddings, indices=indexes, axis=0)
     checkpoint = tf.train.Checkpoint(embeddings=tf.Variable(embeddings))
     checkpoint.save(str(path / "embeddings.ckpt"))
     # Set up projector config
     config = projector.ProjectorConfig()
     embedding = config.embeddings.add()
     # The name of the tensor will be suffixed by `/.ATTRIBUTES/VARIABLE_VALUE`
     embedding.tensor_name = "embeddings/.ATTRIBUTES/VARIABLE_VALUE"
     embedding.metadata_path = 'metadata.tsv'
     projector.visualize_embeddings(str(path), config)
def main():
    model = tf.keras.models.load_model(MODEL_PATH)
    # with open(os.path.join(LOG_DIR, 'metadata.tsv'), "w") as f:
    #     for subwords in encoder.subwords:
    #         f.write("{}\n".format(subwords))
    #     # Fill in the rest of the labels with "unknown"
    #     for unknown in range(1, encoder.vocab_size - len(encoder.subwords)):
    #         f.write("unknown #{}\n".format(unknown))
    weights = tf.Variable(model.embed.get_weights()[0][1:])
    # Create a checkpoint from embedding, the filename and key are
    # name of the tensor.
    checkpoint = tf.train.Checkpoint(embedding=weights)
    checkpoint.save(os.path.join(LOG_DIR, "embedding.ckpt"))
    config = projector.ProjectorConfig()
    embedding = config.embeddings.add()
    # The name of the tensor will be suffixed by `/.ATTRIBUTES/VARIABLE_VALUE`
    embedding.tensor_name = "embedding/.ATTRIBUTES/VARIABLE_VALUE"
    # embedding.metadata_path = 'metadata.tsv'
    projector.visualize_embeddings(LOG_DIR, config)
Beispiel #5
0
  def testVisualizeEmbeddings(self):
    # Create a dummy configuration.
    config = projector.ProjectorConfig()
    config.model_checkpoint_path = 'test'
    emb1 = config.embeddings.add()
    emb1.tensor_name = 'tensor1'
    emb1.metadata_path = 'metadata1'

    # Call the API method to save the configuration to a temporary dir.
    temp_dir = self.get_temp_dir()
    self.addCleanup(shutil.rmtree, temp_dir)
    writer = tf.summary.FileWriter(temp_dir)
    projector.visualize_embeddings(writer, config)

    # Read the configurations from disk and make sure it matches the original.
    with tf.gfile.GFile(os.path.join(temp_dir, 'projector_config.pbtxt')) as f:
      config2 = projector.ProjectorConfig()
      text_format.Parse(f.read(), config2)
      self.assertEqual(config, config2)
Beispiel #6
0
    def create_config(self, with_sprite=True):
        """
        Create a congfig files that defines image tensor name, path to metadata file, path to the sprite image,
        and the size of individual image whithin the sprite image.

        Parameters
        ----------
        with_sprite : bool, optional
            If to save sprite or not, by default True
        """
        config = projector.ProjectorConfig()
        embedding = config.embeddings.add()
        embedding.tensor_name = f'{self.data_name}/.ATTRIBUTES/VARIABLE_VALUE'
        embedding.metadata_path = f'metadata_{self.data_name}.tsv'
        if with_sprite:
            embedding.sprite.image_path = f'sprite_{self.data_name}.png'
            embedding.sprite.single_image_dim.extend(
                [self.image_width, self.image_height])
        projector.visualize_embeddings(self.log_dir, config)
Beispiel #7
0
    def _configure_embeddings(self):
        """Configure the Projector for embeddings.
        Implementation from tensorflow codebase, but supports multiple models
        """
        try:
            # noinspection PyPackageRequirements
            from tensorboard.plugins import projector
        except ImportError:
            raise ImportError(
                'Failed to import TensorBoard. Please make sure that '
                'TensorBoard integration is complete."')
        config = projector.ProjectorConfig()
        for model_name, model in self.network.model.items():
            for layer in model.layers:
                if isinstance(layer, embeddings.Embedding):
                    embedding = config.embeddings.add()
                    embedding.tensor_name = layer.embeddings.name

                    if self.embeddings_metadata is not None:
                        if isinstance(self.embeddings_metadata, str):
                            embedding.metadata_path = self.embeddings_metadata
                        else:
                            if layer.name in embedding.metadata_path:
                                embedding.metadata_path = self.embeddings_metadata.pop(
                                    layer.name)

        if self.embeddings_metadata:
            raise ValueError(
                'Unrecognized `Embedding` layer names passed to '
                '`keras.callbacks.TensorBoard` `embeddings_metadata` '
                'argument: ' + str(self.embeddings_metadata))

        class DummyWriter(object):
            """Dummy writer to conform to `Projector` API."""
            def __init__(self, logdir):
                self.logdir = logdir

            def get_logdir(self):
                return self.logdir

        writer = DummyWriter(self.train_log_dir)
        projector.visualize_embeddings(writer, config)
def create_config_file(path_logs, tensor_filename, features,
                       image_sprite_filename, metadata_filename,
                       image_sprite_size):

    with open(tensor_filename, 'w') as fw:
        csv_writer = csv.writer(fw, delimiter='\t')
        csv_writer.writerows(features)

    config = projector.ProjectorConfig()
    # One can add multiple embeddings.
    embedding = config.embeddings.add()
    embedding.tensor_path = tensor_filename
    # Link this tensor to its metadata file (e.g. labels).
    embedding.metadata_path = metadata_filename
    # Comment out if you don't want sprites
    embedding.sprite.image_path = image_sprite_filename
    embedding.sprite.single_image_dim.extend(
        [image_sprite_size[0], image_sprite_size[1]])
    # Saves a config file that TensorBoard will read during startup.
    projector.visualize_embeddings(path_logs, config)
Beispiel #9
0
    def visualize_data_frame(self, features: embeddings.LabelledFeatures) -> None:

        print("Exporting tensorboard logs to: {}".format(self._output_path))

        features = _sample_if_needed(features)

        path_metadata = self._resolved_path(FILENAME_METADATA)
        path_features = self._resolved_path(FILENAME_FEATURES)

        _write_labels(features.labels, path_metadata)

        _save_embedding_as_checkpoint(
            self._maybe_project(features.features), path_features
        )

        projector_config = _create_projector_config(
            path_metadata, self._maybe_create_sprite(features.image_paths)
        )

        projector.visualize_embeddings(self._output_path, projector_config)
Beispiel #10
0
def write_embedding(log_dir):
    """Writes embedding data and projector configuration to the logdir."""
    metadata_filename = "metadata.tsv"
    tensor_filename = "tensor.tsv"

    labels = ANIMALS.strip().splitlines()
    labels_to_tensors = {label: tensor_for_label(label) for label in labels}
    os.makedirs(log_dir, exist_ok=True)
    with open(os.path.join(log_dir, metadata_filename), "w") as f:
        for label in labels_to_tensors:
            f.write("{}\n".format(label))
    with open(os.path.join(log_dir, tensor_filename), "w") as f:
        for tensor in labels_to_tensors.values():
            f.write("{}\n".format("\t".join(str(x) for x in tensor)))

    config = projector.ProjectorConfig()
    embedding = config.embeddings.add()
    embedding.metadata_path = metadata_filename
    embedding.tensor_path = tensor_filename
    projector.visualize_embeddings(log_dir, config)
    def setup_sample_training_data(self, log_dir, writer):

        # get some samples
        segs = [self.datasets.train.sample_segment() for _ in range(1000)]
        sample_X = []
        sample_y = []
        for segment in segs:
            data = self.datasets.train.fetch_segment(segment, augment=False)
            sample_X.append(data)
            sample_y.append(self.labels.index(segment.label))

        X = np.asarray(sample_X, dtype=np.float32)
        y = np.asarray(sample_y, dtype=np.int32)

        data = (X, y, segs)
        sprite_path = os.path.join(log_dir, "examples.png")
        meta_path = os.path.join(log_dir, "examples.tsv")

        config = projector.ProjectorConfig()
        for var_name in ["sample_logits", "sample_hidden"]:
            embedding = config.embeddings.add()
            embedding.tensor_name = var_name
            embedding.metadata_path = meta_path
            embedding.sprite.image_path = sprite_path
            embedding.sprite.single_image_dim.extend([48, 48])

        projector.visualize_embeddings(writer, config)

        # save tsv file containing labels"
        with open(meta_path, "w") as f:
            f.write("Index\tLabel\tSource\n")
            for index, segment in enumerate(segs):
                f.write("{}\t{}\t{}\n".format(index, segment.label,
                                              segment.clip_id))

        # save out image previews
        to_vis = X[:, self.training_segment_frames // 2, 0]
        sprite_image = self._create_sprite_image(to_vis)
        plt.imsave(sprite_path, sprite_image, cmap="gray")

        return data
Beispiel #12
0
def save(df: pd.DataFrame, data_out: str):
    """Saves artifacts for tensorboard embeddings projector"""
    df.to_csv(f'{data_out}metadata.tsv',
              sep='\t',
              columns=['label', 'title'],
              index=False,
              header=True)

    embeddings = df['embedding']
    embeddings = embeddings.tolist()
    embeddings = np.array(embeddings)

    embeddings_tensor = tf.Variable(embeddings, name='embedding')
    checkpoint = tf.train.Checkpoint(embedding=embeddings_tensor)
    checkpoint.save(os.path.join(data_out, 'embedding.ckpt'))

    config = projector.ProjectorConfig()
    embedding = config.embeddings.add()
    embedding.tensor_name = "embedding/.ATTRIBUTES/VARIABLE_VALUE"
    embedding.metadata_path = 'metadata.tsv'
    projector.visualize_embeddings(data_out, config)
Beispiel #13
0
    def _configure_embeddings(self):
        """Configure the Projector for embeddings."""
        # TODO(omalleyt): Add integration tests.
        from tensorflow.python.keras.layers import embeddings
        try:
            from tensorboard.plugins import projector
        except ImportError:
            raise ImportError(
                'Failed to import TensorBoard. Please make sure that '
                'TensorBoard integration is complete."')
        config = projector.ProjectorConfig()
        for layer in self.model.layers:
            if isinstance(layer, embeddings.Embedding):
                embedding = config.embeddings.add()
                embedding.tensor_name = layer.embeddings.name

                if self.embeddings_metadata is not None:
                    if isinstance(self.embeddings_metadata, str):
                        embedding.metadata_path = self.embeddings_metadata
                    else:
                        if layer.name in embedding.metadata_path:
                            embedding.metadata_path = self.embeddings_metadata.pop(
                                layer.name)

        if self.embeddings_metadata:
            raise ValueError(
                'Unrecognized `Embedding` layer names passed to '
                '`keras.callbacks.TensorBoard` `embeddings_metadata` '
                'argument: ' + str(self.embeddings_metadata.keys()))

        class DummyWriter(object):
            """Dummy writer to conform to `Projector` API."""
            def __init__(self, logdir):
                self.logdir = logdir

            def get_logdir(self):
                return self.logdir

        writer = DummyWriter(self.log_dir)
        projector.visualize_embeddings(writer, config)
Beispiel #14
0
def visualize(model, output_path):
    meta_file = "test/w2x_metadata.tsv"
    placeholder = np.zeros((len(model.wv.index2word), 50))

    with open(os.path.join(output_path, meta_file), 'wb') as file_metadata:
        for i, word in enumerate(model.wv.index2word):
            placeholder[i] = model[word]
            # temporary solution for https://github.com/tensorflow/tensorflow/issues/9094
            if word == '':
                print(
                    "Emply Line, should replecaed by any thing else, or will cause a bug of tensorboard"
                )
                file_metadata.write(
                    "{0}".format('<Empty Line>').encode('utf-8') + b'\n')
            else:
                file_metadata.write("{0}".format(word).encode('utf-8') + b'\n')

    # define the model without training
    sess = tf.InteractiveSession()

    embedding = tf.Variable(placeholder, trainable=False, name='w2x_metadata')
    tf.global_variables_initializer().run()
    tf.disable_eager_execution()
    saver = tf.train.Saver()
    writer = tf.summary.FileWriter(output_path, sess.graph)

    # adding into projector
    config = projector.ProjectorConfig()
    embed = config.embeddings.add()
    embed.tensor_name = 'w2x_metadata'
    embed.metadata_path = meta_file

    # Specify the width and height of a single thumbnail.
    projector.visualize_embeddings(writer, config)
    saver.save(sess=sess,
               save_path=os.path.join(output_path, 'w2x_metadata.ckpt'))
    print(
        'Run `tensorboard --logdir={0}` to run visualize result on tensorboard'
        .format(output_path))
def project(word_vecs, id2word):
    log_dir = 'logs'
    if not os.path.exists(log_dir):
        os.mkdir(log_dir)

    with open(os.path.join(log_dir, 'metadata.tsv'), "w") as f:
        f.write("{}\t{}\n".format("index", "word"))
        for i, word in id2word.items():
            f.write("{}\t{}\n".format(i, word))

    weights = tf.Variable(word_vecs)
    checkpoint = tf.train.Checkpoint(embedding=weights)
    checkpoint.save(os.path.join(log_dir, "embedding.ckpt"))

    # Set up config
    config = projector.ProjectorConfig()
    embedding = config.embeddings.add()

    # The name of the tensor will be suffixed by `/.ATTRIBUTES/VARIABLE_VALUE`
    embedding.tensor_name = "embedding/.ATTRIBUTES/VARIABLE_VALUE"
    embedding.metadata_path = 'metadata.tsv'
    projector.visualize_embeddings(log_dir, config)
Beispiel #16
0
def visualize_embeddings(images,
                         embeddings,
                         output_dir,
                         thumbnail_size=(32, 32)):
    """

  Args:
    images:
    embeddings:
    output_dir:
    thumbnail_size:

  Returns:

  """
    summary_writer = tf.summary.FileWriter(output_dir)

    sprite_path = os.path.abspath(os.path.join(output_dir, 'sprite.png'))
    metadata_path = os.path.abspath(os.path.join(output_dir, 'metadata.csv'))
    embeddings_path = os.path.join(output_dir, 'embeddings.ckpt')

    embedding_var = tf.Variable(embeddings, name='embeddings')
    sprite = images_to_sprite(images)
    imsave(os.path.join(output_dir, 'sprite.png'), sprite)

    with tf.Session() as sess:
        sess.run(embedding_var.initializer)
        config = projector.ProjectorConfig()

        embedding = config.embeddings.add()
        embedding.tensor_name = embedding_var.name
        embedding.metadata_path = metadata_path
        embedding.sprite.image_path = sprite_path
        embedding.sprite.single_image_dim.extend(thumbnail_size)

        projector.visualize_embeddings(summary_writer, config)
        saver = tf.train.Saver([embedding_var])
        saver.save(sess, embeddings_path, 1)
Beispiel #17
0
    def write_embeddings(self, Wv, name="WordVectors"):
        """Write embedding matrix to the right place.

        Args:
          Wv: (numpy.ndarray) |V| x d matrix of word embeddings
        """
        #with tf.Graph().as_default(), tf.Session() as session:
        with tf.Graph().as_default(), tf.compat.v1.Session() as session:
            ##
            # Feed embeddings to tf, and save.
            embedding_var = tf.Variable(Wv, name=name, dtype=tf.float32)
            session.run(tf.compat.v1.global_variables_initializer())

            saver = tf.compat.v1.train.Saver()
            saver.save(session, self.CHECKPOINT_FILE, 0)

            ##
            # Save metadata
            summary_writer = tf.compat.v1.summary.FileWriter(self.LOGDIR)
            config = projector.ProjectorConfig()
            embedding = config.embeddings.add()
            embedding.tensor_name = embedding_var.name
            embedding.metadata_path = self.VOCAB_FILE_BASE
            projector.visualize_embeddings(summary_writer, config)

        msg = "Saved {s0:d} x {s1:d} embedding matrix '{name}'"
        msg += " to LOGDIR='{logdir}'"
        print(
            msg.format(s0=Wv.shape[0],
                       s1=Wv.shape[1],
                       name=name,
                       logdir=self.LOGDIR))

        print("To view, run:")
        print("\n  tensorboard --logdir=\"{logdir}\"\n".format(
            logdir=self.LOGDIR))
        print("and navigate to the \"Embeddings\" tab in the web interface.")
Beispiel #18
0
    def visualization(self, num_visualize, visual_fld):
        """ run "'tensorboard --logdir='visualization'" to see the embeddings """

        # create the list of num_variable most common words to visualize
        word2vec_utils.most_common_words(visual_fld, num_visualize)

        saver = tf.train.Saver()
        with tf.Session() as sess:
            sess.run(tf.global_variables_initializer())
            ckpt = tf.train.get_checkpoint_state(
                os.path.dirname('checkpoints/checkpoint'))

            # if that checkpoint exists, restore from checkpoint
            if ckpt and ckpt.model_checkpoint_path:
                saver.restore(sess, ckpt.model_checkpoint_path)

            final_embed_matrix = sess.run(self.embed_matrix)

            # you have to store embeddings in a new variable
            embedding_var = tf.Variable(final_embed_matrix[:num_visualize],
                                        name='embedding')
            sess.run(embedding_var.initializer)

            config = projector.ProjectorConfig()
            summary_writer = tf.summary.FileWriter(visual_fld)

            # add embedding to the config file
            embedding = config.embeddings.add()
            embedding.tensor_name = embedding_var.name

            # link this tensor to its metadata file, in this case the first NUM_VISUALIZE words of vocab
            embedding.metadata_path = 'vocab_' + str(num_visualize) + '.tsv'

            # saves a configuration file that TensorBoard will read during startup.
            projector.visualize_embeddings(summary_writer, config)
            saver_embed = tf.train.Saver([embedding_var])
            saver_embed.save(sess, os.path.join(visual_fld, 'model.ckpt'), 1)
Beispiel #19
0
def visualize_embeddings(images, embeddings, output_dir,
                         thumbnail_size=(32, 32)):
  """

  Args:
    images:
    embeddings:
    output_dir:
    thumbnail_size:

  Returns:

  """
  summary_writer = tf.summary.FileWriter(output_dir)

  sprite_path = os.path.abspath(os.path.join(output_dir, 'sprite.png'))
  metadata_path = os.path.abspath(os.path.join(output_dir, 'metadata.csv'))
  embeddings_path = os.path.join(output_dir, 'embeddings.ckpt')

  embedding_var = tf.Variable(embeddings, name='embeddings')
  sprite = images_to_sprite(images)
  imsave(os.path.join(output_dir, 'sprite.png'), sprite)

  with tf.Session() as sess:
    sess.run(embedding_var.initializer)
    config = projector.ProjectorConfig()

    embedding = config.embeddings.add()
    embedding.tensor_name = embedding_var.name
    embedding.metadata_path = metadata_path
    embedding.sprite.image_path = sprite_path
    embedding.sprite.single_image_dim.extend(thumbnail_size)

    projector.visualize_embeddings(summary_writer, config)
    saver = tf.train.Saver([embedding_var])
    saver.save(sess, embeddings_path, 1)
Beispiel #20
0
def embedding(vlist, rlist, metaPath, spSize):
    vs = []
    for i in range(0, len(vlist)):
        v = tf.Variable(rlist[i], name=vlist[i].name.split('/')[0])  # x/relu
        vs.append(v)

    with tf.Session() as sess:
        tf.variables_initializer(vs).run()  # assign to vs
        saver = tf.train.Saver(vs)
        saver.save(sess, './log/model.ckpt', 0)  # only contain vs

    # get writer and config
    summary_writer = tf.summary.FileWriter('./log/')
    config = projector.ProjectorConfig()

    # set config
    for v in vs:
        embed = config.embeddings.add()
        embed.tensor_name = v.name
        embed.metadata_path = metaPath + 'meta.tsv'
        embed.sprite.image_path = metaPath + 'meta.png'
        embed.sprite.single_image_dim.extend(spSize)
    # write
    projector.visualize_embeddings(summary_writer, config)
Beispiel #21
0
def tensorboard_view(v_data, words=None):
    df = pd.DataFrame.from_records(data=v_data)
    tf_data = tf.Variable(
        df.values.transpose()) if words is None else tf.Variable(
            df[words].values.transpose())

    ## Running Tensorlow Session
    with tf.Session() as sess:
        saver = tf.train.Saver([tf_data])
        sess.run(tf_data.initializer)
        saver.save(sess, os.path.join(LOG_DIR, 'tf_data.ckpt'))
        config = projector.ProjectorConfig()
        # One can add multiple embeddings.
        embedding = config.embeddings.add()
        embedding.tensor_name = tf_data.name
        # Specify where you find the sprite (we will create this later)
        path_for_mnist_sprites = os.path.join(LOG_DIR, 'mnistdigits.png')
        embedding.sprite.image_path = path_for_mnist_sprites
        embedding.sprite.single_image_dim.extend([28, 28])

        # Link this tensor to its metadata(Labels) file
        path_for_metadata = os.path.join(LOG_DIR, 'metadata.tsv')
        with open(path_for_metadata, 'w+') as metadata_file:
            _words = v_data.keys() if words is None else words
            for word in _words:
                metadata_file.write(f'{word}\n')

        embedding.metadata_path = path_for_metadata

        # Saves a config file that TensorBoard will read during startup.

        tf.compat.v1.disable_eager_execution()
        projector.visualize_embeddings(
            tf.compat.v1.summary.FileWriter(LOG_DIR), config)

        print(f'create logs on {LOG_DIR}')
Beispiel #22
0
def visualize_embeddings(model_path, dataset, log_dir):
    model = tf.keras.models.load_model(model_path)

    labels = []
    embeddings = []
    for n, (x, y) in enumerate(dataset):
        pred_y = model(x).numpy()
        for m in pred_y:
            embeddings.append(n)
        for m in y:
            labels.append(m)
    # Set up a logs directory, so Tensorboard knows where to look for files
    if not os.path.exists(log_dir):
        os.makedirs(log_dir)

    # Save Labels separately on a line-by-line manner.
    with open(os.path.join(log_dir, 'metadata.tsv'), "w") as f:
        for l in labels:
            f.write("{}\n".format(l))

    # Save the weights we want to analyse as a variable. Note that the first
    # value represents any unknown word, which is not in the metadata, so
    # we will remove that value.
    embedding_array = np.array(embeddings)
    embedding = tf.Variable(embedding_array, name='embedding')
    # Create a checkpoint from embedding, the filename and key are
    # name of the tensor.
    checkpoint = tf.train.Checkpoint(embedding=embedding)
    checkpoint.save(os.path.join(log_dir, "embedding.ckpt"))

    # Set up config
    config = projector.ProjectorConfig()
    embedding = config.embeddings.add()
    embedding.tensor_name = "embedding/.ATTRIBUTES/VARIABLE_VALUE"
    embedding.metadata_path = 'metadata.tsv'
    projector.visualize_embeddings(log_dir, config)
Beispiel #23
0
def train_cnn():
    """Training CNN model."""

    # Load sentences, labels, and training parameters
    logger.info("✔︎ Loading data...")

    logger.info("✔︎ Training data processing...")
    train_data = dh.load_data_and_labels(FLAGS.training_data_file, FLAGS.embedding_dim)

    logger.info("✔︎ Validation data processing...")
    validation_data = dh.load_data_and_labels(FLAGS.validation_data_file, FLAGS.embedding_dim)

    logger.info("Recommended padding Sequence length is: {0}".format(FLAGS.pad_seq_len))

    logger.info("✔︎ Training data padding...")
    x_train_front, x_train_behind, y_train = dh.pad_data(train_data, FLAGS.pad_seq_len)

    logger.info("✔︎ Validation data padding...")
    x_validation_front, x_validation_behind, y_validation = dh.pad_data(validation_data, FLAGS.pad_seq_len)

    # Build vocabulary
    VOCAB_SIZE = dh.load_vocab_size(FLAGS.embedding_dim)
    pretrained_word2vec_matrix = dh.load_word2vec_matrix(VOCAB_SIZE, FLAGS.embedding_dim)

    # Build a graph and cnn object
    with tf.Graph().as_default():
        session_conf = tf.ConfigProto(
            allow_soft_placement=FLAGS.allow_soft_placement,
            log_device_placement=FLAGS.log_device_placement)
        session_conf.gpu_options.allow_growth = FLAGS.gpu_options_allow_growth
        sess = tf.Session(config=session_conf)
        with sess.as_default():
            cnn = TextCNN(
                sequence_length=FLAGS.pad_seq_len,
                num_classes=y_train.shape[1],
                vocab_size=VOCAB_SIZE,
                fc_hidden_size=FLAGS.fc_hidden_size,
                embedding_size=FLAGS.embedding_dim,
                embedding_type=FLAGS.embedding_type,
                filter_sizes=list(map(int, FLAGS.filter_sizes.split(','))),
                num_filters=FLAGS.num_filters,
                l2_reg_lambda=FLAGS.l2_reg_lambda,
                pretrained_embedding=pretrained_word2vec_matrix)

            # Define training procedure
            with tf.control_dependencies(tf.get_collection(tf.GraphKeys.UPDATE_OPS)):
                learning_rate = tf.train.exponential_decay(learning_rate=FLAGS.learning_rate,
                                                           global_step=cnn.global_step, decay_steps=FLAGS.decay_steps,
                                                           decay_rate=FLAGS.decay_rate, staircase=True)
                optimizer = tf.train.AdamOptimizer(learning_rate)
                grads, vars = zip(*optimizer.compute_gradients(cnn.loss))
                grads, _ = tf.clip_by_global_norm(grads, clip_norm=FLAGS.norm_ratio)
                train_op = optimizer.apply_gradients(zip(grads, vars), global_step=cnn.global_step, name="train_op")

            # Keep track of gradient values and sparsity (optional)
            grad_summaries = []
            for g, v in zip(grads, vars):
                if g is not None:
                    grad_hist_summary = tf.summary.histogram("{0}/grad/hist".format(v.name), g)
                    sparsity_summary = tf.summary.scalar("{0}/grad/sparsity".format(v.name), tf.nn.zero_fraction(g))
                    grad_summaries.append(grad_hist_summary)
                    grad_summaries.append(sparsity_summary)
            grad_summaries_merged = tf.summary.merge(grad_summaries)

            # Output directory for models and summaries
            if FLAGS.train_or_restore == 'R':
                MODEL = input("☛ Please input the checkpoints model you want to restore, "
                              "it should be like(1490175368): ")  # The model you want to restore

                while not (MODEL.isdigit() and len(MODEL) == 10):
                    MODEL = input("✘ The format of your input is illegal, please re-input: ")
                logger.info("✔︎ The format of your input is legal, now loading to next step...")
                out_dir = os.path.abspath(os.path.join(os.path.curdir, "runs", MODEL))
                logger.info("✔︎ Writing to {0}\n".format(out_dir))
            else:
                timestamp = str(int(time.time()))
                out_dir = os.path.abspath(os.path.join(os.path.curdir, "runs", timestamp))
                logger.info("✔︎ Writing to {0}\n".format(out_dir))

            checkpoint_dir = os.path.abspath(os.path.join(out_dir, "checkpoints"))
            best_checkpoint_dir = os.path.abspath(os.path.join(out_dir, "bestcheckpoints"))

            # Summaries for loss and accuracy
            loss_summary = tf.summary.scalar("loss", cnn.loss)
            acc_summary = tf.summary.scalar("accuracy", cnn.accuracy)

            # Train summaries
            train_summary_op = tf.summary.merge([loss_summary, acc_summary, grad_summaries_merged])
            train_summary_dir = os.path.join(out_dir, "summaries", "train")
            train_summary_writer = tf.summary.FileWriter(train_summary_dir, sess.graph)

            # Validation summaries
            validation_summary_op = tf.summary.merge([loss_summary, acc_summary])
            validation_summary_dir = os.path.join(out_dir, "summaries", "validation")
            validation_summary_writer = tf.summary.FileWriter(validation_summary_dir, sess.graph)

            saver = tf.train.Saver(tf.global_variables(), max_to_keep=FLAGS.num_checkpoints)
            best_saver = cm.BestCheckpointSaver(save_dir=best_checkpoint_dir, num_to_keep=3, maximize=True)

            if FLAGS.train_or_restore == 'R':
                # Load cnn model
                logger.info("✔︎ Loading model...")
                checkpoint_file = tf.train.latest_checkpoint(checkpoint_dir)
                logger.info(checkpoint_file)

                # Load the saved meta graph and restore variables
                saver = tf.train.import_meta_graph("{0}.meta".format(checkpoint_file))
                saver.restore(sess, checkpoint_file)
            else:
                if not os.path.exists(checkpoint_dir):
                    os.makedirs(checkpoint_dir)
                sess.run(tf.global_variables_initializer())
                sess.run(tf.local_variables_initializer())

                # Embedding visualization config
                config = projector.ProjectorConfig()
                embedding_conf = config.embeddings.add()
                embedding_conf.tensor_name = "embedding"
                embedding_conf.metadata_path = FLAGS.metadata_file

                projector.visualize_embeddings(train_summary_writer, config)
                projector.visualize_embeddings(validation_summary_writer, config)

                # Save the embedding visualization
                saver.save(sess, os.path.join(out_dir, "embedding", "embedding.ckpt"))

            current_step = sess.run(cnn.global_step)

            def train_step(x_batch_front, x_batch_behind, y_batch):
                """A single training step"""
                feed_dict = {
                    cnn.input_x_front: x_batch_front,
                    cnn.input_x_behind: x_batch_behind,
                    cnn.input_y: y_batch,
                    cnn.dropout_keep_prob: FLAGS.dropout_keep_prob,
                    cnn.is_training: True
                }
                _, step, summaries, loss, accuracy = sess.run(
                    [train_op, cnn.global_step, train_summary_op, cnn.loss, cnn.accuracy], feed_dict)
                logger.info("step {0}: loss {1:g}, acc {2:g}".format(step, loss, accuracy))
                train_summary_writer.add_summary(summaries, step)

            def validation_step(x_batch_front, x_batch_behind, y_batch, writer=None):
                """Evaluates model on a validation set"""
                feed_dict = {
                    cnn.input_x_front: x_batch_front,
                    cnn.input_x_behind: x_batch_behind,
                    cnn.input_y: y_batch,
                    cnn.dropout_keep_prob: 1.0,
                    cnn.is_training: False
                }
                step, summaries, loss, accuracy, recall, precision, f1, auc = sess.run(
                    [cnn.global_step, validation_summary_op, cnn.loss, cnn.accuracy,
                     cnn.recall, cnn.precision, cnn.F1, cnn.AUC], feed_dict)
                logger.info("step {0}: loss {1:g}, acc {2:g}, recall {3:g}, precision {4:g}, f1 {5:g}, AUC {6}"
                            .format(step, loss, accuracy, recall, precision, f1, auc))
                if writer:
                    writer.add_summary(summaries, step)

                return accuracy

            # Generate batches
            batches = dh.batch_iter(
                list(zip(x_train_front, x_train_behind, y_train)), FLAGS.batch_size, FLAGS.num_epochs)

            num_batches_per_epoch = int((len(x_train_front) - 1) / FLAGS.batch_size) + 1

            # Training loop. For each batch...
            for batch in batches:
                x_batch_front, x_batch_behind, y_batch = zip(*batch)
                train_step(x_batch_front, x_batch_behind, y_batch)
                current_step = tf.train.global_step(sess, cnn.global_step)

                if current_step % FLAGS.evaluate_every == 0:
                    logger.info("\nEvaluation:")
                    accuracy = validation_step(x_validation_front, x_validation_behind, y_validation,
                                               writer=validation_summary_writer)
                    best_saver.handle(accuracy, sess, current_step)
                if current_step % FLAGS.checkpoint_every == 0:
                    checkpoint_prefix = os.path.join(checkpoint_dir, "model")
                    path = saver.save(sess, checkpoint_prefix, global_step=current_step)
                    logger.info("✔︎ Saved model checkpoint to {0}\n".format(path))
                if current_step % num_batches_per_epoch == 0:
                    current_epoch = current_step // num_batches_per_epoch
                    logger.info("✔︎ Epoch {0} has finished!".format(current_epoch))

    logger.info("✔︎ Done.")
Beispiel #24
0
def train_rcnn():
    """Training RCNN model."""

    # Load sentences, labels, and training parameters
    logger.info("✔︎ Loading data...")

    logger.info("✔︎ Training data processing...")
    train_data = dh.load_data_and_labels(FLAGS.training_data_file, FLAGS.num_classes,
                                         FLAGS.embedding_dim, data_aug_flag=False)

    logger.info("✔︎ Validation data processing...")
    val_data = dh.load_data_and_labels(FLAGS.validation_data_file, FLAGS.num_classes,
                                       FLAGS.embedding_dim, data_aug_flag=False)

    logger.info("Recommended padding Sequence length is: {0}".format(FLAGS.pad_seq_len))

    logger.info("✔︎ Training data padding...")
    x_train, y_train = dh.pad_data(train_data, FLAGS.pad_seq_len)

    logger.info("✔︎ Validation data padding...")
    x_val, y_val = dh.pad_data(val_data, FLAGS.pad_seq_len)

    # Build vocabulary
    VOCAB_SIZE = dh.load_vocab_size(FLAGS.embedding_dim)
    pretrained_word2vec_matrix = dh.load_word2vec_matrix(VOCAB_SIZE, FLAGS.embedding_dim)

    # Build a graph and rcnn object
    with tf.Graph().as_default():
        session_conf = tf.ConfigProto(
            allow_soft_placement=FLAGS.allow_soft_placement,
            log_device_placement=FLAGS.log_device_placement)
        session_conf.gpu_options.allow_growth = FLAGS.gpu_options_allow_growth
        sess = tf.Session(config=session_conf)
        with sess.as_default():
            rcnn = TextRCNN(
                sequence_length=FLAGS.pad_seq_len,
                num_classes=FLAGS.num_classes,
                vocab_size=VOCAB_SIZE,
                lstm_hidden_size=FLAGS.lstm_hidden_size,
                fc_hidden_size=FLAGS.fc_hidden_size,
                embedding_size=FLAGS.embedding_dim,
                embedding_type=FLAGS.embedding_type,
                filter_sizes=list(map(int, FLAGS.filter_sizes.split(','))),
                num_filters=FLAGS.num_filters,
                l2_reg_lambda=FLAGS.l2_reg_lambda,
                pretrained_embedding=pretrained_word2vec_matrix)

            # Define training procedure
            with tf.control_dependencies(tf.get_collection(tf.GraphKeys.UPDATE_OPS)):
                learning_rate = tf.train.exponential_decay(learning_rate=FLAGS.learning_rate,
                                                           global_step=rcnn.global_step, decay_steps=FLAGS.decay_steps,
                                                           decay_rate=FLAGS.decay_rate, staircase=True)
                optimizer = tf.train.AdamOptimizer(learning_rate)
                grads, vars = zip(*optimizer.compute_gradients(rcnn.loss))
                grads, _ = tf.clip_by_global_norm(grads, clip_norm=FLAGS.norm_ratio)
                train_op = optimizer.apply_gradients(zip(grads, vars), global_step=rcnn.global_step, name="train_op")

            # Keep track of gradient values and sparsity (optional)
            grad_summaries = []
            for g, v in zip(grads, vars):
                if g is not None:
                    grad_hist_summary = tf.summary.histogram("{0}/grad/hist".format(v.name), g)
                    sparsity_summary = tf.summary.scalar("{0}/grad/sparsity".format(v.name), tf.nn.zero_fraction(g))
                    grad_summaries.append(grad_hist_summary)
                    grad_summaries.append(sparsity_summary)
            grad_summaries_merged = tf.summary.merge(grad_summaries)

            # Output directory for models and summaries
            if FLAGS.train_or_restore == 'R':
                MODEL = input("☛ Please input the checkpoints model you want to restore, "
                              "it should be like(1490175368): ")  # The model you want to restore

                while not (MODEL.isdigit() and len(MODEL) == 10):
                    MODEL = input("✘ The format of your input is illegal, please re-input: ")
                logger.info("✔︎ The format of your input is legal, now loading to next step...")
                out_dir = os.path.abspath(os.path.join(os.path.curdir, "runs", MODEL))
                logger.info("✔︎ Writing to {0}\n".format(out_dir))
            else:
                timestamp = str(int(time.time()))
                out_dir = os.path.abspath(os.path.join(os.path.curdir, "runs", timestamp))
                logger.info("✔︎ Writing to {0}\n".format(out_dir))

            checkpoint_dir = os.path.abspath(os.path.join(out_dir, "checkpoints"))
            best_checkpoint_dir = os.path.abspath(os.path.join(out_dir, "bestcheckpoints"))

            # Summaries for loss
            loss_summary = tf.summary.scalar("loss", rcnn.loss)

            # Train summaries
            train_summary_op = tf.summary.merge([loss_summary, grad_summaries_merged])
            train_summary_dir = os.path.join(out_dir, "summaries", "train")
            train_summary_writer = tf.summary.FileWriter(train_summary_dir, sess.graph)

            # Validation summaries
            validation_summary_op = tf.summary.merge([loss_summary])
            validation_summary_dir = os.path.join(out_dir, "summaries", "validation")
            validation_summary_writer = tf.summary.FileWriter(validation_summary_dir, sess.graph)

            saver = tf.train.Saver(tf.global_variables(), max_to_keep=FLAGS.num_checkpoints)
            best_saver = cm.BestCheckpointSaver(save_dir=best_checkpoint_dir, num_to_keep=3, maximize=True)

            if FLAGS.train_or_restore == 'R':
                # Load rcnn model
                logger.info("✔︎ Loading model...")
                checkpoint_file = tf.train.latest_checkpoint(checkpoint_dir)
                logger.info(checkpoint_file)

                # Load the saved meta graph and restore variables
                saver = tf.train.import_meta_graph("{0}.meta".format(checkpoint_file))
                saver.restore(sess, checkpoint_file)
            else:
                if not os.path.exists(checkpoint_dir):
                    os.makedirs(checkpoint_dir)
                sess.run(tf.global_variables_initializer())
                sess.run(tf.local_variables_initializer())

                # Embedding visualization config
                config = projector.ProjectorConfig()
                embedding_conf = config.embeddings.add()
                embedding_conf.tensor_name = "embedding"
                embedding_conf.metadata_path = FLAGS.metadata_file

                projector.visualize_embeddings(train_summary_writer, config)
                projector.visualize_embeddings(validation_summary_writer, config)

                # Save the embedding visualization
                saver.save(sess, os.path.join(out_dir, "embedding", "embedding.ckpt"))

            current_step = sess.run(rcnn.global_step)

            def train_step(x_batch, y_batch):
                """A single training step"""
                feed_dict = {
                    rcnn.input_x: x_batch,
                    rcnn.input_y: y_batch,
                    rcnn.dropout_keep_prob: FLAGS.dropout_keep_prob,
                    rcnn.is_training: True
                }
                _, step, summaries, loss = sess.run(
                    [train_op, rcnn.global_step, train_summary_op, rcnn.loss], feed_dict)
                logger.info("step {0}: loss {1:g}".format(step, loss))
                train_summary_writer.add_summary(summaries, step)

            def validation_step(x_val, y_val, writer=None):
                """Evaluates model on a validation set"""
                batches_validation = dh.batch_iter(list(zip(x_val, y_val)), FLAGS.batch_size, 1)

                # Predict classes by threshold or topk ('ts': threshold; 'tk': topk)
                eval_counter, eval_loss = 0, 0.0

                eval_pre_tk = [0.0] * FLAGS.top_num
                eval_rec_tk = [0.0] * FLAGS.top_num
                eval_F_tk = [0.0] * FLAGS.top_num

                true_onehot_labels = []
                predicted_onehot_scores = []
                predicted_onehot_labels_ts = []
                predicted_onehot_labels_tk = [[] for _ in range(FLAGS.top_num)]

                for batch_validation in batches_validation:
                    x_batch_val, y_batch_val = zip(*batch_validation)
                    feed_dict = {
                        rcnn.input_x: x_batch_val,
                        rcnn.input_y: y_batch_val,
                        rcnn.dropout_keep_prob: 1.0,
                        rcnn.is_training: False
                    }
                    step, summaries, scores, cur_loss = sess.run(
                        [rcnn.global_step, validation_summary_op, rcnn.scores, rcnn.loss], feed_dict)

                    # Prepare for calculating metrics
                    for i in y_batch_val:
                        true_onehot_labels.append(i)
                    for j in scores:
                        predicted_onehot_scores.append(j)

                    # Predict by threshold
                    batch_predicted_onehot_labels_ts = \
                        dh.get_onehot_label_threshold(scores=scores, threshold=FLAGS.threshold)

                    for k in batch_predicted_onehot_labels_ts:
                        predicted_onehot_labels_ts.append(k)

                    # Predict by topK
                    for top_num in range(FLAGS.top_num):
                        batch_predicted_onehot_labels_tk = dh.get_onehot_label_topk(scores=scores, top_num=top_num+1)

                        for i in batch_predicted_onehot_labels_tk:
                            predicted_onehot_labels_tk[top_num].append(i)

                    eval_loss = eval_loss + cur_loss
                    eval_counter = eval_counter + 1

                    if writer:
                        writer.add_summary(summaries, step)

                eval_loss = float(eval_loss / eval_counter)

                # Calculate Precision & Recall & F1 (threshold & topK)
                eval_pre_ts = precision_score(y_true=np.array(true_onehot_labels),
                                              y_pred=np.array(predicted_onehot_labels_ts), average='micro')
                eval_rec_ts = recall_score(y_true=np.array(true_onehot_labels),
                                           y_pred=np.array(predicted_onehot_labels_ts), average='micro')
                eval_F_ts = f1_score(y_true=np.array(true_onehot_labels),
                                     y_pred=np.array(predicted_onehot_labels_ts), average='micro')

                for top_num in range(FLAGS.top_num):
                    eval_pre_tk[top_num] = precision_score(y_true=np.array(true_onehot_labels),
                                                           y_pred=np.array(predicted_onehot_labels_tk[top_num]),
                                                           average='micro')
                    eval_rec_tk[top_num] = recall_score(y_true=np.array(true_onehot_labels),
                                                        y_pred=np.array(predicted_onehot_labels_tk[top_num]),
                                                        average='micro')
                    eval_F_tk[top_num] = f1_score(y_true=np.array(true_onehot_labels),
                                                  y_pred=np.array(predicted_onehot_labels_tk[top_num]),
                                                  average='micro')

                # Calculate the average AUC
                eval_auc = roc_auc_score(y_true=np.array(true_onehot_labels),
                                         y_score=np.array(predicted_onehot_scores), average='micro')
                # Calculate the average PR
                eval_prc = average_precision_score(y_true=np.array(true_onehot_labels),
                                                   y_score=np.array(predicted_onehot_scores), average='micro')

                return eval_loss, eval_auc, eval_prc, eval_rec_ts, eval_pre_ts, eval_F_ts, \
                       eval_rec_tk, eval_pre_tk, eval_F_tk

            # Generate batches
            batches_train = dh.batch_iter(
                list(zip(x_train, y_train)), FLAGS.batch_size, FLAGS.num_epochs)

            num_batches_per_epoch = int((len(x_train) - 1) / FLAGS.batch_size) + 1

            # Training loop. For each batch...
            for batch_train in batches_train:
                x_batch_train, y_batch_train = zip(*batch_train)
                train_step(x_batch_train, y_batch_train)
                current_step = tf.train.global_step(sess, rcnn.global_step)

                if current_step % FLAGS.evaluate_every == 0:
                    logger.info("\nEvaluation:")
                    eval_loss, eval_auc, eval_prc, \
                    eval_rec_ts, eval_pre_ts, eval_F_ts, eval_rec_tk, eval_pre_tk, eval_F_tk = \
                        validation_step(x_val, y_val, writer=validation_summary_writer)

                    logger.info("All Validation set: Loss {0:g} | AUC {1:g} | AUPRC {2:g}"
                                .format(eval_loss, eval_auc, eval_prc))

                    # Predict by threshold
                    logger.info("☛ Predict by threshold: Precision {0:g}, Recall {1:g}, F {2:g}"
                                .format(eval_pre_ts, eval_rec_ts, eval_F_ts))

                    # Predict by topK
                    logger.info("☛ Predict by topK:")
                    for top_num in range(FLAGS.top_num):
                        logger.info("Top{0}: Precision {1:g}, Recall {2:g}, F {3:g}"
                                    .format(top_num+1, eval_pre_tk[top_num], eval_rec_tk[top_num], eval_F_tk[top_num]))
                    best_saver.handle(eval_prc, sess, current_step)
                if current_step % FLAGS.checkpoint_every == 0:
                    checkpoint_prefix = os.path.join(checkpoint_dir, "model")
                    path = saver.save(sess, checkpoint_prefix, global_step=current_step)
                    logger.info("✔︎ Saved model checkpoint to {0}\n".format(path))
                if current_step % num_batches_per_epoch == 0:
                    current_epoch = current_step // num_batches_per_epoch
                    logger.info("✔︎ Epoch {0} has finished!".format(current_epoch))

    logger.info("✔︎ Done.")
def train_han():
    """Training HAN model."""
    # Print parameters used for the model
    dh.tab_printer(args, logger)

    # Load sentences, labels, and training parameters
    logger.info("Loading data...")
    logger.info("Data processing...")
    train_data = dh.load_data_and_labels(args.train_file, args.num_classes, args.word2vec_file, data_aug_flag=False)
    val_data = dh.load_data_and_labels(args.validation_file, args.num_classes, args.word2vec_file, data_aug_flag=False)

    logger.info("Data padding...")
    x_train, y_train = dh.pad_data(train_data, args.pad_seq_len)
    x_val, y_val = dh.pad_data(val_data, args.pad_seq_len)

    # Build vocabulary
    VOCAB_SIZE, EMBEDDING_SIZE, pretrained_word2vec_matrix = dh.load_word2vec_matrix(args.word2vec_file)

    # Build a graph and han object
    with tf.Graph().as_default():
        session_conf = tf.ConfigProto(
            allow_soft_placement=args.allow_soft_placement,
            log_device_placement=args.log_device_placement)
        session_conf.gpu_options.allow_growth = args.gpu_options_allow_growth
        sess = tf.Session(config=session_conf)
        with sess.as_default():
            han = TextHAN(
                sequence_length=args.pad_seq_len,
                vocab_size=VOCAB_SIZE,
                embedding_type=args.embedding_type,
                embedding_size=EMBEDDING_SIZE,
                lstm_hidden_size=args.lstm_dim,
                fc_hidden_size=args.fc_dim,
                num_classes=args.num_classes,
                l2_reg_lambda=args.l2_lambda,
                pretrained_embedding=pretrained_word2vec_matrix)

            # Define training procedure
            with tf.control_dependencies(tf.get_collection(tf.GraphKeys.UPDATE_OPS)):
                learning_rate = tf.train.exponential_decay(learning_rate=args.learning_rate,
                                                           global_step=han.global_step, decay_steps=args.decay_steps,
                                                           decay_rate=args.decay_rate, staircase=True)
                optimizer = tf.train.AdamOptimizer(learning_rate)
                grads, vars = zip(*optimizer.compute_gradients(han.loss))
                grads, _ = tf.clip_by_global_norm(grads, clip_norm=args.norm_ratio)
                train_op = optimizer.apply_gradients(zip(grads, vars), global_step=han.global_step, name="train_op")

            # Keep track of gradient values and sparsity (optional)
            grad_summaries = []
            for g, v in zip(grads, vars):
                if g is not None:
                    grad_hist_summary = tf.summary.histogram("{0}/grad/hist".format(v.name), g)
                    sparsity_summary = tf.summary.scalar("{0}/grad/sparsity".format(v.name), tf.nn.zero_fraction(g))
                    grad_summaries.append(grad_hist_summary)
                    grad_summaries.append(sparsity_summary)
            grad_summaries_merged = tf.summary.merge(grad_summaries)

            # Output directory for models and summaries
            out_dir = dh.get_out_dir(OPTION, logger)
            checkpoint_dir = os.path.abspath(os.path.join(out_dir, "checkpoints"))
            best_checkpoint_dir = os.path.abspath(os.path.join(out_dir, "bestcheckpoints"))

            # Summaries for loss
            loss_summary = tf.summary.scalar("loss", han.loss)

            # Train summaries
            train_summary_op = tf.summary.merge([loss_summary, grad_summaries_merged])
            train_summary_dir = os.path.join(out_dir, "summaries", "train")
            train_summary_writer = tf.summary.FileWriter(train_summary_dir, sess.graph)

            # Validation summaries
            validation_summary_op = tf.summary.merge([loss_summary])
            validation_summary_dir = os.path.join(out_dir, "summaries", "validation")
            validation_summary_writer = tf.summary.FileWriter(validation_summary_dir, sess.graph)

            saver = tf.train.Saver(tf.global_variables(), max_to_keep=args.num_checkpoints)
            best_saver = cm.BestCheckpointSaver(save_dir=best_checkpoint_dir, num_to_keep=3, maximize=True)

            if OPTION == 'R':
                # Load han model
                logger.info("Loading model...")
                checkpoint_file = tf.train.latest_checkpoint(checkpoint_dir)
                logger.info(checkpoint_file)

                # Load the saved meta graph and restore variables
                saver = tf.train.import_meta_graph("{0}.meta".format(checkpoint_file))
                saver.restore(sess, checkpoint_file)
            if OPTION == 'T':
                if not os.path.exists(checkpoint_dir):
                    os.makedirs(checkpoint_dir)
                sess.run(tf.global_variables_initializer())
                sess.run(tf.local_variables_initializer())

                # Embedding visualization config
                config = projector.ProjectorConfig()
                embedding_conf = config.embeddings.add()
                embedding_conf.tensor_name = "embedding"
                embedding_conf.metadata_path = args.metadata_file

                projector.visualize_embeddings(train_summary_writer, config)
                projector.visualize_embeddings(validation_summary_writer, config)

                # Save the embedding visualization
                saver.save(sess, os.path.join(out_dir, "embedding", "embedding.ckpt"))

            current_step = sess.run(han.global_step)

            def train_step(x_batch, y_batch):
                """A single training step"""
                feed_dict = {
                    han.input_x: x_batch,
                    han.input_y: y_batch,
                    han.dropout_keep_prob: args.dropout_rate,
                    han.is_training: True
                }
                _, step, summaries, loss = sess.run(
                    [train_op, han.global_step, train_summary_op, han.loss], feed_dict)
                logger.info("step {0}: loss {1:g}".format(step, loss))
                train_summary_writer.add_summary(summaries, step)

            def validation_step(x_val, y_val, writer=None):
                """Evaluates model on a validation set"""
                batches_validation = dh.batch_iter(list(zip(x_val, y_val)), args.batch_size, 1)

                # Predict classes by threshold or topk ('ts': threshold; 'tk': topk)
                eval_counter, eval_loss = 0, 0.0

                eval_pre_tk = [0.0] * args.topK
                eval_rec_tk = [0.0] * args.topK
                eval_F1_tk = [0.0] * args.topK

                true_onehot_labels = []
                predicted_onehot_scores = []
                predicted_onehot_labels_ts = []
                predicted_onehot_labels_tk = [[] for _ in range(args.topK)]

                for batch_validation in batches_validation:
                    x_batch_val, y_batch_val = zip(*batch_validation)
                    feed_dict = {
                        han.input_x: x_batch_val,
                        han.input_y: y_batch_val,
                        han.dropout_keep_prob: 1.0,
                        han.is_training: False
                    }
                    step, summaries, scores, cur_loss = sess.run(
                        [han.global_step, validation_summary_op, han.scores, han.loss], feed_dict)

                    # Prepare for calculating metrics
                    for i in y_batch_val:
                        true_onehot_labels.append(i)
                    for j in scores:
                        predicted_onehot_scores.append(j)

                    # Predict by threshold
                    batch_predicted_onehot_labels_ts = \
                        dh.get_onehot_label_threshold(scores=scores, threshold=args.threshold)

                    for k in batch_predicted_onehot_labels_ts:
                        predicted_onehot_labels_ts.append(k)

                    # Predict by topK
                    for top_num in range(args.topK):
                        batch_predicted_onehot_labels_tk = dh.get_onehot_label_topk(scores=scores, top_num=top_num+1)

                        for i in batch_predicted_onehot_labels_tk:
                            predicted_onehot_labels_tk[top_num].append(i)

                    eval_loss = eval_loss + cur_loss
                    eval_counter = eval_counter + 1

                    if writer:
                        writer.add_summary(summaries, step)

                eval_loss = float(eval_loss / eval_counter)

                # Calculate Precision & Recall & F1
                eval_pre_ts = precision_score(y_true=np.array(true_onehot_labels),
                                              y_pred=np.array(predicted_onehot_labels_ts), average='micro')
                eval_rec_ts = recall_score(y_true=np.array(true_onehot_labels),
                                           y_pred=np.array(predicted_onehot_labels_ts), average='micro')
                eval_F1_ts = f1_score(y_true=np.array(true_onehot_labels),
                                      y_pred=np.array(predicted_onehot_labels_ts), average='micro')

                for top_num in range(args.topK):
                    eval_pre_tk[top_num] = precision_score(y_true=np.array(true_onehot_labels),
                                                           y_pred=np.array(predicted_onehot_labels_tk[top_num]),
                                                           average='micro')
                    eval_rec_tk[top_num] = recall_score(y_true=np.array(true_onehot_labels),
                                                        y_pred=np.array(predicted_onehot_labels_tk[top_num]),
                                                        average='micro')
                    eval_F1_tk[top_num] = f1_score(y_true=np.array(true_onehot_labels),
                                                   y_pred=np.array(predicted_onehot_labels_tk[top_num]),
                                                   average='micro')

                # Calculate the average AUC
                eval_auc = roc_auc_score(y_true=np.array(true_onehot_labels),
                                         y_score=np.array(predicted_onehot_scores), average='micro')
                # Calculate the average PR
                eval_prc = average_precision_score(y_true=np.array(true_onehot_labels),
                                                   y_score=np.array(predicted_onehot_scores), average='micro')

                return eval_loss, eval_auc, eval_prc, eval_pre_ts, eval_rec_ts, eval_F1_ts, \
                       eval_pre_tk, eval_rec_tk, eval_F1_tk

            # Generate batches
            batches_train = dh.batch_iter(
                list(zip(x_train, y_train)), args.batch_size, args.epochs)

            num_batches_per_epoch = int((len(x_train) - 1) / args.batch_size) + 1

            # Training loop. For each batch...
            for batch_train in batches_train:
                x_batch_train, y_batch_train = zip(*batch_train)
                train_step(x_batch_train, y_batch_train)
                current_step = tf.train.global_step(sess, han.global_step)

                if current_step % args.evaluate_steps == 0:
                    logger.info("\nEvaluation:")
                    eval_loss, eval_auc, eval_prc, \
                    eval_pre_ts, eval_rec_ts, eval_F1_ts, eval_pre_tk, eval_rec_tk, eval_F1_tk = \
                        validation_step(x_val, y_val, writer=validation_summary_writer)

                    logger.info("All Validation set: Loss {0:g} | AUC {1:g} | AUPRC {2:g}"
                                .format(eval_loss, eval_auc, eval_prc))

                    # Predict by threshold
                    logger.info("Predict by threshold: Precision {0:g}, Recall {1:g}, F1 {2:g}"
                                .format(eval_pre_ts, eval_rec_ts, eval_F1_ts))

                    # Predict by topK
                    logger.info("Predict by topK:")
                    for top_num in range(args.topK):
                        logger.info("Top{0}: Precision {1:g}, Recall {2:g}, F1 {3:g}"
                                    .format(top_num+1, eval_pre_tk[top_num], eval_rec_tk[top_num], eval_F1_tk[top_num]))
                    best_saver.handle(eval_prc, sess, current_step)
                if current_step % args.checkpoint_steps == 0:
                    checkpoint_prefix = os.path.join(checkpoint_dir, "model")
                    path = saver.save(sess, checkpoint_prefix, global_step=current_step)
                    logger.info("Saved model checkpoint to {0}\n".format(path))
                if current_step % num_batches_per_epoch == 0:
                    current_epoch = current_step // num_batches_per_epoch
                    logger.info("Epoch {0} has finished!".format(current_epoch))

    logger.info("All Done.")
def train_mann():
    """Training MANN model."""

    # Load sentences, labels, and training parameters
    logger.info('✔︎ Loading data...')

    logger.info('✔︎ Training data processing...')
    train_data = dh.load_data_and_labels(FLAGS.training_data_file,
                                         FLAGS.num_classes,
                                         FLAGS.embedding_dim)

    logger.info('✔︎ Validation data processing...')
    validation_data = \
        dh.load_data_and_labels(FLAGS.validation_data_file, FLAGS.num_classes, FLAGS.embedding_dim)

    logger.info('Recommended padding Sequence length is: {0}'.format(
        FLAGS.pad_seq_len))

    logger.info('✔︎ Training data padding...')
    x_train, y_train = dh.pad_data(train_data, FLAGS.pad_seq_len)

    logger.info('✔︎ Validation data padding...')
    x_validation, y_validation = dh.pad_data(validation_data,
                                             FLAGS.pad_seq_len)

    # Build vocabulary
    VOCAB_SIZE = dh.load_vocab_size(FLAGS.embedding_dim)
    pretrained_word2vec_matrix = dh.load_word2vec_matrix(
        VOCAB_SIZE, FLAGS.embedding_dim)

    # Build a graph and mann object
    with tf.Graph().as_default():
        session_conf = tf.ConfigProto(
            allow_soft_placement=FLAGS.allow_soft_placement,
            log_device_placement=FLAGS.log_device_placement)
        session_conf.gpu_options.allow_growth = FLAGS.gpu_options_allow_growth
        sess = tf.Session(config=session_conf)
        with sess.as_default():
            mann = TextMANN(sequence_length=FLAGS.pad_seq_len,
                            num_classes=FLAGS.num_classes,
                            batch_size=FLAGS.batch_size,
                            vocab_size=VOCAB_SIZE,
                            lstm_hidden_size=FLAGS.lstm_hidden_size,
                            fc_hidden_size=FLAGS.fc_hidden_size,
                            embedding_size=FLAGS.embedding_dim,
                            embedding_type=FLAGS.embedding_type,
                            l2_reg_lambda=FLAGS.l2_reg_lambda,
                            pretrained_embedding=pretrained_word2vec_matrix)

            # Define training procedure
            with tf.control_dependencies(
                    tf.get_collection(tf.GraphKeys.UPDATE_OPS)):
                learning_rate = tf.train.exponential_decay(
                    learning_rate=FLAGS.learning_rate,
                    global_step=mann.global_step,
                    decay_steps=FLAGS.decay_steps,
                    decay_rate=FLAGS.decay_rate,
                    staircase=True)
                optimizer = tf.train.AdamOptimizer(learning_rate)
                grads, vars = zip(*optimizer.compute_gradients(mann.loss))
                grads, _ = tf.clip_by_global_norm(grads,
                                                  clip_norm=FLAGS.norm_ratio)
                train_op = optimizer.apply_gradients(
                    zip(grads, vars),
                    global_step=mann.global_step,
                    name="train_op")

            # Keep track of gradient values and sparsity (optional)
            grad_summaries = []
            for g, v in zip(grads, vars):
                if g is not None:
                    grad_hist_summary = tf.summary.histogram(
                        "{0}/grad/hist".format(v.name), g)
                    sparsity_summary = tf.summary.scalar(
                        "{0}/grad/sparsity".format(v.name),
                        tf.nn.zero_fraction(g))
                    grad_summaries.append(grad_hist_summary)
                    grad_summaries.append(sparsity_summary)
            grad_summaries_merged = tf.summary.merge(grad_summaries)

            # Output directory for models and summaries
            if FLAGS.train_or_restore == 'R':
                MODEL = input(
                    "☛ Please input the checkpoints model you want to restore, "
                    "it should be like(1490175368): "
                )  # The model you want to restore

                while not (MODEL.isdigit() and len(MODEL) == 10):
                    MODEL = input(
                        '✘ The format of your input is illegal, please re-input: '
                    )
                logger.info(
                    '✔︎ The format of your input is legal, now loading to next step...'
                )

                checkpoint_dir = 'runs/' + MODEL + '/checkpoints/'

                out_dir = os.path.abspath(
                    os.path.join(os.path.curdir, "runs", MODEL))
                logger.info("✔︎ Writing to {0}\n".format(out_dir))
            else:
                timestamp = str(int(time.time()))
                out_dir = os.path.abspath(
                    os.path.join(os.path.curdir, "runs", timestamp))
                logger.info("✔︎ Writing to {0}\n".format(out_dir))

            # Summaries for loss and accuracy
            loss_summary = tf.summary.scalar("loss", mann.loss)

            # Train summaries
            train_summary_op = tf.summary.merge(
                [loss_summary, grad_summaries_merged])
            train_summary_dir = os.path.join(out_dir, "summaries", "train")
            train_summary_writer = tf.summary.FileWriter(
                train_summary_dir, sess.graph)

            # Validation summaries
            validation_summary_op = tf.summary.merge([loss_summary])
            validation_summary_dir = os.path.join(out_dir, "summaries",
                                                  "validation")
            validation_summary_writer = tf.summary.FileWriter(
                validation_summary_dir, sess.graph)

            saver = tf.train.Saver(tf.global_variables(),
                                   max_to_keep=FLAGS.num_checkpoints)

            if FLAGS.train_or_restore == 'R':
                # Load mann model
                logger.info("✔ Loading model...")
                checkpoint_file = tf.train.latest_checkpoint(checkpoint_dir)
                logger.info(checkpoint_file)

                # Load the saved meta graph and restore variables
                saver = tf.train.import_meta_graph(
                    "{0}.meta".format(checkpoint_file))
                saver.restore(sess, checkpoint_file)
            else:
                checkpoint_dir = os.path.abspath(
                    os.path.join(out_dir, "checkpoints"))
                if not os.path.exists(checkpoint_dir):
                    os.makedirs(checkpoint_dir)
                sess.run(tf.global_variables_initializer())
                sess.run(tf.local_variables_initializer())

                # Embedding visualization config
                config = projector.ProjectorConfig()
                embedding_conf = config.embeddings.add()
                embedding_conf.tensor_name = 'embedding'
                embedding_conf.metadata_path = FLAGS.metadata_file

                projector.visualize_embeddings(train_summary_writer, config)
                projector.visualize_embeddings(validation_summary_writer,
                                               config)

                # Save the embedding visualization
                saver.save(
                    sess, os.path.join(out_dir, 'embedding', 'embedding.ckpt'))

            current_step = sess.run(mann.global_step)

            def train_step(x_batch, y_batch):
                """A single training step"""
                feed_dict = {
                    mann.input_x: x_batch,
                    mann.input_y: y_batch,
                    mann.dropout_keep_prob: FLAGS.dropout_keep_prob,
                    mann.is_training: True
                }
                _, step, summaries, loss = sess.run(
                    [train_op, mann.global_step, train_summary_op, mann.loss],
                    feed_dict)
                logger.info("step {0}: loss {1:g}".format(step, loss))
                train_summary_writer.add_summary(summaries, step)

            def validation_step(x_validation, y_validation, writer=None):
                """Evaluates model on a validation set"""
                batches_validation = dh.batch_iter(
                    list(zip(x_validation, y_validation)), FLAGS.batch_size,
                    FLAGS.num_epochs)

                # Predict classes by threshold or topk ('ts': threshold; 'tk': topk)
                eval_counter, eval_loss, eval_rec_ts, eval_acc_ts, eval_F_ts = 0, 0.0, 0.0, 0.0, 0.0
                eval_rec_tk = [0.0] * FLAGS.top_num
                eval_acc_tk = [0.0] * FLAGS.top_num
                eval_F_tk = [0.0] * FLAGS.top_num

                for batch_validation in batches_validation:
                    x_batch_validation, y_batch_validation = zip(
                        *batch_validation)
                    feed_dict = {
                        mann.input_x: x_batch_validation,
                        mann.input_y: y_batch_validation,
                        mann.dropout_keep_prob: 1.0,
                        mann.is_training: False
                    }
                    step, summaries, scores, cur_loss = sess.run([
                        mann.global_step, validation_summary_op, mann.scores,
                        mann.loss
                    ], feed_dict)

                    # Predict by threshold
                    predicted_labels_threshold, predicted_values_threshold = \
                        dh.get_label_using_scores_by_threshold(scores=scores, threshold=FLAGS.threshold)

                    cur_rec_ts, cur_acc_ts, cur_F_ts = 0.0, 0.0, 0.0

                    for index, predicted_label_threshold in enumerate(
                            predicted_labels_threshold):
                        rec_inc_ts, acc_inc_ts, F_inc_ts = dh.cal_metric(
                            predicted_label_threshold,
                            y_batch_validation[index])
                        cur_rec_ts, cur_acc_ts, cur_F_ts = cur_rec_ts + rec_inc_ts, \
                                                           cur_acc_ts + acc_inc_ts, \
                                                           cur_F_ts + F_inc_ts

                    cur_rec_ts = cur_rec_ts / len(y_batch_validation)
                    cur_acc_ts = cur_acc_ts / len(y_batch_validation)
                    cur_F_ts = cur_F_ts / len(y_batch_validation)

                    eval_rec_ts, eval_acc_ts, eval_F_ts = eval_rec_ts + cur_rec_ts, \
                                                          eval_acc_ts + cur_acc_ts, \
                                                          eval_F_ts + cur_F_ts

                    # Predict by topK
                    topK_predicted_labels = []
                    for top_num in range(FLAGS.top_num):
                        predicted_labels_topk, predicted_values_topk = \
                            dh.get_label_using_scores_by_topk(scores=scores, top_num=top_num+1)
                        topK_predicted_labels.append(predicted_labels_topk)

                    cur_rec_tk = [0.0] * FLAGS.top_num
                    cur_acc_tk = [0.0] * FLAGS.top_num
                    cur_F_tk = [0.0] * FLAGS.top_num

                    for top_num, predicted_labels_topK in enumerate(
                            topK_predicted_labels):
                        for index, predicted_label_topK in enumerate(
                                predicted_labels_topK):
                            rec_inc_tk, acc_inc_tk, F_inc_tk = dh.cal_metric(
                                predicted_label_topK,
                                y_batch_validation[index])
                            cur_rec_tk[top_num], cur_acc_tk[top_num], cur_F_tk[top_num] = \
                                cur_rec_tk[top_num] + rec_inc_tk, \
                                cur_acc_tk[top_num] + acc_inc_tk, \
                                cur_F_tk[top_num] + F_inc_tk

                        cur_rec_tk[top_num] = cur_rec_tk[top_num] / len(
                            y_batch_validation)
                        cur_acc_tk[top_num] = cur_acc_tk[top_num] / len(
                            y_batch_validation)
                        cur_F_tk[top_num] = cur_F_tk[top_num] / len(
                            y_batch_validation)

                        eval_rec_tk[top_num], eval_acc_tk[top_num], eval_F_tk[top_num] = \
                            eval_rec_tk[top_num] + cur_rec_tk[top_num], \
                            eval_acc_tk[top_num] + cur_acc_tk[top_num], \
                            eval_F_tk[top_num] + cur_F_tk[top_num]

                    eval_loss = eval_loss + cur_loss
                    eval_counter = eval_counter + 1

                    logger.info("✔︎ validation batch {0}: loss {1:g}".format(
                        eval_counter, cur_loss))
                    logger.info(
                        "︎☛ Predict by threshold: recall {0:g}, accuracy {1:g}, F {2:g}"
                        .format(cur_rec_ts, cur_acc_ts, cur_F_ts))

                    logger.info("︎☛ Predict by topK:")
                    for top_num in range(FLAGS.top_num):
                        logger.info(
                            "Top{0}: recall {1:g}, accuracy {2:g}, F {3:g}".
                            format(top_num + 1, cur_rec_tk[top_num],
                                   cur_acc_tk[top_num], cur_F_tk[top_num]))

                    if writer:
                        writer.add_summary(summaries, step)

                eval_loss = float(eval_loss / eval_counter)
                eval_rec_ts = float(eval_rec_ts / eval_counter)
                eval_acc_ts = float(eval_acc_ts / eval_counter)
                eval_F_ts = float(eval_F_ts / eval_counter)

                for top_num in range(FLAGS.top_num):
                    eval_rec_tk[top_num] = float(eval_rec_tk[top_num] /
                                                 eval_counter)
                    eval_acc_tk[top_num] = float(eval_acc_tk[top_num] /
                                                 eval_counter)
                    eval_F_tk[top_num] = float(eval_F_tk[top_num] /
                                               eval_counter)

                return eval_loss, eval_rec_ts, eval_acc_ts, eval_F_ts, eval_rec_tk, eval_acc_tk, eval_F_tk

            # Generate batches
            batches_train = dh.batch_iter(list(zip(x_train, y_train)),
                                          FLAGS.batch_size, FLAGS.num_epochs)

            num_batches_per_epoch = int(
                (len(x_train) - 1) / FLAGS.batch_size) + 1

            # Training loop. For each batch...
            for batch_train in batches_train:
                x_batch_train, y_batch_train = zip(*batch_train)
                train_step(x_batch_train, y_batch_train)
                current_step = tf.train.global_step(sess, mann.global_step)

                if current_step % FLAGS.evaluate_every == 0:
                    logger.info("\nEvaluation:")
                    eval_loss, eval_rec_ts, eval_acc_ts, eval_F_ts, eval_rec_tk, eval_acc_tk, eval_F_tk = \
                        validation_step(x_validation, y_validation, writer=validation_summary_writer)

                    logger.info(
                        "All Validation set: Loss {0:g}".format(eval_loss))

                    # Predict by threshold
                    logger.info(
                        "︎☛ Predict by threshold: Recall {0:g}, Accuracy {1:g}, F {2:g}"
                        .format(eval_rec_ts, eval_acc_ts, eval_F_ts))

                    # Predict by topK
                    logger.info("︎☛ Predict by topK:")
                    for top_num in range(FLAGS.top_num):
                        logger.info(
                            "Top{0}: Recall {1:g}, Accuracy {2:g}, F {3:g}".
                            format(top_num + 1, eval_rec_tk[top_num],
                                   eval_acc_tk[top_num], eval_F_tk[top_num]))
                if current_step % FLAGS.checkpoint_every == 0:
                    checkpoint_prefix = os.path.join(checkpoint_dir, "model")
                    path = saver.save(sess,
                                      checkpoint_prefix,
                                      global_step=current_step)
                    logger.info(
                        "✔︎ Saved model checkpoint to {0}\n".format(path))
                if current_step % num_batches_per_epoch == 0:
                    current_epoch = current_step // num_batches_per_epoch
                    logger.info(
                        "✔︎ Epoch {0} has finished!".format(current_epoch))

    logger.info("✔︎ Done.")
Beispiel #27
0
def train_sann():
    """Training RNN model."""
    # Print parameters used for the model
    dh.tab_printer(args, logger)

    # Load word2vec model
    word2idx, embedding_matrix = dh.load_word2vec_matrix(args.word2vec_file)

    # Load sentences, labels, and training parameters
    logger.info("Loading data...")
    logger.info("Data processing...")
    train_data = dh.load_data_and_labels(args, args.train_file, word2idx)
    val_data = dh.load_data_and_labels(args, args.validation_file, word2idx)

    # Build a graph and sann object
    with tf.Graph().as_default():
        session_conf = tf.ConfigProto(
            allow_soft_placement=args.allow_soft_placement,
            log_device_placement=args.log_device_placement)
        session_conf.gpu_options.allow_growth = args.gpu_options_allow_growth
        sess = tf.Session(config=session_conf)
        with sess.as_default():
            sann = TextSANN(sequence_length=args.pad_seq_len,
                            vocab_size=len(word2idx),
                            embedding_type=args.embedding_type,
                            embedding_size=args.embedding_dim,
                            lstm_hidden_size=args.lstm_dim,
                            attention_unit_size=args.attention_dim,
                            attention_hops_size=args.attention_hops_dim,
                            fc_hidden_size=args.fc_dim,
                            num_classes=args.num_classes,
                            l2_reg_lambda=args.l2_lambda,
                            pretrained_embedding=embedding_matrix)

            # Define training procedure
            with tf.control_dependencies(
                    tf.get_collection(tf.GraphKeys.UPDATE_OPS)):
                learning_rate = tf.train.exponential_decay(
                    learning_rate=args.learning_rate,
                    global_step=sann.global_step,
                    decay_steps=args.decay_steps,
                    decay_rate=args.decay_rate,
                    staircase=True)
                optimizer = tf.train.AdamOptimizer(learning_rate)
                grads, vars = zip(*optimizer.compute_gradients(sann.loss))
                grads, _ = tf.clip_by_global_norm(grads,
                                                  clip_norm=args.norm_ratio)
                train_op = optimizer.apply_gradients(
                    zip(grads, vars),
                    global_step=sann.global_step,
                    name="train_op")

            # Keep track of gradient values and sparsity (optional)
            grad_summaries = []
            for g, v in zip(grads, vars):
                if g is not None:
                    grad_hist_summary = tf.summary.histogram(
                        "{0}/grad/hist".format(v.name), g)
                    sparsity_summary = tf.summary.scalar(
                        "{0}/grad/sparsity".format(v.name),
                        tf.nn.zero_fraction(g))
                    grad_summaries.append(grad_hist_summary)
                    grad_summaries.append(sparsity_summary)
            grad_summaries_merged = tf.summary.merge(grad_summaries)

            # Output directory for models and summaries
            out_dir = dh.get_out_dir(OPTION, logger)
            checkpoint_dir = os.path.abspath(
                os.path.join(out_dir, "checkpoints"))
            best_checkpoint_dir = os.path.abspath(
                os.path.join(out_dir, "bestcheckpoints"))

            # Summaries for loss
            loss_summary = tf.summary.scalar("loss", sann.loss)

            # Train summaries
            train_summary_op = tf.summary.merge(
                [loss_summary, grad_summaries_merged])
            train_summary_dir = os.path.join(out_dir, "summaries", "train")
            train_summary_writer = tf.summary.FileWriter(
                train_summary_dir, sess.graph)

            # Validation summaries
            validation_summary_op = tf.summary.merge([loss_summary])
            validation_summary_dir = os.path.join(out_dir, "summaries",
                                                  "validation")
            validation_summary_writer = tf.summary.FileWriter(
                validation_summary_dir, sess.graph)

            saver = tf.train.Saver(tf.global_variables(),
                                   max_to_keep=args.num_checkpoints)
            best_saver = cm.BestCheckpointSaver(save_dir=best_checkpoint_dir,
                                                num_to_keep=3,
                                                maximize=True)

            if OPTION == 'R':
                # Load sann model
                logger.info("Loading model...")
                checkpoint_file = tf.train.latest_checkpoint(checkpoint_dir)
                logger.info(checkpoint_file)

                # Load the saved meta graph and restore variables
                saver = tf.train.import_meta_graph(
                    "{0}.meta".format(checkpoint_file))
                saver.restore(sess, checkpoint_file)
            if OPTION == 'T':
                if not os.path.exists(checkpoint_dir):
                    os.makedirs(checkpoint_dir)
                sess.run(tf.global_variables_initializer())
                sess.run(tf.local_variables_initializer())

                # Embedding visualization config
                config = projector.ProjectorConfig()
                embedding_conf = config.embeddings.add()
                embedding_conf.tensor_name = "embedding"
                embedding_conf.metadata_path = args.metadata_file

                projector.visualize_embeddings(train_summary_writer, config)
                projector.visualize_embeddings(validation_summary_writer,
                                               config)

                # Save the embedding visualization
                saver.save(
                    sess, os.path.join(out_dir, "embedding", "embedding.ckpt"))

            current_step = sess.run(sann.global_step)

            def train_step(batch_data):
                """A single training step."""
                x_f, x_b, y_onehot = zip(*batch_data)

                feed_dict = {
                    sann.input_x_front: x_f,
                    sann.input_x_behind: x_b,
                    sann.input_y: y_onehot,
                    sann.dropout_keep_prob: args.dropout_rate,
                    sann.is_training: True
                }
                _, step, summaries, loss = sess.run(
                    [train_op, sann.global_step, train_summary_op, sann.loss],
                    feed_dict)
                logger.info("step {0}: loss {1:g}".format(step, loss))
                train_summary_writer.add_summary(summaries, step)

            def validation_step(val_loader, writer=None):
                """Evaluates model on a validation set."""
                batches_validation = dh.batch_iter(
                    list(create_input_data(val_loader)), args.batch_size, 1)

                eval_counter, eval_loss = 0, 0.0
                true_labels = []
                predicted_scores = []
                predicted_labels = []

                for batch_validation in batches_validation:
                    x_f, x_b, y_onehot = zip(*batch_validation)
                    feed_dict = {
                        sann.input_x_front: x_f,
                        sann.input_x_behind: x_b,
                        sann.input_y: y_onehot,
                        sann.dropout_keep_prob: 1.0,
                        sann.is_training: False
                    }
                    step, summaries, predictions, cur_loss = sess.run([
                        sann.global_step, validation_summary_op,
                        sann.topKPreds, sann.loss
                    ], feed_dict)

                    # Prepare for calculating metrics
                    for i in y_onehot:
                        true_labels.append(np.argmax(i))
                    for j in predictions[0]:
                        predicted_scores.append(j[0])
                    for k in predictions[1]:
                        predicted_labels.append(k[0])

                    eval_loss = eval_loss + cur_loss
                    eval_counter = eval_counter + 1

                    if writer:
                        writer.add_summary(summaries, step)

                eval_loss = float(eval_loss / eval_counter)

                # Calculate Precision & Recall & F1
                eval_acc = accuracy_score(y_true=np.array(true_labels),
                                          y_pred=np.array(predicted_labels))
                eval_pre = precision_score(y_true=np.array(true_labels),
                                           y_pred=np.array(predicted_labels),
                                           average='micro')
                eval_rec = recall_score(y_true=np.array(true_labels),
                                        y_pred=np.array(predicted_labels),
                                        average='micro')
                eval_F1 = f1_score(y_true=np.array(true_labels),
                                   y_pred=np.array(predicted_labels),
                                   average='micro')

                # Calculate the average AUC
                eval_auc = roc_auc_score(y_true=np.array(true_labels),
                                         y_score=np.array(predicted_scores),
                                         average='micro')

                return eval_loss, eval_acc, eval_pre, eval_rec, eval_F1, eval_auc

            # Generate batches
            batches_train = dh.batch_iter(list(create_input_data(train_data)),
                                          args.batch_size, args.epochs)
            num_batches_per_epoch = int(
                (len(train_data['f_pad_seqs']) - 1) / args.batch_size) + 1

            # Training loop. For each batch...
            for batch_train in batches_train:
                train_step(batch_train)
                current_step = tf.train.global_step(sess, sann.global_step)

                if current_step % args.evaluate_steps == 0:
                    logger.info("\nEvaluation:")
                    eval_loss, eval_acc, eval_pre, eval_rec, eval_F1, eval_auc = \
                        validation_step(val_data, writer=validation_summary_writer)
                    logger.info(
                        "All Validation set: Loss {0:g} | Acc {1:g} | Precision {2:g} | "
                        "Recall {3:g} | F1 {4:g} | AUC {5:g}".format(
                            eval_loss, eval_acc, eval_pre, eval_rec, eval_F1,
                            eval_auc))
                    best_saver.handle(eval_acc, sess, current_step)
                if current_step % args.checkpoint_steps == 0:
                    checkpoint_prefix = os.path.join(checkpoint_dir, "model")
                    path = saver.save(sess,
                                      checkpoint_prefix,
                                      global_step=current_step)
                    logger.info("Saved model checkpoint to {0}\n".format(path))
                if current_step % num_batches_per_epoch == 0:
                    current_epoch = current_step // num_batches_per_epoch
                    logger.info(
                        "Epoch {0} has finished!".format(current_epoch))

    logger.info("All Done.")
checkpoint.save(os.path.join(embedding_path, "embedding.ckpt"))

# Set up config
config = projector.ProjectorConfig()
embedding = config.embeddings.add()
# The name of the tensor will be suffixed by `/.ATTRIBUTES/VARIABLE_VALUE`
embedding.tensor_name = "embedding/.ATTRIBUTES/VARIABLE_VALUE"
# Link this tensor to its metadata file (e.g. labels).
embedding.metadata_path = 'metadata.tsv'
# Comment out if you don't want sprites
embedding.sprite.image_path = 'sprite.png'
embedding.sprite.single_image_dim.extend(
    [img_data.shape[1], img_data.shape[1]])
# Saves a config file that TensorBoard will read during startup.

projector.visualize_embeddings(embedding_path, config)
#%%
'''

# Use the iris dataset to illustrate PCA:
import pandas as pd
url = "https://archive.ics.uci.edu/ml/machine-learning-databases/iris/iris.data"
# load dataset into Pandas DataFrame
df = pd.read_csv(url, names=['sepal length','sepal width','petal length','petal width','target'])
df.head()

from sklearn.preprocessing import StandardScaler
variables = ['sepal length','sepal width','petal length','petal width']
x = df.loc[:, variables].values
y = df.loc[:,['target']].values
x = StandardScaler().fit_transform(x)
Beispiel #29
0
    # Report only errors by default
    if not args.verbose:
        os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"

    # Generate the embeddings for the projector
    tf.summary.create_file_writer(args.output_dir)
    with open(args.input_embeddings, "r") as embedding_file:
        _, dim = map(int, embedding_file.readline().split())

        embeddings = np.zeros([args.elements, dim], np.float32)
        with open(os.path.join(args.output_dir, "metadata.tsv"),
                  "w") as metadata_file:
            for i, line in zip(range(args.elements), embedding_file):
                form, *embedding = line.split()
                print(form, file=metadata_file)
                embeddings[i] = list(map(float, embedding))

    # Save the variable
    embeddings = tf.Variable(embeddings, tf.float32)
    checkpoint = tf.train.Checkpoint(embeddings=embeddings)
    checkpoint.save(os.path.join(args.output_dir, "embeddings.ckpt"))

    # Set up the projector config
    config = projector.ProjectorConfig()
    embeddings = config.embeddings.add()

    # The name of the tensor will be suffixed by `/.ATTRIBUTES/VARIABLE_VALUE`
    embeddings.tensor_name = "embeddings/.ATTRIBUTES/VARIABLE_VALUE"
    embeddings.metadata_path = "metadata.tsv"
    projector.visualize_embeddings(args.output_dir, config)
PATH = os.getcwd()

LOG_DIR = PATH + '/log-1/'
metadata = os.path.join(LOG_DIR, 'metadata.tsv')

Voices = tf.Variable(o.reshape((len(o), -1)), name='Voices')

#def save_metadata(file):
with open(metadata, 'w') as metadata_file:
    for row in range(702):
        c = emotion[y_train[row]]
        metadata_file.write('{}\n'.format(c))

with tf.compat.v1.Session() as sess:
    saver = tf.compat.v1.train.Saver([Voices])

    sess.run(Voices.initializer)
    saver.save(sess, os.path.join(LOG_DIR, 'model.ckpt'))

    config = projector.ProjectorConfig()
    # One can add multiple embeddings.
    embedding = config.embeddings.add()
    embedding.tensor_name = Voices.name
    # Link this tensor to its metadata file (e.g. labels).
    embedding.metadata_path = metadata
    # Saves a config file that TensorBoard will read during startup.
    projector.visualize_embeddings(tf.compat.v1.summary.FileWriter(LOG_DIR),
                                   config)

#get_ipython().system('tensorboard --logdir=./log-1/')
Beispiel #31
0
    def build_model(self):
        with tf.variable_scope('Placeholder'):
            self.nodes_placeholder = tf.placeholder(tf.int32, (None, ),
                                                    name='nodes_placeholder')
            self.seqlen_placeholder = tf.placeholder(tf.int32, (None, ),
                                                     name='seqlen_placeholder')
            self.neighborhood_placeholder = tf.placeholder(
                tf.int32, (None, self.args.sampling_size),
                name='neighborhood_placeholder')
            self.label_placeholder = tf.placeholder(tf.float32, (None, ),
                                                    name='label_placeholder')

        self.data = network.next_batch(self.graph,
                                       self.degree_max,
                                       sampling=True,
                                       sampling_size=self.args.sampling_size)

        with tf.variable_scope('Embeddings'):
            self.embeddings = tf.get_variable(
                'embeddings', [len(self.graph), self.args.embedding_size],
                initializer=tf.constant_initializer(
                    utils.init_embedding(self.degree, self.degree_max,
                                         self.args.embedding_size)))

        with tf.variable_scope('LSTM'):
            cell = tf.contrib.rnn.DropoutWrapper(
                #tf.contrib.rnn.BasicLSTMCell(num_units=self.args.embedding_size),
                tf.contrib.rnn.LayerNormBasicLSTMCell(
                    num_units=self.args.embedding_size, layer_norm=False),
                input_keep_prob=1.0,
                output_keep_prob=1.0)
            _, states = tf.nn.dynamic_rnn(
                cell,
                tf.nn.embedding_lookup(self.embeddings,
                                       self.neighborhood_placeholder),
                dtype=tf.float32,
                sequence_length=self.seqlen_placeholder)
            self.lstm_output = states.h

        with tf.variable_scope('Guilded'):
            self.predict_info = tf.squeeze(
                tf.layers.dense(self.lstm_output,
                                units=1,
                                activation=utils.selu))

        with tf.variable_scope('Loss'):
            self.structure_loss = tf.losses.mean_squared_error(
                tf.nn.embedding_lookup(self.embeddings,
                                       self.nodes_placeholder),
                self.lstm_output)
            self.guilded_loss = tf.reduce_mean(
                tf.abs(tf.subtract(self.predict_info, self.label_placeholder)))
            self.orth_loss = tf.losses.mean_squared_error(
                tf.matmul(self.embeddings, self.embeddings, transpose_a=True),
                tf.eye(self.args.embedding_size))
            self.total_loss = self.structure_loss + self.args.alpha * self.orth_loss + self.args.lamb * self.guilded_loss

        with tf.variable_scope('Optimizer'):
            #self.optimizer = tf.train.AdamOptimizer(self.args.learning_rate)
            self.optimizer = tf.train.RMSPropOptimizer(self.args.learning_rate)
            tvars = tf.trainable_variables()
            grads, self.global_norm = tf.clip_by_global_norm(
                tf.gradients(self.total_loss, tvars), self.args.grad_clip)
            self.train_op = self.optimizer.apply_gradients(zip(grads, tvars))

        with tf.variable_scope('Summary'):
            tf.summary.scalar("orth_loss", self.orth_loss)
            tf.summary.scalar("guilded_loss", self.guilded_loss)
            tf.summary.scalar("structure_loss", self.structure_loss)
            tf.summary.scalar("total_loss", self.total_loss)
            tf.summary.scalar("globol_norm", self.global_norm)
            for (grad, var) in zip(grads, tvars):
                if grad is not None:
                    tf.summary.histogram('grad/{}'.format(var.name), grad)
                    tf.summary.histogram('weight/{}'.format(var.name), var)

            log_dir = os.path.join(self.save_path, 'logs')
            if os.path.exists(log_dir):
                shutil.rmtree(log_dir)
            self.summary_writer = tf.summary.FileWriter(
                log_dir, self.sess.graph)

            config = projector.ProjectorConfig()
            embedding = config.embeddings.add()
            embedding.tensor_name = self.embeddings.name
            embedding.metadata_path = os.path.join(
                os.path.join(self.args.save_path, 'data', 'index.tsv'))
            projector.visualize_embeddings(self.summary_writer, config)

            self.merged_summary = tf.summary.merge_all()

        self.saver = tf.train.Saver()
        self.sess.run(tf.global_variables_initializer())
Beispiel #32
0
  def set_model(self, model):
    """Sets Keras model and creates summary ops."""

    self.model = model
    self._init_writer(model)
    # histogram summaries only enabled in graph mode
    if not context.executing_eagerly():
      self._make_histogram_ops(model)
      self.merged = tf_summary.merge_all()

    # If both embedding_freq and embeddings_data are available, we will
    # visualize embeddings.
    if self.embeddings_freq and self.embeddings_data is not None:
      # Avoid circular dependency.
      from tensorflow.python.keras.engine import training_utils  # pylint: disable=g-import-not-at-top
      self.embeddings_data = training_utils.standardize_input_data(
          self.embeddings_data, model.input_names)

      # If embedding_layer_names are not provided, get all of the embedding
      # layers from the model.
      embeddings_layer_names = self.embeddings_layer_names
      if not embeddings_layer_names:
        embeddings_layer_names = [
            layer.name
            for layer in self.model.layers
            if type(layer).__name__ == 'Embedding'
        ]

      self.assign_embeddings = []
      embeddings_vars = {}

      self.batch_id = batch_id = array_ops.placeholder(dtypes.int32)
      self.step = step = array_ops.placeholder(dtypes.int32)

      for layer in self.model.layers:
        if layer.name in embeddings_layer_names:
          embedding_input = self.model.get_layer(layer.name).output
          embedding_size = np.prod(embedding_input.shape[1:])
          embedding_input = array_ops.reshape(embedding_input,
                                              (step, int(embedding_size)))
          shape = (self.embeddings_data[0].shape[0], int(embedding_size))
          embedding = variables.Variable(
              array_ops.zeros(shape), name=layer.name + '_embedding')
          embeddings_vars[layer.name] = embedding
          batch = state_ops.assign(embedding[batch_id:batch_id + step],
                                   embedding_input)
          self.assign_embeddings.append(batch)

      self.saver = saver.Saver(list(embeddings_vars.values()))

      # Create embeddings_metadata dictionary
      if isinstance(self.embeddings_metadata, str):
        embeddings_metadata = {
            layer_name: self.embeddings_metadata
            for layer_name in embeddings_vars.keys()
        }
      else:
        # If embedding_metadata is already a dictionary
        embeddings_metadata = self.embeddings_metadata

      try:
        from tensorboard.plugins import projector
      except ImportError:
        raise ImportError('Failed to import TensorBoard. Please make sure that '
                          'TensorBoard integration is complete."')

      # TODO(psv): Add integration tests to test embedding visualization
      # with TensorBoard callback. We are unable to write a unit test for this
      # because TensorBoard dependency assumes TensorFlow package is installed.
      config = projector.ProjectorConfig()
      for layer_name, tensor in embeddings_vars.items():
        embedding = config.embeddings.add()
        embedding.tensor_name = tensor.name

        if (embeddings_metadata is not None and
            layer_name in embeddings_metadata):
          embedding.metadata_path = embeddings_metadata[layer_name]

      projector.visualize_embeddings(self.writer, config)
Beispiel #33
0
  def set_model(self, model):
    """Sets Keras model and creates summary ops."""

    self.model = model
    self.sess = K.get_session()
    # only make histogram summary op if it hasn't already been made
    if self.histogram_freq and self.merged is None:
      for layer in self.model.layers:
        for weight in layer.weights:
          mapped_weight_name = weight.name.replace(':', '_')
          tf_summary.histogram(mapped_weight_name, weight)
          if self.write_images:
            w_img = array_ops.squeeze(weight)
            shape = K.int_shape(w_img)
            if len(shape) == 2:  # dense layer kernel case
              if shape[0] > shape[1]:
                w_img = array_ops.transpose(w_img)
                shape = K.int_shape(w_img)
              w_img = array_ops.reshape(w_img, [1, shape[0], shape[1], 1])
            elif len(shape) == 3:  # convnet case
              if K.image_data_format() == 'channels_last':
                # switch to channels_first to display
                # every kernel as a separate image
                w_img = array_ops.transpose(w_img, perm=[2, 0, 1])
                shape = K.int_shape(w_img)
              w_img = array_ops.reshape(w_img,
                                        [shape[0], shape[1], shape[2], 1])
            elif len(shape) == 1:  # bias case
              w_img = array_ops.reshape(w_img, [1, shape[0], 1, 1])
            else:
              # not possible to handle 3D convnets etc.
              continue

            shape = K.int_shape(w_img)
            assert len(shape) == 4 and shape[-1] in [1, 3, 4]
            tf_summary.image(mapped_weight_name, w_img)

        if self.write_grads:
          for weight in layer.trainable_weights:
            mapped_weight_name = weight.name.replace(':', '_')
            grads = model.optimizer.get_gradients(model.total_loss, weight)

            def is_indexed_slices(grad):
              return type(grad).__name__ == 'IndexedSlices'

            grads = [grad.values if is_indexed_slices(grad) else grad
                     for grad in grads]
            tf_summary.histogram('{}_grad'.format(mapped_weight_name), grads)

        if hasattr(layer, 'output'):
          if isinstance(layer.output, list):
            for i, output in enumerate(layer.output):
              tf_summary.histogram('{}_out_{}'.format(layer.name, i), output)
          else:
            tf_summary.histogram('{}_out'.format(layer.name), layer.output)
    self.merged = tf_summary.merge_all()

    if self.write_graph:
      self.writer = self._writer_class(self.log_dir, self.sess.graph)
    else:
      self.writer = self._writer_class(self.log_dir)

    # If both embedding_freq and embeddings_data are available, we will
    # visualize embeddings.
    if self.embeddings_freq and self.embeddings_data is not None:
      self.embeddings_data = standardize_input_data(self.embeddings_data,
                                                    model.input_names)

      # If embedding_layer_names are not provided, get all of the embedding
      # layers from the model.
      embeddings_layer_names = self.embeddings_layer_names
      if not embeddings_layer_names:
        embeddings_layer_names = [
            layer.name
            for layer in self.model.layers
            if type(layer).__name__ == 'Embedding'
        ]

      self.assign_embeddings = []
      embeddings_vars = {}

      self.batch_id = batch_id = array_ops.placeholder(dtypes.int32)
      self.step = step = array_ops.placeholder(dtypes.int32)

      for layer in self.model.layers:
        if layer.name in embeddings_layer_names:
          embedding_input = self.model.get_layer(layer.name).output
          embedding_size = np.prod(embedding_input.shape[1:])
          embedding_input = array_ops.reshape(embedding_input,
                                              (step, int(embedding_size)))
          shape = (self.embeddings_data[0].shape[0], int(embedding_size))
          embedding = variables.Variable(
              array_ops.zeros(shape), name=layer.name + '_embedding')
          embeddings_vars[layer.name] = embedding
          batch = state_ops.assign(embedding[batch_id:batch_id + step],
                                   embedding_input)
          self.assign_embeddings.append(batch)

      self.saver = saver.Saver(list(embeddings_vars.values()))

      # Create embeddings_metadata dictionary
      if isinstance(self.embeddings_metadata, str):
        embeddings_metadata = {
            layer_name: self.embeddings_metadata
            for layer_name in embeddings_vars.keys()
        }
      else:
        # If embedding_metadata is already a dictionary
        embeddings_metadata = self.embeddings_metadata

      try:
        from tensorboard.plugins import projector
      except ImportError:
        raise ImportError('Failed to import TensorBoard. Please make sure that '
                          'TensorBoard integration is complete."')

      # TODO(psv): Add integration tests to test embedding visualization
      # with TensorBoard callback. We are unable to write a unit test for this
      # because TensorBoard dependency assumes TensorFlow package is installed.
      config = projector.ProjectorConfig()
      for layer_name, tensor in embeddings_vars.items():
        embedding = config.embeddings.add()
        embedding.tensor_name = tensor.name

        if (embeddings_metadata is not None and
            layer_name in embeddings_metadata):
          embedding.metadata_path = embeddings_metadata[layer_name]

      projector.visualize_embeddings(self.writer, config)