Ejemplo n.º 1
0
    def __init__(self,
                 model,
                 period,
                 time_reference,
                 dirName,
                 sample_indices_by_dataset={VALIDATION: []},
                 hop_legth_cqt=128,
                 dataset_keys=[TRAIN, VALIDATION],
                 target_dim=1):

        self._dir_name = dirName + '/pca_latent'
        super().__init__(model,
                         period,
                         time_reference,
                         dataset_keys=dataset_keys,
                         hop_legth_cqt=hop_legth_cqt,
                         dirName=self._dir_name)

        self.model_class_name = type(model).__name__

        self._sample_indices_by_dataset = sample_indices_by_dataset

        check_dataset_keys_not_loop(list(sample_indices_by_dataset.keys()))

        tf_logging.info("Create WavLatentPCAHook for: \n" + \
                        "\n".join([ds_key + ": " + ", ".join(map(str, idxs)) \
                                   for ds_key, idxs in sample_indices_by_dataset.items()]))
Ejemplo n.º 2
0
    def load_images(self, session):
        check_dataset_keys_not_loop(list(self._images_indexes.keys()))

        images = {ds_key: (index_list, self._model.dataset.get_elements(self._model.x, self._ds_handle,
                                                                        self._ds_handles[ds_key],
                                                                        self._ds_initializers[ds_key], session,
                                                                        index_list)) \
                  for (ds_key, index_list) in self._images_indexes.items()}

        self._images = images
    def load_labels(self, session):

        if self._conditional and self._labels is None:
            check_dataset_keys_not_loop(list(self._images_indexes.keys()))

            labels = {ds_key: (index_list, self._model.dataset.get_elements(self._model.y, self._ds_handle,
                                                                            self._ds_handles[ds_key],
                                                                            self._ds_initializers[ds_key], session,
                                                                            index_list)) \
                      for (ds_key, index_list) in self._images_indexes.items()}

            self._labels = labels
Ejemplo n.º 4
0
    def __init__(
            self,
            model,
            period,
            time_reference,
            dirName,
            sample_indices_by_dataset={VALIDATION: []},
            fast_gen=True,  # use fast_generation wavenet for reconstruction without teacher forcing
            debug_fast_gen=True,
            # use fast_generation wavenet with the true input shifted and quantized to reconstruct with teacher forcing and check the FastGen network
            hop_legth_cqt=128,
            dataset_keys=[TRAIN, VALIDATION],
            save_wav=False,
            compute_reconstruction_metrics=True,
            _plot=True,
            generate_from_mean=True,
            spider_plot_time_splits=None,
            anomaly_detection_params=None):

        super().__init__(model,
                         period,
                         time_reference,
                         dataset_keys=dataset_keys,
                         hop_legth_cqt=hop_legth_cqt,
                         dirName=dirName)

        self.model_class_name = type(model).__name__
        self.save_wav = save_wav
        self.compute_reconstruction_metrics = compute_reconstruction_metrics
        self._plot = _plot
        self.generate_from_mean = generate_from_mean
        self.spider_plot_time_splits = spider_plot_time_splits

        self.anomaly_detection_params = anomaly_detection_params

        if compute_reconstruction_metrics:
            self.reconstr_metrics_file_names = {
                TRAIN: dirName + '/reconstr_metrics_x_train.txt',
                VALIDATION: dirName + '/reconstr_metrics_x_validation.txt',
                TEST: dirName + '/reconstr_metrics_x_test.txt',
            }

        self._sample_indices_by_dataset = sample_indices_by_dataset

        self._fast_gen = bool(fast_gen)
        self._debug_fast_gen = bool(debug_fast_gen)

        check_dataset_keys_not_loop(list(sample_indices_by_dataset.keys()))

        tf_logging.info("Create WavGenerateHook for: \n" + \
                        "\n".join([ds_key + ": " + ", ".join(map(str, idxs or ['all'])) \
                                   for ds_key, idxs in sample_indices_by_dataset.items()]))
    def load_masks(self, session):

        if self._masks == None and self._model.mask is not None:
            check_dataset_keys_not_loop(list(self._images_indexes.keys()))

            masks = {ds_key : (index_list, self._model.dataset.get_elements(self._model.mask, self._ds_handle, self._ds_handles[ds_key], self._ds_initializers[ds_key], session, index_list)) \
                      for (ds_key, index_list) in self._images_indexes.items()}
            # I set something like the following structure, e.g.
            # images = {TRAIN : ([0,100,200,300,400,500], train_images),
            #           VALIDATION : ([0,100,200,300], validation_images),
            #          },

            self._masks = masks
    def load_images_once(self, session):        
        if self._images==None:
            check_dataset_keys_not_loop(list(self._images_indexes.keys()))

            images = {ds_key : (couple_indices_list, [(img1, img2) for img1, img2 in zip(
            self._model.dataset.get_elements(self._model.x, self._ds_handle, self._ds_handles[ds_key], self._ds_initializers[ds_key], session, [i[0] for i in couple_indices_list]),
            self._model.dataset.get_elements(self._model.x, self._ds_handle, self._ds_handles[ds_key], self._ds_initializers[ds_key], session, [i[1] for i in couple_indices_list]))]) \
                for (ds_key, couple_indices_list) in self._images_indexes.items()}

            # I set something like the following structure, e.g.
            # images = {TRAIN : ([(0,50),(100,230),(200,790),(300,600),(400,1000),(500,10)], list_of_couples_train_images,
            #           VALIDATION : ([(0,50),(100,230),(200,790),(300,600),(400,1000),(500,10)], list_of_couples_validation_images,
            #          },        
            
            self._images = images
Ejemplo n.º 7
0
    def __init__(self,
                 model,
                 period,
                 time_reference,
                 datasets_keys,
                 plot_offset,
                 tensorboard_dir=None,
                 trigger_summaries=False,
                 extra_feed_dict={}):

        time_choices = [EPOCHS, STEPS]
        if not time_reference in time_choices:
            raise ValueError("time reference attribute can be only in %s" %
                             time_choices)

        self._timer = tf.train.SecondOrStepTimer(every_secs=None,
                                                 every_steps=period)

        self._time_reference_str = time_reference
        self._time_ref_shortstr = self._time_reference_str[:2]

        self._model = model

        self._plot_offset = plot_offset
        self._extra_feed_dict = extra_feed_dict
        # called in before_run
        self._nodes_to_be_computed_by_run = {}

        check_dataset_keys_not_loop(datasets_keys)
        self._datasets_keys = datasets_keys
        self._ds_initializers = model.datasets_initializers
        self._ds_handles_nodes = model.datasets_handles_nodes
        self._ds_handle = model.ds_handle

        # these needs to be defined in the child class
        self._tensors_names = None
        self._tensors_plots = None
        self._tensors_values = None

        self._trigger_summaries = trigger_summaries
        self._tensorboard_dir = tensorboard_dir
        assert not self._trigger_summaries or self._tensorboard_dir is not None, \
            "If you specified that you want to Trigger Summeries, you should also specify where to save them."
        self.SUMMARIES_KEY = "log_mean_summaries"

        self._default_plot_bool = True
Ejemplo n.º 8
0
    def __init__(self,
                 model,
                 period,
                 time_reference,
                 dirName,
                 sample_indices_by_dataset={VALIDATION: []},
                 hop_legth_cqt=128,
                 dataset_keys=[TRAIN, VALIDATION],
                 save_wav=False,
                 reconstruct_from_mean=True,
                 batch_size=20,
                 _plot=True,
                 spider_plot_time_splits=None,
                 anomaly_detection_params=None,
                 compute_reconstruction_metrics=True):

        super().__init__(model,
                         period,
                         time_reference,
                         dataset_keys=dataset_keys,
                         hop_legth_cqt=hop_legth_cqt,
                         dirName=dirName)

        self._plot = _plot
        self.reconstruct_from_mean = reconstruct_from_mean
        self.save_wav = save_wav
        self.batch_size = batch_size
        self.spider_plot_time_splits = spider_plot_time_splits
        self._sample_indices_by_dataset = sample_indices_by_dataset

        self.compute_reconstruction_metrics = compute_reconstruction_metrics

        self.anomaly_detection_params = anomaly_detection_params

        if compute_reconstruction_metrics:
            self.reconstr_metrics_file_names = {
                TRAIN: dirName + '/reconstr_metrics_tf_train.txt',
                VALIDATION: dirName + '/reconstr_metrics_tf_validation.txt',
                TEST: dirName + '/reconstr_metrics_tf_test.txt',
            }

        check_dataset_keys_not_loop(list(sample_indices_by_dataset.keys()))

        tf_logging.info("Create WavReconstructHook for: \n" + \
                        "\n".join([ds_key + ": " + ", ".join(map(str, idxs or ['all'])) \
                                   for ds_key, idxs in sample_indices_by_dataset.items()]))
    def __init__(self,
                 model,
                 period,
                 time_reference,
                 dirName,
                 sample_indices_by_dataset={
                     VALIDATION: []},
                 hop_legth_cqt=128,
                 dataset_keys=[TRAIN, VALIDATION]
                 ):

        _dirName = dirName + '/mean_variance_plots'
        super().__init__(model, period, time_reference, dataset_keys=dataset_keys, hop_legth_cqt=hop_legth_cqt,
                         dirName=_dirName)

        self._sample_indices_by_dataset = sample_indices_by_dataset

        check_dataset_keys_not_loop(list(sample_indices_by_dataset.keys()))

        tf_logging.info("Create WavenetGaussianVisualizationHook for: \n"
                        + "\n".join([ds_key + ": " + ", ".join(map(str, idxs))
                                     for ds_key, idxs in sample_indices_by_dataset.items()]))