def __call__(self, features): raw_init_std = np.log(np.exp(self._init_std) - 1) x = features for index in range(self._layers): x = self.get(f'h{index}', tfkl.Dense, self._units, self._act)(x) if self._dist == 'tanh_normal': # https://www.desmos.com/calculator/rcmcf5jwe7 x = self.get(f'hout', tfkl.Dense, 2 * self._size)(x) mean, std = tf.split(x, 2, -1) mean = self._mean_scale * tf.tanh(mean / self._mean_scale) std = tf.nn.softplus(std + raw_init_std) + self._min_std dist = tfd.Normal(mean, std) dist = tfd.TransformedDistribution(dist, tools.TanhBijector()) dist = tfd.Independent(dist, 1) dist = tools.SampleDist(dist) elif self._dist == 'onehot': x = self.get(f'hout', tfkl.Dense, self._size)(x) dist = tools.OneHotDist(x) elif self._dist == 'gumbel': x = self.get(f'hout', tfkl.Dense, self._size)(x) dist = tfd.RelaxedOneHotCategorical(temperature=1e-1, logits=x) dist = tools.SampleDist(dist) else: raise NotImplementedError return dist
def actor(self, feat): shape = feat.shape[:-1] + [self._config.num_actions] if self._config.actor_dist == 'onehot': return tools.OneHotDist(tf.zeros(shape)) else: ones = tf.ones(shape, self._float) return tfd.Uniform(-ones, ones)
def __init__(self, env, model_maker, config, training=True): self.env = env # self.act_space = self.env.action_space self._c = config self._precision = config.precision self._float = prec.global_policy().compute_dtype # self.ob, _, _, _ = self.env.step( # self.env._env.action_space.sample() # ) # whether it is discrete or not, 0 is proper self.ob = self.env.reset() self.state = None acts = self.env.action_space self.random_actor = tools.OneHotDist(tf.zeros_like(acts.low)[None]) self._c.num_actions = acts.n if hasattr(acts, "n") else acts.shape[0] print("self._c.num_actions:", self._c.num_actions) # self.batch_size = 16 self.batch_size = self._c.batch_size self.batch_length = ( self._c.batch_length ) # when it is not model-based learning, consider it controling the replay buffer self.TD_size = 1 # no TD self.play_records = [] self.advantage = True self.total_step = 1 self.save_play_img = False self.RGB_array_list = [] self.episode_reward = 0 self.episode_step = 0 # to avoid devide by zero self.datadir = self._c.logdir / "episodes" self._writer = tf.summary.create_file_writer("./tf_log", max_queue=1000, flush_millis=20000) if training: self.prefill_and_make_dataset() else: pass with tf.device("cpu:1"): self._step = tf.Variable(count_steps(self.datadir), dtype=tf.int64) self._c.actor_entropy = lambda x=self._c.actor_entropy: tools.schedule( x, self._step) self._c.actor_state_entropy = (lambda x=self._c.actor_state_entropy: tools.schedule(x, self._step)) self._c.imag_gradient_mix = lambda x=self._c.imag_gradient_mix: tools.schedule( x, self._step) self.model = model_maker(self.env, training, self._step, self._writer, self._c)
def _exploration(self, action, training): amount = self._config.expl_amount if training else self._config.eval_noise if amount == 0: return action amount = tf.cast(amount, self._float) if 'onehot' in self._config.actor_dist: probs = amount / self._config.num_actions + (1 - amount) * action return tools.OneHotDist(probs=probs).sample() else: return tf.clip_by_value(tfd.Normal(action, amount).sample(), -1, 1) raise NotImplementedError(self._config.action_noise)
def get_dist(self, state, dtype=None): if self._discrete: logit = state['logit'] logit = tf.cast(logit, tf.float32) dist = tfd.Independent(tools.OneHotDist(logit), 1) if dtype != tf.float32: dist = tools.DtypeDist(dist, dtype or state['logit'].dtype) else: mean, std = state['mean'], state['std'] if dtype: mean = tf.cast(mean, dtype) std = tf.cast(std, dtype) dist = tfd.MultivariateNormalDiag(mean, std) return dist
def __call__(self, features): raw_init_std = np.log(np.exp(self._init_std) - 1) x = features for index in range(self._layers_num): x = self.get(f"h{index}", tf.keras.layers.Dense, self._units, self._act)(x) if self._dist == "tanh_normal": # https://www.desmos.com/calculator/rcmcf5jwe7 x = self.get(f"hout", tf.keras.layers.Dense, 2 * self._size)(x) mean, std = tf.split(x, 2, -1) mean = self._mean_scale * tf.tanh(mean / self._mean_scale) std = tf.nn.softplus(std + raw_init_std) + self._min_std dist = tfd.Normal(mean, std) dist = tfd.TransformedDistribution(dist, tools.TanhBijector()) dist = tfd.Independent(dist, 1) dist = tools.SampleDist(dist) elif self._dist == "onehot": x = self.get(f"hout", tf.keras.layers.Dense, self._size)(x) dist = tools.OneHotDist(x) else: raise NotImplementedError(dist) return dist
def main(config): logdir = pathlib.Path(config.logdir).expanduser() config.traindir = config.traindir or logdir / 'train_eps' config.evaldir = config.evaldir or logdir / 'eval_eps' config.steps //= config.action_repeat config.eval_every //= config.action_repeat config.log_every //= config.action_repeat config.time_limit //= config.action_repeat config.act = getattr(tf.nn, config.act) if config.debug: tf.config.experimental_run_functions_eagerly(True) if config.gpu_growth: message = 'No GPU found. To actually train on CPU remove this assert.' assert tf.config.experimental.list_physical_devices('GPU'), message for gpu in tf.config.experimental.list_physical_devices('GPU'): tf.config.experimental.set_memory_growth(gpu, True) assert config.precision in (16, 32), config.precision if config.precision == 16: prec.set_policy(prec.Policy('mixed_float16')) print('Logdir', logdir) logdir.mkdir(parents=True, exist_ok=True) config.traindir.mkdir(parents=True, exist_ok=True) config.evaldir.mkdir(parents=True, exist_ok=True) step = count_steps(config.traindir) logger = tools.Logger(logdir, config.action_repeat * step) print('Create envs.') if config.offline_traindir: directory = config.offline_traindir.format(**vars(config)) else: directory = config.traindir train_eps = tools.load_episodes(directory, limit=config.dataset_size) if config.offline_evaldir: directory = config.offline_evaldir.format(**vars(config)) else: directory = config.evaldir eval_eps = tools.load_episodes(directory, limit=1) make = lambda mode: make_env(config, logger, mode, train_eps, eval_eps) train_envs = [make('train') for _ in range(config.envs)] eval_envs = [make('eval') for _ in range(config.envs)] acts = train_envs[0].action_space config.num_actions = acts.n if hasattr(acts, 'n') else acts.shape[0] prefill = max(0, config.prefill - count_steps(config.traindir)) print(f'Prefill dataset ({prefill} steps).') if hasattr(acts, 'discrete'): random_actor = tools.OneHotDist(tf.zeros_like(acts.low)[None]) else: random_actor = tfd.Independent( tfd.Uniform(acts.low[None], acts.high[None]), 1) def random_agent(o, d, s): action = random_actor.sample() logprob = random_actor.log_prob(action) return {'action': action, 'logprob': logprob}, None tools.simulate(random_agent, train_envs, prefill) tools.simulate(random_agent, eval_envs, episodes=1) logger.step = config.action_repeat * count_steps(config.traindir) print('Simulate agent.') train_dataset = make_dataset(train_eps, config) eval_dataset = iter(make_dataset(eval_eps, config)) agent = Dreamer(config, logger, train_dataset) if (logdir / 'variables.pkl').exists(): agent.load(logdir / 'variables.pkl') agent._should_pretrain._once = False state = None while agent._step.numpy().item() < config.steps: logger.write() print('Start evaluation.') video_pred = agent._wm.video_pred(next(eval_dataset)) logger.video('eval_openl', video_pred) eval_policy = functools.partial(agent, training=False) tools.simulate(eval_policy, eval_envs, episodes=1) print('Start training.') state = tools.simulate(agent, train_envs, config.eval_every, state=state) agent.save(logdir / 'variables.pkl') for env in train_envs + eval_envs: try: env.close() except Exception: pass
def __call__(self, features, dtype=None): x = features for index in range(self._layers): kw = {} if index == self._layers - 1 and self._outscale: kw['kernel_initializer'] = tf.keras.initializers.VarianceScaling( self._outscale) x = self.get(f'h{index}', tfkl.Dense, self._units, self._act, **kw)(x) if self._dist == 'tanh_normal': # https://www.desmos.com/calculator/rcmcf5jwe7 x = self.get(f'hout', tfkl.Dense, 2 * self._size)(x) if dtype: x = tf.cast(x, dtype) mean, std = tf.split(x, 2, -1) mean = tf.tanh(mean) std = tf.nn.softplus(std + self._init_std) + self._min_std dist = tfd.Normal(mean, std) dist = tfd.TransformedDistribution(dist, tools.TanhBijector()) dist = tfd.Independent(dist, 1) dist = tools.SampleDist(dist) elif self._dist == 'tanh_normal_5': x = self.get(f'hout', tfkl.Dense, 2 * self._size)(x) if dtype: x = tf.cast(x, dtype) mean, std = tf.split(x, 2, -1) mean = 5 * tf.tanh(mean / 5) std = tf.nn.softplus(std + 5) + 5 dist = tfd.Normal(mean, std) dist = tfd.TransformedDistribution(dist, tools.TanhBijector()) dist = tfd.Independent(dist, 1) dist = tools.SampleDist(dist) elif self._dist == 'normal': x = self.get(f'hout', tfkl.Dense, 2 * self._size)(x) if dtype: x = tf.cast(x, dtype) mean, std = tf.split(x, 2, -1) std = tf.nn.softplus(std + self._init_std) + self._min_std dist = tfd.Normal(mean, std) dist = tfd.Independent(dist, 1) elif self._dist == 'normal_1': mean = self.get(f'hout', tfkl.Dense, self._size)(x) if dtype: mean = tf.cast(mean, dtype) dist = tfd.Normal(mean, 1) dist = tfd.Independent(dist, 1) elif self._dist == 'trunc_normal': # https://www.desmos.com/calculator/mmuvuhnyxo x = self.get(f'hout', tfkl.Dense, 2 * self._size)(x) x = tf.cast(x, tf.float32) mean, std = tf.split(x, 2, -1) mean = tf.tanh(mean) std = 2 * tf.nn.sigmoid(std / 2) + self._min_std dist = tools.SafeTruncatedNormal(mean, std, -1, 1) dist = tools.DtypeDist(dist, dtype) dist = tfd.Independent(dist, 1) elif self._dist == 'onehot': x = self.get(f'hout', tfkl.Dense, self._size)(x) x = tf.cast(x, tf.float32) dist = tools.OneHotDist(x, dtype=dtype) dist = tools.DtypeDist(dist, dtype) elif self._dist == 'onehot_gumble': x = self.get(f'hout', tfkl.Dense, self._size)(x) if dtype: x = tf.cast(x, dtype) temp = self._temp dist = tools.GumbleDist(temp, x, dtype=dtype) else: raise NotImplementedError(self._dist) return dist
def random_policy(agent_state): action_dist = tools.OneHotDist(1.0*(torch.tensor([0,0,0,0,0,0])).cuda()[None]) action = action_dist.sample() return action.detach()[0].cpu().numpy()