示例#1
0
 def train(self, num_steps_limit, checkpoint_steps, checkpoint_path):
     global_step_ev = tf.compat.v1.train.global_step(self.sess, self.global_step)
     best_eval_metric = (0.0 if self.eval_metric_type == 'acc' else 1.e16)
     while global_step_ev <= num_steps_limit:
         if global_step_ev % checkpoint_steps == 0:
             # evaluating model when checkpointing
             eval_tr_metric_ev, eval_val_metric_ev = utils.evaluate_and_average(
                     self.sess, [self.eval_tr_metric, self.eval_val_metric], 10)
             print("[  Step: {1} meta-valid context_{0}: {2:.5f}, "
                     "meta-valid target_{0}: {3:.5f}  ]".format(self.eval_metric_type, 
                     global_step_ev, eval_tr_metric_ev, eval_val_metric_ev))   
             # copy best checkpoints for early stopping
             if self.eval_metric_type == 'acc':
                 if eval_val_metric_ev > best_eval_metric:
                     utils.copy_checkpoint(checkpoint_path, global_step_ev, 
                             eval_val_metric_ev, eval_metric_type=self.eval_metric_type)
                     best_eval_metric = eval_val_metric_ev
             else:
                 if eval_val_metric_ev < best_eval_metric:
                     utils.copy_checkpoint(checkpoint_path, global_step_ev, 
                             eval_val_metric_ev, eval_metric_type=self.eval_metric_type)
                     best_eval_metric = eval_val_metric_ev
             self.visualise(save_name="{0}-{1}".format(self.name, global_step_ev))
         if global_step_ev == num_steps_limit:
             global_step_ev += 1
             continue
         # train step
         _, train_tr_metric_ev, train_val_metric_ev = self.sess.run([self.train_op, self.train_tr_metric, self.train_val_metric])
         global_step_ev = tf.compat.v1.train.global_step(self.sess, self.global_step)
示例#2
0
def run_training_loop(checkpoint_path, a, b, layers):
    """Runs the training loop, either saving a checkpoint or evaluating it."""
    outer_model_config = config.get_outer_model_config()
    tf.logging.set_verbosity(tf.logging.INFO)
    tf.logging.info("outer_model_config: {}".format(outer_model_config))

    (train_op, global_step, metatrain_accuracy, metavalid_accuracy,
     metatest_accuracy) = construct_graph(outer_model_config, a, b, layers)

    num_steps_limit = outer_model_config["num_steps_limit"]
    best_metavalid_accuracy = 0.

    tf_config = tf.ConfigProto()
    tf_config.gpu_options.allow_growth = True

    with tf.train.MonitoredTrainingSession(
            checkpoint_dir=checkpoint_path,
            save_summaries_steps=FLAGS.checkpoint_steps,
            log_step_count_steps=FLAGS.checkpoint_steps,
            save_checkpoint_steps=FLAGS.checkpoint_steps,
            # summary_dir=checkpoint_path,
            config=tf_config) as sess:
        if not FLAGS.evaluation_mode:
            global_step_ev = sess.run(global_step)
            while global_step_ev < num_steps_limit:
                if global_step_ev % FLAGS.checkpoint_steps == 0:
                    # Just after saving checkpoint, calculate accuracy 10 times and save
                    # the best checkpoint for early stopping.
                    metavalid_accuracy_ev = utils.evaluate_and_average(
                        sess, metavalid_accuracy, 10)
                    tf.logging.info("Step: {} meta-valid accuracy: {}".format(
                        global_step_ev, metavalid_accuracy_ev))

                    if metavalid_accuracy_ev > best_metavalid_accuracy:
                        utils.copy_checkpoint(checkpoint_path, global_step_ev,
                                              metavalid_accuracy_ev)
                        best_metavalid_accuracy = metavalid_accuracy_ev

                _, global_step_ev, metatrain_accuracy_ev = sess.run(
                    [train_op, global_step, metatrain_accuracy])
                if global_step_ev % (FLAGS.checkpoint_steps // 2) == 0:
                    tf.logging.info("Step: {} meta-train accuracy: {}".format(
                        global_step_ev, metatrain_accuracy_ev))
        else:
            assert not FLAGS.checkpoint_steps
            num_metatest_estimates = (
                10000 // outer_model_config["metatest_batch_size"])

            test_accuracy = utils.evaluate_and_average(sess,
                                                       metatest_accuracy,
                                                       num_metatest_estimates,
                                                       has_std=True)
            return test_accuracy
示例#3
0
def run_training_loop(checkpoint_path):
    """Runs the training loop, either saving a checkpoint or evaluating it."""
    outer_model_config = config.get_outer_model_config()
    tf.logging.info("outer_model_config: {}".format(outer_model_config))
    (train_op, global_step, metatrain_accuracy, metavalid_accuracy,
     metatest_accuracy) = construct_graph(outer_model_config)

    num_steps_limit = outer_model_config["num_steps_limit"]
    best_metavalid_accuracy = 0.

    with tf.train.MonitoredTrainingSession(
            checkpoint_dir=checkpoint_path,
            save_summaries_steps=FLAGS.checkpoint_steps,
            log_step_count_steps=FLAGS.checkpoint_steps,
            save_checkpoint_steps=FLAGS.checkpoint_steps,
            summary_dir=checkpoint_path) as sess:
        if not FLAGS.evaluation_mode:
            global_step_ev = sess.run(global_step)
            while global_step_ev < num_steps_limit:
                if global_step_ev % FLAGS.checkpoint_steps == 0:
                    # Just after saving checkpoint, calculate accuracy 10 times and save
                    # the best checkpoint for early stopping.
                    metavalid_accuracy_ev = utils.evaluate_and_average(
                        sess, metavalid_accuracy, 10)
                    tf.logging.info("Step: {} meta-valid accuracy: {}".format(
                        global_step_ev, metavalid_accuracy_ev))

                    if metavalid_accuracy_ev > best_metavalid_accuracy:
                        utils.copy_checkpoint(checkpoint_path, global_step_ev,
                                              metavalid_accuracy_ev)
                        best_metavalid_accuracy = metavalid_accuracy_ev

                _, global_step_ev, metatrain_accuracy_ev = sess.run(
                    [train_op, global_step, metatrain_accuracy])

                if global_step_ev % (FLAGS.checkpoint_steps // 2) == 0:
                    tf.logging.info("Step: {} meta-train accuracy: {}".format(
                        global_step_ev, metatrain_accuracy_ev))
                global_step_ev += 1
        else:
            # assert not FLAGS.checkpoint_steps
            num_metatest_estimates = (
                10000 // outer_model_config["metatest_batch_size"])
            # num_metatest_estimates = 10
            test_accuracy = utils.evaluate_and_average(sess, metatest_accuracy,
                                                       num_metatest_estimates)

            tf.logging.info("Metatest accuracy: %f", test_accuracy)
            with tf.gfile.Open(os.path.join(checkpoint_path, "test_accuracy"),
                               "wb") as f:
                pickle.dump(test_accuracy, f)
示例#4
0
def run_training_loop(checkpoint_path):
  """Runs the training loop, either saving a checkpoint or evaluating it."""
  outer_model_config = config.get_outer_model_config()
  tf.logging.info("outer_model_config: {}".format(outer_model_config))
  (train_op, global_step, metatrain_accuracy, metavalid_accuracy,
   metatest_accuracy, kl_components, adapted_kl_components, kl_zn, adapted_kl_zn, kl, adapted_kl, latents, adapted_latents, spurious) = construct_graph(outer_model_config)

  num_steps_limit = outer_model_config["num_steps_limit"]
  best_metavalid_accuracy = 0.

  # curate summary
  classes_seen = {}

  kl_components_hist = []
  adapted_kl_components_hist = []

  kl_zn_hist = []
  adapted_kl_zn_hist = []

  kl_hist = []
  adapted_kl_hist = []

  latents_hist = []
  metavalid_accuracy_hist = []

  for i in range(5):
      latents_hist.append([])

      kl_components_hist.append([])
      adapted_kl_components_hist.append([])

      kl_zn_hist.append([])
      adapted_kl_zn_hist.append([])

      for j in range(64):
        kl_components_hist[i].append([])
        adapted_kl_components_hist[i].append([])
        latents_hist[i].append([])


  with tf.train.MonitoredTrainingSession(
      checkpoint_dir=checkpoint_path,
      save_summaries_steps=FLAGS.checkpoint_steps,
      log_step_count_steps=FLAGS.checkpoint_steps,
      save_checkpoint_steps=FLAGS.checkpoint_steps,
      summary_dir=checkpoint_path) as sess:
      # hooks=[wandb.tensorflow.WandbHook(steps_per_log=10)]) as sess:
    if not FLAGS.evaluation_mode:
      global_step_ev = sess.run(global_step)
      while global_step_ev < num_steps_limit:
        if global_step_ev % FLAGS.checkpoint_steps == 0:
          # Just after saving checkpoint, calculate accuracy 10 times and save
          # the best checkpoint for early stopping.
          metavalid_accuracy_ev = utils.evaluate_and_average(
              sess, metavalid_accuracy, 1)  #runs the session for validation

          # kl_components_ev = utils.evaluate(sess, kl_components)
          # adapted_kl_components_ev = utils.evaluate(sess, adapted_kl_components)
          #
          # kl_zn_ev = utils.evaluate(sess, kl_zn)
          # adapted_kl_zn_ev = utils.evaluate(sess, adapted_kl_zn)
          #
          # # why is there only one kl divergence score for eatch batch. The divergence should be per class per component.
          #
          # kl_ev = utils.evaluate(sess, kl)
          # adapted_kl_ev = utils.evaluate(sess, adapted_kl)
          #
          latents_ev = utils.evaluate(sess, latents)
          adapted_latents_ev = utils.evaluate(sess, adapted_latents)
          spurious_ev = utils.evaluate(sess, spurious)

          # for batch in kl_components_ev:
          #     for c in batch:
          #         for components in c:
          #             for i, component in enumerate(components):
          #                 cl = int(component[0])
          #                 kl_val = component[1]
          #                 if (cl <= 5):  # collect data for sampled classes
          #                     # for each component
          #                     kl_components_hist[cl][i].append(kl_val)
          #                 if cl not in classes_seen:
          #                     classes_seen[cl] = 1
          #
          # for batch in adapted_kl_components_ev:
          #     for c in batch:
          #         for components in c:
          #             for i, component in enumerate(components):
          #                 cl = int(component[0])
          #                 kl_val = component[1]
          #                 if (cl <= 5):  # collect data for sampled classes
          #                     # for each class and component
          #                     adapted_kl_components_hist[cl][i].append(kl_val)
          #
          # for batch in kl_zn_ev: # batch, 5, 2
          #     for component in batch:
          #         cl = int(component[0])
          #         kl_zn_val = component[1]
          #         if (cl <= 5):  # collect data for sampled classes
          #             kl_zn_hist[cl].append(kl_zn_val)
          #
          # for batch in adapted_kl_zn_ev:  # batch, 5, 2
          #     for component in batch:
          #         cl = int(component[0])
          #         adapted_kl_zn_val = component[1]
          #         if (cl <= 5):  # collect data for sampled classes
          #             adapted_kl_zn_hist[cl].append(adapted_kl_zn_val)
          #
          for batch_change, batch_latents in zip(latents_ev - spurious_ev, latents_ev):
              for k, c in enumerate(batch_change):
                  for j, components in enumerate(c):
                      for i, component in enumerate(components):
                          cl = int(batch_latents[k][j][i][0])
                          latent_val = component[1]
                          if (cl <= 5):  # collect data for sampled classes
                              latents_hist[cl][i].append(latent_val)
          #
          # ########## Visualize kl history
          # _, ax = plt.subplots(5, 2, sharex='col', sharey='row', figsize=(20, 20))
          #
          # for i in range(5):
          #     color = iter(cm.rainbow(np.linspace(0, 1, 64)))
          #     for j in range(64):
          #         c = next(color)
          #         val = kl_components_hist[i][j]
          #         step = range(global_step_ev, global_step_ev+len(val))
          #         ax[i][0].plot(step, val, c=c) #adds values for each component using a different color
          #         ax[i][1].plot(step, adapted_kl_components_hist[i][j], c=c)
          #
          #     ax[i][0].set_title('N=' + str(i) + ' log(q(zn|x) / p(z)) ratio for Initial Factors')
          #     ax[i][1].set_title('N=' + str(i) + ' log(q(zn|x) / p(z)) ratio for Adapted Factors')
          #     # ax[i][0].legend(list(range(64)))
          #     ax[i][0].set_ylabel('kl divergence')
          #
          # ax[4][0].set_xlabel('step')
          # ax[4][1].set_xlabel('step')
          #
          # ######### Visualize kl_zn history
          # _, ax_zn = plt.subplots(5, 2, sharex='col', sharey='row', figsize=(20, 20))
          #
          # for i in range(5):
          #     color = iter(cm.rainbow(np.linspace(0, 1, 5)))
          #     c = next(color)
          #     val = kl_zn_hist[i]
          #     step = range(global_step_ev, global_step_ev+ len(val))
          #     ax_zn[i][0].plot(step, val, c=c)
          #     ax_zn[i][1].plot(step, adapted_kl_zn_hist[i], c=c)
          #
          #     ax_zn[i][0].set_title('N=' + str(i) + ' KL Divergence for Initial Zn for q(zn|x) and p(z)')
          #     ax_zn[i][1].set_title('N=' + str(i) + ' KL Divergence for Adapted Zn for q(zn|x) and p(z)')
          #
          # ax_zn[4][0].set_xlabel('step')
          # ax_zn[4][1].set_xlabel('step')
          #
          # ########### Visualize kl divergence for batches
          # kl_hist.append(kl_ev.flatten())
          # adapted_kl_hist.append(adapted_kl_ev.flatten())
          # _, (ax1, ax2) = plt.subplots(1, 2, sharey=True)
          # ax1.plot(range(global_step_ev, global_step_ev+ len(kl_hist)), kl_hist)
          # ax1.set_title('KL Divergence for Initial q(z|x) and p(z)')
          # ########### Visualize adapted kl divergence for batches
          # ax2.plot(range(global_step_ev, global_step_ev+ len(adapted_kl_hist)), adapted_kl_hist)
          # ax2.set_title('KL Divergence for Adapted q(z|x) and p(z)')
          #
          # metavalid_accuracy_hist.append(metavalid_accuracy_ev)
          # _, metavalid_accuracy_plot = plt.subplots()
          # metavalid_accuracy_plot.plot(range(0, len(metavalid_accuracy_hist)), metavalid_accuracy_hist)
          # metavalid_accuracy_plot.set_title('Metavalidation Accuracy')
          #
          #
          # Visualize latent history, additionally examine the gradients for the latents
          _, ax_latent = plt.subplots(5, sharex='col', figsize=(20, 20))

          for i in range(5):
              color = iter(cm.rainbow(np.linspace(0, 1, 64)))
              for j in range(64):
                  c = next(color)
                  step = range(0, len(latents_hist[i][j]))
                  ax_latent[i].plot(step, latents_hist[i][j], c=c)

              ax_latent[i].set_title('class=' + str(i) + ' Change in Latents')

          ax_latent[4].set_xlabel('step')
          plt.show();


          tf.logging.info("Step: {} meta-valid accuracy: {}".format(
              global_step_ev, metavalid_accuracy_ev))

          if metavalid_accuracy_ev > best_metavalid_accuracy:
            utils.copy_checkpoint(checkpoint_path, global_step_ev,
                                  metavalid_accuracy_ev)
            best_metavalid_accuracy = metavalid_accuracy_ev

        _, global_step_ev, metatrain_accuracy_ev = sess.run(
            [train_op, global_step, metatrain_accuracy]) #runs the session for training

        if global_step_ev % (FLAGS.checkpoint_steps // 2) == 0:
          tf.logging.info("Step: {} meta-train accuracy: {}".format(
              global_step_ev, metatrain_accuracy_ev))
    else:
      assert not FLAGS.checkpoint_steps
      num_metatest_estimates = (
          10000 // outer_model_config["metatest_batch_size"])

      test_accuracy = utils.evaluate_and_average(sess, metatest_accuracy,
                                                 num_metatest_estimates) #runs the session for testing

      tf.logging.info("Metatest accuracy: %f", test_accuracy)
      with tf.gfile.Open(
          os.path.join(checkpoint_path, "test_accuracy"), "wb") as f:
        pickle.dump(test_accuracy, f)
示例#5
0
def run_training_loop(checkpoint_path):
    """Runs the training loop, either saving a checkpoint or evaluating it."""
    outer_model_config = config.get_outer_model_config()
    tf.logging.info("outer_model_config: {}".format(outer_model_config))
    (train_op, global_step, metatrain_accuracy, metavalid_accuracy,
     metatest_accuracy, metatrain_dacc, metavalid_dacc, metatest_dacc,
     hardness, correct) = construct_graph(outer_model_config)

    num_steps_limit = outer_model_config["num_steps_limit"]
    best_metavalid_accuracy = 0.
    best_metavalid_dacc = 0.
    with tf.train.MonitoredTrainingSession(
            checkpoint_dir=checkpoint_path,
            save_summaries_steps=FLAGS.checkpoint_steps,
            log_step_count_steps=FLAGS.checkpoint_steps,
            save_checkpoint_steps=FLAGS.checkpoint_steps,
            summary_dir=checkpoint_path) as sess:
        if not FLAGS.evaluation_mode:
            global_step_ev = sess.run(global_step)
            while global_step_ev < num_steps_limit:
                if global_step_ev % FLAGS.checkpoint_steps == 0:
                    # Just after saving checkpoint, calculate accuracy 10 times and save
                    # the best checkpoint for early stopping.
                    #metavalid_accuracy_ev = utils.evaluate_and_average(
                    #sess, metavalid_accuracy, 10)
                    metavalid_accuracy_ev, metavalid_dacc_ev = utils.evaluate_and_average_acc_dacc(
                        sess, metavalid_accuracy, metavalid_dacc, 10)
                    tf.logging.info(
                        "Step: {} meta-valid accuracy: {}, dacc: {} best acc: {} best dacc: {}"
                        .format(global_step_ev, metavalid_accuracy_ev,
                                metavalid_dacc_ev, best_metavalid_accuracy,
                                best_metavalid_dacc))

                    if metavalid_accuracy_ev > best_metavalid_accuracy:
                        utils.copy_checkpoint(checkpoint_path, global_step_ev,
                                              metavalid_accuracy_ev)
                        best_metavalid_accuracy = metavalid_accuracy_ev
                    if metavalid_dacc_ev > best_metavalid_dacc:
                        best_metavalid_dacc = metavalid_dacc_ev
                _, global_step_ev, metatrain_accuracy_ev = sess.run(
                    [train_op, global_step, metatrain_accuracy])
                if global_step_ev % (FLAGS.checkpoint_steps // 2) == 0:
                    tf.logging.info("Step: {} meta-train accuracy: {}".format(
                        global_step_ev, metatrain_accuracy_ev))
        else:
            if not FLAGS.hacc:
                assert not FLAGS.checkpoint_steps
                num_metatest_estimates = (
                    2000 // outer_model_config["metatest_batch_size"])
                # Not changed to dacc yet
                test_accuracy = utils.evaluate_and_average(
                    sess, metatest_accuracy, num_metatest_estimates)

                tf.logging.info("Metatest accuracy: %f", test_accuracy)
                with tf.gfile.Open(
                        os.path.join(checkpoint_path, "test_accuracy"),
                        "wb") as f:
                    pickle.dump(test_accuracy, f)
            else:
                all_hardness = []
                all_correct = []
                for i in range(2000):
                    hardness_ev, correct_ev = sess.run([hardness, correct])
                    hardness_ev = [hardness_ev[i, :, i] for i in range(5)]
                    hardness_ev = np.array(hardness_ev).flatten()
                    correct_ev = np.array(correct_ev).flatten()
                    all_hardness.append(hardness_ev)
                    all_correct.append(correct_ev)
                all_hardness = np.array(all_hardness).flatten()
                all_correct = np.array(all_correct).flatten()
                save_file = {"hardness": all_hardness, "correct": all_correct}
                print(all_correct.sum() / len(all_correct))
                pickle.dump(save_file, open("hacc/" + FLAGS.config, "wb"))