コード例 #1
0
 def unify(self, toprove, uni_toprove, candidates, uni_candidates,
           embedded_candidates):
     """Given two sentences compute variable matches and score."""
     # toprove.shape = (R, Ps, P)
     # uni_toprove.shape = (R, Ps, P, E)
     # candidates.shape = (B, Cs, C)
     # uni_candidates.shape = (B, Cs, C, E)
     # embedded_candidates.shape = (B, Cs, C, E)
     # ---------------------------
     # Setup masks
     mask_toprove = (toprove != 0)  # (R, Ps, P)
     mask_candidates = (candidates == 0)  # (B, Cs, C)
     sim_mask = mask_candidates.astype(np.float32) * MINUS_INF  # (B, Cs, C)
     # ---------------------------
     # Calculate a match for every word in s1 to every word in s2
     # Compute similarity between every provable symbol and candidate symbol
     # (R, Ps, P, E) x (B, Cs, C, E)
     raw_sims = F.einsum("rpse,bcde->brpscd", uni_toprove,
                         uni_candidates)  # (B, R, Ps, P, Cs, C)
     # ---------------------------
     # Calculate attended unified word representations for toprove
     raw_sims += sim_mask[:, None, None, None]  # (B, R, Ps, P, Cs, C)
     sim_weights = F.softmax(raw_sims, -1)  # (B, R, Ps, P, Cs, C)
     sim_weights *= mask_toprove[..., None, None]  # (B, R, Ps, P, Cs, C)
     # (B, R, Ps, P, Cs, C) x (B, Cs, C, E)
     unifications = F.einsum("brpscd,bcde->brpsce", sim_weights,
                             embedded_candidates)  # (B, R, Ps, P, Cs, E)
     return unifications, sim_weights
コード例 #2
0
ファイル: models.py プロジェクト: sh-okugawa/HDNNP-tools
    def _predict_d2y(self, xs, dxs, d2xs, differentiate_more):
        """Calculate 2nd-order prediction for each `SubNNP`.

        Args:
            xs (list [~chainer.Variable]):
                Input data for each `SubNNP` constituting this HDNNP
                instance. The shape of data is
                ``n_atom x (n_sample, n_input)``.
            dxs (list [~chainer.Variable]):
                Differentiated input data. The shape of data is
                ``n_atom x (n_sample, n_input, n_deriv)``.
            d2xs (list [~chainer.Variable]):
                Double differentiated input data. The shape of data is
                ``n_atom x (n_sample, n_input, n_deriv, n_deriv)``.
            differentiate_more (bool):
                If True, more deep calculation graph will be created for
                back-propagation or higher-order differentiation.

        Returns:
            ~chainer.Variable:
                Double differentiated output data. The shape of data is
                ``(n_sample, n_output, n_deriv, n_deriv)``.
        """
        for nnp, x in zip(self, xs):
            nnp.second_differentiate(x, differentiate_more)
        return sum([
            F.einsum('soij,six,sjy->soxy', nnp.results['d2y'], dx, dx) +
            F.einsum('soi,sixy->soxy', nnp.results['dy'], d2x)
            for nnp, dx, d2x in zip(self, dxs, d2xs)
        ])
コード例 #3
0
ファイル: inv_net.py プロジェクト: nasiryahm/GAIT-prop
    def ortho_gradients(self, ortho_weighting, layer_index):
        weights = Variable(self.layers[layer_index].weight_matrix)
        reg = functions.einsum('ik, jk -> ij', weights, weights)

        target = reg * xp.eye(self.layers[layer_index].weight_matrix.shape[0])
        ortho_loss = functions.sum((reg - target)**2)
        gradient = grad([ortho_loss], [weights])[0].array
        return ortho_weighting * gradient
コード例 #4
0
 def forward(self, x):
     h = self.res(x)
     if self.dropout:
         h = F.dropout(h)
     h = self.fc(h)
     h = F.einsum('ij, ik->ijk', h, h)
     h = self.conv(h[:, None, :, :])
     h = h[:, 0, 0, :]
     return h
コード例 #5
0
    def calculate_logit(self, x, t=None, n_batch_axes=1):
        if n_batch_axes != 1:
            raise NotImplementedError
        if self.lo.W.array is None:
            in_size = chainer.utils.size_of_shape(x.shape[n_batch_axes:])
            self.lo._initialize_params(in_size)

        # Standard call
        y = self.lo(x)
        if not (hasattr(x, 'lower') and hasattr(x, 'upper')):
            return y

        # Call with bounds
        if isinstance(t, chainer.Variable):
            t = t.array
        w = self.lo.W
        b = self.lo.b
        batchsize = x.shape[0]
        n_class = b.shape[0]

        w_correct = w[t]  # (batchsize, dim)
        b_correct = b[t]  # (batchsize, )

        _ar2d = self.xp.tile(self.xp.arange(n_class), (batchsize, 1))
        wrong_ids = _ar2d[_ar2d != t[:, None]].reshape(
            (batchsize, n_class - 1))
        w_wrong = w[wrong_ids]  # (batchsize, n_class - 1, dim)
        b_wrong = b[wrong_ids]  # (batchsize, n_class - 1)

        w = w_wrong - w_correct[:, None, :]
        b = b_wrong - b_correct[:, None]
        w = F.transpose(w, (0, 2, 1))  # (batchsize, dim, n_class - 1)

        lower, upper = x.lower, x.upper
        c = (lower + upper) / 2.  # (batchsize, dim)
        r = (upper - lower) / 2.
        c = F.einsum('ij,ijk->ik', c, w)  # (batchsize, n_class - 1)
        if b is not None:
            c += b
        r = F.einsum('ij,ijk->ik', r, abs(w))
        y.worst = c + r
        return y
コード例 #6
0
ファイル: ima.py プロジェクト: nuric/softuni
 def forward(self, stories):
     """Compute the forward inference pass for given stories."""
     self.log = dict()
     # ---------------------------
     vctx, vq, va, supps = stories  # (B, R, P, C), (B, Q), (B,), (B, I)
     # Embed stories
     # ectx = F.embed_id(vctx, wordeye, ignore_label=0) # (B, R, P, C, V)
     # eq = F.embed_id(vq, wordeye, ignore_label=0) # (B, Q, V)
     ectx = self.embed(vctx)  # (B, R, P, C, V)
     eq = self.embed(vq)  # (B, Q, V)
     # ---------------------------
     # Embed predicates
     embedded_preds = seq_rnn_embed(vctx, ectx, self.pred_rnn,
                                    reverse=True)  # (B, R, P, E)
     vector_preds = vctx[
         ..., 0]  # (B, R, P) first character to check if pred is empty
     embedded_query = seq_rnn_embed(vq, eq, self.pred_rnn,
                                    reverse=True)  # (B, E)
     embedded_rules = embedded_preds[:, :, 0]  # (B, R, E) head of rule
     # ---------------------------
     # Perform iterative updates
     state = embedded_query  # (B, E)
     repeated_query = F.repeat(embedded_query[:, None], vctx.shape[1],
                               1)  # (B, R, E)
     rule_mask = np.all(vctx == 0, (2, 3))  # (B, R)
     for _ in range(supps.shape[-1]):
         # Compute attention over memory
         repeated_state = F.repeat(state[:, None], vctx.shape[1],
                                   1)  # (B, R, E)
         combined = F.concat([
             repeated_state, embedded_rules, repeated_query,
             F.squared_difference(repeated_state, embedded_rules),
             embedded_rules * repeated_state
         ], -1)  # (B, R, 5*E)
         att = F.tanh(self.att_dense1(combined,
                                      n_batch_axes=2))  # (B, R, E//2)
         att = self.att_dense2(att, n_batch_axes=2)  # (B, R, 1)
         att = F.squeeze(att, -1)  # (B, R)
         att += rule_mask * MINUS_INF  # (B, R)
         self.tolog('raw_att', att)
         att = F.softmax(att)  # (B, R)
         self.tolog('att', att)
         # Iterate state
         new_states = seq_rnn_embed(
             vector_preds,
             embedded_preds,
             self.unifier,
             initial_state=repeated_state)  # (B, R, E)
         # Update state
         # (B, R) x (B, R, E) -> (B, E)
         state = F.einsum('br,bre->be', att, new_states)  # (B, E)
     return self.out_linear(state)[:, 0]  # (B,)
コード例 #7
0
    def relative_logits_1d(self, q, rel_k, H, W, Nh, transpose_mask):
        rel_logits = F.einsum('bhxyd,md->bhxym', q, rel_k)

        rel_logits = rel_logits.reshape((-1, Nh * H, W, 2 * W - 1))
        rel_logits = self.rel_to_abs(rel_logits)

        rel_logits = rel_logits.reshape((-1, Nh, H, W, W))
        rel_logits = F.expand_dims(rel_logits, axis=3)
        rel_logits = F.tile(rel_logits, (1, 1, 1, H, 1, 1))

        rel_logits = rel_logits.transpose(transpose_mask)
        rel_logits = rel_logits.reshape((-1, Nh, H * W, H * W))

        return rel_logits
コード例 #8
0
 def update_state(self, oldstate, mem_att, vmemory, ememory, iteration=0):
     """Update state given old, attention and new possible states."""
     # oldstate.shape == (..., E)
     # mem_att.shape == (..., Ms)
     # vmemory.shape == (..., Ms, M)
     # ememory.shape == (..., Ms, E)
     ostate = F.repeat(oldstate[..., None, :], vmemory.shape[-2],
                       -2)  # (..., Ms, E)
     merged = F.concat([
         ostate, ememory, ostate * ememory,
         F.squared_difference(ostate, ememory)
     ], -1)  # (..., Ms, 4*E)
     mem_inter = self.state_linear(merged,
                                   n_batch_axes=len(merged.shape) -
                                   1)  # (..., Ms, E)
     mem_inter = F.tanh(mem_inter)  # (..., E)
     # (..., Ms) x (..., Ms, E) -> (..., E)
     new_state = F.einsum("...i,...ij->...j", mem_att,
                          mem_inter)  # (..., E)
     return new_state
コード例 #9
0
def cosine_loss(tens1, tens2, absol=True):
    """
    Computes the cosine loss between two representations.
    The cos is computed per element, i.e. assumed that 
    tens1[i] and tens2[i] correspond to the representations
    of which we want to compute the cos.
    Works only on chainer 5.x, because of the einsum.
    """
    mat1 = _tensor_to_matrix(tens1, axis=0)
    mat2 = _tensor_to_matrix(tens2, axis=0)
    # # compute the inner product.
    prod = F.einsum('ij,ij->i', mat1, mat2)
    # # compute the norms.
    norm1 = F.batch_l2_norm_squared(mat1)
    norm2 = F.batch_l2_norm_squared(mat2)
    # # compute the final cosine (per element).
    cos = prod / F.matmul(norm1, norm2)
    if absol:
        # # We restrict the angles to [-90, 90] effectively.
        # # That is, we allow only positive cos.
        cos = F.absolute(cos)
    return F.mean(cos)
コード例 #10
0
 def blend_featuremap(self, hs, blend):
     return F.einsum('nijkl,nkli->njkl', hs, blend)
コード例 #11
0
    def __call__(self, x, sentence, att_mask=None, train=True):
        with chainer.using_config('train', train), chainer.using_config('enable_backprop', train):
            xp = cuda.get_array_module(x.data)

            h1 = F.leaky_relu(self.dc1(x))
            h2 = F.leaky_relu(self.norm2(self.dc2(h1)))
            h2_ = F.leaky_relu(self.norm2_(self.dc2_(h2)))
            h2__ = F.leaky_relu(self.norm2__(self.dc2__(h2_)))
            h3 = F.leaky_relu(self.norm3(self.dc3(h2__)))
            h3_ = F.leaky_relu(self.norm3_(self.dc3_(h3)))
            h3__ = F.leaky_relu(self.norm3__(self.dc3__(h3_)))
            h4 = F.leaky_relu(self.norm4(self.dc4(h3__)))
            mean = self.dc5_mean(h4)
            var = F.tanh(self.dc5_var(h4))
            rand = xp.random.normal(0, 1, var.data.shape).astype(np.float32)
            z = mean + F.exp(var) * Variable(rand)
            # h6 = F.leaky_relu(self.dc6(h5))

            f0 = F.tanh(self.norm0(self.fc_video0(h4)))
            f1 = F.tanh(self.norm1(self.fc_video1(f0)))
            f3 = self.fc_video2(f1)

            self.l1_.reset_state()
            for i in range(sentence.shape[1]):
                encoded = self.l1_(sentence[:, i])

            s0 = F.tanh(self.norm_text0(self.fc_text0(encoded)))
            s1 = self.fc_text1(s0)
            s2 = F.expand_dims(s1, axis=2)
            s2 = F.repeat(s2, self.att_size * self.att_size, axis=2)
            s2 = F.reshape(s2, (-1, int(8 * self.density), self.att_size, self.att_size))

            m3 = f3 + s2
            m3 = F.tanh(self.norm_mix(m3))
            m4 = F.reshape(self.fc_mix0(m3), (-1, self.att_size * self.att_size))
            # m4 = 20 * F.normalize(m4, axis=1)
            m4 = F.softmax(F.relu(m4), axis=1)

            # h0_ = F.reshape(F.max_pooling_2d(h0_, 2), (-1, 512, self.att_size * self.att_size))
            f3 = F.reshape(f3, (-1, 8 * self.density, self.att_size * self.att_size))
            # f2 = F.einsum('ijk,ik -> ij', h0_, h4)
            # features_rolled = None
            if train:
                masked = att_mask * m4
                features = F.einsum('ijk,ik -> ij', f3, masked)
                # features_rolled = F.einsum('ijk,ik -> ij', f3, xp.roll(masked.data, 1, axis=0))
            else:
                features = F.einsum('ijk,ik -> ij', f3, m4)
            # features = F.dropout(features, 0.5)
            #Classifier
            f0 = self.norm_cls0(F.leaky_relu(self.fc_cls0(features)))
            s2 = self.fc5(f0)
            c2 = self.fc6(f0)

            # h4 = F.reshape(h4, (-1, int(128 * self.density), self.att_size * self.att_size))
            # D_broad = F.broadcast_to(self.D, (h4.shape[0], self.D.shape[0], self.D.shape[1]))
            # toLatent = F.reshape(F.concat((h4, D_broad), axis=1), (-1, int((128 + 8) * self.density), self.att_size * self.att_size))
            # toLatent = self.norm_D(toLatent)

            # m4_prime = Variable(m4.data)
            # toZ = F.einsum('ijk,ik -> ij', toLatent, m4_prime)
            # mean = self.fc_toz_mean(toZ)
            # var = F.tanh(self.fc_toz_mean(toZ))
            # rand = xp.random.normal(0, 1, var.data.shape).astype(np.float32)
            # z = mean + F.exp(var) * Variable(rand)
            # return h5, z, mean, var, encoded, features, features_rolled, h6, F.reshape(m4, (-1, 1, self.att_size, self.att_size)), s2, c2
            return z, var, mean, encoded, features, F.reshape(m4, (-1, 1, self.att_size, self.att_size)), s2, c2
コード例 #12
0
 def forward(self, stories):
     """Compute the forward inference pass for given stories."""
     self.log = dict()
     # ---------------------------
     vctx, vq, va, supps = stories  # (B, Cs, C), (B, Q), (B, A), (B, I)
     # Embed stories
     ectx = self.embed(vctx)  # (B, Cs, C, E)
     eq = self.embed(vq)  # (B, Q, E)
     # ---------------------------
     # Prepare rules and variable states
     rvctx, rvq, rva, rsupps = self.vrules  # (R, Ls, L), (R, Q), (R, A), (R, I)
     erctx, erq, era = [self.embed(v) for v in self.vrules[:-1]
                        ]  # (R, Ls, L, E), (R, Q, E), (R, A, E)
     # ---------------------------
     # Compute variable map
     vmap = self.compute_vmap()  # (R, V)
     self.tolog('vmap', vmap)
     # ---------------------------
     # Indexing ranges
     nrules_range = np.arange(rvq.shape[0])  # (R,)
     # ---------------------------
     # Rule states
     rs = self.mematt.init_state(rvq, erq)  # (R, E)
     # Original states
     orig_cs = self.mematt.init_state(vq, eq)  # (B, E)
     # ---------------------------
     # Unify query first assuming given query is ground
     uni_erq = self.unification_features(rvq, erq)  # (R, Q, E)
     uni_eq = self.unification_features(vq, eq)  # (B, Q', E)
     qunis, q_uniatt = self.unify(
         rvq[:, None], uni_erq[:, None], vq[:, None], uni_eq[:, None],
         eq[:, None])  # (B, R, 1, Q, 1, E), (B, R, 1, Q, 1, Q')
     qunis = F.squeeze(qunis, (2, 4))  # (B, R, Q, E)
     q_uniatt = F.squeeze(q_uniatt, (2, 4))  # (B, R, Q, Q')
     self.tolog('q_uniatt', q_uniatt)
     # ---------------------------
     # Unified states
     qvgates = vmap[nrules_range[:, None], rvq]  # (R, Q)
     qstate = qvgates[..., None] * qunis + (
         1 - qvgates[..., None]) * erq  # (B, R, Q, E)
     brvq = np.repeat(rvq[None, ...], qstate.shape[0], 0)  # (B, R, Q)
     uni_cs = self.mematt.init_state(brvq, qstate)  # (B, R, E)
     # ---------------------------
     # Compute rule attentions
     num_rules = rvq.shape[0]  # R
     if num_rules > 1:
         cs_feats = self.rule_linear(orig_cs)  # (B, E)
         ratt = cs_feats @ rs.T  # (B, R)
         ratt = F.softmax(ratt, -1)  # (B, R)
         self.tolog('ratt', ratt)
     # ---------------------------
     # Prepare unified state
     if num_rules == 1:
         uni_cs = uni_cs[:, 0]  # (B, E)
     else:
         # (B, R) x (B, R, E) -> (B, E)
         uni_cs = F.einsum('br,bre->be', ratt, uni_cs)  # (B, E)
     # ---------------------------
     # Compute loss from unifying the query
     uniloss = F.mean_squared_error(uni_cs, orig_cs)  # ()
     self.tolog('uniloss', uniloss)
     # ---------------------------
     # Unify body, every symbol to every symbol
     uni_erctx = self.unification_features(rvctx, erctx)  # (R, Ls, L, E)
     uni_ectx = self.unification_features(vctx, ectx)  # (B, Cs, C, E)
     bunis, uni_att = self.unify(
         rvctx, uni_erctx, vctx, uni_ectx,
         ectx)  # (B, R, Ls, L, Cs, C, E), (B, R, Ls, L, Cs, C)
     self.tolog('uni_att', uni_att)
     # ---------------------------
     # Setup memory sequence embeddings
     mem_erctx = self.mematt.seq_embed(rvctx, erctx)  # (R, Ls, E)
     mem_ectx = self.mematt.seq_embed(vctx, ectx)  # (B, Cs, E)
     # ---------------------------
     # Attention masks, and rule variable gates
     bodyattmask = np.all(rvctx == 0, -1)  # (R, Ls)
     candattmask = np.all(vctx == 0, -1)  # (B, Cs)
     ctxvgates = vmap[nrules_range[:, None, None], rvctx,
                      None]  # (R, Ls, L, 1)
     brvctx = np.repeat(rvctx[None, ...], vctx.shape[0], 0)  # (B, R, Ls, L)
     # ---------------------------
     # Compute iterative updates on variables
     for t in range(supps.shape[-1]):
         # ---------------------------
         # Compute which body literal to prove using rule state
         raw_body_att = self.mematt(rs, rvctx, mem_erctx, bodyattmask,
                                    t)  # (R, Ls)
         self.tolog('raw_body_att', raw_body_att)
         body_att = F.softmax(raw_body_att, -1)  # (R, Ls)
         # Compute unified candidate attention
         raw_uni_cands_att = self.mematt(uni_cs, vctx, mem_ectx,
                                         candattmask, t)  # (B, Cs)
         self.tolog('raw_uni_cands_att', raw_uni_cands_att)
         uni_cands_att = F.softmax(raw_uni_cands_att, -1)  # (B, Cs)
         # Compute original candidate attention
         raw_orig_cands_att = self.mematt(orig_cs, vctx, mem_ectx,
                                          candattmask, t)  # (B, Cs)
         self.tolog('raw_orig_cands_att', raw_orig_cands_att)
         orig_cands_att = F.softmax(raw_orig_cands_att, -1)  # (B, Cs)
         # ---------------------------
         # Update states for the rule and original
         rs = self.mematt.update_state(rs, body_att, rvctx, mem_erctx,
                                       t)  # (R, E)
         orig_cs = self.mematt.update_state(orig_cs, orig_cands_att, vctx,
                                            mem_ectx, t)  # (B, E)
         # ---------------------------
         # Compute attended unification over candidates
         # (B, Cs) x (B, R, Ls, L, Cs, E) -> (B, R, Ls, L, E)
         unis = F.einsum('bc,brlsce->brlse', uni_cands_att,
                         bunis)  # (B, R, Ls, L, E)
         # ---------------------------
         # Update candidate states with new variable bindings
         bstate = ctxvgates * unis + (
             1 - ctxvgates) * erctx  # (B, R, Ls, Ls, E)
         mem_bstate = self.mematt.seq_embed(brvctx, bstate)  # (B, R, Ls, E)
         body_att = F.broadcast_to(body_att, bstate.shape[:3])  # (B, R, Ls)
         uni_cs = F.repeat(uni_cs[:, None], rvq.shape[0], 1)  # (B, R, E)
         uni_cs = self.mematt.update_state(uni_cs, body_att, brvctx,
                                           mem_bstate, t)  # (B, R, E)
         # ---------------------------
         # Apply rule attention
         if num_rules == 1:
             uni_cs = uni_cs[:, 0]  # (B, E)
         else:
             # (B, R) x (B, R, E) -> (B, E)
             uni_cs = F.einsum('br,bre->be', ratt, uni_cs)  # (B, E)
         # ---
         # Compute unification loss after this iteration
         uniloss = F.mean_squared_error(uni_cs, orig_cs)  # ()
         self.tolog('uniloss', uniloss)
     # ---------------------------
     # Compute answers based on variable and rule scores
     prediction = self.answer_linear(uni_cs)  # (B, V)
     # Compute auxilary answers
     rpred = self.answer_linear(rs)  # (R, V)
     self.tolog('rpred', rpred)
     opred = self.answer_linear(orig_cs)  # (B, V)
     self.tolog('opred', opred)
     return prediction
コード例 #13
0
ファイル: ucnn.py プロジェクト: nuric/softuni
    def forward(self, ground_examples: np.ndarray):
        """Compute the forward inference pass for given stories."""
        # ground_examples (B, 1+W*H+1)
        self.log = dict()
        # ---------------------------
        # Invariant ground prediction
        self.compute_ground_loss(self.inv_examples, log_prefix='ig')
        # Ground example prediction
        self.compute_ground_loss(ground_examples, log_prefix='o')
        # ---------------------------
        # Unification case
        task_ids = ground_examples[:, 0]  # (B,)
        ground_inputs = ground_examples[:, 1:-1]  # (B, W*H)

        invs_inputs = self.inv_examples[..., 1:-1]  # (T, I, W*H)
        # invs_inputs = invariant_inputs[task_ids-1] # (B, I, W*H)

        # Embed ground examples
        eg = self.embed(ground_inputs)  # (B, W*H, E)
        ei = self.embed(invs_inputs)  # (T, I, W*H, E)

        # Extract unification features
        tids = F.embed_id(task_ids - 1, np.eye(TASKS,
                                               dtype=np.float32))  # (B, T)
        tids = F.repeat(tids[:, None], eg.shape[1], 1)  # (B, W*H, T)
        itids = np.eye(TASKS, dtype=np.float32)  # (T, T)
        itids = F.tile(itids[:, None, None, :],
                       (1, invs_inputs.shape[1], invs_inputs.shape[2],
                        1))  # (T, I, W*H, T)

        egt = F.concat((eg, tids), -1)  # (B, W*H, E+T)
        eit = F.concat((ei, itids), -1)  # (T, I, W*H, E+T)
        egt = F.reshape(egt, egt.shape[:1] + tuple(GRID) +
                        egt.shape[-1:])  # (B, W, H, E+T)
        eit = F.reshape(eit, (-1, ) + tuple(GRID) +
                        eit.shape[-1:])  # (T*I, W, H, E+T)
        egt = F.swapaxes(egt, -1, -3)  # (B, E+T, W, H)
        eit = F.swapaxes(eit, -1, -3)  # (T*I, E+T, W, H)

        gfeats = F.relu(self.uni_conv1(egt))  # (B, E, W, H)
        ifeats = F.relu(self.uni_conv1(eit))  # (T*I, E, W, H)
        gfeats = self.uni_conv2(gfeats)  # (B, E, W, H)
        ifeats = self.uni_conv2(ifeats)  # (T*I, E, W, H)
        gfeats = F.reshape(gfeats, gfeats.shape[:2] + (-1, ))  # (B, E, W*H)
        ifeats = F.reshape(ifeats, ei.shape[:2] + ifeats.shape[1:2] +
                           (-1, ))  # (T, I, E, W*H)

        batch_ifeats = ifeats[task_ids - 1]  # (B, I, E, W*H)
        # (B, I, E, W*H) x (B, E, W*H) -> (B, I, W*H, W*H)
        uni_att = F.einsum("ijek,iel->ijkl", batch_ifeats,
                           gfeats)  # (B, I, W*H, W*H)
        mask = -100 * (ground_inputs == 0)  # (B, W*H) cannot attend to padding
        uni_att += mask[:, None, None]  # (B, I, W*H, W*H)
        uni_att = F.softmax(uni_att, axis=-1)  # (B, I, W*H, W*H)
        self.tolog('uniatt', uni_att)

        # (B, I, W*H, W*H) x (B, W*H, E) -> (B, I, W*H, E)
        eu = F.einsum("ijkl,ile->ijke", uni_att, eg)  # (B, I, W*H, E)

        # Compute variable map
        vmap = F.sigmoid(self.vmap_params * 10)  # (T, I, V)
        mask = np.ones(VOCAB)  # (V,)
        mask[0] = 0  # padding symbol cannot be variable
        vmap *= mask  # (T, I, V)
        self.tolog('vmap', vmap)
        vmap = vmap[np.arange(vmap.shape[0])[:, None, None],
                    np.arange(vmap.shape[1])[None, :, None],
                    invs_inputs]  # (T, I, W*H)
        vmap = vmap[task_ids - 1]  # (B, I, W*H)

        batch_ei = ei[task_ids - 1]  # (B, I, W*H, E)
        uni_embed = (vmap[..., None] * eu + (1 - vmap)[..., None] * batch_ei
                     )  # (B, I, W*H, E)

        # Make the prediction on the unification
        batch_itids = itids[task_ids - 1]  # (B, I, W*H, T)
        uni_embed = F.concat((uni_embed, batch_itids), -1)  # (B, I, W*H, E+T)
        uni_inputs = F.reshape(uni_embed, uni_embed.shape[:2] + tuple(GRID) +
                               uni_embed.shape[-1:])  # (B, I, W, H, E+T)
        uni_preds = self.predict(uni_inputs)  # (B, I, V)

        # Aggregate results from each invariant
        final_uni_preds = F.sum(uni_preds, -2)  # (B, V)
        # ---------------------------
        return final_uni_preds  # (B, V)
コード例 #14
0
    def forward(self, texts):
        """Compute the forward inference pass for given stories."""
        # texts [(L1,), (L2,), (L3,)]
        report = dict()

        # ---------------------------
        def sequence_embed(xs):
            """Embed sequences of integers."""
            # xt [(L1,), (L2,), ...]
            xs = list(xs)  # Chainer quirk expects lists
            x_len = [len(x) for x in xs]
            x_section = np.cumsum(x_len[:-1])
            x_concat = F.concat(xs, axis=0)  # (L1+L2...,)
            # ex = self.embed(x_concat) # (..., E)
            ex = F.embed_id(x_concat, wordembeds, ignore_label=0)
            ex = F.tanh(self.embed(ex))  # (..., E)
            uex = self.uni_embed(ex)  # (..., E)
            uvx = self.var_linear(ex)  # (..., 1)
            uvx = F.sigmoid(F.squeeze(uvx, -1))  # (..., )
            # evx = F.concat([ex, uvx[:, None]], -1)  # (..., E+1)
            evxs = F.split_axis(ex, x_section, 0)
            uexs = F.split_axis(uex, x_section, 0)
            uvs = F.split_axis(uvx, x_section, 0)
            return evxs, uexs, uvs

        # Ground example prediction
        ove, ue, uv = sequence_embed(
            texts
        )  # B x [(L1, E), (L2, E), ...], Bx[(L1, E), ...], B x [(L1,), (L2,), ...]
        oys, opred = self.predict(ove)  # B x [(L1, E), ...], (B, 1)
        report['opred'] = opred
        # Invariant example prediction
        ive, iue, iuv = sequence_embed(
            self.inv_examples[0])  # I x [(L1, E), ...] ...
        iys, ipred = self.predict(ive)  # I x [(L1, E), ...], (I, 1)
        report['igpred'] = ipred
        # ---------------------------
        # Compute padding mask
        padded_texts = F.pad_sequence(list(texts)).array  # (B, LB)
        mask = -100 * (padded_texts == 0)  # (B, LB)
        padded_itexts = F.pad_sequence(list(
            self.inv_examples[0])).array  # (I, LI)
        # ---------------------------
        # Extract unification features
        oufeats = F.pad_sequence(ue)  # (B, LB, E)
        iufeats = F.pad_sequence(iue)  # (I, LI, E)
        iuvar = F.pad_sequence(iuv)  # (I, LI)
        report['vmap'] = iuvar
        # ---------------------------
        # Unification attention
        # (I, LI, E) x (B, LB, E) -> (B, I, LI, LB)
        uniatt = F.einsum('ile,bfe->bilf', iufeats, oufeats)
        # Mask to stop attention to padding
        uniatt += mask[:, None, None]  # (B, I, LI, LB)
        uniatt = F.softmax(uniatt, -1)  # (B, I, LI, LB)
        uniatt *= (padded_itexts != 0)[..., None]  # (B, I, LI, LB)
        report['uniatt'] = uniatt
        # ---------------------------
        # Compute unified representation
        padded_ove = F.pad_sequence(ove)  # (B, LB, E)
        padded_ive = F.pad_sequence(ive)  # (I, LI, E)
        # (B, I, LI, LB) x (B, LB, E) -> (B, I, LI, E)
        uve = F.einsum('bilf,bfe->bile', uniatt, padded_ove)
        # ---
        uve = iuvar[..., None] * uve + (
            1 - iuvar[..., None]) * padded_ive  # (B, I, LI, E)
        uve = F.reshape(uve, (-1, ) + uve.shape[2:])  # (B*I, LI, E)
        uve = F.separate(uve, 0)  # B*I x [(LI, E), ...]
        ulens = np.array([len(t) for t in self.inv_examples[0]] *
                         texts.shape[0])  # (I,)
        uve = [seq[:l]
               for seq, l in zip(uve, ulens)]  # I x [(L1, E), (L2, E), ..]
        # ---------------------------
        # Compute unification predictions
        _, upred = self.predict(uve)  # (B*I, 1)
        upred = F.reshape(
            upred,
            (texts.shape[0], self.inv_examples[0].shape[0], 1))  # (B, I, 1)
        upred = F.sum(upred, 1)  # (B, 1)
        report['upred'] = upred
        # ---------------------------
        return report
コード例 #15
0
    def forward(self, ground_examples):
        """Compute the forward inference pass for given stories."""
        # ground_examples (B, 1+L+1)
        self.log = dict()
        # ---------------------------
        # Invariant ground prediction
        self.compute_ground_loss(self.inv_examples, log_prefix='ig')
        # Ground example prediction
        self.compute_ground_loss(ground_examples, log_prefix='o')
        # ---------------------------
        # Unification case
        task_ids = ground_examples[:, 0]  # (B,)
        ground_inputs = ground_examples[:, 1:-1]  # (B, L)

        invariant_inputs = self.inv_examples[..., 1:-1]  # (T, I, L)
        invs_inputs = invariant_inputs[task_ids - 1]  # (B, I, L)

        # Compute variable map
        vmap = F.sigmoid(self.vmap_params * 10)  # (T, I, V)
        self.tolog('vmap', vmap)
        vmap = vmap[task_ids - 1]  # (B, I, V)
        vmap = vmap[np.arange(vmap.shape[0])[:, None, None],
                    np.arange(vmap.shape[1])[None, :, None],
                    invs_inputs]  # (B, I, L)

        # Embed ground examples
        eg = self.embed(ground_inputs)  # (B, L, E)
        ei = self.embed(invariant_inputs)  # (T, I, L, E)

        # Embed tasks for RNN init states
        embed_tasks = self.task_embed(task_ids - 1)  # (B, E)
        embed_tasks = F.repeat(embed_tasks[None, ...], 2, axis=0)  # (2, B, E)
        iembed_tasks = self.task_embed(self.inv_examples[..., 0] -
                                       1)  # (T, I, E)
        iembed_tasks = F.repeat(iembed_tasks[None, ...], 2,
                                axis=0)  # (2, T, I, E)
        iembed_tasks = F.reshape(iembed_tasks, [2, -1, EMBED])  # (2, T*I, E)

        # Extract unification features
        ground_rnn = seq_rnn_embed(eg,
                                   self.uni_birnn,
                                   init_state=embed_tasks,
                                   return_sequences=True)  # (B, L, 2*E)
        invs_rnn = seq_rnn_embed(ei,
                                 self.uni_birnn,
                                 init_state=iembed_tasks,
                                 return_sequences=True)  # (T, I, L, 2*E)
        ground_rnn = self.uni_linear(ground_rnn, n_batch_axes=2)  # (B, L, E)
        invs_rnn = self.uni_linear(invs_rnn, n_batch_axes=3)  # (T, I, L, E)
        invs_rnn = invs_rnn[task_ids - 1]  # (B, I, L, E)
        # (B, I, L, E) x (B, L, E) -> (B, I, L, L)
        uni_att = F.einsum("ijke,ile->ijkl", invs_rnn,
                           ground_rnn)  # (B, I, L, L)
        uni_att = F.softmax(uni_att, axis=-1)  # (B, I, L, L)
        self.tolog('uniatt', uni_att)

        # (B, I, L, L) x (B, L, E) -> (B, I, L, E)
        eu = F.einsum("ijkl,ile->ijke", uni_att, eg)  # (B, I, L, E)

        # uni_embed = vmap[..., None]*eg[:, None] + (1-vmap)[..., None]*ei # (B, I, L, E)
        uni_embed = vmap[..., None] * eu + (
            1 - vmap)[..., None] * ei[task_ids - 1]  # (B, I, L, E)
        uni_embed = F.reshape(uni_embed,
                              uni_embed.shape[:-2] + (-1, ))  # (B, I, L*E)

        # Make the prediction on the unification
        ets = F.embed_id(task_ids - 1, np.eye(TASKS,
                                              dtype=np.float32))  # (B, T)
        ets = F.repeat(ets[:, None], vmap.shape[1], axis=1)  # (B, I, T)
        uni_inputs = F.concat((uni_embed, ets), axis=-1)  # (B, I, L*E+T)
        uni_preds = self.predict(uni_inputs)  # (B, I, V)

        # Aggregate results from each invariant
        final_uni_preds = F.sum(uni_preds, -2)  # (B, V)
        # ---------------------------
        return final_uni_preds  # (B, V)