예제 #1
0
def einsum(equation, *operands):
    """Variadic version of torch.einsum to match numpy api.
    """
    # rename symbols to support PyTorch 0.4.1 and earlier,
    # which allow only symbols a-z.
    equation = convert_to_valid_einsum_chars(equation)

    torch, _ = _get_torch_and_device()
    return torch.einsum(equation, operands)
예제 #2
0
def affine_product(X, A, b):
    """ Special case of affine transformation that receives coordinates X in 2-d (x, y)
    and affine matrix A and translation vector b in 3-d (x, y, z). Y = AX + b

    :param torch.Tensor X: A matrix of 2-d coordinates (d1 x d2 x 2).
    :param torch.Tensor A: The first two columns of the affine matrix (3 x 2).
    :param torch.Tensor b: A 3-d translation vector.

    :return: A (d1 x d2 x 3) torch.Tensor corresponding to the transformed coordinates.
    """
    return torch.einsum('ij,klj->kli', (A, X)) + b
예제 #3
0
    def forward(self, x, A):
        assert A.size(0) == self.kernel_size
        x = self.conv1d(x)

        n, kc, v = x.size()
        x = x.view(n, self.kernel_size, kc//self.kernel_size, v)
        x = torch.einsum('nkcv,kvw->ncw', (x, A))

        # n, kc, t, v = x.size()
        # x = x.view(n, self.kernel_size, kc//self.kernel_size, t, v)
        # x = torch.einsum('nkctv,kvw->nctw', (x, A))
        #print('einsum',x.shape)
        x = x.contiguous()
        return x, A
예제 #4
0
    def forward(self, x):

        if self.share_weights:
            u_hat_vecs = t.matmul(x, self.W)
        else:
            print('add later')

        batch_size = x.size(0)
        input_num_capsule = x.size(1)
        u_hat_vecs = u_hat_vecs.view((batch_size, input_num_capsule,
                                      self.num_capsule, self.dim_capsule))
        u_hat_vecs = u_hat_vecs.permute(0, 2, 1, 3)  # 转成(batch_size,num_capsule,input_num_capsule,dim_capsule)
        b = t.zeros_like(u_hat_vecs[:, :, :, 0])  # (batch_size,num_capsule,input_num_capsule)

        for i in range(self.routings):
            b = b.permute(0, 2, 1)
            c = F.softmax(b, dim=2)
            c = c.permute(0, 2, 1)
            b = b.permute(0, 2, 1)
            outputs = self.activation(t.einsum('bij,bijk->bik', (c, u_hat_vecs)))  # batch matrix multiplication
            # outputs shape (batch_size, num_capsule, dim_capsule)
            if i < self.routings - 1:
                b = t.einsum('bik,bijk->bij', (outputs, u_hat_vecs))  # batch matrix multiplication
        return outputs  # (batch_size, num_capsule, dim_capsule)
예제 #5
0
    def _oc(a: Tensor, rhs: Tensor, Y: Tensor) -> Tensor:
        r"""Evaluate constraints.

        Note: einsum multiples Y by a and sums over the `o`-dimension. Einsum
            is ~2x faster than using `(Y * a.view(1, 1, -1)).sum(dim-1)`.

        Args:
            a: `o`-dim tensor of weights for the outcomes
            rhs: Singleton tensor containing the outcome constraint value
            Y: `... x b x q x o` tensor of function values

        Returns:
            A `... x b x q`-dim tensor where negative values imply feasibility
        """
        lhs = torch.einsum("...o, o", [Y, a])
        return lhs - rhs
예제 #6
0
 def compute_chunk(left_act, right_act):
     act = torch.einsum('...bac,...dae->...bdce', left_act, right_act)
     act = act.reshape(act.shape[:-2] + (-1, ))
     act = self.output_w(act)
     return act
예제 #7
0
def apply_TM_2sO(state, env, edge, op=None, verbosity=0):
    r"""
    :param state: underlying 1-site C4v symmetric wavefunction
    :param env: C4v symmetric environment corresponding to ``state``
    :param edge: tensor of dimensions :math:`\chi \times D^2 \times \chi`
    :param op: two-site operator to be inserted into the two consecutive
               transfer matrices
    :param verbosity: logging verbosity
    :type state: IPEPS_C4V
    :type env: ENV_C4V
    :type edge: torch.tensor
    :type op: torch.tensor
    :type verbosity: int
    :return: ``edge`` with two transfer matrices (and operator ``op``, if any) applied.
             The resulting tensor has an identical index structure as the 
             original ``edge``
    :rtype: torch.tensor
    
    Applies two transfer matrices to the ``edge`` tensor, including the two-site operator
    ``op`` by contracting the following network::

         -----T-------------T------------
        |     |             |
       edge--(a^+ op_l a)==(a^+ op_r a)--
        |     |             |
         -----T-------------T------------

    where the physical indices `s` and `s'` of the on-site tensor :math:`a` 
    and it's hermitian conjugate :math:`a^\dagger` are contracted with 
    identity :math:`\delta_{s,s'}` or ``op_l`` and ``op_r`` if ``op`` is supplied.
    The ``op_l`` and ``op_r`` are given by the SVD decomposition of two-site operator
    ``op``::

         0  1        0           1          0            1->0
		 |  |  SVD   |           |          |            |
 	    | op |  =  |op_l|--(S--|op^~_r|) = |op_l|--2 2--|op_r| 
         |  |        |           |          |            |
         2  3        2           3          2->1         3->1
    """
    # TODO stronger verification
    if op is not None:
        assert (len(op.size()) == 4)

        # pre-process ``op``
        # TODO possibly truncate/compress according to the vanishingly small singular values
        dims_op = op.size()
        op_mat = op.permute(0, 2, 1,
                            3).contiguous().reshape(dims_op[0]**2,
                                                    dims_op[0]**2)
        op_l, s, op_r = torch.svd(op_mat)
        op_l = op_l.reshape(dims_op[0], dims_op[0], s.size()[0])
        op_r = torch.einsum('i,ij->ij', s,
                            op_r.t()).reshape(s.size()[0], dims_op[0],
                                              dims_op[0])
        op_r = op_r.permute(1, 2, 0).contiguous()

    T = env.T[env.keyT]
    # Assume index structure of ``edge`` tensor to be as follows
    #
    #       -- 0
    # edge |-- 1
    #       -- 2
    #
    #   ----0 0--T--1->2
    #  |         2->3
    # edge--1->0
    #  |
    #   ----2->1
    E = torch.tensordot(edge, T, ([0], [0]))
    if verbosity > 0: print("E=edgeT " + str(E.size()))

    # TODO - more efficent contraction with uncontracted-double-layer on-site tensor
    #        Possibly reshape indices 1,2 of E, which are to be contracted with
    #        on-site tensor and contract bra,ket in two steps instead of creating
    #        double layer tensor
    #    /
    # --A--
    #  /|s
    #   X
    # s'|/
    # --A--
    #  /
    #
    # where X is Id or op
    a = next(iter(state.sites.values()))
    dims_a = a.size()
    X = torch.eye(dims_a[0], dtype=a.dtype,
                  device=a.device)[:, :, None] if op is None else op_l
    A= torch.einsum('mefgh,mnl,nabcd->eafbgchdl',a,X,a).contiguous()\
        .view(dims_a[1]**2, dims_a[2]**2, dims_a[3]**2, dims_a[4]**2, -1)

    #   ---------T--2->1
    #  |         3 4
    #  |         0/
    # edge--0 1--A--3
    #  |         2
    #   ----1->0
    E = torch.tensordot(E, A, ([0, 3], [1, 0]))
    if verbosity > 0: print("E=EA " + str(E.size()))

    #   -------T--1->0
    #  |       | 4->2
    #  |       |/
    # edge-----A--3->1
    #  |       2
    #  |       2
    #   --0 0--T--1->3
    E = torch.tensordot(E, T, ([0, 2], [0, 2]))
    if verbosity > 0: print("E=ET " + str(E.size()))

    #   ----0 0----T--1->3
    #  |----2->1   2->4
    # edge--1->0
    #  |
    #   ----3->2
    E = torch.tensordot(E, T, ([0], [0]))
    if verbosity > 0: print("E=ET " + str(E.size()))

    # TODO - more efficent contraction with uncontracted-double-layer on-site tensor
    #        Possibly reshape indices 1,2 of E, which are to be contracted with
    #        on-site tensor and contract bra,ket in two steps instead of creating
    #        double layer tensor
    #    /
    # --A--
    #  /|s
    #   X
    # s'|/
    # --A--
    #  /
    #
    # where X is Id or op
    X = torch.eye(dims_a[0], dtype=a.dtype,
                  device=a.device)[:, :, None] if op is None else op_r
    A= torch.einsum('mefgh,mnl,nabcd->eafbgchdl',a,X,a).contiguous()\
        .view(dims_a[1]**2, dims_a[2]**2, dims_a[3]**2, dims_a[4]**2, -1)

    #   ---------T--3->1
    #  |         4
    #  |----1 4-\0
    # edge--0 1--A--3
    #  |         2
    #   ----2->0
    E = torch.tensordot(E, A, ([0, 1, 4], [1, 4, 0]))
    if verbosity > 0: print("E=EA " + str(E.size()))

    #   -------T--1->0
    #  |       |
    #  |       |
    # edge-----A--3->1
    #  |       2
    #  |       2
    #   --0 0--T--1->2
    E = torch.tensordot(E, T, ([0, 2], [0, 2]))
    if verbosity > 0: print("E=ET " + str(E.size()))

    return E
    def forward(self,
                w,
                r,
                attn_mask=None,
                mems=None,
                head_mask=None,
                output_attentions=False):
        qlen, rlen, bsz = w.size(0), r.size(0), w.size(1)

        if mems is not None:
            cat = torch.cat([mems, w], 0)
            if self.pre_lnorm:
                w_heads = self.qkv_net(self.layer_norm(cat))
            else:
                w_heads = self.qkv_net(cat)
            r_head_k = self.r_net(r)

            w_head_q, w_head_k, w_head_v = torch.chunk(w_heads, 3, dim=-1)
            w_head_q = w_head_q[-qlen:]
        else:
            if self.pre_lnorm:
                w_heads = self.qkv_net(self.layer_norm(w))
            else:
                w_heads = self.qkv_net(w)
            r_head_k = self.r_net(r)

            w_head_q, w_head_k, w_head_v = torch.chunk(w_heads, 3, dim=-1)

        klen = w_head_k.size(0)

        w_head_q = w_head_q.view(qlen, bsz, self.n_head,
                                 self.d_head)  # qlen x bsz x n_head x d_head
        w_head_k = w_head_k.view(klen, bsz, self.n_head,
                                 self.d_head)  # qlen x bsz x n_head x d_head
        w_head_v = w_head_v.view(klen, bsz, self.n_head,
                                 self.d_head)  # qlen x bsz x n_head x d_head

        r_head_k = r_head_k.view(rlen, self.n_head,
                                 self.d_head)  # qlen x n_head x d_head

        # compute attention score
        rw_head_q = w_head_q + self.r_w_bias  # qlen x bsz x n_head x d_head
        AC = torch.einsum("ibnd,jbnd->ijbn",
                          (rw_head_q, w_head_k))  # qlen x klen x bsz x n_head

        rr_head_q = w_head_q + self.r_r_bias
        BD = torch.einsum("ibnd,jnd->ijbn",
                          (rr_head_q, r_head_k))  # qlen x klen x bsz x n_head
        BD = self._rel_shift(BD)

        # [qlen x klen x bsz x n_head]
        attn_score = AC + BD
        attn_score.mul_(self.scale)

        # compute attention probability
        if attn_mask is not None and torch.sum(attn_mask).item():
            attn_mask = attn_mask == 1  # Switch to bool
            if attn_mask.dim() == 2:
                if next(self.parameters()).dtype == torch.float16:
                    attn_score = (attn_score.float().masked_fill(
                        attn_mask[None, :, :, None],
                        -65000).type_as(attn_score))
                else:
                    attn_score = attn_score.float().masked_fill(
                        attn_mask[None, :, :, None], -1e30).type_as(attn_score)
            elif attn_mask.dim() == 3:
                if next(self.parameters()).dtype == torch.float16:
                    attn_score = attn_score.float().masked_fill(
                        attn_mask[:, :, :, None], -65000).type_as(attn_score)
                else:
                    attn_score = attn_score.float().masked_fill(
                        attn_mask[:, :, :, None], -1e30).type_as(attn_score)

        # [qlen x klen x bsz x n_head]
        attn_prob = F.softmax(attn_score, dim=1)
        attn_prob = self.dropatt(attn_prob)

        # Mask heads if we want to
        if head_mask is not None:
            attn_prob = attn_prob * head_mask

        # compute attention vector
        attn_vec = torch.einsum("ijbn,jbnd->ibnd", (attn_prob, w_head_v))

        # [qlen x bsz x n_head x d_head]
        attn_vec = attn_vec.contiguous().view(attn_vec.size(0),
                                              attn_vec.size(1),
                                              self.n_head * self.d_head)

        # linear projection
        attn_out = self.o_net(attn_vec)
        attn_out = self.drop(attn_out)

        if self.pre_lnorm:
            # residual connection
            outputs = [w + attn_out]
        else:
            # residual connection + layer normalization
            outputs = [self.layer_norm(w + attn_out)]

        if output_attentions:
            outputs.append(attn_prob)

        return outputs
예제 #9
0
파일: torch.py 프로젝트: zimonitrome/einops
 def forward(self, input):
     result = torch.einsum(self.einsum_pattern, input, self.weight)
     if self.bias is not None:
         result += self.bias
     return result
예제 #10
0
def apply_TM_1sO(state, env, edge, op=None, verbosity=0):
    r"""
    :param state: underlying 1-site C4v symmetric wavefunction
    :param env: C4v symmetric environment corresponding to ``state``
    :param edge: tensor of dimensions :math:`\chi \times D^2 \times \chi`
    :param op: operator to be inserted into transfer matrix
    :param verbosity: logging verbosity
    :type state: IPEPS_C4V
    :type env: ENV_C4V
    :type edge: torch.tensor
    :type op: torch.tensor
    :type verbosity: int
    :return: ``edge`` with a single instance of the transfer matrix applied.
             The resulting tensor has an identical index structure as the 
             original ``edge`` 
    :rtype: torch.tensor
    
    Applies a single instance of the "transfer matrix" to the ``edge`` tensor  
    by contracting the following network::

         -----T----------
        |     |     
       edge--(a^+ op a)--
        |     |     
         -----T----------

    where the physical indices `s` and `s'` of the on-site tensor :math:`a` 
    and it's hermitian conjugate :math:`a^\dagger` are contracted with 
    identity :math:`\delta_{s,s'}` or ``op`` (if supplied).
    """
    # TODO stronger verification
    if op is not None:
        assert (len(op.size()) == 2)

    T = env.T[env.keyT]
    # Assume index structure of ``edge`` tensor to be as follows
    #
    #       -- 0
    # edge |-- 1
    #       -- 2
    #
    #   --0 0--T--1->2
    #  |       2->3
    # edge--1->0
    #  |
    #   --2->1
    E = torch.tensordot(edge, T, ([0], [0]))
    if verbosity > 0: print("E=edgeT " + str(E.size()))

    # TODO - more efficent contraction with uncontracted-double-layer on-site tensor
    #        Possibly reshape indices 1,2 of E, which are to be contracted with
    #        on-site tensor and contract bra,ket in two steps instead of creating
    #        double layer tensor
    #    /
    # --A--
    #  /|s
    #   X
    # s'|/
    # --A--
    #  /
    #
    # where X is Id or op
    a = next(iter(state.sites.values()))
    dims_a = a.size()
    X = torch.eye(dims_a[0], dtype=a.dtype,
                  device=a.device) if op is None else op
    A= torch.einsum('mefgh,mn,nabcd->eafbgchd',a,X,a).contiguous()\
        .view(dims_a[1]**2, dims_a[2]**2, dims_a[3]**2, dims_a[4]**2)

    #   ---------T--2->1
    #  |         3
    #  |         0
    # edge--0 1--A--3
    #  |         2
    #   ----1->0
    E = torch.tensordot(E, A, ([0, 3], [1, 0]))
    if verbosity > 0: print("E=EA " + str(E.size()))

    #   -------T--1->0
    #  |       |
    #  |       |
    # edge-----A--3->1
    #  |       2
    #  |       2
    #   --0 0--T--1->2
    E = torch.tensordot(E, T, ([0, 2], [0, 2]))
    if verbosity > 0: print("E=ET " + str(E.size()))

    return E
예제 #11
0
    def forward(
        self,
        x,
        encoder_out: Optional[torch.Tensor] = None,
        encoder_padding_mask: Optional[torch.Tensor] = None,
        incremental_state: Optional[Dict[str, Dict[str,
                                                   Optional[Tensor]]]] = None,
        prev_self_attn_state: Optional[List[torch.Tensor]] = None,
        prev_attn_state: Optional[List[torch.Tensor]] = None,
        self_attn_mask: Optional[torch.Tensor] = None,
        self_attn_padding_mask: Optional[torch.Tensor] = None,
        need_attn: bool = False,
        need_head_weights: bool = False,
    ):
        """
        Args:
            x (Tensor): input to the layer of shape `(seq_len, batch, embed_dim)`
            encoder_padding_mask (ByteTensor, optional): binary
                ByteTensor of shape `(batch, src_len)` where padding
                elements are indicated by ``1``.
            need_attn (bool, optional): return attention weights
            need_head_weights (bool, optional): return attention weights
                for each head (default: return average over heads).

        Returns:
            encoded output of shape `(seq_len, batch, embed_dim)`
        """
        if need_head_weights:
            need_attn = True

        residual = x
        if self.normalize_before:
            x = self.self_attn_layer_norm(x)
        if prev_self_attn_state is not None:
            prev_key, prev_value = prev_self_attn_state[:2]
            saved_state: Dict[str, Optional[Tensor]] = {
                "prev_key": prev_key,
                "prev_value": prev_value,
            }
            if len(prev_self_attn_state) >= 3:
                saved_state["prev_key_padding_mask"] = prev_self_attn_state[2]
            assert incremental_state is not None
            self.self_attn._set_input_buffer(incremental_state, saved_state)
        _self_attn_input_buffer = self.self_attn._get_input_buffer(
            incremental_state)
        if self.cross_self_attention and not (
                incremental_state is not None and _self_attn_input_buffer
                is not None and "prev_key" in _self_attn_input_buffer):
            if self_attn_mask is not None:
                assert encoder_out is not None
                self_attn_mask = torch.cat((x.new_zeros(
                    x.size(0), encoder_out.size(0)), self_attn_mask),
                                           dim=1)
            if self_attn_padding_mask is not None:
                if encoder_padding_mask is None:
                    assert encoder_out is not None
                    encoder_padding_mask = self_attn_padding_mask.new_zeros(
                        encoder_out.size(1), encoder_out.size(0))
                self_attn_padding_mask = torch.cat(
                    (encoder_padding_mask, self_attn_padding_mask), dim=1)
            assert encoder_out is not None
            y = torch.cat((encoder_out, x), dim=0)
        else:
            y = x

        x, attn = self.self_attn(
            query=x,
            key=y,
            value=y,
            key_padding_mask=self_attn_padding_mask,
            incremental_state=incremental_state,
            need_weights=False,
            attn_mask=self_attn_mask,
        )
        if self.c_attn is not None:
            tgt_len, bsz = x.size(0), x.size(1)
            x = x.view(tgt_len, bsz, self.nh, self.head_dim)
            x = torch.einsum("tbhd,h->tbhd", x, self.c_attn)
            x = x.reshape(tgt_len, bsz, self.embed_dim)
        if self.attn_ln is not None:
            x = self.attn_ln(x)
        x = self.dropout_module(x)
        x = self.residual_connection(x, residual)
        if not self.normalize_before:
            x = self.self_attn_layer_norm(x)

        if self.encoder_attn is not None and encoder_out is not None:
            residual = x
            if self.normalize_before:
                x = self.encoder_attn_layer_norm(x)
            if prev_attn_state is not None:
                prev_key, prev_value = prev_attn_state[:2]
                saved_state: Dict[str, Optional[Tensor]] = {
                    "prev_key": prev_key,
                    "prev_value": prev_value,
                }
                if len(prev_attn_state) >= 3:
                    saved_state["prev_key_padding_mask"] = prev_attn_state[2]
                assert incremental_state is not None
                self.encoder_attn._set_input_buffer(incremental_state,
                                                    saved_state)

            x, attn = self.encoder_attn(
                query=x,
                key=encoder_out,
                value=encoder_out,
                key_padding_mask=encoder_padding_mask,
                incremental_state=incremental_state,
                static_kv=True,
                need_weights=need_attn
                or (not self.training and self.need_attn),
                need_head_weights=need_head_weights,
            )
            x = self.dropout_module(x)
            x = self.residual_connection(x, residual)
            if not self.normalize_before:
                x = self.encoder_attn_layer_norm(x)

        residual = x
        if self.normalize_before:
            x = self.final_layer_norm(x)

        x = self.activation_fn(self.fc1(x))
        x = self.activation_dropout_module(x)
        if self.ffn_layernorm is not None:
            x = self.ffn_layernorm(x)
        x = self.fc2(x)
        x = self.dropout_module(x)
        if self.w_resid is not None:
            residual = torch.mul(self.w_resid, residual)
        x = self.residual_connection(x, residual)
        if not self.normalize_before:
            x = self.final_layer_norm(x)
        if self.onnx_trace and incremental_state is not None:
            saved_state = self.self_attn._get_input_buffer(incremental_state)
            assert saved_state is not None
            if self_attn_padding_mask is not None:
                self_attn_state = [
                    saved_state["prev_key"],
                    saved_state["prev_value"],
                    saved_state["prev_key_padding_mask"],
                ]
            else:
                self_attn_state = [
                    saved_state["prev_key"], saved_state["prev_value"]
                ]
            return x, attn, self_attn_state
        return x, attn, None
예제 #12
0
def _get_predicts(predicts, coefficients):
    return torch.einsum("ij,j->ij", (predicts, coefficients))
예제 #13
0
    def gradient(self, *xs, y=None, v=None, ctx=None):
        """Computes the vector--Jacobian product, that is, the gradient of the
        loss function with respect to the problem parameters. The returned
        gradient is a tuple of batched Torch tensors. Can be overridden by the
        derived class to provide a more efficient implementation.

        Arguments:
            xs: ((b, ...), ...) tuple of Torch tensors,
                tuple of batches of input tensors

            y: (b, ...) Torch tensor or None,
                batch of minima of the objective function
            
            v: (b, ...) Torch tensor or None,
                batch of gradients of the loss function with respect to the
                problem output J_Y(x,y)

            ctx: dictionary of contextual information used for computing the
            gradient

        Return Values:
            gradients: ((b, ...), ...) tuple of Torch tensors or Nones,
                batch of gradients of the loss function with respect to the
                problem parameters;
                strictly, returns the vector--Jacobian products J_Y(x,y) * y'(x)
        """

        # Compute optimal value if have not already done so:
        if y is None:
            y, ctx = torch.no_grad()(self.solve)(*xs)
            y.requires_grad = True
        # Set incoming gradient v = J_Y(x,y) to one if not specified:
        if v is None:
            v = torch.ones_like(y)

        b = y.size(0)
        m = y.view(b, -1).size(-1)

        # Get constraint parameters and form batch:
        A, d = self.linear_constraint_parameters(y)
        A = self._expand_as_batch(A, b)
        d = self._expand_as_batch(d, b)

        # Check linear equality constraints are satisfied:
        h = torch.einsum('bpm,bm->bp', (A, y)) - d
        if not self._check_equality_constraints(h):
            warnings.warn("Constraints not satisfied {}".format(
                h.detach().squeeze().cpu().numpy()))

        # Compute relevant derivatives with autograd:
        with torch.enable_grad():
            # Split each input x into a tuple of n tensors of size bx1:
            # Required since gradients can only be computed wrt individual
            # tensors, not slices of a tensors. See:
            # https://discuss.pytorch.org/t/how-to-calculate-gradients-wrt-one-of-inputs/24407
            xs_split, xs_sizes = self._split_inputs(xs)
            xs = self._cat_inputs(xs_split, xs_sizes)

            # Evaluate objective function at (xs,y):
            f = self.objective(*xs, y)  # b

        # Compute partial derivative of f wrt y at (xs,y):
        grad_outputs = torch.ones_like(f)  # b
        fY = grad(f, y, grad_outputs=grad_outputs,
                  create_graph=True)[0].view(b, -1)  # bxm
        if not fY.requires_grad:  # if fY is independent of y
            fY.requires_grad = True

        # Compute second-order partial derivative of f wrt y at (xs,y):
        fYY = self._batch_jacobian(fY, y)
        assert fYY is not None

        # Compute 2nd-order partial derivative of h wrt y at (xs,y) and form H:
        H = fYY.detach()

        # Solve u = -H^-1 v (bxm) and t = H^-1 A^T (bxmxp):
        H = 0.5 * (H + H.transpose(1, 2))  # Ensure that H is symmetric
        v = v.view(b, -1, 1)  # bxmx1
        u, t = self._solve_linear_system(H, (-1.0 * v, A.transpose(-2, -1)))
        u = u.squeeze(-1)  # bxm

        # ToDo: check for NaN values in u and t

        # Solve s = (A H^-1 A^T)^-1 A H^-1 v = -(A t)^-1 A u:
        s = self._solve_linear_system(torch.einsum('bpm,bmq->bpq', (A, t)),
                                      torch.einsum('bpm,bm->bp',
                                                   (A, -1.0 * u)))  # bxpx1
        s = s.squeeze(-1)  # bxp

        # ToDo: check for NaN values in s

        # Compute u + ts = -H^-1 v + H^-1 A^T (A H^-1 A^T)^-1 A H^-1 v:
        uts = u + torch.einsum('bmp,bp->bm', (t, s))  # bxm

        # Compute bi^T (u + ts) for all i:
        gradients = []
        for x_split, x_size in zip(xs_split,
                                   xs_sizes):  # Loop over input tuple
            if isinstance(x_split[0],
                          torch.Tensor) and x_split[0].requires_grad:
                n = len(x_split)
                gradient = x_split[0].new_zeros(b, n)  # bxn
                for i in range(n):
                    # 2nd-order partial derivative of f wrt y and xi at (xs,y):
                    fXiY = self._batch_jacobian(fY, x_split[i])  # bxmx1
                    bi = fXiY.detach().squeeze(-1) if (fXiY is not None) else (
                        torch.zeros_like(fY))  # Shares storage with fXiY
                    gradient[:, i] = torch.einsum('bm,bm->b', (bi, uts))
                # Reshape gradient to size(x):
                gradients.append(gradient.view(x_size))
            else:
                gradients.append(None)
        return tuple(gradients)
예제 #14
0
    def gradient(self, *xs, y=None, v=None, ctx=None):
        """Computes the vector--Jacobian product, that is, the gradient of the
        loss function with respect to the problem parameters. The returned
        gradient is a tuple of batched Torch tensors. Can be overridden by the
        derived class to provide a more efficient implementation.

        Arguments:
            xs: ((b, ...), ...) tuple of Torch tensors,
                tuple of batches of input tensors

            y: (b, ...) Torch tensor or None,
                batch of minima of the objective function
            
            v: (b, ...) Torch tensor or None,
                batch of gradients of the loss function with respect to the
                problem output J_Y(x,y)

            ctx: dictionary of contextual information used for computing the
                 gradient

        Return Values:
            gradients: ((b, ...), ...) tuple of Torch tensors or Nones,
                batch of gradients of the loss function with respect to the
                problem parameters;
                strictly, returns the vector--Jacobian products J_Y(x,y) * y'(x)
        """
        # Compute optimal value if have not already done so:
        if y is None:
            y, ctx = torch.no_grad()(self.solve)(*xs)
            y.requires_grad = True
        # Set incoming gradient v = J_Y(x,y) to one if not specified:
        if v is None:
            v = torch.ones_like(y)

        # Compute relevant derivatives with autograd:
        b = y.size(0)
        m = y.view(b, -1).size(-1)
        with torch.enable_grad():
            # Split each input x into a tuple of n tensors of size bx1:
            # Required since gradients can only be computed wrt individual
            # tensors, not slices of a tensor. See:
            # https://discuss.pytorch.org/t/how-to-calculate-gradients-wrt-one-of-inputs/24407
            xs_split, xs_sizes = self._split_inputs(xs)

            # Evaluate objective function at (xs,y):
            f = self.objective(*self._cat_inputs(xs_split, xs_sizes), y)  # b

        # Compute partial derivative of f wrt y at (xs,y):
        fY = grad(f, y, grad_outputs=torch.ones_like(f),
                  create_graph=True)[0].view(b, -1)  # bxm

        if not self._check_optimality_cond(fY):
            warnings.warn(
                "Non-zero objective function gradient {} at y".format(
                    fY.detach().squeeze().cpu().numpy()))

        # Compute second-order partial derivative of f wrt y at (xs,y):
        fYY = self._batch_jacobian(fY, y)

        # Solve u = -H^-1 v:
        H = fYY.detach()
        H = 0.5 * (H + H.transpose(1, 2))  # Ensure that H is symmetric
        v = v.view(b, -1, 1)
        u = self._solve_linear_system(H, -1.0 * v)  # bxmx1
        u = u.squeeze(-1)  # bxm

        # ToDo: check for NaN values in u

        # Compute -b_i^T H^-1 v (== b_i^T u) for all i:
        gradients = []
        for x_split, x_size in zip(xs_split,
                                   xs_sizes):  # Loop over input tuple
            if isinstance(x_split[0],
                          torch.Tensor) and x_split[0].requires_grad:
                n = len(x_split)
                gradient = x_split[0].new_zeros(b, n)  # bxn
                # 2nd-order partial derivative of f wrt y and x at (xs,y):
                fXiY = torch.zeros_like(fY)  # bxm
                grad_outputs = torch.ones_like(fY)
                for i in range(n):
                    with torch.enable_grad():
                        fXiY = grad(fY,
                                    x_split[i],
                                    grad_outputs=grad_outputs,
                                    create_graph=True)[0]  # bxm
                    bi = fXiY.detach()
                    gradient[:, i] = torch.einsum('bm,bm->b', (bi, u))
                # Reshape gradient to size(x):
                gradients.append(gradient.view(x_size))
            else:
                gradients.append(None)
        return tuple(gradients)
예제 #15
0
    def gradient(self, *xs, y=None, v=None, ctx=None):
        """Computes the vector--Jacobian product, that is, the gradient of the
        loss function with respect to the problem parameters. The returned
        gradient is a tuple of batched Torch tensors. Can be overridden by the
        derived class to provide a more efficient implementation.

        Arguments:
            xs: ((b, ...), ...) tuple of Torch tensors,
                tuple of batches of input tensors

            y: (b, ...) Torch tensor or None,
                batch of minima of the objective function
            
            v: (b, ...) Torch tensor or None,
                batch of gradients of the loss function with respect to the
                problem output J_Y(x,y)

            ctx: dictionary of contextual information used for computing the
            gradient

        Return Values:
            gradients: ((b, ...), ...) tuple of Torch tensors or Nones,
                batch of gradients of the loss function with respect to the
                problem parameters;
                strictly, returns the vector--Jacobian products J_Y(x,y) * y'(x)
        """

        # Compute optimal value if have not already done so:
        if y is None:
            y, ctx = torch.no_grad()(self.solve)(*xs)
            y.requires_grad = True
        # Set incoming gradient v = J_Y(x,y) to one if not specified:
        if v is None:
            v = torch.ones_like(y)

        # Compute relevant derivatives with autograd:
        b = y.size(0)
        m = y.view(b, -1).size(-1)
        with torch.enable_grad():
            # Split each input x into a tuple of n tensors of size bx1:
            # Required since gradients can only be computed wrt individual
            # tensors, not slices of a tensors. See:
            # https://discuss.pytorch.org/t/how-to-calculate-gradients-wrt-one-of-inputs/24407
            xs_split, xs_sizes = self._split_inputs(xs)
            xs = self._cat_inputs(xs_split, xs_sizes)

            # Evaluate constraint function(s) at (xs,y):
            h = self._get_constraint_set(xs, y)  # bxp
            if h is None:  # If None, use unconstrained gradient
                return super().gradient(xs, y=y, v=v, ctx=ctx)

            # Evaluate objective function at (xs,y):
            f = self.objective(*xs, y)  # b

        # Compute partial derivative of f wrt y at (xs,y):
        fY = grad(f, y, grad_outputs=torch.ones_like(f),
                  create_graph=True)[0].view(b, -1)  # bxm
        if not fY.requires_grad:  # if fY is independent of y
            fY.requires_grad = True

        # Compute partial derivative of h wrt y at (xs,y):
        hY = self._batch_jacobian(h, y, create_graph=True)
        if not hY.requires_grad:  # if hY is independent of y
            hY.requires_grad = True

        # Compute nu (b, p):
        nu = self._get_nu(fY, hY) if (ctx is None
                                      or 'nu' not in ctx) else ctx['nu']
        nu = nu.unsqueeze(-1) if len(
            nu.size()) == 1 else nu  # Force p dimension

        if not self._check_optimality_cond(fY, hY, nu):
            warnings.warn(
                "Non-zero Lagrangian gradient {} at y. fY: {}, hY: {}, nu: {}".
                format(
                    (fY -
                     torch.einsum('ab,abc->ac',
                                  (nu, hY))).detach().squeeze().cpu().numpy(),
                    fY.detach().squeeze().cpu().numpy(),
                    hY.detach().squeeze().cpu().numpy(),
                    nu.detach().squeeze().cpu().numpy()))

        # Compute second-order partial derivative of f wrt y at (xs,y):
        fYY = self._batch_jacobian(fY, y)

        # Compute 2nd-order partial derivative of h wrt y at (xs,y) and form H:
        H = fYY.detach() if fYY is not None else 0.0  # Shares storage with fYY
        p = h.size(-1)
        for i in range(p):
            with torch.enable_grad():  # Needed when looping over output
                hiYY = self._batch_jacobian(hY[:, i, :], y, create_graph=False)
            if hiYY is not None:
                H -= torch.einsum('b,bmn->bmn', (nu[:, i], hiYY))
        assert isinstance(H, torch.Tensor)

        # Solve u = -H^-1 v (bxm) and t = H^-1 A^T (bxmxp):
        H = 0.5 * (H + H.transpose(1, 2))  # Ensure that H is symmetric
        A = hY.detach()  # Shares storage with hY
        v = v.view(b, -1, 1)  # bxmx1
        u, t = self._solve_linear_system(H, (-1.0 * v, A.transpose(-2, -1)))
        u = u.squeeze(-1)  # bxm

        # ToDo: check for NaN values in u and t

        # Solve s = (A H^-1 A^T)^-1 A H^-1 v = -(A t)^-1 A u:
        s = self._solve_linear_system(torch.einsum('bpm,bmq->bpq', (A, t)),
                                      torch.einsum('bpm,bm->bp',
                                                   (A, -1.0 * u)))  # bxpx1
        s = s.squeeze(-1)  # bxp

        # ToDo: check for NaN values in s

        # Compute u + ts:
        uts = u + torch.einsum('bmp,bp->bm', (t, s))  # bxm

        # Compute bi^T (u + ts) - ci^T s for all i:
        gradients = []
        for x_split, x_size in zip(xs_split,
                                   xs_sizes):  # Loop over input tuple
            if isinstance(x_split[0],
                          torch.Tensor) and x_split[0].requires_grad:
                n = len(x_split)
                gradient = x_split[0].new_zeros(b, n)  # bxn
                for i in range(n):
                    # 2nd-order partial derivative of f wrt y and xi at (xs,y):
                    fXiY = self._batch_jacobian(fY, x_split[i])  # bxmx1
                    bi = fXiY.detach().squeeze(-1) if (fXiY is not None) else (
                        torch.zeros_like(fY))  # Shares storage with fXiY
                    for j in range(p):
                        # 2nd-order partial derivative of hj wrt y and xi at (xs,y):
                        with torch.enable_grad():
                            hjXiY = self._batch_jacobian(
                                hY[:, j, :], x_split[i])  # bxmx1
                        if hjXiY is not None:
                            bi -= torch.einsum(
                                'b,bm->bm',
                                (nu[:, j], hjXiY.detach().squeeze(-1)))  # bxm
                    # Compute partial derivative of h wrt xi at (xs,y):
                    hXi = self._batch_jacobian(h, x_split[i])  # bxpx1
                    if hXi is None:
                        gradient[:, i] = torch.einsum('bm,bm->b', (bi, uts))
                    else:
                        ci = hXi.detach().squeeze(
                            -1)  # Shares storage with hXi
                        gradient[:, i] = (torch.einsum('bm,bm->b', (bi, uts)) -
                                          torch.einsum('bp,bp->b', (ci, s)))
                # Reshape gradient to size(x):
                gradients.append(gradient.view(x_size))
            else:
                gradients.append(None)
        return tuple(gradients)
예제 #16
0
def get_batch_top3score(target, output):
    weights = torch.as_tensor([1, 1 / 2, 1 / 3]).cuda()
    _, pred = torch.topk(output, k=3, dim=1)
    target = target.reshape(target.shape[0], 1)
    target = target.repeat(1, 3)
    return torch.sum(torch.einsum("ij,j->i", (pred == target).float(), weights)).item()
예제 #17
0
    def forward(self, z1ss, pos_emb, u1ss, mems=None):
        # Note: In this context, qlen means the length of the (small) subsequence; and mlen describes
        #       the length of the padding. Their sum is klen.
        bsz, d_model, qlen = z1ss.size()
        r_w_bias, r_r_bias = self.r_w_bias, self.r_r_bias
        n_head, d_head = self.n_head, self.d_head
        rlen = pos_emb.size(2)

        if mems is None:
            mems = torch.tensor([]).view(0, 0, 0)
        mlen = mems.size(2)
        cat = torch.cat([mems, z1ss], dim=-1)

        if self.pre_lnorm:
            cat = F.layer_norm(cat.transpose(1, 2),
                               (d_model, )).transpose(1, 2)
        w_heads = self.qkv_net(cat)  # (N x 3*d_model x seq_len)
        r_head_k = self.r_net(pos_emb)

        # Input injection
        w_heads += u1ss
        w_head_q, w_head_k, w_head_v = torch.chunk(w_heads, 3, dim=1)
        w_head_q = w_head_q[:, :, -qlen:]

        klen = w_head_k.size(2)

        w_head_q = w_head_q.view(bsz, n_head, d_head,
                                 qlen)  # bsz x n_head x d_head x qlen
        w_head_k = w_head_k.view(bsz, n_head, d_head,
                                 klen)  # bsz x n_head x d_head x klen
        w_head_v = w_head_v.view(bsz, n_head, d_head,
                                 klen)  # bsz x n_head x d_head x klen

        r_head_k = r_head_k.view(n_head, d_head,
                                 rlen)  # n_head x d_head x rlen

        # Compute attention score
        rw_head_q = w_head_q + r_w_bias[:, :,
                                        None]  # bsz x n_head x d_head x qlen
        AC = torch.einsum('bndi,bndj->bnij', rw_head_q, w_head_k)
        rr_head_q = w_head_q + r_r_bias[:, :, None]
        BD = torch.einsum('bndi,ndj->bnij', rr_head_q, r_head_k)
        BD = self._rel_shift(BD)  # for relative positional embedding

        attn_score = AC + BD  # bsz x n_head x qlen x klen
        attn_score.mul_(self.scale)

        # Compute attention probability
        # We apply a local mask, with local horizon size of mlen
        local_size = self.local_size or 1000
        attn_mask = (torch.triu(torch.ones(qlen, klen), diagonal=1 + mlen) >
                     0)[None]
        attn_mask += (torch.tril(torch.ones(qlen, klen),
                                 diagonal=mlen - local_size) > 0)[None]
        if attn_mask is not None and attn_mask.any().item():
            attn_score = attn_score.float().masked_fill(
                attn_mask[None], -float('inf')).type_as(attn_score)

        attn_prob = F.softmax(attn_score, dim=-1)  # bsz x n_head x qlen x klen

        # Compute attention vector
        attn_vec = torch.einsum('bnij,bndj->bndi', (attn_prob, w_head_v))

        # [bsz x d x qlen]
        attn_vec = attn_vec.contiguous().view(bsz, n_head * d_head,
                                              attn_vec.size(-1))

        # Linear projection
        attn_out = self.o_net(attn_vec)
        attn_out = self.drop(attn_out)

        # Residual connection + layer normolization (if applicable)
        if self.pre_lnorm:
            out = attn_out + z1ss
        else:
            out = F.layer_norm((attn_out + z1ss).transpose(1, 2),
                               (d_model, )).transpose(1, 2)
        return out
 def last_layer(self, z):
     z = torch.einsum("ij,mnj->imn", z, self.W)
     return z
예제 #19
0
def train(epoch):
    torch.set_printoptions(precision=16)
    print('\nEpoch: %d' % epoch)
    net.train()
    train_loss = 0
    correct = 0
    total = 0
    step_st_time = time.time()
    epoch_time = 0
    print('\nKFAC/KBFGS damping: %f' % damping)
    print('\nNGD damping: %f' % (damping))

    # 
    desc = ('[%s][LR=%s] Loss: %.3f | Acc: %.3f%% (%d/%d)' %
            (tag, lr_scheduler.get_last_lr()[0], 0, 0, correct, total))

    writer.add_scalar('train/lr', lr_scheduler.get_last_lr()[0], epoch)

    prog_bar = tqdm(enumerate(trainloader), total=len(trainloader), desc=desc, leave=True)
    for batch_idx, (inputs, targets) in prog_bar:

        if optim_name in ['kfac', 'skfac', 'ekfac', 'sgd', 'adam']:
            inputs, targets = inputs.to(args.device), targets.to(args.device)
            optimizer.zero_grad()
            outputs = net(inputs)
            loss = criterion(outputs, targets)
            if optim_name in ['kfac', 'skfac', 'ekfac'] and optimizer.steps % optimizer.TCov == 0:
                # compute true fisher
                optimizer.acc_stats = True
                with torch.no_grad():
                    sampled_y = torch.multinomial(torch.nn.functional.softmax(outputs.cpu().data, dim=1),1).squeeze().to(args.device)
                loss_sample = criterion(outputs, sampled_y)
                loss_sample.backward(retain_graph=True)
                optimizer.acc_stats = False
                optimizer.zero_grad()  # clear the gradient for computing true-fisher.
            loss.backward()
            optimizer.step()
        elif optim_name in ['kbfgs', 'kbfgsl', 'kbfgsl_2loop', 'kbfgsl_mem_eff']:
            inputs, targets = inputs.to(args.device), targets.to(args.device)
            optimizer.zero_grad()
            outputs = net.forward(inputs)
            loss = criterion(outputs, targets)
            loss.backward()
            # do another forward-backward pass over batch inside step()
            def closure():
                return inputs, targets, criterion, False # is_autoencoder = False
            optimizer.step(closure)
        elif optim_name == 'exact_ngd':
            inputs, targets = inputs.to(args.device), targets.to(args.device)
            optimizer.zero_grad()
            outputs = net(inputs)
            loss = criterion(outputs, targets)
            # update Fisher inverse
            if batch_idx % args.freq == 0:
              # compute true fisher
              with torch.no_grad():
                sampled_y = torch.multinomial(torch.nn.functional.softmax(outputs.cpu().data, dim=1),1).squeeze().to(args.device)
              # use backpack extension to compute individual gradient in a batch
              batch_grad = []
              with backpack(BatchGrad()):
                loss_sample = criterion(outputs, sampled_y)
                loss_sample.backward(retain_graph=True)

              for name, param in net.named_parameters():
                if hasattr(param, "grad_batch"):
                  batch_grad.append(args.batch_size * param.grad_batch.reshape(args.batch_size, -1))
                else:
                  raise NotImplementedError

              J = torch.cat(batch_grad, 1)
              fisher = torch.matmul(J.t(), J) / args.batch_size
              inv = torch.linalg.inv(fisher + damping * torch.eye(fisher.size(0)).to(fisher.device))
              # clean the gradient to compute the true fisher
              optimizer.zero_grad()

            loss.backward()
            # compute the step direction p = F^-1 @ g
            grad_list = []
            for name, param in net.named_parameters():
              grad_list.append(param.grad.data.reshape(-1, 1))
            g = torch.cat(grad_list, 0)
            p = torch.matmul(inv, g)

            start = 0
            for name, param in net.named_parameters():
              end = start + param.data.reshape(-1, 1).size(0)
              param.grad.copy_(p[start:end].reshape(param.grad.data.shape))
              start = end

            optimizer.step()

        ### new optimizer test
        elif optim_name in ['kngd'] :
            inputs, targets = inputs.to(args.device), targets.to(args.device)
            optimizer.zero_grad()
            outputs = net(inputs)
            loss = criterion(outputs, targets)
            if  optimizer.steps % optimizer.freq == 0:
                # compute true fisher
                optimizer.acc_stats = True
                with torch.no_grad():
                    sampled_y = torch.multinomial(torch.nn.functional.softmax(outputs, dim=1),1).squeeze().to(args.device)
                loss_sample = criterion(outputs, sampled_y)
                loss_sample.backward(retain_graph=True)
                optimizer.acc_stats = False
                optimizer.zero_grad()  # clear the gradient for computing true-fisher.
                if args.partial_backprop == 'true':
                  idx = (sampled_y == targets) == False
                  loss = criterion(outputs[idx,:], targets[idx])
                  # print('extra:', idx.sum().item())
            loss.backward()
            optimizer.step()

        elif optim_name == 'ngd':
            if batch_idx % args.freq == 0:
                store_io_(True)
                inputs, targets = inputs.to(args.device), targets.to(args.device)
                optimizer.zero_grad()
                # net.set_require_grad(True)

                outputs = net(inputs)
                damp = damping
                loss = criterion(outputs, targets)
                loss.backward(retain_graph=True)

                # storing original gradient for later use
                grad_org = []
                # grad_dict = {}
                for name, param in net.named_parameters():
                    grad_org.append(param.grad.reshape(1, -1))
                #     grad_dict[name] = param.grad.clone()
                grad_org = torch.cat(grad_org, 1)

                ###### now we have to compute the true fisher
                with torch.no_grad():
                # gg = torch.nn.functional.softmax(outputs, dim=1)
                    sampled_y = torch.multinomial(torch.nn.functional.softmax(outputs, dim=1),1).squeeze().to(args.device)
                
                if args.trial == 'true':
                    update_list, loss = optimal_JJT_v2(outputs, sampled_y, args.batch_size, damping=damp, alpha=0.95, low_rank=args.low_rank, gamma=args.gamma, memory_efficient=args.memory_efficient, super_opt=args.super_opt)
                else:
                    update_list, loss = optimal_JJT(outputs, sampled_y, args.batch_size, damping=damp, alpha=0.95, low_rank=args.low_rank, gamma=args.gamma, memory_efficient=args.memory_efficient)

                # optimizer.zero_grad()
                # update_list, loss = optimal_JJT_fused(outputs, sampled_y, args.batch_size, damping=damp)

                optimizer.zero_grad()
   
                # last part of SMW formula
                grad_new = []
                for name, param in net.named_parameters():
                    param.grad.copy_(update_list[name])
                    grad_new.append(param.grad.reshape(1, -1))
                grad_new = torch.cat(grad_new, 1)   
                # grad_new = grad_org
                store_io_(False)
            else:
                inputs, targets = inputs.to(args.device), targets.to(args.device)
                optimizer.zero_grad()
                # net.set_require_grad(True)

                outputs = net(inputs)
                damp = damping
                loss = criterion(outputs, targets)
                loss.backward()

                # storing original gradient for later use
                grad_org = []
                # grad_dict = {}
                for name, param in net.named_parameters():
                    grad_org.append(param.grad.reshape(1, -1))
                #     grad_dict[name] = param.grad.clone()
                grad_org = torch.cat(grad_org, 1)

                ###### now we have to compute the true fisher
                # with torch.no_grad():
                # gg = torch.nn.functional.softmax(outputs, dim=1)
                    # sampled_y = torch.multinomial(torch.nn.functional.softmax(outputs, dim=1),1).squeeze().to(args.device)
                all_modules = net.modules()

                for m in net.modules():
                    if hasattr(m, "NGD_inv"):                    
                        grad = m.weight.grad
                        if isinstance(m, nn.Linear):
                            I = m.I
                            G = m.G
                            n = I.shape[0]
                            NGD_inv = m.NGD_inv
                            grad_prod = einsum("ni,oi->no", (I, grad))
                            grad_prod = einsum("no,no->n", (grad_prod, G))
                            v = matmul(NGD_inv, grad_prod.unsqueeze(1)).squeeze()
                            gv = einsum("n,no->no", (v, G))
                            gv = einsum("no,ni->oi", (gv, I))
                            gv = gv / n
                            update = (grad - gv)/damp
                            m.weight.grad.copy_(update)
                        elif isinstance(m, nn.Conv2d):
                            if hasattr(m, "AX"):

                                if args.low_rank.lower() == 'true':
                                    ###### using low rank structure
                                    U = m.U
                                    S = m.S
                                    V = m.V
                                    NGD_inv = m.NGD_inv
                                    n = NGD_inv.shape[0]

                                    grad_reshape = grad.reshape(grad.shape[0], -1)
                                    grad_prod = V @ grad_reshape.t().reshape(-1, 1)
                                    grad_prod = torch.diag(S) @ grad_prod
                                    grad_prod = U @ grad_prod
                                    
                                    grad_prod = grad_prod.squeeze()
                                    v = matmul(NGD_inv, grad_prod.unsqueeze(1)).squeeze()
                                    gv = U.t() @ v.unsqueeze(1)
                                    gv = torch.diag(S) @ gv
                                    gv = V.t() @ gv

                                    gv = gv.reshape(grad_reshape.shape[1], grad_reshape.shape[0]).t()
                                    gv = gv.view_as(grad)
                                    gv = gv / n
                                    update = (grad - gv)/damp
                                    m.weight.grad.copy_(update)
                                else:
                                    AX = m.AX
                                    NGD_inv = m.NGD_inv
                                    n = AX.shape[0]

                                    grad_reshape = grad.reshape(grad.shape[0], -1)
                                    grad_prod = einsum("nkm,mk->n", (AX, grad_reshape))
                                    v = matmul(NGD_inv, grad_prod.unsqueeze(1)).squeeze()
                                    gv = einsum("nkm,n->mk", (AX, v))
                                    gv = gv.view_as(grad)
                                    gv = gv / n
                                    update = (grad - gv)/damp
                                    m.weight.grad.copy_(update)
                            elif hasattr(m, "I"):
                                I = m.I
                                if args.memory_efficient == 'true':
                                    I = unfold_func(m)(I)
                                G = m.G
                                n = I.shape[0]
                                NGD_inv = m.NGD_inv
                                grad_reshape = grad.reshape(grad.shape[0], -1)
                                x1 = einsum("nkl,mk->nml", (I, grad_reshape))
                                grad_prod = einsum("nml,nml->n", (x1, G))
                                v = matmul(NGD_inv, grad_prod.unsqueeze(1)).squeeze()
                                gv = einsum("n,nml->nml", (v, G))
                                gv = einsum("nml,nkl->mk", (gv, I))
                                gv = gv.view_as(grad)
                                gv = gv / n
                                update = (grad - gv)/damp
                                m.weight.grad.copy_(update)
                        elif isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.BatchNorm1d):
                            if args.batchnorm == 'true':
                                dw = m.dw
                                n = dw.shape[0]
                                NGD_inv = m.NGD_inv
                                grad_prod = einsum("ni,i->n", (dw, grad))

                                v = matmul(NGD_inv, grad_prod.unsqueeze(1)).squeeze()
                                gv = einsum("n,ni->i", (v, dw))
                                
                                gv = gv / n
                                update = (grad - gv)/damp
                                m.weight.grad.copy_(update)
                        
                        

                # last part of SMW formula
                grad_new = []
                for name, param in net.named_parameters():
                    grad_new.append(param.grad.reshape(1, -1))
                grad_new = torch.cat(grad_new, 1)   
                # grad_new = grad_org


            ##### do kl clip
            lr = lr_scheduler.get_last_lr()[0]
            # vg_sum = 0
            # vg_sum += (grad_new * grad_org ).sum()
            # vg_sum = vg_sum * (lr ** 2)
            # nu = min(1.0, math.sqrt(args.kl_clip / vg_sum))
            # for name, param in net.named_parameters():
            #     param.grad.mul_(nu)

            # optimizer.step()
            # manual optimizing:
            with torch.no_grad():
                for name, param in net.named_parameters():
                    d_p = param.grad.data
                    # print('=== step ===')

                    # apply momentum
                    # if args.momentum != 0:
                    #     buf[name].mul_(args.momentum).add_(d_p)
                    #     d_p.copy_(buf[name])

                    # apply weight decay
                    if args.weight_decay != 0:
                        d_p.add_(args.weight_decay, param.data)

                    lr = lr_scheduler.get_last_lr()[0]
                    param.data.add_(-lr, d_p)
                    # print('d_p:', d_p.shape)
                    # print(d_p)



        train_loss += loss.item()
        _, predicted = outputs.max(1)
        total += targets.size(0)
        correct += predicted.eq(targets).sum().item()
        desc = ('[%s][LR=%s] Loss: %.3f | Acc: %.3f%% (%d/%d)' %
                (tag, lr_scheduler.get_last_lr()[0], train_loss / (batch_idx + 1), 100. * correct / total, correct, total))
        prog_bar.set_description(desc, refresh=True)
        if args.step_info == 'true' and (batch_idx % 50 == 0 or batch_idx == len(prog_bar) - 1):
            step_saved_time = time.time() - step_st_time
            epoch_time += step_saved_time
            test_acc, test_loss = test(epoch)
            TRAIN_INFO['train_acc'].append(float("{:.4f}".format(100. * correct / total)))
            TRAIN_INFO['test_acc'].append(float("{:.4f}".format(test_acc)))
            TRAIN_INFO['train_loss'].append(float("{:.4f}".format(train_loss/(batch_idx + 1))))
            TRAIN_INFO['test_loss'].append(float("{:.4f}".format(test_loss)))
            TRAIN_INFO['total_time'].append(float("{:.4f}".format(step_saved_time)))
            if args.debug_mem == 'true':
                TRAIN_INFO['memory'].append(torch.cuda.memory_reserved())
            step_st_time = time.time()
            net.train()

    writer.add_scalar('train/loss', train_loss/(batch_idx + 1), epoch)
    writer.add_scalar('train/acc', 100. * correct / total, epoch)
    acc = 100. * correct / total
    train_loss = train_loss/(batch_idx + 1)
    if args.step_info == 'true':
        TRAIN_INFO['epoch_time'].append(float("{:.4f}".format(epoch_time)))
    # save diagonal blocks of exact Fisher inverse or its approximations
    if args.save_inv == 'true':
      all_modules = net.modules()

      count = 0
      start, end = 0, 0
      if optim_name == 'ngd':
        for m in all_modules:
          if m.__class__.__name__ == 'Linear':
            with torch.no_grad():
              I = m.I
              G = m.G
              J = torch.einsum('ni,no->nio', I, G)
              J = J.reshape(J.size(0), -1)
              JTDJ = torch.matmul(J.t(), torch.matmul(m.NGD_inv, J)) / args.batch_size

              with open('ngd/' + str(epoch) + '_m_' + str(count) + '_inv.npy', 'wb') as f:
                np.save(f, ((torch.eye(JTDJ.size(0)).to(JTDJ.device) - JTDJ) / damping).cpu().numpy())
                count += 1

          elif m.__class__.__name__ == 'Conv2d':
            with torch.no_grad():
              AX = m.AX
              AX = AX.reshape(AX.size(0), -1)
              JTDJ = torch.matmul(AX.t(), torch.matmul(m.NGD_inv, AX)) / args.batch_size
              with open('ngd/' + str(epoch) + '_m_' + str(count) + '_inv.npy', 'wb') as f:
                np.save(f, ((torch.eye(JTDJ.size(0)).to(JTDJ.device) - JTDJ) / damping).cpu().numpy())
                count += 1

      elif optim_name == 'exact_ngd':
        for m in all_modules:
          if m.__class__.__name__ in ['Conv2d', 'Linear']:
            with open('exact/' + str(epoch) + '_m_' + str(count) + '_inv.npy', 'wb') as f:
              end = start + m.weight.data.reshape(1, -1).size(1)
              np.save(f, inv[start:end,start:end].cpu().numpy())
              start = end + m.bias.data.size(0)
              count += 1

      elif optim_name == 'kfac':
        for m in all_modules:
          if m.__class__.__name__ in ['Conv2d', 'Linear']:
            with open('kfac/' + str(epoch) + '_m_' + str(count) + '_inv.npy', 'wb') as f:
              G = optimizer.m_gg[m]
              A = optimizer.m_aa[m]

              H_g = torch.linalg.inv(G + math.sqrt(damping) * torch.eye(G.size(0)).to(G.device))
              H_a = torch.linalg.inv(A + math.sqrt(damping) * torch.eye(A.size(0)).to(A.device))

              end = m.weight.data.reshape(1, -1).size(1)
              kfac_inv = torch.kron(H_a, H_g)[:end,:end]
              np.save(f, kfac_inv.cpu().numpy())
              count += 1

    return acc, train_loss
예제 #20
0
 def forward(self, positions):
     sinusoid_inp = torch.einsum("i,j->ij", positions.float(),
                                 self.inv_freq)
     emb = torch.cat((sinusoid_inp.sin(), sinusoid_inp.cos()), dim=-1)
     return emb[None, :, :]
예제 #21
0
    def model(self, data):
        '''
        Define the parameters
        '''
        n_ind = data['n_ind']
        n_trt = data['n_trt']
        n_tms = data['n_tms']
        n_mrk = data['n_mrk']

        n_prs = n_trt * n_tms * n_mrk

        plt_ind = pyro.plate('individuals', n_ind, dim=-3)
        plt_trt = pyro.plate('treatments', n_trt, dim=-2)
        plt_tms = pyro.plate('times', n_tms, dim=-1)

        pars = {}
        # covariance factors
        with plt_tms:
            # learning dt time step sizes
            # if k(t1,t2) is independent of time, can instead learn scales and variances for RBF kernels that use data['time_vals']
            pars['dt0'] = pyro.sample('dt0', dist.Normal(0, 1))
            pars['dt1'] = pyro.sample('dt1', dist.Normal(0, 1))

        pars['theta_trt0'] = pyro.sample('theta_trt0',
                                         dist.HalfCauchy(torch.ones(n_trt)))
        pars['theta_mrk0'] = pyro.sample('theta_mrk0',
                                         dist.HalfCauchy(torch.ones(n_mrk)))
        pars['theta_trt1'] = pyro.sample('theta_trt1',
                                         dist.HalfCauchy(torch.ones(n_trt)))
        pars['L_omega_trt1'] = pyro.sample(
            'L_omega_trt1', dist.LKJCorrCholesky(n_trt, torch.ones(1)))
        pars['theta_mrk1'] = pyro.sample('theta_mrk1',
                                         dist.HalfCauchy(torch.ones(n_mrk)))
        pars['L_omega_mrk1'] = pyro.sample(
            'L_omega_mrk1', dist.LKJCorrCholesky(n_mrk, torch.ones(1)))

        times0 = fun.pad(torch.cumsum(pars['dt0'].exp().log1p(), 0), (1, 0),
                         value=0)[:-1].unsqueeze(1)
        times1 = fun.pad(torch.cumsum(pars['dt1'].exp().log1p(), 0), (1, 0),
                         value=0)[:-1].unsqueeze(1)
        cov_t0 = (-torch.cdist(times0, times0)).exp()
        cov_t1 = (-torch.cdist(times1, times1)).exp()

        cov_i0 = pars['theta_trt0'].diag()
        L_Omega_trt = torch.mm(torch.diag(pars['theta_trt1'].sqrt()),
                               pars['L_omega_trt1'])
        cov_i1 = L_Omega_trt.mm(L_Omega_trt.t())

        cov_m0 = pars['theta_mrk0'].diag()
        L_Omega_mrk = torch.mm(torch.diag(pars['theta_mrk1'].sqrt()),
                               pars['L_omega_mrk1'])
        cov_m1 = L_Omega_mrk.mm(L_Omega_mrk.t())

        # kronecker product of the factors
        cov_itm0 = torch.einsum('ij,tu,mn->itmjun',
                                [cov_i0, cov_t0, cov_m0]).view(n_prs, n_prs)
        cov_itm1 = torch.einsum('ij,tu,mn->itmjun',
                                [cov_i1, cov_t1, cov_m1]).view(n_prs, n_prs)

        # global and individual level params of each marker, treatment, and time point
        pars['glb'] = pyro.sample(
            'glb', dist.MultivariateNormal(torch.zeros(n_prs), cov_itm0))
        with plt_ind:
            pars['ind'] = pyro.sample(
                'ind', dist.MultivariateNormal(torch.zeros(n_prs), cov_itm1))

        # observation noise, time series bias and scale
        pars['noise_scale'] = pyro.sample('noise_scale',
                                          dist.HalfCauchy(torch.ones(n_mrk)))
        pars['t0_scale'] = pyro.sample('t0_scale',
                                       dist.HalfCauchy(torch.ones(n_mrk)))
        with plt_ind:
            pars['t0'] = pyro.sample(
                't0',
                dist.MultivariateNormal(torch.zeros(n_mrk),
                                        pars['t0_scale'].diag()))
            with plt_trt, plt_tms:
                pars['noise'] = pyro.sample(
                    'noise',
                    dist.MultivariateNormal(torch.zeros(n_mrk),
                                            pars['noise_scale'].diag()))

        # likelihood of the data
        distr = self.get_distr(data, pars)
        pyro.sample('obs', distr, obs=data['Y'])
예제 #22
0
 def conditional(self, input, given):
     return torch.einsum('ik,lk->il', input, self.weight[given,:,:]) + self.bias[given,:].unsqueeze(0) 
예제 #23
0
    def backward(ctx, grad_kernel):
        F, Y, R, norm_coef = ctx.saved_tensors
        batch, a, b = ctx.batch, ctx.a, ctx.b

        grad_F = grad_Y = grad_R = None

        if ctx.needs_input_grad[0]:
            grad_F = grad_kernel.new_zeros(
                *ctx.F_shape)  # [batch, b, l_in * mul_in * m_in]
        if ctx.needs_input_grad[1]:
            grad_Y = grad_kernel.new_zeros(
                *ctx.Y_shape)  # [l_filter * m_filter, batch, a, b]
        if ctx.needs_input_grad[2]:
            grad_R = grad_kernel.new_zeros(
                *ctx.R_shape
            )  # [batch, a, b, l_out * l_in * mul_out * mul_in * l_filter]

        begin_R = 0

        begin_out = 0
        for i, (mul_out, l_out, p_out) in enumerate(ctx.Rs_out):
            s_out = slice(begin_out, begin_out + mul_out * (2 * l_out + 1))
            begin_out += mul_out * (2 * l_out + 1)

            begin_in = 0
            for j, (mul_in, l_in, p_in) in enumerate(ctx.Rs_in):
                s_in = slice(begin_in, begin_in + mul_in * (2 * l_in + 1))
                begin_in += mul_in * (2 * l_in + 1)

                l_filters = ctx.get_l_filters(l_in, p_in, l_out, p_out)
                if not l_filters:
                    continue

                n = mul_out * mul_in * len(l_filters)
                if (grad_Y is not None) or (grad_F is not None):
                    sub_R = R[:, :, :, begin_R:begin_R + n].contiguous().view(
                        batch, a, b, mul_out, mul_in,
                        -1)  # [batch, a, b, mul_out, mul_in, l_filter]
                if grad_R is not None:
                    sub_grad_R = grad_R[:, :, :, begin_R:begin_R + n].contiguous(
                    ).view(batch, a, b, mul_out, mul_in,
                           -1)  # [batch, a, b, mul_out, mul_in, l_filter]

                if grad_F is not None:
                    sub_grad_F = grad_F[:, :, s_in].contiguous().view(
                        batch, b, mul_in,
                        2 * l_in + 1)  # [batch, b, mul_in, 2 * l_in + 1]
                if (grad_Y is not None) or (grad_R is not None):
                    sub_F = F[..., s_in].view(batch, b, mul_in, 2 * l_in + 1)

                grad_K = grad_kernel[:, :, s_out].view(batch, a, mul_out,
                                                       2 * l_out + 1)

                sub_norm_coef = norm_coef[i, j]  # [batch, a, b]

                for k, l_filter in enumerate(l_filters):
                    tmp = sum(2 * l + 1 for l in ctx.set_of_l_filters
                              if l < l_filter)
                    C = o3.clebsch_gordan(l_out,
                                          l_in,
                                          l_filter,
                                          cached=True,
                                          like=grad_kernel)  # [m_out, m_in, m]

                    if (grad_F is not None) or (grad_R is not None):
                        sub_Y = Y[tmp:tmp + 2 * l_filter + 1,
                                  ...]  # [m, batch, a, b]

                    if grad_F is not None:
                        sub_grad_F += torch.einsum(
                            "zaui,ijk,kzab,zabuv,zab->zbvj", grad_K, C, sub_Y,
                            sub_R[..., k],
                            sub_norm_coef)  # [batch, b, mul_in, 2 * l_in + 1
                    if grad_Y is not None:
                        grad_Y[tmp:tmp + 2 * l_filter + 1,
                               ...] += torch.einsum(
                                   "zaui,ijk,zabuv,zab,zbvj->kzab", grad_K, C,
                                   sub_R[..., k], sub_norm_coef,
                                   sub_F)  # [m, batch, a, b]
                    if grad_R is not None:
                        sub_grad_R[..., k] = torch.einsum(
                            "zaui,ijk,kzab,zab,zbvj->zabuv", grad_K, C, sub_Y,
                            sub_norm_coef,
                            sub_F)  # [batch, a, b, mul_out, mul_in]
                if grad_F is not None:
                    grad_F[:, :,
                           s_in] = sub_grad_F.view(batch, b,
                                                   mul_in * (2 * l_in + 1))
                if grad_R is not None:
                    grad_R[..., begin_R:begin_R + n] += sub_grad_R.view(
                        batch, a, b, -1)
                begin_R += n

        return grad_F, grad_Y, grad_R, None, None, None, None, None
예제 #24
0
def apply_TM_1sO_2(state, env, edge, op=None, verbosity=0):
    r"""
    :param state: underlying 1-site C4v symmetric wavefunction
    :param env: C4v symmetric environment corresponding to ``state``
    :param edge: tensor of dimensions :math:`\chi \times (D^2)^2 \times \chi`
    :param op: two-site operator to be inserted within the two-site transfer matrix
    :param verbosity: logging verbosity
    :type state: IPEPS_C4V
    :type env: ENV_C4V
    :type edge: torch.tensor
    :type op: torch.tensor
    :type verbosity: int
    :return: ``edge`` with a single instance of the transfer matrix applied 
             The resulting tensor has an identical index structure as the 
             original ``edge``
    :rtype: torch.tensor
    
    Applies a single instance of the two-site "transfer matrix" to 
    the ``edge`` tensor by contracting the following network, or its corresponding 
    rotation depending on the ``direction``::

                 -----T----------
                |     |          
               edge--(a^+ o1 a)--
                |     |   |      
                |----(a^+ o2 a)--
                |     |          
                 -----T----------

    The two-site operator is first decomposed into a simple MPO o1--o2
    (TODO case where op comes with an extra MPO index)::
        
         s1'  s2'    s1'      s2'
        |  op   | = |o1|-----|o2|
         s1   s2     s1       s2  

    where the physical indices `s` and `s'` of the on-site tensor :math:`a` 
    and it's hermitian conjugate :math:`a^\dagger` are contracted with 
    identity :math:`\delta_{s,s'}` or ``o1``, ``o2``.
    """

    # TODO stronger verification
    op_1, op_2 = None, None
    if op is not None:
        if len(op.size()) == 4:
            # pre-process ``op``
            # TODO possibly truncate/compress according to the vanishingly small singular values
            dims_op = op.size()
            op_mat = op.permute(0, 2, 1, 3).contiguous().reshape(
                dims_op[0]**2, dims_op[0]**2)
            op_1, s, op_2 = torch.svd(op_mat)
            op_1 = op_1.reshape(dims_op[0], dims_op[0], s.size()[0])
            op_2 = torch.einsum('i,ij->ij', s,
                                op_2.t()).reshape(s.size()[0], dims_op[0],
                                                  dims_op[0])
            op_2 = op_2.permute(1, 2, 0).contiguous()
        else:
            raise ValueError(f"Invalid op: rank {op.size()}")

    # Four basic cases of passed op
    def get_aXa(a, op):
        # a - on-site tensor
        # op - operator
        dims_a = a.size()
        dims_op = None if op is None else op.size()
        if op is None:
            # identity
            A= torch.einsum('nefgh,nabcd->eafbgchd',a,a).contiguous()\
                .view(dims_a[1]**2, dims_a[2]**2, dims_a[3]**2, dims_a[4]**2)
        elif len(dims_op) == 2:
            # one-site operator
            A= torch.einsum('mefgh,mn,nabcd->eafbgchd',a,op,a).contiguous()\
                .view(dims_a[1]**2, dims_a[2]**2, dims_a[3]**2, dims_a[4]**2)
        elif len(dims_op) == 3:
            # edge operators of some MPO within the transfer matrix
            #
            # 0                   0
            # |                   |
            # op--2 ... or ... 2--op
            # |                   |
            # 1                   1
            #
            # assume the last index of the op is the MPO dimension.
            # It will become the last index of the resulting edge
            A= torch.einsum('mefgh,mnl,nabcd->eafbgchdl',a,op,a).contiguous()\
                .view(dims_a[1]**2, dims_a[2]**2, dims_a[3]**2, dims_a[4]**2, -1)
        if verbosity > 0: print(f"aXa {A.size()}")
        return A

    a = next(iter(state.sites.values()))
    T = env.T[env.keyT]
    # Assume index structure of ``edge`` tensor to be as follows
    #
    #       -- 0
    # edge |-- 1
    #      |---2
    #       -- 3
    #
    #   ----0 0--T--1->0
    #  |         2->1
    # edge--1->2
    #  |
    #   ----2->3
    #  |
    #   ----3->4
    E = torch.tensordot(T, edge, ([0], [0]))
    if verbosity > 0: print("E=edgeT " + str(E.size()))

    # TODO - more efficent contraction with uncontracted-double-layer on-site tensor
    #        Possibly reshape indices 1,2 of E, which are to be contracted with
    #        on-site tensor and contract bra,ket in two steps instead of creating
    #        double layer tensor
    #    /
    # --A--
    #  /|s
    #   X
    # s'|/
    # --A--
    #  /
    #
    # where X is Id or op
    A = get_aXa(a, op_1)

    #   ---------T--0
    #  |         1
    #  |         0
    # edge--2 1--A--3->4
    #  |      3<-2 \
    #   ----3->1   (4->5)
    #  |
    #   ----4->2
    E = torch.tensordot(E, A, ([1, 2], [0, 1]))
    if verbosity > 0: print("E=edgeTA " + str(E.size()))

    A = get_aXa(a, op_2)
    #   ---------T--0
    #  |         |
    # edge-------A--4->2
    #  |         | \
    #  |         3 (5)
    #  |         0 (4)
    #  |         | /
    #   ----1 1--A--2->3
    #  |         3->4
    #   ----2->1
    E = torch.tensordot(E,A,([1,3],[1,0])) if op is None else \
        torch.tensordot(E,A,([1,3,5],[1,0,4]))
    if verbosity > 0: print("E=edgeTAA " + str(E.size()))

    #   ---------T--0
    #  |         |
    # edge-------A--2->1
    #  |         |
    #   ---------A--3->2
    #  |         3
    #  |         2
    #   ----1 0--T2--1->3
    E = torch.tensordot(E, T, ([1, 3], [0, 2]))
    if verbosity > 0: print("E=edgeTAAT " + str(E.size()))

    return E
예제 #25
0
def kernel_conv_fn_forward(F, Y, R, norm_coef, Rs_in, Rs_out, get_l_filters,
                           set_of_l_filters):
    """
    :param F: tensor [batch, b, l_in * mul_in * m_in]
    :param Y: tensor [l_filter * m_filter, batch, a, b]
    :param R: tensor [batch, a, b, l_out * l_in * mul_out * mul_in * l_filter]
    :param norm_coef: tensor [l_out, l_in, batch, a, b]
    :return: tensor [batch, a, l_out * mul_out * m_out, l_in * mul_in * m_in]
    """
    batch, a, b = Y.shape[1:]
    n_in = rs.dim(Rs_in)
    n_out = rs.dim(Rs_out)

    kernel_conv = Y.new_zeros(batch, a, n_out)

    # note: for the normalization we assume that the variance of R[i] is one
    begin_R = 0

    begin_out = 0
    for i, (mul_out, l_out, p_out) in enumerate(Rs_out):
        s_out = slice(begin_out, begin_out + mul_out * (2 * l_out + 1))
        begin_out += mul_out * (2 * l_out + 1)

        begin_in = 0
        for j, (mul_in, l_in, p_in) in enumerate(Rs_in):
            s_in = slice(begin_in, begin_in + mul_in * (2 * l_in + 1))
            begin_in += mul_in * (2 * l_in + 1)

            l_filters = get_l_filters(l_in, p_in, l_out, p_out)
            if not l_filters:
                continue

            # extract the subset of the `R` that corresponds to the couple (l_out, l_in)
            n = mul_out * mul_in * len(l_filters)
            sub_R = R[:, :, :, begin_R:begin_R + n].contiguous().view(
                batch, a, b, mul_out, mul_in,
                -1)  # [batch, a, b, mul_out, mul_in, l_filter]
            begin_R += n

            sub_norm_coef = norm_coef[i, j]  # [batch]

            K = 0
            for k, l_filter in enumerate(l_filters):
                offset = sum(2 * l + 1 for l in set_of_l_filters
                             if l < l_filter)
                sub_Y = Y[offset:offset + 2 * l_filter + 1,
                          ...]  # [m, batch, a, b]

                C = o3.clebsch_gordan(l_out,
                                      l_in,
                                      l_filter,
                                      cached=True,
                                      like=kernel_conv)  # [m_out, m_in, m]

                K += torch.einsum("ijk,kzab,zabuv,zab,zbvj->zaui", C, sub_Y,
                                  sub_R[...,
                                        k], sub_norm_coef, F[..., s_in].view(
                                            batch, b, mul_in,
                                            -1))  # [batch, a, mul_out, m_out]

            if K is not 0:
                kernel_conv[:, :, s_out] += K.view(batch, a, -1)

    return kernel_conv
    def get_fantasy_strategy(self, inputs, targets, full_inputs, full_targets,
                             full_output, **kwargs):
        """
        Returns a new PredictionStrategy that incorporates the specified inputs and targets as new training data.

        This method is primary responsible for updating the mean and covariance caches. To add fantasy data to a
        GP model, use the :meth:`~gpytorch.models.ExactGP.get_fantasy_model` method.

        Args:
            - :attr:`inputs` (Tensor `b1 x ... x bk x m x d` or `f x b1 x ... x bk x m x d`): Locations of fantasy
                observations.
            - :attr:`targets` (Tensor `b1 x ... x bk x m` or `f x b1 x ... x bk x m`): Labels of fantasy observations.
            - :attr:`full_inputs` (Tensor `b1 x ... x bk x n+m x d` or `f x b1 x ... x bk x n+m x d`): Training data
                concatenated with fantasy inputs
            - :attr:`full_targets` (Tensor `b1 x ... x bk x n+m` or `f x b1 x ... x bk x n+m`): Training labels
                concatenated with fantasy labels.
            - :attr:`full_output` (:class:`gpytorch.distributions.MultivariateNormal`): Prior called on full_inputs

        Returns:
            - :class:`DefaultPredictionStrategy`
                A `DefaultPredictionStrategy` model with `n + m` training examples, where the `m` fantasy examples have
                been added and all test-time caches have been updated.
        """
        full_mean, full_covar = full_output.mean, full_output.lazy_covariance_matrix

        batch_shape = full_inputs[0].shape[:-2]

        full_mean = full_mean.view(*batch_shape, -1)
        num_train = self.num_train

        # Evaluate fant x train and fant x fant covariance matrices, leave train x train unevaluated.
        fant_fant_covar = full_covar[..., num_train:, num_train:]
        fant_mean = full_mean[..., num_train:]
        mvn = self.train_prior_dist.__class__(fant_mean, fant_fant_covar)
        fant_likelihood = self.likelihood.get_fantasy_likelihood(**kwargs)
        mvn_obs = fant_likelihood(mvn, inputs, **kwargs)

        fant_fant_covar = mvn_obs.covariance_matrix
        fant_train_covar = delazify(full_covar[..., num_train:, :num_train])

        self.fantasy_inputs = inputs
        self.fantasy_targets = targets
        r"""
        Compute a new mean cache given the old mean cache.

        We have \alpha = K^{-1}y, and we want to solve [K U; U' S][a; b] = [y; y_f], where U' is fant_train_covar,
        S is fant_fant_covar, and y_f is (targets - fant_mean)

        To do this, we solve the bordered linear system of equations for [a; b]:
            AQ = U  # Q = fant_solve
            [S - U'Q]b = y_f - U'\alpha   ==> b = [S - U'Q]^{-1}(y_f - U'\alpha)
            a = \alpha - Qb
        """
        # Get cached K inverse decomp. (or compute if we somehow don't already have the covariance cache)
        K_inverse = self.lik_train_train_covar.root_inv_decomposition()
        fant_solve = K_inverse.matmul(fant_train_covar.transpose(-2, -1))

        # Solve for "b", the lower portion of the *new* \\alpha corresponding to the fantasy points.
        schur_complement = fant_fant_covar - fant_train_covar.matmul(
            fant_solve)

        # we'd like to use a less hacky approach for the following, but einsum can be much faster than
        # than unsqueezing/squeezing here (esp. in backward passes), unfortunately it currenlty has some
        # issues with broadcasting: https://github.com/pytorch/pytorch/issues/15671
        prefix = string.ascii_lowercase[:max(
            fant_train_covar.dim() - self.mean_cache.dim() - 1, 0)]
        ftcm = torch.einsum(prefix + "...yz,...z->" + prefix + "...y",
                            [fant_train_covar, self.mean_cache])

        small_system_rhs = targets - fant_mean - ftcm
        small_system_rhs = small_system_rhs.unsqueeze(-1)
        # Schur complement of a spd matrix is guaranteed to be positive definite
        schur_cholesky = psd_safe_cholesky(schur_complement)
        fant_cache_lower = torch.cholesky_solve(small_system_rhs,
                                                schur_cholesky)

        # Get "a", the new upper portion of the cache corresponding to the old training points.
        fant_cache_upper = self.mean_cache.unsqueeze(-1) - fant_solve.matmul(
            fant_cache_lower)

        fant_cache_upper = fant_cache_upper.squeeze(-1)
        fant_cache_lower = fant_cache_lower.squeeze(-1)

        # New mean cache.
        fant_mean_cache = torch.cat((fant_cache_upper, fant_cache_lower),
                                    dim=-1)

        # now update the root and root inverse
        new_lt = self.lik_train_train_covar.cat_rows(fant_train_covar,
                                                     fant_fant_covar)
        new_root = new_lt.root_decomposition().root.evaluate()
        new_covar_cache = new_lt.root_inv_decomposition().root.evaluate()

        # Expand inputs accordingly if necessary (for fantasies at the same points)
        if full_inputs[0].dim() <= full_targets.dim():
            fant_batch_shape = full_targets.shape[:1]
            n_batch = len(full_mean.shape[:-1])
            repeat_shape = fant_batch_shape + torch.Size([1] * n_batch)
            full_inputs = [
                fi.expand(fant_batch_shape + fi.shape) for fi in full_inputs
            ]
            full_mean = full_mean.expand(fant_batch_shape + full_mean.shape)
            full_covar = BatchRepeatLazyTensor(full_covar, repeat_shape)
            new_root = BatchRepeatLazyTensor(NonLazyTensor(new_root),
                                             repeat_shape)
            # no need to repeat the covar cache, broadcasting will do the right thing

        # Create new DefaultPredictionStrategy object
        fant_strat = self.__class__(
            train_inputs=full_inputs,
            train_prior_dist=self.train_prior_dist.__class__(
                full_mean, full_covar),
            train_labels=full_targets,
            likelihood=fant_likelihood,
            root=new_root,
            inv_root=new_covar_cache,
        )
        add_to_cache(fant_strat, "mean_cache", fant_mean_cache)
        add_to_cache(fant_strat, "covar_cache", new_covar_cache)
        return fant_strat
예제 #27
0
    def forward(self, data):
        """Run SuperGlue on a pair of keypoints and descriptors"""
        desc0, desc1 = data['descriptors0'], data['descriptors1']
        kpts0, kpts1 = data['keypoints0'], data['keypoints1']

        if kpts0.shape[1] == 0 or kpts1.shape[1] == 0:  # no keypoints
            shape0, shape1 = kpts0.shape[:-1], kpts1.shape[:-1]
            return {
                'matches0': kpts0.new_full(shape0, -1, dtype=torch.int),
                'matches1': kpts1.new_full(shape1, -1, dtype=torch.int),
                'matching_scores0': kpts0.new_zeros(shape0),
                'matching_scores1': kpts1.new_zeros(shape1),
            }

        # Keypoint normalization.
        # kpts0 = normalize_keypoints(kpts0, data['image0'].shape)
        # kpts1 = normalize_keypoints(kpts1, data['image1'].shape)

        # Keypoint MLP encoder.
        # desc0 = desc0 + self.kenc(kpts0, data['scores0'])
        # desc1 = desc1 + self.kenc(kpts1, data['scores1'])
        desc0 = desc0 + self.kenc(kpts0)
        desc1 = desc1 + self.kenc(kpts1)

        # Multi-layer Transformer network.
        desc0, desc1 = self.gnn(desc0, desc1)

        # Final MLP projection.
        mdesc0, mdesc1 = self.final_proj(desc0), self.final_proj(desc1)

        # Compute matching descriptor distance.
        scores = torch.einsum('bdn,bdm->bnm', mdesc0, mdesc1)
        scores = scores / self.config['descriptor_dim']**.5

        # Run the optimal transport.
        scores = log_optimal_transport(
            scores, self.bin_score, iters=self.config['sinkhorn_iterations'])

        # 对scores构造损失函数
        # loss = compute_loss(scores, matches_gt)
        # scores: 1 * (m+1) * (n+1), matches_gt: 1 * (m+1) * (n+1)
        # loss = -scores.log() * matches_gt

        # Get the matches with score above "match_threshold".
        max0, max1 = scores[:, :-1, :-1].max(2), scores[:, :-1, :-1].max(1)
        indices0, indices1 = max0.indices, max1.indices
        mutual0 = arange_like(indices0,
                              1)[None] == indices1.gather(1, indices0)
        mutual1 = arange_like(indices1,
                              1)[None] == indices0.gather(1, indices1)
        zero = scores.new_tensor(0)
        mscores0 = torch.where(mutual0, max0.values.exp(), zero)
        # mscores1 = torch.where(mutual1, mscores0.gather(1, indices1), zero)
        valid0 = mutual0 & (mscores0 > self.config['match_threshold'])
        valid1 = mutual1 & valid0.gather(1, indices1)
        indices0 = torch.where(valid0, indices0, indices0.new_tensor(-1))
        indices1 = torch.where(valid1, indices1, indices1.new_tensor(-1))

        # hard-code top k values
        top_k_matches0 = scores[0, :-1, :-1].topk(5, dim=0).indices
        return {
            'matches0': indices0,  # use -1 for invalid match
            'matches1': indices1,  # use -1 for invalid match
            # 'matching_scores0': mscores0,
            # 'matching_scores1': mscores1,
            'scores': scores,
            'top_k_matches1': top_k_matches0
        }
예제 #28
0
    def forward(self,
                key,
                query,
                mask,
                cache=False,
                boundary_leftmost=0,
                boundary_rightmost=100000):
        """Compute chunkwise energy.

        Args:
            key (FloatTensor): `[B, klen, kdim]`
            query (FloatTensor): `[B, qlen, qdim]`
            mask (ByteTensor): `[B, qlen, klen]`
            cache (bool): cache key and mask
            boundary_leftmost (int): leftmost boundary offset
            boundary_rightmost (int): rightmost boundary offset
        Returns:
            e (FloatTensor): `[B, H_ca, qlen, klen]`

        """
        klen, kdim = key.size()[1:]
        bs, qlen = query.size()[:2]

        # Pre-computation of encoder-side features for computing scores
        if self.key is None or not cache:
            self.key = self.w_key(key).view(-1, klen, self.n_heads,
                                            self.d_k)  # `[B, klen, H_ca, d_k]`
            if mask is not None:
                self.mask = mask.unsqueeze(3).repeat(
                    [1, 1, 1, self.n_heads])  # `[B, qlen, klen, H_ca]`
                mask_size = (bs, qlen, klen, self.n_heads)
                assert self.mask.size() == mask_size, (self.mask.size(),
                                                       mask_size)
            else:
                self.mask = None

        k = self.key
        if k.size(0) != bs:  # for infernece
            k = k[0:1].repeat([bs, 1, 1, 1])
        klen = k.size(1)
        q = self.w_query(query).view(-1, qlen, self.n_heads,
                                     self.d_k)  # `[B, qlen, H_ca, d_k]`
        m = self.mask

        # Truncate encoder memories for efficient DECODING
        if boundary_leftmost > 0 or (0 < boundary_rightmost < klen):
            k = k[:, boundary_leftmost:boundary_rightmost + 1]
            klen = k.size(1)
            if m is not None:
                m = m[:, :, boundary_leftmost:boundary_rightmost + 1]

        if self.atype == 'scaled_dot':
            e = torch.einsum("bihd,bjhd->bijh", (q, k)) / self.scale
        elif self.atype == 'add':
            e = self.v(
                torch.relu(k[:, None] + q[:, :, None]).view(
                    bs, qlen, klen, -1))
        # e: `[B, qlen, klen, H_ca]`

        if m is not None:
            NEG_INF = float(
                np.finfo(torch.tensor(0, dtype=e.dtype).numpy().dtype).min)
            e = e.masked_fill_(m == 0, NEG_INF)
        e = e.permute(0, 3, 1, 2)  # `[B, H_ca, qlen, klen]`

        return e
예제 #29
0
    def forward(self,
                hidden_states,
                start_positions=None,
                end_positions=None,
                cls_index=None,
                is_impossible=None,
                p_mask=None):
        outputs = ()

        start_logits = self.start_logits(hidden_states, p_mask=p_mask)

        if start_positions is not None and end_positions is not None:
            # If we are on multi-GPU, let's remove the dimension added by batch splitting
            for x in (start_positions, end_positions, cls_index,
                      is_impossible):
                if x is not None and x.dim() > 1:
                    x.squeeze_(-1)

            # during training, compute the end logits based on the ground truth of the start position
            end_logits = self.end_logits(hidden_states,
                                         start_positions=start_positions,
                                         p_mask=p_mask)

            loss_fct = CrossEntropyLoss()
            start_loss = loss_fct(start_logits, start_positions)
            end_loss = loss_fct(end_logits, end_positions)
            total_loss = (start_loss + end_loss) / 2

            if cls_index is not None and is_impossible is not None:
                # Predict answerability from the representation of CLS and START
                cls_logits = self.answer_class(hidden_states,
                                               start_positions=start_positions,
                                               cls_index=cls_index)
                loss_fct_cls = nn.BCEWithLogitsLoss()
                cls_loss = loss_fct_cls(cls_logits, is_impossible)

                # note(zhiliny): by default multiply the loss by 0.5 so that the scale is comparable to start_loss and end_loss
                total_loss += cls_loss * 0.5

            outputs = (total_loss, ) + outputs

        else:
            # during inference, compute the end logits based on beam search
            bsz, slen, hsz = hidden_states.size()
            start_log_probs = F.softmax(start_logits,
                                        dim=-1)  # shape (bsz, slen)

            start_top_log_probs, start_top_index = torch.topk(
                start_log_probs, self.start_n_top,
                dim=-1)  # shape (bsz, start_n_top)
            start_top_index_exp = start_top_index.unsqueeze(-1).expand(
                -1, -1, hsz)  # shape (bsz, start_n_top, hsz)
            start_states = torch.gather(
                hidden_states, -2,
                start_top_index_exp)  # shape (bsz, start_n_top, hsz)
            start_states = start_states.unsqueeze(1).expand(
                -1, slen, -1, -1)  # shape (bsz, slen, start_n_top, hsz)

            hidden_states_expanded = hidden_states.unsqueeze(2).expand_as(
                start_states)  # shape (bsz, slen, start_n_top, hsz)
            p_mask = p_mask.unsqueeze(-1) if p_mask is not None else None
            end_logits = self.end_logits(hidden_states_expanded,
                                         start_states=start_states,
                                         p_mask=p_mask)
            end_log_probs = F.softmax(end_logits,
                                      dim=1)  # shape (bsz, slen, start_n_top)

            end_top_log_probs, end_top_index = torch.topk(
                end_log_probs, self.end_n_top,
                dim=1)  # shape (bsz, end_n_top, start_n_top)
            end_top_log_probs = end_top_log_probs.view(
                -1, self.start_n_top * self.end_n_top)
            end_top_index = end_top_index.view(
                -1, self.start_n_top * self.end_n_top)

            start_states = torch.einsum("blh,bl->bh", hidden_states,
                                        start_log_probs)
            cls_logits = self.answer_class(hidden_states,
                                           start_states=start_states,
                                           cls_index=cls_index)

            outputs = (start_top_log_probs, start_top_index, end_top_log_probs,
                       end_top_index, cls_logits) + outputs

        # return start_top_log_probs, start_top_index, end_top_log_probs, end_top_index, cls_logits
        # or (if labels are provided) (total_loss,)
        return outputs
예제 #30
0
 def projx(self, x: torch.Tensor) -> torch.Tensor:
     U, _, V = linalg.svd(x, full_matrices=False)
     return torch.einsum("...ik,...kj->...ij", U, V)
예제 #31
0
# adds a dimension of size 1, just like unsqueeze
points = points[None]
print(points)

# -------------------------------------------------------------------

# 3.4  Named tensors

img_t = torch.randn(3, 5, 5)  # shape  [channels, rows, columns]
weights = torch.tensor([0.2126, 0.7152, 0.0722])

batch_t = torch.randn(2, 3, 5, 5)  # shape  [batch, channels, rows, columns]

img_gray_naive = img_t.mean(-3)
batch_gray_naive = batch_t.mean(-3)

print(f"shape_1: {img_gray_naive.shape}, shape_2: {batch_gray_naive.shape}")

unsqueezed_weights = weights.unsqueeze(-1).unsqueeze_(-1)
img_weights = (img_t * unsqueezed_weights)
batch_weights = (batch_t * unsqueezed_weights)
img_gray_weighted = img_weights.sum(-3)
batch_gray_weighted = batch_weights.sum(-3)

print(f"{batch_weights.shape}, {batch_t.shape}, {unsqueezed_weights.shape}")

img_gray_weighted_fancy = torch.einsum('...chw,c->...hw', img_t, weights)
batch_gray_weighted_fancy = torch.einsum('...chw,c->...hw', batch_t, weights)
print(batch_gray_weighted_fancy.shape)
예제 #32
0
def similarity(x, means):
    return torch.einsum('bhld,hcd->bhlc', x, means)
예제 #33
0
    def forward(self, qk, v, query_len=None, input_mask=None):
        batch_size, seqlen, dim = qk.shape
        query_len = default(query_len, seqlen)
        device = qk.device

        n_buckets = seqlen // self.bucket_size

        buckets = self.hash_vectors(n_buckets, qk)
        # We use the same vector as both a query and a key.
        assert int(buckets.shape[1]) == self.n_hashes * seqlen

        ticker = torch.arange(self.n_hashes * seqlen,
                              device=device).unsqueeze(0).expand_as(buckets)
        buckets_and_t = seqlen * buckets + (ticker % seqlen)
        buckets_and_t = buckets_and_t.detach()

        # Hash-based sort ("s" at the start of variable names means "sorted")
        sbuckets_and_t, sticker = sort_key_val(buckets_and_t, ticker, dim=-1)
        _, undo_sort = sort_key_val(sticker, ticker, dim=-1)
        del ticker

        sbuckets_and_t = sbuckets_and_t.detach()
        sticker = sticker.detach()
        undo_sort = undo_sort.detach()

        st = (sticker % seqlen)
        sqk = batched_index_select(qk, st)
        sv = batched_index_select(v, st)

        # Split off a "bin" axis so that attention only occurs within chunks.
        chunk_size = self.n_hashes * n_buckets
        bq_t = bkv_t = torch.reshape(st, (batch_size, chunk_size, -1))
        bqk = torch.reshape(sqk, (batch_size, chunk_size, -1, dim))
        bv = torch.reshape(sv, (batch_size, chunk_size, -1, dim))

        # Hashing operates on unit-length vectors. Unnormalized query vectors are
        # fine because they effectively provide a learnable temperature for the
        # attention softmax, but normalizing keys is needed so that similarity for
        # the purposes of attention correctly corresponds to hash locality.
        bq = bqk
        bk = F.normalize(bqk, p=2, dim=-1).type(bq.type())

        # Allow each chunk to attend within itself, and also one chunk back. Chunk
        # boundaries might occur in the middle of a sequence of items from the
        # same bucket, so this increases the chances of attending to relevant items.
        def look_one_back(x):
            x_extra = torch.cat([x[:, -1:, ...], x[:, :-1, ...]], dim=1)
            return torch.cat([x, x_extra], dim=2)

        bk = look_one_back(bk)
        bv = look_one_back(bv)
        bkv_t = look_one_back(bkv_t)

        # Dot-product attention.
        dots = torch.einsum('bhie,bhje->bhij', bq, bk) * (dim**-0.5)
        masked_value = max_neg_value(dots)

        # Input mask for padding in variable lengthed sequences
        if input_mask is not None:
            input_mask = F.pad(input_mask, (0, seqlen - input_mask.shape[1]),
                               'constant', True)
            mq = input_mask.gather(1, st).reshape((batch_size, chunk_size, -1))
            mkv = look_one_back(mq)
            mask = mq[:, :, :, None] * mkv[:, :, None, :]
            dots.masked_fill_(~mask, masked_value)
            del mask

        # Causal masking
        if self.causal:
            mask = bq_t[:, :, :,
                        None] < bkv_t[:, :, None, :].clamp(max=query_len - 1)
            dots.masked_fill_(mask, masked_value)
            del mask

        # Mask out attention to self except when no other targets are available.
        self_mask = bq_t[:, :, :, None] == bkv_t[:, :, None, :]
        dots.masked_fill_(self_mask, TOKEN_SELF_ATTN_VALUE)
        del self_mask

        # Mask out attention to other hash buckets.
        if not self._attend_across_buckets:
            bq_buckets = bkv_buckets = torch.reshape(
                sbuckets_and_t // seqlen, (batch_size, chunk_size, -1))
            bkv_buckets = look_one_back(bkv_buckets)
            bucket_mask = bq_buckets[:, :, :, None] != bkv_buckets[:, :,
                                                                   None, :]
            dots.masked_fill_(bucket_mask, masked_value)
            del bucket_mask

        # Don't double-count query-key pairs across multiple rounds of hashing.
        # There are two possible strategies here. (1) The default is to count how
        # many times a query-key pair is repeated, and to lower its log-prob
        # correspondingly at each repetition. (2) When hard_k is set, the code
        # instead masks all but the first occurence of each query-key pair.
        if not self._allow_duplicate_attention:
            locs1 = undo_sort // bq_t.shape[-1]
            locs2 = (locs1 + 1) % chunk_size
            if not self._attend_across_buckets:
                locs1 = buckets * chunk_size + locs1
                locs2 = buckets * chunk_size + locs2
            locs = torch.cat([
                torch.reshape(locs1, (batch_size, self.n_hashes, seqlen)),
                torch.reshape(locs2, (batch_size, self.n_hashes, seqlen)),
            ], 1).permute((0, 2, 1))

            slocs = batched_index_select(locs, st)
            b_locs = torch.reshape(
                slocs, (batch_size, chunk_size, -1, 2 * self.n_hashes))

            b_locs1 = b_locs[:, :, :, None, :self.n_hashes]

            bq_locs = b_locs1.expand(b_locs.shape[:3] + (2, self.n_hashes))
            bq_locs = torch.reshape(bq_locs, b_locs.shape)
            bkv_locs = look_one_back(b_locs)

            dup_counts = (bq_locs[:, :, :, None, :] == bkv_locs[:, :,
                                                                None, :, :])
            # for memory considerations, chunk summation of last dimension for counting duplicates
            dup_counts = chunked_sum(dup_counts,
                                     chunks=(self.n_hashes * batch_size))
            dup_counts = dup_counts.detach()
            assert dup_counts.shape == dots.shape
            dots = dots - torch.log(dup_counts + 1e-9)
            del dup_counts

        # Softmax.
        dots_logsumexp = torch.logsumexp(dots, dim=-1, keepdim=True)
        dots = torch.exp(dots - dots_logsumexp).type(dots.type())
        dropped_dots = self.dropout(dots)

        bo = torch.einsum('buij,buje->buie', dropped_dots, bv)
        so = torch.reshape(bo, (batch_size, -1, dim))
        slogits = torch.reshape(dots_logsumexp, (
            batch_size,
            -1,
        ))

        class UnsortLogits(Function):
            @staticmethod
            def forward(ctx, so, slogits):
                so = so.detach()
                slogits = slogits.detach()
                o = batched_index_select(so, undo_sort)
                _, logits = sort_key_val(sticker, slogits, dim=-1)
                return o, logits

            @staticmethod
            def backward(ctx, grad_x, grad_y):
                so_grad = batched_index_select(grad_x, sticker)
                _, slogits_grad = sort_key_val(buckets_and_t, grad_y, dim=-1)
                return so_grad, slogits_grad

        o, logits = UnsortLogits.apply(so, slogits)
        o = torch.reshape(o, (batch_size, self.n_hashes, seqlen, dim))
        logits = torch.reshape(logits, (batch_size, self.n_hashes, seqlen, 1))

        if query_len != seqlen:
            query_slice = (slice(None), slice(None), slice(0, query_len))
            o, logits = o[query_slice], logits[query_slice]

        probs = torch.exp(logits -
                          torch.logsumexp(logits, dim=1, keepdim=True))
        out = torch.sum(o * probs, dim=1)

        attn = torch.empty(0, device=device)

        # return unsorted attention weights
        if self._return_attn:
            attn_unsort = ((bq_t * seqlen)[:, :, :, None] +
                           bkv_t[:, :, None, :])
            attn_unsort = attn_unsort.view(batch_size * self.n_hashes,
                                           -1).long()
            unsorted_dots = torch.zeros(batch_size * self.n_hashes,
                                        seqlen * seqlen,
                                        device=device)
            unsorted_dots.scatter_add_(1, attn_unsort,
                                       dots.view_as(attn_unsort))
            del attn_unsort
            unsorted_dots = unsorted_dots.reshape(batch_size, self.n_hashes,
                                                  seqlen, seqlen)
            attn = torch.sum(unsorted_dots[:, :, 0:query_len, :] * probs,
                             dim=1)

        # return output, attention matrix, and bucket distribution
        return out, attn, buckets
예제 #34
0
def attention(query, key, value):
    dim = query.shape[1]
    scores = torch.einsum('bdhn,bdhm->bhnm', query, key) / dim**.5
    prob = torch.nn.functional.softmax(scores, dim=-1)
    return torch.einsum('bhnm,bdhm->bdhn', prob, value), prob
예제 #35
0
    def forward(self, q, k, v, query_mask=None, key_mask=None, **kwargs):
        b, h, t, d, kv_t, wsz, c_wsz, nc, device, dtype = *q.shape, k.shape[2], self.window_size, self.context_window_size, self.num_clusters, q.device, q.dtype
        is_reverse = kwargs.pop('_reverse', False)

        out = torch.zeros_like(q, dtype=dtype)

        update_kmeans = self.training and not is_reverse

        key_mask = default(key_mask, query_mask) if not self.receives_context else key_mask
        kv_wsz = wsz if not self.receives_context else c_wsz

        wsz = min(wsz, t)
        kv_wsz = min(kv_wsz, kv_t)

        if not self.shared_qk or self.receives_context:
            dists, aux_loss = self.kmeans(torch.cat((q, k), dim=2), update_kmeans)
            q_dists, k_dists = split_at_index(2, t, dists)
            indices = distribution(q_dists, wsz)
            kv_indices = distribution(k_dists, kv_wsz)
        else:
            dists, aux_loss = self.kmeans(q, update_kmeans)
            k = F.normalize(k, dim=-1).to(q)
            indices = distribution(dists, wsz)
            kv_indices = indices

        q = batched_index_select(q, indices)
        k = batched_index_select(k, kv_indices)
        v = batched_index_select(v, kv_indices)

        reshape_with_window = lambda x: x.reshape(b, h, nc, -1, d)
        q, k, v = map(reshape_with_window, (q, k, v))

        m_k, m_v = map(lambda x: expand_dim(x, 0, b).to(q), (self.mem_key, self.mem_value))
        k, v = map(lambda x: torch.cat(x, dim=3), ((m_k, k), (m_v, v)))

        dots = torch.einsum('bhnid,bhnjd->bhnij', q, k) * (d ** -0.5)

        mask_value = max_neg_value(dots)

        if exists(query_mask) or exists(key_mask):
            query_mask = default(query_mask, lambda: torch.ones((b, t), device=device).bool())
            key_mask = default(key_mask, lambda: torch.ones((b, kv_t), device=device).bool())

            q_mask = expand_dim(query_mask, 1, h).gather(2, indices)
            kv_mask = expand_dim(key_mask, 1, h).gather(2, kv_indices)
            q_mask, kv_mask = map(lambda t: t.reshape(b, h, nc, -1), (q_mask, kv_mask))
            mask = q_mask[:, :, :, :, None] * kv_mask[:, :, :, None, :]
            mask = F.pad(mask, (self.num_mem_kv, 0), value=1)
            dots.masked_fill_(~mask, mask_value)
            del mask

        if self.causal:
            q_mask, kv_mask = map(lambda t: t.reshape(b, h, nc, -1), (indices, kv_indices))
            mask = q_mask[:, :, :, :, None] >= kv_mask[:, :, :, None, :]
            mask = F.pad(mask, (self.num_mem_kv, 0), value=1)
            dots.masked_fill_(~mask, mask_value)
            del mask

        if self.shared_qk:
            q_mask, kv_mask = map(lambda t: t.reshape(b, h, nc, -1), (indices, kv_indices))
            mask = q_mask[:, :, :, :, None] == kv_mask[:, :, :, None, :]
            mask = F.pad(mask, (self.num_mem_kv, 0), value=0)
            dots.masked_fill_(mask, TOKEN_SELF_ATTN_VALUE)
            del mask

        dots = dots.softmax(dim=-1)
        dots = self.dropout(dots)

        bo = torch.einsum('bhcij,bhcjd->bhcid', dots, v)
        so = torch.reshape(bo, (b, h, -1, bo.shape[-1])).type(dtype)
        out = scatter_mean(out, so, indices.unsqueeze(-1).expand_as(so), -2)
        return out, aux_loss