def on_step(self, task) -> None:
     if not task.train or self.current_samples >= self.num_samples:
     input = recursive_copy_to_device(
     self.current_samples += get_batchsize_per_replica(input)
Example #2
    def on_step(self, task) -> None:
        if not task.train:

        if self.cache_samples:
            if self.current_samples >= self.num_samples:
            sample = recursive_copy_to_device(
            self.current_samples += get_batchsize_per_replica(sample)

            if self.batch_size is not None:

            self.batch_size = get_batchsize_per_replica(task.last_batch.sample["input"])
Example #3
def _layer_flops(layer: nn.Module, layer_args: List[Any], y: Any) -> int:
    Computes the number of FLOPs required for a single layer.

    For common layers, such as Conv1d, the flop compute is implemented in this
    centralized place.
    For other layers, if it defines a method to compute flops with the signature
    below, we will use it to compute flops.

    Class MyModule(nn.Module):
        def flops(self, x):


    x = layer_args[0]
    # get layer type:
    typestr = layer.__repr__()
    layer_type = typestr[: typestr.find("(")].strip()
    batchsize_per_replica = get_batchsize_per_replica(x)

    flops = None
    # 1D convolution:
    if layer_type in ["Conv1d"]:
        # x shape is N x C x W
        out_w = int(
            (x.size()[2] + 2 * layer.padding[0] - layer.kernel_size[0])
            / layer.stride[0]
            + 1
        flops = (
            * layer.in_channels
            * layer.out_channels
            * layer.kernel_size[0]
            * out_w
            / layer.groups
    # 2D convolution:
    elif layer_type in ["Conv2d"]:
        out_h = int(
            (x.size()[2] + 2 * layer.padding[0] - layer.kernel_size[0])
            / layer.stride[0]
            + 1
        out_w = int(
            (x.size()[3] + 2 * layer.padding[1] - layer.kernel_size[1])
            / layer.stride[1]
            + 1
        flops = (
            * layer.in_channels
            * layer.out_channels
            * layer.kernel_size[0]
            * layer.kernel_size[1]
            * out_h
            * out_w
            / layer.groups

    # learned group convolution:
    elif layer_type in ["LearnedGroupConv"]:
        conv = layer.conv
        out_h = int(
            (x.size()[2] + 2 * conv.padding[0] - conv.kernel_size[0]) / conv.stride[0]
            + 1
        out_w = int(
            (x.size()[3] + 2 * conv.padding[1] - conv.kernel_size[1]) / conv.stride[1]
            + 1
        count1 = _layer_flops(layer.relu, x) + _layer_flops(layer.norm, x)
        count2 = (
            * conv.in_channels
            * conv.out_channels
            * conv.kernel_size[0]
            * conv.kernel_size[1]
            * out_h
            * out_w
            / layer.condense_factor
        flops = count1 + count2

    # non-linearities:
    elif layer_type in ["ReLU", "ReLU6", "Tanh", "Sigmoid", "Softmax", "SiLU"]:
        flops = x.numel()

    # 2D pooling layers:
    elif layer_type in ["AvgPool2d", "MaxPool2d"]:
        in_h = x.size()[2]
        in_w = x.size()[3]
        if isinstance(layer.kernel_size, int):
            layer.kernel_size = (layer.kernel_size, layer.kernel_size)
        kernel_ops = layer.kernel_size[0] * layer.kernel_size[1]
        out_h = 1 + int(
            (in_h + 2 * layer.padding - layer.kernel_size[0]) / layer.stride
        out_w = 1 + int(
            (in_w + 2 * layer.padding - layer.kernel_size[1]) / layer.stride
        flops = x.size()[0] * x.size()[1] * out_w * out_h * kernel_ops

    # adaptive avg pool2d
    # This is approximate and works only for downsampling without padding
    # based on aten/src/ATen/native/AdaptiveAveragePooling.cpp
    elif layer_type in ["AdaptiveAvgPool2d"]:
        in_h = x.size()[2]
        in_w = x.size()[3]
        if isinstance(layer.output_size, int):
            out_h, out_w = layer.output_size, layer.output_size
        elif len(layer.output_size) == 1:
            out_h, out_w = layer.output_size[0], layer.output_size[0]
            out_h, out_w = layer.output_size
        if out_h > in_h or out_w > in_w:
            raise ClassyProfilerNotImplementedError(layer)
        batchsize_per_replica = x.size()[0]
        num_channels = x.size()[1]
        kh = in_h - out_h + 1
        kw = in_w - out_w + 1
        kernel_ops = kh * kw
        flops = batchsize_per_replica * num_channels * out_h * out_w * kernel_ops

    # linear layer:
    elif layer_type in ["Linear"]:
        weight_ops = layer.weight.numel()
        bias_ops = layer.bias.numel() if layer.bias is not None else 0
        flops = x.size()[0] * (weight_ops + bias_ops)

    # batch normalization / layer normalization:
    elif layer_type in [
        flops = 2 * x.numel()

    # 3D convolution
    elif layer_type in ["Conv3d"]:
        out_t = int(
            (x.size()[2] + 2 * layer.padding[0] - layer.kernel_size[0])
            // layer.stride[0]
            + 1
        out_h = int(
            (x.size()[3] + 2 * layer.padding[1] - layer.kernel_size[1])
            // layer.stride[1]
            + 1
        out_w = int(
            (x.size()[4] + 2 * layer.padding[2] - layer.kernel_size[2])
            // layer.stride[2]
            + 1
        flops = (
            * layer.in_channels
            * layer.out_channels
            * layer.kernel_size[0]
            * layer.kernel_size[1]
            * layer.kernel_size[2]
            * out_t
            * out_h
            * out_w
            / layer.groups

    # 3D pooling layers
    elif layer_type in ["AvgPool3d", "MaxPool3d"]:
        in_t = x.size()[2]
        in_h = x.size()[3]
        in_w = x.size()[4]
        if isinstance(layer.kernel_size, int):
            layer.kernel_size = (
        if isinstance(layer.padding, int):
            layer.padding = (layer.padding, layer.padding, layer.padding)
        if isinstance(layer.stride, int):
            layer.stride = (layer.stride, layer.stride, layer.stride)
        kernel_ops = layer.kernel_size[0] * layer.kernel_size[1] * layer.kernel_size[2]
        out_t = 1 + int(
            (in_t + 2 * layer.padding[0] - layer.kernel_size[0]) / layer.stride[0]
        out_h = 1 + int(
            (in_h + 2 * layer.padding[1] - layer.kernel_size[1]) / layer.stride[1]
        out_w = 1 + int(
            (in_w + 2 * layer.padding[2] - layer.kernel_size[2]) / layer.stride[2]
        flops = batchsize_per_replica * x.size()[1] * out_t * out_h * out_w * kernel_ops

    # adaptive avg pool3d
    # This is approximate and works only for downsampling without padding
    # based on aten/src/ATen/native/AdaptiveAveragePooling3d.cpp
    elif layer_type in ["AdaptiveAvgPool3d"]:
        in_t = x.size()[2]
        in_h = x.size()[3]
        in_w = x.size()[4]
        out_t = layer.output_size[0]
        out_h = layer.output_size[1]
        out_w = layer.output_size[2]
        if out_t > in_t or out_h > in_h or out_w > in_w:
            raise ClassyProfilerNotImplementedError(layer)
        batchsize_per_replica = x.size()[0]
        num_channels = x.size()[1]
        kt = in_t - out_t + 1
        kh = in_h - out_h + 1
        kw = in_w - out_w + 1
        kernel_ops = kt * kh * kw
        flops = (
            batchsize_per_replica * num_channels * out_t * out_w * out_h * kernel_ops

    # dropout layer
    elif layer_type in ["Dropout"]:
        # At test time, we do not drop values but scale the feature map by the
        # dropout ratio
        flops = 1
        for dim_size in x.size():
            flops *= dim_size

    elif layer_type == "Identity":
        flops = 0

    elif hasattr(layer, "flops"):
        # If the module already defines a method to compute flops with the signature
        # below, we use it to compute flops
        #   Class MyModule(nn.Module):
        #     def flops(self, x):
        #       ...
        #   or
        #   Class MyModule(nn.Module):
        #     def flops(self, x1, x2):
        #       ...
        flops = layer.flops(*layer_args)

    if flops is None:
        raise ClassyProfilerNotImplementedError(layer)

    message = [
        f"module type: {typestr}",
        f"input size: {get_shape(x)}",
        f"output size: {get_shape(y)}",
        f"params(M): {count_params(layer) / 1e6}",
        f"flops(M): {int(flops) / 1e6}",
    return int(flops)