Ejemplo n.º 1
0
def voxelnet_init():
    global sess, model

    gpu_options = tf.GPUOptions(
        per_process_gpu_memory_fraction=cfg.GPU_MEMORY_FRACTION,
        visible_device_list=cfg.GPU_AVAILABLE,
        allow_growth=True)

    config = tf.ConfigProto(
        gpu_options=gpu_options,
        device_count={
            "GPU": cfg.GPU_USE_COUNT,
        },
        allow_soft_placement=True,
    )

    sess = tf.Session(config=config)

    model = RPN3D(cls=cfg.DETECT_OBJ,
                  single_batch_size=args.single_batch_size,
                  avail_gpus=cfg.GPU_AVAILABLE.split(','))

    if tf.train.get_checkpoint_state(save_model_dir):
        print("Reading model parameters from %s" % save_model_dir)
        model.saver.restore(sess, tf.train.latest_checkpoint(save_model_dir))
def main(_):
    with tf.Graph().as_default():

        gpu_options = tf.GPUOptions(
            per_process_gpu_memory_fraction=cfg.GPU_MEMORY_FRACTION,
            visible_device_list=cfg.GPU_AVAILABLE,
            allow_growth=True)

        conf = tf.ConfigProto(
            gpu_options=gpu_options,
            device_count={
                "GPU": cfg.GPU_USE_COUNT,
            },
            allow_soft_placement=True,
        )

        with tf.Session(config=conf) as sess:
            model = RPN3D(cls=cfg.DETECT_OBJ,
                          single_batch_size=args.single_batch_size,
                          avail_gpus=cfg.GPU_AVAILABLE.split(','))

        nd_names = model.get_output_nodes_names()
        node_list = []
        # we ned the names of the tensor, not of the ops
        for nd in nd_names:
            node_list.append(nd + ':0')

        print(node_list)
        print("\n\n\n")

        with gfile.FastGFile(save_model_dir + "/frozen.pb", 'rb') as f:
            graph_def = tf.GraphDef()
            graph_def.ParseFromString(f.read())
            trt_graph = trt.create_inference_graph(
                input_graph_def=graph_def,
                outputs=node_list,
                max_batch_size=2,
                max_workspace_size_bytes=max_workspace_size_bytes,
                minimum_segment_size=6,
                precision_mode=args.precision)
            path_new_frozen_pb = save_model_dir + "/newFrozenModel_TRT_{}.pb".format(
                args.precision)
            with gfile.FastGFile(path_new_frozen_pb, 'wb') as fp:
                fp.write(trt_graph.SerializeToString())
                print("TRT graph written to path ", path_new_frozen_pb)
            with tf.Session() as sess:
                writer = tf.summary.FileWriter('logs', sess.graph)
                writer.close()
def main(_):
    with tf.Graph().as_default():
        gpu_options = tf.GPUOptions(
            per_process_gpu_memory_fraction=cfg.GPU_MEMORY_FRACTION,
            visible_device_list=cfg.GPU_AVAILABLE,
            allow_growth=True)

        config = tf.ConfigProto(
            gpu_options=gpu_options,
            device_count={
                "GPU": cfg.GPU_USE_COUNT,
            },
            allow_soft_placement=True,
        )

        # just one run to initialize all the variables
        with tf.Session(config=config) as sess:
            model = RPN3D(cls=cfg.DETECT_OBJ,
                          single_batch_size=args.single_batch_size,
                          avail_gpus=cfg.GPU_AVAILABLE.split(','))
            # param init/restore
            if tf.train.get_checkpoint_state(save_model_dir):
                print("Reading model parameters from %s" % save_model_dir)
                model.saver.restore(sess,
                                    tf.train.latest_checkpoint(save_model_dir))

            for batch in iterate_data(test_dir,
                                      shuffle=False,
                                      aug=False,
                                      is_testset=True,
                                      batch_size=args.single_batch_size *
                                      cfg.GPU_USE_COUNT,
                                      multi_gpu_sum=cfg.GPU_USE_COUNT):
                tags, results = model.predict_step(sess,
                                                   batch,
                                                   summary=False,
                                                   vis=False)
                break

            model.save_frozen_graph(sess, save_model_dir + "/frozen.pb")
def main(_):

    with tf.Graph().as_default():

        gpu_options = tf.GPUOptions(
            per_process_gpu_memory_fraction=cfg.GPU_MEMORY_FRACTION,
            visible_device_list=cfg.GPU_AVAILABLE,
            allow_growth=True)

        config = tf.ConfigProto(
            gpu_options=gpu_options,
            device_count={
                "GPU": cfg.GPU_USE_COUNT,
            },
            allow_soft_placement=True,
        )
        calib_graph = load_graph(save_model_dir + "/frozen.pb")
        with tf.Session(config=config, graph=calib_graph) as sess:
            model = RPN3D(cls=cfg.DETECT_OBJ,
                          single_batch_size=args.single_batch_size,
                          avail_gpus=cfg.GPU_AVAILABLE.split(','))

            writer = tf.summary.FileWriter('tensorboard_logs', sess.graph)
            writer.close()
Ejemplo n.º 5
0
def main(_):
    # TODO: split file support
    with tf.Graph().as_default():
        global save_model_dir
        with KittiLoader(object_dir=os.path.join(dataset_dir, 'training'), queue_size=50, require_shuffle=True,
                         is_testset=False, batch_size=args.single_batch_size * cfg.GPU_USE_COUNT, use_multi_process_num=8, multi_gpu_sum=cfg.GPU_USE_COUNT, aug=True) as train_loader, \
            KittiLoader(object_dir=os.path.join(dataset_dir, 'testing'), queue_size=50, require_shuffle=True,
                        is_testset=False, batch_size=args.single_batch_size * cfg.GPU_USE_COUNT, use_multi_process_num=8, multi_gpu_sum=cfg.GPU_USE_COUNT, aug=False) as valid_loader:

            gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=cfg.GPU_MEMORY_FRACTION,
                                        visible_device_list=cfg.GPU_AVAILABLE,
                                        allow_growth=True)
            config = tf.ConfigProto(
                gpu_options=gpu_options,
                device_count={
                    "GPU": cfg.GPU_USE_COUNT,
                },
                allow_soft_placement=True,
            )
            with tf.Session(config=config) as sess:
                model = RPN3D(
                    cls=cfg.DETECT_OBJ,
                    single_batch_size=args.single_batch_size,
                    learning_rate=args.lr,
                    max_gradient_norm=5.0,
                    is_train=True,
                    alpha=1.5,
                    beta=1,
                    avail_gpus=cfg.GPU_AVAILABLE.split(',')
                )
                # param init/restore
                if tf.train.get_checkpoint_state(save_model_dir):
                    print("Reading model parameters from %s" % save_model_dir)
                    model.saver.restore(
                        sess, tf.train.latest_checkpoint(save_model_dir))
                else:
                    print("Created model with fresh parameters.")
                    tf.global_variables_initializer().run()

                # train and validate
                iter_per_epoch = int(
                    len(train_loader) / (args.single_batch_size * cfg.GPU_USE_COUNT))
                is_summary, is_summary_image, is_validate = False, False, False

                summary_interval = 5
                summary_image_interval = 20
                save_model_interval = int(iter_per_epoch / 3)
                validate_interval = 60

                summary_writer = tf.summary.FileWriter(log_dir, sess.graph)
                while model.epoch.eval() < args.max_epoch:
                    is_summary, is_summary_image, is_validate = False, False, False
                    iter = model.global_step.eval()
                    if not iter % summary_interval:
                        is_summary = True
                    if not iter % summary_image_interval:
                        is_summary_image = True
                    if not iter % save_model_interval:
                        model.saver.save(sess, os.path.join(
                            save_model_dir, 'checkpoint'), global_step=model.global_step)
                    if not iter % validate_interval:
                        is_validate = True
                    if not iter % iter_per_epoch:
                        sess.run(model.epoch_add_op)
                        print('train {} epoch, total: {}'.format(
                            model.epoch.eval(), args.max_epoch))

                    ret = model.train_step(
                        sess, train_loader.load(), train=True, summary=is_summary)
                    print('train: {}/{} @ epoch:{}/{} loss: {} reg_loss: {} cls_loss: {} {}'.format(iter,
                                                                                                    iter_per_epoch * args.max_epoch, model.epoch.eval(), args.max_epoch, ret[0], ret[1], ret[2], args.tag))

                    if is_summary:
                        summary_writer.add_summary(ret[-1], iter)

                    if is_summary_image:
                        ret = model.predict_step(
                                sess, valid_loader.load(), summary=True)
                        summary_writer.add_summary(ret[-1], iter)

                    if is_validate:
                        ret = model.validate_step(
                                sess, valid_loader.load(), summary=True)
                        summary_writer.add_summary(ret[-1], iter)

                    if check_if_should_pause(args.tag):
                        model.saver.save(sess, os.path.join(
                            save_model_dir, 'checkpoint'), global_step=model.global_step)
                        print('pause and save model @ {} steps:{}'.format(
                            save_model_dir, model.global_step.eval()))
                        sys.exit(0)

                print('train done. total epoch:{} iter:{}'.format(
                    model.epoch.eval(), model.global_step.eval()))

                # finallly save model
                model.saver.save(sess, os.path.join(
                    save_model_dir, 'checkpoint'), global_step=model.global_step)
Ejemplo n.º 6
0
def main(_):
    # TODO: split file support
    with tf.Graph().as_default():
        global save_model_dir
        start_epoch = 0
        global_counter = 0

        gpu_options = tf.GPUOptions(
            per_process_gpu_memory_fraction=cfg.GPU_MEMORY_FRACTION,
            visible_device_list=cfg.GPU_AVAILABLE,
            allow_growth=True)
        config = tf.ConfigProto(
            gpu_options=gpu_options,
            device_count={
                "GPU": cfg.GPU_USE_COUNT,
            },
            allow_soft_placement=True,
        )
        with tf.Session(config=config) as sess:
            model = RPN3D(
                cls=cfg.DETECT_OBJ,
                single_batch_size=args.single_batch_size,
                learning_rate=args.lr,
                max_gradient_norm=5.0,
                alpha=args.alpha,
                beta=args.beta,
                avail_gpus=cfg.GPU_AVAILABLE  #.split(',')
            )
            # param init/restore
            if tf.train.get_checkpoint_state(save_model_dir):
                print("Reading model parameters from %s" % save_model_dir)
                model.saver.restore(sess,
                                    tf.train.latest_checkpoint(save_model_dir))
                start_epoch = model.epoch.eval() + 1
                global_counter = model.global_step.eval() + 1
            else:
                print("Created model with fresh parameters.")
                tf.global_variables_initializer().run()

            # train and validate
            is_summary, is_summary_image, is_validate = False, False, False

            summary_interval = 5
            summary_val_interval = 10
            summary_writer = tf.summary.FileWriter(log_dir, sess.graph)

            # training
            for epoch in range(start_epoch, args.max_epoch):
                counter = 0
                batch_time = time.time()
                for batch in iterate_data(train_dir,
                                          shuffle=True,
                                          aug=True,
                                          is_testset=False,
                                          batch_size=args.single_batch_size *
                                          cfg.GPU_USE_COUNT,
                                          multi_gpu_sum=cfg.GPU_USE_COUNT):

                    counter += 1
                    global_counter += 1

                    if counter % summary_interval == 0:
                        is_summary = True
                    else:
                        is_summary = False

                    start_time = time.time()
                    ret = model.train_step(sess,
                                           batch,
                                           train=True,
                                           summary=is_summary)
                    forward_time = time.time() - start_time
                    batch_time = time.time() - batch_time


                    print('train: {} @ epoch:{}/{} loss: {:.4f} reg_loss: {:.4f} cls_loss: {:.4f} cls_pos_loss: {:.4f} cls_neg_loss: {:.4f} forward time: {:.4f} batch time: {:.4f}'\
                            .format(counter,epoch, args.max_epoch, ret[0], ret[1], ret[2], ret[3], ret[4], forward_time, batch_time))

                    with open('log/train.txt', 'a') as f:
                        f.write(
                            'train: {} @ epoch:{}/{} loss: {:.4f} reg_loss: {:.4f} cls_loss: {:.4f} cls_pos_loss: {:.4f} cls_neg_loss: {:.4f} forward time: {:.4f} batch time: {:.4f} \n'
                            .format(counter, epoch, args.max_epoch, ret[0],
                                    ret[1], ret[2], ret[3], ret[4],
                                    forward_time, batch_time))

                    #print(counter, summary_interval, counter % summary_interval)
                    if counter % summary_interval == 0:
                        print("summary_interval now")
                        summary_writer.add_summary(ret[-1], global_counter)

                    #print(counter, summary_val_interval, counter % summary_val_interval)
                    if counter % summary_val_interval == 0:
                        print("summary_val_interval now")
                        batch = sample_test_data(
                            val_dir,
                            args.single_batch_size * cfg.GPU_USE_COUNT,
                            multi_gpu_sum=cfg.GPU_USE_COUNT)

                        ret = model.validate_step(sess, batch, summary=True)
                        summary_writer.add_summary(ret[-1], global_counter)

                        try:
                            ret = model.predict_step(sess, batch, summary=True)
                            summary_writer.add_summary(ret[-1], global_counter)
                        except:
                            print("prediction skipped due to error")

                    if check_if_should_pause(args.tag):
                        model.saver.save(sess,
                                         os.path.join(save_model_dir,
                                                      'checkpoint'),
                                         global_step=model.global_step)
                        print('pause and save model @ {} steps:{}'.format(
                            save_model_dir, model.global_step.eval()))
                        sys.exit(0)

                    batch_time = time.time()

                sess.run(model.epoch_add_op)

                model.saver.save(sess,
                                 os.path.join(save_model_dir, 'checkpoint'),
                                 global_step=model.global_step)

                # dump test data every 10 epochs
                if (epoch + 1) % 10 == 0:
                    # create output folder
                    os.makedirs(os.path.join(args.output_path, str(epoch)),
                                exist_ok=True)
                    os.makedirs(os.path.join(args.output_path, str(epoch),
                                             'data'),
                                exist_ok=True)
                    if args.vis:
                        os.makedirs(os.path.join(args.output_path, str(epoch),
                                                 'vis'),
                                    exist_ok=True)

                    for batch in iterate_data(
                            val_dir,
                            shuffle=False,
                            aug=False,
                            is_testset=False,
                            batch_size=args.single_batch_size *
                            cfg.GPU_USE_COUNT,
                            multi_gpu_sum=cfg.GPU_USE_COUNT):

                        if args.vis:
                            tags, results, front_images, bird_views, heatmaps = model.predict_step(
                                sess, batch, summary=False, vis=True)
                        else:
                            tags, results = model.predict_step(sess,
                                                               batch,
                                                               summary=False,
                                                               vis=False)

                        for tag, result in zip(tags, results):
                            of_path = os.path.join(args.output_path,
                                                   str(epoch), 'data',
                                                   tag + '.txt')
                            with open(of_path, 'w+') as f:
                                labels = box3d_to_label([result[:, 1:8]],
                                                        [result[:, 0]],
                                                        [result[:, -1]],
                                                        coordinate='lidar')[0]
                                for line in labels:
                                    f.write(line)
                                print('write out {} objects to {}'.format(
                                    len(labels), tag))
                        # dump visualizations
                        if args.vis:
                            for tag, front_image, bird_view, heatmap in zip(
                                    tags, front_images, bird_views, heatmaps):
                                front_img_path = os.path.join(
                                    args.output_path, str(epoch), 'vis',
                                    tag + '_front.jpg')
                                bird_view_path = os.path.join(
                                    args.output_path, str(epoch), 'vis',
                                    tag + '_bv.jpg')
                                heatmap_path = os.path.join(
                                    args.output_path, str(epoch), 'vis',
                                    tag + '_heatmap.jpg')
                                cv2.imwrite(front_img_path, front_image)
                                cv2.imwrite(bird_view_path, bird_view)
                                cv2.imwrite(heatmap_path, heatmap)

                    # execute evaluation code
                    cmd_1 = "./kitti_eval/launch_test.sh"
                    cmd_2 = os.path.join(args.output_path, str(epoch))
                    cmd_3 = os.path.join(args.output_path, str(epoch), 'log')
                    os.system(" ".join([cmd_1, cmd_2, cmd_3]))

            print('train done. total epoch:{} iter:{}'.format(
                epoch, model.global_step.eval()))

            # finallly save model
            model.saver.save(sess,
                             os.path.join(save_model_dir, 'checkpoint'),
                             global_step=model.global_step)
Ejemplo n.º 7
0
        gpu_options = tf.GPUOptions(
            per_process_gpu_memory_fraction=cfg.GPU_MEMORY_FRACTION,
            visible_device_list=cfg.GPU_AVAILABLE,
            allow_growth=True)

        config = tf.ConfigProto(
            gpu_options=gpu_options,
            device_count={
                "GPU": cfg.GPU_USE_COUNT,
            },
            allow_soft_placement=True,
        )

        with tf.Session(config=config) as sess:
            model = RPN3D(cls=cfg.DETECT_OBJ,
                          single_batch_size=args.single_batch_size,
                          avail_gpus=cfg.GPU_AVAILABLE.split(','))
            if tf.train.get_checkpoint_state(save_model_dir):
                print("Reading model parameters from %s" % save_model_dir)
                model.saver.restore(sess,
                                    tf.train.latest_checkpoint(save_model_dir))

            test_count = 0
            for batch in iterate_data(train_dir,
                                      shuffle=False,
                                      aug=False,
                                      is_testset=False,
                                      batch_size=args.single_batch_size *
                                      cfg.GPU_USE_COUNT,
                                      multi_gpu_sum=cfg.GPU_USE_COUNT):
Ejemplo n.º 8
0
def main():
    # load config
    train_dataset_dir = os.path.join(cfg.DATA_DIR, "training")
    val_dataset_dir = os.path.join(cfg.DATA_DIR, "validation")
    eval_dataset_dir = os.path.join(cfg.DATA_DIR, "validation")
    save_model_dir = os.path.join("./save_model", cfg.TAG)
    log_dir = os.path.join("./log", cfg.TAG)
    os.makedirs(log_dir, exist_ok=True)
    os.makedirs(save_model_dir, exist_ok=True)
    config = gpu_config()
    max_epoch = cfg.MAX_EPOCH

    # config logging
    logging.basicConfig(filename='./log/' + cfg.TAG + '/train.log',
                        level=logging.ERROR)
    copyfile("model/model.py",
             os.path.join(log_dir,
                          "model.py"))  # copyu model.py into log/$TAG/
    copyfile(
        "model/group_pointcloud.py",
        os.path.join(log_dir,
                     "group_pointcloud.py"))  # copyu model.py into log/$TAG/
    copyfile("model/rpn.py",
             os.path.join(log_dir, "rpn.py"))  # copy rpn.py into log/$TAG/
    copyfile("config.py",
             os.path.join(log_dir,
                          "config.py"))  # copy config.py into log/$TAG/

    print("tag: {}".format(cfg.TAG))
    logging.critical("tag: {}".format(cfg.TAG))
    with tf.Session(config=config) as sess:
        # load model
        model = RPN3D(cls=cfg.DETECT_OBJ,
                      single_batch_size=cfg.SINGLE_BATCH_SIZE,
                      is_training=True,
                      learning_rate=cfg.LR,
                      max_gradient_norm=5.0,
                      alpha=cfg.ALPHA,
                      beta=cfg.BETA,
                      gamma=cfg.GAMMA,
                      avail_gpus=cfg.GPU_AVAILABLE.split(','))
        saver = tf.train.Saver(write_version=tf.train.SaverDef.V2,
                               max_to_keep=10,
                               pad_step_number=True,
                               keep_checkpoint_every_n_hours=1.0)
        summary_writer = tf.summary.FileWriter(log_dir, sess.graph)
        # param init/restore
        if not cfg.LOAD_CHECKPT == None:
            print("Reading model parameters from {}, {}".format(
                save_model_dir, cfg.LOAD_CHECKPT))
            logging.critical("Reading model parameters from {}, {}".format(
                save_model_dir, cfg.LOAD_CHECKPT))
            saver.restore(sess, os.path.join(save_model_dir, cfg.LOAD_CHECKPT))
            start_epoch = model.epoch.eval() + 1
        elif tf.train.get_checkpoint_state(save_model_dir):
            print("Reading model parameters from %s" % save_model_dir)
            logging.critical("Reading model parameters from %s" %
                             save_model_dir)
            saver.restore(sess, tf.train.latest_checkpoint(save_model_dir))
            start_epoch = model.epoch.eval() + 1
        else:
            print("Created model with fresh parameters.")
            logging.critical("Created model with fresh parameters.")
            tf.global_variables_initializer().run()
            start_epoch = 0

        # train
        for epoch in range(start_epoch, max_epoch):
            # load data
            data_generator_train = iterate_data(
                train_dataset_dir,
                shuffle=True,
                aug=False,
                is_testset=False,
                batch_size=cfg.SINGLE_BATCH_SIZE * cfg.GPU_USE_COUNT,
                multi_gpu_sum=cfg.GPU_USE_COUNT)
            data_generator_val = iterate_data(
                val_dataset_dir,
                shuffle=True,
                aug=False,
                is_testset=False,
                batch_size=cfg.SINGLE_BATCH_SIZE * cfg.GPU_USE_COUNT,
                multi_gpu_sum=cfg.GPU_USE_COUNT)
            for batch in data_generator_train:
                # train
                ret = model.train_step(sess, batch, train=True, summary=True)
                output_log(epoch,
                           model.global_step.eval(),
                           pos_cls_loss=ret[3],
                           neg_cls_loss=ret[4],
                           cls_loss=ret[2],
                           reg_loss=ret[1],
                           loss=ret[0],
                           phase="train",
                           logger=logging)
                summary_writer.add_summary(ret[-1], model.global_step.eval())
                if model.global_step.eval() % cfg.VALIDATE_INTERVAL == 0:
                    # val
                    val_batch = data_generator_val.__next__()
                    ret = model.validate_step(sess, val_batch, summary=True)
                    output_log(epoch,
                               model.global_step.eval(),
                               pos_cls_loss=ret[3],
                               neg_cls_loss=ret[4],
                               cls_loss=ret[2],
                               reg_loss=ret[1],
                               loss=ret[0],
                               phase="validation",
                               logger=logging)
                    summary_writer.add_summary(ret[-1],
                                               model.global_step.eval())
                    # eval
                    eval_batch = sample_test_data(
                        eval_dataset_dir,
                        cfg.SINGLE_BATCH_SIZE * cfg.GPU_USE_COUNT,
                        multi_gpu_sum=cfg.GPU_USE_COUNT)
                    try:
                        ret = model.predict_step(sess,
                                                 eval_batch,
                                                 summary=True)
                        summary_writer.add_summary(ret[-1],
                                                   model.global_step.eval())
                    except:
                        print("prediction skipped due to error")

            sess.run(model.epoch_add_op)
            model.saver.save(sess,
                             os.path.join(save_model_dir, 'checkpoint'),
                             global_step=model.global_step)
        print('{} Training Done!'.format(cfg.TAG))
        logging.critical('{} Training Done!'.format(cfg.TAG))
Ejemplo n.º 9
0
    gpu_options = tf.GPUOptions(
        per_process_gpu_memory_fraction=cfg.GPU_MEMORY_FRACTION,
        visible_device_list=GPU_AVAILABLE,
        allow_growth=True)

    config = tf.ConfigProto(gpu_options=gpu_options,
                            device_count={
                                "GPU": GPU_USE_COUNT,
                            },
                            allow_soft_placement=True,
                            log_device_placement=True)

    with tf.Session(config=config) as sess:
        model = RPN3D(cls=cfg.DETECT_OBJ,
                      single_batch_size=bs,
                      avail_gpus=GPU_AVAILABLE)
        if tf.train.get_checkpoint_state(save_model_dir):
            print("Reading model parameters from %s" % save_model_dir)
            model.saver.restore(sess,
                                tf.train.latest_checkpoint(save_model_dir))
        counter = 0
        #         with experiment.test():
        for batch in iterate_data(val_dir,
                                  shuffle=False,
                                  aug=False,
                                  is_testset=False,
                                  batch_size=bs * GPU_USE_COUNT,
                                  multi_gpu_sum=GPU_USE_COUNT):
            #             experiment.log_metric("counter",counter)
Ejemplo n.º 10
0
def main(_):

    with tf.Graph().as_default():

        gpu_options = tf.GPUOptions(
            per_process_gpu_memory_fraction=cfg.GPU_MEMORY_FRACTION,
            visible_device_list=cfg.GPU_AVAILABLE,
            allow_growth=True)

        config = tf.ConfigProto(
            gpu_options=gpu_options,
            device_count={
                "GPU": cfg.GPU_USE_COUNT,
            },
            allow_soft_placement=True,
        )

        with tf.Session(config=config) as sess:
            model = RPN3D(cls=cfg.DETECT_OBJ,
                          decrease=args.decrease,
                          minimize=args.minimize,
                          single_batch_size=args.single_batch_size,
                          avail_gpus=cfg.GPU_AVAILABLE.split(','))

            # param init/restore
            if tf.train.get_checkpoint_state(save_model_dir):
                print("Reading model parameters from %s" % save_model_dir)
                model.saver.restore(sess,
                                    tf.train.latest_checkpoint(save_model_dir))

            for batch in iterate_data(test_dir,
                                      shuffle=False,
                                      aug=False,
                                      is_testset=True,
                                      batch_size=args.single_batch_size *
                                      cfg.GPU_USE_COUNT,
                                      multi_gpu_sum=cfg.GPU_USE_COUNT):

                if args.vis:
                    tags, results, front_images, bird_views, heatmaps = model.predict_step(
                        sess, batch, summary=False, vis=True)
                else:
                    tags, results = model.predict_step(sess,
                                                       batch,
                                                       summary=False,
                                                       vis=False)

                for tag, result in zip(tags, results):
                    of_path = os.path.join(res_dir, 'data', tag + '.txt')
                    with open(of_path, 'w+') as f:
                        labels = box3d_to_label([result[:, 1:8]],
                                                [result[:, 0]],
                                                [result[:, -1]],
                                                coordinate='lidar')[0]
                        for line in labels:
                            f.write(line)
                        print('write out {} objects to {}'.format(
                            len(labels), tag))

                # dump visualizations
                if args.vis:
                    for tag, front_image, bird_view, heatmap in zip(
                            tags, front_images, bird_views, heatmaps):
                        front_img_path = os.path.join(res_dir, 'vis',
                                                      tag + '_front.jpg')
                        bird_view_path = os.path.join(res_dir, 'vis',
                                                      tag + '_bv.jpg')
                        heatmap_path = os.path.join(res_dir, 'vis',
                                                    tag + '_heatmap.jpg')
                        cv2.imwrite(front_img_path, front_image)
                        cv2.imwrite(bird_view_path, bird_view)
                        cv2.imwrite(heatmap_path, heatmap)
Ejemplo n.º 11
0
def main():
    # load config
    eval_dataset_dir = os.path.join(cfg.DATA_DIR, "validation")
    save_model_dir = os.path.join("./save_model", cfg.TAG)
    log_dir = os.path.join("./log", cfg.TAG)
    predall_dir = os.path.join("./predicts-all", cfg.TAG)
    os.makedirs(log_dir, exist_ok=True)
    os.makedirs(predall_dir, exist_ok=True)
    config = gpu_config()

    # config logging
    logging.basicConfig(filename='./log/' + cfg.TAG + '/test_all.log',
                        level=logging.ERROR)

    with tf.Session(config=config) as sess:
        # load model
        model = RPN3D(cls=cfg.DETECT_OBJ,
                      single_batch_size=cfg.SINGLE_BATCH_SIZE,
                      is_training=False,
                      learning_rate=cfg.LR,
                      max_gradient_norm=5.0,
                      alpha=cfg.ALPHA,
                      beta=cfg.BETA,
                      gamma=cfg.GAMMA,
                      avail_gpus=cfg.GPU_AVAILABLE.split(','))
        saver = tf.train.Saver(write_version=tf.train.SaverDef.V2,
                               max_to_keep=10,
                               pad_step_number=True,
                               keep_checkpoint_every_n_hours=1.0)
        # param init/restore
        print("Reading model parameters from %s" % save_model_dir)
        logging.critical("Reading model parameters from %s" % save_model_dir)
        ckpt_list = get_ckpt_list(save_model_dir)
        for ckpt in ckpt_list:
            print("Checkpoint: {}".format(ckpt))
            logging.critical("Checkpoint: {}".format(ckpt))
            pred_dir = os.path.join(predall_dir, ckpt)
            if os.path.isdir(pred_dir):
                continue
            os.makedirs(pred_dir, exist_ok=True)
            os.makedirs(os.path.join(pred_dir, "data"), exist_ok=True)
            os.makedirs(os.path.join(pred_dir, "vis"), exist_ok=True)
            saver.restore(sess, os.path.join(save_model_dir, ckpt))
            # load data
            data_generator_test = iterate_data(
                eval_dataset_dir,
                shuffle=False,
                aug=False,
                is_testset=False,
                batch_size=cfg.SINGLE_BATCH_SIZE * cfg.GPU_USE_COUNT,
                multi_gpu_sum=cfg.GPU_USE_COUNT)
            for batch in data_generator_test:
                tags, results, front_images, bird_views, heatmaps = model.predict_step(
                    sess, batch, summary=False, vis=True)
                for tag, result in zip(tags, results):
                    of_path = os.path.join(pred_dir, 'data', tag + '.txt')
                    with open(of_path, 'w+') as f:
                        labels = box3d_to_label([result[:, 1:8]],
                                                [result[:, 0]],
                                                [result[:, -1]],
                                                coordinate='lidar')[0]
                        for line in labels:
                            f.write(line)
                        print('write out {} objects to {}'.format(
                            len(labels), tag))
                        logging.critical('write out {} objects to {}'.format(
                            len(labels), tag))
                    for tag, front_image, bird_view, heatmap in zip(
                            tags, front_images, bird_views, heatmaps):
                        front_img_path = os.path.join(pred_dir, 'vis',
                                                      tag + '_front.jpg')
                        bird_view_path = os.path.join(pred_dir, 'vis',
                                                      tag + '_bv.jpg')
                        heatmap_path = os.path.join(pred_dir, 'vis',
                                                    tag + '_heatmap.jpg')
                        cv2.imwrite(front_img_path, front_image)
                        cv2.imwrite(bird_view_path, bird_view)
                        cv2.imwrite(heatmap_path, heatmap)

            print('{} Testing Done! Starting Evaluation!'.format(cfg.TAG +
                                                                 ckpt))
            logging.critical(
                '{} Testing Done!  Starting Evaluation!'.format(cfg.TAG +
                                                                ckpt))
            # ./kitti_eval/evaluate_object_3d_offline /usr/app/TuneDataKitti/validation/label_2/ ./predicts/$TAG/ > ./predicts/$TAG/cmd.log
            cmd = "./kitti_eval/evaluate_object_3d_offline" + " " \
                + os.path.join(eval_dataset_dir, "label_2/") + " " \
                + pred_dir + " " \
                + ">" + " " \
                + os.path.join(pred_dir, "cmd.log")
            os.system(cmd)
            print('{} Evaluation Done!'.format(cfg.TAG + ckpt))
            logging.critical('{} Evaluation Done!'.format(cfg.TAG + ckpt))
        print('{} Testing Done!'.format(cfg.TAG))
        logging.critical('{} Testing Done!'.format(cfg.TAG))
Ejemplo n.º 12
0
 with KittiLoader(object_dir=dataset_dir, queue_size=100, require_shuffle=False, is_testset=True, batch_size=args.single_batch_size*cfg.GPU_USE_COUNT, use_multi_process_num=8, multi_gpu_sum=cfg.GPU_USE_COUNT) as test_loader:
     gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=cfg.GPU_MEMORY_FRACTION, 
         visible_device_list=cfg.GPU_AVAILABLE,
         allow_growth=True)
     config = tf.ConfigProto(
         gpu_options=gpu_options,
         device_count={
             "GPU" : cfg.GPU_USE_COUNT,  
         allow_soft_placement=True,
         }
     )
     with tf.Session(config=config) as sess:
         model = RPN3D(
             cls=cfg.DETECT_OBJ,
             batch_size=1,
             tag=args.tag,
             is_train=False,
             avail_gpus=cfg.GPU_AVAILABLE.split(',')
         )
         while True:
             data = test_loader.load()
             if data is None:
                 print('test done.')
                 break
             ret = model.predict_step(sess, data)
             # ret: A, B
             # A: (N) tag
             # B: (N, N') (class, x, y, z, h, w, l, rz, score)
             for tag, result in zip(*ret):
                 of_path = os.path.join(args.output_path, tag + '.txt')
                 with open(of_path, 'w+') as f:
Ejemplo n.º 13
0
def main(_):

    with tf.Graph().as_default():

        start_epoch = 0
        global_counter = 0

        gpu_options = tf.GPUOptions(
            per_process_gpu_memory_fraction=cfg.GPU_MEMORY_FRACTION,
            visible_device_list=cfg.GPU_AVAILABLE,
            allow_growth=True)

        config = tf.ConfigProto(
            gpu_options=gpu_options,
            device_count={
                "GPU": cfg.GPU_USE_COUNT,
            },
            allow_soft_placement=True,
        )

        with tf.Session(config=config) as sess:
            model = RPN3D(cls=cfg.DETECT_OBJ,
                          decrease=args.decrease,
                          minimize=args.minimize,
                          single_batch_size=args.single_batch_size,
                          learning_rate=args.lr,
                          max_gradient_norm=5.0,
                          alpha=args.alpha,
                          beta=args.beta,
                          avail_gpus=cfg.GPU_AVAILABLE.split(','))

            # param init/restore
            if tf.train.get_checkpoint_state(save_model_dir):
                print("Reading model parameters from %s" % save_model_dir)
                model.saver.restore(sess,
                                    tf.train.latest_checkpoint(save_model_dir))
                start_epoch = model.epoch.eval() + 1
                global_counter = model.global_step.eval() + 1
            else:
                print("Created model with fresh parameters.")
                tf.global_variables_initializer().run()

            # train and validate
            is_summary, is_summary_image, is_validate = False, False, False

            summary_interval = 5
            summary_val_interval = 10
            summary_writer = tf.summary.FileWriter(log_dir, sess.graph)

            # training
            for epoch in range(start_epoch, args.max_epoch):
                counter = 0
                batch_time = time.time()
                for batch in iterate_data(train_dir,
                                          shuffle=True,
                                          aug=True,
                                          is_testset=False,
                                          batch_size=args.single_batch_size *
                                          cfg.GPU_USE_COUNT,
                                          multi_gpu_sum=cfg.GPU_USE_COUNT):

                    counter += 1
                    global_counter += 1

                    if counter % summary_interval == 0:
                        is_summary = True
                    else:
                        is_summary = False

                    start_time = time.time()
                    ret = model.train_step(sess,
                                           batch,
                                           train=True,
                                           summary=is_summary)
                    forward_time = time.time() - start_time
                    batch_time = time.time() - batch_time

                    print(
                        'train: {} @ epoch:{}/{} loss: {:.4f} reg_loss: {:.4f} cls_loss: {:.4f} cls_pos_loss: {:.4f} cls_neg_loss: {:.4f} forward time: {:.4f} batch time: {:.4f}'
                        .format(counter, epoch + 1, args.max_epoch, ret[0],
                                ret[1], ret[2], ret[3], ret[4], forward_time,
                                batch_time))
                    with open(os.path.join('log', 'train.txt'), 'a') as f:
                        f.write(
                            'train: {} @ epoch:{}/{} loss: {:.4f} reg_loss: {:.4f} cls_loss: {:.4f} cls_pos_loss: {:.4f} cls_neg_loss: {:.4f} forward time: {:.4f} batch time: {:.4f} \n'
                            .format(counter, epoch + 1, args.max_epoch, ret[0],
                                    ret[1], ret[2], ret[3], ret[4],
                                    forward_time, batch_time))

                    if counter % summary_interval == 0:
                        print("summary_interval now")
                        summary_writer.add_summary(ret[-1], global_counter)

                    if counter % summary_val_interval == 0:
                        print("summary_val_interval now")
                        batch = sample_test_data(
                            val_dir,
                            args.single_batch_size * cfg.GPU_USE_COUNT,
                            multi_gpu_sum=cfg.GPU_USE_COUNT)

                        ret = model.validate_step(sess, batch, summary=True)
                        summary_writer.add_summary(ret[-1], global_counter)

                    if check_if_should_pause(args.tag):
                        model.saver.save(sess,
                                         os.path.join(save_model_dir,
                                                      'checkpoint'),
                                         global_step=model.global_step)
                        print('pause and save model @ {} steps:{}'.format(
                            save_model_dir, model.global_step.eval()))
                        sys.exit(0)

                    batch_time = time.time()

                sess.run(model.epoch_add_op)

                model.saver.save(sess,
                                 os.path.join(save_model_dir, 'checkpoint'),
                                 global_step=model.global_step)

                # dump test data every 10 epochs
                if (epoch + 1) % 10 == 0:
                    os.makedirs(os.path.join(res_dir, str(epoch)),
                                exist_ok=True)
                    os.makedirs(os.path.join(res_dir, str(epoch), 'data'),
                                exist_ok=True)

                    for batch in iterate_data(
                            val_dir,
                            shuffle=False,
                            aug=False,
                            is_testset=False,
                            batch_size=args.single_batch_size *
                            cfg.GPU_USE_COUNT,
                            multi_gpu_sum=cfg.GPU_USE_COUNT):

                        tags, results = model.predict_step(sess,
                                                           batch,
                                                           summary=False,
                                                           vis=False)

                        for tag, result in zip(tags, results):
                            of_path = os.path.join(res_dir, str(epoch), 'data',
                                                   tag + '.txt')
                            with open(of_path, 'w+') as f:
                                labels = box3d_to_label([result[:, 1:8]],
                                                        [result[:, 0]],
                                                        [result[:, -1]],
                                                        coordinate='lidar')[0]
                                for line in labels:
                                    f.write(line)
                                print('write out {} objects to {}'.format(
                                    len(labels), tag))

            # finally save model
            model.saver.save(sess,
                             os.path.join(save_model_dir, 'checkpoint'),
                             global_step=model.global_step)
Ejemplo n.º 14
0
def main(_):
    # TODO: split file support
    warn("main start")
    with tf.Graph().as_default():
        global save_model_dir
        with KittiLoader(object_dir=os.path.join(dataset_dir, dataset),
                         queue_size=16,
                         require_shuffle=shuffle,
                         is_testset=False,
                         batch_size=args.single_batch_size * cfg.GPU_USE_COUNT,
                         use_multi_process_num=8,
                         split_file=split_file,
                         valid_file=valid_file,
                         multi_gpu_sum=cfg.GPU_USE_COUNT) as train_loader:
            # , \KittiLoader(object_dir=os.path.join(dataset_dir, 'testing'), queue_size=50, require_shuffle=True,
            #         is_testset=False, batch_size=args.single_batch_size*cfg.GPU_USE_COUNT, use_multi_process_num=8, multi_gpu_sum=cfg.GPU_USE_COUNT) as valid_loader:

            gpu_options = tf.GPUOptions(
                per_process_gpu_memory_fraction=cfg.GPU_MEMORY_FRACTION,
                visible_device_list=cfg.GPU_AVAILABLE,
                allow_growth=True)
            config = tf.ConfigProto(
                gpu_options=gpu_options,
                device_count={
                    "GPU": cfg.GPU_USE_COUNT,
                },
                allow_soft_placement=True,
            )
            # tf_config = tf.ConfigProto(allow_soft_placement=True, log_device_placement=True)
            # tf_config.gpu_options.allow_growth = True
            with tf.Session(config=config) as sess:
                model = RPN3D(cls=cfg.DETECT_OBJ,
                              single_batch_size=args.single_batch_size,
                              learning_rate=args.lr,
                              max_gradient_norm=5.0,
                              is_train=True,
                              alpha=1.5,
                              beta=1.0,
                              avail_gpus=cfg.GPU_AVAILABLE.split(','))
                # param init/restore
                if tf.train.get_checkpoint_state(save_model_dir):
                    # model.saver.restore(sess, save_model_dir+'/checkpoint-00343950')#tf.train.latest_checkpoint(save_model_dir))
                    model.saver.restore(
                        sess, tf.train.latest_checkpoint(save_model_dir))
                    warn("loading done")
                else:
                    warn("Created model with fresh parameters.")
                    tf.global_variables_initializer().run()

                # train and validate
                iter_per_epoch = int(
                    len(train_loader) /
                    (args.single_batch_size * cfg.GPU_USE_COUNT))
                is_summary, is_summary_image, is_validate = False, False, False

                summary_interval = 50
                summary_image_interval = 50
                save_model_interval = 50
                validate_interval = 8000

                # if is_test == True:
                #     for test_idx in range(5236, train_loader.dataset_size):
                #         t0 = time.time()
                #         ret = model.test_step(sess, train_loader.load_specified(test_idx), output_path = save_result_dir, summary=True)
                #         t1 = time.time()
                #         warn("test: {:.2f}".format(t1-t0))

                if is_test:
                    total_iter = int(
                        np.ceil(train_loader.validset_size /
                                train_loader.batch_size))
                    # idx = iter
                    save_result_folder = os.path.join(save_result_dir,
                                                      "test_result")
                    mkdir_p(save_result_folder)
                    for idx in range(total_iter):
                        t0 = time.time()
                        if train_loader.batch_size * (
                                idx + 1) > train_loader.validset_size:
                            start_idx = train_loader.validset_size - train_loader.batch_size
                            end_idx = train_loader.validset_size
                        else:
                            start_idx = train_loader.batch_size * idx
                            end_idx = train_loader.batch_size * (idx + 1)

                        warn("start: {} end: {}".format(start_idx, end_idx))

                        # ret = model.test_step(sess, train_loader.load_specified(np.arange(start_idx, end_idx)), output_path= save_result_folder, summary=True, visualize = True)
                        ret = model.test_step(sess,
                                              train_loader.load(),
                                              output_path=save_result_folder,
                                              summary=True,
                                              visualize=True)
                        t1 = time.time()
                        warn("test: {:.2f} sec | remaining {:.2f} sec {}/{}".
                             format(t1 - t0, (t1 - t0) * (total_iter - idx),
                                    idx, total_iter))

                summary_writer = tf.summary.FileWriter(log_dir, sess.graph)

                while model.epoch.eval() < args.max_epoch:
                    is_summary, is_summary_image, is_validate = False, False, False
                    progress = model.epoch.eval() / args.max_epoch
                    train_loader.progress = progress

                    iter = model.global_step.eval()
                    if not iter % summary_interval:
                        is_summary = True
                    if not iter % summary_image_interval:
                        is_summary_image = True
                    if not iter % save_model_interval:
                        model.saver.save(sess,
                                         os.path.join(save_model_dir,
                                                      'checkpoint'),
                                         global_step=model.global_step)
                    if not iter % validate_interval:
                        is_validate = True
                    if not iter % iter_per_epoch:
                        sess.run(model.epoch_add_op)
                        t1 = time.time()
                        warn("train: {}".format(t1 - t0))
                        print('train {} epoch, total: {}'.format(
                            model.epoch.eval(), args.max_epoch))
                    t0 = time.time()
                    ret = model.train_step(sess,
                                           train_loader.load(),
                                           train=True,
                                           summary=is_summary)
                    t1 = time.time()
                    warn("train: {}".format(t1 - t0))
                    # warn("reg loss: {}".format(ret[1]))
                    # warn("corner loss: {}".format(ret[3]))
                    # for box in range(len(ret[4][0])):
                    #     warn("box {} : {}".format(box, ret[4][0][box]))
                    # warn("indexes : {}".format(ret[4][0]))
                    print(
                        'train: {:.2f} / 1 : {}/{} @ epoch:{}/{} {:.2f} sec, remaining: {:.2f} min, loss: {:.2f} corner_loss: {:.2f} reg_loss: {:.2f} cls_loss: {:.2f} {}'
                        .format(train_loader.progress, iter,
                                iter_per_epoch * args.max_epoch,
                                model.epoch.eval(), args.max_epoch, t1 - t0,
                                (t1 - t0) *
                                (iter_per_epoch * args.max_epoch - iter) // 60,
                                ret[0], ret[3], ret[1], ret[2], args.tag))

                    if is_summary:
                        summary_writer.add_summary(ret[-1], iter)

                    if is_summary_image:
                        t0 = time.time()
                        ret = model.predict_step(sess,
                                                 train_loader.load_specified(),
                                                 iter,
                                                 summary=True)
                        summary_writer.add_summary(ret[-1], iter)
                        t1 = time.time()
                        warn("predict: {}".format(t1 - t0))

                    if is_validate:
                        total_iter = int(
                            np.ceil(train_loader.validset_size /
                                    train_loader.batch_size))
                        # idx = iter
                        save_result_folder = os.path.join(
                            save_result_dir, "{}".format(iter))
                        mkdir_p(save_result_folder)
                        for idx in range(total_iter):
                            t0 = time.time()
                            if train_loader.batch_size * (
                                    idx + 1) > train_loader.validset_size:
                                start_idx = train_loader.validset_size - train_loader.batch_size
                                end_idx = train_loader.validset_size
                            else:
                                start_idx = train_loader.batch_size * idx
                                end_idx = train_loader.batch_size * (idx + 1)

                            warn("start: {} end: {}".format(
                                start_idx, end_idx))

                            ret = model.validate_step(
                                sess,
                                train_loader.load_specified(
                                    np.arange(start_idx, end_idx)),
                                output_path=save_result_folder,
                                summary=True,
                                visualize=False)
                            t1 = time.time()
                            warn(
                                "valid: {:.2f} sec | remaining {:.2f} sec {}/{}"
                                .format(t1 - t0,
                                        (t1 - t0) * (total_iter - idx), idx,
                                        total_iter))
                        cmd = "./evaluate_object {}".format(iter)
                        os.system(cmd)

                    if extract_false_patch == True:
                        warn("Extracting false patch from Train set")
                        total_iter = int(
                            np.ceil(train_loader.dataset_size /
                                    train_loader.batch_size))
                        # idx = iter
                        save_result_folder = os.path.join(
                            save_result_dir, "{}_false_patch".format(iter))
                        mkdir_p(save_result_folder)
                        for idx in range(total_iter):
                            t0 = time.time()
                            if train_loader.batch_size * (
                                    idx + 1) > train_loader.dataset_size:
                                start_idx = train_loader.dataset_size - train_loader.batch_size
                                end_idx = train_loader.dataset_size
                            else:
                                start_idx = train_loader.batch_size * idx
                                end_idx = train_loader.batch_size * (idx + 1)

                            warn("start: {} end: {}".format(
                                start_idx, end_idx))

                            ret = model.false_patch_step(
                                sess,
                                train_loader.load_specified_train(
                                    np.arange(start_idx, end_idx)),
                                output_path=save_result_folder,
                                summary=True,
                                visualize=False)
                            t1 = time.time()
                            warn(
                                "valid: {:.2f} sec | remaining {:.2f} sec {}/{}"
                                .format(t1 - t0,
                                        (t1 - t0) * (total_iter - idx), idx,
                                        total_iter))

                print('train done. total epoch:{} iter:{}'.format(
                    model.epoch.eval(), model.global_step.eval()))

                # finallly save model
                model.saver.save(sess,
                                 os.path.join(save_model_dir, 'checkpoint'),
                                 global_step=model.global_step)
Ejemplo n.º 15
0
        # gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=cfg.GPU_MEMORY_FRACTION,
        #                     visible_device_list=cfg.TESTING_GPU,
        #                     allow_growth=True)
        gpu_options = tf.GPUOptions(visible_device_list=cfg.TESTING_GPU)
        config = tf.ConfigProto(
            gpu_options=gpu_options,
            device_count={
                "GPU": cfg.TESTING_GPU_USE_COUNT,
            },
            allow_soft_placement=True,
        )

        with tf.Session(config=config) as sess:
            model = RPN3D(cls=cfg.DETECT_OBJ,
                          single_batch_size=args.single_batch_size,
                          avail_gpus=cfg.TESTING_GPU.split(','))
            if tf.train.get_checkpoint_state(save_model_dir):
                print("Reading model parameters from %s" % save_model_dir)
                model.saver.restore(sess,
                                    tf.train.latest_checkpoint(save_model_dir))

            for batch in iterate_data(val_dir,
                                      shuffle=False,
                                      aug=False,
                                      is_testset=False,
                                      batch_size=args.single_batch_size *
                                      cfg.TESTING_GPU_USE_COUNT,
                                      multi_gpu_sum=cfg.TESTING_GPU_USE_COUNT):

                if args.vis:
Ejemplo n.º 16
0
def main(_):
    # TODO: split file support
    with tf.Graph().as_default():
        global save_model_dir
        start_epoch = 0
        global_counter = 0

        gpu_options = tf.GPUOptions(
            per_process_gpu_memory_fraction=cfg.GPU_MEMORY_FRACTION,
            visible_device_list=cfg.GPU_AVAILABLE,
            allow_growth=True)
        config = tf.ConfigProto(
            gpu_options=gpu_options,
            device_count={
                "GPU": cfg.GPU_USE_COUNT,
            },
            allow_soft_placement=True,
        )
        with tf.Session(config=config) as sess:
            model = RPN3D(cls=cfg.DETECT_OBJ,
                          single_batch_size=args.single_batch_size,
                          learning_rate=args.lr,
                          max_gradient_norm=5.0,
                          is_train=True,
                          alpha=args.alpha,
                          beta=args.beta,
                          avail_gpus=cfg.GPU_AVAILABLE.split(','))
            # param init/restore
            if tf.train.get_checkpoint_state(save_model_dir):
                print("Reading model parameters from %s" % save_model_dir)
                model.saver.restore(sess,
                                    tf.train.latest_checkpoint(save_model_dir))
                start_epoch = model.epoch.eval() + 1
                global_counter = model.global_step.eval() + 1
            else:
                print("Created model with fresh parameters.")
                tf.global_variables_initializer().run()

            # train and validate
            is_summary, is_summary_image, is_validate = False, False, False

            summary_interval = 5
            summary_val_interval = 10
            summary_writer = tf.summary.FileWriter(log_dir, sess.graph)

            # training
            for epoch in range(start_epoch, args.max_epoch):
                counter = 0
                for batch in iterate_data(train_dir,
                                          shuffle=True,
                                          aug=True,
                                          is_testset=False,
                                          batch_size=args.single_batch_size *
                                          cfg.GPU_USE_COUNT,
                                          multi_gpu_sum=cfg.GPU_USE_COUNT):

                    counter += 1
                    global_counter += 1

                    if counter % summary_interval == 0:
                        is_summary = True
                    else:
                        is_summary = False

                    start_time = time.time()
                    ret = model.train_step(sess,
                                           batch,
                                           train=True,
                                           summary=is_summary)
                    times = time.time() - start_time

                    print(
                        'train: {} @ epoch:{}/{} loss: {} reg_loss: {} cls_loss: {} time: {}'
                        .format(counter, epoch, args.max_epoch, ret[0], ret[1],
                                ret[2], times))
                    with open('log/train.txt', 'a') as f:
                        f.write(
                            'train: {} @ epoch:{}/{} loss: {} reg_loss: {} cls_loss: {} time: {} \n'
                            .format(counter, epoch, args.max_epoch, ret[0],
                                    ret[1], ret[2], times))

                    #print(counter, summary_interval, counter % summary_interval)
                    if counter % summary_interval == 0:
                        print("summary_interval now")
                        summary_writer.add_summary(ret[-1], global_counter)

                    #print(counter, summary_val_interval, counter % summary_val_interval)
                    if counter % summary_val_interval == 0:
                        print("summary_val_interval now")
                        batch = sample_test_data(
                            val_dir,
                            args.single_batch_size * cfg.GPU_USE_COUNT,
                            multi_gpu_sum=cfg.GPU_USE_COUNT)

                        ret = model.validate_step(sess, batch, summary=True)
                        summary_writer.add_summary(ret[-1], global_counter)

                        try:
                            ret = model.predict_step(sess, batch, summary=True)
                            summary_writer.add_summary(ret[-1], global_counter)
                        except:
                            print("prediction skipped due to error")

                    if check_if_should_pause(args.tag):
                        model.saver.save(sess,
                                         os.path.join(save_model_dir,
                                                      'checkpoint'),
                                         global_step=model.global_step)
                        print('pause and save model @ {} steps:{}'.format(
                            save_model_dir, model.global_step.eval()))
                        sys.exit(0)

                sess.run(model.epoch_add_op)

                model.saver.save(sess,
                                 os.path.join(save_model_dir, 'checkpoint'),
                                 global_step=model.global_step)

            print('train done. total epoch:{} iter:{}'.format(
                epoch, model.global_step.eval()))

            # finallly save model
            model.saver.save(sess,
                             os.path.join(save_model_dir, 'checkpoint'),
                             global_step=model.global_step)