def __init__(self, model, period, time_reference, dirName, sample_indices_by_dataset={VALIDATION: []}, hop_legth_cqt=128, dataset_keys=[TRAIN, VALIDATION], target_dim=1): self._dir_name = dirName + '/pca_latent' super().__init__(model, period, time_reference, dataset_keys=dataset_keys, hop_legth_cqt=hop_legth_cqt, dirName=self._dir_name) self.model_class_name = type(model).__name__ self._sample_indices_by_dataset = sample_indices_by_dataset check_dataset_keys_not_loop(list(sample_indices_by_dataset.keys())) tf_logging.info("Create WavLatentPCAHook for: \n" + \ "\n".join([ds_key + ": " + ", ".join(map(str, idxs)) \ for ds_key, idxs in sample_indices_by_dataset.items()]))
def load_images(self, session): check_dataset_keys_not_loop(list(self._images_indexes.keys())) images = {ds_key: (index_list, self._model.dataset.get_elements(self._model.x, self._ds_handle, self._ds_handles[ds_key], self._ds_initializers[ds_key], session, index_list)) \ for (ds_key, index_list) in self._images_indexes.items()} self._images = images
def load_labels(self, session): if self._conditional and self._labels is None: check_dataset_keys_not_loop(list(self._images_indexes.keys())) labels = {ds_key: (index_list, self._model.dataset.get_elements(self._model.y, self._ds_handle, self._ds_handles[ds_key], self._ds_initializers[ds_key], session, index_list)) \ for (ds_key, index_list) in self._images_indexes.items()} self._labels = labels
def __init__( self, model, period, time_reference, dirName, sample_indices_by_dataset={VALIDATION: []}, fast_gen=True, # use fast_generation wavenet for reconstruction without teacher forcing debug_fast_gen=True, # use fast_generation wavenet with the true input shifted and quantized to reconstruct with teacher forcing and check the FastGen network hop_legth_cqt=128, dataset_keys=[TRAIN, VALIDATION], save_wav=False, compute_reconstruction_metrics=True, _plot=True, generate_from_mean=True, spider_plot_time_splits=None, anomaly_detection_params=None): super().__init__(model, period, time_reference, dataset_keys=dataset_keys, hop_legth_cqt=hop_legth_cqt, dirName=dirName) self.model_class_name = type(model).__name__ self.save_wav = save_wav self.compute_reconstruction_metrics = compute_reconstruction_metrics self._plot = _plot self.generate_from_mean = generate_from_mean self.spider_plot_time_splits = spider_plot_time_splits self.anomaly_detection_params = anomaly_detection_params if compute_reconstruction_metrics: self.reconstr_metrics_file_names = { TRAIN: dirName + '/reconstr_metrics_x_train.txt', VALIDATION: dirName + '/reconstr_metrics_x_validation.txt', TEST: dirName + '/reconstr_metrics_x_test.txt', } self._sample_indices_by_dataset = sample_indices_by_dataset self._fast_gen = bool(fast_gen) self._debug_fast_gen = bool(debug_fast_gen) check_dataset_keys_not_loop(list(sample_indices_by_dataset.keys())) tf_logging.info("Create WavGenerateHook for: \n" + \ "\n".join([ds_key + ": " + ", ".join(map(str, idxs or ['all'])) \ for ds_key, idxs in sample_indices_by_dataset.items()]))
def load_masks(self, session): if self._masks == None and self._model.mask is not None: check_dataset_keys_not_loop(list(self._images_indexes.keys())) masks = {ds_key : (index_list, self._model.dataset.get_elements(self._model.mask, self._ds_handle, self._ds_handles[ds_key], self._ds_initializers[ds_key], session, index_list)) \ for (ds_key, index_list) in self._images_indexes.items()} # I set something like the following structure, e.g. # images = {TRAIN : ([0,100,200,300,400,500], train_images), # VALIDATION : ([0,100,200,300], validation_images), # }, self._masks = masks
def load_images_once(self, session): if self._images==None: check_dataset_keys_not_loop(list(self._images_indexes.keys())) images = {ds_key : (couple_indices_list, [(img1, img2) for img1, img2 in zip( self._model.dataset.get_elements(self._model.x, self._ds_handle, self._ds_handles[ds_key], self._ds_initializers[ds_key], session, [i[0] for i in couple_indices_list]), self._model.dataset.get_elements(self._model.x, self._ds_handle, self._ds_handles[ds_key], self._ds_initializers[ds_key], session, [i[1] for i in couple_indices_list]))]) \ for (ds_key, couple_indices_list) in self._images_indexes.items()} # I set something like the following structure, e.g. # images = {TRAIN : ([(0,50),(100,230),(200,790),(300,600),(400,1000),(500,10)], list_of_couples_train_images, # VALIDATION : ([(0,50),(100,230),(200,790),(300,600),(400,1000),(500,10)], list_of_couples_validation_images, # }, self._images = images
def __init__(self, model, period, time_reference, datasets_keys, plot_offset, tensorboard_dir=None, trigger_summaries=False, extra_feed_dict={}): time_choices = [EPOCHS, STEPS] if not time_reference in time_choices: raise ValueError("time reference attribute can be only in %s" % time_choices) self._timer = tf.train.SecondOrStepTimer(every_secs=None, every_steps=period) self._time_reference_str = time_reference self._time_ref_shortstr = self._time_reference_str[:2] self._model = model self._plot_offset = plot_offset self._extra_feed_dict = extra_feed_dict # called in before_run self._nodes_to_be_computed_by_run = {} check_dataset_keys_not_loop(datasets_keys) self._datasets_keys = datasets_keys self._ds_initializers = model.datasets_initializers self._ds_handles_nodes = model.datasets_handles_nodes self._ds_handle = model.ds_handle # these needs to be defined in the child class self._tensors_names = None self._tensors_plots = None self._tensors_values = None self._trigger_summaries = trigger_summaries self._tensorboard_dir = tensorboard_dir assert not self._trigger_summaries or self._tensorboard_dir is not None, \ "If you specified that you want to Trigger Summeries, you should also specify where to save them." self.SUMMARIES_KEY = "log_mean_summaries" self._default_plot_bool = True
def __init__(self, model, period, time_reference, dirName, sample_indices_by_dataset={VALIDATION: []}, hop_legth_cqt=128, dataset_keys=[TRAIN, VALIDATION], save_wav=False, reconstruct_from_mean=True, batch_size=20, _plot=True, spider_plot_time_splits=None, anomaly_detection_params=None, compute_reconstruction_metrics=True): super().__init__(model, period, time_reference, dataset_keys=dataset_keys, hop_legth_cqt=hop_legth_cqt, dirName=dirName) self._plot = _plot self.reconstruct_from_mean = reconstruct_from_mean self.save_wav = save_wav self.batch_size = batch_size self.spider_plot_time_splits = spider_plot_time_splits self._sample_indices_by_dataset = sample_indices_by_dataset self.compute_reconstruction_metrics = compute_reconstruction_metrics self.anomaly_detection_params = anomaly_detection_params if compute_reconstruction_metrics: self.reconstr_metrics_file_names = { TRAIN: dirName + '/reconstr_metrics_tf_train.txt', VALIDATION: dirName + '/reconstr_metrics_tf_validation.txt', TEST: dirName + '/reconstr_metrics_tf_test.txt', } check_dataset_keys_not_loop(list(sample_indices_by_dataset.keys())) tf_logging.info("Create WavReconstructHook for: \n" + \ "\n".join([ds_key + ": " + ", ".join(map(str, idxs or ['all'])) \ for ds_key, idxs in sample_indices_by_dataset.items()]))
def __init__(self, model, period, time_reference, dirName, sample_indices_by_dataset={ VALIDATION: []}, hop_legth_cqt=128, dataset_keys=[TRAIN, VALIDATION] ): _dirName = dirName + '/mean_variance_plots' super().__init__(model, period, time_reference, dataset_keys=dataset_keys, hop_legth_cqt=hop_legth_cqt, dirName=_dirName) self._sample_indices_by_dataset = sample_indices_by_dataset check_dataset_keys_not_loop(list(sample_indices_by_dataset.keys())) tf_logging.info("Create WavenetGaussianVisualizationHook for: \n" + "\n".join([ds_key + ": " + ", ".join(map(str, idxs)) for ds_key, idxs in sample_indices_by_dataset.items()]))