예제 #1
0
def linesearch2(f,
                x,
                fullstep,
                max_kl,
                max_backtracks=20,
                backtrack_ratio=.8,
                allow_backwards_steps=False):
    fval, kl = f(x)
    logger.log("fval/kl before %f/%f" % (fval, kl))

    for (_n_backtracks,
         stepfrac) in enumerate(backtrack_ratio**np.arange(max_backtracks)):
        xnew = x - stepfrac * fullstep
        newfval, newkl = f(xnew)
        actual_improve = fval - newfval
        # expected_improve = expected_improve_rate * stepfrac
        # ratio = actual_improve / expected_improve
        # if ratio > accept_ratio and actual_improve > 0:
        logger.log(("a/kl %f/%f" % (actual_improve, newkl)))
        if not allow_backwards_steps:
            if newfval < fval and newkl <= max_kl:
                logger.log("backtrack iters: %d" % _n_backtracks)
                return xnew
        else:
            if newkl <= max_kl:
                logger.log("backtrack iters: %d" % _n_backtracks)
                return xnew
        logger.log("backtrack iters: %d" % _n_backtracks)
    return x
예제 #2
0
    def step(self):
        disc_process = time.time()
        mins, maxs, aves, stds = [], [], [], []
        if self.inverse:
            logger.log("Training policy and extracting rewards...")
            itera, paths = self.algo.step(
                paths_processor=lambda x:
                process_samples_with_reward_extractor(
                    x, self.discriminator, batch_size=50000))
        else:
            num_updates = 1
            if self.is_first_noninverse:
                num_updates = 1
                self.is_first_noninverse = False
            itera, paths = self.algo.step(train_steps=num_updates)
            self.iterations += 1
            return self.iterations

        # unroll and stack novice observations
        logger.log("Processing rollouts for discriminator....")
        disc_process = time.time()
        observations = [
            step for path in paths for step in path["observations"]
        ]
        observations = np.vstack(observations)
        novice_section = len(observations)

        # subsample experts according to size of observations
        #idx = np.random.randint(len(self.expert_rollouts_tensor), size=novice_section)
        useful_expert_rollouts = self.expert_rollouts_tensor  #[idx, :]

        observations = np.concatenate([observations, useful_expert_rollouts],
                                      axis=0)
        if hasattr(self.discriminator, "ob_rms"):
            self.discriminator.ob_rms.update(
                observations, self.discriminator.session
            )  # update running mean/std for policy

        labels = np.zeros((len(observations), ))
        labels[novice_section:] = 1.0
        labels = labels.reshape((-1, 1))

        #observations, labels = shuffle(observations, labels)
        disc_proce_time = (time.time() - disc_process) / 60.0
        logger.log("Processed rollouts for disc in %f ...." % disc_proce_time)

        # TODO: merge rollouts with experts and add labels
        logger.log("Updating disc with %d rollouts" % novice_section)
        disc_process = time.time()
        self.discriminator.step(observations, labels)
        disc_proce_time = (time.time() - disc_process) / 60.0
        logger.log("Updated disc in %f ...." % disc_proce_time)

        external_parameters = self.discriminator.get_external_parameters()

        if external_parameters is not None:
            self.algo.set_external_parameters(external_parameters)

        self.iterations += 1
        return self.iterations
예제 #3
0
def process_samples_with_reward_extractor(samples, discriminator, batch_size=50000, obs_filter=lambda x: x):
    t0 = time.time()
    super_all_datas = []
    
    for sample in samples:
        sample["observations"] = np.vstack([obs_filter(step) for step in sample["observations"]])
    # splits = []
    # convert all the data to the proper format, concat frames if needed
    super_all_datas = np.vstack([obs_filter(x) for sample in samples for x in sample['observations']])

    extracted_rewards = []
    for batch in batchify_list(super_all_datas, batch_size): #TODO: make batch_size configurable
        extracted_rewards.extend(discriminator.get_reward(batch)) #TODO: unnecessary computation here

    index = 0
    extracted_rewards = np.vstack(extracted_rewards)
    for sample in samples:#len(extracted_rewards):
        sample['true_rewards'] = sample['raw_rewards']
        num_obs = len(sample['observations'])
        sample['raw_rewards'] = select_from_tensor(extracted_rewards, index, index+num_obs).reshape(-1)
        sample['rewards'] = select_from_tensor(extracted_rewards, index, index+num_obs).reshape(-1)
        if len(sample['true_rewards']) != len(sample['rewards']):
            import pdb; pdb.set_trace()
            raise Exception("Problem, extracted rewards not equal in length to old rewards!")
        index += num_obs

    t1 = time.time()
    logger.log("Time to process samples: %f" % (t1-t0))
    return samples
예제 #4
0
    def run(self):
        config = tf.ConfigProto(
            device_count = {'GPU': 0},
            #gpu_options = tf.GPUOptions(allow_growth=True)
        )
        self.session = tf.Session(config=config)
        self.make_model()

        while True:
            task = self.task_q.get()
            if task.code == LearnerTask.KILL_CODE:
                # TODO: self.terminate?
                self.task_q.cancel_join_thread()
                self.session.close()
                return
            elif task.code == LearnerTask.SET_EXTERNAL_POLICY_VALUES:
                weights = task.extra_params["weights"]
                print("Setting extrernal policy params learning task")
                self.policy.set_external_values(self.session, weights)
                time.sleep(0.1)
                self.task_q.task_done()
            elif task.code == LearnerTask.GET_PARAMS_CODE:
                self.task_q.task_done()
                self.result_q.put(LearnerResult(self.policy.get_param_values(self.session, optimizer_params=False)))
                print("Getting model %d params" % len(self.policy.get_param_values(self.session, optimizer_params=False)))
            elif task.code == LearnerTask.PUT_PARAMS_CODE:
                params = next_task.extra_params["weights"]
                logger.log("Setting model %d params" % len(params))
                # print(params)
                # the task is to set parameters of the actor policy
                self.policy.set_param_values(self.session, params)
                # super hacky method to make sure when we fill the queue with set parameter tasks,
                # an actor doesn't finish updating before the other actors can accept their own tasks.
                time.sleep(0.1)
                self.task_q.task_done()
            elif task.code == LearnerTask.ADJUST_MAX_KL:
                self.task_q.task_done()
                self.args.max_kl = task.extra_params['max_kl']
            elif task.code == LearnerTask.LEARN_PATHS:
                paths = task.extra_params['paths']
                stats = self.learn(paths)
                self.task_q.task_done()
                self.result_q.put(LearnerResult(self.policy.get_param_values(self.session, optimizer_params=False), stats))
            elif task.code == LearnerTask.REBUILD_NET:
                self.timesteps_so_far = 0 
                self.policy.rebuild_net(**task.extra_params)
                initialize_uninitialized(self.session)
                self.make_model()
                self.task_q.task_done()
                self.result_q.put(LearnerResult(self.policy.get_param_values(self.session, optimizer_params=False), stats))
            else:
                logger.log("Received unknown code! (%d)" % task.code)

        return
예제 #5
0
def load_expert_rollouts(filepath, max_traj_len = -1, num_expert_rollouts = 10):
    # why encoding? http://stackoverflow.com/questions/11305790/pickle-incompatability-of-numpy-arrays-between-python-2-and-3
    expert_rollouts = pickle.load(open(filepath, "rb"), encoding='latin1')

    # In the case that we only have one expert rollout in the file
    if type(expert_rollouts) is dict:
        expert_rollouts = [expert_rollouts]

    expert_rollouts = expert_rollouts[:min(len(expert_rollouts), num_expert_rollouts)]

    if max_traj_len > 0:
        expert_rollouts = [shorten_tensor_dict(x, traj_len) for x in expert_rollouts]

    # TODO: change this to logging
    logger.log("Average reward for expert rollouts: %f" % np.mean([np.sum(p['rewards']) for p in expert_rollouts]))
    return expert_rollouts
예제 #6
0
    def step(self, observations, labels, aux_logging=True, clip_weights=False):
        if self.normalize_obs:
            self.inputnorm.update(observations)

        ops = [
            self.train_op, self.loss, self.accuracy, self.accuracy_for_expert,
            self.accuracy_for_currpolicy, self.l2_loss
        ]

        for i in range(self.num_epochs_per_step):
            op_returns = self.session.run(ops,
                                          feed_dict={
                                              self.obs: observations,
                                              self.targets: labels
                                          })
            if clip_weights:
                self.session.run([self.clip_disc_weights_op])

        logger.log("Loss: %f" % op_returns[1])
        logger.log("LossL2: %f" % op_returns[5])
        logger.log("Accuracy: %f" % op_returns[2])
        logger.log("Accuracy (policy): %f" % op_returns[4])
        logger.log("Accuracy (expert): %f" % op_returns[3])
예제 #7
0
def linesearch(f,
               x,
               fullstep,
               expected_improve_rate,
               max_backtracks=10,
               accept_ratio=.1,
               backtrack_ratio=.5):
    fval, kl = f(x)
    logger.log("fval before %f" % fval)

    for (_n_backtracks,
         stepfrac) in enumerate(backtrack_ratio**np.arange(max_backtracks)):
        xnew = x + stepfrac * fullstep
        newfval, kl = f(xnew)
        actual_improve = fval - newfval
        expected_improve = expected_improve_rate * stepfrac
        ratio = actual_improve / expected_improve
        logger.log(
            ("a/e/r %f/%f/%f" % (actual_improve, expected_improve, ratio)))
        if ratio > accept_ratio and actual_improve > 0:
            logger.log("fval after: %f" % (newfval))
            return xnew
    return x
예제 #8
0
파일: model.py 프로젝트: Theling/OptionGAN
    def learn(self, paths):
        # is it possible to replace A(s,a) with Q(s,a)?
        for path in paths:
            b = path["baseline"] = self.vf.predict(path, self.session)
            b1 = np.append(b, 0 if path["terminated"] else b[-1])
            deltas = path["rewards"] + self.args.gamma * b1[1:] - b1[:-1]
            path["advantage"] = discount(deltas,
                                         self.args.gamma * self.args.lam)
            path["returns"] = discount(path["rewards"], self.args.gamma)

        alladv = np.concatenate([path["advantage"] for path in paths])
        # Standardize advantage
        std = alladv.std()
        mean = alladv.mean()
        for path in paths:
            path["advantage"] = (path["advantage"] - mean) / (std + 1e-8)
        advant_n = np.concatenate([path["advantage"] for path in paths])

        # puts all the experiences in a matrix: total_timesteps x options
        # TODO: make this policy dependent like in rllab
        paths_concated = concat_tensor_dict_list(paths)

        action_dist_mu = paths_concated["info"][
            "action_dist_mu"]  #np.concatenate([path["info"]["action_dist_mu"] for path in paths])
        action_dist_logstd = paths_concated["info"][
            "action_dist_logstd"]  #np.concatenate([path["info"]["action_dist_logstd"] for path in paths])

        obs_n = paths_concated[
            "observations"]  #np.concatenate([path["observations"] for path in paths])
        action_n = paths_concated[
            "actions"]  #np.concatenate([path["actions"] for path in paths])

        # train value function / baseline on rollout paths
        self.vf.fit(paths, self.session)
        if hasattr(self.policy, "ob_rms") and not isinstance(
                self.policy, GatedGaussianMLPPolicy):
            # In the case of a GatedGaussian policy, we're going to share the gate/filter provided to us
            self.policy.ob_rms.update(obs_n, self.session)

        # TODO: make this policy dependent like in rllab
        feed_dict = {
            self.obs: obs_n,
            self.action: action_n,
            self.advantage: advant_n,
            self.oldaction_dist_mu: action_dist_mu,
            self.oldaction_dist_logstd: action_dist_logstd
        }

        feed_dict.update(
            self.policy.get_extra_inputs(self.session, obs_n,
                                         paths_concated["info"]))

        loss_before, kl_before, entropy_before = self.session.run(
            self.losses, feed_dict)

        logger.log("loss_before, kl_before, ent_before : %f,%f,%f" %
                   (loss_before, kl_before, np.mean(entropy_before)))

        # parameters
        thprev = self.gf()

        fisher_vector_product = lambda x: self.hvp_func(
            x, session=self.session, feed_dict=feed_dict)

        g = self.session.run(self.pg, feed_dict)

        if np.allclose(g, 0):
            print("got zero gradient. not updating")
            return {}

        # solve Ax = g, where A is Fisher information metrix and g is gradient of parameters
        # stepdir = A_inverse * g = x
        stepdir = conjugate_gradient(fisher_vector_product, g)

        # let stepdir =  change in theta / direction that theta changes in
        # KL divergence approximated by 0.5 x stepdir_transpose * [Fisher Information Matrix] * stepdir
        # where the [Fisher Information Matrix] acts like a metric
        # ([Fisher Information Matrix] * stepdir) is computed using the function,
        # and then stepdir * [above] is computed manually.

        shs = 0.5 * stepdir.dot(fisher_vector_product(stepdir))

        assert shs > 0

        lm = np.sqrt(shs / self.args.max_kl)

        fullstep = stepdir / lm

        logger.log("lagrange multiplier: %f gnorm: %f" %
                   (lm, np.linalg.norm(g)))

        def loss(th):
            self.sff(th)
            # surrogate loss: policy gradient loss
            return self.session.run(self.losses[:2], feed_dict)

        theta = linesearch2(loss, thprev, fullstep,
                            self.args.max_kl)  #negative_g_dot_steppdir / lm)
        self.sff(theta)

        surrogate_after, kl_after, entropy_after = self.session.run(
            self.losses, feed_dict)
        # print("new", self.session.run(self.action_dist_mu, feed_dict))

        if kl_after >= self.args.max_kl:
            logger.log(
                "Violated KL constraint, rejecting step! KL-After (%f)" %
                kl_after)
            self.sff(thprev)
        if np.isnan(surrogate_after) or np.isnan(kl_after):
            logger.log("Violated because loss or KL is NaN")
            self.sff(thprev)
        if loss_before <= surrogate_after:
            logger.log(
                "Violated because loss not improving... Prev (%f) After (%f)" %
                (loss_before, surrogate_after))
            self.sff(thprev)

        episoderewards = np.array(
            [path["raw_rewards"].sum() for path in paths])
        realepisoderewards = np.array(
            [path["rewards"].sum() for path in paths])
        stats = {}

        if "true_rewards" in paths[0]:
            truerewards = np.array(
                [path["true_rewards"].sum() for path in paths])
            stats["TrueAverageReturn"] = truerewards.mean()
            stats["TrueStdReturn"] = truerewards.std()
            stats["TrueMaxReturn"] = truerewards.max()
            stats["TrueMinReturn"] = truerewards.min()

        logger.log("Min return agent_id: %d" %
                   min(paths, key=lambda x: x["rewards"].sum())["agentid"])

        stats["ProcessedAverageReturn"] = realepisoderewards.mean()
        stats["ProcessedStdReturn"] = realepisoderewards.std()
        stats["ProcessedMaxReturn"] = realepisoderewards.max()
        stats["ProcessedMinReturn"] = realepisoderewards.min()
        baseline_paths = np.array([path['baseline'].sum() for path in paths])
        advantage_paths = np.array([path['advantage'].sum() for path in paths])
        stats["BaselineAverage"] = baseline_paths.mean()
        stats["AdvantageAverage"] = advantage_paths.mean()
        stats["Episodes"] = len(paths)
        stats["EpisodeAveLength"] = np.mean(
            [len(path["rewards"]) for path in paths])
        stats["EpisodeStdLength"] = np.std(
            [len(path["rewards"]) for path in paths])
        stats["EpisodeMinLength"] = np.min(
            [len(path["rewards"]) for path in paths])
        stats["EpisodeMaxLength"] = np.max(
            [len(path["rewards"]) for path in paths])
        stats["RawAverageReturn"] = episoderewards.mean()
        stats["RawStdReturn"] = episoderewards.std()
        stats["RawMaxReturn"] = episoderewards.max()
        stats["RawMinReturn"] = episoderewards.min()
        stats["Entropy"] = entropy_after
        stats["MaxKL"] = self.args.max_kl
        stats["Timesteps"] = sum([len(path["raw_rewards"]) for path in paths])
        # stats["Time elapsed"] = "%.2f mins" % ((time.time() - start_time) / 60.0)
        stats["KLDifference"] = kl_after
        stats["SurrogateLoss"] = surrogate_after
        # print ("\n********** Iteration {} ************".format(i))
        for k, v in sorted(stats.items()):
            logger.record_tabular(k, v)
        logger.dump_tabular()

        return stats
예제 #9
0
파일: model.py 프로젝트: Theling/OptionGAN
    def evaluate(self, epoch):
        logger.log("Collecting samples for evaluation")

        num_samples = 0
        paths = []
        # import pdb; pdb.set_trace()

        while num_samples < self.eval_samples:
            path = rollout(self.env, self.policy, self.max_path_length,
                           self.sess)
            num_samples += len(path["rewards"])
            paths.append(path)

        self.env.reset()

        returns = [np.sum(path["rewards"]) for path in paths]

        average_discounted_return = np.mean([
            discount_return(path["rewards"], self.discount_gamma)
            for path in paths
        ])

        all_qs = np.concatenate(self.q_averages)
        all_ys = np.concatenate(self.y_averages)

        average_q_loss = np.mean(self.qf_loss_averages)
        average_policy_surr = np.mean(self.policy_surr_averages)
        average_action = np.mean(
            np.square(np.concatenate([path["actions"] for path in paths])))

        # policy_reg_param_norm = np.linalg.norm(self.policy.get_param_values(self.sess))
        # qfun_reg_param_norm = np.linalg.norm(self.qf.get_param_values(self.sess))

        logger.record_tabular('Epoch', epoch)
        logger.record_tabular('Iteration', epoch)
        logger.record_tabular('AverageReturn', np.mean(returns))
        logger.record_tabular('StdReturn', np.std(returns))
        logger.record_tabular('MaxReturn', np.max(returns))
        logger.record_tabular('MinReturn', np.min(returns))
        if len(self.termination_averages) > 0:
            logger.record_tabular('TerminationVal',
                                  np.mean(self.termination_averages))
        if len(self.es_path_returns) > 0:
            logger.record_tabular('AverageEsReturn',
                                  np.mean(self.es_path_returns))
            logger.record_tabular('StdEsReturn', np.std(self.es_path_returns))
            logger.record_tabular('MaxEsReturn', np.max(self.es_path_returns))
            logger.record_tabular('MinEsReturn', np.min(self.es_path_returns))
        if hasattr(self.es, 'get_and_clear_losses'):
            logger.record_tabular('ESLoss', self.es.get_and_clear_losses())
        logger.record_tabular('AverageDiscountedReturn',
                              average_discounted_return)
        logger.record_tabular('AverageQLoss', average_q_loss)
        logger.record_tabular('AverageQL2Loss',
                              np.mean(self.qf_l2_loss_averages))
        logger.record_tabular('AveragePolicySurr', average_policy_surr)
        logger.record_tabular('AveragePolicyL2Loss',
                              np.mean(self.policy_l2_loss_averages))
        logger.record_tabular('AverageQ', np.mean(all_qs))
        logger.record_tabular('AverageAbsQ', np.mean(np.abs(all_qs)))
        logger.record_tabular('AverageY', np.mean(all_ys))
        logger.record_tabular('AverageAbsY', np.mean(np.abs(all_ys)))
        logger.record_tabular('AverageAbsQYDiff',
                              np.mean(np.abs(all_qs - all_ys)))
        if self.dual_asynchronous_q:
            all_q2s = np.concatenate(self.qf2values)
            logger.record_tabular('AverageQ2Loss', np.mean(self.qf2losses))
            logger.record_tabular('AverageQ2', np.mean(all_q2s))
            logger.record_tabular('AverageAbsQ2', np.mean(np.abs(all_q2s)))
        logger.record_tabular('AverageAction', average_action)
        logger.record_tabular('TotalTimesteps', self.total_timesteps)

        self.qf_loss_averages = []
        self.policy_surr_averages = []

        self.q_averages = []
        self.y_averages = []
        self.termination_averages = []
        self.es_path_returns = []
        return paths
예제 #10
0
파일: model.py 프로젝트: Theling/OptionGAN
    def step(self):
        """
        Step in this case counts as an epoch
        """
        itr = self.itr
        path_length = 0
        path_return = 0
        terminal = False
        initial = False
        self.epoch += 1
        observation = self.env.reset()

        logger.push_prefix('epoch #%d | ' % self.epoch)
        logger.log("Training started")
        train_qf_itr, train_policy_itr = 0, 0

        path = []
        for epoch_itr in pyprind.prog_bar(range(self.epoch_length)):
            # Execute policy
            if terminal:
                # print("terminal")
                # Note that if the last time step ends an episode, the very
                # last state and observation will be ignored and not added
                # to the replay pool
                observation = self.env.reset()
                if self.dual_asynchronous_q:
                    self.train_second_q(path)
                # sample_policy.reset()
                self.es_path_returns.append(path_return)
                self.es.reset()
                path_length = 0
                path_return = 0
                initial = True
            else:
                initial = False

            on_policy = False
            if (not self.average_on_policy_updates) or (
                    self.average_on_policy_updates
                    and bool(random.getrandbits(1))):
                action = self.es.act(itr,
                                     observation,
                                     sess=self.sess,
                                     policy=self.sample_policy)  # qf=qf)
            else:
                on_policy = True
                action, _ = self.sample_policy.act(
                    observation, self.sess
                )  # self.es.act(itr, observation, sess=self.sess, policy=)  # qf=qf)

            next_observation, reward, terminal, _ = self.env.step(action)
            self.total_timesteps += 1
            path_length += 1
            path_return += reward

            if not terminal and path_length >= self.max_path_length:
                terminal = True
                # TODO: fix this, This is only true of tasks where ending early is bad
                failure = False
                # only include the terminal transition in this case if the flag was set
                if self.include_horizon_terminal_transitions:
                    # TODO: initial?
                    if on_policy:
                        self.on_policy_pool.add(observation, action,
                                                reward * self.scale_reward,
                                                next_observation, terminal,
                                                failure)
                    else:
                        self.pool.add(observation, action,
                                      reward * self.scale_reward,
                                      next_observation, terminal, failure)
            else:
                failure = True
                sample = (observation, action, reward * self.scale_reward,
                          next_observation, terminal, failure)
                if on_policy:
                    self.on_policy_pool.add(*sample)
                else:
                    self.pool.add(*sample)

                self.pool.add(*sample)
                path.append(sample)

            observation = next_observation

            if len(self.pool) >= self.min_pool_size and (
                (not self.average_on_policy_updates)
                    or len(self.on_policy_pool) >= self.min_pool_size):
                if not self.soft_target and self.total_timesteps % self.target_network_update_freq == 0:
                    self.sess.run(self.update_target_policy)
                    self.sess.run(self.update_target_q)

                for update_itr in range(self.n_updates_per_sample):
                    # Train policy
                    if hasattr(self, "beta_schedule"):
                        batch = self.pool.sample(self.batch_size,
                                                 beta=self.beta_schedule.value(
                                                     self.total_timesteps))
                    else:
                        batch = self.pool.sample(self.batch_size)

                    on_polbatch = None

                    if self.average_on_policy_updates:
                        if hasattr(self, "beta_schedule"):
                            on_polbatch = self.on_policy_pool.sample(
                                self.batch_size,
                                beta=self.beta_schedule.value(
                                    self.total_timesteps))
                        else:
                            on_polbatch = self.on_policy_pool.sample(
                                self.batch_size)

                    itrs = self.do_training(itr, batch, on_polbatch)

                    train_qf_itr += itrs[0]
                    train_policy_itr += itrs[1]
                self.sess.run(self.update_sample_policy)

            itr += 1
            if time.time() - self.gc_dump_time > 100:
                gc.collect()
                self.gc_dump_time = time.time()

        logger.log("Training finished")
        logger.log("Trained qf %d steps, policy %d steps" %
                   (train_qf_itr, train_policy_itr))
        rollouts = []
        if len(self.pool) >= self.min_pool_size and (
            (not self.average_on_policy_updates)
                or len(self.on_policy_pool) >= self.min_pool_size):
            rollouts = self.evaluate(self.epoch)
        logger.dump_tabular(with_prefix=False)
        logger.pop_prefix()

        self.itr = itr
        return itr, rollouts
예제 #11
0
    def run(self):
        self.env = gym.make(self.args.task)

        self.env.seed(randint(0, 999999))
        if self.monitor:
            self.env.monitor.start('monitor/', force=True)

        # self.observation_filter, self.reward_filter = get_filters(self.args, self.env.observation_space)

        # tensorflow variables (same as in model.py)
        self.observation_size = self.env.observation_space.shape[0]
        self.action_size = np.prod(self.env.action_space.shape)

        # tensorflow model of the policy
        self.obs = self.policy.obs
        self.debug = tf.constant([2, 2])
        # self.action_dist_mu = self.policy.action_dist_mu
        # self.action_dist_logstd = self.policy.action_dist_logstd

        config = tf.ConfigProto(device_count={'GPU': 0})
        self.session = tf.Session(config=config)
        initialize_uninitialized(self.session)
        # self.session.run(tf.global_variables_initializer())

        while True:
            # get a task, or wait until it gets one
            next_task = self.task_q.get(block=True)

            if next_task.code == SamplingTask.COLLECT_SAMPLES_CODE:
                # the taskprint is an actor request to collect experience
                path = self.rollout()
                self.task_q.task_done()
                self.result_q.put(SamplingResult(path))
            elif next_task.code == SamplingTask.SET_EXTERNAL_POLICY_VALUES:
                weights = next_task.extra_params["weights"]
                print("Setting extrernal policy params")
                self.policy.set_external_values(self.session, weights)
                time.sleep(0.1)
                self.task_q.task_done()
            elif next_task.code == SamplingTask.KILL_CODE:
                logger.log("kill message")
                if self.monitor:
                    self.env.monitor.close()
                self.task_q.task_done()
                return
            elif next_task.code == SamplingTask.SET_ENV_TASK:
                env_name = next_task.extra_params["env"]
                print("setting new env! %s" % env_name)
                self.set_env(env_name)
                time.sleep(0.2)
                self.task_q.task_done()
            elif next_task.code == SamplingTask.PUT_PARAMS_CODE:
                params = next_task.extra_params["policy"]
                logger.log("Setting model %d params" % len(params))
                # print(params)
                # the task is to set parameters of the actor policy
                print("Setting rollout policy values of",
                      [x.name for x in self.policy.get_params()])
                self.policy.set_param_values(self.session, params)
                # super hacky method to make sure when we fill the queue with set parameter tasks,
                # an actor doesn't finish updating before the other actors can accept their own tasks.
                time.sleep(0.1)
                self.task_q.task_done()
            elif next_task.code == SamplingTask.REBUILD_NET:
                print("Rebuilding net with args: ", next_task.extra_params)
                self.policy.rebuild_net(**next_task.extra_params)
                initialize_uninitialized(self.session)
                time.sleep(0.2)
                self.task_q.task_done()
            else:
                logger.log("Rollout thread got unknown task...")
        logger.log("Rollout thread dying")
        return
예제 #12
0
    def step(self, paths=None, paths_processor=lambda x: x, train_steps=1):
        logger.log("................Starting iteration................")
        with logger.prefix('itr #%d | ' % self.iteration):
            # runs a bunch of async processes that collect rollouts
            logger.log("Iteration %d" % self.iteration)
            if paths is None:
                rollout_start = time.time()

                paths = self.rollouts.rollout()

                rollout_time = (time.time() - rollout_start) / 60.0

            self.total_episodes += len(paths)
            self.total_timesteps += sum(
                [len(path["raw_rewards"]) for path in paths])
            logger.log("CumulativeEpisodes: %d" % self.total_episodes)
            logger.log("CumulativeTimesteps: %d" % self.total_timesteps)

            paths = paths_processor(
                paths
            )  # reward from discriminator is replaced in this function

            # Why is the learner in an async process?
            # Well, it turns out tensorflow has an issue: when there's a tf.Session in the main thread
            # and an async process creates another tf.Session, it will freeze up.
            # To solve this, we just make the learner's tf.Session in its own async process,
            # and wait until the learner's done before continuing the main thread.
            learn_start = time.time()
            for i in range(train_steps):
                self.learner_tasks.put(LearnFromPathsTask(paths))
                self.learner_tasks.join()
            results = self.learner_results.get()
            new_policy_weights, stats = results.policy, results.stats

            mean_reward = stats["RawAverageReturn"]
            std_reward = stats["RawStdReturn"]
            if "gate_dist" in paths[0]["info"]:
                gate_dists = []
                maxgate_dists = []
                mingate_dists = []
                for path in paths:
                    gate_dist = np.mean(path["info"]["gate_dist"], axis=0)
                    maxgate_dist = np.max(path["info"]["gate_dist"], axis=0)
                    mingate_dist = np.min(path["info"]["gate_dist"], axis=0)
                    gate_dists.append(gate_dist)
                    maxgate_dists.append(maxgate_dist)
                    mingate_dists.append(mingate_dist)
                gate_dists = np.vstack(gate_dists)
                logger.record_tabular("MeanGateDist",
                                      np.mean(gate_dists, axis=0))
                logger.record_tabular("MinGateDist",
                                      np.mean(mingate_dists, axis=0))
                logger.record_tabular("MaxGateDist",
                                      np.mean(maxgate_dists, axis=0))
                print(paths[0]["info"]["gate_dist"])

            learn_time = (time.time() - learn_start) / 60.0

            self.recent_total_reward += mean_reward

            logger.log("Total time: %.2f mins" %
                       ((time.time() - self.start_time) / 60.0))
            logger.log("Current steps is " + str(self.timesteps_per_batch) +
                       " and KL is " + str(self.max_kl))

            self.totalsteps += self.timesteps_per_batch
            self.prev_mean_reward = mean_reward
            self.prev_std_reward = std_reward
            self.iteration += 1
            logger.log("%d total steps have happened" % self.totalsteps)

            self.rollouts.set_policy_weights(
                new_policy_weights)  # Update weights for each process
            self.policy_weights = new_policy_weights

            logger.dump_tabular(with_prefix=False)
        return self.iteration, paths
예제 #13
0
파일: model.py 프로젝트: Theling/OptionGAN
    def learn(self, paths):

        if self.schedule is 'constant':
            cur_lrmult = 1.0
            print("Using constant schedule")
        elif self.schedule is 'linear':
            cur_lrmult = max(
                1.0 - float(self.timesteps_so_far) / self.max_timesteps, 0.01)
            print("Using linear schedule")
        elif self.schedule is 'quadratic':
            if self.timesteps_so_far > 0 and self.timesteps_so_far % 50e4 == 0:
                cur_lrmult *= .5
            else:
                cur_lrmult = 1.0
        else:
            raise NotImplementedError

        for path in paths:
            b = path["baseline"] = self.vf.predict(path, self.session)
            b1 = np.append(b, 0 if path["terminated"] else b[-1])
            deltas = path["rewards"] + self.args.gamma * b1[1:] - b1[:-1]
            path["advantage"] = discount(deltas,
                                         self.args.gamma * self.args.lam)
            path["returns"] = discount(path["rewards"], self.args.gamma)

        alladv = np.concatenate([path["advantage"] for path in paths])
        # Standardize advantage
        std = alladv.std()
        mean = alladv.mean()
        for path in paths:
            path["advantage"] = (path["advantage"] - mean) / (std + 1e-8)
        advant_n = np.concatenate([path["advantage"] for path in paths])

        # puts all the experiences in a matrix: total_timesteps x options
        # TODO: make this policy dependent like in rllab
        paths_concated = concat_tensor_dict_list(paths)

        action_dist_mu = paths_concated["info"][
            "action_dist_mu"]  #np.concatenate([path["info"]["action_dist_mu"] for path in paths])
        action_dist_logstd = paths_concated["info"][
            "action_dist_logstd"]  #np.concatenate([path["info"]["action_dist_logstd"] for path in paths])
        obs_n = paths_concated[
            "observations"]  #np.concatenate([path["observations"] for path in paths])
        action_n = paths_concated[
            "actions"]  #np.concatenate([path["actions"] for path in paths])

        # TODO: make this policy dependent like in rllab
        feed_dict = {
            self.obs: obs_n,
            self.action: action_n,
            self.advantage: advant_n,
            self.oldaction_dist_mu: action_dist_mu,
            self.oldaction_dist_logstd: action_dist_logstd
        }

        feed_dict.update(
            self.policy.get_extra_inputs(self.session, obs_n,
                                         paths_concated["info"]))
        feed_dict.update({self.learning_rate_multiplier: cur_lrmult})
        # if isinstance(self.vf, MLPConstrainedValueFunction):
        #     feed_dict.update(self.vf.get_feed_vals(paths, self.session))

        if not isinstance(self.vf, MLPConstrainedValueFunction):
            self.vf.fit(paths, self.session)

        # if hasattr(self.policy, "ob_rms") and not isinstance(self.policy, GatedGaussianMLPPolicy):
        #     # In the case of a GatedGaussian policy, we're going to share the gate/filter provided to us
        #     self.policy.ob_rms.update(obs_n, self.session)

        losses = []
        for _ in range(self.optim_epochs):
            # losses = [] # list of tuples, each of which gives the loss for a minibatch
            # for i in range(self.optim_epochs):
            # TODO: batchify
            _, newlosses = self.session.run([self.train_op, self.losses],
                                            feed_dict)
            losses.append(newlosses)

        surrogate_after, kl_after, entropy_after = self.session.run(
            self.losses, feed_dict)

        episoderewards = np.array(
            [path["raw_rewards"].sum() for path in paths])
        realepisoderewards = np.array(
            [path["rewards"].sum() for path in paths])
        stats = {}

        if "true_rewards" in paths[0]:
            truerewards = np.array(
                [path["true_rewards"].sum() for path in paths])
            stats["TrueAverageReturn"] = truerewards.mean()
            stats["TrueStdReturn"] = truerewards.std()
            stats["TrueMaxReturn"] = truerewards.max()
            stats["TrueMinReturn"] = truerewards.min()

        logger.log("Min return agent_id: %d" %
                   min(paths, key=lambda x: x["rewards"].sum())["agentid"])

        stats["ProcessedAverageReturn"] = realepisoderewards.mean()
        stats["ProcessedStdReturn"] = realepisoderewards.std()
        stats["ProcessedMaxReturn"] = realepisoderewards.max()
        stats["ProcessedMinReturn"] = realepisoderewards.min()
        baseline_paths = np.array([path['baseline'].sum() for path in paths])
        advantage_paths = np.array([path['advantage'].sum() for path in paths])
        stats["BaselineAverage"] = baseline_paths.mean()
        stats["AdvantageAverage"] = advantage_paths.mean()
        stats["Episodes"] = len(paths)
        stats["EpisodeAveLength"] = np.mean(
            [len(path["rewards"]) for path in paths])
        stats["EpisodeStdLength"] = np.std(
            [len(path["rewards"]) for path in paths])
        stats["EpisodeMinLength"] = np.min(
            [len(path["rewards"]) for path in paths])
        stats["EpisodeMaxLength"] = np.max(
            [len(path["rewards"]) for path in paths])
        stats["RawAverageReturn"] = episoderewards.mean()
        stats["RawStdReturn"] = episoderewards.std()
        stats["RawMaxReturn"] = episoderewards.max()
        stats["RawMinReturn"] = episoderewards.min()
        stats["Entropy"] = entropy_after
        stats["MaxKL"] = self.args.max_kl
        stats["Timesteps"] = sum([len(path["raw_rewards"]) for path in paths])
        # stats["Time elapsed"] = "%.2f mins" % ((time.time() - start_time) / 60.0)
        stats["KLDifference"] = kl_after
        stats["SurrogateLoss"] = surrogate_after
        # print ("\n********** Iteration {} ************".format(i))
        for k, v in sorted(stats.items()):
            logger.record_tabular(k, v)
        logger.dump_tabular()

        return stats
예제 #14
0
    def __init__(self,
                 observation_size,
                 hidden_sizes=(400, 300),
                 activation=tf.nn.tanh,
                 learning_rate=1e-4,
                 scope="discriminator",
                 normalize_obs=False,
                 ent_reg_weight=0.0,
                 gradient_penalty_weight=0.0,
                 l2_penalty_weight=0.001,
                 objective="regular",
                 use_rms_filter=False,
                 num_epochs_per_step=3):
        self.observation_size = observation_size
        # self.target_size = target_size
        self.normalize_obs = normalize_obs
        self.obs = tf.placeholder(tf.float32, [None, self.observation_size])
        self.targets = tf.placeholder(tf.float32, [None, 1])
        self.ent_reg_weight = ent_reg_weight
        self.num_epochs_per_step = num_epochs_per_step

        config = tf.ConfigProto(device_count={'GPU': 0})

        self.session = tf.Session(config=config)

        with tf.variable_scope(scope):
            if use_rms_filter:
                with tf.variable_scope("obfilter"):
                    self.ob_rms = RunningMeanStd(shape=(observation_size, ))

                net_input = tf.clip_by_value(
                    (self.obs - self.ob_rms.mean) / self.ob_rms.std, -5.0, 5.0)
            else:
                net_input = self.obs

            net = net_input
            for i, x in enumerate(hidden_sizes):
                net = tf.layers.dense(
                    inputs=net,
                    units=x,
                    activation=activation,
                    kernel_initializer=tf.random_uniform_initializer(
                        -0.05, 0.05),
                    name="discriminator_h%d" % i)
            net = tf.layers.dense(
                inputs=net,
                units=1,
                activation=None,
                kernel_initializer=tf.random_uniform_initializer(-0.05, 0.05),
                name="discriminator_outlayer")

        self.net_out = net
        # TODO: maybe make this non-deterministic? MC dropout and then use a normal distribution?
        # action_dist_logstd_param = tf.Variable((.01*np.random.randn(1, self.action_size)).astype(np.float32), name="policy_logstd")

        # loss function
        self.learning_rate = learning_rate

        var_list = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES,
                                     scope=scope)

        # Possible clipping from WGAN
        clip_ops = []
        for var in var_list:
            clip_bounds = [-.01, .01]
            clip_ops.append(
                tf.assign(
                    var, tf.clip_by_value(var, clip_bounds[0],
                                          clip_bounds[1])))
        self.clip_disc_weights_op = tf.group(*clip_ops)

        self.pred = tf.sigmoid(
            self.net_out)  #-tf.log(1.-tf.sigmoid(self.net_out))
        self.reward = tf.sigmoid(
            self.net_out)  #-tf.log(1.-tf.sigmoid(self.net_out))

        num_experts = tf.cast(tf.count_nonzero(self.targets), tf.int32)
        batch_size = tf.shape(self.obs)[0]

        #weights_B = tf.zeros(tf.shape(self.targets), tf.float32)
        #weights_bexp = weights_B[-num_experts:] + 1.0/(tf.cast(num_experts, tf.float32))
        #weights_bnovice = weights_B[:-num_experts] + 1.0/(tf.cast(batch_size - num_experts, tf.float32))
        #weights_B = tf.concat([weights_bnovice, weights_bexp], axis=0)

        if objective != "wgan":
            logger.log("Using sigmoid cross entropy discriminator objective")
            ent_B = logit_bernoulli_entropy(self.net_out)
            cross_entropy = tf.reduce_mean(
                tf.nn.sigmoid_cross_entropy_with_logits(
                    logits=self.net_out[:-num_experts],
                    labels=self.targets[:-num_experts]) -
                self.ent_reg_weight * ent_B[:-num_experts],
                axis=0)
            cross_entropy += tf.reduce_mean(
                tf.nn.sigmoid_cross_entropy_with_logits(
                    logits=self.net_out[-num_experts:],
                    labels=self.targets[-num_experts:]) -
                self.ent_reg_weight * ent_B[-num_experts:],
                axis=0)
            cross_entropy /= 2.

            #cross_entropy = tf.reduce_sum((cross_entropy - self.ent_reg_weight*ent_B)*weights_B, axis=0)
            self.loss = cross_entropy * 2.0  # reweighting to make focus on this instead of other penalty terms
        else:
            logger.log("Using wgan objective")
            disc_fake = self.net_out[:-num_experts]
            disc_real = self.net_out[-num_experts:]
            self.loss = tf.reduce_mean(disc_fake) - tf.reduce_mean(disc_real)

        self.l2_loss = tf.constant(0.0)

        if l2_penalty_weight > 0.0:
            loss_l2 = tf.add_n([
                tf.nn.l2_loss(v) for v in var_list
                if 'kernel' in v.name and not ('Adam' in v.name)
            ]) / float(len(var_list)) * l2_penalty_weight
            self.l2_loss = loss_l2
            self.loss += loss_l2

        if gradient_penalty_weight > 0.0:
            batch_size = tf.shape(self.obs)[0]
            smallest = tf.minimum(num_experts, batch_size - num_experts)

            alpha = tf.random_uniform(shape=[smallest, 1],
                                      minval=0.,
                                      maxval=1.)

            alpha_in = alpha * self.obs[-smallest:]
            beta_in = ((1 - alpha) * self.obs[:smallest])
            interpolates = alpha_in + beta_in
            net2 = interpolates
            with tf.variable_scope(scope, reuse=True):
                for i, x in enumerate(hidden_sizes):
                    net2 = tf.layers.dense(
                        inputs=net2,
                        units=x,
                        activation=tf.tanh,
                        kernel_initializer=tf.random_uniform_initializer(
                            -0.05, 0.05),
                        name="discriminator_h%d" % i)
                net2 = tf.layers.dense(
                    inputs=net2,
                    units=1,
                    activation=None,
                    kernel_initializer=tf.random_uniform_initializer(
                        -0.05, 0.05),
                    name="discriminator_outlayer")

            gradients = tf.gradients(net2, [interpolates])[0]
            gradients = tf.clip_by_value(gradients, -10., 10.)
            slopes = tf.sqrt(
                tf.reduce_sum(tf.square(gradients), reduction_indices=[1]))
            gradient_penalty = gradient_penalty_weight * tf.reduce_mean(
                (slopes - 1)**2)
            self.loss += gradient_penalty

        self.train_op = tf.train.AdamOptimizer(
            learning_rate=learning_rate).minimize(self.loss)
        comparison = tf.less(self.pred, tf.constant(0.5))
        comparison2 = tf.less(self.targets, tf.constant(0.5))
        overall = tf.cast(tf.equal(comparison, comparison2), tf.float32)
        accuracy = tf.reduce_mean(overall)  #, tf.ones_like(self.targets)))
        accuracy_for_currpolicy = tf.reduce_mean(
            overall[:-num_experts])  #, tf.ones_like(self.targets)))
        accuracy_for_expert = tf.reduce_mean(
            overall[-num_experts:])  #, tf.ones_like(self.targets)))
        self.accuracy = accuracy
        self.accuracy_for_currpolicy = accuracy_for_currpolicy
        self.accuracy_for_expert = accuracy_for_expert

        # aux values
        # label_accuracy = tf.equal(tf.round(self.pred), tf.round(self.targets))
        # self.label_accuracy = tf.reduce_mean(tf.cast(label_accuracy, tf.float32))
        # self.mse = tf.reduce_mean(tf.nn.l2_loss(self.pred - self.targets))
        # ones = tf.ones_like(self.targets)
        #
        # true_positives = tf.round(self.pred) * tf.round(self.targets)
        # predicted_positives = tf.round(self.pred)
        #
        # false_negatives = tf.logical_not(tf.logical_xor(tf.equal(tf.round(self.pred), ones), tf.equal(tf.round(self.targets), ones)))
        #
        # self.label_precision = tf.reduce_sum(tf.cast(true_positives, tf.float32)) / tf.reduce_sum(tf.cast(predicted_positives, tf.float32))
        # self.label_recall = tf.reduce_sum(tf.cast(true_positives, tf.float32)) / (tf.reduce_sum(tf.cast(true_positives, tf.float32)) + tf.reduce_sum(tf.cast(false_negatives, tf.float32)))

        initialize_uninitialized(self.session)
예제 #15
0
    def step(self, observations, labels, aux_logging=True):
        if self.normalize_obs:
            self.inputnorm.update(observations)

        ops = [self.train_op,
               self.loss,
               self.accuracy,
               self.accuracy_for_expert,
               self.accuracy_for_currpolicy,
               self.mi,
               self.cv,
               self.lambda_s_loss,
               self.lambda_v_loss,
               self.termination_function,
               self.l2_loss,
               self.cross_entropy_loss,
               self.gate_change,
               self.confidence]

        old_gate_vals = self.session.run(self.termination_function, feed_dict={self.obs : observations, self.targets : labels})

        for i in range(self.num_epochs_per_step):
            op_returns = self.session.run(ops, feed_dict={self.obs : observations, self.targets : labels, self.old_gating_output : old_gate_vals})

        logger.log("Loss: %f" % op_returns[1])
        logger.log("Accuracy: %f" % op_returns[2])
        logger.log("Accuracy (policy): %f" % op_returns[4])
        logger.log("Accuracy (expert): %f" % op_returns[3])
        logger.log("MI: %f" % op_returns[5])
        logger.log("cv: %f" % op_returns[6])
        logger.log("l2_loss: %f" % op_returns[10])
        logger.log("ce_loss: %f" % op_returns[11])
        logger.log("gate_change: %f" % op_returns[12])
        logger.log("lambda_s_loss: %f" % op_returns[7])
        logger.log("lambda_v_loss: %f" % op_returns[8])
        logger.log("Importance: {}".format(str(np.mean(np.array(op_returns[9]), axis=0))))
        logger.log("Gate Confidence: %f" % op_returns[13])
        print(op_returns[9])
예제 #16
0
    def build_network(self, reuse=None, extra_options=0, stop_old_option_gradients=False):
        # build options
        scope = self.scope
        gate_hidden_sizes = self.gate_hidden_sizes
        hidden_sizes = self.hidden_sizes
        activation = self.activation
        num_options = self.num_options
        use_rms_filter = self.use_rms_filter
        use_gated_trust_region = self.use_gated_trust_region
        use_shared_layer = self.use_shared_layer
        l2_penalty_weight = self.l2_penalty_weight
        cv_penalty_weight = self.cv_penalty_weight
        mutual_info_penalty_weight = self.mutual_info_penalty_weight
        cross_entropy_reweighting = self.cross_entropy_reweighting
        gate_change_penalty = self.gate_change_penalty
        lambda_s = self.lambda_s
        lambda_v = self.lambda_v
        tau = self.tau
        self.options = []
        learning_rate = self.learning_rate
        with tf.variable_scope(scope):
            if use_rms_filter:
                if not hasattr(self, "ob_rms"):
                    with tf.variable_scope("obfilter"):
                        self.ob_rms = RunningMeanStd(shape=(self.observation_size,))

                net_input = tf.clip_by_value((self.obs - self.ob_rms.mean) / self.ob_rms.std, -5.0, 5.0)
            else:
                net_input = self.obs

            if use_shared_layer:
                shared_net = net_input
                for i, x in enumerate(hidden_sizes):
                    shared_net = tf.layers.dense(inputs=shared_net, units=x, activation=activation, kernel_initializer= tf.random_uniform_initializer(-0.05, 0.05), name="discriminator_h%d"%i, reuse=reuse)
                for o in range(num_options):
                    with tf.variable_scope("option%d" % o):
                        net = tf.layers.dense(inputs=shared_net, units=1, activation=None, kernel_initializer= tf.random_uniform_initializer(-0.05, 0.05), name="discriminator_outlayer", reuse=(reuse and o < num_options - extra_options))
                        if stop_old_option_gradients and o < num_options - extra_options:
                            net = tf.stop_gradient(net)
                        self.options.append(net)
            else:
                for o in range(num_options):
                    with tf.variable_scope("option%d" % o):
                        net = net_input
                        for i, x in enumerate(hidden_sizes):
                            net = tf.layers.dense(inputs=net, units=x, activation=activation, kernel_initializer= tf.random_uniform_initializer(-0.05, 0.05), name="discriminator_h%d"%i, reuse=(reuse and o < num_options - extra_options))
                        net = tf.layers.dense(inputs=net, units=1, activation=None, kernel_initializer= tf.random_uniform_initializer(-0.05, 0.05), name="discriminator_outlayer", reuse=(reuse and o < num_options - extra_options))
                        if stop_old_option_gradients and o < num_options - extra_options:
                            net = tf.stop_gradient(net)
                        self.options.append(net)

            # build gate
            with tf.variable_scope("gate"):
                gating_network = net_input
                for i, x in enumerate(gate_hidden_sizes):
                    gating_network = tf.layers.dense(inputs=gating_network, units=x, activation=activation, kernel_initializer= tf.random_uniform_initializer(-0.05, 0.05), name="gating_hidden%d"%i, reuse=reuse)
                gating_network = tf.layers.dense(inputs=gating_network, units=num_options, activation=tf.nn.softmax, kernel_initializer= tf.random_uniform_initializer(-0.05, 0.05), name="gating_outlayer_%d"%extra_options, reuse=False)
                self.termination_function = gating_network

            combined_options = tf.concat(self.options, axis=1)
            self.net_out = tf.reshape(tf.reduce_sum(combined_options * gating_network, axis=1), [-1, 1])

            self.termination_importance_values = tf.reduce_sum(self.termination_function, axis=0)

        var_list = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope=scope)

        self.pred = tf.sigmoid(self.net_out) #-tf.log(1.-tf.sigmoid(self.net_out))
        self.reward = tf.sigmoid(self.net_out)#-tf.log(1.-tf.sigmoid(self.net_out))

        num_experts = tf.cast(tf.count_nonzero(self.targets), tf.int32)
        batch_size = tf.shape(self.obs)[0]

        #weights_B = tf.zeros(tf.shape(self.targets), tf.float32)
        #weights_bexp = weights_B[-num_experts:] + 1.0/(tf.cast(num_experts, tf.float32))
        #weights_bnovice = weights_B[:-num_experts] + 1.0/(tf.cast(batch_size - num_experts, tf.float32))
        #weights_B = tf.concat([weights_bnovice, weights_bexp], axis=0)
        self.cross_entropy_loss = tf.constant(0.0)
        self.old_gating_output = tf.placeholder(tf.float32, [None, num_options])

        logger.log("Using sigmoid cross entropy discriminator objective")
        option_losses_exp = []
        option_losses_nov = []

        # http://www.cs.utoronto.ca/~fidler/teaching/2015/slides/CSC411/18_mixture.pdf
        for option in self.options:
            ent_B = logit_bernoulli_entropy(self.net_out)
            cross_entropy_nov = tf.nn.sigmoid_cross_entropy_with_logits(logits=option[:-num_experts], labels=self.targets[:-num_experts]) - self.ent_reg_weight * ent_B[:-num_experts]
            cross_entropy_exp = tf.nn.sigmoid_cross_entropy_with_logits(logits=option[-num_experts:], labels=self.targets[-num_experts:]) - self.ent_reg_weight * ent_B[-num_experts:]
            option_losses_exp.append(cross_entropy_exp)
            option_losses_nov.append(cross_entropy_nov)

        combined_losses_exp = tf.concat(option_losses_exp, axis=1)
        combined_losses_nov = tf.concat(option_losses_nov, axis=1)
        cross_entropy_nov = tf.reshape(tf.reduce_mean(combined_losses_nov * gating_network[:-num_experts], axis=1), [-1, 1])
        cross_entropy_exp = tf.reshape(tf.reduce_mean(combined_losses_exp * gating_network[-num_experts:], axis=1), [-1, 1])

        self.confidence = tf.sqrt(tf.reduce_mean(tf.square(self.old_gating_output - tf.reduce_mean(self.old_gating_output, axis=0))))
        clip_param = 1.0 - self.confidence # if very confident predictions generally, should clip changes very small.
        if use_gated_trust_region:

            clipped_gating_network = self.old_gating_output + tf.clip_by_value(gating_network - self.old_gating_output, - clip_param, clip_param)
            clipped_cross_entropy_nov = tf.reshape(tf.reduce_mean(combined_losses_nov * clipped_gating_network[:-num_experts], axis=1), [-1, 1])
            clipped_cross_entropy_exp = tf.reshape(tf.reduce_mean(combined_losses_exp * clipped_gating_network[-num_experts:], axis=1), [-1, 1])
            cross_entropy_nov = clipped_cross_entropy_nov#tf.maximum(clipped_cross_entropy_nov, cross_entropy_nov)
            cross_entropy_exp = clipped_cross_entropy_exp#tf.maximum(clipped_cross_entropy_exp, cross_entropy_exp)

        self.cross_entropy_loss = (tf.reduce_mean(cross_entropy_exp, axis=0) + tf.reduce_mean(cross_entropy_nov, axis=0)) / 2.0
        self.cross_entropy_loss *= cross_entropy_reweighting # reweight cross-entropy loss so other penalties don't affect it so much
        self.loss = self.cross_entropy_loss

        self.l2_loss = tf.constant(0.0)

        if l2_penalty_weight > 0.0:
            self.l2_loss = tf.add_n([ tf.nn.l2_loss(v) for v in var_list if 'kernel' in v.name and not ('Adam' in v.name)]) / float(len(var_list)) * l2_penalty_weight
            self.loss += self.l2_loss

        self.mi = tf.constant(0.0)

        if mutual_info_penalty_weight > 0.0:
            #TODO: fix mutual info so that each option is bounded here? maybe doesn't matter
            self.mi = get_mutual_info_penalty(self.options) * mutual_info_penalty_weight
            self.loss += self.mi

        self.cv = tf.constant(0.0)

        if cv_penalty_weight > 0.0:
            self.cv = get_cv_penalty(self.termination_importance_values) * cv_penalty_weight
            self.loss += self.cv

        self.gate_change = tf.constant(0.0)
        if gate_change_penalty > 0.0:
            self.gate_dist = Categorical(num_options)
            self.gate_change = gate_change_penalty * tf.reduce_mean(self.gate_dist.kl_sym(self.old_gating_output, gating_network))
            self.loss += self.gate_change


        # https://arxiv.org/abs/1511.06297
        # These two terms in ensemble encourage diversity and sparsity, while load balancing.
        # it's pretty amazing/awesome actually

        self.lambda_s_loss = tf.constant(0.0)

        if lambda_s > 0.0:
            gate = self.termination_function
            self.lambda_s_loss = lambda_s * (self.uniform_distribution_rescale * tf.reduce_mean((tf.reduce_mean(gate, axis=0) - tau)**2) +
                                    tf.reduce_mean((tf.reduce_mean(gate, axis=1) - tau)**2))
            self.loss += self.lambda_s_loss

        self.lambda_v_loss = tf.constant(0.0)

        if lambda_v > 0.0:
            gate = self.termination_function
            if use_gated_trust_region:
                gate = self.old_gating_output + tf.clip_by_value(self.termination_function - self.old_gating_output, - clip_param, clip_param)
            mean0, var0 = tf.nn.moments(gate, axes=[0])
            mean, var1 = tf.nn.moments(gate, axes=[1])
            self.lambda_v_loss = - lambda_v * (tf.reduce_mean(var0) + tf.reduce_mean(var1))
            self.loss += self.lambda_v_loss

        self.train_op = self.optimizer.minimize(self.loss)
        comparison = tf.less(self.pred, tf.constant(0.5) )
        comparison2 = tf.less(self.targets, tf.constant(0.5) )
        overall = tf.cast(tf.equal(comparison, comparison2), tf.float32)
        accuracy = tf.reduce_mean(overall)#, tf.ones_like(self.targets)))
        accuracy_for_currpolicy = tf.reduce_mean(overall[:-num_experts])#, tf.ones_like(self.targets)))
        accuracy_for_expert = tf.reduce_mean(overall[-num_experts:])#, tf.ones_like(self.targets)))
        self.accuracy = accuracy
        self.accuracy_for_currpolicy = accuracy_for_currpolicy
        self.accuracy_for_expert = accuracy_for_expert

        initialize_uninitialized(self.session)
        return gating_network, net