Esempio n. 1
0
    def initialize(self, resource: Resources, configs: HParams):

        self.resource = resource

        self.word_alphabet = resource.get("word_alphabet")
        self.char_alphabet = resource.get("char_alphabet")
        self.ner_alphabet = resource.get("ner_alphabet")

        word_embedding_table = resource.get('word_embedding_table')

        self.config_model = configs.config_model
        self.config_data = configs.config_data

        self.normalize_func = utils.normalize_digit_word

        self.device = torch.device("cuda") if torch.cuda.is_available() \
            else torch.device("cpu")

        utils.set_random_seed(self.config_model.random_seed)

        self.model = BiRecurrentConvCRF(
            word_embedding_table, self.char_alphabet.size(),
            self.ner_alphabet.size(), self.config_model).to(device=self.device)

        self.optim = SGD(self.model.parameters(),
                         lr=self.config_model.learning_rate,
                         momentum=self.config_model.momentum,
                         nesterov=True)

        self.trained_epochs = 0

        self.resource.update(model=self.model)
Esempio n. 2
0
            def load_model(path):
                model = BiRecurrentConvCRF(word_embedding_table,
                                           self.char_alphabet.size(),
                                           self.ner_alphabet.size(),
                                           self.config_model)

                if os.path.exists(path):
                    with open(path, "rb") as f:
                        weights = torch.load(f, map_location=self.device)
                        model.load_state_dict(weights)
                return model
Esempio n. 3
0
    def initialize(self, resources: Resources, configs: Config):
        """
        The training pipeline will run this initialization method during
        the initialization phase and send resources in as parameters.

        Args:
            resources: The resources shared in the pipeline.
            configs: configuration object for this trainer.

        Returns:

        """
        self.resource = resources

        self.word_alphabet = resources.get("word_alphabet")
        self.char_alphabet = resources.get("char_alphabet")
        self.ner_alphabet = resources.get("ner_alphabet")

        word_embedding_table = resources.get("word_embedding_table")

        self.config_model = configs.config_model
        self.config_data = configs.config_data

        self.normalize_func = utils.normalize_digit_word

        self.device = (
            torch.device("cuda")
            if torch.cuda.is_available()
            else torch.device("cpu")
        )

        utils.set_random_seed(self.config_model.random_seed)

        self.model = BiRecurrentConvCRF(
            word_embedding_table,
            self.char_alphabet.size(),
            self.ner_alphabet.size(),
            self.config_model,
        ).to(device=self.device)

        self.optim = SGD(
            self.model.parameters(),
            lr=self.config_model.learning_rate,
            momentum=self.config_model.momentum,
            nesterov=True,
        )

        self.trained_epochs = 0

        self.resource.update(model=self.model)
Esempio n. 4
0
class CoNLLNERTrainer(BaseTrainer):
    def __init__(self):
        super().__init__()

        self.model = None

        self.word_alphabet = None
        self.char_alphabet = None
        self.ner_alphabet = None

        self.config_model = None
        self.config_data = None
        self.normalize_func = None

        self.device = None
        self.optim, self.trained_epochs = None, None

        self.resource: Optional[Resources] = None

        self.train_instances_cache = []

        # Just for recording
        self.max_char_length = 0

        self.__past_dev_result = None

    def initialize(self, resource: Resources, configs: HParams):

        self.resource = resource

        self.word_alphabet = resource.get("word_alphabet")
        self.char_alphabet = resource.get("char_alphabet")
        self.ner_alphabet = resource.get("ner_alphabet")

        word_embedding_table = resource.get('word_embedding_table')

        self.config_model = configs.config_model
        self.config_data = configs.config_data

        self.normalize_func = utils.normalize_digit_word

        self.device = torch.device("cuda") if torch.cuda.is_available() \
            else torch.device("cpu")

        utils.set_random_seed(self.config_model.random_seed)

        self.model = BiRecurrentConvCRF(
            word_embedding_table, self.char_alphabet.size(),
            self.ner_alphabet.size(), self.config_model).to(device=self.device)

        self.optim = SGD(self.model.parameters(),
                         lr=self.config_model.learning_rate,
                         momentum=self.config_model.momentum,
                         nesterov=True)

        self.trained_epochs = 0

        self.resource.update(model=self.model)

    def data_request(self):
        request_string = {
            "context_type": Sentence,
            "request": {
                Token: ["ner"],
                Sentence: [],  # span by default
            }
        }
        return request_string

    def consume(self, instance):
        tokens = instance["Token"]
        word_ids = []
        char_id_seqs = []
        ner_tags, ner_ids = tokens["ner"], []

        for word in tokens["text"]:
            char_ids = []
            for char in word:
                char_ids.append(self.char_alphabet.get_index(char))
            if len(char_ids) > self.config_data.max_char_length:
                char_ids = char_ids[:self.config_data.max_char_length]
            char_id_seqs.append(char_ids)

            word = self.normalize_func(word)
            word_ids.append(self.word_alphabet.get_index(word))

        for ner in ner_tags:
            ner_ids.append(self.ner_alphabet.get_index(ner))

        max_len = max([len(char_seq) for char_seq in char_id_seqs])
        self.max_char_length = max(self.max_char_length, max_len)

        self.train_instances_cache.append((word_ids, char_id_seqs, ner_ids))

    def pack_finish_action(self, pack_count):
        pass

    def epoch_finish_action(self, epoch):
        """
        at the end of each dataset_iteration, we perform the training,
        and set validation flags
        :return:
        """
        counter = len(self.train_instances_cache)
        logger.info(f"Total number of ner_data: {counter}")

        lengths = \
            sum([len(instance[0]) for instance in self.train_instances_cache])

        logger.info(f"Average sentence length: {(lengths / counter):0.3f}")

        train_err = 0.0
        train_total = 0.0

        start_time = time.time()
        self.model.train()

        # Each time we will clear and reload the train_instances_cache
        instances = self.train_instances_cache
        random.shuffle(self.train_instances_cache)
        data_iterator = torchtext.data.iterator.pool(
            instances,
            self.config_data.batch_size_tokens,
            key=lambda x: x.length(),  # length of word_ids
            batch_size_fn=batch_size_fn,
            random_shuffler=torchtext.data.iterator.RandomShuffler())

        step = 0

        for batch in data_iterator:
            step += 1
            batch_data = self.get_batch_tensor(batch, device=self.device)
            word, char, labels, masks, lengths = batch_data

            self.optim.zero_grad()
            loss = self.model(word, char, labels, mask=masks)
            loss.backward()
            self.optim.step()

            num_inst = word.size(0)
            train_err += loss.item() * num_inst
            train_total += num_inst

            # update log
            if step % 200 == 0:
                logger.info(f"Train: {step}, "
                            f"loss: {(train_err / train_total):0.3f}")

        logger.info(f"Epoch: {epoch}, steps: {step}, "
                    f"loss: {(train_err / train_total):0.3f}, "
                    f"time: {(time.time() - start_time):0.3f}s")

        self.trained_epochs = epoch

        if epoch % self.config_model.decay_interval == 0:
            lr = self.config_model.learning_rate / \
                 (1.0 + self.trained_epochs * self.config_model.decay_rate)
            for param_group in self.optim.param_groups:
                param_group["lr"] = lr
            logger.info(f"Update learning rate to {lr:0.3f}")

        self.request_eval()
        self.train_instances_cache.clear()

        if epoch >= self.config_data.num_epochs:
            self.request_stop_train()

    @torch.no_grad()
    def get_loss(self, instances: Iterator) -> float:
        losses = 0
        val_data = list(instances)
        for i in tqdm(range(0, len(val_data),
                            self.config_data.test_batch_size)):
            b_data = val_data[i:i + self.config_data.test_batch_size]
            batch = self.get_batch_tensor(b_data, device=self.device)

            word, char, labels, masks, unused_lengths = batch
            loss = self.model(word, char, labels, mask=masks)
            losses += loss.item()

        mean_loss = losses / len(val_data)
        return mean_loss

    def post_validation_action(self, eval_result):
        if self.__past_dev_result is None or \
                (eval_result["eval"]["f1"] >
                 self.__past_dev_result["eval"]["f1"]):
            self.__past_dev_result = eval_result
            logger.info("Validation f1 increased, saving model")
            self.save_model_checkpoint()

        best_epoch = self.__past_dev_result["epoch"]
        acc, prec, rec, f1 = (self.__past_dev_result["eval"]["accuracy"],
                              self.__past_dev_result["eval"]["precision"],
                              self.__past_dev_result["eval"]["recall"],
                              self.__past_dev_result["eval"]["f1"])
        logger.info(f"Best val acc: {acc: 0.3f}, precision: {prec:0.3f}, "
                    f"recall: {rec:0.3f}, F1: {f1:0.3f}, epoch={best_epoch}")

        if "test" in self.__past_dev_result:
            acc, prec, rec, f1 = (self.__past_dev_result["test"]["accuracy"],
                                  self.__past_dev_result["test"]["precision"],
                                  self.__past_dev_result["test"]["recall"],
                                  self.__past_dev_result["test"]["f1"])
            logger.info(
                f"Best test acc: {acc: 0.3f}, precision: {prec: 0.3f}, "
                f"recall: {rec: 0.3f}, F1: {f1: 0.3f}, "
                f"epoch={best_epoch}")

    def finish(self, resources: Resources):  # pylint: disable=unused-argument
        if self.resource:
            keys_to_serializers = {}
            for key in resources.keys():
                if key == "model":
                    keys_to_serializers[key] = \
                        lambda x, y: pickle.dump(x.state_dict(), open(y, "wb"))
                else:
                    keys_to_serializers[key] = \
                        lambda x, y: pickle.dump(x, open(y, "wb"))

            self.resource.save(keys_to_serializers)

        self.save_model_checkpoint()

    def save_model_checkpoint(self):
        states = {
            "model": self.model.state_dict(),
            "optimizer": self.optim.state_dict(),
        }
        torch.save(states, self.config_model.model_path)

    def load_model_checkpoint(self):
        ckpt = torch.load(self.config_model.model_path)
        logger.info("restoring model from %s", self.config_model.model_path)
        self.model.load_state_dict(ckpt["model"])
        self.optim.load_state_dict(ckpt["optimizer"])

    def get_batch_tensor(
            self, data: List[Tuple[List[int], List[List[int]], List[int]]],
            device: Optional[torch.device] = None) -> \
            Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor,
                  torch.Tensor]:
        """Get the tensors to be fed into the model.

        Args:
            data: A list of tuple (word_ids, char_id_sequences, ner_ids)
            device: The device for the tensors.

        Returns:
            A tuple where

            - ``words``: A tensor of shape `[batch_size, batch_length]`
              representing the word ids in the batch
            - ``chars``: A tensor of shape
              `[batch_size, batch_length, char_length]` representing the char
              ids for each word in the batch
            - ``ners``: A tensor of shape `[batch_size, batch_length]`
              representing the ner ids for each word in the batch
            - ``masks``: A tensor of shape `[batch_size, batch_length]`
              representing the indices to be masked in the batch. 1 indicates
              no masking.
            - ``lengths``: A tensor of shape `[batch_size]` representing the
              length of each sentences in the batch
        """
        batch_size = len(data)
        batch_length = max([len(d[0]) for d in data])
        char_length = max(
            [max([len(charseq) for charseq in d[1]]) for d in data])

        char_length = min(
            self.config_data.max_char_length,
            char_length + self.config_data.num_char_pad,
        )

        wid_inputs = np.empty([batch_size, batch_length], dtype=np.int64)
        cid_inputs = np.empty([batch_size, batch_length, char_length],
                              dtype=np.int64)
        nid_inputs = np.empty([batch_size, batch_length], dtype=np.int64)

        masks = np.zeros([batch_size, batch_length], dtype=np.float32)

        lengths = np.empty(batch_size, dtype=np.int64)

        for i, inst in enumerate(data):
            wids, cid_seqs, nids = inst

            inst_size = len(wids)
            lengths[i] = inst_size
            # word ids
            wid_inputs[i, :inst_size] = wids
            wid_inputs[i, inst_size:] = self.word_alphabet.pad_id
            for c, cids in enumerate(cid_seqs):
                cid_inputs[i, c, :len(cids)] = cids
                cid_inputs[i, c, len(cids):] = self.char_alphabet.pad_id
            cid_inputs[i, inst_size:, :] = self.char_alphabet.pad_id
            # ner ids
            nid_inputs[i, :inst_size] = nids
            nid_inputs[i, inst_size:] = self.ner_alphabet.pad_id
            # masks
            masks[i, :inst_size] = 1.0

        words = torch.from_numpy(wid_inputs).to(device)
        chars = torch.from_numpy(cid_inputs).to(device)
        ners = torch.from_numpy(nid_inputs).to(device)
        masks = torch.from_numpy(masks).to(device)
        lengths = torch.from_numpy(lengths).to(device)

        return words, chars, ners, masks, lengths