class Infer: def __init__(self, args, device): self.args = args self.device = device if args.network == 'resnet18': model = resnet18(pretrained=self.args.pretrained, classes=args.n_classes) elif args.network == 'resnet50': model = resnet50(pretrained=self.args.pretrained, classes=args.n_classes) else: model = resnet18(pretrained=self.args.pretrained, classes=args.n_classes) self.model = model.to(device) self.model_params_PATH = None if args.infer_model: self.model_params_PATH = args.infer_model if isfile(self.model_params_PATH): print(f"=> loading checkpoint '{self.model_params_PATH}'") checkpoint = torch.load(self.model_params_PATH) state_dict = checkpoint['model'] #try to fix last fc layer's name dismatch ['fc'->'class_classifier'] for key in list(state_dict.keys()): if key == 'fc.weight': state_dict['class_classifier.weight'] = state_dict.pop( 'fc.weight') elif key == 'fc.bias': state_dict['class_classifier.bias'] = state_dict.pop( 'fc.bias') #load state_dict self.model.load_state_dict(state_dict) print(f"=> loaded checkpoint") else: raise ValueError( f"Failed to find checkpoint {self.model_params_PATH}") self.dataloader = data_helper.get_tgt_dataloader( self.args, patches=self.model.is_patch_based()) def eval(self): self.model.eval() self.logger = Logger(self.args) with torch.no_grad(): total = len(self.dataloader.dataset) class_correct, auc_dict = self.do_test(self.dataloader) class_acc = float(class_correct) / total self.logger.log_test('Inference Result', {'class_acc': class_acc}) self.logger.log_test('Inference Result', {'auc': auc_dict['auc']}) self.logger.log_test('Inference Result', {'fpr_980': auc_dict['fpr_980']}) self.logger.log_test('Inference Result', {'fpr_991': auc_dict['fpr_991']}) del auc_dict def do_test(self, loader): class_correct = 0 auc_meter = AUCMeter() for it, ((data, nouse, class_l), _) in enumerate(tqdm(loader)): data, nouse, class_l = data.to(self.device), nouse.to( self.device), class_l.to(self.device) class_logit = self.model(data, class_l, False) _, cls_pred = class_logit.max(dim=1) class_correct += torch.sum(cls_pred == class_l.data) cls_score = F.softmax(class_logit, dim=1) auc_meter.update(class_l.cpu(), cls_score.cpu()) auc, fpr_980, fpr_991, fpr_993, fpr_995, fpr_997, fpr_999, fpr_1, thresholds = auc_meter.calculate( ) auc_dict = { 'auc': auc, 'fpr_980': fpr_980, 'fpr_991': fpr_991, 'fpr_993': fpr_993, 'fpr_995': fpr_995, 'fpr_997': fpr_997, 'fpr_999': fpr_999, 'fpr_1': fpr_1, 'thresholds': thresholds } return class_correct, auc_dict
class Trainer: def __init__(self, args, device): self.args = args self.device = device if args.network == 'resnet18': model = resnet18(pretrained=True, classes=args.n_classes) elif args.network == 'resnet50': model = resnet50(pretrained=True, classes=args.n_classes) else: model = resnet18(pretrained=True, classes=args.n_classes) self.model = model.to(device) self.D_model = IntraClsInfoMax(alpha=args.alpha, beta=args.beta, gamma=args.gamma).to(device) # print(self.model) # print(self.D_model) self.source_loader, self.val_loader = data_helper.get_train_dataloader( args, patches=model.is_patch_based()) self.target_loader = data_helper.get_val_dataloader( args, patches=model.is_patch_based()) self.test_loaders = { "val": self.val_loader, "test": self.target_loader } self.len_dataloader = len(self.source_loader) print("Dataset size: train %d, val %d, test %d" % (len(self.source_loader.dataset), len( self.val_loader.dataset), len(self.target_loader.dataset))) self.optimizer, self.scheduler = get_optim_and_scheduler( [self.model, self.D_model.global_d, self.D_model.local_d], args.epochs, args.learning_rate, args.train_all, nesterov=args.nesterov) self.dis_optimizer, self.dis_scheduler = get_optim_and_scheduler( [self.D_model.prior_d], args.epochs, args.learning_rate, args.train_all, nesterov=args.nesterov) #args.learning_ratee*1e-3 self.n_classes = args.n_classes if args.target in args.source: self.target_id = args.source.index(args.target) print("Target in source: %d" % self.target_id) print(args.source) else: self.target_id = None self.max_test_acc = 0.0 self.logger = Logger(self.args, update_frequency=30) self.results = { "val": torch.zeros(self.args.epochs), "test": torch.zeros(self.args.epochs) } def _do_epoch(self, device='cuda'): criterion = nn.CrossEntropyLoss() self.model.train() self.D_model.train() for it, ((data, jig_l, class_l), d_idx) in enumerate(self.source_loader): data, jig_l, class_l, d_idx = data.to(self.device), jig_l.to( self.device), class_l.to(self.device), d_idx.to(self.device) self.optimizer.zero_grad() data_flip = torch.flip(data, (3, )).detach().clone() data = torch.cat((data, data_flip)) class_l = torch.cat((class_l, class_l)) y, M = self.model(data, feature_flag=True) # Classification Loss class_logit = self.model.class_classifier(y) class_loss = criterion(class_logit, class_l) # G loss - DIM Loss - P_loss M_prime = torch.cat( (M[1:], M[0].unsqueeze(0)), dim=0) # Move feature to front position one by one class_prime = torch.cat((class_l[1:], class_l[0].unsqueeze(0)), dim=0) class_ll = (class_l, class_prime) DIM_loss = self.D_model(y, M, M_prime, class_ll) P_loss = self.D_model.prior_loss(y) DIM_loss = DIM_loss - P_loss # DIM_loss=self.beta*(DIM_loss-P_loss) loss = class_loss + DIM_loss loss.backward() self.optimizer.step() self.dis_optimizer.zero_grad() P_loss = self.D_model.prior_loss(y.detach()) P_loss.backward() self.dis_optimizer.step() # Prediction _, cls_pred = class_logit.max(dim=1) losses = { 'class': class_loss.detach().item(), 'DIM': DIM_loss.detach().item(), 'P_loss': P_loss.detach().item() } self.logger.log( it, len(self.source_loader), losses, { "class": torch.sum(cls_pred == class_l.data).item(), }, data.shape[0]) del loss, class_loss, class_logit, DIM_loss self.model.eval() with torch.no_grad(): for phase, loader in self.test_loaders.items(): total = len(loader.dataset) class_correct = self.do_test(loader) class_acc = float(class_correct) / total self.logger.log_test(phase, {"class": class_acc}) self.results[phase][self.current_epoch] = class_acc if phase == 'test' and class_acc > self.max_test_acc: torch.save( self.model.state_dict(), os.path.join(self.logger.log_path, 'best_{}.pth'.format(phase))) def do_test(self, loader): class_correct = 0 for it, ((data, nouse, class_l), _) in enumerate(loader): data, nouse, class_l = data.to(self.device), nouse.to( self.device), class_l.to(self.device) class_logit = self.model(data, feature_flag=False) _, cls_pred = class_logit.max(dim=1) class_correct += torch.sum(cls_pred == class_l.data) return class_correct def do_training(self): for self.current_epoch in range(self.args.epochs): self.scheduler.step() self.dis_scheduler.step() self.logger.new_epoch( [*self.scheduler.get_lr(), *self.dis_scheduler.get_lr()]) self._do_epoch() # use self.current_epoch val_res = self.results["val"] test_res = self.results["test"] idx_best = val_res.argmax() print( "Best val %g, corresponding test %g - best test: %g, best epoch: %g" % (val_res.max(), test_res[idx_best], test_res.max(), idx_best)) self.logger.save_best(test_res[idx_best], test_res.max()) return self.logger, self.model