def test_add_smiles(self) -> None: """Test add_smiles.""" smiles = 'CCO' entities = ['Initiator', 'Monomer'] polymer_language = PolymerTokenizer(entity_names=entities) polymer_language.add_smiles(smiles) self.assertEqual(polymer_language.number_of_tokens, 41)
def test__update_max_token_sequence_length(self) -> None: """Test _update_max_token_sequence_length.""" smiles = 'CCO' entities = ['Initiator', 'Monomer', 'Catalyst'] polymer_language = PolymerTokenizer(entity_names=entities) self.assertEqual(polymer_language.max_token_sequence_length, 0) polymer_language.add_smiles(smiles) self.assertEqual(polymer_language.max_token_sequence_length, 5)
def test__update_language_dictionaries_with_tokens(self) -> None: """Test _update_language_dictionaries_with_tokens.""" smiles = 'CCO' entities = ['Initiator', 'Monomer', 'Catalyst'] polymer_language = PolymerTokenizer(entity_names=entities) polymer_language._update_language_dictionaries_with_tokens( polymer_language.smiles_tokenizer(smiles)) self.assertTrue('C' in polymer_language.token_to_index and 'O' in polymer_language.token_to_index) self.assertEqual(polymer_language.number_of_tokens, 43)
def test_add_smi(self) -> None: """Test add_smi.""" content = os.linesep.join([ 'CCO CHEMBL545', 'C CHEMBL17564', 'CO CHEMBL14688', 'NCCS CHEMBL602' ]) with TestFileContent(content) as test_file: entities = ['Initiator', 'Monomer'] polymer_language = PolymerTokenizer(entity_names=entities) polymer_language.add_smi(test_file.filename) self.assertEqual(polymer_language.number_of_tokens, 43)
def test_token_indexes_to_smiles(self) -> None: """Test token_indexes_to_smiles.""" smiles = 'CCO' entities = ['Initiator', 'Monomer', 'Catalyst'] polymer_language = PolymerTokenizer(entity_names=entities) polymer_language.add_smiles(smiles) token_indexes = [ polymer_language.token_to_index[token] for token in smiles ] self.assertEqual( polymer_language.token_indexes_to_smiles(token_indexes), 'CCO') token_indexes = ([polymer_language.token_to_index['<MONOMER_START>']] + token_indexes + [polymer_language.token_to_index['<MONOMER_STOP>']]) self.assertEqual( polymer_language.token_indexes_to_smiles(token_indexes), 'CCO')
def test_smiles_to_token_indexes(self) -> None: """Test smiles_to_token_indexes.""" smiles = 'CCO' entities = ['Initiator', 'Monomer', 'Catalyst'] polymer_language = PolymerTokenizer(entity_names=entities) polymer_language.add_smiles(smiles) token_indexes = [ polymer_language.token_to_index[token] for token in smiles ] polymer_language.update_entity('monomer') self.assertListEqual( list(polymer_language.smiles_to_token_indexes(smiles)), [polymer_language.token_to_index['<MONOMER_START>']] + token_indexes + [polymer_language.token_to_index['<MONOMER_STOP>']], ) polymer_language.update_entity('catalyst') self.assertListEqual( list(polymer_language.smiles_to_token_indexes(smiles)), [polymer_language.token_to_index['<CATALYST_START>']] + token_indexes + [polymer_language.token_to_index['<CATALYST_STOP>']], ) # SELFIES polymer_language = PolymerTokenizer(entity_names=entities, smiles_tokenizer=split_selfies) transform = Selfies() selfies = transform(smiles) polymer_language.add_smiles(selfies) token_indexes = [ polymer_language.token_to_index[token] for token in ['[C]', '[C]', '[O]'] ] polymer_language.update_entity('monomer') self.assertListEqual( list(polymer_language.smiles_to_token_indexes(selfies)), [polymer_language.token_to_index['<MONOMER_START>']] + token_indexes + [polymer_language.token_to_index['<MONOMER_STOP>']], )
def test_vocab_roundtrip(self): smiles = 'CCO' entities = ['Initiator', 'Monomer', 'Catalyst'] source_language = PolymerTokenizer(entity_names=entities) source_language.add_smiles(smiles) # to test vocab = source_language.token_to_index vocab_ = source_language.index_to_token max_len = source_language.max_token_sequence_length count = source_language.token_count total = source_language.number_of_tokens # just vocab with tempfile.TemporaryDirectory() as tempdir: source_language.save_vocabulary(tempdir) polymer_language = PolymerTokenizer(entity_names=entities) polymer_language.load_vocabulary(tempdir) self.assertDictEqual(vocab, polymer_language.token_to_index) self.assertDictEqual(vocab_, polymer_language.index_to_token) # pretrained with tempfile.TemporaryDirectory() as tempdir: source_language.save_pretrained(tempdir) polymer_language = PolymerTokenizer.from_pretrained(tempdir) self.assertDictEqual(vocab, polymer_language.token_to_index) self.assertDictEqual(vocab_, polymer_language.index_to_token) self.assertEqual(max_len, polymer_language.max_token_sequence_length) self.assertDictEqual(count, polymer_language.token_count) self.assertEqual(total, polymer_language.number_of_tokens) self.assertEqual(entities, polymer_language.entities)