def get_config():
    parser = argparse.ArgumentParser()
    parser.add_argument('--model', required=True)
    parser.add_argument('--mode',type=str,choices=['train','predict'],required=True)
    parser.add_argument('--config-path',type=str,default='config/cfg.json')
    parser.add_argument('--gpu-ids',type=int,nargs='+',default=0,dest='gpu_ids')
    parser.add_argument('--state',type=int,default=1,dest='state')
    parser.add_argument('-e', '--epochs', type=int, default=5,
                        help='Number of epochs', dest='epochs')
    parser.add_argument('-b', '--batch-size', type=int, default=2,
                        help='Batch size', dest='batch_size')
    parser.add_argument('-l', '--learning-rate', type=float, default=0.1,
                        help='Learning rate', dest='lr')
    parser.add_argument('-p', '--port', type=int,default=10001,
                        help='Visualization port', dest='port')
    parser.add_argument('-w', '--worker-num', type=int, default=8,
                        help='Dataloader worker number', dest='num_workers')
    parser.add_argument('-c', '--class-num', type=int, default=2,
                        help='class number', dest='class_num')
    parser.add_argument('-v', '--valid-percent', type=get_range_limited_float_type(0,100), default=10.0,
                        help='Percent of the data that is used as validation (0-100)', dest='valid_percent')

    args = parser.parse_args()

    assert os.path.exists(args.config_path),'config json not exists'
    with open(args.config_path,'r') as f:
        config = json.load(f)

    for arg in vars(args):
        config[arg]=getattr(args,arg)
    
    if isinstance(config['gpu_ids'],int):
        config['gpu_ids'] = [config['gpu_ids']]
    config['gpu_ids'] = list(set(config['gpu_ids']))
    config['device'] = 'cuda' if torch.cuda.is_available() else 'cpu'
    if config['device'] == 'cuda':
        gpu_num = torch.cuda.device_count()
        assert len(config['gpu_ids'])!=0,'unexpected gpu number'
        for gpu_id in config['gpu_ids']:
            assert gpu_id>=0 and gpu_id<gpu_num,'invalid gpu id input'

    config['input_transform'] = Compose([
        # CenterCrop(256),
        ToTensor(),
        # The mean and std result from data statistics of ImageNet dataset, you should fill corresponding mean and std here.
        # Normalize([.485, .456, .406], [.229, .224, .225]), 
    ])

    config['target_transform'] = Compose([
        # CenterCrop(256),
        ToLabel(),
        # Relabel(255, 21),
    ])
            
    return config
Пример #2
0
    tt = time.time()
    all_E = Infer_MO(all_F,
                     all_M,
                     num_frames,
                     num_objects,
                     scales=[0.5, 0.75, 1.0])
    print('{} | num_objects: {}, FPS: {}'.format(
        seq_name, num_objects, num_frames / (time.time() - tt)))

    # Save results for quantitative eval ######################
    if MO:
        folder = 'results/MO'
    else:
        folder = 'results/SO'
    test_path = os.path.join(folder, seq_name)
    if not os.path.exists(test_path):
        os.makedirs(test_path)

    for f in range(num_frames):
        E = all_E[0, :, f].numpy()
        # make hard label
        E = ToLabel(E)

        (lh, uh), (lw, uw) = info['pad']
        E = E[lh[0]:-uh[0], lw[0]:-uw[0]]

        img_E = Image.fromarray(E)
        img_E.putpalette(palette)
        img_E.save(os.path.join(test_path, '{:05d}.png'.format(f)))
Пример #3
0
    os.makedirs(os.path.join(args.name + '_results', 'Colorization'))
if not os.path.exists(os.path.join(args.name + '_results', 'log')):
    os.makedirs(os.path.join(args.name + '_results', 'log'))

# Loss plot
logger = Logger(os.path.join(args.name + '_results', 'log'))

image_transform = transforms.Compose([
    transforms.CenterCrop(args.input_size),
    transforms.ToTensor(),
    transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
])

target_transform = transforms.Compose([
    transforms.CenterCrop(args.input_size),
    ToLabel(),
    Relabel(255, 21),
])

train_loader_src = DataLoader(VOC12(args.src_data, image_transform,
                                    target_transform),
                              num_workers=4,
                              batch_size=args.batch_size,
                              shuffle=True)
train_loader_tgt = DataLoader(VOC12(args.tgt_data, image_transform,
                                    target_transform),
                              num_workers=4,
                              batch_size=args.batch_size,
                              shuffle=True)
test_loader_src = data_load('./data/test_data/',
                            'test',