示例#1
0
    def evaluate_example(self, example: SingleExample,
                         prediction: SinglePrediction) -> MeanStat:
        """Computes token accuracy for a single sequence example.

    Args:
      example: One example with target in range [0, num_classes) of shape [1].
      prediction: Unnormalized prediction for ``example`` of shape [num_classes]

    Returns:
      MeanStat for token accuracy for a single sequence example or at each
        token position if ``per_position`` is ``True``.
    """
        target = example[self.target_key]
        pred = prediction if self.pred_key is None else prediction[
            self.pred_key]
        if self.logits_mask is not None:
            logits_mask = jnp.array(self.logits_mask)
            pred += logits_mask
        target_weight = get_target_weight(target, self.masked_target_values)
        correct = (target == jnp.argmax(pred, axis=-1)).astype(jnp.float32)
        if self.per_position:
            return MeanStat.new(correct * target_weight, target_weight)
        return MeanStat.new(jnp.sum(correct * target_weight),
                            jnp.sum(target_weight))
示例#2
0
    def test_argmax_consistent(self):
        rngs = PRNGSequence(13)

        vec = jax.random.normal(next(rngs), shape=(5,))
        mat = jax.random.normal(next(rngs), shape=(3, 5))
        ten = jax.random.normal(next(rngs), shape=(3, 5, 7))

        self.assertEqual(
            argmax(next(rngs), vec), jnp.argmax(vec, axis=-1))
        self.assertArrayAlmostEqual(
            argmax(next(rngs), mat), jnp.argmax(mat, axis=-1))
        self.assertArrayAlmostEqual(
            argmax(next(rngs), mat, axis=0), jnp.argmax(mat, axis=0))
        self.assertArrayAlmostEqual(
            argmax(next(rngs), ten), jnp.argmax(ten, axis=-1))
        self.assertArrayAlmostEqual(
            argmax(next(rngs), ten, axis=0), jnp.argmax(ten, axis=0))
        self.assertArrayAlmostEqual(
            argmax(next(rngs), ten, axis=1), jnp.argmax(ten, axis=1))
def extract_best_search_space(scores_dict, centrality_key,
                              search_spaces_reduced):
  """Select the arg max score search space."""
  return search_spaces_reduced[jnp.argmax(scores_dict[centrality_key]), :, :]
mdbM = moldb.MdbExomol('.database/H2O/1H2-16O/POKAZATEL', nus,
                       crit=1.e-45)  # loading molecular dat
molmassM = molinfo.molmass('H2O')  # molecular mass (H2O)

q = mdbM.qr_interp(1500.0)
S = SijT(1500.0, mdbM.logsij0, mdbM.nu_lines, mdbM.elower, q)
mask = S > 1.e-25
mdbM.masking(mask)

Tarr = jnp.logspace(jnp.log10(800), jnp.log10(1600), 100)
qt = vmap(mdbM.qr_interp)(Tarr)
SijM = jit(vmap(SijT,
                (0, None, None, None, 0)))(Tarr, mdbM.logsij0, mdbM.nu_lines,
                                           mdbM.elower, qt)

imax = jnp.argmax(SijM, axis=0)
Tmax = Tarr[imax]
print(jnp.min(Tmax))

pl = planck.piBarr(jnp.array([1100.0, 1000.0]), nus)
print(pl[1] / pl[0])

pl = planck.piBarr(jnp.array([1400.0, 1200.0]), nus)
print(pl[1] / pl[0])

lsa = ['solid', 'dashed', 'dotted', 'dashdot']
lab = ['A', 'B', 'C']
fac = 1.e22
fig = plt.figure(figsize=(12, 6))
# for j,i in enumerate(range(len(mdbM.A))):
# for j,i in enumerate([56,72,141,147,173,236,259,211,290]):
示例#5
0
 def _per_batch(inputs, labels):
   target_class = jnp.argmax(labels, axis=1)
   predicted_class = jnp.argmax(predict(params, inputs), axis=1)
   return jnp.mean(predicted_class == target_class)
示例#6
0
 def accuracy(params, mask):
     logits = model.apply(params, graph, train=False)
     correct = jnp.argmax(logits, -1) == jnp.argmax(labels, -1)
     return jnp.sum(correct * mask) / jnp.sum(mask)
def ground_truth_label(sentences, lengths):

  scores = batch_score(sentences, lengths)
  return jnp.argmax(scores, axis=1)
示例#8
0
def accuracy(params, batch, model_predict):
  """Calculate accuracy."""
  inputs, targets = batch
  predicted_class = np.argmax(model_predict(params, inputs), axis=1)
  return np.mean(predicted_class == targets)
示例#9
0
        def loss(
            params: networks_lib.Params, target_params: networks_lib.Params,
            key_grad: networks_lib.PRNGKey, sample: reverb.ReplaySample
        ) -> Tuple[jnp.ndarray, Tuple[jnp.ndarray, jnp.ndarray]]:
            """Computes mean transformed N-step loss for a batch of sequences."""

            # Convert sample data to sequence-major format [T, B, ...].
            data = utils.batch_to_sequence(sample.data)

            # Get core state & warm it up on observations for a burn-in period.
            if use_core_state:
                # Replay core state.
                online_state = jax.tree_map(lambda x: x[0],
                                            data.extras['core_state'])
            else:
                online_state = initial_state
            target_state = online_state

            # Maybe burn the core state in.
            if burn_in_length:
                burn_obs = jax.tree_map(lambda x: x[:burn_in_length],
                                        data.observation)
                key_grad, key1, key2 = jax.random.split(key_grad, 3)
                _, online_state = unroll.apply(params, key1, burn_obs,
                                               online_state)
                _, target_state = unroll.apply(target_params, key2, burn_obs,
                                               target_state)

            # Only get data to learn on from after the end of the burn in period.
            data = jax.tree_map(lambda seq: seq[burn_in_length:], data)

            # Unroll on sequences to get online and target Q-Values.
            key1, key2 = jax.random.split(key_grad)
            online_q, _ = unroll.apply(params, key1, data.observation,
                                       online_state)
            target_q, _ = unroll.apply(target_params, key2, data.observation,
                                       target_state)

            # Get value-selector actions from online Q-values for double Q-learning.
            selector_actions = jnp.argmax(online_q, axis=-1)
            # Preprocess discounts & rewards.
            discounts = (data.discount * discount).astype(online_q.dtype)
            rewards = data.reward
            if clip_rewards:
                rewards = jnp.clip(rewards, -max_abs_reward, max_abs_reward)
            rewards = rewards.astype(online_q.dtype)

            # Get N-step transformed TD error and loss.
            batch_td_error_fn = jax.vmap(functools.partial(
                rlax.transformed_n_step_q_learning,
                n=bootstrap_n,
                tx_pair=tx_pair),
                                         in_axes=1,
                                         out_axes=1)
            # TODO(b/183945808): when this bug is fixed, truncations of actions,
            # rewards, and discounts will no longer be necessary.
            batch_td_error = batch_td_error_fn(online_q[:-1], data.action[:-1],
                                               target_q[1:],
                                               selector_actions[1:],
                                               rewards[:-1], discounts[:-1])
            batch_loss = 0.5 * jnp.square(batch_td_error).sum(axis=0)

            # Importance weighting.
            probs = sample.info.probability
            importance_weights = (1. / (probs + 1e-6)).astype(online_q.dtype)
            importance_weights **= importance_sampling_exponent
            importance_weights /= jnp.max(importance_weights)
            mean_loss = jnp.mean(importance_weights * batch_loss)

            # Calculate priorities as a mixture of max and mean sequence errors.
            abs_td_error = jnp.abs(batch_td_error).astype(online_q.dtype)
            max_priority = max_priority_weight * jnp.max(abs_td_error, axis=0)
            mean_priority = (1 - max_priority_weight) * jnp.mean(abs_td_error,
                                                                 axis=0)
            priorities = (max_priority + mean_priority)

            return mean_loss, priorities
示例#10
0
def hmm_viterbi_log(params, obs_seq, length=None):
    '''
    Computes, for each time step, the marginal conditional probability that the Hidden Markov Model was
    in each possible state given the observations that were made at each time step, i.e.
    P(z[i] | x[0], ..., x[num_steps - 1]) for all i from 0 to num_steps - 1
    It is based on https://github.com/deepmind/distrax/blob/master/distrax/_src/utils/hmm.py

    Parameters
    ----------
    params : HMM
        Hidden Markov Model
    obs_seq: array(seq_len)
        History of observed states
    Returns
    -------
    * array(seq_len, n_states)
        Alpha values
    * array(seq_len, n_states)
        Beta values
    * array(seq_len, n_states)
        Marginal conditional probability
    * float
        The loglikelihood giving log(p(x|model))
    '''
    seq_len = len(obs_seq)

    if length is None:
        length = seq_len

    trans_dist, obs_dist, init_dist = params.trans_dist, params.obs_dist, params.init_dist

    trans_log_probs = log_softmax(trans_dist.logits)
    init_log_probs = log_softmax(init_dist.logits)

    n_states = obs_dist.batch_shape[0]

    first_log_prob = init_log_probs + obs_dist.log_prob(obs_seq[0])

    if seq_len == 1:
        return jnp.expand_dims(jnp.argmax(first_log_prob), axis=0)

    def viterbi_forward(prev_logp, t):
        obs_logp = obs_dist.log_prob(obs_seq[t])

        logp = jnp.where(
            t <= length,
            prev_logp[..., None] + trans_log_probs + obs_logp[..., None, :],
            -jnp.inf + jnp.zeros_like(trans_log_probs))

        max_logp_given_successor = jnp.where(t <= length, jnp.max(logp,
                                                                  axis=-2),
                                             prev_logp)
        most_likely_given_successor = jnp.where(t <= length,
                                                jnp.argmax(logp, axis=-2), -1)

        return max_logp_given_successor, most_likely_given_successor

    ts = jnp.arange(1, seq_len)
    final_log_prob, most_likely_sources = lax.scan(viterbi_forward,
                                                   first_log_prob, ts)

    most_likely_initial_given_successor = jnp.argmax(trans_log_probs +
                                                     first_log_prob,
                                                     axis=-2)

    most_likely_sources = jnp.concatenate([
        jnp.expand_dims(most_likely_initial_given_successor, axis=0),
        most_likely_sources
    ],
                                          axis=0)

    def viterbi_backward(state, t):
        state = jnp.where(
            t <= length,
            jnp.sum(most_likely_sources[t] * one_hot(state, n_states)).astype(
                jnp.int64), state)
        most_likely = jnp.where(t <= length, state, -1)
        return state, most_likely

    final_state = jnp.argmax(final_log_prob)
    _, most_likely_path = lax.scan(viterbi_backward,
                                   final_state,
                                   ts,
                                   reverse=True)

    final_state = jnp.where(length == seq_len, final_state, -1)

    return jnp.append(most_likely_path, final_state)
示例#11
0
def categorical_logits(key, logits, shape=()):
    shape = shape or logits.shape[:-1]
    return np.argmax(
        random.gumbel(key, shape + logits.shape[-1:], logits.dtype) + logits,
        axis=-1)
def loss_fn(
    model,
    padded_example_and_rng,
    static_metadata,
    regularization_weights = None,
    reinforce_weight = 1.0,
    baseline_weight = 0.001,
):
  """Loss function for multi-pointer task.

  Args:
    model: The model to evaluate.
    padded_example_and_rng: Padded example to evaluate on, with a PRNGKey.
    static_metadata: Padding configuration for the example, since this may vary
      for different examples.
    regularization_weights: Associates side output key regexes with
      regularization penalties.
    reinforce_weight: Weight to give to the reinforce term.
    baseline_weight: Weight to give to the baseline.

  Returns:
    Tuple of loss and metrics.
  """
  padded_example, rng = padded_example_and_rng

  # Run the model.
  with side_outputs.collect_side_outputs() as collected_side_outputs:
    with flax.nn.stochastic(rng):
      joint_log_probs = model(padded_example, static_metadata)

  # Computing the loss:
  # Extract logits for the correct location.
  log_probs_at_bug = joint_log_probs[padded_example.bug_node_index, :]
  # Compute p(repair) = sum[ p(node) p(repair | node) ]
  # -> log p(repair) = logsumexp[ log p(node) + log p (repair | node) ]
  log_prob_joint = jax.scipy.special.logsumexp(
      log_probs_at_bug + jnp.log(padded_example.repair_node_mask))

  # Metrics:
  # Marginal log probabilities:
  log_prob_bug = jax.scipy.special.logsumexp(log_probs_at_bug)
  log_prob_repair = jax.scipy.special.logsumexp(
      jax.scipy.special.logsumexp(joint_log_probs, axis=0) +
      jnp.log(padded_example.repair_node_mask))

  # Conditional log probabilities:
  log_prob_repair_given_bug = log_prob_joint - log_prob_bug
  log_prob_bug_given_repair = log_prob_joint - log_prob_repair

  # Majority accuracy (1 if we assign the correct tuple > 50%):
  # (note that this is easier to compute, since we can't currently aggregate
  # probability separately for each candidate.)
  log_half = jnp.log(0.5)
  majority_acc_joint = log_prob_joint > log_half

  # Probabilities associated with each node.
  node_node_probs = jnp.exp(joint_log_probs)
  # Accumulate across unique candidates by identifier. This has the same shape,
  # but only the first few values will be populated.
  node_candidate_probs = padded_example.unique_candidate_operator.apply_add(
      in_array=node_node_probs,
      out_array=jnp.zeros_like(node_node_probs),
      in_dims=[1],
      out_dims=[1])

  # Classify: 50% decision boundary
  only_buggy_probs = node_candidate_probs.at[0, :].set(0).at[:, 0].set(0)
  p_buggy = jnp.sum(only_buggy_probs)
  pred_nobug = p_buggy <= 0.5

  # Localize/repair: take most likely bug position, conditioned on being buggy
  pred_bug_loc, pred_cand_id = jnp.unravel_index(
      jnp.argmax(only_buggy_probs), only_buggy_probs.shape)

  actual_nobug = jnp.array(padded_example.bug_node_index == 0)

  actual_bug = jnp.logical_not(actual_nobug)
  pred_bug = jnp.logical_not(pred_nobug)

  metrics = {
      'nll/joint':
          -log_prob_joint,
      'nll/marginal_bug':
          -log_prob_bug,
      'nll/marginal_repair':
          -log_prob_repair,
      'nll/repair_given_bug':
          -log_prob_repair_given_bug,
      'nll/bug_given_repair':
          -log_prob_bug_given_repair,
      'inaccuracy/legacy_overall':
          1 - majority_acc_joint,
      'inaccuracy/overall':
          (~((actual_nobug & pred_nobug) |
             (actual_bug & pred_bug &
              (pred_bug_loc == padded_example.bug_node_index) &
              (pred_cand_id == padded_example.repair_id)))),
      'inaccuracy/classification_overall': (actual_nobug != pred_nobug),
      'inaccuracy/classification_given_nobug':
          train_util.RatioMetric(
              numerator=(actual_nobug & ~pred_nobug), denominator=actual_nobug),
      'inaccuracy/classification_given_bug':
          train_util.RatioMetric(
              numerator=(actual_bug & ~pred_bug), denominator=actual_bug),
      'inaccuracy/localized_given_bug':
          train_util.RatioMetric(
              numerator=(actual_bug
                         & ~(pred_bug_loc == padded_example.bug_node_index)),
              denominator=actual_bug),
      'inaccuracy/repaired_given_bug':
          train_util.RatioMetric(
              numerator=(actual_bug
                         & ~(pred_cand_id == padded_example.repair_id)),
              denominator=actual_bug),
      'inaccuracy/localized_repaired_given_bug':
          train_util.RatioMetric(
              numerator=(actual_bug
                         & ~((pred_bug_loc == padded_example.bug_node_index) &
                             (pred_cand_id == padded_example.repair_id))),
              denominator=actual_bug),
      'inaccuracy/overall_given_bug':
          train_util.RatioMetric(
              numerator=(actual_bug
                         & ~(pred_bug &
                             (pred_bug_loc == padded_example.bug_node_index) &
                             (pred_cand_id == padded_example.repair_id))),
              denominator=actual_bug),
  }

  loss = -log_prob_joint

  for k, v in collected_side_outputs.items():
    # Flax collection keys will start with "/".
    if v.shape == ():  # pylint: disable=g-explicit-bool-comparison
      metrics['side' + k] = v

  if regularization_weights:
    total_regularization = 0
    for query, weight in regularization_weights.items():
      logging.info('Regularizing side outputs matching query %s', query)
      found = False
      for k, v in collected_side_outputs.items():
        if re.search(query, k):
          found = True
          logging.info('Regularizing %s with weight %f', k, weight)
          total_regularization += weight * v
      if not found:
        raise ValueError(
            f'Regularization query {query} did not match any side output. '
            f'Side outputs were {set(collected_side_outputs.keys())}')

    loss = loss + total_regularization

  is_single_sample = any(
      k.endswith('one_sample_log_prob_per_edge_per_node')
      for k in collected_side_outputs)
  if is_single_sample:
    log_prob, = [
        v for k, v in collected_side_outputs.items()
        if k.endswith('one_sample_log_prob_per_edge_per_node')
    ]
    baseline, = [
        v for k, v in collected_side_outputs.items()
        if k.endswith('one_sample_reward_baseline')
    ]

    num_real_nodes = padded_example.input_graph.bundle.graph_metadata.num_nodes
    valid_mask = (
        jnp.arange(static_metadata.bundle_padding.static_max_metadata.num_nodes)
        < num_real_nodes)
    log_prob = jnp.where(valid_mask[None, :], log_prob, 0)
    total_log_prob = jnp.sum(log_prob)

    reinforce_virtual_cost = (
        total_log_prob * jax.lax.stop_gradient(loss - baseline))
    baseline_penalty = jnp.square(loss - baseline)

    reinforce_virtual_cost_zeroed = reinforce_virtual_cost - jax.lax.stop_gradient(
        reinforce_virtual_cost)

    loss = (
        loss + reinforce_weight * reinforce_virtual_cost_zeroed +
        baseline_weight * baseline_penalty)
    metrics['reinforce_virtual_cost'] = reinforce_virtual_cost
    metrics['baseline_penalty'] = baseline_penalty
    metrics['baseline'] = baseline
    metrics['total_log_prob'] = total_log_prob

  metrics = jax.tree_map(lambda x: x.astype(jnp.float32), metrics)
  return loss, metrics
                      bias_coef=args.bias_coef,
                      activation=args.activation,
                      norm=args.norm)
    _, params = net_init(rng=random.PRNGKey(42), input_shape=(-1, 2))

else:
    raise ValueError

# loss functions
if args.dataset == 'sinusoid':
    loss = lambda fx, y_hat: 0.5 * np.mean((fx - y_hat)**2)
elif args.dataset in ['omniglot', 'circle']:
    loss = lambda fx, targets: -np.sum(logsoftmax(fx) * targets
                                       ) / targets.shape[0]
    acc = lambda fx, targets: np.mean(
        np.argmax(logsoftmax(fx), axis=-1) == np.argmax(targets, axis=-1))
    param_acc = jit(lambda p, x, y: acc(f(p, x), y))
else:
    raise ValueError

grad_loss = jit(grad(lambda p, x, y: loss(f(p, x), y)))
param_loss = jit(lambda p, x, y: loss(f(p, x), y))

# optimizers    #TODO: separate optimizers for nonlinear and linear?
outer_opt_init, outer_opt_update, outer_get_params = select_opt(
    args.outer_opt_alg, args.outer_step_size)()
inner_opt_init, inner_opt_update, inner_get_params = select_opt(
    args.inner_opt_alg, args.inner_step_size)()

# consistent task for plotting eval
if args.dataset == 'sinusoid':
示例#14
0
def collect_trajectories(env,
                         policy_net_apply,
                         policy_net_params,
                         num_trajectories=1,
                         policy="greedy",
                         max_timestep=None,
                         epsilon=0.1):
    """Collect trajectories with the given policy net and behaviour."""
    trajectories = []

    for t in range(num_trajectories):
        t_start = time.time()
        rewards = []
        actions = []
        done = False

        observation = env.reset()

        # This is currently shaped (1, 1) + OBS, but new observations will keep
        # getting added to it, making it eventually (1, T+1) + OBS
        observation_history = observation[np.newaxis, np.newaxis, :]

        # Run either till we're done OR if max_timestep is defined only till that
        # timestep.
        ts = 0
        while ((not done)
               and (not max_timestep
                    or observation_history.shape[1] < max_timestep)):
            ts_start = time.time()
            # Run the policy, to pick an action, shape is (1, t, A) because
            # observation_history is shaped (1, t) + OBS
            predictions = policy_net_apply(observation_history,
                                           policy_net_params)

            # We need the predictions for the last time-step, so squeeze the batch
            # dimension and take the last time-step.
            predictions = np.squeeze(predictions, axis=0)[-1]

            # Policy can be run in one of the following ways:
            #  - Greedy
            #  - Epsilon-Greedy
            #  - Categorical-Sampling
            action = None
            if policy == "greedy":
                action = np.argmax(predictions)
            elif policy == "epsilon-greedy":
                # A schedule for epsilon is 1/k where k is the episode number sampled.
                if onp.random.random() < epsilon:
                    # Choose an action at random.
                    action = onp.random.randint(0, high=len(predictions))
                else:
                    # Return the best action.
                    action = np.argmax(predictions)
            elif policy == "categorical-sampling":
                # NOTE: The predictions aren't probabilities but log-probabilities
                # instead, since they were computed with LogSoftmax.
                # So just np.exp them to make them probabilities.
                predictions = np.exp(predictions)
                action = onp.argwhere(
                    onp.random.multinomial(1, predictions) == 1)
            else:
                raise ValueError("Unknown policy: %s" % policy)

            # NOTE: Assumption, single batch.
            try:
                action = int(action)
            except TypeError as err:
                # Let's dump some information before we die off.
                logging.error("Cannot convert action into an integer: [%s]",
                              err)
                logging.error("action.shape: [%s]", action.shape)
                logging.error("action: [%s]", action)
                logging.error("predictions.shape: [%s]", predictions.shape)
                logging.error("predictions: [%s]", predictions)
                logging.error("observation_history: [%s]", observation_history)
                logging.error("policy_net_params: [%s]", policy_net_params)
                log_params(policy_net_params, "policy_net_params")
                raise err

            observation, reward, done, _ = env.step(action)

            # observation is of shape OBS, so add extra dims and concatenate on the
            # time dimension.
            observation_history = np.concatenate(
                [observation_history, observation[np.newaxis, np.newaxis, :]],
                axis=1)

            rewards.append(reward)
            actions.append(action)

            ts += 1
            logging.vlog(
                2,
                "  Collected time-step[ %5d] of trajectory[ %5d] in [%0.2f] msec.",
                ts, t, get_time(ts_start))
        logging.vlog(2, " Collected trajectory[ %5d] in [%0.2f] msec.", t,
                     get_time(t_start))

        # This means we are done we're been terminated early.
        assert done or (max_timestep
                        and max_timestep >= observation_history.shape[1])
        # observation_history is (1, T+1) + OBS, lets squeeze out the batch dim.
        observation_history = np.squeeze(observation_history, axis=0)
        trajectories.append(
            (observation_history, np.stack(actions), np.stack(rewards)))

    return trajectories
示例#15
0
 def accuracy(params: hk.Params, batch: Batch) -> jnp.ndarray:
     predictions = net.apply(params, batch)
     return jnp.mean(jnp.argmax(predictions, axis=-1) == batch["label"])
示例#16
0
    def __call__(self, inputs, is_training):
        """Connects the module to some inputs.

    Args:
      inputs: Tensor, final dimension must be equal to ``embedding_dim``. All
        other leading dimensions will be flattened and treated as a large batch.
      is_training: boolean, whether this connection is to training data. When
        this is set to ``False``, the internal moving average statistics will
        not be updated.

    Returns:
      dict: Dictionary containing the following keys and values:
        * ``quantize``: Tensor containing the quantized version of the input.
        * ``loss``: Tensor containing the loss to optimize.
        * ``perplexity``: Tensor containing the perplexity of the encodings.
        * ``encodings``: Tensor containing the discrete encodings, ie which
          element of the quantized space each input element was mapped to.
        * ``encoding_indices``: Tensor containing the discrete encoding indices,
          ie which element of the quantized space each input element was mapped
          to.
    """
        flat_inputs = jnp.reshape(inputs, [-1, self.embedding_dim])
        embeddings = self.embeddings

        distances = (jnp.sum(flat_inputs**2, 1, keepdims=True) -
                     2 * jnp.matmul(flat_inputs, embeddings) +
                     jnp.sum(embeddings**2, 0, keepdims=True))

        encoding_indices = jnp.argmax(-distances, 1)
        encodings = jax.nn.one_hot(encoding_indices,
                                   self.num_embeddings,
                                   dtype=distances.dtype)

        # NB: if your code crashes with a reshape error on the line below about a
        # Tensor containing the wrong number of values, then the most likely cause
        # is that the input passed in does not have a final dimension equal to
        # self.embedding_dim. Ideally we would catch this with an Assert but that
        # creates various other problems related to device placement / TPUs.
        encoding_indices = jnp.reshape(encoding_indices, inputs.shape[:-1])
        quantized = self.quantize(encoding_indices)
        e_latent_loss = jnp.mean(
            (jax.lax.stop_gradient(quantized) - inputs)**2)

        if is_training:
            cluster_size = jnp.sum(encodings, axis=0)
            if self.cross_replica_axis:
                cluster_size = jax.lax.psum(cluster_size,
                                            axis_name=self.cross_replica_axis)
            updated_ema_cluster_size = self.ema_cluster_size(cluster_size)

            dw = jnp.matmul(flat_inputs.T, encodings)
            if self.cross_replica_axis:
                dw = jax.lax.psum(dw, axis_name=self.cross_replica_axis)
            updated_ema_dw = self.ema_dw(dw)

            n = jnp.sum(updated_ema_cluster_size)
            updated_ema_cluster_size = (
                (updated_ema_cluster_size + self.epsilon) /
                (n + self.num_embeddings * self.epsilon) * n)

            normalised_updated_ema_w = (
                updated_ema_dw /
                jnp.reshape(updated_ema_cluster_size, [1, -1]))

            hk.set_state("embeddings", normalised_updated_ema_w)
            loss = self.commitment_cost * e_latent_loss

        else:
            loss = self.commitment_cost * e_latent_loss

        # Straight Through Estimator
        quantized = inputs + jax.lax.stop_gradient(quantized - inputs)
        avg_probs = jnp.mean(encodings, 0)
        if self.cross_replica_axis:
            avg_probs = jax.lax.pmean(avg_probs,
                                      axis_name=self.cross_replica_axis)
        perplexity = jnp.exp(-jnp.sum(avg_probs * jnp.log(avg_probs + 1e-10)))

        return {
            "quantize": quantized,
            "loss": loss,
            "perplexity": perplexity,
            "encodings": encodings,
            "encoding_indices": encoding_indices,
            "distances": distances,
        }
    logits = predict(params, inputs)
    preds = stax.logsoftmax(logits)
    return -np.mean(np.sum(preds * targets, axis=1))


#set up of index
tl = test_labels
index7 = tl.tolist().index([0, 0, 0, 0, 0, 0, 0, 1, 0, 0])
print(test_labels[index7])
#computing process to the new x
input_image, input_label = shape_as_image(test_images[index7],
                                          test_labels[index7])
grad_newx = grad(computation, 1)(params, input_image, input_label)
newx = input_image + hyper * np.sign(grad_newx)
#start plot and its predicted vector
target_class = np.argmax(input_label)
predicted_class = np.argmax(predict(params, newx))
#predicted vector
predict_vector = predict(params, newx)
print('the target class is :', target_class)
print('the predict class is :', predicted_class)
print('the predicted vector is :', predict_vector)

image = np.array(newx)
image = image * 255
image = image.reshape(28, 28)
plt.imshow(image)
"""## From here is Part 2"""

#initial setup
images = np.array(test_images[0:1000])
 def accuracy(y_true, y_pred):
     return jnp.mean(jnp.argmax(y_pred, axis=-1) == y_true)
示例#19
0
    num_complete_batches, leftover = divmod(X.shape[0], batch_size)
    num_batches = num_complete_batches + bool(leftover)
    while True:
        temp, rng = random.split(rng)
        perm = random.permutation(temp, X.shape[0])
        for i in range(num_batches):
            batch_idx = perm[i * batch_size:(i + 1) * batch_size]
            yield X[batch_idx], y[batch_idx]


if __name__ == "__main__":
    rng = random.PRNGKey(0)

    X, y, X_test, y_test = mnist()
    X, X_test = X.reshape(-1, 28, 28, 1), X_test.reshape(-1, 28, 28, 1)
    y, y_test = (np.argmax(y, 1) % 2 == 1).astype(
        np.float32), (np.argmax(y_test, 1) % 1 == 1).astype(np.float32)

    temp, rng = random.split(rng)
    params, predict = model(temp)

    def loss(params, batch, l2=0.05):
        X, y = batch
        y_hat = predict(params, X).reshape(-1)
        return -np.mean(y * np.log(y_hat) + (1. - y) * np.log(1. - y_hat))

    @jit
    def update(i, opt_state, batch):
        params = get_params(opt_state)
        return opt_update(i, grad(loss)(params, batch), opt_state)
示例#20
0
 def accuracy(trainable_params, untrainable_params, batch):
     inputs, targets = batch
     target_class = jnp.argmax(targets, axis=1)
     params = merge_params(trainable_params, untrainable_params)
     predicted_class = jnp.argmax(net.apply(params, inputs), axis=1)
     return jnp.mean(predicted_class == target_class)
示例#21
0
def test_change_point_x64():
    # Ref: https://forum.pyro.ai/t/i-dont-understand-why-nuts-code-is-not-working-bayesian-hackers-mail/696
    warmup_steps, num_samples = 500, 3000

    def model(data):
        alpha = 1 / np.mean(data)
        lambda1 = numpyro.sample('lambda1', dist.Exponential(alpha))
        lambda2 = numpyro.sample('lambda2', dist.Exponential(alpha))
        tau = numpyro.sample('tau', dist.Uniform(0, 1))
        lambda12 = np.where(
            np.arange(len(data)) < tau * len(data), lambda1, lambda2)
        numpyro.sample('obs', dist.Poisson(lambda12), obs=data)

    count_data = np.array([
        13,
        24,
        8,
        24,
        7,
        35,
        14,
        11,
        15,
        11,
        22,
        22,
        11,
        57,
        11,
        19,
        29,
        6,
        19,
        12,
        22,
        12,
        18,
        72,
        32,
        9,
        7,
        13,
        19,
        23,
        27,
        20,
        6,
        17,
        13,
        10,
        14,
        6,
        16,
        15,
        7,
        2,
        15,
        15,
        19,
        70,
        49,
        7,
        53,
        22,
        21,
        31,
        19,
        11,
        18,
        20,
        12,
        35,
        17,
        23,
        17,
        4,
        2,
        31,
        30,
        13,
        27,
        0,
        39,
        37,
        5,
        14,
        13,
        22,
    ])
    kernel = NUTS(model=model)
    mcmc = MCMC(kernel, warmup_steps, num_samples)
    mcmc.run(random.PRNGKey(4), count_data)
    samples = mcmc.get_samples()
    tau_posterior = (samples['tau'] * len(count_data)).astype(np.int32)
    tau_values, counts = onp.unique(tau_posterior, return_counts=True)
    mode_ind = np.argmax(counts)
    mode = tau_values[mode_ind]
    assert mode == 44

    if 'JAX_ENABLE_X64' in os.environ:
        assert samples['lambda1'].dtype == np.float64
        assert samples['lambda2'].dtype == np.float64
        assert samples['tau'].dtype == np.float64
示例#22
0
def sample_categorical(key, logits):
    minval = jnp.finfo(logits.dtype).tiny
    unif = jax.random.uniform(key, logits.shape, minval=minval, maxval=1.0)
    gumbel = -jnp.log(-jnp.log(unif) + minval)
    category = jnp.argmax(logits + gumbel, -1)
    return category
示例#23
0
def gumbel_max_sampler(logits_1, logits_2, rng):
    """Samples from a Gumbel-max coupling."""
    gumbels = jax.random.gumbel(rng, logits_1.shape)
    x = jnp.argmax(gumbels + logits_1)
    y = jnp.argmax(gumbels + logits_2)
    return jnp.zeros([10, 10]).at[x, y].set(1.)
示例#24
0
def main(_):
    sns.set()
    sns.set_palette(sns.color_palette('hls', 10))
    npr.seed(FLAGS.seed)

    logging.info('Starting experiment.')

    # Create model folder for outputs
    try:
        gfile.MakeDirs(FLAGS.work_dir)
    except gfile.GOSError:
        pass
    stdout_log = gfile.Open('{}/stdout.log'.format(FLAGS.work_dir), 'w+')

    # use mean/std of svhn train
    train_images, _, _ = datasets.get_dataset_split(
        name=FLAGS.train_split.split('-')[0],
        split=FLAGS.train_split.split('-')[1],
        shuffle=False)
    train_mu, train_std = onp.mean(train_images), onp.std(train_images)
    del train_images

    # BEGIN: fetch test data and candidate pool
    test_images, test_labels, _ = datasets.get_dataset_split(
        name=FLAGS.test_split.split('-')[0],
        split=FLAGS.test_split.split('-')[1],
        shuffle=False)
    pool_images, pool_labels, _ = datasets.get_dataset_split(
        name=FLAGS.pool_split.split('-')[0],
        split=FLAGS.pool_split.split('-')[1],
        shuffle=False)

    n_pool = len(pool_images)
    test_images = (test_images -
                   train_mu) / train_std  # normalize w train mu/std
    pool_images = (pool_images -
                   train_mu) / train_std  # normalize w train mu/std

    # augmentation for train/pool data
    if FLAGS.augment_data:
        augmentation = data.chain_transforms(data.RandomHorizontalFlip(0.5),
                                             data.RandomCrop(4), data.ToDevice)
    else:
        augmentation = None
    # END: fetch test data and candidate pool

    # BEGIN: load ckpt
    opt_init, opt_update, get_params = optimizers.sgd(FLAGS.learning_rate)

    if FLAGS.pretrained_dir is not None:
        with gfile.Open(FLAGS.pretrained_dir, 'rb') as fpre:
            pretrained_opt_state = optimizers.pack_optimizer_state(
                pickle.load(fpre))
        fixed_params = get_params(pretrained_opt_state)[:7]

        ckpt_dir = '{}/{}'.format(FLAGS.root_dir, FLAGS.ckpt_idx)
        with gfile.Open(ckpt_dir, 'wr') as fckpt:
            opt_state = optimizers.pack_optimizer_state(pickle.load(fckpt))
        params = get_params(opt_state)
        # combine fixed pretrained params and dpsgd trained last layers
        params = fixed_params + params
        opt_state = opt_init(params)
    else:
        ckpt_dir = '{}/{}'.format(FLAGS.root_dir, FLAGS.ckpt_idx)
        with gfile.Open(ckpt_dir, 'wr') as fckpt:
            opt_state = optimizers.pack_optimizer_state(pickle.load(fckpt))
        params = get_params(opt_state)

    stdout_log.write('finetune from: {}\n'.format(ckpt_dir))
    logging.info('finetune from: %s', ckpt_dir)
    test_acc, test_pred = accuracy(params,
                                   shape_as_image(test_images, test_labels),
                                   return_predicted_class=True)
    logging.info('test accuracy: %.2f', test_acc)
    stdout_log.write('test accuracy: {}\n'.format(test_acc))
    stdout_log.flush()
    # END: load ckpt

    # BEGIN: setup for dp model
    @jit
    def update(_, i, opt_state, batch):
        params = get_params(opt_state)
        return opt_update(i, grad_loss(params, batch), opt_state)

    @jit
    def private_update(rng, i, opt_state, batch):
        params = get_params(opt_state)
        rng = random.fold_in(rng, i)  # get new key for new random numbers
        return opt_update(
            i,
            private_grad(params, batch, rng, FLAGS.l2_norm_clip,
                         FLAGS.noise_multiplier, FLAGS.batch_size), opt_state)

    # END: setup for dp model

    ### BEGIN: prepare extra points picked from pool data
    # BEGIN: on pool data
    pool_embeddings = [apply_fn_0(params[:-1],
                                  pool_images[b_i:b_i + FLAGS.batch_size]) \
                       for b_i in range(0, n_pool, FLAGS.batch_size)]
    pool_embeddings = np.concatenate(pool_embeddings, axis=0)

    pool_logits = apply_fn_1(params[-1:], pool_embeddings)

    pool_true_labels = np.argmax(pool_labels, axis=1)
    pool_predicted_labels = np.argmax(pool_logits, axis=1)
    pool_correct_indices = \
        onp.where(pool_true_labels == pool_predicted_labels)[0]
    pool_incorrect_indices = \
        onp.where(pool_true_labels != pool_predicted_labels)[0]
    assert len(pool_correct_indices) + \
        len(pool_incorrect_indices) == len(pool_labels)

    pool_probs = stax.softmax(pool_logits)

    if FLAGS.uncertain == 0 or FLAGS.uncertain == 'entropy':
        pool_entropy = -onp.sum(pool_probs * onp.log(pool_probs), axis=1)
        stdout_log.write('all {} entropy: min {}, max {}\n'.format(
            len(pool_entropy), onp.min(pool_entropy), onp.max(pool_entropy)))

        pool_entropy_sorted_indices = onp.argsort(pool_entropy)
        # take the n_extra most uncertain points
        pool_uncertain_indices = \
            pool_entropy_sorted_indices[::-1][:FLAGS.n_extra]
        stdout_log.write('uncertain {} entropy: min {}, max {}\n'.format(
            len(pool_entropy[pool_uncertain_indices]),
            onp.min(pool_entropy[pool_uncertain_indices]),
            onp.max(pool_entropy[pool_uncertain_indices])))

    elif FLAGS.uncertain == 1 or FLAGS.uncertain == 'difference':
        # 1st_prob - 2nd_prob
        assert len(pool_probs.shape) == 2
        sorted_pool_probs = onp.sort(pool_probs, axis=1)
        pool_probs_diff = sorted_pool_probs[:, -1] - sorted_pool_probs[:, -2]
        assert min(pool_probs_diff) > 0.
        stdout_log.write('all {} difference: min {}, max {}\n'.format(
            len(pool_probs_diff), onp.min(pool_probs_diff),
            onp.max(pool_probs_diff)))

        pool_uncertain_indices = onp.argsort(pool_probs_diff)[:FLAGS.n_extra]
        stdout_log.write('uncertain {} difference: min {}, max {}\n'.format(
            len(pool_probs_diff[pool_uncertain_indices]),
            onp.min(pool_probs_diff[pool_uncertain_indices]),
            onp.max(pool_probs_diff[pool_uncertain_indices])))

    elif FLAGS.uncertain == 2 or FLAGS.uncertain == 'random':
        pool_uncertain_indices = npr.permutation(n_pool)[:FLAGS.n_extra]

    # END: on pool data
    ### END: prepare extra points picked from pool data

    finetune_images = copy.deepcopy(pool_images[pool_uncertain_indices])
    finetune_labels = copy.deepcopy(pool_labels[pool_uncertain_indices])

    stdout_log.write('Starting fine-tuning...\n')
    logging.info('Starting fine-tuning...')
    stdout_log.flush()

    stdout_log.write('{} points picked via {}\n'.format(
        len(finetune_images), FLAGS.uncertain))
    logging.info('%d points picked via %s', len(finetune_images),
                 FLAGS.uncertain)
    assert FLAGS.n_extra == len(finetune_images)

    for epoch in range(1, FLAGS.epochs + 1):

        # BEGIN: finetune model with extra data, evaluate and save
        num_extra = len(finetune_images)
        num_complete_batches, leftover = divmod(num_extra, FLAGS.batch_size)
        num_batches = num_complete_batches + bool(leftover)

        finetune = data.DataChunk(X=finetune_images,
                                  Y=finetune_labels,
                                  image_size=32,
                                  image_channels=3,
                                  label_dim=1,
                                  label_format='numeric')

        batches = data.minibatcher(finetune,
                                   FLAGS.batch_size,
                                   transform=augmentation)

        itercount = itertools.count()
        key = random.PRNGKey(FLAGS.seed)

        start_time = time.time()

        for _ in range(num_batches):
            # tmp_time = time.time()
            b = next(batches)
            if FLAGS.dpsgd:
                opt_state = private_update(
                    key, next(itercount), opt_state,
                    shape_as_image(b.X, b.Y, dummy_dim=True))
            else:
                opt_state = update(key, next(itercount), opt_state,
                                   shape_as_image(b.X, b.Y))
            # stdout_log.write('single update in {:.2f} sec\n'.format(
            #     time.time() - tmp_time))

        epoch_time = time.time() - start_time
        stdout_log.write('Epoch {} in {:.2f} sec\n'.format(epoch, epoch_time))
        logging.info('Epoch %d in %.2f sec', epoch, epoch_time)

        # accuracy on test data
        params = get_params(opt_state)

        test_pred_0 = test_pred
        test_acc, test_pred = accuracy(params,
                                       shape_as_image(test_images,
                                                      test_labels),
                                       return_predicted_class=True)
        test_loss = loss(params, shape_as_image(test_images, test_labels))
        stdout_log.write(
            'Eval set loss, accuracy (%): ({:.2f}, {:.2f})\n'.format(
                test_loss, 100 * test_acc))
        logging.info('Eval set loss, accuracy: (%.2f, %.2f)', test_loss,
                     100 * test_acc)
        stdout_log.flush()

        # visualize prediction difference between 2 checkpoints.
        if FLAGS.visualize:
            utils.visualize_ckpt_difference(test_images,
                                            np.argmax(test_labels, axis=1),
                                            test_pred_0,
                                            test_pred,
                                            epoch - 1,
                                            epoch,
                                            FLAGS.work_dir,
                                            mu=train_mu,
                                            sigma=train_std)

    # END: finetune model with extra data, evaluate and save

    stdout_log.close()
示例#25
0
def accuracy(params, batch):
  inputs, targets = batch
  target_class = np.argmax(targets, axis=1)
  predicted_class = np.argmax(predict(params, inputs), axis=1)
  return np.mean(predicted_class == target_class)
示例#26
0
def _accuracy(y, y_hat):
    """Compute the accuracy of the predictions with respect to one-hot labels."""
    return np.mean(np.argmax(y, axis=1) == np.argmax(y_hat, axis=1))
示例#27
0
def _gumbel_max(rng, logit_probs):
    return np.argmax(random.gumbel(rng, logit_probs.shape, logit_probs.dtype) +
                     logit_probs,
                     axis=0)
示例#28
0
 def mode(self):
     return jp.argmax(self._probs, axis=-1)
示例#29
0
def accuracy(params, images, targets):
    target_class = jnp.argmax(targets, axis=1)
    predicted_class = jnp.argmax(batched_predict(params, images), axis=1)
    return jnp.mean(predicted_class == target_class)
示例#30
0
 def sample(self, sample_shape, seed):
     return jp.argmax(
         self._logits +
         jax.random.gumbel(seed, shape=sample_shape + self._logits.shape),
         axis=-1,
     )