예제 #1
0
def main():
    global args, config, best_prec1
    args = parser.parse_args()

    with open(args.config) as f:
        config = yaml.load(f)

    config = EasyDict(config['common'])
    config.save_path = os.path.dirname(args.config)

    rank, world_size = dist_init()

    # create model
    bn_group_size = config.model.kwargs.bn_group_size
    bn_var_mode = config.model.kwargs.get('bn_var_mode', 'L2')
    if bn_group_size == 1:
        bn_group = None
    else:
        assert world_size % bn_group_size == 0
        bn_group = simple_group_split(world_size, rank,
                                      world_size // bn_group_size)

    config.model.kwargs.bn_group = bn_group
    config.model.kwargs.bn_var_mode = (link.syncbnVarMode_t.L1 if bn_var_mode
                                       == 'L1' else link.syncbnVarMode_t.L2)
    model = model_entry(config.model)
    if rank == 0:
        print(model)

    model.cuda()

    if config.optimizer.type == 'FP16SGD' or config.optimizer.type == 'FusedFP16SGD':
        args.fp16 = True
    else:
        args.fp16 = False

    if args.fp16:
        # if you have modules that must use fp32 parameters, and need fp32 input
        # try use link.fp16.register_float_module(your_module)
        # if you only need fp32 parameters set cast_args=False when call this
        # function, then call link.fp16.init() before call model.half()
        if config.optimizer.get('fp16_normal_bn', False):
            print('using normal bn for fp16')
            link.fp16.register_float_module(link.nn.SyncBatchNorm2d,
                                            cast_args=False)
            link.fp16.register_float_module(torch.nn.BatchNorm2d,
                                            cast_args=False)
            link.fp16.init()
        model.half()

    model = DistModule(model, args.sync)

    # create optimizer
    opt_config = config.optimizer
    opt_config.kwargs.lr = config.lr_scheduler.base_lr
    if config.get('no_wd', False):
        param_group, type2num = param_group_no_wd(model)
        opt_config.kwargs.params = param_group
    else:
        opt_config.kwargs.params = model.parameters()

    optimizer = optim_entry(opt_config)

    # optionally resume from a checkpoint
    last_iter = -1
    best_prec1 = 0
    if args.load_path:
        if args.recover:
            best_prec1, last_iter = load_state(args.load_path,
                                               model,
                                               optimizer=optimizer)
        else:
            load_state(args.load_path, model)

    cudnn.benchmark = True

    # Data loading code
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])

    # augmentation
    aug = [
        transforms.RandomResizedCrop(config.augmentation.input_size),
        transforms.RandomHorizontalFlip()
    ]

    for k in config.augmentation.keys():
        assert k in [
            'input_size', 'test_resize', 'rotation', 'colorjitter', 'colorold'
        ]
    rotation = config.augmentation.get('rotation', 0)
    colorjitter = config.augmentation.get('colorjitter', None)
    colorold = config.augmentation.get('colorold', False)

    if rotation > 0:
        aug.append(transforms.RandomRotation(rotation))

    if colorjitter is not None:
        aug.append(transforms.ColorJitter(*colorjitter))

    aug.append(transforms.ToTensor())

    if colorold:
        aug.append(ColorAugmentation())

    aug.append(normalize)

    # train
    train_dataset = McDataset(config.train_root,
                              config.train_source,
                              transforms.Compose(aug),
                              fake=args.fake)

    # val
    val_dataset = McDataset(
        config.val_root, config.val_source,
        transforms.Compose([
            transforms.Resize(config.augmentation.test_resize),
            transforms.CenterCrop(config.augmentation.input_size),
            transforms.ToTensor(),
            normalize,
        ]), args.fake)

    train_sampler = DistributedGivenIterationSampler(
        train_dataset,
        config.lr_scheduler.max_iter,
        config.batch_size,
        last_iter=last_iter)
    val_sampler = DistributedSampler(val_dataset, round_up=False)

    train_loader = DataLoader(train_dataset,
                              batch_size=config.batch_size,
                              shuffle=False,
                              num_workers=config.workers,
                              pin_memory=True,
                              sampler=train_sampler)

    val_loader = DataLoader(val_dataset,
                            batch_size=config.batch_size,
                            shuffle=False,
                            num_workers=config.workers,
                            pin_memory=True,
                            sampler=val_sampler)

    config.lr_scheduler['optimizer'] = optimizer.optimizer if isinstance(
        optimizer, FP16SGD) else optimizer
    config.lr_scheduler['last_iter'] = last_iter
    lr_scheduler = get_scheduler(config.lr_scheduler)

    if rank == 0:
        tb_logger = SummaryWriter(config.save_path + '/events')
        logger = create_logger('global_logger', config.save_path + '/log.txt')
        logger.info('args: {}'.format(pprint.pformat(args)))
        logger.info('config: {}'.format(pprint.pformat(config)))
    else:
        tb_logger = None

    if args.evaluate:
        if args.fusion_list is not None:
            validate(val_loader,
                     model,
                     fusion_list=args.fusion_list,
                     fuse_prob=args.fuse_prob)
        else:
            validate(val_loader, model)
        link.finalize()
        return

    train(train_loader, val_loader, model, optimizer, lr_scheduler,
          last_iter + 1, tb_logger)

    link.finalize()
예제 #2
0
파일: main.py 프로젝트: gmh14/RobNets
def main():
    global cfg, rank, world_size

    cfg = Config.fromfile(args.config)

    # Set seed
    np.random.seed(cfg.seed)
    cudnn.benchmark = True
    torch.manual_seed(cfg.seed)
    cudnn.enabled = True
    torch.cuda.manual_seed(cfg.seed)

    # Model
    print('==> Building model..')
    arch_code = eval('architecture_code.{}'.format(cfg.model))
    net = models.model_entry(cfg, arch_code)
    rank = 0  # for non-distributed
    world_size = 1  # for non-distributed
    if args.distributed:
        print('==> Initializing distributed training..')
        init_dist(
            launcher='slurm', backend='nccl'
        )  # Only support slurm for now, if you would like to personalize your launcher, please refer to https://github.com/open-mmlab/mmcv/blob/master/mmcv/runner/dist_utils.py
        rank, world_size = get_dist_info()
    net = net.cuda()

    cfg.netpara = sum(p.numel() for p in net.parameters()) / 1e6

    start_epoch = 0
    best_acc = 0
    # Load checkpoint.
    if cfg.get('resume_path', False):
        print('==> Resuming from {}checkpoint {}..'.format(
            ('original ' if cfg.resume_path.origin_ckpt else ''),
            cfg.resume_path.path))
        if cfg.resume_path.origin_ckpt:
            utils.load_state(cfg.resume_path.path, net, rank=rank)
        else:
            if args.distributed:
                net = torch.nn.parallel.DistributedDataParallel(
                    net,
                    device_ids=[torch.cuda.current_device()],
                    output_device=torch.cuda.current_device())
            utils.load_state(cfg.resume_path.path, net, rank=rank)

    # Data
    print('==> Preparing data..')
    if args.eval_only:
        testloader = dataset_entry(cfg, args.distributed, args.eval_only)
    else:
        trainloader, testloader, train_sampler, test_sampler = dataset_entry(
            cfg, args.distributed, args.eval_only)
        print(trainloader, testloader, train_sampler, test_sampler)
    criterion = nn.CrossEntropyLoss()
    if not args.eval_only:
        cfg.attack_param.num_steps = 7
    net_adv = AttackPGD(net, cfg.attack_param)

    if not args.eval_only:
        # Train params
        print('==> Setting train parameters..')
        train_param = cfg.train_param
        epochs = train_param.epochs
        init_lr = train_param.learning_rate
        if train_param.get('warm_up_param', False):
            warm_up_param = train_param.warm_up_param
            init_lr = warm_up_param.warm_up_base_lr
            epochs += warm_up_param.warm_up_epochs
        if train_param.get('no_wd', False):
            param_group, type2num, _, _ = utils.param_group_no_wd(net)
            cfg.param_group_no_wd = type2num
            optimizer = torch.optim.SGD(param_group,
                                        lr=init_lr,
                                        momentum=train_param.momentum,
                                        weight_decay=train_param.weight_decay)
        else:
            optimizer = torch.optim.SGD(net.parameters(),
                                        lr=init_lr,
                                        momentum=train_param.momentum,
                                        weight_decay=train_param.weight_decay)

        scheduler = lr_scheduler.CosineLRScheduler(
            optimizer, epochs, train_param.learning_rate_min, init_lr,
            train_param.learning_rate,
            (warm_up_param.warm_up_epochs if train_param.get(
                'warm_up_param', False) else 0))
    # Log
    print('==> Writing log..')
    if rank == 0:
        cfg.save = '{}/{}-{}-{}'.format(cfg.save_path, cfg.model, cfg.dataset,
                                        time.strftime("%Y%m%d-%H%M%S"))
        utils.create_exp_dir(cfg.save)
        logger = utils.create_logger('global_logger', cfg.save + '/log.txt')
        logger.info('config: {}'.format(pprint.pformat(cfg)))

    # Evaluation only
    if args.eval_only:
        assert cfg.get(
            'resume_path',
            False), 'Should set the resume path for the eval_only mode'
        print('==> Testing on Clean Data..')
        test(net, testloader, criterion)
        print('==> Testing on Adversarial Data..')
        test(net_adv, testloader, criterion, adv=True)
        return

    # Training process
    for epoch in range(start_epoch, epochs):
        if args.distributed:
            train_sampler.set_epoch(epoch)
            test_sampler.set_epoch(epoch)
        scheduler.step()
        if rank == 0:
            logger.info('Epoch %d learning rate %e', epoch,
                        scheduler.get_lr()[0])

        # Train for one epoch
        train(net_adv, trainloader, criterion, optimizer)

        # Validate for one epoch
        valid_acc = test(net_adv, testloader, criterion, adv=True)

        if rank == 0:
            logger.info('Validation Accuracy: {}'.format(valid_acc))
            is_best = valid_acc > best_acc
            best_acc = max(valid_acc, best_acc)
            print('==> Saving')
            state = {
                'epoch': epoch,
                'best_acc': best_acc,
                'optimizer': optimizer.state_dict(),
                'state_dict': net.state_dict(),
                'scheduler': scheduler
            }
            utils.save_checkpoint(state, is_best, os.path.join(cfg.save))