예제 #1
0
def eval_h5(conf, ckpt):
    """
    Train model for a number of steps.
    
    Args:
      conf: configuration dictionary
      ckpt: restore from ckpt
    """
    cropw = conf['cropw']
    mb_size = conf['mb_size']
    path_tmp = conf['path_tmp']
    n_epochs = conf['n_epochs']
    cw = conf['cw']
    grad_norm_thresh = conf['grad_norm_thresh']

    # Prepare data
    tr_stream, te_stream = tools.prepare_data(conf)
    n_tr = tr_stream.dataset.num_examples
    n_te = te_stream.dataset.num_examples

    with tf.Graph().as_default(), tf.device('/cpu:0' if FLAGS.dev_assign else None):
        # Placeholders
        Xs = [tf.placeholder(tf.float32, [None, cw, cw, 1], name='X_%02d' % i) \
              for i in range(FLAGS.num_gpus)]
        Ys = [tf.placeholder(tf.float32, [None, cw - 2*cropw, cw - 2*cropw, 1],
                             name='Y_%02d' % i) \
              for i in range(FLAGS.num_gpus)]

        # Calculate the gradients for each model tower
        tower_grads = []
        y_splits = []
        for i in range(FLAGS.num_gpus):
            with tf.device(('/gpu:%d' % i) if FLAGS.dev_assign else None):
                with tf.name_scope('%s_%02d' % (FLAGS.tower_name, i)) as scope:
                    # Calculate the loss for one tower. This function constructs
                    # the entire model but shares the variables across all towers.
                    y_split = model.inference(Xs[i], conf)
                    y_splits.append(y_split)
                    total_loss = model.loss(y_split, Ys[i], conf, scope)
                    
                    # Reuse variables for the next tower.
                    tf.get_variable_scope().reuse_variables()

        y = tf.concat(0, y_splits, name='y')

        # Tensorflow boilerplate
        sess, saver, summ_writer, summ_op = tools.tf_boilerplate(None, conf, ckpt)

        # Evaluation
        psnr_tr = eval_epoch(Xs, Ys, y, sess, tr_stream, cropw)
        psnr_te = eval_epoch(Xs, Ys, y, sess, te_stream, cropw)
        print('approx psnr_tr=%.3f' % psnr_tr)
        print('approx psnr_te=%.3f' % psnr_te)
        tr_stream.close()
        te_stream.close()
예제 #2
0
def train(conf, ckpt=None):
    """
    Train model for a number of steps.
    
    Args:
      conf: configuration dictionary
      ckpt: restore from ckpt
    """
    cropw = conf['cropw']
    mb_size = conf['mb_size']
    path_tmp = conf['path_tmp']
    n_epochs = conf['n_epochs']
    cw = conf['cw']
    grad_norm_thresh = conf['grad_norm_thresh']

    tools.reset_tmp(path_tmp)

    # Prepare data
    tr_stream, te_stream = tools.prepare_data(conf)
    n_tr = tr_stream.dataset.num_examples
    n_te = te_stream.dataset.num_examples

    with tf.Graph().as_default(), tf.device('/cpu:0' if FLAGS.dev_assign else None):
        # Exponential decay learning rate
        global_step = tf.get_variable('global_step', [],
            initializer=tf.constant_initializer(0), dtype=tf.int32,
            trainable=False)
        lr = tools.exp_decay_lr(global_step, n_tr, conf)

        # Create an optimizer that performs gradient descent
        opt = tf.train.AdamOptimizer(lr)

        # Placeholders
        Xs = [tf.placeholder(tf.float32, [None, cw, cw, 1], name='X_%02d' % i) \
              for i in range(FLAGS.num_gpus)]
        Ys = [tf.placeholder(tf.float32, [None, cw - 2*cropw, cw - 2*cropw, 1],
                             name='Y_%02d' % i) \
              for i in range(FLAGS.num_gpus)]

        # Calculate the gradients for each model tower
        tower_grads = []
        y_splits = []
        for i in range(FLAGS.num_gpus):
            with tf.device(('/gpu:%d' % i) if FLAGS.dev_assign else None):
                with tf.name_scope('%s_%02d' % (FLAGS.tower_name, i)) as scope:
                    # Calculate the loss for one tower. This function constructs
                    # the entire model but shares the variables across all towers.
                    y_split = model.inference(Xs[i], conf)
                    y_splits.append(y_split)
                    total_loss = model.loss(y_split, Ys[i], conf, scope)

                    # Calculate the gradients for the batch of data on this tower.
                    gvs = opt.compute_gradients(total_loss)

                    # Optionally clip gradients.
                    if grad_norm_thresh > 0:
                        gvs = tools.clip_by_norm(gvs, grad_norm_thresh)

                    # Keep track of the gradients across all towers.
                    tower_grads.append(gvs)
                    
                    # Reuse variables for the next tower.
                    tf.get_variable_scope().reuse_variables()

                    # Retain the summaries from the final tower.
                    summs = tf.get_collection(tf.GraphKeys.SUMMARIES, scope)

        y = tf.concat(0, y_splits, name='y')

        # We must calculate the mean of each gradient. Note that this is the
        # synchronization point across all towers.
        gvs = tools.average_gradients(tower_grads)

        # Apply the gradients to adjust the shared variables.
        apply_grad_op = opt.apply_gradients(gvs, global_step=global_step)

        # Add a summary to track the learning rate.
        summs.append(tf.scalar_summary('learning_rate', lr))

        # Add histograms for gradients.
        for g, v in gvs:
            if g:
                v_name = re.sub('%s_[0-9]*/' % FLAGS.tower_name, '', v.op.name)
                summs.append(tf.histogram_summary(v_name + '/gradients', g))

        # Tensorflow boilerplate
        sess, saver, summ_writer, summ_op = tools.tf_boilerplate(summs, conf, ckpt)

        # Baseline error
        bpsnr_tr = tools.baseline_psnr(tr_stream)
        bpsnr_te = tools.baseline_psnr(te_stream)
        print('approx baseline psnr_tr=%.3f' % bpsnr_tr)
        print('approx baseline psnr_te=%.3f' % bpsnr_te)

        # Train
        format_str = ('%s| %04d PSNR=%.3f (Tr: %.1fex/s; %.1fs/batch)'
                      '(Te: %.1fex/s; %.1fs/batch)')
        step = 0
        for epoch in range(n_epochs):
            print('--- Epoch %d ---' % epoch)
            # Training
            for X_c, y_c in tr_stream.get_epoch_iterator():
                y_c = y_c[:, cropw:-cropw, cropw:-cropw]
                chunk_size = X_c.shape[0]
                gpu_chunk = chunk_size // FLAGS.num_gpus
                dict_input1 = [(Xs[i], X_c[i*gpu_chunk : \
                                           ((i + 1)*gpu_chunk) \
                                           if (i != FLAGS.num_gpus - 1) \
                                           else chunk_size]) \
                               for i in range(FLAGS.num_gpus)]
                dict_input2 = [(Ys[i], y_c[i*gpu_chunk : \
                                           ((i + 1)*gpu_chunk) \
                                           if (i != FLAGS.num_gpus - 1) \
                                           else chunk_size]) \
                               for i in range(FLAGS.num_gpus)]
                feed = dict(dict_input1 + dict_input2)
                
                start_time = time.time()
                sess.run(apply_grad_op, feed_dict=feed)
                duration_tr = time.time() - start_time

                if step % 10 == 0:
                    feed2 = dict(dict_input1)
                    
                    start_time = time.time()
                    y_eval = sess.run(y, feed_dict=feed2)
                    duration_eval = time.time() - start_time
                    
                    psnr = tools.eval_psnr(y_c, y_eval)
                    ex_per_step_tr = mb_size * FLAGS.num_gpus / duration_tr
                    ex_per_step_eval = mb_size * FLAGS.num_gpus / duration_eval
                    print(format_str % (datetime.now().time(), step, psnr,
                          ex_per_step_tr, float(duration_tr / FLAGS.num_gpus),
                          ex_per_step_eval, float(duration_eval / FLAGS.num_gpus)))

                if step % 25 == 0:
                    summ_str = sess.run(summ_op, feed_dict=feed)
                    summ_writer.add_summary(summ_str, step)

                if step % 100 == 0:
                    saver.save(sess, os.path.join(path_tmp, 'ckpt'),
                        global_step=step)

                step += 1

            # Evaluation
            #psnr_tr = eval_epoch(Xs, Ys, y, sess, tr_stream, cropw)
            #psnr_te = eval_epoch(Xs, Ys, y, sess, te_stream, cropw)
            #print('approx psnr_tr=%.3f' % psnr_tr)
            #print('approx psnr_te=%.3f' % psnr_te)
            saver.save(sess, os.path.join(path_tmp, 'ckpt'),
                       global_step=step)            

        saver.save(sess, os.path.join(path_tmp, 'ckpt'),
                   global_step=step)
        tr_stream.close()
        te_stream.close()