Ejemplo n.º 1
0
seq_tokenizer = SeqsTokenizer(dict_file)
ent_tokenize = EntityTokenizer(ent_vocab_file)
desc_tokenize = DescTokenizer(code2desc_file)

# torch.cuda.is_available() checks and returns a Boolean True if a GPU is available, else it'll return False
device = torch.device('cuda:2' if torch.cuda.is_available() else 'cpu')

visit = '[CLS] D_41401 D_V4581 D_53081 D_496 D_30000 [SEP] D_4241 D_2720 D_V4581 D_4538 [SEP]'
mask_input = '[CLS] D_41401 [MASK] D_53081 D_496 D_30000 [SEP] D_4241 D_2720 D_V4581 D_4538 [SEP]'
mask_label = '[PAD] [PAD] D_V4581 [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD]'
ent_input_str = '[UNK] D_41401 D_V4581 D_53081 D_496 D_30000 [UNK] D_4241 D_2720 D_V4581 D_4538 [UNK]'
token_type = [0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1]
ent_mask =   [1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 0]
input_mask = [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]

_, mask_input = seq_tokenizer.tokenize(mask_input)
_, mask_label = seq_tokenizer.tokenize(mask_label)
_, ent_input = ent_tokenize.tokenize(ent_input_str)
_, desc_input = desc_tokenize.tokenize(ent_input_str)

# embedding entity ids
ent_input_tensor = torch.tensor(ent_input).long()
ent_input_embed = ent_embedding(ent_input_tensor).to(device)

seq_input_tensor = torch.tensor(mask_input).long().to(device)
token_type_tensor = torch.tensor(token_type).long().to(device)
desc_input_tensor = torch.tensor(desc_input).long().to(device)
ent_mask_tensor = torch.tensor(ent_mask).to(device)
input_mask_tensor = torch.tensor(input_mask).to(device)

masked_index = 2
Ejemplo n.º 2
0
seqs_file = data_path + 'outputs/kemce/data/raw/mimic.seqs'
ent_file = data_path + 'outputs/kemce/data/raw/mimic.entity'
config_json = data_path + 'src/KEMCE/kemce_config.json'

config = BertConfig.from_json_file(config_json)
seqs_tokenizer = SeqsTokenizer(seqs_vocab_file)
ent_tokenize = EntityTokenizer(ent_vocab_file)
desc_tokenize = DescTokenizer(code2desc_file)

seqs = pickle.load(open(seqs_file, 'rb'))
ents = pickle.load(open(ent_file, 'rb'))

visit_sample = seqs[0]
ent_sample = ents[0]

seq_tokens, seq_input = seqs_tokenizer.tokenize(visit_sample)
ent_tokens, ent_input = ent_tokenize.tokenize(ent_sample)
desc_tokens, desc_input = desc_tokenize.tokenize(ent_sample)

masked_index = 7
seq_tokens[masked_index] = '[MASK]'
mask_labels = [0] * len(seq_tokens)
mask_labels[masked_index] = seq_input[masked_index]
print(visit_sample)
print(seq_input)
print(mask_labels)

seq_input = seqs_tokenizer.convert_tokens_to_ids(seq_tokens)

ent_mask = []
for ent in ent_tokens: