Пример #1
0
def create_loss(output, label, num_classes, ignore_label):
    raw_pred = tf.reshape(output, [-1, num_classes])
    label = prepare_label(label,
                          tf.stack(output.get_shape()[1:3]),
                          num_classes=num_classes,
                          one_hot=False)
    label = tf.reshape(label, [
        -1,
    ])

    indices = get_mask(label, num_classes, ignore_label)
    gt = tf.cast(tf.gather(label, indices), tf.int32)
    pred = tf.gather(raw_pred, indices)

    # added class weights  un, bui, wo, wa, ro, res
    #class_weights = tf.constant([0.153, 0.144, 0.245, 0.022, 0.11, 0.325])

    #  class weight calculation used in segnet
    global dataset_class_weights
    if dataset_class_weights is None:
        dataset_class_weights = tf.constant([1 for i in range(num_classes)])
    class_weights = dataset_class_weights  #tf.constant([0.975644, 1.025603, 0.601745, 6.600600, 1.328684, 0.454776])
    weights = tf.gather(class_weights, gt)

    loss = tf.losses.sparse_softmax_cross_entropy(logits=pred,
                                                  labels=gt,
                                                  weights=weights)
    reduced_loss = tf.reduce_mean(loss)

    return reduced_loss
Пример #2
0
def createLoss_crossEntropyMedianFrequencyBalancing(output, label, num_classes,
                                                    ignore_label, loss_weight):
    """Define cross-entropy loss with median frequency balancing"""

    raw_pred = tf.reshape(
        output,
        [-1, num_classes])  # force 2nd dimension to be of length num_classes
    label = prepare_label(label,
                          tf.stack(output.get_shape()[1:3]),
                          num_classes=num_classes,
                          one_hot=False)
    label = tf.reshape(label, [
        -1,
    ])  # flatten tensor

    indices = get_mask(label, num_classes, ignore_label)
    gt = tf.cast(tf.gather(label, indices), tf.int32)
    pred = tf.gather(raw_pred, indices)
    gt = tf.one_hot(gt, depth=num_classes)

    loss = tf.nn.weighted_cross_entropy_with_logits(targets=gt,
                                                    logits=pred,
                                                    pos_weight=loss_weight)
    reduced_loss = tf.reduce_mean(loss)

    return reduced_loss
Пример #3
0
def create_loss(output, label, num_classes, ignore_label, use_w = False):
    raw_pred = tf.reshape(output, [-1, num_classes])
    label = prepare_label(label, tf.stack(output.get_shape()[1:3]), num_classes=num_classes, one_hot=False)
    label = tf.reshape(label, [-1,])

    indices = get_mask(label, num_classes, ignore_label)
    gt = tf.cast(tf.gather(label, indices), tf.int32)
    pred = tf.gather(raw_pred, indices)


    #with tf.device('/cpu:0'):

    loss = tf.nn.sparse_softmax_cross_entropy_with_logits(logits = pred, labels = gt)

    # Make mistakes for class N more important for network
    if use_w:
        if len(CLASS_WEIGHTS) != num_classes:
            print('Incorrect class weights, it will be not used')
        else:

            mask = tf.zeros_like(loss)
            for i, w in enumerate(CLASS_WEIGHTS):
                # mask = mask + tf.cast(tf.equal(gt, i), tf.float32) * tf.constant(w)
                preds = tf.unstack(pred, axis = -1)[0]
                mask = mask + tf.cast(tf.logical_or(tf.equal(gt, i), tf.equal(preds, i)), tf.float32) * tf.constant(w)

            loss = loss * mask

    reduced_loss = tf.reduce_mean(loss)

    return reduced_loss
Пример #4
0
def create_loss(pred, label, args):
    with tf.variable_scope('optimizer_fscore'):
        logits = tf.sigmoid(pred)

        label_onehot = prepare_label(label, tf.stack(pred.get_shape()[1:3]), num_classes=3, one_hot=True)

        logits_cls1, logits_cls2 = tf.split(logits, axis=-1, num_or_size_splits=2)
        _, labels_cls1, labels_cls2 = tf.split(label_onehot, axis=-1, num_or_size_splits=3)

        def f_score(logits_1cls, labels_1cls, beta):
            true_positive = tf.reduce_sum(tf.multiply(logits_1cls, labels_1cls))
            false_positive = tf.reduce_sum(tf.multiply(logits_1cls, (1 - labels_1cls)))
            false_negative = tf.reduce_sum(tf.multiply((1 - logits_1cls), labels_1cls))

            precision = true_positive / (true_positive + false_positive)
            recall = true_positive / (true_positive + false_negative)

            f = (1 + beta**2) * (precision * recall) / (beta**2 * precision + recall)
            return f

        f_car = f_score(logits_cls2, labels_cls2, 2.0)
        f_car_loss = 1.0 - f_car

        f_road = f_score(logits_cls1, labels_cls1, 0.5)
        f_road_loss = 1.0 - f_road

        overall_loss = f_car_loss * args.loss_mult_nonego_car + f_road_loss * args.loss_mult_road

    return overall_loss
def create_loss(output, label, num_classes, ignore_label):
    raw_pred = tf.reshape(output, [-1, num_classes])
    label = prepare_label(label, tf.stack(output.get_shape()[1:3]), num_classes=num_classes, one_hot=False)
    label = tf.reshape(label, [-1,])

    indices = get_mask(label, num_classes, ignore_label)
    gt = tf.cast(tf.gather(label, indices), tf.int32)
    pred = tf.gather(raw_pred, indices)

    loss = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=pred, labels=gt)
    reduced_loss = tf.reduce_mean(loss)

    return reduced_loss
Пример #6
0
def forward(net, labels, num_classes):
    raw_output = net.layers['conv6']
    raw_prediction = tf.reshape(raw_output, [-1, num_classes])
    label_proc = prepare_label(labels, tf.stack(raw_output.get_shape()[1:3]), num_classes=num_classes,
                               one_hot=False)  # [batch_size, h, w]
    raw_gt = tf.reshape(label_proc, [-1, ])
    indices = tf.squeeze(tf.where(tf.less_equal(raw_gt, num_classes - 1)), 1)
    gt = tf.cast(tf.gather(raw_gt, indices), tf.int32)
    prediction = tf.gather(raw_prediction, indices)

    raw_output_up = tf.argmax(raw_output, dimension=3)

    return prediction, gt, label_proc, raw_output_up
Пример #7
0
def create_loss(output, label, num_classes, ignore_label):
    raw_pred = tf.reshape(output, [-1, num_classes])
    label = prepare_label(label, tf.stack(output.get_shape()[1:3]), num_classes=num_classes, one_hot=False)
    label = tf.reshape(label, [-1,])

    indices = get_mask(label, num_classes, ignore_label)
    gt = tf.cast(tf.gather(label, indices), tf.int32)
    pred = tf.gather(raw_pred, indices)

    loss = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=pred, labels=gt)
    reduced_loss = tf.reduce_mean(loss)

    return reduced_loss
Пример #8
0
def createLoss_softmaxCrossEntropy(output, label, num_classes, ignore_label):
    """Define softmax cross-entropy loss"""
    from tools import decode_labels, prepare_label

    raw_pred = tf.reshape(
        output,
        [-1, num_classes])  # force 2nd dimension to be of length num_classes
    label = prepare_label(label,
                          tf.stack(output.get_shape()[1:3]),
                          num_classes=num_classes,
                          one_hot=False)
    label = tf.reshape(label, [
        -1,
    ])  # flatten tensor

    indices = get_mask(label, num_classes, ignore_label)
    gt = tf.cast(tf.gather(label, indices), tf.int32)
    pred = tf.gather(raw_pred, indices)

    loss = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=pred,
                                                          labels=gt)
    reduced_loss = tf.reduce_mean(loss)

    return reduced_loss
Пример #9
0
def main():
    """Create the model and start the training."""
    args = get_arguments()

    h, w = map(int, args.input_size.split(','))
    input_size = (h, w)

    tf.set_random_seed(args.random_seed)

    coord = tf.train.Coordinator()

    with tf.name_scope("create_inputs"):
        reader = ImageReader(args.data_dir, args.data_list, input_size,
                             args.random_scale, args.random_mirror,
                             args.ignore_label, IMG_MEAN, coord)
        image_batch, label_batch = reader.dequeue(args.batch_size)

    net = PSPNet101({'data': image_batch},
                    is_training=True,
                    num_classes=args.num_classes)

    raw_output = net.layers['conv6']

    # According from the prototxt in Caffe implement, learning rate must multiply by 10.0 in pyramid module
    fc_list = [
        'conv5_3_pool1_conv', 'conv5_3_pool2_conv', 'conv5_3_pool3_conv',
        'conv5_3_pool6_conv', 'conv6', 'conv5_4'
    ]
    restore_var = [v for v in tf.global_variables()]
    all_trainable = [
        v for v in tf.trainable_variables()
        if ('beta' not in v.name and 'gamma' not in v.name)
        or args.train_beta_gamma
    ]
    fc_trainable = [
        v for v in all_trainable if v.name.split('/')[0] in fc_list
    ]
    conv_trainable = [
        v for v in all_trainable if v.name.split('/')[0] not in fc_list
    ]  # lr * 1.0
    fc_w_trainable = [v for v in fc_trainable
                      if 'weights' in v.name]  # lr * 10.0
    fc_b_trainable = [v for v in fc_trainable
                      if 'biases' in v.name]  # lr * 20.0
    assert (len(all_trainable) == len(fc_trainable) + len(conv_trainable))
    assert (len(fc_trainable) == len(fc_w_trainable) + len(fc_b_trainable))

    # Predictions: ignoring all predictions with labels greater or equal than n_classes
    raw_prediction = tf.reshape(raw_output, [-1, args.num_classes])
    label_proc = prepare_label(label_batch,
                               tf.stack(raw_output.get_shape()[1:3]),
                               num_classes=args.num_classes,
                               one_hot=False)  # [batch_size, h, w]
    raw_gt = tf.reshape(label_proc, [
        -1,
    ])
    indices = tf.squeeze(tf.where(tf.less_equal(raw_gt, args.num_classes - 1)),
                         1)
    gt = tf.cast(tf.gather(raw_gt, indices), tf.int32)
    prediction = tf.gather(raw_prediction, indices)

    # Pixel-wise softmax loss.
    loss = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=prediction,
                                                          labels=gt)
    l2_losses = [
        args.weight_decay * tf.nn.l2_loss(v) for v in tf.trainable_variables()
        if 'weights' in v.name
    ]
    reduced_loss = tf.reduce_mean(loss) + tf.add_n(l2_losses)

    # Using Poly learning rate policy
    base_lr = tf.constant(args.learning_rate)
    step_ph = tf.placeholder(dtype=tf.float32, shape=())
    learning_rate = tf.scalar_mul(
        base_lr, tf.pow((1 - step_ph / args.num_steps), args.power))

    # Gets moving_mean and moving_variance update operations from tf.GraphKeys.UPDATE_OPS
    if args.update_mean_var == False:
        update_ops = None
    else:
        update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)

    with tf.control_dependencies(update_ops):
        opt_conv = tf.train.MomentumOptimizer(learning_rate, args.momentum)
        opt_fc_w = tf.train.MomentumOptimizer(learning_rate * 10.0,
                                              args.momentum)
        opt_fc_b = tf.train.MomentumOptimizer(learning_rate * 20.0,
                                              args.momentum)

        grads = tf.gradients(reduced_loss,
                             conv_trainable + fc_w_trainable + fc_b_trainable)
        grads_conv = grads[:len(conv_trainable)]
        grads_fc_w = grads[len(conv_trainable):(len(conv_trainable) +
                                                len(fc_w_trainable))]
        grads_fc_b = grads[(len(conv_trainable) + len(fc_w_trainable)):]

        train_op_conv = opt_conv.apply_gradients(
            zip(grads_conv, conv_trainable))
        train_op_fc_w = opt_fc_w.apply_gradients(
            zip(grads_fc_w, fc_w_trainable))
        train_op_fc_b = opt_fc_b.apply_gradients(
            zip(grads_fc_b, fc_b_trainable))

        train_op = tf.group(train_op_conv, train_op_fc_w, train_op_fc_b)

    # Set up tf session and initialize variables.
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    sess = tf.Session(config=config)
    init = tf.global_variables_initializer()

    sess.run(init)

    # Saver for storing checkpoints of the model.
    saver = tf.train.Saver(var_list=tf.global_variables(), max_to_keep=10)

    ckpt = tf.train.get_checkpoint_state(SNAPSHOT_DIR)
    if ckpt and ckpt.model_checkpoint_path:
        loader = tf.train.Saver(var_list=restore_var)
        load_step = int(
            os.path.basename(ckpt.model_checkpoint_path).split('-')[1])
        load(loader, sess, ckpt.model_checkpoint_path)
    else:
        print('No checkpoint file found.')
        load_step = 0

    # Start queue threads.
    threads = tf.train.start_queue_runners(coord=coord, sess=sess)

    # Iterate over training steps.
    for step in range(args.num_steps):
        start_time = time.time()

        feed_dict = {step_ph: step}
        if step % args.save_pred_every == 0:
            loss_value, _ = sess.run([reduced_loss, train_op],
                                     feed_dict=feed_dict)
            save(saver, sess, args.snapshot_dir, step)
        else:
            loss_value, _ = sess.run([reduced_loss, train_op],
                                     feed_dict=feed_dict)
        duration = time.time() - start_time
        print('step {:d} \t loss = {:.3f}, ({:.3f} sec/step)'.format(
            step, loss_value, duration))

    coord.request_stop()
    coord.join(threads)
Пример #10
0
def main():
    """Create the model and start the training."""
    args = get_arguments()
    h, w = map(int, args.input_size.split(','))
    input_size = (h, w)
    
    #tf.set_random_seed(args.random_seed)
    
    coord = tf.train.Coordinator()

    with tf.Graph().as_default(), tf.device('/cpu:0'):
        # Using Poly learning rate policy 
        base_lr = tf.constant(args.learning_rate)
        step_ph = tf.placeholder(dtype=tf.float32, shape=())
        learning_rate = tf.train.exponential_decay(base_lr,
                                    step_ph,
                                    20000,
                                    0.5,
                                    staircase=True)

        tf.summary.scalar('lr', learning_rate)

        opt = tf.train.MomentumOptimizer(learning_rate, 0.9)

        #opt = tf.train.RMSPropOptimizer(learning_rate, 0.9, momentum=0.9, epsilon=1e-10)

        #opt = tf.train.AdamOptimizer(learning_rate)

        losses = []
        train_op = []

        total_batch_size = args.batch_size*args.gpu_nums

        with tf.name_scope('DeepLabResNetModel') as scope:
            with tf.name_scope("create_inputs"):
                reader = ImageReader(
                    args.data_dir,
                    args.data_list,
                    input_size,
                    args.random_blur,
                    args.random_scale,
                    args.random_mirror,
                    args.random_rotate,
                    args.ignore_label,
                    IMG_MEAN,
                    coord)
                image_batch, label_batch = reader.dequeue(total_batch_size)

                images_splits = tf.split(axis=0, num_or_size_splits=args.gpu_nums, value=image_batch)
                labels_splits = tf.split(axis=0, num_or_size_splits=args.gpu_nums, value=label_batch)
   
            net = DeepLabResNetModel({'data': images_splits}, is_training=True, num_classes=args.num_classes)
    
            raw_output_list = net.layers['fc_voc12']

            num_valide_pixel = 0
            for i in range(len(raw_output_list)):
                with tf.device('/gpu:%d' % i):
                    raw_output_up = tf.image.resize_bilinear(raw_output_list[i], size=input_size, align_corners=True)

                    tf.summary.image('images_{}'.format(i), images_splits[i]+IMG_MEAN, max_outputs = 4)
                    tf.summary.image('labels_{}'.format(i), labels_splits[i], max_outputs = 4)

                    tf.summary.image('predict_{}'.format(i), tf.cast(tf.expand_dims(tf.argmax(raw_output_up, -1),3),tf.float32), max_outputs = 4)

                    all_trainable = [v for v in tf.trainable_variables()]

                    # Predictions: ignoring all predictions with labels greater or equal than n_classes
                    raw_prediction = tf.reshape(raw_output_up, [-1, args.num_classes])
                    label_proc = prepare_label(labels_splits[i], tf.stack(raw_output_up.get_shape()[1:3]), num_classes=args.num_classes, one_hot=False) # [batch_size, h, w]
                    raw_gt = tf.reshape(label_proc, [-1,])
                    #indices = tf.squeeze(tf.where(tf.less_equal(raw_gt, args.num_classes - 1)), 1)
                    indices = tf.where(tf.logical_and(tf.less(raw_gt, args.num_classes), tf.greater_equal(raw_gt, 0)))
                    gt = tf.cast(tf.gather(raw_gt, indices), tf.int32)
                    prediction = tf.gather(raw_prediction, indices)
                    mIoU, update_op = tf.contrib.metrics.streaming_mean_iou(tf.argmax(tf.nn.softmax(prediction), axis=-1), gt, num_classes=args.num_classes)
                    tf.summary.scalar('mean IoU_{}'.format(i), mIoU)
                    train_op.append(update_op)
                                                                                             
                    # Pixel-wise softmax loss.
                    loss = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=prediction, labels=gt)
                    num_valide_pixel += tf.shape(gt)[0]
 
                    losses.append(tf.reduce_sum(loss))

            l2_losses = [args.weight_decay * tf.nn.l2_loss(v) for v in tf.trainable_variables() if 'weights' in v.name]
            reduced_loss = tf.truediv(tf.reduce_sum(losses), tf.cast(num_valide_pixel, tf.float32)) + tf.add_n(l2_losses)
            tf.summary.scalar('average_loss', reduced_loss) 

        grads = tf.gradients(reduced_loss, all_trainable, colocate_gradients_with_ops=True)

        variable_averages = tf.train.ExponentialMovingAverage(0.99, step_ph)

        variables_to_average = (tf.trainable_variables() + tf.moving_average_variables())
        variables_averages_op = variable_averages.apply(variables_to_average)

        train_op = tf.group(opt.apply_gradients(zip(grads, all_trainable)), *train_op)
        
        train_op = tf.group(train_op, variables_averages_op)

        summary_op = tf.summary.merge_all()
    
        # Set up tf session and initialize variables. 
        config = tf.ConfigProto()
        config.allow_soft_placement=True
        sess = tf.Session(config=config)
        init = [tf.global_variables_initializer(),tf.local_variables_initializer()]
        sess.run(init)
        # Saver for storing checkpoints of the model.
        saver = tf.train.Saver(var_list=tf.global_variables(), max_to_keep=2)

        
        #restore from resnet imagenet, bised and local_step is in moving_average
        #restore_var = [v for v in tf.trainable_variables() if 'fc' not in v.name]+[v for v in tf.global_variables() if ('moving_mean' in v.name or 'moving_variance' in v.name) and ('biased' not in v.name and 'local_step' not in v.name)]
        restore_var = [v for v in tf.trainable_variables() if 'fc' not in v.name]

        ckpt = tf.train.get_checkpoint_state(args.restore_from)
        if ckpt and ckpt.model_checkpoint_path:
            loader = tf.train.Saver(var_list=restore_var)
            load(loader, sess, ckpt.model_checkpoint_path)
        else:
            print('No checkpoint file found.')

        """
        #restore from snapshot
        restore_var = tf.global_variables()

        ckpt = tf.train.get_checkpoint_state(args.snapshot_dir)
        if ckpt and ckpt.model_checkpoint_path:
            loader = tf.train.Saver(var_list=restore_var, allow_empty=True)
            load_step = int(os.path.basename(ckpt.model_checkpoint_path).split('-')[1])
            load(loader, sess, ckpt.model_checkpoint_path)
        else:
            print('No checkpoint file found.')
            load_step = 0
        """
        # Start queue threads.
        threads = tf.train.start_queue_runners(coord=coord, sess=sess)

        summary_writer = tf.summary.FileWriter(args.snapshot_dir, graph=sess.graph)
        # Iterate over training steps.
        for step in range(args.num_steps):
            start_time = time.time()
        
            feed_dict = {step_ph: step}
            if step % args.save_pred_every == 0 and step != 0:
                loss_value, _ = sess.run([reduced_loss, train_op], feed_dict=feed_dict)
                save(saver, sess, args.snapshot_dir, step)
            elif step%100 == 0:
                summary_str, loss_value, _, IOU = sess.run([summary_op, reduced_loss, train_op, mIoU], feed_dict=feed_dict)
                duration = time.time() - start_time
                summary_writer.add_summary(summary_str, step)
                print('step {:d} \t loss = {:.3f}, mean_IoU = {:.3f}, ({:.3f} sec/step)'.format(step, loss_value, IOU, duration))
            else:
                loss_value, _ = sess.run([reduced_loss, train_op], feed_dict=feed_dict)
        
        coord.request_stop()
        coord.join(threads)
Пример #11
0
def main():
    # lr_decay = 0.5
    # decay_every = 100
    """Create the model and start the training."""
    args = get_arguments()
    
    h, w = map(int, args.input_size.split(','))
    input_size = (h, w)
    
    tf.set_random_seed(args.random_seed)
    
    coord = tf.train.Coordinator()
    
    with tf.name_scope("create_inputs"):
        reader = ImageReader(
            args.data_list,
            input_size,
            args.random_scale,
            args.random_mirror,
            args.ignore_label,
            IMG_MEAN,
            coord)
        image_batch, label_batch = reader.dequeue(args.batch_size)
    
    # Set up tf session and initialize variables. 
    config = tf.ConfigProto()
    # config.gpu_options.allow_growth = True
    # config.allow_soft_placement = True
    # config.intra_op_parallelism_threads = 1
    sess = tf.Session(config = config)
    net = unext(image_batch, is_train = True, reuse = False, n_out = NUM_CLASSES)
    
    # Predictions: ignoring all predictions with labels greater or equal than n_classes
    raw_output = net.outputs
    raw_prediction = tf.reshape(raw_output, [-1, args.num_classes])
    label_proc = prepare_label(label_batch, tf.stack(raw_output.get_shape()[1:3]), num_classes=args.num_classes, one_hot=False) # [batch_size, h, w]
    raw_gt = tf.reshape(label_proc, [-1,])
    indices = tf.squeeze(tf.where(tf.less_equal(raw_gt, args.num_classes - 1)), 1)
    gt = tf.cast(tf.gather(raw_gt, indices), dtype = tf.int32)
    prediction = tf.gather(raw_prediction, indices)
                                                                                            
    main_loss = tf.nn.sparse_softmax_cross_entropy_with_logits(logits = prediction, labels = gt)

    t_vars = tf.trainable_variables()
    l2_losses = [args.weight_decay * tf.nn.l2_loss(v) for v in t_vars if 'kernel' in v.name]
    #reduced_loss = 0.5 * tf.reduce_mean(main_loss) + generalised_dice_loss(prediction, gt) + tf.add_n(l2_losses)
    reduced_loss = tf.reduce_mean(main_loss) + tf.add_n(l2_losses)

    # Processed predictions: for visualisation.
    raw_output_up = tf.image.resize_bilinear(raw_output, tf.shape(image_batch)[1:3,])
    raw_output_up = tf.argmax(raw_output_up, dimension = 3)
    pred = tf.expand_dims(raw_output_up, dim = 3)
    
    # Image summary.
    images_summary = tf.py_func(inv_preprocess, [image_batch, args.save_num_images, IMG_MEAN], tf.uint8)
    labels_summary = tf.py_func(decode_labels, [label_batch, args.save_num_images, args.num_classes], tf.uint8)
    preds_summary = tf.py_func(decode_labels, [pred, args.save_num_images, args.num_classes], tf.uint8)
    
    total_summary = tf.summary.image('images', 
                                     tf.concat(axis=2, values=[images_summary, labels_summary, preds_summary]), 
                                     max_outputs=args.save_num_images) # Concatenate row-wise.
    loss_summary = tf.summary.scalar('TotalLoss', reduced_loss)
    summary_writer = tf.summary.FileWriter(args.snapshot_dir,
                                           graph=tf.get_default_graph())

    # Using Poly learning rate policy 
    base_lr = tf.constant(args.learning_rate)
    step_ph = tf.placeholder(dtype=tf.float32, shape=())
    learning_rate = tf.train.exponential_decay(base_lr, step_ph, args.num_steps, args.power)

    lr_summary = tf.summary.scalar('LearningRate', learning_rate)
    #train_op = tf.train.MomentumOptimizer(learning_rate, args.momentum).minimize(reduced_loss, var_list = t_vars)
    train_op = tf.train.AdamOptimizer(learning_rate).minimize(reduced_loss, var_list = t_vars)
    init = tf.global_variables_initializer()
    sess.run(init)
    
    # Saver for storing checkpoints of the model.
    saver = tf.train.Saver(var_list = tf.global_variables(), max_to_keep = 10)

    ckpt = tf.train.get_checkpoint_state(SNAPSHOT_DIR)
    if ckpt and ckpt.model_checkpoint_path:
        #restore_vars = list([t for t in tf.global_variables() if not 'uconv1' in t.name])
        loader = tf.train.Saver(var_list = tf.global_variables())
        load_step = int(os.path.basename(ckpt.model_checkpoint_path).split('-')[1])
        load(loader, sess, ckpt.model_checkpoint_path)
    else:
        print('No checkpoint file found.')
        load_step = 0

    # Start queue threads.
    threads = tf.train.start_queue_runners(coord = coord, sess = sess)

    # Iterate over training steps.
    save_summary_every = 10
    for step in range(args.num_steps):
        start_time = time.time()
        
        feed_dict = {step_ph: step}
        if not step % args.save_pred_every == 0:
            loss_value, _, l_summary, lr_summ = sess.run([reduced_loss, train_op, loss_summary, lr_summary], feed_dict=feed_dict)
            duration = time.time() - start_time
        elif step % args.save_pred_every == 0:
            loss_value, _, summary, l_summary, lr_summ = sess.run([reduced_loss, train_op, total_summary, loss_summary, lr_summary], feed_dict=feed_dict)
            duration = time.time() - start_time
            save(saver, sess, args.snapshot_dir, step)
            summary_writer.add_summary(summary, step)

        if step % save_summary_every == 0:
    
            summary_writer.add_summary(l_summary, step)
            summary_writer.add_summary(lr_summ, step)
        
        print('step {:d} \t loss = {:.3f}, ({:.3f} sec/step)'.format(step, loss_value, duration))
        
    coord.request_stop()
    coord.join(threads)
Пример #12
0
    def run(self):
        tf.set_random_seed(self.random_seed)
        coord = tf.train.Coordinator()

        # 读取数据
        with tf.name_scope("create_inputs"):
            reader = ImageReader(self.data_dir, self.data_train_list,
                                 self.input_size, self.random_scale,
                                 self.random_mirror, self.ignore_label,
                                 self.img_mean, coord)
            image_batch, label_batch = reader.dequeue(self.batch_size)

        # 网络
        net = PSPNet({'data': image_batch},
                     is_training=True,
                     num_classes=self.num_classes)
        raw_output = net.layers['conv6']

        # According from the prototxt in Caffe implement, learning rate must multiply by 10.0 in pyramid module
        fc_list = [
            'conv5_3_pool1_conv', 'conv5_3_pool2_conv', 'conv5_3_pool3_conv',
            'conv5_3_pool6_conv', 'conv6', 'conv5_4'
        ]
        # 所有的变量
        restore_var = [v for v in tf.global_variables()]
        # 所有可训练变量
        all_trainable = [
            v for v in tf.trainable_variables()
            if ('beta' not in v.name and 'gamma' not in v.name)
            or self.train_beta_gamma
        ]
        # fc_list中的全连接层可训练变量和卷积可训练变量
        fc_trainable = [
            v for v in all_trainable if v.name.split('/')[0] in fc_list
        ]
        conv_trainable = [
            v for v in all_trainable if v.name.split('/')[0] not in fc_list
        ]  # lr * 1.0
        fc_w_trainable = [v for v in fc_trainable
                          if 'weights' in v.name]  # lr * 10.0
        fc_b_trainable = [v for v in fc_trainable
                          if 'biases' in v.name]  # lr * 20.0
        # 验证
        assert (len(all_trainable) == len(fc_trainable) + len(conv_trainable))
        assert (len(fc_trainable) == len(fc_w_trainable) + len(fc_b_trainable))

        # Predictions: ignoring all predictions with labels greater or equal than n_classes
        raw_prediction = tf.reshape(raw_output, [-1, self.num_classes])
        label_process = prepare_label(label_batch,
                                      tf.stack(raw_output.get_shape()[1:3]),
                                      num_classes=self.num_classes,
                                      one_hot=False)  # [batch_size, h, w]
        raw_gt = tf.reshape(label_process, [
            -1,
        ])
        indices = tf.squeeze(
            tf.where(tf.less_equal(raw_gt, self.num_classes - 1)), 1)
        gt = tf.cast(tf.gather(raw_gt, indices), tf.int32)
        prediction = tf.gather(raw_prediction, indices)

        # Pixel-wise softmax loss.
        loss = tf.nn.sparse_softmax_cross_entropy_with_logits(
            logits=prediction, labels=gt)
        l2_losses = [
            self.weight_decay * tf.nn.l2_loss(v)
            for v in tf.trainable_variables() if 'weights' in v.name
        ]
        reduced_loss = tf.reduce_mean(loss) + tf.add_n(l2_losses)

        # Using Poly learning rate policy
        base_lr = tf.constant(self.learning_rate)
        step_ph = tf.placeholder(dtype=tf.float32, shape=())
        learning_rate = tf.scalar_mul(
            base_lr, tf.pow((1 - step_ph / self.num_steps), self.power))

        # Gets moving_mean and moving_variance update operations from tf.GraphKeys.UPDATE_OPS
        update_ops = None if not self.update_mean_var else tf.get_collection(
            tf.GraphKeys.UPDATE_OPS)

        # 对变量以不同的学习率优化:分别求梯度、应用梯度
        with tf.control_dependencies(update_ops):
            opt_conv = tf.train.MomentumOptimizer(learning_rate, self.momentum)
            opt_fc_w = tf.train.MomentumOptimizer(learning_rate * 10.0,
                                                  self.momentum)
            opt_fc_b = tf.train.MomentumOptimizer(learning_rate * 20.0,
                                                  self.momentum)

            grads = tf.gradients(
                reduced_loss, conv_trainable + fc_w_trainable + fc_b_trainable)
            grads_conv = grads[:len(conv_trainable)]
            grads_fc_w = grads[len(conv_trainable):(len(conv_trainable) +
                                                    len(fc_w_trainable))]
            grads_fc_b = grads[(len(conv_trainable) + len(fc_w_trainable)):]

            train_op_conv = opt_conv.apply_gradients(
                zip(grads_conv, conv_trainable))
            train_op_fc_w = opt_fc_w.apply_gradients(
                zip(grads_fc_w, fc_w_trainable))
            train_op_fc_b = opt_fc_b.apply_gradients(
                zip(grads_fc_b, fc_b_trainable))

            train_op = tf.group(train_op_conv, train_op_fc_w, train_op_fc_b)
            pass

        sess = tf.Session(config=self.config)
        sess.run(tf.global_variables_initializer())

        # Saver for storing checkpoints of the model.
        saver = tf.train.Saver(var_list=tf.global_variables(), max_to_keep=10)

        # 加载模型
        ckpt = tf.train.get_checkpoint_state(self.log_dir)
        if ckpt and ckpt.model_checkpoint_path:
            tf.train.Saver(var_list=restore_var).restore(
                sess, ckpt.model_checkpoint_path)
            Tools.print_info("Restored model parameters from {}".format(
                ckpt.model_checkpoint_path))
        else:
            Tools.print_info('No checkpoint file found.')
            pass

        # Start queue threads.
        threads = tf.train.start_queue_runners(coord=coord, sess=sess)

        # Iterate over training steps.
        for step in range(self.num_steps):
            start_time = time.time()
            if step % self.save_pred_freq == 0:
                loss_value, _ = sess.run([reduced_loss, train_op],
                                         feed_dict={step_ph: step})
                saver.save(sess, self.checkpoint_path, global_step=step)
                Tools.print_info('The checkpoint has been created.')
            else:
                loss_value, _ = sess.run([reduced_loss, train_op],
                                         feed_dict={step_ph: step})
            duration = time.time() - start_time
            Tools.print_info(
                'step {:d} \t loss = {:.3f}, ({:.3f} sec/step)'.format(
                    step, loss_value, duration))

        coord.request_stop()
        coord.join(threads)
        pass
Пример #13
0
def main():

    temp_flags = FLAGS.__flags.items()
    temp_flags.sort()
    for params, value in FLAGS.__flags.items():
        print('{}: {}'.format(params, value))

    input_size = (FLAGS.train_image_size, FLAGS.train_image_size)

    tf.set_random_seed(1234)

    coord = tf.train.Coordinator()

    reader = ImageReader(FLAGS.data_dir, FLAGS.data_list, input_size,
                         FLAGS.random_scale, FLAGS.random_mirror,
                         FLAGS.ignore_label, IMG_MEAN, coord)
    image_batch, label_batch = reader.dequeue(FLAGS.batch_size)

    raw_output = MobileNet(image_batch,
                           isTraining=True,
                           updateBeta=FLAGS.update_beta)

    psp_list = [
        'conv_ds_15a', 'conv_ds_15b', 'conv_ds_15c', 'conv_ds_15d',
        'conv_ds_16', 'conv_ds_17'
    ]
    all_trainable = [v for v in tf.trainable_variables()]
    if FLAGS.update_beta == False:
        all_trainable = [v for v in all_trainable if 'beta' not in v.name]
    psp_trainable = [
        v for v in all_trainable if v.name.split('/')[1] in psp_list and (
            'weights' in v.name or 'biases' in v.name)
    ]
    conv_trainable = [v for v in all_trainable
                      if v not in psp_trainable]  # lr * 1.0
    psp_w_trainable = [v for v in psp_trainable
                       if 'weights' in v.name]  # lr * 10.0
    psp_b_trainable = [v for v in psp_trainable
                       if 'biases' in v.name]  # lr * 20.0

    assert (len(all_trainable) == len(psp_trainable) + len(conv_trainable))
    assert (len(psp_trainable) == len(psp_w_trainable) + len(psp_b_trainable))

    # Predictions: ignoring all predictions with labels greater or equal than n_classes
    raw_prediction = tf.reshape(raw_output, [-1, FLAGS.num_classes])
    label_proc = prepare_label(label_batch,
                               tf.stack(raw_output.get_shape()[1:3]),
                               num_classes=FLAGS.num_classes,
                               one_hot=False)  # [batch_size, h, w]
    raw_gt = tf.reshape(label_proc, [
        -1,
    ])
    indices = tf.squeeze(
        tf.where(tf.less_equal(raw_gt, FLAGS.num_classes - 1)), 1)
    gt = tf.cast(tf.gather(raw_gt, indices), tf.int32)
    prediction = tf.gather(raw_prediction, indices)

    # Pixel-wise softmax loss.
    loss = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=prediction,
                                                          labels=gt)
    # Regularisation loss
    l2_losses = [
        FLAGS.weight_decay * tf.nn.l2_loss(v)
        for v in tf.trainable_variables() if 'weights' in v.name
    ]
    reduced_loss = tf.reduce_mean(loss) + tf.add_n(l2_losses)
    #TODO  auxilary loss

    #Using Poly learning rate policy
    current_epoch = tf.placeholder(dtype=tf.float32, shape=())
    learning_rate = tf.train.polynomial_decay(
        FLAGS.start_learning_rate,
        current_epoch,
        FLAGS.decay_steps,
        end_learning_rate=FLAGS.end_learning_rate,
        power=FLAGS.learning_rate_decay_power,
        name="poly_learning_rate")

    if FLAGS.update_mean_var == False:
        update_ops = None
    else:
        update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)

    with tf.control_dependencies(update_ops):
        if FLAGS.optimizer == 'momentum':
            opt_conv = tf.train.MomentumOptimizer(learning_rate,
                                                  FLAGS.momentum)
            opt_psp_w = tf.train.MomentumOptimizer(learning_rate * 10.0,
                                                   FLAGS.momentum)
            opt_psp_b = tf.train.MomentumOptimizer(learning_rate * 20.0,
                                                   FLAGS.momentum)
        elif FLAGS.optimizer == 'rmsprop':
            opt_conv = tf.train.RMSPropOptimizer(
                learning_rate,
                decay=FLAGS.rmsprop_decay,
                momentum=FLAGS.rmsprop_momentum,
                epsilon=FLAGS.opt_epsilon)
            opt_psp_w = tf.train.RMSPropOptimizer(
                learning_rate * 10.0,
                decay=FLAGS.rmsprop_decay,
                momentum=FLAGS.rmsprop_momentum,
                epsilon=FLAGS.opt_epsilon)
            opt_psp_b = tf.train.RMSPropOptimizer(
                learning_rate * 20.0,
                decay=FLAGS.rmsprop_decay,
                momentum=FLAGS.rmsprop_momentum,
                epsilon=FLAGS.opt_epsilon)

        grads = tf.gradients(
            reduced_loss, conv_trainable + psp_w_trainable + psp_b_trainable)
        grads_conv = grads[:len(conv_trainable)]
        grads_psp_w = grads[len(conv_trainable):(len(conv_trainable) +
                                                 len(psp_w_trainable))]
        grads_psp_b = grads[(len(conv_trainable) + len(psp_w_trainable)):]

        train_op_conv = opt_conv.apply_gradients(
            zip(grads_conv, conv_trainable))
        train_op_psp_w = opt_psp_w.apply_gradients(
            zip(grads_psp_w, psp_w_trainable))
        train_op_psp_b = opt_psp_b.apply_gradients(
            zip(grads_psp_b, psp_b_trainable))

        train_op = tf.group(train_op_conv, train_op_psp_w, train_op_psp_b)

    restore_var = tf.global_variables()

    # Set up tf session and initialize variables.
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    sess = tf.Session(config=config)
    init = tf.global_variables_initializer()

    sess.run(init)

    # Saver for storing checkpoints of the model.
    saver = tf.train.Saver(var_list=tf.global_variables(), max_to_keep=500)

    load(sess, FLAGS.pretrained_checkpoint, restore_var)

    # Start queue threads.
    threads = tf.train.start_queue_runners(coord=coord, sess=sess)

    # Iterate over training steps.
    for epoch in range(FLAGS.start_epoch,
                       FLAGS.start_epoch + FLAGS.num_epochs):

        total_loss = 0.0
        for step in range(1, FLAGS.num_steps + 1):

            start_time = time.time()

            feed_dict = {current_epoch: epoch}
            loss_value, _ = sess.run([reduced_loss, train_op],
                                     feed_dict=feed_dict)

            duration = time.time() - start_time
            print('step {:d} \t loss = {:.3f}, ({:.3f} sec/step)'.format(
                step, loss_value, duration))
            #TODO ignore NaN loss
            total_loss += loss_value

        save(saver, sess, FLAGS.log_dir, epoch)
        total_loss /= FLAGS.num_steps
        print('Epoch {:d} completed! Total Loss = {:.3f}'.format(
            epoch, total_loss))

    coord.request_stop()
    coord.join(threads)
Пример #14
0
def main():
    """Create the model and start the training."""
    args = get_arguments()

    h, w = map(int, args.input_size.split(','))
    input_size = (h, w)

    tf.set_random_seed(args.random_seed)

    coord = tf.train.Coordinator()

    with tf.name_scope("create_inputs"):
        reader = ImageReader(args.data_list, input_size, args.random_scale,
                             args.random_mirror, args.ignore_label, IMG_MEAN,
                             coord)
        image_batch, label_batch = reader.dequeue(args.batch_size)

    net = PSPNet50({'data': image_batch},
                   is_training=True,
                   num_classes=args.num_classes)

    raw_output = net.layers['conv6']

    # According from the prototxt in Caffe implement, learning rate must multiply by 10.0 in pyramid module
    fc_list = [
        'conv5_3_pool1_conv', 'conv5_3_pool2_conv', 'conv5_3_pool3_conv',
        'conv5_3_pool6_conv', 'conv6', 'conv5_4'
    ]
    restore_var = [
        v for v in tf.global_variables()
        if not (len([f for f in fc_list
                     if f in v.name])) or not args.not_restore_last
    ]
    all_trainable = [
        v for v in tf.trainable_variables()
        if 'gamma' not in v.name and 'beta' not in v.name
    ]
    fc_trainable = [
        v for v in all_trainable if v.name.split('/')[0] in fc_list
    ]
    conv_trainable = [
        v for v in all_trainable if v.name.split('/')[0] not in fc_list
    ]  # lr * 1.0
    fc_w_trainable = [v for v in fc_trainable
                      if 'weights' in v.name]  # lr * 10.0
    fc_b_trainable = [v for v in fc_trainable
                      if 'biases' in v.name]  # lr * 20.0
    assert (len(all_trainable) == len(fc_trainable) + len(conv_trainable))
    assert (len(fc_trainable) == len(fc_w_trainable) + len(fc_b_trainable))

    # Predictions: ignoring all predictions with labels greater or equal than n_classes
    raw_prediction = tf.reshape(raw_output, [-1, args.num_classes])
    label_proc = prepare_label(label_batch,
                               tf.stack(raw_output.get_shape()[1:3]),
                               num_classes=args.num_classes,
                               one_hot=False)  # [batch_size, h, w]
    raw_gt = tf.reshape(label_proc, [
        -1,
    ])
    indices = tf.squeeze(tf.where(tf.less_equal(raw_gt, args.num_classes - 1)),
                         1)
    gt = tf.cast(tf.gather(raw_gt, indices), tf.int32)
    prediction = tf.gather(raw_prediction, indices)

    # Pixel-wise softmax loss.
    loss = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=prediction,
                                                          labels=gt)

    # Make mistakes for class N more important for network
    if USE_CLASS_WEIGHTS:
        if len(CLASS_WEIGHTS) != NUM_CLASSES:
            print('Incorrect class weights, it will be not used')
        else:
            mask = tf.zeros_like(loss)

            for i, w in enumerate(CLASS_WEIGHTS):
                mask = mask + tf.cast(tf.equal(gt, i),
                                      tf.float32) * tf.constant(w)
            loss = loss * mask

    l2_losses = [
        args.weight_decay * tf.nn.l2_loss(v) for v in tf.trainable_variables()
        if 'weights' in v.name
    ]
    reduced_loss = tf.reduce_mean(loss) + tf.add_n(l2_losses)

    # Processed predictions: for visualisation.
    raw_output_up = tf.image.resize_bilinear(raw_output,
                                             tf.shape(image_batch)[1:3, ])
    raw_output_up = tf.argmax(raw_output_up, dimension=3)
    pred = tf.expand_dims(raw_output_up, dim=3)

    # Image summary.
    images_summary = tf.py_func(inv_preprocess,
                                [image_batch, args.save_num_images, IMG_MEAN],
                                tf.uint8)
    labels_summary = tf.py_func(
        decode_labels, [label_batch, args.save_num_images, args.num_classes],
        tf.uint8)
    preds_summary = tf.py_func(decode_labels,
                               [pred, args.save_num_images, args.num_classes],
                               tf.uint8)

    total_summary = tf.summary.image(
        'images',
        tf.concat(axis=2,
                  values=[images_summary, labels_summary, preds_summary]),
        max_outputs=args.save_num_images)  # Concatenate row-wise.
    summary_writer = tf.summary.FileWriter(args.snapshot_dir,
                                           graph=tf.get_default_graph())

    # Using Poly learning rate policy
    base_lr = tf.constant(args.learning_rate)
    step_ph = tf.placeholder(dtype=tf.float32, shape=())
    learning_rate = tf.scalar_mul(
        base_lr, tf.pow((1 - step_ph / args.num_steps), args.power))

    # Gets moving_mean and moving_variance update operations from tf.GraphKeys.UPDATE_OPS
    if args.update_mean_var == False:
        update_ops = None
    else:
        update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)

    with tf.control_dependencies(update_ops):
        opt_conv = tf.train.MomentumOptimizer(learning_rate, args.momentum)
        opt_fc_w = tf.train.MomentumOptimizer(learning_rate * 10.0,
                                              args.momentum)
        opt_fc_b = tf.train.MomentumOptimizer(learning_rate * 20.0,
                                              args.momentum)

        grads = tf.gradients(reduced_loss,
                             conv_trainable + fc_w_trainable + fc_b_trainable)
        grads_conv = grads[:len(conv_trainable)]
        grads_fc_w = grads[len(conv_trainable):(len(conv_trainable) +
                                                len(fc_w_trainable))]
        grads_fc_b = grads[(len(conv_trainable) + len(fc_w_trainable)):]

        train_op_conv = opt_conv.apply_gradients(
            zip(grads_conv, conv_trainable))
        train_op_fc_w = opt_fc_w.apply_gradients(
            zip(grads_fc_w, fc_w_trainable))
        train_op_fc_b = opt_fc_b.apply_gradients(
            zip(grads_fc_b, fc_b_trainable))

        train_op = tf.group(train_op_conv, train_op_fc_w, train_op_fc_b)

    # Set up tf session and initialize variables.
    config = tf.ConfigProto()
    # config.gpu_options.allow_growth = True
    # config.allow_soft_placement = True
    # config.intra_op_parallelism_threads = 1
    sess = tf.Session(config=config)
    init = tf.global_variables_initializer()

    sess.run(init)

    # Saver for storing checkpoints of the model.
    saver = tf.train.Saver(var_list=tf.global_variables(), max_to_keep=10)

    ckpt = tf.train.get_checkpoint_state(SNAPSHOT_DIR)
    if ckpt and ckpt.model_checkpoint_path:
        loader = tf.train.Saver(var_list=restore_var)
        load_step = int(
            os.path.basename(ckpt.model_checkpoint_path).split('-')[1])
        load(loader, sess, ckpt.model_checkpoint_path)
    else:
        print('No checkpoint file found.')
        load_step = 0

    # Start queue threads.
    threads = tf.train.start_queue_runners(coord=coord, sess=sess)

    # Iterate over training steps.
    for step in range(args.num_steps):
        start_time = time.time()

        feed_dict = {step_ph: step}
        if step % args.save_pred_every == 0:
            loss_value, _, summary = sess.run(
                [reduced_loss, train_op, total_summary], feed_dict=feed_dict)
            summary_writer.add_summary(summary, step)
            save(saver, sess, args.snapshot_dir, step)
        else:
            z, t, o, p, loss_value, _ = sess.run(
                [raw_gt, raw_output, gt, prediction, reduced_loss, train_op],
                feed_dict=feed_dict)
            print(z.shape, t.shape, o.shape, p.shape)
        duration = time.time() - start_time
        print('step {:d} \t loss = {:.3f}, ({:.3f} sec/step)'.format(
            step, loss_value, duration))

    coord.request_stop()
    coord.join(threads)
Пример #15
0
    def train_setup(self):
        tf.set_random_seed(self.parameters.random_seed)

        self.coord = tf.train.Coordinator()
        input_size = (self.parameters.input_height,
                      self.parameters.input_width)
        with tf.name_scope("create_inputs"):
            reader = ImageReader(self.parameters.data_dir,
                                 self.parameters.data_list, input_size,
                                 self.parameters.random_scale,
                                 self.parameters.random_mirror,
                                 self.parameters.ignore_label, IMG_MEAN,
                                 self.coord)
            self.image_batch, self.label_batch = reader.dequeue(
                self.parameters.batch_size)

        net = Deeplab_v2(self.image_batch, self.parameters.num_classes, True,
                         self.parameters.dilated_type)
        restore_var = [
            v for v in tf.global_variables()
            if 'fc' not in v.name and 'fix_w' not in v.name
        ]
        all_trainable = tf.trainable_variables()
        encoder_trainable = [v for v in all_trainable if 'fc' not in v.name]
        decoder_trainable = [v for v in all_trainable if 'fc' in v.name]

        decoder_w_trainable = [
            v for v in decoder_trainable
            if 'weights' in v.name or 'gamma' in v.name
        ]
        decoder_b_trainable = [
            v for v in decoder_trainable
            if 'biases' in v.name or 'beta' in v.name
        ]
        assert (len(all_trainable) == len(decoder_trainable) +
                len(encoder_trainable))
        assert (len(decoder_trainable) == len(decoder_w_trainable) +
                len(decoder_b_trainable))

        raw_output = net.outputs

        output_shape = tf.shape(raw_output)
        output_size = (output_shape[1], output_shape[2])

        label_proc = prepare_label(self.label_batch,
                                   output_size,
                                   num_classes=self.parameters.num_classes,
                                   one_hot=False)
        raw_gt = tf.reshape(label_proc, [
            -1,
        ])
        indices = tf.squeeze(
            tf.where(tf.less_equal(raw_gt, self.parameters.num_classes - 1)),
            1)
        gt = tf.cast(tf.gather(raw_gt, indices), tf.int32)
        raw_prediction = tf.reshape(raw_output,
                                    [-1, self.parameters.num_classes])
        prediction = tf.gather(raw_prediction, indices)

        loss = tf.nn.sparse_softmax_cross_entropy_with_logits(
            logits=prediction, labels=gt)
        # L2 regularization
        l2_losses = [
            self.parameters.weight_decay * tf.nn.l2_loss(v)
            for v in all_trainable if 'weights' in v.name
        ]
        # Loss function
        self.reduced_loss = tf.reduce_mean(loss) + tf.add_n(l2_losses)

        base_lr = tf.constant(self.parameters.learning_rate)
        self.curr_step = tf.placeholder(dtype=tf.float32, shape=())
        learning_rate = tf.scalar_mul(
            base_lr,
            tf.pow((1 - self.curr_step / self.parameters.num_steps),
                   self.parameters.power))

        opt_encoder = tf.train.MomentumOptimizer(learning_rate,
                                                 self.parameters.momentum)
        opt_decoder_w = tf.train.MomentumOptimizer(learning_rate * 10.0,
                                                   self.parameters.momentum)
        opt_decoder_b = tf.train.MomentumOptimizer(learning_rate * 20.0,
                                                   self.parameters.momentum)

        grads = tf.gradients(
            self.reduced_loss,
            encoder_trainable + decoder_w_trainable + decoder_b_trainable)
        grads_encoder = grads[:len(encoder_trainable)]
        grads_decoder_w = grads[len(encoder_trainable):(
            len(encoder_trainable) + len(decoder_w_trainable))]
        grads_decoder_b = grads[(len(encoder_trainable) +
                                 len(decoder_w_trainable)):]

        train_op_conv = opt_encoder.apply_gradients(
            zip(grads_encoder, encoder_trainable))
        train_op_fc_w = opt_decoder_w.apply_gradients(
            zip(grads_decoder_w, decoder_w_trainable))
        train_op_fc_b = opt_decoder_b.apply_gradients(
            zip(grads_decoder_b, decoder_b_trainable))

        update_ops = tf.get_collection(
            tf.GraphKeys.UPDATE_OPS
        )  # for collecting moving_mean and moving_variance
        with tf.control_dependencies(update_ops):
            self.train_op = tf.group(train_op_conv, train_op_fc_w,
                                     train_op_fc_b)

        self.saver = tf.train.Saver(var_list=tf.global_variables(),
                                    max_to_keep=0)

        self.loader = tf.train.Saver(var_list=restore_var)

        raw_output_up = tf.image.resize_bilinear(raw_output, input_size)
        raw_output_up = tf.argmax(raw_output_up, axis=3)
        self.pred = tf.expand_dims(raw_output_up, dim=3)
        # Image summary.
        images_summary = tf.py_func(inv_preprocess,
                                    [self.image_batch, 2, IMG_MEAN], tf.uint8)
        labels_summary = tf.py_func(
            decode_labels, [self.label_batch, 2, self.parameters.num_classes],
            tf.uint8)
        preds_summary = tf.py_func(decode_labels,
                                   [self.pred, 2, self.parameters.num_classes],
                                   tf.uint8)
        self.total_summary = tf.summary.image(
            'images',
            tf.concat(axis=2,
                      values=[images_summary, labels_summary, preds_summary]),
            max_outputs=2)  # Concatenate row-wise.
        if not os.path.exists(self.parameters.logdir):
            os.makedirs(self.parameters.logdir)
        self.summary_writer = tf.summary.FileWriter(
            self.parameters.logdir, graph=tf.get_default_graph())