def gen_qkv(batch_size, seq_len, embed_size, attn_type): q = profiling.generate_batch(batch_size, seq_len, embed_size) if attn_type == 'self': k, v = None, None elif attn_type == 'encdec': k = profiling.generate_batch(batch_size, seq_len, embed_size) v = None elif attn_type == 'arb': k = profiling.generate_batch(batch_size, seq_len, embed_size) v = profiling.generate_batch(batch_size, seq_len, embed_size) return q, k, v
def time_multihead_attention(q, num_heads, k=None, v=None, mask=False, mode='self', bias=True, do_backprop=True, fp='fp32', use_apex=False, num_iters=100, num_warmups=5): """Benchmark multi-head attention. q, k, v are input values in (sequence, batch, embedding) order, passed based on the mode. num_heads is the number of heads in multi-head attention. mask is True if doing masked multi-head attention. mode is one of 'self', 'encdec', or 'arb'. 'self' requires only q; 'encdec' needs q and k; and 'arb' needs q, k, and v. do_backprop is whether to benchmark backprop. fp is the precision to perform operations in. use_apex is whether to import and use Nvidia's Apex library. num_iters and num_warmups are the number of warmup and benchmarking iterations, respectively. Returns the runtimes for each iteration of each function. """ if use_apex: from apex import amp embed_size = q.size(2) attn = torch.nn.MultiheadAttention(embed_size, num_heads, bias=bias).to(profiling.cuda_device) attn.train() q = q.to(profiling.cuda_device) if k is not None: k = k.to(profiling.cuda_device) mask_shape = (q.size(0), k.size(0)) else: mask_shape = (q.size(0), q.size(0)) if v is not None: v = v.to(profiling.cuda_device) dy = profiling.generate_batch(q.size(1), q.size(0), embed_size).to(profiling.cuda_device) if mask: mask = profiling.gen_attention_mask(*mask_shape).to( profiling.cuda_device) else: mask = None if fp == 'fp16': if use_apex: attn = amp.initialize(attn) else: q = q.half() if k is not None: k = k.half() if v is not None: v = v.half() if mask is not None: mask = mask.half() attn = attn.half() dy = dy.half() result, backward_result = None, None def forward(): nonlocal result if mode == 'self': result = attn.forward(q, q, q, need_weights=False, attn_mask=mask)[0] elif mode == 'encdec': result = attn.forward(q, k, k, need_weights=False, attn_mask=mask)[0] elif mode == 'arb': result = attn.forward(q, k, v, need_weights=False, attn_mask=mask)[0] def backward(): nonlocal backward_result backward_result = result.backward(dy) def clear(): attn.zero_grad() return profiling.time_funcs([forward, backward, clear], name='MHA ' + mode, func_names=['forward', 'backward', 'clear'], num_iters=num_iters, warmups=num_warmups)
attn.zero_grad() return profiling.time_funcs([forward, backward, clear], name='MHA ' + mode, func_names=['forward', 'backward', 'clear'], num_iters=num_iters, warmups=num_warmups) if __name__ == '__main__': args = parser.parse_args() # Check this here first. if args.plot_file and os.path.exists(args.plot_file): print(f'{args.plot_file} exists, aborting.') sys.exit(1) q = profiling.generate_batch(args.batch_size, args.max_seq_len, args.embed_size) if args.attn_type == 'self': k = None v = None elif args.attn_type == 'encdec': k = profiling.generate_batch(args.batch_size, args.max_enc_seq_len, args.embed_size) v = None elif args.attn_type == 'arb': k = profiling.generate_batch(args.batch_size, args.max_enc_seq_len, args.embed_size) v = profiling.generate_batch(args.batch_size, args.max_enc_seq_len, args.embed_size) times = time_multihead_attention(q, args.num_heads, k=k,
'fwd_stdev': [], 'bwd_time': [], 'bwd_stdev': [] } # Print table header. print('Batch'.rjust(12) + '\t' + 'Seq Len'.rjust(12) + '\t' + 'Embed'.rjust(12) + '\t' + 'Heads'.rjust(12) + '\t' + 'Fwd'.rjust(12) + '\t' + 'Stdev'.rjust(12) + '\t' + 'Bwd'.rjust(12) + '\t' + 'Stdev'.rjust(12) + '\t') # Profile all configurations. for batch_size in batch_sizes: for seq_len in seq_lens: for embed_size in embed_sizes: x = profiling.generate_batch(batch_size, seq_len, embed_size) if args.layer_type == 'encdec': encoder_out = profiling.generate_batch(batch_size, seq_len, embed_size) for num_heads in heads: times['batch'].append(batch_size) times['seq_len'].append(seq_len) times['embed'].append(embed_size) times['heads'].append(num_heads) try: if args.layer_type == 'encoder': t_times = time_encoder(x, num_heads, fp=args.fp) elif args.layer_type == 'decoder': t_times = time_decoder(x, num_heads, fp=args.fp) elif args.layer_type == 'encdec': t_times = time_encdec(x,
def time_encoder(x, num_heads, activation='relu', bias=True, dropout=True, do_backprop=True, fp='fp32', use_apex=False, num_iters=100, num_warmups=5): """Benchmark a transformer encoder layer. x is the input sequence in (sequence, batch, embedding) order. num_heads is the number of multi-head attention heads. activation is the activation function to use. bias is whether to use bias in attention. do_backprop is whether to benchmark backprop. fp is the precision to perform operations in. use_apex is whether to import and use Nvidia's Apex library. num_iters and num_warmups are the number of warmup and benchmarking iterations, respectively. Returns the runtimes for each iteration of each function. """ if use_apex: from apex import amp embed_size = x.size(2) encoder = torch.nn.TransformerEncoderLayer(embed_size, num_heads, dim_feedforward=4 * embed_size, activation=activation) if not bias or not dropout: new_bias = bias new_dropout = 0.1 if dropout else 0.0 encoder.self_attn = torch.nn.MultiheadAttention(embed_size, num_heads, dropout=new_dropout, bias=new_bias) encoder = encoder.to(profiling.cuda_device) encoder.train() #x = x.requires_grad_().to(profiling.cuda_device) x = x.to(profiling.cuda_device).requires_grad_() dy = profiling.generate_batch(x.size(1), x.size(0), embed_size).to(profiling.cuda_device) if fp == 'fp16': if use_apex: encoder = amp.initialize(encoder) else: encoder = encoder.half() x = x.half() dy = dy.half() result, backward_result = None, None def forward(): nonlocal result result = encoder.forward(x) def backward(): nonlocal backward_result backward_result = result.backward(dy) def clear(): encoder.zero_grad() return profiling.time_funcs([forward, backward, clear], name='Encoder', func_names=['forward', 'backward', 'clear'], num_iters=num_iters, warmups=num_warmups)
def time_encdec(x, encoder_out, num_heads, activation='relu', bias=True, dropout=True, do_backprop=True, use_apex=False, fp='fp32', num_iters=100, num_warmups=5): """Benchmark a transformer decoder layer with encoder/decoder attention. x is the input sequence in (sequence, batch, embedding) order. encoder_out is the output from an encoder in (sequence, batch, embedding) order. num_heads is the number of multi-head attention heads. activation is the activation function to use. do_backprop is whether to benchmark backprop. fp is the precision to perform operations in. use_apex is whether to import and use Nvidia's Apex library. num_iters and num_warmups are the number of warmup and benchmarking iterations, respectively. Returns the runtimes for each iteration of each function. """ if use_apex: from apex import amp if not bias or not dropout: raise ValueError('Not supported') embed_size = x.size(2) decoder = torch.nn.TransformerDecoderLayer(embed_size, num_heads, dim_feedforward=4 * embed_size, activation=activation).to( profiling.cuda_device) decoder.train() x = x.to(profiling.cuda_device).requires_grad_() encoder_out = encoder_out.to(profiling.cuda_device) dy = profiling.generate_batch(x.size(1), x.size(0), embed_size).to(profiling.cuda_device) mask = profiling.gen_attention_mask(x.size(0), x.size(0)).to(profiling.cuda_device) if fp == 'fp16': if use_apex: decoder = amp.initialize(decoder) else: decoder = decoder.half() x = x.half() encoder_out = encoder_out.half() dy = dy.half() mask = mask.half() # Must compute a gradient for this. encoder_out = encoder_out.requires_grad_() result, backward_result = None, None def forward(): nonlocal result result = decoder.forward(x, encoder_out, tgt_mask=mask) def backward(): nonlocal backward_result backward_result = result.backward(dy) def clear(): decoder.zero_grad() return profiling.time_funcs([forward, backward, clear], name='Encdec', func_names=['forward', 'backward', 'clear'], num_iters=num_iters, warmups=num_warmups)
decoder.zero_grad() return profiling.time_funcs([forward, backward, clear], name='Encdec', func_names=['forward', 'backward', 'clear'], num_iters=num_iters, warmups=num_warmups) if __name__ == '__main__': args = parser.parse_args() # Check this here first. if args.plot_file and os.path.exists(args.plot_file): print(f'{args.plot_file} exists, aborting.') sys.exit(1) x = profiling.generate_batch(args.batch_size, args.max_seq_len, args.embed_size) if args.layer == 'encoder': times = time_encoder(x, args.num_heads, activation=args.activation, bias=not args.no_attn_bias, dropout=not args.no_attn_dropout, do_backprop=not args.no_backprop, fp=args.fp, use_apex=args.apex, num_iters=args.num_iters, num_warmups=args.num_warmups) elif args.layer == 'decoder': times = time_decoder(x, args.num_heads, activation=args.activation,