示例#1
0
def up_block(lower_res_inputs,
             same_res_inputs,
             output_channels,
             kernel_shape,
             stride=1,
             rate=1,
             num_convs=2,
             initializers={'w': he_normal(), 'b': tf.truncated_normal_initializer(stddev=0.001)},
             regularizers=None,
             nonlinearity=tf.nn.relu,
             up_sampling_op=lambda x, size: tf.image.resize_images(x, size,
                                            method=tf.image.ResizeMethod.BILINEAR, align_corners=True),
             data_format='NCHW',
             name='up_block'):
    """A block made up of an up-sampling step followed by several convolutional layers."""
    with tf.variable_scope(name):
        spatial_shape = same_res_inputs.get_shape()[2:]

        if data_format=='NHWC':
            features = up_sampling_op(lower_res_inputs, spatial_shape)
            features = tf.concat([features, same_res_inputs], axis=-1)
        else:
            lower_res_inputs = tf.transpose(lower_res_inputs, perm=[0,2,3,1])
            features = up_sampling_op(lower_res_inputs, spatial_shape)
            features = tf.transpose(features, perm=[0,3,1,2])
            features = tf.concat([features, same_res_inputs], axis=1)

        for _ in range(num_convs):
            features = snt.Conv2D(output_channels, kernel_shape, stride, rate, data_format=data_format,
                                  initializers=initializers, regularizers=regularizers)(features)
            features = nonlinearity(features)

        return features
示例#2
0
    def load(self, experiment_dir):
        config_filename = os.path.join(experiment_dir, CONFIG_FILENAME)
        cf = SourceFileLoader('cf', config_filename).load_module()
        self._punet = ProbUNet(
            latent_dim=cf.latent_dim,
            num_channels=cf.num_channels,
            num_1x1_convs=cf.num_1x1_convs,
            num_classes=cf.num_classes,
            num_convs_per_block=cf.num_convs_per_block,
            initializers={
                'w': training_utils.he_normal(),
                'b': tf.truncated_normal_initializer(stddev=0.001)
            },
            regularizers={
                'w': tf.contrib.layers.l2_regularizer(1.0),
                'b': tf.contrib.layers.l2_regularizer(1.0)
            })
        self._x = tf.placeholder(tf.float32, shape=cf.network_input_shape)
        self._y = tf.placeholder(tf.uint8, shape=cf.label_shape)
        self._sigma_multiplier = tf.placeholder(tf.float32, shape=(6, ))
        # is_training=True to get posterior_net as well
        self._punet(self._x,
                    self._y,
                    is_training=True,
                    one_hot_labels=cf.one_hot_labels)

        # posterior-based inference
        self._posterior_latent_mu_op = self._punet._q.mean()
        self._posterior_latent_stddev_op = self._punet._q.stddev()
        self._posterior_latent_sample_op = self._punet._q.sample()
        self._posterior_sample_det_op = self._punet.reconstruct(
            z_q=self._posterior_latent_mu_op, softmax=True)
        self._posterior_sample_op = self._punet.reconstruct(
            z_q=self._posterior_latent_sample_op, softmax=True)
        self._posterior_external_sample_op = self._punet.reconstruct(
            z_q=self._posterior_latent_mu_op +
            self._sigma_multiplier * self._posterior_latent_stddev_op,
            softmax=True)

        # prior-based inference
        self._prior_latent_mu_op = self._punet._p.mean()
        self._prior_latent_stddev_op = self._punet._p.stddev()
        self._prior_latent_sample_op = self._punet._p.sample()
        self._prior_sample_det_op = self._punet.reconstruct(
            z_q=self._prior_latent_mu_op, softmax=True)
        self._prior_sample_op = self._punet.reconstruct(
            z_q=self._prior_latent_sample_op, softmax=True)
        self._prior_external_sample_op = self._punet.reconstruct(
            z_q=self._prior_latent_mu_op +
            self._sigma_multiplier * self._prior_latent_stddev_op,
            softmax=True)

        self._sampled_logits_op = self._punet.sample()

        saver = tf.train.Saver(save_relative_paths=True)
        self._session = tf.train.MonitoredTrainingSession()
        print("Experiment dir:", experiment_dir)
        latest_ckpt_path = tf.train.latest_checkpoint(experiment_dir)
        print("Loading model from:", latest_ckpt_path)
        saver.restore(self._session, latest_ckpt_path)
示例#3
0
def down_block(features,
               output_channels,
               kernel_shape,
               stride=1,
               rate=1,
               num_convs=2,
               initializers={'w': he_normal(), 'b': tf.truncated_normal_initializer(stddev=0.001)},
               regularizers=None,
               nonlinearity=tf.nn.relu,
               down_sample_input=True,
               down_sampling_op=lambda x, df: tf.nn.avg_pool(x, ksize=[1,1,2,2], strides=[1,1,2,2],
                                                padding='SAME', data_format=df),
               data_format='NCHW',
               name='down_block'):
    """A block made up of a down-sampling step followed by several convolutional layers."""
    with tf.variable_scope(name):
        if down_sample_input:
            features = down_sampling_op(features, data_format)

        for _ in range(num_convs):
            features = snt.Conv2D(output_channels, kernel_shape, stride, rate, data_format=data_format,
                                  initializers=initializers, regularizers=regularizers)(features)
            features = nonlinearity(features)

        return features
示例#4
0
    def __init__(self,
                 latent_dim,
                 num_channels,
                 nonlinearity=tf.nn.relu,
                 num_convs_per_block=3,
                 initializers={'w': he_normal(), 'b': tf.truncated_normal_initializer(stddev=0.001)},
                 regularizers={'w': tf.contrib.layers.l2_regularizer(1.0), 'b': tf.contrib.layers.l2_regularizer(1.0)},
                 data_format='NCHW',
                 down_sampling_op=lambda x, df:\
                         tf.nn.avg_pool(x, ksize=[1,1,2,2], strides=[1,1,2,2], padding='SAME', data_format=df),
                 name="conv_dist"):
        self._latent_dim = latent_dim
        self._initializers = initializers
        self._regularizers = regularizers
        self._data_format = data_format

        if data_format == 'NCHW':
            self._channel_axis = 1
            self._spatial_axes = [2,3]
        else:
            self._channel_axis = -1
            self._spatial_axes = [1,2]

        super(AxisAlignedConvGaussian, self).__init__(name=name)
        with self._enter_variable_scope():
            tf.logging.info('Building ConvGaussian.')
            self._encoder = VGG_Encoder(num_channels, nonlinearity, num_convs_per_block, initializers, regularizers,
                                        data_format=data_format, down_sampling_op=down_sampling_op)
 def __init__(self,
              num_channels,
              num_classes,
              nonlinearity=tf.nn.relu,
              num_convs_per_block=3,
              initializers={
                  'w': he_normal(),
                  'b': tf.truncated_normal_initializer(stddev=0.001)
              },
              regularizers={
                  'w': tf.contrib.layers.l2_regularizer(1.0),
                  'b': tf.contrib.layers.l2_regularizer(1.0)
              },
              data_format='NCHW',
              up_sampling_op=lambda x, size: tf.image.resize_images(
                  x,
                  size,
                  method=tf.image.ResizeMethod.NEAREST_NEIGHBOR,
                  align_corners=True),
              name="vgg_dec"):
     super(VGG_Decoder, self).__init__(name=name)
     self._num_channels = num_channels
     self._num_classes = num_classes
     self._nonlinearity = nonlinearity
     self._num_convs = num_convs_per_block
     self._initializers = initializers
     self._regularizers = regularizers
     self._data_format = data_format
     self._up_sampling_op = up_sampling_op
 def __init__(
         self,
         num_channels,
         nonlinearity=tf.nn.relu,
         num_convs_per_block=3,
         initializers={
             'w': he_normal(),
             'b': tf.truncated_normal_initializer(stddev=0.001)
         },
         regularizers={
             'w': tf.contrib.layers.l2_regularizer(1.0),
             'b': tf.contrib.layers.l2_regularizer(1.0)
         },
         data_format='NCHW',
         down_sampling_op=lambda x, df: tf.nn.avg_pool(x,
                                                       ksize=[1, 1, 2, 2],
                                                       strides=[1, 1, 2, 2],
                                                       padding='SAME',
                                                       data_format=df),
         name="vgg_enc"):
     super(VGG_Encoder, self).__init__(name=name)
     self._num_channels = num_channels
     self._nonlinearity = nonlinearity
     self._num_convs = num_convs_per_block
     self._initializers = initializers
     self._regularizers = regularizers
     self._data_format = data_format
     self._down_sampling_op = down_sampling_op
示例#7
0
def down_block(
        features,
        output_channels,
        kernel_shape,
        stride=1,
        rate=1,
        num_convs=2,
        initializers={
            'w': he_normal(),
            'b': tf.truncated_normal_initializer(stddev=0.001)
        },
        regularizers=None,
        nonlinearity=tf.nn.relu,
        down_sample_input=True,
        down_sampling_op=lambda x, size: tf.image.resize_images(
            x, size, method=tf.image.ResizeMethod.BILINEAR, align_corners=True
        ),
        data_format='NCHW',
        name='down_block'):
    """A block made up of a down-sampling step followed by several convolutional layers."""
    with tf.variable_scope(name):
        if down_sample_input:
            features = down_sampling_op(features, data_format)

        for _ in range(num_convs):
            features = snt.Conv2D(output_channels,
                                  kernel_shape,
                                  stride,
                                  rate,
                                  data_format=data_format,
                                  initializers=initializers,
                                  regularizers=regularizers)(features)
            features = nonlinearity(features)

        return features
    def __init__(self,
                 latent_dim,
                 num_channels,
                 num_classes,
                 num_1x1_convs=3,
                 nonlinearity=tf.nn.relu,
                 num_convs_per_block=3,
                 initializers={'w': he_normal(), 'b': tf.truncated_normal_initializer(stddev=0.001)},
                 regularizers={'w': tf.contrib.layers.l2_regularizer(1.0), 'b': tf.contrib.layers.l2_regularizer(1.0)},
                 data_format='NCHW',
                 down_sampling_op=lambda x, df:\
                         tf.nn.avg_pool(x, ksize=[1,1,2,2], strides=[1,1,2,2], padding='SAME', data_format=df),
                 up_sampling_op=lambda x, size:\
                         tf.image.resize_images(x, size, method=tf.image.ResizeMethod.BILINEAR, align_corners=True),
                 name='prob_unet'):
        super(ProbUNet, self).__init__(name=name)
        self._data_format = data_format
        self._num_classes = num_classes

        with self._enter_variable_scope():
            self._unet = UNet(num_channels=num_channels,
                              num_classes=num_classes,
                              nonlinearity=nonlinearity,
                              num_convs_per_block=num_convs_per_block,
                              initializers=initializers,
                              regularizers=regularizers,
                              data_format=data_format,
                              down_sampling_op=down_sampling_op,
                              up_sampling_op=up_sampling_op)

            self._f_comb = Conv1x1Decoder(num_classes=num_classes,
                                          num_1x1_convs=num_1x1_convs,
                                          num_channels=num_channels[0],
                                          nonlinearity=nonlinearity,
                                          data_format=data_format,
                                          initializers=initializers,
                                          regularizers=regularizers)

            self._prior =\
                AxisAlignedConvGaussian(latent_dim=latent_dim, num_channels=num_channels,
                                        nonlinearity=nonlinearity, num_convs_per_block=num_convs_per_block,
                                        initializers=initializers, regularizers=regularizers, name='prior')

            self._posterior =\
                AxisAlignedConvGaussian(latent_dim=latent_dim, num_channels=num_channels,
                                        nonlinearity=nonlinearity, num_convs_per_block=num_convs_per_block,
                                        initializers=initializers, regularizers=regularizers, name='posterior')
示例#9
0
 def __init__(self,
              num_channels,
              num_classes,
              nonlinearity=tf.nn.relu,
              num_convs_per_block=3,
              initializers={'w': he_normal(), 'b': tf.truncated_normal_initializer(stddev=0.001)},
              regularizers=None,
              data_format='NCHW',
              down_sampling_op=lambda x, df: tf.nn.avg_pool(x, ksize=[1,1,2,2], strides=[1,1,2,2],
                                             padding='SAME', data_format=df),
              up_sampling_op=lambda x, size: tf.image.resize_images(x, size,
                                             method=tf.image.ResizeMethod.BILINEAR, align_corners=True),
              name="unet"):
     super(UNet, self).__init__(name=name)
     with self._enter_variable_scope():
         tf.logging.info('Building U-Net.')
         self._encoder = VGG_Encoder(num_channels, nonlinearity, num_convs_per_block, initializers, regularizers,
                                     data_format=data_format, down_sampling_op=down_sampling_op)
         self._decoder = VGG_Decoder(num_channels, num_classes, nonlinearity, num_convs_per_block, initializers,
                                     regularizers, data_format=data_format, up_sampling_op=up_sampling_op)
def write_test_predictions(cf):
    """
    Write samples as numpy arrays.
    :param cf: config module
    :return:
    """
    # do not use all gpus
    os.environ["CUDA_VISIBLE_DEVICES"] = cf.cuda_visible_devices

    data_dir = os.path.join(cf.data_dir, cf.resolution)
    data_dict = loadFiles(label_density=cf.label_density,
                          split='val',
                          input_path=data_dir,
                          cities=None,
                          instance=False)
    # prepare out_dir
    if not os.path.isdir(cf.out_dir):
        os.mkdir(cf.out_dir)

    logging.info('Writing to {}'.format(cf.out_dir))

    # initialize computation graph
    prob_unet = ProbUNet(latent_dim=cf.latent_dim,
                         num_channels=cf.num_channels,
                         num_1x1_convs=cf.num_1x1_convs,
                         num_classes=cf.num_classes,
                         num_convs_per_block=cf.num_convs_per_block,
                         initializers={
                             'w': training_utils.he_normal(),
                             'b': tf.truncated_normal_initializer(stddev=0.001)
                         },
                         regularizers={
                             'w': tf.contrib.layers.l2_regularizer(1.0),
                             'b': tf.contrib.layers.l2_regularizer(1.0)
                         })
    x = tf.placeholder(tf.float32, shape=cf.network_input_shape)

    with tf.device(cf.gpu_device):
        prob_unet(x, is_training=False, one_hot_labels=cf.one_hot_labels)
        sampled_logits = prob_unet.sample()

    saver = tf.train.Saver(save_relative_paths=True)
    with tf.train.MonitoredTrainingSession() as sess:

        print('EXP DIR', cf.exp_dir)
        latest_ckpt_path = tf.train.latest_checkpoint(cf.exp_dir)
        print('CKPT PATH', latest_ckpt_path)
        saver.restore(sess, latest_ckpt_path)

        for k, v in tqdm(data_dict.items()):
            img = np.load(v['data']) / 255.
            # add batch dimensions
            img = img[np.newaxis]

            for i in range(cf.num_samples):
                sample = sess.run(sampled_logits, feed_dict={x: img})
                sample = np.argmax(sample, axis=1)[:, np.newaxis]
                sample = sample.astype(np.uint8)
                sample_path = os.path.join(
                    cf.out_dir, '{}_sample{}_labelIds.npy'.format(k, i))
                np.save(sample_path, sample)
def train(cf):
    """Perform training from scratch."""

    # do not use all gpus
    os.environ["CUDA_VISIBLE_DEVICES"] = cf.cuda_visible_devices

    # initialize data providers
    data_provider = get_train_generators(cf)
    train_provider = data_provider['train']
    val_provider = data_provider['val']

    prob_unet = ProbUNet(
        latent_dim=cf.latent_dim,
        num_channels=cf.num_channels,
        num_1x1_convs=cf.num_1x1_convs,
        num_classes=cf.num_classes,
        num_convs_per_block=cf.num_convs_per_block,
        initializers={
            'w': training_utils.he_normal(),
            'b': tf.truncated_normal_initializer(stddev=0.001)
        },
        regularizers={'w': tf.contrib.layers.l2_regularizer(1.0)})

    x = tf.placeholder(tf.float32, shape=cf.network_input_shape)
    y = tf.placeholder(tf.uint8, shape=cf.label_shape)
    mask = tf.placeholder(tf.uint8, shape=cf.loss_mask_shape)

    global_step = tf.train.get_or_create_global_step()

    if cf.learning_rate_schedule == 'piecewise_constant':
        learning_rate = tf.train.piecewise_constant(x=global_step,
                                                    **cf.learning_rate_kwargs)
    else:
        learning_rate = tf.train.exponential_decay(
            learning_rate=cf.initial_learning_rate,
            global_step=global_step,
            **cf.learning_rate_kwargs)
    with tf.device(cf.gpu_device):
        prob_unet(x, y, is_training=True, one_hot_labels=cf.one_hot_labels)
        elbo = prob_unet.elbo(y,
                              reconstruct_posterior_mean=cf.use_posterior_mean,
                              beta=cf.beta,
                              loss_mask=mask,
                              analytic_kl=cf.analytic_kl,
                              one_hot_labels=cf.one_hot_labels)
        reconstructed_logits = prob_unet._rec_logits
        sampled_logits = prob_unet.sample()

        reg_loss = cf.regularizarion_weight * tf.reduce_sum(
            tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES))
        loss = -elbo + reg_loss
        rec_loss = prob_unet._rec_loss_mean
        kl = prob_unet._kl

        mean_val_rec_loss = tf.placeholder(tf.float32,
                                           shape=(),
                                           name="mean_val_rec_loss")
        mean_val_kl = tf.placeholder(tf.float32, shape=(), name="mean_val_kl")

    optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate).minimize(
        loss, global_step=global_step)

    # prepare tf summaries
    train_elbo_summary = tf.summary.scalar('train_elbo', elbo)
    train_kl_summary = tf.summary.scalar('train_kl', kl)
    train_rec_loss_summary = tf.summary.scalar('rec_loss', rec_loss)
    train_loss_summary = tf.summary.scalar('train_loss', loss)
    reg_loss_summary = tf.summary.scalar('train_reg_loss', reg_loss)
    lr_summary = tf.summary.scalar('learning_rate', learning_rate)
    beta_summary = tf.summary.scalar('beta', cf.beta)
    training_summary_op = tf.summary.merge([
        train_loss_summary, reg_loss_summary, lr_summary, train_elbo_summary,
        train_kl_summary, train_rec_loss_summary, beta_summary
    ])
    batches_per_second = tf.placeholder(tf.float32,
                                        shape=(),
                                        name="batches_per_sec_placeholder")
    timing_summary = tf.summary.scalar('batches_per_sec', batches_per_second)
    val_rec_loss_summary = tf.summary.scalar('val_loss', mean_val_rec_loss)
    val_kl_summary = tf.summary.scalar('val_kl', mean_val_kl)
    validation_summary_op = tf.summary.merge(
        [val_rec_loss_summary, val_kl_summary])

    tf.global_variables_initializer()

    # Add ops to save and restore all the variables.
    saver_hook = tf.train.CheckpointSaverHook(
        checkpoint_dir=cf.exp_dir,
        save_steps=cf.save_every_n_steps,
        saver=tf.train.Saver(save_relative_paths=True))
    # save config
    shutil.copyfile(cf.config_path, os.path.join(cf.exp_dir, 'used_config.py'))

    with tf.train.MonitoredTrainingSession(hooks=[saver_hook]) as sess:
        summary_writer = tf.summary.FileWriter(cf.exp_dir, sess.graph)
        logging.info('Model: {}'.format(cf.exp_dir))

        for i in tqdm(range(cf.n_training_batches),
                      disable=cf.disable_progress_bar):

            start_time = time.time()
            train_batch = next(train_provider)
            _, train_summary = sess.run(
                [optimizer, training_summary_op],
                feed_dict={
                    x: train_batch['data'],
                    y: train_batch['seg'],
                    mask: train_batch['loss_mask']
                })
            summary_writer.add_summary(train_summary, i)
            time_delta = time.time() - start_time
            train_speed = sess.run(
                timing_summary,
                feed_dict={batches_per_second: 1. / time_delta})
            summary_writer.add_summary(train_speed, i)

            # validation
            if i % cf.validation['every_n_batches'] == 0:

                train_rec = sess.run(reconstructed_logits,
                                     feed_dict={
                                         x: train_batch['data'],
                                         y: train_batch['seg']
                                     })
                image_path = os.path.join(
                    cf.exp_dir, 'batch_{}_train_reconstructions.png'.format(
                        i // cf.validation['every_n_batches']))
                training_utils.plot_batch(train_batch,
                                          train_rec,
                                          num_classes=cf.num_classes,
                                          cmap=cf.color_map,
                                          out_dir=image_path)

                running_mean_val_rec_loss = 0.
                running_mean_val_kl = 0.

                for j in range(cf.validation['n_batches']):
                    val_batch = next(val_provider)
                    val_rec, val_sample, val_rec_loss, val_kl =\
                     sess.run([reconstructed_logits, sampled_logits, rec_loss, kl],
                         feed_dict={x: val_batch['data'], y: val_batch['seg'], mask: val_batch['loss_mask']})
                    running_mean_val_rec_loss += val_rec_loss / cf.validation[
                        'n_batches']
                    running_mean_val_kl += val_kl / cf.validation['n_batches']

                    if j == 0:
                        image_path = os.path.join(
                            cf.exp_dir,
                            'batch_{}_val_reconstructions.png'.format(
                                i // cf.validation['every_n_batches']))
                        training_utils.plot_batch(val_batch,
                                                  val_rec,
                                                  num_classes=cf.num_classes,
                                                  cmap=cf.color_map,
                                                  out_dir=image_path)
                        image_path = os.path.join(
                            cf.exp_dir, 'batch_{}_val_samples.png'.format(
                                i // cf.validation['every_n_batches']))

                        for _ in range(3):
                            val_sample_ = sess.run(sampled_logits,
                                                   feed_dict={
                                                       x: val_batch['data'],
                                                       y: val_batch['seg']
                                                   })
                            val_sample = np.concatenate(
                                [val_sample, val_sample_], axis=1)

                        training_utils.plot_batch(val_batch,
                                                  val_sample,
                                                  num_classes=cf.num_classes,
                                                  cmap=cf.color_map,
                                                  out_dir=image_path)

                val_summary = sess.run(validation_summary_op,
                                       feed_dict={
                                           mean_val_rec_loss:
                                           running_mean_val_rec_loss,
                                           mean_val_kl: running_mean_val_kl
                                       })
                summary_writer.add_summary(val_summary, i)

                if cf.disable_progress_bar:
                    logging.info('Evaluating epoch {}/{}: validation loss={}, kl={}'\
                        .format(i, cf.n_training_batches, running_mean_val_rec_loss, running_mean_val_kl))

            sess.run(global_step)
示例#12
0
def sample(cf, args):
    """Sampling from the learnt conditional distribution."""

    sample_size = args.sample_size
    time_stamp = args.time_stamp
    ckpt_dir = os.path.join(cf.project_dir, 'experiments', time_stamp)
    sample_dir = os.path.join(cf.project_dir, 'samples', time_stamp)

    if not os.path.exists(sample_dir):
        os.mkdir(sample_dir)

    log_path = os.path.join(sample_dir, 'sampling_stat.log')
    logging.basicConfig(filename=log_path, level=logging.DEBUG, filemode='a')

    prl_dncnn = PRL(latent_dim=cf.latent_dim,
                    output_channels=cf.output_channels,
                    num_channels=cf.num_channels,
                    det_net_depth=cf.det_net_depth,
                    merging_depth=cf.merging_depth,
                    num_convs_per_block=cf.num_convs_per_block,
                    initializers={
                        'w': training_utils.he_normal(),
                        'b': tf.truncated_normal_initializer(stddev=0.001)
                    },
                    regularizers={'w': tf.contrib.layers.l2_regularizer(1.0)},
                    data_format=cf.data_format,
                    name='prl_dncnn')

    x = tf.placeholder(tf.float32,
                       shape=cf.network_input_shape,
                       name='observation')
    y = tf.placeholder(tf.float32,
                       shape=cf.network_output_shape,
                       name='ground_truth')
    is_training = tf.placeholder(tf.bool)

    prl_dncnn(x, y, is_training, is_inference=True)
    sampled_imgs = prl_dncnn.inference_sample(x)

    saver = tf.train.Saver()

    [val_data_noisy_list,
     val_data_clean_list] = test_data_list(img_dir=cf.validation_data_dir,
                                           noise_type=cf.noise_type,
                                           noise_param=cf.noise_param,
                                           data_format=cf.data_format)
    num_data = len(val_data_clean_list)

    with tf.Session() as sess:
        saver.restore(sess, tf.train.latest_checkpoint(ckpt_dir))
        for i in tqdm(range(num_data)):
            restored_samples = []
            sampling_start_time = time.time()
            for j in range(sample_size):
                smpl_img = sess.run(sampled_imgs,
                                    feed_dict={
                                        x: val_data_noisy_list[i],
                                        is_training: False
                                    })
                restored_samples.append(smpl_img)
            sampling_time_delta = time.time() - sampling_start_time
            restored_samples = np.asarray(restored_samples)

            if cf.data_format == 'NCHW':
                restored_samples = np.squeeze(restored_samples, axis=(1, 2))
            else:
                restored_samples = np.squeeze(restored_samples, axis=(1, 4))

            save_path = os.path.join(
                sample_dir,
                '{}_img{}_t{}_s{}.npy'.format(cf.validation_data_name, i,
                                              time_stamp, sample_size))
            np.save(save_path, restored_samples)
            logging.info(
                '{}s used for image {} of sample size {}, average time: {} s/sample'
                .format(sampling_time_delta, i, sample_size,
                        sampling_time_delta / sample_size))

    val_noisy_path = os.path.join(
        sample_dir, '{}_val_noisy.npy'.format(cf.validation_data_name))
    val_clean_path = os.path.join(
        sample_dir, '{}_val_clean.npy'.format(cf.validation_data_name))

    for i in range(len(val_data_noisy_list)):
        if cf.data_format == 'NCHW':
            val_data_noisy_list[i] = np.squeeze(val_data_noisy_list[i],
                                                axis=(0, 1))
            val_data_clean_list[i] = np.squeeze(val_data_clean_list[i],
                                                axis=(0, 1))
        else:
            val_data_noisy_list[i] = np.squeeze(val_data_noisy_list[i],
                                                axis=(0, 3))
            val_data_clean_list[i] = np.squeeze(val_data_clean_list[i],
                                                axis=(0, 3))

    np.save(val_noisy_path, np.asarray(val_data_noisy_list))
    np.save(val_clean_path, np.asarray(val_data_clean_list))
示例#13
0
def train(cf):
    """Perform training of PRL with DnCNN from scratch."""

    if cf.use_single_gpu:
        os.environ['CUDA_VISIBLE_DEVICES'] = cf.cuda_visible_devices

    train_dataset = train_generator(img_dir=cf.training_data_dir, data_format=cf.data_format,
                                    shuffle_every_n_epochs=cf.shuffle_every_n_epochs, batch_size=cf.batch_size,
                                    noise_type=cf.noise_type, noise_param=cf.noise_param)

    prl_dncnn = PRL(latent_dim=cf.latent_dim,
                    output_channels=cf.output_channels,
                    num_channels=cf.num_channels,
                    det_net_depth=cf.det_net_depth,
                    merging_depth=cf.merging_depth,
                    num_convs_per_block=cf.num_convs_per_block,
                    initializers={'w': training_utils.he_normal(),
                                  'b': tf.truncated_normal_initializer(stddev=0.001)},
                    regularizers={'w': tf.contrib.layers.l2_regularizer(1.0)},
                    data_format=cf.data_format,
                    name='prl_dncnn')

    x = tf.placeholder(tf.float32, shape=cf.network_input_shape, name='observation')
    y = tf.placeholder(tf.float32, shape=cf.network_output_shape, name='ground_truth')
    beta = tf.placeholder(tf.float32, shape=(), name='beta')
    is_training = tf.placeholder(tf.bool)

    global_step = tf.train.get_or_create_global_step()

    if cf.learning_rate_schedule == 'piecewise_constant':
        learning_rate = tf.train.piecewise_constant(x=global_step, **cf.learning_rate_kwargs)
    else:
        learning_rate = tf.train.exponential_decay(learning_rate=cf.initial_learning_rate,
                                                   global_step=global_step,
                                                   **cf.learning_rate_kwargs)

    prl_dncnn(x, y, is_training, is_inference=False)

    model_loss = prl_dncnn.loss_mini_batch(x, y, beta=beta, analytic_kl=cf.analytic_kl, use_ref_mean=cf.use_ref_mean)
    reg_loss = cf.regularization_weight * tf.reduce_sum(tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES))
    loss = model_loss + reg_loss

    # prepare for summaries
    ref_rec_loss = prl_dncnn._rec_loss
    kl = prl_dncnn._kl_val

    ref_sample = prl_dncnn._ref_sample
    inf_sample = prl_dncnn.inference_sample(x)

    gt_unnormed = training_utils.img_unnorm(y, cf.noise_type, cf.noise_param)
    ref_unnormed = training_utils.img_unnorm(ref_sample, cf.noise_type, cf.noise_param)
    inf_unnormed = training_utils.img_unnorm(inf_sample, cf.noise_type, cf.noise_param)

    if cf.data_format == 'NCHW':
        gt_unnormed = tf.transpose(gt_unnormed, perm=(0, 2, 3, 1))
        ref_unnormed = tf.transpose(ref_unnormed, perm=(0, 2, 3, 1))
        inf_unnormed = tf.transpose(inf_unnormed, perm=(0, 2, 3, 1))

    # can be used for both training and validation
    # ref_avg_mse = tf.reduce_mean(tf.metrics.mean_squared_error(gt_unnormed, ref_unnormed))
    # ref_avg_ssim = tf.reduce_mean(tf.image.ssim(gt_unnormed, ref_unnormed, max_val=cf.peak_val))
    # ref_avg_psnr = tf.reduce_mean(tf.image.psnr(gt_unnormed, ref_unnormed, max_val=cf.peak_val))
    inf_avg_mse = tf.reduce_mean(tf.metrics.mean_squared_error(gt_unnormed, inf_unnormed))
    inf_avg_ssim = tf.reduce_mean(tf.image.ssim(gt_unnormed, inf_unnormed, max_val=cf.peak_val))
    inf_avg_psnr = tf.reduce_mean(tf.image.psnr(gt_unnormed, inf_unnormed, max_val=cf.peak_val))

    # ---------------------------- training summaries
    # loss summaries
    train_ref_rec_loss_summary = tf.summary.scalar('train_ref_rec_loss', ref_rec_loss)
    train_kl_summary = tf.summary.scalar('train_kl', kl)
    train_model_loss_summary = tf.summary.scalar('train_elbo', model_loss)
    train_reg_loss_summary = tf.summary.scalar('train_reg_loss', reg_loss)
    train_loss_summary = tf.summary.scalar('train_loss', loss)
    # quantitative indicator summaries (not in use during training, to save time)
    # train_ref_avg_mse_summary = tf.summary.scalar('train_reference_avg_mse', ref_avg_mse)
    # train_ref_avg_ssim_summary = tf.summary.scalar('train_reference_avg_ssim', ref_avg_ssim)
    # train_ref_avg_psnr_summary = tf.summary.scalar('train_reference_avg_psnr', ref_avg_psnr)
    # train_inf_avg_mse_summary = tf.summary.scalar('train_reference_avg_mse', inf_avg_mse)
    # train_inf_avg_ssim_summary = tf.summary.scalar('train_reference_avg_ssim', inf_avg_ssim)
    # train_inf_avg_psnr_summary = tf.summary.scalar('train_reference_avg_psnr', inf_avg_psnr)
    # hyper-parameter summaries
    lr_summary = tf.summary.scalar('learning_rate', learning_rate)
    beta_summary = tf.summary.scalar('kl_beta', beta)
    # merging summaries
    train_summary_op = tf.summary.merge([lr_summary, beta_summary,
                                         train_loss_summary,
                                         train_model_loss_summary,
                                         train_reg_loss_summary,
                                         train_ref_rec_loss_summary,
                                         train_kl_summary])

    # ---------------------------- timing summaries
    batches_per_second = tf.placeholder(tf.float32, shape=(), name='batches_per_second')
    timing_summary = tf.summary.scalar('batches_per_sec', batches_per_second)

    # ---------------------------- validation summaries
    val_avg_ref_rec_loss = tf.placeholder(tf.float32, shape=(), name='mean_val_ref_rec_loss')
    val_avg_kl = tf.placeholder(tf.float32, shape=(), name='mean_val_kl')
    # val_avg_ref_mse = tf.placeholder(tf.float32, shape=(), name='mean_val_ref_mse')
    # val_avg_ref_ssim = tf.placeholder(tf.float32, shape=(), name='mean_val_ref_ssim')
    # val_avg_ref_psnr = tf.placeholder(tf.float32, shape=(), name='mean_val_ref_psnr')
    val_avg_inf_mse = tf.placeholder(tf.float32, shape=(), name='mean_val_inf_mse')
    val_avg_inf_ssim = tf.placeholder(tf.float32, shape=(), name='mean_val_inf_ssim')
    val_avg_inf_psnr = tf.placeholder(tf.float32, shape=(), name='mean_val_inf_psnr')

    val_ref_rec_loss_summary = tf.summary.scalar('validation_ref_rec_loss', val_avg_ref_rec_loss)
    val_kl_summary = tf.summary.scalar('valiation_kl', val_avg_kl)
    val_avg_inf_mse_summary = tf.summary.scalar('validation_avg_inf_mse', val_avg_inf_mse)
    val_avg_inf_ssim_summary = tf.summary.scalar('validation_avg_inf_ssim', val_avg_inf_ssim)
    val_avg_inf_psnr_summary = tf.summary.scalar('validation_avg_inf_psnr', val_avg_inf_psnr)

    validation_summary_op = tf.summary.merge([val_ref_rec_loss_summary,
                                              val_kl_summary,
                                              val_avg_inf_mse_summary,
                                              val_avg_inf_ssim_summary,
                                              val_avg_inf_psnr_summary])

    update_op = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
    with tf.control_dependencies(update_op):
        optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate).minimize(loss, global_step=global_step)

    tf.global_variables_initializer()

    saver_hook = tf.train.CheckpointSaverHook(checkpoint_dir=cf.experiment_dir,
                                              save_steps=cf.save_every_n_steps,
                                              saver=tf.train.Saver(save_relative_paths=True))

    shutil.copyfile(cf.config_path, os.path.join(cf.experiment_dir, 'used_config.py'))

    with tf.train.MonitoredTrainingSession(hooks=[saver_hook]) as sess:
        summary_writer = tf.summary.FileWriter(cf.experiment_dir, sess.graph)
        logging.info('Model: {}'.format(cf.experiment_dir))

        training_start_time = time.time()
        for i in tqdm(range(cf.num_training_batches), disable=cf.disable_progress_bar):
            sess_start_time = time.time()
            [train_data_noisy, train_data_clean] = next(train_dataset)
            _, train_summary = sess.run([optimizer, train_summary_op],
                                        feed_dict={x: train_data_noisy, y: train_data_clean,
                                                   beta: cf.kl_weight, is_training: True})
            summary_writer.add_summary(train_summary, i)
            sess_time_delta = time.time() - sess_start_time

            train_speed = sess.run(timing_summary, feed_dict={batches_per_second: 1. / sess_time_delta})
            summary_writer.add_summary(train_speed, i)

            if i % cf.batches_per_epoch == 0:
                running_avg_val_ref_rec_loss = 0.
                running_avg_val_kl = 0.
                running_avg_val_inf_mse = 0.
                running_avg_val_inf_ssim = 0.
                running_avg_val_inf_psnr = 0.

                [val_data_noisy_list, val_data_clean_list] = test_data_list(img_dir=cf.validation_data_dir,
                                                                            noise_type=cf.noise_type,
                                                                            noise_param=cf.noise_param,
                                                                            data_format=cf.data_format)
                num_val_data = len(val_data_clean_list)
                val_ref_img_list = []
                val_inf_img_list = []
                for j in range(num_val_data):
                    val_ref_img, val_inf_img, val_ref_rec_loss, val_kl, val_inf_mse, val_inf_ssim, val_inf_psnr = \
                        sess.run([ref_unnormed, inf_unnormed,
                                  ref_rec_loss, kl, inf_avg_mse, inf_avg_ssim, inf_avg_psnr],
                                 feed_dict={x: val_data_noisy_list[j], y: val_data_clean_list[j],
                                            beta: cf.kl_weight, is_training: False})

                    running_avg_val_ref_rec_loss += val_ref_rec_loss / num_val_data
                    running_avg_val_kl += val_kl / num_val_data
                    running_avg_val_inf_mse += val_inf_mse / num_val_data
                    running_avg_val_inf_ssim += val_inf_ssim / num_val_data
                    running_avg_val_inf_psnr += val_inf_psnr / num_val_data

                    val_ref_img_list.append(val_ref_img)
                    val_inf_img_list.append(val_inf_img)

                image_path = os.path.join(cf.experiment_image_dir,
                                          'epoch_{}_val_samples.png'.format(i//cf.batches_per_epoch))
                training_utils.save_sample_img(val_data_clean_list, val_ref_img_list, val_inf_img_list,
                                               img_path=image_path,
                                               noise_type=cf.noise_type, noise_param=cf.noise_param,
                                               colormap=cf.colormap)

                val_summary = sess.run(validation_summary_op,
                                       feed_dict={val_avg_ref_rec_loss: running_avg_val_ref_rec_loss,
                                                  val_avg_kl: running_avg_val_kl,
                                                  val_avg_inf_mse: running_avg_val_inf_mse,
                                                  val_avg_inf_ssim: running_avg_val_inf_ssim,
                                                  val_avg_inf_psnr: running_avg_val_inf_psnr})
                summary_writer.add_summary(val_summary, i)

                if cf.disable_progress_bar:
                    logging.info('Evaluating epoch {}/{}: validation loss={}, kl={}'
                                 .format(i, cf.num_training_batches, running_avg_val_ref_rec_loss, running_avg_val_kl))
            sess.run(global_step)

        training_time_delta = time.time() - training_start_time
        logging.info('Total training time (with running time validations) is: %f' % training_time_delta)