def __call__(self, graphs: jraph.GraphsTuple) -> jraph.GraphsTuple: """Compute embeddings for each node in the graphs. Args: graphs: a set of graphs batched into a single graph. The nodes and edges are represented as feature tensors. Returns: graphs: new graph with node embeddings updated (shape [n_nodes, embed_dim]). """ nodes = hk.Linear(self._embed_dim)(graphs.nodes) edges = hk.Linear(self._embed_dim)(graphs.edges) nodes = hk.LayerNorm(axis=-1, create_scale=True, create_offset=True)(jax.nn.gelu(nodes)) edges = hk.LayerNorm(axis=-1, create_scale=True, create_offset=True)(jax.nn.gelu(edges)) graphs = graphs._replace(nodes=nodes, edges=edges) graphs = gn.SimpleGraphNet( num_layers=self._num_layers, msg_hidden_size_factor=self._msg_hidden_size_factor, layer_norm=self._use_layer_norm)(graphs) return graphs
def __init__(self, C, attention_fn, name='SlotAttention'): super().__init__(name=name) he_init = hk.initializers.VarianceScaling(scale=2.0) self.num_slots = C['slots'] self.slot_size = C['slot_size'] self.attn_eps = C['attention_eps'] self.mlp_hidden_size = C['mlp_hidden_size'] # Learnable mu and sigma (no covar) (dim slot_dim) to initilialize slots # Learnable linear transforms for K,Q,V Attention - Use Glorot init here? self.k = hk.Linear(self.slot_size, w_init=he_init, with_bias=False) self.q = hk.Linear(self.slot_size, w_init=he_init, with_bias=False) self.v = hk.Linear(self.slot_size, w_init=he_init, with_bias=False) # Layer norm for slots after and before attention (GRU Omitted) self.layer_norm_in = hk.LayerNorm(axis=-1, create_scale=True, create_offset=True) self.layer_norm_1 = hk.LayerNorm(axis=-1, create_scale=True, create_offset=True) self.layer_norm_2 = hk.LayerNorm(axis=-1, create_scale=True, create_offset=True) # Slot update function is learned by GRU - #hidden_states = slot_dim self.mlp = hk.Sequential([ hk.Linear(self.mlp_hidden_size, w_init=he_init), jax.nn.relu,# MLP + Residual Connection improves output hk.Linear(self.slot_size, w_init=he_init) ]) self.attention_fn = attention_fn
def __init__(self, layer_sizes: Sequence[int], w_init: hk.initializers.Initializer = uniform_initializer, activation: Callable[[jnp.ndarray], jnp.ndarray] = jax.nn.elu, activate_final: bool = False, name: str = 'feedforward_mlp_torso'): """Construct the MLP. Args: layer_sizes: a sequence of ints specifying the size of each layer. w_init: initializer for Linear layers. activation: nonlinearity to use in the MLP, defaults to elu. Note! The default activation differs from the usual MLP default of ReLU for legacy reasons. activate_final: whether or not to use the activation function on the final layer of the neural network. name: a name for the module. """ super().__init__(name=name) self._network = hk.Sequential([ hk.Linear(layer_sizes[0], w_init=w_init), hk.LayerNorm(axis=-1, create_scale=True, create_offset=True), jax.lax.tanh, hk.nets.MLP(layer_sizes[1:], w_init=w_init, activation=activation, activate_final=activate_final), ])
def large_critic(x): # inspired by the ones used in RL Unplugged x = hk.Sequential([ hk.Linear(400, w_init=rlu_uniform_initializer), hk.LayerNorm(axis=-1, create_scale=True, create_offset=True), jax.lax.tanh, ])(x) x = hk.Linear(1024, w_init=rlu_uniform_initializer)(x) for i in range(4): x = network_utils.ResidualLayerNormBlock( [1024, 1024], activation=jax.nn.relu, w_init=rlu_uniform_initializer, )(x) h = x # v = hk.Linear(1, w_init=rlu_uniform_initializer)(h) # v = hk.Linear(critic_output_dim)(h) all_vs = [] for _ in range(critic_output_dim): head_v = hk.Linear(256, w_init=rlu_uniform_initializer)(h) head_v = jax.nn.relu(head_v) head_v = hk.Linear(1, w_init=rlu_uniform_initializer)(head_v) all_vs.append(head_v) v = jnp.concatenate(all_vs, axis=-1) return v, h
def base_network(x, layers): x = hk.nets.MLP(layers[:-1], activate_final=True)(x) x = hk.LayerNorm(axis=-1, create_scale=True, create_offset=True)(x) x = hk.Linear(layers[-1])(x) x = jax.nn.relu(x) x = hk.Linear(self._num_actions)(x) return x
def __init__(self, C, position_enc_fn, name=None): super().__init__(name=name) he_init = hk.initializers.VarianceScaling(scale=2.0) channels = C['encoder_cnn_channels'] kernels = C['encoder_cnn_kernels'] strides = C['encoder_cnn_strides'] hidden_size = channels[-1] self.cnn_layers = hk.Sequential([ hk.Conv2D(channels[0], kernels[0], stride=strides[0], padding='SAME', w_init=he_init, with_bias=True), jax.nn.relu, hk.Conv2D(channels[1], kernels[1], stride=strides[1], padding='SAME', w_init=he_init, with_bias=True), jax.nn.relu, hk.Conv2D(channels[2], kernels[2], stride=strides[2], padding='SAME', w_init=he_init, with_bias=True), jax.nn.relu, hk.Conv2D(hidden_size, kernels[3], stride=strides[3], padding='SAME', w_init=he_init, with_bias=True), jax.nn.relu, ]) self.pos_embed = SoftPositionEmbed(hidden_size, C['hidden_res'], position_enc_fn) self.linears = hk.Sequential([ # i.e. 1x1 convolution (shared 32 neurons across all locations) hk.Reshape((-1, hidden_size)), # Flatten spatial dim (works with batch) hk.LayerNorm(axis=-1, create_scale=True, create_offset=True), hk.Linear(32, w_init=he_init), jax.nn.relu, hk.Linear(32, w_init=he_init), ])
def single_mlp(inner_name: str): """Creates a single MLP performing the update.""" mlp = hk.nets.MLP(output_sizes=output_sizes, name=inner_name, activation=activation) mlp = jraph.concatenated_args(mlp) if normalization_type == 'layer_norm': norm = hk.LayerNorm(axis=-1, create_scale=True, create_offset=True, name=name + '_layer_norm') elif normalization_type == 'batch_norm': batch_norm = hk.BatchNorm( create_scale=True, create_offset=True, decay_rate=0.9, name=f'{inner_name}_batch_norm', cross_replica_axis=None if hk.running_init() else 'i', ) norm = lambda x: batch_norm(x, is_training) elif normalization_type == 'none': return mlp else: raise ValueError( f'Unknown normalization type {normalization_type}') return jraph.concatenated_args(hk.Sequential([mlp, norm]))
def __init__(self, config, name=None): super().__init__(name=name) out_dim = config["n_vocab"] self.dim = out_dim self.norm = hk.LayerNorm(-1, True, True) self.proj = hk.Linear(self.dim)
def __init__(self, make_inner_op: MakeInnerOp, non_linearity: NonLinearity = jax.nn.relu, use_layer_norm: bool = False, name: str = 'residual_block'): super().__init__(name=name) self.inner_op1 = make_inner_op() self.inner_op2 = make_inner_op() self.non_linearity = non_linearity self.use_layer_norm = use_layer_norm if use_layer_norm: self.layernorm1 = hk.LayerNorm(axis=(1, 2, 3), create_scale=True, create_offset=True, eps=1e-6) self.layernorm2 = hk.LayerNorm(axis=(1, 2, 3), create_scale=True, create_offset=True, eps=1e-6)
def norm(name): layer_norm = hk.LayerNorm(axis=-1, name=name, create_scale=True, create_offset=True) def norm_apply( x, **kwargs ): # So that this code works with the is_training kwarg return layer_norm(x) return norm_apply
def quantile_net(x, quantile_fractions): x_size = x.shape[-1] x_tiled = jnp.tile(x[:, None, :], [num_bins, 1]) quantiles_emb = quantile_cos_embedding(quantile_fractions) quantiles_emb = hk.Linear(x_size)(quantiles_emb) quantiles_emb = hk.LayerNorm(axis=-1, create_scale=True, create_offset=True)(quantiles_emb) quantiles_emb = jax.nn.sigmoid(quantiles_emb) x = x_tiled * quantiles_emb x = hk.Linear(x_size)(x) x = jax.nn.relu(x) return x
def __init__(self, config, name=None, init_scale=1.): super().__init__(name=name) self.dim = config["d_model"] self.n_head = config["n_heads"] self.d_head = config["d_head"] self.d_rotary = config["pe_rotary_dims"] self.mp_num = thread_resources.env.shape['mp'] self.norm = hk.LayerNorm(-1, True, True) self.input_proj = hk.Linear(self.d_head * self.n_head * 3 + self.dim * 4) self.output_proj = hk.Linear( self.dim, w_init=hk.initializers.TruncatedNormal(stddev=init_scale / jnp.sqrt(self.dim)))
def mlp( x: tp.Union[jnp.ndarray, JAXSparse], is_training: bool, ids=None, num_classes: tp.Optional[int] = None, hidden_filters: tp.Union[int, tp.Iterable[int]] = 64, dropout_rate: float = 0.8, use_batch_norm: bool = False, use_layer_norm: bool = False, use_renormalize: bool = False, use_gathered_batch_norm: bool = False, activation: Activation = jax.nn.relu, final_activation: Activation = lambda x: x, input_dropout_rate: tp.Optional[float] = None, batch_norm_decay: float = 0.9, renorm_scale: bool = True, w_init=None, ): assert (sum((use_batch_norm, use_layer_norm, use_renormalize, use_gathered_batch_norm)) <= 1) if input_dropout_rate is None: input_dropout_rate = dropout_rate if isinstance(hidden_filters, int): hidden_filters = (hidden_filters, ) x = dropout(x, input_dropout_rate, is_training=is_training) for filters in hidden_filters: x = Linear(filters, w_init=w_init)(x) if use_batch_norm: x = hk.BatchNorm(renorm_scale, True, batch_norm_decay)(x, is_training) if use_layer_norm: x = hk.LayerNorm(0, renorm_scale, True)(x) if use_renormalize: x = Renormalize(renorm_scale, True)(x) if use_gathered_batch_norm: assert ids is not None x = GatheredBatchNorm(True, True, batch_norm_decay)(x, is_training=is_training, ids=ids) x = activation(x) x = dropout(x, dropout_rate, is_training=is_training) if num_classes is not None: x = hk.Linear(num_classes, w_init=w_init)(x) return final_activation(x)
def _build_mlp( name: str, output_sizes: Sequence[int], use_layer_norm=False, activation=jax.nn.relu, ): """Builds an MLP, optionally with layernorm.""" net = hk.nets.MLP(output_sizes=output_sizes, name=name + "_mlp", activation=activation) if use_layer_norm: layer_norm = hk.LayerNorm(axis=-1, create_scale=True, create_offset=True, name=name + "_layer_norm") net = hk.Sequential([net, layer_norm]) return jraph.concatenated_args(net)
def getnorm(type): if type == "layernorm": return ReplicatedLayerNorm() if type == "layernorm-desync": return hk.LayerNorm(-1, True, True) elif type == "layernorm-nobias": return ReplicatedLayerNorm(offset=False) elif type == "rmsnorm": return RMSNorm(False, True) elif type == "scalenorm": return RMSNorm(False, False) elif type == "rmsnorm-bias": return RMSNorm(True, True) elif type == "scalenorm-bias": return RMSNorm(True, False) else: raise Exception("Not implemented")
def make_downsampling_layer( strategy: Union[str, DownsamplingStrategy], output_channels: int, ) -> hk.SupportsCall: """Returns a sequence of modules corresponding to the desired downsampling.""" strategy = DownsamplingStrategy(strategy) if strategy is DownsamplingStrategy.AVG_POOL: return hk.AvgPool(window_shape=(3, 3, 1), strides=(2, 2, 1), padding='SAME') elif strategy is DownsamplingStrategy.CONV: return hk.Sequential([ hk.Conv2D(output_channels, kernel_shape=3, stride=2, w_init=hk.initializers.TruncatedNormal(1e-2)), ]) elif strategy is DownsamplingStrategy.LAYERNORM_RELU_CONV: return hk.Sequential([ hk.LayerNorm(axis=(1, 2, 3), create_scale=True, create_offset=True, eps=1e-6), jax.nn.relu, hk.Conv2D(output_channels, kernel_shape=3, stride=2, w_init=hk.initializers.TruncatedNormal(1e-2)), ]) elif strategy is DownsamplingStrategy.CONV_MAX: return hk.Sequential([ hk.Conv2D(output_channels, kernel_shape=3, stride=1), hk.MaxPool(window_shape=(3, 3, 1), strides=(2, 2, 1), padding='SAME') ]) else: raise ValueError( 'Unrecognized downsampling strategy. Expected one of' f' {[strategy.value for strategy in DownsamplingStrategy]}' f' but received {strategy}.')
def _update_nodes(self, graph: jraph.GraphsTuple, messages: ArrayType) -> ArrayType: """Compute updated node representations.""" x = jax.ops.segment_sum(messages, graph.receivers, num_segments=graph.nodes.shape[0]) x = jnp.concatenate([graph.nodes, x], axis=-1) layer_sizes = self._node_hidden_sizes[:] if self._residual: layer_sizes += [graph.nodes.shape[-1]] x = hk.nets.MLP(layer_sizes, activate_final=False)(x) if self._layer_norm: x = hk.LayerNorm(axis=-1, create_scale=True, create_offset=True)(x) if self._residual: return graph.nodes + x else: return x
def __init__( self, layer_sizes, activation, with_bias=True, w_init=None, b_init=None, name = 'ResidualLayerNormBlock'): super().__init__(name=name) self._mlp = hk.nets.MLP( output_sizes=layer_sizes, w_init=w_init, b_init=b_init, with_bias=with_bias, activation=activation, activate_final=False) self._layer_norm = hk.LayerNorm( axis=-1, create_scale=True, create_offset=True)
def __init__(self, layer_sizes: Sequence[int], activate_final: bool = False): """Construct the MLP. Args: layer_sizes: a sequence of ints specifying the size of each layer. activate_final: whether or not to use the activation function on the final layer of the neural network. """ super().__init__(name='feedforward_mlp_torso') self._network = hk.Sequential([ hk.Linear(layer_sizes[0], w_init=uniform_initializer), hk.LayerNorm(axis=-1, create_scale=True, create_offset=True), jax.lax.tanh, hk.nets.MLP(layer_sizes[1:], w_init=uniform_initializer, activation=jax.nn.elu, activate_final=activate_final), ])
def layer_norm(x): return hk.LayerNorm(-1, True, True)(x)
def layer_norm(x, name=None): return hk.LayerNorm(axis=-1, create_scale=True, create_offset=True, name=name)(x)
def layer_norm(x: jnp.ndarray, name: Optional[str] = None) -> jnp.ndarray: """Apply a unique LayerNorm to x with default settings.""" return hk.LayerNorm(axis=-1, create_scale=True, create_offset=True, name=name)(x)
def layer_norm(x: jnp.ndarray) -> jnp.ndarray: """Apply a unique LayerNorm to x with default settings.""" return hk.LayerNorm(axis=-1, create_scale=True, create_offset=True)(x)
def layer_norm(x): return hk.LayerNorm(axis=-1, create_scale=True, create_offset=True)(x)
def __call__(self, x): w_init = hk.initializers.Orthogonal(scale=1.0) x = hk.Linear(self.feature_dim, w_init=w_init)(x) x = hk.LayerNorm(axis=1, create_scale=True, create_offset=True)(x) x = jnp.tanh(x) return x
def __call__(self, activations, sequence_mask, update_affine, is_training, initial_act, safe_key=None, static_feat_2d=None, aatype=None): c = self.config if safe_key is None: safe_key = prng.SafeKey(hk.next_rng_key()) def safe_dropout_fn(tensor, safe_key): return prng.safe_dropout( tensor=tensor, safe_key=safe_key, rate=c.dropout, is_deterministic=self.global_config.deterministic, is_training=is_training) affine = quat_affine.QuatAffine.from_tensor(activations['affine']) act = activations['act'] attention_module = InvariantPointAttention(self.config, self.global_config) # Attention attn = attention_module(inputs_1d=act, inputs_2d=static_feat_2d, mask=sequence_mask, affine=affine) act += attn safe_key, *sub_keys = safe_key.split(3) sub_keys = iter(sub_keys) act = safe_dropout_fn(act, next(sub_keys)) act = hk.LayerNorm(axis=[-1], create_scale=True, create_offset=True, name='attention_layer_norm')(act) final_init = 'zeros' if self.global_config.zero_init else 'linear' # Transition input_act = act for i in range(c.num_layer_in_transition): init = 'relu' if i < c.num_layer_in_transition - 1 else final_init act = common_modules.Linear(c.num_channel, initializer=init, name='transition')(act) if i < c.num_layer_in_transition - 1: act = jax.nn.relu(act) act += input_act act = safe_dropout_fn(act, next(sub_keys)) act = hk.LayerNorm(axis=[-1], create_scale=True, create_offset=True, name='transition_layer_norm')(act) if update_affine: # This block corresponds to # Jumper et al. (2021) Alg. 23 "Backbone update" affine_update_size = 6 # Affine update affine_update = common_modules.Linear(affine_update_size, initializer=final_init, name='affine_update')(act) affine = affine.pre_compose(affine_update) sc = MultiRigidSidechain(c.sidechain, self.global_config)( affine.scale_translation(c.position_scale), [act, initial_act], aatype) outputs = {'affine': affine.to_tensor(), 'sc': sc} affine = affine.apply_rotation_tensor_fn(jax.lax.stop_gradient) new_activations = {'act': act, 'affine': affine.to_tensor()} return new_activations, outputs
def generate_affines(representations, batch, config, global_config, is_training, safe_key): """Generate predicted affines for a single chain. Jumper et al. (2021) Suppl. Alg. 20 "StructureModule" This is the main part of the structure module - it iteratively applies folding to produce a set of predicted residue positions. Args: representations: Representations dictionary. batch: Batch dictionary. config: Config for the structure module. global_config: Global config. is_training: Whether the model is being trained. safe_key: A prng.SafeKey object that wraps a PRNG key. Returns: A dictionary containing residue affines and sidechain positions. """ c = config sequence_mask = batch['seq_mask'][:, None] act = hk.LayerNorm(axis=[-1], create_scale=True, create_offset=True, name='single_layer_norm')(representations['single']) initial_act = act act = common_modules.Linear(c.num_channel, name='initial_projection')(act) affine = generate_new_affine(sequence_mask) fold_iteration = FoldIteration(c, global_config, name='fold_iteration') assert len(batch['seq_mask'].shape) == 1 activations = { 'act': act, 'affine': affine.to_tensor(), } act_2d = hk.LayerNorm(axis=[-1], create_scale=True, create_offset=True, name='pair_layer_norm')(representations['pair']) outputs = [] safe_keys = safe_key.split(c.num_layer) for sub_key in safe_keys: activations, output = fold_iteration(activations, initial_act=initial_act, static_feat_2d=act_2d, safe_key=sub_key, sequence_mask=sequence_mask, update_affine=True, is_training=is_training, aatype=batch['aatype']) outputs.append(output) output = jax.tree_map(lambda *x: jnp.stack(x), *outputs) # Include the activations in the output dict for use by the LDDT-Head. output['act'] = activations['act'] return output
shape=(BATCH_SIZE, 2, 2, 3)), ModuleDescriptor( name="Bias", create=lambda: hk.Bias(), shape=(BATCH_SIZE, 3, 3, 3)), ModuleDescriptor( name="Flatten", create=lambda: hk.Flatten(), shape=(BATCH_SIZE, 3, 3, 3)), ModuleDescriptor( name="InstanceNorm", create=lambda: hk.InstanceNorm(True, True), shape=(BATCH_SIZE, 3, 2)), ModuleDescriptor( name="LayerNorm", create=lambda: hk.LayerNorm(1, True, True), shape=(BATCH_SIZE, 3, 2)), ModuleDescriptor( name="SpectralNorm", create=lambda: hk.SpectralNorm(), shape=(BATCH_SIZE, 3, 2)), ModuleDescriptor( name="nets.ResNet", create=lambda: Training(hk.nets.ResNet((3, 4, 6, 3), 1000)), shape=(BATCH_SIZE, 3, 3, 2)), # pylint: disable=g-long-lambda ModuleDescriptor( name="nets.MobileNetV1", create=lambda: Training(hk.nets.MobileNetV1(num_classes=1000, strides=(1, 1, 1), channels=(16, 32, 64))),
create=lambda: Training(hk.BatchNorm(True, True, 0.9)), shape=(BATCH_SIZE, 2, 2, 3)), ModuleDescriptor(name="Bias", create=lambda: hk.Bias(), shape=(BATCH_SIZE, 3, 3, 3)), ModuleDescriptor(name="Flatten", create=lambda: hk.Flatten(), shape=(BATCH_SIZE, 3, 3, 3)), ModuleDescriptor(name="InstanceNorm", create=lambda: hk.InstanceNorm(True, True), shape=(BATCH_SIZE, 3, 2)), ModuleDescriptor(name="GroupNorm", create=lambda: hk.GroupNorm(5), shape=(BATCH_SIZE, 4, 4, 10)), ModuleDescriptor(name="LayerNorm", create=lambda: hk.LayerNorm(1, True, True, param_axis=-1), shape=(BATCH_SIZE, 3, 2)), ModuleDescriptor( name="MultiHeadAttention", create=lambda: MultiInput( # pylint: disable=g-long-lambda hk.MultiHeadAttention(num_heads=8, key_size=64, w_init_scale=1.0), num_inputs=3), shape=(BATCH_SIZE, 3, 2)), ModuleDescriptor(name="RMSNorm", create=lambda: hk.RMSNorm(1), shape=(BATCH_SIZE, 3, 2)), ModuleDescriptor(name="SpectralNorm", create=lambda: hk.SpectralNorm(), shape=(BATCH_SIZE, 3, 2)), ModuleDescriptor(name="nets.ResNet", create=lambda: Training(hk.nets.ResNet(