示例#1
0
 def __init__(self, input_shape, dict_size=(-1., 1., 20), gamma=None):
     self.d = tf.linspace(*dict_size)
     if gamma is None:
         self.gamma = .5 / tf.square(2 * (self.d[-1] - self.d[0]))  # (d_stop - d_start)*2
     else:
         self.gamma = gamma
     self.alpha = tf.get_variable('alpha', shape=(1, input_shape, self.d.get_shape()[0]),
                                  initializer=RidgeInit(gauss_kernel(self.d, self.d, self.gamma), self.d))
示例#2
0
def kaf(linear, name, kernel='rbf', D=None, gamma=None):
    if D is None:
        D = tf.linspace(start=-2., stop=2., num=20)

    with tf.variable_scope('kaf', reuse=tf.AUTO_REUSE):
        if kernel == "rbf":
            K = gauss_kernel(linear, D, gamma=gamma)
            alpha = tf.get_variable(name, shape=(1, linear.get_shape()[-1], D.get_shape()[0]),
                                    initializer=tf.random_normal_initializer(stddev=0.1))
        elif kernel == 'rbf2d':
            Dx, Dy = tf.meshgrid(D, D)
            K = gauss_kernel2D(linear, Dx, Dy, gamma=gamma)

            alpha = tf.get_variable(name,
                                    shape=(1, linear.get_shape()[-1] // 2, D.get_shape()[0] * D.get_shape()[0]),
                                    initializer=tf.random_normal_initializer(stddev=0.1))
        else:
            raise NotImplementedError()
        act = tf.reduce_sum(tf.multiply(K, alpha), axis=-1)
        # act = tf.squeeze(act, axis=0)
    return act
示例#3
0
  def __init__(self,
               num_actions=None,
               observation_size=None,
               num_players=None,
               num_atoms=51,
               vmax=25.,
               gamma=0.99,
               update_horizon=1,
               min_replay_history=500,
               update_period=4,
               target_update_period=500,
               epsilon_train=0.0,
               epsilon_eval=0.0,
               epsilon_decay_period=1000,
               learning_rate=0.000025,
               optimizer_epsilon=0.00003125,
               tf_device='/cpu:*'):
    """Initializes the agent and constructs its graph.

    Args:
      num_actions: int, number of actions the agent can take at any state.
      observation_size: int, size of observation vector.
      num_players: int, number of players playing this game.
      num_atoms: Int, the number of buckets for the value function distribution.
      vmax: float, maximum return predicted by a value distribution.
      gamma: float, discount factor as commonly used in the RL literature.
      update_horizon: int, horizon at which updates are performed, the 'n' in
        n-step update.
      min_replay_history: int, number of stored transitions before training.
      update_period: int, period between DQN updates.
      target_update_period: int, update period for the target network.
      epsilon_train: float, final epsilon for training.
      epsilon_eval: float, epsilon during evaluation.
      epsilon_decay_period: int, number of steps for epsilon to decay.
      learning_rate: float, learning rate for the optimizer.
      optimizer_epsilon: float, epsilon for Adam optimizer.
      tf_device: str, Tensorflow device on which to run computations.
    """
    self.graph = tf.Graph()

    with self.graph.as_default():
      # We need this because some tools convert round floats into ints.
      vmax = float(vmax)
      self.num_atoms = num_atoms
      # Using -vmax as the minimum return is is wasteful, because all rewards are
      # positive -- but does not unduly affect performance.
      self.support = tf.linspace(-vmax, vmax, num_atoms)
      self.learning_rate = learning_rate
      self.optimizer_epsilon = optimizer_epsilon

      graph_template = functools.partial(rainbow_template, num_atoms=num_atoms)
      super(RainbowAgent, self).__init__(
          num_actions=num_actions,
          observation_size=observation_size,
          num_players=num_players,
          gamma=gamma,
          update_horizon=update_horizon,
          min_replay_history=min_replay_history,
          update_period=update_period,
          target_update_period=target_update_period,
          epsilon_train=epsilon_train,
          epsilon_eval=epsilon_eval,
          epsilon_decay_period=epsilon_decay_period,
          graph_template=graph_template,
          tf_device=tf_device)
      tf.logging.info('\t learning_rate: %f', learning_rate)
      tf.logging.info('\t optimizer_epsilon: %f', optimizer_epsilon)
示例#4
0
def compute_module_criticality(
    objective_fn,
    module_variables_init,
    module_variables_final,
    num_samples_per_iteration=10,
    alpha_grid_size=10,
    sigma_grid_size=10,
    sigma_ratio=1.0,
    loss_threshold_condition=relative_error_condition,
    normalize_error=False,
):
    """Compute the criticality of a module parameterized by `module_variables`.

  Args:
    objective_fn: A callable that takes in an iterable of the module-specific
      variables and produces the value of the objective function.
    module_variables_init: A list of tf.Tensors; the variables of the module at
      initialization.
    module_variables_final: A list of tf.Tensors; the variables of the module at
      convergence.
    num_samples_per_iteration: Number of perturbations to sample each iteration.
    alpha_grid_size: The number of values to test for alpha, the interpolation
      coefficient.
    sigma_grid_size: The number of values to test for sigma, the standard
      deviation of the perturbation.
    sigma_ratio: Positive scalar multiplier k for values of sigma, to enforce
      that the tested values of sigma lie in [k * 1e-16, k]; the default is 1.0,
      implying that the tested values of sigma lie in the interval [1e-16, 1].
    loss_threshold_condition: A callable that takes in a reference objective
      value and a candidate objective value and produces a thresholding
      decision.
    normalize_error: Whether to normalize the error that is minimized over in
      the definition of criticality by the Frobenius norm of the distance
      between initial and final parameters.

  Returns:
    A `collections.NamedTuple` that contains the results of the criticality
    analysis.
  """
    initial_objective_value = objective_fn(module_variables_init)
    final_objective_value = objective_fn(module_variables_final)

    # Test a 2D grid of alpha and sigma values.
    float_zero = tf.cast(0, tf.float32)
    alphas, sigmas = tf.meshgrid(
        tf.linspace(float_zero, 1, alpha_grid_size + 1),
        tf.linspace(float_zero + 1e-16, 1, sigma_grid_size + 1) * sigma_ratio,
    )
    alphas, sigmas = tf.reshape(alphas, [-1]), tf.reshape(sigmas, [-1])

    def _evaluate_alpha_sigma(alpha_sigma):
        alpha, sigma = alpha_sigma
        return _interpolate_and_perturb(
            alpha=alpha,
            sigma=sigma,
            params_init=module_variables_init,
            params_final=module_variables_final,
            objective_fn=objective_fn,
            loss_threshold_condition=functools.partial(
                loss_threshold_condition,
                reference_error=final_objective_value),
            normalize_error=normalize_error,
            num_samples_per_iteration=num_samples_per_iteration,
        )

    (threshold_conditions, interpolated_and_perturbed_losses,
     interpolated_and_perturbed_norms) = tf.map_fn(
         _evaluate_alpha_sigma,
         elems=(alphas, sigmas),
         dtype=(tf.bool, tf.float32, tf.float32),
     )

    masked_interpolated_and_perturbed_norms = tf.where(
        threshold_conditions, interpolated_and_perturbed_norms,
        tf.ones_like(interpolated_and_perturbed_norms) * np.inf)
    idx_min = tf.math.argmin(masked_interpolated_and_perturbed_norms)
    (loss_final, norm_final, alpha_final,
     sigma_final) = (interpolated_and_perturbed_losses[idx_min],
                     interpolated_and_perturbed_norms[idx_min],
                     alphas[idx_min], sigmas[idx_min])

    return ModuleCriticalityAnalysis(
        criticality_score=norm_final,
        alpha=alpha_final,
        sigma=sigma_final,
        loss_value=loss_final,
        num_samples_per_iteration=num_samples_per_iteration,
        alpha_grid_size=alpha_grid_size,
        sigma_grid_size=sigma_grid_size,
        sigma_ratio=sigma_ratio,
        initial_objective_value=initial_objective_value,
        final_objective_value=final_objective_value,
    )