Ejemplo n.º 1
0
    def __init__(
            self,
            env,
            replay_buffer: ReplayBuffer,
            policy_net: torch.nn.Module,  # actor
            q_net1: torch.nn.Module,  # critic
            q_net2: torch.nn.Module,
            policy_optimizer: torch.optim.Optimizer,
            q_optimizer1: torch.optim.Optimizer,
            q_optimizer2: torch.optim.Optimizer,
            gamma=0.99,
            tau=0.05,
            alpha=0.5,
            automatic_entropy_tuning=False,
            explore_step=2000,
            max_train_step=50000,
            train_id="sac_test",
            log_interval=1000,
            resume=False):

        self.env = env
        self.replay_buffer = replay_buffer

        # the network and optimizers
        self.policy_net = policy_net
        self.q_net1 = q_net1
        self.q_net2 = q_net2
        self.target_q_net1 = copy.deepcopy(self.q_net1)
        self.target_q_net2 = copy.deepcopy(self.q_net2)
        self.policy_optimizer = policy_optimizer
        self.q_optimizer1 = q_optimizer1
        self.q_optimizer2 = q_optimizer2

        self.gamma = gamma
        self.tau = tau
        self.alpha = alpha
        self.automatic_entropy_tuning = automatic_entropy_tuning

        if self.automatic_entropy_tuning:
            self.target_entropy = -np.prod(self.env.action_space.shape).item()
            self.log_alpha = torch.zeros(1, requires_grad=True)
            self.alpha_optimizer = torch.optim.Adam([self.log_alpha], lr=3e-3)
            self.alpha = torch.exp(self.log_alpha)

        self.explore_step = explore_step
        self.max_train_step = max_train_step

        self.train_step = 0
        self.episode_num = 0

        self.resume = resume  # whether load checkpoint start train from last time

        # log dir and interval
        self.log_interval = log_interval
        self.result_dir = os.path.join("./results", train_id)
        log_tools.make_dir(self.result_dir)
        self.checkpoint_path = os.path.join(self.result_dir, "checkpoint.pth")
        self.tensorboard_writer = log_tools.TensorboardLogger(self.result_dir)
Ejemplo n.º 2
0
    def __init__(
            self,
            env: Env,
            trajectory_buffer: TrajectoryBuffer,
            actor_net: torch.nn.Module,
            critic_net: torch.nn.Module,
            actor_lr=1e-4,
            critic_lr=1e-3,
            gamma=0.99,
            gae_lambda=0.95,
            gae_normalize=False,
            clip_pram=0.2,
            trajectory_length=128,  # the length of a trajectory_
            train_actor_iters=10,
            train_critic_iters=10,
            eval_freq=1000,  # it will not evaluate the agent during train if eval_freq < 0
            max_time_step=10000,
            train_id="PPO_CarPole_test",
            log_interval=1000,
            resume=False,  # if True, train from last checkpoint
            device='cpu'):
        self.env = env

        self.trajectory_buffer = trajectory_buffer

        self.device = torch.device(device)

        # the network and optimizers
        self.actor_net = actor_net.to(self.device)
        self.critic_net = critic_net.to(self.device)
        self.actor_optimizer = torch.optim.Adam(self.actor_net.parameters(),
                                                lr=actor_lr)
        self.critic_optimizer = torch.optim.Adam(self.critic_net.parameters(),
                                                 lr=critic_lr)

        self.gamma = gamma
        self.gae_lambda = gae_lambda
        self.gae_normalize = gae_normalize
        self.trajectory_length = trajectory_length
        self.train_actor_iters = train_actor_iters
        self.train_critic_iters = train_critic_iters
        self.clip_pram = clip_pram

        self.eval_freq = eval_freq
        self.max_time_step = max_time_step

        self.time_step = 0
        self.episode_num = 0

        self.resume = resume  # whether load checkpoint start train from last time

        # log dir and interval
        self.log_interval = log_interval
        self.result_dir = os.path.join(log_tools.ROOT_DIR, "run/results",
                                       train_id)
        log_tools.make_dir(self.result_dir)
        self.checkpoint_path = os.path.join(self.result_dir, "checkpoint.pth")
        self.tensorboard_writer = log_tools.TensorboardLogger(self.result_dir)
Ejemplo n.º 3
0
    def __init__(
            self,
            env: Env,
            replay_buffer: ReplayBuffer,
            actor_net: torch.nn.Module,
            critic_net: torch.nn.Module,
            actor_lr=1e-4,
            critic_lr=1e-3,
            gamma=0.99,
            tau=0.005,  # used to update target network, w' = tau*w + (1-tau)*w'
            explore_step=128,
            eval_freq=1000,  # it will not evaluate the agent during train if eval_freq < 0
            max_train_step=10000,
            gaussian_noise_sigma=0.2,
            train_id="ddpg_Pendulum_test",
            log_interval=1000,
            resume=False,  # if True, train from last checkpoint
            device='cpu'):

        self.env = env
        self.action_num = env.action_space.shape[0]
        self.action_bound = env.action_space.high[0]
        self.replay_buffer = replay_buffer

        self.device = torch.device(device)

        # the network and optimizers
        self.actor_net = actor_net.to(self.device)
        self.target_actor_net = copy.deepcopy(self.actor_net).to(self.device)
        self.critic_net = critic_net.to(self.device)
        self.target_critic_net = copy.deepcopy(self.critic_net).to(self.device)
        self.actor_optimizer = torch.optim.Adam(self.actor_net.parameters(),
                                                lr=actor_lr)
        self.critic_optimizer = torch.optim.Adam(self.critic_net.parameters(),
                                                 lr=critic_lr)

        self.gamma = gamma
        self.tau = tau
        self.gaussian_noise_sigma = gaussian_noise_sigma

        self.explore_step = explore_step
        self.eval_freq = eval_freq
        self.max_train_step = max_train_step

        self.train_step = 0
        self.episode_num = 0

        self.resume = resume  # whether load checkpoint start train from last time

        # log dir and interval
        self.log_interval = log_interval
        self.result_dir = os.path.join(log_tools.ROOT_DIR, "run/results",
                                       train_id)
        log_tools.make_dir(self.result_dir)
        self.checkpoint_path = os.path.join(self.result_dir, "checkpoint.pth")
        self.tensorboard_writer = log_tools.TensorboardLogger(self.result_dir)
Ejemplo n.º 4
0
    def __init__(
            self,
            env: Env,
            replay_buffer: ReplayBuffer,
            Q_net: torch.nn.Module,
            qf_lr=0.001,
            gamma=0.99,
            initial_eps=0.1,
            end_eps=0.001,
            eps_decay_period=2000,
            eval_eps=0.001,
            target_update_freq=10,
            train_interval: int = 1,
            explore_step=500,
            eval_freq=1000,  # it will not evaluate the agent during train if eval_freq < 0
            max_train_step=10000,
            train_id="ddqn_CartPole_test",
            log_interval=1000,
            resume=False,  # if True, train from last checkpoint
            device='cpu'):
        self.env = env
        self.replay_buffer = replay_buffer

        self.explore_step = explore_step
        self.eval_freq = eval_freq
        self.max_train_step = max_train_step
        self.train_interval = train_interval
        self.target_update_freq = target_update_freq

        self.device = torch.device(device)

        self.Q_net = Q_net.to(self.device)
        self.target_Q_net = copy.deepcopy(self.Q_net).to(self.device)
        self.optimizer = torch.optim.Adam(self.Q_net.parameters(), lr=qf_lr)

        # Decay for epsilon
        self.initial_eps = initial_eps
        self.end_eps = end_eps
        self.slope = (self.end_eps - self.initial_eps) / eps_decay_period
        self.eval_eps = eval_eps

        self.gamma = gamma
        self.train_step = 0
        self.episode_num = 0

        self.resume = resume  # whether load checkpoint start train from last time

        # log dir and interval
        self.log_interval = log_interval
        self.result_dir = os.path.join(log_tools.ROOT_DIR, "run/results",
                                       train_id)
        log_tools.make_dir(self.result_dir)
        self.checkpoint_path = os.path.join(self.result_dir, "checkpoint.pth")
        self.tensorboard_writer = log_tools.TensorboardLogger(self.result_dir)
Ejemplo n.º 5
0
    def __init__(self,
                 env,
                 data_buffer,
                 Q_net: torch.nn.Module,
                 qf_lr=0.001,
                 gamma=0.99,
                 eval_eps=0.001,
                 target_update_freq=8000,
                 train_interval: int = 1,

                 # CQL
                 min_q_weight=5.0,  # the value of alpha in CQL loss, set to 5.0 or 10.0 if not using lagrange

                 max_train_step=2000000,
                 log_interval=1000,
                 eval_freq=5000,
                 train_id="sac_Pendulum_test",
                 resume=False,  # if True, train from last checkpoint
                 device='cpu',
                 ):

        self.env = env
        self.data_buffer = data_buffer

        self.max_train_step = max_train_step
        self.train_interval = train_interval
        self.target_update_freq = target_update_freq

        self.device = torch.device(device)

        self.Q_net = Q_net.to(self.device)
        self.target_Q_net = copy.deepcopy(self.Q_net).to(self.device)
        self.optimizer = torch.optim.Adam(self.Q_net.parameters(), lr=qf_lr)

        self.min_q_weight= min_q_weight

        self.eval_eps = eval_eps

        self.gamma = gamma
        self.train_step = 0
        self.eval_freq = eval_freq

        self.resume = resume  # whether load checkpoint start train from last time

        # log dir and interval
        self.log_interval = log_interval
        self.result_dir = os.path.join(log_tools.ROOT_DIR, "run/results", train_id)
        log_tools.make_dir(self.result_dir)
        self.checkpoint_path = os.path.join(self.result_dir, "checkpoint.pth")
        self.tensorboard_writer = log_tools.TensorboardLogger(self.result_dir)
Ejemplo n.º 6
0
    def __init__(
            self,
            env: Env,
            replay_buffer: ReplayBuffer,
            actor_net: torch.nn.Module,
            critic_net: torch.nn.Module,
            actor_optimizer: torch.optim.Optimizer,
            critic_optimizer: torch.optim.Optimizer,
            gamma=0.99,
            tau=0.005,  # used to update target network, w' = tau*w + (1-tau)*w'
            explore_step=128,
            max_train_step=10000,
            gaussian_noise_sigma=0.2,
            train_id="ddpg_Pendulum_test",
            log_interval=1000,
            resume=False):

        self.env = env
        self.action_num = env.action_space.shape[0]
        self.action_bound = env.action_space.high[0]
        self.replay_buffer = replay_buffer

        # the network and optimizers
        self.actor_net = actor_net
        self.target_actor_net = copy.deepcopy(self.actor_net)
        self.critic_net = critic_net
        self.target_critic_net = copy.deepcopy(self.critic_net)
        self.actor_optimizer = actor_optimizer
        self.critic_optimizer = critic_optimizer

        self.gamma = gamma
        self.tau = tau
        self.gaussian_noise_sigma = gaussian_noise_sigma

        self.explore_step = explore_step
        self.max_train_step = max_train_step

        self.train_step = 0
        self.episode_num = 0

        self.resume = resume  # whether load checkpoint start train from last time

        # log dir and interval
        self.log_interval = log_interval
        self.result_dir = os.path.join("./results", train_id)
        log_tools.make_dir(self.result_dir)
        self.checkpoint_path = os.path.join(self.result_dir, "checkpoint.pth")
        self.tensorboard_writer = log_tools.TensorboardLogger(self.result_dir)
Ejemplo n.º 7
0
    def __init__(self,
                 env: Env,
                 replay_buffer: ReplayBuffer,
                 Q_net: torch.nn.Module,
                 optimizer: torch.optim.Optimizer,
                 gamma=0.99,
                 initial_eps=0.1,
                 end_eps=0.001,
                 eps_decay_period=25e4,
                 eval_eps=0.001,
                 target_update_freq=10,
                 train_interval: int = 1,
                 explore_step=500,
                 max_train_step=10000,
                 train_id="dqn_CartPole_test",
                 log_interval=1000,
                 resume=False):
        self.env = env
        self.replay_buffer = replay_buffer

        self.explore_step = explore_step
        self.max_train_step = max_train_step
        self.train_interval = train_interval
        self.target_update_freq = target_update_freq

        self.Q_net = Q_net
        self.target_Q_net = copy.deepcopy(self.Q_net)
        self.optimizer = optimizer

        # Decay for epsilon
        self.initial_eps = initial_eps
        self.end_eps = end_eps
        self.slope = (self.end_eps - self.initial_eps) / eps_decay_period
        self.eval_eps = eval_eps

        self.gamma = gamma
        self.train_step = 0
        self.episode_num = 0

        self.resume = resume  # whether load checkpoint start train from last time

        # log dir and interval
        self.log_interval = log_interval
        self.result_dir = os.path.join("./results", train_id)
        log_tools.make_dir(self.result_dir)
        self.checkpoint_path = os.path.join(self.result_dir, "checkpoint.pth")
        self.tensorboard_writer = log_tools.TensorboardLogger(self.result_dir)
Ejemplo n.º 8
0
    def __init__(
        self,
        env,
        replay_buffer: OfflineBuffer,
        actor_net: torch.nn.Module,
        critic_net1: torch.nn.Module,
        critic_net2: torch.nn.Module,
        actor_lr=3e-4,
        critic_lr=3e-4,
        gamma=0.99,
        tau=0.005,  # used to update target network, w' = tau*w + (1-tau)*w'
        policy_noise=0.2,  # Noise added to target policy during critic update
        noise_clip=0.5,  # Range to clip target policy noise
        policy_delay=2,  # Frequency of delayed policy updates
        alpha=2.5,  # The alpha to compute lambda
        max_train_step=1000000,
        log_interval=1000,
        eval_freq=5000,
        train_id="td3bc_test",
        resume=False,  # if True, train from last checkpoint
        device='cpu',
    ):

        self.env = env
        self.action_num = env.action_space.shape[0]
        self.action_bound = env.action_space.high[0]
        self.replay_buffer = replay_buffer

        self.device = torch.device(device)

        # the network and optimizers
        self.actor_net = actor_net.to(self.device)
        self.target_actor_net = copy.deepcopy(self.actor_net).to(self.device)
        self.critic_net1 = critic_net1.to(self.device)
        self.target_critic_net1 = copy.deepcopy(self.critic_net1).to(
            self.device)
        self.critic_net2 = critic_net2.to(self.device)
        self.target_critic_net2 = copy.deepcopy(self.critic_net2).to(
            self.device)

        self.actor_optimizer = torch.optim.Adam(self.actor_net.parameters(),
                                                lr=actor_lr)
        self.critic_optimizer1 = torch.optim.Adam(
            self.critic_net1.parameters(), lr=critic_lr)
        self.critic_optimizer2 = torch.optim.Adam(
            self.critic_net2.parameters(), lr=critic_lr)

        self.gamma = gamma
        self.tau = tau
        self.policy_noise = policy_noise
        self.noise_clip = noise_clip
        self.policy_delay = policy_delay
        self.alpha = alpha

        self.actor_loss = 0
        self.eval_freq = eval_freq
        self.max_train_step = max_train_step
        self.train_step = 0

        self.resume = resume  # whether load checkpoint start train from last time

        # log dir and interval
        self.log_interval = log_interval
        self.result_dir = os.path.join(log_tools.ROOT_DIR, "run/results",
                                       train_id)
        log_tools.make_dir(self.result_dir)
        self.checkpoint_path = os.path.join(self.result_dir, "checkpoint.pth")
        self.tensorboard_writer = log_tools.TensorboardLogger(self.result_dir)
Ejemplo n.º 9
0
    def __init__(
        self,
        env,
        data_buffer: OfflineBuffer,
        policy_net: MLPSquashedReparamGaussianPolicy,  # actor
        q_net1: MLPQsaNet,  # critic
        q_net2: MLPQsaNet,
        cvae_net: CVAE,
        policy_lr=1e-4,
        qf_lr=3e-4,
        cvae_lr=3e-4,
        gamma=0.99,
        tau=0.05,

        # BEAR
        lmbda=0.75,  # used for double clipped double q-learning
        mmd_sigma=20.0,  # the sigma used in mmd kernel
        kernel_type='gaussian',  # the type of mmd kernel(gaussian or laplacian)
        lagrange_thresh=0.05,  # the hyper-parameter used in automatic tuning alpha in cql loss
        n_action_samples=100,  # the number of action samples to compute the best action when choose action
        n_target_samples=10,  # the number of action samples to compute BCQ-like target value
        n_mmd_action_samples=4,  # the number of action samples to compute MMD.
        warmup_step=40000,  # do support matching with a warm start before policy(actor) train
        max_train_step=1000000,
        log_interval=1000,
        eval_freq=5000,
        train_id="bear_hopper-medium-v2_test",
        resume=False,  # if True, train from last checkpoint
        device='cpu',
    ):

        self.env = env
        self.data_buffer = data_buffer

        self.device = torch.device(device)

        # the network and optimizers
        self.policy_net = policy_net.to(self.device)
        self.q_net1 = q_net1.to(self.device)
        self.q_net2 = q_net2.to(self.device)
        self.target_q_net1 = copy.deepcopy(self.q_net1).to(self.device)
        self.target_q_net2 = copy.deepcopy(self.q_net2).to(self.device)
        self.cvae_net = cvae_net.to(self.device)
        self.policy_optimizer = torch.optim.Adam(self.policy_net.parameters(),
                                                 lr=policy_lr)
        self.q_optimizer1 = torch.optim.Adam(self.q_net1.parameters(),
                                             lr=qf_lr)
        self.q_optimizer2 = torch.optim.Adam(self.q_net2.parameters(),
                                             lr=qf_lr)
        self.cvae_optimizer = torch.optim.Adam(self.cvae_net.parameters(),
                                               lr=cvae_lr)

        self.gamma = gamma
        self.tau = tau

        self.max_train_step = max_train_step
        self.eval_freq = eval_freq
        self.train_step = 0

        self.resume = resume  # whether load checkpoint start train from last time

        # BEAR
        self.lmbda = lmbda
        self.mmd_sigma = mmd_sigma
        self.kernel_type = kernel_type
        self.lagrange_thresh = lagrange_thresh
        self.n_action_samples = n_action_samples
        self.n_target_samples = n_target_samples
        self.n_mmd_action_samples = n_mmd_action_samples
        self.warmup_step = warmup_step

        # mmd loss's temperature
        self.log_alpha_prime = torch.zeros(1,
                                           requires_grad=True,
                                           device=self.device)
        self.alpha_prime_optimizer = torch.optim.Adam([self.log_alpha_prime],
                                                      lr=1e-3)

        # log dir and interval
        self.log_interval = log_interval
        self.result_dir = os.path.join(log_tools.ROOT_DIR, "run/results",
                                       train_id)
        log_tools.make_dir(self.result_dir)
        self.checkpoint_path = os.path.join(self.result_dir, "checkpoint.pth")
        self.tensorboard_writer = log_tools.TensorboardLogger(self.result_dir)
Ejemplo n.º 10
0
    def __init__(self,
                 env,
                 data_buffer: OfflineBuffer,
                 policy_net: torch.nn.Module,  # actor
                 q_net1: torch.nn.Module,  # critic
                 q_net2: torch.nn.Module,
                 policy_lr=3e-4,
                 qf_lr=3e-4,
                 gamma=0.99,
                 tau=0.05,
                 alpha=0.5,
                 auto_alpha_tuning=False,

                 # CQL
                 min_q_weight=5.0,  # the value of alpha in CQL loss, set to 5.0 or 10.0 if not using lagrange
                 entropy_backup=False,  # whether use sac style target Q with entropy
                 max_q_backup=False,  # whether use max q backup
                 with_lagrange=False,  # whether auto tune alpha in Conservative Q Loss(different from the alpha in sac)
                 lagrange_thresh=0.0,  # the hyper-parameter used in automatic tuning alpha in cql loss
                 n_action_samples=10,  # the number of action sampled in importance sampling

                 max_train_step=2000000,
                 log_interval=1000,
                 eval_freq=5000,
                 train_id="cql_hopper-medium-v2_test",
                 resume=False,  # if True, train from last checkpoint
                 device='cpu',
                 ):

        self.env = env
        self.data_buffer = data_buffer

        self.device = torch.device(device)

        # the network and optimizers
        self.policy_net = policy_net.to(self.device)
        self.q_net1 = q_net1.to(self.device)
        self.q_net2 = q_net2.to(self.device)
        self.target_q_net1 = copy.deepcopy(self.q_net1).to(self.device)
        self.target_q_net2 = copy.deepcopy(self.q_net2).to(self.device)
        self.policy_optimizer = torch.optim.Adam(self.policy_net.parameters(), lr=policy_lr)
        self.q_optimizer1 = torch.optim.Adam(self.q_net1.parameters(), lr=qf_lr)
        self.q_optimizer2 = torch.optim.Adam(self.q_net2.parameters(), lr=qf_lr)

        self.gamma = gamma
        self.tau = tau
        self.alpha = alpha
        self.auto_alpha_tuning = auto_alpha_tuning

        if self.auto_alpha_tuning:
            self.target_entropy = -np.prod(self.env.action_space.shape).item()
            self.log_alpha = torch.zeros(1, requires_grad=True, device=self.device)
            self.alpha_optimizer = torch.optim.Adam([self.log_alpha], lr=policy_lr)
            self.alpha = torch.exp(self.log_alpha)

        self.max_train_step = max_train_step
        self.eval_freq = eval_freq
        self.train_step = 0

        self.resume = resume  # whether load checkpoint start train from last time

        # CQL
        self.min_q_weight = min_q_weight
        self.entropy_backup = entropy_backup
        self.max_q_backup = max_q_backup
        self.with_lagrange = with_lagrange
        self.lagrange_thresh = lagrange_thresh
        self.n_action_samples = n_action_samples

        if self.with_lagrange:
            self.log_alpha_prime = torch.zeros(1, requires_grad=True, device=self.device)
            self.alpha_prime_optimizer = torch.optim.Adam([self.log_alpha_prime], lr=qf_lr)

        # log dir and interval
        self.log_interval = log_interval
        self.result_dir = os.path.join(log_tools.ROOT_DIR, "run/results", train_id)
        log_tools.make_dir(self.result_dir)
        self.checkpoint_path = os.path.join(self.result_dir, "checkpoint.pth")
        self.tensorboard_writer = log_tools.TensorboardLogger(self.result_dir)
Ejemplo n.º 11
0
    def __init__(
        self,
        env,
        data_buffer: OfflineBuffer,
        critic_net1: MLPQsaNet,
        critic_net2: MLPQsaNet,
        actor_net: PLAS_Actor,
        cvae_net: CVAE,  # generation model
        critic_lr=1e-3,
        actor_lr=1e-4,
        cvae_lr=1e-4,
        gamma=0.99,
        tau=0.005,
        lmbda=0.75,  # used for double clipped double q-learning
        max_cvae_iterations=500000,  # the num of iterations when training CVAE model
        max_train_step=2000000,
        log_interval=1000,
        eval_freq=5000,
        train_id="plas_test",
        resume=False,  # if True, train from last checkpoint
        device='cpu',
    ):
        self.env = env
        self.data_buffer = data_buffer
        self.device = torch.device(device)

        self.critic_net1 = critic_net1.to(self.device)
        self.critic_net2 = critic_net2.to(self.device)
        self.target_critic_net1 = copy.deepcopy(self.critic_net1).to(
            self.device)
        self.target_critic_net2 = copy.deepcopy(self.critic_net2).to(
            self.device)
        self.actor_net = actor_net.to(self.device)
        self.target_actor_net = copy.deepcopy(self.actor_net).to(self.device)
        self.cvae_net = cvae_net.to(self.device)
        self.critic_optimizer1 = torch.optim.Adam(
            self.critic_net1.parameters(), lr=critic_lr)
        self.critic_optimizer2 = torch.optim.Adam(
            self.critic_net2.parameters(), lr=critic_lr)
        self.actor_optimizer = torch.optim.Adam(self.actor_net.parameters(),
                                                lr=actor_lr)
        self.cvae_optimizer = torch.optim.Adam(self.cvae_net.parameters(),
                                               lr=cvae_lr)

        self.gamma = gamma
        self.tau = tau
        self.lmbda = lmbda

        self.max_cvae_iterations = max_cvae_iterations
        self.max_train_step = max_train_step
        self.eval_freq = eval_freq
        self.cvae_iterations = 0
        self.train_step = 0

        self.resume = resume  # whether load checkpoint start train from last time

        # log dir and interval
        self.log_interval = log_interval
        self.result_dir = os.path.join(log_tools.ROOT_DIR, "run/results",
                                       train_id)
        log_tools.make_dir(self.result_dir)
        self.checkpoint_path = os.path.join(self.result_dir, "checkpoint.pth")
        self.tensorboard_writer = log_tools.TensorboardLogger(self.result_dir)
Ejemplo n.º 12
0
    def __init__(
        self,
        env,
        data_buffer: OfflineBuffer,
        policy_net: torch.nn.Module,  # actor
        q_net1: torch.nn.Module,  # critic
        q_net2: torch.nn.Module,
        policy_lr=3e-4,
        qf_lr=3e-4,
        gamma=0.99,
        tau=0.05,
        alpha=0.5,
        auto_alpha_tuning=False,
        max_train_step=2000000,
        log_interval=1000,
        eval_freq=5000,
        train_id="sac_Pendulum_test",
        resume=False,  # if True, train from last checkpoint
        device='cpu',
    ):
        self.env = env
        self.data_buffer = data_buffer

        self.device = torch.device(device)

        # the network and optimizers
        self.policy_net = policy_net.to(self.device)
        self.q_net1 = q_net1.to(self.device)
        self.q_net2 = q_net2.to(self.device)
        self.target_q_net1 = copy.deepcopy(self.q_net1).to(self.device)
        self.target_q_net2 = copy.deepcopy(self.q_net2).to(self.device)
        self.policy_optimizer = torch.optim.Adam(self.policy_net.parameters(),
                                                 lr=policy_lr)
        self.q_optimizer1 = torch.optim.Adam(self.q_net1.parameters(),
                                             lr=qf_lr)
        self.q_optimizer2 = torch.optim.Adam(self.q_net2.parameters(),
                                             lr=qf_lr)

        self.gamma = gamma
        self.tau = tau
        self.alpha = alpha
        self.auto_alpha_tuning = auto_alpha_tuning

        if self.auto_alpha_tuning:
            self.target_entropy = -np.prod(self.env.action_space.shape).item()
            self.log_alpha = torch.zeros(1,
                                         requires_grad=True,
                                         device=self.device)
            self.alpha_optimizer = torch.optim.Adam([self.log_alpha],
                                                    lr=policy_lr)
            self.alpha = torch.exp(self.log_alpha)

        self.max_train_step = max_train_step
        self.eval_freq = eval_freq
        self.train_step = 0

        self.resume = resume  # whether load checkpoint start train from last time

        # log dir and interval
        self.log_interval = log_interval
        self.result_dir = os.path.join(log_tools.ROOT_DIR, "run/results",
                                       train_id)
        log_tools.make_dir(self.result_dir)
        self.checkpoint_path = os.path.join(self.result_dir, "checkpoint.pth")
        self.tensorboard_writer = log_tools.TensorboardLogger(self.result_dir)