def unify(self, toprove, uni_toprove, candidates, uni_candidates, embedded_candidates): """Given two sentences compute variable matches and score.""" # toprove.shape = (R, Ps, P) # uni_toprove.shape = (R, Ps, P, E) # candidates.shape = (B, Cs, C) # uni_candidates.shape = (B, Cs, C, E) # embedded_candidates.shape = (B, Cs, C, E) # --------------------------- # Setup masks mask_toprove = (toprove != 0) # (R, Ps, P) mask_candidates = (candidates == 0) # (B, Cs, C) sim_mask = mask_candidates.astype(np.float32) * MINUS_INF # (B, Cs, C) # --------------------------- # Calculate a match for every word in s1 to every word in s2 # Compute similarity between every provable symbol and candidate symbol # (R, Ps, P, E) x (B, Cs, C, E) raw_sims = F.einsum("rpse,bcde->brpscd", uni_toprove, uni_candidates) # (B, R, Ps, P, Cs, C) # --------------------------- # Calculate attended unified word representations for toprove raw_sims += sim_mask[:, None, None, None] # (B, R, Ps, P, Cs, C) sim_weights = F.softmax(raw_sims, -1) # (B, R, Ps, P, Cs, C) sim_weights *= mask_toprove[..., None, None] # (B, R, Ps, P, Cs, C) # (B, R, Ps, P, Cs, C) x (B, Cs, C, E) unifications = F.einsum("brpscd,bcde->brpsce", sim_weights, embedded_candidates) # (B, R, Ps, P, Cs, E) return unifications, sim_weights
def _predict_d2y(self, xs, dxs, d2xs, differentiate_more): """Calculate 2nd-order prediction for each `SubNNP`. Args: xs (list [~chainer.Variable]): Input data for each `SubNNP` constituting this HDNNP instance. The shape of data is ``n_atom x (n_sample, n_input)``. dxs (list [~chainer.Variable]): Differentiated input data. The shape of data is ``n_atom x (n_sample, n_input, n_deriv)``. d2xs (list [~chainer.Variable]): Double differentiated input data. The shape of data is ``n_atom x (n_sample, n_input, n_deriv, n_deriv)``. differentiate_more (bool): If True, more deep calculation graph will be created for back-propagation or higher-order differentiation. Returns: ~chainer.Variable: Double differentiated output data. The shape of data is ``(n_sample, n_output, n_deriv, n_deriv)``. """ for nnp, x in zip(self, xs): nnp.second_differentiate(x, differentiate_more) return sum([ F.einsum('soij,six,sjy->soxy', nnp.results['d2y'], dx, dx) + F.einsum('soi,sixy->soxy', nnp.results['dy'], d2x) for nnp, dx, d2x in zip(self, dxs, d2xs) ])
def ortho_gradients(self, ortho_weighting, layer_index): weights = Variable(self.layers[layer_index].weight_matrix) reg = functions.einsum('ik, jk -> ij', weights, weights) target = reg * xp.eye(self.layers[layer_index].weight_matrix.shape[0]) ortho_loss = functions.sum((reg - target)**2) gradient = grad([ortho_loss], [weights])[0].array return ortho_weighting * gradient
def forward(self, x): h = self.res(x) if self.dropout: h = F.dropout(h) h = self.fc(h) h = F.einsum('ij, ik->ijk', h, h) h = self.conv(h[:, None, :, :]) h = h[:, 0, 0, :] return h
def calculate_logit(self, x, t=None, n_batch_axes=1): if n_batch_axes != 1: raise NotImplementedError if self.lo.W.array is None: in_size = chainer.utils.size_of_shape(x.shape[n_batch_axes:]) self.lo._initialize_params(in_size) # Standard call y = self.lo(x) if not (hasattr(x, 'lower') and hasattr(x, 'upper')): return y # Call with bounds if isinstance(t, chainer.Variable): t = t.array w = self.lo.W b = self.lo.b batchsize = x.shape[0] n_class = b.shape[0] w_correct = w[t] # (batchsize, dim) b_correct = b[t] # (batchsize, ) _ar2d = self.xp.tile(self.xp.arange(n_class), (batchsize, 1)) wrong_ids = _ar2d[_ar2d != t[:, None]].reshape( (batchsize, n_class - 1)) w_wrong = w[wrong_ids] # (batchsize, n_class - 1, dim) b_wrong = b[wrong_ids] # (batchsize, n_class - 1) w = w_wrong - w_correct[:, None, :] b = b_wrong - b_correct[:, None] w = F.transpose(w, (0, 2, 1)) # (batchsize, dim, n_class - 1) lower, upper = x.lower, x.upper c = (lower + upper) / 2. # (batchsize, dim) r = (upper - lower) / 2. c = F.einsum('ij,ijk->ik', c, w) # (batchsize, n_class - 1) if b is not None: c += b r = F.einsum('ij,ijk->ik', r, abs(w)) y.worst = c + r return y
def forward(self, stories): """Compute the forward inference pass for given stories.""" self.log = dict() # --------------------------- vctx, vq, va, supps = stories # (B, R, P, C), (B, Q), (B,), (B, I) # Embed stories # ectx = F.embed_id(vctx, wordeye, ignore_label=0) # (B, R, P, C, V) # eq = F.embed_id(vq, wordeye, ignore_label=0) # (B, Q, V) ectx = self.embed(vctx) # (B, R, P, C, V) eq = self.embed(vq) # (B, Q, V) # --------------------------- # Embed predicates embedded_preds = seq_rnn_embed(vctx, ectx, self.pred_rnn, reverse=True) # (B, R, P, E) vector_preds = vctx[ ..., 0] # (B, R, P) first character to check if pred is empty embedded_query = seq_rnn_embed(vq, eq, self.pred_rnn, reverse=True) # (B, E) embedded_rules = embedded_preds[:, :, 0] # (B, R, E) head of rule # --------------------------- # Perform iterative updates state = embedded_query # (B, E) repeated_query = F.repeat(embedded_query[:, None], vctx.shape[1], 1) # (B, R, E) rule_mask = np.all(vctx == 0, (2, 3)) # (B, R) for _ in range(supps.shape[-1]): # Compute attention over memory repeated_state = F.repeat(state[:, None], vctx.shape[1], 1) # (B, R, E) combined = F.concat([ repeated_state, embedded_rules, repeated_query, F.squared_difference(repeated_state, embedded_rules), embedded_rules * repeated_state ], -1) # (B, R, 5*E) att = F.tanh(self.att_dense1(combined, n_batch_axes=2)) # (B, R, E//2) att = self.att_dense2(att, n_batch_axes=2) # (B, R, 1) att = F.squeeze(att, -1) # (B, R) att += rule_mask * MINUS_INF # (B, R) self.tolog('raw_att', att) att = F.softmax(att) # (B, R) self.tolog('att', att) # Iterate state new_states = seq_rnn_embed( vector_preds, embedded_preds, self.unifier, initial_state=repeated_state) # (B, R, E) # Update state # (B, R) x (B, R, E) -> (B, E) state = F.einsum('br,bre->be', att, new_states) # (B, E) return self.out_linear(state)[:, 0] # (B,)
def relative_logits_1d(self, q, rel_k, H, W, Nh, transpose_mask): rel_logits = F.einsum('bhxyd,md->bhxym', q, rel_k) rel_logits = rel_logits.reshape((-1, Nh * H, W, 2 * W - 1)) rel_logits = self.rel_to_abs(rel_logits) rel_logits = rel_logits.reshape((-1, Nh, H, W, W)) rel_logits = F.expand_dims(rel_logits, axis=3) rel_logits = F.tile(rel_logits, (1, 1, 1, H, 1, 1)) rel_logits = rel_logits.transpose(transpose_mask) rel_logits = rel_logits.reshape((-1, Nh, H * W, H * W)) return rel_logits
def update_state(self, oldstate, mem_att, vmemory, ememory, iteration=0): """Update state given old, attention and new possible states.""" # oldstate.shape == (..., E) # mem_att.shape == (..., Ms) # vmemory.shape == (..., Ms, M) # ememory.shape == (..., Ms, E) ostate = F.repeat(oldstate[..., None, :], vmemory.shape[-2], -2) # (..., Ms, E) merged = F.concat([ ostate, ememory, ostate * ememory, F.squared_difference(ostate, ememory) ], -1) # (..., Ms, 4*E) mem_inter = self.state_linear(merged, n_batch_axes=len(merged.shape) - 1) # (..., Ms, E) mem_inter = F.tanh(mem_inter) # (..., E) # (..., Ms) x (..., Ms, E) -> (..., E) new_state = F.einsum("...i,...ij->...j", mem_att, mem_inter) # (..., E) return new_state
def cosine_loss(tens1, tens2, absol=True): """ Computes the cosine loss between two representations. The cos is computed per element, i.e. assumed that tens1[i] and tens2[i] correspond to the representations of which we want to compute the cos. Works only on chainer 5.x, because of the einsum. """ mat1 = _tensor_to_matrix(tens1, axis=0) mat2 = _tensor_to_matrix(tens2, axis=0) # # compute the inner product. prod = F.einsum('ij,ij->i', mat1, mat2) # # compute the norms. norm1 = F.batch_l2_norm_squared(mat1) norm2 = F.batch_l2_norm_squared(mat2) # # compute the final cosine (per element). cos = prod / F.matmul(norm1, norm2) if absol: # # We restrict the angles to [-90, 90] effectively. # # That is, we allow only positive cos. cos = F.absolute(cos) return F.mean(cos)
def blend_featuremap(self, hs, blend): return F.einsum('nijkl,nkli->njkl', hs, blend)
def __call__(self, x, sentence, att_mask=None, train=True): with chainer.using_config('train', train), chainer.using_config('enable_backprop', train): xp = cuda.get_array_module(x.data) h1 = F.leaky_relu(self.dc1(x)) h2 = F.leaky_relu(self.norm2(self.dc2(h1))) h2_ = F.leaky_relu(self.norm2_(self.dc2_(h2))) h2__ = F.leaky_relu(self.norm2__(self.dc2__(h2_))) h3 = F.leaky_relu(self.norm3(self.dc3(h2__))) h3_ = F.leaky_relu(self.norm3_(self.dc3_(h3))) h3__ = F.leaky_relu(self.norm3__(self.dc3__(h3_))) h4 = F.leaky_relu(self.norm4(self.dc4(h3__))) mean = self.dc5_mean(h4) var = F.tanh(self.dc5_var(h4)) rand = xp.random.normal(0, 1, var.data.shape).astype(np.float32) z = mean + F.exp(var) * Variable(rand) # h6 = F.leaky_relu(self.dc6(h5)) f0 = F.tanh(self.norm0(self.fc_video0(h4))) f1 = F.tanh(self.norm1(self.fc_video1(f0))) f3 = self.fc_video2(f1) self.l1_.reset_state() for i in range(sentence.shape[1]): encoded = self.l1_(sentence[:, i]) s0 = F.tanh(self.norm_text0(self.fc_text0(encoded))) s1 = self.fc_text1(s0) s2 = F.expand_dims(s1, axis=2) s2 = F.repeat(s2, self.att_size * self.att_size, axis=2) s2 = F.reshape(s2, (-1, int(8 * self.density), self.att_size, self.att_size)) m3 = f3 + s2 m3 = F.tanh(self.norm_mix(m3)) m4 = F.reshape(self.fc_mix0(m3), (-1, self.att_size * self.att_size)) # m4 = 20 * F.normalize(m4, axis=1) m4 = F.softmax(F.relu(m4), axis=1) # h0_ = F.reshape(F.max_pooling_2d(h0_, 2), (-1, 512, self.att_size * self.att_size)) f3 = F.reshape(f3, (-1, 8 * self.density, self.att_size * self.att_size)) # f2 = F.einsum('ijk,ik -> ij', h0_, h4) # features_rolled = None if train: masked = att_mask * m4 features = F.einsum('ijk,ik -> ij', f3, masked) # features_rolled = F.einsum('ijk,ik -> ij', f3, xp.roll(masked.data, 1, axis=0)) else: features = F.einsum('ijk,ik -> ij', f3, m4) # features = F.dropout(features, 0.5) #Classifier f0 = self.norm_cls0(F.leaky_relu(self.fc_cls0(features))) s2 = self.fc5(f0) c2 = self.fc6(f0) # h4 = F.reshape(h4, (-1, int(128 * self.density), self.att_size * self.att_size)) # D_broad = F.broadcast_to(self.D, (h4.shape[0], self.D.shape[0], self.D.shape[1])) # toLatent = F.reshape(F.concat((h4, D_broad), axis=1), (-1, int((128 + 8) * self.density), self.att_size * self.att_size)) # toLatent = self.norm_D(toLatent) # m4_prime = Variable(m4.data) # toZ = F.einsum('ijk,ik -> ij', toLatent, m4_prime) # mean = self.fc_toz_mean(toZ) # var = F.tanh(self.fc_toz_mean(toZ)) # rand = xp.random.normal(0, 1, var.data.shape).astype(np.float32) # z = mean + F.exp(var) * Variable(rand) # return h5, z, mean, var, encoded, features, features_rolled, h6, F.reshape(m4, (-1, 1, self.att_size, self.att_size)), s2, c2 return z, var, mean, encoded, features, F.reshape(m4, (-1, 1, self.att_size, self.att_size)), s2, c2
def forward(self, stories): """Compute the forward inference pass for given stories.""" self.log = dict() # --------------------------- vctx, vq, va, supps = stories # (B, Cs, C), (B, Q), (B, A), (B, I) # Embed stories ectx = self.embed(vctx) # (B, Cs, C, E) eq = self.embed(vq) # (B, Q, E) # --------------------------- # Prepare rules and variable states rvctx, rvq, rva, rsupps = self.vrules # (R, Ls, L), (R, Q), (R, A), (R, I) erctx, erq, era = [self.embed(v) for v in self.vrules[:-1] ] # (R, Ls, L, E), (R, Q, E), (R, A, E) # --------------------------- # Compute variable map vmap = self.compute_vmap() # (R, V) self.tolog('vmap', vmap) # --------------------------- # Indexing ranges nrules_range = np.arange(rvq.shape[0]) # (R,) # --------------------------- # Rule states rs = self.mematt.init_state(rvq, erq) # (R, E) # Original states orig_cs = self.mematt.init_state(vq, eq) # (B, E) # --------------------------- # Unify query first assuming given query is ground uni_erq = self.unification_features(rvq, erq) # (R, Q, E) uni_eq = self.unification_features(vq, eq) # (B, Q', E) qunis, q_uniatt = self.unify( rvq[:, None], uni_erq[:, None], vq[:, None], uni_eq[:, None], eq[:, None]) # (B, R, 1, Q, 1, E), (B, R, 1, Q, 1, Q') qunis = F.squeeze(qunis, (2, 4)) # (B, R, Q, E) q_uniatt = F.squeeze(q_uniatt, (2, 4)) # (B, R, Q, Q') self.tolog('q_uniatt', q_uniatt) # --------------------------- # Unified states qvgates = vmap[nrules_range[:, None], rvq] # (R, Q) qstate = qvgates[..., None] * qunis + ( 1 - qvgates[..., None]) * erq # (B, R, Q, E) brvq = np.repeat(rvq[None, ...], qstate.shape[0], 0) # (B, R, Q) uni_cs = self.mematt.init_state(brvq, qstate) # (B, R, E) # --------------------------- # Compute rule attentions num_rules = rvq.shape[0] # R if num_rules > 1: cs_feats = self.rule_linear(orig_cs) # (B, E) ratt = cs_feats @ rs.T # (B, R) ratt = F.softmax(ratt, -1) # (B, R) self.tolog('ratt', ratt) # --------------------------- # Prepare unified state if num_rules == 1: uni_cs = uni_cs[:, 0] # (B, E) else: # (B, R) x (B, R, E) -> (B, E) uni_cs = F.einsum('br,bre->be', ratt, uni_cs) # (B, E) # --------------------------- # Compute loss from unifying the query uniloss = F.mean_squared_error(uni_cs, orig_cs) # () self.tolog('uniloss', uniloss) # --------------------------- # Unify body, every symbol to every symbol uni_erctx = self.unification_features(rvctx, erctx) # (R, Ls, L, E) uni_ectx = self.unification_features(vctx, ectx) # (B, Cs, C, E) bunis, uni_att = self.unify( rvctx, uni_erctx, vctx, uni_ectx, ectx) # (B, R, Ls, L, Cs, C, E), (B, R, Ls, L, Cs, C) self.tolog('uni_att', uni_att) # --------------------------- # Setup memory sequence embeddings mem_erctx = self.mematt.seq_embed(rvctx, erctx) # (R, Ls, E) mem_ectx = self.mematt.seq_embed(vctx, ectx) # (B, Cs, E) # --------------------------- # Attention masks, and rule variable gates bodyattmask = np.all(rvctx == 0, -1) # (R, Ls) candattmask = np.all(vctx == 0, -1) # (B, Cs) ctxvgates = vmap[nrules_range[:, None, None], rvctx, None] # (R, Ls, L, 1) brvctx = np.repeat(rvctx[None, ...], vctx.shape[0], 0) # (B, R, Ls, L) # --------------------------- # Compute iterative updates on variables for t in range(supps.shape[-1]): # --------------------------- # Compute which body literal to prove using rule state raw_body_att = self.mematt(rs, rvctx, mem_erctx, bodyattmask, t) # (R, Ls) self.tolog('raw_body_att', raw_body_att) body_att = F.softmax(raw_body_att, -1) # (R, Ls) # Compute unified candidate attention raw_uni_cands_att = self.mematt(uni_cs, vctx, mem_ectx, candattmask, t) # (B, Cs) self.tolog('raw_uni_cands_att', raw_uni_cands_att) uni_cands_att = F.softmax(raw_uni_cands_att, -1) # (B, Cs) # Compute original candidate attention raw_orig_cands_att = self.mematt(orig_cs, vctx, mem_ectx, candattmask, t) # (B, Cs) self.tolog('raw_orig_cands_att', raw_orig_cands_att) orig_cands_att = F.softmax(raw_orig_cands_att, -1) # (B, Cs) # --------------------------- # Update states for the rule and original rs = self.mematt.update_state(rs, body_att, rvctx, mem_erctx, t) # (R, E) orig_cs = self.mematt.update_state(orig_cs, orig_cands_att, vctx, mem_ectx, t) # (B, E) # --------------------------- # Compute attended unification over candidates # (B, Cs) x (B, R, Ls, L, Cs, E) -> (B, R, Ls, L, E) unis = F.einsum('bc,brlsce->brlse', uni_cands_att, bunis) # (B, R, Ls, L, E) # --------------------------- # Update candidate states with new variable bindings bstate = ctxvgates * unis + ( 1 - ctxvgates) * erctx # (B, R, Ls, Ls, E) mem_bstate = self.mematt.seq_embed(brvctx, bstate) # (B, R, Ls, E) body_att = F.broadcast_to(body_att, bstate.shape[:3]) # (B, R, Ls) uni_cs = F.repeat(uni_cs[:, None], rvq.shape[0], 1) # (B, R, E) uni_cs = self.mematt.update_state(uni_cs, body_att, brvctx, mem_bstate, t) # (B, R, E) # --------------------------- # Apply rule attention if num_rules == 1: uni_cs = uni_cs[:, 0] # (B, E) else: # (B, R) x (B, R, E) -> (B, E) uni_cs = F.einsum('br,bre->be', ratt, uni_cs) # (B, E) # --- # Compute unification loss after this iteration uniloss = F.mean_squared_error(uni_cs, orig_cs) # () self.tolog('uniloss', uniloss) # --------------------------- # Compute answers based on variable and rule scores prediction = self.answer_linear(uni_cs) # (B, V) # Compute auxilary answers rpred = self.answer_linear(rs) # (R, V) self.tolog('rpred', rpred) opred = self.answer_linear(orig_cs) # (B, V) self.tolog('opred', opred) return prediction
def forward(self, ground_examples: np.ndarray): """Compute the forward inference pass for given stories.""" # ground_examples (B, 1+W*H+1) self.log = dict() # --------------------------- # Invariant ground prediction self.compute_ground_loss(self.inv_examples, log_prefix='ig') # Ground example prediction self.compute_ground_loss(ground_examples, log_prefix='o') # --------------------------- # Unification case task_ids = ground_examples[:, 0] # (B,) ground_inputs = ground_examples[:, 1:-1] # (B, W*H) invs_inputs = self.inv_examples[..., 1:-1] # (T, I, W*H) # invs_inputs = invariant_inputs[task_ids-1] # (B, I, W*H) # Embed ground examples eg = self.embed(ground_inputs) # (B, W*H, E) ei = self.embed(invs_inputs) # (T, I, W*H, E) # Extract unification features tids = F.embed_id(task_ids - 1, np.eye(TASKS, dtype=np.float32)) # (B, T) tids = F.repeat(tids[:, None], eg.shape[1], 1) # (B, W*H, T) itids = np.eye(TASKS, dtype=np.float32) # (T, T) itids = F.tile(itids[:, None, None, :], (1, invs_inputs.shape[1], invs_inputs.shape[2], 1)) # (T, I, W*H, T) egt = F.concat((eg, tids), -1) # (B, W*H, E+T) eit = F.concat((ei, itids), -1) # (T, I, W*H, E+T) egt = F.reshape(egt, egt.shape[:1] + tuple(GRID) + egt.shape[-1:]) # (B, W, H, E+T) eit = F.reshape(eit, (-1, ) + tuple(GRID) + eit.shape[-1:]) # (T*I, W, H, E+T) egt = F.swapaxes(egt, -1, -3) # (B, E+T, W, H) eit = F.swapaxes(eit, -1, -3) # (T*I, E+T, W, H) gfeats = F.relu(self.uni_conv1(egt)) # (B, E, W, H) ifeats = F.relu(self.uni_conv1(eit)) # (T*I, E, W, H) gfeats = self.uni_conv2(gfeats) # (B, E, W, H) ifeats = self.uni_conv2(ifeats) # (T*I, E, W, H) gfeats = F.reshape(gfeats, gfeats.shape[:2] + (-1, )) # (B, E, W*H) ifeats = F.reshape(ifeats, ei.shape[:2] + ifeats.shape[1:2] + (-1, )) # (T, I, E, W*H) batch_ifeats = ifeats[task_ids - 1] # (B, I, E, W*H) # (B, I, E, W*H) x (B, E, W*H) -> (B, I, W*H, W*H) uni_att = F.einsum("ijek,iel->ijkl", batch_ifeats, gfeats) # (B, I, W*H, W*H) mask = -100 * (ground_inputs == 0) # (B, W*H) cannot attend to padding uni_att += mask[:, None, None] # (B, I, W*H, W*H) uni_att = F.softmax(uni_att, axis=-1) # (B, I, W*H, W*H) self.tolog('uniatt', uni_att) # (B, I, W*H, W*H) x (B, W*H, E) -> (B, I, W*H, E) eu = F.einsum("ijkl,ile->ijke", uni_att, eg) # (B, I, W*H, E) # Compute variable map vmap = F.sigmoid(self.vmap_params * 10) # (T, I, V) mask = np.ones(VOCAB) # (V,) mask[0] = 0 # padding symbol cannot be variable vmap *= mask # (T, I, V) self.tolog('vmap', vmap) vmap = vmap[np.arange(vmap.shape[0])[:, None, None], np.arange(vmap.shape[1])[None, :, None], invs_inputs] # (T, I, W*H) vmap = vmap[task_ids - 1] # (B, I, W*H) batch_ei = ei[task_ids - 1] # (B, I, W*H, E) uni_embed = (vmap[..., None] * eu + (1 - vmap)[..., None] * batch_ei ) # (B, I, W*H, E) # Make the prediction on the unification batch_itids = itids[task_ids - 1] # (B, I, W*H, T) uni_embed = F.concat((uni_embed, batch_itids), -1) # (B, I, W*H, E+T) uni_inputs = F.reshape(uni_embed, uni_embed.shape[:2] + tuple(GRID) + uni_embed.shape[-1:]) # (B, I, W, H, E+T) uni_preds = self.predict(uni_inputs) # (B, I, V) # Aggregate results from each invariant final_uni_preds = F.sum(uni_preds, -2) # (B, V) # --------------------------- return final_uni_preds # (B, V)
def forward(self, texts): """Compute the forward inference pass for given stories.""" # texts [(L1,), (L2,), (L3,)] report = dict() # --------------------------- def sequence_embed(xs): """Embed sequences of integers.""" # xt [(L1,), (L2,), ...] xs = list(xs) # Chainer quirk expects lists x_len = [len(x) for x in xs] x_section = np.cumsum(x_len[:-1]) x_concat = F.concat(xs, axis=0) # (L1+L2...,) # ex = self.embed(x_concat) # (..., E) ex = F.embed_id(x_concat, wordembeds, ignore_label=0) ex = F.tanh(self.embed(ex)) # (..., E) uex = self.uni_embed(ex) # (..., E) uvx = self.var_linear(ex) # (..., 1) uvx = F.sigmoid(F.squeeze(uvx, -1)) # (..., ) # evx = F.concat([ex, uvx[:, None]], -1) # (..., E+1) evxs = F.split_axis(ex, x_section, 0) uexs = F.split_axis(uex, x_section, 0) uvs = F.split_axis(uvx, x_section, 0) return evxs, uexs, uvs # Ground example prediction ove, ue, uv = sequence_embed( texts ) # B x [(L1, E), (L2, E), ...], Bx[(L1, E), ...], B x [(L1,), (L2,), ...] oys, opred = self.predict(ove) # B x [(L1, E), ...], (B, 1) report['opred'] = opred # Invariant example prediction ive, iue, iuv = sequence_embed( self.inv_examples[0]) # I x [(L1, E), ...] ... iys, ipred = self.predict(ive) # I x [(L1, E), ...], (I, 1) report['igpred'] = ipred # --------------------------- # Compute padding mask padded_texts = F.pad_sequence(list(texts)).array # (B, LB) mask = -100 * (padded_texts == 0) # (B, LB) padded_itexts = F.pad_sequence(list( self.inv_examples[0])).array # (I, LI) # --------------------------- # Extract unification features oufeats = F.pad_sequence(ue) # (B, LB, E) iufeats = F.pad_sequence(iue) # (I, LI, E) iuvar = F.pad_sequence(iuv) # (I, LI) report['vmap'] = iuvar # --------------------------- # Unification attention # (I, LI, E) x (B, LB, E) -> (B, I, LI, LB) uniatt = F.einsum('ile,bfe->bilf', iufeats, oufeats) # Mask to stop attention to padding uniatt += mask[:, None, None] # (B, I, LI, LB) uniatt = F.softmax(uniatt, -1) # (B, I, LI, LB) uniatt *= (padded_itexts != 0)[..., None] # (B, I, LI, LB) report['uniatt'] = uniatt # --------------------------- # Compute unified representation padded_ove = F.pad_sequence(ove) # (B, LB, E) padded_ive = F.pad_sequence(ive) # (I, LI, E) # (B, I, LI, LB) x (B, LB, E) -> (B, I, LI, E) uve = F.einsum('bilf,bfe->bile', uniatt, padded_ove) # --- uve = iuvar[..., None] * uve + ( 1 - iuvar[..., None]) * padded_ive # (B, I, LI, E) uve = F.reshape(uve, (-1, ) + uve.shape[2:]) # (B*I, LI, E) uve = F.separate(uve, 0) # B*I x [(LI, E), ...] ulens = np.array([len(t) for t in self.inv_examples[0]] * texts.shape[0]) # (I,) uve = [seq[:l] for seq, l in zip(uve, ulens)] # I x [(L1, E), (L2, E), ..] # --------------------------- # Compute unification predictions _, upred = self.predict(uve) # (B*I, 1) upred = F.reshape( upred, (texts.shape[0], self.inv_examples[0].shape[0], 1)) # (B, I, 1) upred = F.sum(upred, 1) # (B, 1) report['upred'] = upred # --------------------------- return report
def forward(self, ground_examples): """Compute the forward inference pass for given stories.""" # ground_examples (B, 1+L+1) self.log = dict() # --------------------------- # Invariant ground prediction self.compute_ground_loss(self.inv_examples, log_prefix='ig') # Ground example prediction self.compute_ground_loss(ground_examples, log_prefix='o') # --------------------------- # Unification case task_ids = ground_examples[:, 0] # (B,) ground_inputs = ground_examples[:, 1:-1] # (B, L) invariant_inputs = self.inv_examples[..., 1:-1] # (T, I, L) invs_inputs = invariant_inputs[task_ids - 1] # (B, I, L) # Compute variable map vmap = F.sigmoid(self.vmap_params * 10) # (T, I, V) self.tolog('vmap', vmap) vmap = vmap[task_ids - 1] # (B, I, V) vmap = vmap[np.arange(vmap.shape[0])[:, None, None], np.arange(vmap.shape[1])[None, :, None], invs_inputs] # (B, I, L) # Embed ground examples eg = self.embed(ground_inputs) # (B, L, E) ei = self.embed(invariant_inputs) # (T, I, L, E) # Embed tasks for RNN init states embed_tasks = self.task_embed(task_ids - 1) # (B, E) embed_tasks = F.repeat(embed_tasks[None, ...], 2, axis=0) # (2, B, E) iembed_tasks = self.task_embed(self.inv_examples[..., 0] - 1) # (T, I, E) iembed_tasks = F.repeat(iembed_tasks[None, ...], 2, axis=0) # (2, T, I, E) iembed_tasks = F.reshape(iembed_tasks, [2, -1, EMBED]) # (2, T*I, E) # Extract unification features ground_rnn = seq_rnn_embed(eg, self.uni_birnn, init_state=embed_tasks, return_sequences=True) # (B, L, 2*E) invs_rnn = seq_rnn_embed(ei, self.uni_birnn, init_state=iembed_tasks, return_sequences=True) # (T, I, L, 2*E) ground_rnn = self.uni_linear(ground_rnn, n_batch_axes=2) # (B, L, E) invs_rnn = self.uni_linear(invs_rnn, n_batch_axes=3) # (T, I, L, E) invs_rnn = invs_rnn[task_ids - 1] # (B, I, L, E) # (B, I, L, E) x (B, L, E) -> (B, I, L, L) uni_att = F.einsum("ijke,ile->ijkl", invs_rnn, ground_rnn) # (B, I, L, L) uni_att = F.softmax(uni_att, axis=-1) # (B, I, L, L) self.tolog('uniatt', uni_att) # (B, I, L, L) x (B, L, E) -> (B, I, L, E) eu = F.einsum("ijkl,ile->ijke", uni_att, eg) # (B, I, L, E) # uni_embed = vmap[..., None]*eg[:, None] + (1-vmap)[..., None]*ei # (B, I, L, E) uni_embed = vmap[..., None] * eu + ( 1 - vmap)[..., None] * ei[task_ids - 1] # (B, I, L, E) uni_embed = F.reshape(uni_embed, uni_embed.shape[:-2] + (-1, )) # (B, I, L*E) # Make the prediction on the unification ets = F.embed_id(task_ids - 1, np.eye(TASKS, dtype=np.float32)) # (B, T) ets = F.repeat(ets[:, None], vmap.shape[1], axis=1) # (B, I, T) uni_inputs = F.concat((uni_embed, ets), axis=-1) # (B, I, L*E+T) uni_preds = self.predict(uni_inputs) # (B, I, V) # Aggregate results from each invariant final_uni_preds = F.sum(uni_preds, -2) # (B, V) # --------------------------- return final_uni_preds # (B, V)