Esempio n. 1
0
def train_pointwise(**kwargs):
    """Train a policy with a specific curriculum 
    strategy (random, etc), 
    policy training method (trpo, ppo, etc), and start state sampling 
    strategy (uniform, weighted by value function, etc).
    """
    problem = kwargs["problem"];
    initial_policy = kwargs["initial_policy"];
    goal_state = kwargs["goal_state"];
    full_start_dist = kwargs["full_start_dist"];
    N_new, N_old = kwargs["N_new"], kwargs["N_old"];
    R_min, R_max = kwargs["R_min"], kwargs["R_max"];
    num_iters = kwargs["num_iters"];
    num_ppo_iters = kwargs["num_ppo_iters"];
    curriculum_strategy = kwargs["curriculum_strategy"];
    train_algo = kwargs["train_algo"];
    start_distribution = kwargs["start_distribution"];
    debug = kwargs["debug"];
    data_logger = kwargs["data_logger"];

    # Keyword arguments for the curriculum strategy.
    curric_kwargs = defaultdict(lambda: None)

    # Not used in algorithm, only for visualization.
    all_starts = [goal_state]
    
    old_starts = [goal_state]
    starts = [goal_state]
    pi_i = initial_policy
    overall_perf, overall_area = list(), list()
    ppo_perf, ppo_lens, ppo_rews = list(), list(), list()
    perf_metric = 0.0
    i = 0
    ppo_iter_count = 0;
    pi_i.save_model(MODEL_DIR, iteration=i);
    while i < num_iters:
    # while perf_metric < args.finish_threshold and i < num_iters:
        print('Training Iteration %d' % i, flush=True)
        data_logger.update_indices({"overall_iter": i, "ppo_iter": ppo_iter_count})

        new_starts = curriculum_strategy(starts, N_new, problem)

        if debug and len(new_starts) > 0:
            visualize_starts(new_starts, problem, 
                             figfile=os.path.join(FIGURES_DIR, 'curric_starts_iter_%d' % i))

        from_replay = sample(old_starts, size=N_old)
        starts = new_starts + from_replay

        if debug:
            visualize_starts(new_starts, problem, 
                             old_starts=from_replay,
                             old_start_color='orange',
                             old_start_name='Replay Starts',
                             figfile=os.path.join(FIGURES_DIR, 'replay_and_curric_starts_iter_%d' % i))
            
            visualize_starts(None, problem, 
                             old_starts=all_starts,
                             old_start_color='grey',
                             old_start_name='Old Starts',
                             figfile=os.path.join(FIGURES_DIR, 'previous_starts_iter_%d' % i))

        rho_i = list(zip(starts, start_distribution(starts)))
        pi_i, rewards_map, ep_mean_lens, ep_mean_rews = train_step(rho_i, pi_i, train_algo, problem, num_ppo_iters=num_ppo_iters)

        if debug:
            if problem.env_name == 'DrivingOrigin-v0':
                visualize_starts(starts, problem, 
                                 rewards_map=rewards_map,
                                 figfile=os.path.join(FIGURES_DIR, 'start_rews_iter_%d' % i))

            visualize_rollouts([sample(starts)], pi_i, problem,
                               figfile=os.path.join(FIGURES_DIR, 'rollouts_iter_%d' % i))

        data_logger.save_to_npy('curr_starts', starts);

        all_starts.extend(starts)
        total_unique_starts = len(dedupe_list_of_np_arrays(starts))
        starts = select(starts, rewards_map, R_min, R_max, problem)
        successful_starts = len(starts)
        pct_successful = float(successful_starts)/total_unique_starts;
        old_starts.extend(starts)

        ppo_perf.append(pct_successful*100.)
        ppo_lens.extend(ep_mean_lens)
        ppo_rews.extend(ep_mean_rews)

        if debug:
            plot_performance(range(len(ppo_perf)), ppo_perf, ylabel=r'% Successful Starts', xlabel=r'PPO Iteration ($\times %d$)' % num_ppo_iters, figfile=os.path.join(FIGURES_DIR, 'ppo_pct_succ_iter_%d' % i))
            plot_performance(range(len(ppo_lens)), ppo_lens, ylabel=r'Avg. Episode Length', xlabel='PPO Iteration', figfile=os.path.join(FIGURES_DIR, 'ppo_avg_lens_iter_%d' % i))
            plot_performance(range(len(ppo_rews)), ppo_rews, ylabel=r'Avg. Episode Reward', xlabel='PPO Iteration', figfile=os.path.join(FIGURES_DIR, 'ppo_avg_rews_iter_%d' % i))

        data_logger.save_to_npy('all_starts', all_starts);
        data_logger.save_to_npy('old_starts', old_starts);
        data_logger.save_to_npy('selected_starts', starts);
        data_logger.save_to_npy('new_starts', new_starts);
        data_logger.save_to_npy('from_replay', from_replay);

        ppo_iter_count += num_ppo_iters;

        perf_metric = evaluate(pi_i, 
                             full_start_dist, 
                             problem, 
                             debug=debug, 
                             figfile=os.path.join(FIGURES_DIR, 'eval_iter_%d' % i))

        # Format is (min_x, max_x, min_y, max_y)
        all_starts_bbox = bounding_box(all_starts)
        min_x = problem.state_space.low[X_IDX]
        max_x = problem.state_space.high[X_IDX]
        min_y = problem.state_space.low[Y_IDX]
        max_y = problem.state_space.high[Y_IDX]
        area_coverage = bounding_box_area(all_starts_bbox) / bounding_box_area((min_x, max_x, min_y, max_y))

        overall_perf.append(perf_metric)
        overall_area.append(area_coverage*100.)

        data_logger.add_rows({'overall_perf': [perf_metric], 'overall_area': [area_coverage], 
                              'ppo_perf': [pct_successful], 'ppo_lens': ep_mean_lens, 'ppo_rews': ep_mean_rews},
                              update_indices=['ppo_iter'])

        if debug:
            plot_performance(range(len(overall_perf)), overall_perf, ylabel=r'% Successful Starts', xlabel='Iteration', figfile=os.path.join(FIGURES_DIR, 'overall_perf'))
            plot_performance(range(len(overall_area)), overall_area, ylabel=r'% State Space Sampled', xlabel='Iteration', figfile=os.path.join(FIGURES_DIR, 'overall_area'))

        print('[Overall Iter %d]: perf_metric = %.2f | Area Coverage = %.2f%%' % (i, perf_metric, area_coverage*100.));

        # Incrementing our algorithm's loop counter.
        i += 1;

        data_logger.save_to_file();
        pi_i.save_model(MODEL_DIR, iteration=i);

    return pi_i
Esempio n. 2
0
def train_gridwise(**kwargs):
    """Train a policy with a specific curriculum 
    strategy (backward reachable, etc), 
    policy training method (trpo, ppo, etc), and start state sampling 
    strategy (uniform, weighted by value function, etc).
    """
    problem = kwargs["problem"];
    initial_policy = kwargs["initial_policy"];
    goal_state = kwargs["goal_state"];
    full_start_dist = kwargs["full_start_dist"];
    N_new, N_old = kwargs["N_new"], kwargs["N_old"];
    R_min, R_max = kwargs["R_min"], kwargs["R_max"];
    num_iters = kwargs["num_iters"];
    num_ppo_iters = kwargs["num_ppo_iters"];
    curriculum_strategy = kwargs["curriculum_strategy"];
    train_algo = kwargs["train_algo"];
    #start_distribution = kwargs["start_distribution"];
    start_distribution = uniform
    debug = kwargs["debug"];
    data_logger = kwargs["data_logger"];

    # Keyword arguments for the curriculum strategy.
    curric_kwargs = defaultdict(lambda: None)
    curric_kwargs["debug"] = debug
    if curriculum_strategy == backward_reachable:
        br_engine = BackreachEngine()
        br_engine.reset_variables(problem, os.path.join(FIGURES_DIR, ''))
        curric_kwargs['br_engine'] = br_engine
        curric_kwargs['curr_train_iter'] = 0
        curric_kwargs['brs_sample'] = args.brs_sample
        curric_kwargs['variation'] = args.variation
        curric_kwargs['problem'] = problem

    # Not used in algorithm, only for visualization.
    all_starts = [goal_state]
    
    old_starts = [goal_state]
    starts = [goal_state]
    pi_i = initial_policy
    overall_perf, overall_area = list(), list()
    perf_metric = 0.0
    i = 0
    pi_i.save_model(MODEL_DIR, iteration=i);
    while i < num_iters:
    # while perf_metric < args.finish_threshold and i < num_iters:
        print('Training Iteration %d' % i, flush=True);
        data_logger.update_indices({"overall_iter": i})

        if 'curr_train_iter' in curric_kwargs:
            curric_kwargs['curr_train_iter'] = i;

        # I've split apart the following call into two separate ones.
        # new_starts = curriculum_strategy(starts, N_new, problem, **curric_kwargs)
        if curriculum_strategy == backward_reachable:
            update_back_rectangle(delta_t=0.1, **curric_kwargs)
            #data_logger.save_to_npy('brs_targets', starts);

        pct_successful = 0.0;
        iter_count = 0;
        ppo_perf, ppo_lens, ppo_rews = list(), list(), list()
        # Think of this as "while (haven't passed this grade)"
        while pct_successful < 0.5:
            data_logger.update_indices({"ppo_iter": iter_count})

            if curriculum_strategy == backward_reachable:
                new_starts = sample_from_back_rectangle(N_new, **curric_kwargs);

            ########################## haven't finished the visualization part for rectangle ##########################

            #     if debug:
            #         br_engine.visualize_grids(os.path.join(FIGURES_DIR, ''), '_iter_%d_ppo_iter_%d' % (i, iter_count))
            #
            # if debug:
            #     visualize_starts(new_starts, problem,
            #                      figfile=os.path.join(FIGURES_DIR, 'curric_starts_iter_%d_ppo_iter_%d' % (i, iter_count)))

            ########################## haven't finished the visualization part for rectangle ##########################

            if args.variation == 2 and br_engine.check_membership(np.array([problem.env.unwrapped.start_state])):
                print('Variation 2 condition train.py!', flush=True)
                from_replay = list()
                starts = [problem.env.unwrapped.start_state]
            else:
                from_replay = sample(old_starts, size=N_old)
                starts = new_starts + from_replay

            ########################## haven't finished the visualization part for rectangle ##########################
            
            if debug:
                visualize_starts(new_starts, problem, 
                                 old_starts=from_replay,
                                 old_start_color='orange',
                                 old_start_name='Replay Starts',
                                 figfile=os.path.join(FIGURES_DIR, 'replay_and_curric_starts_iter_%d_ppo_iter_%d' % (i, iter_count)))
                
                visualize_starts(None, problem, 
                                 old_starts=all_starts,
                                 old_start_color='grey',
                                 old_start_name='Old Starts',
                                 old_start_alpha=0.2,
                                 with_arrows=False,
                                 figfile=os.path.join(FIGURES_DIR, 'previous_starts_iter_%d_ppo_iter_%d' % (i, iter_count)))

            ########################## haven't finished the visualization part for rectangle ##########################
            a = [1./len(starts)]*len(starts)
            #rho_i = list(zip(starts, start_distribution(starts)))
            rho_i = list(zip(starts, a))
            pi_i, rewards_map, ep_mean_lens, ep_mean_rews = train_step(rho_i, pi_i, train_algo, problem, num_ppo_iters=num_ppo_iters)

            ########################## haven't finished the visualization part for rectangle ##########################
            if debug:
                if problem.env_name == 'DrivingOrigin-v0':
                    visualize_starts(starts, problem, 
                                     rewards_map=rewards_map,
                                     figfile=os.path.join(FIGURES_DIR, 'start_rews_iter_%d_ppo_iter_%d' % (i, iter_count)))

                visualize_rollouts([sample(starts)], pi_i, problem,
                                  figfile=os.path.join(FIGURES_DIR, 'rollouts_iter_%d_ppo_iter_%d' % (i, iter_count)))

            ########################## haven't finished the visualization part for rectangle ##########################

            data_logger.save_to_npy('curr_starts', starts);

            all_starts.extend(starts)
            total_unique_starts = len(dedupe_list_of_np_arrays(starts))
            starts = select(starts, rewards_map, R_min, R_max, problem)
            successful_starts = len(starts)
            pct_successful = float(successful_starts)/total_unique_starts;

            ppo_perf.append(pct_successful*100.)
            ppo_lens.extend(ep_mean_lens)
            ppo_rews.extend(ep_mean_rews)

            data_logger.add_rows({'ppo_perf': [pct_successful], 'ppo_lens': ep_mean_lens, 'ppo_rews': ep_mean_rews}, update_indices=['ppo_iter'])

            if debug:
                plot_performance(range(len(ppo_perf)), ppo_perf, ylabel=r'% Successful Starts', xlabel=r'PPO Iteration ($\times %d$)' % num_ppo_iters, figfile=os.path.join(FIGURES_DIR, 'ppo_pct_succ_iter_%d' % i))
                plot_performance(range(len(ppo_lens)), ppo_lens, ylabel=r'Avg. Episode Length', xlabel='PPO Iteration', figfile=os.path.join(FIGURES_DIR, 'ppo_avg_lens_iter_%d' % i))
                plot_performance(range(len(ppo_rews)), ppo_rews, ylabel=r'Avg. Episode Reward', xlabel='PPO Iteration', figfile=os.path.join(FIGURES_DIR, 'ppo_avg_rews_iter_%d' % i))

            data_logger.save_to_npy('all_starts', all_starts);
            data_logger.save_to_npy('old_starts', old_starts);
            data_logger.save_to_npy('selected_starts', starts);
            data_logger.save_to_npy('new_starts', new_starts);
            data_logger.save_to_npy('from_replay', from_replay);

            iter_count += num_ppo_iters;
            print('[PPO Iter %d]: %.2f%% Successful Starts (%d / %d)' % (iter_count, pct_successful*100., successful_starts, total_unique_starts));

        # This final update is so we get the last iter_count correctly after jumping out of the while loop.
        data_logger.update_indices({"ppo_iter": iter_count})

        # Ok, we've graduated!
        old_starts.extend(starts)
        perf_metric, overall_reward = evaluate(pi_i,
                             full_start_dist, 
                             problem, 
                             debug=debug, 
                             figfile=os.path.join(FIGURES_DIR, 'eval_iter_%d' % i))

        # Format is (min_x, max_x, min_y, max_y)
        all_starts_bbox = bounding_box(all_starts)
        min_x = problem.state_space.low[X_IDX]
        max_x = problem.state_space.high[X_IDX]
        min_y = problem.state_space.low[Y_IDX]
        max_y = problem.state_space.high[Y_IDX]
        area_coverage = bounding_box_area(all_starts_bbox) / bounding_box_area((min_x, max_x, min_y, max_y))
        
        overall_perf.append(perf_metric)
        overall_area.append(area_coverage*100.)

        data_logger.add_rows({'overall_perf': [perf_metric], 'overall_area': [area_coverage],
                          'overall_reward': [overall_reward]})

        if debug:
            plot_performance(range(len(overall_perf)), overall_perf, ylabel=r'% Successful Starts', xlabel='Iteration', figfile=os.path.join(FIGURES_DIR, 'overall_perf'))
            plot_performance(range(len(overall_area)), overall_area, ylabel=r'% State Space Sampled', xlabel='Iteration', figfile=os.path.join(FIGURES_DIR, 'overall_area'))

        print('[Overall Iter %d]: perf_metric = %.2f | Area Coverage = %.2f%%' % (i, perf_metric, area_coverage*100.));

        # Incrementing our algorithm's loop counter.
        i += 1;

        data_logger.save_to_file();
        pi_i.save_model(MODEL_DIR, iteration=i);

    # Done!
    if curriculum_strategy == backward_reachable:
        br_engine.stop();
        del br_engine;

    return pi_i
Esempio n. 3
0
File: train.py Progetto: xubo92/BaRC
def train_gridwise(**kwargs):
    """Train a policy with a specific curriculum 
    strategy (backward reachable, etc), 
    policy training method (trpo, ppo, etc), and start state sampling 
    strategy (uniform, weighted by value function, etc).
    """
    problem = kwargs["problem"]
    initial_policy = kwargs["initial_policy"]
    goal_state = kwargs["goal_state"]
    full_start_dist = kwargs["full_start_dist"]
    N_new, N_old = kwargs["N_new"], kwargs["N_old"]
    R_min, R_max = kwargs["R_min"], kwargs["R_max"]
    num_iters = kwargs["num_iters"]
    num_ppo_iters = kwargs["num_ppo_iters"]
    curriculum_strategy = kwargs["curriculum_strategy"]
    train_algo = kwargs["train_algo"]
    start_distribution = kwargs["start_distribution"]
    debug = kwargs["debug"]
    data_logger = kwargs["data_logger"]

    gl_ppo_rewards = []
    gl_overall_perf = []

    # Keyword arguments for the curriculum strategy.
    curric_kwargs = defaultdict(lambda: None)
    curric_kwargs["debug"] = debug
    if curriculum_strategy == backward_reachable:
        br_engine = BackreachEngine()
        br_engine.reset_variables(problem, os.path.join(FIGURES_DIR, ''))
        curric_kwargs['br_engine'] = br_engine
        curric_kwargs['curr_train_iter'] = 0
        curric_kwargs['brs_sample'] = args.brs_sample
        curric_kwargs['variation'] = args.variation
        curric_kwargs['problem'] = problem

    # Not used in algorithm, only for visualization.
    all_starts = [goal_state]

    old_starts = [goal_state]
    starts = [goal_state]
    pi_i = initial_policy
    overall_perf, overall_area = list(), list()
    perf_metric = 0.0
    i = 0
    pi_i.save_model(MODEL_DIR, iteration=i)
    # Note: here is a list contains distance for each algo iter --> [[size = 300]*num_iter]
    gl_samples_dist = []

    # Note: open an figure object outside of loop
    plt.figure()
    plt.ion()
    while i < num_iters:
        # while perf_metric < args.finish_threshold and i < num_iters:
        print('Training Iteration %d' % i)
        data_logger.update_indices({"overall_iter": i})

        if 'curr_train_iter' in curric_kwargs:
            curric_kwargs['curr_train_iter'] = i

        # I've split apart the following call into two separate ones.
        # new_starts = curriculum_strategy(starts, N_new, problem, **curric_kwargs)
        if curriculum_strategy == backward_reachable:
            update_backward_reachable_set(starts, **curric_kwargs)
            data_logger.save_to_npy('brs_targets', starts)

        pct_successful = 0.0
        iter_count = 0
        ppo_perf, ppo_lens, ppo_rews = list(), list(), list()
        sample_dis = list()
        # Think of this as "while (haven't passed this grade)"
        while pct_successful < 0.5:
            data_logger.update_indices({"ppo_iter": iter_count})

            if curriculum_strategy == backward_reachable:
                new_starts = sample_from_backward_reachable_set(
                    N_new, **curric_kwargs)

                if debug:
                    br_engine.visualize_grids(
                        os.path.join(FIGURES_DIR, ''),
                        '_iter_%d_ppo_iter_%d' % (i, iter_count))

            if debug:
                visualize_starts(
                    new_starts,
                    problem,
                    figfile=os.path.join(
                        FIGURES_DIR,
                        'curric_starts_iter_%d_ppo_iter_%d' % (i, iter_count)))

            if args.variation == 2 and br_engine.check_membership(
                    np.array([problem.env.unwrapped.start_state])):
                print('Variation 2 condition train.py!')
                from_replay = list()
                starts = [problem.env.unwrapped.start_state]
            else:
                from_replay = sample(old_starts, size=N_old)
                starts = new_starts + from_replay

            if debug:
                visualize_starts(
                    new_starts,
                    problem,
                    old_starts=from_replay,
                    old_start_color='orange',
                    old_start_name='Replay Starts',
                    figfile=os.path.join(
                        FIGURES_DIR,
                        'replay_and_curric_starts_iter_%d_ppo_iter_%d' %
                        (i, iter_count)))

                visualize_starts(None,
                                 problem,
                                 old_starts=all_starts,
                                 old_start_color='grey',
                                 old_start_name='Old Starts',
                                 old_start_alpha=0.2,
                                 with_arrows=False,
                                 figfile=os.path.join(
                                     FIGURES_DIR,
                                     'previous_starts_iter_%d_ppo_iter_%d' %
                                     (i, iter_count)))

            rho_i = list(zip(starts, start_distribution(starts)))
            pi_i, rewards_map, ep_mean_lens, ep_mean_rews = train_step(
                rho_i, pi_i, train_algo, problem, num_ppo_iters=num_ppo_iters)

            # Note: here we compute distance between samples' pos and start's pos in one algo iter
            for ind in range(len(starts)):
                tmp_p1 = (starts[ind][0], starts[ind][2])
                tmp_p2 = (full_start_dist[0][0][0], full_start_dist[0][0][2])
                sample_dis.append(Euclid_dis(tmp_p1, tmp_p2))

            if debug:
                if problem.env_name == 'DrivingOrigin-v0':
                    visualize_starts(starts,
                                     problem,
                                     rewards_map=rewards_map,
                                     figfile=os.path.join(
                                         FIGURES_DIR,
                                         'start_rews_iter_%d_ppo_iter_%d' %
                                         (i, iter_count)))

                visualize_rollouts(
                    [sample(starts)],
                    pi_i,
                    problem,
                    figfile=os.path.join(
                        FIGURES_DIR,
                        'rollouts_iter_%d_ppo_iter_%d' % (i, iter_count)))

            data_logger.save_to_npy('curr_starts', starts)

            all_starts.extend(starts)
            total_unique_starts = len(dedupe_list_of_np_arrays(starts))
            starts = select(starts, rewards_map, R_min, R_max, problem)
            successful_starts = len(starts)
            pct_successful = float(successful_starts) / total_unique_starts

            ppo_perf.append(pct_successful * 100.)
            ppo_lens.extend(ep_mean_lens)
            ppo_rews.extend(ep_mean_rews)

            data_logger.add_rows(
                {
                    'ppo_perf': [pct_successful],
                    'ppo_lens': ep_mean_lens,
                    'ppo_rews': ep_mean_rews
                },
                update_indices=['ppo_iter'])

            if debug:
                plot_performance(
                    range(len(ppo_perf)),
                    ppo_perf,
                    ylabel=r'% Successful Starts',
                    xlabel=r'PPO Iteration ($\times %d$)' % num_ppo_iters,
                    figfile=os.path.join(FIGURES_DIR,
                                         'ppo_pct_succ_iter_%d' % i))
                plot_performance(range(len(ppo_lens)),
                                 ppo_lens,
                                 ylabel=r'Avg. Episode Length',
                                 xlabel='PPO Iteration',
                                 figfile=os.path.join(
                                     FIGURES_DIR, 'ppo_avg_lens_iter_%d' % i))
                plot_performance(range(len(ppo_rews)),
                                 ppo_rews,
                                 ylabel=r'Avg. Episode Reward',
                                 xlabel='PPO Iteration',
                                 figfile=os.path.join(
                                     FIGURES_DIR, 'ppo_avg_rews_iter_%d' % i))

            # NOTE: modified for save global data
            for it in range(len(ppo_rews)):
                gl_ppo_rewards.append(ppo_rews[it])

            data_logger.save_to_npy('all_starts', all_starts)
            data_logger.save_to_npy('old_starts', old_starts)
            data_logger.save_to_npy('selected_starts', starts)
            data_logger.save_to_npy('new_starts', new_starts)
            data_logger.save_to_npy('from_replay', from_replay)

            iter_count += num_ppo_iters
            print('[PPO Iter %d]: %.2f%% Successful Starts (%d / %d)' %
                  (iter_count, pct_successful * 100., successful_starts,
                   total_unique_starts))

        # This final update is so we get the last iter_count correctly after jumping out of the while loop.
        data_logger.update_indices({"ppo_iter": iter_count})

        # Ok, we've graduated!
        old_starts.extend(starts)
        perf_metric, rollout_count = evaluate(pi_i,
                                              full_start_dist,
                                              problem,
                                              debug=debug,
                                              figfile=os.path.join(
                                                  FIGURES_DIR,
                                                  'eval_iter_%d' % i))

        # Note: here I save trained policy in each algo iter into pk file.
        # pkl.dump(pi_i, open(os.path.join(POLICY_DIR, 'policy_iter_%d' % i), 'wb'))
        # Format is (min_x, max_x, min_y, max_y)
        all_starts_bbox = bounding_box(all_starts)
        min_x = problem.state_space.low[X_IDX]
        max_x = problem.state_space.high[X_IDX]
        min_y = problem.state_space.low[Y_IDX]
        max_y = problem.state_space.high[Y_IDX]
        area_coverage = bounding_box_area(all_starts_bbox) / bounding_box_area(
            (min_x, max_x, min_y, max_y))

        overall_perf.append(perf_metric)
        overall_area.append(area_coverage * 100.)

        data_logger.add_rows({
            'overall_perf': [perf_metric],
            'overall_area': [area_coverage]
        })

        # NOTE: modified for save global data
        gl_overall_perf = overall_perf
        gl_overall_area = overall_area
        gl_samples_dist.append(sample_dis)

        # NOTE: show the samples distribution stage by stage
        if i > 0 and not (i + 1) % 5:
            plt.cla()
            show_dis = []

            for ind in range(5 * (i // 5), len(gl_samples_dist)):
                show_dis.extend(gl_samples_dist[ind])

            sns_plot = sns.distplot(show_dis, rug=True)
            sns_plot.figure.savefig(FIGURES_DIR + '/dist_iter' + str(i) +
                                    '.png')
            print("saving the distance plot")

            # sns.distplot(show_dis, rug=True)
            #print("showing the distance distribution")
            plt.pause(0.1)

        if debug:
            plot_performance(range(len(overall_perf)),
                             overall_perf,
                             ylabel=r'% Successful Starts',
                             xlabel='Iteration',
                             figfile=os.path.join(FIGURES_DIR, 'overall_perf'))
            plot_performance(range(len(overall_area)),
                             overall_area,
                             ylabel=r'% State Space Sampled',
                             xlabel='Iteration',
                             figfile=os.path.join(FIGURES_DIR, 'overall_area'))

        print(
            '[Overall Iter %d]: perf_metric = %.2f | Area Coverage = %.2f%%' %
            (i, perf_metric, area_coverage * 100.))

        # Incrementing our algorithm's loop counter.
        i += 1

        data_logger.save_to_file()
        pi_i.save_model(MODEL_DIR, iteration=i)

    # Done!
    # NOTE: global plotting performance
    plot_performance(range(len(gl_ppo_rewards)),
                     gl_ppo_rewards,
                     ylabel=r'Avg Reward per training step',
                     xlabel='Iteration',
                     figfile=os.path.join(FIGURES_DIR, 'global_ppo_rewards'),
                     pickle=True)
    plot_performance(range(len(gl_overall_perf)),
                     gl_overall_perf,
                     ylabel=r'Avg Eval Reward from start state',
                     xlabel='Iteration',
                     figfile=os.path.join(FIGURES_DIR, 'global_overall_perf'),
                     pickle=True)
    plot_performance(range(len(gl_overall_area)),
                     gl_overall_area,
                     ylabel=r'State Space Sampled',
                     xlabel='Iteration',
                     figfile=os.path.join(FIGURES_DIR, 'gl_overall_area'),
                     pickle=True)

    if curriculum_strategy == backward_reachable:
        br_engine.stop()
        del br_engine

    plt.ioff()

    return pi_i