コード例 #1
0
    def _test_get_thread_id_not_parallel(self):
        python_get_thread_id = get_thread_id()
        check_array_size = 8

        @njit(parallel=False)
        def par_false(size):
            njit_par_false_tid = get_thread_id()
            res = np.ones(size)
            for i in prange(size):
                res[i] = get_thread_id()
            return njit_par_false_tid, res

        @njit(parallel=True)
        def par_true(size):
            njit_par_true_tid = get_thread_id()
            res = np.ones(size)
            for i in range(size):
                res[i] = get_thread_id()
            return njit_par_true_tid, res

        self.assertEqual(python_get_thread_id, 0)
        njit_par_false_tid, njit_par_false_arr = par_false(check_array_size)
        self.assertEqual(njit_par_false_tid, 0)
        np.testing.assert_equal(njit_par_false_arr, 0)
        njit_par_true_tid, njit_par_true_arr = par_true(check_array_size)
        self.assertEqual(njit_par_true_tid, 0)
        np.testing.assert_equal(njit_par_true_arr, 0)
コード例 #2
0
 def test_func():
     set_num_threads(mask)
     x = 5000000
     buf = np.empty((x, ))
     for i in prange(x):
         buf[i] = get_thread_id()
     return len(np.unique(buf)), get_num_threads()
コード例 #3
0
 def work(local_nt):  # arg is value 3
     tid = np.zeros(BIG)
     acc = 0
     set_num_threads(local_nt)  # set to 3 threads
     for i in prange(BIG):
         acc += 1
         tid[i] = get_thread_id()
     return acc, np.unique(tid)
コード例 #4
0
ファイル: base.py プロジェクト: tardis-sn/tardis
def montecarlo_main_loop(
    packet_collection,
    numba_model,
    numba_plasma,
    estimators,
    spectrum_frequency,
    number_of_vpackets,
    packet_seeds,
    virtual_packet_logging,
    iteration,
    show_progress_bars,
    no_of_packets,
    total_iterations,
):
    """
    This is the main loop of the MonteCarlo routine that generates packets
    and sends them through the ejecta.

    Parameters
    ----------
    packet_collection : PacketCollection
    numba_model : NumbaModel
        numba_plasma : NumbaPlasma
    estimators : NumbaEstimators
    spectrum_frequency : astropy.units.Quantity
        frequency binspas
    number_of_vpackets : int
        VPackets released per interaction
    packet_seeds : numpy.array
    virtual_packet_logging : bool
        Option to enable virtual packet logging.
    """
    output_nus = np.empty_like(packet_collection.packets_input_nu)
    last_interaction_types = (
        np.ones_like(packet_collection.packets_output_nu, dtype=np.int64) * -1)
    output_energies = np.empty_like(packet_collection.packets_output_nu)

    last_interaction_in_nus = np.empty_like(
        packet_collection.packets_output_nu)
    last_line_interaction_in_ids = (
        np.ones_like(packet_collection.packets_output_nu, dtype=np.int64) * -1)
    last_line_interaction_out_ids = (
        np.ones_like(packet_collection.packets_output_nu, dtype=np.int64) * -1)

    v_packets_energy_hist = np.zeros_like(spectrum_frequency)
    delta_nu = spectrum_frequency[1] - spectrum_frequency[0]

    # Pre-allocate a list of vpacket collections for later storage
    vpacket_collections = List()
    # Configuring the Tracking for R_Packets
    rpacket_trackers = List()
    for i in range(len(output_nus)):
        vpacket_collections.append(
            VPacketCollection(
                i,
                spectrum_frequency,
                montecarlo_configuration.v_packet_spawn_start_frequency,
                montecarlo_configuration.v_packet_spawn_end_frequency,
                number_of_vpackets,
                montecarlo_configuration.temporary_v_packet_bins,
            ))
        rpacket_trackers.append(RPacketTracker())

    main_thread_id = get_thread_id()
    n_threads = get_num_threads()

    estimator_list = List()
    for i in range(n_threads):  # betting get tid goes from 0 to num threads
        estimator_list.append(
            Estimators(
                np.copy(estimators.j_estimator),
                np.copy(estimators.nu_bar_estimator),
                np.copy(estimators.j_blue_estimator),
                np.copy(estimators.Edotlu_estimator),
                np.copy(estimators.photo_ion_estimator),
                np.copy(estimators.stim_recomb_estimator),
                np.copy(estimators.bf_heating_estimator),
                np.copy(estimators.stim_recomb_cooling_estimator),
                np.copy(estimators.photo_ion_estimator_statistics),
            ))
    # Arrays for vpacket logging
    virt_packet_nus = []
    virt_packet_energies = []
    virt_packet_initial_mus = []
    virt_packet_initial_rs = []
    virt_packet_last_interaction_in_nu = []
    virt_packet_last_interaction_type = []
    virt_packet_last_line_interaction_in_id = []
    virt_packet_last_line_interaction_out_id = []
    for i in prange(len(output_nus)):
        tid = get_thread_id()
        if show_progress_bars:

            if tid == main_thread_id:
                with objmode:
                    update_amount = 1 * n_threads
                    update_packet_pbar(
                        update_amount,
                        current_iteration=iteration,
                        no_of_packets=no_of_packets,
                        total_iterations=total_iterations,
                    )

        seed = packet_seeds[i]
        np.random.seed(seed)
        r_packet = RPacket(
            numba_model.r_inner[0],
            packet_collection.packets_input_mu[i],
            packet_collection.packets_input_nu[i],
            packet_collection.packets_input_energy[i],
            seed,
            i,
        )
        local_estimators = estimator_list[tid]
        vpacket_collection = vpacket_collections[i]
        rpacket_tracker = rpacket_trackers[i]

        loop = single_packet_loop(
            r_packet,
            numba_model,
            numba_plasma,
            estimators,
            vpacket_collection,
            rpacket_tracker,
        )

        output_nus[i] = r_packet.nu
        last_interaction_in_nus[i] = r_packet.last_interaction_in_nu
        last_line_interaction_in_ids[i] = r_packet.last_line_interaction_in_id
        last_line_interaction_out_ids[
            i] = r_packet.last_line_interaction_out_id

        if r_packet.status == PacketStatus.REABSORBED:
            output_energies[i] = -r_packet.energy
            last_interaction_types[i] = r_packet.last_interaction_type
        elif r_packet.status == PacketStatus.EMITTED:
            output_energies[i] = r_packet.energy
            last_interaction_types[i] = r_packet.last_interaction_type

        vpackets_nu = vpacket_collection.nus[:vpacket_collection.idx]
        vpackets_energy = vpacket_collection.energies[:vpacket_collection.idx]
        vpackets_initial_mu = vpacket_collection.initial_mus[:
                                                             vpacket_collection
                                                             .idx]
        vpackets_initial_r = vpacket_collection.initial_rs[:vpacket_collection.
                                                           idx]

        v_packets_idx = np.floor(
            (vpackets_nu - spectrum_frequency[0]) / delta_nu).astype(np.int64)

        for j, idx in enumerate(v_packets_idx):
            if (vpackets_nu[j] < spectrum_frequency[0]) or (
                    vpackets_nu[j] > spectrum_frequency[-1]):
                continue
            v_packets_energy_hist[idx] += vpackets_energy[j]

    for sub_estimator in estimator_list:
        estimators.increment(sub_estimator)

    if virtual_packet_logging:
        for vpacket_collection in vpacket_collections:
            vpackets_nu = vpacket_collection.nus[:vpacket_collection.idx]
            vpackets_energy = vpacket_collection.energies[:vpacket_collection.
                                                          idx]
            vpackets_initial_mu = vpacket_collection.initial_mus[:
                                                                 vpacket_collection
                                                                 .idx]
            vpackets_initial_r = vpacket_collection.initial_rs[:
                                                               vpacket_collection
                                                               .idx]
            virt_packet_nus.append(np.ascontiguousarray(vpackets_nu))
            virt_packet_energies.append(np.ascontiguousarray(vpackets_energy))
            virt_packet_initial_mus.append(
                np.ascontiguousarray(vpackets_initial_mu))
            virt_packet_initial_rs.append(
                np.ascontiguousarray(vpackets_initial_r))
            virt_packet_last_interaction_in_nu.append(
                np.ascontiguousarray(
                    vpacket_collection.
                    last_interaction_in_nu[:vpacket_collection.idx]))
            virt_packet_last_interaction_type.append(
                np.ascontiguousarray(
                    vpacket_collection.
                    last_interaction_type[:vpacket_collection.idx]))
            virt_packet_last_line_interaction_in_id.append(
                np.ascontiguousarray(
                    vpacket_collection.
                    last_interaction_in_id[:vpacket_collection.idx]))
            virt_packet_last_line_interaction_out_id.append(
                np.ascontiguousarray(
                    vpacket_collection.
                    last_interaction_out_id[:vpacket_collection.idx]))

    if montecarlo_configuration.RPACKET_TRACKING:
        for rpacket_tracker in rpacket_trackers:
            rpacket_tracker.finalize_array()

    packet_collection.packets_output_energy[:] = output_energies[:]
    packet_collection.packets_output_nu[:] = output_nus[:]
    return (
        v_packets_energy_hist,
        last_interaction_types,
        last_interaction_in_nus,
        last_line_interaction_in_ids,
        last_line_interaction_out_ids,
        virt_packet_nus,
        virt_packet_energies,
        virt_packet_initial_mus,
        virt_packet_initial_rs,
        virt_packet_last_interaction_in_nu,
        virt_packet_last_interaction_type,
        virt_packet_last_line_interaction_in_id,
        virt_packet_last_line_interaction_out_id,
        rpacket_trackers,
    )
コード例 #5
0
 def par_true(size):
     njit_par_true_tid = get_thread_id()
     res = np.ones(size)
     for i in range(size):
         res[i] = get_thread_id()
     return njit_par_true_tid, res
コード例 #6
0
 def test_gufunc(x, out):
     set_num_threads(mask)
     x[:] = get_thread_id()
     out[0] = get_num_threads()