Ejemplo n.º 1
0
def test_pass_labels_with_unkown_1():
    """Test labels specified in the init"""
    dummy = ['LABEL1', 'LABEL3', 'LABEL2', 'LABEL2']

    field = LabelField(labels=['LABEL1', 'LABEL2'])
    with pytest.raises(ValueError):
        field.setup(dummy)
Ejemplo n.º 2
0
def test_label_process_one_hot():
    """Test label numericalization."""
    dummy = ['LABEL1', 'LABEL3', 'LABEL2', 'LABEL2']

    field = LabelField(one_hot=True)
    field.setup(dummy)

    assert len(field.vocab) == 3
    assert list(field.process('LABEL1')) == [1, 0, 0]
    assert list(field.process('LABEL2')) == [0, 0, 1]
    assert list(field.process('LABEL3')) == [0, 1, 0]
Ejemplo n.º 3
0
def test_label_process():
    """Test label nuemricalization."""
    dummy = ['LABEL1', 'LABEL3', 'LABEL2', 'LABEL2']

    field = LabelField()
    field.setup(dummy)

    assert len(field.vocab) == 3
    assert int(field.process('LABEL1')) == 0
    assert int(field.process('LABEL2')) == 2
    assert int(field.process('LABEL3')) == 1
Ejemplo n.º 4
0
def test_label_frequencies_3():
    """Test label frequencies."""
    dummy = []

    field = LabelField()
    field.setup(dummy)

    assert len(field.vocab) == 0

    assert len(field.label_freq) == 0
    assert len(field.label_count) == 0
    assert len(field.label_inv_freq) == 0
Ejemplo n.º 5
0
def test_label_process_multilabel():
    """Test label nuemricalization."""
    dummy = ['LABEL1,LABEL2', 'LABEL3', 'LABEL2,LABEL1', 'LABEL2']

    field = LabelField()
    field.setup(dummy)
    assert len(field.vocab) == 4

    field = LabelField(multilabel_sep=',')
    field.setup(dummy)
    assert len(field.vocab) == 3

    assert list(field.process('LABEL1,LABEL2')) == [0, 1]
    assert list(field.process('LABEL2,LABEL1')) == [1, 0]
    assert int(field.process('LABEL2')) == 1
    assert int(field.process('LABEL3')) == 2
Ejemplo n.º 6
0
def test_pass_bool_labels():
    """Test labels specified in the init"""
    dummy = [True, False, True, True]

    field = LabelField(labels=[False, True])
    field.setup(dummy)

    assert len(field.vocab) == 2
    assert int(field.process(False)) == 0
    assert int(field.process(True)) == 1

    field = LabelField(labels=[True, False])
    field.setup(dummy)

    assert len(field.vocab) == 2
    assert int(field.process(False)) == 1
    assert int(field.process(True)) == 0
Ejemplo n.º 7
0
def test_label_frequencies_2():
    """Test label frequencies."""
    dummy = ['LABEL1'] * 80

    field = LabelField()
    field.setup(dummy)

    assert len(field.vocab) == 1

    assert len(field.label_freq) == 1
    assert isclose(field.label_freq[0].item(), 1, rtol=NUMERIC_PRECISION)

    assert len(field.label_count) == 1
    assert isclose(field.label_count[0].item(), 80, rtol=NUMERIC_PRECISION)

    assert len(field.label_inv_freq) == 1
    assert isclose(field.label_inv_freq[0].item(), 1, rtol=NUMERIC_PRECISION)
Ejemplo n.º 8
0
def test_pass_labels():
    """Test labels specified in the init"""
    dummy = ['LABEL1', 'LABEL3', 'LABEL2', 'LABEL2']

    field = LabelField(labels=['LABEL1', 'LABEL2', 'LABEL3'])
    field.setup(dummy)

    assert len(field.vocab) == 3
    assert int(field.process('LABEL1')) == 0
    assert int(field.process('LABEL2')) == 1
    assert int(field.process('LABEL3')) == 2

    field = LabelField(labels=['LABEL3', 'LABEL1', 'LABEL2'])
    field.setup(dummy)

    assert len(field.vocab) == 3
    assert int(field.process('LABEL1')) == 1
    assert int(field.process('LABEL2')) == 2
    assert int(field.process('LABEL3')) == 0
Ejemplo n.º 9
0
def test_label_process_multilabel_one_hot_frequencies():
    """Test label numericalization."""
    dummy = ['LABEL1,LABEL2', 'LABEL3', 'LABEL2,LABEL1', 'LABEL2']

    field = LabelField(multilabel_sep=',', one_hot=True)
    field.setup(dummy)

    assert len(field.label_freq) == 3
    assert isclose(field.label_freq[0].item(), 0.333, rtol=NUMERIC_PRECISION)
    assert isclose(field.label_freq[1].item(), 0.5, rtol=NUMERIC_PRECISION)
    assert isclose(field.label_freq[2].item(), 0.166, rtol=NUMERIC_PRECISION)

    assert len(field.label_count) == 3
    assert isclose(field.label_count[0].item(), 2, rtol=NUMERIC_PRECISION)
    assert isclose(field.label_count[1].item(), 3, rtol=NUMERIC_PRECISION)
    assert isclose(field.label_count[2].item(), 1, rtol=NUMERIC_PRECISION)

    assert len(field.label_inv_freq) == 3
    assert isclose(field.label_inv_freq[0].item(), 3, rtol=NUMERIC_PRECISION)
    assert isclose(field.label_inv_freq[1].item(), 2, rtol=NUMERIC_PRECISION)
    assert isclose(field.label_inv_freq[2].item(), 6, rtol=NUMERIC_PRECISION)
Ejemplo n.º 10
0
def test_label_frequencies():
    """Test label frequencies."""
    dummy = ['LABEL1'] * 80
    dummy.extend(['LABEL2'] * 20)

    field = LabelField()
    field.setup(dummy)

    assert len(field.vocab) == 2

    assert len(field.label_freq) == 2
    assert isclose(field.label_freq[0].item(), 0.8, rtol=NUMERIC_PRECISION)
    assert isclose(field.label_freq[1].item(), 0.2, rtol=NUMERIC_PRECISION)

    assert len(field.label_count) == 2
    assert isclose(field.label_count[0].item(), 80, rtol=NUMERIC_PRECISION)
    assert isclose(field.label_count[1].item(), 20, rtol=NUMERIC_PRECISION)

    assert len(field.label_inv_freq) == 2
    assert isclose(field.label_inv_freq[0].item(), 1.25, rtol=NUMERIC_PRECISION)
    assert isclose(field.label_inv_freq[1].item(), 5, rtol=NUMERIC_PRECISION)