def prepare_s3dis_test():
    train_list, test_list = get_block_train_test_split()

    for fn in test_list:
        bg = time.time()
        prepare_s3dis_test_single_file(fn)
        print 'done {} cost {} s'.format(fn, time.time() - bg)
示例#2
0
def eval():
    train_list, test_list = get_block_train_test_split()
    test_list = ['data/S3DIS/sampled_test/' + fn for fn in test_list]

    def fn(model, filename):
        data = read_pkl(filename)
        return data[0], data[2], data[3], data[4], data[12]

    test_provider = Provider(test_list, 'test',
                             FLAGS.batch_size * FLAGS.num_gpus, fn)

    try:
        pls = build_placeholder(FLAGS.num_gpus)
        batch_num_per_epoch = 2000 / FLAGS.num_gpus
        ops = train_ops(pls['xyzs'], pls['feats'], pls['lbls'],
                        pls['is_training'], batch_num_per_epoch)

        feed_dict = {}
        config = tf.ConfigProto()
        config.gpu_options.allow_growth = True
        config.allow_soft_placement = True
        config.log_device_placement = False
        sess = tf.Session(config=config)

        saver = tf.train.Saver(max_to_keep=500)
        saver.restore(sess, FLAGS.eval_model)
        summary_writer = tf.summary.FileWriter(FLAGS.train_dir,
                                               graph=sess.graph)
        test_one_epoch(ops, pls, sess, saver, test_provider, 0, feed_dict,
                       summary_writer)

    finally:
        test_provider.close()
def merge_train_by_area():
    from io_util import get_block_train_test_split
    train_list, test_list = get_block_train_test_split()
    random.shuffle(train_list)
    f = open('cached/s3dis_merged_train.txt', 'w')
    for ai in xrange(1, 7):
        cur_data = [[] for _ in xrange(5)]
        cur_idx = 0
        for fn in train_list:
            an = get_area(fn)
            if an != ai: continue
            data = read_pkl('data/S3DIS/sampled_train_new/' + fn)
            for i in xrange(5):
                cur_data[i] += data[i]

            if len(cur_data[0]) > 1000:
                save_pkl(
                    'data/S3DIS/merged_train_new/{}_{}.pkl'.format(
                        ai, cur_idx), cur_data)
                f.write('data/S3DIS/merged_train_new/{}_{}.pkl\n'.format(
                    ai, cur_idx))
                cur_idx += 1
                cur_data = [[] for _ in xrange(5)]

        if len(cur_data[0]) > 0:
            save_pkl(
                'data/S3DIS/merged_train_new/{}_{}.pkl'.format(ai, cur_idx),
                cur_data)
            f.write('data/S3DIS/merged_train_new/{}_{}.pkl\n'.format(
                ai, cur_idx))
            cur_idx += 1

        print 'area {} done'.format(ai)

    f.close()
def compute_weight():
    from io_util import get_block_train_test_split, get_class_names
    import numpy as np
    train_list, test_list = get_block_train_test_split()

    test_list = ['data/S3DIS/sampled_test/' + fs for fs in test_list]
    train_list = ['data/S3DIS/sampled_train/' + fs for fs in train_list]
    test_list += train_list
    labels = []
    for fs in test_list:
        labels += read_pkl(fs)[4]
    labels = np.concatenate(labels, axis=0)

    labelweights, _ = np.histogram(labels, range(14))
    plt.figure(0, figsize=(10, 8), dpi=80)
    plt.bar(np.arange(len(labelweights)),
            labelweights,
            tick_label=get_class_names())
    plt.savefig('s3dis_dist.png')
    plt.close()

    print labelweights
    labelweights = labelweights.astype(np.float32)
    labelweights = labelweights / np.sum(labelweights)
    labelweights = 1 / np.log(1.2 + labelweights)

    print labelweights
def eval():
    train_list, test_list = get_block_train_test_split()
    test_list = ['data/S3DIS/room_block_10_10/' + fn for fn in test_list]

    test_provider = Provider(test_list, 'test',
                             FLAGS.batch_size * FLAGS.num_gpus, read_fn)

    try:
        pls = {}
        pls['xyzs'],pls['lbls'],pls['rgbs'],pls['covars'],pls['nidxs'],\
        pls['nidxs_lens'],pls['nidxs_bgs'],pls['cidxs']=[],[],[],[],[],[],[],[]
        pls['weights'] = []
        for i in xrange(FLAGS.num_gpus):
            pls['xyzs'].append(
                tf.placeholder(tf.float32, [None, 3], 'xyz{}'.format(i)))
            pls['rgbs'].append(
                tf.placeholder(tf.float32, [None, 3], 'rgb{}'.format(i)))
            pls['covars'].append(
                tf.placeholder(tf.float32, [None, 9], 'covar{}'.format(i)))
            pls['lbls'].append(
                tf.placeholder(tf.int64, [None], 'lbl{}'.format(i)))
            pls['nidxs'].append(
                tf.placeholder(tf.int32, [None], 'nidxs{}'.format(i)))
            pls['nidxs_lens'].append(
                tf.placeholder(tf.int32, [None], 'nidxs_lens{}'.format(i)))
            pls['nidxs_bgs'].append(
                tf.placeholder(tf.int32, [None], 'nidxs_bgs{}'.format(i)))
            pls['cidxs'].append(
                tf.placeholder(tf.int32, [None], 'cidxs{}'.format(i)))
            pls['weights'].append(
                tf.placeholder(tf.float32, [None], 'weights{}'.format(i)))

        pmiu = neighbor_anchors_v2()
        pls['is_training'] = tf.placeholder(tf.bool, name='is_training')
        pls['pmiu'] = tf.placeholder(tf.float32, name='pmiu')

        batch_num_per_epoch = 2500 / FLAGS.num_gpus
        ops = train_ops(pls['xyzs'], pls['rgbs'], pls['covars'], pls['lbls'],
                        pls['cidxs'], pls['nidxs'], pls['nidxs_lens'],
                        pls['nidxs_bgs'], pmiu.shape[1], pls['is_training'],
                        batch_num_per_epoch, pls['pmiu'], pls['weights'])

        feed_dict = {}
        feed_dict[pls['pmiu']] = pmiu
        config = tf.ConfigProto()
        config.gpu_options.allow_growth = True
        config.allow_soft_placement = True
        config.log_device_placement = False
        sess = tf.Session(config=config)

        saver = tf.train.Saver(max_to_keep=500)
        saver.restore(sess, FLAGS.eval_model)
        summary_writer = tf.summary.FileWriter(FLAGS.train_dir,
                                               graph=sess.graph)
        test_one_epoch(ops, pls, sess, saver, test_provider, 0, feed_dict,
                       summary_writer)

    finally:
        test_provider.close()
def prepare_s3dis_train():
    train_list, test_list = get_block_train_test_split()

    from concurrent.futures import ProcessPoolExecutor
    executor = ProcessPoolExecutor(8)
    futures = []
    for fn in train_list:
        futures.append(executor.submit(prepare_s3dis_train_single_file, fn))

    for future in futures:
        future.result()
示例#7
0
def train():
    train_list, test_list = get_block_train_test_split()
    # test_list=['data/S3DIS/sampled_train/'+fn for fn in train_list[:2]]
    train_list = ['data/S3DIS/sampled_train/' + fn for fn in train_list]
    test_list = ['data/S3DIS/sampled_test/' + fn for fn in test_list]

    def fn(model, filename):
        data = read_pkl(filename)
        return data[0], data[2], data[3], data[4], data[12]

    train_provider = Provider(train_list, 'train',
                              FLAGS.batch_size * FLAGS.num_gpus, fn)
    test_provider = Provider(test_list, 'test',
                             FLAGS.batch_size * FLAGS.num_gpus, fn)

    try:
        pls = build_placeholder(FLAGS.num_gpus)
        batch_num_per_epoch = 2000 / FLAGS.num_gpus
        ops = train_ops(pls['xyzs'], pls['feats'], pls['lbls'],
                        pls['is_training'], batch_num_per_epoch)

        feed_dict = {}
        config = tf.ConfigProto()
        config.gpu_options.allow_growth = True
        config.allow_soft_placement = True
        config.log_device_placement = False
        sess = tf.Session(config=config)
        saver = tf.train.Saver(max_to_keep=500)
        if FLAGS.restore:
            saver.restore(sess, FLAGS.restore_model)
        else:
            sess.run(tf.global_variables_initializer())

            base_var = [
                var for var in tf.trainable_variables() if
                var.name.startswith('base') or var.name.startswith('class_mlp')
            ]
            base_saver = tf.train.Saver(base_var)
            base_saver.restore(sess, FLAGS.base_restore)

        summary_writer = tf.summary.FileWriter(FLAGS.train_dir,
                                               graph=sess.graph)

        for epoch_num in xrange(FLAGS.restore_epoch, FLAGS.train_epoch_num):
            train_one_epoch(ops, pls, sess, summary_writer, train_provider,
                            epoch_num, feed_dict)
            test_one_epoch(ops, pls, sess, saver, test_provider, epoch_num,
                           feed_dict)

    finally:
        train_provider.close()
        test_provider.close()
def visual_room():
    train_list, test_list = get_block_train_test_split()
    train_list += test_list
    file_list = [fn for fn in train_list if fn.split('_')[-2] == 'office']
    from draw_util import get_class_colors, output_points
    colors = get_class_colors()
    for fn in file_list:
        xyzs, rgbs, covars, labels, block_mins = read_pkl(
            'data/S3DIS/office_block/' + fn)
        for k in xrange(len(xyzs)):
            xyzs[k] += block_mins[k]
        xyzs = np.concatenate(xyzs, axis=0)
        labels = np.concatenate(labels, axis=0)

        output_points('test_result/{}.txt'.format(fn), xyzs, colors[labels])
def train():
    train_list, test_list = get_block_train_test_split()
    train_list = ['data/S3DIS/sampled_train/' + fn for fn in train_list]
    test_list = ['data/S3DIS/sampled_test/' + fn for fn in test_list]
    fn = lambda model, filename: read_pkl(filename)

    train_provider = Provider(train_list, 'train',
                              FLAGS.batch_size * FLAGS.num_gpus, fn)
    test_provider = Provider(test_list, 'test',
                             FLAGS.batch_size * FLAGS.num_gpus, fn)

    try:
        pls = build_placeholder(FLAGS.num_gpus)
        pmiu = neighbor_anchors_v2()

        batch_num_per_epoch = 2000 / FLAGS.num_gpus
        ops = train_ops(pls['cxyzs'], pls['dxyzs'], pls['rgbs'], pls['covars'],
                        pls['vlens'], pls['vlens_bgs'], pls['vcidxs'],
                        pls['cidxs'], pls['nidxs'], pls['nidxs_lens'],
                        pls['nidxs_bgs'], pls['lbls'], pmiu.shape[1],
                        pls['is_training'], batch_num_per_epoch, pls['pmiu'])

        feed_dict = {}
        feed_dict[pls['pmiu']] = pmiu
        config = tf.ConfigProto()
        config.gpu_options.allow_growth = True
        config.allow_soft_placement = True
        config.log_device_placement = False
        sess = tf.Session(config=config)
        saver = tf.train.Saver(max_to_keep=500)
        if FLAGS.restore:
            saver.restore(sess, FLAGS.restore_model)
        else:
            sess.run(tf.global_variables_initializer())

        summary_writer = tf.summary.FileWriter(FLAGS.train_dir,
                                               graph=sess.graph)

        for epoch_num in xrange(FLAGS.restore_epoch, FLAGS.train_epoch_num):
            train_one_epoch(ops, pls, sess, summary_writer, train_provider,
                            epoch_num, feed_dict)
            test_one_epoch(ops, pls, sess, saver, test_provider, epoch_num,
                           feed_dict)

    finally:
        train_provider.close()
        test_provider.close()
示例#10
0
def train():
    train_list, test_list = get_block_train_test_split()
    train_list = [
        'data/S3DIS/sampled_train_nolimits/' + fn for fn in train_list
    ]
    # train_list=['data/S3DIS/sampled_train_no_aug/'+fn for fn in train_list]
    # with open('cached/s3dis_merged_train.txt', 'r') as f:
    #     train_list=[line.strip('\n') for line in f.readlines()]
    random.shuffle(train_list)
    test_list = ['data/S3DIS/sampled_test_nolimits/' + fn for fn in test_list]

    train_provider = Provider(train_list, 'train',
                              FLAGS.batch_size * FLAGS.num_gpus, test_fn)
    test_provider = Provider(test_list, 'test',
                             FLAGS.batch_size * FLAGS.num_gpus, test_fn)

    try:
        pls = build_placeholder(FLAGS.num_gpus)
        batch_num_per_epoch = 2000 / FLAGS.num_gpus
        ops = train_ops(pls['xyzs'], pls['feats'], pls['lbls'],
                        pls['is_training'], batch_num_per_epoch)

        feed_dict = {}
        config = tf.ConfigProto()
        config.gpu_options.allow_growth = True
        config.allow_soft_placement = True
        config.log_device_placement = False
        sess = tf.Session(config=config)
        saver = tf.train.Saver(max_to_keep=500)
        if FLAGS.restore:
            saver.restore(sess, FLAGS.restore_model)
        else:
            sess.run(tf.global_variables_initializer())

        summary_writer = tf.summary.FileWriter(FLAGS.train_dir,
                                               graph=sess.graph)

        for epoch_num in xrange(FLAGS.restore_epoch, FLAGS.train_epoch_num):
            train_one_epoch(ops, pls, sess, summary_writer, train_provider,
                            epoch_num, feed_dict)
            test_one_epoch(ops, pls, sess, saver, test_provider, epoch_num,
                           feed_dict)

    finally:
        train_provider.close()
        test_provider.close()
def train():
    import random
    train_list, test_list = get_block_train_test_split()
    train_list = [
        'data/S3DIS/sampled_train_nolimits/' + fn for fn in train_list
    ]
    random.shuffle(train_list)
    test_list = ['data/S3DIS/sampled_test_nolimits/' + fn for fn in test_list]

    def test_fn(model, filename):
        xyzs, rgbs, covars, lbls, block_mins = read_pkl(filename)
        return xyzs, rgbs, covars, lbls, block_mins

    train_provider = Provider(train_list, 'train',
                              FLAGS.batch_size * FLAGS.num_gpus, test_fn)
    test_provider = Provider(test_list, 'test',
                             FLAGS.batch_size * FLAGS.num_gpus, test_fn)

    try:
        pls = build_placeholder(FLAGS.num_gpus)
        batch_num_per_epoch = 2000 / FLAGS.num_gpus
        ops = train_ops(pls['xyzs'], pls['feats'], pls['lbls'],
                        pls['is_training'], batch_num_per_epoch)

        feed_dict = {}
        config = tf.ConfigProto()
        config.gpu_options.allow_growth = True
        config.allow_soft_placement = True
        config.log_device_placement = False
        sess = tf.Session(config=config)
        sess.run(tf.global_variables_initializer())
        var_list = [
            var for var in tf.trainable_variables()
            if not var.name.startswith('class_mlp')
        ]
        saver = tf.train.Saver(max_to_keep=500, var_list=var_list)
        saver.restore(sess, FLAGS.restore_model)

        for epoch_num in xrange(FLAGS.restore_epoch, FLAGS.train_epoch_num):
            train_one_epoch(ops, pls, sess, train_provider, epoch_num,
                            feed_dict)
            test_one_epoch(ops, pls, sess, test_provider, epoch_num, feed_dict)

    finally:
        train_provider.close()
        test_provider.close()
示例#12
0
def downsample_and_save():
    train_list,test_list=get_block_train_test_split()
    idx=0
    train_list+=test_list
    with open('cached/room_block_ds0.03_stems.txt','w') as f:
        for _ in xrange(5):
            for fs in train_list:
                points,labels=read_room_pkl('data/S3DIS/room_block_10_10/'+fs)
                points, labels=downsample(points,labels,0.03)

                names=fs.split('_')
                names[0]=str(idx)
                nfs='_'.join(names)
                f.write(nfs+'\n')
                ofs='data/S3DIS/room_block_10_10_ds0.03/'+nfs
                save_room_pkl(ofs,points,labels)
                idx+=1
示例#13
0
def load_data():
    train_list, test_list = get_block_train_test_split()
    train_list = ['data/S3DIS/sampled_train/' + fn for fn in train_list]

    def fn(model, filename):
        data = read_pkl(filename)
        return data[0], data[2], data[3], data[4], data[12]

    random.shuffle(train_list)
    all_xyzs, all_feats, all_labels = [], [], []
    for fs in train_list[:20]:
        xyzs, rgbs, covars, labels, block_mins = fn('', fs)
        for i in xrange(3):
            all_xyzs.append(xyzs[i][0])
            feats = np.concatenate([rgbs[i], covars[i]], axis=1)
            all_feats.append(feats)
            all_labels.append(labels[i])

    return all_xyzs, all_feats, all_labels
def compare():
    from draw_util import output_points
    from sklearn.cluster import KMeans
    train_list, test_list = get_block_train_test_split()
    random.shuffle(train_list)

    train_list_add = ['data/S3DIS/sampled_train/' + fn for fn in train_list]
    for fi, fs in enumerate(train_list_add[:3]):
        cxyzs, dxyzs, rgbs, covars, lbls, vlens, vlens_bgs, vcidxs, cidxs, nidxs, nidxs_bgs, nidxs_lens, block_mins = read_pkl(
            fs)

        for i in xrange(len(cxyzs[:10])):
            print np.sum(np.sum(np.abs(covars[i]), axis=1) < 1e-3)
            kmeans = KMeans(5)
            colors = np.random.randint(0, 256, [5, 3])
            preds = kmeans.fit_predict(covars[i])
            output_points('test_result/{}_{}.txt'.format(fi, i), cxyzs[i][0],
                          colors[preds])

    print '//////////////////////////'
def test_block_train():
    train_list, test_list = get_block_train_test_split()

    from draw_util import get_class_colors, output_points
    # colors=get_class_colors()
    # for fn in train_list[:1]:
    #     xyzs, rgbs, covars, lbls, block_mins=read_pkl(fn)
    #
    #     for i in xrange(len(xyzs[:5])):
    #         rgbs[i]+=128
    #         rgbs[i]*=127
    #         output_points('test_result/{}clr.txt'.format(i),xyzs[i],rgbs[i])
    #         output_points('test_result/{}lbl.txt'.format(i),xyzs[i],colors[lbls[i]])
    # count=0
    # pt_nums=[]
    #
    # stem2num={}
    # for fn in train_list:
    #     xyzs, rgbs, covars, lbls, block_mins=read_pkl('data/S3DIS/sampled_train_nolimits/'+fn)
    #     stem='_'.join(fn.split('_')[1:])
    #     if stem in stem2num:
    #         stem2num[stem]+=len(xyzs)
    #     else:
    #         stem2num[stem]=len(xyzs)
    #
    #     print stem,stem2num[stem]
    #     count+=len(xyzs)
    #     pt_nums+=[len(pts) for pts in xyzs]
    #
    # print count
    # print np.max(pt_nums)
    # print np.histogram(pt_nums)

    xyzs, rgbs, covars, lbls, block_mins = read_pkl(
        'data/S3DIS/sampled_train_nolimits/{}'.format(
            '1_Area_1_conferenceRoom_2.pkl'))
    for i in xrange(len(xyzs)):
        output_points('test_result/{}.txt'.format(i), xyzs[i] + block_mins[i],
                      rgbs[i] * 127 + 128)
def test_covar():
    train_list, test_list = get_block_train_test_split()
    points, labels = read_pkl('data/S3DIS/room_block_10_10/' + train_list[0])
    xyzs, rgbs, covars, lbls = sample_block(points,
                                            labels,
                                            sstride,
                                            bsize,
                                            bstride,
                                            min_pn=512,
                                            use_rescale=False,
                                            swap=False,
                                            flip_x=False,
                                            flip_y=False,
                                            covar_ds_stride=0.075,
                                            covar_nn_size=0.15)

    from sklearn.cluster import KMeans
    from draw_util import output_points
    for i in xrange(len(xyzs[:5])):
        kmeans = KMeans(5)
        colors = np.random.randint(0, 256, [5, 3])
        preds = kmeans.fit_predict(covars[i])
        output_points('test_result/{}.txt'.format(i), xyzs[i], colors[preds])
def prepare_subset():
    train_list, test_list = get_block_train_test_split()
    train_list += test_list
    file_list = [fn for fn in train_list if fn.split('_')[-2] == 'office']

    for fn in file_list:
        bg = time.time()
        path = 'data/S3DIS/room_block_10_10/' + fn
        flip_x = random.random() < 0.5
        flip_y = random.random() < 0.5
        swap = random.random() < 0.5
        all_data = [[] for _ in xrange(5)]
        for i in xrange(1):
            data = prepare_subset_single_file(path, 0.075, 1.5, 0.75, 128,
                                              True, swap, flip_x, flip_y, True,
                                              True)

            for k in xrange(5):
                all_data[k] += data[k]

        save_pkl('data/S3DIS/office_block/' + fn, all_data)
        print 'done {} cost {} s pn {}'.format(
            fn,
            time.time() - bg, np.mean([len(xyzs) for xyzs in all_data[0]]))
def interpolate(sxyzs, sprobs, qxyzs, ratio=1.0 / (2 * 0.075 * 0.075)):
    print sxyzs.shape
    print qxyzs.shape
    nidxs = libPointUtil.findNeighborInAnotherCPU(sxyzs, qxyzs, 6)
    nidxs_lens = np.asarray([len(idxs) for idxs in nidxs], dtype=np.int32)
    nidxs_bgs = compute_nidxs_bgs(nidxs_lens)
    nidxs = np.concatenate(nidxs, axis=0)
    qprobs = libPointUtil.interpolateProbsGPU(sxyzs, qxyzs, sprobs, nidxs,
                                              nidxs_lens, nidxs_bgs, ratio)

    return qprobs


if __name__ == "__main__":
    import random
    train_list, test_list = get_block_train_test_split()
    sess, pls, ops, feed_dict = build_session()
    all_preds, all_labels = [], []
    fp = np.zeros(13, dtype=np.uint64)
    tp = np.zeros(13, dtype=np.uint64)
    fn = np.zeros(13, dtype=np.uint64)
    random.shuffle(test_list)
    for fi, fs in enumerate(test_list):
        sxyzs, slbls, sprobs = eval_room_probs(fs, sess, pls, ops, feed_dict)
        filename = 'data/S3DIS/room_block_10_10/' + fs
        points, labels = read_pkl(filename)
        qxyzs = np.ascontiguousarray(points[:, :3], np.float32)
        qn = qxyzs.shape[0]
        rn = 1000000
        qrn = qn / rn
        if qn % rn != 0: qrn += 1
def train():
    import random
    from aug_util import flip, swap_xy
    train_list, test_list = get_block_train_test_split()
    train_list = [
        'data/S3DIS/sampled_train_nolimits/' + fn for fn in train_list
    ]
    # train_list=['data/S3DIS/sampled_train_no_aug/'+fn for fn in train_list]
    # with open('cached/s3dis_merged_train.txt', 'r') as f:
    #     train_list=[line.strip('\n') for line in f.readlines()]
    random.shuffle(train_list)
    test_list = ['data/S3DIS/sampled_test_nolimits/' + fn for fn in test_list]

    def train_fn(model, filename):
        xyzs, rgbs, covars, lbls, block_mins = read_pkl(filename)

        num = len(xyzs)
        for i in xrange(num):
            # pt_num=len(xyzs[i])
            # ds_ratio=np.random.uniform(0.8,1.0)
            # idxs=np.random.choice(pt_num,int(ds_ratio*pt_num),False)
            #
            # xyzs[i]=xyzs[i][idxs]
            # rgbs[i]=rgbs[i][idxs]
            # covars[i]=covars[i][idxs]
            # lbls[i]=lbls[i][idxs]

            if random.random() < 0.5:
                xyzs[i] = flip(xyzs[i], axis=0)

            if random.random() < 0.5:
                xyzs[i] = flip(xyzs[i], axis=1)

            if random.random() < 0.5:
                xyzs[i] = swap_xy(xyzs[i])

            jitter_color = np.random.uniform(-0.02, 0.02, rgbs[i].shape)
            rgbs[i] += jitter_color

        return xyzs, rgbs, covars, lbls, block_mins

    def test_fn(model, filename):
        xyzs, rgbs, covars, lbls, block_mins = read_pkl(filename)
        return xyzs, rgbs, covars, lbls, block_mins

    train_provider = Provider(train_list, 'train',
                              FLAGS.batch_size * FLAGS.num_gpus, test_fn)
    test_provider = Provider(test_list, 'test',
                             FLAGS.batch_size * FLAGS.num_gpus, test_fn)

    try:
        pls = build_placeholder(FLAGS.num_gpus)
        batch_num_per_epoch = 2000 / FLAGS.num_gpus
        ops = train_ops(pls['xyzs'], pls['feats'], pls['lbls'],
                        pls['is_training'], batch_num_per_epoch)

        feed_dict = {}
        config = tf.ConfigProto()
        config.gpu_options.allow_growth = True
        config.allow_soft_placement = True
        config.log_device_placement = False
        sess = tf.Session(config=config)
        saver = tf.train.Saver(max_to_keep=500)
        if FLAGS.restore:
            saver.restore(sess, FLAGS.restore_model)
        else:
            sess.run(tf.global_variables_initializer())

        summary_writer = tf.summary.FileWriter(FLAGS.train_dir,
                                               graph=sess.graph)

        for epoch_num in xrange(FLAGS.restore_epoch, FLAGS.train_epoch_num):
            train_one_epoch(ops, pls, sess, summary_writer, train_provider,
                            epoch_num, feed_dict)
            test_one_epoch(ops, pls, sess, saver, test_provider, epoch_num,
                           feed_dict)

    finally:
        train_provider.close()
        test_provider.close()