Exemplo n.º 1
0
    def update(self, m, algorithm):
        """ Run dual gradient decent to optimize trajectories. """
        T = algorithm.T
        eta = algorithm.cur[m].eta
        step_mult = algorithm.cur[m].step_mult
        traj_info = algorithm.cur[m].traj_info

        prev_traj_distr = algorithm.cur[m].traj_distr

        # Set KL-divergence step size (epsilon).
        kl_step = T * algorithm.base_kl_step * step_mult

        # We assume at min_eta, kl_div > kl_step, opposite for max_eta.
        min_eta = self._hyperparams['min_eta']
        max_eta = self._hyperparams['max_eta']

        LOGGER.debug("Running DGD for trajectory %d, eta: %f", m, eta)
        for itr in range(DGD_MAX_ITER):
            LOGGER.debug("Iteration %i, bracket: (%.2e , %.2e , %.2e)", itr,
                         min_eta, eta, max_eta)

            # Run fwd/bwd pass, note that eta may be updated.
            # NOTE: we can just ignore case when the new eta is larger.
            traj_distr, eta = self.backward(prev_traj_distr, traj_info, eta,
                                            algorithm, m)
            new_mu, new_sigma = self.forward(traj_distr, traj_info)

            # Compute KL divergence constraint violation.
            kl_div = traj_distr_kl(new_mu, new_sigma, traj_distr,
                                   prev_traj_distr)
            con = kl_div - kl_step

            # Convergence check - constraint satisfaction.
            if (abs(con) < 0.1 * kl_step):
                LOGGER.debug("KL: %f / %f, converged iteration %i", kl_div,
                             kl_step, itr)
                break

            # Choose new eta (bisect bracket or multiply by constant)
            if con < 0:  # Eta was too big.
                max_eta = eta
                geom = np.sqrt(min_eta * max_eta)  # Geometric mean.
                new_eta = max(geom, 0.1 * max_eta)
                LOGGER.debug("KL: %f / %f, eta too big, new eta: %f", kl_div,
                             kl_step, new_eta)
            else:  # Eta was too small.
                min_eta = eta
                geom = np.sqrt(min_eta * max_eta)  # Geometric mean.
                new_eta = min(geom, 10.0 * min_eta)
                LOGGER.debug("KL: %f / %f, eta too small, new eta: %f", kl_div,
                             kl_step, new_eta)

            # Logarithmic mean: log_mean(x,y) = (y - x)/(log(y) - log(x))
            eta = new_eta

        if kl_div > kl_step and abs(kl_div - kl_step) > 0.1 * kl_step:
            LOGGER.warning(
                "Final KL divergence after DGD convergence is too high.")

        return traj_distr, eta
Exemplo n.º 2
0
    def iteration(self, sample_lists):
        """
        Run iteration of LQR.
        Args:
            sample_lists: List of SampleList objects for each condition.
        """
        self.N = sum(len(self.sample_list[i]) for i in self.sample_list.keys())
        for m in range(self.M):
            self.cur[m].sample_list = sample_lists[m]
            prev_samples = self.sample_list[m].get_samples()
            prev_samples.extend(sample_lists[m].get_samples())
            self.sample_list[m] = SampleList(prev_samples)
            self.N += len(sample_lists[m])
        # Update dynamics model using all samples.
        self._update_dynamics()

        # Update the cost during learning if we use IOC.
        if self._hyperparams['ioc']:
            self._update_cost()

        self._update_step_size()  # KL Divergence step size.

        # Run inner loop to compute new policies.
        for _ in range(self._hyperparams['inner_iterations']):
            self._update_trajectories()

        # Computing KL-divergence between sample distribution and demo distribution
        itr = self.iteration_count
        if self._hyperparams['ioc']:
            for i in xrange(self.M):
                mu, sigma = self.traj_opt.forward(self.traj_distr[itr][i], self.traj_info[itr][i])
                # KL divergence between current traj. distribution and gt distribution
                self.kl_div[itr].append(traj_distr_kl(mu, sigma, self.traj_distr[itr][i], self.demo_traj[i]))

        if self._hyperparams['learning_from_prior']:
            for i in xrange(self.M):
                target_position = self._hyperparams['target_end_effector'][:3]
                cur_samples = sample_lists[m].get_samples()
                sample_end_effectors = [cur_samples[i].get(END_EFFECTOR_POINTS) for i in xrange(len(cur_samples))]
                dists = [np.amin(np.sqrt(np.sum((sample_end_effectors[i][:, :3] - target_position.reshape(1, -1))**2, axis = 1)), axis = 0) \
                         for i in xrange(len(cur_samples))]
                self.dists_to_target[itr].append(sum(dists) / len(cur_samples))   
        self._advance_iteration_variables()
Exemplo n.º 3
0
    def update(self, m, algorithm):
        """ Run dual gradient decent to optimize trajectories. """
        T = algorithm.T
        # here the T is 100
        eta = algorithm.cur[m].eta
        # gets the eta 
        if self.cons_per_step and type(eta) in (int, float):
            eta = np.ones(T) * eta
        step_mult = algorithm.cur[m].step_mult
        traj_info = algorithm.cur[m].traj_info
        # this is the trajectory info object in the which has object like dynamics etc

        if isinstance(algorithm, AlgorithmMDGPS):
            # For MDGPS, constrain to previous NN linearization
            prev_traj_distr = algorithm.cur[m].pol_info.traj_distr()
        else:
            # For BADMM/trajopt, constrain to previous LG controller
            prev_traj_distr = algorithm.cur[m].traj_distr
            #trajectory distribution is linear gaussian ditribution
        # print(step_mult,'this is the step_mult')
        #this step_mult is in the Iternationdata() class check that and change that to change the effects
        # Set KL-divergence step size (epsilon).
        kl_step = algorithm.base_kl_step * step_mult
        if not self.cons_per_step:
            kl_step *= T

        # We assume at min_eta, kl_div > kl_step, opposite for max_eta.
        if not self.cons_per_step:
            min_eta = self._hyperparams['min_eta']
            max_eta = self._hyperparams['max_eta']
            LOGGER.debug("Running DGD for trajectory %d, eta: %f", m, eta)
        else:
            min_eta = np.ones(T) * self._hyperparams['min_eta']
            max_eta = np.ones(T) * self._hyperparams['max_eta']
            LOGGER.debug("Running DGD for trajectory %d, avg eta: %f", m,
                         np.mean(eta[:-1]))

        max_itr = (DGD_MAX_LS_ITER if self.cons_per_step else
                   DGD_MAX_ITER)
        for itr in range(max_itr):
            if not self.cons_per_step:
                LOGGER.debug("Iteration %d, bracket: (%.2e , %.2e , %.2e)", itr,
                             min_eta, eta, max_eta)

            # Run fwd/bwd pass, note that eta may be updated.
            # Compute KL divergence constraint violation.
            #here the current distribution becomes the previous distribuiton
            # and traj_info holds all the dynamics and the information about the dsitribution
            traj_distr, eta = self.backward(prev_traj_distr, traj_info,
                                            eta, algorithm, m)

            if not self._use_prev_distr:
                new_mu, new_sigma = self.forward(traj_distr, traj_info)
                kl_div = traj_distr_kl(
                        new_mu, new_sigma, traj_distr, prev_traj_distr,
                        tot=(not self.cons_per_step)
                )
            else:
                prev_mu, prev_sigma = self.forward(prev_traj_distr, traj_info)
                kl_div = traj_distr_kl_alt(
                        prev_mu, prev_sigma, traj_distr, prev_traj_distr,
                        tot=(not self.cons_per_step)
                )

            con = kl_div - kl_step

            # Convergence check - constraint satisfaction.
            if self._conv_check(con, kl_step):
                if not self.cons_per_step:
                    LOGGER.debug("KL: %f / %f, converged iteration %d", kl_div,
                                 kl_step, itr)
                else:
                    LOGGER.debug(
                            "KL: %f / %f, converged iteration %d",
                            np.mean(kl_div[:-1]), np.mean(kl_step[:-1]), itr
                    )
                break

            if not self.cons_per_step:
                # Choose new eta (bisect bracket or multiply by constant)
                if con < 0: # Eta was too big.
                    max_eta = eta
                    geom = np.sqrt(min_eta*max_eta)  # Geometric mean.
                    new_eta = max(geom, 0.1*max_eta)
                    LOGGER.debug("KL: %f / %f, eta too big, new eta: %f",
                                 kl_div, kl_step, new_eta)
                else: # Eta was too small.
                    min_eta = eta
                    geom = np.sqrt(min_eta*max_eta)  # Geometric mean.
                    new_eta = min(geom, 10.0*min_eta)
                    LOGGER.debug("KL: %f / %f, eta too small, new eta: %f",
                                 kl_div, kl_step, new_eta)

                # Logarithmic mean: log_mean(x,y) = (y - x)/(log(y) - log(x))
                eta = new_eta
            else:
                for t in range(T):
                    if con[t] < 0:
                        max_eta[t] = eta[t]
                        geom = np.sqrt(min_eta[t]*max_eta[t])
                        eta[t] = max(geom, 0.1*max_eta[t])
                    else:
                        min_eta[t] = eta[t]
                        geom = np.sqrt(min_eta[t]*max_eta[t])
                        eta[t] = min(geom, 10.0*min_eta[t])
                if itr % 10 == 0:
                    LOGGER.debug("avg KL: %f / %f, avg new eta: %f",
                                 np.mean(kl_div[:-1]), np.mean(kl_step[:-1]),
                                 np.mean(eta[:-1]))

        if (self.cons_per_step and not self._conv_check(con, kl_step)):
            m_b, v_b = np.zeros(T-1), np.zeros(T-1)

            for itr in range(DGD_MAX_GD_ITER):
                traj_distr, eta = self.backward(prev_traj_distr, traj_info,
                                                eta, algorithm, m)

                if not self._use_prev_distr:
                    new_mu, new_sigma = self.forward(traj_distr, traj_info)
                    kl_div = traj_distr_kl(
                            new_mu, new_sigma, traj_distr, prev_traj_distr,
                            tot=False
                    )
                else:
                    prev_mu, prev_sigma = self.forward(prev_traj_distr,
                                                       traj_info)
                    kl_div = traj_distr_kl_alt(
                            prev_mu, prev_sigma, traj_distr, prev_traj_distr,
                            tot=False
                    )

                con = kl_div - kl_step
                if self._conv_check(con, kl_step):
                    LOGGER.debug(
                            "KL: %f / %f, converged iteration %d",
                            np.mean(kl_div[:-1]), np.mean(kl_step[:-1]), itr
                    )
                    break

                m_b = (BETA1 * m_b + (1-BETA1) * con[:-1])
                m_u = m_b / (1 - BETA1 ** (itr+1))
                v_b = (BETA2 * v_b + (1-BETA2) * np.square(con[:-1]))
                v_u = v_b / (1 - BETA2 ** (itr+1))
                eta[:-1] = np.minimum(
                        np.maximum(eta[:-1] + ALPHA * m_u / (np.sqrt(v_u) + EPS),
                                   self._hyperparams['min_eta']),
                        self._hyperparams['max_eta']
                )

                if itr % 10 == 0:
                    LOGGER.debug("avg KL: %f / %f, avg new eta: %f",
                                 np.mean(kl_div[:-1]), np.mean(kl_step[:-1]),
                                 np.mean(eta[:-1]))

        if (np.mean(kl_div) > np.mean(kl_step) and
            not self._conv_check(con, kl_step)):
            LOGGER.warning(
                    "Final KL divergence after DGD convergence is too high."
            )
        return traj_distr, eta
Exemplo n.º 4
0
    def update(self, m, algorithm):
        """ Run dual gradient decent to optimize trajectories. """
        T = algorithm.T
        eta = algorithm.cur[m].eta
        if self.cons_per_step and type(eta) in (int, float):
            eta = np.ones(T) * eta
        step_mult = algorithm.cur[m].step_mult
        traj_info = algorithm.cur[m].traj_info

        # For MDGPS, constrain to previous NN linearization
        prev_traj_distr = algorithm.cur[m].pol_info.traj_distr()

        # Set KL-divergence step size (epsilon).
        kl_step = algorithm.base_kl_step * step_mult
        if not self.cons_per_step:
            kl_step *= T

        # We assume at min_eta, kl_div > kl_step, opposite for max_eta.
        min_eta = self._hyperparams['min_eta']
        max_eta = self._hyperparams['max_eta']
        LOGGER.debug("Running DGD for trajectory %d, eta: %f", m, eta)

        max_itr = (DGD_MAX_LS_ITER if self.cons_per_step else DGD_MAX_ITER)
        for itr in range(max_itr):
            LOGGER.debug("[DEBUG] Iteration %d, bracket: (%.2e , %.2e , %.2e)",
                         itr, min_eta, eta, max_eta)

            # Run fwd/bwd pass, note that eta may be updated.
            # Compute KL divergence constraint violation.
            traj_distr, eta = self.backward(prev_traj_distr, traj_info, eta,
                                            algorithm, m)

            new_mu, new_sigma = self.forward(traj_distr, traj_info)
            kl_div = traj_distr_kl(new_mu,
                                   new_sigma,
                                   traj_distr,
                                   prev_traj_distr,
                                   tot=(not self.cons_per_step))
            con = kl_div - kl_step
            LOGGER.debug("[DEBUG] KL (%.2e , %.2e)", kl_div, kl_step)

            # Convergence check - constraint satisfaction.
            if self._conv_check(con, kl_step):
                LOGGER.debug("KL: %f / %f, converged iteration %d", kl_div,
                             kl_step, itr)
                break

            # Choose new eta (bisect bracket or multiply by constant)
            if con < 0:  # Eta was too big.
                max_eta = eta
                geom = np.sqrt(min_eta * max_eta)  # Geometric mean.
                new_eta = max(geom, 0.1 * max_eta)
                LOGGER.debug("KL: %f / %f, eta too big, new eta: %f", kl_div,
                             kl_step, new_eta)
            else:  # Eta was too small.
                min_eta = eta
                geom = np.sqrt(min_eta * max_eta)  # Geometric mean.
                new_eta = min(geom, 10.0 * min_eta)
                LOGGER.debug("KL: %f / %f, eta too small, new eta: %f", kl_div,
                             kl_step, new_eta)

            # Logarithmic mean: log_mean(x,y) = (y - x)/(log(y) - log(x))
            eta = new_eta
            print('traj {}'.format(m))
            print('eta {}'.format(eta))

        if (np.mean(kl_div) > np.mean(kl_step)
                and not self._conv_check(con, kl_step)):
            LOGGER.warning(
                "Final KL divergence after DGD convergence is too high.")
        return traj_distr, eta
Exemplo n.º 5
0
    def update(self, m, algorithm):
        """ Run dual gradient decent to optimize trajectories. """
        T = algorithm.T
        step_mult = algorithm.cur[m].step_mult
        prev_eta = algorithm.cur[m].eta
        traj_info = algorithm.cur[m].traj_info

        if type(algorithm) == AlgorithmMDGPS:
            # For MDGPS, constrain to previous NN linearization
            prev_traj_distr = algorithm.cur[m].pol_info.traj_distr()
        else:
            # For BADMM/trajopt, constrain to previous LG controller
            prev_traj_distr = algorithm.cur[m].traj_distr

        # Set KL-divergence step size (epsilon).
        kl_step = algorithm.base_kl_step * step_mult

        line_search = LineSearch(self._hyperparams['min_eta'])
        min_eta = -np.Inf
        #import pdb; pdb.set_trace()

        for itr in range(DGD_MAX_ITER):
            traj_distr, new_eta = self.backward(prev_traj_distr, traj_info,
                                                prev_eta, algorithm, m)
            new_mu, new_sigma = self.forward(traj_distr, traj_info)

            # Update min eta if we had a correction after running bwd.
            if new_eta > prev_eta:
                min_eta = new_eta

            # Compute KL divergence between prev and new distribution.
            kl_div = traj_distr_kl(new_mu, new_sigma, traj_distr,
                                   prev_traj_distr)

            traj_info.last_kl_step = kl_div

            # Main convergence check - constraint satisfaction.
            if (abs(kl_div - kl_step * T) < 0.1 * kl_step * T
                    or (itr >= 20 and kl_div < kl_step * T)):
                LOGGER.debug("Iteration %i, KL: %f / %f converged", itr,
                             kl_div, kl_step * T)
                eta = prev_eta  # TODO - Should this be here?
                break

            # Adjust eta using bracketing line search.
            eta = line_search.bracketing_line_search(kl_div - kl_step * T,
                                                     new_eta, min_eta)

            # Convergence check - dual variable change when min_eta hit.
            if (abs(prev_eta - eta) < THRESHA
                    and eta == max(min_eta, self._hyperparams['min_eta'])):
                LOGGER.debug("Iteration %i, KL: %f / %f converged (eta limit)",
                             itr, kl_div, kl_step * T)
                break

            # Convergence check - constraint satisfaction, KL not
            # changing much.
            if (itr > 2 and abs(kl_div - prev_kl_div) < THRESHB
                    and kl_div < kl_step * T):
                LOGGER.debug("Iteration %i, KL: %f / %f converged (no change)",
                             itr, kl_div, kl_step * T)
                break

            prev_kl_div = kl_div
            LOGGER.debug('Iteration %i, KL: %f / %f eta: %f -> %f', itr,
                         kl_div, kl_step * T, prev_eta, eta)
            prev_eta = eta

        if kl_div > kl_step * T and abs(kl_div -
                                        kl_step * T) > 0.1 * kl_step * T:
            LOGGER.warning(
                "Final KL divergence after DGD convergence is too high.")

        return traj_distr, eta
Exemplo n.º 6
0
    def update(self, m, algorithm):
        """ Run dual gradient decent to optimize trajectories. """
        T = algorithm.T
        step_mult = algorithm.cur[m].step_mult
        prev_eta = algorithm.cur[m].eta
        traj_info = algorithm.cur[m].traj_info
        prev_traj_distr = algorithm.cur[m].traj_distr

        # Set KL-divergence step size (epsilon).
        kl_step = algorithm.base_kl_step * step_mult

        line_search = LineSearch(self._hyperparams['min_eta'])
        min_eta = -np.Inf

        for itr in range(DGD_MAX_ITER):
            traj_distr, new_eta = self.backward(prev_traj_distr, traj_info,
                                                prev_eta, algorithm, m)
            new_mu, new_sigma = self.forward(traj_distr, traj_info)

            # Update min eta if we had a correction after running bwd.
            if new_eta > prev_eta:
                min_eta = new_eta

            # Compute KL divergence between prev and new distribution.
            kl_div = traj_distr_kl(new_mu, new_sigma,
                                   traj_distr, prev_traj_distr)

            traj_info.last_kl_step = kl_div

            # Main convergence check - constraint satisfaction.
            if (abs(kl_div - kl_step*T) < 0.1*kl_step*T or
                    (itr >= 20 and kl_div < kl_step*T)):
                LOGGER.debug("Iteration %i, KL: %f / %f converged",
                             itr, kl_div, kl_step * T)
                eta = prev_eta  # TODO - Should this be here?
                break

            # Adjust eta using bracketing line search.
            eta = line_search.bracketing_line_search(kl_div - kl_step*T,
                                                     new_eta, min_eta)

            # Convergence check - dual variable change when min_eta hit.
            if (abs(prev_eta - eta) < THRESHA and
                    eta == max(min_eta, self._hyperparams['min_eta'])):
                LOGGER.debug("Iteration %i, KL: %f / %f converged (eta limit)",
                             itr, kl_div, kl_step * T)
                break

            # Convergence check - constraint satisfaction, KL not
            # changing much.
            if (itr > 2 and abs(kl_div - prev_kl_div) < THRESHB and
                    kl_div < kl_step*T):
                LOGGER.debug("Iteration %i, KL: %f / %f converged (no change)",
                             itr, kl_div, kl_step * T)
                break

            prev_kl_div = kl_div
            LOGGER.debug('Iteration %i, KL: %f / %f eta: %f -> %f',
                         itr, kl_div, kl_step * T, prev_eta, eta)
            prev_eta = eta

        if kl_div > kl_step*T and abs(kl_div - kl_step*T) > 0.1*kl_step*T:
            LOGGER.warning(
                "Final KL divergence after DGD convergence is too high."
            )

        return traj_distr, eta
Exemplo n.º 7
0
    def update(self, m, algorithm):
        """ Run dual gradient decent to optimize trajectories. """
        T = algorithm.T
        eta = algorithm.cur[m].eta
        if self.cons_per_step and type(eta) in (int, float):
            eta = np.ones(T) * eta
        step_mult = algorithm.cur[m].step_mult
        traj_info = algorithm.cur[m].traj_info

        if isinstance(algorithm, AlgorithmMDGPS):
            # For MDGPS, constrain to previous NN linearization
            prev_traj_distr = algorithm.cur[m].pol_info.traj_distr()
        else:
            # For BADMM/trajopt, constrain to previous LG controller
            prev_traj_distr = algorithm.cur[m].traj_distr

        # Set KL-divergence step size (epsilon).
        kl_step = algorithm.base_kl_step * step_mult
        if not self.cons_per_step:
            kl_step *= T

        # We assume at min_eta, kl_div > kl_step, opposite for max_eta.
        if not self.cons_per_step:
            min_eta = self._hyperparams['min_eta']
            max_eta = self._hyperparams['max_eta']
            LOGGER.debug("Running DGD for trajectory %d, eta: %f", m, eta)
        else:
            min_eta = np.ones(T) * self._hyperparams['min_eta']
            max_eta = np.ones(T) * self._hyperparams['max_eta']
            LOGGER.debug("Running DGD for trajectory %d, avg eta: %f", m,
                         np.mean(eta[:-1]))

        max_itr = (DGD_MAX_LS_ITER if self.cons_per_step else
                   DGD_MAX_ITER)
        for itr in range(max_itr):
            if not self.cons_per_step:
                LOGGER.debug("Iteration %d, bracket: (%.2e , %.2e , %.2e)", itr,
                             min_eta, eta, max_eta)

            # Run fwd/bwd pass, note that eta may be updated.
            # Compute KL divergence constraint violation.
            traj_distr, eta = self.backward(prev_traj_distr, traj_info,
                                            eta, algorithm, m)

            if not self._use_prev_distr:
                new_mu, new_sigma = self.forward(traj_distr, traj_info)
                kl_div = traj_distr_kl(
                        new_mu, new_sigma, traj_distr, prev_traj_distr,
                        tot=(not self.cons_per_step)
                )
            else:
                prev_mu, prev_sigma = self.forward(prev_traj_distr, traj_info)
                kl_div = traj_distr_kl_alt(
                        prev_mu, prev_sigma, traj_distr, prev_traj_distr,
                        tot=(not self.cons_per_step)
                )

            con = kl_div - kl_step

            # Convergence check - constraint satisfaction.
            if self._conv_check(con, kl_step):
                if not self.cons_per_step:
                    LOGGER.debug("KL: %f / %f, converged iteration %d", kl_div,
                                 kl_step, itr)
                else:
                    LOGGER.debug(
                            "KL: %f / %f, converged iteration %d",
                            np.mean(kl_div[:-1]), np.mean(kl_step[:-1]), itr
                    )
                break

            if not self.cons_per_step:
                # Choose new eta (bisect bracket or multiply by constant)
                if con < 0: # Eta was too big.
                    max_eta = eta
                    geom = np.sqrt(min_eta*max_eta)  # Geometric mean.
                    new_eta = max(geom, 0.1*max_eta)
                    LOGGER.debug("KL: %f / %f, eta too big, new eta: %f",
                                 kl_div, kl_step, new_eta)
                else: # Eta was too small.
                    min_eta = eta
                    geom = np.sqrt(min_eta*max_eta)  # Geometric mean.
                    new_eta = min(geom, 10.0*min_eta)
                    LOGGER.debug("KL: %f / %f, eta too small, new eta: %f",
                                 kl_div, kl_step, new_eta)

                # Logarithmic mean: log_mean(x,y) = (y - x)/(log(y) - log(x))
                eta = new_eta
            else:
                for t in range(T):
                    if con[t] < 0:
                        max_eta[t] = eta[t]
                        geom = np.sqrt(min_eta[t]*max_eta[t])
                        eta[t] = max(geom, 0.1*max_eta[t])
                    else:
                        min_eta[t] = eta[t]
                        geom = np.sqrt(min_eta[t]*max_eta[t])
                        eta[t] = min(geom, 10.0*min_eta[t])
                if itr % 10 == 0:
                    LOGGER.debug("avg KL: %f / %f, avg new eta: %f",
                                 np.mean(kl_div[:-1]), np.mean(kl_step[:-1]),
                                 np.mean(eta[:-1]))

        if (self.cons_per_step and not self._conv_check(con, kl_step)):
            m_b, v_b = np.zeros(T-1), np.zeros(T-1)

            for itr in range(DGD_MAX_GD_ITER):
                traj_distr, eta = self.backward(prev_traj_distr, traj_info,
                                                eta, algorithm, m)

                if not self._use_prev_distr:
                    new_mu, new_sigma = self.forward(traj_distr, traj_info)
                    kl_div = traj_distr_kl(
                            new_mu, new_sigma, traj_distr, prev_traj_distr,
                            tot=False
                    )
                else:
                    prev_mu, prev_sigma = self.forward(prev_traj_distr,
                                                       traj_info)
                    kl_div = traj_distr_kl_alt(
                            prev_mu, prev_sigma, traj_distr, prev_traj_distr,
                            tot=False
                    )

                con = kl_div - kl_step
                if self._conv_check(con, kl_step):
                    LOGGER.debug(
                            "KL: %f / %f, converged iteration %d",
                            np.mean(kl_div[:-1]), np.mean(kl_step[:-1]), itr
                    )
                    break

                m_b = (BETA1 * m_b + (1-BETA1) * con[:-1])
                m_u = m_b / (1 - BETA1 ** (itr+1))
                v_b = (BETA2 * v_b + (1-BETA2) * np.square(con[:-1]))
                v_u = v_b / (1 - BETA2 ** (itr+1))
                eta[:-1] = np.minimum(
                        np.maximum(eta[:-1] + ALPHA * m_u / (np.sqrt(v_u) + EPS),
                                   self._hyperparams['min_eta']),
                        self._hyperparams['max_eta']
                )

                if itr % 10 == 0:
                    LOGGER.debug("avg KL: %f / %f, avg new eta: %f",
                                 np.mean(kl_div[:-1]), np.mean(kl_step[:-1]),
                                 np.mean(eta[:-1]))

        if (np.mean(kl_div) > np.mean(kl_step) and
            not self._conv_check(con, kl_step)):
            LOGGER.warning(
                    "Final KL divergence after DGD convergence is too high."
            )
        return traj_distr, eta