def clf_test(net, vloader, crit: nn.Module = nn.CrossEntropyLoss): """ This function helps in quickly testing the network. Arguments --------- net : nn.Module The net which to train. vloader : torch.nn.utils.DataLoader or a generator which returns the images and the labels """ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') vcorr = 0 vloss = 0 net.to(device) net.eval() try: crit = crit() except: pass with torch.no_grad(): for ii, (data, labl) in enumerate(tqdm(vloader)): data, labl = data.to(device), labl.to(device) out = net(data) vloss += crit(out, labl).item() vcorr += (out.argmax(dim=1) == labl).float().sum() vacc = accuracy(vcorr, len(vloader) * vloader.batch_size) vloss /= len(vloader) return vacc, vloss
def clf_train(net, tloader, opti: torch.optim, crit: nn.Module, **kwargs): # TODO Fix this if kwargs['topk'] != (1, 5): raise Exception('topk other than (1, 5) not supported for now.') device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') net.to(device) net.train() a1mtr = AvgMeter('train_acc1') a5mtr = AvgMeter('train_acc5') tloss = 0 try: crit = crit() except: pass for ii, (data, labl) in enumerate(tqdm(tloader)): data, labl = data.to(device), labl.to(device) out = net(data) loss = crit(out, labl) opti.zero_grad() loss.backward() opti.step() with torch.no_grad(): tloss += loss.item() acc1, acc5 = accuracy(out, labl, topk=kwargs['topk']) a1mtr(acc1, data.size(0)) a5mtr(acc5, data.size(0)) tloss /= len(tloader) return (a1mtr.avg, a5mtr.avg), tloss
def clf_test(net, vloader, crit: nn.Module = nn.CrossEntropyLoss, topk=(1, 5)): """ This function helps in quickly testing the network. Arguments --------- net : nn.Module The net which to train. vloader : torch.nn.utils.DataLoader or a generator which returns the images and the labels """ # TODO Fix this if topk != (1, 5): raise Exception('topk other than (1, 5) not supported for now.') device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') a1mtr = AvgMeter('test_acc1') a5mtr = AvgMeter('test_acc5') # vloss = 0 net.to(device) net.eval() try: crit = crit() except: pass with torch.no_grad(): for ii, (data, labl) in enumerate(tqdm(vloader)): data, labl = data.to(device), labl.to(device) out = net(data) vloss += crit(out, labl).item() acc1, acc5 = accuracy(out, labl, topk=topk) a1mtr(acc1, data.size(0)) a5mtr(acc5, data.size(0)) vloss /= len(vloader) return (a1mtr.avg, a5mtr.avg), vloss
def clf_train(net, tloader, opti: torch.optim, crit: nn.Module, **kwargs): device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') net.to(device) net.train() tcorr = 0 tloss = 0 try: crit = crit() except: pass for ii, (data, labl) in enumerate(tqdm(tloader)): data, labl = data.to(device), labl.to(device) out = net(data) loss = crit(out, labl) opti.zero_grad() loss.backward() opti.step() with torch.no_grad(): tloss += loss.item() tcorr += (out.argmax(dim=1) == labl).float().sum() tloss /= len(tloader) tacc = accuracy(tcorr, len(tloader) * tloader.batch_size) return tacc, tloss
def benchmark_atk(atk, net: nn.Module, **kwargs): """ Helper function to benchmark using a particular attack on a particular dataset. All benchmarks that are present in this repository are created using this function. Arguments --------- atk : scratchai.attacks.attacks The attack on which to use. net : nn.Module The net which is to be attacked. root : str The root directory of the dataset. dfunc : function The function that can take the root and torchvision.transforms and return a torchvision.Datasets object Defaults to datasets.ImageFolder trf : torchvision.Transforms The transforms that you want to apply. Defaults to (get_trf('rz256_cc224_tt_normimgnet') bs : int The batch size. Defaults to 4. """ loader, topk, kwargs = pre_benchmark_atk(**kwargs) freeze(net) print('[INFO] Net Frozen!') atk = atk(net, **kwargs) atk_name = name_from_object(atk) net_name = name_from_object(net) device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') loss = 0 adv_loss = 0 oatopk = Topk('original accuracy', topk) aatopk = Topk('adversarial accuracy', topk) net.to(device) net.eval() crit = nn.CrossEntropyLoss() for ii, (data, labl) in enumerate(tqdm(loader)): adv_data, data = atk(data.to(device).clone()), data.to(device) labl = labl.to(device) adv_out = net(adv_data) out = net(data) loss += crit(out, labl).item() adv_loss += crit(adv_out, labl).item() acc = accuracy(out, labl, topk) adv_acc = accuracy(adv_out, labl, topk) oatopk.update(acc, data.size(0)) aatopk.update(adv_acc, data.size(0)) loss /= len(loader) adv_loss /= len(loader) print('\nAttack Summary on {} with {} attack:'.format(net_name, atk_name)) print('-' * 45) print(oatopk) print('-' * 35) print(aatopk)