Пример #1
0
def test_ctc_loss_batch_size(custom_ops, input_size, target_size, num_classes, batch_size, reduction_type, ipu):
    """ test across different batch-sizes """
    args = test_utils.args_from_params(
        input_size, target_size, num_classes, batch_size, reduction_type
    )
    grad_err, loss_err = run_single_case(args)
    assert grad_err < 1e-4
    assert loss_err < 1e-5
Пример #2
0
def test_rnnt_loss_batch_size(custom_ops, input_size, target_size, num_classes,
                              batch_size):
    args = test_utils.args_from_params(
        input_size,
        target_size,
        num_classes,
        batch_size,
    )
    grad_err, loss_err = run_single_case(args)
    assert grad_err < 1e-5
    assert loss_err < 1e-5
Пример #3
0
def test_rnnt_loss_variable_tgt(custom_ops, variable_input):
    params = {
        "input_size": 20,
        "target_size": 10,
        "num_classes": 5,
        "batch_size": 3,
        "variable_input": variable_input,
    }

    args = test_utils.args_from_params(**params)
    grad_err, loss_err = run_single_case(args)
    assert grad_err < 1e-5
    assert loss_err < 1e-5
Пример #4
0
def test_rnnt_loss_precision(custom_ops, precision):
    params = {
        "input_size": 6,
        "target_size": 3,
        "num_classes": 5,
        "batch_size": 3,
        "precision": precision,
    }

    args = test_utils.args_from_params(**params)
    grad_err, loss_err = run_single_case(args)
    assert grad_err < 1e-5
    assert loss_err < 1e-5
Пример #5
0
def test_ctc_loss_asr_dim(custom_ops, batch_size, precision, ipu):
    """ test for typical ASR model dimensions"""
    num_classes = 36
    input_size = 375
    target_size = 200
    reduction_type = "mean"

    args = test_utils.args_from_params(
        input_size, target_size, num_classes, batch_size, reduction_type, precision, ipu
    )
    grad_err, loss_err = run_single_case(args)
    assert grad_err < 1e-4
    assert loss_err < 1e-5
Пример #6
0
def test_rnnt_loss_logits_scale(custom_ops, logits_scale, precision):
    params = {
        "input_size": 197,
        "target_size": 116,
        "num_classes": 64,
        "batch_size": 2,
        "variable_input": True,
        "logits_scale": logits_scale,
        "precision": precision,
        "ipu": True,
    }

    args = test_utils.args_from_params(**params)
    grad_err, loss_err = run_single_case(args)
    assert grad_err < 1e-4  # (relaxing grad_error req. a bit so that it passes for larger logits-scale)
    assert loss_err < 1e-5
Пример #7
0
def test_ctc_loss_variable_tgt(custom_ops, reduction_type, variable_input, ipu):
    """ test for both constant and variable sequence lengths """
    params = {
        "input_size": 20,
        "target_size": 10,
        "num_classes": 5,
        "batch_size": 3,
        "reduction_type": reduction_type,
        "variable_input": variable_input,
        "ipu": ipu,
    }

    args = test_utils.args_from_params(**params)
    grad_err, loss_err = run_single_case(args)
    assert grad_err < 1e-4
    assert loss_err < 1e-5
Пример #8
0
def test_ctc_loss_precision(custom_ops, reduction_type, precision, ipu):
    """ test across different reduction types and precisions """
    params = {
        "input_size": 6,
        "target_size": 3,
        "num_classes": 5,
        "batch_size": 3,
        "reduction_type": reduction_type,
        "precision": precision,
        "ipu": ipu,
    }

    args = test_utils.args_from_params(**params)
    grad_err, loss_err = run_single_case(args)
    assert grad_err < 1e-4
    assert loss_err < 1e-5
Пример #9
0
def test_rnnt_loss_large_dim_wplits(custom_ops, ipu, splits):
    params = {
        "input_size": 30,
        "target_size": 15,
        "num_classes": 20,
        "batch_size": 4,
        "precision": "FLOAT16",
        "variable_input": True,
        "ipu": ipu,
    }

    if splits:
        assert sum(splits) == params["input_size"]

    args = test_utils.args_from_params(**params)
    grad_err, loss_err = run_single_case(args, splits=splits)
    assert grad_err < 1e-5
    assert loss_err < 1e-5