def __init__(self, config, logger, dataset): self._config = config self._logger = logger self._float = prec.global_policy().compute_dtype self._should_log = tools.Every(config.log_every) self._should_train = tools.Every(config.train_every) self._should_pretrain = tools.Once() self._should_reset = tools.Every(config.reset_every) self._should_expl = tools.Until( int(config.expl_until / config.action_repeat)) self._metrics = collections.defaultdict(tf.metrics.Mean) with tf.device('cpu:0'): self._step = tf.Variable(count_steps(config.traindir), dtype=tf.int64) # Schedules. config.actor_entropy = ( lambda x=config.actor_entropy: tools.schedule(x, self._step)) config.actor_state_entropy = ( lambda x=config.actor_state_entropy: tools.schedule(x, self._step)) config.imag_gradient_mix = ( lambda x=config.imag_gradient_mix: tools.schedule(x, self._step)) self._dataset = iter(dataset) self._wm = models.WorldModel(self._step, config) self._task_behavior = models.ImagBehavior(config, self._wm, config.behavior_stop_grad) reward = lambda f, s, a: self._wm.heads['reward'](f).mode() self._expl_behavior = dict( greedy=lambda: self._task_behavior, random=lambda: expl.Random(config), plan2explore=lambda: expl.Plan2Explore(config, self._wm, reward), )[config.expl_behavior]() # Train step to initialize variables including optimizer statistics. self._train(next(self._dataset))
def train(self, data): data = self.preprocess(data) with tf.GradientTape() as model_tape: embed = self.encoder(data) post, prior = self.dynamics.observe(embed, data['action']) kl_balance = tools.schedule(self._config.kl_balance, self._step) kl_free = tools.schedule(self._config.kl_free, self._step) kl_scale = tools.schedule(self._config.kl_scale, self._step) kl_loss, kl_value = self.dynamics.kl_loss( post, prior, kl_balance, kl_free, kl_scale) feat = self.dynamics.get_feat(post) likes = {} for name, head in self.heads.items(): grad_head = (name in self._config.grad_heads) inp = feat if grad_head else tf.stop_gradient(feat) pred = head(inp, tf.float32) like = pred.log_prob(tf.cast(data[name], tf.float32)) likes[name] = tf.reduce_mean(like) * self._scales.get(name, 1.0) model_loss = kl_loss - sum(likes.values()) model_parts = [self.encoder, self.dynamics] + list(self.heads.values()) metrics = self._model_opt(model_tape, model_loss, model_parts) metrics.update({f'{name}_loss': -like for name, like in likes.items()}) metrics['kl_balance'] = kl_balance metrics['kl_free'] = kl_free metrics['kl_scale'] = kl_scale metrics['kl'] = tf.reduce_mean(kl_value) metrics['prior_ent'] = self.dynamics.get_dist(prior).entropy() metrics['post_ent'] = self.dynamics.get_dist(post).entropy() return embed, post, feat, kl_value, metrics
def __init__(self, env, model_maker, config, training=True): self.env = env # self.act_space = self.env.action_space self._c = config self._precision = config.precision self._float = prec.global_policy().compute_dtype # self.ob, _, _, _ = self.env.step( # self.env._env.action_space.sample() # ) # whether it is discrete or not, 0 is proper self.ob = self.env.reset() self.state = None acts = self.env.action_space self.random_actor = tools.OneHotDist(tf.zeros_like(acts.low)[None]) self._c.num_actions = acts.n if hasattr(acts, "n") else acts.shape[0] print("self._c.num_actions:", self._c.num_actions) # self.batch_size = 16 self.batch_size = self._c.batch_size self.batch_length = ( self._c.batch_length ) # when it is not model-based learning, consider it controling the replay buffer self.TD_size = 1 # no TD self.play_records = [] self.advantage = True self.total_step = 1 self.save_play_img = False self.RGB_array_list = [] self.episode_reward = 0 self.episode_step = 0 # to avoid devide by zero self.datadir = self._c.logdir / "episodes" self._writer = tf.summary.create_file_writer("./tf_log", max_queue=1000, flush_millis=20000) if training: self.prefill_and_make_dataset() else: pass with tf.device("cpu:1"): self._step = tf.Variable(count_steps(self.datadir), dtype=tf.int64) self._c.actor_entropy = lambda x=self._c.actor_entropy: tools.schedule( x, self._step) self._c.actor_state_entropy = (lambda x=self._c.actor_state_entropy: tools.schedule(x, self._step)) self._c.imag_gradient_mix = lambda x=self._c.imag_gradient_mix: tools.schedule( x, self._step) self.model = model_maker(self.env, training, self._step, self._writer, self._c)