def train_network(dataloader, model, loss_function, optimizer, start_lr, end_lr, num_epochs=90, sanity_check=False):
	"""Trains the network and saves for different checkpoints such as minimum train/val loss, f1-score, AUC etc. different performance metrics

		Parameters:
		-----------
			dataloader (dict): {key (str):  Value(torch.utils.data.DataLoader)} training and validation dataloader to respective purposes
			model (nn.Module): models to traine the face-recognition
			loss_function (torch.nn.Module): Module to mesure loss between target and model-output
			optimizer (Optimizer): Non vanilla gradient descent method to optimize learning and descent direction
			start_lr (float): For one cycle training the start learning rate
			end_lr (float): the end learning must be greater than start learning rate
			num_epochs (int): number of epochs the one cycle is 
			sanity_check (bool): if the training is perfomed to check the sanity of the model. i.e. to anaswer 'is model is able to overfit for small amount of data?'

		Returns:
		--------
			None: perfoms the required task of training

	"""

	if isinstance(model, dict):
		for k, v in model.items():
			model[k] = v.train()
	else:
		model = model.train()

	logger_msg = '\nDataLoader = {}' \
				 '\nModel = {}' \
				 '\nLossFucntion = {}' \
				 '\nOptimizer = {}' \
				 '\nStartLR = {}, EndLR = {}' \
				 '\nNumEpochs = {}'.format(dataloader, model, loss_function, optimizer, start_lr, end_lr, num_epochs)

	logger.info(logger_msg), print(logger_msg)

	# [https://arxiv.org/abs/1803.09820]
	# This is used to find optimal learning-rate which can be used in one-cycle training policy
	# [LR]TODO: for finding optimal learning rate

	lr_scheduler = {}
	if kconfig.tr.lr_search_flag:
		if isinstance(optimizer, dict):
			for k, opt in optimizer.items():
				lr_scheduler[k] = MultiStepLR(optimizer=opt, milestones=list(np.arange(2, 24, 2)), gamma=10, last_epoch=-1)
		else:
			lr_scheduler = MultiStepLR(optimizer=optimizer, milestones=list(np.arange(2, 24, 2)), gamma=10, last_epoch=-1)
		

	# TODO: Cyclic momentum
	# optimizer.param_groups[0]['momentum'] # weight_decay
	# 0.95 -> 0.8.
	# This implies that as LR increases during 1Cycle, WD should decrease. 
	# https://forums.fast.ai/t/one-cycle-policy/25944/2
	# The large batch training literature recommends not using WD on BN, so if you are asking what your should do, don’t apply WD to BN.


	def get_lr():
		lr = []
		# pdb.set_trace()

		if isinstance(optimizer, dict): 
			for k, opt in optimizer.items():
				for param_group in opt.param_groups:
					lr.append(np.round(param_group['lr'], 11))
				break
		else:
			for param_group in optimizer.param_groups:
				lr.append(np.round(param_group['lr'], 11))
		return lr



	def set_lr(lr):
		if isinstance(optimizer, dict):
			for k, opt in optimizer.items():
				for param_group in opt.param_groups:
					param_group['lr'] = lr
		else:
			for param_group in opt.param_groups:
				param_group['lr'] = lr


	def set_momentum(m):
		# pdb.set_trace()

		if isinstance(optimizer, dict):
			for k, opt in optimizer.items():
				for param_group in opt.param_groups:
					param_group['momentum'] = m
		else:
			for param_group in opt.param_groups:
					param_group['momentum'] = m


	# 'Training': loss Containers
	train_cur_epoch_batchwise_loss = []
	train_epoch_avg_loss_container = []  # Stores loss for each epoch averged over batches.
	train_all_epoch_batchwise_loss = []

	# 'Validation': loss containers
	val_avg_loss_container = []


	# 'Validation': Metric Containers 
	val_report_container = []
	val_f1_container = []
	val_auc_container = []
	val_accuracy_container = []

	# 'Test': Metric Containers. Only computed and stored when certain condition is met. 
	# Of course, this is perfomed only when test_set with labels are present. 
	test_auc_container = {}
	test_f1_container = {}
	test_accuracy_container = {}

	
	# 'Extra' epochs
	if kconfig.tr.lr_search_flag:
		extra_epochs = kconfig.one_cycle_policy.extra_epochs.lr_search # 4 
	else:
		extra_epochs = kconfig.one_cycle_policy.extra_epochs.train  # 20
	total_epochs = num_epochs + extra_epochs


	# One cycle setting of Learning Rate
	num_steps_upndown = kconfig.one_cycle_policy.num_steps_upndown # 10
	further_lowering_factor = kconfig.one_cycle_policy.extra_epochs.lowering_factor  # 10
	further_lowering_factor_steps = kconfig.one_cycle_policy.extra_epochs.lower_after # 4


	# Cyclic Learning Rate
	def one_cycle_lr_setter(current_epoch):
		start_momentum = 0.95
		end_momentum = 0.85
		current_momentum = None
		if current_epoch <= num_epochs:
			assert end_lr > start_lr, '[EndLR] should be greater than [StartLR]'
			lr_inc_rate = np.round((end_lr - start_lr) / (num_steps_upndown), 9)
			lr_inc_epoch_step_len = max(num_epochs / (2 * num_steps_upndown), 1)

			steps_completed = current_epoch / lr_inc_epoch_step_len
			print('[Steps Completed] = ', steps_completed)
			if steps_completed <= num_steps_upndown:
				current_lr = start_lr + (steps_completed * lr_inc_rate)
				current_momentum = start_momentum - ((start_momentum - end_momentum) * int(steps_completed) / num_steps_upndown)
			else:
				current_lr = end_lr - ((steps_completed - num_steps_upndown) * lr_inc_rate)
				current_momentum = end_momentum + ((start_momentum - end_momentum) * int(steps_completed - num_steps_upndown) / num_steps_upndown)

			set_lr(current_lr)
			# set_momentum(current_momentum)
		else:
			current_lr = start_lr / (
						further_lowering_factor ** ((current_epoch - num_epochs) // further_lowering_factor_steps))
			set_lr(current_lr)



	if sanity_check:
		train_dataloader = next(iter(dataloader['train']))
		train_dataloader = [train_dataloader] * 128
	else:
		train_dataloader = dataloader['train']


	def reset_grad(optimizer):
		# Zero Grad
		if isinstance(optimizer, dict):
			for k, opt in optimizer.items():
				opt.zero_grad()
		else:
			optimizer.zero_grad()


	# Model Tranining for 'total_epochs' 
	counter = 0
	ep_ctr = 0
	for epoch in range(total_epochs):
		msg = '\n\n\n[Epoch] = {}'.format(epoch + 1)
		print(msg)
		start_time = time.time()
		start_datetime = datetime.now()
		
		for i, (X, y) in enumerate(train_dataloader): 

			y_cls_lbs, y_src_lbs = y
			# print(f'[Class Labels Counts] = {y_cls_lbs.unique(return_counts=True)   }', end='')
			# print(f'[Source Labels Counts] = {y_src_lbs.unique(return_counts=True)   }', end='')


			X = X.to(device=device, dtype=torch.float32) # or float is alias for float32

			# pdb.set_trace()


			# ep_ctr += 1
			# if ep_ctr % 2 == 0 and (kconfig.tr.train_flag or kconfig.tr.sanity_check_flag):
			# 	src_idx_randperm = torch.randperm(len(y_src_lbs))
			# 	y_src_lbs = y_src_lbs[src_idx_randperm]


			y_cls_lbs = y_cls_lbs.to(device=device, dtype=torch.long)
			y_src_lbs = y_src_lbs.to(device=device, dtype=torch.long)

			y = (y_cls_lbs, y_src_lbs)

			
			# TODO: early breaker
			if kconfig.tr.early_break and i == 3:
				print('[Break] by force for validation check')
				break




			# 
			if isinstance(model, dict):
				features_repr = model['feature_repr_model'](X)

				cls_output = model['cls_model'](features_repr)
				src_output = model['src_model'](features_repr)

			else:	
				output = model(X) #


			# pdb.set_trace()

			# Reset gradient
			reset_grad(optimizer)
 

			# Domain-Adversarial Training of Neural Networks: https://arxiv.org/abs/1505.07818
			if isinstance(loss_function, dict):
				cls_loss = loss_function['cls_loss'](cls_output, y_cls_lbs)
				src_loss = loss_function['src_loss'](src_output, y_src_lbs)

				feature_loss = cls_loss - src_loss

				# pdb.set_trace()


				feature_loss.backward()


				# print(f"\n[Wt] = {model['feature_repr_model'].fc1[0].weight[:10, ...]}")
				# print(f"\n[grad] = {model['feature_repr_model'].fc1[0].weight.grad[:10, ...]}")

				

				optimizer['feature_repr_opt'].step()

				print('[After] step')
				# print(f"\n[Wt] = {model['feature_repr_model'].fc1[0].weight[:10, ...]}")
				# print(f"\n[grad] = {model['feature_repr_model'].fc1[0].weight.grad[:10, ...]}")



				# print('[Check what happens to gradient]')


				reset_grad(optimizer)


				# pdb.set_trace()

				features_repr = features_repr.detach()

				cls_output = model['cls_model'](features_repr)
				src_output = model['src_model'](features_repr)


				# cls_output = model['cls_model'](features_repr)
				cls_loss = loss_function['cls_loss'](cls_output, y_cls_lbs)
				cls_loss.backward()

				optimizer['cls_opt'].step()



				# src_output = model['src_model'](features_repr)
				src_loss = loss_function['src_loss'](src_output, y_src_lbs)
				src_loss.backward()

				optimizer['src_opt'].step()


				loss = feature_loss + cls_loss + src_loss


			else:
				loss = loss_function(output, y)
				loss.backward()
				optimizer.step()




			# gap = 1
			# if counter > 1 and counter % gap == 0:
			# 	print()
			# 	print('\n\n\n[BBBBBBefore loss.backward()]')
			# 	print_wt_n_output(model, y, output, optimizer=optimizer)
				
			# loss.backward()
			
			# if counter > 1 and counter % gap == 0:
			# 	print('\n\n\n[AAAAAAfter loss.backward()]')
			# 	print_wt_n_output(model, y, output, optimizer=optimizer)
			
			# optimizer.step()
			
			# if counter > 1 and counter % gap == 0:
			# 	print('\n\n\n[AAAAAfter optimizer.step()]')
			# 	print_wt_n_output(model, y, output, optimizer=optimizer)
			# counter += 1

			# pdb.set_trace()
			# if isinstance(loss_function, dict):

			# 	# check <model['feature_repr_model']> grad before and after also 
			# 	# check <model['cls_model']>	
			# 	feature_loss.backward()


			# 	cls_loss.backward()
			# 	src_loss.backward()

			# else:
			# 	loss.backward()


			# set_momentum(0.90)


			train_cur_epoch_batchwise_loss.append(loss.item())
			train_all_epoch_batchwise_loss.append(loss.item())

			batch_run_msg = '\nEpoch: [%s/%s], Step: [%s/%s], InitialLR: %s, CurrentLR: %s, Loss: %s' \
							% (epoch + 1, total_epochs, i + 1, len(train_dataloader), start_lr, get_lr(), loss.item())
			print(batch_run_msg)
		#------------------ End of an Epoch ------------------ 
		
		# store average loss
		epoch_avg_loss = np.round(sum(train_cur_epoch_batchwise_loss) / (i + 1.0), 6)
		train_cur_epoch_batchwise_loss = []
		train_epoch_avg_loss_container.append(epoch_avg_loss)
		

		# 'Validation': xompute metrics the dataset for saving the models at checkpoints.
		if not (kconfig.tr.lr_search_flag or sanity_check):
			val_loss, val_report, f1_checker, auc_val = cal_loss_and_metric(model, dataloader['val'], loss_function, epoch+1)
			val_report['roc'] = 'Removed'
		

		# 'Validation': save model if certain condition is met on the computed metrics. 
		test_test_data = False
		accuracy = None 
		if not (kconfig.tr.lr_search_flag or sanity_check):
			val_report_container.append(val_report)  # ['epoch_' + str(epoch)] = val_report

			# Check point for which models will be saved
			val_avg_loss_container.append(val_loss)
			val_f1_container.append(f1_checker)	
			val_auc_container.append(auc_val)

			accuracy = val_report.get('accuracy', None)
			val_accuracy_container.append(accuracy)

			if np.round(val_loss, 4) <= np.round(min(val_avg_loss_container), 4):
				model = save_model(model, extra_extension='_minval') # + '_epoch_' + str(epoch))

			if np.round(auc_val, 4) >= np.round(max(val_auc_container), 4):
				model = save_model(model, extra_extension='_maxauc') # + '_epoch_' + str(epoch))
				test_test_data = True

			if np.round(f1_checker, 4) >= np.round(max(val_f1_container), 4):
				model = save_model(model, extra_extension='_maxf1') # + '_epoch_' + str(epoch))
				test_test_data = True



		# Save
		if epoch_avg_loss <= min(train_epoch_avg_loss_container):
			model = save_model(model, extra_extension='_mintrain')


		
		# Logger msg
		msg = '\n\n\n\n\nEpoch: [%s/%s], InitialLR: %s, CurrentLR= %s \n' \
			  '\n\n[Train] Average Epoch-wise Loss = %s \n' \
			  '\n\n********************************************************** [Validation]' \
			  '\n\n[Validation] Average Epoch-wise loss = %s \n' \
			  '\n\n[Validation] Report () = %s \n'\
			  '\n\n[Validation] F-Report = %s\n'\
			  '\n\n[Validation] Accuracy = %s\n'\
			  %(epoch+1, total_epochs, start_lr, get_lr(), train_epoch_avg_loss_container, val_avg_loss_container, None if not val_report_container else util.pretty(val_report_container[-1]), val_f1_container, val_accuracy_container)
		logger.info(msg); print(msg)


		# 'Test': compute metrics on the test dataset. Again, of course, only if it is present. 
		if not (kconfig.tr.lr_search_flag or sanity_check) and test_test_data and dataloader.get('test', False):
			test_loss, test_report, test_f1_checker, test_auc = cal_loss_and_metric(model, dataloader['test'], loss_function, epoch+1, model_type='test_set')
			
			test_report['roc'] = 'Removed'
			accuracy = test_report.get('accuracy', None)

			
			test_auc_container[epoch+1] = "{0:.3f}".format(round(test_auc, 4)) 
			test_f1_container[epoch+1] = "{0:.3f}".format(round(test_f1_checker, 4))
			test_accuracy_container[epoch+1] = "{}".format(accuracy)

			
			msg = '\n\n\n\n**********************************************************[Test]\n '\
				  '[Test] Report = {}' \
				  '\n\n[Test] fscore = {}' \
				  '\n\n[Test] AUC dict = {}' \
				  '\n\n[Test] F1-dict = {}'\
				  '\n\n[Test] Accuracy = {}'.format(util.pretty(test_report), test_f1_checker, test_auc_container, test_f1_container, test_accuracy_container)

			logger.info(msg); print(msg)

		
		# Strop training if the 'model' is already converged. 
		if epoch_avg_loss < 1e-6 or get_lr()[0] < 1e-11 or get_lr()[0] >= 10:
			msg = '\n\nAvg. Loss = {} or Current LR = {} thus stopping training'.format(epoch_avg_loss, get_lr())
			logger.info(msg)
			print(msg)
			break
			
		
		# Cyclic alteration of 'LR' during up and down steps movements.
		if kconfig.tr.lr_search_flag:
			# lr_scheduler.step(epoch + 1) # TODO: Only for estimating good learning rate
			if isinstance(lr_scheduler, dict):
				for key, lr_scheduler_eg in lr_scheduler.items():
					lr_scheduler_eg.step(epoch + 1)
		else:
			one_cycle_lr_setter(epoch + 1)

		# Time keeping for training epoch time. 
		end_time = time.time()
		end_datetime = datetime.now()
		msg = '\n\n[Time] taken for epoch({}) time = {}, datetime = {} \n\n'.format(epoch+1, end_time - start_time, end_datetime - start_datetime)
		logger.info(msg); print(msg)

	# ----------------- End of training process -----------------

	msg = '\n\n[Epoch Loss] = {}'.format(train_epoch_avg_loss_container)
	logger.info(msg); print(msg)

	
	# [LR]TODO: change for lr finder
	if kconfig.tr.lr_search_flag:
		losses = train_epoch_avg_loss_container
		plot_file_name = 'training_epoch_loss_for_lr_finder.png'
		title = 'Training Epoch Loss'
	else:
		losses = {'train': train_epoch_avg_loss_container, 'val': val_avg_loss_container}
		plot_file_name = 'training_vs_val_epoch_avg_loss.png'
		title= 'Training vs Validation Epoch Loss'
	plot_loss(losses=losses,
			plot_file_name=plot_file_name,
			title=title)
	plot_loss(losses=train_all_epoch_batchwise_loss, plot_file_name='training_batchwise.png', title='Training Batchwise Loss',
			xlabel='#Batchwise')
		

	# Save the model		
	model = save_model(model)