예제 #1
0
 def test_AccuaryMetric8(self):
     try:
         metric = AccuracyMetric(pred='predictions', target='targets')
         pred_dict = {"predictions": torch.zeros(4, 3, 2)}
         target_dict = {'targets': torch.zeros(4, 3)}
         metric(
             pred_dict=pred_dict,
             target_dict=target_dict,
         )
         self.assertDictEqual(metric.get_metric(), {'acc': 1})
     except Exception as e:
         print(e)
         return
     self.assertTrue(True, False), "No exception catches."
예제 #2
0
 def test_AccuaryMetric10(self):
     # (10) check _fast_metric
     try:
         metric = AccuracyMetric()
         pred_dict = {
             "predictions": torch.zeros(4, 3, 2),
             "seq_len": torch.ones(3) * 3
         }
         target_dict = {'targets': torch.zeros(4, 3)}
         metric(pred_dict=pred_dict, target_dict=target_dict)
         self.assertDictEqual(metric.get_metric(), {'acc': 1})
     except Exception as e:
         print(e)
         return
     self.assertTrue(True, False), "No exception catches."
예제 #3
0
    def test_AccuracyMetric2(self):
        # (2) with corrupted size
        try:
            pred_dict = {"pred": torch.zeros(4, 3, 2)}
            target_dict = {'target': torch.zeros(4)}
            metric = AccuracyMetric()

            metric(
                pred_dict=pred_dict,
                target_dict=target_dict,
            )
            print(metric.get_metric())
        except Exception as e:
            print(e)
            return
        print("No exception catches.")
예제 #4
0
    def test_AccuracyMetric3(self):
        # (3) the second batch is corrupted size
        try:
            metric = AccuracyMetric()
            pred_dict = {"pred": torch.zeros(4, 3, 2)}
            target_dict = {'target': torch.zeros(4, 3)}
            metric(pred_dict=pred_dict, target_dict=target_dict)

            pred_dict = {"pred": torch.zeros(4, 3, 2)}
            target_dict = {'target': torch.zeros(4)}
            metric(pred_dict=pred_dict, target_dict=target_dict)

            print(metric.get_metric())
        except Exception as e:
            print(e)
            return
        self.assertTrue(True, False), "No exception catches."
예제 #5
0
def train():
    n_epochs = 10
    train_set = data_set_loader._load('../models/all4bert_new_triple.txt')
    train_set, tmp_set = train_set.split(0.2)
    val_set, test_set = tmp_set.split(0.5)
    data_bundle = [train_set, val_set, test_set]

    for dataset in data_bundle:
        dataset.apply(addWords, new_field_name="p_words")
        dataset.apply(addWordPiece, new_field_name="t_words")
        dataset.apply(processItem, new_field_name="word_pieces")
        dataset.apply(processNum, new_field_name="word_nums")
        dataset.apply(addSeqlen, new_field_name="seq_len")
        dataset.apply(processTarget, new_field_name="target")

    for dataset in data_bundle:
        dataset.field_arrays["word_pieces"].is_input = True
        dataset.field_arrays["seq_len"].is_input = True
        dataset.field_arrays["word_nums"].is_input = True
        dataset.field_arrays["target"].is_target = True

    print("In total " + str(len(data_bundle)) + " datasets:")
    print("Trainset has " + str(len(train_set)) + " instances.")
    print("Validateset has " + str(len(val_set)) + " instances.")
    print("Testset has " + str(len(test_set)) + " instances.")
    train_set.print_field_meta()
    # print(train_set)
    from fastNLP.models.Mybert import BertForSentenceMatching
    from fastNLP import AccuracyMetric, DataSetIter

    from fastNLP.core.utils import _pseudo_tqdm as tqdm
    # 注意这里是表明分的类数
    model = BertForSentenceMatching(embed, 3)
    if torch.cuda.is_available():
        model = _move_model_to_device(model, device=0)
    # print(model)
    train_batch = DataSetIter(batch_size=16, dataset=train_set, sampler=None)
    optimizer = torch.optim.Adam(model.parameters(), lr=2e-5)
    Lossfunc = torch.nn.CrossEntropyLoss()
    with tqdm(total=n_epochs,
              postfix='loss:{0:<6.5f}',
              leave=False,
              dynamic_ncols=True) as pbar:
        print_every = 10
        for epoch in range(1, n_epochs + 1):
            pbar.set_description_str(
                desc="Epoch {}/{}".format(epoch, n_epochs))
            avg_loss = 0
            step = 0
            for batch_x, batch_y in train_batch:
                step += 1
                _move_dict_value_to_device(batch_x,
                                           batch_y,
                                           device=_get_model_device(model))
                optimizer.zero_grad()
                output = model.forward(batch_x["word_pieces"],
                                       batch_x["word_nums"],
                                       batch_x["seq_len"])
                loss = Lossfunc(output['pred'], batch_y['target'])
                loss.backward()
                optimizer.step()
                avg_loss += loss.item()
                if step % print_every == 0:
                    avg_loss = float(avg_loss) / print_every
                    print_output = "[epoch: {:>3} step: {:>4}] train loss: {:>4.6}".format(
                        epoch, step, avg_loss)
                    pbar.update(print_every)
                    pbar.set_postfix_str(print_output)
                    avg_loss = 0
            metric = AccuracyMetric()
            val_batch = DataSetIter(batch_size=8,
                                    dataset=val_set,
                                    sampler=None)
            for batch_x, batch_y in val_batch:
                _move_dict_value_to_device(batch_x,
                                           batch_y,
                                           device=_get_model_device(model))
                output = model.predict(batch_x["word_pieces"],
                                       batch_x["word_nums"],
                                       batch_x["seq_len"])
                metric(output, batch_y)
            eval_result = metric.get_metric()
            print("ACC on Validate Set:", eval_result)
            from fastNLP.io import ModelSaver
            saver = ModelSaver("../models/bert_model_max_triple.pkl")
            saver.save_pytorch(model, param_only=False)
        pbar.close()
    metric = AccuracyMetric()
    test_batch = DataSetIter(batch_size=8, dataset=test_set, sampler=None)
    for batch_x, batch_y in test_batch:
        _move_dict_value_to_device(batch_x,
                                   batch_y,
                                   device=_get_model_device(model))
        output = model.predict(batch_x["word_pieces"], batch_x["word_nums"],
                               batch_x["seq_len"])
        metric(output, batch_y)
    eval_result = metric.get_metric()
    print("ACC on Test Set:", eval_result)
    from fastNLP.io import ModelSaver
    saver = ModelSaver("../models/bert_model_max_triple.pkl")
    saver.save_pytorch(model, param_only=False)