def einsum(equation, *operands): """Variadic version of torch.einsum to match numpy api. """ # rename symbols to support PyTorch 0.4.1 and earlier, # which allow only symbols a-z. equation = convert_to_valid_einsum_chars(equation) torch, _ = _get_torch_and_device() return torch.einsum(equation, operands)
def affine_product(X, A, b): """ Special case of affine transformation that receives coordinates X in 2-d (x, y) and affine matrix A and translation vector b in 3-d (x, y, z). Y = AX + b :param torch.Tensor X: A matrix of 2-d coordinates (d1 x d2 x 2). :param torch.Tensor A: The first two columns of the affine matrix (3 x 2). :param torch.Tensor b: A 3-d translation vector. :return: A (d1 x d2 x 3) torch.Tensor corresponding to the transformed coordinates. """ return torch.einsum('ij,klj->kli', (A, X)) + b
def forward(self, x, A): assert A.size(0) == self.kernel_size x = self.conv1d(x) n, kc, v = x.size() x = x.view(n, self.kernel_size, kc//self.kernel_size, v) x = torch.einsum('nkcv,kvw->ncw', (x, A)) # n, kc, t, v = x.size() # x = x.view(n, self.kernel_size, kc//self.kernel_size, t, v) # x = torch.einsum('nkctv,kvw->nctw', (x, A)) #print('einsum',x.shape) x = x.contiguous() return x, A
def forward(self, x): if self.share_weights: u_hat_vecs = t.matmul(x, self.W) else: print('add later') batch_size = x.size(0) input_num_capsule = x.size(1) u_hat_vecs = u_hat_vecs.view((batch_size, input_num_capsule, self.num_capsule, self.dim_capsule)) u_hat_vecs = u_hat_vecs.permute(0, 2, 1, 3) # 转成(batch_size,num_capsule,input_num_capsule,dim_capsule) b = t.zeros_like(u_hat_vecs[:, :, :, 0]) # (batch_size,num_capsule,input_num_capsule) for i in range(self.routings): b = b.permute(0, 2, 1) c = F.softmax(b, dim=2) c = c.permute(0, 2, 1) b = b.permute(0, 2, 1) outputs = self.activation(t.einsum('bij,bijk->bik', (c, u_hat_vecs))) # batch matrix multiplication # outputs shape (batch_size, num_capsule, dim_capsule) if i < self.routings - 1: b = t.einsum('bik,bijk->bij', (outputs, u_hat_vecs)) # batch matrix multiplication return outputs # (batch_size, num_capsule, dim_capsule)
def _oc(a: Tensor, rhs: Tensor, Y: Tensor) -> Tensor: r"""Evaluate constraints. Note: einsum multiples Y by a and sums over the `o`-dimension. Einsum is ~2x faster than using `(Y * a.view(1, 1, -1)).sum(dim-1)`. Args: a: `o`-dim tensor of weights for the outcomes rhs: Singleton tensor containing the outcome constraint value Y: `... x b x q x o` tensor of function values Returns: A `... x b x q`-dim tensor where negative values imply feasibility """ lhs = torch.einsum("...o, o", [Y, a]) return lhs - rhs
def compute_chunk(left_act, right_act): act = torch.einsum('...bac,...dae->...bdce', left_act, right_act) act = act.reshape(act.shape[:-2] + (-1, )) act = self.output_w(act) return act
def apply_TM_2sO(state, env, edge, op=None, verbosity=0): r""" :param state: underlying 1-site C4v symmetric wavefunction :param env: C4v symmetric environment corresponding to ``state`` :param edge: tensor of dimensions :math:`\chi \times D^2 \times \chi` :param op: two-site operator to be inserted into the two consecutive transfer matrices :param verbosity: logging verbosity :type state: IPEPS_C4V :type env: ENV_C4V :type edge: torch.tensor :type op: torch.tensor :type verbosity: int :return: ``edge`` with two transfer matrices (and operator ``op``, if any) applied. The resulting tensor has an identical index structure as the original ``edge`` :rtype: torch.tensor Applies two transfer matrices to the ``edge`` tensor, including the two-site operator ``op`` by contracting the following network:: -----T-------------T------------ | | | edge--(a^+ op_l a)==(a^+ op_r a)-- | | | -----T-------------T------------ where the physical indices `s` and `s'` of the on-site tensor :math:`a` and it's hermitian conjugate :math:`a^\dagger` are contracted with identity :math:`\delta_{s,s'}` or ``op_l`` and ``op_r`` if ``op`` is supplied. The ``op_l`` and ``op_r`` are given by the SVD decomposition of two-site operator ``op``:: 0 1 0 1 0 1->0 | | SVD | | | | | op | = |op_l|--(S--|op^~_r|) = |op_l|--2 2--|op_r| | | | | | | 2 3 2 3 2->1 3->1 """ # TODO stronger verification if op is not None: assert (len(op.size()) == 4) # pre-process ``op`` # TODO possibly truncate/compress according to the vanishingly small singular values dims_op = op.size() op_mat = op.permute(0, 2, 1, 3).contiguous().reshape(dims_op[0]**2, dims_op[0]**2) op_l, s, op_r = torch.svd(op_mat) op_l = op_l.reshape(dims_op[0], dims_op[0], s.size()[0]) op_r = torch.einsum('i,ij->ij', s, op_r.t()).reshape(s.size()[0], dims_op[0], dims_op[0]) op_r = op_r.permute(1, 2, 0).contiguous() T = env.T[env.keyT] # Assume index structure of ``edge`` tensor to be as follows # # -- 0 # edge |-- 1 # -- 2 # # ----0 0--T--1->2 # | 2->3 # edge--1->0 # | # ----2->1 E = torch.tensordot(edge, T, ([0], [0])) if verbosity > 0: print("E=edgeT " + str(E.size())) # TODO - more efficent contraction with uncontracted-double-layer on-site tensor # Possibly reshape indices 1,2 of E, which are to be contracted with # on-site tensor and contract bra,ket in two steps instead of creating # double layer tensor # / # --A-- # /|s # X # s'|/ # --A-- # / # # where X is Id or op a = next(iter(state.sites.values())) dims_a = a.size() X = torch.eye(dims_a[0], dtype=a.dtype, device=a.device)[:, :, None] if op is None else op_l A= torch.einsum('mefgh,mnl,nabcd->eafbgchdl',a,X,a).contiguous()\ .view(dims_a[1]**2, dims_a[2]**2, dims_a[3]**2, dims_a[4]**2, -1) # ---------T--2->1 # | 3 4 # | 0/ # edge--0 1--A--3 # | 2 # ----1->0 E = torch.tensordot(E, A, ([0, 3], [1, 0])) if verbosity > 0: print("E=EA " + str(E.size())) # -------T--1->0 # | | 4->2 # | |/ # edge-----A--3->1 # | 2 # | 2 # --0 0--T--1->3 E = torch.tensordot(E, T, ([0, 2], [0, 2])) if verbosity > 0: print("E=ET " + str(E.size())) # ----0 0----T--1->3 # |----2->1 2->4 # edge--1->0 # | # ----3->2 E = torch.tensordot(E, T, ([0], [0])) if verbosity > 0: print("E=ET " + str(E.size())) # TODO - more efficent contraction with uncontracted-double-layer on-site tensor # Possibly reshape indices 1,2 of E, which are to be contracted with # on-site tensor and contract bra,ket in two steps instead of creating # double layer tensor # / # --A-- # /|s # X # s'|/ # --A-- # / # # where X is Id or op X = torch.eye(dims_a[0], dtype=a.dtype, device=a.device)[:, :, None] if op is None else op_r A= torch.einsum('mefgh,mnl,nabcd->eafbgchdl',a,X,a).contiguous()\ .view(dims_a[1]**2, dims_a[2]**2, dims_a[3]**2, dims_a[4]**2, -1) # ---------T--3->1 # | 4 # |----1 4-\0 # edge--0 1--A--3 # | 2 # ----2->0 E = torch.tensordot(E, A, ([0, 1, 4], [1, 4, 0])) if verbosity > 0: print("E=EA " + str(E.size())) # -------T--1->0 # | | # | | # edge-----A--3->1 # | 2 # | 2 # --0 0--T--1->2 E = torch.tensordot(E, T, ([0, 2], [0, 2])) if verbosity > 0: print("E=ET " + str(E.size())) return E
def forward(self, w, r, attn_mask=None, mems=None, head_mask=None, output_attentions=False): qlen, rlen, bsz = w.size(0), r.size(0), w.size(1) if mems is not None: cat = torch.cat([mems, w], 0) if self.pre_lnorm: w_heads = self.qkv_net(self.layer_norm(cat)) else: w_heads = self.qkv_net(cat) r_head_k = self.r_net(r) w_head_q, w_head_k, w_head_v = torch.chunk(w_heads, 3, dim=-1) w_head_q = w_head_q[-qlen:] else: if self.pre_lnorm: w_heads = self.qkv_net(self.layer_norm(w)) else: w_heads = self.qkv_net(w) r_head_k = self.r_net(r) w_head_q, w_head_k, w_head_v = torch.chunk(w_heads, 3, dim=-1) klen = w_head_k.size(0) w_head_q = w_head_q.view(qlen, bsz, self.n_head, self.d_head) # qlen x bsz x n_head x d_head w_head_k = w_head_k.view(klen, bsz, self.n_head, self.d_head) # qlen x bsz x n_head x d_head w_head_v = w_head_v.view(klen, bsz, self.n_head, self.d_head) # qlen x bsz x n_head x d_head r_head_k = r_head_k.view(rlen, self.n_head, self.d_head) # qlen x n_head x d_head # compute attention score rw_head_q = w_head_q + self.r_w_bias # qlen x bsz x n_head x d_head AC = torch.einsum("ibnd,jbnd->ijbn", (rw_head_q, w_head_k)) # qlen x klen x bsz x n_head rr_head_q = w_head_q + self.r_r_bias BD = torch.einsum("ibnd,jnd->ijbn", (rr_head_q, r_head_k)) # qlen x klen x bsz x n_head BD = self._rel_shift(BD) # [qlen x klen x bsz x n_head] attn_score = AC + BD attn_score.mul_(self.scale) # compute attention probability if attn_mask is not None and torch.sum(attn_mask).item(): attn_mask = attn_mask == 1 # Switch to bool if attn_mask.dim() == 2: if next(self.parameters()).dtype == torch.float16: attn_score = (attn_score.float().masked_fill( attn_mask[None, :, :, None], -65000).type_as(attn_score)) else: attn_score = attn_score.float().masked_fill( attn_mask[None, :, :, None], -1e30).type_as(attn_score) elif attn_mask.dim() == 3: if next(self.parameters()).dtype == torch.float16: attn_score = attn_score.float().masked_fill( attn_mask[:, :, :, None], -65000).type_as(attn_score) else: attn_score = attn_score.float().masked_fill( attn_mask[:, :, :, None], -1e30).type_as(attn_score) # [qlen x klen x bsz x n_head] attn_prob = F.softmax(attn_score, dim=1) attn_prob = self.dropatt(attn_prob) # Mask heads if we want to if head_mask is not None: attn_prob = attn_prob * head_mask # compute attention vector attn_vec = torch.einsum("ijbn,jbnd->ibnd", (attn_prob, w_head_v)) # [qlen x bsz x n_head x d_head] attn_vec = attn_vec.contiguous().view(attn_vec.size(0), attn_vec.size(1), self.n_head * self.d_head) # linear projection attn_out = self.o_net(attn_vec) attn_out = self.drop(attn_out) if self.pre_lnorm: # residual connection outputs = [w + attn_out] else: # residual connection + layer normalization outputs = [self.layer_norm(w + attn_out)] if output_attentions: outputs.append(attn_prob) return outputs
def forward(self, input): result = torch.einsum(self.einsum_pattern, input, self.weight) if self.bias is not None: result += self.bias return result
def apply_TM_1sO(state, env, edge, op=None, verbosity=0): r""" :param state: underlying 1-site C4v symmetric wavefunction :param env: C4v symmetric environment corresponding to ``state`` :param edge: tensor of dimensions :math:`\chi \times D^2 \times \chi` :param op: operator to be inserted into transfer matrix :param verbosity: logging verbosity :type state: IPEPS_C4V :type env: ENV_C4V :type edge: torch.tensor :type op: torch.tensor :type verbosity: int :return: ``edge`` with a single instance of the transfer matrix applied. The resulting tensor has an identical index structure as the original ``edge`` :rtype: torch.tensor Applies a single instance of the "transfer matrix" to the ``edge`` tensor by contracting the following network:: -----T---------- | | edge--(a^+ op a)-- | | -----T---------- where the physical indices `s` and `s'` of the on-site tensor :math:`a` and it's hermitian conjugate :math:`a^\dagger` are contracted with identity :math:`\delta_{s,s'}` or ``op`` (if supplied). """ # TODO stronger verification if op is not None: assert (len(op.size()) == 2) T = env.T[env.keyT] # Assume index structure of ``edge`` tensor to be as follows # # -- 0 # edge |-- 1 # -- 2 # # --0 0--T--1->2 # | 2->3 # edge--1->0 # | # --2->1 E = torch.tensordot(edge, T, ([0], [0])) if verbosity > 0: print("E=edgeT " + str(E.size())) # TODO - more efficent contraction with uncontracted-double-layer on-site tensor # Possibly reshape indices 1,2 of E, which are to be contracted with # on-site tensor and contract bra,ket in two steps instead of creating # double layer tensor # / # --A-- # /|s # X # s'|/ # --A-- # / # # where X is Id or op a = next(iter(state.sites.values())) dims_a = a.size() X = torch.eye(dims_a[0], dtype=a.dtype, device=a.device) if op is None else op A= torch.einsum('mefgh,mn,nabcd->eafbgchd',a,X,a).contiguous()\ .view(dims_a[1]**2, dims_a[2]**2, dims_a[3]**2, dims_a[4]**2) # ---------T--2->1 # | 3 # | 0 # edge--0 1--A--3 # | 2 # ----1->0 E = torch.tensordot(E, A, ([0, 3], [1, 0])) if verbosity > 0: print("E=EA " + str(E.size())) # -------T--1->0 # | | # | | # edge-----A--3->1 # | 2 # | 2 # --0 0--T--1->2 E = torch.tensordot(E, T, ([0, 2], [0, 2])) if verbosity > 0: print("E=ET " + str(E.size())) return E
def forward( self, x, encoder_out: Optional[torch.Tensor] = None, encoder_padding_mask: Optional[torch.Tensor] = None, incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None, prev_self_attn_state: Optional[List[torch.Tensor]] = None, prev_attn_state: Optional[List[torch.Tensor]] = None, self_attn_mask: Optional[torch.Tensor] = None, self_attn_padding_mask: Optional[torch.Tensor] = None, need_attn: bool = False, need_head_weights: bool = False, ): """ Args: x (Tensor): input to the layer of shape `(seq_len, batch, embed_dim)` encoder_padding_mask (ByteTensor, optional): binary ByteTensor of shape `(batch, src_len)` where padding elements are indicated by ``1``. need_attn (bool, optional): return attention weights need_head_weights (bool, optional): return attention weights for each head (default: return average over heads). Returns: encoded output of shape `(seq_len, batch, embed_dim)` """ if need_head_weights: need_attn = True residual = x if self.normalize_before: x = self.self_attn_layer_norm(x) if prev_self_attn_state is not None: prev_key, prev_value = prev_self_attn_state[:2] saved_state: Dict[str, Optional[Tensor]] = { "prev_key": prev_key, "prev_value": prev_value, } if len(prev_self_attn_state) >= 3: saved_state["prev_key_padding_mask"] = prev_self_attn_state[2] assert incremental_state is not None self.self_attn._set_input_buffer(incremental_state, saved_state) _self_attn_input_buffer = self.self_attn._get_input_buffer( incremental_state) if self.cross_self_attention and not ( incremental_state is not None and _self_attn_input_buffer is not None and "prev_key" in _self_attn_input_buffer): if self_attn_mask is not None: assert encoder_out is not None self_attn_mask = torch.cat((x.new_zeros( x.size(0), encoder_out.size(0)), self_attn_mask), dim=1) if self_attn_padding_mask is not None: if encoder_padding_mask is None: assert encoder_out is not None encoder_padding_mask = self_attn_padding_mask.new_zeros( encoder_out.size(1), encoder_out.size(0)) self_attn_padding_mask = torch.cat( (encoder_padding_mask, self_attn_padding_mask), dim=1) assert encoder_out is not None y = torch.cat((encoder_out, x), dim=0) else: y = x x, attn = self.self_attn( query=x, key=y, value=y, key_padding_mask=self_attn_padding_mask, incremental_state=incremental_state, need_weights=False, attn_mask=self_attn_mask, ) if self.c_attn is not None: tgt_len, bsz = x.size(0), x.size(1) x = x.view(tgt_len, bsz, self.nh, self.head_dim) x = torch.einsum("tbhd,h->tbhd", x, self.c_attn) x = x.reshape(tgt_len, bsz, self.embed_dim) if self.attn_ln is not None: x = self.attn_ln(x) x = self.dropout_module(x) x = self.residual_connection(x, residual) if not self.normalize_before: x = self.self_attn_layer_norm(x) if self.encoder_attn is not None and encoder_out is not None: residual = x if self.normalize_before: x = self.encoder_attn_layer_norm(x) if prev_attn_state is not None: prev_key, prev_value = prev_attn_state[:2] saved_state: Dict[str, Optional[Tensor]] = { "prev_key": prev_key, "prev_value": prev_value, } if len(prev_attn_state) >= 3: saved_state["prev_key_padding_mask"] = prev_attn_state[2] assert incremental_state is not None self.encoder_attn._set_input_buffer(incremental_state, saved_state) x, attn = self.encoder_attn( query=x, key=encoder_out, value=encoder_out, key_padding_mask=encoder_padding_mask, incremental_state=incremental_state, static_kv=True, need_weights=need_attn or (not self.training and self.need_attn), need_head_weights=need_head_weights, ) x = self.dropout_module(x) x = self.residual_connection(x, residual) if not self.normalize_before: x = self.encoder_attn_layer_norm(x) residual = x if self.normalize_before: x = self.final_layer_norm(x) x = self.activation_fn(self.fc1(x)) x = self.activation_dropout_module(x) if self.ffn_layernorm is not None: x = self.ffn_layernorm(x) x = self.fc2(x) x = self.dropout_module(x) if self.w_resid is not None: residual = torch.mul(self.w_resid, residual) x = self.residual_connection(x, residual) if not self.normalize_before: x = self.final_layer_norm(x) if self.onnx_trace and incremental_state is not None: saved_state = self.self_attn._get_input_buffer(incremental_state) assert saved_state is not None if self_attn_padding_mask is not None: self_attn_state = [ saved_state["prev_key"], saved_state["prev_value"], saved_state["prev_key_padding_mask"], ] else: self_attn_state = [ saved_state["prev_key"], saved_state["prev_value"] ] return x, attn, self_attn_state return x, attn, None
def _get_predicts(predicts, coefficients): return torch.einsum("ij,j->ij", (predicts, coefficients))
def gradient(self, *xs, y=None, v=None, ctx=None): """Computes the vector--Jacobian product, that is, the gradient of the loss function with respect to the problem parameters. The returned gradient is a tuple of batched Torch tensors. Can be overridden by the derived class to provide a more efficient implementation. Arguments: xs: ((b, ...), ...) tuple of Torch tensors, tuple of batches of input tensors y: (b, ...) Torch tensor or None, batch of minima of the objective function v: (b, ...) Torch tensor or None, batch of gradients of the loss function with respect to the problem output J_Y(x,y) ctx: dictionary of contextual information used for computing the gradient Return Values: gradients: ((b, ...), ...) tuple of Torch tensors or Nones, batch of gradients of the loss function with respect to the problem parameters; strictly, returns the vector--Jacobian products J_Y(x,y) * y'(x) """ # Compute optimal value if have not already done so: if y is None: y, ctx = torch.no_grad()(self.solve)(*xs) y.requires_grad = True # Set incoming gradient v = J_Y(x,y) to one if not specified: if v is None: v = torch.ones_like(y) b = y.size(0) m = y.view(b, -1).size(-1) # Get constraint parameters and form batch: A, d = self.linear_constraint_parameters(y) A = self._expand_as_batch(A, b) d = self._expand_as_batch(d, b) # Check linear equality constraints are satisfied: h = torch.einsum('bpm,bm->bp', (A, y)) - d if not self._check_equality_constraints(h): warnings.warn("Constraints not satisfied {}".format( h.detach().squeeze().cpu().numpy())) # Compute relevant derivatives with autograd: with torch.enable_grad(): # Split each input x into a tuple of n tensors of size bx1: # Required since gradients can only be computed wrt individual # tensors, not slices of a tensors. See: # https://discuss.pytorch.org/t/how-to-calculate-gradients-wrt-one-of-inputs/24407 xs_split, xs_sizes = self._split_inputs(xs) xs = self._cat_inputs(xs_split, xs_sizes) # Evaluate objective function at (xs,y): f = self.objective(*xs, y) # b # Compute partial derivative of f wrt y at (xs,y): grad_outputs = torch.ones_like(f) # b fY = grad(f, y, grad_outputs=grad_outputs, create_graph=True)[0].view(b, -1) # bxm if not fY.requires_grad: # if fY is independent of y fY.requires_grad = True # Compute second-order partial derivative of f wrt y at (xs,y): fYY = self._batch_jacobian(fY, y) assert fYY is not None # Compute 2nd-order partial derivative of h wrt y at (xs,y) and form H: H = fYY.detach() # Solve u = -H^-1 v (bxm) and t = H^-1 A^T (bxmxp): H = 0.5 * (H + H.transpose(1, 2)) # Ensure that H is symmetric v = v.view(b, -1, 1) # bxmx1 u, t = self._solve_linear_system(H, (-1.0 * v, A.transpose(-2, -1))) u = u.squeeze(-1) # bxm # ToDo: check for NaN values in u and t # Solve s = (A H^-1 A^T)^-1 A H^-1 v = -(A t)^-1 A u: s = self._solve_linear_system(torch.einsum('bpm,bmq->bpq', (A, t)), torch.einsum('bpm,bm->bp', (A, -1.0 * u))) # bxpx1 s = s.squeeze(-1) # bxp # ToDo: check for NaN values in s # Compute u + ts = -H^-1 v + H^-1 A^T (A H^-1 A^T)^-1 A H^-1 v: uts = u + torch.einsum('bmp,bp->bm', (t, s)) # bxm # Compute bi^T (u + ts) for all i: gradients = [] for x_split, x_size in zip(xs_split, xs_sizes): # Loop over input tuple if isinstance(x_split[0], torch.Tensor) and x_split[0].requires_grad: n = len(x_split) gradient = x_split[0].new_zeros(b, n) # bxn for i in range(n): # 2nd-order partial derivative of f wrt y and xi at (xs,y): fXiY = self._batch_jacobian(fY, x_split[i]) # bxmx1 bi = fXiY.detach().squeeze(-1) if (fXiY is not None) else ( torch.zeros_like(fY)) # Shares storage with fXiY gradient[:, i] = torch.einsum('bm,bm->b', (bi, uts)) # Reshape gradient to size(x): gradients.append(gradient.view(x_size)) else: gradients.append(None) return tuple(gradients)
def gradient(self, *xs, y=None, v=None, ctx=None): """Computes the vector--Jacobian product, that is, the gradient of the loss function with respect to the problem parameters. The returned gradient is a tuple of batched Torch tensors. Can be overridden by the derived class to provide a more efficient implementation. Arguments: xs: ((b, ...), ...) tuple of Torch tensors, tuple of batches of input tensors y: (b, ...) Torch tensor or None, batch of minima of the objective function v: (b, ...) Torch tensor or None, batch of gradients of the loss function with respect to the problem output J_Y(x,y) ctx: dictionary of contextual information used for computing the gradient Return Values: gradients: ((b, ...), ...) tuple of Torch tensors or Nones, batch of gradients of the loss function with respect to the problem parameters; strictly, returns the vector--Jacobian products J_Y(x,y) * y'(x) """ # Compute optimal value if have not already done so: if y is None: y, ctx = torch.no_grad()(self.solve)(*xs) y.requires_grad = True # Set incoming gradient v = J_Y(x,y) to one if not specified: if v is None: v = torch.ones_like(y) # Compute relevant derivatives with autograd: b = y.size(0) m = y.view(b, -1).size(-1) with torch.enable_grad(): # Split each input x into a tuple of n tensors of size bx1: # Required since gradients can only be computed wrt individual # tensors, not slices of a tensor. See: # https://discuss.pytorch.org/t/how-to-calculate-gradients-wrt-one-of-inputs/24407 xs_split, xs_sizes = self._split_inputs(xs) # Evaluate objective function at (xs,y): f = self.objective(*self._cat_inputs(xs_split, xs_sizes), y) # b # Compute partial derivative of f wrt y at (xs,y): fY = grad(f, y, grad_outputs=torch.ones_like(f), create_graph=True)[0].view(b, -1) # bxm if not self._check_optimality_cond(fY): warnings.warn( "Non-zero objective function gradient {} at y".format( fY.detach().squeeze().cpu().numpy())) # Compute second-order partial derivative of f wrt y at (xs,y): fYY = self._batch_jacobian(fY, y) # Solve u = -H^-1 v: H = fYY.detach() H = 0.5 * (H + H.transpose(1, 2)) # Ensure that H is symmetric v = v.view(b, -1, 1) u = self._solve_linear_system(H, -1.0 * v) # bxmx1 u = u.squeeze(-1) # bxm # ToDo: check for NaN values in u # Compute -b_i^T H^-1 v (== b_i^T u) for all i: gradients = [] for x_split, x_size in zip(xs_split, xs_sizes): # Loop over input tuple if isinstance(x_split[0], torch.Tensor) and x_split[0].requires_grad: n = len(x_split) gradient = x_split[0].new_zeros(b, n) # bxn # 2nd-order partial derivative of f wrt y and x at (xs,y): fXiY = torch.zeros_like(fY) # bxm grad_outputs = torch.ones_like(fY) for i in range(n): with torch.enable_grad(): fXiY = grad(fY, x_split[i], grad_outputs=grad_outputs, create_graph=True)[0] # bxm bi = fXiY.detach() gradient[:, i] = torch.einsum('bm,bm->b', (bi, u)) # Reshape gradient to size(x): gradients.append(gradient.view(x_size)) else: gradients.append(None) return tuple(gradients)
def gradient(self, *xs, y=None, v=None, ctx=None): """Computes the vector--Jacobian product, that is, the gradient of the loss function with respect to the problem parameters. The returned gradient is a tuple of batched Torch tensors. Can be overridden by the derived class to provide a more efficient implementation. Arguments: xs: ((b, ...), ...) tuple of Torch tensors, tuple of batches of input tensors y: (b, ...) Torch tensor or None, batch of minima of the objective function v: (b, ...) Torch tensor or None, batch of gradients of the loss function with respect to the problem output J_Y(x,y) ctx: dictionary of contextual information used for computing the gradient Return Values: gradients: ((b, ...), ...) tuple of Torch tensors or Nones, batch of gradients of the loss function with respect to the problem parameters; strictly, returns the vector--Jacobian products J_Y(x,y) * y'(x) """ # Compute optimal value if have not already done so: if y is None: y, ctx = torch.no_grad()(self.solve)(*xs) y.requires_grad = True # Set incoming gradient v = J_Y(x,y) to one if not specified: if v is None: v = torch.ones_like(y) # Compute relevant derivatives with autograd: b = y.size(0) m = y.view(b, -1).size(-1) with torch.enable_grad(): # Split each input x into a tuple of n tensors of size bx1: # Required since gradients can only be computed wrt individual # tensors, not slices of a tensors. See: # https://discuss.pytorch.org/t/how-to-calculate-gradients-wrt-one-of-inputs/24407 xs_split, xs_sizes = self._split_inputs(xs) xs = self._cat_inputs(xs_split, xs_sizes) # Evaluate constraint function(s) at (xs,y): h = self._get_constraint_set(xs, y) # bxp if h is None: # If None, use unconstrained gradient return super().gradient(xs, y=y, v=v, ctx=ctx) # Evaluate objective function at (xs,y): f = self.objective(*xs, y) # b # Compute partial derivative of f wrt y at (xs,y): fY = grad(f, y, grad_outputs=torch.ones_like(f), create_graph=True)[0].view(b, -1) # bxm if not fY.requires_grad: # if fY is independent of y fY.requires_grad = True # Compute partial derivative of h wrt y at (xs,y): hY = self._batch_jacobian(h, y, create_graph=True) if not hY.requires_grad: # if hY is independent of y hY.requires_grad = True # Compute nu (b, p): nu = self._get_nu(fY, hY) if (ctx is None or 'nu' not in ctx) else ctx['nu'] nu = nu.unsqueeze(-1) if len( nu.size()) == 1 else nu # Force p dimension if not self._check_optimality_cond(fY, hY, nu): warnings.warn( "Non-zero Lagrangian gradient {} at y. fY: {}, hY: {}, nu: {}". format( (fY - torch.einsum('ab,abc->ac', (nu, hY))).detach().squeeze().cpu().numpy(), fY.detach().squeeze().cpu().numpy(), hY.detach().squeeze().cpu().numpy(), nu.detach().squeeze().cpu().numpy())) # Compute second-order partial derivative of f wrt y at (xs,y): fYY = self._batch_jacobian(fY, y) # Compute 2nd-order partial derivative of h wrt y at (xs,y) and form H: H = fYY.detach() if fYY is not None else 0.0 # Shares storage with fYY p = h.size(-1) for i in range(p): with torch.enable_grad(): # Needed when looping over output hiYY = self._batch_jacobian(hY[:, i, :], y, create_graph=False) if hiYY is not None: H -= torch.einsum('b,bmn->bmn', (nu[:, i], hiYY)) assert isinstance(H, torch.Tensor) # Solve u = -H^-1 v (bxm) and t = H^-1 A^T (bxmxp): H = 0.5 * (H + H.transpose(1, 2)) # Ensure that H is symmetric A = hY.detach() # Shares storage with hY v = v.view(b, -1, 1) # bxmx1 u, t = self._solve_linear_system(H, (-1.0 * v, A.transpose(-2, -1))) u = u.squeeze(-1) # bxm # ToDo: check for NaN values in u and t # Solve s = (A H^-1 A^T)^-1 A H^-1 v = -(A t)^-1 A u: s = self._solve_linear_system(torch.einsum('bpm,bmq->bpq', (A, t)), torch.einsum('bpm,bm->bp', (A, -1.0 * u))) # bxpx1 s = s.squeeze(-1) # bxp # ToDo: check for NaN values in s # Compute u + ts: uts = u + torch.einsum('bmp,bp->bm', (t, s)) # bxm # Compute bi^T (u + ts) - ci^T s for all i: gradients = [] for x_split, x_size in zip(xs_split, xs_sizes): # Loop over input tuple if isinstance(x_split[0], torch.Tensor) and x_split[0].requires_grad: n = len(x_split) gradient = x_split[0].new_zeros(b, n) # bxn for i in range(n): # 2nd-order partial derivative of f wrt y and xi at (xs,y): fXiY = self._batch_jacobian(fY, x_split[i]) # bxmx1 bi = fXiY.detach().squeeze(-1) if (fXiY is not None) else ( torch.zeros_like(fY)) # Shares storage with fXiY for j in range(p): # 2nd-order partial derivative of hj wrt y and xi at (xs,y): with torch.enable_grad(): hjXiY = self._batch_jacobian( hY[:, j, :], x_split[i]) # bxmx1 if hjXiY is not None: bi -= torch.einsum( 'b,bm->bm', (nu[:, j], hjXiY.detach().squeeze(-1))) # bxm # Compute partial derivative of h wrt xi at (xs,y): hXi = self._batch_jacobian(h, x_split[i]) # bxpx1 if hXi is None: gradient[:, i] = torch.einsum('bm,bm->b', (bi, uts)) else: ci = hXi.detach().squeeze( -1) # Shares storage with hXi gradient[:, i] = (torch.einsum('bm,bm->b', (bi, uts)) - torch.einsum('bp,bp->b', (ci, s))) # Reshape gradient to size(x): gradients.append(gradient.view(x_size)) else: gradients.append(None) return tuple(gradients)
def get_batch_top3score(target, output): weights = torch.as_tensor([1, 1 / 2, 1 / 3]).cuda() _, pred = torch.topk(output, k=3, dim=1) target = target.reshape(target.shape[0], 1) target = target.repeat(1, 3) return torch.sum(torch.einsum("ij,j->i", (pred == target).float(), weights)).item()
def forward(self, z1ss, pos_emb, u1ss, mems=None): # Note: In this context, qlen means the length of the (small) subsequence; and mlen describes # the length of the padding. Their sum is klen. bsz, d_model, qlen = z1ss.size() r_w_bias, r_r_bias = self.r_w_bias, self.r_r_bias n_head, d_head = self.n_head, self.d_head rlen = pos_emb.size(2) if mems is None: mems = torch.tensor([]).view(0, 0, 0) mlen = mems.size(2) cat = torch.cat([mems, z1ss], dim=-1) if self.pre_lnorm: cat = F.layer_norm(cat.transpose(1, 2), (d_model, )).transpose(1, 2) w_heads = self.qkv_net(cat) # (N x 3*d_model x seq_len) r_head_k = self.r_net(pos_emb) # Input injection w_heads += u1ss w_head_q, w_head_k, w_head_v = torch.chunk(w_heads, 3, dim=1) w_head_q = w_head_q[:, :, -qlen:] klen = w_head_k.size(2) w_head_q = w_head_q.view(bsz, n_head, d_head, qlen) # bsz x n_head x d_head x qlen w_head_k = w_head_k.view(bsz, n_head, d_head, klen) # bsz x n_head x d_head x klen w_head_v = w_head_v.view(bsz, n_head, d_head, klen) # bsz x n_head x d_head x klen r_head_k = r_head_k.view(n_head, d_head, rlen) # n_head x d_head x rlen # Compute attention score rw_head_q = w_head_q + r_w_bias[:, :, None] # bsz x n_head x d_head x qlen AC = torch.einsum('bndi,bndj->bnij', rw_head_q, w_head_k) rr_head_q = w_head_q + r_r_bias[:, :, None] BD = torch.einsum('bndi,ndj->bnij', rr_head_q, r_head_k) BD = self._rel_shift(BD) # for relative positional embedding attn_score = AC + BD # bsz x n_head x qlen x klen attn_score.mul_(self.scale) # Compute attention probability # We apply a local mask, with local horizon size of mlen local_size = self.local_size or 1000 attn_mask = (torch.triu(torch.ones(qlen, klen), diagonal=1 + mlen) > 0)[None] attn_mask += (torch.tril(torch.ones(qlen, klen), diagonal=mlen - local_size) > 0)[None] if attn_mask is not None and attn_mask.any().item(): attn_score = attn_score.float().masked_fill( attn_mask[None], -float('inf')).type_as(attn_score) attn_prob = F.softmax(attn_score, dim=-1) # bsz x n_head x qlen x klen # Compute attention vector attn_vec = torch.einsum('bnij,bndj->bndi', (attn_prob, w_head_v)) # [bsz x d x qlen] attn_vec = attn_vec.contiguous().view(bsz, n_head * d_head, attn_vec.size(-1)) # Linear projection attn_out = self.o_net(attn_vec) attn_out = self.drop(attn_out) # Residual connection + layer normolization (if applicable) if self.pre_lnorm: out = attn_out + z1ss else: out = F.layer_norm((attn_out + z1ss).transpose(1, 2), (d_model, )).transpose(1, 2) return out
def last_layer(self, z): z = torch.einsum("ij,mnj->imn", z, self.W) return z
def train(epoch): torch.set_printoptions(precision=16) print('\nEpoch: %d' % epoch) net.train() train_loss = 0 correct = 0 total = 0 step_st_time = time.time() epoch_time = 0 print('\nKFAC/KBFGS damping: %f' % damping) print('\nNGD damping: %f' % (damping)) # desc = ('[%s][LR=%s] Loss: %.3f | Acc: %.3f%% (%d/%d)' % (tag, lr_scheduler.get_last_lr()[0], 0, 0, correct, total)) writer.add_scalar('train/lr', lr_scheduler.get_last_lr()[0], epoch) prog_bar = tqdm(enumerate(trainloader), total=len(trainloader), desc=desc, leave=True) for batch_idx, (inputs, targets) in prog_bar: if optim_name in ['kfac', 'skfac', 'ekfac', 'sgd', 'adam']: inputs, targets = inputs.to(args.device), targets.to(args.device) optimizer.zero_grad() outputs = net(inputs) loss = criterion(outputs, targets) if optim_name in ['kfac', 'skfac', 'ekfac'] and optimizer.steps % optimizer.TCov == 0: # compute true fisher optimizer.acc_stats = True with torch.no_grad(): sampled_y = torch.multinomial(torch.nn.functional.softmax(outputs.cpu().data, dim=1),1).squeeze().to(args.device) loss_sample = criterion(outputs, sampled_y) loss_sample.backward(retain_graph=True) optimizer.acc_stats = False optimizer.zero_grad() # clear the gradient for computing true-fisher. loss.backward() optimizer.step() elif optim_name in ['kbfgs', 'kbfgsl', 'kbfgsl_2loop', 'kbfgsl_mem_eff']: inputs, targets = inputs.to(args.device), targets.to(args.device) optimizer.zero_grad() outputs = net.forward(inputs) loss = criterion(outputs, targets) loss.backward() # do another forward-backward pass over batch inside step() def closure(): return inputs, targets, criterion, False # is_autoencoder = False optimizer.step(closure) elif optim_name == 'exact_ngd': inputs, targets = inputs.to(args.device), targets.to(args.device) optimizer.zero_grad() outputs = net(inputs) loss = criterion(outputs, targets) # update Fisher inverse if batch_idx % args.freq == 0: # compute true fisher with torch.no_grad(): sampled_y = torch.multinomial(torch.nn.functional.softmax(outputs.cpu().data, dim=1),1).squeeze().to(args.device) # use backpack extension to compute individual gradient in a batch batch_grad = [] with backpack(BatchGrad()): loss_sample = criterion(outputs, sampled_y) loss_sample.backward(retain_graph=True) for name, param in net.named_parameters(): if hasattr(param, "grad_batch"): batch_grad.append(args.batch_size * param.grad_batch.reshape(args.batch_size, -1)) else: raise NotImplementedError J = torch.cat(batch_grad, 1) fisher = torch.matmul(J.t(), J) / args.batch_size inv = torch.linalg.inv(fisher + damping * torch.eye(fisher.size(0)).to(fisher.device)) # clean the gradient to compute the true fisher optimizer.zero_grad() loss.backward() # compute the step direction p = F^-1 @ g grad_list = [] for name, param in net.named_parameters(): grad_list.append(param.grad.data.reshape(-1, 1)) g = torch.cat(grad_list, 0) p = torch.matmul(inv, g) start = 0 for name, param in net.named_parameters(): end = start + param.data.reshape(-1, 1).size(0) param.grad.copy_(p[start:end].reshape(param.grad.data.shape)) start = end optimizer.step() ### new optimizer test elif optim_name in ['kngd'] : inputs, targets = inputs.to(args.device), targets.to(args.device) optimizer.zero_grad() outputs = net(inputs) loss = criterion(outputs, targets) if optimizer.steps % optimizer.freq == 0: # compute true fisher optimizer.acc_stats = True with torch.no_grad(): sampled_y = torch.multinomial(torch.nn.functional.softmax(outputs, dim=1),1).squeeze().to(args.device) loss_sample = criterion(outputs, sampled_y) loss_sample.backward(retain_graph=True) optimizer.acc_stats = False optimizer.zero_grad() # clear the gradient for computing true-fisher. if args.partial_backprop == 'true': idx = (sampled_y == targets) == False loss = criterion(outputs[idx,:], targets[idx]) # print('extra:', idx.sum().item()) loss.backward() optimizer.step() elif optim_name == 'ngd': if batch_idx % args.freq == 0: store_io_(True) inputs, targets = inputs.to(args.device), targets.to(args.device) optimizer.zero_grad() # net.set_require_grad(True) outputs = net(inputs) damp = damping loss = criterion(outputs, targets) loss.backward(retain_graph=True) # storing original gradient for later use grad_org = [] # grad_dict = {} for name, param in net.named_parameters(): grad_org.append(param.grad.reshape(1, -1)) # grad_dict[name] = param.grad.clone() grad_org = torch.cat(grad_org, 1) ###### now we have to compute the true fisher with torch.no_grad(): # gg = torch.nn.functional.softmax(outputs, dim=1) sampled_y = torch.multinomial(torch.nn.functional.softmax(outputs, dim=1),1).squeeze().to(args.device) if args.trial == 'true': update_list, loss = optimal_JJT_v2(outputs, sampled_y, args.batch_size, damping=damp, alpha=0.95, low_rank=args.low_rank, gamma=args.gamma, memory_efficient=args.memory_efficient, super_opt=args.super_opt) else: update_list, loss = optimal_JJT(outputs, sampled_y, args.batch_size, damping=damp, alpha=0.95, low_rank=args.low_rank, gamma=args.gamma, memory_efficient=args.memory_efficient) # optimizer.zero_grad() # update_list, loss = optimal_JJT_fused(outputs, sampled_y, args.batch_size, damping=damp) optimizer.zero_grad() # last part of SMW formula grad_new = [] for name, param in net.named_parameters(): param.grad.copy_(update_list[name]) grad_new.append(param.grad.reshape(1, -1)) grad_new = torch.cat(grad_new, 1) # grad_new = grad_org store_io_(False) else: inputs, targets = inputs.to(args.device), targets.to(args.device) optimizer.zero_grad() # net.set_require_grad(True) outputs = net(inputs) damp = damping loss = criterion(outputs, targets) loss.backward() # storing original gradient for later use grad_org = [] # grad_dict = {} for name, param in net.named_parameters(): grad_org.append(param.grad.reshape(1, -1)) # grad_dict[name] = param.grad.clone() grad_org = torch.cat(grad_org, 1) ###### now we have to compute the true fisher # with torch.no_grad(): # gg = torch.nn.functional.softmax(outputs, dim=1) # sampled_y = torch.multinomial(torch.nn.functional.softmax(outputs, dim=1),1).squeeze().to(args.device) all_modules = net.modules() for m in net.modules(): if hasattr(m, "NGD_inv"): grad = m.weight.grad if isinstance(m, nn.Linear): I = m.I G = m.G n = I.shape[0] NGD_inv = m.NGD_inv grad_prod = einsum("ni,oi->no", (I, grad)) grad_prod = einsum("no,no->n", (grad_prod, G)) v = matmul(NGD_inv, grad_prod.unsqueeze(1)).squeeze() gv = einsum("n,no->no", (v, G)) gv = einsum("no,ni->oi", (gv, I)) gv = gv / n update = (grad - gv)/damp m.weight.grad.copy_(update) elif isinstance(m, nn.Conv2d): if hasattr(m, "AX"): if args.low_rank.lower() == 'true': ###### using low rank structure U = m.U S = m.S V = m.V NGD_inv = m.NGD_inv n = NGD_inv.shape[0] grad_reshape = grad.reshape(grad.shape[0], -1) grad_prod = V @ grad_reshape.t().reshape(-1, 1) grad_prod = torch.diag(S) @ grad_prod grad_prod = U @ grad_prod grad_prod = grad_prod.squeeze() v = matmul(NGD_inv, grad_prod.unsqueeze(1)).squeeze() gv = U.t() @ v.unsqueeze(1) gv = torch.diag(S) @ gv gv = V.t() @ gv gv = gv.reshape(grad_reshape.shape[1], grad_reshape.shape[0]).t() gv = gv.view_as(grad) gv = gv / n update = (grad - gv)/damp m.weight.grad.copy_(update) else: AX = m.AX NGD_inv = m.NGD_inv n = AX.shape[0] grad_reshape = grad.reshape(grad.shape[0], -1) grad_prod = einsum("nkm,mk->n", (AX, grad_reshape)) v = matmul(NGD_inv, grad_prod.unsqueeze(1)).squeeze() gv = einsum("nkm,n->mk", (AX, v)) gv = gv.view_as(grad) gv = gv / n update = (grad - gv)/damp m.weight.grad.copy_(update) elif hasattr(m, "I"): I = m.I if args.memory_efficient == 'true': I = unfold_func(m)(I) G = m.G n = I.shape[0] NGD_inv = m.NGD_inv grad_reshape = grad.reshape(grad.shape[0], -1) x1 = einsum("nkl,mk->nml", (I, grad_reshape)) grad_prod = einsum("nml,nml->n", (x1, G)) v = matmul(NGD_inv, grad_prod.unsqueeze(1)).squeeze() gv = einsum("n,nml->nml", (v, G)) gv = einsum("nml,nkl->mk", (gv, I)) gv = gv.view_as(grad) gv = gv / n update = (grad - gv)/damp m.weight.grad.copy_(update) elif isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.BatchNorm1d): if args.batchnorm == 'true': dw = m.dw n = dw.shape[0] NGD_inv = m.NGD_inv grad_prod = einsum("ni,i->n", (dw, grad)) v = matmul(NGD_inv, grad_prod.unsqueeze(1)).squeeze() gv = einsum("n,ni->i", (v, dw)) gv = gv / n update = (grad - gv)/damp m.weight.grad.copy_(update) # last part of SMW formula grad_new = [] for name, param in net.named_parameters(): grad_new.append(param.grad.reshape(1, -1)) grad_new = torch.cat(grad_new, 1) # grad_new = grad_org ##### do kl clip lr = lr_scheduler.get_last_lr()[0] # vg_sum = 0 # vg_sum += (grad_new * grad_org ).sum() # vg_sum = vg_sum * (lr ** 2) # nu = min(1.0, math.sqrt(args.kl_clip / vg_sum)) # for name, param in net.named_parameters(): # param.grad.mul_(nu) # optimizer.step() # manual optimizing: with torch.no_grad(): for name, param in net.named_parameters(): d_p = param.grad.data # print('=== step ===') # apply momentum # if args.momentum != 0: # buf[name].mul_(args.momentum).add_(d_p) # d_p.copy_(buf[name]) # apply weight decay if args.weight_decay != 0: d_p.add_(args.weight_decay, param.data) lr = lr_scheduler.get_last_lr()[0] param.data.add_(-lr, d_p) # print('d_p:', d_p.shape) # print(d_p) train_loss += loss.item() _, predicted = outputs.max(1) total += targets.size(0) correct += predicted.eq(targets).sum().item() desc = ('[%s][LR=%s] Loss: %.3f | Acc: %.3f%% (%d/%d)' % (tag, lr_scheduler.get_last_lr()[0], train_loss / (batch_idx + 1), 100. * correct / total, correct, total)) prog_bar.set_description(desc, refresh=True) if args.step_info == 'true' and (batch_idx % 50 == 0 or batch_idx == len(prog_bar) - 1): step_saved_time = time.time() - step_st_time epoch_time += step_saved_time test_acc, test_loss = test(epoch) TRAIN_INFO['train_acc'].append(float("{:.4f}".format(100. * correct / total))) TRAIN_INFO['test_acc'].append(float("{:.4f}".format(test_acc))) TRAIN_INFO['train_loss'].append(float("{:.4f}".format(train_loss/(batch_idx + 1)))) TRAIN_INFO['test_loss'].append(float("{:.4f}".format(test_loss))) TRAIN_INFO['total_time'].append(float("{:.4f}".format(step_saved_time))) if args.debug_mem == 'true': TRAIN_INFO['memory'].append(torch.cuda.memory_reserved()) step_st_time = time.time() net.train() writer.add_scalar('train/loss', train_loss/(batch_idx + 1), epoch) writer.add_scalar('train/acc', 100. * correct / total, epoch) acc = 100. * correct / total train_loss = train_loss/(batch_idx + 1) if args.step_info == 'true': TRAIN_INFO['epoch_time'].append(float("{:.4f}".format(epoch_time))) # save diagonal blocks of exact Fisher inverse or its approximations if args.save_inv == 'true': all_modules = net.modules() count = 0 start, end = 0, 0 if optim_name == 'ngd': for m in all_modules: if m.__class__.__name__ == 'Linear': with torch.no_grad(): I = m.I G = m.G J = torch.einsum('ni,no->nio', I, G) J = J.reshape(J.size(0), -1) JTDJ = torch.matmul(J.t(), torch.matmul(m.NGD_inv, J)) / args.batch_size with open('ngd/' + str(epoch) + '_m_' + str(count) + '_inv.npy', 'wb') as f: np.save(f, ((torch.eye(JTDJ.size(0)).to(JTDJ.device) - JTDJ) / damping).cpu().numpy()) count += 1 elif m.__class__.__name__ == 'Conv2d': with torch.no_grad(): AX = m.AX AX = AX.reshape(AX.size(0), -1) JTDJ = torch.matmul(AX.t(), torch.matmul(m.NGD_inv, AX)) / args.batch_size with open('ngd/' + str(epoch) + '_m_' + str(count) + '_inv.npy', 'wb') as f: np.save(f, ((torch.eye(JTDJ.size(0)).to(JTDJ.device) - JTDJ) / damping).cpu().numpy()) count += 1 elif optim_name == 'exact_ngd': for m in all_modules: if m.__class__.__name__ in ['Conv2d', 'Linear']: with open('exact/' + str(epoch) + '_m_' + str(count) + '_inv.npy', 'wb') as f: end = start + m.weight.data.reshape(1, -1).size(1) np.save(f, inv[start:end,start:end].cpu().numpy()) start = end + m.bias.data.size(0) count += 1 elif optim_name == 'kfac': for m in all_modules: if m.__class__.__name__ in ['Conv2d', 'Linear']: with open('kfac/' + str(epoch) + '_m_' + str(count) + '_inv.npy', 'wb') as f: G = optimizer.m_gg[m] A = optimizer.m_aa[m] H_g = torch.linalg.inv(G + math.sqrt(damping) * torch.eye(G.size(0)).to(G.device)) H_a = torch.linalg.inv(A + math.sqrt(damping) * torch.eye(A.size(0)).to(A.device)) end = m.weight.data.reshape(1, -1).size(1) kfac_inv = torch.kron(H_a, H_g)[:end,:end] np.save(f, kfac_inv.cpu().numpy()) count += 1 return acc, train_loss
def forward(self, positions): sinusoid_inp = torch.einsum("i,j->ij", positions.float(), self.inv_freq) emb = torch.cat((sinusoid_inp.sin(), sinusoid_inp.cos()), dim=-1) return emb[None, :, :]
def model(self, data): ''' Define the parameters ''' n_ind = data['n_ind'] n_trt = data['n_trt'] n_tms = data['n_tms'] n_mrk = data['n_mrk'] n_prs = n_trt * n_tms * n_mrk plt_ind = pyro.plate('individuals', n_ind, dim=-3) plt_trt = pyro.plate('treatments', n_trt, dim=-2) plt_tms = pyro.plate('times', n_tms, dim=-1) pars = {} # covariance factors with plt_tms: # learning dt time step sizes # if k(t1,t2) is independent of time, can instead learn scales and variances for RBF kernels that use data['time_vals'] pars['dt0'] = pyro.sample('dt0', dist.Normal(0, 1)) pars['dt1'] = pyro.sample('dt1', dist.Normal(0, 1)) pars['theta_trt0'] = pyro.sample('theta_trt0', dist.HalfCauchy(torch.ones(n_trt))) pars['theta_mrk0'] = pyro.sample('theta_mrk0', dist.HalfCauchy(torch.ones(n_mrk))) pars['theta_trt1'] = pyro.sample('theta_trt1', dist.HalfCauchy(torch.ones(n_trt))) pars['L_omega_trt1'] = pyro.sample( 'L_omega_trt1', dist.LKJCorrCholesky(n_trt, torch.ones(1))) pars['theta_mrk1'] = pyro.sample('theta_mrk1', dist.HalfCauchy(torch.ones(n_mrk))) pars['L_omega_mrk1'] = pyro.sample( 'L_omega_mrk1', dist.LKJCorrCholesky(n_mrk, torch.ones(1))) times0 = fun.pad(torch.cumsum(pars['dt0'].exp().log1p(), 0), (1, 0), value=0)[:-1].unsqueeze(1) times1 = fun.pad(torch.cumsum(pars['dt1'].exp().log1p(), 0), (1, 0), value=0)[:-1].unsqueeze(1) cov_t0 = (-torch.cdist(times0, times0)).exp() cov_t1 = (-torch.cdist(times1, times1)).exp() cov_i0 = pars['theta_trt0'].diag() L_Omega_trt = torch.mm(torch.diag(pars['theta_trt1'].sqrt()), pars['L_omega_trt1']) cov_i1 = L_Omega_trt.mm(L_Omega_trt.t()) cov_m0 = pars['theta_mrk0'].diag() L_Omega_mrk = torch.mm(torch.diag(pars['theta_mrk1'].sqrt()), pars['L_omega_mrk1']) cov_m1 = L_Omega_mrk.mm(L_Omega_mrk.t()) # kronecker product of the factors cov_itm0 = torch.einsum('ij,tu,mn->itmjun', [cov_i0, cov_t0, cov_m0]).view(n_prs, n_prs) cov_itm1 = torch.einsum('ij,tu,mn->itmjun', [cov_i1, cov_t1, cov_m1]).view(n_prs, n_prs) # global and individual level params of each marker, treatment, and time point pars['glb'] = pyro.sample( 'glb', dist.MultivariateNormal(torch.zeros(n_prs), cov_itm0)) with plt_ind: pars['ind'] = pyro.sample( 'ind', dist.MultivariateNormal(torch.zeros(n_prs), cov_itm1)) # observation noise, time series bias and scale pars['noise_scale'] = pyro.sample('noise_scale', dist.HalfCauchy(torch.ones(n_mrk))) pars['t0_scale'] = pyro.sample('t0_scale', dist.HalfCauchy(torch.ones(n_mrk))) with plt_ind: pars['t0'] = pyro.sample( 't0', dist.MultivariateNormal(torch.zeros(n_mrk), pars['t0_scale'].diag())) with plt_trt, plt_tms: pars['noise'] = pyro.sample( 'noise', dist.MultivariateNormal(torch.zeros(n_mrk), pars['noise_scale'].diag())) # likelihood of the data distr = self.get_distr(data, pars) pyro.sample('obs', distr, obs=data['Y'])
def conditional(self, input, given): return torch.einsum('ik,lk->il', input, self.weight[given,:,:]) + self.bias[given,:].unsqueeze(0)
def backward(ctx, grad_kernel): F, Y, R, norm_coef = ctx.saved_tensors batch, a, b = ctx.batch, ctx.a, ctx.b grad_F = grad_Y = grad_R = None if ctx.needs_input_grad[0]: grad_F = grad_kernel.new_zeros( *ctx.F_shape) # [batch, b, l_in * mul_in * m_in] if ctx.needs_input_grad[1]: grad_Y = grad_kernel.new_zeros( *ctx.Y_shape) # [l_filter * m_filter, batch, a, b] if ctx.needs_input_grad[2]: grad_R = grad_kernel.new_zeros( *ctx.R_shape ) # [batch, a, b, l_out * l_in * mul_out * mul_in * l_filter] begin_R = 0 begin_out = 0 for i, (mul_out, l_out, p_out) in enumerate(ctx.Rs_out): s_out = slice(begin_out, begin_out + mul_out * (2 * l_out + 1)) begin_out += mul_out * (2 * l_out + 1) begin_in = 0 for j, (mul_in, l_in, p_in) in enumerate(ctx.Rs_in): s_in = slice(begin_in, begin_in + mul_in * (2 * l_in + 1)) begin_in += mul_in * (2 * l_in + 1) l_filters = ctx.get_l_filters(l_in, p_in, l_out, p_out) if not l_filters: continue n = mul_out * mul_in * len(l_filters) if (grad_Y is not None) or (grad_F is not None): sub_R = R[:, :, :, begin_R:begin_R + n].contiguous().view( batch, a, b, mul_out, mul_in, -1) # [batch, a, b, mul_out, mul_in, l_filter] if grad_R is not None: sub_grad_R = grad_R[:, :, :, begin_R:begin_R + n].contiguous( ).view(batch, a, b, mul_out, mul_in, -1) # [batch, a, b, mul_out, mul_in, l_filter] if grad_F is not None: sub_grad_F = grad_F[:, :, s_in].contiguous().view( batch, b, mul_in, 2 * l_in + 1) # [batch, b, mul_in, 2 * l_in + 1] if (grad_Y is not None) or (grad_R is not None): sub_F = F[..., s_in].view(batch, b, mul_in, 2 * l_in + 1) grad_K = grad_kernel[:, :, s_out].view(batch, a, mul_out, 2 * l_out + 1) sub_norm_coef = norm_coef[i, j] # [batch, a, b] for k, l_filter in enumerate(l_filters): tmp = sum(2 * l + 1 for l in ctx.set_of_l_filters if l < l_filter) C = o3.clebsch_gordan(l_out, l_in, l_filter, cached=True, like=grad_kernel) # [m_out, m_in, m] if (grad_F is not None) or (grad_R is not None): sub_Y = Y[tmp:tmp + 2 * l_filter + 1, ...] # [m, batch, a, b] if grad_F is not None: sub_grad_F += torch.einsum( "zaui,ijk,kzab,zabuv,zab->zbvj", grad_K, C, sub_Y, sub_R[..., k], sub_norm_coef) # [batch, b, mul_in, 2 * l_in + 1 if grad_Y is not None: grad_Y[tmp:tmp + 2 * l_filter + 1, ...] += torch.einsum( "zaui,ijk,zabuv,zab,zbvj->kzab", grad_K, C, sub_R[..., k], sub_norm_coef, sub_F) # [m, batch, a, b] if grad_R is not None: sub_grad_R[..., k] = torch.einsum( "zaui,ijk,kzab,zab,zbvj->zabuv", grad_K, C, sub_Y, sub_norm_coef, sub_F) # [batch, a, b, mul_out, mul_in] if grad_F is not None: grad_F[:, :, s_in] = sub_grad_F.view(batch, b, mul_in * (2 * l_in + 1)) if grad_R is not None: grad_R[..., begin_R:begin_R + n] += sub_grad_R.view( batch, a, b, -1) begin_R += n return grad_F, grad_Y, grad_R, None, None, None, None, None
def apply_TM_1sO_2(state, env, edge, op=None, verbosity=0): r""" :param state: underlying 1-site C4v symmetric wavefunction :param env: C4v symmetric environment corresponding to ``state`` :param edge: tensor of dimensions :math:`\chi \times (D^2)^2 \times \chi` :param op: two-site operator to be inserted within the two-site transfer matrix :param verbosity: logging verbosity :type state: IPEPS_C4V :type env: ENV_C4V :type edge: torch.tensor :type op: torch.tensor :type verbosity: int :return: ``edge`` with a single instance of the transfer matrix applied The resulting tensor has an identical index structure as the original ``edge`` :rtype: torch.tensor Applies a single instance of the two-site "transfer matrix" to the ``edge`` tensor by contracting the following network, or its corresponding rotation depending on the ``direction``:: -----T---------- | | edge--(a^+ o1 a)-- | | | |----(a^+ o2 a)-- | | -----T---------- The two-site operator is first decomposed into a simple MPO o1--o2 (TODO case where op comes with an extra MPO index):: s1' s2' s1' s2' | op | = |o1|-----|o2| s1 s2 s1 s2 where the physical indices `s` and `s'` of the on-site tensor :math:`a` and it's hermitian conjugate :math:`a^\dagger` are contracted with identity :math:`\delta_{s,s'}` or ``o1``, ``o2``. """ # TODO stronger verification op_1, op_2 = None, None if op is not None: if len(op.size()) == 4: # pre-process ``op`` # TODO possibly truncate/compress according to the vanishingly small singular values dims_op = op.size() op_mat = op.permute(0, 2, 1, 3).contiguous().reshape( dims_op[0]**2, dims_op[0]**2) op_1, s, op_2 = torch.svd(op_mat) op_1 = op_1.reshape(dims_op[0], dims_op[0], s.size()[0]) op_2 = torch.einsum('i,ij->ij', s, op_2.t()).reshape(s.size()[0], dims_op[0], dims_op[0]) op_2 = op_2.permute(1, 2, 0).contiguous() else: raise ValueError(f"Invalid op: rank {op.size()}") # Four basic cases of passed op def get_aXa(a, op): # a - on-site tensor # op - operator dims_a = a.size() dims_op = None if op is None else op.size() if op is None: # identity A= torch.einsum('nefgh,nabcd->eafbgchd',a,a).contiguous()\ .view(dims_a[1]**2, dims_a[2]**2, dims_a[3]**2, dims_a[4]**2) elif len(dims_op) == 2: # one-site operator A= torch.einsum('mefgh,mn,nabcd->eafbgchd',a,op,a).contiguous()\ .view(dims_a[1]**2, dims_a[2]**2, dims_a[3]**2, dims_a[4]**2) elif len(dims_op) == 3: # edge operators of some MPO within the transfer matrix # # 0 0 # | | # op--2 ... or ... 2--op # | | # 1 1 # # assume the last index of the op is the MPO dimension. # It will become the last index of the resulting edge A= torch.einsum('mefgh,mnl,nabcd->eafbgchdl',a,op,a).contiguous()\ .view(dims_a[1]**2, dims_a[2]**2, dims_a[3]**2, dims_a[4]**2, -1) if verbosity > 0: print(f"aXa {A.size()}") return A a = next(iter(state.sites.values())) T = env.T[env.keyT] # Assume index structure of ``edge`` tensor to be as follows # # -- 0 # edge |-- 1 # |---2 # -- 3 # # ----0 0--T--1->0 # | 2->1 # edge--1->2 # | # ----2->3 # | # ----3->4 E = torch.tensordot(T, edge, ([0], [0])) if verbosity > 0: print("E=edgeT " + str(E.size())) # TODO - more efficent contraction with uncontracted-double-layer on-site tensor # Possibly reshape indices 1,2 of E, which are to be contracted with # on-site tensor and contract bra,ket in two steps instead of creating # double layer tensor # / # --A-- # /|s # X # s'|/ # --A-- # / # # where X is Id or op A = get_aXa(a, op_1) # ---------T--0 # | 1 # | 0 # edge--2 1--A--3->4 # | 3<-2 \ # ----3->1 (4->5) # | # ----4->2 E = torch.tensordot(E, A, ([1, 2], [0, 1])) if verbosity > 0: print("E=edgeTA " + str(E.size())) A = get_aXa(a, op_2) # ---------T--0 # | | # edge-------A--4->2 # | | \ # | 3 (5) # | 0 (4) # | | / # ----1 1--A--2->3 # | 3->4 # ----2->1 E = torch.tensordot(E,A,([1,3],[1,0])) if op is None else \ torch.tensordot(E,A,([1,3,5],[1,0,4])) if verbosity > 0: print("E=edgeTAA " + str(E.size())) # ---------T--0 # | | # edge-------A--2->1 # | | # ---------A--3->2 # | 3 # | 2 # ----1 0--T2--1->3 E = torch.tensordot(E, T, ([1, 3], [0, 2])) if verbosity > 0: print("E=edgeTAAT " + str(E.size())) return E
def kernel_conv_fn_forward(F, Y, R, norm_coef, Rs_in, Rs_out, get_l_filters, set_of_l_filters): """ :param F: tensor [batch, b, l_in * mul_in * m_in] :param Y: tensor [l_filter * m_filter, batch, a, b] :param R: tensor [batch, a, b, l_out * l_in * mul_out * mul_in * l_filter] :param norm_coef: tensor [l_out, l_in, batch, a, b] :return: tensor [batch, a, l_out * mul_out * m_out, l_in * mul_in * m_in] """ batch, a, b = Y.shape[1:] n_in = rs.dim(Rs_in) n_out = rs.dim(Rs_out) kernel_conv = Y.new_zeros(batch, a, n_out) # note: for the normalization we assume that the variance of R[i] is one begin_R = 0 begin_out = 0 for i, (mul_out, l_out, p_out) in enumerate(Rs_out): s_out = slice(begin_out, begin_out + mul_out * (2 * l_out + 1)) begin_out += mul_out * (2 * l_out + 1) begin_in = 0 for j, (mul_in, l_in, p_in) in enumerate(Rs_in): s_in = slice(begin_in, begin_in + mul_in * (2 * l_in + 1)) begin_in += mul_in * (2 * l_in + 1) l_filters = get_l_filters(l_in, p_in, l_out, p_out) if not l_filters: continue # extract the subset of the `R` that corresponds to the couple (l_out, l_in) n = mul_out * mul_in * len(l_filters) sub_R = R[:, :, :, begin_R:begin_R + n].contiguous().view( batch, a, b, mul_out, mul_in, -1) # [batch, a, b, mul_out, mul_in, l_filter] begin_R += n sub_norm_coef = norm_coef[i, j] # [batch] K = 0 for k, l_filter in enumerate(l_filters): offset = sum(2 * l + 1 for l in set_of_l_filters if l < l_filter) sub_Y = Y[offset:offset + 2 * l_filter + 1, ...] # [m, batch, a, b] C = o3.clebsch_gordan(l_out, l_in, l_filter, cached=True, like=kernel_conv) # [m_out, m_in, m] K += torch.einsum("ijk,kzab,zabuv,zab,zbvj->zaui", C, sub_Y, sub_R[..., k], sub_norm_coef, F[..., s_in].view( batch, b, mul_in, -1)) # [batch, a, mul_out, m_out] if K is not 0: kernel_conv[:, :, s_out] += K.view(batch, a, -1) return kernel_conv
def get_fantasy_strategy(self, inputs, targets, full_inputs, full_targets, full_output, **kwargs): """ Returns a new PredictionStrategy that incorporates the specified inputs and targets as new training data. This method is primary responsible for updating the mean and covariance caches. To add fantasy data to a GP model, use the :meth:`~gpytorch.models.ExactGP.get_fantasy_model` method. Args: - :attr:`inputs` (Tensor `b1 x ... x bk x m x d` or `f x b1 x ... x bk x m x d`): Locations of fantasy observations. - :attr:`targets` (Tensor `b1 x ... x bk x m` or `f x b1 x ... x bk x m`): Labels of fantasy observations. - :attr:`full_inputs` (Tensor `b1 x ... x bk x n+m x d` or `f x b1 x ... x bk x n+m x d`): Training data concatenated with fantasy inputs - :attr:`full_targets` (Tensor `b1 x ... x bk x n+m` or `f x b1 x ... x bk x n+m`): Training labels concatenated with fantasy labels. - :attr:`full_output` (:class:`gpytorch.distributions.MultivariateNormal`): Prior called on full_inputs Returns: - :class:`DefaultPredictionStrategy` A `DefaultPredictionStrategy` model with `n + m` training examples, where the `m` fantasy examples have been added and all test-time caches have been updated. """ full_mean, full_covar = full_output.mean, full_output.lazy_covariance_matrix batch_shape = full_inputs[0].shape[:-2] full_mean = full_mean.view(*batch_shape, -1) num_train = self.num_train # Evaluate fant x train and fant x fant covariance matrices, leave train x train unevaluated. fant_fant_covar = full_covar[..., num_train:, num_train:] fant_mean = full_mean[..., num_train:] mvn = self.train_prior_dist.__class__(fant_mean, fant_fant_covar) fant_likelihood = self.likelihood.get_fantasy_likelihood(**kwargs) mvn_obs = fant_likelihood(mvn, inputs, **kwargs) fant_fant_covar = mvn_obs.covariance_matrix fant_train_covar = delazify(full_covar[..., num_train:, :num_train]) self.fantasy_inputs = inputs self.fantasy_targets = targets r""" Compute a new mean cache given the old mean cache. We have \alpha = K^{-1}y, and we want to solve [K U; U' S][a; b] = [y; y_f], where U' is fant_train_covar, S is fant_fant_covar, and y_f is (targets - fant_mean) To do this, we solve the bordered linear system of equations for [a; b]: AQ = U # Q = fant_solve [S - U'Q]b = y_f - U'\alpha ==> b = [S - U'Q]^{-1}(y_f - U'\alpha) a = \alpha - Qb """ # Get cached K inverse decomp. (or compute if we somehow don't already have the covariance cache) K_inverse = self.lik_train_train_covar.root_inv_decomposition() fant_solve = K_inverse.matmul(fant_train_covar.transpose(-2, -1)) # Solve for "b", the lower portion of the *new* \\alpha corresponding to the fantasy points. schur_complement = fant_fant_covar - fant_train_covar.matmul( fant_solve) # we'd like to use a less hacky approach for the following, but einsum can be much faster than # than unsqueezing/squeezing here (esp. in backward passes), unfortunately it currenlty has some # issues with broadcasting: https://github.com/pytorch/pytorch/issues/15671 prefix = string.ascii_lowercase[:max( fant_train_covar.dim() - self.mean_cache.dim() - 1, 0)] ftcm = torch.einsum(prefix + "...yz,...z->" + prefix + "...y", [fant_train_covar, self.mean_cache]) small_system_rhs = targets - fant_mean - ftcm small_system_rhs = small_system_rhs.unsqueeze(-1) # Schur complement of a spd matrix is guaranteed to be positive definite schur_cholesky = psd_safe_cholesky(schur_complement) fant_cache_lower = torch.cholesky_solve(small_system_rhs, schur_cholesky) # Get "a", the new upper portion of the cache corresponding to the old training points. fant_cache_upper = self.mean_cache.unsqueeze(-1) - fant_solve.matmul( fant_cache_lower) fant_cache_upper = fant_cache_upper.squeeze(-1) fant_cache_lower = fant_cache_lower.squeeze(-1) # New mean cache. fant_mean_cache = torch.cat((fant_cache_upper, fant_cache_lower), dim=-1) # now update the root and root inverse new_lt = self.lik_train_train_covar.cat_rows(fant_train_covar, fant_fant_covar) new_root = new_lt.root_decomposition().root.evaluate() new_covar_cache = new_lt.root_inv_decomposition().root.evaluate() # Expand inputs accordingly if necessary (for fantasies at the same points) if full_inputs[0].dim() <= full_targets.dim(): fant_batch_shape = full_targets.shape[:1] n_batch = len(full_mean.shape[:-1]) repeat_shape = fant_batch_shape + torch.Size([1] * n_batch) full_inputs = [ fi.expand(fant_batch_shape + fi.shape) for fi in full_inputs ] full_mean = full_mean.expand(fant_batch_shape + full_mean.shape) full_covar = BatchRepeatLazyTensor(full_covar, repeat_shape) new_root = BatchRepeatLazyTensor(NonLazyTensor(new_root), repeat_shape) # no need to repeat the covar cache, broadcasting will do the right thing # Create new DefaultPredictionStrategy object fant_strat = self.__class__( train_inputs=full_inputs, train_prior_dist=self.train_prior_dist.__class__( full_mean, full_covar), train_labels=full_targets, likelihood=fant_likelihood, root=new_root, inv_root=new_covar_cache, ) add_to_cache(fant_strat, "mean_cache", fant_mean_cache) add_to_cache(fant_strat, "covar_cache", new_covar_cache) return fant_strat
def forward(self, data): """Run SuperGlue on a pair of keypoints and descriptors""" desc0, desc1 = data['descriptors0'], data['descriptors1'] kpts0, kpts1 = data['keypoints0'], data['keypoints1'] if kpts0.shape[1] == 0 or kpts1.shape[1] == 0: # no keypoints shape0, shape1 = kpts0.shape[:-1], kpts1.shape[:-1] return { 'matches0': kpts0.new_full(shape0, -1, dtype=torch.int), 'matches1': kpts1.new_full(shape1, -1, dtype=torch.int), 'matching_scores0': kpts0.new_zeros(shape0), 'matching_scores1': kpts1.new_zeros(shape1), } # Keypoint normalization. # kpts0 = normalize_keypoints(kpts0, data['image0'].shape) # kpts1 = normalize_keypoints(kpts1, data['image1'].shape) # Keypoint MLP encoder. # desc0 = desc0 + self.kenc(kpts0, data['scores0']) # desc1 = desc1 + self.kenc(kpts1, data['scores1']) desc0 = desc0 + self.kenc(kpts0) desc1 = desc1 + self.kenc(kpts1) # Multi-layer Transformer network. desc0, desc1 = self.gnn(desc0, desc1) # Final MLP projection. mdesc0, mdesc1 = self.final_proj(desc0), self.final_proj(desc1) # Compute matching descriptor distance. scores = torch.einsum('bdn,bdm->bnm', mdesc0, mdesc1) scores = scores / self.config['descriptor_dim']**.5 # Run the optimal transport. scores = log_optimal_transport( scores, self.bin_score, iters=self.config['sinkhorn_iterations']) # 对scores构造损失函数 # loss = compute_loss(scores, matches_gt) # scores: 1 * (m+1) * (n+1), matches_gt: 1 * (m+1) * (n+1) # loss = -scores.log() * matches_gt # Get the matches with score above "match_threshold". max0, max1 = scores[:, :-1, :-1].max(2), scores[:, :-1, :-1].max(1) indices0, indices1 = max0.indices, max1.indices mutual0 = arange_like(indices0, 1)[None] == indices1.gather(1, indices0) mutual1 = arange_like(indices1, 1)[None] == indices0.gather(1, indices1) zero = scores.new_tensor(0) mscores0 = torch.where(mutual0, max0.values.exp(), zero) # mscores1 = torch.where(mutual1, mscores0.gather(1, indices1), zero) valid0 = mutual0 & (mscores0 > self.config['match_threshold']) valid1 = mutual1 & valid0.gather(1, indices1) indices0 = torch.where(valid0, indices0, indices0.new_tensor(-1)) indices1 = torch.where(valid1, indices1, indices1.new_tensor(-1)) # hard-code top k values top_k_matches0 = scores[0, :-1, :-1].topk(5, dim=0).indices return { 'matches0': indices0, # use -1 for invalid match 'matches1': indices1, # use -1 for invalid match # 'matching_scores0': mscores0, # 'matching_scores1': mscores1, 'scores': scores, 'top_k_matches1': top_k_matches0 }
def forward(self, key, query, mask, cache=False, boundary_leftmost=0, boundary_rightmost=100000): """Compute chunkwise energy. Args: key (FloatTensor): `[B, klen, kdim]` query (FloatTensor): `[B, qlen, qdim]` mask (ByteTensor): `[B, qlen, klen]` cache (bool): cache key and mask boundary_leftmost (int): leftmost boundary offset boundary_rightmost (int): rightmost boundary offset Returns: e (FloatTensor): `[B, H_ca, qlen, klen]` """ klen, kdim = key.size()[1:] bs, qlen = query.size()[:2] # Pre-computation of encoder-side features for computing scores if self.key is None or not cache: self.key = self.w_key(key).view(-1, klen, self.n_heads, self.d_k) # `[B, klen, H_ca, d_k]` if mask is not None: self.mask = mask.unsqueeze(3).repeat( [1, 1, 1, self.n_heads]) # `[B, qlen, klen, H_ca]` mask_size = (bs, qlen, klen, self.n_heads) assert self.mask.size() == mask_size, (self.mask.size(), mask_size) else: self.mask = None k = self.key if k.size(0) != bs: # for infernece k = k[0:1].repeat([bs, 1, 1, 1]) klen = k.size(1) q = self.w_query(query).view(-1, qlen, self.n_heads, self.d_k) # `[B, qlen, H_ca, d_k]` m = self.mask # Truncate encoder memories for efficient DECODING if boundary_leftmost > 0 or (0 < boundary_rightmost < klen): k = k[:, boundary_leftmost:boundary_rightmost + 1] klen = k.size(1) if m is not None: m = m[:, :, boundary_leftmost:boundary_rightmost + 1] if self.atype == 'scaled_dot': e = torch.einsum("bihd,bjhd->bijh", (q, k)) / self.scale elif self.atype == 'add': e = self.v( torch.relu(k[:, None] + q[:, :, None]).view( bs, qlen, klen, -1)) # e: `[B, qlen, klen, H_ca]` if m is not None: NEG_INF = float( np.finfo(torch.tensor(0, dtype=e.dtype).numpy().dtype).min) e = e.masked_fill_(m == 0, NEG_INF) e = e.permute(0, 3, 1, 2) # `[B, H_ca, qlen, klen]` return e
def forward(self, hidden_states, start_positions=None, end_positions=None, cls_index=None, is_impossible=None, p_mask=None): outputs = () start_logits = self.start_logits(hidden_states, p_mask=p_mask) if start_positions is not None and end_positions is not None: # If we are on multi-GPU, let's remove the dimension added by batch splitting for x in (start_positions, end_positions, cls_index, is_impossible): if x is not None and x.dim() > 1: x.squeeze_(-1) # during training, compute the end logits based on the ground truth of the start position end_logits = self.end_logits(hidden_states, start_positions=start_positions, p_mask=p_mask) loss_fct = CrossEntropyLoss() start_loss = loss_fct(start_logits, start_positions) end_loss = loss_fct(end_logits, end_positions) total_loss = (start_loss + end_loss) / 2 if cls_index is not None and is_impossible is not None: # Predict answerability from the representation of CLS and START cls_logits = self.answer_class(hidden_states, start_positions=start_positions, cls_index=cls_index) loss_fct_cls = nn.BCEWithLogitsLoss() cls_loss = loss_fct_cls(cls_logits, is_impossible) # note(zhiliny): by default multiply the loss by 0.5 so that the scale is comparable to start_loss and end_loss total_loss += cls_loss * 0.5 outputs = (total_loss, ) + outputs else: # during inference, compute the end logits based on beam search bsz, slen, hsz = hidden_states.size() start_log_probs = F.softmax(start_logits, dim=-1) # shape (bsz, slen) start_top_log_probs, start_top_index = torch.topk( start_log_probs, self.start_n_top, dim=-1) # shape (bsz, start_n_top) start_top_index_exp = start_top_index.unsqueeze(-1).expand( -1, -1, hsz) # shape (bsz, start_n_top, hsz) start_states = torch.gather( hidden_states, -2, start_top_index_exp) # shape (bsz, start_n_top, hsz) start_states = start_states.unsqueeze(1).expand( -1, slen, -1, -1) # shape (bsz, slen, start_n_top, hsz) hidden_states_expanded = hidden_states.unsqueeze(2).expand_as( start_states) # shape (bsz, slen, start_n_top, hsz) p_mask = p_mask.unsqueeze(-1) if p_mask is not None else None end_logits = self.end_logits(hidden_states_expanded, start_states=start_states, p_mask=p_mask) end_log_probs = F.softmax(end_logits, dim=1) # shape (bsz, slen, start_n_top) end_top_log_probs, end_top_index = torch.topk( end_log_probs, self.end_n_top, dim=1) # shape (bsz, end_n_top, start_n_top) end_top_log_probs = end_top_log_probs.view( -1, self.start_n_top * self.end_n_top) end_top_index = end_top_index.view( -1, self.start_n_top * self.end_n_top) start_states = torch.einsum("blh,bl->bh", hidden_states, start_log_probs) cls_logits = self.answer_class(hidden_states, start_states=start_states, cls_index=cls_index) outputs = (start_top_log_probs, start_top_index, end_top_log_probs, end_top_index, cls_logits) + outputs # return start_top_log_probs, start_top_index, end_top_log_probs, end_top_index, cls_logits # or (if labels are provided) (total_loss,) return outputs
def projx(self, x: torch.Tensor) -> torch.Tensor: U, _, V = linalg.svd(x, full_matrices=False) return torch.einsum("...ik,...kj->...ij", U, V)
# adds a dimension of size 1, just like unsqueeze points = points[None] print(points) # ------------------------------------------------------------------- # 3.4 Named tensors img_t = torch.randn(3, 5, 5) # shape [channels, rows, columns] weights = torch.tensor([0.2126, 0.7152, 0.0722]) batch_t = torch.randn(2, 3, 5, 5) # shape [batch, channels, rows, columns] img_gray_naive = img_t.mean(-3) batch_gray_naive = batch_t.mean(-3) print(f"shape_1: {img_gray_naive.shape}, shape_2: {batch_gray_naive.shape}") unsqueezed_weights = weights.unsqueeze(-1).unsqueeze_(-1) img_weights = (img_t * unsqueezed_weights) batch_weights = (batch_t * unsqueezed_weights) img_gray_weighted = img_weights.sum(-3) batch_gray_weighted = batch_weights.sum(-3) print(f"{batch_weights.shape}, {batch_t.shape}, {unsqueezed_weights.shape}") img_gray_weighted_fancy = torch.einsum('...chw,c->...hw', img_t, weights) batch_gray_weighted_fancy = torch.einsum('...chw,c->...hw', batch_t, weights) print(batch_gray_weighted_fancy.shape)
def similarity(x, means): return torch.einsum('bhld,hcd->bhlc', x, means)
def forward(self, qk, v, query_len=None, input_mask=None): batch_size, seqlen, dim = qk.shape query_len = default(query_len, seqlen) device = qk.device n_buckets = seqlen // self.bucket_size buckets = self.hash_vectors(n_buckets, qk) # We use the same vector as both a query and a key. assert int(buckets.shape[1]) == self.n_hashes * seqlen ticker = torch.arange(self.n_hashes * seqlen, device=device).unsqueeze(0).expand_as(buckets) buckets_and_t = seqlen * buckets + (ticker % seqlen) buckets_and_t = buckets_and_t.detach() # Hash-based sort ("s" at the start of variable names means "sorted") sbuckets_and_t, sticker = sort_key_val(buckets_and_t, ticker, dim=-1) _, undo_sort = sort_key_val(sticker, ticker, dim=-1) del ticker sbuckets_and_t = sbuckets_and_t.detach() sticker = sticker.detach() undo_sort = undo_sort.detach() st = (sticker % seqlen) sqk = batched_index_select(qk, st) sv = batched_index_select(v, st) # Split off a "bin" axis so that attention only occurs within chunks. chunk_size = self.n_hashes * n_buckets bq_t = bkv_t = torch.reshape(st, (batch_size, chunk_size, -1)) bqk = torch.reshape(sqk, (batch_size, chunk_size, -1, dim)) bv = torch.reshape(sv, (batch_size, chunk_size, -1, dim)) # Hashing operates on unit-length vectors. Unnormalized query vectors are # fine because they effectively provide a learnable temperature for the # attention softmax, but normalizing keys is needed so that similarity for # the purposes of attention correctly corresponds to hash locality. bq = bqk bk = F.normalize(bqk, p=2, dim=-1).type(bq.type()) # Allow each chunk to attend within itself, and also one chunk back. Chunk # boundaries might occur in the middle of a sequence of items from the # same bucket, so this increases the chances of attending to relevant items. def look_one_back(x): x_extra = torch.cat([x[:, -1:, ...], x[:, :-1, ...]], dim=1) return torch.cat([x, x_extra], dim=2) bk = look_one_back(bk) bv = look_one_back(bv) bkv_t = look_one_back(bkv_t) # Dot-product attention. dots = torch.einsum('bhie,bhje->bhij', bq, bk) * (dim**-0.5) masked_value = max_neg_value(dots) # Input mask for padding in variable lengthed sequences if input_mask is not None: input_mask = F.pad(input_mask, (0, seqlen - input_mask.shape[1]), 'constant', True) mq = input_mask.gather(1, st).reshape((batch_size, chunk_size, -1)) mkv = look_one_back(mq) mask = mq[:, :, :, None] * mkv[:, :, None, :] dots.masked_fill_(~mask, masked_value) del mask # Causal masking if self.causal: mask = bq_t[:, :, :, None] < bkv_t[:, :, None, :].clamp(max=query_len - 1) dots.masked_fill_(mask, masked_value) del mask # Mask out attention to self except when no other targets are available. self_mask = bq_t[:, :, :, None] == bkv_t[:, :, None, :] dots.masked_fill_(self_mask, TOKEN_SELF_ATTN_VALUE) del self_mask # Mask out attention to other hash buckets. if not self._attend_across_buckets: bq_buckets = bkv_buckets = torch.reshape( sbuckets_and_t // seqlen, (batch_size, chunk_size, -1)) bkv_buckets = look_one_back(bkv_buckets) bucket_mask = bq_buckets[:, :, :, None] != bkv_buckets[:, :, None, :] dots.masked_fill_(bucket_mask, masked_value) del bucket_mask # Don't double-count query-key pairs across multiple rounds of hashing. # There are two possible strategies here. (1) The default is to count how # many times a query-key pair is repeated, and to lower its log-prob # correspondingly at each repetition. (2) When hard_k is set, the code # instead masks all but the first occurence of each query-key pair. if not self._allow_duplicate_attention: locs1 = undo_sort // bq_t.shape[-1] locs2 = (locs1 + 1) % chunk_size if not self._attend_across_buckets: locs1 = buckets * chunk_size + locs1 locs2 = buckets * chunk_size + locs2 locs = torch.cat([ torch.reshape(locs1, (batch_size, self.n_hashes, seqlen)), torch.reshape(locs2, (batch_size, self.n_hashes, seqlen)), ], 1).permute((0, 2, 1)) slocs = batched_index_select(locs, st) b_locs = torch.reshape( slocs, (batch_size, chunk_size, -1, 2 * self.n_hashes)) b_locs1 = b_locs[:, :, :, None, :self.n_hashes] bq_locs = b_locs1.expand(b_locs.shape[:3] + (2, self.n_hashes)) bq_locs = torch.reshape(bq_locs, b_locs.shape) bkv_locs = look_one_back(b_locs) dup_counts = (bq_locs[:, :, :, None, :] == bkv_locs[:, :, None, :, :]) # for memory considerations, chunk summation of last dimension for counting duplicates dup_counts = chunked_sum(dup_counts, chunks=(self.n_hashes * batch_size)) dup_counts = dup_counts.detach() assert dup_counts.shape == dots.shape dots = dots - torch.log(dup_counts + 1e-9) del dup_counts # Softmax. dots_logsumexp = torch.logsumexp(dots, dim=-1, keepdim=True) dots = torch.exp(dots - dots_logsumexp).type(dots.type()) dropped_dots = self.dropout(dots) bo = torch.einsum('buij,buje->buie', dropped_dots, bv) so = torch.reshape(bo, (batch_size, -1, dim)) slogits = torch.reshape(dots_logsumexp, ( batch_size, -1, )) class UnsortLogits(Function): @staticmethod def forward(ctx, so, slogits): so = so.detach() slogits = slogits.detach() o = batched_index_select(so, undo_sort) _, logits = sort_key_val(sticker, slogits, dim=-1) return o, logits @staticmethod def backward(ctx, grad_x, grad_y): so_grad = batched_index_select(grad_x, sticker) _, slogits_grad = sort_key_val(buckets_and_t, grad_y, dim=-1) return so_grad, slogits_grad o, logits = UnsortLogits.apply(so, slogits) o = torch.reshape(o, (batch_size, self.n_hashes, seqlen, dim)) logits = torch.reshape(logits, (batch_size, self.n_hashes, seqlen, 1)) if query_len != seqlen: query_slice = (slice(None), slice(None), slice(0, query_len)) o, logits = o[query_slice], logits[query_slice] probs = torch.exp(logits - torch.logsumexp(logits, dim=1, keepdim=True)) out = torch.sum(o * probs, dim=1) attn = torch.empty(0, device=device) # return unsorted attention weights if self._return_attn: attn_unsort = ((bq_t * seqlen)[:, :, :, None] + bkv_t[:, :, None, :]) attn_unsort = attn_unsort.view(batch_size * self.n_hashes, -1).long() unsorted_dots = torch.zeros(batch_size * self.n_hashes, seqlen * seqlen, device=device) unsorted_dots.scatter_add_(1, attn_unsort, dots.view_as(attn_unsort)) del attn_unsort unsorted_dots = unsorted_dots.reshape(batch_size, self.n_hashes, seqlen, seqlen) attn = torch.sum(unsorted_dots[:, :, 0:query_len, :] * probs, dim=1) # return output, attention matrix, and bucket distribution return out, attn, buckets
def attention(query, key, value): dim = query.shape[1] scores = torch.einsum('bdhn,bdhm->bhnm', query, key) / dim**.5 prob = torch.nn.functional.softmax(scores, dim=-1) return torch.einsum('bhnm,bdhm->bdhn', prob, value), prob
def forward(self, q, k, v, query_mask=None, key_mask=None, **kwargs): b, h, t, d, kv_t, wsz, c_wsz, nc, device, dtype = *q.shape, k.shape[2], self.window_size, self.context_window_size, self.num_clusters, q.device, q.dtype is_reverse = kwargs.pop('_reverse', False) out = torch.zeros_like(q, dtype=dtype) update_kmeans = self.training and not is_reverse key_mask = default(key_mask, query_mask) if not self.receives_context else key_mask kv_wsz = wsz if not self.receives_context else c_wsz wsz = min(wsz, t) kv_wsz = min(kv_wsz, kv_t) if not self.shared_qk or self.receives_context: dists, aux_loss = self.kmeans(torch.cat((q, k), dim=2), update_kmeans) q_dists, k_dists = split_at_index(2, t, dists) indices = distribution(q_dists, wsz) kv_indices = distribution(k_dists, kv_wsz) else: dists, aux_loss = self.kmeans(q, update_kmeans) k = F.normalize(k, dim=-1).to(q) indices = distribution(dists, wsz) kv_indices = indices q = batched_index_select(q, indices) k = batched_index_select(k, kv_indices) v = batched_index_select(v, kv_indices) reshape_with_window = lambda x: x.reshape(b, h, nc, -1, d) q, k, v = map(reshape_with_window, (q, k, v)) m_k, m_v = map(lambda x: expand_dim(x, 0, b).to(q), (self.mem_key, self.mem_value)) k, v = map(lambda x: torch.cat(x, dim=3), ((m_k, k), (m_v, v))) dots = torch.einsum('bhnid,bhnjd->bhnij', q, k) * (d ** -0.5) mask_value = max_neg_value(dots) if exists(query_mask) or exists(key_mask): query_mask = default(query_mask, lambda: torch.ones((b, t), device=device).bool()) key_mask = default(key_mask, lambda: torch.ones((b, kv_t), device=device).bool()) q_mask = expand_dim(query_mask, 1, h).gather(2, indices) kv_mask = expand_dim(key_mask, 1, h).gather(2, kv_indices) q_mask, kv_mask = map(lambda t: t.reshape(b, h, nc, -1), (q_mask, kv_mask)) mask = q_mask[:, :, :, :, None] * kv_mask[:, :, :, None, :] mask = F.pad(mask, (self.num_mem_kv, 0), value=1) dots.masked_fill_(~mask, mask_value) del mask if self.causal: q_mask, kv_mask = map(lambda t: t.reshape(b, h, nc, -1), (indices, kv_indices)) mask = q_mask[:, :, :, :, None] >= kv_mask[:, :, :, None, :] mask = F.pad(mask, (self.num_mem_kv, 0), value=1) dots.masked_fill_(~mask, mask_value) del mask if self.shared_qk: q_mask, kv_mask = map(lambda t: t.reshape(b, h, nc, -1), (indices, kv_indices)) mask = q_mask[:, :, :, :, None] == kv_mask[:, :, :, None, :] mask = F.pad(mask, (self.num_mem_kv, 0), value=0) dots.masked_fill_(mask, TOKEN_SELF_ATTN_VALUE) del mask dots = dots.softmax(dim=-1) dots = self.dropout(dots) bo = torch.einsum('bhcij,bhcjd->bhcid', dots, v) so = torch.reshape(bo, (b, h, -1, bo.shape[-1])).type(dtype) out = scatter_mean(out, so, indices.unsqueeze(-1).expand_as(so), -2) return out, aux_loss