Beispiel #1
0
def train(**flag_kwargs):
    FLAGS = FlagHolder()
    FLAGS.initialize(**flag_kwargs)
    hvd.init()
    if FLAGS.step_size is None:
        FLAGS.step_size = get_step_size(FLAGS.epsilon, FLAGS.n_iters,
                                        FLAGS.use_max_step)
        FLAGS._dict['step_size'] = FLAGS.step_size

    if hvd.rank() == 0:
        FLAGS.summary()

    if FLAGS.dataset == 'imagenet':
        Trainer = ImagenetTrainer
    elif FLAGS.dataset == 'cifar-10':
        Trainer = CIFAR10Trainer
    else:
        raise NotImplementedError

    if hvd.rank() == 0:
        logger = init_logger(FLAGS.use_wandb, 'train', FLAGS._dict)
    else:
        logger = None
    if FLAGS.checkpoint_dir is None:
        FLAGS.checkpoint_dir = logger.log_dir
    print('checkpoint at {}'.format(FLAGS.checkpoint_dir))

    model = get_model(FLAGS.dataset, FLAGS.resnet_size,
                      1000 // FLAGS.class_downsample_factor)
    if FLAGS.adv_train:
        attack = get_attack(FLAGS.dataset, FLAGS.attack, FLAGS.epsilon,
                            FLAGS.n_iters, FLAGS.step_size, FLAGS.scale_each)
    else:
        attack = None

    trainer = Trainer(
        # model/checkpoint options
        model=model,
        checkpoint_dir=FLAGS.checkpoint_dir,
        dataset_path=FLAGS.dataset_path,
        # attack options
        attack=attack,
        scale_eps=FLAGS.scale_eps,
        attack_loss=FLAGS.attack_loss,
        # training options
        batch_size=FLAGS.batch_size,
        epochs=FLAGS.epochs,
        stride=FLAGS.class_downsample_factor,
        fp_all_reduce=FLAGS.use_fp16,
        label_smoothing=FLAGS.label_smoothing,
        rand_target=FLAGS.rand_target,
        # logging options
        logger=logger,
        tag=FLAGS.tag)
    trainer.train()

    if hvd.rank() == 0:
        print("Training finished.")
Beispiel #2
0
def run(**flag_kwargs):
    FLAGS = FlagHolder()
    FLAGS.initialize(**flag_kwargs)
    if FLAGS.wandb_ckpt_project is None:
        FLAGS._dict['wandb_ckpt_project'] = FLAGS.wandb_project
    if FLAGS.step_size is None:
        FLAGS.step_size = get_step_size(FLAGS.epsilon, FLAGS.n_iters,
                                        FLAGS.use_max_step)
        FLAGS._dict['step_size'] = FLAGS.step_size
    FLAGS.summary()

    logger = init_logger(FLAGS.use_wandb, 'eval', FLAGS._dict)

    if FLAGS.dataset in ['cifar-10', 'cifar-10-c']:
        nb_classes = 10
    else:
        nb_classes = 1000 // FLAGS.class_downsample_factor

    model_dataset = FLAGS.dataset
    if model_dataset == 'imagenet-c':
        model_dataset = 'imagenet'
    print(FLAGS.resnet_size)
    model = get_model(model_dataset, FLAGS.resnet_size, nb_classes)
    ckpt = get_ckpt(FLAGS)
    model.load_state_dict(ckpt['model'])

    attack = get_attack(FLAGS.dataset, FLAGS.attack, FLAGS.epsilon,
                        FLAGS.n_iters, FLAGS.step_size, False)

    if FLAGS.dataset == 'imagenet':
        Evaluator = ImagenetEvaluator
    elif FLAGS.dataset == 'imagenet-c':
        Evaluator = ImagenetCEvaluator
    elif FLAGS.dataset == 'cifar-10':
        Evaluator = CIFAR10Evaluator
    elif FLAGS.dataset == 'cifar-10-c':
        Evaluator = CIFAR10CEvaluator

    evaluator = Evaluator(model=model,
                          attack=attack,
                          dataset=FLAGS.dataset,
                          dataset_path=FLAGS.dataset_path,
                          nb_classes=nb_classes,
                          corruption_type=FLAGS.corruption_type,
                          corruption_name=FLAGS.corruption_name,
                          corruption_level=FLAGS.corruption_level,
                          batch_size=FLAGS.batch_size,
                          stride=FLAGS.class_downsample_factor,
                          fp_all_reduce=FLAGS.use_fp16,
                          logger=logger,
                          tag=FLAGS.tag)
    evaluator.evaluate()
def main(**kwargs):

    flags = FlagHolder()
    flags.initialize(**kwargs)
    flags.summary()

    # make logdir
    os.makedirs(flags.log_root, exist_ok=True)

    for idx in tqdm.tqdm(range(flags.num_image)):

        trgname = '*_{idx:06d}_fft.png'.format(idx=idx)
        trgpath = os.path.join(flags.trg_root, trgname)

        pngpaths = sorted(glob.glob(trgpath))
        outs = []

        if not pngpaths:
            print('No target files are found.')
            continue

        for pngpath in pngpaths:
            img = Image.open(pngpath)
            transform = torchvision.transforms.Compose(
                [torchvision.transforms.ToTensor()])
            x = transform(img).unsqueeze(0)  # (B, RGB, H, 7*W)
            outs.append(x)

        out = torch.cat(outs, dim=-2)

        # save
        savename = '{idx:06d}_merge.png'.format(idx=idx)
        savepath = os.path.join(flags.log_root, savename)
        torchvision.utils.save_image(out, savepath)
        raise NotImplementedError
Beispiel #4
0
def main(**kwargs):

    flags = FlagHolder()
    flags.initialize(**kwargs)
    flags.summary()

    # make logdir
    os.makedirs(flags.log_root, exist_ok=True)

    trg_path = os.path.join(flags.trg_root, '*.png')
    path_list = sorted(glob.glob(trg_path))

    for png_path in tqdm.tqdm(path_list):

        img = Image.open(png_path)

        transform = torchvision.transforms.Compose([
            torchvision.transforms.ToTensor(),
        ])

        # load delta
        x = transform(img).unsqueeze(0)  # (B, RGB, H, 3*W)
        width = int(x.size(-1) / 3)
        delta = x[:, :, :, width * 2:]

        # fft
        w = fft(delta)  #(B,4,H,W)
        w = torch.cat(
            [w[:, 0, :, :], w[:, 1, :, :], w[:, 2, :, :], w[:, 3, :, :]],
            dim=-1)
        w = w.unsqueeze(0).repeat(1, 3, 1, 1)  # (B, RGB, H, 4*W)

        out = torch.cat([x, w], dim=-1)

        # save
        basename = os.path.basename(png_path)
        name, ext = os.path.splitext(basename)
        savename = name + '_fft' + ext
        savepath = os.path.join(flags.log_root, savename)

        torchvision.utils.save_image(out, savepath)
Beispiel #5
0
def main(**flags):
    FLAGS = FlagHolder()
    FLAGS.initialize(**flags)

    # dataset (ImageNet100)
    normalize = torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406], 
                                                std=[0.229, 0.224, 0.225])
    transform = torchvision.transforms.Compose([torchvision.transforms.Resize(256),
                                                torchvision.transforms.CenterCrop(224),
                                                torchvision.transforms.ToTensor(),
                                                normalize,])

    dataset = torchvision.datasets.ImageFolder(dataset_root, transform)
    loader  = torch.utils.data.DataLoader(dataset, batch_size=FLAGS.batch_size, num_workers=FLAGS.num_workers, shuffle=True)

    attacks = ATTACKS

    dataset='imagenet100'
    n_iters=50
    scale_each=True
    scale_eps=True

    for name in MODEL_NAMES:
        # model
        weight_path = os.path.join(weight_root, name+'_model.pth')
        model = torchvision.models.resnet50(num_classes=100)
        model.load_state_dict(torch.load(weight_path))
        model.eval()
        model = model.cuda()

        for i, (x,t) in enumerate(loader):
            x, t = x.cuda(), t.cuda()
            x_unnormalized = unnormalize(x.detach(), MEAN, STD)

            for attack_method, eps_list in tqdm.tqdm(attacks.items()):
                if attack_method not in ATTACK_METHODS: raise ValueError

                for eps in eps_list:
                    step_size = get_step_size(eps, n_iters, use_max=False)


                    attack = get_attack(dataset, attack_method, eps, n_iters, step_size, scale_each)
                    attack = attack()
                    
                    print(attack.__class__)
                    adv = attack(model, x, t, avoid_target=True, scale_eps=False).detach()

                    # unnormalize
                    adv_unnormalized = unnormalize(adv.detach(), MEAN, STD)
                    delta = adv_unnormalized-x_unnormalized

                    for idx in range(x.size(0)):

                        out = torch.cat([x_unnormalized[idx,:,:,:], 
                                           adv_unnormalized[idx,:,:,:], 
                                           delta[idx,:,:,:]], dim=-1).unsqueeze(0)

                        save_idx  = (FLAGS.batch_size*i)+idx
                        save_name = attack_method+'_{eps}_{save_idx:06d}.png'.format(eps=eps, save_idx=save_idx)
                        save_path = os.path.join(log_root, save_name)

                        torchvision.utils.save_image(out, save_path)

            if i==0: break