def evaluate_helper(eval_size, suffix): """Evaluates the model on train / valid data on and off-policy. Args: eval_size (int): the number of examples to evaluate on. suffix (str): appended to all logging and tensorboard paths. """ evaluate(policy_model, oracle_valid_data[-eval_size:], step, "off_policy_valid" + suffix, tb_writer, predictions_dir) # train_data is defined in the loop, but evaluate_helper is only # called in the same loop iteration. # pylint: disable=cell-var-from-loop evaluate(policy_model, train_data[-eval_size:], step, "train" + suffix, tb_writer, predictions_dir) # pylint: enable=cell-var-from-loop # Log the cache hit rates on portions of train / valid _, hit_rates = next( measure_cache_hit_rate( FLAGS.train_memtrace, cache_config, policy_model, schedules.ConstantSchedule(1), get_step, os.path.join( evict_trace_dir, "train{}-{}.txt".format(suffix, step)), max_examples=eval_size, use_oracle_scores=False)) log_hit_rates(tb_writer, "cache_hit_rate/train" + suffix, hit_rates, step) # Use oracle scores, since eviction trace in log_evaluate_stats will # log with on-policy scores. on_policy_valid_data, hit_rates = next( measure_cache_hit_rate( FLAGS.valid_memtrace, cache_config, policy_model, schedules.ConstantSchedule(1), get_step, os.path.join( evict_trace_dir, "valid{}-{}.txt".format(suffix, step)), max_examples=eval_size)) log_hit_rates(tb_writer, "cache_hit_rate/valid" + suffix, hit_rates, step) evaluate(policy_model, on_policy_valid_data[-eval_size:], step, "on_policy_valid" + suffix, tb_writer, predictions_dir)
def schedule_from_config(config): if config.get("type") == "linear": return schedules.LinearSchedule( config.get("num_steps"), config.get("final"), config.get("initial")) elif config.get("type") == "constant": return schedules.ConstantSchedule(config.get("value")) else: raise ValueError("Unsupported schedule type: {}".format(config.get("type")))
def main(_): logging.info("Seed: %d", FLAGS.seed) np.random.seed(FLAGS.seed) torch.random.manual_seed(FLAGS.seed) if FLAGS.save_freq % FLAGS.small_eval_freq != 0: raise ValueError(( "Save frequency ({}) must be a multiple of evaluation frequency ({})." " Allows choosing checkpoints based on their evaluation scores." ).format(FLAGS.save_freq, FLAGS.small_eval_freq)) if FLAGS.full_eval_freq % FLAGS.small_eval_freq != 0: raise ValueError( ("Full evaluation frequency ({}) must be a multiple of small" " evaluation frequency ({}) so that their values can be compared." ).format(FLAGS.full_eval_freq, FLAGS.small_eval_freq)) exp_dir = os.path.join(FLAGS.experiment_base_dir, FLAGS.experiment_name) common_utils.create_experiment_directory(exp_dir, FLAGS.force_overwrite) tensorboard_dir = os.path.join(exp_dir, "tensorboard") tf.disable_eager_execution() tb_writer = tf.summary.FileWriter(tensorboard_dir) predictions_dir = os.path.join(exp_dir, "predictions") os.makedirs(predictions_dir, exist_ok=True) checkpoints_dir = os.path.join(exp_dir, "checkpoints") os.makedirs(checkpoints_dir, exist_ok=True) evict_trace_dir = os.path.join(exp_dir, "evictions") os.makedirs(evict_trace_dir, exist_ok=True) model_config = cfg.Config.from_files_and_bindings(FLAGS.model_configs, FLAGS.model_bindings) logging.info("Model config: %s", model_config) with open(os.path.join(exp_dir, "model_config.json"), "w") as f: model_config.to_file(f) cache_config = cfg.Config.from_files_and_bindings(FLAGS.cache_configs, FLAGS.cache_bindings) logging.info("Cache config: %s", cache_config) with open(os.path.join(exp_dir, "cache_config.json"), "w") as f: cache_config.to_file(f) dagger_schedule_config = cfg.Config.from_files_and_bindings( FLAGS.dagger_schedule_configs, FLAGS.dagger_schedule_bindings) logging.info("DAgger config: %s", dagger_schedule_config) with open(os.path.join(exp_dir, "dagger_config.json"), "w") as f: dagger_schedule_config.to_file(f) dagger_schedule = schedule_from_config(dagger_schedule_config) # Process everything on GPU if available device = torch.device("cpu") if torch.cuda.is_available(): torch.set_default_tensor_type(torch.cuda.FloatTensor) device = torch.device("cuda:0") logging.info("Device: %s", device) policy_model = model.EvictionPolicyModel.from_config(model_config).to( device) optimizer = optim.Adam(policy_model.parameters(), lr=model_config.get("lr")) step = 0 get_step = lambda: step oracle_valid_data, hit_rates = next( measure_cache_hit_rate( FLAGS.valid_memtrace, cache_config, policy_model, schedules.ConstantSchedule(0), get_step, os.path.join(evict_trace_dir, "oracle_valid.txt"))) log_hit_rates(tb_writer, "cache_hit_rate/oracle_valid", hit_rates, step) with tqdm.tqdm(total=FLAGS.total_steps) as pbar: while True: # loop for waiting until steps == FLAGS.total_steps # Optimization: Instead of passing through the whole memory trace for # training and only using update_freq many of them, we lazily gather k * # update_freq batches and still train on a subsample of update_freq. # The value of k=collection_multiplier trades off between: # - The set of k * update_freq examples are all consecutive in the # memory trace. As k gets small, the set of these examples becomes less # i.i.d., as they are temporally correlated. The examples cannot be # random access within the memory trace, since at time t, we require the # previous cache accesses to compute the cache state at time t. # - As k gets large, training becomes slower, as we must perform k times # as much collecting work than training work. max_examples = (dagger_schedule_config.get("update_freq") * FLAGS.collection_multiplier * FLAGS.batch_size) train_data_generator = measure_cache_hit_rate( FLAGS.train_memtrace, cache_config, policy_model, dagger_schedule, get_step, os.path.join(evict_trace_dir, "mixture-train-{}.txt"), max_examples=max_examples) for train_data, hit_rates in train_data_generator: log_hit_rates(tb_writer, "cache_hit_rate/train_mixture_policy", hit_rates, step) utils.log_scalar(tb_writer, "cache_hit_rate/mixture_parameter", dagger_schedule.value(step), step) for batch_num, batch in enumerate( utils.as_batches([train_data], FLAGS.batch_size, model_config.get("sequence_length"))): def evaluate_helper(eval_size, suffix): """Evaluates the model on train / valid data on and off-policy. Args: eval_size (int): the number of examples to evaluate on. suffix (str): appended to all logging and tensorboard paths. """ evaluate(policy_model, oracle_valid_data[-eval_size:], step, "off_policy_valid" + suffix, tb_writer, predictions_dir) # train_data is defined in the loop, but evaluate_helper is only # called in the same loop iteration. # pylint: disable=cell-var-from-loop evaluate(policy_model, train_data[-eval_size:], step, "train" + suffix, tb_writer, predictions_dir) # pylint: enable=cell-var-from-loop # Log the cache hit rates on portions of train / valid _, hit_rates = next( measure_cache_hit_rate( FLAGS.train_memtrace, cache_config, policy_model, schedules.ConstantSchedule(1), get_step, os.path.join( evict_trace_dir, "train{}-{}.txt".format(suffix, step)), max_examples=eval_size, use_oracle_scores=False)) log_hit_rates(tb_writer, "cache_hit_rate/train" + suffix, hit_rates, step) # Use oracle scores, since eviction trace in log_evaluate_stats will # log with on-policy scores. on_policy_valid_data, hit_rates = next( measure_cache_hit_rate( FLAGS.valid_memtrace, cache_config, policy_model, schedules.ConstantSchedule(1), get_step, os.path.join( evict_trace_dir, "valid{}-{}.txt".format(suffix, step)), max_examples=eval_size)) log_hit_rates(tb_writer, "cache_hit_rate/valid" + suffix, hit_rates, step) evaluate(policy_model, on_policy_valid_data[-eval_size:], step, "on_policy_valid" + suffix, tb_writer, predictions_dir) if step % FLAGS.small_eval_freq == 0: evaluate_helper(FLAGS.small_eval_size, "") if step % FLAGS.full_eval_freq == 0: evaluate_helper(len(oracle_valid_data), "_full") if step % FLAGS.save_freq == 0 and step != 0: save_path = os.path.join(checkpoints_dir, "{}.ckpt".format(step)) with open(save_path, "wb") as save_file: checkpoint_buffer = io.BytesIO() torch.save(policy_model.state_dict(), checkpoint_buffer) logging.info("Saving model checkpoint to: %s", save_path) save_file.write(checkpoint_buffer.getvalue()) optimizer.zero_grad() losses = policy_model.loss( batch, model_config.get("sequence_length") // 2) total_loss = sum(losses.values()) total_loss.backward() optimizer.step() pbar.update(1) step += 1 if step % FLAGS.tb_freq == 0: utils.log_scalar(tb_writer, "loss/total", total_loss, step) for loss_name, loss_value in losses.items(): utils.log_scalar(tb_writer, "loss/{}".format(loss_name), loss_value, step) if step == FLAGS.total_steps: return # Break out of inner-loop to get next set of k * update_freq batches if batch_num == dagger_schedule_config.get("update_freq"): break