class TrainLoop(object): def __init__(self, models_dict, optimizer_task, source_loader, test_source_loader, target_loader, nadir_slack, alpha, patience, factor, label_smoothing, warmup_its, lr_threshold, verbose=-1, cp_name=None, save_cp=False, checkpoint_path=None, checkpoint_epoch=None, cuda=True, logging=False, ablation='no', train_mode='hv'): if checkpoint_path is None: # Save to current directory self.checkpoint_path = os.getcwd() else: self.checkpoint_path = checkpoint_path if not os.path.isdir(self.checkpoint_path): os.mkdir(self.checkpoint_path) self.save_epoch_fmt_task = os.path.join( self.checkpoint_path, 'task' + cp_name) if cp_name else os.path.join(self.checkpoint_path, 'task_checkpoint_{}ep.pt') self.save_epoch_fmt_domain = os.path.join( self.checkpoint_path, 'Domain_{}' + cp_name) if cp_name else os.path.join(self.checkpoint_path, 'Domain_{}.pt') self.cuda_mode = cuda self.feature_extractor = models_dict['feature_extractor'] self.task_classifier = models_dict['task_classifier'] self.domain_discriminator_list = models_dict[ 'domain_discriminator_list'] self.optimizer_task = optimizer_task self.source_loader = source_loader self.test_source_loader = test_source_loader self.target_loader = target_loader self.history = { 'loss_task': [], 'hypervolume': [], 'loss_domain': [], 'accuracy_source': [], 'accuracy_target': [] } self.cur_epoch = 0 self.total_iter = 0 self.nadir_slack = nadir_slack self.alpha = alpha self.ablation = ablation self.train_mode = train_mode self.device = next(self.feature_extractor.parameters()).device its_per_epoch = len(source_loader.dataset) // ( source_loader.batch_size) + 1 if len(source_loader.dataset) % ( source_loader.batch_size) > 0 else len( source_loader.dataset) // (source_loader.batch_size) patience = patience * (1 + its_per_epoch) self.after_scheduler_task = torch.optim.lr_scheduler.ReduceLROnPlateau( self.optimizer_task, factor=factor, patience=patience, verbose=True if verbose > 0 else False, threshold=lr_threshold, min_lr=1e-7) self.after_scheduler_disc_list = [ torch.optim.lr_scheduler.ReduceLROnPlateau( disc.optimizer, factor=factor, patience=patience, verbose=True if verbose > 0 else False, threshold=lr_threshold, min_lr=1e-7) for disc in self.domain_discriminator_list ] self.verbose = verbose self.save_cp = save_cp self.scheduler_task = GradualWarmupScheduler( self.optimizer_task, total_epoch=warmup_its, after_scheduler=self.after_scheduler_task) self.scheduler_disc_list = [ GradualWarmupScheduler(self.domain_discriminator_list[i].optimizer, total_epoch=warmup_its, after_scheduler=sch_disc) for i, sch_disc in enumerate(self.after_scheduler_disc_list) ] if checkpoint_epoch is not None: self.load_checkpoint(checkpoint_epoch) self.logging = logging if self.logging: from torch.utils.tensorboard import SummaryWriter self.writer = SummaryWriter() if label_smoothing > 0.0: self.ce_criterion = LabelSmoothingLoss(label_smoothing, lbl_set_size=7) else: self.ce_criterion = torch.nn.CrossEntropyLoss( ) #torch.nn.NLLLoss()# #loss_domain_discriminator weight = torch.tensor([2.0 / 3.0, 1.0 / 3.0]).to(self.device) #d_cr=torch.nn.CrossEntropyLoss(weight=weight) self.d_cr = torch.nn.NLLLoss(weight=weight) #### Edit#### def adjust_learning_rate(self, optimizer, epoch=1, every_n=700, In_lr=0.01): """Sets the learning rate to the initial LR decayed by 10 every n epoch epochs""" every_n_epoch = every_n #n_epoch/n_step lr = In_lr * (0.1**(epoch // every_n_epoch)) for param_group in optimizer.param_groups: param_group['lr'] = lr return lr ###### def train(self, n_epochs=1, save_every=1): ''' ##### Edit #### # init necessary objects num_steps = n_epochs #* (len(train_loader.dataset) / train_loader.batch_size) #yd = Variable(torch.from_numpy(np.hstack([np.repeat(1, int(batch_size / 2)), np.repeat(0, int(batch_size / 2))]).reshape(50, 1))) j = 0 lambd=0 max_lambd= 1 - self.alpha max_lr_t = list(self.optimizer_task.param_groups)[-1]['initial_lr'] max_lr_d = list(self.domain_discriminator_list[0].optimizer.param_groups)[-1]['initial_lr'] #0.005 #lr_2=0 #lr=In_lr #Domain_Classifier.set_lambda(lambd) #print(num_steps) ''' max_lr_t = list(self.optimizer_task.param_groups)[-1]['initial_lr'] max_lr_d = list(self.domain_discriminator_list[0].optimizer. param_groups)[-1]['initial_lr'] ############### while self.cur_epoch < n_epochs: cur_loss_task = 0 cur_hypervolume = 0 cur_loss_total = 0 source_iter = tqdm(enumerate(self.source_loader), disable=False) ''' ## Edit ## ## update lambda and learning rate as suggested in the paper p = float(j) / num_steps lambd = max_lambd * round(2. / (1. + np.exp(-10. * p)) - 1, 3) #I changed 10 to 7 #lambd=lambd lr_t = max_lr_t / (1. + 10 * p)**0.75 # lr_d = max_lr_d / (1. + 10 * p)**0.75 # lr_t= self.adjust_learning_rate(self.optimizer_task,In_lr=lr_t) for _, disc in enumerate(self.domain_discriminator_list): lr_d= self.adjust_learning_rate(disc.optimizer,In_lr=lr_d) self.alpha= 1 - lambd #print('Alpha = {}' .format(self.alpha)) j += 1 ''' lr_t = max_lr_t #/ (1. + 10 * p)**0.75 # lr_d = max_lr_d #/ (1. + 10 * p)**0.75 # ''' lr_t= self.adjust_learning_rate(self.optimizer_task, epoch= self.cur_epoch, every_n=0.8 * n_epochs, In_lr=lr_t) for _, disc in enumerate(self.domain_discriminator_list): lr_d= self.adjust_learning_rate(disc.optimizer, epoch= self.cur_epoch, every_n=0.8 * n_epochs, In_lr=lr_d) ''' ########## print( 'Epoch {}/{} | Alpha = {:1.3} | Lr_task = {:1.4} | Lr_dis = {:1.4} ' .format(self.cur_epoch + 1, n_epochs, self.alpha, lr_t, lr_d)) for t, batch in source_iter: if self.ablation == 'all': cur_losses = self.train_step_ablation_all(batch) else: cur_losses = self.train_step(batch) self.scheduler_task.step(epoch=self.total_iter, metrics=1. - self.history['accuracy_source'][-1] if self.cur_epoch > 0 else np.inf) for sched in self.scheduler_disc_list: sched.step(epoch=self.total_iter, metrics=1. - self.history['accuracy_source'][-1] if self.cur_epoch > 0 else np.inf) cur_loss_task += cur_losses[0] cur_hypervolume += cur_losses[1] cur_loss_total += cur_losses[2] self.total_iter += 1 if self.logging: self.writer.add_scalar('train/task_loss', cur_losses[0], self.total_iter) self.writer.add_scalar('train/hypervolume_loss', cur_losses[1], self.total_iter) self.writer.add_scalar('train/total_loss', cur_losses[2], self.total_iter) self.history['loss_task'].append(cur_loss_task / (t + 1)) self.history['hypervolume'].append(cur_hypervolume / (t + 1)) print('Current task loss: {}.'.format(cur_loss_task / (t + 1))) print('Current hypervolume: {}.'.format(cur_hypervolume / (t + 1))) self.history['accuracy_source'].append( test(self.test_source_loader, self.feature_extractor, self.task_classifier, self.domain_discriminator_list, self.device, source_target='source', epoch=self.cur_epoch, tb_writer=self.writer if self.logging else None)) self.history['accuracy_target'].append( test(self.target_loader, self.feature_extractor, self.task_classifier, self.domain_discriminator_list, self.device, source_target='target', epoch=self.cur_epoch, tb_writer=self.writer if self.logging else None)) idx = np.argmax(self.history['accuracy_source']) print( 'Valid. on SOURCE data - Current acc., best acc., best acc target, and epoch: {:0.4f}, {:0.4f}, {:0.4f}, {}' .format(self.history['accuracy_source'][-1], np.max(self.history['accuracy_source']), self.history['accuracy_target'][idx], 1 + np.argmax(self.history['accuracy_source']))) print( 'Valid. on TARGET data - Current acc., best acc., and epoch: {:0.4f}, {:0.4f}, {}' .format(self.history['accuracy_target'][-1], np.max(self.history['accuracy_target']), 1 + np.argmax(self.history['accuracy_target']))) if self.logging: self.writer.add_scalar( 'misc/LR-task', self.optimizer_task.param_groups[0]['lr'], self.total_iter) for i, disc in enumerate(self.domain_discriminator_list): self.writer.add_scalar( 'misc/LR-disc{}'.format(i), disc.optimizer.param_groups[0]['lr'], self.total_iter) self.cur_epoch += 1 if self.save_cp and ( self.cur_epoch % save_every == 0 or self.history['accuracy_target'][-1] > np.max([-np.inf] + self.history['accuracy_target'][:-1])): self.checkpointing() if self.logging: self.writer.close() idx_final = np.argmax(self.history['accuracy_source']) idx_loss = np.argmin(self.history['loss_task']) print('min loss task = {} and corresponding target accuracy is = {}'. format(idx_loss + 1, self.history['accuracy_target'][idx_loss])) return np.max( self.history['accuracy_target'] ), self.history['accuracy_target'][idx], self.history[ 'accuracy_target'][idx_loss], self.history['accuracy_target'][-1] def train_step(self, batch): self.feature_extractor.train() self.task_classifier.train() for disc in self.domain_discriminator_list: disc = disc.train() x_1, x_2, x_3, y_task_1, y_task_2, y_task_3, y_domain_1, y_domain_2, y_domain_3 = batch x = torch.cat((x_1, x_2, x_3), dim=0) y_task = torch.cat((y_task_1, y_task_2, y_task_3), dim=0) y_domain = torch.cat((y_domain_1, y_domain_2, y_domain_3), dim=0) if self.cuda_mode: x = x.to(self.device) y_task = y_task.to(self.device) # COMPUTING FEATURES features = self.feature_extractor.forward(x) features_ = features.detach() # DOMAIN DISCRIMINATORS (First) for i, disc in enumerate(self.domain_discriminator_list): y_predict = disc.forward(features_).squeeze() curr_y_domain = torch.where(y_domain == i, torch.ones(y_domain.size(0)), torch.zeros(y_domain.size(0))) curr_y_domain.type_as(y_domain) #print(y_domain.shape,curr_y_domain.shape,y_predict.shape) #print(sum(curr_y_domain)) curr_y_domain = curr_y_domain.long() if self.cuda_mode: curr_y_domain = curr_y_domain.long().to(self.device) #loss_domain_discriminator = F.binary_cross_entropy_with_logits(y_predict, curr_y_domain) #weight = torch.tensor([2.0/3.0, 1.0/3.0]).to(self.device) #d_cr=torch.nn.CrossEntropyLoss(weight=weight) #d_cr=torch.nn.NLLLoss(weight=weight) loss_domain_discriminator = self.d_cr(y_predict, curr_y_domain) #print(loss_domain_discriminator) if self.logging: self.writer.add_scalar('train/D{}_loss'.format(i), loss_domain_discriminator, self.total_iter) disc.optimizer.zero_grad() loss_domain_discriminator.backward() disc.optimizer.step() # UPDATE TASK CLASSIFIER AND FEATURE EXTRACTOR task_out = self.task_classifier.forward(features) loss_domain_disc_list = [] loss_domain_disc_list_float = [] for i, disc in enumerate(self.domain_discriminator_list): y_predict = disc.forward(features).squeeze() curr_y_domain = torch.where(y_domain == i, torch.zeros(y_domain.size(0)), torch.ones(y_domain.size(0))) curr_y_domain = curr_y_domain.long() if self.cuda_mode: curr_y_domain = curr_y_domain.long().to(self.device) #loss_domain_disc_list.append(F.binary_cross_entropy_with_logits(y_predict, curr_y_domain)) #y_predict= y_predict.long().to(self.device) loss_domain_disc_list.append(self.d_cr(y_predict, curr_y_domain)) loss_domain_disc_list_float.append( loss_domain_disc_list[-1].detach().item()) if self.train_mode == 'hv': self.update_nadir_point(loss_domain_disc_list_float) hypervolume = 0 for loss in loss_domain_disc_list: if self.train_mode == 'hv': hypervolume -= torch.log(self.nadir - loss + 1e-6) #hypervolume -= torch.log(loss) elif self.train_mode == 'avg': hypervolume -= loss task_loss = self.ce_criterion(task_out, y_task) loss_total = self.alpha * task_loss + ( 1 - self.alpha) * hypervolume / len(loss_domain_disc_list) self.optimizer_task.zero_grad() loss_total.backward() self.optimizer_task.step() losses_return = task_loss.item(), hypervolume.item(), loss_total.item() return losses_return def train_step_ablation_all(self, batch): self.feature_extractor.train() self.task_classifier.train() for disc in self.domain_discriminator_list: disc = disc.train() x_1, x_2, x_3, y_task_1, y_task_2, y_task_3, _, _, _ = batch x = torch.cat((x_1, x_2, x_3), dim=0) y_task = torch.cat((y_task_1, y_task_2, y_task_3), dim=0) if self.cuda_mode: x = x.to(self.device) y_task = y_task.to(self.device) # COMPUTING FEATURES features = self.feature_extractor.forward(x) task_out = self.task_classifier.forward(features) task_loss = torch.nn.CrossEntropyLoss()(task_out, y_task) self.optimizer_task.zero_grad() task_loss.backward() self.optimizer_task.step() losses_return = task_loss.item(), 0 return losses_return def checkpointing(self): if self.verbose > 0: print(' ') print('Checkpointing...') ckpt = { 'feature_extractor_state': self.feature_extractor.state_dict(), 'task_classifier_state': self.task_classifier.state_dict(), 'optimizer_task_state': self.optimizer_task.state_dict(), 'scheduler_task_state': self.scheduler_task.state_dict(), 'history': self.history, 'cur_epoch': self.cur_epoch } torch.save(ckpt, self.save_epoch_fmt_task.format(self.cur_epoch)) for i, disc in enumerate(self.domain_discriminator_list): ckpt = { 'model_state': disc.state_dict(), 'optimizer_disc_state': disc.optimizer.state_dict(), 'scheduler_disc_state': self.scheduler_disc_list[i].state_dict() } torch.save(ckpt, self.save_epoch_fmt_domain.format(i + 1)) def load_checkpoint(self, epoch): ckpt = self.save_epoch_fmt_task.format(epoch) if os.path.isfile(ckpt): ckpt = torch.load(ckpt) # Load model state self.feature_extractor.load_state_dict( ckpt['feature_extractor_state']) self.task_classifier.load_state_dict(ckpt['task_classifier_state']) self.domain_classifier.load_state_dict( ckpt['domain_classifier_state']) # Load optimizer state self.optimizer.load_state_dict(ckpt['optimizer_task_state']) # Load scheduler state self.scheduler_task.load_state_dict(ckpt['scheduler_task_state']) # Load history self.history = ckpt['history'] self.cur_epoch = ckpt['cur_epoch'] for i, disc in enumerate(self.domain_discriminator_list): ckpt = torch.load(self.save_epoch_fmt_domain.format(i + 1)) disc.load_state_dict(ckpt['model_state']) disc.optimizer.load_state_dict(ckpt['optimizer_disc_state']) self.scheduler_disc_list[i].load_state_dict( ckpt['scheduler_disc_state']) else: print('No checkpoint found at: {}'.format(ckpt)) def print_grad_norms(self, model): norm = 0.0 for params in list( filter(lambda p: p.grad is not None, model.parameters())): norm += params.grad.norm(2).item() print('Sum of grads norms: {}'.format(norm)) def update_nadir_point(self, losses_list): self.nadir = float(np.max(losses_list) * self.nadir_slack + 1e-8)
def train(args, logger): task_time = time.strftime("%Y-%m-%d %H:%M", time.localtime()) Path("./saved_models/").mkdir(parents=True, exist_ok=True) Path("./pretrained_models/").mkdir(parents=True, exist_ok=True) MODEL_SAVE_PATH = './saved_models/' Pretrained_MODEL_PATH = './pretrained_models/' get_model_name = lambda part: f'{part}-{args.data}-{args.tasks}-{args.prefix}.pth' get_pretrain_model_name = lambda part: f'{part}-{args.data}-LP-{args.prefix}.pth' device_string = 'cuda:{}'.format(args.gpu) if torch.cuda.is_available() and args.gpu >=0 else 'cpu' print('Model trainging with '+device_string) device = torch.device(device_string) g = load_graphs(f"./data/{args.data}.dgl")[0][0] efeat_dim = g.edata['feat'].shape[1] nfeat_dim = efeat_dim train_loader, val_loader, test_loader, num_val_samples, num_test_samples = dataloader(args, g) encoder = Encoder(args, nfeat_dim, n_head=args.n_head, dropout=args.dropout).to(device) decoder = Decoder(args, nfeat_dim).to(device) msg2mail = Msg2Mail(args, nfeat_dim) fraud_sampler = frauder_sampler(g) optimizer = torch.optim.Adam(list(encoder.parameters()) + list(decoder.parameters()), lr=args.lr, weight_decay=args.weight_decay) scheduler_lr = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=40) if args.warmup: scheduler_warmup = GradualWarmupScheduler(optimizer, multiplier=1, total_epoch=3, after_scheduler=scheduler_lr) optimizer.zero_grad() optimizer.step() loss_fcn = torch.nn.BCEWithLogitsLoss() loss_fcn = loss_fcn.to(device) early_stopper = EarlyStopMonitor(logger=logger, max_round=args.patience, higher_better=True) if args.pretrain: logger.info(f'Loading the linkpred pretrained attention based encoder model') encoder.load_state_dict(torch.load(Pretrained_MODEL_PATH+get_pretrain_model_name('Encoder'))) for epoch in range(args.n_epoch): # reset node state g.ndata['mail'] = torch.zeros((g.num_nodes(), args.n_mail, nfeat_dim+2), dtype=torch.float32) g.ndata['feat'] = torch.zeros((g.num_nodes(), nfeat_dim), dtype=torch.float32) # init as zero, people can init it using others. g.ndata['last_update'] = torch.zeros((g.num_nodes()), dtype=torch.float32) encoder.train() decoder.train() start_epoch = time.time() m_loss = [] logger.info('start {} epoch, current optim lr is {}'.format(epoch, optimizer.param_groups[0]['lr'])) for batch_idx, (input_nodes, pos_graph, neg_graph, blocks, frontier, current_ts) in enumerate(train_loader): pos_graph = pos_graph.to(device) neg_graph = neg_graph.to(device) if neg_graph is not None else None if not args.no_time or not args.no_pos: current_ts, pos_ts, num_pos_nodes = get_current_ts(args, pos_graph, neg_graph) pos_graph.ndata['ts'] = current_ts else: current_ts, pos_ts, num_pos_nodes = None, None, None _ = dgl.add_reverse_edges(neg_graph) if neg_graph is not None else None emb, _ = encoder(dgl.add_reverse_edges(pos_graph), _, num_pos_nodes) if batch_idx != 0: if 'LP' not in args.tasks and args.balance: neg_graph = fraud_sampler.sample_fraud_event(g, args.bs//5, current_ts.max().cpu()).to(device) logits, labels = decoder(emb, pos_graph, neg_graph) loss = loss_fcn(logits, labels) optimizer.zero_grad() loss.backward() optimizer.step() m_loss.append(loss.item()) # MSG Passing with torch.no_grad(): mail = msg2mail.gen_mail(args, emb, input_nodes, pos_graph, frontier, 'train') if not args.no_time: g.ndata['last_update'][pos_graph.ndata[dgl.NID][:num_pos_nodes]] = pos_ts.to('cpu') g.ndata['feat'][pos_graph.ndata[dgl.NID]] = emb.to('cpu') g.ndata['mail'][input_nodes] = mail if batch_idx % 100 == 1: gpu_mem = torch.cuda.max_memory_allocated() / 1.074e9 if torch.cuda.is_available() and args.gpu >= 0 else 0 torch.cuda.empty_cache() mem_perc = psutil.virtual_memory().percent cpu_perc = psutil.cpu_percent(interval=None) output_string = f'Epoch {epoch} | Step {batch_idx}/{len(train_loader)} | CPU {cpu_perc:.1f}% | Sys Mem {mem_perc:.1f}% | GPU Mem {gpu_mem:.4f}GB ' output_string += f'| {args.tasks} Loss {np.mean(m_loss):.4f}' logger.info(output_string) total_epoch_time = time.time() - start_epoch logger.info(' training epoch: {} took {:.4f}s'.format(epoch, total_epoch_time)) val_ap, val_auc, val_acc, val_loss = eval_epoch(args, logger, g, val_loader, encoder, decoder, msg2mail, loss_fcn, device, num_val_samples) logger.info('Val {} Task | ap: {:.4f} | auc: {:.4f} | acc: {:.4f} | Loss: {:.4f}'.format(args.tasks, val_ap, val_auc, val_acc, val_loss)) if args.warmup: scheduler_warmup.step(epoch) else: scheduler_lr.step() early_stopper_metric = val_ap if 'LP' in args.tasks else val_auc if early_stopper.early_stop_check(early_stopper_metric): logger.info('No improvement over {} epochs, stop training'.format(early_stopper.max_round)) logger.info(f'Loading the best model at epoch {early_stopper.best_epoch}') encoder.load_state_dict(torch.load(MODEL_SAVE_PATH+get_model_name('Encoder'))) decoder.load_state_dict(torch.load(MODEL_SAVE_PATH+get_model_name('Decoder'))) test_result = [early_stopper.best_ap, early_stopper.best_auc, early_stopper.best_acc, early_stopper.best_loss] break test_ap, test_auc, test_acc, test_loss = eval_epoch(args, logger, g, test_loader, encoder, decoder, msg2mail, loss_fcn, device, num_test_samples) logger.info('Test {} Task | ap: {:.4f} | auc: {:.4f} | acc: {:.4f} | Loss: {:.4f}'.format(args.tasks, test_ap, test_auc, test_acc, test_loss)) test_result = [test_ap, test_auc, test_acc, test_loss] if early_stopper.best_epoch == epoch: early_stopper.best_ap = test_ap early_stopper.best_auc = test_auc early_stopper.best_acc = test_acc early_stopper.best_loss = test_loss logger.info(f'Saving the best model at epoch {early_stopper.best_epoch}') torch.save(encoder.state_dict(), MODEL_SAVE_PATH+get_model_name('Encoder')) torch.save(decoder.state_dict(), MODEL_SAVE_PATH+get_model_name('Decoder'))
def __init__(self, models_dict, optimizer_task, source_loader, test_source_loader, target_loader, nadir_slack, alpha, patience, factor, label_smoothing, warmup_its, lr_threshold, verbose=-1, cp_name=None, save_cp=False, checkpoint_path=None, checkpoint_epoch=None, cuda=True, logging=False, ablation='no', train_mode='hv'): if checkpoint_path is None: # Save to current directory self.checkpoint_path = os.getcwd() else: self.checkpoint_path = checkpoint_path if not os.path.isdir(self.checkpoint_path): os.mkdir(self.checkpoint_path) self.save_epoch_fmt_task = os.path.join( self.checkpoint_path, 'task' + cp_name) if cp_name else os.path.join(self.checkpoint_path, 'task_checkpoint_{}ep.pt') self.save_epoch_fmt_domain = os.path.join( self.checkpoint_path, 'Domain_{}' + cp_name) if cp_name else os.path.join(self.checkpoint_path, 'Domain_{}.pt') self.cuda_mode = cuda self.feature_extractor = models_dict['feature_extractor'] self.task_classifier = models_dict['task_classifier'] self.domain_discriminator_list = models_dict[ 'domain_discriminator_list'] self.optimizer_task = optimizer_task self.source_loader = source_loader self.test_source_loader = test_source_loader self.target_loader = target_loader self.history = { 'loss_task': [], 'hypervolume': [], 'loss_domain': [], 'accuracy_source': [], 'accuracy_target': [] } self.cur_epoch = 0 self.total_iter = 0 self.nadir_slack = nadir_slack self.alpha = alpha self.ablation = ablation self.train_mode = train_mode self.device = next(self.feature_extractor.parameters()).device its_per_epoch = len(source_loader.dataset) // ( source_loader.batch_size) + 1 if len(source_loader.dataset) % ( source_loader.batch_size) > 0 else len( source_loader.dataset) // (source_loader.batch_size) patience = patience * (1 + its_per_epoch) self.after_scheduler_task = torch.optim.lr_scheduler.ReduceLROnPlateau( self.optimizer_task, factor=factor, patience=patience, verbose=True if verbose > 0 else False, threshold=lr_threshold, min_lr=1e-7) self.after_scheduler_disc_list = [ torch.optim.lr_scheduler.ReduceLROnPlateau( disc.optimizer, factor=factor, patience=patience, verbose=True if verbose > 0 else False, threshold=lr_threshold, min_lr=1e-7) for disc in self.domain_discriminator_list ] self.verbose = verbose self.save_cp = save_cp self.scheduler_task = GradualWarmupScheduler( self.optimizer_task, total_epoch=warmup_its, after_scheduler=self.after_scheduler_task) self.scheduler_disc_list = [ GradualWarmupScheduler(self.domain_discriminator_list[i].optimizer, total_epoch=warmup_its, after_scheduler=sch_disc) for i, sch_disc in enumerate(self.after_scheduler_disc_list) ] if checkpoint_epoch is not None: self.load_checkpoint(checkpoint_epoch) self.logging = logging if self.logging: from torch.utils.tensorboard import SummaryWriter self.writer = SummaryWriter() if label_smoothing > 0.0: self.ce_criterion = LabelSmoothingLoss(label_smoothing, lbl_set_size=7) else: self.ce_criterion = torch.nn.CrossEntropyLoss( ) #torch.nn.NLLLoss()# #loss_domain_discriminator weight = torch.tensor([2.0 / 3.0, 1.0 / 3.0]).to(self.device) #d_cr=torch.nn.CrossEntropyLoss(weight=weight) self.d_cr = torch.nn.NLLLoss(weight=weight)
params_to_finetune = list(im_im_scorer_model.parameters()) models_to_save.append(im_im_scorer_model) # optimizer optfunc = { 'adam': optim.Adam, 'rmsprop': optim.RMSprop, 'sgd': optim.SGD }[args.optimizer] pretrain_optimizer = optfunc(params_to_pretrain, lr=args.pt_lr) finetune_optimizer = optfunc(params_to_finetune, lr=args.ft_lr) # models_to_save.append(optimizer) after_scheduler = optim.lr_scheduler.StepLR(pretrain_optimizer, 4000, 0.5) # after_scheduler = optim.lr_scheduler.CosineAnnealingLR(pretrain_optimizer, T_max=500*args.pt_epochs-1000) pt_scheduler = GradualWarmupScheduler(pretrain_optimizer, 1.0, total_epoch=1000, after_scheduler=after_scheduler) ft_scheduler = optim.lr_scheduler.StepLR(finetune_optimizer, 80 * args.ft_epochs, 0.5) print(sum([p.numel() for p in params_to_pretrain])) print(sum([p.numel() for p in params_to_finetune])) if args.load_checkpoint and os.path.exists( os.path.join(args.exp_dir, 'checkpoint.pth.tar')): ckpt_path = os.path.join(args.exp_dir, 'checkpoint.pth.tar') sds = torch.load(ckpt_path, map_location=device) for m in models_to_save: if (not isinstance(m, TransformerAgg)) and (not isinstance( m, RelationNetAgg)): print(m.load_state_dict(sds[repr(m)])) print("loaded checkpoint")
class TrainLoop(object): def __init__(self, args, models_dict, optimizer_task, source_loader, test_source_loader, target_loader, nadir_slack, alpha, patience, factor, label_smoothing, warmup_its, lr_threshold, verbose=-1, cp_name=None, save_cp=True, checkpoint_path=None, checkpoint_epoch=None, cuda=True, logging=False, ablation='no', train_mode='hv'): if checkpoint_path is None: # Save to current directory self.checkpoint_path = os.getcwd() else: self.checkpoint_path = checkpoint_path if not os.path.isdir(self.checkpoint_path): try: os.mkdir(self.checkpoint_path) except OSError: pass self.save_epoch_fmt_task = os.path.join( self.checkpoint_path, 'task' + cp_name) if cp_name else os.path.join(self.checkpoint_path, 'task_checkpoint_{}ep.pt') self.save_epoch_fmt_domain = os.path.join( self.checkpoint_path, 'Domain_{}' + cp_name) if cp_name else os.path.join(self.checkpoint_path, 'Domain_{}.pt') self.cuda_mode = cuda self.feature_extractor = models_dict['feature_extractor'] self.task_classifier = models_dict['task_classifier'] self.domain_discriminator_list = models_dict[ 'domain_discriminator_list'] self.optimizer_task = optimizer_task self.source_loader = source_loader self.test_source_loader = test_source_loader self.target_loader = target_loader self.history = { 'loss_task': [], 'loss_total': [], 'loss_domain': [], 'accuracy_source': [], 'accuracy_target': [], 'loss_task_val_source': [] } self.cur_epoch = 0 self.total_iter = 0 self.nadir_slack = nadir_slack self.alpha = alpha self.ablation = ablation self.train_mode = train_mode self.device = next(self.feature_extractor.parameters()).device self.args = args its_per_epoch = len(source_loader.dataset) // ( source_loader.batch_size) + 1 if len(source_loader.dataset) % ( source_loader.batch_size) > 0 else len( source_loader.dataset) // (source_loader.batch_size) patience = patience * (1 + its_per_epoch) self.after_scheduler_task = torch.optim.lr_scheduler.ReduceLROnPlateau( self.optimizer_task, factor=factor, patience=patience, verbose=True if verbose > 0 else False, threshold=lr_threshold, min_lr=1e-7) self.after_scheduler_disc_list = [ torch.optim.lr_scheduler.ReduceLROnPlateau( disc.optimizer, factor=factor, patience=patience, verbose=True if verbose > 0 else False, threshold=lr_threshold, min_lr=1e-7) for disc in self.domain_discriminator_list ] self.verbose = verbose self.save_cp = save_cp self.scheduler_task = GradualWarmupScheduler( self.optimizer_task, total_epoch=warmup_its, after_scheduler=self.after_scheduler_task) self.scheduler_disc_list = [ GradualWarmupScheduler(self.domain_discriminator_list[i].optimizer, total_epoch=warmup_its, after_scheduler=sch_disc) for i, sch_disc in enumerate(self.after_scheduler_disc_list) ] if checkpoint_epoch is not None: self.load_checkpoint(checkpoint_epoch) self.logging = logging if self.logging: from torch.utils.tensorboard import SummaryWriter log_path = args.checkpoint_path + 'runs/' self.writer = SummaryWriter(log_path) if label_smoothing > 0.0: self.ce_criterion = LabelSmoothingLoss(label_smoothing, lbl_set_size=7) else: self.ce_criterion = torch.nn.CrossEntropyLoss() def train(self, n_epochs=1, save_every=1): while self.cur_epoch < n_epochs: print('Epoch {}/{}'.format(self.cur_epoch + 1, n_epochs)) self.cur_loss_task = 0 self.cur_hypervolume = 0 self.cur_loss_total = 0 source_iter = tqdm(enumerate(self.source_loader), total=len(self.source_loader), disable=False) for t, batch in source_iter: if self.ablation == 'all': cur_losses = self.train_step_ablation_all(batch) else: cur_losses = self.train_step(batch) self.scheduler_task.step(epoch=self.total_iter, metrics=1. - self.history['accuracy_source'][-1] if self.cur_epoch > 0 else np.inf) for sched in self.scheduler_disc_list: sched.step(epoch=self.total_iter, metrics=1. - self.history['accuracy_source'][-1] if self.cur_epoch > 0 else np.inf) self.cur_loss_task += cur_losses[0] self.cur_hypervolume += cur_losses[1] self.cur_loss_total += cur_losses[2] self.total_iter += 1 if self.logging: self.writer.add_scalar('Iteration/task_loss', cur_losses[0], self.total_iter) self.writer.add_scalar('Iteration/domain_loss', cur_losses[1], self.total_iter) self.writer.add_scalar('Iteration/total_loss', cur_losses[2], self.total_iter) self.history['loss_task'].append(self.cur_loss_task / (t + 1)) self.history['loss_domain'].append(self.cur_hypervolume / (t + 1)) self.history['loss_total'].append(self.cur_loss_total / (t + 1)) acc_source, loss_task_val_source = test( self.test_source_loader, self.feature_extractor, self.task_classifier, self.domain_discriminator_list, self.device, source_target='source', epoch=self.cur_epoch, tb_writer=self.writer if self.logging else None) acc_target, loss_task_target = test( self.target_loader, self.feature_extractor, self.task_classifier, self.domain_discriminator_list, self.device, source_target='target', epoch=self.cur_epoch, tb_writer=self.writer if self.logging else None) self.history['accuracy_source'].append(acc_source) self.history['accuracy_target'].append(acc_target) self.history['loss_task_val_source'].append(loss_task_val_source) self.source_epoch_best_loss_task = np.argmin( self.history['loss_task']) self.source_epoch_best_loss_domain = np.argmin( self.history['loss_domain']) self.source_epoch_best_loss_total = np.argmin( self.history['loss_total']) self.source_epoch_best_loss_task_val = np.argmin( self.history['loss_task_val_source']) self.source_epoch_best = np.argmax(self.history['accuracy_source']) self.target_epoch_best = np.argmax(self.history['accuracy_target']) self.source_best_acc = np.max(self.history['accuracy_source']) self.target_best_loss_task = self.history['accuracy_target'][ self.source_epoch_best_loss_task] self.target_best_loss_domain = self.history['accuracy_target'][ self.source_epoch_best_loss_domain] self.target_best_loss_total = self.history['accuracy_target'][ self.source_epoch_best_loss_total] self.target_best_source_acc = self.history['accuracy_target'][ self.source_epoch_best] self.target_best_acc = np.max(self.history['accuracy_target']) self.target_best_acc_loss_task_val = self.history[ 'accuracy_target'][self.source_epoch_best_loss_task_val] self.print_results() if self.logging: self.writer.add_scalar( 'misc/LR-task', self.optimizer_task.param_groups[0]['lr'], self.total_iter) for i, disc in enumerate(self.domain_discriminator_list): self.writer.add_scalar( 'misc/LR-disc{}'.format(i), disc.optimizer.param_groups[0]['lr'], self.total_iter) self.writer.add_scalar('Epoch/Loss-total', self.history['loss_total'][-1], self.cur_epoch) self.writer.add_scalar('Epoch/Loss-task', self.history['loss_task'][-1], self.cur_epoch) self.writer.add_scalar('Epoch/Loss-domain', self.history['loss_domain'][-1], self.cur_epoch) self.writer.add_scalar( 'Epoch/Loss-task-val', self.history['loss_task_val_source'][-1], self.cur_epoch) self.writer.add_scalar('Epoch/Acc-Source', self.history['accuracy_source'][-1], self.cur_epoch) self.writer.add_scalar('Epoch/Acc-target', self.history['accuracy_target'][-1], self.cur_epoch) self.cur_epoch += 1 if self.save_cp and ( self.cur_epoch % save_every == 0 or self.history['accuracy_target'][-1] > np.max([-np.inf] + self.history['accuracy_target'][:-1])): self.checkpointing() if self.logging: self.writer.close() results_acc = [ self.target_best_loss_task, self.target_best_loss_domain, self.target_best_loss_total, self.target_best_source_acc, self.source_best_acc, self.target_best_acc ] results_epochs = [ self.source_epoch_best_loss_task, self.source_epoch_best_loss_domain, self.source_epoch_best_loss_total, self.source_epoch_best, self.target_epoch_best ] return np.min( self.history['loss_task_val_source']), results_acc, results_epochs def train_step(self, batch): self.feature_extractor.train() self.task_classifier.train() for disc in self.domain_discriminator_list: disc = disc.train() x_1, x_2, x_3, y_task_1, y_task_2, y_task_3, y_domain_1, y_domain_2, y_domain_3 = batch x = torch.cat((x_1, x_2, x_3), dim=0) y_task = torch.cat((y_task_1, y_task_2, y_task_3), dim=0) y_domain = torch.cat((y_domain_1, y_domain_2, y_domain_3), dim=0) if self.cuda_mode: x = x.to(self.device) y_task = y_task.to(self.device) # COMPUTING FEATURES features = self.feature_extractor.forward(x) features_ = features.detach() # DOMAIN DISCRIMINATORS for i, disc in enumerate(self.domain_discriminator_list): y_predict = disc.forward(features_).squeeze() curr_y_domain = torch.where(y_domain == i, torch.ones(y_domain.size(0)), torch.zeros(y_domain.size(0))) if self.cuda_mode: curr_y_domain = curr_y_domain.float().to(self.device) loss_domain_discriminator = F.binary_cross_entropy_with_logits( y_predict, curr_y_domain) if self.logging: self.writer.add_scalar('train/D{}_loss'.format(i), loss_domain_discriminator, self.total_iter) disc.optimizer.zero_grad() loss_domain_discriminator.backward() disc.optimizer.step() # UPDATE TASK CLASSIFIER AND FEATURE EXTRACTOR task_out = self.task_classifier.forward(features) loss_domain_disc_list = [] loss_domain_disc_list_float = [] for i, disc in enumerate(self.domain_discriminator_list): y_predict = disc.forward(features).squeeze() curr_y_domain = torch.where(y_domain == i, torch.zeros(y_domain.size(0)), torch.ones(y_domain.size(0))) if self.cuda_mode: curr_y_domain = curr_y_domain.float().to(self.device) loss_domain_disc_list.append( F.binary_cross_entropy_with_logits(y_predict, curr_y_domain)) loss_domain_disc_list_float.append( loss_domain_disc_list[-1].detach().item()) if self.train_mode == 'hv': self.update_nadir_point(loss_domain_disc_list_float) hypervolume = 0 for loss in loss_domain_disc_list: if self.train_mode == 'hv': hypervolume -= torch.log(self.nadir - loss + 1e-6) elif self.train_mode == 'avg': hypervolume -= loss task_loss = self.ce_criterion(task_out, y_task) loss_total = self.alpha * task_loss + ( 1 - self.alpha) * hypervolume / len(loss_domain_disc_list) self.optimizer_task.zero_grad() loss_total.backward() self.optimizer_task.step() losses_return = task_loss.item(), hypervolume.item(), loss_total.item() return losses_return def train_step_ablation_all(self, batch): self.feature_extractor.train() self.task_classifier.train() for disc in self.domain_discriminator_list: disc = disc.train() x_1, x_2, x_3, y_task_1, y_task_2, y_task_3, _, _, _ = batch x = torch.cat((x_1, x_2, x_3), dim=0) y_task = torch.cat((y_task_1, y_task_2, y_task_3), dim=0) if self.cuda_mode: x = x.to(self.device) y_task = y_task.to(self.device) # COMPUTING FEATURES features = self.feature_extractor.forward(x) task_out = self.task_classifier.forward(features) task_loss = torch.nn.CrossEntropyLoss()(task_out, y_task) self.optimizer_task.zero_grad() task_loss.backward() self.optimizer_task.step() losses_return = task_loss.item(), 0 return losses_return def checkpointing(self): if self.verbose > 0: print(' ') print('Checkpointing...') ckpt = { 'feature_extractor_state': self.feature_extractor.state_dict(), 'task_classifier_state': self.task_classifier.state_dict(), 'optimizer_task_state': self.optimizer_task.state_dict(), 'scheduler_task_state': self.scheduler_task.state_dict(), 'history': self.history, 'cur_epoch': self.cur_epoch } torch.save(ckpt, self.save_epoch_fmt_task.format(self.cur_epoch)) for i, disc in enumerate(self.domain_discriminator_list): ckpt = { 'model_state': disc.state_dict(), 'optimizer_disc_state': disc.optimizer.state_dict(), 'scheduler_disc_state': self.scheduler_disc_list[i].state_dict() } torch.save(ckpt, self.save_epoch_fmt_domain.format(i + 1)) def load_checkpoint(self, epoch): ckpt = self.save_epoch_fmt_task.format(epoch) if os.path.isfile(ckpt): ckpt = torch.load(ckpt) # Load model state self.feature_extractor.load_state_dict( ckpt['feature_extractor_state']) self.task_classifier.load_state_dict(ckpt['task_classifier_state']) self.domain_classifier.load_state_dict( ckpt['domain_classifier_state']) # Load optimizer state self.optimizer.load_state_dict(ckpt['optimizer_task_state']) # Load scheduler state self.scheduler_task.load_state_dict(ckpt['scheduler_task_state']) # Load history self.history = ckpt['history'] self.cur_epoch = ckpt['cur_epoch'] for i, disc in enumerate(self.domain_discriminator_list): ckpt = torch.load(self.save_epoch_fmt_domain.format(i + 1)) disc.load_state_dict(ckpt['model_state']) disc.optimizer.load_state_dict(ckpt['optimizer_disc_state']) self.scheduler_disc_list[i].load_state_dict( ckpt['scheduler_disc_state']) else: print('No checkpoint found at: {}'.format(ckpt)) def print_grad_norms(self, model): norm = 0.0 for params in list( filter(lambda p: p.grad is not None, model.parameters())): norm += params.grad.norm(2).item() print('Sum of grads norms: {}'.format(norm)) def update_nadir_point(self, losses_list): self.nadir = float(np.max(losses_list) * self.nadir_slack + 1e-8) def print_results(self): print('Current task loss: {}.'.format(self.history['loss_task'][-1])) print('Current hypervolume: {}.'.format( self.history['loss_domain'][-1])) print('Current total loss: {}.'.format(self.history['loss_total'][-1])) print('VALIDATION ON SOURCE DOMAINS - {}, {}, {}'.format( self.args.source1, self.args.source2, self.args.source3)) print('Current, best, and epoch: {:0.4f}, {:0.4f}, {}'.format( self.history['accuracy_source'][-1], self.source_best_acc, self.source_epoch_best)) print('VALIDATION ON TARGET DOMAIN - {}'.format(self.args.target)) print('Current, best, and epoch: {:0.4f}, {:0.4f}, {}'.format( self.history['accuracy_target'][-1], self.target_best_acc, self.target_epoch_best)) print('VALIDATION ON TARGET DOMAIN - BEST TOTAL LOSS') print('Best and epoch: {:0.4f}, {}'.format( self.target_best_loss_total, self.source_epoch_best_loss_total)) print('VALIDATION ON TARGET DOMAIN - BEST SOURCE VAL ACC') print('Best and epoch: {:0.4f}, {}'.format(self.target_best_source_acc, self.source_epoch_best)) print('VALIDATION ON TARGET DOMAIN - BEST TASK LOSS') print('Best and epoch: {:0.4f}, {}'.format( self.target_best_loss_task, self.source_epoch_best_loss_task)) print('VALIDATION ON TARGET DOMAIN - BEST DOMAIN DISC LOSS') print('Best and epoch: {:0.4f}, {}'.format( self.target_best_loss_domain, self.source_epoch_best_loss_domain)) print('VALIDATION ON TARGET DOMAIN - BEST VAL TASK LOSS') print('Best and epoch: {:0.4f}, {}'.format( self.target_best_acc_loss_task_val, self.source_epoch_best_loss_task_val))
import torch from utils import GradualWarmupScheduler if __name__ == '__main__': v = torch.zeros(10) optim = torch.optim.SGD([v], lr=0.01) scheduler = GradualWarmupScheduler(optim, multiplier=8, total_epoch=10) for epoch in range(1, 20): scheduler.step(epoch) print(epoch, optim.param_groups[0]['lr'])
def run(): seed_torch(seed=config.seed) os.makedirs(config.MODEL_PATH, exist_ok=True) setup_logger(config.MODEL_PATH + 'log.txt') writer = SummaryWriter(config.MODEL_PATH) folds = pd.read_csv(config.fold_csv) folds.head() if config.tile_stats_csv: attention_df = pd.read_csv(config.tile_stats_csv) attention_df.head() #train val split if config.DEBUG: folds = folds.sample( n=50, random_state=config.seed).reset_index(drop=True).copy() logging.info(f"fold: {config.fold}") fold = config.fold #trn_idx = folds[folds['fold'] != fold].index #val_idx = folds[folds['fold'] == fold].index trn_idx = folds[folds[f'fold_{fold}'] == 0].index val_idx = folds[folds[f'fold_{fold}'] == 1].index df_train = folds.loc[trn_idx] df_val = folds.loc[val_idx] # #------single image------ if config.strategy == 'stitched': train_dataset = PANDADataset(image_folder=config.DATA_PATH, df=df_train, image_size=config.IMG_SIZE, num_tiles=config.num_tiles, rand=False, transform=get_transforms(phase='train'), attention_df=attention_df) valid_dataset = PANDADataset(image_folder=config.DATA_PATH, df=df_val, image_size=config.IMG_SIZE, num_tiles=config.num_tiles, rand=False, transform=get_transforms(phase='valid'), attention_df=attention_df) #------image tiles------ else: train_dataset = PANDADatasetTiles( image_folder=config.DATA_PATH, df=df_train, image_size=config.IMG_SIZE, num_tiles=config.num_tiles, transform=get_transforms(phase='train'), attention_df=attention_df) valid_dataset = PANDADatasetTiles( image_folder=config.DATA_PATH, df=df_val, image_size=config.IMG_SIZE, num_tiles=config.num_tiles, transform=get_transforms(phase='valid'), attention_df=attention_df) train_loader = DataLoader(train_dataset, batch_size=config.batch_size, sampler=RandomSampler(train_dataset), num_workers=multiprocessing.cpu_count(), pin_memory=True) val_loader = DataLoader(valid_dataset, batch_size=config.batch_size, sampler=SequentialSampler(valid_dataset), num_workers=multiprocessing.cpu_count(), pin_memory=True) device = torch.device("cuda") #model=EnetNetVLAD(num_clusters=config.num_cluster,num_tiles=config.num_tiles,num_classes=config.num_class,arch=config.backbone) #model = EnetV1(backbone=config.backbone, num_classes=config.num_class) #------Model use for Generate Tile Weights-------- #model = EfficientModel(c_out=config.num_class,n_tiles=config.num_tiles, # tile_size=config.IMG_SIZE, # name=config.backbone, # strategy='bag', # head='attention') #-------------------------------------------------- #model = Regnet(num_classes=config.num_class,ckpt=config.pretrain_model) model = RegnetNetVLAD(num_clusters=config.num_cluster, num_tiles=config.num_tiles, num_classes=config.num_class, ckpt=config.pretrain_model) model = model.to(device) if config.multi_gpu: model = torch.nn.DataParallel(model) if config.ckpt_path: model.load_state_dict(torch.load(config.ckpt_path)) warmup_factor = 10 warmup_epo = 1 optimizer = Adam(model.parameters(), lr=config.lr / warmup_factor) scheduler_cosine = torch.optim.lr_scheduler.CosineAnnealingLR( optimizer, config.num_epoch - warmup_epo) scheduler = GradualWarmupScheduler(optimizer, multiplier=warmup_factor, total_epoch=warmup_epo, after_scheduler=scheduler_cosine) if config.apex: model, optimizer = amp.initialize(model, optimizer, opt_level="O1", verbosity=0) best_score = 0. best_loss = 100. if config.model_type == 'reg': optimized_rounder = OptimizedRounder() optimizer.zero_grad() optimizer.step() for epoch in range(1, config.num_epoch + 1): if scheduler: scheduler.step(epoch - 1) if config.model_type != 'reg': train_fn(train_loader, model, optimizer, device, epoch, writer, df_train) metric = eval_fn(val_loader, model, device, epoch, writer, df_val) else: coefficients = train_fn(train_loader, model, optimizer, device, epoch, writer, df_train, optimized_rounder) metric = eval_fn(val_loader, model, device, epoch, writer, df_val, coefficients) score = metric['score'] val_loss = metric['loss'] if score > best_score: best_score = score logging.info(f"Epoch {epoch} - found best score {best_score}") save_model(model, config.MODEL_PATH + f"best_kappa_f{fold}.pth") if val_loss < best_loss: best_loss = val_loss logging.info(f"Epoch {epoch} - found best loss {best_loss}") save_model(model, config.MODEL_PATH + f"best_loss_f{fold}.pth")