Ejemplo n.º 1
0
def learn(
        env,
        make_policy,
        *,
        n_episodes,
        horizon,
        delta,
        gamma,
        max_iters,
        sampler=None,
        use_natural_gradient=False,  #can be 'exact', 'approximate'
        fisher_reg=1e-2,
        iw_method='is',
        iw_norm='none',
        bound='J',
        line_search_type='parabola',
        save_weights=0,
        improvement_tol=0.,
        center_return=False,
        render_after=None,
        max_offline_iters=100,
        callback=None,
        clipping=False,
        entropy='none',
        positive_return=False,
        reward_clustering='none',
        capacity=10,
        warm_start=True):

    np.set_printoptions(precision=3)
    max_samples = horizon * n_episodes

    if line_search_type == 'binary':
        line_search = line_search_binary
    elif line_search_type == 'parabola':
        line_search = line_search_parabola
    else:
        raise ValueError()

    # Building the environment
    ob_space = env.observation_space
    ac_space = env.action_space

    # Creating the memory buffer
    memory = Memory(capacity=capacity,
                    batch_size=n_episodes,
                    horizon=horizon,
                    ob_space=ob_space,
                    ac_space=ac_space)

    # Building the target policy and saving its parameters
    pi = make_policy('pi', ob_space, ac_space)
    all_var_list = pi.get_trainable_variables()
    var_list = [
        v for v in all_var_list if v.name.split('/')[1].startswith('pol')
    ]
    shapes = [U.intprod(var.get_shape().as_list()) for var in var_list]
    n_parameters = sum(shapes)

    # Building a set of behavioral policies
    behavioral_policies = memory.build_policies(make_policy, pi)

    # Placeholders
    ob_ = ob = U.get_placeholder_cached(name='ob')
    ac_ = pi.pdtype.sample_placeholder([None], name='ac')
    mask_ = tf.placeholder(dtype=tf.float32, shape=(None), name='mask')
    rew_ = tf.placeholder(dtype=tf.float32, shape=(None), name='rew')
    disc_rew_ = tf.placeholder(dtype=tf.float32, shape=(None), name='disc_rew')
    clustered_rew_ = tf.placeholder(dtype=tf.float32, shape=(None))
    gradient_ = tf.placeholder(dtype=tf.float32,
                               shape=(n_parameters, 1),
                               name='gradient')
    iter_number_ = tf.placeholder(dtype=tf.int32, name='iter_number')
    active_policies = tf.placeholder(dtype=tf.float32,
                                     shape=(capacity),
                                     name='active_policies')
    losses_with_name = []

    # Total number of trajectories
    N_total = tf.reduce_sum(active_policies) * n_episodes

    # Split operations
    disc_rew_split = tf.reshape(disc_rew_ * mask_, [-1, horizon])
    rew_split = tf.reshape(rew_ * mask_, [-1, horizon])
    mask_split = tf.reshape(mask_, [-1, horizon])

    # Policy densities
    target_log_pdf = pi.pd.logp(ac_) * mask_
    target_log_pdf_split = tf.reshape(target_log_pdf, [-1, horizon])
    behavioral_log_pdfs = tf.stack([
        bpi.pd.logp(ac_) * mask_ for bpi in memory.policies
    ])  # Shape is (capacity, ntraj*horizon)
    behavioral_log_pdfs_split = tf.reshape(behavioral_log_pdfs,
                                           [memory.capacity, -1, horizon])

    # Compute renyi divergencies and sum over time, then exponentiate
    emp_d2_split = tf.reshape(
        tf.stack([pi.pd.renyi(bpi.pd, 2) * mask_ for bpi in memory.policies]),
        [memory.capacity, -1, horizon])
    emp_d2_split_cum = tf.exp(tf.reduce_sum(emp_d2_split, axis=2))
    # Compute arithmetic and harmonic mean of emp_d2
    emp_d2_mean = tf.reduce_mean(emp_d2_split_cum, axis=1)
    emp_d2_arithmetic = tf.reduce_sum(
        emp_d2_mean * active_policies) / tf.reduce_sum(active_policies)
    emp_d2_harmonic = tf.reduce_sum(active_policies) / tf.reduce_sum(
        1 / emp_d2_mean)

    # Return processing: clipping, centering, discounting
    ep_return = clustered_rew_  #tf.reduce_sum(mask_split * disc_rew_split, axis=1)
    if clipping:
        rew_split = tf.clip_by_value(rew_split, -1, 1)
    if center_return:
        ep_return = ep_return - tf.reduce_mean(ep_return)
        rew_split = rew_split - (tf.reduce_sum(rew_split) /
                                 (tf.reduce_sum(mask_split) + 1e-24))
    discounter = [pow(gamma, i) for i in range(0, horizon)]  # Decreasing gamma
    discounter_tf = tf.constant(discounter)
    disc_rew_split = rew_split * discounter_tf

    # Reward statistics
    return_mean = tf.reduce_mean(ep_return)
    return_std = U.reduce_std(ep_return)
    return_max = tf.reduce_max(ep_return)
    return_min = tf.reduce_min(ep_return)
    return_abs_max = tf.reduce_max(tf.abs(ep_return))
    return_step_max = tf.reduce_max(tf.abs(rew_split))  # Max step reward
    return_step_mean = tf.abs(tf.reduce_mean(rew_split))
    positive_step_return_max = tf.maximum(0.0, tf.reduce_max(rew_split))
    negative_step_return_max = tf.maximum(0.0, tf.reduce_max(-rew_split))
    return_step_maxmin = tf.abs(positive_step_return_max -
                                negative_step_return_max)
    losses_with_name.extend([(return_mean, 'InitialReturnMean'),
                             (return_max, 'InitialReturnMax'),
                             (return_min, 'InitialReturnMin'),
                             (return_std, 'InitialReturnStd'),
                             (emp_d2_arithmetic, 'EmpiricalD2Arithmetic'),
                             (emp_d2_harmonic, 'EmpiricalD2Harmonic'),
                             (return_step_max, 'ReturnStepMax'),
                             (return_step_maxmin, 'ReturnStepMaxmin')])

    if iw_method == 'is':
        # Sum the log prob over time. Shapes: target(Nep, H), behav (Cap, Nep, H)
        target_log_pdf_episode = tf.reduce_sum(target_log_pdf_split, axis=1)
        behavioral_log_pdf_episode = tf.reduce_sum(behavioral_log_pdfs_split,
                                                   axis=2)
        # To avoid numerical instability, compute the inversed ratio
        log_ratio = target_log_pdf_split - behavioral_log_pdfs_split
        inverse_log_ratio_episode = -tf.reduce_sum(log_ratio, axis=2)

        iw = 1 / tf.reduce_sum(tf.exp(inverse_log_ratio_episode) *
                               tf.expand_dims(active_policies, -1),
                               axis=0)

        # Compute also the balance-heuristic weights
        iw_split = tf.reshape(iw, (memory.capacity, -1))
        iw_by_behavioral = tf.reduce_mean(iw_split, axis=1)
        losses_with_name.append(
            (iw_by_behavioral[0] / tf.reduce_sum(iw_by_behavioral),
             'MultiIWFirstRatio'))
        losses_with_name.append(
            (tf.reduce_max(iw_by_behavioral), 'MultiIWMax'))
        losses_with_name.append(
            (tf.reduce_sum(iw_by_behavioral), 'MultiIWSum'))
        losses_with_name.append(
            (tf.reduce_min(iw_by_behavioral), 'MultiIWMin'))

        # Get the probability by exponentiation
        #target_pdf_episode = tf.exp(target_log_pdf_episode)
        #behavioral_pdf_episode = tf.exp(behavioral_log_pdf_episode)
        # Get the denominator by averaging over behavioral policies
        #behavioral_pdf_mixture = tf.reduce_mean(behavioral_pdf_episode, axis=0) + 1e-24
        #iw = target_pdf_episode / behavioral_pdf_mixture
        iwn = iw / n_episodes

        # Compute the J
        w_return_mean = tf.reduce_sum(ep_return * iwn)
        # Empirical D2 of the mixture and relative ESS
        ess_renyi_arithmetic = N_total / emp_d2_arithmetic
        ess_renyi_harmonic = N_total / emp_d2_harmonic
        # Log quantities
        losses_with_name.extend([
            (tf.reduce_max(iw), 'MaxIW'), (tf.reduce_min(iw), 'MinIW'),
            (tf.reduce_mean(iw), 'MeanIW'), (U.reduce_std(iw), 'StdIW'),
            (tf.reduce_min(target_log_pdf_episode), 'MinTargetPdf'),
            (tf.reduce_min(behavioral_log_pdf_episode), 'MinBehavPdf'),
            (ess_renyi_arithmetic, 'ESSRenyiArithmetic'),
            (ess_renyi_harmonic, 'ESSRenyiHarmonic')
        ])
    else:
        raise NotImplementedError()

    if bound == 'J':
        bound_ = w_return_mean
    elif bound == 'max-d2-harmonic':
        bound_ = w_return_mean - tf.sqrt(
            (1 - delta) / (delta * ess_renyi_harmonic)) * return_abs_max
    elif bound == 'max-d2-arithmetic':
        bound_ = w_return_mean - tf.sqrt(
            (1 - delta) / (delta * ess_renyi_arithmetic)) * return_abs_max
    else:
        raise NotImplementedError()

    # Policy entropy for exploration
    ent = pi.pd.entropy()
    meanent = tf.reduce_mean(ent)
    losses_with_name.append((meanent, 'MeanEntropy'))
    # Add policy entropy bonus
    if entropy != 'none':
        scheme, v1, v2 = entropy.split(':')
        if scheme == 'step':
            entcoeff = tf.cond(iter_number_ < int(v2), lambda: float(v1),
                               lambda: float(0.0))
            losses_with_name.append((entcoeff, 'EntropyCoefficient'))
            entbonus = entcoeff * meanent
            bound_ = bound_ + entbonus
        elif scheme == 'lin':
            ip = tf.cast(iter_number_ / max_iters, tf.float32)
            entcoeff_decay = tf.maximum(
                0.0,
                float(v2) + (float(v1) - float(v2)) * (1.0 - ip))
            losses_with_name.append((entcoeff_decay, 'EntropyCoefficient'))
            entbonus = entcoeff_decay * meanent
            bound_ = bound_ + entbonus
        elif scheme == 'exp':
            ent_f = tf.exp(
                -tf.abs(tf.reduce_mean(iw) - 1) * float(v2)) * float(v1)
            losses_with_name.append((ent_f, 'EntropyCoefficient'))
            bound_ = bound_ + ent_f * meanent
        else:
            raise Exception('Unrecognized entropy scheme.')

    losses_with_name.append((w_return_mean, 'ReturnMeanIW'))
    losses_with_name.append((bound_, 'Bound'))
    losses, loss_names = map(list, zip(*losses_with_name))
    '''
    if use_natural_gradient:
        p = tf.placeholder(dtype=tf.float32, shape=[None])
        target_logpdf_episode = tf.reduce_sum(target_log_pdf_split * mask_split, axis=1)
        grad_logprob = U.flatgrad(tf.stop_gradient(iwn) * target_logpdf_episode, var_list)
        dot_product = tf.reduce_sum(grad_logprob * p)
        hess_logprob = U.flatgrad(dot_product, var_list)
        compute_linear_operator = U.function([p, ob_, ac_, disc_rew_, mask_], [-hess_logprob])
    '''

    assert_ops = tf.group(*tf.get_collection('asserts'))
    print_ops = tf.group(*tf.get_collection('prints'))

    compute_lossandgrad = U.function([
        ob_, ac_, rew_, disc_rew_, clustered_rew_, mask_, iter_number_,
        active_policies
    ], losses + [U.flatgrad(bound_, var_list), assert_ops, print_ops])
    compute_grad = U.function([
        ob_, ac_, rew_, disc_rew_, clustered_rew_, mask_, iter_number_,
        active_policies
    ], [U.flatgrad(bound_, var_list), assert_ops, print_ops])
    compute_bound = U.function([
        ob_, ac_, rew_, disc_rew_, clustered_rew_, mask_, iter_number_,
        active_policies
    ], [bound_, assert_ops, print_ops])
    compute_losses = U.function([
        ob_, ac_, rew_, disc_rew_, clustered_rew_, mask_, iter_number_,
        active_policies
    ], losses)
    #compute_temp = U.function([ob_, ac_, rew_, disc_rew_, clustered_rew_, mask_, iter_number_, active_policies], [log_inverse_ratio, abc, iw])

    set_parameter = U.SetFromFlat(var_list)
    get_parameter = U.GetFlat(var_list)
    policy_reinit = tf.variables_initializer(var_list)

    if sampler is None:
        seg_gen = traj_segment_generator(pi,
                                         env,
                                         n_episodes,
                                         horizon,
                                         stochastic=True,
                                         gamma=gamma)
        sampler = type("SequentialSampler", (object, ), {
            "collect": lambda self, _: seg_gen.__next__()
        })()

    U.initialize()

    # Starting optimizing
    episodes_so_far = 0
    timesteps_so_far = 0
    iters_so_far = 0
    tstart = time.time()
    lenbuffer = deque(maxlen=n_episodes)
    rewbuffer = deque(maxlen=n_episodes)

    while True:

        iters_so_far += 1

        if render_after is not None and iters_so_far % render_after == 0:
            if hasattr(env, 'render'):
                render(env, pi, horizon)

        if callback:
            callback(locals(), globals())

        if iters_so_far >= max_iters:
            print('Finished...')
            break

        logger.log('********** Iteration %i ************' % iters_so_far)

        theta = get_parameter()

        with timed('sampling'):
            seg = sampler.collect(theta)

        lens, rets = seg['ep_lens'], seg['ep_rets']
        lenbuffer.extend(lens)
        rewbuffer.extend(rets)
        episodes_so_far += len(lens)
        timesteps_so_far += sum(lens)

        # Adding batch of trajectories to memory
        memory.add_trajectory_batch(seg)

        # Get multiple batches from memory
        seg_with_memory = memory.get_trajectories()

        # Get clustered reward
        reward_matrix = np.reshape(
            seg_with_memory['disc_rew'] * seg_with_memory['mask'],
            (-1, horizon))
        ep_reward = np.sum(reward_matrix, axis=1)
        ep_reward = cluster_rewards(ep_reward, reward_clustering)

        args = ob, ac, rew, disc_rew, clustered_rew, mask, iter_number, active_policies = (
            seg_with_memory['ob'], seg_with_memory['ac'],
            seg_with_memory['rew'], seg_with_memory['disc_rew'], ep_reward,
            seg_with_memory['mask'], iters_so_far,
            memory.get_active_policies_mask())

        def evaluate_loss():
            loss = compute_bound(*args)
            return loss[0]

        def evaluate_gradient():
            gradient = compute_grad(*args)
            return gradient[0]

        if use_natural_gradient:

            def evaluate_fisher_vector_prod(x):
                return compute_linear_operator(x, *args)[0] + fisher_reg * x

            def evaluate_natural_gradient(g):
                return cg(evaluate_fisher_vector_prod,
                          g,
                          cg_iters=10,
                          verbose=0)
        else:
            evaluate_natural_gradient = None

        with timed('summaries before'):
            logger.record_tabular("Iteration", iters_so_far)
            logger.record_tabular("InitialBound", evaluate_loss())
            logger.record_tabular("EpLenMean", np.mean(lenbuffer))
            logger.record_tabular("EpRewMean", np.mean(rewbuffer))
            logger.record_tabular("EpThisIter", len(lens))
            logger.record_tabular("EpisodesSoFar", episodes_so_far)
            logger.record_tabular("TimestepsSoFar", timesteps_so_far)
            logger.record_tabular("TimeElapsed", time.time() - tstart)

        if save_weights > 0 and iters_so_far % save_weights == 0:
            logger.record_tabular('Weights', str(get_parameter()))
            import pickle
            file = open('checkpoint' + str(iters_so_far) + '.pkl', 'wb')
            pickle.dump(theta, file)

        if not warm_start or memory.get_current_load() == capacity:
            # Optimize
            with timed("offline optimization"):
                theta, improvement = optimize_offline(
                    theta,
                    set_parameter,
                    line_search,
                    evaluate_loss,
                    evaluate_gradient,
                    evaluate_natural_gradient,
                    max_offline_ite=max_offline_iters)

            set_parameter(theta)
            print(theta)

            with timed('summaries after'):
                meanlosses = np.array(compute_losses(*args))
                for (lossname, lossval) in zip(loss_names, meanlosses):
                    logger.record_tabular(lossname, lossval)
        else:
            # Reinitialize the policy
            tf.get_default_session().run(policy_reinit)

        logger.dump_tabular()

    env.close()
Ejemplo n.º 2
0
def learn(
        env,
        make_policy,
        *,
        n_episodes,
        horizon,
        delta,
        gamma,
        max_iters,
        sampler=None,
        use_natural_gradient=False,  #can be 'exact', 'approximate'
        fisher_reg=1e-2,
        iw_method='is',
        iw_norm='none',
        bound='J',
        line_search_type='parabola',
        save_weights=0,
        improvement_tol=0.,
        center_return=False,
        render_after=None,
        max_offline_iters=100,
        callback=None,
        clipping=False,
        entropy='none',
        positive_return=False,
        reward_clustering='none'):

    np.set_printoptions(precision=3)
    max_samples = horizon * n_episodes

    if line_search_type == 'binary':
        line_search = line_search_binary
    elif line_search_type == 'parabola':
        line_search = line_search_parabola
    else:
        raise ValueError()

    # Building the environment
    ob_space = env.observation_space
    ac_space = env.action_space

    # Building the policy
    pi = make_policy('pi', ob_space, ac_space)
    oldpi = make_policy('oldpi', ob_space, ac_space)

    all_var_list = pi.get_trainable_variables()
    var_list = [
        v for v in all_var_list if v.name.split('/')[1].startswith('pol')
    ]

    shapes = [U.intprod(var.get_shape().as_list()) for var in var_list]
    n_parameters = sum(shapes)

    # Placeholders
    ob_ = ob = U.get_placeholder_cached(name='ob')
    ac_ = pi.pdtype.sample_placeholder([max_samples], name='ac')
    mask_ = tf.placeholder(dtype=tf.float32, shape=(max_samples), name='mask')
    rew_ = tf.placeholder(dtype=tf.float32, shape=(max_samples), name='rew')
    disc_rew_ = tf.placeholder(dtype=tf.float32,
                               shape=(max_samples),
                               name='disc_rew')
    gradient_ = tf.placeholder(dtype=tf.float32,
                               shape=(n_parameters, 1),
                               name='gradient')
    iter_number_ = tf.placeholder(dtype=tf.int32, name='iter_number')
    losses_with_name = []

    # Policy densities
    target_log_pdf = pi.pd.logp(ac_)
    behavioral_log_pdf = oldpi.pd.logp(ac_)
    log_ratio = target_log_pdf - behavioral_log_pdf

    # Split operations
    disc_rew_split = tf.stack(tf.split(disc_rew_ * mask_, n_episodes))
    rew_split = tf.stack(tf.split(rew_ * mask_, n_episodes))
    target_log_pdf_split = tf.stack(
        tf.split(target_log_pdf * mask_, n_episodes))
    behavioral_log_pdf_split = tf.stack(
        tf.split(behavioral_log_pdf * mask_, n_episodes))
    log_ratio_split = target_log_pdf_split - behavioral_log_pdf_split
    mask_split = tf.stack(tf.split(mask_, n_episodes))

    # Renyi divergence
    emp_d2_split = tf.stack(
        tf.split(pi.pd.renyi(oldpi.pd, 2) * mask_, n_episodes))
    emp_d2_cum_split = tf.reduce_sum(emp_d2_split, axis=1)
    empirical_d2 = tf.reduce_mean(tf.exp(emp_d2_cum_split))

    # Return
    ep_return = tf.reduce_sum(mask_split * disc_rew_split, axis=1)
    if clipping:
        rew_split = tf.clip_by_value(rew_split, -1, 1)

    if center_return:
        ep_return = ep_return - tf.reduce_mean(ep_return)
        rew_split = rew_split - (tf.reduce_sum(rew_split) /
                                 (tf.reduce_sum(mask_split) + 1e-24))

    discounter = [pow(gamma, i) for i in range(0, horizon)]  # Decreasing gamma
    discounter_tf = tf.constant(discounter)
    disc_rew_split = rew_split * discounter_tf

    return_mean = tf.reduce_mean(ep_return)
    return_std = U.reduce_std(ep_return)
    return_max = tf.reduce_max(ep_return)
    return_min = tf.reduce_min(ep_return)
    return_abs_max = tf.reduce_max(tf.abs(ep_return))
    return_step_max = tf.reduce_max(tf.abs(rew_split))  # Max step reward
    return_step_mean = tf.abs(tf.reduce_mean(rew_split))
    positive_step_return_max = tf.maximum(0.0, tf.reduce_max(rew_split))
    negative_step_return_max = tf.maximum(0.0, tf.reduce_max(-rew_split))
    return_step_maxmin = tf.abs(positive_step_return_max -
                                negative_step_return_max)

    losses_with_name.extend([(return_mean, 'InitialReturnMean'),
                             (return_max, 'InitialReturnMax'),
                             (return_min, 'InitialReturnMin'),
                             (return_std, 'InitialReturnStd'),
                             (empirical_d2, 'EmpiricalD2'),
                             (return_step_max, 'ReturnStepMax'),
                             (return_step_maxmin, 'ReturnStepMaxmin')])

    if iw_method == 'pdis':
        # log_ratio_split cumulative sum
        log_ratio_cumsum = tf.cumsum(log_ratio_split, axis=1)
        # Exponentiate
        ratio_cumsum = tf.exp(log_ratio_cumsum)
        # Multiply by the step-wise reward (not episode)
        ratio_reward = ratio_cumsum * disc_rew_split
        # Average on episodes
        ratio_reward_per_episode = tf.reduce_sum(ratio_reward, axis=1)
        w_return_mean = tf.reduce_sum(ratio_reward_per_episode,
                                      axis=0) / n_episodes
        # Get d2(w0:t) with mask
        d2_w_0t = tf.exp(tf.cumsum(emp_d2_split,
                                   axis=1)) * mask_split  # LEAVE THIS OUTSIDE
        # Sum d2(w0:t) over timesteps
        episode_d2_0t = tf.reduce_sum(d2_w_0t, axis=1)
        # Sample variance
        J_sample_variance = (1 / (n_episodes - 1)) * tf.reduce_sum(
            tf.square(ratio_reward_per_episode - w_return_mean))
        losses_with_name.append((J_sample_variance, 'J_sample_variance'))
        losses_with_name.extend([(tf.reduce_max(ratio_cumsum), 'MaxIW'),
                                 (tf.reduce_min(ratio_cumsum), 'MinIW'),
                                 (tf.reduce_mean(ratio_cumsum), 'MeanIW'),
                                 (U.reduce_std(ratio_cumsum), 'StdIW')])
        losses_with_name.extend([(tf.reduce_max(d2_w_0t), 'MaxD2w0t'),
                                 (tf.reduce_min(d2_w_0t), 'MinD2w0t'),
                                 (tf.reduce_mean(d2_w_0t), 'MeanD2w0t'),
                                 (U.reduce_std(d2_w_0t), 'StdD2w0t')])

    elif iw_method == 'is':
        iw = tf.exp(tf.reduce_sum(log_ratio_split, axis=1))
        if iw_norm == 'none':
            iwn = iw / n_episodes
            w_return_mean = tf.reduce_sum(iwn * ep_return)
            J_sample_variance = (1 / (n_episodes - 1)) * tf.reduce_sum(
                tf.square(iw * ep_return - w_return_mean))
            losses_with_name.append((J_sample_variance, 'J_sample_variance'))
        elif iw_norm == 'sn':
            iwn = iw / tf.reduce_sum(iw)
            w_return_mean = tf.reduce_sum(iwn * ep_return)
        elif iw_norm == 'regression':
            iwn = iw / n_episodes
            mean_iw = tf.reduce_mean(iw)
            beta = tf.reduce_sum(
                (iw - mean_iw) * ep_return * iw) / (tf.reduce_sum(
                    (iw - mean_iw)**2) + 1e-24)
            w_return_mean = tf.reduce_mean(iw * ep_return - beta * (iw - 1))
        else:
            raise NotImplementedError()
        ess_classic = tf.linalg.norm(iw, 1)**2 / tf.linalg.norm(iw, 2)**2
        sqrt_ess_classic = tf.linalg.norm(iw, 1) / tf.linalg.norm(iw, 2)
        ess_renyi = n_episodes / empirical_d2
        losses_with_name.extend([(tf.reduce_max(iwn), 'MaxIWNorm'),
                                 (tf.reduce_min(iwn), 'MinIWNorm'),
                                 (tf.reduce_mean(iwn), 'MeanIWNorm'),
                                 (U.reduce_std(iwn), 'StdIWNorm'),
                                 (tf.reduce_max(iw), 'MaxIW'),
                                 (tf.reduce_min(iw), 'MinIW'),
                                 (tf.reduce_mean(iw), 'MeanIW'),
                                 (U.reduce_std(iw), 'StdIW'),
                                 (ess_classic, 'ESSClassic'),
                                 (ess_renyi, 'ESSRenyi')])
    elif iw_method == 'rbis':
        # Get pdfs for episodes
        target_log_pdf_episode = tf.reduce_sum(target_log_pdf_split, axis=1)
        behavioral_log_pdf_episode = tf.reduce_sum(behavioral_log_pdf_split,
                                                   axis=1)
        # Normalize log_proba (avoid as overflows as possible)
        normalization_factor = tf.reduce_mean(
            tf.stack([target_log_pdf_episode, behavioral_log_pdf_episode]))
        target_norm_log_pdf_episode = target_log_pdf_episode - normalization_factor
        behavioral_norm_log_pdf_episode = behavioral_log_pdf_episode - normalization_factor
        # Exponentiate
        target_pdf_episode = tf.clip_by_value(
            tf.cast(tf.exp(target_norm_log_pdf_episode), tf.float64), 1e-300,
            1e+300)
        behavioral_pdf_episode = tf.clip_by_value(
            tf.cast(tf.exp(behavioral_norm_log_pdf_episode), tf.float64),
            1e-300, 1e+300)
        tf.add_to_collection(
            'asserts',
            tf.assert_positive(target_pdf_episode, name='target_pdf_positive'))
        tf.add_to_collection(
            'asserts',
            tf.assert_positive(behavioral_pdf_episode,
                               name='behavioral_pdf_positive'))
        # Compute the merging matrix (reward-clustering) and the number of clusters
        reward_unique, reward_indexes = tf.unique(ep_return)
        episode_clustering_matrix = tf.cast(
            tf.one_hot(reward_indexes, n_episodes), tf.float64)
        max_index = tf.reduce_max(reward_indexes) + 1
        tf.add_to_collection(
            'asserts',
            tf.assert_positive(tf.reduce_sum(episode_clustering_matrix,
                                             axis=0)[:max_index],
                               name='clustering_matrix'))
        # Get the clustered pdfs
        clustered_target_pdf = tf.matmul(
            tf.reshape(target_pdf_episode, (1, -1)),
            episode_clustering_matrix)[0][:max_index]
        clustered_behavioral_pdf = tf.matmul(
            tf.reshape(behavioral_pdf_episode, (1, -1)),
            episode_clustering_matrix)[0][:max_index]
        tf.add_to_collection(
            'asserts',
            tf.assert_positive(clustered_target_pdf,
                               name='clust_target_pdf_positive'))
        tf.add_to_collection(
            'asserts',
            tf.assert_positive(clustered_behavioral_pdf,
                               name='clust_behavioral_pdf_positive'))
        # Compute the J
        ratio_clustered = clustered_target_pdf / clustered_behavioral_pdf
        ratio_reward = tf.cast(ratio_clustered, tf.float32) * reward_unique
        w_return_mean = tf.reduce_sum(ratio_reward) / tf.cast(
            max_index, tf.float32)
        # Divergences
        ess_classic = tf.linalg.norm(ratio_reward, 1)**2 / tf.linalg.norm(
            ratio_reward, 2)**2
        sqrt_ess_classic = tf.linalg.norm(ratio_reward, 1) / tf.linalg.norm(
            ratio_reward, 2)
        ess_renyi = n_episodes / empirical_d2
        # Summaries
        losses_with_name.extend([(tf.reduce_max(ratio_clustered), 'MaxIW'),
                                 (tf.reduce_min(ratio_clustered), 'MinIW'),
                                 (tf.reduce_mean(ratio_clustered), 'MeanIW'),
                                 (U.reduce_std(ratio_clustered), 'StdIW'),
                                 (1 - (max_index / n_episodes),
                                  'RewardCompression'),
                                 (ess_classic, 'ESSClassic'),
                                 (ess_renyi, 'ESSRenyi')])
    else:
        raise NotImplementedError()

    if bound == 'J':
        bound_ = w_return_mean
    elif bound == 'std-d2':
        bound_ = w_return_mean - tf.sqrt(
            (1 - delta) / (delta * ess_renyi)) * return_std
    elif bound == 'max-d2':
        var_estimate = tf.sqrt(
            (1 - delta) / (delta * ess_renyi)) * return_abs_max
        bound_ = w_return_mean - tf.sqrt(
            (1 - delta) / (delta * ess_renyi)) * return_abs_max
    elif bound == 'max-ess':
        bound_ = w_return_mean - tf.sqrt(
            (1 - delta) / delta) / sqrt_ess_classic * return_abs_max
    elif bound == 'std-ess':
        bound_ = w_return_mean - tf.sqrt(
            (1 - delta) / delta) / sqrt_ess_classic * return_std
    elif bound == 'pdis-max-d2':
        # Discount factor
        if gamma >= 1:
            discounter = [
                float(1 + 2 * (horizon - t - 1)) for t in range(0, horizon)
            ]
        else:

            def f(t):
                return pow(gamma, t) * (pow(gamma, t) + pow(gamma, t + 1) -
                                        2 * pow(gamma, horizon)) / (1 - gamma)

            discounter = [f(t) for t in range(0, horizon)]
        discounter_tf = tf.constant(discounter)
        mean_episode_d2 = tf.reduce_sum(
            d2_w_0t, axis=0) / (tf.reduce_sum(mask_split, axis=0) + 1e-24)
        discounted_d2 = mean_episode_d2 * discounter_tf  # Discounted d2
        discounted_total_d2 = tf.reduce_sum(discounted_d2,
                                            axis=0)  # Sum over time
        bound_ = w_return_mean - tf.sqrt(
            (1 - delta) * discounted_total_d2 /
            (delta * n_episodes)) * return_step_max
    elif bound == 'pdis-mean-d2':
        # Discount factor
        if gamma >= 1:
            discounter = [
                float(1 + 2 * (horizon - t - 1)) for t in range(0, horizon)
            ]
        else:

            def f(t):
                return pow(gamma, t) * (pow(gamma, t) + pow(gamma, t + 1) -
                                        2 * pow(gamma, horizon)) / (1 - gamma)

            discounter = [f(t) for t in range(0, horizon)]
        discounter_tf = tf.constant(discounter)
        mean_episode_d2 = tf.reduce_sum(
            d2_w_0t, axis=0) / (tf.reduce_sum(mask_split, axis=0) + 1e-24)
        discounted_d2 = mean_episode_d2 * discounter_tf  # Discounted d2
        discounted_total_d2 = tf.reduce_sum(discounted_d2,
                                            axis=0)  # Sum over time
        bound_ = w_return_mean - tf.sqrt(
            (1 - delta) * discounted_total_d2 /
            (delta * n_episodes)) * return_step_mean
    else:
        raise NotImplementedError()

    # Policy entropy for exploration
    ent = pi.pd.entropy()
    meanent = tf.reduce_mean(ent)
    losses_with_name.append((meanent, 'MeanEntropy'))
    # Add policy entropy bonus
    if entropy != 'none':
        scheme, v1, v2 = entropy.split(':')
        if scheme == 'step':
            entcoeff = tf.cond(iter_number_ < int(v2), lambda: float(v1),
                               lambda: float(0.0))
            losses_with_name.append((entcoeff, 'EntropyCoefficient'))
            entbonus = entcoeff * meanent
            bound_ = bound_ + entbonus
        elif scheme == 'lin':
            ip = tf.cast(iter_number_ / max_iters, tf.float32)
            entcoeff_decay = tf.maximum(
                0.0,
                float(v2) + (float(v1) - float(v2)) * (1.0 - ip))
            losses_with_name.append((entcoeff_decay, 'EntropyCoefficient'))
            entbonus = entcoeff_decay * meanent
            bound_ = bound_ + entbonus
        elif scheme == 'exp':
            ent_f = tf.exp(
                -tf.abs(tf.reduce_mean(iw) - 1) * float(v2)) * float(v1)
            losses_with_name.append((ent_f, 'EntropyCoefficient'))
            bound_ = bound_ + ent_f * meanent
        else:
            raise Exception('Unrecognized entropy scheme.')

    losses_with_name.append((w_return_mean, 'ReturnMeanIW'))
    losses_with_name.append((bound_, 'Bound'))
    losses, loss_names = map(list, zip(*losses_with_name))

    if use_natural_gradient:
        p = tf.placeholder(dtype=tf.float32, shape=[None])
        target_logpdf_episode = tf.reduce_sum(target_log_pdf_split *
                                              mask_split,
                                              axis=1)
        grad_logprob = U.flatgrad(
            tf.stop_gradient(iwn) * target_logpdf_episode, var_list)
        dot_product = tf.reduce_sum(grad_logprob * p)
        hess_logprob = U.flatgrad(dot_product, var_list)
        compute_linear_operator = U.function([p, ob_, ac_, disc_rew_, mask_],
                                             [-hess_logprob])

    assign_old_eq_new = U.function(
        [], [],
        updates=[
            tf.assign(oldv, newv)
            for (oldv,
                 newv) in zipsame(oldpi.get_variables(), pi.get_variables())
        ])

    assert_ops = tf.group(*tf.get_collection('asserts'))
    print_ops = tf.group(*tf.get_collection('prints'))

    compute_lossandgrad = U.function(
        [ob_, ac_, rew_, disc_rew_, mask_, iter_number_],
        losses + [U.flatgrad(bound_, var_list), assert_ops, print_ops])
    compute_grad = U.function(
        [ob_, ac_, rew_, disc_rew_, mask_, iter_number_],
        [U.flatgrad(bound_, var_list), assert_ops, print_ops])
    compute_bound = U.function(
        [ob_, ac_, rew_, disc_rew_, mask_, iter_number_],
        [bound_, assert_ops, print_ops])
    compute_losses = U.function(
        [ob_, ac_, rew_, disc_rew_, mask_, iter_number_], losses)
    #compute_temp = U.function([ob_, ac_, rew_, disc_rew_, mask_], [ratio_cumsum, discounted_ratio])

    set_parameter = U.SetFromFlat(var_list)
    get_parameter = U.GetFlat(var_list)

    seg_gen = traj_segment_generator(pi,
                                     env,
                                     n_episodes,
                                     horizon,
                                     stochastic=True,
                                     gamma=gamma)
    sampler = type("SequentialSampler", (object, ), {
        "collect": lambda self, _: seg_gen.__next__()
    })()

    U.initialize()

    # Starting optimizing

    episodes_so_far = 0
    timesteps_so_far = 0
    iters_so_far = 0
    tstart = time.time()
    lenbuffer = deque(maxlen=n_episodes)
    rewbuffer = deque(maxlen=n_episodes)

    while True:

        iters_so_far += 1

        if render_after is not None and iters_so_far % render_after == 0:
            if hasattr(env, 'render'):
                render(env, pi, horizon)

        if callback:
            callback(locals(), globals())

        if iters_so_far >= max_iters:
            print('Finised...')
            break

        logger.log('********** Iteration %i ************' % iters_so_far)

        theta = get_parameter()

        with timed('sampling'):
            seg = sampler.collect(theta)

        lens, rets = seg['ep_lens'], seg['ep_rets']

        lenbuffer.extend(lens)
        rewbuffer.extend(rets)
        episodes_so_far += len(lens)
        timesteps_so_far += sum(lens)

        args = ob, ac, rew, disc_rew, mask, iter_number = seg['ob'], seg[
            'ac'], seg['rew'], seg['disc_rew'], seg['mask'], iters_so_far

        assign_old_eq_new()

        def evaluate_loss():
            loss = compute_bound(*args)
            return loss[0]

        def evaluate_gradient():
            gradient = compute_grad(*args)
            return gradient[0]

        if use_natural_gradient:

            def evaluate_fisher_vector_prod(x):
                return compute_linear_operator(x, *args)[0] + fisher_reg * x

            def evaluate_natural_gradient(g):
                return cg(evaluate_fisher_vector_prod,
                          g,
                          cg_iters=10,
                          verbose=0)
        else:
            evaluate_natural_gradient = None

        with timed('summaries before'):
            logger.record_tabular("Itaration", iters_so_far)
            logger.record_tabular("InitialBound", evaluate_loss())
            logger.record_tabular("EpLenMean", np.mean(lenbuffer))
            logger.record_tabular("EpRewMean", np.mean(rewbuffer))
            logger.record_tabular("EpThisIter", len(lens))
            logger.record_tabular("EpisodesSoFar", episodes_so_far)
            logger.record_tabular("TimestepsSoFar", timesteps_so_far)
            logger.record_tabular("TimeElapsed", time.time() - tstart)
            # Time records
            logger.record_tabular("TotalTime", seg['total_time'])
            logger.record_tabular("PolicyTime", seg['policy_time'])
            logger.record_tabular("EnvTime", seg['env_time'])
            logger.record_tabular("PolicyRatio",
                                  seg['policy_time'] / seg['total_time'])
            logger.record_tabular("EnvRatio",
                                  seg['env_time'] / seg['total_time'])

        if save_weights > 0 and iters_so_far % save_weights == 0:
            logger.record_tabular('Weights', str(get_parameter()))
            import pickle
            file = open('checkpoint' + str(iters_so_far) + '.pkl', 'wb')
            pickle.dump(theta, file)

        with timed("offline optimization"):

            theta, improvement = optimize_offline(
                theta,
                set_parameter,
                line_search,
                evaluate_loss,
                evaluate_gradient,
                evaluate_natural_gradient,
                max_offline_ite=max_offline_iters)

        set_parameter(theta)
        print(theta)

        with timed('summaries after'):
            meanlosses = np.array(compute_losses(*args))
            for (lossname, lossval) in zip(loss_names, meanlosses):
                logger.record_tabular(lossname, lossval)

        logger.dump_tabular()

    env.close()