Exemple #1
0
    def __init__(self, trail=False, non_reward_trail=False):
        super().__init__(episodes,
                         runs,
                         trail=trail,
                         non_reward_trail=non_reward_trail
                         )  # sets environment and handles visualization
        self.step_taker = self.q_learning  # the function that provides experience and does the learning
        self.window_area = self.env.window_size**2
        self.num_actions = len(
            self.env.action_space.actions
        )  # number of possible actions in the environment
        self.int_reward_stats = Welford(
        )  # maintains running mean and variance of intrinsic reward
        self.obs_stats = Welford(
        )  # maintains running mean and variance of observations (only for prediction and target networks)

        # handle non_reward_trail
        self.non_reward_trail = non_reward_trail
        if non_reward_trail:
            self.window_area = self.window_area * 2  # extra dimension for breadcrumb observations

        # placeholders
        self.inputs_ph = tf.placeholder(
            tf.float32,
            shape=(1, self.window_area))  # the observation at each time step
        self.aux_inputs_ph = tf.placeholder(
            tf.float32,
            shape=(1, self.window_area
                   ))  # inputs to prediction and fixed target networks
        self.targets_ext_ph = tf.placeholder(
            tf.float32,
            shape=(1,
                   self.num_actions))  # r+gamma*Q(s',a'). external. q-learning
        self.targets_int_ph = tf.placeholder(
            tf.float32,
            shape=(1,
                   self.num_actions))  # r+gamma*Q(s',a'). internal. q-learning

        # building the graph
        self.Q_ext, self.Q_int, self.Q_loss = self.build_graph(
        )  # the ouputs of Q network and its update operation
        self.aux_loss = self.build_aux_graphs(
        )  # the internal reward and the prediciton net's update op
        self.update_op = self.update()
        self.init_op = tf.global_variables_initializer(
        )  # global initialization done after the graph is defined

        # session
        self.sess = tf.Session(
        )  # using the same session for the life of this RNDLearner object (each run).
Exemple #2
0
 def z_score_encoding(self, dataset):
     welford = Welford(shape=(self.autoencoder.encoding_size,))
     for inp in dataset:
         encoding = self.autoencoder(inp, what=['encoding'])['encoding']
         welford(encoding)
     self.encoding_mean = welford.mean
     self.encoding_std = welford.std
Exemple #3
0
 def __init__(self, config):
     self.name = config.name
     self.path_root = get_original_cwd() + '/' + config.path
     self.path = normpath(self.path_root + '/' + self.name + '/')
     if isdir(self.path):
         rmtree(self.path)
     makedirs(self.path)
     self.array_on_disk = appendable_array_file(self.path + '/table.dat')
     self.n_simulations = config.n_simulations
     self.simulations = SimulationPool(self.n_simulations)
     self.cam_resolution = config.cam_resolution
     self.n_steps_per_movment = config.n_steps_per_movment
     DISTANCE = 0.75
     self.simulations.add_arm(position=[-DISTANCE, 0.0, 0.0])
     self.simulations.add_arm(position=[DISTANCE, 0.0, 0.0])
     self.cameras = self.simulations.add_camera(
         position=[0.0, 1.6, 1.0],
         orientation=[13 * np.pi / 20, 0.0, 0.0],
         resolution=self.cam_resolution,
         view_angle=90.0,
     )
     self.simulations.start_sim()
     self._buffer_dtype = np.dtype([
         ('arm0_end_eff', np.float32, (3, )),
         ('arm0_positions', np.float32, (7, )),
         ('arm0_velocities', np.float32, (7, )),
         ('arm0_forces', np.float32, (7, )),
         ('arm1_end_eff', np.float32, (3, )),
         ('arm1_positions', np.float32, (7, )),
         ('arm1_velocities', np.float32, (7, )),
         ('arm1_forces', np.float32, (7, )),
         ('frame_path', np.unicode_, 12),
     ])
     self._buffer = np.zeros(shape=self.n_simulations,
                             dtype=self._buffer_dtype)
     self._welfords = {
         'arm0_end_eff': Welford(shape=(3, )),
         'arm0_positions': Welford(shape=(7, )),
         'arm0_velocities': Welford(shape=(7, )),
         'arm0_forces': Welford(shape=(7, )),
         'arm1_end_eff': Welford(shape=(3, )),
         'arm1_positions': Welford(shape=(7, )),
         'arm1_velocities': Welford(shape=(7, )),
         'arm1_forces': Welford(shape=(7, )),
     }
     self._frame_welford = Welford(shape=(self.cam_resolution[1],
                                          self.cam_resolution[0], 3))
     self._current_frame_id = 0
     with self.simulations.specific([0]):
         self._intervals = self.simulations.get_joint_intervals()[0]
 def block(self, nblocks, data):
     blockLength = len(data) / nblocks
     i = 0
     blockAvgs = []
     while i < nblocks:
         blockData = data[i * blockLength:(i + 1) * blockLength]
         stats = Welford()
         stats(blockData)
         blockAvgs.append(stats.mean)
         i += 1
     return blockAvgs
Exemple #5
0
def get_statpool(allstats, statfuncs, key_func=get_key):
    statpool = {}
    statbase = [Welford() for x in statfuncs]
    for faction in allstats:
        if "1" not in faction.score_tiles:
            print >> sys.stderr, "invalid score tiles: " + faction.game_id
            continue
        key = key_func(faction)
        stats = statpool.setdefault(key, copy.deepcopy(statbase))
        for i, statfunc in enumerate(statfuncs):
            stats[i](statfunc(faction))
    return statpool
    def scanBlocking(self, data):
        self.blockDict = {}

        M = 0
        while len(data) >= 2:
            M += 1
            data = self.blockNtimes(1, data)
            stats = Welford()
            stats(data)

            #mean   = sum(data) / float( len(data) )
            #meansq = mean * mean
            #var = 0.00
            #for x in data:
            #    var += x*x - meansq
            #var = var / (len(data) - 1.0)
            #print "CHECK -  welford:", stats.std * stats.std, \
            #      "        standard:", var, \
            #      "            diff:", stats.std*stats.std - var

            if len(data) < 2: break

            L = float(len(data))
            Lminus1 = L - 1.0

            var = stats.std * stats.std
            var = var / Lminus1
            varStd = math.sqrt(2.0 / Lminus1)
            varPlus = var * (1.0 + varStd)
            varMinus = var * (1.0 - varStd)

            std = math.sqrt(var)
            stdStd = 1.0 / math.sqrt(2.0 * Lminus1)
            stdPlus = std * (1.0 + stdStd)
            stdMinus = std * (1.0 - stdStd)

            out = {
                "length": len(data),
                "mean": stats.mean,
                "var": var,
                "var_plus": varPlus,
                "var_minus": varMinus,
                "std": std,
                "std_plus": stdPlus,
                "std_minus": stdMinus
            }
            self.blockDict[M] = out
Exemple #7
0
import numpy as np
from welford import Welford

if __name__ == '__main__':
    data = []
    stats = []
    np.random.seed(10)
    for i in range(1, 11):
        buffer = np.random.normal(i, i / 2, 2000)
        data.append(buffer)
        stats.append(Welford())
        print('Column %d | mean: %f | std: %f' %
              (i, buffer.mean(), buffer.std()))

    # Real data vector: 300 files (observations), 10 columns (features)
    data = np.array(data).T

    data_split = np.split(data, 100)
    print('\n\n\t RESULTS: \n')

    # for i, row in enumerate(data.T):
    #     print('Column %d | mean: %f' % (i + 1, row.mean()))

    for i, sub_data in enumerate(data_split):
        for j, col in enumerate(sub_data.T):
            if i == 0:
                stats[j].k = 1
                stats[j].M = np.mean(col)
                stats[j].S = np.std(col)
            else:
                stats[j](col)
Exemple #8
0
import pandas as pd
import numpy as np
from welford import Welford

if __name__ == '__main__':
    # Set main folder: mf
    mf = '/run/media/ssilvari/Smith_2T_WD/Data/metaimagen_test_results/'

    print('[  INFO  ] Loading AIO data...')
    all_csv = join(mf, 'all_in_one/output/groupfile_features.csv')
    df = pd.read_csv(all_csv, index_col=0)
    all_stats = {'mean': df.mean().values, 'std': df.std().values}

    print('[  INFO  ] Calculating local WF...')
    wf = Welford()

    for i in range(1, 9):
        center_csv = join(mf, 'center_%d/output/groupfile_features.csv' % i)
        data = pd.read_csv(center_csv, index_col=0).values

        wf(data)
    stats = {'mean': wf.M, 'std': np.sqrt(wf.S / (wf.k - 1))}

    for key, _ in stats.items():
        err = np.mean(np.abs(stats[key] - all_stats[key]))
        print('%s: %f' % (key, err))

    # ==== Compare with distributed results ====
    print('\n\n[  INFO  ] Comparing with distributed version...')
    distributed_wf = join(mf, 'all_in_one/output/welford/welford_final.npz')
Exemple #9
0
class RNDLearner(DoubleQLearner):
    def __init__(self, episodes, runs):
        super().__init__(
            episodes,
            runs)  # sets environment and handles visualization if turned on
        self.step_taker = self.q_learning  # the function that provides experience and does the learning
        self.num_actions = len(
            self.env.action_space.actions
        )  # number of possible actions in the environment
        self.num_states = self.env.width * self.env.height  # number of states in the environment
        self.int_reward_stats = Welford(
        )  # maintains running mean and variance of intrinsic reward
        self.obs_stats = Welford(
        )  # maintains running mean and variance of observations (only for prediction and target networks)

        # placeholders
        self.inputs_ph = tf.placeholder(
            tf.float32,
            shape=(1, self.num_states))  # the observation at each time step
        self.aux_inputs_ph = tf.placeholder(
            tf.float32,
            shape=(1, self.num_states
                   ))  # inputs to prediction and fixed target networks
        self.targets_ext_ph = tf.placeholder(
            tf.float32,
            shape=(1,
                   self.num_actions))  # r+gamma*Q(s',a'). external. q-learning
        self.targets_int_ph = tf.placeholder(
            tf.float32,
            shape=(1,
                   self.num_actions))  # r+gamma*Q(s',a'). internal. q-learning

        # building the graph
        self.Q_ext, self.Q_int, self.Q_loss = self.build_graph(
        )  # the ouputs of Q network and its update operation
        self.aux_loss = self.build_aux_graphs(
        )  # the internal reward and the prediciton net's update op
        self.update_op = self.update()
        self.init_op = tf.global_variables_initializer(
        )  # global initialization done after the graph is defined

        # session
        self.sess = tf.Session(
        )  # using the same session for the life of this RNDLearner object (each run).

    # builds the computation graph for a Q network
    def build_graph(self):
        # separate Q-value heads for estimates of extrinsic and intrinsic returns
        Q_h = tf.layers.dense(self.inputs_ph,
                              16,
                              activation=tf.nn.relu,
                              kernel_initializer=tf.initializers.random_normal,
                              name="Q_h")
        Q_ext = tf.layers.dense(
            Q_h,
            self.num_actions,
            activation=None,
            kernel_initializer=tf.initializers.random_normal,
            name="Q_ext")
        Q_int = tf.layers.dense(
            Q_h,
            self.num_actions,
            activation=None,
            kernel_initializer=tf.initializers.random_normal,
            name="Q_int")

        loss_ext = tf.reduce_sum(
            tf.square(self.targets_ext_ph -
                      Q_ext))  # error in prediction of external return
        loss_int = tf.reduce_sum(
            tf.square(self.targets_int_ph -
                      Q_int))  # error in prediction of internal return
        Q_loss = loss_ext + loss_int
        return Q_ext, Q_int, Q_loss

    # defines the graph structure used for both the prediction net and the fixed target net
    def aux_graph(self, trainable):
        h1 = tf.layers.dense(self.aux_inputs_ph,
                             16,
                             activation=tf.nn.relu,
                             kernel_initializer=tf.initializers.random_normal,
                             name="h1",
                             trainable=trainable)
        output = tf.layers.dense(
            h1,
            self.num_actions,
            activation=None,
            kernel_initializer=tf.initializers.random_normal,
            name="output",
            trainable=trainable)
        return output

    # returns operations for getting the prediction net loss (aka internal reward) and updating the prediction net parameters
    def build_aux_graphs(self):
        with tf.variable_scope("target_net"):
            target_net_out = self.aux_graph(trainable=False)
        with tf.variable_scope("predictor_net"):
            predictor_net_out = self.aux_graph(trainable=True)

        aux_loss = tf.reduce_sum(
            tf.square(predictor_net_out - target_net_out)
        )  # loss for training predictor network. also intrinsic reward
        return aux_loss

    def update(self):
        loss = self.Q_loss + self.aux_loss
        optimizer = tf.train.AdamOptimizer()
        gradients, variables = zip(*optimizer.compute_gradients(loss))
        gradients, _ = tf.clip_by_global_norm(gradients, max_grad_norm)
        update_op = optimizer.apply_gradients(zip(gradients, variables))
        return update_op

    # returns eps-greedy action with respect to Q_ext+Q_int
    def get_eps_action(self, one_hot_state, eps):
        Q_ext, Q_int = self.sess.run(
            [self.Q_ext, self.Q_int],
            {self.inputs_ph: one_hot_state})  # outputs of Q network
        Q = Q_ext + Q_int_coeff * Q_int
        if self.env.action_space.np_random.uniform(
        ) < eps:  # take random action with probability epsilon
            action = self.env.action_space.sample()
        else:
            max_actions = np.where(
                np.ravel(Q) == Q.max())[0]  # list of optimal action indices
            if len(max_actions) == 0:  # for debugging
                print(Q)
            action = self.env.action_space.np_random.choice(
                max_actions)  # select from optimal actions randomly
        return action, Q_ext, Q_int

    # take step in environment and gather/update information
    def step(self, eps):
        state = self.env.state
        action, Q_ext, Q_int = self.get_eps_action(
            self.one_hot(state), eps)  # take action wrt Q_ext+Q_int.
        next_state, reward_ext, done, _ = self.env.step(
            action)  # get external reward by acting in the environment

        self.obs_stats.update(
            self.one_hot(next_state))  # update observation statistics
        whitened_state = (
            self.one_hot(next_state) - self.obs_stats.mean) / np.sqrt(
                self.obs_stats.var)  # whitened obs for pred and target nets
        whitened_state = np.clip(whitened_state, -5, 5)

        reward_int = self.sess.run(
            self.aux_loss,
            {self.aux_inputs_ph: whitened_state})  # get intrinsic reward
        self.int_reward_stats.update(
            reward_int)  # update running statistics for intrinsic reward
        reward_int = reward_int / np.sqrt(
            self.int_reward_stats.var)  # normalize intrinsic reward

        return state, whitened_state, action, reward_ext, reward_int, next_state, Q_ext, Q_int, done

    def q_learning(self):
        self.sess.run(self.init_op)  # initialize all model parameters
        self.initialize_stats(
        )  # reset all statistics to zero and then initialize with random agent

        for episode in range(episodes):
            eps = eps0 - eps0 * episode / episodes  # decay epsilon
            done = False
            t = 0
            while not done:  # for each episode
                t += 1

                # take step in environment and gather/update information
                state, whitened_state, action, reward_ext, reward_int, next_state, Q_ext, Q_int, done = self.step(
                    eps)

                #print(episode, reward_int)

                if t > 20:  # cap episode length at 20 timesteps
                    self.env.reset()
                    done = True

                # report data for analysis/plotting/visualization
                yield state, self.env.action_space.actions[
                    action], reward_ext, reward_int, next_state, done

                # greedy next action wrt Q_ext+Q_int
                _, Q_ext_next, Q_int_next = self.get_eps_action(
                    self.one_hot(next_state), 0)

                # intrinsic reward is non-episodic
                target_value_int = reward_int + gamma_int * np.max(Q_int_next)

                # extrinsic reward is episodic
                target_value_ext = reward_ext
                if not done:
                    target_value_ext += gamma_ext * np.max(Q_ext_next)

                target_Q_ext = Q_ext  # only chosen action can have nonzero error
                target_Q_ext[
                    0,
                    action] = target_value_ext  # the first index is into the zeroth (and only) batch dimension

                target_Q_int = Q_int
                target_Q_int[0, action] = target_value_int

                # update all parameters to minimize combined loss
                self.sess.run(
                    self.update_op, {
                        self.inputs_ph: self.one_hot(state),
                        self.targets_ext_ph: target_Q_ext,
                        self.targets_int_ph: target_Q_int,
                        self.aux_inputs_ph: whitened_state
                    })

    # returns one-hot representation of the given state
    def one_hot(self, state):
        one_hot_state = np.zeros(self.num_states)
        one_hot_state[state[0] * self.env.width + state[1]] = 1
        return np.array([one_hot_state])  # add batch dimension of length 1

    # initialize intrinsic reward and observation statistics with experience from a random agent
    def initialize_stats(self):
        # reset statistics to zero
        self.int_reward_stats.reset()
        self.obs_stats.reset()

        # initialize observation stats first
        obs_list = []
        while np.any(
                self.obs_stats.var == 0
        ):  # act for as many episodes as it takes for every state to have nonzero variance
            done = False
            while not done:
                action = self.env.action_space.sample(
                )  # take an action to give us another observation
                next_state, _, done, _ = self.env.step(
                    action)  # use resulting next state as observation
                observation = self.one_hot(
                    next_state
                )  #TODO: change this (and all one_hot calls) when partially observing
                self.obs_stats.update(observation)  # update obs stats
                obs_list.append(
                    observation
                )  # save to list for use with intrinsic reward stats

        # use observations to create whitened states for input to the aux networks
        for observation in obs_list:
            whitened_state = (observation - self.obs_stats.mean) / np.sqrt(
                self.obs_stats.var
            )  # whitened observation for pred and target nets
            whitened_state = np.clip(whitened_state, -5, 5)
            reward_int = self.sess.run(
                self.aux_loss,
                {self.aux_inputs_ph: whitened_state})  # get intrinsic reward
            self.int_reward_stats.update(reward_int)  # update int_reward stats
Exemple #10
0
from welford import Welford
import numpy as np
from os import listdir
from os.path import isfile, join
import cv2

if __name__ == '__main__':

    w = Welford()
    imgs_dir = '../data_dm_overlapping/imgs/'

    for img_name in listdir(imgs_dir):
        img_path = join(imgs_dir, img_name)

        if (not img_name.endswith('.png') or not isfile(img_path)):
            continue

        img = cv2.imread(img_path)
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)

        for i in range(img.shape[0]):
            for j in range(img.shape[1]):
                w.add(img[i, j, :] / 255)

    print('Mean: ', w.mean)
    print('Std: ', np.sqrt(w.var_s))
    def __init__(
        self,
        envs,
        acmodel,
        autoencoder,
        autoencoder_opt,
        uncertainty,
        noisy_tv,
        curiosity,
        randomise_env,
        uncertainty_budget,
        environment_seed,
        reward_weighting,
        normalise_rewards,
        device=None,
        num_frames_per_proc=None,
        discount=0.99,
        lr=0.001,
        gae_lambda=0.95,
        entropy_coef=0.01,
        value_loss_coef=0.5,
        max_grad_norm=0.5,
        recurrence=4,
        adam_eps=1e-8,
        clip_eps=0.2,
        epochs=4,
        batch_size=256,
        preprocess_obss=None,
        reshape_reward=None,
    ):
        num_frames_per_proc = num_frames_per_proc or 128

        super().__init__(
            envs,
            acmodel,
            device,
            num_frames_per_proc,
            discount,
            lr,
            gae_lambda,
            entropy_coef,
            value_loss_coef,
            max_grad_norm,
            recurrence,
            preprocess_obss,
            reshape_reward,
        )

        shape = (self.num_frames_per_proc, self.num_procs)
        self.clip_eps = clip_eps
        self.epochs = epochs
        self.batch_size = batch_size
        self.noisy_tv = noisy_tv
        assert self.batch_size % self.recurrence == 0

        self.optimizer = torch.optim.Adam(self.acmodel.parameters(),
                                          lr,
                                          eps=adam_eps)
        self.batch_num = 0
        self.icm = ICM(
            autoencoder,
            autoencoder_opt,
            uncertainty,
            device,
            self.preprocess_obss,
        )
        self.visitation_counts = np.zeros(
            (self.env.envs[0].width, self.env.envs[0].height))
        self.reward_weighting = reward_weighting
        self.curiosity = curiosity
        self.intrinsic_rewards = torch.zeros(*shape, device=self.device)
        self.uncertainties = torch.zeros(*shape, device=self.device)
        self.novel_states_visited = torch.zeros(*shape, device=self.device)
        self.normalise_rewards = normalise_rewards
        self.intrinsic_reward_buffer = []
        self.action_stats_logger = ActionStatsLogger(
            self.env.envs[0].action_space.n)
        self.online_variance = Welford()
        self.normalise_rewards = normalise_rewards
        self.env = NoisyTVWrapper(self.env, self.noisy_tv)
class PPOAlgo(BaseAlgo):
    """The Proximal Policy Optimization algorithm
    ([Schulman et al., 2015](https://arxiv.org/abs/1707.06347))."""
    def __init__(
        self,
        envs,
        acmodel,
        autoencoder,
        autoencoder_opt,
        uncertainty,
        noisy_tv,
        curiosity,
        randomise_env,
        uncertainty_budget,
        environment_seed,
        reward_weighting,
        normalise_rewards,
        device=None,
        num_frames_per_proc=None,
        discount=0.99,
        lr=0.001,
        gae_lambda=0.95,
        entropy_coef=0.01,
        value_loss_coef=0.5,
        max_grad_norm=0.5,
        recurrence=4,
        adam_eps=1e-8,
        clip_eps=0.2,
        epochs=4,
        batch_size=256,
        preprocess_obss=None,
        reshape_reward=None,
    ):
        num_frames_per_proc = num_frames_per_proc or 128

        super().__init__(
            envs,
            acmodel,
            device,
            num_frames_per_proc,
            discount,
            lr,
            gae_lambda,
            entropy_coef,
            value_loss_coef,
            max_grad_norm,
            recurrence,
            preprocess_obss,
            reshape_reward,
        )

        shape = (self.num_frames_per_proc, self.num_procs)
        self.clip_eps = clip_eps
        self.epochs = epochs
        self.batch_size = batch_size
        self.noisy_tv = noisy_tv
        assert self.batch_size % self.recurrence == 0

        self.optimizer = torch.optim.Adam(self.acmodel.parameters(),
                                          lr,
                                          eps=adam_eps)
        self.batch_num = 0
        self.icm = ICM(
            autoencoder,
            autoencoder_opt,
            uncertainty,
            device,
            self.preprocess_obss,
        )
        self.visitation_counts = np.zeros(
            (self.env.envs[0].width, self.env.envs[0].height))
        self.reward_weighting = reward_weighting
        self.curiosity = curiosity
        self.intrinsic_rewards = torch.zeros(*shape, device=self.device)
        self.uncertainties = torch.zeros(*shape, device=self.device)
        self.novel_states_visited = torch.zeros(*shape, device=self.device)
        self.normalise_rewards = normalise_rewards
        self.intrinsic_reward_buffer = []
        self.action_stats_logger = ActionStatsLogger(
            self.env.envs[0].action_space.n)
        self.online_variance = Welford()
        self.normalise_rewards = normalise_rewards
        self.env = NoisyTVWrapper(self.env, self.noisy_tv)

    def update_visitation_counts(self, envs):
        """
        updates counts of novel states visited
        """
        for i, env in enumerate(envs):
            if self.visitation_counts[env.agent_pos[0]][env.agent_pos[1]] == 0:
                pass
                # self.agents_to_save.append(i)
            self.visitation_counts[env.agent_pos[0]][env.agent_pos[1]] += 1

    def collect_experiences(self):
        """Collects rollouts and computes advantages.

        Runs several environments concurrently. The next actions are computed
        in a batch mode for all environments at the same time. The rollouts
        and advantages from all environments are concatenated together.

        Returns
        -------
        exps : DictList
            Contains actions, rewards, advantages etc as attributes.
            Each attribute, e.g. `exps.reward` has a shape
            (self.num_frames_per_proc * num_envs, ...). k-th block
            of consecutive `self.num_frames_per_proc` frames contains
            data obtained from the k-th environment. Be careful not to mix
            data from different environments!
        logs : dict
            Useful stats about the training process, including the average
            reward, policy loss, value loss, etc.
        """
        # 16 threads running in parallel for 8 frames at a time before parameters
        # are updated, so gathers a total 128 frames
        loss = 0
        for i in range(self.num_frames_per_proc):
            preprocessed_obs = self.preprocess_obss(self.obs,
                                                    device=self.device)
            with torch.no_grad():
                if self.acmodel.recurrent:
                    dist, value, memory = self.acmodel(
                        preprocessed_obs, self.memory * self.mask.unsqueeze(1))
                else:
                    dist, value = self.acmodel(preprocessed_obs)
            action = dist.sample()
            obs, extrinsic_reward, done, _ = self.env.step(action)
            reward = extrinsic_reward
            self.update_visitation_counts(self.env.envs)
            self.obss[i] = self.obs
            self.obs = obs
            if self.curiosity == "True":
                (
                    loss,
                    intrinsic_reward,
                    uncertainty,
                ) = self.icm.compute_intrinsic_rewards(self.obss[i], self.obs,
                                                       action)
                if self.normalise_rewards == "True":
                    intrinsic_reward_numpy = intrinsic_reward.detach().cpu(
                    ).numpy()
                    self.online_variance.add_all(intrinsic_reward_numpy)
                    intrinsic_reward /= np.sqrt(self.online_variance.var_s)
                intrinsic_reward *= self.reward_weighting
                reward = intrinsic_reward + torch.tensor(
                    reward, dtype=torch.float).to(self.device)
                loss = torch.sum(loss)
                self.intrinsic_reward_buffer.append(intrinsic_reward)
                self.action_stats_logger.add_to_log_dicts(
                    action.detach().numpy(),
                    intrinsic_reward.detach().numpy())
                self.icm.update_curiosity_parameters(loss)
            if self.acmodel.recurrent:
                self.memories[i] = self.memory
                self.memory = memory
            self.masks[i] = self.mask
            self.mask = 1 - torch.tensor(
                done, device=self.device, dtype=torch.float)
            self.actions[i] = action
            self.values[i] = value
            if self.curiosity == "True":
                self.uncertainties[i] = uncertainty
                self.intrinsic_rewards[i] = intrinsic_reward
            else:
                self.uncertainties[i] = torch.zeros_like(action)
                self.intrinsic_rewards[i] = torch.zeros_like(action)
            self.novel_states_visited[i] = np.count_nonzero(
                self.visitation_counts)
            self.rewards[i] = torch.tensor(reward, device=self.device)
            self.log_probs[i] = dist.log_prob(action)

            # Update log values

            self.log_episode_return += torch.tensor(reward,
                                                    device=self.device,
                                                    dtype=torch.float)
            self.log_episode_reshaped_return += self.rewards[i]
            self.log_episode_num_frames += torch.ones(self.num_procs,
                                                      device=self.device)

            for i, done_ in enumerate(done):
                if done_:
                    self.log_done_counter += 1
                    self.log_return.append(self.log_episode_return[i].item())
                    self.log_reshaped_return.append(
                        self.log_episode_reshaped_return[i].item())
                    self.log_num_frames.append(
                        self.log_episode_num_frames[i].item())

            self.log_episode_return *= self.mask
            self.log_episode_reshaped_return *= self.mask
            self.log_episode_num_frames *= self.mask

        # Add advantage and return to experiences

        preprocessed_obs = self.preprocess_obss(self.obs, device=self.device)
        with torch.no_grad():
            if self.acmodel.recurrent:
                _, next_value, _ = self.acmodel(
                    preprocessed_obs, self.memory * self.mask.unsqueeze(1))
            else:
                _, next_value = self.acmodel(preprocessed_obs)

        for i in reversed(range(self.num_frames_per_proc)):
            next_mask = (self.masks[i + 1]
                         if i < self.num_frames_per_proc - 1 else self.mask)
            next_value = (self.values[i + 1]
                          if i < self.num_frames_per_proc - 1 else next_value)
            next_advantage = (self.advantages[i + 1]
                              if i < self.num_frames_per_proc - 1 else 0)

            delta = (self.rewards[i] + self.discount * next_value * next_mask -
                     self.values[i])
            self.advantages[i] = (
                delta +
                self.discount * self.gae_lambda * next_advantage * next_mask)

        # Define experiences:
        #   the whole experience is the concatenation of the experience
        #   of each process.
        # In comments below:
        #   - T is self.num_frames_per_proc,
        #   - P is self.num_procs,
        #   - D is the dimensionality.

        exps = DictList()
        exps.obs = [
            self.obss[i][j] for j in range(self.num_procs)
            for i in range(self.num_frames_per_proc)
        ]
        if self.acmodel.recurrent:
            # T x P x D -> P x T x D -> (P * T) x D
            exps.memory = self.memories.transpose(0, 1).reshape(
                -1, *self.memories.shape[2:])
            # T x P -> P x T -> (P * T) x 1
            exps.mask = self.masks.transpose(0, 1).reshape(-1).unsqueeze(1)
        # for all tensors below, T x P -> P x T -> P * T
        exps.action = self.actions.transpose(0, 1).reshape(-1)
        exps.value = self.values.transpose(0, 1).reshape(-1)
        exps.reward = self.rewards.transpose(0, 1).reshape(-1)
        intrinsic_rewards = self.intrinsic_rewards.transpose(0, 1).reshape(-1)
        uncertainties = self.uncertainties.transpose(0, 1).reshape(-1)
        novel_states_visited = self.novel_states_visited.transpose(
            0, 1).reshape(-1)
        exps.advantage = self.advantages.transpose(0, 1).reshape(-1)
        exps.returnn = exps.value + exps.advantage
        exps.log_prob = self.log_probs.transpose(0, 1).reshape(-1)

        # Preprocess experiences

        exps.obs = self.preprocess_obss(exps.obs, device=self.device)

        # Log some values

        keep = max(self.log_done_counter, self.num_procs)

        logs = {
            "uncertainties": uncertainties,
            "intrinsic_rewards": intrinsic_rewards,
            "novel_states_visited": novel_states_visited,
            "return_per_episode": self.log_return[-keep:],
            "reshaped_return_per_episode": self.log_reshaped_return[-keep:],
            "num_frames_per_episode": self.log_num_frames[-keep:],
            "num_frames": self.num_frames,
        }

        self.log_done_counter = 0
        self.log_return = self.log_return[-self.num_procs:]
        self.log_reshaped_return = self.log_reshaped_return[-self.num_procs:]
        self.log_num_frames = self.log_num_frames[-self.num_procs:]

        return exps, logs

    def update_parameters(self, exps):
        # Collect experiences

        for _ in range(self.epochs):
            # Initialize log values

            log_entropies = []
            log_values = []
            log_policy_losses = []
            log_value_losses = []
            log_grad_norms = []

            for inds in self._get_batches_starting_indexes():
                # Initialize batch values

                batch_entropy = 0
                batch_value = 0
                batch_policy_loss = 0
                batch_value_loss = 0
                batch_loss = 0

                # Initialize memory

                if self.acmodel.recurrent:
                    memory = exps.memory[inds]

                for i in range(self.recurrence):
                    # Create a sub-batch of experience

                    sb = exps[inds + i]

                    # Compute loss

                    if self.acmodel.recurrent:
                        dist, value, memory = self.acmodel(
                            sb.obs, memory * sb.mask)
                    else:
                        dist, value = self.acmodel(sb.obs)

                    entropy = dist.entropy().mean()

                    ratio = torch.exp(dist.log_prob(sb.action) - sb.log_prob)
                    surr1 = ratio * sb.advantage
                    surr2 = (torch.clamp(ratio, 1.0 - self.clip_eps,
                                         1.0 + self.clip_eps) * sb.advantage)
                    policy_loss = -torch.min(surr1, surr2).mean()

                    value_clipped = sb.value + torch.clamp(
                        value - sb.value, -self.clip_eps, self.clip_eps)
                    surr1 = (value - sb.returnn).pow(2)
                    surr2 = (value_clipped - sb.returnn).pow(2)
                    value_loss = torch.max(surr1, surr2).mean()

                    loss = (policy_loss - self.entropy_coef * entropy +
                            self.value_loss_coef * value_loss)

                    # Update batch values

                    batch_entropy += entropy.item()
                    batch_value += value.mean().item()
                    batch_policy_loss += policy_loss.item()
                    batch_value_loss += value_loss.item()
                    batch_loss += loss

                    # Update memories for next epoch

                    if self.acmodel.recurrent and i < self.recurrence - 1:
                        exps.memory[inds + i + 1] = memory.detach()

                # Update batch values

                batch_entropy /= self.recurrence
                batch_value /= self.recurrence
                batch_policy_loss /= self.recurrence
                batch_value_loss /= self.recurrence
                batch_loss /= self.recurrence

                # Update actor-critic

                self.optimizer.zero_grad()
                batch_loss.backward()
                grad_norm = (sum(
                    p.grad.data.norm(2).item()**2
                    for p in self.acmodel.parameters())**0.5)
                torch.nn.utils.clip_grad_norm_(self.acmodel.parameters(),
                                               self.max_grad_norm)
                self.optimizer.step()

                # Update log values

                log_entropies.append(batch_entropy)
                log_values.append(batch_value)
                log_policy_losses.append(batch_policy_loss)
                log_value_losses.append(batch_value_loss)
                log_grad_norms.append(grad_norm)

        # Log some values

        logs = {
            "entropy": numpy.mean(log_entropies),
            "value": numpy.mean(log_values),
            "policy_loss": numpy.mean(log_policy_losses),
            "value_loss": numpy.mean(log_value_losses),
            "grad_norm": numpy.mean(log_grad_norms),
        }

        return logs

    def _get_batches_starting_indexes(self):
        """Gives, for each batch, the indexes of the observations given to
        the model and the experiences used to compute the loss at first.

        First, the indexes are the integers from 0 to `self.num_frames` with a step of
        `self.recurrence`, shifted by `self.recurrence//2` one time in two for having
        more diverse batches. Then, the indexes are splited into the different batches.

        Returns
        -------
        batches_starting_indexes : list of list of int
            the indexes of the experiences to be used at first for each batch
        """

        indexes = numpy.arange(0, self.num_frames, self.recurrence)
        indexes = numpy.random.permutation(indexes)

        # Shift starting indexes by self.recurrence//2 half the time
        if self.batch_num % 2 == 1:
            indexes = indexes[(indexes + self.recurrence) %
                              self.num_frames_per_proc != 0]
            indexes += self.recurrence // 2
        self.batch_num += 1

        num_indexes = self.batch_size // self.recurrence
        batches_starting_indexes = [
            indexes[i:i + num_indexes]
            for i in range(0, len(indexes), num_indexes)
        ]

        return batches_starting_indexes
        else:
            mn, st = normalize
            samples = (np.stack([
                spectrogram_base(
                    samp,
                    nperseg,
                    noverlap,
                    2**logfft
                )[3]
                for samp in samples
            ]) - mn) / st
        
        yield (samples[:, :t_in, :], samples[:, t_in:, :])
        
        
welford_from_generator = Welford()
w_epochs = 5
w_batch_size = 5000
# 5 * (t_in + t_out) * 5000 ~= 2.5M
for x, y in data_generator(w_batch_size, 1, w_epochs):
    for spects in x:
        welford_from_generator.add_all(spects)
    for spects in y:
        welford_from_generator.add_all(spects)
        
        
mn, st = welford_from_generator.mean.reshape((1, -1)), np.sqrt(welford_from_generator.var_p).reshape((1, -1))


#---------------------
#---------------------