コード例 #1
0
        def apply_fns():
            embed_apply_fn = hk.without_apply_rng(
                hk.transform(embedding)).apply
            transformer_apply_fn = hk.without_apply_rng(
                hk.transform(transformer)).apply

            return embed_apply_fn, transformer_apply_fn
コード例 #2
0
def init(arch, h, L, act, seed_init, **args):
    if act == 'silu':
        act = jax.nn.silu
    if act == 'gelu':
        act = jax.nn.gelu
    if act == 'relu':
        act = jax.nn.relu

    act = normalize_act(act)

    xtr, xte, ytr, yte = dataset(**args)
    print('dataset generated', flush=True)

    if arch == 'mlp':
        model = hk.without_apply_rng(
            hk.transform(lambda x: mlp([h] * L, act, x)))

        xtr = xtr.reshape(xtr.shape[0], -1)
        xte = xte.reshape(xte.shape[0], -1)

    if arch == 'mnas':
        model = hk.without_apply_rng(hk.transform(lambda x: mnas(h, act, x)))

    print(f'xtr.shape={xtr.shape} xte.shape={xte.shape}', flush=True)

    w = model.init(jax.random.PRNGKey(seed_init), xtr)
    print('network initialized', flush=True)

    return model, w, xtr, xte, ytr, yte
コード例 #3
0
    def __init__(
        self,
        obs_spec: specs.Array,
        action_spec: specs.DiscreteArray,
        network: PolicyValueNet,
        optimizer: optax.GradientTransformation,
        rng: hk.PRNGSequence,
        sequence_length: int,
        discount: float,
        td_lambda: float,
    ):

        # Define loss function.
        def loss(trajectory: sequence.Trajectory) -> jnp.ndarray:
            """"Actor-critic loss."""
            logits, values = network(trajectory.observations)
            td_errors = rlax.td_lambda(
                v_tm1=values[:-1],
                r_t=trajectory.rewards,
                discount_t=trajectory.discounts * discount,
                v_t=values[1:],
                lambda_=jnp.array(td_lambda),
            )
            critic_loss = jnp.mean(td_errors**2)
            actor_loss = rlax.policy_gradient_loss(
                logits_t=logits[:-1],
                a_t=trajectory.actions,
                adv_t=td_errors,
                w_t=jnp.ones_like(td_errors))

            return actor_loss + critic_loss

        # Transform the loss into a pure function.
        loss_fn = hk.without_apply_rng(hk.transform(loss,
                                                    apply_rng=True)).apply

        # Define update function.
        @jax.jit
        def sgd_step(state: TrainingState,
                     trajectory: sequence.Trajectory) -> TrainingState:
            """Does a step of SGD over a trajectory."""
            gradients = jax.grad(loss_fn)(state.params, trajectory)
            updates, new_opt_state = optimizer.update(gradients,
                                                      state.opt_state)
            new_params = optax.apply_updates(state.params, updates)
            return TrainingState(params=new_params, opt_state=new_opt_state)

        # Initialize network parameters and optimiser state.
        init, forward = hk.without_apply_rng(
            hk.transform(network, apply_rng=True))
        dummy_observation = jnp.zeros((1, *obs_spec.shape), dtype=jnp.float32)
        initial_params = init(next(rng), dummy_observation)
        initial_opt_state = optimizer.init(initial_params)

        # Internalize state.
        self._state = TrainingState(initial_params, initial_opt_state)
        self._forward = jax.jit(forward)
        self._buffer = sequence.Buffer(obs_spec, action_spec, sequence_length)
        self._sgd_step = sgd_step
        self._rng = rng
コード例 #4
0
def make_haiku_networks(
        env_spec: specs.EnvironmentSpec, forward_fn: Any,
        initial_state_fn: Any,
        unroll_fn: Any) -> IMPALANetworks[types.RecurrentState]:
    """Builds functional impala network from recurrent model definitions."""
    # Make networks purely functional.
    forward_hk = hk.without_apply_rng(hk.transform(forward_fn))
    initial_state_hk = hk.without_apply_rng(hk.transform(initial_state_fn))
    unroll_hk = hk.without_apply_rng(hk.transform(unroll_fn))

    # Define networks init functions.
    def initial_state_init_fn(rng: networks_lib.PRNGKey) -> hk.Params:
        return initial_state_hk.init(rng)

    # Note: batch axis is not needed for the actors.
    dummy_obs = utils.zeros_like(env_spec.observations)
    dummy_obs_sequence = utils.add_batch_dim(dummy_obs)

    def unroll_init_fn(rng: networks_lib.PRNGKey,
                       initial_state: types.RecurrentState) -> hk.Params:
        return unroll_hk.init(rng, dummy_obs_sequence, initial_state)

    return IMPALANetworks(forward_fn=forward_hk.apply,
                          unroll_init_fn=unroll_init_fn,
                          unroll_fn=unroll_hk.apply,
                          initial_state_init_fn=initial_state_init_fn,
                          initial_state_fn=initial_state_hk.apply)
コード例 #5
0
ファイル: misc.py プロジェクト: winston-ds/rljax
def make_quantile_nerwork(
    rng,
    state_space,
    action_space,
    fn,
    num_quantiles,
):
    """
    Make Quantile Nerwork for FQF.
    """
    fake_state = state_space.sample()[None, ...]
    if len(state_space.shape) == 1:
        fake_state = fake_state.astype(np.float32)
    network_dict = {}
    params_dict = {}

    if len(state_space.shape) == 3:
        network_dict["feature"] = hk.without_apply_rng(
            hk.transform(lambda s: DQNBody()(s)))
        fake_feature = np.zeros((1, 7 * 7 * 64), dtype=np.float32)
    else:
        network_dict["feature"] = hk.without_apply_rng(
            hk.transform(lambda s: s))
        fake_feature = fake_state
    params_dict["feature"] = network_dict["feature"].init(
        next(rng), fake_state)

    fake_cum_p = np.empty((1, num_quantiles), dtype=np.float32)
    network_dict["quantile"] = hk.without_apply_rng(hk.transform(fn))
    params_dict["quantile"] = network_dict["quantile"].init(
        next(rng), fake_feature, fake_cum_p)

    network_dict = hk.data_structures.to_immutable_dict(network_dict)
    params_dict = hk.data_structures.to_immutable_dict(params_dict)
    return network_dict, params_dict, fake_feature
コード例 #6
0
def load(name: str, device: Union[str, torch.device] = "cpu", jit=True):
    """Load a CLIP model

    Parameters
    ----------
    name : str
        A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict

    device : Union[str, torch.device]
        The device to put the loaded model

    jit : bool
        Whether to load the optimized JIT model (default) or more hackable non-JIT model.

    Returns
    -------
    model : torch.nn.Module
        The CLIP model

    preprocess : Callable[[PIL.Image], torch.Tensor]
        A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input
    """
    if name in _MODELS:
        model_path = _download(_MODELS[name])
    elif os.path.isfile(name):
        model_path = name
    else:
        raise RuntimeError(f"Model {name} not found; available models = {available_models()}")

    try:
        # loading JIT archive
        state_dict = torch.jit.load(model_path, map_location=device if jit else "cpu").eval().state_dict()
    except RuntimeError:
        state_dict = torch.load(model_path, map_location="cpu")

    clip_params = get_params(state_dict)

    # jax model
    def clip_jax(image, text):
        clip = CLIP(**clip_params)
        return clip.encode_image(image), clip.encode_text(text)

    def vit_jax(image):
        clip = CLIP(**clip_params)
        return clip.encode_image(image)

    def text_jax(text):
        clip = CLIP(**clip_params)
        return clip.encode_text(text)

    rng_key = jax.random.PRNGKey(42)
    transformed = hk.transform(clip_jax)
    jax_params = transformed.init(rng=rng_key, image=jnp.zeros((1, 3, 224, 224)), text=jnp.zeros((1, 77), dtype=jnp.int16))
    jax_params = convert_params(state_dict, jax_params)

    image_fn = hk.without_apply_rng(hk.transform(vit_jax)).apply
    text_fn = hk.without_apply_rng(hk.transform(text_jax)).apply

    return image_fn, text_fn, jax_params, _transform(clip_params["image_resolution"])
コード例 #7
0
ファイル: main.py プロジェクト: ChrisWaites/data-deletion
def main(_):
    FLAGS.alsologtostderr = True

    # Make training dataset.
    train_data = iter(
        dataset.load(tfds.Split.TRAIN,
                     batch_size=FLAGS.train_batch_size,
                     sequence_length=FLAGS.sequence_length))

    # Make evaluation dataset(s).
    eval_data = {  # pylint: disable=g-complex-comprehension
        split: iter(
            dataset.load(split,
                         batch_size=FLAGS.eval_batch_size,
                         sequence_length=FLAGS.sequence_length))
        for split in [tfds.Split.TRAIN, tfds.Split.TEST]
    }

    # Make loss, sampler, and optimizer.
    params_init, loss_fn = hk.without_apply_rng(hk.transform(sequence_loss))
    _, sample_fn = hk.without_apply_rng(hk.transform(sample))
    opt_init, _ = make_optimizer()

    loss_fn = jax.jit(loss_fn)
    sample_fn = jax.jit(sample_fn, static_argnums=[3])

    # Initialize training state.
    rng = hk.PRNGSequence(FLAGS.seed)
    initial_params = params_init(next(rng), next(train_data))
    initial_opt_state = opt_init(initial_params)
    state = TrainingState(params=initial_params, opt_state=initial_opt_state)

    # Training loop.
    for step in tqdm(range(FLAGS.training_steps + 1)):
        # Do a batch of SGD.
        train_batch = next(train_data)
        state = update(state, train_batch)

        # Periodically generate samples.
        if step % FLAGS.sampling_interval == 0:
            context = train_batch[
                'input'][:, 0]  # First element of training batch.
            assert context.ndim == 1
            rng_key = next(rng)
            samples = sample_fn(state.params, rng_key, context,
                                FLAGS.sample_length)

            prompt = dataset.decode(context)
            continuation = dataset.decode(samples)

            #logging.info('Prompt: %s', prompt)
            #logging.info('Continuation: %s', continuation)

        # Periodically evaluate training and test loss.
        if step % FLAGS.evaluation_interval == 0:
            for split, ds in eval_data.items():
                eval_batch = next(ds)
                loss = loss_fn(state.params, eval_batch)
コード例 #8
0
    def __init__(self, random_seed, num_classes, batch_size, max_steps,
                 enable_double_transpose, checkpoint_to_evaluate,
                 allow_train_from_scratch, freeze_backbone, network_config,
                 optimizer_config, lr_schedule_config, evaluation_config,
                 checkpointing_config):
        """Constructs the experiment.

    Args:
      random_seed: the random seed to use when initializing network weights.
      num_classes: the number of classes; used for the online evaluation.
      batch_size: the total batch size; should be a multiple of the number of
        available accelerators.
      max_steps: the number of training steps; used for the lr/target network
        ema schedules.
      enable_double_transpose: see dataset.py; only has effect on TPU.
      checkpoint_to_evaluate: the path to the checkpoint to evaluate.
      allow_train_from_scratch: whether to allow training without specifying a
        checkpoint to evaluate (training from scratch).
      freeze_backbone: whether the backbone resnet should remain frozen (linear
        evaluation) or be trainable (fine-tuning).
      network_config: the configuration for the network.
      optimizer_config: the configuration for the optimizer.
      lr_schedule_config: the configuration for the learning rate schedule.
      evaluation_config: the evaluation configuration.
      checkpointing_config: the configuration for checkpointing.
    """

        self._random_seed = random_seed
        self._enable_double_transpose = enable_double_transpose
        self._num_classes = num_classes
        self._lr_schedule_config = lr_schedule_config
        self._batch_size = batch_size
        self._max_steps = max_steps
        self._checkpoint_to_evaluate = checkpoint_to_evaluate
        self._allow_train_from_scratch = allow_train_from_scratch
        self._freeze_backbone = freeze_backbone
        self._optimizer_config = optimizer_config
        self._evaluation_config = evaluation_config

        # Checkpointed experiment state.
        self._experiment_state = None

        # Input pipelines.
        self._train_input = None
        self._eval_input = None

        backbone_fn = functools.partial(self._backbone_fn, **network_config)
        self.forward_backbone = hk.without_apply_rng(
            hk.transform_with_state(backbone_fn))
        self.forward_classif = hk.without_apply_rng(
            hk.transform(self._classif_fn))
        self.update_pmap = jax.pmap(self._update_func, axis_name='i')
        self.eval_batch_jit = jax.jit(self._eval_batch)

        self._is_backbone_training = not self._freeze_backbone

        self._checkpointer = checkpointing.Checkpointer(**checkpointing_config)
コード例 #9
0
def make_networks(
    spec: specs.EnvironmentSpec,
    hidden_layer_sizes: Tuple[int, ...] = (256, 256)
) -> SACNetworks:
    """Creates networks used by the agent."""

    num_dimensions = np.prod(spec.actions.shape, dtype=int)

    def _actor_fn(obs):
        network = hk.Sequential([
            hk.nets.MLP(list(hidden_layer_sizes),
                        w_init=hk.initializers.VarianceScaling(
                            1.0, 'fan_in', 'uniform'),
                        activation=jax.nn.relu,
                        activate_final=True),
            networks_lib.NormalTanhDistribution(num_dimensions),
        ])
        return network(obs)

    def _critic_fn(obs, action):
        network1 = hk.Sequential([
            hk.nets.MLP(list(hidden_layer_sizes) + [1],
                        w_init=hk.initializers.VarianceScaling(
                            1.0, 'fan_in', 'uniform'),
                        activation=jax.nn.relu),
        ])
        network2 = hk.Sequential([
            hk.nets.MLP(list(hidden_layer_sizes) + [1],
                        w_init=hk.initializers.VarianceScaling(
                            1.0, 'fan_in', 'uniform'),
                        activation=jax.nn.relu),
        ])
        input_ = jnp.concatenate([obs, action], axis=-1)
        value1 = network1(input_)
        value2 = network2(input_)
        return jnp.concatenate([value1, value2], axis=-1)

    policy = hk.without_apply_rng(hk.transform(_actor_fn))
    critic = hk.without_apply_rng(hk.transform(_critic_fn))

    # Create dummy observations and actions to create network parameters.
    dummy_action = utils.zeros_like(spec.actions)
    dummy_obs = utils.zeros_like(spec.observations)
    dummy_action = utils.add_batch_dim(dummy_action)
    dummy_obs = utils.add_batch_dim(dummy_obs)

    return SACNetworks(
        policy_network=networks_lib.FeedForwardNetwork(
            lambda key: policy.init(key, dummy_obs), policy.apply),
        q_network=networks_lib.FeedForwardNetwork(
            lambda key: critic.init(key, dummy_obs, dummy_action),
            critic.apply),
        log_prob=lambda params, actions: params.log_prob(actions),
        sample=lambda params, key: params.sample(seed=key),
        sample_eval=lambda params, key: params.mode())
コード例 #10
0
    def __init__(
        self,
        environment_spec: specs.EnvironmentSpec,
        forward_fn: networks.PolicyValueRNN,
        unroll_fn: networks.PolicyValueRNN,
        initial_state_fn: Callable[[], hk.LSTMState],
        sequence_length: int,
        sequence_period: int,
        counter: counting.Counter = None,
        logger: loggers.Logger = None,
        discount: float = 0.99,
        max_queue_size: int = 100000,
        batch_size: int = 16,
        learning_rate: float = 1e-3,
        entropy_cost: float = 0.01,
        baseline_cost: float = 0.5,
        seed: int = 0,
        max_abs_reward: float = np.inf,
        max_gradient_norm: float = np.inf,
    ):

        forward_fn_transformed = hk.without_apply_rng(
            hk.transform(forward_fn, apply_rng=True))
        unroll_fn_transformed = hk.without_apply_rng(
            hk.transform(unroll_fn, apply_rng=True))
        initial_state_fn_transformed = hk.without_apply_rng(
            hk.transform(initial_state_fn, apply_rng=True))

        config = IMPALAConfig(
            sequence_length=sequence_length,
            sequence_period=sequence_period,
            discount=discount,
            max_queue_size=max_queue_size,
            batch_size=batch_size,
            learning_rate=learning_rate,
            entropy_cost=entropy_cost,
            baseline_cost=baseline_cost,
            seed=seed,
            max_abs_reward=max_abs_reward,
            max_gradient_norm=max_gradient_norm,
        )
        super().__init__(
            environment_spec=environment_spec,
            forward_fn=forward_fn_transformed.apply,
            unroll_init_fn=unroll_fn_transformed.init,
            unroll_fn=unroll_fn_transformed.apply,
            initial_state_init_fn=initial_state_fn_transformed.init,
            initial_state_fn=initial_state_fn_transformed.apply,
            config=config,
            counter=counter,
            logger=logger,
        )
コード例 #11
0
    def setUp(self):
        super(GatedLinearNetworkTest, self).setUp()
        self._name = "test_network"
        self._rng = hk.PRNGSequence(jax.random.PRNGKey(42))

        self._output_sizes = (4, 5, 6)
        self._context_dim = 2
        self._bias_len = 3

        def gln_factory():
            return gaussian.GatedLinearNetwork(
                output_sizes=self._output_sizes,
                context_dim=self._context_dim,
                bias_len=self._bias_len,
                name=self._name,
            )

        def inference_fn(inputs, side_info):
            return gln_factory().inference(inputs, side_info, 0.5)

        def batch_inference_fn(inputs, side_info):
            return jax.vmap(inference_fn, in_axes=(0, 0))(inputs, side_info)

        def update_fn(inputs, side_info, label, learning_rate):
            params, predictions, unused_loss = gln_factory().update(
                inputs, side_info, label, learning_rate, 0.5)
            return predictions, params

        def batch_update_fn(inputs, side_info, label, learning_rate):
            predictions, params = jax.vmap(update_fn,
                                           in_axes=(0, 0, 0,
                                                    None))(inputs, side_info,
                                                           label,
                                                           learning_rate)
            avg_params = tree.map_structure(lambda x: jnp.mean(x, axis=0),
                                            params)
            return predictions, avg_params

        # Haiku transform functions.
        self._init_fn, inference_fn_ = hk.without_apply_rng(
            hk.transform_with_state(inference_fn))
        self._batch_init_fn, batch_inference_fn_ = hk.without_apply_rng(
            hk.transform_with_state(batch_inference_fn))
        _, update_fn_ = hk.without_apply_rng(
            hk.transform_with_state(update_fn))
        _, batch_update_fn_ = hk.without_apply_rng(
            hk.transform_with_state(batch_update_fn))

        self._inference_fn = jax.jit(inference_fn_)
        self._batch_inference_fn = jax.jit(batch_inference_fn_)
        self._update_fn = jax.jit(update_fn_)
        self._batch_update_fn = jax.jit(batch_update_fn_)
コード例 #12
0
ファイル: agent_test.py プロジェクト: zerocurve/acme
    def test_impala(self):
        # Create a fake environment to test with.
        environment = fakes.DiscreteEnvironment(num_actions=5,
                                                num_observations=10,
                                                obs_shape=(10, 5),
                                                obs_dtype=np.float32,
                                                episode_length=10)
        spec = specs.make_environment_spec(environment)

        def forward_fn(x, s):
            model = MyNetwork(spec.actions.num_values)
            return model(x, s)

        def initial_state_fn(batch_size: Optional[int] = None):
            model = MyNetwork(spec.actions.num_values)
            return model.initial_state(batch_size)

        def unroll_fn(inputs, state):
            model = MyNetwork(spec.actions.num_values)
            return hk.static_unroll(model, inputs, state)

        # We pass pure, Haiku-agnostic functions to the agent.
        forward_fn_transformed = hk.without_apply_rng(
            hk.transform(forward_fn, apply_rng=True))
        unroll_fn_transformed = hk.without_apply_rng(
            hk.transform(unroll_fn, apply_rng=True))
        initial_state_fn_transformed = hk.without_apply_rng(
            hk.transform(initial_state_fn, apply_rng=True))

        # Construct the agent.
        config = impala_agent.IMPALAConfig(
            sequence_length=3,
            sequence_period=3,
            batch_size=6,
        )
        agent = impala.IMPALAFromConfig(
            environment_spec=spec,
            forward_fn=forward_fn_transformed.apply,
            initial_state_init_fn=initial_state_fn_transformed.init,
            initial_state_fn=initial_state_fn_transformed.apply,
            unroll_init_fn=unroll_fn_transformed.init,
            unroll_fn=unroll_fn_transformed.apply,
            config=config,
        )

        # Try running the environment loop. We have no assertions here because all
        # we care about is that the agent runs without raising any errors.
        loop = acme.EnvironmentLoop(environment, agent)
        loop.run(num_episodes=20)
コード例 #13
0
ファイル: networks.py プロジェクト: kokizzu/google-research
def make_networks(
    spec,
    build_actor_fn=build_standard_actor_fn,
    img_encoder_fn=None,
):
    """Creates networks used by the agent."""
    # Create dummy observations and actions to create network parameters.
    dummy_action = utils.zeros_like(spec.actions)
    dummy_obs = utils.zeros_like(spec.observations)
    dummy_action = utils.add_batch_dim(dummy_action)
    dummy_obs = utils.add_batch_dim(dummy_obs)

    if isinstance(spec.actions, specs.DiscreteArray):
        num_dimensions = spec.actions.num_values
        # _actor_fn = procgen_networks.build_procgen_actor_fn(num_dimensions)
    else:
        num_dimensions = np.prod(spec.actions.shape, dtype=int)

    _actor_fn = build_actor_fn(num_dimensions)

    if img_encoder_fn is not None:
        img_encoder = hk.without_apply_rng(
            hk.transform(img_encoder_fn, apply_rng=True))
        key = jax.random.PRNGKey(seed=42)
        temp_encoder_params = img_encoder.init(key, dummy_obs['state_image'])
        dummy_hidden = img_encoder.apply(temp_encoder_params,
                                         dummy_obs['state_image'])
        img_encoder_network = networks_lib.FeedForwardNetwork(
            lambda key: img_encoder.init(key, dummy_hidden), img_encoder.apply)
        dummy_policy_input = dict(
            state_image=dummy_hidden,
            state_dense=dummy_obs['state_dense'],
        )
    else:
        img_encoder_fn = None
        dummy_policy_input = dummy_obs
        img_encoder_network = None

    policy = hk.without_apply_rng(hk.transform(_actor_fn, apply_rng=True))

    return BCNetworks(
        policy_network=networks_lib.FeedForwardNetwork(
            lambda key: policy.init(key, dummy_policy_input), policy.apply),
        log_prob=lambda params, actions: params.log_prob(actions),
        sample=lambda params, key: params.sample(seed=key),
        sample_eval=lambda params, key: params.mode(),
        img_encoder=img_encoder_network,
    )
コード例 #14
0
def make_networks(
    spec: specs.EnvironmentSpec,
    policy_layer_sizes: Sequence[int] = (300, 200),
    critic_layer_sizes: Sequence[int] = (400, 300),
    vmin: float = -150.,
    vmax: float = 150.,
    num_atoms: int = 51,
) -> D4PGNetworks:
    """Creates networks used by the agent."""

    action_spec = spec.actions

    num_dimensions = np.prod(action_spec.shape, dtype=int)
    critic_atoms = jnp.linspace(vmin, vmax, num_atoms)

    def _actor_fn(obs):
        network = hk.Sequential([
            utils.batch_concat,
            networks_lib.LayerNormMLP(policy_layer_sizes, activate_final=True),
            networks_lib.NearZeroInitializedLinear(num_dimensions),
            networks_lib.TanhToSpec(action_spec),
        ])
        return network(obs)

    def _critic_fn(obs, action):
        network = hk.Sequential([
            utils.batch_concat,
            networks_lib.LayerNormMLP(
                layer_sizes=[*critic_layer_sizes, num_atoms]),
        ])
        value = network([obs, action])
        return value, critic_atoms

    policy = hk.without_apply_rng(hk.transform(_actor_fn))
    critic = hk.without_apply_rng(hk.transform(_critic_fn))

    # Create dummy observations and actions to create network parameters.
    dummy_action = utils.zeros_like(spec.actions)
    dummy_obs = utils.zeros_like(spec.observations)
    dummy_action = utils.add_batch_dim(dummy_action)
    dummy_obs = utils.add_batch_dim(dummy_obs)

    return D4PGNetworks(
        policy_network=networks_lib.FeedForwardNetwork(
            lambda rng: policy.init(rng, dummy_obs), policy.apply),
        critic_network=networks_lib.FeedForwardNetwork(
            lambda rng: critic.init(rng, dummy_obs, dummy_action),
            critic.apply))
コード例 #15
0
def make_networks(
    spec: specs.EnvironmentSpec,
    policy_layer_sizes: Tuple[int, ...] = (256, 256),
    critic_layer_sizes: Tuple[int, ...] = (256, 256),
    activation: Callable[[jnp.ndarray], jnp.ndarray] = jax.nn.relu,
) -> CRRNetworks:
  """Creates networks used by the agent."""
  num_actions = np.prod(spec.actions.shape, dtype=int)

  # Create dummy observations and actions to create network parameters.
  dummy_action = utils.add_batch_dim(utils.zeros_like(spec.actions))
  dummy_obs = utils.add_batch_dim(utils.zeros_like(spec.observations))

  def _policy_fn(obs: jnp.ndarray) -> jnp.ndarray:
    network = hk.Sequential([
        hk.nets.MLP(
            list(policy_layer_sizes),
            w_init=hk.initializers.VarianceScaling(1.0, 'fan_in', 'uniform'),
            activation=activation,
            activate_final=True),
        networks_lib.NormalTanhDistribution(num_actions),
    ])
    return network(obs)

  policy = hk.without_apply_rng(hk.transform(_policy_fn))
  policy_network = networks_lib.FeedForwardNetwork(
      lambda key: policy.init(key, dummy_obs), policy.apply)

  def _critic_fn(obs, action):
    network = hk.Sequential([
        hk.nets.MLP(
            list(critic_layer_sizes) + [1],
            w_init=hk.initializers.VarianceScaling(1.0, 'fan_in', 'uniform'),
            activation=activation),
    ])
    data = jnp.concatenate([obs, action], axis=-1)
    return network(data)

  critic = hk.without_apply_rng(hk.transform(_critic_fn))
  critic_network = networks_lib.FeedForwardNetwork(
      lambda key: critic.init(key, dummy_obs, dummy_action), critic.apply)

  return CRRNetworks(
      policy_network=policy_network,
      critic_network=critic_network,
      log_prob=lambda params, actions: params.log_prob(actions),
      sample=lambda params, key: params.sample(seed=key),
      sample_eval=lambda params, key: params.mode())
コード例 #16
0
    def test_dqn(self):
        # Create a fake environment to test with.
        environment = fakes.DiscreteEnvironment(num_actions=5,
                                                num_observations=10,
                                                obs_shape=(10, 5),
                                                obs_dtype=np.float32,
                                                episode_length=10)
        spec = specs.make_environment_spec(environment)

        def network(x):
            model = hk.Sequential(
                [hk.Flatten(),
                 hk.nets.MLP([50, 50, spec.actions.num_values])])
            return model(x)

        # Make network purely functional
        network_hk = hk.without_apply_rng(hk.transform(network,
                                                       apply_rng=True))
        dummy_obs = utils.add_batch_dim(utils.zeros_like(spec.observations))

        network = networks_lib.FeedForwardNetwork(
            init=lambda rng: network_hk.init(rng, dummy_obs),
            apply=network_hk.apply)

        # Construct the agent.
        agent = dqn.DQN(environment_spec=spec,
                        network=network,
                        batch_size=10,
                        samples_per_insert=2,
                        min_replay_size=10)

        # Try running the environment loop. We have no assertions here because all
        # we care about is that the agent runs without raising any errors.
        loop = acme.EnvironmentLoop(environment, agent)
        loop.run(num_episodes=20)
コード例 #17
0
    def __init__(
            self,
            num_critics=2,
            fn_error=None,
            lr_error=3e-4,
            units_error=(256, 256, 256),
            d2rl=False,
            init_error=10.0,
    ):
        if fn_error is None:

            def fn_error(s, a):
                return ContinuousQFunction(
                    num_critics=num_critics,
                    hidden_units=units_error,
                    d2rl=d2rl,
                )(s, a)

        # Error model.
        self.error = hk.without_apply_rng(hk.transform(fn_error))
        self.params_error = self.params_error_target = self.error.init(
            next(self.rng), *self.fake_args_critic)
        opt_init, self.opt_error = optix.adam(lr_error)
        self.opt_state_error = opt_init(self.params_error)
        # Running mean of error.
        self.rm_error_list = [
            jnp.array(init_error, dtype=jnp.float32)
            for _ in range(num_critics)
        ]
コード例 #18
0
        def eval_apply_fn(params, x, y, mask):
            embed_apply_fn, transformer_apply_fn = apply_fns()

            if early_collect:
                bf16_params = maybe_shard(to_bf16(params), mp_shard_strategy)
            else:
                bf16_params = to_bf16(params)

            def eval_loss(x, y):
                loss, correct = Projection(config).loss(x, y)
                return {
                    "loss": loss.mean(axis=-1),
                    "last_loss": loss[:, -1],
                    "all_loss": loss,
                    "correct": correct
                }

            projection_apply_fn = hk.without_apply_rng(
                hk.transform(eval_loss)).apply

            x = embed_apply_fn(bf16_params["embed"], x)

            def apply_scan_fn(layer_in, layer_state):
                x, mask = layer_in
                return (to_bf16(transformer_apply_fn(layer_state, x,
                                                     mask)), mask), None

            x = jax.lax.scan(apply_scan_fn, (to_bf16(x), mask),
                             xs=bf16_params["transformer"])[0][0]

            return projection_apply_fn(bf16_params["proj"], x, y)
コード例 #19
0
ファイル: networks.py プロジェクト: vishalbelsare/acme
def make_discrete_networks(
    environment_spec: specs.EnvironmentSpec,
    hidden_layer_sizes: Sequence[int] = (512, ),
    use_conv: bool = True,
) -> PPONetworks:
    """Creates networks used by the agent for discrete action environments.

  Args:
    environment_spec: Environment spec used to define number of actions.
    hidden_layer_sizes: Network definition.
    use_conv: Whether to use a conv or MLP feature extractor.
  Returns:
    PPONetworks
  """

    num_actions = environment_spec.actions.num_values

    def forward_fn(inputs):
        layers = []
        if use_conv:
            layers.extend([networks_lib.AtariTorso()])
        layers.extend([
            hk.nets.MLP(hidden_layer_sizes, activation=jax.nn.relu),
            networks_lib.CategoricalValueHead(num_values=num_actions)
        ])
        policy_value_network = hk.Sequential(layers)
        return policy_value_network(inputs)

    forward_fn = hk.without_apply_rng(hk.transform(forward_fn))
    dummy_obs = utils.zeros_like(environment_spec.observations)
    dummy_obs = utils.add_batch_dim(dummy_obs)  # Dummy 'sequence' dim.
    network = networks_lib.FeedForwardNetwork(
        lambda rng: forward_fn.init(rng, dummy_obs), forward_fn.apply)
    # Create PPONetworks to add functionality required by the agent.
    return make_ppo_networks(network)
コード例 #20
0
def make_networks(
        spec: specs.EnvironmentSpec,
        discrete_actions: bool = False) -> networks_lib.FeedForwardNetwork:
    """Creates networks used by the agent."""

    if discrete_actions:
        final_layer_size = spec.actions.num_values
    else:
        final_layer_size = np.prod(spec.actions.shape, dtype=int)

    def _actor_fn(obs, is_training=False, key=None):
        # is_training and key allows to defined train/test dependant modules
        # like dropout.
        del is_training
        del key
        if discrete_actions:
            network = hk.nets.MLP([64, 64, final_layer_size])
        else:
            network = hk.Sequential([
                networks_lib.LayerNormMLP([64, 64], activate_final=True),
                networks_lib.NormalTanhDistribution(final_layer_size),
            ])
        return network(obs)

    policy = hk.without_apply_rng(hk.transform(_actor_fn, apply_rng=True))

    # Create dummy observations and actions to create network parameters.
    dummy_obs = utils.zeros_like(spec.observations)
    dummy_obs = utils.add_batch_dim(dummy_obs)
    network = networks_lib.FeedForwardNetwork(
        lambda key: policy.init(key, dummy_obs), policy.apply)
    return network
コード例 #21
0
    def get_particular_critic_init(w_init, b_init, key, obs, act):
        def _critic_with_particular_init(obs, action):
            raise NotImplementedError(
                'Not implemented for MIMO, Not implemented for new version that also returns h1, h2'
            )
            network1 = hk.Sequential([
                hk.nets.MLP(list(critic_hidden_layer_sizes) + [1],
                            w_init=w_init,
                            b_init=b_init,
                            activation=jax.nn.relu,
                            activate_final=False),
            ])
            input_ = jnp.concatenate([obs, action], axis=-1)
            value1 = network1(input_)
            if use_double_q:
                network2 = hk.Sequential([
                    hk.nets.MLP(list(critic_hidden_layer_sizes) + [1],
                                w_init=w_init,
                                b_init=b_init,
                                activation=jax.nn.relu,
                                activate_final=False),
                ])
                value2 = network2(input_)
                return jnp.concatenate([value1, value2], axis=-1)
            else:
                return value1

        init_fn = hk.without_apply_rng(
            hk.transform(_critic_with_particular_init, apply_rng=True)).init
        return init_fn(key, obs, act)
コード例 #22
0
    def test_conv1d_cvxpy_relaxation(self):
        def conv1d_model(inp):
            return hk.Conv1D(output_channels=1,
                             kernel_shape=2,
                             padding='VALID',
                             stride=1,
                             with_bias=True)(inp)

        z = jnp.array([3., 4.])
        z = jnp.reshape(z, [1, 2, 1])

        params = {
            'conv1_d': {
                'w': jnp.ones((2, 1, 1), dtype=jnp.float32),
                'b': jnp.array([2.])
            }
        }

        fun = functools.partial(
            hk.without_apply_rng(hk.transform(conv1d_model)).apply, params)
        input_bounds = jax_verify.IntervalBound(z - 1., z + 1.)

        lower_bounds, upper_bounds = self.get_bounds(fun, input_bounds)

        self.assertAlmostEqual(7., lower_bounds, delta=1e-5)
        self.assertAlmostEqual(11., upper_bounds, delta=1e-5)
コード例 #23
0
ファイル: ibp_test.py プロジェクト: deepmind/jax_verify
    def test_conv2d_ibp(self):
        def conv2d_model(inp):
            return hk.Conv2D(output_channels=1,
                             kernel_shape=(2, 2),
                             padding='VALID',
                             stride=1,
                             with_bias=True)(inp)

        z = jnp.array([1., 2., 3., 4.])
        z = jnp.reshape(z, [1, 2, 2, 1])

        params = {
            'conv2_d': {
                'w': jnp.ones((2, 2, 1, 1), dtype=jnp.float32),
                'b': jnp.array([2.])
            }
        }

        fun = functools.partial(
            hk.without_apply_rng(hk.transform(conv2d_model)).apply, params)
        input_bounds = jax_verify.IntervalBound(z - 1., z + 1.)
        output_bounds = jax_verify.interval_bound_propagation(
            fun, input_bounds)

        self.assertAlmostEqual(8., output_bounds.lower)
        self.assertAlmostEqual(16., output_bounds.upper)
コード例 #24
0
def optimize_club(num_steps: int):
    """Solves the karte club problem by optimizing the assignments of students."""
    network = hk.without_apply_rng(hk.transform(network_definition))
    zacharys_karate_club = get_zacharys_karate_club()
    labels = get_ground_truth_assignments_for_zacharys_karate_club()
    params = network.init(jax.random.PRNGKey(42), zacharys_karate_club)

    @jax.jit
    def prediction_loss(params):
        decoded_nodes = network.apply(params, zacharys_karate_club)
        # We interpret the decoded nodes as a pair of logits for each node.
        log_prob = jax.nn.log_softmax(decoded_nodes)
        # The only two assignments we know a-priori are those of Mr. Hi (Node 0)
        # and John A (Node 33).
        return -(log_prob[0, 0] + log_prob[33, 1])

    opt_init, opt_update = optax.adam(1e-2)
    opt_state = opt_init(params)

    @jax.jit
    def update(params, opt_state):
        g = jax.grad(prediction_loss)(params)
        updates, opt_state = opt_update(g, opt_state)
        return optax.apply_updates(params, updates), opt_state

    @jax.jit
    def accuracy(params):
        decoded_nodes = network.apply(params, zacharys_karate_club)
        return jnp.mean(jnp.argmax(decoded_nodes, axis=1) == labels)

    for step in range(num_steps):
        logging.info("step %r accuracy %r", step, accuracy(params).item())
        params, opt_state = update(params, opt_state)
コード例 #25
0
def main(_):

    network = hk.without_apply_rng(hk.transform(network_definition))
    input_graph = get_random_graph()
    params = network.init(jax.random.PRNGKey(42), input_graph)
    output_graph = network.apply(params, input_graph)
    print(tree.tree_map(lambda x: x.shape, output_graph))
コード例 #26
0
  def test_graph_embedding_model_runs(self):
    graph = jraph.GraphsTuple(
        nodes=np.array([[0, 1, 1],
                        [1, 2, 0],
                        [0, 3, 0],
                        [0, 4, 4]], dtype=np.float32),
        edges=np.array([[1, 1],
                        [2, 2],
                        [3, 3]], dtype=np.float32),
        senders=np.array([0, 1, 2], dtype=np.int32),
        receivers=np.array([1, 2, 3], dtype=np.int32),
        n_node=np.array([4], dtype=np.int32),
        n_edge=np.array([3], dtype=np.int32),
        globals=None)
    embed_dim = 3

    def forward(graph):
      return embedding.GraphEmbeddingModel(embed_dim=3, num_layers=2)(graph)

    init_fn, apply_fn = hk.without_apply_rng(hk.transform(forward))
    key = hk.PRNGSequence(8)
    params = init_fn(next(key), graph)
    out = apply_fn(params, graph)

    self.assertEqual(out.nodes.shape, (graph.nodes.shape[0], embed_dim))
    self.assertEqual(out.edges.shape, (graph.edges.shape[0], embed_dim))
    np.testing.assert_array_equal(out.senders, graph.senders)
    np.testing.assert_array_equal(out.receivers, graph.receivers)
    np.testing.assert_array_equal(out.n_node, graph.n_node)
コード例 #27
0
def make_haiku_networks(
        spec: specs.EnvironmentSpec) -> networks_lib.FeedForwardNetwork:
    """Creates Haiku networks to be used by the agent."""

    num_actions = spec.actions.num_values

    def forward_fn(inputs):
        policy_network = hk.Sequential([
            utils.batch_concat,
            hk.nets.MLP([64, 64]),
            networks_lib.CategoricalHead(num_actions)
        ])
        value_network = hk.Sequential([
            utils.batch_concat,
            hk.nets.MLP([64, 64]),
            hk.Linear(1), lambda x: jnp.squeeze(x, axis=-1)
        ])

        action_distribution = policy_network(inputs)
        value = value_network(inputs)
        return (action_distribution, value)

    # Transform into pure functions.
    forward_fn = hk.without_apply_rng(hk.transform(forward_fn))

    dummy_obs = utils.zeros_like(spec.observations)
    dummy_obs = utils.add_batch_dim(dummy_obs)  # Dummy 'sequence' dim.
    return networks_lib.FeedForwardNetwork(
        lambda rng: forward_fn.init(rng, dummy_obs), forward_fn.apply)
コード例 #28
0
ファイル: train.py プロジェクト: tirkarthi/dm-haiku
def update(state: TrainingState, batch: dataset.Batch) -> TrainingState:
    """Does a step of SGD given inputs & targets."""
    _, optimizer = optix.adam(FLAGS.learning_rate)
    _, loss_fn = hk.without_apply_rng(hk.transform(sequence_loss))
    gradients = jax.grad(loss_fn)(state.params, batch)
    updates, new_opt_state = optimizer(gradients, state.opt_state)
    new_params = optix.apply_updates(state.params, updates)
    return TrainingState(params=new_params, opt_state=new_opt_state)
コード例 #29
0
def main(config: Config, ckpt_path: str, structures_dir: str, output: str):
    # Prepare test dataset
    atom_featurizer = AtomFeaturizer(
        atom_features_json=config.atom_init_features_path)
    bond_featurizer = BondFeaturizer(dmin=config.dmin,
                                     dmax=config.cutoff,
                                     num_filters=config.num_bond_features)
    dataset, list_ids = create_dataset(
        atom_featurizer=atom_featurizer,
        bond_featurizer=bond_featurizer,
        structures_dir=structures_dir,
        targets_csv_path="",
        max_num_neighbors=config.max_num_neighbors,
        cutoff=config.cutoff,
        is_training=False,
        seed=config.seed,
        n_jobs=config.n_jobs,
    )

    # Define model
    model_fn_t = get_model_fn_t(
        num_initial_atom_features=atom_featurizer.num_initial_atom_features,
        num_atom_features=config.num_atom_features,
        num_bond_features=config.num_bond_features,
        num_convs=config.num_convs,
        num_hidden_layers=config.num_hidden_layers,
        num_hidden_features=config.num_hidden_features,
        max_num_neighbors=config.max_num_neighbors,
        batch_size=config.batch_size,
    )
    model = hk.without_apply_rng(model_fn_t)

    # Load checkpoint
    params, state, normalizer = restore_checkpoint(ckpt_path)

    @jax.jit
    def predict_one_step(batch: Batch) -> jnp.ndarray:
        predictions, _ = model.apply(params, state, batch, is_training=False)
        return predictions

    # Prediction
    batch_size = config.batch_size
    steps_per_epoch = (len(dataset) + batch_size - 1) // batch_size
    predictions = []
    for i in range(steps_per_epoch):
        batch = collate_pool(
            dataset[i * batch_size:min(len(dataset), (i + 1) * batch_size)],
            False)  # train=False
        preds = predict_one_step(batch)
        predictions.append(preds)
    predictions = jnp.concatenate(predictions)  # (len(dataset), 1)

    # denormalize predictions
    denormed_preds = normalizer.denormalize(predictions)

    with open(args.output, "w") as f:
        for i, idx in enumerate(list_ids):
            f.write(f"{idx},{denormed_preds[i, 0]}\n")
コード例 #30
0
def build_network(num_hidden_units: int, num_actions: int) -> hk.Transformed:
    """Factory for a simple MLP network for approximating Q-values."""
    def q(obs):
        flatten = lambda x: jnp.reshape(x, (-1, ))
        network = hk.Sequential(
            [flatten, nets.MLP([num_hidden_units, num_actions])])
        return network(obs)

    return hk.without_apply_rng(hk.transform(q))