示例#1
0
 def __init__(self,
              observation_space,
              action_space,
              learning_rate=0.001,
              update_period=100,
              embedding_dim=10,
              net_fn=None,
              net_kwargs=None,
              device="cuda:best",
              rate_power=0.5,
              batch_size=10,
              memory_size=10000,
              with_action=False,
              **kwargs):
     assert isinstance(observation_space, spaces.Box)
     UncertaintyEstimator.__init__(self, observation_space, action_space)
     self.learning_rate = learning_rate
     self.loss_fn = F.mse_loss
     self.update_period = update_period
     self.embedding_dim = embedding_dim
     out_size = embedding_dim * action_space.n if with_action else embedding_dim
     self.net_fn = load(net_fn) if isinstance(net_fn, str) else \
         net_fn or partial(get_network, shape=observation_space.shape, embedding_dim=out_size)
     self.net_kwargs = net_kwargs or {}
     if "out_size" in self.net_kwargs:
         self.net_kwargs["out_size"] = out_size
     self.device = choose_device(device)
     self.rate_power = rate_power
     self.batch_size = batch_size
     self.memory = ReplayMemory(capacity=memory_size)
     self.with_action = with_action
     self.reset()
示例#2
0
    def __init__(self, env,
                 n_episodes=4000,
                 batch_size=8,
                 horizon=256,
                 gamma=0.99,
                 entr_coef=0.01,
                 vf_coef=0.,
                 avec_coef=1.,
                 learning_rate=0.0003,
                 optimizer_type='ADAM',
                 eps_clip=0.2,
                 k_epochs=10,
                 policy_net_fn=None,
                 value_net_fn=None,
                 policy_net_kwargs=None,
                 value_net_kwargs=None,
                 use_bonus=False,
                 uncertainty_estimator_kwargs=None,
                 device="cuda:best",
                 **kwargs):
        self.use_bonus = use_bonus
        if self.use_bonus:
            env = UncertaintyEstimatorWrapper(env,
                                              **uncertainty_estimator_kwargs)
        IncrementalAgent.__init__(self, env, **kwargs)

        self.learning_rate = learning_rate
        self.gamma = gamma
        self.entr_coef = entr_coef
        self.vf_coef = vf_coef
        self.avec_coef = avec_coef
        self.eps_clip = eps_clip
        self.k_epochs = k_epochs
        self.horizon = horizon
        self.n_episodes = n_episodes
        self.batch_size = batch_size
        self.device = choose_device(device)

        self.policy_net_kwargs = policy_net_kwargs or {}
        self.value_net_kwargs = value_net_kwargs or {}

        self.state_dim = self.env.observation_space.shape[0]
        self.action_dim = self.env.action_space.n

        #
        self.policy_net_fn = policy_net_fn or default_policy_net_fn
        self.value_net_fn = value_net_fn or default_value_net_fn

        self.optimizer_kwargs = {'optimizer_type': optimizer_type,
                                 'lr': learning_rate}

        # check environment
        assert isinstance(self.env.observation_space, spaces.Box)
        assert isinstance(self.env.action_space, spaces.Discrete)

        self.cat_policy = None  # categorical policy function

        # initialize
        self.reset()
示例#3
0
    def __init__(self,
                 env,
                 batch_size=8,
                 horizon=256,
                 gamma=0.99,
                 entr_coef=0.01,
                 vf_coef=0.0,
                 avec_coef=1.0,
                 learning_rate=0.0003,
                 optimizer_type="ADAM",
                 eps_clip=0.2,
                 k_epochs=10,
                 policy_net_fn=None,
                 value_net_fn=None,
                 policy_net_kwargs=None,
                 value_net_kwargs=None,
                 use_bonus=False,
                 uncertainty_estimator_kwargs=None,
                 device="cuda:best",
                 **kwargs):
        # For all parameters, define self.param = param
        _, _, _, values = inspect.getargvalues(inspect.currentframe())
        values.pop("self")
        for arg, val in values.items():
            setattr(self, arg, val)

        AgentWithSimplePolicy.__init__(self, env, **kwargs)

        self.use_bonus = use_bonus
        if self.use_bonus:
            self.env = UncertaintyEstimatorWrapper(
                self.env, **uncertainty_estimator_kwargs)

        self.device = choose_device(device)

        self.policy_net_kwargs = policy_net_kwargs or {}
        self.value_net_kwargs = value_net_kwargs or {}

        self.state_dim = self.env.observation_space.shape[0]
        self.action_dim = self.env.action_space.n

        #
        self.policy_net_fn = policy_net_fn or default_policy_net_fn
        self.value_net_fn = value_net_fn or default_value_net_fn

        self.optimizer_kwargs = {
            "optimizer_type": optimizer_type,
            "lr": learning_rate
        }

        # check environment
        assert isinstance(self.env.observation_space, spaces.Box)
        assert isinstance(self.env.action_space, spaces.Discrete)

        self.cat_policy = None  # categorical policy function

        # initialize
        self.reset()
示例#4
0
文件: a2c.py 项目: omardrwch/rlberry
    def __init__(self,
                 env,
                 batch_size=8,
                 horizon=256,
                 gamma=0.99,
                 entr_coef=0.01,
                 learning_rate=0.01,
                 optimizer_type="ADAM",
                 k_epochs=5,
                 policy_net_fn=None,
                 value_net_fn=None,
                 policy_net_kwargs=None,
                 value_net_kwargs=None,
                 use_bonus=False,
                 uncertainty_estimator_kwargs=None,
                 device="cuda:best",
                 **kwargs):

        AgentWithSimplePolicy.__init__(self, env, **kwargs)

        self.use_bonus = use_bonus
        if self.use_bonus:
            self.env = UncertaintyEstimatorWrapper(
                self.env, **uncertainty_estimator_kwargs)

        self.batch_size = batch_size
        self.horizon = horizon
        self.gamma = gamma
        self.entr_coef = entr_coef
        self.learning_rate = learning_rate
        self.k_epochs = k_epochs
        self.device = choose_device(device)

        self.policy_net_kwargs = policy_net_kwargs or {}
        self.value_net_kwargs = value_net_kwargs or {}

        self.state_dim = self.env.observation_space.shape[0]
        self.action_dim = self.env.action_space.n

        #
        self.policy_net_fn = policy_net_fn or default_policy_net_fn
        self.value_net_fn = value_net_fn or default_value_net_fn

        self.optimizer_kwargs = {
            "optimizer_type": optimizer_type,
            "lr": learning_rate
        }

        # check environment
        assert isinstance(self.env.observation_space, spaces.Box)
        assert isinstance(self.env.action_space, spaces.Discrete)

        self.cat_policy = None  # categorical policy function

        # initialize
        self.reset()
示例#5
0
    def __init__(self,
                 env,
                 n_episodes=4000,
                 batch_size=8,
                 horizon=256,
                 gamma=0.99,
                 entr_coef=0.01,
                 learning_rate=0.0001,
                 normalize=True,
                 optimizer_type='ADAM',
                 policy_net_fn=None,
                 policy_net_kwargs=None,
                 use_bonus_if_available=False,
                 device="cuda:best",
                 **kwargs):
        IncrementalAgent.__init__(self, env, **kwargs)

        self.n_episodes = n_episodes
        self.batch_size = batch_size
        self.horizon = horizon
        self.gamma = gamma
        self.entr_coef = entr_coef
        self.learning_rate = learning_rate
        self.normalize = normalize
        self.use_bonus_if_available = use_bonus_if_available
        self.device = choose_device(device)

        self.state_dim = self.env.observation_space.shape[0]
        self.action_dim = self.env.action_space.n

        self.policy_net_kwargs = policy_net_kwargs or {}

        #
        self.policy_net_fn = policy_net_fn or default_policy_net_fn

        self.optimizer_kwargs = {
            'optimizer_type': optimizer_type,
            'lr': learning_rate
        }

        # check environment
        assert isinstance(self.env.observation_space, spaces.Box)
        assert isinstance(self.env.action_space, spaces.Discrete)

        self.policy_net = None  # policy network

        # initialize
        self.reset()
示例#6
0
    def __init__(self,
                 env,
                 batch_size=8,
                 horizon=256,
                 gamma=0.99,
                 entr_coef=0.01,
                 learning_rate=0.0001,
                 normalize=True,
                 optimizer_type="ADAM",
                 policy_net_fn=None,
                 policy_net_kwargs=None,
                 use_bonus_if_available=False,
                 device="cuda:best",
                 **kwargs):

        # For all parameters, define self.param = param
        _, _, _, values = inspect.getargvalues(inspect.currentframe())
        values.pop("self")
        for arg, val in values.items():
            setattr(self, arg, val)

        AgentWithSimplePolicy.__init__(self, env, **kwargs)

        self.device = choose_device(device)

        self.state_dim = self.env.observation_space.shape[0]
        self.action_dim = self.env.action_space.n

        self.policy_net_kwargs = policy_net_kwargs or {}

        #
        self.policy_net_fn = policy_net_fn or default_policy_net_fn

        self.optimizer_kwargs = {
            "optimizer_type": optimizer_type,
            "lr": learning_rate
        }

        # check environment
        assert isinstance(self.env.observation_space, spaces.Box)
        assert isinstance(self.env.action_space, spaces.Discrete)

        self.policy_net = None  # policy network

        # initialize
        self.reset()
示例#7
0
    def __init__(self,
                 env,
                 n_episodes=1000,
                 horizon=256,
                 gamma=0.99,
                 loss_function="l2",
                 batch_size=100,
                 device="cuda:best",
                 target_update=1,
                 learning_rate=0.001,
                 optimizer_type='ADAM',
                 qvalue_net_fn=None,
                 double=True,
                 exploration_kwargs=None,
                 memory_kwargs=None,
                 **kwargs):
        # Wrap arguments and initialize base class
        memory_kwargs = memory_kwargs or {}
        memory_kwargs['gamma'] = gamma
        base_args = (env, horizon, exploration_kwargs, memory_kwargs,
                     n_episodes, batch_size, target_update, double)
        AbstractDQNAgent.__init__(self, *base_args, **kwargs)

        # init
        self.optimizer_kwargs = {'optimizer_type': optimizer_type,
                                 'lr': learning_rate}
        self.device = device
        self.loss_function = loss_function
        self.gamma = gamma
        #
        qvalue_net_fn = qvalue_net_fn \
            or (lambda: default_qvalue_net_fn(self.env))
        self.value_net = qvalue_net_fn()
        self.target_net = qvalue_net_fn()
        #
        self.target_net.load_state_dict(self.value_net.state_dict())
        self.target_net.eval()
        logger.debug("Number of trainable parameters: {}"
                     .format(trainable_parameters(self.value_net)))
        self.device = choose_device(self.device)
        self.value_net.to(self.device)
        self.target_net.to(self.device)
        self.loss_function = loss_function_factory(self.loss_function)
        self.optimizer = optimizer_factory(self.value_net.parameters(),
                                           **self.optimizer_kwargs)
        self.steps = 0
示例#8
0
    def __init__(self,
                 env,
                 n_episodes=1000,
                 horizon=100,
                 gamma=0.99,
                 entr_coef=0.1,
                 batch_size=16,
                 percentile=70,
                 learning_rate=0.01,
                 optimizer_type='ADAM',
                 on_policy=False,
                 policy_net_fn=None,
                 policy_net_kwargs=None,
                 device="cuda:best",
                 **kwargs):
        IncrementalAgent.__init__(self, env, **kwargs)

        # check environment
        assert isinstance(self.env.observation_space, spaces.Box)
        assert isinstance(self.env.action_space, spaces.Discrete)

        # parameters
        self.gamma = gamma
        self.entr_coef = entr_coef
        self.batch_size = batch_size
        self.n_episodes = n_episodes
        self.percentile = percentile
        self.learning_rate = learning_rate
        self.horizon = horizon
        self.on_policy = on_policy
        self.policy_net_kwargs = policy_net_kwargs or {}
        self.policy_net_fn = policy_net_fn or default_policy_net_fn
        self.optimizer_kwargs = {
            'optimizer_type': optimizer_type,
            'lr': learning_rate
        }
        self.device = choose_device(device)
        self.reset()
示例#9
0
文件: dqn.py 项目: omardrwch/rlberry
class DQNAgent(AgentWithSimplePolicy):
    """DQN Agent based on PyTorch.

    Notes
    -----
    Uses Q(lambda) for computing targets by default. To recover
    the standard DQN, set :code:`lambda_ = 0.0` and :code:`chunk_size = 1`.

    Parameters
    ----------
    env: :class:`~rlberry.types.Env`
        Environment, can be a tuple (constructor, kwargs)
    gamma: float, default = 0.99
        Discount factor.
    batch_size: int, default=32
        Batch size.
    chunk_size: int, default=8
        Length of sub-trajectories sampled from the replay buffer.
    lambda_: float, default=0.5
        Q(lambda) parameter.
    target_update_parameter : int or float
        If int: interval (in number total number of online updates) between updates of the target network.
        If float: soft update coefficient
    device: str
        Torch device, see :func:`~rlberry.utils.torch.choose_device`
    learning_rate : float, default = 1e-3
        Optimizer learning rate.
    loss_function: {"l1", "l2", "smooth_l1"}, default: "l2"
        Loss function used to compute Bellman error.
    epsilon_init: float, default = 1.0
        Initial epsilon value for epsilon-greedy exploration.
    epsilon_final: float, default = 0.1
        Final epsilon value for epsilon-greedy exploration.
    epsilon_decay_interval : int
        After :code:`epsilon_decay` timesteps, epsilon approaches :code:`epsilon_final`.
    optimizer_type : {"ADAM", "RMS_PROP"}
        Optimization algorithm.
    q_net_constructor : Callable
        Function/constructor that returns a torch module for the Q-network:
        :code:`qnet = q_net_constructor(env, **kwargs)`.

        Module (Q-network) requirements:

        * Input shape = (batch_dim, chunk_size, obs_dims)

        * Ouput shape = (batch_dim, chunk_size, number_of_actions)

    q_net_kwargs : optional, dict
        Parameters for q_net_constructor.
    use_double_dqn : bool, default = False
        If True, use Double DQN.
    use_prioritized_replay : bool, default = False
        If True, use Prioritized Experience Replay.
    train_interval: int
        Update the model every :code:`train_interval` steps.
        If -1, train only at the end of the episodes.
    gradient_steps: int
        How many gradient steps to do at each update.
        If -1, take the number of timesteps since last update.
    max_replay_size : int
        Maximum number of transitions in the replay buffer.
    learning_starts : int
        How many steps of the model to collect transitions for before learning starts
    eval_interval : int, default = None
        Interval (in number of transitions) between agent evaluations in fit().
        If None, never evaluate.
    """

    name = "DQN"

    def __init__(
        self,
        env: types.Env,
        gamma: float = 0.99,
        batch_size: int = 32,
        chunk_size: int = 8,
        lambda_: float = 0.5,
        target_update_parameter: Union[int, float] = 0.005,
        device: str = "cuda:best",
        learning_rate: float = 1e-3,
        epsilon_init: float = 1.0,
        epsilon_final: float = 0.1,
        epsilon_decay_interval: int = 20_000,
        loss_function: str = "l2",
        optimizer_type: str = "ADAM",
        q_net_constructor: Optional[Callable[..., torch.nn.Module]] = None,
        q_net_kwargs: Optional[dict] = None,
        use_double_dqn: bool = False,
        use_prioritized_replay: bool = False,
        train_interval: int = 10,
        gradient_steps: int = -1,
        max_replay_size: int = 200_000,
        learning_starts: int = 5_000,
        eval_interval: Optional[int] = None,
        **kwargs,
    ):
        # For all parameters, define self.param = param
        _, _, _, values = inspect.getargvalues(inspect.currentframe())
        values.pop("self")
        for arg, val in values.items():
            setattr(self, arg, val)

        AgentWithSimplePolicy.__init__(self, env, **kwargs)
        env = self.env
        assert isinstance(env.observation_space, spaces.Box)
        assert isinstance(env.action_space, spaces.Discrete)

        # DQN parameters

        # Online and target Q networks, torch device
        self._device = choose_device(device)
        if isinstance(q_net_constructor, str):
            q_net_ctor = load(q_net_constructor)
        elif q_net_constructor is None:
            q_net_ctor = default_q_net_fn
        q_net_kwargs = q_net_kwargs or dict()
        self._qnet_online = q_net_ctor(env, **q_net_kwargs).to(self._device)
        self._qnet_target = q_net_ctor(env, **q_net_kwargs).to(self._device)

        # Optimizer and loss
        optimizer_kwargs = {
            "optimizer_type": optimizer_type,
            "lr": learning_rate
        }
        self._optimizer = optimizer_factory(self._qnet_online.parameters(),
                                            **optimizer_kwargs)
        self._loss_function = loss_function_factory(loss_function,
                                                    reduction="none")

        # Training params
        self._train_interval = train_interval
        self._gradient_steps = gradient_steps
        self._learning_starts = learning_starts
        self._learning_starts = learning_starts
        self._eval_interval = eval_interval

        # Setup replay buffer
        if hasattr(self.env, "_max_episode_steps"):
            max_episode_steps = self.env._max_episode_steps
        else:
            max_episode_steps = np.inf
        self._max_episode_steps = max_episode_steps

        self._replay_buffer = replay.ReplayBuffer(
            max_replay_size=max_replay_size,
            rng=self.rng,
            max_episode_steps=self._max_episode_steps,
            enable_prioritized=use_prioritized_replay,
        )
        self._replay_buffer.setup_entry("observations", np.float32)
        self._replay_buffer.setup_entry("next_observations", np.float32)
        self._replay_buffer.setup_entry("actions", np.int32)
        self._replay_buffer.setup_entry("rewards", np.float32)
        self._replay_buffer.setup_entry("dones", bool)

        # Counters
        self._total_timesteps = 0
        self._total_episodes = 0
        self._total_updates = 0
        self._timesteps_since_last_update = 0

        # epsilon scheduling
        self._epsilon_schedule = polynomial_schedule(
            self.epsilon_init,
            self.epsilon_final,
            power=1.0,
            transition_steps=self.epsilon_decay_interval,
            transition_begin=0,
        )
示例#10
0
    def __init__(self,
                 env,
                 n_episodes=1000,
                 horizon=256,
                 gamma=0.99,
                 loss_function="l2",
                 batch_size=100,
                 device="cuda:best",
                 target_update=1,
                 learning_rate=0.001,
                 epsilon_init=1.0,
                 epsilon_final=0.1,
                 epsilon_decay=5000,
                 optimizer_type='ADAM',
                 qvalue_net_fn=None,
                 qvalue_net_kwargs=None,
                 double=True,
                 memory_capacity=10000,
                 use_bonus=False,
                 uncertainty_estimator_kwargs=None,
                 prioritized_replay=True,
                 update_frequency=1,
                 **kwargs):
        # Wrap arguments and initialize base class
        memory_kwargs = {
            'capacity': memory_capacity,
            'n_steps': 1,
            'gamma': gamma
        }
        exploration_kwargs = {
            'method': "EpsilonGreedy",
            'temperature': epsilon_init,
            'final_temperature': epsilon_final,
            'tau': epsilon_decay,
        }
        self.use_bonus = use_bonus
        if self.use_bonus:
            env = UncertaintyEstimatorWrapper(env,
                                              **uncertainty_estimator_kwargs)
        IncrementalAgent.__init__(self, env, **kwargs)
        self.horizon = horizon
        self.exploration_kwargs = exploration_kwargs or {}
        self.memory_kwargs = memory_kwargs or {}
        self.n_episodes = n_episodes
        self.batch_size = batch_size
        self.target_update = target_update
        self.double = double

        assert isinstance(env.action_space, spaces.Discrete), \
            "Only compatible with Discrete action spaces."

        self.prioritized_replay = prioritized_replay
        memory_class = PrioritizedReplayMemory if prioritized_replay else TransitionReplayMemory
        self.memory = memory_class(**self.memory_kwargs)
        self.exploration_policy = \
            exploration_factory(self.env.action_space,
                                **self.exploration_kwargs)
        self.training = True
        self.steps = 0
        self.episode = 0
        self.writer = None

        self.optimizer_kwargs = {
            'optimizer_type': optimizer_type,
            'lr': learning_rate
        }
        self.device = choose_device(device)
        self.loss_function = loss_function
        self.gamma = gamma

        qvalue_net_kwargs = qvalue_net_kwargs or {}
        qvalue_net_fn = load(qvalue_net_fn) if isinstance(qvalue_net_fn, str) else \
            qvalue_net_fn or default_qvalue_net_fn
        self.value_net = qvalue_net_fn(self.env, **qvalue_net_kwargs)
        self.target_net = qvalue_net_fn(self.env, **qvalue_net_kwargs)

        self.target_net.load_state_dict(self.value_net.state_dict())
        self.target_net.eval()
        logger.info("Number of trainable parameters: {}".format(
            trainable_parameters(self.value_net)))
        self.value_net.to(self.device)
        self.target_net.to(self.device)
        self.loss_function = loss_function_factory(self.loss_function)
        self.optimizer = optimizer_factory(self.value_net.parameters(),
                                           **self.optimizer_kwargs)
        self.update_frequency = update_frequency
        self.steps = 0
示例#11
0
    def __init__(self,
                 env,
                 batch_size=64,
                 update_frequency=8,
                 horizon=256,
                 gamma=0.99,
                 entr_coef=0.01,
                 vf_coef=0.5,
                 learning_rate=0.01,
                 optimizer_type="ADAM",
                 eps_clip=0.2,
                 k_epochs=5,
                 use_gae=True,
                 gae_lambda=0.95,
                 policy_net_fn=None,
                 value_net_fn=None,
                 policy_net_kwargs=None,
                 value_net_kwargs=None,
                 device="cuda:best",
                 use_bonus=False,
                 uncertainty_estimator_kwargs=None,
                 **kwargs):  # TODO: sort arguments

        # For all parameters, define self.param = param
        _, _, _, values = inspect.getargvalues(inspect.currentframe())
        values.pop("self")
        for arg, val in values.items():
            setattr(self, arg, val)
        AgentWithSimplePolicy.__init__(self, env, **kwargs)

        # bonus
        self.use_bonus = use_bonus
        if self.use_bonus:
            self.env = UncertaintyEstimatorWrapper(
                self.env, **uncertainty_estimator_kwargs)

        # algorithm parameters

        # options
        # TODO: add reward normalization option
        #       add observation normalization option
        #       add orthogonal weight initialization option
        #       add value function clip option
        #       add ... ?
        self.normalize_advantages = True  # TODO: turn into argument

        self.use_gae = use_gae
        self.gae_lambda = gae_lambda

        # function approximators
        self.policy_net_kwargs = policy_net_kwargs or {}
        self.value_net_kwargs = value_net_kwargs or {}

        self.state_dim = self.env.observation_space.shape[0]
        self.action_dim = self.env.action_space.n

        #
        self.policy_net_fn = policy_net_fn or default_policy_net_fn
        self.value_net_fn = value_net_fn or default_value_net_fn

        self.device = choose_device(device)

        self.optimizer_kwargs = {
            "optimizer_type": optimizer_type,
            "lr": learning_rate
        }

        # check environment
        assert isinstance(self.env.observation_space, spaces.Box)
        assert isinstance(self.env.action_space, spaces.Discrete)

        self.cat_policy = None  # categorical policy function

        # initialize
        self.reset()
示例#12
0
    def __init__(self,
                 env,
                 n_episodes=4000,
                 batch_size=8,
                 horizon=256,
                 gamma=0.99,
                 entr_coef=0.01,
                 vf_coef=0.5,
                 learning_rate=0.01,
                 optimizer_type='ADAM',
                 k_epochs=5,
                 use_gae=True,
                 gae_lambda=0.95,
                 policy_net_fn=None,
                 value_net_fn=None,
                 policy_net_kwargs=None,
                 value_net_kwargs=None,
                 device="cuda:best",
                 use_bonus=False,
                 uncertainty_estimator_kwargs=None,
                 **kwargs):
        self.use_bonus = use_bonus
        if self.use_bonus:
            env = UncertaintyEstimatorWrapper(env,
                                              **uncertainty_estimator_kwargs)
        IncrementalAgent.__init__(self, env, **kwargs)

        self.n_episodes = n_episodes
        self.batch_size = batch_size
        self.horizon = horizon
        self.gamma = gamma
        self.entr_coef = entr_coef
        self.vf_coef = vf_coef
        self.learning_rate = learning_rate
        self.k_epochs = k_epochs
        self.use_gae = use_gae
        self.gae_lambda = gae_lambda
        self.damping = 0  # TODO: turn into argument
        self.max_kl = 0.1  # TODO: turn into argument
        self.use_entropy = False  # TODO: test, and eventually turn into argument
        self.normalize_advantage = True  # TODO: turn into argument
        self.normalize_reward = False  # TODO: turn into argument

        self.policy_net_kwargs = policy_net_kwargs or {}
        self.value_net_kwargs = value_net_kwargs or {}

        self.state_dim = self.env.observation_space.shape[0]
        self.action_dim = self.env.action_space.n

        #
        self.policy_net_fn = policy_net_fn or default_policy_net_fn
        self.value_net_fn = value_net_fn or default_value_net_fn

        self.device = choose_device(device)

        self.optimizer_kwargs = {
            'optimizer_type': optimizer_type,
            'lr': learning_rate
        }

        # check environment
        assert isinstance(self.env.observation_space, spaces.Box)
        assert isinstance(self.env.action_space, spaces.Discrete)

        # TODO: check
        self.cat_policy = None  # categorical policy function
        self.policy_optimizer = None

        self.value_net = None
        self.value_optimizer = None

        self.cat_policy_old = None

        self.value_loss_fn = None

        self.memory = None

        self.episode = 0

        self._rewards = None
        self._cumul_rewards = None

        # initialize
        self.reset()