def main(): args = parse_args() if args.type == 'pytorch': from mem_transformer import MemTransformerLM else: from inference.mem_transformer_base_jit import MemTransformerLM torch.cuda.set_device(args.local_rank) device = torch.device('cuda' if args.cuda else 'cpu') utils.distributed.init_distributed(args.cuda) with utils.distributed.sync_workers() as rank: if rank == 0: create_exp_dir(args.work_dir, debug=args.debug) # Setup logging if args.log_all_ranks: log_file = f'log_rank_{utils.distributed.get_rank()}.log' else: log_file = f'log.log' log_file = os.path.join(args.work_dir, log_file) if args.debug: log_file = os.devnull utils.exp_utils.setup_logging( log_all_ranks=args.log_all_ranks, filename=log_file, filemode='a', ) logging.info(args) if args.model: model_path = args.model elif args.work_dir: model_path = os.path.join(args.work_dir, 'checkpoint_best.pt') else: raise RuntimeError( 'Specify path to checkpoint using --model or --work_dir') checkpoint = load_checkpoint(model_path) if args.manual: args.batch_size = 1 vocab = checkpoint['vocab'] if hasattr(vocab, 'sym2idx') and not hasattr(vocab, 'unk_idx'): vocab.unk_idx = vocab.sym2idx['<unk>'] text = " ".join(args.manual) tokenized = tokenize_raw(text) symbols = vocab.tokenize(tokenized, add_eos=True) tensor = vocab.convert_to_tensor(symbols) iter = data_utils.LMOrderedIterator(tensor, bsz=args.batch_size, bptt=args.tgt_len, device=device, ext_len=args.ext_len) else: # Load dataset corpus = get_lm_corpus(args.data, args.dataset, checkpoint['args'].vocab) if args.split == 'valid': iter = corpus.get_iterator('valid', args.batch_size, args.tgt_len, device=device, ext_len=args.ext_len) elif args.split == 'test': iter = corpus.get_iterator('test', args.batch_size, args.tgt_len, device=device, ext_len=args.ext_len) else: raise RuntimeError('Unknown split') if args.fp16: dtype = torch.float16 math_str = 'fp16' else: dtype = torch.float32 math_str = 'fp32' if args.load_torchscript: model = torch.jit.load(args.load_torchscript) else: checkpoint['model_config']['tgt_len'] = args.tgt_len checkpoint['model_config']['ext_len'] = args.ext_len checkpoint['model_config']['mem_len'] = args.mem_len checkpoint['model_config']['clamp_len'] = args.clamp_len checkpoint['model_config']['same_length'] = args.same_length checkpoint['model_config']['dtype'] = dtype model = MemTransformerLM(**checkpoint['model_config']) model.load_state_dict(checkpoint['model_state']) model = model.eval() model = model.to(device) model = model.float() if args.fp16: model = model.half() if args.type != 'pytorch': compile_model(model, device, args) if args.type == 'torchscript' and args.save_torchscript: torch.jit.save(model, args.save_torchscript) logging.info(f'Evaluating with: math {math_str} type {args.type} ' f'bsz {args.batch_size} tgt_len {args.tgt_len} ' f'ext_len {args.ext_len} mem_len {args.mem_len} ' f'clamp_len {args.clamp_len}') meters = {} warmup = args.mem_len // args.tgt_len + 1 meters['eval_throughput'] = AverageMeter(warmup=warmup, keep=args.save_data) meters['eval_latency'] = AverageMeter(warmup=warmup, keep=args.save_data) loss = evaluate(iter, model, meters, args.max_size, args.repeat) perplexity = math.exp(loss) log_str = format_log(loss, args.split, args) logging.info('=' * 100) logging.info(log_str) logging.info('=' * 100) if args.save_data: latency_data = np.array(meters['eval_latency'].vals) throughput_data = np.array(meters['eval_throughput'].vals) precision = 'fp16' if args.fp16 else 'fp32' data_fname = f'eval_data_{args.batch_size}_{precision}_{args.type}' data_path = os.path.join(args.work_dir, data_fname) data = { 'args': args, 'throughput': throughput_data, 'latency': latency_data, } with open(data_path, 'wb') as f: pickle.dump(data, f) logging.info(f'Throughput Avg: {throughput_data.mean():.2f} tok/s') logging.info(f'Latency Avg: {1000.0 * latency_data.mean():.2f} ms') for p in args.percentiles: logging.info( f'Latency {p}%: {1000.0 * np.percentile(latency_data, p):.2f} ms' ) logging.info('=' * 100) passed = benchmark( target_perplexity=args.target_perplexity, test_perplexity=perplexity, target_throughput=args.target_throughput, test_throughput=meters['eval_throughput'].avg, ) if not passed: sys.exit(1)
clamp_len=args.clamp_len, sample_softmax=args.sample_softmax, ) if args.restart: print('Restarting training from {args.restart_dir}') with open(os.path.join(args.restart_dir, 'model.pt'), 'rb') as f: state = torch.load(f) if isinstance(state, MemTransformerLM): # old format model = state else: model_params = state['model_params'] model = MemTransformerLM(**model_params) model.load_state_dict(state['state_dict']) del state if not args.fp16: model = model.float() model.apply(update_dropout) model.apply(update_dropatt) else: model = MemTransformerLM(**model_params) model.apply(weights_init) # ensure embedding init is not overridden by out_layer # in case of weight sharing model.word_emb.apply(weights_init) args.n_all_param = sum([p.nelement() for p in model.parameters()]) args.n_nonemb_param = sum([p.nelement() for p in model.layers.parameters()]) if args.fp16: model = model.half() if args.multi_gpu: