コード例 #1
0
def train(hps):
    wfile = open(filepath, 'w')
    """Training loop."""
    images, labels = cifar_input.build_input(FLAGS.dataset,
                                             FLAGS.train_data_path,
                                             hps.batch_size, FLAGS.mode)
    model = fb_resnet_model.ResNet(hps, images, labels, FLAGS.mode)
    model.build_graph()

    param_stats = tf.contrib.tfprof.model_analyzer.print_model_analysis(
        tf.get_default_graph(),
        tfprof_options=tf.contrib.tfprof.model_analyzer.
        TRAINABLE_VARS_PARAMS_STAT_OPTIONS)
    sys.stdout.write('total_params: %d\n' % param_stats.total_parameters)

    tf.contrib.tfprof.model_analyzer.print_model_analysis(
        tf.get_default_graph(),
        tfprof_options=tf.contrib.tfprof.model_analyzer.FLOAT_OPS_OPTIONS)

    truth = tf.argmax(model.labels, axis=1)
    predictions = tf.argmax(model.predictions, axis=1)
    precision = tf.reduce_mean(tf.to_float(tf.equal(predictions, truth)))

    summary_hook = tf.train.SummarySaverHook(
        save_steps=100,
        output_dir=FLAGS.train_dir,
        summary_op=tf.summary.merge(
            [model.summaries,
             tf.summary.scalar('Precision', precision)]))

    logging_hook = tf.train.LoggingTensorHook(tensors={
        'step': model.global_step,
        'clss_loss': model.cost_cls,
        'loss': model.cost,
        'linear_loss': model.cost_lin,
        'linear_weight': model._weight,
        'precision': precision
    },
                                              every_n_iter=10)

    class _LearningRateSetterHook(tf.train.SessionRunHook):
        """Sets learning_rate based on global step."""
        def begin(self):
            self._lrn_rate = 0.1

        def before_run(self, run_context):
            return tf.train.SessionRunArgs(
                model.global_step,  # Asks for global step value.
                feed_dict={model.lrn_rate:
                           self._lrn_rate})  # Sets learning rate

        def after_run(self, run_context, run_values):
            train_step = run_values.results
            if train_step < 10000:
                self._lrn_rate = 0.1
            elif train_step < 20000:
                self._lrn_rate = 0.01
            elif train_step < 40000:
                self._lrn_rate = 0.001
            else:
                self._lrn_rate = 0.0001

    with tf.train.MonitoredTrainingSession(
            checkpoint_dir=FLAGS.log_root,
            hooks=[logging_hook, _LearningRateSetterHook()],
            chief_only_hooks=[summary_hook],
            # Since we provide a SummarySaverHook, we need to disable default
            # SummarySaverHook. To do that we set save_summaries_steps to 0.
            save_summaries_steps=0,
            config=tf.ConfigProto(allow_soft_placement=True)) as mon_sess:
        while not mon_sess.should_stop():
            _, kernel, global_step = mon_sess.run(
                [model.train_op, model.norm_kernel, model.global_step])
            #pdb.set_trace()
            _, gs, cost, cost_cls, cost_lin, lineval, prec = mon_sess.run([
                model.train_op, model.global_step, model.cost, model.cost_cls,
                model.cost_lin, model.lin_eval, precision
            ])
            wfile.write('%d\t%f\t%f\t%f\t%f\t%f\n' %
                        (gs, cost, cost_cls, cost_lin, lineval, prec))
        wfile.close()
コード例 #2
0
def evaluate(hps):
    """Eval loop."""
    images, labels = cifar_input.build_input(FLAGS.dataset,
                                             FLAGS.eval_data_path,
                                             hps.batch_size, FLAGS.mode)
    model = fb_resnet_model.ResNet(hps, images, labels, FLAGS.mode)
    model.build_graph()
    saver = tf.train.Saver()
    summary_writer = tf.summary.FileWriter(FLAGS.eval_dir)

    sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True))
    tf.train.start_queue_runners(sess)

    best_precision = 0.0
    while True:
        try:
            ckpt_state = tf.train.get_checkpoint_state(FLAGS.log_root)
        except tf.errors.OutOfRangeError as e:
            tf.logging.error('Cannot restore checkpoint: %s', e)
            continue
        if not (ckpt_state and ckpt_state.model_checkpoint_path):
            tf.logging.info('No model to eval yet at %s', FLAGS.log_root)
            continue
        tf.logging.info('Loading checkpoint %s',
                        ckpt_state.model_checkpoint_path)
        saver.restore(sess, ckpt_state.model_checkpoint_path)

        total_prediction, correct_prediction = 0, 0
        for _ in six.moves.range(FLAGS.eval_batch_count):
            (summaries, loss, predictions, truth, train_step) = sess.run([
                model.summaries, model.cost, model.predictions, model.labels,
                model.global_step
            ])

            truth = np.argmax(truth, axis=1)
            predictions = np.argmax(predictions, axis=1)
            correct_prediction += np.sum(truth == predictions)
            total_prediction += predictions.shape[0]

        #weight linearity evaluation
        kernel = sess.run([model.norm_kernel])
        kernel_lin_eval = weu.weight_lin_eval(kernel)
        precision = 1.0 * float(correct_prediction) / float(total_prediction)
        best_precision = max(precision, best_precision)

        kernel_linearity = tf.Summary()
        kernel_linearity.value.add(tag='Weight Linearity',
                                   simple_value=kernel_lin_eval)
        summary_writer.add_summary(kernel_linearity, train_step)

        precision_summ = tf.Summary()
        precision_summ.value.add(tag='Precision', simple_value=precision)
        summary_writer.add_summary(precision_summ, train_step)

        best_precision_summ = tf.Summary()
        best_precision_summ.value.add(tag='Best Precision',
                                      simple_value=best_precision)
        summary_writer.add_summary(best_precision_summ, train_step)

        summary_writer.add_summary(summaries, train_step)
        tf.logging.info(
            'loss: %.3f, weight_linearity: %.5f, precision: %.3f, best precision: %.3f'
            % (loss, kernel_lin_eval, precision, best_precision))
        summary_writer.flush()

        if FLAGS.eval_once:
            break

        time.sleep(60)