예제 #1
0
                    def evaluate_helper(eval_size, suffix):
                        """Evaluates the model on train / valid data on and off-policy.

            Args:
              eval_size (int): the number of examples to evaluate on.
              suffix (str): appended to all logging and tensorboard paths.
            """
                        evaluate(policy_model, oracle_valid_data[-eval_size:],
                                 step, "off_policy_valid" + suffix, tb_writer,
                                 predictions_dir)
                        # train_data is defined in the loop, but evaluate_helper is only
                        # called in the same loop iteration.
                        # pylint: disable=cell-var-from-loop
                        evaluate(policy_model, train_data[-eval_size:], step,
                                 "train" + suffix, tb_writer, predictions_dir)
                        # pylint: enable=cell-var-from-loop

                        # Log the cache hit rates on portions of train / valid
                        _, hit_rates = next(
                            measure_cache_hit_rate(
                                FLAGS.train_memtrace,
                                cache_config,
                                policy_model,
                                schedules.ConstantSchedule(1),
                                get_step,
                                os.path.join(
                                    evict_trace_dir,
                                    "train{}-{}.txt".format(suffix, step)),
                                max_examples=eval_size,
                                use_oracle_scores=False))
                        log_hit_rates(tb_writer,
                                      "cache_hit_rate/train" + suffix,
                                      hit_rates, step)

                        # Use oracle scores, since eviction trace in log_evaluate_stats will
                        # log with on-policy scores.
                        on_policy_valid_data, hit_rates = next(
                            measure_cache_hit_rate(
                                FLAGS.valid_memtrace,
                                cache_config,
                                policy_model,
                                schedules.ConstantSchedule(1),
                                get_step,
                                os.path.join(
                                    evict_trace_dir,
                                    "valid{}-{}.txt".format(suffix, step)),
                                max_examples=eval_size))
                        log_hit_rates(tb_writer,
                                      "cache_hit_rate/valid" + suffix,
                                      hit_rates, step)
                        evaluate(policy_model,
                                 on_policy_valid_data[-eval_size:], step,
                                 "on_policy_valid" + suffix, tb_writer,
                                 predictions_dir)
예제 #2
0
def schedule_from_config(config):
  if config.get("type") == "linear":
    return schedules.LinearSchedule(
        config.get("num_steps"), config.get("final"), config.get("initial"))
  elif config.get("type") == "constant":
    return schedules.ConstantSchedule(config.get("value"))
  else:
    raise ValueError("Unsupported schedule type: {}".format(config.get("type")))
예제 #3
0
def main(_):
    logging.info("Seed: %d", FLAGS.seed)
    np.random.seed(FLAGS.seed)
    torch.random.manual_seed(FLAGS.seed)

    if FLAGS.save_freq % FLAGS.small_eval_freq != 0:
        raise ValueError((
            "Save frequency ({}) must be a multiple of evaluation frequency ({})."
            " Allows choosing checkpoints based on their evaluation scores."
        ).format(FLAGS.save_freq, FLAGS.small_eval_freq))

    if FLAGS.full_eval_freq % FLAGS.small_eval_freq != 0:
        raise ValueError(
            ("Full evaluation frequency ({}) must be a multiple of small"
             " evaluation frequency ({}) so that their values can be compared."
             ).format(FLAGS.full_eval_freq, FLAGS.small_eval_freq))

    exp_dir = os.path.join(FLAGS.experiment_base_dir, FLAGS.experiment_name)
    common_utils.create_experiment_directory(exp_dir, FLAGS.force_overwrite)
    tensorboard_dir = os.path.join(exp_dir, "tensorboard")
    tf.disable_eager_execution()
    tb_writer = tf.summary.FileWriter(tensorboard_dir)

    predictions_dir = os.path.join(exp_dir, "predictions")
    os.makedirs(predictions_dir, exist_ok=True)

    checkpoints_dir = os.path.join(exp_dir, "checkpoints")
    os.makedirs(checkpoints_dir, exist_ok=True)

    evict_trace_dir = os.path.join(exp_dir, "evictions")
    os.makedirs(evict_trace_dir, exist_ok=True)

    model_config = cfg.Config.from_files_and_bindings(FLAGS.model_configs,
                                                      FLAGS.model_bindings)
    logging.info("Model config: %s", model_config)
    with open(os.path.join(exp_dir, "model_config.json"), "w") as f:
        model_config.to_file(f)

    cache_config = cfg.Config.from_files_and_bindings(FLAGS.cache_configs,
                                                      FLAGS.cache_bindings)
    logging.info("Cache config: %s", cache_config)
    with open(os.path.join(exp_dir, "cache_config.json"), "w") as f:
        cache_config.to_file(f)

    dagger_schedule_config = cfg.Config.from_files_and_bindings(
        FLAGS.dagger_schedule_configs, FLAGS.dagger_schedule_bindings)
    logging.info("DAgger config: %s", dagger_schedule_config)
    with open(os.path.join(exp_dir, "dagger_config.json"), "w") as f:
        dagger_schedule_config.to_file(f)
    dagger_schedule = schedule_from_config(dagger_schedule_config)

    # Process everything on GPU if available
    device = torch.device("cpu")
    if torch.cuda.is_available():
        torch.set_default_tensor_type(torch.cuda.FloatTensor)
        device = torch.device("cuda:0")
    logging.info("Device: %s", device)

    policy_model = model.EvictionPolicyModel.from_config(model_config).to(
        device)
    optimizer = optim.Adam(policy_model.parameters(),
                           lr=model_config.get("lr"))

    step = 0
    get_step = lambda: step
    oracle_valid_data, hit_rates = next(
        measure_cache_hit_rate(
            FLAGS.valid_memtrace, cache_config, policy_model,
            schedules.ConstantSchedule(0), get_step,
            os.path.join(evict_trace_dir, "oracle_valid.txt")))
    log_hit_rates(tb_writer, "cache_hit_rate/oracle_valid", hit_rates, step)

    with tqdm.tqdm(total=FLAGS.total_steps) as pbar:
        while True:  # loop for waiting until steps == FLAGS.total_steps
            # Optimization: Instead of passing through the whole memory trace for
            # training and only using update_freq many of them, we lazily gather k *
            # update_freq batches and still train on a subsample of update_freq.
            # The value of k=collection_multiplier trades off between:
            #   - The set of k * update_freq examples are all consecutive in the
            #   memory trace. As k gets small, the set of these examples becomes less
            #   i.i.d., as they are temporally correlated. The examples cannot be
            #   random access within the memory trace, since at time t, we require the
            #   previous cache accesses to compute the cache state at time t.
            #   - As k gets large, training becomes slower, as we must perform k times
            #   as much collecting work than training work.
            max_examples = (dagger_schedule_config.get("update_freq") *
                            FLAGS.collection_multiplier * FLAGS.batch_size)
            train_data_generator = measure_cache_hit_rate(
                FLAGS.train_memtrace,
                cache_config,
                policy_model,
                dagger_schedule,
                get_step,
                os.path.join(evict_trace_dir, "mixture-train-{}.txt"),
                max_examples=max_examples)
            for train_data, hit_rates in train_data_generator:
                log_hit_rates(tb_writer, "cache_hit_rate/train_mixture_policy",
                              hit_rates, step)
                utils.log_scalar(tb_writer, "cache_hit_rate/mixture_parameter",
                                 dagger_schedule.value(step), step)

                for batch_num, batch in enumerate(
                        utils.as_batches([train_data], FLAGS.batch_size,
                                         model_config.get("sequence_length"))):

                    def evaluate_helper(eval_size, suffix):
                        """Evaluates the model on train / valid data on and off-policy.

            Args:
              eval_size (int): the number of examples to evaluate on.
              suffix (str): appended to all logging and tensorboard paths.
            """
                        evaluate(policy_model, oracle_valid_data[-eval_size:],
                                 step, "off_policy_valid" + suffix, tb_writer,
                                 predictions_dir)
                        # train_data is defined in the loop, but evaluate_helper is only
                        # called in the same loop iteration.
                        # pylint: disable=cell-var-from-loop
                        evaluate(policy_model, train_data[-eval_size:], step,
                                 "train" + suffix, tb_writer, predictions_dir)
                        # pylint: enable=cell-var-from-loop

                        # Log the cache hit rates on portions of train / valid
                        _, hit_rates = next(
                            measure_cache_hit_rate(
                                FLAGS.train_memtrace,
                                cache_config,
                                policy_model,
                                schedules.ConstantSchedule(1),
                                get_step,
                                os.path.join(
                                    evict_trace_dir,
                                    "train{}-{}.txt".format(suffix, step)),
                                max_examples=eval_size,
                                use_oracle_scores=False))
                        log_hit_rates(tb_writer,
                                      "cache_hit_rate/train" + suffix,
                                      hit_rates, step)

                        # Use oracle scores, since eviction trace in log_evaluate_stats will
                        # log with on-policy scores.
                        on_policy_valid_data, hit_rates = next(
                            measure_cache_hit_rate(
                                FLAGS.valid_memtrace,
                                cache_config,
                                policy_model,
                                schedules.ConstantSchedule(1),
                                get_step,
                                os.path.join(
                                    evict_trace_dir,
                                    "valid{}-{}.txt".format(suffix, step)),
                                max_examples=eval_size))
                        log_hit_rates(tb_writer,
                                      "cache_hit_rate/valid" + suffix,
                                      hit_rates, step)
                        evaluate(policy_model,
                                 on_policy_valid_data[-eval_size:], step,
                                 "on_policy_valid" + suffix, tb_writer,
                                 predictions_dir)

                    if step % FLAGS.small_eval_freq == 0:
                        evaluate_helper(FLAGS.small_eval_size, "")

                    if step % FLAGS.full_eval_freq == 0:
                        evaluate_helper(len(oracle_valid_data), "_full")

                    if step % FLAGS.save_freq == 0 and step != 0:
                        save_path = os.path.join(checkpoints_dir,
                                                 "{}.ckpt".format(step))
                        with open(save_path, "wb") as save_file:
                            checkpoint_buffer = io.BytesIO()
                            torch.save(policy_model.state_dict(),
                                       checkpoint_buffer)
                            logging.info("Saving model checkpoint to: %s",
                                         save_path)
                            save_file.write(checkpoint_buffer.getvalue())

                    optimizer.zero_grad()
                    losses = policy_model.loss(
                        batch,
                        model_config.get("sequence_length") // 2)
                    total_loss = sum(losses.values())
                    total_loss.backward()
                    optimizer.step()
                    pbar.update(1)
                    step += 1

                    if step % FLAGS.tb_freq == 0:
                        utils.log_scalar(tb_writer, "loss/total", total_loss,
                                         step)
                        for loss_name, loss_value in losses.items():
                            utils.log_scalar(tb_writer,
                                             "loss/{}".format(loss_name),
                                             loss_value, step)

                    if step == FLAGS.total_steps:
                        return

                    # Break out of inner-loop to get next set of k * update_freq batches
                    if batch_num == dagger_schedule_config.get("update_freq"):
                        break