Example #1
0
class TD3Agent(Agent):
    def __init__(self,
                 action_space,
                 observation_space,
                 gamma=0.99,
                 nb_steps_warmup=2000,
                 sigma=0.3,
                 polyak=0.995,
                 pi_lr=0.001,
                 q_lr=0.001,
                 batch_size=100,
                 action_noise=0.1,
                 target_noise=0.2,
                 noise_clip=0.5,
                 policy_delay=2,
                 memory_size=10000,
                 training=True):
        super().__init__()
        self.gamma = gamma
        self.sigma = sigma
        self.polyak = polyak
        self.pi_lr = pi_lr
        self.q_lr = q_lr
        self.batch_size = batch_size
        self.action_noise = action_noise
        self.target_noise = target_noise
        self.noise_clip = noise_clip
        self.policy_delay = policy_delay

        self.action_space = action_space
        self.nb_actions = action_space.shape[0]
        self.observation_shape = observation_space.shape
        self.nb_steps_warmup = nb_steps_warmup
        self.training = training

        self.memory = Memory(capacity=memory_size,
                             observation_shape=self.observation_shape,
                             action_shape=self.action_space.shape)

        self.actor_model, self.critic_model1, self.critic_model2 = self._build_network(
        )

        self.target_actor_model, self.target_critic_model1, self.target_critic_model2 = self._build_network(
        )

        self.target_actor_model.set_weights(self.actor_model.get_weights())
        self.target_critic_model1.set_weights(self.critic_model1.get_weights())
        self.target_critic_model2.set_weights(self.critic_model2.get_weights())

        self.step_count = 0

    def _build_network(self):
        action_tensor = tf.keras.layers.Input(shape=(self.nb_actions, ),
                                              dtype=tf.float64)
        observation_tensor = tf.keras.layers.Input(
            shape=self.observation_shape, dtype=tf.float64)

        # 创建Actor模型
        y = tf.keras.layers.Flatten()(observation_tensor)
        y = tf.keras.layers.Dense(32, activation='relu')(y)
        y = tf.keras.layers.Dense(32, activation='relu')(y)
        y = tf.keras.layers.Dense(32, activation='relu')(y)
        y = tf.keras.layers.Dense(self.nb_actions, activation='tanh')(y)

        actor_model = tf.keras.Model(inputs=observation_tensor, outputs=y)
        actor_model.compile(optimizer=tf.keras.optimizers.Adam(lr=self.pi_lr),
                            loss='mse')

        # 创建Critic1模型
        critic_model1 = self._build_critic_network(observation_tensor,
                                                   action_tensor)
        # 创建Critic2模型
        critic_model2 = self._build_critic_network(observation_tensor,
                                                   action_tensor)

        return actor_model, critic_model1, critic_model2

    def _build_critic_network(self, observation_tensor, action_tensor):
        y = tf.keras.layers.Concatenate()([observation_tensor, action_tensor])
        y = tf.keras.layers.Dense(32, activation='relu')(y)
        y = tf.keras.layers.Dense(32, activation='relu')(y)
        y = tf.keras.layers.Dense(32, activation='relu')(y)
        y = tf.keras.layers.Dense(1, activation='linear')(y)

        critic_model = tf.keras.Model(
            inputs=[observation_tensor, action_tensor], outputs=y)
        critic_model.compile(optimizer=tf.keras.optimizers.Adam(lr=self.q_lr),
                             loss='mse')
        return critic_model

    def forward(self, observation):
        self.step_count += 1

        if self.step_count < self.nb_steps_warmup:
            return self.action_space.sample()
        else:
            observation = np.expand_dims(observation, axis=0)
            action = self.actor_model.predict(observation)
            action = action.reshape(self.nb_actions)
            if self.training:
                action = action + np.clip(
                    np.random.normal(0.0, self.action_noise, self.nb_actions),
                    -self.noise_clip, self.noise_clip)
            return action

    def backward(self, observation, action, reward, terminal,
                 next_observation):
        self.memory.store_transition(observation, action, reward, terminal,
                                     next_observation)

        if self.step_count < self.nb_steps_warmup:
            return
        else:
            self._update()

    def _update(self):
        observations, actions, rewards, terminals, next_observations = self.memory.sample_batch(
        )

        self._update_critic(observations, actions, rewards, terminals,
                            next_observations)
        self._update_actor(observations)

        if self.step_count % self.policy_delay == 0:
            # 更新critic的target网络
            new_target_critic_weights_list = polyak_averaging(
                self.critic_model1.get_weights(),
                self.target_critic_model1.get_weights(), self.polyak)
            self.target_critic_model1.set_weights(
                new_target_critic_weights_list)
            new_target_critic_weights_list = polyak_averaging(
                self.critic_model2.get_weights(),
                self.target_critic_model2.get_weights(), self.polyak)
            self.target_critic_model2.set_weights(
                new_target_critic_weights_list)

            # 更新actor的target网络
            new_target_actor_weights_list = polyak_averaging(
                self.actor_model.get_weights(),
                self.target_actor_model.get_weights(), self.polyak)
            self.target_actor_model.set_weights(new_target_actor_weights_list)

    def _update_critic(self, observations, actions, rewards, terminals,
                       next_observations):
        batch_size = observations.shape[0]

        q_values_next1 = self.target_critic_model1(
            [next_observations,
             self.actor_model(next_observations)])
        target1_noise = tf.clip_by_value(
            tf.random.normal(mean=0.0,
                             stddev=self.target_noise,
                             shape=(batch_size, 1),
                             dtype=tf.float64), -self.noise_clip,
            self.noise_clip)
        target_q_values1 = rewards + self.gamma * q_values_next1 + target1_noise
        q_values_next2 = self.target_critic_model2(
            [next_observations,
             self.actor_model(next_observations)])
        target2_noise = tf.clip_by_value(
            tf.random.normal(mean=0.0,
                             stddev=self.target_noise,
                             shape=(batch_size, 1),
                             dtype=tf.float64), -self.noise_clip,
            self.noise_clip)
        target_q_values2 = rewards + self.gamma * q_values_next2 + target2_noise

        target_q_values = tf.minimum(target_q_values1, target_q_values2)

        self.critic_model1.fit([observations, actions],
                               target_q_values,
                               verbose=0)
        self.critic_model2.fit([observations, actions],
                               target_q_values,
                               verbose=0)

    @tf.function
    def _update_actor(self, observations):
        with tf.GradientTape() as tape:
            tape.watch(self.actor_model.trainable_weights)
            q_values = self.target_critic_model1(
                [observations, self.actor_model(observations)])
            loss = -tf.reduce_mean(q_values)

        actor_grads = tape.gradient(loss, self.actor_model.trainable_weights)
        self.actor_model.optimizer.apply_gradients(
            zip(actor_grads, self.actor_model.trainable_weights))
Example #2
0
class DqnCacheEnv(CacheEnv):
    def __init__(self,
                 capacity,
                 request_path,
                 top_k,
                 time_slot_length,
                 gamma=0.99,
                 memory_size=1000,
                 target_model_update=10):
        super().__init__(capacity, request_path, top_k, time_slot_length)
        self.nb_actions = capacity + top_k
        self.observation_shape = (self.nb_actions, )

        # DQN参数
        self.gamma = gamma
        self.memory_size = memory_size
        self.target_model_update = target_model_update

        # ReplayBuffer
        self.memory = Memory(capacity=memory_size,
                             observation_shape=self.observation_shape,
                             action_shape=(self.nb_actions, ))

        # 创建DQN网络
        self.q_net = self._build_q_net()
        self.target_q_net = self._build_q_net()
        self.target_q_net.set_weights(self.q_net.get_weights())

        self.step_count = 0
        self.last_observation = None
        self.last_action = None
        self.eps = 1
        self.eps_decay = 0.95
        self.min_eps = 0.05

    def _build_q_net(self):
        model = tf.keras.Sequential([
            tf.keras.layers.Flatten(input_shape=self.observation_shape),
            tf.keras.layers.Dense(128, activation='relu'),
            tf.keras.layers.Dense(32, activation='relu'),
            tf.keras.layers.Dense(self.nb_actions, activation='linear')
        ])
        model.compile(optimizer='adam',
                      loss=tf.keras.losses.mean_squared_error)
        return model

    def _get_new_cache_content(self, cache_content, top_k_missed_videos,
                               hit_rate):
        candidates = np.concatenate([cache_content, top_k_missed_videos])
        observation = self.loader.get_frequencies(candidates)
        reward = hit_rate

        if self.last_observation is not None and self.last_action is not None:
            self.backward(self.last_observation, self.last_action, reward,
                          False, observation)

        action = self.forward(observation)
        new_cache_content = candidates[action]

        self.last_action = action
        self.last_observation = observation

        return new_cache_content

    def forward(self, observation):
        self.step_count += 1
        action_mask = np.zeros_like(observation, dtype=np.bool)
        if np.random.random() < max(self.min_eps, self.eps):
            action = np.random.choice(np.arange(self.capacity + self.top_k),
                                      self.capacity,
                                      replace=False)
        else:
            self.eps *= self.eps_decay
            if observation.ndim == 1:
                observation = np.expand_dims(observation, axis=0)
            q_values = self.q_net.predict(observation).squeeze(0)
            action = np.argpartition(q_values, -self.capacity)[-self.capacity:]

        action_mask[action] = True
        return action_mask

    def backward(self, observation, action, reward, terminal,
                 next_observation):
        self.memory.store_transition(observation, action, reward, terminal,
                                     next_observation)

        observations, actions, rewards, terminals, next_observations = self.memory.sample_batch(
        )

        q_values = self.q_net.predict(observations)
        actions = actions.astype(np.bool)

        target_q_values = self.target_q_net.predict(next_observations)
        new_q_values = rewards + self.gamma * target_q_values * (~terminals)
        q_values[actions] = new_q_values[actions]

        self.q_net.fit(observations, q_values, verbose=0)

        if self.step_count % self.target_model_update == 0:
            self.target_q_net.set_weights(self.q_net.get_weights())
Example #3
0
class DQNAgent(Agent):
    def __init__(self,
                 action_space,
                 observation_space,
                 gamma=0.99,
                 target_model_update=100,
                 memory_size=10000):
        super().__init__()
        self.gamma = gamma
        self.action_space = action_space
        self.observation_space = observation_space

        self.nb_actions = action_space.n
        self.observation_shape = observation_space.shape
        self.target_model_update = target_model_update

        self.memory = Memory(capacity=memory_size,
                             action_shape=(1, ),
                             observation_shape=self.observation_shape)

        self.model = self._build_network()
        self.target_model = self._build_network()
        self.target_model.set_weights(self.model.get_weights())

        self.update_count = 0

    def _build_network(self):
        model = tf.keras.Sequential([
            tf.keras.layers.Flatten(input_shape=self.observation_shape),
            tf.keras.layers.Dense(128, activation='relu'),
            tf.keras.layers.Dense(32, activation='relu'),
            tf.keras.layers.Dense(self.nb_actions, activation='linear')
        ])
        model.compile(optimizer='adam',
                      metrics=['mse'],
                      loss=tf.keras.losses.mean_squared_error)
        return model

    def forward(self, observation):
        observation = np.expand_dims(observation, axis=0)
        q_values = self.model.predict(observation)
        return q_values

    def backward(self, observation, action, reward, terminal,
                 next_observation):
        self.memory.store_transition(observation, action, reward, terminal,
                                     next_observation)

        observations, actions, rewards, terminals, next_observations = self.memory.sample_batch(
        )

        actions = tf.keras.utils.to_categorical(
            actions, num_classes=self.nb_actions).astype(np.bool)

        q_values = self.model.predict(observations)

        target_q_values = np.max(self.target_model.predict(next_observations),
                                 axis=1,
                                 keepdims=True)
        q_values[
            actions,
            np.newaxis] = rewards + self.gamma * target_q_values * (~terminals)

        self.model.fit(observations, q_values, verbose=0)

        self.update_count += 1
        if self.update_count % self.target_model_update == 0:
            self.target_model.set_weights(self.model.get_weights())
Example #4
0
class DDPGAgent(Agent):
    def __init__(self,
                 action_space,
                 observation_space,
                 gamma=0.99,
                 nb_steps_warmup=2000,
                 training=True,
                 polyak=0.99,
                 memory_size=10000):
        super().__init__()
        self.gamma = gamma
        self.polyak = polyak

        self.action_space = action_space
        self.nb_actions = action_space.shape[0]
        self.observation_shape = observation_space.shape
        self.nb_steps_warmup = nb_steps_warmup
        self.training = training

        self.memory = Memory(capacity=memory_size,
                             observation_shape=self.observation_shape,
                             action_shape=self.action_space.shape)

        self.actor_model, self.critic_model = self._build_network()
        self.target_actor_model, self.target_critic_model = self._build_network(
        )
        self.target_actor_model.set_weights(self.actor_model.get_weights())
        self.target_critic_model.set_weights(self.critic_model.get_weights())

        self.step_count = 0
        self.random_process = OrnsteinUhlenbeckProcess(size=self.nb_actions,
                                                       theta=0.15,
                                                       mu=0.,
                                                       sigma=0.3)

    def _build_network(self):
        action_tensor = tf.keras.layers.Input(shape=(self.nb_actions, ),
                                              dtype=tf.float64)
        observation_tensor = tf.keras.layers.Input(
            shape=self.observation_shape, dtype=tf.float64)

        # 创建Actor模型
        y = tf.keras.layers.Flatten()(observation_tensor)
        y = tf.keras.layers.Dense(32, activation='relu')(y)
        y = tf.keras.layers.Dense(32, activation='relu')(y)
        y = tf.keras.layers.Dense(32, activation='relu')(y)
        y = tf.keras.layers.Dense(self.nb_actions, activation='tanh')(y)

        actor_model = tf.keras.Model(inputs=observation_tensor, outputs=y)
        actor_model.compile(optimizer=tf.keras.optimizers.Adam(lr=3e-4),
                            loss='mse')

        # 创建Critic模型
        y = tf.keras.layers.Concatenate()([observation_tensor, action_tensor])
        y = tf.keras.layers.Dense(32, activation='relu')(y)
        y = tf.keras.layers.Dense(32, activation='relu')(y)
        y = tf.keras.layers.Dense(32, activation='relu')(y)
        y = tf.keras.layers.Dense(1, activation='linear')(y)

        critic_model = tf.keras.Model(
            inputs=[observation_tensor, action_tensor], outputs=y)
        critic_model.compile(optimizer=tf.keras.optimizers.Adam(lr=3e-4),
                             loss='mse')

        return actor_model, critic_model

    def forward(self, observation):
        self.step_count += 1

        if self.step_count < self.nb_steps_warmup:
            return self.action_space.sample()
        else:
            observation = np.expand_dims(observation, axis=0)
            action = self.actor_model.predict(observation)
            action = action.reshape(self.nb_actions)
            if self.training:
                action = action + self.random_process.sample()
            return action

    def backward(self, observation, action, reward, terminal,
                 next_observation):
        self.memory.store_transition(observation, action, reward, terminal,
                                     next_observation)

        if self.step_count < self.nb_steps_warmup:
            return
        else:
            self._update()

    def _update(self):
        observations, actions, rewards, terminals, next_observations = self.memory.sample_batch(
        )

        self._update_critic(observations, actions, rewards, terminals,
                            next_observations)
        self._update_actor(observations)

        # 更新critic的target网络
        new_target_critic_weights_list = polyak_averaging(
            self.critic_model.get_weights(),
            self.target_critic_model.get_weights(), self.polyak)
        self.target_critic_model.set_weights(new_target_critic_weights_list)

        # 更新actor的target网络
        new_target_actor_weights_list = polyak_averaging(
            self.actor_model.get_weights(),
            self.target_actor_model.get_weights(), self.polyak)
        self.target_actor_model.set_weights(new_target_actor_weights_list)

    def polyak_averaging(self, weights_list, target_weights_list):
        new_target_weights_list = []
        for weights, target_weights in zip(weights_list, target_weights_list):
            new_target_weights = self.polyak * target_weights + (
                1 - self.polyak) * weights
            new_target_weights_list.append(new_target_weights)
        return new_target_weights_list

    def _update_critic(self, observations, actions, rewards, terminals,
                       next_observations):
        q_values_next = self.target_critic_model(
            [next_observations,
             self.actor_model(next_observations)])
        target_q_values = rewards + self.gamma * q_values_next
        self.critic_model.fit([observations, actions],
                              target_q_values,
                              verbose=0)

    @tf.function
    def _update_actor(self, observations):
        with tf.GradientTape() as tape:
            tape.watch(self.actor_model.trainable_weights)
            q_values = self.target_critic_model(
                [observations, self.actor_model(observations)])
            loss = -tf.reduce_mean(q_values)

        actor_grads = tape.gradient(loss, self.actor_model.trainable_weights)
        self.actor_model.optimizer.apply_gradients(
            zip(actor_grads, self.actor_model.trainable_weights))
Example #5
0
class SacCacheEnv(CacheEnv):
    def __init__(self, capacity, request_path, top_k, time_slot_length,
                 gamma=0.99,
                 nb_steps_warmup=100,
                 alpha=0.2,
                 polyak=0.995,
                 lr=3e-4,
                 log_std_min=-20,
                 log_std_max=2,
                 memory_size=10000
                 ):
        super().__init__(capacity, request_path, top_k, time_slot_length)
        
        self.nb_actions = capacity + top_k
        self.action_dim = 3
        self.observation_shape = (self.nb_actions, self.action_dim)
        
        self.gamma = gamma
        self.alpha = alpha
        self.polyak = polyak
        self.nb_steps_warmup = nb_steps_warmup
        self.lr = lr
        self.step_count = 0
        self.log_std_min = log_std_min
        self.log_std_max = log_std_max
        self.memory_size = memory_size
        self.memory = Memory(capacity=memory_size,
                             observation_shape=self.observation_shape,
                             action_shape=(self.nb_actions,))
        
        self.policy_net = self._build_policy_network()
        
        self.soft_q_net1 = self._build_q_net()
        self.soft_q_net2 = self._build_q_net()
        
        self.value_net = self._build_value_net()
        self.target_value_net = self._build_value_net()
        self.target_value_net.set_weights(self.value_net.get_weights())
        
        self.step_count = 0
        self.last_observation = None
        self.last_action = None
    
    def _get_new_cache_content(self, cache_content, top_k_missed_videos, hit_rate):
        candidates = np.concatenate([cache_content, top_k_missed_videos])
        observation = self.loader.get_frequencies2(candidates, [200, 1000, 2000])
        reward = hit_rate
        
        if self.last_observation is not None and self.last_action is not None:
            self.backward(self.last_observation, self.last_action, reward, False, observation)
        
        action = self.forward(observation)
        new_cache_content = candidates[action]
        
        self.last_action = action
        self.last_observation = observation
        
        return new_cache_content
    
    def forward(self, observation):
        self.step_count += 1
        action_mask = np.zeros((self.observation_shape[0],), dtype=np.bool)
        
        if observation.ndim == 2:
            observation = np.expand_dims(observation, axis=0)
        
        mean, log_std = self.policy_net.predict(observation)
        mean = mean.squeeze(0)
        log_std = log_std.squeeze(0)
        log_prob = gaussian_likelihood(observation.squeeze(0), mean, log_std)
        
        action = np.random.choice(np.arange(self.nb_actions), size=self.capacity,
                                  p=softmax(log_prob),
                                  replace=False)
        action_mask[action] = True
        return action_mask
    
    def backward(self, observation, action, reward, terminal, next_observation):
        self.memory.store_transition(observation, action, reward, terminal, next_observation)
        
        if self.step_count >= self.nb_steps_warmup:
            self._update()
            
            new_target_weights = polyak_averaging(
                self.value_net.get_weights(),
                self.target_value_net.get_weights(),
                self.polyak
            )
            self.target_value_net.set_weights(new_target_weights)
    
    def _build_policy_network(self):
        layers = tf.keras.layers
        observation_tensor = layers.Input(shape=self.observation_shape)
        y = layers.Flatten()(observation_tensor)
        y = layers.Dense(32, activation='relu')(y)
        y = layers.Dense(32, activation='relu')(y)
        y = layers.Dense(32, activation='relu')(y)
        
        mean = layers.Dense(self.action_dim, activation='tanh')(y)
        log_std = layers.Dense(self.action_dim, activation='tanh')(y)
        
        log_std = self.log_std_min + 0.5 * (self.log_std_max - self.log_std_min) * (log_std + 1)
        
        model = tf.keras.models.Model(inputs=observation_tensor, outputs=[mean, log_std])
        model.compile(loss='mse', optimizer=tf.keras.optimizers.Adam(self.lr))
        
        return model
    
    def _build_q_net(self):
        layers = tf.keras.layers
        observation_tensor = layers.Input(shape=self.observation_shape)
        action_tensor = layers.Input(shape=(self.action_dim,))
        
        observation_tensor_flattened = layers.Flatten()(observation_tensor)
        y = layers.Concatenate()([observation_tensor_flattened, action_tensor])
        y = layers.Dense(32, activation='relu')(y)
        y = layers.Dense(32, activation='relu')(y)
        y = layers.Dense(32, activation='relu')(y)
        y = layers.Dense(1)(y)
        
        model = tf.keras.models.Model(inputs=[observation_tensor, action_tensor], outputs=y)
        model.compile(loss='mse', optimizer=tf.keras.optimizers.Adam(lr=self.lr))
        
        return model
    
    def _build_value_net(self):
        layers = tf.keras.layers
        
        model = tf.keras.models.Sequential([
            layers.Flatten(input_shape=self.observation_shape),
            layers.Dense(32, activation='relu'),
            layers.Dense(32, activation='relu'),
            layers.Dense(32, activation='relu'),
            layers.Dense(1)
        ])
        model.compile(loss='mse', optimizer=tf.keras.optimizers.Adam(lr=self.lr))
        
        return model
    
    @tf.function
    def _evaluate(self, observations):
        mean, log_std = self.policy_net(observations)
        
        std = tf.math.exp(log_std)
        z = mean + tf.random.normal(tf.shape(mean)) * std
        action = tf.math.tanh(z)
        log_prob = gaussian_likelihood(z, mean, log_std)
        log_prob -= tf.math.reduce_sum(tf.math.log(1 - action ** 2 + 1e-6), axis=1)
        
        action = tf.cast(action, dtype=tf.float64)
        
        return action, log_prob
    
    def _update(self):
        observations, actions, rewards, terminals, next_observations = self.memory.sample_batch()
        
        target_q_value = rewards + self.gamma * self.target_value_net.predict(next_observations)
        
        soft_actions, log_probs = self._evaluate(observations)
        
        soft_q_value1 = self.soft_q_net1.predict([observations, soft_actions])
        soft_q_value2 = self.soft_q_net2.predict([observations, soft_actions])
        
        target_value = tf.minimum(soft_q_value1, soft_q_value2) - self.alpha * log_probs
        
        # Update soft Q network
        batch_size = observations.shape[0]
        actions = actions.astype(np.bool)
        actions = observations[actions].reshape((batch_size, self.capacity, self.action_dim)).mean(axis=1)
        
        self.soft_q_net1.fit([observations, actions], target_q_value, verbose=0)
        self.soft_q_net2.fit([observations, actions], target_q_value, verbose=0)
        
        # Update value network
        self.value_net.fit(observations, target_value, verbose=0)
        
        # Update policy network
        with tf.GradientTape() as tape:
            tape.watch(self.policy_net.trainable_weights)
            
            soft_actions, log_probs = self._evaluate(observations)
            
            soft_q_value = self.soft_q_net1([observations, soft_actions])
            
            loss = -tf.reduce_mean(soft_q_value - self.alpha * log_probs)
        
        actor_grads = tape.gradient(loss, self.policy_net.trainable_weights)
        self.policy_net.optimizer.apply_gradients(zip(actor_grads, self.policy_net.trainable_weights))
Example #6
0
class SACAgent(Agent):
    
    def __init__(self,
                 action_space,
                 observation_space,
                 gamma=0.99,
                 nb_steps_warmup=2000,
                 alpha=0.2,
                 polyak=0.995,
                 lr=3e-4,
                 log_std_min=-20,
                 log_std_max=2,
                 memory_size=10000
                 ):
        super().__init__()
        self.action_space = action_space
        self.observation_space = observation_space
        self.gamma = gamma
        self.alpha = alpha
        self.polyak = polyak
        self.nb_steps_warmup = nb_steps_warmup
        self.lr = lr
        self.step_count = 0
        self.log_std_min = log_std_min
        self.log_std_max = log_std_max
        self.min_entropy = -self.action_space.shape[0]
        
        self.memory = Memory(capacity=memory_size,
                             observation_shape=observation_space.shape,
                             action_shape=action_space.shape)
        
        self.value_net = self._build_value_network()
        self.target_value_net = self._build_value_network()
        self.target_value_net.set_weights(self.value_net.get_weights())
        
        self.soft_q_net1 = self._build_soft_q_network()
        self.soft_q_net2 = self._build_soft_q_network()
        
        self.policy_net = self._build_policy_network()
    
    def forward(self, observation):
        self.step_count += 1
        
        if self.step_count < self.nb_steps_warmup:
            return self.action_space.sample()
        else:
            if observation.ndim == 1:
                observation = np.expand_dims(observation, axis=0)
            mean, log_std = self.policy_net.predict(observation)
            
            std = tf.math.exp(log_std)
            action = mean + tf.random.normal(tf.shape(mean)) * std
            return action
    
    def backward(self, observation, action, reward, terminal, next_observation):
        self.memory.store_transition(observation, action, reward, terminal, next_observation)
        
        if self.step_count >= self.nb_steps_warmup:
            self._update()
            
            new_target_weights = polyak_averaging(
                self.value_net.get_weights(),
                self.target_value_net.get_weights(),
                self.polyak
            )
            self.target_value_net.set_weights(new_target_weights)
    
    def _update(self):
        observations, actions, rewards, terminals, next_observations = self.memory.sample_batch()
        
        target_q_value = rewards + self.gamma * self.target_value_net.predict(next_observations)
        
        soft_actions, log_probs = self.evaluate(observations)
        
        soft_q_value1 = self.soft_q_net1.predict([observations, soft_actions])
        soft_q_value2 = self.soft_q_net2.predict([observations, soft_actions])
        
        target_value = tf.minimum(soft_q_value1, soft_q_value2) - self.alpha * log_probs
        
        # Update soft Q network
        self.soft_q_net1.fit([observations, actions], target_q_value, verbose=0)
        self.soft_q_net2.fit([observations, actions], target_q_value, verbose=0)
        
        # Update value network
        self.value_net.fit(observations, target_value, verbose=0)
        
        # Update policy network
        with tf.GradientTape() as tape:
            tape.watch(self.policy_net.trainable_weights)
            
            soft_actions, log_probs = self.evaluate(observations)
            
            soft_q_value = self.soft_q_net1([observations, soft_actions])
            
            loss = -tf.reduce_mean(soft_q_value - self.alpha * log_probs)
        
        actor_grads = tape.gradient(loss, self.policy_net.trainable_weights)
        self.policy_net.optimizer.apply_gradients(zip(actor_grads, self.policy_net.trainable_weights))
    
    @tf.function
    def evaluate(self, observations):
        mean, log_std = self.policy_net(observations)
        
        std = tf.math.exp(log_std)
        z = mean + tf.random.normal(tf.shape(mean)) * std
        action = tf.math.tanh(z)
        log_prob = gaussian_likelihood(z, mean, log_std)
        log_prob -= tf.math.reduce_sum(tf.math.log(1 - action ** 2 + 1e-6), axis=1)
        
        action = tf.cast(action, dtype=tf.float64)
        
        return action, log_prob
    
    def _build_value_network(self):
        observation_shape = self.observation_space.shape
        
        layers = tf.keras.layers
        
        model = tf.keras.models.Sequential([
            layers.Flatten(input_shape=observation_shape),
            layers.Dense(32, activation='relu'),
            layers.Dense(32, activation='relu'),
            layers.Dense(32, activation='relu'),
            layers.Dense(1)
        ])
        model.compile(loss='mse', optimizer=tf.keras.optimizers.Adam(lr=self.lr))
        
        return model
    
    def _build_soft_q_network(self):
        observation_shape = self.observation_space.shape
        nb_actions = self.action_space.shape[0]
        
        layers = tf.keras.layers
        observation_tensor = layers.Input(shape=observation_shape)
        action_tensor = layers.Input(shape=(nb_actions,))
        
        observation_tensor_flattened = layers.Flatten()(observation_tensor)
        y = layers.Concatenate()([observation_tensor_flattened, action_tensor])
        y = layers.Dense(32, activation='relu')(y)
        y = layers.Dense(32, activation='relu')(y)
        y = layers.Dense(32, activation='relu')(y)
        y = layers.Dense(1)(y)
        
        model = tf.keras.models.Model(inputs=[observation_tensor, action_tensor], outputs=y)
        model.compile(loss='mse', optimizer=tf.keras.optimizers.Adam(lr=self.lr))
        
        return model
    
    def _build_policy_network(self):
        observation_shape = self.observation_space.shape
        nb_actions = self.action_space.shape[0]
        
        layers = tf.keras.layers
        observation_tensor = layers.Input(shape=observation_shape)
        y = layers.Flatten()(observation_tensor)
        y = layers.Dense(32, activation='relu')(y)
        y = layers.Dense(32, activation='relu')(y)
        y = layers.Dense(32, activation='relu')(y)
        
        mean = layers.Dense(nb_actions, activation='tanh')(y)
        log_std = layers.Dense(nb_actions, activation='tanh')(y)
        
        log_std = self.log_std_min + 0.5 * (self.log_std_max - self.log_std_min) * (log_std + 1)
        
        model = tf.keras.models.Model(inputs=observation_tensor, outputs=[mean, log_std])
        model.compile(loss='mse', optimizer=tf.keras.optimizers.Adam(self.lr))
        
        return model