def load_baseline_model(args): """ :param args: :return: """ if args.dataset == 'cifar10': num_classes = 10 train_loader, val_loader, test_loader = data_loaders.load_cifar10(args.batch_size, val_split=True, augmentation=args.data_augmentation) elif args.dataset == 'cifar100': num_classes = 100 train_loader, val_loader, test_loader = data_loaders.load_cifar100(args.batch_size, val_split=True, augmentation=args.data_augmentation) elif args.dataset == 'mnist': args.datasize, args.valsize, args.testsize = 100, 100, 100 num_train = args.datasize if args.datasize == -1: num_train = 50000 from data_loaders import load_mnist train_loader, val_loader, test_loader = load_mnist(args.batch_size, subset=[args.datasize, args.valsize, args.testsize], num_train=num_train) if args.model == 'resnet18': cnn = ResNet18(num_classes=num_classes) elif args.model == 'wideresnet': cnn = WideResNet(depth=28, num_classes=num_classes, widen_factor=10, dropRate=0.3) checkpoint = None if args.load_baseline_checkpoint: checkpoint = torch.load(args.load_baseline_checkpoint) cnn.load_state_dict(checkpoint['model_state_dict']) model = cnn.cuda() model.train() return model, train_loader, val_loader, test_loader, checkpoint
train_loader, val_loader, test_loader = data_loaders.load_cifar10( args.batch_size, val_split=True, augmentation=args.data_augmentation) elif args.dataset == 'cifar100': num_classes = 100 train_loader, val_loader, test_loader = data_loaders.load_cifar100( args.batch_size, val_split=True, augmentation=args.data_augmentation) if args.model == 'resnet18': cnn = ResNet18(num_classes=num_classes) elif args.model == 'wideresnet': cnn = WideResNet(depth=28, num_classes=num_classes, widen_factor=10, dropRate=0.3) cnn = cnn.cuda() criterion = nn.CrossEntropyLoss().cuda() if args.optimizer == 'sgdm': cnn_optimizer = torch.optim.SGD(cnn.parameters(), lr=args.lr, momentum=0.9, nesterov=True, weight_decay=args.wdecay) elif args.optimizer == 'sgd': cnn_optimizer = torch.optim.SGD(cnn.parameters(), lr=args.lr, weight_decay=args.wdecay) elif args.optimizer == 'adam': cnn_optimizer = torch.optim.Adam( cnn.parameters(), lr=args.lr, weight_decay=args.wdecay
def main(args): harakiri = Harakiri() harakiri.set_max_plateau(20) train_loss_meter = Meter() val_loss_meter = Meter() val_accuracy_meter = Meter() log = JsonLogger(args.log_path, rand_folder=True) log.update(args.__dict__) state = args.__dict__ state['exp_dir'] = os.path.dirname(log.path) state['start_lr'] = state['lr'] print(state) imagenet_mean = [0.485, 0.456, 0.406] imagenet_std = [0.229, 0.224, 0.225] train_dataset = ImageList(args.root_folder, args.train_listfile, transform=transforms.Compose([ transforms.Resize(256), transforms.RandomCrop(224), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize(imagenet_mean, imagenet_std) ])) train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, pin_memory=False, num_workers=args.num_workers) val_dataset = ImageList(args.root_folder, args.val_listfile, transform=transforms.Compose([ transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), transforms.Normalize(imagenet_mean, imagenet_std) ])) val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, pin_memory=False, num_workers=args.num_workers) if args.attention_depth == 0: from models.wide_resnet import WideResNet model = WideResNet().finetune(args.nlabels).cuda() else: from models.wide_resnet_attention import WideResNetAttention model = WideResNetAttention(args.nlabels, args.attention_depth, args.attention_width, args.has_gates, args.reg_weight).finetune(args.nlabels) # if args.load != "": # net.load_state_dict(torch.load(args.load), strict=False) # net = net.cuda() optimizer = optim.SGD([{ 'params': model.get_base_params(), 'lr': args.lr * 0.1 }, { 'params': model.get_classifier_params() }], lr=args.lr, weight_decay=1e-4, momentum=0.9, nesterov=True) if args.ngpu > 1: model = torch.nn.DataParallel(model, range(args.ngpu)).cuda() else: model = model.cuda() criterion = torch.nn.NLLLoss().cuda() def train(): """ """ model.train() for data, label in train_loader: data, label = torch.autograd.Variable(data, requires_grad=False).cuda(async=True), \ torch.autograd.Variable(label, requires_grad=False).cuda() optimizer.zero_grad() if args.attention_depth > 0: output, loss = model(data) if args.reg_weight > 0: loss = loss.mean() else: loss = 0 else: loss = 0 output = model(data) loss += F.nll_loss(output, label) loss.backward() optimizer.step() train_loss_meter.update(loss.data[0], data.size(0)) state['train_loss'] = train_loss_meter.mean() def val(): """ """ model.eval() for data, label in val_loader: data, label = torch.autograd.Variable(data, volatile=True).cuda(async=True), \ torch.autograd.Variable(label, volatile=True).cuda() if args.attention_depth > 0: output, loss = model(data) else: output = model(data) loss = F.nll_loss(output, label) val_loss_meter.update(loss.data[0], data.size(0)) preds = output.max(1)[1] val_accuracy_meter.update((preds == label).float().sum().data[0], data.size(0)) state['val_loss'] = val_loss_meter.mean() state['val_accuracy'] = val_accuracy_meter.mean() best_accuracy = 0 counter = 0 for epoch in range(args.epochs): train() val() harakiri.update(epoch, state['val_accuracy']) if state['val_accuracy'] > best_accuracy: counter = 0 best_accuracy = state['val_accuracy'] if args.save: torch.save(model.state_dict(), os.path.join(state["exp_dir"], "model.pytorch")) else: counter += 1 state['epoch'] = epoch + 1 log.update(state) print(state) if (epoch + 1) in args.schedule: for param_group in optimizer.param_groups: param_group['lr'] *= 0.1 state['lr'] *= 0.1
class NeuralNet: def __init__(self): if torch.cuda.is_available(): self.device = torch.device('cuda') else: self.device = torch.device('cpu') print('WARNING: Found no valid GPU device - Running on CPU') self.model = WideResNet(DEPTH, cfg.NUM_TRANS) self.model.cuda(self.device) self.criterion = torch.nn.CrossEntropyLoss().cuda(self.device) self.optimizer = None def test(self, test_gen): self.model.eval() batch_time = self.AverageMeter('Time', ':6.3f') losses = self.AverageMeter('Loss', ':.4e') top1 = self.AverageMeter('Acc@1', ':6.2f') progress = self.ProgressMeter(len(test_gen), batch_time, losses, top1, prefix='Test: ') # switch to evaluate mode self.model.eval() with torch.no_grad(): end = time.time() for i, (input, target, _) in enumerate(test_gen): input = input.cuda(self.device, non_blocking=True) target = target.cuda(self.device, non_blocking=True) # Compute output output, _ = self.model([input, target]) loss = self.criterion(output, target) # Measure accuracy and record loss acc1 = self._accuracy(output, target, topk=(1)) losses.update(loss.item(), input.size(0)) top1.update(acc1[0].cpu().detach().item(), input.size(0)) # Measure elapsed time batch_time.update(time.time() - end) end = time.time() # Print to screen if i % 100 == 0: progress.print(i) # TODO: this should also be done with the ProgressMeter print(' * Acc@1 {top1.avg:.3f}'.format(top1=top1)) return top1.avg def train(self, train_gen, test_gen, epochs, lr=0.0001, lr_plan=None, momentum=0.9, wd=5e-4): self.optimizer = torch.optim.SGD(self.model.parameters(), lr=lr, momentum=momentum, weight_decay=wd) #self.optimizer = torch.optim.Adam(self.model.parameters()) for epoch in range(epochs): self._adjust_lr_rate(self.optimizer, epoch, lr_plan) print("=> Training (specific label)") self._train_step(train_gen, epoch, self.optimizer) print("=> Validation (entire dataset)") self.test(test_gen) def _train_step(self, train_gen, epoch, optimizer): self.model.train() batch_time = self.AverageMeter('Time', ':6.3f') data_time = self.AverageMeter('Data', ':6.3f') losses = self.AverageMeter('Loss', ':.4e') top1 = self.AverageMeter('Acc@1', ':6.2f') progress = self.ProgressMeter(len(train_gen), batch_time, data_time, losses, top1, prefix="Epoch: [{}]".format(epoch)) end = time.time() for i, (input, target, _) in enumerate(train_gen): # measure data loading time data_time.update(time.time() - end) input = input.cuda(self.device, non_blocking=True) target = target.cuda(self.device, non_blocking=True) # Compute output output, trans_out = self.model([input, target]) loss = self.criterion(output, target) # measure accuracy and record loss acc1 = self._accuracy(output, target, topk=(1)) losses.update(loss.item(), input.size(0)) top1.update(acc1[0].cpu().detach().item(), input.size(0)) # compute gradient and do SGD step optimizer.zero_grad() loss.backward() optimizer.step() # measure elapsed time batch_time.update(time.time() - end) end = time.time() if i % 100 == 0: progress.print(i) def evaluate(self, eval_gen): # switch to evaluate mode self.model.eval() score_func_list = [] labels_list = [] with torch.no_grad(): for i, (input, target, labels) in enumerate(eval_gen): input = input.cuda(self.device, non_blocking=True) #Target- the transforamation class target = target.cuda(self.device, non_blocking=True) #The true label labels = labels[[ cfg.NUM_TRANS * x for x in range(len(labels) // cfg.NUM_TRANS) ]].cuda(self.device, non_blocking=True) # Compute output # #TODO: Rewrite this code section, can be more efficient output_SM = self.model([input, target]) target_mat = torch.zeros_like(output_SM[0]) target_mat[range(output_SM[0].shape[0]), target] = 1 target_SM = (target_mat * output_SM[0]).sum(dim=1).view( -1, cfg.NUM_TRANS).sum(dim=1) score_func_list.append(1 / cfg.NUM_TRANS * target_SM) labels_list.append(labels) return torch.cat(score_func_list), torch.cat(labels_list) def _adjust_lr_rate(self, optimizer, epoch, lr_dict): if lr_dict is None: return for key, value in lr_dict.items(): if epoch == key: print("=> New learning rate set of {}".format(value)) for param_group in optimizer.param_groups: param_group['lr'] = value def summary(self, x_size, print_it=True): return self.model.summary(x_size, print_it=print_it) def print_weights(self): self.model.print_weights() @staticmethod def _accuracy(output, target, topk=(1)): """Computes the accuracy over the k top predictions for the specified values of k""" with torch.no_grad(): maxk = 1 batch_size = target.size(0) _, pred = output.topk(maxk, 1, True, True) pred = pred.t() correct = pred.eq(target.view(1, -1).expand_as(pred)) res = [] correct_k = correct[0].view(-1).float().sum(0, keepdim=True) res.append(correct_k.mul_(100.0 / batch_size)) return res class AverageMeter(object): """Computes and stores the average and current value""" def __init__(self, name, fmt=':f'): self.name = name self.fmt = fmt self.reset() def reset(self): self.val = 0 self.avg = 0 self.sum = 0 self.count = 0 def update(self, val, n=1): self.val = val self.sum += val * n self.count += n self.avg = self.sum / self.count def __str__(self): fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})' return fmtstr.format(**self.__dict__) class ProgressMeter(object): def __init__(self, num_batches, *meters, prefix=""): self.batch_fmtstr = self._get_batch_fmtstr(num_batches) self.meters = meters self.prefix = prefix def print(self, batch): entries = [self.prefix + self.batch_fmtstr.format(batch)] entries += [str(meter) for meter in self.meters] print('\t'.join(entries)) def _get_batch_fmtstr(self, num_batches): num_digits = len(str(num_batches // 1)) fmt = '{:' + str(num_digits) + 'd}' return '[' + fmt + '/' + fmt.format(num_batches) + ']'