Exemple #1
0
 def __init__(self, config, logger, dataset):
     self._config = config
     self._logger = logger
     self._float = prec.global_policy().compute_dtype
     self._should_log = tools.Every(config.log_every)
     self._should_train = tools.Every(config.train_every)
     self._should_pretrain = tools.Once()
     self._should_reset = tools.Every(config.reset_every)
     self._should_expl = tools.Until(
         int(config.expl_until / config.action_repeat))
     self._metrics = collections.defaultdict(tf.metrics.Mean)
     with tf.device('cpu:0'):
         self._step = tf.Variable(count_steps(config.traindir),
                                  dtype=tf.int64)
     # Schedules.
     config.actor_entropy = (
         lambda x=config.actor_entropy: tools.schedule(x, self._step))
     config.actor_state_entropy = (
         lambda x=config.actor_state_entropy: tools.schedule(x, self._step))
     config.imag_gradient_mix = (
         lambda x=config.imag_gradient_mix: tools.schedule(x, self._step))
     self._dataset = iter(dataset)
     self._wm = models.WorldModel(self._step, config)
     self._task_behavior = models.ImagBehavior(config, self._wm,
                                               config.behavior_stop_grad)
     reward = lambda f, s, a: self._wm.heads['reward'](f).mode()
     self._expl_behavior = dict(
         greedy=lambda: self._task_behavior,
         random=lambda: expl.Random(config),
         plan2explore=lambda: expl.Plan2Explore(config, self._wm, reward),
     )[config.expl_behavior]()
     # Train step to initialize variables including optimizer statistics.
     self._train(next(self._dataset))
Exemple #2
0
 def __init__(self, config, world_model, reward=None):
     self._config = config
     self._reward = reward
     self._behavior = models.ImagBehavior(config, world_model)
     self.actor = self._behavior.actor
     stoch_size = config.dyn_stoch
     if config.dyn_discrete:
         stoch_size *= config.dyn_discrete
     size = {
         'embed': 32 * config.cnn_depth,
         'stoch': stoch_size,
         'deter': config.dyn_deter,
         'feat': config.dyn_stoch + config.dyn_deter,
     }[self._config.disag_target]
     kw = dict(shape=size,
               layers=config.disag_layers,
               units=config.disag_units,
               act=config.act)
     self._networks = [
         networks.DenseHead(**kw) for _ in range(config.disag_models)
     ]
     self._opt = tools.Optimizer('ensemble',
                                 config.model_lr,
                                 config.opt_eps,
                                 config.grad_clip,
                                 config.weight_decay,
                                 opt=config.opt)