def create_small_net_with_conv_layer(self, conv_layer, outputs_per_channel): self.conv_layer = conv_layer self.conv_layer.set_inputs(self.input_layer) self.flatten_layer = layers.Flatten() self.flatten_layer.set_inputs(self.conv_layer) self.dense_layer = layers.Dense( kernel=(np.array([ list(itertools.chain(*[[1.0,-1.0] for i in range(outputs_per_channel)])) ]).T) .astype("float32"), bias=np.array([1]).astype("float32"), dense_mxts_mode=DenseMxtsMode.Linear) print(outputs_per_channel) print(self.dense_layer.kernel) self.dense_layer.set_inputs(self.flatten_layer) self.dense_layer.build_fwd_pass_vars() self.input_layer.reset_mxts_updated() self.dense_layer.set_scoring_mode(layers.ScoringMode.OneAndZeros) self.dense_layer.set_active() self.input_layer.update_mxts() self.inp = ((np.arange(16).reshape((2,2,4)) .astype("float32"))-8.0).transpose((0,2,1))
def create_small_net_with_pool_layer(self, pool_layer, outputs_per_channel): self.pool_layer = pool_layer self.pool_layer.set_inputs(self.input_layer) self.flatten_layer = layers.Flatten() self.flatten_layer.set_inputs(self.pool_layer) self.dense_layer = layers.Dense(kernel=(np.array([ list(itertools.chain(*[[2, 3] for i in range(outputs_per_channel)])) ])).astype("float32").T, bias=np.array([1]).astype("float32"), dense_mxts_mode=DenseMxtsMode.Linear) self.dense_layer.set_inputs(self.flatten_layer) self.dense_layer.build_fwd_pass_vars() self.dense_layer.set_scoring_mode(layers.ScoringMode.OneAndZeros) self.dense_layer.set_active() self.input_layer.update_mxts()
def prepare_batch_norm_deeplift_model(self, axis): self.input_layer = layers.Input(batch_shape=(None, 2, 2, 2)) self.batch_norm_layer = layers.BatchNormalization(gamma=self.gamma, beta=self.beta, axis=axis, mean=self.mean, var=self.var, epsilon=self.epsilon) self.batch_norm_layer.set_inputs(self.input_layer) self.flatten_layer = layers.Flatten() self.flatten_layer.set_inputs(self.batch_norm_layer) self.dense_layer = layers.Dense(kernel=np.ones( (1, 8)).astype("float32").T, bias=np.zeros(1).astype("float32"), dense_mxts_mode=DenseMxtsMode.Linear) self.dense_layer.set_inputs(self.flatten_layer) self.dense_layer.build_fwd_pass_vars() self.dense_layer.set_scoring_mode(layers.ScoringMode.OneAndZeros) self.dense_layer.set_active() self.dense_layer.update_task_index(0) self.input_layer.update_mxts()
def setUp(self): self.input_layer1 = layers.Input(batch_shape=(None, 1, 1, 1)) self.input_layer2 = layers.Input(batch_shape=(None, 1, 1, 1)) self.concat_layer = layers.Concat(axis=1) self.concat_layer.set_inputs([self.input_layer1, self.input_layer2]) self.flatten_layer = layers.Flatten() self.flatten_layer.set_inputs(self.concat_layer) self.dense_layer = layers.Dense(kernel=np.array([([1, 2])]).T, bias=[1], dense_mxts_mode=DenseMxtsMode.Linear) self.dense_layer.set_inputs(self.flatten_layer) self.dense_layer.build_fwd_pass_vars() self.input_layer1.reset_mxts_updated() self.input_layer2.reset_mxts_updated() self.dense_layer.set_scoring_mode(layers.ScoringMode.OneAndZeros) self.dense_layer.set_active() self.input_layer1.update_mxts() self.input_layer2.update_mxts() self.inp1 = np.arange(2).reshape((2, 1, 1, 1)) + 1 self.inp2 = np.arange(2).reshape((2, 1, 1, 1)) + 1