def write_to_tensorboard(self, tb_writer, tb_tag, step): for key in self._num_top_i_successes: for i in range(self._k): top_i_success_rate = (self._num_top_i_successes[key][i] / (self._num_accesses[key] + 1e-8)) utils.log_scalar( tb_writer, "{}/{}_top_{}".format(tb_tag, key, i + 1), top_i_success_rate, step)
def write_to_tensorboard(self, tb_writer, tb_tag, step): weighted_taus = np.array(self._weighted_taus) eviction_masks = np.array(self._masks) eviction_mean_weighted_tau = np.sum( weighted_taus * eviction_masks) / (np.sum(eviction_masks) + 1e-8) utils.log_scalar(tb_writer, "{}/eviction_weighted_tau".format(tb_tag), eviction_mean_weighted_tau, step) utils.log_scalar(tb_writer, "{}/total_weighted_tau".format(tb_tag), np.mean(weighted_taus), step)
def log_hit_rates(tb_writer, tb_key, hit_rates, step): """Logs list of cumulative hit rates to tensorboard. Args: tb_writer (FileWriter): used to log. tb_key (str): used as the tensorboard key. hit_rates (list[float]): the hit rates to log. Assumed that hit_rates[i] is the cumulative hit rate on the first i / len(hit_rates) portion of the data. step (int): step number to use in tensorboard. """ for i, hit_rate in enumerate(hit_rates[:-1]): utils.log_scalar(tb_writer, tb_key + "_{:.2f}".format( (i + 1) / len(hit_rates)), hit_rate, step) utils.log_scalar(tb_writer, tb_key, hit_rates[-1], step)
def write_to_tensorboard(self, tb_writer, tb_tag, step): eviction_masks = np.array(self._masks) difference_gap = np.log10( np.array(self._evicted_scores) - np.array(self._optimal_scores) + 1) quotient_gap = np.log10( np.array(self._optimal_scores) / np.array(self._evicted_scores)) gaps = { "difference": difference_gap, "quotient": quotient_gap, } for gap_type, gap in gaps.items(): eviction_mean_gap = np.sum( gap * eviction_masks) / (np.sum(eviction_masks) + 1e-8) utils.log_scalar( tb_writer, "{}/eviction_oracle_score_{}_gap".format(tb_tag, gap_type), eviction_mean_gap, step) utils.log_scalar( tb_writer, "{}/oracle_score_{}_gap".format(tb_tag, gap_type), np.mean(gap), step)
def main(_): logging.info("Seed: %d", FLAGS.seed) np.random.seed(FLAGS.seed) torch.random.manual_seed(FLAGS.seed) if FLAGS.save_freq % FLAGS.small_eval_freq != 0: raise ValueError(( "Save frequency ({}) must be a multiple of evaluation frequency ({})." " Allows choosing checkpoints based on their evaluation scores." ).format(FLAGS.save_freq, FLAGS.small_eval_freq)) if FLAGS.full_eval_freq % FLAGS.small_eval_freq != 0: raise ValueError( ("Full evaluation frequency ({}) must be a multiple of small" " evaluation frequency ({}) so that their values can be compared." ).format(FLAGS.full_eval_freq, FLAGS.small_eval_freq)) exp_dir = os.path.join(FLAGS.experiment_base_dir, FLAGS.experiment_name) common_utils.create_experiment_directory(exp_dir, FLAGS.force_overwrite) tensorboard_dir = os.path.join(exp_dir, "tensorboard") tf.disable_eager_execution() tb_writer = tf.summary.FileWriter(tensorboard_dir) predictions_dir = os.path.join(exp_dir, "predictions") os.makedirs(predictions_dir, exist_ok=True) checkpoints_dir = os.path.join(exp_dir, "checkpoints") os.makedirs(checkpoints_dir, exist_ok=True) evict_trace_dir = os.path.join(exp_dir, "evictions") os.makedirs(evict_trace_dir, exist_ok=True) model_config = cfg.Config.from_files_and_bindings(FLAGS.model_configs, FLAGS.model_bindings) logging.info("Model config: %s", model_config) with open(os.path.join(exp_dir, "model_config.json"), "w") as f: model_config.to_file(f) cache_config = cfg.Config.from_files_and_bindings(FLAGS.cache_configs, FLAGS.cache_bindings) logging.info("Cache config: %s", cache_config) with open(os.path.join(exp_dir, "cache_config.json"), "w") as f: cache_config.to_file(f) dagger_schedule_config = cfg.Config.from_files_and_bindings( FLAGS.dagger_schedule_configs, FLAGS.dagger_schedule_bindings) logging.info("DAgger config: %s", dagger_schedule_config) with open(os.path.join(exp_dir, "dagger_config.json"), "w") as f: dagger_schedule_config.to_file(f) dagger_schedule = schedule_from_config(dagger_schedule_config) # Process everything on GPU if available device = torch.device("cpu") if torch.cuda.is_available(): torch.set_default_tensor_type(torch.cuda.FloatTensor) device = torch.device("cuda:0") logging.info("Device: %s", device) policy_model = model.EvictionPolicyModel.from_config(model_config).to( device) optimizer = optim.Adam(policy_model.parameters(), lr=model_config.get("lr")) step = 0 get_step = lambda: step oracle_valid_data, hit_rates = next( measure_cache_hit_rate( FLAGS.valid_memtrace, cache_config, policy_model, schedules.ConstantSchedule(0), get_step, os.path.join(evict_trace_dir, "oracle_valid.txt"))) log_hit_rates(tb_writer, "cache_hit_rate/oracle_valid", hit_rates, step) with tqdm.tqdm(total=FLAGS.total_steps) as pbar: while True: # loop for waiting until steps == FLAGS.total_steps # Optimization: Instead of passing through the whole memory trace for # training and only using update_freq many of them, we lazily gather k * # update_freq batches and still train on a subsample of update_freq. # The value of k=collection_multiplier trades off between: # - The set of k * update_freq examples are all consecutive in the # memory trace. As k gets small, the set of these examples becomes less # i.i.d., as they are temporally correlated. The examples cannot be # random access within the memory trace, since at time t, we require the # previous cache accesses to compute the cache state at time t. # - As k gets large, training becomes slower, as we must perform k times # as much collecting work than training work. max_examples = (dagger_schedule_config.get("update_freq") * FLAGS.collection_multiplier * FLAGS.batch_size) train_data_generator = measure_cache_hit_rate( FLAGS.train_memtrace, cache_config, policy_model, dagger_schedule, get_step, os.path.join(evict_trace_dir, "mixture-train-{}.txt"), max_examples=max_examples) for train_data, hit_rates in train_data_generator: log_hit_rates(tb_writer, "cache_hit_rate/train_mixture_policy", hit_rates, step) utils.log_scalar(tb_writer, "cache_hit_rate/mixture_parameter", dagger_schedule.value(step), step) for batch_num, batch in enumerate( utils.as_batches([train_data], FLAGS.batch_size, model_config.get("sequence_length"))): def evaluate_helper(eval_size, suffix): """Evaluates the model on train / valid data on and off-policy. Args: eval_size (int): the number of examples to evaluate on. suffix (str): appended to all logging and tensorboard paths. """ evaluate(policy_model, oracle_valid_data[-eval_size:], step, "off_policy_valid" + suffix, tb_writer, predictions_dir) # train_data is defined in the loop, but evaluate_helper is only # called in the same loop iteration. # pylint: disable=cell-var-from-loop evaluate(policy_model, train_data[-eval_size:], step, "train" + suffix, tb_writer, predictions_dir) # pylint: enable=cell-var-from-loop # Log the cache hit rates on portions of train / valid _, hit_rates = next( measure_cache_hit_rate( FLAGS.train_memtrace, cache_config, policy_model, schedules.ConstantSchedule(1), get_step, os.path.join( evict_trace_dir, "train{}-{}.txt".format(suffix, step)), max_examples=eval_size, use_oracle_scores=False)) log_hit_rates(tb_writer, "cache_hit_rate/train" + suffix, hit_rates, step) # Use oracle scores, since eviction trace in log_evaluate_stats will # log with on-policy scores. on_policy_valid_data, hit_rates = next( measure_cache_hit_rate( FLAGS.valid_memtrace, cache_config, policy_model, schedules.ConstantSchedule(1), get_step, os.path.join( evict_trace_dir, "valid{}-{}.txt".format(suffix, step)), max_examples=eval_size)) log_hit_rates(tb_writer, "cache_hit_rate/valid" + suffix, hit_rates, step) evaluate(policy_model, on_policy_valid_data[-eval_size:], step, "on_policy_valid" + suffix, tb_writer, predictions_dir) if step % FLAGS.small_eval_freq == 0: evaluate_helper(FLAGS.small_eval_size, "") if step % FLAGS.full_eval_freq == 0: evaluate_helper(len(oracle_valid_data), "_full") if step % FLAGS.save_freq == 0 and step != 0: save_path = os.path.join(checkpoints_dir, "{}.ckpt".format(step)) with open(save_path, "wb") as save_file: checkpoint_buffer = io.BytesIO() torch.save(policy_model.state_dict(), checkpoint_buffer) logging.info("Saving model checkpoint to: %s", save_path) save_file.write(checkpoint_buffer.getvalue()) optimizer.zero_grad() losses = policy_model.loss( batch, model_config.get("sequence_length") // 2) total_loss = sum(losses.values()) total_loss.backward() optimizer.step() pbar.update(1) step += 1 if step % FLAGS.tb_freq == 0: utils.log_scalar(tb_writer, "loss/total", total_loss, step) for loss_name, loss_value in losses.items(): utils.log_scalar(tb_writer, "loss/{}".format(loss_name), loss_value, step) if step == FLAGS.total_steps: return # Break out of inner-loop to get next set of k * update_freq batches if batch_num == dagger_schedule_config.get("update_freq"): break