示例#1
0
 def setUp(self):
     super(LEOTest, self).setUp()
     self._problem = _random_problem_instance(5, 7, 4)
     # This doesn"t call any function, so doesn't need the mocks to be started.
     self._config = get_test_config()
     self._leo = model.LEO(config=self._config)
     self.addCleanup(mock.patch.stopall)
示例#2
0
def construct_graph(outer_model_config):
    """Constructs the optimization graph."""
    inner_model_config = config.get_inner_model_config()
    tf.logging.info("inner_model_config: {}".format(inner_model_config))
    leo = model.LEO(inner_model_config, use_64bits_dtype=False)

    num_classes = outer_model_config["num_classes"]
    num_tr_examples_per_class = outer_model_config["num_tr_examples_per_class"]
    metatrain_batch = _construct_examples_batch(
        outer_model_config["metatrain_batch_size"], "train", num_classes,
        num_tr_examples_per_class,
        outer_model_config["num_val_examples_per_class"])
    metatrain_loss, metatrain_accuracy = _construct_loss_and_accuracy(
        leo, metatrain_batch, True)

    metatrain_gradients, metatrain_variables = leo.grads_and_vars(
        metatrain_loss)

    # Avoids NaNs in summaries.
    metatrain_loss = tf.cond(tf.is_nan(metatrain_loss),
                             lambda: tf.zeros_like(metatrain_loss),
                             lambda: metatrain_loss)

    metatrain_gradients = _clip_gradients(
        metatrain_gradients, outer_model_config["gradient_threshold"],
        outer_model_config["gradient_norm_threshold"])

    _construct_training_summaries(metatrain_loss, metatrain_accuracy,
                                  metatrain_gradients, metatrain_variables)
    optimizer = tf.train.AdamOptimizer(
        learning_rate=outer_model_config["outer_lr"])
    global_step = tf.train.get_or_create_global_step()
    train_op = optimizer.apply_gradients(
        list(zip(metatrain_gradients, metatrain_variables)), global_step)

    data_config = config.get_data_config()
    tf.logging.info("data_config: {}".format(data_config))
    total_examples_per_class = data_config["total_examples_per_class"]
    metavalid_batch = _construct_examples_batch(
        outer_model_config["metavalid_batch_size"], "val", num_classes,
        num_tr_examples_per_class,
        total_examples_per_class - num_tr_examples_per_class)
    metavalid_loss, metavalid_accuracy = _construct_loss_and_accuracy(
        leo, metavalid_batch, False)

    metatest_batch = _construct_examples_batch(
        outer_model_config["metatest_batch_size"], "test", num_classes,
        num_tr_examples_per_class,
        total_examples_per_class - num_tr_examples_per_class)
    _, metatest_accuracy = _construct_loss_and_accuracy(
        leo, metatest_batch, False)
    _construct_validation_summaries(metavalid_loss, metavalid_accuracy)

    return (train_op, global_step, metatrain_accuracy, metavalid_accuracy,
            metatest_accuracy)
示例#3
0
def construct_graph(outer_model_config):
  """Constructs the optimization graph."""
  inner_model_config = config.get_inner_model_config()
  tf.logging.info("inner_model_config: {}".format(inner_model_config))
  leo = model.LEO(inner_model_config, use_64bits_dtype=False)

  num_classes = outer_model_config["num_classes"]
  num_tr_examples_per_class = outer_model_config["num_tr_examples_per_class"]
  metatrain_batch = _construct_examples_batch(
      outer_model_config["metatrain_batch_size"], "train", num_classes,
      num_tr_examples_per_class,
      outer_model_config["num_val_examples_per_class"])
  metatrain_loss, metatrain_accuracy, metatrain_generalization_loss, \
  kl_components, adapted_kl_components, kl_zn, adapted_kl_zn, kl, adapted_kl, \
  latents, adapted_latents, spurious = _construct_loss_and_accuracy(
      leo, metatrain_batch, True) #returned by the inner_leo_loop

  metatrain_gradients, metatrain_variables = leo.grads_and_vars(metatrain_loss)

  # Avoids NaNs in summaries.
  metatrain_loss = tf.cond(tf.is_nan(metatrain_loss),
                           lambda: tf.zeros_like(metatrain_loss),
                           lambda: metatrain_loss)
  # adapted_kl_components = tf.cond(tf.is_nan(adapted_kl_components),
  #                          lambda: tf.zeros_like(adapted_kl_components),
  #                          lambda: adapted_kl_components)
  #
  # adapted_kl_zn = tf.cond(tf.is_nan(adapted_kl_zn),
  #                                 lambda: tf.zeros_like(adapted_kl_zn),
  #                                 lambda: adapted_kl_zn)
  # adapted_kl = tf.cond(tf.is_nan(adapted_kl),
  #                         lambda: tf.zeros_like(adapted_kl),
  #                         lambda: adapted_kl)
  # kl = tf.cond(tf.is_nan(kl),
  #                      lambda: tf.zeros_like(kl),
  #                      lambda: kl)
  metatrain_gradients = _clip_gradients(
      metatrain_gradients, outer_model_config["gradient_threshold"],
      outer_model_config["gradient_norm_threshold"])

  _construct_training_summaries(metatrain_loss, metatrain_accuracy,
                                metatrain_gradients, metatrain_variables, metatrain_generalization_loss,
                                kl_components, adapted_kl_components, kl_zn, adapted_kl_zn, kl, adapted_kl,  latents, adapted_latents, spurious)
  optimizer = tf.train.AdamOptimizer(
      learning_rate=outer_model_config["outer_lr"])
  global_step = tf.train.get_or_create_global_step()
  train_op = optimizer.apply_gradients(
      list(zip(metatrain_gradients, metatrain_variables)), global_step)
  #after applying the gradients, compute the meta-validation loss using the same algorithm
  data_config = config.get_data_config()
  tf.logging.info("data_config: {}".format(data_config))
  total_examples_per_class = data_config["total_examples_per_class"]
  metavalid_batch = _construct_examples_batch(
      outer_model_config["metavalid_batch_size"], "val", num_classes,
      num_tr_examples_per_class,
      total_examples_per_class - num_tr_examples_per_class)
  metavalid_loss, metavalid_accuracy, metavalid_generalization_loss, _, _, _, _, _, _, _, _, _ = _construct_loss_and_accuracy(
      leo, metavalid_batch, False)

  metatest_batch = _construct_examples_batch(
      outer_model_config["metatest_batch_size"], "test", num_classes,
      num_tr_examples_per_class,
      total_examples_per_class - num_tr_examples_per_class)
  _, metatest_accuracy, _, _, _, _, _, _, _, _, _, _ = _construct_loss_and_accuracy(
      leo, metatest_batch, False)

  _construct_validation_summaries(metavalid_loss, metavalid_accuracy, metavalid_generalization_loss)

  return (train_op, global_step, metatrain_accuracy, metavalid_accuracy,
          metatest_accuracy, kl_components, adapted_kl_components, kl_zn, adapted_kl_zn, kl, adapted_kl, latents, adapted_latents, spurious)
示例#4
0
 def test_construct_float32_leo_graph(self):
     leo = model.LEO(use_64bits_dtype=False, config=self._config)
     problem_instance_32_bits = _random_problem_instance(
         use_64bits_dtype=False)
     leo(problem_instance_32_bits)