コード例 #1
0
    def __init__(self, config: Config) -> None:
        super().__init__()
        self.hparams = config

        self.env = env_selector(self.hparams)
        self.Da = self.env.action_space.flat_dim
        self.Do = self.env.observation_space.flat_dim
        self.q1 = ValueFunction(self.Do + self.Da,
                                [config.layer_size, config.layer_size])
        self.q2 = ValueFunction(self.Do + self.Da,
                                [config.layer_size, config.layer_size])
        # Constructs a value function mlp with Relu hidden non-linearities, no output non-linearity and with xavier
        # init for weights and zero init for biases.
        self.q1_target = ValueFunction(self.Do + self.Da,
                                       [config.layer_size, config.layer_size])
        self.q2_target = ValueFunction(self.Do + self.Da,
                                       [config.layer_size, config.layer_size])

        self.pool_train = SimpleReplayBuffer(
            env_spec=self.env.spec,
            max_replay_buffer_size=config.max_pool_size,
        )  # create a replay buffer for state+skill and action.

        self.pool_val = SimpleReplayBuffer(
            env_spec=self.env.spec,
            max_replay_buffer_size=config.max_pool_size,
        )
        self.policy = GMMPolicy(
            env_spec=self.env.spec,
            K=config.K,
            hidden_layer_sizes=[config.layer_size, config.layer_size],
            # TODO: pass both q functions to use policy in deterministic mode
            qf=self.q1_target,
            reg=config.reg,
            device=self.hparams.device,
            reparametrization=True
        )  # GMM policy with K mixtures, no reparametrization trick, regularization

        # TODO: add assertion to test qf of policy and qf of model.

        self._policy_lr = config.lr
        self._qf_lr = config.lr
        self._vf_lr = config.lr
        # TODO fix varialbe naming with _
        self._scale_reward = config.scale_reward
        self._discount = config.discount
        self._tau = config.tau
        self.max_path_return = -np.inf
        self.last_path_return = 0
        self.val_path_return = 0
        self._scale_entropy = config.scale_entropy

        self._save_full_state = config.save_full_state
        self.modules = [
            "Policy", self.policy, "Q1", self.q1, "Q2", self.q2, "Q1_target",
            self.q1_target, "Q2_target", self.q2_target
        ]
コード例 #2
0
 def on_sanity_check_start(self) -> None:
     q1 = ValueFunction(self.Do + self.Da,
                        [self.hparams.layer_size, self.hparams.layer_size])
     self.q1.load_state_dict(q1.state_dict())
     q2 = ValueFunction(self.Do + self.Da,
                        [self.hparams.layer_size, self.hparams.layer_size])
     self.q2.load_state_dict(q2.state_dict())
コード例 #3
0
    def __init__(self, config: Config) -> None:
        super().__init__()
        self.hparams = config

        self.env = env_selector(
            self.hparams
        )  # TODO: normalization is not required but will it be needed?
        self.eval_env = env_selector(self.hparams, config.seed + 1)
        self.Da = self.env.action_space.flat_dim
        self.Do = self.env.observation_space.flat_dim  # includes skill in case env is option wrapped
        self.qf = ValueFunction(self.Do + self.Da,
                                [config.layer_size, config.layer_size])
        # Constructs a value function mlp with Relu hidden non-linearities, no output non-linearity and with xavier
        # init for weights and zero init for biases.
        self.vf = ValueFunction(self.Do,
                                [config.layer_size, config.layer_size])
        self.vf_target = ValueFunction(self.Do,
                                       [config.layer_size, config.layer_size])
        self.vf_target.load_state_dict(self.vf.state_dict())

        self.pool = SimpleReplayBuffer(
            env_spec=self.env.spec,
            max_replay_buffer_size=config.max_pool_size,
        )  # create a replay buffer for state+skill and action.

        self.policy = GMMPolicy(
            env_spec=self.env.spec,
            K=config.K,
            hidden_layer_sizes=[config.layer_size, config.layer_size],
            qf=self.qf,
            reg=config.reg,
            device=self.hparams.device
        )  # GMM policy with K mixtures, no reparametrization trick, regularization
        self.modules = [
            "Policy", self.policy, "QF", self.qf, "VF", self.vf, "VF_Target",
            self.vf_target
        ]

        # TODO: add assertion to test qf of policy and qf of model.

        self.sampler = Sampler(self.env, config.max_path_length)

        self._policy_lr = config.lr
        self._qf_lr = config.lr
        self._vf_lr = config.lr
        # TODO fix varialbe naming with _
        self._scale_reward = config.scale_reward
        self._discount = config.discount
        self._tau = config.tau
        self.max_path_return = -np.inf
        self.last_path_return = 0
        self.val_path_return = 0
        self._scale_entropy = config.scale_entropy

        self._save_full_state = config.save_full_state
        # Runs on CPU(moved sampling to (on_train_start) to avoid bug in DIAYN + use GPU instead of CPU(No need for device logic!!) as Models are transferred to GPU only by trainer which happens after the lightning model init.
        # TODO remove device logic in Policy
        # Also the reason why wandb logger is not available
        self.batch_idx = None
コード例 #4
0
 def on_sanity_check_start(self):
     if self.hparams.distill:
         self.stage = 0
         self.distiller.append(self.discriminator)
         self.discriminator = self.distiller[self.stage]
     print(self.discriminator.state_dict)
     # Reinitializing to wipe out values loaded from checkpoint.
     qf = ValueFunction(self.Do + self.Da,
                        [self.hparams.layer_size, self.hparams.layer_size])
     self.qf.load_state_dict(qf.state_dict())
     vf = ValueFunction(self.Do,
                        [self.hparams.layer_size, self.hparams.layer_size])
     self.vf.load_state_dict(vf.state_dict())
     self.vf_target.load_state_dict(self.vf.state_dict())
     # TODO figure out why vf doesn't need load state dict, ie doesnt throw tensors does not reequire grad error.
     policy = GMMPolicy(
         env_spec=self.env.spec,
         K=self.hparams.K,
         hidden_layer_sizes=[
             self.hparams.layer_size, self.hparams.layer_size
         ],
         qf=self.qf,
         reg=self.hparams.reg,
         device=self.hparams.device
     )  # GMM policy with K mixtures, no reparametrization trick, regularization
     self.policy.load_state_dict(policy.state_dict())
     # Verified by loading trained checkpoint that policy qf and self.qf are the exact same
     # TODO hack to load only discriminator, reemove also is bleow line needed?
     self.modules = [
         "Policy", self.policy, "QF", self.qf, "VF", self.vf, "VF_Target",
         self.vf_target, "Discriminator", self.discriminator
     ]
     self.pool.add_samples(
         self.sampler.sample(self.hparams.min_pool_size, self.policy))
     print("Initialized Replay Buffer with %d samples" % self.pool.size)
     if self.on_gpu:
         self._p_z = self._p_z.cuda(self.hparams.device)
         for i in range(len(self.hparams.disc_size)):
             self.distiller[i].cuda(self.hparams.device)
         print("Moving p_z and distillers to GPU")
コード例 #5
0
class SAC(pl.LightningModule):
    def __init__(self, config: Config) -> None:
        super().__init__()
        self.hparams = config

        self.env = env_selector(
            self.hparams
        )  # TODO: normalization is not required but will it be needed?
        self.eval_env = env_selector(self.hparams, config.seed + 1)
        self.Da = self.env.action_space.flat_dim
        self.Do = self.env.observation_space.flat_dim  # includes skill in case env is option wrapped
        self.qf = ValueFunction(self.Do + self.Da,
                                [config.layer_size, config.layer_size])
        # Constructs a value function mlp with Relu hidden non-linearities, no output non-linearity and with xavier
        # init for weights and zero init for biases.
        self.vf = ValueFunction(self.Do,
                                [config.layer_size, config.layer_size])
        self.vf_target = ValueFunction(self.Do,
                                       [config.layer_size, config.layer_size])
        self.vf_target.load_state_dict(self.vf.state_dict())

        self.pool = SimpleReplayBuffer(
            env_spec=self.env.spec,
            max_replay_buffer_size=config.max_pool_size,
        )  # create a replay buffer for state+skill and action.

        self.policy = GMMPolicy(
            env_spec=self.env.spec,
            K=config.K,
            hidden_layer_sizes=[config.layer_size, config.layer_size],
            qf=self.qf,
            reg=config.reg,
            device=self.hparams.device
        )  # GMM policy with K mixtures, no reparametrization trick, regularization
        self.modules = [
            "Policy", self.policy, "QF", self.qf, "VF", self.vf, "VF_Target",
            self.vf_target
        ]

        # TODO: add assertion to test qf of policy and qf of model.

        self.sampler = Sampler(self.env, config.max_path_length)

        self._policy_lr = config.lr
        self._qf_lr = config.lr
        self._vf_lr = config.lr
        # TODO fix varialbe naming with _
        self._scale_reward = config.scale_reward
        self._discount = config.discount
        self._tau = config.tau
        self.max_path_return = -np.inf
        self.last_path_return = 0
        self.val_path_return = 0
        self._scale_entropy = config.scale_entropy

        self._save_full_state = config.save_full_state
        # Runs on CPU(moved sampling to (on_train_start) to avoid bug in DIAYN + use GPU instead of CPU(No need for device logic!!) as Models are transferred to GPU only by trainer which happens after the lightning model init.
        # TODO remove device logic in Policy
        # Also the reason why wandb logger is not available
        self.batch_idx = None
        # torch.autograd.set_detect_anomaly(True) #TODO: disable if compute overhead

    def get_best_skill(self,
                       policy,
                       env,
                       num_skills,
                       max_path_length,
                       n_paths=1):
        print('Finding best skill...')
        reward_list = []
        with policy.deterministic(self.hparams.deterministic_eval):
            for z in range(num_skills):
                env.reset(state=None, skill=z)
                total_returns = 0
                sampler = Sampler(env, max_path_length)
                for p in range(n_paths):
                    new_paths = sampler.sample(max_path_length, policy)
                    total_returns += new_paths[-1]['path_return']
                print('Reward for skill %d = %.3f' % (z, total_returns))
                reward_list.append(total_returns)

        best_z = np.argmax(reward_list)
        print('Best skill found: z = %d, reward = %d, seed = %d' %
              (best_z, reward_list[best_z], self.hparams.seed))
        return best_z

    def on_sanity_check_start(self) -> None:
        self.pool.add_samples(
            self.sampler.sample(self.hparams.min_pool_size, self.policy))
        print("Initialized Replay Buffer with %d samples" % self.pool.size)

    def __dataloader(self) -> DataLoader:
        """Initialize the Replay Buffer dataset used for retrieving experiences"""
        dataset = RLDataset(self.pool, self.hparams.epoch_length,
                            self.hparams.batch_size)

        # TODO: figure out why referencee codeee uses episode length abovee instead of batch size

        def _init_fn(worker_id):
            np.random.seed(self.hparams.seed + worker_id)

        dataloader = DataLoader(dataset=dataset,
                                batch_size=self.hparams.batch_size,
                                num_workers=self.hparams.num_workers,
                                worker_init_fn=_init_fn)
        return dataloader

    def train_dataloader(self) -> DataLoader:
        """Get train loader"""
        return self.__dataloader()

    def val_dataloader(self) -> DataLoader:
        """Initialize the Replay Buffer dataset used for retrieving experiences"""
        dataset = RLDataset(self.pool, 1, 1)
        dataloader = DataLoader(
            dataset=dataset,
            batch_size=1,
            # num_workers=5
        )
        return dataloader

    # def _split_obs(self,t):
    # TODO remove from DIAYN, herf, and v2?
    #     # TODO: verify that dim is 1, assert shape
    #     return torch.split(t, [self._Do, self._num_skills], 1)

    def training_step(self, batch, batch_idx, optimizer_idx) -> OrderedDict:

        states, actions, rewards, dones, next_states = batch
        self.batch_idx = batch_idx
        # print(states[0], batch_idx)

        # print(self.pool.size,optimizer_idx,batch_idx,states[0])
        # print("Running train",states.shape,batch_idx,optimizer_idx)

        # TODO: vars are already floatTensors.
        # Train Policy
        if optimizer_idx == 0:
            # for param in self.policy.parameters():
            #     print(param.names, param.size(), param.requires_grad)
            # print("Done")
            # for param in self.vf.parameters():
            #     print(param.names, param.size(), param.requires_grad)
            # print("Donevf")
            # print(torch.max(rewards),torch.min(rewards),torch.mean(rewards))
            samples = self.sampler.sample(
                1, self.policy)  # TODO remove magic numbers
            self.pool.add_samples(samples)

            if samples[0]['done'] or samples[0][
                    'path_length'] == self.hparams.max_path_length:
                self.max_path_return = max(self.max_path_return,
                                           samples[0]['path_return'])
                self.last_path_return = samples[0]['path_return']

            distributions, action_samples, log_probs, corr, reg_loss = self.policy(
                states)
            assert log_probs.shape == torch.Size([action_samples.shape[0]])
            # TODO: figure out why squash correction is not done in policy as kl_surrogate seems
            # to need uncorrected log probs?
            self.values = self.vf(states)
            # print(action_samples.shape,log_probs.shape,reg_loss.shape,states.shape) #TODO assert shapes

            with torch.no_grad():
                self.log_targets = self.qf(states, action_samples)
                self.scaled_log_pi = self._scale_entropy * (log_probs - corr)

            # How is this kl surrogate loss derived?
            self._kl_surrogate_loss = torch.mean(
                log_probs *
                (self.scaled_log_pi - self.log_targets + self.values.detach()))
            self._policy_loss = reg_loss + self._kl_surrogate_loss
            self._vf_loss = 0.5 * torch.mean(
                (self.values - self.log_targets + self.scaled_log_pi)**2)

            log = {
                'max_path_return':
                self.max_path_return,
                'train_loss':
                self._policy_loss.detach().cpu().numpy(),
                'kl_loss':
                self._kl_surrogate_loss.detach().cpu().numpy(),
                'reg_loss':
                reg_loss.detach().cpu().numpy(),
                'gmm_means':
                torch.mean(distributions.component_distribution.mean).detach().
                cpu().numpy(),
                'gmm_sigmas':
                torch.mean(distributions.component_distribution.stddev).detach(
                ).cpu().numpy(),
                'vf_loss':
                self._vf_loss.detach().cpu().numpy(),
                'vf_value':
                torch.mean(self.values).detach().cpu().numpy(),
                'scaled_log_pi':
                torch.mean(self.scaled_log_pi).detach().cpu().numpy()
            }
            status = {
                'train_loss':
                self._policy_loss.detach().cpu().numpy(),
                # 'vf_loss': self._vf_loss,
                # 'steps': torch.tensor(self.global_step),#.to(device),#Where did this global_step comee from is it PL inbuilt?
                'max_ret':
                self.max_path_return,
                'last_ret':
                self.last_path_return,
                'gmm_mu':
                torch.mean(distributions.component_distribution.mean).detach().
                cpu().numpy(),
                'gmm_sig':
                torch.mean(distributions.component_distribution.stddev).detach(
                ).cpu().numpy(),
                'vf_loss':
                self._vf_loss.detach().cpu().numpy(),
                'vf_mu':
                torch.mean(self.values).detach().cpu().numpy()
            }

            return OrderedDict({
                'loss': self._policy_loss + self._vf_loss,
                'log': log,
                'progress_bar': status
            })

        # TODO is it faster if qf is also optimized simultaneously along with vf and policy?

        # Train QF
        if optimizer_idx == 1:
            # for param in self.qf.parameters():
            #     print(param.names, param.size(), param.requires_grad)
            # print("Doneqf")
            self.q_values = self.qf(states, actions)
            # assert (self.policy._qf(states,actions)==self.q_values).all()
            with torch.no_grad():
                vf_next_target = self.vf_target(next_states)  # N
                ys = self._scale_reward * rewards + (
                    1 - dones) * self._discount * vf_next_target  # N

            self._td_loss = 0.5 * torch.mean((ys - self.q_values)**2)

            return OrderedDict({
                'loss': self._td_loss,
                'log': {
                    'qf_loss': self._td_loss.detach().cpu().numpy(),
                    'qf_value':
                    torch.mean(self.q_values).detach().cpu().numpy(),
                    'rewards': torch.mean(rewards).detach().cpu().numpy()
                },
                'progress_bar': {
                    'qf_loss': self._td_loss,
                    'rewards': torch.mean(rewards).detach().cpu().numpy(),
                    'qf_mu': torch.mean(self.q_values).detach().cpu().numpy()
                }
            })

        # if self.trainer.use_dp or self.trainer.use_ddp2:
        #     loss = loss.unsqueeze(0)

    def on_batch_end(self) -> None:
        with torch.no_grad():
            for vf, vf_targ in zip(self.vf.parameters(),
                                   self.vf_target.parameters()):
                vf_targ.data.mul_(1 - self.hparams.tau)
                vf_targ.data.add_(self.hparams.tau * vf.data)

    def validation_step(self, batch, batch_idx) -> OrderedDict:
        # state = self.eval_env.reset()
        # print("Running Validation step")
        # path_return = 0
        # path_length = 0
        # for i in range(self.config.max_path_length):
        #     action = self.policy.get_actions(state.reshape((1, -1)))
        #     next_ob, reward, terminal, info = self.env.step(action)
        #     state = next_ob
        #     path_return += reward
        #     path_length += 1
        #     if(terminal):
        #         break

        return OrderedDict({'val_ret': 0, 'path_len': 0})

    def validation_epoch_end(self, outputs) -> OrderedDict:
        gc.collect()
        state = self.eval_env.reset()
        print(
            datetime.datetime.now(
                dateutil.tz.tzlocal()).strftime('%Y-%m-%d-%H-%M-%S-%f-%Z'))
        # print("Running Validation")
        path_return = 0
        path_length = 0
        self.ims = []
        with self.policy.deterministic(self.hparams.deterministic_eval):
            # TODO add support for n_eval_iters
            for i in range(self.hparams.max_path_length):
                action = self.policy.get_actions(state.reshape((1, -1)))
                next_ob, reward, done, info = self.eval_env.step(action)
                if self.hparams.render_validation:
                    # TODO use common resizing everywhere
                    self.ims.append(
                        cv2.resize(self.eval_env.render(mode='rgb_array'),
                                   (500, 500)))
                    # print(self.ims[0].shape)#config={'height':500,'width':500,'xpos':0,'ypos':0,'title':'validation'}
                state = next_ob
                path_return += reward
                path_length += 1
                if done:
                    break

        self.val_path_return = path_return  # TODO : remove printcall back for this, already printed in progress bar
        return OrderedDict({
            'log': {
                'path_return': path_return,
                'path_length': path_length
            },
            'progress_bar': {
                'val_ret': path_return,
                'path_len': path_length
            }
        })

    def configure_optimizers(self) -> List[Optimizer]:
        """ Initialize Adam optimizer"""
        optimizers = []
        # TODO: combining vf and policy, figure out more elegant way to have unlinked learning rates than as
        # a multiplication factor in the loss sum. Also figure out why having them separate doesn't increase
        # compute time by the expected
        optimizers.append(
            optim.Adam(list(self.policy.parameters()) +
                       list(self.vf.parameters()),
                       lr=self._policy_lr))
        # optimizers.append(optim.Adam(self.vf.parameters(), lr=self._vf_lr))
        optimizers.append(optim.Adam(self.qf.parameters(), lr=self._qf_lr))
        return optimizers

    def forward(self, *args, **kwargs):
        return None

    def check_modules(self):
        self.policy.cuda(self.hparams.device)
        self.vf.cuda(self.hparams.device)
        self.qf.cuda(self.hparams.device)
        self.vf_target.cuda(self.hparams.device)
        for param in self.policy.parameters():
            print(param.data.shape, param.data.mean(), param.data.max(),
                  param.data.min(), param.data.std())
        for param in self.vf.parameters():
            print(param.data.shape, param.data.mean(), param.data.max(),
                  param.data.min(), param.data.std())
        for param in self.qf.parameters():
            print(param.data.shape, param.data.mean(), param.data.max(),
                  param.data.min(), param.data.std())
        for param in self.vf_target.parameters():
            print(param.data.shape, param.data.mean(), param.data.max(),
                  param.data.min(), param.data.std())
コード例 #6
0
    def __init__(self, config: Config) -> None:
        self.hparams = config
        self.env = env_selector(self.hparams)  # TODO: ensure normalization is not required
        self.eval_env = env_selector(self.hparams, config.seed + 1)  # TODO: add functionality to optionwrap for DIAYN
        # TODO: check all config.names to ensure they are in dict
        self.Da = self.env.action_space.flat_dim
        self.Do = self.env.observation_space.flat_dim
        self.qf = ValueFunction(self.Do + self.Da, [config.layer_size, config.layer_size])
        # Constructs a value function mlp with Relu hidden non-linearities, no output non-linearity and with xavier
        # init for weights and zero init for biases.
        self.vf = ValueFunction(self.Do, [config.layer_size, config.layer_size])
        self.vf_target = ValueFunction(self.Do, [config.layer_size, config.layer_size])
        self.vf_target.load_state_dict(self.vf.state_dict())

        self.pool = SimpleReplayBuffer(
            env_spec=self.env.spec,
            max_replay_buffer_size=config.max_pool_size,
        )  # create a replay buffer for state+skill and action.

        self.policy = GMMPolicy(
            env_spec=self.env.spec,
            K=config.K,
            hidden_layer_sizes=[config.layer_size, config.layer_size],
            qf=self.qf,
            reg=config.reg,
            device="cpu"
        )  # GMM policy with K mixtures, no reparametrization trick, regularization

        # self.policy.cuda(config.device)
        # self.vf.cuda(config.device)
        # self.qf.cuda(config.device)
        # self.vf_target.cuda(config.device)

        # TODO: add assertion to test qf of policy and qf of model.

        self.sampler = Sampler(self.env, config.max_path_length)

        self._policy_lr = config.lr
        self._qf_lr = config.lr
        self._vf_lr = config.lr
        # TODO fix varialbe naming with _
        self._scale_reward = config.scale_reward
        self._discount = config.discount
        self._tau = config.tau
        self.max_path_return = -np.inf
        self.last_path_return = 0
        self.val_path_return = 0
        self._scale_entropy = config.scale_entropy

        self._save_full_state = config.save_full_state
        # self.z = self.get_best_skill(self.policy, self.env, self.config.num_skills, self.config.max_path_length)
        # self.env.reset(None,self.z)

        # Runs on CPU as Models are transferred to GPU only by trainer which happens after the lightning model init.
        # Also the reason why wandb logger is not available
        self.pool.add_samples(self.sampler.sample(config.min_pool_size, self.policy))
        # self.optimizers = []
        # TODO: combining vf and policy, figure out more elegant way to have unlinked learning rates than as
        # a multiplication factor in the loss sum. Also figure out why having them separate doesn't increase
        # compute time by the expected
        self.optimizer_policy = optim.Adam(list(self.policy.parameters())  # +list(self.vf.parameters())
                                           , lr=self._policy_lr)
        self.optimizer_vf = optim.Adam(self.vf.parameters(), lr=self._vf_lr)
        self.optimizer_qf = optim.Adam(self.qf.parameters(), lr=self._qf_lr)
        self.optimizer = optim.Adam(list(self.policy.parameters())+
                                    list(self.vf.parameters())+
                                    list(self.qf.parameters()), lr=self._policy_lr)
コード例 #7
0
class SAC():

    def __init__(self, config: Config) -> None:
        self.hparams = config
        self.env = env_selector(self.hparams)  # TODO: ensure normalization is not required
        self.eval_env = env_selector(self.hparams, config.seed + 1)  # TODO: add functionality to optionwrap for DIAYN
        # TODO: check all config.names to ensure they are in dict
        self.Da = self.env.action_space.flat_dim
        self.Do = self.env.observation_space.flat_dim
        self.qf = ValueFunction(self.Do + self.Da, [config.layer_size, config.layer_size])
        # Constructs a value function mlp with Relu hidden non-linearities, no output non-linearity and with xavier
        # init for weights and zero init for biases.
        self.vf = ValueFunction(self.Do, [config.layer_size, config.layer_size])
        self.vf_target = ValueFunction(self.Do, [config.layer_size, config.layer_size])
        self.vf_target.load_state_dict(self.vf.state_dict())

        self.pool = SimpleReplayBuffer(
            env_spec=self.env.spec,
            max_replay_buffer_size=config.max_pool_size,
        )  # create a replay buffer for state+skill and action.

        self.policy = GMMPolicy(
            env_spec=self.env.spec,
            K=config.K,
            hidden_layer_sizes=[config.layer_size, config.layer_size],
            qf=self.qf,
            reg=config.reg,
            device="cpu"
        )  # GMM policy with K mixtures, no reparametrization trick, regularization

        # self.policy.cuda(config.device)
        # self.vf.cuda(config.device)
        # self.qf.cuda(config.device)
        # self.vf_target.cuda(config.device)

        # TODO: add assertion to test qf of policy and qf of model.

        self.sampler = Sampler(self.env, config.max_path_length)

        self._policy_lr = config.lr
        self._qf_lr = config.lr
        self._vf_lr = config.lr
        # TODO fix varialbe naming with _
        self._scale_reward = config.scale_reward
        self._discount = config.discount
        self._tau = config.tau
        self.max_path_return = -np.inf
        self.last_path_return = 0
        self.val_path_return = 0
        self._scale_entropy = config.scale_entropy

        self._save_full_state = config.save_full_state
        # self.z = self.get_best_skill(self.policy, self.env, self.config.num_skills, self.config.max_path_length)
        # self.env.reset(None,self.z)

        # Runs on CPU as Models are transferred to GPU only by trainer which happens after the lightning model init.
        # Also the reason why wandb logger is not available
        self.pool.add_samples(self.sampler.sample(config.min_pool_size, self.policy))
        # self.optimizers = []
        # TODO: combining vf and policy, figure out more elegant way to have unlinked learning rates than as
        # a multiplication factor in the loss sum. Also figure out why having them separate doesn't increase
        # compute time by the expected
        self.optimizer_policy = optim.Adam(list(self.policy.parameters())  # +list(self.vf.parameters())
                                           , lr=self._policy_lr)
        self.optimizer_vf = optim.Adam(self.vf.parameters(), lr=self._vf_lr)
        self.optimizer_qf = optim.Adam(self.qf.parameters(), lr=self._qf_lr)
        self.optimizer = optim.Adam(list(self.policy.parameters())+
                                    list(self.vf.parameters())+
                                    list(self.qf.parameters()), lr=self._policy_lr)
        # torch.autograd.set_detect_anomaly(True)

    @staticmethod
    def _squash_correction(t):
        """receives action samples from gmm of shape batchsize x dim_action. For each action, the log probability
         correction requires a product by the inverse of the jacobian determinant. In log, it reduces to a sum, including
         the determinant of the diagonal jacobian. Adding epsilon to avoid overflow due to log
         Should return a tensor of batchsize x 1"""
        # TODO: Refer to OpenAI implementation for more numerically stable correction
        # https://github.com/openai/spinningup/blob/master/spinup/algos/pytorch/sac/core.py
        return torch.sum(torch.log(1 - (t ** 2) + EPS), dim=1)

    def train(self):
        for epoch in range(self.hparams.max_epochs):
            for step in range(self.hparams.epoch_length):

                samples = self.sampler.sample(1, self.policy)  # TODO remove magic numbers
                self.pool.add_samples(samples)
                # print(samples[0]['done'])
                if samples[0]['done'] or samples[0]['path_length'] == self.hparams.max_path_length:
                    self.max_path_return = max(self.max_path_return, samples[0]['path_return'])
                    self.last_path_return = samples[0]['path_return']

                batch = self.pool.random_batch(self.hparams.batch_size)
                states, rewards, actions, dones, next_states = torch.FloatTensor(
                    batch['observations']), torch.FloatTensor(batch['rewards']), torch.FloatTensor(
                    batch['actions']), torch.FloatTensor(batch['dones']), torch.FloatTensor(batch['next_observations'])
                # self.optimizer_policy.zero_grad()
                self.optimizer.zero_grad()
                distributions, action_samples, log_probs, reg_loss = self.policy(states)
                # print(log_probs.shape)
                # assert log_probs.shape == torch.Size([action_samples.shape[0]])
                # TODO: figure out why squash correction is not done in policy as kl_surrogate seems
                # to need uncorrected log probs?
                self.values = self.vf(states)
                # print(action_samples.shape,log_probs.shape,reg_loss.shape,states.shape) #TODO assert shapes

                with torch.no_grad():

                    self.log_targets = self.qf(states, action_samples)
                    # Probability of squashed action is not same as probability of unsquashed action.
                    corr = self._squash_correction(action_samples)
                    # print(log_probs.shape,corr.shape)
                    # assert not torch.isnan(corr).any() and not torch.isinf(corr).any()
                    # correction must be subtracted from log_probs as we need inverse of jacobian determinant.
                    self.scaled_log_pi = self._scale_entropy * (log_probs - corr)


                # self._vf_loss = 0.5 * torch.mean(
                #             (self.values - self.log_targets - self.scaled_log_pi) ** 2)
                ## How is this kl surrogate loss derived?
                self._kl_surrogate_loss = torch.mean(log_probs * (
                        self.scaled_log_pi - self.log_targets + self.values.detach()))
                self._policy_loss = reg_loss + self._kl_surrogate_loss

                # self._policy_loss.backward()
                # self.optimizer_policy.step()
                #
                # self.optimizer_vf.zero_grad()
                # self.values = self.vf(states)
                self._vf_loss = 0.5 * torch.mean(
                    (self.values - self.log_targets + self.scaled_log_pi) ** 2)



                # self._vf_loss.backward()
                # self.optimizer_vf.step()
                #
                # self.optimizer_qf.zero_grad()
                self.q_values = self.qf(states, actions)
                # assert (self.policy._qf(states,actions)==self.q_values).all()
                with torch.no_grad():
                    vf_next_target = self.vf_target(next_states)  # N
                    # self._vf_target_params = self._vf.get_params_internal()

                    ys = self._scale_reward * rewards + (1 - dones) * self._discount * vf_next_target  # N


                self._td_loss = 0.5 * torch.mean((ys - self.q_values) ** 2)

                #TODO COde not working, need to fix bug
                self.loss = self._policy_loss + self._vf_loss + self._td_loss
                self.loss.backward()
                self.optimizer.step()

                with torch.no_grad():
                    for vf, vf_targ in zip(self.vf.parameters(), self.vf_target.parameters()):
                        vf_targ.data.mul_(1 - self.hparams.tau)
                        vf_targ.data.add_((self.hparams.tau) * vf.data)


            print('train_loss: ', self._policy_loss.detach().numpy(),
                  'epoch: ', epoch,
                  # 'vf_loss': self._vf_loss,
                  # 'steps': torch.tensor(self.global_step),#.to(device),#Where did this global_step comee from is it PL inbuilt?
                  'max_return: ', (self.max_path_return),
                  'last_return: ', (self.last_path_return),
                  # 'gmm_means: ', torch.mean(distributions.component_distribution.mean).detach().numpy(),
                  # 'gmm_sigmas: ', torch.mean(distributions.component_distribution.stddev).detach().numpy(),
                  'vf_loss: ', self._vf_loss.detach().numpy(),
                  'vf_value: ', torch.mean(self.values).detach().numpy(),
                  'qf_loss: ', self._td_loss.detach().numpy(),
                  'rewards: ', torch.mean(rewards).detach().numpy(),
                  'actions: ', torch.mean(actions).detach().numpy(),
                  'qf_value: ', torch.mean(self.q_values).detach().numpy()
                  )

            state = self.eval_env.reset()
            # print("Running Validation")
            path_return = 0
            path_length = 0
            self.ims = []
            print(datetime.datetime.now(dateutil.tz.tzlocal()).strftime('%Y-%m-%d-%H-%M-%S-%f-%Z'))
            # with self.policy.deterministic(True):
            #     for i in range(self.hparams.max_path_length):
            #         action = self.policy.get_actions(state.reshape((1, -1)))
            #         next_ob, reward, done, info = self.eval_env.step(action)
            #         if self.hparams.render_validation:
            #             self.ims.append(self.eval_env.render(mode='rgb_array'))
            #             # print(self.ims[0].shape)#config={'height':500,'width':500,'xpos':0,'ypos':0,'title':'validation'}
            #         # print(reward)
            #         state = next_ob
            #         path_return += reward
            #         path_length += 1
            #         if (done):
            #             break

            self.val_path_return = path_return
            print('path_return: ', path_return,
                  'path_length: ', path_length)
コード例 #8
0
class DISTILL_Q(pl.LightningModule):
    def __init__(self, config: Config) -> None:
        super().__init__()
        self.hparams = config

        self.env = env_selector(self.hparams)
        self.Da = self.env.action_space.flat_dim
        self.Do = self.env.observation_space.flat_dim
        self.q1 = ValueFunction(self.Do + self.Da,
                                [config.layer_size, config.layer_size])
        self.q2 = ValueFunction(self.Do + self.Da,
                                [config.layer_size, config.layer_size])
        # Constructs a value function mlp with Relu hidden non-linearities, no output non-linearity and with xavier
        # init for weights and zero init for biases.
        self.q1_target = ValueFunction(self.Do + self.Da,
                                       [config.layer_size, config.layer_size])
        self.q2_target = ValueFunction(self.Do + self.Da,
                                       [config.layer_size, config.layer_size])
        self.stage = None

        self.pool_train = SimpleReplayBuffer(
            env_spec=self.env.spec,
            max_replay_buffer_size=config.max_pool_size,
        )  # create a replay buffer for state+skill and action.

        self.pool_val = SimpleReplayBuffer(
            env_spec=self.env.spec,
            max_replay_buffer_size=config.max_pool_size,
        )

        self.policy = GMMPolicy(
            env_spec=self.env.spec,
            K=config.K,
            hidden_layer_sizes=[config.layer_size, config.layer_size],
            #TODO: pass both q functions to use policy in deterministic mode
            qf=self.q1_target,
            reg=config.reg,
            device=self.hparams.device,
            reparametrization=True
        )  # GMM policy with K mixtures, no reparametrization trick, regularization

        # TODO: add assertion to test qf of policy and qf of model.

        self._policy_lr = config.lr
        self._qf_lr = config.lr
        self._vf_lr = config.lr
        # TODO fix varialbe naming with _
        self._scale_reward = config.scale_reward
        self._discount = config.discount
        self._tau = config.tau
        self.max_path_return = -np.inf
        self.last_path_return = 0
        self.val_path_return = 0
        self._scale_entropy = config.scale_entropy

        self._save_full_state = config.save_full_state
        self.modules = [
            "Policy", self.policy, "Q1", self.q1, "Q2", self.q2, "Q1_target",
            self.q1_target, "Q2_target", self.q2_target
        ]
        # self.z = self.get_best_skill(self.policy, self.env, self.config.num_skills, self.config.max_path_length)
        # self.env.reset(None,self.z)

        # Runs on CPU as Models are transferred to GPU only by trainer which happens after the lightning model init.
        # Also the reason why wandb logger is not available

    def on_sanity_check_start(self) -> None:
        q1 = ValueFunction(self.Do + self.Da,
                           [self.hparams.layer_size, self.hparams.layer_size])
        self.q1.load_state_dict(q1.state_dict())
        q2 = ValueFunction(self.Do + self.Da,
                           [self.hparams.layer_size, self.hparams.layer_size])
        self.q2.load_state_dict(q2.state_dict())

    def on_epoch_start(self) -> None:
        print("Distilling epoch %d" % self.current_epoch)

    def __dataloader(self) -> DataLoader:
        """Initialize the Replay Buffer dataset used for retrieving experiences"""
        dataset = RLDataset(self.pool_train, self.hparams.epoch_length,
                            self.hparams.batch_size)

        # TODO: figure out why referencee codeee uses episode length abovee instead of batch size

        def _init_fn(worker_id):
            np.random.seed(self.hparams.seed + worker_id)

        dataloader = DataLoader(dataset=dataset,
                                batch_size=self.hparams.batch_size,
                                num_workers=self.hparams.num_workers,
                                worker_init_fn=_init_fn)
        return dataloader

    def train_dataloader(self) -> DataLoader:
        """Get train loader"""
        return self.__dataloader()

    def val_dataloader(self) -> DataLoader:
        """Initialize the Replay Buffer dataset used for retrieving experiences"""
        dataset = RLDataset(self.pool_val, self.hparams.epoch_length,
                            self.hparams.batch_size)
        dataloader = DataLoader(
            dataset=dataset,
            batch_size=self.hparams.batch_size,
            # num_workers=5
        )
        return dataloader

    def training_step(self, batch, batch_idx) -> OrderedDict:

        states, actions, rewards, dones, next_states = batch

        self.q1_values = self.q1(states, actions)
        self.q2_values = self.q2(states, actions)
        # assert (self.policy._qf(states,actions)==self.q_values).all()
        with torch.no_grad():
            q1_next_target = self.q1_target(states, actions)  # N
            q2_next_target = self.q2_target(states, actions)

        self._td1_loss = torch.mean((q1_next_target - self.q1_values)**2)
        self._td2_loss = torch.mean((q2_next_target - self.q2_values)**2)

        return OrderedDict({
            'loss': self._td1_loss + self._td2_loss,
            'log': {
                'qf_distill_loss_%d' % self.stage:
                self._td1_loss + self._td2_loss,
                'qf_distill_value_%d' % self.stage: torch.mean(self.q1_values)
            },
            'progress_bar': {
                'qf_loss': self._td1_loss + self._td2_loss,
                'qf_mu': torch.mean(self.q1_values)
            }
        })

        # if self.trainer.use_dp or self.trainer.use_ddp2:
        #     loss = loss.unsqueeze(0)

    def configure_optimizers(self) -> List[Optimizer]:
        """ Initialize Adam optimizer"""
        optimizers = []
        optimizers.append(
            optim.Adam(list(self.q1.parameters()) + list(self.q2.parameters()),
                       lr=self._qf_lr))
        return optimizers

    def forward(self, *args, **kwargs):
        return None

    def validation_step(self, batch, batch_idx) -> OrderedDict:
        states, actions, rewards, dones, next_states = batch
        q1_values = self.q1(states, actions)
        q2_values = self.q2(states, actions)
        # assert (self.policy._qf(states,actions)==self.q_values).all()
        with torch.no_grad():
            q1_next_target = self.q1_target(states, actions)  # N
            q2_next_target = self.q2_target(states, actions)

        td1_loss = torch.mean((q1_next_target - q1_values)**2)
        td2_loss = torch.mean((q2_next_target - q2_values)**2)

        return OrderedDict({'val_loss': td2_loss + td1_loss})

    def validation_epoch_end(self, outputs) -> OrderedDict:
        # called at the end of a validation epoch
        # outputs is an array with what you returned in validation_step for each batch
        # outputs = [{'loss': batch_0_loss}, {'loss': batch_1_loss}, ..., {'loss': batch_n_loss}]

        avg_loss = torch.stack([x['val_loss'] for x in outputs]).mean()
        log = {'val_loss_distill_qf_%d' % self.stage: avg_loss}
        return {'val_loss': avg_loss, 'log': log}
コード例 #9
0
    def __init__(self, config: Config) -> None:
        super().__init__()
        self.hparams = config

        self.env = env_selector(
            self.hparams
        )  # TODO: normalization is not required but will it be needed?
        self.eval_env = env_selector(self.hparams, config.seed + 1)
        # TODO: check all config.names to ensure they are in dict
        self.Da = self.env.action_space.flat_dim
        self.Do = self.env.observation_space.flat_dim
        self.q1 = ValueFunction(self.Do + self.Da,
                                [config.layer_size, config.layer_size])
        self.q2 = ValueFunction(self.Do + self.Da,
                                [config.layer_size, config.layer_size])
        # Constructs a value function mlp with Relu hidden non-linearities, no output non-linearity and with xavier
        # init for weights and zero init for biases.
        self.q1_target = ValueFunction(self.Do + self.Da,
                                       [config.layer_size, config.layer_size])
        self.q2_target = ValueFunction(self.Do + self.Da,
                                       [config.layer_size, config.layer_size])

        self.q1_target.load_state_dict(self.q1.state_dict())
        self.q2_target.load_state_dict(self.q2.state_dict())

        self.pool = SimpleReplayBuffer(
            env_spec=self.env.spec,
            max_replay_buffer_size=config.max_pool_size,
        )  # create a replay buffer for state+skill and action.

        self.policy = GMMPolicy(
            env_spec=self.env.spec,
            K=config.K,
            hidden_layer_sizes=[config.layer_size, config.layer_size],
            #TODO: pass both q functions to use policy in deterministic mode
            qf=self.q1_target,
            reg=config.reg,
            device=self.hparams.device,
            reparametrization=True
        )  # GMM policy with K mixtures, no reparametrization trick, regularization

        # TODO: add assertion to test qf of policy and qf of model.

        self.sampler = Sampler(self.env, config.max_path_length)

        self._policy_lr = config.lr
        self._qf_lr = config.lr
        self._vf_lr = config.lr
        # TODO fix varialbe naming with _
        self._scale_reward = config.scale_reward
        self._discount = config.discount
        self._tau = config.tau
        self.max_path_return = -np.inf
        self.last_path_return = 0
        self.val_path_return = 0
        self._scale_entropy = config.scale_entropy

        self._save_full_state = config.save_full_state
        self.modules = [
            "Policy", self.policy, "Q1", self.q1, "Q2", self.q2, "Q1_target",
            self.q1_target, "Q2_target", self.q2_target
        ]
        # self.z = self.get_best_skill(self.policy, self.env, self.config.num_skills, self.config.max_path_length)
        # self.env.reset(None,self.z)

        # Runs on CPU as Models are transferred to GPU only by trainer which happens after the lightning model init.
        # Also the reason why wandb logger is not available
        self.batch_idx = None
コード例 #10
0
class SAC(pl.LightningModule):
    def __init__(self, config: Config) -> None:
        super().__init__()
        self.hparams = config

        self.env = env_selector(
            self.hparams
        )  # TODO: normalization is not required but will it be needed?
        self.eval_env = env_selector(self.hparams, config.seed + 1)
        # TODO: check all config.names to ensure they are in dict
        self.Da = self.env.action_space.flat_dim
        self.Do = self.env.observation_space.flat_dim
        self.q1 = ValueFunction(self.Do + self.Da,
                                [config.layer_size, config.layer_size])
        self.q2 = ValueFunction(self.Do + self.Da,
                                [config.layer_size, config.layer_size])
        # Constructs a value function mlp with Relu hidden non-linearities, no output non-linearity and with xavier
        # init for weights and zero init for biases.
        self.q1_target = ValueFunction(self.Do + self.Da,
                                       [config.layer_size, config.layer_size])
        self.q2_target = ValueFunction(self.Do + self.Da,
                                       [config.layer_size, config.layer_size])

        self.q1_target.load_state_dict(self.q1.state_dict())
        self.q2_target.load_state_dict(self.q2.state_dict())

        self.pool = SimpleReplayBuffer(
            env_spec=self.env.spec,
            max_replay_buffer_size=config.max_pool_size,
        )  # create a replay buffer for state+skill and action.

        self.policy = GMMPolicy(
            env_spec=self.env.spec,
            K=config.K,
            hidden_layer_sizes=[config.layer_size, config.layer_size],
            #TODO: pass both q functions to use policy in deterministic mode
            qf=self.q1_target,
            reg=config.reg,
            device=self.hparams.device,
            reparametrization=True
        )  # GMM policy with K mixtures, no reparametrization trick, regularization

        # TODO: add assertion to test qf of policy and qf of model.

        self.sampler = Sampler(self.env, config.max_path_length)

        self._policy_lr = config.lr
        self._qf_lr = config.lr
        self._vf_lr = config.lr
        # TODO fix varialbe naming with _
        self._scale_reward = config.scale_reward
        self._discount = config.discount
        self._tau = config.tau
        self.max_path_return = -np.inf
        self.last_path_return = 0
        self.val_path_return = 0
        self._scale_entropy = config.scale_entropy

        self._save_full_state = config.save_full_state
        self.modules = [
            "Policy", self.policy, "Q1", self.q1, "Q2", self.q2, "Q1_target",
            self.q1_target, "Q2_target", self.q2_target
        ]
        # self.z = self.get_best_skill(self.policy, self.env, self.config.num_skills, self.config.max_path_length)
        # self.env.reset(None,self.z)

        # Runs on CPU as Models are transferred to GPU only by trainer which happens after the lightning model init.
        # Also the reason why wandb logger is not available
        self.batch_idx = None
        # torch.autograd.set_detect_anomaly(True) #TODO: disable if compute overhead

    def get_best_skill(self,
                       policy,
                       env,
                       num_skills,
                       max_path_length,
                       n_paths=1):
        print('Finding best skill...')
        reward_list = []
        with policy.deterministic(self.hparams.deterministic_eval):
            for z in range(num_skills):
                env.reset(state=None, skill=z)
                total_returns = 0
                sampler = Sampler(env, max_path_length)
                for p in range(n_paths):
                    new_paths = sampler.sample(max_path_length, policy)
                    total_returns += new_paths[-1]['path_return']
                print('Reward for skill %d = %.3f' % (z, total_returns))
                reward_list.append(total_returns)

        best_z = np.argmax(reward_list)
        print('Best skill found: z = %d, reward = %d, seed = %d' %
              (best_z, reward_list[best_z], self.hparams.seed))
        return best_z

    def on_sanity_check_start(self) -> None:
        # self.z = self.get_best_skill(self.policy, self.env, self.hparams.num_skills, self.hparams.max_path_length,
        #                              self.hparams.num_runs)
        # self._num_skills = self.hparams.num_skills
        # self.env.reset(state=None, skill=self.z)
        # self.eval_env.reset(state=None, skill=self.z)
        # # TODO sampler reset logic and epoch length interaction seems adhoc
        # self.sampler.reset()
        if self.pool.size < self.hparams.min_pool_size:
            self.pool.add_samples(
                self.sampler.sample(self.hparams.min_pool_size, None))
            print("Initialized Replay Buffer with %d samples" % self.pool.size)

    def __dataloader(self) -> DataLoader:
        """Initialize the Replay Buffer dataset used for retrieving experiences"""
        dataset = RLDataset(self.pool, self.hparams.epoch_length,
                            self.hparams.batch_size)

        # TODO: figure out why referencee codeee uses episode length abovee instead of batch size

        def _init_fn(worker_id):
            np.random.seed(self.hparams.seed + worker_id)

        dataloader = DataLoader(dataset=dataset,
                                batch_size=self.hparams.batch_size,
                                num_workers=self.hparams.num_workers,
                                worker_init_fn=_init_fn)
        return dataloader

    def train_dataloader(self) -> DataLoader:
        """Get train loader"""
        return self.__dataloader()

    def val_dataloader(self) -> DataLoader:
        """Initialize the Replay Buffer dataset used for retrieving experiences"""
        dataset = RLDataset(self.pool, 1, 1)
        # TODO: figure out why referencee codeee uses episode length abovee instead of batch size
        dataloader = DataLoader(
            dataset=dataset,
            batch_size=1,
            # num_workers=5
        )
        return dataloader

    # def _split_obs(self,t):
    #     # TODO: verify that dim is 1, assert shape
    #     return torch.split(t, [self._Do, self._num_skills], 1)

    def training_step(self, batch, batch_idx, optimizer_idx) -> OrderedDict:

        states, actions, rewards, dones, next_states = batch
        self.batch_idx = batch_idx
        # print(states[0], batch_idx)

        # print(self.pool.size,optimizer_idx,batch_idx,states[0])
        # print("Running train",states.shape,batch_idx,optimizer_idx)

        # TODO: vars are already floatTensors.
        # Train Policy
        if optimizer_idx == 1:
            # for param in self.policy.parameters():
            #     print(param.names, param.size(), param.requires_grad)
            # print("Done")
            # for param in self.vf.parameters():
            #     print(param.names, param.size(), param.requires_grad)
            # print("Donevf")
            # print(torch.max(rewards),torch.min(rewards),torch.mean(rewards))
            samples = self.sampler.sample(
                1, self.policy)  # TODO remove magic numbers
            self.pool.add_samples(samples)

            if samples[0]['done'] or samples[0][
                    'path_length'] == self.hparams.max_path_length:
                self.max_path_return = max(self.max_path_return,
                                           samples[0]['path_return'])
                self.last_path_return = samples[0]['path_return']

            distributions, action_samples, log_probs, corr, reg_loss = self.policy(
                states)
            # print(log_probs.shape)
            assert log_probs.shape == torch.Size([action_samples.shape[0]])
            values1 = self.q1(states, action_samples)
            values2 = self.q2(states, action_samples)
            self.value = torch.min(values1, values2)  # N
            # print(action_samples.shape,log_probs.shape,reg_loss.shape,states.shape) #TODO assert shapes

            # with torch.no_grad():
            # TODO : check grad
            self.scaled_log_pi = self._scale_entropy * (log_probs - corr)
            self._policy_loss = torch.mean(self.scaled_log_pi - self.value)

            log = {
                'max_path_return': torch.tensor(self.max_path_return),
                'train_loss': self._policy_loss,
                'reg_loss': reg_loss,
                'vf_value': torch.mean(self.value)
            }
            status = {
                'train_loss': self._policy_loss,
                'max_ret': torch.tensor(self.max_path_return),
                'last_ret': torch.tensor(self.last_path_return),
                'vf_mu': torch.mean(self.value)
            }

            return OrderedDict({
                'loss': self._policy_loss,
                'log': log,
                'progress_bar': status
            })

        # Train QF
        if optimizer_idx == 0:
            # for param in self.qf.parameters():
            #     print(param.names, param.size(), param.requires_grad)
            # print("Doneqf")
            self.q1_values = self.q1(states, actions)
            self.q2_values = self.q2(states, actions)
            # assert (self.policy._qf(states,actions)==self.q_values).all()
            with torch.no_grad():
                distributions, action_samples, log_probs, corr, reg_loss = self.policy(
                    next_states)
                q1_next_target = self.q1_target(next_states,
                                                action_samples)  # N
                q2_next_target = self.q2_target(next_states, action_samples)
                q_next_target = torch.min(q1_next_target, q2_next_target)  # N

                ys = self._scale_reward * rewards + (1 - dones) * self._discount * \
                     (q_next_target-self._scale_entropy*(log_probs - corr))  # N

            self._td1_loss = torch.mean((ys - self.q1_values)**2)
            self._td2_loss = torch.mean((ys - self.q2_values)**2)

            return OrderedDict({
                'loss': self._td1_loss + self._td2_loss,
                'log': {
                    'qf_loss': self._td1_loss + self._td2_loss,
                    'qf_value': torch.mean(self.q1_values),
                    'rewards': torch.mean(rewards)
                },
                'progress_bar': {
                    'qf_loss': self._td1_loss + self._td2_loss,
                    'rewards': torch.mean(rewards),
                    'qf_mu': torch.mean(self.q1_values),
                    'log_probs': torch.mean(log_probs - corr)
                }
            })

        # if self.trainer.use_dp or self.trainer.use_ddp2:
        #     loss = loss.unsqueeze(0)

    def on_batch_end(self) -> None:
        with torch.no_grad():
            for q1, q1_targ in zip(self.q1.parameters(),
                                   self.q1_target.parameters()):
                q1_targ.data.mul_(1 - self.hparams.tau)
                q1_targ.data.add_((self.hparams.tau) * q1.data)
            for q2, q2_targ in zip(self.q2.parameters(),
                                   self.q2_target.parameters()):
                q2_targ.data.mul_(1 - self.hparams.tau)
                q2_targ.data.add_((self.hparams.tau) * q2.data)

    def validation_step(self, batch, batch_idx) -> OrderedDict:
        # state = self.eval_env.reset()
        # print("Running Validation step")
        # path_return = 0
        # path_length = 0
        # for i in range(self.config.max_path_length):
        #     action = self.policy.get_actions(state.reshape((1, -1)))
        #     next_ob, reward, terminal, info = self.env.step(action)
        #     state = next_ob
        #     path_return += reward
        #     path_length += 1
        #     if(terminal):
        #         break

        return OrderedDict({'val_ret': 0, 'path_len': 0})

    def validation_epoch_end(self, outputs) -> OrderedDict:
        state = self.eval_env.reset()
        print(
            datetime.datetime.now(
                dateutil.tz.tzlocal()).strftime('%Y-%m-%d-%H-%M-%S-%f-%Z'))
        # print("Running Validation")
        path_return = 0
        path_length = 0
        self.ims = []
        with self.policy.deterministic(self.hparams.deterministic_eval):
            for i in range(self.hparams.max_path_length):
                action = self.policy.get_actions(state.reshape((1, -1)))
                next_ob, reward, done, info = self.eval_env.step(action)
                # self.eval_env.render(mode='human')
                if self.hparams.render_validation:
                    self.ims.append(self.eval_env.render(mode='rgb_array'))
                    # print(self.ims[0].shape)#config={'height':500,'width':500,'xpos':0,'ypos':0,'title':'validation'}
                # print(reward)
                state = next_ob
                path_return += reward
                path_length += 1
                if (done):
                    break

        self.val_path_return = path_return  # TODO : remove printcall back for this, already printed in progress bar
        return OrderedDict({
            'log': {
                'path_return': path_return,
                'path_length': path_length
            },
            'progress_bar': {
                'val_ret': path_return,
                'path_len': path_length
            }
        })

    def configure_optimizers(self) -> List[Optimizer]:
        """ Initialize Adam optimizer"""
        optimizers = []
        # TODO: combining vf and policy, figure out more elegant way to have unlinked learning rates than as
        # a multiplication factor in the loss sum. Also figure out why having them separate doesn't increase
        # compute time by the expected
        optimizers.append(
            optim.Adam(list(self.q1.parameters()) + list(self.q2.parameters()),
                       lr=self._qf_lr))
        # optimizers.append(optim.Adam(self.vf.parameters(), lr=self._vf_lr))
        optimizers.append(
            optim.Adam(self.policy.parameters(), lr=self._policy_lr))
        return optimizers

    def forward(self, *args, **kwargs):
        return None