Exemple #1
0
def generate_test(kitti_dir, is_hypotheses=True):

    for sequence in os.listdir(kitti_dir):
        # sequence = '2011_09_28_drive_0039_sync'
        print("Processing %s" % sequence)
        sequence_dir = os.path.join(kitti_dir, sequence)
        detection_dir = os.path.join(sequence_dir, 'det')
        if not os.path.isdir(sequence_dir):
            continue
        if is_hypotheses:
            det_file = os.path.join(detection_dir, sequence + '_video_rgb_hypotheses.pickle')
            output_file = os.path.join(detection_dir, sequence + '_video_rgb_hypotheses_3D.pickle')
            result_dir = os.path.join(detection_dir, 'hypotheses_3D')
        else:
            det_file = os.path.join(detection_dir, sequence + '_video_rgb_detection.pickle')
            output_file = os.path.join(detection_dir, sequence + '_video_rgb_detection_3D.pickle')
            result_dir = os.path.join(detection_dir, 'detection_3D')


        test_dataset = provider.FrustumDataset(npoints=NUM_POINT, split='val',
                                       rotate_to_center=True, overwritten_data_path=det_file,
                                       from_rgb_detection=True, one_hot=True,
                                       is_hypotheses=is_hypotheses,)
        test_from_video_detection(test_dataset, output_file, result_dir, sequence_dir)
parser.add_argument('--idx_path', default=None, help='filename of txt where each line is a data idx, used for rgb detection -- write <id>.txt for all frames. [default: None]')
parser.add_argument('--dump_result', action='store_true', help='If true, also dump results to .pickle file')
FLAGS = parser.parse_args()

# Set training configurations
BATCH_SIZE = FLAGS.batch_size
MODEL_PATH = FLAGS.model_path
GPU_INDEX = FLAGS.gpu
NUM_POINT = FLAGS.num_point
MODEL = importlib.import_module(FLAGS.model)
NUM_CLASSES = 2
NUM_CHANNEL = 4

# Load Frustum Datasets.
TEST_DATASET = provider.FrustumDataset(npoints=NUM_POINT, split='val',
    rotate_to_center=True, overwritten_data_path=FLAGS.data_path,
    from_rgb_detection=FLAGS.from_rgb_detection, one_hot=True)

def get_session_and_ops(batch_size, num_point):
    ''' Define model graph, load model parameters,
    create session and return session handle and tensors
    '''
    with tf.Graph().as_default():
        with tf.device('/gpu:'+str(GPU_INDEX)):
            pointclouds_pl, one_hot_vec_pl, labels_pl, centers_pl, \
            heading_class_label_pl, heading_residual_label_pl, \
            size_class_label_pl, size_residual_label_pl = \
                MODEL.placeholder_inputs(batch_size, num_point)
            is_training_pl = tf.placeholder(tf.bool, shape=())
            end_points = MODEL.get_model(pointclouds_pl, one_hot_vec_pl,
                is_training_pl)
os.system('cp %s %s' % (MODEL_FILE, LOG_DIR))  # bkp of model def
os.system('cp %s %s' % (os.path.join(BASE_DIR, 'train.py'), LOG_DIR))
LOG_FOUT = open(os.path.join(LOG_DIR, 'log_train.txt'), 'w')
LOG_FOUT.write(str(FLAGS) + '\n')
# BN_INIT_DECAY = 0.5
# BN_DECAY_DECAY_RATE = 0.5
# BN_DECAY_DECAY_STEP = float(DECAY_STEP)
# BN_DECAY_CLIP = 0.99

# Load Frustum Datasets. Use default data paths.
if FLAGS.dataset == 'kitti':
    TRAIN_DATASET = provider.FrustumDataset(
        npoints=NUM_POINT,
        split=FLAGS.train_sets,
        rotate_to_center=True,
        random_flip=True,
        random_shift=True,
        one_hot=True,
        overwritten_data_path='kitti/frustum_' + FLAGS.objtype + '_' +
        FLAGS.train_sets + '.pickle')
    TEST_DATASET = provider.FrustumDataset(
        npoints=NUM_POINT,
        split=FLAGS.val_sets,
        rotate_to_center=True,
        one_hot=True,
        overwritten_data_path='kitti/frustum_' + FLAGS.objtype + '_' +
        FLAGS.val_sets + '.pickle')
elif FLAGS.dataset == 'nuscenes2kitti':
    SENSOR = FLAGS.sensor
    overwritten_data_path_prefix = 'nuscenes2kitti/frustum_' + FLAGS.objtype + '_' + SENSOR + '_'
    TRAIN_DATASET = provider.FrustumDataset(
Exemple #4
0
MODEL = importlib.import_module(FLAGS.model) # import network module
MODEL_FILE = os.path.join(ROOT_DIR, 'models', FLAGS.model+'.py')
LOG_DIR = FLAGS.log_dir
if not os.path.exists(LOG_DIR): os.mkdir(LOG_DIR)
os.system('cp %s %s' % (MODEL_FILE, LOG_DIR)) # bkp of model def
os.system('cp %s %s' % (os.path.join(BASE_DIR, 'train.py'), LOG_DIR))
LOG_FOUT = open(os.path.join(LOG_DIR, 'log_train.txt'), 'w')
LOG_FOUT.write(str(FLAGS)+'\n')

BN_INIT_DECAY = 0.5
BN_DECAY_DECAY_RATE = 0.5
BN_DECAY_DECAY_STEP = float(DECAY_STEP)
BN_DECAY_CLIP = 0.99

# Load Frustum Datasets. Use default data paths.
TRAIN_DATASET = provider.FrustumDataset(npoints=NUM_POINT, split='train',
    rotate_to_center=True, random_flip=True, random_shift=True, one_hot=True)
TEST_DATASET = provider.FrustumDataset(npoints=NUM_POINT, split='val',
    rotate_to_center=True, one_hot=True)

def log_string(out_str):
    LOG_FOUT.write(out_str+'\n')
    LOG_FOUT.flush()
    print(out_str)

def get_learning_rate(batch):
    learning_rate = tf.train.exponential_decay(
                        BASE_LEARNING_RATE,  # Base learning rate.
                        batch * BATCH_SIZE,  # Current index into the dataset.
                        DECAY_STEP,          # Decay step.
                        DECAY_RATE,          # Decay rate.
                        staircase=True)
                           pathsplit[4], pathsplit[5],
                           'results_train_eval_190/')
if not os.path.exists(OUTPUT_FILE):
    os.mkdir(OUTPUT_FILE)
BN_INIT_DECAY = 0.5
BN_DECAY_DECAY_RATE = 0.5
BN_DECAY_DECAY_STEP = float(DECAY_STEP)
BN_DECAY_CLIP = 0.99

# Load Frustum Datasets. Use default data paths.
#TRAIN_DATASET = provider.FrustumDataset(npoints=NUM_POINT, split='train',res=0,
#rotate_to_center=False, random_flip=False, random_shift=True, one_hot=True)

EVAL_DATASET_224 = provider.FrustumDataset(npoints=NUM_POINT,
                                           database="KITTI",
                                           split='val',
                                           res="224",
                                           rotate_to_center=True,
                                           one_hot=True)

EVAL_DATASET_704 = provider.FrustumDataset(npoints=NUM_POINT,
                                           database="KITTI",
                                           split='val',
                                           res="704",
                                           rotate_to_center=True,
                                           one_hot=True)

#TEST_DATASET_224 =  provider.FrustumDataset('pc_radar_2','KITTI_2',npoints=NUM_POINT, split='test',rotate_to_center=False, one_hot=True,all_batches = True, translate_radar_center=False, store_data=True, proposals_3 =False ,no_color=True)

TEST_DATASET_224 = provider.FrustumDataset(npoints=NUM_POINT,
                                           database="KITTI_2",
                                           split='test',
Exemple #6
0
os.system('cp %s %s' % (MODEL_FILE, LOG_DIR))  # bkp of model def
os.system('cp %s %s' % (os.path.join(BASE_DIR, 'train.py'), LOG_DIR))
LOG_FOUT = open(os.path.join(LOG_DIR, 'log_train.txt'), 'w')
LOG_FOUT.write(str(FLAGS) + '\n')

BN_INIT_DECAY = 0.5
BN_DECAY_DECAY_RATE = 0.5
BN_DECAY_DECAY_STEP = float(DECAY_STEP)
BN_DECAY_CLIP = 0.99

# Load Frustum Datasets. Use default data paths.
TRAIN_DATASET = provider.FrustumDataset(
    npoints=NUM_POINT,
    split='train',
    overwritten_data_path=
    '/data/ssd/public/jlliu/frustum-pointnets/avod_prop/frustum_caronly_train%s.pickle'
    % FLAGS.pickle,
    rotate_to_center=True,
    random_flip=False,
    random_shift=True,
    extra_feature=True)
TEST_DATASET = provider.FrustumDataset(
    npoints=NUM_POINT,
    split='val',
    overwritten_data_path=
    '/data/ssd/public/jlliu/frustum-pointnets/avod_prop/frustum_caronly_val%s.pickle'
    % FLAGS.pickle,
    rotate_to_center=True,
    extra_feature=True)


def log_string(out_str):
Exemple #7
0
BN_DECAY_DECAY_RATE = 0.5
BN_DECAY_DECAY_STEP = float(DECAY_STEP)
BN_DECAY_CLIP = 0.99

# Load Frustum Datasets. 
if FLAGS.two_frustum:
    data_path_train = os.path.join(ROOT_DIR,'kitti/'+FLAGS.frustum_folder+'/two_frustum_carpedcyc_train.pickle')
    data_path_val   = os.path.join(ROOT_DIR,'kitti/'+FLAGS.frustum_folder+'/two_frustum_carpedcyc_val.pickle')
else:
    data_path_train = os.path.join(ROOT_DIR,'kitti/'+FLAGS.frustum_folder+'/frustum_carpedcyc_train.pickle')
    data_path_val   = os.path.join(ROOT_DIR,'kitti/'+FLAGS.frustum_folder+'/frustum_carpedcyc_val.pickle')

print("Loading train frustums from: ",data_path_train)
print("Loading val   frustums from: ",data_path_val)

TRAIN_DATASET = provider.FrustumDataset(overwritten_data_path= data_path_train,npoints=NUM_POINT, split='train',rotate_to_center=True, random_flip=True, random_shift=True, one_hot=True)
TEST_DATASET = provider.FrustumDataset(overwritten_data_path= data_path_val,npoints=NUM_POINT, split='val',rotate_to_center=True, one_hot=True)

def log_string(out_str):
    LOG_FOUT.write(out_str+'\n')
    LOG_FOUT.flush()
    print(out_str)

def get_learning_rate(batch):
    learning_rate = tf.train.exponential_decay(
                        BASE_LEARNING_RATE,  # Base learning rate.
                        batch * BATCH_SIZE,  # Current index into the dataset.
                        DECAY_STEP,          # Decay step.
                        DECAY_RATE,          # Decay rate.
                        staircase=True)
    learing_rate = tf.maximum(learning_rate, 0.00001) # CLIP THE LEARNING RATE!
else:
    lstm_parameters['flags']['one_branch'] = False
lstm_parameters['flags']['only_whl'] = FLAGS.only_whl
lstm_parameters['flags']['temp_attention'] = FLAGS.temp_attention
lstm_parameters['flags']['add_center'] = FLAGS.add_center
lstm_parameters['flags']['output_attention'] = FLAGS.output_attention
lstm_parameters['flags']['time_indication'] = FLAGS.time_indication
lstm_parameters['flags']['dropout'] = FLAGS.dropout
lstm_parameters['flags']['random_time_sampling'] = False
lstm_parameters['random_n'] = FLAGS.random_n
# Load Frustum Datasets.
if VKITTI:
    # Load Virtual KITTI dataset
    overwritten_val_data_path = os.path.join(ROOT_DIR, 'kitti/frustum_caronly_vkitti_val.pickle')
    TEST_DATASET = provider.FrustumDataset(npoints=NUM_POINT, split='val', 
                                           rotate_to_center=True, one_hot=True,
                                           overwritten_data_path=overwritten_val_data_path,
                                           tracks=FLAGS.tracking,tau=FLAGS.tau,feat_len=feat_len)
elif TRACKING:
    # Load KITTI tracking dataset
    if FLAGS.from_rgb_detection:
        overwritten_val_data_path = os.path.join(ROOT_DIR, 'kitti/tracking_val_rgb_detection.pickle')
    else:
        overwritten_val_data_path = os.path.join(ROOT_DIR, 'kitti/frustum_carpedcyc_tracking_val.pickle')
    TEST_DATASET = provider.FrustumDataset(npoints=NUM_POINT, split='val', 
                                           rotate_to_center=True, one_hot=True,
                                           from_rgb_detection=FLAGS.from_rgb_detection,
                                           overwritten_data_path=FLAGS.data_path,
                                           tracks=FLAGS.tracking,tau=FLAGS.tau,feat_len=feat_len)
else:
    # Load KITTI dataset
    TEST_DATASET = provider.FrustumDataset(npoints=NUM_POINT, split='val',
BASE_LEARNING_RATE = FLAGS.learning_rate
GPU_INDEX = FLAGS.gpu
MOMENTUM = FLAGS.momentum
OPTIMIZER = FLAGS.optimizer
DECAY_STEP = FLAGS.decay_step
DECAY_RATE = FLAGS.decay_rate
NUM_CHANNEL = 3 if FLAGS.no_intensity else 4  # point feature channel
NUM_CLASSES = 2  # segmentation has two classes
"""
MODEL = importlib.import_module(FLAGS.model) # import network module
MODEL_FILE = os.path.join(ROOT_DIR, 'models', FLAGS.model+'.py')
LOG_DIR = FLAGS.log_dir
if not os.path.exists(LOG_DIR): os.mkdir(LOG_DIR)
os.system('cp %s %s' % (MODEL_FILE, LOG_DIR)) # bkp of model def
os.system('cp %s %s' % (os.path.join(BASE_DIR, 'train.py'), LOG_DIR))
LOG_FOUT = open(os.path.join(LOG_DIR, 'log_train.txt'), 'w')
LOG_FOUT.write(str(FLAGS)+'\n')

BN_INIT_DECAY = 0.5
BN_DECAY_DECAY_RATE = 0.5
BN_DECAY_DECAY_STEP = float(DECAY_STEP)
BN_DECAY_CLIP = 0.99
"""
# Load Frustum Datasets. Use default data paths.
#TRAIN_DATASET = provider.FrustumDataset(npoints=NUM_POINT, split='train',
#    rotate_to_center=True, random_flip=False, random_shift=True, one_hot=True)
TEST_DATASET = provider.FrustumDataset(npoints=NUM_POINT,
                                       split='val',
                                       rotate_to_center=True,
                                       one_hot=True)
else:
    print("Wrong model parameter.")
    exit(0)

if 'fusion' in MODEL_FILE:
    with_image = True
else:
    with_image = False

provider = import_from_file(DATA_FILE)

TRAIN_DATASET = provider.FrustumDataset(npoints=NUM_POINT,
                                        split=TRAIN_SETS,
                                        rotate_to_center=True,
                                        random_flip=True,
                                        random_shift=True,
                                        one_hot=True,
                                        overwritten_data_path=TRAIN_FILE,
                                        gen_ref=gen_ref,
                                        with_image=with_image)
TEST_DATASET = provider.FrustumDataset(npoints=NUM_POINT,
                                       split=TEST_SETS,
                                       rotate_to_center=True,
                                       one_hot=True,
                                       overwritten_data_path=TEST_FILE,
                                       gen_ref=gen_ref,
                                       with_image=with_image)
train_dataloader = DataLoader(TRAIN_DATASET,
                              batch_size=BATCH_SIZE,
                              shuffle=True,
                              num_workers=NUM_WORKERS,
if not os.path.exists(LOG_DIR): os.mkdir(LOG_DIR)
os.system('cp %s %s' % (MODEL_FILE, LOG_DIR)) # bkp of model def
os.system('cp %s %s' % (os.path.join(BASE_DIR, 'extract_mistakes.py'), LOG_DIR))
LOG_FOUT = open(os.path.join(LOG_DIR, 'log_extract_mistakes.txt'), 'w')
LOG_FOUT.write(str(FLAGS)+'\n')


# Load Frustum Datasets. 
if FLAGS.two_frustum:
    data_path_val   = os.path.join(ROOT_DIR,'kitti/'+FLAGS.frustum_folder+'/two_frustum_carpedcyc_val.pickle')
else:
    data_path_val   = os.path.join(ROOT_DIR,'kitti/'+FLAGS.frustum_folder+'/frustum_carpedcyc_val.pickle')

print("Loading val frustums from: ",data_path_val)

TEST_DATASET = provider.FrustumDataset(overwritten_data_path= data_path_val,npoints=NUM_POINT, split='val',rotate_to_center=True, one_hot=True)

def log_string(out_str):
    LOG_FOUT.write(out_str+'\n')
    LOG_FOUT.flush()
    print(out_str)

def main():
    ''' Main function for evaluation and determining mistakes. '''
    with tf.Graph().as_default():
        with tf.device('/gpu:'+str(GPU_INDEX)):
            pointclouds_pl, one_hot_vec_pl, labels_pl, centers_pl, \
            heading_class_label_pl, heading_residual_label_pl, \
            size_class_label_pl, size_residual_label_pl = \
                MODEL.placeholder_inputs(BATCH_SIZE, NUM_POINT)
if __name__ == '__main__':

    # Load Frustum Datasets.
    if 'frustum_pointnet' in MODEL_FILE:
        gen_ref = False
    elif 'frustum_convnet' in MODEL_FILE:
        gen_ref = True
    else:
        print("Wrong model parameter.")
        exit(0)

    provider = import_from_file(DATA_FILE)

    TEST_DATASET = provider.FrustumDataset(npoints=NUM_POINT,
                                           split=TEST_SETS,
                                           rotate_to_center=True,
                                           one_hot=True,
                                           overwritten_data_path=TEST_FILE,
                                           gen_ref=gen_ref)

    test_dataloader = DataLoader(TEST_DATASET,
                                 batch_size=TEST_BATCH_SIZE,
                                 shuffle=False,
                                 num_workers=NUM_WORKERS,
                                 pin_memory=True)
    # Model
    if 'frustum_pointnets_v1' in MODEL_FILE:
        from frustum_pointnets_v1 import FrustumPointNetv1

        model = FrustumPointNetv1(n_classes=NUM_CLASSES,
                                  n_channel=NUM_CHANNEL).cuda()
    elif 'frustum_convnet_v1' in MODEL_FILE:
Exemple #13
0
        ROOT_DIR, 'kitti/frustum_caronly_vkitti_train.pickle')
    overwritten_val_data_path = os.path.join(
        ROOT_DIR, 'kitti/frustum_caronly_vkitti_val.pickle')

    if GROUND:
        overwritten_train_data_path = os.path.join(
            ROOT_DIR, 'kitti/ground_caronly_vkitti_train.pickle')
        overwritten_val_data_path = os.path.join(
            ROOT_DIR, 'kitti/ground_caronly_vkitti_val.pickle')

    TRAIN_DATASET = provider.FrustumDataset(
        npoints=NUM_POINT,
        split='train',
        rotate_to_center=True,
        random_flip=True,
        random_shift=True,
        one_hot=True,
        overwritten_data_path=overwritten_train_data_path,
        tracks=FLAGS.tracks,
        tau=FLAGS.tau,
        feat_len=feat_len)
    TEST_DATASET = provider.FrustumDataset(
        npoints=NUM_POINT,
        split='val',
        rotate_to_center=True,
        one_hot=True,
        overwritten_data_path=overwritten_val_data_path,
        tracks=FLAGS.tracks,
        tau=FLAGS.tau,
        feat_len=feat_len)
elif TRACKING:
Exemple #14
0
os.system('cp %s %s' % (os.path.join(BASE_DIR, 'train.py'), LOG_DIR))
LOG_FOUT = open(os.path.join(LOG_DIR, 'log_train.txt'), 'w')
LOG_FOUT.write(str(FLAGS) + '\n')

BN_INIT_DECAY = 0.5
BN_DECAY_DECAY_RATE = 0.5
BN_DECAY_DECAY_STEP = float(DECAY_STEP)
BN_DECAY_CLIP = 0.99

OUTPUT_FILE = os.path.join(LOG_DIR, 'results/')

# Load Frustum Datasets. Use default data paths.
TRAIN_DATASET = provider.FrustumDataset(npoints=NUM_POINT,
                                        database='KITTI',
                                        split='train',
                                        res=0,
                                        rotate_to_center=True,
                                        random_flip=False,
                                        random_shift=True,
                                        one_hot=True)
EVAL_DATASET_224 = provider.FrustumDataset(npoints=NUM_POINT,
                                           database='KITTI',
                                           split='val',
                                           res="224",
                                           rotate_to_center=True,
                                           one_hot=True)
EVAL_DATASET_704 = provider.FrustumDataset(npoints=NUM_POINT,
                                           database='KITTI',
                                           split='val',
                                           res="704",
                                           rotate_to_center=True,
                                           one_hot=True)