Ejemplo n.º 1
0
def main(_):
    predict_dir = '/home/give/Documents/dataset/BOT_Game/train/positive-hm/method5'
    file_names = os.listdir(predict_dir)
    file_pathes = [os.path.join(predict_dir, file_name) for file_name in file_names]
    image_values = [np.array(Image.open(file_path).convert('RGB')) for file_path in file_pathes]
    image_values = np.asarray(image_values, np.float32)
    image_values = image_values[:net_config.BATCH_SIZE]
    new_image_values = []
    for index, image_value in enumerate(image_values):
        image_value = np.asarray(image_value, np.float32)
        image_value = image_value * (1.0 / np.max(image_value))
        image_value = np.asarray(image_value, np.float32)
        img = np.zeros([net_config.IMAGE_W, net_config.IMAGE_H, net_config.IMAGE_CHANNEL])
        for j in range(net_config.IMAGE_CHANNEL):
            img[:, :, j] = np.array(
                Image.fromarray(image_value[:, :, j]).resize([net_config.IMAGE_W, net_config.IMAGE_H])
            )
        new_image_values.append(np.array(img))
    image_values = np.array(new_image_values)
    image_tensor = tf.placeholder(
        tf.float32,
        [net_config.BATCH_SIZE, net_config.IMAGE_W, net_config.IMAGE_H, net_config.IMAGE_CHANNEL]
    )
    label_tensor = tf.placeholder(
        tf.int32,
        [net_config.BATCH_SIZE]
    )
    logits = inference(image_tensor,
                       num_classes=2,
                       is_training=True,
                       bottleneck=False,)
    save_model_path = '/home/give/PycharmProjects/StomachCanner/classification/Net/ResNetHeatMap/models/method5-512'
    print 'image_tensor is ', image_tensor
    print np.shape(image_values)
    val(image_tensor, logits, image_values, label_tensor, [0]*len(image_values), save_model_path=save_model_path)
Ejemplo n.º 2
0
def freeze_mobilenet(meta_file):

    tf.reset_default_graph()

    inputs = tf.placeholder(tf.float32, [None, IMAGE_HEIGHT, IMAGE_WIDTH, 3],
                            name='img_in')
    logits = resnet.inference(inputs, NUM_RESIDUAL_BLOCKS, reuse=False)
    output = tf.identity(logits, name='final_preds')
    output_node_names = 'final_preds'

    output_txt_name = output_node_names + '.pbtxt'
    output_pb_name = output_node_names + '.pb'

    rest_var = slim.get_variables_to_restore()
    with tf.Session() as sess:
        graph = tf.get_default_graph()
        input_graph_def = graph.as_graph_def()

        saver = tf.train.Saver(rest_var)
        saver.restore(sess, meta_file)

        tf.train.write_graph(sess.graph_def,
                             "./",
                             output_txt_name,
                             as_text=True)
        output_graph_def = tf.graph_util.convert_variables_to_constants(
            sess, sess.graph_def, output_node_names.split(","))
        tf.train.write_graph(output_graph_def,
                             "./",
                             output_pb_name,
                             as_text=False)
Ejemplo n.º 3
0
def main(_):
    with tf.Graph().as_default():
        global_step = tf.contrib.framework.get_or_create_global_step()

        filenames = cifar10_input.get_filenames(data_dir=FLAGS.dataset_dir,
                                                isTrain=True)

        images, labels = cifar10_input.load_batch(filenames=filenames,
                                                  batch_size=FLAGS.batch_size,
                                                  isTrain=True,
                                                  isShuffle=True)

        logits = resnet.inference(images, isTrain=True)

        total_loss = resnet.loss(logits, labels)

        lr = tf.placeholder(shape=[], dtype=tf.float32)

        train_op = train(total_loss, global_step, lr)

        class _Hook(tf.train.SessionRunHook):
            def begin(self):
                self._lr = FLAGS.learning_rate
                self._start_time = time.time()

            def before_run(self, run_context):
                return tf.train.SessionRunArgs([total_loss, global_step],
                                               feed_dict={lr: self._lr})

            def after_run(self, run_context, run_values):
                step = run_values.results[1] - 1

                self._lr = _lr_in_stage(FLAGS.learning_rate, step, STAGE)

                if step % FLAGS.log_frequency == 0:
                    current_time = time.time()
                    duration = current_time - self._start_time
                    self._start_time = current_time

                    loss_value = run_values.results[0]

                    examples_per_sec = FLAGS.log_frequency * FLAGS.batch_size / duration
                    sec_per_batch = float(duration / FLAGS.log_frequency)
                    format_str = (
                        'lr: %.4f, step: %d, loss: %.2f (%.1f examples/sec; %.3f sec/batch)'
                    )
                    print(format_str % (self._lr, step, loss_value,
                                        examples_per_sec, sec_per_batch))

        with tf.train.MonitoredTrainingSession(
                checkpoint_dir=FLAGS.train_dir,
                hooks=[
                    tf.train.StopAtStepHook(last_step=FLAGS.max_steps),
                    tf.train.NanTensorHook(total_loss),
                    _Hook()
                ],
                config=tf.ConfigProto(log_device_placement=FLAGS.
                                      log_device_placement)) as mon_sess:
            while not mon_sess.should_stop():
                mon_sess.run(train_op)
Ejemplo n.º 4
0
def main(_):
    val_positive_path = '/home/give/Documents/dataset/BOT_Game/0-testdataset-jpg'
    val_dataset = DataSet(positive_path=val_positive_path, )
    images, labels = distorted_inputs(val_dataset)
    print images

    is_training = tf.placeholder('bool', [], name='is_training')
    logits = inference(images,
                       num_classes=2,
                       is_training=False,
                       bottleneck=False,
                       num_blocks=[2, 2, 2, 2])
    predicted_labels = val(is_training,
                           logits,
                           images,
                           labels,
                           is_testing=True)
    predicted_paths = val_dataset.images_names
    print np.sum(predicted_labels)
    print predicted_labels
    predicted_names = [
        os.path.basename(path).replace('.jpg', '.tiff')
        for path in predicted_paths
    ]
    write_result(predicted_labels, predicted_names, u'./毕业的西瓜籽.txt')
Ejemplo n.º 5
0
def prune():
    """Do pruning, and save pruned model for retrain
    """
    with tf.Graph().as_default() as g:
        # Input evaluation data
        images, labels = rn.inputs(eval_data=True)

        # inference model.
        logits = rn.inference(images, 15)

        # Calculate predictions.
        top_k_op = tf.nn.in_top_k(logits, labels, 1)

        # Create a saver
        saver = tf.train.Saver()

        # Create session to restore, and restore data
        sess = tf.InteractiveSession()

        # Queue runner
        tf.train.start_queue_runners()

        ckpt = tf.train.get_checkpoint_state(FLAGS.checkpoint_dir)
        if ckpt and ckpt.model_checkpoint_path:
            # Restores from checkpoint
            saver.restore(sess, ckpt.model_checkpoint_path)
            # extract global_step from it.
            global_step_num = ckpt.model_checkpoint_path.split('/')[-1].split(
                '-')[-1]
        else:
            print('No checkpoint file found')
            return

        precision = eval_once(sess, top_k_op)
    """
Ejemplo n.º 6
0
def main(argv=None):
    with tf.Session() as sess:
        data_dir = FLAGS.data_dir
        files = [os.path.join(data_dir, item) for item in os.listdir(data_dir)]
        # files = random.sample(files,  800)
        images = tf.placeholder(tf.float32,
                                [None, RESIZE_FINAL, RESIZE_FINAL, 3])
        logits = inference(
            images,
            False,
            num_classes=2,
            num_blocks=[3, 4, 6, 3],  # defaults to 50-layer network
            use_bias=False,  # defaults to using batch norm
            bottleneck=True)
        init = tf.global_variables_initializer()
        resnet_variables = tf.global_variables()
        saver = tf.train.Saver(resnet_variables)
        saver.restore(sess, os.path.join(FLAGS.model_dir, FLAGS.ckpt_file))

        softmax_output = tf.nn.softmax(logits)
        if FLAGS.target:
            print('Creating output file %s' % FLAGS.target)
            output = open(os.path.join(FLAGS.data_dir, FLAGS.target), 'w')
            writer = csv.writer(output)
            writer.writerow(('file', 'label', 'score'))

        num_batches = int(math.ceil(len(files)) / MAX_BATCH_SZ)
        pg = ProgressBar(num_batches)
        # try:
        for j in range(num_batches):
            start_offset = j * MAX_BATCH_SZ
            end_offset = min((j + 1) * MAX_BATCH_SZ, len(files))

            batch_image_files = files[start_offset:end_offset]
            images_ = []
            for file in batch_image_files:
                print file
                image_buffer = tf.read_file(file)
                bbox = []
                image = image_preprocessing(image_buffer, [], False)
                images_.append(image)
            image_batch = tf.stack(images_)
            batch_results = sess.run(softmax_output,
                                     feed_dict={images: image_batch.eval()})
            batch_sz = batch_results.shape[0]

            for i in range(batch_sz):
                output_i = batch_results[i]
                best_i = np.argmax(output_i)

                best_choice = (label_list[best_i], output_i[best_i])
                if writer is not None:
                    f = batch_image_files[i]
                    writer.writerow(
                        (f, best_choice[0], '%.2f' % best_choice[1]))
            pg.update()
        pg.done()
Ejemplo n.º 7
0
def main(_):
    images, labels = distorted_inputs()

    logits = inference(images,
                       num_classes=1000,
                       is_training=True,
                       bottleneck=False,
                       num_blocks=[2, 2, 2, 2])
    train(logits, images, labels)
def main(_):
    images, labels = distorted_inputs()
    is_training = tf.placeholder('bool',[], name='is_training')
    logits = inference(images,
                       num_classes=2,
                       is_training=is_training,
                       bottleneck=True,
                       num_blocks=[3, 4, 6, 3])
    train(is_training,logits, images, labels)
Ejemplo n.º 9
0
def main(_):
    images, labels = distorted_inputs()
    print images

    is_training = tf.placeholder('bool', [], name='is_training')
    logits = inference(images,
                       num_classes=2,
                       is_training=False,
                       bottleneck=False,
                       num_blocks=[2, 2, 2, 2])
    val(is_training, logits, images, labels)
Ejemplo n.º 10
0
def tower_loss(scope):
    images, labels = input_data.read_cifar10(FLAGS.data_dir, True,
                                             FLAGS.batch_size, True)
    logits = resnet.inference(images, FLAGS.num_units_per_block,
                              FLAGS.is_training)
    _ = resnet.loss(logits, labels)
    losses = tf.get_collection('losses', scope)
    total_loss = tf.add_n(losses, name='total_loss')
    with tf.name_scope(None) as scope:
        tf.summary.scalar("total_loss", total_loss)
    return total_loss
Ejemplo n.º 11
0
    def __init__(self, model_dir=None, img_size=None):
        if model_dir:
            self._model_dir = model_dir

        if img_size:
            self._img_size = img_size
        # C,H,W 
        data_shape = (3, self._img_size[1], self._img_size[0])
        self.inferencer = fluid.Inferencer(
                infer_func=resnet.inference(data_shape, self._label_cnt),
                param_path=self._model_dir,
                place=fluid.CPUPlace())
Ejemplo n.º 12
0
def main(_):
    [train_images, train_labels], [val_images, val_labels] = distorted_inputs()
    print train_images
    is_training = tf.placeholder('bool', [], name='is_training')
    images, labels = tf.cond(is_training, lambda: (train_images, train_labels),
                             lambda: (val_images, val_labels))
    logits = inference(
        images,
        num_classes=2,
        is_training=True,
        bottleneck=False,
    )
    save_model_path = '/home/give/PycharmProjects/StomachCanner/classification/Net/ResNet/models/method4'
    train(is_training, logits, images, labels, save_model_path=save_model_path)
Ejemplo n.º 13
0
def main(_):
    [train_images, train_labels], [val_images, val_labels] = distorted_inputs()
    is_training = tf.placeholder('bool', [], name='is_training')
    images, labels = tf.cond(is_training, lambda: (train_images, train_labels),
                             lambda: (val_images, val_labels))
    logits_multi_task = []
    for i in range(FLAGS.letter_num_per_vc):
        logits = inference(
            images,
            task_name='task_' + str(i),
            num_classes=FLAGS.max_single_vc_length,
            is_training=True,
            bottleneck=False,
        )
        logits_multi_task.append(logits)
    save_model_path = '/home/give/PycharmProjects/AIChallenger/ResNet/models'
    train(is_training,
          logits_multi_task,
          images,
          labels,
          save_model_path=save_model_path)
Ejemplo n.º 14
0
def evaluate(validation_set, validation_labels):
    """Evaluate model on Dataset for a number of steps."""
    with tf.Graph().as_default(), tf.device('/cpu:0'):
        # Graph creation
        batch_size = validation_set.shape[0]
        images_placeholder, labels_placeholder = cifar10.placeholder_inputs(
            batch_size)
        logits = resnet.inference(images_placeholder,
                                  FLAGS.num_residual_blocks,
                                  reuse=False)
        predictions = tf.nn.softmax(logits)
        in_top1 = tf.to_float(
            tf.nn.in_top_k(predictions, labels_placeholder, k=1))
        num_correct = tf.reduce_sum(in_top1)
        validation_accuracy = (batch_size - num_correct) / float(batch_size)
        #  validation_accuracy = tf.reduce_sum(resnet.evaluation(logits, labels_placeholder)) / tf.constant(batch_size)
        validation_loss = resnet.loss(logits, labels_placeholder)

        # Reference to sess and saver
        sess = tf.Session()
        saver = tf.train.Saver()

        # Create summary writer
        graph_def = tf.get_default_graph().as_graph_def()
        summary_writer = tf.summary.FileWriter(FLAGS.eval_dir,
                                               graph_def=graph_def)
        step = -1
        while True:
            step = do_eval(saver,
                           summary_writer,
                           validation_accuracy,
                           validation_loss,
                           images_placeholder,
                           labels_placeholder,
                           validation_set,
                           validation_labels,
                           prev_global_step=step)
            if FLAGS.run_once:
                break
            time.sleep(FLAGS.eval_interval_secs)
Ejemplo n.º 15
0
def evaluate():
    with tf.Graph().as_default():
        test_images, test_labels = input_data.read_cifar10(
            FLAGS.DATA_DIR, False, FLAGS.BATCH_SIZE, False)
        logits = resnet.inference(test_images, 7, is_training=False)
        init = tf.global_variables_initializer()
        acc_num = num_correct_predition(logits, test_labels)
        saver = tf.train.Saver()
        with tf.Session() as sess:
            sess.run(init)
            coord = tf.train.Coordinator()
            threads = tf.train.start_queue_runners(sess=sess, coord=coord)
            print('reading checkpoints...')
            checkpoint = tf.train.latest_checkpoint(FLAGS.train_dir)
            if checkpoint:
                saver.restore(sess, checkpoint)
                print("restore from the checkpoint {0}".format(checkpoint))
            else:
                print('No checkpoint file found')
            try:
                print('\nEvaluating...')
                num_step = int(math.floor(FLAGS.num_test / FLAGS.BATCH_SIZE))
                num_sample = num_step * FLAGS.BATCH_SIZE
                step = 0
                total_correct = 0
                print(num_step, num_sample)
                while step < num_step:
                    batch_correct = sess.run(acc_num)
                    total_correct += np.sum(batch_correct)
                    step += 1
                    print(step)
                print('Total testing samples: %d' % num_sample)
                print('Total correct predictions: %d' % total_correct)
                print('Average accuracy: %.2f%%' %
                      (100 * total_correct / num_sample))
            except Exception as e:
                pass
            finally:
                coord.request_stop()
                coord.join(threads)
Ejemplo n.º 16
0
def main(_):
    train_dataset, val_dataset = distorted_inputs()
    train_batch = generate_next_batch(train_dataset, net_config.BATCH_SIZE,
                                      None)
    val_batch = generate_next_batch(val_dataset, net_config.BATCH_SIZE, None)
    is_training = tf.placeholder('bool', [], name='is_training')
    image_tensor = tf.placeholder(tf.float32, [
        None, net_config.IMAGE_W, net_config.IMAGE_H, net_config.IMAGE_CHANNEL
    ])
    label_tensor = tf.placeholder(tf.int32, [None])
    logits = inference(
        image_tensor,
        num_classes=2,
        is_training=True,
        bottleneck=False,
    )
    save_model_path = '/home/give/PycharmProjects/StomachCanner/classification/Net/ResNetHeatMap/models/method5-512/1740.0'
    train(is_training,
          logits,
          image_tensor,
          label_tensor,
          train_batch,
          val_batch,
          save_model_path=save_model_path)
def train():
    """Train CIFAR-10 for a number of steps."""
    with tf.Graph().as_default() as g:
        global_step = tf.train.get_or_create_global_step()

        # Get images and labels for CIFAR-10.
        # Force input pipeline to CPU:0 to avoid operations sometimes ending up on
        # GPU and resulting in a slow down.
        with tf.device('/cpu:0'):
            images, labels = resnet.distorted_inputs()

        # Build a Graph that computes the logits predictions from the
        # inference model.
        logits, tensor_list = resnet.inference(images)

        # Calculate loss.
        loss = resnet.loss(logits, labels)

        # Build a Graph that trains the model with one batch of examples and
        # updates the model parameters.
        train_op, retrieve_list = resnet.train(loss, tensor_list, global_step)

        class _LoggerHook(tf.train.SessionRunHook):
            """Logs loss and runtime."""
            def begin(self):
                self._step = -1
                self._start_time = time.time()

            def before_run(self, run_context):
                self._step += 1
                return tf.train.SessionRunArgs(loss)  # Asks for loss value.

            def after_run(self, run_context, run_values):
                if self._step % FLAGS.log_frequency == 0:
                    current_time = time.time()
                    duration = current_time - self._start_time
                    self._start_time = current_time

                    loss_value = run_values.results
                    examples_per_sec = FLAGS.log_frequency * FLAGS.batch_size / duration
                    sec_per_batch = float(duration / FLAGS.log_frequency)

                    format_str = (
                        '%s: step %d, loss = %.2f (%.1f examples/sec; %.3f '
                        'sec/batch)')
                    print(format_str % (datetime.now(), self._step, loss_value,
                                        examples_per_sec, sec_per_batch))

        class _SparsityHook(tf.train.SessionRunHook):
            """Logs loss and runtime."""
            def begin(self):
                self._step = -1
                mode = sparsity_monitor.Mode.monitor
                data_format = "NHWC"
                self.monitor = sparsity_monitor.SparsityMonitor(mode, data_format, FLAGS.monitor_interval,\
                                                                FLAGS.monitor_period, retrieve_list)

            def before_run(self, run_context):
                self._step += 1
                selected_list = self.monitor.scheduler_before(self._step)
                return tf.train.SessionRunArgs(
                    selected_list)  # Asks for loss value.

            def after_run(self, run_context, run_values):
                self.monitor.scheduler_after(run_values.results, self._step,
                                             os.getcwd(), FLAGS.file_io)

        sparsity_summary_op = tf.summary.merge_all()
        summary_writer = tf.summary.FileWriter(FLAGS.sparsity_dir, g)

        start = time.time()
        with tf.train.MonitoredTrainingSession(
                checkpoint_dir=FLAGS.train_dir,
                hooks=
            [
                tf.train.StopAtStepHook(last_step=FLAGS.max_steps),
                tf.train.NanTensorHook(loss),
                #tf.train.SummarySaverHook(save_steps=FLAGS.log_frequency, summary_writer=summary_writer, summary_op=sparsity_summary_op),
                _LoggerHook(),
                _SparsityHook()
            ],
                config=tf.ConfigProto(log_device_placement=FLAGS.
                                      log_device_placement)) as mon_sess:
            while not mon_sess.should_stop():
                mon_sess.run(train_op)
        end = time.time()
        print(end - start)
Ejemplo n.º 18
0
g = tf.Graph().as_default()
tf.device('/cpu:0')
# Get images and labels.
image = tf.placeholder(tf.string, name='input')
reshaped_image = tf.to_float(tf.image.decode_jpeg(image,
                                                  channels=num_channels))
reshaped_image = tf.image.resize_images(reshaped_image,
                                        (load_size[0], load_size[1]))
reshaped_image = _test_preprocess(reshaped_image, crop_size, num_channels)
imgs = reshaped_image[None, ...]
# Performing computations on a GPU
tf.device('/gpu:0')
# Build a Graph that computes the logits predictions from the
# inference model.
logits = resnet.inference(imgs, depth, num_classes, 0.0, False)
top5_id = tf.nn.top_k(tf.nn.softmax(logits[0]), 5)
top5ind_id = top5_id.indices
top5val_id = top5_id.values
# Count
top3_cn = tf.nn.top_k(tf.nn.softmax(logits[1]), 3)
top3ind_cn = top3_cn.indices
top3val_cn = top3_cn.values
# Additional Attributes (e.g. description)
top1_bh = [None] * 6
top1ind_bh = [None] * 6
top1val_bh = [None] * 6

for i in range(0, 6):
    top1_bh[i] = tf.nn.top_k(tf.nn.softmax(logits[i + 2]), 1)
    top1ind_bh[i] = top1_bh[i].indices
Ejemplo n.º 19
0
def convert(graph, img, img_p, layers):
    caffe_model = load_caffe(img_p, layers)

    #for i, n in enumerate(caffe_model.params):
    #    print n

    param_provider = CaffeParamProvider(caffe_model)

    if layers == 50:
        num_blocks = [3, 4, 6, 3]
    elif layers == 101:
        num_blocks = [3, 4, 23, 3]
    elif layers == 152:
        num_blocks = [3, 8, 36, 3]

    with tf.device('/cpu:0'):
        images = tf.placeholder("float32", [None, 224, 224, 3], name="images")
        logits = resnet.inference(images,
                                  is_training=False,
                                  num_blocks=num_blocks,
                                  preprocess=True,
                                  bottleneck=True)
        prob = tf.nn.softmax(logits, name='prob')

    # We write the metagraph first to avoid adding a bunch of
    # assign ops that are used to set variables from caffe.
    # The checkpoint is written to at the end.
    tf.train.export_meta_graph(filename=meta_fn(layers))

    vars_to_restore = tf.all_variables()
    saver = tf.train.Saver(vars_to_restore)

    sess = tf.Session()
    sess.run(tf.initialize_all_variables())

    assigns = []
    for var in vars_to_restore:
        #print var.op.name
        data = parse_tf_varnames(param_provider, var.op.name, layers)
        #print "caffe data shape", data.shape
        #print "tf shape", var.get_shape()
        assigns.append(var.assign(data))
    sess.run(assigns)

    #for op in tf.get_default_graph().get_operations():
    #    print op.name

    i = [
        graph.get_tensor_by_name("scale1/Relu:0"),
        graph.get_tensor_by_name("scale2/MaxPool:0"),
        graph.get_tensor_by_name("scale2/block1/Relu:0"),
        graph.get_tensor_by_name("scale2/block2/Relu:0"),
        graph.get_tensor_by_name("scale2/block3/Relu:0"),
        graph.get_tensor_by_name("scale3/block1/Relu:0"),
        graph.get_tensor_by_name("scale5/block3/Relu:0"),
        graph.get_tensor_by_name("avg_pool:0"),
        graph.get_tensor_by_name("prob:0"),
    ]

    o = sess.run(i, {images: img[np.newaxis, :]})

    assert_almost_equal(caffe_model.blobs['conv1'].data, o[0])
    assert_almost_equal(caffe_model.blobs['pool1'].data, o[1])
    assert_almost_equal(caffe_model.blobs['res2a'].data, o[2])
    assert_almost_equal(caffe_model.blobs['res2b'].data, o[3])
    assert_almost_equal(caffe_model.blobs['res2c'].data, o[4])
    assert_almost_equal(caffe_model.blobs['res3a'].data, o[5])
    assert_almost_equal(caffe_model.blobs['res5c'].data, o[6])
    #assert_almost_equal(np.squeeze(caffe_model.blobs['pool5'].data), o[7])

    print_prob(o[8][0])

    prob_dist = np.linalg.norm(caffe_model.blobs['prob'].data - o[8])
    print 'prob_dist ', prob_dist
    assert prob_dist < 0.2  # XXX can this be tightened?

    # We've already written the metagraph to avoid a bunch of assign ops.
    saver.save(sess, checkpoint_fn(layers), write_meta_graph=False)
Ejemplo n.º 20
0
    return var


with tf.device('/gpu:1'):
    with tf.Graph().as_default():

        inputs_placeholder = tf.placeholder(tf.float32,
                                            shape=(batch_size, num_frames,
                                                   height, width, channels))

        with tf.variable_scope('b_rnn'):
            rw = _variable_with_weight_decay('rw', [hidden_size, n_classes],
                                             0.0005)
            rb = _variable_with_weight_decay('rb', [n_classes], 0.000)

        feature = resnet.inference(inputs_placeholder)
        outputs = brnn_model.BiRNN(feature, rw, rb)
        outputs = tf.nn.softmax(outputs)
        predict = []

        with tf.Session() as sess:
            saver = tf.train.Saver()
            saver.restore(sess, model_filename)

            # test
            videos = read.readFile()
            for l in videos:

                for batch in range(int(len(l[1]) / batch_size)):

                    nextX, segments = read.readTrainData(batch, l, batch_size)
    processimg = np.transpose(processimg, (2, 0, 1))
    processimg = np.expand_dims(processimg, 0)
    processimg = processimg / 255.0
    processimg = (processimg - 0.5) / 0.5
    return processimg


if __name__ == '__main__':
    args = parse_args()
    fd_detector = facedetect.facedetect("tf-ssh")
    cudnn.enabled = True

    gpu = "cuda:0"
    snapshot_path = args.snapshot

    model = resnet.inference(10).cuda()

    saved_state_dict = torch.load(snapshot_path)
    model.load_state_dict(saved_state_dict)

    model.eval()

    camera = cv2.VideoCapture(0)

    while True:
        ret, frame = camera.read()
        if ret == False:
            break

        cv2_frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
        rects, imgs = fd_detector.findfaces(cv2_frame)
def main(_):
    with tf.Graph().as_default() as g:
        filenames = cifar10_input.get_filenames(data_dir=FLAGS.dataset_dir,
                                                isTrain=False)

        images, labels = cifar10_input.load_batch(filenames=filenames,
                                                  batch_size=FLAGS.batch_size,
                                                  isTrain=False,
                                                  isShuffle=False)

        logits = resnet.inference(images)

        correct_prediction = tf.nn.in_top_k(predictions=logits,
                                            targets=labels,
                                            k=1)

        variable_averages = tf.train.ExponentialMovingAverage(
            decay=FLAGS.variable_averages_decay)
        variables_to_restore = variable_averages.variables_to_restore()
        saver = tf.train.Saver(variables_to_restore)

        summary_op = tf.summary.merge_all()

        summary_writer = tf.summary.FileWriter(FLAGS.test_dir, g)

        with tf.Session() as sess:
            checkpoint = tf.train.get_checkpoint_state(FLAGS.train_dir)
            if checkpoint and checkpoint.model_checkpoint_path:
                saver.restore(sess, checkpoint.model_checkpoint_path)

                global_step = checkpoint.model_checkpoint_path.split(
                    '/')[-1].split('-')[-1]
            else:
                print('No checkpoint file found')
                return

            coord = tf.train.Coordinator()

            try:
                threads = []
                for qr in tf.get_collection(tf.GraphKeys.QUEUE_RUNNERS):
                    threads.extend(
                        qr.create_threads(sess=sess,
                                          coord=coord,
                                          daemon=True,
                                          start=True))

                num_iter = int(np.ceil(NUM_SAMPLES / FLAGS.batch_size))
                true_counter = 0
                total_samples = num_iter * FLAGS.batch_size
                step = 0
                while step < num_iter and not coord.should_stop():
                    num_correct = sess.run(correct_prediction)
                    true_counter += np.sum(num_correct)
                    step += 1

                precision = true_counter / total_samples
                print('%s: precision :%.3f' % (datetime.now(), precision))

                summary = tf.Summary()
                summary.ParseFromString(sess.run(summary_op))
                summary.value.add(tag='Precision', simple_value=precision)
                summary_writer.add_summary(summary, global_step)

            except Exception as e:
                coord.request_stop(e)

            coord.request_stop()
            coord.join(threads, stop_grace_period_secs=10)
def train():
  """Train CIFAR-10 for a number of steps."""
  with tf.Graph().as_default() as g:
    global_step = tf.train.get_or_create_global_step()

    # Get images and labels for CIFAR-10.
    # Force input pipeline to CPU:0 to avoid operations sometimes ending up on
    # GPU and resulting in a slow down.
    with tf.device('/cpu:0'):
      images, labels = resnet.distorted_inputs()

    # Build a Graph that computes the logits predictions from the
    # inference model.
    logits, tensor_list = resnet.inference(images)

    # Calculate loss.
    loss = resnet.loss(logits, labels)

    # Build a Graph that trains the model with one batch of examples and
    # updates the model parameters.
    train_op, retrieve_list = resnet.train(loss, tensor_list, global_step)

    class _LoggerHook(tf.train.SessionRunHook):
      """Logs loss and runtime."""

      def begin(self):
        self._step = -1
        self._start_time = time.time()

      def before_run(self, run_context):
        self._step += 1
        return tf.train.SessionRunArgs(loss)  # Asks for loss value.

      def after_run(self, run_context, run_values):
        if self._step % FLAGS.log_frequency == 0:
          current_time = time.time()
          duration = current_time - self._start_time
          self._start_time = current_time

          loss_value = run_values.results
          examples_per_sec = FLAGS.log_frequency * FLAGS.batch_size / duration
          sec_per_batch = float(duration / FLAGS.log_frequency)

          format_str = ('%s: step %d, loss = %.2f (%.1f examples/sec; %.3f '
                        'sec/batch)')
          print (format_str % (datetime.now(), self._step, loss_value,
                               examples_per_sec, sec_per_batch))

    class _SparsityHook(tf.train.SessionRunHook):
      """Logs loss and runtime."""

      def begin(self):
       self._bs_util = block_sparsity_util.BlockSparsityUtil(FLAGS.block_size)
       self._internal_index_keeper = collections.OrderedDict()
       self._local_step = collections.OrderedDict()
       #self._fig, self._ax = plt.subplots()

      def before_run(self, run_context):
        self.selected_list = []
        for tensor_tuple in retrieve_list:
          self.selected_list.append(tensor_tuple[0])
          self.selected_list.append(tensor_tuple[1])
        return tf.train.SessionRunArgs(self.selected_list)  # Asks for loss value.

      def after_run(self, run_context, run_values):
        self._data_list = []
        self._sparsity_list = []
        for i in range(len(run_values.results)):
          if i % 2 == 0:
            # tensor
            self._data_list.append(run_values.results[i])
          if i % 2 == 1:
            # sparsity
            self._sparsity_list.append(run_values.results[i])
        assert len(self._sparsity_list) == len(self.selected_list) / 2
        assert len(self._data_list) == len(self.selected_list) / 2
        num_data = len(self._data_list)
        format_str = ('local_step: %d %s: sparsity = %.2f difference percentage = %.2f')
        zero_block_format_str = ('local_step: %d %s: zero block ratio = %.2f')
        for i in range(num_data):
          sparsity = self._sparsity_list[i]
          shape = self.selected_list[2*i].get_shape()
          tensor_name = self.selected_list[2*i].name
          batch_idx = 0
          channel_idx = 0
          if tensor_name in self._local_step:
            if self._local_step[tensor_name] == FLAGS.monitor_interval and \
               FLAGS.log_animation:
              fig, ax = plt.subplots()
              ani = animation.FuncAnimation(fig, animate, frames=FLAGS.monitor_interval,
                                            fargs=(ax, tensor_name,),
                                            interval=500, repeat=False, blit=True)                        
              
              figure_name = tensor_name.replace('/', '_').replace(':', '_')
              ani.save(figure_name+'.gif', dpi=80, writer='imagemagick')
              self._local_step[tensor_name] += 1
              continue
            if self._local_step[tensor_name] >= FLAGS.monitor_interval:
              continue
          if tensor_name not in self._local_step and sparsity > FLAGS.sparsity_threshold:
            self._local_step[tensor_name] = 0
            zero_block_ratio = self._bs_util.zero_block_ratio_matrix(self._data_list[i], shape)
            print(zero_block_format_str % (self._local_step[tensor_name], tensor_name, zero_block_ratio))
            print (format_str % (self._local_step[tensor_name], tensor_name,
                                 sparsity, 0.0))
            self._internal_index_keeper[tensor_name] = get_non_zero_index(self._data_list[i], shape)
            if tensor_name not in data_dict:
              data_dict[tensor_name] = []
            data_dict[tensor_name].append(feature_map_extraction(self._data_list[i], batch_idx, channel_idx))
            self._local_step[tensor_name] += 1
          elif tensor_name in self._local_step and self._local_step[tensor_name] > 0:
            # Inside the monitoring interval
            zero_block_ratio = self._bs_util.zero_block_ratio_matrix(self._data_list[i], shape)
            print(zero_block_format_str % (self._local_step[tensor_name], tensor_name, zero_block_ratio))
            data_length = self._data_list[i].size
            #local_index_list = get_non_zero_index(self._data_list[i], shape)
            #diff_percentage = calc_index_diff_percentage(local_index_list,
            #  self._internal_index_keeper[tensor_name], sparsity, data_length)
            #self._internal_index_keeper[tensor_name] = local_index_list
            print (format_str % (self._local_step[tensor_name], tensor_name,
                                 #sparsity, diff_percentage))
                                 sparsity, 0.0))
            data_dict[tensor_name].append(feature_map_extraction(self._data_list[i], batch_idx, channel_idx))
            self._local_step[tensor_name] += 1
          else:
            continue

    sparsity_summary_op = tf.summary.merge_all()
    summary_writer = tf.summary.FileWriter(FLAGS.sparsity_dir, g)

    start = time.time()
    with tf.train.MonitoredTrainingSession(
        checkpoint_dir=FLAGS.train_dir,
        hooks=[tf.train.StopAtStepHook(last_step=FLAGS.max_steps),
               tf.train.NanTensorHook(loss),
               #tf.train.SummarySaverHook(save_steps=FLAGS.log_frequency, summary_writer=summary_writer, summary_op=sparsity_summary_op),
               _LoggerHook(),
               _SparsityHook()],
        config=tf.ConfigProto(
            log_device_placement=FLAGS.log_device_placement)) as mon_sess:
      while not mon_sess.should_stop():
        mon_sess.run(train_op)
    end = time.time()
    print(end - start)
Ejemplo n.º 24
0
def train():
    with tf.name_scope('inputs'):
        train_images, train_labels = input_data.read_cifar10(
            FLAGS.data_dir, True, FLAGS.batch_size, True)
        test_images, test_labels = input_data.read_cifar10(
            FLAGS.data_dir, False, FLAGS.batch_size, False)
    xs = tf.placeholder(dtype=tf.float32, shape=(FLAGS.batch_size, 32, 32, 3))
    ys = tf.placeholder(dtype=tf.int32, shape=(FLAGS.batch_size, 10))
    global_step = tf.Variable(0, trainable=False)
    lerning_rate = tf.train.exponential_decay(FLAGS.lr,
                                              global_step,
                                              32000,
                                              0.1,
                                              staircase=False)
    tf.summary.scalar('lerning_rate', lerning_rate)

    logits = resnet.inference(xs, FLAGS.num_units_per_block, FLAGS.is_training)
    loss = resnet.loss(logits, ys)
    tf.summary.scalar('loss', loss)
    opt = tf.train.MomentumOptimizer(lerning_rate, 0.9)
    update_op = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
    with tf.control_dependencies(update_op):
        train = opt.minimize(loss, global_step=global_step)

    acc_op = resnet.accurracy(logits, ys)
    tf.summary.scalar('accuracy', acc_op)
    err_op = resnet.error(logits, ys)
    tf.summary.scalar('error', err_op)

    summary_op = tf.summary.merge_all()
    saver = tf.train.Saver(tf.all_variables())
    init = tf.global_variables_initializer()
    coord = tf.train.Coordinator()

    with tf.Session() as sess:
        sess.run(init)
        train_summary_writer = tf.summary.FileWriter(FLAGS.train_dir,
                                                     sess.graph)
        test_summary_writer = tf.summary.FileWriter(FLAGS.test_dir, sess.graph)
        threads = tf.train.start_queue_runners(sess=sess, coord=coord)
        start_step = 0
        checkpoint = tf.train.latest_checkpoint(FLAGS.train_dir)
        if checkpoint:
            saver.restore(sess, checkpoint)
            print("restore from the checkpoint {0}".format(checkpoint))
            start_step += int(checkpoint.split('-')[-1])
        print("start training...")
        try:
            for step in range(start_step, FLAGS.max_steps):
                if coord.should_stop():
                    break
                tra_images_batch, tra_labels_batch = sess.run(
                    [train_images, train_labels])
                tes_images_batch, tes_labels_batch = sess.run(
                    [test_images, test_labels])
                _ = sess.run(train,
                             feed_dict={
                                 xs: tra_images_batch,
                                 ys: tra_labels_batch
                             })
                if step % 50 == 0 or (step + 1) == FLAGS.max_steps:
                    tra_los, tra_acc = sess.run([loss, acc_op],
                                                feed_dict={
                                                    xs: tra_images_batch,
                                                    ys: tra_labels_batch
                                                })
                    print('Step: %d, loss: %.6f, accuracy: %.4f' %
                          (step, tra_los, tra_acc))
                if step % 200 == 0 or (step + 1) == FLAGS.max_steps:
                    tes_los, tes_acc = sess.run([loss, acc_op],
                                                feed_dict={
                                                    xs: tes_images_batch,
                                                    ys: tes_labels_batch
                                                })
                    print(
                        '***test_loss***Step: %d, loss: %.6f, accuracy: %.4f' %
                        (step, tes_los, tes_acc))
                if step % 300 == 0 or (step + 1) == FLAGS.max_steps:
                    summary_str1 = sess.run(summary_op,
                                            feed_dict={
                                                xs: tra_images_batch,
                                                ys: tra_labels_batch
                                            })
                    summary_str2 = sess.run(summary_op,
                                            feed_dict={
                                                xs: tes_images_batch,
                                                ys: tes_labels_batch
                                            })
                    train_summary_writer.add_summary(summary_str1, step)
                    test_summary_writer.add_summary(summary_str2, step)
                if step % 2000 == 0 or (step + 1) == FLAGS.max_steps:
                    checkpoint_path = os.path.join(FLAGS.train_dir,
                                                   'model.ckpt')
                    saver.save(sess, checkpoint_path, global_step=step)

        except tf.errors.OutOfRangeError:
            print('Done training -- epoch limit reached')
            coord.request_stop()
            coord.join()
        finally:
            coord.request_stop()
            coord.join(threads)
Ejemplo n.º 25
0
testrecord_images = tf.stack(testrecord_images)

# transpose to set the channel first
testrecord_images = tf.transpose(testrecord_images, perm=[0, 3, 1, 2])

global_step = tf.Variable(0, trainable=False)
boundaries = [10000, 15000, 20000, 25000]
values = [0.1, 0.05, 0.01, 0.005, 0.001]
learning_rate = tf.train.piecewise_constant(global_step, boundaries, values)
weight_decay = 2e-4
filters = 16  # the first resnet block filter number
n = 5  # the basic resnet block number, total network layers are 6n+2
ver = 2  # the resnet block version

# Get the inference logits by the model
result = resnet.inference(distorted_images, True, filters, n, ver)

# Calculate the cross entropy loss
cross_entropy = tf.losses.sparse_softmax_cross_entropy(labels=record_labels,
                                                       logits=result)
cross_entropy_mean = tf.reduce_mean(cross_entropy, name='cross_entropy')

# Add the l2 weights to the loss
# Add weight decay to the loss.
l2_loss = weight_decay * tf.add_n(
    # loss is computed using fp32 for numerical stability.
    [tf.nn.l2_loss(tf.cast(v, tf.float32)) for v in tf.trainable_variables()])
tf.summary.scalar('l2_loss', l2_loss)
loss = cross_entropy_mean + l2_loss

# Define the optimizer
Ejemplo n.º 26
0
    xx = int((img.shape[1] - short_edge) / 2)
    crop_img = img[yy:yy + short_edge, xx:xx + short_edge]
    resized_img = skimage.transform.resize(crop_img, (size, size))
    return resized_img


if __name__ == "__main__":
    sess = tf.Session()

    batch_size = 1
    num_classes = 1000

    x = tf.placeholder(tf.float32, [batch_size, 224, 224, 3])

    logits = resnet.inference(x,
                              is_training=False,
                              num_classes=num_classes,
                              num_blocks=[3, 4, 6, 3])
    logits = tf.nn.softmax(logits)

    img = load_image("data/waterbottle.jpg")
    img = img[:, :, [2, 1, 0]].reshape((1, 224, 224, 3)) * 255
    img -= np.array(IMAGENET_MEAN_BGR)

    saver = tf.train.Saver(tf.global_variables())
    saver.restore(sess, CHECKPOINT_FN)

    pred = sess.run(logits, {x: img})
    top_k = 5
    top = pred.ravel().argsort()[-top_k:][::-1]
    for t in top.ravel().tolist():
        for s in synset.synset_map.values():
Ejemplo n.º 27
0
test_batch, test_label_batch = data_preprocess.get_batch(test,
                                                         test_label,
                                                         IMG_W,
                                                         IMG_H,
                                                         BATCH_SIZE,
                                                         CAPACITY,
                                                         is_training=False)
# 将输入命名以便在Android app中使用
test_batch = tf.add(test_batch, tf.zeros([IMG_W, IMG_H, 3]), name="input")
one_hot_labels = tf.one_hot(indices=tf.cast(test_label_batch, tf.int32),
                            depth=2)

# 获取输出
learning_rate = tf.placeholder(tf.float32)
train_op, train_loss, train_logits = resnet.inference(test_batch, one_hot_labels, 1500, \
                                                                 deep=N_CLASSES, is_training=False, batch_size=BATCH_SIZE, lr=learning_rate, \
                                                                 net_mode=FLAGS.net_mode)
correct_prediction = tf.equal(tf.argmax(train_logits, 1),
                              tf.argmax(one_hot_labels, 1))
train__acc = tf.reduce_mean(tf.cast(correct_prediction, "float"))

# 初始化
sess = tf.Session()
coord = tf.train.Coordinator()  #队列监控
threads = tf.train.start_queue_runners(sess=sess, coord=coord)
total_begin_time = time.time()

# checkpoint读取
saver = tf.train.Saver()
ckpt = tf.train.get_checkpoint_state(checkpoint_dir)
if ckpt and ckpt.model_checkpoint_path:
def run_training():
    """Training process"""
    with tf.Graph().as_default():
        global_step = tf.contrib.framework.get_or_create_global_step()

        # Get images and labels for CIFAR-10.
        # Force input pipeline to CPU:0 to avoid operations sometimes ending up on
        # GPU and resulting in a slow down.
        with tf.device('/cpu:0'):
            images, labels = rn.distorted_inputs()
            # images_eval, labels_eval = rn.inputs(eval_data=True)

        # Build a Graph that computes the logits predictions from the
        # inference model.
        logits = rn.inference(images, 15)

        # Calculate loss.
        loss = rn.loss(logits, labels)
        rn._activation_summary(loss)

        # The precision
        top_k_op = tf.nn.in_top_k(logits, labels, 1)

        # Build a Graph that trains the model with one batch of examples and
        # updates the model parameters.
        train_op = rn.training(loss, global_step)

        # Build the summary Tensor based on the TF collection of Summaries.
        summary = tf.summary.merge_all()

        # Add the variable initializer Op.
        init = tf.global_variables_initializer()

        # Create a saver for writing training checkpoints.
        saver = tf.train.Saver(tf.global_variables())

        # Create a session for running Ops on the Graph.
        sess = tf.InteractiveSession()

        # Instantiate a SummaryWriter to output summaries and the Graph.
        summary_writer = tf.summary.FileWriter(FLAGS.train_dir, sess.graph)

        # And then after everything is built:
        # Run the Op to initialize the variables.
        sess.run(init)

        # Start all threads
        tf.train.start_queue_runners()

        for step in xrange(FLAGS.max_steps):
            start_time = time.time()
            image_batch, label_batch = sess.run([images, labels])

            # This is the real training step
            _, loss_value = sess.run([train_op, loss])

            duration = time.time() - start_time

            # Write the summaries and print an overview fairly often.
            if step % 10 == 0:
                # Print status to stdout.
                print('%s: Step %d: loss = %.2f (%.3f sec)' %
                      (datetime.now(), step, loss_value, duration))
                # Update the events file.
                summary_str = sess.run(summary)
                summary_writer.add_summary(summary_str, step)
                summary_writer.flush()
            """
            if (step + 1) % 100 == 0 or (step + 1) == FLAGS.max_steps:
                with tf.device('/cpu:0'):
                    print(" Step %d: precision = %.2f" % (step,
                          sess.run(tf.reduce_sum(tf.to_int32([top_k_op])))/FLAGS.batch_size))

            """

            if (step + 1) % 1000 == 0 or (step + 1) == FLAGS.max_steps:
                checkpoint_file = os.path.join(FLAGS.train_dir, 'model.ckpt')
                saver.save(sess, checkpoint_file, global_step=step)
                print(" Step %d, checkpoint saved! " % (step))
Ejemplo n.º 29
0
tf.app.flags.DEFINE_string(
    'train_dir', '/tmp/resnet_train',
    """Directory where to write event logs """
    """and checkpoint.""")
tf.app.flags.DEFINE_string('save_model_dir', './models',
                           'the path using to save model')
tf.app.flags.DEFINE_float('learning_rate', 0.01, "learning rate.")
tf.app.flags.DEFINE_integer('batch_size', net_config.BATCH_SIZE, "batch size")
tf.app.flags.DEFINE_integer('max_steps', 500000, "max steps")
tf.app.flags.DEFINE_boolean('resume', True, 'resume from latest saved state')
img_tensor = tf.placeholder(
    tf.float32,
    [None, net_config.IMAGE_W, net_config.IMAGE_H, net_config.IMAGE_CHANNEL])
label_tensor = tf.placeholder(tf.int32, [None])
logits = inference(img_tensor,
                   num_classes=2,
                   is_training=False,
                   bottleneck=False)

saver = tf.train.Saver(tf.all_variables())

init = tf.global_variables_initializer()

sess = tf.Session(config=tf.ConfigProto(log_device_placement=False))
sess.run(init)
sess.run(tf.initialize_local_variables())
tf.train.start_queue_runners(sess=sess)
# def train(logits, label_value, image_pathes):
#     from image_processing import image_preprocessing
#     filenames = image_pathes
#     labels = label_value
#     filename, label = tf.train.slice_input_producer([filenames, labels], shuffle=True)
Ejemplo n.º 30
0
def train(training_set, training_labels):
    """Train on dataset for a number of steps."""
    with tf.Graph().as_default(), tf.device('/gpu:0'):
        # Create a variable to count the number of train() calls. This equals the
        # number of batches processed * FLAGS.num_gpus.
        global_step = tf.Variable(0, name="global_step", trainable=False)

        # get num of examples in training set
        dataset_num_examples = training_set.shape[0]

        # Calculate the learning rate schedule.
        num_batches_per_epoch = int(dataset_num_examples / FLAGS.batch_size)

        # Decay the learning rate exponentially based on the number of steps.
        '''
    lr = tf.train.exponential_decay(FLAGS.initial_learning_rate,
                                    global_step,
                                    decay_steps,
                                    FLAGS.learning_rate_decay_factor,
                                    staircase=True)
    '''
        lr_placeholder = tf.placeholder(dtype=tf.float32, shape=[])

        # Create an optimizer that performs gradient descent.
        #opt = tf.train.AdamOptimizer(lr)
        opt = tf.train.MomentumOptimizer(lr_placeholder, MOMENTUM)

        #fetch the data batch from training set
        images, labels = cifar10.placeholder_inputs(FLAGS.batch_size)
        logits = resnet.inference(images,
                                  FLAGS.num_residual_blocks,
                                  reuse=False)

        #calc the loss and gradients
        loss = resnet.loss(logits, labels)
        regu_losses = tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)
        total_loss = tf.add_n([loss] + regu_losses)

        grads = opt.compute_gradients(total_loss)

        # Apply the gradients to adjust the shared variables.
        apply_gradients_op = opt.apply_gradients(grads,
                                                 global_step=global_step)

        with tf.control_dependencies([apply_gradients_op]):
            train_op = tf.identity(total_loss, name='train_op')

        # Create a saver.
        saver = tf.train.Saver(tf.global_variables())

        # Build the summary operation from the last tower summaries.
        summary_op = tf.summary.merge_all()

        validation_accuracy = tf.reduce_sum(resnet.evaluation(
            logits, labels)) / tf.constant(FLAGS.batch_size)
        # Build an initialization operation to run below.
        init = tf.global_variables_initializer()

        # Start running operations on the Graph. allow_soft_placement must be set to
        # True to build towers on GPU, as some of the ops do not have GPU
        # implementations.
        sess = tf.Session(config=tf.ConfigProto(
            allow_soft_placement=True,
            log_device_placement=FLAGS.log_device_placement))
        sess.run(init)

        # these two parameters is used to measure when to enter next epoch
        local_data_batch_idx = 0
        epoch_counter = 0
        batch_counter = 0

        # Start the queue runners.
        tf.train.start_queue_runners(sess=sess)
        for step in range(FLAGS.max_steps):
            # change the API for new aug method
            epoch_counter, local_data_batch_idx, feed_dict = cifar10.fill_feed_dict(
                training_set, training_labels, images, labels,
                FLAGS.batch_size, local_data_batch_idx, epoch_counter,
                FLAGS.init_lr, lr_placeholder)

            batch_counter += 1

            if batch_counter > num_batches_per_epoch:
                batch_counter = 0

            start_time = time.time()
            _, loss_value, acc = sess.run(
                [train_op, total_loss, validation_accuracy],
                feed_dict=feed_dict)

            duration = time.time() - start_time

            assert not np.isnan(loss_value), 'Model diverged with loss = NaN'

            examples_per_sec = FLAGS.batch_size / float(duration)

            print(
                'Train Epoch: {} [{}/{} ({:.0f}%)], Train Loss: {}, Time Cost: {}, Train Acc: {}'
                .format(epoch_counter, batch_counter, num_batches_per_epoch,
                        (100. * (batch_counter * FLAGS.batch_size) /
                         (FLAGS.batch_size * num_batches_per_epoch)),
                        loss_value,
                        time.time() - start_time, acc))
            #tf.logging.info("Data batch index: %s, Current epoch idex: %s" % (str(epoch_counter), str(local_data_batch_idx)))

            if step == FLAGS.decay_step0 or step == FLAGS.decay_step1:
                FLAGS.init_lr = 0.1 * FLAGS.init_lr

            if step % 195 == 0 or (step + 1) == FLAGS.max_steps:
                checkpoint_path = os.path.join(FLAGS.train_dir, 'model.ckpt')
                saver.save(sess, checkpoint_path, global_step=step)