def __init__(self, sizes, args): super().__init__() self.use_cuda = args.cuda self.debug = args.debug # self.embeddings_chars = CharEmbedding(sizes, EMBED_DIM) self.embeddings_forms = torch.nn.Embedding(sizes['vocab'], EMBED_DIM) self.embeddings_tags = torch.nn.Embedding(sizes['postags'], EMBED_DIM) self.lstm = torch.nn.LSTM(500 + sizes['semtags'], LSTM_DIM, LSTM_LAYERS, batch_first=True, bidirectional=True, dropout=0.33) self.mlp_head = torch.nn.Linear(2 * LSTM_DIM, REDUCE_DIM_ARC) self.mlp_dep = torch.nn.Linear(2 * LSTM_DIM, REDUCE_DIM_ARC) self.mlp_deprel_head = torch.nn.Linear(2 * LSTM_DIM, REDUCE_DIM_LABEL) self.mlp_deprel_dep = torch.nn.Linear(2 * LSTM_DIM, REDUCE_DIM_LABEL) self.mlp_tag = torch.nn.Linear(300, 150) self.out_tag = torch.nn.Linear(150, sizes['semtags']) self.lstm_tag = torch.nn.LSTM(EMBED_DIM, 150, LSTM_LAYERS - 2, batch_first=True, bidirectional=True, dropout=0.33) self.relu = torch.nn.ReLU() self.dropout = torch.nn.Dropout(p=0.33) # self.biaffine = Biaffine(REDUCE_DIM_ARC + 1, REDUCE_DIM_ARC, BATCH_SIZE) self.biaffine = ShorterBiaffine(REDUCE_DIM_ARC) self.label_biaffine = LongerBiaffine(REDUCE_DIM_LABEL, REDUCE_DIM_LABEL, sizes['deprels']) self.criterion = torch.nn.CrossEntropyLoss(ignore_index=-1) self.optimiser = torch.optim.Adam(self.parameters(), lr=LEARNING_RATE, betas=(0.9, 0.9)) if self.use_cuda: self.biaffine.cuda() self.label_biaffine.cuda()
def __init__(self, sizes, args, vocab, embeddings=None, embed_dim=100, lstm_dim=400, lstm_layers=3, reduce_dim_arc=100, reduce_dim_label=100, learning_rate=1e-3): super().__init__() self.use_cuda = args.use_cuda self.use_chars = args.use_chars self.save = args.save self.vocab = vocab # for writer self.test_file = args.test[0] if self.use_chars: self.embeddings_chars = CharEmbedding(sizes['chars'], embed_dim, lstm_dim, lstm_layers) self.embeddings_forms = torch.nn.Embedding(sizes['vocab'], embed_dim) if args.embed: self.embeddings_forms.weight.data.copy_(vocab[0].vectors) self.compress = torch.nn.Linear(300, 100) self.embeddings_tags = torch.nn.Embedding(sizes['postags'], 100) self.lstm = torch.nn.LSTM(200, lstm_dim, lstm_layers, batch_first=True, bidirectional=True, dropout=0.33) self.mlp_head = torch.nn.Linear(2 * lstm_dim, reduce_dim_arc) self.mlp_dep = torch.nn.Linear(2 * lstm_dim, reduce_dim_arc) self.mlp_deprel_head = torch.nn.Linear(2 * lstm_dim, reduce_dim_label) self.mlp_deprel_dep = torch.nn.Linear(2 * lstm_dim, reduce_dim_label) self.relu = torch.nn.ReLU() self.dropout = torch.nn.Dropout(p=0.33) # self.biaffine = Biaffine(reduce_dim_arc + 1, reduce_dim_arc, BATCH_SIZE) self.biaffine = ShorterBiaffine(reduce_dim_arc) self.label_biaffine = LongerBiaffine(reduce_dim_label, reduce_dim_label, sizes['deprels']) self.criterion = torch.nn.CrossEntropyLoss(ignore_index=-1) self.optimiser = torch.optim.Adam(self.parameters(), lr=learning_rate, betas=(0.9, 0.9)) if self.use_cuda: self.biaffine.cuda() self.label_biaffine.cuda()
class TagAndParse(torch.nn.Module): def __init__(self, sizes, args, vocab, embeddings=None, embed_dim=100, lstm_dim=400, lstm_layers=3, reduce_dim_arc=100, reduce_dim_label=100, learning_rate=1e-3): super().__init__() self.use_cuda = args.use_cuda self.use_chars = args.use_chars self.save = args.save self.vocab = vocab # for writer self.test_file = args.test[0] # for tagger self.embeddings_forms = torch.nn.Embedding(sizes['vocab'], embed_dim) if args.embed: self.embeddings_forms.weight.data.copy_(embeddings.vectors) # self.embeddings_forms.weight.requires_grad = False # self.embeddings_forms_random = torch.nn.Embedding(sizes['vocab'], embed_dim) self.tag_lstm = torch.nn.LSTM(embed_dim, 150, lstm_layers - 2, batch_first=True, bidirectional=True, dropout=0.33) self.tag_mlp = torch.nn.Linear(300, 150) self.tag_out = torch.nn.Linear(150, sizes['postags']) if self.use_chars: self.embeddings_chars = CharEmbedding(sizes['chars'], embed_dim, lstm_dim - 2, lstm_layers) self.embeddings_tags = torch.nn.Embedding(sizes['postags'], embed_dim) self.lstm = torch.nn.LSTM(embed_dim + 300 + sizes['postags'], lstm_dim, lstm_layers, batch_first=True, bidirectional=True, dropout=0.33) self.mlp_head = torch.nn.Linear(2 * lstm_dim, reduce_dim_arc) self.mlp_dep = torch.nn.Linear(2 * lstm_dim, reduce_dim_arc) self.mlp_deprel_head = torch.nn.Linear(2 * lstm_dim, reduce_dim_label) self.mlp_deprel_dep = torch.nn.Linear(2 * lstm_dim, reduce_dim_label) self.relu = torch.nn.ReLU() self.dropout = torch.nn.Dropout(p=0.33) # self.biaffine = Biaffine(reduce_dim_arc + 1, reduce_dim_arc, BATCH_SIZE) self.biaffine = ShorterBiaffine(reduce_dim_arc) self.label_biaffine = LongerBiaffine(reduce_dim_label, reduce_dim_label, sizes['deprels']) self.criterion = torch.nn.CrossEntropyLoss(ignore_index=-1) params = filter(lambda p: p.requires_grad, self.parameters()) self.optimiser = torch.optim.Adam(params, lr=learning_rate, betas=(0.9, 0.9)) if self.use_cuda: self.biaffine.cuda() self.label_biaffine.cuda() def forward(self, forms, tags, pack, chars, char_pack): form_embeds = F.dropout(self.embeddings_forms(forms), p=0.33, training=self.training) # form_embeds_random = F.dropout(self.embeddings_forms_random(forms), p=0.33, training=self.training) if self.use_chars: form_embeds += F.dropout(self.embeddings_chars(chars, char_pack), p=0.33, training=self.training) # tag packed_form = torch.nn.utils.rnn.pack_padded_sequence(form_embeds, pack.tolist(), batch_first=True) out_tag_lstm, _ = self.tag_lstm(packed_form) out_tag_lstm, _ = torch.nn.utils.rnn.pad_packed_sequence( out_tag_lstm, batch_first=True) out_tag_mlp = F.dropout(self.relu(self.tag_mlp(out_tag_lstm)), p=0.33, training=self.training) y_pred_postag = self.tag_out(out_tag_mlp) embeds = torch.cat([form_embeds, out_tag_lstm, y_pred_postag], dim=2) # pack/unpack for LSTM embeds = torch.nn.utils.rnn.pack_padded_sequence(embeds, pack.tolist(), batch_first=True) output, _ = self.lstm(embeds) output, _ = torch.nn.utils.rnn.pad_packed_sequence(output, batch_first=True) # predict heads reduced_head_head = F.dropout(self.relu(self.mlp_head(output)), p=0.33, training=self.training) reduced_head_dep = F.dropout(self.relu(self.mlp_dep(output)), p=0.33, training=self.training) y_pred_head = self.biaffine(reduced_head_head, reduced_head_dep) # predict deprels using heads reduced_deprel_head = F.dropout(self.relu( self.mlp_deprel_head(output)), p=0.33, training=self.training) reduced_deprel_dep = F.dropout(self.relu(self.mlp_deprel_dep(output)), p=0.33, training=self.training) predicted_labels = y_pred_head.max(2)[1] selected_heads = torch.stack([ torch.index_select(reduced_deprel_head[n], 0, predicted_labels[n]) for n, _ in enumerate(predicted_labels) ]) y_pred_label = self.label_biaffine(selected_heads, reduced_deprel_dep) y_pred_label = Helpers.extract_best_label_logits( predicted_labels, y_pred_label, pack) if self.use_cuda: y_pred_label = y_pred_label.cuda() return y_pred_head, y_pred_label, y_pred_postag ''' 1. the bare minimum that needs to be loaded is forms, upos, head, deprel (could change later); load those 2. initialise everything else to none; load it if necessary based on command line args 3. pass everything, whether it's been loaded or not, to the forward function; if it's unnecessary it won't use it ''' def train_(self, epoch, train_loader): self.train() train_loader.init_epoch() for i, batch in enumerate(train_loader): chars, length_per_word_per_sent = None, None ( x_forms, pack ), x_tags, y_heads, y_deprels = batch.form, batch.upos, batch.head, batch.deprel # TODO: add something similar for semtags if self.use_chars: (chars, _, length_per_word_per_sent) = batch.char y_pred_head, y_pred_deprel, y_pred_postags = self( x_forms, x_tags, pack, chars, length_per_word_per_sent) # reshape for cross-entropy batch_size, longest_sentence_in_batch = y_heads.size() # predictions: (B x S x S) => (B * S x S) # heads: (B x S) => (B * S) y_pred_head = y_pred_head.view( batch_size * longest_sentence_in_batch, -1) y_heads = y_heads.contiguous().view(batch_size * longest_sentence_in_batch) # predictions: (B x S x D) => (B * S x D) # heads: (B x S) => (B * S) y_pred_deprel = y_pred_deprel.view( batch_size * longest_sentence_in_batch, -1) y_deprels = y_deprels.contiguous().view(batch_size * longest_sentence_in_batch) # same for tags y_pred_postags = y_pred_postags.view( batch_size * longest_sentence_in_batch, -1) y_tags = x_tags.contiguous().view(batch_size * longest_sentence_in_batch) # sum losses train_loss = self.criterion(y_pred_head, y_heads) + self.criterion( y_pred_deprel, y_deprels) + 0.75 * self.criterion(y_pred_postags, y_tags) self.zero_grad() train_loss.backward() self.optimiser.step() print("Epoch: {}\t{}/{}\tloss: {}".format( epoch, (i + 1) * len(x_forms), len(train_loader.dataset), train_loss.data[0])) if self.save: with open(self.save[0], "wb") as f: torch.save(self.state_dict(), f) def evaluate_(self, test_loader, print_conll=False): las_correct, uas_correct, total = 0, 0, 0 self.eval() for i, batch in enumerate(test_loader): chars, length_per_word_per_sent = None, None ( x_forms, pack ), x_tags, y_heads, y_deprels = batch.form, batch.upos, batch.head, batch.deprel # TODO: add something similar for semtags if self.use_chars: (chars, _, length_per_word_per_sent) = batch.char mask = torch.zeros(pack.size()[0], max(pack)).type(torch.LongTensor) for n, size in enumerate(pack): mask[n, 0:size] = 1 # get labels # TODO: ensure well-formed tree y_pred_head, y_pred_deprel, y_pred_postags = [ i.max(2)[1] for i in self(x_forms, x_tags, pack, chars, length_per_word_per_sent) ] mask = mask.type(torch.ByteTensor) if self.use_cuda: mask = mask.cuda() mask = Variable(mask) mask[:, 0] = 0 heads_correct = ((y_heads == y_pred_head) * mask) deprels_correct = ((y_deprels == y_pred_deprel) * mask) # excepts should never trigger; leave them in just in case try: uas_correct += heads_correct.nonzero().size(0) except RuntimeError: pass try: las_correct += (heads_correct * deprels_correct).nonzero().size(0) except RuntimeError: pass total += mask.nonzero().size(0) if print_conll: deprel_vocab = self.vocab[1] deprels = [ deprel_vocab.itos[i.data[0]] for i in y_pred_deprel.view(-1, 1) ] heads_softmaxes = self(x_forms, x_tags, pack, chars, length_per_word_per_sent)[0][0] heads_softmaxes = F.softmax(heads_softmaxes, dim=1) json = cle.mst(heads_softmaxes.data.numpy()) # json = cle.mst(i, pad) for i, pad in zip(self(x_forms, x_tags, pack, chars, # length_per_word_per_sent)[0], pack) Helpers.write_to_conllu(self.test_file, json, deprels, i) print("UAS = {}/{} = {}\nLAS = {}/{} = {}".format( uas_correct, total, uas_correct / total, las_correct, total, las_correct / total))
class Parser(torch.nn.Module): def __init__(self, sizes, vocab, args): super().__init__() self.use_cuda = args.cuda self.debug = args.debug if args.chars: self.embeddings_chars = CharEmbedding(sizes['chars'], EMBED_DIM, LSTM_DIM, LSTM_LAYERS) self.embeddings_forms = torch.nn.Embedding(sizes['vocab'], EMBED_DIM) if args.embed: self.embeddings_forms.weight.data.copy_(vocab.vectors) self.embeddings_forms_rand = torch.nn.Embedding( sizes['vocab'], EMBED_DIM) # self.embeddings_tags = torch.nn.Embedding(sizes['postags'], EMBED_DIM) self.lstm = torch.nn.LSTM(700 + sizes['semtags'] + sizes['postags'], LSTM_DIM, LSTM_LAYERS + 1, batch_first=True, bidirectional=True, dropout=0.33) self.mlp_head = torch.nn.Linear(2 * LSTM_DIM, REDUCE_DIM_ARC) self.mlp_dep = torch.nn.Linear(2 * LSTM_DIM, REDUCE_DIM_ARC) self.mlp_deprel_head = torch.nn.Linear(2 * LSTM_DIM, REDUCE_DIM_LABEL) self.mlp_deprel_dep = torch.nn.Linear(2 * LSTM_DIM, REDUCE_DIM_LABEL) #pos self.mlp_tag = torch.nn.Linear(300, 150) self.out_tag = torch.nn.Linear(150, sizes['postags']) #sem self.mlp_semtag = torch.nn.Linear(500, 200) self.out_semtag = torch.nn.Linear(200, sizes['semtags']) self.lstm_tag = torch.nn.LSTM(EMBED_DIM * 2, 150, 1, batch_first=True, bidirectional=True, dropout=0.33) self.lstm_semtag = torch.nn.LSTM(EMBED_DIM * 5 + sizes['postags'], 250, 1, batch_first=True, bidirectional=True, dropout=0.33) self.relu = torch.nn.ReLU() self.dropout = torch.nn.Dropout(p=0.33) # self.biaffine = Biaffine(REDUCE_DIM_ARC + 1, REDUCE_DIM_ARC, BATCH_SIZE) self.biaffine = ShorterBiaffine(REDUCE_DIM_ARC) self.label_biaffine = LongerBiaffine(REDUCE_DIM_LABEL, REDUCE_DIM_LABEL, sizes['deprels']) self.criterion = torch.nn.CrossEntropyLoss(ignore_index=-1) self.optimiser = torch.optim.Adam(self.parameters(), lr=LEARNING_RATE, betas=(0.9, 0.9)) if self.use_cuda: self.biaffine.cuda() self.label_biaffine.cuda() def forward(self, forms, tags, semtags, pack, chars, char_pack): # embed and dropout forms and tags; concat # TODO: same mask embedding if args.chars: char_embeds = self.dropout(self.embeddings_chars(chars, char_pack)) form_embeds = self.dropout(self.embeddings_forms(forms)) # tag_embeds = self.dropout(self.embeddings_tags(tags)) #task-specific emb form_embeds_rand = self.dropout(self.embeddings_forms_rand(forms)) #merge all emb if args.chars: form_embeds += char_embeds form_embeds = torch.cat([form_embeds_rand, form_embeds], dim=2) else: form_embeds = torch.cat([form_embeds_rand, form_embeds], dim=2) # pack/unpack for LSTM_tag tagging_embeds = torch.nn.utils.rnn.pack_padded_sequence( form_embeds, pack.tolist(), batch_first=True) output_tag, _ = self.lstm_tag(tagging_embeds) output_tag, _ = torch.nn.utils.rnn.pad_packed_sequence( output_tag, batch_first=True) #pos mlp_tag = self.dropout(self.relu(self.mlp_tag(output_tag))) y_pred_tag = self.out_tag(mlp_tag) #concat original embeddings with POS lstm and softmaxc outs output_tag = torch.cat([output_tag, form_embeds, y_pred_tag], dim=2) # pack/unpack for LSTM_semtag semtagging_embeds = torch.nn.utils.rnn.pack_padded_sequence( output_tag, pack.tolist(), batch_first=True) output_semtag, _ = self.lstm_semtag(semtagging_embeds) output_semtag, _ = torch.nn.utils.rnn.pad_packed_sequence( output_semtag, batch_first=True) #sem mlp_semtag = self.dropout(self.relu(self.mlp_semtag(output_semtag))) y_pred_semtag = self.out_semtag(mlp_semtag) print(output_tag.size(), output_semtag.size()) #concat original embeddings with sem lstm and softmax outs embeds = torch.cat( [form_embeds, output_semtag, y_pred_semtag, y_pred_tag], dim=2) print(embeds.size()) # pack/unpack for LSTM_parse embeds = torch.nn.utils.rnn.pack_padded_sequence(embeds, pack.tolist(), batch_first=True) output, _ = self.lstm(embeds) output, _ = torch.nn.utils.rnn.pad_packed_sequence(output, batch_first=True) # predict heads reduced_head_head = self.dropout(self.relu(self.mlp_head(output))) reduced_head_dep = self.dropout(self.relu(self.mlp_dep(output))) y_pred_head = self.biaffine(reduced_head_head, reduced_head_dep) if self.debug: return y_pred_head, Variable(torch.rand(y_pred_head.size())) # predict deprels using heads reduced_deprel_head = self.dropout( self.relu(self.mlp_deprel_head(output))) reduced_deprel_dep = self.dropout( self.relu(self.mlp_deprel_dep(output))) predicted_labels = y_pred_head.max(2)[1] selected_heads = torch.stack([ torch.index_select(reduced_deprel_head[n], 0, predicted_labels[n]) for n, _ in enumerate(predicted_labels) ]) y_pred_label = self.label_biaffine(selected_heads, reduced_deprel_dep) y_pred_label = Helpers.extract_best_label_logits( predicted_labels, y_pred_label, pack) if self.use_cuda: y_pred_label = y_pred_label.cuda() return y_pred_head, y_pred_label, y_pred_semtag, y_pred_tag def train_(self, epoch, train_loader): self.train() train_loader.init_epoch() for i, batch in enumerate(train_loader): (x_forms, pack), ( chars, _, length_per_word_per_sent ), x_tags, y_heads, y_deprels, x_sem = batch.form, batch.char, batch.upos, batch.head, batch.deprel, batch.sem mask = torch.zeros(pack.size()[0], max(pack)).type(torch.LongTensor) for n, size in enumerate(pack): mask[n, 0:size] = 1 y_pred_head, y_pred_deprel, y_pred_semtag, y_pred_tag = self( x_forms, x_tags, x_sem, pack, chars, length_per_word_per_sent) # reshape for cross-entropy batch_size, longest_sentence_in_batch = y_heads.size() # predictions: (B x S x S) => (B * S x S) # heads: (B x S) => (B * S) y_pred_head = y_pred_head.view( batch_size * longest_sentence_in_batch, -1) y_heads = y_heads.contiguous().view(batch_size * longest_sentence_in_batch) # predictions: (B x S x D) => (B * S x D) # heads: (B x S) => (B * S) y_pred_deprel = y_pred_deprel.view( batch_size * longest_sentence_in_batch, -1) y_deprels = y_deprels.contiguous().view(batch_size * longest_sentence_in_batch) #sem y_pred_semtag = y_pred_semtag.view( batch_size * longest_sentence_in_batch, -1) x_sem = x_sem.contiguous().view(batch_size * longest_sentence_in_batch) #pos y_pred_tag = y_pred_tag.view( batch_size * longest_sentence_in_batch, -1) x_tags = x_tags.contiguous().view(batch_size * longest_sentence_in_batch) # sum losses train_loss = self.criterion(y_pred_head, y_heads) if not self.debug: train_loss += self.criterion(y_pred_deprel, y_deprels) train_loss += 0.5 * self.criterion(y_pred_semtag, x_sem) train_loss += 0.5 * self.criterion(y_pred_tag, x_tags) self.zero_grad() train_loss.backward() self.optimiser.step() print("Epoch: {}\t{}/{}\tloss: {}".format( epoch, (i + 1) * len(x_forms), len(train_loader.dataset), train_loss.data[0])) def evaluate_(self, test_loader): las_correct, uas_correct, semtags_correct, tags_correct, total = 0, 0, 0, 0, 0 self.eval() for i, batch in enumerate(test_loader): (x_forms, pack), ( chars, _, length_per_word_per_sent ), x_tags, y_heads, y_deprels, x_sem = batch.form, batch.char, batch.upos, batch.head, batch.deprel, batch.sem mask = torch.zeros(pack.size()[0], max(pack)).type(torch.LongTensor) for n, size in enumerate(pack): mask[n, 0:size] = 1 # get labels # TODO: ensure well-formed tree y_pred_head, y_pred_deprel, y_pred_semtag, y_pred_tag = [ i.max(2)[1] for i in self(x_forms, x_tags, x_sem, pack, chars, length_per_word_per_sent) ] mask = mask.type(torch.ByteTensor) if self.use_cuda: mask = mask.cuda() mask = Variable(mask) heads_correct = ((y_heads == y_pred_head) * mask) deprels_correct = ((y_deprels == y_pred_deprel) * mask) #tags_correct = ((x_tags == y_pred_tag) * mask) # excepts should never trigger; leave them in just in case try: uas_correct += heads_correct.nonzero().size(0) except RuntimeError: pass try: las_correct += (heads_correct * deprels_correct).nonzero().size(0) except RuntimeError: pass try: semtags_correct += ((x_sem == y_pred_semtag) * mask).nonzero().size(0) except RuntimeError: pass try: tags_correct += ((x_tags == y_pred_tag) * mask).nonzero().size(0) except RuntimeError: pass total += mask.nonzero().size(0) print( "UAS = {}/{} = {}\nLAS = {}/{} = {}\nTAG = {}/{} = {}\n\nSEMTAG = {}/{} = {}\n" .format(uas_correct, total, uas_correct / total, las_correct, total, las_correct / total, tags_correct, total, tags_correct / total, semtags_correct, total, semtags_correct / total))
class CSParser(torch.nn.Module): def __init__(self, sizes, args): super().__init__() self.use_cuda = args.cuda self.debug = args.debug # self.embeddings_chars = CharEmbedding(sizes, EMBED_DIM) self.embeddings_forms = torch.nn.Embedding(sizes['vocab'], EMBED_DIM) self.embeddings_tags = torch.nn.Embedding(sizes['postags'], EMBED_DIM) self.embeddings_langs = torch.nn.Embedding(sizes['langs'], EMBED_DIM) self.lstm = torch.nn.LSTM(2 * EMBED_DIM, LSTM_DIM, LSTM_LAYERS, batch_first=True, bidirectional=True, dropout=0.33) self.mlp_head = torch.nn.Linear(2 * LSTM_DIM, REDUCE_DIM_ARC) self.mlp_dep = torch.nn.Linear(2 * LSTM_DIM, REDUCE_DIM_ARC) self.mlp_deprel_head = torch.nn.Linear(2 * LSTM_DIM, REDUCE_DIM_LABEL) self.mlp_deprel_dep = torch.nn.Linear(2 * LSTM_DIM, REDUCE_DIM_LABEL) self.relu = torch.nn.ReLU() self.dropout = torch.nn.Dropout(p=0.33) # self.biaffine = Biaffine(REDUCE_DIM_ARC + 1, REDUCE_DIM_ARC, BATCH_SIZE) self.biaffine = ShorterBiaffine(REDUCE_DIM_ARC) self.label_biaffine = LongerBiaffine(REDUCE_DIM_LABEL, REDUCE_DIM_LABEL, sizes['deprels']) self.criterion = torch.nn.CrossEntropyLoss(ignore_index=-1) self.optimiser = torch.optim.Adam(self.parameters(), lr=LEARNING_RATE, betas=(0.9, 0.9)) # langid stuffs self.langid_mlp = torch.nn.Linear(2 * LSTM_DIM, REDUCE_DIM_ARC) self.langid_out = torch.nn.Linear(REDUCE_DIM_ARC, sizes['langs']) if self.use_cuda: self.biaffine.cuda() self.label_biaffine.cuda() def langid_fwd(self, forms, tags, pack): form_embeds = self.dropout(self.embeddings_forms(forms)) tag_embeds = self.dropout(self.embeddings_tags(tags)) embeds = torch.cat([form_embeds, tag_embeds], dim=2) embeds = torch.nn.utils.rnn.pack_padded_sequence(embeds, pack.tolist(), batch_first=True) lstm_out, _ = self.lstm(embeds) lstm_out, _ = torch.nn.utils.rnn.pad_packed_sequence(lstm_out, batch_first=True) mlp_out = self.dropout(self.relu(self.langid_mlp(lstm_out))) return self.langid_out(mlp_out) def forward(self, forms, tags, pack): # embed and dropout forms and tags; concat # TODO: same mask embedding # char_embeds = self.embeddings_chars(chars, pack) form_embeds = self.dropout(self.embeddings_forms(forms)) tag_embeds = self.dropout(self.embeddings_tags(tags)) # lang_embeds = self.dropout(self.embeddings_tags(tags)) embeds = torch.cat([form_embeds, tag_embeds], dim=2) # pack/unpack for LSTM embeds = torch.nn.utils.rnn.pack_padded_sequence(embeds, pack.tolist(), batch_first=True) output, _ = self.lstm(embeds) output, _ = torch.nn.utils.rnn.pad_packed_sequence(output, batch_first=True) # predict heads reduced_head_head = self.dropout(self.relu(self.mlp_head(output))) reduced_head_dep = self.dropout(self.relu(self.mlp_dep(output))) y_pred_head = self.biaffine(reduced_head_head, reduced_head_dep) if self.debug: return y_pred_head, Variable(torch.rand(y_pred_head.size())) # predict deprels using heads reduced_deprel_head = self.dropout( self.relu(self.mlp_deprel_head(output))) reduced_deprel_dep = self.dropout( self.relu(self.mlp_deprel_dep(output))) predicted_labels = y_pred_head.max(2)[1] selected_heads = torch.stack([ torch.index_select(reduced_deprel_head[n], 0, predicted_labels[n]) for n, _ in enumerate(predicted_labels) ]) y_pred_label = self.label_biaffine(selected_heads, reduced_deprel_dep) y_pred_label = Helpers.extract_best_label_logits( predicted_labels, y_pred_label, pack) if self.use_cuda: y_pred_label = y_pred_label.cuda() return y_pred_head, y_pred_label def train_(self, epoch, train_loader): self.train() train_loader.init_epoch() for i, batch in enumerate(train_loader): (x_forms, pack), x_tags, y_heads, y_deprels, y_langs = \ batch.form, batch.upos, batch.head, batch.deprel, batch.misc mask = torch.zeros(pack.size()[0], max(pack)).type(torch.LongTensor) for n, size in enumerate(pack): mask[n, 0:size] = 1 y_pred_head, y_pred_deprel = self(x_forms, x_tags, pack) y_pred_langs = self.langid_fwd(x_forms, x_tags, pack) # reshape for cross-entropy batch_size, longest_sentence_in_batch = y_heads.size() # predictions: (B x S x S) => (B * S x S) # heads: (B x S) => (B * S) y_pred_head = y_pred_head.view( batch_size * longest_sentence_in_batch, -1) y_heads = y_heads.contiguous().view(batch_size * longest_sentence_in_batch) # predictions: (B x S x D) => (B * S x D) # heads: (B x S) => (B * S) y_pred_deprel = y_pred_deprel.view( batch_size * longest_sentence_in_batch, -1) y_deprels = y_deprels.contiguous().view(batch_size * longest_sentence_in_batch) # langid y_pred_langs = y_pred_langs.view( batch_size * longest_sentence_in_batch, -1) y_langs = y_langs.contiguous().view(batch_size * longest_sentence_in_batch) train_loss = self.criterion(y_pred_head, y_heads) if not self.debug: train_loss += self.criterion(y_pred_deprel, y_deprels) train_loss -= self.criterion(y_pred_langs, y_langs) self.zero_grad() train_loss.backward() self.optimiser.step() print("Epoch: {}\t{}/{}\tloss: {}".format( epoch, (i + 1) * len(x_forms), len(train_loader.dataset), train_loss.data[0])) def evaluate_(self, test_loader): las_correct, uas_correct, total = 0, 0, 0 self.eval() for i, batch in enumerate(test_loader): ( x_forms, pack ), x_tags, y_heads, y_deprels = batch.form, batch.upos, batch.head, batch.deprel mask = torch.zeros(pack.size()[0], max(pack)).type(torch.LongTensor) for n, size in enumerate(pack): mask[n, 0:size] = 1 # get labels # TODO: ensure well-formed tree y_pred_head, y_pred_deprel = [ i.max(2)[1] for i in self(x_forms, x_tags, pack) ] mask = mask.type(torch.ByteTensor) if self.use_cuda: mask = mask.cuda() mask = Variable(mask) heads_correct = ((y_heads == y_pred_head) * mask) deprels_correct = ((y_deprels == y_pred_deprel) * mask) # excepts should never trigger; leave them in just in case try: uas_correct += heads_correct.nonzero().size(0) except RuntimeError: pass try: las_correct += (heads_correct * deprels_correct).nonzero().size(0) except RuntimeError: pass total += mask.nonzero().size(0) print("UAS = {}/{} = {}\nLAS = {}/{} = {}".format( uas_correct, total, uas_correct / total, las_correct, total, las_correct / total))