Ejemplo n.º 1
0
    def create_process(self, vector: jnp.array):
        vector.astype(np.float)
        left_matrix = jnp.array([[0, 0, 1, 0], [0, 0, 0, 1]])
        right_matrix = jnp.array([[0, 0, 0], [1, 0, 0], [0, 1, 0], [0, 0, 1]])
        tags = jnp.dot(jnp.dot(left_matrix, vector), right_matrix)
        final_tags = tags
        final_tags[0][1:3] = tags[1][:, 1:3]

        res_tags = jnp.append(tags.reshape(1, length), final_tags)

        return res_tags
Ejemplo n.º 2
0
    def __init__(
        self,
        output_channels: int,
        kernel_shape: Union[int, Tuple[int, int]],
        resample_kernel: jnp.array,
        downsample_factor: int = 1,
        gain: float = 1.0,
        data_format: ChannelOrder = ChannelOrder.channels_last,
        name: str = None,
    ):
        super().__init__(name=name)
        if resample_kernel.ndim == 1:
            resample_kernel = resample_kernel[:,
                                              None] * resample_kernel[None, :]
        elif 0 <= resample_kernel.ndim > 2:
            raise ValueError(
                f"Resample kernel has invalid shape {resample_kernel.shape}")

        self.conv = hk.Conv2D(
            output_channels,
            kernel_shape=kernel_shape,
            stride=downsample_factor,
            padding="VALID",
            data_format=data_format.name,
        )
        self.resample_kernel = jnp.array(
            resample_kernel) * gain / resample_kernel.sum()
        self.downsample_factor = downsample_factor
        self.data_format = data_format
Ejemplo n.º 3
0
    def read_process(self, vector: jnp.array):
        # 2x4 4x4 4x3 -> 2x3
        # tags structure
        # srcNode: sTag iTag cTag
        # desNode: sTag iTag cTag
        # print(vector)
        for i, l in enumerate(vector):
            for j, t in enumerate(l):
                if isinstance(t, Tensor):
                    # print(type(vector))
                    # print(type(jax.ops.index[i, j]))
                    vector[i][j] = float(t.cpu().detach().numpy())
                    # jax.ops.index_update(vector, jax.ops.index[i,j], float(t.cpu().detach().numpy()))
                elif isinstance(t, np.ndarray):
                    vector[i][j] = float(t)
                    # jax.ops.index_update(vector, jax.ops.index[i,j], t.astype(float))
        vector = vector.astype('float64')
        # print("vector", vector.dtype)
        left_matrix = jnp.array([[0, 0, 1, 0], [0, 0, 0, 1]])
        right_matrix = jnp.array([[0, 0, 0], [1, 0, 0], [0, 1, 0], [0, 0, 1]])

        tags = jnp.dot(jnp.dot(left_matrix, vector), right_matrix)
        final_tags = (jax.ops.index_update(tags, jax.ops.index[0, 1:3],
                                           jnp.min(tags[:, 1:3],
                                                   axis=0))).reshape(
                                                       1, length)

        tags.reshape(1, length)
        res_tags = jnp.append(tags, final_tags)
        return res_tags
Ejemplo n.º 4
0
 def from_model_output(cls, *, logits: jnp.array, labels: jnp.array,
                       **kwargs) -> Metric:
   if logits.ndim != labels.ndim + 1 or labels.dtype != jnp.int32:
     raise ValueError(
         f"Expected labels.dtype==jnp.int32 and logits.ndim={logits.ndim}=="
         f"labels.ndim+1={labels.ndim + 1}")
   return super().from_model_output(
       values=(logits.argmax(axis=-1) == labels).astype(jnp.float32), **kwargs)
Ejemplo n.º 5
0
def log_likelihood(x: np.array, theta: np.array):
    """
    Calculate the log likelihood of the data given the posterior.
    """
    assert theta.ndim == 1 and len(theta) == 5, "theta must be a 1D array of length 5"
    mu, Sigma = _calc_vars(theta)

    x = x.reshape((4, 2))
    return jax.scipy.stats.multivariate_normal.logpdf(x, mean=mu, cov=Sigma).sum()
Ejemplo n.º 6
0
    def __init__(self, inp_feat_uncentered_gram:np.array) -> None:
        """Center input for input features of a centered operator.
        To be applied to uncentered feature vector Φ = [Φ_1, …, Φ_n].

        Args:
            inp_feat_uncentered_gram (np.array): The output of inp_feat_uncentered.inner(), where inp_feat_uncentered == Φ.
        """
        assert len(inp_feat_uncentered_gram.shape) == 2
        assert inp_feat_uncentered_gram.shape[0] == inp_feat_uncentered_gram.shape[1]
        mean = inp_feat_uncentered_gram.mean(axis = 1, keepdims=True)
        self.const_term = mean.mean() - mean 
Ejemplo n.º 7
0
 def from_model_output(cls,
                       values: jnp.array,
                       mask: Optional[jnp.array] = None,
                       **_) -> Metric:
   if values.ndim == 0:
     values = values[None]
   utils.check_param(values, ndim=1)
   if mask is None:
     mask = jnp.ones(values.shape[0])
   return cls(
       total=values.sum(),
       sum_of_squares=jnp.where(mask, values**2, jnp.zeros_like(values)).sum(),
       count=mask.sum(),
   )
Ejemplo n.º 8
0
    def __init__(self,
                 inp_feat: InpVecT,
                 outp_feat: OutVecT,
                 matr: np.array,
                 mean_center_inp: bool = False,
                 decenter_outp: bool = False,
                 normalize=False,
                 outp_bias: np.array = None):
        assert not (
            decenter_outp and outp_bias is not None
        ), "Either decenter_outp == True or outp_bias != None, but not both"
        self.matr = matr
        self.mean_center_inp = mean_center_inp

        if not mean_center_inp:
            self.inp_feat = inp_feat
            self.outp_feat = outp_feat
        else:
            self.inp_feat = inp_feat.extend_reduce(
                [CenterInpFeat(inp_feat.inner())])
            self.outp_feat = outp_feat
        self._normalize = normalize

        self.debias_outp = decenter_outp
        if decenter_outp:
            outp_bias = np.ones((1, len(outp_feat))) / len(outp_feat)
        else:
            outp_bias = np.zeros((1, len(outp_feat)))

        if outp_bias is not None:
            assert outp_bias.shape[1] == len(outp_feat)
            if len(outp_bias.squeeze().shape) == 1:
                self.bias = outp_bias.squeeze()[np.newaxis, :]
            else:
                assert outp_bias.shape[0] == len(outp_feat)
                self.bias = outp_bias
Ejemplo n.º 9
0
def sample(rng, theta: np.array, num_samples_per_theta: int):
    """
    Sample from the posterior.
    """

    def _sample(_th):
        mu, Sigma = _calc_vars(_th)
        samp = jax.random.multivariate_normal(
            rng, mu, Sigma, shape=(4 * num_samples_per_theta, 1)
        )
        samp = np.reshape(samp, (num_samples_per_theta, 4, -1))
        return samp

    assert theta.shape[-1], "theta must be a 1/2D array with 5D final dim"
    if theta.ndim == 1:
        theta = theta.reshape((1, theta.shape[0]))
    return np.reshape(
        jax.vmap(_sample)(theta), (theta.shape[0] * num_samples_per_theta, -1)
    )
Ejemplo n.º 10
0
 def __init__(
     self,
     resample_kernel: jnp.array,
     upsample_factor: int = 1,
     gain: float = 1.0,
     data_format: ChannelOrder = ChannelOrder.channels_last,
     name: str = None,
 ):
     super().__init__(name=name)
     if resample_kernel.ndim == 1:
         resample_kernel = resample_kernel[:,
                                           None] * resample_kernel[None, :]
     elif 0 <= resample_kernel.ndim > 2:
         raise ValueError(
             f"Resample kernel has invalid shape {resample_kernel.shape}")
     self.resample_kernel = jnp.array(
         resample_kernel) * gain / resample_kernel.sum()
     self.upsample_factor = upsample_factor
     self.data_format = data_format
Ejemplo n.º 11
0
    def __init__(self,
                 start: np.array,
                 periodicities: np.array = None,
                 stepsizes: np.array = None,
                 example: np.array = None):
        """An IndexRollout object zig-zags through indices with periodicities given in initialization.
        At most one periodicity can be NaN, in which case it is taken to be ever-increasing

        Args:
            periodicities (np.array): The periodicities, the first one can be np.inf. When a digit would become larger than its periodicity, the previous digit is increased
            stepsizes (np.array, optional): [description]. Defaults to None. Step sizes for increasing the index.
            example (np.array, optional): [description]. Defaults to None. Examples of consecutive indices (in rows) for inferring step sizes.
        """
        self.current = start

        if example is not None:
            self.stepsizes = np.array(
                [np.median(x[x > 0]) for x in (example[1:] - example[:-1]).T])
            self.periodicities = example.max(0) + self.stepsizes
        else:
            assert stepsizes is not None and periodicities is not None
            self.stepsizes = stepsizes
            self.periodicities = periodicities
Ejemplo n.º 12
0
 def __call__(self, inp: np.array, axis: int = 0) -> np.array:
     assert self.prefactors.shape[0] == inp.shape[axis]
     return inp.astype(self.prefactors.dtype) * np.expand_dims(
         self.prefactors, axis=(axis + 1) % 2)
Ejemplo n.º 13
0
 def __call__(self, inp: np.array, axis: int = 0) -> np.array:
     return inp.mean(axis=axis, keepdims=True)
Ejemplo n.º 14
0
 def __call__(self, inp: np.array, axis: int = 0) -> np.array:
     return inp.astype(self.prefactors.dtype) * np.expand_dims(
         self.prefactors, axis=(axis + 1) % 2)
Ejemplo n.º 15
0
 def get_read_grad(self, vector: jnp.array):
     vector = vector.astype('float64')
     grad = self.jrev(vector)
     # grad = jit(jacrev(EventProcessor.write_process)(vector))
     # [12 * 4 * 4]
     return grad[:, 1, :]
Ejemplo n.º 16
0
 def reduce_first_ax(self, inp:np.array) -> np.array:
     return inp - inp.mean(0, keepdims=True) + self.const_term
Ejemplo n.º 17
0
 def get_exec_grad(self, vector: jnp.array):
     vector = vector.astype('float64')
     grad = self.jrev(vector)
     # [12 * 4 * 4]
     return grad[:, 1, :]
Ejemplo n.º 18
0
    def write_process(self, vector: jnp.array):
        benign_thresh = vector[1][2]
        susp_thresh = vector[1][3]
        a_b = vector[1][0]
        a_e = vector[1][1]

        # print(vector)
        vector.astype(float)
        left_matrix = jnp.array([[0, 0, 1, 0], [0, 0, 0, 1]])
        right_matrix = jnp.array([[0, 0, 0], [1, 0, 0], [0, 1, 0], [0, 0, 1]])
        # print("left_matrix: ", left_matrix.shape)
        # print("vector: ", vector.shape)
        # print("right_matrix: ", right_matrix.shape)
        # print(vector)
        # some values in the vector are torch tensor but not a number, so they are replaced by their data, losing the autograd trace

        # print("left_matrix", left_matrix)
        # print("vector", vector)
        # print("vector", vector)
        # vector = jnp.array(vector)

        for i, l in enumerate(vector):
            for j, t in enumerate(l):
                if isinstance(t, Tensor):
                    # print(type(vector))
                    # print(type(jax.ops.index[i, j]))
                    vector[i][j] = float(t.cpu().detach().numpy())
                    # jax.ops.index_update(vector, jax.ops.index[i,j], float(t.cpu().detach().numpy()))
                elif isinstance(t, np.ndarray):
                    vector[i][j] = float(t)
                    # jax.ops.index_update(vector, jax.ops.index[i,j], t.astype(float))

        # print("left_matrix", left_matrix)
        # print("vector", vector)
        # print(left_matrix.dtype)
        # print(vector.dtype)
        # for l in vector:
        #     for t in l:
        #         print(t.dtype)
        # print(type(vector))
        vector = vector.astype('float64')
        tmp = jnp.dot(left_matrix, vector)
        tags = jnp.dot(tmp, right_matrix)

        benign_mul = benign_thresh + susp_thresh
        susp_mul = (1 - benign_thresh) + susp_thresh
        dangerous_mul = (1 - benign_thresh) + (1 - susp_thresh)

        # print("a_b", type(a_b))
        # print("a_e", type(a_e))
        if isinstance(a_b, Tensor):
            a_b = float(a_b.cpu().detach().numpy())
        if isinstance(a_e, Tensor):
            a_e = float(a_e.cpu().detach().numpy())
        attenuation_b = jnp.array([[0, a_b, a_b], [0, 0, 0]])
        attenuation_e = jnp.array([[0, a_e, a_e], [0, 0, 0]])

        tag_benign = (jax.ops.index_update(
            tags, jax.ops.index[0, 1:3],
            jnp.min(tags + attenuation_b, axis=0)[1:3])).reshape(1, length)
        tag_susp_env = (jax.ops.index_update(
            tags, jax.ops.index[0, 1:3],
            jnp.min(tags + attenuation_e, axis=0)[1:3])).reshape(1, length)
        tag_dangerous = (jax.ops.index_update(tags, jax.ops.index[0, 1:3],
                                              jnp.min(tags[:, 1:3],
                                                      axis=0))).reshape(
                                                          1, length)

        possible_tags = jnp.concatenate(
            [tag_benign, tag_susp_env, tag_dangerous])
        tags_probability = jax.nn.softmax(
            jnp.array([benign_mul, susp_mul, dangerous_mul]))

        final_tags = jnp.dot(tags_probability, possible_tags)

        # if tags[0][0] >= benign:
        #     attenuation = jnp.array([[0, a_b, a_b], [0, 0, 0]])
        #     jax.ops.index_update(tags, jax.ops.index[0, 1:3], jnp.min(tags + attenuation, axis=0)[1:3])
        # elif tags[0][0] >= suspect_env:
        #     attenuation = jnp.array([[0, a_e, a_e], [0, 0, 0]])
        #     jax.ops.index_update(tags, jax.ops.index[0, 1:3], jnp.min(tags + attenuation, axis=0)[1:3])
        # else:
        #     jax.ops.index_update(tags, jax.ops.index[0, 1:3], jnp.min(tags[:, 1:3], axis=0))
        res_tags = jnp.append(tags.reshape(1, length), final_tags)
        return res_tags