예제 #1
0
파일: Gate.py 프로젝트: aghriss/TRHPO
    def __init__(self,
                 env,
                 gatepolicy,
                 policy_func,
                 value_func,
                 n_options,
                 option_len=3,
                 timesteps_per_batch=1000,
                 gamma=0.99,
                 lam=0.97,
                 MI_lambda=1e-3,
                 gate_max_kl=1e-2,
                 option_max_kl=1e-2,
                 cg_iters=10,
                 cg_damping=1e-2,
                 vf_iters=2,
                 max_train=1000,
                 ls_step=0.5,
                 checkpoint_freq=50):

        super(GateTRPO, self).__init__(name=env.name)

        self.n_options = n_options
        #self.name = self.name
        self.env = env
        self.gamma = gamma
        self.lam = lam
        self.MI_lambda = MI_lambda
        self.current_option = 0
        self.timesteps_per_batch = timesteps_per_batch
        self.gate_max_kl = gate_max_kl
        self.cg_iters = cg_iters
        self.cg_damping = cg_damping
        self.max_train = max_train
        self.ls_step = ls_step
        self.checkpoint_freq = checkpoint_freq

        self.policy = gatepolicy(env.observation_space.shape,
                                 env.action_space.n)
        self.oldpolicy = gatepolicy(env.observation_space.shape,
                                    env.action_space.n,
                                    verbose=0)
        self.oldpolicy.disable_grad()
        self.progbar = Progbar(self.timesteps_per_batch)

        self.path_generator = self.roller()
        self.episodes_reward = collections.deque([], 5)
        self.episodes_len = collections.deque([], 5)
        self.done = 0
        self.functions = [self.policy]

        self.options = [
            OptionTRPO(env.name, i, env, policy_func, value_func, gamma, lam,
                       option_len, option_max_kl, cg_iters, cg_damping,
                       vf_iters, ls_step, self.logger, checkpoint_freq)
            for i in range(n_options)
        ]
예제 #2
0
파일: trpo.py 프로젝트: aghriss/RL
    def __init__(
            self,
            env,
            policy_func,
            value_func,
            timesteps_per_batch=1000,  # what to train on
            gamma=0.99,
            lam=0.97,  # advantage estimation
            max_kl=1e-2,
            cg_iters=10,
            entropy_coeff=0.0,
            cg_damping=1e-2,
            vf_iters=3,
            max_train=1000,
            checkpoint_freq=50):

        super(TRPO, self).__init__()

        self.env = env
        self.gamma = gamma
        self.lam = lam

        self.timesteps_per_batch = timesteps_per_batch
        self.max_kl = max_kl
        self.cg_iters = cg_iters
        self.entropy_coeff = entropy_coeff
        self.cg_damping = cg_damping
        self.vf_iters = vf_iters
        self.max_train = max_train
        self.checkpoint_freq = checkpoint_freq

        self.policy = policy_func(env)
        self.oldpolicy = policy_func(env)
        self.value_function = value_func(self.env)
        self.progbar = Progbar(self.timesteps_per_batch)

        self.path_generator = self.roller()
        self.value_function.summary()

        self.episodes_reward = collections.deque([], 50)
        self.episodes_len = collections.deque([], 50)
        self.done = 0

        self.functions = [self.policy, self.value_function]
예제 #3
0
    def __init__(self,
                 env,
                 deep_func,
                 gamma,
                 batch_size,
                 memory_min,
                 memory_max,
                 update_double=10000,
                 train_steps=1000000,
                 log_freq=1000,
                 eps_start=1,
                 eps_decay=-1,
                 eps_min=0.1):

        super(DDQN, self).__init__()

        self.env = env
        self.Q = self.model = deep_func(env)
        self.target_Q = deep_func(env)

        self.discount = gamma
        self.memory_min = memory_min
        self.memory_max = memory_max
        self.eps = eps_start
        self.train_steps = train_steps
        self.batch_size = batch_size
        self.done = 0
        self.log_freq = log_freq
        self.progbar = Progbar(self.memory_max)
        self.memory = ReplayMemory(
            self.memory_max,
            ["state", "action", "reward", "next_state", "terminated"])

        self.eps_decay = eps_decay
        if eps_decay == -1:
            self.eps_decay = 1 / train_steps

        self.eps_min = eps_min
        self.update_double = update_double
        self.actions = []
        self.path_generator = self.roller()
        self.past_rewards = collections.deque([], 50)
        self.functions = [self.Q]
예제 #4
0
    def __init__(self, input_shape, output_shape):
        super(BaseNetwork, self).__init__()

        self.input_shape = input_shape
        self.output_shape = output_shape
        self.progbar = Progbar(100)
예제 #5
0
class BaseNetwork(nn.Module):
    """
        Base class for our Neural networks
    """
    name = "BaseNetwork"

    def __init__(self, input_shape, output_shape):
        super(BaseNetwork, self).__init__()

        self.input_shape = input_shape
        self.output_shape = output_shape
        self.progbar = Progbar(100)

    def forward(self, x):
        return self.model(x)

    def predict(self, x):
        x = U.torchify(x)
        if len(x.shape) == len(self.input_shape):
            x.unsqueeze_(0)
        return U.get(self.forward(x).squeeze())

    def optimize(self, l, clip=False):
        self.optimizer.zero_grad()
        l.backward()
        if clip: nn.utils.clip_grad_norm_(self.parameters(), 1.0)
        self.optimizer.step()

    def fit(self, X, Y, batch_size=50, epochs=1, clip=False):
        Xtmp, Ytmp = X.split(batch_size), Y.split(batch_size)
        for _ in range(epochs):
            self.progbar.__init__(len(Xtmp))
            for x, y in zip(Xtmp, Ytmp):
                #self.optimizer.zero_grad()
                loss = self.loss(self.forward(x), y)
                self.optimize(loss, clip)
                new_loss = self.loss(self.forward(x).detach(), y)
                self.progbar.add(1,
                                 values=[("old", U.get(loss)),
                                         ("new", U.get(new_loss))])

    def step(self, grad):
        self.optimizer.zero_grad()
        self.flaten.set_grad(grad)
        self.optimizer.step()

    def set_learning_rate(self, rate, verbose=0):
        for param_group in self.optimizer.param_groups:
            param_group['lr'] = rate
            if verbose: print("\n New Learning Rate ", param_group['lr'], "\n")

    def get_learning_rate(self):
        for param_group in self.optimizer.param_groups:
            return param_group['lr']

    def compile(self):
        self.loss = nn.SmoothL1Loss()
        self.optimizer = torch.optim.RMSprop(filter(lambda p: p.requires_grad,
                                                    self.parameters()),
                                             lr=1e-4)
        self.flaten = Flattener(self.parameters)
        self.summary()

    def load(self, fname):
        print("Loading %s.%s" % (fname, self.name))
        dic = torch.load(U.CHECK_PATH + fname + "." + self.name)
        super(BaseNetwork, self).load_state_dict(dic)

    def save(self, fname):
        print("Saving to %s.%s" % (fname, self.name))
        dic = super(BaseNetwork, self).state_dict()
        torch.save(dic, "%s.%s" % (U.CHECK_PATH + fname, self.name))

    def copy(self, X):
        self.flaten.set(X.flaten.get())

    def summary(self):
        U.summary(self, self.input_shape)
예제 #6
0
파일: trpo.py 프로젝트: aghriss/RL
class TRPO(BaseAgent):

    name = "TRPO"

    def __init__(
            self,
            env,
            policy_func,
            value_func,
            timesteps_per_batch=1000,  # what to train on
            gamma=0.99,
            lam=0.97,  # advantage estimation
            max_kl=1e-2,
            cg_iters=10,
            entropy_coeff=0.0,
            cg_damping=1e-2,
            vf_iters=3,
            max_train=1000,
            checkpoint_freq=50):

        super(TRPO, self).__init__()

        self.env = env
        self.gamma = gamma
        self.lam = lam

        self.timesteps_per_batch = timesteps_per_batch
        self.max_kl = max_kl
        self.cg_iters = cg_iters
        self.entropy_coeff = entropy_coeff
        self.cg_damping = cg_damping
        self.vf_iters = vf_iters
        self.max_train = max_train
        self.checkpoint_freq = checkpoint_freq

        self.policy = policy_func(env)
        self.oldpolicy = policy_func(env)
        self.value_function = value_func(self.env)
        self.progbar = Progbar(self.timesteps_per_batch)

        self.path_generator = self.roller()
        self.value_function.summary()

        self.episodes_reward = collections.deque([], 50)
        self.episodes_len = collections.deque([], 50)
        self.done = 0

        self.functions = [self.policy, self.value_function]

    def act(self, state, train=True):
        if train:
            return self.policy.sample(state)
        return self.policy.act(state)

    def calculate_losses(self, states, actions, advantages, tdlamret):

        pi = self.policy(states)
        old_pi = self.oldpolicy(states).detach()

        kl_old_new = m_utils.kl_logits(old_pi, pi)
        entropy = m_utils.entropy_logits(pi)
        mean_kl = kl_old_new.mean()

        mean_entropy = entropy.mean()

        ratio = torch.exp(
            m_utils.logp(pi, actions) -
            m_utils.logp(old_pi, actions))  # advantage * pnew / pold
        surrogate_gain = (ratio * advantages).mean()

        optimization_gain = surrogate_gain + self.entropy_coeff * mean_entropy

        losses = {
            "gain": optimization_gain,
            "meankl": mean_kl,
            "entropy": mean_entropy,
            "surrogate": surrogate_gain,
            "mean_entropy": mean_entropy
        }
        return losses

    def train(self):
        while self.done < self.max_train:
            print("=" * 40)
            print(" " * 15, self.done, "\n")
            self._train()
            if not self.done % self.checkpoint_freq:
                self.save()
            self.done = self.done + 1
        self.done = 0

    def _train(self):

        # Prepare for rollouts
        # ----------------------------------------

        self.oldpolicy.copy(self.policy)

        path = self.path_generator.__next__()

        states = U.torchify(path["state"])
        actions = U.torchify(path["action"]).long()
        advantages = U.torchify(path["advantage"])
        tdlamret = U.torchify(path["tdlamret"])
        vpred = U.torchify(
            path["vf"])  # predicted value function before udpate
        advantages = (advantages - advantages.mean()) / advantages.std(
        )  # standardized advantage function estimate

        losses = self.calculate_losses(states, actions, advantages, tdlamret)
        kl = losses["meankl"]
        optimization_gain = losses["gain"]

        loss_grad = self.policy.flaten.flatgrad(optimization_gain, retain=True)
        grad_kl = self.policy.flaten.flatgrad(kl, create=True, retain=True)

        theta_before = self.policy.flaten.get()
        self.log("Init param sum", theta_before.sum())
        self.log("explained variance",
                 (vpred - tdlamret).var() / tdlamret.var())

        if np.allclose(loss_grad.detach().cpu().numpy(), 0, atol=1e-15):
            print("Got zero gradient. not updating")
        else:
            print("Conjugate Gradient", end="")
            start = time.time()
            stepdir = m_utils.conjugate_gradient(self.Fvp(grad_kl),
                                                 loss_grad,
                                                 cg_iters=self.cg_iters)
            elapsed = time.time() - start
            print(", Done in %.3f" % elapsed)
            self.log("Conjugate Gradient in s", elapsed)
            assert stepdir.sum() != float("Inf")
            shs = .5 * stepdir.dot(self.Fvp(grad_kl)(stepdir))
            lm = torch.sqrt(shs / self.max_kl)
            self.log("lagrange multiplier:", lm)
            self.log("gnorm:",
                     np.linalg.norm(loss_grad.cpu().detach().numpy()))
            fullstep = stepdir / lm
            expected_improve = loss_grad.dot(fullstep)
            surrogate_before = losses["surrogate"]
            stepsize = 1.0

            print("Line Search", end="")
            start = time.time()
            for _ in range(10):
                theta_new = theta_before + fullstep * stepsize
                self.policy.flaten.set(theta_new)
                losses = self.calculate_losses(states, actions, advantages,
                                               tdlamret)
                surr = losses["surrogate"]
                improve = surr - surrogate_before
                kl = losses["meankl"]
                if surr == float("Inf") or kl == float("Inf"):
                    print("Infinite value of losses")
                elif kl > self.max_kl:
                    print("Violated KL")
                elif improve < 0:
                    print("Surrogate didn't improve. shrinking step.")
                else:
                    print("Expected: %.3f Actual: %.3f" %
                          (expected_improve, improve))
                    print("Stepsize OK!")
                    self.log("Line Search", "OK")
                    break
                stepsize *= .5
            else:
                print("couldn't compute a good step")
                self.log("Line Search", "NOPE")
                self.policy.flaten.set(theta_before)
            elapsed = time.time() - start
            print(", Done in %.3f" % elapsed)
            self.log("Line Search in s", elapsed)
            self.log("KL", kl)
            self.log("Surrogate", surr)
        start = time.time()
        print("Value Function Update", end="")
        self.value_function.fit(states[::5],
                                tdlamret[::5],
                                batch_size=50,
                                epochs=self.vf_iters)
        elapsed = time.time() - start
        print(", Done in %.3f" % elapsed)
        self.log("Value Function Fitting in s", elapsed)
        self.log("TDlamret mean", tdlamret.mean())
        self.log("Last 50 rolls mean rew", np.mean(self.episodes_reward))
        self.log("Last 50 rolls mean len", np.mean(self.episodes_len))
        self.print()

    def roller(self):

        state = self.env.reset()
        ep_rews = 0
        ep_len = 0
        while True:

            path = {
                s: []
                for s in
                ["state", "action", "reward", "terminated", "vf", "next_vf"]
            }
            self.progbar.__init__(self.timesteps_per_batch)
            for _ in range(self.timesteps_per_batch):

                path["state"].append(state)
                # act
                action = self.act(state)
                vf = self.value_function.predict(state)
                state, rew, done, _ = self.env.step(action)
                path["action"].append(action)
                path["reward"].append(rew)
                path["vf"].append(vf)
                path["terminated"].append(done * 1.0)
                path["next_vf"].append((1 - done) * vf)

                ep_rews += rew
                ep_len += 1

                if done:
                    state = self.env.reset()
                    self.episodes_reward.append(ep_rews)
                    self.episodes_len.append(ep_len)
                    ep_rews = 0
                    ep_len = 0
                self.progbar.add(1)

            for k, v in path.items():
                path[k] = np.array(v)

            self.add_vtarg_and_adv(path)
            yield path

    def Fvp(self, grad_kl):
        def fisher_product(v):
            kl_v = (grad_kl * v).sum()
            grad_grad_kl = self.policy.flaten.flatgrad(kl_v, retain=True)
            return grad_grad_kl + v * self.cg_damping

        return fisher_product

    def add_vtarg_and_adv(self, path):
        # General Advantage Estimation
        terminal = np.append(path["terminated"], 0)
        vpred = np.append(path["vf"], path["next_vf"])
        T = len(path["reward"])
        path["advantage"] = np.empty(T, 'float32')
        lastgaelam = 0
        for t in reversed(range(T)):
            nonterminal = 1 - terminal[t + 1]
            delta = path["reward"][t] + self.gamma * vpred[
                t + 1] * nonterminal - vpred[t]
            path["advantage"][
                t] = lastgaelam = delta + self.gamma * self.lam * nonterminal * lastgaelam
        path["tdlamret"] = (path["advantage"] + path["vf"]).reshape(-1, 1)
예제 #7
0
class DDQN(BaseAgent):
    """
    Double Deep Q Networks
    """
    name = "DDQN"

    def __init__(self,
                 env,
                 deep_func,
                 gamma,
                 batch_size,
                 memory_min,
                 memory_max,
                 update_double=10000,
                 train_steps=1000000,
                 log_freq=1000,
                 eps_start=1,
                 eps_decay=-1,
                 eps_min=0.1):

        super(DDQN, self).__init__()

        self.env = env
        self.Q = self.model = deep_func(env)
        self.target_Q = deep_func(env)

        self.discount = gamma
        self.memory_min = memory_min
        self.memory_max = memory_max
        self.eps = eps_start
        self.train_steps = train_steps
        self.batch_size = batch_size
        self.done = 0
        self.log_freq = log_freq
        self.progbar = Progbar(self.memory_max)
        self.memory = ReplayMemory(
            self.memory_max,
            ["state", "action", "reward", "next_state", "terminated"])

        self.eps_decay = eps_decay
        if eps_decay == -1:
            self.eps_decay = 1 / train_steps

        self.eps_min = eps_min
        self.update_double = update_double
        self.actions = []
        self.path_generator = self.roller()
        self.past_rewards = collections.deque([], 50)
        self.functions = [self.Q]

    def act(self, state, train=True):

        if train:
            if np.random.rand() < self.eps:
                return np.random.randint(self.env.action_space.n)

        return np.argmax(self.Q.predict(state))

    def train(self):

        self.progbar.__init__(self.memory_min)
        while (self.memory.size < self.memory_min):
            self.path_generator.__next__()

        while (self.done < self.train_steps):

            to_log = 0
            self.progbar.__init__(self.update_double)
            old_theta = self.Q.flaten.get()
            self.target_Q.copy(self.Q)
            while to_log < self.update_double:

                self.path_generator.__next__()

                rollout = self.memory.sample(self.batch_size)
                state_batch = U.torchify(rollout["state"])
                action_batch = U.torchify(rollout["action"]).long()
                reward_batch = U.torchify(rollout["reward"])

                non_final_batch = U.torchify(1 - rollout["terminated"])
                next_state_batch = U.torchify(rollout["next_state"])

                #current_q = self.Q(state_batch)

                current_q = self.Q(state_batch).gather(
                    1, action_batch.unsqueeze(1)).view(-1)
                _, a_prime = self.Q(next_state_batch).max(1)

                # Compute the target of the current Q values
                next_max_q = self.target_Q(next_state_batch).gather(
                    1, a_prime.unsqueeze(1)).view(-1)
                target_q = reward_batch + self.discount * non_final_batch * next_max_q.squeeze(
                )

                # Compute loss
                loss = self.Q.loss(current_q, target_q.detach(
                ))  # loss = self.Q.total_loss(current_q, target_q)

                # Optimize the model
                self.Q.optimize(loss, clip=True)

                self.progbar.add(self.batch_size,
                                 values=[("Loss", U.get(loss))])

                to_log += self.batch_size

            self.target_Q.copy(self.Q)
            new_theta = self.Q.flaten.get()

            self.log("Delta Theta L1",
                     U.get((new_theta - old_theta).abs().mean()))
            self.log("Av 50ep  rew", np.mean(self.past_rewards))
            self.log("Max 50ep rew", np.max(self.past_rewards))
            self.log("Min 50ep rew", np.min(self.past_rewards))
            self.log("Epsilon", self.eps)
            self.log("Done", self.done)
            self.log("Total", self.train_steps)
            self.target_Q.copy(self.Q)
            self.print()
            #self.play()
            self.save()

    def set_eps(self, x):
        self.eps = max(x, self.eps_min)

    def roller(self):

        state = self.env.reset()
        ep_reward = 0
        while True:
            episode = self.memory.empty_episode()
            for i in range(self.batch_size):

                # save current state
                episode["state"].append(state)

                # act
                action = self.act(state)
                self.actions.append(action)
                state, rew, done, info = self.env.step(action)

                episode["next_state"].append(state)
                episode["action"].append(action)
                episode["reward"].append(rew)
                episode["terminated"].append(done)

                ep_reward += rew
                self.set_eps(self.eps - self.eps_decay)

                if done:
                    self.past_rewards.append(ep_reward)
                    state = self.env.reset()
                    ep_reward = 0
                self.done += 1
                if not (self.done) % self.update_double:
                    self.update = True

            # record the episodes
            self.memory.record(episode)
            if self.memory.size < self.memory_min:
                self.progbar.add(self.batch_size, values=[("Loss", 0.0)])
            yield True
예제 #8
0
파일: Gate.py 프로젝트: aghriss/TRHPO
class GateTRPO(BaseAgent):

    name = "GateTRPO"

    def __init__(self,
                 env,
                 gatepolicy,
                 policy_func,
                 value_func,
                 n_options,
                 option_len=3,
                 timesteps_per_batch=1000,
                 gamma=0.99,
                 lam=0.97,
                 MI_lambda=1e-3,
                 gate_max_kl=1e-2,
                 option_max_kl=1e-2,
                 cg_iters=10,
                 cg_damping=1e-2,
                 vf_iters=2,
                 max_train=1000,
                 ls_step=0.5,
                 checkpoint_freq=50):

        super(GateTRPO, self).__init__(name=env.name)

        self.n_options = n_options
        #self.name = self.name
        self.env = env
        self.gamma = gamma
        self.lam = lam
        self.MI_lambda = MI_lambda
        self.current_option = 0
        self.timesteps_per_batch = timesteps_per_batch
        self.gate_max_kl = gate_max_kl
        self.cg_iters = cg_iters
        self.cg_damping = cg_damping
        self.max_train = max_train
        self.ls_step = ls_step
        self.checkpoint_freq = checkpoint_freq

        self.policy = gatepolicy(env.observation_space.shape,
                                 env.action_space.n)
        self.oldpolicy = gatepolicy(env.observation_space.shape,
                                    env.action_space.n,
                                    verbose=0)
        self.oldpolicy.disable_grad()
        self.progbar = Progbar(self.timesteps_per_batch)

        self.path_generator = self.roller()
        self.episodes_reward = collections.deque([], 5)
        self.episodes_len = collections.deque([], 5)
        self.done = 0
        self.functions = [self.policy]

        self.options = [
            OptionTRPO(env.name, i, env, policy_func, value_func, gamma, lam,
                       option_len, option_max_kl, cg_iters, cg_damping,
                       vf_iters, ls_step, self.logger, checkpoint_freq)
            for i in range(n_options)
        ]

    def act(self, state, train=True):
        if train:
            return self.policy.sample(state)
        return self.policy.act(state)

    def calculate_losses(self, states, options, actions, advantages):

        RIM = self.KLRIM(states, options, actions)
        old_pi = RIM["old_log_pi_oia_s"]
        pi = RIM["old_log_pi_oia_s"]

        ratio = torch.exp(
            m_utils.logp(pi, actions) -
            m_utils.logp(old_pi, actions))  # advantage * pnew / pold
        surrogate_gain = (ratio * advantages).mean()

        optimization_gain = surrogate_gain - self.MI_lambda * RIM["MI"]

        def surr_get(grad=False):
            Id, pid = RIM["MI_get"](grad)
            return (torch.exp(
                m_utils.logp(pid, actions) - m_utils.logp(old_pi, actions)) *
                    advantages).mean() - self.MI_lambda * Id

        RIM["gain"] = optimization_gain
        RIM["surr_get"] = surr_get
        return RIM

    def train(self):

        while self.done < self.max_train:
            print("=" * 40)
            print(" " * 15, self.done, "\n")
            self.logger.step()
            path = self.path_generator.__next__()
            self.oldpolicy.copy(self.policy)
            for p in self.options:
                p.oldpolicy.copy(p.policy)
            self._train(path)
            self.logger.display()
            if not self.done % self.checkpoint_freq:
                self.save()
                for p in self.options:
                    p.save()
            self.done = self.done + 1
        self.done = 0

    def _train(self, path):

        states = U.torchify(path["states"])
        options = U.torchify(path["options"]).long()
        actions = U.torchify(path["actions"]).long()
        advantages = U.torchify(path["baseline"])
        tdlamret = U.torchify(path["tdlamret"])
        vpred = U.torchify(
            path["vf"])  # predicted value function before udpate
        #advantages = (advantages - advantages.mean()) / advantages.std() # standardized advantage function estimate

        losses = self.calculate_losses(states, options, actions, advantages)
        kl = losses["gate_meankl"]
        optimization_gain = losses["gain"]

        loss_grad = self.policy.flaten.flatgrad(optimization_gain, retain=True)
        grad_kl = self.policy.flaten.flatgrad(kl, create=True, retain=True)

        theta_before = self.policy.flaten.get()
        self.log("Init param sum", theta_before.sum())
        self.log("explained variance",
                 (vpred - tdlamret).var() / tdlamret.var())

        if np.allclose(loss_grad.detach().cpu().numpy(), 0, atol=1e-19):
            print("Got zero gradient. not updating")
        else:
            with C.timeit("Conjugate Gradient"):
                stepdir = m_utils.conjugate_gradient(self.Fvp(grad_kl),
                                                     loss_grad,
                                                     cg_iters=self.cg_iters)

            self.log("Conjugate Gradient in s", C.elapsed)
            assert stepdir.sum() != float("Inf")
            shs = .5 * stepdir.dot(self.Fvp(grad_kl)(stepdir))
            lm = torch.sqrt(shs / self.gate_max_kl)
            self.log("lagrange multiplier:", lm)
            self.log("gnorm:",
                     np.linalg.norm(loss_grad.cpu().detach().numpy()))
            fullstep = stepdir / lm
            expected_improve = loss_grad.dot(fullstep)
            surrogate_before = losses["gain"].detach()

            with C.timeit("Line Search"):
                stepsize = 1.0
                for i in range(10):
                    theta_new = theta_before + fullstep * stepsize
                    self.policy.flaten.set(theta_new)
                    surr = losses["surr_get"]()
                    improve = surr - surrogate_before
                    kl = losses["KL_gate_get"]()
                    if surr == float("Inf") or kl == float("Inf"):
                        C.warning("Infinite value of losses")
                    elif kl > self.gate_max_kl:
                        C.warning("Violated KL")
                    elif improve < 0:
                        stepsize *= self.ls_step
                    else:
                        self.log("Line Search", "OK")
                        break
                else:
                    improve = 0
                    self.log("Line Search", "NOPE")
                    self.policy.flaten.set(theta_before)

            for op in self.options:
                losses["gain"] = losses["surr_get"](grad=True)
                op.train(states, options, actions, advantages, tdlamret,
                         losses)

            surr = losses["surr_get"]()
            improve = surr - surrogate_before
            self.log("Expected", expected_improve)
            self.log("Actual", improve)
            self.log("Line Search in s", C.elapsed)
            self.log("LS Steps", i)
            self.log("KL", kl)
            self.log("MI", -losses["MI"])
            self.log("MI improve", -losses["MI_get"]()[0] + losses["MI"])
            self.log("Surrogate", surr)
            self.log("Gate KL", losses["KL_gate_get"]())
            self.log("HRL KL", losses["KL_get"]())
            self.log("TDlamret mean", tdlamret.mean())
            del (improve, surr, kl)
        self.log("Last %i rolls mean rew" % len(self.episodes_reward),
                 np.mean(self.episodes_reward))
        self.log("Last %i rolls mean len" % len(self.episodes_len),
                 np.mean(self.episodes_len))
        del (losses, states, options, actions, advantages, tdlamret, vpred,
             optimization_gain, loss_grad, grad_kl)
        for _ in range(10):
            gc.collect()

    def roller(self):

        state = self.env.reset()
        path = {
            "states":
            np.array([state for _ in range(self.timesteps_per_batch)]),
            "options": np.zeros(self.timesteps_per_batch).astype(int),
            "actions": np.zeros(self.timesteps_per_batch).astype(int),
            "rewards": np.zeros(self.timesteps_per_batch),
            "terminated": np.zeros(self.timesteps_per_batch),
            "vf": np.zeros(self.timesteps_per_batch)
        }
        self.current_option = self.act(state)
        self.options[self.current_option].select()
        ep_rews = 0
        ep_len = 0
        t = 0
        done = True
        rew = 0.0
        self.progbar.__init__(self.timesteps_per_batch)
        while True:
            if self.options[self.current_option].finished:
                self.current_option = self.act(state)

            action = self.options[self.current_option].act(state)
            vf = self.options[self.current_option].value_function.predict(
                state)
            if t > self.timesteps_per_batch - 1:
                path["next_vf"] = vf * (1 - done * 1.0)
                self.add_vtarg_and_adv(path)
                yield path
                t = 0
                self.progbar.__init__(self.timesteps_per_batch)

            path["states"][t] = state
            state, rew, done, _ = self.env.step(action)
            path["options"][t] = self.options[self.current_option].option_n
            path["actions"][t] = action
            path["rewards"][t] = rew
            path["vf"][t] = vf
            path["terminated"][t] = done * 1.0
            ep_rews += rew
            ep_len += 1
            t += 1
            self.progbar.add(1)

            if done:
                state = self.env.reset()
                self.episodes_reward.append(ep_rews)
                self.episodes_len.append(ep_len)
                ep_rews = 0
                ep_len = 0

    def add_vtarg_and_adv(self, path):
        # General Advantage Estimation
        terminal = np.append(path["terminated"], 0)
        vpred = np.append(path["vf"], path["next_vf"])
        T = len(path["rewards"])
        path["advantage"] = np.empty(T, 'float32')
        lastgaelam = 0
        for t in reversed(range(T)):
            nonterminal = 1 - terminal[t + 1]
            delta = path["rewards"][t] + self.gamma * vpred[
                t + 1] * nonterminal - vpred[t]
            path["advantage"][
                t] = lastgaelam = delta + self.gamma * self.lam * nonterminal * lastgaelam
        path["tdlamret"] = (path["advantage"] + path["vf"]).reshape(-1, 1)
        path["baseline"] = (path["advantage"] - np.mean(
            path["advantage"])) / np.std(path["advantage"])

    def Fvp(self, grad_kl):
        def fisher_product(v):
            kl_v = (grad_kl * v).sum()
            grad_grad_kl = self.policy.flaten.flatgrad(kl_v, retain=True)
            return grad_grad_kl + v * self.cg_damping

        return fisher_product

    def KLRIM(self, states, options, actions):
        """
        pg : \pi_g
        pi_a_so : \pi(a|s,o)
        pi_oa_s : \pi(o,a|s)
        pi_o_as : \pi(o|a,s)
        pi_a_s  :  \pi(a|s)
        old : \tilde(\pi)
        """

        old_log_pi_a_so = torch.cat([
            p.oldpolicy.logsoftmax(states).unsqueeze(1).detach()
            for p in self.options
        ],
                                    dim=1)
        old_log_pg_o_s = self.oldpolicy.logsoftmax(states).detach()
        old_log_pi_oa_s = old_log_pi_a_so + old_log_pg_o_s.unsqueeze(-1)
        old_log_pi_a_s = old_log_pi_oa_s.exp().sum(1).log()
        old_log_pi_oia_s = old_log_pi_oa_s[np.arange(states.shape[0]), options]

    def calculate_surr(self, states, options, actions, advantages, grad=False):
        if grad:
            log_pi_a_so = torch.cat([
                p.policy.logsoftmax(states).unsqueeze(1) for p in self.options
            ],
                                    dim=1)
            log_pg_o_s = self.policy.logsoftmax(states)
        else:
            with torch.set_grad_enabled(False):
                log_pi_a_so = torch.cat([
                    p.policy.logsoftmax(states).unsqueeze(1)
                    for p in self.options
                ],
                                        dim=1)
                log_pg_o_s = self.policy.logsoftmax(states)

        log_pi_oa_s = log_pi_a_so + log_pg_o_s.unsqueeze(-1)
        log_pi_a_s = log_pi_oa_s.exp().sum(1).log()
        log_pi_o_as = log_pi_oa_s - log_pi_a_s.unsqueeze(1)

        H_O_AS = -(log_pi_a_s.exp() *
                   (log_pi_o_as * log_pi_o_as.exp()).sum(1)).sum(-1).mean()
        H_O = m_utils.entropy_logits(log_pg_o_s).mean()
        log_pi_o_ais = log_pi_o_as[np.arange(states.shape[0]), :,
                                   actions].exp().mean(0).log()
        log_pi_oi_ais = log_pi_o_as[np.arange(states.shape[0]), options,
                                    actions]
        log_pi_oia_s = log_pi_oa_s[np.arange(states.shape[0]), options]
        MI = m_utils.entropy_logits(log_pi_o_ais) - m_utils.entropy_logits(
            log_pi_oi_ais)

        ratio = torch.exp(
            m_utils.logp(pi, actions) - m_utils.logp(old_pi, actions))
        surrogate_gain = (ratio * advantages).mean()

        optimization_gain = surrogate_gain - self.MI_lambda * MI

#    def surr_get(self,grad=False):
#        Id,pid = RIM["MI_get"](grad)
#        return (torch.exp(m_utils.logp(pid,actions) - m_utils.logp(old_pi,actions))*advantages).mean() - self.MI_lambda*Id
#
#        RIM["gain"] = optimization_gain
#        RIM["surr_get"] = surr_get
#        return RIM
#
#        return MI
#
#        log_pi_a_so = torch.cat([p.policy.logsoftmax(states).unsqueeze(1) for p in self.options],dim=1)
#        log_pg_o_s = self.policy.logsoftmax(states)
#        log_pi_oa_s = log_pi_a_so+log_pg_o_s.unsqueeze(-1)
#        log_pi_a_s = log_pi_oa_s.exp().sum(1).log()
#        log_pi_o_as = log_pi_oa_s - log_pi_a_s.unsqueeze(1)
#
#
#        log_pi_o_ais = log_pi_o_as[np.arange(states.shape[0]),:,actions].exp().mean(0).log()
#        log_pi_oi_ais = log_pi_o_as[np.arange(states.shape[0]),options,actions]

    def mean_HKL(self, states, old_log_pi_a_s, grad=False):

        if grad:
            log_pi_a_so = torch.cat([
                p.policy.logsoftmax(states).unsqueeze(1) for p in self.options
            ],
                                    dim=1)
            log_pg_o_s = self.policy.logsoftmax(states)
        else:
            log_pi_a_so = torch.cat([
                p.policy.logsoftmax(states).detach().unsqueeze(1)
                for p in self.options
            ],
                                    dim=1)
            log_pg_o_s = self.policy.logsoftmax(states).detach()
        log_pi_a_s = (log_pi_a_so +
                      log_pg_o_s.unsqueeze(-1)).exp().sum(1).log()
        mean_kl_new_old = m_utils.kl_logits(old_log_pi_a_s, log_pi_a_s).mean()
        return mean_kl_new_old

    def mean_KL_gate(self, states, old_log_pg_o_s, grad=False):
        if grad:
            log_pg_o_s = self.policy.logsoftmax(states)
        else:
            log_pg_o_s = self.policy.logsoftmax(states).detach()
        return m_utils.kl_logits(old_log_pg_o_s, log_pg_o_s).mean()

    def load(self):
        super(GateTRPO, self).load()

        for p in self.options:
            p.load()