def create_process(self, vector: jnp.array): vector.astype(np.float) left_matrix = jnp.array([[0, 0, 1, 0], [0, 0, 0, 1]]) right_matrix = jnp.array([[0, 0, 0], [1, 0, 0], [0, 1, 0], [0, 0, 1]]) tags = jnp.dot(jnp.dot(left_matrix, vector), right_matrix) final_tags = tags final_tags[0][1:3] = tags[1][:, 1:3] res_tags = jnp.append(tags.reshape(1, length), final_tags) return res_tags
def __init__( self, output_channels: int, kernel_shape: Union[int, Tuple[int, int]], resample_kernel: jnp.array, downsample_factor: int = 1, gain: float = 1.0, data_format: ChannelOrder = ChannelOrder.channels_last, name: str = None, ): super().__init__(name=name) if resample_kernel.ndim == 1: resample_kernel = resample_kernel[:, None] * resample_kernel[None, :] elif 0 <= resample_kernel.ndim > 2: raise ValueError( f"Resample kernel has invalid shape {resample_kernel.shape}") self.conv = hk.Conv2D( output_channels, kernel_shape=kernel_shape, stride=downsample_factor, padding="VALID", data_format=data_format.name, ) self.resample_kernel = jnp.array( resample_kernel) * gain / resample_kernel.sum() self.downsample_factor = downsample_factor self.data_format = data_format
def read_process(self, vector: jnp.array): # 2x4 4x4 4x3 -> 2x3 # tags structure # srcNode: sTag iTag cTag # desNode: sTag iTag cTag # print(vector) for i, l in enumerate(vector): for j, t in enumerate(l): if isinstance(t, Tensor): # print(type(vector)) # print(type(jax.ops.index[i, j])) vector[i][j] = float(t.cpu().detach().numpy()) # jax.ops.index_update(vector, jax.ops.index[i,j], float(t.cpu().detach().numpy())) elif isinstance(t, np.ndarray): vector[i][j] = float(t) # jax.ops.index_update(vector, jax.ops.index[i,j], t.astype(float)) vector = vector.astype('float64') # print("vector", vector.dtype) left_matrix = jnp.array([[0, 0, 1, 0], [0, 0, 0, 1]]) right_matrix = jnp.array([[0, 0, 0], [1, 0, 0], [0, 1, 0], [0, 0, 1]]) tags = jnp.dot(jnp.dot(left_matrix, vector), right_matrix) final_tags = (jax.ops.index_update(tags, jax.ops.index[0, 1:3], jnp.min(tags[:, 1:3], axis=0))).reshape( 1, length) tags.reshape(1, length) res_tags = jnp.append(tags, final_tags) return res_tags
def from_model_output(cls, *, logits: jnp.array, labels: jnp.array, **kwargs) -> Metric: if logits.ndim != labels.ndim + 1 or labels.dtype != jnp.int32: raise ValueError( f"Expected labels.dtype==jnp.int32 and logits.ndim={logits.ndim}==" f"labels.ndim+1={labels.ndim + 1}") return super().from_model_output( values=(logits.argmax(axis=-1) == labels).astype(jnp.float32), **kwargs)
def log_likelihood(x: np.array, theta: np.array): """ Calculate the log likelihood of the data given the posterior. """ assert theta.ndim == 1 and len(theta) == 5, "theta must be a 1D array of length 5" mu, Sigma = _calc_vars(theta) x = x.reshape((4, 2)) return jax.scipy.stats.multivariate_normal.logpdf(x, mean=mu, cov=Sigma).sum()
def __init__(self, inp_feat_uncentered_gram:np.array) -> None: """Center input for input features of a centered operator. To be applied to uncentered feature vector Φ = [Φ_1, …, Φ_n]. Args: inp_feat_uncentered_gram (np.array): The output of inp_feat_uncentered.inner(), where inp_feat_uncentered == Φ. """ assert len(inp_feat_uncentered_gram.shape) == 2 assert inp_feat_uncentered_gram.shape[0] == inp_feat_uncentered_gram.shape[1] mean = inp_feat_uncentered_gram.mean(axis = 1, keepdims=True) self.const_term = mean.mean() - mean
def from_model_output(cls, values: jnp.array, mask: Optional[jnp.array] = None, **_) -> Metric: if values.ndim == 0: values = values[None] utils.check_param(values, ndim=1) if mask is None: mask = jnp.ones(values.shape[0]) return cls( total=values.sum(), sum_of_squares=jnp.where(mask, values**2, jnp.zeros_like(values)).sum(), count=mask.sum(), )
def __init__(self, inp_feat: InpVecT, outp_feat: OutVecT, matr: np.array, mean_center_inp: bool = False, decenter_outp: bool = False, normalize=False, outp_bias: np.array = None): assert not ( decenter_outp and outp_bias is not None ), "Either decenter_outp == True or outp_bias != None, but not both" self.matr = matr self.mean_center_inp = mean_center_inp if not mean_center_inp: self.inp_feat = inp_feat self.outp_feat = outp_feat else: self.inp_feat = inp_feat.extend_reduce( [CenterInpFeat(inp_feat.inner())]) self.outp_feat = outp_feat self._normalize = normalize self.debias_outp = decenter_outp if decenter_outp: outp_bias = np.ones((1, len(outp_feat))) / len(outp_feat) else: outp_bias = np.zeros((1, len(outp_feat))) if outp_bias is not None: assert outp_bias.shape[1] == len(outp_feat) if len(outp_bias.squeeze().shape) == 1: self.bias = outp_bias.squeeze()[np.newaxis, :] else: assert outp_bias.shape[0] == len(outp_feat) self.bias = outp_bias
def sample(rng, theta: np.array, num_samples_per_theta: int): """ Sample from the posterior. """ def _sample(_th): mu, Sigma = _calc_vars(_th) samp = jax.random.multivariate_normal( rng, mu, Sigma, shape=(4 * num_samples_per_theta, 1) ) samp = np.reshape(samp, (num_samples_per_theta, 4, -1)) return samp assert theta.shape[-1], "theta must be a 1/2D array with 5D final dim" if theta.ndim == 1: theta = theta.reshape((1, theta.shape[0])) return np.reshape( jax.vmap(_sample)(theta), (theta.shape[0] * num_samples_per_theta, -1) )
def __init__( self, resample_kernel: jnp.array, upsample_factor: int = 1, gain: float = 1.0, data_format: ChannelOrder = ChannelOrder.channels_last, name: str = None, ): super().__init__(name=name) if resample_kernel.ndim == 1: resample_kernel = resample_kernel[:, None] * resample_kernel[None, :] elif 0 <= resample_kernel.ndim > 2: raise ValueError( f"Resample kernel has invalid shape {resample_kernel.shape}") self.resample_kernel = jnp.array( resample_kernel) * gain / resample_kernel.sum() self.upsample_factor = upsample_factor self.data_format = data_format
def __init__(self, start: np.array, periodicities: np.array = None, stepsizes: np.array = None, example: np.array = None): """An IndexRollout object zig-zags through indices with periodicities given in initialization. At most one periodicity can be NaN, in which case it is taken to be ever-increasing Args: periodicities (np.array): The periodicities, the first one can be np.inf. When a digit would become larger than its periodicity, the previous digit is increased stepsizes (np.array, optional): [description]. Defaults to None. Step sizes for increasing the index. example (np.array, optional): [description]. Defaults to None. Examples of consecutive indices (in rows) for inferring step sizes. """ self.current = start if example is not None: self.stepsizes = np.array( [np.median(x[x > 0]) for x in (example[1:] - example[:-1]).T]) self.periodicities = example.max(0) + self.stepsizes else: assert stepsizes is not None and periodicities is not None self.stepsizes = stepsizes self.periodicities = periodicities
def __call__(self, inp: np.array, axis: int = 0) -> np.array: assert self.prefactors.shape[0] == inp.shape[axis] return inp.astype(self.prefactors.dtype) * np.expand_dims( self.prefactors, axis=(axis + 1) % 2)
def __call__(self, inp: np.array, axis: int = 0) -> np.array: return inp.mean(axis=axis, keepdims=True)
def __call__(self, inp: np.array, axis: int = 0) -> np.array: return inp.astype(self.prefactors.dtype) * np.expand_dims( self.prefactors, axis=(axis + 1) % 2)
def get_read_grad(self, vector: jnp.array): vector = vector.astype('float64') grad = self.jrev(vector) # grad = jit(jacrev(EventProcessor.write_process)(vector)) # [12 * 4 * 4] return grad[:, 1, :]
def reduce_first_ax(self, inp:np.array) -> np.array: return inp - inp.mean(0, keepdims=True) + self.const_term
def get_exec_grad(self, vector: jnp.array): vector = vector.astype('float64') grad = self.jrev(vector) # [12 * 4 * 4] return grad[:, 1, :]
def write_process(self, vector: jnp.array): benign_thresh = vector[1][2] susp_thresh = vector[1][3] a_b = vector[1][0] a_e = vector[1][1] # print(vector) vector.astype(float) left_matrix = jnp.array([[0, 0, 1, 0], [0, 0, 0, 1]]) right_matrix = jnp.array([[0, 0, 0], [1, 0, 0], [0, 1, 0], [0, 0, 1]]) # print("left_matrix: ", left_matrix.shape) # print("vector: ", vector.shape) # print("right_matrix: ", right_matrix.shape) # print(vector) # some values in the vector are torch tensor but not a number, so they are replaced by their data, losing the autograd trace # print("left_matrix", left_matrix) # print("vector", vector) # print("vector", vector) # vector = jnp.array(vector) for i, l in enumerate(vector): for j, t in enumerate(l): if isinstance(t, Tensor): # print(type(vector)) # print(type(jax.ops.index[i, j])) vector[i][j] = float(t.cpu().detach().numpy()) # jax.ops.index_update(vector, jax.ops.index[i,j], float(t.cpu().detach().numpy())) elif isinstance(t, np.ndarray): vector[i][j] = float(t) # jax.ops.index_update(vector, jax.ops.index[i,j], t.astype(float)) # print("left_matrix", left_matrix) # print("vector", vector) # print(left_matrix.dtype) # print(vector.dtype) # for l in vector: # for t in l: # print(t.dtype) # print(type(vector)) vector = vector.astype('float64') tmp = jnp.dot(left_matrix, vector) tags = jnp.dot(tmp, right_matrix) benign_mul = benign_thresh + susp_thresh susp_mul = (1 - benign_thresh) + susp_thresh dangerous_mul = (1 - benign_thresh) + (1 - susp_thresh) # print("a_b", type(a_b)) # print("a_e", type(a_e)) if isinstance(a_b, Tensor): a_b = float(a_b.cpu().detach().numpy()) if isinstance(a_e, Tensor): a_e = float(a_e.cpu().detach().numpy()) attenuation_b = jnp.array([[0, a_b, a_b], [0, 0, 0]]) attenuation_e = jnp.array([[0, a_e, a_e], [0, 0, 0]]) tag_benign = (jax.ops.index_update( tags, jax.ops.index[0, 1:3], jnp.min(tags + attenuation_b, axis=0)[1:3])).reshape(1, length) tag_susp_env = (jax.ops.index_update( tags, jax.ops.index[0, 1:3], jnp.min(tags + attenuation_e, axis=0)[1:3])).reshape(1, length) tag_dangerous = (jax.ops.index_update(tags, jax.ops.index[0, 1:3], jnp.min(tags[:, 1:3], axis=0))).reshape( 1, length) possible_tags = jnp.concatenate( [tag_benign, tag_susp_env, tag_dangerous]) tags_probability = jax.nn.softmax( jnp.array([benign_mul, susp_mul, dangerous_mul])) final_tags = jnp.dot(tags_probability, possible_tags) # if tags[0][0] >= benign: # attenuation = jnp.array([[0, a_b, a_b], [0, 0, 0]]) # jax.ops.index_update(tags, jax.ops.index[0, 1:3], jnp.min(tags + attenuation, axis=0)[1:3]) # elif tags[0][0] >= suspect_env: # attenuation = jnp.array([[0, a_e, a_e], [0, 0, 0]]) # jax.ops.index_update(tags, jax.ops.index[0, 1:3], jnp.min(tags + attenuation, axis=0)[1:3]) # else: # jax.ops.index_update(tags, jax.ops.index[0, 1:3], jnp.min(tags[:, 1:3], axis=0)) res_tags = jnp.append(tags.reshape(1, length), final_tags) return res_tags