Пример #1
0
def input_fn(data_dir,
             subset,
             num_shards,
             batch_size,
             use_distortion_for_training=True):
    """Create input graph for model.

  Args:
    data_dir: Directory where TFRecords representing the dataset are located.
    subset: one of 'train', 'validate' and 'eval'.
    num_shards: num of towers participating in data-parallel training.
    batch_size: total batch size for training to be divided by the number of
    shards.
    use_distortion_for_training: True to use distortions.
  Returns:
    two lists of tensors for features and labels, each of num_shards length.
  """
    with tf.device('/cpu:0'):
        use_distortion = subset == 'train' and use_distortion_for_training
        dataset = cifar10.Cifar10DataSet(data_dir, subset, use_distortion)
        image_batch, label_batch = dataset.make_batch(batch_size)

        if args.synthetic:
            image_batch = tf.random_uniform((batch_size, 32, 32, 3))
            label_batch = tf.ones((batch_size, ), dtype=tf.int32)

        assert num_shards <= 1  # remove multi-gpu support for now
        return [image_batch], [label_batch]
    def __init__(self, args):
        self.args = args
        self.args.stages = list(map(int, self.args.stages.split('-')))
        self.args.growth = list(map(int, self.args.growth.split('-')))

        self.num_examples_train = cifar10.Cifar10DataSet.num_examples_per_epoch(True)
        self.num_examples_eval = cifar10.Cifar10DataSet.num_examples_per_epoch(False)
        self.num_batches_train = self.num_examples_train // self.args.bsize
        self.num_batches_eval = self.num_examples_eval // self.args.bsize

        with tf.device('/cpu:0'):

            with tf.name_scope('dataset'):
                self.dataset = cifar10.Cifar10DataSet(self.args.bsize)
                self.dataset.make_batch()

            with tf.name_scope('cosine_annealing_lr'):
                total_iters = self.args.ep * self.num_batches_train
                lr_op = self.args.lr * 0.5 * (1.0 + tf.cos(np.pi * (tf.train.get_or_create_global_step() / total_iters)))

            self.opt = tf.train.MomentumOptimizer(lr_op, self.args.momentum, use_nesterov=True)

        print ('========= TRAINING CONDENSENET =========')
        print ('      Initial LR : {}'.format(self.args.lr))
        print ('        LR decay : cosine annealing')
        print ('       Optimizer : Momentum Optimizer')
        print ('          Epochs : {}'.format(self.args.ep))
        print ('        Momentum : {}'.format(self.args.momentum))
        print ('          Stages : {}'.format(self.args.stages))
        print ('    Growth Rates : {}'.format(self.args.growth))
        print ('      Batch size : {}'.format(self.args.bsize))
        print ('  Num Batches EP : {}'.format(self.num_batches_train))
        print ('  Train Examples : {}'.format(self.num_examples_train))
        print ('    Val Examples : {}'.format(self.num_batches_eval * self.args.bsize))
        print ('========================================')
Пример #3
0
def input_fn(data_dir,
             subset,
             batch_size):
  dataset = cifar10.Cifar10DataSet(data_dir, subset, subset=='train')
  image_batch, label_batch = dataset.make_batch(batch_size)
  #return [image_batch], [label_batch]

  return {"x": image_batch}, label_batch
Пример #4
0
def get_dataset(subset):
    data_dir = './mount/data/cifar10/'
    use_distortion = False
    if subset == 'train':
        use_distortion = True
    return cifar10.Cifar10DataSet(data_dir,
                                  subset=subset,
                                  use_distortion=use_distortion)
Пример #5
0
def input_fn(mode, batch_size=128):
    dataset = cifar10.Cifar10DataSet(data_dir=config.data_dir,
                                     subset=mode,
                                     use_distortion=mode == 'train')

    with tf.device('/cpu:0'):
        image_batch, label_batch = dataset.make_batch(batch_size)

        return image_batch, label_batch
Пример #6
0
def load_preprocess_training_batch(batch_id, batch_size, subset,
                                   use_distortion_for_training, data_dir):
    """
    Load the Preprocessed Training data and return them in batches of <batch_size> or less
    """
    #with tf.device('/cpu:0'):
    use_distortion = 'train' in subset and use_distortion_for_training
    dataset = cifar10.Cifar10DataSet(data_dir, subset, use_distortion)
    image_batch, label_batch = dataset.make_batch(batch_size, epochs)
    return image_batch, label_batch
Пример #7
0
def input_fn(data_dir,
             subset,
             num_shards,
             batch_size,
             variable_strategy,
             run_config,
             use_distortion_for_training=True):
    """Create input graph for model.

  Args:
    data_dir: Directory where TFRecords representing the dataset are located.
    subset: one of 'train', 'validate' and 'eval'.
    num_shards: num of towers participating in data-parallel training.
    batch_size: total batch size for training to be divided by the number of
    shards.
    use_distortion_for_training: True to use distortions.
  Returns:
    two lists of tensors for features and labels, each of num_shards length.
  """
    #Is this called on every batch? make_batch is called here.

    if subset == 'train':
        batch_size = run_config.get_node_batch_size

    tf.logging.info(">>> Num shards: " + str(num_shards))
    tf.logging.info('batch-size value fed to input fn: ' + str(batch_size))

    consolidation_device = '/gpu:0' if variable_strategy == 'GPU' else '/cpu:0'
    with tf.device(consolidation_device):
        use_distortion = subset == 'train' and use_distortion_for_training
        dataset = cifar10.Cifar10DataSet(data_dir, subset, use_distortion)
        #XXX This is where the sharding needs to happen based on the worker?
        image_batch, label_batch = dataset.make_batch(batch_size)
        if num_shards <= 1:
            # No GPU available or only 1 GPU.
            return [image_batch], [label_batch]

        # Note that passing num=batch_size is safe here, even though
        # dataset.batch(batch_size) can, in some cases, return fewer than batch_size
        # examples. This is because it does so only when repeating for a limited
        # number of epochs, but our dataset repeats forever.
        # XXX repeats forever: what does that mean!?!?

        image_batch = tf.unstack(image_batch, num=batch_size, axis=0)
        label_batch = tf.unstack(label_batch, num=batch_size, axis=0)
        feature_shards = [[] for i in range(num_shards)]
        label_shards = [[] for i in range(num_shards)]
        for i in xrange(batch_size):
            idx = i % num_shards
            feature_shards[idx].append(image_batch[i])
            label_shards[idx].append(label_batch[i])
        feature_shards = [tf.parallel_stack(x) for x in feature_shards]
        label_shards = [tf.parallel_stack(x) for x in label_shards]
        return feature_shards, label_shards
Пример #8
0
def input_fn(data_dir,
             subset,
             num_shards,
             batch_size,
             use_distortion_for_training=True):
    """Create input graph for model.

  Args:
    data_dir: Directory where TFRecords representing the dataset are located.
    subset: one of 'train', 'validate' and 'eval'.
    num_shards: num of towers participating in data-parallel training.
    batch_size: total batch size for training to be divided by the number of
    shards.
    use_distortion_for_training: True to use distortions.
  Returns:
    two lists of tensors for features and labels, each of num_shards length.
  """
    with tf.device('/cpu:0'):
        use_distortion = subset == 'train' and use_distortion_for_training
        dataset = cifar10.Cifar10DataSet(data_dir, subset, use_distortion)
        image_batch, label_batch = dataset.make_batch(batch_size)

        assert not (args.synthetic and args.synthetic_labels)

        if args.synthetic:
            image_batch = tf.random_uniform((batch_size, 32, 32, 3))
            label_batch = tf.ones((batch_size, ), dtype=tf.int32)

        if args.synthetic_labels:
            label_batch = tf.ones((batch_size, ), dtype=tf.int32)

        if num_shards <= 1:
            # No GPU available or only 1 GPU.
            return [image_batch], [label_batch]

        # Note that passing num=batch_size is safe here, even though
        # dataset.batch(batch_size) can, in some cases, return fewer than batch_size
        # examples. This is because it does so only when repeating for a limited
        # number of epochs, but our dataset repeats forever.
        image_batch = tf.unstack(image_batch, num=batch_size, axis=0)
        label_batch = tf.unstack(label_batch, num=batch_size, axis=0)
        feature_shards = [[] for i in range(num_shards)]
        label_shards = [[] for i in range(num_shards)]
        for i in xrange(batch_size):
            idx = i % num_shards
            feature_shards[idx].append(image_batch[i])
            label_shards[idx].append(label_batch[i])
        feature_shards = [tf.parallel_stack(x) for x in feature_shards]
        label_shards = [tf.parallel_stack(x) for x in label_shards]
        return feature_shards, label_shards
Пример #9
0
def input_fn(subset, num_shards):
    """Create input graph for model.

  Args:
    subset: one of 'train', 'validate' and 'eval'.
    num_shards: num of towers participating in data-parallel training.
  Returns:
    two lists of tensors for features and labels, each of num_shards length.
  """
    if subset == 'train':
        batch_size = FLAGS.train_batch_size
    elif subset == 'validate' or subset == 'eval':
        batch_size = FLAGS.eval_batch_size
    else:
        raise ValueError(
            'Subset must be one of \'train\', \'validate\' and \'eval\'')
    with tf.device('/cpu:0'):
        use_distortion = subset == 'train' and FLAGS.use_distortion_for_training
        dataset = cifar10.Cifar10DataSet(FLAGS.data_dir, subset,
                                         use_distortion)
        image_batch, label_batch = dataset.make_batch(batch_size)
        if num_shards <= 1:
            # No GPU available or only 1 GPU.
            return [image_batch], [label_batch]

        # Note that passing num=batch_size is safe here, even though
        # dataset.batch(batch_size) can, in some cases, return fewer than batch_size
        # examples. This is because it does so only when repeating for a limited
        # number of epochs, but our dataset repeats forever.
        image_batch = tf.unstack(image_batch, num=batch_size, axis=0)
        label_batch = tf.unstack(label_batch, num=batch_size, axis=0)
        feature_shards = [[] for i in range(num_shards)]
        label_shards = [[] for i in range(num_shards)]
        for i in xrange(batch_size):
            idx = i % num_shards
            feature_shards[idx].append(image_batch[i])
            label_shards[idx].append(label_batch[i])
        feature_shards = [tf.parallel_stack(x) for x in feature_shards]
        label_shards = [tf.parallel_stack(x) for x in label_shards]
        return feature_shards, label_shards
Пример #10
0
             use_distortion_for_training=True):
  """Create input graph for model.

  Args:
    data_dir: Directory where TFRecords representing the dataset are located.
    subset: one of 'train', 'validate' and 'eval'.
    num_shards: num of towers participating in data-parallel training.
    batch_size: total batch size for training to be divided by the number of
    shards.
    use_distortion_for_training: True to use distortions.
  Returns:
    two lists of tensors for features and labels, each of num_shards length.
  """
  with tf.device('/cpu:0'):
    use_distortion = subset == 'train' and use_distortion_for_training
    dataset = cifar10.Cifar10DataSet(data_dir, subset, use_distortion)
    image_batch, label_batch = dataset.make_batch(batch_size)
    if num_shards <= 1:
      # No GPU available or only 1 GPU.
      return [image_batch], [label_batch]

    # Note that passing num=batch_size is safe here, even though
    # dataset.batch(batch_size) can, in some cases, return fewer than batch_size
    # examples. This is because it does so only when repeating for a limited
    # number of epochs, but our dataset repeats forever.
    image_batch = tf.unstack(image_batch, num=batch_size, axis=0)
    label_batch = tf.unstack(label_batch, num=batch_size, axis=0)
    feature_shards = [[] for i in range(num_shards)]
    label_shards = [[] for i in range(num_shards)]
    for i in xrange(batch_size):
      idx = i % num_shards
Пример #11
0
def input_fn(subset, num_shards):
  """Create input graph for model.

  Args:
    subset: one of 'train', 'validate' and 'eval'.
    num_shards: num of towers participating in data-parallel training.
  Returns:
    two lists of tensors for features and labels, each of num_shards length.
  """
  dataset = cifar10.Cifar10DataSet(FLAGS.data_dir)
  is_training = (subset == 'train')
  if is_training:
    batch_size = FLAGS.train_batch_size
  else:
    batch_size = FLAGS.eval_batch_size
  with tf.device('/cpu:0'), tf.name_scope('batching'):
    # CPU loads all data from disk since there're only 60k 32*32 RGB images.
    all_images, all_labels = dataset.read_all_data(subset)
    dataset = tf.contrib.data.Dataset.from_tensor_slices(
        (all_images, all_labels))
    dataset = dataset.map(
        lambda x, y: (tf.cast(x, tf.float32), tf.cast(y, tf.int32)),
        num_threads=2,
        output_buffer_size=batch_size)

    # Image preprocessing.
    def _preprocess(image, label):
      # If GPU is available, NHWC to NCHW transpose is done in ResNetCifar10
      # class, not included in preprocessing.
      return cifar10.Cifar10DataSet.preprocess(
          image, is_training, FLAGS.use_distortion_for_training), label
    dataset = dataset.map(
        _preprocess, num_threads=batch_size, output_buffer_size=2 * batch_size)
    # Repeat infinitely.
    dataset = dataset.repeat()
    if is_training:
      min_fraction_of_examples_in_queue = 0.4
      min_queue_examples = int(
          cifar10.Cifar10DataSet.num_examples_per_epoch(subset) *
          min_fraction_of_examples_in_queue)
      # Ensure that the capacity is sufficiently large to provide good random
      # shuffling
      dataset = dataset.shuffle(buffer_size=min_queue_examples + 3 * batch_size)
    dataset = dataset.batch(batch_size)
    iterator = dataset.make_one_shot_iterator()
    image_batch, label_batch = iterator.get_next()
    if num_shards <= 1:
      # No GPU available or only 1 GPU.
      return [image_batch], [label_batch]

    # Note that passing num=batch_size is safe here, even though
    # dataset.batch(batch_size) can, in some cases, return fewer than batch_size
    # examples. This is because it does so only when repeating for a limited
    # number of epochs, but our dataset repeats forever.
    image_batch = tf.unstack(image_batch, num=batch_size, axis=0)
    label_batch = tf.unstack(label_batch, num=batch_size, axis=0)
    feature_shards = [[] for i in range(num_shards)]
    label_shards = [[] for i in range(num_shards)]
    for i in xrange(batch_size):
      idx = i % num_shards
      feature_shards[idx].append(image_batch[i])
      label_shards[idx].append(label_batch[i])
    feature_shards = [tf.parallel_stack(x) for x in feature_shards]
    label_shards = [tf.parallel_stack(x) for x in label_shards]
    return feature_shards, label_shards
Пример #12
0
def main(data_dir,
         graph,
         mult,
         batch_size=1000,
         iterations=10,
         tune=True,
         write_log=False,
         **unused):
    tf.reset_default_graph()

    logdir = "log"
    subset = "validation"
    subset = "eval"
    input_layer = "input"
    output_layer = "resnet/tower_0/fully_connected/dense/BiasAdd:0"
    use_distortion = False

    num_intra_threads = 2

    sess_config = tf.ConfigProto(
        allow_soft_placement=True,
        intra_op_parallelism_threads=num_intra_threads)

    with tf.Session() as sess:
        dataset = cifar10.Cifar10DataSet(data_dir, subset, use_distortion)
        image_batch, label_batch = dataset.make_batch(batch_size)

        # assignment of attributes at protobuf level
        idop = 0
        with tf.gfile.GFile(graph, "rb") as f:
            graph_def = tf.GraphDef()
            graph_def.ParseFromString(f.read())
            for node in graph_def.node:
                if node.op != "AxConv2D": continue
                tf.logging.info("%s = %s (tune = %s)" %
                                (node.name, mult, tune))
                val = node.attr.get_or_create("AxMult")
                val.s = str.encode(mult)

                val = node.attr.get_or_create("AxTune")
                val.b = tune
                #node.attr["AxMult"] = val
                idop += 1
                #node.attr["AxMult"] = None

        # load graph
        predictions = tf.import_graph_def(graph_def,
                                          name="TestedNet",
                                          input_map={input_layer: image_batch},
                                          return_elements=[output_layer])

        graph = sess.graph

        labels = tf.reshape(label_batch, shape=[batch_size])
        probs = tf.reshape(predictions, shape=[batch_size, 11])

        mval = tf.argmax(probs, 1, output_type=tf.int32)
        equality = tf.equal(mval, labels)
        accuracy = tf.reduce_mean(tf.cast(equality, tf.float32))

        if write_log:
            writer = tf.summary.FileWriter(logdir)
            writer.add_graph(sess.graph)
            merged = tf.summary.merge_all()

        acc, cnt = 0, 0
        bacc = 0
        for it in range(iterations):
            start = time.time()
            amean = sess.run(accuracy)

            tf.logging.info("Mean accuracy (run %d): %.5f in %f sec", it,
                            amean,
                            time.time() - start)
            bacc += amean

        tf.logging.info("results;mult=%s;tune=%s;accuracy=%.3f" %
                        (str(mult), str(tune), float(bacc) / float(it + 1)))
        return float(bacc) / float(it + 1)
Пример #13
0
def main(_):
    num_train_examples = 45000
    melt.apps.train.init()

    batch_size = melt.batch_size()
    num_gpus = melt.num_gpus()

    batch_size_per_gpu = FLAGS.batch_size

    # batch size not changed but FLAGS.batch_size will change to batch_size / num_gpus
    #print('--------------batch_size, FLAGS.batch_size, num_steps_per_epoch', batch_size, FLAGS.batch_size, num_train_examples // batch_size)

    global_scope = FLAGS.algo
    with tf.variable_scope(global_scope) as global_scope:
        data_format = 'channels_first'
        num_layers = 44
        batch_norm_decay = 0.997
        batch_norm_epsilon = 1e-05
        data_dir = './mount/data/cifar10/'
        with tf.variable_scope('main') as scope:
            model = cifar10_model.ResNetCifar10(
                num_layers,
                batch_norm_decay=batch_norm_decay,
                batch_norm_epsilon=batch_norm_epsilon,
                is_training=True,
                data_format=data_format)

            dataset = cifar10.Cifar10DataSet(data_dir,
                                             subset='train',
                                             use_distortion=True)

            ## This is wrong will cause all gpu read same data, so slow convergence but will get better test result
            #_, image_batch, label_batch = dataset.make_batch(FLAGS.batch_size)
            def loss_function():
                # doing this 2gpu will get similar result as 1gpu, seems a bit better valid result and a bit worse test result might due to randomness
                _, image_batch, label_batch = dataset.make_batch(
                    batch_size_per_gpu)
                return tower_loss(model, image_batch, label_batch)

            #loss_function = lambda: tower_loss(model, image_batch, label_batch)
            loss = melt.tower_losses(loss_function, num_gpus)
            pred = model.predict()
            pred = pred['classes']
            label_batch = dataset.label_batch
            acc = tf.reduce_mean(tf.to_float(tf.equal(pred, label_batch)))

            #tf.summary.image('train/image', dataset.image_batch)
            # # Compute confusion matrix
            # matrix = tf.confusion_matrix(label_batch, pred, num_classes=10)
            # # Get a image tensor for summary usage
            # image_tensor = draw_confusion_matrix(matrix)
            # tf.summary.image('train/confusion_matrix', image_tensor)

            scope.reuse_variables()
            ops = [loss, acc]

            # TODO multiple gpu validation and inference

            validator = cifar10_model.ResNetCifar10(
                num_layers,
                batch_norm_decay=batch_norm_decay,
                batch_norm_epsilon=batch_norm_epsilon,
                is_training=False,
                data_format=data_format)

            valid_dataset = cifar10.Cifar10DataSet(data_dir,
                                                   subset='valid',
                                                   use_distortion=False)
            valid_iterator = valid_dataset.make_batch(batch_size)
            valid_id_batch, valid_image_batch, valid_label_batch = valid_iterator.get_next(
            )

            valid_loss = tower_loss(validator, valid_image_batch,
                                    valid_label_batch)
            valid_pred = validator.predict()
            valid_pred = valid_pred['classes']

            ## seems not work with non rpeat mode..
            #tf.summary.image('valid/image', valid_image_batch)
            ## Compute confusion matrix
            #matrix = tf.confusion_matrix(valid_label_batch, valid_pred, num_classes=10)
            ## Get a image tensor for summary usage
            #image_tensor = draw_confusion_matrix(matrix)
            #tf.summary.image('valid/confusion_matrix', image_tensor)

            #loss_function = lambda: tower_loss(validator, val_image_batch, val_label_batch)
            #val_loss = melt.tower_losses(loss_function, FLAGS.num_gpus, is_training=False)
            #eval_ops = [val_loss]

            metric_eval_fn = lambda model_path=None: \
                                evaluator.evaluate([valid_id_batch, valid_loss, valid_pred, valid_label_batch, valid_image_batch],
                                                   valid_iterator,
                                                   model_path=model_path)

            predictor = cifar10_model.ResNetCifar10(
                num_layers,
                batch_norm_decay=batch_norm_decay,
                batch_norm_epsilon=batch_norm_epsilon,
                is_training=False,
                data_format=data_format)

            predictor.init_predict()

            test_dataset = cifar10.Cifar10DataSet(data_dir,
                                                  subset='test',
                                                  use_distortion=False)
            test_iterator = test_dataset.make_batch(batch_size)
            test_id_batch, test_image_batch, test_label_batch = test_iterator.get_next(
            )

            test_pred = predictor.predict(test_image_batch,
                                          input_data_format='channels_last')
            test_pred = test_pred['classes']

            inference_fn = lambda model_path=None: \
                                evaluator.inference([test_id_batch, test_pred],
                                                    test_iterator,
                                                    model_path=model_path)

            global eval_names
            names = ['loss', 'acc']

        melt.apps.train_flow(ops,
                             names=names,
                             metric_eval_fn=metric_eval_fn,
                             inference_fn=inference_fn,
                             model_dir=FLAGS.model_dir,
                             num_steps_per_epoch=num_train_examples //
                             batch_size)
Пример #14
0
def main(_):
    num_train_examples = 45000
    melt.apps.init()

    batch_size = melt.batch_size()
    num_gpus = melt.num_gpus()

    batch_size_per_gpu = FLAGS.batch_size

    # batch size not changed but FLAGS.batch_size will change to batch_size / num_gpus
    #print('--------------batch_size, FLAGS.batch_size, num_steps_per_epoch', batch_size, FLAGS.batch_size, num_train_examples // batch_size)

    global_scope = FLAGS.algo
    with tf.variable_scope(global_scope) as global_scope:
        data_format = 'channels_first'
        num_layers = 44
        batch_norm_decay = 0.997
        batch_norm_epsilon = 1e-05
        data_dir = './mount/data/cifar10/'
        with tf.variable_scope('main') as scope:
            model = cifar10_model.ResNetCifar10(
                num_layers,
                batch_norm_decay=batch_norm_decay,
                batch_norm_epsilon=batch_norm_epsilon,
                training=True,
                data_format=data_format)

            dataset = cifar10.Cifar10DataSet(data_dir,
                                             subset='train',
                                             use_distortion=True)
            # this is faster then above method
            iterator = dataset.make_batch(batch_size)
            batch = iterator.get_next()

            ## Now below is also ok...
            # x = {'id': batch[0], 'image': batch[1]}
            # y = batch[2]
            # batch = (x, y)
            # x, y = melt.split_batch(batch, batch_size, num_gpus)
            # image_batches, label_batches = [item['image'] for item in x], y

            _, image_batches, label_batches = melt.split_batch(
                batch, batch_size, num_gpus)

            def loss_function(i):
                return tower_loss(model, image_batches[i], label_batches[i])

            label_batch = label_batches[-1]

            #loss_function = lambda: tower_loss(model, image_batch, label_batch)
            loss = melt.tower(loss_function, num_gpus)
            pred = model.predict()
            pred = pred['classes']
            #label_batch = dataset.label_batch
            acc = tf.reduce_mean(tf.to_float(tf.equal(pred, label_batch)))

            #tf.summary.image('train/image', dataset.image_batch)
            # # Compute confusion matrix
            # matrix = tf.confusion_matrix(label_batch, pred, num_classes=10)
            # # Get a image tensor for summary usage
            # image_tensor = draw_confusion_matrix(matrix)
            # tf.summary.image('train/confusion_matrix', image_tensor)

            scope.reuse_variables()
            ops = [loss, acc]

            validator = cifar10_model.ResNetCifar10(
                num_layers,
                batch_norm_decay=batch_norm_decay,
                batch_norm_epsilon=batch_norm_epsilon,
                training=False,
                data_format=data_format)

            valid_dataset = cifar10.Cifar10DataSet(data_dir,
                                                   subset='valid',
                                                   use_distortion=False)
            valid_iterator = valid_dataset.make_batch(batch_size)
            valid_batch = valid_iterator.get_next()
            valid_id_batches, valid_image_batches, valid_label_batches = melt.split_batch(
                valid_batch, batch_size, num_gpus, training=False)

            def valid_loss_fn(i):
                valid_loss = tower_loss(validator, valid_image_batches[i],
                                        valid_label_batches[i])
                valid_pred = validator.predict()
                return valid_id_batches[i], valid_loss, valid_pred[
                    'classes'], valid_label_batches[i], valid_image_batches[i]

            num_valid_examples = dataset.num_examples_per_epoch(subset='valid')
            valid_ops = melt.tower(valid_loss_fn, num_gpus, training=False)

            ## seems not work with non rpeat mode..
            #tf.summary.image('valid/image', valid_image_batch)
            ## Compute confusion matrix
            #matrix = tf.confusion_matrix(valid_label_batch, valid_pred, num_classes=10)
            ## Get a image tensor for summary usage
            #image_tensor = draw_confusion_matrix(matrix)
            #tf.summary.image('valid/confusion_matrix', image_tensor)

            #loss_function = lambda: tower_loss(validator, val_image_batch, val_label_batch)
            #val_loss = melt.tower_losses(loss_function, FLAGS.num_gpus, training=False)
            #eval_ops = [val_loss]

            metric_eval_fn = lambda model_path=None: \
                                evaluator.evaluate(valid_ops,
                                                   valid_iterator,
                                                   num_steps=-(-num_valid_examples // batch_size),
                                                   num_examples=num_valid_examples,
                                                   model_path=model_path,
                                                   num_gpus=num_gpus)

            predictor = cifar10_model.ResNetCifar10(
                num_layers,
                batch_norm_decay=batch_norm_decay,
                batch_norm_epsilon=batch_norm_epsilon,
                training=False,
                data_format=data_format)

            predictor.init_predict()

            test_dataset = cifar10.Cifar10DataSet(data_dir,
                                                  subset='test',
                                                  use_distortion=False)
            test_iterator = test_dataset.make_batch(batch_size)

            test_batch = test_iterator.get_next()
            test_id_batches, test_image_batches, test_label_batches = melt.split_batch(
                test_batch, batch_size, num_gpus, training=False)

            def test_fn(i):
                test_pred = predictor.predict(test_image_batches[i])
                test_pred = test_pred['classes']
                return test_id_batches[i], test_pred

            num_test_examples = dataset.num_examples_per_epoch(subset='test')
            test_ops = melt.tower(test_fn, num_gpus, training=False)
            inference_fn = lambda model_path=None: \
                                evaluator.inference(test_ops,
                                                    test_iterator,
                                                    num_steps=-(-num_test_examples // batch_size),
                                                    num_examples=num_test_examples,
                                                    model_path=model_path,
                                                    num_gpus=num_gpus)

            global eval_names
            names = ['loss', 'acc']

        melt.apps.train_flow(ops,
                             names=names,
                             metric_eval_fn=metric_eval_fn,
                             inference_fn=inference_fn,
                             model_dir=FLAGS.model_dir,
                             num_steps_per_epoch=num_train_examples //
                             batch_size)