Ejemplo n.º 1
0
 def build(x):
     return [
         mb.batch_norm(x=x, mean=mean_val, variance=variance_val),
         mb.batch_norm(
             x=x,
             mean=mean_val,
             variance=variance_val,
             gamma=gamma_val,
             beta=beta_val,
             epsilon=1e-4,
         ),
     ]
Ejemplo n.º 2
0
    def _add_batch_norm(x, mean, variance, scale, offset, epsilon, name):

        if mean.shape[0] != 0 and variance.shape[0] != 0:
            # In this case, we can use the mb.batch_norm directly
            x = mb.batch_norm(x=x,
                              mean=mean,
                              variance=variance,
                              gamma=scale,
                              beta=offset,
                              epsilon=epsilon,
                              name=name)
        else:
            # In this case, we need to manually compute the batch_norm
            axes = [axis for axis in range(x.rank) if axis != 1]
            mean = mb.reduce_mean(x=x, axes=axes, keep_dims=True)
            num = mb.sub(x=x, y=mean)
            square = mb.mul(x=num, y=num)
            variance = mb.reduce_mean(x=square, axes=axes, keep_dims=True)
            variance_add_epsilon = mb.add(x=variance, y=epsilon)
            sqrt = mb.sqrt(x=variance_add_epsilon)
            x = mb.real_div(x=num, y=sqrt)

            shape = [1] * x.rank
            shape[1] = -1 if any_symbolic(scale.shape) else scale.shape[0]
            scale_reshape = mb.reshape(x=scale, shape=shape)
            offset_reshape = mb.reshape(x=offset, shape=shape)

            x = mb.mul(x=x, y=scale_reshape)
            x = mb.add(x=x, y=offset_reshape, name=name)

        return x
    def prog(x):

        if is_deconv:
            x = mb.conv_transpose(
                    x=x,
                    weight=inputs["conv_weight"],
                    bias=inputs["conv_bias"],
                    groups=groups,
                )
        else:
            x = mb.conv(
                    x=x,
                    weight=inputs["conv_weight"],
                    bias=inputs["conv_bias"],
                    groups=groups,
                )

        x = mb.batch_norm(
                x=x,
                mean=inputs["mean"],
                variance=inputs["variance"],
                gamma=inputs["gamma"],
                beta=inputs["beta"],
                epsilon=inputs["epsilon"],
            )
        return x
        def prog(x):
            # conv layer
            conv_weight = np.random.rand(Cin, Cout // groups, 2) if rank == 3 else np.random.rand(Cin, Cout // groups, 2, 3)
            conv_bias = np.random.rand(Cout) if has_bias else None
            x = mb.conv_transpose(
                    x=x,
                    weight=conv_weight,
                    bias=conv_bias,
                    groups=groups,
                )

            # batch_norm layer
            gamma = np.random.rand(Cout)
            beta = np.random.rand(Cout)
            mean = np.random.rand(Cout)
            variance = np.random.rand(Cout)

            epsilon = 1e-5
            x = mb.batch_norm(
                    x=x,
                    mean=mean,
                    variance=variance,
                    gamma=gamma,
                    beta=beta,
                    epsilon=epsilon,
                )
            return x
def try_to_transform(mul_op, add_op, block):
    non_const_input_mul = mul_op.x if mul_op.x.val is None else mul_op.y
    if non_const_input_mul.rank != 4:
        return False

    gamma = _find_const_input_val(mul_op)
    beta = _find_const_input_val(add_op)
    if gamma is None or beta is None:
        return False

    if not (isinstance(gamma, np.ndarray) and isinstance(beta, np.ndarray)):
        return False

    # check that gamma and beta have shape (1,C,1,1) or (C,1,1)
    # that is they are doing vector addition on the axis=-3, which is what the
    # batchnorm layer does (batchnorm layer only works on rank 4 input tensors)
    if not (_check_shape(gamma) and _check_shape(beta)):
        return False

    C = gamma.shape[-3]
    if C == 1:
        return False

    out_name = add_op.outputs[0].name
    x = mb.batch_norm(
        x=non_const_input_mul,
        mean=np.zeros((C, ), np.float32),
        variance=np.ones((C, ), np.float32),
        gamma=np.squeeze(gamma),
        beta=np.squeeze(beta),
        name=out_name,
        before_op=mul_op,
    )

    add_op.enclosing_block.replace_uses_of_var_after_op(
        anchor_op=add_op, old_var=add_op.outputs[0], new_var=x)
    # Remove all the ops at once
    block.remove_ops([mul_op, add_op])
    return True