Exemplo n.º 1
0
def construct_graph(outer_model_config, a, b, layers):
    """Constructs the optimization graph."""
    inner_model_config = config.get_inner_model_config()
    tf.logging.info("inner_model_config: {}".format(inner_model_config))
    num_classes = outer_model_config["num_classes"]
    maml = model.LeastSquareMeta(layers,
                                 a,
                                 num_classes,
                                 limited=False,
                                 l2_weight=b)

    num_tr_examples_per_class = outer_model_config["num_tr_examples_per_class"]
    metatrain_batch = _construct_examples_batch_unlimited(
        outer_model_config["metatrain_batch_size"], "train", num_classes,
        num_tr_examples_per_class,
        outer_model_config["num_val_examples_per_class"])
    metatrain_loss, _, metatrain_accuracy = _construct_loss_and_accuracy(
        maml, metatrain_batch, True)

    metatrain_gradients, metatrain_variables = maml.grads_and_vars(
        metatrain_loss)

    # Avoids NaNs in summaries.
    metatrain_loss = tf.cond(tf.is_nan(metatrain_loss),
                             lambda: tf.zeros_like(metatrain_loss),
                             lambda: metatrain_loss)

    metatrain_gradients = _clip_gradients(
        metatrain_gradients, outer_model_config["gradient_threshold"],
        outer_model_config["gradient_norm_threshold"])

    _construct_training_summaries(metatrain_loss, metatrain_accuracy,
                                  metatrain_gradients, metatrain_variables)
    optimizer = tf.train.AdamOptimizer(
        learning_rate=outer_model_config["outer_lr"])
    global_step = tf.train.get_or_create_global_step()
    train_op = optimizer.apply_gradients(
        list(zip(metatrain_gradients, metatrain_variables)), global_step)

    data_config = config.get_data_config()
    tf.logging.info("data_config: {}".format(data_config))
    total_examples_per_class = data_config["total_examples_per_class"]
    metavalid_batch = _construct_examples_batch_unlimited(
        outer_model_config["metavalid_batch_size"], "val", num_classes,
        num_tr_examples_per_class,
        total_examples_per_class - num_tr_examples_per_class)
    metavalid_loss, _, metavalid_accuracy = _construct_loss_and_accuracy(
        maml, metavalid_batch, False)

    metatest_batch = _construct_examples_batch_unlimited(
        outer_model_config["metatest_batch_size"], "test", num_classes,
        num_tr_examples_per_class,
        total_examples_per_class - num_tr_examples_per_class)
    _, _, metatest_accuracy = _construct_loss_and_accuracy(
        maml, metatest_batch, False)
    _construct_validation_summaries(metavalid_loss, metavalid_accuracy)

    return (train_op, global_step, metatrain_accuracy, metavalid_accuracy,
            metatest_accuracy)
Exemplo n.º 2
0
def _construct_examples_batch(batch_size, split, num_classes,
                              num_tr_examples_per_class,
                              num_val_examples_per_class):
  data_provider = data.DataProvider(split, config.get_data_config())
  examples_batch = data_provider.get_batch(batch_size, num_classes,
                                           num_tr_examples_per_class,
                                           num_val_examples_per_class)
  return utils.unpack_data(examples_batch)
Exemplo n.º 3
0
def main(argv):
    del argv  # Unused.
    tf.logging.set_verbosity(tf.logging.INFO)
    print(config.get_data_config())
    print(config.get_inner_model_config())
    print(config.get_outer_model_config())

    run_training_loop(FLAGS.checkpoint_path)
Exemplo n.º 4
0
def _construct_examples_batch(batch_size, split, num_classes,
                              num_tr_examples_per_class,
                              num_val_examples_per_class,
                              use_cross=False):
  data_provider = data.DataProvider(split, config.get_data_config(), feat_dim=FLAGS.feat_dim, use_cross=use_cross)
  examples_batch = data_provider.get_batch(batch_size, num_classes,
                                           num_tr_examples_per_class,
                                           num_val_examples_per_class,
                                           num_pretrain_classes=FLAGS.num_pretrain_classes)
  return utils.unpack_data(examples_batch)
Exemplo n.º 5
0
def _construct_examples_batch(batch_size,
                              split,
                              num_classes,
                              num_tr_examples_per_class,
                              num_val_examples_per_class,
                              db_path,
                              sp_para=None):
    data_provider = data.DataProvider(split, config.get_data_config())
    data_provider.load_db(db_path)
    if sp_para:
        test_id, sp_bias, weights, k = sp_para
        data_provider.set_sp_paras(weights, sp_bias)
        data_provider.set_test_id(test_id, k)

    examples_batch = data_provider.get_batch(batch_size, num_classes,
                                             num_tr_examples_per_class,
                                             num_val_examples_per_class)
    return utils.unpack_data(examples_batch)
Exemplo n.º 6
0
def build_db(checkpoint_path, db_name, sample_size, db_title=""):
    outer_model_config = config.get_outer_model_config()

    num_classes = outer_model_config["num_classes"]
    tr_size = outer_model_config["num_tr_examples_per_class"]

    if db_name == "test":
        val_size = 600 - tr_size
    else:
        val_size = outer_model_config["num_val_examples_per_class"]

    save_path = osp.join(
        checkpoint_path,
        "%s%s_%i_%i_%i" % (db_name, db_title, sample_size, tr_size, val_size))

    if not osp.exists(save_path):
        provider = data.DataProvider(db_name,
                                     config.get_data_config(),
                                     verbose=False)
        provider.create_db(sample_size, num_classes, tr_size, val_size)
        provider.save_db(save_path)

    return osp.basename(save_path)
Exemplo n.º 7
0
def construct_graph(outer_model_config):
  """Constructs the optimization graph."""
  inner_model_config = config.get_inner_model_config()
  tf.logging.info("inner_model_config: {}".format(inner_model_config))
  leo = model.LEO(inner_model_config, use_64bits_dtype=False)

  num_classes = outer_model_config["num_classes"]
  num_tr_examples_per_class = outer_model_config["num_tr_examples_per_class"]
  metatrain_batch = _construct_examples_batch(
      outer_model_config["metatrain_batch_size"], "train", num_classes,
      num_tr_examples_per_class,
      outer_model_config["num_val_examples_per_class"])
  metatrain_loss, metatrain_accuracy, metatrain_generalization_loss, \
  kl_components, adapted_kl_components, kl_zn, adapted_kl_zn, kl, adapted_kl, \
  latents, adapted_latents, spurious = _construct_loss_and_accuracy(
      leo, metatrain_batch, True) #returned by the inner_leo_loop

  metatrain_gradients, metatrain_variables = leo.grads_and_vars(metatrain_loss)

  # Avoids NaNs in summaries.
  metatrain_loss = tf.cond(tf.is_nan(metatrain_loss),
                           lambda: tf.zeros_like(metatrain_loss),
                           lambda: metatrain_loss)
  # adapted_kl_components = tf.cond(tf.is_nan(adapted_kl_components),
  #                          lambda: tf.zeros_like(adapted_kl_components),
  #                          lambda: adapted_kl_components)
  #
  # adapted_kl_zn = tf.cond(tf.is_nan(adapted_kl_zn),
  #                                 lambda: tf.zeros_like(adapted_kl_zn),
  #                                 lambda: adapted_kl_zn)
  # adapted_kl = tf.cond(tf.is_nan(adapted_kl),
  #                         lambda: tf.zeros_like(adapted_kl),
  #                         lambda: adapted_kl)
  # kl = tf.cond(tf.is_nan(kl),
  #                      lambda: tf.zeros_like(kl),
  #                      lambda: kl)
  metatrain_gradients = _clip_gradients(
      metatrain_gradients, outer_model_config["gradient_threshold"],
      outer_model_config["gradient_norm_threshold"])

  _construct_training_summaries(metatrain_loss, metatrain_accuracy,
                                metatrain_gradients, metatrain_variables, metatrain_generalization_loss,
                                kl_components, adapted_kl_components, kl_zn, adapted_kl_zn, kl, adapted_kl,  latents, adapted_latents, spurious)
  optimizer = tf.train.AdamOptimizer(
      learning_rate=outer_model_config["outer_lr"])
  global_step = tf.train.get_or_create_global_step()
  train_op = optimizer.apply_gradients(
      list(zip(metatrain_gradients, metatrain_variables)), global_step)
  #after applying the gradients, compute the meta-validation loss using the same algorithm
  data_config = config.get_data_config()
  tf.logging.info("data_config: {}".format(data_config))
  total_examples_per_class = data_config["total_examples_per_class"]
  metavalid_batch = _construct_examples_batch(
      outer_model_config["metavalid_batch_size"], "val", num_classes,
      num_tr_examples_per_class,
      total_examples_per_class - num_tr_examples_per_class)
  metavalid_loss, metavalid_accuracy, metavalid_generalization_loss, _, _, _, _, _, _, _, _, _ = _construct_loss_and_accuracy(
      leo, metavalid_batch, False)

  metatest_batch = _construct_examples_batch(
      outer_model_config["metatest_batch_size"], "test", num_classes,
      num_tr_examples_per_class,
      total_examples_per_class - num_tr_examples_per_class)
  _, metatest_accuracy, _, _, _, _, _, _, _, _, _, _ = _construct_loss_and_accuracy(
      leo, metatest_batch, False)

  _construct_validation_summaries(metavalid_loss, metavalid_accuracy, metavalid_generalization_loss)

  return (train_op, global_step, metatrain_accuracy, metavalid_accuracy,
          metatest_accuracy, kl_components, adapted_kl_components, kl_zn, adapted_kl_zn, kl, adapted_kl, latents, adapted_latents, spurious)
Exemplo n.º 8
0
def construct_graph(outer_model_config):
    """Constructs the optimization graph."""
    inner_model_config = config.get_inner_model_config()
    tf.logging.info("inner_model_config: {}".format(inner_model_config))
    if FLAGS.deconfound:
        leo = model.IFSL(inner_model_config,
                         use_64bits_dtype=False,
                         n_splits=FLAGS.n_splits,
                         is_cosine_feature=FLAGS.is_cosine_feature,
                         fusion=FLAGS.fusion,
                         classifier=FLAGS.classifier,
                         num_classes=FLAGS.pretrain_num_classes,
                         logit_fusion=FLAGS.logit_fusion,
                         use_x_only=FLAGS.use_x_only,
                         preprocess_before_split=FLAGS.preprocess_before_split,
                         preprocess_after_split=FLAGS.preprocess_after_split,
                         normalize_before_center=FLAGS.normalize_before_center,
                         normalize_d=FLAGS.normalize_d,
                         normalize_ed=FLAGS.normalize_ed)
    else:
        # leo = model.LEO(inner_model_config, use_64bits_dtype=False)
        leo = model.IFSL(inner_model_config, False, 1, False, "concat",
                         "single", FLAGS.pretrain_num_classes, "product", True)

    num_classes = outer_model_config["num_classes"]
    num_tr_examples_per_class = outer_model_config["num_tr_examples_per_class"]
    metatrain_batch = _construct_examples_batch(
        outer_model_config["metatrain_batch_size"], "train", num_classes,
        num_tr_examples_per_class,
        outer_model_config["num_val_examples_per_class"])
    metatrain_loss, metatrain_accuracy, metatrain_dacc, metatrain_hessians = _construct_loss_and_accuracy(
        leo, metatrain_batch, True)

    metatrain_gradients, metatrain_variables, _ = leo.grads_and_vars(
        metatrain_loss)

    # Avoids NaNs in summaries.
    metatrain_loss = tf.cond(tf.is_nan(metatrain_loss),
                             lambda: tf.zeros_like(metatrain_loss),
                             lambda: metatrain_loss)

    metatrain_gradients = _clip_gradients(
        metatrain_gradients, outer_model_config["gradient_threshold"],
        outer_model_config["gradient_norm_threshold"])

    _construct_training_summaries(metatrain_loss, metatrain_accuracy,
                                  metatrain_gradients, metatrain_variables,
                                  metatrain_hessians)
    optimizer = tf.train.AdamOptimizer(
        learning_rate=outer_model_config["outer_lr"])
    global_step = tf.train.get_or_create_global_step()
    train_op = optimizer.apply_gradients(
        list(zip(metatrain_gradients, metatrain_variables)), global_step)

    data_config = config.get_data_config()
    tf.logging.info("data_config: {}".format(data_config))
    total_examples_per_class = data_config["total_examples_per_class"]

    split = "val"
    metavalid_batch = _construct_examples_batch(
        outer_model_config["metavalid_batch_size"], split, num_classes,
        num_tr_examples_per_class,
        total_examples_per_class - num_tr_examples_per_class)
    metavalid_loss, metavalid_accuracy, metavalid_dacc, _ = _construct_loss_and_accuracy(
        leo, metavalid_batch, False)

    if not FLAGS.cross:
        metatest_batch = _construct_examples_batch(
            outer_model_config["metatest_batch_size"],
            "test",
            num_classes,
            num_tr_examples_per_class,
            total_examples_per_class - num_tr_examples_per_class,
            use_cross=FLAGS.cross)
    else:
        metatest_batch = _construct_examples_batch(
            outer_model_config["metatest_batch_size"],
            "test",
            num_classes,
            num_tr_examples_per_class,
            15,
            use_cross=FLAGS.cross)

    _, metatest_accuracy, metatest_dacc, _ = _construct_loss_and_accuracy(
        leo, metatest_batch, False)
    _construct_validation_summaries(metavalid_loss, metavalid_accuracy)

    break_down_batch = _construct_examples_batch(1, "test", num_classes,
                                                 num_tr_examples_per_class, 15)
    hardness, correct = leo(break_down_batch, False, True, True)
    return (train_op, global_step, metatrain_accuracy, metavalid_accuracy,
            metatest_accuracy, metatrain_dacc, metavalid_dacc, metatest_dacc,
            hardness, correct)
Exemplo n.º 9
0
def sp_construct_graph(lam, layers, outer_model_config, train_path, test_path,
                       sp_para):
    """Constructs the optimization graph."""
    inner_model_config = config.get_inner_model_config()
    tf.logging.info("inner_model_config: {}".format(inner_model_config))
    num_classes = outer_model_config["num_classes"]

    leo = model.LeastSquareMeta(layers,
                                lam,
                                num_classes,
                                limited=False,
                                l2_weight=1e-6)

    test_data = data.DataProvider("test", config.get_data_config())
    test_data.load_db(test_path)

    num_tr_examples_per_class = outer_model_config["num_tr_examples_per_class"]

    metatrain_batch = _construct_examples_batch(
        outer_model_config["metatrain_batch_size"],
        "train",
        num_classes,
        num_tr_examples_per_class,
        outer_model_config["num_val_examples_per_class"],
        train_path,
        sp_para=sp_para)

    data_config = config.get_data_config()
    tf.logging.info("data_config: {}".format(data_config))
    total_examples_per_class = data_config["total_examples_per_class"]
    test_id, sp_bias, weights, k = sp_para
    metavalid_batch = _construct_examples_batch(
        1,
        "test",
        num_classes,
        num_tr_examples_per_class,
        total_examples_per_class - num_tr_examples_per_class,
        test_path,
        sp_para=(test_id, False, None, None))
    metavalid_loss, _, metavalid_accuracy = _construct_loss_and_accuracy(
        leo, metavalid_batch, False)

    metatrain_loss, reg_loss, metatrain_accuracy = _construct_loss_and_accuracy(
        leo, metatrain_batch, True, test_batch=metavalid_batch)

    metatrain_gradients, metatrain_variables = leo.grads_and_vars(
        metatrain_loss)

    metatrain_loss = tf.cond(tf.is_nan(metatrain_loss),
                             lambda: tf.zeros_like(metatrain_loss),
                             lambda: metatrain_loss)

    metatrain_gradients = _clip_gradients(
        metatrain_gradients, outer_model_config["gradient_threshold"],
        outer_model_config["gradient_norm_threshold"])

    _construct_training_summaries(metatrain_loss, metatrain_accuracy,
                                  metatrain_gradients, metatrain_variables)
    optimizer = tf.train.AdamOptimizer(
        learning_rate=outer_model_config["outer_lr"])

    global_step = tf.train.get_or_create_global_step()
    train_op = optimizer.apply_gradients(
        list(zip(metatrain_gradients, metatrain_variables)), global_step)

    reset_optimizer_op = tf.variables_initializer(optimizer.variables())

    return (train_op, global_step, metatrain_loss, reg_loss,
            metavalid_accuracy, reset_optimizer_op)