コード例 #1
0
class SACAgent:
    def __init__(self, env, gamma, tau, v_lr, q_lr, policy_lr, buffer_maxlen):
        self.device = torch.device(
            "cuda" if torch.cuda.is_available() else "cpu")

        self.env = env
        self.action_range = [env.action_space.low, env.action_space.high]
        self.obs_dim = env.observation_space.shape[0]
        self.action_dim = env.action_space.shape[0]

        # hyperparameters
        self.gamma = gamma
        self.tau = tau
        self.update_step = 0
        self.delay_step = 2

        # initialize networks
        self.value_net = ValueNetwork(self.obs_dim, 1).to(self.device)
        self.target_value_net = ValueNetwork(self.obs_dim, 1).to(self.device)
        self.q_net1 = SoftQNetwork(self.obs_dim,
                                   self.action_dim).to(self.device)
        self.q_net2 = SoftQNetwork(self.obs_dim,
                                   self.action_dim).to(self.device)
        self.policy_net = PolicyNetwork(self.obs_dim,
                                        self.action_dim).to(self.device)

        # copy params to target param
        for target_param, param in zip(self.target_value_net.parameters(),
                                       self.value_net.parameters()):
            target_param.data.copy_(param)

        # initialize optimizers
        self.value_optimizer = optim.Adam(self.value_net.parameters(), lr=v_lr)
        self.q1_optimizer = optim.Adam(self.q_net1.parameters(), lr=q_lr)
        self.q2_optimizer = optim.Adam(self.q_net2.parameters(), lr=q_lr)
        self.policy_optimizer = optim.Adam(self.policy_net.parameters(),
                                           lr=policy_lr)

        self.replay_buffer = BasicBuffer(buffer_maxlen)

    def get_action(self, state):
        state = torch.FloatTensor(state).unsqueeze(0).to(self.device)
        mean, log_std = self.policy_net.forward(state)
        std = log_std.exp()

        normal = Normal(mean, std)
        z = normal.sample()
        action = torch.tanh(z)
        action = action.cpu().detach().squeeze(0).numpy()

        return self.rescale_action(action)

    def rescale_action(self, action):
        return action * (self.action_range[1] - self.action_range[0]) / 2.0 +\
            (self.action_range[1] + self.action_range[0]) / 2.0

    def update(self, batch_size):
        states, actions, rewards, next_states, dones = self.replay_buffer.sample(
            batch_size)
        states = torch.FloatTensor(states).to(self.device)
        actions = torch.FloatTensor(actions).to(self.device)
        rewards = torch.FloatTensor(rewards).to(self.device)
        next_states = torch.FloatTensor(next_states).to(self.device)
        dones = torch.FloatTensor(dones).to(self.device)
        dones = dones.view(dones.size(0), -1)

        next_actions, next_log_pi = self.policy_net.sample(next_states)
        next_q1 = self.q_net1(next_states, next_actions)
        next_q2 = self.q_net2(next_states, next_actions)
        next_v = self.target_value_net(next_states)

        # value Loss
        next_v_target = torch.min(next_q1, next_q2) - next_log_pi
        curr_v = self.value_net.forward(states)
        v_loss = F.mse_loss(curr_v, next_v_target.detach())

        # q loss
        curr_q1 = self.q_net1.forward(states, actions)
        curr_q2 = self.q_net2.forward(states, actions)
        expected_q = rewards + (1 - dones) * self.gamma * next_v
        q1_loss = F.mse_loss(curr_q1, expected_q.detach())
        q2_loss = F.mse_loss(curr_q2, expected_q.detach())

        # update value network and q networks
        self.value_optimizer.zero_grad()
        v_loss.backward()
        self.value_optimizer.step()

        self.q1_optimizer.zero_grad()
        q1_loss.backward()
        self.q1_optimizer.step()

        self.q2_optimizer.zero_grad()
        q2_loss.backward()
        self.q2_optimizer.step()

        #delayed update for policy net and target value nets
        if self.update_step % self.delay_step == 0:
            new_actions, log_pi = self.policy_net.sample(states)
            min_q = torch.min(self.q_net1.forward(states, new_actions),
                              self.q_net2.forward(states, new_actions))
            policy_loss = (log_pi - min_q).mean()

            self.policy_optimizer.zero_grad()
            policy_loss.backward()
            self.policy_optimizer.step()

            # target networks
            for target_param, param in zip(self.target_value_net.parameters(),
                                           self.value_net.parameters()):
                target_param.data.copy_(self.tau * param +
                                        (1 - self.tau) * target_param)

        self.update_step += 1
コード例 #2
0
class OldSACAgent:
    def __init__(self, env, render, config_info):
        self.env = env
        self.render = render
        self._reset_env()

        # Create run folder to store parameters, figures, and tensorboard logs
        self.path_runs = create_run_folder(config_info)

        # Extract training parameters from yaml config file
        param = load_training_parameters(config_info["config_param"])
        self.train_param = param["training"]

        # Define device
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        print(f"Device in use : {self.device}")

        # Define state and action dimension spaces
        state_dim = env.observation_space.shape[0]
        num_actions = env.action_space.shape[0]

        # Define models
        hidden_size = param["model"]["hidden_size"]
        self.q_net = QNetwork(state_dim, num_actions, hidden_size).to(self.device)
        self.v_net = VNetwork(state_dim, hidden_size).to(self.device)
        self.target_v_net = VNetwork(state_dim, hidden_size).to(self.device)
        self.target_v_net.load_state_dict(self.v_net.state_dict())
        self.policy_net = PolicyNetwork(state_dim, num_actions, hidden_size).to(
            self.device
        )

        # Define loss criterion
        self.q_criterion = nn.MSELoss()
        self.v_criterion = nn.MSELoss()

        # Define optimizers
        lr = float(param["optimizer"]["learning_rate"])
        self.q_opt = optim.Adam(self.q_net.parameters(), lr=lr)
        self.v_opt = optim.Adam(self.v_net.parameters(), lr=lr)
        self.policy_opt = optim.Adam(self.policy_net.parameters(), lr=lr)

        # Initialize replay buffer
        self.replay_buffer = ReplayBuffer(param["training"]["replay_size"])

        self.transition = namedtuple(
            "transition",
            field_names=["state", "action", "reward", "done", "next_state"],
        )

        # Useful variables
        self.batch_size = param["training"]["batch_size"]
        self.gamma = param["training"]["gamma"]
        self.tau = param["training"]["tau"]
        self.start_step = param["training"]["start_step"]
        self.max_timesteps = param["training"]["max_timesteps"]
        self.alpha = param["training"]["alpha"]

    def _reset_env(self):
        # Reset the environment and initialize episode reward
        self.state, self.done = self.env.reset(), False
        self.episode_reward = 0.0
        self.episode_step = 0

    def train(self):
        # Main training loop
        total_timestep = 0
        all_episode_rewards = []
        all_mean_rewards = []
        update = 0

        # Create tensorboard writer
        writer = SummaryWriter(log_dir=self.path_runs, comment="-sac")

        for episode in itertools.count(1, 1):
            self._reset_env()

            while not self.done:
                # trick to improve exploration at the start of training
                if self.start_step > total_timestep:
                    action = self.env.action_space.sample()  # Sample random action
                else:
                    action = self.policy_net.get_action(
                        self.state, self.device
                    )  # Sample action from policy

                # Fill the replay buffer up with transitions
                if len(self.replay_buffer) > self.batch_size:
                    batch = self.replay_buffer.sample_buffer(self.batch_size)

                    # Update parameters of all the networks
                    q_loss, v_loss, policy_loss = self.train_on_batch(batch)
                    writer.add_scalar("loss/q", q_loss, update)
                    writer.add_scalar("loss/v", v_loss, update)
                    writer.add_scalar("loss/policy", policy_loss, update)
                    update += 1

                if self.render:
                    self.env.render()

                # Perform one step in the environment
                next_state, reward, self.done, _ = self.env.step(action)
                total_timestep += 1
                self.episode_step += 1
                self.episode_reward += reward

                # Create a tuple for the new transition
                new_transition = self.transition(
                    self.state, action, reward, self.done, next_state
                )

                # Append transition to the replay buffer
                self.replay_buffer.store_transition(new_transition)

                self.state = next_state

            if total_timestep > self.max_timesteps:
                break

            mean_reward = np.mean(all_episode_rewards[-100:])
            all_episode_rewards.append(self.episode_reward)
            all_mean_rewards.append(mean_reward)

            print(
                "Episode n°{} ; total timestep [{}/{}] ; episode steps {} ; "
                "reward {} ; mean reward {}".format(
                    episode,
                    total_timestep,
                    self.max_timesteps,
                    self.episode_step,
                    round(self.episode_reward, 2),
                    round(mean_reward, 2),
                )
            )

            writer.add_scalar("reward", self.episode_reward, episode)
            writer.add_scalar("mean reward", mean_reward, episode)

        # Save networks' weights
        path_critic = os.path.join(self.path_runs, "critic.pth")
        path_actor = os.path.join(self.path_runs, "actor.pth")
        torch.save(self.q_net.state_dict(), path_critic)
        torch.save(self.policy_net.state_dict(), path_actor)

        # Plot reward
        self.plot_reward(all_episode_rewards, all_mean_rewards)

        # Close all
        writer.close()
        self.env.close()

    def train_on_batch(self, batch_samples):
        # Unpack batch_size of transitions randomly drawn from the replay buffer
        (
            state_batch,
            action_batch,
            reward_batch,
            done_int_batch,
            next_state_batch,
        ) = batch_samples

        # Transform np arrays into tensors and send them to device
        state_batch = torch.tensor(state_batch).to(self.device)
        next_state_batch = torch.tensor(next_state_batch).to(self.device)
        action_batch = torch.tensor(action_batch).to(self.device)
        reward_batch = torch.tensor(reward_batch).unsqueeze(1).to(self.device)
        done_int_batch = torch.tensor(done_int_batch).unsqueeze(1).to(self.device)

        q_value, _ = self.q_net(state_batch, action_batch)
        value = self.v_net(state_batch)
        pi, log_pi = self.policy_net.sample(state_batch)

        ### Update Q
        target_next_value = self.target_v_net(next_state_batch)
        next_q_value = (
            reward_batch + (1 - done_int_batch) * self.gamma * target_next_value
        )

        q_loss = self.q_criterion(q_value, next_q_value.detach())

        ### Update V
        q_pi, _ = self.q_net(state_batch, pi)
        next_value = q_pi - log_pi
        v_loss = self.v_criterion(value, next_value.detach())

        ### Update policy
        log_pi_target = q_pi - value
        policy_loss = (log_pi * (log_pi - log_pi_target).detach()).mean()

        # Losses and optimizers
        self.q_opt.zero_grad()
        q_loss.backward()
        self.q_opt.step()

        self.v_opt.zero_grad()
        v_loss.backward()
        self.v_opt.step()

        self.policy_opt.zero_grad()
        policy_loss.backward()
        self.policy_opt.step()

        soft_update(self.target_v_net, self.v_net, self.tau)

        return q_loss.item(), v_loss.item(), policy_loss.item()

    def plot_reward(self, data, mean_data):
        plt.plot(data, label="reward")
        plt.plot(mean_data, label="mean reward")
        plt.xlabel("Episode")
        plt.ylabel("Reward")
        plt.title(f"Reward evolution for {self.env.unwrapped.spec.id} Gym environment")
        plt.tight_layout()
        plt.legend()

        path_fig = os.path.join(self.path_runs, "figure.png")
        plt.savefig(path_fig)
        print(f"Figure saved to {path_fig}")

        plt.show()
コード例 #3
0
class SACAgent:
    def __init__(self, env, gamma, tau, alpha, q_lr, policy_lr, a_lr,
                 buffer_maxlen):
        self.device = torch.device(
            "cuda" if torch.cuda.is_available() else "cpu")

        self.env = env
        self.action_range = [0, 250]
        self.obs_dim = env.state_dim
        self.action_dim = env.action_dim

        # hyperparameters
        self.gamma = gamma
        self.tau = tau
        self.update_step = 0
        self.delay_step = 2

        # initialize networks
        self.q_net1 = SoftQNetwork(self.obs_dim,
                                   self.action_dim).to(self.device)
        self.q_net2 = SoftQNetwork(self.obs_dim,
                                   self.action_dim).to(self.device)
        self.target_q_net1 = SoftQNetwork(self.obs_dim,
                                          self.action_dim).to(self.device)
        self.target_q_net2 = SoftQNetwork(self.obs_dim,
                                          self.action_dim).to(self.device)
        self.policy_net = PolicyNetwork(self.obs_dim,
                                        self.action_dim).to(self.device)

        # copy params to target param
        for target_param, param in zip(self.target_q_net1.parameters(),
                                       self.q_net1.parameters()):
            target_param.data.copy_(param)

        for target_param, param in zip(self.target_q_net2.parameters(),
                                       self.q_net2.parameters()):
            target_param.data.copy_(param)

        # initialize optimizers
        self.q1_optimizer = optim.Adam(self.q_net1.parameters(), lr=q_lr)
        self.q2_optimizer = optim.Adam(self.q_net2.parameters(), lr=q_lr)
        self.policy_optimizer = optim.Adam(self.policy_net.parameters(),
                                           lr=policy_lr)

        # entropy temperature
        self.alpha = alpha
        self.target_entropy = -torch.prod(
            torch.Tensor([self.action_dim, 1]).to(self.device)).item()
        self.log_alpha = torch.zeros(1, requires_grad=True, device=self.device)
        self.alpha_optim = optim.Adam([self.log_alpha], lr=a_lr)

        self.replay_buffer = BasicBuffer(buffer_maxlen)

    def get_action(self, state):
        state = torch.FloatTensor(state).unsqueeze(0).to(self.device)
        mean, log_std = self.policy_net.forward(state)
        std = log_std.exp()

        normal = Normal(mean, std)
        z = normal.sample()
        action = torch.tanh(z)
        action = action.cpu().detach().squeeze(0).numpy()

        return self.rescale_action(action)

    def rescale_action(self, action):
        return action * (self.action_range[1] - self.action_range[0]) / 2.0 +\
            (self.action_range[1] + self.action_range[0]) / 2.0

    def update(self, batch_size):
        states, actions, rewards, next_states, dones = self.replay_buffer.sample(
            batch_size)
        states = torch.FloatTensor(states).to(self.device)
        actions = torch.FloatTensor(actions).to(self.device)
        rewards = torch.FloatTensor(rewards).to(self.device)
        next_states = torch.FloatTensor(next_states).to(self.device)
        dones = torch.FloatTensor(dones).to(self.device)
        dones = dones.view(dones.size(0), -1)

        next_actions, next_log_pi = self.policy_net.sample(next_states)
        next_q1 = self.target_q_net1(next_states, next_actions)
        next_q2 = self.target_q_net2(next_states, next_actions)
        next_q_target = torch.min(next_q1, next_q2) - self.alpha * next_log_pi
        expected_q = rewards + (1 - dones) * self.gamma * next_q_target

        # q loss
        curr_q1 = self.q_net1.forward(states, actions)
        curr_q2 = self.q_net2.forward(states, actions)
        q1_loss = F.mse_loss(curr_q1, expected_q.detach())
        q2_loss = F.mse_loss(curr_q2, expected_q.detach())

        # update q networks
        self.q1_optimizer.zero_grad()
        q1_loss.backward()
        self.q1_optimizer.step()

        self.q2_optimizer.zero_grad()
        q2_loss.backward()
        self.q2_optimizer.step()

        # delayed update for policy network and target q networks
        new_actions, log_pi = self.policy_net.sample(states)
        if self.update_step % self.delay_step == 0:
            min_q = torch.min(self.q_net1.forward(states, new_actions),
                              self.q_net2.forward(states, new_actions))
            policy_loss = (self.alpha * log_pi - min_q).mean()

            self.policy_optimizer.zero_grad()
            policy_loss.backward()
            self.policy_optimizer.step()

            # target networks
            for target_param, param in zip(self.target_q_net1.parameters(),
                                           self.q_net1.parameters()):
                target_param.data.copy_(self.tau * param +
                                        (1 - self.tau) * target_param)

            for target_param, param in zip(self.target_q_net2.parameters(),
                                           self.q_net2.parameters()):
                target_param.data.copy_(self.tau * param +
                                        (1 - self.tau) * target_param)

        # update temperature
        alpha_loss = (self.log_alpha *
                      (-log_pi - self.target_entropy).detach()).mean()

        self.alpha_optim.zero_grad()
        alpha_loss.backward()
        self.alpha_optim.step()
        self.alpha = self.log_alpha.exp()

        self.update_step += 1
コード例 #4
0
class SACAgent():
    def __init__(self, env: object, gamma: float, tau: float,
                 buffer_maxlen: int, critic_lr: float, actor_lr: float,
                 reward_scale: int):

        # Selecting the device to use, wheter CUDA (GPU) if available or CPU
        self.device = torch.device(
            "cuda" if torch.cuda.is_available() else "cpu")

        # Creating the Gym environments for training and evaluation
        self.env = env
        # Get max and min values of the action of this environment
        self.action_range = [
            self.env.action_space.low, self.env.action_space.high
        ]
        # Get dimension of of the state and the action
        self.obs_dim = self.env.observation_space.shape[0]
        self.action_dim = self.env.action_space.shape[0]

        # hyperparameters
        self.gamma = gamma
        self.tau = tau
        self.critic_lr = critic_lr
        self.actor_lr = actor_lr
        self.buffer_maxlen = buffer_maxlen
        self.reward_scale = reward_scale

        # Scaling and bias factor for the actions -> We need scaling of the actions because each environment has different min and max values of actions
        self.scale = (self.action_range[1] - self.action_range[0]) / 2.0
        self.bias = (self.action_range[1] + self.action_range[0]) / 2.0

        # initialize networks
        self.q_net1 = SoftQNetwork(self.obs_dim,
                                   self.action_dim).to(self.device)
        self.target_q_net1 = SoftQNetwork(self.obs_dim,
                                          self.action_dim).to(self.device)
        self.q_net2 = SoftQNetwork(self.obs_dim,
                                   self.action_dim).to(self.device)
        self.target_q_net2 = SoftQNetwork(self.obs_dim,
                                          self.action_dim).to(self.device)
        self.policy = PolicyNetwork(self.obs_dim,
                                    self.action_dim).to(self.device)

        # copy weight parameters to the target Q networks
        for target_param, param in zip(self.target_q_net1.parameters(),
                                       self.q_net1.parameters()):
            target_param.data.copy_(param)

        for target_param, param in zip(self.target_q_net2.parameters(),
                                       self.q_net2.parameters()):
            target_param.data.copy_(param)

        # initialize optimizers
        self.q1_optimizer = optim.Adam(self.q_net1.parameters(),
                                       lr=self.critic_lr)
        self.q2_optimizer = optim.Adam(self.q_net2.parameters(),
                                       lr=self.critic_lr)
        self.policy_optimizer = optim.Adam(self.policy.parameters(),
                                           lr=self.actor_lr)

        # Create a replay buffer
        self.replay_buffer = BasicBuffer(self.buffer_maxlen)

    def update(self, batch_size: int):
        # Sampling experiences from the replay buffer
        states, actions, rewards, next_states, dones = self.replay_buffer.sample(
            batch_size)

        # Convert numpy arrays of experience tuples into pytorch tensors
        states = torch.FloatTensor(states).to(self.device)
        actions = torch.FloatTensor(actions).to(self.device)
        rewards = self.reward_scale * torch.FloatTensor(rewards).to(
            self.device)  # in SAC we do reward scaling for the sampled rewards
        next_states = torch.FloatTensor(next_states).to(self.device)
        dones = torch.FloatTensor(dones).to(self.device)
        dones = dones.view(dones.size(0), -1)

        # Critic update (computing the loss)
        # Please refer to equation (6) in the paper for details
        # Sample actions for the next states (s_t+1) using the current policy
        next_actions, next_log_pi, _, _ = self.policy.sample(
            next_states, self.scale)
        next_actions = self.rescale_action(next_actions)

        # Compute Q(s_t+1,a_t+1) by giving the states and actions to the Q network and choose the minimum from 2 target Q networks
        next_q1 = self.target_q_net1(next_states, next_actions)
        next_q2 = self.target_q_net2(next_states, next_actions)
        min_q = torch.min(next_q1,
                          next_q2)  # find minimum between next_q1 and next_q2

        # Compute the next Q_target (Q(s_t,a_t)-alpha(next_log_pi))
        next_q_target = (min_q - next_log_pi)

        # Compute the Q(s_t,a_t) using s_t and a_t from the replay buffer
        curr_q1 = self.q_net1.forward(states, actions)
        curr_q2 = self.q_net2.forward(states, actions)

        # Find expected Q, i.e., r(t) + gamma*next_q_target
        expected_q = rewards + (1 - dones) * self.gamma * next_q_target

        # Compute loss between Q network and expected Q
        q1_loss = F.mse_loss(curr_q1, expected_q.detach())
        q2_loss = F.mse_loss(curr_q2, expected_q.detach())

        # Backpropagate the losses and update Q network parameters
        self.q1_optimizer.zero_grad()
        q1_loss.backward()
        self.q1_optimizer.step()

        self.q2_optimizer.zero_grad()
        q2_loss.backward()
        self.q2_optimizer.step()

        # Policy update (computing the loss)
        # Sample new actions for the current states (s_t) using the current policy
        new_actions, log_pi, _, _ = self.policy.sample(states, self.scale)
        new_actions = self.rescale_action(new_actions)

        # Compute Q(s_t,a_t) and choose the minimum from 2 Q networks
        new_q1 = self.q_net1.forward(states, new_actions)
        new_q2 = self.q_net2.forward(states, new_actions)
        min_q = torch.min(new_q1, new_q2)

        # Compute the next policy loss, i.e., alpha*log_pi - Q(s_t,a_t) eq. (7)
        policy_loss = (log_pi - min_q).mean()

        # Backpropagate the losses and update policy network parameters
        self.policy_optimizer.zero_grad()
        policy_loss.backward()
        self.policy_optimizer.step()

        # Updating target networks with soft update using update rate tau
        for target_param, param in zip(self.target_q_net1.parameters(),
                                       self.q_net1.parameters()):
            target_param.data.copy_(self.tau * param +
                                    (1 - self.tau) * target_param)

        for target_param, param in zip(self.target_q_net2.parameters(),
                                       self.q_net2.parameters()):
            target_param.data.copy_(self.tau * param +
                                    (1 - self.tau) * target_param)

    def get_action(
            self, state: np.ndarray,
            stochastic: bool) -> Tuple[np.ndarray, torch.Tensor, torch.Tensor]:
        # state: the state input to the pi network
        # stochastic: boolean (True -> use noisy action, False -> use noiseless (deterministic action))
        state = torch.FloatTensor(state).unsqueeze(0).to(self.device)

        # Get mean and sigma from the policy network
        mean, log_std = self.policy.forward(state)
        std = log_std.exp()

        # Stochastic mode is used for training, non-stochastic mode is used for evaluation
        if stochastic:
            normal = Normal(mean, std)
            z = normal.sample()
            action = torch.tanh(z)
            action = action.cpu().detach().squeeze(0).numpy()
        else:
            normal = Normal(mean, 0)
            z = normal.sample()
            action = torch.tanh(z)
            action = action.cpu().detach().squeeze(0).numpy()

        # return a rescaled action, and also the mean and standar deviation of the action
        # we use a rescaled action since the output of the policy network is [-1,1] and the mujoco environments could be ranging from [-n,n] where n is an arbitrary real value
        return self.rescale_action(action), mean, std

    def rescale_action(self, action: np.ndarray) -> np.ndarray:
        # we use a rescaled action since the output of the policy network is [-1,1] and the mujoco environments could be ranging from [-n,n] where n is an arbitrary real value
        # scale -> scalar multiplication
        # bias -> scalar offset
        return action * self.scale[0] + self.bias[0]

    def Actor_save(self, WORKSPACE: str):
        # save 각 node별 모델 저장
        print("Save the torch model")
        savePath = WORKSPACE + "./policy_model5_Hop_.pth"
        torch.save(self.policy.state_dict(), savePath)

    def Actor_load(self, WORKSPACE: str):
        # save 각 node별 모델 로드
        print("load the torch model")
        savePath = WORKSPACE + "./policy_model5_Hop_.pth"  # Best
        self.policy = PolicyNetwork(self.obs_dim,
                                    self.action_dim).to(self.device)
        self.policy.load_state_dict(torch.load(savePath))
コード例 #5
0
ファイル: sac.py プロジェクト: arcosin/ANP_Robot_RL_Sim
class SACAgent:
    def __init__(self, env, gamma, tau, v_lr, q_lr, policy_lr, buffer_maxlen):
        self.device = torch.device(
            "cuda" if torch.cuda.is_available() else "cpu")

        self.firsttime = 0

        self.env = env
        self.action_range = [env.action_space.low, env.action_space.high]
        #self.obs_dim = env.observation_space.shape[0]
        self.action_dim = env.action_space.shape[0]  #1

        self.conv_channels = 4
        self.kernel_size = (3, 3)

        self.img_size = (500, 500, 3)

        print("Diagnostics:")
        print(f"action_range: {self.action_range}")
        #print(f"obs_dim: {self.obs_dim}")
        print(f"action_dim: {self.action_dim}")

        # hyperparameters
        self.gamma = gamma
        self.tau = tau
        self.update_step = 0
        self.delay_step = 2

        # initialize networks
        self.feature_net = FeatureExtractor(self.img_size[2],
                                            self.conv_channels,
                                            self.kernel_size).to(self.device)
        print("Feature net init'd successfully")

        input_dim = self.feature_net.get_output_size(self.img_size)
        self.input_size = input_dim[0] * input_dim[1] * input_dim[2]
        print(f"input_size: {self.input_size}")

        self.value_net = ValueNetwork(self.input_size, 1).to(self.device)
        self.target_value_net = ValueNetwork(self.input_size,
                                             1).to(self.device)
        self.q_net1 = SoftQNetwork(self.input_size,
                                   self.action_dim).to(self.device)
        self.q_net2 = SoftQNetwork(self.input_size,
                                   self.action_dim).to(self.device)
        self.policy_net = PolicyNetwork(self.input_size,
                                        self.action_dim).to(self.device)

        print("Finished initing all nets")

        # copy params to target param
        for target_param, param in zip(self.target_value_net.parameters(),
                                       self.value_net.parameters()):
            target_param.data.copy_(param)

        print("Finished copying targets")

        # initialize optimizers
        self.value_optimizer = optim.Adam(self.value_net.parameters(), lr=v_lr)
        self.q1_optimizer = optim.Adam(self.q_net1.parameters(), lr=q_lr)
        self.q2_optimizer = optim.Adam(self.q_net2.parameters(), lr=q_lr)
        self.policy_optimizer = optim.Adam(self.policy_net.parameters(),
                                           lr=policy_lr)

        print("Finished initing optimizers")

        self.replay_buffer = BasicBuffer(buffer_maxlen)
        print("End of init")

    def get_action(self, state):
        if state.shape != self.img_size:
            print(
                f"Invalid size, expected shape {self.img_size}, got {state.shape}"
            )
            return None

        inp = torch.from_numpy(state).float().permute(2, 0, 1).unsqueeze(0).to(
            self.device)
        features = self.feature_net(inp)
        features = features.view(-1, self.input_size)

        mean, log_std = self.policy_net.forward(features)
        std = log_std.exp()

        normal = Normal(mean, std)
        z = normal.sample()
        action = torch.tanh(z)
        action = action.cpu().detach().squeeze(0).numpy()

        return self.rescale_action(action)

    def rescale_action(self, action):
        return action * (self.action_range[1] - self.action_range[0]) / 2.0 +\
            (self.action_range[1] + self.action_range[0]) / 2.0

    def update(self, batch_size):
        states, actions, rewards, next_states, dones = self.replay_buffer.sample(
            batch_size)

        # states and next states are lists of ndarrays, np.stack converts them to
        # ndarrays of shape (batch_size, height, width, num_channels)
        states = np.stack(states)
        next_states = np.stack(next_states)

        states = torch.FloatTensor(states).permute(0, 3, 1, 2).to(self.device)
        actions = torch.FloatTensor(actions).to(self.device)
        rewards = torch.FloatTensor(rewards).to(self.device)
        next_states = torch.FloatTensor(next_states).permute(0, 3, 1,
                                                             2).to(self.device)
        dones = torch.FloatTensor(dones).to(self.device)
        dones = dones.view(dones.size(0), -1)

        # Process images
        features = self.feature_net(
            states)  #.contiguous() # Properly shaped due to batching
        next_features = self.feature_net(next_states)  #.contiguous()

        features = torch.reshape(features, (64, self.input_size))
        next_features = torch.reshape(next_features, (64, self.input_size))

        next_actions, next_log_pi = self.policy_net.sample(next_features)
        next_q1 = self.q_net1(next_features, next_actions)
        next_q2 = self.q_net2(next_features, next_actions)
        next_v = self.target_value_net(next_features)

        next_v_target = torch.min(next_q1, next_q2) - next_log_pi
        curr_v = self.value_net.forward(features)
        v_loss = F.mse_loss(curr_v, next_v_target.detach())

        # q loss
        expected_q = rewards + (1 - dones) * self.gamma * next_v
        curr_q1 = self.q_net1.forward(features, actions)
        curr_q2 = self.q_net2.forward(features, actions)
        q1_loss = F.mse_loss(curr_q1, expected_q.detach())
        q2_loss = F.mse_loss(curr_q2, expected_q.detach())

        # update value and q networks
        self.value_optimizer.zero_grad()
        v_loss.backward(retain_graph=True)
        self.value_optimizer.step()

        self.q1_optimizer.zero_grad()
        q1_loss.backward(retain_graph=True)
        self.q1_optimizer.step()

        self.q2_optimizer.zero_grad()
        q2_loss.backward(retain_graph=True)
        self.q2_optimizer.step()

        # delayed update for policy network and target q networks
        if self.update_step % self.delay_step == 0:
            new_actions, log_pi = self.policy_net.sample(features)
            min_q = torch.min(self.q_net1.forward(features, new_actions),
                              self.q_net2.forward(features, new_actions))
            policy_loss = (log_pi - min_q).mean()

            self.policy_optimizer.zero_grad()
            policy_loss.backward(retain_graph=True)
            self.policy_optimizer.step()

            # target networks
            for target_param, param in zip(self.target_value_net.parameters(),
                                           self.value_net.parameters()):
                target_param.data.copy_(self.tau * param +
                                        (1 - self.tau) * target_param)

        self.update_step += 1