예제 #1
0
def test_pass_labels_with_unkown_1():
    """Test labels specified in the init"""
    dummy = ['LABEL1', 'LABEL3', 'LABEL2', 'LABEL2']

    field = LabelField(labels=['LABEL1', 'LABEL2'])
    with pytest.raises(ValueError):
        field.setup(dummy)
예제 #2
0
def test_label_process_one_hot():
    """Test label numericalization."""
    dummy = ['LABEL1', 'LABEL3', 'LABEL2', 'LABEL2']

    field = LabelField(one_hot=True)
    field.setup(dummy)

    assert len(field.vocab) == 3
    assert list(field.process('LABEL1')) == [1, 0, 0]
    assert list(field.process('LABEL2')) == [0, 0, 1]
    assert list(field.process('LABEL3')) == [0, 1, 0]
예제 #3
0
def test_label_process():
    """Test label nuemricalization."""
    dummy = ['LABEL1', 'LABEL3', 'LABEL2', 'LABEL2']

    field = LabelField()
    field.setup(dummy)

    assert len(field.vocab) == 3
    assert int(field.process('LABEL1')) == 0
    assert int(field.process('LABEL2')) == 2
    assert int(field.process('LABEL3')) == 1
예제 #4
0
def test_label_frequencies_3():
    """Test label frequencies."""
    dummy = []

    field = LabelField()
    field.setup(dummy)

    assert len(field.vocab) == 0

    assert len(field.label_freq) == 0
    assert len(field.label_count) == 0
    assert len(field.label_inv_freq) == 0
예제 #5
0
def test_label_process_multilabel():
    """Test label nuemricalization."""
    dummy = ['LABEL1,LABEL2', 'LABEL3', 'LABEL2,LABEL1', 'LABEL2']

    field = LabelField()
    field.setup(dummy)
    assert len(field.vocab) == 4

    field = LabelField(multilabel_sep=',')
    field.setup(dummy)
    assert len(field.vocab) == 3

    assert list(field.process('LABEL1,LABEL2')) == [0, 1]
    assert list(field.process('LABEL2,LABEL1')) == [1, 0]
    assert int(field.process('LABEL2')) == 1
    assert int(field.process('LABEL3')) == 2
예제 #6
0
def test_pass_bool_labels():
    """Test labels specified in the init"""
    dummy = [True, False, True, True]

    field = LabelField(labels=[False, True])
    field.setup(dummy)

    assert len(field.vocab) == 2
    assert int(field.process(False)) == 0
    assert int(field.process(True)) == 1

    field = LabelField(labels=[True, False])
    field.setup(dummy)

    assert len(field.vocab) == 2
    assert int(field.process(False)) == 1
    assert int(field.process(True)) == 0
예제 #7
0
def test_label_frequencies_2():
    """Test label frequencies."""
    dummy = ['LABEL1'] * 80

    field = LabelField()
    field.setup(dummy)

    assert len(field.vocab) == 1

    assert len(field.label_freq) == 1
    assert isclose(field.label_freq[0].item(), 1, rtol=NUMERIC_PRECISION)

    assert len(field.label_count) == 1
    assert isclose(field.label_count[0].item(), 80, rtol=NUMERIC_PRECISION)

    assert len(field.label_inv_freq) == 1
    assert isclose(field.label_inv_freq[0].item(), 1, rtol=NUMERIC_PRECISION)
예제 #8
0
def test_pass_labels():
    """Test labels specified in the init"""
    dummy = ['LABEL1', 'LABEL3', 'LABEL2', 'LABEL2']

    field = LabelField(labels=['LABEL1', 'LABEL2', 'LABEL3'])
    field.setup(dummy)

    assert len(field.vocab) == 3
    assert int(field.process('LABEL1')) == 0
    assert int(field.process('LABEL2')) == 1
    assert int(field.process('LABEL3')) == 2

    field = LabelField(labels=['LABEL3', 'LABEL1', 'LABEL2'])
    field.setup(dummy)

    assert len(field.vocab) == 3
    assert int(field.process('LABEL1')) == 1
    assert int(field.process('LABEL2')) == 2
    assert int(field.process('LABEL3')) == 0
예제 #9
0
def test_label_process_multilabel_one_hot_frequencies():
    """Test label numericalization."""
    dummy = ['LABEL1,LABEL2', 'LABEL3', 'LABEL2,LABEL1', 'LABEL2']

    field = LabelField(multilabel_sep=',', one_hot=True)
    field.setup(dummy)

    assert len(field.label_freq) == 3
    assert isclose(field.label_freq[0].item(), 0.333, rtol=NUMERIC_PRECISION)
    assert isclose(field.label_freq[1].item(), 0.5, rtol=NUMERIC_PRECISION)
    assert isclose(field.label_freq[2].item(), 0.166, rtol=NUMERIC_PRECISION)

    assert len(field.label_count) == 3
    assert isclose(field.label_count[0].item(), 2, rtol=NUMERIC_PRECISION)
    assert isclose(field.label_count[1].item(), 3, rtol=NUMERIC_PRECISION)
    assert isclose(field.label_count[2].item(), 1, rtol=NUMERIC_PRECISION)

    assert len(field.label_inv_freq) == 3
    assert isclose(field.label_inv_freq[0].item(), 3, rtol=NUMERIC_PRECISION)
    assert isclose(field.label_inv_freq[1].item(), 2, rtol=NUMERIC_PRECISION)
    assert isclose(field.label_inv_freq[2].item(), 6, rtol=NUMERIC_PRECISION)
예제 #10
0
def test_label_frequencies():
    """Test label frequencies."""
    dummy = ['LABEL1'] * 80
    dummy.extend(['LABEL2'] * 20)

    field = LabelField()
    field.setup(dummy)

    assert len(field.vocab) == 2

    assert len(field.label_freq) == 2
    assert isclose(field.label_freq[0].item(), 0.8, rtol=NUMERIC_PRECISION)
    assert isclose(field.label_freq[1].item(), 0.2, rtol=NUMERIC_PRECISION)

    assert len(field.label_count) == 2
    assert isclose(field.label_count[0].item(), 80, rtol=NUMERIC_PRECISION)
    assert isclose(field.label_count[1].item(), 20, rtol=NUMERIC_PRECISION)

    assert len(field.label_inv_freq) == 2
    assert isclose(field.label_inv_freq[0].item(), 1.25, rtol=NUMERIC_PRECISION)
    assert isclose(field.label_inv_freq[1].item(), 5, rtol=NUMERIC_PRECISION)