コード例 #1
0
ファイル: test_microbatch.py プロジェクト: yyht/fairscale
def test_gather_tensors():
    a = torch.zeros(1, 1)
    b = torch.zeros(1, 1)

    ab = gather([Batch(a), Batch(b)])

    assert ab.size() == (2, 1)
コード例 #2
0
ファイル: test_microbatch.py プロジェクト: yyht/fairscale
def test_gather_tuples():
    a = (torch.zeros(1, 1), torch.zeros(2, 2))
    b = (torch.zeros(1, 1), torch.zeros(2, 2))

    ab = gather([Batch(a), Batch(b)])

    assert isinstance(ab, tuple)
    assert ab[0].size() == (2, 1)
    assert ab[1].size() == (4, 2)
コード例 #3
0
    def get_batch_from_message(self, message: PipeMessage, queue_name: int) -> Batch:
        """Get the tensor(s) wrapped in a `Batch` from a `PipeMessage`, applying
        AsyncRecvOperator so we can intercept the backward pass"""

        microbatch_index = message.args.microbatch_index
        phony = torch.empty(0, device=self.transport.input_device, requires_grad=True)
        result = AsyncRecvOperator.apply(phony, self.transport, message, queue_name)
        if len(result) == 1:
            batch = Batch(result[0], microbatch_index)
        else:
            batch = Batch(result, microbatch_index)
        return batch
コード例 #4
0
ファイル: test_microbatch.py プロジェクト: yyht/fairscale
def test_batch_setitem_by_slice():
    a = Batch(torch.tensor(42))
    b = Batch((torch.tensor(42), torch.tensor(21)))

    a[:] = (torch.tensor(0), )
    b[:] = (torch.tensor(0), )

    assert a.atomic
    assert a[0].item() == 0

    assert not b.atomic
    assert len(b) == 1
    assert b[0].item() == 0
コード例 #5
0
ファイル: test_microbatch.py プロジェクト: yyht/fairscale
def test_batch_setitem_by_index():
    a = Batch(torch.tensor(42))
    b = Batch((torch.tensor(42), torch.tensor(21)))

    a[0] = torch.tensor(0)
    b[0] = torch.tensor(0)

    assert a.atomic
    assert a[0].item() == 0

    assert not b.atomic
    assert len(b) == 2
    assert b[0].item() == 0
    assert b[1].item() == 21
コード例 #6
0
ファイル: ampnet.py プロジェクト: mrzzd/fairscale
 def compute(
     batch: Batch = batch,
     partition: nn.Sequential = partition,
     chunk_id: int = i,
     part_id: int = j,
 ) -> Batch:
     with record_function("chunk%d-part%d" % (chunk_id, part_id)):
         return batch.call(partition)
コード例 #7
0
 def wait_for(self, chunk: int) -> None:
     """Waits until all elements of given chunk is populated in self.tensors.
     Then it constructs self.batches[chunk] if it is not constructed yet.
     """
     with self.ready_cv:
         while self.batches[chunk] is None and any(
                 b is None for b in self.tensors[chunk]):
             self.ready_cv.wait()
         if self.batches[chunk] is None:
             tensors = cast(List[Tensor], self.tensors[chunk])
             self.batches[chunk] = Batch(tuple(tensors), chunk)
コード例 #8
0
def make_checkpoint(function: Function, input: TensorOrTensors, index: int) -> TensorOrTensors:
    """Makes a checkpoint with a simple interface like
    :func:`torch.utils.checkpoint.checkpoint`. It's only used to test or debug
    :class:`Checkpoint` and :class:`Recompute` without boilerplate.
    """
    batch = Batch(input, index)

    chk = Checkpointing(function, batch)
    batch = chk.checkpoint()
    chk.recompute(batch)

    return batch.tensor_or_tensors
コード例 #9
0
 def compute(
     batch: Batch = batch,
     chunk_id: int = chunk,
     rank: int = pipeline_record.rank
     if pipeline_record is not None else -1,
 ) -> Batch:
     with record_function("chunk%d-rank%d" %
                          (chunk_id, pipeline_record.rank)):
         result = self.module(*batch.tensors)
         if self.num_outputs is None:
             result = (result, )
     return Batch(result, chunk_id)
コード例 #10
0
ファイル: test_microbatch.py プロジェクト: yyht/fairscale
def test_batch_atomic():
    x = torch.tensor(42)
    b = Batch(x)

    assert b.atomic

    assert b.tensor is x
    with pytest.raises(AttributeError):
        b.tensors

    assert list(b) == [x]
    assert len(b) == 1
    assert b[0] is x
コード例 #11
0
def test_tensor_life_without_checkpointing():
    skip_layout = SkipLayout(num_partitions=2,
                             skip_routes={(None, "test"): (0, 1)})
    skip_tracker = SkipTrackerThroughPotals(skip_layout)

    batch = Batch(torch.tensor([1.0]))
    tensor = torch.tensor([2.0])

    skip_tracker.save(batch, None, "test", tensor)
    assert skip_tracker.portals[(None, "test")].tensor_life == 1

    skip_tracker.load(batch, None, "test")
    assert skip_tracker.portals[(None, "test")].tensor_life == 0
コード例 #12
0
def test_reuse_portal():
    skip_layout = SkipLayout(num_partitions=2,
                             skip_routes={(None, "test"): (0, 1)})
    skip_tracker = SkipTrackerThroughPotals(skip_layout)

    batch = Batch(torch.tensor([1.0]))
    a = torch.tensor([2.0])
    b = torch.tensor([2.0])

    skip_tracker.save(batch, None, "test", a)
    portal = skip_tracker.portals[(None, "test")]

    skip_tracker.save(batch, None, "test", b)
    assert portal is skip_tracker.portals[(None, "test")]
コード例 #13
0
ファイル: test_microbatch.py プロジェクト: yyht/fairscale
def test_batch_non_atomic():
    x, y = torch.tensor(42), torch.tensor(21)
    b = Batch((x, y))

    assert not b.atomic

    with pytest.raises(AttributeError):
        b.tensor
    assert b.tensors == (x, y)

    assert list(b) == [x, y]
    assert len(b) == 2
    assert b[0] is x
    assert b[1] is y
コード例 #14
0
ファイル: test_microbatch.py プロジェクト: yyht/fairscale
def test_batch_call():
    a = Batch(torch.tensor(42))
    b = Batch((torch.tensor(42), torch.tensor(21)))

    def f(x):
        return x

    assert a.call(f).atomic
    assert not b.call(f).atomic
コード例 #15
0
ファイル: test_checkpoint.py プロジェクト: yyht/fairscale
def test_not_requires_grad():
    x = Batch(torch.rand(1, requires_grad=False))
    assert not x[0].requires_grad

    def f(x):
        return x * 2

    chk = Checkpointing(f, x)
    x = chk.checkpoint()
    assert x[0].requires_grad

    chk.recompute(x)
    assert x[0].requires_grad

    x.tensor.backward()
コード例 #16
0
 def fence(self, chunk: int) -> None:
     """Prepares micro-batches for computation."""
     # Ensure that batches[chunk-1] is executed after batches[chunk] in
     # backpropagation by an explicit dependency.
     # TODO: This dependency injection causes deadlock if this partition
     # gets its input from model input. 1) Figure out why 2) If we need to live
     # with this constraint, replace the condition 'self.rank > 0' below with
     # a more accurate one.
     if chunk != 0 and self.consumers and self.rank > 0:
         batch = self.batches[chunk]
         assert batch is not None
         dependant_tensors = list(batch.tensors)
         for remote_ph_list in self.forwarded_phony[chunk - 1]:
             for remote_ph in remote_ph_list:
                 phony = remote_ph.to_here()
                 dependant_tensors[0] = join(dependant_tensors[0], phony)
         self.batches[chunk] = Batch(tuple(dependant_tensors), chunk)
コード例 #17
0
def test_no_copy_no_portal():
    skip_layout = SkipLayout(num_partitions=2,
                             skip_routes={
                                 (None, "copy"): (0, 1),
                                 (None, "not_copy"): (0, 0)
                             })
    skip_tracker = SkipTrackerThroughPotals(skip_layout)

    batch = Batch(torch.tensor([1.0]))
    a = torch.tensor([2.0])
    b = torch.tensor([2.0])

    skip_tracker.save(batch, None, "copy", a)
    skip_tracker.save(batch, None, "not_copy", b)

    assert (None, "copy") in skip_tracker.portals
    assert (None, "copy") not in skip_tracker.tensors
    assert (None, "not_copy") in skip_tracker.tensors
    assert (None, "not_copy") not in skip_tracker.portals
コード例 #18
0
    def event_loop_head_across_minibatches(
        self, lm_dataloader: DataLoader, criterion: nn.Module, optimizer: Optimizer, transform_logger_object: Any
    ) -> None:
        # handles one epoch

        cur_rank = self.group.rank()
        N = len(get_pipeline_parallel_ranks())  # for warmup phase
        activations = dict()
        count = 0
        num_gradients = 0
        lm_iter = iter(lm_dataloader)

        # filling the pipeline: warmup  -> all N - 1 forward passes
        while True:
            try:
                cur_batch = next(lm_iter)
                reqd_input = transform_logger_object.transform_input(cur_batch).to(self.input_device)
                batch = Batch(reqd_input, count)
                if self.weight_prediction:
                    optimizer.update_weight_using_future_predictions(cur_rank, N, forward=True)
                activations[count], message = self.async_send_inner(batch, count)
                self.transport.send_message(message, sync=True)
                count += 1
                if count == N - 1:
                    break
            except StopIteration:
                break

        # steady state
        while True:
            try:
                # 1 forward pass
                cur_batch = next(lm_iter)
                reqd_input = transform_logger_object.transform_input(cur_batch).to(self.input_device)
                batch = Batch(reqd_input, count)
                if self.weight_prediction:
                    optimizer.update_weight_using_future_predictions(cur_rank, N, forward=True)
                activations[count], forward_message = self.async_send_inner(batch, count)
                count += 1

                # 1 backward pass
                message = self.transport.recv_message_header(EVENT_LOOP_GRADIENTS_QUEUE)
                args: AsyncMessageBody = message.args
                assert args.message_type is AsyncMessageType.Gradients
                if self.weight_prediction:
                    optimizer.update_weight_using_future_predictions(cur_rank, N, forward=False)
                self.async_grad_inner(message, activations)

                # Send after grad
                self.transport.send_message(forward_message, sync=True)

                num_gradients += 1
                if self.perform_optimizer_step(optimizer, num_gradients):
                    optimizer.step()
                    optimizer.zero_grad()
                    transform_logger_object.check_and_save_weights(num_gradients)

            except StopIteration:
                break

        # remaining items for backward
        remaining_items = len(activations)
        for _ in range(remaining_items):
            message = self.transport.recv_message_header(EVENT_LOOP_GRADIENTS_QUEUE)
            args = message.args
            assert args.message_type is AsyncMessageType.Gradients
            if self.weight_prediction:
                optimizer.update_weight_using_future_predictions(cur_rank, N, forward=False)
            self.async_grad_inner(message, activations)
            num_gradients += 1

            if self.perform_optimizer_step(optimizer, num_gradients):
                optimizer.step()
                optimizer.zero_grad()
                transform_logger_object.check_and_save_weights(num_gradients)
コード例 #19
0
ファイル: test_worker.py プロジェクト: zzszmyf/fairscale
 def _42():
     return Batch(torch.tensor(42), 0)
コード例 #20
0
ファイル: test_worker.py プロジェクト: zzszmyf/fairscale
 def detect_grad_enabled():
     x = torch.rand(1, requires_grad=torch.is_grad_enabled())
     return Batch(x, 0)
コード例 #21
0
ファイル: test_worker.py プロジェクト: zzszmyf/fairscale
 def counter():
     nonlocal count
     time.sleep(0.1)
     count += 1
     return Batch((), 0)
コード例 #22
0
ファイル: test_worker.py プロジェクト: zzszmyf/fairscale
 def log_thread_id():
     thread_id = threading.current_thread().ident
     thread_ids.add(thread_id)
     return Batch((), 0)