예제 #1
0
    a = torch.add(input, bias)
    b = torch.nn.functional.gelu(a)
    c = torch.nn.functional.dropout(b, p=0.6, training=True)
    return c


def aot_fn(input, bias):
    a = torch.add(input, bias)
    b = a * 0.5 * (1.0 + torch.tanh(0.79788456 * a * (1 + 0.044715 * a * a)))
    c = torch.nn.functional.dropout(b, p=0.6, training=True)
    return c


fn = bias_gelu_dropout

clear_compile_cache()

# Set inputs
device = "cuda"
dtype = torch.float16
batch_size = 32
seq_len = 196
intermediate_size = 4096
# batch_size = 2
# seq_len = 4
# intermediate_size = 3
input = torch.randn(
    batch_size,
    seq_len,
    intermediate_size,
    requires_grad=True,
예제 #2
0
 def setUp(self):
     super().setUp()
     # NB: We cache on function id, which is unreliable
     # Can fix by using weakrefs, but not sure if it matters
     clear_compile_cache()