Beispiel #1
0
def add_dense_layer(network, input, weight_map, lname):
    bn1 = add_batch_norm_2d(network, weight_map, input, lname + ".norm1")

    relu1 = network.add_activation(bn1.get_output(0),
                                   type=trt.ActivationType.RELU)
    assert relu1

    conv1 = network.add_convolution(input=relu1.get_output(0),
                                    num_output_maps=128,
                                    kernel_shape=(1, 1),
                                    kernel=weight_map[lname + ".conv1.weight"],
                                    bias=trt.Weights())
    assert conv1
    conv1.stride = (1, 1)

    bn2 = add_batch_norm_2d(network, weight_map, conv1.get_output(0),
                            lname + ".norm2")

    relu2 = network.add_activation(bn2.get_output(0),
                                   type=trt.ActivationType.RELU)
    assert relu2

    conv2 = network.add_convolution(input=relu2.get_output(0),
                                    num_output_maps=32,
                                    kernel_shape=(3, 3),
                                    kernel=weight_map[lname + ".conv2.weight"],
                                    bias=trt.Weights())
    assert conv2
    conv2.stride = (1, 1)
    conv2.padding = (1, 1)

    return conv2
Beispiel #2
0
def load_onnx_weights_and_quant(path, config):
    """
    Load the weights from the onnx checkpoint
    """
    N = config.num_attention_heads
    H = config.head_size
    hidden_size = config.hidden_size

    model = onnx.load(path)
    weights = model.graph.initializer
    tensor_dict = dict([(onnx_to_trt_name(w.name),
                         np.frombuffer(w.raw_data, np.float32).reshape(w.dims))
                        for w in weights])

    weights_dict = dict()
    for outname, tensor in tensor_dict.items():
        if outname.find("_amax") != -1:
            weights_dict[outname] = tensor
        elif outname.find(BQ) != -1:
            prefix = outname[:outname.find(BQ)]

            Wqkv = np.zeros((3, hidden_size, hidden_size), np.float32)
            Bqkv = np.zeros((3, hidden_size), np.float32)

            Wqkv[0, :, :] = tensor_dict[prefix + WQ]
            Wqkv[1, :, :] = tensor_dict[prefix + WK]
            Wqkv[2, :, :] = tensor_dict[prefix + WV]
            Bqkv[0, :] = tensor
            Bqkv[1, :] = tensor_dict[prefix + BK]
            Bqkv[2, :] = tensor_dict[prefix + BV]

            if config.use_int8 and config.interleaved:
                Wqkv = np.ascontiguousarray(Wqkv.reshape((3, N, H, N, H)))
                Bqkv = np.ascontiguousarray(Bqkv.reshape((3, N, H)))
            else:
                Wqkv = np.ascontiguousarray(
                    Wqkv.reshape((3, N, H, N, H)).transpose((1, 0, 2, 3, 4)))
                Bqkv = np.ascontiguousarray(
                    Bqkv.reshape((3, N, H)).transpose((1, 0, 2)))

            weights_dict[prefix + WQKV] = trt.Weights(Wqkv)
            weights_dict[prefix + BQKV] = trt.Weights(Bqkv)
            weights_dict[prefix + WQKV + "_notrans"] = trt.Weights(Wqkv.T)

        elif outname.find(BK) != -1 or outname.find(BV) != -1 or outname.find(
                WQ) != -1 or outname.find(WK) != -1 or outname.find(WV) != -1:
            pass
        else:
            flat_tensor = np.ascontiguousarray(tensor).flatten()
            weights_dict[outname] = trt.Weights(flat_tensor)

            if outname.find("kernel") != -1:
                tensor = np.transpose(tensor)
                weights_dict[outname + "_notrans"] = trt.Weights(
                    np.ascontiguousarray(tensor).flatten())

    TRT_LOGGER.log(TRT_LOGGER.INFO,
                   "Found {:} entries in weight map".format(len(weights_dict)))
    return weights_dict
Beispiel #3
0
def acc_ops_layer_norm(network, target, args, kwargs, name):
    input_val = kwargs["input"]

    if not isinstance(input_val, trt.tensorrt.ITensor):
        raise RuntimeError(f"LayerNorm received input {input_val} that is not part "
                           "of the TensorRT region!")

    shape = kwargs["weight"].shape
    broadcasted_shape = (1,) * (len(input_val.shape) - len(shape)) + shape
    gamma = to_numpy(kwargs["weight"].reshape(*shape))
    beta = to_numpy(kwargs["bias"].reshape(*shape))
    eps = kwargs["eps"]
    normalized_shape = kwargs["normalized_shape"]

    axes = 0
    for d in range(len(normalized_shape)):
        axes |= 1 << (len(input_val.shape) - d - 1)

    # E[x]
    mean_expected_layer = network.add_reduce(input_val, trt.ReduceOperation.AVG, axes, keep_dims=True)
    mean_expected_layer.name = f"{name}_mean_expected"
    # X-E[x]
    sub_trt = add_binary_elementwise_layer(
        network, input_val, mean_expected_layer.get_output(0), trt.ElementWiseOperation.SUB, f"{name}_sub"
    )
    # Variance = mean(pow(x_sub_mean,2))
    pow_tensor = network.add_constant(
        (1,) * len(input_val.shape), trt.Weights(np.ascontiguousarray([2.0], dtype=np.float32))
    )
    pow_tensor.name = f"{name}_power"
    pow_var = add_binary_elementwise_layer(
        network, sub_trt, pow_tensor.get_output(0), trt.ElementWiseOperation.POW, f"{name}_pow_var"
    )
    mean_trt_layer = network.add_reduce(pow_var, trt.ReduceOperation.AVG, axes, keep_dims=True)
    mean_trt_layer.name = f"{name}_mean"
    # Variance + eps
    eps_tensor = network.add_constant(
        (1,) * len(input_val.shape), trt.Weights(np.ascontiguousarray([eps], dtype=np.float32))
    )
    eps_tensor.name = f"{name}_eps"
    add_trt = add_binary_elementwise_layer(
        network, mean_trt_layer.get_output(0), eps_tensor.get_output(0), trt.ElementWiseOperation.SUM, f"{name}_add"
    )
    # SQRT((Var + eps))
    sqrt_trt = add_unary_layer(network, add_trt, trt.UnaryOperation.SQRT, f"{name}_sqrt")
    # (x - E[x]) / sqrt((var + eps))
    div_trt = add_binary_elementwise_layer(network, sub_trt, sqrt_trt, trt.ElementWiseOperation.DIV, f"{name}_div_trt")

    gamma_tensor = network.add_constant(gamma.shape, trt.Weights(np.ascontiguousarray(gamma)))
    gamma_tensor.name = f"{name}_gamma"
    beta_tensor = network.add_constant(gamma.shape, trt.Weights(np.ascontiguousarray(beta)))
    beta_tensor.name = f"{name}_beta"
    # y * gamma + beta
    scale_layer = add_binary_elementwise_layer(
        network, div_trt, gamma_tensor.get_output(0), trt.ElementWiseOperation.PROD, f"{name}_scale"
    )
    return add_binary_elementwise_layer(
        network, scale_layer, beta_tensor.get_output(0), trt.ElementWiseOperation.SUM, name
    )
def FC(network,input,insize,outsize,model_dict,prefix):
  weight = trt.Weights(model_dict[prefix + ".weight"])
  bias = trt.Weights(model_dict[prefix + ".bias"])
  w = network.add_constant(shape=(outsize,insize),weights=weight).get_output(0)
  b = network.add_constant(shape=(1,outsize),weights=bias).get_output(0)
  fc = network.add_matrix_multiply(input,trt.MatrixOperation.NONE,w,trt.MatrixOperation.TRANSPOSE).get_output(0)
  fb = network.add_elementwise(fc,b,trt.ElementWiseOperation.SUM).get_output(0)
  return fb
def bottleneck(network,weight_map,input,in_channels,out_channels,stride,layer_name):
    conv1 = network.add_convolution(input=input,
                                    num_output_maps = out_channels,
                                    kernel_shape = (1,1),
                                    kernel = weight_map[layer_name + "conv1.weight"],
                                    bias = trt.Weights())

    assert conv1

    bn1 = addBatchNorm2d(network,weight_map,conv1.get_output(0),layer_name + "bn1",EPS)
    assert bn1

    relu1 = network.add_activation(bn1.get_output(0),type=trt.ActivationType.RELU)
    assert relu1

    conv2 = network.add_convolution(input=relu1.get_output(0),
                                    num_output_maps=out_channels,
                                    kernel_shape=(3,3),
                                    kernel=weight_map[layer_name + "conv2.weight"],
                                    bias=trt.Weights())
    assert conv2
    conv2.stride = (stride,stride)
    conv2.padding = (1,1)
    bn2 = addBatchNorm2d(network,weight_map,conv2.get_output(0),layer_name + "bn2",EPS)
    assert bn2
    relu2 = network.add_activation(bn2.get_output(0),type=trt.ActivationType.RELU)
    assert relu2

    conv3 = network.add_convolution(input=relu2.get_output(0),
                                    num_output_maps=out_channels * 4,
                                    kernel_shape=(1,1),
                                    kernel=weight_map[layer_name + "conv3.weight"],
                                    bias=trt.Weights())
    assert conv3
    bn3 = addBatchNorm2d(network,weight_map,conv3.get_output(0),layer_name+"bn3",EPS)
    if stride != 1 or in_channels != 4*out_channels:
        conv4 = network.add_convolution(
            input=input,
            num_output_maps=out_channels*4,
            kernel_shape = (1,1),
            kernel=weight_map[layer_name + "downsample.0.weight"],
            bias=trt.Weights()
        )
        assert conv4

        conv4.stride = (stride,stride)
        bn4 = addBatchNorm2d(network,weight_map,conv4.get_output(0),layer_name + "downsample.1", EPS)
        assert bn4

        ew1 = network.add_elementwise(bn4.get_output(0),bn3.get_output(0),trt.ElementWiseOperation.SUM)
    else:
        ew1 = network.add_elementwise(input,bn3.get_output(0),trt.ElementWiseOperation.SUM)
    assert ew1

    relu3 = network.add_activation(ew1.get_output(0),type=trt.ActivationType.RELU)
    assert relu3

    return relu3
Beispiel #6
0
def sequence_class_output(prefix,
                          init_dict,
                          network,
                          input_tensor,
                          softmax=True):
    logging.info(input_tensor.shape)
    seq_len = input_tensor.shape[1]
    hidden_size = input_tensor.shape[2]

    shuf = network.add_shuffle(input_tensor)
    shuf.first_transpose = (0, 3, 4, 1, 2)
    logging.info("seq class in: ", shuf.get_output(0).shape)

    in_shape_tensor = network.add_shape(shuf.get_output(0)).get_output(0)
    out_shape_tensor = network.add_gather(
        in_shape_tensor,
        network.add_constant(
            (5, ), trt.Weights(np.array([0, 1, 2, 2,
                                         4]).astype(np.int32))).get_output(0),
        0,
    ).get_output(0)

    first_token_tensor = network.add_slice(
        shuf.get_output(0),
        start=(0, 0, 0, 0, 0),
        shape=(-1, 1, 1, 1, hidden_size),
        stride=(1, 1, 1, 1, 1),
    )
    first_token_tensor.set_input(
        1,
        network.add_constant(
            (5, ), trt.Weights(np.array([0, 0, 0, 0,
                                         0]).astype(np.int32))).get_output(0),
    )
    first_token_tensor.set_input(2, out_shape_tensor)

    W_out = init_dict[prefix + "mlp.layer0." + SQD_W]
    B_out = init_dict[prefix + "mlp.layer0." + SQD_B]
    dense = network.add_fully_connected(first_token_tensor.get_output(0),
                                        W_out.shape[0], W_out, B_out)
    dense_relu = network.add_activation(dense.get_output(0),
                                        trt.ActivationType.RELU)
    W_out = init_dict[prefix + "mlp.layer2." + SQD_W]
    B_out = init_dict[prefix + "mlp.layer2." + SQD_B]
    classifier = network.add_fully_connected(dense_relu.get_output(0),
                                             W_out.shape[0], W_out, B_out)
    if softmax:
        probs = network.add_softmax(classifier.get_output(0))
        probs.axes = 4  # last dimension
        classifier = probs
    classifier = network.add_shuffle(classifier.get_output(0))
    classifier.reshape_dims = trt.Dims([0, -1])

    set_layer_name(classifier, prefix, "classifier")
    logging.info("seq class: ", classifier.get_output(0).shape)
    return classifier
Beispiel #7
0
def conv_seq_2(network, weight_map, input, output, hdim, k, s, use_se, use_hs,
               w, lname):
    p = (k - 1) // 2
    conv1 = network.add_convolution(input=input,
                                    num_output_maps=hdim,
                                    kernel_shape=(1, 1),
                                    kernel=weight_map[lname + "0.weight"],
                                    bias=trt.Weights())
    bn1 = add_batch_norm_2d(network, weight_map, conv1.get_output(0),
                            lname + "1", EPS)

    if use_hs:
        hsw1 = add_h_swish(network, bn1.get_output(0))
        tensor3 = hsw1.get_output(0)
    else:
        relu1 = network.add_activation(bn1.get_output(0),
                                       type=trt.ActivationType.RELU)
        tensor3 = relu1.get_output(0)

    conv2 = network.add_convolution(input=tensor3,
                                    num_output_maps=hdim,
                                    kernel_shape=(k, k),
                                    kernel=weight_map[lname + "3.weight"],
                                    bias=trt.Weights())
    conv2.stride = (s, s)
    conv2.padding = (p, p)
    conv2.num_groups = hdim
    bn2 = add_batch_norm_2d(network, weight_map, conv2.get_output(0),
                            lname + "4", EPS)

    if use_se:
        se1 = add_se_layer(network, weight_map, bn2.get_output(0), hdim, w,
                           lname + "5.")
        tensor6 = se1.get_output(0)
    else:
        tensor6 = bn2.get_output(0)

    if use_hs:
        hsw2 = add_h_swish(network, tensor6)
        tensor7 = hsw2.get_output(0)
    else:
        relu2 = network.add_activation(tensor6, type=trt.ActivationType.RELU)
        tensor7 = relu2.get_output(0)

    conv3 = network.add_convolution(input=tensor7,
                                    num_output_maps=output,
                                    kernel_shape=(1, 1),
                                    kernel=weight_map[lname + "7.weight"],
                                    bias=trt.Weights())
    bn3 = add_batch_norm_2d(network, weight_map, conv3.get_output(0),
                            lname + "8", EPS)
    assert bn3

    return bn3
Beispiel #8
0
def add_conv_relu(reader, network, input, outch, kernel, stride, lname):
    w = reader.get_tensor(lname + "weights").transpose(3, 2, 0, 1).reshape(-1)
    b = reader.get_tensor(lname + "biases")
    conv = network.add_convolution(input, outch, (kernel, kernel),
                                   trt.Weights(w), trt.Weights(b))
    conv.stride = (stride, stride)
    if kernel == 3:
        conv.padding = (1, 1)

    ac = network.add_activation(conv.get_output(0), trt.ActivationType.RELU)
    return ac
Beispiel #9
0
def make_gelu_layer(prefix, config, network, input_tensor):
    POW = network.add_constant(
        (1, 1, 1, 1, 1),
        trt.Weights(np.ascontiguousarray([3.0], dtype=np.float32)))
    MULTIPLY = network.add_constant(
        (1, 1, 1, 1, 1),
        trt.Weights(np.ascontiguousarray([0.044715], dtype=np.float32)))
    SQRT = network.add_constant(
        (1, 1, 1, 1, 1),
        trt.Weights((np.ascontiguousarray([0.79788456080286535587989211986876],
                                          dtype=np.float32))))
    ONE = network.add_constant((1, 1, 1, 1, 1),
                               trt.Weights(
                                   (np.ascontiguousarray([1.0],
                                                         dtype=np.float32))))
    HALF = network.add_constant((1, 1, 1, 1, 1),
                                trt.Weights(
                                    (np.ascontiguousarray([0.5],
                                                          dtype=np.float32))))
    X_pow = network.add_elementwise(input_tensor, POW.get_output(0),
                                    trt.ElementWiseOperation.POW)
    X_pow_t = X_pow.get_output(0)
    X_mul = network.add_elementwise(X_pow_t, MULTIPLY.get_output(0),
                                    trt.ElementWiseOperation.PROD)
    X_add = network.add_elementwise(mid_dense_out, X_mul.get_output(0),
                                    trt.ElementWiseOperation.SUM)
    X_sqrt = network.add_elementwise(X_add.get_output(0), SQRT.get_output(0),
                                     trt.ElementWiseOperation.PROD)
    X_sqrt_tensor = X_sqrt.get_output(0)
    X_tanh = network.add_activation(X_sqrt_tensor, trt.ActivationType.TANH)
    X_tanh_tensor = X_tanh.get_output(0)
    X_one = network.add_elementwise(X_tanh_tensor, ONE.get_output(0),
                                    trt.ElementWiseOperation.SUM)
    CDF = network.add_elementwise(X_one.get_output(0), HALF.get_output(0),
                                  trt.ElementWiseOperation.PROD)
    gelu_layer = network.add_elementwise(CDF.get_output(0), mid_dense_out,
                                         trt.ElementWiseOperation.PROD)

    # enable elementwise fusing for int8 && fp16
    POW.precision = trt.DataType.FLOAT
    MULTIPLY.precision = trt.DataType.FLOAT
    SQRT.precision = trt.DataType.FLOAT
    ONE.precision = trt.DataType.FLOAT
    HALF.precision = trt.DataType.FLOAT
    X_pow.precision = trt.DataType.FLOAT
    X_mul.precision = trt.DataType.FLOAT
    X_add.precision = trt.DataType.FLOAT
    X_sqrt.precision = trt.DataType.FLOAT
    X_tanh.precision = trt.DataType.FLOAT
    X_one.precision = trt.DataType.FLOAT
    CDF.precision = trt.DataType.FLOAT
    gelu_layer.precision = trt.DataType.FLOAT
    return gelu_layer
Beispiel #10
0
def add_gelu(network, input_tensor):
    """
    Adds elementwise GELU, and will trigger FC+GELU fusion in TRT
    """
    shape = (1, ) * len(input_tensor.shape)
    POW = network.add_constant(
        shape, trt.Weights(np.ascontiguousarray([3.0], dtype=np.float32)))
    MULTIPLY = network.add_constant(
        shape, trt.Weights(np.ascontiguousarray([0.044715], dtype=np.float32)))
    SQRT = network.add_constant(
        shape,
        trt.Weights((np.ascontiguousarray([0.79788456080286535587989211986876],
                                          dtype=np.float32))))
    ONE = network.add_constant(
        shape, trt.Weights((np.ascontiguousarray([1.0], dtype=np.float32))))
    HALF = network.add_constant(
        shape, trt.Weights((np.ascontiguousarray([0.5], dtype=np.float32))))
    X_pow = network.add_elementwise(input_tensor, POW.get_output(0),
                                    trt.ElementWiseOperation.POW)
    X_pow_t = X_pow.get_output(0)
    X_mul = network.add_elementwise(X_pow_t, MULTIPLY.get_output(0),
                                    trt.ElementWiseOperation.PROD)
    X_add = network.add_elementwise(input_tensor, X_mul.get_output(0),
                                    trt.ElementWiseOperation.SUM)
    X_sqrt = network.add_elementwise(X_add.get_output(0), SQRT.get_output(0),
                                     trt.ElementWiseOperation.PROD)
    X_sqrt_tensor = X_sqrt.get_output(0)
    X_tanh = network.add_activation(X_sqrt_tensor, trt.ActivationType.TANH)
    X_tanh_tensor = X_tanh.get_output(0)
    X_one = network.add_elementwise(X_tanh_tensor, ONE.get_output(0),
                                    trt.ElementWiseOperation.SUM)
    CDF = network.add_elementwise(X_one.get_output(0), HALF.get_output(0),
                                  trt.ElementWiseOperation.PROD)
    gelu_layer = network.add_elementwise(CDF.get_output(0), input_tensor,
                                         trt.ElementWiseOperation.PROD)

    # enable elementwise fusing for int8 && fp16
    POW.precision = trt.DataType.FLOAT
    MULTIPLY.precision = trt.DataType.FLOAT
    SQRT.precision = trt.DataType.FLOAT
    ONE.precision = trt.DataType.FLOAT
    HALF.precision = trt.DataType.FLOAT
    X_pow.precision = trt.DataType.FLOAT
    X_mul.precision = trt.DataType.FLOAT
    X_add.precision = trt.DataType.FLOAT
    X_sqrt.precision = trt.DataType.FLOAT
    X_tanh.precision = trt.DataType.FLOAT
    X_one.precision = trt.DataType.FLOAT
    CDF.precision = trt.DataType.FLOAT
    gelu_layer.precision = trt.DataType.FLOAT
    return gelu_layer
Beispiel #11
0
def get_onnx_weight_dict(tensor_dict, config):
    N = config.num_attention_heads
    H = config.head_size
    hidden_size = config.hidden_size

    weights_dict = dict()
    for outname, tensor in tensor_dict.items():
        if outname.find("_amax") != -1:
            weights_dict[outname] = tensor
        elif outname.find(BQ) != -1:
            prefix = outname[:outname.find(BQ)]

            Wqkv = np.zeros((3, hidden_size, hidden_size), np.float32)
            Bqkv = np.zeros((3, hidden_size), np.float32)

            Wqkv[0, :, :] = tensor_dict[prefix + WQ]
            Wqkv[1, :, :] = tensor_dict[prefix + WK]
            Wqkv[2, :, :] = tensor_dict[prefix + WV]
            Bqkv[0, :] = tensor
            Bqkv[1, :] = tensor_dict[prefix + BK]
            Bqkv[2, :] = tensor_dict[prefix + BV]

            if config.use_int8 and getattr(config, 'interleaved', False):
                Wqkv = np.ascontiguousarray(Wqkv.reshape((3, N, H, N, H)))
                Bqkv = np.ascontiguousarray(Bqkv.reshape((3, N, H)))
            else:
                Wqkv = np.ascontiguousarray(
                    Wqkv.reshape((3, N, H, N, H)).transpose((1, 0, 2, 3, 4)))
                Bqkv = np.ascontiguousarray(
                    Bqkv.reshape((3, N, H)).transpose((1, 0, 2)))

            weights_dict[prefix + WQKV] = trt.Weights(Wqkv)
            weights_dict[prefix + BQKV] = trt.Weights(Bqkv)
            weights_dict[prefix + WQKV + "_notrans"] = trt.Weights(Wqkv.T)

        elif outname.find(BK) != -1 or outname.find(BV) != -1 or outname.find(
                WQ) != -1 or outname.find(WK) != -1 or outname.find(WV) != -1:
            pass
        else:
            flat_tensor = np.ascontiguousarray(tensor).flatten()
            weights_dict[outname] = trt.Weights(flat_tensor)

            if outname.find("kernel") != -1:
                tensor = np.transpose(tensor)
                weights_dict[outname + "_notrans"] = trt.Weights(
                    np.ascontiguousarray(tensor).flatten())

    TRT_LOGGER.log(TRT_LOGGER.INFO,
                   "Found {:} entries in weight map".format(len(weights_dict)))
    return weights_dict
Beispiel #12
0
def bottleneck(reader, network, input, ch, stride, lname, branch_type):

    w = reader.get_tensor(lname + "conv1/weights").transpose(3, 2, 0,
                                                             1).reshape(-1)
    b = np.zeros(ch, dtype=np.float32)
    conv1 = network.add_convolution(input, ch, (1, 1), trt.Weights(w),
                                    trt.Weights(b))

    bn1 = add_batchnorm(reader, network, conv1.get_output(0),
                        lname + "conv1/BatchNorm/", 1e-5)

    relu1 = network.add_activation(bn1.get_output(0), trt.ActivationType.RELU)

    w = reader.get_tensor(lname + "conv2/weights").transpose(3, 2, 0,
                                                             1).reshape(-1)
    b = np.zeros(ch, dtype=np.float32)
    conv2 = network.add_convolution(relu1.get_output(0), ch, (3, 3),
                                    trt.Weights(w), trt.Weights(b))
    conv2.stride = (stride, stride)
    conv2.padding = (1, 1)

    bn2 = add_batchnorm(reader, network, conv2.get_output(0),
                        lname + "conv2/BatchNorm/", 1e-5)

    relu2 = network.add_activation(bn2.get_output(0), trt.ActivationType.RELU)

    w = reader.get_tensor(lname + "conv3/weights").transpose(3, 2, 0,
                                                             1).reshape(-1)
    b = np.zeros(ch * 4, dtype=np.float32)
    conv3 = network.add_convolution(relu2.get_output(0), ch * 4, (1, 1),
                                    trt.Weights(w), trt.Weights(b))

    bn3 = add_batchnorm(reader, network, conv3.get_output(0),
                        lname + "conv3/BatchNorm/", 1e-5)

    # branch_type 0:shortcut,1:conv+bn+shortcut,2:maxpool+shortcut
    if branch_type == 0:
        ew1 = network.add_elementwise(input, bn3.get_output(0),
                                      trt.ElementWiseOperation.SUM)
    elif branch_type == 1:
        w = reader.get_tensor(lname + "shortcut/weights").transpose(
            3, 2, 0, 1).reshape(-1)
        b = np.zeros(ch * 4, dtype=np.float32)
        conv4 = network.add_convolution(input, ch * 4, (1, 1), trt.Weights(w),
                                        trt.Weights(b))
        conv4.stride = (stride, stride)
        bn4 = add_batchnorm(reader, network, conv4.get_output(0),
                            lname + "shortcut/BatchNorm/", 1e-5)
        ew1 = network.add_elementwise(bn4.get_output(0), bn3.get_output(0),
                                      trt.ElementWiseOperation.SUM)
    else:
        pool = network.add_pooling(input, trt.PoolingType.MAX, (1, 1))
        pool.stride = (2, 2)
        ew1 = network.add_elementwise(pool.get_output(0), bn3.get_output(0),
                                      trt.ElementWiseOperation.SUM)

    relu3 = network.add_activation(ew1.get_output(0), trt.ActivationType.RELU)

    return relu3
Beispiel #13
0
def acc_ops_quantize_per_tensor(network, target, args, kwargs, name):
    input_val = kwargs["input"]

    if not isinstance(input_val, trt.tensorrt.ITensor):
        raise RuntimeError(f"{name} received input {input_val} that is not part "
                           "of the TensorRT region!")

    q_scale = acc_utils.get_field_from_acc_out_ty(kwargs["acc_out_ty"], "q_scale")
    q_zero_point = acc_utils.get_field_from_acc_out_ty(kwargs["acc_out_ty"], "q_zero_point")
    dtype = acc_utils.get_field_from_acc_out_ty(kwargs["acc_out_ty"], "dtype")
    if dtype not in (torch.quint8, torch.qint8, torch.qint32):
        raise RuntimeError("Only support (torch.quint8, torch.qint8, torch.qint32) "
                           f"quantized type in quantize_per_tensor, get {dtype}.")

    if q_zero_point != 0:
        raise RuntimeError(f"Only support zero_point == 0, get {q_zero_point}")

    # temporarily set q_scale to 1 to make sure the q_scale is different
    # for quantize and dequantize to avoid the error
    # TODO: follow up with nvidia TensorRT team to repro and fix the problem
    q_scale = 1
    scale_layer = network.add_constant((1,), trt.Weights(np.ascontiguousarray([float(q_scale)], dtype=np.float32)))
    scale_layer.name = input_val.name + ".quant.scale"
    scale = scale_layer.get_output(0)
    assert trt.__version__ > "8.0", "Explicit quantize op is only supported in "
    "TensorRT 8.0 or above, current TensorRT version:" + trt.__version__
    layer = network.add_quantize(input=input_val, scale=scale)
    layer.axis = 0
    layer.name = input_val.name + ".quant"
    return layer.get_output(0)
Beispiel #14
0
def acc_ops_dequantize(network, target, args, kwargs, name):
    """
    Currently just a no-op.
    """
    input_val = kwargs["input"]

    if not isinstance(input_val, trt.tensorrt.ITensor):
        raise RuntimeError(f"{name} received input {input_val} that is not part "
                           "of the TensorRT region!")

    q_scale = acc_utils.get_field_from_acc_out_ty(kwargs["input_tensor_meta"], "q_scale")
    q_zero_point = acc_utils.get_field_from_acc_out_ty(kwargs["input_tensor_meta"], "q_zero_point")
    dtype = acc_utils.get_field_from_acc_out_ty(kwargs["input_tensor_meta"], "dtype")

    if dtype not in (torch.quint8, torch.qint8, torch.qint32):
        raise RuntimeError("Only support (torch.quint8, torch.qint8, torch.qint32) "
                           f"quantized type in dequantize, get {dtype}.")

    if q_zero_point != 0:
        raise RuntimeError(f"Only support zero_point == 0, get {q_zero_point}")

    scale_layer = network.add_constant((1,), trt.Weights(np.ascontiguousarray([q_scale], dtype=np.float32)))
    scale_layer.name = input_val.name + ".dequant.scale"
    scale = scale_layer.get_output(0)
    assert trt.__version__ > "8.0", "Explicit dequantize op is only supported in "
    "TensorRT 8.0 or above, current TensorRT version:" + trt.__version__
    layer = network.add_dequantize(input=input_val, scale=scale)
    layer.name = input_val.name + ".dequant"
    layer.axis = 0
    return layer.get_output(0)
Beispiel #15
0
def convert_Linear(ctx):
    module = ctx.method_args[0]
    input = ctx.method_args[1]
    input_trt = trt_(ctx.network, input)
    output = ctx.method_return

    # reshape to ...xNx1x1
    layer = ctx.network.add_shuffle(input_trt)
    layer.reshape_dims = (0, ) * len(input_trt.shape) + (1, 1)

    # add fully connected
    bias = trt.Weights(torch_dtype_to_trt(module.weight.dtype))
    if module.bias is not None:
        bias = module.bias.detach().cpu().numpy()

    layer = ctx.network.add_convolution(
        input=layer.get_output(0),
        num_output_maps=module.out_features,
        kernel_shape=(1, 1),
        kernel=module.weight.detach().cpu().numpy(),
        bias=bias)

    # reshape back to N
    layer = ctx.network.add_shuffle(layer.get_output(0))
    # layer.reshape_dims = tuple(output.shape[1:])
    layer.reshape_dims = (0, ) * len(input_trt.shape)

    output._trt = layer.get_output(0)
Beispiel #16
0
def aten_repeat(inputs, attributes, scope):
    inp, params = inputs
    ctx = current_context()
    net = ctx.network
    if ctx.is_tensorrt and has_trt_tensor(inputs):
        assert params[0] == 1
        assert len(params) > 1
        assert len(params) == len(inp.shape) + 1
        # implement repeat by several gather operation, slower than native repeat
        i = 0
        for p, s in zip(params[1:], inp.shape):
            if p > 1:
                repeat_weights = np.tile(np.arange(0, s), [p]).astype(np.int32)
                layer = net.add_constant([1, s * p],
                                         trt.Weights(repeat_weights))
                layer.name = scope + "/" + "constant_{}".format(i)
                gather_inds = layer.get_output(0)
                gather_inds.name = scope + "/" + "constant_{}".format(i)
                layer = net.add_gather(inp, gather_inds, i)
                layer.name = scope + "/" + "gather_{}".format(i)
                out = layer.get_output(0)
                out.name = scope + "/" + "gather_{}".format(i)
            i += 1
        return [out]
    elif ctx.is_tvm and has_tvm_tensor(inputs):
        raise NotImplementedError

    return [inp.repeat(*params)]
Beispiel #17
0
def aten_matmul(inputs, attributes, scope):
    mat1, mat2 = inputs[:2]
    ctx = current_context()
    net = ctx.network
    if ctx.is_tensorrt and has_trt_tensor(inputs):
        assert isinstance(mat2, torch.Tensor)
        inp = mat1
        weight = mat2.t().detach().cpu().numpy()
        C = weight.shape[0]
        # use fc to implement this
        if len(inp.shape) < 3:
            inp = _trt_reshape(net, inp, [-1, 1, 1], scope + "/reshape")
        layer = net.add_fully_connected(inp, C, weight, trt.Weights())
        output = layer.get_output(0)
        output.name = scope
        layer.name = scope
        ctx.refit_weight_dict[layer.name] = {
            "type": "Linear",
            "weight": inputs[1].__torch2trt_weight_name,
        }
        return [output]
    elif ctx.is_tvm and has_tvm_tensor(inputs):
        inp = mat1
        weight = mat2.t().detach().cpu().numpy()
        C = weight.shape[0]
        weight_t = _expr.var(scope + "/weight",
                             shape=weight.shape,
                             dtype="float32")
        ctx.tvm_weight_dict[weight_t] = weight
        res = _op.nn.dense(inputs[0], weight_t, units=C)
        return [res]
    res = torch.matmul(mat1, mat2)
    return [res]
Beispiel #18
0
def torch_nn_modules_linear_Linear(network, submod, args, kwargs, name):
    # args/kwargs should have already been normalized to kwargs
    assert len(args) == 0
    input_val = kwargs["input"]

    if not isinstance(input_val, trt.tensorrt.ITensor):
        raise RuntimeError(
            f"Linear received input {input_val} that is not part "
            "of the TensorRT region!")

    layer = network.add_shuffle(input_val)
    layer.reshape_dims = tuple(input_val.shape) + (1, 1)
    layer.name = f"{name}_pre_shuffle"

    bias = trt.Weights(torch_dtype_to_trt(submod.weight.dtype))
    if submod.bias is not None:
        bias = to_numpy(submod.bias)

    # add fully connected
    layer = network.add_fully_connected(input=layer.get_output(0),
                                        num_outputs=submod.out_features,
                                        kernel=to_numpy(submod.weight),
                                        bias=bias)
    layer.name = f"{name}_linear"

    # reshape back
    layer = network.add_shuffle(layer.get_output(0))
    layer.reshape_dims = tuple(input_val.shape[:-1]) + (submod.out_features, )
    layer.name = f"{name}_post_shuffle"
    return layer.get_output(0)
Beispiel #19
0
def load_megatron_pickle_weights(path, config):
    N = config.num_attention_heads
    H = config.head_size

    with open(path, 'rb') as f:
        tensor_dict = pickle.load(f)

    weight_dict = {}
    for name, tensor in tensor_dict.items():
        if 'scale' in name:
            continue

        name = (onnx_to_trt_name(name).replace(
            'embedding_',
            'embeddings_').replace('tokentype_', 'token_type_').replace(
                '_av', '_self_av').replace('_qv', '_self_qv').replace(
                    'query_key_value', 'self_qkv'))

        if name.endswith('self_qkv_kernel'):
            tensor = np.ascontiguousarray(tensor.reshape(
                (3, N, H, N, H))).astype(np.float32)
            weight_dict[name] = trt.Weights(tensor)
        elif name.endswith('self_qkv_bias'):
            tensor = np.ascontiguousarray(tensor.reshape(
                (3, N, H))).astype(np.float32)
            weight_dict[name] = trt.Weights(tensor)
        elif name == 'l{}_output_layernorm_output_quantizer_amax'.format(
                config.num_hidden_layers - 1):
            weight_dict['bert_encoder_final_input_quantizer_amax'] = tensor
        elif name.endswith('_amax'):
            weight_dict[name] = tensor
            if name.endswith('_qkv_input_amax'):
                weight_dict[name.replace('_qkv_input_amax',
                                         '_query_input_amax')] = tensor
                weight_dict[name.replace('_qkv_input_amax',
                                         '_key_input_amax')] = tensor
                weight_dict[name.replace('_qkv_input_amax',
                                         '_value_input_amax')] = tensor
        else:
            flat_tensor = np.ascontiguousarray(tensor).flatten().astype(
                np.float32)
            weight_dict[name] = trt.Weights(flat_tensor)

    TRT_LOGGER.log(TRT_LOGGER.INFO,
                   "Found {:} entries in weight map".format(len(weight_dict)))
    return weight_dict
Beispiel #20
0
def add_batchnorm(reader, network, input, lname, eps):
    gamma = reader.get_tensor(lname + "gamma")
    beta = reader.get_tensor(lname + "beta")
    mean = reader.get_tensor(lname + "moving_mean")
    var = reader.get_tensor(lname + "moving_variance")

    scale = gamma / np.sqrt(var + eps)
    shift = -mean / np.sqrt(var + eps) * gamma + beta
    power = np.ones(len(gamma), dtype=np.float32)

    bn = network.add_scale(
        input,
        trt.ScaleMode.CHANNEL,
        trt.Weights(shift),
        trt.Weights(scale),
        trt.Weights(power),
    )
    return bn
Beispiel #21
0
def _scale_or_elementwise(net, lfs, rfs, op, name):
    """pytorch elementwise may contains constants.
    if contains constant, use add_scale, otherwise use add_elementwise
    """
    trt_op = {
        "add": trt.ElementWiseOperation.SUM,
        "sub": trt.ElementWiseOperation.SUB,
        "mul": trt.ElementWiseOperation.PROD,
        "div": trt.ElementWiseOperation.DIV,
    }
    assert op in trt_op
    assert not all([isinstance(t, torch.Tensor) for t in [lfs, rfs]])
    if all([isinstance(t, trt.ITensor) for t in [lfs, rfs]]):
        layer = net.add_elementwise(lfs, rfs, trt_op[op])
        layer.name = name
        output = layer.get_output(0)
        return output
    if isinstance(rfs, torch.Tensor):
        val = rfs.detach().cpu().numpy()
        main = lfs
        scale = val
        if val.size == 1:
            # use scale implementation
            if op == "add":
                shift = trt.Weights(np.array(scale, dtype=np.float32))
                scale = trt.Weights(np.array(1, dtype=np.float32))
            elif op == "sub":
                shift = trt.Weights(np.array(-scale, dtype=np.float32))
                scale = trt.Weights(np.array(1, dtype=np.float32))
            elif op == "mul":
                shift = trt.Weights(np.array(0, dtype=np.float32))
                scale = trt.Weights(np.array(scale, dtype=np.float32))
            elif op == "div":
                shift = trt.Weights(np.array(0, dtype=np.float32))
                scale = trt.Weights(np.array(1 / scale, dtype=np.float32))
            else:
                raise NotImplementedError
            power = trt.Weights(np.array(1, dtype=np.float32))
            layer = net.add_scale(main, trt.ScaleMode.UNIFORM, shift, scale,
                                  power)
        else:
            lfs, rfs = try_convert_to_constant(net, [lfs, rfs])
            layer = net.add_elementwise(lfs, rfs, trt_op[op])
    else:
        lfs, rfs = try_convert_to_constant(net, [lfs, rfs])
        layer = net.add_elementwise(lfs, rfs, trt_op[op])
    layer.name = name
    output = layer.get_output(0)
    return output
Beispiel #22
0
def convert_ConvTranspose1d(ctx):
    module = ctx.method_args[0]
    input = ctx.method_args[1]
    input_trt = trt_(ctx.network, input)
    output = ctx.method_return

    kernel_size = module.kernel_size
    if not isinstance(kernel_size, tuple):
        kernel_size = (kernel_size, 1)
    else:
        kernel_size = kernel_size + (1, )

    stride = module.stride
    if not isinstance(stride, tuple):
        stride = (stride, 1)
    else:
        stride = stride + (1, )

    padding = module.padding
    if not isinstance(padding, tuple):
        padding = (padding, 0)
    else:
        padding = padding + (0, )

    kernel = module.weight.detach().cpu().numpy()[..., None]

    bias = trt.Weights(torch_dtype_to_trt(module.weight.dtype))
    if module.bias is not None:
        bias = module.bias.detach().cpu().numpy()[..., None]

    # unsqueeze(3)
    layer = ctx.network.add_shuffle(input_trt)
    layer.reshape_dims = (0, 0, 0, 1)
    input_trt = layer.get_output(0)

    # deconv
    layer = ctx.network.add_deconvolution(input=input_trt,
                                          num_output_maps=module.out_channels,
                                          kernel_shape=kernel_size,
                                          kernel=kernel,
                                          bias=bias)
    layer.stride = stride
    layer.padding = padding

    if module.groups is not None:
        layer.num_groups = module.groups

    output_trt = layer.get_output(0)

    # squeeze(3)
    layer = ctx.network.add_shuffle(output_trt)
    layer.reshape_dims = (0, 0, 0)
    output_trt = layer.get_output(0)

    output._trt = output_trt
Beispiel #23
0
def inverted_res(network, weight_map, input, lname, inch, outch, s, exp):
    hidden = inch * exp
    use_res_connect = (s == 1 and inch == outch)

    if exp != 1:
        ew1 = conv_bn_relu(network, weight_map, input, hidden, 1, 1, 1,
                           lname + "conv.0.")
        ew2 = conv_bn_relu(network, weight_map, ew1.get_output(0), hidden, 3,
                           s, hidden, lname + "conv.1.")
        conv1 = network.add_convolution(input=ew2.get_output(0),
                                        num_output_maps=outch,
                                        kernel_shape=(1, 1),
                                        kernel=weight_map[lname +
                                                          "conv.2.weight"],
                                        bias=trt.Weights())
        assert conv1
        bn1 = add_batch_norm_2d(network, weight_map, conv1.get_output(0),
                                lname + "conv.3", EPS)
    else:
        ew1 = conv_bn_relu(network, weight_map, input, hidden, 3, s, hidden,
                           lname + "conv.0.")
        conv1 = network.add_convolution(input=ew1.get_output(0),
                                        num_output_maps=outch,
                                        kernel_shape=(1, 1),
                                        kernel=weight_map[lname +
                                                          "conv.1.weight"],
                                        bias=trt.Weights())
        assert conv1
        bn1 = add_batch_norm_2d(network, weight_map, conv1.get_output(0),
                                lname + "conv.2", EPS)

    if not use_res_connect:
        return bn1

    ew3 = network.add_elementwise(input, bn1.get_output(0),
                                  trt.ElementWiseOperation.SUM)
    assert ew3

    return ew3
Beispiel #24
0
def convert_Conv1d(ctx):

    module = ctx.method_args[0]
    input = ctx.method_args[1]
    input_trt = trt_(ctx.network, input)
    output = ctx.method_return

    kernel_size = (module.kernel_size[0], 1)
    stride = (module.stride[0], 1)
    padding = (module.padding[0], 0)
    dilation = (module.dilation[0], 1)

    kernel = module.weight.detach().cpu().numpy()[..., None]

    bias = trt.Weights(torch_dtype_to_trt(module.weight.dtype))
    if module.bias is not None:
        bias = module.bias.detach().cpu().numpy()

    # reshape to 2D
    input_shape_trt = ctx.network.add_shape(input_trt).get_output(0)
    one_trt = trt_(ctx.network,
                   torch.tensor([1], dtype=torch.int32).to(input.device))
    new_input_shape_trt = ctx.network.add_concatenation(
        [input_shape_trt, one_trt]).get_output(0)
    layer = ctx.network.add_shuffle(input_trt)
    layer.set_input(1, new_input_shape_trt)

    layer = ctx.network.add_convolution(input=layer.get_output(0),
                                        num_output_maps=module.out_channels,
                                        kernel_shape=kernel_size,
                                        kernel=kernel,
                                        bias=bias)
    layer.stride = stride
    layer.padding = padding
    layer.dilation = dilation

    if module.groups is not None:
        layer.num_groups = module.groups

    # reshape back to 1D
    conv_out_trt = layer.get_output(0)
    out_shape_trt = ctx.network.add_shape(conv_out_trt).get_output(0)
    new_out_shape_trt = ctx.network.add_slice(out_shape_trt, [0], [3],
                                              [1]).get_output(0)
    layer = ctx.network.add_shuffle(conv_out_trt)
    layer.set_input(1, new_out_shape_trt)

    output._trt = layer.get_output(0)
Beispiel #25
0
def try_convert_to_constant(net, inputs):
    res = []
    ref_shape = None
    for inp in inputs:
        if isinstance(inp, trt.ITensor):
            ref_shape = inp.shape
    for inp in inputs:
        if isinstance(inp, torch.Tensor):
            inp = inp.detach().cpu().numpy()
            if inp.dtype == np.float64:
                inp = inp.astype(np.float32)
            if len(inp.shape) == 0:
                inp = inp.reshape(*([1] * len(ref_shape)))
            layer = net.add_constant(inp.shape, trt.Weights(inp))
            inp = layer.get_output(0)
        res.append(inp)
    return res
Beispiel #26
0
def conv_bn_h_swish(network, weight_map, input, outch, ksize, s, g, lname):
    p = (ksize - 1) // 2
    conv1 = network.add_convolution(input=input,
                                    num_output_maps=outch,
                                    kernel_shape=(ksize, ksize),
                                    kernel=weight_map[lname + "0.weight"],
                                    bias=trt.Weights())
    assert conv1
    conv1.stride = (s, s)
    conv1.padding = (p, p)
    conv1.num_groups = g

    bn1 = add_batch_norm_2d(network, weight_map, conv1.get_output(0),
                            lname + "1", EPS)
    hsw = add_h_swish(network, bn1.get_output(0))
    assert hsw

    return hsw
Beispiel #27
0
def aten_convolution(inputs, attributes, scope):
    inp, weight, bias = inputs[:3]
    stride, pad, dilation = inputs[3:6]
    transposed, output_padding, groups = inputs[6:9]
    net = current_network()
    if net is not None and has_trt_tensor(inputs):
        assert all([e == 0 for e in output_padding
                    ]), "tensor rt don't support out padding"
        if transposed:
            I, O_groups, *ksize = weight.shape
            O = O_groups * groups
        else:
            O, I_groups, *ksize = weight.shape
            I = I_groups * groups
        ndim = len(ksize)
        assert ndim == 2, "tensorrt only support 2d conv"
        # trt weight format: GKCRS: [num_groups, O_groups, I, H, W]
        weight = weight.detach().cpu().numpy()
        if bias is not None:
            bias = bias.detach().cpu().numpy()
        else:
            bias = trt.Weights()
        if transposed:
            layer = net.add_deconvolution(inputs[0], O, tuple(ksize), weight,
                                          bias)
        else:
            layer = net.add_convolution(inputs[0], O, tuple(ksize), weight,
                                        bias)
            layer.dilation = tuple(dilation)
        layer.stride = tuple(stride)
        layer.padding = tuple(pad)
        layer.num_groups = groups
        output = layer.get_output(0)
        output.name = scope
        layer.name = scope
        return [output]
    ndim = len(inputs[3])
    assert ndim == 2
    if transposed:
        res = F.conv_transpose2d(inp, weight, bias, stride, pad,
                                 output_padding, groups, dilation)
    else:
        res = F.conv2d(inp, weight, bias, stride, pad, dilation, groups)
    return [res]
    def add_conv(self, network, inp, padding=None, stride=None, dilation=None):
        # Kernel should never be missing.
        kernel = self.pop_weights("weight")
        # Kernel is always NCHW
        kernel_N = kernel.shape[0]
        kernel_HW = kernel.shape[2:4]
        # Bias can be missing.
        bias = self.pop_weights("bias")
        bias = bias if bias is not None else trt.Weights()
        conv = network.add_convolution(inp,
                                       num_output_maps=kernel_N,
                                       kernel_shape=kernel_HW,
                                       kernel=kernel,
                                       bias=bias)
        conv.stride = stride or conv.stride
        conv.padding = padding or conv.padding
        conv.dilation = dilation or conv.dilation

        return conv.get_output(0)
Beispiel #29
0
def convert_Conv2d(ctx):
    module = ctx.method_args[0]
    input = ctx.method_args[1]
    input_trt = trt_(ctx.network, input)
    output = ctx.method_return

    kernel_size = module.kernel_size
    if not isinstance(kernel_size, tuple):
        kernel_size = (kernel_size, ) * 2

    stride = module.stride
    if not isinstance(stride, tuple):
        stride = (stride, ) * 2

    padding = module.padding
    if not isinstance(padding, tuple):
        padding = (padding, ) * 2

    dilation = module.dilation
    if not isinstance(dilation, tuple):
        dilation = (dilation, ) * 2

    kernel = module.weight.detach().cpu().numpy()

    bias = trt.Weights(torch_dtype_to_trt(module.weight.dtype))
    if module.bias is not None:
        bias = module.bias.detach().cpu().numpy()

    layer = ctx.network.add_convolution(input=input_trt,
                                        num_output_maps=module.out_channels,
                                        kernel_shape=kernel_size,
                                        kernel=kernel,
                                        bias=bias)
    layer.stride = stride
    layer.padding = padding
    layer.dilation = dilation

    if module.groups is not None:
        layer.num_groups = module.groups

    output._trt = layer.get_output(0)
Beispiel #30
0
def conv_bn_relu(network, weight_map, input, outch, ksize, s, g, lname):
    p = (ksize - 1) // 2

    conv1 = network.add_convolution(input=input,
                                    num_output_maps=outch,
                                    kernel_shape=(ksize, ksize),
                                    kernel=weight_map[lname + "0.weight"],
                                    bias=trt.Weights())
    assert conv1
    conv1.stride = (s, s)
    conv1.padding = (p, p)
    conv1.num_groups = g

    bn1 = add_batch_norm_2d(network, weight_map, conv1.get_output(0),
                            lname + "1", EPS)
    assert bn1

    relu1 = network.add_activation(bn1.get_output(0),
                                   type=trt.ActivationType.RELU)
    assert relu1

    shift = np.array(-6.0, dtype=np.float32)
    scale = np.array(1.0, dtype=np.float32)
    power = np.array(1.0, dtype=np.float32)
    scale1 = network.add_scale(input=bn1.get_output(0),
                               mode=trt.ScaleMode.UNIFORM,
                               shift=shift,
                               scale=scale,
                               power=power)
    assert scale1

    relu2 = network.add_activation(scale1.get_output(0),
                                   type=trt.ActivationType.RELU)
    assert relu2

    ew1 = network.add_elementwise(relu1.get_output(0), relu2.get_output(0),
                                  trt.ElementWiseOperation.SUB)
    assert ew1

    return ew1