示例#1
0
def test(images_tensor, labels_tensor, parameters_path, test_data_dir):
    '''
    测试模型的有效性
    :param images_tensor: 图像的tensor
    :param labels_tensor: label的tensor
    :param generator: 数据的生成器,可以通过for (images_batch, labels_batch) in generator:格式来获取数据
    :param parameters_path:模型保存的路径
    :return:
    '''
    logits = inference(images_tensor, False)
    prediction_tensor = tf.argmax(logits, 1)
    correct_prediction = tf.equal(tf.argmax(logits, 1),
                                  tf.cast(tf.squeeze(labels_tensor), tf.int64))
    accuracy_tensor = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
    saver = tf.train.Saver()
    featuremaps = []
    labels = []

    with tf.Session() as sess:
        full_path = tf.train.latest_checkpoint(parameters_path)
        print full_path
        saver.restore(sess, full_path)

        names = os.listdir(test_data_dir)
        paths = []
        labels = []
        all_predictions = []
        for name in names:
            if int(name[-1]) not in [0, 1, 2, 3]:
                continue
            paths.append(os.path.join(test_data_dir, name))
            labels.append(int(name[-1]))
        start_index = 0
        while True:
            end_index = start_index + net_config.BATCH_SIZE
            if end_index > len(paths):
                end_index = len(paths)
            cur_batch_paths = paths[start_index:end_index]
            cur_batch_images = [
                load_ROI(cur_path) for cur_path in cur_batch_paths
            ]
            cur_batch_labels = labels[start_index:end_index]
            cur_batch_images = resize_images(cur_batch_images,
                                             net_config.ROI_SIZE_W, True)
            print np.shape(cur_batch_images)
            print np.shape(cur_batch_labels)
            predicted, accuracy = sess.run(
                [prediction_tensor, accuracy_tensor],
                feed_dict={
                    images_tensor: np.array(cur_batch_images),
                    labels_tensor: np.squeeze(cur_batch_labels)
                })
            print accuracy
            all_predictions.extend(predicted)
            start_index = end_index
            if start_index >= len(paths):
                break
    calculate_acc_error(all_predictions, labels)
示例#2
0
def execute_classify(train_features, train_labels, val_features, val_labels, test_features, test_labels):
    from LeaningBased.BoVW_DualDict.classification import SVM, LinearSVM, KNN
    predicted_label, c_params, g_params, max_c, max_g, accs = SVM.do(train_features, train_labels, val_features,
                                                                     val_labels,
                                                                     adjust_parameters=True)
    predicted_label, accs = SVM.do(train_features, train_labels, test_features, test_labels,
                                   adjust_parameters=False, C=max_c, gamma=max_g)
    calculate_acc_error(predicted_label, test_labels)
    print 'ACA is ', accs
    return accs
示例#3
0
    from utils.classification import SVM, KNN

    train_data = scio.loadmat(
        './features/crossvalidation/0/saved/train.npy.mat')
    train_features = train_data['features']
    train_labels = train_data['labels']

    val_data = scio.loadmat('./features/crossvalidation/0/saved/val.npy.mat')
    val_features = val_data['features']
    val_labels = val_data['labels']

    test_data = scio.loadmat('./features/crossvalidation/0/saved/test.npy.mat')
    test_features = test_data['features']
    test_labels = test_data['labels']
    # SVM
    predicted_label, c_params, g_params, accs = SVM.do(train_features,
                                                       train_labels,
                                                       val_features,
                                                       val_labels,
                                                       adjust_parameters=True)
    # use default parameters
    predicted_label, acc = SVM.do(train_features,
                                  train_labels,
                                  test_features,
                                  test_labels,
                                  adjust_parameters=False,
                                  C=c_params,
                                  gamma=g_params)
    print 'ACC is ', acc
    calculate_acc_error(predicted_label, test_labels)
示例#4
0
def main(_):
    roi_images = tf.placeholder(shape=[
        None, net_config.ROI_SIZE_W, net_config.ROI_SIZE_H,
        net_config.IMAGE_CHANNEL
    ],
                                dtype=np.float32,
                                name='roi_input')
    expand_roi_images = tf.placeholder(shape=[
        None, net_config.EXPAND_SIZE_W, net_config.EXPAND_SIZE_H,
        net_config.IMAGE_CHANNEL
    ],
                                       dtype=np.float32,
                                       name='expand_roi_input')
    batch_size_tensor = tf.placeholder(dtype=tf.int32, shape=[])
    is_training_tensor = tf.placeholder(dtype=tf.bool, shape=[])
    logits, _, _, _ = inference_small(roi_images,
                                      expand_roi_images,
                                      phase_names=['NC', 'ART', 'PV'],
                                      num_classes=4,
                                      is_training=is_training_tensor,
                                      batch_size=batch_size_tensor)
    model_path = '/home/give/PycharmProjects/MICCAI2018/deeplearning/Parallel/parameters/1'
    # model_path = '/home/give/PycharmProjects/MedicalImage/Net/forpatch/cross_validation/model/multiscale/parallel/0/2200.0'
    predictions = tf.nn.softmax(logits)
    saver = tf.train.Saver(tf.all_variables())
    print predictions

    predicted_label_tensor = tf.argmax(predictions, axis=1)
    print predicted_label_tensor
    init = tf.initialize_all_variables()
    sess = tf.Session(config=tf.ConfigProto(log_device_placement=False))
    sess.run(init)
    tf.train.start_queue_runners(sess=sess)
    latest = tf.train.latest_checkpoint(model_path)
    if not latest:
        print "No checkpoint to continue from in", model_path
        sys.exit(1)
    print "resume", latest
    saver.restore(sess, latest)

    data_dir = '/home/give/Documents/dataset/MICCAI2018/Patches/crossvalidation/1/test'
    slice_dir = '/home/give/Documents/dataset/MICCAI2018/Slices/crossvalidation/1/test'
    labels = []
    paths = []
    for typeid in [0, 1, 2, 3]:
        cur_path = os.path.join(data_dir, str(typeid))
        names = os.listdir(cur_path)
        labels.extend([typeid] * len(names))
        paths.extend([os.path.join(cur_path, name) for name in names])
    paths, labels = shuffle_image_label(paths, labels)
    start_index = 0
    predicted_labels = []
    liver_density = load_raw_liver_density()
    while True:
        if start_index >= len(paths):
            break
        print start_index, len(paths)
        end_index = start_index + net_config.BATCH_SIZE
        cur_paths = paths[start_index:end_index]
        cur_roi_images = [np.asarray(load_patch(path)) for path in cur_paths]
        cur_expand_roi_images = [
            np.asarray(load_patch(path, return_roi=True, parent_dir=slice_dir))
            for path in cur_paths
        ]
        cur_roi_images = resize_images(cur_roi_images, net_config.ROI_SIZE_W,
                                       True)
        cur_expand_roi_images = resize_images(cur_expand_roi_images,
                                              net_config.EXPAND_SIZE_W, True)
        cur_liver_densitys = [
            liver_density[os.path.basename(path)
                          [:os.path.basename(path).rfind('_')]]
            for path in cur_paths
        ]
        # for i in range(len(cur_roi_images)):
        #     for j in range(3):
        #         cur_roi_images[i, :, :, j] = (1.0 * cur_roi_images[i, :, :, j]) / (1.0 * cur_liver_densitys[i][j])
        #         cur_expand_roi_images[i, :, :, j] = (1.0 * cur_expand_roi_images[i, :, :, j]) / (
        #         1.0 * cur_liver_densitys[i][j])
        predicted_batch_labels = sess.run(predicted_label_tensor,
                                          feed_dict={
                                              roi_images:
                                              cur_roi_images,
                                              expand_roi_images:
                                              cur_expand_roi_images,
                                              is_training_tensor:
                                              False,
                                              batch_size_tensor:
                                              len(cur_roi_images)
                                          })
        batch_labels = labels[start_index:end_index]
        predicted_labels.extend(predicted_batch_labels)
        start_index = end_index
        calculate_acc_error(predicted_batch_labels, batch_labels)
    calculate_acc_error(predicted_labels, labels)
示例#5
0
    @staticmethod
    def do(train_data, train_label, test_data, test_label=None, adjust_parameters=True):
        train_data = np.array(train_data).squeeze()
        train_label = np.array(train_label).squeeze()
        test_data = np.array(test_data).squeeze()
        if test_label is not None:
            test_label = np.array(test_label).squeeze()
        svm = LinearSVC()
        svm.fit(train_data, train_label)
        predicts = svm.predict(test_data)
        acc = None
        if test_label is not None:
            acc = accuracy_score(test_label, predicts)
            print acc
        return predicts


if __name__ == '__main__':
    data = scio.loadmat('/home/give/PycharmProjects/MedicalImage/BoVW/data_256_False.mat')
    train_features = data['train_features']
    val_features = data['val_features']
    train_labels = data['train_labels']
    val_labels = data['val_labels']
    val_labels = np.squeeze(val_labels)
    print np.shape(train_features), np.shape(train_labels)
    predicted_label = SVM.do(train_features, train_labels, val_features, val_labels, adjust_parameters=True)

    np.save('./predicted_res.npy', predicted_label)
    # predicted_label = np.load('./predicted_res.npy')
    calculate_acc_error(predicted_label, val_labels)
示例#6
0
def train(logits,
          local_output_tensor,
          global_output_tensor,
          represent_feature_tensor,
          images_tensor,
          expand_images_tensor,
          labels_tensor,
          is_training_tensor,
          save_model_path=None,
          step_width=100,
          record_loss=False):
    cross_id = 1
    has_centerloss = False
    patches_dir = '/home/give/Documents/dataset/MICCAI2018/Patches/crossvalidation'
    roi_dir = '/home/give/Documents/dataset/MICCAI2018/Slices/crossvalidation'
    pre_load = True
    train_dataset = DataSet(os.path.join(patches_dir, str(cross_id), 'train'),
                            'train',
                            pre_load=pre_load,
                            rescale=True,
                            divied_liver=False,
                            expand_is_roi=True,
                            full_roi_path=os.path.join(roi_dir, str(cross_id),
                                                       'train'))
    val_dataset = DataSet(os.path.join(patches_dir, str(cross_id), 'val'),
                          'val',
                          pre_load=pre_load,
                          rescale=True,
                          divied_liver=False,
                          expand_is_roi=True,
                          full_roi_path=os.path.join(roi_dir, str(cross_id),
                                                     'val'))

    train_batchdata = train_dataset.get_next_batch(net_config.BATCH_SIZE)
    val_batchdata = val_dataset.get_next_batch(net_config.BATCH_SIZE)

    global_step = tf.get_variable('global_step', [],
                                  initializer=tf.constant_initializer(0),
                                  trainable=False)
    val_step = tf.get_variable('val_step', [],
                               initializer=tf.constant_initializer(0),
                               trainable=False)
    # inter loss
    loss_local = loss(local_output_tensor, labels_tensor)
    loss_global = loss(global_output_tensor, labels_tensor)
    loss_last = loss(logits, labels_tensor)
    loss_inter = loss_local_coefficient * loss_local + loss_global_coefficient * loss_global * loss_all_coefficient * loss_last

    # intra loss
    if has_centerloss:
        represent_feature_tensor_shape = represent_feature_tensor.get_shape(
        ).as_list()
        print 'represent_feature_tensor_shape is ', represent_feature_tensor_shape
        centers_value = np.zeros(
            [category_num, represent_feature_tensor_shape[1]],
            dtype=np.float32)
        print 'centers_value shape is ', np.shape(centers_value)
        centers_tensor = tf.placeholder(
            dtype=tf.float32,
            shape=[category_num, represent_feature_tensor_shape[1]])
        print 'center_tensor shape is ', tf.shape(centers_tensor)
        center_loss = calculate_centerloss(represent_feature_tensor,
                                           labels_tensor,
                                           centers_tensor=centers_tensor)
        owner_step = tf.py_func(update_centers, [
            centers_tensor, represent_feature_tensor, labels_tensor,
            category_num
        ], tf.float32)

        loss_ = loss_inter + _lambda * center_loss
    else:
        loss_ = loss_inter
    predictions = tf.nn.softmax(logits)
    print 'predictions shape is ', predictions
    print 'label is ', labels_tensor
    top1_error = top_k_error(predictions, labels_tensor, 1)
    labels_onehot = tf.one_hot(labels_tensor, logits.get_shape().as_list()[-1])
    print 'output node is ', logits.get_shape().as_list()[-1]
    accuracy_tensor = calculate_accuracy(predictions, labels_onehot)

    # loss_avg
    ema = tf.train.ExponentialMovingAverage(MOVING_AVERAGE_DECAY, global_step)
    tf.add_to_collection(UPDATE_OPS_COLLECTION, ema.apply([loss_]))
    tf.summary.scalar('loss_avg', ema.average(loss_))

    # validation stats
    ema = tf.train.ExponentialMovingAverage(0.9, val_step)
    val_op = tf.group(val_step.assign_add(1), ema.apply([top1_error]))
    top1_error_avg = ema.average(top1_error)
    tf.summary.scalar('val_top1_error_avg', top1_error_avg)

    tf.summary.scalar('learning_rate', FLAGS.learning_rate)

    opt = tf.train.MomentumOptimizer(FLAGS.learning_rate, MOMENTUM)
    grads = opt.compute_gradients(loss_)
    for grad, var in grads:
        if grad is not None and not FLAGS.minimal_summaries:
            tf.summary.histogram(var.op.name + '/gradients', grad)
    apply_gradient_op = opt.apply_gradients(grads, global_step=global_step)

    if not FLAGS.minimal_summaries:
        # Display the training images in the visualizer.
        tf.summary.image('images', images_tensor)

        for var in tf.trainable_variables():
            tf.summary.image(var.op.name, var)

    batchnorm_updates = tf.get_collection(UPDATE_OPS_COLLECTION)
    batchnorm_updates_op = tf.group(*batchnorm_updates)

    if has_centerloss:
        with tf.control_dependencies(
            [apply_gradient_op, batchnorm_updates_op, owner_step]):
            train_op = tf.no_op('train')
    else:
        train_op = tf.group(apply_gradient_op, batchnorm_updates_op)

    saver = tf.train.Saver(tf.all_variables())

    summary_op = tf.summary.merge_all()

    init = tf.initialize_all_variables()

    sess = tf.Session(config=tf.ConfigProto(log_device_placement=False))
    sess.run(init)
    tf.train.start_queue_runners(sess=sess)

    summary_writer = tf.summary.FileWriter(FLAGS.log_dir, sess.graph)
    val_summary_writer = tf.summary.FileWriter(FLAGS.log_val_dir, sess.graph)
    if FLAGS.resume:
        latest = tf.train.latest_checkpoint(FLAGS.load_model_path)
        if not latest:
            print "No checkpoint to continue from in", FLAGS.train_dir
            sys.exit(1)
        print "resume", latest
        saver.restore(sess, latest)

    for x in xrange(FLAGS.max_steps + 1):
        start_time = time.time()

        step = sess.run(global_step)
        if has_centerloss:
            i = [train_op, loss_, owner_step]
        else:
            i = [train_op, loss_]
        write_summary = step % 100 and step > 1
        if write_summary:
            i.append(summary_op)
        train_roi_batch_images, train_expand_roi_batch_images, train_labels = train_batchdata.next(
        )
        o = sess.run(
            i,
            feed_dict={
                images_tensor: train_roi_batch_images,
                expand_images_tensor: train_expand_roi_batch_images,
                labels_tensor: train_labels,
                # centers_tensor: centers_value,
                is_training_tensor: True
            })
        if has_centerloss:
            centers_value = o[2]
        loss_value = o[1]

        duration = time.time() - start_time

        assert not np.isnan(loss_value), 'Model diverged with loss = NaN'

        if (step - 1) % step_width == 0:
            top1_error_value, accuracy_value, labels_values, predictions_values = sess.run(
                [top1_error, accuracy_tensor, labels_tensor, predictions],
                feed_dict={
                    images_tensor: train_roi_batch_images,
                    expand_images_tensor: train_expand_roi_batch_images,
                    labels_tensor: train_labels,
                    # centers_tensor: centers_value,
                    is_training_tensor: True
                })
            predictions_values = np.argmax(predictions_values, axis=1)
            examples_per_sec = FLAGS.batch_size / float(duration)
            # accuracy = eval_accuracy(predictions_values, labels_values)
            format_str = (
                'step %d, loss = %.2f, top1 error = %g, accuracy value = %g  (%.1f examples/sec; %.3f '
                'sec/batch)')

            print(format_str % (step, loss_value, top1_error_value,
                                accuracy_value, examples_per_sec, duration))
        if write_summary:
            if has_centerloss:
                summary_str = o[3]
            else:
                summary_str = o[2]
            summary_writer.add_summary(summary_str, step)

        # Save the model checkpoint periodically.
        if step > 1 and step % step_width == 0:

            checkpoint_path = os.path.join(save_model_path, 'model.ckpt')
            saver.save(sess, checkpoint_path, global_step=global_step)
            save_dir = os.path.join(save_model_path, str(step))
            if not os.path.exists(save_dir):
                os.mkdir(save_dir)
            filenames = glob(
                os.path.join(save_model_path,
                             '*-' + str(int(step + 1)) + '.*'))
            for filename in filenames:
                shutil.copy(filename,
                            os.path.join(save_dir, os.path.basename(filename)))
        # Run validation periodically
        if step > 1 and step % step_width == 0:
            val_roi_batch_images, val_expand_roi_batch_images, val_labels = val_batchdata.next(
            )
            _, top1_error_value, summary_value, accuracy_value, labels_values, predictions_values = sess.run(
                [
                    val_op, top1_error, summary_op, accuracy_tensor,
                    labels_tensor, predictions
                ],
                {
                    images_tensor: val_roi_batch_images,
                    expand_images_tensor: val_expand_roi_batch_images,
                    # centers_tensor: centers_value,
                    labels_tensor: val_labels,
                    is_training_tensor: False
                })
            predictions_values = np.argmax(predictions_values, axis=1)
            # accuracy = eval_accuracy(predictions_values, labels_values)
            calculate_acc_error(logits=predictions_values,
                                label=labels_values,
                                show=True)
            print('Validation top1 error %.2f, accuracy value %f' %
                  (top1_error_value, accuracy_value))
            val_summary_writer.add_summary(summary_value, step)
示例#7
0
def main(_):
    roi_images = tf.placeholder(shape=[
        None, net_config.ROI_SIZE_W, net_config.ROI_SIZE_H,
        net_config.IMAGE_CHANNEL
    ],
                                dtype=np.float32,
                                name='roi_input')
    expand_roi_images = tf.placeholder(shape=[
        None, net_config.EXPAND_SIZE_W, net_config.EXPAND_SIZE_H,
        net_config.IMAGE_CHANNEL
    ],
                                       dtype=np.float32,
                                       name='expand_roi_input')
    batch_size_tensor = tf.placeholder(dtype=tf.int32, shape=[])
    is_training_tensor = tf.placeholder(dtype=tf.bool, shape=[])
    logits = inference_small(roi_images,
                             expand_roi_images,
                             phase_names=['NC', 'ART', 'PV'],
                             num_classes=4,
                             point_phase=[2],
                             is_training=is_training_tensor,
                             batch_size=batch_size_tensor)
    # model_path = '/home/give/PycharmProjects/MedicalImage/Net/ICIP/4-class/Patch_ROI/models/300.0/'
    # model_path = '/home/give/PycharmProjects/MedicalImage/Net/forpatch/cross_validation/model/multiscale/parallel/0/2200.0'
    model_path = '/home/give/PycharmProjects/MedicalImage/Net/ICIP/4-class/Patch_ROI/models_7'
    predictions = tf.nn.softmax(logits)
    saver = tf.train.Saver(tf.all_variables())
    print predictions

    predicted_label_tensor = tf.argmax(predictions, axis=1)
    print predicted_label_tensor
    init = tf.initialize_all_variables()
    sess = tf.Session(config=tf.ConfigProto(log_device_placement=False))
    sess.run(init)
    tf.train.start_queue_runners(sess=sess)
    latest = tf.train.latest_checkpoint(model_path)
    if not latest:
        print "No checkpoint to continue from in", model_path
        sys.exit(1)
    print "resume", latest
    saver.restore(sess, latest)

    data_dir = '/home/give/Documents/dataset/MedicalImage/MedicalImage/Patches/ICIP/only-patch-7/val'
    labels = []
    paths = []
    mapping_label = {0: 0, 1: 1, 2: 2, 3: 3}
    for typeid in [0, 1, 2, 3]:
        cur_path = os.path.join(data_dir, str(typeid))
        names = os.listdir(cur_path)
        labels.extend([mapping_label[typeid]] * len(names))
        paths.extend([os.path.join(cur_path, name) for name in names])
    paths, labels = shuffle_image_label(paths, labels)
    start_index = 0
    predicted_labels = []
    while True:
        if start_index >= len(paths):
            break
        print start_index, len(paths)
        end_index = start_index + net_config.BATCH_SIZE
        cur_paths = paths[start_index:end_index]
        cur_roi_images = [np.asarray(load_patch(path)) for path in cur_paths]
        cur_expand_roi_images = [
            np.asarray(
                load_patch(
                    path,
                    return_roi=True,
                    parent_dir=
                    '/home/give/Documents/dataset/MedicalImage/MedicalImage/SL_TrainAndVal/val'
                )) for path in cur_paths
        ]
        cur_roi_images = resize_images(cur_roi_images, net_config.ROI_SIZE_W,
                                       True)
        cur_expand_roi_images = resize_images(cur_expand_roi_images,
                                              net_config.EXPAND_SIZE_W, True)
        predicted_batch_labels = sess.run(predicted_label_tensor,
                                          feed_dict={
                                              roi_images:
                                              cur_roi_images,
                                              expand_roi_images:
                                              cur_expand_roi_images,
                                              is_training_tensor:
                                              False,
                                              batch_size_tensor:
                                              len(cur_roi_images)
                                          })
        batch_labels = labels[start_index:end_index]
        predicted_labels.extend(predicted_batch_labels)
        start_index = end_index
        calculate_acc_error(predicted_batch_labels, batch_labels)
    calculate_acc_error(predicted_labels, labels)
示例#8
0
def train(logits,
          images_tensor,
          expand_images_tensor,
          labels_tensor,
          is_training_tensor,
          save_model_path=None,
          step_width=100,
          record_loss=False):
    train_dataset = DataSet(
        '/home/give/Documents/dataset/MedicalImage/MedicalImage/Patches/ICIP/140_only_patch/train',
        'train',
        rescale=True,
        divied_liver=False,
        expand_is_roi=True,
        full_roi_path=
        '/home/give/Documents/dataset/MedicalImage/MedicalImage/SL_TrainAndVal/train'
    )
    val_dataset = DataSet(
        '/home/give/Documents/dataset/MedicalImage/MedicalImage/Patches/ICIP/140_only_patch/val',
        'val',
        rescale=True,
        divied_liver=False,
        expand_is_roi=True,
        full_roi_path=
        '/home/give/Documents/dataset/MedicalImage/MedicalImage/SL_TrainAndVal/train'
    )
    loss_value_record_file_path = '/home/give/PycharmProjects/MedicalImage/Net/ICIP/4-class/Patch_ROI/models_draw_line/loss_value'
    if record_loss:
        writed_filed = open(loss_value_record_file_path, 'w')
        writed_filed.write(
            'step training loss, val loss, training acc, val acc\n')
        writed_filed.close()
    train_batchdata = train_dataset.get_next_batch(net_config.BATCH_SIZE)
    val_batchdata = val_dataset.get_next_batch(net_config.BATCH_SIZE)

    global_step = tf.get_variable('global_step', [],
                                  initializer=tf.constant_initializer(0),
                                  trainable=False)
    val_step = tf.get_variable('val_step', [],
                               initializer=tf.constant_initializer(0),
                               trainable=False)
    loss_ = loss(logits, labels_tensor)
    predictions = tf.nn.softmax(logits)
    print 'predictions shape is ', predictions
    print 'label is ', labels_tensor
    top1_error = top_k_error(predictions, labels_tensor, 1)
    labels_onehot = tf.one_hot(labels_tensor, logits.get_shape().as_list()[-1])
    print 'output node is ', logits.get_shape().as_list()[-1]
    accuracy_tensor = calculate_accuracy(predictions, labels_onehot)

    # loss_avg
    ema = tf.train.ExponentialMovingAverage(MOVING_AVERAGE_DECAY, global_step)
    tf.add_to_collection(UPDATE_OPS_COLLECTION, ema.apply([loss_]))
    tf.summary.scalar('loss_avg', ema.average(loss_))

    # validation stats
    ema = tf.train.ExponentialMovingAverage(0.9, val_step)
    val_op = tf.group(val_step.assign_add(1), ema.apply([top1_error]))
    top1_error_avg = ema.average(top1_error)
    tf.summary.scalar('val_top1_error_avg', top1_error_avg)

    tf.summary.scalar('learning_rate', FLAGS.learning_rate)

    opt = tf.train.MomentumOptimizer(FLAGS.learning_rate, MOMENTUM)
    grads = opt.compute_gradients(loss_)
    for grad, var in grads:
        if grad is not None and not FLAGS.minimal_summaries:
            tf.summary.histogram(var.op.name + '/gradients', grad)
    apply_gradient_op = opt.apply_gradients(grads, global_step=global_step)

    if not FLAGS.minimal_summaries:
        # Display the training images in the visualizer.
        tf.summary.image('images', images_tensor)

        for var in tf.trainable_variables():
            tf.summary.image(var.op.name, var)

    batchnorm_updates = tf.get_collection(UPDATE_OPS_COLLECTION)
    batchnorm_updates_op = tf.group(*batchnorm_updates)
    train_op = tf.group(apply_gradient_op, batchnorm_updates_op)

    saver = tf.train.Saver(tf.all_variables())

    summary_op = tf.summary.merge_all()

    init = tf.initialize_all_variables()

    sess = tf.Session(config=tf.ConfigProto(log_device_placement=False))
    sess.run(init)
    tf.train.start_queue_runners(sess=sess)

    summary_writer = tf.summary.FileWriter(FLAGS.log_dir, sess.graph)
    val_summary_writer = tf.summary.FileWriter(FLAGS.log_val_dir, sess.graph)
    if FLAGS.resume:
        latest = tf.train.latest_checkpoint(FLAGS.load_model_path)
        if not latest:
            import sys
            print "No checkpoint to continue from in", FLAGS.train_dir
            sys.exit(1)
        print "resume", latest
        saver.restore(sess, latest)

    for x in xrange(FLAGS.max_steps + 1):
        start_time = time.time()

        step = sess.run(global_step)
        i = [train_op, loss_, accuracy_tensor]

        write_summary = step % 100 and step > 1
        if write_summary:
            i.append(summary_op)
        train_roi_batch_images, train_expand_roi_batch_images, train_labels = train_batchdata.next(
        )
        o = sess.run(i,
                     feed_dict={
                         images_tensor: train_roi_batch_images,
                         expand_images_tensor: train_expand_roi_batch_images,
                         labels_tensor: train_labels,
                         is_training_tensor: True
                     })

        loss_value = o[1]

        # calculate the value of loss of validation
        if record_loss:
            val_roi_batch_images, val_expand_roi_batch_images, val_labels = val_batchdata.next(
            )
            val_loss_value, val_accuracy_value = sess.run(
                [loss_, accuracy_tensor],
                feed_dict={
                    images_tensor: train_roi_batch_images,
                    expand_images_tensor: train_expand_roi_batch_images,
                    labels_tensor: train_labels,
                    is_training_tensor: False
                })
            printed_str = 'step %d, training loss_value %f accuracy %f, validation loss_value %f accuracy %f'\
                          % (step, loss_value, o[2], val_loss_value, val_accuracy_value)
            print printed_str
            writed_str = '%d %f %f %f %f\n' % (
                step, loss_value, val_loss_value, o[2], val_accuracy_value)
            writed_filed = open(loss_value_record_file_path, 'a+')
            writed_filed.write(writed_str)
            writed_filed.close()

        duration = time.time() - start_time

        assert not np.isnan(loss_value), 'Model diverged with loss = NaN'

        if (step - 1) % step_width == 0:
            top1_error_value, accuracy_value, labels_values, predictions_values = sess.run(
                [top1_error, accuracy_tensor, labels_tensor, predictions],
                feed_dict={
                    images_tensor: train_roi_batch_images,
                    expand_images_tensor: train_expand_roi_batch_images,
                    labels_tensor: train_labels,
                    is_training_tensor: True
                })
            predictions_values = np.argmax(predictions_values, axis=1)
            examples_per_sec = FLAGS.batch_size / float(duration)
            # accuracy = eval_accuracy(predictions_values, labels_values)
            format_str = (
                'step %d, loss = %.2f, top1 error = %g, accuracy value = %g  (%.1f examples/sec; %.3f '
                'sec/batch)')

            print(format_str % (step, loss_value, top1_error_value,
                                accuracy_value, examples_per_sec, duration))
        if write_summary:
            summary_str = o[3]
            summary_writer.add_summary(summary_str, step)

        if step > 1 and step % step_width == 0:

            checkpoint_path = os.path.join(save_model_path, 'model.ckpt')
            saver.save(sess, checkpoint_path, global_step=global_step)
            save_dir = os.path.join(save_model_path, str(step))
            if not os.path.exists(save_dir):
                os.mkdir(save_dir)
            filenames = glob(
                os.path.join(save_model_path,
                             '*-' + str(int(step + 1)) + '.*'))
            for filename in filenames:
                shutil.copy(filename,
                            os.path.join(save_dir, os.path.basename(filename)))
        # Run validation periodically
        if step > 1 and step % step_width == 0:
            if not record_loss:
                val_roi_batch_images, val_expand_roi_batch_images, val_labels = val_batchdata.next(
                )
            _, top1_error_value, summary_value, accuracy_value, labels_values, predictions_values = sess.run(
                [
                    val_op, top1_error, summary_op, accuracy_tensor,
                    labels_tensor, predictions
                ], {
                    images_tensor: val_roi_batch_images,
                    expand_images_tensor: val_expand_roi_batch_images,
                    labels_tensor: val_labels,
                    is_training_tensor: False
                })
            predictions_values = np.argmax(predictions_values, axis=1)
            # accuracy = eval_accuracy(predictions_values, labels_values)
            calculate_acc_error(logits=predictions_values,
                                label=labels_values,
                                show=True)
            print('Validation top1 error %.2f, accuracy value %f' %
                  (top1_error_value, accuracy_value))
            val_summary_writer.add_summary(summary_value, step)