Exemplo n.º 1
0
  def create_loss():
    """Creates the loss to be optimized.

    Returns:
      bound: A float Tensor containing the value of the bound that is
        being optimized.
      loss: A float Tensor that when differentiated yields the gradients
        to apply to the model. Should be optimized via gradient descent.
    """
    inputs, targets, lengths, model = create_dataset_and_model(
        config, split="train", shuffle=True, repeat=True)
    # Compute lower bounds on the log likelihood.
    if config.bound == "elbo":
      ll_per_seq, _, _, _ = bounds.iwae(
          model, (inputs, targets), lengths, num_samples=1)
    elif config.bound == "iwae":
      ll_per_seq, _, _, _ = bounds.iwae(
          model, (inputs, targets), lengths, num_samples=config.num_samples)
    elif config.bound == "fivo":
      ll_per_seq, _, _, _, _ = bounds.fivo(
          model, (inputs, targets), lengths, num_samples=config.num_samples,
          resampling_criterion=bounds.ess_criterion)
    # Compute loss scaled by number of timesteps.
    ll_per_t = tf.reduce_mean(ll_per_seq / tf.to_float(lengths))
    ll_per_seq = tf.reduce_mean(ll_per_seq)

    tf.summary.scalar("train_ll_per_seq", ll_per_seq)
    tf.summary.scalar("train_ll_per_t", ll_per_t)

    if config.normalize_by_seq_len:
      return ll_per_t, -ll_per_t
    else:
      return ll_per_seq, -ll_per_seq
Exemplo n.º 2
0
  def create_graph():
    """Creates the evaluation graph.

    Returns:
      lower_bounds: A tuple of float Tensors containing the values of the 3
        evidence lower bounds, summed across the batch.
      total_batch_length: The total number of timesteps in the batch, summed
        across batch examples.
      batch_size: The batch size.
      global_step: The global step the checkpoint was loaded from.
    """
    global_step = tf.train.get_or_create_global_step()
    inputs, targets, lengths, model = create_dataset_and_model(
        config, split=config.split, shuffle=False, repeat=False)
    # Compute lower bounds on the log likelihood.
    elbo_ll_per_seq, _, _, _ = bounds.iwae(
        model, (inputs, targets), lengths, num_samples=1)
    iwae_ll_per_seq, _, _, _ = bounds.iwae(
        model, (inputs, targets), lengths, num_samples=config.num_samples)
    fivo_ll_per_seq, _, _, _, _ = bounds.fivo(
        model, (inputs, targets), lengths, num_samples=config.num_samples,
        resampling_criterion=bounds.ess_criterion)
    elbo_ll = tf.reduce_sum(elbo_ll_per_seq)
    iwae_ll = tf.reduce_sum(iwae_ll_per_seq)
    fivo_ll = tf.reduce_sum(fivo_ll_per_seq)
    batch_size = tf.shape(lengths)[0]
    total_batch_length = tf.reduce_sum(lengths)
    return ((elbo_ll, iwae_ll, fivo_ll), total_batch_length, batch_size,
            global_step)
Exemplo n.º 3
0
    def create_graph():
        """Creates the evaluation graph.

    Returns:
      lower_bounds: A tuple of float Tensors containing the values of the 3
        evidence lower bounds, summed across the batch.
      total_batch_length: The total number of timesteps in the batch, summed
        across batch examples.
      batch_size: The batch size.
      global_step: The global step the checkpoint was loaded from.
    """
        global_step = tf.train.get_or_create_global_step()
        inputs, targets, lengths, model = create_dataset_and_model(
            config, split=config.split, shuffle=False, repeat=False)
        # Compute lower bounds on the log likelihood.
        elbo_ll_per_seq, _, _, _ = bounds.iwae(model, (inputs, targets),
                                               lengths,
                                               num_samples=1)
        iwae_ll_per_seq, _, _, _ = bounds.iwae(model, (inputs, targets),
                                               lengths,
                                               num_samples=config.num_samples)
        fivo_ll_per_seq, _, _, _, _ = bounds.fivo(
            model, (inputs, targets),
            lengths,
            num_samples=config.num_samples,
            resampling_criterion=bounds.ess_criterion)
        elbo_ll = tf.reduce_sum(elbo_ll_per_seq)
        iwae_ll = tf.reduce_sum(iwae_ll_per_seq)
        fivo_ll = tf.reduce_sum(fivo_ll_per_seq)
        batch_size = tf.shape(lengths)[0]
        total_batch_length = tf.reduce_sum(lengths)
        return ((elbo_ll, iwae_ll, fivo_ll), total_batch_length, batch_size,
                global_step)
Exemplo n.º 4
0
    def create_loss():
        """Creates the loss to be optimized.
        Returns:
            bound: A float Tensor containing the value of the bound that is
                   being optimized.
            loss: A float Tensor that when differentiated yields the gradients
                to apply to the model. Should be optimized via gradient descent.
        """
        inputs, targets, mmsis, lengths, model = create_dataset_and_model(
            config, config.split, shuffle=True, repeat=True)
        # Compute lower bounds on the log likelihood.
        if config.bound == "elbo":
            ll_per_seq, _, _, _ = bounds.elbo(model, (inputs, targets),
                                              lengths,
                                              num_samples=1)
        elif config.bound == "fivo":
            ll_per_seq, _, _, _, _ = bounds.fivo(
                model, (inputs, targets),
                lengths,
                num_samples=config.num_samples,
                resampling_criterion=bounds.ess_criterion)
        # Compute loss scaled by number of timesteps.
        ll_per_t = tf.reduce_mean(ll_per_seq / tf.to_float(lengths))
        ll_per_seq = tf.reduce_mean(ll_per_seq)

        tf.summary.scalar("train_ll_per_seq", ll_per_seq)
        tf.summary.scalar("train_ll_per_t", ll_per_t)

        if config.normalize_by_seq_len:
            return ll_per_t, -ll_per_t
        else:
            return ll_per_seq, -ll_per_seq
Exemplo n.º 5
0
    def create_graph():
        global_step = tf.train.get_or_create_global_step()
        inputs, targets, lengths, model = create_dataset_and_model(
            config, split=config.split, shuffle=False, repeat=False)
        # Compute lower bounds on the log likelihood.
        elbo_ll_per_seq, _, _, _, _ = bounds.fivo(
            model, (inputs, targets),
            lengths,
            num_samples=config.num_samples,
            resampling_criterion=bounds.ess_criterion)
        elbo_ll = tf.reduce_sum(elbo_ll_per_seq)

        batch_size = tf.shape(lengths)[0]
        total_batch_length = tf.reduce_sum(lengths)
        return ((elbo_ll), total_batch_length, batch_size, global_step)
Exemplo n.º 6
0
def create_graph(bound,
                 state_size,
                 num_timesteps,
                 batch_size,
                 num_samples,
                 num_eval_samples,
                 resampling_schedule,
                 use_resampling_grads,
                 learning_rate,
                 lr_decay_steps,
                 train_p,
                 dtype='float64'):
    if FLAGS.use_bs:
        true_bs = None
    else:
        true_bs = [
            np.zeros([state_size]).astype(dtype) for _ in xrange(num_timesteps)
        ]

    # Make the dataset.
    true_bs, dataset = data.make_dataset(
        bs=true_bs,
        state_size=state_size,
        num_timesteps=num_timesteps,
        batch_size=batch_size,
        num_samples=num_samples,
        variance=FLAGS.variance,
        prior_type=FLAGS.p_type,
        bimodal_prior_weight=FLAGS.bimodal_prior_weight,
        bimodal_prior_mean=FLAGS.bimodal_prior_mean,
        transition_type=FLAGS.transition_type,
        fixed_observation=FLAGS.fixed_observation,
        dtype=dtype)
    itr = dataset.make_one_shot_iterator()
    _, observations = itr.get_next()
    # Make the dataset for eval
    _, eval_dataset = data.make_dataset(
        bs=true_bs,
        state_size=state_size,
        num_timesteps=num_timesteps,
        batch_size=num_eval_samples,
        num_samples=num_eval_samples,
        variance=FLAGS.variance,
        prior_type=FLAGS.p_type,
        bimodal_prior_weight=FLAGS.bimodal_prior_weight,
        bimodal_prior_mean=FLAGS.bimodal_prior_mean,
        transition_type=FLAGS.transition_type,
        fixed_observation=FLAGS.fixed_observation,
        dtype=dtype)
    eval_itr = eval_dataset.make_one_shot_iterator()
    _, eval_observations = eval_itr.get_next()

    # Make the model.
    if bound == "fivo-aux-td":
        model = models.TDModel.create(
            state_size,
            num_timesteps,
            variance=FLAGS.variance,
            train_p=train_p,
            p_type=FLAGS.p_type,
            q_type=FLAGS.q_type,
            mixing_coeff=FLAGS.bimodal_prior_weight,
            prior_mode_mean=FLAGS.bimodal_prior_mean,
            observation_variance=FLAGS.observation_variance,
            transition_type=FLAGS.transition_type,
            use_bs=FLAGS.use_bs,
            dtype=tf.as_dtype(dtype),
            random_seed=FLAGS.random_seed)
    else:
        model = models.Model.create(
            state_size,
            num_timesteps,
            variance=FLAGS.variance,
            train_p=train_p,
            p_type=FLAGS.p_type,
            q_type=FLAGS.q_type,
            mixing_coeff=FLAGS.bimodal_prior_weight,
            prior_mode_mean=FLAGS.bimodal_prior_mean,
            observation_variance=FLAGS.observation_variance,
            transition_type=FLAGS.transition_type,
            use_bs=FLAGS.use_bs,
            r_sigma_init=FLAGS.r_sigma_init,
            dtype=tf.as_dtype(dtype),
            random_seed=FLAGS.random_seed)

    # Compute the bound and loss
    if bound == "iwae":
        (_, losses, ema_op, _, _) = bounds.iwae(model,
                                                observations,
                                                num_timesteps,
                                                num_samples=num_samples)
        (eval_log_p_hat, _, _, eval_states,
         eval_log_weights) = bounds.iwae(model,
                                         eval_observations,
                                         num_timesteps,
                                         num_samples=num_eval_samples,
                                         summarize=True)

        eval_log_p_hat = tf.reduce_mean(eval_log_p_hat)

    elif "fivo" in bound:
        if bound == "fivo-aux-td":
            (_, losses, ema_op, _,
             _) = bounds.fivo_aux_td(model,
                                     observations,
                                     num_timesteps,
                                     resampling_schedule=resampling_schedule,
                                     num_samples=num_samples)
            (eval_log_p_hat, _, _, eval_states,
             eval_log_weights) = bounds.fivo_aux_td(
                 model,
                 eval_observations,
                 num_timesteps,
                 resampling_schedule=resampling_schedule,
                 num_samples=num_eval_samples,
                 summarize=True)
        else:
            (_, losses, ema_op, _,
             _) = bounds.fivo(model,
                              observations,
                              num_timesteps,
                              resampling_schedule=resampling_schedule,
                              use_resampling_grads=use_resampling_grads,
                              resampling_type=FLAGS.resampling_method,
                              aux=("aux" in bound),
                              num_samples=num_samples)
            (eval_log_p_hat, _, _, eval_states,
             eval_log_weights) = bounds.fivo(
                 model,
                 eval_observations,
                 num_timesteps,
                 resampling_schedule=resampling_schedule,
                 use_resampling_grads=False,
                 resampling_type="multinomial",
                 aux=("aux" in bound),
                 num_samples=num_eval_samples,
                 summarize=True)
        eval_log_p_hat = tf.reduce_mean(eval_log_p_hat)

    summ.summarize_ess(eval_log_weights, only_last_timestep=True)

    # if FLAGS.p_type == "bimodal":
    # # create the observations that showcase the model.
    # mode_odds_ratio = tf.convert_to_tensor([1., 3., 1./3., 512., 1./512.],
    #                                        dtype=tf.float64)
    # mode_odds_ratio = tf.expand_dims(mode_odds_ratio, 1)
    # k = ((num_timesteps+1) * FLAGS.variance) / (2*FLAGS.bimodal_prior_mean)
    # explain_obs = tf.reduce_sum(model.p.bs) + tf.log(mode_odds_ratio) * k
    # explain_obs = tf.tile(explain_obs, [num_eval_samples, 1])
    # # run the model on the explainable observations
    # if bound == "iwae":
    #   (_, _, _, explain_states, explain_log_weights) = bounds.iwae(
    #       model,
    #       explain_obs,
    #       num_timesteps,
    #       num_samples=num_eval_samples)
    # elif bound == "fivo" or "fivo-aux":
    #   (_, _, _, explain_states, explain_log_weights) = bounds.fivo(
    #       model,
    #       explain_obs,
    #       num_timesteps,
    #       resampling_schedule=resampling_schedule,
    #       use_resampling_grads=False,
    #       resampling_type="multinomial",
    #       aux=("aux" in bound),
    #       num_samples=num_eval_samples)
    # summ.summarize_particles(explain_states,
    #                          explain_log_weights,
    #                          explain_obs,
    #                          model)

    # Calculate the true likelihood.
    if hasattr(model.p, 'likelihood') and callable(
            getattr(model.p, 'likelihood')):
        eval_likelihood = model.p.likelihood(
            eval_observations) / FLAGS.num_timesteps
    else:
        eval_likelihood = tf.zeros_like(eval_log_p_hat)

    tf.summary.scalar("log_p_hat", eval_log_p_hat)
    tf.summary.scalar("likelihood", eval_likelihood)
    tf.summary.scalar("bound_gap", eval_likelihood - eval_log_p_hat)
    summ.summarize_model(model,
                         true_bs,
                         eval_observations,
                         eval_states,
                         bound,
                         summarize_r=not bound == "fivo-aux-td")

    # Compute and apply grads.
    global_step = tf.train.get_or_create_global_step()

    apply_grads = make_apply_grads_op(losses, global_step, learning_rate,
                                      lr_decay_steps)

    # Update the emas after applying the grads.
    with tf.control_dependencies([apply_grads]):
        train_op = tf.group(ema_op)
        #train_op = tf.group(ema_op, add_check_numerics_ops())

    return global_step, train_op, eval_log_p_hat, eval_likelihood
Exemplo n.º 7
0
def create_long_chain_graph(bound,
                            state_size,
                            num_obs,
                            steps_per_obs,
                            batch_size,
                            num_samples,
                            num_eval_samples,
                            resampling_schedule,
                            use_resampling_grads,
                            learning_rate,
                            lr_decay_steps,
                            dtype="float64"):
    num_timesteps = num_obs * steps_per_obs + 1
    # Make the dataset.
    dataset = data.make_long_chain_dataset(
        state_size=state_size,
        num_obs=num_obs,
        steps_per_obs=steps_per_obs,
        batch_size=batch_size,
        num_samples=num_samples,
        variance=FLAGS.variance,
        observation_variance=FLAGS.observation_variance,
        dtype=dtype,
        observation_type=FLAGS.observation_type,
        transition_type=FLAGS.transition_type,
        fixed_observation=FLAGS.fixed_observation)
    itr = dataset.make_one_shot_iterator()
    _, observations = itr.get_next()
    # Make the dataset for eval
    eval_dataset = data.make_long_chain_dataset(
        state_size=state_size,
        num_obs=num_obs,
        steps_per_obs=steps_per_obs,
        batch_size=batch_size,
        num_samples=num_eval_samples,
        variance=FLAGS.variance,
        observation_variance=FLAGS.observation_variance,
        dtype=dtype,
        observation_type=FLAGS.observation_type,
        transition_type=FLAGS.transition_type,
        fixed_observation=FLAGS.fixed_observation)
    eval_itr = eval_dataset.make_one_shot_iterator()
    _, eval_observations = eval_itr.get_next()

    # Make the model.
    model = models.LongChainModel.create(
        state_size,
        num_obs,
        steps_per_obs,
        observation_type=FLAGS.observation_type,
        transition_type=FLAGS.transition_type,
        variance=FLAGS.variance,
        observation_variance=FLAGS.observation_variance,
        dtype=tf.as_dtype(dtype),
        disable_r=FLAGS.disable_r)

    # Compute the bound and loss
    if bound == "iwae":
        (_, losses, ema_op, _, _) = bounds.iwae(model,
                                                observations,
                                                num_timesteps,
                                                num_samples=num_samples)
        (eval_log_p_hat, _, _, _,
         eval_log_weights) = bounds.iwae(model,
                                         eval_observations,
                                         num_timesteps,
                                         num_samples=num_eval_samples,
                                         summarize=False)
        eval_log_p_hat = tf.reduce_mean(eval_log_p_hat)
    elif bound == "fivo" or "fivo-aux":
        (_, losses, ema_op, _,
         _) = bounds.fivo(model,
                          observations,
                          num_timesteps,
                          resampling_schedule=resampling_schedule,
                          use_resampling_grads=use_resampling_grads,
                          resampling_type=FLAGS.resampling_method,
                          aux=("aux" in bound),
                          num_samples=num_samples)
        (eval_log_p_hat, _, _, _, eval_log_weights) = bounds.fivo(
            model,
            eval_observations,
            num_timesteps,
            resampling_schedule=resampling_schedule,
            use_resampling_grads=False,
            resampling_type="multinomial",
            aux=("aux" in bound),
            num_samples=num_eval_samples,
            summarize=False)
        eval_log_p_hat = tf.reduce_mean(eval_log_p_hat)

    summ.summarize_ess(eval_log_weights, only_last_timestep=True)

    tf.summary.scalar("log_p_hat", eval_log_p_hat)

    # Compute and apply grads.
    global_step = tf.train.get_or_create_global_step()

    apply_grads = make_apply_grads_op(losses, global_step, learning_rate,
                                      lr_decay_steps)

    # Update the emas after applying the grads.
    with tf.control_dependencies([apply_grads]):
        train_op = tf.group(ema_op)

    # We can't calculate the likelihood for most of these models
    # so we just return zeros.
    eval_likelihood = tf.zeros([], dtype=dtype)
    return global_step, train_op, eval_log_p_hat, eval_likelihood
Exemplo n.º 8
0
def create_graph(bound, state_size, num_timesteps, batch_size,
                 num_samples, num_eval_samples, resampling_schedule,
                 use_resampling_grads, learning_rate, lr_decay_steps,
                 train_p, dtype='float64'):
  if FLAGS.use_bs:
    true_bs = None
  else:
    true_bs = [np.zeros([state_size]).astype(dtype) for _ in xrange(num_timesteps)]

  # Make the dataset.
  true_bs, dataset = data.make_dataset(
      bs=true_bs,
      state_size=state_size,
      num_timesteps=num_timesteps,
      batch_size=batch_size,
      num_samples=num_samples,
      variance=FLAGS.variance,
      prior_type=FLAGS.p_type,
      bimodal_prior_weight=FLAGS.bimodal_prior_weight,
      bimodal_prior_mean=FLAGS.bimodal_prior_mean,
      transition_type=FLAGS.transition_type,
      fixed_observation=FLAGS.fixed_observation,
      dtype=dtype)
  itr = dataset.make_one_shot_iterator()
  _, observations = itr.get_next()
  # Make the dataset for eval
  _, eval_dataset = data.make_dataset(
      bs=true_bs,
      state_size=state_size,
      num_timesteps=num_timesteps,
      batch_size=num_eval_samples,
      num_samples=num_eval_samples,
      variance=FLAGS.variance,
      prior_type=FLAGS.p_type,
      bimodal_prior_weight=FLAGS.bimodal_prior_weight,
      bimodal_prior_mean=FLAGS.bimodal_prior_mean,
      transition_type=FLAGS.transition_type,
      fixed_observation=FLAGS.fixed_observation,
      dtype=dtype)
  eval_itr = eval_dataset.make_one_shot_iterator()
  _, eval_observations = eval_itr.get_next()

  # Make the model.
  if bound == "fivo-aux-td":
    model = models.TDModel.create(
        state_size,
        num_timesteps,
        variance=FLAGS.variance,
        train_p=train_p,
        p_type=FLAGS.p_type,
        q_type=FLAGS.q_type,
        mixing_coeff=FLAGS.bimodal_prior_weight,
        prior_mode_mean=FLAGS.bimodal_prior_mean,
        observation_variance=FLAGS.observation_variance,
        transition_type=FLAGS.transition_type,
        use_bs=FLAGS.use_bs,
        dtype=tf.as_dtype(dtype),
        random_seed=FLAGS.random_seed)
  else:
    model = models.Model.create(
        state_size,
        num_timesteps,
        variance=FLAGS.variance,
        train_p=train_p,
        p_type=FLAGS.p_type,
        q_type=FLAGS.q_type,
        mixing_coeff=FLAGS.bimodal_prior_weight,
        prior_mode_mean=FLAGS.bimodal_prior_mean,
        observation_variance=FLAGS.observation_variance,
        transition_type=FLAGS.transition_type,
        use_bs=FLAGS.use_bs,
        r_sigma_init=FLAGS.r_sigma_init,
        dtype=tf.as_dtype(dtype),
        random_seed=FLAGS.random_seed)

  # Compute the bound and loss
  if bound == "iwae":
    (_, losses, ema_op, _, _) = bounds.iwae(
        model,
        observations,
        num_timesteps,
        num_samples=num_samples)
    (eval_log_p_hat, _, _, eval_states, eval_log_weights) = bounds.iwae(
        model,
        eval_observations,
        num_timesteps,
        num_samples=num_eval_samples,
        summarize=True)

    eval_log_p_hat = tf.reduce_mean(eval_log_p_hat)

  elif "fivo" in bound:
    if bound == "fivo-aux-td":
      (_, losses, ema_op, _, _) = bounds.fivo_aux_td(
          model,
          observations,
          num_timesteps,
          resampling_schedule=resampling_schedule,
          num_samples=num_samples)
      (eval_log_p_hat, _, _, eval_states, eval_log_weights) = bounds.fivo_aux_td(
          model,
          eval_observations,
          num_timesteps,
          resampling_schedule=resampling_schedule,
          num_samples=num_eval_samples,
          summarize=True)
    else:
      (_, losses, ema_op, _, _) = bounds.fivo(
          model,
          observations,
          num_timesteps,
          resampling_schedule=resampling_schedule,
          use_resampling_grads=use_resampling_grads,
          resampling_type=FLAGS.resampling_method,
          aux=("aux" in bound),
          num_samples=num_samples)
      (eval_log_p_hat, _, _, eval_states, eval_log_weights) = bounds.fivo(
          model,
          eval_observations,
          num_timesteps,
          resampling_schedule=resampling_schedule,
          use_resampling_grads=False,
          resampling_type="multinomial",
          aux=("aux" in bound),
          num_samples=num_eval_samples,
          summarize=True)
    eval_log_p_hat = tf.reduce_mean(eval_log_p_hat)

  summ.summarize_ess(eval_log_weights, only_last_timestep=True)

  # if FLAGS.p_type == "bimodal":
    # # create the observations that showcase the model.
    # mode_odds_ratio = tf.convert_to_tensor([1., 3., 1./3., 512., 1./512.],
    #                                        dtype=tf.float64)
    # mode_odds_ratio = tf.expand_dims(mode_odds_ratio, 1)
    # k = ((num_timesteps+1) * FLAGS.variance) / (2*FLAGS.bimodal_prior_mean)
    # explain_obs = tf.reduce_sum(model.p.bs) + tf.log(mode_odds_ratio) * k
    # explain_obs = tf.tile(explain_obs, [num_eval_samples, 1])
    # # run the model on the explainable observations
    # if bound == "iwae":
    #   (_, _, _, explain_states, explain_log_weights) = bounds.iwae(
    #       model,
    #       explain_obs,
    #       num_timesteps,
    #       num_samples=num_eval_samples)
    # elif bound == "fivo" or "fivo-aux":
    #   (_, _, _, explain_states, explain_log_weights) = bounds.fivo(
    #       model,
    #       explain_obs,
    #       num_timesteps,
    #       resampling_schedule=resampling_schedule,
    #       use_resampling_grads=False,
    #       resampling_type="multinomial",
    #       aux=("aux" in bound),
    #       num_samples=num_eval_samples)
    # summ.summarize_particles(explain_states,
    #                          explain_log_weights,
    #                          explain_obs,
    #                          model)

  # Calculate the true likelihood.
  if hasattr(model.p, 'likelihood') and callable(getattr(model.p, 'likelihood')):
    eval_likelihood = model.p.likelihood(eval_observations)/ FLAGS.num_timesteps
  else:
    eval_likelihood = tf.zeros_like(eval_log_p_hat)

  tf.summary.scalar("log_p_hat", eval_log_p_hat)
  tf.summary.scalar("likelihood", eval_likelihood)
  tf.summary.scalar("bound_gap", eval_likelihood - eval_log_p_hat)
  summ.summarize_model(model, true_bs, eval_observations, eval_states, bound,
                       summarize_r=not bound == "fivo-aux-td")

  # Compute and apply grads.
  global_step = tf.train.get_or_create_global_step()

  apply_grads = make_apply_grads_op(losses,
                                    global_step,
                                    learning_rate,
                                    lr_decay_steps)

  # Update the emas after applying the grads.
  with tf.control_dependencies([apply_grads]):
    train_op = tf.group(ema_op)
    #train_op = tf.group(ema_op, add_check_numerics_ops())

  return global_step, train_op, eval_log_p_hat, eval_likelihood
Exemplo n.º 9
0
def create_long_chain_graph(bound, state_size, num_obs, steps_per_obs,
                            batch_size, num_samples, num_eval_samples,
                            resampling_schedule, use_resampling_grads,
                            learning_rate, lr_decay_steps, dtype="float64"):
  num_timesteps = num_obs * steps_per_obs + 1
  # Make the dataset.
  dataset = data.make_long_chain_dataset(
      state_size=state_size,
      num_obs=num_obs,
      steps_per_obs=steps_per_obs,
      batch_size=batch_size,
      num_samples=num_samples,
      variance=FLAGS.variance,
      observation_variance=FLAGS.observation_variance,
      dtype=dtype,
      observation_type=FLAGS.observation_type,
      transition_type=FLAGS.transition_type,
      fixed_observation=FLAGS.fixed_observation)
  itr = dataset.make_one_shot_iterator()
  _, observations = itr.get_next()
  # Make the dataset for eval
  eval_dataset = data.make_long_chain_dataset(
      state_size=state_size,
      num_obs=num_obs,
      steps_per_obs=steps_per_obs,
      batch_size=batch_size,
      num_samples=num_eval_samples,
      variance=FLAGS.variance,
      observation_variance=FLAGS.observation_variance,
      dtype=dtype,
      observation_type=FLAGS.observation_type,
      transition_type=FLAGS.transition_type,
      fixed_observation=FLAGS.fixed_observation)
  eval_itr = eval_dataset.make_one_shot_iterator()
  _, eval_observations = eval_itr.get_next()

  # Make the model.
  model = models.LongChainModel.create(
      state_size,
      num_obs,
      steps_per_obs,
      observation_type=FLAGS.observation_type,
      transition_type=FLAGS.transition_type,
      variance=FLAGS.variance,
      observation_variance=FLAGS.observation_variance,
      dtype=tf.as_dtype(dtype),
      disable_r=FLAGS.disable_r)

  # Compute the bound and loss
  if bound == "iwae":
    (_, losses, ema_op, _, _) = bounds.iwae(
        model,
        observations,
        num_timesteps,
        num_samples=num_samples)
    (eval_log_p_hat, _, _, _, eval_log_weights) = bounds.iwae(
        model,
        eval_observations,
        num_timesteps,
        num_samples=num_eval_samples,
        summarize=False)
    eval_log_p_hat = tf.reduce_mean(eval_log_p_hat)
  elif bound == "fivo" or "fivo-aux":
    (_, losses, ema_op, _, _) = bounds.fivo(
        model,
        observations,
        num_timesteps,
        resampling_schedule=resampling_schedule,
        use_resampling_grads=use_resampling_grads,
        resampling_type=FLAGS.resampling_method,
        aux=("aux" in bound),
        num_samples=num_samples)
    (eval_log_p_hat, _, _, _, eval_log_weights) = bounds.fivo(
        model,
        eval_observations,
        num_timesteps,
        resampling_schedule=resampling_schedule,
        use_resampling_grads=False,
        resampling_type="multinomial",
        aux=("aux" in bound),
        num_samples=num_eval_samples,
        summarize=False)
    eval_log_p_hat = tf.reduce_mean(eval_log_p_hat)

  summ.summarize_ess(eval_log_weights, only_last_timestep=True)

  tf.summary.scalar("log_p_hat", eval_log_p_hat)

  # Compute and apply grads.
  global_step = tf.train.get_or_create_global_step()

  apply_grads = make_apply_grads_op(losses,
                                    global_step,
                                    learning_rate,
                                    lr_decay_steps)

  # Update the emas after applying the grads.
  with tf.control_dependencies([apply_grads]):
    train_op = tf.group(ema_op)

  # We can't calculate the likelihood for most of these models
  # so we just return zeros.
  eval_likelihood = tf.zeros([], dtype=dtype)
  return global_step, train_op, eval_log_p_hat, eval_likelihood