def unstack(a, axis=0):
    """The opposite of stack()."""
    shape = a.shape
    return [jnp.squeeze(b, axis=axis) for b in \
            jnp.split(a, shape[axis], axis=axis)]
Пример #2
0
    def apply(self, inputs, info, config, train=False, cache=None):
        start_indexes = inputs['start_index']  # pylint: disable=unused-variable
        exit_indexes = inputs['exit_index']
        steps_all = jnp.squeeze(inputs['steps'], axis=-1)
        # steps_all.shape: batch_size
        edge_types = inputs['edge_types']
        source_indices = inputs['source_indices']
        dest_indices = inputs['dest_indices']
        vocab_size = info.features[info._builder.key('statements')].vocab_size  # pylint: disable=protected-access
        output_token_vocabulary_size = info.output_vocab_size
        hidden_size = config.model.hidden_size
        data = inputs['data'].astype('int32')
        unused_batch_size, num_nodes, unused_statement_length = data.shape

        max_steps = int(1.5 * info.max_diameter)

        # Init parameters
        def emb_init(key, shape, dtype=jnp.float32):
            return jax.random.uniform(key, shape, dtype,
                                      -config.initialization.maxval,
                                      config.initialization.maxval)

        embed = Embed.shared(num_embeddings=vocab_size,
                             features=hidden_size,
                             emb_init=emb_init,
                             name='embed')

        cells = create_lstm_cells(config.model.rnn_cell.layers)
        lstm = StackedRNNCell.shared(cells=cells)
        initial_state = lstm.initialize_carry(jax.random.PRNGKey(0), cells, (),
                                              hidden_size)

        def embed_statement(token_embeddings):
            # token_embeddings.shape: 4, hidden_size
            _, results = lax.scan(lstm, initial_state, token_embeddings)
            return results[-1]

        embed_all_statements_single_example = jax.vmap(embed_statement)
        embed_all_statements = jax.vmap(embed_all_statements_single_example)

        output_dense = nn.Dense.shared(
            name='output_dense',
            features=output_token_vocabulary_size,
            kernel_init=nn.initializers.xavier_uniform(),
            bias_init=nn.initializers.normal(stddev=1e-6))

        node_embeddings = embed(data)
        # node_embeddings.shape:
        #     batch_size, num_nodes, statement_length, hidden_size
        statement_embeddings = embed_all_statements(node_embeddings)
        # statement_embeddings.shape: batch_size, num_nodes, hidden_size

        gnn_layer_single_example = GGNNLayer.shared(num_nodes=num_nodes,
                                                    hidden_size=hidden_size,
                                                    config=config)
        gnn_layer = jax.vmap(gnn_layer_single_example)

        # statement_embeddings.shape: batch_size, num_nodes, hidden_size
        for step in range(max_steps):
            new_statement_embeddings = gnn_layer(statement_embeddings,
                                                 source_indices, dest_indices,
                                                 edge_types)
            # steps_all.shape: batch_size
            valid = jnp.expand_dims(step < steps_all, axis=(1, 2))
            # valid.shape: batch_size, 1, 1
            statement_embeddings = jnp.where(valid, new_statement_embeddings,
                                             statement_embeddings)

        def get_final_state(statement_embeddings, exit_index):
            return statement_embeddings[exit_index]

        final_states = jax.vmap(get_final_state)(statement_embeddings,
                                                 exit_indexes)
        # final_states.shape: batch_size, hidden_size
        logits = output_dense(final_states)
        # logits.shape: batch_size, output_token_vocabulary_size
        logits = jnp.expand_dims(logits, axis=1)
        # logits.shape: batch_size, 1, output_token_vocabulary_size
        return logits
 def loss_fun(x, step):
     del step
     logits = jnp.squeeze(predict_fun(x, features))
     data_loss = jnp.mean(jnp.log1p(jnp.exp(logits)) - targets * logits)
     reg_loss = l2_pen * utils.norm(x)
     return data_loss + reg_loss
Пример #4
0
        def affine_transform(dist_params, scale, shift, value_transform=None):
            """ implements the "Categorical Algorithm" from https://arxiv.org/abs/1707.06887 """

            # check inputs
            chex.assert_rank([dist_params['logits'], scale, shift],
                             [2, {0, 1}, {0, 1}])
            p = jax.nn.softmax(dist_params['logits'])
            batch_size = p.shape[0]

            if isscalar(scale):
                scale = jnp.full(shape=(batch_size, ),
                                 fill_value=jnp.squeeze(scale))
            if isscalar(shift):
                shift = jnp.full(shape=(batch_size, ),
                                 fill_value=jnp.squeeze(shift))

            chex.assert_shape(p, (batch_size, self.num_bins))
            chex.assert_shape([scale, shift], (batch_size, ))

            if value_transform is None:
                f = f_inv = lambda x: x
            else:
                f, f_inv = value_transform

            # variable names correspond to those defined in: https://arxiv.org/abs/1707.06887
            z = self.__atoms
            Vmin, Vmax, Δz = z[0], z[-1], z[1] - z[0]
            Tz = f(jax.vmap(jnp.add)(jnp.outer(scale, f_inv(z)), shift))
            Tz = jnp.clip(Tz, Vmin, Vmax)  # keep values in valid range
            chex.assert_shape(Tz, (batch_size, self.num_bins))

            b = (Tz - Vmin) / Δz  # float in [0, num_bins - 1]
            l = jnp.floor(b).astype(
                'int32')  # noqa: E741   # int in {0, 1, ..., num_bins - 1}
            u = jnp.ceil(b).astype('int32')  # int in {0, 1, ..., num_bins - 1}
            chex.assert_shape([p, b, l, u], (batch_size, self.num_bins))

            m = jnp.zeros_like(p)
            i = jnp.expand_dims(jnp.arange(batch_size), axis=1)  # batch index
            m = jax.ops.index_add(m, (i, l),
                                  p * (u - b),
                                  indices_are_sorted=True)
            m = jax.ops.index_add(m, (i, u),
                                  p * (b - l),
                                  indices_are_sorted=True)
            m = jax.ops.index_add(m, (i, l),
                                  p * (l == u),
                                  indices_are_sorted=True)
            # chex.assert_tree_all_close(jnp.sum(m, axis=1), jnp.ones(batch_size), rtol=1e-6)

            # # The above index trickery is equivalent to:
            # m_alt = onp.zeros((batch_size, self.num_bins))
            # for i in range(batch_size):
            #     for j in range(self.num_bins):
            #         if l[i, j] == u[i, j]:
            #             m_alt[i, l[i, j]] += p[i, j]  # don't split if b[i, j] is an integer
            #         else:
            #             m_alt[i, l[i, j]] += p[i, j] * (u[i, j] - b[i, j])
            #             m_alt[i, u[i, j]] += p[i, j] * (b[i, j] - l[i, j])
            # chex.assert_tree_all_close(m, m_alt, rtol=1e-6)
            return {'logits': jnp.log(jnp.maximum(m, 1e-16))}
Пример #5
0
 def f_jax(x):
     return jnp.squeeze(x, axis=1)
Пример #6
0
def _inputs_to_kernel(x1, x2, use_pooling, compute_ntk):
    """Transforms (batches of) inputs to a `Kernel`.

  This is a private method. Docstring and example are for internal reference.

   The kernel contains the empirical covariances between different inputs and
     their entries (pixels) necessary to compute the covariance of the Gaussian
     Process corresponding to an infinite Bayesian or gradient-flow-trained
     neural network.

   The smallest necessary number of covariance entries is tracked. For example,
     all networks are assumed to have i.i.d. weights along the channel / feature
     / logits dimensions, hence covariance between different entries along these
     dimensions is known to be 0 and is not tracked.

  Args:
    x1: a 2D `np.ndarray` of shape `[batch_size_1, n_features]` (dense
      network) or 4D of shape `[batch_size_1, height, width, channels]`
      (conv-nets).
    x2: an optional `np.ndarray` with the same shape as `x1` apart
      from possibly different leading batch size. `None` means
      `x2 == x1`.
    use_pooling: a boolean, indicating whether pooling will be used somewhere in
      the model. If so, more covariance entries need to be tracked. Is set
      automatically based on the network topology. Specifically, is set to
      `False` if a `serial` or `parallel` networks contain a `Flatten` layer
      and no pooling layers (`AvgPool` or `GlobalAvgPool`). Has no effect for
      non-convolutional models.
    compute_ntk: a boolean, `True` to compute both NTK and NNGP kernels,
        `False` to only compute NNGP.

    Example:
      ```python
          >>> x = np.ones((10, 32, 16, 3))
          >>> _inputs_to_kernel(x, None, use_pooling=True,
          >>>                   compute_ntk=True).ntk.shape
          (10, 10, 32, 32, 16, 16)
          >>> _inputs_to_kernel(x, None, use_pooling=False,
          >>>                   compute_ntk=True).ntk.shape
          (10, 10, 32, 16)
          >>> x1 = np.ones((10, 128))
          >>> x2 = np.ones((20, 128))
          >>> _inputs_to_kernel(x, None, use_pooling=True,
          >>>                   compute_ntk=False).nngp.shape
          (10, 20)
          >>> _inputs_to_kernel(x, None, use_pooling=False,
          >>>                   compute_ntk=False).nngp.shape
          (10, 20)
          >>> _inputs_to_kernel(x, None, use_pooling=False,
          >>>                   compute_ntk=False).ntk
          None
      ```

  Returns:
    a `Kernel` object.
  """
    x1 = x1.astype(xla_bridge.canonicalize_dtype(np.float64))
    var1 = _get_variance(x1)

    if x2 is None:
        x2 = x1
        var2 = None
    else:
        if x1.shape[1:] != x2.shape[1:]:
            raise ValueError(
                '`x1` and `x2` are expected to be batches of'
                ' inputs with the same shape (apart from the batch size),'
                ' got %s and %s.' % (str(x1.shape), str(x2.shape)))

        x2 = x2.astype(xla_bridge.canonicalize_dtype(np.float64))
        var2 = _get_variance(x2)

    if use_pooling and x1.ndim == 4:
        x2 = np.expand_dims(x2, -1)
        nngp = np.dot(x1, x2) / x1.shape[-1]
        nngp = np.transpose(np.squeeze(nngp, -1), (0, 3, 1, 4, 2, 5))

    elif x1.ndim == 4 or x1.ndim == 2:
        nngp = _batch_uncentered_covariance(x1, x2)

    else:
        raise ValueError('Inputs must be 2D or 4D `np.ndarray`s of shape '
                         '`[batch_size, n_features]` or '
                         '`[batch_size, height, width, channels]`, '
                         'got %s.' % str(x1.shape))

    ntk = 0. if compute_ntk else None
    is_gaussian = False
    is_height_width = True
    return Kernel(var1, nngp, var2, ntk, is_gaussian, is_height_width)
Пример #7
0
    def test_monte_carlo_generator(self, batch_size, device_count,
                                   store_on_device, get):
        utils.stub_out_pmap(batch, device_count)

        x1, x2, init_fn, apply_fn, stax_kernel_fn, key = _get_inputs_and_model(
            8, 1)
        x3, x4, _, _, _, _ = _get_inputs_and_model(8, 1)

        log_n_max = 4
        n_samples = [2**k for k in range(log_n_max)]
        sample_generator = monte_carlo.monte_carlo_kernel_fn(
            init_fn, apply_fn, key, n_samples, batch_size, device_count,
            store_on_device)

        if get is None:
            samples_12 = sample_generator(x1, x2)
            samples_34 = sample_generator(x3, x4)

            count = 0
            for n, s_12, s_34 in zip(n_samples, samples_12, samples_34):
                sample_fn = monte_carlo.monte_carlo_kernel_fn(
                    init_fn, apply_fn, key, n, batch_size, device_count,
                    store_on_device)
                sample_12 = sample_fn(x1, x2)
                sample_34 = sample_fn(x3, x4)
                self.assertAllClose(s_12, sample_12, True)
                self.assertAllClose(s_12, s_34, True)
                self.assertAllClose(s_12, sample_34, True)
                count += 1

            self.assertEqual(log_n_max, count)

            ker_analytic_12 = stax_kernel_fn(x1, x2)
            ker_analytic_34 = stax_kernel_fn(x3, x4)

        else:
            samples_12 = sample_generator(x1, x2, get)
            samples_34 = sample_generator(x3, x4, get)

            count = 0
            for n, s_12, s_34 in zip(n_samples, samples_12, samples_34):
                sample_fn = monte_carlo.monte_carlo_kernel_fn(
                    init_fn, apply_fn, key, n, batch_size, device_count,
                    store_on_device)
                sample_12 = sample_fn(x1, x2, get)
                sample_34 = sample_fn(x3, x4, get)
                self.assertAllClose(s_12, sample_12, True)
                self.assertAllClose(s_12, s_34, True)
                self.assertAllClose(s_12, sample_34, True)
                count += 1

            self.assertEqual(log_n_max, count)

            ker_analytic_12 = stax_kernel_fn(x1, x2, get)
            ker_analytic_34 = stax_kernel_fn(x3, x4, get)

        if get == 'ntk':
            s_12 = np.squeeze(s_12, (-1, -2))
        elif get is None or 'ntk' in get:
            s_12 = s_12._replace(ntk=np.squeeze(s_12.ntk, (-1, -2)))

        self.assertAllClose(ker_analytic_12, s_12, True, 2., 2.)
        self.assertAllClose(ker_analytic_12, ker_analytic_34, True)
Пример #8
0
def squeeze(x, axis=None):
    if x.shape == ():
        if axis is None or axis == 0 or axis == -1:
            return x
        raise Exception('tried to squeeze a zero-dimensional input by axis {}'.format(axis))
    return _jnp.squeeze(x, axis)
Пример #9
0
def mpo_compute_weights_and_temperature_loss(
    sample_q_values: Array,
    temperature_constraint: LagrangePenalty,
    projection_operator: Callable[[Numeric], Numeric],
    sample_axis: int = 0,
) -> Tuple[Array, Array, Scalar]:
  """Computes the weights and temperature loss for MPO.

  The E-Step computes a non-parameteric sample-based approximation of the
  current policy by reweighting the state-action value function.

  Here, we compute this nonparametric policy and optimize the temperature
  parameter used in the reweighting.

  Args:
    sample_q_values: An array of shape E* + a sample axis inserted at
      sample_axis containing the q function values evaluated on the sampled
      actions.
    temperature_constraint: Lagrange constraint for the E-step temperature
      optimization.
    projection_operator: Function to project temperature into the positive
      range.
    sample_axis: Axis in sample_q_values containing sampled actions.

  Returns:
    The temperature loss, normalized weights and number of actions samples per
    state.
  """
  chex.assert_rank(temperature_constraint.epsilon, 0)
  chex.assert_type([sample_q_values, temperature_constraint.alpha,
                    temperature_constraint.epsilon], float)

  if sample_axis < 0:
    sample_axis += sample_q_values.ndim
  if not 0 <= sample_axis < sample_q_values.ndim:
    raise ValueError(
        f"`sample_axis` {sample_axis} not in array rank {sample_q_values.ndim}")

  n_action_samples = sample_q_values.shape[sample_axis]

  # Clip the temperature value (temperature must be positive).
  temperature = projection_operator(temperature_constraint.alpha)
  epsilon = temperature_constraint.epsilon

  # Scale the Q-values.
  scaled_sample_q_values = sample_q_values / temperature

  # Temperature optimization.
  q_logsumexp = jax.scipy.special.logsumexp(
      scaled_sample_q_values, axis=sample_axis, keepdims=True)

  # The temperature loss encourages the current and previous policy to stay
  # close. This loss optimizes the convex dual of an upper bound on the average
  # KL (epsilon) between the current and previous state-action values.
  temperature_loss = (
      temperature * epsilon +
      (temperature * (jnp.squeeze(q_logsumexp, axis=sample_axis)
                      - jnp.log(n_action_samples))))

  # The weights corresponds to a softmax over state-action values.
  weights = jnp.exp(scaled_sample_q_values - q_logsumexp)

  # Normalize the weights before the M-Step
  norm_weights = weights / jnp.sum(weights, axis=sample_axis, keepdims=True)

  return temperature_loss, norm_weights, n_action_samples
Пример #10
0
 def _time_derivative_map(self, input_array, time, sc: OrderedDict):
     values = self.predict_time(
         sc["params_time"], np.expand_dims(np.append(input_array, time), 0))
     return np.squeeze(values)
Пример #11
0
    def __call__(self, inputs: jnp.ndarray) -> Tuple[jnp.ndarray, jnp.ndarray]:
        """Returns a (Logits, Value) tuple."""
        logits = self._policy_layer(inputs)  # [B, A]
        value = jnp.squeeze(self._value_layer(inputs), axis=-1)  # [B]

        return logits, value
Пример #12
0
 def _map(self, input_array, time, sc: OrderedDict):
     values = self.predict(sc["params"], np.expand_dims(input_array, 0))
     return np.squeeze(values)
Пример #13
0
 def testSqueeze(self, arg_shape, dtype, ax, rng):
   onp_fun = lambda x: onp.squeeze(x, ax)
   lnp_fun = lambda x: lnp.squeeze(x, ax)
   args_maker = lambda: [rng(arg_shape, dtype)]
   self._CheckAgainstNumpy(onp_fun, lnp_fun, args_maker, check_dtypes=True)
   self._CompileAndCheck(lnp_fun, args_maker, check_dtypes=True)
Пример #14
0
def piter(f, x, trust_radius, args):

    fg = jax.value_and_grad(f)
    g = jax.grad(f)
    #h = jax.hessian(f)
    h = jax.jacfwd(g)
    #h = jax.jacrev(g)

    #compute function value and gradient
    val, grad = fg(x, *args)
    gradmag = np.linalg.norm(grad, axis=-1)
    gradcol = np.expand_dims(grad, axis=-1)

    #compute hessian and eigen-decomposition
    hess = h(x, *args)
    e, u = np.linalg.eigh(hess)

    ut = u.T

    #convert gradient to eigen-basis
    a = np.matmul(ut, gradcol)
    a = np.squeeze(a, axis=-1)

    lam = e
    e0 = lam[..., 0]

    #TODO deal with null gradient components and repeated eigenvectors
    lambar = lam
    abarsq = np.square(a)

    def phif(s):
        pmagsq = np.sum(abarsq / np.square(lambar + s), axis=-1)
        pmag = np.sqrt(pmagsq)
        phipartial = np.reciprocal(pmag)
        singular = np.any(np.equal(-s, lambar), axis=-1)
        phipartial = np.where(singular, 0., phipartial)
        phi = phipartial - np.reciprocal(trust_radius)
        return phi

    def phiphiprime(s):
        phi = phif(s)
        pmagsq = np.sum(abarsq / np.square(lambar + s), axis=-1)
        phiprime = np.power(pmagsq, -1.5) * np.sum(
            abarsq / np.power(lambar + s, 3), axis=-1)
        return (phi, phiprime)

    #check if unconstrained solution is valid
    sigma0 = np.maximum(-e0, 0.)
    phisigma0 = phif(sigma0)
    usesolu = np.logical_and(e0 > 0., phisigma0 >= 0.)

    sigma = np.max(np.abs(a) / trust_radius - lam, axis=-1)
    sigma = np.maximum(sigma, 0.)
    sigma = np.where(usesolu, 0., sigma)
    phi, phiprime = phiphiprime(sigma)

    #TODO, add handling of additional cases here (singular and "hard" cases)

    #iteratively solve for sigma, enforcing unconstrained solution sigma=0 where appropriate
    unconverged = np.ones(shape=sigma.shape, dtype=np.bool_)
    j = 0
    maxiter = 200

    #This can't work with vmap+jit because of the dynamic condition, so we use the jax while_loop below
    #j = 0
    #while np.logical_and(np.any(unconverged), j<maxiter):
    #sigma = sigma - phi/phiprime
    #sigma = np.where(usesolu, 0., sigma)
    #phiout, phiprimeout = phiphiprime(sigma)
    #unconverged = np.logical_and( (phiout > phi) , (phiout < 0.) )
    #phi,phiprime = (phiout, phiprimeout)
    #j = j +1

    def cond(vals):
        sigma, phi, phiprime, unconverged, j = vals
        return np.logical_and(np.any(unconverged), j < maxiter)

    def body(vals):
        sigma, phi, phiprime, unconverged, j = vals
        sigma = sigma - phi / phiprime
        sigma = np.where(usesolu, 0., sigma)
        phiout, phiprimeout = phiphiprime(sigma)
        unconverged = np.logical_and((phiout > phi), (phiout < 0.))
        phi, phiprime = (phiout, phiprimeout)
        j = j + 1
        return (sigma, phi, phiprime, unconverged, j)

    sigma = jax.lax.while_loop(cond, body,
                               (sigma, phi, phiprime, unconverged, j))[0]

    #compute solution from eigenvalues and eigenvectors
    coeffs = -a / (lam + sigma)
    coeffscol = np.expand_dims(coeffs, axis=-1)

    p = np.matmul(u, coeffscol)
    p = np.squeeze(p, axis=-1)

    #compute predicted reduction in loss function from eigenvalues and eigenvectors
    predicted_reduction = -np.sum(a * coeffs + 0.5 * lam * np.square(coeffs),
                                  axis=-1)

    #compute actual reduction in loss
    x_new = x + p
    val_new = f(x_new, *args)
    #actual_reduction = -(val_new - val)
    actual_reduction = val - val_new

    #update trust radius and output parameters, following Nocedal and Wright 2nd ed. Algorithm 4.1
    eta = 0.15
    trust_radius_max = 1e3
    rho = actual_reduction / np.where(np.equal(actual_reduction, 0.), 1.,
                                      predicted_reduction)
    rho = np.where(np.isnan(rho), 0., rho)
    at_boundary = np.logical_not(usesolu)
    trust_radius_out = np.where(
        rho < 0.25, 0.25 * trust_radius,
        np.where(np.logical_and(rho > 0.75, at_boundary),
                 np.minimum(2. * trust_radius, trust_radius_max),
                 trust_radius))

    x_out = np.where(rho > eta, x_new, x)

    #compute estimated distance to minimum for unconstrained solution (only valid if e0>0)
    coeffs0 = -a / lam
    edm = -np.sum(a * coeffs0 + 0.5 * lam * np.square(coeffs0), axis=-1)

    return x_out, trust_radius_out, val, gradmag, edm, e0
Пример #15
0
def collect_trajectories(env,
                         policy_net_apply,
                         policy_net_params,
                         num_trajectories=1,
                         policy="greedy",
                         max_timestep=None,
                         epsilon=0.1):
    """Collect trajectories with the given policy net and behaviour."""
    trajectories = []

    for t in range(num_trajectories):
        t_start = time.time()
        rewards = []
        actions = []
        done = False

        observation = env.reset()

        # This is currently shaped (1, 1) + OBS, but new observations will keep
        # getting added to it, making it eventually (1, T+1) + OBS
        observation_history = observation[np.newaxis, np.newaxis, :]

        # Run either till we're done OR if max_timestep is defined only till that
        # timestep.
        ts = 0
        while ((not done)
               and (not max_timestep
                    or observation_history.shape[1] < max_timestep)):
            ts_start = time.time()
            # Run the policy, to pick an action, shape is (1, t, A) because
            # observation_history is shaped (1, t) + OBS
            predictions = policy_net_apply(observation_history,
                                           policy_net_params)

            # We need the predictions for the last time-step, so squeeze the batch
            # dimension and take the last time-step.
            predictions = np.squeeze(predictions, axis=0)[-1]

            # Policy can be run in one of the following ways:
            #  - Greedy
            #  - Epsilon-Greedy
            #  - Categorical-Sampling
            action = None
            if policy == "greedy":
                action = np.argmax(predictions)
            elif policy == "epsilon-greedy":
                # A schedule for epsilon is 1/k where k is the episode number sampled.
                if onp.random.random() < epsilon:
                    # Choose an action at random.
                    action = onp.random.randint(0, high=len(predictions))
                else:
                    # Return the best action.
                    action = np.argmax(predictions)
            elif policy == "categorical-sampling":
                # NOTE: The predictions aren't probabilities but log-probabilities
                # instead, since they were computed with LogSoftmax.
                # So just np.exp them to make them probabilities.
                predictions = np.exp(predictions)
                action = onp.argwhere(
                    onp.random.multinomial(1, predictions) == 1)
            else:
                raise ValueError("Unknown policy: %s" % policy)

            # NOTE: Assumption, single batch.
            try:
                action = int(action)
            except TypeError as err:
                # Let's dump some information before we die off.
                logging.error("Cannot convert action into an integer: [%s]",
                              err)
                logging.error("action.shape: [%s]", action.shape)
                logging.error("action: [%s]", action)
                logging.error("predictions.shape: [%s]", predictions.shape)
                logging.error("predictions: [%s]", predictions)
                logging.error("observation_history: [%s]", observation_history)
                logging.error("policy_net_params: [%s]", policy_net_params)
                log_params(policy_net_params, "policy_net_params")
                raise err

            observation, reward, done, _ = env.step(action)

            # observation is of shape OBS, so add extra dims and concatenate on the
            # time dimension.
            observation_history = np.concatenate(
                [observation_history, observation[np.newaxis, np.newaxis, :]],
                axis=1)

            rewards.append(reward)
            actions.append(action)

            ts += 1
            logging.vlog(
                2,
                "  Collected time-step[ %5d] of trajectory[ %5d] in [%0.2f] msec.",
                ts, t, get_time(ts_start))
        logging.vlog(2, " Collected trajectory[ %5d] in [%0.2f] msec.", t,
                     get_time(t_start))

        # This means we are done we're been terminated early.
        assert done or (max_timestep
                        and max_timestep >= observation_history.shape[1])
        # observation_history is (1, T+1) + OBS, lets squeeze out the batch dim.
        observation_history = np.squeeze(observation_history, axis=0)
        trajectories.append(
            (observation_history, np.stack(actions), np.stack(rewards)))

    return trajectories
Пример #16
0
 def wrapper(*args, **kwargs):
     expand = lambda t: jnp.expand_dims(t, axis=axis)
     args = tree.map_structure(expand, args)
     kwargs = tree.map_structure(expand, kwargs)
     outputs = f(*args, **kwargs)
     return tree.map_structure(lambda t: jnp.squeeze(t, axis=axis), outputs)
Пример #17
0
def ppo_loss(policy_net_apply,
             new_policy_params,
             old_policy_params,
             value_net_apply,
             value_net_params,
             padded_observations,
             padded_actions,
             padded_rewards,
             reward_mask,
             gamma=0.99,
             lambda_=0.95,
             epsilon=0.2):
    """PPO objective, with an eventual minus sign."""
    B, T = padded_rewards.shape  # pylint: disable=invalid-name
    assert (B, T + 1) == padded_observations.shape[:2]
    assert (B, T) == padded_actions.shape
    assert (B, T) == padded_rewards.shape
    assert (B, T) == reward_mask.shape

    # (B, T+1, 1)
    predicted_values = value_net_apply(padded_observations, value_net_params)
    assert (B, T + 1, 1) == predicted_values.shape

    # (B, T)
    td_deltas = deltas(
        np.squeeze(predicted_values, axis=2),  # (B, T+1)
        padded_rewards,
        reward_mask,
        gamma=gamma)
    assert (B, T) == td_deltas.shape

    # (B, T)
    advantages = gae_advantages(td_deltas,
                                reward_mask,
                                lambda_=lambda_,
                                gamma=gamma)
    assert (B, T) == advantages.shape

    # probab_actions_{old,new} are both (B, T+1, A)
    log_probab_actions_old = policy_net_apply(padded_observations,
                                              old_policy_params)
    log_probab_actions_new = policy_net_apply(padded_observations,
                                              new_policy_params)
    assert (B, T + 1) == log_probab_actions_old.shape[:2]
    assert (B, T + 1) == log_probab_actions_new.shape[:2]
    assert log_probab_actions_old.shape[-1] == log_probab_actions_new.shape[-1]

    # (B, T)
    ratios = compute_probab_ratios(log_probab_actions_old,
                                   log_probab_actions_new, padded_actions,
                                   reward_mask)
    assert (B, T) == ratios.shape

    # (B, T)
    objective = clipped_objective(ratios,
                                  advantages,
                                  reward_mask,
                                  epsilon=epsilon)
    assert (B, T) == objective.shape

    # ()
    average_objective = np.sum(objective) / np.sum(reward_mask)

    # Loss is negative objective.
    return -average_objective
Пример #18
0
  def __call__(self, inputs: Array) -> Array:
    """Applies a convolution to the inputs.

    Args:
      inputs: input data with dimensions (batch, spatial_dims..., features).

    Returns:
      The convolved data.
    """

    inputs = jnp.asarray(inputs, self.dtype)

    if isinstance(self.kernel_size, int):
      raise TypeError('The kernel size must be specified as a'
                      ' tuple/list of integers (eg.: [3, 3]).')
    else:
      kernel_size = tuple(self.kernel_size)

    def maybe_broadcast(x):
      if x is None:
        # backward compatibility with using None as sentinel for
        # broadcast 1
        x = 1
      if isinstance(x, int):
        return (x,) * len(kernel_size)
      return x

    is_single_input = False
    if inputs.ndim == len(kernel_size) + 1:
      is_single_input = True
      inputs = jnp.expand_dims(inputs, axis=0)

    strides = maybe_broadcast(self.strides)  # self.strides or (1,) * (inputs.ndim - 2)
    input_dilation = maybe_broadcast(self.input_dilation)
    kernel_dilation = maybe_broadcast(self.kernel_dilation)

    in_features = inputs.shape[-1]
    assert in_features % self.feature_group_count == 0
    kernel_shape = kernel_size + (
        in_features // self.feature_group_count, self.features)
    kernel = self.param('kernel', self.kernel_init, kernel_shape)
    kernel = jnp.asarray(kernel, self.dtype)

    dimension_numbers = _conv_dimension_numbers(inputs.shape)
    y = lax.conv_general_dilated(
        inputs,
        kernel,
        strides,
        self.padding,
        lhs_dilation=input_dilation,
        rhs_dilation=kernel_dilation,
        dimension_numbers=dimension_numbers,
        feature_group_count=self.feature_group_count,
        precision=self.precision)

    if is_single_input:
      y = jnp.squeeze(y, axis=0)
    if self.use_bias:
      bias = self.param('bias', self.bias_init, (self.features,))
      bias = jnp.asarray(bias, self.dtype)
      y = y + bias
    return y
Пример #19
0
def calculate_likelihood(
    theta,
    fg_ids,
    fg_covs,
    bg_covs,
    fg_covs_thin,
    bg_covs_thin,
    quad_weights,
    counts,
    n_s,
    n_fg,
):

    # Project
    rel_means_fg = theta["w_means"][:, fg_ids].T
    rel_vars_fg = theta["w_vars"][:, fg_ids].T

    pred_mean_fg = jnp.sum(fg_covs * rel_means_fg, axis=1)
    pred_vars_fg = jnp.sum(fg_covs**2 * rel_vars_fg, axis=1)

    # Add thinning
    pred_mean_fg_thin = jnp.squeeze(fg_covs_thin @ theta["w_means_thin"])
    pred_vars_fg_thin = jnp.squeeze(fg_covs_thin**2 @ theta["w_vars_thin"])

    pred_mean_fg = pred_mean_fg + pred_mean_fg_thin
    pred_vars_fg = pred_vars_fg + pred_vars_fg_thin

    # Add intercept too
    pred_mean_fg = pred_mean_fg + theta["intercept_means"][fg_ids]
    pred_vars_fg = pred_vars_fg + theta["intercept_vars"][fg_ids]

    pred_mean_bg = bg_covs @ theta["w_means"] + theta[
        "intercept_means"].reshape(1, -1)
    pred_vars_bg = bg_covs**2 @ theta["w_vars"] + theta[
        "intercept_vars"].reshape(1, -1)

    pred_mean_bg = pred_mean_bg + bg_covs_thin @ theta["w_means_thin"]
    pred_vars_bg = pred_vars_bg + bg_covs_thin**2 @ theta["w_vars_thin"]

    quad_weights_bg = jnp.tile(quad_weights, (n_s, 1)).T
    ys_bg = jnp.zeros_like(quad_weights_bg)

    n_fg = pred_mean_fg.shape[0]

    ys_full = jnp.concatenate([counts, ys_bg.reshape(-1)])
    weights_full = jnp.concatenate(
        [jnp.repeat(-1, n_fg),
         quad_weights_bg.reshape(-1)])
    pred_mean_full = jnp.concatenate([pred_mean_fg, pred_mean_bg.reshape(-1)])
    pred_var_full = jnp.concatenate([pred_vars_fg, pred_vars_bg.reshape(-1)])

    liks = expected_ppm_likelihood_quadrature_approx(ys_full, weights_full,
                                                     pred_mean_full,
                                                     pred_var_full)

    # liks = expectation(
    #     ys_full,
    #     pred_var_full,
    #     pred_mean_full,
    #     partial(square_cox_lik, weights=weights_full),
    # )

    return jnp.sum(liks)
Пример #20
0
def _unpack_inputs(inputs: Array) -> Tuple[Array, Array]:
    inputs = jnp.atleast_2d(inputs)
    chex.assert_rank(inputs, 2)
    (mu, sigma_sq) = [jnp.squeeze(x, 1) for x in jnp.hsplit(inputs, 2)]
    return mu, sigma_sq
Пример #21
0
 def __call__(self, x):
     return self.loc + np.squeeze(
         np.matmul(self.scale_tril, x[..., np.newaxis]), axis=-1)
Пример #22
0
def main(_):
    sns.set()
    sns.set_palette(sns.color_palette('hls', 10))
    npr.seed(FLAGS.seed)

    logging.info('Starting experiment.')

    # Create model folder for outputs
    try:
        gfile.MakeDirs(FLAGS.work_dir)
    except gfile.GOSError:
        pass
    stdout_log = gfile.Open('{}/stdout.log'.format(FLAGS.work_dir), 'w+')

    # BEGIN: fetch test data and candidate pool
    test_images, test_labels, _ = datasets.get_dataset_split(
        name=FLAGS.test_split.split('-')[0],
        split=FLAGS.test_split.split('-')[1],
        shuffle=False)
    pool_images, pool_labels, _ = datasets.get_dataset_split(
        name=FLAGS.pool_split.split('-')[0],
        split=FLAGS.pool_split.split('-')[1],
        shuffle=False)

    n_pool = len(pool_images)
    # normalize to range [-1.0, 127./128]
    test_images = test_images / np.float32(128.0) - np.float32(1.0)
    pool_images = pool_images / np.float32(128.0) - np.float32(1.0)

    # augmentation for train/pool data
    if FLAGS.augment_data:
        augmentation = data.chain_transforms(data.RandomHorizontalFlip(0.5),
                                             data.RandomCrop(4), data.ToDevice)
    else:
        augmentation = None
    # END: fetch test data and candidate pool

    _, opt_update, get_params = optimizers.sgd(FLAGS.learning_rate)

    # BEGIN: load ckpt
    ckpt_dir = '{}/{}'.format(FLAGS.root_dir, FLAGS.ckpt_idx)
    with gfile.Open(ckpt_dir, 'rb') as fckpt:
        opt_state = optimizers.pack_optimizer_state(pickle.load(fckpt))
    params = get_params(opt_state)

    stdout_log.write('finetune from: {}\n'.format(ckpt_dir))
    logging.info('finetune from: %s', ckpt_dir)
    test_acc, test_pred = accuracy(params,
                                   shape_as_image(test_images, test_labels),
                                   return_predicted_class=True)
    logging.info('test accuracy: %.2f', test_acc)
    stdout_log.write('test accuracy: {}\n'.format(test_acc))
    stdout_log.flush()
    # END: load ckpt

    # BEGIN: setup for dp model
    @jit
    def update(_, i, opt_state, batch):
        params = get_params(opt_state)
        return opt_update(i, grad_loss(params, batch), opt_state)

    @jit
    def private_update(rng, i, opt_state, batch):
        params = get_params(opt_state)
        rng = random.fold_in(rng, i)  # get new key for new random numbers
        return opt_update(
            i,
            private_grad(params, batch, rng, FLAGS.l2_norm_clip,
                         FLAGS.noise_multiplier, FLAGS.batch_size), opt_state)

    # END: setup for dp model

    n_uncertain = FLAGS.n_extra + FLAGS.uncertain_extra

    ### BEGIN: prepare extra points picked from pool data
    # BEGIN: on pool data
    pool_embeddings = [apply_fn_0(params[:-1],
                                  pool_images[b_i:b_i + FLAGS.batch_size]) \
                       for b_i in range(0, n_pool, FLAGS.batch_size)]
    pool_embeddings = np.concatenate(pool_embeddings, axis=0)

    pool_logits = apply_fn_1(params[-1:], pool_embeddings)

    pool_true_labels = np.argmax(pool_labels, axis=1)
    pool_predicted_labels = np.argmax(pool_logits, axis=1)
    pool_correct_indices = \
        onp.where(pool_true_labels == pool_predicted_labels)[0]
    pool_incorrect_indices = \
        onp.where(pool_true_labels != pool_predicted_labels)[0]
    assert len(pool_correct_indices) + \
        len(pool_incorrect_indices) == len(pool_labels)

    pool_probs = stax.softmax(pool_logits)

    if FLAGS.uncertain == 0 or FLAGS.uncertain == 'entropy':
        pool_entropy = -onp.sum(pool_probs * onp.log(pool_probs), axis=1)
        stdout_log.write('all {} entropy: min {}, max {}\n'.format(
            len(pool_entropy), onp.min(pool_entropy), onp.max(pool_entropy)))

        pool_entropy_sorted_indices = onp.argsort(pool_entropy)
        # take the n_uncertain most uncertain points
        pool_uncertain_indices = \
            pool_entropy_sorted_indices[::-1][:n_uncertain]
        stdout_log.write('uncertain {} entropy: min {}, max {}\n'.format(
            len(pool_entropy[pool_uncertain_indices]),
            onp.min(pool_entropy[pool_uncertain_indices]),
            onp.max(pool_entropy[pool_uncertain_indices])))

    elif FLAGS.uncertain == 1 or FLAGS.uncertain == 'difference':
        # 1st_prob - 2nd_prob
        assert len(pool_probs.shape) == 2
        sorted_pool_probs = onp.sort(pool_probs, axis=1)
        pool_probs_diff = sorted_pool_probs[:, -1] - sorted_pool_probs[:, -2]
        assert min(pool_probs_diff) > 0.
        pool_uncertain_indices = onp.argsort(pool_probs_diff)[:n_uncertain]

    # END: on pool data

    # BEGIN: cluster uncertain pool points
    big_pca = sklearn.decomposition.PCA(n_components=pool_embeddings.shape[1])
    big_pca.random_state = FLAGS.seed
    # fit PCA onto embeddings of all the pool points
    big_pca.fit(pool_embeddings)

    # For uncertain points, project embeddings onto the first K components
    pool_uncertain_projected_embeddings, _ = utils.project_embeddings(
        pool_embeddings[pool_uncertain_indices], big_pca, FLAGS.k_components)

    n_cluster = int(FLAGS.n_extra / FLAGS.ppc)
    cluster_method = get_cluster_method('{}_nc-{}'.format(
        FLAGS.clustering, n_cluster))
    cluster_method.random_state = FLAGS.seed
    pool_uncertain_cluster_labels = cluster_method.fit_predict(
        pool_uncertain_projected_embeddings)
    pool_uncertain_cluster_label_indices = {
        x: []
        for x in set(pool_uncertain_cluster_labels)
    }  # local i within n_uncertain
    for i, c_label in enumerate(pool_uncertain_cluster_labels):
        pool_uncertain_cluster_label_indices[c_label].append(i)

    # find center of each cluster
    # aka, the most representative point of each 'tough' cluster
    pool_picked_indices = []
    pool_uncertain_cluster_label_pick = {}
    for c_label, indices in pool_uncertain_cluster_label_indices.items():
        cluster_projected_embeddings = \
            pool_uncertain_projected_embeddings[indices]
        cluster_center = onp.mean(cluster_projected_embeddings,
                                  axis=0,
                                  keepdims=True)
        if FLAGS.distance == 0 or FLAGS.distance == 'euclidean':
            cluster_distances = euclidean_distances(
                cluster_projected_embeddings, cluster_center).reshape(-1)
        elif FLAGS.distance == 1 or FLAGS.distance == 'weighted_euclidean':
            cluster_distances = weighted_euclidean_distances(
                cluster_projected_embeddings, cluster_center,
                big_pca.singular_values_[:FLAGS.k_components])

        sorted_is = onp.argsort(cluster_distances)
        sorted_indices = onp.array(indices)[sorted_is]
        pool_uncertain_cluster_label_indices[c_label] = sorted_indices
        center_i = sorted_indices[0]  # center_i in 3000
        pool_uncertain_cluster_label_pick[c_label] = center_i
        pool_picked_indices.extend(
            pool_uncertain_indices[sorted_indices[:FLAGS.ppc]])

        # BEGIN: visualize cluster of picked uncertain pool
        if FLAGS.visualize:
            this_cluster = []
            for i in sorted_indices:
                idx = pool_uncertain_indices[i]
                img = pool_images[idx]
                if idx in pool_correct_indices:
                    border_color = 'green'
                else:
                    border_color = 'red'
                    img = utils.mark_labels(img, pool_predicted_labels[idx],
                                            pool_true_labels[idx])
                img = utils.denormalize(img, 128., 128.)
                img = np.squeeze(utils.to_rgb(np.expand_dims(img, 0)))
                img = utils.add_border(img, width=2, color=border_color)
                this_cluster.append(img)
            utils.tile_image_list(
                this_cluster, '{}/picked_uncertain_pool_cid-{}'.format(
                    FLAGS.work_dir, c_label))
        # END: visualize cluster of picked uncertain pool

    # END: cluster uncertain pool points

    pool_picked_indices = list(set(pool_picked_indices))

    n_gap = FLAGS.n_extra - len(pool_picked_indices)
    gap_indices = list(set(pool_uncertain_indices) - set(pool_picked_indices))
    pool_picked_indices.extend(npr.choice(gap_indices, n_gap, replace=False))
    stdout_log.write('n_gap: {}\n'.format(n_gap))
    ### END: prepare extra points picked from pool data

    finetune_images = copy.deepcopy(pool_images[pool_picked_indices])
    finetune_labels = copy.deepcopy(pool_labels[pool_picked_indices])

    stdout_log.write('{} points picked via {}\n'.format(
        len(finetune_images), FLAGS.uncertain))
    logging.info('%d points picked via %s', len(finetune_images),
                 FLAGS.uncertain)
    assert FLAGS.n_extra == len(finetune_images)
    # END: gather points to be used for finetuning

    stdout_log.write('Starting fine-tuning...\n')
    logging.info('Starting fine-tuning...')
    stdout_log.flush()

    for epoch in range(1, FLAGS.epochs + 1):

        # BEGIN: finetune model with extra data, evaluate and save
        num_extra = len(finetune_images)
        num_complete_batches, leftover = divmod(num_extra, FLAGS.batch_size)
        num_batches = num_complete_batches + bool(leftover)

        finetune = data.DataChunk(X=finetune_images,
                                  Y=finetune_labels,
                                  image_size=28,
                                  image_channels=1,
                                  label_dim=1,
                                  label_format='numeric')

        batches = data.minibatcher(finetune,
                                   FLAGS.batch_size,
                                   transform=augmentation)

        itercount = itertools.count()
        key = random.PRNGKey(FLAGS.seed)

        start_time = time.time()

        for _ in range(num_batches):
            # tmp_time = time.time()
            b = next(batches)
            if FLAGS.dpsgd:
                opt_state = private_update(
                    key, next(itercount), opt_state,
                    shape_as_image(b.X, b.Y, dummy_dim=True))
            else:
                opt_state = update(key, next(itercount), opt_state,
                                   shape_as_image(b.X, b.Y))
            # stdout_log.write('single update in {:.2f} sec\n'.format(
            #     time.time() - tmp_time))

        epoch_time = time.time() - start_time
        stdout_log.write('Epoch {} in {:.2f} sec\n'.format(epoch, epoch_time))
        logging.info('Epoch %d in %.2f sec', epoch, epoch_time)

        # accuracy on test data
        params = get_params(opt_state)

        test_pred_0 = test_pred
        test_acc, test_pred = accuracy(params,
                                       shape_as_image(test_images,
                                                      test_labels),
                                       return_predicted_class=True)
        test_loss = loss(params, shape_as_image(test_images, test_labels))
        stdout_log.write(
            'Eval set loss, accuracy (%): ({:.2f}, {:.2f})\n'.format(
                test_loss, 100 * test_acc))
        logging.info('Eval set loss, accuracy: (%.2f, %.2f)', test_loss,
                     100 * test_acc)
        stdout_log.flush()

        # visualize prediction difference between 2 checkpoints.
        if FLAGS.visualize:
            utils.visualize_ckpt_difference(test_images,
                                            np.argmax(test_labels, axis=1),
                                            test_pred_0,
                                            test_pred,
                                            epoch - 1,
                                            epoch,
                                            FLAGS.work_dir,
                                            mu=128.,
                                            sigma=128.)

    # END: finetune model with extra data, evaluate and save
    stdout_log.close()
Пример #23
0
def forward_fn(x):
    model = MyLinear(10)
    return jnp.squeeze(jnp.sum(model(x), axis=-1))
Пример #24
0
  def run_train(self,
                experiment_dir,
                work_unit_dir,
                rng,
                yield_results=False):
    """Train a Dream Field and save results to work_unit_dir."""
    t_start = time.time()
    config = self.config

    logging.info('Local devices: %s', jax.local_devices())
    logging.info('All devices: %s', jax.devices())

    ## Load CLIP
    encode_image, encode_text, preprocess_image, tokenize_fn = (
        helpers.load_image_text_model(config.loss_model))

    ## Pick a prompt
    template = config.get('query_template', '{query}')
    query = template.format(query=config.query)
    z_clip = encode_text(tokenize_fn(query))

    ## Encode retrieval set
    if config.queries_r:
      if config.retrieve_models[0] == config.loss_model:
        # Reuse loss model.
        encode_image_r, preprocess_image_r = encode_image, preprocess_image
        encode_text_r, tokenize_fn_r = encode_text, tokenize_fn
      else:
        # Load new model.
        encode_image_r, encode_text_r, preprocess_image_r, tokenize_fn_r = (
            helpers.load_image_text_model(config.retrieve_models[0]))

      if config.query not in config.queries_r:
        config.queries_r.append(config.query)
      z_clip_r = encode_text_r(tokenize_fn_r(config.queries_r))
      true_idx_r = config.queries_r.index(config.query)
      assert true_idx_r >= 0  # Input query must be set of retrieval queries.

      del encode_text_r, tokenize_fn_r  # Clean up retrieval text encoder.

    del encode_text, tokenize_fn  # Clean up text encoder.

    ## Scene origin manually tracked
    scene_origin = scene.EMA(np.zeros(3, dtype=np.float64), decay=0.999)

    def train_step(state, rays, key, *multistep_constants):
      """Perform a training iteration, optionally composed of multiple substeps.

      Using multiple substeps slightly reduces training time, but only one
      substep per training iteration is used in experiments.

      Args:
        state: Optimizer state.
        rays: Camera rays for rendering, shared across all substeps.
        key: PRNGKey for random number generation (e.g. for augmentations).
        *multistep_constants: Training constants that can vary across substeps.
          7 arrays of constants of length config.substeps are expected:
            (1) lrs: learning rates
            (2) scs: scale factor for integrated positional encoding. Larger
              scales lead to a blurrier appearance. A constant sc=1 is the
              standard mip-NeRF IPE, and used by Dream Fields.
            (3) sns: standard deviation of pre-activation noise for NeRF
              density. Dream Fields use sn=0. density(x) = softplus(s(x) + eps),
              eps ~ N(0, sn^2)
            (4) mrs: norm of radiance mask, defining scene bounds.
            (5) betas: scale of beta prior loss. Dream Fields use beta=0.
            (6) acct: transmittance loss hyperparameter, defining the target
              average opacity. This is 1 - tau (target transmittance).
            (7) acclam: weight of transmittance loss.

      Returns:
        state: Updated optimizer state.
        last_augs: Augmented views of renderings from the last substep.
        mean_losses: Dictionary of losses averaged over replicas and substeps.
        scene_origin: Updated origin of the scene, based on the center of mass.
      """
      # NOTE(jainajay): rays are shared across all substeps
      pmean = functools.partial(jax.lax.pmean, axis_name='batch')
      psum = functools.partial(jax.lax.psum, axis_name='batch')

      def loss_fn(params, key, sc, sn, mr, beta, acct, acclam):
        render_key, aug_key, key = random.split(key, 3)

        # Render from nerf
        (rgb_est_flat, _, acc_est_flat), aux = render_rays(
            rays=rays,
            variables=params,
            rng=render_key,
            config=config,
            sc=sc,
            sigma_noise_std=sn,
            mask_rad=mr,
            origin=scene_origin.value,
            train=True)
        rgb_est = scene.gather_and_reshape(rgb_est_flat, config.render_width, 3)
        acc_est = scene.gather_and_reshape(acc_est_flat, config.render_width, 1)
        # Make augmentations process specific
        aug_key = random.fold_in(aug_key, pid)
        # Perform augmentations and resize to clip_width
        augs = augment.augment_rendering(config, rgb_est, acc_est, aug_key)

        # Run through CLIP
        z_est = encode_image(preprocess_image(augs))
        clip_loss = -(z_est * z_clip).sum(-1).mean()
        total_loss = clip_loss

        transparency_loss = config.get('transparency_loss', None)
        acc_mean = np.mean(acc_est)
        aux['losses']['acc_mean'] = acc_mean
        if transparency_loss == 'neg_lam_transmittance_clipped':
          # Compute the Dream Fields transmittance loss for scene sparsity.
          trans_mean = 1 - acc_mean
          trans_mean_clipped = np.minimum(1 - acct, trans_mean)
          reg = acclam * trans_mean_clipped
          total_loss -= reg

          aux['losses']['trans_mean_clipped'] = trans_mean_clipped
          aux['losses']['acc_reg_additive'] = reg
        else:
          assert transparency_loss is None

        # Compute a sparsity loss by placing a bimodal beta prior on the
        # per-pixel transmittance. This prior was proposed by Lombardi et al
        # in "Neural Volumes: Learning Dynamic Renderable Volumes from Images"
        # and is used only in ablations.
        beta_loss = np.mean(
            np.log(np.maximum(1e-6, acc_est_flat)) +
            np.log(np.maximum(1e-6, 1. - acc_est_flat)))
        total_loss += beta_loss * beta

        # Compute a weighted mean of each replica's estimated scene origin,
        # since replicas get a different subset of rays
        total_sigma = psum(aux['scene_origin_sigma'])
        aux['scene_origin'] = psum(aux['scene_origin'] *
                                   aux['scene_origin_sigma'] / total_sigma)
        # Compute loss that pushes scene content to 0 origin. We set the loss
        # weight zero_origin_lam = 0 in experiments so the loss is just for
        # logging how far the origin has drifted.
        origin_loss = np.sum(np.square(aux['scene_origin']))
        if config.get('zero_origin_lam', 0.):
          total_loss += config.zero_origin_lam * origin_loss

        aux['losses'].update({
            'clip_loss': clip_loss,
            'beta_loss': beta_loss,
            'origin_loss': origin_loss,
            'loss': total_loss,
        })
        aux['augs'] = augs
        return total_loss, aux

      grad_fn = jax.value_and_grad(loss_fn, has_aux=True)

      # Scan over substeps
      def body_fn(state, step_constants):
        lr, step_constants = step_constants[0], step_constants[1:]
        grad_fn_key, _ = random.split(key, 2)
        (_, aux), grad = grad_fn(state.target, grad_fn_key, *step_constants)
        grad = pmean(grad)  # all-reduce grad
        aux['losses'] = pmean(aux['losses'])
        aux['losses']['grad_norm'] = helpers.tree_norm(grad)
        state = state.apply_gradient(grad, learning_rate=lr)
        return state, aux

      assert len(multistep_constants) == 7
      multistep_constants = np.array(multistep_constants).T

      if config.substeps == 1:
        state, aux = body_fn(state, np.squeeze(multistep_constants))
        last_augs = aux['augs']
      else:
        state, aux = jax.lax.scan(body_fn, state, multistep_constants)
        # Augmentations from last substep.
        # Shape: [n_local_aug, clip_width, clip_width, 3]
        last_augs = aux['augs'][-1]

      # Average each type of loss over substeps
      mean_losses = jax.tree_map(np.mean, aux['losses'])
      return state, last_augs, mean_losses, aux['scene_origin']

    train_pstep = jax.pmap(
        train_step,
        axis_name='batch',
        in_axes=(0, 0, 0, None, None, None, None, None, None, None))

    onp.random.seed(config.seed)

    n_device = jax.local_device_count()
    pid = jax.process_index()
    logging.info('n_device %d', n_device)
    ## Modified NeRF architecture, with swish, softplus, skips.
    variables, render_rays = helpers.init_nerf_model(rng.advance(1), config)
    state = flax.optim.Adam(config.lr0, eps=config.adam_eps).create(variables)

    ## Try to restore a checkpoint.
    restore_dir = config.get('restore_dir', experiment_dir)
    restore_dir = os.path.join(restore_dir, os.path.basename(work_unit_dir))
    if checkpoints.latest_checkpoint(restore_dir):
      restored = checkpoints.restore_checkpoint(
          restore_dir,
          target={
              'origin': np.zeros(3),
              'state': state,
              'vars': variables
          })
      scene_origin.value = onp.array(restored['origin'])
      state = restored['state']
      variables = restored['vars']
      logging.info('restored checkpoint from step %d', state.state.step)
    else:
      logging.info('did not find checkpoint in %s', restore_dir)

    ## Replicate state.
    step_init = state.state.step
    helpers.defragment()
    state = flax.jax_utils.replicate(state, jax.devices())
    helpers.defragment()

    ## pmap'd rendering for test time evaluation.
    kwargs_test = dict(rng=None, sigma_noise_std=0.)
    config_test = ml_collections.ConfigDict(config)
    config_test.update(config.test)
    config_test_hq = ml_collections.ConfigDict(config_test)
    config_test_hq.update(config.test_hq)

    @functools.partial(jax.pmap, in_axes=(0, None, None, None))
    def render_test_p(rays, variables, sc=1., mr=1.):
      return render_rays(
          rays=rays,
          variables=variables,
          sc=sc,
          mask_rad=mr,
          origin=scene_origin.value,
          config=config_test,
          **kwargs_test)[0]

    @functools.partial(jax.pmap, in_axes=(0, None, None, None))
    def render_test_hq_p(rays, variables, sc=1., mr=1.):
      return render_rays(
          rays=rays,
          variables=variables,
          config=config_test_hq,
          sc=sc,
          mask_rad=mr,
          origin=scene_origin.value,
          **kwargs_test)[0]

    def render_test(rays, variables, sc=1., mr=1., hq=False):
      sh = rays[0].shape
      rays = [x.reshape((jax.device_count(), -1) + x.shape[1:]) for x in rays]
      if hq:
        out = render_test_hq_p(rays, variables, sc, mr)
      else:
        out = render_test_p(rays, variables, sc, mr)
      out = [x.reshape(sh[:-1] + (-1,)) for x in out]
      return out

    def render_loop(rays, variables, sc=1., mr=1., chunk=2**13, hq=False):
      sh = list(rays[0].shape[:-1])
      rays = [x.reshape((-1,) + x.shape[-1:]) for x in rays]
      outs = [
          render_test([x[i:i + chunk]
                       for x in rays], variables, sc, mr, hq=hq)
          for i in range(0, rays[0].shape[0], chunk)
      ]
      outs = [
          np.reshape(np.concatenate([z[i]
                                     for z in outs]), sh + [-1])
          for i in range(3)
      ]
      return outs

    ## Training loop
    t_total = 0.
    logging.info('Experiment dir %s', experiment_dir)
    logging.info('Work unit dir %s', work_unit_dir)
    gfile.makedirs(work_unit_dir)

    # Set up metric writer
    writer = metric_writers.create_default_writer(
        work_unit_dir, asynchronous=True, just_logging=jax.process_index() > 0)
    if jax.process_index() == 0:
      train_config = config.copy_and_resolve_references()
      log.write_config_json(train_config, work_unit_dir)

    # Scale instrinsics to different resolutions.
    hwf_clip_r = scene.scale_intrinsics(config.retrieve_widths[0])
    hwf_base = scene.scale_intrinsics(config.render_width)
    hwf_video = scene.scale_intrinsics(config.get('lq_video_width', 300.))
    hwf_video_hq = scene.scale_intrinsics(config.get('hq_video_width', 400.))

    # JIT compile ray generation
    @jax.jit
    def camera_ray_batch_base(p, focal_mult):
      return scene.camera_ray_batch(p, *hwf_base[:2], hwf_base[2] * focal_mult)

    @jax.jit
    def sample_pose_focal(key):
      return scene.sample_camera(key, config.th_range, config.phi_range,
                                 config.rad_range, config.focal_mult_range)

    shard_rays_jit = jax.jit(functools.partial(scene.shard_rays))

    def sample_iter_data(key, step):
      # Sample pose, focal length multiplier.
      pose, rad, focal_mult = sample_pose_focal(key)

      # Generate rays, shaped for pmap over devices.
      rays = camera_ray_batch_base(pose, focal_mult)
      rays_in = shard_rays_jit(rays)
      # Select rays for this process
      rays_in = jax.tree_map(lambda x: x[pid], rays_in)

      substeps = np.arange(start=step, stop=step + config.substeps, step=1)

      # mip-NeRF scale annealing.
      decays = config.mipnerf.decay_start * (
          1 - substeps / config.mipnerf.decay_iters)
      scs = np.maximum(1., 2**decays)

      # Sigma noise annealing.
      sns = schedule.sigma_noise_std_fn(
          substeps, i_split=config.sn_i_split, sn0=config.sn0, sn1=config.sn1)

      # Scene bounds annealing.
      mrs = schedule.mask_rad_fn(
          substeps, i_split=config.mr_i_split, mr0=config.mr0, mr1=config.mr1)

      # Anneal target opacity (1 - transmittance).
      accts = schedule.anneal_exponentially(substeps, config.acc_target_i_split,
                                            config.acc_target0,
                                            config.acc_target1)
      # The area of an object on the image plane grows with the focal length
      # and shrinks with increasing camera radius. Scale target opacity
      # proportionally with the squared focal multiplier and inversely
      # proportionally with the squared camera radius. For consistency with
      # early experiments that did not use this scaling, we also scale by a
      # constant, 1 / (4^2 * 1.2).
      acct_scaling = focal_mult**2 / ((rad / 4.)**2) / 1.2
      accts = np.minimum(1., acct_scaling * accts)
      acclams = np.where(substeps < config.acc_lam_after, 0., config.acc_lam)

      # Beta prior encourages either 0 or 1 opacity for rays
      betas = np.where(substeps < config.beta_after, .0,
                       config.get('beta_lam', .001))

      # Learning rate schedule.
      # NOTE: vectorized calculation of lrs doesn't work with multiple substeps
      lrs = schedule.lr_fn(
          substeps,
          i_split=config.lr_i_split,
          i_end=config.iters,
          lr0=config.lr0,
          lr1=config.lr1,
          lr2=config.lr2,
          cosine_decay=config.lr_cosine_decay)

      return substeps, rays_in, lrs, scs, sns, mrs, betas, accts, acclams

    pbar = tqdm.trange(
        step_init,
        config.iters + config.substeps,
        config.substeps,
        desc='training')
    for i in pbar:
      t = time.time()

      substeps, rays_in, lrs, scs, sns, mrs, betas, accts, acclams = (
          sample_iter_data(rng.advance(1), i))
      l = substeps[-1]

      keys_pstep = rng.split(n_device)
      # NOTE: loss is averaged across substeps.
      new_state, augs, mean_losses, new_scene_origin = train_pstep(
          state, rays_in, keys_pstep, lrs, scs, sns, mrs, betas, accts, acclams)

      # Reduce across devices
      mean_losses = jax.tree_map(np.mean, mean_losses)

      # Gradient skipping if nan.
      if (helpers.all_finite_tree(mean_losses) and
          helpers.all_finite_tree(new_state)):
        state = new_state
      else:
        logging.warn('Skipping update on step %d. non-finite loss or state', i)
        continue

      # Update scene origin.
      if config.get('ema_scene_origin', False):
        if helpers.all_finite(new_scene_origin):
          scene_origin.update(new_scene_origin[0])
        else:
          logging.warn(
              'Skipping origin update on step %d. '
              'non-finite origin. old: %s skipped update: %s', i,
              scene_origin.value, new_scene_origin)

      ## Yield results, for display in colab.
      augs = augs.reshape(-1, *augs.shape[2:])  # devices, n_localaug, HWC->BHWC
      if yield_results:
        yield mean_losses, augs, scene_origin.value
      else:
        yield None
      pbar.set_description(f'Loss: {mean_losses["loss"]:.4f}')

      ## Logging.
      if i == 0:
        continue

      t_total += time.time() - t

      if i % config.log_scalars_every == 0:
        scalars = {f'losses/{key}': value for key, value in mean_losses.items()}
        scalars.update({
            'schedule/mipnerf_scale': scs[-1],
            'schedule/lr': lrs[-1],
            'schedule/mask_rad': mrs[-1],
            'schedule/sigma_noise_std': sns[-1],
            'schedule/beta': betas[-1],
            'schedule/acc_target': accts[-1],
            'schedule/acc_lam': acclams[-1],
            'origin/x': scene_origin.value[0],
            'origin/y': scene_origin.value[1],
            'origin/z': scene_origin.value[2],
            'origin/norm': np.linalg.norm(scene_origin.value),
        })

        secs_per_iter = t_total / (l - step_init)
        iters_per_sec = (l - step_init) / t_total
        wall = time.time() - t_start
        scalars.update({
            'system/wall': wall,
            'system/secs_per_iter': secs_per_iter,
            'system/iters_per_sec': iters_per_sec,
        })

      if i % config.render_every == 0:
        variables = helpers.state_to_variables(state)
        cam2world = scene.pose_spherical(30., -45., 4.)
        rays = scene.camera_ray_batch(cam2world, *hwf_clip_r)

        # Render with no scale manipulation.
        outs = render_loop(rays, variables, sc=1., mr=mrs[-1], hq=True)
        outs = [np.squeeze(x) for x in outs]
        step_images = {
            'render/rgb': outs[0][None],
            'render/depth': outs[1][None, Ellipsis, None],
            'render/acc': outs[2][None, Ellipsis, None],
        }

        # Compute retrieval metric.
        if config.queries_r:
          z_est = encode_image_r(preprocess_image_r(outs[0][None]))
          cosine_sim = (z_est * z_clip_r).sum(-1)  # 1d, num retrieval queries
          log_prob = nn.log_softmax(cosine_sim)
          prefix = f'val/{config.retrieve_models[0]}/retrieve_'
          scalars.update({
              f'{prefix}cosine_sim':
                  cosine_sim[true_idx_r],
              f'{prefix}loss':
                  -log_prob[true_idx_r],
              f'{prefix}acc':
                  (np.argmax(cosine_sim) == true_idx_r).astype(float)
          })

        augs_tiled = log.make_image_grid(augs[:8])
        step_images['render/augmentations'] = augs_tiled

        fig = plt.figure()
        plt.imshow(1. / np.maximum(config.near, outs[1]))
        plt.colorbar()
        plt.title('disparity')
        disparity = log.plot_to_image(fig)
        step_images['render/disparity'] = disparity

        writer.write_images(step=l, images=step_images)

      if config.render_lq_video and (i == config.iters or config.video_every and
                                     i % config.video_every == 0):

        def rays_theta(th):
          cam2world = scene.pose_spherical(th, -30., 4.)
          return scene.camera_ray_batch(cam2world, *hwf_video)

        th_range = np.linspace(
            0, 360, config.get('lq_video_n_frames', 60), endpoint=False)
        variables = helpers.state_to_variables(state)
        frames_all = [
            render_loop(rays_theta(th), variables, scs[-1], mrs[-1], hq=False)
            for th in tqdm.tqdm(th_range, desc='render video')
        ]

        videos = [[np.squeeze(f[i]) for f in frames_all] for i in range(3)]
        for video, label in zip(videos, 'rgb depth acc'.split()):
          scale = (label == 'depth')
          log.log_video(
              None, video, 'frames', label, l, work_unit_dir, scale=scale)

      if i % config.log_scalars_every == 0:
        writer.write_scalars(step=l, scalars=scalars)

      if i % config.flush_every == 0:
        writer.flush()

      defrag_every = config.get('defragment_every', default=0)
      if defrag_every and i % defrag_every == 0:
        helpers.defragment()

      if config.get('checkpoint_every') and i % config.checkpoint_every == 0:
        saved_path = checkpoints.save_checkpoint(
            ckpt_dir=work_unit_dir,
            target={
                'state': flax.jax_utils.unreplicate(state),
                'vars': helpers.state_to_variables(state),
                'origin': np.array(scene_origin.value),
            },
            step=l,
            keep=1,
            overwrite=True,
            keep_every_n_steps=config.get('keep_every_n_steps', None))
        logging.info('saved checkpoint to %s', saved_path)

      # Make a higher res, higher frame rate video.
      if config.render_hq_video and (config.get('hq_video_every', None) and
                                     i % config.hq_video_every == 0 or
                                     i == config.iters):

        my_rays = lambda c2w: scene.camera_ray_batch(c2w, *hwf_video_hq)
        th_range = np.linspace(
            0, 360, config.get('hq_video_n_frames', 240), endpoint=False)
        poses = [scene.pose_spherical(th, -30., 4.) for th in th_range]
        variables = helpers.state_to_variables(state)
        frames_all = [
            render_loop(my_rays(pose), variables, 1., config.mr1, hq=True)
            for pose in tqdm.tqdm(poses, 'render hq video')
        ]

        videos = [
            [onp.array(np.squeeze(f[j])) for f in frames_all] for j in range(3)
        ]
        meta_path = os.path.join(work_unit_dir, 'meta_hq.npy')
        with gfile.GFile(meta_path, 'wb') as f:
          logging.info('saving metadata for rendered hq frames to %s',
                       meta_path)
          onp.save(f, dict(poses=onp.array(poses), hwf=onp.array(hwf_video_hq)))
        for video, label in zip(videos, 'rgb depth acc'.split()):
          scale = (label == 'depth')
          log.log_video(
              None, video, 'frames_hq', label, i, work_unit_dir, scale=scale)

    writer.flush()
    writer.close()
    logging.info('%f sec elapsed total', time.time() - t_start)
Пример #25
0
 def apply_fun(params, inputs, **kwargs):
     return jnp.squeeze(inputs, axis=axis)
Пример #26
0
  def render_from_checkpoint(self,
                             work_unit_dir,
                             widths,
                             render_test_hq_p,
                             step=None):
    """Restore learned radiance field weights and scene origin."""
    zero_outs = {
        width: [np.zeros((width, width, c)).squeeze() for c in [3, 1, 1, 3]
               ] for width in widths
    }
    latest_checkpoint = checkpoints.latest_checkpoint(work_unit_dir)
    if not latest_checkpoint:
      print(f'ERROR: no checkpoint found in {work_unit_dir}')
      return latest_checkpoint, zero_outs

    try:
      restored = checkpoints.restore_checkpoint(
          work_unit_dir, target=None, step=step)
    except ValueError as e:
      print(f'ERROR loading checkpoint from {work_unit_dir} at step {step}:', e)
      return latest_checkpoint, zero_outs
    variables = flax.core.frozen_dict.FrozenDict(restored['vars'])
    origin = restored['origin']
    if not np.all(np.isfinite(origin)):
      print('origin', origin, 'has nan value(s) for wu', work_unit_dir)

    # Render wrapper methods.
    def render_test(rays):
      sh = rays[0].shape
      rays = scene.padded_shard_rays(rays, multihost=False)
      out = render_test_hq_p(rays, variables, origin)
      out = [x.reshape((onp.prod(sh[:-1]), -1)) for x in out]  # gather flat
      out = [x[:sh[0]] for x in out]  # Unpad
      return out

    def render_loop(rays, chunk=2**16):
      sh = list(rays[0].shape[:-1])
      rays = [x.reshape((-1,) + x.shape[-1:]) for x in rays]
      outs = [
          render_test([x[i:i + chunk]
                       for x in rays])
          for i in range(0, rays[0].shape[0], chunk)
      ]
      outs = [
          np.reshape(np.concatenate([z[i]
                                     for z in outs]), sh + [-1])
          for i in range(3)
      ]
      return outs

    # Render validation view.
    renders_by_width = {}
    for width in set(widths):
      logging.info('rendering at width %d', width)
      hwf_clip_r = scene.scale_intrinsics(width)
      cam2world = scene.pose_spherical(30., -45., 4.)
      rays = scene.camera_ray_batch(cam2world, *hwf_clip_r)
      outs = render_loop(rays)
      outs = [np.squeeze(x) for x in outs]
      renders_by_width[width] = outs

    return latest_checkpoint, renders_by_width
Пример #27
0
 def apply(self, inputs: jnp.ndarray):
     inputs = utils.batch_concat(inputs)
     logits = MLP(inputs, [64, 64, num_actions])
     value = MLP(inputs, [64, 64, 1])
     value = jnp.squeeze(value, axis=-1)
     return tfd.Categorical(logits=logits), value
Пример #28
0
    def train_step(state, rays, key, *multistep_constants):
      """Perform a training iteration, optionally composed of multiple substeps.

      Using multiple substeps slightly reduces training time, but only one
      substep per training iteration is used in experiments.

      Args:
        state: Optimizer state.
        rays: Camera rays for rendering, shared across all substeps.
        key: PRNGKey for random number generation (e.g. for augmentations).
        *multistep_constants: Training constants that can vary across substeps.
          7 arrays of constants of length config.substeps are expected:
            (1) lrs: learning rates
            (2) scs: scale factor for integrated positional encoding. Larger
              scales lead to a blurrier appearance. A constant sc=1 is the
              standard mip-NeRF IPE, and used by Dream Fields.
            (3) sns: standard deviation of pre-activation noise for NeRF
              density. Dream Fields use sn=0. density(x) = softplus(s(x) + eps),
              eps ~ N(0, sn^2)
            (4) mrs: norm of radiance mask, defining scene bounds.
            (5) betas: scale of beta prior loss. Dream Fields use beta=0.
            (6) acct: transmittance loss hyperparameter, defining the target
              average opacity. This is 1 - tau (target transmittance).
            (7) acclam: weight of transmittance loss.

      Returns:
        state: Updated optimizer state.
        last_augs: Augmented views of renderings from the last substep.
        mean_losses: Dictionary of losses averaged over replicas and substeps.
        scene_origin: Updated origin of the scene, based on the center of mass.
      """
      # NOTE(jainajay): rays are shared across all substeps
      pmean = functools.partial(jax.lax.pmean, axis_name='batch')
      psum = functools.partial(jax.lax.psum, axis_name='batch')

      def loss_fn(params, key, sc, sn, mr, beta, acct, acclam):
        render_key, aug_key, key = random.split(key, 3)

        # Render from nerf
        (rgb_est_flat, _, acc_est_flat), aux = render_rays(
            rays=rays,
            variables=params,
            rng=render_key,
            config=config,
            sc=sc,
            sigma_noise_std=sn,
            mask_rad=mr,
            origin=scene_origin.value,
            train=True)
        rgb_est = scene.gather_and_reshape(rgb_est_flat, config.render_width, 3)
        acc_est = scene.gather_and_reshape(acc_est_flat, config.render_width, 1)
        # Make augmentations process specific
        aug_key = random.fold_in(aug_key, pid)
        # Perform augmentations and resize to clip_width
        augs = augment.augment_rendering(config, rgb_est, acc_est, aug_key)

        # Run through CLIP
        z_est = encode_image(preprocess_image(augs))
        clip_loss = -(z_est * z_clip).sum(-1).mean()
        total_loss = clip_loss

        transparency_loss = config.get('transparency_loss', None)
        acc_mean = np.mean(acc_est)
        aux['losses']['acc_mean'] = acc_mean
        if transparency_loss == 'neg_lam_transmittance_clipped':
          # Compute the Dream Fields transmittance loss for scene sparsity.
          trans_mean = 1 - acc_mean
          trans_mean_clipped = np.minimum(1 - acct, trans_mean)
          reg = acclam * trans_mean_clipped
          total_loss -= reg

          aux['losses']['trans_mean_clipped'] = trans_mean_clipped
          aux['losses']['acc_reg_additive'] = reg
        else:
          assert transparency_loss is None

        # Compute a sparsity loss by placing a bimodal beta prior on the
        # per-pixel transmittance. This prior was proposed by Lombardi et al
        # in "Neural Volumes: Learning Dynamic Renderable Volumes from Images"
        # and is used only in ablations.
        beta_loss = np.mean(
            np.log(np.maximum(1e-6, acc_est_flat)) +
            np.log(np.maximum(1e-6, 1. - acc_est_flat)))
        total_loss += beta_loss * beta

        # Compute a weighted mean of each replica's estimated scene origin,
        # since replicas get a different subset of rays
        total_sigma = psum(aux['scene_origin_sigma'])
        aux['scene_origin'] = psum(aux['scene_origin'] *
                                   aux['scene_origin_sigma'] / total_sigma)
        # Compute loss that pushes scene content to 0 origin. We set the loss
        # weight zero_origin_lam = 0 in experiments so the loss is just for
        # logging how far the origin has drifted.
        origin_loss = np.sum(np.square(aux['scene_origin']))
        if config.get('zero_origin_lam', 0.):
          total_loss += config.zero_origin_lam * origin_loss

        aux['losses'].update({
            'clip_loss': clip_loss,
            'beta_loss': beta_loss,
            'origin_loss': origin_loss,
            'loss': total_loss,
        })
        aux['augs'] = augs
        return total_loss, aux

      grad_fn = jax.value_and_grad(loss_fn, has_aux=True)

      # Scan over substeps
      def body_fn(state, step_constants):
        lr, step_constants = step_constants[0], step_constants[1:]
        grad_fn_key, _ = random.split(key, 2)
        (_, aux), grad = grad_fn(state.target, grad_fn_key, *step_constants)
        grad = pmean(grad)  # all-reduce grad
        aux['losses'] = pmean(aux['losses'])
        aux['losses']['grad_norm'] = helpers.tree_norm(grad)
        state = state.apply_gradient(grad, learning_rate=lr)
        return state, aux

      assert len(multistep_constants) == 7
      multistep_constants = np.array(multistep_constants).T

      if config.substeps == 1:
        state, aux = body_fn(state, np.squeeze(multistep_constants))
        last_augs = aux['augs']
      else:
        state, aux = jax.lax.scan(body_fn, state, multistep_constants)
        # Augmentations from last substep.
        # Shape: [n_local_aug, clip_width, clip_width, 3]
        last_augs = aux['augs'][-1]

      # Average each type of loss over substeps
      mean_losses = jax.tree_map(np.mean, aux['losses'])
      return state, last_augs, mean_losses, aux['scene_origin']
def munchausen_target_quantile_values(network, target_params, states, actions,
                                      next_states, rewards, terminals,
                                      num_tau_prime_samples,
                                      num_quantile_samples, cumulative_gamma,
                                      rng, tau, alpha, clip_value_min):
    """Build the munchausen target for return values at given quantiles."""
    rng, rng1, rng2, rng3 = jax.random.split(rng, num=4)
    target_action = network.apply(target_params,
                                  states,
                                  num_quantiles=num_quantile_samples,
                                  rng=rng1)
    curr_state_representation = target_action.representation
    curr_state_representation = jnp.squeeze(curr_state_representation)
    is_terminal_multiplier = 1. - terminals.astype(jnp.float32)
    # Incorporate terminal state to discount factor.
    gamma_with_terminal = cumulative_gamma * is_terminal_multiplier
    gamma_with_terminal = jnp.tile(gamma_with_terminal,
                                   [num_tau_prime_samples])

    replay_net_target_outputs = network.apply(
        target_params,
        next_states,
        num_quantiles=num_tau_prime_samples,
        rng=rng2)
    replay_quantile_values = replay_net_target_outputs.quantile_values

    target_next_action = network.apply(target_params,
                                       next_states,
                                       num_quantiles=num_quantile_samples,
                                       rng=rng3)
    target_next_quantile_values_action = target_next_action.quantile_values
    replay_next_target_q_values = jnp.squeeze(
        jnp.mean(target_next_quantile_values_action, axis=0))

    q_state_values = target_action.quantile_values
    replay_target_q_values = jnp.squeeze(jnp.mean(q_state_values, axis=0))

    num_actions = q_state_values.shape[-1]
    replay_action_one_hot = jax.nn.one_hot(actions, num_actions)
    replay_next_log_policy = stable_scaled_log_softmax(
        replay_next_target_q_values, tau, axis=0)
    replay_next_policy = stable_softmax(replay_next_target_q_values,
                                        tau,
                                        axis=0)
    replay_log_policy = stable_scaled_log_softmax(replay_target_q_values,
                                                  tau,
                                                  axis=0)

    tau_log_pi_a = jnp.sum(replay_log_policy * replay_action_one_hot, axis=0)
    tau_log_pi_a = jnp.clip(tau_log_pi_a, a_min=clip_value_min, a_max=1)
    munchausen_term = alpha * tau_log_pi_a
    weighted_logits = (replay_next_policy *
                       (replay_quantile_values - replay_next_log_policy))

    target_quantile_vals = jnp.sum(weighted_logits, axis=1)
    rewards += munchausen_term
    rewards = jnp.tile(rewards, [num_tau_prime_samples])
    target_quantile_vals = (rewards +
                            gamma_with_terminal * target_quantile_vals)
    next_state_representation = target_next_action.representation
    next_state_representation = jnp.squeeze(next_state_representation)

    return (rng, jax.lax.stop_gradient(target_quantile_vals[:, None]),
            jax.lax.stop_gradient(curr_state_representation),
            jax.lax.stop_gradient(next_state_representation))
Пример #30
0
def onnx_squeeze(x, axes: List[int]):
    return jnp.squeeze(x, axes)