Esempio n. 1
0
    def __init__(self,
                 logdir,
                 game,
                 policy,
                 optimizer=torch.optim.Adam,
                 n_simulations=100,
                 buffer_size=200,
                 batch_size=64,
                 batches_per_game=1,
                 gpu=True):
        """Init."""
        self.logdir = logdir
        self.ckptr = Checkpointer(os.path.join(logdir, 'ckpts'))
        self.game = game
        self.device = torch.device(
            'cuda:0' if gpu and torch.cuda.is_available() else 'cpu')

        self.game = game
        self.n_sims = n_simulations
        self.batch_size = batch_size
        self.batches_per_game = batches_per_game

        self.pi = policy.to(self.device)
        self.opt = optimizer(self.pi.parameters(), lr=1e-2, weight_decay=1e-4)

        self.buffer = GameReplay(buffer_size)
        self.data_manager = SelfPlayManager(self.pi, self.game, self.buffer,
                                            self.device)

        self.mse = nn.MSELoss()

        self.t = 0
Esempio n. 2
0
    def __init__(self, logdir, model, opt, batch_size, num_workers, gpu=True):
        self.logdir = logdir
        self.ckptr = Checkpointer(os.path.join(logdir, 'ckpts'))
        self.transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.1307, ), (0.3081, ))
        ])
        self.data_train = datasets.MNIST('./data_train',
                                         download=True,
                                         transform=self.transform)
        self.data_test = datasets.MNIST('./data_test',
                                        download=True,
                                        train=False,
                                        transform=self.transform)
        self.sampler = StatefulSampler(self.data_train, shuffle=True)
        self.dtrain = DataLoader(self.data_train,
                                 sampler=self.sampler,
                                 batch_size=batch_size,
                                 num_workers=num_workers)
        self.dtest = DataLoader(self.data_test,
                                batch_size=batch_size,
                                num_workers=num_workers)
        self._diter = None
        self.t = 0
        self.epochs = 0
        self.batch_size = batch_size

        self.device = torch.device(
            'cuda:0' if gpu and torch.cuda.is_available() else 'cpu')
        self.model = model
        self.model.to(self.device)
        self.opt = opt(self.model.parameters())
Esempio n. 3
0
    def __init__(self,
                 logdir,
                 model,
                 opt,
                 datafile,
                 batch_size,
                 num_workers,
                 gpu=True):
        self.logdir = logdir
        self.ckptr = Checkpointer(os.path.join(logdir, 'ckpts'))
        self.data = DemonstrationData(datafile)
        self.sampler = StatefulSampler(self.data, shuffle=True)
        self.dtrain = DataLoader(self.data,
                                 sampler=self.sampler,
                                 batch_size=batch_size,
                                 num_workers=num_workers)
        self._diter = None
        self.t = 0
        self.epochs = 0
        self.batch_size = batch_size

        self.device = torch.device(
            'cuda:0' if gpu and torch.cuda.is_available() else 'cpu')
        self.model = model
        self.model.to(self.device)
        self.opt = opt(self.model.parameters())
Esempio n. 4
0
 def __init__(self, env, logdir, device):
     self.ckptr = Checkpointer(os.path.join(logdir, 'ckpts'))
     if not torch.cuda.is_available():
         device = 'cpu'
     self.device = device
     self.net = BCNet()
     self.net.to(device)
     self.net.load_state_dict(self.ckptr.load()['model'])
Esempio n. 5
0
 def __init__(self, env, logdir, device):
     self.ckptr = Checkpointer(os.path.join(logdir, 'ckpts'))
     if not torch.cuda.is_available():
         device = 'cpu'
     self.pi = drone_ppo_policy_fn(env)
     self.pi.to(device)
     self.pi.load_state_dict(self.ckptr.load()['pi'])
     self.pi.eval()
     self.device = device
Esempio n. 6
0
    def __init__(self,
                 logdir,
                 env_fn,
                 policy_fn,
                 nenv=1,
                 optimizer=torch.optim.Adam,
                 batch_size=32,
                 rollout_length=None,
                 gamma=0.99,
                 lambda_=0.95,
                 norm_advantages=False,
                 epochs_per_rollout=10,
                 max_grad_norm=None,
                 ent_coef=0.01,
                 vf_coef=0.5,
                 clip_param=0.2,
                 eval_num_episodes=1,
                 record_num_episodes=1,
                 gpu=True):
        """Init."""
        self.logdir = logdir
        self.ckptr = Checkpointer(os.path.join(logdir, 'ckpts'))
        self.env_fn = env_fn
        self.nenv = nenv
        self.eval_num_episodes = eval_num_episodes
        self.record_num_episodes = record_num_episodes
        self.epochs_per_rollout = epochs_per_rollout
        self.max_grad_norm = max_grad_norm
        self.ent_coef = ent_coef
        self.vf_coef = vf_coef
        self.clip_param = clip_param
        self.device = torch.device('cuda:0' if gpu and torch.cuda.is_available()
                                   else 'cpu')

        self.env = VecEpisodeLogger(VecRewardNormWrapper(env_fn(nenv=nenv),
                                                         gamma))

        self.pi = policy_fn(self.env).to(self.device)
        self.opt = optimizer(self.pi.parameters())
        self.data_manager = RolloutDataManager(
            self.env,
            PPOActor(self.pi),
            self.device,
            batch_size=batch_size,
            rollout_length=rollout_length,
            gamma=gamma,
            lambda_=lambda_,
            norm_advantages=norm_advantages)

        self.mse = nn.MSELoss(reduction='none')

        self.t = 0
Esempio n. 7
0
    def __init__(self, env, logdir, device, switch_prob=0.001):
        dirs = [x for x in os.listdir(logdir) if os.path.isdir(
                                            os.path.join(logdir, x, 'ckpts'))]

        self.ckptrs = [Checkpointer(os.path.join(logdir, x, 'ckpts'))
                       for x in dirs]
        if not torch.cuda.is_available():
            device = 'cpu'
        self.device = device
        self.nets = [BCNet() for _ in dirs]
        for net, ckptr in zip(self.nets, self.ckptrs):
            net.to(device)
            net.load_state_dict(ckptr.load()['model'])
        self.current_actor = np.random.choice(self.nets)
        self.switch_prob = switch_prob
Esempio n. 8
0
class BCActor(object):
    """Actor trained with Behavioral cloning"""

    def __init__(self, env, logdir, device):
        self.ckptr = Checkpointer(os.path.join(logdir, 'ckpts'))
        if not torch.cuda.is_available():
            device = 'cpu'
        self.device = device
        self.net = BCNet()
        self.net.to(device)
        self.net.load_state_dict(self.ckptr.load()['model'])

    def __call__(self, ob):
        """Act."""
        with torch.no_grad():
            dist = self.net(torch.from_numpy(ob).to(self.device))
            ac = dist.sample().cpu().numpy()
            return np.clip(ac, -1., 1.)
Esempio n. 9
0
class DroneReacherActor(object):
    """DroneReacher actor."""

    def __init__(self, env, logdir, device):
        self.ckptr = Checkpointer(os.path.join(logdir, 'ckpts'))
        if not torch.cuda.is_available():
            device = 'cpu'
        self.pi = drone_ppo_policy_fn(env)
        self.pi.to(device)
        self.pi.load_state_dict(self.ckptr.load()['pi'])
        self.pi.eval()
        self.device = device

    def __call__(self, ob):
        """Act."""
        with torch.no_grad():
            if isinstance(ob, np.ndarray):
                ob = torch.from_numpy(ob).to(self.device)
            return self.pi(ob).action.cpu().numpy()
class ResidualPPO2(Algorithm):
    """PPO algorithm with upgrades.

    This version is described in https://arxiv.org/abs/1707.02286 and
    https://github.com/joschu/modular_rl/blob/master/modular_rl/ppo.py
    """

    def __init__(self,
                 logdir,
                 env_fn,
                 policy_fn,
                 value_fn,
                 nenv=1,
                 opt_pi=torch.optim.Adam,
                 opt_vf=torch.optim.Adam,
                 batch_size=32,
                 rollout_length=None,
                 gamma=0.99,
                 lambda_=0.95,
                 ent_coef=0.01,
                 norm_advantages=False,
                 epochs_pi=10,
                 epochs_vf=10,
                 max_grad_norm=None,
                 kl_target=0.01,
                 alpha=1.5,
                 policy_training_start=10000,
                 eval_num_episodes=10,
                 record_num_episodes=0,
                 gpu=True):
        """Init."""
        self.logdir = logdir
        self.ckptr = Checkpointer(os.path.join(logdir, 'ckpts'))
        self.env_fn = env_fn
        self.nenv = nenv
        self.eval_num_episodes = eval_num_episodes
        self.record_num_episodes = record_num_episodes
        self.ent_coef = ent_coef
        self.epochs_pi = epochs_pi
        self.epochs_vf = epochs_vf
        self.max_grad_norm = max_grad_norm
        self.kl_target = kl_target
        self.initial_kl_weight = 0.2
        self.kl_weight = self.initial_kl_weight
        self.alpha = alpha
        self.policy_training_start = policy_training_start
        self.device = torch.device('cuda:0' if gpu and torch.cuda.is_available()
                                   else 'cpu')

        self.env = VecEpisodeLogger(VecRewardNormWrapper(env_fn(nenv=nenv),
                                                         gamma))

        self.pi = policy_fn(self.env).to(self.device)
        self.vf = value_fn(self.env).to(self.device)
        self.opt_pi = opt_pi(self.pi.parameters())
        self.opt_vf = opt_vf(self.vf.parameters())
        self._actor = ResidualPPOActor(self.pi, self.vf, policy_training_start)
        self.data_manager = RolloutDataManager(
            self.env,
            self._actor,
            self.device,
            batch_size=batch_size,
            rollout_length=rollout_length,
            gamma=gamma,
            lambda_=lambda_,
            norm_advantages=norm_advantages)

        self.mse = nn.MSELoss()

        self.t = 0

    def compute_kl(self):
        """Compute KL divergence of new and old policies."""
        kl = 0
        n = 0
        for batch in self.data_manager.sampler():
            outs = self.pi(batch['obs'])
            old_dist = outs.dist.from_tensors(batch['dist'])
            k = old_dist.kl(outs.dist).mean().detach().cpu().numpy()
            s = nest.flatten(batch['action'])[0].shape[0]
            kl = (n / (n + s)) * kl + (s / (n + s)) * k
            n += s
        return kl

    def loss_pi(self, batch):
        """Compute loss."""
        outs = self.pi(batch['obs'])

        # compute policy loss
        logp = outs.dist.log_prob(batch['action'])
        assert logp.shape == batch['logp'].shape
        ratio = torch.exp(logp - batch['logp'])
        assert ratio.shape == batch['atarg'].shape

        old_dist = outs.dist.from_tensors(batch['dist'])
        kl = old_dist.kl(outs.dist)
        kl_pen = (kl - 2 * self.kl_target).clamp(min=0).pow(2)
        losses = {}
        losses['pi'] = -(ratio * batch['atarg']).mean()
        losses['ent'] = -outs.dist.entropy().mean()
        losses['kl'] = kl.mean()
        losses['kl_pen'] = kl_pen.mean()
        losses['total'] = (losses['pi'] + self.ent_coef * losses['ent']
                           + self.kl_weight * losses['kl'] + 1000 * losses['kl_pen'])
        return losses

    def loss_vf(self, batch):
        return self.mse(self.vf(batch['obs']).value, batch['vtarg'])

    def step(self):
        """Compute rollout, loss, and update model."""
        self.pi.train()
        self.t += self.data_manager.rollout()
        losses = {'pi': [], 'vf': [], 'ent': [], 'kl': [], 'total': [],
                  'kl_pen': []}

        #######################
        # Update pi
        #######################

        if self.t >= self.policy_training_start:
            kl_too_big = False
            for _ in range(self.epochs_pi):
                if kl_too_big:
                    break
                for batch in self.data_manager.sampler():
                    self.opt_pi.zero_grad()
                    loss = self.loss_pi(batch)
                    # break if new policy is too different from old policy
                    if loss['kl'] > 4 * self.kl_target:
                        kl_too_big = True
                        break
                    loss['total'].backward()

                    for k, v in loss.items():
                        losses[k].append(v.detach().cpu().numpy())

                    if self.max_grad_norm:
                        norm = nn.utils.clip_grad_norm_(self.pi.parameters(),
                                                        self.max_grad_norm)
                        logger.add_scalar('alg/grad_norm', norm, self.t,
                                          time.time())
                        logger.add_scalar('alg/grad_norm_clipped',
                                          min(norm, self.max_grad_norm),
                                          self.t, time.time())
                    self.opt_pi.step()

        #######################
        # Update value function
        #######################
        for _ in range(self.epochs_vf):
            for batch in self.data_manager.sampler():
                self.opt_vf.zero_grad()
                loss = self.loss_vf(batch)
                losses['vf'].append(loss.detach().cpu().numpy())
                loss.backward()
                if self.max_grad_norm:
                    norm = nn.utils.clip_grad_norm_(self.vf.parameters(),
                                                    self.max_grad_norm)
                    logger.add_scalar('alg/vf_grad_norm', norm, self.t,
                                      time.time())
                    logger.add_scalar('alg/vf_grad_norm_clipped',
                                      min(norm, self.max_grad_norm),
                                      self.t, time.time())
                self.opt_vf.step()

        for k, v in losses.items():
            if len(v) > 0:
                logger.add_scalar(f'loss/{k}', np.mean(v), self.t, time.time())

        # update weight on kl to match kl_target.
        if self.t >= self.policy_training_start:
            kl = self.compute_kl()
            if kl > 10.0 * self.kl_target and self.kl_weight < self.initial_kl_weight:
                self.kl_weight = self.initial_kl_weight
            elif kl > 1.3 * self.kl_target:
                self.kl_weight *= self.alpha
            elif kl < 0.7 * self.kl_target:
                self.kl_weight /= self.alpha

            logger.add_scalar('alg/kl', kl, self.t, time.time())
            logger.add_scalar('alg/kl_weight', self.kl_weight, self.t, time.time())

        data = self.data_manager.storage.get_rollout()
        value_error = data['vpred'].data - data['q_mc'].data
        logger.add_scalar('alg/value_error_mean',
                          value_error.mean().cpu().numpy(), self.t, time.time())
        logger.add_scalar('alg/value_error_std',
                          value_error.std().cpu().numpy(), self.t, time.time())
        return self.t

    def evaluate(self):
        """Evaluate model."""
        self.pi.eval()
        misc.set_env_to_eval_mode(self.env)

        # Eval policy
        os.makedirs(os.path.join(self.logdir, 'eval'), exist_ok=True)
        outfile = os.path.join(self.logdir, 'eval',
                               self.ckptr.format.format(self.t) + '.json')
        stats = rl_evaluate(self.env, self.pi, self.eval_num_episodes,
                            outfile, self.device)
        logger.add_scalar('eval/mean_episode_reward', stats['mean_reward'],
                          self.t, time.time())
        logger.add_scalar('eval/mean_episode_length', stats['mean_length'],
                          self.t, time.time())

        # Record policy
        os.makedirs(os.path.join(self.logdir, 'video'), exist_ok=True)
        outfile = os.path.join(self.logdir, 'video',
                               self.ckptr.format.format(self.t) + '.mp4')
        rl_record(self.env, self.pi, self.record_num_episodes, outfile,
                  self.device)

        self.pi.train()
        misc.set_env_to_train_mode(self.env)

    def save(self):
        """State dict."""
        state_dict = {
            'pi': self.pi.state_dict(),
            'vf': self.vf.state_dict(),
            'opt_pi': self.opt_pi.state_dict(),
            'opt_vf': self.opt_vf.state_dict(),
            'kl_weight': self.kl_weight,
            'env': misc.env_state_dict(self.env),
            '_actor': self._actor.state_dict(),
            't': self.t
        }
        self.ckptr.save(state_dict, self.t)

    def load(self, t=None):
        """Load state dict."""
        state_dict = self.ckptr.load(t)
        if state_dict is None:
            self.t = 0
            return self.t
        self.pi.load_state_dict(state_dict['pi'])
        self.vf.load_state_dict(state_dict['vf'])
        self.opt_pi.load_state_dict(state_dict['opt_pi'])
        self.opt_vf.load_state_dict(state_dict['opt_vf'])
        self.kl_weight = state_dict['kl_weight']
        misc.env_load_state_dict(self.env, state_dict['env'])
        self._actor.load_state_dict(state_dict['_actor'])
        self.t = state_dict['t']
        return self.t

    def close(self):
        """Close environment."""
        try:
            self.env.close()
        except Exception:
            pass
    def __init__(self,
                 logdir,
                 env_fn,
                 policy_fn,
                 value_fn,
                 nenv=1,
                 opt_pi=torch.optim.Adam,
                 opt_vf=torch.optim.Adam,
                 batch_size=32,
                 rollout_length=None,
                 gamma=0.99,
                 lambda_=0.95,
                 ent_coef=0.01,
                 norm_advantages=False,
                 epochs_pi=10,
                 epochs_vf=10,
                 max_grad_norm=None,
                 kl_target=0.01,
                 alpha=1.5,
                 policy_training_start=10000,
                 eval_num_episodes=10,
                 record_num_episodes=0,
                 gpu=True):
        """Init."""
        self.logdir = logdir
        self.ckptr = Checkpointer(os.path.join(logdir, 'ckpts'))
        self.env_fn = env_fn
        self.nenv = nenv
        self.eval_num_episodes = eval_num_episodes
        self.record_num_episodes = record_num_episodes
        self.ent_coef = ent_coef
        self.epochs_pi = epochs_pi
        self.epochs_vf = epochs_vf
        self.max_grad_norm = max_grad_norm
        self.kl_target = kl_target
        self.initial_kl_weight = 0.2
        self.kl_weight = self.initial_kl_weight
        self.alpha = alpha
        self.policy_training_start = policy_training_start
        self.device = torch.device('cuda:0' if gpu and torch.cuda.is_available()
                                   else 'cpu')

        self.env = VecEpisodeLogger(VecRewardNormWrapper(env_fn(nenv=nenv),
                                                         gamma))

        self.pi = policy_fn(self.env).to(self.device)
        self.vf = value_fn(self.env).to(self.device)
        self.opt_pi = opt_pi(self.pi.parameters())
        self.opt_vf = opt_vf(self.vf.parameters())
        self._actor = ResidualPPOActor(self.pi, self.vf, policy_training_start)
        self.data_manager = RolloutDataManager(
            self.env,
            self._actor,
            self.device,
            batch_size=batch_size,
            rollout_length=rollout_length,
            gamma=gamma,
            lambda_=lambda_,
            norm_advantages=norm_advantages)

        self.mse = nn.MSELoss()

        self.t = 0
Esempio n. 12
0
class DQN(Algorithm):
    """DQN algorithm."""
    def __init__(self,
                 logdir,
                 env_fn,
                 qf_fn,
                 nenv=1,
                 optimizer=torch.optim.RMSprop,
                 buffer_size=100000,
                 frame_stack=1,
                 learning_starts=10000,
                 update_period=1,
                 gamma=0.99,
                 huber_loss=True,
                 exploration_timesteps=1000000,
                 final_eps=0.1,
                 eval_eps=0.05,
                 target_update_period=10000,
                 batch_size=32,
                 gpu=True,
                 eval_num_episodes=1,
                 record_num_episodes=1,
                 log_period=10):
        """Init."""
        self.logdir = logdir
        self.ckptr = Checkpointer(os.path.join(logdir, 'ckpts'))
        self.env_fn = env_fn
        self.nenv = nenv
        self.eval_num_episodes = eval_num_episodes
        self.record_num_episodes = record_num_episodes
        self.gamma = gamma
        self.frame_stack = frame_stack
        self.buffer_size = buffer_size
        self.batch_size = batch_size
        self.learning_starts = learning_starts
        self.update_period = update_period
        self.eval_eps = eval_eps
        self.target_update_period = target_update_period - (
            target_update_period % self.update_period)
        self.log_period = log_period
        self.device = torch.device(
            'cuda:0' if gpu and torch.cuda.is_available() else 'cpu')

        self.env = VecEpisodeLogger(env_fn(nenv=nenv))
        stacked_env = VecFrameStack(env_fn(nenv=nenv), self.frame_stack)

        self.qf = qf_fn(stacked_env).to(self.device)
        self.qf_targ = qf_fn(stacked_env).to(self.device)
        self.opt = optimizer(self.qf.parameters())
        if huber_loss:
            self.criterion = torch.nn.SmoothL1Loss(reduction='none')
        else:
            self.criterion = torch.nn.MSELoss(reduction='none')
        self.eps_schedule = LinearSchedule(exploration_timesteps, final_eps,
                                           1.0)
        self._actor = EpsilonGreedyActor(self.qf, self.eps_schedule,
                                         self.env.action_space)

        self.buffer = ReplayBuffer(self.buffer_size, self.frame_stack)
        self.data_manager = ReplayBufferDataManager(self.buffer, self.env,
                                                    self._actor, self.device,
                                                    self.learning_starts,
                                                    self.update_period)
        self.t = 0

    def _compute_target(self, rew, next_ob, done):
        qtarg = self.qf_targ(next_ob).max_q
        return rew + (1.0 - done) * self.gamma * qtarg

    def _get_batch(self):
        return self.data_manager.sample(self.batch_size)

    def loss(self, batch):
        """Compute loss."""
        q = self.qf(batch['obs'], batch['action']).value

        with torch.no_grad():
            target = self._compute_target(batch['reward'], batch['next_obs'],
                                          batch['done'])

        assert target.shape == q.shape
        loss = self.criterion(target, q).mean()
        if self.t % self.log_period < self.update_period:
            logger.add_scalar('alg/maxq',
                              torch.max(q).detach().cpu().numpy(), self.t,
                              time.time())
            logger.add_scalar('alg/loss',
                              loss.detach().cpu().numpy(), self.t, time.time())
            logger.add_scalar('alg/epsilon',
                              self.eps_schedule.value(self._actor.t), self.t,
                              time.time())
        return loss

    def step(self):
        """Step."""
        self.t += self.data_manager.step_until_update()
        if self.t % self.target_update_period == 0:
            self.qf_targ.load_state_dict(self.qf.state_dict())

        self.opt.zero_grad()
        loss = self.loss(self._get_batch())
        loss.backward()
        self.opt.step()
        return self.t

    def evaluate(self):
        """Evaluate."""
        eval_env = VecEpsilonGreedy(VecFrameStack(self.env, self.frame_stack),
                                    self.eval_eps)
        self.qf.eval()
        misc.set_env_to_eval_mode(eval_env)

        # Eval policy
        os.makedirs(os.path.join(self.logdir, 'eval'), exist_ok=True)
        outfile = os.path.join(self.logdir, 'eval',
                               self.ckptr.format.format(self.t) + '.json')
        stats = rl_evaluate(eval_env, self.qf, self.eval_num_episodes, outfile,
                            self.device)
        logger.add_scalar('eval/mean_episode_reward', stats['mean_reward'],
                          self.t, time.time())
        logger.add_scalar('eval/mean_episode_length', stats['mean_length'],
                          self.t, time.time())

        # Record policy
        os.makedirs(os.path.join(self.logdir, 'video'), exist_ok=True)
        outfile = os.path.join(self.logdir, 'video',
                               self.ckptr.format.format(self.t) + '.mp4')
        rl_record(eval_env, self.qf, self.record_num_episodes, outfile,
                  self.device)

        self.qf.train()
        misc.set_env_to_train_mode(self.env)
        self.data_manager.manual_reset()

    def save(self):
        """Save."""
        state_dict = {
            'qf': self.qf.state_dict(),
            'qf_targ': self.qf.state_dict(),
            'opt': self.opt.state_dict(),
            '_actor': self._actor.state_dict(),
            'env': misc.env_state_dict(self.env),
            't': self.t
        }
        buffer_dict = self.buffer.state_dict()
        state_dict['buffer_format'] = nest.get_structure(buffer_dict)
        self.ckptr.save(state_dict, self.t)

        # save buffer seperately and only once (because it can be huge)
        np.savez(
            os.path.join(self.ckptr.ckptdir, 'buffer.npz'),
            **{f'{i:04d}': x
               for i, x in enumerate(nest.flatten(buffer_dict))})

    def load(self, t=None):
        """Load."""
        state_dict = self.ckptr.load(t)
        if state_dict is None:
            self.t = 0
            return self.t
        self.qf.load_state_dict(state_dict['qf'])
        self.qf_targ.load_state_dict(state_dict['qf_targ'])
        self.opt.load_state_dict(state_dict['opt'])
        self._actor.load_state_dict(state_dict['_actor'])
        misc.env_load_state_dict(self.env, state_dict['env'])
        self.t = state_dict['t']

        buffer_format = state_dict['buffer_format']
        buffer_state = dict(
            np.load(os.path.join(self.ckptr.ckptdir, 'buffer.npz')))
        buffer_state = nest.flatten(buffer_state)
        self.buffer.load_state_dict(
            nest.pack_sequence_as(buffer_state, buffer_format))
        self.data_manager.manual_reset()
        return self.t

    def close(self):
        """Close environment."""
        try:
            self.env.close()
        except Exception:
            pass
Esempio n. 13
0
    def __init__(self,
                 logdir,
                 env_fn,
                 qf_fn,
                 nenv=1,
                 optimizer=torch.optim.RMSprop,
                 buffer_size=100000,
                 frame_stack=1,
                 learning_starts=10000,
                 update_period=1,
                 gamma=0.99,
                 huber_loss=True,
                 exploration_timesteps=1000000,
                 final_eps=0.1,
                 eval_eps=0.05,
                 target_update_period=10000,
                 batch_size=32,
                 gpu=True,
                 eval_num_episodes=1,
                 record_num_episodes=1,
                 log_period=10):
        """Init."""
        self.logdir = logdir
        self.ckptr = Checkpointer(os.path.join(logdir, 'ckpts'))
        self.env_fn = env_fn
        self.nenv = nenv
        self.eval_num_episodes = eval_num_episodes
        self.record_num_episodes = record_num_episodes
        self.gamma = gamma
        self.frame_stack = frame_stack
        self.buffer_size = buffer_size
        self.batch_size = batch_size
        self.learning_starts = learning_starts
        self.update_period = update_period
        self.eval_eps = eval_eps
        self.target_update_period = target_update_period - (
            target_update_period % self.update_period)
        self.log_period = log_period
        self.device = torch.device(
            'cuda:0' if gpu and torch.cuda.is_available() else 'cpu')

        self.env = VecEpisodeLogger(env_fn(nenv=nenv))
        stacked_env = VecFrameStack(env_fn(nenv=nenv), self.frame_stack)

        self.qf = qf_fn(stacked_env).to(self.device)
        self.qf_targ = qf_fn(stacked_env).to(self.device)
        self.opt = optimizer(self.qf.parameters())
        if huber_loss:
            self.criterion = torch.nn.SmoothL1Loss(reduction='none')
        else:
            self.criterion = torch.nn.MSELoss(reduction='none')
        self.eps_schedule = LinearSchedule(exploration_timesteps, final_eps,
                                           1.0)
        self._actor = EpsilonGreedyActor(self.qf, self.eps_schedule,
                                         self.env.action_space)

        self.buffer = ReplayBuffer(self.buffer_size, self.frame_stack)
        self.data_manager = ReplayBufferDataManager(self.buffer, self.env,
                                                    self._actor, self.device,
                                                    self.learning_starts,
                                                    self.update_period)
        self.t = 0
Esempio n. 14
0
File: ddpg.py Progetto: amackeith/dl
class DDPG(Algorithm):
    """DDPG algorithm."""
    def __init__(self,
                 logdir,
                 env_fn,
                 policy_fn,
                 qf_fn,
                 nenv=1,
                 optimizer=torch.optim.Adam,
                 buffer_size=10000,
                 frame_stack=1,
                 learning_starts=1000,
                 update_period=1,
                 batch_size=256,
                 policy_lr=1e-4,
                 qf_lr=1e-3,
                 qf_weight_decay=0.01,
                 gamma=0.99,
                 noise_theta=0.15,
                 noise_sigma=0.2,
                 noise_sigma_final=0.01,
                 noise_decay_period=10000,
                 target_update_period=1,
                 target_smoothing_coef=0.005,
                 reward_scale=1,
                 gpu=True,
                 eval_num_episodes=1,
                 record_num_episodes=1,
                 log_period=1000):
        """Init."""
        self.logdir = logdir
        self.ckptr = Checkpointer(os.path.join(logdir, 'ckpts'))
        self.env_fn = env_fn
        self.nenv = nenv
        self.eval_num_episodes = eval_num_episodes
        self.record_num_episodes = record_num_episodes
        self.gamma = gamma
        self.buffer_size = buffer_size
        self.frame_stack = frame_stack
        self.learning_starts = learning_starts
        self.update_period = update_period
        self.batch_size = batch_size
        if target_update_period < self.update_period:
            self.target_update_period = self.update_period
        else:
            self.target_update_period = target_update_period - (
                target_update_period % self.update_period)
        self.reward_scale = reward_scale
        self.target_smoothing_coef = target_smoothing_coef
        self.log_period = log_period

        self.device = torch.device(
            'cuda:0' if gpu and torch.cuda.is_available() else 'cpu')
        self.t = 0

        self.env = VecEpisodeLogger(env_fn(nenv=nenv))
        self.policy_fn = policy_fn
        self.qf_fn = qf_fn
        eval_env = VecFrameStack(self.env, self.frame_stack)
        self.pi = policy_fn(eval_env)
        self.qf = qf_fn(eval_env)
        self.target_pi = policy_fn(eval_env)
        self.target_qf = qf_fn(eval_env)

        self.pi.to(self.device)
        self.qf.to(self.device)
        self.target_pi.to(self.device)
        self.target_qf.to(self.device)

        self.optimizer = optimizer
        self.policy_lr = policy_lr
        self.qf_lr = qf_lr
        self.qf_weight_decay = qf_weight_decay
        self.opt_pi = optimizer(self.pi.parameters(), lr=policy_lr)
        self.opt_qf = optimizer(self.qf.parameters(),
                                lr=qf_lr,
                                weight_decay=qf_weight_decay)

        self.target_pi.load_state_dict(self.pi.state_dict())
        self.target_qf.load_state_dict(self.qf.state_dict())

        self.noise_schedule = LinearSchedule(noise_decay_period,
                                             noise_sigma_final, noise_sigma)
        self._actor = DDPGActor(self.pi, self.env.action_space, noise_theta,
                                self.noise_schedule.value(self.t))
        self.buffer = ReplayBuffer(buffer_size, frame_stack)
        self.data_manager = ReplayBufferDataManager(self.buffer, self.env,
                                                    self._actor, self.device,
                                                    self.learning_starts,
                                                    self.update_period)

        self.qf_criterion = torch.nn.MSELoss()
        if self.env.action_space.__class__.__name__ == 'Discrete':
            raise ValueError("Action space must be continuous!")

    def loss(self, batch):
        """Loss function."""
        # compute QFunction loss.
        with torch.no_grad():
            target_action = self.target_pi(batch['next_obs']).action
            target_q = self.target_qf(batch['next_obs'], target_action).value
            qtarg = self.reward_scale * batch['reward'].float() + (
                (1.0 - batch['done']) * self.gamma * target_q)

        q = self.qf(batch['obs'], batch['action']).value
        assert qtarg.shape == q.shape
        qf_loss = self.qf_criterion(q, qtarg)

        # compute policy loss
        action = self.pi(batch['obs'], deterministic=True).action
        q = self.qf(batch['obs'], action).value
        pi_loss = -q.mean()

        # log losses
        if self.t % self.log_period < self.update_period:
            logger.add_scalar('loss/qf', qf_loss, self.t, time.time())
            logger.add_scalar('loss/pi', pi_loss, self.t, time.time())
        return pi_loss, qf_loss

    def step(self):
        """Step optimization."""
        self._actor.update_sigma(self.noise_schedule.value(self.t))
        self.t += self.data_manager.step_until_update()
        if self.t % self.target_update_period == 0:
            soft_target_update(self.target_pi, self.pi,
                               self.target_smoothing_coef)
            soft_target_update(self.target_qf, self.qf,
                               self.target_smoothing_coef)

        if self.t % self.update_period == 0:
            batch = self.data_manager.sample(self.batch_size)

            pi_loss, qf_loss = self.loss(batch)

            # update
            self.opt_qf.zero_grad()
            qf_loss.backward()
            self.opt_qf.step()

            self.opt_pi.zero_grad()
            pi_loss.backward()
            self.opt_pi.step()
        return self.t

    def evaluate(self):
        """Evaluate."""
        eval_env = VecFrameStack(self.env, self.frame_stack)
        self.pi.eval()
        misc.set_env_to_eval_mode(eval_env)

        # Eval policy
        os.makedirs(os.path.join(self.logdir, 'eval'), exist_ok=True)
        outfile = os.path.join(self.logdir, 'eval',
                               self.ckptr.format.format(self.t) + '.json')
        stats = rl_evaluate(eval_env, self.pi, self.eval_num_episodes, outfile,
                            self.device)
        logger.add_scalar('eval/mean_episode_reward', stats['mean_reward'],
                          self.t, time.time())
        logger.add_scalar('eval/mean_episode_length', stats['mean_length'],
                          self.t, time.time())

        # Record policy
        os.makedirs(os.path.join(self.logdir, 'video'), exist_ok=True)
        outfile = os.path.join(self.logdir, 'video',
                               self.ckptr.format.format(self.t) + '.mp4')
        rl_record(eval_env, self.pi, self.record_num_episodes, outfile,
                  self.device)

        self.pi.train()
        misc.set_env_to_train_mode(self.env)
        self.data_manager.manual_reset()

    def save(self):
        """Save."""
        state_dict = {
            'pi': self.pi.state_dict(),
            'qf': self.qf.state_dict(),
            'target_pi': self.target_pi.state_dict(),
            'target_qf': self.target_qf.state_dict(),
            'opt_pi': self.opt_pi.state_dict(),
            'opt_qf': self.opt_qf.state_dict(),
            'env': misc.env_state_dict(self.env),
            't': self.t
        }
        buffer_dict = self.buffer.state_dict()
        state_dict['buffer_format'] = nest.get_structure(buffer_dict)
        self.ckptr.save(state_dict, self.t)

        # save buffer seperately and only once (because it can be huge)
        np.savez(
            os.path.join(self.ckptr.ckptdir, 'buffer.npz'),
            **{f'{i:04d}': x
               for i, x in enumerate(nest.flatten(buffer_dict))})

    def load(self, t=None):
        """Load."""
        state_dict = self.ckptr.load(t)
        if state_dict is None:
            self.t = 0
            return self.t
        self.pi.load_state_dict(state_dict['pi'])
        self.qf.load_state_dict(state_dict['qf'])
        self.target_pi.load_state_dict(state_dict['target_pi'])
        self.target_qf.load_state_dict(state_dict['target_qf'])

        self.opt_pi.load_state_dict(state_dict['opt_pi'])
        self.opt_qf.load_state_dict(state_dict['opt_qf'])
        misc.env_load_state_dict(self.env, state_dict['env'])
        self.t = state_dict['t']

        buffer_format = state_dict['buffer_format']
        buffer_state = dict(
            np.load(os.path.join(self.ckptr.ckptdir, 'buffer.npz')))
        buffer_state = nest.flatten(buffer_state)
        self.buffer.load_state_dict(
            nest.pack_sequence_as(buffer_state, buffer_format))
        self.data_manager.manual_reset()
        return self.t

    def close(self):
        """Close environment."""
        try:
            self.env.close()
        except Exception:
            pass
Esempio n. 15
0
File: ddpg.py Progetto: amackeith/dl
    def __init__(self,
                 logdir,
                 env_fn,
                 policy_fn,
                 qf_fn,
                 nenv=1,
                 optimizer=torch.optim.Adam,
                 buffer_size=10000,
                 frame_stack=1,
                 learning_starts=1000,
                 update_period=1,
                 batch_size=256,
                 policy_lr=1e-4,
                 qf_lr=1e-3,
                 qf_weight_decay=0.01,
                 gamma=0.99,
                 noise_theta=0.15,
                 noise_sigma=0.2,
                 noise_sigma_final=0.01,
                 noise_decay_period=10000,
                 target_update_period=1,
                 target_smoothing_coef=0.005,
                 reward_scale=1,
                 gpu=True,
                 eval_num_episodes=1,
                 record_num_episodes=1,
                 log_period=1000):
        """Init."""
        self.logdir = logdir
        self.ckptr = Checkpointer(os.path.join(logdir, 'ckpts'))
        self.env_fn = env_fn
        self.nenv = nenv
        self.eval_num_episodes = eval_num_episodes
        self.record_num_episodes = record_num_episodes
        self.gamma = gamma
        self.buffer_size = buffer_size
        self.frame_stack = frame_stack
        self.learning_starts = learning_starts
        self.update_period = update_period
        self.batch_size = batch_size
        if target_update_period < self.update_period:
            self.target_update_period = self.update_period
        else:
            self.target_update_period = target_update_period - (
                target_update_period % self.update_period)
        self.reward_scale = reward_scale
        self.target_smoothing_coef = target_smoothing_coef
        self.log_period = log_period

        self.device = torch.device(
            'cuda:0' if gpu and torch.cuda.is_available() else 'cpu')
        self.t = 0

        self.env = VecEpisodeLogger(env_fn(nenv=nenv))
        self.policy_fn = policy_fn
        self.qf_fn = qf_fn
        eval_env = VecFrameStack(self.env, self.frame_stack)
        self.pi = policy_fn(eval_env)
        self.qf = qf_fn(eval_env)
        self.target_pi = policy_fn(eval_env)
        self.target_qf = qf_fn(eval_env)

        self.pi.to(self.device)
        self.qf.to(self.device)
        self.target_pi.to(self.device)
        self.target_qf.to(self.device)

        self.optimizer = optimizer
        self.policy_lr = policy_lr
        self.qf_lr = qf_lr
        self.qf_weight_decay = qf_weight_decay
        self.opt_pi = optimizer(self.pi.parameters(), lr=policy_lr)
        self.opt_qf = optimizer(self.qf.parameters(),
                                lr=qf_lr,
                                weight_decay=qf_weight_decay)

        self.target_pi.load_state_dict(self.pi.state_dict())
        self.target_qf.load_state_dict(self.qf.state_dict())

        self.noise_schedule = LinearSchedule(noise_decay_period,
                                             noise_sigma_final, noise_sigma)
        self._actor = DDPGActor(self.pi, self.env.action_space, noise_theta,
                                self.noise_schedule.value(self.t))
        self.buffer = ReplayBuffer(buffer_size, frame_stack)
        self.data_manager = ReplayBufferDataManager(self.buffer, self.env,
                                                    self._actor, self.device,
                                                    self.learning_starts,
                                                    self.update_period)

        self.qf_criterion = torch.nn.MSELoss()
        if self.env.action_space.__class__.__name__ == 'Discrete':
            raise ValueError("Action space must be continuous!")
Esempio n. 16
0
class BCTrainer(object):
    """Behavioral cloning."""
    def __init__(self,
                 logdir,
                 model,
                 opt,
                 datafile,
                 batch_size,
                 num_workers,
                 gpu=True):
        self.logdir = logdir
        self.ckptr = Checkpointer(os.path.join(logdir, 'ckpts'))
        self.data = DemonstrationData(datafile)
        self.sampler = StatefulSampler(self.data, shuffle=True)
        self.dtrain = DataLoader(self.data,
                                 sampler=self.sampler,
                                 batch_size=batch_size,
                                 num_workers=num_workers)
        self._diter = None
        self.t = 0
        self.epochs = 0
        self.batch_size = batch_size

        self.device = torch.device(
            'cuda:0' if gpu and torch.cuda.is_available() else 'cpu')
        self.model = model
        self.model.to(self.device)
        self.opt = opt(self.model.parameters())

    def step(self):
        # Get batch.
        if self._diter is None:
            self._diter = self.dtrain.__iter__()
        try:
            batch = self._diter.__next__()
        except StopIteration:
            self.epochs += 1
            self._diter = None
            return self.epochs
        batch = nest.map_structure(lambda x: x.to(self.device), batch)

        # compute loss
        ob, ac = batch
        self.model.train()
        loss = -self.model(ob).log_prob(ac).mean()

        logger.add_scalar('train/loss',
                          loss.detach().cpu().numpy(), self.t, time.time())

        # update model
        self.opt.zero_grad()
        loss.backward()
        self.opt.step()

        # increment step
        self.t += min(
            len(self.data) - (self.t % len(self.data)), self.batch_size)
        return self.epochs

    def evaluate(self):
        """Evaluate model."""
        self.model.eval()

        nll = 0.
        with torch.no_grad():
            for batch in self.dtrain:
                ob, ac = nest.map_structure(lambda x: x.to(self.device), batch)
                nll += -self.model(ob).log_prob(ac).sum()
            avg_nll = nll / len(self.data)

            logger.add_scalar('train/NLL', nll, self.epochs, time.time())
            logger.add_scalar('train/AVG_NLL', avg_nll, self.epochs,
                              time.time())

    def save(self):
        state_dict = {}
        state_dict['model'] = self.model.state_dict()
        state_dict['opt'] = self.opt.state_dict()
        state_dict['sampler'] = self.sampler.state_dict(self._diter)
        state_dict['t'] = self.t
        state_dict['epochs'] = self.epochs
        self.ckptr.save(state_dict, self.t)

    def load(self, t=None):
        state_dict = self.ckptr.load()
        if state_dict is None:
            self.t = 0
            self.epochs = 0
            return self.epochs
        self.model.load_state_dict(state_dict['model'])
        self.opt.load_state_dict(state_dict['opt'])
        self.sampler.load_state_dict(state_dict['sampler'])
        self.t = state_dict['t']
        self.epochs = state_dict['epochs']
        if self._diter is not None:
            self._diter.__del__()
            self._diter = None

    def close(self):
        """Close data iterator."""
        if self._diter is not None:
            self._diter.__del__()
            self._diter = None
Esempio n. 17
0
File: sac.py Progetto: takuma-ynd/dl
class SAC(Algorithm):
    """SAC algorithm."""

    def __init__(self,
                 logdir,
                 env_fn,
                 policy_fn,
                 qf_fn,
                 vf_fn,
                 nenv=1,
                 optimizer=torch.optim.Adam,
                 buffer_size=10000,
                 frame_stack=1,
                 learning_starts=1000,
                 update_period=1,
                 batch_size=256,
                 policy_lr=1e-3,
                 qf_lr=1e-3,
                 vf_lr=1e-3,
                 policy_mean_reg_weight=1e-3,
                 gamma=0.99,
                 target_update_period=1,
                 policy_update_period=1,
                 target_smoothing_coef=0.005,
                 automatic_entropy_tuning=True,
                 reparameterization_trick=True,
                 target_entropy=None,
                 reward_scale=1,
                 gpu=True,
                 eval_num_episodes=1,
                 record_num_episodes=1,
                 log_period=1000):
        """Init."""
        self.logdir = logdir
        self.ckptr = Checkpointer(os.path.join(logdir, 'ckpts'))
        self.env_fn = env_fn
        self.nenv = nenv
        self.eval_num_episodes = eval_num_episodes
        self.record_num_episodes = record_num_episodes
        self.gamma = gamma
        self.buffer_size = buffer_size
        self.frame_stack = frame_stack
        self.learning_starts = learning_starts
        self.update_period = update_period
        self.batch_size = batch_size
        if target_update_period < self.update_period:
            self.target_update_period = self.update_period
        else:
            self.target_update_period = target_update_period - (
                                target_update_period % self.update_period)
        if policy_update_period < self.update_period:
            self.policy_update_period = self.update_period
        else:
            self.policy_update_period = policy_update_period - (
                                policy_update_period % self.update_period)
        self.rsample = reparameterization_trick
        self.reward_scale = reward_scale
        self.target_smoothing_coef = target_smoothing_coef
        self.log_period = log_period

        self.device = torch.device('cuda:0' if gpu and torch.cuda.is_available()
                                   else 'cpu')

        self.env = VecEpisodeLogger(env_fn(nenv=nenv))
        eval_env = VecFrameStack(self.env, self.frame_stack)
        self.pi = policy_fn(eval_env)
        self.qf1 = qf_fn(eval_env)
        self.qf2 = qf_fn(eval_env)
        self.vf = vf_fn(eval_env)
        self.target_vf = vf_fn(eval_env)

        self.pi.to(self.device)
        self.qf1.to(self.device)
        self.qf2.to(self.device)
        self.vf.to(self.device)
        self.target_vf.to(self.device)

        self.opt_pi = optimizer(self.pi.parameters(), lr=policy_lr)
        self.opt_qf1 = optimizer(self.qf1.parameters(), lr=qf_lr)
        self.opt_qf2 = optimizer(self.qf2.parameters(), lr=qf_lr)
        self.opt_vf = optimizer(self.vf.parameters(), lr=vf_lr)
        self.policy_mean_reg_weight = policy_mean_reg_weight

        self.target_vf.load_state_dict(self.vf.state_dict())

        self.buffer = ReplayBuffer(buffer_size, frame_stack)
        self.data_manager = ReplayBufferDataManager(self.buffer,
                                                    self.env,
                                                    SACActor(self.pi),
                                                    self.device,
                                                    self.learning_starts,
                                                    self.update_period)

        self.discrete = self.env.action_space.__class__.__name__ == 'Discrete'
        self.automatic_entropy_tuning = automatic_entropy_tuning
        if self.automatic_entropy_tuning:
            if target_entropy:
                self.target_entropy = target_entropy
            else:
                # heuristic value from Tuomas
                if self.discrete:
                    self.target_entropy = np.log(1.5)
                else:
                    self.target_entropy = -np.prod(
                        self.env.action_space.shape).item()
            self.log_alpha = torch.zeros(1, requires_grad=True,
                                         device=self.device)
            self.opt_alpha = optimizer([self.log_alpha], lr=policy_lr)
        else:
            self.target_entropy = None
            self.log_alpha = None
            self.opt_alpha = None

        self.qf_criterion = torch.nn.MSELoss()
        self.vf_criterion = torch.nn.MSELoss()

        self.t = 0

    def loss(self, batch):
        """Loss function."""
        pi_out = self.pi(batch['obs'], reparameterization_trick=self.rsample)
        if self.discrete:
            new_ac = pi_out.action
            ent = pi_out.dist.entropy()
        else:
            assert isinstance(pi_out.dist, TanhNormal), (
                "It is strongly encouraged that you use a TanhNormal "
                "action distribution for continuous action spaces.")
            if self.rsample:
                new_ac, new_pth_ac = pi_out.dist.rsample(
                                                    return_pretanh_value=True)
            else:
                new_ac, new_pth_ac = pi_out.dist.sample(
                                                    return_pretanh_value=True)
            logp = pi_out.dist.log_prob(new_ac, new_pth_ac)
        q1 = self.qf1(batch['obs'], batch['action']).value
        q2 = self.qf2(batch['obs'], batch['action']).value
        v = self.vf(batch['obs']).value

        # alpha loss
        if self.automatic_entropy_tuning:
            if self.discrete:
                ent_error = -ent + self.target_entropy
            else:
                ent_error = logp + self.target_entropy
            alpha_loss = -(self.log_alpha * ent_error.detach()).mean()
            self.opt_alpha.zero_grad()
            alpha_loss.backward()
            self.opt_alpha.step()
            alpha = self.log_alpha.exp()
        else:
            alpha = 1
            alpha_loss = 0

        # qf loss
        vtarg = self.target_vf(batch['next_obs']).value
        qtarg = self.reward_scale * batch['reward'].float() + (
                    (1.0 - batch['done']) * self.gamma * vtarg)
        assert qtarg.shape == q1.shape
        assert qtarg.shape == q2.shape
        qf1_loss = self.qf_criterion(q1, qtarg.detach())
        qf2_loss = self.qf_criterion(q2, qtarg.detach())

        # vf loss
        q1_outs = self.qf1(batch['obs'], new_ac)
        q1_new = q1_outs.value
        q2_new = self.qf2(batch['obs'], new_ac).value
        q = torch.min(q1_new, q2_new)
        if self.discrete:
            vtarg = q + alpha * ent
        else:
            vtarg = q - alpha * logp
        assert v.shape == vtarg.shape
        vf_loss = self.vf_criterion(v, vtarg.detach())

        # pi loss
        pi_loss = None
        if self.t % self.policy_update_period == 0:
            if self.discrete:
                target_dist = CatDist(logits=q1_outs.qvals.detach())
                pi_dist = CatDist(logits=alpha * pi_out.dist.logits)
                pi_loss = pi_dist.kl(target_dist).mean()
            else:
                if self.rsample:
                    assert q.shape == logp.shape
                    pi_loss = (alpha*logp - q1_new).mean()
                else:
                    pi_targ = q1_new - v
                    assert pi_targ.shape == logp.shape
                    pi_loss = (logp * (alpha * logp - pi_targ).detach()).mean()

                pi_loss += self.policy_mean_reg_weight * (
                                            pi_out.dist.normal.mean**2).mean()

            # log pi loss about as frequently as other losses
            if self.t % self.log_period < self.policy_update_period:
                logger.add_scalar('loss/pi', pi_loss, self.t, time.time())

        if self.t % self.log_period < self.update_period:
            if self.automatic_entropy_tuning:
                logger.add_scalar('ent/log_alpha',
                                  self.log_alpha.detach().cpu().numpy(), self.t,
                                  time.time())
                if self.discrete:
                    scalars = {"target": self.target_entropy,
                               "entropy": ent.mean().detach().cpu().numpy().item()}
                else:
                    scalars = {"target": self.target_entropy,
                               "entropy": -torch.mean(
                                            logp.detach()).cpu().numpy().item()}
                logger.add_scalars('ent/entropy', scalars, self.t, time.time())
            else:
                if self.discrete:
                    logger.add_scalar(
                            'ent/entropy',
                            ent.mean().detach().cpu().numpy().item(),
                            self.t, time.time())
                else:
                    logger.add_scalar(
                            'ent/entropy',
                            -torch.mean(logp.detach()).cpu().numpy().item(),
                            self.t, time.time())
            logger.add_scalar('loss/qf1', qf1_loss, self.t, time.time())
            logger.add_scalar('loss/qf2', qf2_loss, self.t, time.time())
            logger.add_scalar('loss/vf', vf_loss, self.t, time.time())
        return pi_loss, qf1_loss, qf2_loss, vf_loss

    def step(self):
        """Step optimization."""
        self.t += self.data_manager.step_until_update()
        if self.t % self.target_update_period == 0:
            soft_target_update(self.target_vf, self.vf,
                               self.target_smoothing_coef)

        if self.t % self.update_period == 0:
            batch = self.data_manager.sample(self.batch_size)

            pi_loss, qf1_loss, qf2_loss, vf_loss = self.loss(batch)

            # update
            self.opt_qf1.zero_grad()
            qf1_loss.backward()
            self.opt_qf1.step()

            self.opt_qf2.zero_grad()
            qf2_loss.backward()
            self.opt_qf2.step()

            self.opt_vf.zero_grad()
            vf_loss.backward()
            self.opt_vf.step()

            if pi_loss:
                self.opt_pi.zero_grad()
                pi_loss.backward()
                self.opt_pi.step()
        return self.t

    def evaluate(self):
        """Evaluate."""
        eval_env = VecFrameStack(self.env, self.frame_stack)
        self.pi.eval()
        misc.set_env_to_eval_mode(eval_env)

        # Eval policy
        os.makedirs(os.path.join(self.logdir, 'eval'), exist_ok=True)
        outfile = os.path.join(self.logdir, 'eval',
                               self.ckptr.format.format(self.t) + '.json')
        stats = rl_evaluate(eval_env, self.pi, self.eval_num_episodes,
                            outfile, self.device)
        logger.add_scalar('eval/mean_episode_reward', stats['mean_reward'],
                          self.t, time.time())
        logger.add_scalar('eval/mean_episode_length', stats['mean_length'],
                          self.t, time.time())

        # Record policy
        os.makedirs(os.path.join(self.logdir, 'video'), exist_ok=True)
        outfile = os.path.join(self.logdir, 'video',
                               self.ckptr.format.format(self.t) + '.mp4')
        rl_record(eval_env, self.pi, self.record_num_episodes, outfile,
                  self.device)

        self.pi.train()
        misc.set_env_to_train_mode(self.env)
        self.data_manager.manual_reset()

    def save(self):
        """Save."""
        state_dict = {
            'pi': self.pi.state_dict(),
            'qf1': self.qf1.state_dict(),
            'qf2': self.qf2.state_dict(),
            'vf': self.vf.state_dict(),
            'opt_pi': self.opt_pi.state_dict(),
            'opt_qf1': self.opt_qf1.state_dict(),
            'opt_qf2': self.opt_qf2.state_dict(),
            'opt_vf': self.opt_vf.state_dict(),
            'log_alpha': (self.log_alpha if self.automatic_entropy_tuning
                          else None),
            'opt_alpha': (self.opt_alpha.state_dict()
                          if self.automatic_entropy_tuning else None),
            'env': misc.env_state_dict(self.env),
            't': self.t
        }
        buffer_dict = self.buffer.state_dict()
        state_dict['buffer_format'] = nest.get_structure(buffer_dict)
        self.ckptr.save(state_dict, self.t)

        # save buffer seperately and only once (because it can be huge)
        np.savez(os.path.join(self.ckptr.ckptdir, 'buffer.npz'),
                 **{f'{i:04d}': x for i, x in
                    enumerate(nest.flatten(buffer_dict))})

    def load(self, t=None):
        """Load."""
        state_dict = self.ckptr.load(t)
        if state_dict is None:
            self.t = 0
            return self.t
        self.pi.load_state_dict(state_dict['pi'])
        self.qf1.load_state_dict(state_dict['qf1'])
        self.qf2.load_state_dict(state_dict['qf2'])
        self.vf.load_state_dict(state_dict['vf'])
        self.target_vf.load_state_dict(state_dict['vf'])

        self.opt_pi.load_state_dict(state_dict['opt_pi'])
        self.opt_qf1.load_state_dict(state_dict['opt_qf1'])
        self.opt_qf2.load_state_dict(state_dict['opt_qf2'])
        self.opt_vf.load_state_dict(state_dict['opt_vf'])

        if state_dict['log_alpha']:
            with torch.no_grad():
                self.log_alpha.copy_(state_dict['log_alpha'])
            self.opt_alpha.load_state_dict(state_dict['opt_alpha'])
        misc.env_load_state_dict(self.env, state_dict['env'])
        self.t = state_dict['t']

        buffer_format = state_dict['buffer_format']
        buffer_state = dict(np.load(os.path.join(self.ckptr.ckptdir,
                                                 'buffer.npz')))
        buffer_state = nest.flatten(buffer_state)
        self.buffer.load_state_dict(nest.pack_sequence_as(buffer_state,
                                                          buffer_format))
        self.data_manager.manual_reset()
        return self.t

    def close(self):
        """Close environment."""
        try:
            self.env.close()
        except Exception:
            pass
Esempio n. 18
0
class SAC(Algorithm):
    """SAC algorithm."""
    def __init__(self,
                 logdir,
                 env_fn,
                 policy_fn,
                 qf_fn,
                 nenv=1,
                 optimizer=torch.optim.Adam,
                 buffer_size=10000,
                 frame_stack=1,
                 learning_starts=1000,
                 update_period=1,
                 batch_size=256,
                 policy_lr=1e-3,
                 qf_lr=1e-3,
                 gamma=0.99,
                 target_update_period=1,
                 policy_update_period=1,
                 target_smoothing_coef=0.005,
                 alpha=0.2,
                 automatic_entropy_tuning=True,
                 target_entropy=None,
                 gpu=True,
                 eval_num_episodes=1,
                 record_num_episodes=1,
                 log_period=1000):
        """Init."""
        self.logdir = logdir
        self.ckptr = Checkpointer(os.path.join(logdir, 'ckpts'))
        self.env_fn = env_fn
        self.nenv = nenv
        self.eval_num_episodes = eval_num_episodes
        self.record_num_episodes = record_num_episodes
        self.gamma = gamma
        self.buffer_size = buffer_size
        self.frame_stack = frame_stack
        self.learning_starts = learning_starts
        self.update_period = update_period
        self.batch_size = batch_size
        if target_update_period < self.update_period:
            self.target_update_period = self.update_period
        else:
            self.target_update_period = target_update_period - (
                target_update_period % self.update_period)
        if policy_update_period < self.update_period:
            self.policy_update_period = self.update_period
        else:
            self.policy_update_period = policy_update_period - (
                policy_update_period % self.update_period)
        self.target_smoothing_coef = target_smoothing_coef
        self.log_period = log_period

        self.device = torch.device(
            'cuda:0' if gpu and torch.cuda.is_available() else 'cpu')

        self.env = VecEpisodeLogger(env_fn(nenv=nenv))
        eval_env = VecFrameStack(self.env, self.frame_stack)
        self.pi = policy_fn(eval_env)
        self.qf1 = qf_fn(eval_env)
        self.qf2 = qf_fn(eval_env)
        self.target_qf1 = qf_fn(eval_env)
        self.target_qf2 = qf_fn(eval_env)

        self.pi.to(self.device)
        self.qf1.to(self.device)
        self.qf2.to(self.device)
        self.target_qf1.to(self.device)
        self.target_qf2.to(self.device)

        self.opt_pi = optimizer(self.pi.parameters(), lr=policy_lr)
        self.opt_qf1 = optimizer(self.qf1.parameters(), lr=qf_lr)
        self.opt_qf2 = optimizer(self.qf2.parameters(), lr=qf_lr)

        self.target_qf1.load_state_dict(self.qf1.state_dict())
        self.target_qf2.load_state_dict(self.qf2.state_dict())

        self.buffer = BatchedReplayBuffer(
            *
            [ReplayBuffer(buffer_size, frame_stack) for _ in range(self.nenv)])
        self.data_manager = ReplayBufferDataManager(self.buffer, self.env,
                                                    SACActor(self.pi),
                                                    self.device,
                                                    self.learning_starts,
                                                    self.update_period)

        self.alpha = alpha
        self.automatic_entropy_tuning = automatic_entropy_tuning
        if self.automatic_entropy_tuning:
            if target_entropy:
                self.target_entropy = target_entropy
            else:
                target_entropies = nest.map_structure(
                    lambda space: -np.prod(space.shape).item(),
                    misc.unpack_space(self.env.action_space))
                self.target_entropy = sum(nest.flatten(target_entropies))

            self.log_alpha = torch.tensor(np.log([self.alpha]),
                                          requires_grad=True,
                                          device=self.device,
                                          dtype=torch.float32)
            self.opt_alpha = optimizer([self.log_alpha], lr=policy_lr)
        else:
            self.target_entropy = None
            self.log_alpha = None
            self.opt_alpha = None

        self.mse_loss = torch.nn.MSELoss()

        self.t = 0

    def loss(self, batch):
        """Loss function."""
        pi_out = self.pi(batch['obs'], reparameterization_trick=True)
        logp = pi_out.dist.log_prob(pi_out.action)
        q1 = self.qf1(batch['obs'], batch['action']).value
        q2 = self.qf2(batch['obs'], batch['action']).value

        # alpha loss
        if self.automatic_entropy_tuning:
            ent_error = logp + self.target_entropy
            alpha_loss = -(self.log_alpha * ent_error.detach()).mean()
            self.opt_alpha.zero_grad()
            alpha_loss.backward()
            self.opt_alpha.step()
            alpha = self.log_alpha.exp()
        else:
            alpha = self.alpha
            alpha_loss = 0

        # qf loss
        with torch.no_grad():
            next_pi_out = self.pi(batch['next_obs'])
            next_ac_logp = next_pi_out.dist.log_prob(next_pi_out.action)
            q1_next = self.target_qf1(batch['next_obs'],
                                      next_pi_out.action).value
            q2_next = self.target_qf2(batch['next_obs'],
                                      next_pi_out.action).value
            qnext = torch.min(q1_next, q2_next) - alpha * next_ac_logp
            qtarg = batch['reward'] + (1.0 -
                                       batch['done']) * self.gamma * qnext

        assert qtarg.shape == q1.shape
        assert qtarg.shape == q2.shape
        qf1_loss = self.mse_loss(q1, qtarg)
        qf2_loss = self.mse_loss(q2, qtarg)

        # pi loss
        pi_loss = None
        if self.t % self.policy_update_period == 0:
            q1_pi = self.qf1(batch['obs'], pi_out.action).value
            q2_pi = self.qf2(batch['obs'], pi_out.action).value
            min_q_pi = torch.min(q1_pi, q2_pi)
            assert min_q_pi.shape == logp.shape
            pi_loss = (alpha * logp - min_q_pi).mean()

            # log pi loss about as frequently as other losses
            if self.t % self.log_period < self.policy_update_period:
                logger.add_scalar('loss/pi', pi_loss, self.t, time.time())

        if self.t % self.log_period < self.update_period:
            if self.automatic_entropy_tuning:
                logger.add_scalar('alg/log_alpha',
                                  self.log_alpha.detach().cpu().numpy(),
                                  self.t, time.time())
                scalars = {
                    "target": self.target_entropy,
                    "entropy": -torch.mean(logp.detach()).cpu().numpy().item()
                }
                logger.add_scalars('alg/entropy', scalars, self.t, time.time())
            else:
                logger.add_scalar(
                    'alg/entropy',
                    -torch.mean(logp.detach()).cpu().numpy().item(), self.t,
                    time.time())
            logger.add_scalar('loss/qf1', qf1_loss, self.t, time.time())
            logger.add_scalar('loss/qf2', qf2_loss, self.t, time.time())
            logger.add_scalar('alg/qf1',
                              q1.mean().detach().cpu().numpy(), self.t,
                              time.time())
            logger.add_scalar('alg/qf2',
                              q2.mean().detach().cpu().numpy(), self.t,
                              time.time())
        return pi_loss, qf1_loss, qf2_loss

    def step(self):
        """Step optimization."""
        self.t += self.data_manager.step_until_update()
        if self.t % self.target_update_period == 0:
            soft_target_update(self.target_qf1, self.qf1,
                               self.target_smoothing_coef)
            soft_target_update(self.target_qf2, self.qf2,
                               self.target_smoothing_coef)

        if self.t % self.update_period == 0:
            batch = self.data_manager.sample(self.batch_size)

            pi_loss, qf1_loss, qf2_loss = self.loss(batch)

            # update
            if pi_loss:
                self.opt_pi.zero_grad()
                pi_loss.backward()
                self.opt_pi.step()

            self.opt_qf1.zero_grad()
            qf1_loss.backward()
            self.opt_qf1.step()

            self.opt_qf2.zero_grad()
            qf2_loss.backward()
            self.opt_qf2.step()

        return self.t

    def evaluate(self):
        """Evaluate."""
        eval_env = VecFrameStack(self.env, self.frame_stack)
        self.pi.eval()
        misc.set_env_to_eval_mode(eval_env)

        # Eval policy
        os.makedirs(os.path.join(self.logdir, 'eval'), exist_ok=True)
        outfile = os.path.join(self.logdir, 'eval',
                               self.ckptr.format.format(self.t) + '.json')
        stats = rl_evaluate(eval_env, self.pi, self.eval_num_episodes, outfile,
                            self.device)
        logger.add_scalar('eval/mean_episode_reward', stats['mean_reward'],
                          self.t, time.time())
        logger.add_scalar('eval/mean_episode_length', stats['mean_length'],
                          self.t, time.time())

        # Record policy
        os.makedirs(os.path.join(self.logdir, 'video'), exist_ok=True)
        outfile = os.path.join(self.logdir, 'video',
                               self.ckptr.format.format(self.t) + '.mp4')
        rl_record(eval_env, self.pi, self.record_num_episodes, outfile,
                  self.device)

        self.pi.train()
        misc.set_env_to_train_mode(self.env)
        self.data_manager.manual_reset()

    def save(self):
        """Save."""
        state_dict = {
            'pi':
            self.pi.state_dict(),
            'qf1':
            self.qf1.state_dict(),
            'qf2':
            self.qf2.state_dict(),
            'target_qf1':
            self.target_qf1.state_dict(),
            'target_qf2':
            self.target_qf2.state_dict(),
            'opt_pi':
            self.opt_pi.state_dict(),
            'opt_qf1':
            self.opt_qf1.state_dict(),
            'opt_qf2':
            self.opt_qf2.state_dict(),
            'log_alpha':
            (self.log_alpha if self.automatic_entropy_tuning else None),
            'opt_alpha': (self.opt_alpha.state_dict()
                          if self.automatic_entropy_tuning else None),
            'env':
            misc.env_state_dict(self.env),
            't':
            self.t
        }
        buffer_dict = self.buffer.state_dict()
        state_dict['buffer_format'] = nest.get_structure(buffer_dict)
        self.ckptr.save(state_dict, self.t)

        # save buffer seperately and only once (because it can be huge)
        np.savez(
            os.path.join(self.ckptr.ckptdir, 'buffer.npz'),
            **{f'{i:04d}': x
               for i, x in enumerate(nest.flatten(buffer_dict))})

    def load(self, t=None):
        """Load."""
        state_dict = self.ckptr.load(t)
        if state_dict is None:
            self.t = 0
            return self.t
        self.pi.load_state_dict(state_dict['pi'])
        self.qf1.load_state_dict(state_dict['qf1'])
        self.qf2.load_state_dict(state_dict['qf2'])
        self.target_qf1.load_state_dict(state_dict['target_qf1'])
        self.target_qf2.load_state_dict(state_dict['target_qf2'])

        self.opt_pi.load_state_dict(state_dict['opt_pi'])
        self.opt_qf1.load_state_dict(state_dict['opt_qf1'])
        self.opt_qf2.load_state_dict(state_dict['opt_qf2'])

        if state_dict['log_alpha']:
            with torch.no_grad():
                self.log_alpha.copy_(state_dict['log_alpha'])
            self.opt_alpha.load_state_dict(state_dict['opt_alpha'])
        misc.env_load_state_dict(self.env, state_dict['env'])
        self.t = state_dict['t']

        buffer_format = state_dict['buffer_format']
        buffer_state = dict(
            np.load(os.path.join(self.ckptr.ckptdir, 'buffer.npz'),
                    allow_pickle=True))
        buffer_state = nest.flatten(buffer_state)
        self.buffer.load_state_dict(
            nest.pack_sequence_as(buffer_state, buffer_format))
        self.data_manager.manual_reset()
        return self.t

    def close(self):
        """Close environment."""
        try:
            self.env.close()
        except Exception:
            pass
Esempio n. 19
0
class ConstrainedResidualPPO(Algorithm):
    """Constrained Residual PPO algorithm."""
    def __init__(
            self,
            logdir,
            env_fn,
            policy_fn,
            nenv=1,
            optimizer=torch.optim.Adam,
            lambda_lr=1e-4,
            lambda_init=100.,
            lr_decay_rate=1. / 3.16227766017,
            lr_decay_freq=20000000,
            l2_reg=True,
            reward_threshold=-0.05,
            rollout_length=128,
            batch_size=32,
            gamma=0.99,
            lambda_=0.95,
            norm_advantages=False,
            epochs_per_rollout=10,
            max_grad_norm=None,
            ent_coef=0.01,
            vf_coef=0.5,
            clip_param=0.2,
            base_actor_cls=None,
            policy_training_start=10000,
            lambda_training_start=100000,
            eval_num_episodes=1,
            record_num_episodes=1,
            wrapper_fn=None,  # additional wrappers for the env
            gpu=True):
        """Init."""
        self.logdir = logdir
        self.ckptr = Checkpointer(os.path.join(logdir, 'ckpts'))
        self.env_fn = env_fn
        self.nenv = nenv
        self.eval_num_episodes = eval_num_episodes
        self.record_num_episodes = record_num_episodes
        self.epochs_per_rollout = epochs_per_rollout
        self.max_grad_norm = max_grad_norm
        self.ent_coef = ent_coef
        self.vf_coef = vf_coef
        self.clip_param = clip_param
        self.base_actor_cls = base_actor_cls
        self.policy_training_start = policy_training_start
        self.lambda_training_start = lambda_training_start
        self.lambda_lr = lambda_lr
        self.lr_decay_rate = lr_decay_rate
        self.lr_decay_freq = lr_decay_freq
        self.l2_reg = l2_reg
        self.reward_threshold = reward_threshold
        self.device = torch.device(
            'cuda:0' if gpu and torch.cuda.is_available() else 'cpu')

        self.env = VecEpisodeLogger(env_fn(nenv=nenv))
        self.env = ResidualWrapper(self.env, self.base_actor_cls(self.env))
        if wrapper_fn:
            self.env = wrapper_fn(self.env)

        self.pi = policy_fn(self.env).to(self.device)
        self.opt = optimizer(self.pi.parameters())
        self.pi_lr = self.opt.param_groups[0]['lr']
        if lambda_init < 10:
            lambda_init = np.log(np.exp(lambda_init) - 1)
        self.log_lambda_ = nn.Parameter(
            torch.Tensor([lambda_init]).to(self.device))
        self.opt_l = optimizer([self.log_lambda_], lr=lambda_lr)
        self._actor = ResidualPPOActor(self.pi, policy_training_start)
        self.data_manager = RolloutDataManager(self.env,
                                               self._actor,
                                               self.device,
                                               rollout_length=rollout_length,
                                               batch_size=batch_size,
                                               gamma=gamma,
                                               lambda_=lambda_,
                                               norm_advantages=norm_advantages)

        self.mse = nn.MSELoss(reduction='none')
        self.huber = nn.SmoothL1Loss()

        self.t = 0

    def loss(self, batch):
        """Compute loss."""
        if self.data_manager.recurrent:
            outs = self.pi(batch['obs'], batch['state'], batch['mask'])
        else:
            outs = self.pi(batch['obs'])
        loss = {}

        # compute policy loss
        if self.t < self.policy_training_start:
            pi_loss = torch.Tensor([0.0]).to(self.device)
        else:
            logp = outs.dist.log_prob(batch['action'])
            assert logp.shape == batch['logp'].shape
            ratio = torch.exp(logp - batch['logp'])
            assert ratio.shape == batch['atarg'].shape
            ploss1 = ratio * batch['atarg']
            ploss2 = torch.clamp(ratio, 1.0 - self.clip_param,
                                 1.0 + self.clip_param) * batch['atarg']
            pi_loss = -torch.min(ploss1, ploss2).mean()
        loss['pi'] = pi_loss

        # compute value loss
        vloss1 = 0.5 * self.mse(outs.value, batch['vtarg'])
        vpred_clipped = batch['vpred'] + (outs.value - batch['vpred']).clamp(
            -self.clip_param, self.clip_param)
        vloss2 = 0.5 * self.mse(vpred_clipped, batch['vtarg'])
        vf_loss = torch.max(vloss1, vloss2).mean()
        loss['value'] = vf_loss

        # compute entropy loss
        if self.t < self.policy_training_start:
            ent_loss = torch.Tensor([0.0]).to(self.device)
        else:
            ent_loss = outs.dist.entropy().mean()
        loss['entropy'] = ent_loss

        # compute residual regularizer
        if self.t < self.policy_training_start:
            reg_loss = torch.Tensor([0.0]).to(self.device)
        else:
            if self.l2_reg:
                reg_loss = outs.dist.rsample().pow(2).sum(dim=-1).mean()
            else:  # huber loss
                ac_norm = torch.norm(outs.dist.rsample(), dim=-1)
                reg_loss = self.huber(ac_norm, torch.zeros_like(ac_norm))
        loss['reg'] = reg_loss

        ###############################
        # Constrained loss added here.
        ###############################

        # soft plus on lambda to constrain it to be positive.
        lambda_ = F.softplus(self.log_lambda_)
        logger.add_scalar('alg/lambda', lambda_, self.t, time.time())
        logger.add_scalar('alg/lambda_', self.log_lambda_, self.t, time.time())
        if self.t < max(self.policy_training_start,
                        self.lambda_training_start):
            loss['lambda'] = torch.Tensor([0.0]).to(self.device)
        else:
            neps = (1.0 - batch['mask']).sum()
            loss['lambda'] = (
                lambda_ *
                (batch['reward'].sum() - self.reward_threshold * neps) /
                batch['reward'].size()[0])
        if self.t >= self.policy_training_start:
            loss['pi'] = (reg_loss + lambda_ * loss['pi']) / (1. + lambda_)
        loss['total'] = (loss['pi'] + self.vf_coef * vf_loss -
                         self.ent_coef * ent_loss)
        return loss

    def step(self):
        """Compute rollout, loss, and update model."""
        self.pi.train()
        # adjust learning rate
        lr_frac = self.lr_decay_rate**(self.t // self.lr_decay_freq)
        for g in self.opt.param_groups:
            g['lr'] = self.pi_lr * lr_frac
        for g in self.opt_l.param_groups:
            g['lr'] = self.lambda_lr * lr_frac

        self.data_manager.rollout()
        self.t += self.data_manager.rollout_length * self.nenv
        losses = {}
        for _ in range(self.epochs_per_rollout):
            for batch in self.data_manager.sampler():
                loss = self.loss(batch)
                if losses == {}:
                    losses = {k: [] for k in loss}
                for k, v in loss.items():
                    losses[k].append(v.detach().cpu().numpy())
                if self.t >= max(self.policy_training_start,
                                 self.lambda_training_start):
                    self.opt_l.zero_grad()
                    loss['lambda'].backward(retain_graph=True)
                    self.opt_l.step()
                self.opt.zero_grad()
                loss['total'].backward()
                if self.max_grad_norm:
                    nn.utils.clip_grad_norm_(self.pi.parameters(),
                                             self.max_grad_norm)
                self.opt.step()
        for k, v in losses.items():
            logger.add_scalar(f'loss/{k}', np.mean(v), self.t, time.time())
        logger.add_scalar('alg/lr_pi', self.opt.param_groups[0]['lr'], self.t,
                          time.time())
        logger.add_scalar('alg/lr_lambda', self.opt_l.param_groups[0]['lr'],
                          self.t, time.time())
        return self.t

    def evaluate(self):
        """Evaluate model."""
        self.pi.eval()
        misc.set_env_to_eval_mode(self.env)

        # Eval policy
        os.makedirs(os.path.join(self.logdir, 'eval'), exist_ok=True)
        outfile = os.path.join(self.logdir, 'eval',
                               self.ckptr.format.format(self.t) + '.json')
        stats = rl_evaluate(self.env, self.pi, self.eval_num_episodes, outfile,
                            self.device)
        logger.add_scalar('eval/mean_episode_reward', stats['mean_reward'],
                          self.t, time.time())
        logger.add_scalar('eval/mean_episode_length', stats['mean_length'],
                          self.t, time.time())

        # Record policy
        # os.makedirs(os.path.join(self.logdir, 'video'), exist_ok=True)
        # outfile = os.path.join(self.logdir, 'video',
        #                        self.ckptr.format.format(self.t) + '.mp4')
        # rl_record(self.env, self.pi, self.record_num_episodes, outfile,
        #           self.device)

        self.pi.train()
        misc.set_env_to_train_mode(self.env)

    def save(self):
        """State dict."""
        state_dict = {
            'pi': self.pi.state_dict(),
            'opt': self.opt.state_dict(),
            'lambda_': self.log_lambda_,
            'opt_l': self.opt_l.state_dict(),
            'env': misc.env_state_dict(self.env),
            '_actor': self._actor.state_dict(),
            't': self.t
        }
        self.ckptr.save(state_dict, self.t)

    def load(self, t=None):
        """Load state dict."""
        state_dict = self.ckptr.load(t)
        if state_dict is None:
            self.t = 0
            return self.t
        self.pi.load_state_dict(state_dict['pi'])
        self.opt.load_state_dict(state_dict['opt'])
        self.opt_l.load_state_dict(state_dict['opt_l'])
        self.log_lambda_.data.copy_(state_dict['lambda_'])
        misc.env_load_state_dict(self.env, state_dict['env'])
        self._actor.load_state_dict(state_dict['_actor'])
        self.t = state_dict['t']
        return self.t

    def close(self):
        """Close environment."""
        try:
            self.env.close()
        except Exception:
            pass
Esempio n. 20
0
File: sac.py Progetto: takuma-ynd/dl
    def __init__(self,
                 logdir,
                 env_fn,
                 policy_fn,
                 qf_fn,
                 vf_fn,
                 nenv=1,
                 optimizer=torch.optim.Adam,
                 buffer_size=10000,
                 frame_stack=1,
                 learning_starts=1000,
                 update_period=1,
                 batch_size=256,
                 policy_lr=1e-3,
                 qf_lr=1e-3,
                 vf_lr=1e-3,
                 policy_mean_reg_weight=1e-3,
                 gamma=0.99,
                 target_update_period=1,
                 policy_update_period=1,
                 target_smoothing_coef=0.005,
                 automatic_entropy_tuning=True,
                 reparameterization_trick=True,
                 target_entropy=None,
                 reward_scale=1,
                 gpu=True,
                 eval_num_episodes=1,
                 record_num_episodes=1,
                 log_period=1000):
        """Init."""
        self.logdir = logdir
        self.ckptr = Checkpointer(os.path.join(logdir, 'ckpts'))
        self.env_fn = env_fn
        self.nenv = nenv
        self.eval_num_episodes = eval_num_episodes
        self.record_num_episodes = record_num_episodes
        self.gamma = gamma
        self.buffer_size = buffer_size
        self.frame_stack = frame_stack
        self.learning_starts = learning_starts
        self.update_period = update_period
        self.batch_size = batch_size
        if target_update_period < self.update_period:
            self.target_update_period = self.update_period
        else:
            self.target_update_period = target_update_period - (
                                target_update_period % self.update_period)
        if policy_update_period < self.update_period:
            self.policy_update_period = self.update_period
        else:
            self.policy_update_period = policy_update_period - (
                                policy_update_period % self.update_period)
        self.rsample = reparameterization_trick
        self.reward_scale = reward_scale
        self.target_smoothing_coef = target_smoothing_coef
        self.log_period = log_period

        self.device = torch.device('cuda:0' if gpu and torch.cuda.is_available()
                                   else 'cpu')

        self.env = VecEpisodeLogger(env_fn(nenv=nenv))
        eval_env = VecFrameStack(self.env, self.frame_stack)
        self.pi = policy_fn(eval_env)
        self.qf1 = qf_fn(eval_env)
        self.qf2 = qf_fn(eval_env)
        self.vf = vf_fn(eval_env)
        self.target_vf = vf_fn(eval_env)

        self.pi.to(self.device)
        self.qf1.to(self.device)
        self.qf2.to(self.device)
        self.vf.to(self.device)
        self.target_vf.to(self.device)

        self.opt_pi = optimizer(self.pi.parameters(), lr=policy_lr)
        self.opt_qf1 = optimizer(self.qf1.parameters(), lr=qf_lr)
        self.opt_qf2 = optimizer(self.qf2.parameters(), lr=qf_lr)
        self.opt_vf = optimizer(self.vf.parameters(), lr=vf_lr)
        self.policy_mean_reg_weight = policy_mean_reg_weight

        self.target_vf.load_state_dict(self.vf.state_dict())

        self.buffer = ReplayBuffer(buffer_size, frame_stack)
        self.data_manager = ReplayBufferDataManager(self.buffer,
                                                    self.env,
                                                    SACActor(self.pi),
                                                    self.device,
                                                    self.learning_starts,
                                                    self.update_period)

        self.discrete = self.env.action_space.__class__.__name__ == 'Discrete'
        self.automatic_entropy_tuning = automatic_entropy_tuning
        if self.automatic_entropy_tuning:
            if target_entropy:
                self.target_entropy = target_entropy
            else:
                # heuristic value from Tuomas
                if self.discrete:
                    self.target_entropy = np.log(1.5)
                else:
                    self.target_entropy = -np.prod(
                        self.env.action_space.shape).item()
            self.log_alpha = torch.zeros(1, requires_grad=True,
                                         device=self.device)
            self.opt_alpha = optimizer([self.log_alpha], lr=policy_lr)
        else:
            self.target_entropy = None
            self.log_alpha = None
            self.opt_alpha = None

        self.qf_criterion = torch.nn.MSELoss()
        self.vf_criterion = torch.nn.MSELoss()

        self.t = 0
Esempio n. 21
0
class PPO(Algorithm):
    """PPO algorithm."""

    def __init__(self,
                 logdir,
                 env_fn,
                 policy_fn,
                 nenv=1,
                 optimizer=torch.optim.Adam,
                 batch_size=32,
                 rollout_length=None,
                 gamma=0.99,
                 lambda_=0.95,
                 norm_advantages=False,
                 epochs_per_rollout=10,
                 max_grad_norm=None,
                 ent_coef=0.01,
                 vf_coef=0.5,
                 clip_param=0.2,
                 eval_num_episodes=1,
                 record_num_episodes=1,
                 gpu=True):
        """Init."""
        self.logdir = logdir
        self.ckptr = Checkpointer(os.path.join(logdir, 'ckpts'))
        self.env_fn = env_fn
        self.nenv = nenv
        self.eval_num_episodes = eval_num_episodes
        self.record_num_episodes = record_num_episodes
        self.epochs_per_rollout = epochs_per_rollout
        self.max_grad_norm = max_grad_norm
        self.ent_coef = ent_coef
        self.vf_coef = vf_coef
        self.clip_param = clip_param
        self.device = torch.device('cuda:0' if gpu and torch.cuda.is_available()
                                   else 'cpu')

        self.env = VecEpisodeLogger(VecRewardNormWrapper(env_fn(nenv=nenv),
                                                         gamma))

        self.pi = policy_fn(self.env).to(self.device)
        self.opt = optimizer(self.pi.parameters())
        self.data_manager = RolloutDataManager(
            self.env,
            PPOActor(self.pi),
            self.device,
            batch_size=batch_size,
            rollout_length=rollout_length,
            gamma=gamma,
            lambda_=lambda_,
            norm_advantages=norm_advantages)

        self.mse = nn.MSELoss(reduction='none')

        self.t = 0

    def compute_kl(self):
        """Compute KL divergence of new and old policies."""
        kl = 0
        n = 0
        for batch in self.data_manager.sampler():
            outs = self.pi(batch['obs'])
            old_dist = outs.dist.from_tensors(batch['dist'])
            k = old_dist.kl(outs.dist).mean()
            s = nest.flatten(batch['action'])[0].shape[0]
            kl = (n / (n + s)) * kl + (s / (n + s)) * k
            n += s
        return kl

    def loss(self, batch):
        """Compute loss."""
        outs = self.pi(batch['obs'])
        loss = {}

        # compute policy loss
        logp = outs.dist.log_prob(batch['action'])
        assert logp.shape == batch['logp'].shape
        ratio = torch.exp(logp - batch['logp'])
        assert ratio.shape == batch['atarg'].shape
        ploss1 = ratio * batch['atarg']
        ploss2 = torch.clamp(ratio, 1.0-self.clip_param,
                             1.0+self.clip_param) * batch['atarg']
        pi_loss = -torch.min(ploss1, ploss2).mean()
        loss['pi'] = pi_loss

        # compute value loss
        vloss1 = 0.5 * self.mse(outs.value, batch['vtarg'])
        vpred_clipped = batch['vpred'] + (
            outs.value - batch['vpred']).clamp(-self.clip_param,
                                               self.clip_param)
        vloss2 = 0.5 * self.mse(vpred_clipped, batch['vtarg'])
        vf_loss = torch.max(vloss1, vloss2).mean()
        loss['value'] = vf_loss

        # compute entropy loss
        ent_loss = outs.dist.entropy().mean()
        loss['entropy'] = ent_loss

        tot_loss = pi_loss + self.vf_coef * vf_loss - self.ent_coef * ent_loss
        loss['total'] = tot_loss
        return loss

    def step(self):
        """Compute rollout, loss, and update model."""
        self.pi.train()
        self.t += self.data_manager.rollout()
        losses = {}
        for _ in range(self.epochs_per_rollout):
            for batch in self.data_manager.sampler():
                self.opt.zero_grad()
                loss = self.loss(batch)
                if losses == {}:
                    losses = {k: [] for k in loss}
                for k, v in loss.items():
                    losses[k].append(v.detach().cpu().numpy())
                loss['total'].backward()
                if self.max_grad_norm:
                    norm = nn.utils.clip_grad_norm_(self.pi.parameters(),
                                                    self.max_grad_norm)
                    logger.add_scalar('alg/grad_norm', norm, self.t,
                                      time.time())
                    logger.add_scalar('alg/grad_norm_clipped',
                                      min(norm, self.max_grad_norm),
                                      self.t, time.time())
                self.opt.step()
        for k, v in losses.items():
            logger.add_scalar(f'loss/{k}', np.mean(v), self.t, time.time())

        data = self.data_manager.storage.get_rollout()
        value_error = data['vpred'].data - data['q_mc'].data
        logger.add_scalar('alg/value_error_mean',
                          value_error.mean().cpu().numpy(), self.t, time.time())
        logger.add_scalar('alg/value_error_std',
                          value_error.std().cpu().numpy(), self.t, time.time())

        logger.add_scalar('alg/kl', self.compute_kl(), self.t, time.time())
        return self.t

    def evaluate(self):
        """Evaluate model."""
        self.pi.eval()
        misc.set_env_to_eval_mode(self.env)

        # Eval policy
        os.makedirs(os.path.join(self.logdir, 'eval'), exist_ok=True)
        outfile = os.path.join(self.logdir, 'eval',
                               self.ckptr.format.format(self.t) + '.json')
        stats = rl_evaluate(self.env, self.pi, self.eval_num_episodes,
                            outfile, self.device)
        logger.add_scalar('eval/mean_episode_reward', stats['mean_reward'],
                          self.t, time.time())
        logger.add_scalar('eval/mean_episode_length', stats['mean_length'],
                          self.t, time.time())

        # Record policy
        os.makedirs(os.path.join(self.logdir, 'video'), exist_ok=True)
        outfile = os.path.join(self.logdir, 'video',
                               self.ckptr.format.format(self.t) + '.mp4')
        rl_record(self.env, self.pi, self.record_num_episodes, outfile,
                  self.device)

        self.pi.train()
        misc.set_env_to_train_mode(self.env)

    def save(self):
        """State dict."""
        state_dict = {
            'pi': self.pi.state_dict(),
            'opt': self.opt.state_dict(),
            'env': misc.env_state_dict(self.env),
            't': self.t
        }
        self.ckptr.save(state_dict, self.t)

    def load(self, t=None):
        """Load state dict."""
        state_dict = self.ckptr.load(t)
        if state_dict is None:
            self.t = 0
            return self.t
        self.pi.load_state_dict(state_dict['pi'])
        self.opt.load_state_dict(state_dict['opt'])
        misc.env_load_state_dict(self.env, state_dict['env'])
        self.t = state_dict['t']
        return self.t

    def close(self):
        """Close environment."""
        try:
            self.env.close()
        except Exception:
            pass
Esempio n. 22
0
    def __init__(self,
                 logdir,
                 env_fn,
                 policy_fn,
                 value_fn,
                 rnd_net,
                 nenv=1,
                 opt_pi=torch.optim.Adam,
                 opt_vf=torch.optim.Adam,
                 opt_rnd=torch.optim.Adam,
                 batch_size=32,
                 rollout_length=128,
                 gamma_ext=0.999,
                 gamma_int=0.99,
                 lambda_=0.95,
                 ent_coef=0.01,
                 rnd_coef=0.5,
                 rnd_subsample_rate=4,
                 norm_advantages=False,
                 epochs_pi=10,
                 epochs_vf=10,
                 max_grad_norm=None,
                 kl_target=0.01,
                 alpha=1.5,
                 eval_num_episodes=1,
                 record_num_episodes=1,
                 gpu=True):
        """Init."""
        self.logdir = logdir
        self.ckptr = Checkpointer(os.path.join(logdir, 'ckpts'))
        self.env_fn = env_fn
        self.nenv = nenv
        self.eval_num_episodes = eval_num_episodes
        self.record_num_episodes = record_num_episodes
        self.ent_coef = ent_coef
        self.rnd_coef = rnd_coef
        self.rnd_subsample_rate = rnd_subsample_rate
        self.rnd_update_count = 0
        self.epochs_pi = epochs_pi
        self.epochs_vf = epochs_vf
        self.max_grad_norm = max_grad_norm
        self.norm_advantages = norm_advantages
        self.kl_target = kl_target
        self.initial_kl_weight = 0.2
        self.kl_weight = self.initial_kl_weight
        self.alpha = alpha
        self.device = torch.device('cuda:0' if gpu and torch.cuda.is_available()
                                   else 'cpu')

        self.env = VecEpisodeLogger(env_fn(nenv=nenv))
        self.rnd = RND(rnd_net, opt_rnd, gamma_int,
                       self.env.observation_space.shape, self.device)
        self.env = RNDVecEnv(self.env, self.rnd)

        self.pi = policy_fn(self.env).to(self.device)
        self.vf = value_fn(self.env).to(self.device)
        self.opt_pi = opt_pi(self.pi.parameters())
        self.opt_vf = opt_vf(self.vf.parameters())

        self.gamma = torch.Tensor([gamma_ext, gamma_int]).to(self.device)
        self.data_manager = RolloutDataManager(
            self.env,
            PPOActor(self.pi, self.vf),
            self.device,
            batch_size=batch_size,
            rollout_length=rollout_length,
            gamma=self.gamma,
            lambda_=lambda_,
            norm_advantages=False)

        self.mse = nn.MSELoss()

        self.t = 0
Esempio n. 23
0
File: td3.py Progetto: amackeith/dl
    def __init__(self,
                 logdir,
                 env_fn,
                 policy_fn,
                 qf_fn,
                 nenv=1,
                 optimizer=torch.optim.Adam,
                 buffer_size=int(1e6),
                 frame_stack=1,
                 learning_starts=10000,
                 update_period=1,
                 batch_size=256,
                 lr=3e-4,
                 policy_update_period=2,
                 target_smoothing_coef=0.005,
                 reward_scale=1,
                 gamma=0.99,
                 exploration_noise=0.1,
                 policy_noise=0.2,
                 policy_noise_clip=0.5,
                 gpu=True,
                 eval_num_episodes=1,
                 record_num_episodes=1,
                 log_period=1000):
        """Init."""
        self.logdir = logdir
        self.ckptr = Checkpointer(os.path.join(logdir, 'ckpts'))
        self.env_fn = env_fn
        self.nenv = nenv
        self.eval_num_episodes = eval_num_episodes
        self.record_num_episodes = record_num_episodes
        self.gamma = gamma
        self.buffer_size = buffer_size
        self.batch_size = batch_size
        self.frame_stack = frame_stack
        self.learning_starts = learning_starts
        self.update_period = update_period
        if policy_update_period < self.update_period:
            self.policy_update_period = self.update_period
        else:
            self.policy_update_period = policy_update_period - (
                policy_update_period % self.update_period)
        self.reward_scale = reward_scale
        self.target_smoothing_coef = target_smoothing_coef
        self.exploration_noise = exploration_noise
        self.policy_noise = policy_noise
        self.policy_noise_clip = policy_noise_clip
        self.log_period = log_period

        self.device = torch.device(
            'cuda:0' if gpu and torch.cuda.is_available() else 'cpu')

        self.policy_fn = policy_fn
        self.qf_fn = qf_fn
        self.env = VecEpisodeLogger(env_fn(nenv=nenv))
        eval_env = VecFrameStack(self.env, self.frame_stack)
        self.pi = policy_fn(eval_env)
        self.qf1 = qf_fn(eval_env)
        self.qf2 = qf_fn(eval_env)
        self.target_pi = policy_fn(eval_env)
        self.target_qf1 = qf_fn(eval_env)
        self.target_qf2 = qf_fn(eval_env)

        self.pi.to(self.device)
        self.qf1.to(self.device)
        self.qf2.to(self.device)
        self.target_pi.to(self.device)
        self.target_qf1.to(self.device)
        self.target_qf2.to(self.device)

        self.optimizer = optimizer
        self.lr = lr
        self.opt_pi = optimizer(self.pi.parameters(), lr=lr)
        self.opt_qf = optimizer(list(self.qf1.parameters()) +
                                list(self.qf2.parameters()),
                                lr=lr)

        self.target_pi.load_state_dict(self.pi.state_dict())
        self.target_qf1.load_state_dict(self.qf1.state_dict())
        self.target_qf2.load_state_dict(self.qf2.state_dict())

        self._actor = TD3Actor(self.pi, self.env.action_space,
                               exploration_noise)
        self.buffer = ReplayBuffer(buffer_size, frame_stack)
        self.data_manager = ReplayBufferDataManager(self.buffer, self.env,
                                                    self._actor, self.device,
                                                    self.learning_starts,
                                                    self.update_period)

        self.qf_criterion = torch.nn.MSELoss()
        if self.env.action_space.__class__.__name__ == 'Discrete':
            raise ValueError("Action space must be continuous!")

        self.low = torch.from_numpy(self.env.action_space.low).to(self.device)
        self.high = torch.from_numpy(self.env.action_space.high).to(
            self.device)

        self.t = 0
Esempio n. 24
0
File: td3.py Progetto: amackeith/dl
class TD3(Algorithm):
    """TD3 algorithm."""
    def __init__(self,
                 logdir,
                 env_fn,
                 policy_fn,
                 qf_fn,
                 nenv=1,
                 optimizer=torch.optim.Adam,
                 buffer_size=int(1e6),
                 frame_stack=1,
                 learning_starts=10000,
                 update_period=1,
                 batch_size=256,
                 lr=3e-4,
                 policy_update_period=2,
                 target_smoothing_coef=0.005,
                 reward_scale=1,
                 gamma=0.99,
                 exploration_noise=0.1,
                 policy_noise=0.2,
                 policy_noise_clip=0.5,
                 gpu=True,
                 eval_num_episodes=1,
                 record_num_episodes=1,
                 log_period=1000):
        """Init."""
        self.logdir = logdir
        self.ckptr = Checkpointer(os.path.join(logdir, 'ckpts'))
        self.env_fn = env_fn
        self.nenv = nenv
        self.eval_num_episodes = eval_num_episodes
        self.record_num_episodes = record_num_episodes
        self.gamma = gamma
        self.buffer_size = buffer_size
        self.batch_size = batch_size
        self.frame_stack = frame_stack
        self.learning_starts = learning_starts
        self.update_period = update_period
        if policy_update_period < self.update_period:
            self.policy_update_period = self.update_period
        else:
            self.policy_update_period = policy_update_period - (
                policy_update_period % self.update_period)
        self.reward_scale = reward_scale
        self.target_smoothing_coef = target_smoothing_coef
        self.exploration_noise = exploration_noise
        self.policy_noise = policy_noise
        self.policy_noise_clip = policy_noise_clip
        self.log_period = log_period

        self.device = torch.device(
            'cuda:0' if gpu and torch.cuda.is_available() else 'cpu')

        self.policy_fn = policy_fn
        self.qf_fn = qf_fn
        self.env = VecEpisodeLogger(env_fn(nenv=nenv))
        eval_env = VecFrameStack(self.env, self.frame_stack)
        self.pi = policy_fn(eval_env)
        self.qf1 = qf_fn(eval_env)
        self.qf2 = qf_fn(eval_env)
        self.target_pi = policy_fn(eval_env)
        self.target_qf1 = qf_fn(eval_env)
        self.target_qf2 = qf_fn(eval_env)

        self.pi.to(self.device)
        self.qf1.to(self.device)
        self.qf2.to(self.device)
        self.target_pi.to(self.device)
        self.target_qf1.to(self.device)
        self.target_qf2.to(self.device)

        self.optimizer = optimizer
        self.lr = lr
        self.opt_pi = optimizer(self.pi.parameters(), lr=lr)
        self.opt_qf = optimizer(list(self.qf1.parameters()) +
                                list(self.qf2.parameters()),
                                lr=lr)

        self.target_pi.load_state_dict(self.pi.state_dict())
        self.target_qf1.load_state_dict(self.qf1.state_dict())
        self.target_qf2.load_state_dict(self.qf2.state_dict())

        self._actor = TD3Actor(self.pi, self.env.action_space,
                               exploration_noise)
        self.buffer = ReplayBuffer(buffer_size, frame_stack)
        self.data_manager = ReplayBufferDataManager(self.buffer, self.env,
                                                    self._actor, self.device,
                                                    self.learning_starts,
                                                    self.update_period)

        self.qf_criterion = torch.nn.MSELoss()
        if self.env.action_space.__class__.__name__ == 'Discrete':
            raise ValueError("Action space must be continuous!")

        self.low = torch.from_numpy(self.env.action_space.low).to(self.device)
        self.high = torch.from_numpy(self.env.action_space.high).to(
            self.device)

        self.t = 0

    def loss(self, batch):
        """Loss function."""
        # compute QFunction loss.
        with torch.no_grad():
            target_action = self.target_pi(batch['next_obs']).action
            noise = (torch.randn_like(target_action) *
                     self.policy_noise).clamp(-self.policy_noise_clip,
                                              self.policy_noise_clip)
            target_action = (target_action + noise).clamp(-1., 1.)
            target_q1 = self.target_qf1(batch['next_obs'], target_action).value
            target_q2 = self.target_qf2(batch['next_obs'], target_action).value
            target_q = torch.min(target_q1, target_q2)
            qtarg = self.reward_scale * batch['reward'].float() + (
                (1.0 - batch['done']) * self.gamma * target_q)

        q1 = self.qf1(batch['obs'], batch['action']).value
        q2 = self.qf2(batch['obs'], batch['action']).value
        assert qtarg.shape == q1.shape
        assert qtarg.shape == q2.shape
        qf_loss = self.qf_criterion(q1, qtarg) + self.qf_criterion(q2, qtarg)

        # compute policy loss
        if self.t % self.policy_update_period == 0:
            action = self.pi(batch['obs'], deterministic=True).action
            q = self.qf1(batch['obs'], action).value
            pi_loss = -q.mean()
        else:
            pi_loss = torch.zeros_like(qf_loss)

        # log losses
        if self.t % self.log_period < self.update_period:
            logger.add_scalar('loss/qf', qf_loss, self.t, time.time())
            if self.t % self.policy_update_period == 0:
                logger.add_scalar('loss/pi', pi_loss, self.t, time.time())
        return pi_loss, qf_loss

    def step(self):
        """Step optimization."""
        self.t += self.data_manager.step_until_update()
        batch = self.data_manager.sample(self.batch_size)

        pi_loss, qf_loss = self.loss(batch)

        # update
        self.opt_qf.zero_grad()
        qf_loss.backward()
        self.opt_qf.step()

        if self.t % self.policy_update_period == 0:
            self.opt_pi.zero_grad()
            pi_loss.backward()
            self.opt_pi.step()

            # update target networks
            soft_target_update(self.target_pi, self.pi,
                               self.target_smoothing_coef)
            soft_target_update(self.target_qf1, self.qf1,
                               self.target_smoothing_coef)
            soft_target_update(self.target_qf2, self.qf2,
                               self.target_smoothing_coef)
        return self.t

    def evaluate(self):
        """Evaluate."""
        eval_env = VecFrameStack(self.env, self.frame_stack)
        self.pi.eval()
        misc.set_env_to_eval_mode(eval_env)

        # Eval policy
        os.makedirs(os.path.join(self.logdir, 'eval'), exist_ok=True)
        outfile = os.path.join(self.logdir, 'eval',
                               self.ckptr.format.format(self.t) + '.json')
        stats = rl_evaluate(eval_env, self.pi, self.eval_num_episodes, outfile,
                            self.device)
        logger.add_scalar('eval/mean_episode_reward', stats['mean_reward'],
                          self.t, time.time())
        logger.add_scalar('eval/mean_episode_length', stats['mean_length'],
                          self.t, time.time())

        # Record policy
        os.makedirs(os.path.join(self.logdir, 'video'), exist_ok=True)
        outfile = os.path.join(self.logdir, 'video',
                               self.ckptr.format.format(self.t) + '.mp4')
        rl_record(eval_env, self.pi, self.record_num_episodes, outfile,
                  self.device)

        self.pi.train()
        misc.set_env_to_train_mode(self.env)
        self.data_manager.manual_reset()

    def save(self):
        """Save."""
        state_dict = {
            'pi': self.pi.state_dict(),
            'qf1': self.qf1.state_dict(),
            'qf2': self.qf2.state_dict(),
            'target_pi': self.target_pi.state_dict(),
            'target_qf1': self.target_qf1.state_dict(),
            'target_qf2': self.target_qf2.state_dict(),
            'opt_pi': self.opt_pi.state_dict(),
            'opt_qf': self.opt_qf.state_dict(),
            'env': misc.env_state_dict(self.env),
            't': self.t
        }
        buffer_dict = self.buffer.state_dict()
        state_dict['buffer_format'] = nest.get_structure(buffer_dict)
        self.ckptr.save(state_dict, self.t)

        # save buffer seperately and only once (because it can be huge)
        np.savez(
            os.path.join(self.ckptr.ckptdir, 'buffer.npz'),
            **{f'{i:04d}': x
               for i, x in enumerate(nest.flatten(buffer_dict))})

    def load(self, t=None):
        """Load."""
        state_dict = self.ckptr.load(t)
        if state_dict is None:
            self.t = 0
            return self.t
        self.pi.load_state_dict(state_dict['pi'])
        self.qf1.load_state_dict(state_dict['qf1'])
        self.qf2.load_state_dict(state_dict['qf2'])
        self.target_pi.load_state_dict(state_dict['target_pi'])
        self.target_qf1.load_state_dict(state_dict['target_qf1'])
        self.target_qf2.load_state_dict(state_dict['target_qf2'])
        self.opt_pi.load_state_dict(state_dict['opt_pi'])
        self.opt_qf.load_state_dict(state_dict['opt_qf'])
        misc.env_load_state_dict(self.env, state_dict['env'])
        self.t = state_dict['t']

        buffer_format = state_dict['buffer_format']
        buffer_state = dict(
            np.load(os.path.join(self.ckptr.ckptdir, 'buffer.npz')))
        buffer_state = nest.flatten(buffer_state)
        self.buffer.load_state_dict(
            nest.pack_sequence_as(buffer_state, buffer_format))
        self.data_manager.manual_reset()
        return self.t

    def close(self):
        """Close environment."""
        try:
            self.env.close()
        except Exception:
            pass
Esempio n. 25
0
    def __init__(self,
                 logdir,
                 env_fn,
                 policy_fn,
                 qf_fn,
                 nenv=1,
                 optimizer=torch.optim.Adam,
                 buffer_size=10000,
                 frame_stack=1,
                 learning_starts=1000,
                 update_period=1,
                 batch_size=256,
                 policy_lr=1e-3,
                 qf_lr=1e-3,
                 gamma=0.99,
                 target_update_period=1,
                 policy_update_period=1,
                 target_smoothing_coef=0.005,
                 alpha=0.2,
                 automatic_entropy_tuning=True,
                 target_entropy=None,
                 gpu=True,
                 eval_num_episodes=1,
                 record_num_episodes=1,
                 log_period=1000):
        """Init."""
        self.logdir = logdir
        self.ckptr = Checkpointer(os.path.join(logdir, 'ckpts'))
        self.env_fn = env_fn
        self.nenv = nenv
        self.eval_num_episodes = eval_num_episodes
        self.record_num_episodes = record_num_episodes
        self.gamma = gamma
        self.buffer_size = buffer_size
        self.frame_stack = frame_stack
        self.learning_starts = learning_starts
        self.update_period = update_period
        self.batch_size = batch_size
        if target_update_period < self.update_period:
            self.target_update_period = self.update_period
        else:
            self.target_update_period = target_update_period - (
                target_update_period % self.update_period)
        if policy_update_period < self.update_period:
            self.policy_update_period = self.update_period
        else:
            self.policy_update_period = policy_update_period - (
                policy_update_period % self.update_period)
        self.target_smoothing_coef = target_smoothing_coef
        self.log_period = log_period

        self.device = torch.device(
            'cuda:0' if gpu and torch.cuda.is_available() else 'cpu')

        self.env = VecEpisodeLogger(env_fn(nenv=nenv))
        eval_env = VecFrameStack(self.env, self.frame_stack)
        self.pi = policy_fn(eval_env)
        self.qf1 = qf_fn(eval_env)
        self.qf2 = qf_fn(eval_env)
        self.target_qf1 = qf_fn(eval_env)
        self.target_qf2 = qf_fn(eval_env)

        self.pi.to(self.device)
        self.qf1.to(self.device)
        self.qf2.to(self.device)
        self.target_qf1.to(self.device)
        self.target_qf2.to(self.device)

        self.opt_pi = optimizer(self.pi.parameters(), lr=policy_lr)
        self.opt_qf1 = optimizer(self.qf1.parameters(), lr=qf_lr)
        self.opt_qf2 = optimizer(self.qf2.parameters(), lr=qf_lr)

        self.target_qf1.load_state_dict(self.qf1.state_dict())
        self.target_qf2.load_state_dict(self.qf2.state_dict())

        self.buffer = BatchedReplayBuffer(
            *
            [ReplayBuffer(buffer_size, frame_stack) for _ in range(self.nenv)])
        self.data_manager = ReplayBufferDataManager(self.buffer, self.env,
                                                    SACActor(self.pi),
                                                    self.device,
                                                    self.learning_starts,
                                                    self.update_period)

        self.alpha = alpha
        self.automatic_entropy_tuning = automatic_entropy_tuning
        if self.automatic_entropy_tuning:
            if target_entropy:
                self.target_entropy = target_entropy
            else:
                target_entropies = nest.map_structure(
                    lambda space: -np.prod(space.shape).item(),
                    misc.unpack_space(self.env.action_space))
                self.target_entropy = sum(nest.flatten(target_entropies))

            self.log_alpha = torch.tensor(np.log([self.alpha]),
                                          requires_grad=True,
                                          device=self.device,
                                          dtype=torch.float32)
            self.opt_alpha = optimizer([self.log_alpha], lr=policy_lr)
        else:
            self.target_entropy = None
            self.log_alpha = None
            self.opt_alpha = None

        self.mse_loss = torch.nn.MSELoss()

        self.t = 0
Esempio n. 26
0
    def __init__(self,
                 logdir,
                 env_fn,
                 policy_fn,
                 value_fn,
                 rnd_net,
                 ide_embedding_net,
                 ide_prediction_net,
                 ide_loss,
                 nenv=1,
                 opt_pi=torch.optim.Adam,
                 opt_vf=torch.optim.Adam,
                 opt_rnd=torch.optim.Adam,
                 opt_ide=torch.optim.Adam,
                 batch_size=32,
                 rollout_length=128,
                 gamma_ext=0.999,
                 gamma_int=0.99,
                 lambda_=0.95,
                 ent_coef=0.01,
                 ngu_coef=0.5,
                 ngu_buffer_capacity=1024,
                 ngu_subsample_freq=32,
                 ngu_updates=4,
                 ngu_batch_size=64,
                 policy_training_starts=500000,
                 buffer_size=100000,
                 norm_advantages=False,
                 epochs_pi=10,
                 epochs_vf=10,
                 max_grad_norm=None,
                 kl_target=0.01,
                 alpha=1.5,
                 eval_num_episodes=1,
                 record_num_episodes=1,
                 gpu=True):
        """Init."""
        self.logdir = logdir
        self.ckptr = Checkpointer(os.path.join(logdir, 'ckpts'))
        self.env_fn = env_fn
        self.nenv = nenv
        self.eval_num_episodes = eval_num_episodes
        self.record_num_episodes = record_num_episodes
        self.ent_coef = ent_coef
        self.ngu_coef = ngu_coef
        self.ngu_updates = ngu_updates
        self.ngu_batch_size = ngu_batch_size
        self.ngu_subsample_freq = ngu_subsample_freq
        self.epochs_pi = epochs_pi
        self.epochs_vf = epochs_vf
        self.max_grad_norm = max_grad_norm
        self.norm_advantages = norm_advantages
        self.kl_target = kl_target
        self.initial_kl_weight = 0.2
        self.kl_weight = self.initial_kl_weight
        self.policy_training_starts = policy_training_starts
        self.alpha = alpha
        self.device = torch.device(
            'cuda:0' if gpu and torch.cuda.is_available() else 'cpu')
        # self.ngu_coefs = (1 + np.tanh(np.arange(-4, 4, 8/nenv))) / 2. * ngu_coef
        # self.ngu_coefs = torch.from_numpy(self.ngu_coefs).to(self.device)

        self.env = VecEpisodeLogger(env_fn(nenv=nenv))
        self.rnd = RND(rnd_net, opt_rnd, gamma_int,
                       self.env.observation_space.shape, self.device)
        self.ide = InverseDynamicsEmbedding(self.env, ide_embedding_net,
                                            ide_prediction_net, ide_loss,
                                            opt_ide, self.device)
        self.ngu = NGU(self.rnd,
                       self.ide,
                       ngu_buffer_capacity,
                       self.device,
                       gamma=gamma_int)
        self.env = VecActionRewardInObWrapper(NGUVecEnv(self.env, self.ngu),
                                              reward_shape=(2, ))

        self.pi = policy_fn(self.env).to(self.device)
        self.vf = value_fn(self.env).to(self.device)
        self.opt_pi = opt_pi(self.pi.parameters())
        self.opt_vf = opt_vf(self.vf.parameters())

        self.gamma = torch.Tensor([gamma_ext, gamma_int]).to(self.device)
        self.data_manager = RolloutDataManager(self.env,
                                               PPOActor(self.pi, self.vf),
                                               self.device,
                                               batch_size=batch_size,
                                               rollout_length=rollout_length,
                                               gamma=self.gamma,
                                               lambda_=lambda_,
                                               norm_advantages=False)
        self.buffer_size = buffer_size
        self.buffer = ReplayBuffer(buffer_size, 1)

        self.mse = nn.MSELoss()

        self.t = 0
Esempio n. 27
0
    def __init__(
            self,
            logdir,
            env_fn,
            policy_fn,
            nenv=1,
            optimizer=torch.optim.Adam,
            lambda_lr=1e-4,
            lambda_init=100.,
            lr_decay_rate=1. / 3.16227766017,
            lr_decay_freq=20000000,
            l2_reg=True,
            reward_threshold=-0.05,
            rollout_length=128,
            batch_size=32,
            gamma=0.99,
            lambda_=0.95,
            norm_advantages=False,
            epochs_per_rollout=10,
            max_grad_norm=None,
            ent_coef=0.01,
            vf_coef=0.5,
            clip_param=0.2,
            base_actor_cls=None,
            policy_training_start=10000,
            lambda_training_start=100000,
            eval_num_episodes=1,
            record_num_episodes=1,
            wrapper_fn=None,  # additional wrappers for the env
            gpu=True):
        """Init."""
        self.logdir = logdir
        self.ckptr = Checkpointer(os.path.join(logdir, 'ckpts'))
        self.env_fn = env_fn
        self.nenv = nenv
        self.eval_num_episodes = eval_num_episodes
        self.record_num_episodes = record_num_episodes
        self.epochs_per_rollout = epochs_per_rollout
        self.max_grad_norm = max_grad_norm
        self.ent_coef = ent_coef
        self.vf_coef = vf_coef
        self.clip_param = clip_param
        self.base_actor_cls = base_actor_cls
        self.policy_training_start = policy_training_start
        self.lambda_training_start = lambda_training_start
        self.lambda_lr = lambda_lr
        self.lr_decay_rate = lr_decay_rate
        self.lr_decay_freq = lr_decay_freq
        self.l2_reg = l2_reg
        self.reward_threshold = reward_threshold
        self.device = torch.device(
            'cuda:0' if gpu and torch.cuda.is_available() else 'cpu')

        self.env = VecEpisodeLogger(env_fn(nenv=nenv))
        self.env = ResidualWrapper(self.env, self.base_actor_cls(self.env))
        if wrapper_fn:
            self.env = wrapper_fn(self.env)

        self.pi = policy_fn(self.env).to(self.device)
        self.opt = optimizer(self.pi.parameters())
        self.pi_lr = self.opt.param_groups[0]['lr']
        if lambda_init < 10:
            lambda_init = np.log(np.exp(lambda_init) - 1)
        self.log_lambda_ = nn.Parameter(
            torch.Tensor([lambda_init]).to(self.device))
        self.opt_l = optimizer([self.log_lambda_], lr=lambda_lr)
        self._actor = ResidualPPOActor(self.pi, policy_training_start)
        self.data_manager = RolloutDataManager(self.env,
                                               self._actor,
                                               self.device,
                                               rollout_length=rollout_length,
                                               batch_size=batch_size,
                                               gamma=gamma,
                                               lambda_=lambda_,
                                               norm_advantages=norm_advantages)

        self.mse = nn.MSELoss(reduction='none')
        self.huber = nn.SmoothL1Loss()

        self.t = 0
Esempio n. 28
0
class MNISTTrainer(object):
    """Trainer for mnist."""
    def __init__(self, logdir, model, opt, batch_size, num_workers, gpu=True):
        self.logdir = logdir
        self.ckptr = Checkpointer(os.path.join(logdir, 'ckpts'))
        self.transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.1307, ), (0.3081, ))
        ])
        self.data_train = datasets.MNIST('./data_train',
                                         download=True,
                                         transform=self.transform)
        self.data_test = datasets.MNIST('./data_test',
                                        download=True,
                                        train=False,
                                        transform=self.transform)
        self.sampler = StatefulSampler(self.data_train, shuffle=True)
        self.dtrain = DataLoader(self.data_train,
                                 sampler=self.sampler,
                                 batch_size=batch_size,
                                 num_workers=num_workers)
        self.dtest = DataLoader(self.data_test,
                                batch_size=batch_size,
                                num_workers=num_workers)
        self._diter = None
        self.t = 0
        self.epochs = 0
        self.batch_size = batch_size

        self.device = torch.device(
            'cuda:0' if gpu and torch.cuda.is_available() else 'cpu')
        self.model = model
        self.model.to(self.device)
        self.opt = opt(self.model.parameters())

    def step(self):
        # Get batch.
        if self._diter is None:
            self._diter = self.dtrain.__iter__()
        try:
            batch = self._diter.__next__()
        except StopIteration:
            self.epochs += 1
            self._diter = None
            return self.epochs
        batch = nest.map_structure(lambda x: x.to(self.device), batch)

        # compute loss
        x, y = batch
        self.model.train()
        loss = F.nll_loss(self.model(x), y)

        logger.add_scalar('train/loss',
                          loss.detach().cpu().numpy(), self.t, time.time())

        # update model
        self.opt.zero_grad()
        loss.backward()
        self.opt.step()

        # increment step
        self.t += min(
            len(self.data_train) - (self.t % len(self.data_train)),
            self.batch_size)
        return self.epochs

    def evaluate(self):
        """Evaluate model."""
        self.model.eval()

        accuracy = []
        with torch.no_grad():
            for batch in self.dtest:
                x, y = nest.map_structure(lambda x: x.to(self.device), batch)
                y_hat = self.model(x).argmax(-1)
                accuracy.append((y_hat == y).float().mean().cpu().numpy())

            logger.add_scalar(f'test_accuracy', np.mean(accuracy), self.epochs,
                              time.time())

    def save(self):
        state_dict = {}
        state_dict['model'] = self.model.state_dict()
        state_dict['opt'] = self.opt.state_dict()
        state_dict['sampler'] = self.sampler.state_dict(self._diter)
        state_dict['t'] = self.t
        state_dict['epochs'] = self.epochs
        self.ckptr.save(state_dict, self.t)

    def load(self, t=None):
        state_dict = self.ckptr.load()
        if state_dict is None:
            self.t = 0
            self.epochs = 0
            return self.epochs
        self.model.load_state_dict(state_dict['model'])
        self.opt.load_state_dict(state_dict['opt'])
        self.sampler.load_state_dict(state_dict['sampler'])
        self.t = state_dict['t']
        self.epochs = state_dict['epochs']
        if self._diter is not None:
            self._diter.__del__()
            self._diter = None

    def close(self):
        """Close data iterator."""
        if self._diter is not None:
            self._diter.__del__()
            self._diter = None
Esempio n. 29
0
class PPO2NGU(Algorithm):
    """PPO with Never Give Up style instrinsic motivation.

    This version of ppo is described in https://arxiv.org/abs/1707.02286 and
    https://github.com/joschu/modular_rl/blob/master/modular_rl/ppo.py
    """
    def __init__(self,
                 logdir,
                 env_fn,
                 policy_fn,
                 value_fn,
                 rnd_net,
                 ide_embedding_net,
                 ide_prediction_net,
                 ide_loss,
                 nenv=1,
                 opt_pi=torch.optim.Adam,
                 opt_vf=torch.optim.Adam,
                 opt_rnd=torch.optim.Adam,
                 opt_ide=torch.optim.Adam,
                 batch_size=32,
                 rollout_length=128,
                 gamma_ext=0.999,
                 gamma_int=0.99,
                 lambda_=0.95,
                 ent_coef=0.01,
                 ngu_coef=0.5,
                 ngu_buffer_capacity=1024,
                 ngu_subsample_freq=32,
                 ngu_updates=4,
                 ngu_batch_size=64,
                 policy_training_starts=500000,
                 buffer_size=100000,
                 norm_advantages=False,
                 epochs_pi=10,
                 epochs_vf=10,
                 max_grad_norm=None,
                 kl_target=0.01,
                 alpha=1.5,
                 eval_num_episodes=1,
                 record_num_episodes=1,
                 gpu=True):
        """Init."""
        self.logdir = logdir
        self.ckptr = Checkpointer(os.path.join(logdir, 'ckpts'))
        self.env_fn = env_fn
        self.nenv = nenv
        self.eval_num_episodes = eval_num_episodes
        self.record_num_episodes = record_num_episodes
        self.ent_coef = ent_coef
        self.ngu_coef = ngu_coef
        self.ngu_updates = ngu_updates
        self.ngu_batch_size = ngu_batch_size
        self.ngu_subsample_freq = ngu_subsample_freq
        self.epochs_pi = epochs_pi
        self.epochs_vf = epochs_vf
        self.max_grad_norm = max_grad_norm
        self.norm_advantages = norm_advantages
        self.kl_target = kl_target
        self.initial_kl_weight = 0.2
        self.kl_weight = self.initial_kl_weight
        self.policy_training_starts = policy_training_starts
        self.alpha = alpha
        self.device = torch.device(
            'cuda:0' if gpu and torch.cuda.is_available() else 'cpu')
        # self.ngu_coefs = (1 + np.tanh(np.arange(-4, 4, 8/nenv))) / 2. * ngu_coef
        # self.ngu_coefs = torch.from_numpy(self.ngu_coefs).to(self.device)

        self.env = VecEpisodeLogger(env_fn(nenv=nenv))
        self.rnd = RND(rnd_net, opt_rnd, gamma_int,
                       self.env.observation_space.shape, self.device)
        self.ide = InverseDynamicsEmbedding(self.env, ide_embedding_net,
                                            ide_prediction_net, ide_loss,
                                            opt_ide, self.device)
        self.ngu = NGU(self.rnd,
                       self.ide,
                       ngu_buffer_capacity,
                       self.device,
                       gamma=gamma_int)
        self.env = VecActionRewardInObWrapper(NGUVecEnv(self.env, self.ngu),
                                              reward_shape=(2, ))

        self.pi = policy_fn(self.env).to(self.device)
        self.vf = value_fn(self.env).to(self.device)
        self.opt_pi = opt_pi(self.pi.parameters())
        self.opt_vf = opt_vf(self.vf.parameters())

        self.gamma = torch.Tensor([gamma_ext, gamma_int]).to(self.device)
        self.data_manager = RolloutDataManager(self.env,
                                               PPOActor(self.pi, self.vf),
                                               self.device,
                                               batch_size=batch_size,
                                               rollout_length=rollout_length,
                                               gamma=self.gamma,
                                               lambda_=lambda_,
                                               norm_advantages=False)
        self.buffer_size = buffer_size
        self.buffer = ReplayBuffer(buffer_size, 1)

        self.mse = nn.MSELoss()

        self.t = 0

    def compute_kl(self):
        """Compute KL divergence of new and old policies."""
        kl = 0
        n = 0
        for batch in self.data_manager.sampler():
            outs = self.pi(batch['obs'])
            old_dist = outs.dist.from_tensors(batch['dist'])
            k = old_dist.kl(outs.dist).mean().detach().cpu().numpy()
            s = nest.flatten(batch['action'])[0].shape[0]
            kl = (n / (n + s)) * kl + (s / (n + s)) * k
            n += s
        return kl

    def loss_pi(self, batch):
        """Compute loss."""
        outs = self.pi(batch['obs'])
        atarg = batch['atarg'][:, 0] + self.ngu_coef * batch['atarg'][:, 1]
        # atarg /= (1 + self.ngu_coef)
        # atarg = batch['atarg'][:, 1]
        # compute policy loss
        logp = outs.dist.log_prob(batch['action'])
        assert logp.shape == batch['logp'].shape
        ratio = torch.exp(logp - batch['logp'])
        assert ratio.shape == atarg.shape

        old_dist = outs.dist.from_tensors(batch['dist'])
        kl = old_dist.kl(outs.dist)
        kl_pen = (kl - 2 * self.kl_target).clamp(min=0).pow(2)
        losses = {}
        losses['pi'] = -(ratio * atarg).mean()
        losses['ent'] = -outs.dist.entropy().mean()
        losses['kl'] = kl.mean()
        losses['kl_pen'] = kl_pen.mean()
        losses['total'] = (losses['pi'] + self.ent_coef * losses['ent'] +
                           self.kl_weight * losses['kl'] +
                           1000 * losses['kl_pen'])
        return losses

    def loss_vf(self, batch):
        v = self.vf(batch['obs']).value
        assert v.shape == batch['vtarg'].shape
        return self.mse(v, batch['vtarg'])

    def step(self):
        """Compute rollout, loss, and update model."""
        self.pi.train()
        self.t += self.data_manager.rollout()
        losses = {
            'pi': [],
            'vf': [],
            'ent': [],
            'kl': [],
            'total': [],
            'kl_pen': [],
            'rnd': [],
            'ide': []
        }
        # if self.norm_advantages:
        #     atarg = self.data_manager.storage.data['atarg']
        #     atarg = atarg[:, 0] + self.ngu_coef * atarg[:, 1]
        #     self.data_manager.storage.data['atarg'][:, 0] -= atarg.mean()
        #     self.data_manager.storage.data['atarg'] /= atarg.std() + 1e-5

        if self.t >= self.policy_training_starts:
            #######################
            # Update pi
            #######################
            kl_too_big = False
            for _ in range(self.epochs_pi):
                if kl_too_big:
                    break
                for batch in self.data_manager.sampler():
                    self.opt_pi.zero_grad()
                    loss = self.loss_pi(batch)
                    # break if new policy is too different from old policy
                    if loss['kl'] > 4 * self.kl_target:
                        kl_too_big = True
                        break
                    loss['total'].backward()

                    for k, v in loss.items():
                        losses[k].append(v.detach().cpu().numpy())

                    if self.max_grad_norm:
                        norm = nn.utils.clip_grad_norm_(
                            self.pi.parameters(), self.max_grad_norm)
                        logger.add_scalar('alg/grad_norm', norm, self.t,
                                          time.time())
                        logger.add_scalar('alg/grad_norm_clipped',
                                          min(norm, self.max_grad_norm),
                                          self.t, time.time())
                    self.opt_pi.step()

            #######################
            # Update value function
            #######################
            for _ in range(self.epochs_vf):
                for batch in self.data_manager.sampler():
                    self.opt_vf.zero_grad()
                    loss = self.loss_vf(batch)
                    losses['vf'].append(loss.detach().cpu().numpy())
                    loss.backward()
                    if self.max_grad_norm:
                        norm = nn.utils.clip_grad_norm_(
                            self.vf.parameters(), self.max_grad_norm)
                        logger.add_scalar('alg/vf_grad_norm', norm, self.t,
                                          time.time())
                        logger.add_scalar('alg/vf_grad_norm_clipped',
                                          min(norm, self.max_grad_norm),
                                          self.t, time.time())
                    self.opt_vf.step()

        rollout = self.data_manager.storage.data
        lens = self.data_manager.storage.sequence_lengths.int()
        #######################
        # Store rollout_data in replay_buffer
        #######################
        for i in range(self.nenv):
            if self.buffer.num_in_buffer < self.buffer_size or i % self.ngu_subsample_freq == 0:
                for step in range(lens[i]):

                    def _f(x):
                        return x[step][i].cpu().numpy()

                    data = nest.map_structure(_f, rollout)
                    idx = self.buffer.store_observation(data['obs'])
                    self.buffer.store_effect(
                        idx, {
                            'done': data['done'],
                            'action': data['action'],
                            'reward': data['reward']
                        })

        #######################
        # Update NGU
        #######################
        for _ in range(self.ngu_updates):
            batch = self.buffer.sample(self.ngu_batch_size)

            def _to_torch(data):
                if isinstance(data, np.ndarray):
                    return torch.from_numpy(data).to(self.device)
                else:
                    return data

            batch = nest.map_structure(_to_torch, batch)
            loss = self.ngu.update_rnd(batch['obs']['ob'])
            losses['rnd'].append(loss.detach().cpu().numpy())
            not_done = torch.logical_not(batch['done'])
            loss = self.ngu.update_ide(batch['obs']['ob'][not_done],
                                       batch['next_obs']['ob'][not_done],
                                       batch['action'][not_done].long())
            losses['ide'].append(loss.detach().cpu().numpy())

        for k, v in losses.items():
            if len(v) == 0:
                continue
            logger.add_scalar(f'loss/{k}', np.mean(v), self.t, time.time())

        # update weight on kl to match kl_target.
        if self.t >= self.policy_training_starts:
            kl = self.compute_kl()
            if kl > 10.0 * self.kl_target and self.kl_weight < self.initial_kl_weight:
                self.kl_weight = self.initial_kl_weight
            elif kl > 1.3 * self.kl_target:
                self.kl_weight *= self.alpha
            elif kl < 0.7 * self.kl_target:
                self.kl_weight /= self.alpha
        else:
            kl = 0.0

        logger.add_scalar('alg/kl', kl, self.t, time.time())
        logger.add_scalar('alg/kl_weight', self.kl_weight, self.t, time.time())
        avg_return = self.data_manager.storage.data['return'][:, 0].mean(dim=0)
        avg_return = (avg_return[0] +
                      self.ngu_coef * avg_return[1]) / (1 + self.ngu_coef)
        logger.add_scalar('alg/return', avg_return, self.t, time.time())

        # log value errors
        errors = []
        for batch in self.data_manager.sampler():
            errors.append(batch['vpred'] - batch['q_mc'])
        errors = torch.cat(errors)
        logger.add_scalar('alg/value_error_mean',
                          errors.mean().cpu().numpy(), self.t, time.time())
        logger.add_scalar('alg/value_error_std',
                          errors.std().cpu().numpy(), self.t, time.time())

        return self.t

    def evaluate(self):
        """Evaluate model."""
        self.pi.eval()
        misc.set_env_to_eval_mode(self.env)

        # Eval policy
        os.makedirs(os.path.join(self.logdir, 'eval'), exist_ok=True)
        outfile = os.path.join(self.logdir, 'eval',
                               self.ckptr.format.format(self.t) + '.json')
        stats = rl_evaluate(self.env, self.pi, self.eval_num_episodes, outfile,
                            self.device)
        logger.add_scalar('eval/mean_episode_reward', stats['mean_reward'],
                          self.t, time.time())
        logger.add_scalar('eval/mean_episode_length', stats['mean_length'],
                          self.t, time.time())

        # Record policy
        os.makedirs(os.path.join(self.logdir, 'video'), exist_ok=True)
        outfile = os.path.join(self.logdir, 'video',
                               self.ckptr.format.format(self.t) + '.mp4')
        rl_record(self.env, self.pi, self.record_num_episodes, outfile,
                  self.device)

        self.pi.train()
        misc.set_env_to_train_mode(self.env)

    def save(self):
        """State dict."""
        state_dict = {
            'pi': self.pi.state_dict(),
            'vf': self.vf.state_dict(),
            'opt_pi': self.opt_pi.state_dict(),
            'opt_vf': self.opt_vf.state_dict(),
            'kl_weight': self.kl_weight,
            'env': misc.env_state_dict(self.env),
            'ngu': self.ngu.state_dict(),
            't': self.t
        }
        buffer_dict = self.buffer.state_dict()
        state_dict['buffer_format'] = nest.get_structure(buffer_dict)
        self.ckptr.save(state_dict, self.t)

        # save buffer seperately and only once (because it can be huge)
        np.savez(
            os.path.join(self.ckptr.ckptdir, 'buffer.npz'),
            **{f'{i:04d}': x
               for i, x in enumerate(nest.flatten(buffer_dict))})

    def load(self, t=None):
        """Load state dict."""
        state_dict = self.ckptr.load(t)
        if state_dict is None:
            self.t = 0
            return self.t
        self.pi.load_state_dict(state_dict['pi'])
        self.vf.load_state_dict(state_dict['vf'])
        self.opt_pi.load_state_dict(state_dict['opt_pi'])
        self.opt_vf.load_state_dict(state_dict['opt_vf'])
        self.kl_weight = state_dict['kl_weight']
        misc.env_load_state_dict(self.env, state_dict['env'])
        self.ngu.load_state_dict(state_dict['ngu'])
        self.t = state_dict['t']

        buffer_format = state_dict['buffer_format']
        buffer_state = dict(
            np.load(os.path.join(self.ckptr.ckptdir, 'buffer.npz'),
                    allow_pickle=True))
        buffer_state = nest.flatten(buffer_state)
        self.buffer.load_state_dict(
            nest.pack_sequence_as(buffer_state, buffer_format))
        self.buffer.env_reset()
        return self.t

    def close(self):
        """Close environment."""
        try:
            self.env.close()
        except Exception:
            pass
Esempio n. 30
0
class AlphaZero(object):
    """Alpha Zero Agent."""
    def __init__(self,
                 logdir,
                 game,
                 policy,
                 optimizer=torch.optim.Adam,
                 n_simulations=100,
                 buffer_size=200,
                 batch_size=64,
                 batches_per_game=1,
                 gpu=True):
        """Init."""
        self.logdir = logdir
        self.ckptr = Checkpointer(os.path.join(logdir, 'ckpts'))
        self.game = game
        self.device = torch.device(
            'cuda:0' if gpu and torch.cuda.is_available() else 'cpu')

        self.game = game
        self.n_sims = n_simulations
        self.batch_size = batch_size
        self.batches_per_game = batches_per_game

        self.pi = policy.to(self.device)
        self.opt = optimizer(self.pi.parameters(), lr=1e-2, weight_decay=1e-4)

        self.buffer = GameReplay(buffer_size)
        self.data_manager = SelfPlayManager(self.pi, self.game, self.buffer,
                                            self.device)

        self.mse = nn.MSELoss()

        self.t = 0

    def loss(self, batch):
        """Compute Loss."""
        loss = {}
        outs = self.pi(batch['state'])
        log_probs = F.log_softmax(outs.dist.logits, dim=1)
        loss['pi'] = -(batch['prob'] * log_probs).sum(dim=1).mean()
        loss['value'] = self.mse(batch['value'], outs.value)
        loss['total'] = loss['pi'] + loss['value']
        return loss

    def step(self):
        """Step alpha zero."""
        self.pi.train()
        self.t += self.data_manager.play_game(self.n_sims)

        # fill replay buffer if needed
        while not self.buffer.full():
            self.t += self.data_manager.play_game(self.n_sims)

        for _ in range(self.batches_per_game):
            batch = self.data_manager.sample(self.batch_size)
            self.opt.zero_grad()
            loss = self.loss(batch)
            loss['total'].backward()
            self.opt.step()
            for k, v in loss.items():
                logger.add_scalar(f'loss/{k}',
                                  v.detach().cpu().numpy(), self.t,
                                  time.time())
        return self.t

    def evaluate(self):
        """Evaluate."""
        for _ in range(10):
            print("Play against random!")
            state, id = self.game.reset()

            self.pi.eval()
            while not self.game.game_over(state, id)[0]:
                print(self.game.to_string(state, id))
                cs, _ = self.game.get_canonical_state(state, id)
                valid_actions = self.game.get_valid_actions(state, id)
                with torch.no_grad():
                    p, v = self.data_manager.pi(cs, valid_actions)
                print(f"ALPHAZero thinks the value is: {v * id}")
                state, id = self.game.move(state, id, np.argmax(p))

                #############################################
                # Play against random agent. Comment out for self play
                if self.game.game_over(state, id)[0]:
                    break
                print(self.game.to_string(state, id))
                valid_actions = self.game.get_valid_actions(state, id)
                acs = [
                    i for i in range(len(valid_actions)) if valid_actions[i]
                ]
                ac = np.random.choice(acs)
                state, id = self.game.move(state, id, ac)
                #############################################
            print(self.game.to_string(state, id))
            _, score = self.game.game_over(state, id)
            print(f"SCORE: {-score}")

    def save(self):
        """State dict."""
        state_dict = {
            'pi': self.pi.state_dict(),
            'opt': self.opt.state_dict(),
            't': self.t
        }
        buffer_dict = self.buffer.state_dict()
        state_dict['buffer_format'] = nest.get_structure(buffer_dict)
        self.ckptr.save(state_dict, self.t)

        # save buffer seperately and only once (because it can be huge)
        np.savez(
            os.path.join(self.ckptr.ckptdir, 'buffer.npz'),
            **{f'{i:04d}': x
               for i, x in enumerate(nest.flatten(buffer_dict))})

    def load(self, t=None):
        """Load state dict."""
        state_dict = self.ckptr.load(t)
        if state_dict is None:
            self.t = 0
            return self.t
        self.pi.load_state_dict(state_dict['pi'])
        self.opt.load_state_dict(state_dict['opt'])
        self.t = state_dict['t']

        buffer_format = state_dict['buffer_format']
        buffer_state = dict(
            np.load(os.path.join(self.ckptr.ckptdir, 'buffer.npz')))
        buffer_state = nest.flatten(buffer_state)
        self.buffer.load_state_dict(
            nest.pack_sequence_as(buffer_state, buffer_format))
        return self.t

    def close(self):
        """Close."""
        pass