Ejemplo n.º 1
0
def train(input_tfr_pool, val_tfr_pool, out_dir, log_dir, mean, sbatch, wd):
    """Train Multi-View Network for a number of steps."""
    log_freq = 100
    val_freq = 1000
    model_save_freq = 10000
    tf.logging.set_verbosity(tf.logging.ERROR)

    # maximum epochs
    total_iters = 140001
    lrs = [0.01, 0.001, 0.0001]
    steps = [
        int(total_iters * 0.5),
        int(total_iters * 0.4),
        int(total_iters * 0.1)
    ]

    # set config file
    config = tf.ConfigProto(log_device_placement=False)
    with tf.Graph().as_default():
        sys.stderr.write("Building Network ... \n")
        global_step = tf.contrib.framework.get_or_create_global_step()

        images, gt_2d, gt_3d, gt_occ = create_bb_pip(input_tfr_pool,
                                                     1000,
                                                     sbatch,
                                                     mean,
                                                     shuffle=True)

        # inference model
        k2d_dim = gt_2d.get_shape().as_list()[1]
        k3d_dim = gt_3d.get_shape().as_list()[1]
        pred_key = sk_net.infer_os(images, 36, tp=True)

        # Calculate loss
        total_loss, data_loss = sk_net.L2_loss_os(pred_key,
                                                  [gt_2d, gt_3d, gt_occ],
                                                  weight_decay=wd)
        train_op, _ = optimizer(total_loss, global_step, lrs, steps)
        sys.stderr.write("Train Graph Done ... \n")
        #add_bb_summary(images, pred_key[0], gt_2d, 'train', max_out=3)

        if val_tfr_pool:
            val_pool = []
            val_iters = []
            for ix, val_tfr in enumerate(val_tfr_pool):
                total_val_num = ndata_tfrecords(val_tfr)
                total_val_iters = int(float(total_val_num) / sbatch)
                val_iters.append(total_val_iters)
                val_images, val_gt_2d, val_gt_3d, _ = create_bb_pip(
                    [val_tfr], 1000, sbatch, mean, shuffle=False)

                val_pred_key = sk_net.infer_os(val_images,
                                               36,
                                               tp=False,
                                               reuse_=True)
                _, val_data_loss = sk_net.L2_loss_23d(val_pred_key,
                                                      [val_gt_2d, val_gt_3d],
                                                      None)
                val_pool.append(val_data_loss)
                #add_bb_summary(val_images, val_pred_key[0], val_gt_2d, 'val_c' + str(ix), max_out=3)
            sys.stderr.write("Validation Graph Done ... \n")

        # merge all summaries
        merged = tf.summary.merge_all()

        init_op = tf.group(tf.global_variables_initializer(),
                           tf.local_variables_initializer())

        with tf.Session(config=config) as sess:
            summary_writer = tf.summary.FileWriter(log_dir, sess.graph)
            model_saver = tf.train.Saver(max_to_keep=15)

            sys.stderr.write("Initializing ... \n")
            # initialize graph
            sess.run(init_op)

            # initialize the queue threads to start to shovel data
            coord = tf.train.Coordinator()
            threads = tf.train.start_queue_runners(sess=sess, coord=coord)

            model_prefix = os.path.join(out_dir, 'single_key')
            timer = 0
            timer_count = 0

            sys.stderr.write("Start Training --- OUT DIM: %d, %d\n" %
                             (k2d_dim, k3d_dim))
            for i in xrange(total_iters):
                ts = time.time()
                if i > 0 and i % log_freq == 0:
                    key_loss, _, summary = sess.run(
                        [data_loss, train_op, merged])

                    summary_writer.add_summary(summary, i)
                    summary_writer.flush()

                    sys.stderr.write(
                        'Training %d (%fs) --- Key L2 Loss: %f\n' %
                        (i, timer / timer_count, key_loss))
                    timer = 0
                    timer_count = 0
                else:
                    sess.run([train_op])
                    timer += time.time() - ts
                    timer_count += 1

                if val_tfr and i > 0 and i % val_freq == 0:
                    sys.stderr.write('Validation %d\n' % i)
                    for cid, v_dl in enumerate(val_pool):
                        val_key_loss = eval_one_epoch(sess, v_dl,
                                                      val_iters[cid])
                        sys.stderr.write('Class %d --- Key L2 Loss: %f\n' %
                                         (cid, val_key_loss))

                if i > 0 and i % model_save_freq == 0:
                    model_saver.save(sess, model_prefix, global_step=i)

            model_saver.save(sess, model_prefix, global_step=i)

            summary_writer.close()
            coord.request_stop()
            coord.join(threads, stop_grace_period_secs=5)
Ejemplo n.º 2
0
def train(input_tfr_pool, val_tfr_pool, out_dir, log_dir, mean, sbatch, wd):
    """Train Multi-View Network for a number of steps."""
    log_freq = 100
    val_freq = 4000
    model_save_freq = 10000
    tf.logging.set_verbosity(tf.logging.ERROR)

    # maximum epochs
    total_iters = 151000  # smaller test...
    # total_iters = 200000 # batch_size 100
    # total_iters = 1250000 # batchsize = 16
    # lrs = [0.01, 0.001, 0.0001]

    # steps = [int(total_iters * 0.5), int(total_iters * 0.4), int(total_iters * 0.1)]

    # set config file
    config = tf.ConfigProto(log_device_placement=False)
    with tf.Graph().as_default():
        sys.stderr.write("Building Network ... \n")
        # global_step = tf.contrib.framework.get_or_create_global_step() # THIS IS REALLY MESSEED UP WHEN LOADING MODELS..

        # images, gt_key = create_bb_pip(input_tfr_pool, 1000, sbatch, mean, shuffle=True)
        images, gt_keys_hm, gt_3d = create_bb_pip(input_tfr_pool,
                                                  1000,
                                                  sbatch,
                                                  mean,
                                                  shuffle=True)

        # print(gt_key.get_shape().as_list()) # key_hm: [B, nStack, h, w, #key_points], i.e. [16, 4, 64, 64, 36]
        # inference model
        #
        # key_dim = gt_key.get_shape().as_list()[1]
        # pred_key = sk_net.infer_key(images, key_dim, tp=True)

        # out_dim = gt_keys_hm.get_shape().as_list()[-1]
        out_dim = 36
        # test_out = sk_net.modified_key23d_64_breaking(images)
        # pred_keys_hm = hg._graph_hourglass(input=images, dropout_rate=0.2, outDim=out_dim, tiny=False, modif=False, is_training=True)

        #preparation with 3d intermediate supervision...
        hg_input, pred_3d = sk_net.modified_hg_preprocessing_with_3d_info_v2(
            images, 36 * 2, 36 * 3, reuse_=False, tp=False)  # fix prep part

        vars_avg = tf.train.ExponentialMovingAverage(0.9)
        vars_to_restore = vars_avg.variables_to_restore()
        # print(vars_to_restore)
        model_saver = tf.train.Saver(
            vars_to_restore
        )  # when you write the model_saver matters... it will restore up to this point

        r3 = tf.image.resize_nearest_neighbor(
            hg_input, size=[64, 64])  # shape=(16, 64, 64, 256), dtype=float32)

        # drop out rate - something we can play with
        pred_keys_hm = hg._graph_hourglass_modified_v1(
            input=r3,
            dropout_rate=0.2,
            outDim=out_dim,
            tiny=False,
            modif=False,
            is_training=True)  # shape=(16, 4, 64, 64, 36), dtype=float32)

        # Calculate loss
        # total_loss, data_loss = sk_net.L2_loss_key(pred_key, gt_key, weight_decay=wd)
        # train_op, _ = optimizer(total_loss, global_step, lrs, steps)

        k2d_hm_loss = ut._bce_loss(
            logits=pred_keys_hm,
            gtMaps=gt_keys_hm,
            name='ce_loss',
            weighted=False
        )  # 4 stacks / 4.... not dividing by 4, just to keep it consistent with what i've done before
        # k3d_loss = 0.0025 * tf.nn.l2_loss(pred_3d - gt_3d)
        # total_loss = tf.add_n([k2d_hm_loss, k3d_loss])
        total_loss = k2d_hm_loss
        init_learning_rate = 2.5e-4  # to be deteremined
        # # exp decay: 125000 / 2000 = 625decays,   0.992658^625 ~=0.01, 0.99^625 ~= 0.00187

        # learn from scratch
        model_saver_23d_v1 = tf.train.Saver()

        train_step = tf.Variable(
            0, name='train_steps', trainable=False
        )  # you need to move train_step here in order to avoid being loaded

        lr_hg = tf.train.exponential_decay(init_learning_rate,
                                           global_step=train_step,
                                           decay_rate=0.96,
                                           decay_steps=2000,
                                           staircase=True,
                                           name="learning_rate")
        # # not learning from scratch
        # model_saver_23d_v1 = tf.train.Saver()

        rmsprop_optimizer = tf.train.RMSPropOptimizer(learning_rate=lr_hg)

        # disgusting....
        update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
        with tf.control_dependencies(update_ops):
            train_op_hg = rmsprop_optimizer.minimize(total_loss, train_step)

        sys.stderr.write("Train Graph Done ... \n")
        #
        # # add_bb_summary_hm(images, pred_keys_hm, gt_keys_hm, 'train', max_out=3) # TODO: enable it
        if val_tfr_pool:
            val_pool = []
            val_iters = []
            accur_pool = []
            for ix, val_tfr in enumerate(val_tfr_pool):
                total_val_num = ndata_tfrecords(val_tfr)
                total_val_iters = int(float(total_val_num) /
                                      sbatch)  # num of batches, iters / epoch
                val_iters.append(total_val_iters)
                # val_images, val_gt_key = create_bb_pip([val_tfr],
                #                                        1000, sbatch, mean, shuffle=False)
                val_images, val_gt_keys_hm, val_gt_3d = create_bb_pip(
                    [val_tfr], 1000, sbatch, mean, shuffle=False)

                val_r3 = sk_net.modified_hg_preprocessing_with_3d_info(
                    val_images, 36 * 2, 36 * 3, reuse_=True, tp=False)
                val_r3 = tf.image.resize_nearest_neighbor(
                    val_r3,
                    size=[64, 64])  # shape=(16, 64, 64, 256), dtype=float32)
                # val_pred_key = sk_net.infer_key(val_images, key_dim, tp=False, reuse_=True)

                # val_pred_key = sk_net.infer_key(val_images, key_dim, tp=False, reuse_=True)
                val_pred_keys_hm = hg._graph_hourglass_modified_v1(
                    input=val_r3,
                    outDim=out_dim,
                    is_training=False,
                    tiny=False,
                    modif=False,
                    reuse=True)

                # _, val_data_loss = sk_net.L2_loss_key(val_pred_key, val_gt_key, None)
                val_train_loss_hg = ut._bce_loss(logits=val_pred_keys_hm,
                                                 gtMaps=val_gt_keys_hm,
                                                 name="val_ce_loss")
                # val_pool.append(val_data_loss)
                val_accur = ut._accuracy_computation(output=val_pred_keys_hm,
                                                     gtMaps=val_gt_keys_hm,
                                                     nStack=4,
                                                     batchSize=16)

                val_pool.append(val_train_loss_hg)
                accur_pool.append(val_accur)
        #
        #         # add_bb_summary(val_images, val_pred_key, val_gt_key, 'val_c' + str(ix), max_out=3)
        #         # add_bb_summary_hm(val_images, val_pred_keys_hm, val_gt_keys_hm, 'val_c' + str(ix), max_out=3) # TODO: argmax pred, draw
            sys.stderr.write("Validation Graph Done ... \n")
        #
        # # merge all summaries
        # # merged = tf.summary.merge_all()
        merged = tf.constant(0)

        init_op = tf.group(tf.global_variables_initializer(),
                           tf.local_variables_initializer())

        with tf.Session(config=config) as sess:
            summary_writer = tf.summary.FileWriter(log_dir, sess.graph)
            sys.stderr.write("Initializing ... \n")
            # initialize graph
            sess.run(init_op)

            #########################################################################################################
            ###### disable/enable pre-trained weight loading
            # print('restoring')
            # model_saver.restore(sess, '/home/tliao4/Desktop/new_tf_car_keypoint/tf_car_keypoint/src/log_hg_s4_256/L23d_pmc/model/single_key_4s_hg-85000') # 85k steps
            #model_saver.restore(sess, 'L23d_non_iso/single_key-144000')
            # print("Successfully restored 3d preprocessing")
            #########################################################################################################

            print(
                'restoring -v3'
            )  # v3 refers to: no-pretrain vgg part(preprocessing, then fix it and train hourglass only)
            model_saver_23d_v1.restore(
                sess,
                'log_hg_s4_256_23d_v2_1.0/model/single_key_4s_hg_23d_v2_1.0-149999'
            )
            print('restored successfully - v3')

            print('initial-sanity check')
            print('init_step: ', sess.run(train_step))
            print('init_lr: ', sess.run(lr_hg))

            # check
            # initialize the queue threads to start to shovel data
            coord = tf.train.Coordinator()
            threads = tf.train.start_queue_runners(sess=sess, coord=coord)
            # for i in 4*range(10):
            #     print("check-img/n", (sess.run(images[1, 60+i, 60+i,:])))
            # print(images)

            model_prefix = os.path.join(out_dir,
                                        'single_key_4s_hg_23d_v3_1.0_8-10')
            timer = 0
            timer_count = 0
            sys.stderr.write("Start Training --- OUT DIM: %d\n" % (out_dim))
            logger.info("Start Training --- OUT DIM: %d\n" % (out_dim))
            for iter in xrange(total_iters):
                ts = time.time()
                if iter > 0 and iter % log_freq == 0:
                    # print('lr', sess.run(lr_hg))
                    # print('global_step', sess.run(train_step))
                    # key_loss, _, summary = sess.run([data_loss, train_op, merged])
                    key_loss, _, summary = sess.run(
                        [total_loss, train_op_hg, merged])

                    # summary_writer.add_summary(summary, i)
                    # summary_writer.flush()

                    sys.stderr.write(
                        'Training %d (%fs) --- Key L2 Loss: %f\n' %
                        (iter, timer / timer_count, key_loss))
                    logger.info(('Training %d (%fs) --- Key L2 Loss: %f\n' %
                                 (iter, timer / timer_count, key_loss)))
                    timer = 0
                    timer_count = 0
                else:
                    # sess.run([train_op])
                    sess.run([train_op_hg, pred_3d])
                    timer += time.time() - ts
                    timer_count += 1

                if val_tfr and iter > 0 and iter % val_freq == 0:
                    cur_lr = lr_hg.eval()
                    print("lr: ", cur_lr)
                    logger.info('lr: {}'.format(cur_lr))

                    sys.stderr.write('Validation %d\n' % iter)
                    logger.info(('Validation %d\n' % iter))
                    # loss
                    for cid, v_dl in enumerate(val_pool):
                        val_key_loss = eval_one_epoch(sess, v_dl,
                                                      val_iters[cid])
                        sys.stderr.write('Class %d --- Key HM CE Loss: %f\n' %
                                         (cid, val_key_loss))
                        logger.info('Class %d --- Key HM CE Loss: %f\n' %
                                    (cid, val_key_loss))
                    #
                    for cid, accur in enumerate(accur_pool):
                        rec = []
                        for i in range(val_iters[cid]):
                            acc = sess.run(accur)  # acc: [(float)*36]
                            rec.append(acc)
                        rec = np.array(rec)
                        rec = np.mean(rec, axis=0)
                        avg_accur = np.mean(rec)
                        temp_dict = {}
                        for k in range(36):
                            temp_dict['kp_' + str(iter)] = rec[k]
                        sys.stderr.write('Class %d -- Avg Accuracy : %f\n' %
                                         (cid, avg_accur))
                        sys.stderr.write(
                            'Classs {} -- All Accuracy:\n{}\n'.format(
                                cid, rec))
                        logger.info('Class %d -- Avg Accuracy : %f\n' %
                                    (cid, avg_accur))
                        logger.info('Class {} -- All Accuracy:\n {}\n'.format(
                            cid, rec))

                if iter > 0 and iter % model_save_freq == 0:
                    model_saver_23d_v1.save(sess,
                                            model_prefix,
                                            global_step=iter)

            model_saver_23d_v1.save(sess, model_prefix, global_step=iter)

            summary_writer.close()
            coord.request_stop()
            coord.join(threads, stop_grace_period_secs=5)
Ejemplo n.º 3
0
def train(input_tfr_pool, val_tfr_pool, out_dir, log_dir, mean, sbatch, wd):
    """Train Multi-View Network for a number of steps."""
    log_freq = 100
    val_freq = 2000
    model_save_freq = 5000
    tf.logging.set_verbosity(tf.logging.ERROR)

    # maximum epochs
    total_iters = 200000  # smaller test...
    # total_iters = 200000 # batch_size 100
    # total_iters = 1250000 # batchsize = 16
    # lrs = [0.01, 0.001, 0.0001]

    # steps = [int(total_iters * 0.5), int(total_iters * 0.4), int(total_iters * 0.1)]

    # set config file
    config = tf.ConfigProto(log_device_placement=False)
    with tf.Graph().as_default():
        sys.stderr.write("Building Network ... \n")
        global_step = tf.contrib.framework.get_or_create_global_step()

        # images, gt_key = create_bb_pip(input_tfr_pool, 1000, sbatch, mean, shuffle=True)
        images, gt_keys_hm = create_bb_pip(input_tfr_pool,
                                           1000,
                                           sbatch,
                                           mean,
                                           shuffle=True)

        # print(gt_key.get_shape().as_list()) # key_hm: [B, nStack, h, w, #key_points], i.e. [16, 4, 64, 64, 36]
        # inference model
        #
        # key_dim = gt_key.get_shape().as_list()[1]
        # pred_key = sk_net.infer_key(images, key_dim, tp=True)

        # out_dim = gt_keys_hm.get_shape().as_list()[-1]
        out_dim = 36
        pred_keys_hm = hg._graph_hourglass(input=images,
                                           nFeat=512,
                                           dropout_rate=0.2,
                                           outDim=out_dim,
                                           tiny=False,
                                           modif=False,
                                           is_training=True)

        # Calculate loss
        # total_loss, data_loss = sk_net.L2_loss_key(pred_key, gt_key, weight_decay=wd)
        # train_op, _ = optimizer(total_loss, global_step, lrs, steps)

        total_loss = ut._bce_loss(logits=pred_keys_hm,
                                  gtMaps=gt_keys_hm,
                                  name='ce_loss',
                                  weighted=False)
        init_learning_rate = 2.5e-4  # to be deteremined
        # exp decay: 125000 / 2000 = 625decays,   0.992658^625 ~=0.01, 0.99^625 ~= 0.00187
        lr_hg = tf.train.exponential_decay(init_learning_rate,
                                           global_step=global_step,
                                           decay_rate=0.96,
                                           decay_steps=2000,
                                           staircase=True,
                                           name="learning_rate")

        rmsprop_optimizer = tf.train.RMSPropOptimizer(learning_rate=lr_hg)

        # disgusting....
        update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
        with tf.control_dependencies(update_ops):
            train_op_hg = rmsprop_optimizer.minimize(total_loss, global_step)

        sys.stderr.write("Train Graph Done ... \n")

        # add_bb_summary_hm(images, pred_keys_hm, gt_keys_hm, 'train', max_out=3) # TODO: enable it
        if val_tfr_pool:
            val_pool = []
            val_iters = []
            accur_pool = []
            for ix, val_tfr in enumerate(val_tfr_pool):
                total_val_num = ndata_tfrecords(val_tfr)
                total_val_iters = int(float(total_val_num) /
                                      sbatch)  # num of batches, iters / epoch
                val_iters.append(total_val_iters)
                # val_images, val_gt_key = create_bb_pip([val_tfr],
                #                                        1000, sbatch, mean, shuffle=False)
                val_images, val_gt_keys_hm = create_bb_pip([val_tfr],
                                                           1000,
                                                           sbatch,
                                                           mean,
                                                           shuffle=False)

                # val_pred_key = sk_net.infer_key(val_images, key_dim, tp=False, reuse_=True)

                # val_pred_key = sk_net.infer_key(val_images, key_dim, tp=False, reuse_=True)
                val_pred_keys_hm = hg._graph_hourglass(input=val_images,
                                                       outDim=out_dim,
                                                       is_training=False,
                                                       tiny=False,
                                                       modif=False,
                                                       reuse=True)

                # _, val_data_loss = sk_net.L2_loss_key(val_pred_key, val_gt_key, None)
                val_train_loss_hg = ut._bce_loss(logits=val_pred_keys_hm,
                                                 gtMaps=val_gt_keys_hm,
                                                 name="val_ce_loss")
                # val_pool.append(val_data_loss)
                val_accur = ut._accuracy_computation(output=val_pred_keys_hm,
                                                     gtMaps=val_gt_keys_hm,
                                                     nStack=4,
                                                     batchSize=16)

                val_pool.append(val_train_loss_hg)
                accur_pool.append(val_accur)

                # add_bb_summary(val_images, val_pred_key, val_gt_key, 'val_c' + str(ix), max_out=3)
                # add_bb_summary_hm(val_images, val_pred_keys_hm, val_gt_keys_hm, 'val_c' + str(ix), max_out=3) # TODO: argmax pred, draw
            sys.stderr.write("Validation Graph Done ... \n")

        # merge all summaries
        # merged = tf.summary.merge_all()
        merged = tf.constant(0)

        init_op = tf.group(tf.global_variables_initializer(),
                           tf.local_variables_initializer())

        with tf.Session(config=config) as sess:
            summary_writer = tf.summary.FileWriter(log_dir, sess.graph)
            model_saver = tf.train.Saver()

            sys.stderr.write("Initializing ... \n")
            # initialize graph
            sess.run(init_op)

            # check
            # initialize the queue threads to start to shovel data
            coord = tf.train.Coordinator()
            threads = tf.train.start_queue_runners(sess=sess, coord=coord)
            # for i in 4*range(10):
            #     print("check-img/n", (sess.run(images[1, 60+i, 60+i,:])))
            # print(images)

            model_prefix = os.path.join(out_dir, 'single_key_4s_hg')
            timer = 0
            timer_count = 0
            sys.stderr.write("Start Training --- OUT DIM: %d\n" % (out_dim))
            logger.info("Start Training --- OUT DIM: %d\n" % (out_dim))
            for iter in xrange(total_iters):
                ts = time.time()
                if iter > 0 and iter % log_freq == 0:
                    # key_loss, _, summary = sess.run([data_loss, train_op, merged])
                    key_loss, _, summary = sess.run(
                        [total_loss, train_op_hg, merged])

                    # summary_writer.add_summary(summary, i)
                    # summary_writer.flush()

                    sys.stderr.write(
                        'Training %d (%fs) --- Key L2 Loss: %f\n' %
                        (iter, timer / timer_count, key_loss))
                    logger.info(('Training %d (%fs) --- Key L2 Loss: %f\n' %
                                 (iter, timer / timer_count, key_loss)))
                    timer = 0
                    timer_count = 0
                else:
                    # sess.run([train_op])
                    sess.run([train_op_hg])
                    timer += time.time() - ts
                    timer_count += 1

                if val_tfr and iter > 0 and iter % val_freq == 0:
                    cur_lr = lr_hg.eval()
                    print("lr: ", cur_lr)
                    logger.info('lr: {}'.format(cur_lr))

                    sys.stderr.write('Validation %d\n' % iter)
                    logger.info(('Validation %d\n' % iter))
                    # loss
                    for cid, v_dl in enumerate(val_pool):
                        val_key_loss = eval_one_epoch(sess, v_dl,
                                                      val_iters[cid])
                        sys.stderr.write('Class %d --- Key HM CE Loss: %f\n' %
                                         (cid, val_key_loss))
                        logger.info('Class %d --- Key HM CE Loss: %f\n' %
                                    (cid, val_key_loss))
                    #
                    for cid, accur in enumerate(accur_pool):
                        rec = []
                        for i in range(val_iters[cid]):
                            acc = sess.run(accur)  # acc: [(float)*36]
                            rec.append(acc)
                        rec = np.array(rec)
                        rec = np.mean(rec, axis=0)
                        avg_accur = np.mean(rec)
                        temp_dict = {}
                        for k in range(36):
                            temp_dict['kp_' + str(iter)] = rec[k]
                        sys.stderr.write('Class %d -- Avg Accuracy : %f\n' %
                                         (cid, avg_accur))
                        sys.stderr.write(
                            'Classs {} -- All Accuracy:\n{}\n'.format(
                                cid, rec))
                        logger.info('Class %d -- Avg Accuracy : %f\n' %
                                    (cid, avg_accur))
                        logger.info('Class {} -- All Accuracy:\n {}\n'.format(
                            cid, rec))

                if iter > 0 and iter % model_save_freq == 0:
                    model_saver.save(sess, model_prefix, global_step=iter)

            model_saver.save(sess, model_prefix, global_step=iter)

            summary_writer.close()
            coord.request_stop()
            coord.join(threads, stop_grace_period_secs=5)