Beispiel #1
0
 def predict(inp, tar, enc_padding_mask):
     predictions = transformer(inp,
                               False,
                               enc_padding_mask=enc_padding_mask)
     mi_f1 = micro_f1(tar, predictions)
     ma_f1 = macro_f1(tar, predictions)
     return mi_f1, ma_f1
def test(model, x_test, y_true):
    print("Test...")
    y_pred = model.predict(x_test)

    y_true = tf.constant(y_true, tf.float32)
    y_pred = tf.constant(y_pred, tf.float32)
    print(micro_f1(y_true, y_pred))
    print(macro_f1(y_true, y_pred))
Beispiel #3
0
    def train_step(inp, tar):
        enc_padding_mask = create_padding_mask(inp)

        with tf.GradientTape() as tape:
            predictions = transformer(inp,
                                      True,
                                      enc_padding_mask=enc_padding_mask)
            loss = loss_function(tar, predictions)
        gradients = tape.gradient(loss, transformer.trainable_variables)
        optimizer.apply_gradients(
            zip(gradients, transformer.trainable_variables))

        train_loss(loss)
        train_accuracy(tar, predictions)

        mi_f1 = micro_f1(tar, predictions)
        ma_f1 = macro_f1(tar, predictions)
        return mi_f1, ma_f1
    def evaluate(test_dataset):
        predictions = []
        tars = []
        for (batch, (inp, tar)) in tqdm(enumerate(test_dataset)):
            enc_padding_mask = create_padding_mask(inp)
            predict = transformer(inp, False, enc_padding_mask=enc_padding_mask)
            predictions.append(predict)
            tars.append(tar)
        predictions = tf.concat(predictions, axis=0)
        tars = tf.concat(tars, axis=0)
        mi_f1 = micro_f1(tars, predictions)
        ma_f1 = macro_f1(tars, predictions)

        predictions = np.where(predictions > 0.5, 1, 0)
        tars = np.where(tars > 0.5, 1, 0)

        smaple_f1 = f1_score(tars, predictions, average='samples')
        return mi_f1, ma_f1, smaple_f1, tars, predictions
                '基因的自由组合规律的实质及应用', '郡县制', '人体水盐平衡调节', '内质网的结构和功能', '人体的体温调节',
                '免疫系统的功能', '科学社会主义常识', '与细胞分裂有关的细胞器', '太阳对地球的影响', '古代史', '清末民主革命风潮',
                '复等位基因', '人工授精、试管婴儿等生殖技术', '“重农抑商”政策', '生态系统的营养结构', '减数分裂的概念',
                '地球的外部圈层结构及特点', '细胞的多样性和统一性', '政治', '工业区位因素', '细胞大小与物质运输的关系',
                '夏商两代的政治制度', '农业区位因素', '溶酶体的结构和功能', '生产活动与地域联系', '内环境的稳态', '遗传与进化',
                '胚胎移植', '生物科学与社会', '近代史', '第三产业的兴起和“新经济”的出现', '公民道德与伦理常识', '中心体的结构和功能',
                '社会主义市场经济的伦理要求', '高中', '选官、用官制度的变化', '减数分裂与有丝分裂的比较', '遗传的细胞基础',
                '地球所处的宇宙环境', '培养基与无菌技术', '生活中的法律常识', '高尔基体的结构和功能', '社会主义是中国人民的历史性选择',
                '人口迁移与人口流动', '现代史', '地球与地图', '走进细胞', '生物', '避孕的原理和方法', '血糖平衡的调节',
                '现代生物技术专题', '海峡两岸关系的发展', '生命活动离不开细胞', '兴奋在神经元之间的传递', '历史', '分子与细胞',
                '拉马克的进化学说', '遗传的分子基础', '稳态与环境']

mlb.fit([[label] for label in all_labels])

true_file = './data/all_labels/test/predicate_out.txt'
predcit_file = './output/epochs6_baidu_95/predicate_predict.txt'

y_true, y_pred = [], []
with open(true_file, encoding='utf8') as f:
    for line in f.readlines():
        y_true.append(line.split())

with open(predcit_file, encoding='utf8') as f:
    for line in f.readlines():
        y_pred.append(line.split())

y_true = mlb.transform(y_true)
y_pred = mlb.transform(y_pred)

print('micro_f1, macro_f1:', micro_f1(y_true, y_pred), macro_f1(y_true, y_pred))
Beispiel #6
0
    def train(self):

        train_dataloader = iter(self.train_loader.run(self.g, self.num_neighbors, self.num_layers))
        val_dataloader = iter(self.val_loader.run(self.g, None, self.num_layers))
        test_dataloader = iter(self.test_loader.run(self.g, None, self.num_layers))



        dur = []
        train_losses = []  # per mini-batch
        train_accuracies = []
        val_losses = []
        val_accuracies = []

        best_val_acc = -1
        best_val_result = (0, 0, 0, 0)
        best_val_y = None

        num_train_samples = len(self.train_loader)
        num_train_batches = (num_train_samples - 1) // self.batch_size + 1

        num_val_samples = len(self.val_loader)
        num_val_batches = (num_val_samples - 1) // self.batch_size + 1


        t0 = time.time()
        # Training loop
        for e in range(self.epochs):

            train_losses_temp = []
            train_accuracies_temp = []
            val_losses_temp = []
            val_accuracies_temp = []

            # minibatch train
            train_num_correct = 0  # number of correct prediction in validation set
            train_total_losses = 0  # total cross entropy loss
            if e >= 2:
                dur.append(time.time() - t0)
            pred_temp = np.array([])
            label_temp = np.array([])

            self.model.train()

            for step in range(num_train_batches):

                if step != num_train_batches - 1:
                    batch_size = self.batch_size
                else:
                    # last batch
                    batch_size = num_train_samples - step * self.batch_size
                
                data = next(train_dataloader)

                ###########################################
                # compute embedding, order: target:pos:neg 
                # input_nodes, output_nodes, blocks, batch_labels = data
                
                # input_nodes = input_nodes.to(self.device)
                # output_nodes = output_nodes.to(self.device)                    
                # blocks = [block.int().to(self.device) for block in blocks]
                # batch_labels = batch_labels.to(self.device)

                input_nodes, blocks, batch_edges, batch_labels = data

                input_nodes = input_nodes.to(self.device)
                blocks = [block.int().to(self.device) for block in blocks]
                batch_edges = batch_edges.to(self.device)
                batch_labels = batch_labels.to(self.device)
            
                node_embedding = self.model(input_nodes, blocks)

                logits = self.compute_logits(node_embedding, batch_edges)
                  

                # collect outputs
                pred = torch.sigmoid(logits) > 0.5
                pred_temp=np.append(pred_temp, pred.long().cpu().detach().numpy())
                label_temp=np.append(label_temp, batch_labels.cpu())
                # print('logits', logits.shape, batch_labels.shape)

                ###########################################
                # update step
                

                # logits = torch.cat(logits_tmp, 0)
                # batch_labels = torch.cat(labels_tmp, 0)
                # print('logits', logits.shape, batch_labels.shape)

                train_loss =  F.binary_cross_entropy_with_logits(logits, batch_labels.float())
                self.optimizer.zero_grad()
                train_loss.backward()

                nn.utils.clip_grad_norm_(self.model.parameters(), 5)

                self.optimizer.step()


                

                mini_batch_accuracy = torch_accuracy(logits, batch_labels)
                train_num_correct += mini_batch_accuracy * batch_size                
                train_total_losses += (train_loss.item() * batch_size)

                # print('loss', train_loss.cpu().item(), mini_batch_accuracy)


            # loss and accuracy of this epoch
            train_average_loss = train_total_losses / num_train_samples           
            train_accuracy = train_num_correct / num_train_samples  

            train_losses.append(train_average_loss)
            train_accuracies.append(train_accuracy)

            # train precision, recall, F1 score
            train_macro_precision = macro_precision(pred_temp, label_temp)
            train_macro_recall = macro_recall(pred_temp, label_temp)
            train_micro_f1 = micro_f1(pred_temp, label_temp)
            train_macro_f1 = macro_f1(pred_temp, label_temp)

            ###########################################
            # validation

            val_num_correct = 0  # number of correct prediction in validation set
            val_total_losses = 0  # total cross entropy loss
            pred_temp = np.array([])
            label_temp = np.array([])

            self.model.eval()
            with torch.no_grad():
               for step in range(num_val_batches):
                    data = next(val_dataloader)

                    ###########################################
                    # compute embedding
                    input_nodes, blocks, batch_edges, batch_labels = data

                    input_nodes = input_nodes.to(self.device)
                    blocks = [block.int().to(self.device) for block in blocks]
                    batch_edges = batch_edges.to(self.device)
                    batch_labels = batch_labels.to(self.device)
                
                    node_embedding = self.model(input_nodes, blocks)

                    logits = self.compute_logits(node_embedding, batch_edges)
 
                    pred = torch.sigmoid(logits) > 0.5
                    pred_temp=np.append(pred_temp, pred.long().cpu().detach().numpy())
                    label_temp=np.append(label_temp, batch_labels.cpu())
                    # collect outputs
                    # logits_tmp.append(logits)
                    # labels_tmp.append(batch_labels)

                        # print('logits', logits.shape, batch_labels.shape)

                    ###########################################
                    # update step
                    

                    # logits = torch.cat(logits_tmp, 0)
                    # batch_labels = torch.cat(labels_tmp, 0)
                    # print('logits', logits.shape, batch_labels.shape)

                    val_loss =  F.binary_cross_entropy_with_logits(logits, batch_labels.float())


                    mini_batch_accuracy = torch_accuracy(logits, batch_labels)
                    # val_num_correct += mini_batch_accuracy * batch_size 
                    val_num_correct += mini_batch_accuracy               
                    val_total_losses += val_loss.cpu().item()

                    # print('val acc ', val_loss.cpu().item(), mini_batch_accuracy)

            val_average_loss = val_total_losses / num_val_batches
            val_losses.append(val_average_loss)
            val_accuracy = val_num_correct / num_val_batches
            val_accuracies.append(val_accuracy)

            val_macro_precision = macro_precision(pred_temp, label_temp)
            val_macro_recall = macro_recall(pred_temp, label_temp)
            val_micro_f1 = micro_f1(pred_temp, label_temp)
            val_macro_f1 = macro_f1(pred_temp, label_temp)
            
            if val_accuracy > best_val_acc:
                best_val_result = (val_accuracy, val_macro_precision, val_macro_recall, val_macro_f1, val_micro_f1)
                best_val_acc = val_accuracy
                # best_val_y = (pred_temp, label_temp)
                torch.save(self.model.state_dict(), self.checkpoint_path)

            logging.info("Epoch {:05d} | Time(s) {:.4f} | \n"
                "TrainLoss {:.4f} | TrainAcc {:.4f} | TrainPrecision {:.4f} | TrainRecall {:.4f} | TrainMacroF1 {:.4f}\n"
                "ValLoss {:.4f}   | ValAcc {:.4f}   | ValPrecision {:.4f}   | ValRecall {:.4f}   | ValMacroF1 {:.4f}".
                format(e, np.mean(dur), 
                       train_average_loss, train_accuracy, train_macro_precision, train_macro_recall, train_macro_f1, 
                       val_average_loss, val_accuracy, val_macro_precision, val_macro_recall, val_macro_f1))
        
            ### Early stopping
            self.early_stopping(val_accuracy, self.model)
            if self.early_stopping.early_stop:
                logging.info("Early stopping")
                break

        ### best validation result
        logging.info(
            'Best val result: ValAcc {:.4f}   | ValPrecision {:.4f}    | ValRecall {:.4f}   | ValMacroF1 {:.4f}\n'
            .format(best_val_result[0], best_val_result[1], best_val_result[2], best_val_result[3]))

        ###########################################
        # testing
        test_losses = []  # per mini-batch
        test_accuracies = []
        test_num_correct = 0
        test_total_losses = 0

        num_test_samples = len(self.test_loader)
        num_test_batches = (num_test_samples - 1) // self.batch_size + 1

        pred_temp = np.array([])
        label_temp = np.array([])

        self.model.eval()
        with torch.no_grad():
           for step in range(num_test_batches):
                data = next(test_dataloader)

                ###########################################
                # compute embedding
                input_nodes, blocks, batch_edges, batch_labels = data

                input_nodes = input_nodes.to(self.device)
                blocks = [block.int().to(self.device) for block in blocks]
                batch_edges = batch_edges.to(self.device)
                batch_labels = batch_labels.to(self.device)
            
                node_embedding = self.model(input_nodes, blocks)

                logits = self.compute_logits(node_embedding, batch_edges)

                test_loss =  F.binary_cross_entropy_with_logits(logits, batch_labels.float())

                mini_batch_accuracy = torch_accuracy(logits, batch_labels)
                # val_num_correct += mini_batch_accuracy * batch_size 
                test_num_correct += mini_batch_accuracy               
                test_total_losses += test_loss.cpu().item()

                pred = torch.sigmoid(logits) > 0.5
                pred_temp=np.append(pred_temp, pred.long().cpu().detach().numpy())
                label_temp=np.append(label_temp, batch_labels.cpu())
                # print('val acc ', val_loss.cpu().item(), mini_batch_accuracy)

        test_average_loss = test_total_losses / num_test_batches
        test_accuracy = test_num_correct / num_test_batches

        test_macro_precision = macro_precision(pred_temp, label_temp)
        test_macro_recall = macro_recall(pred_temp, label_temp)
        test_micro_f1 = micro_f1(pred_temp, label_temp)
        test_macro_f1 = macro_f1(pred_temp, label_temp)

        logging.info("Finishing training...\n"
                "TestLoss {:.4f} | TestAcc {:.4f} | TestPrecision {:.4f} | TestRecall {:.4f} | TestMacroF1 {:.4f}\n".
                format(test_average_loss, test_accuracy, test_macro_precision, test_macro_recall, test_macro_f1))

        self.plot(train_losses, val_losses, train_accuracies, val_accuracies)

        return best_val_result, (test_average_loss, test_accuracy, test_macro_precision, test_macro_recall, test_macro_f1)