def make_network_with_prior( action_spec: specs.BoundedArray, policy_layer_sizes: Sequence[int] = (200, 100), critic_layer_sizes: Sequence[int] = (400, 300), prior_layer_sizes: Sequence[int] = (200, 100), policy_keys: Optional[Sequence[str]] = None, prior_keys: Optional[Sequence[str]] = None, ) -> Mapping[str, types.TensorTransformation]: """Creates networks used by the agent.""" # Get total number of action dimensions from action spec. num_dimensions = np.prod(action_spec.shape, dtype=int) flatten_concat_policy = functools.partial( svg0_utils.batch_concat_selection, concat_keys=policy_keys) flatten_concat_prior = functools.partial( svg0_utils.batch_concat_selection, concat_keys=prior_keys) policy_network = snt.Sequential([ flatten_concat_policy, networks.LayerNormMLP(policy_layer_sizes, activate_final=True), networks.MultivariateNormalDiagHead( num_dimensions, tanh_mean=True, min_scale=0.1, init_scale=0.7, fixed_scale=False, use_tfd_independent=False) ]) # The multiplexer concatenates the (maybe transformed) observations/actions. multiplexer = networks.CriticMultiplexer( observation_network=flatten_concat_policy, action_network=networks.ClipToSpec(action_spec)) critic_network = snt.Sequential([ multiplexer, networks.LayerNormMLP(critic_layer_sizes, activate_final=True), networks.NearZeroInitializedLinear(1), ]) prior_network = snt.Sequential([ flatten_concat_prior, networks.LayerNormMLP(prior_layer_sizes, activate_final=True), networks.MultivariateNormalDiagHead( num_dimensions, tanh_mean=True, min_scale=0.1, init_scale=0.7, fixed_scale=False, use_tfd_independent=False) ]) return { "policy": policy_network, "critic": critic_network, "prior": prior_network, }
def test_snapshot_distribution(self): """Test that snapshotter correctly calls saves/restores snapshots.""" # Create a test network. net1 = snt.Sequential([ networks.LayerNormMLP([10, 10]), networks.MultivariateNormalDiagHead(1) ]) spec = specs.Array([10], dtype=np.float32) tf2_utils.create_variables(net1, [spec]) # Save the test network. directory = self.get_tempdir() objects_to_save = {'net': net1} snapshotter = tf2_savers.Snapshotter(objects_to_save, directory=directory) snapshotter.save() # Reload the test network. net2 = tf.saved_model.load(os.path.join(snapshotter.directory, 'net')) inputs = tf2_utils.add_batch_dim(tf2_utils.zeros_like(spec)) with tf.GradientTape() as tape: dist1 = net1(inputs) loss1 = tf.math.reduce_sum(dist1.mean() + dist1.variance()) grads1 = tape.gradient(loss1, net1.trainable_variables) with tf.GradientTape() as tape: dist2 = net2(inputs) loss2 = tf.math.reduce_sum(dist2.mean() + dist2.variance()) grads2 = tape.gradient(loss2, net2.trainable_variables) assert all(tree.map_structure(np.allclose, list(grads1), list(grads2)))
def make_dmpo_networks( action_spec, policy_layer_sizes = (300, 200), critic_layer_sizes = (400, 300), vmin = -150., vmax = 150., num_atoms = 51, ): """Creates networks used by the agent.""" num_dimensions = np.prod(action_spec.shape, dtype=int) policy_network = snt.Sequential([ networks.LayerNormMLP(policy_layer_sizes), networks.MultivariateNormalDiagHead(num_dimensions) ]) # The multiplexer concatenates the (maybe transformed) observations/actions. critic_network = networks.CriticMultiplexer( critic_network=networks.LayerNormMLP(critic_layer_sizes), action_network=networks.ClipToSpec(action_spec)) critic_network = snt.Sequential( [critic_network, networks.DiscreteValuedHead(vmin, vmax, num_atoms)]) return { 'policy': policy_network, 'critic': critic_network, 'observation': tf_utils.batch_concat, }
def make_mpo_networks( action_spec, policy_layer_sizes=(256, 256, 256), critic_layer_sizes=(512, 512, 256), policy_init_std=1e-9, obs_network=None): """Creates networks used by the agent.""" num_dimensions = np.prod(action_spec.shape, dtype=int) critic_layer_sizes = list(critic_layer_sizes) + [1] policy_network = snt.Sequential([ networks.LayerNormMLP(policy_layer_sizes), networks.MultivariateNormalDiagHead( num_dimensions, init_scale=policy_init_std, min_scale=1e-10) ]) # The multiplexer concatenates the (maybe transformed) observations/actions. critic_network = networks.CriticMultiplexer( critic_network=networks.LayerNormMLP(critic_layer_sizes), action_network=networks.ClipToSpec(action_spec)) if obs_network is None: obs_network = tf_utils.batch_concat return { 'policy': policy_network, 'critic': critic_network, 'observation': obs_network, }
def make_networks( action_spec: types.NestedSpec, policy_layer_sizes: Sequence[int] = (10, 10), critic_layer_sizes: Sequence[int] = (10, 10), ) -> Dict[str, snt.Module]: """Creates networks used by the agent.""" # Get total number of action dimensions from action spec. num_dimensions = np.prod(action_spec.shape, dtype=int) policy_network = snt.Sequential([ tf2_utils.batch_concat, networks.LayerNormMLP(policy_layer_sizes, activate_final=True), networks.MultivariateNormalDiagHead( num_dimensions, tanh_mean=True, min_scale=0.3, init_scale=0.7, fixed_scale=False, use_tfd_independent=False) ]) # The multiplexer concatenates the (maybe transformed) observations/actions. multiplexer = networks.CriticMultiplexer() critic_network = snt.Sequential([ multiplexer, networks.LayerNormMLP(critic_layer_sizes, activate_final=True), networks.NearZeroInitializedLinear(1), ]) return { 'policy': policy_network, 'critic': critic_network, }
def make_bc_network( action_spec, policy_layer_sizes=(256, 256, 256), policy_init_std=1e-9, binary_grip_action=False): """Residual BC network in Sonnet, equivalent to residual policy network.""" num_dimensions = np.prod(action_spec.shape, dtype=int) if policy_layer_sizes: policy_network = snt.Sequential([ tf_utils.batch_concat, networks.LayerNormMLP([int(l) for l in policy_layer_sizes]), networks.MultivariateNormalDiagHead( num_dimensions, init_scale=policy_init_std, min_scale=1e-10) ]) else: policy_network = snt.Sequential([ tf_utils.batch_concat, ArmPolicyNormalDiagHead( binary_grip_action=binary_grip_action, num_dimensions=num_dimensions, init_scale=policy_init_std, min_scale=1e-10) ]) return { # 'observation': tf_utils.batch_concat, 'policy': policy_network, }
def make_networks( action_spec: specs.Array, policy_layer_sizes: Sequence[int] = (300, 200), critic_layer_sizes: Sequence[int] = (400, 300), ) -> Dict[str, snt.Module]: """Creates networks used by the agent.""" num_dimensions = np.prod(action_spec.shape, dtype=int) critic_layer_sizes = list(critic_layer_sizes) policy_network = snt.Sequential([ networks.LayerNormMLP(policy_layer_sizes), networks.MultivariateNormalDiagHead(num_dimensions), ]) # The multiplexer concatenates the (maybe transformed) observations/actions. critic_network = snt.Sequential([ networks.CriticMultiplexer( critic_network=networks.LayerNormMLP(critic_layer_sizes)), networks.DiscreteValuedHead(0., 1., 10), ]) return { 'policy': policy_network, 'critic': critic_network, }
def make_networks( action_spec: specs.BoundedArray, policy_layer_sizes: Sequence[int] = (50, ), critic_layer_sizes: Sequence[int] = (50, ), vmin: float = -150., vmax: float = 150., num_atoms: int = 51, ): """Creates networks used by the agent.""" num_dimensions = np.prod(action_spec.shape, dtype=int) policy_network = snt.Sequential([ networks.LayerNormMLP(policy_layer_sizes, activate_final=True), networks.MultivariateNormalDiagHead(num_dimensions, tanh_mean=True, init_scale=0.3, fixed_scale=True, use_tfd_independent=False) ]) # The multiplexer concatenates the (maybe transformed) observations/actions. critic_network = networks.CriticMultiplexer( critic_network=networks.LayerNormMLP(critic_layer_sizes, activate_final=True), action_network=networks.ClipToSpec(action_spec)) critic_network = snt.Sequential( [critic_network, networks.DiscreteValuedHead(vmin, vmax, num_atoms)]) return { 'policy': policy_network, 'critic': critic_network, 'observation': tf2_utils.batch_concat, }
def make_networks( action_spec: specs.BoundedArray, policy_layer_sizes: Sequence[int] = (256, 256, 256), critic_layer_sizes: Sequence[int] = (512, 512, 256), ) -> Dict[str, types.TensorTransformation]: """Creates networks used by the agent.""" num_dimensions = np.prod(action_spec.shape, dtype=int) policy_network = snt.Sequential([ networks.LayerNormMLP(policy_layer_sizes, activate_final=True), networks.MultivariateNormalDiagHead(num_dimensions, init_scale=0.7, use_tfd_independent=True) ]) # The multiplexer concatenates the (maybe transformed) observations/actions. multiplexer = networks.CriticMultiplexer( action_network=networks.ClipToSpec(action_spec)) critic_network = snt.Sequential([ multiplexer, networks.LayerNormMLP(critic_layer_sizes, activate_final=True), networks.NearZeroInitializedLinear(1), ]) return { 'policy': policy_network, 'critic': critic_network, 'observation': tf2_utils.batch_concat, }
def make_default_networks( action_spec: specs.BoundedArray, policy_layer_sizes: Sequence[int] = (256, 256, 256), critic_layer_sizes: Sequence[int] = (512, 512, 256), ) -> Mapping[str, types.TensorTransformation]: """Creates networks used by the agent.""" # Get total number of action dimensions from action spec. num_dimensions = np.prod(action_spec.shape, dtype=int) policy_network = snt.Sequential([ tf2_utils.batch_concat, networks.LayerNormMLP(policy_layer_sizes, activate_final=True), networks.MultivariateNormalDiagHead( num_dimensions, tanh_mean=True, min_scale=0.3, init_scale=0.7, fixed_scale=False, use_tfd_independent=False) ]) # The multiplexer concatenates the (maybe transformed) observations/actions. multiplexer = networks.CriticMultiplexer( action_network=networks.ClipToSpec(action_spec)) critic_network = snt.Sequential([ multiplexer, networks.LayerNormMLP(critic_layer_sizes, activate_final=True), networks.NearZeroInitializedLinear(1), ]) return { "policy": policy_network, "critic": critic_network, }
def make_networks( action_spec: specs.BoundedArray, num_critic_heads: int, policy_layer_sizes: Sequence[int] = (50, ), critic_layer_sizes: Sequence[int] = (50, ), num_layers_shared: int = 1, distributional_critic: bool = True, vmin: float = -150., vmax: float = 150., num_atoms: int = 51, ): """Creates networks used by the agent.""" num_dimensions = np.prod(action_spec.shape, dtype=int) policy_network = snt.Sequential([ networks.LayerNormMLP(policy_layer_sizes, activate_final=True), networks.MultivariateNormalDiagHead(num_dimensions, tanh_mean=False, init_scale=0.69) ]) if not distributional_critic: critic_layer_sizes = list(critic_layer_sizes) + [1] if not num_layers_shared: # No layers are shared critic_network_base = None else: critic_network_base = networks.LayerNormMLP( critic_layer_sizes[:num_layers_shared], activate_final=True) critic_network_heads = [ snt.nets.MLP(critic_layer_sizes, activation=tf.nn.elu, activate_final=False) for _ in range(num_critic_heads) ] if distributional_critic: critic_network_heads = [ snt.Sequential( [c, networks.DiscreteValuedHead(vmin, vmax, num_atoms)]) for c in critic_network_heads ] # The multiplexer concatenates the (maybe transformed) observations/actions. critic_network = snt.Sequential([ networks.CriticMultiplexer( critic_network=critic_network_base, action_network=networks.ClipToSpec(action_spec)), networks.Multihead(network_heads=critic_network_heads), ]) return { 'policy': policy_network, 'critic': critic_network, 'observation': tf2_utils.batch_concat, }
def make_networks( action_spec: specs.BoundedArray, policy_layer_sizes: Sequence[int] = (50, 1024, 1024), critic_layer_sizes: Sequence[int] = (50, 1024, 1024), vmin: float = -150., vmax: float = 150., num_atoms: int = 51, ) -> Dict[str, snt.Module]: """Creates networks used by the agent.""" num_dimensions = np.prod(action_spec.shape, dtype=int) policy_network = snt.Sequential([ networks.LayerNormMLP(policy_layer_sizes, w_init=snt.initializers.Orthogonal(), activation=tf.nn.relu, activate_final=True), networks.MultivariateNormalDiagHead( num_dimensions, tanh_mean=False, init_scale=1.0, fixed_scale=False, use_tfd_independent=True, w_init=snt.initializers.Orthogonal()) ]) # The multiplexer concatenates the (maybe transformed) observations/actions. critic_network = networks.CriticMultiplexer( observation_network=snt.Sequential([ snt.Linear(critic_layer_sizes[0], w_init=snt.initializers.Orthogonal()), snt.LayerNorm(axis=slice(1, None), create_scale=True, create_offset=True), tf.nn.tanh ]), critic_network=snt.nets.MLP(critic_layer_sizes[1:], w_init=snt.initializers.Orthogonal(), activation=tf.nn.relu, activate_final=True), action_network=networks.ClipToSpec(action_spec)) critic_network = snt.Sequential([ critic_network, networks.DiscreteValuedHead(vmin, vmax, num_atoms, w_init=snt.initializers.Orthogonal()) ]) observation_network = networks.DrQTorso() return { 'policy': policy_network, 'critic': critic_network, 'observation': observation_network, }
def make_lstm_mpo_agent(env_spec: specs.EnvironmentSpec, logger: Logger, hyperparams: Dict, checkpoint_path: str): params = DEFAULT_PARAMS.copy() params.update(hyperparams) action_size = np.prod(env_spec.actions.shape, dtype=int).item() policy_network = snt.Sequential([ networks.LayerNormMLP( layer_sizes=[*params.pop('policy_layers'), action_size]), networks.MultivariateNormalDiagHead(num_dimensions=action_size) ]) critic_network = snt.Sequential([ networks.CriticMultiplexer(critic_network=networks.LayerNormMLP( layer_sizes=[*params.pop('critic_layers'), 1])) ]) observation_network = networks.DeepRNN([ networks.LayerNormMLP(layer_sizes=params.pop('observation_layers')), networks.LSTM(hidden_size=200) ]) loss_param_keys = list( filter(lambda key: key.startswith('loss_'), params.keys())) loss_params = dict([(k.replace('loss_', ''), params.pop(k)) for k in loss_param_keys]) policy_loss_module = losses.MPO(**loss_params) # Create a replay server to add data to. # Make sure observation network is a Sonnet Module. observation_network = tf2_utils.to_sonnet_module(observation_network) # Create optimizers. policy_optimizer = Adam(params.pop('policy_lr')) critic_optimizer = Adam(params.pop('critic_lr')) actor = RecurrentActor( networks.DeepRNN([ observation_network, policy_network, networks.StochasticModeHead() ])) # The learner updates the parameters (and initializes them). return RecurrentMPO(environment_spec=env_spec, policy_network=policy_network, critic_network=critic_network, observation_network=observation_network, policy_loss_module=policy_loss_module, policy_optimizer=policy_optimizer, critic_optimizer=critic_optimizer, logger=logger, checkpoint_path=checkpoint_path, **params), actor
def make_default_networks( environment_spec: specs.EnvironmentSpec, *, policy_layer_sizes: Sequence[int] = (256, 256, 256), critic_layer_sizes: Sequence[int] = (512, 512, 256), policy_init_scale: float = 0.7, critic_init_scale: float = 1e-3, critic_num_components: int = 5, ) -> Mapping[str, snt.Module]: """Creates networks used by the agent.""" # Unpack the environment spec to get appropriate shapes, dtypes, etc. act_spec = environment_spec.actions obs_spec = environment_spec.observations num_dimensions = np.prod(act_spec.shape, dtype=int) # Create the observation network and make sure it's a Sonnet module. observation_network = tf2_utils.batch_concat observation_network = tf2_utils.to_sonnet_module(observation_network) # Create the policy network. policy_network = snt.Sequential([ networks.LayerNormMLP(policy_layer_sizes, activate_final=True), networks.MultivariateNormalDiagHead(num_dimensions, init_scale=policy_init_scale, use_tfd_independent=True) ]) # The multiplexer concatenates the (maybe transformed) observations/actions. critic_network = snt.Sequential([ networks.CriticMultiplexer( action_network=networks.ClipToSpec(act_spec)), networks.LayerNormMLP(critic_layer_sizes, activate_final=True), networks.GaussianMixtureHead(num_dimensions=1, num_components=critic_num_components, init_scale=critic_init_scale) ]) # Create network variables. # Get embedding spec by creating observation network variables. emb_spec = tf2_utils.create_variables(observation_network, [obs_spec]) tf2_utils.create_variables(policy_network, [emb_spec]) tf2_utils.create_variables(critic_network, [emb_spec, act_spec]) return { 'policy': policy_network, 'critic': critic_network, 'observation': observation_network, }
def make_dmpo_networks( action_spec, policy_layer_sizes=(256, 256, 256), critic_layer_sizes=(512, 512, 256), vmin=-150., vmax=150., num_atoms=51, policy_init_std=1e-9, obs_network=None, binary_grip_action=False): """Creates networks used by the agent.""" num_dimensions = np.prod(action_spec.shape, dtype=int) if policy_layer_sizes: policy_network = snt.Sequential([ networks.LayerNormMLP([int(l) for l in policy_layer_sizes]), networks.MultivariateNormalDiagHead( num_dimensions, init_scale=policy_init_std, min_scale=1e-10) ]) else: # Useful when initializing from a trained BC network. policy_network = snt.Sequential([ ArmPolicyNormalDiagHead( binary_grip_action=binary_grip_action, num_dimensions=num_dimensions, init_scale=policy_init_std, min_scale=1e-10) ]) # The multiplexer concatenates the (maybe transformed) observations/actions. critic_network = networks.CriticMultiplexer( critic_network=networks.LayerNormMLP(critic_layer_sizes), action_network=networks.ClipToSpec(action_spec)) critic_network = snt.Sequential( [critic_network, networks.DiscreteValuedHead(vmin, vmax, num_atoms)]) if obs_network is None: obs_network = tf_utils.batch_concat return { 'policy': policy_network, 'critic': critic_network, 'observation': obs_network, }
def make_feed_forward_networks( action_spec: specs.BoundedArray, z_spec: specs.BoundedArray, policy_layer_sizes: Tuple[int, ...] = (256, 256), critic_layer_sizes: Tuple[int, ...] = (256, 256), discriminator_layer_sizes: Tuple[int, ...] = (256, 256), hierarchical_controller_layer_sizes: Tuple[int, ...] = (256, 256), vmin: float = -150., # Minimum value for the Critic distribution. vmax: float = 150., # Maximum value for the Critic distribution. num_atoms: int = 51, # Number of atoms for the discrete value distribution. ) -> Dict[str, types.TensorTransformation]: num_dimensions = np.prod(action_spec.shape, dtype=int) z_dim = np.prod(z_spec.shape, dtype=int) observation_network = tf2_utils.batch_concat policy_network = snt.Sequential([ networks.LayerNormMLP(policy_layer_sizes), networks.MultivariateNormalDiagHead(num_dimensions) ]) critic_multiplexer = networks.CriticMultiplexer( critic_network=networks.LayerNormMLP(critic_layer_sizes), action_network=networks.ClipToSpec(action_spec)) critic_network = snt.Sequential([ critic_multiplexer, networks.DiscreteValuedHead(vmin, vmax, num_atoms), ]) # The discriminator in DIAYN uses the same architecture as the critic. discriminator_network = networks.LayerNormMLP(discriminator_layer_sizes + (z_dim, )) hierarchical_controller_network = networks.LayerNormMLP( hierarchical_controller_layer_sizes + (z_dim, )) return { 'policy': policy_network, 'critic': critic_network, 'observation': observation_network, 'discriminator': discriminator_network, 'hierarchical_controller': hierarchical_controller_network, }
def make_networks( action_spec, policy_layer_sizes=(10, 10), critic_layer_sizes=(10, 10), ): """Creates networks used by the agent.""" num_dimensions = np.prod(action_spec.shape, dtype=int) critic_layer_sizes = list(critic_layer_sizes) + [1] policy_network = snt.Sequential([ networks.LayerNormMLP(policy_layer_sizes), networks.MultivariateNormalDiagHead(num_dimensions) ]) critic_network = networks.CriticMultiplexer( critic_network=networks.LayerNormMLP(critic_layer_sizes)) return { 'policy': policy_network, 'critic': critic_network, }
def make_networks( action_spec: specs.BoundedArray, policy_layer_sizes: Sequence[int] = (256, 256, 256), critic_layer_sizes: Sequence[int] = (512, 512, 256), vmin: float = -150., vmax: float = 150., num_atoms: int = 51, ) -> Dict[str, types.TensorTransformation]: """Creates networks used by the agent.""" # Get total number of action dimensions from action spec. num_dimensions = np.prod(action_spec.shape, dtype=int) # Create the shared observation network; here simply a state-less operation. observation_network = tf2_utils.batch_concat # Create the policy network. policy_network = snt.Sequential([ networks.LayerNormMLP(policy_layer_sizes), networks.MultivariateNormalDiagHead(num_dimensions) ]) # The multiplexer transforms concatenates the observations/actions. multiplexer = networks.CriticMultiplexer( critic_network=networks.LayerNormMLP(critic_layer_sizes), action_network=networks.ClipToSpec(action_spec)) # Create the critic network. critic_network = snt.Sequential([ multiplexer, networks.DiscreteValuedHead(vmin, vmax, num_atoms), ]) return { 'policy': policy_network, 'critic': critic_network, 'observation': observation_network, }
def make_networks( action_spec: specs.BoundedArray, policy_layer_sizes: Sequence[int] = (50, 50), critic_layer_sizes: Sequence[int] = (50, 50), ): """Creates networks used by the agent.""" num_dimensions = np.prod(action_spec.shape, dtype=int) observation_network = tf2_utils.batch_concat policy_network = snt.Sequential([ networks.LayerNormMLP(policy_layer_sizes, activate_final=True), networks.MultivariateNormalDiagHead(num_dimensions, tanh_mean=True, init_scale=0.3, fixed_scale=True, use_tfd_independent=False) ]) evaluator_network = snt.Sequential([ observation_network, policy_network, networks.StochasticMeanHead(), ]) # The multiplexer concatenates the (maybe transformed) observations/actions. multiplexer = networks.CriticMultiplexer( action_network=networks.ClipToSpec(action_spec)) critic_network = snt.Sequential([ multiplexer, networks.LayerNormMLP(critic_layer_sizes, activate_final=True), networks.NearZeroInitializedLinear(1), ]) return { 'policy': policy_network, 'critic': critic_network, 'observation': observation_network, 'evaluator': evaluator_network, }
def __init__(self, n_classes=None, last_activation=None, fc_layer_sizes=(), weight_decay=5e-4, bn_axis=3, batch_norm_decay=0.1, init_scheme='v1'): super(Resnet18Narrow32, self).__init__(name='') if init_scheme == 'v1': print('Using v1 weight init') conv2d_init = v1_conv2d_init # Bias is not used in conv layers. linear_init = v1_linear_init linear_bias_init = v1_linear_bias_init else: print('Using v2 weight init') conv2d_init = keras.initializers.VarianceScaling( scale=2.0, mode='fan_out', distribution='untruncated_normal') linear_init = torch_linear_init linear_bias_init = torch_linear_bias_init # Why is this separate instead of padding='same' in tfl.Conv2D? self.zero_pad = tfl.ZeroPadding2D(padding=(3, 3), input_shape=(32, 32, 3), name='conv1_pad') self.conv1 = tfl.Conv2D( 64, (7, 7), strides=(2, 2), padding='valid', kernel_initializer=conv2d_init, kernel_regularizer=keras.regularizers.l2(weight_decay), use_bias=False, name='conv1') self.bn1 = tfl.BatchNormalization(axis=bn_axis, name='bn_conv1', momentum=batch_norm_decay, epsilon=BATCH_NORM_EPSILON) self.zero_pad2 = tfl.ZeroPadding2D(padding=(1, 1), name='max_pool_pad') self.max_pool = tfl.MaxPooling2D(pool_size=(3, 3), strides=(2, 2), padding='valid') self.resblock1 = Resnet18Block(kernel_size=3, input_planes=64, output_planes=32, stage=2, strides=(1, 1), weight_decay=weight_decay, batch_norm_decay=batch_norm_decay, conv2d_init=conv2d_init) self.resblock2 = Resnet18Block(kernel_size=3, input_planes=32, output_planes=64, stage=3, strides=(2, 2), weight_decay=weight_decay, batch_norm_decay=batch_norm_decay, conv2d_init=conv2d_init) self.resblock3 = Resnet18Block(kernel_size=3, input_planes=64, output_planes=128, stage=4, strides=(2, 2), weight_decay=weight_decay, batch_norm_decay=batch_norm_decay, conv2d_init=conv2d_init) self.resblock4 = Resnet18Block(kernel_size=3, input_planes=128, output_planes=256, stage=4, strides=(2, 2), weight_decay=weight_decay, batch_norm_decay=batch_norm_decay, conv2d_init=conv2d_init) self.pool = tfl.GlobalAveragePooling2D(name='avg_pool') self.bn2 = tfl.BatchNormalization(axis=-1, name='bn_conv2', momentum=batch_norm_decay, epsilon=BATCH_NORM_EPSILON) self.fcs = [] if FLAGS.layer_norm_policy: self.linear = snt.Sequential([ networks.LayerNormMLP(fc_layer_sizes), networks.MultivariateNormalDiagHead(n_classes), networks.StochasticMeanHead() ]) else: for size in fc_layer_sizes: self.fcs.append( tfl.Dense( size, activation=tf.nn.relu, kernel_initializer=linear_init, bias_initializer=linear_bias_init, kernel_regularizer=keras.regularizers.l2(weight_decay), bias_regularizer=keras.regularizers.l2(weight_decay))) if n_classes is not None: self.linear = tfl.Dense( n_classes, activation=last_activation, kernel_initializer=linear_init, bias_initializer=linear_bias_init, kernel_regularizer=keras.regularizers.l2(weight_decay), bias_regularizer=keras.regularizers.l2(weight_decay), name='fc%d' % n_classes) self.n_classes = n_classes if n_classes is not None: self.log_std = tf.Variable(tf.zeros(n_classes), trainable=True, name='log_std') self.first_forward_pass = FLAGS.data_smaller
def make_default_networks( environment_spec: mava_specs.MAEnvironmentSpec, policy_networks_layer_sizes: Union[Dict[str, Sequence], Sequence] = (256, 256, 256), critic_networks_layer_sizes: Union[Dict[str, Sequence], Sequence] = (512, 512, 256), shared_weights: bool = True, ) -> Dict[str, snt.Module]: """Default networks for mappo. Args: environment_spec (mava_specs.MAEnvironmentSpec): description of the action and observation spaces etc. for each agent in the system. policy_networks_layer_sizes (Union[Dict[str, Sequence], Sequence], optional): size of policy networks. Defaults to (256, 256, 256). critic_networks_layer_sizes (Union[Dict[str, Sequence], Sequence], optional): size of critic networks. Defaults to (512, 512, 256). shared_weights (bool, optional): whether agents should share weights or not. Defaults to True. Raises: ValueError: Unknown action_spec type, if actions aren't DiscreteArray or BoundedArray. Returns: Dict[str, snt.Module]: returned agent networks. """ # Create agent_type specs. specs = environment_spec.get_agent_specs() if shared_weights: type_specs = {key.split("_")[0]: specs[key] for key in specs.keys()} specs = type_specs if isinstance(policy_networks_layer_sizes, Sequence): policy_networks_layer_sizes = { key: policy_networks_layer_sizes for key in specs.keys() } if isinstance(critic_networks_layer_sizes, Sequence): critic_networks_layer_sizes = { key: critic_networks_layer_sizes for key in specs.keys() } observation_networks = {} policy_networks = {} critic_networks = {} for key in specs.keys(): # Create the shared observation network; here simply a state-less operation. observation_network = tf2_utils.to_sonnet_module(tf.identity) # Note: The discrete case must be placed first as it inherits from BoundedArray. if isinstance(specs[key].actions, dm_env.specs.DiscreteArray): # discrete num_actions = specs[key].actions.num_values policy_network = snt.Sequential( [ networks.LayerNormMLP( tuple(policy_networks_layer_sizes[key]) + (num_actions,), activate_final=False, ), tf.keras.layers.Lambda( lambda logits: tfp.distributions.Categorical(logits=logits) ), ] ) elif isinstance(specs[key].actions, dm_env.specs.BoundedArray): # continuous num_actions = np.prod(specs[key].actions.shape, dtype=int) policy_network = snt.Sequential( [ networks.LayerNormMLP( policy_networks_layer_sizes[key], activate_final=True ), networks.MultivariateNormalDiagHead(num_dimensions=num_actions), networks.TanhToSpec(specs[key].actions), ] ) else: raise ValueError(f"Unknown action_spec type, got {specs[key].actions}.") critic_network = snt.Sequential( [ networks.LayerNormMLP( list(critic_networks_layer_sizes[key]) + [1], activate_final=False ), ] ) observation_networks[key] = observation_network policy_networks[key] = policy_network critic_networks[key] = critic_network return { "policies": policy_networks, "critics": critic_networks, "observations": observation_networks, }
def make_networks( environment_spec: mava_specs.MAEnvironmentSpec, policy_networks_layer_sizes: Union[Dict[str, Sequence], Sequence] = ( 256, 256, 256, ), critic_networks_layer_sizes: Union[Dict[str, Sequence], Sequence] = (512, 512, 256), shared_weights: bool = True, ) -> Dict[str, snt.Module]: """Creates networks used by the agents.""" # Create agent_type specs. specs = environment_spec.get_agent_specs() if shared_weights: type_specs = {key.split("_")[0]: specs[key] for key in specs.keys()} specs = type_specs if isinstance(policy_networks_layer_sizes, Sequence): policy_networks_layer_sizes = { key: policy_networks_layer_sizes for key in specs.keys() } if isinstance(critic_networks_layer_sizes, Sequence): critic_networks_layer_sizes = { key: critic_networks_layer_sizes for key in specs.keys() } observation_networks = {} policy_networks = {} critic_networks = {} for key in specs.keys(): # Create the shared observation network; here simply a state-less operation. observation_network = tf2_utils.to_sonnet_module(tf.identity) # Note: The discrete case must be placed first as it inherits from BoundedArray. if isinstance(specs[key].actions, dm_env.specs.DiscreteArray): # discrete num_actions = specs[key].actions.num_values policy_network = snt.Sequential([ networks.LayerNormMLP( tuple(policy_networks_layer_sizes[key]) + (num_actions, ), activate_final=False, ), tf.keras.layers.Lambda(lambda logits: tfp.distributions. Categorical(logits=logits)), ]) elif isinstance(specs[key].actions, dm_env.specs.BoundedArray): # continuous num_actions = np.prod(specs[key].actions.shape, dtype=int) policy_network = snt.Sequential([ networks.LayerNormMLP(policy_networks_layer_sizes[key], activate_final=True), networks.MultivariateNormalDiagHead( num_dimensions=num_actions), networks.TanhToSpec(specs[key].actions), ]) else: raise ValueError( f"Unknown action_spec type, got {specs[key].actions}.") critic_network = snt.Sequential([ networks.LayerNormMLP(critic_networks_layer_sizes[key], activate_final=True), networks.NearZeroInitializedLinear(1), ]) observation_networks[key] = observation_network policy_networks[key] = policy_network critic_networks[key] = critic_network return { "policies": policy_networks, "critics": critic_networks, "observations": observation_networks, }