def _process_stems(save_dir, encoder_stem, decoder_stem, losses_stem, config_stem): paths = [save_dir, encoder_stem, decoder_stem, losses_stem, config_stem] save_dir, encoder_stem, decoder_stem, losses_stem, config_stem = check_path( paths, path_type=Path) encoder_base = encoder_stem.with_suffix('.h5') decoder_base = decoder_stem.with_suffix('.h5') losses_base = losses_stem.with_suffix('.dict') config_base = config_stem.with_suffix('.json') bases = encoder_base, decoder_base, losses_base, config_base files = [save_dir / base for base in bases] encoder_file, decoder_file, loss_file, config_file = check_path( files, path_type=str) return encoder_file, decoder_file, loss_file, config_file
def __init__(self, images, output_dir, stem='reconstruction', fmt='jpg', batch_size=32): super().__init__() self.images = images self.output_dir = check_path(output_dir, Path) self.stem = stem if fmt[0] != '.': fmt = '.' + fmt self.format = fmt self.model_predict = None self.batch_size = batch_size
def __init__(self, csv_file, data, batch_size=512, verbose=False, overwrite=False): super().__init__() self.csv_file = check_path(csv_file, path_type=Path) self.data = data self.batch_size = batch_size self.overwrite = overwrite self.encoder = None self.num_latents = None self.csv_header = None self.batch_means = None self.batch_stds = None self.verbose = verbose
def __init__(self, image, output_dir, stem='traversal', traversal_range=(-4, 4), traversal_resolution=10): super().__init__() if image.ndim == 4: assert (image.shape[0] == 1), ( 'Traversal check can only be performed on a single image.') else: image = np.expand_dims(image, axis=0) self.image = image self.output_dir = check_path(output_dir, Path) self.stem = stem self.tcvae = None self.traversal_range = traversal_range self.traversal_resolution = traversal_resolution
def save(self, save_dir, encoder_stem='encoder', decoder_stem='decoder', losses_stem='losses', config_stem='config', overwrite=False): """ Saves model resources into a user-specified directory. Parameters ---------- save_dir : str or pathlib.Path The directory where model resources will be saved. encoder_stem : str The name of the saved encoder file without file extension. decoder_stem : str The name of the saved decoder file without file extension. losses_stem : str The name of the saved loss dictionary without file extension. overwrite : bool Whether to overwrite an already existing directory. """ save_dir = check_path(save_dir, Path) if save_dir.exists() and overwrite: rmtree(save_dir) save_dir.mkdir() encoder_file, decoder_file, loss_file, config_file = _process_stems( save_dir, encoder_stem, decoder_stem, losses_stem, config_stem) self.encoder.save(encoder_file) self.decoder.save(decoder_file) with open(loss_file, 'wb') as f: pkl.dump(self.loss_dict, f) config = { 'batch_size': self.batch_size, 'dataset_size': self.dataset_size } write_json(config_file, config)
def __init__(self, model, model_dir, monitor): super().__init__() self.model_ = model self.model_dir = check_path(model_dir, Path) self.monitor = monitor self.best_value = float('inf')