示例#1
0
    def _env_fn():
        nonlocal frame_stack

        env = gym.make(env_id)

        if episode_time_limit is not None:
            env._max_episode_steps = episode_time_limit

        if seed is not None:
            env.seed(seed)

        env = MonitorEpisodeWrapper(env)

        if reward_scale is not None:
            env = ScaleRewardsWrapper(env, reward_scale)

        if env_type == "atari":
            env = AtariPreprocessing(env)
            frame_stack = 4

        if frame_stack is not None:
            env = FrameStack(env, frame_stack)

        if isinstance(env.action_space, gym.spaces.Box):
            env = ClipActionsWrapper(env)

        return env
示例#2
0
def make_env(env_id,  # 环境id
             noop_max=30,  # 最大的no-op操作步数
             frame_skip=4,  # 跳帧步数
             screen_size=84,  # 帧的尺寸
             terminal_on_life_loss=True,  # 是否在一条命没后结束Episode
             grayscale_obs=True,  # True的话返回灰度图,否则返回RGB彩色图
             grayscale_newaxis=False,  # 将输出的灰度图由2维转换为1维
             scale_obs=True,  # 是否对obs标准化到[0,1]
             num_stack=4,  # 叠加帧的步数
             lz4_compress=False,  # 是否使用lz4压缩
             obs_LazyFramesToNumpy=True,  # 是否将输出的图像由LazyFrames转化为numpy
             ):

    assert gym.envs.registry.spec(env_id).entry_point == 'gym.envs.atari:AtariEnv', "env is not Atari"

    env = gym.make(env_id)
    env = atari_preprocessing.AtariPreprocessing(env=env,
                                                 noop_max=noop_max,
                                                 frame_skip=frame_skip,
                                                 screen_size=screen_size,
                                                 terminal_on_life_loss=terminal_on_life_loss,
                                                 grayscale_obs=grayscale_obs,
                                                 grayscale_newaxis=grayscale_newaxis,
                                                 scale_obs=scale_obs)
    env = FrameStack(env, num_stack=num_stack, lz4_compress=lz4_compress)
    if obs_LazyFramesToNumpy:
        env = ObsLazyFramesToNumpy(env)
    return env
 def __init__(self,
              env,
              stacked_frames=STACKED_FRAMES,
              screen_size=DEFAULT_SCREEN_SIZE,
              charge_policy=True,
              train_decoder_first=False):
     """
     Generates
     :param env:
     :param results_path:
     """
     if type(env) == str:
         env = gym.make(env)
     self.env_name = env.spec.id
     if len(env.observation_space.shape) == 3:
         env = AtariPreprocessing(env,
                                  noop_max=10,
                                  frame_skip=1,
                                  scale_obs=True,
                                  screen_size=screen_size,
                                  terminal_on_life_loss=True)
     self.env = FrameStack(env, num_stack=stacked_frames)
     #self.env = FrameStack(env, num_stack=4)
     self.policy_path = os.path.join(SPECIFIC_NETWORKS_PATH, self.env_name +
                                     '.pth') if charge_policy else None
     self.train_encoder_first = train_decoder_first
     self.policy = Network(input_size=self.env.observation_space.shape,
                           num_actions=env.action_space.n,
                           load_from_path=self.policy_path,
                           prepare_conv=train_decoder_first)
def test_env(env, model, device, deterministic=True):
    env = gym_super_mario_bros.make('SuperMarioBros-v0')
    env = JoypadSpace(env, SIMPLE_MOVEMENT)
    env = RewardScalar(env)
    env = WarpFrame(env)
    env = FrameStack(env, 4)
    env = StochasticFrameSkip(env, 4, 0.5)
    env = ScaledFloatFrame(env)
    # env=gym.wrappers.Monitor(env, 'recording/PPORB5/{}'.format(str(num)), video_callable=lambda episode_id: True, force=True)
    state = env.reset()
    done = False
    total_reward = 0
    distance = []
    print("yes")
    for i in range(2000):
        state = torch.FloatTensor(state).to(device)
        state = state.float()
        state = state.permute(3, 0, 1, 2)
        dist, _ = model(state)
        policy = dist
        policy = Categorical(F.softmax(policy, dim=-1).data.cpu())
        actionLog = policy.sample()
        action = actionLog.numpy()
        next_state, reward, done, info = env.step(action[0])
        distance.append(info['x_pos'])
        state = next_state
        total_reward += reward
        env.render()

    print(total_reward)
    print(max(distance))
示例#5
0
    def __init__(self, env_name):
        env = gym.make(env_name)

        self.env = GrayScaleObservation(env)
        self.env = NormalizeObservation(self.env)
        self.env = FrameStack(self.env, 4)

        gym.Wrapper.__init__(self, self.env)
示例#6
0
 def make_atari_env(game: str, seed: Optional[int] = None) -> Env:
     env = gym.make(f'{game}NoFrameskip-v4')
     if seed is not None:
         env.seed(seed)
     env = RemoveALEInfo(env)
     env = AtariPreprocessing(env, frame_skip=4, screen_size=96, terminal_on_life_loss=False, grayscale_obs=True)
     env = FrameStack(env, num_stack=framestack)
     return env
示例#7
0
    def __init__(self, episodes):
        self.current_episode = 0
        self.episodes = episodes

        self.episode_score = []
        self.episode_qs = []
        self.episode_distance = []
        self.episode_loss = []

        self.fig, self.ax = plt.subplots(2, 2)
        self.fig.canvas.draw()
        plt.show(block=False)

        self.env = gym_super_mario_bros.make('SuperMarioBros-v0')
        # Apply Observation Wrappers
        self.env = GrayScaleObservation(self.env)
        self.env = ResizeObservation(self.env, 84)
        # Apply Control Wrappers
        self.env = JoypadSpace(self.env, SIMPLE_MOVEMENT)
        self.env = NoopResetEnv(self.env)
        # Apply Frame Wrappers
        self.env = SkipFrame(self.env, 4)
        self.env = FrameStack(self.env, 4)

        self.agent = DQNAgent(stateShape=(84, 84, 4),
                              actionSpace=self.env.action_space, numPicks=32, memorySize=100000)
示例#8
0
    def __init__(self, episodes, checkpoint, current_episode, epsilon):
        self.current_episode = current_episode
        self.episodes = episodes

        self.episode_score = []
        self.episode_qs = []
        self.episode_distance = []
        self.episode_loss = []
        self.episode_policies = []

        self.fig, self.ax = plt.subplots(1, 2, figsize=(12, 4))
        self.fig.canvas.draw()

        self.env = gym_super_mario_bros.make('SuperMarioBros-1-1-v0')
        # Apply Observation Wrappers
        self.env = GrayScaleObservation(self.env)
        self.env = ResizeObservation(self.env, 84)
        # Apply Control Wrappers
        self.env = JoypadSpace(self.env, SIMPLE_MOVEMENT)
        self.env = NoopResetEnv(self.env)
        # Apply Frame Wrappers
        self.env = SkipFrame(self.env, 4)
        self.env = FrameStack(self.env, 4)

        self.agent = DQNAgent(stateShape=(4, 84, 84),
                              actionSpace=self.env.action_space,
                              numPicks=32,
                              memorySize=20000,
                              numRewards=4,
                              epsilon=epsilon,
                              checkpoint=checkpoint)
示例#9
0
def create_env(name, train=True):
    env = gym.make(name)
    # env = EpisodicLifeEnv(env)
    if train:
        env = NoopResetEnv(env, 50)
    env = ScaledFloatFrame(env)
    env = FrameStack(env, 4)

    return env
 def _thunk():
     env = gym_super_mario_bros.make('SuperMarioBros-v0')
     env = JoypadSpace(env, SIMPLE_MOVEMENT)
     env = RewardScalar(env)
     env = WarpFrame(env)
     env = StochasticFrameSkip(env, 4, 0.5)
     env = FrameStack(env, 4)
     env = ScaledFloatFrame(env)
     return env
def test_frame_stack(env_id, num_stack, lz4_compress):
    env = gym.make(env_id)
    shape = env.observation_space.shape
    env = FrameStack(env, num_stack, lz4_compress)
    assert env.observation_space.shape == (num_stack, ) + shape

    obs = env.reset()
    obs = np.asarray(obs)
    assert obs.shape == (num_stack, ) + shape
    for i in range(1, num_stack):
        assert np.allclose(obs[i - 1], obs[i])

    obs, _, _, _ = env.step(env.action_space.sample())
    obs = np.asarray(obs)
    assert obs.shape == (num_stack, ) + shape
    for i in range(1, num_stack - 1):
        assert np.allclose(obs[i - 1], obs[i])
    assert not np.allclose(obs[-1], obs[-2])
示例#12
0
def test_frame_stack(env_id, num_stack, lz4_compress):
    env = gym.make(env_id)
    shape = env.observation_space.shape
    env = FrameStack(env, num_stack, lz4_compress)
    assert env.observation_space.shape == (num_stack, ) + shape
    assert env.observation_space.dtype == env.env.observation_space.dtype

    dup = gym.make(env_id)

    obs = env.reset(seed=0)
    dup_obs = dup.reset(seed=0)
    assert np.allclose(obs[-1], dup_obs)

    for _ in range(num_stack**2):
        action = env.action_space.sample()
        dup_obs, _, _, _ = dup.step(action)
        obs, _, _, _ = env.step(action)
        assert np.allclose(obs[-1], dup_obs)

    assert len(obs) == num_stack
示例#13
0
文件: wrappers.py 项目: hw26/MadMario
def wrapper(env):
    # skip to every 4th frame. Remove redundant info. to speed up training
    env = SkipEnv(env, skip=4)
    # rgb to gray. Reduce input dimension thus model size
    env = GrayScaleObservation(env, keep_dim=False)
    # resize to 84 x 84. Reduce input dimension thus model size
    env = ResizeObservation(env, shape=84)
    # make obs a stack of previous 3 frames. Need consecutive frames
    # to differentiate landing vs. taking off
    env = FrameStack(env, num_stack=4)
    return env
示例#14
0
文件: train.py 项目: nik-sm/dqn
    def __init__(self,
                 game: str,
                 replay_buffer_capacity: int,
                 replay_start_size: int,
                 batch_size: int,
                 discount_factor: float,
                 lr: float,
                 device: str = 'cuda:0',
                 env_seed: int = 0,
                 frame_buffer_size: int = 4,
                 print_self=True):

        self.device = device
        self.discount_factor = discount_factor
        self.game = game
        self.batch_size = batch_size

        self.replay_buf = ReplayBuffer(capacity=replay_buffer_capacity)

        self.env = FrameStack(
            AtariPreprocessing(
                gym.make(self.game),
                # noop_max=0,
                # terminal_on_life_loss=True,
                scale_obs=False),
            num_stack=frame_buffer_size)
        self.env.seed(env_seed)
        self.reset()

        self.n_action = self.env.action_space.n
        self.policy_net = DQN(self.n_action).to(self.device)
        self.target_net = DQN(self.n_action).to(self.device).eval()
        self.optimizer = RMSprop(
            self.policy_net.parameters(),
            alpha=0.95,
            # momentum=0.95,
            eps=0.01)

        if print_self:
            print(self)
        self._fill_replay_buf(replay_start_size)
示例#15
0
def make_env(state, stacks, size, game="SuperMarioKart-Snes", record=False):
    env = retro.make(game=game, use_restricted_actions=retro.Actions.ALL, state=state)
    if record:
        env = KartRecorder(env=env, size=size)
    env = KartMultiDiscretizer(env)
    env = KartObservation(env, size=size)
    env = FrameStack(env, num_stack=stacks)
    # Careful, this has to be done after the stack
    env = KartSkipper(env, skip=5)
    # Has to be done after skipper
    env = KartReward(env)
    return env
示例#16
0
 def get_env(*args, **kwargs):
     return GymEnvWrapper(
         TransformObservation(env=FrameStack(
             num_stack=4,
             env=(gym_wrapper.GymFromDMEnv(
                 bsuite.load_and_record_to_csv(
                     bsuite_id=bsuite_id,
                     results_dir=results_dir,
                     overwrite=True,
                 )) if not gym_id else gym.make(gym_id))),
                              f=lambda lazy_frames: np.reshape(
                                  np.stack(lazy_frames._frames), -1)))
示例#17
0
    def __init__(self, env: Union[Env, Wrapper, str],
                 noop_max: int = 30,
                 frameskip: int = 4,
                 stacked_frames: int = 4,
                 screen_size: int = 84,
                 grayscale_obs: bool = True,
                 scale_obs: bool = False,
                 clip_reward: bool = True,
                 min_reward_clip: float = -1.0,
                 max_reward_clip: float = 1.0,
                 terminal_on_life_loss: bool = True,
                 max_steps_per_episode: Optional[int] = None,
                 compress_frames=True):

        if isinstance(env, str):
            if "NoFrameskip" in env:
                env = gym.make(env)
            else:
                env = gym.make(env, frameskip=1)
                env.spec.id = f"{env.spec.id}(NoFrameskip)"

        # Main Wrapping (OpenAI)
        # > No Ops
        # > Explicit frameskipping
        # > Resized
        # > Terminal on ale lives -= 1 or == 0
        # > Grayscaled
        # > Scaled
        env = AtariPreprocessing(env, noop_max, frameskip, screen_size, terminal_on_life_loss, grayscale_obs, scale_obs)

        # Time (Step) Limit Wrapper
        if max_steps_per_episode is not None:
            env = TimeLimit(env, max_episode_steps=max_steps_per_episode)

        # Fire to reset
        if 'FIRE' in env.unwrapped.get_action_meanings():
            from yaaf.environments.wrappers import FireToResetWrapper
            env = FireToResetWrapper(env)

        # Clipped reward -1, 0, +1
        if clip_reward:
            from yaaf.environments.wrappers import ClippedRewardWrapper
            env = ClippedRewardWrapper(env, min_reward_clip, max_reward_clip)

        # Stacked frames for history length
        if stacked_frames is not None:
            env = FrameStack(env, stacked_frames)

        super(DeepMindAtari2600Wrapper, self).__init__(env)

        self._compress_frames = not compress_frames
    def __init__(self, env_name, max_steps=10000, **kwargs):
        self.env_fn = lambda: FrameStack(AtariPreprocessing(
            env=gym.make(str(env_name) + "NoFrameskip-v4"),
            terminal_on_life_loss=True,
            frame_skip=4),
                                         num_stack=4)

        env = self.env_fn()
        # Call the parent constructor, so we can access self.env later
        super(Atari_continuous, self).__init__(env)

        self.env_name = env_name
        self.max_steps = max_steps
        # Counter of steps per episode
        self.current_step = 0
示例#19
0
def test_actor(i, eps, replay_memory, parameter_server):
    env = gym.make("BreakoutNoFrameskip-v4")
    env = AtariPreprocessing(env)
    env = FrameStack(env, num_stack=4)

    assert env.action_space.n == 4

    actor = Actor(env, env.action_space.n, eps, replay_memory,
                  parameter_server)

    t = 0
    score_sum = 0.0
    while True:
        t += 1
        score_sum += actor.run_episode()
        if t % 100 == 0:
            print("Actor", i, "got", score_sum / 100)
示例#20
0
def create_env(env_name="SuperMarioBros-1-1-v0"):
    env = gym_super_mario_bros.make(env_name)

    # Restricts action space to only "right" and "jump + right"
    env = JoypadSpace(env, [["right"], ["right", "A"]])
    # Accumulates rewards every 4th frame
    env = SkipFrame(env, skip=4)
    # Transform RGB image to graycale, [240, 256]
    env = GrayScaleObservation(env)
    # Downsample to new size, [1, 84, 84]
    env = ResizeObservation(env, shape=84)
    # Add extra precision to np.array state
    env = TransformObservation(env, f=lambda x: x / 255.)
    # Squash 4 consecutive frames of the environment into a
    # single observation point to feed to our learning model, [4, 84, 84]
    env = FrameStack(env, num_stack=4)
    return env
示例#21
0
def launch_env(map_name,
               randomize_maps_on_reset=False,
               domain_rand=False,
               frame_stacking=1):
    environment = DuckietownEnv(domain_rand=domain_rand,
                                max_steps=math.inf,
                                map_name=map_name,
                                randomize_maps_on_reset=False)

    tmp = environment._get_tile

    if frame_stacking != 1:
        environment = FrameStack(environment, frame_stacking)
        environment._get_tile = tmp
        environment.reset()  # Force reset to get fake frame buffer

    return environment
示例#22
0
def make_atari_env(env_name):
    env = gym.make(env_name)
    env = FireReset(env)
    env = AtariPreprocessing(env,
                             noop_max=30,
                             frame_skip=2,
                             screen_size=84,
                             terminal_on_life_loss=False,
                             grayscale_obs=True,
                             scale_obs=False)
    env = PyTorchImageWrapper(env)
    env = FrameStack(env, num_stack=4)
    env = TransformObservation(env, f=np.array)
    env = ConcatWrapper(env, axis=0)
    env = TransformObservation(
        env, f=lambda obs: np.asarray(obs, dtype=np.float32) / 255.0)

    return env
示例#23
0
def getMarioEnv():
    r"""Return an `env`. Each time Mario makes an action, the environment responds with a state.

    Returns:
        env (gym environment): it returns a state of 3D array of size (4, 84, 84) representing a 4 consecutive frames stacked state.
    """
    # Initialize Super Mario environment
    env = gym_super_mario_bros.make("SuperMarioBros-1-1-v0")
    # Limit the action-space to
    #   0. walk right
    #   1. jump right
    env = JoypadSpace(env, [["right"], ["right", "A"]])
    env.reset()

    env = SkipFrame(env, skip=4)
    env = GrayScaleObservation(env)
    env = ResizeObservation(env, shape=84)
    env = FrameStack(env, num_stack=4)

    return env
示例#24
0
def run_agent(layout: str):
    env = PacmanEnv(layout)
    env = SkipFrame(env, skip=4)
    env = GrayScaleObservation(env)
    env = ResizeObservation(env, shape=84)
    env = FrameStack(env, num_stack=4)
    screen = env.reset(mode='rgb_array')
    n_actions = env.action_space.n

    model = load_model(screen.shape, n_actions, 'pacman.pth')

    for i in range(10):

        env.render(mode='human')
        screen = env.reset(mode='rgb_array')

        for _ in count():
            env.render(mode='human')
            action = select_action(screen, 0, model, n_actions)
            screen, reward, done, info = env.step(action)

            if done:
                break
示例#25
0
    def __init__(self,
                 game,
                 stack=False,
                 sticky_action=False,
                 clip_reward=False,
                 terminal_on_life_loss=False,
                 **kwargs):
        # set action_probability=0.25 if sticky_action=True
        env_id = '{}NoFrameskip-v{}'.format(game, 0 if sticky_action else 4)

        # use official atari wrapper
        env = AtariPreprocessing(gym.make(env_id),
                                 terminal_on_life_loss=terminal_on_life_loss)

        if stack:
            env = FrameStack(env, num_stack=4)

        if clip_reward:
            env = TransformReward(env, lambda r: np.clip(r, -1.0, 1.0))

        self._env = env

        self.observation_space = env.observation_space
        self.action_space = env.action_space
def _enjoy(args):
    # Launch the env with our helper function
    env = launch_env()
    print("Initialized environment")

    # Wrappers
    env = ResizeWrapper(env)
    env = GrayscaleWrapper(env)
    env = NormalizeWrapper(env)
    env = FrameStack(env, 4)
    env = DtRewardWrapper(env)
    env = ActionWrapper(env)
    print("Initialized Wrappers")

    state_dim = env.observation_space.shape
    action_dim = env.action_space.shape[0]
    max_action = float(env.action_space.high[0])

    # Initialize policy
    # policy = TD3(state_dim, action_dim, max_action, net_type="cnn")
    # policy.load(filename=args.policy, directory='reinforcement/pytorch/models/')

    policy = policies[args.policy](state_dim, action_dim, max_action)
    policy.load("reinforcement/pytorch/models/", args.policy)

    obs = env.reset()
    done = False

    while True:
        while not done:
            action = policy.predict(np.array(obs))
            # Perform action
            obs, reward, done, _ = env.step(action)
            env.render()
        done = False
        obs = env.reset()
示例#27
0
    def __init__(self, episodes, checkpoint, current_episode, epsilon):
        self.current_episode = current_episode
        self.episodes = episodes

        self.episode_score = []
        self.episode_qs = []
        self.episode_distance = []
        self.episode_loss = []

        self.env = gym_super_mario_bros.make('SuperMarioBros-1-1-v0')
        self.env = JoypadSpace(self.env, SIMPLE_MOVEMENT)

        # Apply Frame Wrappers
        self.env = SkipFrame(self.env, 4)
        self.env = GrayScaleObservation(self.env)
        self.env = ResizeObservation(self.env, 84)
        self.env = FrameStack(self.env, 4)

        self.agent = DQNAgent(stateShape=(4, 84, 84),
                              actionSpace=self.env.action_space,
                              numPicks=32,
                              memorySize=20000,
                              epsilon=epsilon,
                              checkpoint=checkpoint)
示例#28
0
from gym.wrappers import FrameStack
from ray.tune import registry

from envs.procgen_env_wrapper import ProcgenEnvWrapper

# Register Env in Ray
registry.register_env(
    "stacked_procgen_env",  # This should be different from procgen_env_wrapper
    lambda config: FrameStack(ProcgenEnvWrapper(config), 4),
)
示例#29
0
    def observation(self, observation):
        transforms = T.Compose([
            T.ToPILImage(),
            T.Resize(self.shape),
            T.ToTensor(),
            T.Normalize((0, ), (255, ))
        ])
        observation = transforms(observation).squeeze(0)
        return observation


env = SkipFrame(env, skip=4)
env = GrayScaleObservation(env)
env = ResizeObservation(env, shape=84)
env = FrameStack(env, num_stack=4)


class Mario:
    def __init__(self, state_dim, action_dim, save_dir):
        self.state_dim = state_dim
        self.action_dim = action_dim
        self.save_dir = save_dir
        self.use_cuda = torch.cuda.is_available()
        self.net = DDQNet(self.state_dim, self.action_dim).float()
        if self.use_cuda:
            self.net = self.net.to(device="cuda")

        self.exploration_rate = 1
        self.exploration_rate_decay = 0.99999975
        self.exploration_rate_min = 0.1
示例#30
0
文件: train.py 项目: nik-sm/dqn
class Agent:
    def __init__(self,
                 game: str,
                 replay_buffer_capacity: int,
                 replay_start_size: int,
                 batch_size: int,
                 discount_factor: float,
                 lr: float,
                 device: str = 'cuda:0',
                 env_seed: int = 0,
                 frame_buffer_size: int = 4,
                 print_self=True):

        self.device = device
        self.discount_factor = discount_factor
        self.game = game
        self.batch_size = batch_size

        self.replay_buf = ReplayBuffer(capacity=replay_buffer_capacity)

        self.env = FrameStack(
            AtariPreprocessing(
                gym.make(self.game),
                # noop_max=0,
                # terminal_on_life_loss=True,
                scale_obs=False),
            num_stack=frame_buffer_size)
        self.env.seed(env_seed)
        self.reset()

        self.n_action = self.env.action_space.n
        self.policy_net = DQN(self.n_action).to(self.device)
        self.target_net = DQN(self.n_action).to(self.device).eval()
        self.optimizer = RMSprop(
            self.policy_net.parameters(),
            alpha=0.95,
            # momentum=0.95,
            eps=0.01)

        if print_self:
            print(self)
        self._fill_replay_buf(replay_start_size)

    def __repr__(self):
        return '\n'.join([
            'Agent:', f'Game: {self.game}', f'Device: {self.device}',
            f'Policy net: {self.policy_net}', f'Target net: {self.target_net}',
            f'Replay buf: {self.replay_buf}'
        ])

    def _fill_replay_buf(self, replay_start_size):
        for _ in trange(replay_start_size,
                        desc='Fill replay_buf randomly',
                        leave=True):
            self.step(1.0)

    def reset(self):
        """Reset the end, pre-populate self.frame_buf and self.state"""
        self.state = self.env.reset()

    @torch.no_grad()
    def step(self, epsilon, clip_reward=True):
        """
        Choose an action based on current state and epsilon-greedy policy
        """
        # Choose action
        if random.random() <= epsilon:
            q_values = None
            action = self.env.action_space.sample()
        else:
            torch_state = torch.tensor(self.state,
                                       dtype=torch.float32,
                                       device=self.device).unsqueeze(0) / 255.0
            q_values = self.policy_net(torch_state)
            action = int(q_values.argmax(dim=1).item())

        # Apply action
        next_state, reward, done, _ = self.env.step(action)
        if clip_reward:
            reward = max(-1.0, min(reward, 1.0))

        # Store into replay buffer
        self.replay_buf.append(
            (torch.tensor(
                np.array(self.state), dtype=torch.float32, device="cpu") /
             255., action, reward,
             torch.tensor(
                 np.array(next_state), dtype=torch.float32, device="cpu") /
             255., done))

        # Advance to next state
        self.state = next_state
        if done:
            self.reset()

        return reward, q_values, done

    def q_update(self):
        self.optimizer.zero_grad()
        states, actions, rewards, next_states, dones = [
            x.to(self.device) for x in self.replay_buf.sample(self.batch_size)
        ]

        with torch.no_grad():
            y = torch.where(
                dones, rewards, rewards +
                self.discount_factor * self.target_net(next_states).max(1)[0])

        predicted_values = self.policy_net(states).gather(
            1, actions.unsqueeze(-1)).squeeze(-1)
        loss = huber(y, predicted_values, 2.)
        loss.backward()
        self.optimizer.step()
        return (y - predicted_values).abs().mean()