def load_model(cfg_file, testing=False):
    merge_cfg_from_file(cfg_file)
    assert_and_infer_cfg()

    model_def = import_from_file(cfg.MODEL.FILE)
    model_def = model_def.PointNetDet

    input_channels = 3 if not cfg.DATA.WITH_EXTRA_FEAT else 4
    dataset_name = cfg.DATA.DATASET_NAME
    assert dataset_name in DATASET_INFO
    datset_category_info = DATASET_INFO[dataset_name]
    NUM_VEC = len(
        datset_category_info.CLASSES)  # rgb category as extra feature vector
    NUM_CLASSES = len(cfg.MODEL.CLASSES)

    model = model_def(input_channels, num_vec=NUM_VEC, num_classes=NUM_CLASSES)
    model = model.cuda()

    if testing:
        if os.path.isfile(cfg.TEST.WEIGHTS):
            checkpoint = torch.load(cfg.TEST.WEIGHTS,
                                    map_location={'cuda:1': 'cuda:0'})
            if 'state_dict' in checkpoint:
                # print(checkpoint['state_dict'])
                new_state_dict = OrderedDict()
                for k, v in checkpoint['state_dict'].items():
                    name = k[7:]  # remove `module.`
                    new_state_dict[name] = v
                model.load_state_dict(new_state_dict)
                logging.info("=> loaded checkpoint '{}' (epoch {})".format(
                    cfg.TEST.WEIGHTS, checkpoint['epoch']))
            else:
                model.load_state_dict(checkpoint)
                logging.info("=> loaded checkpoint '{}')".format(
                    cfg.TEST.WEIGHTS))
        else:
            logging.error("=> no checkpoint found at '{}'".format(
                cfg.TEST.WEIGHTS))
            assert False

    return model
Beispiel #2
0
                            continue
                        total_dataset_names.append(
                            (dataset_idx, object_i, sample_name))
        else:
            print('Something wrong with data')
            sys.exit()
    return total_dataset_names


if __name__ == '__main__':

    set_random_seed()
    args = parse_args()

    if args.cfg_file is not None:
        merge_cfg_from_file(args.cfg_file)

    if args.opts is not None:
        merge_cfg_from_list(args.opts)

    assert_and_infer_cfg()

    SAVE_DIR = os.path.join(cfg.OUTPUT_DIR, cfg.SAVE_SUB_DIR)

    if not os.path.exists(SAVE_DIR):
        os.makedirs(SAVE_DIR)

    # set logger
    cfg_name = os.path.basename(args.cfg_file).split('.')[0]
    log_file = '{}_{}_val.log'.format(cfg_name,
                                      time.strftime('%Y-%m-%d-%H-%M'))
from configs.config import merge_cfg_from_list
from configs.config import assert_and_infer_cfg
# from src.frustum_convnet.utils.utils import import_from_file

from pathlib import Path
import sys
import importlib
import shutil
from pyntcloud import PyntCloud
# from src.calib import Calib

from utils.utils import import_from_file

use_cam = True

merge_cfg_from_file('cfgs/det_sample_waymo.yaml')
assert_and_infer_cfg()
fr_weigths_path = 'pretrained_models/car_waymo/model_best.pth'
fr_model_def = import_from_file('models/det_base_onnx.py')
fr_model_def = fr_model_def.PointNetDet
input_channels = 3
NUM_VEC = 0
NUM_CLASSES = 2
fr_model = fr_model_def(input_channels,
                        num_vec=NUM_VEC,
                        num_classes=NUM_CLASSES)
fr_model = torch.nn.DataParallel(fr_model)
if os.path.isfile(fr_weigths_path):

    checkpoint = torch.load(fr_weigths_path)
    if 'state_dict' in checkpoint:
def main():
    # parse arguments
    args = parse_args()

    if args.cfg_file is not None:
        merge_cfg_from_file(args.cfg_file)

    if args.opts is not None:
        merge_cfg_from_list(args.opts)

    assert_and_infer_cfg()

    if not os.path.exists(cfg.OUTPUT_DIR):
        os.makedirs(cfg.OUTPUT_DIR)

    # set logger
    cfg_name = os.path.basename(args.cfg_file).split('.')[0]
    log_file = '{}_{}_train.log'.format(cfg_name,
                                        time.strftime('%Y-%m-%d-%H-%M'))
    log_file = os.path.join(cfg.OUTPUT_DIR, log_file)
    logger = get_logger(log_file)

    logger.info(pprint.pformat(args))
    logger.info('config:\n {}'.format(pprint.pformat(cfg)))

    # set visualize logger
    logger_train = None
    logger_val = None
    if cfg.USE_TFBOARD:
        from utils.logger import Logger
        logger_dir = os.path.join(cfg.OUTPUT_DIR, 'tb_logger', 'train')
        if not os.path.exists(logger_dir):
            os.makedirs(logger_dir)
        logger_train = Logger(logger_dir)

        logger_dir = os.path.join(cfg.OUTPUT_DIR, 'tb_logger', 'val')
        if not os.path.exists(logger_dir):
            os.makedirs(logger_dir)
        logger_val = Logger(logger_dir)

    # import dataset

    set_random_seed()

    logging.info(cfg.DATA.FILE)
    dataset_def = import_from_file(cfg.DATA.FILE)
    collate_fn = dataset_def.collate_fn
    dataset_def = dataset_def.ProviderDataset

    train_dataset = dataset_def(cfg.DATA.NUM_SAMPLES,
                                split=cfg.TRAIN.DATASET,
                                one_hot=True,
                                random_flip=True,
                                random_shift=True,
                                extend_from_det=cfg.DATA.EXTEND_FROM_DET)

    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=cfg.TRAIN.BATCH_SIZE,
                                               shuffle=True,
                                               num_workers=cfg.NUM_WORKERS,
                                               pin_memory=True,
                                               drop_last=True,
                                               collate_fn=collate_fn)

    val_dataset = dataset_def(cfg.DATA.NUM_SAMPLES,
                              split=cfg.TEST.DATASET,
                              one_hot=True,
                              random_flip=False,
                              random_shift=False,
                              extend_from_det=cfg.DATA.EXTEND_FROM_DET)

    val_loader = torch.utils.data.DataLoader(val_dataset,
                                             batch_size=cfg.TEST.BATCH_SIZE,
                                             shuffle=False,
                                             num_workers=cfg.NUM_WORKERS,
                                             pin_memory=True,
                                             drop_last=False,
                                             collate_fn=collate_fn)

    logging.info('training: sample {} / batch {} '.format(
        len(train_dataset), len(train_loader)))
    logging.info('validation: sample {} / batch {} '.format(
        len(val_dataset), len(val_loader)))

    logging.info(cfg.MODEL.FILE)
    model_def = import_from_file(cfg.MODEL.FILE)
    model_def = model_def.PointNetDet

    input_channels = 3 if not cfg.DATA.WITH_EXTRA_FEAT else cfg.DATA.EXTRA_FEAT_DIM
    # NUM_VEC = 0 if cfg.DATA.CAR_ONLY else 3
    dataset_name = cfg.DATA.DATASET_NAME
    assert dataset_name in DATASET_INFO
    datset_category_info = DATASET_INFO[dataset_name]
    NUM_VEC = len(
        datset_category_info.CLASSES)  # rgb category as extra feature vector
    NUM_CLASSES = cfg.MODEL.NUM_CLASSES

    model = model_def(input_channels, num_vec=NUM_VEC, num_classes=NUM_CLASSES)

    logging.info(pprint.pformat(model))

    if cfg.NUM_GPUS > 1:
        model = torch.nn.DataParallel(model)

    model = model.cuda()

    parameters_size = 0
    for p in model.parameters():
        parameters_size += p.numel()

    logging.info('parameters: %d' % parameters_size)

    logging.info('using optimizer method {}'.format(cfg.TRAIN.OPTIMIZER))

    if cfg.TRAIN.OPTIMIZER == 'adam':
        optimizer = optim.Adam(model.parameters(),
                               lr=cfg.TRAIN.BASE_LR,
                               betas=(0.9, 0.999),
                               weight_decay=cfg.TRAIN.WEIGHT_DECAY)
    elif cfg.TRAIN.OPTIMIZER == 'sgd':
        optimizer = optim.SGD(model.parameters(),
                              lr=cfg.TRAIN.BASE_LR,
                              momentum=cfg.TRAIN.MOMENTUM,
                              weight_decay=cfg.TRAIN.WEIGHT_DECAY)
    else:
        assert False, 'Not support now.'

    # miles = [math.ceil(num_epochs*3/8), math.ceil(num_epochs*6/8)]
    # assert isinstance(LR_SETP, list)

    LR_STEPS = cfg.TRAIN.LR_STEPS
    LR_DECAY = cfg.TRAIN.GAMMA

    if len(LR_STEPS) > 1:
        lr_scheduler = optim.lr_scheduler.MultiStepLR(optimizer,
                                                      milestones=LR_STEPS,
                                                      gamma=LR_DECAY)
    else:
        lr_scheduler = optim.lr_scheduler.StepLR(optimizer,
                                                 step_size=LR_STEPS[0],
                                                 gamma=LR_DECAY)

    best_prec1 = 0
    best_epoch = 0
    start_epoch = 0
    # optionally resume from a checkpoint
    if cfg.RESUME:
        if os.path.isfile(cfg.TRAIN.WEIGHTS):
            checkpoint = torch.load(cfg.TRAIN.WEIGHTS)
            start_epoch = checkpoint['epoch']
            best_prec1 = checkpoint['best_prec1']
            best_epoch = checkpoint['best_epoch']
            model.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            logger.info("=> loaded checkpoint '{}' (epoch {})".format(
                cfg.TRAIN.WEIGHTS, checkpoint['epoch']))
        else:
            logger.error("=> no checkpoint found at '{}'".format(
                cfg.TRAIN.WEIGHTS))

        # resume from other pretrained model
        if start_epoch == cfg.TRAIN.MAX_EPOCH:
            start_epoch = 0
            best_prec1 = 0
            best_epoch = 0

    if cfg.EVAL_MODE:
        validate(val_loader, model, start_epoch, logger_val)
        return

    MAX_EPOCH = cfg.TRAIN.MAX_EPOCH

    for n in range(start_epoch, MAX_EPOCH):

        train(train_loader, model, optimizer, lr_scheduler, n, logger_train)

        ious_gt = validate(val_loader, model, n, logger_val)

        prec1 = ious_gt

        is_best = False
        if prec1 > best_prec1:
            best_prec1 = prec1
            best_epoch = n + 1
            is_best = True
            logging.info(
                'Best model {:04d}, Validation Accuracy {:.6f}'.format(
                    best_epoch, best_prec1))

        save_data = {
            'epoch':
            n + 1,
            'state_dict':
            model.state_dict()
            if cfg.NUM_GPUS == 1 else model.module.state_dict(),
            'optimizer':
            optimizer.state_dict(),
            'best_prec1':
            best_prec1,
            'best_epoch':
            best_epoch
        }
        if (n + 1) % 5 == 0 or (n + 1) == MAX_EPOCH:
            torch.save(
                save_data,
                os.path.join(cfg.OUTPUT_DIR, 'model_%04d.pth' % (n + 1)))

        if is_best:
            torch.save(save_data, os.path.join(cfg.OUTPUT_DIR,
                                               'model_best.pth'))

        if (n + 1) == MAX_EPOCH:
            torch.save(save_data,
                       os.path.join(cfg.OUTPUT_DIR, 'model_final.pth'))

    logging.info('Best model {:04d}, Validation Accuracy {:.6f}'.format(
        best_epoch, best_prec1))