Exemplo n.º 1
0
    def output_valstats(self,
                        sess,
                        summary_writer,
                        step,
                        batch_x,
                        batch_y,
                        name,
                        store_img=True,
                        get_loss_dict=False):
        if get_loss_dict:
            prediction, loss_dict, avg_psnr = sess.run(
                [
                    self.net.recons, self.net.valid_loss_dict,
                    self.net.valid_avg_psnr
                ],
                feed_dict={
                    self.net.x: batch_x,
                    self.net.y: batch_y,
                    self.net.keep_prob: 1.,
                    self.net.phase: False
                })
            loss = loss_dict['total_loss']
            for loss_name, loss_value in loss_dict.items():
                self.record_summary(summary_writer, 'valid_' + loss_name,
                                    loss_value, step)
        else:
            prediction, loss, avg_psnr = sess.run(
                [
                    self.net.recons, self.net.valid_loss,
                    self.net.valid_avg_psnr
                ],
                feed_dict={
                    self.net.x: batch_x,
                    self.net.y: batch_y,
                    self.net.keep_prob: 1.,
                    self.net.phase: False
                })
            self.record_summary(summary_writer, 'valid_loss', loss, step)

        self.record_summary(summary_writer, 'valid_avg_psnr', avg_psnr, step)

        # Xing
        if SHORT_INFO:
            logging.info("Iter {:}".format(step))
        else:
            logging.info(
                "Validation Statistics, validation loss= {:.4f}, Avg PSNR= {:.4f}"
                .format(loss, avg_psnr))

        util.save_mat(prediction, "%s/%s.mat" % (self.prediction_path, name))

        if store_img:
            if SAVE_MODE == 'Original':
                util.save_img(prediction[0, ...],
                              "%s/%s_img.tif" % (self.prediction_path, name))
            elif SAVE_MODE == 'Xiaojian':
                # Xiaojian's code
                img = util.concat_n_images(prediction)
                util.save_img(img,
                              "%s/%s_img.tif" % (self.prediction_path, name))
Exemplo n.º 2
0
    def output_valstats(self, sess, summary_writer, step, batch_x, batch_y, name, store_img=True):
        prediction, loss, avg_psnr = sess.run([self.net.recons,
                                                self.net.valid_loss,
                                                self.net.valid_avg_psnr], 
                                                feed_dict={self.net.x: batch_x, 
                                                            self.net.y: batch_y,
                                                            self.net.keep_prob: 1.,
                                                            self.net.phase: False})

        self.record_summary(summary_writer, 'valid_loss', loss, step)
        self.record_summary(summary_writer, 'valid_avg_psnr', avg_psnr, step)

        logging.info("Validation Statistics, validation loss= {:.4f}, Avg PSNR= {:.4f}".format(loss, avg_psnr))

        util.save_mat(prediction, "%s/%s.mat"%(self.prediction_path, name))

        if store_img:
            util.save_img(prediction[0,...], "%s/%s_img.tif"%(self.prediction_path, name))
Exemplo n.º 3
0
####              	  PREDICT                    ###
####################################################

predicts = []

valid_x, valid_y = valid_provider('full')
num = valid_x.shape[0]

for i in range(num):

    print('')
    print('')
    print('************* {} *************'.format(i))
    print('')
    print('')

    x_train, y_train = data_provider(23)
    x_input = valid_x[i:i+1,:,:,:]
    x_input = np.concatenate((x_input, x_train), axis=0)
    predict = net.predict(model_path, x_input, 1, True)
    predicts.append(predict[0:1,:,:])

predicts = np.concatenate(predicts, axis=0)
util.save_mat(predicts, 'test{}Noise.mat'.format(level))






Exemplo n.º 4
0
num = 25;
# valid_x, valid_y = valid_provider(5)
valid_x, valid_y = valid_provider(num, fix = True)
if comp_train_result:
    data_x, data_y = data_provider(num)




for i in range(num):

    print('')
    # print('')
    print('************* {}/{} *************'.format(i+1, num))
    print('')
    # print('')

    # x_train, y_train = data_provider(5)        
    x_input = valid_x[i:i+1,:,:,:]
    # x_input = np.concatenate((x_input, x_train), axis=0)    
    predict = net.predict(model_path, x_input, 1, False)
    predicts.append(predict[0:1,:,:])
    if comp_train_result:
        x_train = data_x[i:i+1,:,:,:]
        train_result = net.predict(model_path, x_train, 1, False)
        train_results.append(train_result[0:1,:,:])

predicts = np.concatenate(predicts, axis=0)
util.save_mat(predicts, folder_path+'test_{}.mat'.format(folder))
if comp_train_result:
    util.save_mat(train_results, folder_path+'train_result_{}.mat'.format(folder))
Exemplo n.º 5
0
    def train(self,
              data_provider,
              output_path,
              valid_provider,
              valid_size,
              training_iters=100,
              epochs=1000,
              dropout=0.75,
              display_step=1,
              save_epoch=50,
              restore=False,
              write_graph=False,
              prediction_path='validation'):
        """
        Lauches the training process
        
        :param data_provider: callable returning training and verification data
        :param output_path: path where to store checkpoints
        :param valid_provider: data provider for the validation dataset
        :param valid_size: batch size for validation provider
        :param training_iters: number of training mini batch iteration
        :param epochs: number of epochs
        :param dropout: dropout probability
        :param display_step: number of steps till outputting stats
        :param restore: Flag if previous model should be restored 
        :param write_graph: Flag if the computation graph should be written as protobuf file to the output path
        :param prediction_path: path where to save predictions on each epoch
        """

        # initialize the training process.
        init = self._initialize(training_iters, output_path, restore,
                                prediction_path)

        # create output path
        directory = os.path.join(output_path, "final/")
        if not os.path.exists(directory):
            os.makedirs(directory)

        save_path = os.path.join(directory, "model.cpkt")
        if epochs == 0:
            return save_path

        with tf.Session() as sess:
            if write_graph:
                tf.train.write_graph(sess.graph_def, output_path, "graph.pb",
                                     False)

            sess.run(init)

            if restore:
                ckpt = tf.train.get_checkpoint_state(output_path)
                if ckpt and ckpt.model_checkpoint_path:
                    self.net.restore(sess, ckpt.model_checkpoint_path)

            summary_writer = tf.summary.FileWriter(output_path,
                                                   graph=sess.graph)
            logging.info("Start optimization")

            # select validation dataset
            valid_x, valid_y = valid_provider(valid_size, fix=True)
            if SAVE_MODE == 'Original':
                util.save_mat(valid_y,
                              "%s/%s.mat" % (self.prediction_path, 'origin_y'))
                util.save_mat(valid_x,
                              "%s/%s.mat" % (self.prediction_path, 'origin_x'))
            # Xiaojian's code
            elif SAVE_MODE == 'Xiaojian':
                imgx = util.concat_n_images(valid_x)
                imgy = util.concat_n_images(valid_y)
                util.save_img(
                    imgx, "%s/%s_img.tif" % (self.prediction_path, 'trainOb'))
                util.save_img(
                    imgy, "%s/%s_img.tif" % (self.prediction_path, 'trainGt'))

            for epoch in range(epochs):
                total_loss = 0
                # batch_x, batch_y = data_provider(self.batch_size)
                for step in range((epoch * training_iters),
                                  ((epoch + 1) * training_iters)):
                    if not Y_RAND:
                        batch_x, batch_y = data_provider(self.batch_size)
                        batch_y_rand = batch_y
                    else:
                        batch_x, batch_y, batch_y_rand = data_provider(
                            self.batch_size, rand_y=True)
                    # Run optimization op (backprop)

                    if self.net.get_loss_dict:
                        _, loss_dict, lr, avg_psnr, train_output = sess.run(
                            [
                                self.optimizer, self.net.loss_dict,
                                self.learning_rate_node, self.net.avg_psnr,
                                self.net.recons
                            ],
                            feed_dict={
                                self.net.x: batch_x,
                                self.net.y: batch_y,
                                self.net.yRand: batch_y_rand,
                                self.net.keep_prob: dropout,
                                self.net.phase: True
                            })
                        loss = loss_dict['total_loss']
                    else:
                        _, loss, lr, avg_psnr, train_output = sess.run(
                            [
                                self.optimizer,
                                #_, loss, lr, avg_psnr = sess.run([self.optimizer,
                                self.net.loss,
                                self.learning_rate_node,
                                self.net.avg_psnr,
                                self.net.recons
                            ],
                            feed_dict={
                                self.net.x: batch_x,
                                self.net.y: batch_y,
                                self.net.yRand: batch_y_rand,
                                self.net.keep_prob: dropout,
                                self.net.phase: True
                            })

                    if step % display_step == 0:
                        # Changed here - Xing
                        # logging.info("Iter {:}".format(step))
                        logging.info(
                            "Iter {:} (before training on the batch) Minibatch MSE= {:.4f}, Minibatch Avg PSNR= {:.4f}"
                            .format(step, loss, avg_psnr))
                        self.output_minibatch_stats(sess, summary_writer, step,
                                                    batch_x, batch_y)

                    total_loss += loss

                    if self.net.get_loss_dict:
                        # print(type(loss_dict))
                        # print(loss_dict)
                        for loss_name, loss_value in loss_dict.items():
                            self.record_summary(summary_writer,
                                                'training_' + loss_name,
                                                loss_value, step)
                    else:
                        self.record_summary(summary_writer, 'training_loss',
                                            loss, step)

                    self.record_summary(summary_writer, 'training_avg_psnr',
                                        avg_psnr, step)

                # output statistics for epoch
                self.output_epoch_stats(epoch, total_loss, training_iters, lr)
                self.output_valstats(sess,
                                     summary_writer,
                                     step,
                                     valid_x,
                                     valid_y,
                                     "epoch_%s_valid" % epoch,
                                     store_img=True,
                                     get_loss_dict=self.net.get_loss_dict)
                # Xing
                if SAVE_TRAIN_PRED:
                    if SAVE_MODE == 'Original':
                        util.save_img(
                            train_output[0, ...], "%s/%s_img.tif" %
                            (self.prediction_path, "epoch_%s_train" % epoch))
                    elif SAVE_MODE == 'Xiaojian':
                        # Xiaojian's code
                        self.output_train_batch_stats(sess, epoch, batch_x,
                                                      batch_y)
                        # train_inputs = util.concat_n_images(batch_x)
                        # train_outputs = util.concat_n_images(train_output)
                        # train_targets = util.concat_n_images(batch_y)
                        # util.save_img(train_inputs, "%s/%s_img.tif"%(self.prediction_path, "epoch_%s_train_inputs"%epoch))
                        # util.save_img(train_outputs, "%s/%s_img.tif"%(self.prediction_path, "epoch_%s_train_outputs"%epoch))
                        # util.save_img(train_targets, "%s/%s_img.tif"%(self.prediction_path, "epoch_%s_train_targets"%epoch))

                if epoch % save_epoch == 0:
                    directory = os.path.join(output_path,
                                             "{}_cpkt/".format(step))
                    if not os.path.exists(directory):
                        os.makedirs(directory)
                    path = os.path.join(directory, "model.cpkt".format(step))
                    self.net.save(sess, path)

                save_path = self.net.save(sess, save_path)

            logging.info("Optimization Finished!")

            return save_path
Exemplo n.º 6
0
    def train(self, data_provider, output_path, valid_provider, valid_size, training_iters=100, epochs=1000, dropout=0.75, display_step=1, save_epoch=50, restore=False, write_graph=False, prediction_path='validation'):
        """
        Lauches the training process
        
        :param data_provider: callable returning training and verification data
        :param output_path: path where to store checkpoints
        :param valid_provider: data provider for the validation dataset
        :param valid_size: batch size for validation provider
        :param training_iters: number of training mini batch iteration
        :param epochs: number of epochs
        :param dropout: dropout probability
        :param display_step: number of steps till outputting stats
        :param restore: Flag if previous model should be restored 
        :param write_graph: Flag if the computation graph should be written as protobuf file to the output path
        :param prediction_path: path where to save predictions on each epoch
        """
        
        # initialize the training process.
        init = self._initialize(training_iters, output_path, restore, prediction_path)

        # create output path
        directory = os.path.join(output_path, "final/")
        if not os.path.exists(directory):
            os.makedirs(directory)

        save_path = os.path.join(directory, "model.cpkt")
        if epochs == 0:
            return save_path

        with tf.Session() as sess:
            if write_graph:
                tf.train.write_graph(sess.graph_def, output_path, "graph.pb", False)
            
            sess.run(init)
            
            if restore:
                ckpt = tf.train.get_checkpoint_state(output_path)
                if ckpt and ckpt.model_checkpoint_path:
                    self.net.restore(sess, ckpt.model_checkpoint_path)

            
            summary_writer = tf.summary.FileWriter(output_path, graph=sess.graph)
            logging.info("Start optimization")

            # select validation dataset
            valid_x, valid_y = valid_provider(valid_size, fix=True)
            util.save_mat(valid_y, "%s/%s.mat"%(self.prediction_path, 'origin_y'))
            util.save_mat(valid_x, "%s/%s.mat"%(self.prediction_path, 'origin_x'))

            for epoch in range(epochs):
                total_loss = 0
                # batch_x, batch_y = data_provider(self.batch_size)
                for step in range((epoch*training_iters), ((epoch+1)*training_iters)):
                    batch_x, batch_y = data_provider(self.batch_size)
                    # Run optimization op (backprop)
                    _, loss, lr, avg_psnr = sess.run([self.optimizer,
                                                        self.net.loss, 
                                                        self.learning_rate_node, 
                                                        self.net.avg_psnr], 
                                                        feed_dict={self.net.x: batch_x,
                                                                    self.net.y: batch_y,
                                                                    self.net.keep_prob: dropout,
                                                                    self.net.phase: True})
                    
                    if step % display_step == 0:
                        logging.info("Iter {:} (before training on the batch) Minibatch MSE= {:.4f}, Minibatch Avg PSNR= {:.4f}".format(step, loss, avg_psnr))
                        self.output_minibatch_stats(sess, summary_writer, step, batch_x, batch_y)
                        
                    total_loss += loss

                    self.record_summary(summary_writer, 'training_loss', loss, step)
                    self.record_summary(summary_writer, 'training_avg_psnr', avg_psnr, step)

                # output statistics for epoch
                self.output_epoch_stats(epoch, total_loss, training_iters, lr)
                self.output_valstats(sess, summary_writer, step, valid_x, valid_y, "epoch_%s"%epoch, store_img=True)

                if epoch % save_epoch == 0:
                    directory = os.path.join(output_path, "{}_cpkt/".format(step))
                    if not os.path.exists(directory):
                        os.makedirs(directory)
                    path = os.path.join(directory, "model.cpkt".format(step))      
                    self.net.save(sess, path)

                save_path = self.net.save(sess, save_path)

            logging.info("Optimization Finished!")
            
            return save_path