def testDropoutWrapperSerialization(self): wrapper_cls = cell_wrappers.DropoutWrapper cell = layers.GRUCell(10) wrapper = wrapper_cls(cell) config = wrapper.get_config() reconstructed_wrapper = wrapper_cls.from_config(config) self.assertDictEqual(config, reconstructed_wrapper.get_config()) self.assertIsInstance(reconstructed_wrapper, wrapper_cls) wrapper = wrapper_cls(cell, dropout_state_filter_visitor=lambda s: True) config = wrapper.get_config() reconstructed_wrapper = wrapper_cls.from_config(config) self.assertTrue(reconstructed_wrapper._dropout_state_filter(None)) def dropout_state_filter_visitor(unused_state): return False wrapper = wrapper_cls( cell, dropout_state_filter_visitor=dropout_state_filter_visitor ) config = wrapper.get_config() reconstructed_wrapper = wrapper_cls.from_config(config) self.assertFalse(reconstructed_wrapper._dropout_state_filter(None))
def distributed_cell(inputs): """ Creates a functional wrapper over RNN cell, applying it on each timestep without propagating hidden states over timesteps """ assert len(inputs) == 2 shapes = [elem._keras_shape for elem in inputs] # no shape validation, assuming all dims of inputs[0] and inputs[1] are equal input_dim, units, ndims = shapes[0][-1], shapes[1][-1], len(shapes[0]) if ndims > 3: dims_order = (1, ) + tuple(range(2, ndims)) + (2, ) inputs = [ kl.Permute(dims_order)(inputs[0]), kl.Permute(dims_order)(inputs[0]) ] first_shape, second_shape = shapes[0][2:], shapes[1][2:] cell = kl.GRUCell(units, input_shape=first_shape, implementation=0) if not cell.built: cell.build(first_shape) concatenated_inputs = kl.Concatenate()(inputs) def timestep_func(x): cell_inputs = x[..., :input_dim] cell_states = x[..., None, input_dim:] cell_output = cell.call(cell_inputs, cell_states) return cell_output[0] func = kl.TimeDistributed( kl.Lambda(timestep_func, output_shape=second_shape)) answer = func(concatenated_inputs) if ndims > 3: reverse_dims_order = (1, ndims - 1) + tuple(range(2, ndims - 1)) answer = kl.Permute(reverse_dims_order)(answer) return answer
def build_model_gru_cell(): inputs = keras.Input(shape=(None, ModelConfig.L_FRAME // 2 + 1)) rnn_cells = [ layers.GRUCell(ModelConfig.HID_SIZE) for _ in range(ModelConfig.NUM_LAYERS) ] stacked_gru = layers.StackedRNNCells(rnn_cells) output_rnn = layers.RNN(stacked_gru)(inputs) input_size = np.shape(inputs)[2] src1 = layers.Dense(input_size, activation="relu")(output_rnn) # src2_pre = layers.Dense(input_size, activation="relu")(output_rnn) # time-freq masking layer # src1 = src1_pre / (src1_pre + src2_pre + np.finfo(float).eps) * inputs # src2 = src2_pre / (src1_pre + src2_pre + np.finfo(float).eps) * inputs model = keras.Model(inputs=inputs, outputs=src1, name="GRUCell_model") model.summary() return model
def __init__(self, units, **kwargs): self.grucell = L.GRUCell(units, name=kwargs['name'] + '_gru') # Internal cell super().__init__(self.grucell, **kwargs)