Exemplo n.º 1
0
def get_model(images, num_classes, filter_scale, for_eval=False):
    net = ICNet_BN({'data': images},
                   is_training=not for_eval,
                   evaluation=for_eval,
                   num_classes=num_classes,
                   filter_scale=filter_scale)
    return net
Exemplo n.º 2
0
def main():
    args = get_arguments()
    cfg = Config(dataset="others",
                 is_training=False,
                 filter_scale=args.filter_scale)
    cfg.param["num_classes"] = args.num_classes

    net = ICNet_BN(cfg=cfg, mode='inference')

    net.create_session()
    net.restore(args.weights)

    os.makedirs(args.output_dir, exist_ok=True)

    input_files = os.listdir(args.input_dir)
    input_files = [
        f for f in input_files if ".png" in f or ".jpg" in f or ".jpeg" in f
    ]

    for input_file in input_files:
        input_path = os.path.join(args.input_dir, input_file)
        img = cv2.imread(input_path)
        # img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        if img.shape[0] == 400 and img.ndim == 3:
            img = np.vstack(
                [np.zeros(shape=(200, 800, 3), dtype=img.dtype), img])
        if img.shape != cfg.INFER_SIZE:
            img = cv2.resize(img, (cfg.INFER_SIZE[1], cfg.INFER_SIZE[0]))
        prediction = net.predict(img)[0]
        res = np.concatenate([img, prediction], axis=1)
        output_path = os.path.join(args.output_dir, input_file)
        cv2.imwrite(output_path, res)

        print("Successfully ran inference on: " + input_file + " . Saved to " +
              output_path)
Exemplo n.º 3
0
def recreate_bn_model(input_imgs_tensor,
                      is_training=True,
                      crop_size=(600, 800)):
    snapshot_dir = './snapshots/'
    restore_from = './model/icnet_cityscapes_trainval_90k_bnnomerge.npy'

    img_r, img_g, img_b = tf.split(axis=3,
                                   num_or_size_splits=3,
                                   value=input_imgs_tensor)
    imgs = tf.cast(tf.concat(axis=3, values=[img_b, img_g, img_r]),
                   dtype=tf.float32)
    imgs = imgs - IMG_MEAN

    net = ICNet_BN({'data': imgs},
                   is_training=is_training,
                   num_classes=19,
                   filter_scale=1)

    _, _, pred = extend_reclassifier(net)

    restore_var = tf.global_variables()

    # Set up tf session and initialize variables.
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    sess = tf.Session(config=config)

    ckpt = tf.train.get_checkpoint_state(snapshot_dir)
    if ckpt and ckpt.model_checkpoint_path:
        print('restoring from', ckpt.model_checkpoint_path, file=sys.stderr)
        loader = tf.train.Saver(var_list=restore_var)
        loader.restore(sess, ckpt.model_checkpoint_path)
    else:
        print('restoring from', restore_from, file=sys.stderr)
        sess.run(tf.global_variables_initializer())
        net.load(restore_from, sess)

    if crop_size is not None:
        pred = tf.image.crop_to_bounding_box(pred, 0, 0, crop_size[0],
                                             crop_size[1])
    pred = tf.identity(pred, name='output_2positiveclasses')

    return sess, pred
Exemplo n.º 4
0
def main():
	"""Create the model and start the training."""
	args = get_arguments()
	
	h, w = map(int, args.input_size.split(','))
	input_size = (h, w)
	
	coord = tf.train.Coordinator()
	
	with tf.name_scope("create_inputs"):
		reader = ImageReader(
			DATA_DIR,
			DATA_LIST_PATH,
			input_size,
			args.random_scale,
			args.random_mirror,
			args.ignore_label,
			IMG_MEAN,
			coord)
		image_batch, label_batch = reader.dequeue(args.batch_size)

	#with g.as_default():
	net = ICNet_BN({'data': image_batch}, is_training=True, num_classes=args.num_classes, filter_scale=args.filter_scale)
	sub4_out = net.layers['sub4_out']
	sub24_out = net.layers['sub24_out']
	sub124_out = net.layers['conv6_cls']

	restore_var = tf.global_variables()
	all_trainable = [v for v in tf.trainable_variables() if ('beta' not in v.name and 'gamma' not in v.name) or args.train_beta_gamma]

	loss_sub4 = create_loss(sub4_out, label_batch, args.num_classes, args.ignore_label)
	loss_sub24 = create_loss(sub24_out, label_batch, args.num_classes, args.ignore_label)
	loss_sub124 = create_loss(sub124_out, label_batch, args.num_classes, args.ignore_label)

	l2_losses = [args.weight_decay * tf.nn.l2_loss(v) for v in tf.trainable_variables() if 'weights' in v.name]
		
	reduced_loss = LAMBDA1 * loss_sub4 +  LAMBDA2 * loss_sub24 + LAMBDA3 * loss_sub124 + tf.add_n(l2_losses)

	# Using Poly learning rate policy 
	base_lr = tf.constant(args.learning_rate)
	step_ph = tf.placeholder(dtype=tf.float32, shape=())
	learning_rate = tf.scalar_mul(base_lr, tf.pow((1 - step_ph / args.num_steps), args.power))
		
	# Gets moving_mean and moving_variance update operations from tf.GraphKeys.UPDATE_OPS
	if args.update_mean_var == False:
		update_ops = None
	else:
		update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)

	with tf.control_dependencies(update_ops):
		opt_conv = tf.train.MomentumOptimizer(learning_rate, args.momentum)
		grads = tf.gradients(reduced_loss, all_trainable)
		train_op = opt_conv.apply_gradients(zip(grads, all_trainable))
Exemplo n.º 5
0
def main():
    args = get_arguments()

    img, filename = load_img(args.img_path)
    shape = img.shape[0:2]

    x = tf.placeholder(dtype=tf.float32, shape=img.shape)
    img_tf = preprocess(x)
    img_tf, n_shape = check_input(img_tf)

    # Create network.
    net = ICNet_BN({'data': img_tf}, num_classes=num_classes)

    # Predictions.
    raw_output = net.layers['conv6_cls']
    output = tf.image.resize_bilinear(raw_output, tf.shape(img_tf)[1:3, ])
    output = tf.argmax(output, dimension=3)
    pred = tf.expand_dims(output, dim=3)

    # Init tf Session
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    sess = tf.Session(config=config)
    init = tf.global_variables_initializer()

    sess.run(init)

    restore_var = tf.global_variables()

    ckpt = tf.train.get_checkpoint_state(args.snapshots_dir)
    if ckpt and ckpt.model_checkpoint_path:
        loader = tf.train.Saver(var_list=restore_var)
        load_step = int(
            os.path.basename(ckpt.model_checkpoint_path).split('-')[1])
        load(loader, sess, ckpt.model_checkpoint_path)

    preds = sess.run(pred, feed_dict={x: img})

    # print(preds.shape)
    # s = preds.flatten()
    # print(set(s))
    # print((s == 0).sum())
    # print((s == 1).sum())
    # print((s == 2).sum())

    msk = decode_labels(preds, num_classes=num_classes)
    im = Image.fromarray(msk[0])
    if not os.path.exists(args.save_dir):
        os.makedirs(args.save_dir)
    im.save(args.save_dir + filename.replace('.jpg', '.png'))
Exemplo n.º 6
0
    def _construct_and_fill_model(self):
        logger.info('Will create model.')
        with tf.get_default_graph().as_default():
            img_np = tf.placeholder(tf.float32, shape=(None, None, 3))
            img_shape = tf.shape(img_np)

            w, h = self.input_size_wh
            img_np_4d = tf.expand_dims(img_np, axis=0)
            image_rs_4d = tf.image.resize_bilinear(img_np_4d, (h, w),
                                                   align_corners=True)
            image_rs = tf.squeeze(image_rs_4d, axis=0)
            img = preprocess(image_rs, h, w)

            net = ICNet_BN(
                {'data': img},
                is_training=False,
                num_classes=len(self.train_classes),
                filter_scale=self.train_config['settings']['filter_scale'])
            raw_output = net.layers['conv6_cls']  # 4d

            # Predictions.
            # !!!!!!!!!  align_corners=False
            raw_output_up = tf.image.resize_bilinear(
                raw_output,
                size=[img_shape[0], img_shape[1]],
                align_corners=False)
            # raw_output_up = tf.argmax(raw_output_up, dimension=3)

            logger.info('Will load weights from trained model.')
            # Init tf Session
            config = tf.ConfigProto()
            config.gpu_options.allow_growth = True
            sess = tf.Session(config=config)
            init = tf.global_variables_initializer()
            sess.run(init)
            loader = tf.train.Saver(var_list=tf.global_variables())

            # last_checkpoint = tf_saver.latest_checkpoint(output_train_dir)
            last_checkpoint = osp.join(self.helper.paths.model_dir,
                                       'model.ckpt')
            loader.restore(sess, last_checkpoint)

            self.input_images = img_np
            self.predictions = raw_output_up
            self.sess = sess
        logger.info('Model has been created & weights are loaded.')
Exemplo n.º 7
0
def load_from_checkpoint(shape, path):
    x = tf.placeholder(dtype=tf.float32, shape=shape)
    img_tf = preprocess(x)
    img_tf, n_shape = check_input(img_tf)

    # Create network.
    net = ICNet_BN({'data': img_tf},
                   is_training=False,
                   num_classes=num_classes)

    # Predictions.
    raw_output = net.layers['conv6_cls']
    print('raw_output', raw_output)
    output = tf.image.resize_bilinear(raw_output, tf.shape(img_tf)[1:3, ])
    output = tf.argmax(output, dimension=3)
    pred = tf.expand_dims(output, dim=3)
    pred = pred[0]
    #pred = tf.py_func(decode_labels, [pred, 1, 2], tf.uint8)

    # Init tf Session
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True

    restore_var = tf.global_variables()

    sess = tf.Session(config=config)
    init = tf.global_variables_initializer()

    sess.run(init)

    ckpt = tf.train.get_checkpoint_state(path)
    if ckpt and ckpt.model_checkpoint_path:
        loader = tf.train.Saver(var_list=restore_var)
        load_step = int(
            os.path.basename(ckpt.model_checkpoint_path).split('-')[1])
        loader.restore(sess, ckpt.model_checkpoint_path)
        print("Restored model parameters from {}".format(
            ckpt.model_checkpoint_path))

    #net.load('./model/icnet_cityscapes_trainval_90k_bnnomerge.npy', sess)
    return sess, pred, x
Exemplo n.º 8
0
def main():
    """Create the model and start the training."""
    args = get_arguments()
    """
    Get configurations here. We pass some arguments from command line to init configurations, for training hyperparameters, 
    you can set them in TrainConfig Class.

    Note: we set filter scale to 1 for pruned model, 2 for non-pruned model. The filters numbers of non-pruned
          model is two times larger than prunde model, e.g., [h, w, 64] <-> [h, w, 32].
    """
    cfg = TrainConfig(dataset=args.dataset,
                      is_training=True,
                      random_scale=args.random_scale,
                      random_mirror=args.random_mirror,
                      filter_scale=args.filter_scale)
    cfg.display()

    # Setup training network and training samples
    train_reader = ImageReader(cfg=cfg, mode='train')

    train_net = ICNet_BN(image_reader=train_reader, cfg=cfg, mode='train')
    """
Exemplo n.º 9
0
def main():
    """Create the model and start the training."""
    args = get_arguments()

    h, w = map(int, args.input_size.split(','))
    input_size = (h, w)

    coord = tf.train.Coordinator()

    with tf.name_scope("create_inputs"):
        reader = ImageReader(DATA_DIR, DATA_LIST_PATH, input_size,
                             args.random_scale, args.random_mirror,
                             args.ignore_label, IMG_MEAN, coord)
        image_batch, label_batch = reader.dequeue(args.batch_size)

    net = ICNet_BN({'data': image_batch},
                   is_training=True,
                   num_classes=args.num_classes,
                   filter_scale=args.filter_scale)

    sub4_out = net.layers['sub4_out']
    sub24_out = net.layers['sub24_out']
    sub124_out = net.layers['conv6_cls']

    restore_var = tf.global_variables()
    all_trainable = [
        v for v in tf.trainable_variables()
        if ('beta' not in v.name and 'gamma' not in v.name)
        or args.train_beta_gamma
    ]

    with tf.name_scope('loss'):
        loss_sub4 = create_loss(sub4_out, label_batch, args.num_classes,
                                args.ignore_label)
        loss_sub24 = create_loss(sub24_out, label_batch, args.num_classes,
                                 args.ignore_label)
        loss_sub124 = create_loss(sub124_out, label_batch, args.num_classes,
                                  args.ignore_label)
        l2_losses = [
            args.weight_decay * tf.nn.l2_loss(v)
            for v in tf.trainable_variables() if 'weights' in v.name
        ]

        reduced_loss = LAMBDA1 * loss_sub4 + LAMBDA2 * loss_sub24 + LAMBDA3 * loss_sub124 + tf.add_n(
            l2_losses)

        tf.summary.scalar('sub4', loss_sub4)
        tf.summary.scalar('sub24', loss_sub24)
        tf.summary.scalar('sub124', loss_sub124)
        tf.summary.scalar('total_loss', reduced_loss)

    # Using Poly learning rate policy
    base_lr = tf.constant(args.learning_rate)
    step_ph = tf.placeholder(dtype=tf.float32, shape=())
    learning_rate = tf.scalar_mul(
        base_lr, tf.pow((1 - step_ph / args.num_steps), args.power))

    # Gets moving_mean and moving_variance update operations from tf.GraphKeys.UPDATE_OPS
    if args.update_mean_var == False:
        update_ops = None
    else:
        update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)

    with tf.control_dependencies(update_ops):
        opt_conv = tf.train.MomentumOptimizer(learning_rate, args.momentum)
        grads = tf.gradients(reduced_loss, all_trainable)
        train_op = opt_conv.apply_gradients(zip(grads, all_trainable))

    # Set up tf session and initialize variables.
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    sess = tf.Session(config=config)
    init = tf.global_variables_initializer()

    sess.run(init)

    # Saver for storing checkpoints of the model.
    saver = tf.train.Saver(var_list=tf.global_variables(), max_to_keep=20)
    summ = tf.summary.merge_all()
    tenboard_dir = tfboard_dir + str(LEARNING_RATE) + '_' + str(NUM_STEPS)

    writer = tf.summary.FileWriter(tenboard_dir)
    writer.add_graph(sess.graph)
    ckpt = tf.train.get_checkpoint_state(args.snapshot_dir)
    # net.load(args.restore_from, sess)

    if ckpt and ckpt.model_checkpoint_path:
        loader = tf.train.Saver(var_list=restore_var)
        load_step = int(
            os.path.basename(ckpt.model_checkpoint_path).split('-')[1])
        load(loader, sess,
             './snapshots/3wDataSet/model.ckpt-' + str(START_STEP))
    else:
        print('Restore from pre-trained model...')
        net.load(args.restore_from, sess)

    # Start queue threads.
    threads = tf.train.start_queue_runners(coord=coord, sess=sess)

    # Iterate over training steps.
    for step in range(START_STEP, args.num_steps):
        start_time = time.time()

        feed_dict = {step_ph: step}
        if step % args.save_pred_every == 0:
            s, loss_value, loss1, loss2, loss3, _ = sess.run(
                [
                    summ, reduced_loss, loss_sub4, loss_sub24, loss_sub124,
                    train_op
                ],
                feed_dict=feed_dict)
            save(saver, sess, args.snapshot_dir, step)
            writer.add_summary(s, step)
        else:
            s, loss_value, loss1, loss2, loss3, _ = sess.run(
                [
                    summ, reduced_loss, loss_sub4, loss_sub24, loss_sub124,
                    train_op
                ],
                feed_dict=feed_dict)
            writer.add_summary(s, step)
        duration = time.time() - start_time
        print(
            'step {:d} \t total loss = {:.3f}, sub4 = {:.3f}, sub24 = {:.3f}, sub124 = {:.3f} ({:.3f} sec/step)'
            .format(step, loss_value, loss1, loss2, loss3, duration))

    coord.request_stop()
    coord.join(threads)
Exemplo n.º 10
0
def main():
    args = get_arguments()

    img, filename = load_img(args.img_path)
    shape = img.shape[0:2]

    x = tf.placeholder(dtype=tf.float32, shape=img.shape)
    img_tf = preprocess(x)
    img_tf, n_shape = check_input(img_tf)

    # Create network.
    if args.model[-2:] == 'bn':
        net = ICNet_BN({'data': img_tf}, num_classes=num_classes)
    elif args.model == 'others':
        net = ICNet_BN({'data': img_tf}, num_classes=num_classes)
    else:
        net = ICNet({'data': img_tf}, num_classes=num_classes)

    raw_output = net.layers['conv6_cls']

    # Predictions.
    raw_output_up = tf.image.resize_bilinear(raw_output,
                                             size=n_shape,
                                             align_corners=True)
    raw_output_up = tf.image.crop_to_bounding_box(raw_output_up, 0, 0,
                                                  shape[0], shape[1])
    raw_output_up = tf.argmax(raw_output_up, axis=3)
    pred = decode_labels(raw_output_up, shape, num_classes)

    # Init tf Session
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    sess = tf.Session(config=config)
    init = tf.global_variables_initializer()

    sess.run(init)

    restore_var = tf.global_variables()

    if args.model == 'train':
        print('Restore from train30k model...')
        net.load(model_train30k, sess)
    elif args.model == 'trainval':
        print('Restore from trainval90k model...')
        net.load(model_trainval90k, sess)
    elif args.model == 'train_bn':
        print('Restore from train30k bnnomerge model...')
        net.load(model_train30k_bn, sess)
    elif args.model == 'trainval_bn':
        print('Restore from trainval90k bnnomerge model...')
        net.load(model_trainval90k_bn, sess)
    else:
        ckpt = tf.train.get_checkpoint_state(snapshot_dir)
        if ckpt and ckpt.model_checkpoint_path:
            loader = tf.train.Saver(var_list=restore_var)
            load_step = int(
                os.path.basename(ckpt.model_checkpoint_path).split('-')[1])
            load(loader, sess, ckpt.model_checkpoint_path)

    preds = sess.run(pred, feed_dict={x: img})

    if not os.path.exists(args.save_dir):
        os.makedirs(args.save_dir)
    misc.imsave(args.save_dir + filename, preds[0])
Exemplo n.º 11
0
def main():
    """Create the model and start the training."""
    args = get_arguments()

    """
    Get configurations here. We pass some arguments from command line to init configurations, for training hyperparameters,
    you can set them in TrainConfig Class.

    Note: we set filter scale to 1 for pruned model, 2 for non-pruned model. The filters numbers of non-pruned
          model is two times larger than prunde model, e.g., [h, w, 64] <-> [h, w, 32].
    """
    cfg = TrainConfig(dataset=args.dataset,
                is_training=True,
                random_scale=args.random_scale,
                random_mirror=args.random_mirror,
                filter_scale=args.filter_scale)
    cfg.display()

    # Setup training network and training samples
    train_reader = ImageReader(cfg=cfg, mode='train')
    train_net = ICNet_BN(image_reader=train_reader,
                            cfg=cfg, mode='train')

    loss_sub4, loss_sub24, loss_sub124, reduced_loss = create_losses(train_net, train_net.labels, cfg)

    # Setup validation network and validation samples
    with tf.variable_scope('', reuse=True):
        val_reader = ImageReader(cfg, mode='eval')
        val_net = ICNet_BN(image_reader=val_reader,
                            cfg=cfg, mode='train')

        val_loss_sub4, val_loss_sub24, val_loss_sub124, val_reduced_loss = create_losses(val_net, val_net.labels, cfg)

    # Using Poly learning rate policy
    base_lr = tf.constant(cfg.LEARNING_RATE)
    step_ph = tf.placeholder(dtype=tf.float32, shape=())
    learning_rate = tf.scalar_mul(base_lr, tf.pow((1 - step_ph / cfg.TRAINING_STEPS), cfg.POWER))

    # Set restore variable
    restore_var = tf.global_variables()
    all_trainable = [v for v in tf.trainable_variables() if ('beta' not in v.name and 'gamma' not in v.name) or args.train_beta_gamma]

    # Gets moving_mean and moving_variance update operations from tf.GraphKeys.UPDATE_OPS
    if args.update_mean_var == False:
        update_ops = None
    else:
        update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)

    with tf.control_dependencies(update_ops):
        opt_conv = tf.train.MomentumOptimizer(learning_rate, cfg.MOMENTUM)
        grads = tf.gradients(reduced_loss, all_trainable)
        train_op = opt_conv.apply_gradients(zip(grads, all_trainable))

    # Create session & restore weights (Here we only need to use train_net to create session since we reuse it)
    train_net.create_session()
    # train_net.restore(cfg.model_weight, restore_var)
    saver = tf.train.Saver(var_list=tf.global_variables(), max_to_keep=5)

    # Iterate over training steps.
    for step in range(cfg.TRAINING_STEPS):
        start_time = time.time()

        feed_dict = {step_ph: step}
        if step % cfg.SAVE_PRED_EVERY == 0:
            loss_value, loss1, loss2, loss3, val_loss_value, _ = train_net.sess.run([reduced_loss, loss_sub4, loss_sub24, loss_sub124, val_reduced_loss, train_op], feed_dict=feed_dict)
            train_net.save(saver, cfg.SNAPSHOT_DIR, step)
        else:
            loss_value, loss1, loss2, loss3, val_loss_value, _ = train_net.sess.run([reduced_loss, loss_sub4, loss_sub24, loss_sub124, val_reduced_loss, train_op], feed_dict=feed_dict)

        duration = time.time() - start_time
        print('step {:d} \t total loss = {:.3f}, sub4 = {:.3f}, sub24 = {:.3f}, sub124 = {:.3f}, val_loss: {:.3f} ({:.3f} sec/step)'.\
                    format(step, loss_value, loss1, loss2, loss3, val_loss_value, duration))
Exemplo n.º 12
0
def main():
    """Create the model and start the training."""
    args = get_arguments()
    print("SAVE TO " + args.snapshot_dir)
    datalists_epoch = {
        1: args.datalist_path_epoch1,
        2: args.datalist_path_epoch2,
        3: args.datalist_path_epoch3,
        4: args.datalist_path_epoch4,
        5: args.datalist_path_epoch5
    }
    if args.cross_val:
        val_epoch = int(args.cross_val)
        train_epochs = [1, 2, 3, 4, 5]
        train_epochs.remove(val_epoch)
        train_lists = [datalists_epoch[i] for i in train_epochs]
        val_lists = datalists_epoch[val_epoch]
    h, w = map(int, args.input_size.split(','))
    input_size = (h, w)
    max_runtime = args.max_runtime
    max_time_seconds = 3600 * max_runtime
    epochs_until_val = 3

    global dataset_class_weights
    if args.weights_for_dataset is None:
        dataset_class_weights = None
    elif args.weights_for_dataset == 'de_top15':
        dataset_class_weights = weights_detop15
    elif args.weights_for_dataset == 'eu_top25':
        dataset_class_weights = weights_eutop25
    elif args.weights_for_dataset == 'world2k':
        dataset_class_weights = weights_world2k
    elif args.weights_for_dataset == 'kaggle_dstl':
        dataset_class_weights = weights_kaggledstl
    elif args.weights_for_dataset == 'vaihingen':
        dataset_class_weights = weights_vaihingen
    elif args.weights_for_dataset == 'de_top15_nores':
        dataset_class_weights = weights_detop15_nores
    elif args.weights_for_dataset == 'eu_top25_nores':
        dataset_class_weights = weights_eutop25_nores
    elif args.weights_for_dataset == 'world2k_nores':
        dataset_class_weights = weights_world2k_nores

    coord = tf.train.Coordinator()

    if args.cross_val:
        with tf.name_scope("create_inputs"):
            reader = ImageReader(args.datadir, train_lists, input_size,
                                 args.random_scale, args.random_mirror,
                                 args.ignore_label, IMG_MEAN, coord)
            image_batch, label_batch = reader.dequeue(args.batch_size)

            # for validation
            reader_val = ImageReader(args.datadir, val_lists, input_size,
                                     args.random_scale, args.random_mirror,
                                     args.ignore_label, IMG_MEAN, coord)
            image_batch_val, label_batch_val = reader_val.dequeue(
                args.batch_size)
    else:

        with tf.name_scope("create_inputs"):
            reader = ImageReader(args.datadir, args.datalist_path, input_size,
                                 args.random_scale, args.random_mirror,
                                 args.ignore_label, IMG_MEAN, coord)
            image_batch, label_batch = reader.dequeue(args.batch_size)

            # for validation
            reader_val = ImageReader(args.datadir, args.datalist_path_val,
                                     input_size, args.random_scale,
                                     args.random_mirror, args.ignore_label,
                                     IMG_MEAN, coord)
            image_batch_val, label_batch_val = reader_val.dequeue(
                args.batch_size)

    net = ICNet_BN({'data': image_batch},
                   is_training=True,
                   num_classes=args.num_classes,
                   filter_scale=args.filter_scale)
    with tf.variable_scope("val"):
        net_val = ICNet_BN({'data': image_batch_val},
                           is_training=True,
                           num_classes=args.num_classes,
                           filter_scale=args.filter_scale)

    sub4_out = net.layers['sub4_out']
    sub24_out = net.layers['sub24_out']
    sub124_out = net.layers['conv6_cls']

    # early stop variables
    last_val_loss_tf = tf.Variable(10000.0, name="last_loss")
    steps_total_tf = tf.Variable(0, name="steps_total")
    val_increased_t_tf = tf.Variable(0, name="loss_increased_t")

    if args.not_restore_last:
        restore_var = [
            v for v in tf.global_variables() if 'conv6_cls' not in v.name
            and 'val' not in v.name and 'sub4_out' not in v.name
            and 'sub24_out' not in v.name and 'sub124_out' not in v.name
        ]
    else:
        # to load last layer, the line 78 in network.py has to be removed too and ignore_missing set to False
        # see https://github.com/hellochick/ICNet-tensorflow/issues/50 BCJuan
        # don't restore val vars
        restore_var = [
            v for v in tf.trainable_variables() if 'val' not in v.name
        ]  #tf.global_variables()
        # don't train val variables
    all_trainable = [
        v for v in tf.trainable_variables()
        if (('beta' not in v.name and 'gamma' not in v.name)
            or args.train_beta_gamma) and 'val' not in v.name
    ]
    # all_trainable = [v for v in tf.trainable_variables() if
    #                  ('beta' not in v.name and 'gamma' not in v.name) or args.train_beta_gamma]

    # print([v for v in tf.global_variables() if v.name in["last_val_loss","steps_total","val_increased_t"]])
    # restore_var.extend([v for v in tf.global_variables() if v.name in["last_val_loss","steps_total","val_increased_t"]])

    # assert not np.any(np.isnan(sub4_out))
    loss_sub4 = create_loss(sub4_out, label_batch, args.num_classes,
                            args.ignore_label)
    loss_sub24 = create_loss(sub24_out, label_batch, args.num_classes,
                             args.ignore_label)
    loss_sub124 = create_loss(sub124_out, label_batch, args.num_classes,
                              args.ignore_label)
    # l2_losses = [args.weight_decay * tf.nn.l2_loss(v) for v in tf.trainable_variables() if 'weights' in v.name]
    l2_losses = [
        args.weight_decay * tf.nn.l2_loss(v) for v in tf.trainable_variables()
        if ('weights' in v.name and 'val' not in v.name)
    ]
    reduced_loss = LAMBDA1 * loss_sub4 + LAMBDA2 * loss_sub24 + LAMBDA3 * loss_sub124 + tf.add_n(
        l2_losses)

    ####################### Loss Calculation FOR VALIDATION

    sub4_out_val = net_val.layers['sub4_out']
    sub24_out_val = net_val.layers['sub24_out']
    sub124_out_val = net_val.layers['conv6_cls']

    loss_sub4_val = create_loss(sub4_out_val, label_batch_val,
                                args.num_classes, args.ignore_label)
    loss_sub24_val = create_loss(sub24_out_val, label_batch_val,
                                 args.num_classes, args.ignore_label)
    loss_sub124_val = create_loss(sub124_out_val, label_batch_val,
                                  args.num_classes, args.ignore_label)
    l2_losses_val = [
        args.weight_decay * tf.nn.l2_loss(v) for v in tf.trainable_variables()
        if ('weights' in v.name and 'val' in v.name)
    ]

    reduced_loss_val = LAMBDA1 * loss_sub4_val + LAMBDA2 * loss_sub24_val + LAMBDA3 * loss_sub124_val + tf.add_n(
        l2_losses_val)
    ####################### End Loss Calculation FOR VALIDATION

    # Using Poly learning rate policy
    base_lr = tf.constant(args.learning_rate)
    step_ph = tf.placeholder(dtype=tf.float32, shape=())
    learning_rate = tf.scalar_mul(
        base_lr, tf.pow((1 - step_ph / args.num_steps), args.power))

    # Gets moving_mean and moving_variance update operations from tf.GraphKeys.UPDATE_OPS
    if args.update_mean_var == False:
        update_ops = None
    else:
        update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)

    with tf.control_dependencies(update_ops):
        opt_conv = tf.train.MomentumOptimizer(learning_rate, args.momentum)
        grads = tf.gradients(reduced_loss, all_trainable)
        train_op = opt_conv.apply_gradients(zip(grads, all_trainable))

    # Set up tf session and initialize variables.
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    sess = tf.Session(config=config)
    init = tf.global_variables_initializer()

    sess.run(init)

    # start time
    glob_start_time = time.time()

    # Saver for storing checkpoints of the model.
    saver = tf.train.Saver(var_list=tf.global_variables(), max_to_keep=10)

    if '.npy' not in args.restore_from:
        ckpt = tf.train.get_checkpoint_state(args.restore_from)
    else:
        ckpt = tf.train.get_checkpoint_state(args.snapshot_dir)
    if ckpt and ckpt.model_checkpoint_path:
        vars_to_restore = get_tensors_in_checkpoint_file(
            file_name=ckpt.model_checkpoint_path)
        # print(vars_to_restore)
        # print([v.name for v in restore_var])
        # thanks to https://stackoverflow.com/a/50216949/8862202
        # v.name[:-2] to transform 'conv1_1_3x3_s2/weights:0' to 'conv1_1_3x3_s2/weights'
        vars_to_restore = [
            v for v in restore_var
            if 'val' not in v.name and v.name[:-2] in vars_to_restore
        ]
        # print(vars_to_restore)
        #loader = tf.train.Saver(var_list=restore_var)
        loader = tf.train.Saver(var_list=vars_to_restore)
        load_step = int(
            os.path.basename(ckpt.model_checkpoint_path).split('-')[1])
        load(loader, sess, ckpt.model_checkpoint_path)
    else:
        print('Restore from pre-trained model...')
        net.load(args.restore_from, sess)
    # Start queue threads.
    threads = tf.train.start_queue_runners(coord=coord, sess=sess)

    if args.reset_patience:
        z = tf.assign(val_increased_t_tf, 0)
        sess.run(z)

    print(sess.run(last_val_loss_tf))
    print(sess.run(steps_total_tf))
    print(sess.run(val_increased_t_tf))

    if not args.cross_val:
        val_epoch_len = len(reader_val.image_list)
        val_num_steps = val_epoch_len // args.batch_size
        # Iterate over training steps.
        last_val_loss = sess.run(last_val_loss_tf)
        val_increased_t = sess.run(val_increased_t_tf)
        best_model_step = 0
        total_steps = sess.run(steps_total_tf)
        for step in range(total_steps, args.num_steps + total_steps):
            start_time = time.time()
            feed_dict = {step_ph: step}
            if step % args.save_pred_every == 0:

                # validating
                if args.validate:
                    print("validating: ")
                    print_assign_vars(sess)
                    print("Assigned vars for validation. ")
                    loss_sum = 0
                    for val_step in trange(val_num_steps,
                                           desc='validation',
                                           leave=True):
                        loss_value_v, loss1_v, loss2_v, loss3_v = sess.run(
                            [
                                reduced_loss_val, loss_sub4_val,
                                loss_sub24_val, loss_sub124_val
                            ],
                            feed_dict=feed_dict)
                        loss_sum = loss_sum + loss_value_v
                    loss_avg = loss_sum / val_num_steps

                    if loss_avg > last_val_loss:
                        val_increased_t = val_increased_t + 1
                        if val_increased_t >= args.patience:
                            print(
                                "Terminated Training, Best Model (at step %d) saved 4 validations ago"
                                % best_model_step)
                            f = open("./FINISHED_ICNET", "w+")
                            f.close()
                            break

                    else:
                        val_increased_t = 0
                        best_model_step = step

                    print(
                        'VALIDATION COMPLETE step {:d}\tVal_Loss Increased {:d}/{:d} times\t total loss = {:.3f}'
                        ' last loss = {:.3f}'.format(step, val_increased_t,
                                                     args.patience, loss_avg,
                                                     last_val_loss))

                    last_val_loss = loss_avg
                    steps_assign = tf.assign(steps_total_tf, step)
                    last_val_assign = tf.assign(last_val_loss_tf,
                                                last_val_loss)
                    increased_assign = tf.assign(val_increased_t_tf,
                                                 val_increased_t)
                    print("loss avg " + str(loss_avg))
                    print(sess.run(steps_assign))
                    print(sess.run(last_val_assign))
                    print(sess.run(increased_assign))

                # Saving

                loss_value, loss1, loss2, loss3, _ = sess.run(
                    [
                        reduced_loss, loss_sub4, loss_sub24, loss_sub124,
                        train_op
                    ],
                    feed_dict=feed_dict)
                save(saver, sess, args.snapshot_dir, step)

                # check if max run time is already over
                elapsed = time.time() - glob_start_time
                if (elapsed + 300) > max_time_seconds:
                    print("Training stopped: max run time elapsed")
                    os.remove("./RUNNING_ICNET")
                    break
            else:
                loss_value, loss1, loss2, loss3, _ = sess.run(
                    [
                        reduced_loss, loss_sub4, loss_sub24, loss_sub124,
                        train_op
                    ],
                    feed_dict=feed_dict)
            duration = time.time() - start_time
            print(
                'step {:d} \t total loss = {:.3f}, sub4 = {:.3f}, sub24 = {:.3f}, sub124 = {:.3f} ({:.3f} sec/step)'
                .format(step, loss_value, loss1, loss2, loss3, duration))
        train_duration = time.time() - glob_start_time
        print('Total training time: ' + str(train_duration))
    else:
        # Training with cross validation
        print("Training-Mode CROSS VALIDATION")
        val_epoch_len = len(reader_val.image_list)
        val_num_steps = val_epoch_len // args.batch_size
        print("Val epoch length %d, Num steps %d" %
              (val_epoch_len, val_num_steps))
        last_val_loss = math.inf
        val_not_imp_t = 0

        # train

        for step in range(1000000):
            feed_dict = {step_ph: step}
            train_start = time.time()
            loss_value, loss1, loss2, loss3, _ = sess.run(
                [reduced_loss, loss_sub4, loss_sub24, loss_sub124, train_op],
                feed_dict=feed_dict)
            duration_t = time.time() - train_start
            if args.print_steps:
                print(
                    'step {:d} \t total loss = {:.3f}, sub4 = {:.3f}, sub24 = {:.3f}, sub124 = {:.3f} ({:.3f} sec/step)'
                    .format(step, loss_value, loss1, loss2, loss3, duration_t))

            if step % args.save_pred_every == 0:
                # save and validate
                # SAVE previously trained model
                save(saver, sess, args.snapshot_dir, step)
                # Validate
                print("validating: ")
                start_time = time.time()
                print_assign_vars(sess)
                print("Assigned vars for validation. ")
                loss_sum = 0
                for val_step in trange(val_num_steps,
                                       desc='validation',
                                       leave=True):
                    loss_value_v, loss1_v, loss2_v, loss3_v = sess.run(
                        [
                            reduced_loss_val, loss_sub4_val, loss_sub24_val,
                            loss_sub124_val
                        ],
                        feed_dict=feed_dict)
                    loss_sum = loss_sum + loss_value_v
                duration = time.time() - start_time
                loss_avg = loss_sum / val_num_steps
                print(
                    'VALIDATION COMPLETE step {:d} \t total loss = {:.3f} \t duration = {:.3f}'
                    .format(step, loss_avg, duration))

            if loss_avg >= last_val_loss:
                val_not_imp_t = val_not_imp_t + 1
                if val_not_imp_t >= 4:
                    print(
                        "Terminated Training, Best Model saved 5 validations before"
                    )
                    f = open("./FINISHED_ICNET", "w+")
                    f.close()
                    break

            else:
                val_not_imp_t = 0

            last_val_loss = loss_avg

            # check if max run time is already over
            elapsed = time.time() - glob_start_time
            if (elapsed + 300) > max_time_seconds:
                print("Training stopped: max run time elapsed")
                os.remove("./RUNNING_ICNET")
                break

    coord.request_stop()
    coord.join(threads)
Exemplo n.º 13
0
def main(_):
    if not FLAGS.output_file:
        raise ValueError(
            'You must supply the path to save to with --output_file')

    tf.logging.set_verbosity(tf.logging.INFO)

    with tf.Graph().as_default() as graph:
        shape = INPUT_SIZE.split(',')
        shape = (int(shape[0]), int(shape[1]), 3)

        x = tf.placeholder(name='input',
                           dtype=tf.float32,
                           shape=(1, shape[0], shape[1], 3))

        img_tf = tf.cast(x, dtype=tf.float32)
        # Extract mean.
        img_tf -= IMG_MEAN

        #img_tf = tf.expand_dims(img_tf, dim=0)

        print(img_tf)
        # Create network.
        net = ICNet_BN({'data': img_tf},
                       is_training=False,
                       num_classes=NUM_CLASSES)

        raw_output = net.layers['conv6_cls']
        output = tf.image.resize_bilinear(raw_output,
                                          tf.shape(img_tf)[1:3, ],
                                          name='raw_output')
        output = tf.argmax(output, dimension=3)
        pred = tf.expand_dims(output, dim=3)
        pred = tf.squeeze(pred, 0)
        pred = tf.cast(pred, dtype=tf.int32, name='indices')
        '''
        pred = tf.squeeze(pred, 2)
        cond = tf.equal(pred, tf.constant(1))
        indx = tf.where((cond))
        #values = tf.constant(200, shape=tf.shape(indx)) 
        delta = tf.SparseTensor(indx, tf.constant(200), tf.shape(pred))
        result = pred + tf.sparse_tensor_to_dense(delta)

        pred = tf.expand_dims(output, dim = 2, name = 'indices')
        '''

        #pred = tf.py_func(decode_labels, [pred, 1, 2], tf.uint8)

        # Adding additional params to graph. It is necessary also to point them as outputs in graph freeze conversation, otherwise they will be cuted
        tf.constant(label_colours, name='label_colours')
        tf.constant(label_names, name="label_names")

        shape = INPUT_SIZE.split(',')
        shape = (int(shape[0]), int(shape[1]), 3)
        tf.constant(shape, name='input_size')
        tf.constant(["indices"], name="output_name")

        graph_def = graph.as_graph_def()
        with gfile.GFile(FLAGS.output_file, 'wb') as f:
            f.write(graph_def.SerializeToString())
            print('Successfull written to', FLAGS.output_file)
Exemplo n.º 14
0
def run_on_video(video_filename,
                 out_filename,
                 model_path,
                 num_classes,
                 save_to='simple',
                 canvas_size=(1600, 800),
                 alpha=0.8,
                 beta=0.2,
                 output_size=(1280, 720),
                 step=1):
    '''
    save_to: simple, double_screen or weighted
    '''
    input_size = (int(INPUT_SIZE.split(',')[0]), int(INPUT_SIZE.split(',')[0]))
    x = tf.placeholder(dtype=tf.float32,
                       shape=(int(input_size[0]), int(input_size[1]), 3))
    img_tf = preprocess(x)
    img_tf, n_shape = check_input(img_tf)

    net = ICNet_BN({'data': img_tf}, num_classes=num_classes)

    raw_output = net.layers['conv6']

    # Predictions.
    raw_output_up = tf.image.resize_bilinear(raw_output,
                                             size=n_shape,
                                             align_corners=True)
    #raw_output_up = tf.image.crop_to_bounding_box(raw_output_up, 0, 0, INPUT_SIZE[0], INPUT_SIZE[1])
    raw_output_up = tf.argmax(raw_output_up, dimension=3)
    pred = tf.expand_dims(raw_output_up, dim=3)

    # Init tf Session
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    sess = tf.Session(config=config)
    init = tf.global_variables_initializer()

    sess.run(init)

    restore_var = tf.global_variables()

    print('model_path', model_path)
    ckpt = tf.train.latest_checkpoint(model_path)
    if len(ckpt):
        loader = tf.train.Saver(var_list=restore_var)
        load(loader, sess, ckpt)


######
    cap = cv2.VideoCapture(video_filename)

    out_cap = None
    if save_to == 'double_screen':
        out_cap = cv2.VideoWriter(out_filename.replace('.mp4', '.avi'),
                                  cv2.VideoWriter_fourcc(*"MJPG"), 60,
                                  (canvas_size[0], canvas_size[1]))
    elif save_to == 'weighted':
        out_cap = cv2.VideoWriter(out_filename.replace('.mp4', '.avi'),
                                  cv2.VideoWriter_fourcc(*"MJPG"), 60,
                                  (output_size[0], output_size[1]))

    # Check if camera opened successfully
    if cap.isOpened() == False:
        print("Error opening video stream or file")
        return

    frame_num = 0
    zf = None
    while (cap.isOpened()):

        # Capture frame-by-frame
        ret, image = cap.read()

        frame_num = frame_num + 1
        if frame_num % step != 0:
            continue
        print('Processing frame', frame_num)

        if out_cap == None and save_to != 'images':
            out_cap = cv2.VideoWriter(out_filename.replace('.mp4', '.avi'),
                                      cv2.VideoWriter_fourcc(*'MJPG'), 60,
                                      (image.shape[1], image.shape[0]))
        elif save_to == 'images' and zf == None:
            zipfile_name = out_filename.replace('.avi', '.zip')

        original_shape = image.shape
        if image.shape[2] == 1:
            image = cv2.cvtColor(image, cv2.COLOR_GRAY2BGR)

        if image.shape[2] == 4:
            image = cv2.cvtColor(image, cv2.COLOR_BGRA2BGR)

        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        image = cv2.resize(image, (input_size[0], input_size[1]))

        preds = sess.run(pred, feed_dict={x: image})
        msk = decode_labels(preds, num_classes=num_classes)
        frame = msk[0]

        #cv2.imshow('1', frame)
        #cv2.waitKey(0)
        #print(frame.shape)

        if save_to == 'double_screen':

            canvas = np.zeros((canvas_size[1], canvas_size[0], 3),
                              dtype=np.uint8)
            #frame_orig = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
            frame_orig = cv2.resize(
                image, (int(canvas_size[0] / 2), int(canvas_size[1])))

            #frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
            frame = cv2.resize(frame,
                               (int(canvas_size[0] / 2), int(canvas_size[1])))

            canvas[:, 0:int(canvas_size[0] / 2), :] = frame_orig
            canvas[:, int(canvas_size[0] / 2):, :] = frame
            #cv2.imshow('1', frame)
            #cv2.waitKey(0)
            canvas = cv2.cvtColor(canvas, cv2.COLOR_BGR2RGB)
            print('canvas shape', canvas.shape)

            out_cap.write(canvas)

        elif save_to == 'simple':

            frame = cv2.resize(frame, (original_shape[1], original_shape[0]),
                               interpolation=cv2.INTER_NEAREST)
            frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
            #cv2.imshow('1', frame)
            #cv2.waitKey(0)
            out_cap.write(frame)

        elif save_to == 'images':

            frame = cv2.resize(frame, (original_shape[1], original_shape[0]),
                               interpolation=cv2.INTER_NEAREST)
            frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
            image = cv2.resize(image, (original_shape[1], original_shape[0]),
                               interpolation=cv2.INTER_NEAREST)
            image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
            cv2.imwrite('/tmp/1.png', frame)
            #cv2.imwrite('/tmp/1_orig.png', image)
            zf = zipfile.ZipFile(zipfile_name, "a", zipfile.ZIP_DEFLATED)
            name = 'frame_' + '%08d' % frame_num + '.png'
            orig_name = 'frame_orig_' + '%08d' % frame_num + '.png'
            zf.write('/tmp/1.png', name)
            #zf.write('/tmp/1_orig.png', orig_name)
            zf.close()

        elif save_to == 'weighted':

            frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
            image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

            frame = cv2.resize(frame, output_size)
            image = cv2.resize(image, output_size)

            frame = cv2.addWeighted(image, alpha, frame, beta, 0)

            out_cap.write(frame)

        elif save_to == 'perspective':

            preds = preds.squeeze()

            img, mask = getCutedRoad(image, preds)

            mask = np.expand_dims(mask, axis=0)
            mask = np.expand_dims(mask, axis=3)

            msk = decode_labels(mask, num_classes=num_classes)
            f = msk[0]

            # h, w = frame.shape[:2]

            # src = np.float32([[x1, y1],    # br
            #           [x0, y1],    # bl
            #           [x0, y0],   # tl
            #           [x1, y0]])  # tr

            # dst = np.float32([[w, h],       # br
            #                 [0, h],       # bl
            #                 [0, 0],       # tl
            #                 [w, 0]])      # tr

            # M = cv2.getPerspectiveTransform(src, dst)
            # Minv = cv2.getPerspectiveTransform(dst, src)

            # warped = cv2.warpPerspective(image, M, (w, h), flags=cv2.INTER_LINEAR)

            # cv2.imshow('1', warped)

            #resized = cv2.resize(img[y0 : y1, x0 : x1], input_size, interpolation = cv2.INTER_NEAREST)

            print((preds == 2).sum() / (preds.shape[0] * preds.shape[1]))
            mask = np.array(mask)
            mask = mask.squeeze()
            print((mask == 2).sum() / (mask.shape[0] * mask.shape[1]))

            cv2.imshow(
                '2',
                cv2.resize(img, input_size, interpolation=cv2.INTER_NEAREST))
            cv2.imshow(
                '3', cv2.resize(f, input_size,
                                interpolation=cv2.INTER_NEAREST))
            cv2.imshow('4', image)
            cv2.waitKey(0)
            #quit()

    cap.release()
    out_cap.release()
    zf.close()
Exemplo n.º 15
0
def main():
    args = get_arguments()
    print(args)

    coord = tf.train.Coordinator()

    tf.reset_default_graph()
    with tf.name_scope("create_inputs"):
        reader = ImageReader(DATA_DIRECTORY, DATA_LIST_PATH, input_size, None,
                             None, ignore_label, IMG_MEAN, coord)
        image, label = reader.image, reader.label
    image_batch, label_batch = tf.expand_dims(image, dim=0), tf.expand_dims(
        label, dim=0)  # Add one batch dimension.

    # Create network.
    if args.model[-2:] == 'bn':
        net = ICNet_BN({'data': image_batch}, num_classes=num_classes)
    else:
        net = ICNet({'data': image_batch}, num_classes=num_classes)

    # Which variables to load.
    restore_var = tf.global_variables()

    # Predictions.
    raw_output = net.layers['conv6_cls']

    raw_output_up = tf.image.resize_bilinear(raw_output,
                                             size=input_size,
                                             align_corners=True)
    raw_output_up = tf.argmax(raw_output_up, dimension=3)
    raw_pred = tf.expand_dims(raw_output_up, dim=3)

    # mIoU
    pred_flatten = tf.reshape(raw_pred, [
        -1,
    ])
    raw_gt = tf.reshape(label_batch, [
        -1,
    ])
    indices = tf.squeeze(tf.where(tf.less_equal(raw_gt, num_classes - 1)), 1)
    gt = tf.cast(tf.gather(raw_gt, indices), tf.int32)
    pred = tf.gather(pred_flatten, indices)

    mIoU, update_op = tf.contrib.metrics.streaming_mean_iou(
        pred, gt, num_classes=num_classes)

    # Set up tf session and initialize variables.
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    sess = tf.Session(config=config)
    init = tf.global_variables_initializer()

    sess.run(init)
    sess.run(tf.local_variables_initializer())

    restore_var = tf.global_variables()

    if args.model == 'train':
        print('Restore from train30k model...')
        net.load(model_train30k, sess)
    elif args.model == 'trainval':
        print('Restore from trainval90k model...')
        net.load(model_trainval90k, sess)
    elif args.model == 'train_bn':
        print('Restore from train30k bnnomerge model...')
        net.load(model_train30k_bn, sess)
    elif args.model == 'trainval_bn':
        print('Restore from trainval90k bnnomerge model...')
        net.load(model_trainval90k_bn, sess)

    # Start queue threads.
    threads = tf.train.start_queue_runners(coord=coord, sess=sess)

    for step in range(num_steps):
        preds, _ = sess.run([pred, update_op])

        if step > 0 and args.measure_time:
            calculate_time(sess, net, raw_pred)

        if step % 10 == 0:
            print('Finish {0}/{1}'.format(step, num_steps))

    print('step {0} mIoU: {1}'.format(step, sess.run(mIoU)))

    coord.request_stop()
    coord.join(threads)
Exemplo n.º 16
0
def main():
    args = get_arguments()

    if args.img_path[-4] != '.':
        files = GetAllFilesListRecusive(args.img_path,
                                        ['.jpg', '.jpeg', '.png'])
    else:
        files = [args.img_path]

    shape = INPUT_SIZE.split(',')
    shape = (int(shape[0]), int(shape[1]), 3)

    x = tf.placeholder(dtype=tf.float32, shape=shape)
    img_tf = preprocess(x)
    img_tf, n_shape = check_input(img_tf)

    # Create network.
    net = ICNet_BN({'data': img_tf},
                   is_training=False,
                   num_classes=num_classes)

    # Predictions.
    raw_output = net.layers['conv6_cls']
    output = tf.image.resize_bilinear(raw_output, tf.shape(img_tf)[1:3, ])
    output = tf.argmax(output, dimension=3)
    pred = tf.expand_dims(output, dim=3)

    # Init tf Session
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    sess = tf.Session(config=config)
    init = tf.global_variables_initializer()

    sess.run(init)

    restore_var = tf.global_variables()

    ckpt = tf.train.get_checkpoint_state(args.snapshots_dir)
    if ckpt and ckpt.model_checkpoint_path:
        loader = tf.train.Saver(var_list=restore_var)
        load_step = int(
            os.path.basename(ckpt.model_checkpoint_path).split('-')[1])
        load(loader, sess, ckpt.model_checkpoint_path)

    for path in files:

        img, filename = load_img(path)

        preds = sess.run(pred, feed_dict={x: img})

        msk = decode_labels(preds, num_classes=num_classes)
        im = msk[0]

        if not os.path.exists(args.save_dir):
            os.makedirs(args.save_dir)

        im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB)
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)

        if args.weighted:
            indx = (im == [0, 0, 0])
            im = cv2.addWeighted(im, 0.7, img, 0.7, -15)
            im[indx] = img[indx]

        cv2.imwrite(args.save_dir + filename.replace('.jpg', '.png'), im)
Exemplo n.º 17
0
def main():
    """Create the model and start the training."""
    args = get_arguments()
    """
    Get configurations here. We pass some arguments from command line to init configurations, for training hyperparameters, 
    you can set them in TrainConfig Class.

    Note: we set filter scale to 1 for pruned model, 2 for non-pruned model. The filters numbers of non-pruned
          model is two times larger than prunde model, e.g., [h, w, 64] <-> [h, w, 32].
    """
    cfg = TrainConfig(dataset=args.dataset,
                      is_training=True,
                      random_scale=args.random_scale,
                      random_mirror=args.random_mirror,
                      filter_scale=args.filter_scale)
    if args.num_classes is not None:
        cfg.param["num_classes"] = args.num_classes
    if args.data_dir is not None:
        cfg.param["data_dir"] = args.data_dir
    if args.val_list is not None:
        cfg.param["eval_list"] = args.val_list
    if args.train_list is not None:
        cfg.param["train_list"] = args.train_list
    if args.ignore_label is not None:
        cfg.param["ignore_label"] = args.ignore_label
    if args.eval_size is not None:
        cfg.param["eval_size"] = [
            int(x.strip()) for x in args.eval_size.split("x")[::-1]
        ]
    if args.training_size is not None:
        cfg.TRAINING_SIZE = [
            int(x.strip()) for x in args.training_size.split("x")[::-1]
        ]
    if args.batch_size is not None:
        cfg.BATCH_SIZE = args.batch_size
    if args.learning_rate is not None:
        cfg.LEARNING_RATE = args.learning_rate
    if args.restore_from is not None:
        cfg.model_weight = args.restore_from
    if args.snapshot_dir is not None:
        cfg.SNAPSHOT_DIR = args.snapshot_dir
    if args.restore_from == "scratch":
        from tqdm import tqdm
        import cv2
        import joblib as joblib
        if not args.img_mean:
            print(
                "Calculating img mean for custom dataset. To prevent this, specify it with --img-mean next time"
            )
            image_files, annotation_files = read_labeled_image_list(
                cfg.param["data_dir"], cfg.param["train_list"])
            means = joblib.Parallel(n_jobs=6)(
                joblib.delayed(calc_mean)(image_file, cv2)
                for image_file in tqdm(image_files, desc="calc img mean"))
            cfg.IMG_MEAN = np.mean(means, axis=0).tolist()
        else:
            cfg.IMG_MEAN = [float(x.strip()) for x in args.img_mean.split(",")]

    cfg.display()

    # Setup training network and training samples
    train_reader = ImageReader(cfg=cfg, mode='train')
    train_net = ICNet_BN(image_reader=train_reader, cfg=cfg, mode='train')

    loss_sub4, loss_sub24, loss_sub124, reduced_loss = create_losses(
        train_net, train_net.labels, cfg)

    # Setup validation network and validation samples
    with tf.variable_scope('', reuse=True):
        val_reader = ImageReader(cfg, mode='eval')
        val_net = ICNet_BN(image_reader=val_reader, cfg=cfg, mode='train')

        val_loss_sub4, val_loss_sub24, val_loss_sub124, val_reduced_loss = create_losses(
            val_net, val_net.labels, cfg)

    # Using Poly learning rate policy
    base_lr = tf.constant(cfg.LEARNING_RATE)
    step_ph = tf.placeholder(dtype=tf.float32, shape=())
    learning_rate = tf.scalar_mul(
        base_lr, tf.pow((1 - step_ph / cfg.TRAINING_STEPS), cfg.POWER))

    # Set restore variable
    restore_var = tf.global_variables()
    all_trainable = [
        v for v in tf.trainable_variables()
        if ('beta' not in v.name and 'gamma' not in v.name)
        or args.train_beta_gamma
    ]

    # Gets moving_mean and moving_variance update operations from tf.GraphKeys.UPDATE_OPS
    if args.update_mean_var == False:
        update_ops = None
    else:
        update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)

    with tf.control_dependencies(update_ops):
        opt_conv = tf.train.MomentumOptimizer(learning_rate, cfg.MOMENTUM)
        grads = tf.gradients(reduced_loss, all_trainable)
        train_op = opt_conv.apply_gradients(zip(grads, all_trainable))

    # Create session & restore weights (Here we only need to use train_net to create session since we reuse it)
    train_net.create_session()
    if args.initializer:
        train_net.set_initializer(initializer_algorithm=args.initializer)
    train_net.initialize_variables()
    if not args.restore_from or args.restore_from != "scratch":
        train_net.restore(cfg.model_weight, restore_var)
    saver = tf.train.Saver(var_list=tf.global_variables(), max_to_keep=20)

    total_parameters = 0
    for variable in tf.trainable_variables():
        # shape is an array of tf.Dimension
        shape = variable.get_shape()
        variable_parameters = 1
        for dim in shape:
            variable_parameters *= dim.value
        total_parameters += variable_parameters
    print("Total trainable parameters: " + str(total_parameters))

    # Iterate over training steps.
    val_loss_value = 10.0
    min_val_loss = float("inf")
    stagnation = 0
    max_non_decreasing_val_loss = int(
        np.ceil(args.early_stopping_patience * len(train_reader.image_list) /
                (cfg.BATCH_SIZE * cfg.EVAL_EVERY)))
    print(
        "Maximum times that val loss can stagnate before early stopping is applied: "
        + str(max_non_decreasing_val_loss))
    for step in range(cfg.TRAINING_STEPS):
        start_time = time.time()

        feed_dict = {step_ph: step}
        if step % cfg.EVAL_EVERY == 0:
            loss_value, loss1, loss2, loss3, val_loss_value, _ = train_net.sess.run(
                [
                    reduced_loss, loss_sub4, loss_sub24, loss_sub124,
                    val_reduced_loss, train_op
                ],
                feed_dict=feed_dict)
            if val_loss_value < min_val_loss:
                print("New best val loss {:.3f}. Saving weights...".format(
                    val_loss_value))
                train_net.save(
                    saver,
                    cfg.SNAPSHOT_DIR,
                    step,
                    model_name="val{:.3f}model.ckpt".format(val_loss_value))
                min_val_loss = val_loss_value
                stagnation = 0
            else:
                stagnation += 1
        else:
            loss_value, loss1, loss2, loss3, _ = train_net.sess.run(
                [reduced_loss, loss_sub4, loss_sub24, loss_sub124, train_op],
                feed_dict=feed_dict)

        duration = time.time() - start_time
        print('step {:d} \t total loss = {:.3f}, sub4 = {:.3f}, sub24 = {:.3f}, sub124 = {:.3f}, val_loss: {:.3f} ({:.3f} sec/step)'.\
                    format(step, loss_value, loss1, loss2, loss3, val_loss_value, duration))

        if stagnation > max_non_decreasing_val_loss:
            print("Early stopping")
            break
Exemplo n.º 18
0
def main():
    """Create the model and start the training."""
    args = get_arguments()
    
    h, w = map(int, args.input_size.split(','))
    input_size = (h, w)
    
    coord = tf.train.Coordinator()
    
    with tf.name_scope("create_inputs"):
        reader = ImageReader(
            ' ',
            args.data_list,
            input_size,
            args.random_scale,
            args.random_mirror,
            args.ignore_label,
            IMG_MEAN,
            coord)
        image_batch, label_batch = reader.dequeue(args.batch_size)
    
    net = ICNet_BN({'data': image_batch}, is_training=True, num_classes=args.num_classes)
    
    sub4_out = net.layers['sub4_out']
    sub24_out = net.layers['sub24_out']
    sub124_out = net.layers['conv6_cls']

    restore_var = tf.global_variables()
    all_trainable = [v for v in tf.trainable_variables() if ('beta' not in v.name and 'gamma' not in v.name) or args.train_beta_gamma]
   
    loss_sub4 = create_loss(sub4_out, label_batch, args.num_classes, args.ignore_label)
    loss_sub24 = create_loss(sub24_out, label_batch, args.num_classes, args.ignore_label)
    loss_sub124 = create_loss(sub124_out, label_batch, args.num_classes, args.ignore_label)
    l2_losses = [args.weight_decay * tf.nn.l2_loss(v) for v in tf.trainable_variables() if 'weights' in v.name]
    
    reduced_loss = LAMBDA1 * loss_sub4 +  LAMBDA2 * loss_sub24 + LAMBDA3 * loss_sub124 + tf.add_n(l2_losses)

    # Using Poly learning rate policy 
    base_lr = tf.constant(args.learning_rate)
    step_ph = tf.placeholder(dtype=tf.float32, shape=())
    learning_rate = tf.scalar_mul(base_lr, tf.pow((1 - step_ph / args.num_steps), args.power))
    
    # Gets moving_mean and moving_variance update operations from tf.GraphKeys.UPDATE_OPS
    if args.update_mean_var == False:
        update_ops = None
    else:
        update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)

    with tf.control_dependencies(update_ops):
        opt_conv = tf.train.MomentumOptimizer(learning_rate, args.momentum)
        grads = tf.gradients(reduced_loss, all_trainable)
        train_op = opt_conv.apply_gradients(zip(grads, all_trainable))
        
    # Set up tf session and initialize variables. 
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    sess = tf.Session(config=config)
    init = tf.global_variables_initializer()
    
    sess.run(init)
    
    # Saver for storing checkpoints of the model.
    saver = tf.train.Saver(var_list=tf.global_variables(), max_to_keep=5)

    ckpt = tf.train.get_checkpoint_state(args.snapshot_dir)
    if ckpt and ckpt.model_checkpoint_path:
        loader = tf.train.Saver(var_list=restore_var)
        load_step = int(os.path.basename(ckpt.model_checkpoint_path).split('-')[1])
        load(loader, sess, ckpt.model_checkpoint_path)
    else:
        print('Restore from pre-trained model...')
        net.load(args.restore_from, sess)

    # Start queue threads.
    threads = tf.train.start_queue_runners(coord=coord, sess=sess)

    # Iterate over training steps.
    for step in range(args.num_steps):
        start_time = time.time()
        
        feed_dict = {step_ph: step}
        if step % args.save_pred_every == 0:
            loss_value, loss1, loss2, loss3, _ = sess.run([reduced_loss, loss_sub4, loss_sub24, loss_sub124, train_op], feed_dict=feed_dict)
            save(saver, sess, args.snapshot_dir, step)
        else:
            loss_value, loss1, loss2, loss3, _ = sess.run([reduced_loss, loss_sub4, loss_sub24, loss_sub124, train_op], feed_dict=feed_dict)
        duration = time.time() - start_time
        print('step {:d} \t total loss = {:.3f}, sub4 = {:.3f}, sub24 = {:.3f}, sub124 = {:.3f} ({:.3f} sec/step)'.format(step, loss_value, loss1, loss2, loss3, duration))
        
    coord.request_stop()
    coord.join(threads)
Exemplo n.º 19
0
def main():
    """Create the model and start the training."""
    args = get_arguments()
    
    h, w = map(int, args.input_size.split(','))
    input_size = (h, w)

    if args.center_crop_size is None:
        center_crop_size = None
    else:
        hc, wc = map(int, args.center_crop_size.split(','))
        center_crop_size = (hc, wc)

    with tf.name_scope("create_inputs"):
        reader = ImageReader(
            DATA_DIR,
            DATA_LIST_PATH,
            input_size,
            center_crop_size,
            args.random_scale,
            args.random_mirror,
            args.ignore_label,
            IMG_MEAN)
        image_batch, label_batch = reader.dequeue(args.batch_size)

    net = ICNet_BN({'data': image_batch}, is_training=True, num_classes=args.num_classes, filter_scale=args.filter_scale)

    sub4_recls, sub24_recls, sub124_recls = bn_common.extend_reclassifier(net)

    restore_var = tf.global_variables()
    all_trainable = [v for v in tf.trainable_variables() if ('beta' not in v.name and 'gamma' not in v.name) or args.train_beta_gamma]
   
    loss_sub4 = create_loss(sub4_recls, label_batch, args)
    loss_sub24 = create_loss(sub24_recls, label_batch, args)
    loss_sub124 = create_loss(sub124_recls, label_batch, args)
    
    l2_losses = [args.weight_decay * tf.nn.l2_loss(v) for v in tf.trainable_variables()
                 if ('weights' in v.name) or ('kernel' in v.name)]
    
    reduced_loss = LAMBDA1 * loss_sub4 +  LAMBDA2 * loss_sub24 + LAMBDA3 * loss_sub124 + tf.add_n(l2_losses)

    # print(tf.get_variable_scope().name)
    # print(','.join([v.__op.original_name_scope for v in l2_losses]))
    # print(','.join([v for v in tf.trainable_variables() if ('beta' in v.name or 'gamma' in v.name)]))
    # tf.summary.FileWriter('./summary', tf.get_default_graph())
    # exit(0)

    # Using Poly learning rate policy 
    base_lr = tf.constant(args.learning_rate)
    step_ph = tf.placeholder(dtype=tf.float32, shape=())
    learning_rate = tf.scalar_mul(base_lr, tf.pow((1 - step_ph / args.num_steps), args.power))
    
    # Gets moving_mean and moving_variance update operations from tf.GraphKeys.UPDATE_OPS
    if args.update_mean_var == False:
        update_ops = None
    else:
        update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)

    with tf.control_dependencies(update_ops):
        opt_conv = tf.train.MomentumOptimizer(learning_rate, args.momentum)
        grads = tf.gradients(reduced_loss, all_trainable)
        train_op = opt_conv.apply_gradients(zip(grads, all_trainable))

    # Set up tf session and initialize variables. 
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    sess = tf.Session(config=config)
    sess.run(tf.global_variables_initializer())
    
    # Saver for storing checkpoints of the model.
    saver = tf.train.Saver(var_list=tf.global_variables(), max_to_keep=99)

    ckpt = tf.train.get_checkpoint_state(args.snapshot_dir)
    if ckpt and ckpt.model_checkpoint_path:
        loader = tf.train.Saver(var_list=restore_var)
        load(loader, sess, ckpt.model_checkpoint_path)
    else:
        print('Restore from pre-trained model...')
        net.load(args.restore_from, sess)

    # Start queue threads.
    coord = tf.train.Coordinator()
    threads = tf.train.start_queue_runners(coord=coord, sess=sess)

    # Iterate over training steps.
    for step in range(args.num_steps):
        start_time = time.time()
        
        feed_dict = {step_ph: step}
        if step % args.save_pred_every == 0:
            loss_value, loss1, loss2, loss3, _ = sess.run([reduced_loss, loss_sub4, loss_sub24, loss_sub124, train_op], feed_dict=feed_dict)
            save(saver, sess, args.snapshot_dir, step)
        else:
            loss_value, loss1, loss2, loss3, _ = sess.run([reduced_loss, loss_sub4, loss_sub24, loss_sub124, train_op], feed_dict=feed_dict)
        duration = time.time() - start_time
        print('step {:d} \t total loss = {:.3f}, sub4 = {:.3f}, sub24 = {:.3f}, sub124 = {:.3f} ({:.3f} sec/step)'.format(step, loss_value, loss1, loss2, loss3, duration))
        
    coord.request_stop()
    coord.join(threads)

    sess.close()
Exemplo n.º 20
0
def evaluate_checkpoint(model_path, args):
    coord = tf.train.Coordinator()

    tf.reset_default_graph()

    reader = ImageReader(
            args.data_list,
            INPUT_SIZE,
            random_scale = False,
            random_mirror = False,
            ignore_label = IGNORE_LABEL,
            img_mean = IMG_MEAN,
            coord = coord,
            train = False)
    image_batch, label_batch = reader.dequeue(batch_size)

    # Set up tf session and initialize variables.
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    sess = tf.Session(config=config)

    # Start queue threads.
    threads = tf.train.start_queue_runners(coord = coord, sess = sess)

    # Create network.
    net = ICNet_BN({'data': image_batch}, num_classes = num_classes)
    # Which variables to load.
    restore_var = tf.global_variables()

    # Predictions.
    raw_output = net.layers['conv6_cls']

    raw_output_up = tf.image.resize_bilinear(raw_output, size = INPUT_SIZE, align_corners = True)
    raw_output_up = tf.argmax(raw_output_up, dimension = 3)
    pred = tf.expand_dims(raw_output_up, dim = 3)

    # mIoU
    pred_flatten = tf.reshape(pred, [-1,])
    raw_gt = tf.reshape(label_batch, [-1,])
    if args.ignore_zero:
        indices = tf.squeeze(tf.where(
            tf.logical_and(
                tf.less_equal(raw_gt, num_classes - 1),
                tf.greater(raw_gt, 0)
                ),), 
            1)
    else:
        indices = tf.squeeze(tf.where(tf.less_equal(raw_gt, num_classes - 1)), 1)

    #indices = tf.squeeze(tf.where(tf.less_equal(raw_gt, num_classes - 1)), 1)

    gt = tf.cast(tf.gather(raw_gt, indices), tf.int32)
    pred = tf.gather(pred_flatten, indices)

    metric, op = tf.contrib.metrics.streaming_mean_iou(pred, gt, num_classes = num_classes)

    mIoU, update_op = metric, op
    
    # Summaries
    miou_op = tf.summary.scalar('mIOU', mIoU)
    start = time.time()
    logging.info('Starting evaluation at ' + time.strftime('%Y-%m-%d-%H:%M:%S',
                                                        time.gmtime()))
    
    sess.run(tf.global_variables_initializer())
    sess.run(tf.local_variables_initializer())

    saver = tf.train.Saver(var_list = restore_var)
    load(saver, sess, model_path)
    

    for step in range(num_steps):
        preds, _ = sess.run([pred, update_op])

        if step % 500 == 0:
            print('Finish {0}/{1}'.format(step + 1, num_steps))

    iou, summ = sess.run([mIoU, miou_op])

    sess.close()

    coord.request_stop()
    #coord.join(threads)

    return summ, iou
Exemplo n.º 21
0
def main():
    """Create the model and start the training."""
    args = get_arguments()
    
    h, w = map(int, args.input_size.split(','))
    input_size = (h, w)
    
    coord = tf.train.Coordinator()
    
    with tf.name_scope("create_inputs"):
        reader = ImageReader(
            args.data_list,
            input_size,
            args.random_scale,
            args.random_mirror,
            args.ignore_label,
            IMG_MEAN,
            coord)
        image_batch, label_batch = reader.dequeue(args.batch_size)
    
    net = ICNet_BN({'data': image_batch}, is_training=True, num_classes=args.num_classes)
    
    sub4_out = net.layers['sub4_out']
    sub24_out = net.layers['sub24_out']
    sub124_out = net.layers['conv6_cls']

    fc_list = ['conv6_cls']

    restore_var = tf.global_variables()
    all_trainable = [v for v in tf.trainable_variables() if ('beta' not in v.name and 'gamma' not in v.name) or args.train_beta_gamma]
    restore_var = [v for v in tf.global_variables() if not (len([f for f in fc_list if f in v.name])) or not args.not_restore_last]
   
    for v in restore_var:
        print(v.name)

    loss_sub4 = create_loss(sub4_out, label_batch, args.num_classes, args.ignore_label, args.use_class_weights)
    loss_sub24 = create_loss(sub24_out, label_batch, args.num_classes, args.ignore_label, args.use_class_weights)
    loss_sub124 = create_loss(sub124_out, label_batch, args.num_classes, args.ignore_label, args.use_class_weights)
    l2_losses = [args.weight_decay * tf.nn.l2_loss(v) for v in tf.trainable_variables() if 'weights' in v.name]
    
    loss = LAMBDA1 * loss_sub4 +  LAMBDA2 * loss_sub24 + LAMBDA3 * loss_sub124

    reduced_loss = loss + tf.add_n(l2_losses)


    ##############################
    # visualization and summary
    ##############################


    # Processed predictions: for visualisation.

    # Sub 4
    raw_output_up4 = tf.image.resize_bilinear(sub4_out, tf.shape(image_batch)[1:3,])
    raw_output_up4 = tf.argmax(raw_output_up4, dimension = 3)
    pred4 = tf.expand_dims(raw_output_up4, dim = 3)
    # Sub 24
    raw_output_up24 = tf.image.resize_bilinear(sub24_out, tf.shape(image_batch)[1:3,])
    raw_output_up24 = tf.argmax(raw_output_up24, dimension=3)
    pred24 = tf.expand_dims(raw_output_up24, dim=3)
    # Sub 124
    raw_output_up124 = tf.image.resize_bilinear(sub124_out, tf.shape(image_batch)[1:3,])
    raw_output_up124 = tf.argmax(raw_output_up124, dimension=3)
    pred124 = tf.expand_dims(raw_output_up124, dim=3)

    images_summary = tf.py_func(inv_preprocess, [image_batch, SAVE_NUM_IMAGES, IMG_MEAN], tf.uint8)
    labels_summary = tf.py_func(decode_labels, [label_batch,SAVE_NUM_IMAGES, args.num_classes], tf.uint8)

    preds_summary4 = tf.py_func(decode_labels, [pred4, SAVE_NUM_IMAGES, args.num_classes], tf.uint8)
    preds_summary24 = tf.py_func(decode_labels, [pred24, SAVE_NUM_IMAGES, args.num_classes], tf.uint8)
    preds_summary124 = tf.py_func(decode_labels, [pred124, SAVE_NUM_IMAGES, args.num_classes], tf.uint8)
    
    total_images_summary = tf.summary.image('images', 
                                     tf.concat(axis=2, values=[images_summary, labels_summary, preds_summary124]), 
                                     max_outputs=SAVE_NUM_IMAGES) # Concatenate row-wise.

    total_summary = [total_images_summary]

    loss_summary = tf.summary.scalar('Total_loss', reduced_loss)

    total_summary.append(loss_summary)
    
    summary_writer = tf.summary.FileWriter(args.snapshot_dir,
                                           graph=tf.get_default_graph())
    ##############################
    ##############################

    # Using Poly learning rate policy 
    if LR_SHEDULE == {}:
        base_lr = tf.constant(args.learning_rate)
        step_ph = tf.placeholder(dtype=tf.float32, shape=())
        learning_rate = tf.scalar_mul(base_lr, tf.pow((1 - step_ph / args.num_steps), args.power))
    else:
        step_ph = tf.placeholder(dtype=tf.float32, shape=())
        learning_rate = tf.Variable(LR_SHEDULE.popitem()[1], tf.float32)

    lr_summary = tf.summary.scalar('Learning_rate', learning_rate)
    total_summary.append(lr_summary)
    
    # Gets moving_mean and moving_variance update operations from tf.GraphKeys.UPDATE_OPS
    if args.update_mean_var == False:
        update_ops = None
    else:
        update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)

    with tf.control_dependencies(update_ops):
        opt_conv = tf.train.MomentumOptimizer(learning_rate, args.momentum)
        grads = tf.gradients(reduced_loss, all_trainable)
        train_op = opt_conv.apply_gradients(zip(grads, all_trainable))
        
    # Set up tf session and initialize variables. 
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    sess = tf.Session(config=config)
    init = tf.global_variables_initializer()
    
    sess.run(init)
    
    # Saver for storing checkpoints of the model.
    saver = tf.train.Saver(var_list = tf.global_variables(), max_to_keep = 10)

    ckpt = tf.train.get_checkpoint_state(args.snapshot_dir)
    if ckpt and ckpt.model_checkpoint_path:
        loader = tf.train.Saver(var_list=restore_var)
        load_step = int(os.path.basename(ckpt.model_checkpoint_path).split('-')[1])
        load(loader, sess, ckpt.model_checkpoint_path)
    else:
        print('Restore from pre-trained model...')
        net.load(args.restore_from, sess, ignore_layers = fc_list)

    # Start queue threads.
    threads = tf.train.start_queue_runners(coord=coord, sess=sess)
    summ_op = tf.summary.merge(total_summary)
    
    # Iterate over training steps.
    for step in range(args.num_steps):
        start_time = time.time()
        
        if LR_SHEDULE != {}:
            if step == LR_SHEDULE.keys()[0]:
                tf.assign(learning_rate, LR_SHEDULE.popitem()[0])

        feed_dict = {step_ph: step}
        if step % args.save_pred_every == 0:
            
            loss_value, loss1, loss2, loss3, _, summary =\
                sess.run([reduced_loss, loss_sub4, loss_sub24, loss_sub124, train_op, summ_op], feed_dict = feed_dict)

            save(saver, sess, args.snapshot_dir, step)
            summary_writer.add_summary(summary, step)

        else:
            loss_value, loss1, loss2, loss3, _ = sess.run([reduced_loss, loss_sub4, loss_sub24, loss_sub124, train_op], feed_dict=feed_dict)
            
        duration = time.time() - start_time
        #print('shape', sess.run(tf.shape(sub124_out)))
        #quit()
        print('step {:d} \t total loss = {:.3f}, sub4 = {:.3f}, sub24 = {:.3f}, sub124 = {:.3f} ({:.3f} sec/step)'.format(step, loss_value, loss1, loss2, loss3, duration))
        
    coord.request_stop()
    coord.join(threads)