Ejemplo n.º 1
0
    def optimize_policy(self, all_samples_data, log=True):
        """
        Performs MAML outer step

        Args:
            all_samples_data (list) : list of lists of lists of samples (each is a dict) split by gradient update and
             meta task
            log (bool) : whether to log statistics

        Returns:
            None
        """
        meta_op_input_dict = self._extract_input_dict_meta_op(all_samples_data, self._optimization_keys)
        logger.log("Computing KL before")
        mean_kl_before = self.optimizer.constraint_val(meta_op_input_dict)

        logger.log("Computing loss before")
        loss_before = self.optimizer.loss(meta_op_input_dict)
        logger.log("Optimizing")
        self.optimizer.optimize(meta_op_input_dict)
        logger.log("Computing loss after")
        loss_after = self.optimizer.loss(meta_op_input_dict)

        logger.log("Computing KL after")
        mean_kl = self.optimizer.constraint_val(meta_op_input_dict)
        if log:
            logger.logkv('MeanKLBefore', mean_kl_before)
            logger.logkv('MeanKL', mean_kl)

            logger.logkv('LossBefore', loss_before)
            logger.logkv('LossAfter', loss_after)
            logger.logkv('dLoss', loss_before - loss_after)
Ejemplo n.º 2
0
    def optimize_policy(self, all_samples_data, log=True):
        """
        Performs MAML outer step

        Args:
            all_samples_data (list) : list of lists of lists of samples (each is a dict) split by gradient update and
             meta task
            log (bool) : whether to log statistics

        Returns:
            None
        """
        meta_op_input_dict = self._extract_input_dict_meta_op(
            all_samples_data, self._optimization_keys)

        if log:
            logger.log("Optimizing")
        loss_before = self.optimizer.optimize(
            input_val_dict=meta_op_input_dict)

        if log:
            logger.log("Computing statistics")
        loss_after = self.optimizer.loss(input_val_dict=meta_op_input_dict)

        if log:
            logger.logkv('LossBefore', loss_before)
            logger.logkv('LossAfter', loss_after)
    def optimize(self, input_val_dict):
        """
        Carries out the optimization step

        Args:
            input_val_dict (dict): dict containing the values to be fed into the computation graph

        Returns:
            (float) loss before optimization

        """

        sess = tf.compat.v1.get_default_session()
        batch_size, seq_len, *_ = list(input_val_dict.values())[0].shape

        loss_before_opt = None
        for epoch in range(self._max_epochs):
            hidden_batch = self._target.get_zero_state(batch_size)
            if self._verbose:
                logger.log("Epoch %d" % epoch)
            # run train op
            loss = []
            all_grads = []

            for i in range(0, seq_len, self._backprop_steps):
                n_i = i + self._backprop_steps
                feed_dict = dict([(self._input_ph_dict[key],
                                   input_val_dict[key][:, i:n_i])
                                  for key in self._input_ph_dict.keys()])
                feed_dict[self._hidden_ph] = hidden_batch
                batch_loss, grads, hidden_batch = sess.run(
                    [self._loss, self._gradients_var, self._next_hidden_var],
                    feed_dict=feed_dict)
                loss.append(batch_loss)
                all_grads.append(grads)

            grads = [np.mean(grad, axis=0) for grad in zip(*all_grads)]
            feed_dict = dict(zip(self._gradients_ph, grads))
            _ = sess.run(self._train_op, feed_dict=feed_dict)

            if not loss_before_opt: loss_before_opt = np.mean(loss)

            # if self._verbose:
            #     logger.log("Epoch: %d | Loss: %f" % (epoch, new_loss))
            #
            # if abs(last_loss - new_loss) < self._tolerance:
            #     break
            # last_loss = new_loss
        return loss_before_opt
Ejemplo n.º 4
0
    def optimize_policy(self, all_samples_data, log=True):
        """
        Performs MAML outer step

        Args:
            all_samples_data (list) : list of lists of lists of samples (each is a dict) split by gradient update and
             meta task
            log (bool) : whether to log statistics

        Returns:
            None
        """
        meta_op_input_dict = self._extract_input_dict_meta_op(
            all_samples_data, self._optimization_keys)

        # add kl_coeffs / clip_eps to meta_op_input_dict
        meta_op_input_dict['inner_kl_coeff'] = self.inner_kl_coeff
        if self.clip_outer:
            meta_op_input_dict['clip_eps'] = self.clip_eps
        else:
            meta_op_input_dict['outer_kl_coeff'] = self.outer_kl_coeff

        if log: logger.log("Optimizing")
        loss_before = self.optimizer.optimize(
            input_val_dict=meta_op_input_dict)

        if log: logger.log("Computing statistics")
        loss_after, inner_kls, outer_kl = self.optimizer.compute_stats(
            input_val_dict=meta_op_input_dict)

        if self.adaptive_inner_kl_penalty:
            if log: logger.log("Updating inner KL loss coefficients")
            self.inner_kl_coeff = self.adapt_kl_coeff(self.inner_kl_coeff,
                                                      inner_kls,
                                                      self.target_inner_step)

        if self.adaptive_outer_kl_penalty:
            if log: logger.log("Updating outer KL loss coefficients")
            self.outer_kl_coeff = self.adapt_kl_coeff(self.outer_kl_coeff,
                                                      outer_kl,
                                                      self.target_outer_step)

        if log:
            logger.logkv('LossBefore', loss_before)
            logger.logkv('LossAfter', loss_after)
            logger.logkv('KLInner', np.mean(inner_kls))
            logger.logkv('KLCoeffInner', np.mean(self.inner_kl_coeff))
            if not self.clip_outer: logger.logkv('KLOuter', outer_kl)
Ejemplo n.º 5
0
    def optimize(self, input_val_dict):
        """
        Carries out the optimization step

        Args:
            input_val_dict (dict): dict containing the values to be fed into the computation graph

        Returns:
            (float) loss before optimization

        """

        sess = tf.compat.v1.get_default_session()
        feed_dict = self.create_feed_dict(input_val_dict)

        # Overload self._batch size
        # dataset = MAMLBatchDataset(inputs, num_batches=self._batch_size, extra_inputs=extra_inputs,
        # meta_batch_size=self.meta_batch_size, num_grad_updates=self.num_grad_updates)
        # Todo: reimplement minibatches

        loss_before_opt = None
        for epoch in range(self._max_epochs):
            if self._verbose:
                logger.log("Epoch %d" % epoch)

            loss, _ = sess.run([self._loss, self._train_op], feed_dict)
            if not loss_before_opt:
                loss_before_opt = loss

            # if self._verbose:
            #     logger.log("Epoch: %d | Loss: %f" % (epoch, new_loss))
            #
            # if abs(last_loss - new_loss) < self._tolerance:
            #     break
            # last_loss = new_loss
        return loss_before_opt
Ejemplo n.º 6
0
    def validate(self, sess):
        """
        Tests policy on env using algo

        Pseudocode:
            for itr in n_itr:
                for step in num_inner_grad_steps:
                    sampler.sample()
                    algo.compute_updated_dists()
                algo.optimize_policy()
                sampler.update_goals()
        """

        # initialize uninitialized vars  (only initialize vars that were not loaded)
        avg_returns = []
        for itr in range(self.n_itr):
            logger.log("\n ---------------- Iteration %d ----------------" %
                       itr)

            self.policy.switch_to_pre_update()  # Switch to pre-update policy
            self.sampler.update_tasks()

            for step in range(self.num_inner_grad_steps + 1):
                logger.log('** Step ' + str(step) + ' **')
                """ -------------------- Sampling --------------------------"""
                paths = self.sampler.obtain_samples(log=False)
                """ ----------------- Processing Samples ---------------------"""
                samples_data = self.sample_processor.process_samples(
                    paths, log='all', log_prefix='Step_%d-' % step)
                self.log_diagnostics(sum(list(paths.values()), []),
                                     prefix='Step_%d-' % step)
                """ ------------------- Inner Policy Update --------------------"""
                if step < self.num_inner_grad_steps:
                    self.algo._adapt(samples_data)
            avg_returns.append(self.sample_processor.avg_return)

        logger.log(f'Average validation reward: {np.mean(avg_returns)}')
        return np.mean(avg_returns)
Ejemplo n.º 7
0
    def train(self, tester):
        """
        Trains policy on env using algo

        Pseudocode:
            for itr in n_itr:
                for step in num_inner_grad_steps:
                    sampler.sample()
                    algo.compute_updated_dists()
                algo.optimize_policy()
                sampler.update_goals()
        """
        best_train_reward = -np.inf
        with self.sess.as_default() as sess:

            # initialize uninitialized vars  (only initialize vars that were not loaded)
            uninit_vars = [
                var for var in tf.compat.v1.global_variables()
                if not sess.run(tf.compat.v1.is_variable_initialized(var))
            ]
            sess.run(tf.compat.v1.variables_initializer(uninit_vars))

            start_time = time.time()
            for itr in range(self.start_itr, self.n_itr):
                itr_start_time = time.time()
                logger.log(
                    "\n ---------------- Iteration %d ----------------" % itr)
                logger.log(
                    "Sampling set of tasks/goals for this meta-batch...")

                self.sampler.update_tasks()
                self.policy.switch_to_pre_update(
                )  # Switch to pre-update policy

                all_samples_data, all_paths = [], []
                list_sampling_time, list_inner_step_time, list_outer_step_time, list_proc_samples_time = [], [], [], []
                start_total_inner_time = time.time()
                for step in range(self.num_inner_grad_steps + 1):
                    logger.log('** Step ' + str(step) + ' **')
                    """ -------------------- Sampling --------------------------"""

                    logger.log("Obtaining samples...")
                    time_env_sampling_start = time.time()
                    paths = self.sampler.obtain_samples(log=True,
                                                        log_prefix='Step_%d-' %
                                                        step)
                    list_sampling_time.append(time.time() -
                                              time_env_sampling_start)
                    all_paths.append(paths)
                    """ ----------------- Processing Samples ---------------------"""

                    logger.log("Processing samples...")
                    time_proc_samples_start = time.time()
                    samples_data = self.sample_processor.process_samples(
                        paths, log='all', log_prefix='Step_%d-' % step)
                    all_samples_data.append(samples_data)
                    list_proc_samples_time.append(time.time() -
                                                  time_proc_samples_start)

                    self.log_diagnostics(sum(list(paths.values()), []),
                                         prefix='Step_%d-' % step)
                    """ ------------------- Inner Policy Update --------------------"""

                    time_inner_step_start = time.time()
                    if step < self.num_inner_grad_steps:
                        logger.log("Computing inner policy updates...")
                        self.algo._adapt(samples_data)
                    # train_writer = tf.summary.FileWriter('/home/ignasi/Desktop/maml_zoo_graph',
                    #                                      sess.graph)
                    list_inner_step_time.append(time.time() -
                                                time_inner_step_start)
                total_inner_time = time.time() - start_total_inner_time

                time_maml_opt_start = time.time()
                """ ------------------ Outer Policy Update ---------------------"""

                logger.log("Optimizing policy...")
                # This needs to take all samples_data so that it can construct graph for meta-optimization.
                time_outer_step_start = time.time()
                self.algo.optimize_policy(all_samples_data)
                """ ------------------- Logging Stuff --------------------------"""
                logger.logkv('Itr', itr)
                logger.logkv('n_timesteps',
                             self.sampler.total_timesteps_sampled)

                logger.logkv('Time-OuterStep',
                             time.time() - time_outer_step_start)
                logger.logkv('Time-TotalInner', total_inner_time)
                logger.logkv('Time-InnerStep', np.sum(list_inner_step_time))
                logger.logkv('Time-SampleProc', np.sum(list_proc_samples_time))
                logger.logkv('Time-Sampling', np.sum(list_sampling_time))

                logger.logkv('Time', time.time() - start_time)
                logger.logkv('ItrTime', time.time() - itr_start_time)
                logger.logkv('Time-MAMLSteps',
                             time.time() - time_maml_opt_start)

                logger.log("Saving snapshot...")
                params = self.get_itr_snapshot(itr)
                logger.save_itr_params(itr, params)
                logger.log("Saved")

                logger.dumpkvs()
                if itr == 0:
                    sess.graph.finalize()

                if (itr % 80 == 79) or (
                    (self.sample_processor.avg_return >
                     best_train_reward * 1.05) and
                    (itr > 50)) and self.sample_processor.avg_return > 300.0:
                    if self.sample_processor.avg_return > best_train_reward:
                        best_train_reward = self.sample_processor.avg_return
                    print('TESTING')
                    delim = '====================\n'
                    print(delim, delim, delim)
                    val_reward = tester.validate(sess)
                    if self.best_val_reward < val_reward:
                        self.best_itr = itr
                        self.best_val_reward = val_reward
                    print(delim, delim, delim)

        logger.log("Training finished")
        logger.log(
            f"Best iteration is {self.best_itr} with reward {self.best_val_reward}"
        )
        self.sess.close()
        return self.best_itr
Ejemplo n.º 8
0
    def optimize(self, input_val_dict):
        """
        Carries out the optimization step

        Args:
            inputs (list): inputs for the optimization
            extra_inputs (list): extra inputs for the optimization
            subsample_grouped_inputs (None or list): subsample data from each element of the list

        """
        logger.log("Start CG optimization")

        logger.log("computing loss before")
        loss_before = self.loss(input_val_dict)

        logger.log("performing update")

        logger.log("computing gradient")
        gradient = self.gradient(input_val_dict)
        logger.log("gradient computed")

        logger.log("computing descent direction")
        Hx = self._hvp_approach.build_eval(input_val_dict)
        descent_direction = conjugate_gradients(Hx,
                                                gradient,
                                                cg_iters=self._cg_iters)

        rat = 1. / (descent_direction.dot(Hx(descent_direction)) + 1e-8)
        initial_step_size = np.sqrt(2.0 * self._max_constraint_val * rat)

        if np.isnan(initial_step_size):
            logger.log("Initial step size is NaN! Rejecting the step!")
            return

        initial_descent_step = initial_step_size * descent_direction
        logger.log("descent direction computed")

        prev_params = self._target.get_param_values()
        prev_params_values = _flatten_params(prev_params)

        loss, constraint_val, n_iter, violated = 0, 0, 0, False
        for n_iter, ratio in enumerate(self._backtrack_ratio**np.arange(
                self._max_backtracks)):
            cur_step = ratio * initial_descent_step
            cur_params_values = prev_params_values - cur_step
            cur_params = _unflatten_params(cur_params_values,
                                           params_example=prev_params)
            self._target.set_params(cur_params)

            loss, constraint_val = self.loss(
                input_val_dict), self.constraint_val(input_val_dict)
            if loss < loss_before and constraint_val <= self._max_constraint_val:
                break
        """ ------------------- Logging Stuff -------------------------- """
        if np.isnan(loss):
            violated = True
            logger.log("Line search violated because loss is NaN")
        if np.isnan(constraint_val):
            violated = True
            logger.log("Line search violated because constraint %s is NaN" %
                       self._constraint_name)
        if loss >= loss_before:
            violated = True
            logger.log("Line search violated because loss not improving")
        if constraint_val >= self._max_constraint_val:
            violated = True
            logger.log(
                "Line search violated because constraint %s is violated" %
                self._constraint_name)

        if violated and not self._accept_violation:
            logger.log("Line search condition violated. Rejecting the step!")
            self._target.set_params(prev_params)

        logger.log("backtrack iters: %d" % n_iter)
        logger.log("computing loss after")
        logger.log("optimization finished")
Ejemplo n.º 9
0
    def train(self):
        """
        Trains policy on env using algo

        Pseudocode:
            for itr in n_itr:
                for step in num_inner_grad_steps:
                    sampler.sample()
                    algo.compute_updated_dists()
                algo.optimize_policy()
                sampler.update_goals()
        """
        with self.sess.as_default() as sess:

            # initialize uninitialized vars  (only initialize vars that were not loaded)
            uninit_vars = [
                var for var in tf.compat.v1.global_variables()
                if not sess.run(tf.compat.v1.is_variable_initialized(var))
            ]
            sess.run(tf.compat.v1.variables_initializer(uninit_vars))

            start_time = time.time()
            for itr in range(self.start_itr, self.n_itr):
                self.task = self.env.sample_tasks(self.sampler.meta_batch_size)
                self.sampler.set_tasks(self.task)
                itr_start_time = time.time()
                logger.log(
                    "\n ---------------- Iteration %d ----------------" % itr)
                logger.log(
                    "Sampling set of tasks/goals for this meta-batch...")
                """ -------------------- Sampling --------------------------"""

                logger.log("Obtaining samples...")
                time_env_sampling_start = time.time()
                paths = self.sampler.obtain_samples(log=True,
                                                    log_prefix='train-')
                sampling_time = time.time() - time_env_sampling_start
                """ ----------------- Processing Samples ---------------------"""

                logger.log("Processing samples...")
                time_proc_samples_start = time.time()
                samples_data = self.sample_processor.process_samples(
                    paths, log='all', log_prefix='train-')
                proc_samples_time = time.time() - time_proc_samples_start

                self.log_diagnostics(sum(paths.values(), []), prefix='train-')
                """ ------------------ Policy Update ---------------------"""

                logger.log("Optimizing policy...")
                # This needs to take all samples_data so that it can construct graph for meta-optimization.
                time_optimization_step_start = time.time()
                self.algo.optimize_policy(samples_data)
                """ ------------------- Logging Stuff --------------------------"""
                logger.logkv('Itr', itr)
                logger.logkv('n_timesteps',
                             self.sampler.total_timesteps_sampled)

                logger.logkv('Time-Optimization',
                             time.time() - time_optimization_step_start)
                logger.logkv('Time-SampleProc', np.sum(proc_samples_time))
                logger.logkv('Time-Sampling', sampling_time)

                logger.logkv('Time', time.time() - start_time)
                logger.logkv('ItrTime', time.time() - itr_start_time)

                logger.log("Saving snapshot...")
                params = self.get_itr_snapshot(itr)
                logger.save_itr_params(itr, params)
                logger.log("Saved")

                logger.dumpkvs()
                if itr == 0:
                    sess.graph.finalize()

        logger.log("Training finished")
        self.sess.close()