示例#1
0
def test_check():
    check(torch.tensor(42))
    check((torch.tensor(4), torch.tensor(2)))

    with pytest.raises(TypeError):
        check(42)

    with pytest.raises(TypeError):
        check("str")

    with pytest.raises(TypeError):
        check((torch.tensor(4), 2))
示例#2
0
    def forward(self, *inputs: Tensor) -> rpc.RRef:  # type: ignore
        for i, input in enumerate(inputs):
            microbatch.check(input)

        # Divide a mini-batch into micro-batches.
        batches_list = [
            microbatch.scatter(input, self.chunks) for input in inputs
        ]

        # Create a DistributedPipelineRecord, one per partition, and make connections between them (i.e.
        # set list of consumers).
        pipeline_records: Dict[DistributedPipeline.Partition, rpc.RRef] = {}
        for partition in reversed(self.partitions):
            r_handler = partition.handler.remote()
            consumers = []
            # Identify consumers of the outputs of the partition
            for consumer in partition.nodes[-1].output_consumers:
                consumer_partition = next(p for p in self.partitions
                                          if p.nodes[0] is consumer.consumer)
                # Index of a consumer partition should be greater than index of the partition.
                assert consumer_partition in pipeline_records
                consumers.append(
                    DistributedPipelineRecord.DataConsumer(
                        pipeline_records[consumer_partition],
                        consumer.consumer_input_idx, consumer.output_idx))
            pipeline_records[partition] = r_handler.make_pipeline_record(
                consumers)
            # Let the pipeline-handler for the partition starts processing the pipeline-record for that partition.
            this_result = r_handler.run_pipeline(pipeline_records[partition])
            # If this is the last partition, we expect the result of the model be the output of this partition.
            if partition is self.partitions[-1]:
                result = this_result

        # Start feeding model input to the partitions that need them.
        for i, b in enumerate(zip(*batches_list)):
            for input_consumer in self.input_consumers:
                pipeline_record = pipeline_records[input_consumer.consumer]
                # TODO: Debug why we need this special handling
                if pipeline_record.owner().name == rpc.get_worker_info(
                ).name:  # type: ignore
                    pipeline_record.local_value().feed(
                        i, input_consumer.consumer_input_idx,
                        b[input_consumer.output_idx].value)
                else:
                    pipeline_record.rpc_async().feed(
                        i, input_consumer.consumer_input_idx,
                        b[input_consumer.output_idx].value)  # type: ignore

        return result