def func(S, is_training): flatten = hk.Flatten() batch_norm_m = hk.BatchNorm(create_scale=True, create_offset=True, decay_rate=0.95) batch_norm_v = hk.BatchNorm(create_scale=True, create_offset=True, decay_rate=0.95) batch_norm_m = partial(batch_norm_m, is_training=is_training) batch_norm_v = partial(batch_norm_v, is_training=is_training) mu = hk.Sequential(( hk.Linear(7), batch_norm_m, jnp.tanh, hk.Linear(3), jnp.tanh, hk.Linear(onp.prod(self.env_boxspace.action_space.shape)), hk.Reshape(self.env_boxspace.action_space.shape), )) logvar = hk.Sequential(( hk.Linear(7), batch_norm_v, jnp.tanh, hk.Linear(3), jnp.tanh, hk.Linear(onp.prod(self.env_boxspace.action_space.shape)), hk.Reshape(self.env_boxspace.action_space.shape), )) return {'mu': mu(flatten(S)), 'logvar': logvar(flatten(S))}
def func(S, is_training): env = self.env_discrete output_shape = (env.action_space.n, *env.observation_space.shape) flatten = hk.Flatten() batch_norm_m = hk.BatchNorm(create_scale=True, create_offset=True, decay_rate=0.95) batch_norm_v = hk.BatchNorm(create_scale=True, create_offset=True, decay_rate=0.95) batch_norm_m = partial(batch_norm_m, is_training=is_training) batch_norm_v = partial(batch_norm_v, is_training=is_training) mu = hk.Sequential(( hk.Linear(7), batch_norm_m, jnp.tanh, hk.Linear(3), jnp.tanh, hk.Linear(onp.prod(output_shape)), hk.Reshape(output_shape), )) logvar = hk.Sequential(( hk.Linear(7), batch_norm_v, jnp.tanh, hk.Linear(3), jnp.tanh, hk.Linear(onp.prod(output_shape)), hk.Reshape(output_shape), )) X = flatten(S) return {'mu': mu(X), 'logvar': logvar(X)}
def __init__(self, in_planes, planes, stride=1): super().__init__() self.conv1 = hk.Conv2D(output_channels=planes, kernel_shape=3, stride=stride, padding='SAME', with_bias=False, data_format='NCHW') self.bn1 = hk.BatchNorm(create_scale=True, create_offset=True, decay_rate=0.9, data_format='NC...') self.conv2 = hk.Conv2D(output_channels=planes, kernel_shape=3, stride=1, padding='SAME', with_bias=False, data_format='NCHW') self.bn2 = hk.BatchNorm(create_scale=True, create_offset=True, decay_rate=0.9, data_format='NC...') self.in_planes = in_planes self.planes = planes self.stride = stride
def __init__(self, vocab_size, lstm_dim, dropout_rate, is_training=True): super().__init__() self.is_training = is_training self.embed = hk.Embed(vocab_size, lstm_dim) self.conv1 = hk.Conv1D(lstm_dim, 3, padding='SAME') self.conv2 = hk.Conv1D(lstm_dim, 3, padding='SAME') self.conv3 = hk.Conv1D(lstm_dim, 3, padding='SAME') self.bn1 = hk.BatchNorm(True, True, 0.9) self.bn2 = hk.BatchNorm(True, True, 0.9) self.bn3 = hk.BatchNorm(True, True, 0.9) self.lstm_fwd = hk.LSTM(lstm_dim) self.lstm_bwd = hk.ResetCore(hk.LSTM(lstm_dim)) self.dropout_rate = dropout_rate
def __call__(self, x, key, dropout_rates, is_training): out = hk.Flatten()(x) if is_training and (dropout_rates is not None): keys = jax.random.split(key, self.nlayers) for l in range(self.nlayers): out = hk.Linear(self.nhid, with_bias=self.with_bias)(out) if self.batch_norm: out = hk.BatchNorm(create_scale=False, create_offset=False, decay_rate=0.9)(out, is_training) if is_training and (dropout_rates is not None): out = dropout(keys[l], dropout_rates[l], out) if self.activation == 'relu': out = jax.nn.relu(out) elif self.activation == 'sigmoid': out = jax.nn.sigmoid(out) elif self.activation == 'tanh': out = jnp.tanh(out) elif self.activation == 'linear': out = out out = hk.Linear(10, with_bias=self.with_bias)(out) return out
def __init__(self): super().__init__() bn_config = { 'create_scale': True, 'create_offset': True, 'decay_rate': 0.999 } # Definition of the modules. self.conv_block = hk.Sequential([ hk.Conv2D(1, (3, 3), stride=3, rate=1), jax.nn.relu, hk.Conv2D(1, (3, 3), stride=3, rate=1), jax.nn.relu, ]) self.conv_res_block = hk.Sequential([ hk.Conv2D(1, (1, 1), stride=1, rate=1), jax.nn.relu, hk.Conv2D(1, (1, 1), stride=1, rate=1), jax.nn.relu, ]) self.reshape_mod = hk.Flatten() self.lin_res_block = [(hk.Linear(16), hk.BatchNorm(name='lin_batchnorm_0', **bn_config))] self.final_linear = hk.Linear(10)
def __call__(self, node_feats: jnp.ndarray, adj: jnp.ndarray, is_training: bool) -> jnp.ndarray: """Update node features. Parameters ---------- node_feats : ndarray of shape (batch_size, N, in_feats) Batch input node features. N is the total number of nodes in the batch of graphs. adj : ndarray of shape (batch_size, N, N) Batch adjacency matrix. is_training : bool Whether the model is training or not. Returns ------- new_node_feats : ndarray of shape (batch_size, N, out_feats) Batch new node features. """ dropout = self.dropout if is_training is True else 0.0 # for batch data new_node_feats = jax.vmap(self._update_nodes)(node_feats, adj) if self.bias: new_node_feats += self.b new_node_feats = self.activation(new_node_feats) if dropout != 0.0: new_node_feats = hk.dropout(hk.next_rng_key(), dropout, new_node_feats) if self.batch_norm: new_node_feats = hk.BatchNorm(True, True, 0.9)(new_node_feats, is_training) return new_node_feats
def __call__(self, inputs, is_training): dropout_rate = self._dropout_rate if is_training else 0.0 h = hk.dropout(hk.next_rng_key(), dropout_rate, inputs) h = hk.Linear(self._vocab_size, with_bias=False)(h) return hk.BatchNorm(create_scale=False, create_offset=False, decay_rate=0.9)(h, is_training)
def __call__(self, x, mask_props=None, is_training=True): out = hk.Flatten()(x) for l in range(self.nlayers): out = hk.Linear(self.nhid, with_bias=self.with_bias)(out) if self.batch_norm: out = hk.BatchNorm(create_scale=False, create_offset=False, decay_rate=0.9)(out, is_training) if mask_props is not None: num_units = jnp.floor(mask_props[l] * out.shape[1]).astype( jnp.int32) mask = jnp.arange(out.shape[1]) < num_units out = jnp.where(mask, out, jnp.zeros(out.shape)) if self.activation == 'relu': out = jax.nn.relu(out) elif self.activation == 'sigmoid': out = jax.nn.sigmoid(out) elif self.activation == 'tanh': out = jnp.tanh(out) elif self.activation == 'linear': out = out out = hk.Linear(10, with_bias=self.with_bias)(out) return out
def func_boxspace(S, is_training): batch_norm = hk.BatchNorm(False, False, 0.99) mu = hk.Sequential(( hk.Flatten(), hk.Linear(8), jax.nn.relu, partial(hk.dropout, hk.next_rng_key(), 0.25 if is_training else 0.), partial(batch_norm, is_training=is_training), hk.Linear(8), jnp.tanh, hk.Linear(onp.prod(boxspace.shape)), hk.Reshape(boxspace.shape), )) logvar = hk.Sequential(( hk.Flatten(), hk.Linear(8), jax.nn.relu, partial(hk.dropout, hk.next_rng_key(), 0.25 if is_training else 0.), partial(batch_norm, is_training=is_training), hk.Linear(8), jnp.tanh, hk.Linear(onp.prod(boxspace.shape)), hk.Reshape(boxspace.shape), )) return {'mu': mu(S), 'logvar': logvar(S)}
def single_mlp(inner_name: str): """Creates a single MLP performing the update.""" mlp = hk.nets.MLP(output_sizes=output_sizes, name=inner_name, activation=activation) mlp = jraph.concatenated_args(mlp) if normalization_type == 'layer_norm': norm = hk.LayerNorm(axis=-1, create_scale=True, create_offset=True, name=name + '_layer_norm') elif normalization_type == 'batch_norm': batch_norm = hk.BatchNorm( create_scale=True, create_offset=True, decay_rate=0.9, name=f'{inner_name}_batch_norm', cross_replica_axis=None if hk.running_init() else 'i', ) norm = lambda x: batch_norm(x, is_training) elif normalization_type == 'none': return mlp else: raise ValueError( f'Unknown normalization type {normalization_type}') return jraph.concatenated_args(hk.Sequential([mlp, norm]))
def fn(x): if dropout: x = hk.dropout(hk.next_rng_key(), 0.5, x) if batchnorm: x = hk.BatchNorm(create_offset=True, create_scale=True, decay_rate=0.001)( x, is_training=True ) return x
def func_discrete_type1(S, A, is_training): batch_norm = hk.BatchNorm(False, False, 0.99) seq = hk.Sequential( (hk.Flatten(), hk.Linear(8), jax.nn.relu, partial(hk.dropout, hk.next_rng_key(), 0.25 if is_training else 0.), partial(batch_norm, is_training=is_training), hk.Linear(8), jnp.tanh, hk.Linear(discrete.n), jax.nn.softmax)) X = jax.vmap(jnp.kron)(S, A) return seq(X)
def func_discrete_type2(S, is_training): batch_norm = hk.BatchNorm(False, False, 0.99) seq = hk.Sequential( (hk.Flatten(), hk.Linear(8), jax.nn.relu, partial(hk.dropout, hk.next_rng_key(), 0.25 if is_training else 0.), partial(batch_norm, is_training=is_training), hk.Linear(8), jnp.tanh, hk.Linear(discrete.n * discrete.n), hk.Reshape((discrete.n, discrete.n)), jax.nn.softmax)) return seq(S)
def func(S, is_training): flatten = hk.Flatten() batch_norm = hk.BatchNorm(create_scale=True, create_offset=True, decay_rate=0.95) batch_norm = partial(batch_norm, is_training=is_training) seq = hk.Sequential( (hk.Linear(7), batch_norm, jnp.tanh, hk.Linear(3), jnp.tanh, hk.Linear(1), jnp.ravel)) return seq(flatten(S))
def func(S, is_training): batch_norm = hk.BatchNorm(False, False, 0.99) logits = hk.Sequential(( hk.Flatten(), hk.Linear(8), jax.nn.relu, partial(hk.dropout, hk.next_rng_key(), 0.25 if is_training else 0.), partial(batch_norm, is_training=is_training), hk.Linear(8), jnp.tanh, hk.Linear(num_bins), )) return {'logits': logits(S)}
def func(S, is_training): flatten = hk.Flatten() batch_norm = hk.BatchNorm(create_scale=True, create_offset=True, decay_rate=0.95) batch_norm = partial(batch_norm, is_training=is_training) seq = hk.Sequential( (hk.Linear(7), batch_norm, jnp.tanh, hk.Linear(3), jnp.tanh, hk.Linear(self.env_discrete.action_space.n * 51), hk.Reshape((self.env_discrete.action_space.n, 51)))) return seq(flatten(S))
def __init__(self, c_in, c_out): super().__init__() self.conv = hk.Conv2D(output_channels=c_out, kernel_shape=3, stride=1, with_bias=False, padding='SAME', data_format='NCHW') self.bn = hk.BatchNorm(create_scale=True, create_offset=True, decay_rate=0.9, data_format='NC...')
def _call_layers( cfg, inp: Tensor, batch_norm: bool = True, include_top: bool = True, initial_weights: Optional[hk.Params] = None, output_feature_maps: bool = False) -> Union[Tensor, List[Tensor]]: x = inp partial_results = [] # Ignore max pooling if we do not append the classifier if not include_top: cfg = cfg[:-1] i = 0 base_name = 'vgg16/conv2_d' for v in cfg: if v == 'M': partial_results.append(x) x = hk.MaxPool(window_shape=2, strides=2, padding="VALID")(x) else: if i == 0: param_name = base_name else: param_name = base_name + f'_{i}' i += 1 w_init = (None if initial_weights is None else hk.initializers.Constant( constant=initial_weights[param_name]['w'])) b_init = (None if initial_weights is None else hk.initializers.Constant( constant=initial_weights[param_name]['b'])) x = hk.Conv2D(v, kernel_shape=3, stride=1, padding='SAME', w_init=w_init, b_init=b_init)(x) if batch_norm: x = hk.BatchNorm(True, True, decay_rate=0.999)(x) x = jax.nn.relu(x) partial_results.append(x) if not output_feature_maps: return partial_results[-1] else: return partial_results
def func_type2(S, is_training): seq = hk.Sequential(( hk.Flatten(), hk.Linear(8), jax.nn.relu, partial(hk.dropout, hk.next_rng_key(), 0.25 if is_training else 0.), partial(hk.BatchNorm(False, False, 0.99), is_training=is_training), hk.Linear(8), jax.nn.relu, hk.Linear(env_discrete.action_space.n), )) return seq(S)
def __init__(self, output_channels: Sequence[int], kernel_shapes: Union[int, Sequence[int]] = 3, strides: Union[int, Sequence[int]] = 1, padding: Union[str, Sequence[str]] = "SAME", data_format: str = "NHWC", with_batch_norm: bool = False, activate_final: bool = False, activation: Activation = "leaky_relu", name: Optional[str] = None): super().__init__(name=name) self.output_channels = tuple(output_channels) self.num_layers = len(self.output_channels) self.kernel_shapes = utils.bcast_if(kernel_shapes, int, self.num_layers) self.strides = utils.bcast_if(strides, int, self.num_layers) self.padding = utils.bcast_if(padding, str, self.num_layers) self.data_format = data_format self.with_batch_norm = with_batch_norm self.activate_final = activate_final self.activation = utils.get_activation(activation) if len(self.kernel_shapes) != self.num_layers: raise ValueError( f"Kernel shapes is of size {len(self.kernel_shapes)}, " f"while output_channels is of size{self.num_layers}.") if len(self.strides) != self.num_layers: raise ValueError( f"Strides is of size {len(self.kernel_shapes)}, while " f"output_channels is of size{self.num_layers}.") if len(self.padding) != self.num_layers: raise ValueError(f"Padding is of size {len(self.padding)}, while " f"output_channels is of size{self.num_layers}.") self.conv_modules = [] self.bn_modules = [] for i in range(self.num_layers): self.conv_modules.append( hk.Conv2D(output_channels=self.output_channels[i], kernel_shape=self.kernel_shapes[i], stride=self.strides[i], padding=self.padding[i], data_format=data_format, name=f"conv_2d_{i}")) if with_batch_norm: self.bn_modules.append( hk.BatchNorm(create_offset=True, create_scale=False, decay_rate=0.999, name=f"batch_norm_{i}")) else: self.bn_modules.append(None)
def func_boxspace_type1(S, A, is_training): batch_norm = hk.BatchNorm(False, False, 0.99) seq = hk.Sequential(( hk.Flatten(), hk.Linear(8), jax.nn.relu, partial(hk.dropout, hk.next_rng_key(), 0.25 if is_training else 0.), partial(batch_norm, is_training=is_training), hk.Linear(8), jnp.tanh, hk.Linear(onp.prod(boxspace.shape)), hk.Reshape(boxspace.shape), )) X = jax.vmap(jnp.kron)(S, A) return seq(X)
def __call__( self, graph: tp.Union[jnp.ndarray, JAXSparse], node_features: jnp.ndarray, is_training: tp.Optional[bool] = None, ): x = node_features for f in self.hidden_filters: x = GCN(f)(graph, x) x = hk.BatchNorm(True, True, 0.9)(x, is_training=is_training) x = jax.nn.relu(x) x = dropout(x, self.dropout_rate, is_training=is_training) logits = GCN(self.num_classes)(graph, x) return logits
def __call__(self, inputs, is_training): dropout_rate = self._dropout_rate if is_training else 0.0 h = jax.nn.softplus(hk.Linear(self._hidden)(inputs)) h = jax.nn.softplus(hk.Linear(self._hidden)(h)) h = hk.dropout(hk.next_rng_key(), dropout_rate, h) h = hk.Linear(self._num_topics)(h) # NB: here we set `create_scale=False` and `create_offset=False` to reduce # the number of learning parameters log_concentration = hk.BatchNorm(create_scale=False, create_offset=False, decay_rate=0.9)(h, is_training) return jnp.exp(log_concentration)
def func(S, A, is_training): flatten = hk.Flatten() batch_norm = hk.BatchNorm(create_scale=True, create_offset=True, decay_rate=0.95) batch_norm = partial(batch_norm, is_training=is_training) seq = hk.Sequential(( hk.Linear(7), batch_norm, jnp.tanh, hk.Linear(51), )) print(S.shape, A.shape) X = jnp.concatenate((flatten(S), flatten(A)), axis=-1) return {'logits': seq(X)}
def func_type1(S, A, is_training): seq = hk.Sequential(( hk.Linear(8), jax.nn.relu, partial(hk.dropout, hk.next_rng_key(), 0.25 if is_training else 0.), partial(hk.BatchNorm(False, False, 0.99), is_training=is_training), hk.Linear(8), jax.nn.relu, hk.Linear(1), jnp.ravel, )) S = hk.Flatten()(S) A = hk.Flatten()(A) X = jnp.concatenate((S, A), axis=-1) return seq(X)
def __init__(self, is_training=True): super().__init__() self.is_training = is_training self.encoder = TokenEncoder(FLAGS.vocab_size, FLAGS.acoustic_encoder_dim, 0.5, is_training) self.decoder = hk.deep_rnn_with_skip_connections([ hk.LSTM(FLAGS.acoustic_decoder_dim), hk.LSTM(FLAGS.acoustic_decoder_dim) ]) self.projection = hk.Linear(FLAGS.mel_dim) # prenet self.prenet_fc1 = hk.Linear(256, with_bias=True) self.prenet_fc2 = hk.Linear(256, with_bias=True) # posnet self.postnet_convs = [hk.Conv1D(FLAGS.postnet_dim, 5) for _ in range(4)] + [hk.Conv1D(FLAGS.mel_dim, 5)] self.postnet_bns = [hk.BatchNorm(True, True, 0.9) for _ in range(4)] + [None]
def mlp( x: tp.Union[jnp.ndarray, JAXSparse], is_training: bool, ids=None, num_classes: tp.Optional[int] = None, hidden_filters: tp.Union[int, tp.Iterable[int]] = 64, dropout_rate: float = 0.8, use_batch_norm: bool = False, use_layer_norm: bool = False, use_renormalize: bool = False, use_gathered_batch_norm: bool = False, activation: Activation = jax.nn.relu, final_activation: Activation = lambda x: x, input_dropout_rate: tp.Optional[float] = None, batch_norm_decay: float = 0.9, renorm_scale: bool = True, w_init=None, ): assert (sum((use_batch_norm, use_layer_norm, use_renormalize, use_gathered_batch_norm)) <= 1) if input_dropout_rate is None: input_dropout_rate = dropout_rate if isinstance(hidden_filters, int): hidden_filters = (hidden_filters, ) x = dropout(x, input_dropout_rate, is_training=is_training) for filters in hidden_filters: x = Linear(filters, w_init=w_init)(x) if use_batch_norm: x = hk.BatchNorm(renorm_scale, True, batch_norm_decay)(x, is_training) if use_layer_norm: x = hk.LayerNorm(0, renorm_scale, True)(x) if use_renormalize: x = Renormalize(renorm_scale, True)(x) if use_gathered_batch_norm: assert ids is not None x = GatheredBatchNorm(True, True, batch_norm_decay)(x, is_training=is_training, ids=ids) x = activation(x) x = dropout(x, dropout_rate, is_training=is_training) if num_classes is not None: x = hk.Linear(num_classes, w_init=w_init)(x) return final_activation(x)
def func(S, A, is_training): flatten = hk.Flatten() batch_norm = hk.BatchNorm(create_scale=True, create_offset=True, decay_rate=0.95) batch_norm = partial(batch_norm, is_training=is_training) seq = hk.Sequential(( hk.Linear(7), batch_norm, jnp.tanh, hk.Linear(3), jnp.tanh, hk.Linear(1), jnp.ravel, )) X = jnp.concatenate((flatten(S), flatten(A)), axis=-1) return seq(X)
def __init__(self, num_blocks=[5, 5, 5], num_classes=10): super().__init__() self.conv1 = hk.Conv2D( output_channels=16, # output_channels=64, kernel_shape=3, stride=1, with_bias=False, padding='SAME', data_format='NCHW') self.bn1 = hk.BatchNorm(create_scale=True, create_offset=True, decay_rate=0.9, data_format='NC...') self.layer1 = MultiBlock(16, 16, [1] + [1] * (num_blocks[0] - 1)) self.layer2 = MultiBlock(16, 32, [2] + [1] * (num_blocks[1] - 1)) self.layer3 = MultiBlock(32, 64, [2] + [1] * (num_blocks[2] - 1)) self.linear = hk.Linear(num_classes)