Example #1
0
def test_batch_seq_mask():
    """
    Test the function `batch_seq_mask`.
    """
    mask = batch_seq_mask([2, 5, 3])
    batch_mask = [[False, False, True, True, True],
                  [False, False, False, False, False],
                  [False, False, False, True, True]]
    assert mask == batch_mask
Example #2
0
def train_callback_plmtg(batch_group: tuple, model: nn.Module, loss_fn: nn.Module, config: TrainerConfig) -> torch.FloatTensor:
    """
    Callback of the function train for the model PLMTG.

    Args:
        batch_group (tuple): Batch tensor data.
        model (nn.Module): Model to be trained.
        loss_fn (nn.Module): Loss function.
        config (TrainerConfig): Trainer configuration.
    """
    _, ready_to_eat = batch_group
    mask = torch.tensor(batch_seq_mask(ready_to_eat["lengths"].tolist())).to(ready_to_eat["inputs"].device)
    logits = model(ready_to_eat["inputs"], mask)
    loss = loss_fn(logits.view(-1, logits.shape[-1]), ready_to_eat["outputs"].view(-1))

    return loss
Example #3
0
def train_callback_ses(batch_group: tuple, model: nn.Module, loss_fn: nn.Module, config: TrainerConfig) -> torch.FloatTensor:
    """
    Callback of the function train for the model SES.

    Args:
        batch_group (tuple): Batch tensor data.
        model (nn.Module): Model to be trained.
        loss_fn (nn.Module): Loss function.
        config (TrainerConfig): Trainer configuration.
    """
    _, ready_to_eat = batch_group
    mask = torch.tensor(batch_seq_mask(ready_to_eat["lengths"].tolist())).to(ready_to_eat["inputs"].device).to(torch.uint8)
    loss = model(ready_to_eat["inputs"], ready_to_eat["lexicons"], ready_to_eat["weights"], 1 - mask, ready_to_eat["outputs"])
    if len(config.gpu) > 1:
        loss = loss.mean()

    return loss
Example #4
0
def test_callback_blcrf(batch_group: tuple, model: nn.Module, out_adapter: BaseOutAdapter, config: TrainerConfig) -> tuple:
    """
    Callback of the function test for the model BLCRF.

    Args:
        batch_group (tuple): Batch tensor data.
        model (nn.Module): Model to be trained.
        out_adapter (BaseOutAdapter): The output adapter that converts tensors to raw data.
        config (TrainerConfig): Trainer configuration.
    """
    batch, ready_to_eat = batch_group
    mask = torch.tensor(batch_seq_mask(ready_to_eat["lengths"].tolist())).to(ready_to_eat["inputs"].device).to(torch.uint8)
    predictions = model(ready_to_eat["inputs"], 1 - mask)
    lengths = ready_to_eat["lengths"].tolist()
    true_tags = [[out_adapter[tag_id] for tag_id in sequence[batch[i]["start"]:lengths[i]]] for i, sequence in enumerate(ready_to_eat["outputs"].tolist())]
    pred_tags = [[out_adapter[tag_id] for tag_id in sequence[batch[i]["start"]:lengths[i]]] for i, sequence in enumerate(predictions.tolist())]

    return (true_tags, pred_tags)