def evaluate_model(self, epoch):

		self.model.zero_grad()
		self.model.eval()
		
	
		all_tys   = None
		all_preds = None

		
		true_and_prediction = []
		if cf.TASK == "end_to_end":
			for (i, (batch_x, batch_y, batch_z, _, batch_tx, batch_ty, batch_tm)) in enumerate(self.data_loader):

				if len(batch_x) < cf.BATCH_SIZE:
					continue


				batch_y = batch_y.float().to(device)

				

				# 3. Build the token to wordpiece mapping using batch_tm, built during the build_data stage.
				token_idxs_to_wp_idxs = build_token_to_wp_mapping(batch_tm)



				non_padding_indexes = torch.BoolTensor((batch_tx > 0))

				if cf.EMBEDDING_MODEL == "bert":
					wordpieces = batch_to_wordpieces(batch_x, self.wordpiece_vocab)

					# 2. Encode the wordpieces into Bert vectors
					bert_embs  = wordpieces_to_bert_embs(wordpieces, self.bc)
					bert_embs = bert_embs.to(device)

					y_hat = self.model(bert_embs)

					#loss = model.calculate_loss(y_hat, batch_x, batch_y, batch_z)

					

					# 4. Retrieve the token predictions for this batch, from the model.
					token_preds = self.model.predict_token_labels(bert_embs, token_idxs_to_wp_idxs)

					
					
				elif cf.EMBEDDING_MODEL in ['random', 'glove', 'word2vec']:
					batch_tx_cuda = batch_tx.long().to(device)
					#print(batch_tx.size())

					y_hat = self.model(batch_tx_cuda)

					#print(y_hat[0])

					# 4. Retrieve the token predictions for this batch, from the model.
					token_preds = self.model.predict_labels(y_hat).cpu()

					#print(token_preds[0])


				
				#print token_preds, "<TP", len(token_preds)
				token_preds = token_preds[non_padding_indexes]


				batch_tx = batch_tx[non_padding_indexes]
				batch_ty = batch_ty[non_padding_indexes]
			

				if all_tys is None:
					all_tys = batch_ty
				else:
					all_tys    = torch.cat((all_tys, batch_ty))

				if all_preds is None:
					all_preds = token_preds
				else:
					all_preds = torch.cat((all_preds, token_preds))

				if i == 0:
					logger.info("\n" + self.get_tagged_sent_example(batch_tx, token_preds, batch_ty))



				


				

		elif cf.TASK == "mention_level":
			if self.model.attention_type == "scalar":
				logger.info("Component weights: " + str(self.model.component_weights))
			num_batches = len(self.data_loader)
			for (i, (batch_xl, batch_xr, batch_xa, batch_xm, batch_y)) in enumerate(self.data_loader):
				

				# 1. Convert the batch_x from wordpiece ids into wordpieces
				wordpieces_l = batch_to_wordpieces(batch_xl, self.wordpiece_vocab)
				wordpieces_r = batch_to_wordpieces(batch_xr, self.wordpiece_vocab)
				#wordpieces_a = batch_to_wordpieces(batch_xa, self.wordpiece_vocab)
				wordpieces_m = batch_to_wordpieces(batch_xm, self.wordpiece_vocab)

				# 2. Encode the wordpieces into Bert vectors
				bert_embs_l  = wordpieces_to_bert_embs(wordpieces_l, self.bc).to(device)
				bert_embs_r  = wordpieces_to_bert_embs(wordpieces_r, self.bc).to(device)				
				#bert_embs_a  = wordpieces_to_bert_embs(wordpieces_a, self.bc).to(device)
				bert_embs_m  = wordpieces_to_bert_embs(wordpieces_m, self.bc).to(device)
								
				mention_preds = self.model.evaluate(bert_embs_l, bert_embs_r, None, bert_embs_m)

				batch_y = batch_y.float().to(device)

				for j, row in enumerate(batch_y):

					labels = self.hierarchy.onehot2categories(batch_y[j])
					preds = self.hierarchy.onehot2categories(mention_preds[j])

					

					true_and_prediction.append((labels, preds))

				sys.stdout.write("\rEvaluating batch %d / %d" % (i, num_batches))

				#if all_tys is None:
				#	all_tys = batch_y
				#else:
				#	
				#	all_tys    = torch.cat((all_tys, batch_y))
				#
				#if all_preds is None:
				#	all_preds = mention_preds
				#else:
				
				#	all_preds = torch.cat((all_preds, mention_preds))
		
			

		
		# Convert all one-hot to categories

		def build_true_and_preds(tys, preds):
			true_and_prediction = []
			empty = 0
			for i, row in enumerate(tys):	
				true_cats = self.hierarchy.onehot2categories(tys[i])		
				pred_cats = self.hierarchy.onehot2categories(preds[i])
				#if pred_cats == []:
				#	empty += 1
				true_and_prediction.append((true_cats, pred_cats))	
			#if empty > 0:
			#	logger.warn("There were %d empty predictions." % empty)
			return true_and_prediction	
	


		#all_tys = all_tys.cpu()
		#all_preds = all_preds.cpu()

		#acc = accuracy_score(all_tys, all_preds)
		#micro_f1 = f1_score(all_tys, all_preds, average="micro")
		#macro_f1 = f1_score(all_tys, all_preds, average="macro")

		#logger.info("                  Micro F1: %.4f\tMacro F1: %.4f\tAcc: %.4f" % (micro_f1, macro_f1, acc))

		if cf.TASK == "end_to_end":

			print(all_tys)
			# Filter out any completely-zero rows in batch_ty, i.e. the words that are not entities
			nonzeros = torch.nonzero(all_tys)
			indexes = torch.index_select(nonzeros, dim=1, index=torch.tensor([0])).view(-1)
			indexes = torch.unique(indexes)
			filtered_tys = all_tys[indexes]
			filtered_preds = all_preds[indexes]

			filtered_acc = accuracy_score(filtered_tys, filtered_preds)
			filtered_micro_f1 = f1_score(filtered_tys, filtered_preds, average="micro")
			filtered_macro_f1 = f1_score(filtered_tys, filtered_preds, average="macro")

			# Predictable: only considers labels that appear in the test hierarchy. A category is not 'predictable' if it only appears in the training hierarchy.
			overlapping_category_ids = self.hierarchy.get_overlapping_category_ids()

			predictable_tys = all_tys[:, overlapping_category_ids]
			predictable_preds = all_preds[:, overlapping_category_ids]

			predictable_acc = accuracy_score(predictable_tys, predictable_preds)
			predictable_micro_f1 = f1_score(predictable_tys, predictable_preds, average="micro")
			predictable_macro_f1 = f1_score(predictable_tys, predictable_preds, average="macro")

			# Filtered + Predictable: Combines Filter + Predictable, i.e. entities only, and categories that appear in the training hierarchy
			filtered_predictable_tys = filtered_tys[:, overlapping_category_ids]
			filtered_predictable_preds = filtered_preds[:, overlapping_category_ids]

			filtered_predictable_acc = accuracy_score(filtered_predictable_tys, filtered_predictable_preds)
			filtered_predictable_micro_f1 = f1_score(filtered_predictable_tys, filtered_predictable_preds, average="micro")
			filtered_predictable_macro_f1 = f1_score(filtered_predictable_tys, filtered_predictable_preds, average="macro")
		

			logger.info("Classification report (all):")
			logger.info("\n" + classification_report(all_tys, all_preds, target_names=self.hierarchy.categories))

			logger.info("Classification report (filtered, test categories only):")
			logger.info("\n" + classification_report(predictable_tys, predictable_preds, target_names=self.hierarchy.get_overlapping_categories()))
		
		
			logger.info("(Filtered)        Micro F1: %.4f\tMacro F1: %.4f\tAcc: %.4f" % (filtered_micro_f1, filtered_macro_f1, filtered_acc))
			logger.info("(Predictable)     Micro F1: %.4f\tMacro F1: %.4f\tAcc: %.4f" % (predictable_micro_f1, predictable_macro_f1, predictable_acc))
			logger.info("(F + Predictable) Micro F1: %.4f\tMacro F1: %.4f\tAcc: %.4f" % (filtered_predictable_micro_f1, filtered_predictable_macro_f1, filtered_predictable_acc))

		
		


			logger.info("\nUsing NFGEC:")
			nfgec_default  			= build_true_and_preds(all_tys, all_preds)
			nfgec_filtered 			= build_true_and_preds(filtered_tys, filtered_preds)	
			nfgec_predictable 		= build_true_and_preds(predictable_tys, predictable_preds)
			nfgec_filtered_predictable	 = build_true_and_preds(filtered_predictable_tys, filtered_predictable_preds)
	
			logger.info("                  Micro F1: %.4f\tMacro F1: %.4f\tAcc: %.4f" % (nfgec_evaluate.loose_micro(nfgec_default)[2], nfgec_evaluate.loose_macro(nfgec_default)[2], nfgec_evaluate.strict(nfgec_default)[2]))	
			logger.info("(Filtered)        Micro F1: %.4f\tMacro F1: %.4f\tAcc: %.4f" % (nfgec_evaluate.loose_micro(nfgec_filtered)[2], nfgec_evaluate.loose_macro(nfgec_filtered)[2], nfgec_evaluate.strict(nfgec_filtered)[2]))		
			logger.info("(Predictable)     Micro F1: %.4f\tMacro F1: %.4f\tAcc: %.4f" % (nfgec_evaluate.loose_micro(nfgec_predictable)[2], nfgec_evaluate.loose_macro(nfgec_predictable)[2], nfgec_evaluate.strict(nfgec_predictable)[2]))	
			logger.info("(F + Predictable) Micro F1: %.4f\tMacro F1: %.4f\tAcc: %.4f" % (nfgec_evaluate.loose_micro(nfgec_filtered_predictable)[2], nfgec_evaluate.loose_macro(nfgec_filtered_predictable)[2], nfgec_evaluate.strict(nfgec_filtered_predictable)[2]))	



			return (filtered_micro_f1 + filtered_macro_f1 + predictable_micro_f1 + predictable_macro_f1 + filtered_predictable_micro_f1 + filtered_predictable_macro_f1) / 6

		elif cf.TASK == "mention_level":
			print("")
			print(len(true_and_prediction))
			#nfgec_default  			= build_true_and_preds(all_tys, all_preds)
			micro, macro, acc = nfgec_evaluate.loose_micro(true_and_prediction)[2], nfgec_evaluate.loose_macro(true_and_prediction)[2], nfgec_evaluate.strict(true_and_prediction)[2]
			logger.info("                  Micro F1: %.4f\tMacro F1: %.4f\tAcc: %.4f" % (micro, macro, acc))
			return (acc + macro + micro) / 3
示例#2
0
def train(model,
          data_loaders,
          word_vocab,
          wordpiece_vocab,
          hierarchy,
          ground_truth_triples,
          epoch_start=1):

    logger.info("Training model.")

    # Set up a new Bert Client, for encoding the wordpieces
    bc = BertClient()

    modelEvaluator = ModelEvaluator(model, data_loaders['dev'], word_vocab,
                                    wordpiece_vocab, hierarchy,
                                    ground_truth_triples, cf)

    #optimizer = optim.SGD(filter(lambda p: p.requires_grad, model.parameters()), lr=cf.LEARNING_RATE, momentum=0.9)
    optimizer = optim.Adam(filter(lambda p: p.requires_grad,
                                  model.parameters()),
                           lr=cf.LEARNING_RATE)  #, momentum=0.9)
    model.cuda()
    print(cf.LEARNING_RATE)

    num_batches = len(data_loaders["train"])
    max_epochs = 1000
    progress_bar = ProgressBar(num_batches=num_batches,
                               max_epochs=max_epochs,
                               logger=logger)
    avg_loss_list = []

    # Train the model

    for epoch in range(epoch_start, max_epochs + 1):
        epoch_start_time = time.time()
        epoch_losses = []

        for (i, (batch_x, batch_y, batch_z, _, batch_tx, _,
                 _)) in enumerate(data_loaders["train"]):

            if len(batch_x) < cf.BATCH_SIZE:
                continue

            # 1. Convert wordpiece ids into wordpiece tokens
            wordpieces = batch_to_wordpieces(batch_x, wordpiece_vocab)
            wordpiece_embs = wordpieces_to_bert_embs(wordpieces, bc)

            # 2. Create sin embeddings and concatenate them to the bert embeddings

            wordpiece_embs = wordpiece_embs.to(device)
            batch_y = batch_y.float().to(device)
            batch_z = batch_z.float().to(device)

            # 3. Feed these vectors to our model

            if cf.POSITIONAL_EMB_DIM > 0:
                sin_embs = SinusoidalPositionalEmbedding(
                    embedding_dim=cf.POSITIONAL_EMB_DIM,
                    padding_idx=0,
                    left_pad=True)
                sin_embs = sin_embs(
                    torch.ones([batch_x.size()[0],
                                batch_x.size()[1]])).to(device)
                joined_embs = torch.cat((wordpiece_embs, sin_embs), dim=2)
            else:
                joined_embs = wordpiece_embs

            # if len(batch_x) < cf.BATCH_SIZE:
            # 	zeros = torch.zeros((cf.BATCH_SIZE - len(batch_x), joined_embs.size()[1], joined_embs.size()[2])).to(device)
            # 	joined_embs = torch.cat((joined_embs, zeros), dim=0)
            # 	print(joined_embs)
            # 	print(joined_embs.size())

            model.zero_grad()
            model.train()

            y_hat = model(joined_embs)

            loss = model.calculate_loss(y_hat, batch_x, batch_y, batch_z)

            # 4. Backpropagate
            loss.backward()
            optimizer.step()
            epoch_losses.append(loss)

            # 5. Draw the progress bar
            progress_bar.draw_bar(i, epoch, epoch_start_time)

        avg_loss = sum(epoch_losses) / float(len(epoch_losses))
        avg_loss_list.append(avg_loss)

        progress_bar.draw_completed_epoch(avg_loss, avg_loss_list, epoch,
                                          epoch_start_time)

        modelEvaluator.evaluate_every_n_epochs(1, epoch)
示例#3
0
def train(model, data_loaders, word_vocab, wordpiece_vocab, hierarchy, epoch_start = 1):

	logger.info("Training model.")
	
	# Set up a new Bert Client, for encoding the wordpieces
	if cf.EMBEDDING_MODEL == "bert":
		bc = BertClient()
	else:
		bc = None

	modelEvaluator = ModelEvaluator(model, data_loaders['dev'], word_vocab, wordpiece_vocab, hierarchy, bc)
	
	#optimizer = optim.SGD(filter(lambda p: p.requires_grad, model.parameters()), lr=cf.LEARNING_RATE, momentum=0.9)
	optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=cf.LEARNING_RATE)#, eps=1e-4, amsgrad=True)#, momentum=0.9)
	model.cuda()


	num_batches = len(data_loaders["train"])
	print(num_batches)
	progress_bar = ProgressBar(num_batches = num_batches, max_epochs = cf.MAX_EPOCHS, logger = logger)
	avg_loss_list = []

	# Train the model

	for epoch in range(epoch_start, cf.MAX_EPOCHS + 1):
		epoch_start_time = time.time()
		epoch_losses = []

		if cf.TASK == "end_to_end":
			if cf.BATCH_SIZE != 10:
				print("Warning: batch size must currently be set to 10 for the end-to-end model.")
			for (i, (batch_x, batch_y, batch_z, _, batch_tx, batch_ty, _)) in enumerate(data_loaders["train"]):


				if len(batch_x) < cf.BATCH_SIZE:
					continue

				batch_y = batch_y.float().to(device)
				batch_z = batch_z.float().to(device)

				model.zero_grad()
				model.train()

				#if i > 1:
				#	continue
				# 1. Convert the batch_x from wordpiece ids into wordpieces
				if cf.EMBEDDING_MODEL == "bert":
					wordpieces = batch_to_wordpieces(batch_x, wordpiece_vocab)


			
					# 2. Encode the wordpieces into Bert vectors
					bert_embs  = wordpieces_to_bert_embs(wordpieces, bc)

					bert_embs = bert_embs.to(device)

					y_hat = model(bert_embs)

					loss = model.calculate_loss(y_hat, batch_x, batch_y, batch_z)
					
				elif cf.EMBEDDING_MODEL in ['random', 'glove', 'word2vec']:
					batch_tx_cuda = batch_tx.long().to(device)
					batch_ty = batch_ty.float().to(device)

					#print(batch_tx.size())

					y_hat = model(batch_tx_cuda)

					loss = model.calculate_loss(y_hat, batch_tx, batch_ty, batch_z)

				

				# 3. Feed these Bert vectors to our model
				

				

				# 4. Backpropagate
				loss.backward()
				optimizer.step()
				epoch_losses.append(loss)

				# 5. Draw the progress bar
				progress_bar.draw_bar(i, epoch, epoch_start_time)

		elif cf.TASK == "mention_level":
			for (i, (batch_xl, batch_xr, batch_xa, batch_xm, batch_y)) in enumerate(data_loaders["train"]):

				#torch.cuda.empty_cache()
				#if i > 1:
				#	continue
				# 1. Convert the batch_x from wordpiece ids into wordpieces
				wordpieces_l = batch_to_wordpieces(batch_xl, wordpiece_vocab)
				wordpieces_r = batch_to_wordpieces(batch_xr, wordpiece_vocab)
				#wordpieces_a = batch_to_wordpieces(batch_xa, wordpiece_vocab)
				wordpieces_m = batch_to_wordpieces(batch_xm, wordpiece_vocab)
	
				
				#print len(wordpieces_l[0]), len(wordpieces_r[0]), len(wordpieces_m[0])
				

				#print len(wordpieces_l[0]),  len(wordpieces_r[0]),  len(wordpieces_a[0]) ,  len(wordpieces_m[0])
				

				# 2. Encode the wordpieces into Bert vectors
				bert_embs_l  = wordpieces_to_bert_embs(wordpieces_l, bc).to(device)
				bert_embs_r  = wordpieces_to_bert_embs(wordpieces_r, bc).to(device)				
				#bert_embs_a  = wordpieces_to_bert_embs(wordpieces_a, bc).to(device)
				bert_embs_m  = wordpieces_to_bert_embs(wordpieces_m, bc).to(device)
				
				batch_y = batch_y.float().to(device)	

				# 3. Feed these Bert vectors to our model
				model.zero_grad()
				model.train()

				y_hat = model(bert_embs_l, bert_embs_r, None, bert_embs_m)

				loss = model.calculate_loss(y_hat, batch_y)

				# 4. Backpropagate
				loss.backward()
				optimizer.step()
				epoch_losses.append(loss)

				# 5. Draw the progress bar
				progress_bar.draw_bar(i, epoch, epoch_start_time)

					

				

		avg_loss = sum(epoch_losses) / float(len(epoch_losses))
		avg_loss_list.append(avg_loss)

		progress_bar.draw_completed_epoch(avg_loss, avg_loss_list, epoch, epoch_start_time)

		#logger.info(avg_loss)

		modelEvaluator.evaluate_every_n_epochs(1, epoch)
示例#4
0
    def evaluate_model(self, epoch, data_loader, mode="training"):

        self.model.zero_grad()
        self.model.eval()

        all_txs = []
        all_tys = []
        all_tys_et = []
        all_preds = []
        all_preds_et = []

        self.model.batch_size = 1  # Set the batch size to 1 for evaluation to avoid missing any sents
        for (i, (batch_x, _, _, _, _, _, batch_tx, batch_ty_tr, batch_ty_et,
                 batch_tm)) in enumerate(data_loader):

            # 1. Convert the batch_x from wordpiece ids into wordpieces
            #if mode == "training":
            wordpieces = batch_to_wordpieces(batch_x, self.wordpiece_vocab)

            #seq_lens = [len([w for w in doc if w != "[PAD]"]) for doc in wordpieces]

            wordpiece_embs = wordpieces_to_bert_embs(wordpieces,
                                                     self.bc).to(device)

            sin_embs = SinusoidalPositionalEmbedding(
                embedding_dim=self.cf.POSITIONAL_EMB_DIM,
                padding_idx=0,
                left_pad=True)
            sin_embs = sin_embs(
                torch.ones([wordpiece_embs.size()[0],
                            self.cf.MAX_SENT_LEN])).to(device)
            joined_embs = torch.cat((wordpiece_embs, sin_embs), dim=2)

            # 3. Build the token to wordpiece mapping using batch_tm, built during the build_data stage.
            token_idxs_to_wp_idxs = build_token_to_wp_mapping(batch_tm)

            non_padding_indexes = torch.ByteTensor((batch_tx > 0))

            # 4. Retrieve the token predictions for this batch, from the model.
            token_preds_tr, token_preds_et = self.model.predict_token_labels(
                joined_embs, token_idxs_to_wp_idxs)

            for batch in token_preds_tr:
                all_preds.append(batch.int().cpu().numpy().tolist())

            for batch in token_preds_et:
                all_preds_et.append(batch.int().cpu().numpy().tolist())

            for batch in batch_tx:
                all_txs.append(batch.int().cpu().numpy().tolist())

            if mode == "training":
                for batch in batch_ty_tr:
                    all_tys.append(batch.int().cpu().numpy().tolist())
                for batch in batch_ty_et:
                    all_tys_et.append(batch.int().cpu().numpy().tolist())

        self.model.batch_size = self.cf.BATCH_SIZE
        tagged_sents = self.get_tagged_sents_from_preds(
            all_txs, all_tys, all_tys_et, all_preds, all_preds_et)
        predicted_triples = self.get_triples_from_tagged_sents(tagged_sents)
        logger.info("\n" + self.get_tagged_sent_example(tagged_sents[:1]))

        def build_true_and_preds(tys, preds):
            true_and_prediction = []
            for b, batch in enumerate(tys):
                for i, row in enumerate(batch):
                    true_cats = self.hierarchy_tr.onehot2categories(tys[b][i])
                    pred_cats = self.hierarchy_tr.onehot2categories(
                        preds[b][i])
                    true_and_prediction.append((true_cats, pred_cats))
            return true_and_prediction

        logger.info("Predicted Triples: ")
        triples_str = ""
        for idx, a in enumerate(predicted_triples):
            for t in a:
                triples_str += "%s %s\n" % (idx, ",".join([
                    "[ " + " ".join(w) + " ]" if i >= 3 else " ".join(w)
                    for i, w in enumerate(t)
                ]))
        logger.info("\n" + triples_str)

        print("")

        if mode == "training":
            # logger.info("Ground truth triples: ")

            # triples_str = ""
            # for idx, a in enumerate(self.ground_truth_triples):
            # 	for t in a:
            # 		triples_str += "%s %s\n" % (idx, ", ".join([" ".join(w) for w in t]))
            # logger.info("\n" + triples_str)

            triples_scores = []
            for pt, gt in zip(predicted_triples, self.ground_truth_triples):
                ts = evaluate_triples(
                    [[' '.join(t[0]), ' '.join(t[1]), ' '.join(t[2])]
                     for t in pt],
                    [[' '.join(t[0]), ' '.join(t[1]), ' '.join(t[2])]
                     for t in gt])
                triples_scores.append(ts)
            triples_score = sum(triples_scores) / len(triples_scores)
            #triples_score = 0.0
            true_and_predictions = build_true_and_preds(all_tys, all_preds)
            micro_f1 = loose_micro(true_and_predictions)[2]
            logger.info("                  Micro F1: %.4f\t" % (micro_f1))
            logger.info("                  Triples score: %.4f\t" %
                        (triples_score))
            return triples_score
        else:
            return predicted_triples