def set_periodic_checkpoint(sim, period):
    """
    Set up periodic checkpoints of the simulation

    The checkpoints are saved in openPMD format, in the directory
    `./checkpoints`, with one subdirectory per process.
    All the field and particle information of each processor is saved.

    NB: Checkpoints are registered among the list of diagnostics
    `diags` of the Simulation object `sim`.

    Parameters
    ----------
    sim: a Simulation object
       The simulation that is to be saved in checkpoints

    period: integer
       The number of PIC iteration between each checkpoint.
    """
    # Only processor 0 creates a directory where checkpoints will be stored
    # Make sure that all processors wait until this directory is created
    # (Use the global MPI communicator instead of the `BoundaryCommunicator`
    # so that this still works in the case `use_all_ranks=False`)
    if comm.rank == 0:
        if os.path.exists('./checkpoints') is False:
            os.mkdir('./checkpoints')
    comm.Barrier()

    # Choose the name of the directory: one directory per processor
    write_dir = 'checkpoints/proc%d/' % comm.rank

    # Register a periodic FieldDiagnostic in the diagnostics of the simulation
    sim.diags.append(FieldDiagnostic(period, sim.fld, write_dir=write_dir))

    # Register a periodic ParticleDiagnostic, which contains all
    # the particles which are present in the simulation
    particle_dict = {}
    for i in range(len(sim.ptcl)):
        particle_dict['species %d' % i] = sim.ptcl[i]
    sim.diags.append(
        ParticleDiagnostic(period, particle_dict, write_dir=write_dir))
Example #2
0
    def test_synchronize(self):
        """ Make sure that we can make the overlap spaces accurate. """
        for case in self.cases:
            space.initialize_space(case['shape'])
            data = np.random.randn(*case['shape']).astype(case['dtype'])
            cpu_data = np.empty_like(data)
            comm.Allreduce(data, cpu_data)
            g = Grid(case['dtype'])
            self.assertRaises(TypeError, g.synchronize)  # No overlap.
            # Test with-overlap cases as well.
            for k in range(1, 4):
                g = Grid(case['dtype'], x_overlap=k)

                # Overwrite entire grid
                data = np.random.randn(*case['shape']).astype(case['dtype'])
                cpu_data = np.empty_like(data)
                comm.Allreduce(data, cpu_data)
                cpu_raw_bad = get_cpu_raw(cpu_data, k)
                cpu_raw_bad[:k, :, :] += 1  # Mess up padding areas.
                cpu_raw_bad[-k:, :, :] += 1
                drv.memcpy_htod(g.data.ptr, cpu_raw_bad)

                # Prove that the data is not synchronized at this time.
                cpu_raw = get_cpu_raw(cpu_data, k)
                xx = case['shape'][0]
                gd = g._get_raw()
                self.assertTrue((gd[:k, :, :] != cpu_raw[:k, :, :]).all())
                self.assertTrue((gd[-k:, :, :] != cpu_raw[-k:, :, :]).all())

                g.synchronize()  # Synchronize the overlapping data.

                # Make sure that the overlap data is accurate.
                gd = g._get_raw()
                self.assertTrue((gd[:k, :, :] == cpu_raw[:k, :, :]).all())
                self.assertTrue((gd[-k:, :, :] == cpu_raw[-k:, :, :]).all())

                comm.Barrier()  # Wait for other mpi nodes to finish.
Example #3
0
    if isinstance(x, collections.Iterable):
        return [a for i in x for a in flatten(i)]
    else:
        return [x]


rank = CW.Get_rank()
size = CW.Get_size()

two_col_width = 7.20472  #inches
single_col_width = 3.50394  #inches
page_height = 10.62472
font_size = 10

sys.stdout.flush()
CW.Barrier()

pickle_file = sys.argv[1]
true_birth_con_pickle = sys.argv[2]
plot_gradient = False
read_pickle = bool(sys.argv[3])
baseline_yr = float(sys.argv[4])
#plot_key = sys.argv[2]
plot_keys = ['System_semimajor']  #, 'System_ecc', 'System_energies']

sys.stdout.flush()
CW.Barrier()

#check_sub_sys = [[21, 22], [23,97], [23,24], [23,97], [72,76], [78,235]]
jump_time_start = []
jump_time_end = []
Example #4
0
    def __init__(self, array_or_dtype, x_overlap=0):
        """ Create a spatial grid on the GPU(s).

        Input variables
        array_or_dtype -- can either be a numpy array of the same shape as
            the global space, or a numpy dtype. If a valid array is passed, 
            it will be loaded on to the GPU. If a dtype is passed, then
            an array of zeros, of that dtype will be loaded onto the GPU.

        Optional variables
        x_overlap -- the number of adjacent cells in either the negative or
            positive x-direction that need to simultaneously be accessed along
            with the current cell. Must be a non-negative integer. Default
            value is 0.

        """

        shape = get_space_info()['shape']  # Get the shape of the space.
        xr = get_space_info()['x_range']  # Get the local x_range.
        all_x_ranges = get_space_info()[
            'all_x_ranges']  # Get the local x_range.
        local_shape = (xr[1] - xr[0], shape[1], shape[2])

        self._set_gce_type('grid')  # Set the gce type to grid.

        # Make sure overlap option is valid.
        if type(x_overlap) is not int:
            raise TypeError('x_overlap must be an integer.')
        elif x_overlap < 0:
            raise TypeError('x_overlap must be a non-negative integer.')

        if comm.rank == 0:
            # Process the array_or_dtype input variable.
            if type(array_or_dtype) is np.ndarray:  # Input is an array.
                array = array_or_dtype

                # Make sure the array is of the correct shape.
                if array.shape != shape:
                    raise TypeError(
                        'Shape of array does not match shape of space.')

                # Make sure the array is of a valid datatype.
                self._get_dtype(array.dtype.type)

            elif type(array_or_dtype) is type:  # Input is a datatype.
                self._get_dtype(array_or_dtype)  # Validate the dtype.
                array = np.zeros(shape,
                                 dtype=self.dtype)  # Make a zeros array.

            else:  # Invalid input.
                raise TypeError(
                    'Input variable must be a numpy array or dtype')

            # Prepare array to be scattered.
            array = [array[r[0]:r[1], :, :] for r in all_x_ranges]

        else:
            array = None

        array = comm.scatter(array)
        self._get_dtype(array.dtype.type)

        #         # Narrow down the array to local x_range.
        #         array = array[xr[0]:xr[1],:,:]

        # Add padding to array, if needed.
        self._xlap = x_overlap
        if self._xlap is not 0:
            padding = np.empty((self._xlap, ) + shape[1:3], dtype=array.dtype)
            array = np.concatenate((padding, array, padding), axis=0)

        self.to_gpu(array)  # Load onto device.

        # Determine information needed for synchronization.
        if self._xlap is not 0:
            # Calculates the pointer to the x offset in a grid.
            ptr_dx = lambda x_pos: self.data.ptr + self.data.dtype.itemsize * \
                                        x_pos * shape[1] * shape[2]

            # Pointers to different sections of the grid that are relevant
            # for synchronization.
            self._sync_ptrs = { 'forw_src': ptr_dx(xr[1]-xr[0]), \
                                'back_dest': ptr_dx(0), \
                                'back_src': ptr_dx(self._xlap), \
                                'forw_dest': ptr_dx(xr[1]-xr[0] + self._xlap)}

            # Buffers used during synchronization.
            self._sync_buffers = [drv.pagelocked_empty( \
                                    (self._xlap, shape[1], shape[2]), \
                                    self.dtype) for k in range(4)]

            # Streams used during synchronization.
            self._sync_streams = [drv.Stream() for k in range(4)]

            # Used to identify neighboring MPI nodes with whom to synchronize.
            self._sync_adj = get_space_info()['mpi_adj']

            # Offset in bytes to the true start of the grid.
            # This is used to "hide" overlap areas from the kernel.
            self._xlap_offset = self.data.dtype.itemsize * \
                                self._xlap * shape[1] * shape[2]

            self.synchronize()  # Synchronize the grid.
            comm.Barrier(
            )  # Wait for all grids to synchronize before proceeding.
Example #5
0
def main():

    rank = CW.Get_rank()
    size = CW.Get_size()

    args = parse_inputs()

    n_orb = int(args.no_orbits)
    n_systems = int(args.no_systems)
    q_min = 0.05
    my_orb = bo.random_orbits(n_orb=n_orb)
    US_group_vel = 10.
    UCL_group_vel = 4.
    #Madsen, 2002 gives the STANDARD ERROR of the US and UCL velcs to be 1.3 and 1.9km/s
    US_group_std = 1.3 * args.group_velocity_sigma  #From Preibisch et al., 2008
    UCL_group_std = 1.3 * args.group_velocity_sigma
    standard_std = {'F': 1.08, 'G': 0.63, 'K': 1.43, 'M': 2.27}  # 2.0
    astrophysical_std = args.astrophysical_std  #Astrophysical radial velocity uncertainty

    Object = []
    Region = []
    IR_excess = []
    Temp_sptype = []
    Pref_template = []
    Obs_info = []
    all_bayes = [[], []]

    RV_standard_info = {}

    sys.stdout.flush()
    CW.Barrier()

    #Read in RV standard list
    header = 0
    with open('/home/100/rlk100/RV_standard_list.csv', 'rU') as f:
        reader = csv.reader(f)
        for row in reader:
            if header != 0:
                RV_standard_info[row[0]] = (float(row[5]), float(row[6]),
                                            float(row[7]))
            else:
                header = 1
        f.close()

    sys.stdout.flush()
    CW.Barrier()

    print("Reading in current spreadsheet", args.input_file)
    header = 0
    reshape_len = -1
    with open(args.input_file, 'rU') as f:
        reader = csv.reader(f)
        for row in reader:
            if header != 0:
                if 'U4' in row[0]:
                    row[0] = 'UCAC4' + row[0].split('U4')[-1]
                Object.append(row[0])
                Region.append(row[1])
                IR_excess.append(row[5])
                Pref_template.append(row[15])  #row[18])
                Temp_sptype.append(row[16])  #row[19])
                if len(row) > 17:
                    Obs = np.array(row[17:])
                    Obs = np.delete(Obs, np.where(Obs == ''))
                    if reshape_len == -1:
                        for ob in Obs:
                            reshape_len = reshape_len + 1
                            if '/' in ob and ob != Obs[0]:
                                break
                    #if len(Obs) > 5:
                    #    Obs = np.reshape(Obs, (len(Obs)/5, 5))
                    Obs = np.reshape(Obs,
                                     (len(Obs) / reshape_len, reshape_len))
                    for ind_obs in Obs:
                        if '/' in ind_obs[0]:
                            new_format = '20' + ind_obs[0].split('/')[
                                -1] + '-' + ind_obs[0].split('/')[-2] + '-' + (
                                    "%02d" % int(ind_obs[0].split('/')[-3]))
                            ind_obs[0] = new_format
                else:
                    Obs = np.array([])
                Obs_info.append(Obs)
            if header == 0:
                header = 1
        f.close()
    del header

    sys.stdout.flush()
    CW.Barrier()

    Obj_bayes = np.nan * np.zeros(len(Object))

    #Read in currently calculated Bayes Factors:
    if args.restart_calc != 'False':
        print("Reading in calulated Bayes factors")
        header = 0
        with open(args.bayes_file, 'rU') as f:
            reader = csv.reader(f)
            for row in reader:
                if header != 0:
                    ind = Object.index(row[0])
                    Obj_bayes[ind] = float(row[2])
                    if row[1] == 'US':
                        all_bayes[0].append(float(row[2]))
                    else:
                        all_bayes[1].append(float(row[2]))
                    del ind
                else:
                    header = 1
            f.close()
        del header

    sys.stdout.flush()
    CW.Barrier()

    if args.restart_calc != 'False' and rank == 0:
        print("Creating new bayes file")
        f = open(args.bayes_file, 'w')
        f.write('Object,Region,Bayes_factor\n')
        f.close()

    sys.stdout.flush()
    CW.Barrier()

    inds = list(range(len(Object)))
    skip_inds = np.where(np.array(IR_excess) == 'NN')[0]
    for skit in skip_inds:
        inds.remove(skit)
    skip_inds = np.where(np.array(Pref_template) == '')[0]
    for skit in skip_inds:
        inds.remove(skit)
    del skip_inds
    del IR_excess

    rit = 0
    sys.stdout.flush()
    CW.Barrier()
    for obj in inds:
        Pref_template_name = Pref_template[obj].split('_')[0]
        if np.isnan(Obj_bayes[obj]) and rank == rit:
            print("Doing object:", Object[obj], "on rank:", rank)
            likelihoods = []
            single_likelihoods = []

            #Produces masses within +/- 10% of the mass of the template.
            #!!! Mike suggests a single mass.
            M_1 = (np.random.random(n_systems) *
                   (RV_standard_info[Pref_template_name][1] -
                    RV_standard_info[Pref_template_name][0])
                   ) + RV_standard_info[Pref_template_name][0]

            #Generates mass ratios with minium mass ratio of q_min (default 0.01?, should this be dependant on the primary mass? Because sometimes low mass ratios could give very low mass companions i.e. BD mass...)
            #!!! Mike suggests 0.05 due to brown dwarf desert.
            q = (np.random.random(n_systems) * (1 - q_min)) + q_min

            #from Primary masses and mass ratios, secondary masses can get calculated
            M_2 = M_1 * q

            #Get dates of the observations of the object
            jds = Obs_info[obj][:, 1].astype(np.float)

            #get observed data, and add in the error in the standards in quadrature.
            #This relates to the spectrograph stability
            #There is also an astrophysical error due to these objects being rapid rotators etc.
            RV_standard_err = standard_std[Temp_sptype[obj][0]]
            err = np.sqrt(Obs_info[obj][:, 3].astype(float)**2 +
                          RV_standard_err**2 + astrophysical_std**2)
            observed_rv = Obs_info[obj][:, 2].astype(float)

            #IN A LOOP iterate over random orbits:
            for orb in range(n_orb):
                #FIXME: Figure out which velocity to use!
                if Region[obj] == 'US':
                    if args.group_velocity == 'True':
                        v_group = np.random.normal(
                            US_group_vel,
                            np.sqrt(US_group_std**2 + RV_standard_err**2),
                            n_systems)
                    else:
                        v_group = np.random.normal(
                            np.mean(observed_rv),
                            np.sqrt(US_group_std**2 + RV_standard_err**2),
                            n_systems)
                else:
                    if args.group_velocity == 'True':
                        v_group = np.random.normal(
                            UCL_group_vel,
                            np.sqrt(UCL_group_std**2 + RV_standard_err**2),
                            n_systems)
                    else:
                        v_group = np.random.normal(
                            np.mean(observed_rv),
                            np.sqrt(UCL_group_std**2 + RV_standard_err**2),
                            n_systems)

                #generate orbit?
                #!!! Find just one set of orbital parameters at at a time, and
                #scale the RVS. OR if you really want you can compute a, i etc
                #yourself and plug these into my_orb, but some RV scalign is still needed.
                rho, theta, normalised_vr = bo.binary_orbit(my_orb,
                                                            jds,
                                                            plot_orbit_no=orb)
                for system in range(n_systems):
                    actual_vr = bo.scale_rv(normalised_vr,
                                            my_orb['P'][orb],
                                            M_1[system],
                                            M_2[system],
                                            my_orb['i'][orb],
                                            group_velocity=v_group[system])

                    this_likelihood = bo.calc_likelihood(
                        actual_vr, observed_rv, err)
                    likelihoods.append(this_likelihood)
                    #THEN CALCULATE PROBABILITY OF BEING A SINGLE STAR
                    single_likelihoods.append(
                        bo.calc_likelihood(v_group[system], observed_rv, err))
                    del actual_vr
                    del this_likelihood
                del v_group
            del M_1
            del q
            del M_2
            del jds
            del RV_standard_err
            del err
            del observed_rv

            #THEN CALCULATE BAYES FACTOR
            bayes_factor = np.mean(likelihoods) / np.mean(single_likelihoods)
            print(("Bayes Factor: {0:5.2f} for ".format(bayes_factor) +
                   Object[obj]), "on rank", rank, "with SpT", Temp_sptype[obj])
            del likelihoods
            del single_likelihoods
            if Region[obj] == 'US':
                send_data = [0.0, float(obj), bayes_factor, Temp_sptype[obj]]
                #print "Sending data:", send_data, "from rank:", rank
                if rank == 0:
                    bayes_update = send_data
                else:
                    CW.send(send_data, dest=0, tag=rank)
            else:
                send_data = [1.0, float(obj), bayes_factor, Temp_sptype[obj]]
                #print "Sending data:", send_data, "from rank:", rank
                if rank == 0:
                    bayes_update = send_data
                else:
                    CW.send(send_data, dest=0, tag=rank)
            del send_data
            if rank == 0:
                all_bayes[int(bayes_update[0])].append(bayes_update[2])
                Obj_bayes[int(bayes_update[1])] = bayes_update[2]
                print("Updated Bayes factors retrieved from rank 0 for object",
                      Object[int(bayes_update[1])])
                f = open(args.bayes_file, 'a')
                write_string = Object[int(bayes_update[1])] + ',' + Region[int(
                    bayes_update[1])] + ',' + str(bayes_update[2]) + ',' + str(
                        bayes_update[3]) + '\n'
                f.write(write_string)
                f.close()
                del bayes_update
                del write_string

        rit = rit + 1
        if rit == size:
            sys.stdout.flush()
            CW.Barrier()
            rit = 0
            if rank == 0:

                print("UPDATING CALCULATED BAYES VALUES")
                for orit in range(1, size):
                    bayes_update = CW.recv(source=orit, tag=orit)
                    all_bayes[int(bayes_update[0])].append(bayes_update[2])
                    Obj_bayes[int(bayes_update[1])] = bayes_update[2]
                    print("Updated Bayes factors retrieved from rank", orit,
                          "for object", Object[int(bayes_update[1])])
                    f = open(args.bayes_file, 'a')
                    write_string = Object[int(
                        bayes_update[1])] + ',' + Region[int(
                            bayes_update[1])] + ',' + str(
                                bayes_update[2]) + ',' + str(
                                    bayes_update[3]) + '\n'
                    f.write(write_string)
                    f.close()
                    del bayes_update
                    del write_string
            sys.stdout.flush()
            CW.Barrier()

    sys.stdout.flush()
    CW.Barrier()
    if rank == 0:

        print("UPDATING CALCULATED BAYES VALUES")
        for orit in range(1, size):
            bayes_update = CW.recv(source=orit, tag=orit)
            all_bayes[int(bayes_update[0])].append(bayes_update[2])
            Obj_bayes[int(bayes_update[1])] = bayes_update[2]
            print("Updated Bayes factors retrieved from rank", orit,
                  "for object", Object[int(bayes_update[1])])
            f = open(args.bayes_file, 'a')
            write_string = Object[int(bayes_update[1])] + ',' + Region[int(
                bayes_update[1])] + ',' + str(bayes_update[2]) + ',' + str(
                    bayes_update[3]) + '\n'
            f.write(write_string)
            f.close()
            del bayes_update
            del write_string
        sys.stdout.flush()
        CW.Barrier()
    print("Finished Calculating bayes factors!")
def restart_from_checkpoint(sim, iteration=None):
    """
    Fills the Simulation object `sim` with data saved in a checkpoint.

    More precisely, the following data from `sim` is overwritten:

    - Current time and iteration number of the simulation
    - Position of the boundaries of the simulation box
    - Values of the field arrays
    - Size and values of the particle arrays

    Any other information (e.g. diagnostics of the simulation, presence of a
    moving window, presence of a laser antenna, etc.) need to be set by hand.

    For this reason, a successful restart will often require to modify the
    original input script that produced the checkpoint, rather than to start
    a new input script from scratch.

    NB: This function should always be called *before* the initialization
    of the moving window, since the moving window infers the position of
    particle injection from the existing particle data.

    Parameters
    ----------
    sim: a Simulation object
       The Simulation object into which the checkpoint should be loaded

    iteration: integer (optional)
       The iteration number of the checkpoint from which to restart
       If None, the latest checkpoint available will be used.
    """
    # Import openPMD-viewer
    try:
        from opmd_viewer import OpenPMDTimeSeries
    except ImportError:
        raise ImportError(
            'The package `opmd_viewer` is required to restart from checkpoints.'
            '\nPlease install it from https://github.com/openPMD/openPMD-viewer'
        )

    # Verify that the restart is valid (only for the first processor)
    # (Use the global MPI communicator instead of the `BoundaryCommunicator`,
    # so that this also works for `use_all_ranks=False`)
    if comm.rank == 0:
        check_restart(sim, iteration)
    comm.Barrier()

    # Choose the name of the directory from which to restart:
    # one directory per processor
    checkpoint_dir = 'checkpoints/proc%d/hdf5' % comm.rank
    ts = OpenPMDTimeSeries(checkpoint_dir)
    # Select the iteration, and its index
    if iteration is None:
        iteration = ts.iterations[-1]
    i_iteration = ts.iterations.index(iteration)

    # Modify parameters of the simulation
    sim.iteration = iteration
    sim.time = ts.t[i_iteration]

    # Load the particles
    # Loop through the different species
    for i in range(len(sim.ptcl)):
        name = 'species %d' % i
        load_species(sim.ptcl[i], name, ts, iteration, sim.comm)

    # Load the fields
    # Loop through the different modes
    for m in range(sim.fld.Nm):
        # Load the fields E and B
        for fieldtype in ['E', 'B', 'J']:
            for coord in ['r', 't', 'z']:
                load_fields(sim.fld.interp[m], fieldtype, coord, ts, iteration)
Example #7
0
        def execute(cfg, *args, **kwargs):

            # Parse keyword arguments.
            post_sync_grids = kwargs.get('post_sync', None)

            # Parse the inputs.
            gpu_params = []
            for k in range(len(params)):
                if params[k]['gce_type'] is 'number':
                    gpu_params.append(params[k]['dtype'](args[k]))
                elif params[k]['gce_type'] is 'const':  # Load Const.
                    gpu_params.append(args[k].data.ptr)
                    # Const no longer actually "const" in cuda code.


#                     d_ptr, size_in_bytes = my_get_global(params[k]['name'])
#                     drv.memcpy_dtod(d_ptr, args[k].data.gpudata, size_in_bytes)
                elif params[k]['gce_type'] is 'grid':
                    if args[k]._xlap is 0:
                        gpu_params.append(args[k].data.ptr)
                    else:
                        gpu_params.append(args[k].data.ptr + \
                                            args[k]._xlap_offset)
                elif params[k]['gce_type'] is 'out':
                    args[k].data.fill(args[k].dtype(0))  # Initialize the Out.
                    gpu_params.append(args[k].data.ptr)
                else:
                    raise TypeError('Invalid input type.')

            # See if we need to synchronize grids after kernel execution.
            if post_sync_grids is None:
                sync_pad = 0
            else:
                sync_pad = max([g._xlap for g in post_sync_grids])

            start2.record(stream)
            comm.Barrier()
            start.record(stream)

            # Execute kernel in padded regions first.
            execute_range(x_start, x_start + sync_pad, gpu_params, cfg, stream)
            execute_range(x_end - sync_pad, x_end, gpu_params, cfg, stream)
            pad_done.record(stream)  # Just for timing purposes.
            stream.synchronize()  # Wait for execution to finish.

            # Begin kernel execution in remaining "core" region.
            execute_range(x_start + sync_pad, x_end - sync_pad, gpu_params,
                          cfg, stream)
            comp_done.record(stream)  # Timing only.

            # While core kernel is executing, perform synchronization.
            if post_sync_grids is not None:  # Synchronization needed.
                for grid in post_sync_grids:
                    grid.synchronize_start()  # Start synchronization.

                # Keep on checking until everything is done.
                while not (all([grid.synchronize_isdone() \
                                for grid in post_sync_grids]) and \
                        stream.is_done()):
                    pass

            else:  # Nothing to synchronize.
                stream.synchronize()  # Just wait for execution to finish.

            sync_done.record()  # Timing.

            # Obtain the result for all Outs.
            batch_reduce(*[args[k] for k in range(len(params)) \
                                if params[k]['gce_type'] is 'out'])
            all_done.record()  # Timing.
            all_done.synchronize()

            return comp_done.time_since(
                start)  # Return time needed to execute the function.
if rank == 0:
    solO = opticalpath(np.loadtxt("dbr.txt")[:, 1])
    sol = neighb(solO, n, Neighb)
    wyn = solO
    cel_wyn = goal(wyn, w)
    mkdir(path)
else:
    sol = np.array([])

sol = mpi.bcast(sol, root=0)
j = 0
start = timeit.default_timer()
stop = start
end = 500
mpi.Barrier()

while (j <= end) & ((stop - start) < time_limit):
    best_sol = np.array([])
    good_sol = np.array([])
    m = 0
    e = 0
    cel = np.array([])

    period = n // size
    full_range = np.linspace(0, size, size + 1) * period
    if full_range[-1] < n:
        full_range[-1] = n

    tempsol = sol[int(full_range[rank]):int(full_range[rank + 1])]
    cel = goal(tempsol, w)
Example #9
0
def main():
    rank = CW.Get_rank()
    size = CW.Get_size()
    args = parse_inputs()
    prev_args = args
    print("Starting mosaic_mod_script on rank", rank)

    # Read in directories:
    input_file = args.input_file
    #save_dir = args.save_directory
    #if os.path.exists(save_dir) == False:
    #    os.makedirs(save_dir)

    # Read in input file
    print("Reading in input mosaic file on rank", rank)
    positions = []
    paths = []
    args_dict = []
    with open(input_file, 'rU') as mosaic_file:
        reader = csv.reader(mosaic_file)
        for row in reader:
            if row[0] == 'Grid_inputs:':
                glr = float(row[1])
                grl = float(row[2])
                glw = float(row[3])
                ghspace = float(row[4])
            elif row[0][0] != '#':
                positions.append((int(row[0]), int(row[1])))
                paths.append(row[2])
                dict = ""
                for col in row[3:]:
                    dict = dict + col
                    if col != row[-1]:
                        dict = dict + ','
                dict = ast.literal_eval(dict)
                args_temp = argparse.Namespace(**vars(args))
                for key in list(dict.keys()):
                    if key in args:
                        exec("args_temp."+ key + " = " + "str(dict[key])")
                args_dict.append(args_temp)
                del args_temp
                args = prev_args
                
    import pdb
    pdb.set_trace()

    positions = np.array(positions)

    c = define_constants()
    mym.set_global_font_size(args.text_font)
    files = []
    simfo = []
    X = []
    Y = []
    X_vel = []
    Y_vel = []
    sim_files = []
    L = None
    for pit in range(len(paths)):
        fs = get_files(paths[pit], args_dict[pit])
        files.append(fs)

        #print "paths =", paths
        #print "fs =", fs
        #print "args_dict =", args_dict
        sfo = sim_info(paths[pit], fs[-1], args_dict[pit])
        simfo.append(sfo)

        if args_dict[pit].yt_proj == False:
            x, y, x_vel, y_vel, cl = mym.initialise_grid(files[pit][-1], zoom_times=args_dict[pit].zoom_times)
            X.append(x)
            Y.append(y)
            X_vel.append(x_vel)
            Y_vel.append(y_vel)
        else:
            x = np.linspace(sfo['xmin'], sfo['xmax'], sfo['dimension'])
            y = np.linspace(sfo['ymin'], sfo['ymax'], sfo['dimension'])
            x, y  = np.meshgrid(x, y)
            
            annotate_space = (simfo[pit]['xmax'] - simfo[pit]['xmin'])/31.
            x_ind = []
            y_ind = []
            counter = 0
            while counter < 31:
                val = annotate_space*counter + annotate_space/2. + simfo[pit]['xmin']
                x_ind.append(int(val))
                y_ind.append(int(val))
                counter = counter + 1
            x_vel, y_vel = np.meshgrid(x_ind, y_ind)
            if args_dict[pit].projection_orientation != None:
                y_val = 1./np.tan(np.deg2rad(float(args_dict[pit].projection_orientation)))
                if np.isinf(y_val):
                    y_val = 0.0
                L = [1.0, y_val, 0.0]
            else:
                if has_particles == False or len(dd['particle_posx']) == 1:
                    L = [0.0, 1.0, 0.0]
                else:
                    pos_vec = [np.diff(dd['particle_posx'].value)[0], np.diff(dd['particle_posy'].value)[0]]
                    L = [-1*pos_vec[-1], pos_vec[0]]
                    L.append(0.0)
                    if L[0] > 0.0:
                        L = [-1.0*L[0], -1.0*L[1], 0.0]
            print("SET PROJECTION ORIENTATION L=", L)
            L = np.array(L)
            X.append(x)
            Y.append(y)
            X_vel.append(x_vel)
            Y_vel.append(y_vel)
        if rank == 0:
            print("shape of x, y", np.shape(x), np.shape(y))

        if args_dict[pit].yt_proj == False and args_dict[pit].image_center != 0:
            sim_fs = sorted(glob.glob(paths[pit] + 'WIND_hdf5_plt_cnt*'))
        elif args_dict[pit].yt_proj != False and args_dict[pit].image_center != 0:
            sim_fs = files
        else:
            sim_fs = []
        sim_files.append(sim_fs)
    #myf.set_normal(L)
    #print "SET PROJECTION ORIENTATION L=", myf.get_normal()

    # Initialise Grid and build lists
    if args.plot_time != None:
        m_times = [args.plot_time]
    else:
        m_times = mym.generate_frame_times(files[0], args.time_step, presink_frames=args.presink_frames, end_time=args.end_time)
    no_frames = len(m_times)
    m_times = m_times[args.start_frame:]
    sys.stdout.flush()
    CW.Barrier()

    usable_files = []
    usable_sim_files = []
    for pit in range(len(paths)):
        usable_fs = mym.find_files(m_times, files[pit])
        usable_files.append(usable_fs)
        if args_dict[pit].image_center != 0 and args_dict[pit].yt_proj == False:
            usable_sfs = mym.find_files(m_times, sim_files[pit])
            usable_sim_files.append(usable_fs)
            del sim_files[pit]
        else:
            usable_sim_files.append([])
    sys.stdout.flush()
    CW.Barrier()
    frames = list(range(args.start_frame, no_frames))

    sink_form_time = []
    for pit in range(len(paths)):
        sink_form = mym.find_sink_formation_time(files[pit])
        print("sink_form_time", sink_form_time)
        sink_form_time.append(sink_form)
    del files

    # Define colourbar bounds
    cbar_max = args.colourbar_max
    cbar_min = args.colourbar_min

    if L is None:
        if args.axis == 'xy':
            L = [0.0, 0.0, 1.0]
        else:
            L = [1.0, 0.0, 0.0]
        L = np.array(L)
    if args.axis == 'xy':
        y_int = 1
    else:
        y_int = 2

    sys.stdout.flush()
    CW.Barrier()
    rit = args.working_rank
    for frame_val in range(len(frames)):
        if rank == rit:
            time_val = m_times[frame_val]
            plt.clf()
            columns = np.max(positions[:,0])
            rows = np.max(positions[:,1])

            width = float(columns)*(14.5/3.)
            height = float(rows)*(17./4.)
            fig =plt.figure(figsize=(width, height))
            
            gs_left = gridspec.GridSpec(rows, columns-1)
            gs_right = gridspec.GridSpec(rows, 1)

            gs_left.update(right=glr, wspace=glw, hspace=ghspace)
            gs_right.update(left=grl, hspace=ghspace)
            
            axes_dict = {}
            counter = 1

            for pit in range(len(paths)):
                
                try:
                    title_parts = args_dict[pit].title
                except:
                    title_parts = args_dict[pit]['title']
                title = ''
                for part in title_parts:
                    if part != title_parts[-1]:
                        title = title + part + ' '
                    else:
                        title = title + part
            
                ax_label = 'ax' + str(counter)
                yit = np.where(positions[:,1] == positions[pit][1])[0][0]
                if positions[pit][0] == 1 and positions[pit][1] == 1:
                    if columns > 1:
                        axes_dict.update({ax_label:fig.add_subplot(gs_left[0,0])})
                        #print "ADDED SUBPLOT:", counter, "on rank", rank
                    else:
                        axes_dict.update({ax_label:fig.add_subplot(gs_right[0,0])})
                        #print "ADDED SUBPLOT:", counter, "on rank", rank
                elif positions[pit][0] != columns:
                    if args.share_x and args.share_y:
                        if yit >= len(axes_dict):
                            axes_dict.update({ax_label:fig.add_subplot(gs_left[positions[pit][1]-1,positions[pit][0]-1], sharex=axes_dict['ax1'])})
                            #print "ADDED SUBPLOT:", counter, "on rank", rank
                        else:
                            axes_dict.update({ax_label:fig.add_subplot(gs_left[positions[pit][1]-1,positions[pit][0]-1], sharex=axes_dict['ax1'], sharey=axes_dict[list(axes_dict.keys())[yit]])})
                            #print "ADDED SUBPLOT:", counter, "on rank", rank
                    elif args.share_x:
                        axes_dict.update({ax_label:fig.add_subplot(gs_left[positions[it][1]-1,positions[pit][0]-1], sharex=axes_dict['ax1'])})
                        #print "ADDED SUBPLOT:", counter, "on rank", rank
                    elif args.share_y and positions[pit][0]!=1:
                        yit = np.where(positions[:,1] == positions[pit][1])[0][0]
                        axes_dict.update({ax_label:fig.add_subplot(gs_left[positions[pit][1]-1,positions[pit][0]-1], sharey=axes_dict[list(axes_dict.keys())[yit]])})
                        #print "ADDED SUBPLOT:", counter, "on rank", rank
                    elif args.share_y:
                        axes_dict.update({ax_label:fig.add_subplot(gs_left[positions[pit][1]-1,positions[pit][0]-1])})
                        #print "ADDED SUBPLOT:", counter, "on rank", rank
                    else:
                        axes_dict.update({ax_label:fig.add_subplot(gs_left[positions[pit][1]-1,positions[pit][0]-1])})
                        #print "ADDED SUBPLOT:", counter, "on rank", rank
                else:
                    if args.share_x and args.share_y:
                        yit = np.where(positions[:,1] == positions[pit][1])[0][0]
                        axes_dict.update({ax_label:fig.add_subplot(gs_right[positions[pit][1]-1,0], sharex=axes_dict['ax1'], sharey=axes_dict[list(axes_dict.keys())[yit]])})
                        #print "ADDED SUBPLOT:", counter, "on rank", rank
                    elif args.share_x:
                        axes_dict.update({ax_label:fig.add_subplot(gs_right[positions[pit][1]-1,0], sharex=axes_dict['ax1'])})
                        #print "ADDED SUBPLOT:", counter, "on rank", rank
                    elif args.share_y:
                        yit = np.where(positions[:,1] == positions[pit][1])[0][0]
                        axes_dict.update({ax_label:fig.add_subplot(gs_right[positions[pit][1]-1,0], sharey=axes_dict[list(axes_dict.keys())[yit]])})
                        #print "ADDED SUBPLOT:", counter, "on rank", rank
                    else:
                        axes_dict.update({ax_label:fig.add_subplot(gs_right[positions[pit][1]-1,0])})
                        #print "ADDED SUBPLOT:", counter, "on rank", rank

                counter = counter + 1
                axes_dict[ax_label].set(adjustable='box-forced', aspect='equal')
                

                if args.yt_proj and args.plot_time==None and os.path.isfile(paths[pit] + "movie_frame_" + ("%06d" % frames[frame_val]) + ".pkl"):
                    pickle_file = paths[pit] + "movie_frame_" + ("%06d" % frames[frame_val]) + ".pkl"
                    print("USING PICKLED FILE:", pickle_file)
                    file = open(pickle_file, 'r')
                    #weight_fieldstuff = pickle.load(file)
                    X[pit], Y[pit], image, magx, magy, X_vel[pit], Y_vel[pit], velx, vely, part_info, args_dict[pit], simfo[pit] = pickle.load(file)

                    #file_time = stuff[17]
                    file.close()

                else:
                    time_val = m_times[frame_val]
                    print("FILE =", usable_files[pit][frame_val])
                    has_particles = has_sinks(usable_files[pit][frame_val])
                    if has_particles:
                        part_info = mym.get_particle_data(usable_files[pit][frame_val], args_dict[pit].axis, proj_or=L)
                    else:
                        part_info = {}
                    center_vel = [0.0, 0.0, 0.0]
                    if args.image_center != 0 and has_particles:
                        original_positions = [X[pit], Y[pit], X_vel[pit], y_vel[pit]]
                        x_pos = np.round(part_info['particle_position'][0][args.image_center - 1]/cl)*cl
                        y_pos = np.round(part_info['particle_position'][1][args.image_center - 1]/cl)*cl
                        pos = np.array([part_info['particle_position'][0][args.image_center - 1], part_info['particle_position'][1][args.image_center - 1]])
                        X[pit] = X[pit] + x_pos
                        Y[pit] = Y[pit] + y_pos
                        X_vel[pit] = X_vel[pit] + x_pos
                        Y_vel[pit] = Y_vel[pit] + y_pos
                        if args.yt_proj == False:
                            sim_file = usable_sim_files[frame_val][:-12] + 'part' + usable_sim_files[frame_val][-5:]
                        else:
                            sim_file = part_file
                        if len(part_info['particle_mass']) == 1:
                            part_ind = 0
                        else:
                            min_dist = 1000.0
                            for part in range(len(part_info['particle_mass'])):
                                f = h5py.File(sim_file, 'r')
                                temp_pos = np.array([f[list(f.keys())[11]][part][13]/c['au'], f[list(f.keys())[11]][part][13+y_int]/c['au']])
                                f.close()
                                dist = np.sqrt(np.abs(np.diff((temp_pos - pos)**2)))[0]
                                if dist < min_dist:
                                    min_dist = dist
                                    part_ind = part
                        f = h5py.File(sim_file, 'r')
                        center_vel = [f[list(f.keys())[11]][part_ind][18], f[list(f.keys())[11]][part_ind][19], f[list(f.keys())[11]][part_ind][20]]
                        f.close()
                    xabel, yabel, xlim, ylim = image_properties(X[pit], Y[pit], args_dict[pit], simfo[pit])
                    if args_dict[pit].axis == 'xy':
                        center_vel=center_vel[:2]
                    else:
                        center_vel=center_vel[::2]
                    
                    if args_dict[pit].ax_lim != None:
                        if has_particles and args_dict[pit].image_center != 0:
                            xlim = [-1*args_dict[pit].ax_lim + part_info['particle_position'][0][args_dict[pit].image_center - 1], args_dict[pit].ax_lim + part_info['particle_position'][0][args_dict[pit].image_center - 1]]
                            ylim = [-1*args_dict[pit].ax_lim + part_info['particle_position'][1][args_dict[pit].image_center - 1], args_dict[pit].ax_lim + part_info['particle_position'][1][args_dict[pit].image_center - 1]]
                        else:
                            xlim = [-1*args_dict[pit].ax_lim, args_dict[pit].ax_lim]
                            ylim = [-1*args_dict[pit].ax_lim, args_dict[pit].ax_lim]

                    if args.yt_proj == False:
                        f = h5py.File(usable_files[pit][frame_val], 'r')
                        image = get_image_arrays(f, simfo[pit]['field'], simfo[pit], args_dict[pit], X[pit], Y[pit])
                        magx = get_image_arrays(f, 'mag'+args.axis[0]+'_'+simfo[pit]['movie_file_type']+'_'+args.axis, simfo[pit], args_dict[pit], X[pit], Y[pit])
                        magy = get_image_arrays(f, 'mag'+args.axis[1]+'_'+simfo[pit]['movie_file_type']+'_'+args.axis, simfo[pit], args_dict[pit], X[pit], Y[pit])
                        x_pos_min = int(np.round(np.min(X[pit]) - simfo[pit]['xmin_full'])/simfo[pit]['cell_length'])
                        y_pos_min = int(np.round(np.min(Y[pit]) - simfo[pit]['xmin_full'])/simfo[pit]['cell_length'])
                        if np.shape(f['vel'+args.axis[0]+'_'+simfo[pit]['movie_file_type']+'_'+args.axis]) == (2048, 2048):
                            velocity_data = [f['vel'+args.axis[0]+'_'+simfo[pit]['movie_file_type']+'_'+args.axis], f['vel'+args.axis[1]+'_'+simfo[pit]['movie_file_type']+'_'+args.axis]]
                        elif args.axis == 'xy':
                            velocity_data = [f['vel'+args.axis[0]+'_'+simfo[pit]['movie_file_type']+'_'+args.axis][:,:,0], f['vel'+args.axis[1]+'_'+simfo[pit]['movie_file_type']+'_'+args.axis][:,:,0]]
                        else:
                            velocity_data = [f['vel'+args.axis[0]+'_'+simfo[pit]['movie_file_type']+'_'+args.axis][:,0,:], f['vel'+args.axis[1]+'_'+simfo[pit]['movie_file_type']+'_'+args.axis][:,0,:]]
                        velx, vely = mym.get_quiver_arrays(y_pos_min, x_pos_min, X[pit], velocity_data[0], velocity_data[1], center_vel=center_vel)
                    else:
                        if args_dict[pit].image_center == 0 or has_particles == False:
                            center_pos = np.array([0.0, 0.0, 0.0])
                        else:
                            dd = f.all_data()
                            center_pos = np.array([dd['particle_posx'][args.image_center-1].in_units('AU'), dd['particle_posy'][args.image_center-1].in_units('AU'), dd['particle_posz'][args.image_center-1].in_units('AU')])
                        x_width = (xlim[1] -xlim[0])
                        y_width = (ylim[1] -ylim[0])
                        thickness = yt.YTArray(args.slice_thickness, 'AU')
                        
                        proj = yt.OffAxisProjectionPlot(f, L, [simfo[pit]['field'], 'cell_mass', 'velz_mw', 'magz_mw', 'Projected_Magnetic_Field_mw', 'Projected_Velocity_mw'], center=(center_pos, 'AU'), width=(x_width, 'AU'), depth=(args.slice_thickness, 'AU'))
                        image = (proj.frb.data[simfo[pit]['field']]/thickness.in_units('cm')).value
                        velx_full = (proj.frb.data[('gas', 'Projected_Velocity_mw')].in_units('g*cm**2/s')/thickness.in_units('cm')).value
                        vely_full = (proj.frb.data[('gas', 'velz_mw')].in_units('g*cm**2/s')/thickness.in_units('cm')).value
                        magx = (proj.frb.data[('gas', 'Projected_Magnetic_Field_mw')].in_units('g*gauss*cm')/thickness.in_units('cm')).value
                        magy = (proj.frb.data[('gas', 'magz_mw')].in_units('g*gauss*cm')/thickness.in_units('cm')).value
                        mass = (proj.frb.data[('gas', 'cell_mass')].in_units('cm*g')/thickness.in_units('cm')).value
                        
                        velx_full = velx_full/mass
                        vely_full = vely_full/mass
                        magx = magx/mass
                        magy = magy/mass
                        del mass

                        velx, vely = mym.get_quiver_arrays(0.0, 0.0, X[pit], velx_full, vely_full, center_vel=center_vel)
                        del velx_full
                        del vely_full

                        if len(frames) == 1:
                            if rank == 0:
                                pickle_file = paths[pit] + "movie_frame_" + ("%06d" % frames[frame_val]) + ".pkl"
                                file = open(pickle_file, 'w+')
                                pickle.dump((X[pit], Y[pit], image, magx, magy, X_vel[pit], Y_vel[pit], velx, vely, xlim, ylim, has_particles, part_info, simfo[pit], time_val,xabel, yabel), file)
                                file.close()
                                print("Created Pickle:", pickle_file, "for  file:", usable_files[pit][frame_val])
                        else:
                            pickle_file = paths[pit] + "movie_frame_" + ("%06d" % frames[frame_val]) + ".pkl"
                            file = open(pickle_file, 'w+')
                            pickle.dump((X[pit], Y[pit], image, magx, magy, X_vel[pit], Y_vel[pit], velx, vely, xlim, ylim, has_particles, part_info, simfo[pit], time_val,xabel, yabel), file)
                            file.close()
                            print("Created Pickle:", pickle_file, "for  file:", usable_files[pit][frame_val])
                    
                    f.close()

                plot = axes_dict[ax_label].pcolormesh(X[pit], Y[pit], image, cmap=plt.cm.gist_heat, norm=LogNorm(vmin=cbar_min, vmax=cbar_max), rasterized=True)
                plt.gca().set_aspect('equal')
                if frame_val > 0 or time_val > -1.0:
                    axes_dict[ax_label].streamplot(X[pit], Y[pit], magx, magy, density=4, linewidth=0.25, arrowstyle='-', minlength=0.5)
                else:
                    axes_dict[ax_label].streamplot(X[pit], Y[pit], magx, magy, density=4, linewidth=0.25, minlength=0.5)

                xlim = args_dict[pit]['xlim']
                ylim = args_dict[pit]['ylim']
                mym.my_own_quiver_function(axes_dict[ax_label], X_vel[pit], Y_vel[pit], velx, vely, plot_velocity_legend=bool(args_dict[pit]['annotate_velocity']), limits=[xlim, ylim], standard_vel=args.standard_vel)
                if args_dict[pit]['has_particles']:
                    if args.annotate_particles_mass == True:
                        mym.annotate_particles(axes_dict[ax_label], part_info['particle_position'], part_info['accretion_rad'], limits=[xlim, ylim], annotate_field=part_info['particle_mass'])
                    else:
                        mym.annotate_particles(axes_dict[ax_label], part_info['particle_position'], part_info['accretion_rad'], limits=[xlim, ylim], annotate_field=None)
                if args.plot_lref == True:
                    r_acc = np.round(part_info['accretion_rad'])
                    axes_dict[ax_label].annotate('$r_{acc}$='+str(r_acc)+'AU', xy=(0.98*simfo[pit]['xmax'], 0.93*simfo[pit]['ymax']), va="center", ha="right", color='w', fontsize=args_dict[pit].text_font)
                if args.annotate_time == "True" and pit == 0:
                    print("ANNONTATING TIME:", str(int(time_val))+'yr')
                    time_text = axes_dict[ax_label].text((xlim[0]+0.01*(xlim[1]-xlim[0])), (ylim[1]-0.03*(ylim[1]-ylim[0])), '$t$='+str(int(time_val))+'yr', va="center", ha="left", color='w', fontsize=args.text_font)
                    time_text.set_path_effects([path_effects.Stroke(linewidth=3, foreground='black'), path_effects.Normal()])
                    #ax.annotate('$t$='+str(int(time_val))+'yr', xy=(xlim[0]+0.01*(xlim[1]-xlim[0]), ylim[1]-0.03*(ylim[1]-ylim[0])), va="center", ha="left", color='w', fontsize=args.text_font)
                title_text = axes_dict[ax_label].text((np.mean(xlim)), (ylim[1]-0.03*(ylim[1]-ylim[0])), title, va="center", ha="center", color='w', fontsize=(args.text_font+2))
                title_text.set_path_effects([path_effects.Stroke(linewidth=3, foreground='black'), path_effects.Normal()])

                if positions[pit][0] == columns:
                    cbar = plt.colorbar(plot, pad=0.0, ax=axes_dict[ax_label])
                    cbar.set_label('Density (gcm$^{-3}$)', rotation=270, labelpad=14, size=args.text_font)
                axes_dict[ax_label].set_xlabel(args_dict[pit]['xabel'], labelpad=-1, fontsize=args.text_font)
                if positions[pit][0] == 1:
                    axes_dict[ax_label].set_ylabel(args_dict[pit]['yabel'], labelpad=-20, fontsize=args.text_font)
                axes_dict[ax_label].set_xlim(xlim)
                axes_dict[ax_label].set_ylim(ylim)
                for line in axes_dict[ax_label].xaxis.get_ticklines():
                    line.set_color('white')
                for line in axes_dict[ax_label].yaxis.get_ticklines():
                    line.set_color('white')

                plt.tick_params(axis='both', which='major', labelsize=16)
                for line in axes_dict[ax_label].xaxis.get_ticklines():
                    line.set_color('white')
                for line in axes_dict[ax_label].yaxis.get_ticklines():
                    line.set_color('white')

                if positions[pit][0] != 1:
                    yticklabels = axes_dict[ax_label].get_yticklabels()
                    plt.setp(yticklabels, visible=False)

                if positions[pit][0] == 1:
                    axes_dict[ax_label].tick_params(axis='y', which='major', labelsize=args.text_font)
                if positions[pit][1] == rows:
                    axes_dict[ax_label].tick_params(axis='x', which='major', labelsize=args.text_font)
                    if positions[pit][0] != 1:
                        xticklabels = axes_dict[ax_label].get_xticklabels()
                        plt.setp(xticklabels[0], visible=False)

                if len(usable_files[pit]) > 1:
                    if args.output_filename == None:
                        import pdb
                        pdb.set_trace()
                        file_name = save_dir + "movie_frame_" + ("%06d" % frames[frame_val])
                    else:
                        file_name = args.output_filename + "_" + str(int(time_val))
                else:
                    if args.output_filename != None:
                        file_name = args.output_filename
                    else:
                        import pdb
                        pdb.set_trace()
                        file_name = save_dir + "time_" + str(args.plot_time)

                plt.savefig(file_name + ".eps", format='eps', bbox_inches='tight')
                #plt.savefig(file_name + ".pdf", format='pdf', bbox_inches='tight')
                
                #plt.savefig(file_name + ".jpg", format='jpeg', bbox_inches='tight')
                call(['convert', '-antialias', '-quality', '100', '-density', '200', '-resize', '100%', '-flatten', file_name+'.eps', file_name+'.jpg'])
                os.remove(file_name + '.eps')

                del image
                del magx
                del magy
                del velx
                del vely
                
                if args.image_center != 0 and has_particles:
                    X[pit], Y[pit], X_vel[pit], Y_vel[pit] = original_positions
            print('Created frame', (frames[frame_val]), 'of', str(frames[-1]), 'on rank', rank, 'at time of', str(time_val), 'to save_dir:', file_name + '.eps')

        rit = rit +1
        if rit == size:
            rit = 0

    print("completed making movie frames on rank", rank)
Example #10
0
rank = CW.Get_rank()
size = CW.Get_size()
if rank == 0:
    print("size =", size)

args = parse_inputs()

#Define relevant directories
input_dir = sys.argv[1]
save_dir = sys.argv[2]
global_pickle = sys.argv[3]
if os.path.exists(save_dir) == False and rank == 0:
    os.makedirs(save_dir)

sys.stdout.flush()
CW.Barrier()

#Set some plot variables independant on data files
#File files
sink_files = sorted(glob.glob(input_dir + "output*/*.dat"))
files = sorted(glob.glob(input_dir + "*/info*.txt"))
rm_files = []
for info_name in files:
    sink_file = info_name.split('info')[0] + 'stars_output.dat'
    if sink_file not in sink_files:
        rm_files.append(info_name)
for rm_file in rm_files:
    files.remove(rm_file)
del sink_files
gc.collect()
Example #11
0
from distributed import Client
from mpi4py.MPI import COMM_WORLD as world

from dask_mpi import initialize, send_close_signal

# Split our MPI world into two pieces, one consisting just of
# the old rank 3 process and the other with everything else
new_comm_assignment = 1 if world.rank == 3 else 0
comm = world.Split(new_comm_assignment)

if world.rank != 3:
    # run tests with rest of comm
    is_client = initialize(comm=comm, exit=False)

    if is_client:
        with Client() as c:
            c.submit(lambda x: x + 1, 10).result() == 11
            c.submit(lambda x: x + 1, 20).result() == 21
        send_close_signal()

# check that our original comm is intact
world.Barrier()
x = 100 if world.rank == 0 else 200
x = world.bcast(x)
assert x == 100