Ejemplo n.º 1
0
def main():
    global cfg

    cfg = Config.fromfile(args.config)

    cfg.save = '{}/{}-{}-{}'.format(cfg.save_path, cfg.model, cfg.dataset,
                                    time.strftime("%Y%m%d-%H%M%S"))
    utils.create_exp_dir(cfg.save)

    logger = utils.create_logger('global_logger', cfg.save + '/log.txt')

    if not torch.cuda.is_available():
        logger.info('no gpu device available')
        sys.exit(1)

    # Set cuda device & seed
    torch.cuda.set_device(cfg.gpu)
    np.random.seed(cfg.seed)
    cudnn.benchmark = True
    torch.manual_seed(cfg.seed)
    cudnn.enabled = True
    torch.cuda.manual_seed(cfg.seed)

    # Model
    print('==> Building model..')
    arch_code = eval('architecture_code.{}'.format(cfg.model))
    net = models.model_entry(cfg, arch_code)
    net = net.cuda()

    cfg.netpara = sum(p.numel() for p in net.parameters()) / 1e6
    logger.info('config: {}'.format(pprint.pformat(cfg)))

    # Load checkpoint.
    if not Debug:
        print('==> Resuming from checkpoint..')
        utils.load_state(cfg.resume_path, net)

    # Data
    print('==> Preparing data..')

    testloader = dataset_entry(cfg)
    criterion = nn.CrossEntropyLoss()
    net_adv = AttackPGD(net, cfg.attack_param)

    print('==> Testing on Clean Data..')
    test(net, testloader, criterion)

    print('==> Testing on Adversarial Data..')
    test(net_adv, testloader, criterion, adv=True)
Ejemplo n.º 2
0
def main():
    global cfg, rank, world_size

    cfg = Config.fromfile(args.config)

    # Set seed
    np.random.seed(cfg.seed)
    cudnn.benchmark = True
    torch.manual_seed(cfg.seed)
    cudnn.enabled = True
    torch.cuda.manual_seed(cfg.seed)

    # Model
    print('==> Building model..')
    arch_code = eval('architecture_code.{}'.format(cfg.model))
    net = models.model_entry(cfg, arch_code)
    rank = 0  # for non-distributed
    world_size = 1  # for non-distributed
    if args.distributed:
        print('==> Initializing distributed training..')
        init_dist(
            launcher='slurm', backend='nccl'
        )  # Only support slurm for now, if you would like to personalize your launcher, please refer to https://github.com/open-mmlab/mmcv/blob/master/mmcv/runner/dist_utils.py
        rank, world_size = get_dist_info()
    net = net.cuda()

    cfg.netpara = sum(p.numel() for p in net.parameters()) / 1e6

    start_epoch = 0
    best_acc = 0
    # Load checkpoint.
    if cfg.get('resume_path', False):
        print('==> Resuming from {}checkpoint {}..'.format(
            ('original ' if cfg.resume_path.origin_ckpt else ''),
            cfg.resume_path.path))
        if cfg.resume_path.origin_ckpt:
            utils.load_state(cfg.resume_path.path, net, rank=rank)
        else:
            if args.distributed:
                net = torch.nn.parallel.DistributedDataParallel(
                    net,
                    device_ids=[torch.cuda.current_device()],
                    output_device=torch.cuda.current_device())
            utils.load_state(cfg.resume_path.path, net, rank=rank)

    # Data
    print('==> Preparing data..')
    if args.eval_only:
        testloader = dataset_entry(cfg, args.distributed, args.eval_only)
    else:
        trainloader, testloader, train_sampler, test_sampler = dataset_entry(
            cfg, args.distributed, args.eval_only)
        print(trainloader, testloader, train_sampler, test_sampler)
    criterion = nn.CrossEntropyLoss()
    if not args.eval_only:
        cfg.attack_param.num_steps = 7
    net_adv = AttackPGD(net, cfg.attack_param)

    if not args.eval_only:
        # Train params
        print('==> Setting train parameters..')
        train_param = cfg.train_param
        epochs = train_param.epochs
        init_lr = train_param.learning_rate
        if train_param.get('warm_up_param', False):
            warm_up_param = train_param.warm_up_param
            init_lr = warm_up_param.warm_up_base_lr
            epochs += warm_up_param.warm_up_epochs
        if train_param.get('no_wd', False):
            param_group, type2num, _, _ = utils.param_group_no_wd(net)
            cfg.param_group_no_wd = type2num
            optimizer = torch.optim.SGD(param_group,
                                        lr=init_lr,
                                        momentum=train_param.momentum,
                                        weight_decay=train_param.weight_decay)
        else:
            optimizer = torch.optim.SGD(net.parameters(),
                                        lr=init_lr,
                                        momentum=train_param.momentum,
                                        weight_decay=train_param.weight_decay)

        scheduler = lr_scheduler.CosineLRScheduler(
            optimizer, epochs, train_param.learning_rate_min, init_lr,
            train_param.learning_rate,
            (warm_up_param.warm_up_epochs if train_param.get(
                'warm_up_param', False) else 0))
    # Log
    print('==> Writing log..')
    if rank == 0:
        cfg.save = '{}/{}-{}-{}'.format(cfg.save_path, cfg.model, cfg.dataset,
                                        time.strftime("%Y%m%d-%H%M%S"))
        utils.create_exp_dir(cfg.save)
        logger = utils.create_logger('global_logger', cfg.save + '/log.txt')
        logger.info('config: {}'.format(pprint.pformat(cfg)))

    # Evaluation only
    if args.eval_only:
        assert cfg.get(
            'resume_path',
            False), 'Should set the resume path for the eval_only mode'
        print('==> Testing on Clean Data..')
        test(net, testloader, criterion)
        print('==> Testing on Adversarial Data..')
        test(net_adv, testloader, criterion, adv=True)
        return

    # Training process
    for epoch in range(start_epoch, epochs):
        if args.distributed:
            train_sampler.set_epoch(epoch)
            test_sampler.set_epoch(epoch)
        scheduler.step()
        if rank == 0:
            logger.info('Epoch %d learning rate %e', epoch,
                        scheduler.get_lr()[0])

        # Train for one epoch
        train(net_adv, trainloader, criterion, optimizer)

        # Validate for one epoch
        valid_acc = test(net_adv, testloader, criterion, adv=True)

        if rank == 0:
            logger.info('Validation Accuracy: {}'.format(valid_acc))
            is_best = valid_acc > best_acc
            best_acc = max(valid_acc, best_acc)
            print('==> Saving')
            state = {
                'epoch': epoch,
                'best_acc': best_acc,
                'optimizer': optimizer.state_dict(),
                'state_dict': net.state_dict(),
                'scheduler': scheduler
            }
            utils.save_checkpoint(state, is_best, os.path.join(cfg.save))