Ejemplo n.º 1
0
 def test_trace_with_geometric_objects_1(self):
     """ Only one of the geometric objects has a transformation applied.
     """
     root = Node(name="Root", geometry=Sphere(radius=10.0))
     a = Node(name="A", parent=root, geometry=Sphere(radius=1.0))
     b = Node(name="B", parent=root, geometry=Sphere(radius=1.0))
     b.translate((5.0, 0.0, 0.0))
     scene = Scene(root)
     tracer = PhotonTracer(scene)
     position = (-3.0, 0.0, 0.0)
     direction = (1.0, 0.0, 0.0)
     initial_ray = Ray(
         position=position, direction=direction, wavelength=555.0, is_alive=True
     )
     expected_history = [
         initial_ray,  # Starting ray
         replace(initial_ray, position=(-1.0, 0.0, 0.0)), # Moved to intersection
         replace(initial_ray, position=(1.0, 0.0, 0.0)),  # Moved to intersection
         replace(initial_ray, position=(4.0, 0.0, 0.0)),  # Moved to intersection
         replace(initial_ray, position=(6.0, 0.0, 0.0)),  # Moved to intersection
         replace(initial_ray, position=(10.0, 0.0, 0.0), is_alive=False),  # Exit ray
     ]
     history = tracer.follow(initial_ray)
     for pair in zip(history, expected_history):
         a, b = pair
         print("Testing {} {}".format(a.position, b.position))
         assert np.allclose(a.position, b.position)
def structural(p):
    print(len(p))
    dataclasses.fields(p)

    dataclasses.asdict(p)
    dataclasses.astuple(p)
    dataclasses.replace(p)
Ejemplo n.º 3
0
 def test_trace_with_material_object(self):
     """ Root node and test object has a material attached.
     """
     np.random.seed(1) # No reflections
     # np.random.seed(2)  # Reflection at last inteface
     root = Node(name="Root", geometry=Sphere(radius=10.0, material=Dielectric.make_constant((400, 800), 1.0)))
     b = Node(name="B", parent=root, geometry=Sphere(radius=1.0, material=Dielectric.make_constant((400, 800), 1.5)))
     b.translate((5.0, 0.0, 0.0))
     scene = Scene(root)
     tracer = PhotonTracer(scene)
     position = (-3.0, 0.0, 0.0)
     direction = (1.0, 0.0, 0.0)
     initial_ray = Ray(
         position=position, direction=direction, wavelength=555.0, is_alive=True
     )
     expected_history = [
         initial_ray,  # Starting ray
         replace(initial_ray, position=(4.0, 0.0, 0.0)), # Moved to intersection
         replace(initial_ray, position=(4.0, 0.0, 0.0)), # Refracted into A
         replace(initial_ray, position=(6.0, 0.0, 0.0)),  # Moved to intersection
         replace(initial_ray, position=(6.0, 0.0, 0.0)),  # Refracted out of A
         replace(initial_ray, position=(10.0, 0.0, 0.0), is_alive=False),  # Exit ray
     ]
     history = tracer.follow(initial_ray)
     for pair in zip(history, expected_history):
         a, b = pair
         print("Testing {} {}".format(a.position, b.position))
         assert np.allclose(a.position, b.position)
Ejemplo n.º 4
0
 def trace_path(self, ray: Ray, container_node: Node, intersection_node: Node, intersection_point: tuple) -> Tuple[Ray, Node]:
     """ Determines if the ray is absorbed or scattered along it's path and returns
     a new ray is these processes occur. If not, or the container node does not have
     a material attached the ray is moved to the next intersection point.
     """
     # Exit early if possible
     if any([node.geometry is None for node in (container_node, intersection_node)]):
         raise TraceError("Node is missing a geometry.")
     elif container_node.geometry.material is None:
         logger.debug("Container node is missing a material. Will propagate ray the full path length.")
         new_ray = replace(ray, position=intersection_point)
         hit_node = intersection_node
         return (new_ray, hit_node)
     
     # Have a proper material
     root = self.scene.root
     distance = distance_between(ray.position, intersection_point)
     volume = make_volume(container_node, distance)
     new_ray = volume.trace(ray)
     position_in_intersection_node_frame = root.point_to_node(new_ray.position, intersection_node)
     if intersection_node.geometry.is_on_surface(position_in_intersection_node_frame):
         # Hit interface between container node and intersection node
         hit_node = intersection_node
     else:
         # Hit molecule in container node (i.e. was absorbed or scattered)
         hit_node = container_node
     return new_ray, hit_node
Ejemplo n.º 5
0
 def _update_row_diff(self, rowid: int, **kwds: Any) -> None:
     row = self.get_row(rowid)
     if not row:
         raise RowLookupError(rowid)
     upd = dataclasses.replace(row, **kwds)
     upd.id = rowid
     upd.state = row.state
     self._diff[rowid] = upd
Ejemplo n.º 6
0
 def test_no_interaction(self):
     np.random.seed(0)
     mat = Dielectric.make_constant((400, 800), 1.0)
     root = Node(name="Root", parent=None)
     root.geometry = Sphere(radius=10.0, material=mat)
     a = Node(name="A", parent=root)
     a.geometry = Sphere(radius=1.0, material=mat)
     ray = Ray(position=(-1.0, 0.0, 0.0), direction=(1.0, 0.0, 0.0), wavelength=555.0, is_alive=True)
     volume = Volume(a, 2.0)
     new_ray = volume.trace(ray)
     expected = replace(ray, position=(1.0, 0.0, 0.0))
     assert new_ray == expected
Ejemplo n.º 7
0
 def transform(self, ray: Ray) -> Ray:
     """ Transform ray according to the physics of the interaction.
     """
     context = self.context
     normal = np.array(context.normal)
     ray_ = ray.representation(context.normal_node.root, context.normal_node)
     vec = np.array(ray_.direction)
     d = np.dot(normal, vec)
     reflected_direction = vec - 2 * d * normal
     new_ray_ = replace(ray_, direction=tuple(reflected_direction.tolist()))
     new_ray = new_ray_.representation(context.normal_node, context.normal_node.root)
     return new_ray  # back to world node
Ejemplo n.º 8
0
 def transform(self, ray: Ray) -> Ray:
     """ Transform ray according to the physics of the interaction.
     """
     context = self.context
     material = context.interaction_material
     if not isinstance(material, Emissive):
         AppError("Need an emissive material.")
     new_wavelength = material.redshift_wavelength(ray.wavelength)
     new_direction = material.emission_direction()
     
     logger.debug("Wavelength was {} and is now {}".format(ray.wavelength, new_wavelength))
     new_ray = replace(ray, wavelength=new_wavelength, direction=new_direction)
     return new_ray
Ejemplo n.º 9
0
 def test_trace_with_translated_geometric_object(self):
     """ Single translated geometric objects.
     """
     root = Node(name="Root", geometry=Sphere(radius=10.0))
     a = Node(name="A", parent=root, geometry=Sphere(radius=1.0))
     a.translate((5.0, 0.0, 0.0))
     scene = Scene(root)
     tracer = PhotonTracer(scene)
     position = (-2.0, 0.0, 0.0)
     direction = (1.0, 0.0, 0.0)
     initial_ray = Ray(
         position=position, direction=direction, wavelength=555.0, is_alive=True
     )
     expected_history = [
         initial_ray,  # Starting ray
         replace(initial_ray, position=(4.0, 0.0, 0.0)),  # First intersection
         replace(initial_ray, position=(6.0, 0.0, 0.0)),  # Second intersection
         replace(initial_ray, position=(10.0, 0.0, 0.0), is_alive=False),  # Exit ray
     ]
     history = tracer.follow(initial_ray)
     for pair in zip(history, expected_history):
         assert pair[0] == pair[1]
Ejemplo n.º 10
0
 def test_trace_with_geometric_object(self):
     """ Contains a single object without an attached material: a geometric object. 
     In this case we expect the tracer to just return the intersection points with
     the scene.
     """
     root = Node(name="Root", geometry=Sphere(radius=10.0))
     a = Node(name="A", parent=root, geometry=Sphere(radius=1.0))
     scene = Scene(root)
     tracer = PhotonTracer(scene)
     position = (-2.0, 0.0, 0.0)
     direction = (1.0, 0.0, 0.0)
     initial_ray = Ray(
         position=position, direction=direction, wavelength=555.0, is_alive=True
     )
     expected_history = [
         initial_ray,  # Starting ray
         replace(initial_ray, position=(-1.0, 0.0, 0.0)),  # Moved to intersection
         replace(initial_ray, position=(1.0, 0.0, 0.0)),  # Moved to intersection
         replace(initial_ray, position=(10.0, 0.0, 0.0), is_alive=False),  # Exit ray
     ]
     history = tracer.follow(initial_ray)
     for pair in zip(history, expected_history):
         assert pair[0] == pair[1]
Ejemplo n.º 11
0
 def propagate(self, distance: float) -> Ray:
     """ Returns a new ray which has been moved the specified distance along
     its direction.
     
     Parameters
     ----------
     distance : float
         The distance to move the ray. Can be negative in which case the new
     ray will be moved backwards.
     """
     if not self.is_alive:
         raise ValueError('Ray is not alive.')
     new_position = np.array(self.position) + np.array(self.direction) * distance
     new_position = tuple(new_position.tolist())
     new_ray = replace(self, position=new_position)
     return new_ray
Ejemplo n.º 12
0
 def test_trace_without_objects(self):
     """ Trace empty scene. Should return intersections with root node.
     """
     root = Node(name="Root", geometry=Sphere(radius=10.0))
     scene = Scene(root)
     tracer = PhotonTracer(scene)
     position = (-2.0, 0.0, 0.0)
     direction = (1.0, 0.0, 0.0)
     initial_ray = Ray(
         position=position, direction=direction, wavelength=555.0, is_alive=True
     )
     expected_history = [
         initial_ray,  # Starting ray
         replace(initial_ray, position=(10.0, 0.0, 0.0), is_alive=False),  # Exit ray
     ]
     history = tracer.follow(initial_ray)
     for pair in zip(history, expected_history):
         assert pair[0] == pair[1]
Ejemplo n.º 13
0
    def _get_login_response_authn(self, ticket: SSOLoginData, user: IdPUser) -> AuthnInfo:
        """
        Figure out what AuthnContext to assert in the SAML response.

        The 'highest' Assurance-Level (AL) asserted is basically min(ID-proofing-AL, Authentication-AL).

        What AuthnContext is asserted is also heavily influenced by what the SP requested.

        :param ticket: State for this request
        :param user: The user for whom the assertion will be made
        :return: Authn information
        """
        self.logger.debug('MFA credentials logged in the ticket: {}'.format(ticket.mfa_action_creds))
        self.logger.debug('External MFA credential logged in the ticket: {}'.format(ticket.mfa_action_external))
        self.logger.debug('Credentials used in this SSO session:\n{}'.format(self.sso_session.authn_credentials))
        self.logger.debug('User credentials:\n{}'.format(user.credentials.to_list()))

        # Decide what AuthnContext to assert based on the one requested in the request
        # and the authentication performed

        req_authn_context = get_requested_authn_context(self.context.idp, ticket.saml_req, self.logger)

        try:
            resp_authn = eduid_idp.assurance.response_authn(req_authn_context, user, self.sso_session, self.logger)
        except WrongMultiFactor as exc:
            self.logger.info('Assurance not possible: {!r}'.format(exc))
            raise eduid_idp.error.Forbidden('SWAMID_MFA_REQUIRED')
        except MissingMultiFactor as exc:
            self.logger.info('Assurance not possible: {!r}'.format(exc))
            raise eduid_idp.error.Forbidden('MFA_REQUIRED')
        except AssuranceException as exc:
            self.logger.info('Assurance not possible: {!r}'.format(exc))
            raise MustAuthenticate()

        self.logger.debug("Response Authn context class: {!r}".format(resp_authn))

        try:
            self.logger.debug("Asserting AuthnContext {!r} (requested: {!r})".format(
                resp_authn, req_authn_context))
        except AttributeError:
            self.logger.debug("Asserting AuthnContext {!r} (none requested)".format(resp_authn))

        # Augment the AuthnInfo with the authn_timestamp before returning it
        return replace(resp_authn, instant=self.sso_session.authn_timestamp)
Ejemplo n.º 14
0
 def representation(self, from_node: Node, to_node: Node) -> Ray:
     """ Representation of the ray in another coordinate system.
     
     Parameters
     ----------
     from_node : Node
         The node which represents the ray's current coordinate system
     to_node : Node
         The node in which the new ray should be represented.
     
     Notes
     -----
     Use this method to express the ray location and direction as viewed in the 
     `to_node` coordinate system.
     """
     new_position = from_node.point_to_node(self.position, to_node)
     new_direction = from_node.vector_to_node(self.direction, to_node)
     new_ray = replace(self, position=new_position, direction=new_direction)
     return new_ray
Ejemplo n.º 15
0
 def transform(self, ray: Ray) -> Ray:
     """ Transform ray according to the physics of the interaction.
     """
     context = self.context
     n1 = context.n1
     n2 = context.n2
     ray_ = ray.representation(context.normal_node.root, context.normal_node)
     normal = np.array(context.normal)
     vector = np.array(ray_.direction)
     n = n1/n2
     dot = np.dot(vector, normal)
     c = np.sqrt(1 - n**2 * (1 - dot**2))
     sign = 1
     if dot < 0.0:
         sign = -1
     refracted_direction = n * vector + sign*(c - sign*n*dot) * normal
     new_ray_ = replace(ray_, direction=tuple(refracted_direction.tolist()))
     new_ray = new_ray_.representation(context.normal_node, context.normal_node.root)
     return new_ray
Ejemplo n.º 16
0
    def step(self, ray: Ray) -> Ray:
        """ Steps the ray one event forward.
        """
        ctx = self.make_step_context(ray)
        logger.debug('Ray is in {}.'.format(ctx.container))
        initial_ray = ray
        next_intersection = ctx.next_intersection()
        logger.debug("Tracing along path.")
        ray, hit_node = self.trace_path(ray, ctx.container, next_intersection.hit, next_intersection.point)
        if hit_node == self.scene.root:            
            ray = replace(ray, is_alive=False)
            logger.debug("Ray died {}".format(ray))
            yield ray
        else:
            logger.debug('Ray moved and hit or is inside {}.'.format(hit_node))
            yield ray # yield ray motion step
            ray_on_hit_node = hit_node.geometry.is_on_surface(self.scene.root.point_to_node(ray.position, hit_node))
            if ray_on_hit_node:
                if np.allclose(ray.direction, initial_ray.direction):
                    interface = (ctx.container, hit_node)
                    if ctx.container == hit_node:
                        interface = (ctx.container, ctx.all_intersections[ctx.next_index+1].hit)
                else:
                    logger.debug('Direction changed. Re-calculating intersections.')
                    ctx = self.make_step_context(ray)
                    interface = (ctx.container, ctx.next_intersection().hit)
                
                logger.debug("Interface {}".format(interface))

                # Check that a material exists on both sides of the interface
                has_material = [side.geometry.material is not None for side in interface]
                if all(has_material):
                    logger.debug("Tracing interface.")
                    ray = self.trace_interface(ray, interface)
                    yield ray # yield ray interface step and exit
                elif all(np.logical_not(has_material)):
                    logger.debug("Interface does not have materials.")
                elif len(set(has_material)) == 2:
                    logger.debug(traceback.format_exc())
                    raise TraceError("Both interface nodes must have a material. At interface {} {}".format(interface[0].geometry.material, interface[1].geometry.material))
Ejemplo n.º 17
0
 def refresh_settings(self, ovrIP=False):
     self.settings = dataclasses.replace(
         self.settings, **ValidateSettings.get_current_settings(self.launcher, ovrIP=ovrIP))
     self.ipPortCombo = f'{self.settings.PublicIP}:{self.settings.Port}'
Ejemplo n.º 18
0
 def clamp(message: AnalysisMessage):
     if not samefile(message.filename, cls_file):
         return replace(message, filename=cls_file, line=cls_start_line)
     else:
         return message
Ejemplo n.º 19
0
 def transform(self, ray: Ray) -> Ray:
     new_ray = replace(ray, is_alive=False)
     return new_ray
Ejemplo n.º 20
0
    async def main_writer(self):
        """
        Main coroutine for submitting commands.
        """
        LOG.info("Writer loop for party %s is starting...", self.party)
        ledger_fut = ensure_future(self._pool.ledger())

        client = await self._client_fut  # type: LedgerClient
        metadata = await ledger_fut  # type: LedgerMetadata
        validator = ValidateSerializer(self.parent.lookup)

        self._writer.pending_commands.start()

        # Asynchronously iterate over all pending commands that a user has (or will) send.
        # This asynchronous loop "blocks" when pending_commands is empty and is "woken up"
        # immediately when new data is added to it. The loop terminates when the pending_commands
        # ServiceQueue is stopped.
        async for p in self._writer.pending_commands:
            LOG.debug("Sending a command: %s", p)
            command_payloads = [
            ]  # type: List[Tuple[_PendingCommand, Sequence[CommandPayload]]]

            if p.future.done():
                # PendingCommand instances that are already marked as done have either been marked
                # as a failure or cancelled by the caller. Do NOT send the corresponding Ledger API
                # command because the PendingCommand() has effectively been aborted.
                continue

            self._writer.inflight_commands.append(p)
            try:
                defaults = CommandDefaults(
                    default_party=self.party,
                    default_ledger_id=metadata.ledger_id,
                    default_workflow_id=None,
                    default_application_id=self._config.application_name,
                    default_command_id=None,
                )
                cps = p.command.build(defaults)
                if cps:
                    commands = await metadata.package_loader.do_with_retry(
                        lambda: [
                            replace(cp,
                                    commands=validator.serialize_commands(
                                        cp.commands)) for cp in cps
                        ])
                    command_payloads.append((p, commands))
                    await submit_command_async(client, p, commands)
                else:
                    # This is a "null command"; don't even bother sending to the server. Immediately
                    # resolve the future successfully and discard
                    if not p.future.done():
                        p.future.set_result(None)
            except Exception as ex:
                LOG.exception("Tried to send a command and failed!")
                p.notify_read_fail(ex)

        LOG.info("Writer loop for party %s is winding down.", self.party)

        # After the pending command list is fully empty (and never to be filled again), wait for
        # all outstanding commands.
        done, pending = await wait(
            [pc.future for pc in self._writer.inflight_commands],
            timeout=5,
            return_when=ALL_COMPLETED,
        )

        if pending:
            LOG.warning(
                "Writer loop for party %s has NOT fully finished, "
                "but will be terminated anyway (%d futures still pending).",
                self.party,
                len(pending),
            )
        else:
            LOG.info("Writer loop for party %s is finished.", self.party)
def union2(p: Union[Type[A], Type[B]]):
    dataclasses.fields(p)

    dataclasses.asdict(<warning descr="'dataclasses.asdict' method should be called on dataclass instances">p</warning>)
    dataclasses.astuple(<warning descr="'dataclasses.astuple' method should be called on dataclass instances">p</warning>)
    dataclasses.replace(<warning descr="'dataclasses.replace' method should be called on dataclass instances">p</warning>)
Ejemplo n.º 22
0
async def get_helm_chart(request: HelmChartRequest,
                         subsystem: HelmSubsystem) -> HelmChart:
    dependencies, source_files, metadata = await MultiGet(
        Get(Targets, DependenciesRequest(request.field_set.dependencies)),
        Get(
            HelmChartSourceFiles,
            HelmChartSourceFilesRequest,
            HelmChartSourceFilesRequest.for_field_set(
                request.field_set,
                include_metadata=False,
                include_resources=True,
                include_files=True,
            ),
        ),
        Get(HelmChartMetadata, HelmChartMetaSourceField,
            request.field_set.chart),
    )

    third_party_artifacts = await Get(
        FetchedHelmArtifacts,
        FetchHelmArfifactsRequest,
        FetchHelmArfifactsRequest.for_targets(
            dependencies,
            description_of_origin=request.field_set.address.spec),
    )

    first_party_subcharts = await MultiGet(
        Get(HelmChart, HelmChartRequest, HelmChartRequest.from_target(target))
        for target in dependencies if HelmChartFieldSet.is_applicable(target))
    third_party_charts = await MultiGet(
        Get(HelmChart, FetchedHelmArtifact, artifact)
        for artifact in third_party_artifacts)

    subcharts = [*first_party_subcharts, *third_party_charts]
    subcharts_digest = EMPTY_DIGEST
    if subcharts:
        logger.debug(
            f"Found {pluralize(len(subcharts), 'subchart')} as direct dependencies on Helm chart at: {request.field_set.address}"
        )

        merged_subcharts = await Get(
            Digest,
            MergeDigests([chart.snapshot.digest for chart in subcharts]))
        subcharts_digest = await Get(Digest,
                                     AddPrefix(merged_subcharts, "charts"))

        # Update subchart dependencies in the metadata and re-render it.
        remotes = subsystem.remotes()
        subchart_map: dict[str, HelmChart] = {
            chart.metadata.name: chart
            for chart in subcharts
        }
        updated_dependencies: OrderedSet[HelmChartDependency] = OrderedSet()
        for dep in metadata.dependencies:
            updated_dep = dep

            if not dep.repository and remotes.default_registry:
                # If the dependency hasn't specified a repository, then we choose the registry with the 'default' alias.
                default_remote = remotes.default_registry
                updated_dep = dataclasses.replace(
                    updated_dep, repository=default_remote.address)
            elif dep.repository and dep.repository.startswith("@"):
                remote = next(remotes.get(dep.repository))
                updated_dep = dataclasses.replace(updated_dep,
                                                  repository=remote.address)

            if dep.name in subchart_map:
                updated_dep = dataclasses.replace(
                    updated_dep,
                    version=subchart_map[dep.name].metadata.version)

            updated_dependencies.add(updated_dep)

        # Include the explicitly provided subchats in the set of dependencies if not already present.
        updated_dependencies_names = {dep.name for dep in updated_dependencies}
        remaining_subcharts = [
            chart for chart in subcharts
            if chart.metadata.name not in updated_dependencies_names
        ]
        for chart in remaining_subcharts:
            if chart.artifact:
                dependency = HelmChartDependency(
                    name=chart.artifact.name,
                    version=chart.artifact.version,
                    repository=chart.artifact.location_url,
                )
            else:
                dependency = HelmChartDependency(
                    name=chart.metadata.name, version=chart.metadata.version)
            updated_dependencies.add(dependency)

        # Update metadata with the information about charts' dependencies.
        metadata = dataclasses.replace(
            metadata, dependencies=tuple(updated_dependencies))

    # Re-render the Chart.yaml file with the updated dependencies.
    metadata_digest, sources_without_metadata = await MultiGet(
        Get(Digest, HelmChartMetadata, metadata),
        Get(
            Digest,
            DigestSubset(
                source_files.snapshot.digest,
                PathGlobs([
                    "**/*", *(f"!**/{filename}"
                              for filename in HELM_CHART_METADATA_FILENAMES)
                ]),
            ),
        ),
    )

    # Merge all digests that conform chart's content.
    content_digest = await Get(
        Digest,
        MergeDigests(
            [metadata_digest, sources_without_metadata, subcharts_digest]))

    chart_snapshot = await Get(Snapshot,
                               AddPrefix(content_digest, metadata.name))
    return HelmChart(address=request.field_set.address,
                     metadata=metadata,
                     snapshot=chart_snapshot)
Ejemplo n.º 23
0
 def replace(self, **changes) -> 'CircuitOperation':
     """Returns a copy of this operation with the specified changes."""
     return dataclasses.replace(self, **changes)
Ejemplo n.º 24
0
def decode_single(player_index: int, num_players: int, game_modifications: dict,
                  configuration: LayoutConfiguration) -> GamePatches:
    """
    Decodes a dict created by `serialize` back into a GamePatches.
    :param game_modifications:
    :param player_index:
    :param num_players:
    :param configuration:
    :return:
    """
    game = data_reader.decode_data(configuration.game_data)
    game_specific = dataclasses.replace(
        game.game_specific,
        energy_per_tank=configuration.energy_per_tank,
        beam_configurations=configuration.beam_configuration.create_game_specific(game.resource_database))

    world_list = game.world_list

    # Starting Location
    starting_location = _area_name_to_area_location(world_list, game_modifications["starting_location"])

    # Initial items
    starting_items = {
        find_resource_info_with_long_name(game.resource_database.item, resource_name): quantity
        for resource_name, quantity in game_modifications["starting_items"].items()
    }

    # Elevators
    elevator_connection = {}
    for source_name, target_name in game_modifications["elevators"].items():
        source_area = _area_name_to_area_location(world_list, source_name)
        target_area = _area_name_to_area_location(world_list, target_name)

        potential_source_nodes = [
            node
            for node in world_list.area_by_area_location(source_area).nodes
            if isinstance(node, TeleporterNode)
        ]
        assert len(potential_source_nodes) == 1
        source_node = potential_source_nodes[0]
        elevator_connection[source_node.teleporter_instance_id] = target_area

    # Translator Gates
    translator_gates = {
        _find_gate_with_name(gate_name): find_resource_info_with_long_name(game.resource_database.item, resource_name)
        for gate_name, resource_name in game_modifications["translators"].items()
    }

    # Pickups
    target_name_re = re.compile(r"(.*) for Player \d+")

    index_to_pickup_name = {}
    for world_name, world_data in game_modifications["locations"].items():
        for area_node_name, target_name in world_data.items():
            if target_name == _ETM_NAME:
                continue

            pickup_name_match = target_name_re.match(target_name)
            if pickup_name_match is not None:
                pickup_name = pickup_name_match.group(1)
            else:
                pickup_name = target_name

            node = world_list.node_from_name(f"{world_name}/{area_node_name}")
            assert isinstance(node, PickupNode)
            index_to_pickup_name[node.pickup_index] = pickup_name

    decoder = BitPackDecoder(base64.b64decode(game_modifications["_locations_internal"].encode("utf-8"), validate=True))
    pickup_assignment = dict(BitPackPickupEntryList.bit_pack_unpack(decoder, {
        "index_mapping": index_to_pickup_name,
        "num_players": num_players,
        "database": game.resource_database,
    }).value)

    # Hints
    hints = {}
    for asset_id, hint in game_modifications["hints"].items():
        hints[LogbookAsset(int(asset_id))] = Hint.from_json(hint)

    return GamePatches(
        player_index=player_index,
        pickup_assignment=pickup_assignment,  # PickupAssignment
        elevator_connection=elevator_connection,  # Dict[int, AreaLocation]
        dock_connection={},  # Dict[Tuple[int, int], DockConnection]
        dock_weakness={},  # Dict[Tuple[int, int], DockWeakness]
        translator_gates=translator_gates,
        starting_items=starting_items,  # ResourceGainTuple
        starting_location=starting_location,  # AreaLocation
        hints=hints,
        game_specific=game_specific,
    )
Ejemplo n.º 25
0
 def _replace(self, **kwargs):
     return replace(self, **kwargs)
Ejemplo n.º 26
0
Archivo: mlir.py Proyecto: 0x0is1/jax
 def replace(self, **kw):
     return dataclasses.replace(self, **kw)
Ejemplo n.º 27
0
from dataclasses import dataclass, field, InitVar, replace


@dataclass
class A:
    a: int
    b: str = "str"


replace(A(1))
replace(A(1), a=1, b="abc")
replace(A(1), <warning descr="Expected type 'int', got 'str' instead">a="str"</warning>, <warning descr="Expected type 'str', got 'int' instead">b=1</warning>)


@dataclass
class B:
    a: int
    b: str = field(default="str", init=False)


replace(B(1))
replace(B(1), a=1)
replace(B(1), <warning descr="Expected type 'int', got 'str' instead">a="str"</warning>)


@dataclass
class C:
    a: int
    b: InitVar[str] = "str"

Ejemplo n.º 28
0
    def mangle_column_value(exp: Expression) -> Expression:
        if not isinstance(exp, Column):
            return exp

        return replace(exp, column_name=f"{alias_prefix}{exp.column_name}")
Ejemplo n.º 29
0
    def mangle_aliases(exp: Expression) -> Expression:
        alias = exp.alias
        if alias is not None:
            return replace(exp, alias=f"{alias_prefix}{alias}")

        return exp
Ejemplo n.º 30
0
 def command_args(self, args: ty.Dict):
     self.inputs = dc.replace(self.inputs, **args)
Ejemplo n.º 31
0
def main():
    parser = HfArgumentParser((ModelArguments, DynamicDataTrainingArguments,
                               DynamicTrainingArguments))

    if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
        # If we pass only one argument to the script and it's the path to a json file,
        # let's parse it to get our arguments.
        model_args, data_args, training_args = parser.parse_json_file(
            json_file=os.path.abspath(sys.argv[1]))
    else:
        model_args, data_args, training_args = parser.parse_args_into_dataclasses(
        )

    if 'prompt' in model_args.few_shot_type:
        data_args.prompt = True

    if training_args.no_train:
        training_args.do_train = False
    if training_args.no_predict:
        training_args.do_predict = False

    # Setup logging
    logging.basicConfig(
        format="%(asctime)s - %(levelname)s - %(name)s -   %(message)s",
        datefmt="%m/%d/%Y %H:%M:%S",
        level=logging.INFO
        if training_args.local_rank in [-1, 0] else logging.WARN,
    )

    # Load prompt/template/mapping file
    if data_args.prompt:
        if data_args.prompt_path is not None:
            assert data_args.prompt_id is not None
            prompt_list = []
            with open(data_args.prompt_path) as f:
                for line in f:
                    line = line.strip()
                    template, mapping = line.split('\t')
                    prompt_list.append((template, mapping))

            data_args.template, data_args.mapping = prompt_list[
                data_args.prompt_id]
            logger.info(
                "Specify load the %d-th prompt: %s | %s" %
                (data_args.prompt_id, data_args.template, data_args.mapping))
        else:
            if data_args.template_path is not None:
                with open(data_args.template_path) as f:
                    data_args.template_list = []
                    for line in f:
                        line = line.strip()
                        if len(line) > 0:
                            data_args.template_list.append(line)

                # Load top-n templates
                if data_args.top_n_template is not None:
                    data_args.template_list = data_args.template_list[:
                                                                      data_args
                                                                      .
                                                                      top_n_template]
                logger.info(
                    "Load top-%d templates from %s" %
                    (len(data_args.template_list), data_args.template_path))

                # ... or load i-th template
                if data_args.template_id is not None:
                    data_args.template = data_args.template_list[
                        data_args.template_id]
                    data_args.template_list = None
                    logger.info("Specify load the %d-th template: %s" %
                                (data_args.template_id, data_args.template))

            if data_args.mapping_path is not None:
                assert data_args.mapping_id is not None  # Only can use one label word mapping
                with open(data_args.mapping_path) as f:
                    mapping_list = []
                    for line in f:
                        line = line.strip()
                        mapping_list.append(line)

                data_args.mapping = mapping_list[data_args.mapping_id]
                logger.info("Specify using the %d-th mapping: %s" %
                            (data_args.mapping_id, data_args.mapping))

    # Check save path
    if (os.path.exists(training_args.output_dir)
            and os.listdir(training_args.output_dir) and training_args.do_train
            and not training_args.overwrite_output_dir):
        raise ValueError(
            f"Output directory ({training_args.output_dir}) already exists.")

    logger.warning(
        "Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s",
        training_args.local_rank,
        training_args.device,
        training_args.n_gpu,
        bool(training_args.local_rank != -1),
        training_args.fp16,
    )
    logger.info("Training/evaluation parameters %s", training_args)

    # Set seed
    set_seed(training_args.seed)

    try:
        num_labels = num_labels_mapping[data_args.task_name]
        output_mode = output_modes_mapping[data_args.task_name]
        logger.info(
            "Task name: {}, number of labels: {}, output mode: {}".format(
                data_args.task_name, num_labels, output_mode))
    except KeyError:
        raise ValueError("Task not found: %s" % (data_args.task_name))

    # Automatically generate template for using demonstrations
    if data_args.auto_demo and model_args.few_shot_type == 'prompt-demo':
        # GPT-3's in-context learning
        if data_args.gpt3_in_context_head or data_args.gpt3_in_context_tail:
            logger.info(
                "Automatically convert the template to GPT-3's in-context learning."
            )
            assert data_args.template_list is None

            old_template = data_args.template
            new_template = old_template + ''
            old_template = old_template.replace('*cls*', '')
            # Single sentence or sentence pair?
            sent_num = 1
            if "_1" in old_template:
                sent_num = 2
            for instance_id in range(data_args.gpt3_in_context_num):
                sub_template = old_template + ''
                # Replace sent_id
                for sent_id in range(sent_num):
                    sub_template = sub_template.replace(
                        "_{}*".format(sent_id),
                        "_{}*".format(sent_num + sent_num * instance_id +
                                      sent_id))
                # Replace mask
                sub_template = sub_template.replace(
                    "*mask*", "*labelx_{}*".format(instance_id))
                if data_args.gpt3_in_context_tail:
                    new_template = new_template + sub_template  # Put context at the end
                else:
                    new_template = sub_template + new_template  # Put context at the beginning
            logger.info("| {} => {}".format(data_args.template, new_template))
            data_args.template = new_template
        else:
            logger.info(
                "Automatically convert the template to using demonstrations.")
            if data_args.template_list is not None:
                for i in range(len(data_args.template_list)):
                    old_template = data_args.template_list[i]
                    new_template = old_template + ''
                    old_template = old_template.replace('*cls*', '')
                    # Single sentence or sentence pair?
                    sent_num = 1
                    if "_1" in old_template:
                        sent_num = 2
                    for label_id in range(num_labels):
                        sub_template = old_template + ''
                        # Replace sent id
                        for sent_id in range(sent_num):
                            sub_template = sub_template.replace(
                                "_{}*".format(sent_id),
                                "_{}*".format(sent_num + sent_num * label_id +
                                              sent_id))
                        # Replace mask
                        sub_template = sub_template.replace(
                            "*mask*", "*label_{}*".format(label_id))
                        new_template = new_template + sub_template
                    logger.info("| {} => {}".format(data_args.template_list[i],
                                                    new_template))
                    data_args.template_list[i] = new_template
            else:
                old_template = data_args.template
                new_template = old_template + ''
                old_template = old_template.replace('*cls*', '')
                # Single sentence or sentence pair?
                sent_num = 1
                if "_1" in old_template:
                    sent_num = 2
                for label_id in range(num_labels):
                    sub_template = old_template + ''
                    # Replace sent id
                    for sent_id in range(sent_num):
                        sub_template = sub_template.replace(
                            "_{}".format(sent_id),
                            "_{}".format(sent_num + sent_num * label_id +
                                         sent_id))
                    # Replace mask
                    sub_template = sub_template.replace(
                        "*mask*", "*label_{}*".format(label_id))
                    new_template = new_template + sub_template
                logger.info("| {} => {}".format(data_args.template,
                                                new_template))
                data_args.template = new_template

    # Create config
    config = AutoConfig.from_pretrained(
        model_args.config_name
        if model_args.config_name else model_args.model_name_or_path,
        num_labels=num_labels,
        finetuning_task=data_args.task_name,
        cache_dir=model_args.cache_dir,
    )

    if 'prompt' in model_args.few_shot_type:
        if config.model_type == 'roberta':
            model_fn = RobertaForPromptFinetuning
        elif config.model_type == 'bert':
            model_fn = BertForPromptFinetuning
        else:
            raise NotImplementedError
    elif model_args.few_shot_type == 'finetune':
        model_fn = AutoModelForSequenceClassification
    else:
        raise NotImplementedError
    special_tokens = []

    # Create tokenizer
    tokenizer = AutoTokenizer.from_pretrained(
        model_args.tokenizer_name
        if model_args.tokenizer_name else model_args.model_name_or_path,
        additional_special_tokens=special_tokens,
        cache_dir=model_args.cache_dir,
    )

    # Get our special datasets.
    train_dataset = (FewShotDataset(data_args,
                                    tokenizer=tokenizer,
                                    mode="train",
                                    use_demo=("demo"
                                              in model_args.few_shot_type)))
    eval_dataset = (FewShotDataset(data_args,
                                   tokenizer=tokenizer,
                                   mode="dev",
                                   use_demo=("demo"
                                             in model_args.few_shot_type))
                    if training_args.do_eval else None)
    test_dataset = (FewShotDataset(data_args,
                                   tokenizer=tokenizer,
                                   mode="test",
                                   use_demo=("demo"
                                             in model_args.few_shot_type))
                    if training_args.do_predict else None)

    set_seed(training_args.seed)

    model = model_fn.from_pretrained(
        model_args.model_name_or_path,
        from_tf=bool(".ckpt" in model_args.model_name_or_path),
        config=config,
        cache_dir=model_args.cache_dir,
    )

    # For BERT, increase the size of the segment (token type) embeddings
    if config.model_type == 'bert':
        model.resize_token_embeddings(len(tokenizer))
        resize_token_type_embeddings(model,
                                     new_num_types=10,
                                     random_segment=model_args.random_segment)

    # Pass dataset and argument information to the model
    if data_args.prompt:
        model.label_word_list = torch.tensor(
            train_dataset.label_word_list).long().cuda()
    if output_modes_mapping[data_args.task_name] == 'regression':
        # lower / upper bounds
        model.lb, model.ub = bound_mapping[data_args.task_name]
    model.model_args = model_args
    model.data_args = data_args
    model.tokenizer = tokenizer

    # Build metric
    def build_compute_metrics_fn(
            task_name: str) -> Callable[[EvalPrediction], Dict]:
        def compute_metrics_fn(p: EvalPrediction):
            # Note: the eval dataloader is sequential, so the examples are in order.
            # We average the logits over each sample for using demonstrations.
            predictions = p.predictions
            num_logits = predictions.shape[-1]
            logits = predictions.reshape(
                [eval_dataset.num_sample, -1, num_logits])
            logits = logits.mean(axis=0)

            if num_logits == 1:
                preds = np.squeeze(logits)
            else:
                preds = np.argmax(logits, axis=1)

            # Just for sanity, assert label ids are the same.
            label_ids = p.label_ids.reshape([eval_dataset.num_sample, -1])
            label_ids_avg = label_ids.mean(axis=0)
            label_ids_avg = label_ids_avg.astype(p.label_ids.dtype)
            assert (label_ids_avg - label_ids[0]).mean() < 1e-2
            label_ids = label_ids[0]

            return compute_metrics_mapping[task_name](task_name, preds,
                                                      label_ids)

        return compute_metrics_fn

    # Initialize our Trainer
    trainer = Trainer(model=model,
                      args=training_args,
                      train_dataset=train_dataset,
                      eval_dataset=eval_dataset,
                      compute_metrics=build_compute_metrics_fn(
                          data_args.task_name))

    # Training
    if training_args.do_train:
        trainer.train(model_path=model_args.model_name_or_path if os.path.
                      isdir(model_args.model_name_or_path) else None)
        # Use the early stop, so do not save the model in the end (unless specify save_at_last)
        if training_args.save_at_last:
            trainer.save_model(training_args.output_dir)

        if trainer.is_world_master():
            tokenizer.save_pretrained(training_args.output_dir)
            torch.save(
                model_args,
                os.path.join(training_args.output_dir, "model_args.bin"))
            torch.save(data_args,
                       os.path.join(training_args.output_dir, "data_args.bin"))

        # Reload the best checkpoint (for eval)
        model = model_fn.from_pretrained(training_args.output_dir)
        model = model.to(training_args.device)
        trainer.model = model
        if data_args.prompt:
            model.label_word_list = torch.tensor(
                train_dataset.label_word_list).long().cuda()
        if output_modes_mapping[data_args.task_name] == 'regression':
            # lower / upper bounds
            model.lb, model.ub = bound_mapping[data_args.task_name]
        model.model_args = model_args
        model.data_args = data_args
        model.tokenizer = tokenizer

    # Evaluation
    final_result = {
        'time': str(datetime.today()),
    }

    eval_results = {}
    if training_args.do_eval:
        logger.info("*** Validate ***")

        eval_datasets = [eval_dataset]

        for eval_dataset in eval_datasets:
            trainer.compute_metrics = build_compute_metrics_fn(
                eval_dataset.args.task_name)
            output = trainer.evaluate(eval_dataset=eval_dataset)
            eval_result = output.metrics

            output_eval_file = os.path.join(
                training_args.output_dir,
                f"eval_results_{eval_dataset.args.task_name}.txt")
            if trainer.is_world_master():
                with open(output_eval_file, "w") as writer:
                    logger.info("***** Eval results {} *****".format(
                        eval_dataset.args.task_name))
                    for key, value in eval_result.items():
                        logger.info("  %s = %s", key, value)
                        writer.write("%s = %s\n" % (key, value))
                        final_result[eval_dataset.args.task_name + '_dev_' +
                                     key] = value
            eval_results.update(eval_result)

    test_results = {}
    if training_args.do_predict:
        logging.info("*** Test ***")
        test_datasets = [test_dataset]
        if data_args.task_name == "mnli":
            mnli_mm_data_args = dataclasses.replace(data_args,
                                                    task_name="mnli-mm")
            test_datasets.append(
                FewShotDataset(mnli_mm_data_args,
                               tokenizer=tokenizer,
                               mode="test",
                               use_demo=('demo' in model_args.few_shot_type)))

        for test_dataset in test_datasets:
            trainer.compute_metrics = build_compute_metrics_fn(
                test_dataset.args.task_name)
            output = trainer.evaluate(eval_dataset=test_dataset)
            test_result = output.metrics

            output_test_file = os.path.join(
                training_args.output_dir,
                f"test_results_{test_dataset.args.task_name}.txt")
            if trainer.is_world_master():
                with open(output_test_file, "w") as writer:
                    logger.info("***** Test results {} *****".format(
                        test_dataset.args.task_name))
                    for key, value in test_result.items():
                        logger.info("  %s = %s", key, value)
                        writer.write("%s = %s\n" % (key, value))
                        final_result[test_dataset.args.task_name + '_test_' +
                                     key] = value

                if training_args.save_logit:
                    predictions = output.predictions
                    num_logits = predictions.shape[-1]
                    logits = predictions.reshape(
                        [test_dataset.num_sample, -1, num_logits]).mean(axis=0)
                    np.save(
                        os.path.join(
                            training_args.save_logit_dir,
                            "{}-{}-{}.npy".format(test_dataset.task_name,
                                                  training_args.model_id,
                                                  training_args.array_id)),
                        logits)

            test_results.update(test_result)

    with FileLock('log.lock'):
        with open('log', 'a') as f:
            final_result.update(vars(model_args))
            final_result.update(vars(training_args))
            final_result.update(vars(data_args))
            if 'evaluation_strategy' in final_result:
                final_result.pop('evaluation_strategy')
            f.write(str(final_result) + '\n')

    return eval_results
Ejemplo n.º 32
0
def analyze_calltree(fn: Callable, options: AnalysisOptions,
                     conditions: Conditions) -> CallTreeAnalysis:
    debug('Begin analyze calltree ', fn.__name__)

    all_messages = MessageCollector()
    search_root = SinglePathNode(True)
    space_exhausted = False
    failing_precondition: Optional[
        ConditionExpr] = conditions.pre[0] if conditions.pre else None
    failing_precondition_reason: str = ''
    num_confirmed_paths = 0

    cur_space: List[StateSpace] = [cast(StateSpace, None)]
    short_circuit = ShortCircuitingContext(lambda: cur_space[0])
    _ = get_subclass_map()  # ensure loaded
    top_analysis: Optional[CallAnalysis] = None
    enforced_conditions = EnforcedConditions(
        fn_globals(fn),
        builtin_patches(),
        interceptor=short_circuit.make_interceptor)

    def in_symbolic_mode():
        return (cur_space[0] is not None
                and not cur_space[0].running_framework_code)

    patched_builtins = Patched(_PATCH_REGISTRATIONS, in_symbolic_mode)
    with enforced_conditions, patched_builtins, enforced_conditions.disabled_enforcement(
    ):
        for i in itertools.count(1):
            start = time.time()
            if start > options.deadline:
                debug('Exceeded condition timeout, stopping')
                break
            options.incr('num_paths')
            debug('Iteration ', i)
            space = TrackingStateSpace(
                execution_deadline=start + options.per_path_timeout,
                model_check_timeout=options.per_path_timeout / 2,
                search_root=search_root)
            cur_space[0] = space
            try:
                # The real work happens here!:
                call_analysis = attempt_call(conditions, space, fn,
                                             short_circuit,
                                             enforced_conditions)
                if failing_precondition is not None:
                    cur_precondition = call_analysis.failing_precondition
                    if cur_precondition is None:
                        if call_analysis.verification_status is not None:
                            # We escaped the all the pre conditions on this try:
                            failing_precondition = None
                    elif (cur_precondition.line == failing_precondition.line
                          and call_analysis.failing_precondition_reason):
                        failing_precondition_reason = call_analysis.failing_precondition_reason
                    elif cur_precondition.line > failing_precondition.line:
                        failing_precondition = cur_precondition
                        failing_precondition_reason = call_analysis.failing_precondition_reason

            except UnexploredPath:
                call_analysis = CallAnalysis(VerificationStatus.UNKNOWN)
            except IgnoreAttempt:
                call_analysis = CallAnalysis()
            status = call_analysis.verification_status
            if status == VerificationStatus.CONFIRMED:
                num_confirmed_paths += 1
            top_analysis, space_exhausted = space.bubble_status(call_analysis)
            overall_status = top_analysis.verification_status if top_analysis else None
            debug('Iter complete. Worst status found so far:',
                  overall_status.name if overall_status else 'None')
            if space_exhausted or top_analysis == VerificationStatus.REFUTED:
                break
    top_analysis = search_root.child.get_result()
    if top_analysis.messages:
        #log = space.execution_log()
        all_messages.extend(
            replace(
                m,
                #execution_log=log,
                test_fn=fn.__qualname__,
                condition_src=conditions.post[0].expr_source)
            for m in top_analysis.messages)
    if top_analysis.verification_status is None:
        top_analysis.verification_status = VerificationStatus.UNKNOWN
    if failing_precondition:
        assert num_confirmed_paths == 0
        addl_ctx = ' ' + failing_precondition.addl_context if failing_precondition.addl_context else ''
        message = f'Unable to meet precondition {addl_ctx}'
        if failing_precondition_reason:
            message += f' (possibly because {failing_precondition_reason}?)'
        all_messages.extend([
            AnalysisMessage(MessageType.PRE_UNSAT, message,
                            failing_precondition.filename,
                            failing_precondition.line, 0, '')
        ])
        top_analysis = CallAnalysis(VerificationStatus.REFUTED)

    assert top_analysis.verification_status is not None
    debug(
        ('Exhausted' if space_exhausted else 'Aborted'),
        ' calltree search with', top_analysis.verification_status.name, 'and',
        len(all_messages.get()), 'messages.', 'Number of iterations: ', i)
    return CallTreeAnalysis(
        messages=all_messages.get(),
        verification_status=top_analysis.verification_status,
        num_confirmed_paths=num_confirmed_paths)
Ejemplo n.º 33
0
 def visit_lambda(self, exp: Lambda) -> Expression:
     self.__level += 1
     res = replace(exp, transformation=exp.transformation.accept(self))
     self.__level -= 1
     return res
Ejemplo n.º 34
0
    def transform_state(
        cls,
        state: MaskProjectTuple,
        new_roi_info: ROIInfo,
        new_roi_extraction_parameters: typing.Dict[
            int, typing.Optional[ROIExtractionProfile]],
        list_of_components: typing.List[int],
        save_chosen: bool = True,
    ) -> MaskProjectTuple:
        """

        :param MaskProjectTuple state: state to be transformed
        :param ROIInfo new_roi_info: roi description
        :param typing.Dict[int, typing.Optional[ROIExtractionProfile]] new_roi_extraction_parameters:
            Parameters used to extract roi
        :param typing.List[int] list_of_components: list of components from new_roi which should be selected
        :param bool save_chosen: if save currently selected components
        :return: new state
        """

        # TODO Refactor
        if not save_chosen or state.roi_info.roi is None or len(
                state.selected_components) == 0:
            return dataclasses.replace(
                state,
                roi_info=new_roi_info,
                selected_components=list_of_components,
                roi_extraction_parameters={
                    i: new_roi_extraction_parameters[i]
                    for i in new_roi_info.bound_info
                },
            )
        if list_of_components is None:
            list_of_components = []
        if new_roi_extraction_parameters is None:
            new_roi_extraction_parameters = defaultdict(lambda: None)
        segmentation_count = len(state.roi_info.bound_info)
        new_segmentation_count = len(new_roi_info.bound_info)
        segmentation_dtype = minimal_dtype(segmentation_count +
                                           new_segmentation_count)
        roi_base = reduce_array(state.roi_info.roi,
                                state.selected_components,
                                dtype=segmentation_dtype)
        annotation_base = {
            i: state.roi_info.annotations.get(x)
            for i, x in enumerate(state.selected_components, start=1)
        }
        alternative_base = {
            name: cls._clip_data_array(roi_base, array)
            for name, array in state.roi_info.alternative.items()
        }
        components_parameters_dict = {
            i: state.roi_extraction_parameters[val]
            for i, val in enumerate(sorted(state.selected_components), 1)
        }

        base_chose = list(annotation_base.keys())

        if new_segmentation_count == 0:
            return dataclasses.replace(
                state,
                roi_info=ROIInfo(roi=roi_base,
                                 annotations=annotation_base,
                                 alternative=alternative_base),
                selected_components=base_chose,
                roi_extraction_parameters=components_parameters_dict,
            )

        new_segmentation = np.copy(new_roi_info.roi)
        new_segmentation[roi_base > 0] = 0
        left_component_list = np.unique(new_segmentation.flat)
        if left_component_list[0] == 0:
            left_component_list = left_component_list[1:]
        new_segmentation = reduce_array(new_segmentation,
                                        dtype=segmentation_dtype)
        roi_base[new_segmentation >
                 0] = new_segmentation[new_segmentation > 0] + len(base_chose)
        for name, array in new_roi_info.alternative.items():
            if name in alternative_base:
                alternative_base[name][new_segmentation > 0] = array[
                    new_segmentation > 0]
            else:
                alternative_base[name] = cls._clip_data_array(
                    new_segmentation, array)
        for i, el in enumerate(left_component_list, start=len(base_chose) + 1):
            annotation_base[i] = new_roi_info.annotations.get(el)
            if el in list_of_components:
                base_chose.append(i)
            components_parameters_dict[i] = new_roi_extraction_parameters[el]

        roi_info = ROIInfo(roi=roi_base,
                           annotations=annotation_base,
                           alternative=alternative_base)

        return dataclasses.replace(
            state,
            roi_info=roi_info,
            selected_components=base_chose,
            roi_extraction_parameters=components_parameters_dict,
        )
Ejemplo n.º 35
0
    def fit(self,
            X: torch.Tensor,
            Y: torch.Tensor,
            Xts: Optional[torch.Tensor] = None,
            Yts: Optional[torch.Tensor] = None):
        """Fits the Falkon KRR model.

        Parameters
        -----------
        X : torch.Tensor
            The tensor of training data, of shape [num_samples, num_dimensions].
            If X is in Fortran order (i.e. column-contiguous) then we can avoid
            an extra copy of the data.
        Y : torch.Tensor
            The tensor of training targets, of shape [num_samples, num_outputs].
            If X and Y represent a classification problem, Y can be encoded as a one-hot
            vector.
            If Y is in Fortran order (i.e. column-contiguous) then we can avoid an
            extra copy of the data.
        Xts : torch.Tensor or None
            Tensor of validation data, of shape [num_test_samples, num_dimensions].
            If validation data is provided and `error_fn` was specified when
            creating the model, they will be used to print the validation error
            during the optimization iterations.
            If Xts is in Fortran order (i.e. column-contiguous) then we can avoid an
            extra copy of the data.
        Yts : torch.Tensor or None
            Tensor of validation targets, of shape [num_test_samples, num_outputs].
            If validation data is provided and `error_fn` was specified when
            creating the model, they will be used to print the validation error
            during the optimization iterations.
            If Yts is in Fortran order (i.e. column-contiguous) then we can avoid an
            extra copy of the data.

        Returns
        --------
        model: Falkon
            The fitted model
        """
        X, Y, Xts, Yts = self._check_fit_inputs(X, Y, Xts, Yts)

        dtype = X.dtype

        # Decide whether to use CUDA for preconditioning based on M
        _use_cuda_preconditioner = (
            self.use_cuda_ and
            (not self.options.cpu_preconditioner) and
            self.M >= get_min_cuda_preconditioner_size(dtype, self.options)
        )
        _use_cuda_mmv = (
            self.use_cuda_ and
            X.shape[0] * X.shape[1] * self.M / self.num_gpus >= get_min_cuda_mmv_size(dtype, self.options)
        )

        self.fit_times_ = []
        self.ny_points_ = None
        self.alpha_ = None

        t_s = time.time()
        # noinspection PyTypeChecker
        ny_points: Union[torch.Tensor, falkon.sparse.SparseTensor] = self.center_selection.select(X, None, self.M)
        if self.use_cuda_:
            ny_points = ny_points.pin_memory()

        with TicToc("Calcuating Preconditioner of size %d" % (self.M), debug=self.options.debug):
            pc_opt: FalkonOptions = dataclasses.replace(self.options,
                                                        use_cpu=not _use_cuda_preconditioner)
            if pc_opt.debug:
                print("Preconditioner will run on %s" %
                      ("CPU" if pc_opt.use_cpu else ("%d GPUs" % self.num_gpus)))
            precond = falkon.preconditioner.FalkonPreconditioner(self.penalty, self.kernel, pc_opt)
            precond.init(ny_points)

        if _use_cuda_mmv:
            # Cache must be emptied to ensure enough memory is visible to the optimizer
            torch.cuda.empty_cache()
            X = X.pin_memory()

        # K_NM storage decision
        k_opt = dataclasses.replace(self.options, use_cpu=True)
        cpu_info = get_device_info(k_opt)
        available_ram = min(k_opt.max_cpu_mem, cpu_info[-1].free_memory) * 0.9
        if self._can_store_knm(X, ny_points, available_ram):
            Knm = self.kernel(X, ny_points, opt=self.options)
        else:
            Knm = None
        self.fit_times_.append(time.time() - t_s)  # Preparation time

        # Here we define the callback function which will run at the end
        # of conjugate gradient iterations. This function computes and
        # displays the validation error.
        validation_cback = None
        if self.error_fn is not None and self.error_every is not None:
            validation_cback = self._get_callback_fn(X, Y, Xts, Yts, ny_points, precond)

        # Start with the falkon algorithm
        with TicToc('Computing Falkon iterations', debug=self.options.debug):
            o_opt: FalkonOptions = dataclasses.replace(self.options, use_cpu=not _use_cuda_mmv)
            if o_opt.debug:
                print("Optimizer will run on %s" %
                      ("CPU" if o_opt.use_cpu else ("%d GPUs" % self.num_gpus)), flush=True)
            optim = falkon.optim.FalkonConjugateGradient(self.kernel, precond, o_opt)
            if Knm is not None:
                beta = optim.solve(
                    Knm, None, Y, self.penalty, initial_solution=None,
                    max_iter=self.maxiter, callback=validation_cback)
            else:
                beta = optim.solve(
                    X, ny_points, Y, self.penalty, initial_solution=None,
                    max_iter=self.maxiter, callback=validation_cback)

            self.alpha_ = precond.apply(beta)
            self.ny_points_ = ny_points
        return self
 def update_cmd(cmd):
     cmd = replace_command_sender(cmd, address=new_sender_address)
     return dataclasses.replace(cmd, my_actor_address=new_sender_address)
Ejemplo n.º 37
0
    def process_msg_and_check(self, message: Message) -> bool:
        """
        Returns True if message can be processed successfully, false if a rate limit is passed.
        """

        current_minute = int(time.time() // self.reset_seconds)
        if current_minute != self.current_minute:
            self.current_minute = current_minute
            self.message_counts = Counter()
            self.message_cumulative_sizes = Counter()
            self.non_tx_message_counts = 0
            self.non_tx_cumulative_size = 0
        try:
            message_type = ProtocolMessageTypes(message.type)
        except Exception as e:
            log.warning(f"Invalid message: {message.type}, {e}")
            return True

        new_message_counts: int = self.message_counts[message_type] + 1
        new_cumulative_size: int = self.message_cumulative_sizes[message_type] + len(message.data)
        new_non_tx_count: int = self.non_tx_message_counts
        new_non_tx_size: int = self.non_tx_cumulative_size
        proportion_of_limit: float = self.percentage_of_limit / 100

        ret: bool = False
        try:

            limits = DEFAULT_SETTINGS
            if message_type in rate_limits_tx:
                limits = rate_limits_tx[message_type]
            elif message_type in rate_limits_other:
                limits = rate_limits_other[message_type]
                new_non_tx_count = self.non_tx_message_counts + 1
                new_non_tx_size = self.non_tx_cumulative_size + len(message.data)
                if new_non_tx_count > NON_TX_FREQ * proportion_of_limit:
                    return False
                if new_non_tx_size > NON_TX_MAX_TOTAL_SIZE * proportion_of_limit:
                    return False
            else:
                log.warning(f"Message type {message_type} not found in rate limits")

            if limits.max_total_size is None:
                limits = dataclasses.replace(limits, max_total_size=limits.frequency * limits.max_size)
            assert limits.max_total_size is not None

            if new_message_counts > limits.frequency * proportion_of_limit:
                return False
            if len(message.data) > limits.max_size:
                return False
            if new_cumulative_size > limits.max_total_size * proportion_of_limit:
                return False

            ret = True
            return True
        finally:
            if self.incoming or ret:
                # now that we determined that it's OK to send the message, commit the
                # updates to the counters. Alternatively, if this was an
                # incoming message, we already received it and it should
                # increment the counters unconditionally
                self.message_counts[message_type] = new_message_counts
                self.message_cumulative_sizes[message_type] = new_cumulative_size
                self.non_tx_message_counts = new_non_tx_count
                self.non_tx_cumulative_size = new_non_tx_size
def replace_actor(cmd, name, actor, **changes):
    new_actor = dataclasses.replace(actor, **changes)
    return replace_command_payment(cmd, **{name: new_actor})
Ejemplo n.º 39
0
def register_spec_with_docutils(spec: specparser.Spec,
                                default_domain: Optional[str]) -> Registry:
    """Register all of the definitions in the spec with docutils, overwriting the previous
    call to this function. This function should only be called once in the
    process lifecycle."""

    builder = Registry.Builder()
    directives = list(spec.directive.items())
    roles = list(spec.role.items())

    # Define rstobjects
    for name, rst_object in spec.rstobject.items():
        directive = rst_object.create_directive()
        directives.append((name, directive))
        role = rst_object.create_role()
        roles.append((name, role))

    for name, directive in directives:
        # Skip abstract base directives
        if name.startswith("_"):
            continue

        options: Dict[str, object] = {
            option_name: spec.get_validator(option)
            for option_name, option in directive.options.items()
        }

        base_class: Any = BaseDocutilsDirective

        # Tabs have special handling because of the need to support legacy syntax
        if name == "tabs":
            base_class = BaseTabsDirective
        elif name in SPECIAL_DIRECTIVE_HANDLERS:
            base_class = SPECIAL_DIRECTIVE_HANDLERS[name]

        DocutilsDirective = make_docutils_directive_handler(
            directive, base_class, name, options)
        builder.add_directive(name, DocutilsDirective)

    # reference tabs directive declaration as first step in registering tabs-* with docutils
    tabs_directive = spec.directive["tabs"]

    # Define tabsets
    for name in spec.tabs:
        tabs_base_class: Any = BaseTabsDirective
        tabs_name = "tabs-" + name

        # copy and modify the tabs directive to update its name to match the deprecated tabs-* naming convention
        modified_tabs_directive = dataclasses.replace(tabs_directive,
                                                      name=tabs_name)

        tabs_options: Dict[str, object] = {
            option_name: spec.get_validator(option)
            for option_name, option in tabs_directive.options.items()
        }

        DocutilsDirective = make_docutils_directive_handler(
            modified_tabs_directive, tabs_base_class, "tabs", tabs_options)

        builder.add_directive(tabs_name, DocutilsDirective)

    # Docutils builtins
    builder.add_directive("unicode",
                          docutils.parsers.rst.directives.misc.Unicode)
    builder.add_directive("replace",
                          docutils.parsers.rst.directives.misc.Replace)

    # Define roles
    builder.add_role("", handle_role_null)
    for name, role_spec in roles:
        handler: Optional[RoleHandlerType] = None
        domain = role_spec.domain or ""
        if not role_spec.type or role_spec.type == specparser.PrimitiveRoleType.text:
            handler = TextRoleHandler(domain)
        elif isinstance(role_spec.type, specparser.LinkRoleType):
            handler = LinkRoleHandler(
                role_spec.type.link,
                role_spec.type.ensure_trailing_slash == True,
                role_spec.type.format,
            )
        elif isinstance(role_spec.type, specparser.RefRoleType):
            handler = RefRoleHandler(
                role_spec.type.domain or domain,
                role_spec.type.name,
                role_spec.type.tag,
                role_spec.rstobject.type
                if role_spec.rstobject else specparser.TargetType.plain,
                role_spec.type.format,
            )
        elif role_spec.type == specparser.PrimitiveRoleType.explicit_title:
            handler = ExplicitTitleRoleHandler(domain)

        if not handler:
            raise ValueError('Unknown role type "{}"'.format(role_spec.type))

        builder.add_role(name, handler)

    return builder.build(default_domain)
def replace_command_payment(cmd, **changes):
    new_payment = dataclasses.replace(cmd.payment, **changes)
    return dataclasses.replace(cmd, payment=new_payment)
import dataclasses
from typing import Type, Union


class A:
    pass


dataclasses.fields(<warning descr="'dataclasses.fields' method should be called on dataclass instances or types">A</warning>)
dataclasses.fields(<warning descr="'dataclasses.fields' method should be called on dataclass instances or types">A()</warning>)

dataclasses.asdict(<warning descr="'dataclasses.asdict' method should be called on dataclass instances">A()</warning>)
dataclasses.astuple(<warning descr="'dataclasses.astuple' method should be called on dataclass instances">A()</warning>)
dataclasses.replace(<warning descr="'dataclasses.replace' method should be called on dataclass instances">A()</warning>)


@dataclasses.dataclass
class B:
    pass


dataclasses.fields(B)
dataclasses.fields(B())

dataclasses.asdict(B())
dataclasses.astuple(B())
dataclasses.replace(B())

dataclasses.asdict(<warning descr="'dataclasses.asdict' method should be called on dataclass instances">B</warning>)
dataclasses.astuple(<warning descr="'dataclasses.astuple' method should be called on dataclass instances">B</warning>)
dataclasses.replace(<warning descr="'dataclasses.replace' method should be called on dataclass instances">B</warning>)
def union1(p: Union[A, B]):
    dataclasses.fields(p)

    dataclasses.asdict(p)
    dataclasses.astuple(p)
    dataclasses.replace(p)
def union1(p: Union[A, B]):
    dataclasses.fields(p)

    dataclasses.asdict(p)
    dataclasses.astuple(p)
    dataclasses.replace(p)
Ejemplo n.º 44
0
def merge_schema(default: Schema, override: Schema) -> Schema:
    if override.override:
        return override
    return replace(override, child=merge_schema(default, override.child))
Ejemplo n.º 45
0
from dataclasses import dataclass, field, InitVar, replace


@dataclass
class A:
    a: int
    b: str = "str"


replace(A(1), <arg1>)


@dataclass
class B:
    a: int
    b: str = field(default="str", init=False)


replace(B(1), <arg2>)


@dataclass
class C:
    a: int
    b: InitVar[str] = "str"


replace(C(1), <arg3>)


class D:
Ejemplo n.º 46
0
def main():
    # See all possible arguments in src/transformers/training_args.py
    # or by passing the --help flag to this script.
    # We now keep distinct sets of args, for a cleaner separation of concerns.

    parser = HfArgumentParser(
        (ModelArguments, DataTrainingArguments, TrainingArguments))

    if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
        # If we pass only one argument to the script and it's the path to a json file,
        # let's parse it to get our arguments.
        model_args, data_args, training_args = parser.parse_json_file(
            json_file=os.path.abspath(sys.argv[1]))
    else:
        model_args, data_args, training_args = parser.parse_args_into_dataclasses(
        )

    if (os.path.exists(training_args.output_dir)
            and os.listdir(training_args.output_dir) and training_args.do_train
            and not training_args.overwrite_output_dir):
        raise ValueError(
            f"Output directory ({training_args.output_dir}) already exists and is not empty. Use --overwrite_output_dir to overcome."
        )

    # Setup logging
    logging.basicConfig(
        format="%(asctime)s - %(levelname)s - %(name)s -   %(message)s",
        datefmt="%m/%d/%Y %H:%M:%S",
        level=logging.INFO
        if training_args.local_rank in [-1, 0] else logging.WARN,
    )
    logger.warning(
        "Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s",
        training_args.local_rank,
        training_args.device,
        training_args.n_gpu,
        bool(training_args.local_rank != -1),
        training_args.fp16,
    )
    logger.info("Training/evaluation parameters %s", training_args)

    # Set seed
    set_seed(training_args.seed)

    try:
        num_labels = glue_tasks_num_labels[data_args.task_name]
        output_mode = glue_output_modes[data_args.task_name]
    except KeyError:
        raise ValueError("Task not found: %s" % (data_args.task_name))

    # Load pretrained model and tokenizer
    #
    # Distributed training:
    # The .from_pretrained methods guarantee that only one local process can concurrently
    # download model & vocab.

    config = AutoConfig.from_pretrained(
        model_args.config_name
        if model_args.config_name else model_args.model_name_or_path,
        num_labels=num_labels,
        finetuning_task=data_args.task_name,
        cache_dir=model_args.cache_dir,
    )
    tokenizer = AutoTokenizer.from_pretrained(
        model_args.tokenizer_name
        if model_args.tokenizer_name else model_args.model_name_or_path,
        cache_dir=model_args.cache_dir,
    )
    model = AutoModelForSequenceClassification.from_pretrained(
        model_args.model_name_or_path,
        from_tf=bool(".ckpt" in model_args.model_name_or_path),
        config=config,
        cache_dir=model_args.cache_dir,
    )

    # Get datasets
    train_dataset = GlueDataset(
        data_args, tokenizer=tokenizer) if training_args.do_train else None
    eval_dataset = GlueDataset(
        data_args, tokenizer=tokenizer,
        evaluate=True) if training_args.do_eval else None

    def compute_metrics(p: EvalPrediction) -> Dict:
        if output_mode == "classification":
            preds = np.argmax(p.predictions, axis=1)
        elif output_mode == "regression":
            preds = np.squeeze(p.predictions)
        return glue_compute_metrics(data_args.task_name, preds, p.label_ids)

    # Initialize our Trainer
    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=train_dataset,
        eval_dataset=eval_dataset,
        compute_metrics=compute_metrics,
    )

    # Training
    if training_args.do_train:
        trainer.train(model_path=model_args.model_name_or_path if os.path.
                      isdir(model_args.model_name_or_path) else None)
        trainer.save_model()
        # For convenience, we also re-save the tokenizer to the same directory,
        # so that you can share your model easily on huggingface.co/models =)
        if trainer.is_world_master():
            tokenizer.save_pretrained(training_args.output_dir)

    # Evaluation
    results = {}
    if training_args.do_eval and training_args.local_rank in [-1, 0]:
        logger.info("*** Evaluate ***")

        # Loop to handle MNLI double evaluation (matched, mis-matched)
        eval_datasets = [eval_dataset]
        if data_args.task_name == "mnli":
            mnli_mm_data_args = dataclasses.replace(data_args,
                                                    task_name="mnli-mm")
            eval_datasets.append(
                GlueDataset(mnli_mm_data_args,
                            tokenizer=tokenizer,
                            evaluate=True))

        for eval_dataset in eval_datasets:
            #result = trainer.evaluate(eval_dataset=eval_dataset)
            predictions = trainer.predict(
                test_dataset=eval_dataset).predictions
            fil = open('/dccstor/tuhinstor/tuhin/output1.txt', 'w')
            for p in predictions:
                fil.write(str(p) + '\n')

            output_eval_file = os.path.join(
                training_args.output_dir,
                f"eval_results_{eval_dataset.args.task_name}.txt")
            with open(output_eval_file, "w") as writer:
                logger.info("***** Eval results {} *****".format(
                    eval_dataset.args.task_name))
                for key, value in result.items():
                    logger.info("  %s = %s", key, value)
                    writer.write("%s = %s\n" % (key, value))

            results.update(result)

    return results
Ejemplo n.º 47
0
 def transform(self, ray: Ray) -> Ray:
     new_ray = replace(ray, position=self.context.end_path)
     return new_ray
Ejemplo n.º 48
0
 def make_scalar(self, new_value):
     return dc.replace(self, value=new_value)
Ejemplo n.º 49
0
 def set_serialized_creds(self, serialized_creds):
   return dataclasses.replace(self, serialized_creds=serialized_creds)
Ejemplo n.º 50
0
    def run(
        self,
        universal_newlines: Optional[bool] = None,
        *,
        capture_output: bool = False,
        check: bool = False,
        encoding: Optional[str] = None,
        errors: Optional[str] = None,
        input: Optional[Union[str, bytes]] = None,
        text: Optional[bool] = None,
        timeout: Optional[float] = None,
        **kwargs: Any,
    ) -> subprocess.CompletedProcess[Any]:
        r"""Run command in :func:`subprocess.run`, optionally overrides via kwargs.

        Parameters
        ----------
        input : Union[bytes, str], optional
            pass string to subprocess's stdin. Bytes by default, str in text mode.

            Text mode is triggered by setting any of text, encoding, errors or
            universal_newlines.

        check : bool
            If True and the exit code was non-zero, it raises a
            :exc:`subprocess.CalledProcessError`. The CalledProcessError object will
            have the return code in the returncode attribute, and output & stderr
            attributes if those streams were captured.

        timeout : int
            If given, and the process takes too long, a :exc:`subprocess.TimeoutExpired`

        **kwargs : dict, optional
            Overrides existing attributes for :func:`subprocess.run`

        Examples
        --------
        >>> import subprocess
        >>> cmd = SubprocessCommand(
        ...     ["/bin/sh", "-c", "ls -l non_existent_file ; exit 0"])
        >>> cmd.run()
        CompletedProcess(args=['/bin/sh', '-c', 'ls -l non_existent_file ; exit 0'],
                         returncode=0)

        >>> import subprocess
        >>> cmd = SubprocessCommand(
        ...     ["/bin/sh", "-c", "ls -l non_existent_file ; exit 0"])
        >>> cmd.run(check=True)
        CompletedProcess(args=['/bin/sh', '-c', 'ls -l non_existent_file ; exit 0'],
                         returncode=0)

        >>> cmd = SubprocessCommand(["sed", "-e", "s/foo/bar/"])
        >>> completed = cmd.run(input=b"when in the course of fooman events\n")
        >>> completed
        CompletedProcess(args=['sed', '-e', 's/foo/bar/'], returncode=0)
        >>> completed.stderr

        >>> cmd = SubprocessCommand(["sed", "-e", "s/foo/bar/"])
        >>> completed = cmd.run(input=b"when in the course of fooman events\n",
        ...                     capture_output=True)
        >>> completed
        CompletedProcess(args=['sed', '-e', 's/foo/bar/'], returncode=0,
                        stdout=b'when in the course of barman events\n', stderr=b'')
        >>> completed.stdout
        b'when in the course of barman events\n'
        >>> completed.stderr
        b''
        """
        return subprocess.run(
            **dataclasses.replace(
                self,
                universal_newlines=universal_newlines,
                errors=errors,
                text=text,
                **kwargs,
            ).__dict__,
            check=check,
            capture_output=capture_output,
            input=input,
            timeout=timeout,
        )
Ejemplo n.º 51
0
from dataclasses import dataclass, replace

@dataclass
class A1:
    a: int

class B1(A1):
    b: str

replace(B1(1), <arg1>)


class A2:
    a: int

@dataclass
class B2(A2):
    b: str

replace(B2("1"), <arg2>)


@dataclass
class A3:
    a: int

class B3(A3):
    def __init__(self, b: str):
        self.a = 10

replace(B3("1"), <arg3>)
Ejemplo n.º 52
0
def merge_framework_metadata(
    exception_info: datamodel.ExceptionData,
    framework_infos: Sequence[datamodel.FrameworkMetadata],
) -> datamodel.FrameworkMetadata:
    """
    Return the merger between the *frameworks* infos and *exception_info*.

    This will raise an exception when the information cannot be merged.
    """
    result = datamodel.FrameworkMetadata()

    for info in framework_infos:
        result.architectures.update(info.architectures)

    # enum_type
    result = replace(
        result,
        enum_type=merge_enum_type(
            exception_info.enum_type,
            [(next(iter(info.architectures)), info.enum_type)
             for info in framework_infos],
        ),
    )

    # enum
    result = replace(
        result,
        enum=merge_enum(
            exception_info.enum,
            [(next(iter(info.architectures)), info.enum)
             for info in framework_infos],
        ),
    )

    # structs
    ...

    # externs
    result = replace(
        result,
        externs=merge_externs(
            exception_info.externs,
            [(next(iter(info.architectures)), info.externs)
             for info in framework_infos],
        ),
    )

    # cftypes
    ...

    # literals
    result = replace(
        result,
        literals=merge_literals(
            exception_info.literals,
            [(next(iter(info.architectures)), info.literals)
             for info in framework_infos],
        ),
    )

    # formal_protocols
    ...

    # informal_protocols
    ...

    # classes
    ...

    # aliases
    result = replace(
        result,
        aliases=merge_aliases(
            exception_info.aliases,
            [(next(iter(info.architectures)), info.aliases)
             for info in framework_infos],
        ),
    )

    # expressions
    result = replace(
        result,
        aliases=merge_expressions(
            exception_info.expressions,
            [(next(iter(info.architectures)), info.expressions)
             for info in framework_infos],
        ),
    )

    # func_macros
    result = replace(
        result,
        aliases=merge_func_macros(
            exception_info.func_macros,
            [(next(iter(info.architectures)), info.func_macros)
             for info in framework_infos],
        ),
    )

    # functions
    ...

    return result
Ejemplo n.º 53
0
from dataclasses import dataclass, field, InitVar, replace


@dataclass
class A:
    a: int
    b: str = "str"


replace(A(1))
replace(A(1), a=1)
replace(A(1), a=1, b="abc")
replace(A(1), a=1, b="abc", <warning descr="Unexpected argument">c=2</warning>)


@dataclass
class B:
    a: int
    b: str = field(default="str", init=False)


replace(B(1))
replace(B(1), a=1)
replace(B(1), a=1, <warning descr="Unexpected argument">b="abc"</warning>)
replace(B(1), a=1, <warning descr="Unexpected argument">b="abc"</warning>, <warning descr="Unexpected argument">c=2</warning>)


@dataclass
class C:
    a: int
    b: InitVar[str] = "str"
@dataclasses.dataclass
class Base:
    pass


class A(Base):
    pass


dataclasses.fields(A)
dataclasses.fields(A())

dataclasses.asdict(<warning descr="'dataclasses.asdict' method should be called on dataclass instances">A()</warning>)
dataclasses.astuple(A())
dataclasses.replace(A())


@dataclasses.dataclass
class B(Base):
    pass


dataclasses.fields(B)
dataclasses.fields(B())

dataclasses.asdict(B())
dataclasses.astuple(B())
dataclasses.replace(B())

dataclasses.asdict(<warning descr="'dataclasses.asdict' method should be called on dataclass instances">B</warning>)
Ejemplo n.º 55
0
 def copy(self):
     return replace(self)
Ejemplo n.º 56
0
 async def set_name(self, new_name: str):
     new_info = replace(self.wallet_info, name=new_name)
     self.wallet_info = new_info
     await self.wallet_state_manager.user_store.update_wallet(
         self.wallet_info)