Exemplo n.º 1
0
    def __init__(self, horizon, obs_dim, action_dim):
        self._T = horizon
        Transition.__init__(self, obs_dim=obs_dim, action_dim=action_dim)

        self._serializable_initialized = False
        Serializable.quick_init(self, locals())
        super(TVLGDynamics, self).__init__()

        self.Fm = nn.Parameter(
            ptu.zeros(horizon, obs_dim, obs_dim + action_dim))
        self.fv = nn.Parameter(ptu.ones(horizon, obs_dim))
        self.dyn_covar = nn.Parameter(ptu.zeros(horizon, obs_dim, obs_dim))

        # Prior
        self._prior = None
Exemplo n.º 2
0
    def fit(self, States, Actions, regularization=1e-6):
        """ Fit dynamics. """
        N, T, dS = States.shape
        dA = Actions.shape[2]

        if N == 1:
            raise ValueError("Cannot fit dynamics on 1 sample")

        it = slice(dS + dA)

        # Fit dynamics with least squares regression.
        dwts = (1.0 / N) * ptu.ones(N)

        for t in range(T - 1):
            Ys = torch.cat(
                (States[:, t, :], Actions[:, t, :], States[:, t + 1, :]),
                dim=-1)

            # Obtain Normal-inverse-Wishart prior.
            mu0, Phi, mm, n0 = self._prior.eval(dS, dA, Ys)
            sig_reg = ptu.zeros((dS + dA + dS, dS + dA + dS))
            sig_reg[it, it] = regularization

            Fm, fv, dyn_covar = \
                gauss_fit_joint_prior(Ys, mu0, Phi, mm, n0,
                                      dwts, dS+dA, dS, sig_reg)
            self.Fm[t, :, :] = Fm
            self.fv[t, :] = fv
            self.dyn_covar[t, :, :] = dyn_covar
Exemplo n.º 3
0
    def __init__(self,
                 obs_dim,
                 action_dim,
                 hidden_sizes,
                 std=None,
                 hidden_w_init='xavier_normal',
                 hidden_b_init_val=0,
                 output_w_init='xavier_normal',
                 output_b_init_val=0,
                 **kwargs):
        """

        Args:
            obs_dim:
            action_dim:
            hidden_sizes:
            std:
            hidden_w_init:
            hidden_b_init_val:
            output_w_init:
            output_b_init_val:
            **kwargs:
        """
        self.save_init_params(locals())
        super(TanhGaussianPolicy,
              self).__init__(hidden_sizes,
                             input_size=obs_dim,
                             output_size=action_dim,
                             hidden_w_init=hidden_w_init,
                             hidden_b_init_val=hidden_b_init_val,
                             output_w_init=output_w_init,
                             output_b_init_val=output_b_init_val,
                             **kwargs)
        ExplorationPolicy.__init__(self, action_dim)

        self.log_std = None
        self.std = std
        if std is None:
            last_hidden_size = obs_dim
            if len(hidden_sizes) > 0:
                last_hidden_size = hidden_sizes[-1]
            self.last_fc_log_std = nn.Linear(last_hidden_size, action_dim)
            ptu.layer_init(layer=self.last_fc_log_std,
                           option=output_w_init,
                           activation='linear',
                           b=output_b_init_val)
        else:
            self.log_std = math.log(std)
            assert LOG_SIG_MIN <= self.log_std <= LOG_SIG_MAX

        self._normal_dist = Normal(loc=ptu.zeros(action_dim),
                                   scale=ptu.ones(action_dim))
Exemplo n.º 4
0
    def _update_global_policy(self):
        """
        Computes(updates) a new global policy.
        :return:
        """
        dU, dO, T = self.dU, self.dO, self.T
        # Compute target mean, cov(precision), and weight for each sample;
        # and concatenate them.
        obs_data, tgt_mu = ptu.zeros((0, T, dO)), ptu.zeros((0, T, dU))
        tgt_prc, tgt_wt = ptu.zeros((0, T, dU, dU)), ptu.zeros((0, T))
        for m in range(self.M):
            samples = self.cur[m].sample_list
            X = samples['observations']
            N = len(samples)
            traj = self.new_traj_distr[m]
            pol_info = self.cur[m].pol_info
            mu = ptu.zeros((N, T, dU))
            prc = ptu.zeros((N, T, dU, dU))
            wt = ptu.zeros((N, T))
            obs = ptu.FloatTensor(samples['observations'])
            # Get time-indexed actions.
            for t in range(T):
                # Compute actions along this trajectory.
                prc[:, t, :, :] = ptu.FloatTensor(
                    np.tile(traj.inv_pol_covar[t, :, :], [N, 1, 1]))
                for i in range(N):
                    mu[i,
                       t, :] = ptu.FloatTensor(traj.K[t, :, :].dot(X[i,
                                                                     t, :]) +
                                               traj.k[t, :])
                wt[:, t] = pol_info.pol_wt[t]

            tgt_mu = torch.cat((tgt_mu, mu))
            tgt_prc = torch.cat((tgt_prc, prc))
            tgt_wt = torch.cat((tgt_wt, wt))
            obs_data = torch.cat((obs_data, obs))

        self.global_policy_optimization(obs_data, tgt_mu, tgt_prc, tgt_wt)
    def __init__(self,
                 obs_dim,
                 action_dim,
                 n_policies,
                 latent_dim,
                 shared_hidden_sizes=None,
                 unshared_hidden_sizes=None,
                 unshared_mix_hidden_sizes=None,
                 unshared_policy_hidden_sizes=None,
                 stds=None,
                 hidden_activation='relu',
                 hidden_w_init='xavier_normal',
                 hidden_b_init_val=1e-2,
                 output_w_init='xavier_normal',
                 output_b_init_val=1e-2,
                 pol_output_activation='linear',
                 mix_output_activation='linear',
                 final_pol_output_activation='linear',
                 input_norm=False,
                 shared_layer_norm=False,
                 policies_layer_norm=False,
                 mixture_layer_norm=False,
                 final_policy_layer_norm=False,
                 epsilon=1e-6,
                 softmax_weights=False,
                 **kwargs):
        self.save_init_params(locals())
        TanhGaussianComposedMultiPolicy.__init__(self)
        ExplorationPolicy.__init__(self, action_dim)

        self._input_size = obs_dim
        self._output_sizes = action_dim
        self._n_subpolicies = n_policies
        self._latent_size = latent_dim
        # Activation Fcns
        self._hidden_activation = ptu.get_activation(hidden_activation)
        self._pol_output_activation = ptu.get_activation(pol_output_activation)
        self._mix_output_activation = ptu.get_activation(mix_output_activation)
        self._final_pol_output_activation = ptu.get_activation(
            final_pol_output_activation)
        # Normalization Layer Flags
        self._shared_layer_norm = shared_layer_norm
        self._policies_layer_norm = policies_layer_norm
        self._mixture_layer_norm = mixture_layer_norm
        self._final_policy_layer_norm = final_policy_layer_norm
        # Layers Lists
        self._sfcs = []  # Shared Layers
        self._sfc_norms = []  # Norm. Shared Layers
        self._pfcs = [list()
                      for _ in range(self._n_subpolicies)]  # Policies Layers
        self._pfc_norms = [list()
                           for _ in range(self._n_subpolicies)]  # N. Pol. L.
        self._pfc_lasts = []  # Last Policies Layers
        self._mfcs = []  # Mixing Layers
        self._norm_mfcs = []  # Norm. Mixing Layers
        # self.mfc_last = None  # Below is instantiated
        self._fpfcs = []  # Final Policy Layers
        self._norm_fpfcs = []  # Norm. Mixing Layers

        self._softmax_weights = softmax_weights

        # Initial size = Obs size
        in_size = self._input_size

        # Ordered Dictionaries for specific modules/parameters
        self._shared_modules = OrderedDict()
        self._shared_parameters = OrderedDict()
        self._policies_modules = [OrderedDict() for _ in range(n_policies)]
        self._policies_parameters = [OrderedDict() for _ in range(n_policies)]
        self._mixing_modules = OrderedDict()
        self._mixing_parameters = OrderedDict()
        self._final_policy_modules = OrderedDict()
        self._final_policy_parameters = OrderedDict()

        # ############# #
        # Shared Layers #
        # ############# #
        if input_norm:
            ln = nn.BatchNorm1d(in_size)
            self.sfc_input = ln
            self.add_shared_module("sfc_input", ln)
            self.__setattr__("sfc_input", ln)
        else:
            self.sfc_input = None

        if shared_hidden_sizes is not None:
            for ii, next_size in enumerate(shared_hidden_sizes):
                sfc = nn.Linear(in_size, next_size)
                ptu.layer_init(
                    layer=sfc,
                    option=hidden_w_init,
                    activation=hidden_activation,
                    b=hidden_b_init_val,
                )
                self.__setattr__("sfc{}".format(ii), sfc)
                self._sfcs.append(sfc)
                self.add_shared_module("sfc{}".format(ii), sfc)

                if self._shared_layer_norm:
                    ln = LayerNorm(next_size)
                    # ln = nn.BatchNorm1d(next_size)
                    self.__setattr__("sfc{}_norm".format(ii), ln)
                    self._sfc_norms.append(ln)
                    self.add_shared_module("sfc{}_norm".format(ii), ln)
                in_size = next_size

        # Get the output_size of the shared layers (assume same for all)
        multipol_in_size = in_size

        # ############### #
        # Unshared Layers #
        # ############### #
        # Unshared Multi-Policy Hidden Layers
        if unshared_hidden_sizes is not None:
            for ii, next_size in enumerate(unshared_hidden_sizes):
                for pol_idx in range(self._n_subpolicies):
                    pfc = nn.Linear(multipol_in_size, next_size)
                    ptu.layer_init(layer=pfc,
                                   option=hidden_w_init,
                                   activation=hidden_activation,
                                   b=hidden_b_init_val)
                    self.__setattr__("pfc{}_{}".format(pol_idx, ii), pfc)
                    self._pfcs[pol_idx].append(pfc)
                    self.add_policies_module("pfc{}_{}".format(pol_idx, ii),
                                             pfc,
                                             idx=pol_idx)

                    if self._policies_layer_norm:
                        ln = LayerNorm(next_size)
                        # ln = nn.BatchNorm1d(next_size)
                        self.__setattr__("pfc{}_{}_norm".format(pol_idx, ii),
                                         ln)
                        self._pfc_norms[pol_idx].append(ln)
                        self.add_policies_module("pfc{}_{}_norm".format(
                            pol_idx, ii),
                                                 ln,
                                                 idx=pol_idx)
                multipol_in_size = next_size

        # Multi-Policy Last Layers
        for pol_idx in range(self._n_subpolicies):
            last_pfc = nn.Linear(multipol_in_size, latent_dim)
            ptu.layer_init(layer=last_pfc,
                           option=output_w_init,
                           activation=pol_output_activation,
                           b=output_b_init_val)
            self.__setattr__("pfc{}_last".format(pol_idx), last_pfc)
            self._pfc_lasts.append(last_pfc)
            self.add_policies_module("pfc{}_last".format(pol_idx),
                                     last_pfc,
                                     idx=pol_idx)

        # ############# #
        # Mixing Layers #
        # ############# #
        mixture_in_size = in_size + latent_dim * self._n_subpolicies
        # Unshared Mixing-Weights Hidden Layers
        if unshared_mix_hidden_sizes is not None:
            for ii, next_size in enumerate(unshared_mix_hidden_sizes):
                mfc = nn.Linear(mixture_in_size, next_size)
                ptu.layer_init(
                    layer=mfc,
                    option=hidden_w_init,
                    activation=hidden_activation,
                    b=hidden_b_init_val,
                )
                self.__setattr__("mfc{}".format(ii), mfc)
                self._mfcs.append(mfc)
                # Add it to specific dictionaries
                self.add_mixing_module("mfc{}".format(ii), mfc)

                if self._mixture_layer_norm:
                    ln = LayerNorm(next_size)
                    # ln = nn.BatchNorm1d(next_size)
                    self.__setattr__("mfc{}_norm".format(ii), ln)
                    self._norm_mfcs.append(ln)
                    self.add_mixing_module("mfc{}_norm".format(ii), ln)
                mixture_in_size = next_size

        # Unshared Mixing-Weights Last Layers
        mfc_last = nn.Linear(mixture_in_size, latent_dim)
        ptu.layer_init(
            layer=mfc_last,
            option=output_w_init,
            activation=mix_output_activation,
            b=output_b_init_val,
        )
        self.__setattr__("mfc_last", mfc_last)
        self.mfc_last = mfc_last
        # Add it to specific dictionaries
        self.add_mixing_module("mfc_last", mfc_last)

        if softmax_weights:
            raise ValueError("Check if it is correct a softmax")
            # self.mfc_softmax = nn.Softmax(dim=1)
        else:
            self.mfc_softmax = None

        # ################### #
        # Final Policy Layers #
        # ################### #
        final_pol_in_size = latent_dim
        if unshared_policy_hidden_sizes is not None:
            for ii, next_size in enumerate(unshared_policy_hidden_sizes):
                fpfc = nn.Linear(final_pol_in_size, next_size)
                ptu.layer_init(layer=fpfc,
                               option=hidden_w_init,
                               activation=hidden_activation,
                               b=hidden_b_init_val)
                self.__setattr__("fpfc{}".format(ii), fpfc)
                self._fpfcs.append(fpfc)
                # Add it to specific dictionaries
                self.add_final_policy_module("fpfc{}".format(ii), fpfc)

                if self._mixture_layer_norm:
                    ln = LayerNorm(next_size)
                    # ln = nn.BatchNorm1d(next_size)
                    self.__setattr__("fpfc{}_norm".format(ii), ln)
                    self._norm_fpfcs.append(ln)
                    self.add_final_policy_module("fpfc{}_norm".format(ii), ln)
                final_pol_in_size = next_size

        # Unshared Final Policy Last Layer
        fpfc_last = nn.Linear(final_pol_in_size, action_dim)
        ptu.layer_init(layer=fpfc_last,
                       option=output_w_init,
                       activation=final_pol_output_activation,
                       b=output_b_init_val)
        self.__setattr__("fpfc_last", fpfc_last)
        self.fpfc_last = fpfc_last
        # Add it to specific dictionaries
        self.add_final_policy_module("fpfc_last", fpfc_last)

        # ########## #
        # Std Layers #
        # ########## #
        # Multi-Policy Log-Stds Last Layers
        fpfc_last_log_std = nn.Linear(final_pol_in_size, action_dim)
        ptu.layer_init(layer=fpfc_last_log_std,
                       option=output_w_init,
                       activation=final_pol_output_activation,
                       b=output_b_init_val)
        self.__setattr__("fpfc_last_log_std", fpfc_last_log_std)
        self.fpfc_last_log_std = fpfc_last_log_std
        # Add it to specific dictionaries
        self.add_final_policy_module("fpfc_last_log_std", fpfc_last_log_std)

        self._normal_dist = Normal(loc=ptu.zeros(action_dim),
                                   scale=ptu.ones(action_dim))
        self._epsilon = epsilon

        self._pols_idxs = ptu.arange(self._n_subpolicies)
        self._compo_pol_idx = torch.tensor([self._n_subpolicies],
                                           dtype=torch.int64,
                                           device=ptu.device)
Exemplo n.º 6
0
    def __init__(self,
                 obs_dim,
                 action_dim,
                 n_policies,
                 shared_hidden_sizes=None,
                 unshared_hidden_sizes=None,
                 unshared_mix_hidden_sizes=None,
                 stds=None,
                 hidden_activation='relu',
                 hidden_w_init='xavier_normal',
                 hidden_b_init_val=1e-2,
                 output_w_init='xavier_normal',
                 output_b_init_val=1e-2,
                 pol_output_activation='linear',
                 mix_output_activation='linear',
                 input_norm=False,
                 shared_layer_norm=False,
                 policies_layer_norm=False,
                 mixture_layer_norm=False,
                 epsilon=1e-6,
        ):
        self.save_init_params(locals())
        super(TanhGaussianMixtureMultiPolicy, self).__init__()
        ExplorationPolicy.__init__(self, action_dim)

        self._input_size = obs_dim
        self._output_sizes = action_dim
        self._n_subpolicies = n_policies
        # Activation Fcns
        self._hidden_activation = ptu.get_activation(hidden_activation)
        self._pol_output_activation = ptu.get_activation(pol_output_activation)
        self._mix_output_activation = ptu.get_activation(mix_output_activation)
        # Normalization Layer Flags
        self._shared_layer_norm = shared_layer_norm
        self._policies_layer_norm = policies_layer_norm
        self._mixture_layer_norm = mixture_layer_norm
        # Layers Lists
        self._sfcs = []  # Shared Layers
        self._sfc_norms = []  # Norm. Shared Layers
        self._pfcs = [list() for _ in range(self._n_subpolicies)]  # Policies Layers
        self._pfc_norms = [list() for _ in range(self._n_subpolicies)]  # N. Pol. L.
        self._pfc_lasts = []  # Last Policies Layers
        self._mfcs = []  # Mixing Layers
        self._norm_mfcs = []  # Norm. Mixing Layers
        # self.mfc_last = None  # Below is instantiated

        # Initial size = Obs size
        in_size = self._input_size

        # Ordered Dictionaries for specific modules/parameters
        self._shared_modules = OrderedDict()
        self._shared_parameters = OrderedDict()
        self._policies_modules = [OrderedDict() for _ in range(n_policies)]
        self._policies_parameters = [OrderedDict() for _ in range(n_policies)]
        self._mixing_modules = OrderedDict()
        self._mixing_parameters = OrderedDict()

        # ############# #
        # Shared Layers #
        # ############# #
        if input_norm:
            ln = nn.BatchNorm1d(in_size)
            self.sfc_input = ln
            self.add_shared_module("sfc_input", ln)
        else:
            self.sfc_input = None

        if shared_hidden_sizes is not None:
            for ii, next_size in enumerate(shared_hidden_sizes):
                sfc = nn.Linear(in_size, next_size)
                ptu.layer_init(
                    layer=sfc,
                    option=hidden_w_init,
                    activation=hidden_activation,
                    b=hidden_b_init_val,
                )
                self.__setattr__("sfc{}".format(ii), sfc)
                self._sfcs.append(sfc)
                self.add_shared_module("sfc{}".format(ii), sfc)

                if self._shared_layer_norm:
                    ln = LayerNorm(next_size)
                    # ln = nn.BatchNorm1d(next_size)
                    self.__setattr__("sfc{}_norm".format(ii), ln)
                    self._sfc_norms.append(ln)
                    self.add_shared_module("sfc{}_norm".format(ii), ln)
                in_size = next_size

        # Get the output_size of the shared layers (assume same for all)
        multipol_in_size = in_size
        mixture_in_size = in_size

        # ############### #
        # Unshared Layers #
        # ############### #
        # Unshared Multi-Policy Hidden Layers
        if unshared_hidden_sizes is not None:
            for ii, next_size in enumerate(unshared_hidden_sizes):
                for pol_idx in range(self._n_subpolicies):
                    pfc = nn.Linear(multipol_in_size, next_size)
                    ptu.layer_init(
                        layer=pfc,
                        option=hidden_w_init,
                        activation=hidden_activation,
                        b=hidden_b_init_val
                    )
                    self.__setattr__("pfc{}_{}".format(pol_idx, ii), pfc)
                    self._pfcs[pol_idx].append(pfc)
                    self.add_policies_module("pfc{}_{}".format(pol_idx, ii),
                                             pfc, idx=pol_idx)

                    if self._policies_layer_norm:
                        ln = LayerNorm(next_size)
                        # ln = nn.BatchNorm1d(next_size)
                        self.__setattr__("pfc{}_{}_norm".format(pol_idx, ii),
                                         ln)
                        self._pfc_norms[pol_idx].append(ln)
                        self.add_policies_module("pfc{}_{}_norm".format(pol_idx,
                                                                        ii),
                                                 ln, idx=pol_idx)
                multipol_in_size = next_size

        # Multi-Policy Last Layers
        for pol_idx in range(self._n_subpolicies):
            last_pfc = nn.Linear(multipol_in_size, action_dim)
            ptu.layer_init(
                layer=last_pfc,
                option=output_w_init,
                activation=pol_output_activation,
                b=output_b_init_val
            )
            self.__setattr__("pfc{}_last".format(pol_idx), last_pfc)
            self._pfc_lasts.append(last_pfc)
            self.add_policies_module("pfc{}_last".format(pol_idx), last_pfc,
                                     idx=pol_idx)

        # Multi-Policy Log-Stds Last Layers
        self.stds = stds
        self.log_std = list()
        if stds is None:
            self._pfc_log_std_lasts = list()
            for pol_idx in range(self._n_subpolicies):
                last_pfc_log_std = nn.Linear(multipol_in_size, action_dim)
                ptu.layer_init(
                    layer=last_pfc_log_std,
                    option=output_w_init,
                    activation=pol_output_activation,
                    b=output_b_init_val
                )
                self.__setattr__("pfc{}_log_std_last".format(pol_idx),
                                 last_pfc_log_std)
                self._pfc_log_std_lasts.append(last_pfc_log_std)
                self.add_policies_module("pfc{}_log_std_last".format(pol_idx),
                                         last_pfc_log_std, idx=pol_idx)

        else:
            for std in stds:
                self.log_std.append(torch.log(stds))
                assert LOG_SIG_MIN <= self.log_std[-1] <= LOG_SIG_MAX

        # ############# #
        # Mixing Layers #
        # ############# #
        # Unshared Mixing-Weights Hidden Layers
        if unshared_mix_hidden_sizes is not None:
            for ii, next_size in enumerate(unshared_mix_hidden_sizes):
                mfc = nn.Linear(mixture_in_size, next_size)
                ptu.layer_init(
                    layer=mfc,
                    option=hidden_w_init,
                    activation=hidden_activation,
                    b=hidden_b_init_val
                )
                self.__setattr__("mfc{}".format(ii), mfc)
                self._mfcs.append(mfc)
                # Add it to specific dictionaries
                self.add_mixing_module("mfc{}".format(ii), mfc)

                if self._mixture_layer_norm:
                    ln = LayerNorm(next_size)
                    # ln = nn.BatchNorm1d(next_size)
                    self.__setattr__("mfc{}_norm".format(ii), ln)
                    self._norm_mfcs.append(ln)
                    self.add_mixing_module("mfc{}_norm".format(ii), ln)
                mixture_in_size = next_size

        # Unshared Mixing-Weights Last Layers
        mfc_last = nn.Linear(mixture_in_size, self._n_subpolicies * action_dim)
        ptu.layer_init(
            layer=mfc_last,
            option=output_w_init,
            activation=mix_output_activation,
            b=output_b_init_val
        )
        self.__setattr__("mfc_last", mfc_last)
        self.mfc_last = mfc_last
        # Add it to specific dictionaries
        self.add_mixing_module("mfc_last", mfc_last)

        softmax_weights = True
        if softmax_weights:
            self.mfc_softmax = nn.Softmax(dim=1)
        else:
            self.mfc_softmax = None

        self._normal_dist = Normal(loc=ptu.zeros(action_dim),
                                   scale=ptu.ones(action_dim))
        self._epsilon = epsilon

        self._pols_idxs = ptu.arange(self._n_subpolicies)
Exemplo n.º 7
0
    def global_policy_optimization(self, obs, tgt_mu, tgt_prc, tgt_wt):
        """
        Update policy.
        :param obs: Numpy array of observations, N x T x dO.
        :param tgt_mu: Numpy array of mean controller outputs, N x T x dU.
        :param tgt_prc: Numpy array of precision matrices, N x T x dU x dU.
        :param tgt_wt: Numpy array of weights, N x T.
        """
        N, T = obs.shape[:2]
        dU = self.dU
        dO = self.dO

        # Save original tgt_prc.
        tgt_prc_orig = torch.reshape(tgt_prc, [N * T, dU, dU])

        # Renormalize weights.
        tgt_wt *= (float(N * T) / torch.sum(tgt_wt))
        # Allow ights to be at most twice the robust median.
        mn = torch.median(tgt_wt[tgt_wt > 1e-2])
        tgt_wt = torch.clamp(tgt_wt, max=2 * mn)
        # Robust median should be around one.
        tgt_wt /= mn

        # Reshape inputs.
        obs = torch.reshape(obs, (N * T, dO))
        tgt_mu = torch.reshape(tgt_mu, (N * T, dU))
        tgt_prc = torch.reshape(tgt_prc, (N * T, dU, dU))
        tgt_wt = torch.reshape(tgt_wt, (N * T, 1, 1))

        # Fold weights into tgt_prc.
        tgt_prc = tgt_wt * tgt_prc

        # TODO: DO THIS MORE THAN ONCE!!
        if not hasattr(self.global_policy, 'scale') or not hasattr(
                self.global_policy, 'bias'):
            # 1e-3 to avoid infs if some state dimensions don't change in the
            # first batch of samples
            self.global_policy.scale = ptu.zeros(self.explo_env.obs_dim)
            self.global_policy.bias = ptu.zeros(self.explo_env.obs_dim)

        m = self._global_samples_counter
        n = m + N * T

        scale_obs = torch.diag(1.0 /
                               torch.clamp(torch.std(obs, dim=0), min=1e-3))
        var_obs = scale_obs**2
        var_prev = self.global_policy.scale**2

        bias_obs = -torch.mean(obs.matmul(scale_obs), dim=0)
        bias_prev = self.global_policy.bias
        bias_new = float(n / (m + n)) * bias_obs + float(m /
                                                         (m + n)) * bias_prev

        var_new = float(n/(m+n))*var_obs + float(m/(m+n))*var_prev - \
                  float((m*n)/(m+n)**2)*(bias_prev - bias_new)**2
        self.global_policy.scale = torch.sqrt(var_new)
        self.global_policy.bias = bias_new

        # self.global_policy.scale = ptu.eye(self.env.obs_dim)
        # self.global_policy.bias = ptu.zeros(self.env.obs_dim)

        # Normalize Inputs
        obs = obs.matmul(self.global_policy.scale) + self.global_policy.bias

        # # Global Policy Optimization
        # self.global_pol_optimizer = torch.optim.Adam(
        #     self.global_policy.parameters(),
        #     lr=self._global_opt_lr,
        #     betas=(0.9, 0.999),
        #     eps=1e-08,  # Term added to the denominator for numerical stability
        #     # weight_decay=0.005,
        #     weight_decay=0.5,
        #     amsgrad=True,
        # )

        # Assuming that N*T >= self.batch_size.
        batches_per_epoch = math.floor(N * T / self._global_opt_batch_size)
        idx = list(range(N * T))
        average_loss = 0
        np.random.shuffle(idx)

        if torch.any(torch.isnan(obs)):
            raise ValueError('GIVING NaN OBSERVATIONS to PYTORCH')
        if torch.any(torch.isnan(tgt_mu)):
            raise ValueError('GIVING NaN ACTIONS to PYTORCH')
        if torch.any(torch.isnan(tgt_prc)):
            raise ValueError('GIVING NaN PRECISION to PYTORCH')

        for oo in range(1):
            print('$$$$\n' * 2)
            print('GLOBAL_OPT %02d' % oo)
            print('$$$$\n' * 2)
            # # Global Policy Optimization
            # self.global_pol_optimizer = torch.optim.Adam(
            #     self.global_policy.parameters(),
            #     lr=self._global_opt_lr,
            #     betas=(0.9, 0.999),
            #     eps=1e-08,  # Term added to the denominator for numerical stability
            #     # weight_decay=0.005,
            #     weight_decay=0.5,
            #     amsgrad=True,
            # )

            for ii in range(self._global_opt_iters):
                # # Load in data for this batch.
                # start_idx = int(ii * self._global_opt_batch_size %
                #                 (batches_per_epoch * self._global_opt_batch_size))
                # idx_i = idx[start_idx:start_idx+self._global_opt_batch_size]

                # Load in data for this batch.
                idx_i = np.random.choice(N * T, self._global_opt_batch_size)

                self.global_pol_optimizer.zero_grad()

                pol_output = self.global_policy(obs[idx_i],
                                                deterministic=True)[0]

                train_loss = euclidean_loss(
                    mlp_out=pol_output,
                    action=tgt_mu[idx_i],
                    precision=tgt_prc[idx_i],
                    batch_size=self._global_opt_batch_size)

                train_loss.backward()
                self.global_pol_optimizer.step()

                average_loss += train_loss.item()

                # del pol_output
                # del train_loss
                loss_tolerance = 5e-10

                if (ii + 1) % 50 == 0:
                    print('PolOpt iteration %d, average loss %f' %
                          (ii + 1, average_loss / 50))
                    average_loss = 0

                if train_loss <= loss_tolerance:
                    print("It converged! loss:", train_loss)
                    break

            if train_loss <= loss_tolerance:
                break

        # Optimize variance.
        A = torch.sum(tgt_prc_orig, dim=0) \
            + 2 * N * T * self._global_opt_ent_reg * ptu.ones((dU, dU))
        A = A / torch.sum(tgt_wt)

        # TODO - Use dense covariance?
        self.global_policy.std = torch.diag(torch.sqrt(A))