def update_actor_and_alpha(self, obs, update_alpha=True): from ml_logger import logger _, pi, log_pi, log_std = self.actor(obs, detach=True) actor_Q1, actor_Q2 = self.critic(obs, pi, detach=True) actor_Q = torch.min(actor_Q1, actor_Q2) actor_loss = (self.alpha.detach() * log_pi - actor_Q).mean() logger.store_metrics(actor_loss=actor_loss) # entropy = 0.5 * log_std.shape[1] * (1.0 + np.log(2 * np.pi)) + log_std.sum(dim=-1) self.actor_optim.zero_grad() actor_loss.backward() self.actor_optim.step() if update_alpha: self.log_alpha_optim.zero_grad() alpha_loss = (self.alpha * (-log_pi - self.target_entropy).detach()).mean() logger.store_metrics({ 'alpha/loss': alpha_loss, 'alpha/value': self.alpha }) alpha_loss.backward() self.log_alpha_optim.step()
def update_inverse_dynamics(self, obs, obs_next, action): from ml_logger import logger assert obs.shape[-1] == 84 and obs_next.shape[-1] == 84 pred_action = self.pad_head(obs, obs_next) pad_loss = F.mse_loss(pred_action, action) self.pad_optimizer.zero_grad() pad_loss.backward() self.pad_optimizer.step() logger.store_metrics({'inm/loss': pad_loss})
def update_curl(self, x, x_pos): from ml_logger import logger assert x.size(-1) == 84 and x_pos.size(-1) == 84 z_a = self.curl_head.encoder(x) with torch.no_grad(): z_pos = self.critic_target.encoder(x_pos) logits = self.curl_head.compute_logits(z_a, z_pos) labels = torch.arange(logits.shape[0]).long().to(self.device) curl_loss = F.cross_entropy(logits, labels) self.curl_optimizer.zero_grad() curl_loss.backward() self.curl_optimizer.step() logger.store_metrics(aux_loss=curl_loss)
def test_store_metrics_prefix(setup): import numpy as np from ml_logger import logger logger.remove("metrics.pkl") for i in range(10): with logger.Prefix(metrics="test"): logger.store_metrics(value=1.0) with logger.Prefix(metrics="eval"): logger.store_metrics(value=3.0) logger.log_metrics_summary(key_values={'step': 10}) assert logger.read_metrics("test/value/mean")[0] == 1.0 assert logger.read_metrics("step")[0] == 10
def update_critic(self, obs, action, reward, next_obs, not_done): from ml_logger import logger with torch.no_grad(): _, policy_action, log_pi, _ = self.actor(next_obs) target_Q1, target_Q2 = self.critic_target(next_obs, policy_action) target_V = torch.min(target_Q1, target_Q2) - self.alpha.detach() * log_pi target_Q = reward + (not_done * self.discount * target_V) current_Q1, current_Q2 = self.critic(obs, action) critic_loss = F.mse_loss(current_Q1, target_Q) + F.mse_loss( current_Q2, target_Q) logger.store_metrics({'critic/loss': critic_loss}) self.critic_optim.zero_grad() critic_loss.backward() self.critic_optim.step()
def update_soda(self, replay_buffer): from ml_logger import logger x = replay_buffer.sample_soda(self.soda_batch_size) assert x.size(-1) == 100 aug_x = x.clone() x = augmentations.random_crop(x) aug_x = augmentations.random_crop(aug_x) aug_x = augmentations.random_overlay(aug_x) soda_loss = self.compute_soda_loss(aug_x, x) self.soda_optimizer.zero_grad() soda_loss.backward() self.soda_optimizer.step() logger.store_metrics({'soda/loss': soda_loss}) utils.soft_update_params(self.predictor, self.predictor_target, self.soda_tau)
def evaluate(env, agent, num_episodes, save_video=None): from ml_logger import logger episode_rewards, frames = [], [] for i in trange(num_episodes, desc="Eval"): obs = env.reset() done = False episode_reward = 0 while not done: with utils.Eval(agent): action = agent.select_action(obs) obs, reward, done, _ = env.step(action) if save_video: frames.append(env.render('rgb_array', width=64, height=64)) episode_reward += reward if save_video: logger.save_video(frames, key=save_video) logger.store_metrics(episode_reward=episode_reward) episode_rewards.append(episode_reward) return np.mean(episode_rewards)
logger.flush() # outputs ~> # ╒════════════════════╤════════════════════════════╕ # │ reward │ 20 │ # ├────────────────────┼────────────────────────────┤ # │ timestep │ 0 │ # ├────────────────────┼────────────────────────────┤ # │ some val/smooth │ 10 │ # ├────────────────────┼────────────────────────────┤ # │ status │ step (0) │ # ├────────────────────┼────────────────────────────┤ # │ timestamp │'2018-11-04T11:37:03.324824'│ # ╘════════════════════╧════════════════════════════╛ for i in range(100): logger.store_metrics(metrics={'some_val/smooth': 10}, some=20, timestep=i) logger.peek_stored_metrics(len=4) # outputs ~> # some | timestep |some_val/smooth # ━━━━━━━━━━━━━━━┿━━━━━━━━━━━━━━━┿━━━━━━━━━━━━━━━ # 20 | 0 | 10 # 20 | 1 | 10 # 20 | 2 | 10 # 20 | 3 | 10 ### The metrics are stored in-memory. Now we need to actually log the summaries: logger.log_metrics_summary(silent=True) # outputs ~> . (data is now logged to the server) ### Logging Matplotlib pyplots
def train(deps=None, **kwargs): from ml_logger import logger from dmc_gen.config import Args Args._update(deps, **kwargs) logger.log_params(Args=vars(Args)) utils.set_seed_everywhere(Args.seed) wrappers.VideoWrapper.prefix = wrappers.ColorWrapper.prefix = DMCGEN_DATA # Initialize environments image_size = 84 if Args.algo == 'sac' else 100 env = wrappers.make_env( domain_name=Args.domain, task_name=Args.task, seed=Args.seed, episode_length=Args.episode_length, action_repeat=Args.action_repeat, image_size=image_size, ) test_env = wrappers.make_env(domain_name=Args.domain, task_name=Args.task, seed=Args.seed + 42, episode_length=Args.episode_length, action_repeat=Args.action_repeat, image_size=image_size, mode=Args.eval_mode) # Prepare agent cropped_obs_shape = (3 * Args.frame_stack, 84, 84) agent = make_agent(algo=Args.algo, obs_shape=cropped_obs_shape, act_shape=env.action_space.shape, args=Args).to(Args.device) if Args.load_checkpoint: print('Loading from checkpoint:', Args.load_checkpoint) logger.load_module(agent, path="models/*.pkl", wd=Args.load_checkpoint, map_location=Args.device) replay_buffer = utils.ReplayBuffer(obs_shape=env.observation_space.shape, action_shape=env.action_space.shape, capacity=Args.train_steps, batch_size=Args.batch_size) episode, episode_reward, episode_step, done = 0, 0, 0, True logger.start('train') for step in range(Args.start_step, Args.train_steps + 1): if done: if step > Args.start_step: logger.store_metrics({'dt_epoch': logger.split('train')}) logger.log_metrics_summary(dict(step=step), default_stats='mean') # Evaluate agent periodically if step % Args.eval_freq == 0: logger.store_metrics(episode=episode) with logger.Prefix(metrics="eval/"): evaluate(env, agent, Args.eval_episodes, save_video=f"videos/{step:08d}_train.mp4") with logger.Prefix(metrics="test/"): evaluate(test_env, agent, Args.eval_episodes, save_video=f"videos/{step:08d}_test.mp4") logger.log_metrics_summary(dict(step=step), default_stats='mean') # Save agent periodically if step > Args.start_step and step % Args.save_freq == 0: with logger.Sync(): logger.save_module(agent, f"models/{step:06d}.pkl") if Args.save_last: logger.remove(f"models/{step - Args.save_freq:06d}.pkl") # torch.save(agent, os.path.join(model_dir, f'{step}.pt')) logger.store_metrics(episode_reward=episode_reward, episode=episode + 1, prefix="train/") obs = env.reset() episode_reward, episode_step, done = 0, 0, False episode += 1 # Sample action for data collection if step < Args.init_steps: action = env.action_space.sample() else: with utils.Eval(agent): action = agent.sample_action(obs) # Run training update if step >= Args.init_steps: num_updates = Args.init_steps if step == Args.init_steps else 1 for _ in range(num_updates): agent.update(replay_buffer, step) # Take step next_obs, reward, done, _ = env.step(action) done_bool = 0 if episode_step + 1 == env._max_episode_steps else float( done) replay_buffer.add(obs, action, reward, next_obs, done_bool) episode_reward += reward obs = next_obs episode_step += 1 logger.print( f'Completed training for {Args.domain}_{Args.task}/{Args.algo}/{Args.seed}' )