예제 #1
0
    def __init__(self,
                 action_size: int,
                 state_dim: int,
                 action_dim: int,
                 gamma: float,
                 sess: tf.Session,
                 optimizer: tf.train.Optimizer = tf.train.AdamOptimizer(
                     learning_rate=0.001),
                 max_tf_checkpoints_to_keep: int = 3,
                 experience_size: int = 1000,
                 per: bool = False,
                 batch_size: int = 64,
                 start_steps: int = 2000):
        self.optimizer = optimizer
        self.sess = sess
        self.gamma = gamma
        self.action_dim = action_dim
        self.state_dim = state_dim
        self.action_size = action_size
        self.per = per

        self.actor = ActorNetwork(action_size=action_size,
                                  state_dim=state_dim,
                                  action_dim=action_dim,
                                  sess=sess,
                                  optimizer=optimizer)

        self.critic = CriticNetwork(action_size=action_size,
                                    state_dim=state_dim,
                                    action_dim=action_dim,
                                    sess=sess,
                                    optimizer=optimizer,
                                    gamma=gamma)

        self.eval_mode = False
        self.t = 0
        self.start_steps = start_steps
        self.training_steps = 0
        self.epsilon = 1
        self.batch_size = batch_size

        self._saver = tf.train.Saver(max_to_keep=max_tf_checkpoints_to_keep)

        if self.per:
            self._replay = PER(experience_size)
        else:
            self._replay = ReplayBuffer(experience_size)

        self._last_state = None
        self._last_items = None
        self._last_action = None

        self.td_losses = []
        self.qvalues = []
예제 #2
0
    def __init__(self,
                 sess,
                 state_dim,
                 action_dim,
                 epsilon=0.4,
                 action_size=4,
                 logdir='./logs/',
                 replay_size=1000,
                 batch_size=64):

        self._state_dim = state_dim
        self._action_dim = action_dim
        self._action_size = action_size

        self._logdir = logdir

        self._sess = sess

        self.epsilon = epsilon
        self.gamma = 0.9
        self.lr = 1e-4
        self.optimizer = tf.train.AdadeltaOptimizer(self.lr)

        self.state, self.action, self.agent, self.weights = self._create_network(
            'agent')

        self.qvalues = self.agent(tf.concat([self.state, self.action],
                                            axis=-1))

        self.target_state, self.target_action, self.target, self.target_weights = self._create_network(
            'target')
        self.target_qvalues = self.target(
            tf.concat([self.target_state, self.target_action], axis=-1))

        self.train_op, self.td_loss = self._create_train_op()
        self.target_update_op = self._create_target_update_op()

        self.merged = tf.summary.merge_all()
        self.train_writer = tf.summary.FileWriter(self._logdir,
                                                  self._sess.graph)
        self.summary = None

        self._replay = ReplayBuffer(replay_size)
        self.batch_size = batch_size
        self.td_losses = []

        self._last_state = None
        self._last_items = None
        self._last_action = None
        self.eval_mode = False
        self.training_steps = 0
예제 #3
0
파일: agent.py 프로젝트: shintay/rl
    def __init__(self, task):
        self.task = task
        self.state_size = task.state_size
        self.action_size = task.action_size
        self.action_low = task.action_low
        self.action_high = task.action_high

        # Actor (Policy) Model
        self.actor_local = Actor(self.state_size,
                                   self.action_size,
                                   self.action_low,
                                   self.action_high)

        self.actor_target = Actor(self.state_size,
                                    self.action_size,
                                    self.action_low,
                                    self.action_high)

        # Critic (Value) Model
        self.critic_local = Critic(self.state_size, self.action_size)
        self.critic_target = Critic(self.state_size, self.action_size)

        # Initialize target model parameters with local model parameters
        self.critic_target.model.set_weights(
            self.critic_local.model.get_weights())
        self.actor_target.model.set_weights(
            self.actor_local.model.get_weights())

        # Noise process
        self.exploration_mu = 0.0  #0
        self.exploration_theta = 0.125 # 0.14 | 0.1
        self.exploration_sigma = 0.0009 # 0.001 | 0.2 | 0.001
        self.noise = OUNoise(self.action_size,
                             self.exploration_mu,
                             self.exploration_theta,
                             self.exploration_sigma)

        # Replay memory
        self.buffer_size = 100000
        self.batch_size = 64
        self.memory = ReplayBuffer(self.buffer_size,
                                   self.batch_size)

        # Algorithm parameters
        self.gamma = 0.998  # 0.99 | 0.9 | discount factor
        self.tau = 0.099  # 0.001| 0.01 | 0.1 | 0.05 |  for soft update of target parameters

        # Score tracker
        self.best_score = -np.inf
        self.score = 0
예제 #4
0
    def __init__(
        self,
        state_size,
        network_id: str,
        buffer_size: int = int(20000),
        batch_size: int = 64,
        gamma: float = 0.99,
        lr: float = 1e-4,
        train_freq: int = 4,
        target_update_freq: int = 1000,
        min_epsilon: float = 0.05,
        epsilon_decay: float = 0.0005,
        scheduler_id: str = 'linear',
        **kwargs,
    ):
        super().__init__(**kwargs)
        self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
        self.batch_size = batch_size
        self.gamma = gamma
        self.lr = lr
        self.train_freq = train_freq
        self.epsilon_scheduler = scheduler_hub[scheduler_id](
            1, epsilon_decay, min_epsilon
        )
        self.t_step = 0
        self.memory = ReplayBuffer(buffer_size, batch_size, self.device)
        self.episode_logs = {}

        # self.future_network = FutureNetwork()
        # self.reward_network = RewardNetwork()
        # self.opt_future_reward = optim.RMSprop(self.future_network.parameters(), lr=self.lr)
        # self.keypoint_network = KeypointNetwork()
        # self.opt_keypoint = optim.RMSprop(self.future_network.parameters(), lr=self.lr)
        self.state_encoder = StateEncoder(state_size).to(self.device)
        self.state_decoder = StateDecoder(state_size).to(self.device)
        self.opt_ae = optim.RMSprop(
            list(self.state_encoder.parameters()) + list(self.state_decoder.parameters()),
            lr=self.lr,
        )
예제 #5
0
class DDPG():
    """Reinforcement Learning agent that learns using DDPG."""
    def __init__(self, task):
        self.task = task
        self.state_size = task.state_size
        self.action_size = task.action_size
        self.action_low = task.action_low
        self.action_high = task.action_high

        # Actor (Policy) Model
        self.actor_local = Actor(self.state_size, self.action_size, self.action_low, self.action_high)
        self.actor_target = Actor(self.state_size, self.action_size, self.action_low, self.action_high)

        # Critic (Value) Model
        self.critic_local = Critic(self.state_size, self.action_size)
        self.critic_target = Critic(self.state_size, self.action_size)

        # Initialize target model parameters with local model parameters
        self.critic_target.model.set_weights(self.critic_local.model.get_weights())
        self.actor_target.model.set_weights(self.actor_local.model.get_weights())

        # Noise process
        self.exploration_mu = 0
        self.exploration_theta = 0.15
        self.exploration_sigma = 0.2
        self.noise = OUNoise(self.action_size, self.exploration_mu, self.exploration_theta, self.exploration_sigma)

        # Replay memory
        self.buffer_size = 100000
        self.batch_size = 64
        self.memory = ReplayBuffer(self.buffer_size, self.batch_size)

        # Algorithm parameters
        self.gamma = 0.9  # discount factor
        self.tau = 0.001  # for soft update of target parameters

        self.total_reward = 0
        self.count = 0
        self.best_score = -np.inf
        self.score = 0

    def reset_episode(self):
        self.total_reward = 0
        self.count = 0
        self.score = 0
        self.noise.reset()
        state = self.task.reset()
        self.last_state = state
        return state

    def step(self, action, reward, next_state, done):
        self.total_reward += reward
        self.count += 1
        # Save experience / reward
        self.memory.add(self.last_state, action, reward, next_state, done)

        # Learn, if enough samples are available in memory
        if len(self.memory) > self.batch_size:
            experiences = self.memory.sample()
            self.learn(experiences)

        # Roll over last state and action
        self.last_state = next_state

    def act(self, states):
        """Returns actions for given state(s) as per current policy."""
        states = np.reshape(states, [-1, self.state_size])
        action = self.actor_local.model.predict(states)[0]
        return list(action + self.noise.sample())  # add some noise for exploration

    def learn(self, experiences):
        """Update policy and value parameters using given batch of experience tuples."""
        self.score = self.total_reward / float(self.count) if self.count else 0.0
        if self.score > self.best_score:
            self.best_score = self.score

        # Convert experience tuples to separate arrays for each element (states, actions, rewards, etc.)
        states = np.vstack([e.state for e in experiences if e is not None])
        actions = np.array([e.action for e in experiences if e is not None]).astype(np.float32).reshape(-1, self.action_size)
        rewards = np.array([e.reward for e in experiences if e is not None]).astype(np.float32).reshape(-1, 1)
        dones = np.array([e.done for e in experiences if e is not None]).astype(np.uint8).reshape(-1, 1)
        next_states = np.vstack([e.next_state for e in experiences if e is not None])

        # Get predicted next-state actions and Q values from target models
        #     Q_targets_next = critic_target(next_state, actor_target(next_state))
        actions_next = self.actor_target.model.predict_on_batch(next_states)
        Q_targets_next = self.critic_target.model.predict_on_batch([next_states, actions_next])

        # Compute Q targets for current states and train critic model (local)
        Q_targets = rewards + self.gamma * Q_targets_next * (1 - dones)
        self.critic_local.model.train_on_batch(x=[states, actions], y=Q_targets)

        # Train actor model (local)
        action_gradients = np.reshape(self.critic_local.get_action_gradients([states, actions, 0]), (-1, self.action_size))
        self.actor_local.train_fn([states, action_gradients, 1])  # custom training function

        # Soft-update target models
        self.soft_update(self.critic_local.model, self.critic_target.model)
        self.soft_update(self.actor_local.model, self.actor_target.model)

    def soft_update(self, local_model, target_model):
        """Soft update model parameters."""
        local_weights = np.array(local_model.get_weights())
        target_weights = np.array(target_model.get_weights())

        assert len(local_weights) == len(target_weights), "Local and target model parameters must have the same size"

        new_weights = self.tau * local_weights + (1 - self.tau) * target_weights
        target_model.set_weights(new_weights)
예제 #6
0
    gif_req_m = mp.Value('i', -1)
    data_proc_list = []
    for _ in range(hp.N_ROLLOUT_PROCESSES):
        data_proc = mp.Process(target=data_func,
                               args=(pi, device, exp_queue, finish_event,
                                     gif_req_m, hp))
        data_proc.start()
        data_proc_list.append(data_proc)

    # Training
    tgt_Q = TargetCritic(Q)
    pi_opt = optim.Adam(pi.parameters(), lr=hp.LEARNING_RATE)
    Q_opt = optim.Adam(Q.parameters(), lr=hp.LEARNING_RATE)
    alpha_optim = optim.Adam([log_alpha], lr=hp.LEARNING_RATE)
    buffer = ReplayBuffer(buffer_size=hp.REPLAY_SIZE,
                          observation_space=hp.observation_space,
                          action_space=hp.action_space,
                          device=hp.DEVICE)
    n_grads = 0
    n_samples = 0
    n_episodes = 0
    best_reward = None
    last_gif = None

    try:
        while n_grads < hp.TOTAL_GRAD_STEPS:
            metrics = {}
            ep_infos = list()
            st_time = time.perf_counter()
            # Collect EXP_GRAD_RATIO sample for each grad step
            new_samples = 0
            while new_samples < hp.EXP_GRAD_RATIO:
예제 #7
0
def main(args):
    device = "cuda" if args.cuda else "cpu"
    mp.set_start_method('spawn')
    # Input Experiment Hyperparameters
    hp = SACHP(EXP_NAME=args.name,
               DEVICE=device,
               ENV_NAME=args.env,
               N_ROLLOUT_PROCESSES=3,
               LEARNING_RATE=0.0001,
               EXP_GRAD_RATIO=10,
               BATCH_SIZE=256,
               GAMMA=0.95,
               REWARD_STEPS=3,
               ALPHA=0.015,
               LOG_SIG_MAX=2,
               LOG_SIG_MIN=-20,
               EPSILON=1e-6,
               REPLAY_SIZE=100000,
               REPLAY_INITIAL=512,
               SAVE_FREQUENCY=100000,
               GIF_FREQUENCY=100000,
               TOTAL_GRAD_STEPS=1000000)
    wandb.init(project='RoboCIn-RL',
               name=hp.EXP_NAME,
               entity='robocin',
               config=hp.to_dict())
    current_time = datetime.datetime.now().strftime('%b-%d_%H-%M-%S')
    tb_path = os.path.join(
        'runs', current_time + '_' + hp.ENV_NAME + '_' + hp.EXP_NAME)
    # Training
    sac = SAC(hp)
    buffer = ReplayBuffer(buffer_size=hp.REPLAY_SIZE,
                          observation_space=hp.observation_space,
                          action_space=hp.action_space,
                          device=hp.DEVICE)

    # Playing
    sac.share_memory()
    exp_queue = mp.Queue(maxsize=hp.EXP_GRAD_RATIO)
    finish_event = mp.Event()
    gif_req_m = mp.Value('i', -1)
    data_proc = mp.Process(target=rollout,
                           args=(sac, device, exp_queue, finish_event,
                                 gif_req_m, hp))
    data_proc.start()

    n_grads = 0
    n_samples = 0
    n_episodes = 0
    best_reward = None
    last_gif = None
    try:
        while n_grads < hp.TOTAL_GRAD_STEPS:
            metrics = {}
            ep_infos = list()
            st_time = time.perf_counter()
            # Collect EXP_GRAD_RATIO sample for each grad step
            new_samples = 0
            while new_samples < hp.EXP_GRAD_RATIO:
                exp = exp_queue.get()
                if exp is None:
                    raise Exception  # got None value in queue
                safe_exp = copy.deepcopy(exp)
                del (exp)

                # Dict is returned with end of episode info
                if isinstance(safe_exp, dict):
                    logs = {
                        "ep_info/" + key: value
                        for key, value in safe_exp.items()
                        if 'truncated' not in key
                    }
                    ep_infos.append(logs)
                    n_episodes += 1
                else:
                    if safe_exp.last_state is not None:
                        last_state = safe_exp.last_state
                    else:
                        last_state = safe_exp.state
                    buffer.add(obs=safe_exp.state,
                               next_obs=last_state,
                               action=safe_exp.action,
                               reward=safe_exp.reward,
                               done=False
                               if safe_exp.last_state is not None else True)
                    new_samples += 1
            n_samples += new_samples
            sample_time = time.perf_counter()

            # Only start training after buffer is larger than initial value
            if buffer.size() < hp.REPLAY_INITIAL:
                continue

            # Sample a batch and load it as a tensor on device
            batch = buffer.sample(hp.BATCH_SIZE)
            metrics["train/loss_pi"], metrics["train/loss_Q1"], \
                metrics["train/loss_Q2"], metrics["train/loss_alpha"], \
                metrics["train/alpha"] = sac.update(batch=batch,
                                                    metrics=metrics)

            n_grads += 1
            grad_time = time.perf_counter()
            metrics['speed/samples'] = new_samples / (sample_time - st_time)
            metrics['speed/grad'] = 1 / (grad_time - sample_time)
            metrics['speed/total'] = 1 / (grad_time - st_time)
            metrics['counters/samples'] = n_samples
            metrics['counters/grads'] = n_grads
            metrics['counters/episodes'] = n_episodes
            metrics["counters/buffer_len"] = buffer.size()

            if ep_infos:
                for key in ep_infos[0].keys():
                    metrics[key] = np.mean([info[key] for info in ep_infos])

            # Log metrics
            wandb.log(metrics)
            if hp.SAVE_FREQUENCY and n_grads % hp.SAVE_FREQUENCY == 0:
                save_checkpoint(hp=hp,
                                metrics={
                                    'alpha': sac.alpha,
                                    'n_samples': n_samples,
                                    'n_grads': n_grads,
                                    'n_episodes': n_episodes
                                },
                                pi=sac.pi,
                                Q=sac.Q,
                                pi_opt=sac.pi_opt,
                                Q_opt=sac.Q_opt)

            if hp.GIF_FREQUENCY and n_grads % hp.GIF_FREQUENCY == 0:
                gif_req_m.value = n_grads

    except KeyboardInterrupt:
        print("...Finishing...")
        finish_event.set()

    finally:
        if exp_queue:
            while exp_queue.qsize() > 0:
                exp_queue.get()

        print('queue is empty')

        print("Waiting for threads to finish...")
        data_proc.terminate()
        data_proc.join()

        del (exp_queue)
        del (sac)

        finish_event.set()
예제 #8
0
class Qagent(Agent):
    def __init__(self,
                 sess,
                 state_dim,
                 action_dim,
                 epsilon=0.4,
                 action_size=4,
                 logdir='./logs/',
                 replay_size=1000,
                 batch_size=64):

        self._state_dim = state_dim
        self._action_dim = action_dim
        self._action_size = action_size

        self._logdir = logdir

        self._sess = sess

        self.epsilon = epsilon
        self.gamma = 0.9
        self.lr = 1e-4
        self.optimizer = tf.train.AdadeltaOptimizer(self.lr)

        self.state, self.action, self.agent, self.weights = self._create_network(
            'agent')

        self.qvalues = self.agent(tf.concat([self.state, self.action],
                                            axis=-1))

        self.target_state, self.target_action, self.target, self.target_weights = self._create_network(
            'target')
        self.target_qvalues = self.target(
            tf.concat([self.target_state, self.target_action], axis=-1))

        self.train_op, self.td_loss = self._create_train_op()
        self.target_update_op = self._create_target_update_op()

        self.merged = tf.summary.merge_all()
        self.train_writer = tf.summary.FileWriter(self._logdir,
                                                  self._sess.graph)
        self.summary = None

        self._replay = ReplayBuffer(replay_size)
        self.batch_size = batch_size
        self.td_losses = []

        self._last_state = None
        self._last_items = None
        self._last_action = None
        self.eval_mode = False
        self.training_steps = 0

    def _create_network(self, name):

        with tf.variable_scope(name_or_scope=name):
            state_ph = tf.placeholder('float32',
                                      shape=(None, ) + self._state_dim,
                                      name='state')
            action_ph = tf.placeholder('float32',
                                       shape=(None, ) + self._action_dim,
                                       name='action')

            net = Sequential(layers=[
                InputLayer(input_shape=(self._state_dim[0] +
                                        self._action_dim[0], )),
                Dense(50, activation='relu'),
                Dense(20, activation='relu'),
                Dense(1)
            ])
            weights = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,
                                        scope=name)

            return state_ph, action_ph, net, weights

    def _create_train_op(self):

        with tf.variable_scope(name_or_scope='train') as train_scope:

            # s a r s' A
            self.s_ph = tf.placeholder(tf.float32,
                                       shape=(None, ) + self._state_dim,
                                       name='s')
            self.a_ph = tf.placeholder(tf.float32,
                                       shape=(None, self._action_size) +
                                       self._action_dim,
                                       name='a')
            self.r_ph = tf.placeholder(tf.float32, shape=[None], name='r')
            self.done_ph = tf.placeholder(tf.float32,
                                          shape=[None],
                                          name='done')
            self.next_s_ph = tf.placeholder(tf.float32,
                                            shape=(None, ) + self._state_dim,
                                            name='next_s')
            # pool of actions at time T (ot T+1?)
            self.next_as_ph = tf.placeholder(tf.float32,
                                             shape=(
                                                 None,
                                                 None,
                                             ) + self._action_dim,
                                             name='next_as')

            repeat_current_state = tf.expand_dims(self.s_ph, 1)
            repeat_current_state = tf.tile(repeat_current_state,
                                           multiples=[1, self._action_size, 1])
            current_qvalue = self.agent(
                tf.concat([repeat_current_state, self.a_ph], axis=-1))

            current_qvalue = tf.squeeze(current_qvalue, 0)
            current_qvalue = tf.reduce_sum(current_qvalue, axis=-1)

            repeat_states = tf.expand_dims(self.next_s_ph, 1)
            repeat_states = tf.tile(
                repeat_states, multiples=[1,
                                          tf.shape(self.next_as_ph)[1], 1])

            next_qvalues = self.target(
                tf.concat([repeat_states, self.next_as_ph], axis=-1))
            next_qvalues = tf.squeeze(next_qvalues, axis=-1)

            k_max_next_qvalues, _ = tf.nn.top_k(next_qvalues,
                                                k=self._action_size)
            # should sum but not over batches
            next_max_qvalue = tf.reduce_sum(k_max_next_qvalues, axis=-1)

            reference = self.r_ph + self.gamma * next_max_qvalue

            td_loss = (current_qvalue - reference)**2
            td_loss = tf.reduce_mean(td_loss)

            tf.summary.histogram('next_max_qvalue', next_max_qvalue)
            tf.summary.histogram('topk', k_max_next_qvalues)
            tf.summary.histogram('target', reference)
            tf.summary.histogram('qvalue', current_qvalue)

            tf.summary.scalar('td_loss', td_loss)
            # Op to calculate every variable gradient
            grads = tf.gradients(td_loss, self.weights)
            grads = list(zip(grads, self.weights))
            # Summarize all gradients
            #for grad, var in grads:
            #    tf.summary.histogram(var.name + '/gradient', grad)

        return self.optimizer.minimize(td_loss, var_list=self.weights), td_loss

    def _create_target_update_op(self):
        """ assign target_network.weights variables to their respective agent.weights values. """
        assigns = []
        for w_agent, w_target in zip(self.weights, self.target_weights):
            assigns.append(tf.assign(w_target, w_agent, validate_shape=True))
        return assigns

    def rank_action(self, state, actions):
        qvalues = self._sess.run(
            self.qvalues, {
                self.state:
                np.repeat(state.reshape(-1, self._state_dim[0]),
                          actions.shape[0],
                          axis=0),
                self.action:
                actions
            })
        return qvalues

    # add rank_actions target
    # action for environement is an array of items
    # action for q function is an item
    def target_rank_action(self, state, actions):
        return self._sess.run(
            self.target_qvalues, {
                self.target_state:
                np.repeat(state.reshape(-1, self._state_dim[0]),
                          actions.shape[0],
                          axis=0),
                self.target_action:
                actions,
            })

    def _train(self, batch):
        s, a, r, next_s, actions, done = batch

        losses = []
        for s, a, r, next_s, actions, done in zip(*batch):

            _, loss, _ = self._sess.run(
                [self.train_op, self.td_loss, self.merged], {
                    self.s_ph: s[None],
                    self.a_ph: a[None],
                    self.r_ph: r[None],
                    self.done_ph: done[None],
                    self.next_s_ph: next_s[None],
                    self.next_as_ph: actions[None]
                })
            losses.append(loss)

        return np.mean(losses)

    def update_target_weights(self):
        self._sess.run(self.target_update_op)

    def write_summary(self, step):
        if self.summary:
            self.train_writer.add_summary(self.summary, global_step=step)

    def _sample_action(self, state, actions):
        actions = np.array(actions)

        if np.random.rand() > self.epsilon:
            qvalues = self.rank_action(state, actions).reshape(-1)
            idxs = qvalues.argsort()[::-1][:self._action_size]
        else:
            idxs = np.random.choice(actions.shape[0], size=self._action_size)

        return idxs, actions[idxs]

    def _train_step(self):
        if len(self._replay) >= self.batch_size:
            self.training_steps += 1

            batch = self._replay.sample(self.batch_size)
            td_loss = self._train(batch)
            self.td_losses.append(td_loss)

    def begin_episode(self, observation):
        state, items = observation
        self._last_state = state
        self._last_items = items
        actions_ids, action = self._sample_action(state, items)
        self._last_action = action
        return actions_ids

    def step(self, reward, observation):
        state, items = observation
        self._replay.add(self._last_state, self._last_action, reward, state,
                         np.array(items), False)

        if not self.eval_mode:
            self._train_step()

        self._last_state = state
        self._last_items = items
        actions_ids, action = self._sample_action(state, items)
        self._last_action = action
        return actions_ids

    def end_episode(self, reward):
        return super().end_episode(reward)

    def bundle_and_checkpoint(self, directory, iteration):
        return super().bundle_and_checkpoint(directory, iteration)

    def unbundle(self, directory, iteration, dictionary):
        return super().unbundle(directory, iteration, dictionary)
예제 #9
0
def training_loop(agent_config, env_config, vaccination_schedule):
    action_dim = len(env_config["groups"]) * 2 * len(vaccination_schedule)
    state_dim = len(env_config["groups"]) + len(vaccination_schedule) * len(
        env_config["groups"]) + len(vaccination_schedule)
    batch_size = 100
    episodes = 150
    max_iterations = env_config["max_time_steps"]

    env = VaccinationEnvironment(env_config, vaccination_schedule)
    agent = TD3Agent([state_dim + action_dim, 256, 256, 1],
                     [state_dim, 256, 256, action_dim])
    replay_buffer = ReplayBuffer(state_dim, action_dim)

    # get starting state
    state = []
    for items in list(env.get_cases().values()):
        for value in items:
            state.append(value)
    state.append(env.get_vaccines()[0].num)

    # first sample enough data
    for i in range(2 * batch_size):
        available_vaccines = env.get_vaccines()[0].num

        action = agent.act(state)

        vaccination_plan = []
        for index in range(0, len(action), 2):
            group_vac_1 = int(action[index] * available_vaccines)
            group_vac_2 = int(action[index + 1] * available_vaccines)
            plan = Vaccination_Plan(JOHNSON, group_vac_1, group_vac_2)
            vaccination_plan.append([plan])

        info, done = env.step(vaccination_plan, False)
        reward = -sum([values[1] for values in info.values()])

        next_state = []
        for items in list(env.get_cases().values()):
            for value in items:
                next_state.append(value)
        next_state.append(available_vaccines)

        replay_buffer.add(state, action, next_state, reward, done)

        state = next_state

    # training
    rewards = []
    losses = []
    # get starting  state
    for i in tqdm(range(episodes)):
        episode_reward = []
        episode_loss = []

        env.reset()
        state = []
        for items in list(env.get_cases().values()):
            for value in items:
                state.append(value)
        state.append(env.get_vaccines()[0].num)

        for j in range(max_iterations):
            available_vaccines = env.get_vaccines()[0].num
            action = agent.act(state)

            vaccination_plan = []
            for index in range(0, len(action), 2):
                group_vac_1 = int(action[index] * available_vaccines)
                group_vac_2 = int(action[index + 1] * available_vaccines)
                plan = Vaccination_Plan(JOHNSON, group_vac_1, group_vac_2)
                vaccination_plan.append([plan])

            info, done = env.step(vaccination_plan, True)
            reward = sum([values[1] for values in info.values()])
            episode_reward.append(reward)

            next_state = []
            for items in list(env.get_cases().values()):
                for value in items:
                    next_state.append(value)
            next_state.append(available_vaccines)

            replay_buffer.add(state, action, next_state, reward, done)
            state = next_state

            # train
            train_states, train_actions, train_next_states, train_reward, train_done = replay_buffer.sample(
                batch_size)
            loss = agent.train(train_states, train_actions, train_next_states,
                               train_reward, train_done)

            episode_loss.append(loss.detach().numpy())

            if done:
                rewards.append(sum(episode_reward))
                losses.append(sum(episode_loss) / len(episode_loss))
                break

    # finally save data:
    data = pd.DataFrame(data={"rewards": rewards, "losses": losses})
    data.to_csv(
        "P:/Dokumente/3 Uni/WiSe2021/Hackathon/Hackathon_KU/exp/performance.csv",
        sep=",",
        index=False)
예제 #10
0
class FutureRealityAgent(BaseAgent):

    def __init__(
        self,
        state_size,
        network_id: str,
        buffer_size: int = int(20000),
        batch_size: int = 64,
        gamma: float = 0.99,
        lr: float = 1e-4,
        train_freq: int = 4,
        target_update_freq: int = 1000,
        min_epsilon: float = 0.05,
        epsilon_decay: float = 0.0005,
        scheduler_id: str = 'linear',
        **kwargs,
    ):
        super().__init__(**kwargs)
        self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
        self.batch_size = batch_size
        self.gamma = gamma
        self.lr = lr
        self.train_freq = train_freq
        self.epsilon_scheduler = scheduler_hub[scheduler_id](
            1, epsilon_decay, min_epsilon
        )
        self.t_step = 0
        self.memory = ReplayBuffer(buffer_size, batch_size, self.device)
        self.episode_logs = {}

        # self.future_network = FutureNetwork()
        # self.reward_network = RewardNetwork()
        # self.opt_future_reward = optim.RMSprop(self.future_network.parameters(), lr=self.lr)
        # self.keypoint_network = KeypointNetwork()
        # self.opt_keypoint = optim.RMSprop(self.future_network.parameters(), lr=self.lr)
        self.state_encoder = StateEncoder(state_size).to(self.device)
        self.state_decoder = StateDecoder(state_size).to(self.device)
        self.opt_ae = optim.RMSprop(
            list(self.state_encoder.parameters()) + list(self.state_decoder.parameters()),
            lr=self.lr,
        )

    def end_timestep(self, info, **kwargs):
        # Save experience in replay memory
        self.t_step += 1
        self.memory.add(**kwargs)

        if len(self.memory) > self.batch_size and self.t_step % self.train_freq == 0:
            experiences = self.memory.sample()
            states, actions, rewards, next_states, dones = experiences

            # train state autoencoder
            self.opt_ae.zero_grad()
            ae_loss, latent_states = self.state_autoencoder_loss(states)
            ae_loss.backward()
            self.opt_ae.step()

            # self.state_encoder.eval()
            # # train future & reward network
            # self.opt_future_reward.zero_grad()
            # future_loss = self.future_network_loss(latent_states)
            # reward_loss = self.reward_network_loss(latent_states, rewards)
            # total_loss = future_loss + reward_loss
            # total_loss.backward()
            # self.opt_future_reward.step()
            # self.state_encoder.train()

            self.episode_logs = {
                **self.episode_logs,
                'ae_loss': ae_loss.item(),
            }

    def state_autoencoder_loss(self, states):
        latent_states = self.state_encoder(states)
        recon_states = self.state_decoder(latent_states)
        loss = nn.MSELoss()(recon_states, states)
        return loss, latent_states

    def future_network_loss(self, latent_states):
        reality_fn, _ = self.future_network(latent_states)
        keypoints = self.keypoint_network(latent_states)
        future_states = reality_fn(keypoints)
        accumulated_rewards = self.reward_network(future_states)
        loss = -accumulated_rewards  # maximize
        return loss

    def reward_network_loss(self, latent_states, rewards):
        pred_rewards = self.reward_network(latent_states)
        loss = nn.MSELoss()(rewards, pred_rewards)
        return loss

    def keypoint_network_loss(self, latent_state, rewards):
        keypoint_ground_truths = self.get_keypoints(rewards)
        keypoint_preds = self.keypoint_network(latent_state)
        loss = nn.MSELoss()(keypoint_ground_truths, keypoint_preds)
        return loss

    @staticmethod
    def get_keypoints(rewards):
        pass

    def end_episode(self) -> dict:
        self.epsilon_scheduler.step()
        # train keypoint_network
        # TODO
        episode_logs = copy.deepcopy(self.episode_logs)
        self.episode_logs.clear()
        return {
            **episode_logs,
            'epsilon': self.epsilon_scheduler.get_epsilon(),
        }

    def get_action(self, state):
        # state = torch.from_numpy(state).float().unsqueeze(0).to(self.device)
        # self.future_network.eval()
        # self.state_encoder.eval()
        # with torch.no_grad():
        #     latent_state = self.state_encoder(state)
        #     _, p_action = self.future_network(latent_state)
        # self.future_network.train()
        # self.state_encoder.train()
        #
        # # Epsilon-greedy action selection
        # if random.random() > self.epsilon_scheduler.get_epsilon():
        #     return np.argmax(p_action.cpu().data.numpy())
        # else:
        return self.get_random_action()
예제 #11
0
파일: ddpg.py 프로젝트: kiminh/rec_gym
class DDPGAgent(Agent):
    def __init__(self,
                 action_size: int,
                 state_dim: int,
                 action_dim: int,
                 gamma: float,
                 sess: tf.Session,
                 optimizer: tf.train.Optimizer = tf.train.AdamOptimizer(
                     learning_rate=0.001
                 ),
                 max_tf_checkpoints_to_keep: int = 3,
                 experience_size: int = 1000,
                 per: bool = False,
                 batch_size: int = 64,
                 start_steps: int = 2000
                ):
        self.optimizer = optimizer
        self.sess = sess
        self.gamma = gamma
        self.action_dim = action_dim
        self.state_dim = state_dim
        self.action_size = action_size
        self.per = per

        self.actor = ActorNetwork(action_size=action_size, state_dim=state_dim,
                                  action_dim=action_dim, sess=sess, optimizer=optimizer)

        self.critic = CriticNetwork(action_size=action_size, state_dim=state_dim,
                                    action_dim=action_dim, sess=sess, optimizer=optimizer, gamma=gamma)

        self.eval_mode = False
        self.t = 0
        self.start_steps = start_steps
        self.training_steps = 0
        self.epsilon = 1
        self.batch_size = batch_size

        self._saver = tf.train.Saver(max_to_keep=max_tf_checkpoints_to_keep)

        if self.per:
            self._replay = PER(experience_size)
        else:
            self._replay = ReplayBuffer(experience_size)

        self._last_state = None
        self._last_items = None
        self._last_action = None

        self.td_losses = []
        self.qvalues = []

    def begin_episode(self, observation):
        state, items = observation
        self._last_state = state
        self._last_items = items
        actions_ids, action = self._sample_action(state, items)
        self._last_action = action
        return actions_ids

    def step(self, reward, observation):
        state, items = observation

        if self.per:
            experience = (self._last_state, self._last_action, reward, state, items, False)
            self._replay.store(experience)
        else:
            self._replay.add(self._last_state, self._last_action, reward, state, items, False)

        if not self.eval_mode:
            self._train_step()

        self._last_state = state
        self._last_items = items
        actions_ids, action = self._sample_action(state, items)
        self._last_action = action

        self.qvalues.append(self.critic.predict_qvalue([self._last_state],
                                                       [np.array(self._last_action).reshape(-1)]))

        return actions_ids

    def end_episode(self, reward):
        return super().end_episode(reward)

    def bundle_and_checkpoint(self, checkpoint_dir, iteration_number):
        """Returns a self-contained bundle of the agent's state.

        This is used for checkpointing. It will return a dictionary containing all
        non-TensorFlow objects (to be saved into a file by the caller), and it saves
        all TensorFlow objects into a checkpoint file.

        Args:
          checkpoint_dir: str, directory where TensorFlow objects will be saved.
          iteration_number: int, iteration number to use for naming the checkpoint
            file.

        Returns:
          A dict containing additional Python objects to be checkpointed by the
            experiment. If the checkpoint directory does not exist, returns None.
        """
        if not tf.gfile.Exists(checkpoint_dir):
            return None
        # Call the Tensorflow saver to checkpoint the graph.
        self._saver.save(
            self._sess,
            os.path.join(checkpoint_dir, 'tf_ckpt'),
            global_step=iteration_number)
        # Checkpoint the out-of-graph replay buffer.
        #self._replay.save(checkpoint_dir, iteration_number)
        bundle_dictionary = {}
        #bundle_dictionary['state'] = self.state
        #bundle_dictionary['eval_mode'] = self.eval_mode
        #bundle_dictionary['training_steps'] = self.training_steps
        return bundle_dictionary

    def unbundle(self, checkpoint_dir, iteration_number, bundle_dictionary):
        """Restores the agent from a checkpoint.

        Restores the agent's Python objects to those specified in bundle_dictionary,
        and restores the TensorFlow objects to those specified in the
        checkpoint_dir. If the checkpoint_dir does not exist, will not reset the
          agent's state.

        Args:
          checkpoint_dir: str, path to the checkpoint saved by tf.Save.
          iteration_number: int, checkpoint version, used when restoring replay
            buffer.
          bundle_dictionary: dict, containing additional Python objects owned by
            the agent.

        Returns:
          bool, True if unbundling was successful.
        """
        try:
            pass
            # self._replay.load() will throw a NotFoundError if it does not find all
            # the necessary files, in which case we abort the process & return False.
            #self._replay.load(checkpoint_dir, iteration_number)
        except tf.errors.NotFoundError:
            return False
        for key in self.__dict__:
            if key in bundle_dictionary:
                self.__dict__[key] = bundle_dictionary[key]
        # Restore the agent's TensorFlow graph.
        self._saver.restore(self._sess,
                            os.path.join(checkpoint_dir,
                                         'tf_ckpt-{}'.format(iteration_number)))
        return True

    def _train_step(self):
        if len(self._replay) >= self.batch_size:
            self.training_steps += 1
            #for i in range(10):
            td_loss = self._train()
            self.td_losses.append(td_loss)

    def _train(self):

        if self.per:
            b_idx, batch, b_ISWeights  = self._replay.sample(self.batch_size)

            states, actions, rewards, next_states, next_actions, is_done = zip(*[i for i in batch])
            batch = np.array(states), np.array(actions), np.array(rewards), \
                   np.array(next_states), np.array(next_actions), np.array(is_done)

        else:
            batch = self._replay.sample(self.batch_size)

        state, action, r, next_s, items, done = batch
        action = [np.reshape(a, newshape=-1) for a in action]
        # choose actions for next_s
        a_next = []

        for i in range(len(state)):

            ids, next_action = self.actor.predict_action(state=state[i], items=items[i])
            a_next.append(np.reshape(next_action, newshape=-1))
        td_loss, action_gradients, errors = self.critic.train([state, action, r, next_s, a_next])

        if self.per:
            self._replay.batch_update(b_idx, np.abs(errors))

        self.actor.train(state=state, action_gradients=action_gradients)
        return td_loss

    def _sample_action(self, observation, items):
        if self.t < self.start_steps:
            self.t += 1

            actions_ids = np.random.choice(range(len(items)), size=self.action_size)
            return actions_ids, [items[i] for i in actions_ids]

        actions_ids, action = self.actor.predict_action(observation, items)
        return actions_ids, action

    def _update_target_weights(self):
        self.actor.sync_target()
        self.critic.sync_target()
예제 #12
0
    def __init__(self,
                 state_size,
                 action_size,
                 num_agents,
                 device,
                 seed=23520,
                 GRADIENT_CLIP=1,
                 ACTIVATION=F.relu,
                 BOOTSTRAP_SIZE=5,
                 GAMMA=0.99,
                 TAU=1e-3,
                 LR_CRITIC=5e-4,
                 LR_ACTOR=5e-4,
                 UPDATE_EVERY=1,
                 TRANSFER_EVERY=2,
                 UPDATE_LOOP=10,
                 ADD_NOISE_EVERY=5,
                 WEIGHT_DECAY=0,
                 MEMORY_SIZE=5e4,
                 BATCH_SIZE=64):
        """Initialize an Agent object.
        
        Params
        ======
            state_size  : dimension of each state
            action_size : dimension of each action
            num_agents  : number of running agents
            device: cpu or cuda:0 if available
            -----These are hyperparameters----
            BOOTSTRAP_SIZE      : How far ahead to bootstrap
            GAMMA               : Discount factor
            TAU                 : Parameter for performing soft updates of target parameters
            LR_CRITIC, LR_ACTOR : Learning rate of the networks
            UPDATE_EVERY        : How often to update the networks
            TRANSFER_EVERY      : How often to transfer the weights from local to target
            UPDATE_LOOP         : Number of iterations for network update
            ADD_NOISE_EVERY     : How often to add noise to favor exploration
            WEIGHT_DECAY        : L2 weight decay for critic optimizer
            GRADIENT_CLIP       : Limit of gradient to be clipped, to avoid exploding gradient issue
        """

        # Actor networks
        self.actor_local = Actor(state_size, action_size).to(device)
        self.actor_target = Actor(state_size, action_size).to(device)
        self.actor_optim = optim.Adam(self.actor_local.parameters(),
                                      lr=LR_ACTOR,
                                      weight_decay=WEIGHT_DECAY)
        hard_update(self.actor_local, self.actor_target)

        #critic networks
        self.critic_local = Critic(state_size * 2, action_size).to(device)
        self.critic_target = Critic(state_size * 2, action_size).to(device)
        self.critic_optim = optim.Adam(self.critic_local.parameters(),
                                       lr=LR_CRITIC,
                                       weight_decay=WEIGHT_DECAY)
        hard_update(self.critic_local, self.critic_target)

        self.device = device
        self.num_agents = num_agents

        # Noise : using simple noise instead of OUNoise
        self.noise = [
            SimpleNoise(action_size, scale=1) for i in range(num_agents)
        ]

        # Replay memory
        self.memory = ReplayBuffer(action_size, device, int(MEMORY_SIZE),
                                   BATCH_SIZE, seed)

        # Initialize time steps (for updating every UPDATE_EVERY steps)
        self.u_step = 0
        self.n_step = 0

        #keeping hyperparameters within the instance
        self.BOOTSTRAP_SIZE = BOOTSTRAP_SIZE
        self.GAMMA = GAMMA
        self.TAU = TAU
        self.LR_CRITIC = LR_CRITIC
        self.LR_ACTOR = LR_ACTOR
        self.UPDATE_EVERY = UPDATE_EVERY
        self.TRANSFER_EVERY = TRANSFER_EVERY
        self.UPDATE_LOOP = UPDATE_LOOP
        self.ADD_NOISE_EVERY = ADD_NOISE_EVERY
        self.GRADIENT_CLIP = GRADIENT_CLIP

        # initialize these variables to store the information of the n-previous timestep that are necessary to apply the bootstrap_size
        self.rewards = deque(maxlen=BOOTSTRAP_SIZE)
        self.states = deque(maxlen=BOOTSTRAP_SIZE)
        self.actions = deque(maxlen=BOOTSTRAP_SIZE)
        self.gammas = np.array([[GAMMA**i for j in range(num_agents)]
                                for i in range(BOOTSTRAP_SIZE)])

        self.loss_function = torch.nn.SmoothL1Loss()
예제 #13
0
class MADDPG():
    """
    Class definition of MADDPG agent. Interacts with and learns from the environment
    Comprises of a pair of Actor-Critic network ad implements centralized training and decentralized exeution (learn function)
    """
    def __init__(self,
                 state_size,
                 action_size,
                 num_agents,
                 device,
                 seed=23520,
                 GRADIENT_CLIP=1,
                 ACTIVATION=F.relu,
                 BOOTSTRAP_SIZE=5,
                 GAMMA=0.99,
                 TAU=1e-3,
                 LR_CRITIC=5e-4,
                 LR_ACTOR=5e-4,
                 UPDATE_EVERY=1,
                 TRANSFER_EVERY=2,
                 UPDATE_LOOP=10,
                 ADD_NOISE_EVERY=5,
                 WEIGHT_DECAY=0,
                 MEMORY_SIZE=5e4,
                 BATCH_SIZE=64):
        """Initialize an Agent object.
        
        Params
        ======
            state_size  : dimension of each state
            action_size : dimension of each action
            num_agents  : number of running agents
            device: cpu or cuda:0 if available
            -----These are hyperparameters----
            BOOTSTRAP_SIZE      : How far ahead to bootstrap
            GAMMA               : Discount factor
            TAU                 : Parameter for performing soft updates of target parameters
            LR_CRITIC, LR_ACTOR : Learning rate of the networks
            UPDATE_EVERY        : How often to update the networks
            TRANSFER_EVERY      : How often to transfer the weights from local to target
            UPDATE_LOOP         : Number of iterations for network update
            ADD_NOISE_EVERY     : How often to add noise to favor exploration
            WEIGHT_DECAY        : L2 weight decay for critic optimizer
            GRADIENT_CLIP       : Limit of gradient to be clipped, to avoid exploding gradient issue
        """

        # Actor networks
        self.actor_local = Actor(state_size, action_size).to(device)
        self.actor_target = Actor(state_size, action_size).to(device)
        self.actor_optim = optim.Adam(self.actor_local.parameters(),
                                      lr=LR_ACTOR,
                                      weight_decay=WEIGHT_DECAY)
        hard_update(self.actor_local, self.actor_target)

        #critic networks
        self.critic_local = Critic(state_size * 2, action_size).to(device)
        self.critic_target = Critic(state_size * 2, action_size).to(device)
        self.critic_optim = optim.Adam(self.critic_local.parameters(),
                                       lr=LR_CRITIC,
                                       weight_decay=WEIGHT_DECAY)
        hard_update(self.critic_local, self.critic_target)

        self.device = device
        self.num_agents = num_agents

        # Noise : using simple noise instead of OUNoise
        self.noise = [
            SimpleNoise(action_size, scale=1) for i in range(num_agents)
        ]

        # Replay memory
        self.memory = ReplayBuffer(action_size, device, int(MEMORY_SIZE),
                                   BATCH_SIZE, seed)

        # Initialize time steps (for updating every UPDATE_EVERY steps)
        self.u_step = 0
        self.n_step = 0

        #keeping hyperparameters within the instance
        self.BOOTSTRAP_SIZE = BOOTSTRAP_SIZE
        self.GAMMA = GAMMA
        self.TAU = TAU
        self.LR_CRITIC = LR_CRITIC
        self.LR_ACTOR = LR_ACTOR
        self.UPDATE_EVERY = UPDATE_EVERY
        self.TRANSFER_EVERY = TRANSFER_EVERY
        self.UPDATE_LOOP = UPDATE_LOOP
        self.ADD_NOISE_EVERY = ADD_NOISE_EVERY
        self.GRADIENT_CLIP = GRADIENT_CLIP

        # initialize these variables to store the information of the n-previous timestep that are necessary to apply the bootstrap_size
        self.rewards = deque(maxlen=BOOTSTRAP_SIZE)
        self.states = deque(maxlen=BOOTSTRAP_SIZE)
        self.actions = deque(maxlen=BOOTSTRAP_SIZE)
        self.gammas = np.array([[GAMMA**i for j in range(num_agents)]
                                for i in range(BOOTSTRAP_SIZE)])

        self.loss_function = torch.nn.SmoothL1Loss()

    def reset(self):
        if self.noise:
            for n in self.noise:
                n.reset()

    def set_noise(self, noise):
        self.noise = noise

    def save(self, filename):
        torch.save(self.actor_local.state_dict(),
                   "{}_actor_local.pth".format(filename))
        torch.save(self.actor_target.state_dict(),
                   "{}_actor_target.pth".format(filename))
        torch.save(self.critic_local.state_dict(),
                   "{}_critic_local.pth".format(filename))
        torch.save(self.critic_target.state_dict(),
                   "{}_critic_target.pth".format(filename))

    def load(self, path):
        self.actor_local.load_state_dict(torch.load(path + "_actor_local.pth"))
        self.actor_target.load_state_dict(
            torch.load(path + "_actor_target.pth"))
        self.critic_local.load_state_dict(
            torch.load(path + "_critic_local.pth"))
        self.critic_target.load_state_dict(
            torch.load(path + "_critic_target.pth"))

    def act(self, states, noise=0.0):
        """
        Returns actions of each actor for given states.
        
        Params
        ======
            state    : current states
            add_noise: Introduce some noise in agent's action or not. During training, this is necessary to promote the exploration but should not be used during validation
        """
        actions = None

        self.n_step = (self.n_step + 1) % self.ADD_NOISE_EVERY

        with torch.no_grad():
            self.actor_local.eval()
            states = torch.from_numpy(states).float().unsqueeze(0).to(
                self.device)
            actions = self.actor_local(states).squeeze().cpu().data.numpy()
            self.actor_local.train()
            if self.n_step == 0:
                for i in range(len(actions)):
                    actions[i] += noise * self.noise[i].sample()
        return actions

    def step(self, states, actions, rewards, next_states, dones):
        """
        Take a step for the current episode
        1. Save the experience
        2. Bootstrap the rewards
        3. If update conditions are statisfied, perform learning on required number of loops
        """

        # Save experience in replay memory
        self.rewards.append(rewards)
        self.states.append(states)
        self.actions.append(actions)

        if len(self.rewards) == self.BOOTSTRAP_SIZE:
            # get the bootstrapped sum of rewards per agents
            reward = np.sum(self.rewards * self.gammas, axis=-2)
            self.memory.add(self.states[0], self.actions[0], reward,
                            next_states, dones)

        if np.any(dones):
            self.rewards.clear()
            self.states.clear()
            self.actions.clear()

        # Learn every UPDATE_EVERY timesteps
        self.u_step = (self.u_step + 1) % self.UPDATE_EVERY

        t_step = 0
        if len(self.memory) > self.memory.batch_size and self.u_step == 0:
            for _ in range(self.UPDATE_LOOP):
                self.learn()
                # transfer the weights as specified
                t_step = (t_step + 1) % self.TRANSFER_EVERY
                if t_step == 0:
                    soft_update(self.actor_local, self.actor_target, self.TAU)
                    soft_update(self.critic_local, self.critic_target,
                                self.TAU)

    def transform_states(self, states):
        """
        Transforms states to full states so that both agents can see each others state via the critic network
        """
        batch_size = states.shape[0]
        state_size = states.shape[-1]
        num_agents = states.shape[-2]
        transformed_states = torch.zeros(
            (batch_size, num_agents, state_size * num_agents)).to(self.device)
        for i in range(num_agents):
            start = 0
            for j in range(num_agents):
                transformed_states[:, i, start:start + state_size] += states[:,
                                                                             j]
                start += state_size
        return transformed_states

    def learn(self):
        """
        Update the network parameters using the experiences. The algorithm is described in detail in readme

        Params
        ======
            experiences : List of (s, a, r, s', done) tuples
        """
        # sample the memory to disrupt the internal correlation
        states, actions, rewards, next_states, dones = self.memory.sample()
        full_states = self.transform_states(states)

        # The critic should estimate the value of the states to be equal to rewards plus
        # the estimation of the next_states value according to the critic_target and actor_target
        with torch.no_grad():
            self.actor_target.eval()
            self.critic_target.eval()
            # obtain next actions as given by the target network and get transformed states for the critic
            next_actions = self.actor_target(next_states)
            next_full_states = self.transform_states(next_states)
            # calculate Q value using transformed next states and next actions, basically predict what the next value is from target's perspective
            q_next = self.critic_target(next_full_states,
                                        next_actions).squeeze(-1)
            # calculate the target's value
            targeted_value = rewards + (
                self.GAMMA**self.BOOTSTRAP_SIZE) * q_next * (1 - dones)

        current_value = self.critic_local(full_states, actions).squeeze(-1)
        loss = self.loss_function(current_value, targeted_value)
        # During the optimization, the critic tells how much the value is off from the action value and adjusts the network towards. Basically, the critic takes the actions predicted by the actor, and tells how good or bad they are by calculating its Q-value

        # calculate the loss of the critic network and backpropagate
        self.critic_optim.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(self.critic_local.parameters(),
                                       self.GRADIENT_CLIP)
        self.critic_optim.step()

        # optimize the actor by having the critic evaluating the value of the actor's decision
        self.actor_optim.zero_grad()
        actions_pred = self.actor_local(states)
        mean = self.critic_local(full_states, actions_pred).mean()
        (-mean).backward()
        self.actor_optim.step()