示例#1
0
文件: td3_agent.py 项目: harwiltz/TD3
    def __init__(self,
                 env,
                 env_kwargs=None,
                 pre_train_steps=1000,
                 max_replay_capacity=10000,
                 actor_lr=3e-4,
                 critic_lr=3e-4,
                 tau=5e-3,
                 gamma=0.999,
                 batch_size=32,
                 value_delay=2):
        self._env_name = env
        if env_kwargs is None:
            self._env_fn = lambda: gym.make(self._env_name)
        else:
            self._env_fn = lambda: gym.make(self._env_name, **env_kwargs)

        env = self._env_fn()
        self._obs_dim = np.prod(env.observation_space.shape)
        self._act_dim = np.prod(env.action_space.shape)
        self._training = False
        self._total_steps = 0
        self._pre_train_steps = pre_train_steps
        self._batch_size = batch_size
        max_action = env.action_space.high
        min_action = env.action_space.low
        self._replay_buf = ReplayBuffer(self._obs_dim, self._act_dim,
                                        max_replay_capacity)
        self._td3 = TD3(self._obs_dim,
                        self._act_dim,
                        max_action,
                        min_action=min_action,
                        discount=gamma,
                        tau=tau,
                        policy_freq=value_delay)
示例#2
0
kwargs = {
    "state_dim": state_dim,
    "action_dim": action_dim,
    "max_action": max_action,
    "discount": args.discount,
    "tau": args.tau,
    "policy": args.policy
}

# Target policy smoothing is scaled wrt the action scale
kwargs["policy_noise"] = args.policy_noise * max_action
kwargs["noise_clip"] = args.noise_clip * max_action
kwargs["policy_freq"] = args.policy_freq
policy = TD3.TD3(**kwargs)

replay_buffer = ReplayBuffer(state_dim, action_dim, max_size=int(1e5))

# Evaluate untrained policy
evaluations = [eval_policy(policy, args.env_name, args.seed)]

state, done = env.reset(), False
episode_reward = 0
episode_timesteps = 0
episode_num = 0

for t in range(int(args.max_timesteps)):

    episode_timesteps += 1

    # Select action randomly or according to policy
    if t < args.start_timesteps:
示例#3
0
    def start_training(self, env, load=False, der_activated=False):
        parser = argparse.ArgumentParser()
        parser.add_argument("--policy", default="TD3")  # Policy name (TD3, DDPG or OurDDPG)
        parser.add_argument("--env",
                            default="AlphaWorm")  # OpenAI gym environment name (not used to start env in AlphaWorm)
        parser.add_argument("--seed", default=0, type=int)  # Sets Gym, PyTorch and Numpy seeds
        parser.add_argument("--start_timesteps", default=1e6, type=int)  # Time steps initial random policy is used
        parser.add_argument("--eval_freq", default=5e3, type=int)  # How often (time steps) we evaluate
        parser.add_argument("--max_timesteps", default=1e9, type=int)  # Max time steps to run environment
        parser.add_argument("--max_env_episode_steps", default=1e3, type=int)  # Max env steps
        parser.add_argument("--expl_noise", default=0.1)  # Std of Gaussian exploration noise
        parser.add_argument("--random_policy_ratio",
                            default=1)  # ratio of random episodes 1 = as many random as policy, 2 = double as many policy as random ...
        parser.add_argument("--batch_size", default=256, type=int)  # Batch size for both actor and critic
        parser.add_argument("--discount", default=0.99)  # Discount factor
        parser.add_argument("--tau", default=0.005)  # Target network update rate
        parser.add_argument("--policy_noise", default=0.2)  # Noise added to target policy during critic update
        parser.add_argument("--noise_clip", default=0.5)  # Range to clip target policy noise
        parser.add_argument("--policy_freq", default=2, type=int)  # Frequency of delayed policy updates
        parser.add_argument("--save_model", default=True, action="store_true")  # Save model and optimizer parameters
        if load:
            parser.add_argument("--load_model",
                                default="default")  # Model load file name, "" doesn't load, "default" uses file_name
        else:
            parser.add_argument("--load_model",
                                default="")  # Model load file name, "" doesn't load, "default" uses file_name
        if der_activated:
            parser.add_argument("--load_replays",
                                default="buffers")  # Loads pre-trained replays to replay into the buffer "" doesn't load, "..." loads from the specified folder name
        else:
            parser.add_argument("--load_replays",
                                default="")  # Loads pre-trained replays to replay into the buffer "" doesn't load, "..." loads from the s
        parser.add_argument("--random_policy", default=False)  # Activate random policy

        args = parser.parse_args()

        file_name = f"{args.policy}_{args.env}_{args.seed}"
        print("---------------------------------------")
        print(f"{datetime.now()} \t Policy: {args.policy}, Env: {args.env}, Seed: {args.seed}")
        print("---------------------------------------")

        if not os.path.exists("./results"):
            os.makedirs("./results")

        if args.save_model and not os.path.exists("./models"):
            os.makedirs("./models")

        if not os.path.exists("./buffers"):
            os.makedirs("./buffers")

        # Set seeds
        # env.seed(args.seed)
        env.action_space.seed(args.seed)
        torch.manual_seed(args.seed)
        np.random.seed(args.seed)

        state_dim = env.observation_space.shape[0]
        action_dim = env.action_space.shape[0]
        max_action = float(env.action_space.high[0])

        kwargs = {
            "state_dim": state_dim,
            "action_dim": action_dim,
            "max_action": max_action,
            "discount": args.discount,
            "tau": args.tau,
        }

        # Initialize policy
        if args.policy == "TD3":
            # Target policy smoothing is scaled wrt the action scale
            kwargs["policy_noise"] = args.policy_noise * max_action
            kwargs["noise_clip"] = args.noise_clip * max_action
            kwargs["policy_freq"] = args.policy_freq
            policy = TD3(**kwargs)

        if args.load_model != "":
            policy_file = file_name if args.load_model == "default" else args.load_model
            policy.load(f"./models/{policy_file}")

        replay_buffer = ReplayBuffer(state_dim, action_dim)
        best_buffer = ReplayBuffer(state_dim, action_dim)
        der_buffer = DynamicExperienceReplay(state_dim, action_dim)

        if args.load_replays != "":
            batch = der_buffer.load(args.load_replays, True)
            if batch is not None:
                policy.train(batch, args.batch_size)
            else:
                print("No buffer batch loaded")

        # Evaluate untrained policy
        evaluations = [self.eval_policy(policy, env, args.seed)]

        state, done = env.reset(), False
        episode_reward = 0
        episode_timesteps = 0
        episode_num = 0

        for t in range(int(args.max_timesteps)):
            episode_timesteps += 1
            if args.random_policy:
                # Select action randomly or according to policy
                if t % ((args.random_policy_ratio + 1) * args.start_timesteps) < args.start_timesteps:
                    action = env.action_space.sample()
                else:
                    action = (
                            policy.select_action(np.array(state))
                            + np.random.normal(0, max_action * args.expl_noise, size=action_dim)
                    ).clip(-max_action, max_action)
            else:
                if t < args.start_timesteps:
                    action = env.action_space.sample()
                else:
                    action = (
                            policy.select_action(np.array(state))
                            + np.random.normal(0, max_action * args.expl_noise, size=action_dim)
                    ).clip(-max_action, max_action)

            # Perform action
            action = np.array(action).reshape((1, 9))
            next_state, reward, done, _ = env.step(action)
            done = True if episode_timesteps % args.max_env_episode_steps == 0 else False
            done_bool = float(done) if episode_timesteps < args.max_env_episode_steps else 0

            # Store data in replay buffer
            replay_buffer.add(state, action, next_state, reward, done_bool)
            best_buffer.add(state, action, next_state, reward, done_bool)

            # Store buffer
            if done:
                der_buffer.add(best_buffer)
                best_buffer = ReplayBuffer(state_dim, action_dim)

            state = next_state
            episode_reward += reward

            # Train agent after collecting sufficient data
            if t >= args.start_timesteps:
                policy.train(replay_buffer, args.batch_size)

            if done:
                # +1 to account for 0 indexing. +0 on ep_timesteps since it will increment +1 even if done=True
                print(
                    f"{datetime.now()} \t Total T: {t + 1} Episode Num: {episode_num + 1} Episode T: {episode_timesteps} Reward: {episode_reward}")
                # Reset environment
                state, done = env.reset(), False
                episode_reward = 0
                episode_timesteps = 0
                episode_num += 1

            # Evaluate episode
            if (t + 1) % args.eval_freq == 0:
                evaluations.append(self.eval_policy(policy, env, args.seed))
                np.save(f"./results/{file_name}", evaluations)
                if args.save_model: policy.save(f"./models/{file_name}")
                if args.load_replays != "":
                    batch = der_buffer.load(args.load_replays, True)
                    if batch is not None:
                        policy.train(batch, args.batch_size)
                    else:
                        print("No buffer batch loaded")

            if (t + 1) % (args.max_env_episode_steps * 100) == 0:
                der_buffer.save()
示例#4
0
    def __init__(self, env, alpha, beta, hidden_dims, tau,
                 batch_size, gamma, d, warmup, max_size, c,
                 sigma, one_device, log_dir, checkpoint_dir,
                 img_input, in_channels, order, depth, multiplier,
                 action_embed_dim, hidden_dim, crop_dim, img_feature_dim):
        if img_input:
            input_dim = [in_channels * order, crop_dim, crop_dim]
        else:
            input_dim = env.observation_space.shape
            state_space = input_dim[0]
        n_actions = env.action_space.shape[0]

        # training params
        self.gamma = gamma
        self.tau = tau
        self.max_action = env.action_space.high[0]
        self.min_action = env.action_space.low[0]
        self.buffer = ReplayBuffer(max_size, input_dim, n_actions)
        self.batch_size = batch_size
        self.learn_step_counter = 0
        self.time_step = 0
        self.warmup = warmup
        self.n_actions = n_actions
        self.d = d
        self.c = c
        self.sigma = sigma
        self.img_input = img_input
        self.env = env

        # training device
        if one_device:
            os.environ["CUDA_VISIBLE_DEVICES"] = "0"
        self.device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

        # logging/checkpointing
        self.writer = SummaryWriter(log_dir)
        self.checkpoint_dir = checkpoint_dir

        # networks & optimizers
        if img_input:
            self.actor = ImageActor(in_channels, n_actions, hidden_dim, self.max_action, order, depth, multiplier, crop_dim, 'actor').to(self.device)
            self.critic_1 = ImageCritic(in_channels, n_actions, hidden_dim, action_embed_dim, order, depth, multiplier, crop_dim, img_feature_dim, 'critic_1').to(self.device)
            self.critic_2 = ImageCritic(in_channels, n_actions, hidden_dim, action_embed_dim, order, depth, multiplier, crop_dim, img_feature_dim, 'critic_2').to(self.device)

            self.target_actor = ImageActor(in_channels, n_actions, hidden_dim, self.max_action, order, depth, multiplier, crop_dim, 'target_actor').to(self.device)
            self.target_critic_1 = ImageCritic(in_channels, n_actions, hidden_dim, action_embed_dim, order, depth, multiplier, crop_dim, img_feature_dim, 'target_critic_1').to(self.device)
            self.target_critic_2 = ImageCritic(in_channels, n_actions, hidden_dim, action_embed_dim, order, depth, multiplier, crop_dim, img_feature_dim, 'target_critic_2').to(self.device)

        # physics networks
        else:
            self.actor = Actor(state_space, hidden_dims, n_actions, self.max_action, 'actor').to(self.device)
            self.critic_1 = Critic(state_space, hidden_dims, n_actions, 'critic_1').to(self.device)
            self.critic_2 = Critic(state_space, hidden_dims, n_actions, 'critic_2').to(self.device)

            self.target_actor = Actor(state_space, hidden_dims, n_actions, self.max_action, 'target_actor').to(self.device)
            self.target_critic_1 = Critic(state_space, hidden_dims, n_actions, 'target_critic_1').to(self.device)
            self.target_critic_2 = Critic(state_space, hidden_dims, n_actions, 'target_critic_2').to(self.device)

        self.critic_optimizer = torch.optim.Adam(
                chain(self.critic_1.parameters(), self.critic_2.parameters()), lr=beta)
        self.actor_optimizer = torch.optim.Adam(self.actor.parameters(), lr=alpha)

        # copy weights
        self.update_network_parameters(tau=1)
示例#5
0
class Agent:
    def __init__(self, env, alpha, beta, hidden_dims, tau,
                 batch_size, gamma, d, warmup, max_size, c,
                 sigma, one_device, log_dir, checkpoint_dir,
                 img_input, in_channels, order, depth, multiplier,
                 action_embed_dim, hidden_dim, crop_dim, img_feature_dim):
        if img_input:
            input_dim = [in_channels * order, crop_dim, crop_dim]
        else:
            input_dim = env.observation_space.shape
            state_space = input_dim[0]
        n_actions = env.action_space.shape[0]

        # training params
        self.gamma = gamma
        self.tau = tau
        self.max_action = env.action_space.high[0]
        self.min_action = env.action_space.low[0]
        self.buffer = ReplayBuffer(max_size, input_dim, n_actions)
        self.batch_size = batch_size
        self.learn_step_counter = 0
        self.time_step = 0
        self.warmup = warmup
        self.n_actions = n_actions
        self.d = d
        self.c = c
        self.sigma = sigma
        self.img_input = img_input
        self.env = env

        # training device
        if one_device:
            os.environ["CUDA_VISIBLE_DEVICES"] = "0"
        self.device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

        # logging/checkpointing
        self.writer = SummaryWriter(log_dir)
        self.checkpoint_dir = checkpoint_dir

        # networks & optimizers
        if img_input:
            self.actor = ImageActor(in_channels, n_actions, hidden_dim, self.max_action, order, depth, multiplier, crop_dim, 'actor').to(self.device)
            self.critic_1 = ImageCritic(in_channels, n_actions, hidden_dim, action_embed_dim, order, depth, multiplier, crop_dim, img_feature_dim, 'critic_1').to(self.device)
            self.critic_2 = ImageCritic(in_channels, n_actions, hidden_dim, action_embed_dim, order, depth, multiplier, crop_dim, img_feature_dim, 'critic_2').to(self.device)

            self.target_actor = ImageActor(in_channels, n_actions, hidden_dim, self.max_action, order, depth, multiplier, crop_dim, 'target_actor').to(self.device)
            self.target_critic_1 = ImageCritic(in_channels, n_actions, hidden_dim, action_embed_dim, order, depth, multiplier, crop_dim, img_feature_dim, 'target_critic_1').to(self.device)
            self.target_critic_2 = ImageCritic(in_channels, n_actions, hidden_dim, action_embed_dim, order, depth, multiplier, crop_dim, img_feature_dim, 'target_critic_2').to(self.device)

        # physics networks
        else:
            self.actor = Actor(state_space, hidden_dims, n_actions, self.max_action, 'actor').to(self.device)
            self.critic_1 = Critic(state_space, hidden_dims, n_actions, 'critic_1').to(self.device)
            self.critic_2 = Critic(state_space, hidden_dims, n_actions, 'critic_2').to(self.device)

            self.target_actor = Actor(state_space, hidden_dims, n_actions, self.max_action, 'target_actor').to(self.device)
            self.target_critic_1 = Critic(state_space, hidden_dims, n_actions, 'target_critic_1').to(self.device)
            self.target_critic_2 = Critic(state_space, hidden_dims, n_actions, 'target_critic_2').to(self.device)

        self.critic_optimizer = torch.optim.Adam(
                chain(self.critic_1.parameters(), self.critic_2.parameters()), lr=beta)
        self.actor_optimizer = torch.optim.Adam(self.actor.parameters(), lr=alpha)

        # copy weights
        self.update_network_parameters(tau=1)

    def _get_noise(self, clip=True):
        noise = torch.randn(self.n_actions, dtype=torch.float, device=self.device) * self.sigma
        if clip:
            noise = noise.clamp(-self.c, self.c)
        return noise

    def _clamp_action_bound(self, action):
        return action.clamp(self.min_action, self.max_action)

    def choose_action(self, observation):
        if self.time_step < self.warmup:
            mu = self.env.action_space.sample()
        else:
            state = torch.tensor(observation, dtype=torch.float).to(self.device)
            mu = self.actor(state) + self._get_noise(clip=False)
            mu = self._clamp_action_bound(mu).cpu().detach().numpy()
        self.time_step += 1
        return mu

    def remember(self, state, action, reward, state_, done):
        self.buffer.store_transition(state, action, reward, state_, done)

    def critic_step(self, state, action, reward, state_, done):
        with torch.no_grad():
            # get target actions w/ noise
            target_actions = self.target_actor(state_) + self._get_noise()
            target_actions = self._clamp_action_bound(target_actions)

            # target & online values
            q1_ = self.target_critic_1(state_, target_actions)
            q2_ = self.target_critic_2(state_, target_actions)

            # done mask
            q1_[done], q2_[done] = 0.0, 0.0

            q1_ = q1_.view(-1)
            q2_ = q2_.view(-1)        

            critic_value_ = torch.min(q1_, q2_)

            target = reward + self.gamma * critic_value_
            target = target.unsqueeze(1)

        q1 = self.critic_1(state, action)
        q2 = self.critic_2(state, action)

        critic_loss = F.mse_loss(q1, target) + F.mse_loss(q2, target)

        self.critic_optimizer.zero_grad()
        critic_loss.backward()
        self.critic_optimizer.step()

        self.writer.add_scalar('Critic loss', critic_loss.item(), global_step=self.learn_step_counter)

    def actor_step(self, state):
        # calculate loss, update actor params
        actor_loss = -torch.mean(self.critic_1(state, self.actor(state)))
        self.actor_optimizer.zero_grad()
        actor_loss.backward()
        self.actor_optimizer.step()

        # update & log
        self.update_network_parameters()
        self.writer.add_scalar('Actor loss', actor_loss.item(), global_step=self.learn_step_counter)

    def learn(self):
        self.learn_step_counter += 1

        # if the buffer is not yet filled w/ enough samples
        if self.buffer.counter < self.batch_size:
            return

        # transitions
        state, action, reward, state_, done = self.buffer.sample_buffer(self.batch_size)
        reward = torch.tensor(reward, dtype=torch.float).to(self.device)
        done = torch.tensor(done).to(self.device)
        state = torch.tensor(state, dtype=torch.float).to(self.device)
        state_ = torch.tensor(state_, dtype=torch.float).to(self.device)
        action = torch.tensor(action, dtype=torch.float).to(self.device)

        self.critic_step(state, action, reward, state_, done)
        if self.learn_step_counter % self.d == 0:
            self.actor_step(state)

    def momentum_update(self, online_network, target_network, tau):
        for param_o, param_t in zip(online_network.parameters(), target_network.parameters()):
            param_t.data.copy_(tau * param_o.data + (1. - tau) * param_t.data)

    def update_network_parameters(self, tau=None):
        if tau is None:
            tau = self.tau
        self.momentum_update(self.critic_1, self.target_critic_1, tau)
        self.momentum_update(self.critic_2, self.target_critic_2, tau)
        self.momentum_update(self.actor, self.target_actor, tau)

    def add_scalar(self, tag, scalar_value, global_step=None):
        self.writer.add_scalar(tag, scalar_value, global_step=global_step)

    def save_networks(self):
        torch.save({
            'actor': self.actor.state_dict(),
            'target_actor': self.target_actor.state_dict(),
            'critic_1': self.critic_1.state_dict(),
            'critic_2': self.critic_2.state_dict(),
            'target_critic_1': self.target_critic_1.state_dict(),
            'target_critic_2': self.target_critic_2.state_dict(),
            'critic_optimizer': self.critic_optimizer.state_dict(),
            'actor_optimizer': self.actor_optimizer.state_dict(),
        }, self.checkpoint_dir)

    def load_state_dicts(self):
        state_dict = torch.load(self.checkpoint_dir)
        self.actor.load_state_dict(state_dict['actor'])
        self.target_actor.load_state_dict(state_dict['target_actor'])
        self.critic_1.load_state_dict(state_dict['critic_1'])
        self.critic_2.load_state_dict(state_dict['critic_2'])
        self.target_critic_1.load_state_dict(state_dict['target_critic_1'])
        self.target_critic_2.load_state_dict(state_dict['target_critic_2'])
        self.critic_optimizer.load_state_dict(state_dict['critic_optimizer'])
        self.actor_optimizer.load_state_dict(state_dict['actor_optimizer'])
示例#6
0
文件: td3_agent.py 项目: harwiltz/TD3
class TD3Agent:
    def __init__(self,
                 env,
                 env_kwargs=None,
                 pre_train_steps=1000,
                 max_replay_capacity=10000,
                 actor_lr=3e-4,
                 critic_lr=3e-4,
                 tau=5e-3,
                 gamma=0.999,
                 batch_size=32,
                 value_delay=2):
        self._env_name = env
        if env_kwargs is None:
            self._env_fn = lambda: gym.make(self._env_name)
        else:
            self._env_fn = lambda: gym.make(self._env_name, **env_kwargs)

        env = self._env_fn()
        self._obs_dim = np.prod(env.observation_space.shape)
        self._act_dim = np.prod(env.action_space.shape)
        self._training = False
        self._total_steps = 0
        self._pre_train_steps = pre_train_steps
        self._batch_size = batch_size
        max_action = env.action_space.high
        min_action = env.action_space.low
        self._replay_buf = ReplayBuffer(self._obs_dim, self._act_dim,
                                        max_replay_capacity)
        self._td3 = TD3(self._obs_dim,
                        self._act_dim,
                        max_action,
                        min_action=min_action,
                        discount=gamma,
                        tau=tau,
                        policy_freq=value_delay)

    def rollout(self, num_rollouts=1, render=False):
        rewards = np.zeros(num_rollouts)
        for i in range(num_rollouts):
            env = self._env_fn()
            s = env.reset()
            episode_reward = 0
            done = False
            while not done:
                if render:
                    env.render()
                a = self.action(s)
                s, r, done, _ = env.step(a)
                episode_reward += r
            rewards[i] = episode_reward
        if render:
            env.close()
        return rewards

    def train(self, num_steps, win_condition=None, win_window=5, logger=None):
        env = self._env_fn()
        s = env.reset()
        episode_reward = 0
        num_episodes = 0
        if win_condition is not None:
            scores = [0. for _ in range(win_window)]
            idx = 0
        for i in range(num_steps):
            if self._training:
                a = self.action(s)
            else:
                a = env.action_space.sample()
            ns, r, d, _ = env.step(a)
            episode_reward += r
            self._replay_buf.add(s, a, ns, r, d)
            self._total_steps += 1
            if not self._training:
                if self._total_steps >= self._pre_train_steps:
                    self._training = True
            if self._training:
                losses = self.update()
                artifacts = {
                    'loss': losses,
                    'step': self._total_steps,
                    'episode': num_episodes,
                    'done': d,
                    'return': episode_reward,
                    'transition': {
                        'state': s,
                        'action': a,
                        'reward': r,
                        'next state': ns,
                        'done': d,
                    }
                }
                if logger is not None:
                    logger(self, artifacts)
            if d:
                s = env.reset()
                num_episodes += 1
                if win_condition is not None:
                    scores[idx] = episode_reward
                    idx = (idx + 1) % win_window
                    if (num_episodes >= win_window) and (np.mean(scores) >=
                                                         win_condition):
                        print("SAC finished training: win condition reached")
                        break
                episode_reward = 0
            else:
                s = ns

    def update(self):
        self._td3.train(self._replay_buf, self._batch_size)
        return {}

    def action(self, x):
        return self._td3.select_action(x)
示例#7
0
    def start_training(self,
                       env,
                       render=False,
                       load=False,
                       experiment_name="td3"):
        neptune.init('sommerfe/aaml-project', neptune_api_token)

        parser = argparse.ArgumentParser()
        parser.add_argument(
            "--policy", default="TD3")  # Policy name (TD3, DDPG or OurDDPG)
        parser.add_argument(
            "--env", default="Pendulum"
        )  # OpenAI gym environment name (not used to start env in AlphaWorm)
        parser.add_argument("--seed", default=0,
                            type=int)  # Sets Gym, PyTorch and Numpy seeds
        parser.add_argument(
            "--start_timesteps", default=1e6,
            type=int)  # Time steps initial random policy is used
        parser.add_argument("--eval_freq", default=5e3,
                            type=int)  # How often (time steps) we evaluate
        parser.add_argument("--max_timesteps", default=5e6,
                            type=int)  # Max time steps to run environment
        parser.add_argument("--max_env_episode_steps", default=1e3,
                            type=int)  # Max env steps
        parser.add_argument("--expl_noise",
                            default=0.1)  # Std of Gaussian exploration noise
        parser.add_argument(
            "--random_policy_ratio", default=1
        )  # ratio of random episodes 1 = as many random as policy, 2 = double as many policy as random ...
        parser.add_argument("--batch_size", default=256,
                            type=int)  # Batch size for both actor and critic
        parser.add_argument("--discount", default=0.99)  # Discount factor
        parser.add_argument("--tau",
                            default=0.005)  # Target network update rate
        parser.add_argument(
            "--policy_noise",
            default=0.2)  # Noise added to target policy during critic update
        parser.add_argument("--noise_clip",
                            default=0.5)  # Range to clip target policy noise
        parser.add_argument("--policy_freq", default=2,
                            type=int)  # Frequency of delayed policy updates
        parser.add_argument(
            "--save_model", default=True,
            action="store_true")  # Save model and optimizer parameters
        if load:
            parser.add_argument(
                "--load_model", default="default"
            )  # Model load file name, "" doesn't load, "default" uses file_name
        else:
            parser.add_argument(
                "--load_model", default=""
            )  # Model load file name, "" doesn't load, "default" uses file_name
        parser.add_argument("--random_policy",
                            default=False)  # Activate random policy

        args = parser.parse_args()

        file_name = f"{args.policy}_{args.env}_{args.seed}"
        print("---------------------------------------")
        print(
            f"{datetime.now()} \t Policy: {args.policy}, Env: {args.env}, Seed: {args.seed}"
        )
        print("---------------------------------------")

        if not os.path.exists("./td3/results"):
            os.makedirs("./td3/results")

        if args.save_model and not os.path.exists("./td3/models"):
            os.makedirs("./td3/models")

        # Set seeds
        # env.seed(args.seed)
        env.action_space.seed(args.seed)
        torch.manual_seed(args.seed)
        np.random.seed(args.seed)

        state_dim = env.observation_space.shape[0]
        action_dim = env.action_space.shape[0]
        max_action = float(env.action_space.high[0])

        kwargs = {
            "state_dim": state_dim,
            "action_dim": action_dim,
            "max_action": max_action,
            "discount": args.discount,
            "tau": args.tau,
        }

        # Initialize policy
        if args.policy == "TD3":
            # Target policy smoothing is scaled wrt the action scale
            kwargs["policy_noise"] = args.policy_noise * max_action
            kwargs["noise_clip"] = args.noise_clip * max_action
            kwargs["policy_freq"] = args.policy_freq
            policy = TD3(**kwargs)

        neptune.create_experiment(experiment_name, params=kwargs)
        neptune.log_text('cpu_count', str(psutil.cpu_count()))
        neptune.log_text('count_non_logical',
                         str(psutil.cpu_count(logical=False)))

        if args.load_model != "":
            policy_file = file_name if args.load_model == "default" else args.load_model
            policy.load(f"./models/{policy_file}")

        replay_buffer = ReplayBuffer(state_dim, action_dim)

        # Evaluate untrained policy
        evaluations = [self.eval_policy(policy, env, args.seed, render)]

        state, done = env.reset(), False
        episode_reward = 0
        episode_timesteps = 0
        episode_num = 0

        tic_training = time.perf_counter()
        for t in range(int(args.max_timesteps)):
            neptune.log_text('avg_cpu_load', str(psutil.getloadavg()))
            neptune.log_text('cpu_percent',
                             str(psutil.cpu_percent(interval=1, percpu=True)))
            tic_episode = time.perf_counter()
            episode_timesteps += 1

            #if t < args.start_timesteps:
            #    action = env.action_space.sample()
            #else:
            action = (policy.select_action(np.array(state)) + np.random.normal(
                0, max_action * args.expl_noise, size=action_dim)).clip(
                    -max_action, max_action)

            # Perform action
            next_state, reward, done, _ = env.step(action)
            if render:
                env.render()
            done = True if episode_timesteps % args.max_env_episode_steps == 0 else False
            done_bool = float(
                done) if episode_timesteps < args.max_env_episode_steps else 0

            # Store data in replay buffer
            replay_buffer.add(state, action, next_state, reward, done_bool)

            state = next_state
            episode_reward += reward

            # Train agent after collecting sufficient data
            if t >= args.start_timesteps:
                policy.train(replay_buffer, args.batch_size)

            if done:
                # +1 to account for 0 indexing. +0 on ep_timesteps since it will increment +1 even if done=True
                print(
                    f"{datetime.now()} \t Total T: {t + 1} Episode Num: {episode_num + 1} Episode T: {episode_timesteps} Reward: {episode_reward}"
                )
                neptune.log_metric('episode_reward', episode_reward)
                # Reset environment
                state, done = env.reset(), False
                episode_reward = 0
                episode_timesteps = 0
                episode_num += 1
                toc_episode = time.perf_counter()
                neptune.log_metric('episode_duration',
                                   toc_episode - tic_episode)

            # Evaluate episode
            if (t + 1) % args.eval_freq == 0:
                eval_reward = self.eval_policy(policy, env, args.seed, render)
                evaluations.append(eval_reward)
                neptune.log_metric('eval_reward', eval_reward)
                np.save(f"./td3/results/{file_name}", evaluations)
                if args.save_model: policy.save(f"./td3/models/{file_name}")

        toc_training = time.perf_counter()
        neptune.log_metric('training_duration', toc_training - tic_training)