Beispiel #1
0
    def forward(self, input_ids, attention_mask=None):
        hidden_states = self.bert(input_ids=input_ids,
                                  attention_mask=attention_mask)[2]

        # выравниваем каждый слой в hidden_states в 1 ряд [num_layers, batch_size * seqlen, embedding_dim]
        hidden_state_one_row = torch.stack([
            torch.cat(torch.tensor_split(hidden_state,
                                         sections=hidden_state.shape[0]),
                      dim=1).squeeze() for hidden_state in hidden_states
        ])
        attention_mask_one_row = torch.cat(torch.tensor_split(
            attention_mask, sections=attention_mask.shape[0]),
                                           dim=1).squeeze()
        # last_hidden_state.shape[0] вместо BatchSize,
        # потому что последний батч может быть остатком от деления и != BatchSize
        attention_output = self.sparse_attention(
            hidden_state_one_row, attention_mask=attention_mask_one_row)[
                0]  # , attention_mask=attention_mask_one_row)[0]
        # на выходе получаем [num_layers, batch_size * seqlen, embedding_dim], берем последний слой
        # размера [batch_size * seqlen, embedding_dim] и превращаем его в [1, batch_size * seqlen, embedding_dim]
        output = self.output(attention_output)[-1].unsqueeze(0)
        # превращаем этот hidden_state в размерность батча [batch_size, seq_len, embedding_dim]
        batch_output = torch.cat(torch.tensor_split(
            output, sections=hidden_states[-1].shape[0], dim=1),
                                 dim=0)
        predictions = self.linear(batch_output)

        return predictions
Beispiel #2
0
 def tensor_indexing_ops(self):
     x = torch.randn(2, 4)
     y = torch.randn(4, 4)
     t = torch.tensor([[0, 0], [1, 0]])
     mask = x.ge(0.5)
     i = [0, 1]
     return len(
         torch.cat((x, x, x), 0),
         torch.concat((x, x, x), 0),
         torch.conj(x),
         torch.chunk(x, 2),
         torch.dsplit(torch.randn(2, 2, 4), i),
         torch.column_stack((x, x)),
         torch.dstack((x, x)),
         torch.gather(x, 0, t),
         torch.hsplit(x, i),
         torch.hstack((x, x)),
         torch.index_select(x, 0, torch.tensor([0, 1])),
         x.index(t),
         torch.masked_select(x, mask),
         torch.movedim(x, 1, 0),
         torch.moveaxis(x, 1, 0),
         torch.narrow(x, 0, 0, 2),
         torch.nonzero(x),
         torch.permute(x, (0, 1)),
         torch.reshape(x, (-1, )),
         torch.row_stack((x, x)),
         torch.select(x, 0, 0),
         torch.scatter(x, 0, t, x),
         x.scatter(0, t, x.clone()),
         torch.diagonal_scatter(y, torch.ones(4)),
         torch.select_scatter(y, torch.ones(4), 0, 0),
         torch.slice_scatter(x, x),
         torch.scatter_add(x, 0, t, x),
         x.scatter_(0, t, y),
         x.scatter_add_(0, t, y),
         # torch.scatter_reduce(x, 0, t, reduce="sum"),
         torch.split(x, 1),
         torch.squeeze(x, 0),
         torch.stack([x, x]),
         torch.swapaxes(x, 0, 1),
         torch.swapdims(x, 0, 1),
         torch.t(x),
         torch.take(x, t),
         torch.take_along_dim(x, torch.argmax(x)),
         torch.tensor_split(x, 1),
         torch.tensor_split(x, [0, 1]),
         torch.tile(x, (2, 2)),
         torch.transpose(x, 0, 1),
         torch.unbind(x),
         torch.unsqueeze(x, -1),
         torch.vsplit(x, i),
         torch.vstack((x, x)),
         torch.where(x),
         torch.where(t > 0, t, 0),
         torch.where(t > 0, t, t),
     )
Beispiel #3
0
def _handle_row_wise_sharding_sharded_tensor(input, world_size, weight,
                                             local_shard_t, bias, pg):
    """
    Entry-point function to handle the logic of row-wise sharding of weight
    for Linear when the input is a sharded tensor. (Detailed explanations
    of the logic can be found in the comment for sharded_linear.)

    Args:
        input: matrix to be multiplied with the sharded weight.
        world_size: number of ranks.
        weight: shareded weight tensor.
        local_shard_t: row-wise shared local weight used for lookup.
        bias: bias term of linear op.
        pg: process group.

    Returns:
        A :class:`_PartialTensor` object which stores the partial local result.
    """
    results = []
    local_shard = input.local_shards()[0].tensor
    if input.sharding_spec().dim not in (-1, len(input.size()) - 1):
        raise NotImplementedError(
            "The case when the input does not come from col-wise sharded "
            "linear is not supported for row-wise sharded linear.")

    for tensor in torch.tensor_split(local_shard, world_size):
        results.append(
            tensor.matmul(local_shard_t) +
            _BiasTensorPartial.apply(world_size, bias))

    # Return the partial local result.
    return _PartialTensor(torch.cat(results), pg)
Beispiel #4
0
    def forward(
        self, x: torch.Tensor,
        thw: Tuple[int, int,
                   int]) -> Tuple[torch.Tensor, Tuple[int, int, int]]:
        x, tensor_dim = _unsqueeze(x, 4, 1)

        # Separate the class token and reshape the input
        class_token, x = torch.tensor_split(x, indices=(1, ), dim=2)
        x = x.transpose(2, 3)
        B, N, C = x.shape[:3]
        x = x.reshape((B * N, C) + thw).contiguous()

        # normalizing prior pooling is useful when we use BN which can be absorbed to speed up inference
        if self.norm_before_pool and self.norm_act is not None:
            x = self.norm_act(x)

        # apply the pool on the input and add back the token
        x = self.pool(x)
        T, H, W = x.shape[2:]
        x = x.reshape(B, N, C, -1).transpose(2, 3)
        x = torch.cat((class_token, x), dim=2)

        if not self.norm_before_pool and self.norm_act is not None:
            x = self.norm_act(x)

        x = _squeeze(x, 4, 1, tensor_dim)
        return x, (T, H, W)
Beispiel #5
0
def get_feats(_input: torch.Tensor, model: trojanvision.models.ImageModel):
    input_list: list[torch.Tensor] = torch.tensor_split(
        _input,
        _input.size(0) // dataset.batch_size)
    feats = torch.cat(
        [model.get_final_fm(sub_input) for sub_input in input_list])
    return feats
Beispiel #6
0
    def forward(self, x):
        # in lightning, forward defines the prediction/inference actions
        embedding = self.encoder(x)
        segments = torch.tensor_split(embedding, (self.output_size, ), dim=1)
        policy, value = segments[0], segments[1]

        return self.policy_activation(policy), self.value_activation(value)
Beispiel #7
0
def glu(a: TensorLikeType, dim: int = -1) -> TensorLikeType:
    dim = utils.canonicalize_dims(a.ndim, dim)
    check(
        a.shape[dim] % 2 == 0,
        lambda:
        f"Halving dimension must be even, but dimension {dim} is size {a.shape[dim]}",
    )
    b, c = torch.tensor_split(a, 2, dim)

    return b * torch.sigmoid(c)
Beispiel #8
0
    def forward(self, features: torch.Tensor) -> List[torch.Tensor]:
        if self.__num_features > 1:
            real_feature_slices = torch.tensor_split(features,
                                                     self.feature_dims[1:],
                                                     dim=-1)
        else:
            real_feature_slices = [features]

        return [
            proj(real_feature_slice) for proj, real_feature_slice in zip(
                self._projector, real_feature_slices)
        ]
Beispiel #9
0
    def apply_emb(self, lS_o, lS_i, emb_l, v_W_l):

        # memoized_vec_l = None # Memoized vector lists
        ly = []

        for k, sparse_index_group_batch in enumerate(lS_i):
            sparse_offset_group_batch = lS_o[k]

            if self.is_memoized:  # When true,
                if (k in self.memoize_idx) and (
                        k != self.memoize_idx[-1]):  # if in memoized table
                    ly.append(None)
                else:
                    if v_W_l[k] is not None:
                        per_sample_weights = v_W_l[k].gather(
                            0, sparse_index_group_batch)
                    else:
                        per_sample_weights = None

                    E = emb_l[k]
                    V = E(
                        sparse_index_group_batch,
                        sparse_offset_group_batch,
                        per_sample_weights=per_sample_weights,
                    )

                    if next(self.top_l.parameters()).is_cuda == True:
                        if k in self.idx_2_cpu:
                            ly.append(V.cuda())
                        else:
                            ly.append(V)
                    else:
                        ly.append(V)

                    # print(type(V))
                    # Type: torch.Tensor
            else:
                print("Not implemented for now (apply_emb)")

        split_tensor = list(
            torch.tensor_split(ly[self.memoize_idx[-1]],
                               len(self.memoize_idx),
                               dim=1))

        # print(f"len: {len(split_tensor)}, shapes: {[_.shape for _ in split_tensor ]}")

        for idx in range(len(split_tensor)):
            ly[self.memoize_idx[idx]] = split_tensor[idx]

        # print(f"chk - len: {len(ly)}, types: {[type(_) for _ in ly ]}, shapes: {[_.shape for _ in split_tensor ]}")

        return ly
Beispiel #10
0
    def encode_batch(self, x):
        ##### Get encoder latent and generate learned parameters.
        x_encoded = self.encoder(x)
        latent_shape = x_encoded.shape
        flattened_latent = self.flatten(x_encoded)
        mu, log_var = torch.tensor_split(flattened_latent, 2, axis=1)
        ##### Define learned multivariate distribution Q(z|x).
        std = torch.exp(log_var / 2)
        self.q = torch.distributions.Normal(mu, std)
        ##### Sample latent from distribution.
        z = self.q.rsample()

        return z, latent_shape
Beispiel #11
0
def _handle_row_wise_sharding_sharded_tensor(
    input, world_size, weight, local_shard_t, bias, pg
):
    """
    Entry-point function to handle the logic of row-wise sharding of weight
    for Linear when the input is a sharded tensor. (Detailed explanations
    of the logic can be found in the comment for sharded_linear.)

    Args:
        input: matrix to be multiplied with the sharded weight.
        world_size: number of ranks.
        weight: shareded weight tensor.
        local_shard_t: row-wise shared local weight used for lookup.
        bias: bias term of linear op.
        pg: process group.

    Returns:
        A :class:`_PartialTensor` object which stores the partial local result.
    """
    results = []
    local_shard = input.local_shards()[0].tensor
    indices = [0] * world_size
    reaggrance_partial = False
    for idx, placement in enumerate(input._sharding_spec.placements):
        indices[placement.rank()] = idx
        if idx != placement.rank():
            reaggrance_partial = True

    for tensor in torch.tensor_split(local_shard, world_size):
        results.append(
            tensor.matmul(local_shard_t) + _BiasTensorPartial.apply(world_size, bias)
        )
    if reaggrance_partial:
        results = [results[idx] for idx in indices]

    # Return the partial local result.
    return _PartialTensor(torch.cat(results), pg)
Beispiel #12
0
    def training_step(self, test_batch, batch_idx, optimizer_idx):
        x, y, prior_info = test_batch
        x = x.permute(0, 3, 1, 2)
        y = y.permute(0, 3, 1, 2)

        opt_encoders_decoders, opt_encoders, opt_disc = self.configure_optimizers(
        )

        opt_encoders_decoders.zero_grad()

        # X data flow
        y_hat = self.Sz(self.Rx(x))
        x_translation_loss = self.mse_loss_weighted(y, y_hat, 1 - prior_info)
        x_cycled = self.Qz(self.Py(y_hat))
        x_cycle_loss = mse_loss(x, x_cycled)
        x_reconstructed = self.Qz(self.Rx(x))
        x_recon_loss = mse_loss(x, x_reconstructed)

        # Y data flow
        x_hat = self.Qz(self.Py(y))
        y_translation_loss = self.mse_loss_weighted(x, x_hat, 1 - prior_info)
        y_cycled = self.Sz(self.Rx(x_hat))
        y_cycle_loss = mse_loss(y, y_cycled)
        x_hat = self.Sz(self.Py(y))
        y_recon_loss = mse_loss(y, x_hat)

        total_AE_loss = (W_RECON * (x_recon_loss + y_recon_loss) + W_HAT *
                         (x_translation_loss + y_translation_loss) + W_CYCLE *
                         (x_cycle_loss + y_cycle_loss))

        self.log("Reconstruction loss",
                 W_RECON * (x_recon_loss + y_recon_loss))
        self.log("Prior information loss",
                 W_HAT * (x_translation_loss + y_translation_loss))
        self.log("Total AutoEncoders loss", total_AE_loss, prog_bar=True)

        self.manual_backward(total_AE_loss, opt_encoders_decoders)
        opt_encoders_decoders.step()

        opt_encoders.zero_grad()

        generator_code = self.discriminator(torch.cat(
            (self.Rx(x), self.Py(y))))
        x_disc, y_disc = torch.tensor_split(generator_code, 2, dim=0)
        disc_code_loss = W_D * (mse_loss(torch.zeros_like(x_disc), x_disc) +
                                mse_loss(torch.ones_like(y_disc), y_disc))
        self.log("Discriminator code loss", disc_code_loss)

        self.manual_backward(disc_code_loss, opt_encoders)
        opt_encoders.step()

        opt_disc.zero_grad()

        disc_out = self.discriminator(torch.cat((self.Rx(x), self.Py(y))))
        x_disc, y_disc = torch.tensor_split(disc_out, 2, dim=0)

        disc_loss = W_D * (mse_loss(torch.ones_like(x_disc), x_disc) +
                           mse_loss(torch.zeros_like(y_disc), y_disc))

        self.log("Discriminator loss", disc_loss)

        self.manual_backward(disc_loss, opt_disc)
        opt_disc.step()
Beispiel #13
0
def sinkhorn_knopp(a,
                   b,
                   M,
                   reg,
                   numItermax=5000,
                   stopThr=1e-5,
                   verbose=False,
                   log=False,
                   **kwargs):

    # init data
    b = b.T
    dim_a = a.shape[0]
    dim_b = b.shape[0]
    # print (b.shape)
    #print (b.shape)
    if len(b.shape) > 1:
        nb_hists = b.shape[1]
    else:
        nb_hists = 0

    if log:
        log = {'err': []}

    # we assume that no distances are null except those of the diagonal of
    # distances
    if nb_hists:
        u = torch.ones(
            (dim_a, nb_hists), dtype=torch.float64, device=device) / dim_a
        v = torch.ones(
            (dim_b, nb_hists), dtype=torch.float64, device=device) / dim_b
    else:
        u = torch.ones(dim_a, dtype=torch.float64, device=device) / dim_a
        v = torch.ones(dim_b, dtype=torch.float64, device=device) / dim_b

    # print(reg)

    # Next 3 lines equivalent to K= np.exp(-M/reg), but faster to compute
    K = torch.divide(M, -reg).to(torch.float64)
    torch.exp(K, out=K)

    # print(np.min(K))
    tmp2 = torch.empty(b.shape, dtype=torch.float64, device=device)

    Kp = (1 / a).reshape(-1, 1) * K
    cpt = 0
    err = 1
    while (err > stopThr and cpt < numItermax):
        uprev = u
        vprev = v
        #print (u.shape)
        #print (K.shape)
        KtransposeU = torch.matmul(torch.transpose(K, 0, 1), u)
        v = torch.divide(b, KtransposeU)
        u = 1. / torch.matmul(Kp, v)
        #u = 1. / torch.mv((1 / a).reshape(-1, 1) * K, v)

        if (torch.any(KtransposeU == 0) or torch.any(torch.isnan(u))
                or torch.any(torch.isnan(v)) or torch.any(torch.isinf(u))
                or torch.any(torch.isinf(v))):
            # we have reached the machine precision
            # come back to previous solution and quit loop
            print('Warning: numerical errors at iteration', cpt)
            u = uprev
            v = vprev
            break
        if cpt % 1000 == 0:
            # we can speed up the process by checking for the error only all
            # the 10th iterations
            if nb_hists:
                tmp2 = torch.einsum('ik,ij,jk->jk', u, K, v)
            else:
                # compute right marginal tmp2= (diag(u)Kdiag(v))^T1
                tmp2 = torch.einsum('i,ij,j->j', u, K, v)
            err = torch.linalg.norm(tmp2 - b)  # violation of marginal
            if log:
                log['err'].append(err)

            if verbose:
                #if cpt % 5000 == 0:
                #print('{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19)
                print('{:5d}|{:8e}|'.format(cpt, err))
        cpt = cpt + 1
    if log:
        log['u'] = u
        log['v'] = v

    if nb_hists:  # return only loss
        #res = torch.ones(nb_hists, dtype=torch.float64, device=device)*1000
        num_splits = 5
        u_split = torch.tensor_split(u, num_splits, dim=1)
        v_split = torch.tensor_split(v, num_splits, dim=1)

        # u0 = u[:,0]
        # v0 = v[:,0]
        # print(u0.unsqueeze(1).shape)
        # print(u_split[0].shape)
        # print(torch.sum(u0.unsqueeze(1)-u_split[0]))
        # print(torch.sum(v0.unsqueeze(1)-v_split[0]))
        # print(torch.einsum('ik,ij,jk,ij->k', u_split[0], K,  v_split[0], M) )
        # print(torch.einsum('ik,ij,jk,ij->k', u0.unsqueeze(1), K, v0.unsqueeze(1), M) )

        res = []
        for i in range(0, len(u_split)):
            #print(torch.einsum('ik,ij,jk,ij->k', u_split[i], K, v_split[i], M))
            res.append(
                torch.einsum('ik,ij,jk,ij->k', u_split[i].to(torch.float64), K,
                             v_split[i].to(torch.float64), M))
        #print (res)
        res = torch.cat(res)
        #print (res[0].item())
        #print (res)

        #for i in range(0,nb_hists):
        # res[i] = torch.sum(M*(u[:,i].reshape((-1, 1)) * K * v[:,i].reshape((1, -1))))
        #res[i] = torch.einsum('ik,ij,jk,ij->k', u[:,i], K, v[:,i], M)
        if log:
            return res, log
        else:
            return res

    else:  # return OT matrix

        if log:
            return u.reshape((-1, 1)) * K * v.reshape((1, -1)), log
        else:
            return torch.sum(M * (u.reshape((-1, 1)) * K * v.reshape((1, -1))))
Beispiel #14
0
    def forward(self, output):
        n_batch = int(output.shape[0] / 3)
        anchor, pos, neg = torch.tensor_split(output, n_batch, dim=0)

        return self.triplet(anchor, pos, neg)
Beispiel #15
0
    def forward(self, x):
        embedding = self.encoder(x)
        segments = torch.tensor_split(embedding, (self.output_size, ), dim=1)
        policy, value = segments[0], segments[1]

        return self.policy_activation(policy), self.value_activation(value)
Beispiel #16
0
def collate_fn(batch_list: Sequence[Tuple],
               device: torch.device,
               use_thermo: bool = True,
               use_dist: bool = False,
               multiclass: bool = False,
               bins: Optional[torch.Tensor] = None,
               dist_idxs: Optional[List[int]] = None,
               ceiling: float = torch.inf,
               inv: bool = False,
               return_raw: bool = False,
               inv_eps: float = 1e-8):
    """Collates list of tensors in batch"""
    assert not inv or not multiclass

    lengths = torch.tensor([tup[0].shape[1] for tup in batch_list])
    pad = max(lengths)
    seqs_, thermos_, cons_, dists_ = [], [], [], []
    # dists_ = [[] for dist_idx in dist_idxs]
    # TODO: clean up this line
    dists_ = []
    raw_dists = []

    dist_idxs = dist_idxs if dist_idxs else torch.arange(10)

    for i, tup in enumerate(batch_list):
        seq_ = tup[0]
        offset = (pad.item() - seq_.shape[1]) // 2
        seq_idxs = seq_.coalesce().indices()
        seq_idxs[1, :] += offset
        seq = torch.zeros(4, pad, device=device)
        seq[seq_idxs[0, :], seq_idxs[1, :]] = 1
        seq = concat(seq)
        seqs_.append(seq)

        thermo_ = tup[1]
        # error handling accounts for what's probably a PyTorch bug
        try:
            thermo_idxs = thermo_.coalesce().indices()
        except NotImplementedError:
            thermo_idxs = torch.tensor([[], []], dtype=torch.long)
        thermo = torch.zeros(1, pad, pad, device=device)
        thermo[0, thermo_idxs[0, :], thermo_idxs[1, :]] = 1
        thermo_idxs += offset
        thermos_.append(thermo)

        con_ = tup[2]
        con_idxs = con_.coalesce().indices()
        con_idxs += offset
        con = torch.zeros(1, pad, pad, device=device)
        con[0, con_idxs[0, :], con_idxs[1, :]] = 1
        cons_.append(con)

        dist_ = tup[3]
        dist = dist_.to(device)
        # dist = dist.clip(-torch.inf, ceiling)
        if inv:
            dist_out = 1 / (dist + inv_eps)
            dist_out[dist <= 0] = torch.nan
        elif multiclass:
            dist_out = one_hot_bin(dist, bins)
            dist_out[:, dist <= 0] = torch.nan
        else:
            dist_out = dist
            dist_out[dist <= 0] = torch.nan
        dists_.append(dist_out)
        raw_dists.append(dist)

    seqs = torch.stack(seqs_)
    thermos = torch.stack(thermos_)
    cons = torch.stack(cons_)
    dists_stack = torch.stack(dists_)
    # dists_stack = dists_stack[:,:,dist_idxs,:,:]
    dists_stack = dists_stack[..., dist_idxs, :, :]
    dists = torch.tensor_split(dists_stack, len(dist_idxs), dim=-3)
    dists = [dist.squeeze(-3) for dist in dists]

    ipts = seqs
    if use_thermo:
        ipts = torch.cat((ipts, thermos), 1)
    opts = cons
    if use_dist:
        opts = (opts, *dists)
        if return_raw:
            opts = (opts, raw_dists[0])

    return ipts, opts
Beispiel #17
0
 def _tensor_split(self, tensor, idx, axis=0):
     return torch.tensor_split(tensor, idx, axis=axis)