def forward(self, input_ids, attention_mask=None): hidden_states = self.bert(input_ids=input_ids, attention_mask=attention_mask)[2] # выравниваем каждый слой в hidden_states в 1 ряд [num_layers, batch_size * seqlen, embedding_dim] hidden_state_one_row = torch.stack([ torch.cat(torch.tensor_split(hidden_state, sections=hidden_state.shape[0]), dim=1).squeeze() for hidden_state in hidden_states ]) attention_mask_one_row = torch.cat(torch.tensor_split( attention_mask, sections=attention_mask.shape[0]), dim=1).squeeze() # last_hidden_state.shape[0] вместо BatchSize, # потому что последний батч может быть остатком от деления и != BatchSize attention_output = self.sparse_attention( hidden_state_one_row, attention_mask=attention_mask_one_row)[ 0] # , attention_mask=attention_mask_one_row)[0] # на выходе получаем [num_layers, batch_size * seqlen, embedding_dim], берем последний слой # размера [batch_size * seqlen, embedding_dim] и превращаем его в [1, batch_size * seqlen, embedding_dim] output = self.output(attention_output)[-1].unsqueeze(0) # превращаем этот hidden_state в размерность батча [batch_size, seq_len, embedding_dim] batch_output = torch.cat(torch.tensor_split( output, sections=hidden_states[-1].shape[0], dim=1), dim=0) predictions = self.linear(batch_output) return predictions
def tensor_indexing_ops(self): x = torch.randn(2, 4) y = torch.randn(4, 4) t = torch.tensor([[0, 0], [1, 0]]) mask = x.ge(0.5) i = [0, 1] return len( torch.cat((x, x, x), 0), torch.concat((x, x, x), 0), torch.conj(x), torch.chunk(x, 2), torch.dsplit(torch.randn(2, 2, 4), i), torch.column_stack((x, x)), torch.dstack((x, x)), torch.gather(x, 0, t), torch.hsplit(x, i), torch.hstack((x, x)), torch.index_select(x, 0, torch.tensor([0, 1])), x.index(t), torch.masked_select(x, mask), torch.movedim(x, 1, 0), torch.moveaxis(x, 1, 0), torch.narrow(x, 0, 0, 2), torch.nonzero(x), torch.permute(x, (0, 1)), torch.reshape(x, (-1, )), torch.row_stack((x, x)), torch.select(x, 0, 0), torch.scatter(x, 0, t, x), x.scatter(0, t, x.clone()), torch.diagonal_scatter(y, torch.ones(4)), torch.select_scatter(y, torch.ones(4), 0, 0), torch.slice_scatter(x, x), torch.scatter_add(x, 0, t, x), x.scatter_(0, t, y), x.scatter_add_(0, t, y), # torch.scatter_reduce(x, 0, t, reduce="sum"), torch.split(x, 1), torch.squeeze(x, 0), torch.stack([x, x]), torch.swapaxes(x, 0, 1), torch.swapdims(x, 0, 1), torch.t(x), torch.take(x, t), torch.take_along_dim(x, torch.argmax(x)), torch.tensor_split(x, 1), torch.tensor_split(x, [0, 1]), torch.tile(x, (2, 2)), torch.transpose(x, 0, 1), torch.unbind(x), torch.unsqueeze(x, -1), torch.vsplit(x, i), torch.vstack((x, x)), torch.where(x), torch.where(t > 0, t, 0), torch.where(t > 0, t, t), )
def _handle_row_wise_sharding_sharded_tensor(input, world_size, weight, local_shard_t, bias, pg): """ Entry-point function to handle the logic of row-wise sharding of weight for Linear when the input is a sharded tensor. (Detailed explanations of the logic can be found in the comment for sharded_linear.) Args: input: matrix to be multiplied with the sharded weight. world_size: number of ranks. weight: shareded weight tensor. local_shard_t: row-wise shared local weight used for lookup. bias: bias term of linear op. pg: process group. Returns: A :class:`_PartialTensor` object which stores the partial local result. """ results = [] local_shard = input.local_shards()[0].tensor if input.sharding_spec().dim not in (-1, len(input.size()) - 1): raise NotImplementedError( "The case when the input does not come from col-wise sharded " "linear is not supported for row-wise sharded linear.") for tensor in torch.tensor_split(local_shard, world_size): results.append( tensor.matmul(local_shard_t) + _BiasTensorPartial.apply(world_size, bias)) # Return the partial local result. return _PartialTensor(torch.cat(results), pg)
def forward( self, x: torch.Tensor, thw: Tuple[int, int, int]) -> Tuple[torch.Tensor, Tuple[int, int, int]]: x, tensor_dim = _unsqueeze(x, 4, 1) # Separate the class token and reshape the input class_token, x = torch.tensor_split(x, indices=(1, ), dim=2) x = x.transpose(2, 3) B, N, C = x.shape[:3] x = x.reshape((B * N, C) + thw).contiguous() # normalizing prior pooling is useful when we use BN which can be absorbed to speed up inference if self.norm_before_pool and self.norm_act is not None: x = self.norm_act(x) # apply the pool on the input and add back the token x = self.pool(x) T, H, W = x.shape[2:] x = x.reshape(B, N, C, -1).transpose(2, 3) x = torch.cat((class_token, x), dim=2) if not self.norm_before_pool and self.norm_act is not None: x = self.norm_act(x) x = _squeeze(x, 4, 1, tensor_dim) return x, (T, H, W)
def get_feats(_input: torch.Tensor, model: trojanvision.models.ImageModel): input_list: list[torch.Tensor] = torch.tensor_split( _input, _input.size(0) // dataset.batch_size) feats = torch.cat( [model.get_final_fm(sub_input) for sub_input in input_list]) return feats
def forward(self, x): # in lightning, forward defines the prediction/inference actions embedding = self.encoder(x) segments = torch.tensor_split(embedding, (self.output_size, ), dim=1) policy, value = segments[0], segments[1] return self.policy_activation(policy), self.value_activation(value)
def glu(a: TensorLikeType, dim: int = -1) -> TensorLikeType: dim = utils.canonicalize_dims(a.ndim, dim) check( a.shape[dim] % 2 == 0, lambda: f"Halving dimension must be even, but dimension {dim} is size {a.shape[dim]}", ) b, c = torch.tensor_split(a, 2, dim) return b * torch.sigmoid(c)
def forward(self, features: torch.Tensor) -> List[torch.Tensor]: if self.__num_features > 1: real_feature_slices = torch.tensor_split(features, self.feature_dims[1:], dim=-1) else: real_feature_slices = [features] return [ proj(real_feature_slice) for proj, real_feature_slice in zip( self._projector, real_feature_slices) ]
def apply_emb(self, lS_o, lS_i, emb_l, v_W_l): # memoized_vec_l = None # Memoized vector lists ly = [] for k, sparse_index_group_batch in enumerate(lS_i): sparse_offset_group_batch = lS_o[k] if self.is_memoized: # When true, if (k in self.memoize_idx) and ( k != self.memoize_idx[-1]): # if in memoized table ly.append(None) else: if v_W_l[k] is not None: per_sample_weights = v_W_l[k].gather( 0, sparse_index_group_batch) else: per_sample_weights = None E = emb_l[k] V = E( sparse_index_group_batch, sparse_offset_group_batch, per_sample_weights=per_sample_weights, ) if next(self.top_l.parameters()).is_cuda == True: if k in self.idx_2_cpu: ly.append(V.cuda()) else: ly.append(V) else: ly.append(V) # print(type(V)) # Type: torch.Tensor else: print("Not implemented for now (apply_emb)") split_tensor = list( torch.tensor_split(ly[self.memoize_idx[-1]], len(self.memoize_idx), dim=1)) # print(f"len: {len(split_tensor)}, shapes: {[_.shape for _ in split_tensor ]}") for idx in range(len(split_tensor)): ly[self.memoize_idx[idx]] = split_tensor[idx] # print(f"chk - len: {len(ly)}, types: {[type(_) for _ in ly ]}, shapes: {[_.shape for _ in split_tensor ]}") return ly
def encode_batch(self, x): ##### Get encoder latent and generate learned parameters. x_encoded = self.encoder(x) latent_shape = x_encoded.shape flattened_latent = self.flatten(x_encoded) mu, log_var = torch.tensor_split(flattened_latent, 2, axis=1) ##### Define learned multivariate distribution Q(z|x). std = torch.exp(log_var / 2) self.q = torch.distributions.Normal(mu, std) ##### Sample latent from distribution. z = self.q.rsample() return z, latent_shape
def _handle_row_wise_sharding_sharded_tensor( input, world_size, weight, local_shard_t, bias, pg ): """ Entry-point function to handle the logic of row-wise sharding of weight for Linear when the input is a sharded tensor. (Detailed explanations of the logic can be found in the comment for sharded_linear.) Args: input: matrix to be multiplied with the sharded weight. world_size: number of ranks. weight: shareded weight tensor. local_shard_t: row-wise shared local weight used for lookup. bias: bias term of linear op. pg: process group. Returns: A :class:`_PartialTensor` object which stores the partial local result. """ results = [] local_shard = input.local_shards()[0].tensor indices = [0] * world_size reaggrance_partial = False for idx, placement in enumerate(input._sharding_spec.placements): indices[placement.rank()] = idx if idx != placement.rank(): reaggrance_partial = True for tensor in torch.tensor_split(local_shard, world_size): results.append( tensor.matmul(local_shard_t) + _BiasTensorPartial.apply(world_size, bias) ) if reaggrance_partial: results = [results[idx] for idx in indices] # Return the partial local result. return _PartialTensor(torch.cat(results), pg)
def training_step(self, test_batch, batch_idx, optimizer_idx): x, y, prior_info = test_batch x = x.permute(0, 3, 1, 2) y = y.permute(0, 3, 1, 2) opt_encoders_decoders, opt_encoders, opt_disc = self.configure_optimizers( ) opt_encoders_decoders.zero_grad() # X data flow y_hat = self.Sz(self.Rx(x)) x_translation_loss = self.mse_loss_weighted(y, y_hat, 1 - prior_info) x_cycled = self.Qz(self.Py(y_hat)) x_cycle_loss = mse_loss(x, x_cycled) x_reconstructed = self.Qz(self.Rx(x)) x_recon_loss = mse_loss(x, x_reconstructed) # Y data flow x_hat = self.Qz(self.Py(y)) y_translation_loss = self.mse_loss_weighted(x, x_hat, 1 - prior_info) y_cycled = self.Sz(self.Rx(x_hat)) y_cycle_loss = mse_loss(y, y_cycled) x_hat = self.Sz(self.Py(y)) y_recon_loss = mse_loss(y, x_hat) total_AE_loss = (W_RECON * (x_recon_loss + y_recon_loss) + W_HAT * (x_translation_loss + y_translation_loss) + W_CYCLE * (x_cycle_loss + y_cycle_loss)) self.log("Reconstruction loss", W_RECON * (x_recon_loss + y_recon_loss)) self.log("Prior information loss", W_HAT * (x_translation_loss + y_translation_loss)) self.log("Total AutoEncoders loss", total_AE_loss, prog_bar=True) self.manual_backward(total_AE_loss, opt_encoders_decoders) opt_encoders_decoders.step() opt_encoders.zero_grad() generator_code = self.discriminator(torch.cat( (self.Rx(x), self.Py(y)))) x_disc, y_disc = torch.tensor_split(generator_code, 2, dim=0) disc_code_loss = W_D * (mse_loss(torch.zeros_like(x_disc), x_disc) + mse_loss(torch.ones_like(y_disc), y_disc)) self.log("Discriminator code loss", disc_code_loss) self.manual_backward(disc_code_loss, opt_encoders) opt_encoders.step() opt_disc.zero_grad() disc_out = self.discriminator(torch.cat((self.Rx(x), self.Py(y)))) x_disc, y_disc = torch.tensor_split(disc_out, 2, dim=0) disc_loss = W_D * (mse_loss(torch.ones_like(x_disc), x_disc) + mse_loss(torch.zeros_like(y_disc), y_disc)) self.log("Discriminator loss", disc_loss) self.manual_backward(disc_loss, opt_disc) opt_disc.step()
def sinkhorn_knopp(a, b, M, reg, numItermax=5000, stopThr=1e-5, verbose=False, log=False, **kwargs): # init data b = b.T dim_a = a.shape[0] dim_b = b.shape[0] # print (b.shape) #print (b.shape) if len(b.shape) > 1: nb_hists = b.shape[1] else: nb_hists = 0 if log: log = {'err': []} # we assume that no distances are null except those of the diagonal of # distances if nb_hists: u = torch.ones( (dim_a, nb_hists), dtype=torch.float64, device=device) / dim_a v = torch.ones( (dim_b, nb_hists), dtype=torch.float64, device=device) / dim_b else: u = torch.ones(dim_a, dtype=torch.float64, device=device) / dim_a v = torch.ones(dim_b, dtype=torch.float64, device=device) / dim_b # print(reg) # Next 3 lines equivalent to K= np.exp(-M/reg), but faster to compute K = torch.divide(M, -reg).to(torch.float64) torch.exp(K, out=K) # print(np.min(K)) tmp2 = torch.empty(b.shape, dtype=torch.float64, device=device) Kp = (1 / a).reshape(-1, 1) * K cpt = 0 err = 1 while (err > stopThr and cpt < numItermax): uprev = u vprev = v #print (u.shape) #print (K.shape) KtransposeU = torch.matmul(torch.transpose(K, 0, 1), u) v = torch.divide(b, KtransposeU) u = 1. / torch.matmul(Kp, v) #u = 1. / torch.mv((1 / a).reshape(-1, 1) * K, v) if (torch.any(KtransposeU == 0) or torch.any(torch.isnan(u)) or torch.any(torch.isnan(v)) or torch.any(torch.isinf(u)) or torch.any(torch.isinf(v))): # we have reached the machine precision # come back to previous solution and quit loop print('Warning: numerical errors at iteration', cpt) u = uprev v = vprev break if cpt % 1000 == 0: # we can speed up the process by checking for the error only all # the 10th iterations if nb_hists: tmp2 = torch.einsum('ik,ij,jk->jk', u, K, v) else: # compute right marginal tmp2= (diag(u)Kdiag(v))^T1 tmp2 = torch.einsum('i,ij,j->j', u, K, v) err = torch.linalg.norm(tmp2 - b) # violation of marginal if log: log['err'].append(err) if verbose: #if cpt % 5000 == 0: #print('{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19) print('{:5d}|{:8e}|'.format(cpt, err)) cpt = cpt + 1 if log: log['u'] = u log['v'] = v if nb_hists: # return only loss #res = torch.ones(nb_hists, dtype=torch.float64, device=device)*1000 num_splits = 5 u_split = torch.tensor_split(u, num_splits, dim=1) v_split = torch.tensor_split(v, num_splits, dim=1) # u0 = u[:,0] # v0 = v[:,0] # print(u0.unsqueeze(1).shape) # print(u_split[0].shape) # print(torch.sum(u0.unsqueeze(1)-u_split[0])) # print(torch.sum(v0.unsqueeze(1)-v_split[0])) # print(torch.einsum('ik,ij,jk,ij->k', u_split[0], K, v_split[0], M) ) # print(torch.einsum('ik,ij,jk,ij->k', u0.unsqueeze(1), K, v0.unsqueeze(1), M) ) res = [] for i in range(0, len(u_split)): #print(torch.einsum('ik,ij,jk,ij->k', u_split[i], K, v_split[i], M)) res.append( torch.einsum('ik,ij,jk,ij->k', u_split[i].to(torch.float64), K, v_split[i].to(torch.float64), M)) #print (res) res = torch.cat(res) #print (res[0].item()) #print (res) #for i in range(0,nb_hists): # res[i] = torch.sum(M*(u[:,i].reshape((-1, 1)) * K * v[:,i].reshape((1, -1)))) #res[i] = torch.einsum('ik,ij,jk,ij->k', u[:,i], K, v[:,i], M) if log: return res, log else: return res else: # return OT matrix if log: return u.reshape((-1, 1)) * K * v.reshape((1, -1)), log else: return torch.sum(M * (u.reshape((-1, 1)) * K * v.reshape((1, -1))))
def forward(self, output): n_batch = int(output.shape[0] / 3) anchor, pos, neg = torch.tensor_split(output, n_batch, dim=0) return self.triplet(anchor, pos, neg)
def forward(self, x): embedding = self.encoder(x) segments = torch.tensor_split(embedding, (self.output_size, ), dim=1) policy, value = segments[0], segments[1] return self.policy_activation(policy), self.value_activation(value)
def collate_fn(batch_list: Sequence[Tuple], device: torch.device, use_thermo: bool = True, use_dist: bool = False, multiclass: bool = False, bins: Optional[torch.Tensor] = None, dist_idxs: Optional[List[int]] = None, ceiling: float = torch.inf, inv: bool = False, return_raw: bool = False, inv_eps: float = 1e-8): """Collates list of tensors in batch""" assert not inv or not multiclass lengths = torch.tensor([tup[0].shape[1] for tup in batch_list]) pad = max(lengths) seqs_, thermos_, cons_, dists_ = [], [], [], [] # dists_ = [[] for dist_idx in dist_idxs] # TODO: clean up this line dists_ = [] raw_dists = [] dist_idxs = dist_idxs if dist_idxs else torch.arange(10) for i, tup in enumerate(batch_list): seq_ = tup[0] offset = (pad.item() - seq_.shape[1]) // 2 seq_idxs = seq_.coalesce().indices() seq_idxs[1, :] += offset seq = torch.zeros(4, pad, device=device) seq[seq_idxs[0, :], seq_idxs[1, :]] = 1 seq = concat(seq) seqs_.append(seq) thermo_ = tup[1] # error handling accounts for what's probably a PyTorch bug try: thermo_idxs = thermo_.coalesce().indices() except NotImplementedError: thermo_idxs = torch.tensor([[], []], dtype=torch.long) thermo = torch.zeros(1, pad, pad, device=device) thermo[0, thermo_idxs[0, :], thermo_idxs[1, :]] = 1 thermo_idxs += offset thermos_.append(thermo) con_ = tup[2] con_idxs = con_.coalesce().indices() con_idxs += offset con = torch.zeros(1, pad, pad, device=device) con[0, con_idxs[0, :], con_idxs[1, :]] = 1 cons_.append(con) dist_ = tup[3] dist = dist_.to(device) # dist = dist.clip(-torch.inf, ceiling) if inv: dist_out = 1 / (dist + inv_eps) dist_out[dist <= 0] = torch.nan elif multiclass: dist_out = one_hot_bin(dist, bins) dist_out[:, dist <= 0] = torch.nan else: dist_out = dist dist_out[dist <= 0] = torch.nan dists_.append(dist_out) raw_dists.append(dist) seqs = torch.stack(seqs_) thermos = torch.stack(thermos_) cons = torch.stack(cons_) dists_stack = torch.stack(dists_) # dists_stack = dists_stack[:,:,dist_idxs,:,:] dists_stack = dists_stack[..., dist_idxs, :, :] dists = torch.tensor_split(dists_stack, len(dist_idxs), dim=-3) dists = [dist.squeeze(-3) for dist in dists] ipts = seqs if use_thermo: ipts = torch.cat((ipts, thermos), 1) opts = cons if use_dist: opts = (opts, *dists) if return_raw: opts = (opts, raw_dists[0]) return ipts, opts
def _tensor_split(self, tensor, idx, axis=0): return torch.tensor_split(tensor, idx, axis=axis)