コード例 #1
0
ファイル: calib_by_gt.py プロジェクト: godspeed1989/voxelnet
def calib_file(fin, out_dir):
    data_tag = fin.split('/')[-1].split('.')[-2]
    print(data_tag)
    # predict
    labels = [line.rpartition(' ')[0]
              for line in open(fin, 'r').readlines()]  # skip score
    pred_boxes3d = label_to_gt_box3d([data_tag], [labels],
                                     cls='Car',
                                     coordinate='lidar')[0]
    pred_boxes3d = np.array(pred_boxes3d)
    #print(pred_boxes3d)
    # ground truth
    val_dir = os.path.join(cfg.DATA_DIR, 'validation')
    f_gt = os.path.join(val_dir, 'label_2', data_tag + '.txt')
    gt_labels = [line for line in open(f_gt, 'r').readlines()]
    gt_boxes3d = label_to_gt_box3d([data_tag], [gt_labels],
                                   cls='Car',
                                   coordinate='lidar')[0]
    gt_boxes3d = np.array(gt_boxes3d)
    #print(gt_boxes3d)
    # load
    P, Tr, R = load_calib(os.path.join(cfg.CALIB_DIR, data_tag + '.txt'))
    # calibrate z if the iou with ground truth > 0.5
    if pred_boxes3d.shape[0] and gt_boxes3d.shape[0]:
        iou = py_rbbox_overlaps_3d(
            np.ascontiguousarray(pred_boxes3d, dtype=np.float32),
            np.ascontiguousarray(gt_boxes3d, dtype=np.float32))
        #print(iou)
        for i in range(iou.shape[0]):
            idx = np.argmax(iou[i])
            if iou[i][idx] < 0.5:  # find corresponding gt
                continue
            # !!! HERE 1/2
            # x(0) y(1) z(2) h(3) w(4) l(5) r(6)
            pred_boxes3d[i][2:4] = gt_boxes3d[idx][2:4]
        # write calibrated result
        of_path = os.path.join(out_dir, data_tag + '.txt')
        fout = open(of_path, 'w')
        fin_data = open(fin, 'r').readlines()
        assert pred_boxes3d.shape[0] == len(fin_data)
        num_objs = pred_boxes3d.shape[0]
        for i, line in zip(range(num_objs), fin_data):
            ret = line.split()
            #print(ret)
            label = box3d_to_label([pred_boxes3d[np.newaxis, i]],
                                   [np.zeros(num_objs)], [np.ones(num_objs)],
                                   coordinate='lidar',
                                   P2=P,
                                   T_VELO_2_CAM=Tr,
                                   R_RECT_0=R)[0][0]
            label = label.split()
            #print(label)
            # !!! HERE 2/2
            # ..., h(-8), w(-7), l(-6), x(-5), y(-4), z(-3), r(-2), score(-1)
            ret[-8] = label[-8]
            ret[-3] = label[-3]
            fout.write(' '.join(ret) + '\n')
        '''
コード例 #2
0
ファイル: test.py プロジェクト: godspeed1989/voxelnet
                                                       vis=False,
                                                       is_testset=True)

                # ret: A, B
                # A: (N) tag
                # B: (N, N') (class, x, y, z, h, w, l, rz, score)
                for tag, result in zip(tags, results):
                    of_path = os.path.join(args.output_path, 'data',
                                           tag + '.txt')
                    with open(of_path, 'w+') as f:
                        P, Tr, R = load_calib(
                            os.path.join(cfg.CALIB_DIR, tag + '.txt'))
                        labels = box3d_to_label([result[:, 1:8]],
                                                [result[:, 0]],
                                                [result[:, -1]],
                                                coordinate='lidar',
                                                P2=P,
                                                T_VELO_2_CAM=Tr,
                                                R_RECT_0=R)[0]
                        for line in labels:
                            f.write(line)
                        print_green('write out {} objects to {}'.format(
                            len(labels), tag))
                # dump visualizations
                if args.vis:
                    for tag, front_image, bird_view, heatmap in zip(
                            tags, front_images, bird_views, heatmaps):
                        front_img_path = os.path.join(args.output_path, 'vis',
                                                      tag + '_front.jpg')
                        bird_view_path = os.path.join(args.output_path, 'vis',
                                                      tag + '_bv.jpg')
コード例 #3
0
def aug_data(tag, object_dir, aug_pc=True, use_newtag=False, sampler=None):
    np.random.seed()
    rgb = cv2.imread(os.path.join(object_dir, 'image_2', tag + '.png'))
    assert rgb is not None, print('ERROR rgb {} {}'.format(object_dir, tag))
    rgb = cv2.resize(rgb, (cfg.IMAGE_WIDTH, cfg.IMAGE_HEIGHT))
    #
    lidar_path = os.path.join(object_dir, cfg.VELODYNE_DIR, tag + '.bin')
    lidar = np.fromfile(lidar_path, dtype=np.float32).reshape(-1, 4)
    assert lidar.shape[0], print('ERROR lidar {} {}'.format(object_dir, tag))
    #
    label_path = os.path.join(object_dir, 'label_2', tag + '.txt')
    label = np.array([line
                      for line in open(label_path, 'r').readlines()])  # (N')
    classes = np.array([line.split()[0] for line in label])  # (N')
    # (N', 7) x, y, z, h, w, l, r
    gt_box3d = label_to_gt_box3d([tag],
                                 np.array(label)[np.newaxis, :],
                                 cls=cfg.DETECT_OBJ,
                                 coordinate='lidar')
    gt_box3d = gt_box3d[0]
    if sampler is not None:
        avoid_collision_boxes = gt_box3d.copy()
        avoid_collision_boxes[:, 3] = gt_box3d[:, 5]
        avoid_collision_boxes[:, 4] = gt_box3d[:, 4]
        avoid_collision_boxes[:, 5] = gt_box3d[:, 3]
        sampled_dict = sampler.sample_all(cfg.DETECT_OBJ,
                                          avoid_collision_boxes)
        if sampled_dict is not None:
            sampled_box = sampled_dict["gt_boxes"].copy()
            sampled_box[:, 3] = sampled_dict["gt_boxes"][:, 5]
            sampled_box[:, 4] = sampled_dict["gt_boxes"][:, 4]
            sampled_box[:, 5] = sampled_dict["gt_boxes"][:, 3]
            for i in range(sampled_box.shape[0]):
                sampled_box[i,
                            6] = angle_in_limit(-sampled_dict["gt_boxes"][i,
                                                                          6])
            gt_box3d = np.concatenate([gt_box3d, sampled_box], axis=0)
            classes = np.concatenate([classes, sampled_dict["gt_names"]])
            lidar = np.concatenate([lidar, sampled_dict["points"]], axis=0)

    if aug_pc:
        choice = np.random.randint(0, 10)
        if choice < 4:
            # global rotation
            angle = np.random.uniform(-np.pi / 30, np.pi / 30)
            lidar[:, 0:3] = point_transform(lidar[:, 0:3], 0, 0, 0, rz=angle)
            gt_box3d = box_transform(gt_box3d,
                                     0,
                                     0,
                                     0,
                                     r=angle,
                                     coordinate='lidar')
            newtag = 'aug_{}_1_{:.4f}'.format(tag, angle).replace('.', '_')
        elif choice < 7:
            # global translation
            tx = np.random.uniform(-0.1, -0.1)
            ty = np.random.uniform(-0.1, -0.1)
            tz = np.random.uniform(-0.15, -0.15)
            lidar[:, 0:3] = point_transform(lidar[:, 0:3], tx, ty, tz)
            gt_box3d = box_transform(gt_box3d, tx, ty, tz, coordinate='lidar')
            newtag = 'aug_{}_2_trans'.format(tag).replace('.', '_')
        else:
            # global scaling
            factor = np.random.uniform(0.95, 1.05)
            lidar[:, 0:3] = lidar[:, 0:3] * factor
            gt_box3d[:, 0:6] = gt_box3d[:, 0:6] * factor
            newtag = 'aug_{}_3_{:.4f}'.format(tag, factor).replace('.', '_')
    else:
        newtag = tag

    P, Tr, R = load_calib(os.path.join(cfg.CALIB_DIR, tag + '.txt'))
    label = box3d_to_label(gt_box3d[np.newaxis, ...],
                           classes[np.newaxis, ...],
                           coordinate='lidar',
                           P2=P,
                           T_VELO_2_CAM=Tr,
                           R_RECT_0=R)[0]  # (N')
    voxel_dict = process_pointcloud(tag, lidar)
    if use_newtag:
        return newtag, rgb, lidar, voxel_dict, label
    else:
        return tag, rgb, lidar, voxel_dict, label
コード例 #4
0
def train_epochs(model, train_batcher, rand_test_batcher, val_batcher, params,
                 cfg, ckpt, ckpt_manager, strategy):
    @tf.function
    def distributed_train_step():

        batch = next(train_batcher)
        # print(batch["feature_buffer"].shape)
        # print(batch["coordinate_buffer"].shape)
        # print(batch["targets"].shape)
        # print(batch["pos_equal_one"].shape)
        # print(batch["pos_equal_one_reg"].shape)
        # print(batch["pos_equal_one_sum"].shape)
        # print(batch["neg_equal_one"].shape)
        # print(batch["neg_equal_one_sum"].shape)
        per_replica_losses = strategy.run(
            model.train_step,
            args=(batch["feature_buffer"], batch["coordinate_buffer"],
                  batch["targets"], batch["pos_equal_one"],
                  batch["pos_equal_one_reg"], batch["pos_equal_one_sum"],
                  batch["neg_equal_one"], batch["neg_equal_one_sum"]))
        #print('finish experimental_run_v2.')
        return [
            strategy.reduce(tf.distribute.ReduceOp.SUM,
                            per_replica_loss,
                            axis=None)
            for per_replica_loss in per_replica_losses
        ]

    @tf.function
    def distributed_validate_step():
        #print('start dis vali step.')
        batch = next(rand_test_batcher)
        #print(f'dis vali step. batch:{batch}')
        per_replica_losses = strategy.run(
            model.train_step,
            args=(batch["feature_buffer"], batch["coordinate_buffer"],
                  batch["targets"], batch["pos_equal_one"],
                  batch["pos_equal_one_reg"], batch["pos_equal_one_sum"],
                  batch["neg_equal_one"], batch["neg_equal_one_sum"]))
        return [
            strategy.reduce(tf.distribute.ReduceOp.SUM,
                            per_replica_loss,
                            axis=None)
            for per_replica_loss in per_replica_losses
        ], batch

    dump_vis = params["dump_vis"]  # bool
    kitti_eval_script = cfg.KITTY_EVAL_SCRIPT

    sum_logdir = os.path.join(params["model_dir"], params["model_name"],
                              "train_log/summary_logdir")
    logdir = os.path.join(params["model_dir"], params["model_name"],
                          "train_log/logdir")
    dump_test_logdir = os.path.join(params["model_dir"], params["model_name"],
                                    "train_log/dump_test_logdir")

    os.makedirs(sum_logdir, exist_ok=True)
    os.makedirs(logdir, exist_ok=True)
    os.makedirs(dump_test_logdir, exist_ok=True)

    step = 1

    dump_interval = params["dump_test_interval"]  # 10
    summary_interval = params["summary_interval"]  # 5
    summary_val_interval = params["summary_val_interval"]  # 10
    summary_flush_interval = params["summary_flush_interval"]
    summary_writer = tf.summary.create_file_writer(sum_logdir)

    epoch = ckpt.epoch
    epoch.assign(epoch_counter(ckpt.step.numpy(), train_batcher.num_examples))

    try:
        while epoch.numpy() <= params["n_epochs"]:
            num_batches = train_batcher.num_examples // params[
                "batch_size"] + (1 if train_batcher.num_examples %
                                 params["batch_size"] == 1 else 0)
            for step in range(num_batches):

                epoch.assign(epoch_counter(ckpt.step.numpy(), num_batches))
                if epoch.numpy() > params["n_epochs"]:
                    break

                global_step = ckpt.step.numpy()
                tf.summary.experimental.set_step(global_step)

                #print('begin distributed train step:')
                t0 = time.time()
                losses = distributed_train_step()
                t1 = time.time() - t0
                #print('finish distributed train step, result:')

                print(
                    'train: {} @ epoch:{}/{} global_step:{} loss: {} reg_loss: {} cls_loss: {} cls_pos_loss: {} cls_neg_loss: {} batch time: {:.4f}'
                    .format(step + 1, epoch.numpy(), params["n_epochs"],
                            ckpt.step.numpy(),
                            colored('{:.4f}'.format(losses[0]), "red"),
                            colored('{:.4f}'.format(losses[1]), "magenta"),
                            colored('{:.4f}'.format(losses[2]), "yellow"),
                            colored('{:.4f}'.format(losses[3]), "blue"),
                            colored('{:.4f}'.format(losses[4]), "cyan"), t1))
                with open('{}/train.txt'.format(logdir), 'a') as f:
                    f.write(
                        'train: {} @ epoch:{}/{} global_step:{} loss: {:.4f} reg_loss: {:.4f} cls_loss: {:.4f} cls_pos_loss: {:.4f} cls_neg_loss: {:.4f} batch time: {:.4f} \n'
                        .format(step + 1, epoch.numpy(), params["n_epochs"],
                                ckpt.step.numpy(), losses[0], losses[1],
                                losses[2], losses[3], losses[4], t1))

                if (step + 1) % summary_interval == 0:
                    train_summary(summary_writer,
                                  list(losses) + [model.trainable_variables])

                if (step + 1) % summary_val_interval == 0:
                    print("summary_val_interval now")

                    ret, batch = distributed_validate_step()
                    val_summary(summary_writer, ret)
                    try:
                        ret = predict_step(model,
                                           batch,
                                           train_batcher.anchors,
                                           cfg,
                                           params,
                                           summary=True)
                        pred_summary(summary_writer, ret)
                    except Exception as ex:
                        print("".join(
                            traceback.TracebackException.from_exception(
                                ex).format()))
                        print(f"prediction skipped due to error: {str(ex)}")

                if (step + 1) % summary_flush_interval == 0:
                    summary_writer.flush()

                if global_step % train_batcher.num_examples == 0:
                    ckpt_manager.save(checkpoint_number=ckpt.step.numpy())
                    print("Saved checkpoint for step {}".format(
                        ckpt.step.numpy()))
                    summary_writer.flush()

                ckpt.step.assign_add(1)

            # dump test data every 10 epochs

            if (epoch.numpy()) % dump_interval == 0:
                print("dump_test")
                # create output folder
                os.makedirs(os.path.join(dump_test_logdir, str(epoch.numpy())),
                            exist_ok=True)
                os.makedirs(os.path.join(dump_test_logdir, str(epoch.numpy()),
                                         'data'),
                            exist_ok=True)
                if dump_vis:
                    os.makedirs(os.path.join(dump_test_logdir,
                                             str(epoch.numpy()), 'vis'),
                                exist_ok=True)

                for eval_step, batch in enumerate(val_batcher.batcher):
                    if dump_vis:
                        res = predict_step(model,
                                           batch,
                                           train_batcher.anchors,
                                           cfg,
                                           params,
                                           summary=False,
                                           vis=True)
                        tags, results, front_images, bird_views, heatmaps = res[
                            "tag"], res["scores"], res["front_image"], res[
                                "bird_view"], res["heatmap"]
                    else:
                        res = predict_step(model,
                                           batch,
                                           train_batcher.anchors,
                                           cfg,
                                           params,
                                           summary=False,
                                           vis=False)
                        tags, results = res["tag"], res["scores"]
                    for tag, result in zip(tags, results):
                        of_path = os.path.join(dump_test_logdir,
                                               str(epoch.numpy()), 'data',
                                               tag + '.txt')
                        with open(of_path, 'w+') as f:
                            labels = box3d_to_label(tag, [result[:, 1:8]],
                                                    [result[:, 0]],
                                                    [result[:, -1]],
                                                    coordinate='lidar')[0]
                            for line in labels:
                                f.write(line)
                            print('write out {} objects to {}'.format(
                                len(labels), tag))
                    # dump visualizations
                    if dump_vis:
                        for tag, front_image, bird_view, heatmap in zip(
                                tags, front_images, bird_views, heatmaps):
                            front_img_path = os.path.join(
                                dump_test_logdir, str(epoch.numpy()), 'vis',
                                tag + '_front.jpg')
                            bird_view_path = os.path.join(
                                dump_test_logdir, str(epoch.numpy()), 'vis',
                                tag + '_bv.jpg')
                            heatmap_path = os.path.join(
                                dump_test_logdir, str(epoch.numpy()), 'vis',
                                tag + '_heatmap.jpg')
                            cv2.imwrite(front_img_path, front_image)
                            cv2.imwrite(bird_view_path, bird_view)
                            cv2.imwrite(heatmap_path, heatmap)
                            print(
                                'write out 3 (front image, bird view and heatmap) jpegs to {}'
                                .format(tag))

                # execute evaluation code
                #cmd_1 = "./"+kitti_eval_script
                #cmd_2 = os.path.join(cfg.DATA_DIR, "validation", "label_2")
                #cmd_3 = os.path.join( dump_test_logdir, str(epoch.numpy()) )
                #cmd_4 = os.path.join( dump_test_logdir, str(epoch.numpy()), 'log' )
                #os.system( " ".join( [cmd_1, cmd_2, cmd_3, cmd_4] ) ).read()

    except KeyboardInterrupt:
        ckpt_manager.save(checkpoint_number=ckpt.step.numpy())
        print("Saved checkpoint for step {}".format(ckpt.step.numpy()))
        summary_writer.flush()
    except Exception as e:
        print(f"Unexpected exception happened: {str(e)}")
コード例 #5
0
ファイル: train.py プロジェクト: godspeed1989/voxelnet
def main(_):
    global log_f
    timestr = time.strftime("%b-%d_%H-%M-%S", time.localtime())
    log_f = open('log/train_{}.txt'.format(timestr), 'w')
    log_print(str(cfg))
    # TODO: split file support
    with tf.Graph().as_default():
        global save_model_dir
        start_epoch = 0
        global_counter = 0

        gpu_options = tf.GPUOptions(
            per_process_gpu_memory_fraction=cfg.GPU_MEMORY_FRACTION,
            visible_device_list=cfg.GPU_AVAILABLE,
            allow_growth=True)
        config = tf.ConfigProto(
            gpu_options=gpu_options,
            device_count={
                "GPU": cfg.GPU_USE_COUNT,
            },
            allow_soft_placement=True,
        )
        with tf.Session(config=config) as sess:
            model = RPN3D(cls=cfg.DETECT_OBJ,
                          single_batch_size=args.single_batch_size,
                          learning_rate=args.lr,
                          max_gradient_norm=5.0,
                          alpha=args.alpha,
                          beta=args.beta,
                          avail_gpus=cfg.GPU_AVAILABLE.split(','))
            # param init/restore
            if args.restore and tf.train.get_checkpoint_state(save_model_dir):
                log_print("Reading model parameters from %s" % save_model_dir)
                model.saver.restore(sess,
                                    tf.train.latest_checkpoint(save_model_dir))
                start_epoch = model.epoch.eval() + 1
                global_counter = model.global_step.eval() + 1
            else:
                log_print("Created model with fresh parameters.")
                tf.global_variables_initializer().run()

            if cfg.FEATURE_NET_TYPE == 'FeatureNet_AE' and cfg.FeatureNet_AE_WPATH:
                ae_checkpoint_file = tf.train.latest_checkpoint(
                    cfg.FeatureNet_AE_WPATH)
                log_print("Load Pretrained FeatureNet_AE weights %s" %
                          ae_checkpoint_file)
                ae_vars = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES,
                                            scope='ae_encoder')
                ae_saver = tf.train.Saver(
                    var_list={v.op.name: v
                              for v in ae_vars})
                ae_saver.restore(sess, ae_checkpoint_file)
            if cfg.FEATURE_NET_TYPE == 'FeatureNet_VAE' and cfg.FeatureNet_VAE_WPATH:
                vae_checkpoint_file = tf.train.latest_checkpoint(
                    cfg.FeatureNet_VAE_WPATH)
                log_print("Load Pretrained FeatureNet_VAE weights %s" %
                          vae_checkpoint_file)
                vae_vars = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES,
                                             scope='vae_encoder')
                vae_saver = tf.train.Saver(
                    var_list={v.op.name: v
                              for v in vae_vars})
                vae_saver.restore(sess, vae_checkpoint_file)

            # train and validate
            is_summary, is_summary_image, is_validate = False, False, False

            summary_interval = 5
            summary_val_interval = 20
            summary_writer = tf.summary.FileWriter(log_dir, sess.graph)

            parameter_num = np.sum(
                [np.prod(v.shape.as_list()) for v in tf.trainable_variables()])
            log_print('Parameter number: {}'.format(parameter_num))

            # training
            for epoch in range(start_epoch, args.max_epoch):
                counter = 0
                batch_time = time.time()
                for batch in iterate_data(train_dir,
                                          db_sampler=sampler,
                                          shuffle=True,
                                          aug=AUG_DATA,
                                          is_testset=False,
                                          batch_size=args.single_batch_size *
                                          cfg.GPU_USE_COUNT,
                                          multi_gpu_sum=cfg.GPU_USE_COUNT):
                    counter += 1
                    global_counter += 1

                    if counter % summary_interval == 0:
                        is_summary = True
                    else:
                        is_summary = False

                    start_time = time.time()
                    ret = model.train_step(sess,
                                           batch,
                                           train=True,
                                           summary=is_summary)
                    forward_time = time.time() - start_time
                    batch_time = time.time() - batch_time

                    log_print(
                        'train: {} @ epoch:{}/{} loss: {:.4f} reg_loss: {:.4f} cls_loss: {:.4f} cls_pos_loss: {:.4f} cls_neg_loss: {:.4f} forward time: {:.4f} batch time: {:.4f}'
                        .format(counter, epoch, args.max_epoch, ret[0], ret[1],
                                ret[2], ret[3], ret[4], forward_time,
                                batch_time),
                        write=is_summary)

                    #print(counter, summary_interval, counter % summary_interval)
                    if counter % summary_interval == 0:
                        log_print("summary_interval now")
                        summary_writer.add_summary(ret[-1], global_counter)

                    #print(counter, summary_val_interval, counter % summary_val_interval)
                    if counter % summary_val_interval == 0:
                        log_print("summary_val_interval now")
                        # Random sample single batch data
                        batch = sample_test_data(
                            val_dir,
                            args.single_batch_size * cfg.GPU_USE_COUNT,
                            multi_gpu_sum=cfg.GPU_USE_COUNT)

                        ret = model.validate_step(sess, batch, summary=True)
                        summary_writer.add_summary(ret[-1], global_counter)
                        log_print(
                            'validation: loss: {:.4f} reg_loss: {:.4f} cls_loss: {:.4f} '
                            .format(ret[0], ret[1], ret[2]))

                        with warnings.catch_warnings():
                            warnings.filterwarnings('error')
                            try:
                                ret = model.predict_step(sess,
                                                         batch,
                                                         summary=True)
                                summary_writer.add_summary(
                                    ret[-1], global_counter)
                            except:
                                log_print('prediction skipped due to error',
                                          'red')

                    if check_if_should_pause(args.tag):
                        model.saver.save(sess,
                                         os.path.join(save_model_dir, timestr),
                                         global_step=model.global_step)
                        log_print('pause and save model @ {} steps:{}'.format(
                            save_model_dir, model.global_step.eval()))
                        sys.exit(0)

                    batch_time = time.time()

                sess.run(model.epoch_add_op)

                model.saver.save(sess,
                                 os.path.join(save_model_dir, timestr),
                                 global_step=model.global_step)

                # dump test data every 10 epochs
                if (epoch + 1) % 10 == 0:
                    # create output folder
                    os.makedirs(os.path.join(args.output_path, str(epoch)),
                                exist_ok=True)
                    os.makedirs(os.path.join(args.output_path, str(epoch),
                                             'data'),
                                exist_ok=True)
                    if args.vis:
                        os.makedirs(os.path.join(args.output_path, str(epoch),
                                                 'vis'),
                                    exist_ok=True)

                    for batch in iterate_data(
                            val_dir,
                            shuffle=False,
                            aug=False,
                            is_testset=False,
                            batch_size=args.single_batch_size *
                            cfg.GPU_USE_COUNT,
                            multi_gpu_sum=cfg.GPU_USE_COUNT):
                        if args.vis:
                            tags, results, front_images, bird_views, heatmaps = model.predict_step(
                                sess, batch, summary=False, vis=True)
                        else:
                            tags, results = model.predict_step(sess,
                                                               batch,
                                                               summary=False,
                                                               vis=False)

                        for tag, result in zip(tags, results):
                            of_path = os.path.join(args.output_path,
                                                   str(epoch), 'data',
                                                   tag + '.txt')
                            with open(of_path, 'w+') as f:
                                P, Tr, R = load_calib(
                                    os.path.join(cfg.CALIB_DIR, tag + '.txt'))
                                labels = box3d_to_label([result[:, 1:8]],
                                                        [result[:, 0]],
                                                        [result[:, -1]],
                                                        coordinate='lidar',
                                                        P2=P,
                                                        T_VELO_2_CAM=Tr,
                                                        R_RECT_0=R)[0]
                                for line in labels:
                                    f.write(line)
                                log_print('write out {} objects to {}'.format(
                                    len(labels), tag))
                        # dump visualizations
                        if args.vis:
                            for tag, front_image, bird_view, heatmap in zip(
                                    tags, front_images, bird_views, heatmaps):
                                front_img_path = os.path.join(
                                    args.output_path, str(epoch), 'vis',
                                    tag + '_front.jpg')
                                bird_view_path = os.path.join(
                                    args.output_path, str(epoch), 'vis',
                                    tag + '_bv.jpg')
                                heatmap_path = os.path.join(
                                    args.output_path, str(epoch), 'vis',
                                    tag + '_heatmap.jpg')
                                cv2.imwrite(front_img_path, front_image)
                                cv2.imwrite(bird_view_path, bird_view)
                                cv2.imwrite(heatmap_path, heatmap)

                    # execute evaluation code
                    cmd_1 = "./kitti_eval/launch_test.sh"
                    cmd_2 = os.path.join(args.output_path, str(epoch))
                    cmd_3 = os.path.join(args.output_path, str(epoch), 'log')
                    os.system(" ".join([cmd_1, cmd_2, cmd_3]))

            log_print('train done. total epoch:{} iter:{}'.format(
                epoch, model.global_step.eval()))

            # finallly save model
            model.saver.save(sess,
                             os.path.join(save_model_dir, 'checkpoint'),
                             global_step=model.global_step)
コード例 #6
0
def run():
    os.environ['CUDA_VISIBLE_DEVICES'] = ','.join(str(x) for x in cfg.GPU_AVAILABLE)

    start_epoch = 0
    global_counter = 0
    min_loss = sys.float_info.max

    # Build data loader
    train_dataset = Dataset(os.path.join(cfg.DATA_DIR, 'training'), shuffle = True, aug = True, is_testset = False)
    train_dataloader = DataLoader(train_dataset, batch_size = args.batch_size, shuffle = True, collate_fn = collate_fn,
                                  num_workers = args.workers, pin_memory = False)

    val_dataset = Dataset(os.path.join(cfg.DATA_DIR, 'validation'), shuffle = False, aug = False, is_testset = False)
    val_dataloader = DataLoader(val_dataset, batch_size = args.batch_size, shuffle = False, collate_fn = collate_fn,
                                num_workers = args.workers, pin_memory = False)
    val_dataloader_iter = iter(val_dataloader)

    # Build model
    model = RPN3D(cfg.DETECT_OBJ, args.alpha, args.beta)

    # Resume model if necessary
    if args.resumed_model:
        model_file = os.path.join(save_model_dir, args.resumed_model)
        if os.path.isfile(model_file):
            checkpoint = torch.load(model_file)
            start_epoch = checkpoint['epoch']
            global_counter = checkpoint['global_counter']
            min_loss = checkpoint['min_loss']
            model.load_state_dict(checkpoint['state_dict'])
            print(("=> Loaded checkpoint '{}' (epoch {}, global_counter {})".format(
                args.resumed_model, checkpoint['epoch'], checkpoint['global_counter'])))
        else:
            print(("=> No checkpoint found at '{}'".format(args.resumed_model)))

    model = nn.DataParallel(model).cuda()

    # Optimization scheme
    optimizer = optim.Adam(model.parameters(), lr = args.lr)

    lr_sched = optim.lr_scheduler.MultiStepLR(optimizer, [150])

    # Init file log
    log = open(os.path.join(args.log_root, args.log_name), 'a')

    # Init TensorBoardX writer
    summary_writer = SummaryWriter(log_dir)

    # train and validate
    tot_epoch = start_epoch
    for epoch in range(start_epoch, args.max_epoch):
        # Learning rate scheme
        lr_sched.step()

        counter = 0
        batch_time = time.time()

        tot_val_loss = 0
        tot_val_times = 0

        for (i, data) in enumerate(train_dataloader):

            model.train(True)   # Training mode

            counter += 1
            global_counter += 1

            start_time = time.time()

            # Forward pass for training
            _, _, loss, cls_loss, reg_loss, cls_pos_loss_rec, cls_neg_loss_rec = model(data)

            forward_time = time.time() - start_time

            loss.backward()

            # Clip gradient
            clip_grad_norm_(model.parameters(), 5)

            optimizer.step()
            optimizer.zero_grad()

            batch_time = time.time() - batch_time

            if counter % args.print_freq == 0:
                # Print training info
                info = 'Train: {} @ epoch:{}/{} loss: {:.4f} reg_loss: {:.4f} cls_loss: {:.4f} cls_pos_loss: {:.4f} ' \
                       'cls_neg_loss: {:.4f} forward time: {:.4f} batch time: {:.4f}'.format(
                    counter, epoch + 1, args.max_epoch, loss.item(), reg_loss.item(), cls_loss.item(), cls_pos_loss_rec.item(),
                    cls_neg_loss_rec.item(), forward_time, batch_time)
                info = '{}\t'.format(time.asctime(time.localtime())) + info
                print(info)

                # Write training info to log
                log.write(info + '\n')
                log.flush()

            # Summarize training info
            if counter % args.summary_interval == 0:
                print("summary_interval now")
                summary_writer.add_scalars(str(epoch + 1), {'train/loss' : loss.item(),
                                                            'train/reg_loss' : reg_loss.item(),
                                                            'train/cls_loss' : cls_loss.item(),
                                                            'train/cls_pos_loss' : cls_pos_loss_rec.item(),
                                                            'train/cls_neg_loss' : cls_neg_loss_rec.item()}, global_counter)

            # Summarize validation info
            if counter % args.summary_val_interval == 0:
                print('summary_val_interval now')

                with torch.no_grad():
                    model.train(False)  # Validation mode

                    val_data = next(val_dataloader_iter)    # Sample one batch

                    # Forward pass for validation and prediction
                    probs, deltas, val_loss, val_cls_loss, val_reg_loss, cls_pos_loss_rec, cls_neg_loss_rec = model(val_data)

                    summary_writer.add_scalars(str(epoch + 1), {'validate/loss': loss.item(),
                                                                'validate/reg_loss': reg_loss.item(),
                                                                'validate/cls_loss': cls_loss.item(),
                                                                'validate/cls_pos_loss': cls_pos_loss_rec.item(),
                                                                'validate/cls_neg_loss': cls_neg_loss_rec.item()}, global_counter)

                    try:
                        # Prediction
                        tags, ret_box3d_scores, ret_summary = model.module.predict(val_data, probs, deltas, summary = True)

                        for (tag, img) in ret_summary:
                            img = img[0].transpose(2, 0, 1)
                            summary_writer.add_image(tag, img, global_counter)
                    except:
                        raise Exception('Prediction skipped due to an error!')

                    # Add sampled data loss
                    tot_val_loss += val_loss.item()
                    tot_val_times += 1

            batch_time = time.time()

        # Save the best model
        avg_val_loss = tot_val_loss / float(tot_val_times)
        is_best = avg_val_loss < min_loss
        min_loss = min(avg_val_loss, min_loss)
        save_checkpoint({'epoch': epoch + 1, 'global_counter': global_counter, 'state_dict': model.module.state_dict(), 'min_loss': min_loss},
                        is_best, args.saved_model.format(cfg.DETECT_OBJ))

        # Dump test data every 10 epochs
        if (epoch + 1) % args.val_epoch == 0:   # Time consuming
            # Create output folder
            os.makedirs(os.path.join(args.output_path, str(epoch + 1)), exist_ok = True)
            os.makedirs(os.path.join(args.output_path, str(epoch + 1), 'data'), exist_ok = True)
            os.makedirs(os.path.join(args.output_path, str(epoch + 1), 'log'), exist_ok=True)

            if args.vis:
                os.makedirs(os.path.join(args.output_path, str(epoch + 1), 'vis'), exist_ok = True)

            model.train(False)  # Validation mode

            with torch.no_grad():
                for (i, val_data) in enumerate(val_dataloader):

                    # Forward pass for validation and prediction
                    probs, deltas, val_loss, val_cls_loss, val_reg_loss, cls_pos_loss_rec, cls_neg_loss_rec = model(val_data)

                    front_images, bird_views, heatmaps = None, None, None
                    if args.vis:
                        tags, ret_box3d_scores, front_images, bird_views, heatmaps = \
                            model.module.predict(val_data, probs, deltas, summary = False, vis = True)
                    else:
                        tags, ret_box3d_scores = model.module.predict(val_data, probs, deltas, summary = False, vis = False)

                    # tags: (N)
                    # ret_box3d_scores: (N, N'); (class, x, y, z, h, w, l, rz, score)
                    for tag, score in zip(tags, ret_box3d_scores):
                        output_path = os.path.join(args.output_path, str(epoch + 1), 'data', tag + '.txt')
                        with open(output_path, 'w+') as f:
                            labels = box3d_to_label([score[:, 1:8]], [score[:, 0]], [score[:, -1]], coordinate = 'lidar')[0]
                            for line in labels:
                                f.write(line)
                            print('Write out {} objects to {}'.format(len(labels), tag))

                    # Dump visualizations
                    if args.vis:
                        for tag, front_image, bird_view, heatmap in zip(tags, front_images, bird_views, heatmaps):
                            front_img_path = os.path.join(args.output_path, str(epoch + 1), 'vis', tag + '_front.jpg')
                            bird_view_path = os.path.join(args.output_path, str(epoch + 1), 'vis', tag + '_bv.jpg')
                            heatmap_path = os.path.join(args.output_path, str(epoch + 1), 'vis', tag + '_heatmap.jpg')
                            cv2.imwrite(front_img_path, front_image)
                            cv2.imwrite(bird_view_path, bird_view)
                            cv2.imwrite(heatmap_path, heatmap)

            # Run evaluation code
            cmd_1 = './eval/KITTI/launch_test.sh'
            cmd_2 = os.path.join(args.output_path, str(epoch + 1))
            cmd_3 = os.path.join(args.output_path, str(epoch + 1), 'log')
            os.system(' '.join([cmd_1, cmd_2, cmd_3]))

        tot_epoch = epoch + 1

    print('Train done with total epoch:{}, iter:{}'.format(tot_epoch, global_counter))

    # Close TensorBoardX writer
    summary_writer.close()
コード例 #7
0
def run():
    os.environ['CUDA_VISIBLE_DEVICES'] = ','.join(
        str(x) for x in cfg.GPU_AVAILABLE)

    # Build data loader
    val_dataset = Dataset(os.path.join(cfg.DATA_DIR, 'validation'),
                          shuffle=False,
                          aug=False,
                          is_testset=False)
    val_dataloader = DataLoader(val_dataset,
                                batch_size=args.batch_size,
                                shuffle=False,
                                collate_fn=collate_fn,
                                num_workers=args.workers,
                                pin_memory=False)

    # Build model
    model = RPN3D(cfg.DETECT_OBJ)

    # Resume model
    if args.resumed_model:
        model_file = os.path.join(save_model_dir, args.resumed_model)
        if os.path.isfile(model_file):
            checkpoint = torch.load(model_file)
            model.load_state_dict(checkpoint['state_dict'])
            print(("=> Loaded checkpoint '{}' (epoch {}, global_counter {})".
                   format(args.resumed_model, checkpoint['epoch'],
                          checkpoint['global_counter'])))
        else:
            print(
                ("=> No checkpoint found at '{}'".format(args.resumed_model)))
    else:
        raise Exception('No pre-trained model to test!')

    model = nn.DataParallel(model).cuda()

    model.train(False)  # Validation mode

    with torch.no_grad():
        for (i, val_data) in enumerate(val_dataloader):

            # Forward pass for validation and prediction
            probs, deltas, val_loss, val_cls_loss, val_reg_loss, cls_pos_loss_rec, cls_neg_loss_rec = model(
                val_data)

            front_images, bird_views, heatmaps = None, None, None
            if args.vis:
                tags, ret_box3d_scores, front_images, bird_views, heatmaps = \
                    model.module.predict(val_data, probs, deltas, summary = False, vis = True)
            else:
                tags, ret_box3d_scores = model.module.predict(val_data,
                                                              probs,
                                                              deltas,
                                                              summary=False,
                                                              vis=False)

            # tags: (N)
            # ret_box3d_scores: (N, N'); (class, x, y, z, h, w, l, rz, score)
            for tag, score in zip(tags, ret_box3d_scores):
                output_path = os.path.join(args.output_path, 'data',
                                           tag + '.txt')
                with open(output_path, 'w+') as f:
                    labels = box3d_to_label([score[:, 1:8]], [score[:, 0]],
                                            [score[:, -1]],
                                            coordinate='lidar')[0]
                    for line in labels:
                        f.write(line)
                    print('Write out {} objects to {}'.format(
                        len(labels), tag))

            # Dump visualizations
            if args.vis:
                for tag, front_image, bird_view, heatmap in zip(
                        tags, front_images, bird_views, heatmaps):
                    front_img_path = os.path.join(args.output_path, 'vis',
                                                  tag + '_front.jpg')
                    bird_view_path = os.path.join(args.output_path, 'vis',
                                                  tag + '_bv.jpg')
                    heatmap_path = os.path.join(args.output_path, 'vis',
                                                tag + '_heatmap.jpg')
                    cv2.imwrite(front_img_path, front_image)
                    cv2.imwrite(bird_view_path, bird_view)
                    cv2.imwrite(heatmap_path, heatmap)