def bench():
    # inputs
    TRAINSET_PATH = r'..\Dataset.SR\Train'
    files = helper.listdir_files(TRAINSET_PATH,
                                 filter_ext=['.jpeg', '.jpg', '.png'],
                                 encoding=True)
    steps_per_epoch = len(files) // FLAGS.batch_size
    epoch_size = steps_per_epoch * FLAGS.batch_size
    max_steps = steps_per_epoch * FLAGS.num_epochs
    files = files[:epoch_size]

    with tf.device('/cpu:0'):
        images_lr, images_hr = inputs(files, is_training=True)

    # session
    config = tf.ConfigProto(log_device_placement=False)
    config.gpu_options.allow_growth = True
    with tf.train.MonitoredTrainingSession(
            config=config, log_step_count_steps=FLAGS.log_frequency) as sess:
        step = 0
        t = time.time()
        while not sess.should_stop():
            if step % FLAGS.log_frequency == 0:
                c_t = time.time()
                duration = c_t - t
                sps = FLAGS.batch_size * FLAGS.log_frequency / duration if duration > 0 else 0
                print('step {}: {} seconds, {} samples per second'.format(
                    step, duration, sps))
                t = c_t
            sess.run([images_lr, images_hr])
            step += 1
Example #2
0
def train():
    import random
    files = helper.listdir_files(FLAGS.dataset,
                                 filter_ext=['.jpeg', '.jpg', '.png'])
    random.shuffle(files)
    steps_per_epoch = len(files) // FLAGS.batch_size
    epoch_size = steps_per_epoch * FLAGS.batch_size
    max_steps = steps_per_epoch * FLAGS.num_epochs
    files = files[:epoch_size]
    print('epoch size: {}\n{} steps per epoch\n{} epochs\n{} steps'.format(
        epoch_size, steps_per_epoch, FLAGS.num_epochs, max_steps))
    
    # validation set
    if FLAGS.lr_decay_steps < 0 and FLAGS.lr_decay_factor != 0:
        val_size = min(FLAGS.batch_size * 50, epoch_size // (10 * FLAGS.batch_size) * FLAGS.batch_size)
        val_batches = val_size // FLAGS.batch_size
        val_files = files[: : (epoch_size + val_size - 1) // val_size]
        val_src_batches = []
        val_losses = []
        with tf.Graph().as_default():
            # dataset
            with tf.device('/cpu:0'):
                val_src = inputs(FLAGS, val_files, is_training=True)
            # session
            gpu_options = tf.GPUOptions(allow_growth=True)
            config = tf.ConfigProto(gpu_options=gpu_options)
            with tf.Session(config=config) as sess:
                for _ in range(val_batches):
                    _src = sess.run((val_src))
                    val_src_batches.append(_src)
    
    # main training graph
    with tf.Graph().as_default():
        # pre-processing for input
        with tf.device('/cpu:0'):
            images_src = inputs(FLAGS, files, is_training=True)
        
        # build model
        model = ICmodel(FLAGS, data_format=FLAGS.data_format,
            input_range=FLAGS.input_range, output_range=FLAGS.output_range,
            multiGPU=FLAGS.multiGPU, use_fp16=FLAGS.use_fp16,
            image_channels=FLAGS.image_channels, input_height=FLAGS.patch_height,
            input_width=FLAGS.patch_width, batch_size=FLAGS.batch_size)
        
        g_loss, d_loss = model.build_train(images_src)
        
        # lr decay operator
        def _get_val_window(lr, lr_last, lr_decay_op):
            with tf.variable_scope('validation_window') as scope:
                val_window = tf.Variable(30.0, trainable=False, dtype=tf.float64,
                    name='validation_window_size')
                val_window_inc_base = 10.0 * np.log(1 - FLAGS.lr_decay_factor) / np.log(0.5)
                val_window_inc = tf.Variable(val_window_inc_base, trainable=False,
                    dtype=tf.float64, name='validation_window_inc')
                tf.summary.scalar('val_window', val_window)
                tf.summary.scalar('val_window_inc', val_window_inc)
                with tf.control_dependencies([lr_decay_op]):
                    def f1_t(): # lr > learning_rate * 0.1
                        return tf.assign(val_window_inc, val_window_inc * 0.9, use_locking=True)
                    def f2_t(): # lr_last > learning_rate * 0.1 >= lr
                        return tf.assign(val_window_inc, val_window_inc_base, use_locking=True)
                    def f2_f(): # learning_rate * 0.1 >= lr_last
                        return tf.assign(val_window_inc, val_window_inc * 0.95, use_locking=True)
                    val_window_inc = tf.cond(lr > FLAGS.learning_rate * 0.1, f1_t,
                        lambda: tf.cond(lr_last > FLAGS.learning_rate * 0.1, f2_t, f2_f))
                    val_window_op = tf.assign_add(val_window, val_window_inc, use_locking=True)
                return val_window, val_window_op
        
        if FLAGS.lr_decay_steps < 0 and FLAGS.lr_decay_factor != 0:
            g_lr_decay_op = model.lr_decay()
            val_window, val_window_op = _get_val_window(model.g_lr, model.g_lr_last, g_lr_decay_op)
        
        # training step and op
        global_step = tf.train.get_or_create_global_step()
        g_train_op = model.train(global_step)
        
        # a saver object which will save all the variables
        saver = tf.train.Saver(var_list=model.g_svars,
            max_to_keep=1 << 16, save_relative_paths=True)
        
        if FLAGS.pretrain_dir and not FLAGS.restore:
            saver0 = tf.train.Saver(var_list=model.g_rvars)
        
        # save the graph
        saver.export_meta_graph(os.path.join(FLAGS.train_dir, 'model.meta'),
            as_text=False, clear_devices=True, clear_extraneous_savers=True)
        
        # monitored session
        gpu_options = tf.GPUOptions(allow_growth=True)
        config = tf.ConfigProto(gpu_options=gpu_options,
            log_device_placement=FLAGS.log_device_placement)
        
        with tf.train.MonitoredTrainingSession(
                checkpoint_dir=FLAGS.train_dir,
                hooks=[tf.train.StopAtStepHook(last_step=max_steps),
                       tf.train.NanTensorHook(g_loss),
                       tf.train.NanTensorHook(d_loss),
                       LoggerHook([g_loss, d_loss], steps_per_epoch)],
                config=config, log_step_count_steps=FLAGS.log_frequency) as mon_sess:
            # options
            sess = helper.get_session(mon_sess)
            if FLAGS.timeline_steps > 0:
                run_options = tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE)
                run_metadata = tf.RunMetadata()
            # restore pre-trained model
            if FLAGS.pretrain_dir and not FLAGS.restore:
                saver0.restore(sess, os.path.join(FLAGS.pretrain_dir, 'model'))
            # get variables
            val_window_ = sess.run(val_window)
            val_window_ = int(np.round(val_window_))
            lr_decay_last = val_window_
            # training session call
            def run_sess(options=None, run_metadata=None):
                mon_sess.run(g_train_op, options=options, run_metadata=run_metadata)
            # run session
            while not mon_sess.should_stop():
                global_step_ = tf.train.global_step(sess, global_step)
                # collect timeline info
                if FLAGS.timeline_steps > 0 and global_step_ // FLAGS.timeline_steps < 10 and global_step_ % FLAGS.timeline_steps == 0:
                    run_sess(run_options, run_metadata)
                    # Create the Timeline object, and write it to a json
                    tl = timeline.Timeline(run_metadata.step_stats)
                    ctf = tl.generate_chrome_trace_format()
                    with open(os.path.join(FLAGS.train_dir, 'timeline_{:0>7}.json'.format(global_step_)), 'a') as f:
                        f.write(ctf)
                else:
                    run_sess()
                # save model periodically
                if FLAGS.save_steps > 0 and global_step_ % FLAGS.save_steps == 0:
                    saver.save(sess, os.path.join(FLAGS.train_dir, 'model_{:0>7}'.format(global_step_)),
                               write_meta_graph=False, write_state=False)
                # test model on validation set
                if FLAGS.lr_decay_steps < 0 and FLAGS.lr_decay_factor != 0 and global_step_ % FLAGS.lr_decay_steps == 0:
                    # get validation error on current model
                    val_batches_loss = []
                    for _ in range(val_batches):
                        feed_dict = {images_src: val_src_batches[_]}
                        val_batches_loss.append(sess.run(g_loss, feed_dict=feed_dict))
                    val_loss = np.mean(val_batches_loss)
                    val_losses.append(val_loss)
                    print('validation: step {}, val_loss = {:.8}'.format(global_step_, val_loss))
                    # compare recent few losses to previous few losses, decay learning rate if not decreasing
                    if len(val_losses) >= lr_decay_last + val_window_:
                        val_current = np.sort(val_losses[-val_window_ : ])
                        val_previous = np.sort(val_losses[-val_window_ * 2 : -val_window_])
                        def _mean(array, percent=0.1):
                            clip = int(np.round(len(array) * percent))
                            return np.mean(np.sort(array)[clip : -clip if clip > 0 else None])
                        val_current = np.mean(val_current), np.median(val_current), np.min(val_current)
                        val_previous = np.mean(val_previous), np.median(val_previous), np.min(val_previous)
                        print('    statistics of {} losses (mean | median | min)'.format(val_window_))
                        print('        previous: {}'.format(val_previous))
                        print('        current:  {}'.format(val_current))
                        if val_current[0] + val_current[1] >= val_previous[0] + val_previous[1]:
                            lr_decay_last = len(val_losses)
                            val_window_, lr_ = sess.run((val_window_op, g_lr_decay_op))
                            val_window_ = int(np.round(val_window_))
                            print('    learning rate decayed to {}'.format(lr_))
Example #3
0
def test():
    # label names
    if FLAGS.num_labels == 4:
        LABEL_NAMES = ['creatine', 'gaba', 'glutamate', 'glutamine']
    elif FLAGS.num_labels == 12:
        LABEL_NAMES = [
            'choline-truncated', 'creatine', 'gaba', 'glutamate', 'glutamine',
            'glycine', 'lactate', 'myo-inositol', 'NAAG-truncated',
            'n-acetylaspartate', 'phosphocreatine', 'taurine'
        ]
    elif FLAGS.num_labels == 16:
        LABEL_NAMES = [
            'acetate', 'aspartate', 'choline-truncated', 'creatine', 'gaba',
            'glutamate', 'glutamine', 'histamine', 'histidine', 'lactate',
            'myo-inositol', 'n-acetylaspartate', 'scyllo-inositol',
            'succinate', 'taurine', 'valine'
        ]
    else:
        LABEL_NAMES = list(range(FLAGS.num_labels))

    print(*LABEL_NAMES)
    print('No.{}'.format(FLAGS.postfix))

    # get dataset files
    labels_file = os.path.join(FLAGS.dataset, 'labels/labels.npy')
    files = helper.listdir_files(FLAGS.dataset,
                                 recursive=False,
                                 filter_ext=['.npy'],
                                 encoding=True)
    steps_per_epoch = len(files) // FLAGS.batch_size
    epoch_size = steps_per_epoch * FLAGS.batch_size
    max_steps = steps_per_epoch
    files = files[:epoch_size]

    with tf.Graph().as_default():
        # pre-processing for input
        with tf.device('/cpu:0'):
            spectrum, labels_ref = inputs(FLAGS,
                                          files,
                                          labels_file,
                                          epoch_size,
                                          is_testing=True)

        # build model
        model = MRSmodel(FLAGS,
                         data_format=FLAGS.data_format,
                         seq_size=FLAGS.seq_size,
                         num_labels=FLAGS.num_labels)

        model.build_model(spectrum)

        # a saver object which will save all the variables
        saver = tf.train.Saver(var_list=model.g_mvars)

        # get output
        labels_pred = tf.get_default_graph().get_tensor_by_name('Output:0')

        # losses
        ret_loss = list(get_losses(labels_ref, labels_pred))
        ret_labels = [labels_ref, labels_pred]
        ret = ret_loss + ret_labels
        ret_loss = ret[:len(ret_loss) - 1]

        # model files
        if FLAGS.progress:
            mfiles = helper.listdir_files(FLAGS.train_dir,
                                          recursive=False,
                                          filter_ext=['.index'],
                                          encoding=None)
            mfiles = [f[:-6] for f in mfiles if 'model_' in f]
            mfiles.sort()
            stats = []
        else:
            mfiles = [tf.train.latest_checkpoint(FLAGS.train_dir)]

        for model_file in mfiles:
            with setup() as sess:
                # restore variables from checkpoint
                saver.restore(sess, model_file)

                # run session
                sum_loss = [0 for _ in range(len(ret_loss))]
                all_errors = []
                labels_ref = []
                labels_pred = []
                for i in range(max_steps):
                    cur_ret = sess.run(ret)
                    cur_loss = cur_ret[0:len(ret_loss)]
                    cur_errors = cur_ret[len(ret_loss)]
                    labels_ref.append(cur_ret[len(ret_loss) + 1])
                    labels_pred.append(cur_ret[len(ret_loss) + 2])
                    all_errors.append(cur_errors)
                    # monitor losses
                    for _ in range(len(ret_loss)):
                        sum_loss[_] += cur_loss[_]
                    #print('batch {}, MSE {}, MAD {}, MSE valid {}, MAD valid {}, False Positives {}, False Negatives {}'.format(i, *cur_loss))

                # summary
                mean_loss = [l / max_steps for l in sum_loss]
                mean_loss[2] /= FLAGS.batch_size
                mean_loss[3] /= FLAGS.batch_size
                mean_loss[4] /= FLAGS.batch_size * FLAGS.num_labels
                mean_loss[5] /= FLAGS.batch_size * FLAGS.num_labels
                print('{} Metabolites'.format(FLAGS.num_labels))
                print('MSE threshold {}'.format(FLAGS.mse_thresh))
                print('MAD threshold {}'.format(FLAGS.mad_thresh))
                print(
                    'Totally {} Samples, MSE {}, MAD {}, MSE accuracy {}, MAD accuracy {}, FP rate {}, FN rate {}'
                    .format(epoch_size, *mean_loss))

                # save stats
                if FLAGS.progress:
                    model_num = os.path.split(model_file)[1][6:]
                    stats.append(np.array([float(model_num)] + mean_loss))

    # errors
    import matplotlib.pyplot as plt
    all_errors = np.concatenate(all_errors, axis=0)
    for _ in range(FLAGS.num_labels):
        errors = all_errors[:, _]
        plt.figure()
        plt.title('Error Ratio Histogram - {}'.format(LABEL_NAMES[_]))
        plt.hist(errors, bins=100, range=(0, 1))
        plt.savefig(os.path.join(FLAGS.test_dir, 'hist_{}.png'.format(_)))
        plt.close()

    # labels
    labels_ref = np.concatenate(labels_ref, axis=0)
    labels_pred = np.concatenate(labels_pred, axis=0)
    with open(os.path.join(FLAGS.test_dir, 'labels.log'), mode='w') as file:
        file.write('Labels (Ground Truth)\nLabels (Predicted)\n\n')
        for _ in range(epoch_size):
            file.write('{}\n{}\n\n'.format(labels_ref[_], labels_pred[_]))

    # save stats
    if FLAGS.progress:
        stats = np.stack(stats)
        np.save(os.path.join(FLAGS.test_dir, 'stats.npy'), stats)
        fig, ax = plt.subplots()
        ax.set_title('Test Error with Training Progress')
        ax.set_xlabel('training steps')
        ax.set_ylabel('mean absolute difference')
        ax.set_xscale('linear')
        ax.set_yscale('log')
        ax.plot(stats[:, 0], stats[:, 2])
        ax.axis(ymin=0)
        plt.tight_layout()
        plt.savefig(os.path.join(FLAGS.test_dir, 'stats.png'))
        plt.close()

    print('')
Example #4
0
def train():
    import random
    files = helper.listdir_files(FLAGS.dataset,
                                 filter_ext=['.jpeg', '.jpg', '.png'],
                                 encoding=True)
    random.shuffle(files)
    steps_per_epoch = len(files) // FLAGS.batch_size
    epoch_size = steps_per_epoch * FLAGS.batch_size
    max_steps = steps_per_epoch * FLAGS.num_epochs
    files = files[:epoch_size]
    steps_per_epoch //= FLAGS.critic_iters + 1
    max_steps //= FLAGS.critic_iters + 1
    print('epoch size: {}\n{} steps per epoch\n{} epochs\n{} steps'.format(
        epoch_size, steps_per_epoch, FLAGS.num_epochs, max_steps))

    with tf.Graph().as_default():
        # pre-processing for input
        with tf.device('/cpu:0'):
            images_lr, images_hr = inputs(FLAGS, files, is_training=True)

        # build model
        model = SRmodel(FLAGS,
                        data_format=FLAGS.data_format,
                        input_range=FLAGS.input_range,
                        output_range=FLAGS.output_range,
                        multiGPU=FLAGS.multiGPU,
                        use_fp16=FLAGS.use_fp16,
                        scaling=FLAGS.scaling,
                        image_channels=FLAGS.image_channels,
                        input_height=FLAGS.patch_height // FLAGS.scaling,
                        input_width=FLAGS.patch_width // FLAGS.scaling)

        gd_loss = model.build_train(images_lr, images_hr)

        # training step and op
        global_step = tf.train.get_or_create_global_step()
        g_train_op, d_train_op = model.train(global_step)

        # a saver object which will save all the variables
        saver = tf.train.Saver(var_list=model.g_svars,
                               max_to_keep=1 << 16,
                               save_relative_paths=True)

        # save the graph
        saver.export_meta_graph(os.path.join(FLAGS.train_dir, 'model.meta'),
                                as_text=False,
                                clear_devices=True,
                                clear_extraneous_savers=True)

        # monitored session
        gpu_options = tf.GPUOptions(allow_growth=True)
        config = tf.ConfigProto(
            gpu_options=gpu_options,
            log_device_placement=FLAGS.log_device_placement)

        with tf.train.MonitoredTrainingSession(
                checkpoint_dir=FLAGS.train_dir,
                hooks=[
                    tf.train.StopAtStepHook(last_step=max_steps),
                    tf.train.NanTensorHook(gd_loss[0]),
                    tf.train.NanTensorHook(gd_loss[1]),
                    LoggerHook(gd_loss, steps_per_epoch)
                ],
                config=config,
                log_step_count_steps=FLAGS.log_frequency) as mon_sess:
            # options
            sess = helper.get_session(mon_sess)
            if FLAGS.timeline_steps > 0:
                run_options = tf.RunOptions(
                    trace_level=tf.RunOptions.FULL_TRACE)
                run_metadata = tf.RunMetadata()
            # restore pre-trained model
            if FLAGS.pretrain_dir and not FLAGS.restore:
                saver.restore(sess, os.path.join(FLAGS.pretrain_dir, 'model'))
            # sessions
            def run_sess(options=None, run_metadata=None):
                for k in range(FLAGS.critic_iters):
                    sess.run(d_train_op)
                mon_sess.run(g_train_op,
                             options=options,
                             run_metadata=run_metadata)

            # run sessions
            while not mon_sess.should_stop():
                step = tf.train.global_step(sess, global_step)
                if FLAGS.timeline_steps > 0 and step // FLAGS.timeline_steps < 10 and step % FLAGS.timeline_steps == 0:
                    run_sess(run_options, run_metadata)
                    # Create the Timeline object, and write it to a json
                    tl = timeline.Timeline(run_metadata.step_stats)
                    ctf = tl.generate_chrome_trace_format()
                    with open(
                            os.path.join(FLAGS.train_dir,
                                         'timeline_{:0>7}.json'.format(step)),
                            'a') as f:
                        f.write(ctf)
                else:
                    run_sess()
                if FLAGS.save_steps > 0 and step % FLAGS.save_steps == 0:
                    saver.save(sess,
                               os.path.join(FLAGS.train_dir,
                                            'model_{:0>7}'.format(step)),
                               write_meta_graph=False,
                               write_state=False)
Example #5
0
def test():
    # get dataset
    files = helper.listdir_files(FLAGS.dataset,
                                 filter_ext=['.jpeg', '.jpg', '.png'])
    steps_per_epoch = len(files) // FLAGS.batch_size
    epoch_size = steps_per_epoch * FLAGS.batch_size
    max_steps = steps_per_epoch
    files = files[:epoch_size]

    # test set
    test_lr_batches = []
    test_hr_batches = []
    with tf.Graph().as_default():
        # dataset
        with tf.device('/cpu:0'):
            test_lr, test_hr = inputs(FLAGS, files, is_testing=True)
        with session() as sess:
            for _ in range(max_steps):
                _lr, _hr = sess.run((test_lr, test_hr))
                test_lr_batches.append(_lr)
                test_hr_batches.append(_hr)

    with tf.Graph().as_default():
        images_lr = tf.placeholder(tf.float32, name='InputLR')
        images_hr = tf.placeholder(tf.float32, name='InputHR')

        # build model
        model = SRmodel(FLAGS,
                        data_format=FLAGS.data_format,
                        input_range=FLAGS.input_range,
                        output_range=FLAGS.output_range,
                        multiGPU=FLAGS.multiGPU,
                        use_fp16=FLAGS.use_fp16,
                        scaling=FLAGS.scaling,
                        image_channels=FLAGS.image_channels)

        model.build_model(images_lr)

        # a Saver object to restore the variables with mappings
        saver = tf.train.Saver(var_list=model.g_rvars)

        # get output
        images_sr = tf.get_default_graph().get_tensor_by_name('Output:0')

        # losses
        ret_loss = list(get_losses(images_hr, images_sr))

        # post-processing for output
        with tf.device('/cpu:0'):
            # data format conversion
            if FLAGS.data_format == 'NCHW':
                images_lr = utils.image.NCHW2NHWC(images_lr)
                images_hr = utils.image.NCHW2NHWC(images_hr)
                images_sr = utils.image.NCHW2NHWC(images_sr)

            # Bicubic upscaling
            shape = tf.shape(images_lr)
            upsize = [shape[-3] * FLAGS.scaling, shape[-2] * FLAGS.scaling]
            images_bicubic = tf.image.resize_images(
                images_lr,
                upsize,
                tf.image.ResizeMethod.BICUBIC,
                align_corners=True)

            # PNGs output
            ret_pngs = []
            ret_pngs.extend(helper.BatchPNG(images_lr, FLAGS.batch_size))
            ret_pngs.extend(helper.BatchPNG(images_hr, FLAGS.batch_size))
            ret_pngs.extend(helper.BatchPNG(images_sr, FLAGS.batch_size))
            ret_pngs.extend(helper.BatchPNG(images_bicubic, FLAGS.batch_size))

        # initialize session
        sess = session()

        # test latest checkpoint
        model_file = tf.train.latest_checkpoint(FLAGS.train_dir)
        saver.restore(sess, model_file)

        ret = ret_loss + ret_pngs
        sum_loss = [0 for _ in range(len(ret_loss))]
        for step in range(max_steps):
            feed_dict = {
                'InputLR:0': test_lr_batches[step],
                'InputHR:0': test_hr_batches[step]
            }
            cur_ret = sess.run(ret, feed_dict=feed_dict)
            cur_loss = cur_ret[0:len(ret_loss)]
            cur_pngs = cur_ret[len(ret_loss):]
            # monitor losses
            for _ in range(len(ret_loss)):
                sum_loss[_] += cur_loss[_]
            #print('batch {}, MSE (RGB) {}, MAD (RGB) {},'
            #      'SS-SSIM(Y) {}, MS-SSIM (Y) {}, loss {}'
            #      .format(step, *cur_loss))
            # images output
            _start = step * FLAGS.batch_size
            _stop = _start + FLAGS.batch_size
            _range = range(_start, _stop)
            ofiles = []
            ofiles.extend([
                os.path.join(FLAGS.test_dir, '{:0>5}.0.LR.png'.format(_))
                for _ in _range
            ])
            ofiles.extend([
                os.path.join(FLAGS.test_dir, '{:0>5}.1.HR.png'.format(_))
                for _ in _range
            ])
            ofiles.extend([
                os.path.join(FLAGS.test_dir,
                             '{:0>5}.2.SR{}.png'.format(_, FLAGS.postfix))
                for _ in _range
            ])
            ofiles.extend([
                os.path.join(FLAGS.test_dir, '{:0>5}.3.Bicubic.png'.format(_))
                for _ in _range
            ])
            helper.WriteFiles(cur_pngs, ofiles)

        # summary
        print('No.{}'.format(FLAGS.postfix))
        mean_loss = [l / max_steps for l in sum_loss]
        psnr = 10 * np.log10(1 / mean_loss[0]) if mean_loss[0] > 0 else 100
        print('PSNR (RGB) {}, MAD (RGB) {}, '
              'SS-SSIM(Y) {}, MS-SSIM (Y) {}, loss {}'.format(
                  psnr, *mean_loss[1:]))

        # test progressively saved models
        if FLAGS.progress:
            mfiles = helper.listdir_files(FLAGS.train_dir,
                                          recursive=False,
                                          filter_ext=['.index'],
                                          encoding=None)
            mfiles = [f[:-6] for f in mfiles if 'model_' in f]
            mfiles.sort()
            stats = []
        else:
            mfiles = []

        for model_file in mfiles:
            # restore variables from saved model
            saver.restore(sess, model_file)

            # run session
            sum_loss = [0 for _ in range(len(ret_loss))]
            for step in range(max_steps):
                feed_dict = {
                    'InputLR:0': test_lr_batches[step],
                    'InputHR:0': test_hr_batches[step]
                }
                cur_loss = sess.run(ret_loss, feed_dict=feed_dict)
                # monitor losses
                for _ in range(len(ret_loss)):
                    sum_loss[_] += cur_loss[_]

            # summary
            mean_loss = [l / max_steps for l in sum_loss]

            # save stats
            if FLAGS.progress:
                model_num = os.path.split(model_file)[1][6:]
                stats.append(np.array([float(model_num)] + mean_loss))

    # save stats
    import matplotlib.pyplot as plt
    if FLAGS.progress:
        stats = np.stack(stats)
        np.save(os.path.join(FLAGS.test_dir, 'stats.npy'), stats)
        fig, ax = plt.subplots()
        ax.set_title('Test Error with Training Progress')
        ax.set_xlabel('training steps')
        ax.set_ylabel('MAD (RGB)')
        ax.set_xscale('linear')
        ax.set_yscale('log')
        stats = stats[1:]
        ax.plot(stats[:, 0], stats[:, 2])
        #ax.axis(ymin=0)
        plt.tight_layout()
        plt.savefig(os.path.join(FLAGS.test_dir, 'stats.png'))
        plt.close()

    print('')
Example #6
0
def test():
    files = helper.listdir_files(FLAGS.dataset,
                                 filter_ext=['.jpeg', '.jpg', '.png'])
    steps_per_epoch = len(files) // FLAGS.batch_size
    epoch_size = steps_per_epoch * FLAGS.batch_size
    max_steps = steps_per_epoch
    files = files[:epoch_size]
    
    with tf.Graph().as_default():
        # pre-processing for input
        with tf.device('/cpu:0'):
            images_src = inputs(FLAGS, files, is_testing=True)
        
        # build model
        model = ICmodel(FLAGS, data_format=FLAGS.data_format,
            input_range=FLAGS.input_range, output_range=FLAGS.output_range,
            multiGPU=FLAGS.multiGPU, use_fp16=FLAGS.use_fp16,
            image_channels=FLAGS.image_channels, input_height=FLAGS.patch_height,
            input_width=FLAGS.patch_width, batch_size=FLAGS.batch_size)
        
        model.build_model(images_src, None)
        
        # a saver object which will save all the variables
        saver = tf.train.Saver(var_list=model.g_svars)
        
        # get output
        images_enc = tf.get_default_graph().get_tensor_by_name('Encoded:0')
        images_dec = tf.get_default_graph().get_tensor_by_name('Decoded:0')
        
        # losses
        ret_loss = list(get_losses(images_src, images_dec, images_enc))
        
        # post-processing for output
        with tf.device('/cpu:0'):
            # data format conversion
            if FLAGS.data_format == 'NCHW':
                images_src = utils.image.NCHW2NHWC(images_src)
                images_dec = utils.image.NCHW2NHWC(images_dec)
                images_enc = utils.image.NCHW2NHWC(images_enc)
            
            # PNGs output
            ret_pngs = []
            ret_pngs.extend(helper.BatchPNG(images_src, FLAGS.batch_size))
            ret_pngs.extend(helper.BatchPNG(images_dec, FLAGS.batch_size))
            ret_pngs.extend(helper.BatchPNG(images_enc, FLAGS.batch_size))
        
        # test latest checkpoint
        with setup() as sess:
            # restore variables from latest checkpoint
            model_file = tf.train.latest_checkpoint(FLAGS.train_dir)
            saver.restore(sess, model_file)
            
            # run session
            ret = ret_loss + ret_pngs
            sum_loss = [0 for _ in range(len(ret_loss))]
            for step in range(max_steps):
                cur_ret = sess.run(ret)
                cur_loss = cur_ret[0:len(ret_loss)]
                cur_pngs = cur_ret[len(ret_loss):]
                # monitor losses
                for _ in range(len(ret_loss)):
                    sum_loss[_] += cur_loss[_]
                # images output
                _start = step * FLAGS.batch_size
                _stop = _start + FLAGS.batch_size
                _range = range(_start, _stop)
                ofiles = []
                ofiles.extend([os.path.join(FLAGS.test_dir,
                    '{:0>5}.0.src.png'.format(_)) for _ in _range])
                ofiles.extend([os.path.join(FLAGS.test_dir,
                    '{:0>5}.1.dec{}.png'.format(_, FLAGS.postfix)) for _ in _range])
                ofiles.extend([os.path.join(FLAGS.test_dir,
                    '{:0>5}.2.enc{}.png'.format(_, FLAGS.postfix)) for _ in _range])
                helper.WriteFiles(cur_pngs, ofiles)
            
            # summary
            print('No.{}'.format(FLAGS.postfix))
            mean_loss = [l / max_steps for l in sum_loss]
            psnr = 10 * np.log10(1 / mean_loss[0]) if mean_loss[0] > 0 else 100
            print('PSNR (RGB) {}, MAD (RGB) {}, SS-SSIM(Y) {}, MS-SSIM (Y) {}, Entropy {}'.format(
                   psnr, *mean_loss[1:]))
        
        # test progressively saved models
        if FLAGS.progress:
            mfiles = helper.listdir_files(FLAGS.train_dir, recursive=False,
                                          filter_ext=['.index'],
                                          encoding=None)
            mfiles = [f[:-6] for f in mfiles if 'model_' in f]
            mfiles.sort()
            stats = []
        else:
            mfiles = []
        
        for model_file in mfiles:
            with setup() as sess:
                # restore variables from saved model
                saver.restore(sess, model_file)
                
                # run session
                sum_loss = [0 for _ in range(len(ret_loss))]
                for step in range(max_steps):
                    cur_loss = sess.run(ret_loss)
                    # monitor losses
                    for _ in range(len(ret_loss)):
                        sum_loss[_] += cur_loss[_]
                
                # summary
                mean_loss = [l / max_steps for l in sum_loss]
                
                # save stats
                if FLAGS.progress:
                    model_num = os.path.split(model_file)[1][6:]
                    stats.append(np.array([float(model_num)] + mean_loss))
    
    # save stats
    import matplotlib.pyplot as plt
    if FLAGS.progress:
        stats = np.stack(stats)
        np.save(os.path.join(FLAGS.test_dir, 'stats.npy'), stats)
        fig, ax = plt.subplots()
        ax.set_title('Test Error with Training Progress')
        ax.set_xlabel('training steps')
        ax.set_ylabel('MAD (RGB)')
        ax.set_xscale('linear')
        ax.set_yscale('log')
        stats = stats[1:]
        ax.plot(stats[:, 0], stats[:, 2])
        ax.axis(ymin=0)
        plt.tight_layout()
        plt.savefig(os.path.join(FLAGS.test_dir, 'stats.png'))
        plt.close()
    
    print('')
Example #7
0
def test():
    # label names
    if FLAGS.num_labels == 4:
        LABEL_NAMES = ['creatine', 'gaba', 'glutamate', 'glutamine']
    elif FLAGS.num_labels == 12:
        LABEL_NAMES = [
            'choline-truncated', 'creatine', 'gaba', 'glutamate', 'glutamine',
            'glycine', 'lactate', 'myo-inositol', 'NAAG-truncated',
            'n-acetylaspartate', 'phosphocreatine', 'taurine'
        ]
    elif FLAGS.num_labels == 16:
        LABEL_NAMES = [
            'acetate', 'aspartate', 'choline-truncated', 'creatine', 'gaba',
            'glutamate', 'glutamine', 'histamine', 'histidine', 'lactate',
            'myo-inositol', 'n-acetylaspartate', 'scyllo-inositol',
            'succinate', 'taurine', 'valine'
        ]
    else:
        LABEL_NAMES = list(range(FLAGS.num_labels))

    # get dataset files
    labels_file = os.path.join(FLAGS.dataset, 'labels/labels.npy')
    files = helper.listdir_files(FLAGS.dataset,
                                 recursive=False,
                                 filter_ext=['.npy'],
                                 encoding=True)
    steps_per_epoch = len(files) // FLAGS.batch_size
    epoch_size = steps_per_epoch * FLAGS.batch_size
    max_steps = steps_per_epoch
    files = files[:epoch_size]

    with tf.Graph().as_default():
        # setup global tensorflow state
        sess, summary_writer = setup()

        # pre-processing for input
        with tf.device('/cpu:0'):
            spectrum, labels_gt = inputs(FLAGS,
                                         files,
                                         labels_file,
                                         epoch_size,
                                         is_testing=True)

        # model inference and losses
        labels_pd = model.inference(spectrum, is_training=False)
        ret_loss = list(get_losses(labels_gt, labels_pd))
        ret_labels = [labels_gt, labels_pd]

        # restore variables from checkpoint
        saver = tf.train.Saver()
        saver.restore(sess, tf.train.latest_checkpoint(FLAGS.train_dir))

        # run session
        ret = ret_loss + ret_labels
        ret_loss = ret[:len(ret_loss) - 1]
        sum_loss = [0 for _ in range(len(ret_loss))]
        all_errors = []
        labels_gt = []
        labels_pd = []
        for i in range(max_steps):
            cur_ret = sess.run(ret)
            cur_loss = cur_ret[0:len(ret_loss)]
            cur_errors = cur_ret[len(ret_loss)]
            labels_gt.append(cur_ret[len(ret_loss) + 1])
            labels_pd.append(cur_ret[len(ret_loss) + 2])
            all_errors.append(cur_errors)
            # monitor losses
            for _ in range(len(ret_loss)):
                sum_loss[_] += cur_loss[_]
            #print('batch {}, MSE {}, MAD {}, MSE valid {}, MAD valid {}, False Positives {}, False Negatives {}'.format(i, *cur_loss))
        sess.close()

    # summary
    mean_loss = [l / max_steps for l in sum_loss]
    mean_loss[2] /= FLAGS.batch_size
    mean_loss[3] /= FLAGS.batch_size
    mean_loss[4] /= FLAGS.batch_size * FLAGS.num_labels
    mean_loss[5] /= FLAGS.batch_size * FLAGS.num_labels
    print('{} Metabolites'.format(FLAGS.num_labels))
    print('MSE threshold {}'.format(FLAGS.mse_thresh))
    print('MAD threshold {}'.format(FLAGS.mad_thresh))
    print(
        'Totally {} Samples, MSE {}, MAD {}, MSE accuracy {}, MAD accuracy {}, FP rate {}, FN rate {}'
        .format(epoch_size, *mean_loss))

    # errors
    import matplotlib.pyplot as plt
    all_errors = np.concatenate(all_errors, axis=0)
    for _ in range(FLAGS.num_labels):
        errors = all_errors[:, _]
        plt.figure()
        plt.title('Error Ratio Histogram - {}'.format(LABEL_NAMES[_]))
        plt.hist(errors, bins=100, range=(0, 1))
        plt.savefig(
            os.path.join(FLAGS.test_dir,
                         'hist{}_{}.png'.format(FLAGS.var_index, _)))
        plt.close()

    # labels
    labels_gt = np.concatenate(labels_gt, axis=0)
    labels_pd = np.concatenate(labels_pd, axis=0)
    with open(os.path.join(FLAGS.test_dir,
                           'labels{}.log'.format(FLAGS.var_index)),
              mode='w') as file:
        file.write('Labels (Ground Truth)\nLabels (Predicted)\n\n')
        for _ in range(epoch_size):
            file.write('{}\n{}\n\n'.format(labels_gt[_], labels_pd[_]))

    # draw plots
    plt.figure()
    plt.title('Predicted Responses to {}'.format(LABEL_NAMES[FLAGS.var_index]))
    x = labels_gt[:, FLAGS.var_index]
    for l in range(FLAGS.num_labels):
        y = labels_pd[:, l]
        plt.plot(x, y, label=LABEL_NAMES[l])
    plt.legend(loc=2)
    plt.savefig(
        os.path.join(FLAGS.test_dir, 'val{}.png'.format(FLAGS.var_index)))
    plt.close()

    print('')
Example #8
0
def train():
    labels_file = os.path.join(FLAGS.dataset, 'labels/labels.npy')
    files = helper.listdir_files(FLAGS.dataset,
                                 recursive=False,
                                 filter_ext=['.npy'],
                                 encoding='utf-8')
    epoch_size = FLAGS.epoch_size if FLAGS.epoch_size > 0 else len(files)
    steps_per_epoch = epoch_size // FLAGS.batch_size
    epoch_size = steps_per_epoch * FLAGS.batch_size
    max_steps = steps_per_epoch * FLAGS.num_epochs
    files = files[:epoch_size]
    print('epoch size: {}\n{} steps per epoch\n{} epochs\n{} steps'.format(
        epoch_size, steps_per_epoch, FLAGS.num_epochs, max_steps))

    with tf.Graph().as_default():
        # pre-processing for input
        with tf.device('/cpu:0'):
            spectrum, labels_ref = inputs(FLAGS,
                                          files,
                                          labels_file,
                                          epoch_size,
                                          is_training=True)

        # build model
        model = MRSmodel(FLAGS,
                         data_format=FLAGS.data_format,
                         seq_size=FLAGS.seq_size,
                         num_labels=FLAGS.num_labels)

        g_loss = model.build_train(spectrum, labels_ref)

        # training step and op
        global_step = tf.train.get_or_create_global_step()
        g_train_op = model.train(global_step)

        # profiler
        #profiler(train_op)

        # a saver object which will save all the variables
        saver = tf.train.Saver(var_list=model.g_mvars,
                               max_to_keep=1 << 16,
                               save_relative_paths=True)

        # save the graph
        saver.export_meta_graph(os.path.join(FLAGS.train_dir, 'model.pbtxt'),
                                as_text=True,
                                clear_devices=True,
                                clear_extraneous_savers=True)
        saver.export_meta_graph(os.path.join(FLAGS.train_dir, 'model.meta'),
                                as_text=False,
                                clear_devices=True,
                                clear_extraneous_savers=True)

        # monitored session
        gpu_options = tf.GPUOptions(allow_growth=True)
        config = tf.ConfigProto(
            gpu_options=gpu_options,
            log_device_placement=FLAGS.log_device_placement)

        with tf.train.MonitoredTrainingSession(
                checkpoint_dir=FLAGS.train_dir,
                hooks=[
                    tf.train.StopAtStepHook(last_step=max_steps),
                    tf.train.NanTensorHook(g_loss),
                    LoggerHook(g_loss, steps_per_epoch)
                ],
                config=config,
                log_step_count_steps=FLAGS.log_frequency) as mon_sess:
            # options
            sess = helper.get_session(mon_sess)
            if FLAGS.timeline_steps > 0:
                run_options = tf.RunOptions(
                    trace_level=tf.RunOptions.FULL_TRACE)
                run_metadata = tf.RunMetadata()
            # restore pre-trained model
            if FLAGS.pretrain_dir and not FLAGS.restore:
                saver.restore(sess, os.path.join(FLAGS.pretrain_dir, 'model'))
            # sessions
            def run_sess(options=None, run_metadata=None):
                mon_sess.run(g_train_op,
                             options=options,
                             run_metadata=run_metadata)

            # run session
            while not mon_sess.should_stop():
                step = tf.train.global_step(sess, global_step)
                if FLAGS.timeline_steps > 0 and step % FLAGS.timeline_steps == 0:
                    run_sess(run_options, run_metadata)
                    # Create the Timeline object, and write it to a json
                    tl = timeline.Timeline(run_metadata.step_stats)
                    ctf = tl.generate_chrome_trace_format()
                    with open(
                            os.path.join(FLAGS.train_dir,
                                         'timeline_{:0>7}.json'.format(step)),
                            'a') as f:
                        f.write(ctf)
                else:
                    run_sess()
                if FLAGS.save_steps > 0 and step % FLAGS.save_steps == 0:
                    saver.save(sess,
                               os.path.join(FLAGS.train_dir,
                                            'model_{:0>7}'.format(step)),
                               write_meta_graph=False,
                               write_state=False)