示例#1
0
def pre_benchmark_atk(**kwargs):
    """
  Helper function that sets all the defaults while performing checks
  for all the options passed before benchmarking attacks.
  """

    # Set the Default options if nothing explicit provided
    def_dict = {
        'bs': 4,
        'trf': get_trf('rz256_cc224_tt_normimgnet'),
        'dset': 'NA',
        'root': './',
        'topk': (1, 5),
        'dfunc': datasets.ImageFolder,
        'download': True,
    }

    for key, val in def_dict.items():
        if key not in kwargs: kwargs[key] = val

    if kwargs['dset'] == 'NA':
        if 'loader' not in kwargs:
            dset = kwargs['dfunc'](kwargs['root'], transform=kwargs['trf'])
            loader = DataLoader(dset, batch_size=kwargs['bs'], num_workers=2)
        else:
            loader = kwargs['loader']

    # Set dataset specific functions here
    else:
        if kwargs['dset'] == IMGNET12:
            dset = datasets.ImageNet(kwargs['root'],
                                     split='test',
                                     download=kwargs['download'],
                                     transform=kwargs['trf'])
        elif kwargs['dset'] == MNIST:
            kwargs['trf'] = get_trf('tt_normmnist')
            kwargs['dfunc'] = datasets.MNIST
            dset = kwargs['dfunc'](kwargs['root'],
                                   train=False,
                                   download=kwargs['download'],
                                   transform=kwargs['trf'])
        else:
            raise

        loader = DataLoader(dset, shuffle=False, batch_size=kwargs['bs'])
    topk = kwargs['topk']

    for key, val in def_dict.items():
        print('[INFO] Setting {} to {}.'.format(key, kwargs[key]))

    # Deleting keys that is used just for benchmark_atk() function is
    # important as the same kwargs dict is passed to initialize the attack
    # So, otherwise the attack will throw an exception
    for key in def_dict:
        del kwargs[key]
    if 'loader' in kwargs: del kwargs['loader']

    return loader, topk, kwargs
示例#2
0
def custom(net, tloader, vloader, **kwargs):
    """
  Train on a custom dataset with net.

  Arguments
  ---------
  net : nn.Module
        The net which to train.
  optim : nn.optim
          The optimizer to use.
  crit : nn.Module
         The criterion to use.
  lr : float
       The learning rate.
  wd : float
       The weight decay.
  bs : int
       The batch size.
  seed : int
         The particular seed to use.
  epochs : int
           The epcochs to train for.
  ckpt : str, None
         Path to the ckpt file. If not None, training is started
         using this ckpt file. Defaults to None
  root : str
         The root where the datasets is or
         needs to be downloaded.
  
  Returns
  -------
  tlist : list
          Contains list of n 2-tuples. where n == epochs
          and a tuple (a, b) where,
          a -> is the acc for the corresponding index
          b -> is the loss for the corresponding index
          for training
  vlist : list
          Contains list of n 2-tuples. where n == epochs
          and a tuple (a, b) where,
          a -> is the acc for the corresponding index
          b -> is the loss for the corresponding index
          for validation
  """
    opti, crit, kwargs = preprocess_opts(net, **kwargs)

    #trf = get_trf('rz256_cc224_tt_normimgnet')
    trf = get_trf('rz32_tt_normimgnet')

    tlist, vlist = clf_fit(net, crit, opti, tloader, vloader, **kwargs)
    plt_tr_vs_tt(tlist, vlist)
示例#3
0
def cifar10(net, **kwargs):
    """
  Train on MNIST with net.

  Arguments
  ---------
  net : nn.Module
        The net which to train.
  optim : nn.optim
          The optimizer to use.
  crit : nn.Module
         The criterion to use.
  lr : float
       The learning rate.
  wd : float
       The weight decay.
  bs : int
       The batch size.
  seed : int
         The particular seed to use.
  epochs : int
           The epcochs to train for.
  ckpt : str, None
         Path to the ckpt file. If not None, training is started
         using this ckpt file. Defaults to None
  root : str
         The root where the datasets is or
         needs to be downloaded.
  
  Returns
  -------
  tlist : list
          Contains list of n 2-tuples. where n == epochs
          and a tuple (a, b) where,
          a -> is the acc for the corresponding index
          b -> is the loss for the corresponding index
          for training
  vlist : list
          Contains list of n 2-tuples. where n == epochs
          and a tuple (a, b) where,
          a -> is the acc for the corresponding index
          b -> is the loss for the corresponding index
          for validation
  """
    opti, crit, kwargs = preprocess_opts(net, dset=CIFAR10, **kwargs)

    trf = get_trf('rr2_tt_normimgnet')

    t = datasets.CIFAR10(kwargs['root'],
                         train=True,
                         download=True,
                         transform=trf)
    v = datasets.CIFAR10(kwargs['root'],
                         train=False,
                         download=True,
                         transform=trf)
    tloader = DataLoader(t, shuffle=True, batch_size=kwargs['bs'])
    vloader = DataLoader(v, shuffle=True, batch_size=kwargs['bs'])

    tlist, vlist = clf_fit(net, crit, opti, tloader, vloader, **kwargs)
    plt_tr_vs_tt(tlist, vlist)