Ejemplo n.º 1
0
def make_flexible_recurrent_net(core_type: str,
                                net_type: str,
                                output_dims: int,
                                num_units: Union[Sequence[int], int],
                                num_layers: Optional[int],
                                activation: Activation,
                                activate_final: bool = False,
                                name: Optional[str] = None,
                                **unused_kwargs):
    """Commonly used for creating a flexible recurrences."""
    if net_type != "mlp":
        raise ValueError("We do not support convolutional recurrent nets atm.")
    if unused_kwargs:
        logging.warning("Unused kwargs of `make_flexible_recurrent_net`: %s",
                        str(unused_kwargs))

    if isinstance(num_units, (list, tuple)):
        num_units = list(num_units) + [output_dims]
        num_layers = len(num_units)
    else:
        assert num_layers is not None
        num_units = [num_units] * (num_layers - 1) + [output_dims]
    name = name or f"{core_type.upper()}"

    activation = utils.get_activation(activation)
    core_list = []
    for i, n in enumerate(num_units):
        if core_type.lower() == "vanilla":
            core_list.append(hk.VanillaRNN(hidden_size=n, name=f"{name}_{i}"))
        elif core_type.lower() == "lstm":
            core_list.append(hk.LSTM(hidden_size=n, name=f"{name}_{i}"))
        elif core_type.lower() == "gru":
            core_list.append(hk.GRU(hidden_size=n, name=f"{name}_{i}"))
        else:
            raise ValueError(f"Unrecognized core_type={core_type}.")
        if i != num_layers - 1:
            core_list.append(activation)
    if activate_final:
        core_list.append(activation)

    return hk.DeepRNN(core_list, name="RNN")
Ejemplo n.º 2
0
        shape=(BATCH_SIZE, 128)),
    ModuleDescriptor(
        name="Conv1DLSTM",
        create=lambda: hk.Conv1DLSTM([2], 3, 3),
        shape=(BATCH_SIZE, 2, 2)),
    ModuleDescriptor(
        name="Conv2DLSTM",
        create=lambda: hk.Conv2DLSTM([2, 2], 3, 3),
        shape=(BATCH_SIZE, 2, 2, 2)),
    ModuleDescriptor(
        name="Conv3DLSTM",
        create=lambda: hk.Conv3DLSTM([2, 2, 2], 3, 3),
        shape=(BATCH_SIZE, 2, 2, 2, 2)),
    ModuleDescriptor(
        name="VanillaRNN",
        create=lambda: hk.VanillaRNN(8),
        shape=(BATCH_SIZE, 128)),
)


def recurrent_factory(
    create_core: Callable[[], hk.RNNCore],
    unroller,
) -> Callable[[], Recurrent]:
  return lambda: Recurrent(create_core(), unroller)


def unroll_descriptors(descriptors, unroller):
  """Returns `Recurrent` wrapped descriptors with the given unroller applied."""
  out = []
  for name, create, shape, dtype in descriptors: