Beispiel #1
0
 def build_graph(self):
     with tf.device(self.device):
         self.model = SRN(self.config)
         self.g_loss = self.model.build_train()
         self.global_step = tf.train.get_or_create_global_step()
         self.g_train_op = self.model.train(self.global_step)
         self.train_summary, self.loss_summary = self.model.get_summaries()
Beispiel #2
0
def main(argv=None):
    # arguments parsing
    import argparse
    argp = argparse.ArgumentParser()
    # testing parameters
    argp.add_argument('dataset')
    argp.add_argument('--num-epochs', type=int, default=1)
    argp.add_argument('--random-seed', type=int)
    argp.add_argument('--device', default='/gpu:0')
    argp.add_argument('--postfix', default='')
    argp.add_argument('--train-dir', default='./train{postfix}.tmp')
    argp.add_argument('--test-dir', default='./test{postfix}.tmp')
    argp.add_argument('--log-file', default='test.log')
    argp.add_argument('--batch-size', type=int, default=1)
    # data parameters
    argp.add_argument('--dtype', type=int, default=2)
    argp.add_argument('--data-format', default='NCHW')
    argp.add_argument('--patch-height', type=int, default=512)
    argp.add_argument('--patch-width', type=int, default=512)
    argp.add_argument('--in-channels', type=int, default=3)
    argp.add_argument('--out-channels', type=int, default=3)
    # pre-processing parameters
    Data.add_arguments(argp)
    # model parameters
    SRN.add_arguments(argp)
    argp.add_argument('--scaling', type=int, default=1)
    # parse
    args = argp.parse_args(argv)
    args.train_dir = args.train_dir.format(postfix=args.postfix)
    args.test_dir = args.test_dir.format(postfix=args.postfix)
    args.dtype = [tf.int8, tf.float16, tf.float32, tf.float64][args.dtype]
    args.pre_down = True
    # run testing
    test = Test(args)
    test()
Beispiel #3
0
 def build_graph(self):
     with tf.device(self.device):
         inputs = tf.placeholder(tf.float32, name='inputs')
         labels = tf.placeholder(tf.float32, name='labels')
         self.model = SRN(self.config)
         outputs = self.model.build_model(inputs)
         self.losses = list(test_losses(labels, outputs))
     # post-processing for output
     with tf.device('/cpu:0'):
         # convert to NHWC format
         if self.config.data_format == 'NCHW':
             inputs = tf.transpose(inputs, [0, 2, 3, 1])
             labels = tf.transpose(labels, [0, 2, 3, 1])
             outputs = tf.transpose(outputs, [0, 2, 3, 1])
         # PNG output
         self.pngs = (BatchPNG(inputs, self.batch_size) +
                      BatchPNG(labels, self.batch_size) +
                      BatchPNG(outputs, self.batch_size))
Beispiel #4
0
def main(argv=None):
    # arguments parsing
    import argparse
    argp = argparse.ArgumentParser()
    # training parameters
    argp.add_argument('dataset')
    argp.add_argument('--num-epochs', type=int, default=24)
    argp.add_argument('--max-steps', type=int)
    argp.add_argument('--random-seed', type=int)
    argp.add_argument('--device', default='/gpu:0')
    argp.add_argument('--postfix', default='')
    argp.add_argument('--pretrain-dir', default='')
    argp.add_argument('--train-dir', default='./train{postfix}.tmp')
    argp.add_argument('--restore', action='store_true')
    argp.add_argument('--save-steps', type=int, default=5000)
    argp.add_argument('--ckpt-period', type=int, default=600)
    argp.add_argument('--log-frequency', type=int, default=100)
    argp.add_argument('--log-file', default='train.log')
    argp.add_argument('--batch-size', type=int, default=32)
    argp.add_argument('--val-size', type=int, default=256)
    # data parameters
    argp.add_argument('--dtype', type=int, default=2)
    argp.add_argument('--data-format', default='NCHW')
    argp.add_argument('--patch-height', type=int, default=128)
    argp.add_argument('--patch-width', type=int, default=128)
    argp.add_argument('--in-channels', type=int, default=3)
    argp.add_argument('--out-channels', type=int, default=3)
    # pre-processing parameters
    Data.add_arguments(argp)
    # model parameters
    SRN.add_arguments(argp)
    argp.add_argument('--scaling', type=int, default=1)
    # parse
    args = argp.parse_args(argv)
    args.train_dir = args.train_dir.format(postfix=args.postfix)
    args.dtype = [tf.int8, tf.float16, tf.float32, tf.float64][args.dtype]
    # run training
    train = Train(args)
    train()
Beispiel #5
0
def main(argv=None):
    # arguments parsing
    import argparse
    argp = argparse.ArgumentParser()
    # testing parameters
    argp.add_argument('--postfix', default='')
    argp.add_argument('--train-dir', default='./train{postfix}.tmp')
    argp.add_argument('--model-dir', default='./model{postfix}.tmp')
    # data parameters
    argp.add_argument('--dtype', type=int, default=2)
    argp.add_argument('--data-format', default='NCHW')
    argp.add_argument('--in-channels', type=int, default=3)
    argp.add_argument('--out-channels', type=int, default=3)
    # model parameters
    SRN.add_arguments(argp)
    argp.add_argument('--scaling', type=int, default=1)
    # parse
    args = argp.parse_args(argv)
    args.train_dir = args.train_dir.format(postfix=args.postfix)
    args.model_dir = args.model_dir.format(postfix=args.postfix)
    args.dtype = [tf.int8, tf.float16, tf.float32, tf.float64][args.dtype]
    # save model
    graph = Graph(args)
    graph()
Beispiel #6
0
class Graph:
    def __init__(self, config):
        self.postfix = None
        self.train_dir = None
        self.model_dir = None
        # copy all the properties from config object
        self.config = config
        self.__dict__.update(config.__dict__)

    def initialize(self):
        # arXiv 1509.09308
        # a new class of fast algorithms for convolutional neural networks using Winograd's minimal filtering algorithms
        os.environ['TF_ENABLE_WINOGRAD_NONFUSED'] = '1'
        # create testing directory
        if not os.path.exists(self.train_dir):
            raise FileNotFoundError('Could not find folder {}'.format(
                self.train_dir))
        if os.path.exists(self.model_dir):
            eprint('Confirm removing {}\n[Y/n]'.format(self.model_dir))
            if input() != 'Y':
                import sys
                sys.exit()
            import shutil
            shutil.rmtree(self.model_dir, ignore_errors=True)
            eprint('Removed: ' + self.model_dir)
        os.makedirs(self.model_dir)

    def build_graph(self):
        self.model = SRN(self.config)
        self.model.build_model()

    def build_saver(self):
        # a Saver object to restore the variables with mappings
        self.saver_r = tf.train.Saver(self.model.rvars)
        # a Saver object to save the variables without mappings
        self.saver_s = tf.train.Saver(self.model.svars)

    def run(self, sess):
        # save the GraphDef
        tf.train.write_graph(tf.get_default_graph(),
                             self.model_dir,
                             'model.graphdef',
                             as_text=True)
        # restore variables from checkpoint
        self.saver_r.restore(sess, tf.train.latest_checkpoint(self.train_dir))
        # save the model parameters
        self.saver_s.export_meta_graph(os.path.join(self.model_dir,
                                                    'model.meta'),
                                       as_text=False,
                                       clear_devices=True,
                                       clear_extraneous_savers=True)
        self.saver_s.save(sess,
                          os.path.join(self.model_dir, 'model'),
                          write_meta_graph=False,
                          write_state=False)

    def __call__(self):
        self.initialize()
        with tf.Graph().as_default():
            self.build_graph()
            self.build_saver()
            with create_session() as sess:
                self.run(sess)
Beispiel #7
0
 def build_graph(self):
     self.model = SRN(self.config)
     self.model.build_model()
Beispiel #8
0
class Train:
    def __init__(self, config):
        self.random_seed = None
        self.device = None
        self.postfix = None
        self.pretrain_dir = None
        self.train_dir = None
        self.restore = None
        self.save_steps = None
        self.ckpt_period = None
        self.log_frequency = None
        self.log_file = None
        self.batch_size = None
        self.val_size = None
        # copy all the properties from config object
        self.config = config
        self.__dict__.update(config.__dict__)

    def initialize(self):
        # arXiv 1509.09308
        # a new class of fast algorithms for convolutional neural networks using Winograd's minimal filtering algorithms
        os.environ['TF_ENABLE_WINOGRAD_NONFUSED'] = '1'
        # create training directory
        if not self.restore:
            if os.path.exists(self.train_dir):
                eprint('Confirm removing {}\n[Y/n]'.format(self.train_dir))
                if input() != 'Y':
                    import sys
                    sys.exit()
                import shutil
                shutil.rmtree(self.train_dir, ignore_errors=True)
                eprint('Removed: ' + self.train_dir)
            os.makedirs(self.train_dir)
        # set deterministic random seed
        if self.random_seed is not None:
            reset_random(self.random_seed)

    def get_dataset(self):
        self.data = Data(self.config)
        self.epoch_steps = self.data.epoch_steps
        self.max_steps = self.data.max_steps
        # pre-computing validation set
        self.val_inputs = []
        self.val_labels = []
        for _inputs, _labels in self.data.gen_val():
            self.val_inputs.append(_inputs)
            self.val_labels.append(_labels)

    def build_graph(self):
        with tf.device(self.device):
            self.model = SRN(self.config)
            self.g_loss = self.model.build_train()
            self.global_step = tf.train.get_or_create_global_step()
            self.g_train_op = self.model.train(self.global_step)
            self.train_summary, self.loss_summary = self.model.get_summaries()

    def build_saver(self):
        # a Saver object to restore the variables with mappings
        # only for restoring from pre-trained model
        if self.pretrain_dir and not self.restore:
            self.saver_pt = tf.train.Saver(self.model.rvars)
        # a Saver object to save recent checkpoints
        self.saver_ckpt = tf.train.Saver(max_to_keep=5,
                                         save_relative_paths=True)
        # a Saver object to save the variables without mappings
        # used for saving checkpoints throughout the entire training progress
        self.saver = tf.train.Saver(self.model.svars,
                                    max_to_keep=1 << 16,
                                    save_relative_paths=True)
        # save the graph
        self.saver.export_meta_graph(os.path.join(self.train_dir,
                                                  'model.meta'),
                                     as_text=False,
                                     clear_devices=True,
                                     clear_extraneous_savers=True)

    def train_session(self):
        self.train_writer = tf.summary.FileWriter(self.train_dir + '/train',
                                                  tf.get_default_graph(),
                                                  max_queue=20,
                                                  flush_secs=120)
        self.val_writer = tf.summary.FileWriter(self.train_dir + '/val')
        return create_session()

    def run_sess(self,
                 sess,
                 global_step,
                 data_gen,
                 options=None,
                 run_metadata=None):
        from datetime import datetime
        import time
        epoch = global_step // self.epoch_steps
        last_step = global_step + 1 >= self.max_steps
        logging = last_step or (self.log_frequency > 0
                                and global_step % self.log_frequency == 0)
        # training - train op
        inputs, labels = next(data_gen)
        feed_dict = {
            self.model.g_training: True,
            'Input:0': inputs,
            'Label:0': labels
        }
        if logging:
            fetch = (self.train_summary, self.g_train_op,
                     self.model.g_losses_acc)
            summary, _, _ = sess.run(fetch, feed_dict, options, run_metadata)
            self.train_writer.add_summary(summary, global_step)
        else:
            fetch = (self.g_train_op, self.model.g_losses_acc)
            sess.run(fetch, feed_dict, options, run_metadata)
        # training - log summary
        if logging:
            # loss summary
            fetch = [self.loss_summary] + self.model.g_log_losses
            summary, train_loss = sess.run(fetch)
            self.train_writer.add_summary(summary, global_step)
            # logging
            time_current = time.time()
            duration = time_current - self.log_last
            self.log_last = time_current
            sec_batch = duration / self.log_frequency if self.log_frequency > 0 else 0
            samples_sec = self.batch_size / sec_batch
            train_log = ('{}: epoch {}, step {}, train loss: {:.5}'
                         ' ({:.1f} samples/sec, {:.3f} sec/batch)'.format(
                             datetime.now(), epoch, global_step, train_loss,
                             samples_sec, sec_batch))
            eprint(train_log)
        # validation
        if logging:
            for inputs, labels in zip(self.val_inputs, self.val_labels):
                feed_dict = {'Input:0': inputs, 'Label:0': labels}
                fetch = [self.model.g_losses_acc]
                sess.run(fetch, feed_dict)
            # loss summary
            fetch = [self.loss_summary] + self.model.g_log_losses
            summary, val_loss = sess.run(fetch)
            self.val_writer.add_summary(summary, global_step)
            # logging
            val_log = ('{}: epoch {}, step {}, val loss: {:.5}'.format(
                datetime.now(), epoch, global_step, val_loss))
            eprint(val_log)
        # log result for the last step
        if self.log_file and last_step:
            last_log = (
                'epoch {}, step {}, train loss: {:.5}, val loss: {:.5}'.format(
                    epoch, global_step, train_loss, val_loss))
            with open(self.log_file, 'a', encoding='utf-8') as fd:
                fd.write('Training No.{}\n'.format(self.postfix))
                fd.write(self.train_dir + '\n')
                fd.write('{}\n'.format(datetime.now()))
                fd.write(last_log + '\n\n')

    def run(self, sess):
        import time
        # restore from checkpoint
        if self.restore and os.path.exists(
                os.path.join(self.train_dir, 'checkpoint')):
            lastest_ckpt = tf.train.latest_checkpoint(self.train_dir,
                                                      'checkpoint')
            self.saver_ckpt.restore(sess, lastest_ckpt)
        # restore pre-trained model
        elif self.pretrain_dir:
            self.saver_pt.restore(sess, os.path.join(self.pretrain_dir,
                                                     'model'))
        # otherwise, initialize from start
        else:
            initializers = (tf.initializers.global_variables(),
                            tf.initializers.local_variables())
            sess.run(initializers)
        # profiler
        profile_offset = 1000 + self.log_frequency // 2
        profile_step = 10000
        builder = tf.profiler.ProfileOptionBuilder
        profiler = tf.profiler.Profiler(sess.graph)
        # initialization
        self.log_last = time.time()
        ckpt_last = time.time()
        # dataset generator
        global_step = tf.train.global_step(sess, self.global_step)
        data_gen = self.data.gen_main(global_step)
        # run training session
        while True:
            # global step
            global_step = tf.train.global_step(sess, self.global_step)
            if global_step >= self.max_steps:
                eprint('Training finished at step={}'.format(global_step))
                break
            # run session
            if global_step % profile_step == profile_offset:
                # profiling every few steps
                options = tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE)
                run_meta = tf.RunMetadata()
                self.run_sess(sess, global_step, data_gen, options, run_meta)
                profiler.add_step(global_step, run_meta)
                # profile the parameters
                if global_step == profile_offset:
                    ofile = os.path.join(self.train_dir, 'parameters.log')
                    profiler.profile_name_scope(
                        builder(builder.trainable_variables_parameter()).
                        with_file_output(ofile).build())
                # profile the timing of model operations
                ofile = os.path.join(
                    self.train_dir,
                    'time_and_memory_{:0>7}.log'.format(global_step))
                profiler.profile_operations(
                    builder(builder.time_and_memory()).with_file_output(
                        ofile).build())
                # generate a timeline
                timeline = os.path.join(self.train_dir, 'timeline')
                profiler.profile_graph(
                    builder(builder.time_and_memory()).with_step(
                        global_step).with_timeline_output(timeline).build())
            else:
                self.run_sess(sess, global_step, data_gen)
            # save checkpoints periodically or when training finished
            if self.ckpt_period > 0:
                time_current = time.time()
                if time_current - ckpt_last >= self.ckpt_period or global_step + 1 >= self.max_steps:
                    ckpt_last = time_current
                    self.saver_ckpt.save(
                        sess, os.path.join(self.train_dir, 'model.ckpt'),
                        global_step, 'checkpoint')
            # save model every few steps
            if self.save_steps > 0 and global_step % self.save_steps == 0:
                self.saver.save(sess,
                                os.path.join(
                                    self.train_dir,
                                    'model_{:0>7}'.format(global_step)),
                                write_meta_graph=False,
                                write_state=False)
        # auto detect problems and generate advice
        ALL_ADVICE = {
            'ExpensiveOperationChecker': {},
            'AcceleratorUtilizationChecker': {},
            'JobChecker': {},
            'OperationChecker': {}
        }
        profiler.advise(ALL_ADVICE)

    def __call__(self):
        self.initialize()
        self.get_dataset()
        with tf.Graph().as_default():
            self.build_graph()
            self.build_saver()
            with self.train_session() as sess:
                self.run(sess)
Beispiel #9
0
class Test:
    def __init__(self, config):
        self.random_seed = None
        self.device = None
        self.postfix = None
        self.train_dir = None
        self.test_dir = None
        self.log_file = None
        self.batch_size = None
        # copy all the properties from config object
        self.config = config
        self.__dict__.update(config.__dict__)

    def initialize(self):
        # arXiv 1509.09308
        # a new class of fast algorithms for convolutional neural networks using Winograd's minimal filtering algorithms
        os.environ['TF_ENABLE_WINOGRAD_NONFUSED'] = '1'
        # create testing directory
        if not os.path.exists(self.train_dir):
            raise FileNotFoundError('Could not find folder {}'.format(
                self.train_dir))
        if os.path.exists(self.test_dir):
            eprint('Confirm removing {}\n[Y/n]'.format(self.test_dir))
            if input() != 'Y':
                import sys
                sys.exit()
            import shutil
            shutil.rmtree(self.test_dir, ignore_errors=True)
            eprint('Removed: ' + self.test_dir)
        os.makedirs(self.test_dir)
        # set deterministic random seed
        if self.random_seed is not None:
            reset_random(self.random_seed)

    def get_dataset(self):
        self.data = Data(self.config)
        self.epoch_steps = self.data.epoch_steps
        self.max_steps = self.data.max_steps
        # pre-computing testing set
        self.test_inputs = []
        self.test_labels = []
        data_gen = self.data.gen_main()
        for _inputs, _labels in data_gen:
            self.test_inputs.append(_inputs)
            self.test_labels.append(_labels)

    def build_graph(self):
        with tf.device(self.device):
            inputs = tf.placeholder(tf.float32, name='inputs')
            labels = tf.placeholder(tf.float32, name='labels')
            self.model = SRN(self.config)
            outputs = self.model.build_model(inputs)
            self.losses = list(test_losses(labels, outputs))
        # post-processing for output
        with tf.device('/cpu:0'):
            # convert to NHWC format
            if self.config.data_format == 'NCHW':
                inputs = tf.transpose(inputs, [0, 2, 3, 1])
                labels = tf.transpose(labels, [0, 2, 3, 1])
                outputs = tf.transpose(outputs, [0, 2, 3, 1])
            # PNG output
            self.pngs = (BatchPNG(inputs, self.batch_size) +
                         BatchPNG(labels, self.batch_size) +
                         BatchPNG(outputs, self.batch_size))

    def build_saver(self):
        # a Saver object to restore the variables with mappings
        self.saver = tf.train.Saver(self.model.rvars)

    def run_last(self, sess):
        # latest checkpoint
        ckpt = tf.train.latest_checkpoint(self.train_dir)
        self.saver.restore(sess, ckpt)
        # to be fetched
        fetch = self.losses + self.pngs
        losses_sum = [0 for _ in range(len(self.losses))]
        # run session
        for step in range(self.epoch_steps):
            feed_dict = {
                'inputs:0': self.test_inputs[step],
                'labels:0': self.test_labels[step]
            }
            ret = sess.run(fetch, feed_dict)
            ret_losses = ret[0:len(self.losses)]
            ret_pngs = ret[len(self.losses):]
            # sum of losses
            for i in range(len(self.losses)):
                losses_sum[i] += ret_losses[i]
            # save images
            _start = step * self.batch_size
            _stop = _start + self.batch_size
            _range = range(_start, _stop)
            ofiles = (['{:0>5}.0.inputs.png'.format(i) for i in _range] +
                      ['{:0>5}.1.labels.png'.format(i) for i in _range] + [
                          '{:0>5}.2.outputs{}.png'.format(i, self.postfix)
                          for i in _range
                      ])
            ofiles = [os.path.join(self.test_dir, f) for f in ofiles]
            for i in range(len(ret_pngs)):
                with open(ofiles[i], 'wb') as fd:
                    fd.write(ret_pngs[i])
        # summary
        if self.log_file:
            from datetime import datetime
            losses_mean = [l / self.epoch_steps for l in losses_sum]
            psnr = 10 * np.log10(1 /
                                 losses_mean[0]) if losses_mean[0] > 0 else 100
            test_log = 'PSNR (RGB):{}, MAD (RGB): {}'\
                .format(psnr, *losses_mean[1:])
            with open(self.log_file, 'a', encoding='utf-8') as fd:
                fd.write('Testing No.{}\n'.format(self.postfix))
                fd.write(self.test_dir + '\n')
                fd.write('{}\n'.format(datetime.now()))
                fd.write(test_log + '\n\n')

    def run_steps(self, sess):
        import re
        prefix = 'model_'
        # get checkpoints of every few steps
        ckpts = listdir_files(self.train_dir,
                              recursive=False,
                              filter_ext=['.index'])
        ckpts = [os.path.splitext(f)[0] for f in ckpts if prefix in f]
        ckpts.sort()
        stats = []
        # test all the checkpoints
        for ckpt in ckpts:
            self.saver.restore(sess, ckpt)
            # to be fetched
            fetch = self.losses
            losses_sum = [0 for _ in range(len(self.losses))]
            # run session
            for step in range(self.epoch_steps):
                feed_dict = {
                    'inputs:0': self.test_inputs[step],
                    'labels:0': self.test_labels[step]
                }
                ret = sess.run(fetch, feed_dict)
                ret_losses = ret
                # sum of losses
                for i in range(len(self.losses)):
                    losses_sum[i] += ret_losses[i]
            # summary
            losses_mean = [l / self.epoch_steps for l in losses_sum]
            # stats
            ckpt_num = re.findall(prefix + r'(\d+)', ckpt)[0]
            stats.append(np.array([float(ckpt_num)] + losses_mean))
        # save stats
        import matplotlib.pyplot as plt
        stats = np.stack(stats)
        np.save(os.path.join(self.test_dir, 'stats.npy'), stats)
        # save plot
        fig, ax = plt.subplots()
        ax.set_title('Test Error with Training Progress')
        ax.set_xlabel('training steps')
        ax.set_ylabel('MAD (RGB)')
        ax.set_xscale('linear')
        ax.set_yscale('log')
        stats = stats[1:]
        ax.plot(stats[:, 0], stats[:, 2])
        plt.tight_layout()
        plt.savefig(os.path.join(self.test_dir, 'stats.png'))
        plt.close()

    def __call__(self):
        self.initialize()
        self.get_dataset()
        with tf.Graph().as_default():
            self.build_graph()
            self.build_saver()
            with create_session() as sess:
                self.run_last(sess)
                self.run_steps(sess)