def make_flexible_recurrent_net(core_type: str, net_type: str, output_dims: int, num_units: Union[Sequence[int], int], num_layers: Optional[int], activation: Activation, activate_final: bool = False, name: Optional[str] = None, **unused_kwargs): """Commonly used for creating a flexible recurrences.""" if net_type != "mlp": raise ValueError("We do not support convolutional recurrent nets atm.") if unused_kwargs: logging.warning("Unused kwargs of `make_flexible_recurrent_net`: %s", str(unused_kwargs)) if isinstance(num_units, (list, tuple)): num_units = list(num_units) + [output_dims] num_layers = len(num_units) else: assert num_layers is not None num_units = [num_units] * (num_layers - 1) + [output_dims] name = name or f"{core_type.upper()}" activation = utils.get_activation(activation) core_list = [] for i, n in enumerate(num_units): if core_type.lower() == "vanilla": core_list.append(hk.VanillaRNN(hidden_size=n, name=f"{name}_{i}")) elif core_type.lower() == "lstm": core_list.append(hk.LSTM(hidden_size=n, name=f"{name}_{i}")) elif core_type.lower() == "gru": core_list.append(hk.GRU(hidden_size=n, name=f"{name}_{i}")) else: raise ValueError(f"Unrecognized core_type={core_type}.") if i != num_layers - 1: core_list.append(activation) if activate_final: core_list.append(activation) return hk.DeepRNN(core_list, name="RNN")
shape=(BATCH_SIZE, 128)), ModuleDescriptor( name="Conv1DLSTM", create=lambda: hk.Conv1DLSTM([2], 3, 3), shape=(BATCH_SIZE, 2, 2)), ModuleDescriptor( name="Conv2DLSTM", create=lambda: hk.Conv2DLSTM([2, 2], 3, 3), shape=(BATCH_SIZE, 2, 2, 2)), ModuleDescriptor( name="Conv3DLSTM", create=lambda: hk.Conv3DLSTM([2, 2, 2], 3, 3), shape=(BATCH_SIZE, 2, 2, 2, 2)), ModuleDescriptor( name="VanillaRNN", create=lambda: hk.VanillaRNN(8), shape=(BATCH_SIZE, 128)), ) def recurrent_factory( create_core: Callable[[], hk.RNNCore], unroller, ) -> Callable[[], Recurrent]: return lambda: Recurrent(create_core(), unroller) def unroll_descriptors(descriptors, unroller): """Returns `Recurrent` wrapped descriptors with the given unroller applied.""" out = [] for name, create, shape, dtype in descriptors: