def evaluate(model, x_test, y_test): inputs = NamedTensor(x_test, ("batch", "slen")) y_test = NamedTensor(y_test, ("batch", )) preds, vector = model(inputs) preds = preds.max("classes")[1] eval_acc = (preds == y_test).sum("batch").item() / len(y_test) return eval_acc, vector.cpu().detach().numpy()
def test_model(model, input_file, filename, TEXT, output_name="classes", use_cuda=False): row_num = 0 V = len(TEXT.vocab) with open(filename, "w") as fout: print('id,word', file=fout) with open(input_file, 'r') as fin: for line in tqdm(fin.readlines()): if use_cuda: batch_text = NamedTensor( Tensor([TEXT.vocab.stoi[s] for s in line.split(' ')[:-1]]).unsqueeze(1).long(), names=('seqlen', 'batch') ).cuda() else: batch_text = NamedTensor( Tensor([TEXT.vocab.stoi[s] for s in line.split(' ')[:-1]]).unsqueeze(1).long(), names=('seqlen', 'batch') ) _, best_words = ntorch.topk(model(batch_text)[{'seqlen': -1, 'classes' : slice(1, V)}], output_name, 20) best_words += 1 for row in best_words.cpu().numpy(): row_num += 1 print(f'{row_num},{tensor_to_text(row, TEXT)}', file=fout)
def forward(self, hidden): dotted = (hidden * hidden.rename("seqlen", "seqlen2")).sum("embedding") mask = torch.arange(hidden.size('seqlen')) mask = (NamedTensor(mask, names='seqlen') < NamedTensor(mask, names='seqlen2')).float() mask[mask.byte()] = -inf if self.cuda_enabled: attn = ((dotted + mask.cuda()) / (hidden.size("embedding") ** .5)).softmax('seqlen2') else: attn = ((dotted + mask) / (hidden.size("embedding") ** .5)).softmax('seqlen2') return (attn * hidden.rename('seqlen', 'seqlen2')).sum('seqlen2')
def train_test_one_split(cv, train_index, test_index): x_train, y_train = X[train_index], Y[train_index] x_test, y_test = X[test_index], Y[test_index] x_train = torch.from_numpy(x_train).long() y_train = torch.from_numpy(y_train).long() dataset_train = TensorDataset(x_train, y_train) train_loader = DataLoader( dataset_train, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=False, ) x_test = torch.from_numpy(x_test).long() y_test = torch.from_numpy(y_test).long() model = CNN(kernel_sizes, num_filters, embedding_dim, pretrained_embeddings) if cv == 0: print("\n{}\n".format(str(model))) if use_cuda: model = model.cuda() parameters = filter(lambda p: p.requires_grad, model.parameters()) optimizer = torch.optim.Adam(parameters, lr=0.0002) loss_fn = nn.CrossEntropyLoss() for epoch in range(10): tic = time.time() eval_acc, sentence_vector = evaluate(model, x_test, y_test) model.train() for i, (inputs, labels) in enumerate(train_loader): inputs, labels = inputs, labels inputs = NamedTensor(inputs, ("batch", "slen")) labels = NamedTensor(labels, ("batch", )) preds, _ = model(inputs) loss = preds.reduce2(labels, loss_fn, ("batch", "classes")) optimizer.zero_grad() loss.backward() optimizer.step() model.eval() eval_acc, sentence_vector = evaluate(model, x_test, y_test) print("[epoch: {:d}] train_loss: {:.3f} acc: {:.3f} ({:.1f}s)". format(epoch, loss.item(), eval_acc, time.time() - tic)) return eval_acc, sentence_vector
def test(epoch): model.eval() test_loss = 0 with torch.no_grad(): for i, (data, _) in enumerate(test_loader): data = data.to(device) data = NamedTensor(data, ("batch", "ch", "height", "width")) recon_batch, normal = model(data) test_loss += loss_function(recon_batch, data, normal).item() if i == 0: n = min(data.size("batch"), 8) group = [ data.narrow("batch", 0, n), recon_batch.split(x=("ch", "height", "width"), height=28, width=28).narrow("batch", 0, n), ] comparison = ntorch.cat(group, "batch") save_image( comparison.values.cpu(), "results/reconstruction_" + str(epoch) + ".png", nrow=n, ) test_loss /= len(test_loader.dataset) print("====> Test set loss: {:.4f}".format(test_loss))
def init_state(self, N): if self._N != N: self._N = N self._state = ( NamedTensor( torch.zeros(self.nlayers, N, self.rnn_sz) .type(self.dtype), names=("layers", "batch", "rnns"), ).to(self.lutx.weight.device), NamedTensor( torch.zeros(self.nlayers, N, self.rnn_sz) .type(self.dtype), names=("layers", "batch", "rnns"), ).to(self.lutx.weight.device), ) return self._state
def _gen_timing_signal(self, length, channels, min_timescale=1.0, max_timescale=1e4): """ Generates a [1, length, channels] timing signal consisting of sinusoids Adapted from: https://github.com/tensorflow/tensor2tensor/blob/master/tensor2tensor/layers/common_attention.py """ position = np.arange(length) num_timescales = channels // 2 log_timescale_increment = math.log( float(max_timescale) / float(min_timescale)) / (float(num_timescales) - 1) inv_timescales = min_timescale * np.exp( np.arange(num_timescales).astype(np.float) * -log_timescale_increment) scaled_time = np.expand_dims(position, 1) * np.expand_dims( inv_timescales, 0) signal = np.concatenate( [np.sin(scaled_time), np.cos(scaled_time)], axis=1) signal = np.pad(signal, [[0, 0], [0, channels % 2]], "constant", constant_values=[0.0, 0.0]) signal = signal.reshape([length, channels]) return NamedTensor( torch.from_numpy(signal).type(torch.FloatTensor).to(self.device), names=("seqlen", "embedding"), )
def sample_a(self, probs, logits, K, dim, sampledim): a_s = NamedTensor(torch.multinomial( probs.stack(("batch", "time"), "samplebatch").transpose("samplebatch", dim).values, K, True, ), names=("samplebatch", sampledim)).chop("samplebatch", ("batch", "time"), batch=logits.shape["batch"]) a_s_log_p = logits.gather(dim, a_s, sampledim) return a_s, a_s_log_p
def _gen_bias_mask(self, max_length): """ Generates bias values (-Inf) to mask future timesteps during attention """ np_mask = np.triu(np.full([max_length, max_length], -np.inf), 1) torch_mask = torch.from_numpy(np_mask).type(torch.FloatTensor) torch_mask = NamedTensor(torch_mask, names=("queries", "seqlen")) return torch_mask
def load_kaggle_data(path_to_data, TEXT, device): with open(path_to_data) as f: data = f.read() sentences = [sent for sent in data.split("\n")[:-1]] convert_sent_to_int = lambda sent: [ TEXT.vocab.stoi[word] for word in sent.split(" ")[:-1] ] sent_list = np.array([convert_sent_to_int(sent) for sent in sentences]) return NamedTensor(torch.from_numpy(sent_list).to(device), names=("batch", "seqlen"))
def numericalize(self, arr, device=None): vals = super(NamedField, self).numericalize(arr, device=device) if isinstance(vals, list) or isinstance(vals, tuple): assert len(vals) == 2 var, lengths = vals if self.sequential and not self.batch_first: var = NamedTensor(var, self.names + ("batch", )) else: var = NamedTensor(var, ("batch", ) + self.names) lengths = NamedTensor(lengths, ("batch", )) return var, lengths else: if self.sequential and not self.batch_first: var = NamedTensor(vals, self.names + ("batch", )) else: var = NamedTensor(vals, ("batch", ) + self.names) return var
def forward(self, x, s, x_info, r, r_info, ue, ue_info, ut, ut_info, v2d): emb = self.lutx(x) N = emb.shape["batch"] T = emb.shape["time"] e = self.lute(r[0]).rename("e", "r") t = self.lutt(r[1]).rename("t", "r") v = self.lutv(r[2]).rename("v", "r") # r: R x N x Er, Wa r: R x N x H r = self.Wa(ntorch.cat([e, t, v], dim="r").tanh()) if not self.inputfeed: # rnn_o: T x N x H rnn_o, s = self.rnn(emb, s, x_info.lengths) # ea: T x N x R _, ea, ec = attn(rnn_o, r, r_info.mask) if self.noattn: ec = r.mean("els").repeat("time", ec.shape["time"]) self.ea = ea out = self.Wc(ntorch.cat([rnn_o, ec], "rnns")).tanh() else: out = [] ect = NamedTensor( torch.zeros(N, self.r_emb_sz).to(emb.values.device), names=("batch", "rnns"), ) for t in range(T): inp = ntorch.cat([emb.get("time", t), ect.rename("rnns", "x")], "x").repeat("time", 1) rnn_o, s = self.rnn(inp, s) rnn_o = rnn_o.get("time", 0) _, eat, ect = attn(rnn_o, r, r_info.mask) out.append(ntorch.cat([rnn_o, ect], "rnns")) out = self.Wc(ntorch.stack(out, "time")).tanh() # return unnormalized vocab return self.proj(self.drop(out)), s
def __init__(self, TEXT, LABEL, hidden_attn, hidden_aligned, hidden_final, intra_attn=False, hidden_intra_attn=200, dropout=0.5, device='cpu', freeze_emb=True): super().__init__() # record parameters self.device = device self.dropout = dropout # initialize embedding self.pretrained_emb = TEXT.vocab.vectors.to(device) self.embedding = (ntorch.nn.Embedding.from_pretrained( self.pretrained_emb.values, freeze=freeze_emb).spec('seqlen', 'embedding')) self.embedding.weight[1] = torch.zeros(300) self.embedding_projection = ntorch.nn.Linear( self.pretrained_emb.shape['embedding'], 200).spec('embedding', 'embedding') emb_dim = 200 # initialize intra attn self.intra_attn = intra_attn if self.intra_attn: distance_bias = (torch.distributions.normal.Normal( 0, 0.01).sample(sample_shape=torch.Size([11]))) self.distance_bias = NamedTensor(distance_bias, names='bias') self.register_parameter("distance_bias", self.distance_bias) self.feedforward_intra_attn = MLP(emb_dim, emb_dim, 'embedding', 'hidden', self.dropout) emb_dim = 2 * emb_dim # initialize feedforward modules self.feedforward_attn = MLP(emb_dim, hidden_attn, 'embedding', 'hidden', self.dropout) self.feedforward_aligned = MLP(2 * emb_dim, hidden_aligned, 'embedding', 'hidden', self.dropout) self.feedforward_agg = MLP(2 * hidden_aligned, hidden_final, 'hidden', 'final', self.dropout) self.final_linear = ntorch.nn.Linear(hidden_final, len(LABEL.vocab)).spec( 'final', 'logit') self.to(device)
def get_distance_bias_matrix(self, dim, name1, name2): if dim > 10: vec = torch.zeros(dim) vec[:10] = self.distance_bias.values[:10] vec[10:] = torch.zeros(vec[11:].shape[0] + 1).fill_( self.distance_bias.values[10]) else: vec = self.distance_bias.values[:dim] distance_bias_matrix = torch.zeros(dim, dim) for row in range(dim): distance_bias_matrix[row, row:] = vec[0:dim - row] distance_bias_matrix = distance_bias_matrix + distance_bias_matrix.transpose( 0, 1) return NamedTensor(distance_bias_matrix.to(self.device), names=(name1, name2))
def predict(self, text, predict_last=False): """Make prediction on named tensor with dimensions 'batch' and 'seqlen' """ batch = text.transpose("batch", "seqlen").values.numpy() batch_size, text_len = batch.shape[0], batch.shape[1] predictions = np.zeros([batch_size, text_len, self.vocab_size]) for batch_id, text in enumerate(batch): for word_id, word in enumerate(text): if predict_last and word_id != (len(text) - 1): continue minus1 = word if (word_id - 1) >= 0: minus2 = text[word_id - 1] else: minus2 = None predictions[batch_id, word_id] = self._get_pred_dist(minus1, minus2) return NamedTensor(torch.from_numpy(predictions), names=("batch", "seqlen", "distribution"))
def load(device = 'cpu', pretrained_embedding = 'glove.6B.300d', embedding_dim = 300, embedding_num = 100, batch_size = 16): # Our input $x$ TEXT = NamedField(names=('seqlen',)) # Our labels $y$ LABEL = NamedField(sequential=False, names=()) # create train val test split train, val, test = torchtext.datasets.SNLI.splits(TEXT, LABEL) # build vocabs TEXT.build_vocab(train) LABEL.build_vocab(train) # create iters train_iter, val_iter = torchtext.data.BucketIterator.splits( (train, val), batch_size=batch_size, device=torch.device(device), repeat=False) test_iter = torchtext.data.BucketIterator(test, train=False, batch_size=10, device=torch.device(device)) # Build the vocabulary with word embeddings # Out-of-vocabulary (OOV) words are hashed to one of 100 random embeddings each # initialized to mean 0 and standarad deviation 1 (Sec 5.1) unk_vectors = [torch.randn(embedding_dim) for _ in range(embedding_num)] TEXT.vocab.load_vectors(vectors=pretrained_embedding, unk_init=lambda x:random.choice(unk_vectors)) # normalized to have l_2 norm of 1 vectors = TEXT.vocab.vectors vectors = vectors / vectors.norm(dim=1,keepdim=True) vectors = NamedTensor(vectors, ('word', 'embedding')) TEXT.vocab.vectors = vectors return train_iter, val_iter, test_iter, TEXT, LABEL
def train(epoch): model.train() train_loss = 0 for batch_idx, (data, _) in enumerate(train_loader): data = data.to(device) data = NamedTensor(data, ("batch", "ch", "height", "width")) optimizer.zero_grad() recon_batch, normal = model(data) loss = loss_function(recon_batch, data, normal) loss.backward() train_loss += loss.item() optimizer.step() if batch_idx % args.log_interval == 0: print("Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}".format( epoch, batch_idx * len(data), len(train_loader.dataset), 100.0 * batch_idx / len(train_loader), loss.item() / len(data), )) print("====> Epoch: {} Average loss: {:.4f}".format( epoch, train_loss / len(train_loader.dataset)))
def forward(self, text): scores = [] def _logscore(tokens, c): class_count = self.class_counts[c] if class_count > 0: logscore = 0 for token in tokens: if token == self.padding_idx: continue # log P(token | class) logscore += math.log(self.token_class_counts[(token, c)]) \ - math.log(self.class_total_counts[c]) # log P(class) logscore += math.log(class_count) else: logscore = -float('inf') return logscore logscores = [] for sent in text.unbind('batch'): logscores.append([_logscore(sent.tolist(), c) \ for c in range(self.num_classes)]) return NamedTensor(torch.Tensor(logscores), ('batch', 'classes'))
def check_sa( x, y, r, ue, ut, v2d, vt2d, d, sa_py, sa_pe, sa_pt, sa_pc, ): def top5(t, query, p): if query == "e": p5, e5 = p.get("time", t).topk("e", 5) print(" | ".join( f"{qca_model.Ve.itos[ue.get('els', x).item()]} ({p:.2f})" for p,x in zip(p5.tolist(), e5.tolist()) )) elif query == "t": p5, t5 = p.get("time", t).topk("t", 5) print(" | ".join( f"{qca_model.Vt.itos[ut.get('els', x).item()]} ({p:.2f})" for p,x in zip(p5.tolist(), t5.tolist()) )) else: raise NotImplementedError # check if alphabetical reps are close to numerical words = ["2", "4", "6", "8"] for word in words: input = ( sa_model.lutx.weight[sa_model.Vx.stoi[word]] if not sa_model.v2d else sa_model.lutv.weight[sa_model.Vv.stoi[word]] ) if sa_model.mlp: probs, idx = (sa_model.lutx.weight @ sa_model.Wvy1(sa_model.Wvy0(NamedTensor( input, names=("ctxt",) )).tanh()).tanh().values).softmax(0).topk(5) else: probs, idx = (sa_model.lutx.weight @ input).softmax(0).topk(5) print(f"{word} probs "+ " || ".join(f"{sa_model.Vx.itos[x]}: {p:.2f}" for p,x in zip(probs.tolist(), idx.tolist()))) # check if alphabetical and numerical words are aligned correctly and high prob under model ytext = [TEXT.vocab.itos[y] for y in y.tolist()] t = 185 print(ytext[t-30:t+10]) print(ytext[t]) print() print("blake griffin assists alphabetical") print(f"py: {sa_py.get('time', t).item()}") print(f"pc: {sa_pc.get('time', t).get('copy', 1).item()}") top5(t, "e", sa_pe) top5(t, "t", sa_pt) t = 182 print() print("blake griffin rebounds numerical") print(f"py: {sa_py.get('time', t).item()}") print(f"pc: {sa_pc.get('time', t).get('copy', 1).item()}") top5(t, "e", sa_pe) top5(t, "t", sa_pt) # check top pc print() print("Checking top copy prob words") probs, time = sa_pc.get("copy", 1).topk("time", 10) print(" || ".join(f"{ytext[t]}: {p:.2f}" for p,t in zip(probs.tolist(), time.tolist()))) print() print("Checking for garbage alignments => the") [(top5(i, "e", sa_pe), top5(i, "t", sa_pt)) for i,x in enumerate(ytext[:50]) if x == "the"] import pdb; pdb.set_trace()
print('LABEL.vocab', LABEL.vocab) train_iter, val_iter, test_iter = torchtext.data.BucketIterator.splits( (train, val, test), batch_size=16, device=torch.device(device), repeat=False) import random unk_vectors = [torch.randn(300) for _ in range(100)] TEXT.vocab.load_vectors(vectors='glove.6B.300d', unk_init=lambda x: random.choice(unk_vectors)) vectors = TEXT.vocab.vectors vectors = vectors / vectors.norm(dim=1, keepdim=True) ntorch_vectors = NamedTensor(vectors, ('word', 'embedding')) TEXT.vocab.vectors = ntorch_vectors def visualize_attn(model): batch = next(iter(train_iter)) a, b, y = batch.premise, batch.hypothesis, batch.label model.showAttention(a, b, TEXT) print("did it") def test_code(model, name="predictions.txt"): "All models should be able to be run with following command." upload = [] # Update: for kaggle the bucket iterator needs to have batch_size 10
def forward_sup( self, text, text_info, x, states, x_info, r, r_info, vt, ue, ue_info, ut, ut_info, v2d, y, y_info, T=None, E=None, R=None, learn=False, ): e, t, v = r N = x.shape["batch"] # posterior nll, log_pv_y, attn = self.rnnvie(text, text_info, e, t, v=v, r_info=r_info, learn=learn) pv_y = log_pv_y.exp() log_pv = self.log_pv_x(e, t) pv_y = log_pv_y.exp() soft_v = pv_y.dot( "v", NamedTensor(self.crnnlm.lutv.weight, names=("v", "r"))) nll_pv = -log_pv.gather("v", v.repeat("i", 1), "i").get( "i", 0)[r_info.mask].sum() nll_qv_y = -log_pv_y.gather("v", v.repeat("i", 1), "i").get( "i", 0)[r_info.mask].sum() v_total = r_info.mask.sum() e = self.crnnlm.lute(e).rename("e", "r") t = self.crnnlm.lutt(t).rename("t", "r") v = self.crnnlm.lutv(v).rename("v", "r") sup_r = [e.repeat("k", 1), t.repeat("k", 1), v.repeat("k", 1)] log_py_v, s = self.crnnlm(x, None, x_info, sup_r, r_info, ue, ue_info, ut, ut_info, v2d) log_py_v = log_py_v.log_softmax("vocab") y_mask = y.ne(1) nwt = y_mask.sum() nll_py_v = -log_py_v.gather("vocab", y.repeat("y", 1), "y").get( "y", 0)[y_mask].sum() if learn: # nll_qv_y / v_total is in self.rnnvie #((nll_pv + nll_qv_y) / v_total + nll_py_v / nwt).backward() (nll_pv / v_total + nll_py_v / nwt).backward() rvinfo = RvInfo( log_py_v=log_py_v, log_pv_y=log_pv_y, log_pv=log_pv, ) return rvinfo, s, nll_pv, nll_py_v, nll_qv_y, v_total, nwt
def forward( self, text, text_info, x, states, x_info, r, r_info, vt, ue, ue_info, ut, ut_info, v2d, y, y_info, T=None, E=None, R=None, learn=False, ): e, t, v = r # posterior nll, log_pv_y, attn = self.rnnvie(text, text_info, e, t) pv_y = log_pv_y.exp() log_pv = self.log_pv_x(e, t) pv_y = log_pv_y.exp() soft_v = pv_y.dot( "v", NamedTensor(self.crnnlm.lutv.weight, names=("v", "r"))) e = self.crnnlm.lute(e).rename("e", "r") t = self.crnnlm.lutt(t).rename("t", "r") #v = self.crnnlm.lutv(v).rename("v", "r") v_total = r_info.mask.sum() # sample from log_pr_x v_s, v_s_log_q = self.sample_v(pv_y, log_pv_y, self.K, "v", "k") hard_v = self.crnnlm.lutv(v_s).rename("v", "r") soft_r = [e.repeat("k", 1), t.repeat("k", 1), soft_v.repeat("k", 1)] hard_r = [e.repeat("k", self.K), t.repeat("k", self.K), hard_v] log_py_Ev, s = self.crnnlm(x, None, x_info, soft_r, r_info, ue, ue_info, ut, ut_info, v2d) log_py_v, s = self.crnnlm(x, None, x_info, hard_r, r_info, ue, ue_info, ut, ut_info, v2d) log_py_Ev = log_py_Ev.log_softmax("vocab") log_py_v = log_py_v.log_softmax("vocab") kl = pv_y * (log_pv_y - log_pv) kl_sum = kl[r_info.mask].sum() y_mask = y.ne(1) nwt = y_mask.sum() ll_soft = log_py_Ev.gather("vocab", y.repeat("y", 1), "y").get("y", 0) ll_hard = log_py_v.gather("vocab", y.repeat("y", 1), "y").get("y", 0) nll_sum = -ll_hard.mean("k")[y_mask].sum() reward = (ll_hard.detach() - ll_soft.detach()) * v_s_log_q reward_sum = -reward.mean("k")[y_mask].sum() #import pdb; pdb.set_trace() if learn: (nll_sum + kl_sum + reward_sum).div(nwt).backward() if kl_sum.item() < 0: import pdb pdb.set_trace() rvinfo = RvInfo( log_py_v=log_py_v, log_py_Ev=log_py_Ev, log_pv_y=log_pv_y, log_pv=log_pv, ) return ( rvinfo, s, nll_sum.detach(), kl_sum.detach(), kl_sum.clone().fill_(0), v_total, nwt, )
def __init__(self, lstm, unigram): self.lstm = lstm.cuda() self.unigram = NamedTensor(torch.log(unigram + 1e-6).cuda(), names='classes')
def _ie_loop( self, iter, optimizer=None, clip=0, learn=False, re=None, T=None, E=None, R=None, ): self.train(learn) context = torch.enable_grad if learn else torch.no_grad if not self.v2d: self.copy_x_to_v() cum_loss = 0 cum_ntokens = 0 cum_rx = 0 cum_kl = 0 batch_loss = 0 batch_ntokens = 0 states = None cum_TP, cum_FP, cum_FN, cum_TM, cum_TG = 0, 0, 0, 0, 0 batch_TP, batch_FP, batch_FN, batch_TM, batch_TG = 0, 0, 0, 0, 0 # for debugging cum_e_correct = 0 cum_t_correct = 0 cum_correct = 0 with context(): titer = tqdm(iter) if learn else iter for i, batch in enumerate(titer): if learn: optimizer.zero_grad() text, x_info = batch.text mask = x_info.mask lens = x_info.lengths L = text.shape["time"] x = text.narrow("time", 0, L - 1) y = text.narrow("time", 1, L - 1) #x = text[:-1] #y = text[1:] x_info.lengths.sub_(1) e, e_info = batch.entities t, t_info = batch.types v, v_info = batch.values lene = e_info.lengths lent = t_info.lengths lenv = v_info.lengths #rlen, N = e.shape #r = torch.stack([e, t, v], dim=-1) r = [e, t, v] assert (lene == lent).all() lenr = lene r_info = e_info vt, vt_info = batch.values_text ue, ue_info = batch.uentities ut, ut_info = batch.utypes v2d = batch.v2d vt2d = batch.vt2d # should i include <eos> in ppl? nwords = y.ne(1).sum() # assert nwords == lens.sum() T = y.shape["time"] N = y.shape["batch"] #if states is None: states = self.init_state(N) logits, _ = self(x, states, x_info, r, r_info, vt, ue, ue_info, ut, ut_info, v2d, vt2d) nll = self.loss(logits, y) # only for crnnlma though """ cor, ecor, tcor, tot = Ie.pat1( self.ea.rename("els", "e"), self.ta.rename("els", "t"), batch.ie_d, ) """ #ds = batch.ie_d etvs = batch.ie_etv #num_cells = batch.num_cells num_cells = float(sum(len(d) for d in etvs)) log_pe = self.ea.cpu() log_pt = self.ta.cpu() log_pc = self.log_pc.cpu() ue = ue.cpu() ut = ut.cpu() # need log_pv, need to check if self.v2d # batch x time x hid h = self.lutx(y).values log_pv = torch.einsum( "nth,vh->ntv", [h, self.lutv.weight.data]).log_softmax(-1) log_pv = NamedTensor(log_pv, names=("batch", "time", "v")) tp, fp, fn, tm, tg = pr( etvs, ue, ut, log_pe, log_pt, log_pv, log_pc=log_pc, lens=lens, Ve=self.Ve, Vt=self.Vt, Vv=self.Vv, Vx=self.Vx, text=text, ) cum_TP += tp cum_FP += fp cum_FN += fn cum_TM += tm cum_TG += tg batch_TP += tp batch_FP += fp batch_FN += fn batch_TM += tm batch_TG += tg # For debugging ds = batch.ie_et_d for batch, d in enumerate(ds): for t, (es, ts) in d.items(): #t = t + 1 _, e_max = log_pe.get("batch", batch).get("time", t).max("e") _, t_max = log_pt.get("batch", batch).get("time", t).max("t") e_preds = ue.get("batch", batch).get("els", e_max.item()) t_preds = ut.get("batch", batch).get("els", t_max.item()) correct = (es.eq(e_preds) * ts.eq(t_preds)).any().float().item() e_correct = es.eq(e_preds).any().float().item() t_correct = ts.eq(t_preds).any().float().item() cum_correct += correct cum_e_correct += e_correct cum_t_correct += t_correct if learn: if clip > 0: gnorm = clip_(self.parameters(), clip) #for param in self.rnn_parameters(): #gnorm = clip_(param, clip) optimizer.step() cum_loss += nll.item() cum_ntokens += num_cells batch_loss += nll.item() batch_ntokens += num_cells if re is not None and i % re == -1 % re: titer.set_postfix( loss=batch_loss / batch_ntokens, gnorm=gnorm, p=batch_TP / (batch_TP + batch_FP), r=batch_TP / (batch_TP + batch_FN), ) batch_loss = 0 batch_ntokens = 0 batch_TP, batch_FP, batch_FN, batch_TM, batch_TG = 0, 0, 0, 0, 0 print( f"DBG acc: {cum_correct / cum_ntokens} | E acc: {cum_e_correct / cum_ntokens} | T acc: {cum_t_correct / cum_ntokens}" ) print( f"p: {cum_TP / (cum_TP + cum_FP)} || r: {cum_TP / (cum_TP + cum_FN)}" ) print(f"total supervised cells: {cum_ntokens}") return cum_loss, cum_ntokens