Exemplo n.º 1
0
 def evaluate(self):
     self.clf.eval()
     if self.config[MODEL_TYPE] > 1:
         self.coref_trainer.model.eval()    
     with torch.no_grad():
         eval = Evaluator(self, self.data_helper, self.config)
         eval.eval_parser(self.data_helper.val_trees)
Exemplo n.º 2
0
    def train_classifier(self, train_loader, dev_loader):
        """
        """
        self.optim = Optimizer(self.clf.parameters(), lr=self.config[LR])
        if self.config[EPOCH_START] != 1:
            self.load('../data/model/' + self.config[MODEL_NAME] + "_" +
                      str(self.config[EPOCH_START]))

        for epoch in range(1, self.config[NUM_EPOCHS] + 1):
            cost_acc = 0
            self.clf.train()
            print("============ epoch: ", epoch, " ============")
            for i, data in enumerate(train_loader):
                docs, gold_actions = data
                cost_acc += self.sr_parse(docs, gold_actions, self.optim)[1]
                if (i % 50 == 0):
                    print("Cost on step ", i, "is ", cost_acc)

            print("Total cost for epoch ", epoch, "is ", cost_acc)
            self.clf.eval()
            with torch.no_grad():
                eval = Evaluator(self, self.clf.data_helper)
                eval.eval_parser(dev_loader, path=None)
            self.save('../data/model/',
                      self.config[MODEL_NAME] + "_" + str(epoch), epoch)
Exemplo n.º 3
0
                                    parse_type=args.parse_type,
                                    isFlat=args.isFlat)
        train_model(data_helper)
    if args.eval:
        # Evaluate models on the RST-DT test set
        if args.isFlat:
            evaluator = Evaluator(isFlat=args.isFlat,
                                  model_dir=os.path.join(
                                      args.output_dir, "RN~model"))
        else:
            evaluator = Evaluator(isFlat=args.isFlat,
                                  model_dir=os.path.join(
                                      args.output_dir, "N~model"))
        evaluator.eval_parser(data_dir=args.data_dir,
                              output_dir=args.output_dir,
                              report=True,
                              bcvocab=brown_clusters,
                              draw=False,
                              isFlat=args.isFlat)

    if args.pred:
        if args.isFlat:
            evaluator = Evaluator(isFlat=args.isFlat,
                                  model_dir=os.path.join(
                                      args.output_dir, "RN~model"))
        else:
            evaluator = Evaluator(isFlat=args.isFlat,
                                  model_dir=os.path.join(
                                      args.output_dir, "N~model"))
        print("predicting")
        evaluator.pred_parser(output_dir=args.output_dir,
                              parse_type=args.parse_type,
Exemplo n.º 4
0
    train_dirname = (args.train_dir[:-1] if args.train_dir[-1] == os.sep else
                     args.train_dir).split(os.sep)[-1]
    HELPER_PATH = f"..{os.sep}data{os.sep}{train_dirname}_data_helper_rst.bin"
    print("Helper path:", HELPER_PATH)

    if args.prepare:
        # Create training data
        #coref_model = CorefScore(higher_order=True).to(config[DEVICE])
        coref_model = CorefScore().to(config[DEVICE])

        coref_trainer = Trainer(coref_model, [], [], [], debug=False)

        data_helper.create_data_helper(args.train_dir, config, coref_trainer)
        data_helper.save_data_helper(HELPER_PATH)

    if args.train:
        train_model_coref(data_helper, config)

    if args.eval:
        # Evaluate models on the RST-DT test set
        data_helper.load_data_helper(HELPER_PATH)

        parser = get_discourse_parser(data_helper, config)
        parser.load('../data/model/' + config[MODEL_NAME])
        print("Evaluating")
        with torch.no_grad():
            evaluator = Evaluator(parser, data_helper, config)
            evaluator.eval_parser(None,
                                  path=args.eval_dir,
                                  use_parseval=args.use_parseval)
Exemplo n.º 5
0
    parser.add_argument('--eval_dir', help='eval data directory')
    return parser.parse_args()


if __name__ == '__main__':
    args = parse_args()
    # Use brown clusters
    with gzip.open("../data/resources/bc3200.pickle.gz") as fin:
        print('Load Brown clusters for creating features ...')
        brown_clusters = pickle.load(fin)
    data_helper = DataHelper(max_action_feat_num=330000,
                             max_relation_feat_num=300000,
                             min_action_feat_occur=1,
                             min_relation_feat_occur=1,
                             brown_clusters=brown_clusters)
    if args.prepare:
        # Create training data
        data_helper.create_data_helper(data_dir=args.train_dir)
        data_helper.save_data_helper('../data/data_helper.bin')
    if args.train:
        data_helper.load_data_helper('../data/data_helper.bin')
        data_helper.load_train_data(data_dir=args.train_dir)
        train_model(data_helper)
    if args.eval:
        # Evaluate models on the RST-DT test set
        evaluator = Evaluator(model_dir='../data/model')
        evaluator.eval_parser(path=args.eval_dir,
                              report=True,
                              bcvocab=brown_clusters,
                              draw=False)