def test_label_process_one_hot(): """Test label numericalization.""" dummy = ['LABEL1', 'LABEL3', 'LABEL2', 'LABEL2'] field = LabelField(one_hot=True) field.setup(dummy) assert len(field.vocab) == 3 assert list(field.process('LABEL1')) == [1, 0, 0] assert list(field.process('LABEL2')) == [0, 0, 1] assert list(field.process('LABEL3')) == [0, 1, 0]
def test_label_process(): """Test label nuemricalization.""" dummy = ['LABEL1', 'LABEL3', 'LABEL2', 'LABEL2'] field = LabelField() field.setup(dummy) assert len(field.vocab) == 3 assert int(field.process('LABEL1')) == 0 assert int(field.process('LABEL2')) == 2 assert int(field.process('LABEL3')) == 1
def test_label_process_multilabel(): """Test label nuemricalization.""" dummy = ['LABEL1,LABEL2', 'LABEL3', 'LABEL2,LABEL1', 'LABEL2'] field = LabelField() field.setup(dummy) assert len(field.vocab) == 4 field = LabelField(multilabel_sep=',') field.setup(dummy) assert len(field.vocab) == 3 assert list(field.process('LABEL1,LABEL2')) == [0, 1] assert list(field.process('LABEL2,LABEL1')) == [1, 0] assert int(field.process('LABEL2')) == 1 assert int(field.process('LABEL3')) == 2
def test_pass_labels_with_unkown_2(): """Test labels specified in the init""" dummy = ['LABEL1', 'LABEL3', 'LABEL2', 'LABEL2'] field = LabelField(labels=['LABEL1', 'LABEL2', 'LABEL3']) field.setup(dummy) with pytest.raises(ValueError): list(field.process('LABEL4'))
def test_pass_bool_labels(): """Test labels specified in the init""" dummy = [True, False, True, True] field = LabelField(labels=[False, True]) field.setup(dummy) assert len(field.vocab) == 2 assert int(field.process(False)) == 0 assert int(field.process(True)) == 1 field = LabelField(labels=[True, False]) field.setup(dummy) assert len(field.vocab) == 2 assert int(field.process(False)) == 1 assert int(field.process(True)) == 0
def test_pass_labels(): """Test labels specified in the init""" dummy = ['LABEL1', 'LABEL3', 'LABEL2', 'LABEL2'] field = LabelField(labels=['LABEL1', 'LABEL2', 'LABEL3']) field.setup(dummy) assert len(field.vocab) == 3 assert int(field.process('LABEL1')) == 0 assert int(field.process('LABEL2')) == 1 assert int(field.process('LABEL3')) == 2 field = LabelField(labels=['LABEL3', 'LABEL1', 'LABEL2']) field.setup(dummy) assert len(field.vocab) == 3 assert int(field.process('LABEL1')) == 1 assert int(field.process('LABEL2')) == 2 assert int(field.process('LABEL3')) == 0