예제 #1
0
파일: callbacks.py 프로젝트: haroon123/trax
    def on_step_end(self, step):
        summary_writer = jaxboard.SummaryWriter(
            os.path.join(self._loop.output_dir, 'srl_eval'))
        try:
            self._model.weights = serialization_utils.extract_inner_model(
                self._loop.eval_model.weights)

            metrics = collections.defaultdict(list)
            for _ in range(self._n_steps):
                batch = self._eval_task.next_batch()
                step_metrics = self._eval_batch(batch)
                for (key, value) in step_metrics.items():
                    metrics[key].append(value)

            def metric_name(context, horizon):
                return f'pred_error/context_{context}/horizon_{horizon}'

            metrics = {
                metric_name(context, horizon):
                np.sum(errors) / np.sum(errors != 0)
                for ((context, horizon), errors) in metrics.items()
            }
            self._loop.log_summary(metrics, summary_writer, '', 'srl_eval')
        finally:
            summary_writer.close()
예제 #2
0
 def on_step_end(self, step):
     summary_writer = jaxboard.SummaryWriter(
         os.path.join(self._loop.output_dir, 'srl_eval'))
     try:
         weights = serialization_utils.extract_inner_model(
             self._loop.eval_model.weights)
         metrics = self.evaluate(weights)
         self._loop.log_summary(metrics, summary_writer, '', 'srl_eval')
     finally:
         summary_writer.close()
예제 #3
0
def init_policy_from_world_model_checkpoint(policy_weights, model_output_dir,
                                            substitute_fn):
    """Initializes policy parameters from world model parameters."""
    pkl_module = utils.get_pickle_module()
    weights_file = os.path.join(model_output_dir, 'model.pkl')
    # Don't use trax.load_trainer_state to avoid a circular import.
    with tf.io.gfile.GFile(weights_file, 'rb') as f:
        model_weights = pkl_module.load(f)['weights']
    model_weights = serialization_utils.extract_inner_model(model_weights)
    # TODO(pkozakowski): The following, brittle line of code is hardcoded for
    # transplanting parameters from TransformerLM to TransformerDecoder-based
    # policy network of the same configuration. Figure out a more general method.
    return substitute_fn(policy_weights, model_weights[1:-2])
예제 #4
0
 def _extract_weights(self, weights):
     return serialization_utils.extract_inner_model(weights)