def form_stars(sink,
               initial_mass_function=settings.stars_initial_mass_function,
               lower_mass_limit=settings.stars_lower_mass_limit,
               upper_mass_limit=settings.stars_upper_mass_limit,
               local_sound_speed=0.2 | units.kms,
               logger=None,
               randomseed=None,
               **keyword_arguments):
    """
    Let a sink form stars.
    """

    logger = logger or logging.getLogger(__name__)
    if randomseed is not None:
        logger.info("setting random seed to %i", randomseed)
        numpy.random.seed(randomseed)

    # sink_initial_density = sink.mass / (4/3 * numpy.pi * sink.radius**3)

    initialised = sink.initialised or False
    if not initialised:
        logger.debug("Initialising sink %i for star formation", sink.key)
        next_mass = generate_next_mass(
            initial_mass_function=initial_mass_function,
            lower_mass_limit=lower_mass_limit,
            upper_mass_limit=upper_mass_limit,
        )
        # sink.next_number_of_stars = len(next_mass)
        # sink.next_total_mass = next_mass.sum()
        sink.next_primary_mass = next_mass[0]
        # if sink.next_number_of_stars > 1:
        #     sink.next_secondary_mass = next_mass[1]
        # if sink.next_number_of_stars > 2:
        #     sink.next_tertiary_mass = next_mass[2]
        sink.initialised = True
    if sink.mass < sink.next_primary_mass:
        logger.debug("Sink %i is not massive enough for the next star",
                     sink.key)
        return [sink, Particles()]

    # We now have the first star that will be formed.
    # Next, we generate a list of stellar masses, so that the last star in the
    # list is just one too many for the sink's mass.

    mass_left = sink.mass - sink.next_primary_mass
    masses = new_masses(
        stellar_mass=mass_left,
        lower_mass_limit=lower_mass_limit,
        upper_mass_limit=upper_mass_limit,
        initial_mass_function=settings.stars_initial_mass_function,
    )
    number_of_stars = len(masses)

    new_stars = Particles(number_of_stars)
    new_stars.age = 0 | units.Myr
    new_stars[0].mass = sink.next_primary_mass
    new_stars[1:].mass = masses[:-1]
    sink.next_primary_mass = masses[-1]
    # if sink.next_number_of_stars > 1:
    #     new_stars[1].mass = sink.next_secondary_mass
    # if sink.next_number_of_stars > 2:
    #     new_stars[2].mass = sink.next_tertiary_mass
    new_stars.position = sink.position
    new_stars.velocity = sink.velocity

    # Random position within the sink radius
    radius = sink.radius
    rho = numpy.random.random(number_of_stars) * radius
    theta = (numpy.random.random(number_of_stars) * (2 * numpy.pi | units.rad))
    phi = (numpy.random.random(number_of_stars) * numpy.pi | units.rad)
    x = rho * sin(phi) * cos(theta)
    y = rho * sin(phi) * sin(theta)
    z = rho * cos(phi)
    new_stars.x += x
    new_stars.y += y
    new_stars.z += z
    # Random velocity, sample magnitude from gaussian with local sound speed
    # like Wall et al (2019)
    # temperature = 10 | units.K
    try:
        local_sound_speed = sink.u.sqrt()
    except AttributeError:
        local_sound_speed = local_sound_speed
    # or (gamma * local_pressure / density).sqrt()
    velocity_magnitude = numpy.random.normal(
        # loc=0.0,  # <- since we already added the velocity of the sink
        scale=local_sound_speed.value_in(units.kms),
        size=number_of_stars,
    ) | units.kms
    velocity_theta = (numpy.random.random(number_of_stars) *
                      (2 * numpy.pi | units.rad))
    velocity_phi = (numpy.random.random(number_of_stars) *
                    (numpy.pi | units.rad))
    vx = velocity_magnitude * sin(velocity_phi) * cos(velocity_theta)
    vy = velocity_magnitude * sin(velocity_phi) * sin(velocity_theta)
    vz = velocity_magnitude * cos(velocity_phi)
    new_stars.vx += vx
    new_stars.vy += vy
    new_stars.vz += vz

    new_stars.origin_cloud = sink.key
    # For Pentacle, this is the PP radius
    new_stars.radius = 0.05 | units.parsec
    sink.mass -= new_stars.total_mass()
    # TODO: fix sink's momentum etc

    # EDIT: Do not shrink the sinks at this point, but rather when finished
    # forming stars.
    # # Shrink the sink's (accretion) radius to prevent it from accreting
    # # relatively far away gas and moving a lot
    # sink.radius = (
    #     (sink.mass / sink_initial_density)
    #     / (4/3 * numpy.pi)
    # )**(1/3)

    # cleanup
    # sink.initialised = False
    new_stars.birth_mass = new_stars.mass
    return [sink, new_stars]
def form_stars_from_group_older_version(
        group_index,
        sink_particles,
        newly_removed_gas,
        lower_mass_limit=settings.stars_lower_mass_limit,
        upper_mass_limit=settings.stars_upper_mass_limit,
        local_sound_speed=0.2 | units.kms,
        minimum_sink_mass=0.01 | units.MSun,
        logger=None,
        randomseed=None,
        shrink_sinks=True,
        **keyword_arguments):
    """
    Form stars from specific group of sinks.

    NOTE: This is the older version where removed gas is
    considered as star-forming region. This is now being
    updated to the above latest version.
    """
    logger = logger or logging.getLogger(__name__)
    logger.info("Using form_stars_from_group on group %i", group_index)
    if randomseed is not None:
        logger.info("Setting random seed to %i", randomseed)
        numpy.random.seed(randomseed)

    # Sanity check: each sink particle must be in a group.
    ungrouped_sinks = sink_particles.select_array(lambda x: x <= 0,
                                                  ['in_group'])
    if not ungrouped_sinks.is_empty():
        logger.info(
            "WARNING: There exist ungrouped sinks. Something is wrong!")
        return None

    # Consider only group with input group index from here onwards.
    group = sink_particles[sink_particles.in_group == group_index]

    # Sanity check: group must have at least a sink
    if group.is_empty():
        logger.info(
            "WARNING: There is no sink in the group: Something is wrong!")
        return None

    number_of_sinks = len(group)
    logger.info("%i sinks found in group #%i: %s", number_of_sinks,
                group_index, group.key)
    group_mass = group.total_mass()
    logger.info("Group mass: %s", group_mass.in_(units.MSun))

    next_mass = generate_next_mass(
        initial_mass_function=initial_mass_function,
        lower_mass_limit=lower_mass_limit,
        upper_mass_limit=upper_mass_limit,
    )[0][0]
    try:
        # Within a group, group_next_primary_mass values are either
        # a mass, or 0 MSun. If all values are 0 MSun, this is a
        # new group. Else, only interested on the non-zero value. The
        # non-zero values are the same.
        logger.info('SANITY CHECK: group_next_primary_mass %s',
                    group.group_next_primary_mass)
        if group.group_next_primary_mass.max() == 0 | units.MSun:
            logger.info('Initiate group #%i for star formation', group_index)
            group.group_next_primary_mass = next_mass
        else:
            next_mass = group.group_next_primary_mass.max()
    # This happens for the first ever assignment of this attribute
    except AttributeError:
        logger.info(
            'AttributeError exception: Initiate group #%i for star formation',
            group_index)
        group.group_next_primary_mass = next_mass

    logger.info("Next mass is %s", next_mass)

    if group_mass < next_mass:
        logger.info("Group #%i is not massive enough for the next star",
                    group_index)
        return None

    # Form stars from the leftover group sink mass
    mass_left = group_mass - next_mass
    masses = new_masses(
        stellar_mass=mass_left,
        lower_mass_limit=lower_mass_limit,
        upper_mass_limit=upper_mass_limit,
        initial_mass_function=settings.stars_initial_mass_function,
    )
    number_of_stars = len(masses)

    logger.info("%i stars created in group #%i with %i sinks", number_of_stars,
                group_index, number_of_sinks)

    new_stars = Particles(number_of_stars)
    new_stars.age = 0 | units.Myr
    new_stars[0].mass = next_mass
    new_stars[1:].mass = masses[:-1]
    group.group_next_primary_mass = masses[-1]
    new_stars = new_stars.sorted_by_attribute("mass").reversed()

    logger.info("Group's next primary mass is %s",
                group.group_next_primary_mass[0])

    # Create placeholders for attributes of new_stars
    new_stars.position = [0, 0, 0] | units.pc
    new_stars.velocity = [0, 0, 0] | units.kms
    new_stars.origin_cloud = group[0].key
    new_stars.star_forming_radius = 0 | units.pc
    new_stars.star_forming_u = local_sound_speed**2

    # Find the newly removed gas in the group
    removed_gas = Particles()
    if not newly_removed_gas.is_empty():
        for s in group:
            removed_gas_by_this_sink = (
                newly_removed_gas[newly_removed_gas.accreted_by_sink == s.key])
            removed_gas.add_particles(removed_gas_by_this_sink)

    logger.info("%i removed gas found in this group", len(removed_gas))

    # Star forming regions that contain the removed gas and the group
    # of sinks
    if not removed_gas.is_empty():
        removed_gas.radius = removed_gas.h_smooth
    star_forming_regions = group.copy()
    star_forming_regions.density = (
        star_forming_regions.initial_density / 1000
    )  # /1000 to reduce likelihood of forming stars in sinks
    star_forming_regions.accreted_by_sink = star_forming_regions.key
    try:
        star_forming_regions.u = star_forming_regions.u
    except AttributeError:
        star_forming_regions.u = local_sound_speed**2
    star_forming_regions.add_particles(removed_gas.copy())
    star_forming_regions.sorted_by_attribute("density").reversed()

    # Generate a probability list of star forming region indices the
    # stars should associate to
    probabilities = (star_forming_regions.density /
                     star_forming_regions.density.sum())
    probabilities /= probabilities.sum()  # Ensure sum is exactly 1
    logger.info("Max & min probabilities: %s, %s", probabilities.max(),
                probabilities.min())

    logger.info("%i star forming regions", len(star_forming_regions))

    def delta_positions_and_velocities(new_stars, star_forming_regions,
                                       probabilities):
        """
        Assign positions and velocities of stars in the star forming regions
        according to the probability distribution
        """
        number_of_stars = len(new_stars)

        # Create an index list of removed gas from probability list
        sample = numpy.random.choice(len(star_forming_regions),
                                     number_of_stars,
                                     p=probabilities)

        # Assign the stars to the removed gas according to the sample
        star_forming_regions_sampled = star_forming_regions[sample]
        new_stars.position = star_forming_regions_sampled.position
        new_stars.velocity = star_forming_regions_sampled.velocity
        new_stars.origin_cloud = star_forming_regions_sampled.accreted_by_sink
        new_stars.star_forming_radius = star_forming_regions_sampled.radius
        try:
            new_stars.star_forming_u = star_forming_regions_sampled.u
        except AttributeError:
            new_stars.star_forming_u = local_sound_speed**2

        # Random position of stars within the sink radius they assigned to
        rho = (numpy.random.random(number_of_stars) *
               new_stars.star_forming_radius)
        theta = (numpy.random.random(number_of_stars) *
                 (2 * numpy.pi | units.rad))
        phi = (numpy.random.random(number_of_stars) * numpy.pi | units.rad)
        x = (rho * sin(phi) * cos(theta)).value_in(units.pc)
        y = (rho * sin(phi) * sin(theta)).value_in(units.pc)
        z = (rho * cos(phi)).value_in(units.pc)

        X = list(zip(*[x, y, z])) | units.pc

        # Random velocity, sample magnitude from gaussian with local sound
        # speed like Wall et al (2019)
        # temperature = 10 | units.K

        # or (gamma * local_pressure / density).sqrt()
        velocity_magnitude = numpy.random.normal(
            # loc=0.0,  # <- since we already added the velocity of the sink
            scale=new_stars.star_forming_u.sqrt().value_in(units.kms),
            size=number_of_stars,
        ) | units.kms
        velocity_theta = (numpy.random.random(number_of_stars) *
                          (2 * numpy.pi | units.rad))
        velocity_phi = (numpy.random.random(number_of_stars) *
                        (numpy.pi | units.rad))
        vx = (velocity_magnitude * sin(velocity_phi) *
              cos(velocity_theta)).value_in(units.kms)
        vy = (velocity_magnitude * sin(velocity_phi) *
              sin(velocity_theta)).value_in(units.kms)
        vz = (velocity_magnitude * cos(velocity_phi)).value_in(units.kms)

        V = list(zip(*[vx, vy, vz])) | units.kms

        return X, V

    dX, dV = delta_positions_and_velocities(new_stars, star_forming_regions,
                                            probabilities)
    logger.info("Updating new stars...")
    new_stars.position += dX
    new_stars.velocity += dV

    # For Pentacle, this is the PP radius
    new_stars.radius = 0.05 | units.parsec

    # mass_ratio = 1 - new_stars.total_mass()/group.total_mass()
    # group.mass *= mass_ratio

    excess_star_mass = 0 | units.MSun
    for s in group:
        logger.info('Sink mass before reduction: %s', s.mass.in_(units.MSun))
        total_star_mass_nearby = (
            new_stars[new_stars.origin_cloud == s.key]).total_mass()

        # To prevent sink mass becomes negative
        if s.mass > minimum_sink_mass:
            if (s.mass - total_star_mass_nearby) <= minimum_sink_mass:
                excess_star_mass += (total_star_mass_nearby - s.mass +
                                     minimum_sink_mass)
                logger.info('Sink mass goes below %s; excess mass is now %s',
                            minimum_sink_mass.in_(units.MSun),
                            excess_star_mass.in_(units.MSun))
                s.mass = minimum_sink_mass
            else:
                s.mass -= total_star_mass_nearby
        else:
            excess_star_mass += total_star_mass_nearby
            logger.info(
                'Sink mass is already <= minimum mass allowed; '
                'excess mass is now %s', excess_star_mass.in_(units.MSun))

        logger.info('Sink mass after reduction: %s', s.mass.in_(units.MSun))

    # Reduce all sinks in group equally with the excess star mass
    logger.info('Reducing all sink mass equally with excess star mass...')
    mass_ratio = 1 - excess_star_mass / group.total_mass()
    group.mass *= mass_ratio

    logger.info("Total sink mass in group: %s",
                group.total_mass().in_(units.MSun))

    if shrink_sinks:
        group.radius = ((group.mass / group.initial_density) /
                        (4 / 3 * numpy.pi))**(1 / 3)
        logger.info("New radii: %s", group.radius.in_(units.pc))

    return new_stars