def test_25d_init(self):
     reader = get_25d_reader()
     sampler = ResizeSampler(reader=reader,
                             window_sizes=SINGLE_25D_DATA,
                             batch_size=1,
                             shuffle=False,
                             queue_length=50)
     aggregator = WindowAsImageAggregator(
         image_reader=reader,
         output_path=os.path.join('testing_data', 'aggregated_identity'),
     )
     more_batch = True
     out_shape = []
     with self.cached_session() as sess:
         sampler.set_num_threads(2)
         while more_batch:
             try:
                 out = sess.run(sampler.pop_batch_op())
                 out_shape = out['image'].shape[1:] + (1, )
             except tf.errors.OutOfRangeError:
                 break
             more_batch = aggregator.decode_batch(
                 {'window_image': out['image']}, out['image_location'])
     output_filename = '{}_window_image_niftynet_generated.nii.gz'.format(
         sampler.reader.get_subject_id(0))
     output_file = os.path.join('testing_data', 'aggregated_identity',
                                output_filename)
     out_shape = [out_shape[i] for i in NEW_ORDER_2D] + [
         1,
     ]
     self.assertAllClose(nib.load(output_file).shape, out_shape[:2])
     sampler.close_all()
    def test_init_2d_mo_bidimcsv(self):
        reader = get_2d_reader()
        sampler = ResizeSampler(reader=reader,
                                window_sizes=MOD_2D_DATA,
                                batch_size=1,
                                shuffle=False,
                                queue_length=50)
        aggregator = WindowAsImageAggregator(
            image_reader=reader,
            output_path=os.path.join('testing_data', 'aggregated_identity'),
        )
        more_batch = True
        out_shape = []
        with self.cached_session() as sess:
            sampler.set_num_threads(2)
            while more_batch:
                try:
                    out = sess.run(sampler.pop_batch_op())
                    out_shape = out['image'].shape[1:] + (1, )
                except tf.errors.OutOfRangeError:
                    break
                min_val = np.sum((np.asarray(out['image']).flatten()))
                stats_val = [
                    np.min(out['image']),
                    np.max(out['image']),
                    np.sum(out['image'])
                ]
                stats_val = np.expand_dims(stats_val, 0)
                stats_val = np.concatenate([stats_val, stats_val], axis=0)
                more_batch = aggregator.decode_batch(
                    {
                        'window_image': out['image'],
                        'csv_sum': min_val,
                        'csv_stats2d': stats_val
                    }, out['image_location'])
        output_filename = '{}_window_image_niftynet_generated.nii.gz'.format(
            sampler.reader.get_subject_id(0))
        sum_filename = os.path.join(
            'testing_data', 'aggregated_identity',
            '{}_csv_sum_niftynet_generated.csv'.format(
                sampler.reader.get_subject_id(0)))
        stats_filename = os.path.join(
            'testing_data', 'aggregated_identity',
            '{}_csv_stats2d_niftynet_generated.csv'.format(
                sampler.reader.get_subject_id(0)))
        output_file = os.path.join('testing_data', 'aggregated_identity',
                                   output_filename)

        out_shape = [out_shape[i] for i in NEW_ORDER_2D] + [
            1,
        ]
        self.assertAllClose(nib.load(output_file).shape, out_shape[:2])
        min_pd = pd.read_csv(sum_filename)
        self.assertAllClose(min_pd.shape, [1, 2])
        stats_pd = pd.read_csv(stats_filename)
        self.assertAllClose(stats_pd.shape, [1, 7])
        sampler.close_all()
Ejemplo n.º 3
0
    def connect_data_and_network(self,
                                 outputs_collector=None,
                                 gradients_collector=None):
        if self.is_training:

            def switch_sampler(for_training):
                with tf.name_scope('train' if for_training else 'validation'):
                    sampler = self.get_sampler()[0][0 if for_training else -1]
                    return sampler.pop_batch_op()

            if self.action_param.validation_every_n > 0:
                data_dict = tf.cond(tf.logical_not(self.is_validation),
                                    lambda: switch_sampler(True),
                                    lambda: switch_sampler(False))
            else:
                data_dict = switch_sampler(for_training=True)

            images = tf.cast(data_dict['image'], tf.float32)
            noise_shape = [
                self.net_param.batch_size, self.gan_param.noise_size
            ]
            noise = tf.random_normal(shape=noise_shape,
                                     mean=0.0,
                                     stddev=1.0,
                                     dtype=tf.float32)
            conditioning = data_dict['conditioning']
            net_output = self.net(noise, images, conditioning,
                                  self.is_training)

            loss_func = LossFunction(loss_type=self.action_param.loss_type)
            real_logits = net_output[1]
            fake_logits = net_output[2]
            lossG, lossD = loss_func(real_logits, fake_logits)
            if self.net_param.decay > 0:
                reg_losses = tf.get_collection(
                    tf.GraphKeys.REGULARIZATION_LOSSES)
                if reg_losses:
                    reg_loss = tf.reduce_mean(
                        [tf.reduce_mean(l_reg) for l_reg in reg_losses])
                    lossD = lossD + reg_loss
                    lossG = lossG + reg_loss

            # variables to display in STDOUT
            outputs_collector.add_to_collection(var=lossD,
                                                name='lossD',
                                                average_over_devices=True,
                                                collection=CONSOLE)
            outputs_collector.add_to_collection(var=lossG,
                                                name='lossG',
                                                average_over_devices=False,
                                                collection=CONSOLE)
            # variables to display in tensorboard
            outputs_collector.add_to_collection(var=lossG,
                                                name='lossG',
                                                average_over_devices=False,
                                                collection=TF_SUMMARIES)
            outputs_collector.add_to_collection(var=lossG,
                                                name='lossD',
                                                average_over_devices=True,
                                                collection=TF_SUMMARIES)

            with tf.name_scope('Optimiser'):
                optimiser_class = OptimiserFactory.create(
                    name=self.action_param.optimiser)
                self.optimiser = optimiser_class.get_instance(
                    learning_rate=self.action_param.lr)

            with tf.name_scope('ComputeGradients'):
                # gradients of generator
                generator_variables = tf.get_collection(
                    tf.GraphKeys.TRAINABLE_VARIABLES, scope='generator')
                generator_grads = self.optimiser.compute_gradients(
                    lossG, var_list=generator_variables)

                # gradients of discriminator
                discriminator_variables = tf.get_collection(
                    tf.GraphKeys.TRAINABLE_VARIABLES, scope='discriminator')
                discriminator_grads = self.optimiser.compute_gradients(
                    lossD, var_list=discriminator_variables)
                grads = [generator_grads, discriminator_grads]

                # add the grads back to application_driver's training_grads
                gradients_collector.add_to_collection(grads)
        else:
            data_dict = self.get_sampler()[0][0].pop_batch_op()
            conditioning_dict = self.get_sampler()[1][0].pop_batch_op()
            conditioning = conditioning_dict['conditioning']
            image_size = conditioning.shape.as_list()[:-1]
            dummy_image = tf.zeros(image_size + [1])
            net_output = self.net(data_dict['vector'], dummy_image,
                                  conditioning, self.is_training)
            outputs_collector.add_to_collection(var=net_output[0],
                                                name='image',
                                                average_over_devices=False,
                                                collection=NETWORK_OUTPUT)
            outputs_collector.add_to_collection(
                var=conditioning_dict['conditioning_location'],
                name='location',
                average_over_devices=False,
                collection=NETWORK_OUTPUT)

            self.output_decoder = WindowAsImageAggregator(
                image_reader=self.readers[0],
                output_path=self.action_param.save_seg_dir)
Ejemplo n.º 4
0
    def connect_data_and_network(self,
                                 outputs_collector=None,
                                 gradients_collector=None):
        def switch_sampler(for_training):
            with tf.name_scope('train' if for_training else 'validation'):
                sampler = self.get_sampler()[0][0 if for_training else -1]
                return sampler.pop_batch_op()

        if self.is_training:
            if self.action_param.validation_every_n > 0:
                data_dict = tf.cond(tf.logical_not(self.is_validation),
                                    lambda: switch_sampler(True),
                                    lambda: switch_sampler(False))
            else:
                data_dict = switch_sampler(for_training=True)

            image = tf.cast(data_dict['image'], tf.float32)
            net_output = self.net(image, is_training=self.is_training)

            with tf.name_scope('Optimiser'):
                optimiser_class = OptimiserFactory.create(
                    name=self.action_param.optimiser)
                self.optimiser = optimiser_class.get_instance(
                    learning_rate=self.action_param.lr)

            loss_func = LossFunction(loss_type=self.action_param.loss_type)
            data_loss = loss_func(net_output)
            loss = data_loss
            if self.net_param.decay > 0.0:
                reg_losses = tf.get_collection(
                    tf.GraphKeys.REGULARIZATION_LOSSES)
                if reg_losses:
                    reg_loss = tf.reduce_mean(
                        [tf.reduce_mean(reg_loss) for reg_loss in reg_losses])
                    loss = loss + reg_loss
            grads = self.optimiser.compute_gradients(loss)
            # collecting gradients variables
            gradients_collector.add_to_collection([grads])

            outputs_collector.add_to_collection(var=data_loss,
                                                name='variational_lower_bound',
                                                average_over_devices=True,
                                                collection=CONSOLE)
            outputs_collector.add_to_collection(var=data_loss,
                                                name='variational_lower_bound',
                                                average_over_devices=True,
                                                summary_type='scalar',
                                                collection=TF_SUMMARIES)

            outputs_collector.add_to_collection(var=net_output[4],
                                                name='Originals',
                                                average_over_devices=False,
                                                summary_type='image3_coronal',
                                                collection=TF_SUMMARIES)
            outputs_collector.add_to_collection(var=net_output[2],
                                                name='Means',
                                                average_over_devices=False,
                                                summary_type='image3_coronal',
                                                collection=TF_SUMMARIES)
            outputs_collector.add_to_collection(var=net_output[5],
                                                name='Variances',
                                                average_over_devices=False,
                                                summary_type='image3_coronal',
                                                collection=TF_SUMMARIES)
        else:
            if self._infer_type in ('encode', 'encode-decode'):
                data_dict = self.get_sampler()[0][0].pop_batch_op()
                image = tf.cast(data_dict['image'], dtype=tf.float32)
                net_output = self.net(image, is_training=False)

                outputs_collector.add_to_collection(
                    var=data_dict['image_location'],
                    name='location',
                    average_over_devices=True,
                    collection=NETWORK_OUTPUT)

                if self._infer_type == 'encode-decode':
                    outputs_collector.add_to_collection(
                        var=net_output[2],
                        name='generated_image',
                        average_over_devices=True,
                        collection=NETWORK_OUTPUT)
                if self._infer_type == 'encode':
                    outputs_collector.add_to_collection(
                        var=net_output[7],
                        name='embedded',
                        average_over_devices=True,
                        collection=NETWORK_OUTPUT)

                self.output_decoder = WindowAsImageAggregator(
                    image_reader=self.readers[0],
                    output_path=self.action_param.save_seg_dir)
                return
            elif self._infer_type == 'sample':
                image_size = (self.net_param.batch_size,) + \
                             self.action_param.spatial_window_size + (1,)
                dummy_image = tf.zeros(image_size)
                net_output = self.net(dummy_image, is_training=False)
                noise_shape = net_output[-1].get_shape().as_list()
                stddev = self.autoencoder_param.noise_stddev
                noise = tf.random_normal(shape=noise_shape,
                                         mean=0.0,
                                         stddev=stddev,
                                         dtype=tf.float32)
                partially_decoded_sample = self.net.shared_decoder(
                    noise, is_training=False)
                decoder_output = self.net.decoder_means(
                    partially_decoded_sample, is_training=False)

                outputs_collector.add_to_collection(var=decoder_output,
                                                    name='generated_image',
                                                    average_over_devices=True,
                                                    collection=NETWORK_OUTPUT)
                self.output_decoder = WindowAsImageAggregator(
                    image_reader=None,
                    output_path=self.action_param.save_seg_dir)
                return
            elif self._infer_type == 'linear_interpolation':
                # construct the entire network
                image_size = (self.net_param.batch_size,) + \
                             self.action_param.spatial_window_size + (1,)
                dummy_image = tf.zeros(image_size)
                net_output = self.net(dummy_image, is_training=False)
                data_dict = self.get_sampler()[0][0].pop_batch_op()
                real_code = data_dict['feature']
                real_code = tf.reshape(real_code, net_output[-1].get_shape())
                partially_decoded_sample = self.net.shared_decoder(
                    real_code, is_training=False)
                decoder_output = self.net.decoder_means(
                    partially_decoded_sample, is_training=False)

                outputs_collector.add_to_collection(var=decoder_output,
                                                    name='generated_image',
                                                    average_over_devices=True,
                                                    collection=NETWORK_OUTPUT)
                outputs_collector.add_to_collection(
                    var=data_dict['feature_location'],
                    name='location',
                    average_over_devices=True,
                    collection=NETWORK_OUTPUT)
                self.output_decoder = WindowAsImageAggregator(
                    image_reader=self.readers[0],
                    output_path=self.action_param.save_seg_dir)
            else:
                raise NotImplementedError
Ejemplo n.º 5
0
 def initialise_identity_aggregator(self):
     self.output_decoder = WindowAsImageAggregator(
         image_reader=self.readers[0],
         output_path=self.action_param.save_seg_dir,
         postfix=self.action_param.output_postfix)