Пример #1
0
    def update_actor_and_alpha(self, obs, update_alpha=True):
        from ml_logger import logger

        _, pi, log_pi, log_std = self.actor(obs, detach=True)
        actor_Q1, actor_Q2 = self.critic(obs, pi, detach=True)

        actor_Q = torch.min(actor_Q1, actor_Q2)
        actor_loss = (self.alpha.detach() * log_pi - actor_Q).mean()

        logger.store_metrics(actor_loss=actor_loss)
        # entropy = 0.5 * log_std.shape[1] * (1.0 + np.log(2 * np.pi)) + log_std.sum(dim=-1)

        self.actor_optim.zero_grad()
        actor_loss.backward()
        self.actor_optim.step()

        if update_alpha:
            self.log_alpha_optim.zero_grad()
            alpha_loss = (self.alpha *
                          (-log_pi - self.target_entropy).detach()).mean()

            logger.store_metrics({
                'alpha/loss': alpha_loss,
                'alpha/value': self.alpha
            })

            alpha_loss.backward()
            self.log_alpha_optim.step()
Пример #2
0
    def update_inverse_dynamics(self, obs, obs_next, action):
        from ml_logger import logger
        assert obs.shape[-1] == 84 and obs_next.shape[-1] == 84

        pred_action = self.pad_head(obs, obs_next)
        pad_loss = F.mse_loss(pred_action, action)

        self.pad_optimizer.zero_grad()
        pad_loss.backward()
        self.pad_optimizer.step()

        logger.store_metrics({'inm/loss': pad_loss})
Пример #3
0
    def update_curl(self, x, x_pos):
        from ml_logger import logger
        assert x.size(-1) == 84 and x_pos.size(-1) == 84

        z_a = self.curl_head.encoder(x)
        with torch.no_grad():
            z_pos = self.critic_target.encoder(x_pos)

        logits = self.curl_head.compute_logits(z_a, z_pos)
        labels = torch.arange(logits.shape[0]).long().to(self.device)
        curl_loss = F.cross_entropy(logits, labels)

        self.curl_optimizer.zero_grad()
        curl_loss.backward()
        self.curl_optimizer.step()
        logger.store_metrics(aux_loss=curl_loss)
Пример #4
0
def test_store_metrics_prefix(setup):
    import numpy as np
    from ml_logger import logger

    logger.remove("metrics.pkl")

    for i in range(10):
        with logger.Prefix(metrics="test"):
            logger.store_metrics(value=1.0)
        with logger.Prefix(metrics="eval"):
            logger.store_metrics(value=3.0)

    logger.log_metrics_summary(key_values={'step': 10})

    assert logger.read_metrics("test/value/mean")[0] == 1.0
    assert logger.read_metrics("step")[0] == 10
Пример #5
0
    def update_critic(self, obs, action, reward, next_obs, not_done):
        from ml_logger import logger
        with torch.no_grad():
            _, policy_action, log_pi, _ = self.actor(next_obs)
            target_Q1, target_Q2 = self.critic_target(next_obs, policy_action)
            target_V = torch.min(target_Q1,
                                 target_Q2) - self.alpha.detach() * log_pi
            target_Q = reward + (not_done * self.discount * target_V)

        current_Q1, current_Q2 = self.critic(obs, action)
        critic_loss = F.mse_loss(current_Q1, target_Q) + F.mse_loss(
            current_Q2, target_Q)
        logger.store_metrics({'critic/loss': critic_loss})

        self.critic_optim.zero_grad()
        critic_loss.backward()
        self.critic_optim.step()
Пример #6
0
    def update_soda(self, replay_buffer):
        from ml_logger import logger
        x = replay_buffer.sample_soda(self.soda_batch_size)
        assert x.size(-1) == 100

        aug_x = x.clone()

        x = augmentations.random_crop(x)
        aug_x = augmentations.random_crop(aug_x)
        aug_x = augmentations.random_overlay(aug_x)

        soda_loss = self.compute_soda_loss(aug_x, x)

        self.soda_optimizer.zero_grad()
        soda_loss.backward()
        self.soda_optimizer.step()
        logger.store_metrics({'soda/loss': soda_loss})

        utils.soft_update_params(self.predictor, self.predictor_target,
                                 self.soda_tau)
Пример #7
0
def evaluate(env, agent, num_episodes, save_video=None):
    from ml_logger import logger

    episode_rewards, frames = [], []
    for i in trange(num_episodes, desc="Eval"):
        obs = env.reset()
        done = False
        episode_reward = 0
        while not done:
            with utils.Eval(agent):
                action = agent.select_action(obs)
            obs, reward, done, _ = env.step(action)
            if save_video:
                frames.append(env.render('rgb_array', width=64, height=64))
            episode_reward += reward

        if save_video:
            logger.save_video(frames, key=save_video)
        logger.store_metrics(episode_reward=episode_reward)
        episode_rewards.append(episode_reward)

    return np.mean(episode_rewards)
Пример #8
0
    logger.flush()
# outputs ~>
# ╒════════════════════╤════════════════════════════╕
# │       reward       │             20             │
# ├────────────────────┼────────────────────────────┤
# │      timestep      │             0              │
# ├────────────────────┼────────────────────────────┤
# │  some val/smooth   │             10             │
# ├────────────────────┼────────────────────────────┤
# │       status       │          step (0)          │
# ├────────────────────┼────────────────────────────┤
# │      timestamp     │'2018-11-04T11:37:03.324824'│
# ╘════════════════════╧════════════════════════════╛

for i in range(100):
    logger.store_metrics(metrics={'some_val/smooth': 10}, some=20, timestep=i)

logger.peek_stored_metrics(len=4)
# outputs ~>
#      some      |   timestep    |some_val/smooth
# ━━━━━━━━━━━━━━━┿━━━━━━━━━━━━━━━┿━━━━━━━━━━━━━━━
#       20       |       0       |      10
#       20       |       1       |      10
#       20       |       2       |      10
#       20       |       3       |      10

### The metrics are stored in-memory. Now we need to actually log the summaries:
logger.log_metrics_summary(silent=True)
# outputs ~> . (data is now logged to the server)

### Logging Matplotlib pyplots
Пример #9
0
def train(deps=None, **kwargs):
    from ml_logger import logger
    from dmc_gen.config import Args

    Args._update(deps, **kwargs)
    logger.log_params(Args=vars(Args))

    utils.set_seed_everywhere(Args.seed)
    wrappers.VideoWrapper.prefix = wrappers.ColorWrapper.prefix = DMCGEN_DATA

    # Initialize environments
    image_size = 84 if Args.algo == 'sac' else 100
    env = wrappers.make_env(
        domain_name=Args.domain,
        task_name=Args.task,
        seed=Args.seed,
        episode_length=Args.episode_length,
        action_repeat=Args.action_repeat,
        image_size=image_size,
    )
    test_env = wrappers.make_env(domain_name=Args.domain,
                                 task_name=Args.task,
                                 seed=Args.seed + 42,
                                 episode_length=Args.episode_length,
                                 action_repeat=Args.action_repeat,
                                 image_size=image_size,
                                 mode=Args.eval_mode)

    # Prepare agent
    cropped_obs_shape = (3 * Args.frame_stack, 84, 84)
    agent = make_agent(algo=Args.algo,
                       obs_shape=cropped_obs_shape,
                       act_shape=env.action_space.shape,
                       args=Args).to(Args.device)

    if Args.load_checkpoint:
        print('Loading from checkpoint:', Args.load_checkpoint)
        logger.load_module(agent,
                           path="models/*.pkl",
                           wd=Args.load_checkpoint,
                           map_location=Args.device)

    replay_buffer = utils.ReplayBuffer(obs_shape=env.observation_space.shape,
                                       action_shape=env.action_space.shape,
                                       capacity=Args.train_steps,
                                       batch_size=Args.batch_size)

    episode, episode_reward, episode_step, done = 0, 0, 0, True
    logger.start('train')
    for step in range(Args.start_step, Args.train_steps + 1):
        if done:
            if step > Args.start_step:
                logger.store_metrics({'dt_epoch': logger.split('train')})
                logger.log_metrics_summary(dict(step=step),
                                           default_stats='mean')

            # Evaluate agent periodically
            if step % Args.eval_freq == 0:
                logger.store_metrics(episode=episode)
                with logger.Prefix(metrics="eval/"):
                    evaluate(env,
                             agent,
                             Args.eval_episodes,
                             save_video=f"videos/{step:08d}_train.mp4")
                with logger.Prefix(metrics="test/"):
                    evaluate(test_env,
                             agent,
                             Args.eval_episodes,
                             save_video=f"videos/{step:08d}_test.mp4")
                logger.log_metrics_summary(dict(step=step),
                                           default_stats='mean')

            # Save agent periodically
            if step > Args.start_step and step % Args.save_freq == 0:
                with logger.Sync():
                    logger.save_module(agent, f"models/{step:06d}.pkl")
                if Args.save_last:
                    logger.remove(f"models/{step - Args.save_freq:06d}.pkl")
                # torch.save(agent, os.path.join(model_dir, f'{step}.pt'))

            logger.store_metrics(episode_reward=episode_reward,
                                 episode=episode + 1,
                                 prefix="train/")

            obs = env.reset()
            episode_reward, episode_step, done = 0, 0, False
            episode += 1

        # Sample action for data collection
        if step < Args.init_steps:
            action = env.action_space.sample()
        else:
            with utils.Eval(agent):
                action = agent.sample_action(obs)

        # Run training update
        if step >= Args.init_steps:
            num_updates = Args.init_steps if step == Args.init_steps else 1
            for _ in range(num_updates):
                agent.update(replay_buffer, step)

        # Take step
        next_obs, reward, done, _ = env.step(action)
        done_bool = 0 if episode_step + 1 == env._max_episode_steps else float(
            done)
        replay_buffer.add(obs, action, reward, next_obs, done_bool)
        episode_reward += reward
        obs = next_obs

        episode_step += 1

    logger.print(
        f'Completed training for {Args.domain}_{Args.task}/{Args.algo}/{Args.seed}'
    )