Exemple #1
0
def test_net():
    # Set logging
    test_logger = logging.getLogger('main.testing')
    test_logger.info('---Begin testing: ---')

    # Load network
    net = SKETCHNET()

    # Testing data
    data_records = [item for item in os.listdir(hyper_params['dbTest']) if item.endswith('.tfrecords')]
    test_records = [os.path.join(hyper_params['dbTest'], item) for item in data_records if item.find('test') != -1]

    test_loss, test_d_loss, test_n_loss, test_ds_loss, test_r_loss, test_real_dloss, \
    test_real_nloss, test_omega_loss, test_gt_normal, test_f_normal, test_gt_depth, test_f_depth, test_gt_ds, \
    test_gt_lines, test_reg_mask, test_f_cfmap, test_gt_a, test_gt_b, test_f_a, test_f_b, test_inputList \
        = test_procedure(net, test_records)

    # Saver
    tf_saver = tf.train.Saver()
	
	config = tf.ConfigProto()
Exemple #2
0
def train_net():
    # Set logging
    train_logger = logging.getLogger('main.training')
    train_logger.info('---Begin training: ---')

    # Load network
    net = SKETCHNET()

    # Train
    train_data_records = [
        item for item in os.listdir(hyper_params['dbTrain'])
        if item.endswith('.tfrecords')
    ]
    train_records = [
        os.path.join(hyper_params['dbTrain'], item)
        for item in train_data_records if item.find('train') != -1
    ]
    train_summary, train_step, train_loss, train_inputList = train_procedure(
        net, train_records)

    # Validation
    val_data_records = [
        item for item in os.listdir(hyper_params['dbEval'])
        if item.endswith('.tfrecords')
    ]
    val_records = [
        os.path.join(hyper_params['dbEval'], item) for item in val_data_records
        if item.find('eval') != -1
    ]
    num_eval_samples = sum(
        1 for _ in tf.python_io.tf_record_iterator(val_records[0]))
    num_eval_itr = num_eval_samples // hyper_params['batchSize']
    num_eval_itr += 1

    val_proto, val_total_loss, val_data_loss, val_smooth_loss, val_inputList = validation_procedure(
        net, val_records)

    valid_loss = tf.placeholder(tf.float32, name='val_loss')
    valid_loss_proto = tf.summary.scalar('Validating_TotalLoss', valid_loss)
    valid_data_loss = tf.placeholder(tf.float32, name='val_data_loss')
    valid_data_loss_proto = tf.summary.scalar('Validating_DataL1Loss',
                                              valid_data_loss)
    valid_smooth_loss = tf.placeholder(tf.float32, name='val_smooth_loss')
    valid_smooth_loss_proto = tf.summary.scalar('Validating_SmoothL1Loss',
                                                valid_smooth_loss)
    valid_loss_merge = tf.summary.merge(
        [valid_loss_proto, valid_data_loss_proto, valid_smooth_loss_proto])

    # Saver
    tf_saver = tf.train.Saver(max_to_keep=100)

    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    config.allow_soft_placement = True
    # config.log_device_placement = True

    with tf.Session(config=config) as sess:
        # TF summary
        train_writer = tf.summary.FileWriter(output_folder + '/train',
                                             sess.graph)

        # initialize
        init_op = tf.global_variables_initializer()
        sess.run(init_op)

        # Start input enqueue threads
        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(sess=sess, coord=coord)
        train_logger.info('pre-load data to fill data buffer...')

        for titr in range(hyper_params['maxIter']):
            # Validation
            if titr % hyper_params['exeValStep'] == 0:
                idx = randint(0, num_eval_itr - 1)
                avg_loss = 0.0
                avg_data_loss = 0.0
                avg_smooth_loss = 0.0
                for eitr in range(num_eval_itr):

                    # get real input
                    val_real_input = sess.run(val_inputList)

                    if eitr == idx:
                        val_merge, cur_v_loss, cur_v_data_loss, cur_v_smooth_loss = sess.run(
                            [
                                val_proto, val_total_loss, val_data_loss,
                                val_smooth_loss
                            ],
                            feed_dict={
                                'npr_input:0': val_real_input[0],
                                'ds_input:0': val_real_input[1],
                                'fLMask_input:0': val_real_input[2],
                                'fLInvMask_input:0': val_real_input[3],
                                'gtField_input:0': val_real_input[4],
                                'clIMask_input:0': val_real_input[5],
                                'shapeMask_input:0': val_real_input[6],
                                '2dMask_input:0': val_real_input[7],
                                'sLMask_input:0': val_real_input[8],
                                'curvMag_input:0': val_real_input[9]
                            })
                        train_writer.add_summary(val_merge, titr)
                    else:
                        cur_v_loss, cur_v_data_loss, cur_v_smooth_loss = sess.run(
                            [val_total_loss, val_data_loss, val_smooth_loss],
                            feed_dict={
                                'npr_input:0': val_real_input[0],
                                'ds_input:0': val_real_input[1],
                                'fLMask_input:0': val_real_input[2],
                                'fLInvMask_input:0': val_real_input[3],
                                'gtField_input:0': val_real_input[4],
                                'clIMask_input:0': val_real_input[5],
                                'shapeMask_input:0': val_real_input[6],
                                '2dMask_input:0': val_real_input[7],
                                'sLMask_input:0': val_real_input[8],
                                'curvMag_input:0': val_real_input[9]
                            })
                    avg_loss += cur_v_loss
                    avg_data_loss += cur_v_data_loss
                    avg_smooth_loss += cur_v_smooth_loss

                avg_loss /= num_eval_itr
                avg_data_loss /= num_eval_itr
                avg_smooth_loss /= num_eval_itr
                valid_summary = sess.run(valid_loss_merge,
                                         feed_dict={
                                             'val_loss:0': avg_loss,
                                             'val_data_loss:0': avg_data_loss,
                                             'val_smooth_loss:0':
                                             avg_smooth_loss
                                         })
                train_writer.add_summary(valid_summary, titr)
                train_logger.info('Validation loss at step {} is: {}'.format(
                    titr, avg_loss))

            # Save model
            if titr % hyper_params['saveModelStep'] == 0:
                tf_saver.save(
                    sess, hyper_params['outDir'] +
                    '/savedModel/my_model{:d}.ckpt'.format(titr))
                train_logger.info('Save model at step: {:d}'.format(titr))

            # Training
            # get real input
            train_real_input = sess.run(train_inputList)

            t_summary, _, t_loss = sess.run(
                [train_summary, train_step, train_loss],
                feed_dict={
                    'npr_input:0': train_real_input[0],
                    'ds_input:0': train_real_input[1],
                    'fLMask_input:0': train_real_input[2],
                    'fLInvMask_input:0': train_real_input[3],
                    'gtField_input:0': train_real_input[4],
                    'clIMask_input:0': train_real_input[5],
                    'shapeMask_input:0': train_real_input[6],
                    '2dMask_input:0': train_real_input[7],
                    'sLMask_input:0': train_real_input[8],
                    'curvMag_input:0': train_real_input[9]
                })

            # display
            if titr % hyper_params['dispLossStep'] == 0:
                train_writer.add_summary(t_summary, titr)
                train_logger.info('Training loss at step {} is: {}'.format(
                    titr, t_loss))

        # Finish training
        coord.request_stop()
        coord.join(threads)

        # Release resource
        train_writer.close()
Exemple #3
0
def test_net():
    # Set logging
    test_logger = logging.getLogger('main.testing')
    test_logger.info('---Begin testing: ---')

    # Load network
    net = SKETCHNET()

    # Testing data
    data_records = [item for item in os.listdir(hyper_params['dbTest']) if item.endswith('.tfrecords')]
    test_records = [os.path.join(hyper_params['dbTest'], item) for item in data_records if item.find('test') != -1]

    test_loss, test_d_loss, test_n_loss, test_real_dloss, test_real_nloss, test_gt_normal, test_f_normal, \
    test_gt_depth, test_f_depth, test_gt_ds, test_gt_lines, test_inputList = test_procedure(net, test_records)

    # Saver
    tf_saver = tf.train.Saver()

    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    config.allow_soft_placement = True
    # config.log_device_placement = True

    with tf.Session(config=config) as sess:
        # initialize
        init_op = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer())
        sess.run(init_op)

        # Restore model
        ckpt = tf.train.latest_checkpoint(hyper_params['cktDir'])
        if ckpt:
            tf_saver.restore(sess, ckpt)
            test_logger.info('restore from the checkpoint {}'.format(ckpt))

        # write graph:
        tf.train.write_graph(sess.graph_def,
                             hyper_params['outDir'],
                             hyper_params['graphName'],
                             as_text=True)
        test_logger.info('save graph tp pbtxt, done')

        # Start input enqueue threads
        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(sess=sess, coord=coord)

        try:
            titr = 0
            avg_loss = 0.0
            while not coord.should_stop():

                # get real input
                test_real_input = sess.run(test_inputList)

                t_loss, t_d_loss, t_n_loss, t_real_dloss, t_real_nloss, \
                t_gt_normal, t_f_normal, t_gt_depth, t_f_depth, t_gt_ds, t_gt_lines = sess.run(
                    [test_loss, test_d_loss, test_n_loss, test_real_dloss, test_real_nloss, test_gt_normal,
                     test_f_normal, test_gt_depth, test_f_depth, test_gt_ds, test_gt_lines],
                    feed_dict={'npr_input:0': test_real_input[0],
                               'ds_input:0': test_real_input[1],
                               'fLMask_input:0': test_real_input[2],
                               'gtN_input:0': test_real_input[3],
                               'gtD_input:0': test_real_input[4],
                               'clIMask_input:0': test_real_input[5],
                               'shapeMask_input:0': test_real_input[6],
                               'dsMask_input:0': test_real_input[7],
                               '2dMask_input:0': test_real_input[8],
                               'sLMask_input:0': test_real_input[9],
                               'curvMag_input:0': test_real_input[10]
                               })

                # Record loss
                avg_loss += t_loss
                test_logger.info(
                    'Test case {}, loss: {}, {}, {}, {}, {}, 0.0, 0.0, 0.0, 0.0'.format(titr, t_loss, t_real_dloss,
                                                                                        t_real_nloss, t_d_loss,
                                                                                        t_n_loss))

                # Write img out
                if titr < 200:
					fn1 = os.path.join(out_img_dir, 'gt_depth_' + str(titr) + '.exr')
					fn2 = os.path.join(out_img_dir, 'fwd_depth_' + str(titr) + '.exr')
					fn3 = os.path.join(out_img_dir, 'gt_normal_' + str(titr) + '.exr')
					fn4 = os.path.join(out_img_dir, 'fwd_normal_' + str(titr) + '.exr')

					out_gt_d = t_gt_depth[0, :, :, :]
					out_gt_d.astype(np.float32)
					out_gt_d = np.flip(out_gt_d, 0)
					cv2.imwrite(fn1, out_gt_d)

					out_f_d = t_f_depth[0, :, :, :]
					out_f_d.astype(np.float32)
					out_f_d = np.flip(out_f_d, 0)
					cv2.imwrite(fn2, out_f_d)

					out_gt_normal = t_gt_normal[0, :, :, :]
					out_gt_normal = out_gt_normal[:, :, [2, 1, 0]]
					out_gt_normal.astype(np.float32)
					out_gt_normal = np.flip(out_gt_normal, 0)
					cv2.imwrite(fn3, out_gt_normal)

					out_f_normal = t_f_normal[0, :, :, :]
					out_f_normal = out_f_normal[:, :, [2, 1, 0]]
					out_f_normal.astype(np.float32)
					out_f_normal = np.flip(out_f_normal, 0)
					cv2.imwrite(fn4, out_f_normal)

                titr += 1
                if titr % 100 == 0:
                    print('Iteration: {}'.format(titr))

            avg_loss /= titr
            test_logger.info('Finish test model, average loss is: {}'.format(avg_loss))

        except tf.errors.OutOfRangeError:
            print('Test Done.')
        finally:
            coord.request_stop()

        # Finish testing
        coord.join(threads)
def test_net():
    # Set logging
    test_logger = logging.getLogger('main.testing')
    test_logger.info('---Begin testing: ---')

    # Load network
    net = SKETCHNET()

    # Testing data
    data_records = [
        item for item in os.listdir(hyper_params['dbTest'])
        if item.endswith('.tfrecords')
    ]
    test_records = [
        os.path.join(hyper_params['dbTest'], item) for item in data_records
        if item.find('test') != -1
    ]

    test_loss, test_d_loss, test_n_loss, test_ds_loss, test_r_loss, test_real_dloss, \
    test_real_nloss, test_omega_loss, test_gt_normal, test_f_normal, test_gt_depth, test_f_depth, test_gt_ds, \
    test_gt_lines, test_reg_mask, test_f_cfmap = test_procedure(net, test_records)

    # Saver
    tf_saver = tf.train.Saver()

    with tf.Session() as sess:
        # initialize
        init_op = tf.group(tf.global_variables_initializer(),
                           tf.local_variables_initializer())
        sess.run(init_op)

        # Restore model
        ckpt = tf.train.latest_checkpoint(hyper_params['cktDir'])
        if ckpt:
            tf_saver.restore(sess, ckpt)
            test_logger.info('restore from the checkpoint {}'.format(ckpt))

        # writeGraph:
        tf.train.write_graph(sess.graph_def,
                             hyper_params['outDir'],
                             hyper_params['graphName'],
                             as_text=True)
        test_logger.info('save graph tp pbtxt, done')

        # Start input enqueue threads
        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(sess=sess, coord=coord)

        try:
            titr = 0
            avg_loss = 0.0
            while not coord.should_stop():

                t_loss, t_d_loss, t_n_loss, t_ds_loss, t_r_loss, t_real_dloss, t_real_nloss, \
                t_omega_loss, t_gt_normal, t_f_normal, t_gt_depth, t_f_depth, t_gt_ds, t_gt_lines, t_reg_mask, \
                t_f_cfmap = sess.run(
                    [test_loss, test_d_loss, test_n_loss, test_ds_loss, test_r_loss,
                     test_real_dloss, test_real_nloss, test_omega_loss, test_gt_normal, test_f_normal,
                     test_gt_depth, test_f_depth, test_gt_ds, test_gt_lines, test_reg_mask, test_f_cfmap])

                # Record loss
                avg_loss += t_loss
                test_logger.info(
                    'Test case {}, loss: {}, {}, {}, {}, {}, {}, {}, 0.0, {}'.
                    format(titr, t_loss, t_real_dloss, t_real_nloss, t_d_loss,
                           t_n_loss, t_ds_loss, t_r_loss, t_omega_loss))

                # Write img out
                # if titr < 200:
                fn1 = os.path.join(out_img_dir,
                                   'gt_depth_' + str(titr) + '.exr')
                fn2 = os.path.join(out_img_dir,
                                   'fwd_depth_' + str(titr) + '.exr')
                fn3 = os.path.join(out_img_dir,
                                   'gt_normal_' + str(titr) + '.exr')
                fn4 = os.path.join(out_img_dir,
                                   'fwd_normal_' + str(titr) + '.exr')
                fn5 = os.path.join(out_img_dir,
                                   'fwd_conf_map_' + str(titr) + '.exr')

                out_gt_d = t_gt_depth[0, :, :, :]
                out_gt_d.astype(np.float32)
                out_gt_d = np.flip(out_gt_d, 0)
                cv2.imwrite(fn1, out_gt_d)

                out_f_d = t_f_depth[0, :, :, :]
                out_f_d.astype(np.float32)
                out_f_d = np.flip(out_f_d, 0)
                cv2.imwrite(fn2, out_f_d)

                out_gt_normal = t_gt_normal[0, :, :, :]
                out_gt_normal = out_gt_normal[:, :, [2, 1, 0]]
                out_gt_normal.astype(np.float32)
                out_gt_normal = np.flip(out_gt_normal, 0)
                cv2.imwrite(fn3, out_gt_normal)

                out_f_normal = t_f_normal[0, :, :, :]
                out_f_normal = out_f_normal[:, :, [2, 1, 0]]
                out_f_normal.astype(np.float32)
                out_f_normal = np.flip(out_f_normal, 0)
                cv2.imwrite(fn4, out_f_normal)

                out_f_cfmap = t_f_cfmap[0, :, :, :]
                out_f_cfmap.astype(np.float32)
                out_f_cfmap = np.flip(out_f_cfmap, 0)
                cv2.imwrite(fn5, out_f_cfmap)

                titr += 1
                if titr % 100 == 0:
                    print('Iteration: {}'.format(titr))

            avg_loss /= titr
            test_logger.info(
                'Finish test model, average loss is: {}'.format(avg_loss))

        except tf.errors.OutOfRangeError:
            print('Test Done.')
        finally:
            coord.request_stop()

        # Finish testing
        coord.join(threads)
Exemple #5
0
def train_net():
    # Set logging
    train_logger = logging.getLogger('main.training')
    train_logger.info('---Begin training: ---')

    # Load network
    net = SKETCHNET()

    # regularization weight
    reg_weight_value = tf.placeholder(tf.float32, name='reg_weight')

    # Train
    train_data_records = [
        item for item in os.listdir(hyper_params['dbTrain'])
        if item.endswith('.tfrecords')
    ]
    train_records = [
        os.path.join(hyper_params['dbTrain'], item)
        for item in train_data_records if item.find('train') != -1
    ]
    train_summary, train_step, train_loss = train_procedure(
        net, train_records, reg_weight_value)

    # Validation
    val_data_records = [
        item for item in os.listdir(hyper_params['dbEval'])
        if item.endswith('.tfrecords')
    ]
    val_records = [
        os.path.join(hyper_params['dbEval'], item) for item in val_data_records
        if item.find('eval') != -1
    ]
    num_eval_samples = sum(
        1 for _ in tf.python_io.tf_record_iterator(val_records[0]))
    num_eval_itr = num_eval_samples // hyper_params['batchSize']
    num_eval_itr += 1

    val_proto, val_loss, val_d_loss, val_n_loss, val_ds_loss, val_reg_loss, val_real_dloss, \
    val_real_nloss, val_omega_loss, feild_vars = validation_procedure(net,
                                                                      val_records,
                                                                      reg_weight_value)

    valid_loss = tf.placeholder(tf.float32, name='val_loss')
    valid_loss_proto = tf.summary.scalar('Validating_TotalLoss', valid_loss)
    valid_d_loss = tf.placeholder(tf.float32, name='val_d_loss')
    valid_d_loss_proto = tf.summary.scalar('Validating_DepthL2Loss',
                                           valid_d_loss)
    valid_n_loss = tf.placeholder(tf.float32, name='val_n_loss')
    valid_n_loss_proto = tf.summary.scalar('Validating_NormalL2Loss',
                                           valid_n_loss)
    valid_ds_loss = tf.placeholder(tf.float32, name='val_ds_loss')
    valid_ds_loss_proto = tf.summary.scalar('Validating_DepthSampleL2Loss',
                                            valid_ds_loss)
    valid_reg_loss = tf.placeholder(tf.float32, name='val_reg_loss')
    valid_reg_loss_proto = tf.summary.scalar('Validating_RegL2Loss',
                                             valid_reg_loss)
    valid_real_dloss = tf.placeholder(tf.float32, name='val_real_dloss')
    valid_real_dloss_proto = tf.summary.scalar('Validating_RealDLoss',
                                               valid_real_dloss)
    valid_real_nloss = tf.placeholder(tf.float32, name='val_real_nloss')
    valid_real_nloss_proto = tf.summary.scalar('Validating_RealNLoss',
                                               valid_real_nloss)
    valid_omega_loss = tf.placeholder(tf.float32, name='val_omega_loss')
    valid_omega_loss_proto = tf.summary.scalar('Validating_OmegaLoss',
                                               valid_omega_loss)
    valid_value_merge = tf.summary.merge([
        valid_loss_proto, valid_d_loss_proto, valid_n_loss_proto,
        valid_ds_loss_proto, valid_reg_loss_proto, valid_real_dloss_proto,
        valid_real_nloss_proto, valid_omega_loss_proto
    ])

    # Saver
    tf_saver = tf.train.Saver(max_to_keep=100)
    tf_field_saver = tf.train.Saver(var_list=feild_vars)

    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True

    with tf.Session(config=config) as sess:
        # TF summary
        train_writer = tf.summary.FileWriter(output_folder + '/train',
                                             sess.graph)

        # initialize
        init_op = tf.global_variables_initializer()
        sess.run(init_op)

        # load direction field network checkpoint
        f_ckpt = tf.train.latest_checkpoint(hyper_params['fCktDir'])
        if f_ckpt:
            tf_field_saver.restore(sess, f_ckpt)
            train_logger.info('restore from the checkpoint {}'.format(f_ckpt))

        # Start input enqueue threads
        train_logger.info('pre-load data into data buffer...')
        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(sess=sess, coord=coord)

        # reg_weight init
        cur_weight = hyper_params['regWeight']
        for titr in range(hyper_params['maxIter']):

            # update regularization weight:
            # # 4 gpus
            # if titr % 5000 == 0 and titr > 0:
            #     cur_weight *= 3.985
            #     cur_weight = min(cur_weight, 2.0)
            # 2 gpus
            if titr % 10000 == 0 and titr > 0:
                cur_weight *= 3.275
                cur_weight = min(cur_weight, 2.0)
            # # 1 gup
            # if titr % 10000 == 0 and titr > 0:
            #     cur_weight *= 1.738
            #     cur_weight = min(cur_weight, 2.0)

            # Validation
            if titr % hyper_params['exeValStep'] == 0:
                idx = randint(0, num_eval_itr - 1)
                avg_loss = 0.0
                avg_d_loss = 0.0
                avg_n_loss = 0.0
                avg_ds_loss = 0.0
                avg_reg_loss = 0.0
                avg_real_dloss = 0.0
                avg_real_nloss = 0.0
                avg_omega_loss = 0.0
                for eitr in range(num_eval_itr):
                    if eitr == idx:
                        val_merge, cur_v_loss, cur_vd_loss, cur_vn_loss, cur_vds_loss, cur_vreg_loss, \
                        cur_real_dloss, cur_omega_loss, cur_real_nloss = sess.run(
                            [val_proto, val_loss, val_d_loss, val_n_loss, val_ds_loss, val_reg_loss,
                             val_real_dloss, val_omega_loss, val_real_nloss], feed_dict={'reg_weight:0': cur_weight})
                        train_writer.add_summary(val_merge, titr)
                    else:
                        cur_v_loss, cur_vd_loss, cur_vn_loss, cur_vds_loss, cur_vreg_loss, \
                        cur_real_dloss, cur_omega_loss, cur_real_nloss = \
                            sess.run([val_loss, val_d_loss, val_n_loss, val_ds_loss, val_reg_loss, val_real_dloss,
                                      val_omega_loss, val_real_nloss], feed_dict={'reg_weight:0': cur_weight})

                    avg_loss += cur_v_loss
                    avg_d_loss += cur_vd_loss
                    avg_n_loss += cur_vn_loss
                    avg_ds_loss += cur_vds_loss
                    avg_reg_loss += cur_vreg_loss
                    avg_real_dloss += cur_real_dloss
                    avg_real_nloss += cur_real_nloss
                    avg_omega_loss += cur_omega_loss

                avg_loss /= num_eval_itr
                avg_d_loss /= num_eval_itr
                avg_n_loss /= num_eval_itr
                avg_ds_loss /= num_eval_itr
                avg_reg_loss /= num_eval_itr
                avg_real_dloss /= num_eval_itr
                avg_real_nloss /= num_eval_itr
                avg_omega_loss /= num_eval_itr

                valid_summary = sess.run(valid_value_merge,
                                         feed_dict={
                                             'val_loss:0': avg_loss,
                                             'val_d_loss:0': avg_d_loss,
                                             'val_n_loss:0': avg_n_loss,
                                             'val_ds_loss:0': avg_ds_loss,
                                             'val_reg_loss:0': avg_reg_loss,
                                             'val_real_dloss:0':
                                             avg_real_dloss,
                                             'val_real_nloss:0':
                                             avg_real_nloss,
                                             'val_omega_loss:0': avg_omega_loss
                                         })
                train_writer.add_summary(valid_summary, titr)
                train_logger.info('Validation loss at step {} is: {}'.format(
                    titr, avg_loss))

            # Save model
            if titr % hyper_params['saveModelStep'] == 0:
                tf_saver.save(
                    sess, hyper_params['outDir'] +
                    '/savedModel/my_model{:d}.ckpt'.format(titr))
                train_logger.info(
                    'Save model at step: {:d}, reg weight is: {}'.format(
                        titr, cur_weight))

            # Training
            t_summary, _, t_loss = sess.run(
                [train_summary, train_step, train_loss],
                feed_dict={'reg_weight:0': cur_weight})

            # Display
            if titr % hyper_params['dispLossStep'] == 0:
                train_writer.add_summary(t_summary, titr)
                train_logger.info('Training loss at step {} is: {}'.format(
                    titr, t_loss))

        # Finish training
        coord.request_stop()
        coord.join(threads)

        # Release resource
        train_writer.close()