Beispiel #1
0
    def learn(self,
              total_timesteps,
              callback=None,
              seed=None,
              log_interval=100,
              tb_log_name="TRPO",
              reset_num_timesteps=True):

        new_tb_log = self._init_num_timesteps(reset_num_timesteps)

        with SetVerbosity(self.verbose), TensorboardWriter(self.graph, self.tensorboard_log, tb_log_name, new_tb_log) \
                as writer:
            self._setup_learn(seed)

            with self.sess.as_default():
                seg_gen = traj_segment_generator(
                    self.policy_pi,
                    self.env,
                    self.timesteps_per_batch,
                    reward_giver=self.reward_giver,
                    gail=self.using_gail)

                episodes_so_far = 0
                timesteps_so_far = 0
                iters_so_far = 0
                t_start = time.time()
                len_buffer = deque(
                    maxlen=40)  # rolling buffer for episode lengths
                reward_buffer = deque(
                    maxlen=40)  # rolling buffer for episode rewards
                self.episode_reward = np.zeros((self.n_envs, ))

                true_reward_buffer = None
                if self.using_gail:
                    true_reward_buffer = deque(maxlen=40)

                    # Initialize dataloader
                    batchsize = self.timesteps_per_batch // self.d_step
                    self.expert_dataset.init_dataloader(batchsize)

                    #  Stats not used for now
                    # TODO: replace with normal tb logging
                    #  g_loss_stats = Stats(loss_names)
                    #  d_loss_stats = Stats(reward_giver.loss_name)
                    #  ep_stats = Stats(["True_rewards", "Rewards", "Episode_length"])

                while True:
                    if callback is not None:
                        # Only stop training if return value is False, not when it is None. This is for backwards
                        # compatibility with callbacks that have no return statement.
                        if callback(locals(), globals()) is False:
                            break
                    if total_timesteps and timesteps_so_far >= total_timesteps:
                        break

                    logger.log("********** Iteration %i ************" %
                               iters_so_far)

                    def fisher_vector_product(vec):
                        return self.allmean(
                            self.compute_fvp(
                                vec, *fvpargs,
                                sess=self.sess)) + self.cg_damping * vec

                    # ------------------ Update G ------------------
                    logger.log("Optimizing Policy...")
                    # g_step = 1 when not using GAIL
                    mean_losses = None
                    vpredbefore = None
                    tdlamret = None
                    observation = None
                    action = None
                    seg = None
                    for k in range(self.g_step):
                        with self.timed("sampling"):
                            seg = seg_gen.__next__()
                        add_vtarg_and_adv(seg, self.gamma, self.lam)
                        # ob, ac, atarg, ret, td1ret = map(np.concatenate, (obs, acs, atargs, rets, td1rets))
                        observation, action, atarg, tdlamret = seg["ob"], seg[
                            "ac"], seg["adv"], seg["tdlamret"]
                        vpredbefore = seg[
                            "vpred"]  # predicted value function before update
                        atarg = (atarg - atarg.mean()) / atarg.std(
                        )  # standardized advantage function estimate

                        # true_rew is the reward without discount
                        if writer is not None:
                            self.episode_reward = total_episode_reward_logger(
                                self.episode_reward, seg["true_rew"].reshape(
                                    (self.n_envs, -1)), seg["dones"].reshape(
                                        (self.n_envs, -1)), writer,
                                self.num_timesteps)

                        args = seg["ob"], seg["ob"], seg["ac"], atarg
                        fvpargs = [arr[::5] for arr in args]

                        self.assign_old_eq_new(sess=self.sess)

                        with self.timed("computegrad"):
                            steps = self.num_timesteps + (k + 1) * (
                                seg["total_timestep"] / self.g_step)
                            run_options = tf.RunOptions(
                                trace_level=tf.RunOptions.FULL_TRACE)
                            run_metadata = tf.RunMetadata(
                            ) if self.full_tensorboard_log else None
                            # run loss backprop with summary, and save the metadata (memory, compute time, ...)
                            if writer is not None:
                                summary, grad, *lossbefore = self.compute_lossandgrad(
                                    *args,
                                    tdlamret,
                                    sess=self.sess,
                                    options=run_options,
                                    run_metadata=run_metadata)
                                if self.full_tensorboard_log:
                                    writer.add_run_metadata(
                                        run_metadata, 'step%d' % steps)
                                writer.add_summary(summary, steps)
                            else:
                                _, grad, *lossbefore = self.compute_lossandgrad(
                                    *args,
                                    tdlamret,
                                    sess=self.sess,
                                    options=run_options,
                                    run_metadata=run_metadata)

                        lossbefore = self.allmean(np.array(lossbefore))
                        grad = self.allmean(grad)
                        if np.allclose(grad, 0):
                            logger.log("Got zero gradient. not updating")
                        else:
                            with self.timed("conjugate_gradient"):
                                stepdir = conjugate_gradient(
                                    fisher_vector_product,
                                    grad,
                                    cg_iters=self.cg_iters,
                                    verbose=self.rank == 0
                                    and self.verbose >= 1)
                            assert np.isfinite(stepdir).all()
                            shs = .5 * stepdir.dot(
                                fisher_vector_product(stepdir))
                            # abs(shs) to avoid taking square root of negative values
                            lagrange_multiplier = np.sqrt(
                                abs(shs) / self.max_kl)
                            # logger.log("lagrange multiplier:", lm, "gnorm:", np.linalg.norm(g))
                            fullstep = stepdir / lagrange_multiplier
                            expectedimprove = grad.dot(fullstep)
                            surrbefore = lossbefore[0]
                            stepsize = 1.0
                            thbefore = self.get_flat()
                            thnew = None
                            for _ in range(10):
                                thnew = thbefore + fullstep * stepsize
                                self.set_from_flat(thnew)
                                mean_losses = surr, kl_loss, *_ = self.allmean(
                                    np.array(
                                        self.compute_losses(*args,
                                                            sess=self.sess)))
                                improve = surr - surrbefore
                                logger.log("Expected: %.3f Actual: %.3f" %
                                           (expectedimprove, improve))
                                if not np.isfinite(mean_losses).all():
                                    logger.log(
                                        "Got non-finite value of losses -- bad!"
                                    )
                                elif kl_loss > self.max_kl * 1.5:
                                    logger.log(
                                        "violated KL constraint. shrinking step."
                                    )
                                elif improve < 0:
                                    logger.log(
                                        "surrogate didn't improve. shrinking step."
                                    )
                                else:
                                    logger.log("Stepsize OK!")
                                    break
                                stepsize *= .5
                            else:
                                logger.log("couldn't compute a good step")
                                self.set_from_flat(thbefore)
                            if self.nworkers > 1 and iters_so_far % 20 == 0:
                                # list of tuples
                                paramsums = MPI.COMM_WORLD.allgather(
                                    (thnew.sum(), self.vfadam.getflat().sum()))
                                assert all(
                                    np.allclose(ps, paramsums[0])
                                    for ps in paramsums[1:])

                        with self.timed("vf"):
                            for _ in range(self.vf_iters):
                                # NOTE: for recurrent policies, use shuffle=False?
                                for (mbob, mbret) in dataset.iterbatches(
                                    (seg["ob"], seg["tdlamret"]),
                                        include_final_partial_batch=False,
                                        batch_size=128,
                                        shuffle=True):
                                    grad = self.allmean(
                                        self.compute_vflossandgrad(
                                            mbob, mbob, mbret, sess=self.sess))
                                    self.vfadam.update(grad, self.vf_stepsize)

                    for (loss_name, loss_val) in zip(self.loss_names,
                                                     mean_losses):
                        logger.record_tabular(loss_name, loss_val)

                    logger.record_tabular(
                        "explained_variance_tdlam_before",
                        explained_variance(vpredbefore, tdlamret))

                    if self.using_gail:
                        # ------------------ Update D ------------------
                        logger.log("Optimizing Discriminator...")
                        logger.log(fmt_row(13, self.reward_giver.loss_name))
                        assert len(observation) == self.timesteps_per_batch
                        batch_size = self.timesteps_per_batch // self.d_step

                        # NOTE: uses only the last g step for observation
                        d_losses = [
                        ]  # list of tuples, each of which gives the loss for a minibatch
                        # NOTE: for recurrent policies, use shuffle=False?
                        for ob_batch, ac_batch in dataset.iterbatches(
                            (observation, action),
                                include_final_partial_batch=False,
                                batch_size=batch_size,
                                shuffle=True):
                            ob_expert, ac_expert = self.expert_dataset.get_next_batch(
                            )
                            # update running mean/std for reward_giver
                            if self.reward_giver.normalize:
                                self.reward_giver.obs_rms.update(
                                    np.concatenate((ob_batch, ob_expert), 0))

                            # Reshape actions if needed when using discrete actions
                            if isinstance(self.action_space,
                                          gym.spaces.Discrete):
                                if len(ac_batch.shape) == 2:
                                    ac_batch = ac_batch[:, 0]
                                if len(ac_expert.shape) == 2:
                                    ac_expert = ac_expert[:, 0]
                            *newlosses, grad = self.reward_giver.lossandgrad(
                                ob_batch, ac_batch, ob_expert, ac_expert)
                            self.d_adam.update(self.allmean(grad),
                                               self.d_stepsize)
                            d_losses.append(newlosses)
                        logger.log(fmt_row(13, np.mean(d_losses, axis=0)))

                        # lr: lengths and rewards
                        lr_local = (seg["ep_lens"], seg["ep_rets"],
                                    seg["ep_true_rets"])  # local values
                        list_lr_pairs = MPI.COMM_WORLD.allgather(
                            lr_local)  # list of tuples
                        lens, rews, true_rets = map(flatten_lists,
                                                    zip(*list_lr_pairs))
                        true_reward_buffer.extend(true_rets)
                    else:
                        # lr: lengths and rewards
                        lr_local = (seg["ep_lens"], seg["ep_rets"]
                                    )  # local values
                        list_lr_pairs = MPI.COMM_WORLD.allgather(
                            lr_local)  # list of tuples
                        lens, rews = map(flatten_lists, zip(*list_lr_pairs))
                    len_buffer.extend(lens)
                    reward_buffer.extend(rews)

                    if len(len_buffer) > 0:
                        logger.record_tabular("EpLenMean", np.mean(len_buffer))
                        logger.record_tabular("EpRewMean",
                                              np.mean(reward_buffer))
                    if self.using_gail:
                        logger.record_tabular("EpTrueRewMean",
                                              np.mean(true_reward_buffer))
                    logger.record_tabular("EpThisIter", len(lens))
                    episodes_so_far += len(lens)
                    current_it_timesteps = MPI.COMM_WORLD.allreduce(
                        seg["total_timestep"])
                    timesteps_so_far += current_it_timesteps
                    self.num_timesteps += current_it_timesteps
                    iters_so_far += 1

                    logger.record_tabular("EpisodesSoFar", episodes_so_far)
                    logger.record_tabular("TimestepsSoFar", self.num_timesteps)
                    logger.record_tabular("TimeElapsed", time.time() - t_start)

                    if self.verbose >= 1 and self.rank == 0:
                        logger.dump_tabular()

        return self
    def learn(self,
              total_timesteps,
              callback=None,
              seed=None,
              log_interval=100,
              tb_log_name="TRPO",
              reset_num_timesteps=True):

        new_tb_log = self._init_num_timesteps(reset_num_timesteps)
        print("args", self.kappa, self.vf_phi_update_interval, seed)

        with SetVerbosity(self.verbose):
            #self._setup_learn(seed)

            with self.sess.as_default():
                seg_gen = traj_segment_generator(
                    self.policy_pi,
                    self.policy_phi_old,
                    self.env,
                    self.timesteps_per_batch,
                    self.gamma,
                    self.kappa,
                    reward_giver=self.reward_giver)

                episodes_so_far = 0
                timesteps_so_far = 0
                iters_so_far = 0
                t_start = time.time()
                len_buffer = deque(
                    maxlen=40)  # rolling buffer for episode lengths
                reward_buffer = deque(
                    maxlen=40)  # rolling buffer for episode rewards
                vf_loss_buffer = deque(maxlen=80)  # rolling buffer for vf loss
                vf_phi_loss_buffer = deque(
                    maxlen=80)  # rolling buffer for vf phi loss
                self.episode_reward = np.zeros((self.n_envs, ))

                true_reward_buffer = None

                while True:
                    if callback is not None:
                        # Only stop training if return value is False, not when it is None. This is for backwards
                        # compatibility with callbacks that have no return statement.
                        if callback(locals(), globals()) is False:
                            break
                    if total_timesteps and timesteps_so_far >= total_timesteps:
                        break

                    logger.log("********** Iteration %i ************" %
                               iters_so_far)

                    def fisher_vector_product(vec):
                        return self.allmean(
                            self.compute_fvp(
                                vec, *fvpargs,
                                sess=self.sess)) + self.cg_damping * vec

                    # ------------------ Update G ------------------
                    logger.log("Optimizing Policy...")
                    # g_step = 1 when not using GAIL
                    mean_losses = None
                    vpredbefore = None
                    tdlamret = None
                    observation = None
                    action = None
                    seg = None
                    for k in range(self.g_step):
                        with self.timed("sampling"):
                            seg = seg_gen.__next__()
                        add_vtarg_and_adv(seg, self.gamma, self.kappa)
                        #print("seg is", seg["tdlamret"], seg["tdlamret_phi"])
                        # ob, ac, atarg, ret, td1ret = map(np.concatenate, (obs, acs, atargs, rets, td1rets))
                        iters_val = int(
                            (2 * int(iters_so_far /
                                     (1 * self.vf_phi_update_interval)) + 1) *
                            (1 * self.vf_phi_update_interval) * 1 / 2)
                        print("iters_val", iters_val,
                              int(iters_so_far / self.vf_phi_update_interval),
                              self.vf_phi_update_interval)
                        vf_loss = []

                        #if iters_so_far < iters_val:
                        print("optimizing surrogate mdp...")
                        observation, action = seg["observations"], seg[
                            "actions"]
                        atarg, tdlamret = seg["adv"], seg["tdlamret"]

                        vpredbefore = seg[
                            "vpred"]  # predicted value function before update
                        atarg = (atarg - atarg.mean()) / atarg.std(
                        )  # standardized advantage function estimate

                        args = seg["observations"], seg["observations"], seg[
                            "actions"], atarg
                        # Subsampling: see p40-42 of John Schulman thesis
                        # http://joschu.net/docs/thesis.pdf
                        fvpargs = [arr[::5] for arr in args]

                        self.assign_old_eq_new(sess=self.sess)

                        with self.timed("computegrad"):
                            steps = self.num_timesteps + (k + 1) * (
                                seg["total_timestep"] / self.g_step)
                            run_options = tf.RunOptions(
                                trace_level=tf.RunOptions.FULL_TRACE)
                            run_metadata = None
                            # run loss backprop with summary, and save the metadata (memory, compute time, ...)
                            _, grad, *lossbefore = self.compute_lossandgrad(
                                *args,
                                tdlamret,
                                sess=self.sess,
                                options=run_options,
                                run_metadata=run_metadata)

                        lossbefore = self.allmean(np.array(lossbefore))
                        grad = self.allmean(grad)
                        if np.allclose(grad, 0):
                            logger.log("Got zero gradient. not updating")
                        else:
                            with self.timed("conjugate_gradient"):
                                stepdir = conjugate_gradient(
                                    fisher_vector_product,
                                    grad,
                                    cg_iters=self.cg_iters,
                                    verbose=self.rank == 0
                                    and self.verbose >= 1)
                            assert np.isfinite(stepdir).all()
                            shs = .5 * stepdir.dot(
                                fisher_vector_product(stepdir))
                            # abs(shs) to avoid taking square root of negative values
                            lagrange_multiplier = np.sqrt(
                                abs(shs) / self.max_kl)
                            # logger.log("lagrange multiplier:", lm, "gnorm:", np.linalg.norm(g))
                            fullstep = stepdir / lagrange_multiplier
                            expectedimprove = grad.dot(fullstep)
                            surrbefore = lossbefore[0]
                            stepsize = 1.0
                            thbefore = self.get_flat()
                            thnew = None
                            for _ in range(10):
                                thnew = thbefore + fullstep * stepsize
                                self.set_from_flat(thnew)
                                mean_losses = surr, kl_loss, *_ = self.allmean(
                                    np.array(
                                        self.compute_losses(*args,
                                                            sess=self.sess)))
                                improve = surr - surrbefore
                                logger.log("Expected: %.3f Actual: %.3f" %
                                           (expectedimprove, improve))
                                if not np.isfinite(mean_losses).all():
                                    logger.log(
                                        "Got non-finite value of losses -- bad!"
                                    )
                                elif kl_loss > self.max_kl * 1.5:
                                    logger.log(
                                        "violated KL constraint. shrinking step."
                                    )
                                elif improve < 0:
                                    logger.log(
                                        "surrogate didn't improve. shrinking step."
                                    )
                                else:
                                    logger.log("Stepsize OK!")
                                    break
                                stepsize *= .5
                            else:
                                logger.log("couldn't compute a good step")
                                self.set_from_flat(thbefore)
                            if self.nworkers > 1 and iters_so_far % 20 == 0:
                                # list of tuples
                                paramsums = MPI.COMM_WORLD.allgather(
                                    (thnew.sum(), self.vfadam.getflat().sum()))
                                assert all(
                                    np.allclose(ps, paramsums[0])
                                    for ps in paramsums[1:])

                        with self.timed("vf"):
                            #vf_loss = []
                            for _ in range(self.vf_iters):
                                # NOTE: for recurrent policies, use shuffle=False?
                                for (mbob, mbret) in dataset.iterbatches(
                                    (seg["observations"], seg["tdlamret"]),
                                        include_final_partial_batch=False,
                                        batch_size=128,
                                        shuffle=True):
                                    grad = self.allmean(
                                        self.compute_vflossandgrad(
                                            mbob, mbob, mbret, sess=self.sess))
                                    self.vfadam.update(grad, self.vf_stepsize)
                                    vf_loss.append(
                                        self.compute_vf_loss(mbob,
                                                             mbob,
                                                             mbret,
                                                             sess=self.sess))

                        vf_phi_loss = []
                        #if iters_so_far >= iters_val: #self.vf_phi_update_interval == 0:
                        print("evaluating policy...")
                        with self.timed("vf_phi"):
                            #vf_phi_loss = []
                            for _ in range(self.vf_iters
                                           ):  #self.vf_phi_update_interval):
                                with self.temp_seed(self.phi_seed):
                                    for (mbob, mbret) in dataset.iterbatches(
                                        (seg["observations"],
                                         seg["tdlamret_phi"]),
                                            include_final_partial_batch=False,
                                            batch_size=128,
                                            shuffle=True):
                                        grad = self.allmean(
                                            self.compute_vf_phi_lossandgrad(
                                                mbob,
                                                mbob,
                                                mbret,
                                                sess=self.sess))
                                        #print("vf phi loss before", self.compute_vf_phi_loss(mbob, mbob, mbret, sess=self.sess))
                                        #print("vf loss before phi is updated", self.compute_vf_loss(mbob, mbob, mbret_vf, sess=self.sess))
                                        self.vf_phi_adam.update(
                                            grad, self.vf_stepsize)
                                        vf_phi_loss.append(
                                            self.compute_vf_phi_loss(
                                                mbob,
                                                mbob,
                                                mbret,
                                                sess=self.sess))
                                        #print("gradient value", grad)
                                        #print("vf phi loss after", self.compute_vf_phi_loss(mbob, mbob, mbret, sess=self.sess))
                                        #print("vf loss after phi is updated", self.compute_vf_loss(mbob, mbob, mbret_vf, sess=self.sess))
                                    self.phi_seed += 1

                        #print("vf phi old loss", self.compute_vf_phi_old_loss(self.phi_old_obs, sess=self.sess))
                        if iters_so_far % self.vf_phi_update_interval == 0:
                            with self.timed("vf_phi_old_update"):
                                self.update_phi_old(sess=self.sess)
                                #update_variables = set(self.policy_vars + self.oldpolicy_vars)
                                #self.sess.run(tf.variables_initializer(update_variables))

                        if iters_so_far % self.eval_freq == 0:
                            if self.test_env is not None:
                                self.eval_reward = self.evaluate_agent(
                                    self.test_env, self.eval_freq)

                    if mean_losses is None:
                        mean_losses = [None, None, None, None, None]
                    for (loss_name, loss_val) in zip(self.loss_names,
                                                     mean_losses):
                        logger.record_tabular(loss_name, loss_val)

                    if vpredbefore is not None:
                        logger.record_tabular(
                            "explained_variance_tdlam_before",
                            explained_variance(vpredbefore, tdlamret))
                    else:
                        logger.record_tabular(
                            "explained_variance_tdlam_before", None)

                    # lr: lengths and rewards
                    lr_local = (seg["ep_lens"], seg["ep_rets"])  # local values
                    list_lr_pairs = MPI.COMM_WORLD.allgather(
                        lr_local)  # list of tuples
                    lens, rews = map(flatten_lists, zip(*list_lr_pairs))
                    len_buffer.extend(lens)
                    reward_buffer.extend(rews)
                    vf_loss_buffer.extend(vf_loss)
                    vf_phi_loss_buffer.extend(vf_phi_loss)

                    if len(len_buffer) > 0:
                        logger.record_tabular("EpLenMean", np.mean(len_buffer))
                        logger.record_tabular("EpRewMean",
                                              np.mean(reward_buffer))
                        logger.record_tabular("VfLoss",
                                              np.mean(vf_loss_buffer))
                        logger.record_tabular("VfPhiLoss",
                                              np.mean(vf_phi_loss_buffer))
                    logger.record_tabular("EpThisIter", len(lens))
                    episodes_so_far += len(lens)
                    current_it_timesteps = MPI.COMM_WORLD.allreduce(
                        seg["total_timestep"])
                    timesteps_so_far += current_it_timesteps
                    self.num_timesteps += current_it_timesteps
                    iters_so_far += 1

                    logger.record_tabular("EpisodesSoFar", episodes_so_far)
                    logger.record_tabular("TimestepsSoFar", self.num_timesteps)
                    logger.record_tabular("TimeElapsed", time.time() - t_start)
                    logger.record_tabular("EvaluationScore", self.eval_reward)

                    if self.verbose >= 1 and self.rank == 0:
                        logger.dump_tabular()

        return self
Beispiel #3
0
    def learn(self,
              total_timesteps,
              callback=None,
              seed=None,
              log_interval=100):
        with SetVerbosity(self.verbose):
            self._setup_learn(seed)

            with self.sess.as_default():
                seg_gen = traj_segment_generator(
                    self.policy_pi,
                    self.env,
                    self.timesteps_per_batch,
                    reward_giver=self.reward_giver,
                    gail=self.using_gail)

                episodes_so_far = 0
                timesteps_so_far = 0
                iters_so_far = 0
                t_start = time.time()
                lenbuffer = deque(
                    maxlen=40)  # rolling buffer for episode lengths
                rewbuffer = deque(
                    maxlen=40)  # rolling buffer for episode rewards

                true_rewbuffer = None
                if self.using_gail:
                    true_rewbuffer = deque(maxlen=40)
                    #  Stats not used for now
                    #  g_loss_stats = Stats(loss_names)
                    #  d_loss_stats = Stats(reward_giver.loss_name)
                    #  ep_stats = Stats(["True_rewards", "Rewards", "Episode_length"])

                    # if provide pretrained weight
                    if self.pretrained_weight is not None:
                        tf_util.load_state(
                            self.pretrained_weight,
                            var_list=tf_util.get_globals_vars("pi"),
                            sess=self.sess)

                while True:
                    if callback:
                        callback(locals(), globals())
                    if total_timesteps and timesteps_so_far >= total_timesteps:
                        break

                    logger.log("********** Iteration %i ************" %
                               iters_so_far)

                    def fisher_vector_product(vec):
                        return self.allmean(
                            self.compute_fvp(
                                vec, *fvpargs,
                                sess=self.sess)) + self.cg_damping * vec

                    # ------------------ Update G ------------------
                    logger.log("Optimizing Policy...")
                    # g_step = 1 when not using GAIL
                    mean_losses = None
                    vpredbefore = None
                    tdlamret = None
                    observation = None
                    action = None
                    seg = None
                    for _ in range(self.g_step):
                        with self.timed("sampling"):
                            seg = seg_gen.__next__()
                        add_vtarg_and_adv(seg, self.gamma, self.lam)
                        # ob, ac, atarg, ret, td1ret = map(np.concatenate, (obs, acs, atargs, rets, td1rets))
                        observation, action, atarg, tdlamret = seg["ob"], seg[
                            "ac"], seg["adv"], seg["tdlamret"]
                        vpredbefore = seg[
                            "vpred"]  # predicted value function before udpate
                        atarg = (atarg - atarg.mean()) / atarg.std(
                        )  # standardized advantage function estimate

                        args = seg["ob"], seg["ob"], seg["ac"], atarg
                        fvpargs = [arr[::5] for arr in args]

                        self.assign_old_eq_new(sess=self.sess)

                        with self.timed("computegrad"):
                            *lossbefore, grad = self.compute_lossandgrad(
                                *args, sess=self.sess)
                        lossbefore = self.allmean(np.array(lossbefore))
                        grad = self.allmean(grad)
                        if np.allclose(grad, 0):
                            logger.log("Got zero gradient. not updating")
                        else:
                            with self.timed("cg"):
                                stepdir = conjugate_gradient(
                                    fisher_vector_product,
                                    grad,
                                    cg_iters=self.cg_iters,
                                    verbose=self.rank == 0
                                    and self.verbose >= 1)
                            assert np.isfinite(stepdir).all()
                            shs = .5 * stepdir.dot(
                                fisher_vector_product(stepdir))
                            # abs(shs) to avoid taking square root of negative values
                            lagrange_multiplier = np.sqrt(
                                abs(shs) / self.max_kl)
                            # logger.log("lagrange multiplier:", lm, "gnorm:", np.linalg.norm(g))
                            fullstep = stepdir / lagrange_multiplier
                            expectedimprove = grad.dot(fullstep)
                            surrbefore = lossbefore[0]
                            stepsize = 1.0
                            thbefore = self.get_flat()
                            thnew = None
                            for _ in range(10):
                                thnew = thbefore + fullstep * stepsize
                                self.set_from_flat(thnew)
                                mean_losses = surr, kl_loss, *_ = self.allmean(
                                    np.array(
                                        self.compute_losses(*args,
                                                            sess=self.sess)))
                                improve = surr - surrbefore
                                logger.log("Expected: %.3f Actual: %.3f" %
                                           (expectedimprove, improve))
                                if not np.isfinite(mean_losses).all():
                                    logger.log(
                                        "Got non-finite value of losses -- bad!"
                                    )
                                elif kl_loss > self.max_kl * 1.5:
                                    logger.log(
                                        "violated KL constraint. shrinking step."
                                    )
                                elif improve < 0:
                                    logger.log(
                                        "surrogate didn't improve. shrinking step."
                                    )
                                else:
                                    logger.log("Stepsize OK!")
                                    break
                                stepsize *= .5
                            else:
                                logger.log("couldn't compute a good step")
                                self.set_from_flat(thbefore)
                            if self.nworkers > 1 and iters_so_far % 20 == 0:
                                # list of tuples
                                paramsums = MPI.COMM_WORLD.allgather(
                                    (thnew.sum(), self.vfadam.getflat().sum()))
                                assert all(
                                    np.allclose(ps, paramsums[0])
                                    for ps in paramsums[1:])

                        with self.timed("vf"):
                            for _ in range(self.vf_iters):
                                for (mbob, mbret) in dataset.iterbatches(
                                    (seg["ob"], seg["tdlamret"]),
                                        include_final_partial_batch=False,
                                        batch_size=128):
                                    grad = self.allmean(
                                        self.compute_vflossandgrad(
                                            mbob, mbob, mbret, sess=self.sess))
                                    self.vfadam.update(grad, self.vf_stepsize)

                    for (loss_name, loss_val) in zip(self.loss_names,
                                                     mean_losses):
                        logger.record_tabular(loss_name, loss_val)

                    logger.record_tabular(
                        "ev_tdlam_before",
                        explained_variance(vpredbefore, tdlamret))

                    if self.using_gail:
                        # ------------------ Update D ------------------
                        logger.log("Optimizing Discriminator...")
                        logger.log(fmt_row(13, self.reward_giver.loss_name))
                        ob_expert, ac_expert = self.expert_dataset.get_next_batch(
                            len(observation))
                        batch_size = len(observation) // self.d_step
                        d_losses = [
                        ]  # list of tuples, each of which gives the loss for a minibatch
                        for ob_batch, ac_batch in dataset.iterbatches(
                            (observation, action),
                                include_final_partial_batch=False,
                                batch_size=batch_size):
                            ob_expert, ac_expert = self.expert_dataset.get_next_batch(
                                len(ob_batch))
                            # update running mean/std for reward_giver
                            if hasattr(self.reward_giver, "obs_rms"):
                                self.reward_giver.obs_rms.update(
                                    np.concatenate((ob_batch, ob_expert), 0))
                            *newlosses, grad = self.reward_giver.lossandgrad(
                                ob_batch, ac_batch, ob_expert, ac_expert)
                            self.d_adam.update(self.allmean(grad),
                                               self.d_stepsize)
                            d_losses.append(newlosses)
                        logger.log(fmt_row(13, np.mean(d_losses, axis=0)))

                        lrlocal = (seg["ep_lens"], seg["ep_rets"],
                                   seg["ep_true_rets"])  # local values
                        listoflrpairs = MPI.COMM_WORLD.allgather(
                            lrlocal)  # list of tuples
                        lens, rews, true_rets = map(flatten_lists,
                                                    zip(*listoflrpairs))
                        true_rewbuffer.extend(true_rets)
                    else:
                        lrlocal = (seg["ep_lens"], seg["ep_rets"]
                                   )  # local values
                        listoflrpairs = MPI.COMM_WORLD.allgather(
                            lrlocal)  # list of tuples
                        lens, rews = map(flatten_lists, zip(*listoflrpairs))
                    lenbuffer.extend(lens)
                    rewbuffer.extend(rews)

                    logger.record_tabular("EpLenMean", np.mean(lenbuffer))
                    logger.record_tabular("EpRewMean", np.mean(rewbuffer))
                    if self.using_gail:
                        logger.record_tabular("EpTrueRewMean",
                                              np.mean(true_rewbuffer))
                    logger.record_tabular("EpThisIter", len(lens))
                    episodes_so_far += len(lens)
                    timesteps_so_far += seg["total_timestep"]
                    iters_so_far += 1

                    logger.record_tabular("EpisodesSoFar", episodes_so_far)
                    logger.record_tabular("TimestepsSoFar", timesteps_so_far)
                    logger.record_tabular("TimeElapsed", time.time() - t_start)

                    if self.verbose >= 1 and self.rank == 0:
                        logger.dump_tabular()

        return self
Beispiel #4
0
    def learn(self,
              total_timesteps,
              callback=None,
              seed=None,
              log_interval=100,
              tb_log_name="MDPO",
              reset_num_timesteps=True):

        new_tb_log = self._init_num_timesteps(reset_num_timesteps)
        print("got seed, sgd_steps", seed, self.sgd_steps)

        with SetVerbosity(self.verbose), TensorboardWriter(self.graph, self.tensorboard_log, tb_log_name, new_tb_log) \
                as writer:

            with self.sess.as_default():
                seg_gen = traj_segment_generator(
                    self.old_policy,
                    self.env,
                    self.timesteps_per_batch,
                    reward_giver=self.reward_giver,
                    gail=self.using_gail,
                    entcoeff=self.entcoeff)

                episodes_so_far = 0
                timesteps_so_far = 0
                iters_so_far = 0
                t_start = time.time()
                len_buffer = deque(
                    maxlen=40)  # rolling buffer for episode lengths
                reward_buffer = deque(
                    maxlen=40)  # rolling buffer for episode rewards
                self.episode_reward = np.zeros((self.n_envs, ))
                self.outer_learning_rate = get_schedule_fn(3e-4)
                self.cliprange_vf = get_schedule_fn(0.2)

                true_reward_buffer = None
                if self.using_gail:
                    true_reward_buffer = deque(maxlen=40)

                    # Initialize dataloader
                    batchsize = self.timesteps_per_batch // self.d_step
                    self.expert_dataset.init_dataloader(batchsize)

                    #  Stats not used for now
                    # TODO: replace with normal tb logging
                    #  g_loss_stats = Stats(loss_names)
                    #  d_loss_stats = Stats(reward_giver.loss_name)
                    #  ep_stats = Stats(["True_rewards", "Rewards", "Episode_length"])

                while True:
                    if callback is not None:
                        # Only stop training if return value is False, not when it is None. This is for backwards
                        # compatibility with callbacks that have no return statement.
                        if callback(locals(), globals()) is False:
                            break
                    if total_timesteps and timesteps_so_far >= total_timesteps:
                        break

                    logger.log("********** Iteration %i ************" %
                               iters_so_far)

                    #def fisher_vector_product(vec):
                    #    return self.allmean(self.compute_fvp(vec, *fvpargs, sess=self.sess)) + self.cg_damping * vec

                    # ------------------ Update G ------------------
                    logger.log("Optimizing Policy...")
                    # g_step = 1 when not using GAIL
                    mean_losses = None
                    vpredbefore = None
                    tdlamret = None
                    observation = None
                    action = None
                    seg = None
                    for k in range(self.g_step):
                        with self.timed("sampling"):
                            seg = seg_gen.__next__()
                        add_vtarg_and_adv(seg, self.gamma, self.lam)
                        # ob, ac, atarg, ret, td1ret = map(np.concatenate, (obs, acs, atargs, rets, td1rets))
                        observation, action = seg["observations"], seg[
                            "actions"]
                        atarg, tdlamret = seg["adv"], seg["tdlamret"]

                        vpredbefore = seg[
                            "vpred"]  # predicted value function before update
                        atarg = (atarg - atarg.mean()) / atarg.std(
                        )  # standardized advantage function estimate

                        # true_rew is the reward without discount
                        if writer is not None:
                            self.episode_reward = total_episode_reward_logger(
                                self.episode_reward,
                                seg["true_rewards"].reshape(
                                    (self.n_envs, -1)), seg["dones"].reshape(
                                        (self.n_envs, -1)), writer,
                                self.num_timesteps)

                        n_updates = int(total_timesteps /
                                        self.timesteps_per_batch)
                        lr_now = 1.0 - (iters_so_far - 1.0) / n_updates
                        outer_lr_now = self.outer_learning_rate(
                            1.0 - (iters_so_far - 1.0) / n_updates)
                        clip_now = self.cliprange_vf(1.0 -
                                                     (iters_so_far - 1.0) /
                                                     n_updates)
                        args = seg["observations"], seg["observations"], seg[
                            "actions"], atarg
                        # Subsampling: see p40-42 of John Schulman thesis
                        # http://joschu.net/docs/thesis.pdf
                        #fvpargs = [arr[::5] for arr in args]

                        with self.timed("computegrad"):
                            steps = self.num_timesteps + (k + 1) * (
                                seg["total_timestep"] / self.g_step)
                            run_options = tf.RunOptions(
                                trace_level=tf.RunOptions.FULL_TRACE)
                            run_metadata = tf.RunMetadata(
                            ) if self.full_tensorboard_log else None
                            # run loss backprop with summary, and save the metadata (memory, compute time, ...)
                            if writer is not None:
                                summary, grad, *lossbefore = self.compute_lossandgrad(
                                    *args,
                                    tdlamret,
                                    sess=self.sess,
                                    options=run_options,
                                    run_metadata=run_metadata)
                                if self.full_tensorboard_log:
                                    writer.add_run_metadata(
                                        run_metadata, 'step%d' % steps)
                                writer.add_summary(summary, steps)
                            else:
                                _, grad, *lossbefore = self.compute_lossandgrad(
                                    *args,
                                    tdlamret,
                                    lr_now,
                                    seg["vpred"],
                                    seg["observations"],
                                    sess=self.sess,
                                    options=run_options,
                                    run_metadata=run_metadata)
                                td_map = {
                                    self.policy_pi.obs_ph: seg["observations"],
                                    self.old_policy.obs_ph:
                                    seg["observations"],
                                    self.closed_policy.obs_ph:
                                    seg["observations"],
                                    self.action: seg["actions"],
                                    self.atarg: atarg,
                                    self.ret: tdlamret,
                                    self.learning_rate_ph: lr_now,
                                    self.outer_learning_rate_ph: outer_lr_now,
                                    self.vtarg: seg["vpred"]
                                }
                                for _ in range(int(self.sgd_steps)):
                                    _ = self.sess.run(self._train, td_map)
                                    #if self.method == "closed-KL":
                                    #    _ = self.sess.run(self._train_policy, td_map)

                        if np.allclose(grad, 0):
                            logger.log("Got zero gradient. not updating")
                        else:
                            for _ in range(1):
                                mean_losses = surr, kl_loss, *_ = self.allmean(
                                    np.array(
                                        self.compute_losses(*args,
                                                            lr_now,
                                                            seg["vpred"],
                                                            sess=self.sess)))

                        with self.timed("vf"):
                            for _ in range(self.vf_iters):
                                # NOTE: for recurrent policies, use shuffle=False?
                                for (mbob, mbret,
                                     mbval) in dataset.iterbatches(
                                         (seg["observations"], seg["tdlamret"],
                                          seg["vpred"]),
                                         include_final_partial_batch=False,
                                         batch_size=128,
                                         shuffle=True):
                                    grad = self.allmean(
                                        self.compute_vflossandgrad(
                                            mbob,
                                            mbob,
                                            mbret,
                                            mbval,
                                            clip_now,
                                            sess=self.sess))
                                    self.vfadam.update(
                                        grad, outer_lr_now)  #self.vf_stepsize)

                        if iters_so_far % 1 == 0:
                            print("updating theta now")
                            self.assign_old_eq_new(sess=self.sess)

                    for (loss_name, loss_val) in zip(self.loss_names,
                                                     mean_losses):
                        logger.record_tabular(loss_name, loss_val)

                    logger.record_tabular(
                        "explained_variance_tdlam_before",
                        explained_variance(vpredbefore, tdlamret))

                    if self.using_gail:
                        # ------------------ Update D ------------------
                        logger.log("Optimizing Discriminator...")
                        logger.log(fmt_row(13, self.reward_giver.loss_name))
                        assert len(observation) == self.timesteps_per_batch
                        batch_size = self.timesteps_per_batch // self.d_step

                        # NOTE: uses only the last g step for observation
                        d_losses = [
                        ]  # list of tuples, each of which gives the loss for a minibatch
                        # NOTE: for recurrent policies, use shuffle=False?
                        for ob_batch, ac_batch in dataset.iterbatches(
                            (observation, action),
                                include_final_partial_batch=False,
                                batch_size=batch_size,
                                shuffle=True):
                            ob_expert, ac_expert = self.expert_dataset.get_next_batch(
                            )
                            # update running mean/std for reward_giver
                            if self.reward_giver.normalize:
                                self.reward_giver.obs_rms.update(
                                    np.concatenate((ob_batch, ob_expert), 0))

                            # Reshape actions if needed when using discrete actions
                            if isinstance(self.action_space,
                                          gym.spaces.Discrete):
                                if len(ac_batch.shape) == 2:
                                    ac_batch = ac_batch[:, 0]
                                if len(ac_expert.shape) == 2:
                                    ac_expert = ac_expert[:, 0]
                            *newlosses, grad = self.reward_giver.lossandgrad(
                                ob_batch, ac_batch, ob_expert, ac_expert)
                            self.d_adam.update(self.allmean(grad),
                                               self.d_stepsize)
                            d_losses.append(newlosses)
                        logger.log(fmt_row(13, np.mean(d_losses, axis=0)))

                        # lr: lengths and rewards
                        lr_local = (seg["ep_lens"], seg["ep_rets"],
                                    seg["ep_true_rets"])  # local values
                        list_lr_pairs = MPI.COMM_WORLD.allgather(
                            lr_local)  # list of tuples
                        lens, rews, true_rets = map(flatten_lists,
                                                    zip(*list_lr_pairs))
                        true_reward_buffer.extend(true_rets)
                    else:
                        # lr: lengths and rewards
                        lr_local = (seg["ep_lens"], seg["ep_rets"]
                                    )  # local values
                        list_lr_pairs = MPI.COMM_WORLD.allgather(
                            lr_local)  # list of tuples
                        lens, rews = map(flatten_lists, zip(*list_lr_pairs))
                    len_buffer.extend(lens)
                    reward_buffer.extend(rews)

                    if len(len_buffer) > 0:
                        logger.record_tabular("EpLenMean", np.mean(len_buffer))
                        logger.record_tabular("EpRewMean",
                                              np.mean(reward_buffer))
                    if self.using_gail:
                        logger.record_tabular("EpTrueRewMean",
                                              np.mean(true_reward_buffer))
                    logger.record_tabular("EpThisIter", len(lens))
                    episodes_so_far += len(lens)
                    current_it_timesteps = MPI.COMM_WORLD.allreduce(
                        seg["total_timestep"])
                    timesteps_so_far += current_it_timesteps
                    self.num_timesteps += current_it_timesteps
                    iters_so_far += 1

                    logger.record_tabular("EpisodesSoFar", episodes_so_far)
                    logger.record_tabular("TimestepsSoFar", self.num_timesteps)
                    logger.record_tabular("TimeElapsed", time.time() - t_start)
                    logger.record_tabular("Tsallis-q", self.tsallis_q)

                    if self.verbose >= 1 and self.rank == 0:
                        logger.dump_tabular()

        return self
Beispiel #5
0
    def learn(self,
              total_timesteps,
              callback=None,
              log_interval=100,
              tb_log_name="TRPO",
              reset_num_timesteps=True):

        new_tb_log = self._init_num_timesteps(reset_num_timesteps)
        callback = self._init_callback(callback)

        with SetVerbosity(self.verbose), TensorboardWriter(
                self.graph, self.tensorboard_log, tb_log_name,
                new_tb_log) as writer:
            self._setup_learn()

            with self.sess.as_default():
                callback.on_training_start(locals(), globals())

                seg_gen = traj_segment_generator(self.policy_pi,
                                                 self.env,
                                                 self.timesteps_per_batch,
                                                 callback=callback)

                episodes_so_far = 0
                timesteps_so_far = 0
                iters_so_far = 0
                t_start = time.time()
                len_buffer = deque(
                    maxlen=40)  # rolling buffer for episode lengths
                reward_buffer = deque(
                    maxlen=40)  # rolling buffer for episode rewards

                while True:
                    if timesteps_so_far >= total_timesteps:
                        break

                    logger.log("********** Iteration %i ************" %
                               iters_so_far)

                    def fisher_vector_product(vec):
                        return self.allmean(
                            self.compute_fvp(
                                vec, *fvpargs,
                                sess=self.sess)) + self.cg_damping * vec

                    # ------------------ Update G ------------------
                    logger.log("Optimizing Policy...")
                    # g_step = 1 when not using GAIL
                    mean_losses = None
                    vpredbefore = None
                    tdlamret = None
                    observation = None
                    action = None
                    seg = None
                    for k in range(self.g_step):
                        with self.timed("sampling"):
                            seg = seg_gen.__next__()

                        # Stop training early (triggered by the callback)
                        if not seg.get('continue_training', True):  # pytype: disable=attribute-error
                            break

                        add_vtarg_and_adv(seg, self.gamma, self.lam)
                        # ob, ac, atarg, ret, td1ret = map(np.concatenate, (obs, acs, atargs, rets, td1rets))
                        observation, action = seg["observations"], seg[
                            "actions"]
                        atarg, tdlamret = seg["adv"], seg["tdlamret"]

                        vpredbefore = seg[
                            "vpred"]  # predicted value function before update
                        atarg = (atarg - atarg.mean()) / (
                            atarg.std() + 1e-8
                        )  # standardized advantage function estimate

                        print('advantages: ', np.min(atarg), np.max(atarg),
                              np.mean(atarg))
                        # true_rew is the reward without discount
                        if writer is not None:
                            total_episode_reward_logger(
                                self.episode_reward,
                                seg["true_rewards"].reshape(
                                    (self.n_envs, -1)), seg["dones"].reshape(
                                        (self.n_envs, -1)), writer,
                                self.num_timesteps)

                        args = seg["observations"], seg["observations"], seg[
                            "actions"], atarg
                        # Subsampling: see p40-42 of John Schulman thesis
                        # http://joschu.net/docs/thesis.pdf
                        fvpargs = [arr[::5] for arr in args]

                        self.assign_old_eq_new(sess=self.sess)

                        with self.timed("computegrad"):
                            steps = self.num_timesteps + (k + 1) * (
                                seg["total_timestep"] / self.g_step)
                            run_options = tf.RunOptions(
                                trace_level=tf.RunOptions.FULL_TRACE)
                            run_metadata = tf.RunMetadata(
                            ) if self.full_tensorboard_log else None

                            _, grad, *lossbefore = self.compute_lossandgrad(
                                *args,
                                tdlamret,
                                sess=self.sess,
                                options=run_options,
                                run_metadata=run_metadata)

                        print(f'losses before', lossbefore)
                        lossbefore = self.allmean(np.array(lossbefore))
                        grad = self.allmean(grad)
                        if np.allclose(grad, 0):
                            logger.log("Got zero gradient. not updating")
                        else:
                            with self.timed("conjugate_gradient"):
                                stepdir = conjugate_gradient(
                                    fisher_vector_product,
                                    grad,
                                    cg_iters=self.cg_iters,
                                    verbose=self.rank == 0
                                    and self.verbose >= 1)
                            assert np.isfinite(stepdir).all()
                            shs = .5 * stepdir.dot(
                                fisher_vector_product(stepdir))
                            # abs(shs) to avoid taking square root of negative values
                            lagrange_multiplier = np.sqrt(
                                abs(shs) / self.max_kl)
                            # logger.log("lagrange multiplier:", lm, "gnorm:", np.linalg.norm(g))
                            fullstep = stepdir / lagrange_multiplier
                            expectedimprove = grad.dot(fullstep)
                            surrbefore = lossbefore[0]
                            stepsize = 1.0
                            thbefore = self.get_flat()
                            for _ in range(10):
                                thnew = thbefore + fullstep * stepsize
                                self.set_from_flat(thnew)
                                mean_losses = surr, kl_loss, *_ = self.allmean(
                                    np.array(
                                        self.compute_losses(*args,
                                                            sess=self.sess)))
                                improve = surr - surrbefore
                                logger.log("Expected: %.3f Actual: %.3f" %
                                           (expectedimprove, improve))
                                if not np.isfinite(mean_losses).all():
                                    logger.log(
                                        "Got non-finite value of losses -- bad!"
                                    )
                                elif kl_loss > self.max_kl * 1.5:
                                    logger.log(
                                        "violated KL constraint. shrinking step."
                                    )
                                elif improve < 0:
                                    logger.log(
                                        "surrogate didn't improve. shrinking step."
                                    )
                                else:
                                    logger.log("Stepsize OK!")
                                    break
                                stepsize *= .5
                            else:
                                logger.log("couldn't compute a good step")
                                self.set_from_flat(thbefore)
                            if self.nworkers > 1 and iters_so_far % 20 == 0:
                                # list of tuples
                                paramsums = MPI.COMM_WORLD.allgather(
                                    (thnew.sum(), self.vfadam.getflat().sum()))
                                assert all(
                                    np.allclose(ps, paramsums[0])
                                    for ps in paramsums[1:])

                            for (loss_name,
                                 loss_val) in zip(self.loss_names,
                                                  mean_losses):
                                logger.record_tabular(loss_name, loss_val)

                        with self.timed("vf"):
                            for _ in range(self.vf_iters):
                                # NOTE: for recurrent policies, use shuffle=False?
                                for (mbob, mbret) in dataset.iterbatches(
                                    (seg["observations"], seg["tdlamret"]),
                                        include_final_partial_batch=False,
                                        batch_size=128,
                                        shuffle=True):
                                    grad = self.allmean(
                                        self.compute_vflossandgrad(
                                            mbob, mbob, mbret, sess=self.sess))
                                    self.vfadam.update(grad, self.vf_stepsize)

                    # Stop training early (triggered by the callback)
                    if not seg.get('continue_training', True):  # pytype: disable=attribute-error
                        break

                    logger.record_tabular(
                        "explained_variance_tdlam_before",
                        explained_variance(vpredbefore, tdlamret))

                    # lr: lengths and rewards
                    lr_local = (seg["ep_lens"], seg["ep_rets"])  # local values
                    list_lr_pairs = MPI.COMM_WORLD.allgather(
                        lr_local)  # list of tuples
                    lens, rews = map(flatten_lists, zip(*list_lr_pairs))

                    len_buffer.extend(lens)
                    reward_buffer.extend(rews)

                    if len(len_buffer) > 0:
                        logger.record_tabular("EpLenMean", np.mean(len_buffer))
                        logger.record_tabular("EpRewMean",
                                              np.mean(reward_buffer))

                    logger.record_tabular("EpThisIter", len(lens))
                    episodes_so_far += len(lens)
                    current_it_timesteps = MPI.COMM_WORLD.allreduce(
                        seg["total_timestep"])
                    timesteps_so_far += current_it_timesteps
                    self.num_timesteps += current_it_timesteps
                    iters_so_far += 1

                    logger.record_tabular("EpisodesSoFar", episodes_so_far)
                    logger.record_tabular("TimestepsSoFar", self.num_timesteps)
                    logger.record_tabular("TimeElapsed", time.time() - t_start)

                    if self.verbose >= 1 and self.rank == 0:
                        logger.dump_tabular()

        callback.on_training_end()
        return self