コード例 #1
0
def parse_args(description):
    parser = argparse.ArgumentParser(description=description)
    parser.add_argument('--cfg',
                        dest='cfg_file',
                        action='append',
                        help='an optional config file',
                        default=None,
                        type=str)
    parser.add_argument('--batch',
                        dest='batch_size',
                        help='batch size',
                        default=None,
                        type=int)
    parser.add_argument('--epoch',
                        dest='epoch',
                        help='epoch number',
                        default=None,
                        type=int)
    parser.add_argument('--model',
                        dest='model',
                        help='model name',
                        default=None,
                        type=str)
    parser.add_argument('--dataset',
                        dest='dataset',
                        help='dataset name',
                        default=None,
                        type=str)
    args = parser.parse_args()

    # load cfg from file
    if args.cfg_file is not None:
        for f in args.cfg_file:
            cfg_from_file(f)

    # load cfg from arguments
    if args.batch_size is not None:
        cfg_from_list(['BATCH_SIZE', args.batch_size])
    if args.epoch is not None:
        cfg_from_list([
            'TRAIN.START_EPOCH', args.epoch, 'EVAL.EPOCH', args.epoch,
            'VISUAL.EPOCH', args.epoch
        ])
    if args.model is not None:
        cfg_from_list(['MODEL_NAME', args.model])
    if args.dataset is not None:
        cfg_from_list(['DATASET_NAME', args.dataset])

    if len(cfg.MODEL_NAME) != 0 and len(cfg.DATASET_NAME) != 0:
        outp_path = get_output_dir(cfg.MODEL_NAME, cfg.DATASET_NAME)
        cfg_from_list(['OUTPUT_PATH', outp_path])
    assert len(
        cfg.OUTPUT_PATH
    ) != 0, 'Invalid OUTPUT_PATH! Make sure model name and dataset name are specified.'
    if not Path(cfg.OUTPUT_PATH).exists():
        Path(cfg.OUTPUT_PATH).mkdir(parents=True)

    return args
コード例 #2
0
def test_parse_args(description):
    parser = argparse.ArgumentParser()
    parser.add_argument('--cfg',
                        dest='cfg_file',
                        type=str,
                        help='an optional config file',
                        default="experiments/vgg16_scannet.yaml")
    parser.add_argument(
        '--model_path',
        dest='model_path',
        help='model name',
        default='output/vgg16_linematching_wire/params/params_0010.pt',
        type=str)
    parser.add_argument('--left_img',
                        dest='left_img',
                        help='left image name',
                        default='test_data/000800.jpg',
                        type=str)
    parser.add_argument('--right_img',
                        dest='right_img',
                        help='right image name',
                        default='test_data/000900.jpg',
                        type=str)
    parser.add_argument('--left_lines',
                        dest='left_lines',
                        help='left lines name',
                        default='test_data/000800.txt',
                        type=str)
    parser.add_argument('--right_lines',
                        dest='right_lines',
                        help='right lines name',
                        default='test_data/000900.txt',
                        type=str)
    parser.add_argument('--output_path',
                        dest='output_path',
                        help='output path name',
                        default='./test_data/',
                        type=str)
    args = parser.parse_args()

    # load cfg from file
    if args.cfg_file is not None:
        cfg_from_file(args.cfg_file)

    if len(cfg.MODEL_NAME) != 0 and len(cfg.DATASET_NAME) != 0:
        outp_path = get_output_dir(cfg.MODEL_NAME, cfg.DATASET_NAME)
        cfg_from_list(['OUTPUT_PATH', outp_path])
    assert len(
        cfg.OUTPUT_PATH
    ) != 0, 'Invalid OUTPUT_PATH! Make sure model name and dataset name are specified.'
    if not Path(cfg.OUTPUT_PATH).exists():
        Path(cfg.OUTPUT_PATH).mkdir(parents=True)

    return args
コード例 #3
0
    start = time.time()
    max_per_image = 100

    vis = args.vis

    if vis:
        thresh = 0.05
    else:
        thresh = 0.0

    save_name = 'faster_rcnn_10'
    num_images = len(imdb.image_index)
    all_boxes = [[[] for _ in xrange(num_images)]
                 for _ in xrange(imdb.num_classes)]

    output_dir = get_output_dir(imdb, save_name)
    dataset = roibatchLoader(roidb,
                             ratio_list,
                             ratio_index,
                             1,
                             imdb.num_classes,
                             training=False,
                             normalize=False)
    dataloader = torch.utils.data.DataLoader(dataset,
                                             batch_size=1,
                                             shuffle=False,
                                             num_workers=0,
                                             pin_memory=True)

    data_iter = iter(dataloader)
コード例 #4
0
from networks.model_train import model_train
from networks.train import get_training_roidb, train_net
from utils.config import cfg_from_file, get_output_dir, get_log_dir
from utils.datasets.factory import get_imdb
from utils.config import cfg

if __name__ == '__main__':
    cfg_from_file('tools/text.yml')
    print('Using config:')
    pprint.pprint(cfg)
    imdb = get_imdb('voc_2007_trainval')
    print('Loaded dataset `{:s}` for training'.format(imdb.name))
    roidb = get_training_roidb(imdb)

    output_dir = get_output_dir(imdb, None)
    log_dir = get_log_dir(imdb)
    print('Output will be saved to `{:s}`'.format(output_dir))
    print('Logs will be saved to `{:s}`'.format(log_dir))

    device_name = '/gpu:0'
    print(device_name)

    network = model_train()

    train_net(network,
              imdb,
              roidb,
              output_dir=output_dir,
              log_dir=log_dir,
              pretrained_model='data/pretrain/VGG_imagenet.npy',
コード例 #5
0
ファイル: train_cls.py プロジェクト: stoneyang/deep_share
    # set up caffe
    if args.gpu_id is not None:
        caffe.set_mode_gpu()
        caffe.set_device(args.gpu_id)
    else:
        caffe.set_mode_cpu()

    traindb = get_imdb(args.traindb_name)
    valdb = get_imdb(args.valdb_name)
    print 'Loaded dataset `{:s}` for training'.format(traindb.name)
    print 'Loaded dataset `{:s}` for validation'.format(valdb.name)

    imdb = {'train': traindb, 'val': valdb}

    output_dir = get_output_dir(traindb, None)
    print 'Output will be saved to `{:s}`'.format(output_dir)

    # parse class_id if necessary
    if args.cls_id is not None:
        class_id = json.loads(args.cls_id)
    else:
        class_id = range(imdb['train'].num_classes)

    # parse cut_points if necessary
    cut_points = None
    if args.cut_points is not None:
        cut_points = json.loads(args.cut_points)

    # load pretrained parameters
    pretrained_params = PretrainedParameter(