def train(args, model, tokenizer, train_dataset):
	train_sampler = RandomSampler(train_dataset)
	train_dataloader = DataLoader(train_dataset, sampler=train_sampler, batch_size=args.train_batch_size)
	bert_path = "./models/model_out_bert_cased"

	if args.max_steps > 0:
	    t_total = args.max_steps
	    args.num_train_epochs = args.max_steps // (len(train_dataloader) // args.gradient_accumulation_steps) + 1
	else:
	    t_total = len(train_dataloader) // args.gradient_accumulation_steps * args.num_train_epochs

	# Prepare optimizer and schedule (linear warmup and decay)
	no_decay = ["bias", "LayerNorm.weight"]
	optimizer_grouped_parameters = [
	    {
	        "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
	        "weight_decay": args.weight_decay,
	    },
	    {"params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], "weight_decay": 0.0},
	]

	optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate)
	scheduler = get_linear_schedule_with_warmup(
	    optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=t_total
	)

	print("***** Running training *****")
	print(f"  Num examples = {len(train_dataset)}")
	print(f"  Num Epochs = {args.num_train_epochs}" )
	print(
	    f"  Total train batch size (w. accumulation) = {args.train_batch_size * args.gradient_accumulation_steps}"
	)
	print(f"  Gradient Accumulation steps = {args.gradient_accumulation_steps}")
	print(f"  Total optimization steps = {t_total}", )

	global_step = 0
	tr_loss, b_loss, logging_loss, b_logging_loss = 0.0, 0.0, 0.0, 0.0
	best_dev_acc = 0.0
	best_steps = 0
	topk = 10
	model.zero_grad()
	genres = ['<Comedy>', '<Action>', '<Adventure>', '<Crime>', '<Drama>', '<Fantasy>', '<Horror>', '<Romance>', '<Sci-Fi>', '<Thriller>']
	train_iterator = trange(int(args.num_train_epochs), desc="Epoch")
	# context_tokens = tokenizer.encode(args.genres)


	# model_path = './models'
	if not os.path.exists(args.output_dir):
	    os.mkdir(args.output_dir)
	bert_loss = 0

	for epoch in train_iterator:
		epoch_iterator = tqdm(train_dataloader, desc = "Iteration")
		for step, batch in enumerate(epoch_iterator):
			model.train()
			batch = tuple(t.to(args.device) for t in batch)
			# print(batch)
			inputs = {
				"input_ids": batch[0],
				"attention_mask": batch[1],
				"token_type_ids": batch[2],
				"labels": batch[0]
				}
			outputs = model(**inputs)
			# print(outputs[1].shape)
			logits = F.softmax(outputs[1], dim=-1)
			predictions = torch.topk(logits, k=1, dim=-1)[1].squeeze()
			# print(predictions.shape)
			loss = outputs[0]
			
			# text = tokenizer.decode(out[i])

			# kdjfdhg
			bert_loss = 0
			for i in range(predictions.shape[0]):
				p = random.uniform(0, 1)
				if p > 0.1:
					continue
				out = predictions[i, :].tolist()
				orig = batch[0][i, :].tolist()
				o_genres = []
				for j in range(len(batch[0][i, :])):
					token = tokenizer.decode([batch[0][i,j]])
					if token in genres:
						o_genres.append(token.replace("<","").replace(">",""))
				pred = classify_bert(tokenizer.decode(out), bert_path)
				inter_loss = 0
				c = 0
				for l in o_genres:
					inter_loss += (1 - pred[l])
					c += 1
				bert_loss += float(inter_loss) / c
			bert_loss = float(bert_loss) / predictions.shape[0]
			# print(bert_loss)
			# ksldjf

			if args.gradient_accumulation_steps > 1:
				loss = loss / args.gradient_accumulation_steps
			loss += bert_loss
			loss.backward()
			torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm)
			tr_loss += loss.item()
			b_loss += bert_loss
			if (step + 1) % args.gradient_accumulation_steps == 0:
				optimizer.step()
				scheduler.step()  # Update learning rate schedule
				model.zero_grad()
				global_step += 1
				if args.logging_steps > 0 and global_step%args.logging_steps == 0:
                    print(f"\nAverage loss: {(tr_loss - logging_loss) / args.logging_steps}, BERT loss: {(b_loss - b_logging_loss)/ args.logging_steps} at global step: {global_step}")
					# print(f"\nAverage loss: {(tr_loss - logging_loss) / args.logging_steps} at global step: {global_step}")
					logging_loss = tr_loss
					b_logging_loss = b_loss
				if args.save_steps > 0 and global_step % args.save_steps == 0:
					# torch.save(model.state_dict(), os.path.join(model_path, f"{args.model_type}_funnies_{epoch}.pt"))
					output_dir = os.path.join(args.output_dir, "checkpoint-{}".format(global_step))
					if not os.path.exists(output_dir):
					    os.makedirs(output_dir)
					model_to_save = (
					    model.module if hasattr(model, "module") else model
					)  
					model_to_save.save_pretrained(output_dir)
					tokenizer.save_vocabulary(output_dir)
					torch.save(args, os.path.join(output_dir, "training_args.bin"))
					print(f"Saving model checkpoint to {output_dir}")
			if args.max_steps > 0 and global_step > args.max_steps:
				epoch_iterator.close()
				break
		if args.max_steps > 0 and global_step > args.max_steps:
			train_dataloader.close()
			break