Beispiel #1
0
def sample_agent_data(N, args, env, model, obs_normalizer, customers,
                      customer_states):
    agent_states = []
    agent_actions = []

    closest_expert = N * [0]

    for i in range(N):
        # Initialize agent with data from ith expert
        initial_state = random.choice(customer_states[i])
        if args['state_rep'] == 22 or args['state_rep'] == 221 or args[
                'state_rep'] == 23 and i >= n_experts:
            # Find closest expert
            c = customers[i]
            distances = [wd(c, e) for e in experts]
            dummy = np.argsort(distances)[0]
            closest_expert[i] = dummy
            initial_state[dummy] = 1
        states, actions = pe.sample_from_policy(env,
                                                model,
                                                obs_normalizer,
                                                initial_state=initial_state)
        agent_states.append(states)
        agent_actions.append(actions)
    agent_states = np.array(agent_states)
    agent_actions = np.array(agent_actions)
    return agent_states, agent_actions, closest_expert
Beispiel #2
0
def purchase_ratio(args, model_dir_path):
    env, model, obs_normalizer = pe.get_env_and_model(args, model_dir_path, sample_length=10000, model_path=model_path)

    expert_trajectories = env.generate_expert_trajectories(out_dir=None, n_demos_per_expert=1, n_expert_time_steps=sample_length)
    expert_states = expert_trajectories['states']
    expert_actions = expert_trajectories['actions']
    sex = ['F' if s == 1 else 'M' for s in expert_trajectories['sex']]
    age = [int(a) for a in expert_trajectories['age']]

    for i, (e_states, e_actions) in enumerate(zip(expert_states, expert_actions)):  # Loop over experts
        print('Expert %d' % (i+1))

        # Sample data from agent
        initial_state = e_states[0]  # random.choice(e_states)
        _, agent_actions = pe.sample_from_policy(env, model, obs_normalizer, initial_state=initial_state)

        temp1 = []
        temp2 = []

        for n in [100, 500, 1000, 5000, 10000]:
            e_pr = get_pr(e_actions, n)
            a_pr = get_pr(agent_actions, n)

            temp1.append(e_pr)
            temp2.append(a_pr)

        fig, ax = plt.subplots()
        ax.plot([100, 500, 1000, 5000, 10000], temp1, label='Expert')
        ax.plot([100, 500, 1000, 5000, 10000], temp2, label='Agent')
        ax.legend()

    plt.show()
Beispiel #3
0
def evaluate_on_pop_level(args, model_path, avg_expert, compare_features):
    n_experts = args['n_experts']
        
    env, model, obs_normalizer = pe.get_env_and_model(args, model_path, sample_length, only_env=False)

    metric = ed if compare_features else wd

    agent_states = []
    agent_actions = []
    for i in range(n_experts):
        # Initialize agent with data from ith expert
        env.model.spawn_new_customer(i)
        sample = env.case.get_sample(
            n_demos_per_expert=1, 
            n_historical_events=args['n_historical_events'], 
            n_time_steps=1000
            )
        all_data = np.hstack(sample[0])  # history, data = sample[0]
        j = np.random.randint(0, all_data.shape[1] - args['n_historical_events'])
        history = all_data[:, j:j + args['n_historical_events']]
        if args['state_rep'] == 71:
            adam_basket = np.random.permutation(env.case.adam_baskets[i])
            env.case.i_expert = i
            initial_state = env.case.get_initial_state(history, adam_basket[0])
        else:
            initial_state = env.case.get_initial_state(history, i)

        states, actions = pe.sample_from_policy(env, model, obs_normalizer, initial_state=initial_state)
        agent_states.append(states)
        agent_actions.append(actions)
    agent_states = np.array(agent_states)
    agent_actions = np.array(agent_actions)

    avg_agent = get_features(agent_actions, average=True) if compare_features else pe.get_distrib(agent_states, agent_actions)

    distance = metric(avg_agent, avg_expert)

    return distance
def evaluate(args, model_path, n_new_customers, sample_length, N,
             customer_states):
    n_experts = args['n_experts']

    if args['state_rep'] == 71:
        env, model, obs_normalizer = pe.get_env_and_model(
            args,
            model_path,
            sample_length,
            only_env=False,
            n_experts_in_adam_basket=n_experts + n_new_customers)
    else:
        env, model, obs_normalizer = pe.get_env_and_model(args,
                                                          model_path,
                                                          sample_length,
                                                          only_env=False)

    agents = []

    for i in range(n_experts + n_new_customers):
        temp_agents = []
        for j in range(N):
            if args['state_rep'] == 22 or args['state_rep'] == 221 or args[
                    'state_rep'] == 23 and i >= n_experts:
                raise NotImplementedError
            else:
                initial_state = random.choice(customer_states[i])
            states, actions = pe.sample_from_policy(
                env, model, obs_normalizer, initial_state=initial_state)
            states = np.array(states)
            actions = np.array(actions)

            a = pe.get_distrib(states, actions)

            temp_agents.append(a)

        agents.append(temp_agents)

    # for seed in range(n_experts + n_new_customers):
    #     temp_agents = []

    #     if args['state_rep'] == 71:
    #         adam_basket = np.random.permutation(env.case.adam_baskets[seed])
    #         env.case.i_expert = seed

    #     env.model.spawn_new_customer(seed)
    #     sample = env.case.get_sample(
    #         n_demos_per_expert=1,
    #         n_historical_events=args['n_historical_events'],
    #         n_time_steps=1000
    #         )
    #     all_data = np.hstack(sample[0])  # history, data = sample[0]

    #     for i in range(N):
    #         j = np.random.randint(0, all_data.shape[1] - args['n_historical_events'])
    #         history = all_data[:, j:j + args['n_historical_events']]
    #         if args['state_rep'] == 71:
    #            initial_state = env.case.get_initial_state(history, adam_basket[i])
    #         else:
    #             raise NotImplementedError

    #         states, actions = pe.sample_from_policy2(env, model, obs_normalizer, initial_state=initial_state)
    #         states = np.array(states)
    #         actions = np.array(actions)

    #         a = pe.get_distrib(states, actions)

    #         temp_agents.append(a)

    #     agents.append(temp_agents)

    return agents
Beispiel #5
0
def evaluate_on_new_customers(args, model_path, experts, new_customers, compare_features, n_new_customers=None):
    global k, N

    n_experts = args['n_experts']

    if n_new_customers is not None:
        env, model, obs_normalizer = pe.get_env_and_model(args, model_path, sample_length, only_env=False, n_experts_in_adam_basket=n_experts+n_new_customers)
    else:
        env, model, obs_normalizer = pe.get_env_and_model(args, model_path, sample_length, only_env=False)

    agents = []
    abs_diffs = []
    errors = []

    metric = ed if compare_features else wd

    for i, nc in enumerate(new_customers):
        distances = [metric(nc, e) for e in experts]
        closest_experts = np.argsort(distances)[:k]
        dummy = closest_experts[0]

        temp_agents = []
        temp_abs_diffs = []
        n_errors = 0

        seed = n_experts + i

        if args['state_rep'] == 71: 
            adam_basket = np.random.permutation(env.case.adam_baskets[seed])
            env.case.i_expert = seed

        env.model.spawn_new_customer(seed)
        sample = env.case.get_sample(
            n_demos_per_expert=1, 
            n_historical_events=args['n_historical_events'], 
            n_time_steps=1000
            )
        all_data = np.hstack(sample[0])  # history, data = sample[0]

        for l in range(N):
            j = np.random.randint(0, all_data.shape[1] - args['n_historical_events'])
            history = all_data[:, j:j + args['n_historical_events']]
            if args['state_rep'] == 71: 
               initial_state = env.case.get_initial_state(history, adam_basket[l])
            else:
                initial_state = env.case.get_initial_state(history, dummy)  # We set dummy to closest expert

            states, actions = pe.sample_from_policy(env, model, obs_normalizer, initial_state=initial_state)
            states = np.array(states)
            actions = np.array(actions)

            a = get_features([actions]) if compare_features else pe.get_distrib(states, actions)

            temp_agents.append(a)
            temp_abs_diffs.append(metric(a, nc))

            distances = [metric(a, e) for e in experts]
            if np.argmin(distances) not in closest_experts:
                n_errors += 1

        agents.append(temp_agents)
        abs_diffs.append(temp_abs_diffs)
        errors.append(n_errors / N)

    return agents, abs_diffs, errors
Beispiel #6
0
def evaluate_policy_at_population_level_multiple_categories(args, model_dir_path, ending_eps, ending_png, info):
    # Load environment, model and observation normalizer
    env, model, obs_normalizer = pe.get_env_and_model(args, model_dir_path, sample_length, model_path=model_path)

    # Get possible validation states
    possible_val_states = pe.get_possible_val_states(n_last_days, max_n_purchases_per_n_last_days)

    # Sample agent data from both categories
    agent_states_1 = []
    agent_states_2 = []
    agent_actions_1 = []
    agent_actions_2 = []
    #for i in range(args['n_experts']):
    for i in range(args['n_experts']):
        # What happens if the agent is repeatedly initialized with history from expert with unique behavior?
        # Should we collect more than n_experts samples (especially if n_experts <= 10)?
        temp_s_1 = []
        temp_s_2 = []
        temp_a_1 = []
        temp_a_2 = []
        temp_states, temp_actions = pe.sample_from_policy(env, model, obs_normalizer)
        for state, action in zip(temp_states, temp_actions):
            if len(state) == args['n_experts'] + 2*args['n_historical_events']: # if dummy variables == n_experts
                s_1, s_2 = np.split(state[args['n_experts']:], 2)
                #state_2 = state[args['n_historical_events']:]                

            elif len(state) == 10 + 2*args['n_historical_events']: # if dummy variables == 10
                s_1, s_2 = np.split(state[10:], 2)
            temp_s_1.append(s_1)
            temp_s_2.append(s_2)
            a1, a2 = convert_action(action)
            temp_a_1.append(a1)
            temp_a_2.append(a2)

        agent_states_1.append(temp_s_1)
        agent_states_2.append(temp_s_2)
        agent_actions_1.append(temp_a_1)
        agent_actions_2.append(temp_a_2)

    agent_purchase_1, agent_no_purchase_1, agent_n_shopping_days_1 = pe.get_cond_distribs(
        agent_states_1, 
        agent_actions_1, 
        n_last_days, 
        max_n_purchases_per_n_last_days, 
        normalize,
        case=args['state_rep']
        )
    agent_purchase_2, agent_no_purchase_2, agent_n_shopping_days_2 = pe.get_cond_distribs(
        agent_states_1, 
        agent_actions_1, 
        n_last_days, 
        max_n_purchases_per_n_last_days, 
        normalize,
        case=args['state_rep']
        )

    # Sample expert data
    expert_trajectories = env.generate_expert_trajectories(out_dir=None)
    expert_states = expert_trajectories['states']
    expert_actions = expert_trajectories['actions']

    expert_states_1 = []
    expert_states_2 = []
    expert_actions_1 = []
    expert_actions_2 = []

    for i in range(args['n_experts']):
        # What happens if the agent is repeatedly initialized with history from expert with unique behavior?
        # Should we collect more than n_experts samples (especially if n_experts <= 10)?
        temp_s_1 = []
        temp_s_2 = []
        temp_a_1 = []
        temp_a_2 = []
        temp_states = expert_states[i]
        temp_actions = expert_actions[i]
        for state, action in zip(temp_states, temp_actions):
            if len(state) == args['n_experts'] + 2*args['n_historical_events']: # if dummy variables == n_experts
                s_1, s_2 = np.split(state[args['n_experts']:], 2)               

            elif len(state) == 10 + 2*args['n_historical_events']: # if dummy variables == 10
                s_1, s_2 = np.split(state[10:], 2)
            temp_s_1.append(s_1)
            temp_s_2.append(s_2)
            a1, a2 = convert_action(action)
            temp_a_1.append(a1)
            temp_a_2.append(a2)

        expert_states_1.append(temp_s_1)
        expert_states_2.append(temp_s_2)
        expert_actions_1.append(temp_a_1)
        expert_actions_2.append(temp_a_2)

    expert_purchase_1, expert_no_purchase_1, expert_n_shopping_days_1 = pe.get_cond_distribs(
        expert_states_1, 
        expert_actions_1, 
        n_last_days, 
        max_n_purchases_per_n_last_days, 
        normalize,
        case=args['state_rep']
        )

    expert_purchase_2, expert_no_purchase_2, expert_n_shopping_days_2 = pe.get_cond_distribs(
        expert_states_2, 
        expert_actions_2, 
        n_last_days, 
        max_n_purchases_per_n_last_days, 
        normalize,
        case=args['state_rep']
        )

    plot_distr_categories(
        args, 
        info, 
        ending_png, 
        expert_actions_1, 
        agent_actions_1, 
        expert_purchase_1, 
        agent_purchase_1, 
        expert_no_purchase_1, 
        agent_no_purchase_1, 
        agent_n_shopping_days_1, 
        expert_n_shopping_days_1, 
        category=1
        )
    plot_distr_categories(args, 
        info, 
        ending_png, 
        expert_actions_2, 
        agent_actions_2, 
        expert_purchase_2,
        agent_purchase_2, 
        expert_no_purchase_2, 
        agent_no_purchase_2, 
        agent_n_shopping_days_2, 
        expert_n_shopping_days_2, 
        category=2
        )
    plt.show()
Beispiel #7
0
def evaluate_policy_at_population_level(args, model_dir_path, ending_eps, ending_png, info):
    # Load environment, model and observation normalizer
    env, model, obs_normalizer = pe.get_env_and_model(args, model_dir_path, sample_length, model_path=model_path)

    # Get possible validation states
    possible_val_states = pe.get_possible_val_states(n_last_days, max_n_purchases_per_n_last_days)

    # Sample expert data
    expert_trajectories = env.generate_expert_trajectories(out_dir=None)
    expert_states = expert_trajectories['states']
    expert_actions = expert_trajectories['actions']

    expert_purchase, expert_no_purchase, expert_n_shopping_days = pe.get_cond_distribs(
        expert_states, 
        expert_actions, 
        n_last_days, 
        max_n_purchases_per_n_last_days, 
        normalize,
        case=args['state_rep']
        )

    n_experts = 2 if (args['state_rep'] == 24 or args['state_rep'] == 31) else args['n_experts']

    # Sample agent data
    agent_states = []
    agent_actions = []
    for i in range(n_experts):
        initial_state = random.choice(expert_states[i])
        # What happens if the agent is repeatedly initialized with history from expert with unique behavior?
        # Should we collect more than n_experts samples (especially if n_experts <= 10)?
        temp_states, temp_actions = pe.sample_from_policy(env, model, obs_normalizer, initial_state=initial_state)
        agent_states.append(temp_states)
        agent_actions.append(temp_actions)

    agent_purchase, agent_no_purchase, agent_n_shopping_days = pe.get_cond_distribs(
        agent_states, 
        agent_actions, 
        n_last_days, 
        max_n_purchases_per_n_last_days, 
        normalize,
        case=args['state_rep']
        )

    # Calculate Wasserstein distances
    wd_purchase = pe.get_wd(expert_purchase, agent_purchase, normalize)
    wd_no_purchase = pe.get_wd(expert_no_purchase, agent_no_purchase, normalize)
    
    agent_shopping_ratio = format(agent_n_shopping_days / (n_experts * sample_length), '.3f')
    expert_shopping_ratio = format(expert_n_shopping_days / (n_experts * sample_length), '.3f')
    expert_str = 'Expert (p.r.: ' + str(expert_shopping_ratio) + ')'
    agent_str = 'Agent (p.r.: ' + str(agent_shopping_ratio) + ')'

    fig, (ax1, ax2) = plt.subplots(1, 2)
    fig.suptitle('Comparison at population level')

    # Plot (purchase)
    data = {expert_str: expert_purchase, agent_str: agent_purchase}

    pe.bar_plot(ax1, data, colors=None, total_width=0.7)
    ax1.set_xticks([], [])
    ax1.set_title('Purchase | EMD: {:.5f}'.format(wd_purchase))
    # ax1.set_title('Last week | Purchase today')
    ax1.set_ylabel('Probability')

    # Plot (no purchase)
    data = {expert_str: expert_no_purchase, agent_str: agent_no_purchase}
    pe.bar_plot(ax2, data, colors=None, total_width=0.7)
    ax2.set_xticks([], [])
    ax2.set_title('No purchase | EMD: {:.5f}'.format(wd_no_purchase))
    # ax2.set_title('Last week | No purchase today')
    ax2.set_ylabel('Probability')
    
    if show_info: fig.text(0.5, 0.025, info, ha='center')
    if save_plots: save_plt_as_png(fig, path=join(dir_path, 'figs', 'pop' + ending_png))
    if not show_plots: plt.close(fig)

    if args['state_rep'] == 23:
        # Plot histogram of purchase amounts
        expert_amounts = np.ravel(expert_actions)[np.flatnonzero(expert_actions)]
        agent_amounts = np.ravel(agent_actions)[np.flatnonzero(agent_actions)]

        fig, ax = plt.subplots()
        ax.hist(expert_amounts, bins=np.arange(1, 11), alpha=0.8, density=True, label='Expert')
        ax.hist(agent_amounts, bins=np.arange(1, 11), alpha=0.8, density=True, label='Agent')
        ax.set_xlabel('Purchase amount')
        ax.set_ylabel('Normalized frequency')
        ax.legend()

        if show_info: fig.text(0.5, 0.025, info, ha='center')
        if save_plots: save_plt_as_png(fig, path=join(dir_path, 'figs', 'pop_amounts' + ending_png))
        if not show_plots: plt.close(fig)

    if show_plots: plt.show()
Beispiel #8
0
def evaluate_policy_at_individual_level(args, model_dir_path, ending_eps, ending_png, info):
    # Load environment, model and observation normalizer
    env, model, obs_normalizer = pe.get_env_and_model(args, model_dir_path, sample_length, model_path=model_path)

    # Get possible validation states
    possible_val_states = pe.get_possible_val_states(n_last_days, max_n_purchases_per_n_last_days)

    # Sample expert data to calculate average expert behavior
    expert_trajectories = env.generate_expert_trajectories(out_dir=None, n_demos_per_expert=1, n_expert_time_steps=sample_length)
    expert_states = expert_trajectories['states']
    expert_actions = expert_trajectories['actions']
    sex = ['F' if s == 1 else 'M' for s in expert_trajectories['sex']]
    age = [int(a) for a in expert_trajectories['age']]

    avg_expert_purchase, avg_expert_no_purchase, avg_expert_n_shopping_days = pe.get_cond_distribs(
        expert_states, 
        expert_actions, 
        n_last_days, 
        max_n_purchases_per_n_last_days, 
        normalize,
        case=args['state_rep']
        )
    n_experts = 2 if (args['state_rep'] == 24 or args['state_rep'] == 31) else args['n_experts']
    avg_expert_shopping_ratio = format(avg_expert_n_shopping_days / (n_experts * sample_length), '.2f')

    all_expert_indices = range(n_experts)
    expert_indices_list = []
    for i in range(0, n_experts, 4):
        try:
            expert_indices_list.append(all_expert_indices[i:i+4])
        except IndexError:
            expert_indices_list.append(all_expert_indices[i:])

    for j, expert_indices in enumerate(expert_indices_list):
        fig1, axes1 = plt.subplots(2, 2, sharex='col')
        fig2, axes2 = plt.subplots(2, 2, sharex='col')
        
        for i, ax1, ax2 in zip(expert_indices, axes1.flat, axes2.flat):
            # Sample agent data starting with expert's history
            initial_state = random.choice(expert_states[i])

            agent_states, agent_actions = pe.sample_from_policy(env, model, obs_normalizer, initial_state=initial_state)

            agent_purchase, agent_no_purchase, agent_n_shopping_days = pe.get_cond_distribs(
                [agent_states], 
                [agent_actions], 
                n_last_days, 
                max_n_purchases_per_n_last_days, 
                normalize,
                case=args['state_rep']
                )
            agent_shopping_ratio = format(agent_n_shopping_days / sample_length, '.3f')

            expert_purchase, expert_no_purchase, expert_n_shopping_days = pe.get_cond_distribs(
                [expert_states[i]], 
                [expert_actions[i]], 
                n_last_days, 
                max_n_purchases_per_n_last_days, 
                normalize,
                case=args['state_rep']
                )
            expert_shopping_ratio = format(expert_n_shopping_days / sample_length, '.3f')

            if args['state_rep'] == 23:
                expert_histo, _ = np.histogram(expert_actions[i], bins=range(11))
                agent_histo, _ = np.histogram(agent_actions, bins=range(11))

            # Calculate Wasserstein distances
            wd_purchase = pe.get_wd(expert_purchase, agent_purchase, normalize)
            wd_purchase_avg = pe.get_wd(avg_expert_purchase, agent_purchase, normalize)
            wd_no_purchase = pe.get_wd(expert_no_purchase, agent_no_purchase, normalize)
            wd_no_purchase_avg = pe.get_wd(avg_expert_no_purchase, agent_no_purchase, normalize)

            expert_str = 'Expert (p.r.: ' + str(expert_shopping_ratio) + ')'
            agent_str = 'Agent (p.r.: ' + str(agent_shopping_ratio) + ')'
            avg_expert_str = 'Avg. expert (p.r.: ' + str(avg_expert_shopping_ratio) + ')'

            # Plot (purchase)
            data = {expert_str: expert_purchase, agent_str: agent_purchase, avg_expert_str: avg_expert_purchase}
            pe.bar_plot(ax1, data, colors=None, total_width=0.7)
            ax1.set_title('Expert {}: {}, age {}\nEMD (expert): {:.5f} | EMD (avg. expert): {:.5f}'.format(i+1, sex[i], age[i], wd_purchase, wd_purchase_avg))

            # Plot (no purchase)
            data = {expert_str: expert_no_purchase, agent_str: agent_no_purchase, avg_expert_str: avg_expert_no_purchase}
            pe.bar_plot(ax2, data, colors=None, total_width=0.7)
            ax2.set_title('Expert {}: {}, age {}\nEMD (expert): {:.5f} | EMD (avg. expert): {:.5f}'.format(i+1, sex[i], age[i], wd_purchase, wd_purchase_avg))

        fig1.suptitle('Comparison at individual level (purchase)')
        fig2.suptitle('Comparison at individual level (no purchase)')

        if show_info:
            for ax1, ax2 in zip(axes1[1][:], axes2[1][:]):
                ax1.set_xticks([], [])
                ax2.set_xticks([], [])
            fig1.text(0.5, 0.025, info, ha='center')
            fig2.text(0.5, 0.025, info, ha='center')
        else:
            fig1.subplots_adjust(bottom=0.2)
            fig2.subplots_adjust(bottom=0.2)
            for ax1, ax2 in zip(axes1[1][:], axes2[1][:]):
                pe.set_xticks(ax1, possible_val_states, max_n_purchases_per_n_last_days)
                pe.set_xticks(ax2, possible_val_states, max_n_purchases_per_n_last_days)

        if save_plots: save_plt_as_png(fig1, path=join(dir_path, 'figs', 'ind_purchase_' + str(j+1) + ending_png))
        if save_plots: save_plt_as_png(fig2, path=join(dir_path, 'figs', 'ind_no_purchase_' + str(j+1) + ending_png))

        if not show_plots: 
            plt.close(fig1)
            plt.close(fig2)

        if show_plots: plt.show()
Beispiel #9
0
def compare_clusters(args, model_dir_path, ending_eps, ending_png, info):
    # Load environment, model and observation normalizer
    env, model, obs_normalizer = pe.get_env_and_model(args, model_dir_path, sample_length, model_path=model_path)

    # Get possible validation states
    possible_val_states = pe.get_possible_val_states(n_last_days, max_n_purchases_per_n_last_days)

    # Get multiple samples from each expert
    assert (sample_length % n_demos_per_expert) == 0
    expert_trajectories = env.generate_expert_trajectories(
        out_dir=None, 
        n_demos_per_expert=n_demos_per_expert, 
        n_expert_time_steps=int(sample_length / n_demos_per_expert)
        )
    expert_states = np.array(expert_trajectories['states'])
    expert_actions = np.array(expert_trajectories['actions'])
    sex = ['F' if s == 1 else 'M' for s in expert_trajectories['sex']]
    age = [int(a) for a in expert_trajectories['age']]

    n_experts = 2 if (args['state_rep'] == 24 or args['state_rep'] == 31) else args['n_experts']
    
    experts = []
    for states, actions in zip(np.split(expert_states, n_experts), np.split(expert_actions, n_experts)):  # Loop over experts
        purchases = []
        no_purchases = []

        for s, a in zip(states, actions):  # Loop over demonstrations       
            temp_purchase, temp_no_purchase, _ = pe.get_cond_distribs(
                [s], 
                [a], 
                n_last_days, 
                max_n_purchases_per_n_last_days, 
                normalize,
                case=args['state_rep']
                )
            purchases.append(temp_purchase)
            no_purchases.append(temp_no_purchase)

        avg_purchase, avg_no_purchase, _ = pe.get_cond_distribs(
            states, 
            actions, 
            n_last_days, 
            max_n_purchases_per_n_last_days, 
            normalize,
            case=args['state_rep']
            )

        experts.append(Expert(purchases, no_purchases, avg_purchase, avg_no_purchase))

    # Calculate average expert behavior
    expert_trajectories = env.generate_expert_trajectories(out_dir=None, n_demos_per_expert=1, n_expert_time_steps=sample_length)
    expert_states = expert_trajectories['states']
    expert_actions = expert_trajectories['actions']

    avg_expert_purchase, avg_expert_no_purchase, _ = pe.get_cond_distribs(
        expert_states, 
        expert_actions, 
        n_last_days, 
        max_n_purchases_per_n_last_days, 
        normalize,
        case=args['state_rep']
        )

    if cluster_comparison and (args['state_rep'] != 24 or args['state_rep'] != 31):
        # Cluster expert data (purcase)
        X = np.array([e.avg_purchase for e in experts])
        T_purchase = fclusterdata(X, 3, 'maxclust', method='single', metric=lambda u, v: wasserstein_distance(u, v))
        T_purchase = pe.get_cluster_labels(T_purchase)

        # Cluster expert data (purcase)
        X = np.array([e.avg_no_purchase for e in experts])
        T_no_purchase = fclusterdata(X, 3, 'maxclust', method='single', metric=lambda u, v: wasserstein_distance(u, v))
        T_no_purchase = pe.get_cluster_labels(T_no_purchase)

        assert np.array_equal(T_purchase, T_no_purchase)
        cluster_indices = [np.argwhere(T_purchase == i) for i in [1, 2, 3]]

        distances_purchase = []
        distances_no_purchase = []
    
    all_distances_purchase = []
    all_distances_no_purchase = []

    for i in range(n_experts):
        # Sample agent data starting with expert's history
        initial_state = random.choice(expert_states[i])
        agent_states, agent_actions = pe.sample_from_policy(env, model, obs_normalizer, initial_state=initial_state)

        agent_purchase, agent_no_purchase, _ = pe.get_cond_distribs(
            [agent_states], 
            [agent_actions], 
            n_last_days, 
            max_n_purchases_per_n_last_days, 
            normalize,
            case=args['state_rep']
            )

        e = experts[i]

        # Compare distributions (purchase)
        if cluster_comparison and (args['state_rep'] != 24 or args['state_rep'] != 31):
            temp = [e.avg_dist_purchase]
            temp.append(pe.get_wd(e.avg_purchase, agent_purchase, normalize))
            temp.append(pe.get_wd(avg_expert_purchase, agent_purchase, normalize))
            temp.append(np.mean([pe.get_wd(experts[j[0]].avg_purchase, agent_purchase, normalize) for j in cluster_indices[0]]))
            temp.append(np.mean([pe.get_wd(experts[j[0]].avg_purchase, agent_purchase, normalize) for j in cluster_indices[1]]))
            temp.append(np.mean([pe.get_wd(experts[j[0]].avg_purchase, agent_purchase, normalize) for j in cluster_indices[2]]))
            distances_purchase.append(temp)

        temp = [pe.get_wd(e.avg_purchase, agent_purchase, normalize) for e in experts]
        temp.append(pe.get_wd(avg_expert_purchase, agent_purchase, normalize))
        all_distances_purchase.append(temp)

        # Compare distributions (no purchase)
        if cluster_comparison and (args['state_rep'] != 24 or args['state_rep'] != 31):
            temp = [e.avg_dist_no_purchase]
            temp.append(pe.get_wd(e.avg_no_purchase, agent_no_purchase, normalize))
            temp.append(pe.get_wd(avg_expert_no_purchase, agent_no_purchase, normalize))
            temp.append(np.mean([pe.get_wd(experts[j[0]].avg_no_purchase, agent_no_purchase, normalize) for j in cluster_indices[0]]))
            temp.append(np.mean([pe.get_wd(experts[j[0]].avg_no_purchase, agent_no_purchase, normalize) for j in cluster_indices[1]]))
            temp.append(np.mean([pe.get_wd(experts[j[0]].avg_no_purchase, agent_no_purchase, normalize) for j in cluster_indices[2]]))
            distances_no_purchase.append(temp)

        temp = [pe.get_wd(e.avg_no_purchase, agent_no_purchase, normalize) for e in experts]
        temp.append(pe.get_wd(avg_expert_no_purchase, agent_no_purchase, normalize))
        all_distances_no_purchase.append(temp)

    if cluster_comparison and (args['state_rep'] != 24 or args['state_rep'] != 31):
        ##### Plot distance to one expert #####
        columns = ['Var. in expert cluster', 'Dist. to expert', 'Dist. to avg. expert', 'Dist. to 1st cluster 1', 'Dist. to 2nd cluster', 'Dist. to 3rd cluster']
        index = ['E{}\n({})'.format(i + 1, int(T_purchase[i])) for i in range(n_experts)]

        fig, (ax1, ax2) = plt.subplots(1, 2)
        fig.subplots_adjust(bottom=0.30)

        # Plot distance (purchase)
        distances_purchase = pd.DataFrame(distances_purchase, columns=columns, index=index)
        seaborn.heatmap(distances_purchase, cmap='BuPu', ax=ax1, linewidth=1, cbar_kws={'label': 'EMD'})
        ax1.set_title('Purchase')

        distances_no_purchase = pd.DataFrame(distances_no_purchase, columns=columns, index=index)
        seaborn.heatmap(distances_no_purchase, cmap='BuPu', ax=ax2, linewidth=1, cbar_kws={'label': 'EMD'})
        ax2.set_title('No purchase')

        if show_info: fig.text(0.5, 0.025, info, ha='center')
        if save_plots: save_plt_as_png(fig, path=join(dir_path, 'figs', 'heatmap' + ending_png))
        if not show_plots: plt.close(fig)

    ##### Plot distance to all experts #####
    fig, (ax1, ax2) = plt.subplots(1, 2, sharey='row')
    fig.subplots_adjust(bottom=0.25)

    columns = ['Expert {}'.format(i + 1) for i in range(n_experts)]
    columns.append('Avg. expert')
    index = ['Agent {}'.format(i + 1) for i in range(n_experts)]

    # Plot the distance between each expert cluster (purcahse)
    all_distances_purchase = pd.DataFrame(all_distances_purchase, columns=columns, index=index)
    seaborn.heatmap(all_distances_purchase, cmap='BuPu', ax=ax1, linewidth=1, cbar_kws={'label': 'EMD'})
    ax1.set_title('Purchase')

    # Plot the distance between each expert cluster (no purcahse)
    all_distances_no_purchase = pd.DataFrame(all_distances_no_purchase, columns=columns, index=index)
    seaborn.heatmap(all_distances_no_purchase, cmap='BuPu', ax=ax2, linewidth=1, cbar_kws={'label': 'EMD'})
    ax2.set_title('No purchase')

    if show_info: fig.text(0.5, 0.025, info, ha='center')
    if save_plots: save_plt_as_png(fig, path=join(dir_path, 'figs', 'heatmap_all' + ending_png))
    if not show_plots: plt.close(fig)
    
    if show_plots: plt.show()

    '''