示例#1
0
    def eval(self, dataset_path: str, types_path: str, input_reader_cls: BaseInputReader):
        args = self.args
        dataset_label = 'test'

        self._logger.info("Dataset: %s" % dataset_path)
        self._logger.info("Model: %s" % args.model_type)

        # create log csv files
        self._init_eval_logging(dataset_label)

        # read datasets
        with open(dataset_path, 'r') as json_file:
            cnt = 0
            batch_size = 32
            documents = []
            for line in json_file:
                cnt += 1
                documents.append(line)
                if cnt == batch_size:
                    cnt = 0
                    input_reader = input_reader_cls(types_path, self._tokenizer,
                                                    max_span_size=args.max_span_size, logger=self._logger)
                    input_reader.dump_dataset(dataset_label, documents)
                    self._log_datasets(input_reader)

                    # create model
                    model_class = models.get_model(self.args.model_type)

                    config = PhobertConfig.from_pretrained(self.args.model_path, cache_dir=self.args.cache_path)
                    util.check_version(config, model_class, self.args.model_path)

                    model = model_class.from_pretrained(self.args.model_path,
                                                        config=self.args.config_path,
                                                        cache_dir=self.args.cache_path,
                                                        # SpERT model parameters
                                                        cls_token=0,
                                                        relation_types=input_reader.relation_type_count - 1,
                                                        entity_types=input_reader.entity_type_count,
                                                        max_pairs=self.args.max_pairs,
                                                        prop_drop=self.args.prop_drop,
                                                        size_embedding=self.args.size_embedding,
                                                        freeze_transformer=self.args.freeze_transformer)

                    model.to(self._device)

                    # evaluate
                    self._eval(model, input_reader.get_dataset(dataset_label), input_reader)

                    self._logger.info("Logged in: %s" % self._log_path)
                    documents = []
        self._close_summary_writer()
示例#2
0
    def eval(self, dataset_path: str, types_path: str,
             input_reader_cls: BaseInputReader):
        args = self.args
        dataset_label = 'test'

        self._logger.info("Dataset: %s" % dataset_path)
        self._logger.info("Model: %s" % args.model_type)

        # create log csv files
        self._init_eval_logging(dataset_label)

        # read datasets
        input_reader = input_reader_cls(types_path,
                                        self._tokenizer,
                                        max_span_size=args.max_span_size,
                                        logger=self._logger)
        input_reader.read({dataset_label: dataset_path})
        self._log_datasets(input_reader)

        # create model
        model_class = models.get_model(self.args.model_type)

        config = BertConfig.from_pretrained(self.args.model_path,
                                            cache_dir=self.args.cache_path)
        util.check_version(config, model_class, self.args.model_path)

        model = model_class.from_pretrained(
            self.args.model_path,
            config=config,
            # SpERT model parameters
            cls_token=self._tokenizer.convert_tokens_to_ids('[CLS]'),
            relation_types=input_reader.relation_type_count - 1,
            entity_types=input_reader.entity_type_count,
            max_pairs=self.args.max_pairs,
            prop_drop=self.args.prop_drop,
            size_embedding=self.args.size_embedding,
            freeze_transformer=self.args.freeze_transformer,
            cache_dir=self.args.cache_path)

        model.to(self._device)

        # evaluate
        self._eval(model, input_reader.get_dataset(dataset_label),
                   input_reader)

        self._logger.info("Logged in: %s" % self._log_path)
        self._close_summary_writer()
示例#3
0
    def _load_model(self, input_reader):
        model_class = models.get_model(self._args.model_type)

        config = BertConfig.from_pretrained(self._args.model_path, cache_dir=self._args.cache_path)
        util.check_version(config, model_class, self._args.model_path)

        config.spert_version = model_class.VERSION
        model = model_class.from_pretrained(self._args.model_path,
                                            config=config,
                                            # SpERT model parameters
                                            cls_token=self._tokenizer.convert_tokens_to_ids('[CLS]'),
                                            relation_types=input_reader.relation_type_count - 1,
                                            entity_types=input_reader.entity_type_count,
                                            max_pairs=self._args.max_pairs,
                                            prop_drop=self._args.prop_drop,
                                            size_embedding=self._args.size_embedding,
                                            freeze_transformer=self._args.freeze_transformer,
                                            cache_dir=self._args.cache_path)

        return model
示例#4
0
    def eval(self, dataset_path: str, types_path: str, input_reader_cls: BaseInputReader):
        args = self.args
        dataset_label = 'test'

        self._logger.info("Dataset: %s" % dataset_path)
        self._logger.info("Model: %s" % args.model_type)

        # create log csv files
        self._init_eval_logging(dataset_label)

        # read datasets
        input_reader = input_reader_cls(types_path, self._tokenizer, self._logger)
        input_reader.read({dataset_label: dataset_path})
        self._log_datasets(input_reader)

        # create model
        model_class = models.get_model(self.args.model_type)

        # load model
        model = model_class.from_pretrained(self.args.model_path,
                                            cache_dir=self.args.cache_path,
                                            # additional model parameters
                                            cls_token=self._tokenizer.convert_tokens_to_ids('[CLS]'),
                                            # no node for 'none' class
                                            relation_types=input_reader.relation_type_count - 1,
                                            entity_types=input_reader.entity_type_count,
                                            max_pairs=self.args.max_pairs,
                                            prop_drop=self.args.prop_drop,
                                            size_embedding=self.args.size_embedding,
                                            freeze_transformer=self.args.freeze_transformer)

        model.to(self._device)

        # evaluate
        self._eval(model, input_reader.get_dataset(dataset_label), input_reader)
        self._logger.info("Logged in: %s" % self._log_path)

        self._sampler.join()
示例#5
0
    def train(self, train_path: str, valid_path: str, types_path: str,
              input_reader_cls: BaseInputReader):
        args = self.args
        train_label, valid_label = 'train', 'valid'

        self._logger.info("Datasets: %s, %s" % (train_path, valid_path))
        self._logger.info("Model type: %s" % args.model_type)

        # create log csv files
        self._init_train_logging(train_label)
        self._init_eval_logging(valid_label)

        # read datasets
        input_reader = input_reader_cls(types_path, self._tokenizer,
                                        args.neg_entity_count,
                                        args.neg_relation_count,
                                        args.max_span_size, self._logger)
        input_reader.read({train_label: train_path, valid_label: valid_path})
        self._log_datasets(input_reader)

        train_dataset = input_reader.get_dataset(train_label)
        train_sample_count = train_dataset.document_count
        updates_epoch = train_sample_count // args.train_batch_size
        updates_total = updates_epoch * args.epochs

        validation_dataset = input_reader.get_dataset(valid_label)

        self._logger.info("Updates per epoch: %s" % updates_epoch)
        self._logger.info("Updates total: %s" % updates_total)

        # create model
        model_class = models.get_model(self.args.model_type)

        # load model
        config = BertConfig.from_pretrained(self.args.model_path,
                                            cache_dir=self.args.cache_path)
        util.check_version(config, model_class, self.args.model_path)

        config.spert_version = model_class.VERSION
        model = model_class.from_pretrained(
            self.args.model_path,
            config=config,
            # SpERT model parameters
            cls_token=self._tokenizer.convert_tokens_to_ids('[CLS]'),
            relation_types=input_reader.relation_type_count - 1,
            entity_types=input_reader.entity_type_count,
            max_pairs=self.args.max_pairs,
            prop_drop=self.args.prop_drop,
            size_embedding=self.args.size_embedding,
            freeze_transformer=self.args.freeze_transformer,
            cache_dir=self.args.cache_path)

        # SpERT is currently optimized on a single GPU and not thoroughly tested in a multi GPU setup
        # If you still want to train SpERT on multiple GPUs, uncomment the following lines
        # # parallelize model
        # if self._device.type != 'cpu':
        #     model = torch.nn.DataParallel(model)

        model.to(self._device)

        # create optimizer
        optimizer_params = self._get_optimizer_params(model)
        optimizer = AdamW(optimizer_params,
                          lr=args.lr,
                          weight_decay=args.weight_decay,
                          correct_bias=False)
        # create scheduler
        scheduler = transformers.get_linear_schedule_with_warmup(
            optimizer,
            num_warmup_steps=args.lr_warmup * updates_total,
            num_training_steps=updates_total)
        # create loss function
        rel_criterion = torch.nn.BCEWithLogitsLoss(reduction='none')
        entity_criterion = torch.nn.CrossEntropyLoss(reduction='none')
        compute_loss = SpERTLoss(rel_criterion, entity_criterion, model,
                                 optimizer, scheduler, args.max_grad_norm)

        # eval validation set
        if args.init_eval:
            self._eval(model, validation_dataset, input_reader, 0,
                       updates_epoch)

        # train
        for epoch in range(args.epochs):
            # train epoch
            self._train_epoch(model, compute_loss, optimizer, train_dataset,
                              updates_epoch, epoch)

            # eval validation sets
            if not args.final_eval or (epoch == args.epochs - 1):
                self._eval(model, validation_dataset, input_reader, epoch + 1,
                           updates_epoch)

        # save final model
        extra = dict(epoch=args.epochs,
                     updates_epoch=updates_epoch,
                     epoch_iteration=0)
        global_iteration = args.epochs * updates_epoch
        self._save_model(
            self._save_path,
            model,
            self._tokenizer,
            global_iteration,
            optimizer=optimizer if self.args.save_optimizer else None,
            extra=extra,
            include_iteration=False,
            name='final_model')

        self._logger.info("Logged in: %s" % self._log_path)
        self._logger.info("Saved in: %s" % self._save_path)
        self._close_summary_writer()
示例#6
0
    def predict(self, dataset_path: str, types_path:str, input_reader_cls: BaseInputReader):
        args = self.args

        dataset_label = 'predict'
        self._logger.info("Dataset: %s" % dataset_path)
        self._logger.info("Model: %s" % args.model_type)
        model_class = models.get_model(self.args.model_type)
        # read dataset
        model = None
        with open(dataset_path, 'r') as json_file:
            cnt = 0
            batch_size = self.args.predict_batch_size
            documents = []
            for line in json_file:
                if cnt == batch_size or not line:
                    cnt = 0
                    input_reader = input_reader_cls(types_path, self._tokenizer,
                                                    max_span_size=args.max_span_size, logger=self._logger)

                    input_reader.dump_dataset(dataset_label, documents)

                    #create model
                    if model is None:
                        model = model_class.from_pretrained(self.args.model_path,
                                                            config=self.args.config_path,
                                                            cache_dir=self.args.cache_path,
                                                            # SpERT model parameters
                                                            cls_token=0,
                                                            relation_types=input_reader.relation_type_count - 1,
                                                            entity_types=input_reader.entity_type_count,
                                                            max_pairs=self.args.max_pairs,
                                                            prop_drop=self.args.prop_drop,
                                                            size_embedding=self.args.size_embedding,
                                                            freeze_transformer=self.args.freeze_transformer)
                        model.to(self._device)

                    #predict
                    self._predict(model, input_reader.get_dataset(dataset_label), input_reader)
                    documents = []

                if not line:
                    break
                cnt += 1
                documents.append(line)
            else:
                input_reader = input_reader_cls(types_path, self._tokenizer,
                                                max_span_size=args.max_span_size, logger=self._logger)

                input_reader.dump_dataset(dataset_label, documents)

                # create model
                if model is None:
                    model = model_class.from_pretrained(self.args.model_path,
                                                        config=self.args.config_path,
                                                        cache_dir=self.args.cache_path,
                                                        # SpERT model parameters
                                                        cls_token=0,
                                                        relation_types=input_reader.relation_type_count - 1,
                                                        entity_types=input_reader.entity_type_count,
                                                        max_pairs=self.args.max_pairs,
                                                        prop_drop=self.args.prop_drop,
                                                        size_embedding=self.args.size_embedding,
                                                        freeze_transformer=self.args.freeze_transformer)
                    model.to(self._device)

                # predict
                self._predict(model, input_reader.get_dataset(dataset_label), input_reader)

        self._logger.info("Logged in: %s" % self._log_path)
        self._close_summary_writer()