Ejemplo n.º 1
0
def eval(i_ckpt):
    # does not perform multi-scale test. ms-test is in predict.py
    tf.reset_default_graph()

    if FLAGS.float_type == 16:
        print('\n< using tf.float16 >\n')
        float_type = tf.float16
    else:
        print('\n< using tf.float32 >\n')
        float_type = tf.float32

    input_size = FLAGS.test_image_size
    with tf.device('/cpu:0'):
        data_dir, img_mean, num_classes = find_data_path(FLAGS.database)
        images_filenames, labels_filenames = read_labeled_image_list(
            FLAGS.database, data_dir, 'val')

    images_pl = [tf.placeholder(tf.float32, [None, input_size, input_size, 3])]
    labels_pl = [tf.placeholder(tf.int32, [None, input_size, input_size, 1])]

    model = pspnet_mg.PSPNetMG(num_classes,
                               mode='val',
                               resnet=FLAGS.network,
                               data_format=FLAGS.data_format,
                               float_type=float_type,
                               has_aux_loss=False,
                               structure_in_paper=FLAGS.structure_in_paper)
    logits = model.inference(images_pl)
    probas_op = tf.nn.softmax(logits[0],
                              dim=1 if FLAGS.data_format == 'NCHW' else 3)
    # ========================= end of building model ================================

    gpu_options = tf.GPUOptions(allow_growth=False)
    config = tf.ConfigProto(log_device_placement=False,
                            gpu_options=gpu_options,
                            allow_soft_placement=True)
    sess = tf.Session(config=config)
    # sess.run([tf.global_variables_initializer(), tf.local_variables_initializer()])

    coord = tf.train.Coordinator()
    threads = tf.train.start_queue_runners(sess=sess, coord=coord)

    if i_ckpt is not None:
        loader = tf.train.Saver(max_to_keep=0)
        loader.restore(sess, i_ckpt)
        eval_step = i_ckpt.split('-')[-1]
        print('Succesfully loaded model from %s at step=%s.' %
              (i_ckpt, eval_step))

    print('\n< eval process begins >\n')
    average_loss = 0.0
    confusion_matrix = np.zeros((num_classes, num_classes), dtype=np.int64)

    if FLAGS.test_max_iter is None:
        max_iter = len(images_filenames)
    else:
        max_iter = FLAGS.test_max_iter

    step = 0
    show_iter = max_iter // 20
    while step < max_iter:
        image, label = cv2.imread(images_filenames[step],
                                  1), cv2.imread(labels_filenames[step], 0)
        label = np.reshape(label, [1, label.shape[0], label.shape[1], 1])
        if 'ADE' in FLAGS.database:  # the first label (0) of ADE is background.
            label -= 1

        imgsplitter = ImageSplitter(image, 1.0, FLAGS.color_switch, input_size,
                                    img_mean)
        feed_dict = {images_pl[0]: imgsplitter.get_split_crops()}
        [logits] = sess.run([probas_op], feed_dict=feed_dict)
        total_logits = imgsplitter.reassemble_crops(logits)
        if FLAGS.mirror == 1:
            image_mirror = image[:, ::-1]
            imgsplitter_mirror = ImageSplitter(image_mirror, 1.0,
                                               FLAGS.color_switch, input_size,
                                               img_mean)
            feed_dict = {images_pl[0]: imgsplitter_mirror.get_split_crops()}
            [logits_m] = sess.run([probas_op], feed_dict=feed_dict)
            logits_m = imgsplitter_mirror.reassemble_crops(logits_m)
            total_logits += logits_m[:, ::-1]

        prediction = np.argmax(total_logits, axis=-1)
        step += 1
        compute_confusion_matrix(label, prediction, confusion_matrix)
        if step % show_iter == 0:
            print('%s %s] %d / %d. iou updating' \
                  % (str(datetime.datetime.now()), str(os.getpid()), step, max_iter))
            compute_iou(confusion_matrix)
            print('imprecise loss', average_loss / step)

    precision = compute_iou(confusion_matrix)
    coord.request_stop()
    coord.join(threads)

    return average_loss / max_iter, precision
Ejemplo n.º 2
0
def infer(image_filename, i_ckpt):

    # < single gpu version >
    # < use FLAGS.batch_size as batch size, it is a number of crops running each time >
    # < use FLAGS.weight_ckpt as i_ckpt >
    # < use FLAGS.database to indicate img_mean and num_classes >

    _, img_mean, num_classes = reader.find_data_path(FLAGS.database)
    crop_size = FLAGS.test_image_size
    # < network >
    model = pspnet_mg.PSPNetMG(num_classes, FLAGS.network, gpu_num(), three_convs_beginning=FLAGS.three_convs_beginning)
    images_pl = [tf.placeholder(tf.float32, [None, crop_size, crop_size, 3])]
    eval_probas_op = model.build_forward_ops(images_pl)

    gpu_options = tf.GPUOptions(allow_growth=False)
    config = tf.ConfigProto(log_device_placement=False, gpu_options=gpu_options, allow_soft_placement=True)
    sess = tf.Session(config=config)
    init = [tf.global_variables_initializer(), tf.local_variables_initializer()]
    sess.run(init)

    loader = tf.train.Saver(max_to_keep=0)
    loader.restore(sess, i_ckpt)

    scales = [1.0]
    if FLAGS.ms == 1:
        scales = [0.5, 0.75, 1.0, 1.25, 1.5, 1.75]

    def run_once(input_image):
        H, W, channel = input_image.shape

        # < in case that input_image is smaller than crop_size >
        dif_height = H - crop_size
        dif_width = W - crop_size
        if dif_height < 0 or dif_width < 0:
            input_image = helper.numpy_pad_image(input_image, dif_height, dif_width)
            H, W, channel = input_image.shape

        # < split this image into crops >
        split_crops = []
        heights = helper.decide_intersection(H, crop_size)
        widths = helper.decide_intersection(W, crop_size)
        for height in heights:
            for width in widths:
                image_crop = input_image[height:height + crop_size, width:width + crop_size]
                split_crops.append(image_crop[np.newaxis, :])

        # < >
        num_chunks = int((len(split_crops) - 1) / FLAGS.batch_size) + 1
        proba_crops_list = []
        for chunk_i in range(num_chunks):
            feed_dict = {}
            start = chunk_i * FLAGS.batch_size
            end = min((chunk_i+1)*FLAGS.batch_size, len(split_crops))
            feed_dict[images_pl[0]] = np.concatenate(split_crops[start:end])
            proba_crops_part = sess.run(eval_probas_op, feed_dict=feed_dict)
            proba_crops_list.append(proba_crops_part[0])

        proba_crops = np.concatenate(proba_crops_list)

        # < reassemble >
        reassemble = np.zeros((H, W, num_classes), np.float32)
        index = 0
        for height in heights:
            for width in widths:
                reassemble[height:height + crop_size, width:width + crop_size] += proba_crops[index]
                index += 1

        # < crop to original image >
        if dif_height < 0 or dif_width < 0:
            reassemble = helper.numpy_crop_image(reassemble, dif_height, dif_width)

        return reassemble

    testDir = '/home/lihang/data/cityscape/demoVideo/stuttgart_02'
    imgFiles = mytool.GetFileList(testDir, 'png')

    for idx, imgFile in enumerate(imgFiles):
        print(idx)
        if idx < 502:
            continue
        img_contents = tf.read_file(imgFile)
        img = tf.image.decode_image(img_contents, channels=3)
        img.set_shape((None, None, 3))  # decode_image does not returns no shape.
        img = tf.cast(img, dtype=tf.float32)
        img -= img_mean

        orig_one_image = sess.run(img)
        orig_height, orig_width, channel = orig_one_image.shape
        total_proba = np.zeros((orig_height, orig_width, num_classes), dtype=np.float32)
        for scale in scales:
            if scale != 1.0:
                one_image = cv2.resize(orig_one_image, dsize=(0, 0), fx=scale, fy=scale)
            else:
                one_image = np.copy(orig_one_image)

            proba = run_once(one_image)
            if FLAGS.mirror == 1:
                proba_mirror = run_once(one_image[:, ::-1])
                proba += proba_mirror[:, ::-1]

            if scale != 1.0:
                proba = cv2.resize(proba, (orig_width, orig_height))

            total_proba += proba

        prediction = np.argmax(total_proba, axis=-1)

        # cv2.imwrite('./demo_examples/demo_prediction.png', prediction)
        if FLAGS.database == 'Cityscapes':
            fileName = imgFile.split('/')[-1]
            cv2.imwrite('./demo_examples/' + fileName,
                        cv2.cvtColor(helper_cityscapes.coloring(prediction), cv2.COLOR_BGR2RGB))
Ejemplo n.º 3
0
def train(resume_step=None):
    # < preparing arguments >
    if FLAGS.float_type == 16:
        print('\n< using tf.float16 >\n')
        float_type = tf.float16
    else:
        print('\n< using tf.float32 >\n')
        float_type = tf.float32
    new_layer_names = FLAGS.new_layer_names
    if FLAGS.new_layer_names is not None:
        new_layer_names = new_layer_names.split(',')

    # < data set >
    data_list = FLAGS.subsets_for_training.split(',')
    if len(data_list) < 1:
        data_list = ['train']
    list_images = []
    list_labels = []
    with tf.device('/cpu:0'):
        reader = SegmentationImageReader(
            FLAGS.database,
            data_list, (FLAGS.train_image_size, FLAGS.train_image_size),
            FLAGS.random_scale,
            random_mirror=True,
            random_blur=True,
            random_rotate=FLAGS.random_rotate,
            color_switch=FLAGS.color_switch,
            scale_rate=(FLAGS.scale_min, FLAGS.scale_max))
        for _ in xrange(FLAGS.gpu_num):
            image_batch, label_batch = reader.dequeue(FLAGS.batch_size)
            list_images.append(image_batch)
            list_labels.append(label_batch)

    # < network >
    model = pspnet_mg.PSPNetMG(
        reader.num_classes,
        mode='train',
        resnet=FLAGS.network,
        bn_mode='frozen' if FLAGS.bn_frozen else 'gather',
        data_format=FLAGS.data_format,
        initializer=FLAGS.initializer,
        fine_tune_filename=FLAGS.fine_tune_filename,
        wd_mode=FLAGS.weight_decay_mode,
        gpu_num=FLAGS.gpu_num,
        float_type=float_type,
        has_aux_loss=FLAGS.has_aux_loss,
        train_like_in_paper=FLAGS.train_like_in_paper,
        structure_in_paper=FLAGS.structure_in_paper,
        new_layer_names=new_layer_names,
        loss_type=FLAGS.loss_type,
        consider_dilated=FLAGS.consider_dilated)
    train_ops = model.build_train_ops(list_images, list_labels)

    # < log dir and model id >
    logdir = LogDir(FLAGS.database, model_id())
    logdir.print_all_info()
    if not os.path.exists(logdir.log_dir):
        print('creating ', logdir.log_dir, '...')
        os.mkdir(logdir.log_dir)
    if not os.path.exists(logdir.database_dir):
        print('creating ', logdir.database_dir, '...')
        os.mkdir(logdir.database_dir)
    if not os.path.exists(logdir.exp_dir):
        print('creating ', logdir.exp_dir, '...')
        os.mkdir(logdir.exp_dir)
    if not os.path.exists(logdir.snapshot_dir):
        print('creating ', logdir.snapshot_dir, '...')
        os.mkdir(logdir.snapshot_dir)

    gpu_options = tf.GPUOptions(allow_growth=False)
    config = tf.ConfigProto(log_device_placement=False,
                            gpu_options=gpu_options,
                            allow_soft_placement=True)
    sess = tf.Session(config=config)
    coord = tf.train.Coordinator()
    threads = tf.train.start_queue_runners(sess=sess, coord=coord)
    init = [
        tf.global_variables_initializer(),
        tf.local_variables_initializer()
    ]
    sess.run(init)

    # < convert npy to .ckpt >
    step = 0
    if '.npy' in FLAGS.fine_tune_filename:
        # This can transform .npy weights with variables names being the same to the tf ckpt model.
        fine_tune_variables = []
        npy_dict = np.load(FLAGS.fine_tune_filename).item()
        new_layers_names = ['Momentum']
        for v in tf.global_variables():
            if any(elem in v.name for elem in new_layers_names):
                continue

            name = v.name.split(':0')[0]
            if name not in npy_dict:
                continue

            v.load(npy_dict[name], sess)
            fine_tune_variables.append(v)

        saver = tf.train.Saver(var_list=fine_tune_variables)
        saver.save(sess, logdir.snapshot_dir + '/model.ckpt', global_step=0)
        return

    # < load pre-trained model>
    import_variables = tf.trainable_variables()
    if FLAGS.fine_tune_filename is not None and resume_step is None:
        fine_tune_variables = []
        new_layers_names = model.new_layers_names
        new_layers_names.append('Momentum')
        new_layers_names.append('up_sample')
        for v in import_variables:
            if any(elem in v.name for elem in new_layers_names):
                print('< Finetuning Process: not import %s >' % v.name)
                continue
            fine_tune_variables.append(v)

        loader = tf.train.Saver(var_list=fine_tune_variables, allow_empty=True)
        loader.restore(sess, FLAGS.fine_tune_filename)
        print('< Succesfully loaded fine-tune model from %s. >' %
              FLAGS.fine_tune_filename)
    elif resume_step is not None:
        # ./snapshot/model.ckpt-3000
        i_ckpt = logdir.snapshot_dir + '/model.ckpt-%d' % resume_step

        loader = tf.train.Saver(max_to_keep=0)
        loader.restore(sess, i_ckpt)

        step = resume_step
        print('< Succesfully loaded model from %s at step=%s. >' %
              (i_ckpt, resume_step))
    else:
        print('< Not import any model. >')

    f_log = open(logdir.exp_dir + '/' + str(datetime.datetime.now()) + '.txt',
                 'w')
    f_log.write('step,loss,precision,wd\n')
    f_log.write(sorted_str_dict(FLAGS.__dict__) + '\n')

    print('\n< training process begins >\n')
    average_loss = 0.0
    show_period = 20
    snapshot = FLAGS.snapshot
    max_iter = FLAGS.train_max_iter
    lrn_rate = FLAGS.lrn_rate

    lr_step = []
    if FLAGS.lr_step is not None:
        temps = FLAGS.lr_step.split(',')
        for t in temps:
            lr_step.append(int(t))

    saver = tf.train.Saver(max_to_keep=2)
    t0 = None
    wd_rate = FLAGS.weight_decay_rate
    wd_rate2 = FLAGS.weight_decay_rate2

    if FLAGS.save_first_iteration == 1:
        saver.save(sess, logdir.snapshot_dir + '/model.ckpt', global_step=step)

    has_nan = False
    while step < max_iter + 1:
        if FLAGS.poly_lr == 1:
            lrn_rate = ((1 - 1.0 * step / max_iter)**0.9) * FLAGS.lrn_rate

        step += 1
        if len(lr_step) > 0 and step == lr_step[0]:
            lrn_rate *= FLAGS.step_size
            lr_step.remove(step)

        _, loss, wd, precision = sess.run(
            [train_ops, model.loss, model.wd, model.precision_op],
            feed_dict={
                model.lrn_rate_ph: lrn_rate,
                model.wd_rate_ph: wd_rate,
                model.wd_rate2_ph: wd_rate2
            })

        if math.isnan(loss) or math.isnan(wd):
            print('\nloss or weight norm is nan. Training Stopped!\n')
            has_nan = True
            break

        average_loss += loss

        if step % snapshot == 0:
            saver.save(sess,
                       logdir.snapshot_dir + '/model.ckpt',
                       global_step=step)
            sess.run([tf.local_variables_initializer()])

        if step % show_period == 0:
            left_hours = 0

            if t0 is not None:
                delta_t = (datetime.datetime.now() - t0).total_seconds()
                left_time = (max_iter - step) / show_period * delta_t
                left_hours = left_time / 3600.0

            t0 = datetime.datetime.now()
            average_loss /= show_period

            f_log.write('%d,%f,%f,%f\n' % (step, average_loss, precision, wd))
            f_log.flush()

            print('%s %s] Step %s, lr = %f, wd_rate = %f, wd_rate_2 = %f ' \
                  % (str(datetime.datetime.now()), str(os.getpid()), step, lrn_rate, wd_rate, wd_rate2))
            print('\t loss = %.4f, precision = %.4f, wd = %.4f' %
                  (average_loss, precision, wd))
            print('\t estimated time left: %.1f hours. %d/%d' %
                  (left_hours, step, max_iter))

            average_loss = 0.0

    coord.request_stop()
    coord.join(threads)

    return f_log, logdir, has_nan  # f_log and logdir returned for eval.
Ejemplo n.º 4
0
def predict(i_ckpt):
    tf.reset_default_graph()

    print '================',
    if FLAGS.data_type == 16:
        #print 'using tf.float16 ====================='
        data_type = tf.float16
    else:
        #print 'using tf.float32 ====================='
        data_type = tf.float32

    image_size = FLAGS.test_image_size
    #print '=====because using pspnet, the inputs have a fixed size and should be divided by 48:', image_size
    assert FLAGS.test_image_size % 48 == 0

    num_classes = 2
    IMG_MEAN = np.array((103.939, 116.779, 123.68), dtype=np.float32)
    with tf.device('/cpu:0'):
        coord = tf.train.Coordinator()
        reader = ImageReader('./infer', 'test.txt', '480,480', 'False',
                             'False', 255, IMG_MEAN, coord)

    images_pl = [tf.placeholder(tf.float32, [None, image_size, image_size, 3])]
    labels_pl = [tf.placeholder(tf.int32, [None, image_size, image_size, 1])]

    with tf.variable_scope('resnet_v1_50'):
        model = pspnet_mg.PSPNetMG(
            num_classes,
            None,
            None,
            None,
            mode=FLAGS.mode,
            bn_epsilon=FLAGS.epsilon,
            resnet='resnet_v1_50',
            norm_only=FLAGS.norm_only,
            float_type=data_type,
            has_aux_loss=False,
            structure_in_paper=FLAGS.structure_in_paper,
            resize_images_method=FLAGS.resize_images_method)
        l = model.inference(images_pl)
    # ========================= end of building model ================================

    gpu_options = tf.GPUOptions(allow_growth=False)
    config = tf.ConfigProto(log_device_placement=False,
                            gpu_options=gpu_options,
                            allow_soft_placement=True)
    sess = tf.Session(config=config)
    sess.run(
        [tf.global_variables_initializer(),
         tf.local_variables_initializer()])

    coord = tf.train.Coordinator()
    threads = tf.train.start_queue_runners(sess=sess, coord=coord)

    if i_ckpt is not None:
        loader = tf.train.Saver(max_to_keep=0)
        loader.restore(sess, i_ckpt)
        eval_step = i_ckpt.split('-')[-1]
        #print('Succesfully loaded model from %s at step=%s.' % (i_ckpt, eval_step))

    print '======================= eval process begins ========================='
    if FLAGS.save_prediction == 0 and FLAGS.mode != 'test':
        print 'not saving prediction ... '

    average_loss = 0.0
    confusion_matrix = np.zeros((num_classes, num_classes), dtype=np.int64)

    if FLAGS.save_prediction == 1 or FLAGS.mode == 'test':
        try:
            os.mkdir('./' + FLAGS.mode + '_set')
        except:
            pass
        prefix = './' + FLAGS.mode + '_set'
        try:
            os.mkdir(os.path.join(prefix, FLAGS.weights_ckpt.split('/')[-2]))
        except:
            pass
        prefix = os.path.join(prefix, FLAGS.weights_ckpt.split('/')[-2])

    if FLAGS.ms == 1:
        scales = [0.5, 0.75, 1.0, 1.25, 1.5, 1.75]
    else:
        scales = [1.0]

    images_filenames = reader.image_list
    # labels_filenames = reader.label_list

    if FLAGS.test_max_iter is None:
        max_iter = len(images_filenames)
    else:
        max_iter = FLAGS.test_max_iter

    # IMG_MEAN = [123.680000305, 116.778999329, 103.939002991]  # RGB mean from official PSPNet

    step = 0
    while step < max_iter:
        image = cv2.imread(images_filenames[step], 1)
        # label = np.reshape(label, [1, label.shape[0], label.shape[1], 1])
        image_height, image_width = image.shape[0], image.shape[1]

        total_logits = np.zeros((image_height, image_width, num_classes),
                                np.float32)
        for scale in scales:
            imgsplitter = ImageSplitter(image, scale, FLAGS.color_switch,
                                        image_size, IMG_MEAN)
            crops = imgsplitter.get_split_crops()

            # This is a suboptimal solution. More batches each iter, more rapid.
            # But the limit of batch size is unknown.
            # TODO: Or there should be a more efficient way.
            if crops.shape[0] > 10:
                half = crops.shape[0] / 2

                feed_dict = {images_pl[0]: crops[0:half]}
                [logits_0] = sess.run([model.probabilities],
                                      feed_dict=feed_dict)

                feed_dict = {images_pl[0]: crops[half:]}
                [logits_1] = sess.run([model.probabilities],
                                      feed_dict=feed_dict)
                logits = np.concatenate((logits_0, logits_1), axis=0)
            else:
                feed_dict = {images_pl[0]: imgsplitter.get_split_crops()}
                [logits] = sess.run([model.probabilities], feed_dict=feed_dict)
            scale_logits = imgsplitter.reassemble_crops(logits)

            if FLAGS.mirror == 1:
                image_mirror = image[:, ::-1]
                imgsplitter_mirror = ImageSplitter(image_mirror, scale,
                                                   FLAGS.color_switch,
                                                   image_size, IMG_MEAN)
                crops_m = imgsplitter_mirror.get_split_crops()
                if crops_m.shape[0] > 10:
                    half = crops_m.shape[0] / 2

                    feed_dict = {images_pl[0]: crops_m[0:half]}
                    [logits_0] = sess.run([model.probabilities],
                                          feed_dict=feed_dict)

                    feed_dict = {images_pl[0]: crops_m[half:]}
                    [logits_1] = sess.run([model.probabilities],
                                          feed_dict=feed_dict)
                    logits_m = np.concatenate((logits_0, logits_1), axis=0)
                else:
                    feed_dict = {
                        images_pl[0]: imgsplitter_mirror.get_split_crops()
                    }
                    [logits_m] = sess.run([model.probabilities],
                                          feed_dict=feed_dict)
                logits_m = imgsplitter_mirror.reassemble_crops(logits_m)
                scale_logits += logits_m[:, ::-1]

            if scale != 1.0:
                scale_logits = cv2.resize(scale_logits,
                                          (image_width, image_height),
                                          interpolation=cv2.INTER_LINEAR)

            total_logits += scale_logits

        prediction = np.argmax(total_logits, axis=-1)
        # print np.max(label), np.max(prediction)

        if FLAGS.database == 'Cityscapes' and (FLAGS.save_prediction == 1
                                               or FLAGS.mode == 'test'):
            image_prefix = images_filenames[step].split('/')[-1].split('_leftImg8bit.png')[0] + '_' \
                           + FLAGS.weights_ckpt.split('/')[-2]

            cv2.imwrite(os.path.join(prefix, image_prefix + '_prediction.png'),
                        trainid_to_labelid(prediction))
            if FLAGS.coloring == 1:
                color_prediction = coloring(prediction)
                cv2.imwrite(
                    os.path.join(prefix, image_prefix + '_coloring.png'),
                    cv2.cvtColor(color_prediction, cv2.COLOR_BGR2RGB))
        elif FLAGS.database == 'sonardata' and (FLAGS.save_prediction == 1
                                                or FLAGS.mode == 'test'):
            image_prefix = images_filenames[step].split('/')[-1].split(
                '.png')[0]
            cv2.imwrite(os.path.join(prefix, image_prefix + '.png'),
                        prediction)
        else:
            pass

        step += 1

    coord.request_stop()
    coord.join(threads)

    return average_loss / max_iter
Ejemplo n.º 5
0
def predict(i_ckpt):
    assert i_ckpt is not None

    if FLAGS.float_type == 16:
        print('\n< using tf.float16 >\n')
        float_type = tf.float16
    else:
        print('\n< using tf.float32 >\n')
        float_type = tf.float32

    image_size = FLAGS.test_image_size
    assert FLAGS.test_image_size % 48 == 0

    with tf.device('/cpu:0'):
        reader = SegmentationImageReader(FLAGS.database,
                                         FLAGS.mode, (image_size, image_size),
                                         random_scale=False,
                                         random_mirror=False,
                                         random_blur=False,
                                         random_rotate=False,
                                         color_switch=FLAGS.color_switch)

    images_pl = [tf.placeholder(tf.float32, [None, image_size, image_size, 3])]

    model = pspnet_mg.PSPNetMG(reader.num_classes,
                               mode='val',
                               resnet=FLAGS.network,
                               data_format=FLAGS.data_format,
                               float_type=float_type,
                               has_aux_loss=False,
                               structure_in_paper=FLAGS.structure_in_paper)
    logits = model.inference(images_pl)
    probas_op = tf.nn.softmax(logits[0],
                              dim=1 if FLAGS.data_format == 'NCHW' else 3)
    # ========================= end of building model ================================

    gpu_options = tf.GPUOptions(allow_growth=False)
    config = tf.ConfigProto(log_device_placement=False,
                            gpu_options=gpu_options,
                            allow_soft_placement=True)
    sess = tf.Session(config=config)
    sess.run(
        [tf.global_variables_initializer(),
         tf.local_variables_initializer()])

    coord = tf.train.Coordinator()
    threads = tf.train.start_queue_runners(sess=sess, coord=coord)

    loader = tf.train.Saver(max_to_keep=0)
    loader.restore(sess, i_ckpt)
    print('Succesfully loaded model from %s.' % i_ckpt)

    print(
        '======================= eval process begins ========================='
    )
    if FLAGS.save_prediction == 0 and FLAGS.mode != 'test':
        print('not saving prediction ... ')

    average_loss = 0.0
    confusion_matrix = np.zeros((reader.num_classes, reader.num_classes),
                                dtype=np.int64)

    if FLAGS.save_prediction == 1 or FLAGS.mode == 'test':
        try:
            os.mkdir('./' + FLAGS.mode + '_set')
        except:
            pass
        prefix = './' + FLAGS.mode + '_set'
        try:
            os.mkdir(os.path.join(prefix, FLAGS.weights_ckpt.split('/')[-2]))
        except:
            pass
        prefix = os.path.join(prefix, FLAGS.weights_ckpt.split('/')[-2])

    if FLAGS.ms == 1:
        scales = [0.5, 0.75, 1.0, 1.25, 1.5, 1.75]
    else:
        scales = [1.0]

    images_filenames = reader.image_list
    labels_filenames = reader.label_list
    img_mean = reader.img_mean

    if FLAGS.test_max_iter is None:
        max_iter = len(images_filenames)
    else:
        max_iter = FLAGS.test_max_iter

    step = 0
    while step < max_iter:
        image, label = cv2.imread(images_filenames[step],
                                  1), cv2.imread(labels_filenames[step], 0)
        label = np.reshape(label, [1, label.shape[0], label.shape[1], 1])
        image_height, image_width = image.shape[0], image.shape[1]

        total_logits = np.zeros(
            (image_height, image_width, reader.num_classes), np.float32)
        for scale in scales:
            imgsplitter = ImageSplitter(image, scale, FLAGS.color_switch,
                                        image_size, img_mean)
            crops = imgsplitter.get_split_crops()

            # This is a suboptimal solution. More batches each iter, more rapid.
            # But the limit of batch size is unknown.
            # TODO: Or there should be a more efficient way.
            if crops.shape[0] > 10 and FLAGS.database == 'Cityscapes':
                half = crops.shape[0] // 2

                feed_dict = {images_pl[0]: crops[0:half]}
                [logits_0] = sess.run([probas_op], feed_dict=feed_dict)

                feed_dict = {images_pl[0]: crops[half:]}
                [logits_1] = sess.run([probas_op], feed_dict=feed_dict)
                logits = np.concatenate((logits_0, logits_1), axis=0)
            else:
                feed_dict = {images_pl[0]: imgsplitter.get_split_crops()}
                [logits] = sess.run([probas_op], feed_dict=feed_dict)
            scale_logits = imgsplitter.reassemble_crops(logits)

            if FLAGS.mirror == 1:
                image_mirror = image[:, ::-1]
                imgsplitter_mirror = ImageSplitter(image_mirror, scale,
                                                   FLAGS.color_switch,
                                                   image_size, img_mean)
                crops_m = imgsplitter_mirror.get_split_crops()
                if crops_m.shape[0] > 10:
                    half = crops_m.shape[0] // 2

                    feed_dict = {images_pl[0]: crops_m[0:half]}
                    [logits_0] = sess.run([probas_op], feed_dict=feed_dict)

                    feed_dict = {images_pl[0]: crops_m[half:]}
                    [logits_1] = sess.run([probas_op], feed_dict=feed_dict)
                    logits_m = np.concatenate((logits_0, logits_1), axis=0)
                else:
                    feed_dict = {
                        images_pl[0]: imgsplitter_mirror.get_split_crops()
                    }
                    [logits_m] = sess.run([probas_op], feed_dict=feed_dict)
                logits_m = imgsplitter_mirror.reassemble_crops(logits_m)
                scale_logits += logits_m[:, ::-1]

            if scale != 1.0:
                scale_logits = cv2.resize(scale_logits,
                                          (image_width, image_height),
                                          interpolation=cv2.INTER_LINEAR)

            total_logits += scale_logits

        prediction = np.argmax(total_logits, axis=-1)
        # print np.max(label), np.max(prediction)

        image_prefix = images_filenames[step].split('/')[-1].split(
            '.')[0] + '_' + FLAGS.weights_ckpt.split('/')[-2]
        if FLAGS.database == 'Cityscapes':
            cv2.imwrite(os.path.join(prefix, image_prefix + '_prediction.png'),
                        trainid_to_labelid(prediction))
            if FLAGS.coloring == 1:
                cv2.imwrite(
                    os.path.join(prefix, image_prefix + '_coloring.png'),
                    cv2.cvtColor(coloring(prediction), cv2.COLOR_BGR2RGB))
        else:
            cv2.imwrite(os.path.join(prefix, image_prefix + '_prediction.png'),
                        prediction)
            # TODO: add coloring for databases other than Cityscapes.

        step += 1

        compute_confusion_matrix(label, prediction, confusion_matrix)
        if step % 20 == 0:
            print('%s %s] %d / %d. iou updating' \
                  % (str(datetime.datetime.now()), str(os.getpid()), step, max_iter))
            compute_iou(confusion_matrix)
            print(average_loss / step)

    precision = compute_iou(confusion_matrix)
    coord.request_stop()
    coord.join(threads)

    return average_loss / max_iter, precision
Ejemplo n.º 6
0
def train_and_eval():
    # < data set >
    data_list = FLAGS.subsets_for_training.split(',')
    if len(data_list) < 1:
        data_list = ['train']

    train_reader_inits = []
    eval_reader_inits = []
    with tf.device('/cpu:0'):
        if FLAGS.reader_method == 'queue':
            train_image_reader = reader.QueueBasedImageReader(
                FLAGS.database, data_list)
            batch_images, batch_labels = train_image_reader.get_batch(
                FLAGS.batch_size * gpu_num(), FLAGS.train_image_size,
                FLAGS.random_mirror, FLAGS.random_blur, FLAGS.random_rotate,
                FLAGS.color_switch, FLAGS.random_scale,
                (FLAGS.scale_min, FLAGS.scale_max))
            list_images = tf.split(batch_images, gpu_num())
            list_labels = tf.split(batch_labels, gpu_num())

            eval_image_reader = reader.QueueBasedImageReader(
                FLAGS.database, 'val')
            eval_image, eval_label, _ = eval_image_reader.get_eval_batch(
                FLAGS.color_switch)
        else:
            # the performance is not good as using queue runners.
            train_image_reader = reader.ImageReader(FLAGS.database, data_list)
            train_reader_iterator = train_image_reader.get_batch_iterator(
                FLAGS.batch_size * gpu_num(), FLAGS.train_image_size,
                FLAGS.random_mirror, FLAGS.random_blur, FLAGS.random_rotate,
                FLAGS.color_switch, FLAGS.random_scale,
                (FLAGS.scale_min, FLAGS.scale_max))
            batch_images, batch_labels = train_reader_iterator.get_next()
            list_images = tf.split(batch_images, gpu_num())
            list_labels = tf.split(batch_labels, gpu_num())

            eval_image_reader = reader.ImageReader(FLAGS.database, 'val')
            eval_reader_iterator = eval_image_reader.get_eval_iterator(
                FLAGS.color_switch)
            eval_image, eval_label, _ = eval_reader_iterator.get_next(
            )  # one image.

            train_reader_inits.append(train_reader_iterator.initializer)
            eval_reader_inits.append(eval_reader_iterator.initializer)

    # < network >
    model = pspnet_mg.PSPNetMG(train_image_reader.num_classes,
                               FLAGS.network,
                               gpu_num(),
                               FLAGS.initializer,
                               FLAGS.weight_decay_mode,
                               FLAGS.fine_tune_filename,
                               FLAGS.optimizer,
                               FLAGS.momentum,
                               FLAGS.train_like_in_caffe,
                               FLAGS.three_convs_beginning,
                               FLAGS.new_layer_names,
                               consider_dilated=FLAGS.consider_dilated)
    train_ops, losses_op, metrics_op = model.build_train_ops(
        list_images, list_labels)

    eval_image_pl = []
    crop_size = FLAGS.test_image_size
    for _ in range(gpu_num()):
        eval_image_pl.append(
            tf.placeholder(tf.float32, [None, crop_size, crop_size, 3]))
    eval_probas_op = model.build_forward_ops(eval_image_pl)

    # < log dir and model id >
    exp_dir, snapshot_dir = prepare_log_dir(FLAGS.database, get_model_id())

    gpu_options = tf.GPUOptions(allow_growth=False)
    config = tf.ConfigProto(log_device_placement=False,
                            gpu_options=gpu_options,
                            allow_soft_placement=True)
    sess = tf.Session(config=config)

    if FLAGS.reader_method == 'queue':
        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(sess=sess, coord=coord)

    init = [
        tf.global_variables_initializer(),
        tf.local_variables_initializer()
    ] + train_reader_inits
    sess.run(init)

    # < load pre-trained model>
    import_variables = tf.trainable_variables()
    if FLAGS.fine_tune_filename is not None:
        fine_tune_variables = []
        new_layers_names = model.new_layers_names
        new_layers_names.append('Momentum')
        new_layers_names.append('up_sample')
        for v in import_variables:
            if any(elem in v.name for elem in new_layers_names):
                print('\t[verbo] < Finetuning Process: not import %s >' %
                      v.name)
                continue
            fine_tune_variables.append(v)

        loader = tf.train.Saver(var_list=fine_tune_variables, allow_empty=True)
        loader.restore(sess, FLAGS.fine_tune_filename)
        print('\t[verbo] < Succesfully loaded fine-tune model from %s. >' %
              FLAGS.fine_tune_filename)
    else:
        print('\t[verbo] < Not import any model. >')

    f_log = open(exp_dir + '/' + str(datetime.datetime.now()) + '.txt', 'w')
    tags = ''
    for loss_op in losses_op:
        tags += loss_op.name.split('/')[-1].split(':')[0] + ','
    for metric_op in metrics_op:
        tags += metric_op.name.split('/')[-1].split(':')[0] + ','
    tags = tags[:-1]
    f_log.write(tags + '\n')
    f_log.write(sorted_str_dict(FLAGS.__dict__) + '\n')

    print('\n\t < training process begins >\n')
    show_period = FLAGS.train_max_iter // 2000
    snapshot = FLAGS.snapshot
    max_iter = FLAGS.train_max_iter
    lrn_rate = FLAGS.lrn_rate

    lr_step = []
    if FLAGS.lr_step is not None:
        temps = FLAGS.lr_step.split(',')
        for t in temps:
            lr_step.append(int(t))

    saver = tf.train.Saver(max_to_keep=2)
    t0 = None
    wd_rate = FLAGS.weight_decay_rate
    wd_rate2 = FLAGS.weight_decay_rate2
    has_nan = False
    step = 0

    if FLAGS.save_first_iteration == 1:
        saver.save(sess, snapshot_dir + '/model.ckpt', global_step=step)

    def run_for_eval(input_image):
        H, W, channel = input_image.shape

        # < in case that input_image is smaller than crop_size >
        dif_height = H - crop_size
        dif_width = W - crop_size
        if dif_height < 0 or dif_width < 0:
            input_image = helper.numpy_pad_image(input_image, dif_height,
                                                 dif_width)
            H, W, channel = input_image.shape

        # < split >
        split_crops = []
        heights = helper.decide_intersection(H, crop_size)
        widths = helper.decide_intersection(W, crop_size)
        for height in heights:
            for width in widths:
                image_crop = input_image[height:height + crop_size,
                                         width:width + crop_size]
                split_crops.append(image_crop[np.newaxis, :])

        feed_dict = {}
        splitters = chunks(split_crops, gpu_num())
        for list_index in range(len(splitters) - 1):
            piece_crops = np.concatenate(
                split_crops[splitters[list_index]:splitters[list_index + 1]])
            feed_dict[eval_image_pl[list_index]] = piece_crops

        for i in range(gpu_num()):
            if eval_image_pl[i] not in feed_dict.keys():
                feed_dict[eval_image_pl[i]] = np.zeros(
                    (1, crop_size, crop_size, 3), np.float32)

        proba_crops_pieces = sess.run(eval_probas_op, feed_dict=feed_dict)
        proba_crops = np.concatenate(proba_crops_pieces)

        # < reassemble >
        reassemble = np.zeros((H, W, eval_image_reader.num_classes),
                              np.float32)
        index = 0
        for height in heights:
            for width in widths:
                reassemble[height:height + crop_size,
                           width:width + crop_size] += proba_crops[index]
                index += 1

        # < crop to original image >
        if dif_height < 0 or dif_width < 0:
            reassemble = helper.numpy_crop_image(reassemble, dif_height,
                                                 dif_width)

        return reassemble

    while step < max_iter + 1:
        if FLAGS.poly_lr == 1:
            lrn_rate = ((1 - 1.0 * step / max_iter)**0.9) * FLAGS.lrn_rate

        step += 1
        if len(lr_step) > 0 and step == lr_step[0]:
            lrn_rate *= FLAGS.step_size
            lr_step.remove(step)

        _, losses, metrics = sess.run(
            [train_ops, losses_op, metrics_op],
            feed_dict={
                model.lrn_rate_ph: lrn_rate,
                model.wd_rate_ph: wd_rate,
                model.wd_rate2_ph: wd_rate2
            })

        if math.isnan(losses[0]) or math.isnan(losses[-1]):
            print('\nloss or weight norm is nan. Training Stopped!\n')
            has_nan = True
            break

        if step % show_period == 0:
            left_hours = 0
            if t0 is not None:
                delta_t = (datetime.datetime.now() - t0).total_seconds()
                left_time = (max_iter - step) / show_period * delta_t
                left_hours = left_time / 3600.0
            t0 = datetime.datetime.now()

            # these losses are not averaged.
            merged_losses = losses + metrics
            str_merged_loss = str(step) + ','
            for i, l in enumerate(merged_losses):
                if i == len(merged_losses) - 1:
                    str_merged_loss += str(l) + '\n'
                else:
                    str_merged_loss += str(l) + ','
            f_log.write(str_merged_loss)
            f_log.flush()

            print(
                '%s %s] Step %d, lr = %f, wd_mode = %d, wd_rate = %f, wd_rate_2 = %f '
                % (str(datetime.datetime.now()), str(os.getpid()), step,
                   lrn_rate, FLAGS.weight_decay_mode, wd_rate, wd_rate2))
            for i, tag in enumerate(tags.split(',')):
                print(tag, '=', merged_losses[i], end=', ')
            print('')
            print('\tEstimated time left: %.2f hours. %d/%d' %
                  (left_hours, step, max_iter))

        if step % snapshot == 0 or step == max_iter:
            saver.save(sess, snapshot_dir + '/model.ckpt', global_step=step)
            confusion_matrix = np.zeros(
                (eval_image_reader.num_classes, eval_image_reader.num_classes),
                dtype=np.int64)
            sess.run([tf.local_variables_initializer()] + eval_reader_inits)
            for i in range(len(eval_image_reader.image_list)):
                orig_one_image, one_label = sess.run([eval_image, eval_label])
                proba = run_for_eval(orig_one_image)
                prediction = np.argmax(proba, axis=-1)
                helper.compute_confusion_matrix(one_label, prediction,
                                                confusion_matrix)

            mIoU = helper.compute_iou(confusion_matrix)
            str_merged_loss = 'TEST:' + str(step) + ',' + str(mIoU) + '\n'
            f_log.write(str_merged_loss)
            f_log.flush()

    f_log.close()

    if FLAGS.reader_method == 'queue':
        coord.request_stop()
        coord.join(threads)
Ejemplo n.º 7
0
def inference(i_ckpt):
    if FLAGS.float_type == 16:
        print('\n< using tf.float16 >\n')
        float_type = tf.float16
    else:
        print('\n< using tf.float32 >\n')
        float_type = tf.float32

    image_size = FLAGS.test_image_size
    assert FLAGS.test_image_size % 48 == 0

    images_pl = [tf.placeholder(tf.float32, [None, image_size, image_size, 3])]
    data_dir, img_mean, num_classes = find_data_path(FLAGS.database)
    model = pspnet_mg.PSPNetMG(num_classes,
                               mode='val',
                               resnet=FLAGS.network,
                               data_format=FLAGS.data_format,
                               float_type=float_type,
                               has_aux_loss=False,
                               structure_in_paper=FLAGS.structure_in_paper)
    logits = model.inference(images_pl)
    probas_op = tf.nn.softmax(logits[0],
                              dim=1 if FLAGS.data_format == 'NCHW' else 3)
    # ========================= end of building model ================================

    gpu_options = tf.GPUOptions(allow_growth=False)
    config = tf.ConfigProto(log_device_placement=False,
                            gpu_options=gpu_options,
                            allow_soft_placement=True)
    sess = tf.Session(config=config)
    sess.run(
        [tf.global_variables_initializer(),
         tf.local_variables_initializer()])

    coord = tf.train.Coordinator()
    threads = tf.train.start_queue_runners(sess=sess, coord=coord)

    if i_ckpt is not None:
        loader = tf.train.Saver(max_to_keep=0)
        loader.restore(sess, i_ckpt)
        eval_step = i_ckpt.split('-')[-1]
        print('Succesfully loaded model from %s at step=%s.' %
              (i_ckpt, eval_step))

    print(
        '======================= eval process begins ========================='
    )
    try:
        os.mkdir('./inference_set')
    except:
        pass
    prefix = './inference_set'
    try:
        os.mkdir(os.path.join(prefix, FLAGS.weights_ckpt.split('/')[-2]))
    except:
        pass
    prefix = os.path.join(prefix, FLAGS.weights_ckpt.split('/')[-2])

    if FLAGS.ms == 1:
        scales = [0.5, 0.75, 1.0, 1.25, 1.5, 1.75]
    else:
        scales = [1.0]

    def inf_one_image(image_path):
        t0 = datetime.datetime.now()
        image = cv2.imread(image_path, 1)
        image_height, image_width = image.shape[0], image.shape[1]

        total_logits = np.zeros((image_height, image_width, num_classes),
                                np.float32)
        for scale in scales:
            imgsplitter = ImageSplitter(image, scale, FLAGS.color_switch,
                                        image_size, img_mean)
            crops = imgsplitter.get_split_crops()

            # This is a suboptimal solution. More batches each iter, more rapid.
            # But the limit of batch size is unknown.
            # TODO: Or there should be a more efficient way.
            if crops.shape[0] > 10 and FLAGS.database == 'Cityscapes':
                half = crops.shape[0] // 2

                feed_dict = {images_pl[0]: crops[0:half]}
                [logits_0] = sess.run([probas_op], feed_dict=feed_dict)

                feed_dict = {images_pl[0]: crops[half:]}
                [logits_1] = sess.run([probas_op], feed_dict=feed_dict)
                logits = np.concatenate((logits_0, logits_1), axis=0)
            else:
                feed_dict = {images_pl[0]: imgsplitter.get_split_crops()}
                [logits] = sess.run([probas_op], feed_dict=feed_dict)
            scale_logits = imgsplitter.reassemble_crops(logits)

            if FLAGS.mirror == 1:
                image_mirror = image[:, ::-1]
                imgsplitter_mirror = ImageSplitter(image_mirror, scale,
                                                   FLAGS.color_switch,
                                                   image_size, img_mean)
                crops_m = imgsplitter_mirror.get_split_crops()
                if crops_m.shape[0] > 10:
                    half = crops_m.shape[0] // 2

                    feed_dict = {images_pl[0]: crops_m[0:half]}
                    [logits_0] = sess.run([probas_op], feed_dict=feed_dict)

                    feed_dict = {images_pl[0]: crops_m[half:]}
                    [logits_1] = sess.run([probas_op], feed_dict=feed_dict)
                    logits_m = np.concatenate((logits_0, logits_1), axis=0)
                else:
                    feed_dict = {
                        images_pl[0]: imgsplitter_mirror.get_split_crops()
                    }
                    [logits_m] = sess.run([probas_op], feed_dict=feed_dict)
                logits_m = imgsplitter_mirror.reassemble_crops(logits_m)
                scale_logits += logits_m[:, ::-1]

            if scale != 1.0:
                scale_logits = cv2.resize(scale_logits,
                                          (image_width, image_height),
                                          interpolation=cv2.INTER_LINEAR)

            total_logits += scale_logits

        prediction = np.argmax(total_logits, axis=-1)

        image_prefix = image_path.split('/')[-1].split(
            '.')[0] + '_' + FLAGS.weights_ckpt.split('/')[-2]
        if FLAGS.database == 'Cityscapes':
            cv2.imwrite(os.path.join(prefix, image_prefix + '_prediction.png'),
                        trainid_to_labelid(prediction))
            cv2.imwrite(os.path.join(prefix, image_prefix + '_coloring.png'),
                        cv2.cvtColor(coloring(prediction), cv2.COLOR_BGR2RGB))
        else:
            cv2.imwrite(os.path.join(prefix, image_prefix + '_prediction.png'),
                        prediction)
            # TODO: add coloring for databases other than Cityscapes.
        delta_t = (datetime.datetime.now() - t0).total_seconds()
        print('\n[info]\t saved!', delta_t, 'seconds.')

    if FLAGS.image_path is not None:
        inf_one_image(FLAGS.image_path)
    else:
        while True:
            image_path = raw_input('Enter the image filename:')
            try:
                inf_one_image(image_path)
            except:
                continue

    coord.request_stop()
    coord.join(threads)

    return
Ejemplo n.º 8
0
def predict(i_ckpt):

    # < single gpu version >
    # < use FLAGS.batch_size as batch size >
    # < use FLAGS.weight_ckpt as i_ckpt >

    reader_init = []
    with tf.device('/cpu:0'):
        if FLAGS.reader_method == 'queue':
            eval_image_reader = reader.QueueBasedImageReader(FLAGS.database, FLAGS.test_subset)
            eval_image, eval_label, eval_image_filename = eval_image_reader.get_eval_batch(FLAGS.color_switch)
        else:
            eval_image_reader = reader.ImageReader(FLAGS.database, FLAGS.test_subset)
            eval_reader_iterator = eval_image_reader.get_eval_iterator(FLAGS.color_switch)
            eval_image, eval_label, eval_image_filename = eval_reader_iterator.get_next()  # one image.
            reader_init.append(eval_reader_iterator.initializer)

    crop_size = FLAGS.test_image_size
    # < network >
    model = pspnet_mg.PSPNetMG(eval_image_reader.num_classes, FLAGS.network, gpu_num(), FLAGS.initializer,
                               FLAGS.weight_decay_mode, FLAGS.fine_tune_filename, FLAGS.optimizer, FLAGS.momentum,
                               FLAGS.train_like_in_caffe, FLAGS.three_convs_beginning, FLAGS.new_layer_names,
                               consider_dilated=FLAGS.consider_dilated)
    images_pl = [tf.placeholder(tf.float32, [None, crop_size, crop_size, 3])]
    eval_probas_op = model.build_forward_ops(images_pl)

    gpu_options = tf.GPUOptions(allow_growth=False)
    config = tf.ConfigProto(log_device_placement=False, gpu_options=gpu_options, allow_soft_placement=True)
    sess = tf.Session(config=config)

    if FLAGS.reader_method == 'queue':
        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(sess=sess, coord=coord)

    init = [tf.global_variables_initializer(), tf.local_variables_initializer()] + reader_init
    sess.run(init)

    loader = tf.train.Saver(max_to_keep=0)
    loader.restore(sess, i_ckpt)

    prefix = i_ckpt.split('model.ckpt')[0] + FLAGS.test_subset + '_set/'
    if not os.path.exists(prefix) and 'test' in FLAGS.test_subset:
        os.mkdir(prefix)
        print('saving predictions to', prefix)

    confusion_matrix = np.zeros((eval_image_reader.num_classes, eval_image_reader.num_classes), dtype=np.int64)
    scales = [1.0]
    if FLAGS.ms == 1:
        scales = [0.5, 0.75, 1.0, 1.25, 1.5, 1.75]

    def run_once(input_image):
        H, W, channel = input_image.shape

        # < in case that input_image is smaller than crop_size >
        dif_height = H - crop_size
        dif_width = W - crop_size
        if dif_height < 0 or dif_width < 0:
            input_image = helper.numpy_pad_image(input_image, dif_height, dif_width)
            H, W, channel = input_image.shape

        # < split this image into crops >
        split_crops = []
        heights = helper.decide_intersection(H, crop_size)
        widths = helper.decide_intersection(W, crop_size)
        for height in heights:
            for width in widths:
                image_crop = input_image[height:height + crop_size, width:width + crop_size]
                split_crops.append(image_crop[np.newaxis, :])

        # < >
        num_chunks = int((len(split_crops) - 1) / FLAGS.batch_size) + 1
        proba_crops_list = []
        for chunk_i in range(num_chunks):
            feed_dict = {}
            start = chunk_i * FLAGS.batch_size
            end = min((chunk_i+1)*FLAGS.batch_size, len(split_crops))
            feed_dict[images_pl[0]] = np.concatenate(split_crops[start:end])
            proba_crops_part = sess.run(eval_probas_op, feed_dict=feed_dict)
            proba_crops_list.append(proba_crops_part[0])

        proba_crops = np.concatenate(proba_crops_list)

        # < reassemble >
        reassemble = np.zeros((H, W, eval_image_reader.num_classes), np.float32)
        index = 0
        for height in heights:
            for width in widths:
                reassemble[height:height + crop_size, width:width + crop_size] += proba_crops[index]
                index += 1

        # < crop to original image >
        if dif_height < 0 or dif_width < 0:
            reassemble = helper.numpy_crop_image(reassemble, dif_height, dif_width)

        return reassemble

    for i in range(len(eval_image_reader.image_list)):
        orig_one_image, one_label, image_filename = sess.run([eval_image, eval_label, eval_image_filename])
        orig_height, orig_width, channel = orig_one_image.shape
        total_proba = np.zeros((orig_height, orig_width, eval_image_reader.num_classes), dtype=np.float32)
        for scale in scales:
            if scale != 1.0:
                one_image = cv2.resize(orig_one_image, dsize=(0, 0), fx=scale, fy=scale)
            else:
                one_image = np.copy(orig_one_image)

            proba = run_once(one_image)
            if FLAGS.mirror == 1:
                proba_mirror = run_once(one_image[:, ::-1])
                proba += proba_mirror[:, ::-1]

            if scale != 1.0:
                proba = cv2.resize(proba, (orig_width, orig_height))

            total_proba += proba

        prediction = np.argmax(total_proba, axis=-1)
        helper.compute_confusion_matrix(one_label, prediction, confusion_matrix)

        if 'test' in FLAGS.test_subset:
            if FLAGS.database == 'Cityscapes':
                cv2.imwrite(prefix + prediction_image_create(image_filename),
                            helper_cityscapes.trainid_to_labelid(prediction))
                if FLAGS.coloring == 1:
                    cv2.imwrite(prefix + coloring_image_create(image_filename),
                                cv2.cvtColor(helper_cityscapes.coloring(prediction), cv2.COLOR_BGR2RGB))
            else:
                cv2.imwrite(prefix + prediction_image_create(image_filename, key_word='.'), prediction)

        if i % 100 == 0:
            print('%s %s] %d / %d. iou updating' \
                  % (str(datetime.datetime.now()), str(os.getpid()), i, len(eval_image_reader.image_list)))
            helper.compute_iou(confusion_matrix)

    print('%s %s] %d / %d. iou updating' \
          % (str(datetime.datetime.now()), str(os.getpid()),
             len(eval_image_reader.image_list),
             len(eval_image_reader.image_list)))
    miou = helper.compute_iou(confusion_matrix)

    log_file = i_ckpt.split('model.ckpt')[0] + 'predict-ms' + str(FLAGS.ms) + '-mirror' + str(FLAGS.mirror) + '.txt'
    f_log = open(log_file, 'w')
    f_log.write(sorted_str_dict(FLAGS.__dict__) + '\n')
    ious = helper.compute_iou_each_class(confusion_matrix)
    f_log.write(str(ious) + '\n')
    for i in range(confusion_matrix.shape[0]):
        f_log.write(str(ious[i]) + '\n')
    f_log.write(str(miou) + '\n')

    if FLAGS.reader_method == 'queue':
        coord.request_stop()
        coord.join(threads)

    return