def smc_body_fn(stage, state, smc_kernel_result):
     """Run one stage of SMC with constant temperature."""
     (new_marginal, new_inv_temperature,
      log_weights) = update_weights_temperature(
          smc_kernel_result.inverse_temperature,
          smc_kernel_result.particle_info.likelihood_log_prob)
     # TODO(b/152412213) Use a tf.scan to better collect debug info.
     if PRINT_DEBUG:
         tf.print(
             'Stage:', stage, 'Beta:', new_inv_temperature, 'n_steps:',
             smc_kernel_result.num_steps, 'accept:',
             tf.reduce_mean(
                 smc_kernel_result.particle_info.accept_prob),
             'scaling:',
             tf.reduce_mean(smc_kernel_result.particle_info.scalings))
     resampled_state, resampled_particle_info = resample(
         log_weights, state, smc_kernel_result.particle_info,
         seed_stream())
     num_steps, scalings = tuning(smc_kernel_result.num_steps,
                                  resampled_particle_info.scalings,
                                  resampled_particle_info.accept_prob)
     next_state, acceptance_rate, tempered_log_prob = mutate(
         resampled_state, scalings, num_steps, new_inv_temperature)
     next_pkr = SMCResults(
         num_steps=num_steps,
         inverse_temperature=new_inv_temperature,
         log_marginal_likelihood=(
             new_marginal + smc_kernel_result.log_marginal_likelihood),
         particle_info=ParticleInfo(
             accept_prob=acceptance_rate,
             scalings=scalings,
             tempered_log_prob=tempered_log_prob,
             likelihood_log_prob=likelihood_log_prob_fn(*next_state),
         ))
     return stage + 1, next_state, next_pkr
Ejemplo n.º 2
0
    def smc_body_fn(stage, state, smc_kernel_result):
      """Run one stage of SMC with constant temperature."""
      (
          new_marginal,
          new_inv_temperature,
          log_weights
      ) = update_weights_temperature(
          smc_kernel_result.inverse_temperature,
          smc_kernel_result.particle_info.likelihood_log_prob)
      # TODO(b/152412213) Use a tf.scan to better collect debug info.
      if PRINT_DEBUG:
        tf.print(
            'Stage:', stage,
            'Beta:', new_inv_temperature,
            'n_steps:', smc_kernel_result.num_steps,
            'accept:', tf.exp(reduce_logmeanexp(
                smc_kernel_result.particle_info.log_accept_prob, axis=0)),
            'scaling:', tf.exp(reduce_logmeanexp(
                smc_kernel_result.particle_info.log_scalings, axis=0))
            )
      (resampled_state,
       resampled_particle_info), _ = weighted_resampling.resample(
           particles=(state, smc_kernel_result.particle_info),
           log_weights=log_weights,
           resample_fn=resample_fn,
           seed=seed_stream)
      next_num_steps, next_log_scalings = tuning_fn(
          smc_kernel_result.num_steps,
          resampled_particle_info.log_scalings,
          resampled_particle_info.log_accept_prob)
      # Skip tuning at stage 0.
      next_num_steps = tf.where(stage == 0,
                                smc_kernel_result.num_steps,
                                next_num_steps)
      next_log_scalings = tf.where(stage == 0,
                                   resampled_particle_info.log_scalings,
                                   next_log_scalings)
      next_num_steps = tf.clip_by_value(
          next_num_steps, min_num_steps, max_num_steps)

      next_state, log_accept_prob, tempered_log_prob = mutate(
          resampled_state,
          next_log_scalings,
          next_num_steps,
          new_inv_temperature)
      next_pkr = SMCResults(
          num_steps=next_num_steps,
          inverse_temperature=new_inv_temperature,
          log_marginal_likelihood=(new_marginal +
                                   smc_kernel_result.log_marginal_likelihood),
          particle_info=ParticleInfo(
              log_accept_prob=log_accept_prob,
              log_scalings=next_log_scalings,
              tempered_log_prob=tempered_log_prob,
              likelihood_log_prob=likelihood_log_prob_fn(*next_state),
          ))
      return stage + 1, next_state, next_pkr
Ejemplo n.º 3
0
    def train_step(self, data):
        tf.print(train_step_message)
        x, y = data
        with tf.GradientTape() as tape:
            y_pred = self(x)
            loss = self.compiled_loss(y, y_pred)

        gradients = tape.gradient(loss, self.trainable_variables)
        self.optimizer.apply_gradients(zip(gradients, self.trainable_variables))
        return {}
Ejemplo n.º 4
0
def multi_output_loss(y_true: tf.Tensor, y_pred: tf.Tensor) -> tf.float32:
    """Calculates the MSE and Binary Cross Entropy over the outputs."""
    tf.print("Sum of actual masking: ", tf.reduce_sum(y_true))
    tf.print("Sum of predicted masking: ", tf.reduce_sum(y_pred))
    # loss_multiplier = tf.where(tf.greater(y_true, tf.constant(5.)), tf.constant(10.),
    #                        tf.constant(1.))
    loss = tf.keras.losses.mean_squared_error(y_true, y_pred)
    # tf.print("Y true: ", y_true)
    # tf.print("Loss multiplier: ", loss_multiplier)
    # loss *= tf.cast(loss_multiplier, dtype=tf.float32)
    return tf.reduce_mean(loss)
Ejemplo n.º 5
0
  def _maybe_warn_increased_dof(self,
                                component_name,
                                increased_dof):
    """Warns or raises when `increased_dof` is True."""
    # Short-circuit when increased_dof is statically False.
    increased_dof_ = tf.get_static_value(increased_dof)
    if increased_dof_ is False:  # pylint: disable=g-bool-id-comparison
      return

    error_message = (
        'Nested component "{}" in composition "{}" operates on inputs '
        'with increased degrees of freedom. This may result in an '
        'incorrect log_det_jacobian.'
        ).format(component_name, self.name)

    # When validate_args is True, we raise on increased DoF.
    if self._validate_args:
      if increased_dof_:
        raise ValueError(error_message)
      return assert_util.assert_equal(False, increased_dof, error_message)

    if (not tf.executing_eagerly() and
        control_flow_util.GraphOrParentsInXlaContext(tf1.get_default_graph())):
      return  # No StringFormat or Print ops in XLA.

    # Otherwise, we print a warning and continue.
    return ps.cond(
        pred=increased_dof,
        false_fn=tf.no_op,
        true_fn=lambda: tf.print(  # pylint: disable=g-long-lambda
            'WARNING: ' + error_message, output_stream=sys.stderr))
Ejemplo n.º 6
0
  def _maybe_warn_increased_dof(self,
                                component_name,
                                component_ldj,
                                increased_dof):
    """Warns or raises when `increased_dof` is True."""
    # Short-circuit when the component LDJ is statically zero.
    if (tf.get_static_value(tf.rank(component_ldj)) == 0
        and tf.get_static_value(component_ldj) == 0):
      return

    # Short-circuit when increased_dof is statically False.
    increased_dof_ = tf.get_static_value(increased_dof)
    if increased_dof_ is False:  # pylint: disable=g-bool-id-comparison
      return

    error_message = (
        'Nested component "{}" in composition "{}" operates on inputs '
        'with increased degrees of freedom. This may result in an '
        'incorrect log_det_jacobian.'
        ).format(component_name, self.name)

    # When validate_args is True, we raise on increased DoF.
    if self._validate_args:
      if increased_dof_:
        raise ValueError(error_message)
      return assert_util.assert_equal(False, increased_dof, error_message)

    # Otherwise, we print a warning and continue.
    return ps.cond(
        pred=increased_dof,
        false_fn=tf.no_op,
        true_fn=lambda: tf.print(  # pylint: disable=g-long-lambda
            'WARNING: ' + error_message, output_stream=sys.stderr))
Ejemplo n.º 7
0
def _stability_limit_tensor(total_count, dtype):
  limit = tf.cast(BATES_TOTAL_COUNT_STABILITY_LIMITS[dtype], dtype)
  return tf.cond(
      tf.math.reduce_any(total_count > limit),
      # pylint: disable=g-long-lambda
      lambda: tf.print(
          'WARNING: Bates PDF/CDF is unstable for `total_count` >', limit,
          output_stream=sys.stderr),
      tf.no_op)
Ejemplo n.º 8
0
    def __init__(self,
                 distribution_fn,
                 index_ranges,
                 parameter_fns,
                 coding_rank,
                 channel_axis=-1,
                 dtype=tf.float32,
                 likelihood_bound=1e-9,
                 tail_mass=2**-8,
                 range_coder_precision=12):
        """Initializer.

    Arguments:
      distribution_fn: A callable returning a `tfp.distributions.Distribution`
        object, which is used to model the distribution of the bottleneck tensor
        values including additive uniform noise - typically a `Distribution`
        class or factory function. The callable will receive keyword arguments
        as determined by `parameter_fns`. For best results, the distributions
        should be flexible enough to have a unit-width uniform distribution as a
        special case, since this is the distribution an element will take on
        when its bottleneck value is constant (due to the additive noise).
      index_ranges: Integer or iterable of integers. If a single integer,
        `indexes` must have the same shape as `bottleneck`, and `channel_axis`
        is ignored. Its values must be in the range `[0, index_ranges)`. If an
        iterable of integers, `indexes` must have an additional dimension at
        position `channel_axis`, and the values of the `n`th channel must be in
        the range `[0, index_ranges[n])`.
      parameter_fns: Dict of strings to callables. Functions mapping `indexes`
        to each distribution parameter. For each item, `indexes` is passed to
        the callable, and the string key and return value make up one keyword
        argument to `distribution_fn`.
      coding_rank: Integer. Number of innermost dimensions considered a coding
        unit. Each coding unit is compressed to its own bit string, and the
        `bits()` method sums over each coding unit.
      channel_axis: Integer. For iterable `index_ranges`, determines the
        position of the channel axis in `indexes`. Defaults to the last
        dimension.
      dtype: `tf.dtypes.DType`. The data type of all floating-point
        computations carried out in this class.
      likelihood_bound: Float. Lower bound for likelihood values, to prevent
        training instabilities.
      tail_mass: Float. Approximate probability mass which is range encoded with
        less precision, by using a Golomb-like code.
      range_coder_precision: Integer. Precision passed to the range coding op.
    """
        if coding_rank <= 0:
            raise ValueError("`coding_rank` must be larger than 0.")

        self._distribution_fn = distribution_fn
        if not callable(self.distribution_fn):
            raise TypeError(
                "`distribution_fn` must be a class or factory function.")
        try:
            self._index_ranges = int(index_ranges)
        except TypeError:
            self._index_ranges = tuple(int(r) for r in index_ranges)  # pytype: disable=attribute-error
        self._parameter_fns = dict(parameter_fns)
        for name, fn in self.parameter_fns.items():
            if not isinstance(name, str):
                raise TypeError("`parameter_fns` must have string keys.")
            if not callable(fn):
                raise TypeError(
                    "`parameter_fns['{}']` must be callable.".format(name))
        self._channel_axis = int(channel_axis)
        dtype = tf.as_dtype(dtype)

        if isinstance(self.index_ranges, int):
            indexes = tf.range(self.index_ranges, dtype=dtype)
        else:
            indexes = [tf.range(r, dtype=dtype) for r in self.index_ranges]
            indexes = tf.meshgrid(*indexes, indexing="ij")
            indexes = tf.stack(indexes, axis=self.channel_axis)
        parameters = {k: f(indexes) for k, f in self.parameter_fns.items()}
        distribution = self.distribution_fn(**parameters)  # pylint:disable=not-callable
        tf.print(distribution.batch_shape)
        tf.print(distribution.event_shape)

        super().__init__(distribution,
                         coding_rank,
                         likelihood_bound=likelihood_bound,
                         tail_mass=tail_mass,
                         range_coder_precision=range_coder_precision)
Ejemplo n.º 9
0
 def print_ids(self, ids):
     string_tensor = tf.strings.as_string(ids)
     tf.print(string_tensor)
Ejemplo n.º 10
0
    def estimate_average_reward(self, dataset: dataset_lib.OffpolicyDataset,
                                target_policy: tf_policy.TFPolicy):
        """Estimates value (average per-step reward) of policy.

    Args:
      dataset: The dataset to sample experience from.
      target_policy: The policy whose value we want to estimate.

    Returns:
      Estimated average per-step reward of the target policy.
    """
        def weight_fn(env_step):
            zeta = self._get_value(self._zeta_network, env_step)
            policy_ratio = 1.0
            if not self._solve_for_state_action_ratio:
                tfagents_timestep = dataset_lib.convert_to_tfagents_timestep(
                    env_step)
                target_log_probabilities = target_policy.distribution(
                    tfagents_timestep).action.log_prob(env_step.action)
                policy_ratio = tf.exp(target_log_probabilities -
                                      env_step.get_log_probability())
            return zeta * common_lib.reverse_broadcast(policy_ratio, zeta)

        def init_nu_fn(env_step, valid_steps):
            """Computes average initial nu values of episodes."""
            # env_step is an episode, and we just want the first step.
            if tf.rank(valid_steps) == 1:
                first_step = tf.nest.map_structure(lambda t: t[0, ...],
                                                   env_step)
            else:
                first_step = tf.nest.map_structure(lambda t: t[:, 0, ...],
                                                   env_step)
            value = self._get_average_value(self._nu_network, first_step,
                                            target_policy)
            return value

        nu_zero = (1 - self._gamma) * estimator_lib.get_fullbatch_average(
            dataset,
            limit=None,
            by_steps=False,
            truncate_episode_at=1,
            reward_fn=init_nu_fn)

        dual_step = estimator_lib.get_fullbatch_average(
            dataset,
            limit=None,
            by_steps=True,
            reward_fn=self._reward_fn,
            weight_fn=weight_fn)

        tf.summary.scalar('nu_zero', nu_zero)
        tf.summary.scalar('lam', self._norm_regularizer * self._lam)
        tf.summary.scalar('dual_step', dual_step)

        constraint, f_nu, f_zeta = self._eval_constraint_and_regs(
            dataset, target_policy)
        lagrangian = nu_zero + self._norm_regularizer * self._lam + constraint
        overall = (lagrangian + self._primal_regularizer * f_nu -
                   self._dual_regularizer * f_zeta)
        tf.summary.scalar('constraint', constraint)
        tf.summary.scalar('nu_reg', self._primal_regularizer * f_nu)
        tf.summary.scalar('zeta_reg', self._dual_regularizer * f_zeta)
        tf.summary.scalar('lagrangian', lagrangian)
        tf.summary.scalar('overall', overall)
        tf.print('step', tf.summary.experimental.get_step(), 'nu_zero =',
                 nu_zero, 'lam =', self._norm_regularizer * self._lam,
                 'dual_step =', dual_step, 'constraint =', constraint,
                 'preg =', self._primal_regularizer * f_nu, 'dreg =',
                 self._dual_regularizer * f_zeta, 'lagrangian =', lagrangian,
                 'overall =', overall)

        return dual_step
Ejemplo n.º 11
0
def main(argv):
    load_dir = FLAGS.load_dir
    save_dir = FLAGS.save_dir
    env_name = FLAGS.env_name
    seed = FLAGS.seed
    tabular_obs = FLAGS.tabular_obs
    num_trajectory = FLAGS.num_trajectory
    max_trajectory_length = FLAGS.max_trajectory_length
    alpha = FLAGS.alpha
    alpha_target = FLAGS.alpha_target
    gamma = FLAGS.gamma
    nu_learning_rate = FLAGS.nu_learning_rate
    zeta_learning_rate = FLAGS.zeta_learning_rate
    nu_regularizer = FLAGS.nu_regularizer
    zeta_regularizer = FLAGS.zeta_regularizer
    num_steps = FLAGS.num_steps
    batch_size = FLAGS.batch_size

    f_exponent = FLAGS.f_exponent
    primal_form = FLAGS.primal_form

    primal_regularizer = FLAGS.primal_regularizer
    dual_regularizer = FLAGS.dual_regularizer
    kl_regularizer = FLAGS.kl_regularizer
    zero_reward = FLAGS.zero_reward
    norm_regularizer = FLAGS.norm_regularizer
    zeta_pos = FLAGS.zeta_pos

    scale_reward = FLAGS.scale_reward
    shift_reward = FLAGS.shift_reward
    transform_reward = FLAGS.transform_reward

    kl_regularizer = FLAGS.kl_regularizer
    eps_std = FLAGS.eps_std

    def reward_fn(env_step):
        reward = env_step.reward * scale_reward + shift_reward
        if transform_reward is None:
            return reward
        if transform_reward == 'exp':
            reward = tf.math.exp(reward)
        elif transform_reward == 'cuberoot':
            reward = tf.sign(reward) * tf.math.pow(tf.abs(reward), 1.0 / 3.0)
        else:
            raise ValueError(
                'Reward {} not implemented.'.format(transform_reward))
        return reward

    hparam_str = ('{ENV_NAME}_tabular{TAB}_alpha{ALPHA}_seed{SEED}_'
                  'numtraj{NUM_TRAJ}_maxtraj{MAX_TRAJ}').format(
                      ENV_NAME=env_name,
                      TAB=tabular_obs,
                      ALPHA=alpha,
                      SEED=seed,
                      NUM_TRAJ=num_trajectory,
                      MAX_TRAJ=max_trajectory_length)
    train_hparam_str = (
        'nlr{NLR}_zlr{ZLR}_zeror{ZEROR}_preg{PREG}_dreg{DREG}_kreg{KREG}_nreg{NREG}_'
        'pform{PFORM}_fexp{FEXP}_zpos{ZPOS}_'
        'scaler{SCALER}_shiftr{SHIFTR}_transr{TRANSR}').format(
            NLR=nu_learning_rate,
            ZLR=zeta_learning_rate,
            ZEROR=zero_reward,
            PREG=primal_regularizer,
            DREG=dual_regularizer,
            KREG=kl_regularizer,
            NREG=norm_regularizer,
            PFORM=primal_form,
            FEXP=f_exponent,
            ZPOS=zeta_pos,
            SCALER=scale_reward,
            SHIFTR=shift_reward,
            TRANSR=transform_reward,
        )

    train_hparam_str = ('eps{EPS}_kl{KL}').format(EPS=eps_std,
                                                  KL=kl_regularizer)

    if save_dir is not None:
        target_hparam_str = hparam_str.replace(
            'alpha{}'.format(alpha),
            'alpha{}_alphat{}'.format(alpha, alpha_target))
        save_dir = os.path.join(save_dir, target_hparam_str, train_hparam_str)
        summary_writer = tf.summary.create_file_writer(logdir=save_dir)
        summary_writer.set_as_default()
    else:
        tf.summary.create_noop_writer()

    directory = os.path.join(load_dir, hparam_str)
    print('Loading dataset from', directory)
    dataset = Dataset.load(directory)
    #dataset = Dataset.load(directory.replace('alpha{}'.format(alpha), 'alpha0.0'))

    all_steps = dataset.get_all_steps()
    max_reward = tf.reduce_max(all_steps.reward)
    min_reward = tf.reduce_min(all_steps.reward)
    print('num loaded steps', dataset.num_steps)
    print('num loaded total steps', dataset.num_total_steps)
    print('num loaded episodes', dataset.num_episodes)
    print('num loaded total episodes', dataset.num_total_episodes)
    print('min reward', min_reward, 'max reward', max_reward)
    print('behavior per-step',
          estimator_lib.get_fullbatch_average(dataset, gamma=gamma))

    activation_fn = tf.nn.relu
    kernel_initializer = tf.keras.initializers.GlorotUniform()
    hidden_dims = (64, 64)
    input_spec = (dataset.spec.observation, dataset.spec.action)
    nu_network = ValueNetwork(input_spec,
                              output_dim=2,
                              fc_layer_params=hidden_dims,
                              activation_fn=activation_fn,
                              kernel_initializer=kernel_initializer,
                              last_kernel_initializer=kernel_initializer)
    output_activation_fn = tf.math.square if zeta_pos else tf.identity
    zeta_network = ValueNetwork(input_spec,
                                output_dim=2,
                                fc_layer_params=hidden_dims,
                                activation_fn=activation_fn,
                                output_activation_fn=output_activation_fn,
                                kernel_initializer=kernel_initializer,
                                last_kernel_initializer=kernel_initializer)

    nu_optimizer = tf.keras.optimizers.Adam(nu_learning_rate)
    zeta_optimizer = tf.keras.optimizers.Adam(zeta_learning_rate)
    lam_optimizer = tf.keras.optimizers.Adam(nu_learning_rate)

    estimator = NeuralBayesDice(dataset.spec,
                                nu_network,
                                zeta_network,
                                nu_optimizer,
                                zeta_optimizer,
                                lam_optimizer,
                                gamma,
                                zero_reward=zero_reward,
                                f_exponent=f_exponent,
                                primal_form=primal_form,
                                reward_fn=reward_fn,
                                primal_regularizer=primal_regularizer,
                                dual_regularizer=dual_regularizer,
                                kl_regularizer=kl_regularizer,
                                eps_std=FLAGS.eps_std,
                                norm_regularizer=norm_regularizer,
                                nu_regularizer=nu_regularizer,
                                zeta_regularizer=zeta_regularizer)

    global_step = tf.Variable(0, dtype=tf.int64)
    tf.summary.experimental.set_step(global_step)

    target_policy = get_target_policy(load_dir, env_name, tabular_obs,
                                      alpha_target)
    running_losses = []
    all_dual = []
    for step in range(num_steps):
        transitions_batch = dataset.get_step(batch_size, num_steps=2)
        initial_steps_batch, _ = dataset.get_episode(batch_size,
                                                     truncate_episode_at=1)
        initial_steps_batch = tf.nest.map_structure(lambda t: t[:, 0, ...],
                                                    initial_steps_batch)
        losses = estimator.train_step(initial_steps_batch, transitions_batch,
                                      target_policy)
        running_losses.append(losses)
        if step % 500 == 0 or step == num_steps - 1:
            num_samples = 100
            dual_ests = []
            for i in range(num_samples):
                dual_est = estimator.estimate_average_reward(
                    dataset, target_policy, write_summary=(i == 0))
                dual_ests.append(dual_est)
            tf.summary.scalar('dual/mean', tf.math.reduce_mean(dual_ests))
            tf.summary.scalar('dual/std', tf.math.reduce_std(dual_ests))

            tf.print('dual/mean =', tf.math.reduce_mean(dual_ests),
                     'dual/std =', tf.math.reduce_std(dual_ests))

            all_dual.append(dual_ests)
            running_losses = []
        global_step.assign_add(1)

    if save_dir is not None:
        np.save(tf.io.gfile.GFile(os.path.join(save_dir, 'results.npy'), 'w'),
                all_dual)

    print('Done!')
Ejemplo n.º 12
0
def print_tensor(tensor):
    tf.print(tensor)
def stderr(*args, **kwargs):
    """Print to stderr"""
    with tf.device('/device:CPU:0'):
        tf.print(*args, output_stream=sys.stderr, **kwargs)
Ejemplo n.º 14
0
def input_tens(tens):
    tf.print(tens)
Ejemplo n.º 15
0
    def projection(self, X, y, resp, weights, t_range):

        print("----Tracing___projection")

        @tf.function
        def expec_ll(alpha1, alpha2):  #, i, j, c1, c2):

            temp_lower = tf.identity(self.lower)
            temp_upper = tf.identity(self.upper)

            self.lower.assign(alpha1)
            self.upper.assign(alpha2)

            log_cond = self.expected_ll(X, y, resp, weights)

            self.lower.assign(temp_lower)
            self.upper.assign(temp_upper)

            #tf.print(self.lower)
            #tf.print(self.upper)

            return log_cond

        tmp_indexes = tf.where(
            tf.less(self.no_ovelap_test(), -self.theta / 50.))
        #tf.print("number of remaining overlapping: ", tf.size(tmp_indexes))

        while not (tf.equal(tf.size(tmp_indexes), 0)):
            #print("while  looop")

            classes = tf.cast(tf.math.floordiv(tmp_indexes, self.n_components),
                              tf.int32)
            good_indexes = tf.cast(
                tf.math.floormod(tmp_indexes, self.n_components), tf.int32)

            score = tf.TensorArray(dtype=tf.float32,
                                   size=0,
                                   dynamic_size=True,
                                   name="score",
                                   clear_after_read=False)
            #Matrix of updates
            alpha1 = tf.TensorArray(dtype=tf.float32,
                                    size=0,
                                    dynamic_size=True,
                                    name="alpha1",
                                    clear_after_read=False)

            alpha2 = tf.TensorArray(dtype=tf.float32,
                                    size=0,
                                    dynamic_size=True,
                                    name="alpha2",
                                    clear_after_read=False)

            #tf.print(classes)
            #tf.print(good_indexes)

            #For each update, compute the entropy

            for it in tf.range(tf.minimum(tf.constant(self.data_dim), 20)):
                #print("toto")
                d = t_range[it]
                #tf.print(self.lower)
                #print("d loooooop")

                if self.upper[classes[0, 0], good_indexes[0, 0],
                              d] > self.upper[classes[0, 1],
                                              good_indexes[0, 1], d]:

                    alpha1 = alpha1.write(
                        2 * it,
                        tf.tensor_scatter_nd_update(
                            self.lower,
                            [[classes[0, 0], good_indexes[0, 0], d]],
                            [self.upper[classes[0, 1], good_indexes[0, 1], d]
                             ]))

                    alpha2 = alpha2.write(2 * it, self.upper)
                    score = score.write(
                        2 * it,
                        expec_ll(alpha1.read(2 * it), alpha2.read(2 * it)))
                    #,good_indexes[0,0], good_indexes[0,1], classes[0,0], classes [0,1]))

                else:

                    alpha1 = alpha1.write(
                        2 * it,
                        tf.tensor_scatter_nd_update(
                            self.lower,
                            [[classes[0, 1], good_indexes[0, 1], d]],
                            [self.upper[classes[0, 0], good_indexes[0, 0], d]
                             ]))

                    alpha2 = alpha2.write(2 * it, self.upper)

                    score = score.write(
                        2 * it,
                        expec_ll(alpha1.read(2 * it), alpha2.read(2 * it)))
                    #,
                    #good_indexes[0,0], good_indexes[0,1], classes[0,0], classes [0,1]))

                if self.lower[classes[0, 0], good_indexes[0, 0],
                              d] < self.lower[classes[0, 1],
                                              good_indexes[0, 1], d]:

                    alpha2 = alpha2.write(
                        2 * it + 1,
                        tf.tensor_scatter_nd_update(
                            self.upper,
                            [[classes[0, 0], good_indexes[0, 0], d]],
                            [self.lower[classes[0, 1], good_indexes[0, 1], d]
                             ]))

                    alpha1 = alpha1.write(2 * it + 1, self.lower)

                    score = score.write(
                        2 * it + 1,
                        expec_ll(alpha1.read(2 * it + 1),
                                 alpha2.read(2 * it + 1)))
                    #,
                    #   good_indexes[0,0], good_indexes[0,1], classes[0,0], classes [0,1]))

                else:

                    alpha2 = alpha2.write(
                        2 * it + 1,
                        tf.tensor_scatter_nd_update(
                            self.upper,
                            [[classes[0, 1], good_indexes[0, 1], d]],
                            [self.lower[classes[0, 0], good_indexes[0, 0], d]
                             ]))

                    alpha1 = alpha1.write(2 * it + 1, self.lower)

                    score = score.write(
                        2 * it + 1,
                        expec_ll(alpha1.read(2 * it + 1),
                                 alpha2.read(2 * it + 1)))

            #tf.print(score.stack())

            #change the values of alpha corresponding to the lowest update
            true_score = score.stack()
            #ind = tf.cast(tf.math.argmin(tf.boolean_mask(true_score, tf.greater(true_score,0))), tf.int32)
            ind = tf.cast(tf.math.argmin(true_score), tf.int32)
            #tf.print(ind)

            self.lower.assign(alpha1.read(ind))
            self.upper.assign(alpha2.read(ind))

            #Re-compute the no-overlapp
            tmp_indexes = tf.where(tf.less(self.no_ovelap_test(), -self.theta))
            #if not(tf.equal(tf.size(tmp_indexes), 0)):
            #    tf.print(self.no_ovelap_test()[tmp_indexes[0,0], tmp_indexes[0,1]])
            tf.print("number of remaining overlapping: ", tf.size(tmp_indexes))
def main(argv):
    env_name = FLAGS.env_name
    seed = FLAGS.seed
    tabular_obs = FLAGS.tabular_obs
    num_trajectory = FLAGS.num_trajectory
    max_trajectory_length = FLAGS.max_trajectory_length
    load_dir = FLAGS.load_dir
    save_dir = FLAGS.save_dir
    gamma = FLAGS.gamma
    assert 0 <= gamma < 1.
    alpha = FLAGS.alpha
    alpha_target = FLAGS.alpha_target

    num_steps = FLAGS.num_steps
    batch_size = FLAGS.batch_size
    zeta_learning_rate = FLAGS.zeta_learning_rate
    nu_learning_rate = FLAGS.nu_learning_rate
    solve_for_state_action_ratio = FLAGS.solve_for_state_action_ratio
    eps_std = FLAGS.eps_std
    kl_regularizer = FLAGS.kl_regularizer

    target_policy = get_target_policy(load_dir,
                                      env_name,
                                      tabular_obs,
                                      alpha=alpha_target)

    hparam_str = ('{ENV_NAME}_tabular{TAB}_alpha{ALPHA}_seed{SEED}_'
                  'numtraj{NUM_TRAJ}_maxtraj{MAX_TRAJ}').format(
                      ENV_NAME=env_name,
                      TAB=tabular_obs,
                      ALPHA=alpha,
                      SEED=seed,
                      NUM_TRAJ=num_trajectory,
                      MAX_TRAJ=max_trajectory_length)

    directory = os.path.join(load_dir, hparam_str)
    print('Loading dataset.')
    dataset = Dataset.load(directory)
    print('num loaded steps', dataset.num_steps)
    print('num loaded total steps', dataset.num_total_steps)
    print('num loaded episodes', dataset.num_episodes)
    print('num loaded total episodes', dataset.num_total_episodes)
    print('behavior per-step',
          estimator_lib.get_fullbatch_average(dataset, gamma=gamma))

    train_hparam_str = ('eps{EPS}_kl{KL}').format(EPS=eps_std,
                                                  KL=kl_regularizer)

    if save_dir is not None:
        # Save for a specific alpha target
        target_hparam_str = hparam_str.replace(
            'alpha{}'.format(alpha),
            'alpha{}_alphat{}'.format(alpha, alpha_target))
        save_dir = os.path.join(save_dir, target_hparam_str, train_hparam_str)
        summary_writer = tf.summary.create_file_writer(logdir=save_dir)
    else:
        summary_writer = tf.summary.create_noop_writer()

    estimator = TabularBayesDice(
        dataset_spec=dataset.spec,
        gamma=gamma,
        solve_for_state_action_ratio=solve_for_state_action_ratio,
        zeta_learning_rate=zeta_learning_rate,
        nu_learning_rate=nu_learning_rate,
        kl_regularizer=kl_regularizer,
        eps_std=eps_std,
    )
    estimator.prepare_dataset(dataset, target_policy)

    global_step = tf.Variable(0, dtype=tf.int64)
    tf.summary.experimental.set_step(global_step)
    with summary_writer.as_default():
        running_losses = []
        running_estimates = []
        for step in range(num_steps):
            loss = estimator.train_step()[0]
            running_losses.append(loss)
            global_step.assign_add(1)

            if step % 500 == 0 or step == num_steps - 1:
                print('step', step, 'losses', np.mean(running_losses, 0))
                estimate = estimator.estimate_average_reward(
                    dataset, target_policy)
                tf.debugging.check_numerics(estimate, 'NaN in estimate')
                running_estimates.append(estimate)
                tf.print('est', tf.math.reduce_mean(estimate),
                         tf.math.reduce_std(estimate))

                running_losses = []

    if save_dir is not None:
        with tf.io.gfile.GFile(os.path.join(save_dir, 'results.npy'),
                               'w') as f:
            np.save(f, running_estimates)
        print('saved results to %s' % save_dir)

    print('Done!')
Ejemplo n.º 17
0
    def estimate_average_reward(self,
                                dataset: dataset_lib.OffpolicyDataset,
                                target_policy: tf_policy.TFPolicy,
                                write_summary: bool = False):
        """Estimates value (average per-step reward) of policy.

    Args:
      dataset: The dataset to sample experience from.
      target_policy: The policy whose value we want to estimate.

    Returns:
      Estimated average per-step reward of the target policy.
    """
        def weight_fn(env_step):
            zeta, _, _ = self._sample_value(self._zeta_network, env_step)
            policy_ratio = 1.0
            if not self._solve_for_state_action_ratio:
                tfagents_timestep = dataset_lib.convert_to_tfagents_timestep(
                    env_step)
                target_log_probabilities = target_policy.distribution(
                    tfagents_timestep).action.log_prob(env_step.action)
                policy_ratio = tf.exp(target_log_probabilities -
                                      env_step.get_log_probability())
            return zeta * common_lib.reverse_broadcast(policy_ratio, zeta)

        def init_nu_fn(env_step, valid_steps):
            """Computes average initial nu values of episodes."""
            # env_step is an episode, and we just want the first step.
            if tf.rank(valid_steps) == 1:
                first_step = tf.nest.map_structure(lambda t: t[0, ...],
                                                   env_step)
            else:
                first_step = tf.nest.map_structure(lambda t: t[:, 0, ...],
                                                   env_step)
            value, _, _ = self._sample_average_value(self._nu_network,
                                                     first_step, target_policy)
            return value

        dual_step = estimator_lib.get_fullbatch_average(
            dataset,
            limit=None,
            by_steps=True,
            reward_fn=self._reward_fn,
            weight_fn=weight_fn)

        nu_zero = (1 - self._gamma) * estimator_lib.get_fullbatch_average(
            dataset,
            limit=None,
            by_steps=False,
            truncate_episode_at=1,
            reward_fn=init_nu_fn)

        if not write_summary:
            return dual_step

        tf.summary.scalar('eval/dual_step', dual_step)
        tf.summary.scalar('eval/nu_zero', nu_zero)
        tf.print('step', tf.summary.experimental.get_step(), 'dual_step =',
                 dual_step, 'nu_zero =', nu_zero)

        return dual_step
Ejemplo n.º 18
0
def train_step_black_box(data, labels_one_hot, samples, weights, _lambda):
    print("----Tracing--train_step_black_box")

    @tf.function
    def share_loss(X, weights):
        print("----Tracing---share_loss")

        def kl_divergence(x_d):
            print("---Tracing the KL")
            kl = tf.keras.losses.KLDivergence()
            return kl(tf.exp(model.compute_log_conditional_distribution(x_d)),
                      black_box(x_d))

        return tfp.monte_carlo.expectation(f=kl_divergence,
                                           samples=X,
                                           log_prob=model.log_pdf,
                                           use_reparametrization=False)

    with tf.GradientTape() as tape1:
        # share_loss = _lambda*black_box.share_loss(X = samples,  sTGMA = model , weights = weights)
        # cross_entropy = cross_ent(labels_one_hot, black_box(data))

        # loss = cross_entropy + share_loss + black_box.losses()
        #gradients = tape.gradient(loss , black_box.trainable_variables)

        print("--tracing-gradient_persistent")
        #print(samples)
        #print(weights)
        #print(black_box(data))
        share_loss = share_loss(X=samples, weights=weights)

    with tf.GradientTape() as tape2:
        cross_ent = tf.keras.losses.CategoricalCrossentropy()
        logits = black_box(data)
        cross_entropy = cross_ent(labels_one_hot, logits)
        # loss = cross_entropy + share_loss + black_box.losses()
    gradients1 = tape1.gradient(share_loss, black_box.trainable_variables)
    gradients2 = tape2.gradient(cross_entropy, black_box.trainable_variables)
    #tf.print([grads.shape for grads in gradients1] )
    #print("tattaataaa")
    numerator = tf.constant(0.0)
    denominator = tf.constant(0.0)

    for grads1, grads2 in zip(gradients1, gradients2):
        numerator = numerator + tf.reduce_sum(grads2 * grads2 -
                                              grads1 * grads2)
        denominator = denominator + tf.norm(grads1 - grads2)**2
    qiota = 1. - 1. / (1. + _lambda)
    tau = tf.math.maximum(tf.math.minimum(numerator / denominator, qiota), 0.0)
    gradients = [
        tau * grads1 + (1 - tau) * grads2
        for grads1, grads2 in zip(gradients1, gradients2)
    ]
    tf.print("Tau param: ", tau)
    optimizer_black_box.apply_gradients(
        zip(gradients, black_box.trainable_variables))

    del tape1
    del tape2

    return cross_entropy, share_loss, tau  #, gradients
Ejemplo n.º 19
0
 def gather(self, string_values, indices):
     tf.print(tf.gather(tf.as_string(string_values), indices))
def dataset_no_vars_loop(ds, dds):
  for pr in dds:
    tf.print(ds.reduce('SUM', pr, axis=None))
Ejemplo n.º 21
0
def dataset_no_vars_loop(ds):
    for e in ds:
        tf.print(e)
def return_with_default(x):
  if x > 0:
    tf.print('x', x)
    return x
  return x * x
Ejemplo n.º 23
0
def iterator_no_vars_loop(ds):
    for e in iter(ds):
        tf.print(e)
                                            batch,
                                            training=False)

            total_loss += ld_loss + f0_loss
            epoch_ld_loss += ld_loss
            epoch_f0_loss += f0_loss

            epoch_total_loss += total_loss

        grads = tape.gradient(total_loss, control_model.trainable_variables)

        optimizer.apply_gradients(zip(grads,
                                      control_model.trainable_variables))

        tf.print("epoch:", epoch, "step:", batch_count, "total_loss: ",
                 tr(total_loss), "ld_loss:", tr(ld_loss), "f0_loss:",
                 tr(f0_loss))

        batch_count += 1

    with trn_summary_writer.as_default():
        tf.summary.scalar("epoch total loss:",
                          epoch_total_loss / batch_count,
                          step=epoch)
        tf.summary.scalar("epoch ld_loss:",
                          epoch_ld_loss / batch_count,
                          step=epoch)
        tf.summary.scalar("epoch f0_loss:",
                          epoch_f0_loss / batch_count,
                          step=epoch)
Ejemplo n.º 25
0
def main(argv):

    if len(argv) > 1:
        raise app.UsageError("Too many command-line arguments.")

    tf.random.set_seed(FLAGS.seed)

    if FLAGS.loudness_traindata_proto_file_pattern is None:
        raise app.UsageError(
            "Must provide --loudness_data_proto_file_pattern.")

    log_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)),
                           FLAGS.logs_dir)
    if not os.path.exists(log_dir):
        os.mkdir(log_dir)

    logging.info("TensorFlow seed: %d", FLAGS.seed)

    input_shape = None
    if FLAGS.mode == "test":
        raise NotImplementedError("Did not implement mode test.")
        data = get_datasets(FLAGS.loudness_testdata_proto_file_pattern,
                            1,
                            carfac=FLAGS.use_carfac)
        logging.info("Created testing datasets")
        model = tf.keras.models.load_model(FLAGS.load_model_from_file)
        logging.info("Loaded model")
    elif FLAGS.mode == "train":
        data = get_datasets(FLAGS.loudness_traindata_proto_file_pattern,
                            FLAGS.batch_size,
                            carfac=FLAGS.use_carfac,
                            extra_file_pattern=FLAGS.
                            extra_loudness_traindata_proto_file_pattern)
        frequency_bins = None
        for example in data["train"].take(1):
            input_example, target_example = example
            input_shape = input_example.shape
            carfac_channels = input_example.shape[1]
            frequency_bins = input_example.shape[2]
        logging.info("Created model")
    elif FLAGS.mode == "eval_once":
        data = get_testdata(FLAGS.loudness_testdata_proto_file_pattern,
                            carfac=FLAGS.use_carfac)
        frequency_bins = None
        for example in data["test"].take(1):
            input_example, target_example, _ = example
            input_shape = input_example.shape
            carfac_channels = input_example.shape[1]
            frequency_bins = input_example.shape[2]
    model = LoudnessPredictor(
        frequency_bins=frequency_bins,
        carfac_channels=carfac_channels,
        num_rows_channel_kernel=FLAGS.num_rows_channel_kernel,
        num_cols_channel_kernel=FLAGS.num_cols_channel_kernel,
        num_filters_channels=FLAGS.num_filters_channels,
        num_rows_bin_kernel=FLAGS.num_rows_bin_kernel,
        num_cols_bin_kernel=FLAGS.num_cols_bin_kernel,
        num_filters_bins=FLAGS.num_filters_bins,
        dropout_p=FLAGS.dropout_p,
        use_channels=FLAGS.use_carfac,
        seed=FLAGS.seed)
    if FLAGS.load_from_checkpoint:
        path_to_load = os.path.join(log_dir, FLAGS.load_from_checkpoint)
        logging.info("Attempting to load model from %s", path_to_load)
        loaded = False
        try:
            model.load_weights(path_to_load)
            loaded = True
            logging.info("Loaded model")
        except Exception as err:
            logging.info(
                "Unable to load log dir checkpoint %s, trying "
                "'load_from_checkpoint' flag: %s", path_to_load, err)
            path_to_load = FLAGS.load_from_checkpoint
            try:
                model.load_weights(path_to_load)
                loaded = True
            except Exception as err:
                logging.info("Unable to load flag checkpoint %s: %s",
                             path_to_load, err)
    else:
        loaded = False

    example_image_batch = []
    if FLAGS.mode == "train":
        data_key = "train"
        for example in data[data_key].take(4):
            input_example, target = example
            input_shape = input_example.shape
            tf.print("(start train) input shape: ", input_shape)
            tf.print("(start train) target phons shape: ", target.shape)
            input_example = tf.expand_dims(input_example[0], axis=0)
            example_image_batch.append([input_example, target])

    elif FLAGS.mode == "eval_once":
        data_key = "test"
        for example in data[data_key].take(4):
            input_example, target, _ = example
            input_shape = input_example.shape
            tf.print("(start eval) input shape: ", input_shape)
            tf.print("(start eval) target phons shape: ", target.shape)
            input_example = tf.expand_dims(input_example[0], axis=0)
            example_image_batch.append([input_example, target])

    callbacks = [helpers.StepIncrementingCallback()]
    callbacks.append(
        tf.keras.callbacks.TensorBoard(log_dir=log_dir,
                                       histogram_freq=1,
                                       update_freq="batch",
                                       write_graph=True))
    model.build(input_shape)
    logging.info("Model summary")
    model.summary()

    if FLAGS.extra_loudness_traindata_proto_file_pattern:
        extra_data = True
    else:
        extra_data = False
    save_ckpt = log_dir + "/cp_carfac{}_extradata{}".format(
        FLAGS.use_carfac, extra_data) + "_{epoch:04d}.ckpt"
    logging.info("Save checkpoint to: %s" % save_ckpt)
    callbacks.append(
        tf.keras.callbacks.ModelCheckpoint(filepath=save_ckpt,
                                           save_weights_only=True,
                                           verbose=1,
                                           period=5))

    if FLAGS.mode == "train":
        logging.info("Starting training for %d epochs" % FLAGS.epochs)
        if FLAGS.extra_loudness_traindata_proto_file_pattern:
            steps_per_epoch = (317 + 639) // FLAGS.batch_size
        else:
            steps_per_epoch = 317 // FLAGS.batch_size
        train(model, data["train"], data["validate"], FLAGS.learning_rate,
              FLAGS.epochs, steps_per_epoch, callbacks)
    elif FLAGS.mode == "test":
        raise NotImplementedError("Mode test not implemented.")
        evaluate(model, data["test"], batch_size=FLAGS.eval_batch_size)
    elif FLAGS.mode == "eval_once":
        if not loaded:
            raise ValueError(
                "Trying to eval. a model with unitialized weights.")
        save_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)),
                                log_dir)
        write_predictions(model,
                          data["test"],
                          batch_size=1,
                          save_directory=save_dir,
                          save_file=FLAGS.save_predictions_file)
        return
    else:
        raise ValueError("Specified value for '--mode' (%s) unknown",
                         FLAGS.mode)
def iterator_no_vars_loop(ds, dds):
  for pr in iter(dds):
    tf.print(ds.reduce('SUM', pr, axis=None))
Ejemplo n.º 27
0
def main(argv):
    env_name = FLAGS.env_name
    seed = FLAGS.seed
    tabular_obs = FLAGS.tabular_obs
    num_trajectory = FLAGS.num_trajectory
    num_expert_trajectory = FLAGS.num_expert_trajectory
    max_trajectory_length = FLAGS.max_trajectory_length
    alpha = FLAGS.alpha
    alpha_expert = FLAGS.alpha_expert
    load_dir = FLAGS.load_dir
    save_dir = FLAGS.save_dir
    gamma = FLAGS.gamma
    assert 0 <= gamma < 1.
    embed_dim = FLAGS.embed_dim
    fourier_dim = FLAGS.fourier_dim
    embed_learning_rate = FLAGS.embed_learning_rate
    learning_rate = FLAGS.learning_rate
    finetune = FLAGS.finetune
    latent_policy = FLAGS.latent_policy
    embed_learner = FLAGS.embed_learner
    num_steps = FLAGS.num_steps
    embed_pretraining_steps = FLAGS.embed_pretraining_steps
    batch_size = FLAGS.batch_size

    hparam_str = ('{ENV_NAME}_tabular{TAB}_alpha{ALPHA}_seed{SEED}_'
                  'numtraj{NUM_TRAJ}_maxtraj{MAX_TRAJ}').format(
                      ENV_NAME=env_name,
                      TAB=tabular_obs,
                      ALPHA=alpha,
                      SEED=seed,
                      NUM_TRAJ=num_trajectory,
                      MAX_TRAJ=max_trajectory_length)
    directory = os.path.join(load_dir, hparam_str)
    print('Loading dataset.')
    dataset = Dataset.load(directory)
    print('num loaded steps', dataset.num_steps)
    print('num loaded total steps', dataset.num_total_steps)
    print('num loaded episodes', dataset.num_episodes)
    print('num loaded total episodes', dataset.num_total_episodes)
    estimate = estimator_lib.get_fullbatch_average(dataset, gamma=gamma)
    print('data per step avg', estimate)

    hparam_str = ('{ENV_NAME}_tabular{TAB}_alpha{ALPHA}_seed{SEED}_'
                  'numtraj{NUM_TRAJ}_maxtraj{MAX_TRAJ}').format(
                      ENV_NAME=env_name,
                      TAB=tabular_obs,
                      ALPHA=alpha_expert,
                      SEED=seed,
                      NUM_TRAJ=num_expert_trajectory,
                      MAX_TRAJ=max_trajectory_length)
    directory = os.path.join(load_dir, hparam_str)
    print('Loading expert dataset.')
    expert_dataset = Dataset.load(directory)
    print('num loaded expert steps', expert_dataset.num_steps)
    print('num loaded total expert steps', expert_dataset.num_total_steps)
    print('num loaded expert episodes', expert_dataset.num_episodes)
    print('num loaded total expert episodes',
          expert_dataset.num_total_episodes)
    expert_estimate = estimator_lib.get_fullbatch_average(expert_dataset,
                                                          gamma=gamma)
    print('expert data per step avg', expert_estimate)

    hparam_dict = {
        'env_name': env_name,
        'alpha_expert': alpha_expert,
        'seed': seed,
        'num_trajectory': num_trajectory,
        'num_expert_trajectory': num_expert_trajectory,
        'max_trajectory_length': max_trajectory_length,
        'embed_learner': embed_learner,
        'embed_dim': embed_dim,
        'fourier_dim': fourier_dim,
        'embed_learning_rate': embed_learning_rate,
        'learning_rate': learning_rate,
        'latent_policy': latent_policy,
        'finetune': finetune,
    }
    hparam_str = ','.join([
        '%s=%s' % (k, str(hparam_dict[k])) for k in sorted(hparam_dict.keys())
    ])
    summary_writer = tf.summary.create_file_writer(
        os.path.join(save_dir, hparam_str, 'train'))

    if embed_learner == 'sgd' or not embed_learner:
        algo = TabularBCSGD(dataset.spec,
                            gamma=gamma,
                            embed_dim=embed_dim,
                            embed_learning_rate=embed_learning_rate,
                            learning_rate=learning_rate,
                            finetune=finetune,
                            latent_policy=latent_policy)
    elif embed_learner == 'svd':
        algo = TabularBCSVD(dataset.spec,
                            gamma=gamma,
                            embed_dim=embed_dim,
                            learning_rate=learning_rate)
    elif embed_learner == 'energy':
        algo = TabularBCEnergy(dataset.spec,
                               gamma=gamma,
                               embed_dim=embed_dim,
                               fourier_dim=fourier_dim,
                               embed_learning_rate=embed_learning_rate,
                               learning_rate=learning_rate)
    else:
        raise ValueError('embed learner %s not supported' % embed_learner)

    if embed_learner == 'svd':
        embed_dict = algo.solve(dataset)
        with summary_writer.as_default():
            for k, v in embed_dict.items():
                tf.summary.scalar(f'embed/{k}', v, step=0)
                print('embed', k, v)
    else:
        algo.prepare_datasets(dataset, expert_dataset)
        if embed_learner is not None:
            for step in range(embed_pretraining_steps):
                batch = dataset.get_step(batch_size, num_steps=2)
                embed_dict = algo.train_embed(batch)
                if step % FLAGS.eval_interval == 0:
                    with summary_writer.as_default():
                        for k, v in embed_dict.items():
                            tf.summary.scalar(f'embed/{k}', v, step=step)
                            print('embed', step, k, v)

    for step in range(num_steps):
        batch = expert_dataset.get_step(batch_size, num_steps=2)
        info_dict = algo.train_step(batch)
        if step % FLAGS.eval_interval == 0:
            with summary_writer.as_default():
                for k, v in info_dict.items():
                    tf.summary.scalar(f'bc/{k}', v, step=step)
                    print('bc', k, v)

            policy_fn, policy_info_spec = algo.get_policy()
            onpolicy_data = get_onpolicy_dataset(env_name, tabular_obs,
                                                 policy_fn, policy_info_spec)
            onpolicy_episodes, _ = onpolicy_data.get_episode(
                100, truncate_episode_at=max_trajectory_length)
            with summary_writer.as_default():
                tf.print('eval/reward', np.mean(onpolicy_episodes.reward))
                tf.summary.scalar('eval/reward',
                                  np.mean(onpolicy_episodes.reward),
                                  step=step)
def sample_sequential_monte_carlo(prior_log_prob_fn,
                                  likelihood_log_prob_fn,
                                  current_state,
                                  max_num_steps=25,
                                  max_stage=100,
                                  make_kernel_fn=make_rwmh_kernel_fn,
                                  optimal_accept=0.234,
                                  target_accept_prob=0.99,
                                  ess_threshold_ratio=0.5,
                                  parallel_iterations=10,
                                  seed=None,
                                  name=None):
    """Runs Sequential Monte Carlo to sample from the posterior distribution.

  This function uses an MCMC transition operator (e.g., Hamiltonian Monte Carlo)
  to sample from a series of distributions that slowly interpolates between
  an initial 'prior' distribution:

  `exp(prior_log_prob_fn(x))`

  and the target 'posterior' distribution:

  `exp(prior_log_prob_fn(x) + target_log_prob_fn(x))`,

  by mutating a collection of MC samples (i.e., particles). The approach is also
  known as Particle Filter in some literature.

  Args:
    prior_log_prob_fn: Python callable that returns the log density of the
      prior distribution.
    likelihood_log_prob_fn: Python callable which takes an argument like
      `current_state` (or `*current_state` if it's a list) and returns its
      (possibly unnormalized) log-density under the likelihood distribution.
    current_state: `Tensor` or Python `list` of `Tensor`s representing the
      current state(s) of the Markov chain(s). The first `r` dimensions index
      independent chains, `r = tf.rank(target_log_prob_fn(*current_state))`.
    max_num_steps: The maximum number of kernel transition steps in one mutation
      of the MC samples. Note that the actual number of steps in one mutation is
      tuned during sampling and likely lower than the max_num_steps.
    max_stage: Integer number of the stage for increasing the temperature
      from 0 to 1.
    make_kernel_fn: Python `callable` which returns a `TransitionKernel`-like
      object. Must take one argument representing the `TransitionKernel`'s
      `target_log_prob_fn`. The `target_log_prob_fn` argument represents the
      `TransitionKernel`'s target log distribution.  Note:
      `sample_annealed_importance_chain` creates a new `target_log_prob_fn`
      which is an interpolation between the supplied `target_log_prob_fn` and
      `proposal_log_prob_fn`; it is this interpolated function which is used as
      an argument to `make_kernel_fn`.
    optimal_accept: Optimal acceptance ratio for a Transitional Kernel. Default
      to 0.234 for Random Walk Metropolis kernel.
    target_accept_prob: Target acceptance probability at the end of one mutation
      step.
    ess_threshold_ratio: Target ratio for effective sample size.
    parallel_iterations: The number of iterations allowed to run in parallel.
        It must be a positive integer. See `tf.while_loop` for more details.
    seed: Python integer or TFP seedstream to seed the random number generator.
    name: Python `str` name prefixed to Ops created by this function.
      Default value: `None` (i.e., 'sample_annealed_importance_chain').

  Returns:
    n_stage: Number of the mutation stage SMC ran.
    final_state: `Tensor` or Python `list` of `Tensor`s representing the
      final state(s) of the Markov chain(s). The output are the posterior
      samples.
    final_kernel_results: `collections.namedtuple` of internal calculations used
      to advance the chain.

  """

    with tf.name_scope(name or 'sample_sequential_monte_carlo'):
        seed_stream = SeedStream(seed, salt='smc_seed')

        unwrap_state_list = not tf.nest.is_nested(current_state)
        if unwrap_state_list:
            current_state = [current_state]
        current_state = [
            tf.convert_to_tensor(s, dtype_hint=tf.float32)
            for s in current_state
        ]

        num_replica = ps.size0(current_state[0])
        effective_sample_size_threshold = tf.cast(
            num_replica * ess_threshold_ratio, tf.int32)

        def preprocess_state(init_state):
            """Initial preprocessing at Stage 0."""
            dimension = ps.reduce_sum(
                [ps.reduce_prod(ps.shape(x)[1:]) for x in init_state])
            likelihood_log_prob = likelihood_log_prob_fn(*init_state)

            # Default to the optimal for normal distributed targets.
            # TODO(b/152412213): Revisit this tuning.
            scale_start = (
                tf.constant(2.38**2, dtype=likelihood_log_prob.dtype) /
                tf.constant(dimension, dtype=likelihood_log_prob.dtype))
            # TODO(b/152412213): Enable batch of batches style by using non-scalar
            # inverse_temperature
            inverse_temperature = tf.zeros([], dtype=likelihood_log_prob.dtype)
            scalings = ps.ones_like(likelihood_log_prob) * ps.minimum(
                scale_start, 1.)
            kernel = make_kernel_fn(_make_tempered_target_log_prob_fn(
                prior_log_prob_fn, likelihood_log_prob_fn,
                inverse_temperature),
                                    init_state,
                                    scalings,
                                    seed=seed_stream())
            pkr = kernel.bootstrap_results(current_state)
            mh_results = _find_inner_mh_results(pkr)

            particle_info = ParticleInfo(
                accept_prob=ps.ones_like(likelihood_log_prob),
                scalings=scalings,
                tempered_log_prob=mh_results.accepted_results.target_log_prob,
                likelihood_log_prob=likelihood_log_prob,
            )

            return SMCResults(num_steps=tf.convert_to_tensor(max_num_steps,
                                                             dtype=tf.int32,
                                                             name='num_steps'),
                              inverse_temperature=inverse_temperature,
                              log_marginal_likelihood=tf.constant(
                                  0., dtype=likelihood_log_prob.dtype),
                              particle_info=particle_info)

        def update_weights_temperature(inverse_temperature,
                                       likelihood_log_prob):
            """Calculate the next inverse temperature and update weights."""

            likelihood_diff = likelihood_log_prob - tf.reduce_max(
                likelihood_log_prob)

            def _body_fn(new_beta, upper_beta, lower_beta, eff_size,
                         log_weights):
                """One iteration of the temperature and weight update."""
                new_beta = (lower_beta + upper_beta) / 2.0
                log_weights = (new_beta -
                               inverse_temperature) * likelihood_diff
                log_weights_norm = (log_weights -
                                    tf.math.reduce_logsumexp(log_weights))
                eff_size = tf.cast(
                    tf.exp(-tf.math.reduce_logsumexp(2 * log_weights_norm)),
                    tf.int32)
                upper_beta = tf.where(
                    eff_size < effective_sample_size_threshold, new_beta,
                    upper_beta)
                lower_beta = tf.where(
                    eff_size < effective_sample_size_threshold, lower_beta,
                    new_beta)
                return new_beta, upper_beta, lower_beta, eff_size, log_weights

            (new_beta, upper_beta, lower_beta, eff_size,
             log_weights) = tf.while_loop(  # pylint: disable=unused-variable
                 cond=lambda new_beta, upper_beta, lower_beta, eff_size, *_:  # pylint: disable=g-long-lambda
                 (upper_beta - lower_beta > 1e-6) &
                 (eff_size != effective_sample_size_threshold),
                 body=_body_fn,
                 loop_vars=(tf.zeros_like(inverse_temperature),
                            tf.cast(2.0, inverse_temperature.dtype),
                            inverse_temperature, tf.cast(0, tf.int32),
                            tf.zeros_like(likelihood_diff)),
                 parallel_iterations=parallel_iterations)

            log_weights = tf.where(new_beta < 1., log_weights,
                                   (1. - inverse_temperature) *
                                   likelihood_diff)
            marginal_loglike_ = reduce_logmeanexp(
                (new_beta - inverse_temperature) * likelihood_log_prob)

            return marginal_loglike_, tf.clip_by_value(new_beta, 0.,
                                                       1.), log_weights

        def resample(log_weights, current_state, particle_info, seed=None):
            """Resample particles based on importance weights."""
            with tf.name_scope('resample_particles'):
                seed = SeedStream(seed, salt='resample_particles')
                resampling_indexes = tf.random.categorical(
                    [log_weights],
                    ps.reduce_prod(*ps.shape(log_weights)),
                    seed=seed())
                next_state = tf.nest.map_structure(
                    lambda x: tf.reshape(tf.gather(x, resampling_indexes),
                                         ps.shape(x)), current_state)
                next_particle_info = tf.nest.map_structure(
                    lambda x: tf.reshape(tf.gather(x, resampling_indexes),
                                         ps.shape(x)), particle_info)

                return next_state, next_particle_info

        def tuning(num_steps, scalings, accept_prob):
            """Tune scaling and/or num_steps based on the acceptance rate."""
            num_proposed = num_replica * num_steps
            accept_prob = tf.cast(accept_prob, dtype=scalings.dtype)
            avg_scaling = tf.exp(
                tf.math.log(tf.reduce_mean(scalings)) +
                (tf.reduce_mean(accept_prob) - optimal_accept))
            scalings = 0.5 * (
                avg_scaling +
                tf.exp(tf.math.log(scalings) + (accept_prob - optimal_accept)))

            if TUNE_STEPS:
                avg_accept = tf.math.maximum(
                    1.0 / tf.cast(num_proposed, dtype=accept_prob.dtype),
                    tf.reduce_mean(accept_prob))
                num_steps = tf.clip_by_value(
                    tf.cast(tf.math.log1p(
                        -tf.cast(target_accept_prob, dtype=avg_accept.dtype)) /
                            tf.math.log1p(-avg_accept),
                            dtype=num_steps.dtype), 2, max_num_steps)

            return num_steps, scalings

        def mutate(current_state, scalings, num_steps, inverse_temperature):
            """Mutate the state using a Transition kernel."""
            with tf.name_scope('mutate_states'):
                kernel = make_kernel_fn(_make_tempered_target_log_prob_fn(
                    prior_log_prob_fn, likelihood_log_prob_fn,
                    inverse_temperature),
                                        current_state,
                                        scalings,
                                        seed=seed_stream())
                pkr = kernel.bootstrap_results(current_state)
                mh_results = _find_inner_mh_results(pkr)

                def mutate_onestep(i, state, pkr, accept_count):
                    next_state, next_kernel_results = kernel.one_step(
                        state, pkr)
                    mh_results = _find_inner_mh_results(pkr)
                    # TODO(b/152412213) Cumulate log_acceptance_ratio instead.
                    accept_count += tf.cast(mh_results.is_accepted,
                                            accept_count.dtype)
                    return i + 1, next_state, next_kernel_results, accept_count

                (_, next_state, next_kernel_results,
                 accept_count) = tf.while_loop(
                     cond=lambda i, *args: i < num_steps,
                     body=mutate_onestep,
                     loop_vars=(tf.zeros([],
                                         dtype=tf.int32), current_state, pkr,
                                tf.zeros_like(mh_results.is_accepted,
                                              tf.float32)),
                     parallel_iterations=parallel_iterations)
                next_mh_results = _find_inner_mh_results(next_kernel_results)

                return (next_state, accept_count /
                        tf.cast(num_steps + 1, accept_count.dtype),
                        next_mh_results.accepted_results.target_log_prob)

        pkr = preprocess_state(current_state)
        # Run once
        new_marginal, new_inv_temperature, log_weights = update_weights_temperature(
            pkr.inverse_temperature, pkr.particle_info.likelihood_log_prob)
        if PRINT_DEBUG:
            tf.print('Stage:', 0, 'Beta:', new_inv_temperature, 'n_steps:',
                     pkr.num_steps, 'accept:',
                     tf.reduce_mean(pkr.particle_info.accept_prob), 'scaling:',
                     tf.reduce_mean(pkr.particle_info.scalings))
        resampled_state, resampled_particle_info = resample(
            log_weights, current_state, pkr.particle_info, seed_stream())
        next_state, acceptance_rate, tempered_log_prob = mutate(
            resampled_state, resampled_particle_info.scalings, pkr.num_steps,
            new_inv_temperature)
        next_pkr = SMCResults(
            num_steps=pkr.num_steps,
            inverse_temperature=new_inv_temperature,
            log_marginal_likelihood=pkr.log_marginal_likelihood + new_marginal,
            particle_info=ParticleInfo(
                accept_prob=acceptance_rate,
                scalings=resampled_particle_info.scalings,
                tempered_log_prob=tempered_log_prob,
                likelihood_log_prob=likelihood_log_prob_fn(*next_state),
            ))

        # Stage > 0
        def smc_body_fn(stage, state, smc_kernel_result):
            """Run one stage of SMC with constant temperature."""
            (new_marginal, new_inv_temperature,
             log_weights) = update_weights_temperature(
                 smc_kernel_result.inverse_temperature,
                 smc_kernel_result.particle_info.likelihood_log_prob)
            # TODO(b/152412213) Use a tf.scan to better collect debug info.
            if PRINT_DEBUG:
                tf.print(
                    'Stage:', stage, 'Beta:', new_inv_temperature, 'n_steps:',
                    smc_kernel_result.num_steps, 'accept:',
                    tf.reduce_mean(
                        smc_kernel_result.particle_info.accept_prob),
                    'scaling:',
                    tf.reduce_mean(smc_kernel_result.particle_info.scalings))
            resampled_state, resampled_particle_info = resample(
                log_weights, state, smc_kernel_result.particle_info,
                seed_stream())
            num_steps, scalings = tuning(smc_kernel_result.num_steps,
                                         resampled_particle_info.scalings,
                                         resampled_particle_info.accept_prob)
            next_state, acceptance_rate, tempered_log_prob = mutate(
                resampled_state, scalings, num_steps, new_inv_temperature)
            next_pkr = SMCResults(
                num_steps=num_steps,
                inverse_temperature=new_inv_temperature,
                log_marginal_likelihood=(
                    new_marginal + smc_kernel_result.log_marginal_likelihood),
                particle_info=ParticleInfo(
                    accept_prob=acceptance_rate,
                    scalings=scalings,
                    tempered_log_prob=tempered_log_prob,
                    likelihood_log_prob=likelihood_log_prob_fn(*next_state),
                ))
            return stage + 1, next_state, next_pkr

        (n_stage, final_state, final_kernel_results) = tf.while_loop(
            cond=lambda i, state, pkr: (i < max_stage) & (  # pylint: disable=g-long-lambda
                pkr.inverse_temperature < 1.),
            body=smc_body_fn,
            loop_vars=(tf.ones([], dtype=tf.int32), next_state, next_pkr),
            parallel_iterations=parallel_iterations)
        if unwrap_state_list:
            final_state = final_state[0]
        return n_stage, final_state, final_kernel_results
Ejemplo n.º 29
0
def input_(x_tens):
    tf.print(x_tens)
 def my_fn(x):
   return {k: tf.print(v, [v], k + ": ") for k, v in x.items()}