Ejemplo n.º 1
0
def _process_stems(save_dir, encoder_stem, decoder_stem, losses_stem,
                   config_stem):

    paths = [save_dir, encoder_stem, decoder_stem, losses_stem, config_stem]
    save_dir, encoder_stem, decoder_stem, losses_stem, config_stem = check_path(
        paths, path_type=Path)

    encoder_base = encoder_stem.with_suffix('.h5')
    decoder_base = decoder_stem.with_suffix('.h5')
    losses_base = losses_stem.with_suffix('.dict')
    config_base = config_stem.with_suffix('.json')

    bases = encoder_base, decoder_base, losses_base, config_base
    files = [save_dir / base for base in bases]
    encoder_file, decoder_file, loss_file, config_file = check_path(
        files, path_type=str)
    return encoder_file, decoder_file, loss_file, config_file
Ejemplo n.º 2
0
 def __init__(self,
              images,
              output_dir,
              stem='reconstruction',
              fmt='jpg',
              batch_size=32):
     super().__init__()
     self.images = images
     self.output_dir = check_path(output_dir, Path)
     self.stem = stem
     if fmt[0] != '.':
         fmt = '.' + fmt
     self.format = fmt
     self.model_predict = None
     self.batch_size = batch_size
Ejemplo n.º 3
0
    def __init__(self,
                 csv_file,
                 data,
                 batch_size=512,
                 verbose=False,
                 overwrite=False):
        super().__init__()
        self.csv_file = check_path(csv_file, path_type=Path)
        self.data = data
        self.batch_size = batch_size
        self.overwrite = overwrite

        self.encoder = None
        self.num_latents = None
        self.csv_header = None
        self.batch_means = None
        self.batch_stds = None
        self.verbose = verbose
Ejemplo n.º 4
0
 def __init__(self,
              image,
              output_dir,
              stem='traversal',
              traversal_range=(-4, 4),
              traversal_resolution=10):
     super().__init__()
     if image.ndim == 4:
         assert (image.shape[0] == 1), (
             'Traversal check can only be performed on a single image.')
     else:
         image = np.expand_dims(image, axis=0)
     self.image = image
     self.output_dir = check_path(output_dir, Path)
     self.stem = stem
     self.tcvae = None
     self.traversal_range = traversal_range
     self.traversal_resolution = traversal_resolution
Ejemplo n.º 5
0
    def save(self,
             save_dir,
             encoder_stem='encoder',
             decoder_stem='decoder',
             losses_stem='losses',
             config_stem='config',
             overwrite=False):
        """
        Saves model resources into a user-specified directory.

        Parameters
        ----------
        save_dir : str or pathlib.Path
            The directory where model resources will be saved.
        encoder_stem : str
            The name of the saved encoder file without file extension.
        decoder_stem : str
            The name of the saved decoder file without file extension.
        losses_stem : str
            The name of the saved loss dictionary without file extension.
        overwrite : bool
            Whether to overwrite an already existing directory.
        """
        save_dir = check_path(save_dir, Path)
        if save_dir.exists() and overwrite:
            rmtree(save_dir)
            save_dir.mkdir()

        encoder_file, decoder_file, loss_file, config_file = _process_stems(
            save_dir, encoder_stem, decoder_stem, losses_stem, config_stem)

        self.encoder.save(encoder_file)
        self.decoder.save(decoder_file)
        with open(loss_file, 'wb') as f:
            pkl.dump(self.loss_dict, f)
        config = {
            'batch_size': self.batch_size,
            'dataset_size': self.dataset_size
        }
        write_json(config_file, config)
Ejemplo n.º 6
0
 def __init__(self, model, model_dir, monitor):
     super().__init__()
     self.model_ = model
     self.model_dir = check_path(model_dir, Path)
     self.monitor = monitor
     self.best_value = float('inf')