def test_check(): check(torch.tensor(42)) check((torch.tensor(4), torch.tensor(2))) with pytest.raises(TypeError): check(42) with pytest.raises(TypeError): check("str") with pytest.raises(TypeError): check((torch.tensor(4), 2))
def forward(self, *inputs: Tensor) -> rpc.RRef: # type: ignore for i, input in enumerate(inputs): microbatch.check(input) # Divide a mini-batch into micro-batches. batches_list = [ microbatch.scatter(input, self.chunks) for input in inputs ] # Create a DistributedPipelineRecord, one per partition, and make connections between them (i.e. # set list of consumers). pipeline_records: Dict[DistributedPipeline.Partition, rpc.RRef] = {} for partition in reversed(self.partitions): r_handler = partition.handler.remote() consumers = [] # Identify consumers of the outputs of the partition for consumer in partition.nodes[-1].output_consumers: consumer_partition = next(p for p in self.partitions if p.nodes[0] is consumer.consumer) # Index of a consumer partition should be greater than index of the partition. assert consumer_partition in pipeline_records consumers.append( DistributedPipelineRecord.DataConsumer( pipeline_records[consumer_partition], consumer.consumer_input_idx, consumer.output_idx)) pipeline_records[partition] = r_handler.make_pipeline_record( consumers) # Let the pipeline-handler for the partition starts processing the pipeline-record for that partition. this_result = r_handler.run_pipeline(pipeline_records[partition]) # If this is the last partition, we expect the result of the model be the output of this partition. if partition is self.partitions[-1]: result = this_result # Start feeding model input to the partitions that need them. for i, b in enumerate(zip(*batches_list)): for input_consumer in self.input_consumers: pipeline_record = pipeline_records[input_consumer.consumer] # TODO: Debug why we need this special handling if pipeline_record.owner().name == rpc.get_worker_info( ).name: # type: ignore pipeline_record.local_value().feed( i, input_consumer.consumer_input_idx, b[input_consumer.output_idx].value) else: pipeline_record.rpc_async().feed( i, input_consumer.consumer_input_idx, b[input_consumer.output_idx].value) # type: ignore return result