Esempio n. 1
0
    def compute_gradients(self,
                          postprocessed_batch: SampleBatch) -> ModelGradients:

        if not isinstance(postprocessed_batch, SampleBatch) or \
                not postprocessed_batch.zero_padded:
            pad_batch_to_sequences_of_same_size(
                postprocessed_batch,
                max_seq_len=self.max_seq_len,
                shuffle=False,
                batch_divisibility_req=self.batch_divisibility_req,
                view_requirements=self.view_requirements,
            )

        # Mark the batch as "is_training" so the Model can use this
        # information.
        postprocessed_batch.is_training = True

        # Single device case: Use batch as-is (no slicing).
        if len(self.devices) == 1:
            batches = [self._lazy_tensor_dict(postprocessed_batch)]
        # Multi-GPU case: Slice inputs into n (roughly) equal batches.
        else:
            len_ = len(postprocessed_batch)
            batches = []
            start = 0
            for i, device in enumerate(self.devices):
                shard_len = len_ // (len(self.devices) - i)
                batch = self._lazy_tensor_dict(postprocessed_batch.slice(
                    start, start + shard_len),
                                               device=device)
                batches.append(batch)
                len_ -= shard_len
                start += shard_len

            # Copy weights of main model to all towers.
            state_dict = self.model.state_dict()
            for tower in self.model_gpu_towers:
                tower.load_state_dict(state_dict)

        # Do the (maybe parallelized) gradient calculation step.
        tower_outputs = self._multi_gpu_parallel_grad_calc(batches)

        # Multi device (GPU) case.
        if len(self.devices) > 1:
            # Mean-reduce over GPU-towers.
            all_grads = []
            for i in range(len(tower_outputs[0][0])):
                if tower_outputs[0][0][i] is not None:
                    all_grads.append(
                        torch.mean(torch.stack(
                            [t[0][i].to(self.device) for t in tower_outputs]),
                                   dim=0))
                else:
                    all_grads.append(None)
            # Set main model's grads to mean-reduced values.
            for i, p in enumerate(self.model.parameters()):
                p.grad = all_grads[i]
            # Reduce stats over towers as well.
            from ray.rllib.execution.train_ops import all_tower_reduce
            grad_info = tree.map_structure_with_path(
                lambda p, *t: all_tower_reduce(p, *t),
                *[t[1] for t in tower_outputs])
        # Single device case.
        else:
            all_grads, grad_info = tower_outputs[0]

        grad_info["allreduce_latency"] /= len(self._optimizers)
        grad_info.update(self.extra_grad_info(postprocessed_batch))

        fetches = self.extra_compute_grad_fetches()

        return all_grads, dict(fetches, **{LEARNER_STATS_KEY: grad_info})
Esempio n. 2
0
    def compute_gradients(self,
                          postprocessed_batch: SampleBatch) -> ModelGradients:

        # For multi-GPU, split the batch into n slices (n=#GPUs).
        if len(self.devices) == 1:
            batches = [postprocessed_batch]
        else:
            from ray.rllib.utils.sgd import minibatches
            batches = list(
                minibatches(postprocessed_batch,
                            len(postprocessed_batch) // len(self.devices),
                            shuffle=False))

        if not isinstance(postprocessed_batch, SampleBatch) or \
                not postprocessed_batch.zero_padded:
            for b in batches:
                pad_batch_to_sequences_of_same_size(
                    b,
                    max_seq_len=self.max_seq_len,
                    shuffle=False,
                    batch_divisibility_req=self.batch_divisibility_req,
                    view_requirements=self.view_requirements,
                )

        for b, d in zip(batches, self.devices):
            b.is_training = True
            self._lazy_tensor_dict(b, device=d)

        # Multi-GPU case: Slice inputs into n (roughly) equal batches.
        if len(self.devices) > 1:
            # Copy weights of main model to all towers.
            state_dict = self.model.state_dict()
            for tower in self.model_gpu_towers:
                tower.load_state_dict(state_dict)

        # Do the (maybe parallelized) gradient calculation step.
        tower_outputs = self._multi_gpu_parallel_grad_calc(batches)

        # Multi device (GPU) case.
        if len(self.devices) > 1:
            # Mean-reduce over GPU-towers.
            all_grads = []
            for i in range(len(tower_outputs[0][0])):
                if tower_outputs[0][0][i] is not None:
                    all_grads.append(
                        torch.mean(torch.stack(
                            [t[0][i].to(self.device) for t in tower_outputs]),
                                   dim=0))
                else:
                    all_grads.append(None)
            # Set main model's grads to mean-reduced values.
            for i, p in enumerate(self.model.parameters()):
                p.grad = all_grads[i]
            # Reduce stats over towers as well.
            from ray.rllib.execution.train_ops import all_tower_reduce
            grad_info = tree.map_structure_with_path(
                lambda p, *t: all_tower_reduce(p, *t),
                *[t[1] for t in tower_outputs])
        # Single device case.
        else:
            all_grads, grad_info = tower_outputs[0]

        grad_info["allreduce_latency"] /= len(self._optimizers)
        grad_info.update(self.extra_grad_info(postprocessed_batch))

        fetches = self.extra_compute_grad_fetches()

        return all_grads, dict(fetches, **{LEARNER_STATS_KEY: grad_info})