def gen_qkv(batch_size, seq_len, embed_size, attn_type):
    q = profiling.generate_batch(batch_size, seq_len, embed_size)
    if attn_type == 'self':
        k, v = None, None
    elif attn_type == 'encdec':
        k = profiling.generate_batch(batch_size, seq_len, embed_size)
        v = None
    elif attn_type == 'arb':
        k = profiling.generate_batch(batch_size, seq_len, embed_size)
        v = profiling.generate_batch(batch_size, seq_len, embed_size)
    return q, k, v
Exemple #2
0
def time_multihead_attention(q,
                             num_heads,
                             k=None,
                             v=None,
                             mask=False,
                             mode='self',
                             bias=True,
                             do_backprop=True,
                             fp='fp32',
                             use_apex=False,
                             num_iters=100,
                             num_warmups=5):
    """Benchmark multi-head attention.

    q, k, v are input values in (sequence, batch, embedding) order,
    passed based on the mode.

    num_heads is the number of heads in multi-head attention.

    mask is True if doing masked multi-head attention.

    mode is one of 'self', 'encdec', or 'arb'. 'self' requires only
    q; 'encdec' needs q and k; and 'arb' needs q, k, and v.

    do_backprop is whether to benchmark backprop.

    fp is the precision to perform operations in.

    use_apex is whether to import and use Nvidia's Apex library.

    num_iters and num_warmups are the number of warmup and benchmarking
    iterations, respectively.

    Returns the runtimes for each iteration of each function.

    """
    if use_apex:
        from apex import amp
    embed_size = q.size(2)
    attn = torch.nn.MultiheadAttention(embed_size, num_heads,
                                       bias=bias).to(profiling.cuda_device)
    attn.train()
    q = q.to(profiling.cuda_device)
    if k is not None:
        k = k.to(profiling.cuda_device)
        mask_shape = (q.size(0), k.size(0))
    else:
        mask_shape = (q.size(0), q.size(0))
    if v is not None:
        v = v.to(profiling.cuda_device)
    dy = profiling.generate_batch(q.size(1), q.size(0),
                                  embed_size).to(profiling.cuda_device)
    if mask:
        mask = profiling.gen_attention_mask(*mask_shape).to(
            profiling.cuda_device)
    else:
        mask = None
    if fp == 'fp16':
        if use_apex:
            attn = amp.initialize(attn)
        else:
            q = q.half()
            if k is not None:
                k = k.half()
            if v is not None:
                v = v.half()
            if mask is not None:
                mask = mask.half()
            attn = attn.half()
            dy = dy.half()
    result, backward_result = None, None

    def forward():
        nonlocal result
        if mode == 'self':
            result = attn.forward(q, q, q, need_weights=False,
                                  attn_mask=mask)[0]
        elif mode == 'encdec':
            result = attn.forward(q, k, k, need_weights=False,
                                  attn_mask=mask)[0]
        elif mode == 'arb':
            result = attn.forward(q, k, v, need_weights=False,
                                  attn_mask=mask)[0]

    def backward():
        nonlocal backward_result
        backward_result = result.backward(dy)

    def clear():
        attn.zero_grad()

    return profiling.time_funcs([forward, backward, clear],
                                name='MHA ' + mode,
                                func_names=['forward', 'backward', 'clear'],
                                num_iters=num_iters,
                                warmups=num_warmups)
Exemple #3
0
        attn.zero_grad()

    return profiling.time_funcs([forward, backward, clear],
                                name='MHA ' + mode,
                                func_names=['forward', 'backward', 'clear'],
                                num_iters=num_iters,
                                warmups=num_warmups)


if __name__ == '__main__':
    args = parser.parse_args()
    # Check this here first.
    if args.plot_file and os.path.exists(args.plot_file):
        print(f'{args.plot_file} exists, aborting.')
        sys.exit(1)
    q = profiling.generate_batch(args.batch_size, args.max_seq_len,
                                 args.embed_size)
    if args.attn_type == 'self':
        k = None
        v = None
    elif args.attn_type == 'encdec':
        k = profiling.generate_batch(args.batch_size, args.max_enc_seq_len,
                                     args.embed_size)
        v = None
    elif args.attn_type == 'arb':
        k = profiling.generate_batch(args.batch_size, args.max_enc_seq_len,
                                     args.embed_size)
        v = profiling.generate_batch(args.batch_size, args.max_enc_seq_len,
                                     args.embed_size)
    times = time_multihead_attention(q,
                                     args.num_heads,
                                     k=k,
    'fwd_stdev': [],
    'bwd_time': [],
    'bwd_stdev': []
}

# Print table header.
print('Batch'.rjust(12) + '\t' + 'Seq Len'.rjust(12) + '\t' +
      'Embed'.rjust(12) + '\t' + 'Heads'.rjust(12) + '\t' + 'Fwd'.rjust(12) +
      '\t' + 'Stdev'.rjust(12) + '\t' + 'Bwd'.rjust(12) + '\t' +
      'Stdev'.rjust(12) + '\t')

# Profile all configurations.
for batch_size in batch_sizes:
    for seq_len in seq_lens:
        for embed_size in embed_sizes:
            x = profiling.generate_batch(batch_size, seq_len, embed_size)
            if args.layer_type == 'encdec':
                encoder_out = profiling.generate_batch(batch_size, seq_len,
                                                       embed_size)
            for num_heads in heads:
                times['batch'].append(batch_size)
                times['seq_len'].append(seq_len)
                times['embed'].append(embed_size)
                times['heads'].append(num_heads)
                try:
                    if args.layer_type == 'encoder':
                        t_times = time_encoder(x, num_heads, fp=args.fp)
                    elif args.layer_type == 'decoder':
                        t_times = time_decoder(x, num_heads, fp=args.fp)
                    elif args.layer_type == 'encdec':
                        t_times = time_encdec(x,
Exemple #5
0
def time_encoder(x,
                 num_heads,
                 activation='relu',
                 bias=True,
                 dropout=True,
                 do_backprop=True,
                 fp='fp32',
                 use_apex=False,
                 num_iters=100,
                 num_warmups=5):
    """Benchmark a transformer encoder layer.

    x is the input sequence in (sequence, batch, embedding) order.

    num_heads is the number of multi-head attention heads.

    activation is the activation function to use.

    bias is whether to use bias in attention.

    do_backprop is whether to benchmark backprop.

    fp is the precision to perform operations in.

    use_apex is whether to import and use Nvidia's Apex library.

    num_iters and num_warmups are the number of warmup and benchmarking
    iterations, respectively.

    Returns the runtimes for each iteration of each function.

    """
    if use_apex:
        from apex import amp
    embed_size = x.size(2)
    encoder = torch.nn.TransformerEncoderLayer(embed_size,
                                               num_heads,
                                               dim_feedforward=4 * embed_size,
                                               activation=activation)
    if not bias or not dropout:
        new_bias = bias
        new_dropout = 0.1 if dropout else 0.0
        encoder.self_attn = torch.nn.MultiheadAttention(embed_size,
                                                        num_heads,
                                                        dropout=new_dropout,
                                                        bias=new_bias)
    encoder = encoder.to(profiling.cuda_device)
    encoder.train()
    #x = x.requires_grad_().to(profiling.cuda_device)
    x = x.to(profiling.cuda_device).requires_grad_()
    dy = profiling.generate_batch(x.size(1), x.size(0),
                                  embed_size).to(profiling.cuda_device)
    if fp == 'fp16':
        if use_apex:
            encoder = amp.initialize(encoder)
        else:
            encoder = encoder.half()
            x = x.half()
            dy = dy.half()
    result, backward_result = None, None

    def forward():
        nonlocal result
        result = encoder.forward(x)

    def backward():
        nonlocal backward_result
        backward_result = result.backward(dy)

    def clear():
        encoder.zero_grad()

    return profiling.time_funcs([forward, backward, clear],
                                name='Encoder',
                                func_names=['forward', 'backward', 'clear'],
                                num_iters=num_iters,
                                warmups=num_warmups)
Exemple #6
0
def time_encdec(x,
                encoder_out,
                num_heads,
                activation='relu',
                bias=True,
                dropout=True,
                do_backprop=True,
                use_apex=False,
                fp='fp32',
                num_iters=100,
                num_warmups=5):
    """Benchmark a transformer decoder layer with encoder/decoder attention.

    x is the input sequence in (sequence, batch, embedding) order.

    encoder_out is the output from an encoder in (sequence, batch, embedding)
    order.

    num_heads is the number of multi-head attention heads.

    activation is the activation function to use.

    do_backprop is whether to benchmark backprop.

    fp is the precision to perform operations in.

    use_apex is whether to import and use Nvidia's Apex library.

    num_iters and num_warmups are the number of warmup and benchmarking
    iterations, respectively.

    Returns the runtimes for each iteration of each function.

    """
    if use_apex:
        from apex import amp
    if not bias or not dropout:
        raise ValueError('Not supported')
    embed_size = x.size(2)
    decoder = torch.nn.TransformerDecoderLayer(embed_size,
                                               num_heads,
                                               dim_feedforward=4 * embed_size,
                                               activation=activation).to(
                                                   profiling.cuda_device)
    decoder.train()
    x = x.to(profiling.cuda_device).requires_grad_()
    encoder_out = encoder_out.to(profiling.cuda_device)
    dy = profiling.generate_batch(x.size(1), x.size(0),
                                  embed_size).to(profiling.cuda_device)
    mask = profiling.gen_attention_mask(x.size(0),
                                        x.size(0)).to(profiling.cuda_device)
    if fp == 'fp16':
        if use_apex:
            decoder = amp.initialize(decoder)
        else:
            decoder = decoder.half()
            x = x.half()
            encoder_out = encoder_out.half()
            dy = dy.half()
            mask = mask.half()
    # Must compute a gradient for this.
    encoder_out = encoder_out.requires_grad_()
    result, backward_result = None, None

    def forward():
        nonlocal result
        result = decoder.forward(x, encoder_out, tgt_mask=mask)

    def backward():
        nonlocal backward_result
        backward_result = result.backward(dy)

    def clear():
        decoder.zero_grad()

    return profiling.time_funcs([forward, backward, clear],
                                name='Encdec',
                                func_names=['forward', 'backward', 'clear'],
                                num_iters=num_iters,
                                warmups=num_warmups)
Exemple #7
0
        decoder.zero_grad()

    return profiling.time_funcs([forward, backward, clear],
                                name='Encdec',
                                func_names=['forward', 'backward', 'clear'],
                                num_iters=num_iters,
                                warmups=num_warmups)


if __name__ == '__main__':
    args = parser.parse_args()
    # Check this here first.
    if args.plot_file and os.path.exists(args.plot_file):
        print(f'{args.plot_file} exists, aborting.')
        sys.exit(1)
    x = profiling.generate_batch(args.batch_size, args.max_seq_len,
                                 args.embed_size)
    if args.layer == 'encoder':
        times = time_encoder(x,
                             args.num_heads,
                             activation=args.activation,
                             bias=not args.no_attn_bias,
                             dropout=not args.no_attn_dropout,
                             do_backprop=not args.no_backprop,
                             fp=args.fp,
                             use_apex=args.apex,
                             num_iters=args.num_iters,
                             num_warmups=args.num_warmups)
    elif args.layer == 'decoder':
        times = time_decoder(x,
                             args.num_heads,
                             activation=args.activation,