Example #1
0
def main():
    global args, config, best_prec1
    args = parser.parse_args()

    with open(args.config) as f:
        config = yaml.load(f)

    config = EasyDict(config['common'])
    config.save_path = os.path.dirname(args.config)

    rank, world_size = dist_init()

    # create model
    bn_group_size = config.model.kwargs.bn_group_size
    bn_var_mode = config.model.kwargs.get('bn_var_mode', 'L2')
    if bn_group_size == 1:
        bn_group = None
    else:
        assert world_size % bn_group_size == 0
        bn_group = simple_group_split(world_size, rank,
                                      world_size // bn_group_size)

    config.model.kwargs.bn_group = bn_group
    config.model.kwargs.bn_var_mode = (link.syncbnVarMode_t.L1 if bn_var_mode
                                       == 'L1' else link.syncbnVarMode_t.L2)
    model = model_entry(config.model)
    if rank == 0:
        print(model)

    model.cuda()

    if config.optimizer.type == 'FP16SGD' or config.optimizer.type == 'FusedFP16SGD':
        args.fp16 = True
    else:
        args.fp16 = False

    if args.fp16:
        # if you have modules that must use fp32 parameters, and need fp32 input
        # try use link.fp16.register_float_module(your_module)
        # if you only need fp32 parameters set cast_args=False when call this
        # function, then call link.fp16.init() before call model.half()
        if config.optimizer.get('fp16_normal_bn', False):
            print('using normal bn for fp16')
            link.fp16.register_float_module(link.nn.SyncBatchNorm2d,
                                            cast_args=False)
            link.fp16.register_float_module(torch.nn.BatchNorm2d,
                                            cast_args=False)
            link.fp16.init()
        model.half()

    model = DistModule(model, args.sync)

    # create optimizer
    opt_config = config.optimizer
    opt_config.kwargs.lr = config.lr_scheduler.base_lr
    if config.get('no_wd', False):
        param_group, type2num = param_group_no_wd(model)
        opt_config.kwargs.params = param_group
    else:
        opt_config.kwargs.params = model.parameters()

    optimizer = optim_entry(opt_config)

    # optionally resume from a checkpoint
    last_iter = -1
    best_prec1 = 0
    if args.load_path:
        if args.recover:
            best_prec1, last_iter = load_state(args.load_path,
                                               model,
                                               optimizer=optimizer)
        else:
            load_state(args.load_path, model)

    cudnn.benchmark = True

    # Data loading code
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])

    # augmentation
    aug = [
        transforms.RandomResizedCrop(config.augmentation.input_size),
        transforms.RandomHorizontalFlip()
    ]

    for k in config.augmentation.keys():
        assert k in [
            'input_size', 'test_resize', 'rotation', 'colorjitter', 'colorold'
        ]
    rotation = config.augmentation.get('rotation', 0)
    colorjitter = config.augmentation.get('colorjitter', None)
    colorold = config.augmentation.get('colorold', False)

    if rotation > 0:
        aug.append(transforms.RandomRotation(rotation))

    if colorjitter is not None:
        aug.append(transforms.ColorJitter(*colorjitter))

    aug.append(transforms.ToTensor())

    if colorold:
        aug.append(ColorAugmentation())

    aug.append(normalize)

    # train
    train_dataset = McDataset(config.train_root,
                              config.train_source,
                              transforms.Compose(aug),
                              fake=args.fake)

    # val
    val_dataset = McDataset(
        config.val_root, config.val_source,
        transforms.Compose([
            transforms.Resize(config.augmentation.test_resize),
            transforms.CenterCrop(config.augmentation.input_size),
            transforms.ToTensor(),
            normalize,
        ]), args.fake)

    train_sampler = DistributedGivenIterationSampler(
        train_dataset,
        config.lr_scheduler.max_iter,
        config.batch_size,
        last_iter=last_iter)
    val_sampler = DistributedSampler(val_dataset, round_up=False)

    train_loader = DataLoader(train_dataset,
                              batch_size=config.batch_size,
                              shuffle=False,
                              num_workers=config.workers,
                              pin_memory=True,
                              sampler=train_sampler)

    val_loader = DataLoader(val_dataset,
                            batch_size=config.batch_size,
                            shuffle=False,
                            num_workers=config.workers,
                            pin_memory=True,
                            sampler=val_sampler)

    config.lr_scheduler['optimizer'] = optimizer.optimizer if isinstance(
        optimizer, FP16SGD) else optimizer
    config.lr_scheduler['last_iter'] = last_iter
    lr_scheduler = get_scheduler(config.lr_scheduler)

    if rank == 0:
        tb_logger = SummaryWriter(config.save_path + '/events')
        logger = create_logger('global_logger', config.save_path + '/log.txt')
        logger.info('args: {}'.format(pprint.pformat(args)))
        logger.info('config: {}'.format(pprint.pformat(config)))
    else:
        tb_logger = None

    if args.evaluate:
        if args.fusion_list is not None:
            validate(val_loader,
                     model,
                     fusion_list=args.fusion_list,
                     fuse_prob=args.fuse_prob)
        else:
            validate(val_loader, model)
        link.finalize()
        return

    train(train_loader, val_loader, model, optimizer, lr_scheduler,
          last_iter + 1, tb_logger)

    link.finalize()
Example #2
0
def dist_finalize():
    link.finalize()