def train_network(dataloader, model, loss_function, optimizer, start_lr, end_lr, num_epochs=90, sanity_check=False): """Trains the network and saves for different checkpoints such as minimum train/val loss, f1-score, AUC etc. different performance metrics Parameters: ----------- dataloader (dict): {key (str): Value(torch.utils.data.DataLoader)} training and validation dataloader to respective purposes model (nn.Module): models to traine the face-recognition loss_function (torch.nn.Module): Module to mesure loss between target and model-output optimizer (Optimizer): Non vanilla gradient descent method to optimize learning and descent direction start_lr (float): For one cycle training the start learning rate end_lr (float): the end learning must be greater than start learning rate num_epochs (int): number of epochs the one cycle is sanity_check (bool): if the training is perfomed to check the sanity of the model. i.e. to anaswer 'is model is able to overfit for small amount of data?' Returns: -------- None: perfoms the required task of training """ if isinstance(model, dict): for k, v in model.items(): model[k] = v.train() else: model = model.train() logger_msg = '\nDataLoader = {}' \ '\nModel = {}' \ '\nLossFucntion = {}' \ '\nOptimizer = {}' \ '\nStartLR = {}, EndLR = {}' \ '\nNumEpochs = {}'.format(dataloader, model, loss_function, optimizer, start_lr, end_lr, num_epochs) logger.info(logger_msg), print(logger_msg) # [https://arxiv.org/abs/1803.09820] # This is used to find optimal learning-rate which can be used in one-cycle training policy # [LR]TODO: for finding optimal learning rate lr_scheduler = {} if kconfig.tr.lr_search_flag: if isinstance(optimizer, dict): for k, opt in optimizer.items(): lr_scheduler[k] = MultiStepLR(optimizer=opt, milestones=list(np.arange(2, 24, 2)), gamma=10, last_epoch=-1) else: lr_scheduler = MultiStepLR(optimizer=optimizer, milestones=list(np.arange(2, 24, 2)), gamma=10, last_epoch=-1) # TODO: Cyclic momentum # optimizer.param_groups[0]['momentum'] # weight_decay # 0.95 -> 0.8. # This implies that as LR increases during 1Cycle, WD should decrease. # https://forums.fast.ai/t/one-cycle-policy/25944/2 # The large batch training literature recommends not using WD on BN, so if you are asking what your should do, don’t apply WD to BN. def get_lr(): lr = [] # pdb.set_trace() if isinstance(optimizer, dict): for k, opt in optimizer.items(): for param_group in opt.param_groups: lr.append(np.round(param_group['lr'], 11)) break else: for param_group in optimizer.param_groups: lr.append(np.round(param_group['lr'], 11)) return lr def set_lr(lr): if isinstance(optimizer, dict): for k, opt in optimizer.items(): for param_group in opt.param_groups: param_group['lr'] = lr else: for param_group in opt.param_groups: param_group['lr'] = lr def set_momentum(m): # pdb.set_trace() if isinstance(optimizer, dict): for k, opt in optimizer.items(): for param_group in opt.param_groups: param_group['momentum'] = m else: for param_group in opt.param_groups: param_group['momentum'] = m # 'Training': loss Containers train_cur_epoch_batchwise_loss = [] train_epoch_avg_loss_container = [] # Stores loss for each epoch averged over batches. train_all_epoch_batchwise_loss = [] # 'Validation': loss containers val_avg_loss_container = [] # 'Validation': Metric Containers val_report_container = [] val_f1_container = [] val_auc_container = [] val_accuracy_container = [] # 'Test': Metric Containers. Only computed and stored when certain condition is met. # Of course, this is perfomed only when test_set with labels are present. test_auc_container = {} test_f1_container = {} test_accuracy_container = {} # 'Extra' epochs if kconfig.tr.lr_search_flag: extra_epochs = kconfig.one_cycle_policy.extra_epochs.lr_search # 4 else: extra_epochs = kconfig.one_cycle_policy.extra_epochs.train # 20 total_epochs = num_epochs + extra_epochs # One cycle setting of Learning Rate num_steps_upndown = kconfig.one_cycle_policy.num_steps_upndown # 10 further_lowering_factor = kconfig.one_cycle_policy.extra_epochs.lowering_factor # 10 further_lowering_factor_steps = kconfig.one_cycle_policy.extra_epochs.lower_after # 4 # Cyclic Learning Rate def one_cycle_lr_setter(current_epoch): start_momentum = 0.95 end_momentum = 0.85 current_momentum = None if current_epoch <= num_epochs: assert end_lr > start_lr, '[EndLR] should be greater than [StartLR]' lr_inc_rate = np.round((end_lr - start_lr) / (num_steps_upndown), 9) lr_inc_epoch_step_len = max(num_epochs / (2 * num_steps_upndown), 1) steps_completed = current_epoch / lr_inc_epoch_step_len print('[Steps Completed] = ', steps_completed) if steps_completed <= num_steps_upndown: current_lr = start_lr + (steps_completed * lr_inc_rate) current_momentum = start_momentum - ((start_momentum - end_momentum) * int(steps_completed) / num_steps_upndown) else: current_lr = end_lr - ((steps_completed - num_steps_upndown) * lr_inc_rate) current_momentum = end_momentum + ((start_momentum - end_momentum) * int(steps_completed - num_steps_upndown) / num_steps_upndown) set_lr(current_lr) # set_momentum(current_momentum) else: current_lr = start_lr / ( further_lowering_factor ** ((current_epoch - num_epochs) // further_lowering_factor_steps)) set_lr(current_lr) if sanity_check: train_dataloader = next(iter(dataloader['train'])) train_dataloader = [train_dataloader] * 128 else: train_dataloader = dataloader['train'] def reset_grad(optimizer): # Zero Grad if isinstance(optimizer, dict): for k, opt in optimizer.items(): opt.zero_grad() else: optimizer.zero_grad() # Model Tranining for 'total_epochs' counter = 0 ep_ctr = 0 for epoch in range(total_epochs): msg = '\n\n\n[Epoch] = {}'.format(epoch + 1) print(msg) start_time = time.time() start_datetime = datetime.now() for i, (X, y) in enumerate(train_dataloader): y_cls_lbs, y_src_lbs = y # print(f'[Class Labels Counts] = {y_cls_lbs.unique(return_counts=True) }', end='') # print(f'[Source Labels Counts] = {y_src_lbs.unique(return_counts=True) }', end='') X = X.to(device=device, dtype=torch.float32) # or float is alias for float32 # pdb.set_trace() # ep_ctr += 1 # if ep_ctr % 2 == 0 and (kconfig.tr.train_flag or kconfig.tr.sanity_check_flag): # src_idx_randperm = torch.randperm(len(y_src_lbs)) # y_src_lbs = y_src_lbs[src_idx_randperm] y_cls_lbs = y_cls_lbs.to(device=device, dtype=torch.long) y_src_lbs = y_src_lbs.to(device=device, dtype=torch.long) y = (y_cls_lbs, y_src_lbs) # TODO: early breaker if kconfig.tr.early_break and i == 3: print('[Break] by force for validation check') break # if isinstance(model, dict): features_repr = model['feature_repr_model'](X) cls_output = model['cls_model'](features_repr) src_output = model['src_model'](features_repr) else: output = model(X) # # pdb.set_trace() # Reset gradient reset_grad(optimizer) # Domain-Adversarial Training of Neural Networks: https://arxiv.org/abs/1505.07818 if isinstance(loss_function, dict): cls_loss = loss_function['cls_loss'](cls_output, y_cls_lbs) src_loss = loss_function['src_loss'](src_output, y_src_lbs) feature_loss = cls_loss - src_loss # pdb.set_trace() feature_loss.backward() # print(f"\n[Wt] = {model['feature_repr_model'].fc1[0].weight[:10, ...]}") # print(f"\n[grad] = {model['feature_repr_model'].fc1[0].weight.grad[:10, ...]}") optimizer['feature_repr_opt'].step() print('[After] step') # print(f"\n[Wt] = {model['feature_repr_model'].fc1[0].weight[:10, ...]}") # print(f"\n[grad] = {model['feature_repr_model'].fc1[0].weight.grad[:10, ...]}") # print('[Check what happens to gradient]') reset_grad(optimizer) # pdb.set_trace() features_repr = features_repr.detach() cls_output = model['cls_model'](features_repr) src_output = model['src_model'](features_repr) # cls_output = model['cls_model'](features_repr) cls_loss = loss_function['cls_loss'](cls_output, y_cls_lbs) cls_loss.backward() optimizer['cls_opt'].step() # src_output = model['src_model'](features_repr) src_loss = loss_function['src_loss'](src_output, y_src_lbs) src_loss.backward() optimizer['src_opt'].step() loss = feature_loss + cls_loss + src_loss else: loss = loss_function(output, y) loss.backward() optimizer.step() # gap = 1 # if counter > 1 and counter % gap == 0: # print() # print('\n\n\n[BBBBBBefore loss.backward()]') # print_wt_n_output(model, y, output, optimizer=optimizer) # loss.backward() # if counter > 1 and counter % gap == 0: # print('\n\n\n[AAAAAAfter loss.backward()]') # print_wt_n_output(model, y, output, optimizer=optimizer) # optimizer.step() # if counter > 1 and counter % gap == 0: # print('\n\n\n[AAAAAfter optimizer.step()]') # print_wt_n_output(model, y, output, optimizer=optimizer) # counter += 1 # pdb.set_trace() # if isinstance(loss_function, dict): # # check <model['feature_repr_model']> grad before and after also # # check <model['cls_model']> # feature_loss.backward() # cls_loss.backward() # src_loss.backward() # else: # loss.backward() # set_momentum(0.90) train_cur_epoch_batchwise_loss.append(loss.item()) train_all_epoch_batchwise_loss.append(loss.item()) batch_run_msg = '\nEpoch: [%s/%s], Step: [%s/%s], InitialLR: %s, CurrentLR: %s, Loss: %s' \ % (epoch + 1, total_epochs, i + 1, len(train_dataloader), start_lr, get_lr(), loss.item()) print(batch_run_msg) #------------------ End of an Epoch ------------------ # store average loss epoch_avg_loss = np.round(sum(train_cur_epoch_batchwise_loss) / (i + 1.0), 6) train_cur_epoch_batchwise_loss = [] train_epoch_avg_loss_container.append(epoch_avg_loss) # 'Validation': xompute metrics the dataset for saving the models at checkpoints. if not (kconfig.tr.lr_search_flag or sanity_check): val_loss, val_report, f1_checker, auc_val = cal_loss_and_metric(model, dataloader['val'], loss_function, epoch+1) val_report['roc'] = 'Removed' # 'Validation': save model if certain condition is met on the computed metrics. test_test_data = False accuracy = None if not (kconfig.tr.lr_search_flag or sanity_check): val_report_container.append(val_report) # ['epoch_' + str(epoch)] = val_report # Check point for which models will be saved val_avg_loss_container.append(val_loss) val_f1_container.append(f1_checker) val_auc_container.append(auc_val) accuracy = val_report.get('accuracy', None) val_accuracy_container.append(accuracy) if np.round(val_loss, 4) <= np.round(min(val_avg_loss_container), 4): model = save_model(model, extra_extension='_minval') # + '_epoch_' + str(epoch)) if np.round(auc_val, 4) >= np.round(max(val_auc_container), 4): model = save_model(model, extra_extension='_maxauc') # + '_epoch_' + str(epoch)) test_test_data = True if np.round(f1_checker, 4) >= np.round(max(val_f1_container), 4): model = save_model(model, extra_extension='_maxf1') # + '_epoch_' + str(epoch)) test_test_data = True # Save if epoch_avg_loss <= min(train_epoch_avg_loss_container): model = save_model(model, extra_extension='_mintrain') # Logger msg msg = '\n\n\n\n\nEpoch: [%s/%s], InitialLR: %s, CurrentLR= %s \n' \ '\n\n[Train] Average Epoch-wise Loss = %s \n' \ '\n\n********************************************************** [Validation]' \ '\n\n[Validation] Average Epoch-wise loss = %s \n' \ '\n\n[Validation] Report () = %s \n'\ '\n\n[Validation] F-Report = %s\n'\ '\n\n[Validation] Accuracy = %s\n'\ %(epoch+1, total_epochs, start_lr, get_lr(), train_epoch_avg_loss_container, val_avg_loss_container, None if not val_report_container else util.pretty(val_report_container[-1]), val_f1_container, val_accuracy_container) logger.info(msg); print(msg) # 'Test': compute metrics on the test dataset. Again, of course, only if it is present. if not (kconfig.tr.lr_search_flag or sanity_check) and test_test_data and dataloader.get('test', False): test_loss, test_report, test_f1_checker, test_auc = cal_loss_and_metric(model, dataloader['test'], loss_function, epoch+1, model_type='test_set') test_report['roc'] = 'Removed' accuracy = test_report.get('accuracy', None) test_auc_container[epoch+1] = "{0:.3f}".format(round(test_auc, 4)) test_f1_container[epoch+1] = "{0:.3f}".format(round(test_f1_checker, 4)) test_accuracy_container[epoch+1] = "{}".format(accuracy) msg = '\n\n\n\n**********************************************************[Test]\n '\ '[Test] Report = {}' \ '\n\n[Test] fscore = {}' \ '\n\n[Test] AUC dict = {}' \ '\n\n[Test] F1-dict = {}'\ '\n\n[Test] Accuracy = {}'.format(util.pretty(test_report), test_f1_checker, test_auc_container, test_f1_container, test_accuracy_container) logger.info(msg); print(msg) # Strop training if the 'model' is already converged. if epoch_avg_loss < 1e-6 or get_lr()[0] < 1e-11 or get_lr()[0] >= 10: msg = '\n\nAvg. Loss = {} or Current LR = {} thus stopping training'.format(epoch_avg_loss, get_lr()) logger.info(msg) print(msg) break # Cyclic alteration of 'LR' during up and down steps movements. if kconfig.tr.lr_search_flag: # lr_scheduler.step(epoch + 1) # TODO: Only for estimating good learning rate if isinstance(lr_scheduler, dict): for key, lr_scheduler_eg in lr_scheduler.items(): lr_scheduler_eg.step(epoch + 1) else: one_cycle_lr_setter(epoch + 1) # Time keeping for training epoch time. end_time = time.time() end_datetime = datetime.now() msg = '\n\n[Time] taken for epoch({}) time = {}, datetime = {} \n\n'.format(epoch+1, end_time - start_time, end_datetime - start_datetime) logger.info(msg); print(msg) # ----------------- End of training process ----------------- msg = '\n\n[Epoch Loss] = {}'.format(train_epoch_avg_loss_container) logger.info(msg); print(msg) # [LR]TODO: change for lr finder if kconfig.tr.lr_search_flag: losses = train_epoch_avg_loss_container plot_file_name = 'training_epoch_loss_for_lr_finder.png' title = 'Training Epoch Loss' else: losses = {'train': train_epoch_avg_loss_container, 'val': val_avg_loss_container} plot_file_name = 'training_vs_val_epoch_avg_loss.png' title= 'Training vs Validation Epoch Loss' plot_loss(losses=losses, plot_file_name=plot_file_name, title=title) plot_loss(losses=train_all_epoch_batchwise_loss, plot_file_name='training_batchwise.png', title='Training Batchwise Loss', xlabel='#Batchwise') # Save the model model = save_model(model)