def test_25d_init(self):
     reader = get_25d_reader()
     sampler = ResizeSampler(reader=reader,
                             window_sizes=SINGLE_25D_DATA,
                             batch_size=1,
                             shuffle=False,
                             queue_length=50)
     aggregator = WindowAsImageAggregator(
         image_reader=reader,
         output_path=os.path.join('testing_data', 'aggregated_identity'),
     )
     more_batch = True
     out_shape = []
     with self.cached_session() as sess:
         sampler.set_num_threads(2)
         while more_batch:
             try:
                 out = sess.run(sampler.pop_batch_op())
                 out_shape = out['image'].shape[1:] + (1, )
             except tf.errors.OutOfRangeError:
                 break
             more_batch = aggregator.decode_batch(
                 {'window_image': out['image']}, out['image_location'])
     output_filename = '{}_window_image_niftynet_generated.nii.gz'.format(
         sampler.reader.get_subject_id(0))
     output_file = os.path.join('testing_data', 'aggregated_identity',
                                output_filename)
     out_shape = [out_shape[i] for i in NEW_ORDER_2D] + [
         1,
     ]
     self.assertAllClose(nib.load(output_file).shape, out_shape[:2])
     sampler.close_all()
    def test_init_2d_mo_bidimcsv(self):
        reader = get_2d_reader()
        sampler = ResizeSampler(reader=reader,
                                window_sizes=MOD_2D_DATA,
                                batch_size=1,
                                shuffle=False,
                                queue_length=50)
        aggregator = WindowAsImageAggregator(
            image_reader=reader,
            output_path=os.path.join('testing_data', 'aggregated_identity'),
        )
        more_batch = True
        out_shape = []
        with self.cached_session() as sess:
            sampler.set_num_threads(2)
            while more_batch:
                try:
                    out = sess.run(sampler.pop_batch_op())
                    out_shape = out['image'].shape[1:] + (1, )
                except tf.errors.OutOfRangeError:
                    break
                min_val = np.sum((np.asarray(out['image']).flatten()))
                stats_val = [
                    np.min(out['image']),
                    np.max(out['image']),
                    np.sum(out['image'])
                ]
                stats_val = np.expand_dims(stats_val, 0)
                stats_val = np.concatenate([stats_val, stats_val], axis=0)
                more_batch = aggregator.decode_batch(
                    {
                        'window_image': out['image'],
                        'csv_sum': min_val,
                        'csv_stats2d': stats_val
                    }, out['image_location'])
        output_filename = '{}_window_image_niftynet_generated.nii.gz'.format(
            sampler.reader.get_subject_id(0))
        sum_filename = os.path.join(
            'testing_data', 'aggregated_identity',
            '{}_csv_sum_niftynet_generated.csv'.format(
                sampler.reader.get_subject_id(0)))
        stats_filename = os.path.join(
            'testing_data', 'aggregated_identity',
            '{}_csv_stats2d_niftynet_generated.csv'.format(
                sampler.reader.get_subject_id(0)))
        output_file = os.path.join('testing_data', 'aggregated_identity',
                                   output_filename)

        out_shape = [out_shape[i] for i in NEW_ORDER_2D] + [
            1,
        ]
        self.assertAllClose(nib.load(output_file).shape, out_shape[:2])
        min_pd = pd.read_csv(sum_filename)
        self.assertAllClose(min_pd.shape, [1, 2])
        stats_pd = pd.read_csv(stats_filename)
        self.assertAllClose(stats_pd.shape, [1, 7])
        sampler.close_all()
Ejemplo n.º 3
0
 def initialise_resize_aggregator(self):
     self.output_decoder = ResizeSamplesAggregator(
         image_reader=self.readers[0],
         output_path=self.action_param.save_seg_dir,
         window_border=self.action_param.border,
         interp_order=self.action_param.output_interp_order,
         postfix=self.action_param.output_postfix)
Ejemplo n.º 4
0
 def initialise_grid_aggregator(self):
     self.output_decoder = GridSamplesAggregator(
         image_reader=self.readers[0],
         output_path=self.action_param.save_seg_dir,
         window_border=self.action_param.border,
         interp_order=self.action_param.output_interp_order,
         postfix=self.action_param.output_postfix,
         fill_constant=self.action_param.fill_constant)
Ejemplo n.º 5
0
class AutoencoderApplication(BaseApplication):
    REQUIRED_CONFIG_SECTION = "AUTOENCODER"

    def __init__(self, net_param, action_param, action):
        BaseApplication.__init__(self)
        tf.logging.info('starting autoencoder application')

        self.action = action

        self.net_param = net_param
        self.action_param = action_param

        self.data_param = None
        self.autoencoder_param = None

    def initialise_dataset_loader(
            self, data_param=None, task_param=None, data_partitioner=None):
        self.data_param = data_param
        self.autoencoder_param = task_param

        if not self.is_training:
            self._infer_type = look_up_operations(
                self.autoencoder_param.inference_type, SUPPORTED_INFERENCE)
        else:
            self._infer_type = None

        file_lists = self.get_file_lists(data_partitioner)
        # read each line of csv files into an instance of Subject
        if self.is_evaluation:
            NotImplementedError('Evaluation is not yet '
                                'supported in this application.')
        if self.is_training:
            self.readers = []
            for file_list in file_lists:
                reader = ImageReader(['image'])
                reader.initialise(data_param, task_param, file_list)
                self.readers.append(reader)
        if self._infer_type in ('encode', 'encode-decode'):
            self.readers = [ImageReader(['image'])]
            self.readers[0].initialise(data_param,
                                       task_param,
                                       file_lists[0])
        elif self._infer_type == 'sample':
            self.readers = []
        elif self._infer_type == 'linear_interpolation':
            self.readers = [ImageReader(['feature'])]
            self.readers[0].initialise(data_param,
                                       task_param,
                                       [file_lists])
        # if self.is_training or self._infer_type in ('encode', 'encode-decode'):
        #    mean_var_normaliser = MeanVarNormalisationLayer(image_name='image')
        #    self.reader.add_preprocessing_layers([mean_var_normaliser])

    def initialise_sampler(self):
        self.sampler = []
        if self.is_training:
            self.sampler.append([ResizeSampler(
                reader=reader,
                data_param=self.data_param,
                batch_size=self.net_param.batch_size,
                windows_per_image=1,
                shuffle_buffer=True,
                queue_length=self.net_param.queue_length) for reader in
                self.readers])
            return
        if self._infer_type in ('encode', 'encode-decode'):
            self.sampler.append([ResizeSampler(
                reader=reader,
                data_param=self.data_param,
                batch_size=self.net_param.batch_size,
                windows_per_image=1,
                shuffle_buffer=False,
                queue_length=self.net_param.queue_length) for reader in
                self.readers])
            return
        if self._infer_type == 'linear_interpolation':
            self.sampler.append([LinearInterpolateSampler(
                reader=reader,
                data_param=self.data_param,
                batch_size=self.net_param.batch_size,
                n_interpolations=self.autoencoder_param.n_interpolations,
                queue_length=self.net_param.queue_length) for reader in
                self.readers])
            return

    def initialise_network(self):
        w_regularizer = None
        b_regularizer = None
        reg_type = self.net_param.reg_type.lower()
        decay = self.net_param.decay
        if reg_type == 'l2' and decay > 0:
            from tensorflow.contrib.layers.python.layers import regularizers
            w_regularizer = regularizers.l2_regularizer(decay)
            b_regularizer = regularizers.l2_regularizer(decay)
        elif reg_type == 'l1' and decay > 0:
            from tensorflow.contrib.layers.python.layers import regularizers
            w_regularizer = regularizers.l1_regularizer(decay)
            b_regularizer = regularizers.l1_regularizer(decay)

        self.net = ApplicationNetFactory.create(self.net_param.name)(
            w_regularizer=w_regularizer,
            b_regularizer=b_regularizer)

    def connect_data_and_network(self,
                                 outputs_collector=None,
                                 gradients_collector=None):
        def switch_sampler(for_training):
            with tf.name_scope('train' if for_training else 'validation'):
                sampler = self.get_sampler()[0][0 if for_training else -1]
                return sampler.pop_batch_op()

        if self.is_training:
            if self.action_param.validation_every_n > 0:
                data_dict = tf.cond(tf.logical_not(self.is_validation),
                                    lambda: switch_sampler(True),
                                    lambda: switch_sampler(False))
            else:
                data_dict = switch_sampler(for_training=True)

            image = tf.cast(data_dict['image'], tf.float32)
            net_output = self.net(image, is_training=self.is_training)

            with tf.name_scope('Optimiser'):
                optimiser_class = OptimiserFactory.create(
                    name=self.action_param.optimiser)
                self.optimiser = optimiser_class.get_instance(
                    learning_rate=self.action_param.lr)

            loss_func = LossFunction(loss_type=self.action_param.loss_type)
            data_loss = loss_func(net_output)
            loss = data_loss
            if self.net_param.decay > 0.0:
                reg_losses = tf.get_collection(
                    tf.GraphKeys.REGULARIZATION_LOSSES)
                if reg_losses:
                    reg_loss = tf.reduce_mean(
                        [tf.reduce_mean(reg_loss) for reg_loss in reg_losses])
                    loss = loss + reg_loss
            grads = self.optimiser.compute_gradients(loss)
            # collecting gradients variables
            gradients_collector.add_to_collection([grads])

            outputs_collector.add_to_collection(
                var=data_loss, name='variational_lower_bound',
                average_over_devices=True, collection=CONSOLE)
            outputs_collector.add_to_collection(
                var=data_loss, name='variational_lower_bound',
                average_over_devices=True, summary_type='scalar',
                collection=TF_SUMMARIES)

            outputs_collector.add_to_collection(
                var=net_output[4], name='Originals',
                average_over_devices=False, summary_type='image3_coronal',
                collection=TF_SUMMARIES)
            outputs_collector.add_to_collection(
                var=net_output[2], name='Means',
                average_over_devices=False, summary_type='image3_coronal',
                collection=TF_SUMMARIES)
            outputs_collector.add_to_collection(
                var=net_output[5], name='Variances',
                average_over_devices=False, summary_type='image3_coronal',
                collection=TF_SUMMARIES)
        else:
            if self._infer_type in ('encode', 'encode-decode'):
                data_dict = self.get_sampler()[0][0].pop_batch_op()
                image = tf.cast(data_dict['image'], dtype=tf.float32)
                net_output = self.net(image, is_training=False)

                outputs_collector.add_to_collection(
                    var=data_dict['image_location'], name='location',
                    average_over_devices=True, collection=NETWORK_OUTPUT)

                if self._infer_type == 'encode-decode':
                    outputs_collector.add_to_collection(
                        var=net_output[2], name='generated_image',
                        average_over_devices=True, collection=NETWORK_OUTPUT)
                if self._infer_type == 'encode':
                    outputs_collector.add_to_collection(
                        var=net_output[7], name='embedded',
                        average_over_devices=True, collection=NETWORK_OUTPUT)

                self.output_decoder = WindowAsImageAggregator(
                    image_reader=self.readers[0],
                    output_path=self.action_param.save_seg_dir)
                return
            elif self._infer_type == 'sample':
                image_size = (self.net_param.batch_size,) + \
                             self.action_param.spatial_window_size + (1,)
                dummy_image = tf.zeros(image_size)
                net_output = self.net(dummy_image, is_training=False)
                noise_shape = net_output[-1].shape.as_list()
                stddev = self.autoencoder_param.noise_stddev
                noise = tf.random_normal(shape=noise_shape,
                                         mean=0.0,
                                         stddev=stddev,
                                         dtype=tf.float32)
                partially_decoded_sample = self.net.shared_decoder(
                    noise, is_training=False)
                decoder_output = self.net.decoder_means(
                    partially_decoded_sample, is_training=False)

                outputs_collector.add_to_collection(
                    var=decoder_output, name='generated_image',
                    average_over_devices=True, collection=NETWORK_OUTPUT)
                self.output_decoder = WindowAsImageAggregator(
                    image_reader=None,
                    output_path=self.action_param.save_seg_dir)
                return
            elif self._infer_type == 'linear_interpolation':
                # construct the entire network
                image_size = (self.net_param.batch_size,) + \
                             self.action_param.spatial_window_size + (1,)
                dummy_image = tf.zeros(image_size)
                net_output = self.net(dummy_image, is_training=False)
                data_dict = self.get_sampler()[0][0].pop_batch_op()
                real_code = data_dict['feature']
                real_code = tf.reshape(real_code, net_output[-1].get_shape())
                partially_decoded_sample = self.net.shared_decoder(
                    real_code, is_training=False)
                decoder_output = self.net.decoder_means(
                    partially_decoded_sample, is_training=False)

                outputs_collector.add_to_collection(
                    var=decoder_output, name='generated_image',
                    average_over_devices=True, collection=NETWORK_OUTPUT)
                outputs_collector.add_to_collection(
                    var=data_dict['feature_location'], name='location',
                    average_over_devices=True, collection=NETWORK_OUTPUT)
                self.output_decoder = WindowAsImageAggregator(
                    image_reader=self.readers[0],
                    output_path=self.action_param.save_seg_dir)
            else:
                raise NotImplementedError

    def interpret_output(self, batch_output):
        if self.is_training:
            return True
        else:
            infer_type = look_up_operations(
                self.autoencoder_param.inference_type,
                SUPPORTED_INFERENCE)
            if infer_type == 'encode':
                return self.output_decoder.decode_batch(
                    batch_output['embedded'],
                    batch_output['location'][:, 0:1])
            if infer_type == 'encode-decode':
                return self.output_decoder.decode_batch(
                    batch_output['generated_image'],
                    batch_output['location'][:, 0:1])
            if infer_type == 'sample':
                return self.output_decoder.decode_batch(
                    batch_output['generated_image'],
                    None)
            if infer_type == 'linear_interpolation':
                return self.output_decoder.decode_batch(
                    batch_output['generated_image'],
                    batch_output['location'][:, :2])
Ejemplo n.º 6
0
    def connect_data_and_network(self,
                                 outputs_collector=None,
                                 gradients_collector=None):
        def switch_sampler(for_training):
            with tf.name_scope('train' if for_training else 'validation'):
                sampler = self.get_sampler()[0][0 if for_training else -1]
                return sampler.pop_batch_op()

        if self.is_training:
            if self.action_param.validation_every_n > 0:
                data_dict = tf.cond(tf.logical_not(self.is_validation),
                                    lambda: switch_sampler(True),
                                    lambda: switch_sampler(False))
            else:
                data_dict = switch_sampler(for_training=True)

            image = tf.cast(data_dict['image'], tf.float32)
            net_output = self.net(image, is_training=self.is_training)

            with tf.name_scope('Optimiser'):
                optimiser_class = OptimiserFactory.create(
                    name=self.action_param.optimiser)
                self.optimiser = optimiser_class.get_instance(
                    learning_rate=self.action_param.lr)

            loss_func = LossFunction(loss_type=self.action_param.loss_type)
            data_loss = loss_func(net_output)
            loss = data_loss
            if self.net_param.decay > 0.0:
                reg_losses = tf.get_collection(
                    tf.GraphKeys.REGULARIZATION_LOSSES)
                if reg_losses:
                    reg_loss = tf.reduce_mean(
                        [tf.reduce_mean(reg_loss) for reg_loss in reg_losses])
                    loss = loss + reg_loss
            grads = self.optimiser.compute_gradients(loss)
            # collecting gradients variables
            gradients_collector.add_to_collection([grads])

            outputs_collector.add_to_collection(
                var=data_loss, name='variational_lower_bound',
                average_over_devices=True, collection=CONSOLE)
            outputs_collector.add_to_collection(
                var=data_loss, name='variational_lower_bound',
                average_over_devices=True, summary_type='scalar',
                collection=TF_SUMMARIES)

            outputs_collector.add_to_collection(
                var=net_output[4], name='Originals',
                average_over_devices=False, summary_type='image3_coronal',
                collection=TF_SUMMARIES)
            outputs_collector.add_to_collection(
                var=net_output[2], name='Means',
                average_over_devices=False, summary_type='image3_coronal',
                collection=TF_SUMMARIES)
            outputs_collector.add_to_collection(
                var=net_output[5], name='Variances',
                average_over_devices=False, summary_type='image3_coronal',
                collection=TF_SUMMARIES)
        else:
            if self._infer_type in ('encode', 'encode-decode'):
                data_dict = self.get_sampler()[0][0].pop_batch_op()
                image = tf.cast(data_dict['image'], dtype=tf.float32)
                net_output = self.net(image, is_training=False)

                outputs_collector.add_to_collection(
                    var=data_dict['image_location'], name='location',
                    average_over_devices=True, collection=NETWORK_OUTPUT)

                if self._infer_type == 'encode-decode':
                    outputs_collector.add_to_collection(
                        var=net_output[2], name='generated_image',
                        average_over_devices=True, collection=NETWORK_OUTPUT)
                if self._infer_type == 'encode':
                    outputs_collector.add_to_collection(
                        var=net_output[7], name='embedded',
                        average_over_devices=True, collection=NETWORK_OUTPUT)

                self.output_decoder = WindowAsImageAggregator(
                    image_reader=self.readers[0],
                    output_path=self.action_param.save_seg_dir)
                return
            elif self._infer_type == 'sample':
                image_size = (self.net_param.batch_size,) + \
                             self.action_param.spatial_window_size + (1,)
                dummy_image = tf.zeros(image_size)
                net_output = self.net(dummy_image, is_training=False)
                noise_shape = net_output[-1].shape.as_list()
                stddev = self.autoencoder_param.noise_stddev
                noise = tf.random_normal(shape=noise_shape,
                                         mean=0.0,
                                         stddev=stddev,
                                         dtype=tf.float32)
                partially_decoded_sample = self.net.shared_decoder(
                    noise, is_training=False)
                decoder_output = self.net.decoder_means(
                    partially_decoded_sample, is_training=False)

                outputs_collector.add_to_collection(
                    var=decoder_output, name='generated_image',
                    average_over_devices=True, collection=NETWORK_OUTPUT)
                self.output_decoder = WindowAsImageAggregator(
                    image_reader=None,
                    output_path=self.action_param.save_seg_dir)
                return
            elif self._infer_type == 'linear_interpolation':
                # construct the entire network
                image_size = (self.net_param.batch_size,) + \
                             self.action_param.spatial_window_size + (1,)
                dummy_image = tf.zeros(image_size)
                net_output = self.net(dummy_image, is_training=False)
                data_dict = self.get_sampler()[0][0].pop_batch_op()
                real_code = data_dict['feature']
                real_code = tf.reshape(real_code, net_output[-1].get_shape())
                partially_decoded_sample = self.net.shared_decoder(
                    real_code, is_training=False)
                decoder_output = self.net.decoder_means(
                    partially_decoded_sample, is_training=False)

                outputs_collector.add_to_collection(
                    var=decoder_output, name='generated_image',
                    average_over_devices=True, collection=NETWORK_OUTPUT)
                outputs_collector.add_to_collection(
                    var=data_dict['feature_location'], name='location',
                    average_over_devices=True, collection=NETWORK_OUTPUT)
                self.output_decoder = WindowAsImageAggregator(
                    image_reader=self.readers[0],
                    output_path=self.action_param.save_seg_dir)
            else:
                raise NotImplementedError
Ejemplo n.º 7
0
class AutoencoderApplication(BaseApplication):
    REQUIRED_CONFIG_SECTION = "AUTOENCODER"

    def __init__(self, net_param, action_param, is_training):
        BaseApplication.__init__(self)
        tf.logging.info('starting autoencoder application')

        self.is_training = is_training

        self.net_param = net_param
        self.action_param = action_param

        self.data_param = None
        self.autoencoder_param = None

    def initialise_dataset_loader(self,
                                  data_param=None,
                                  task_param=None,
                                  data_partitioner=None):
        self.data_param = data_param
        self.autoencoder_param = task_param

        if not self.is_training:
            self._infer_type = look_up_operations(
                self.autoencoder_param.inference_type, SUPPORTED_INFERENCE)
        else:
            self._infer_type = None

        # read each line of csv files into an instance of Subject
        if self.is_training:
            file_lists = []
            if self.action_param.validation_every_n > 0:
                file_lists.append(data_partitioner.train_files)
                file_lists.append(data_partitioner.validation_files)
            else:
                file_lists.append(data_partitioner.train_files)

            self.readers = []
            for file_list in file_lists:
                reader = ImageReader(['image'])
                reader.initialise(data_param, task_param, file_list)
                self.readers.append(reader)
        if self._infer_type in ('encode', 'encode-decode'):
            self.readers = [ImageReader(['image'])]
            self.readers[0].initialise(data_param, task_param,
                                       data_partitioner.inference_files)
        elif self._infer_type == 'sample':
            self.readers = []
        elif self._infer_type == 'linear_interpolation':
            self.readers = [ImageReader(['feature'])]
            self.readers[0].initialise(data_param, task_param,
                                       data_partitioner.inference_files)
        # if self.is_training or self._infer_type in ('encode', 'encode-decode'):
        #    mean_var_normaliser = MeanVarNormalisationLayer(image_name='image')
        #    self.reader.add_preprocessing_layers([mean_var_normaliser])

    def initialise_sampler(self):
        self.sampler = []
        if self.is_training:
            self.sampler.append([
                ResizeSampler(reader=reader,
                              data_param=self.data_param,
                              batch_size=self.net_param.batch_size,
                              windows_per_image=1,
                              shuffle_buffer=True,
                              queue_length=self.net_param.queue_length)
                for reader in self.readers
            ])
            return
        if self._infer_type in ('encode', 'encode-decode'):
            self.sampler.append([
                ResizeSampler(reader=reader,
                              data_param=self.data_param,
                              batch_size=self.net_param.batch_size,
                              windows_per_image=1,
                              shuffle_buffer=False,
                              queue_length=self.net_param.queue_length)
                for reader in self.readers
            ])
            return
        if self._infer_type == 'linear_interpolation':
            self.sampler.append([
                LinearInterpolateSampler(
                    reader=reader,
                    data_param=self.data_param,
                    batch_size=self.net_param.batch_size,
                    n_interpolations=self.autoencoder_param.n_interpolations,
                    queue_length=self.net_param.queue_length)
                for reader in self.readers
            ])
            return

    def initialise_network(self):
        w_regularizer = None
        b_regularizer = None
        reg_type = self.net_param.reg_type.lower()
        decay = self.net_param.decay
        if reg_type == 'l2' and decay > 0:
            from tensorflow.contrib.layers.python.layers import regularizers
            w_regularizer = regularizers.l2_regularizer(decay)
            b_regularizer = regularizers.l2_regularizer(decay)
        elif reg_type == 'l1' and decay > 0:
            from tensorflow.contrib.layers.python.layers import regularizers
            w_regularizer = regularizers.l1_regularizer(decay)
            b_regularizer = regularizers.l1_regularizer(decay)

        self.net = ApplicationNetFactory.create(self.net_param.name)(
            w_regularizer=w_regularizer, b_regularizer=b_regularizer)

    def connect_data_and_network(self,
                                 outputs_collector=None,
                                 gradients_collector=None):
        def switch_sampler(for_training):
            with tf.name_scope('train' if for_training else 'validation'):
                sampler = self.get_sampler()[0][0 if for_training else -1]
                return sampler.pop_batch_op()

        if self.is_training:
            if self.action_param.validation_every_n > 0:
                data_dict = tf.cond(tf.logical_not(self.is_validation),
                                    lambda: switch_sampler(True),
                                    lambda: switch_sampler(False))
            else:
                data_dict = switch_sampler(for_training=True)

            image = tf.cast(data_dict['image'], tf.float32)
            net_output = self.net(image, is_training=self.is_training)

            with tf.name_scope('Optimiser'):
                optimiser_class = OptimiserFactory.create(
                    name=self.action_param.optimiser)
                self.optimiser = optimiser_class.get_instance(
                    learning_rate=self.action_param.lr)

            loss_func = LossFunction(loss_type=self.action_param.loss_type)
            data_loss = loss_func(net_output)
            loss = data_loss
            if self.net_param.decay > 0.0:
                reg_losses = tf.get_collection(
                    tf.GraphKeys.REGULARIZATION_LOSSES)
                if reg_losses:
                    reg_loss = tf.reduce_mean(
                        [tf.reduce_mean(reg_loss) for reg_loss in reg_losses])
                    loss = loss + reg_loss
            grads = self.optimiser.compute_gradients(loss)
            # collecting gradients variables
            gradients_collector.add_to_collection([grads])

            outputs_collector.add_to_collection(var=data_loss,
                                                name='variational_lower_bound',
                                                average_over_devices=True,
                                                collection=CONSOLE)
            outputs_collector.add_to_collection(var=data_loss,
                                                name='variational_lower_bound',
                                                average_over_devices=True,
                                                summary_type='scalar',
                                                collection=TF_SUMMARIES)

            outputs_collector.add_to_collection(var=net_output[4],
                                                name='Originals',
                                                average_over_devices=False,
                                                summary_type='image3_coronal',
                                                collection=TF_SUMMARIES)
            outputs_collector.add_to_collection(var=net_output[2],
                                                name='Means',
                                                average_over_devices=False,
                                                summary_type='image3_coronal',
                                                collection=TF_SUMMARIES)
            outputs_collector.add_to_collection(var=net_output[5],
                                                name='Variances',
                                                average_over_devices=False,
                                                summary_type='image3_coronal',
                                                collection=TF_SUMMARIES)
        else:
            if self._infer_type in ('encode', 'encode-decode'):
                data_dict = self.get_sampler()[0][0].pop_batch_op()
                image = tf.cast(data_dict['image'], dtype=tf.float32)
                net_output = self.net(image, is_training=False)

                outputs_collector.add_to_collection(
                    var=data_dict['image_location'],
                    name='location',
                    average_over_devices=True,
                    collection=NETWORK_OUTPUT)

                if self._infer_type == 'encode-decode':
                    outputs_collector.add_to_collection(
                        var=net_output[2],
                        name='generated_image',
                        average_over_devices=True,
                        collection=NETWORK_OUTPUT)
                if self._infer_type == 'encode':
                    outputs_collector.add_to_collection(
                        var=net_output[7],
                        name='embedded',
                        average_over_devices=True,
                        collection=NETWORK_OUTPUT)

                self.output_decoder = WindowAsImageAggregator(
                    image_reader=self.readers[0],
                    output_path=self.action_param.save_seg_dir)
                return
            elif self._infer_type == 'sample':
                image_size = (self.net_param.batch_size,) + \
                             self.action_param.spatial_window_size + (1,)
                dummy_image = tf.zeros(image_size)
                net_output = self.net(dummy_image, is_training=False)
                noise_shape = net_output[-1].get_shape().as_list()
                stddev = self.autoencoder_param.noise_stddev
                noise = tf.random_normal(shape=noise_shape,
                                         mean=0.0,
                                         stddev=stddev,
                                         dtype=tf.float32)
                partially_decoded_sample = self.net.shared_decoder(
                    noise, is_training=False)
                decoder_output = self.net.decoder_means(
                    partially_decoded_sample, is_training=False)

                outputs_collector.add_to_collection(var=decoder_output,
                                                    name='generated_image',
                                                    average_over_devices=True,
                                                    collection=NETWORK_OUTPUT)
                self.output_decoder = WindowAsImageAggregator(
                    image_reader=None,
                    output_path=self.action_param.save_seg_dir)
                return
            elif self._infer_type == 'linear_interpolation':
                # construct the entire network
                image_size = (self.net_param.batch_size,) + \
                             self.action_param.spatial_window_size + (1,)
                dummy_image = tf.zeros(image_size)
                net_output = self.net(dummy_image, is_training=False)
                data_dict = self.get_sampler()[0][0].pop_batch_op()
                real_code = data_dict['feature']
                real_code = tf.reshape(real_code, net_output[-1].get_shape())
                partially_decoded_sample = self.net.shared_decoder(
                    real_code, is_training=False)
                decoder_output = self.net.decoder_means(
                    partially_decoded_sample, is_training=False)

                outputs_collector.add_to_collection(var=decoder_output,
                                                    name='generated_image',
                                                    average_over_devices=True,
                                                    collection=NETWORK_OUTPUT)
                outputs_collector.add_to_collection(
                    var=data_dict['feature_location'],
                    name='location',
                    average_over_devices=True,
                    collection=NETWORK_OUTPUT)
                self.output_decoder = WindowAsImageAggregator(
                    image_reader=self.readers[0],
                    output_path=self.action_param.save_seg_dir)
            else:
                raise NotImplementedError

    def interpret_output(self, batch_output):
        if self.is_training:
            return True
        else:
            infer_type = look_up_operations(
                self.autoencoder_param.inference_type, SUPPORTED_INFERENCE)
            if infer_type == 'encode':
                return self.output_decoder.decode_batch(
                    batch_output['embedded'], batch_output['location'][:, 0:1])
            if infer_type == 'encode-decode':
                return self.output_decoder.decode_batch(
                    batch_output['generated_image'],
                    batch_output['location'][:, 0:1])
            if infer_type == 'sample':
                return self.output_decoder.decode_batch(
                    batch_output['generated_image'], None)
            if infer_type == 'linear_interpolation':
                return self.output_decoder.decode_batch(
                    batch_output['generated_image'],
                    batch_output['location'][:, :2])
Ejemplo n.º 8
0
    def connect_data_and_network(self,
                                 outputs_collector=None,
                                 gradients_collector=None):
        def switch_sampler(for_training):
            with tf.name_scope('train' if for_training else 'validation'):
                sampler = self.get_sampler()[0][0 if for_training else -1]
                return sampler.pop_batch_op()

        if self.is_training:
            if self.action_param.validation_every_n > 0:
                data_dict = tf.cond(tf.logical_not(self.is_validation),
                                    lambda: switch_sampler(True),
                                    lambda: switch_sampler(False))
            else:
                data_dict = switch_sampler(for_training=True)

            image = tf.cast(data_dict['image'], tf.float32)
            net_output = self.net(image, is_training=self.is_training)

            with tf.name_scope('Optimiser'):
                optimiser_class = OptimiserFactory.create(
                    name=self.action_param.optimiser)
                self.optimiser = optimiser_class.get_instance(
                    learning_rate=self.action_param.lr)

            loss_func = LossFunction(loss_type=self.action_param.loss_type)
            data_loss = loss_func(net_output)
            loss = data_loss
            if self.net_param.decay > 0.0:
                reg_losses = tf.get_collection(
                    tf.GraphKeys.REGULARIZATION_LOSSES)
                if reg_losses:
                    reg_loss = tf.reduce_mean(
                        [tf.reduce_mean(reg_loss) for reg_loss in reg_losses])
                    loss = loss + reg_loss
            grads = self.optimiser.compute_gradients(loss)
            # collecting gradients variables
            gradients_collector.add_to_collection([grads])

            outputs_collector.add_to_collection(var=data_loss,
                                                name='variational_lower_bound',
                                                average_over_devices=True,
                                                collection=CONSOLE)
            outputs_collector.add_to_collection(var=data_loss,
                                                name='variational_lower_bound',
                                                average_over_devices=True,
                                                summary_type='scalar',
                                                collection=TF_SUMMARIES)

            outputs_collector.add_to_collection(var=net_output[4],
                                                name='Originals',
                                                average_over_devices=False,
                                                summary_type='image3_coronal',
                                                collection=TF_SUMMARIES)
            outputs_collector.add_to_collection(var=net_output[2],
                                                name='Means',
                                                average_over_devices=False,
                                                summary_type='image3_coronal',
                                                collection=TF_SUMMARIES)
            outputs_collector.add_to_collection(var=net_output[5],
                                                name='Variances',
                                                average_over_devices=False,
                                                summary_type='image3_coronal',
                                                collection=TF_SUMMARIES)
        else:
            if self._infer_type in ('encode', 'encode-decode'):
                data_dict = self.get_sampler()[0][0].pop_batch_op()
                image = tf.cast(data_dict['image'], dtype=tf.float32)
                net_output = self.net(image, is_training=False)

                outputs_collector.add_to_collection(
                    var=data_dict['image_location'],
                    name='location',
                    average_over_devices=True,
                    collection=NETWORK_OUTPUT)

                if self._infer_type == 'encode-decode':
                    outputs_collector.add_to_collection(
                        var=net_output[2],
                        name='generated_image',
                        average_over_devices=True,
                        collection=NETWORK_OUTPUT)
                if self._infer_type == 'encode':
                    outputs_collector.add_to_collection(
                        var=net_output[7],
                        name='embedded',
                        average_over_devices=True,
                        collection=NETWORK_OUTPUT)

                self.output_decoder = WindowAsImageAggregator(
                    image_reader=self.readers[0],
                    output_path=self.action_param.save_seg_dir)
                return
            elif self._infer_type == 'sample':
                image_size = (self.net_param.batch_size,) + \
                             self.action_param.spatial_window_size + (1,)
                dummy_image = tf.zeros(image_size)
                net_output = self.net(dummy_image, is_training=False)
                noise_shape = net_output[-1].get_shape().as_list()
                stddev = self.autoencoder_param.noise_stddev
                noise = tf.random_normal(shape=noise_shape,
                                         mean=0.0,
                                         stddev=stddev,
                                         dtype=tf.float32)
                partially_decoded_sample = self.net.shared_decoder(
                    noise, is_training=False)
                decoder_output = self.net.decoder_means(
                    partially_decoded_sample, is_training=False)

                outputs_collector.add_to_collection(var=decoder_output,
                                                    name='generated_image',
                                                    average_over_devices=True,
                                                    collection=NETWORK_OUTPUT)
                self.output_decoder = WindowAsImageAggregator(
                    image_reader=None,
                    output_path=self.action_param.save_seg_dir)
                return
            elif self._infer_type == 'linear_interpolation':
                # construct the entire network
                image_size = (self.net_param.batch_size,) + \
                             self.action_param.spatial_window_size + (1,)
                dummy_image = tf.zeros(image_size)
                net_output = self.net(dummy_image, is_training=False)
                data_dict = self.get_sampler()[0][0].pop_batch_op()
                real_code = data_dict['feature']
                real_code = tf.reshape(real_code, net_output[-1].get_shape())
                partially_decoded_sample = self.net.shared_decoder(
                    real_code, is_training=False)
                decoder_output = self.net.decoder_means(
                    partially_decoded_sample, is_training=False)

                outputs_collector.add_to_collection(var=decoder_output,
                                                    name='generated_image',
                                                    average_over_devices=True,
                                                    collection=NETWORK_OUTPUT)
                outputs_collector.add_to_collection(
                    var=data_dict['feature_location'],
                    name='location',
                    average_over_devices=True,
                    collection=NETWORK_OUTPUT)
                self.output_decoder = WindowAsImageAggregator(
                    image_reader=self.readers[0],
                    output_path=self.action_param.save_seg_dir)
            else:
                raise NotImplementedError
Ejemplo n.º 9
0
class GANApplication(BaseApplication):
    REQUIRED_CONFIG_SECTION = "GAN"

    def __init__(self, net_param, action_param, is_training):
        BaseApplication.__init__(self)
        tf.logging.info('starting GAN application')
        self.is_training = is_training

        self.net_param = net_param
        self.action_param = action_param

        self.data_param = None
        self.gan_param = None

    def initialise_dataset_loader(
            self, data_param=None, task_param=None, data_partitioner=None):
        self.data_param = data_param
        self.gan_param = task_param

        # read each line of csv files into an instance of Subject
        if self.is_training:
            file_lists = []
            if self.action_param.validation_every_n > 0:
                file_lists.append(data_partitioner.train_files)
                file_lists.append(data_partitioner.validation_files)
            else:
                file_lists.append(data_partitioner.train_files)
            self.readers = []
            for file_list in file_lists:
                reader = ImageReader(['image', 'conditioning'])
                reader.initialise(data_param, task_param, file_list)
                self.readers.append(reader)
        else:
            inference_reader = ImageReader(['conditioning'])
            file_list = data_partitioner.inference_files
            inference_reader.initialise(data_param, task_param, file_list)
            self.readers = [inference_reader]

        foreground_masking_layer = None
        if self.net_param.normalise_foreground_only:
            foreground_masking_layer = BinaryMaskingLayer(
                type_str=self.net_param.foreground_type,
                multimod_fusion=self.net_param.multimod_foreground_type,
                threshold=0.0)

        mean_var_normaliser = MeanVarNormalisationLayer(
            image_name='image',
            binary_masking_func=foreground_masking_layer)
        histogram_normaliser = None
        if self.net_param.histogram_ref_file:
            histogram_normaliser = HistogramNormalisationLayer(
                image_name='image',
                modalities=vars(task_param).get('image'),
                model_filename=self.net_param.histogram_ref_file,
                binary_masking_func=foreground_masking_layer,
                norm_type=self.net_param.norm_type,
                cutoff=self.net_param.cutoff,
                name='hist_norm_layer')

        normalisation_layers = []
        if self.net_param.normalisation:
            normalisation_layers.append(histogram_normaliser)
        if self.net_param.whitening:
            normalisation_layers.append(mean_var_normaliser)

        augmentation_layers = []
        if self.is_training:
            if self.action_param.random_flipping_axes != -1:
                augmentation_layers.append(RandomFlipLayer(
                    flip_axes=self.action_param.random_flipping_axes))
            if self.action_param.scaling_percentage:
                augmentation_layers.append(RandomSpatialScalingLayer(
                    min_percentage=self.action_param.scaling_percentage[0],
                    max_percentage=self.action_param.scaling_percentage[1]))
            if self.action_param.rotation_angle:
                augmentation_layers.append(RandomRotationLayer())
                augmentation_layers[-1].init_uniform_angle(
                    self.action_param.rotation_angle)

        for reader in self.readers:
            reader.add_preprocessing_layers(
                normalisation_layers + augmentation_layers)

    def initialise_sampler(self):
        self.sampler = []
        if self.is_training:
            self.sampler.append([ResizeSampler(
                reader=reader,
                data_param=self.data_param,
                batch_size=self.net_param.batch_size,
                windows_per_image=1,
                shuffle_buffer=True,
                queue_length=self.net_param.queue_length) for reader in
                self.readers])
        else:
            self.sampler.append([RandomVectorSampler(
                names=('vector',),
                vector_size=(self.gan_param.noise_size,),
                batch_size=self.net_param.batch_size,
                n_interpolations=self.gan_param.n_interpolations,
                repeat=None,
                queue_length=self.net_param.queue_length) for _ in
                self.readers])
            # repeat each resized image n times, so that each
            # image matches one random vector,
            # (n = self.gan_param.n_interpolations)
            self.sampler.append([ResizeSampler(
                reader=reader,
                data_param=self.data_param,
                batch_size=self.net_param.batch_size,
                windows_per_image=self.gan_param.n_interpolations,
                shuffle_buffer=False,
                queue_length=self.net_param.queue_length) for reader in
                self.readers])

    def initialise_network(self):
        self.net = ApplicationNetFactory.create(self.net_param.name)()

    def connect_data_and_network(self,
                                 outputs_collector=None,
                                 gradients_collector=None):
        if self.is_training:
            def switch_sampler(for_training):
                with tf.name_scope('train' if for_training else 'validation'):
                    sampler = self.get_sampler()[0][0 if for_training else -1]
                    return sampler.pop_batch_op()

            if self.action_param.validation_every_n > 0:
                data_dict = tf.cond(tf.logical_not(self.is_validation),
                                    lambda: switch_sampler(True),
                                    lambda: switch_sampler(False))
            else:
                data_dict = switch_sampler(for_training=True)

            images = tf.cast(data_dict['image'], tf.float32)
            noise_shape = [self.net_param.batch_size,
                           self.gan_param.noise_size]
            noise = tf.random_normal(shape=noise_shape,
                                     mean=0.0,
                                     stddev=1.0,
                                     dtype=tf.float32)
            conditioning = data_dict['conditioning']
            net_output = self.net(
                noise, images, conditioning, self.is_training)

            loss_func = LossFunction(
                loss_type=self.action_param.loss_type)
            real_logits = net_output[1]
            fake_logits = net_output[2]
            lossG, lossD = loss_func(real_logits, fake_logits)
            if self.net_param.decay > 0:
                reg_losses = tf.get_collection(
                    tf.GraphKeys.REGULARIZATION_LOSSES)
                if reg_losses:
                    reg_loss = tf.reduce_mean(
                        [tf.reduce_mean(l_reg) for l_reg in reg_losses])
                    lossD = lossD + reg_loss
                    lossG = lossG + reg_loss

            # variables to display in STDOUT
            outputs_collector.add_to_collection(
                var=lossD, name='lossD', average_over_devices=True,
                collection=CONSOLE)
            outputs_collector.add_to_collection(
                var=lossG, name='lossG', average_over_devices=False,
                collection=CONSOLE)
            # variables to display in tensorboard
            outputs_collector.add_to_collection(
                var=lossG, name='lossG', average_over_devices=False,
                collection=TF_SUMMARIES)
            outputs_collector.add_to_collection(
                var=lossG, name='lossD', average_over_devices=True,
                collection=TF_SUMMARIES)

            with tf.name_scope('Optimiser'):
                optimiser_class = OptimiserFactory.create(
                    name=self.action_param.optimiser)
                self.optimiser = optimiser_class.get_instance(
                    learning_rate=self.action_param.lr)

            with tf.name_scope('ComputeGradients'):
                # gradients of generator
                generator_variables = tf.get_collection(
                    tf.GraphKeys.TRAINABLE_VARIABLES, scope='generator')
                generator_grads = self.optimiser.compute_gradients(
                    lossG, var_list=generator_variables)

                # gradients of discriminator
                discriminator_variables = tf.get_collection(
                    tf.GraphKeys.TRAINABLE_VARIABLES, scope='discriminator')
                discriminator_grads = self.optimiser.compute_gradients(
                    lossD, var_list=discriminator_variables)
                grads = [generator_grads, discriminator_grads]

                # add the grads back to application_driver's training_grads
                gradients_collector.add_to_collection(grads)
        else:
            data_dict = self.get_sampler()[0][0].pop_batch_op()
            conditioning_dict = self.get_sampler()[1][0].pop_batch_op()
            conditioning = conditioning_dict['conditioning']
            image_size = conditioning.shape.as_list()[:-1]
            dummy_image = tf.zeros(image_size + [1])
            net_output = self.net(data_dict['vector'],
                                  dummy_image,
                                  conditioning,
                                  self.is_training)
            outputs_collector.add_to_collection(
                var=net_output[0],
                name='image',
                average_over_devices=False,
                collection=NETWORK_OUTPUT)
            outputs_collector.add_to_collection(
                var=conditioning_dict['conditioning_location'],
                name='location',
                average_over_devices=False,
                collection=NETWORK_OUTPUT)

            self.output_decoder = WindowAsImageAggregator(
                image_reader=self.readers[0],
                output_path=self.action_param.save_seg_dir)

    def interpret_output(self, batch_output):
        if self.is_training:
            return True
        return self.output_decoder.decode_batch(
            batch_output['image'], batch_output['location'])
Ejemplo n.º 10
0
    def connect_data_and_network(self,
                                 outputs_collector=None,
                                 gradients_collector=None):
        if self.is_training:
            def switch_sampler(for_training):
                with tf.name_scope('train' if for_training else 'validation'):
                    sampler = self.get_sampler()[0][0 if for_training else -1]
                    return sampler.pop_batch_op()

            if self.action_param.validation_every_n > 0:
                data_dict = tf.cond(tf.logical_not(self.is_validation),
                                    lambda: switch_sampler(True),
                                    lambda: switch_sampler(False))
            else:
                data_dict = switch_sampler(for_training=True)

            images = tf.cast(data_dict['image'], tf.float32)
            noise_shape = [self.net_param.batch_size,
                           self.gan_param.noise_size]
            noise = tf.random_normal(shape=noise_shape,
                                     mean=0.0,
                                     stddev=1.0,
                                     dtype=tf.float32)
            conditioning = data_dict['conditioning']
            net_output = self.net(
                noise, images, conditioning, self.is_training)

            loss_func = LossFunction(
                loss_type=self.action_param.loss_type)
            real_logits = net_output[1]
            fake_logits = net_output[2]
            lossG, lossD = loss_func(real_logits, fake_logits)
            if self.net_param.decay > 0:
                reg_losses = tf.get_collection(
                    tf.GraphKeys.REGULARIZATION_LOSSES)
                if reg_losses:
                    reg_loss = tf.reduce_mean(
                        [tf.reduce_mean(l_reg) for l_reg in reg_losses])
                    lossD = lossD + reg_loss
                    lossG = lossG + reg_loss

            # variables to display in STDOUT
            outputs_collector.add_to_collection(
                var=lossD, name='lossD', average_over_devices=True,
                collection=CONSOLE)
            outputs_collector.add_to_collection(
                var=lossG, name='lossG', average_over_devices=False,
                collection=CONSOLE)
            # variables to display in tensorboard
            outputs_collector.add_to_collection(
                var=lossG, name='lossG', average_over_devices=False,
                collection=TF_SUMMARIES)
            outputs_collector.add_to_collection(
                var=lossG, name='lossD', average_over_devices=True,
                collection=TF_SUMMARIES)

            with tf.name_scope('Optimiser'):
                optimiser_class = OptimiserFactory.create(
                    name=self.action_param.optimiser)
                self.optimiser = optimiser_class.get_instance(
                    learning_rate=self.action_param.lr)

            with tf.name_scope('ComputeGradients'):
                # gradients of generator
                generator_variables = tf.get_collection(
                    tf.GraphKeys.TRAINABLE_VARIABLES, scope='generator')
                generator_grads = self.optimiser.compute_gradients(
                    lossG, var_list=generator_variables)

                # gradients of discriminator
                discriminator_variables = tf.get_collection(
                    tf.GraphKeys.TRAINABLE_VARIABLES, scope='discriminator')
                discriminator_grads = self.optimiser.compute_gradients(
                    lossD, var_list=discriminator_variables)
                grads = [generator_grads, discriminator_grads]

                # add the grads back to application_driver's training_grads
                gradients_collector.add_to_collection(grads)
        else:
            data_dict = self.get_sampler()[0][0].pop_batch_op()
            conditioning_dict = self.get_sampler()[1][0].pop_batch_op()
            conditioning = conditioning_dict['conditioning']
            image_size = conditioning.shape.as_list()[:-1]
            dummy_image = tf.zeros(image_size + [1])
            net_output = self.net(data_dict['vector'],
                                  dummy_image,
                                  conditioning,
                                  self.is_training)
            outputs_collector.add_to_collection(
                var=net_output[0],
                name='image',
                average_over_devices=False,
                collection=NETWORK_OUTPUT)
            outputs_collector.add_to_collection(
                var=conditioning_dict['conditioning_location'],
                name='location',
                average_over_devices=False,
                collection=NETWORK_OUTPUT)

            self.output_decoder = WindowAsImageAggregator(
                image_reader=self.readers[0],
                output_path=self.action_param.save_seg_dir)
Ejemplo n.º 11
0
class GANApplication(BaseApplication):
    REQUIRED_CONFIG_SECTION = "GAN"

    def __init__(self, net_param, action_param, action):
        BaseApplication.__init__(self)
        tf.logging.info('starting GAN application')
        self.action = action

        self.net_param = net_param
        self.action_param = action_param

        self.data_param = None
        self.gan_param = None

    def initialise_dataset_loader(self,
                                  data_param=None,
                                  task_param=None,
                                  data_partitioner=None):
        self.data_param = data_param
        self.gan_param = task_param

        file_lists = self.get_file_lists(data_partitioner)
        # read each line of csv files into an instance of Subject
        if self.is_training:
            self.readers = []
            for file_list in file_lists:
                reader = ImageReader(['image', 'conditioning'])
                reader.initialise(data_param, task_param, file_list)
                self.readers.append(reader)
        elif self.is_inference:
            inference_reader = ImageReader(['conditioning'])
            inference_reader.initialise(data_param, task_param, file_lists[0])
            self.readers = [inference_reader]
        elif self.is_evaluation:
            NotImplementedError('Evaluation is not yet '
                                'supported in this application.')

        foreground_masking_layer = None
        if self.net_param.normalise_foreground_only:
            foreground_masking_layer = BinaryMaskingLayer(
                type_str=self.net_param.foreground_type,
                multimod_fusion=self.net_param.multimod_foreground_type,
                threshold=0.0)

        mean_var_normaliser = MeanVarNormalisationLayer(
            image_name='image', binary_masking_func=foreground_masking_layer)
        histogram_normaliser = None
        if self.net_param.histogram_ref_file:
            histogram_normaliser = HistogramNormalisationLayer(
                image_name='image',
                modalities=vars(task_param).get('image'),
                model_filename=self.net_param.histogram_ref_file,
                binary_masking_func=foreground_masking_layer,
                norm_type=self.net_param.norm_type,
                cutoff=self.net_param.cutoff,
                name='hist_norm_layer')

        normalisation_layers = []
        if self.net_param.normalisation:
            normalisation_layers.append(histogram_normaliser)
        if self.net_param.whitening:
            normalisation_layers.append(mean_var_normaliser)

        augmentation_layers = []
        if self.is_training:
            if self.action_param.random_flipping_axes != -1:
                augmentation_layers.append(
                    RandomFlipLayer(
                        flip_axes=self.action_param.random_flipping_axes))
            if self.action_param.scaling_percentage:
                augmentation_layers.append(
                    RandomSpatialScalingLayer(
                        min_percentage=self.action_param.scaling_percentage[0],
                        max_percentage=self.action_param.scaling_percentage[1])
                )
            if self.action_param.rotation_angle:
                augmentation_layers.append(RandomRotationLayer())
                augmentation_layers[-1].init_uniform_angle(
                    self.action_param.rotation_angle)

        for reader in self.readers:
            reader.add_preprocessing_layers(normalisation_layers +
                                            augmentation_layers)

    def initialise_sampler(self):
        self.sampler = []
        if self.is_training:
            self.sampler.append([
                ResizeSampler(reader=reader,
                              data_param=self.data_param,
                              batch_size=self.net_param.batch_size,
                              windows_per_image=1,
                              shuffle_buffer=True,
                              queue_length=self.net_param.queue_length)
                for reader in self.readers
            ])
        else:
            self.sampler.append([
                RandomVectorSampler(
                    names=('vector', ),
                    vector_size=(self.gan_param.noise_size, ),
                    batch_size=self.net_param.batch_size,
                    n_interpolations=self.gan_param.n_interpolations,
                    repeat=None,
                    queue_length=self.net_param.queue_length)
                for _ in self.readers
            ])
            # repeat each resized image n times, so that each
            # image matches one random vector,
            # (n = self.gan_param.n_interpolations)
            self.sampler.append([
                ResizeSampler(
                    reader=reader,
                    data_param=self.data_param,
                    batch_size=self.net_param.batch_size,
                    windows_per_image=self.gan_param.n_interpolations,
                    shuffle_buffer=False,
                    queue_length=self.net_param.queue_length)
                for reader in self.readers
            ])

    def initialise_network(self):
        self.net = ApplicationNetFactory.create(self.net_param.name)()

    def connect_data_and_network(self,
                                 outputs_collector=None,
                                 gradients_collector=None):
        if self.is_training:

            def switch_sampler(for_training):
                with tf.name_scope('train' if for_training else 'validation'):
                    sampler = self.get_sampler()[0][0 if for_training else -1]
                    return sampler.pop_batch_op()

            if self.action_param.validation_every_n > 0:
                data_dict = tf.cond(tf.logical_not(self.is_validation),
                                    lambda: switch_sampler(True),
                                    lambda: switch_sampler(False))
            else:
                data_dict = switch_sampler(for_training=True)

            images = tf.cast(data_dict['image'], tf.float32)
            noise_shape = [
                self.net_param.batch_size, self.gan_param.noise_size
            ]
            noise = tf.random_normal(shape=noise_shape,
                                     mean=0.0,
                                     stddev=1.0,
                                     dtype=tf.float32)
            conditioning = data_dict['conditioning']
            net_output = self.net(noise, images, conditioning,
                                  self.is_training)

            loss_func = LossFunction(loss_type=self.action_param.loss_type)
            real_logits = net_output[1]
            fake_logits = net_output[2]
            lossG, lossD = loss_func(real_logits, fake_logits)
            if self.net_param.decay > 0:
                reg_losses = tf.get_collection(
                    tf.GraphKeys.REGULARIZATION_LOSSES)
                if reg_losses:
                    reg_loss = tf.reduce_mean(
                        [tf.reduce_mean(l_reg) for l_reg in reg_losses])
                    lossD = lossD + reg_loss
                    lossG = lossG + reg_loss

            # variables to display in STDOUT
            outputs_collector.add_to_collection(var=lossD,
                                                name='lossD',
                                                average_over_devices=True,
                                                collection=CONSOLE)
            outputs_collector.add_to_collection(var=lossG,
                                                name='lossG',
                                                average_over_devices=False,
                                                collection=CONSOLE)
            # variables to display in tensorboard
            outputs_collector.add_to_collection(var=lossG,
                                                name='lossG',
                                                average_over_devices=False,
                                                collection=TF_SUMMARIES)
            outputs_collector.add_to_collection(var=lossG,
                                                name='lossD',
                                                average_over_devices=True,
                                                collection=TF_SUMMARIES)

            with tf.name_scope('Optimiser'):
                optimiser_class = OptimiserFactory.create(
                    name=self.action_param.optimiser)
                self.optimiser = optimiser_class.get_instance(
                    learning_rate=self.action_param.lr)

            with tf.name_scope('ComputeGradients'):
                # gradients of generator
                generator_variables = tf.get_collection(
                    tf.GraphKeys.TRAINABLE_VARIABLES, scope='generator')
                generator_grads = self.optimiser.compute_gradients(
                    lossG, var_list=generator_variables)

                # gradients of discriminator
                discriminator_variables = tf.get_collection(
                    tf.GraphKeys.TRAINABLE_VARIABLES, scope='discriminator')
                discriminator_grads = self.optimiser.compute_gradients(
                    lossD, var_list=discriminator_variables)
                grads = [generator_grads, discriminator_grads]

                # add the grads back to application_driver's training_grads
                gradients_collector.add_to_collection(grads)
        else:
            data_dict = self.get_sampler()[0][0].pop_batch_op()
            conditioning_dict = self.get_sampler()[1][0].pop_batch_op()
            conditioning = conditioning_dict['conditioning']
            image_size = conditioning.shape.as_list()[:-1]
            dummy_image = tf.zeros(image_size + [1])
            net_output = self.net(data_dict['vector'], dummy_image,
                                  conditioning, self.is_training)
            outputs_collector.add_to_collection(var=net_output[0],
                                                name='image',
                                                average_over_devices=False,
                                                collection=NETWORK_OUTPUT)
            outputs_collector.add_to_collection(
                var=conditioning_dict['conditioning_location'],
                name='location',
                average_over_devices=False,
                collection=NETWORK_OUTPUT)

            self.output_decoder = WindowAsImageAggregator(
                image_reader=self.readers[0],
                output_path=self.action_param.save_seg_dir)

    def interpret_output(self, batch_output):
        if self.is_training:
            return True
        return self.output_decoder.decode_batch(batch_output['image'],
                                                batch_output['location'])
Ejemplo n.º 12
0
    def connect_data_and_network(self,
                                 outputs_collector=None,
                                 gradients_collector=None):
        if self.is_training:

            def switch_sampler(for_training):
                with tf.name_scope('train' if for_training else 'validation'):
                    sampler = self.get_sampler()[0][0 if for_training else -1]
                    return sampler.pop_batch_op()

            if self.action_param.validation_every_n > 0:
                data_dict = tf.cond(tf.logical_not(self.is_validation),
                                    lambda: switch_sampler(True),
                                    lambda: switch_sampler(False))
            else:
                data_dict = switch_sampler(for_training=True)

            images = tf.cast(data_dict['image'], tf.float32)
            noise_shape = [
                self.net_param.batch_size, self.gan_param.noise_size
            ]
            noise = tf.random_normal(shape=noise_shape,
                                     mean=0.0,
                                     stddev=1.0,
                                     dtype=tf.float32)
            conditioning = data_dict['conditioning']
            net_output = self.net(noise, images, conditioning,
                                  self.is_training)

            loss_func = LossFunction(loss_type=self.action_param.loss_type)
            real_logits = net_output[1]
            fake_logits = net_output[2]
            lossG, lossD = loss_func(real_logits, fake_logits)
            if self.net_param.decay > 0:
                reg_losses = tf.get_collection(
                    tf.GraphKeys.REGULARIZATION_LOSSES)
                if reg_losses:
                    reg_loss = tf.reduce_mean(
                        [tf.reduce_mean(l_reg) for l_reg in reg_losses])
                    lossD = lossD + reg_loss
                    lossG = lossG + reg_loss

            # variables to display in STDOUT
            outputs_collector.add_to_collection(var=lossD,
                                                name='lossD',
                                                average_over_devices=True,
                                                collection=CONSOLE)
            outputs_collector.add_to_collection(var=lossG,
                                                name='lossG',
                                                average_over_devices=False,
                                                collection=CONSOLE)
            # variables to display in tensorboard
            outputs_collector.add_to_collection(var=lossG,
                                                name='lossG',
                                                average_over_devices=False,
                                                collection=TF_SUMMARIES)
            outputs_collector.add_to_collection(var=lossG,
                                                name='lossD',
                                                average_over_devices=True,
                                                collection=TF_SUMMARIES)

            with tf.name_scope('Optimiser'):
                optimiser_class = OptimiserFactory.create(
                    name=self.action_param.optimiser)
                self.optimiser = optimiser_class.get_instance(
                    learning_rate=self.action_param.lr)

            with tf.name_scope('ComputeGradients'):
                # gradients of generator
                generator_variables = tf.get_collection(
                    tf.GraphKeys.TRAINABLE_VARIABLES, scope='generator')
                generator_grads = self.optimiser.compute_gradients(
                    lossG, var_list=generator_variables)

                # gradients of discriminator
                discriminator_variables = tf.get_collection(
                    tf.GraphKeys.TRAINABLE_VARIABLES, scope='discriminator')
                discriminator_grads = self.optimiser.compute_gradients(
                    lossD, var_list=discriminator_variables)
                grads = [generator_grads, discriminator_grads]

                # add the grads back to application_driver's training_grads
                gradients_collector.add_to_collection(grads)
        else:
            data_dict = self.get_sampler()[0][0].pop_batch_op()
            conditioning_dict = self.get_sampler()[1][0].pop_batch_op()
            conditioning = conditioning_dict['conditioning']
            image_size = conditioning.shape.as_list()[:-1]
            dummy_image = tf.zeros(image_size + [1])
            net_output = self.net(data_dict['vector'], dummy_image,
                                  conditioning, self.is_training)
            outputs_collector.add_to_collection(var=net_output[0],
                                                name='image',
                                                average_over_devices=False,
                                                collection=NETWORK_OUTPUT)
            outputs_collector.add_to_collection(
                var=conditioning_dict['conditioning_location'],
                name='location',
                average_over_devices=False,
                collection=NETWORK_OUTPUT)

            self.output_decoder = WindowAsImageAggregator(
                image_reader=self.readers[0],
                output_path=self.action_param.save_seg_dir)
Ejemplo n.º 13
0
class RegressionApplication(BaseApplication):
    REQUIRED_CONFIG_SECTION = "REGRESSION"

    def __init__(self, net_param, action_param, action):
        BaseApplication.__init__(self)
        tf.logging.info('starting regression application')
        self.action = action

        self.net_param = net_param
        self.action_param = action_param

        self.data_param = None
        self.regression_param = None
        self.SUPPORTED_SAMPLING = {
            'uniform':
            (self.initialise_uniform_sampler, self.initialise_grid_sampler,
             self.initialise_grid_aggregator),
            'weighted':
            (self.initialise_weighted_sampler, self.initialise_grid_sampler,
             self.initialise_grid_aggregator),
            'resize':
            (self.initialise_resize_sampler, self.initialise_resize_sampler,
             self.initialise_resize_aggregator),
            'balanced':
            (self.initialise_balanced_sampler, self.initialise_grid_sampler,
             self.initialise_grid_aggregator),
        }

    def initialise_dataset_loader(self,
                                  data_param=None,
                                  task_param=None,
                                  data_partitioner=None):

        self.data_param = data_param
        self.regression_param = task_param

        # initialise input image readers
        if self.is_training:
            reader_names = ('image', 'output', 'weight', 'sampler')
        elif self.is_inference:
            # in the inference process use `image` input only
            reader_names = ('image', )
        elif self.is_evaluation:
            reader_names = ('image', 'output', 'inferred')
        else:
            tf.logging.fatal('Action `%s` not supported. Expected one of %s',
                             self.action, self.SUPPORTED_PHASES)
            raise ValueError
        try:
            reader_phase = self.action_param.dataset_to_infer
        except AttributeError:
            reader_phase = None
        file_lists = data_partitioner.get_file_lists_by(phase=reader_phase,
                                                        action=self.action)
        self.readers = [
            ImageReader(reader_names).initialise(data_param, task_param,
                                                 file_list)
            for file_list in file_lists
        ]

        # initialise input preprocessing layers
        mean_var_normaliser = MeanVarNormalisationLayer(image_name='image') \
            if self.net_param.whitening else None
        histogram_normaliser = HistogramNormalisationLayer(
            image_name='image',
            modalities=vars(task_param).get('image'),
            model_filename=self.net_param.histogram_ref_file,
            norm_type=self.net_param.norm_type,
            cutoff=self.net_param.cutoff,
            name='hist_norm_layer') \
            if (self.net_param.histogram_ref_file and
                self.net_param.normalisation) else None
        rgb_normaliser = RGBHistogramEquilisationLayer(
            image_name='image', name='rbg_norm_layer'
        ) if self.net_param.rgb_normalisation else None

        normalisation_layers = []
        if histogram_normaliser is not None:
            normalisation_layers.append(histogram_normaliser)
        if mean_var_normaliser is not None:
            normalisation_layers.append(mean_var_normaliser)
        if rgb_normaliser is not None:
            normalisation_layers.append(rgb_normaliser)

        volume_padding_layer = [
            PadLayer(image_name=SUPPORTED_INPUT,
                     border=self.net_param.volume_padding_size,
                     mode=self.net_param.volume_padding_mode,
                     pad_to=self.net_param.volume_padding_to_size)
        ]

        # initialise training data augmentation layers
        augmentation_layers = []
        if self.is_training:
            train_param = self.action_param
            if train_param.random_flipping_axes != -1:
                augmentation_layers.append(
                    RandomFlipLayer(
                        flip_axes=train_param.random_flipping_axes))
            if train_param.scaling_percentage:
                augmentation_layers.append(
                    RandomSpatialScalingLayer(
                        min_percentage=train_param.scaling_percentage[0],
                        max_percentage=train_param.scaling_percentage[1],
                        antialiasing=train_param.antialiasing,
                        isotropic=train_param.isotropic_scaling))
            if train_param.rotation_angle:
                rotation_layer = RandomRotationLayer()
                if train_param.rotation_angle:
                    rotation_layer.init_uniform_angle(
                        train_param.rotation_angle)
                augmentation_layers.append(rotation_layer)
            if train_param.do_elastic_deformation:
                spatial_rank = list(self.readers[0].spatial_ranks.values())[0]
                augmentation_layers.append(
                    RandomElasticDeformationLayer(
                        spatial_rank=spatial_rank,
                        num_controlpoints=train_param.num_ctrl_points,
                        std_deformation_sigma=train_param.deformation_sigma,
                        proportion_to_augment=train_param.proportion_to_deform)
                )

        # only add augmentation to first reader (not validation reader)
        self.readers[0].add_preprocessing_layers(volume_padding_layer +
                                                 normalisation_layers +
                                                 augmentation_layers)

        for reader in self.readers[1:]:
            reader.add_preprocessing_layers(volume_padding_layer +
                                            normalisation_layers)

    def initialise_uniform_sampler(self):
        self.sampler = [[
            UniformSampler(
                reader=reader,
                window_sizes=self.data_param,
                batch_size=self.net_param.batch_size,
                windows_per_image=self.action_param.sample_per_volume,
                queue_length=self.net_param.queue_length)
            for reader in self.readers
        ]]

    def initialise_weighted_sampler(self):
        self.sampler = [[
            WeightedSampler(
                reader=reader,
                window_sizes=self.data_param,
                batch_size=self.net_param.batch_size,
                windows_per_image=self.action_param.sample_per_volume,
                queue_length=self.net_param.queue_length)
            for reader in self.readers
        ]]

    def initialise_resize_sampler(self):
        self.sampler = [[
            ResizeSampler(reader=reader,
                          window_sizes=self.data_param,
                          batch_size=self.net_param.batch_size,
                          shuffle=self.is_training,
                          smaller_final_batch_mode=self.net_param.
                          smaller_final_batch_mode,
                          queue_length=self.net_param.queue_length)
            for reader in self.readers
        ]]

    def initialise_grid_sampler(self):
        self.sampler = [[
            GridSampler(
                reader=reader,
                window_sizes=self.data_param,
                batch_size=self.net_param.batch_size,
                spatial_window_size=self.action_param.spatial_window_size,
                window_border=self.action_param.border,
                smaller_final_batch_mode=self.net_param.
                smaller_final_batch_mode,
                queue_length=self.net_param.queue_length)
            for reader in self.readers
        ]]

    def initialise_balanced_sampler(self):
        self.sampler = [[
            BalancedSampler(
                reader=reader,
                window_sizes=self.data_param,
                batch_size=self.net_param.batch_size,
                windows_per_image=self.action_param.sample_per_volume,
                queue_length=self.net_param.queue_length)
            for reader in self.readers
        ]]

    def initialise_grid_aggregator(self):
        self.output_decoder = GridSamplesAggregator(
            image_reader=self.readers[0],
            output_path=self.action_param.save_seg_dir,
            window_border=self.action_param.border,
            interp_order=self.action_param.output_interp_order,
            postfix=self.action_param.output_postfix,
            fill_constant=self.action_param.fill_constant)

    def initialise_resize_aggregator(self):
        self.output_decoder = ResizeSamplesAggregator(
            image_reader=self.readers[0],
            output_path=self.action_param.save_seg_dir,
            window_border=self.action_param.border,
            interp_order=self.action_param.output_interp_order,
            postfix=self.action_param.output_postfix)

    def initialise_identity_aggregator(self):
        self.output_decoder = WindowAsImageAggregator(
            image_reader=self.readers[0],
            output_path=self.action_param.save_seg_dir,
            postfix=self.action_param.output_postfix)

    def initialise_sampler(self):
        if self.is_training:
            self.SUPPORTED_SAMPLING[self.net_param.window_sampling][0]()
        elif self.is_inference:
            self.SUPPORTED_SAMPLING[self.net_param.window_sampling][1]()

    def initialise_aggregator(self):
        if self.net_param.force_output_identity_resizing:
            self.initialise_identity_aggregator()
        else:
            self.SUPPORTED_SAMPLING[self.net_param.window_sampling][2]()

    def initialise_network(self):
        w_regularizer = None
        b_regularizer = None
        reg_type = self.net_param.reg_type.lower()
        decay = self.net_param.decay
        if reg_type == 'l2' and decay > 0:
            from tensorflow.contrib.layers.python.layers import regularizers
            w_regularizer = regularizers.l2_regularizer(decay)
            b_regularizer = regularizers.l2_regularizer(decay)
        elif reg_type == 'l1' and decay > 0:
            from tensorflow.contrib.layers.python.layers import regularizers
            w_regularizer = regularizers.l1_regularizer(decay)
            b_regularizer = regularizers.l1_regularizer(decay)

        self.net = ApplicationNetFactory.create(self.net_param.name)(
            num_classes=1,
            w_initializer=InitializerFactory.get_initializer(
                name=self.net_param.weight_initializer),
            b_initializer=InitializerFactory.get_initializer(
                name=self.net_param.bias_initializer),
            w_regularizer=w_regularizer,
            b_regularizer=b_regularizer,
            acti_func=self.net_param.activation_function)

    def connect_data_and_network(self,
                                 outputs_collector=None,
                                 gradients_collector=None):
        def switch_sampler(for_training):
            with tf.name_scope('train' if for_training else 'validation'):
                sampler = self.get_sampler()[0][0 if for_training else -1]
                return sampler.pop_batch_op()

        if self.is_training:
            self.patience = self.action_param.patience
            self.mode = self.action_param.early_stopping_mode
            if self.action_param.validation_every_n > 0:
                data_dict = tf.cond(tf.logical_not(self.is_validation),
                                    lambda: switch_sampler(for_training=True),
                                    lambda: switch_sampler(for_training=False))
            else:
                data_dict = switch_sampler(for_training=True)

            image = tf.cast(data_dict['image'], tf.float32)
            net_args = {
                'is_training': self.is_training,
                'keep_prob': self.net_param.keep_prob
            }
            net_out = self.net(image, **net_args)

            with tf.name_scope('Optimiser'):
                optimiser_class = OptimiserFactory.create(
                    name=self.action_param.optimiser)
                self.optimiser = optimiser_class.get_instance(
                    learning_rate=self.action_param.lr)
            loss_func = LossFunction(loss_type=self.action_param.loss_type)

            weight_map = data_dict.get('weight', None)
            border = self.regression_param.loss_border
            if border == None or tf.reduce_sum(tf.abs(border)) == 0:
                data_loss = loss_func(prediction=net_out,
                                      ground_truth=data_dict['output'],
                                      weight_map=weight_map)
            else:
                crop_layer = CropLayer(border)
                weight_map = None if weight_map is None else crop_layer(
                    weight_map)
                data_loss = loss_func(prediction=crop_layer(net_out),
                                      ground_truth=crop_layer(
                                          data_dict['output']),
                                      weight_map=weight_map)
            reg_losses = tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)
            if self.net_param.decay > 0.0 and reg_losses:
                reg_loss = tf.reduce_mean(
                    [tf.reduce_mean(reg_loss) for reg_loss in reg_losses])
                loss = data_loss + reg_loss
            else:
                loss = data_loss

            # Get all vars
            to_optimise = tf.trainable_variables()
            vars_to_freeze = \
                self.action_param.vars_to_freeze or \
                self.action_param.vars_to_restore
            if vars_to_freeze:
                import re
                var_regex = re.compile(vars_to_freeze)
                # Only optimise vars that are not frozen
                to_optimise = \
                    [v for v in to_optimise if not var_regex.search(v.name)]
                tf.logging.info(
                    "Optimizing %d out of %d trainable variables, "
                    "the other variables are fixed (--vars_to_freeze %s)",
                    len(to_optimise), len(tf.trainable_variables()),
                    vars_to_freeze)

            self.total_loss = loss

            grads = self.optimiser.compute_gradients(
                loss, var_list=to_optimise, colocate_gradients_with_ops=True)
            # collecting gradients variables
            gradients_collector.add_to_collection([grads])

            # collecting output variables
            outputs_collector.add_to_collection(var=self.total_loss,
                                                name='total_loss',
                                                average_over_devices=True,
                                                collection=CONSOLE)
            outputs_collector.add_to_collection(var=self.total_loss,
                                                name='total_loss',
                                                average_over_devices=True,
                                                summary_type='scalar',
                                                collection=TF_SUMMARIES)
            outputs_collector.add_to_collection(var=data_loss,
                                                name='loss',
                                                average_over_devices=False,
                                                collection=CONSOLE)
            outputs_collector.add_to_collection(var=data_loss,
                                                name='loss',
                                                average_over_devices=True,
                                                summary_type='scalar',
                                                collection=TF_SUMMARIES)

        elif self.is_inference:
            data_dict = switch_sampler(for_training=False)
            image = tf.cast(data_dict['image'], tf.float32)
            net_args = {
                'is_training': self.is_training,
                'keep_prob': self.net_param.keep_prob
            }
            net_out = self.net(image, **net_args)
            net_out = PostProcessingLayer('IDENTITY')(net_out)

            outputs_collector.add_to_collection(var=net_out,
                                                name='window',
                                                average_over_devices=False,
                                                collection=NETWORK_OUTPUT)
            outputs_collector.add_to_collection(
                var=data_dict['image_location'],
                name='location',
                average_over_devices=False,
                collection=NETWORK_OUTPUT)
            self.initialise_aggregator()

    def interpret_output(self, batch_output):
        if self.is_inference:
            return self.output_decoder.decode_batch(
                {'window_reg': batch_output['window']},
                batch_output['location'])
        return True

    def initialise_evaluator(self, eval_param):
        self.eval_param = eval_param
        self.evaluator = RegressionEvaluator(self.readers[0],
                                             self.regression_param, eval_param)

    def add_inferred_output(self, data_param, task_param):
        return self.add_inferred_output_like(data_param, task_param, 'output')
Ejemplo n.º 14
0
 def initialise_identity_aggregator(self):
     self.output_decoder = WindowAsImageAggregator(
         image_reader=self.readers[0],
         output_path=self.action_param.save_seg_dir,
         postfix=self.action_param.output_postfix)
class MultiOutputApplication(BaseApplication):
    REQUIRED_CONFIG_SECTION = "SEGMENTATION"

    def __init__(self, net_param, action_param, action):
        BaseApplication.__init__(self)
        tf.logging.info('starting multioutput test')
        self.action = action

        self.net_param = net_param
        self.action_param = action_param

        self.data_param = None
        self.multioutput_param = None

        self.SUPPORTED_SAMPLING = {
            'uniform':
            (self.initialise_uniform_sampler, self.initialise_grid_sampler,
             self.initialise_grid_aggregator),
            'weighted':
            (self.initialise_weighted_sampler, self.initialise_grid_sampler,
             self.initialise_grid_aggregator),
            'resize':
            (self.initialise_resize_sampler, self.initialise_resize_sampler,
             self.initialise_resize_aggregator),
            'classifier':
            (self.initialise_resize_sampler, self.initialise_resize_sampler,
             self.initialise_classifier_aggregator),
            'identity':
            (self.initialise_uniform_sampler, self.initialise_resize_sampler,
             self.initialise_identity_aggregator)
        }

    def initialise_dataset_loader(self,
                                  data_param=None,
                                  task_param=None,
                                  data_partitioner=None):

        self.data_param = data_param
        self.multioutput_param = task_param

        # initialise input image readers
        if self.is_training:
            reader_names = ('image', 'label', 'weight', 'sampler')
        elif self.is_inference:
            # in the inference process use `image` input only
            reader_names = ('image', )
        elif self.is_evaluation:
            reader_names = ('image', 'label', 'inferred')
        else:
            tf.logging.fatal('Action `%s` not supported. Expected one of %s',
                             self.action, self.SUPPORTED_PHASES)
            raise ValueError
        try:
            reader_phase = self.action_param.dataset_to_infer
        except AttributeError:
            reader_phase = None
        file_lists = data_partitioner.get_file_lists_by(phase=reader_phase,
                                                        action=self.action)
        self.readers = [
            ImageReader(reader_names).initialise(data_param, task_param,
                                                 file_list)
            for file_list in file_lists
        ]

    def initialise_uniform_sampler(self):
        self.sampler = [[
            UniformSampler(
                reader=reader,
                window_sizes=self.data_param,
                batch_size=self.net_param.batch_size,
                windows_per_image=self.action_param.sample_per_volume,
                queue_length=self.net_param.queue_length)
            for reader in self.readers
        ]]

    def initialise_weighted_sampler(self):
        self.sampler = [[
            WeightedSampler(
                reader=reader,
                window_sizes=self.data_param,
                batch_size=self.net_param.batch_size,
                windows_per_image=self.action_param.sample_per_volume,
                queue_length=self.net_param.queue_length)
            for reader in self.readers
        ]]

    def initialise_resize_sampler(self):
        self.sampler = [[
            ResizeSampler(reader=reader,
                          window_sizes=self.data_param,
                          batch_size=self.net_param.batch_size,
                          shuffle=self.is_training,
                          smaller_final_batch_mode=self.net_param.
                          smaller_final_batch_mode,
                          queue_length=self.net_param.queue_length)
            for reader in self.readers
        ]]

    def initialise_grid_sampler(self):
        self.sampler = [[
            GridSampler(
                reader=reader,
                window_sizes=self.data_param,
                batch_size=self.net_param.batch_size,
                spatial_window_size=self.action_param.spatial_window_size,
                window_border=self.action_param.border,
                smaller_final_batch_mode=self.net_param.
                smaller_final_batch_mode,
                queue_length=self.net_param.queue_length)
            for reader in self.readers
        ]]

    def initialise_balanced_sampler(self):
        self.sampler = [[
            BalancedSampler(
                reader=reader,
                window_sizes=self.data_param,
                batch_size=self.net_param.batch_size,
                windows_per_image=self.action_param.sample_per_volume,
                queue_length=self.net_param.queue_length)
            for reader in self.readers
        ]]

    def initialise_grid_aggregator(self):
        self.output_decoder = GridSamplesAggregator(
            image_reader=self.readers[0],
            output_path=self.action_param.save_seg_dir,
            window_border=self.action_param.border,
            interp_order=self.action_param.output_interp_order,
            postfix=self.action_param.output_postfix,
            fill_constant=self.action_param.fill_constant)

    def initialise_resize_aggregator(self):
        self.output_decoder = ResizeSamplesAggregator(
            image_reader=self.readers[0],
            output_path=self.action_param.save_seg_dir,
            window_border=self.action_param.border,
            interp_order=self.action_param.output_interp_order,
            postfix=self.action_param.output_postfix)

    def initialise_identity_aggregator(self):
        self.output_decoder = WindowAsImageAggregator(
            image_reader=self.readers[0],
            output_path=self.action_param.save_seg_dir,
            postfix=self.action_param.output_postfix)

    def initialise_classifier_aggregator(self):
        pass
        # self.output_decoder = ClassifierSamplesAggregator(
        #     image_reader=self.readers[0],
        #     output_path=self.action_param.save_seg_dir,
        #     postfix=self.action_param.output_postfix)

    def initialise_sampler(self):
        if self.is_training:
            self.SUPPORTED_SAMPLING[self.net_param.window_sampling][0]()
        elif self.is_inference:
            self.SUPPORTED_SAMPLING[self.net_param.window_sampling][1]()

    def initialise_aggregator(self):
        self.SUPPORTED_SAMPLING[self.net_param.window_sampling][2]()

    def initialise_network(self):
        w_regularizer = None
        b_regularizer = None
        reg_type = self.net_param.reg_type.lower()
        decay = self.net_param.decay
        if reg_type == 'l2' and decay > 0:
            from tensorflow.contrib.layers.python.layers import regularizers
            w_regularizer = regularizers.l2_regularizer(decay)
            b_regularizer = regularizers.l2_regularizer(decay)
        elif reg_type == 'l1' and decay > 0:
            from tensorflow.contrib.layers.python.layers import regularizers
            w_regularizer = regularizers.l1_regularizer(decay)
            b_regularizer = regularizers.l1_regularizer(decay)

        self.net = ApplicationNetFactory.create('toynet')(
            num_classes=self.multioutput_param.num_classes,
            w_initializer=InitializerFactory.get_initializer(
                name=self.net_param.weight_initializer),
            b_initializer=InitializerFactory.get_initializer(
                name=self.net_param.bias_initializer),
            w_regularizer=w_regularizer,
            b_regularizer=b_regularizer,
            acti_func=self.net_param.activation_function)

    def connect_data_and_network(self,
                                 outputs_collector=None,
                                 gradients_collector=None):
        def switch_sampler(for_training):
            with tf.name_scope('train' if for_training else 'validation'):
                sampler = self.get_sampler()[0][0 if for_training else -1]
                return sampler.pop_batch_op()

        if self.is_training:
            # extract data
            if self.action_param.validation_every_n > 0:
                data_dict = tf.cond(tf.logical_not(self.is_validation),
                                    lambda: switch_sampler(for_training=True),
                                    lambda: switch_sampler(for_training=False))
            else:
                data_dict = switch_sampler(for_training=True)

            image = tf.cast(data_dict['image'], tf.float32)
            net_args = {
                'is_training': self.is_training,
                'keep_prob': self.net_param.keep_prob
            }
            net_out = self.net(image, **net_args)

            with tf.name_scope('Optimiser'):
                optimiser_class = OptimiserFactory.create(
                    name=self.action_param.optimiser)
                self.optimiser = optimiser_class.get_instance(
                    learning_rate=self.action_param.lr)

            loss_func = LossFunction(
                n_class=self.multioutput_param.num_classes,
                loss_type=self.action_param.loss_type,
                softmax=self.multioutput_param.softmax)
            data_loss = loss_func(prediction=net_out,
                                  ground_truth=data_dict.get('label', None),
                                  weight_map=data_dict.get('weight', None))
            reg_losses = tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)
            if self.net_param.decay > 0.0 and reg_losses:
                reg_loss = tf.reduce_mean(
                    [tf.reduce_mean(reg_loss) for reg_loss in reg_losses])
                loss = data_loss + reg_loss
            else:
                loss = data_loss

            # set the optimiser and the gradient
            to_optimise = tf.trainable_variables()
            vars_to_freeze = \
                self.action_param.vars_to_freeze or \
                self.action_param.vars_to_restore
            if vars_to_freeze:
                import re
                var_regex = re.compile(vars_to_freeze)
                # Only optimise vars that are not frozen
                to_optimise = \
                    [v for v in to_optimise if not var_regex.search(v.name)]
                tf.logging.info(
                    "Optimizing %d out of %d trainable variables, "
                    "the other variables fixed (--vars_to_freeze %s)",
                    len(to_optimise), len(tf.trainable_variables()),
                    vars_to_freeze)

            grads = self.optimiser.compute_gradients(
                loss, var_list=to_optimise, colocate_gradients_with_ops=True)

            # collecting gradients variables
            gradients_collector.add_to_collection([grads])
            # collecting output variables
            outputs_collector.add_to_collection(var=data_loss,
                                                name='loss',
                                                average_over_devices=False,
                                                collection=CONSOLE)
            outputs_collector.add_to_collection(var=data_loss,
                                                name='loss',
                                                average_over_devices=True,
                                                summary_type='scalar',
                                                collection=TF_SUMMARIES)

        elif self.is_inference:

            data_dict = switch_sampler(for_training=False)
            image = tf.cast(data_dict['image'], tf.float32)
            net_args = {
                'is_training': self.is_training,
                'keep_prob': self.net_param.keep_prob
            }
            net_out = self.net(image, **net_args)

            num_classes = self.multioutput_param.num_classes
            argmax_layer = PostProcessingLayer('ARGMAX',
                                               num_classes=num_classes)
            softmax_layer = PostProcessingLayer('SOFTMAX',
                                                num_classes=num_classes)

            arg_max_out = argmax_layer(net_out)
            soft_max_out = softmax_layer(net_out)
            # sum_prob_out = tf.reshape(tf.reduce_sum(soft_max_out),[1,1])
            # min_prob_out = tf.reshape(tf.reduce_min(soft_max_out),[1,1])
            sum_prob_out = tf.reduce_sum(soft_max_out)
            min_prob_out = tf.reduce_min(soft_max_out)

            outputs_collector.add_to_collection(var=arg_max_out,
                                                name='window_argmax',
                                                average_over_devices=False,
                                                collection=NETWORK_OUTPUT)
            outputs_collector.add_to_collection(var=soft_max_out,
                                                name='window_softmax',
                                                average_over_devices=False,
                                                collection=NETWORK_OUTPUT)
            outputs_collector.add_to_collection(var=sum_prob_out,
                                                name='csv_sum',
                                                average_over_devices=False,
                                                collection=NETWORK_OUTPUT)
            outputs_collector.add_to_collection(var=min_prob_out,
                                                name='csv_min',
                                                average_over_devices=False,
                                                collection=NETWORK_OUTPUT)
            outputs_collector.add_to_collection(
                var=data_dict['image_location'],
                name='location',
                average_over_devices=False,
                collection=NETWORK_OUTPUT)
            self.initialise_aggregator()

    def interpret_output(self, batch_output):
        if self.is_inference:
            return self.output_decoder.decode_batch(
                {
                    'window_argmax': batch_output['window_argmax'],
                    'window_softmax': batch_output['window_softmax'],
                    'csv_sum': batch_output['csv_sum'],
                    'csv_min': batch_output['csv_min']
                }, batch_output['location'])
        return True