def test_batch_seq_mask(): """ Test the function `batch_seq_mask`. """ mask = batch_seq_mask([2, 5, 3]) batch_mask = [[False, False, True, True, True], [False, False, False, False, False], [False, False, False, True, True]] assert mask == batch_mask
def train_callback_plmtg(batch_group: tuple, model: nn.Module, loss_fn: nn.Module, config: TrainerConfig) -> torch.FloatTensor: """ Callback of the function train for the model PLMTG. Args: batch_group (tuple): Batch tensor data. model (nn.Module): Model to be trained. loss_fn (nn.Module): Loss function. config (TrainerConfig): Trainer configuration. """ _, ready_to_eat = batch_group mask = torch.tensor(batch_seq_mask(ready_to_eat["lengths"].tolist())).to(ready_to_eat["inputs"].device) logits = model(ready_to_eat["inputs"], mask) loss = loss_fn(logits.view(-1, logits.shape[-1]), ready_to_eat["outputs"].view(-1)) return loss
def train_callback_ses(batch_group: tuple, model: nn.Module, loss_fn: nn.Module, config: TrainerConfig) -> torch.FloatTensor: """ Callback of the function train for the model SES. Args: batch_group (tuple): Batch tensor data. model (nn.Module): Model to be trained. loss_fn (nn.Module): Loss function. config (TrainerConfig): Trainer configuration. """ _, ready_to_eat = batch_group mask = torch.tensor(batch_seq_mask(ready_to_eat["lengths"].tolist())).to(ready_to_eat["inputs"].device).to(torch.uint8) loss = model(ready_to_eat["inputs"], ready_to_eat["lexicons"], ready_to_eat["weights"], 1 - mask, ready_to_eat["outputs"]) if len(config.gpu) > 1: loss = loss.mean() return loss
def test_callback_blcrf(batch_group: tuple, model: nn.Module, out_adapter: BaseOutAdapter, config: TrainerConfig) -> tuple: """ Callback of the function test for the model BLCRF. Args: batch_group (tuple): Batch tensor data. model (nn.Module): Model to be trained. out_adapter (BaseOutAdapter): The output adapter that converts tensors to raw data. config (TrainerConfig): Trainer configuration. """ batch, ready_to_eat = batch_group mask = torch.tensor(batch_seq_mask(ready_to_eat["lengths"].tolist())).to(ready_to_eat["inputs"].device).to(torch.uint8) predictions = model(ready_to_eat["inputs"], 1 - mask) lengths = ready_to_eat["lengths"].tolist() true_tags = [[out_adapter[tag_id] for tag_id in sequence[batch[i]["start"]:lengths[i]]] for i, sequence in enumerate(ready_to_eat["outputs"].tolist())] pred_tags = [[out_adapter[tag_id] for tag_id in sequence[batch[i]["start"]:lengths[i]]] for i, sequence in enumerate(predictions.tolist())] return (true_tags, pred_tags)