コード例 #1
0
    def _read_latest_config_files(self, run_path_pairs):
        """Reads and returns the projector config files in every run directory."""
        configs = {}
        config_fpaths = {}
        for run_name, logdir in run_path_pairs:
            config = projector_config_pb2.ProjectorConfig()
            config_fpath = os.path.join(logdir, PROJECTOR_FILENAME)
            if file_io.file_exists(config_fpath):
                file_content = file_io.read_file_to_string(config_fpath)
                text_format.Merge(file_content, config)

            has_tensor_files = False
            for embedding in config.embeddings:
                if embedding.tensor_path:
                    has_tensor_files = True
                    break

            if not config.model_checkpoint_path:
                # See if you can find a checkpoint file in the logdir.
                ckpt_path = _find_latest_checkpoint(logdir)
                if not ckpt_path and not has_tensor_files:
                    continue
                if ckpt_path:
                    config.model_checkpoint_path = ckpt_path

            # Sanity check for the checkpoint file.
            if (config.model_checkpoint_path
                    and not checkpoint_exists(config.model_checkpoint_path)):
                logging.warning('Checkpoint file %s not found',
                                config.model_checkpoint_path)
                continue
            configs[run_name] = config
            config_fpaths[run_name] = config_fpath
        return configs, config_fpaths
コード例 #2
0
ファイル: latent_space.py プロジェクト: oskopek/mvae
def main() -> None:
    parser = argparse.ArgumentParser()
    parser.add_argument("--path", type=str, help="Path to checkpoint folder.", required=True)
    args = parser.parse_args()

    if not os.path.isdir(args.path):
        raise ValueError(f"Input folder doesn't exist: '{args.path}'.")

    projector_meta_path = os.path.join(args.path, "projector_config.pbtxt")
    if not os.path.isfile(projector_meta_path):
        raise ValueError(f"Projector metadata file doesn't exist: '{projector_meta_path}'.")

    with open(projector_meta_path) as f:
        txt = f.read()

    projector_config = protobuf.text_format.Parse(txt, projector_config_pb2.ProjectorConfig())  # type: ignore

    def to_key(embedding_info: projector_config_pb2.EmbeddingInfo) -> Tuple[str, int]:
        name, checkpoint = embedding_info.tensor_name.split(":")
        return name, int(checkpoint)

    embeddings = {to_key(e): e for e in projector_config.embeddings}
    last_embeddings = [v for k, v in embeddings.items() if k[1] >= max(k[1] for k in embeddings.keys())]
    for embedding_info in last_embeddings:
        figure, name = plot_latent_space(embedding_info, args.path)
        utils.export_plots(figure,
                           filename=os.path.join(args.path, name),
                           title=f"Latent space {name}",
                           box=True,
                           x_range_start=-1,
                           x_range_end=1,
                           y_range_start=-1,
                           y_range_end=1)
コード例 #3
0
    def testAddEmbeddingWithTwoMetadataColumns(self):
        manager = plugin_asset.get_plugin_asset(
            projector_plugin.ProjectorPluginAsset)

        metadata = projector_plugin.EmbeddingMetadata(3)
        metadata.add_column('labels', ['a', 'b', 'друг јазик'])
        metadata.add_column('sizes', [10, 20, 30])
        manager.add_embedding('test', np.array([[1], [2], [3]]), metadata)

        config = projector_config_pb2.ProjectorConfig()
        embedding = config.embeddings.add()
        embedding.tensor_name = 'test'
        embedding.tensor_shape.extend([3, 1])
        embedding.tensor_path = 'test_values.tsv'
        embedding.metadata_path = 'test_metadata.tsv'
        expected_config_pbtxt = text_format.MessageToString(config)

        self.assertEqual(
            manager.assets(), {
                'projector_config.pbtxt':
                expected_config_pbtxt,
                'test_values.tsv':
                b'1\n2\n3\n',
                'test_metadata.tsv':
                'labels\tsizes\na\t10\nb\t20\nдруг јазик\t30\n'
            })
コード例 #4
0
  def testVisualizeEmbeddings(self):
    # Create a dummy configuration.
    config = projector_config_pb2.ProjectorConfig()
    config.model_checkpoint_path = 'test'
    emb1 = config.embeddings.add()
    emb1.tensor_name = 'tensor1'
    emb1.metadata_path = 'metadata1'

    # Call the API method to save the configuration to a temporary dir.
    temp_dir = self.get_temp_dir()
    self.addCleanup(shutil.rmtree, temp_dir)
    writer = writer_lib.FileWriter(temp_dir)
    projector.visualize_embeddings(writer, config)

    # Read the configuratin from disk and make sure it matches the original.
    with gfile.GFile(os.path.join(temp_dir, 'projector_config.pbtxt')) as f:
      config2 = projector_config_pb2.ProjectorConfig()
      text_format.Parse(f.read(), config2)
      self.assertEqual(config, config2)
コード例 #5
0
    def testAddEmbeddingNoMetadata(self):
        manager = plugin_asset.get_plugin_asset(
            projector_plugin.ProjectorPluginAsset)
        manager.add_embedding('test', np.array([[1, 2, 3.1]]))

        config = projector_config_pb2.ProjectorConfig()
        embedding = config.embeddings.add()
        embedding.tensor_name = 'test'
        embedding.tensor_shape.extend([1, 3])
        embedding.tensor_path = 'test_values.tsv'
        expected_config_pbtxt = text_format.MessageToString(config)

        self.assertEqual(
            manager.assets(), {
                'projector_config.pbtxt': expected_config_pbtxt,
                'test_values.tsv': b'1\t2\t3.1\n'
            })
コード例 #6
0
    def testAddMetadataForVariable(self):
        manager = plugin_asset.get_plugin_asset(
            projector_plugin.ProjectorPluginAsset)
        metadata = projector_plugin.EmbeddingMetadata(3)
        metadata.add_column('Labels', ['a', 'b', 'c'])
        manager.add_metadata_for_embedding_variable('test', metadata)

        config = projector_config_pb2.ProjectorConfig()
        embedding = config.embeddings.add()
        embedding.tensor_name = 'test'
        embedding.metadata_path = 'test_metadata.tsv'
        expected_config_pbtxt = text_format.MessageToString(config)

        self.assertEqual(
            manager.assets(), {
                'projector_config.pbtxt': expected_config_pbtxt,
                'test_metadata.tsv': 'a\nb\nc\n'
            })
コード例 #7
0
def _latest_checkpoints_changed(configs, run_path_pairs):
    """Returns true if the latest checkpoint has changed in any of the runs."""
    for run_name, logdir in run_path_pairs:
        if run_name not in configs:
            config = projector_config_pb2.ProjectorConfig()
            config_fpath = os.path.join(logdir, PROJECTOR_FILENAME)
            if file_io.file_exists(config_fpath):
                file_content = file_io.read_file_to_string(config_fpath)
                text_format.Merge(file_content, config)
        else:
            config = configs[run_name]

        # See if you can find a checkpoint file in the logdir.
        ckpt_path = _find_latest_checkpoint(logdir)
        if not ckpt_path:
            continue
        if config.model_checkpoint_path != ckpt_path:
            return True
    return False
コード例 #8
0
    def _GenerateProjectorTestData(self):
        config_path = os.path.join(self.log_dir, 'projector_config.pbtxt')
        config = projector_config_pb2.ProjectorConfig()
        embedding = config.embeddings.add()
        # Add an embedding by its canonical tensor name.
        embedding.tensor_name = 'var1:0'
        config_pbtxt = text_format.MessageToString(config)
        with gfile.GFile(config_path, 'w') as f:
            f.write(config_pbtxt)

        # Write a checkpoint with some dummy variables.
        with ops.Graph().as_default():
            sess = session.Session()
            checkpoint_path = os.path.join(self.log_dir, 'model')
            variable_scope.get_variable(
                'var1', [1, 2], initializer=init_ops.constant_initializer(6.0))
            variable_scope.get_variable('var2', [10, 10])
            variable_scope.get_variable('var3', [100, 100])
            sess.run(variables.global_variables_initializer())
            saver = saver_lib.Saver(write_version=saver_pb2.SaverDef.V1)
            saver.save(sess, checkpoint_path)
コード例 #9
0
    def testAddEmbeddingWithThumbnails(self):
        manager = plugin_asset.get_plugin_asset(
            projector_plugin.ProjectorPluginAsset)

        image1 = np.array([[[1, 2, 3], [4, 5, 6]], [[7, 8, 9], [10, 11, 12]]])
        image2 = np.array([[[10, 20, 30], [40, 50, 60]],
                           [[70, 80, 90], [100, 110, 120]]])
        manager.add_embedding('test',
                              np.array([[1], [2], [3]]),
                              thumbnails=[image1, image2],
                              thumbnail_dim=[2, 2])

        assets = manager.assets()

        config = projector_config_pb2.ProjectorConfig()
        embedding = config.embeddings.add()
        embedding.tensor_name = 'test'
        embedding.tensor_shape.extend([3, 1])
        embedding.tensor_path = 'test_values.tsv'
        embedding.sprite.image_path = 'test_sprite.png'
        embedding.sprite.single_image_dim.extend([2, 2])
        expected_config_pbtxt = text_format.MessageToString(config)

        self.assertEqual(assets['projector_config.pbtxt'],
                         expected_config_pbtxt)
        self.assertEqual(assets['test_values.tsv'], b'1\n2\n3\n')

        png_bytes = assets['test_sprite.png']
        with ops.Graph().as_default():
            s = session.Session()
            image_array = image_ops.decode_png(png_bytes).eval(
                session=s).tolist()
        expected_master_image = [[[1, 2, 3], [4, 5, 6], [10, 20, 30],
                                  [40, 50, 60]],
                                 [[7, 8, 9], [10, 11, 12], [70, 80, 90],
                                  [100, 110, 120]],
                                 [[0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0]],
                                 [[0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0]]]
        self.assertEqual(image_array, expected_master_image)
コード例 #10
0
 def __init__(self):
     self._config = projector_config_pb2.ProjectorConfig()
     self._assets = {}
     self._used_names = set()