Ejemplo n.º 1
0
    def __init__(self,
                 env,
                 test_env,
                 log_dir,
                 num_steps=5 * (10**7),
                 lr=5e-5,
                 gamma=0.99,
                 multi_step=1,
                 update_interval=4,
                 target_update_interval=10000,
                 start_steps=50000,
                 epsilon_train=0.01,
                 epsilon_eval=0.001,
                 epsilon_decay_steps=250000,
                 double_q_learning=False,
                 log_interval=100,
                 eval_interval=250000,
                 num_eval_steps=125000,
                 max_episode_steps=27000,
                 seed=0,
                 cuda=True):

        self.env = env
        self.test_env = test_env
        self.num_actions = env.num_actions

        self.device = torch.device(
            "cuda" if cuda and torch.cuda.is_available() else "cpu")

        self.start_steps = start_steps
        self.max_episode_steps = max_episode_steps
        self.epsilon_train = LinearAnneaer(1.0, epsilon_train,
                                           epsilon_decay_steps)
        self.epsilon_eval = epsilon_eval
        self.num_steps = num_steps
        self.num_eval_steps = num_eval_steps
        self.lr = lr
        self.gamma = gamma
        self.double_q_learning = double_q_learning

        self.target_update_interval = target_update_interval
        self.log_interval = log_interval
        self.update_interval = update_interval
        self.eval_interval = eval_interval

        self.log_dir = log_dir
        self.model_dir = os.path.join(log_dir, 'model')
        self.summary_dir = os.path.join(log_dir, 'summary')
        if not os.path.exists(self.model_dir):
            os.makedirs(self.model_dir)
        if not os.path.exists(self.summary_dir):
            os.makedirs(self.summary_dir)
        self.writer = SummaryWriter(log_dir=self.summary_dir)
        self.train_return = RunningMeanStats(log_interval)

        self.steps = 0
        self.episodes = 0
        self.learning_steps = 0
        self.best_eval_score = -np.inf
    def __init__(self,
                 env,
                 log_dir,
                 num_steps=3000000,
                 initial_latent_steps=100000,
                 batch_size=256,
                 latent_batch_size=32,
                 num_sequences=8,
                 lr=0.0003,
                 latent_lr=0.0001,
                 feature_dim=256,
                 latent1_dim=32,
                 latent2_dim=256,
                 hidden_units=[256, 256],
                 memory_size=1e5,
                 gamma=0.99,
                 target_update_interval=1,
                 tau=0.005,
                 entropy_tuning=True,
                 ent_coef=0.2,
                 leaky_slope=0.2,
                 grad_clip=None,
                 updates_per_step=1,
                 start_steps=10000,
                 training_log_interval=10,
                 learning_log_interval=100,
                 eval_interval=50000,
                 cuda=True,
                 seed=0,
                 colab='save'):
        self.env = env
        self.observation_shape = self.env.observation_space.shape
        self.action_shape = self.env.action_space.shape
        self.action_repeat = self.env.action_repeat

        torch.manual_seed(seed)
        np.random.seed(seed)
        self.env.seed(seed)
        # torch.backends.cudnn.deterministic = True  # It harms a performance.
        # torch.backends.cudnn.benchmark = False  # It harms a performance.

        self.device = torch.device(
            "cuda" if cuda and torch.cuda.is_available() else "cpu")

        self.latent = LatentNetwork(self.observation_shape, self.action_shape,
                                    feature_dim, latent1_dim, latent2_dim,
                                    hidden_units, leaky_slope).to(self.device)

        self.policy = GaussianPolicy(
            num_sequences * feature_dim +
            (num_sequences - 1) * self.action_shape[0], self.action_shape[0],
            hidden_units).to(self.device)

        # Policy is updated without the encoder.
        self.latent_optim = Adam(self.latent.parameters(), lr=latent_lr)
        self.memory = LazyMemory(memory_size, num_sequences,
                                 self.observation_shape, self.action_shape,
                                 self.device)

        self.log_dir = log_dir
        self.model_dir = os.path.join(log_dir, 'model')
        self.summary_dir = os.path.join(log_dir, 'summary')
        self.images_dir = os.path.join(log_dir, 'images')
        if not os.path.exists(self.model_dir):
            os.makedirs(self.model_dir)
        if not os.path.exists(self.summary_dir):
            os.makedirs(self.summary_dir)
        if not os.path.exists(self.images_dir):
            os.makedirs(self.images_dir)

        self.writer = SummaryWriter(log_dir=self.summary_dir)
        self.train_rewards = RunningMeanStats(training_log_interval)

        self.steps = 0
        self.learning_steps = 0
        self.episodes = 0
        self.initial_latent_steps = initial_latent_steps
        self.num_sequences = num_sequences
        self.num_steps = num_steps
        self.tau = tau
        self.batch_size = batch_size
        self.latent_batch_size = latent_batch_size
        self.start_steps = start_steps
        self.gamma = gamma
        self.entropy_tuning = entropy_tuning
        self.grad_clip = grad_clip
        self.updates_per_step = updates_per_step
        self.training_log_interval = training_log_interval
        self.learning_log_interval = learning_log_interval
        self.target_update_interval = target_update_interval
        self.eval_interval = eval_interval
        self.colab = colab
class LatentTrainer:
    def __init__(self,
                 env,
                 log_dir,
                 num_steps=3000000,
                 initial_latent_steps=100000,
                 batch_size=256,
                 latent_batch_size=32,
                 num_sequences=8,
                 lr=0.0003,
                 latent_lr=0.0001,
                 feature_dim=256,
                 latent1_dim=32,
                 latent2_dim=256,
                 hidden_units=[256, 256],
                 memory_size=1e5,
                 gamma=0.99,
                 target_update_interval=1,
                 tau=0.005,
                 entropy_tuning=True,
                 ent_coef=0.2,
                 leaky_slope=0.2,
                 grad_clip=None,
                 updates_per_step=1,
                 start_steps=10000,
                 training_log_interval=10,
                 learning_log_interval=100,
                 eval_interval=50000,
                 cuda=True,
                 seed=0,
                 colab='save'):
        self.env = env
        self.observation_shape = self.env.observation_space.shape
        self.action_shape = self.env.action_space.shape
        self.action_repeat = self.env.action_repeat

        torch.manual_seed(seed)
        np.random.seed(seed)
        self.env.seed(seed)
        # torch.backends.cudnn.deterministic = True  # It harms a performance.
        # torch.backends.cudnn.benchmark = False  # It harms a performance.

        self.device = torch.device(
            "cuda" if cuda and torch.cuda.is_available() else "cpu")

        self.latent = LatentNetwork(self.observation_shape, self.action_shape,
                                    feature_dim, latent1_dim, latent2_dim,
                                    hidden_units, leaky_slope).to(self.device)

        self.policy = GaussianPolicy(
            num_sequences * feature_dim +
            (num_sequences - 1) * self.action_shape[0], self.action_shape[0],
            hidden_units).to(self.device)

        # Policy is updated without the encoder.
        self.latent_optim = Adam(self.latent.parameters(), lr=latent_lr)
        self.memory = LazyMemory(memory_size, num_sequences,
                                 self.observation_shape, self.action_shape,
                                 self.device)

        self.log_dir = log_dir
        self.model_dir = os.path.join(log_dir, 'model')
        self.summary_dir = os.path.join(log_dir, 'summary')
        self.images_dir = os.path.join(log_dir, 'images')
        if not os.path.exists(self.model_dir):
            os.makedirs(self.model_dir)
        if not os.path.exists(self.summary_dir):
            os.makedirs(self.summary_dir)
        if not os.path.exists(self.images_dir):
            os.makedirs(self.images_dir)

        self.writer = SummaryWriter(log_dir=self.summary_dir)
        self.train_rewards = RunningMeanStats(training_log_interval)

        self.steps = 0
        self.learning_steps = 0
        self.episodes = 0
        self.initial_latent_steps = initial_latent_steps
        self.num_sequences = num_sequences
        self.num_steps = num_steps
        self.tau = tau
        self.batch_size = batch_size
        self.latent_batch_size = latent_batch_size
        self.start_steps = start_steps
        self.gamma = gamma
        self.entropy_tuning = entropy_tuning
        self.grad_clip = grad_clip
        self.updates_per_step = updates_per_step
        self.training_log_interval = training_log_interval
        self.learning_log_interval = learning_log_interval
        self.target_update_interval = target_update_interval
        self.eval_interval = eval_interval
        self.colab = colab

    def run(self):
        while True:
            self.train_episode()
            if self.steps > self.num_steps:
                break

    def is_update(self):
        return len(self.memory) > self.batch_size and \
               self.steps >= self.start_steps * self.action_repeat

    def reset_deque(self, state):
        state_deque = deque(maxlen=self.num_sequences)
        action_deque = deque(maxlen=self.num_sequences - 1)

        for _ in range(self.num_sequences - 1):
            state_deque.append(np.zeros(self.observation_shape,
                                        dtype=np.uint8))
            action_deque.append(np.zeros(self.action_shape, dtype=np.uint8))
        state_deque.append(state)

        return state_deque, action_deque

    def deque_to_batch(self, state_deque, action_deque):
        # Convert deques to batched tensor.
        state = np.array(state_deque, dtype=np.uint8)
        state = torch.ByteTensor( \
            state).unsqueeze(0).to(self.device).float() / 255.0
        with torch.no_grad():
            feature = self.latent.encoder(state).view(1, -1)

        action = np.array(action_deque, dtype=np.float32)
        action = torch.FloatTensor(action).view(1, -1).to(self.device)
        feature_action = torch.cat([feature, action], dim=-1)
        return feature_action

    def explore(self, state_deque, action_deque):
        # Act with randomness
        feature_action = self.deque_to_batch(state_deque, action_deque)
        with torch.no_grad():
            action, _, _ = self.policy.sample(feature_action)
        return action.cpu().numpy().reshape(-1)

    def train_episode(self):
        self.episodes += 1
        episode_reward = 0.
        episode_steps = 0
        done = False
        state = self.env.reset()
        self.memory.set_initial_state(state)
        state_deque, action_deque = self.reset_deque(state)

        while not done:
            if self.steps >= self.start_steps * self.action_repeat:
                action = self.explore(state_deque, action_deque)
            else:
                action = 2 * np.random.rand(*self.action_shape) - 1

            next_state, reward, done, _ = self.env.step(action)
            self.steps += self.action_repeat
            episode_steps += self.action_repeat
            episode_reward += reward

            self.memory.append(action, reward, next_state, done)

            if self.is_update():
                # First, train the latent model only.
                if self.learning_steps < self.initial_latent_steps:
                    print('-' * 60)
                    print('Learning the latent model only...')
                    for _ in tqdm(range(self.initial_latent_steps)):
                        self.learning_steps += 1
                        self.learn_latent()
                    print('Finish learning the latent model.')
                    print('-' * 60)

            state_deque.append(next_state)
            action_deque.append(action)

        if self.episodes % self.training_log_interval == 0:
            self.writer.add_scalar('reward/train', self.train_rewards.get(),
                                   self.steps)

        print(f'episode: {self.episodes:<4}  '
              f'episode steps: {episode_steps:<4}  ')

    def learn_latent(self):
        images_seq, actions_seq, rewards_seq, dones_seq = \
            self.memory.sample_latent(self.latent_batch_size)
        latent_loss = self.calc_latent_loss(images_seq, actions_seq,
                                            rewards_seq, dones_seq)
        update_params(self.latent_optim, self.latent, latent_loss,
                      self.grad_clip)

        if self.learning_steps % self.learning_log_interval == 0:
            self.writer.add_scalar('loss/latent',
                                   latent_loss.detach().item(),
                                   self.learning_steps)

    def calc_latent_loss(self, images_seq, actions_seq, rewards_seq,
                         dones_seq):
        features_seq = self.latent.encoder(images_seq)

        # Sample from posterior dynamics.
        (latent1_post_samples, latent2_post_samples), \
        (latent1_post_dists, latent2_post_dists) = \
            self.latent.sample_posterior(features_seq, actions_seq)
        # Sample from prior dynamics.
        (latent1_pri_samples, latent2_pri_samples), \
        (latent1_pri_dists, latent2_pri_dists) = \
            self.latent.sample_prior(actions_seq)

        # KL divergence loss.
        kld_loss = calc_kl_divergence(latent1_post_dists, latent1_pri_dists)

        # Log likelihood loss of generated observations.
        images_seq_dists = self.latent.decoder(
            [latent1_post_samples, latent2_post_samples])
        log_likelihood_loss = images_seq_dists.log_prob(images_seq).mean(
            dim=0).sum()

        latent_loss = \
            kld_loss - log_likelihood_loss

        if self.learning_steps % self.learning_log_interval == 0:
            reconst_error = (images_seq - images_seq_dists.loc).pow(2).mean(
                dim=(0, 1)).sum().item()
            self.writer.add_scalar('stats/reconst_error', reconst_error,
                                   self.learning_steps)

        if self.learning_steps % self.learning_log_interval == 0:
            gt_images = images_seq[0].detach().cpu()
            post_images = images_seq_dists.loc[0].detach().cpu()

            with torch.no_grad():
                pri_images = self.latent.decoder(
                    [latent1_pri_samples[:1],
                     latent2_pri_samples[:1]]).loc[0].detach().cpu()
                cond_pri_samples, _ = self.latent.sample_prior(
                    actions_seq[:1], features_seq[:1, 0])
                cond_pri_images = self.latent.decoder(
                    cond_pri_samples).loc[0].detach().cpu()

            images = torch.cat(
                [gt_images, post_images, cond_pri_images, pri_images], dim=-2)

            for idx, img in enumerate(gt_images):
                Image.fromarray((img * 255).numpy().astype(np.uint8).transpose(
                    [1, 2, 0])).save(
                        os.path.join(self.images_dir,
                                     'gt_image%03i' % idx + '.png'))
                Image.fromarray((post_images[idx] * 255).numpy().astype(
                    np.uint8).transpose([1, 2, 0])).save(
                        os.path.join(self.images_dir,
                                     'post_images%03i' % idx + '.png'))
                Image.fromarray((cond_pri_images[idx] * 255).numpy().astype(
                    np.uint8).transpose([1, 2, 0])).save(
                        os.path.join(self.images_dir,
                                     'cond_pri_image%03i' % idx + '.png'))
                Image.fromarray((pri_images[idx] * 255).numpy().astype(
                    np.uint8).transpose([1, 2, 0])).save(
                        os.path.join(self.images_dir,
                                     'pri_images%03i' % idx + '.png'))

            # Visualize multiple of 8 images because each row contains 8
            # images at most.
            self.writer.add_images('images/gt_posterior_cond-prior_prior',
                                   images[:(len(images) // 8) * 8],
                                   self.learning_steps)

        return latent_loss

    def save_models(self):
        self.latent.encoder.save(os.path.join(self.model_dir, 'encoder.pth'))
        self.latent.save(os.path.join(self.model_dir, 'latent.pth'))
        self.policy.save(os.path.join(self.model_dir, 'policy.pth'))

    def save_images(self, images):
        pass

    def __del__(self):
        self.writer.close()
        self.env.close()
Ejemplo n.º 4
0
    def __init__(self,
                 env,
                 test_env,
                 log_dir,
                 num_steps=5 * (10**7),
                 batch_size=32,
                 memory_size=10**6,
                 gamma=0.99,
                 multi_step=1,
                 update_interval=4,
                 target_update_interval=10000,
                 start_steps=50000,
                 epsilon_train=0.01,
                 epsilon_eval=0.001,
                 epsilon_decay_steps=250000,
                 double_q_learning=False,
                 log_interval=100,
                 eval_interval=250000,
                 num_eval_steps=125000,
                 max_episode_steps=27000,
                 grad_cliping=5.0,
                 cuda=True,
                 seed=0):

        self.env = env
        self.test_env = test_env

        torch.manual_seed(seed)
        np.random.seed(seed)
        self.env.seed(seed)
        self.test_env.seed(2**31 - 1 - seed)

        self.device = torch.device(
            "cuda" if cuda and torch.cuda.is_available() else "cpu")

        self.online_net = None
        self.target_net = None

        # Replay memory which is memory-efficient to store stacked frames.
        self.memory = LazyMultiStepMemory(memory_size,
                                          self.env.observation_space.shape,
                                          self.device, gamma, multi_step)

        self.log_dir = log_dir
        self.model_dir = os.path.join(log_dir, 'model')
        self.summary_dir = os.path.join(log_dir, 'summary')
        if not os.path.exists(self.model_dir):
            os.makedirs(self.model_dir)
        if not os.path.exists(self.summary_dir):
            os.makedirs(self.summary_dir)

        self.writer = SummaryWriter(log_dir=self.summary_dir)
        self.train_return = RunningMeanStats(log_interval)

        self.steps = 0
        self.learning_steps = 0
        self.episodes = 0
        self.best_eval_score = -np.inf
        self.num_actions = env.num_actions
        self.num_steps = num_steps
        self.batch_size = batch_size

        self.double_q_learning = double_q_learning

        self.log_interval = log_interval
        self.eval_interval = eval_interval
        self.num_eval_steps = num_eval_steps
        self.gamma_n = gamma**multi_step
        self.start_steps = start_steps
        self.epsilon_train = LinearAnneaer(1.0, epsilon_train,
                                           epsilon_decay_steps)
        self.epsilon_eval = epsilon_eval
        self.update_interval = update_interval
        self.target_update_interval = target_update_interval
        self.max_episode_steps = max_episode_steps
        self.grad_cliping = grad_cliping
Ejemplo n.º 5
0
class BaseAgent:
    def __init__(self,
                 env,
                 test_env,
                 log_dir,
                 num_steps=5 * (10**7),
                 batch_size=32,
                 memory_size=10**6,
                 gamma=0.99,
                 multi_step=1,
                 update_interval=4,
                 target_update_interval=10000,
                 start_steps=50000,
                 epsilon_train=0.01,
                 epsilon_eval=0.001,
                 epsilon_decay_steps=250000,
                 double_q_learning=False,
                 log_interval=100,
                 eval_interval=250000,
                 num_eval_steps=125000,
                 max_episode_steps=27000,
                 grad_cliping=5.0,
                 cuda=True,
                 seed=0):

        self.env = env
        self.test_env = test_env

        torch.manual_seed(seed)
        np.random.seed(seed)
        self.env.seed(seed)
        self.test_env.seed(2**31 - 1 - seed)

        self.device = torch.device(
            "cuda" if cuda and torch.cuda.is_available() else "cpu")

        self.online_net = None
        self.target_net = None

        # Replay memory which is memory-efficient to store stacked frames.
        self.memory = LazyMultiStepMemory(memory_size,
                                          self.env.observation_space.shape,
                                          self.device, gamma, multi_step)

        self.log_dir = log_dir
        self.model_dir = os.path.join(log_dir, 'model')
        self.summary_dir = os.path.join(log_dir, 'summary')
        if not os.path.exists(self.model_dir):
            os.makedirs(self.model_dir)
        if not os.path.exists(self.summary_dir):
            os.makedirs(self.summary_dir)

        self.writer = SummaryWriter(log_dir=self.summary_dir)
        self.train_return = RunningMeanStats(log_interval)

        self.steps = 0
        self.learning_steps = 0
        self.episodes = 0
        self.best_eval_score = -np.inf
        self.num_actions = env.num_actions
        self.num_steps = num_steps
        self.batch_size = batch_size

        self.double_q_learning = double_q_learning

        self.log_interval = log_interval
        self.eval_interval = eval_interval
        self.num_eval_steps = num_eval_steps
        self.gamma_n = gamma**multi_step
        self.start_steps = start_steps
        self.epsilon_train = LinearAnneaer(1.0, epsilon_train,
                                           epsilon_decay_steps)
        self.epsilon_eval = epsilon_eval
        self.update_interval = update_interval
        self.target_update_interval = target_update_interval
        self.max_episode_steps = max_episode_steps
        self.grad_cliping = grad_cliping

    def run(self):
        while True:
            self.train_episode()
            if self.steps > self.num_steps:
                break

    def is_update(self):
        return self.steps % self.update_interval == 0\
            and self.steps >= self.start_steps

    def is_greedy(self, eval=False):
        if eval:
            return np.random.rand() < self.epsilon_eval
        else:
            return self.steps < self.start_steps\
                or np.random.rand() < self.epsilon_train.get()

    def update_target(self):
        self.target_net.load_state_dict(self.online_net.state_dict())

    def explore(self):
        # Act with randomness.
        action = self.env.action_space.sample()
        return action

    def exploit(self, state):
        # one-hot encoding
        state = torch.eye(self.env.nrow * self.env.ncol,
                          dtype=torch.float32)[state].to(
                              self.device).unsqueeze(0)
        # state = torch.LongTensor([state]).to(self.device)

        with torch.no_grad():
            action = self.online_net.calculate_q(states=state).argmax().item()

        return action

    def learn(self):
        raise NotImplementedError

    def save_models(self, save_dir):
        if not os.path.exists(save_dir):
            os.makedirs(save_dir)
        torch.save(self.online_net.state_dict(),
                   os.path.join(save_dir, 'online_net.pth'))
        torch.save(self.target_net.state_dict(),
                   os.path.join(save_dir, 'target_net.pth'))

    def load_models(self, save_dir):
        self.online_net.load_state_dict(
            torch.load(os.path.join(save_dir, 'online_net.pth')))
        self.target_net.load_state_dict(
            torch.load(os.path.join(save_dir, 'target_net.pth')))

    def train_episode(self):
        self.online_net.train()
        self.target_net.train()

        self.episodes += 1
        episode_return = 0.
        episode_steps = 0

        done = False
        state = self.env.reset()

        while (not done) and episode_steps <= self.max_episode_steps:
            if self.is_greedy(eval=False):
                action = self.explore()
            else:
                action = self.exploit(state)

            next_state, reward, done, _ = self.env.step(action)

            if (not done) and (episode_steps == self.max_episode_steps):
                reward = -100
                done = True

            self.memory.append(state, action, reward, next_state, done)

            self.steps += 1
            episode_steps += 1
            episode_return += reward
            state = next_state

            self.train_step_interval()

        # We log running mean of stats.
        self.train_return.append(episode_return)

        # We log evaluation results along with training frames = 4 * steps.
        if self.episodes % self.log_interval == 0:
            self.writer.add_scalar('return/train', self.train_return.get(),
                                   4 * self.steps)

        if self.episodes % 1000 == 0:
            print(f'Episode: {self.episodes:<4}  '
                  f'episode steps: {episode_steps:<4}  '
                  f'return: {episode_return:<5.1f}')

    def train_step_interval(self):
        self.epsilon_train.step()

        if self.steps % self.target_update_interval == 0:
            self.update_target()

        if self.is_update():
            self.learn()

        if self.steps % self.eval_interval == 0:
            self.online_net.eval()
            self.evaluate()
            self.save_models(os.path.join(self.model_dir, 'final'))
            self.online_net.train()

    def evaluate(self):
        num_episodes = 0
        num_steps = 0
        total_return = 0.0

        while True:
            state = self.test_env.reset()
            episode_steps = 0
            episode_return = 0.0
            done = False
            while (not done) and episode_steps <= self.max_episode_steps:
                if self.is_greedy(eval=True):
                    action = self.explore()
                else:
                    action = self.exploit(state)

                next_state, reward, done, _ = self.test_env.step(action)
                num_steps += 1
                episode_steps += 1
                episode_return += reward
                state = next_state

            num_episodes += 1
            total_return += episode_return

            if num_steps > self.num_eval_steps:
                break

        mean_return = total_return / num_episodes

        if mean_return > self.best_eval_score:
            self.best_eval_score = mean_return
            self.save_models(os.path.join(self.model_dir, 'best'))

        # We log evaluation results along with training frames = 4 * steps.
        self.writer.add_scalar('return/test', mean_return, 4 * self.steps)
        print('-' * 60)
        print(f'Num steps: {self.steps:<5}  ' f'return: {mean_return:<5.1f}')
        print('-' * 60)

    def plot(self, q_value, dist=None):
        os.makedirs(self.log_dir + '/' + str(self.episodes))

        state_size = 3
        q_nrow = self.env.nrow * state_size
        q_ncol = self.env.ncol * state_size

        # Normalization
        xmin = -1
        xmax = 1
        q_value = (q_value - q_value.min()) / (q_value.max() - q_value.min()) * \
                  (xmax - xmin) + xmin

        # Delete Corner
        q_value[:, 0] = 0
        q_value[-1, :, 1] = 0
        q_value[:, -1, 2] = 0
        # q_value[0, :, 3] = 0

        value = np.zeros((q_nrow, q_ncol))

        # 0.Left, 1.Down, 2.Right, 3.Up, 4.Center
        value[1::3, ::3] += q_value[:, :, 0]
        value[2::3, 1::3] += q_value[:, :, 1]
        value[1::3, 2::3] += q_value[:, :, 2]
        # value[::3, 1::3] += q_value[:, :, 3]
        value[1::3, 1::3] += q_value.mean(axis=2)

        # Heatmap Plot
        fig = plt.figure(figsize=(6, 12))
        ax = fig.add_subplot(1, 1, 1)
        mappable0 = plt.imshow(value,
                               cmap=cm.jet,
                               interpolation="bilinear",
                               vmax=abs(value).max(),
                               vmin=-abs(value).max())

        ax.set_xticks(np.arange(-0.5, q_ncol, 3))
        ax.set_yticks(np.arange(-0.5, q_nrow, 3))
        ax.set_xticklabels(range(self.env.ncol + 1))
        ax.set_yticklabels(range(self.env.nrow + 1))
        ax.grid(which="both")

        # Marker Of Start, Goal, Cliff
        # Start: green, Goal: blue, Cliff: red
        for i in range(0, self.env.nrow):
            y = i * 3 + 1
            for j in range(self.env.ncol):
                x = j * 3 + 1
                if self.env.desc[i][j] == b'S':
                    ax.plot([x], [y],
                            marker="o",
                            color='g',
                            markersize=40,
                            alpha=0.8)
                    ax.text(x, y + 0.3, 'START', ha='center', size=12, c='w')
                elif self.env.desc[i][j] == b'G':
                    ax.plot([x], [y],
                            marker="o",
                            color='r',
                            markersize=40,
                            alpha=0.8)
                    ax.text(x, y + 0.3, 'GOAL', ha='center', size=12, c='w')
                elif self.env.desc[i][j] == b'H':
                    ax.plot([x], [y],
                            marker="x",
                            color='b',
                            markersize=30,
                            markeredgewidth=10,
                            alpha=0.8)
                elif self.env.desc[i][j] == b'g':
                    ax.plot([x], [y],
                            marker="o",
                            color='orange',
                            markersize=30,
                            markeredgewidth=10,
                            alpha=0.8)

        fig.colorbar(mappable0, ax=ax, orientation="vertical")

        plt.savefig(self.log_dir + '/' + str(self.episodes) + '/' +
                    'heatmap.png')
        plt.close()

        # Optimization Path
        fig = plt.figure(figsize=(6, 12))
        ax = fig.add_subplot(1, 1, 1)
        value = np.zeros((q_nrow, q_ncol))
        mappable0 = plt.imshow(value,
                               cmap=cm.jet,
                               interpolation="bilinear",
                               vmax=abs(value).max(),
                               vmin=-abs(value).max())

        opt_act = q_value.argmax(axis=2)
        self.plot_arrow(ax, opt_act)

        ax.set_xticks(np.arange(-0.5, q_ncol, 3))
        ax.set_yticks(np.arange(-0.5, q_nrow, 3))
        ax.set_xticklabels(range(self.env.ncol + 1))
        ax.set_yticklabels(range(self.env.nrow + 1))
        ax.grid(which="both")

        # Marker Of Start, Goal, Cliff
        # Start: green, Goal: blue, Cliff: red
        for i in range(0, self.env.nrow):
            y = i * 3 + 1
            for j in range(self.env.ncol):
                x = j * 3 + 1
                if self.env.desc[i][j] == b'S':
                    ax.plot([x], [y],
                            marker="o",
                            color='g',
                            markersize=40,
                            alpha=0.8)
                    ax.text(x, y + 0.3, 'START', ha='center', size=12, c='w')
                elif self.env.desc[i][j] == b'G':
                    ax.plot([x], [y],
                            marker="o",
                            color='r',
                            markersize=40,
                            alpha=0.8)
                    ax.text(x, y + 0.3, 'GOAL', ha='center', size=12, c='w')
                elif self.env.desc[i][j] == b'H':
                    ax.plot([x], [y],
                            marker="x",
                            color='b',
                            markersize=30,
                            markeredgewidth=10,
                            alpha=0.8)
                elif self.env.desc[i][j] == b'g':
                    ax.plot([x], [y],
                            marker="o",
                            color='orange',
                            markersize=30,
                            markeredgewidth=10,
                            alpha=0.8)

        fig.colorbar(mappable0, ax=ax, orientation="vertical")

        plt.savefig(self.log_dir + '/' + str(self.episodes) + '/' +
                    'optimization.png')
        plt.close()

        # Distribution Plot
        if dist is not None:
            # sns.set(rc={"figure.figsize": (6, 12)});
            plt.gcf().set_size_inches(6, 12)

            for i in range(self.env.nrow * self.env.ncol):
                plt.subplot(self.env.nrow,
                            self.env.ncol,
                            i + 1,
                            facecolor='#EAEAF2',
                            fc='#EAEAF2')
                for x in ['left', 'bottom', 'top', 'right']:
                    plt.gca().spines[x].set_visible(False)
                    # plt.gca().spines['top'].set_visible(False)

                # for j, c in zip(range(4), ['red', 'blue', 'green', 'darkorange']):
                for j, c in enumerate(['red', 'blue', 'green']):
                    ax = sns.distplot(dist[i, :, j], color=c, hist=False)
                    ax.fill_between(ax.lines[j].get_xydata()[:, 0],
                                    ax.lines[j].get_xydata()[:, 1],
                                    color=c,
                                    alpha=0.3)

            # # one liner to remove *all axes in all subplots*
            plt.setp(plt.gcf().get_axes(), xticks=[], yticks=[])

            plt.savefig(self.log_dir + '/' + str(self.episodes) + '/' +
                        'distribution.png')
            plt.close()

    def plot_arrow(self, ax, opt_act):
        for y in range(self.env.nrow):
            for x in range(1, self.env.ncol):
                if opt_act[y][x] == 0:  # 右向き
                    ax.annotate('',
                                xy=[0 + 3 * x, 1 + 3 * y],
                                xytext=[2 + 3 * x, 1 + 3 * y],
                                arrowprops=dict(shrink=10,
                                                width=20,
                                                headwidth=40,
                                                headlength=20,
                                                connectionstyle='arc3',
                                                facecolor='red',
                                                edgecolor='red'))
                elif opt_act[y][x] == 1:  # 下向き
                    ax.annotate('',
                                xy=[1 + 3 * x, 2 + 3 * y],
                                xytext=[1 + 3 * x, 0 + 3 * y],
                                arrowprops=dict(shrink=10,
                                                width=20,
                                                headwidth=40,
                                                headlength=20,
                                                connectionstyle='arc3',
                                                facecolor='red',
                                                edgecolor='red'))
                elif opt_act[y][x] == 2:  # 左向き
                    ax.annotate('',
                                xy=[2 + 3 * x, 1 + 3 * y],
                                xytext=[0 + 3 * x, 1 + 3 * y],
                                arrowprops=dict(shrink=10,
                                                width=20,
                                                headwidth=40,
                                                headlength=20,
                                                connectionstyle='arc3',
                                                facecolor='red',
                                                edgecolor='red'))

    def __del__(self):
        self.env.close()
        self.test_env.close()
        self.writer.close()
Ejemplo n.º 6
0
    def __init__(self, env, log_dir, num_steps=3000000, batch_size=256,
                 lr=0.0003, hidden_units=[256, 256], memory_size=1e6,
                 gamma=0.99, tau=0.005, entropy_tuning=True, ent_coef=0.2,
                 multi_step=1, per=False, alpha=0.6, beta=0.4,
                 beta_annealing=0.0001, grad_clip=None, updates_per_step=1,
                 start_steps=10000, log_interval=10, target_update_interval=1,
                 eval_interval=1000, cuda=True, seed=0):
        self.env = env

        torch.manual_seed(seed)
        np.random.seed(seed)
        self.env.seed(seed)
        torch.backends.cudnn.deterministic = True  # It harms a performance.
        torch.backends.cudnn.benchmark = False

        self.device = torch.device(
            "cuda" if cuda and torch.cuda.is_available() else "cpu")

        self.policy = GaussianPolicy(
            self.env.observation_space.shape[0],
            self.env.action_space.shape[0],
            hidden_units=hidden_units).to(self.device)
        self.critic = TwinnedQNetwork(
            self.env.observation_space.shape[0],
            self.env.action_space.shape[0],
            hidden_units=hidden_units).to(self.device)
        self.critic_target = TwinnedQNetwork(
            self.env.observation_space.shape[0],
            self.env.action_space.shape[0],
            hidden_units=hidden_units).to(self.device).eval()

        # copy parameters of the learning network to the target network
        hard_update(self.critic_target, self.critic)
        # disable gradient calculations of the target network
        grad_false(self.critic_target)

        self.policy_optim = Adam(self.policy.parameters(), lr=lr)
        self.q1_optim = Adam(self.critic.Q1.parameters(), lr=lr)
        self.q2_optim = Adam(self.critic.Q2.parameters(), lr=lr)

        if entropy_tuning:
            # Target entropy is -|A|.
            self.target_entropy = -torch.prod(torch.Tensor(
                self.env.action_space.shape).to(self.device)).item()
            # We optimize log(alpha), instead of alpha.
            self.log_alpha = torch.zeros(
                1, requires_grad=True, device=self.device)
            self.alpha = self.log_alpha.exp()
            self.alpha_optim = Adam([self.log_alpha], lr=lr)
        else:
            # fixed alpha
            self.alpha = torch.tensor(ent_coef).to(self.device)

        if per:
            # replay memory with prioritied experience replay
            # See https://github.com/ku2482/rltorch/blob/master/rltorch/memory
            self.memory = PrioritizedMemory(
                memory_size, self.env.observation_space.shape,
                self.env.action_space.shape, self.device, gamma, multi_step,
                alpha=alpha, beta=beta, beta_annealing=beta_annealing)
        else:
            # replay memory without prioritied experience replay
            # See https://github.com/ku2482/rltorch/blob/master/rltorch/memory
            self.memory = MultiStepMemory(
                memory_size, self.env.observation_space.shape,
                self.env.action_space.shape, self.device, gamma, multi_step)

        self.log_dir = log_dir
        self.model_dir = os.path.join(log_dir, 'model')
        self.summary_dir = os.path.join(log_dir, 'summary')
        if not os.path.exists(self.model_dir):
            os.makedirs(self.model_dir)
        if not os.path.exists(self.summary_dir):
            os.makedirs(self.summary_dir)

        self.writer = SummaryWriter(log_dir=self.summary_dir)
        self.train_rewards = RunningMeanStats(log_interval)

        self.steps = 0
        self.learning_steps = 0
        self.episodes = 0
        self.num_steps = num_steps
        self.tau = tau
        self.per = per
        self.batch_size = batch_size
        self.start_steps = start_steps
        self.gamma_n = gamma ** multi_step
        self.entropy_tuning = entropy_tuning
        self.grad_clip = grad_clip
        self.updates_per_step = updates_per_step
        self.log_interval = log_interval
        self.target_update_interval = target_update_interval
        self.eval_interval = eval_interval
Ejemplo n.º 7
0
class SacAgent:

    def __init__(self, env, log_dir, num_steps=3000000, batch_size=256,
                 lr=0.0003, hidden_units=[256, 256], memory_size=1e6,
                 gamma=0.99, tau=0.005, entropy_tuning=True, ent_coef=0.2,
                 multi_step=1, per=False, alpha=0.6, beta=0.4,
                 beta_annealing=0.0001, grad_clip=None, updates_per_step=1,
                 start_steps=10000, log_interval=10, target_update_interval=1,
                 eval_interval=1000, cuda=True, seed=0):
        self.env = env

        torch.manual_seed(seed)
        np.random.seed(seed)
        self.env.seed(seed)
        torch.backends.cudnn.deterministic = True  # It harms a performance.
        torch.backends.cudnn.benchmark = False

        self.device = torch.device(
            "cuda" if cuda and torch.cuda.is_available() else "cpu")

        self.policy = GaussianPolicy(
            self.env.observation_space.shape[0],
            self.env.action_space.shape[0],
            hidden_units=hidden_units).to(self.device)
        self.critic = TwinnedQNetwork(
            self.env.observation_space.shape[0],
            self.env.action_space.shape[0],
            hidden_units=hidden_units).to(self.device)
        self.critic_target = TwinnedQNetwork(
            self.env.observation_space.shape[0],
            self.env.action_space.shape[0],
            hidden_units=hidden_units).to(self.device).eval()

        # copy parameters of the learning network to the target network
        hard_update(self.critic_target, self.critic)
        # disable gradient calculations of the target network
        grad_false(self.critic_target)

        self.policy_optim = Adam(self.policy.parameters(), lr=lr)
        self.q1_optim = Adam(self.critic.Q1.parameters(), lr=lr)
        self.q2_optim = Adam(self.critic.Q2.parameters(), lr=lr)

        if entropy_tuning:
            # Target entropy is -|A|.
            self.target_entropy = -torch.prod(torch.Tensor(
                self.env.action_space.shape).to(self.device)).item()
            # We optimize log(alpha), instead of alpha.
            self.log_alpha = torch.zeros(
                1, requires_grad=True, device=self.device)
            self.alpha = self.log_alpha.exp()
            self.alpha_optim = Adam([self.log_alpha], lr=lr)
        else:
            # fixed alpha
            self.alpha = torch.tensor(ent_coef).to(self.device)

        if per:
            # replay memory with prioritied experience replay
            # See https://github.com/ku2482/rltorch/blob/master/rltorch/memory
            self.memory = PrioritizedMemory(
                memory_size, self.env.observation_space.shape,
                self.env.action_space.shape, self.device, gamma, multi_step,
                alpha=alpha, beta=beta, beta_annealing=beta_annealing)
        else:
            # replay memory without prioritied experience replay
            # See https://github.com/ku2482/rltorch/blob/master/rltorch/memory
            self.memory = MultiStepMemory(
                memory_size, self.env.observation_space.shape,
                self.env.action_space.shape, self.device, gamma, multi_step)

        self.log_dir = log_dir
        self.model_dir = os.path.join(log_dir, 'model')
        self.summary_dir = os.path.join(log_dir, 'summary')
        if not os.path.exists(self.model_dir):
            os.makedirs(self.model_dir)
        if not os.path.exists(self.summary_dir):
            os.makedirs(self.summary_dir)

        self.writer = SummaryWriter(log_dir=self.summary_dir)
        self.train_rewards = RunningMeanStats(log_interval)

        self.steps = 0
        self.learning_steps = 0
        self.episodes = 0
        self.num_steps = num_steps
        self.tau = tau
        self.per = per
        self.batch_size = batch_size
        self.start_steps = start_steps
        self.gamma_n = gamma ** multi_step
        self.entropy_tuning = entropy_tuning
        self.grad_clip = grad_clip
        self.updates_per_step = updates_per_step
        self.log_interval = log_interval
        self.target_update_interval = target_update_interval
        self.eval_interval = eval_interval

    def run(self):
        while True:
            self.train_episode()
            if self.steps > self.num_steps:
                break

    def is_update(self):
        return len(self.memory) > self.batch_size and\
            self.steps >= self.start_steps

    def act(self, state):
        if self.start_steps > self.steps:
            action = self.env.action_space.sample()
        else:
            action = self.explore(state)
        return action

    def explore(self, state):
        # act with randomness
        state = torch.FloatTensor(state).unsqueeze(0).to(self.device)
        with torch.no_grad():
            action, _, _ = self.policy.sample(state)
        return action.cpu().numpy().reshape(-1)

    def exploit(self, state):
        # act without randomness
        state = torch.FloatTensor(state).unsqueeze(0).to(self.device)
        with torch.no_grad():
            _, _, action = self.policy.sample(state)
        return action.cpu().numpy().reshape(-1)

    def calc_current_q(self, states, actions, rewards, next_states, dones):
        curr_q1, curr_q2 = self.critic(states, actions)
        return curr_q1, curr_q2

    def calc_target_q(self, states, actions, rewards, next_states, dones):
        with torch.no_grad():
            next_actions, next_entropies, _ = self.policy.sample(next_states)
            next_q1, next_q2 = self.critic_target(next_states, next_actions)
            next_q = torch.min(next_q1, next_q2) + self.alpha * next_entropies

        target_q = rewards + (1.0 - dones) * self.gamma_n * next_q

        return target_q

    def train_episode(self):
        self.episodes += 1
        episode_reward = 0.
        episode_steps = 0
        done = False
        state = self.env.reset()

        while not done:
            action = self.act(state)
            next_state, reward, done, _ = self.env.step(action)
            self.steps += 1
            episode_steps += 1
            episode_reward += reward

            # ignore done if the agent reach time horizons
            # (set done=True only when the agent fails)
            if episode_steps >= self.env._max_episode_steps:
                masked_done = False
            else:
                masked_done = done

            if self.per:
                batch = to_batch(
                    state, action, reward, next_state, masked_done,
                    self.device)
                with torch.no_grad():
                    curr_q1, curr_q2 = self.calc_current_q(*batch)
                target_q = self.calc_target_q(*batch)
                error = torch.abs(curr_q1 - target_q).item()
                # We need to give true done signal with addition to masked done
                # signal to calculate multi-step rewards.
                self.memory.append(
                    state, action, reward, next_state, masked_done, error,
                    episode_done=done)
            else:
                # We need to give true done signal with addition to masked done
                # signal to calculate multi-step rewards.
                self.memory.append(
                    state, action, reward, next_state, masked_done,
                    episode_done=done)

            if self.is_update():
                for _ in range(self.updates_per_step):
                    self.learn()

            if self.steps % self.eval_interval == 0:
                self.evaluate()
                self.save_models()

            state = next_state

        # We log running mean of training rewards.
        self.train_rewards.append(episode_reward)

        if self.episodes % self.log_interval == 0:
            self.writer.add_scalar(
                'reward/train', self.train_rewards.get(), self.steps)

        print(f'episode: {self.episodes:<4}  '
              f'episode steps: {episode_steps:<4}  '
              f'reward: {episode_reward:<5.1f}')

    def learn(self):
        self.learning_steps += 1
        if self.learning_steps % self.target_update_interval == 0:
            soft_update(self.critic_target, self.critic, self.tau)

        if self.per:
            # batch with indices and priority weights
            batch, indices, weights = \
                self.memory.sample(self.batch_size)
        else:
            batch = self.memory.sample(self.batch_size)
            # set priority weights to 1 when we don't use PER.
            weights = 1.

        q1_loss, q2_loss, errors, mean_q1, mean_q2 =\
            self.calc_critic_loss(batch, weights)
        policy_loss, entropies = self.calc_policy_loss(batch, weights)

        update_params(
            self.q1_optim, self.critic.Q1, q1_loss, self.grad_clip)
        update_params(
            self.q2_optim, self.critic.Q2, q2_loss, self.grad_clip)
        update_params(
            self.policy_optim, self.policy, policy_loss, self.grad_clip)

        if self.entropy_tuning:
            entropy_loss = self.calc_entropy_loss(entropies, weights)
            update_params(self.alpha_optim, None, entropy_loss)
            self.alpha = self.log_alpha.exp()
            self.writer.add_scalar(
                'loss/alpha', entropy_loss.detach().item(), self.steps)

        if self.per:
            # update priority weights
            self.memory.update_priority(indices, errors.cpu().numpy())

        if self.learning_steps % self.log_interval == 0:
            self.writer.add_scalar(
                'loss/Q1', q1_loss.detach().item(),
                self.learning_steps)
            self.writer.add_scalar(
                'loss/Q2', q2_loss.detach().item(),
                self.learning_steps)
            self.writer.add_scalar(
                'loss/policy', policy_loss.detach().item(),
                self.learning_steps)
            self.writer.add_scalar(
                'stats/alpha', self.alpha.detach().item(),
                self.learning_steps)
            self.writer.add_scalar(
                'stats/mean_Q1', mean_q1, self.learning_steps)
            self.writer.add_scalar(
                'stats/mean_Q2', mean_q2, self.learning_steps)
            self.writer.add_scalar(
                'stats/entropy', entropies.detach().mean().item(),
                self.learning_steps)

    def calc_critic_loss(self, batch, weights):
        curr_q1, curr_q2 = self.calc_current_q(*batch)
        target_q = self.calc_target_q(*batch)

        # TD errors for updating priority weights
        errors = torch.abs(curr_q1.detach() - target_q)
        # We log means of Q to monitor training.
        mean_q1 = curr_q1.detach().mean().item()
        mean_q2 = curr_q2.detach().mean().item()

        # Critic loss is mean squared TD errors with priority weights.
        q1_loss = torch.mean((curr_q1 - target_q).pow(2) * weights)
        q2_loss = torch.mean((curr_q2 - target_q).pow(2) * weights)
        return q1_loss, q2_loss, errors, mean_q1, mean_q2

    def calc_policy_loss(self, batch, weights):
        states, actions, rewards, next_states, dones = batch

        # We re-sample actions to calculate expectations of Q.
        sampled_action, entropy, _ = self.policy.sample(states)
        # expectations of Q with clipped double Q technique
        q1, q2 = self.critic(states, sampled_action)
        q = torch.min(q1, q2)

        # Policy objective is maximization of (Q + alpha * entropy) with
        # priority weights.
        policy_loss = torch.mean((- q - self.alpha * entropy) * weights)
        return policy_loss, entropy

    def calc_entropy_loss(self, entropy, weights):
        # Intuitively, we increse alpha when entropy is less than target
        # entropy, vice versa.
        entropy_loss = -torch.mean(
            self.log_alpha * (self.target_entropy - entropy).detach()
            * weights)
        return entropy_loss

    def evaluate(self):
        episodes = 10
        returns = np.zeros((episodes,), dtype=np.float32)

        for i in range(episodes):
            state = self.env.reset()
            episode_reward = 0.
            done = False
            while not done:
                action = self.exploit(state)
                next_state, reward, done, _ = self.env.step(action)
                episode_reward += reward
                state = next_state
            returns[i] = episode_reward

        mean_return = np.mean(returns)

        self.writer.add_scalar(
            'reward/test', mean_return, self.steps)
        print('-' * 60)
        print(f'Num steps: {self.steps:<5}  '
              f'reward: {mean_return:<5.1f}')
        print('-' * 60)

    def save_models(self):
        self.policy.save(os.path.join(self.model_dir, 'policy.pth'))
        self.critic.save(os.path.join(self.model_dir, 'critic.pth'))
        self.critic_target.save(
            os.path.join(self.model_dir, 'critic_target.pth'))

    def __del__(self):
        self.writer.close()
        self.env.close()
Ejemplo n.º 8
0
    def __init__(self,
                 env,
                 log_dir,
                 num_steps=3000000,
                 initial_latent_steps=100000,
                 batch_size=256,
                 latent_batch_size=32,
                 num_sequences=8,
                 lr=0.0003,
                 latent_lr=0.0001,
                 feature_dim=256,
                 latent1_dim=32,
                 latent2_dim=256,
                 hidden_units=[256, 256],
                 memory_size=1e5,
                 gamma=0.99,
                 target_update_interval=1,
                 tau=0.005,
                 entropy_tuning=True,
                 ent_coef=0.2,
                 leaky_slope=0.2,
                 grad_clip=None,
                 updates_per_step=1,
                 start_steps=10000,
                 training_log_interval=10,
                 learning_log_interval=100,
                 eval_interval=50000,
                 cuda=True,
                 seed=0):

        self.env = env
        self.observation_shape = self.env.observation_space.shape
        self.action_shape = self.env.action_space.shape
        self.action_repeat = self.env.action_repeat

        torch.manual_seed(seed)
        np.random.seed(seed)
        self.env.seed(seed)
        # torch.backends.cudnn.deterministic = True  # It harms a performance.
        # torch.backends.cudnn.benchmark = False  # It harms a performance.

        self.device = torch.device(
            "cuda" if cuda and torch.cuda.is_available() else "cpu")

        self.latent = LatentNetwork(self.observation_shape, self.action_shape,
                                    feature_dim, latent1_dim, latent2_dim,
                                    hidden_units, leaky_slope).to(self.device)

        self.policy = GaussianPolicy(
            num_sequences * feature_dim +
            (num_sequences - 1) * self.action_shape[0], self.action_shape[0],
            hidden_units).to(self.device)

        self.critic = TwinnedQNetwork(latent1_dim + latent2_dim,
                                      self.action_shape[0],
                                      hidden_units).to(self.device)
        self.critic_target = TwinnedQNetwork(
            latent1_dim + latent2_dim, self.action_shape[0],
            hidden_units).to(self.device).eval()

        # Copy parameters of the learning network to the target network.
        soft_update(self.critic_target, self.critic, 1.0)
        # Disable gradient calculations of the target network.
        grad_false(self.critic_target)

        # Policy is updated without the encoder.
        self.policy_optim = Adam(self.policy.parameters(), lr=lr)
        self.q_optim = Adam(self.critic.parameters(), lr=lr)
        self.latent_optim = Adam(self.latent.parameters(), lr=latent_lr)

        if entropy_tuning:
            # Target entropy is -|A|.
            self.target_entropy = -self.action_shape[0]
            # We optimize log(alpha) because alpha is always larger than 0.
            self.log_alpha = torch.zeros(1,
                                         requires_grad=True,
                                         device=self.device)
            self.alpha_optim = Adam([self.log_alpha], lr=lr)
            self.alpha = self.log_alpha.detach().exp()

        else:
            self.alpha = ent_coef

        self.memory = LazyMemory(memory_size, num_sequences,
                                 self.observation_shape, self.action_shape,
                                 self.device)

        self.log_dir = log_dir
        self.model_dir = os.path.join(log_dir, 'model')
        self.summary_dir = os.path.join(log_dir, 'summary')
        if not os.path.exists(self.model_dir):
            os.makedirs(self.model_dir)
        if not os.path.exists(self.summary_dir):
            os.makedirs(self.summary_dir)

        self.writer = SummaryWriter(log_dir=self.summary_dir)
        self.train_rewards = RunningMeanStats(training_log_interval)

        self.steps = 0
        self.learning_steps = 0
        self.episodes = 0
        self.initial_latent_steps = initial_latent_steps
        self.num_sequences = num_sequences
        self.num_steps = num_steps
        self.tau = tau
        self.batch_size = batch_size
        self.latent_batch_size = latent_batch_size
        self.start_steps = start_steps
        self.gamma = gamma
        self.entropy_tuning = entropy_tuning
        self.grad_clip = grad_clip
        self.updates_per_step = updates_per_step
        self.training_log_interval = training_log_interval
        self.learning_log_interval = learning_log_interval
        self.target_update_interval = target_update_interval
        self.eval_interval = eval_interval
Ejemplo n.º 9
0
class SlacAgent:
    def __init__(self,
                 env,
                 log_dir,
                 num_steps=3000000,
                 initial_latent_steps=100000,
                 batch_size=256,
                 latent_batch_size=32,
                 num_sequences=8,
                 lr=0.0003,
                 latent_lr=0.0001,
                 feature_dim=256,
                 latent1_dim=32,
                 latent2_dim=256,
                 hidden_units=[256, 256],
                 memory_size=1e5,
                 gamma=0.99,
                 target_update_interval=1,
                 tau=0.005,
                 entropy_tuning=True,
                 ent_coef=0.2,
                 leaky_slope=0.2,
                 grad_clip=None,
                 updates_per_step=1,
                 start_steps=10000,
                 training_log_interval=10,
                 learning_log_interval=100,
                 eval_interval=50000,
                 cuda=True,
                 seed=0):

        self.env = env
        self.observation_shape = self.env.observation_space.shape
        self.action_shape = self.env.action_space.shape
        self.action_repeat = self.env.action_repeat

        torch.manual_seed(seed)
        np.random.seed(seed)
        self.env.seed(seed)
        # torch.backends.cudnn.deterministic = True  # It harms a performance.
        # torch.backends.cudnn.benchmark = False  # It harms a performance.

        self.device = torch.device(
            "cuda" if cuda and torch.cuda.is_available() else "cpu")

        self.latent = LatentNetwork(self.observation_shape, self.action_shape,
                                    feature_dim, latent1_dim, latent2_dim,
                                    hidden_units, leaky_slope).to(self.device)

        self.policy = GaussianPolicy(
            num_sequences * feature_dim +
            (num_sequences - 1) * self.action_shape[0], self.action_shape[0],
            hidden_units).to(self.device)

        self.critic = TwinnedQNetwork(latent1_dim + latent2_dim,
                                      self.action_shape[0],
                                      hidden_units).to(self.device)
        self.critic_target = TwinnedQNetwork(
            latent1_dim + latent2_dim, self.action_shape[0],
            hidden_units).to(self.device).eval()

        # Copy parameters of the learning network to the target network.
        soft_update(self.critic_target, self.critic, 1.0)
        # Disable gradient calculations of the target network.
        grad_false(self.critic_target)

        # Policy is updated without the encoder.
        self.policy_optim = Adam(self.policy.parameters(), lr=lr)
        self.q_optim = Adam(self.critic.parameters(), lr=lr)
        self.latent_optim = Adam(self.latent.parameters(), lr=latent_lr)

        if entropy_tuning:
            # Target entropy is -|A|.
            self.target_entropy = -self.action_shape[0]
            # We optimize log(alpha) because alpha is always larger than 0.
            self.log_alpha = torch.zeros(1,
                                         requires_grad=True,
                                         device=self.device)
            self.alpha_optim = Adam([self.log_alpha], lr=lr)
            self.alpha = self.log_alpha.detach().exp()

        else:
            self.alpha = ent_coef

        self.memory = LazyMemory(memory_size, num_sequences,
                                 self.observation_shape, self.action_shape,
                                 self.device)

        self.log_dir = log_dir
        self.model_dir = os.path.join(log_dir, 'model')
        self.summary_dir = os.path.join(log_dir, 'summary')
        if not os.path.exists(self.model_dir):
            os.makedirs(self.model_dir)
        if not os.path.exists(self.summary_dir):
            os.makedirs(self.summary_dir)

        self.writer = SummaryWriter(log_dir=self.summary_dir)
        self.train_rewards = RunningMeanStats(training_log_interval)

        self.steps = 0
        self.learning_steps = 0
        self.episodes = 0
        self.initial_latent_steps = initial_latent_steps
        self.num_sequences = num_sequences
        self.num_steps = num_steps
        self.tau = tau
        self.batch_size = batch_size
        self.latent_batch_size = latent_batch_size
        self.start_steps = start_steps
        self.gamma = gamma
        self.entropy_tuning = entropy_tuning
        self.grad_clip = grad_clip
        self.updates_per_step = updates_per_step
        self.training_log_interval = training_log_interval
        self.learning_log_interval = learning_log_interval
        self.target_update_interval = target_update_interval
        self.eval_interval = eval_interval

    def run(self):
        while True:
            self.train_episode()
            if self.steps > self.num_steps:
                break

    def is_update(self):
        return len(self.memory) > self.batch_size and\
            self.steps >= self.start_steps * self.action_repeat

    def reset_deque(self, state):
        state_deque = deque(maxlen=self.num_sequences)
        action_deque = deque(maxlen=self.num_sequences - 1)

        for _ in range(self.num_sequences - 1):
            state_deque.append(np.zeros(self.observation_shape,
                                        dtype=np.uint8))
            action_deque.append(np.zeros(self.action_shape, dtype=np.uint8))
        state_deque.append(state)

        return state_deque, action_deque

    def deque_to_batch(self, state_deque, action_deque):
        # Convert deques to batched tensor.
        state = np.array(state_deque, dtype=np.uint8)[None, ...]
        state = torch.ByteTensor(state).to(self.device).float() / 255.0
        with torch.no_grad():
            feature = self.latent.encoder(state).view(1, -1)

        action = np.array(action_deque, dtype=np.float32)
        action = torch.FloatTensor(action).view(1, -1).to(self.device)
        feature_action = torch.cat([feature, action], dim=-1)
        return feature_action

    def explore(self, state_deque, action_deque):
        # Act with randomness
        feature_action = self.deque_to_batch(state_deque, action_deque)
        with torch.no_grad():
            action, _, _ = self.policy.sample(feature_action)
        return action.cpu().numpy().reshape(-1)

    def exploit(self, state_deque, action_deque):
        # Act without randomness
        feature_action = self.deque_to_batch(state_deque, action_deque)
        with torch.no_grad():
            _, _, action = self.policy.sample(feature_action)
        return action.cpu().numpy().reshape(-1)

    def train_episode(self):
        self.episodes += 1
        episode_reward = 0.
        episode_steps = 0
        done = False
        state = self.env.reset()
        self.memory.set_initial_state(state)
        state_deque, action_deque = self.reset_deque(state)

        while not done:
            if self.steps >= self.start_steps * self.action_repeat:
                action = self.explore(state_deque, action_deque)
            else:
                action = 2 * np.random.rand(*self.action_shape) - 1

            next_state, reward, done, _ = self.env.step(action)
            self.steps += self.action_repeat
            episode_steps += self.action_repeat
            episode_reward += reward

            self.memory.append(action, reward, next_state, done)

            if self.is_update():
                # First, train the latent model only.
                if self.learning_steps < self.initial_latent_steps:
                    print('-' * 60)
                    print('Learning the latent model only...')
                    for _ in range(self.initial_latent_steps):
                        self.learning_steps += 1
                        self.learn_latent()
                    print('Finish learning the latent model.')
                    print('-' * 60)

                for _ in range(self.updates_per_step):
                    self.learn()

                if self.steps % self.eval_interval == 0:
                    self.evaluate()
                    self.save_models()

            state_deque.append(next_state)
            action_deque.append(action)

        # We log running mean of training rewards.
        self.train_rewards.append(episode_reward)

        if self.episodes % self.training_log_interval == 0:
            self.writer.add_scalar('reward/train', self.train_rewards.get(),
                                   self.steps)

        print(f'episode: {self.episodes:<4}  '
              f'episode steps: {episode_steps:<4}  '
              f'reward: {episode_reward:<5.1f}')

    def learn(self):
        self.learning_steps += 1
        if self.learning_steps % self.target_update_interval == 0:
            soft_update(self.critic_target, self.critic, self.tau)

        # Update the latent model.
        self.learn_latent()
        # Update policy and critic.
        self.learn_sac()

    def learn_latent(self):
        images_seq, actions_seq, rewards_seq, dones_seq =\
            self.memory.sample_latent(self.latent_batch_size)
        latent_loss = self.calc_latent_loss(images_seq, actions_seq,
                                            rewards_seq, dones_seq)
        update_params(self.latent_optim, self.latent, latent_loss,
                      self.grad_clip)

        if self.learning_steps % self.learning_log_interval == 0:
            self.writer.add_scalar('loss/latent',
                                   latent_loss.detach().item(),
                                   self.learning_steps)

    def learn_sac(self):
        images_seq, actions_seq, rewards =\
            self.memory.sample_sac(self.batch_size)

        # NOTE: Don't update the encoder part of the policy here.
        with torch.no_grad():
            # f(1:t+1)
            features_seq = self.latent.encoder(images_seq)
            latent_samples, _ = self.latent.sample_posterior(
                features_seq, actions_seq)

        # z(t), z(t+1)
        latents_seq = torch.cat(latent_samples, dim=-1)
        latents = latents_seq[:, -2]
        next_latents = latents_seq[:, -1]
        # a(t)
        actions = actions_seq[:, -1]
        # fa(t)=(x(1:t), a(1:t-1)), fa(t+1)=(x(2:t+1), a(2:t))
        feature_actions, next_feature_actions =\
            create_feature_actions(features_seq, actions_seq)

        q1_loss, q2_loss = self.calc_critic_loss(latents, next_latents,
                                                 actions, next_feature_actions,
                                                 rewards)
        update_params(self.q_optim, self.critic, q1_loss + q2_loss,
                      self.grad_clip)

        policy_loss, entropies = self.calc_policy_loss(latents,
                                                       feature_actions)
        update_params(self.policy_optim, self.policy, policy_loss,
                      self.grad_clip)

        if self.entropy_tuning:
            entropy_loss = self.calc_entropy_loss(entropies)
            update_params(self.alpha_optim, None, entropy_loss)
            self.alpha = self.log_alpha.detach().exp().item()
        else:
            entropy_loss = 0.

        if self.learning_steps % self.learning_log_interval == 0:
            self.writer.add_scalar('loss/Q1',
                                   q1_loss.detach().item(),
                                   self.learning_steps)
            self.writer.add_scalar('loss/Q2',
                                   q2_loss.detach().item(),
                                   self.learning_steps)
            self.writer.add_scalar('loss/policy',
                                   policy_loss.detach().item(),
                                   self.learning_steps)
            self.writer.add_scalar('loss/alpha',
                                   entropy_loss.detach().item(),
                                   self.learning_steps)
            self.writer.add_scalar('stats/alpha', self.alpha,
                                   self.learning_steps)
            self.writer.add_scalar('stats/entropy',
                                   entropies.detach().mean().item(),
                                   self.learning_steps)

    def calc_latent_loss(self, images_seq, actions_seq, rewards_seq,
                         dones_seq):
        features_seq = self.latent.encoder(images_seq)

        # Sample from posterior dynamics.
        (latent1_post_samples, latent2_post_samples),\
            (latent1_post_dists, latent2_post_dists) =\
            self.latent.sample_posterior(features_seq, actions_seq)
        # Sample from prior dynamics.
        (latent1_pri_samples, latent2_pri_samples),\
            (latent1_pri_dists, latent2_pri_dists) =\
            self.latent.sample_prior(actions_seq)

        # KL divergence loss.
        kld_loss = calc_kl_divergence(latent1_post_dists, latent1_pri_dists)

        # Log likelihood loss of generated observations.
        images_seq_dists = self.latent.decoder(
            [latent1_post_samples, latent2_post_samples])
        log_likelihood_loss = images_seq_dists.log_prob(images_seq).mean(
            dim=0).sum()

        # Log likelihood loss of genarated rewards.
        rewards_seq_dists = self.latent.reward_predictor([
            latent1_post_samples[:, :-1], latent2_post_samples[:, :-1],
            actions_seq, latent1_post_samples[:, 1:], latent2_post_samples[:,
                                                                           1:]
        ])
        reward_log_likelihoods =\
            rewards_seq_dists.log_prob(rewards_seq) * (1.0 - dones_seq)
        reward_log_likelihood_loss = reward_log_likelihoods.mean(dim=0).sum()

        latent_loss =\
            kld_loss - log_likelihood_loss - reward_log_likelihood_loss

        if self.learning_steps % self.learning_log_interval == 0:
            reconst_error = (images_seq - images_seq_dists.loc).pow(2).mean(
                dim=(0, 1)).sum().item()
            reward_reconst_error = (
                (rewards_seq - rewards_seq_dists.loc).pow(2) *
                (1.0 - dones_seq)).mean(dim=(0, 1)).sum().detach().item()
            self.writer.add_scalar('stats/reconst_error', reconst_error,
                                   self.learning_steps)
            self.writer.add_scalar('stats/reward_reconst_error',
                                   reward_reconst_error, self.learning_steps)

        if self.learning_steps % (100 * self.learning_log_interval) == 0:
            gt_images = images_seq[0].detach().cpu()
            post_images = images_seq_dists.loc[0].detach().cpu()

            with torch.no_grad():
                pri_images = self.latent.decoder(
                    [latent1_pri_samples[:1],
                     latent2_pri_samples[:1]]).loc[0].detach().cpu()
                cond_pri_samples, _ = self.latent.sample_prior(
                    actions_seq[:1], features_seq[:1, 0])
                cond_pri_images = self.latent.decoder(
                    cond_pri_samples).loc[0].detach().cpu()

            images = torch.cat(
                [gt_images, post_images, cond_pri_images, pri_images], dim=-2)

            # Visualize multiple of 8 images because each row contains 8
            # images at most.
            self.writer.add_images('images/gt_posterior_cond-prior_prior',
                                   images[:(len(images) // 8) * 8],
                                   self.learning_steps)

        return latent_loss

    def calc_critic_loss(self, latents, next_latents, actions,
                         next_feature_actions, rewards):
        # Q(z(t), a(t))
        curr_q1, curr_q2 = self.critic(latents, actions)
        # E[Q(z(t+1), a(t+1)) + alpha * H(pi)]
        with torch.no_grad():
            next_actions, next_entropies, _ =\
                self.policy.sample(next_feature_actions)
            next_q1, next_q2 = self.critic_target(next_latents, next_actions)
            next_q = torch.min(next_q1, next_q2) + self.alpha * next_entropies
        # r(t) + gamma * E[Q(z(t+1), a(t+1)) + alpha * H(pi)]
        target_q = rewards + self.gamma * next_q

        # Critic losses are mean squared TD errors.
        q1_loss = 0.5 * torch.mean((curr_q1 - target_q).pow(2))
        q2_loss = 0.5 * torch.mean((curr_q2 - target_q).pow(2))

        if self.learning_steps % self.learning_log_interval == 0:
            mean_q1 = curr_q1.detach().mean().item()
            mean_q2 = curr_q2.detach().mean().item()
            self.writer.add_scalar('stats/mean_Q1', mean_q1,
                                   self.learning_steps)
            self.writer.add_scalar('stats/mean_Q2', mean_q2,
                                   self.learning_steps)

        return q1_loss, q2_loss

    def calc_policy_loss(self, latents, feature_actions):
        # Re-sample actions to calculate expectations of Q.
        sampled_actions, entropies, _ = self.policy.sample(feature_actions)
        # E[Q(z(t), a(t))]
        q1, q2 = self.critic(latents, sampled_actions)
        q = torch.min(q1, q2)

        # Policy objective is maximization of (Q + alpha * entropy).
        policy_loss = torch.mean((-q - self.alpha * entropies))

        return policy_loss, entropies

    def calc_entropy_loss(self, entropies):
        # Intuitively, we increse alpha when entropy is less than target
        # entropy, vice versa.
        entropy_loss = -torch.mean(self.log_alpha *
                                   (self.target_entropy - entropies.detach()))
        return entropy_loss

    def evaluate(self):
        episodes = 10
        returns = np.zeros((episodes, ), dtype=np.float32)

        for i in range(episodes):
            state = self.env.reset()
            episode_reward = 0.
            done = False
            state_deque, action_deque = self.reset_deque(state)

            while not done:
                action = self.explore(state_deque, action_deque)
                next_state, reward, done, _ = self.env.step(action)
                episode_reward += reward
                state_deque.append(next_state)
                action_deque.append(action)

            returns[i] = episode_reward

        mean_return = np.mean(returns)
        std_return = np.std(returns)

        self.writer.add_scalar('reward/test', mean_return, self.steps)
        print('-' * 60)
        print(f'environment steps: {self.steps:<5}  '
              f'return: {mean_return:<5.1f} +/- {std_return:<5.1f}')
        print('-' * 60)

    def save_models(self):
        self.latent.encoder.save(os.path.join(self.model_dir, 'encoder.pth'))
        self.latent.save(os.path.join(self.model_dir, 'latent.pth'))
        self.policy.save(os.path.join(self.model_dir, 'policy.pth'))
        self.critic.save(os.path.join(self.model_dir, 'critic.pth'))
        self.critic_target.save(
            os.path.join(self.model_dir, 'critic_target.pth'))

    def __del__(self):
        self.writer.close()
        self.env.close()