def main(output, model, source, lr, batch_size_per_gpu): config_logging() num_gpus = torch.cuda.device_count() batch_size = batch_size_per_gpu * num_gpus iterations = 16000 step = 16000 # datasets if '+' in source: sources = source.split('+') datasets = [get_dataset(source, split='all') for source in sources] source_dataset = ConcatDataset(datasets) source_dataset.num_classes = 19 val_datasets = [get_dataset(source, split='val') for source in sources] source_val_dataset = ConcatDataset(val_datasets) source_val_dataset.num_classes = 19 else: source_dataset = get_dataset(source, split='all') source_val_dataset = get_dataset(source, split='val') source_dataset = RandomRescaleWrapper(source_dataset, min_scale=0.5, max_scale=2.0, w=2048, h=1024) target_dataset = RandomRescaleWrapper(get_dataset('cityscapes'), min_scale=0.5, max_scale=2.0) source_val_dataset = RandomSubset(source_val_dataset) target_val_dataset = get_dataset('cityscapes', split='val') # net num_classes = source_dataset.num_classes if model == 'drn': backbone = drn_c_26(pretrained=True, finetune=True, num_classes=num_classes, out_map=True, out_middle=True) elif model == 'r38': backbone = resnet38d.Net( num_classes=num_classes, freeze=False, pretrained=True, ) state_dict = backbone.state_dict() state_dict.update(torch.load('gta_rna-a1_cls19_s8_ep-0000.pth')) backbone.load_state_dict(state_dict) backbone.freeze = True backbone.train() elif model == 'deeplab': backbone = deeplab_resnet101( num_classes=num_classes, pretrained=True, freeze=True, ) backbone.load_state_dict( torch.load( 'results/deeplab-gta5-sourceonly/snapshot/net-iter016000.pth')) backbone.train() else: raise KeyError(model) net = TaskNet(backbone) # tasks tasks = [ Segmentation(net, source_dataset, source_val_dataset, target_val_dataset, batch_size=batch_size, crop_size=512), Rotation(net, source_dataset, target_dataset, source_val_dataset, target_val_dataset, batch_size=batch_size // 2, crop_size=100, name='rotation100'), Rotation(net, source_dataset, target_dataset, source_val_dataset, target_val_dataset, batch_size=batch_size // 2, crop_size=200, name='rotation200'), Rotation(net, source_dataset, target_dataset, source_val_dataset, target_val_dataset, batch_size=batch_size // 2, crop_size=400, name='rotation400'), ContinuousGridRegression(net, source_dataset, target_dataset, source_val_dataset, target_val_dataset, batch_size=batch_size // 2, crop_size=100, name='gridregression100'), ContinuousGridRegression(net, source_dataset, target_dataset, source_val_dataset, target_val_dataset, batch_size=batch_size // 2, crop_size=200, name='gridregression200'), ContinuousGridRegression(net, source_dataset, target_dataset, source_val_dataset, target_val_dataset, batch_size=batch_size // 2, crop_size=400, name='gridregression400'), ] net = torch.nn.DataParallel(net).cuda() for task in tasks: task.net = net trainer = TaskTrainer(output, net, tasks, iterations=iterations, lr=lr, momentum=0.9, step_lr=step) trainer.run()