Exemple #1
0
 def compute_strong_fwd_bwd_loss(self, y_fwd, y_bwd, targets):
     if self.label_smoothing > 0.:
         targets = torch.clip(targets,
                              min=self.label_smoothing,
                              max=1 - self.label_smoothing)
     strong_targets_fwd = torch.cummax(targets, dim=-1)[0]
     strong_targets_bwd = torch.cummax(targets.flip(-1), dim=-1)[0].flip(-1)
     loss = nn.BCELoss(reduction='none')(y_fwd, strong_targets_fwd)
     if y_bwd is not None:
         loss = (
             loss / 2 +
             nn.BCELoss(reduction='none')(y_bwd, strong_targets_bwd) / 2)
     return loss
    def __call__(self, weights, y):
        """Compute maximum drawdown.

        Parameters
        ----------
        weights : torch.Tensor
            Tensor of shape `(n_samples, n_assets)` representing the predicted weights by our portfolio optimizer.

        y : torch.Tensor
            Tensor of shape `(n_samples, n_channels, horizon, n_assets)` representing the evolution over the next
            `horizon` timesteps.

        Returns
        -------
        torch.Tensor
            Tensor of shape `(n_samples,)` representing the per sample maximum drawdown.

        """
        cumrets = 1 + portfolio_cumulative_returns(weights,
                                                   y[:, self.returns_channel,
                                                     ...],
                                                   input_type=self.input_type,
                                                   output_type='simple')

        cummax = torch.cummax(cumrets, 1)[0]  # (n_samples, n_timesteps)

        div = (cumrets / cummax) - 1  # (n_samples, n_timesteps)

        end = div.argmin(dim=1)  # (n_samples,)
        mdd = div.gather(1, end.view(-1, 1)).view(-1)

        return -mdd
 def forward(self, x):
     if torch.__version__ >= '1.5.0':
         dim, flip = self.cummax_dim_flip[self.mode]
         if flip:
             x = x.flip(dim)
         pool_tensor, _ = torch.cummax(x, dim=dim)
         if flip:
             pool_tensor = pool_tensor.flip(dim)
         return pool_tensor
     else:
         return self.corner_pool.apply(x)
 def forward(self, input, warpfield):
     '''
     :param input: audio signal to be warped (B x 2 x T)
     :param warpfield: the corresponding warpfield (B x 2 x T)
     :return: the warped signal (B x 2 x T), ensured to be monotonous
     '''
     warpfield = self._to_absolute_positions(warpfield, input.shape[-1])
     # ensure monotonicity: each warp must be at least as big as previous_warp-1
     warpfield = th.cummax(warpfield, dim=-1)[0]
     # warp
     warped = self.warper(input, warpfield)
     return warped
Exemple #5
0
def boundariesfilt(score_arr, stepfilt_length, axis):
    if stepfilt_length > 0:
        temp_scores_fwd = stepfilt(
            score_arr, stepfilt_length, axis=axis
        )
        temp_scores_bwd = stepfilt(
            np.flip(score_arr, axis=axis), stepfilt_length, axis=axis
        )
    else:
        temp_scores_fwd = score_arr
        temp_scores_bwd = np.flip(score_arr, axis=axis)
    return np.minimum(
        torch.cummax(
            torch.from_numpy(temp_scores_fwd.copy()),
            dim=axis
        )[0].numpy(),
        np.flip(
            torch.cummax(
                torch.from_numpy(temp_scores_bwd.copy()),
                dim=axis
            )[0].numpy(),
            axis=axis
        ),
    )
    def parse(self, x, pos):
        """Parse input sentence.

    Args:
      x: input tokens (required).
      pos: position for each token (optional).
    Returns:
      distance: syntactic distance
      height: syntactic height
    """

        mask = (x != self.pad)
        mask_shifted = F.pad(mask[:, 1:], (0, 1), value=0)

        h = self.emb(x)
        for i in range(self.n_parse_layers):
            h = h.masked_fill(~mask[:, :, None], 0)
            h = self.parser_layers[i](h)

        height = self.height_ff(h).squeeze(-1)
        height.masked_fill_(~mask, -1e9)

        distance = self.distance_ff(h).squeeze(-1)
        distance.masked_fill_(~mask_shifted, 1e9)

        # Calbrating the distance and height to the same level
        length = distance.size(1)
        height_max = height[:, None, :].expand(-1, length, -1)
        height_max = torch.cummax(height_max.triu(0) -
                                  torch.ones_like(height_max).tril(-1) * 1e9,
                                  dim=-1)[0].triu(0)

        margin_left = torch.relu(
            F.pad(distance[:, :-1, None], (0, 0, 1, 0), value=1e9) -
            height_max)
        margin_right = torch.relu(distance[:, None, :] - height_max)
        margin = torch.where(margin_left > margin_right, margin_right,
                             margin_left).triu(0)

        margin_mask = torch.stack([mask_shifted] + [mask] * (length - 1),
                                  dim=1)
        margin.masked_fill_(~margin_mask, 0)
        margin = margin.max()

        distance = distance - margin

        return distance, height
Exemple #7
0
    def forward(self, x):
        if torch.__version__ != 'parrots' and torch.__version__ >= '1.5.0':
            if torch.onnx.is_in_onnx_export():
                assert torch.__version__ >= '1.7.0', \
                    'When `cummax` serves as an intermediate component whose '\
                    'outputs is used as inputs for another modules, it\'s '\
                    'expected that pytorch version must be >= 1.7.0, '\
                    'otherwise Error appears like: `RuntimeError: tuple '\
                    'appears in op that does not forward tuples, unsupported '\
                    'kind: prim::PythonOp`.'

            dim, flip = self.cummax_dim_flip[self.mode]
            if flip:
                x = x.flip(dim)
            pool_tensor, _ = torch.cummax(x, dim=dim)
            if flip:
                pool_tensor = pool_tensor.flip(dim)
            return pool_tensor
        else:
            return self.corner_pool.apply(x)
    def monotonize(self, input_data):
        # number of quantiles
        num_quantiles = input_data.size()[-1]

        # split into below 50% and above 50%
        idx_50 = num_quantiles // 2

        # if a small number of quantiles are estimated or quantile levels are not centered at 0.5
        if num_quantiles < 3 or self.quantile_levels[0, idx_50] != 0.5:
            return input_data

        # below 50%
        below_50 = input_data[:, :(idx_50 + 1)].contiguous()
        below_50 = torch.flip(
            torch.cummin(torch.flip(below_50, [-1]), -1)[0], [-1])

        # above 50%
        above_50 = input_data[:, idx_50:].contiguous()
        above_50 = torch.cummax(above_50, -1)[0]

        # refined output
        ordered_data = torch.cat([below_50[:, :-1], above_50], -1)
        return ordered_data
    def compute_lkd(self, arr_params, act, stim, side, return_details):
        '''
        Generates the loglikelihood (and prior)
        Params:
            arr_params (array): parameter of shape [nb_chains, nb_params]
            act (array of shape [nb_sessions, nb_trials]): action performed by the mice of shape
            stim (array of shape [nb_sessions, nb_trials]): stimulus contraste (between -1 and 1) observed by the mice
            side (array of shape [nb_sessions, nb_trials]): stimulus side (-1 (right), 1 (left)) observed by the mice
            return_details (boolean). If true, only return loglikelihood, else, return loglikelihood and prior
        Output:
            loglikelihood (array of length nb_chains): loglikelihood for each chain
            prior (array of shape [nb_sessions, nb_chains, nb_trials]): prior for each chain and session
        '''
        nb_chains = len(arr_params)
        if not self.repetition_bias and self.UB_fit and not self.loguniform_prior:
            tau0, tau1, tau2, gamma, zeta_pos, zeta_neg, lapse_pos, lapse_neg = torch.tensor(
                arr_params, device=self.device, dtype=torch.float32).T
            lb, tau, ub = tau0, tau0 + tau1, tau0 + tau1 + tau2
        elif self.UB_fit and self.repetition_bias and not self.loguniform_prior:
            tau0, tau1, tau2, gamma, zeta_pos, zeta_neg, lapse_pos, lapse_neg, rep_bias = torch.tensor(
                arr_params, device=self.device, dtype=torch.float32).T
            lb, tau, ub = tau0, tau0 + tau1, tau0 + tau1 + tau2
        elif not self.UB_fit and not self.loguniform_prior and not self.repetition_bias:
            tau0, tau1, gamma, zeta_pos, zeta_neg, lapse_pos, lapse_neg = torch.tensor(
                arr_params, device=self.device, dtype=torch.float32).T
            lb, tau, ub = tau0, tau0 + tau1, torch.zeros(
                len(gamma), device=self.device,
                dtype=torch.float32) + self.nb_blocklengths
        elif not self.UB_fit and self.loguniform_prior and self.repetition_bias:
            loglb, logtau, gamma, zeta_pos, zeta_neg, lapse_pos, lapse_neg, rep_bias = torch.tensor(
                arr_params, device=self.device, dtype=torch.float32).T
            lb, tau, ub = torch.exp(loglb), torch.exp(logtau), torch.zeros(
                len(gamma), device=self.device,
                dtype=torch.float32) + self.nb_blocklengths
        elif not self.UB_fit and self.loguniform_prior and not self.repetition_bias:
            loglb, logtau, gamma, zeta_pos, zeta_neg, lapse_pos, lapse_neg = torch.tensor(
                arr_params, device=self.device, dtype=torch.float32).T
            lb, tau, ub = torch.exp(loglb), torch.exp(logtau), torch.zeros(
                len(gamma), device=self.device,
                dtype=torch.float32) + self.nb_blocklengths
        else:
            raise ValueError('model instance not supported')
        act, stim, side = torch.tensor(act,
                                       device=self.device,
                                       dtype=torch.float32), torch.tensor(
                                           stim,
                                           device=self.device,
                                           dtype=torch.float32), torch.tensor(
                                               side,
                                               device=self.device,
                                               dtype=torch.float32)
        nb_sessions = len(act)

        alpha = torch.zeros([
            nb_sessions, nb_chains, act.shape[-1], self.nb_blocklengths,
            self.nb_typeblocks
        ],
                            device=self.device,
                            dtype=torch.float32)
        alpha[:, :, 0, 0, 1] = 1
        alpha = alpha.reshape(nb_sessions, nb_chains, -1,
                              self.nb_typeblocks * self.nb_blocklengths)
        h = torch.zeros([
            nb_sessions, nb_chains, self.nb_typeblocks * self.nb_blocklengths
        ],
                        device=self.device,
                        dtype=torch.float32)

        zetas = unsqueeze(zeta_pos) * (torch.unsqueeze(
            side, 1) > 0) + unsqueeze(zeta_neg) * (torch.unsqueeze(side, 1) <=
                                                   0)
        lapses = unsqueeze(lapse_pos) * (torch.unsqueeze(
            side, 1) > 0) + unsqueeze(lapse_neg) * (torch.unsqueeze(side, 1) <=
                                                    0)

        # build transition matrix
        b = torch.zeros([self.nb_blocklengths, 3, 3],
                        device=self.device,
                        dtype=torch.float32)
        b[1:][:, 0, 0], b[1:][:, 1, 1], b[1:][:, 2,
                                              2] = 1, 1, 1  # case when l_t > 0
        b[0][0][-1], b[0][-1][0], b[0][1][np.array(
            [0, 2])] = 1, 1, 1. / 2  # case when l_t = 1
        n = torch.unsqueeze(
            torch.arange(1,
                         self.nb_blocklengths + 1,
                         device=self.device,
                         dtype=torch.float32), 0)
        ref = torch.exp(-n / torch.unsqueeze(tau, -1)) * (torch.unsqueeze(
            lb, -1) <= n) * (torch.unsqueeze(ub, -1) >= n)
        hazard = torch.cummax(
            ref /
            torch.flip(torch.cumsum(torch.flip(ref, (1, )), 1) + 1e-18,
                       (1, )), 1)[0]
        padding = torch.unsqueeze(
            torch.zeros(self.nb_blocklengths - 1,
                        device=self.device,
                        dtype=torch.float32), 0)

        l = torch.stack([
            torch.tensor(torch.cat(
                (torch.unsqueeze(hazard[i], -1),
                 torch.cat((torch.diag(1 - hazard[i, :-1]), padding), axis=0)),
                axis=-1),
                         dtype=torch.float32) for i in range(len(hazard))
        ])  # l_{t-1}, l_t

        transition = torch.stack([
            1e-12 + torch.transpose(
                l[k][:, :, np.newaxis, np.newaxis] * b[np.newaxis], 1,
                2).reshape(self.nb_typeblocks * self.nb_blocklengths, -1)
            for k in range(len(l))
        ])
        # likelihood
        Rhos = Normal(loc=torch.unsqueeze(stim, 1), scale=zetas).cdf(0)
        ones = torch.ones((nb_chains, nb_sessions, act.shape[-1]),
                          device=self.device,
                          dtype=torch.float32)
        gamma_unsqueezed, side_unsqueezed = torch.unsqueeze(
            torch.unsqueeze(gamma, 1), -1), torch.unsqueeze(side, 0)
        lks = torch.stack([
            gamma_unsqueezed * (side_unsqueezed == -1) +
            (1 - gamma_unsqueezed) * (side_unsqueezed == 1), ones * 1. / 2,
            gamma_unsqueezed * (side_unsqueezed == 1) +
            (1 - gamma_unsqueezed) * (side_unsqueezed == -1)
        ]).T
        to_update = torch.unsqueeze(torch.unsqueeze(act != 0, -1), -1) * 1

        for i_trial in range(act.shape[-1]):
            if i_trial > 0:
                alpha[:, :, i_trial] = torch.sum(
                    torch.unsqueeze(h, -1) * transition, axis=2
                ) * to_update[:, i_trial - 1] + alpha[:, :, i_trial - 1] * (
                    1 - to_update[:, i_trial - 1])
            h = alpha[:, :, i_trial] * lks[i_trial].repeat(
                1, 1, self.nb_blocklengths)
            h = h / torch.unsqueeze(torch.sum(h, axis=-1), -1)

        predictive = torch.sum(
            alpha.reshape(nb_sessions, nb_chains, -1, self.nb_blocklengths,
                          self.nb_typeblocks), 3)
        Pis = predictive[:, :, :, 0] * unsqueeze(
            gamma) + predictive[:, :, :, 1] * 0.5 + predictive[:, :, :, 2] * (
                1 - unsqueeze(gamma))
        pRight, pLeft = Pis * Rhos, (1 - Pis) * (1 - Rhos)
        pActions = torch.stack(
            (pRight / (pRight + pLeft), pLeft / (pRight + pLeft)))

        unsqueezed_lapses = torch.unsqueeze(lapses, 0)
        if self.repetition_bias:
            unsqueezed_rep_bias = torch.unsqueeze(
                torch.unsqueeze(torch.unsqueeze(rep_bias, 0), 0), -1)
            pActions[:, :, :, 0] = pActions[:, :, :, 0] * (
                1 - unsqueezed_lapses[:, :, :, 0]) + unsqueezed_lapses[:, :, :,
                                                                       0] / 2.
            pActions[:, :, :, 1:] = pActions[:, :, :, 1:] * (
                1 - unsqueezed_lapses[:, :, :, 1:] - unsqueezed_rep_bias
            ) + unsqueezed_lapses[:, :, :,
                                  1:] / 2. + unsqueezed_rep_bias * torch.unsqueeze(
                                      torch.stack(((act[:, :-1] == -1) * 1,
                                                   (act[:, :-1] == 1) * 1)), 2)
        else:
            pActions = pActions * (1 - torch.unsqueeze(
                lapses, 0)) + torch.unsqueeze(lapses, 0) / 2.

        pActions[torch.isnan(pActions)] = 0

        p_ch = pActions[0] * (torch.unsqueeze(act, 1) == -1) + pActions[1] * (
            torch.unsqueeze(act, 1) == 1) + 1 * (torch.unsqueeze(
                act, 1) == 0)  # discard trials where agent did not answer

        priors = 1 - torch.tensor(Pis.detach(), device='cpu')
        p_ch_cpu = torch.tensor(p_ch.detach(), device='cpu')
        logp_ch = torch.log(
            torch.minimum(torch.maximum(p_ch_cpu, torch.tensor(1e-8)),
                          torch.tensor(1 - 1e-8)))

        # clean up gpu memory
        # if self.use_gpu:
        #     del tau0, tau1, tau2, gamma, zeta_pos, zeta_neg, lapse_pos, lapse_neg, lb, tau, ub, act, stim, side, s, lks
        #     del alpha, h, zetas, lapses, b, n, ref, hazard, padding, l, transition, ones, Rhos, gamma_unsqueezed
        #     del predictive, Pis, pRight, pLeft, pActions, p_ch, unsqueezed_lapses
        #     if self.repetition_bias:
        #         del rep_bias, unsqueezed_rep_bias
        #     torch.cuda.empty_cache()

        if return_details:
            return logp_ch, priors
        return np.array(torch.sum(logp_ch, axis=(0, -1)))
Exemple #10
0
 def other_ops(self):
     a = torch.randn(4)
     b = torch.randn(4)
     c = torch.randint(0, 8, (5, ), dtype=torch.int64)
     e = torch.randn(4, 3)
     f = torch.randn(4, 4, 4)
     size = [0, 1]
     dims = [0, 1]
     return (
         torch.atleast_1d(a),
         torch.atleast_2d(a),
         torch.atleast_3d(a),
         torch.bincount(c),
         torch.block_diag(a),
         torch.broadcast_tensors(a),
         torch.broadcast_to(a, (4)),
         # torch.broadcast_shapes(a),
         torch.bucketize(a, b),
         torch.cartesian_prod(a),
         torch.cdist(e, e),
         torch.clone(a),
         torch.combinations(a),
         torch.corrcoef(a),
         # torch.cov(a),
         torch.cross(e, e),
         torch.cummax(a, 0),
         torch.cummin(a, 0),
         torch.cumprod(a, 0),
         torch.cumsum(a, 0),
         torch.diag(a),
         torch.diag_embed(a),
         torch.diagflat(a),
         torch.diagonal(e),
         torch.diff(a),
         torch.einsum("iii", f),
         torch.flatten(a),
         torch.flip(e, dims),
         torch.fliplr(e),
         torch.flipud(e),
         torch.kron(a, b),
         torch.rot90(e),
         torch.gcd(c, c),
         torch.histc(a),
         torch.histogram(a),
         torch.meshgrid(a),
         torch.lcm(c, c),
         torch.logcumsumexp(a, 0),
         torch.ravel(a),
         torch.renorm(e, 1, 0, 5),
         torch.repeat_interleave(c),
         torch.roll(a, 1, 0),
         torch.searchsorted(a, b),
         torch.tensordot(e, e),
         torch.trace(e),
         torch.tril(e),
         torch.tril_indices(3, 3),
         torch.triu(e),
         torch.triu_indices(3, 3),
         torch.vander(a),
         torch.view_as_real(torch.randn(4, dtype=torch.cfloat)),
         torch.view_as_complex(torch.randn(4, 2)),
         torch.resolve_conj(a),
         torch.resolve_neg(a),
     )
Exemple #11
0
    def forward(self, agg_graph: dgl.DGLGraph, prop_graph: dgl.DGLGraph,
                traversal_order, new_node_ids) -> torch.Tensor:
        tg = agg_graph.local_var()
        pg = prop_graph.local_var()

        nfeat = tg.ndata["nfeat"]
        # h_self = nfeat
        h_self = self.encode_time(nfeat, tg.ndata["timestamp"])
        tg.ndata["nfeat"] = h_self
        tg.edata["efeat"] = self.fc_edge(tg.edata["efeat"])
        # efeat = tg.edata["efeat"]
        # tg.apply_edges(lambda edges: {
        #     "efeat":
        #     torch.cat((edges.src["nfeat"], edges.data["efeat"]), dim=1)
        # })
        # tg.edata["efeat"] = self.encode_time(tg.edata["efeat"], tg.edata["timestamp"])
        degs = tg.ndata["degree"]

        # agg_graph aggregation
        if self._agg_type == "pool":
            tg.edata["efeat"] = F.relu(self.fc_pool(tg.edata["efeat"]))
            tg.update_all(fn.u_add_e("nfeat", "efeat", "m"),
                          fn.max("m", "neigh"))
            h_neigh = tg.ndata["neigh"]
        elif self._agg_type in ["mean", "gcn", "lstm"]:
            tg.update_all(fn.u_add_e("nfeat", "efeat", "m"),
                          fn.sum("m", "neigh"))
            h_neigh = tg.ndata["neigh"]
        else:
            raise KeyError("Aggregator type {} not recognized.".format(
                self._agg_type))

        pg.ndata["neigh"] = h_neigh
        # prop_graph propagation
        if False:
            if self._agg_type == "mean":
                pg.prop_nodes(traversal_order,
                              message_func=fn.copy_src("neigh", "tmp"),
                              reduce_func=fn.sum("tmp", "acc"))
                h_neigh = h_neigh + pg.ndata["acc"]
                h_neigh = h_neigh / degs.unsqueeze(-1)
            elif self._agg_type == "gcn":
                pg.prop_nodes(traversal_order,
                              message_func=fn.copy_src("neigh", "tmp"),
                              reduce_func=fn.sum("tmp", "acc"))
                h_neigh = h_neigh + pg.ndata["acc"]
                h_neigh = (h_self + h_neigh) / (degs.unsqueeze(-1) + 1)
            elif self._agg_type == "pool":
                pg.prop_nodes(traversal_order,
                              message_func=fn.copy_src("neigh", "tmp"),
                              reduce_func=fn.max("tmp", "acc"))
                h_neigh = torch.max(h_neigh, pg.ndata["acc"])
            elif self._agg_type == "lstm":
                h_neighs = [
                    self._lstm_reducer(h_neigh[ids]) for ids in new_node_ids
                ]
                h_neighs = torch.cat(h_neighs, dim=0)
                ridx = torch.arange(h_neighs.shape[0])
                ridx[np.concatenate(new_node_ids)] = torch.arange(
                    h_neighs.shape[0])
                h_neigh = h_neighs[ridx]
        else:
            if self._agg_type == "mean":
                h_neighs = [
                    torch.cumsum(h_neigh[ids], dim=0) for ids in new_node_ids
                ]
                h_neighs = torch.cat(h_neighs, dim=0)
                ridx = torch.arange(h_neighs.shape[0])
                ridx[np.concatenate(new_node_ids)] = torch.arange(
                    h_neighs.shape[0])
                h_neigh = h_neighs[ridx]
                h_neigh = h_neigh / degs.unsqueeze(-1)
            elif self._agg_type == "gcn":
                h_neighs = [
                    torch.cumsum(h_neigh[ids], dim=0) for ids in new_node_ids
                ]
                h_neighs = torch.cat(h_neighs, dim=0)
                ridx = torch.arange(h_neighs.shape[0])
                ridx[np.concatenate(new_node_ids)] = torch.arange(
                    h_neighs.shape[0])
                h_neigh = h_neighs[ridx]
                h_neigh = (h_self + h_neigh) / (degs.unsqueeze(-1) + 1)
            elif self._agg_type == "pool":
                h_neighs = [
                    torch.cummax(h_neigh[ids], dim=0) for ids in new_node_ids
                ]
                h_neighs = torch.cat(h_neighs, dim=0)
                ridx = torch.arange(h_neighs.shape[0])
                ridx[np.concatenate(new_node_ids)] = torch.arange(
                    h_neighs.shape[0])
                h_neigh = h_neighs[ridx]
            elif self._agg_type == "lstm":
                h_neighs = [
                    self._lstm_reducer(h_neigh[ids]) for ids in new_node_ids
                ]
                h_neighs = torch.cat(h_neighs, dim=0)
                ridx = torch.arange(h_neighs.shape[0])
                ridx[np.concatenate(new_node_ids)] = torch.arange(
                    h_neighs.shape[0])
                h_neigh = h_neighs[ridx]

        if self._agg_type == "gcn":
            rst = self.fc_neigh(h_neigh)
        else:
            rst = self.fc_self(h_self) + self.fc_neigh(h_neigh)
        return rst
    def simulate(self, arr_params, stim, side, valid, nb_simul=50, ignore_likelihood=False):
        '''
        custom
        '''
        assert(stim.shape == side.shape), 'side and stim don\'t have the same shape'
        if self.repetition_bias:
            raise NotImplementedError

        nb_chains = len(arr_params)
        if arr_params.shape[-1] == 4:
            zeta_pos, zeta_neg, lapse_pos, lapse_neg = torch.tensor(arr_params, device=self.device, dtype=torch.float32).T
        else:
            raise NotImplementedError

        stim, side = torch.tensor(stim, device=self.device, dtype=torch.float32), torch.tensor(side, device=self.device, dtype=torch.float32)
        nb_sessions = len(stim)
        lb, tau, ub, gamma = 20, 60, 100, 0.8

        alpha = torch.zeros([nb_sessions, nb_chains, stim.shape[-1], self.nb_blocklengths, self.nb_typeblocks], device=self.device, dtype=torch.float32)
        alpha[:, :, 0, 0, 1] = 1
        alpha = alpha.reshape(nb_sessions, nb_chains, -1, self.nb_typeblocks * self.nb_blocklengths)
        h = torch.zeros([nb_sessions, nb_chains, self.nb_typeblocks * self.nb_blocklengths], device=self.device, dtype=torch.float32)

        if arr_params.shape[-1] == 4:
            zetas = unsqueeze(zeta_pos) * (torch.unsqueeze(side,1) > 0) + unsqueeze(zeta_neg) * (torch.unsqueeze(side,1) <= 0)
            lapses = unsqueeze(lapse_pos) * (torch.unsqueeze(side,1) > 0) + unsqueeze(lapse_neg) * (torch.unsqueeze(side,1) <= 0)
        else:
            zetas = unsqueeze(zeta)
            lapses = unsqueeze(lapse)

        # build transition matrix
        b = torch.zeros([self.nb_blocklengths, 3, 3], device=self.device, dtype=torch.float32)
        b[1:][:,0,0], b[1:][:,1,1], b[1:][:,2,2] = 1, 1, 1 # case when l_t > 0
        b[0][0][-1], b[0][-1][0], b[0][1][np.array([0, 2])] = 1, 1, 1./2 # case when l_t = 1
        n = torch.arange(1, self.nb_blocklengths+1, device=self.device, dtype=torch.float32)
        ref    = torch.exp(-n/tau) * (lb <= n) * (ub >= n)
        hazard = torch.cummax(ref/torch.flip(torch.cumsum(torch.flip(ref, (0,)), 0) + 1e-18, (0,)), 0)[0]
        padding = torch.zeros(self.nb_blocklengths-1, device=self.device, dtype=torch.float32)
        l = torch.cat((torch.unsqueeze(hazard, -1), torch.cat(
                    (torch.diag(1 - hazard[:-1]), padding[np.newaxis]), axis=0)), axis=-1) # l_{t-1}, l_t
        transition = 1e-12 + torch.transpose(l[:,:,np.newaxis,np.newaxis] * b[np.newaxis], 1, 2).reshape(self.nb_typeblocks * self.nb_blocklengths, -1)        

        # likelihood
        Rhos = Normal(loc=torch.unsqueeze(stim, 1), scale=zetas).cdf(0)
        ones = torch.ones((nb_sessions, stim.shape[-1]), device=self.device, dtype=torch.float32)
        lks = torch.stack([gamma*(side==-1) + (1-gamma) * (side==1), ones * 1./2, gamma*(side==1) + (1-gamma)*(side==-1)]).T

        for i_trial in range(stim.shape[-1]):
            # save priors
            if i_trial > 0:
                alpha[:, :, i_trial] = torch.sum(torch.unsqueeze(h, -1) * transition, axis=2)
            h = alpha[:, :, i_trial] * torch.unsqueeze(lks[i_trial], 1).repeat(1, 1, self.nb_blocklengths)
            h = h/torch.unsqueeze(torch.sum(h, axis=-1), -1)

        predictive = torch.sum(alpha.reshape(nb_sessions, nb_chains, -1, self.nb_blocklengths, self.nb_typeblocks), 3)
        Pis  = predictive[:, :, :, 0] * gamma + predictive[:, :, :, 1] * 0.5 + predictive[:, :, :, 2] * (1 - gamma)
        if not ignore_likelihood:
            pRight, pLeft = Pis * Rhos, (1 - Pis) * (1 - Rhos)
            pActions = torch.stack((pRight/(pRight + pLeft), pLeft/(pRight + pLeft)))
            pActions = pActions * (1 - torch.unsqueeze(lapses, 0)) + torch.unsqueeze(lapses, 0) / 2.
        else:
            pRight, pLeft = (Pis > 0.5)*1. + Pis * (Pis==0.5), ((1 - Pis) > 0.5) * 1. + Pis * (Pis==0.5)
            pActions = torch.stack((pRight/(pRight + pLeft), pLeft/(pRight + pLeft)))

        act_sim = 2 * (torch.rand(nb_sessions, nb_chains, stim.shape[-1], nb_simul) < torch.unsqueeze(pActions[1], -1)) - 1

        correct = (act_sim == side[:, np.newaxis, :, np.newaxis])
        correct = np.array(correct, dtype=np.float)
        valid_arr = np.tile(valid[:, np.newaxis,:,np.newaxis], (1, nb_chains, 1, nb_simul))
        correct[valid_arr==False] = np.nan
        perf = np.nanmean(correct, axis=(0, -2, -1))
        return perf
Exemple #13
0
 print(torch.amin(mat1, 0))  # 按列
 print(torch.amin(mat1, 1))  # 按行
 print(torch.argmin(mat1))  # 所有元素
 print(torch.argmin(mat1, 0))  # 按列
 print(torch.argmin(mat1, 1))  # 按行
 print(torch.argsort(mat1, 0))  # 按列, returns the indices
 print(torch.argsort(mat1, 1))  # 按行
 print(torch.topk(mat1, 2))
 # print(torch.msort(mat1))  # 按行
 print(torch.kthvalue(mat1, 1, 0))
 print(torch.kthvalue(mat1, 1, 1))
 print(torch.logsumexp(mat1, 1))  # 按行
 """cum"""
 print("cum function:")
 print(torch.logcumsumexp(x, dim=0))  # log (sigma(exp(xi)))
 print(torch.cummax(x, dim=0))
 print(torch.cummin(x, dim=0))
 print(torch.cumprod(x, dim=0))
 print(torch.cumsum(x, dim=0))
 """vec <> vec"""
 a = torch.tensor([9.7, float('nan'), 3.1, float('nan')])
 b = torch.tensor([-2.2, 0.5, float('nan'), float('nan')])
 c = torch.tensor([9.7, 1, 3.1, 4])
 d = torch.tensor([1.7, 1.2, 3.1, 2])
 print(torch.maximum(a, b))
 print(torch.minimum(a, b))
 print(torch.fmod(a, 2))
 print(torch.dist(c, d, 1))  # p-norm
 print(torch.norm(c))
 print(torch.div(c, d))
 print(torch.true_divide(c, d))  # rounding_mode=None