예제 #1
0
def loss_one_kid(logits, labels):
    '''get loss from logits and labels
  input params
    logits    : tf.Tensor
    labels    : tf.placeholder shape[bs]
    kid       : kid index number, started from 1

  return the tf losses operation
  '''
    assert active_kid.index >= 1
    logger.info("Loss Active Kid is %s" % active_kid.name)
    kid = active_kid.index
    # assert labels.get_shape().as_list()[0] == num_kids

    bs = labels.get_shape().as_list()[0]
    lbl = tf.constant([kid] * bs, tf.int64, name=active_kid.name + '_label')
    lbl = tf.cast(tf.equal(labels, lbl), logits.dtype)
    l = tf.expand_dims(lbl, 1, name=active_kid.name + '_expand_dim_label')
    sigmoid_cross_entropy = tf.nn.sigmoid_cross_entropy_with_logits(
        logits=logits, labels=l, name=active_kid.name + "_sigmoid_entropy")
    cross_entropy_mean = tf.reduce_mean(sigmoid_cross_entropy,
                                        name=active_kid.name + '_loss')
    debugger.add_scalar_summaries(name=active_kid.name + '_loss',
                                  x=cross_entropy_mean)

    # The total loss is defined as the cross entropy loss plus all of the weight
    # decay terms (L2 loss).
    tf.add_to_collection(active_kid.loss_name, cross_entropy_mean)
    return tf.add_n(tf.get_collection(active_kid.loss_name),
                    name=active_kid.name + 'total_loss')
예제 #2
0
def accuracy_rate(predictions, labels, num_classes):
    """Return the accuracy rate based on dense predictions and sparse labels.
    predictions: [bs, num_classes]
    labels     : [bs] (range:[1, num_classes])
    num_classes: int
    return accuracy (0~100%), list: err of every kids, list: sample count of every kids
  """
    import numpy as np
    if isinstance(predictions, list):
        predictions = np.array(predictions)
    if isinstance(labels, list):
        labels = np.array(labels)
    assert predictions.shape[0] == labels.shape[0]
    top1_index = np.argmax(predictions, 1)
    top1_index = top1_index + 1  # label start from index 1
    #logger.debug('predict: %s' % top1_index)
    #logger.debug('label  : %s' % labels)
    acc = 100.0 * np.sum(top1_index == labels) / predictions.shape[0]

    missed_kids = labels[top1_index != labels]
    kids_counts = []
    kids_err = []
    for kid in range(1, num_classes + 1):
        cnt = labels[labels == kid]
        failed_cnt = missed_kids[missed_kids == kid]
        kids_err.append('%.2f%%' % (100 * failed_cnt.size / cnt.size))
        kids_counts.append(cnt.size)
    logger.info('err cnt size: %s' % kids_counts)
    logger.info('err rate: %s' % kids_err)
    return acc, kids_err, kids_counts
예제 #3
0
def train(dataset, args):
    '''sks main train loop
  '''
    global ARGS
    ARGS = args
    cpu_setting()
    global num_classes
    num_classes = dataset.num_classes()

    total_start = time.time()
    reports = load_reports(ARGS.log_dir)
    kids_predictions_on_test_data = dict()
    val_rates = dict()

    for kid in ARGS.kids:
        kids_predictions_on_test_data[kid] = train_one_kid(
            kid, dataset, reports[kid])
        val_rates[kid] = dict()
        # for justice, only use validation's f1, acc, pprec, nprec, tpr and tnr
        # for belief policy.
        # DO NOT use test acc, because we will finnaly check on test data.
        # it is cheat if we use test acc.
        f1 = reports[kid]['val']['f1']
        val_rates[kid]['f1'] = f1[len(f1) - 1]
        acc = reports[kid]['val']['acc']
        val_rates[kid]['acc'] = acc[len(acc) - 1]
        pprec = reports[kid]['val']['pprec']
        val_rates[kid]['pprec'] = pprec[len(pprec) - 1]
        nprec = reports[kid]['val']['nprec']
        val_rates[kid]['nprec'] = nprec[len(nprec) - 1]
        tpr = reports[kid]['val']['tpr']
        val_rates[kid]['tpr'] = tpr[len(tpr) - 1]
        tnr = reports[kid]['val']['tnr']
        val_rates[kid]['tnr'] = tnr[len(tnr) - 1]

    test_predict = predict_overall_from_kids(kids_predictions_on_test_data,
                                             val_rates, dataset)
    test_acc, kids_err, kids_counts = sks.accuracy_rate(
        test_predict, dataset.infer_labels('test'), dataset.num_classes())

    test_predict_with_policy = predict_overall_from_kids(
        kids_predictions_on_test_data, val_rates, dataset, True)
    test_acc_with_policy, kids_err_with_policy, kids_counts_with_policy = sks.accuracy_rate(
        test_predict_with_policy, dataset.infer_labels('test'),
        dataset.num_classes())
    logger.info('test accruracy        = %.2f%%' % test_acc)
    logger.info('test accruracy policy = %.2f%%' % test_acc_with_policy)

    logger.info('test accruracy = %.2f%%' % test_acc)
    reports['overall']['test_accruracy'] = test_acc
    reports['overall']['kids_err'] = kids_err
    reports['overall']['test_samples_of_every_kid'] = kids_counts
    spend_time = 'total time: %.1f sec' % (time.time() - total_start)
    reports['overall']['spend_time'] = spend_time
    logger.info(spend_time)
    logger.info("saving reports...")
    logger.debug(reports)
    save_reports(reports, ARGS.log_dir)
    plot_reports(reports, ARGS.log_dir, False)
    plot_reports(reports, ARGS.log_dir, True)
예제 #4
0
     def log_test_results(infer_ops, subset, feed_name, report):
         '''
 obtain loss, acc, tpr, tnr for infer ops,
 log and save the report
 return pred_values
 '''
         assert subset in ['val', 'test']
         logits_values = sks.infer_one_kid_with_all_dataset_in_batches(
             sess, infer_ops, dataset, subset, feed_name,
             ARGS.val_batch_size, ARGS.use_fp16)
         raw_labels = dataset.infer_labels(subset)
         labels_values = sks.convert_label_to_kid_label(raw_labels, kid)
         loss_value = sks.cpu_loss_one_kid(logits_values, labels_values)
         pred_values = logits_values
         f1, acc, pprec, nprec, tpr, tnr = sks.cpu_acc_one_kid(
             pred_values, labels_values)
         report['loss'].append(loss_value)
         report['f1'].append(f1)
         report['acc'].append(acc)
         report['pprec'].append(pprec)
         report['nprec'].append(nprec)
         report['tpr'].append(tpr)
         report['tnr'].append(tnr)
         log_content = (kid, loss_value, f1 * 100, acc * 100.0,
                        pprec * 100, nprec * 100, tpr * 100.0,
                        tnr * 100.0)
         logger.info(
             'kid: %d, loss: %.4f, f1 %.2f%%, acc: %.2f%%, pprec: %.2f%%, nprec: %.2f%%, tpr: %.2f%%, tnr: %.2f%%'
             % log_content)
         return pred_values
예제 #5
0
 def extract_labels(self, filename, num_images):
     """Extract the labels into a vector of int64 label IDs."""
     import gzip
     logger.info('Extracting %s' % filename)
     with gzip.open(filename) as bytestream:
         bytestream.read(8)  # skip head
         buf = bytestream.read(1 * num_images)
         labels = np.frombuffer(buf, dtype=np.uint8).astype(np.int64)
         # all labels start from 1 now, skipping 0!
         labels = np.add(labels, 1)
         # print (labels)
     return labels
예제 #6
0
def predict_overall_from_kids(kids_predicts,
                              val_rates,
                              dataset,
                              use_belief=False,
                              threshold=0.5):
    '''predict overall from all avaiable kids on test data.
  if some kids do not have data, use 0.5.
  can choose whether to use the belief policy.
  return the prediction(shape: (test_num, num_classes)) overall
  '''
    import numpy as np
    res = []
    num_samples = dataset.test_samples()
    num_classes = dataset.num_classes()
    for kid in range(1, num_classes + 1):
        logits = kids_predicts.get(kid, 0)
        if isinstance(logits, int):
            # have not predict this kid
            logits = [0.5] * num_samples
        assert len(logits) == num_samples
        rate = val_rates.get(kid, 0)
        logger.info("kid %d belief rate:" % kid)
        logger.info(rate)
        if isinstance(rate, int):
            # have not pre-validate rate
            rate = dict()
            rate['acc'] = 0.0
            rate['pprec'] = 0.0
            rate['nprec'] = 0.0
        assert rate.get('acc', -1) != -1 and \
            rate.get('pprec', -1) != -1 and \
            rate.get('nprec', -1) != -1
        if use_belief:
            logger.warning("need optimaze")
            # TODO: optimaze it
            for i in xrange(num_samples):
                if logits[i] > threshold:
                    logits[i] = threshold + (logits[i] -
                                             threshold) * rate['pprec']
                else:
                    logits[i] = threshold - (threshold -
                                             logits[i]) * rate['nprec']
        res.append(logits)
    out = np.array(res)
    out = out.transpose()
    assert out.shape == (num_samples, num_classes)
    for i in xrange(num_samples):
        out[i] = sks.cpu_softmax(out[i])
    return out
예제 #7
0
def load_reports(prefix):
    '''load reports if exist, otherwise create one
  return dict
  '''
    import pickle
    import os
    filename = prefix + '//reports.txt'
    if os.path.isfile(filename):
        logger.info("last report %s" % filename)
        with open(filename, "rb") as f:
            reports = pickle.load(f)
            logger.info("last overall report: %s" % reports['overall'])
    else:
        reports = dict()
        for kid in range(1, num_classes + 1):  # [1, num_classes]
            reports[kid] = dict()
            reports[kid]['train'] = dict()
            reports[kid]['val'] = dict()
            reports[kid]['test'] = dict()
            reports[kid]['train']['epoch'] = []
            reports[kid]['train']['loss'] = []
            reports[kid]['train']['lr'] = []
            reports[kid]['train']['f1'] = []
            reports[kid]['train']['acc'] = []
            reports[kid]['train']['pprec'] = []
            reports[kid]['train']['nprec'] = []
            reports[kid]['train']['tpr'] = []
            reports[kid]['train']['tnr'] = []
            reports[kid]['val']['epoch'] = []
            reports[kid]['val']['loss'] = []
            reports[kid]['val']['f1'] = []
            reports[kid]['val']['acc'] = []
            reports[kid]['val']['pprec'] = []
            reports[kid]['val']['nprec'] = []
            reports[kid]['val']['tpr'] = []
            reports[kid]['val']['tnr'] = []
            reports[kid]['test']['epoch'] = []
            reports[kid]['test']['loss'] = []
            reports[kid]['test']['f1'] = []
            reports[kid]['test']['acc'] = []
            reports[kid]['test']['pprec'] = []
            reports[kid]['test']['nprec'] = []
            reports[kid]['test']['tpr'] = []
            reports[kid]['test']['tnr'] = []
        reports['overall'] = dict()

    for kid in range(1, num_classes + 1):
        check_kid_report(reports.get(kid, 0))
    return reports
예제 #8
0
 def maybe_download(self, data_dir, filename):
     """Download the data from Yann's website, unless it's already here."""
     import os
     from six.moves import urllib
     data_url = 'http://yann.lecun.com/exdb/mnist/'
     if not tf.gfile.Exists(data_dir):
         tf.gfile.MakeDirs(data_dir)
     filepath = os.path.join(data_dir, filename)
     if not tf.gfile.Exists(filepath):
         filepath, _ = urllib.request.urlretrieve(data_url + filename,
                                                  filepath)
         with tf.gfile.GFile(filepath) as f:
             size = f.size()
         logger.info('Successfully downloaded %s %d' % (filename, size) +
                     'bytes.')
     return filepath
예제 #9
0
def init_variables(sess, saver=None, checkpoint_dir=None):
    """ Initialize variables of the graph
  will init from checkpoint if provided
  return the last iter number
  """
    ckpt = tf.train.get_checkpoint_state(checkpoint_dir)
    assert sess is not None, 'session should not be none, use tf.Session()'

    # logger.debug(ckpt)
    if ckpt is None or saver is None or ckpt.model_checkpoint_path is None:
        sess.run(tf.global_variables_initializer())
        return 0

    if ckpt.model_checkpoint_path:
        # Restores from checkpoint
        checkpoint_path = ckpt.model_checkpoint_path
        saver.restore(sess, checkpoint_path)
        # Assuming model_checkpoint_path looks something like:
        #   /my-path/model-1000,
        # extract last_iter from it.
        last_iter = checkpoint_path.split('/')[-1].split('-')[-1]
        logger.info("Init from checkpoint " + checkpoint_path)
        return int(last_iter)
예제 #10
0
def infer_one_kid(input_data,
                  input_channels,
                  img_height,
                  img_width,
                  is_training,
                  data_format='nchw',
                  use_gpu=True,
                  dtype=tf.float32):
    '''inference
  kid_index
  return logits
  '''
    assert data_format in ['nchw', 'nhwc'], "only support nchw or nhwc yet"
    assert active_kid.index >= 1
    logger.info("Inference Active Kid is %s" % active_kid.name)

    #assert data_format == 'nchw', "only support NCHW yet"
    if data_format == 'nchw':
        shape = [-1, input_channels, img_height, img_width]
    else:
        shape = [-1, img_height, img_width, input_channels]
    feat = tf.reshape(input_data, shape)

    concat_axis = 1 if data_format == 'nchw' else 3

    #def infer_mnist():

    #def infer_imagenet():

    with tf.variable_scope(active_kid.name):
        reset_all_default_names()
        logger.info(feat.get_shape().as_list())
        y = cbr_op(feat,
                   input_channels,
                   3,
                   3,
                   3,
                   1,
                   1,
                   'VALID',
                   is_training,
                   data_format,
                   dtype=dtype)
        logger.info(y.get_shape().as_list())

        pyramid_depth = 4
        # brach 1
        concat_list = []
        b1 = conv_op(y, 3, 1, 3, 3, 1, 1, 'SAME', data_format, dtype=dtype)
        logger.info(b1.get_shape().as_list())
        for repeat in range(pyramid_depth):
            b1 = pool_op(b1, 3, 3, 2, 2, 'MAX', data_format, padding='VALID')
            logger.info(b1.get_shape().as_list())
            # replace 3x3 ==> 1x3 && 3x1
            b1 = conv_op(b1,
                         1,
                         1,
                         3,
                         3,
                         1,
                         1,
                         'SAME',
                         data_format,
                         dtype=dtype)
            logger.info(b1.get_shape().as_list())
        concat_list.append(b1)

        pyramid = []
        pool = y
        for idx in xrange(pyramid_depth):
            pool = pool_op(pool,
                           3,
                           3,
                           2,
                           2,
                           'MAX',
                           data_format,
                           padding='VALID')
            logger.info(pool.get_shape().as_list())
            pyramid.append(pool)

        p_idx = 1
        for p in pyramid:
            p = conv_op(p, 3, 1, 3, 3, 1, 1, 'SAME', data_format, dtype=dtype)
            logger.info(p.get_shape().as_list())
            for repeat in range(pyramid_depth - p_idx):
                p = pool_op(p, 3, 3, 2, 2, 'MAX', data_format, padding='VALID')
                # replace 3x3 ==> 1x3 && 3x1
                p = conv_op(p,
                            1,
                            1,
                            3,
                            3,
                            1,
                            1,
                            'SAME',
                            data_format,
                            dtype=dtype)
                logger.info(p.get_shape().as_list())
            concat_list.append(p)
            p_idx += 1

        concat = concat_op(concat_list, concat_axis)

        # Reshape the feature map cuboid into a 2D matrix to feed it to the
        # fully connected layers.
        shape = concat.get_shape().as_list()  # tf.shape(pool)
        dim_in = shape[1] * shape[2] * shape[3]
        logger.info("last fc input shape: %s" % shape)
        logger.info("last fc input dims: %d" % dim_in)

        y = fc_op(concat, dim_in=dim_in, dim_out=1, dtype=dtype, bias_init=0.1)

        return y if is_training else tf.sigmoid(
            y, name=active_kid.name + '_sigmoid')
예제 #11
0
def train_one_kid(kid, dataset, report):
    ''' train one kid
  param:
    kid (input): kid index
    dataset (input): dataset
    report (output).
  return the last pred_values of test data: the logist of test ops includes sigoimd!
  '''
    active_kid.set_active_kid(kid)
    logger.info('Training ' + active_kid.name)
    assert kid <= dataset.num_classes(
    ) and kid >= 1, 'invalid kid id: %d' % kid
    input_channels = dataset.num_channels()
    img_height = dataset.img_height()
    img_width = dataset.img_width()
    dims = input_channels * img_height * img_width
    with tf.name_scope('inputs'):
        train_data = tf.placeholder(dtype=data_type(),
                                    shape=[ARGS.batch_size, dims])
        train_label = tf.placeholder(tf.int64, shape=[ARGS.batch_size])
        val_data = tf.placeholder(dtype=data_type(),
                                  shape=[ARGS.val_batch_size, dims])
        test_data = tf.placeholder(dtype=data_type(),
                                   shape=[ARGS.val_batch_size, dims])

    with tf.variable_scope('sks'):
        logits = sks.infer_one_kid(train_data, input_channels, img_height,
                                   img_width, True, ARGS.data_format,
                                   ARGS.use_gpu, data_type())
        loss = sks.loss_one_kid(logits, train_label)

    with tf.name_scope('trainer'):
        ################# MomentumTrainer #################
        # Optimizer: set up a variable that's incremented once per batch and
        # controls the learning rate decay.
        batch_num = tf.Variable(0, dtype=data_type())  #, trainable=False)
        #global_step = tf.get_variable('global_step', [], initializer=tf.constant_initializer(0), trainable=False)
        # decayed_learning_rate = learning_rate * decay_rate ^ (global_step / decay_steps)
        # TODO: self-design lr policy
        learning_rate = tf.train.exponential_decay(
            ARGS.learning_rate,
            global_step=batch_num * ARGS.batch_size,
            decay_steps=dataset.train_samples() /
            25,  # tested 25 is better than 50
            decay_rate=ARGS.learning_rate_decay_rate,
            staircase=True
        )  #True means (global_step / decay_steps) is an integer division
        # Use simple momentum for the optimization.
        optimizer = tf.train.MomentumOptimizer(learning_rate, 0.9)
        debugger.add_scalar_summaries('learning_rate_kid_%d' % kid,
                                      learning_rate)
        opt_op = optimizer.minimize(loss, global_step=batch_num)
        # batchnorm moving average
        bn_averages = tf.train.ExponentialMovingAverage(
            BATCHNORM_MOVING_AVERAGE_DECAY, batch_num)
        bn_averages_op = bn_averages.apply(active_kid.bn_moving_average_list)
        trainer = tf.group(opt_op, bn_averages_op)
        avg_result_list = active_kid.get_bn_moving_result_list(bn_averages)
        for var in avg_result_list:
            # logger.info(var.name)
            trainer = tf.group(trainer, var)

        ################# MomentumTrainer #################
        # optimizer = tf.train.AdamOptimizer(ARGS.learning_rate)
        # trainers.append(optimizer.minimize(losses[kid]))
        # Predictions for the test and validation

    with tf.variable_scope('sks', reuse=True):
        val_infer_ops = sks.infer_one_kid(val_data, input_channels, img_height,
                                          img_width, False, ARGS.data_format,
                                          ARGS.use_gpu, data_type())
        test_infer_ops = sks.infer_one_kid(test_data, input_channels,
                                           img_height, img_width, False,
                                           ARGS.data_format, ARGS.use_gpu,
                                           data_type())

    # Create a saver. TODO: how about only inference?
    saver = tf.train.Saver(tf.global_variables())
    check_kid_report(report)
    train_report = report['train']
    val_report = report['val']
    test_report = report['test']
    with tf.Session() as sess:
        # sess = tf.Session(config = tf.ConfigProto(allow_soft_placement = False, log_device_placement = False, inter_op_parallelism_threads = 8, intra_op_parallelism_threads = 32))
        ckp_path = ARGS.checkpoint_dir + '/kid_%d' % kid
        last_iter = sks.init_variables(sess, saver, ckp_path)
        # TODO: do not need merge all, only need merge this kid's ops
        run_options, run_metadata = debugger.init_merge_all(
            ARGS.log_dir, sess.graph)
        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(sess=sess, coord=coord)

        sum_time = 0
        min_time = float('inf')
        total_start = time.time()
        log_val = 1
        log_test = 1
        log_ckpt = 1
        log_prof = 0
        log_summary = 0
        i = 0
        while (i < ARGS.max_iter):
            epoch = i * ARGS.batch_size / dataset.train_samples()
            if (epoch > ARGS.num_epochs):
                break
            global_iter = last_iter + i

            start_time = time.time()
            batch_data, batch_label = dataset.next_batch(ARGS.batch_size, kid)
            feed_dict = {train_data: batch_data, train_label: batch_label}
            sess.run(trainer,
                     feed_dict=feed_dict,
                     options=run_options,
                     run_metadata=run_metadata)
            duration = time.time() - start_time
            sum_time += duration
            min_time = duration if min_time > duration else min_time

            # print and save log
            if i % ARGS.log_every_n_iters == 0:
                _logits, _loss, _lr = sess.run([logits, loss, learning_rate],
                                               feed_dict=feed_dict,
                                               options=run_options,
                                               run_metadata=run_metadata)
                train_report['epoch'].append(epoch)
                train_report['loss'].append(_loss)
                train_report['lr'].append(_lr)

                # cal the training acc
                _labels = sks.convert_label_to_kid_label(batch_label, kid)
                #print("src label: %s" % batch_label)
                #print("dst kid label: %s" % _labels)
                #print("src logits: %s" % _logits.reshape(len(_labels)))
                _preds = _logits.reshape(len(_labels))
                #print("predicts: %s" % _preds)
                _f1, _acc, _pprec, _nprec, _tpr, _tnr = sks.cpu_acc_one_kid(
                    _preds, _labels)
                train_report['f1'].append(_f1)
                train_report['acc'].append(_acc)
                train_report['pprec'].append(_pprec)
                train_report['nprec'].append(_nprec)
                train_report['tpr'].append(_tpr)
                train_report['tnr'].append(_tnr)

                # print logs
                format_str = 'kid: %d, iter %d (epoch %.2f), lr %.6f, loss : %.4f, f1 %.2f%%, acc: %.2f%%, pprec: %.2f%%, nprec: %.2f%%, tpr: %.2f%%, tnr: %.2f%% (%.1f ms/iter)'
                log_content = (kid, i, epoch, _lr, _loss, _f1 * 100,
                               _acc * 100, _pprec * 100, _nprec * 100,
                               _tpr * 100, _tnr * 100, duration * 1000)
                logger.info(format_str % log_content)

            def log_test_results(infer_ops, subset, feed_name, report):
                '''
        obtain loss, acc, tpr, tnr for infer ops,
        log and save the report
        return pred_values
        '''
                assert subset in ['val', 'test']
                logits_values = sks.infer_one_kid_with_all_dataset_in_batches(
                    sess, infer_ops, dataset, subset, feed_name,
                    ARGS.val_batch_size, ARGS.use_fp16)
                raw_labels = dataset.infer_labels(subset)
                labels_values = sks.convert_label_to_kid_label(raw_labels, kid)
                loss_value = sks.cpu_loss_one_kid(logits_values, labels_values)
                pred_values = logits_values
                f1, acc, pprec, nprec, tpr, tnr = sks.cpu_acc_one_kid(
                    pred_values, labels_values)
                report['loss'].append(loss_value)
                report['f1'].append(f1)
                report['acc'].append(acc)
                report['pprec'].append(pprec)
                report['nprec'].append(nprec)
                report['tpr'].append(tpr)
                report['tnr'].append(tnr)
                log_content = (kid, loss_value, f1 * 100, acc * 100.0,
                               pprec * 100, nprec * 100, tpr * 100.0,
                               tnr * 100.0)
                logger.info(
                    'kid: %d, loss: %.4f, f1 %.2f%%, acc: %.2f%%, pprec: %.2f%%, nprec: %.2f%%, tpr: %.2f%%, tnr: %.2f%%'
                    % log_content)
                return pred_values

            if epoch / ARGS.val_every_n_epoch > log_val or (
                    i + 1) == ARGS.max_iter:
                logger.info('kid: %d, epoch %.2f, validation info:' %
                            (kid, epoch))
                log_test_results(val_infer_ops, 'val', val_data, val_report)
                val_report['epoch'].append(epoch)
                log_val += 1

            if epoch / ARGS.test_every_n_epoch > log_test or (
                    i + 1) == ARGS.max_iter:
                logger.info('kid: %d, epoch %.2f, test info: -------- ' %
                            (kid, epoch))
                log_test_results(test_infer_ops, 'test', test_data,
                                 test_report)
                test_report['epoch'].append(epoch)
                log_test += 1

            if epoch / ARGS.ckpt_every_n_epoch > log_ckpt or (
                    i + 1) == ARGS.max_iter:
                checkpoint_path = ckp_path + '/kid_%d_model' % kid
                saver.save(sess, checkpoint_path, global_step=global_iter)
                logger.info('save checkpoint to %s-%d' %
                            (checkpoint_path, global_iter))
                log_ckpt += 1

            if epoch / ARGS.profil_every_n_epoch > log_prof:
                filename = ARGS.log_dir + ('/timeline_iter%d' %
                                           global_iter) + '.json'
                debugger.save_timeline(filename, run_metadata)
                debugger.save_tfprof(ARGS.log_dir, sess.graph, run_metadata)
                log_prof += 1

            if epoch % ARGS.summary_every_n_epoch > log_summary:
                debugger.save_summaries(sess, feed_dict, run_options,
                                        run_metadata, i)
                log_summary += 1

            sys.stdout.flush()
            i += 1
        test_pred = log_test_results(test_infer_ops, 'test', test_data,
                                     test_report)
        coord.request_stop()
        coord.join(threads)
        content = (time.time() - total_start, sum_time * 1000.0 / i,
                   min_time * 1000.0)
        spend_time = 'total time: %.1f sec, train avg time: %.3f ms, train min_time: %.3f ms' % content
        logger.info(spend_time)
        report['spend_time'] = spend_time

    return test_pred