示例#1
0
    def _scatter_samples(self, batches, replace_empty_samples=False):
        """Split and distribute a sample across GPUs."""
        if not replace_empty_samples:
            # pad with None until its size is equal to the number of replicas
            batches = batches + [None] * (self.num_replicas - len(batches))
        else:
            # pad by cycling through the given samples
            batches = list(islice(cycle(batches), self.num_replicas))

        assert len(batches) == self.num_replicas

        Future.gen_list([
            self.call_async(rank, '_async_prepare_batch', batch=batches[rank])
            for rank in range(self.num_replicas)
        ])
    def load_checkpoint(self, filename):
        """Load a checkpoint into the model replicas in each process."""
        results = Future.gen_list([
            self.call_async(rank, '_async_load_checkpoint', filename=filename)
            for rank in range(self.num_replicas)
        ])

        return results[0]
示例#3
0
    def load_optim_state_dict(self, optim_state_dict):
        """Load a checkpoint into the model replicas in each process."""
        results = Future.gen_list([
            self.call_async(rank,
                            '_async_load_optim_state_dict',
                            optim_state_dict=optim_state_dict)
            for rank in range(self.num_replicas)
        ])

        return results[0]
    def __init__(self,
                 opt,
                 model,
                 loss_function,
                 device_ids=None,
                 multiprocessing_method='spawn'):

        if device_ids is None:
            device_ids = tuple(range(torch.cuda.device_count()))

        super().__init__(device_ids, multiprocessing_method)

        if not torch.cuda.is_available():
            raise NotImplementedError('Training on CPU is not supported')

        print("Initializing multi-gpu training with %d devices" %
              self.num_replicas)

        model = model.share_memory()
        nccl_uid = nccl.get_unique_id()
        self.loss_function = loss_function

        Future.gen_list([
            self.call_async(rank,
                            '_async_init',
                            args=opt,
                            model=model,
                            loss_function=loss_function,
                            nccl_uid=nccl_uid)
            for rank in range(self.num_replicas)
        ])

        self._grads_initialized = False

        self.initialize_gradients()

        self.set_seed(opt.seed)
 def set_seed(self, seed):
     Future.gen_list([
         self.call_async(rank, '_async_set_seed', seed=seed)
         for rank in range(self.num_replicas)
     ])