示例#1
0
    def forward(self, x):
        """ Encodes the given input into a q_phi(z|x) probability distribution,
        samples a latent vector from that distribution, and finally calls the decoder network.

        For compatibility, it returns zK_sampled = z_sampled and the log abs det jacobian(T) = 0.0
        (T = identity)

        :returns: z_mu_logvar, z_sampled, zK_sampled=z_sampled, logabsdetjacT=0.0, x_out (reconstructed spectrogram)
        """
        with profiler.record_function(
                "ENCODING") if self.is_profiled else contextlib.nullcontext():
            z_mu_logvar = self.encoder(x)
            n_minibatch = z_mu_logvar.size()[0]
            mu = z_mu_logvar[:, 0, :]
            sigma = torch.exp(z_mu_logvar[:, 1, :] / 2.0)
        with profiler.record_function(
                "LATENT_SAMPLING"
        ) if self.is_profiled else contextlib.nullcontext():
            if self.training:
                # Sampling from the q_phi(z|x) probability distribution - with re-parametrization trick
                eps = Normal(
                    torch.zeros(n_minibatch, self.dim_z, device=mu.device),
                    torch.ones(n_minibatch, self.dim_z,
                               device=mu.device)).sample()
                z_sampled = mu + sigma * eps
            else:  # eval mode: no random sampling
                z_sampled = mu
        with profiler.record_function(
                "DECODING") if self.is_profiled else contextlib.nullcontext():
            x_out = self.decoder(z_sampled)
        return z_mu_logvar, z_sampled, z_sampled, torch.zeros(
            (z_sampled.shape[0], 1), device=x.device), x_out
示例#2
0
文件: oss.py 项目: ncilfone/fairscale
    def step(self, closure: Optional[Callable[[], float]] = None, **kwargs: Any) -> Optional[float]:
        """Performs a single optimization step (parameter update).

        Arguments:
            closure (callable): A closure that reevaluates the model and
                returns the loss. Optional for most optimizers.

        .. note: Any extra parameter is passed to the base optimizer as-is"""

        # Sync oss param_groups attributes in case they've been updated by a scheduler.
        OSS._sync_param_groups(self.param_groups, self.optim.param_groups)

        # Catch a possible change of devices in between OSS construction and step()
        with profiler.record_function("fairscale::oss::refresh_trainable"):
            if self._default_device.type != self.param_groups[0]["params"][0].device.type:
                logging.info("OSS detected that the parameter changed devices, re-allocating buffers")
                self._clear_cache()
                self.refresh_trainable()

        # Run the optimizer step on this shard only:
        with profiler.record_function("fairscale::oss::optim_step"):
            if closure is not None:
                loss = self.optim.step(closure=closure, **kwargs)  # type: ignore
            else:
                loss = self.optim.step(**kwargs)

        # Sync all the updated shards in between the ranks
        self._broadcast_params()

        # Sync hypothethical new results from the wrapped optimizer to the exposed param_groups
        OSS._sync_param_groups(self.optim.param_groups, self.param_groups)

        return loss
示例#3
0
    def calculate_loss(self, interaction):
        user = interaction[self.USER_ID]
        item = interaction[self.ITEM_ID]
        label = interaction[self.LABEL]

        with profiler.record_function("REC output and loss"):
            output_rec = self.forward_rec(user, item)
            loss_rec = self.loss_rec(output_rec, label)

        with profiler.record_function("LM output"):
            output_lm = self.forward_lm(item)

        if self.variant == 3:
            label_lm = self.lm_gt[item].to_dense().to(self.device)

        if self.variant == 2:
            label_lm = self.lm_gt[item].to(self.device).to_dense()

        if self.variant == 1:
            with profiler.record_function("LM making label on GPU"):
                label_lm = self.lm_gt[item].to_dense()
                # label_lm = torch.zeros(len(item), self.vocab_size, device=self.device)
                # for i in range(len(item)):
                #     item_id = item[i]
                #     label_lm[i] = self.lm_gt[item_id].to_dense()

        with profiler.record_function("LM loss"):
            loss_lm = self.loss_lm(output_lm, label_lm)

        return loss_rec, self.alpha * loss_lm
示例#4
0
def benchmark_model(model, optimizer, parameters, name):
    # Run
    step_times = []
    # Autograd profiler adds some overhead, so we time the forward pass with
    # and without enabling it.
    for profile_autograd in [False, True]:
        with profiler.profile(enabled=profile_autograd, use_cuda=(device == "cuda")) as prof:
            for i in range(15):
                # Warm up for five steps, reset step_times after this.
                if i == 5:
                    step_times = []
                with profiler.record_function("forward"):
                    loss = model(x).sum()
                with profiler.record_function("backward"):
                    loss.backward()
                if device == 'cuda':
                    torch.cuda.synchronize()
                start = time.time()
                with profiler.record_function("gradient_norm"):
                    torch.nn.utils.clip_grad_norm_(parameters, 0.1)
                with profiler.record_function("step"):
                    optimizer.step()
                with profiler.record_function("zero_grad"):
                    optimizer.zero_grad()
                if device == 'cuda':
                    torch.cuda.synchronize()
                step_times.append(time.time() - start)
            print(f"Mean step time: {sum(step_times) / 10} seconds. "
                  f"(Autograd profiler enabled: {profile_autograd})")
    prof.export_chrome_trace(f"{name}_timeline.json")
def func(model, device, input_size, epochs=100, dendrite=False):
    batch_size = 4096
    use_cuda = device.type == "cuda"
    dummy_tensor = torch.rand((batch_size, input_size), device=device)
    wall_clock = 0.0
    for _ in range(epochs):
        if dendrite:
            dummy_context = torch.rand((batch_size, model.dim_context), device=device)

            s = time.time()
            with profiler.profile(record_shapes=True, use_cuda=use_cuda) as prof:
                with profiler.record_function("model_inference"):
                    res = model(dummy_tensor, dummy_context)
        else:
            s = time.time()
            with profiler.profile(record_shapes=True, use_cuda=use_cuda) as prof:
                with profiler.record_function("model_inference"):
                    res = model(dummy_tensor)

        wall_clock += time.time() - s
    print("Wall clock:", wall_clock / epochs)
    if device.type == "cuda":
        print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))
    else:
        print(prof.key_averages().table(sort_by="cpu_time_total", row_limit=10))

    dense_params, sparse_params = count_nonzero_params(model)
    print(f"Total params:{dense_params}, non-zero params:{sparse_params}")

    if res.sum() == 0:  # Just to make Python think we need res
        print(res.sum())

    return wall_clock / epochs
示例#6
0
    def calculate_loss(self, interaction):
        user = interaction[self.USER_ID]
        item = interaction[self.ITEM_ID]
        label = interaction[self.LABEL]

        with profiler.record_function("REC output and loss"):
            output_rec = self.forward_rec(user, item)
            loss_rec = self.loss_rec(output_rec, label)

        with profiler.record_function("LM output"):
            output_lm = self.forward_lm(item)

        if self.variant == 1:
            with profiler.record_function("LM making tensor on GPU"):
                label_lm_k = self.lm_gt[item].to(device=self.device)
                label_lm_len = self.lm_gt_len[item].to(device=self.device)
                label_lm = (label_lm_k.T / label_lm_len).T
        elif self.variant == 2:
            with profiler.record_function("LM making tensor on CPU"):
                label_lm_k = self.lm_gt[item]
                label_lm_len = self.lm_gt_len[item]
                label_lm_k = label_lm_k.T / label_lm_len
                label_lm = label_lm_k.to(device=self.device).T

        with profiler.record_function("LM loss"):
            loss_lm = self.loss_lm(output_lm, label_lm)

        return loss_rec, self.alpha * loss_lm
示例#7
0
    def calculate_loss(self, interaction):
        user = interaction[self.USER_ID]
        item = interaction[self.ITEM_ID]
        label = interaction[self.LABEL]

        with profiler.record_function("REC output and loss"):
            output_rec = self.forward_rec(user, item)
            loss_rec = self.loss_rec(output_rec, label)

        with profiler.record_function("LM output"):
            output_lm = self.forward_lm(item)

        if self.variant == 3:
            label_lm = self.lm_gt[item].to_dense().to(self.device)

        if self.variant == 2:
            label_lm = self.lm_gt[item].to(self.device).to_dense()

        if self.variant == 1:
            with profiler.record_function("LM making label on GPU"):
                label_lm = self.lm_gt[item].to_dense()

        with profiler.record_function("LM loss"):
            loss_lm = self.loss_lm(output_lm, label_lm)

        return loss_rec, self.alpha * loss_lm
示例#8
0
    def index(self, uv, cam_z=None, image_size=(), z_bounds=None):
        """
        Get pixel-aligned image features at 2D image coordinates
        :param uv (B, N, 2) image points (x,y)
        :param cam_z ignored (for compatibility)
        :param image_size image size, either (width, height) or single int.
        if not specified, assumes coords are in [-1, 1]
        :param z_bounds ignored (for compatibility)
        :return (B, L, N) L is latent size
        """
        with profiler.record_function("encoder_index"):
            if uv.shape[0] == 1 and self.latent.shape[0] > 1:
                uv = uv.expand(self.latent.shape[0], -1, -1)

            with profiler.record_function("encoder_index_pre"):
                if len(image_size) > 0:
                    if len(image_size) == 1:
                        image_size = (image_size, image_size)
                    scale = self.latent_scaling / image_size
                    uv = uv * scale - 1.0

            uv = uv.unsqueeze(2)  # (B, N, 1, 2)
            samples = F.grid_sample(
                self.latent,
                uv,
                align_corners=True,
                mode=self.index_interp,
                padding_mode=self.index_padding,
            )
            return samples[:, :, :, 0]  # (B, C, N)
示例#9
0
    def forward(self, input, mask):
        with profiler.record_function("LABEL1: linear pass"):
            out = self.linear(input)

        with profiler.record_function("LABEL2: masking"):
            threshold = out.sum(axis=1).mean()  # removed.item()
            hi_idx = (mask > threshold).nonzero(as_tuple=True)
        return out, hi_idx
示例#10
0
 def background_idf(self, term):
     with profiler.record_function("jterm"):
         jterm = self.JTerm("contents", term)
     with profiler.record_function("df"):
         df = self.reader.docFreq(jterm)
     with profiler.record_function("idf"):
         idf = np.log10((self.numdocs - df + 0.5) / (df + 0.5))
     return idf
示例#11
0
    def forward(self, input, mask):
        with profiler.record_function("LINEAR PASS"):
            out = self.linear(input)

        with profiler.record_function("MASK INDICES"):
            threshold = out.sum(axis=1).mean()
            hi_idx = (mask > threshold).nonzero(as_tuple=True)

        return out, hi_idx
示例#12
0
    def forward(self, src_tokens, src_lengths, prev_output_tokens):
        with profiler.record_function("Encoder_out"):
            encoder_out = self.encoder(src_tokens=src_tokens,
                                       src_lengths=src_lengths)

        with profiler.record_function("Decoder_out"):
            decoder_out = self.decoder(prev_output_tokens=prev_output_tokens,
                                       encoder_out=encoder_out)
        return decoder_out
示例#13
0
    def forward(self, x, mask):
        with profiler.record_function('Linear'):
            out = self.linear(x)

        with profiler.record_function('Mask'):
            threshold = out.sum(dim=1).mean().item()
            idx = np.argwhere(mask.cpu().numpy() > threshold)
            idx = torch.from_numpy(idx).cuda()

        return out, idx
示例#14
0
    def forward(self, input, mask):
        with profiler.record_function("LINEAR PASS"):
            out = self.linear(input)

        with profiler.record_function("MASK INDICES"):
            threshold = out.sum(axis=1).mean().item()
            hi_idx = np.argwhere(mask.cpu().numpy() > threshold)
            hi_idx = torch.from_numpy(hi_idx).cuda()

        return out, hi_idx
示例#15
0
    def forward(self, x, sample_info=None):
        """ Encodes the given input into a q_Z0(z_0|x) probability distribution,
        samples a latent vector from that distribution,
        transforms it into q_ZK(z_K|x) using a invertible normalizing flow,
        and finally calls the decoder network using the z_K samples.

        :param x: Single- or Multi-channel spectrogram tensor
        :param sample_info: Required for MIDI pitch end velocity to be appended to the latent vector. On the last dim,
            index 0 should be a preset UID, index 1 a MIDI pitch, index 2 a MIDI velocity.

        :returns: z0_mu_logvar, z0_sampled, zK_sampled, logabsdetjacT, x_out (reconstructed spectrogram)
        """
        n_minibatch = x.size()[0]
        with profiler.record_function(
                "ENCODING") if self.is_profiled else contextlib.nullcontext():
            # Don't ask for requires_grad or this tensor becomes a leaf variable (it will require grad later)
            z_0_mu_logvar = torch.empty((n_minibatch, 2, self.dim_z),
                                        device=x.device,
                                        requires_grad=False)
            if not self.concat_midi_to_z0:
                z_0_mu_logvar = self.encoder(x)
            else:  # insert midi notes if required
                z_0_mu_logvar[:, :, 2:] = self.encoder(x)
                if sample_info is None:  # missing MIDI notes are tolerated for graphs and summaries
                    z_0_mu_logvar[:, :, [0, 1]] = 0.0
                else:  # MIDI pitch and velocity models: free-mean and unit-variance scaled in [-1.0, 1.0]
                    # TODO extend this to work with multiple MIDI notes?
                    # Mean is simply scaled to [-1.0, 1.0] (min/max normalization)
                    midi_pitch_and_vel_mu = -1.0 + 2.0 * sample_info[:, [
                        1, 2
                    ]].float() / 127.0
                    z_0_mu_logvar[:, 0, [0, 1]] = midi_pitch_and_vel_mu
                    # log(var) corresponds to a unit standard deviation in the original [0, 127] MIDI domain
                    z_0_mu_logvar[:, 1, [0, 1]] = np.log(4.0 / (127**2))
            # Separate mean and standard deviation
            mu0 = z_0_mu_logvar[:, 0, :]
            sigma0 = torch.exp(z_0_mu_logvar[:, 1, :] / 2.0)
        with profiler.record_function(
                "LATENT_FLOW") if self.is_profiled else contextlib.nullcontext(
                ):
            if self.training:
                # Sampling from the q_phi(z|x) probability distribution - with re-parametrization trick
                eps = Normal(
                    torch.zeros(n_minibatch, self.dim_z, device=mu0.device),
                    torch.ones(n_minibatch, self.dim_z,
                               device=mu0.device)).sample()
                z_0_sampled = mu0 + sigma0 * eps
            else:  # eval mode: no random sampling
                z_0_sampled = mu0
            # Forward flow (fast with nflows MAF implementation - always fast with RealNVP)
            z_K_sampled, log_abs_det_jac = self.flow_transform(z_0_sampled)
        with profiler.record_function(
                "DECODING") if self.is_profiled else contextlib.nullcontext():
            x_out = self.decoder(z_K_sampled)
        return z_0_mu_logvar, z_0_sampled, z_K_sampled, log_abs_det_jac, x_out
示例#16
0
    def start(self, action_name: str) -> None:
        if self.profiler is None:
            # close profiler if it is already opened. might happen if 2 profilers
            # are created and the first one did not call `describe`
            if torch.autograd._profiler_enabled():
                torch.autograd._disable_profiler()

            if self._schedule is not None:
                self._schedule.setup(action_name)

            self._create_profilers()

            profiler = self.profiler.__enter__()
            if profiler is not None:
                self.profiler = profiler

            if self._parent_profiler is not None:
                self._parent_profiler.__enter__()

        if self._lightning_module is not None and self._register is None and self._record_module_names:
            self._register = RegisterRecordFunction(self._lightning_module)
            self._register.__enter__()

        if self.profiler is not None and action_name not in self._recording_map:
            # Add [pl][profile] in name for pytorch profiler to recognize
            recording = record_function("[pl][profile]" + action_name)
            recording.__enter__()
            self._recording_map[action_name] = recording
示例#17
0
def _broadcast(values, group):
    with torch.no_grad():
        coalesced = get_apex_wrapper().flatten(values)
        with record_function("torch.distributed.broadcast"):
            dist.broadcast(coalesced, 0, group=group)
        get_apex_wrapper().multi_tensor_scale(
            get_apex_wrapper().unflatten(coalesced, values), values, 1.0)
示例#18
0
def dummy_train_3x1(device):
    model = nn.Sequential(
        spnn.Conv3d(4, 32, kernel_size=(3, 1, 3), stride=1),
        spnn.Conv3d(32, 64, kernel_size=(1, 3, 3), stride=1),
        spnn.Conv3d(64, 128, kernel_size=(3, 1, 3), stride=1),
        spnn.Conv3d(128, 256, kernel_size=(1, 3, 3), stride=1),
        spnn.Conv3d(256, 128, kernel_size=(3, 1, 3), stride=1, transpose=True),
        spnn.Conv3d(128, 64, kernel_size=(1, 3, 3), stride=1, transpose=True),
        spnn.Conv3d(64, 32, kernel_size=(3, 1, 3), stride=1, transpose=True),
        spnn.Conv3d(32, 10, kernel_size=(1, 3, 3), stride=1, transpose=True),
    ).to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    criterion = nn.CrossEntropyLoss().to(device)

    print('Starting dummy_train_3x1 ...')
    time = datetime.now()
    with profiler.profile(profile_memory=True, use_cuda=True) as prof:
        with profiler.record_function("model_inference"):
            for i in range(10):
                feed_dict = generate_batched_random_point_clouds()
                inputs = feed_dict['lidar'].to(device)
                targets = feed_dict['targets'].F.to(device).long()
                outputs = model(inputs)
                optimizer.zero_grad()
                loss = criterion(outputs.F, targets)
                loss.backward()
                optimizer.step()
                # print('[step %d] loss = %f.'%(i, loss.item()))
    print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))
    prof.export_chrome_trace("trace_dummy_3x1.json")

    time = datetime.now() - time
    print('Finished dummy_train_3x1 in ', time)
示例#19
0
    def start(self, action_name: str) -> None:
        if self.profiler is None and action_name in self._record_functions_start:

            # close profiler if it is already opened. might happen if 2 profilers
            # are created and the first one did not call `describe`
            try:
                torch.autograd._disable_profiler()  # noqa
            except (AttributeError, RuntimeError):
                pass

            if self._schedule is not None:
                self._schedule.setup(action_name)

            self._create_profilers()

            profiler = self.profiler.__enter__()
            if profiler is not None:
                self.profiler = profiler

            if self._parent_profiler is not None:
                self._parent_profiler.__enter__()

            if self._register is not None:
                self._register.__enter__()

        if (self.profiler is not None
                and (action_name in self._record_functions
                     or action_name.startswith(self.RECORD_FUNCTION_PREFIX))
                and action_name not in self._recording_map):
            recording = record_function(action_name)
            recording.__enter__()
            self._recording_map[action_name] = recording
示例#20
0
def eval():
    data_loader, _ = get_data_loader()
    lstm = load_model().to(device)
    acc = 0
    profiler_x = None

    lstm.eval()
    with torch.no_grad():
        for d in data_loader:
            x = d['x'].to(device)
            label = d['label'].to(device)
            if profiler_x is None:
                profiler_x = x
            pred = lstm(x).to(device)

            _acc = torch.true_divide((torch.argmax(pred, 1) == label).double().sum().item(), 8)
            acc += _acc.item()

        print(torch.true_divide(acc, len(data_loader)), len(data_loader))

    with profiler.profile(record_shapes=True, profile_memory=True, use_cuda=True) as prof:
        with profiler.record_function("model_inference"):
            lstm(profiler_x)
    print(prof.key_averages().table(row_limit=100))
    prof.export_chrome_trace("trace.json")
示例#21
0
    def step(self):
        """
        Signals the profiler that the next profiling step has started.
        """
        if self.record_steps and self.step_rec_fn:
            self.step_rec_fn.__exit__(None, None, None)
        prev_action = self.current_action
        self.step_num += 1
        self.current_action = self.schedule(self.step_num)

        if self.current_action == ProfilerAction.NONE:
            if prev_action == ProfilerAction.NONE:
                pass
            elif prev_action == ProfilerAction.WARMUP:
                warn("Incorrect schedule: WARMUP followed by NONE")
                self._start_trace()
                self._stop_trace()
            elif prev_action == ProfilerAction.RECORD:
                warn("Incorrect schedule: RECORD followed by NONE")
                self._stop_trace()
            else:
                assert prev_action == ProfilerAction.RECORD_AND_SAVE
                self._stop_trace()
                if self.on_trace_ready:
                    self.on_trace_ready(self)
        elif self.current_action == ProfilerAction.WARMUP:
            if prev_action == ProfilerAction.NONE:
                self._start_warmup()
            elif prev_action == ProfilerAction.WARMUP:
                pass
            elif prev_action == ProfilerAction.RECORD:
                warn("Incorrect schedule: RECORD followed by WARMUP")
                self._stop_trace()
            else:
                assert prev_action == ProfilerAction.RECORD_AND_SAVE
                self._stop_trace()
                if self.on_trace_ready:
                    self.on_trace_ready(self)
                self._start_warmup()
        elif self.current_action in \
                [ProfilerAction.RECORD, ProfilerAction.RECORD_AND_SAVE]:
            if prev_action == ProfilerAction.NONE:
                self._start_warmup()
                self._start_trace()
            elif prev_action == ProfilerAction.WARMUP:
                self._start_trace()
            elif prev_action == ProfilerAction.RECORD:
                pass
            else:
                assert prev_action == ProfilerAction.RECORD_AND_SAVE
                self._stop_trace()
                if self.on_trace_ready:
                    self.on_trace_ready(self)
                self._start_warmup()
                self._start_trace()

        if self.record_steps:
            self.step_rec_fn = prof.record_function("ProfilerStep#" +
                                                    str(self.step_num))
            self.step_rec_fn.__enter__()
示例#22
0
 def __enter__(self):
     self._enter_actions()
     if self.record_steps:
         self.step_rec_fn = prof.record_function("ProfilerStep#" +
                                                 str(self.step_num))
         self.step_rec_fn.__enter__()
     return self
示例#23
0
    def refresh_trainable(self) -> None:
        """ If the module trainability has changed, update all the assumptions """

        # Make sure that this is not done while gradients are waiting to be reduced (if no_sync context for instance)
        if functools.reduce(lambda x, y: x or y, self._grad_to_be_reduced,
                            False):
            logging.warning(
                "Grads waiting to be reduced. If this is on purpose (grad accumulation), please use a no_sync() context"
            )

        with profiler.record_function("fairscale::sdp::refresh_trainable"):
            self._trainable_params = list(
                filter(lambda x: x.requires_grad, self._all_params))
            self._trainable_params.sort(key=lambda x: x.numel())

            self._trainable_param_to_rank = {}
            for optim in self._sharded_optimizers:
                # OSS may need to change the communication pattern
                optim.refresh_trainable()

                # Update ShardedDDP given the new partitions
                for (device_per_rank_params
                     ) in optim._per_device_params.values(
                     ):  # all the params on this device (inc all ranks)
                    for device_params in device_per_rank_params:
                        for param in filter(lambda x: x.requires_grad,
                                            device_params):
                            self._trainable_param_to_rank[
                                param] = optim._param_to_rank[param]

            self._setup_bucket_strategy()
            self._setup_backward_hooks()
示例#24
0
    def forward(self, *inputs: Any, **kwargs: Any) -> Any:
        """
        Module forward pass, handles any DDP-specific work in the background. Primes the
        backward pass for gradient reduction to the proper ranks.
        """

        with profiler.record_function("fairscale::sdp::forward"):
            # Deferred initialization, or change detection
            needs_setup = len(self._grad_hooks) == 0 and self.training

            if self._auto_refresh_trainable:
                needs_setup |= self._detect_train_change()

            if needs_setup:
                self.refresh_trainable()

            if self._enable_broadcast_buffers:
                # NCCL communications are on a different stream, needs to be blocking
                # for the subsequent FW to be correct
                self.sync_buffers(blocking=True)

            # Reset all the grad reduce and bucket state flags
            self._clear_counters()

            # Normal FW on the base model
            return self.module(*inputs, **kwargs)
示例#25
0
def record_function_on_caller_rpc_async(dst_worker_name: str, block: str) -> Tensor:
    t: Tensor = torch.ones(1)
    with record_function(block) as rf:
        fut1 = rpc.rpc_async(dst_worker_name, script_add_ones, (t, ))
        fut2 = rpc.rpc_async(dst_worker_name, script_add_ones, (t, ))
        res = fut1.wait() + fut2.wait()
    return res
示例#26
0
    def _benchmark_op(self, op: Callable, args: List, kwargs: Dict[str, Any],
                      tag: str, label_str: str) -> Tuple[float, float]:
        logger.debug(f"benchmarking {label_str}")
        gpu_memory = 0
        timer = Timer(self.device)

        # flush cache
        if self.use_cuda:
            if not self.cuda_l2_cache:
                _clear_cache()
            # Reset to measure peak memory usage
            torch.cuda.reset_peak_memory_stats()

        with record_function(label_str):
            timer.start()
            if self.use_cuda:
                op_run_id_range = torch.cuda.nvtx.range_start(label_str)
            op(*args, **kwargs)
            timer.stop()
            if self.use_cuda:
                torch.cuda.nvtx.range_end(op_run_id_range)
                # Memory size in MB
                gpu_memory = torch.cuda.max_memory_allocated() / (1048576)

        # Return result in milliseconds.
        return timer.elapsed_time_ms(), gpu_memory
示例#27
0
 def _start_recording_forward(self, _: nn.Module, input: Tensor,
                              record_name: str) -> Tensor:
     # Add [pl][module] in name for pytorch profiler to recognize
     record = record_function("[pl][module]" + record_name)
     record.__enter__()
     self._records[record_name] = record
     return input
示例#28
0
    def forward(ctx, a2a_info, *inputs):
        global myreq
        with record_function("DLRM alltoall_req_fwd_single"):
            batch_split_lengths = a2a_info.global_batch_partition_slices
            if batch_split_lengths:
                batch_split_lengths = [
                    m * a2a_info.emb_dim * a2a_info.local_table_num
                    for m in batch_split_lengths
                ]
            table_split_lengths = a2a_info.global_table_wise_parition_slices
            if table_split_lengths:
                table_split_lengths = [
                    a2a_info.local_batch_num * e * a2a_info.emb_dim
                    for e in table_split_lengths
                ]
            input = torch.cat(inputs, dim=1).view([-1])
            output = input.new_empty([
                a2a_info.global_table_num * a2a_info.local_batch_num *
                a2a_info.emb_dim
            ])
            req = dist.all_to_all_single(output,
                                         input,
                                         table_split_lengths,
                                         batch_split_lengths,
                                         async_op=True)

            myreq.req = req
            myreq.tensor = []
            myreq.tensor.append(output)
            myreq.tensor = tuple(myreq.tensor)
            a2a_info.batch_split_lengths = batch_split_lengths
            a2a_info.table_split_lengths = table_split_lengths
            myreq.a2a_info = a2a_info
            ctx.a2a_info = a2a_info
            return myreq.tensor
示例#29
0
 def block(self,
           tag: str,
           time_tag: Optional[str] = None,
           monitor_gpu_utils: bool = False):
     self.debug(f"start {tag}")
     time_tag = time_tag or tag
     if not torch.cuda.is_available():
         monitor_gpu_utils = False
     if monitor_gpu_utils:
         torch.cuda.synchronize()
     begin = time.time()
     try:
         with record_function(f"{self.name}:{tag}"):
             yield
     finally:
         self.debug(f"finish {tag}")
         if time_tag not in self.elapsed_time_log:
             self.elapsed_time_log[time_tag] = (0, 0.0)
         if monitor_gpu_utils and time_tag not in self.gpu_elapsed_time_log:
             self.gpu_elapsed_time_log[time_tag] = (0, 0.0)
         elapsed_time = time.time() - begin
         n, t = self.elapsed_time_log[time_tag]
         self.elapsed_time_log[time_tag] = (n + 1, t + elapsed_time)
         if monitor_gpu_utils:
             torch.cuda.synchronize()
             gpu_elapsed_time = time.time() - begin
             n, t = self.gpu_elapsed_time_log[time_tag]
             self.gpu_elapsed_time_log[time_tag] = \
                 (n + 1, t + gpu_elapsed_time)
示例#30
0
文件: oss.py 项目: ncilfone/fairscale
    def clip_grad_norm(
        self,
        max_norm: Union[float, int],
        norm_type: Union[float, int] = 2.0,
        filter_params_fn: Callable[[Any], Any] = None,
    ) -> torch.Tensor:
        """
        Clip all gradients at this point in time. The norm is computed over all gradients together, as if they were
        concatenated into a single vector. Gradients are modified in-place.

        Arguments:
            max_norm (float or int): max norm of the gradients
            norm_type (float or int): type of the used p-norm. Can be ``'inf'`` for infinity norm.

        Returns:
            Total norm of the parameters (viewed as a single vector).

        .. note: This is analogous to `torch.nn.utils.clip_grad_norm_` but handles the partitioning and multiple devices per rank
            under the hood. The default torch util is not applicable here, because each rank only has a partial view of all the grads
            in the model, so calling it in the OSS context would lead to different scaling being applied per subset of model parameters

        .. warning: This needs to be called on all ranks, since synchronization primitives will be used

        """

        # Compute the max norm for this shards's worth of gradients
        max_norm = float(max_norm)
        norm_type = float(norm_type)

        with profiler.record_function("fairscale::oss::clip_grad_norm"):
            # Option to filter parameters from the grad_norm calculation. This is useful for model parallelism.
            # To avoid double counting, only consider parameters on rank zero + anything marked 'model_parallel'
            # 'model_parallel' flag is set in Megatron-LM:
            # https://github.com/NVIDIA/Megatron-LM/blob/19301985dd31c8b612095cbad15bd903e8ddd497/megatron/mpu/layers.py#L54
            local_params = filter_params_fn(self._local_params) if filter_params_fn is not None else self._local_params

            local_norm = calc_grad_norm(local_params, norm_type).to(self._default_device)
            # Compute the norm on this grad set,
            # then sync all the norms from all ranks
            if norm_type == inf:
                total_norm = local_norm
                # all reduce over data parallel and model parallel workers
                dist.all_reduce(total_norm, op=torch.distributed.ReduceOp.MAX, group=dist.group.WORLD)
            else:
                # local norm result can be accumulated with the remote ones if put to the right power
                # n_i = sum_rank(a^p)^1/p
                # -> n_total = all_reduce(n_i^p)^(1/p) = sum_i(n_i^p)^1/p = sum_i(sum_rank(a^p))^1/p
                # all reduce over data parallel and model parallel workers
                total_norm = local_norm ** norm_type
                dist.all_reduce(total_norm)
                total_norm = total_norm ** (1.0 / norm_type)

            clip_coef = torch.tensor(max_norm, dtype=total_norm.dtype, device=total_norm.device) / (total_norm + 1e-6)
            if clip_coef < 1:
                for device, device_params in self._per_device_params.items():
                    for p in filter(lambda x: x.grad is not None, device_params[self.rank]):
                        p.grad.detach().mul_(clip_coef.to(device))  # type: ignore   # mypy trips on the filter

        return total_norm