class Trainer: def __init__(self, args, device): self.alpha_jigsaw_weight = 0.5 self.alpha_odd_weight = 0.5 self.alpha_rotation_weight = 0.5 self.args = args self.device = device self.betaJigen = args.betaJigen model = model_factory.get_network(args.network)(classes=args.n_classes, jigsaw_classes=31, odd_classes=10, rotation_classes=4) #if args.rotation== True: # model = model_factory.get_network(args.network)(classes=args.n_classes,jigsaw_classes=4) #elif args.oddOneOut == True: # model = model_factory.get_network(args.network)(classes=args.n_classes,jigsaw_classes=10) #else: # model = model_factory.get_network(args.network)(classes=args.n_classes,jigsaw_classes=31) self.model = model.to(device) self.source_loader, self.val_loader = data_helper.get_train_dataloader( args) self.target_loader = data_helper.get_val_dataloader(args) self.test_loaders = { "val": self.val_loader, "test": self.target_loader } self.len_dataloader = len(self.source_loader) print("Dataset size: train %d, val %d, test %d" % (len(self.source_loader.dataset), len( self.val_loader.dataset), len(self.target_loader.dataset))) self.optimizer, self.scheduler = get_optim_and_scheduler( model, args.epochs, args.learning_rate, args.train_all) self.n_classes = args.n_classes if args.oddOneOut == True and args.rotation == True: self.nTasks = 4 elif args.oddOneOut == True or args.rotation == True: self.nTasks = 3 else: self.nTasks = 2 def _do_epoch(self): criterion = nn.CrossEntropyLoss() self.model.train() for it, (data, class_l, jigsaw_l, self_sup_task) in enumerate(self.source_loader): #source_loader is only data for training data, class_l, jigsaw_l, self_sup_task = data.to( self.device), class_l.to(self.device), jigsaw_l.to( self.device), self_sup_task.to(self.device) self.optimizer.zero_grad() class_logit, jigsaw_logit, odd_logit, rotation_logit = self.model( data) #label from model #evaluate jigsaw mistake jigsaw_loss = criterion( jigsaw_logit[(self_sup_task == 0) | (self_sup_task == 3)], jigsaw_l[(self_sup_task == 0) | (self_sup_task == 3)]) if self.args.oddOneOut == True: odd_loss = criterion( odd_logit[(self_sup_task == 1) | (self_sup_task == 3)], jigsaw_l[(self_sup_task == 1) | (self_sup_task == 3)]) else: odd_loss = 0 if self.args.rotation == True: rotation_loss = criterion( rotation_logit[(self_sup_task == 2) | (self_sup_task == 3)], jigsaw_l[(self_sup_task == 2) | (self_sup_task == 3)]) else: rotation_loss = 0 #for classification we evaluate the loss only for the not scrumbled images class_loss = criterion(class_logit[jigsaw_l == 0], class_l[jigsaw_l == 0]) _, jigsaw_pred = jigsaw_logit[(self_sup_task == 0) | (self_sup_task == 3)].max(dim=1) if self.args.oddOneOut == True: _, odd_pred = odd_logit[(self_sup_task == 1) | (self_sup_task == 3)].max(dim=1) if self.args.rotation == True: _, rotation_pred = rotation_logit[(self_sup_task == 2) | (self_sup_task == 3)].max( dim=1) _, cls_pred = class_logit.max(dim=1) loss = class_loss + self.alpha_jigsaw_weight * jigsaw_loss + self.alpha_odd_weight * odd_loss + self.alpha_rotation_weight * rotation_loss loss.backward() self.optimizer.step() if self.args.oddOneOut == True and self.args.rotation == True: self.logger.log( it, len(self.source_loader), { "Class Loss ": class_loss.item(), "Jigsaw Loss": jigsaw_loss.item(), "Odd Loss": odd_loss.item(), "Rotation Loss": rotation_loss.item() }, { "Class Accuracy ": torch.sum(cls_pred == class_l.data).item(), "Jigsaw Accuracy ": torch.sum(jigsaw_pred == jigsaw_l[ (self_sup_task == 0) | (self_sup_task == 3)].data).item(), "Odd Accuracy ": torch.sum(odd_pred == jigsaw_l[(self_sup_task == 1) | ( self_sup_task == 3)].data).item(), "Rotation Accuracy ": torch.sum(rotation_pred == jigsaw_l[ (self_sup_task == 2) | (self_sup_task == 3)].data).item() }, data.shape[0]) elif self.args.oddOneOut == True and self.args.rotation == False: self.logger.log( it, len(self.source_loader), { "Class Loss ": class_loss.item(), "Jigsaw Loss": jigsaw_loss.item(), "Odd Loss": odd_loss.item() }, { "Class Accuracy ": torch.sum(cls_pred == class_l.data).item(), "Jigsaw Accuracy ": torch.sum(jigsaw_pred == jigsaw_l[ (self_sup_task == 0) | (self_sup_task == 3)].data).item(), "Odd Accuracy ": torch.sum(odd_pred == jigsaw_l[(self_sup_task == 1) | ( self_sup_task == 3)].data).item() }, data.shape[0]) elif self.args.oddOneOut == False and self.args.rotation == True: self.logger.log( it, len(self.source_loader), { "Class Loss ": class_loss.item(), "Jigsaw Loss": jigsaw_loss.item(), "Rotation Loss": rotation_loss.item() }, { "Class Accuracy ": torch.sum(cls_pred == class_l.data).item(), "Jigsaw Accuracy ": torch.sum(jigsaw_pred == jigsaw_l[ (self_sup_task == 0) | (self_sup_task == 3)].data).item(), "Rotation Accuracy ": torch.sum(rotation_pred == jigsaw_l[ (self_sup_task == 2) | (self_sup_task == 3)].data).item() }, data.shape[0]) else: self.logger.log( it, len(self.source_loader), { "Class Loss ": class_loss.item(), "Jigsaw Loss": jigsaw_loss.item() }, { "Class Accuracy ": torch.sum(cls_pred == class_l.data).item(), "Jigsaw Accuracy ": torch.sum(jigsaw_pred == jigsaw_l[ (self_sup_task == 0) | (self_sup_task == 3)].data).item() }, data.shape[0]) del loss, class_loss, jigsaw_loss, jigsaw_logit, class_logit, odd_loss, rotation_loss, odd_logit, rotation_logit self.model.eval() with torch.no_grad(): for phase, loader in self.test_loaders.items(): total = len(loader.dataset) class_correct, jigsaw_correct, odd_correct, rotation_correct = self.do_test( loader) class_acc = float(class_correct) / total jigsaw_acc = float(jigsaw_correct) / total odd_acc = float(odd_correct) / total rotation_acc = float(rotation_correct) / total acc = (class_acc + jigsaw_acc + odd_acc + rotation_acc) / self.nTasks self.logger.log_test(phase, {"Classification Accuracy": acc}) self.results[phase][self.current_epoch] = acc def do_test(self, loader): class_correct = 0 jigsaw_correct = 0 odd_correct = 0 rotation_correct = 0 for it, (data, class_l, jigsaw_l, self_sup_task) in enumerate(loader): data, class_l, jigsaw_l, self_sup_task = data.to( self.device), class_l.to(self.device), jigsaw_l.to( self.device), self_sup_task.to(self.device) class_logit, jigsaw_logit, odd_logit, rotation_logit = self.model( data) _, jigsaw_pred = jigsaw_logit.max(dim=1) if self.args.oddOneOut == True: _, odd_pred = odd_logit.max(dim=1) odd_correct += torch.sum(odd_pred == jigsaw_l.data) if self.args.rotation == True: _, rotation_pred = rotation_logit.max(dim=1) rotation_correct += torch.sum(rotation_pred == jigsaw_l.data) _, cls_pred = class_logit.max(dim=1) jigsaw_correct += torch.sum(jigsaw_pred == jigsaw_l.data) class_correct += torch.sum(cls_pred == class_l.data) return class_correct, jigsaw_correct, odd_correct, rotation_correct def do_training(self): self.logger = Logger(self.args, update_frequency=30) self.results = { "val": torch.zeros(self.args.epochs), "test": torch.zeros(self.args.epochs) } for self.current_epoch in range(self.args.epochs): self.logger.new_epoch(self.scheduler.get_lr()) self._do_epoch() self.scheduler.step() val_res = self.results["val"] test_res = self.results["test"] idx_best = val_res.argmax() print("Best val %g, corresponding test %g - best test: %g" % (val_res.max(), test_res[idx_best], test_res.max())) self.logger.save_best(test_res[idx_best], test_res.max()) return self.logger, self.model
class Trainer: def __init__(self, args, device): self.args = args self.device = device model = resnet18(pretrained=True, classes=args.n_classes) # ------ self.model = model.to(device) # print(self.model) self.source_loader, self.val_loader = data_helper.get_train_dataloader(args, patches=model.is_patch_based()) self.target_loader = data_helper.get_val_dataloader(args, patches=model.is_patch_based()) self.test_loaders = {"val": self.val_loader, "test": self.target_loader} self.len_dataloader = len(self.source_loader) print("Dataset size: train %d, val %d, test %d" % ( len(self.source_loader.dataset), len(self.val_loader.dataset), len(self.target_loader.dataset))) self.optimizer, self.scheduler = get_optim_and_scheduler(model, args.epochs, args.learning_rate, args.train_all, nesterov=args.nesterov) self.jig_weight = args.jig_weight self.only_non_scrambled = args.classify_only_sane self.n_classes = args.n_classes if args.target in args.source: self.target_id = args.source.index(args.target) print("Target in source: %d" % self.target_id) print(args.source) else: self.target_id = None def _do_epoch(self): criterion = nn.CrossEntropyLoss() self.model.train() for it, ((data, jig_l, class_l), d_idx) in enumerate(self.source_loader): data, jig_l, class_l, d_idx = data.to(self.device), jig_l.to(self.device), class_l.to( self.device), d_idx.to(self.device) self.optimizer.zero_grad() class_logit = self.model(data, class_l, True) class_loss = criterion(class_logit, class_l) _, cls_pred = class_logit.max(dim=1) loss = class_loss loss.backward() self.optimizer.step() self.logger.log(it, len(self.source_loader), {"class": class_loss.item()}, {"class": torch.sum(cls_pred == class_l.data).item(), }, data.shape[0]) del loss, class_loss, class_logit self.model.eval() with torch.no_grad(): for phase, loader in self.test_loaders.items(): total = len(loader.dataset) class_correct = self.do_test(loader) class_acc = float(class_correct) / total self.logger.log_test(phase, {"class": class_acc}) self.results[phase][self.current_epoch] = class_acc def do_test(self, loader): class_correct = 0 for it, ((data, nouse, class_l), _) in enumerate(loader): data, nouse, class_l = data.to(self.device), nouse.to(self.device), class_l.to(self.device) class_logit = self.model(data, class_l, False) _, cls_pred = class_logit.max(dim=1) class_correct += torch.sum(cls_pred == class_l.data) return class_correct def do_training(self): self.logger = Logger(self.args, update_frequency=30) self.results = {"val": torch.zeros(self.args.epochs), "test": torch.zeros(self.args.epochs)} for self.current_epoch in range(self.args.epochs): self.scheduler.step() self.logger.new_epoch(self.scheduler.get_lr()) self._do_epoch() val_res = self.results["val"] test_res = self.results["test"] idx_best = val_res.argmax() print("Best val %g, corresponding test %g - best test: %g, best epoch: %g" % ( val_res.max(), test_res[idx_best], test_res.max(), idx_best)) self.logger.save_best(test_res[idx_best], test_res.max()) return self.logger, self.model
class Trainer: def __init__(self, args, device): self.args = args self.device = device model = model_factory.get_network(args.network)( jigsaw_classes=args.jigsaw_n_classes + 1, classes=args.n_classes) self.model = model.to(device) # print(self.model) self.source_loader, self.val_loader = data_helper.get_train_dataloader( args, patches=model.is_patch_based()) self.target_loader = data_helper.get_val_dataloader( args, patches=model.is_patch_based()) self.test_loaders = { "val": self.val_loader, "test": self.target_loader } self.len_dataloader = len(self.source_loader) print("Dataset size: train %d, val %d, test %d" % (len(self.source_loader.dataset), len( self.val_loader.dataset), len(self.target_loader.dataset))) self.optimizer, self.scheduler = get_optim_and_scheduler( model, args.epochs, args.learning_rate, args.train_all, nesterov=args.nesterov) self.jig_weight = args.jig_weight self.only_non_scrambled = args.classify_only_sane self.n_classes = args.n_classes if args.target in args.source: self.target_id = args.source.index(args.target) print("Target in source: %d" % self.target_id) print(args.source) else: self.target_id = None def _do_epoch(self): criterion = nn.CrossEntropyLoss() self.model.train() for it, ((data, jig_l, class_l), d_idx) in enumerate(self.source_loader): data, jig_l, class_l, d_idx = data.to(self.device), jig_l.to( self.device), class_l.to(self.device), d_idx.to(self.device) # absolute_iter_count = it + self.current_epoch * self.len_dataloader # p = float(absolute_iter_count) / self.args.epochs / self.len_dataloader # lambda_val = 2. / (1. + np.exp(-10 * p)) - 1 # if domain_error > 2.0: # lambda_val = 0 # print("Shutting down LAMBDA to prevent implosion") self.optimizer.zero_grad() jigsaw_logit, class_logit = self.model( data) # , lambda_val=lambda_val) jigsaw_loss = criterion(jigsaw_logit, jig_l) # domain_loss = criterion(domain_logit, d_idx) # domain_error = domain_loss.item() if self.only_non_scrambled: if self.target_id is not None: idx = (jig_l == 0) & (d_idx != self.target_id) class_loss = criterion(class_logit[idx], class_l[idx]) else: class_loss = criterion(class_logit[jig_l == 0], class_l[jig_l == 0]) elif self.target_id: class_loss = criterion(class_logit[d_idx != self.target_id], class_l[d_idx != self.target_id]) else: class_loss = criterion(class_logit, class_l) _, cls_pred = class_logit.max(dim=1) _, jig_pred = jigsaw_logit.max(dim=1) # _, domain_pred = domain_logit.max(dim=1) loss = class_loss + jigsaw_loss * self.jig_weight # + 0.1 * domain_loss loss.backward() self.optimizer.step() self.logger.log( it, len(self.source_loader), { "jigsaw": jigsaw_loss.item(), "class": class_loss.item() # , "domain": domain_loss.item() }, # ,"lambda": lambda_val}, { "jigsaw": torch.sum(jig_pred == jig_l.data).item(), "class": torch.sum(cls_pred == class_l.data).item(), # "domain": torch.sum(domain_pred == d_idx.data).item() }, data.shape[0]) del loss, class_loss, jigsaw_loss, jigsaw_logit, class_logit self.model.eval() with torch.no_grad(): for phase, loader in self.test_loaders.items(): total = len(loader.dataset) if loader.dataset.isMulti(): jigsaw_correct, class_correct, single_acc = self.do_test_multi( loader) print("Single vs multi: %g %g" % (float(single_acc) / total, float(class_correct) / total)) else: jigsaw_correct, class_correct = self.do_test(loader) jigsaw_acc = float(jigsaw_correct) / total class_acc = float(class_correct) / total self.logger.log_test(phase, { "jigsaw": jigsaw_acc, "class": class_acc }) self.results[phase][self.current_epoch] = class_acc def do_test(self, loader): jigsaw_correct = 0 class_correct = 0 domain_correct = 0 for it, ((data, jig_l, class_l), _) in enumerate(loader): data, jig_l, class_l = data.to(self.device), jig_l.to( self.device), class_l.to(self.device) jigsaw_logit, class_logit = self.model(data) _, cls_pred = class_logit.max(dim=1) _, jig_pred = jigsaw_logit.max(dim=1) class_correct += torch.sum(cls_pred == class_l.data) jigsaw_correct += torch.sum(jig_pred == jig_l.data) return jigsaw_correct, class_correct def do_test_multi(self, loader): jigsaw_correct = 0 class_correct = 0 single_correct = 0 for it, ((data, jig_l, class_l), d_idx) in enumerate(loader): data, jig_l, class_l = data.to(self.device), jig_l.to( self.device), class_l.to(self.device) n_permutations = data.shape[1] class_logits = torch.zeros(n_permutations, data.shape[0], self.n_classes).to(self.device) for k in range(n_permutations): class_logits[k] = F.softmax(self.model(data[:, k])[1], dim=1) class_logits[ 0] *= 4 * n_permutations # bias more the original image class_logit = class_logits.mean(0) _, cls_pred = class_logit.max(dim=1) jigsaw_logit, single_logit = self.model(data[:, 0]) _, jig_pred = jigsaw_logit.max(dim=1) _, single_logit = single_logit.max(dim=1) single_correct += torch.sum(single_logit == class_l.data) class_correct += torch.sum(cls_pred == class_l.data) jigsaw_correct += torch.sum(jig_pred == jig_l.data[:, 0]) return jigsaw_correct, class_correct, single_correct def do_training(self): self.logger = Logger(self.args, update_frequency=30) # , "domain", "lambda" self.results = { "val": torch.zeros(self.args.epochs), "test": torch.zeros(self.args.epochs) } for self.current_epoch in range(self.args.epochs): self.scheduler.step() self.logger.new_epoch(self.scheduler.get_lr()) self._do_epoch() val_res = self.results["val"] test_res = self.results["test"] idx_best = val_res.argmax() #print("Best val %g, corresponding test %g - best test: %g" % (val_res.max(), test_res[idx_best], test_res.max())) self.logger.save_best(test_res[idx_best], test_res.max()) return self.logger, self.model
class Trainer: def __init__(self, args, device): self.args = args self.device = device model = model_factory.get_network(args.network)( jigsaw_classes=args.jigsaw_n_classes + 1, classes=args.n_classes) self.model = model.to(device) # print(self.model) if args.target in args.source: print( "No need to include target in source, it is automatically done by this script" ) k = args.source.index(args.target) args.source = args.source[:k] + args.source[k + 1:] print("Source: %s" % args.source) self.source_loader, self.val_loader = data_helper.get_train_dataloader( args, patches=model.is_patch_based()) self.target_jig_loader = data_helper.get_target_jigsaw_loader(args) self.target_loader = data_helper.get_val_dataloader( args, patches=model.is_patch_based()) self.test_loaders = { "val": self.val_loader, "test": self.target_loader } self.len_dataloader = len(self.source_loader) print("Dataset size: train %d, target jig: %d, val %d, test %d" % (len(self.source_loader.dataset), len(self.target_jig_loader.dataset), len( self.val_loader.dataset), len(self.target_loader.dataset))) self.optimizer, self.scheduler = get_optim_and_scheduler( model, args.epochs, args.learning_rate, args.train_all, nesterov=args.nesterov) self.jig_weight = args.jig_weight self.target_weight = args.target_weight self.target_entropy = args.entropy_weight self.only_non_scrambled = args.classify_only_sane self.n_classes = args.n_classes def _do_epoch(self): criterion = nn.CrossEntropyLoss() self.model.train() for it, (source_batch, target_batch) in enumerate( zip(self.source_loader, itertools.cycle(self.target_jig_loader))): (data, jig_l, class_l), d_idx = source_batch data, jig_l, class_l, d_idx = data.to(self.device), jig_l.to( self.device), class_l.to(self.device), d_idx.to(self.device) tdata, tjig_l, _ = target_batch tdata, tjig_l = tdata.to(self.device), tjig_l.to(self.device) self.optimizer.zero_grad() jigsaw_logit, class_logit = self.model(data) jigsaw_loss = criterion(jigsaw_logit, jig_l) target_jigsaw_logit, target_class_logit = self.model(tdata) target_jigsaw_loss = criterion(target_jigsaw_logit, tjig_l) target_entropy_loss = entropy_loss(target_class_logit[tjig_l == 0]) if self.only_non_scrambled: class_loss = criterion(class_logit[jig_l == 0], class_l[jig_l == 0]) else: class_loss = criterion(class_logit, class_l) _, cls_pred = class_logit.max(dim=1) _, jig_pred = jigsaw_logit.max(dim=1) loss = class_loss + jigsaw_loss * self.jig_weight + target_jigsaw_loss * self.target_weight + target_entropy_loss * self.target_entropy loss.backward() self.optimizer.step() self.logger.log( it, len(self.source_loader), { "jigsaw": jigsaw_loss.item(), "class": class_loss.item(), "t_jigsaw": target_jigsaw_loss.item(), "entropy": target_entropy_loss.item() }, { "jigsaw": torch.sum(jig_pred == jig_l.data).item(), "class": torch.sum(cls_pred == class_l.data).item(), }, data.shape[0]) del loss, class_loss, jigsaw_loss, jigsaw_logit, class_logit, target_jigsaw_logit, target_jigsaw_loss self.model.eval() with torch.no_grad(): for phase, loader in self.test_loaders.items(): total = len(loader.dataset) if loader.dataset.isMulti(): jigsaw_correct, class_correct, single_acc = self.do_test_multi( loader) print("Single vs multi: %g %g" % (float(single_acc) / total, float(class_correct) / total)) else: jigsaw_correct, class_correct = self.do_test(loader) jigsaw_acc = float(jigsaw_correct) / total class_acc = float(class_correct) / total self.logger.log_test(phase, { "jigsaw": jigsaw_acc, "class": class_acc }) self.results[phase][self.current_epoch] = class_acc def do_test(self, loader): jigsaw_correct = 0 class_correct = 0 domain_correct = 0 for it, ((data, jig_l, class_l), _) in enumerate(loader): data, jig_l, class_l = data.to(self.device), jig_l.to( self.device), class_l.to(self.device) jigsaw_logit, class_logit = self.model(data) _, cls_pred = class_logit.max(dim=1) _, jig_pred = jigsaw_logit.max(dim=1) class_correct += torch.sum(cls_pred == class_l.data) jigsaw_correct += torch.sum(jig_pred == jig_l.data) return jigsaw_correct, class_correct def do_training(self): self.logger = Logger(self.args, update_frequency=30) # , "domain", "lambda" self.results = { "val": torch.zeros(self.args.epochs), "test": torch.zeros(self.args.epochs) } for self.current_epoch in range(self.args.epochs): self.scheduler.step() self.logger.new_epoch(self.scheduler.get_lr()) self._do_epoch() val_res = self.results["val"] test_res = self.results["test"] idx_best = val_res.argmax() print("Best val %g, corresponding test %g - best test: %g" % (val_res.max(), test_res[idx_best], test_res.max())) self.logger.save_best(test_res[idx_best], test_res.max()) return self.logger, self.model
class Trainer: def __init__(self, args, device): self.args = args self.device = device if args.network == 'resnet18': model = resnet18(pretrained=self.args.pretrained, classes=args.n_classes) elif args.network == 'resnet50': model = resnet50(pretrained=self.args.pretrained, classes=args.n_classes) else: model = resnet18(pretrained=self.args.pretrained, classes=args.n_classes) self.model = model.to(device) if args.resume: if isfile(args.resume): print(f"=> loading checkpoint '{args.resume}'") checkpoint = torch.load(args.resume) self.args.start_epoch = checkpoint['epoch'] self.model.load_state_dict(checkpoint['model']) print( f"=> loaded checkpoint '{args.resume}' (epoch {checkpoint['epoch']})" ) else: raise ValueError(f"Failed to find checkpoint {args.resume}") self.source_loader, self.val_loader = data_helper.get_train_dataloader( args, patches=model.is_patch_based()) # self.target_loader = data_helper.get_val_dataloader(args, patches=model.is_patch_based()) self.target_loader = data_helper.get_tgt_dataloader( self.args, patches=model.is_patch_based()) self.test_loaders = { "val": self.val_loader, "test": self.target_loader } self.len_dataloader = len(self.source_loader) print("Dataset size: train %d, val %d, test %d" % (len(self.source_loader.dataset), len( self.val_loader.dataset), len(self.target_loader.dataset))) self.optimizer, self.scheduler = get_optim_and_scheduler( model, args.epochs, args.learning_rate, args.train_all, nesterov=args.nesterov) self.n_classes = args.n_classes if args.target in args.source: self.target_id = args.source.index(args.target) print("Target in source: %d" % self.target_id) print(args.source) else: self.target_id = None self.topk = [0 for _ in range(3)] def _do_epoch(self, epoch=None): if self.args.loss == 'ce': criterion = nn.CrossEntropyLoss() elif self.args.loss == 'fl': criterion = FocalLoss(class_num=self.args.n_classes) self.model.train() for it, ((data, jig_l, class_l), d_idx) in enumerate(self.source_loader): data, jig_l, class_l, d_idx = data.to(self.device), jig_l.to( self.device), class_l.to(self.device), d_idx.to(self.device) self.optimizer.zero_grad() data_flip = torch.flip(data, (3, )).detach().clone() data = torch.cat((data, data_flip)) class_l = torch.clamp(class_l, 0, 9) class_l = torch.cat((class_l, class_l)) class_logit = self.model(data, class_l, self.args.RSC_flag, epoch) class_loss = criterion(class_logit, class_l) _, cls_pred = class_logit.max(dim=1) loss = class_loss loss.backward() self.optimizer.step() self.logger.log( it, len(self.source_loader), {"loss": class_loss.item()}, { "class": torch.sum(cls_pred == class_l.data).item(), }, data.shape[0]) del loss, class_loss, class_logit self.model.eval() with torch.no_grad(): for phase, loader in self.test_loaders.items(): total = len(loader.dataset) class_correct, auc_dict = self.do_test(loader) class_acc = float(class_correct) / total self.logger.log_test(phase, {"class_acc": class_acc}) self.logger.log_test(phase, {"auc": auc_dict['auc']}) self.logger.log_test(phase, {"fpr_980": auc_dict['fpr_980']}) self.logger.log_test(phase, {"fpr_991": auc_dict['fpr_991']}) self.results[phase][self.current_epoch] = class_acc #save best&latest model params if phase == 'val': self.save_model(epoch, auc_dict) del auc_dict def do_test(self, loader): class_correct = 0 auc_meter = AUCMeter() for it, ((data, nouse, class_l), _) in enumerate(loader): data, nouse, class_l = data.to(self.device), nouse.to( self.device), class_l.to(self.device) class_logit = self.model(data, class_l, False) _, cls_pred = class_logit.max(dim=1) class_correct += torch.sum(cls_pred == class_l.data) cls_score = F.softmax(class_logit, dim=1) auc_meter.update(class_l.cpu(), cls_score.cpu()) auc, fpr_980, fpr_991, fpr_993, fpr_995, fpr_997, fpr_999, fpr_1, thresholds = auc_meter.calculate( ) auc_dict = { 'auc': auc, 'fpr_980': fpr_980, 'fpr_991': fpr_991, 'fpr_993': fpr_993, 'fpr_995': fpr_995, 'fpr_997': fpr_997, 'fpr_999': fpr_999, 'fpr_1': fpr_1, 'thresholds': thresholds } return class_correct, auc_dict def do_training(self): self.logger = Logger(self.args, update_frequency=50) self.results = { "val": torch.zeros(self.args.epochs), "test": torch.zeros(self.args.epochs) } for self.current_epoch in range(self.args.start_epoch, self.args.epochs): self._do_epoch(self.current_epoch) self.scheduler.step() self.logger.new_epoch(self.scheduler.get_last_lr()) val_res = self.results["val"] test_res = self.results["test"] idx_best = val_res.argmax() print( "Best val %g, corresponding test %g - best test: %g, best epoch: %g" % (val_res.max(), test_res[idx_best], test_res.max(), idx_best)) self.logger.save_best(test_res[idx_best], test_res.max()) return self.logger, self.model # def save_model(self,epoch, auc_dict): # if not exists(_save_models_dir): os.mkdir(_save_models_dir) # state_to_save = {'model':self.model.state_dict(), 'auc_dict':auc_dict, 'epoch':epoch} # tmp_auc, tmp_fpr_980 = auc_dict['auc'], auc_dict['fpr_980'] # best1,best2,best3 = self.moving_record['best1'],self.moving_record['best2'],self.moving_record['best3'] # best1_path, best2_path, best3_path = (join(_save_models_dir, f"tgt_{self.args.target}_src_{'-'.join(self.args.source)}_RSC_{self.args.RSC_flag}_best{_}.pth") for _ in [1,2,3]) # #resort top3 # update_pos = -1 # if tmp_auc>best1['auc']: # best3['auc'], best3['fpr_980'] = best2['auc'], best2['fpr_980'] # best2['auc'], best2['fpr_980'] = best1['auc'], best1['fpr_980'] # best1['auc'], best1['fpr_980'] = tmp_auc, tmp_fpr_980 # if exists(best2_path) and exists(best3_path): # os.rename(best2_path, best3_path) # if exists(best1_path) and exists(best2_path): # os.rename(best1_path, best2_path) # update_pos = 1 # elif best2['auc']< tmp_auc < best1['auc']: # best3['auc'], best3['fpr_980'] = best2['auc'], best2['fpr_980'] # best2['auc'], best2['fpr_980'] = tmp_auc, tmp_fpr_980 # if exists(best2_path) and exists(best3_path): # os.rename(best2_path, best3_path) # update_pos = 2 # elif best3['auc']< tmp_auc < best2['auc']: # best3['auc'], best3['fpr_980'] = tmp_auc, tmp_fpr_980 # update_pos = 3 # if update_pos in [1,2,3]: # model_saved_path = join(_save_models_dir, f"tgt_{self.args.target}_src_{'-'.join(self.args.source)}_RSC_{self.args.RSC_flag}_best{update_pos}.pth") # torch.save(state_to_save, model_saved_path) # print(f'=>Best{update_pos} model updated and saved in path {model_saved_path}') # if epoch in range(self.args.epochs-3, self.args.epochs): # model_saved_path = join(_save_models_dir, f"tgt_{self.args.target}_src_{'-'.join(self.args.source)}_RSC_{self.args.RSC_flag}_epochs{epoch}.pth") # torch.save(state_to_save, model_saved_path) # print(f'=>Last{self.args.epochs - epoch} model updated and saved in path {model_saved_path}') def save_model(self, epoch, auc_dict): if not exists(_save_models_dir): os.mkdir(_save_models_dir) tmp_auc, tmp_fpr_980 = auc_dict['auc'], auc_dict['fpr_980'] for i, rec in enumerate(self.topk): if tmp_auc > rec: for j in range(len(self.topk) - 1, i, -1): self.topk[j] = self.topk[j - 1] _j, _jm1 = join(_save_models_dir, f"tgt_{self.args.target}_src_{'-'.join(self.args.source)}_RSC_{self.args.RSC_flag}_best{j+1}.pth"),\ join(_save_models_dir, f"tgt_{self.args.target}_src_{'-'.join(self.args.source)}_RSC_{self.args.RSC_flag}_best{j}.pth") if exists(_jm1): os.rename(_jm1, _j) self.topk[i] = tmp_auc model_saved_path = join( _save_models_dir, f"tgt_{self.args.target}_src_{'-'.join(self.args.source)}_RSC_{self.args.RSC_flag}_best{i+1}.pth" ) state_to_save = { 'model': self.model.state_dict(), 'auc_dict': auc_dict, 'epoch': epoch } torch.save(state_to_save, model_saved_path) print( f'=>Best{i+1} model updated and saved in path {model_saved_path}' ) break if epoch in range(self.args.epochs - 3, self.args.epochs): model_saved_path = join( _save_models_dir, f"tgt_{self.args.target}_src_{'-'.join(self.args.source)}_RSC_{self.args.RSC_flag}_epochs{epoch}.pth" ) torch.save(state_to_save, model_saved_path) print( f'=>Last{self.args.epochs - epoch} model updated and saved in path {model_saved_path}' )
class Trainer: def __init__(self, args, device): self.args = args self.device = device model = model_factory.get_network(args.network)( jigsaw_classes=args.jigsaw_n_classes + 1, classes=args.n_classes) self.model = model.to(device) # print(self.model) self.source_loader, self.val_loader = data_helper.get_train_dataloader( args, patches=model.is_patch_based()) self.target_loader = data_helper.get_val_dataloader( args, patches=model.is_patch_based()) self.test_loaders = { "val": self.val_loader, "test": self.target_loader } self.len_dataloader = len(self.source_loader) print("Dataset size: train %d, val %d, test %d" % (len(self.source_loader.dataset), len( self.val_loader.dataset), len(self.target_loader.dataset))) self.optimizer, self.scheduler = get_optim_and_scheduler( model, args.epochs, args.learning_rate, args.train_all, nesterov=args.nesterov) self.jig_weight = args.jig_weight self.only_non_scrambled = args.classify_only_sane self.n_classes = args.n_classes if args.target in args.source: self.target_id = args.source.index(args.target) print("Target in source: %d" % self.target_id) print(args.source) else: self.target_id = None self.best_val_jigsaw = 0.0 self.best_class_acc = 0.0 _, logname = Logger.get_name_from_args(args) self.folder_name = "%s/%s_to_%s/%s" % (args.folder_name, "-".join( sorted(args.source)), args.target, logname) def _do_epoch(self): criterion = nn.CrossEntropyLoss() self.model.train() epoch_loss = 0 pbar = pkbar.Pbar(name='Epoch Progress', target=len(self.source_loader)) for it, ((data, jig_l, class_l), d_idx) in enumerate(self.source_loader): pbar.update(it) data, jig_l, class_l, d_idx = data.to(self.device), jig_l.to( self.device), class_l.to(self.device), d_idx.to(self.device) # absolute_iter_count = it + self.current_epoch * self.len_dataloader # p = float(absolute_iter_count) / self.args.epochs / self.len_dataloader # lambda_val = 2. / (1. + np.exp(-10 * p)) - 1 # if domain_error > 2.0: # lambda_val = 0 # print("Shutting down LAMBDA to prevent implosion") self.optimizer.zero_grad() jigsaw_logit, class_logit = self.model( data) # , lambda_val=lambda_val) jigsaw_loss = criterion(jigsaw_logit, jig_l) # domain_loss = criterion(domain_logit, d_idx) # domain_error = domain_loss.item() if self.only_non_scrambled: if self.target_id is not None: idx = (jig_l == 0) & (d_idx != self.target_id) class_loss = criterion(class_logit[idx], class_l[idx]) else: class_loss = criterion(class_logit[jig_l == 0], class_l[jig_l == 0]) elif self.target_id: class_loss = criterion(class_logit[d_idx != self.target_id], class_l[d_idx != self.target_id]) else: class_loss = criterion(class_logit, class_l) _, cls_pred = class_logit.max(dim=1) _, jig_pred = jigsaw_logit.max(dim=1) if self.args.deep_all: jigsaw_loss = torch.Tensor([0.0]) loss = class_loss else: loss = class_loss + jigsaw_loss * self.jig_weight # + 0.1 * domain_loss # _, domain_pred = domain_logit.max(dim=1) epoch_loss = epoch_loss + loss loss.backward() self.optimizer.step() self.logger.log( it, len(self.source_loader), { "jigsaw": jigsaw_loss.item(), "class": class_loss.item() # , "domain": domain_loss.item() }, # ,"lambda": lambda_val}, { "jigsaw": torch.sum(jig_pred == jig_l.data).item(), "class": torch.sum(cls_pred == class_l.data).item(), # "domain": torch.sum(domain_pred == d_idx.data).item() }, data.shape[0]) del loss, class_loss, jigsaw_loss, jigsaw_logit, class_logit self.model.eval() with torch.no_grad(): for phase, loader in self.test_loaders.items(): total = len(loader.dataset) if loader.dataset.isMulti(): jigsaw_correct, class_correct, single_acc = self.do_test_multi( loader) print("Single vs multi: %g %g" % (float(single_acc) / total, float(class_correct) / total)) else: jigsaw_correct, class_correct = self.do_test(loader) jigsaw_acc = float(jigsaw_correct) / total class_acc = float(class_correct) / total self.logger.log_test(phase, { "jigsaw": jigsaw_acc, "class": class_acc }) self.results[phase][self.current_epoch] = class_acc if (self.results['val'][self.current_epoch] > self.best_class_acc): self.best_class_acc = self.results['val'][self.current_epoch] print("Saving new best at epoch: {}".format(self.current_epoch)) self.save_model( os.path.join("logs", self.folder_name, 'best_model.pth')) print("Saving latest at epoch: {}".format(self.current_epoch)) self.save_model( os.path.join("logs", self.folder_name, 'latest_model.pth')) def save_model(self, file_path): torch.save( { 'epoch': self.current_epoch, 'model_state_dict': self.model.state_dict(), 'optimizer_state_dict': self.optimizer.state_dict(), 'best_val_acc': self.results['val'][self.current_epoch], 'test_acc': self.results['test'][self.current_epoch] }, file_path) def do_test(self, loader): jigsaw_correct = 0 class_correct = 0 domain_correct = 0 for it, ((data, jig_l, class_l), _) in enumerate(loader): data, jig_l, class_l = data.to(self.device), jig_l.to( self.device), class_l.to(self.device) jigsaw_logit, class_logit = self.model(data) _, cls_pred = class_logit.max(dim=1) _, jig_pred = jigsaw_logit.max(dim=1) class_correct += torch.sum(cls_pred == class_l.data) jigsaw_correct += torch.sum(jig_pred == jig_l.data) return jigsaw_correct, class_correct def do_test_multi(self, loader): jigsaw_correct = 0 class_correct = 0 single_correct = 0 for it, ((data, jig_l, class_l), d_idx) in enumerate(loader): data, jig_l, class_l = data.to(self.device), jig_l.to( self.device), class_l.to(self.device) n_permutations = data.shape[1] class_logits = torch.zeros(n_permutations, data.shape[0], self.n_classes).to(self.device) for k in range(n_permutations): class_logits[k] = F.softmax(self.model(data[:, k])[1], dim=1) class_logits[ 0] *= 4 * n_permutations # bias more the original image class_logit = class_logits.mean(0) _, cls_pred = class_logit.max(dim=1) jigsaw_logit, single_logit = self.model(data[:, 0]) _, jig_pred = jigsaw_logit.max(dim=1) _, single_logit = single_logit.max(dim=1) single_correct += torch.sum(single_logit == class_l.data) class_correct += torch.sum(cls_pred == class_l.data) jigsaw_correct += torch.sum(jig_pred == jig_l.data[:, 0]) return jigsaw_correct, class_correct, single_correct def do_training(self): self.logger = Logger(self.args, update_frequency=30) # , "domain", "lambda" self.results = { "val": torch.zeros(self.args.epochs), "test": torch.zeros(self.args.epochs) } for self.current_epoch in range(self.args.epochs): start_time = time.time() self.scheduler.step() self.logger.new_epoch(self.scheduler.get_lr()) self._do_epoch() end_time = time.time() print(f"Runtime of the epoch is {end_time - start_time}") val_res = self.results["val"] test_res = self.results["test"] idx_best = val_res.argmax() print("Best val %g, corresponding test %g - best test: %g" % (val_res.max(), test_res[idx_best], test_res.max())) self.logger.save_best(test_res[idx_best], test_res.max()) # Save Arguments with open(osp.join('logs', self.folder_name, 'args.txt'), 'w') as f: json.dump(self.args.__dict__, f, indent=2) # Save results with open(osp.join('logs', self.folder_name, 'results.txt'), 'w') as f: f.write("Best val %g, corresponding test %g - best test: %g" % (val_res.max(), test_res[idx_best], test_res.max())) return self.logger, self.model
class Trainer: def __init__(self, args, device): self.args = args self.device = device model = model_factory.get_network(args.network)( jigsaw_classes=args.jigsaw_n_classes + 1, classes=args.n_classes) self.model = model.to(device) self.source_loader, self.val_loader = data_helper.get_train_dataloader( args, patches=model.is_patch_based()) self.target_test_loaders = data_helper.get_jigsaw_test_dataloaders( args, patches=model.is_patch_based()) # Evaluate on Validation & Test datasets self.evaluation_loaders = { "val": self.val_loader, "test": self.target_test_loaders } print("Dataset size: train %d, val %d" % (len(self.source_loader.dataset), len(self.val_loader.dataset))) self.optimizer, self.scheduler = get_optim_and_scheduler( model, args.epochs, args.learning_rate, args.train_all, nesterov=args.nesterov) self.jig_weight = args.jig_weight self.only_non_scrambled = args.classify_only_sane self.n_classes = args.n_classes if args.target in args.source: self.target_id = args.source.index(args.target) print("Target in source: %d" % self.target_id) print(args.source) else: self.target_id = None def _do_epoch(self): criterion = nn.CrossEntropyLoss() self.model.train() for it, ((data, jig_l, class_l), d_idx) in enumerate(self.source_loader): data, jig_l, class_l, d_idx = data.to(self.device), jig_l.to( self.device), class_l.to(self.device), d_idx.to(self.device) self.optimizer.zero_grad() jigsaw_logit, class_logit = self.model(data) jigsaw_loss = criterion(jigsaw_logit, jig_l) if self.only_non_scrambled: # 只对正常图片进行物种分类 if self.target_id is not None: # 图片没有被打乱 && 图片的 domain 不是 target domain #(因为我们不训练target domain,target domain的图片只用来 predict) idx = (jig_l == 0) & (d_idx != self.target_id) class_loss = criterion(class_logit[idx], class_l[idx]) else: class_loss = criterion(class_logit[jig_l == 0], class_l[jig_l == 0]) elif self.target_id: # 对所有(包括打乱的)图片进行物种分类,target domain 只用于 predict class_loss = criterion(class_logit[d_idx != self.target_id], class_l[d_idx != self.target_id]) else: # 对所有(包括打乱的)图片进行物种分类,target domain 只用于 predict class_loss = criterion(class_logit, class_l) _, cls_pred = class_logit.max(dim=1) _, jig_pred = jigsaw_logit.max(dim=1) loss = class_loss + jigsaw_loss * self.jig_weight # + 0.1 * domain_loss loss.backward() self.optimizer.step() self.logger.log( it, len(self.source_loader), { "jigsaw": jigsaw_loss.item(), "class": class_loss.item() }, { "jigsaw": torch.sum(jig_pred == jig_l.data).item(), "class": torch.sum(cls_pred == class_l.data).item(), }, data.shape[0]) # 解除变量引用与实际值的指向关系 del loss, class_loss, jigsaw_loss, jigsaw_logit, class_logit self.model.eval() with torch.no_grad(): for phase, loader in self.evaluation_loaders.items(): if phase == 'test': belonged_dataset = data_helper.get_belonged_dataset( self.args.source[0]) target_domains = [ item for item in belonged_dataset if item not in self.args.source ] acc_sum = 0.0 for didx in range(len(loader)): dkey = phase + '-' + target_domains[didx] test_loader = loader[didx] test_total = len(test_loader.dataset) jigsaw_correct, class_correct = self.do_test( test_loader) jigsaw_acc = float(jigsaw_correct) / total class_acc = float(class_correct) / test_total self.logger.log_test(dkey, {"class": class_acc}) if dkey not in self.results.keys(): self.results[dkey] = torch.zeros(self.args.epochs) self.results[dkey][self.current_epoch] = class_acc acc_sum += class_acc self.logger.log_test(phase, {"class": acc_sum / len(loader)}) self.results[phase][ self.current_epoch] = acc_sum / len(loader) else: total = len(loader.dataset) if loader.dataset.isMulti(): jigsaw_correct, class_correct, single_acc = self.do_test_multi( loader) print("Single vs multi: %g %g" % (float(single_acc) / total, float(class_correct) / total)) else: jigsaw_correct, class_correct = self.do_test(loader) jigsaw_acc = float(jigsaw_correct) / total class_acc = float(class_correct) / total self.logger.log_test(phase, { "jigsaw": jigsaw_acc, "class": class_acc }) self.results[phase][self.current_epoch] = class_acc def do_test(self, loader): jigsaw_correct = 0 class_correct = 0 domain_correct = 0 for it, ((data, jig_l, class_l), _) in enumerate(loader): data, jig_l, class_l = data.to(self.device), jig_l.to( self.device), class_l.to(self.device) jigsaw_logit, class_logit = self.model(data) _, cls_pred = class_logit.max(dim=1) _, jig_pred = jigsaw_logit.max(dim=1) class_correct += torch.sum(cls_pred == class_l.data) jigsaw_correct += torch.sum(jig_pred == jig_l.data) return jigsaw_correct, class_correct def do_test_multi(self, loader): jigsaw_correct = 0 class_correct = 0 single_correct = 0 for it, ((data, jig_l, class_l), d_idx) in enumerate(loader): data, jig_l, class_l = data.to(self.device), jig_l.to( self.device), class_l.to(self.device) n_permutations = data.shape[1] class_logits = torch.zeros(n_permutations, data.shape[0], self.n_classes).to(self.device) for k in range(n_permutations): class_logits[k] = F.softmax(self.model(data[:, k])[1], dim=1) class_logits[ 0] *= 4 * n_permutations # bias more the original image class_logit = class_logits.mean(0) _, cls_pred = class_logit.max(dim=1) jigsaw_logit, single_logit = self.model(data[:, 0]) _, jig_pred = jigsaw_logit.max(dim=1) _, single_logit = single_logit.max(dim=1) single_correct += torch.sum(single_logit == class_l.data) class_correct += torch.sum(cls_pred == class_l.data) jigsaw_correct += torch.sum(jig_pred == jig_l.data[:, 0]) return jigsaw_correct, class_correct, single_correct def do_training(self): self.logger = Logger(self.args, update_frequency=30) # , "domain", "lambda" self.results = { "val": torch.zeros(self.args.epochs), "test": torch.zeros(self.args.epochs) } for self.current_epoch in range(self.args.epochs): self.scheduler.step() self.logger.new_epoch(self.scheduler.get_lr()) self._do_epoch() val_res = self.results["val"] test_res = self.results["test"] idx__val_best = val_res.argmax() idx_test_best = test_res.argmax() print("Best test acc: %g in epoch: %d" % (test_res.max(), idx_test_best + 1)) self.logger.save_best(test_res[idx_test_best].item(), test_res.max().item()) return self.logger, self.model
class Trainer: def __init__(self, args, device): self.args = args self.device = device if args.network == 'resnet18': model = resnet18(pretrained=True, classes=args.n_classes) elif args.network == 'resnet50': model = resnet50(pretrained=True, classes=args.n_classes) else: model = resnet18(pretrained=True, classes=args.n_classes) self.model = model.to(device) self.D_model = IntraClsInfoMax(alpha=args.alpha, beta=args.beta, gamma=args.gamma).to(device) # print(self.model) # print(self.D_model) self.source_loader, self.val_loader = data_helper.get_train_dataloader( args, patches=model.is_patch_based()) self.target_loader = data_helper.get_val_dataloader( args, patches=model.is_patch_based()) self.test_loaders = { "val": self.val_loader, "test": self.target_loader } self.len_dataloader = len(self.source_loader) print("Dataset size: train %d, val %d, test %d" % (len(self.source_loader.dataset), len( self.val_loader.dataset), len(self.target_loader.dataset))) self.optimizer, self.scheduler = get_optim_and_scheduler( [self.model, self.D_model.global_d, self.D_model.local_d], args.epochs, args.learning_rate, args.train_all, nesterov=args.nesterov) self.dis_optimizer, self.dis_scheduler = get_optim_and_scheduler( [self.D_model.prior_d], args.epochs, args.learning_rate, args.train_all, nesterov=args.nesterov) #args.learning_ratee*1e-3 self.n_classes = args.n_classes if args.target in args.source: self.target_id = args.source.index(args.target) print("Target in source: %d" % self.target_id) print(args.source) else: self.target_id = None self.max_test_acc = 0.0 self.logger = Logger(self.args, update_frequency=30) self.results = { "val": torch.zeros(self.args.epochs), "test": torch.zeros(self.args.epochs) } def _do_epoch(self, device='cuda'): criterion = nn.CrossEntropyLoss() self.model.train() self.D_model.train() for it, ((data, jig_l, class_l), d_idx) in enumerate(self.source_loader): data, jig_l, class_l, d_idx = data.to(self.device), jig_l.to( self.device), class_l.to(self.device), d_idx.to(self.device) self.optimizer.zero_grad() data_flip = torch.flip(data, (3, )).detach().clone() data = torch.cat((data, data_flip)) class_l = torch.cat((class_l, class_l)) y, M = self.model(data, feature_flag=True) # Classification Loss class_logit = self.model.class_classifier(y) class_loss = criterion(class_logit, class_l) # G loss - DIM Loss - P_loss M_prime = torch.cat( (M[1:], M[0].unsqueeze(0)), dim=0) # Move feature to front position one by one class_prime = torch.cat((class_l[1:], class_l[0].unsqueeze(0)), dim=0) class_ll = (class_l, class_prime) DIM_loss = self.D_model(y, M, M_prime, class_ll) P_loss = self.D_model.prior_loss(y) DIM_loss = DIM_loss - P_loss # DIM_loss=self.beta*(DIM_loss-P_loss) loss = class_loss + DIM_loss loss.backward() self.optimizer.step() self.dis_optimizer.zero_grad() P_loss = self.D_model.prior_loss(y.detach()) P_loss.backward() self.dis_optimizer.step() # Prediction _, cls_pred = class_logit.max(dim=1) losses = { 'class': class_loss.detach().item(), 'DIM': DIM_loss.detach().item(), 'P_loss': P_loss.detach().item() } self.logger.log( it, len(self.source_loader), losses, { "class": torch.sum(cls_pred == class_l.data).item(), }, data.shape[0]) del loss, class_loss, class_logit, DIM_loss self.model.eval() with torch.no_grad(): for phase, loader in self.test_loaders.items(): total = len(loader.dataset) class_correct = self.do_test(loader) class_acc = float(class_correct) / total self.logger.log_test(phase, {"class": class_acc}) self.results[phase][self.current_epoch] = class_acc if phase == 'test' and class_acc > self.max_test_acc: torch.save( self.model.state_dict(), os.path.join(self.logger.log_path, 'best_{}.pth'.format(phase))) def do_test(self, loader): class_correct = 0 for it, ((data, nouse, class_l), _) in enumerate(loader): data, nouse, class_l = data.to(self.device), nouse.to( self.device), class_l.to(self.device) class_logit = self.model(data, feature_flag=False) _, cls_pred = class_logit.max(dim=1) class_correct += torch.sum(cls_pred == class_l.data) return class_correct def do_training(self): for self.current_epoch in range(self.args.epochs): self.scheduler.step() self.dis_scheduler.step() self.logger.new_epoch( [*self.scheduler.get_lr(), *self.dis_scheduler.get_lr()]) self._do_epoch() # use self.current_epoch val_res = self.results["val"] test_res = self.results["test"] idx_best = val_res.argmax() print( "Best val %g, corresponding test %g - best test: %g, best epoch: %g" % (val_res.max(), test_res[idx_best], test_res.max(), idx_best)) self.logger.save_best(test_res[idx_best], test_res.max()) return self.logger, self.model
class Trainer: def __init__(self, args, device): self.args = args self.device = device model = model_factory.get_network(args.network)(classes=args.n_classes) self.model = model.to(device) self.source_loader, self.val_loader = data_helper.get_train_dataloader( args) self.target_loader = data_helper.get_val_dataloader(args) self.test_loaders = { "val": self.val_loader, "test": self.target_loader } self.len_dataloader = len(self.source_loader) print("Dataset size: train %d, val %d, test %d" % (len(self.source_loader.dataset), len( self.val_loader.dataset), len(self.target_loader.dataset))) self.optimizer, self.scheduler = get_optim_and_scheduler( model, args.epochs, args.learning_rate, args.train_all) self.n_classes = args.n_classes def _do_epoch(self): criterion = nn.CrossEntropyLoss() self.model.train() for it, (data, class_l) in enumerate(self.source_loader): data, class_l = data.to(self.device), class_l.to(self.device) self.optimizer.zero_grad() class_logit = self.model(data) class_loss = criterion(class_logit, class_l) _, cls_pred = class_logit.max(dim=1) loss = class_loss loss.backward() self.optimizer.step() self.logger.log(it, len(self.source_loader), { "Class Loss ": class_loss.item() }, {"Class Accuracy ": torch.sum(cls_pred == class_l.data).item()}, data.shape[0]) del loss, class_loss, class_logit self.model.eval() with torch.no_grad(): for phase, loader in self.test_loaders.items(): total = len(loader.dataset) class_correct = self.do_test(loader) class_acc = float(class_correct) / total self.logger.log_test(phase, {"Classification Accuracy": class_acc}) self.results[phase][self.current_epoch] = class_acc def do_test(self, loader): class_correct = 0 for it, (data, class_l) in enumerate(loader): data, class_l = data.to(self.device), class_l.to(self.device) class_logit = self.model(data) _, cls_pred = class_logit.max(dim=1) class_correct += torch.sum(cls_pred == class_l.data) return class_correct def do_training(self): self.logger = Logger(self.args, update_frequency=30) self.results = { "val": torch.zeros(self.args.epochs), "test": torch.zeros(self.args.epochs) } for self.current_epoch in range(self.args.epochs): self.logger.new_epoch(self.scheduler.get_lr()) self._do_epoch() self.scheduler.step() val_res = self.results["val"] test_res = self.results["test"] idx_best = val_res.argmax() print("Best val %g, corresponding test %g - best test: %g" % (val_res.max(), test_res[idx_best], test_res.max())) self.logger.save_best(test_res[idx_best], test_res.max()) return self.logger, self.model
class Trainer: def __init__(self, args, device): self.args = args self.device = device model = model_factory.get_network(args.network)(classes=args.n_classes, jigsaw_classes=31, rotation_classes=4, odd_classes=9) self.model = model.to(device) self.source_loader, self.val_loader = data_helper.get_train_dataloader(args) self.target_loader = data_helper.get_val_dataloader(args) self.test_loaders = {"val": self.val_loader, "test": self.target_loader} self.len_dataloader = len(self.source_loader) print("Dataset size: train %d, val %d, test %d" % (len(self.source_loader.dataset), len(self.val_loader.dataset), len(self.target_loader.dataset))) self.optimizer, self.scheduler = get_optim_and_scheduler(model, args.epochs, args.learning_rate, args.train_all) self.n_classes = args.n_classes self.nTasks = 2 if args.rotation == True: self.nTasks += 1 if args.odd_one_out == True: self.nTasks += 1 print("N of tasks: " + str(self.nTasks)) def _do_epoch(self): criterion = nn.CrossEntropyLoss() self.model.train() for it, (data, class_l, jigsaw_label, task_type) in enumerate(self.source_loader): rotation_loss = 0 odd_one_out_loss = 0 rotation_pred = 0 odd_pred = 0 data, class_l, jigsaw_label, task_type = data.to(self.device), class_l.to(self.device), jigsaw_label.to(self.device), task_type.to(self.device) self.optimizer.zero_grad() class_logit, jigsaw_logit, rotation_logit, odd_logit = self.model(data) class_loss = criterion(class_logit[task_type==0], class_l[task_type==0]) jigsaw_loss = criterion(jigsaw_logit[(task_type==0) | (task_type==1)], jigsaw_label[(task_type==0) | (task_type==1)]) _, cls_pred = class_logit.max(dim=1) _, jigsaw_pred = jigsaw_logit.max(dim=1) if self.args.rotation == True: #Rotation loss if the task is classification of "rotation" rotation_loss = criterion(rotation_logit[(task_type==0) | (task_type==2)], jigsaw_label[(task_type==0) | (task_type==2)]) _, rotation_pred = rotation_logit.max(dim=1) if self.args.odd_one_out == True: #Odd one out loss if the task is classification of "rotation" odd_one_out_loss = criterion(odd_logit[(task_type==0) | (task_type==3)], jigsaw_label[(task_type==0) | (task_type==3)]) _, odd_pred = odd_logit.max(dim=1) jig_loss = jigsaw_loss * self.args.jigsaw_alpha rot_loss = rotation_loss * self.args.beta_rotated odd_loss = odd_one_out_loss * self.args.beta_odd loss = class_loss + jig_loss + rot_loss + odd_loss + odd_loss loss.backward() self.optimizer.step() self.logger.log(it, len(self.source_loader), {"Class Loss ": class_loss.item()}, {"Class Accuracy ": torch.sum(cls_pred == class_l.data).item()}, data.shape[0]) self.logger.log(it, len(self.source_loader), {"Jigsaw Loss ": jigsaw_loss.item()}, {"Jigsaw Accuracy ": torch.sum(jigsaw_pred == jigsaw_label.data).item()}, data.shape[0]) if self.args.rotation == True: self.logger.log(it, len(self.source_loader), {"Rotation Loss ": rotation_loss.item()}, {"Rotation Accuracy ": torch.sum(rotation_pred == jigsaw_label.data).item()}, data.shape[0]) if self.args.odd_one_out == True: self.logger.log(it, len(self.source_loader), {"Odd one out Loss ": odd_loss.item()}, {"Odd one out Accuracy ": torch.sum(odd_pred == jigsaw_label.data).item()}, data.shape[0]) del loss, class_loss, class_logit, jigsaw_loss, jigsaw_logit del rotation_loss, odd_one_out_loss, jig_loss, rot_loss, odd_loss, rotation_pred, odd_pred self.model.eval() with torch.no_grad(): for phase, loader in self.test_loaders.items(): total = len(loader.dataset) class_correct, jigsaw_correct, rotation_correct, odd_correct = self.do_test(loader) class_acc = float(class_correct) / total jigsaw_acc = float(jigsaw_correct) / total rotation_acc = 0 odd_acc = 0 if self.args.rotation == True: rotation_acc = float(rotation_correct) / total if self.args.odd_one_out == True: odd_acc = float(odd_correct) / total self.logger.log_test(phase, {"Classification Accuracy": class_acc, "Jigsaw Accuracy": jigsaw_acc}) self.results[phase][self.current_epoch] = class_acc def do_test(self, loader): class_correct = 0 jigsaw_correct = 0 rotation_correct = 0 odd_correct = 0 for it, (data, class_l, jigsaw_label, task_type) in enumerate(loader): data, class_l, jigsaw_label, task_type = data.to(self.device), class_l.to(self.device), jigsaw_label.to(self.device), task_type.to(self.device) class_logit, jigsaw_logit, rotation_logit, odd_logit = self.model(data) _, cls_pred = class_logit.max(dim=1) _, jigsaw_pred = jigsaw_logit.max(dim=1) if self.args.rotation == True: _, rotation_pred = rotation_logit.max(dim=1) rotation_correct += torch.sum(rotation_pred == jigsaw_label.data) if self.args.odd_one_out == True: _, odd_pred = odd_logit.max(dim=1) odd_correct += torch.sum(odd_pred == jigsaw_label.data) class_correct += torch.sum(cls_pred == class_l.data) jigsaw_correct += torch.sum(jigsaw_pred == jigsaw_label.data) return class_correct, jigsaw_correct, rotation_correct, odd_correct def do_training(self): self.logger = Logger(self.args, update_frequency=30) self.results = {"val": torch.zeros(self.args.epochs), "test": torch.zeros(self.args.epochs)} for self.current_epoch in range(self.args.epochs): self.logger.new_epoch(self.scheduler.get_lr()) self._do_epoch() self.scheduler.step() val_res = self.results["val"] test_res = self.results["test"] idx_best = val_res.argmax() print("Best val %g, corresponding test %g - best test: %g" % (val_res.max(), test_res[idx_best], test_res.max())) self.logger.save_best(test_res[idx_best], test_res.max()) return self.logger, self.model
class Trainer: def __init__(self, args, device): self.args = args self.device = device model = model_factory.get_network(args.network)(jigsaw_classes=args.jigsaw_n_classes + 1, classes=args.n_classes) self.model = model.to(device) # print(self.model) self.source_loader, self.val_loader = data_helper.get_train_dataloader(args, patches=model.is_patch_based()) self.target_loader = data_helper.get_val_dataloader(args, patches=model.is_patch_based()) self.test_loaders = {"val": self.val_loader, "test": self.target_loader} self.len_dataloader = len(self.source_loader) print("Dataset size: train %d, val %d, test %d" % (len(self.source_loader.dataset), len(self.val_loader.dataset), len(self.target_loader.dataset))) self.optimizer, self.scheduler = get_optim_and_scheduler(model, args.epochs, args.learning_rate, args.train_all, nesterov=args.nesterov, adam=args.adam) self.jig_weight = args.jig_weight self.rex_weight_class = args.rex_weight_class self.irm_weight_class = args.irm_weight_class self.rex_weight_jigsaw = args.rex_weight_jigsaw self.irm_weight_jigsaw = args.irm_weight_jigsaw self.only_non_scrambled = args.classify_only_sane self.n_classes = args.n_classes if args.target in args.source: self.target_id = args.source.index(args.target) print("Target in source: %d" % self.target_id) print(args.source) else: self.target_id = None def _do_epoch(self): criterion = nn.CrossEntropyLoss() self.model.train() for it, ((data, jig_l, class_l), d_idx) in enumerate(self.source_loader): data, jig_l, class_l, d_idx = data.to(self.device), jig_l.to(self.device), class_l.to(self.device), d_idx.to(self.device) self.optimizer.zero_grad() jigsaw_logit, class_logit = self.model(data) # , lambda_val=lambda_val) j1 = criterion(jigsaw_logit[d_idx == 0], jig_l[d_idx == 0]) j1_irm = compute_irm_penalty(jigsaw_logit[d_idx == 0], jig_l[d_idx == 0], criterion) j2 = criterion(jigsaw_logit[d_idx == 1], jig_l[d_idx == 1]) j2_irm = compute_irm_penalty(jigsaw_logit[d_idx == 1], jig_l[d_idx == 1], criterion) j3 = criterion(jigsaw_logit[d_idx == 2], jig_l[d_idx == 2]) j3_irm = compute_irm_penalty(jigsaw_logit[d_idx == 2], jig_l[d_idx == 2], criterion) rex_jigsaw = compute_rex_penalty(j1,j2,j3) jigsaw_loss = j1+j2+j3 irm_jigsaw = (j1_irm+j2_irm+j3_irm)/3 if self.only_non_scrambled: if self.target_id is not None: idx = (jig_l == 0) & (d_idx != self.target_id) class_loss = criterion(class_logit[idx], class_l[idx]) rex_class = torch.Tensor([0.]).cuda() irm_class = torch.Tensor([0.]).cuda() else: class_loss_1 = criterion(class_logit[(jig_l == 0) & (d_idx == 0)], class_l[(jig_l == 0) & (d_idx == 0)]) class_irm_1 = compute_irm_penalty(class_logit[(jig_l == 0) & (d_idx == 0)], class_l[(jig_l == 0) & (d_idx == 0)], criterion) class_loss_2 = criterion(class_logit[(jig_l == 0) & (d_idx == 1)], class_l[(jig_l == 0) & (d_idx == 1)]) class_irm_2 = compute_irm_penalty(class_logit[(jig_l == 0) & (d_idx == 1)], class_l[(jig_l == 0) & (d_idx == 1)], criterion) class_loss_3 = criterion(class_logit[(jig_l == 0) & (d_idx == 2)], class_l[(jig_l == 0) & (d_idx == 2)]) class_irm_3 = compute_irm_penalty(class_logit[(jig_l == 0) & (d_idx == 2)], class_l[(jig_l == 0) & (d_idx == 2)], criterion) class_loss = class_loss_1 + class_loss_2 + class_loss_3 irm_class = (class_irm_1 + class_irm_2 + class_irm_3)/3 rex_class = compute_rex_penalty(class_loss_1, class_loss_2, class_loss_2) elif self.target_id: class_loss = criterion(class_logit[d_idx != self.target_id], class_l[d_idx != self.target_id]) rex_class = torch.Tensor([0.]).cuda() irm_class = torch.Tensor([0.]).cuda() else: class_loss_1 = criterion(class_logit[(d_idx == 0)], class_l[(d_idx == 0)]) class_irm_1 = compute_irm_penalty(class_logit[(d_idx == 0)], class_l[(d_idx == 0)], criterion) class_loss_2 = criterion(class_logit[(d_idx == 1)], class_l[(d_idx == 1)]) class_irm_2 = compute_irm_penalty(class_logit[(d_idx == 1)], class_l[(d_idx == 1)], criterion) class_loss_3 = criterion(class_logit[(d_idx == 2)], class_l[(d_idx == 2)]) class_irm_3 = compute_irm_penalty(class_logit[(d_idx == 2)], class_l[(d_idx == 2)], criterion) class_loss = class_loss_1 + class_loss_2 + class_loss_3 irm_class = (class_irm_1 + class_irm_2 + class_irm_3)/3 rex_class = compute_rex_penalty(class_loss_1, class_loss_2, class_loss_2) _, cls_pred = class_logit.max(dim=1) _, jig_pred = jigsaw_logit.max(dim=1) # _, domain_pred = domain_logit.max(dim=1) rex_loss = self.rex_weight_class * rex_class + self.rex_weight_jigsaw * self.jig_weight * rex_jigsaw irm_loss = self.irm_weight_class * irm_class + self.irm_weight_jigsaw * self.jig_weight * irm_jigsaw if self.rex_weight_class == 0. and self.rex_weight_jigsaw == 0. and self.irm_weight_jigsaw == 0. and self.irm_weight_class == 0.: loss = class_loss + jigsaw_loss * self.jig_weight elif self.irm_weight_jigsaw == 0. and self.irm_weight_class == 0.: loss = class_loss + jigsaw_loss * self.jig_weight + rex_loss elif self.rex_weight_class == 0. and self.rex_weight_jigsaw == 0.: loss = class_loss + jigsaw_loss * self.jig_weight + irm_loss loss.backward() self.optimizer.step() self.logger.log(it, len(self.source_loader), {"jigsaw": jigsaw_loss.item(), "class": class_loss.item(), "rex loss class": rex_class.item(), "rex loss jigsaw": rex_jigsaw.item(), "rext total": rex_loss.item(), "irm loss class": irm_class.item(), "irm loss jigsaw": irm_jigsaw.item(), "irm total": irm_loss.item()}, # ,"lambda": lambda_val}, {"jigsaw": torch.sum(jig_pred == jig_l.data).item(), "class": torch.sum(cls_pred == class_l.data).item(), # "domain": torch.sum(domain_pred == d_idx.data).item() }, data.shape[0]) del loss, class_loss, jigsaw_loss, rex_loss, jigsaw_logit, class_logit self.model.eval() with torch.no_grad(): for phase, loader in self.test_loaders.items(): total = len(loader.dataset) if loader.dataset.isMulti(): jigsaw_correct, class_correct, single_acc = self.do_test_multi(loader) print("Single vs multi: %g %g" % (float(single_acc) / total, float(class_correct) / total)) else: jigsaw_correct, class_correct = self.do_test(loader) jigsaw_acc = float(jigsaw_correct) / total class_acc = float(class_correct) / total self.logger.log_test(phase, {"jigsaw": jigsaw_acc, "class": class_acc}) self.results[phase][self.current_epoch] = class_acc def do_test(self, loader): jigsaw_correct = 0 class_correct = 0 domain_correct = 0 for it, ((data, jig_l, class_l), _) in enumerate(loader): data, jig_l, class_l = data.to(self.device), jig_l.to(self.device), class_l.to(self.device) jigsaw_logit, class_logit = self.model(data) _, cls_pred = class_logit.max(dim=1) _, jig_pred = jigsaw_logit.max(dim=1) class_correct += torch.sum(cls_pred == class_l.data) jigsaw_correct += torch.sum(jig_pred == jig_l.data) return jigsaw_correct, class_correct def do_test_multi(self, loader): jigsaw_correct = 0 class_correct = 0 single_correct = 0 for it, ((data, jig_l, class_l), d_idx) in enumerate(loader): data, jig_l, class_l = data.to(self.device), jig_l.to(self.device), class_l.to(self.device) n_permutations = data.shape[1] class_logits = torch.zeros(n_permutations, data.shape[0], self.n_classes).to(self.device) for k in range(n_permutations): class_logits[k] = F.softmax(self.model(data[:, k])[1], dim=1) class_logits[0] *= 4 * n_permutations # bias more the original image class_logit = class_logits.mean(0) _, cls_pred = class_logit.max(dim=1) jigsaw_logit, single_logit = self.model(data[:, 0]) _, jig_pred = jigsaw_logit.max(dim=1) _, single_logit = single_logit.max(dim=1) single_correct += torch.sum(single_logit == class_l.data) class_correct += torch.sum(cls_pred == class_l.data) jigsaw_correct += torch.sum(jig_pred == jig_l.data[:, 0]) return jigsaw_correct, class_correct, single_correct def do_training(self): self.logger = Logger(self.args, update_frequency=30) # , "domain", "lambda" self.results = {"val": torch.zeros(self.args.epochs), "test": torch.zeros(self.args.epochs)} for self.current_epoch in range(self.args.epochs): self.scheduler.step() self.logger.new_epoch(self.scheduler.get_lr()) self._do_epoch() val_res = self.results["val"] test_res = self.results["test"] idx_best = val_res.argmax() #print("Best val %g, corresponding test %g - best test: %g" % (val_res.max(), test_res[idx_best], test_res.max())) name = self.args.prefix+"_"+str(self.args.source[0])+str(self.args.source[1])+str(self.args.source[2])+"_"+str(self.args.target)+"_eps%d_bs%d_lr%g_class%d_jigClass%d_rexWeightClass%g_rexWeightJig%g_irmWeightClass%g_irmWeightJig%g_jigWeight%g" % (self.args.epochs, self.args.batch_size, self.args.learning_rate, self.args.n_classes, self.args.jigsaw_n_classes, self.args.rex_weight_class, self.args.rex_weight_jigsaw, self.args.irm_weight_class, self.args.irm_weight_jigsaw, self.args.jig_weight) with open('./result_summary_txt/'+name+'.txt', 'a+') as f: f.write('best validation accuracy: '+str(val_res.max())+' test acc at best val acc: '+str(test_res[idx_best])+' max test: '+str(test_res.max())) self.logger.save_best(test_res[idx_best], test_res.max()) return self.logger, self.model
class Trainer: def __init__(self, args, device): self.args = args self.device = device model = model_factory.get_network(args.network)(classes=args.n_classes) self.model = model.to(device) # print(self.model) self.source_loader, self.val_loader = data_helper.get_train_dataloader(args, patches=model.is_patch_based()) self.target_loader = data_helper.get_val_dataloader(args, patches=model.is_patch_based()) self.test_loaders = {"val": self.val_loader, "test": self.target_loader} self.len_dataloader = len(self.source_loader) print("Dataset size: train %d, val %d, test %d" % (len(self.source_loader.dataset), len(self.val_loader.dataset), len(self.target_loader.dataset))) self.optimizer, self.scheduler = get_optim_and_scheduler(model, args.epochs, args.learning_rate, args.train_all, nesterov=args.nesterov) self.n_classes = args.n_classes if args.target in args.source: self.target_id = args.source.index(args.target) print("Target in source: %d" % self.target_id) print(args.source) else: self.target_id = None def _do_epoch(self): criterion = nn.CrossEntropyLoss() self.model.train() for it, ((data, jig_l, class_l), d_idx) in enumerate(self.source_loader): NUM_DOMAINS = 3 oh_dids = torch.tensor(one_hot(d_idx, NUM_DOMAINS), dtype=torch.float, device='cuda') data, jig_l, class_l, d_idx = data.to(self.device), jig_l.to(self.device), class_l.to(self.device), d_idx.to(self.device) self.optimizer.zero_grad() specific_logit, class_logit = self.model(data, oh_dids) specific_loss = criterion(specific_logit, class_l) class_loss = criterion(class_logit, class_l) _, cls_pred = class_logit.max(dim=1) sms = self.model.sms K = 2 diag_tensor = torch.stack([torch.eye(K) for _ in range(self.n_classes)], dim=0).cuda() cps = torch.stack([torch.matmul(sms[:, :, _], torch.transpose(sms[:, :, _], 0, 1)) for _ in range(self.n_classes)], dim=0) if self.args.network.startswith('caffenet'): orth_loss = torch.mean((1 - diag_tensor)*(cps - diag_tensor)**2) else: orth_loss = torch.mean((cps - diag_tensor)**2) loss = class_loss + specific_loss + orth_loss loss.backward() self.optimizer.step() self.logger.log(it, len(self.source_loader), {"specific": specific_loss.item(), "class": class_loss.item() }, {"class": torch.sum(cls_pred == class_l.data).item() }, data.shape[0]) del loss, class_loss, specific_loss, specific_logit, class_logit self.model.eval() with torch.no_grad(): for phase, loader in self.test_loaders.items(): total = len(loader.dataset) specific_correct, class_correct = self.do_test(loader) specific_acc = float(specific_correct) / total class_acc = float(class_correct) / total self.logger.log_test(phase, {"specific": specific_acc, "class": class_acc}) self.results[phase][self.current_epoch] = class_acc def do_test(self, loader): specific_correct = 0 class_correct = 0 for it, ((data, jig_l, class_l), _) in enumerate(loader): data, jig_l, class_l = data.to(self.device), jig_l.to(self.device), class_l.to(self.device) dummy_ids = one_hot(np.zeros(len(data), dtype=np.int32), 3) specific_logit, class_logit = self.model(data, torch.tensor(dummy_ids, dtype=torch.float, device='cuda')) _, cls_pred = class_logit.max(dim=1) _, specific_pred = specific_logit.max(dim=1) class_correct += torch.sum(cls_pred == class_l.data) specific_correct += torch.sum(specific_pred == class_l.data) print (self.model.embs, self.model.cs_wt) return specific_correct, class_correct def do_training(self): self.logger = Logger(self.args, update_frequency=30) # , "domain", "lambda" self.results = {"val": torch.zeros(self.args.epochs), "test": torch.zeros(self.args.epochs)} k = 512 for self.current_epoch in range(self.args.epochs): self.scheduler.step() self.logger.new_epoch(self.scheduler.get_lr()) self._do_epoch() val_res = self.results["val"] test_res = self.results["test"] idx_best = val_res.argmax() self.logger.save_best(test_res[idx_best], test_res.max()) return self.logger, self.model
class Trainer: def __init__(self, args, device): self.args = args self.device = device if args.network == 'resnet18': model = resnet18(pretrained=True, classes=args.n_classes) elif args.network == 'resnet50': model = resnet50(pretrained=True, classes=args.n_classes) else: model = resnet18(pretrained=True, classes=args.n_classes) self.model = model.to(device) # print(self.model) self.source_loader, self.val_loader = data_helper.get_train_dataloader( args, patches=model.is_patch_based()) self.target_loader = data_helper.get_val_dataloader( args, patches=model.is_patch_based()) self.test_loaders = { "val": self.val_loader, "test": self.target_loader } #source_loader의 length self.len_dataloader = len(self.source_loader) print("Dataset size: train %d, val %d, test %d" % (len(self.source_loader.dataset), len( self.val_loader.dataset), len(self.target_loader.dataset))) #optimizer : SGD self.optimizer, self.scheduler = get_optim_and_scheduler( model, args.epochs, args.learning_rate, args.train_all, nesterov=args.nesterov) self.n_classes = args.n_classes if args.target in args.source: self.target_id = args.source.index(args.target) print("Target in source: %d" % self.target_id) print(args.source) else: self.target_id = None def _do_epoch(self, epoch=None): #============================= #train #============================= criterion = nn.CrossEntropyLoss() #train 모드 self.model.train() #it : batch 몇번째인지 #data : input image #jig_l : ? (일단 무조건 0으로 고정되어 있음) #class_l = class label index #d_idx= ? for it, ((data, jig_l, class_l), d_idx) in enumerate(self.source_loader): #data들 cuda에 올리기 data, jig_l, class_l, d_idx = data.to(self.device), jig_l.to( self.device), class_l.to(self.device), d_idx.to(self.device) #gradient descent 직전에 초기화 해주기 self.optimizer.zero_grad() #3 dimension axis에 대해 flip #detach를 통해 original data tensor 대한 연산들이 추적되는 것을 방지 #clone을 통해 autograd relationship이 없는 tensor를 생성 #어쨌든 data_flip은 computational graph에서 빠져있기 때문에 data tensor에 영향을 미치지 못함 print(data.shape) data_flip = torch.flip(data, (3, )).detach().clone() #밑의 2가지 concatenate를 하면서 batch가 64+64=128이 됨 #data와 data_flip을 concatenate #data.shape = (128,3,222,222) data = torch.cat((data, data_flip)) #class label을 concatenate class_l = torch.cat((class_l, class_l)) #class score vector 구하기 print(0, data.shape) class_logit = self.model(data, class_l, True, epoch) #loss구하기 class_loss = criterion(class_logit, class_l) #class prediction _, cls_pred = class_logit.max(dim=1) loss = class_loss #구한 loss로부터 back propagation을 통해 각 변수마다 loss에 대한 gradient 를 구해주기 loss.backward() #model의 paramater update self.optimizer.step() self.logger.log( it, len(self.source_loader), {"class": class_loss.item()}, { "class": torch.sum(cls_pred == class_l.data).item(), }, data.shape[0]) del loss, class_loss, class_logit #============================= #test #============================= self.model.eval() with torch.no_grad(): for phase, loader in self.test_loaders.items(): total = len(loader.dataset) class_correct = self.do_test(loader) class_acc = float(class_correct) / total self.logger.log_test(phase, {"class": class_acc}) self.results[phase][self.current_epoch] = class_acc def do_test(self, loader): class_correct = 0 for it, ((data, nouse, class_l), _) in enumerate(loader): data, nouse, class_l = data.to(self.device), nouse.to( self.device), class_l.to(self.device) class_logit = self.model(data, class_l, False) _, cls_pred = class_logit.max(dim=1) class_correct += torch.sum(cls_pred == class_l.data) return class_correct #train 함수 def do_training(self): self.logger = Logger(self.args, update_frequency=30) self.results = { "val": torch.zeros(self.args.epochs), "test": torch.zeros(self.args.epochs) } #epoch만큼 train for self.current_epoch in range(self.args.epochs): #scheduler에 따른 learning rate 갱신 self.scheduler.step() self.logger.new_epoch(self.scheduler.get_lr()) #실제 rsc algortihm이 포함된 코드 수행 self._do_epoch(self.current_epoch) val_res = self.results["val"] test_res = self.results["test"] idx_best = val_res.argmax() print( "Best val %g, corresponding test %g - best test: %g, best epoch: %g" % (val_res.max(), test_res[idx_best], test_res.max(), idx_best)) self.logger.save_best(test_res[idx_best], test_res.max()) return self.logger, self.model
class Trainer: def __init__(self, args, device): self.args = args self.device = device model = model_factory.get_network(args.network)(classes=args.n_classes) self.model = model.to(device) # print(self.model) self.source_loader, self.val_loader = data_helper.get_train_dataloader( args, patches=model.is_patch_based()) self.target_loader = data_helper.get_val_dataloader( args, patches=model.is_patch_based()) self.test_loaders = { "val": self.val_loader, "test": self.target_loader } self.len_dataloader = len(self.source_loader) print("Dataset size: train %d, val %d, test %d" % (len(self.source_loader.dataset), len( self.val_loader.dataset), len(self.target_loader.dataset))) self.optimizer, self.scheduler, self.optimizer_par, self.scheduler_par = get_optim_and_scheduler_PAR( model, args.epochs, args.learning_rate, args.par_learning_rate, args.train_all, nesterov=args.nesterov) self.par_weight = args.par_weight self.only_non_scrambled = args.classify_only_sane self.n_classes = args.n_classes if args.target in args.source: self.target_id = args.source.index(args.target) print("Target in source: %d" % self.target_id) print(args.source) else: self.target_id = None # import ipdb;ipdb.set_trace() def accuracy(self, output, target, topk=(1, )): """Computes the accuracy over the k top predictions for the specified values of k""" with torch.no_grad(): maxk = max(topk) batch_size = target.size(0) _, pred = output.topk(maxk, 1, True, True) pred = pred.t() correct = pred.eq(target.view(1, -1).expand_as(pred)) res = [] for k in topk: correct_k = correct[:k].view(-1).float().sum(0, keepdim=True) res.append(correct_k) return res def _do_epoch(self): criterion = nn.CrossEntropyLoss() self.model.train() for it, ((data, _, class_l), _) in enumerate(self.source_loader): data, class_l = data.to(self.device), class_l.to(self.device) # update par classifier self.optimizer_par.zero_grad() class_logit, par_logit = self.model(data) m, n = par_logit.shape[1], par_logit.shape[2] par_class_l = class_l.view(-1, 1, 1, 1).repeat(1, m, n, 1).view(-1) par_loss = criterion(par_logit.view(-1, self.n_classes), par_class_l) _, par_pred = par_logit.view(-1, self.n_classes).max(dim=1) par_loss.backward() self.optimizer_par.step() # update main classifier self.optimizer.zero_grad() class_logit, par_logit = self.model(data) class_loss = criterion(class_logit, class_l) # import ipdb;ipdb.set_trace() par_loss2 = criterion(par_logit.view(-1, self.n_classes), par_class_l) # top1_correct_pred, top5_correct_pred = self.accuracy(class_logit, class_l, topk=[1,5]) _, cls_pred = class_logit.max(dim=1) loss = class_loss - par_loss2 * self.par_weight # loss = class_loss loss.backward() self.optimizer.step() self.logger.log( it, len(self.source_loader), { "par": par_loss.item(), "class": class_loss.item() }, # ,"lambda": lambda_val}, { "par": torch.sum(par_pred == par_class_l.data).type( torch.FloatTensor) / (m * n), "class": torch.sum(cls_pred == class_l.data).item(), # "top5 class": top5_correct_pred.item(), }, data.shape[0]) # print(time()-begin) del loss, class_loss, par_loss, par_logit, class_logit self.model.eval() with torch.no_grad(): for phase, loader in self.test_loaders.items(): total = len(loader.dataset) par_correct, top1_correct_pred, top5_correct_pred = self.do_test( loader) par_acc = float(par_correct) / total class_top1_acc = float(top1_correct_pred) / total class_top5_acc = float(top5_correct_pred) / total self.logger.log_test( phase, { "par": par_acc, "class top1": class_top1_acc, "class top5": class_top5_acc }) self.results[phase + 'top1'][self.current_epoch] = class_top1_acc self.results[phase + 'top5'][self.current_epoch] = class_top5_acc def do_test(self, loader): par_correct = 0 # class_correct = 0 class_correct_top1 = 0 class_correct_top5 = 0 domain_correct = 0 for it, ((data, _, class_l), _) in enumerate(loader): data, class_l = data.to(self.device), class_l.to(self.device) class_logit, par_logit = self.model(data) m, n = par_logit.shape[1], par_logit.shape[2] par_class_l = class_l.view(-1, 1, 1, 1).repeat(1, m, n, 1).view(-1) _, cls_pred = class_logit.max(dim=1) _, par_pred = par_logit.view(-1, self.n_classes).max(dim=1) top1_correct_pred, top5_correct_pred = self.accuracy(class_logit, class_l, topk=[1, 5]) # class_correct += torch.sum(cls_pred == class_l.data) class_correct_top1 += top1_correct_pred class_correct_top5 += top5_correct_pred # import ipdb;ipdb.set_trace() par_correct += torch.sum(par_pred == par_class_l.data).type( torch.FloatTensor) / (m * n) return par_correct, class_correct_top1, class_correct_top5 def do_training(self): self.logger = Logger(self.args, update_frequency=30) # , "domain", "lambda" self.results = { "valtop1": torch.zeros(self.args.epochs), "valtop5": torch.zeros(self.args.epochs), "testtop1": torch.zeros(self.args.epochs), "testtop5": torch.zeros(self.args.epochs) } for self.current_epoch in range(self.args.epochs): self.scheduler.step() self.scheduler_par.step() self.logger.new_epoch(self.scheduler.get_lr()) self._do_epoch() val_res = self.results["valtop1"] testtop1_res = self.results["testtop1"] testtop5_res = self.results["testtop5"] idx_best = val_res.argmax() print( "Best val %g, corresponding test top1 acc %g top5 acc %g - best test top1: %g, top5: %g" % (val_res.max(), testtop1_res[idx_best], testtop5_res[idx_best], testtop1_res.max(), testtop5_res.max())) self.logger.save_best(testtop1_res[idx_best], testtop1_res.max()) return self.logger, self.model