Пример #1
0
def _Train():
    params = dict()
    params['batch_size'] = FLAGS.batch_size
    params['seq_len'] = FLAGS.sequence_length
    params['image_size'] = FLAGS.image_size
    params['is_training'] = True
    params['norm_scale'] = FLAGS.norm_scale
    params['scale'] = FLAGS.scale
    params['learning_rate'] = FLAGS.learning_rate
    params['l2_loss'] = FLAGS.l2_loss
    params['reconstr_loss'] = FLAGS.reconstr_loss
    params['kl_loss'] = FLAGS.kl_loss

    train_dir = os.path.join(FLAGS.log_root, 'train')

    images = reader.ReadInput(FLAGS.data_filepattern,
                              shuffle=True,
                              params=params)
    images *= params['scale']
    # Increase the value makes training much faster.
    image_diff_list = reader.SequenceToImageAndDiff(images)
    model = cross_conv_model.CrossConvModel(image_diff_list, params)
    model.Build()
    tf.contrib.tfprof.model_analyzer.print_model_analysis(
        tf.get_default_graph())

    summary_writer = tf.summary.FileWriter(train_dir)
    sv = tf.train.Supervisor(logdir=FLAGS.log_root,
                             summary_op=None,
                             is_chief=True,
                             save_model_secs=60,
                             global_step=model.global_step)
    sess = sv.prepare_or_wait_for_session(
        FLAGS.master, config=tf.ConfigProto(allow_soft_placement=True))

    total_loss = 0.0
    step = 0
    sample_z_mean = np.zeros(model.z_mean.get_shape().as_list())
    sample_z_stddev_log = np.zeros(model.z_stddev_log.get_shape().as_list())
    sample_step = 0

    while True:
        _, loss_val, total_steps, summaries, z_mean, z_stddev_log = sess.run([
            model.train_op, model.loss, model.global_step, model.summary_op,
            model.z_mean, model.z_stddev_log
        ])

        sample_z_mean += z_mean
        sample_z_stddev_log += z_stddev_log
        total_loss += loss_val
        step += 1
        sample_step += 1

        if step % 100 == 0:
            summary_writer.add_summary(summaries, total_steps)
            sys.stderr.write('step: %d, loss: %f\n' %
                             (total_steps, total_loss / step))
            total_loss = 0.0
            step = 0

        # Sampled z is used for eval.
        # It seems 10k is better than 1k. Maybe try 100k next?
        if sample_step % 10000 == 0:
            with tf.gfile.Open(os.path.join(FLAGS.log_root, 'z_mean.npy'),
                               'w') as f:
                np.save(f, sample_z_mean / sample_step)
            with tf.gfile.Open(
                    os.path.join(FLAGS.log_root, 'z_stddev_log.npy'),
                    'w') as f:
                np.save(f, sample_z_stddev_log / sample_step)
            sample_z_mean = np.zeros(model.z_mean.get_shape().as_list())
            sample_z_stddev_log = np.zeros(
                model.z_stddev_log.get_shape().as_list())
            sample_step = 0
Пример #2
0
def _Eval():
    params = dict()
    params['batch_size'] = FLAGS.batch_size
    params['seq_len'] = FLAGS.sequence_length
    params['image_size'] = FLAGS.image_size
    params['is_training'] = False
    params['norm_scale'] = FLAGS.norm_scale
    params['scale'] = FLAGS.scale
    params['l2_loss'] = FLAGS.l2_loss
    params['reconstr_loss'] = FLAGS.reconstr_loss
    params['kl_loss'] = FLAGS.kl_loss

    eval_dir = os.path.join(FLAGS.log_root, 'eval')

    images = reader.ReadInput(FLAGS.data_filepattern,
                              shuffle=False,
                              params=params)
    images *= params['scale']
    # Increase the value makes training much faster.
    image_diff_list = reader.SequenceToImageAndDiff(images)
    model = cross_conv_model.CrossConvModel(image_diff_list, params)
    model.Build()

    summary_writer = tf.summary.FileWriter(eval_dir)
    saver = tf.train.Saver()
    sess = tf.Session('', config=tf.ConfigProto(allow_soft_placement=True))
    tf.train.start_queue_runners(sess)

    while True:
        time.sleep(60)
        try:
            ckpt_state = tf.train.get_checkpoint_state(FLAGS.log_root)
        except tf.errors.OutOfRangeError as e:
            sys.stderr.write('Cannot restore checkpoint: %s\n' % e)
            continue
        if not (ckpt_state and ckpt_state.model_checkpoint_path):
            sys.stderr.write('No model to eval yet at %s\n' % FLAGS.log_root)
            continue
        sys.stderr.write('Loading checkpoint %s\n' %
                         ckpt_state.model_checkpoint_path)
        saver.restore(sess, ckpt_state.model_checkpoint_path)
        # Use the empirical distribution of z from training set.
        if not tf.gfile.Exists(os.path.join(FLAGS.log_root, 'z_mean.npy')):
            sys.stderr.write('No z at %s\n' % FLAGS.log_root)
            continue

        with tf.gfile.Open(os.path.join(FLAGS.log_root, 'z_mean.npy')) as f:
            sample_z_mean = np.load(io.BytesIO(f.read()))
        with tf.gfile.Open(os.path.join(FLAGS.log_root,
                                        'z_stddev_log.npy')) as f:
            sample_z_stddev_log = np.load(io.BytesIO(f.read()))

        total_loss = 0.0
        for _ in xrange(FLAGS.eval_batch_count):
            loss_val, total_steps, summaries = sess.run(
                [model.loss, model.global_step, model.summary_op],
                feed_dict={
                    model.z_mean: sample_z_mean,
                    model.z_stddev_log: sample_z_stddev_log
                })
            total_loss += loss_val

        summary_writer.add_summary(summaries, total_steps)
        sys.stderr.write('steps: %d, loss: %f\n' %
                         (total_steps, total_loss / FLAGS.eval_batch_count))