def train(base_class_name, demo_numbers, program_generation_step_size, num_programs, num_dts, max_num_particles):
    programs, program_prior_log_probs = get_program_set(base_class_name, num_programs)

    X, y = run_all_programs_on_demonstrations(base_class_name, num_programs, demo_numbers)
    plps, plp_priors = learn_plps(X, y, programs, program_prior_log_probs, num_dts=num_dts,
        program_generation_step_size=program_generation_step_size)

    demonstrations = get_demonstrations(base_class_name, demo_numbers=demo_numbers)
    likelihoods = compute_likelihood_plps(plps, demonstrations)

    particles = []
    particle_log_probs = []

    for plp, prior, likelihood in zip(plps, plp_priors, likelihoods):
        particles.append(plp)
        particle_log_probs.append(prior + likelihood)

    print("\nDone!")
    map_idx = np.argmax(particle_log_probs).squeeze()
    print("MAP program ({}):".format(particle_log_probs[map_idx]))
    print(particles[map_idx])

    top_particles, top_particle_log_probs = select_particles(particles, particle_log_probs, max_num_particles)
    if len(top_particle_log_probs) > 0:
        top_particle_log_probs = np.array(top_particle_log_probs) - logsumexp(top_particle_log_probs)
        top_particle_probs = np.exp(top_particle_log_probs)
        print("top_particle_probs:", top_particle_probs)
        policy = PLPPolicy(top_particles, top_particle_probs)
    else:
        print("no nontrivial particles found")
        policy = PLPPolicy([StateActionProgram("False")], [1.0])

    return policy
def run_all_programs_on_single_demonstration(base_class_name, num_programs, demo_number, program_interval=1000):
    """
    Run all programs up to some iteration on one demonstration.

    Expensive in general because programs can be slow and numerous, so caching can be very helpful.

    Parallelization is designed to save time in the regime of many programs.

    Care is taken to avoid memory issues, which are a serious problem when num_programs exceeds 50,000.

    Returns classification dataset X, y.

    Parameters
    ----------
    base_class_name : str
    num_programs : int
    demo_number : int
    program_interval : int
        This interval splits up program batches for parallelization.

    Returns
    -------
    X : csr_matrix
        X.shape = (num_demo_items, num_programs)
    y : [ bool ]
        y.shape = (num_demo_items,)
    """

    print("Running all programs on {}, {}".format(base_class_name, demo_number))

    programs, _ = get_program_set(base_class_name, num_programs)

    demonstration = get_demonstrations(base_class_name, demo_numbers=(demo_number,))
    positive_examples, negative_examples = extract_examples_from_demonstration(demonstration)
    y = [1] * len(positive_examples) + [0] * len(negative_examples)

    num_data = len(y)
    num_programs = len(programs)

    X = lil_matrix((num_data, num_programs), dtype=bool)

    # This loop avoids memory issues
    for i in range(0, num_programs, program_interval):
        end = min(i+program_interval, num_programs)
        print('Iteration {} of {}'.format(i, num_programs), end='\r')

        num_workers = multiprocessing.cpu_count()
        pool = multiprocessing.Pool(num_workers)

        fn = partial(apply_programs, programs[i:end])
        fn_inputs = positive_examples + negative_examples
        
        results = pool.map(fn, fn_inputs)
        pool.close()

        for X_idx, x in enumerate(results):
            X[X_idx, i:end] = x

    X = X.tocsr()

    print()
    return X, y