示例#1
0
def test_procedure(net, test_records):
    # Load data
    reader = SketchReader(tfrecord_list=test_records,
                          raw_size=[256, 256, 25],
                          shuffle=False,
                          num_threads=hyper_params['nbThreads'],
                          batch_size=1,
                          nb_epoch=1)
    raw_input = reader.next_batch()

    npr_lines, ds, _, fm, fm_inv, gt_normal, gt_depth, _, mask_cline_inv, mask_shape, mask_ds, mask_line, \
    mask_line_inv, mask_2d, sel_mask, vdotn_scalar = net.cook_raw_inputs(raw_input)

    # Network forward
    logit_d, logit_n, logit_c, _ = net.load_baseline_net(
        npr_lines,
        ds,
        mask_2d,
        fm,
        sel_mask,
        vdotn_scalar,
        hyper_params['rootFt'],
        is_training=False)

    # Test loss
    test_loss, test_d_loss, test_n_loss, test_ds_loss, test_reg_loss, test_real_dloss, \
    test_real_nloss, test_omega_loss, out_gt_normal, out_f_normal, out_gt_depth, out_f_depth, out_gt_ds, gt_lines, \
    reg_mask, out_cf_map = loss(logit_d, logit_n, logit_c, gt_normal, gt_depth, mask_shape, mask_ds, mask_cline_inv,
                                ds, npr_lines, fm_inv)

    return test_loss, test_d_loss, test_n_loss, test_ds_loss, test_reg_loss, test_real_dloss, \
           test_real_nloss, test_omega_loss, out_gt_normal, out_f_normal, out_gt_depth, \
           out_f_depth, out_gt_ds, gt_lines, reg_mask, out_cf_map
示例#2
0
def validation_procedure(net, val_records, reg_weight):
    # Load data
    with tf.name_scope('eval_inputs') as _:
        reader = SketchReader(tfrecord_list=val_records,
                              raw_size=[256, 256, 25],
                              shuffle=False,
                              num_threads=hyper_params['nbThreads'],
                              batch_size=hyper_params['batchSize'])
        raw_input = reader.next_batch()

        npr_lines, ds, _, fm, fm_inv, gt_normal, gt_depth, gt_field, mask_cline_inv, mask_shape, mask_ds, _, \
        _, mask2d, selm, ndotv = net.cook_raw_inputs(raw_input)

    # Network forward
    logit_f, field_vars = net.load_field_net(npr_lines,
                                             mask2d,
                                             ds,
                                             fm,
                                             selm,
                                             ndotv,
                                             hyper_params['rootFt'],
                                             is_training=False,
                                             reuse=True)

    logit_d, logit_n, logit_c, _ = net.load_GeomNet(npr_lines,
                                                    ds,
                                                    mask2d,
                                                    fm,
                                                    selm,
                                                    ndotv,
                                                    logit_f,
                                                    mask_cline_inv,
                                                    hyper_params['rootFt'],
                                                    is_training=False,
                                                    reuse=True)

    # Validate loss
    val_loss, val_d_loss, val_n_loss, val_ds_loss, val_reg_loss, val_real_dloss, val_real_nloss, \
    val_omega_loss = loss(logit_d,
                          logit_n,
                          logit_c,
                          gt_normal,
                          gt_depth,
                          mask_shape,
                          mask_ds,
                          mask_cline_inv,
                          reg_weight,
                          fm_inv,
                          scope='test_loss')

    # TensorBoard
    proto_list = collect_vis_img_val(logit_d, logit_n, logit_c, logit_f,
                                     npr_lines, gt_normal, gt_depth,
                                     mask_shape, ds, gt_field, mask_cline_inv,
                                     fm, fm_inv, selm, ndotv)
    merged_val = tf.summary.merge(proto_list)

    return merged_val, val_loss, val_d_loss, val_n_loss, val_ds_loss, val_reg_loss, val_real_dloss, \
           val_real_nloss, val_omega_loss, field_vars
示例#3
0
def validation_procedure(net, val_records, reg_weight):
    # Load data
    with tf.name_scope('eval_inputs') as _:
        reader = SketchReader(tfrecord_list=val_records, raw_size=[256, 256, 25], shuffle=False,
                              num_threads=hyper_params['nbThreads'], batch_size=hyper_params['batchSize'])
        raw_input = reader.next_batch()

        npr_lines, ds, _, fm, fm_inv, gt_normal, gt_depth, _, mask_cline_inv, mask_shape, mask_ds, _, \
        _, mask_2d, sel_mask, vdotn_scalar = net.cook_raw_inputs(raw_input)

    # Network forward
    logit_d, logit_n, logit_c, _ = net.load_baseline_net(nprLine_input,
                                                         ds_input,
                                                         mask2D_input,
                                                         fm_input,
                                                         selLineMask_input,
                                                         vdotnScalar_input,
                                                         hyper_params['rootFt'],
                                                         is_training=False,
                                                         reuse=True)

    # Validate loss
    val_loss, val_d_loss, val_n_loss, val_ds_loss, val_reg_loss, val_real_dloss, val_real_nloss, \
    val_omega_loss = loss(logit_d,
                          logit_n,
                          logit_c,
                          gtNormal_input,
                          gtDepth_input,
                          maskShape_input,
                          maskDs_input,
                          clineInvMask_input,
                          reg_weight,
                          fmInv_input,
                          scope='test_loss')

    # Tensorboard
    proto_list = collect_vis_img_val(logit_d,
                                     logit_n,
                                     logit_c,
                                     nprLine_input,
                                     gtNormal_input,
                                     gtDepth_input,
                                     maskShape_input,
                                     ds_input,
                                     clineInvMask_input,
                                     fm_input,
                                     fmInv_input,
                                     selLineMask_input,
                                     vdotnScalar_input,
                                     mask2D_input)
    merged_val = tf.summary.merge(proto_list)

    return merged_val, val_loss, val_d_loss, val_n_loss, val_ds_loss, val_reg_loss, val_real_dloss, \
           val_real_nloss, val_omega_loss, [npr_lines, ds, fm, fm_inv, gt_normal, gt_depth, mask_cline_inv,
                                            mask_shape, mask_ds, mask_2d, sel_mask, vdotn_scalar]
示例#4
0
def test_procedure(net, test_records):
    # Load data
    reader = SketchReader(tfrecord_list=test_records, raw_size=[256, 256, 25],
                          shuffle=False, num_threads=hyper_params['nbThreads'],
                          batch_size=1, nb_epoch=1)
    raw_input = reader.next_batch()

    npr_lines, ds, _, fm, fm_inv, gt_normal, gt_depth, gt_field, mask_cline_inv, mask_shape, mask_ds, _, _, \
    mask2d, selm, ndotv = net.cook_raw_inputs(raw_input)

    # Network forward
    logit_f, _ = net.load_field_net(nprLine_input,
                                    mask2D_input,
                                    ds_input,
                                    fm_input,
                                    selLineMask_input,
                                    vdotnScalar_input,
                                    hyper_params['rootFt'],
                                    is_training=False)

    logit_d, logit_n, logit_c, _ = net.load_GeomNet(nprLine_input,
                                                    ds_input,
                                                    mask2D_input,
                                                    fm_input,
                                                    selLineMask_input,
                                                    vdotnScalar_input,
                                                    logit_f,
                                                    clineInvMask_input,
                                                    hyper_params['rootFt'],
                                                    is_training=False)

    # Test loss
    test_loss, test_d_loss, test_n_loss, test_ds_loss, test_reg_loss, test_real_dloss, \
    test_real_nloss, test_omega_loss, out_gt_normal, out_f_normal, out_gt_depth, out_f_depth, out_gt_ds, gt_lines, \
    reg_mask, out_cf_map, test_gt_a, test_gt_b, test_f_a, test_f_b \
        = loss(logit_d,
               logit_n,
               logit_c,
               gtNormal_input,
               gtDepth_input,
               maskShape_input,
               maskDs_input,
               clineInvMask_input,
               ds_input,
               nprLine_input,
               logit_f,
               gtField_input,
               fmInv_input)

    return test_loss, test_d_loss, test_n_loss, test_ds_loss, test_reg_loss, test_real_dloss, \
           test_real_nloss, test_omega_loss, out_gt_normal, out_f_normal, out_gt_depth, \
           out_f_depth, out_gt_ds, gt_lines, reg_mask, out_cf_map, test_gt_a, test_gt_b, test_f_a, test_f_b, \
           [npr_lines, ds, fm, fm_inv, gt_normal, gt_depth, gt_field,
            mask_cline_inv, mask_shape, mask_ds, mask2d, selm, ndotv]
示例#5
0
def validation_procedure(net, val_records, reg_weight):
    # Load data
    with tf.name_scope('eval_inputs') as _:
        reader = SketchReader(tfrecord_list=val_records, raw_size=[256, 256, 25], shuffle=False,
                              num_threads=hyper_params['nbThreads'], batch_size=hyper_params['batchSize'])
        raw_input = reader.next_batch()

        npr_lines, ds, _, fm, fm_inv, gt_normal, gt_depth, _, mask_cline_inv, mask_shape, mask_ds, _, _, \
        mask_2d, sel_mask, vdotn_scalar = net.cook_raw_inputs(raw_input)

    # Network forward
    logit_d, logit_n, _ = net.load_dn_naive_net(npr_lines,
                                                ds,
                                                mask_2d,
                                                fm,
                                                sel_mask,
                                                vdotn_scalar,
                                                hyper_params['rootFt'],
                                                is_training=False,
                                                reuse=True)

    # Validate loss
    val_loss, val_d_loss, val_n_loss, val_real_dloss, val_real_nloss, val_r_loss, val_ds_loss \
        = loss(logit_d,
               logit_n,
               gt_normal,
               gt_depth,
               mask_shape,
               mask_ds,
               mask_cline_inv,
               reg_weight,
               fm_inv,
               mode=hyper_params['lossId'],
               scope='test_loss')

    # TensorBoard
    proto_list = collect_vis_img_val(logit_d,
                                     logit_n,
                                     npr_lines,
                                     gt_normal,
                                     gt_depth,
                                     mask_shape,
                                     ds,
                                     mask_cline_inv,
                                     fm,
                                     fm_inv,
                                     sel_mask,
                                     vdotn_scalar,
                                     mask_2d)
    merged_val = tf.summary.merge(proto_list)

    return merged_val, val_loss, val_d_loss, val_n_loss, val_real_dloss, val_real_nloss, val_r_loss, val_ds_loss
示例#6
0
def validation_procedure(net, val_records):
    # Load data
    with tf.name_scope('eval_inputs') as _:
        reader = SketchReader(tfrecord_list=val_records,
                              raw_size=[256, 256, 25],
                              shuffle=False,
                              num_threads=hyper_params['nbThreads'],
                              batch_size=hyper_params['batchSize'])
        raw_input = reader.next_batch()

        npr_lines, ds, _, fm, fm_inv, _, _, gt_field, cline_mask_inv, mask_shape, _, _, _, mask2d, selm, vdotn \
            = net.cook_raw_inputs(raw_input)

    # Network forward
    logit_f, _ = net.load_field_net(nprLine_input,
                                    mask2D_input,
                                    ds_input,
                                    fm_input,
                                    selLineMask_input,
                                    vdotnScalar_input,
                                    hyper_params['rootFt'],
                                    is_training=False,
                                    reuse=True)

    # Validate loss
    val_loss, val_data_loss, val_smooth_loss = loss(logit_f,
                                                    gtField_input,
                                                    maskShape_input,
                                                    clineInvMask_input,
                                                    fmInv_input,
                                                    scope='test_loss')

    # TensorBoard
    proto_list = collect_vis_img_val(logit_f, nprLine_input, gtField_input,
                                     maskShape_input, clineInvMask_input,
                                     ds_input, fm_input, selLineMask_input,
                                     vdotnScalar_input, fmInv_input)

    merged_val = tf.summary.merge(proto_list)

    return merged_val, val_loss, val_data_loss, val_smooth_loss, \
           [npr_lines, ds, fm, fm_inv, gt_field, cline_mask_inv, mask_shape, mask2d, selm, vdotn]
示例#7
0
def test_procedure(net, test_records):
    # Load data
    reader = SketchReader(tfrecord_list=test_records,
                          raw_size=[256, 256, 25],
                          shuffle=False,
                          num_threads=hyper_params['nbThreads'],
                          batch_size=1,
                          nb_epoch=1)
    raw_input = reader.next_batch()

    npr_lines, ds, _, fm, _, gt_normal, gt_depth, _, mask_cline_inv, mask_shape, mask_ds, _, _, mask_2d, \
    sel_m, vdotn = net.cook_raw_inputs(raw_input)

    # Network forward
    logit_d, logit_n, _ = net.load_dn_naive_net(npr_lines,
                                                ds,
                                                mask_2d,
                                                fm,
                                                sel_m,
                                                vdotn,
                                                hyper_params['rootFt'],
                                                is_training=False)

    # Test loss
    test_loss, test_d_loss, test_n_loss, test_real_dloss, test_real_nloss, out_gt_normal, out_f_normal, out_gt_depth, \
    out_f_depth, out_gt_ds, gt_lines = loss(logit_d,
                                            logit_n,
                                            gt_normal,
                                            gt_depth,
                                            mask_shape,
                                            mask_ds,
                                            mask_cline_inv,
                                            ds,
                                            npr_lines,
                                            mode=hyper_params['lossId'])

    return test_loss, test_d_loss, test_n_loss, test_real_dloss, test_real_nloss, out_gt_normal, out_f_normal, \
           out_gt_depth, out_f_depth, out_gt_ds, gt_lines
示例#8
0
def train_procedure(net, train_records):
    nb_gpus = hyper_params['nb_gpus']

    # Load data
    with tf.name_scope('train_inputs') as _:
        bSize = hyper_params['batchSize'] * nb_gpus
        nbThreads = hyper_params['nbThreads'] * nb_gpus
        reader = SketchReader(tfrecord_list=train_records,
                              raw_size=[256, 256, 25],
                              shuffle=True,
                              num_threads=nbThreads,
                              batch_size=bSize)
        raw_input = reader.next_batch()

        npr_lines, ds, _, fm, fm_inv, _, _, gt_field, cline_mask_inv, mask_shape, _, _, _, mask_2d, selm, vdotn \
            = net.cook_raw_inputs(raw_input)

    # initialize optimizer
    opt = tf.train.AdamOptimizer()

    # split data
    with tf.name_scope('divide_data'):
        gpu_npr_lines = tf.split(nprLine_input, nb_gpus, axis=0)
        gpu_mask2d = tf.split(mask2D_input, nb_gpus, axis=0)
        gpu_ds = tf.split(ds_input, nb_gpus, axis=0)
        gpu_fm = tf.split(fm_input, nb_gpus, axis=0)
        gpu_fm_inv = tf.split(fmInv_input, nb_gpus, axis=0)
        gpu_selm = tf.split(selLineMask_input, nb_gpus, axis=0)
        gpu_vdotn = tf.split(vdotnScalar_input, nb_gpus, axis=0)
        gpu_gt_field = tf.split(gtField_input, nb_gpus, axis=0)
        gpu_mask_shape = tf.split(maskShape_input, nb_gpus, axis=0)
        gpu_mask_cline = tf.split(clineInvMask_input, nb_gpus, axis=0)

    tower_grads = []
    tower_loss_collected = []
    tower_total_losses = []
    tower_data_losses = []
    tower_smooth_losses = []

    # TensorBoard: images
    gpu0_npr_lines_imgs = None
    gpu0_logit_f_imgs = None
    gpu0_gt_f_imgs = None
    gpu0_shape_mask_imgs = None
    gpu0_shape_cline_imgs = None
    gpu0_ds_imgs = None
    gpu0_fm_imgs = None
    gpu0_fm_inv_imgs = None
    gpu0_selm_imgs = None
    gpu0_vdotn_imgs = None

    with tf.variable_scope(tf.get_variable_scope()):
        for gpu_id in range(nb_gpus):
            with tf.device('/gpu:%d' % gpu_id):
                with tf.name_scope('tower_%s' % gpu_id) as _:
                    # Network forward
                    logit_f, _ = net.load_field_net(gpu_npr_lines[gpu_id],
                                                    gpu_mask2d[gpu_id],
                                                    gpu_ds[gpu_id],
                                                    gpu_fm[gpu_id],
                                                    gpu_selm[gpu_id],
                                                    gpu_vdotn[gpu_id],
                                                    hyper_params['rootFt'],
                                                    is_training=True)

                    # Training loss
                    train_loss, train_data_loss, train_smooth_loss = loss(
                        logit_f,
                        gpu_gt_field[gpu_id],
                        gpu_mask_shape[gpu_id],
                        gpu_mask_cline[gpu_id],
                        gpu_fm_inv[gpu_id],
                        scope='train_loss')

                    # reuse variables
                    tf.get_variable_scope().reuse_variables()

                    # collect gradients and every loss
                    tower_grads.append(opt.compute_gradients(train_loss))
                    tower_total_losses.append(train_loss)
                    tower_data_losses.append(train_data_loss)
                    tower_smooth_losses.append(train_smooth_loss)

                    # TensorBoard: collect images from GPU 0
                    if gpu_id == 0:
                        gpu0_npr_lines_imgs = gpu_npr_lines[gpu_id]
                        gpu0_logit_f_imgs = logit_f
                        gpu0_gt_f_imgs = gpu_gt_field[gpu_id]
                        gpu0_shape_mask_imgs = gpu_mask_shape[gpu_id]
                        gpu0_shape_cline_imgs = gpu_mask_cline[gpu_id]
                        gpu0_ds_imgs = gpu_ds[gpu_id]
                        gpu0_fm_imgs = gpu_fm[gpu_id]
                        gpu0_fm_inv_imgs = gpu_fm_inv[gpu_id]
                        gpu0_selm_imgs = gpu_selm[gpu_id]
                        gpu0_vdotn_imgs = gpu_vdotn[gpu_id]

        tower_loss_collected.append(tower_total_losses)
        tower_loss_collected.append(tower_data_losses)
        tower_loss_collected.append(tower_smooth_losses)

    # Solver
    with tf.name_scope('solve') as _:
        update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
        with tf.control_dependencies(update_ops):
            grads = average_gradient(tower_grads)
            averaged_losses = average_losses(tower_loss_collected)
            apply_gradient_op = opt.apply_gradients(grads)
            train_op = tf.group(apply_gradient_op)

    # TensorBoard: visualization
    train_diff_proto = tf.summary.scalar('Training_TotalLoss',
                                         averaged_losses[0])
    train_data_loss_proto = tf.summary.scalar('Traning_DataL1Loss',
                                              averaged_losses[1])
    train_smooth_loss_proto = tf.summary.scalar('Training_SmoothL1Loss',
                                                averaged_losses[2])

    proto_list = collect_vis_img(gpu0_logit_f_imgs, gpu0_npr_lines_imgs,
                                 gpu0_gt_f_imgs, gpu0_shape_mask_imgs,
                                 gpu0_shape_cline_imgs, gpu0_ds_imgs,
                                 gpu0_fm_imgs, gpu0_selm_imgs, gpu0_vdotn_imgs,
                                 gpu0_fm_inv_imgs)

    proto_list.append(train_diff_proto)
    proto_list.append(train_data_loss_proto)
    proto_list.append(train_smooth_loss_proto)
    merged_train = tf.summary.merge(proto_list)

    return merged_train, train_op, averaged_losses[0], \
           [npr_lines, ds, fm, fm_inv, gt_field, cline_mask_inv, mask_shape, mask_2d, selm, vdotn]
示例#9
0
def train_procedure(net, train_records, reg_weight):
    nb_gpus = hyper_params['nb_gpus']

    # Load data
    with tf.name_scope('train_inputs') as _:
        bSize = hyper_params['batchSize'] * nb_gpus
        nbThreads = hyper_params['nbThreads'] * nb_gpus
        reader = SketchReader(tfrecord_list=train_records,
                              raw_size=[256, 256, 25],
                              shuffle=True,
                              num_threads=nbThreads,
                              batch_size=bSize)
        raw_input = reader.next_batch()

        npr_lines, ds, _, fm, fm_inv, gt_normal, gt_depth, gt_field, mask_cline_inv, mask_shape, mask_ds, _, \
        _, mask2d, selm, ndotv = net.cook_raw_inputs(raw_input)

    # initialize optimizer
    opt = tf.train.AdamOptimizer()

    # split data
    with tf.name_scope('divide_data'):
        gpu_npr_lines = tf.split(npr_lines, nb_gpus, axis=0)
        gpu_ds = tf.split(ds, nb_gpus, axis=0)
        gpu_fm = tf.split(fm, nb_gpus, axis=0)
        gpu_fm_inv = tf.split(fm_inv, nb_gpus, axis=0)
        gpu_selm = tf.split(selm, nb_gpus, axis=0)
        gpu_ndotv = tf.split(ndotv, nb_gpus, axis=0)
        gpu_gt_normal = tf.split(gt_normal, nb_gpus, axis=0)
        gpu_gt_depth = tf.split(gt_depth, nb_gpus, axis=0)
        gpu_gt_field = tf.split(gt_field, nb_gpus, axis=0)
        gpu_mask_shape = tf.split(mask_shape, nb_gpus, axis=0)
        gpu_mask_ds = tf.split(mask_ds, nb_gpus, axis=0)
        gpu_cline_inv = tf.split(mask_cline_inv, nb_gpus, axis=0)
        gpu_mask2d = tf.split(mask2d, nb_gpus, axis=0)

    tower_grads = []
    tower_loss_collected = []
    tower_total_losses = []
    tower_d_losses = []
    tower_n_losses = []
    tower_ds_losses = []
    tower_reg_losses = []
    tower_abs_d_losses = []
    tower_abs_n_losses = []
    tower_omega_losses = []

    # TensorBoard: images
    gpu0_npr_lines_imgs = None
    gpu0_gt_ds_imgs = None
    gpu0_logit_d_imgs = None
    gpu0_logit_n_imgs = None
    gpu0_logit_c_imgs = None
    gpu0_logit_f_imgs = None
    gpu0_gt_normal_imgs = None
    gpu0_gt_depth_imgs = None
    gpu0_shape_mask_imgs = None
    gpu0_gt_field_imgs = None
    gpu0_mask_cline_inv_imgs = None
    gpu0_fm_imgs = None
    gpu0_fm_inv_imgs = None
    gpu0_selm_imgs = None
    gpu0_ndotv_imgs = None

    with tf.variable_scope(tf.get_variable_scope()):
        for gpu_id in range(nb_gpus):
            with tf.device('gpu:%d' % gpu_id):
                with tf.name_scope('tower_%s' % gpu_id) as _:
                    # Network forward
                    logit_f, _ = net.load_field_net(gpu_npr_lines[gpu_id],
                                                    gpu_mask2d[gpu_id],
                                                    gpu_ds[gpu_id],
                                                    gpu_fm[gpu_id],
                                                    gpu_selm[gpu_id],
                                                    gpu_ndotv[gpu_id],
                                                    hyper_params['rootFt'],
                                                    is_training=False)

                    logit_d, logit_n, logit_c, _ = net.load_GeomNet(
                        gpu_npr_lines[gpu_id],
                        gpu_ds[gpu_id],
                        gpu_mask2d[gpu_id],
                        gpu_fm[gpu_id],
                        gpu_selm[gpu_id],
                        gpu_ndotv[gpu_id],
                        logit_f,
                        gpu_cline_inv[gpu_id],
                        hyper_params['rootFt'],
                        is_training=True)

                    # Training loss
                    train_loss, train_d_loss, train_n_loss, train_ds_loss, train_reg_loss, train_real_dloss, \
                    train_real_nloss, train_omega_loss = loss(logit_d,
                                                              logit_n,
                                                              logit_c,
                                                              gpu_gt_normal[gpu_id],
                                                              gpu_gt_depth[gpu_id],
                                                              gpu_mask_shape[gpu_id],
                                                              gpu_mask_ds[gpu_id],
                                                              gpu_cline_inv[gpu_id],
                                                              reg_weight,
                                                              gpu_fm_inv[gpu_id],
                                                              scope='train_loss')

                    # reuse variables
                    tf.get_variable_scope().reuse_variables()

                    # collect gradients and every loss
                    cur_trainable_vars = tf.get_collection(
                        tf.GraphKeys.TRAINABLE_VARIABLES, scope='SASMFGeoNet')
                    tower_grads.append(
                        opt.compute_gradients(train_loss,
                                              var_list=cur_trainable_vars))
                    tower_total_losses.append(train_loss)
                    tower_d_losses.append(train_d_loss)
                    tower_n_losses.append(train_n_loss)
                    tower_ds_losses.append(train_ds_loss)
                    tower_reg_losses.append(train_reg_loss)
                    tower_abs_d_losses.append(train_real_dloss)
                    tower_abs_n_losses.append(train_real_nloss)
                    tower_omega_losses.append(train_omega_loss)

                    # TensorBoard: collect images from GPU 0
                    if gpu_id == 0:
                        gpu0_npr_lines_imgs = gpu_npr_lines[gpu_id]
                        gpu0_gt_ds_imgs = gpu_ds[gpu_id]
                        gpu0_logit_d_imgs = logit_d
                        gpu0_logit_n_imgs = logit_n
                        gpu0_logit_c_imgs = logit_c
                        gpu0_logit_f_imgs = logit_f
                        gpu0_gt_normal_imgs = gpu_gt_normal[gpu_id]
                        gpu0_gt_depth_imgs = gpu_gt_depth[gpu_id]
                        gpu0_shape_mask_imgs = gpu_mask_shape[gpu_id]
                        gpu0_gt_field_imgs = gpu_gt_field[gpu_id]
                        gpu0_mask_cline_inv_imgs = gpu_cline_inv[gpu_id]
                        gpu0_fm_imgs = gpu_fm[gpu_id]
                        gpu0_fm_inv_imgs = gpu_fm_inv[gpu_id]
                        gpu0_selm_imgs = gpu_selm[gpu_id]
                        gpu0_ndotv_imgs = gpu_ndotv[gpu_id]

        tower_loss_collected.append(tower_total_losses)
        tower_loss_collected.append(tower_d_losses)
        tower_loss_collected.append(tower_n_losses)
        tower_loss_collected.append(tower_ds_losses)
        tower_loss_collected.append(tower_reg_losses)
        tower_loss_collected.append(tower_abs_d_losses)
        tower_loss_collected.append(tower_abs_n_losses)
        tower_loss_collected.append(tower_omega_losses)

    # Solve
    with tf.name_scope('solve') as _:
        update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS,
                                       scope='.*SASMFGeoNet.*')
        with tf.control_dependencies(update_ops):
            grads = average_gradient(tower_grads)
            averaged_losses = average_losses(tower_loss_collected)
            apply_gradient_op = opt.apply_gradients(grads)
            train_op = tf.group(apply_gradient_op)

    # TensorBoard: visualization
    train_diff_proto = tf.summary.scalar('Training_TotalLoss',
                                         averaged_losses[0])
    train_diff_d_proto = tf.summary.scalar('Training_DepthL2Loss',
                                           averaged_losses[1])
    train_diff_n_proto = tf.summary.scalar('Training_NormalL2Loss',
                                           averaged_losses[2])
    train_diff_ds_proto = tf.summary.scalar('Training_DepthSampleL2Loss',
                                            averaged_losses[3])
    train_diff_reg_proto = tf.summary.scalar('Training_RegL2Loss',
                                             averaged_losses[4])
    train_diff_reald_proto = tf.summary.scalar('Training_RealDLoss',
                                               averaged_losses[5])
    train_diff_realn_proto = tf.summary.scalar('Training_RealNLoss',
                                               averaged_losses[6])
    train_diff_omega_proto = tf.summary.scalar('Training_OmegaLoss',
                                               averaged_losses[7])

    proto_list = collect_vis_img(
        gpu0_logit_d_imgs, gpu0_logit_n_imgs, gpu0_logit_c_imgs,
        gpu0_logit_f_imgs, gpu0_npr_lines_imgs, gpu0_gt_normal_imgs,
        gpu0_gt_depth_imgs, gpu0_shape_mask_imgs, gpu0_gt_ds_imgs,
        gpu0_gt_field_imgs, gpu0_mask_cline_inv_imgs, gpu0_fm_imgs,
        gpu0_fm_inv_imgs, gpu0_selm_imgs, gpu0_ndotv_imgs)

    proto_list.append(train_diff_proto)
    proto_list.append(train_diff_d_proto)
    proto_list.append(train_diff_n_proto)
    proto_list.append(train_diff_ds_proto)
    proto_list.append(train_diff_reg_proto)
    proto_list.append(train_diff_reald_proto)
    proto_list.append(train_diff_realn_proto)
    proto_list.append(train_diff_omega_proto)
    merged_train = tf.summary.merge(proto_list)

    return merged_train, train_op, averaged_losses[0]