def do_train(args): if args.use_gpu: rank = dist.get_rank() trainer_count = dist.get_world_size() else: rank = 0 trainer_count = 1 paddle.set_device("cpu") if trainer_count > 1: dist.init_parallel_env() random_seed = eval(str(args.random_seed)) if random_seed is not None: paddle.seed(random_seed) vocab = get_lm_vocab(args) train_loader = get_lm_data_loader(args, vocab, "train") eval_loader = get_lm_data_loader(args, vocab, "valid") cutoffs, tie_projs = [], [False] if args.adaptive: assert args.dataset in ['wt103', 'lm1b'] if args.dataset == 'wt103': cutoffs = [20000, 40000, 200000] tie_projs += [True] * len(cutoffs) elif args.dataset == 'lm1b': cutoffs = [60000, 100000, 640000] tie_projs += [False] * len(cutoffs) mem_transformer = MemTransformerLM(args.ntokens, args.n_layer, args.n_head, args.d_model, args.d_head, args.d_inner_hid, args.dropout, args.attn_dropout, tie_weight=args.tie_weight, d_embed=args.d_model, div_val=args.div_val, tie_projs=tie_projs, normalize_before=args.normalize_before, tgt_len=args.tgt_len, ext_len=args.ext_len, mem_len=args.mem_len, cutoffs=cutoffs, same_length=args.same_length, attn_type=args.attn_type, clamp_len=args.clamp_len, sample_softmax=args.sample_softmax) if args.scheduler == 'cosine': scheduler = paddle.optimizer.lr.CosineAnnealingDecay( learning_rate=args.learning_rate, T_max=args.max_step, eta_min=args.eta_min) elif args.scheduler == 'noam': scheduler = paddle.optimizer.lr.NoamDecay( d_model=args.d_model, warmup_steps=args.warmup_steps, learning_rate=args.learning_rate) elif args.scheduler == 'dev_perf': # fluid api scheduler = paddle.fluid.dygraph.ReduceLROnPlateau( learning_rate=args.learning_rate, decay_rate=args.decay_rate, patience=args.patience, min_lr=args.lr_min) elif args.scheduler == 'constant': scheduler = args.learning_rate clip = paddle.nn.ClipGradByGlobalNorm(args.clip) if args.optim.lower() == 'momentum': optimizer = paddle.optimizer.Momentum( learning_rate=scheduler, parameters=mem_transformer.parameters(), momentum=args.mom, grad_clip=clip) elif args.optim.lower() == 'adam': optimizer = paddle.optimizer.Adam( learning_rate=scheduler, parameters=mem_transformer.parameters(), beta1=args.beta1, beta2=args.beta2, epsilon=eval(args.eps), grad_clip=clip) elif args.optim.lower() == 'adagrad': optimizer = paddle.optimizer.Adagrad( learning_rate=scheduler, parameters=mem_transformer.parameters(), grad_clip=clip) # Init from some checkpoint, to resume the previous training if args.init_from_checkpoint: model_dict = paddle.load( os.path.join(args.init_from_checkpoint, "mem_transformer.pdparams")) opt_dict = paddle.load( os.path.join(args.init_from_checkpoint, "mem_transformer.pdopt")) mem_transformer.set_state_dict(model_dict) optimizer.set_state_dict(opt_dict) print("loaded from checkpoint.") # Init from some pretrain models, to better solve the current task if args.init_from_pretrain_model: model_dict = paddle.load( os.path.join(args.init_from_pretrain_model, "mem_transformer.pdparams")) mem_transformer.set_state_dict(model_dict) print("loaded from pre-trained model.") if trainer_count > 1: mem_transformer = paddle.DataParallel(mem_transformer) step_idx = 0 train_loss = 0.0 log_start_time = time.time() for pass_id in range(args.epoch): batch_id = 0 mems = tuple() for input_data in train_loader: (src, target, seq_len) = input_data ret = mem_transformer(src, target, *mems) loss = ret[0] mems = ret[1:] train_loss += loss.numpy() loss.backward() optimizer.step() optimizer.clear_grad() if step_idx > 0 and step_idx % args.print_step == 0 and rank == 0: cur_loss = train_loss / args.print_step elapsed = time.time() - log_start_time if args.scheduler == "constant": lr = optimizer.get_lr() else: lr = scheduler.get_lr() logger_info = "step_idx: %d, epoch: %d, batch: %d, learning rate: %.8f, " \ "speed: %f ms/batch, loss: %f" % \ (step_idx, pass_id, batch_id, lr, elapsed * 1000.0 / args.print_step, cur_loss) if args.dataset in ["enwik8", "text8"]: logger_info = logger_info + ", bpc: %f" % (cur_loss / np.log(2)) else: logger_info = logger_info + ", ppl: %f" % ( np.exp(cur_loss)) logger.info(logger_info) train_loss = 0.0 log_start_time = time.time() if step_idx % args.save_step == 0 and step_idx != 0: # Do validation. mem_transformer.eval() # TODO(FrostML): simplify this. if args.mem_len == 0: if dist.get_world_size() == 1: mem_transformer.reset_length(tgt_len=args.eval_tgt_len, ext_len=args.ext_len + args.tgt_len - args.eval_tgt_len, mem_len=args.mem_len) else: mem_transformer._layers.reset_length( tgt_len=args.eval_tgt_len, ext_len=args.ext_len + args.tgt_len - args.eval_tgt_len, mem_len=args.mem_len) else: if dist.get_world_size() == 1: mem_transformer.reset_length(tgt_len=args.eval_tgt_len, ext_len=args.ext_len, mem_len=args.mem_len + args.tgt_len - args.eval_tgt_len) else: mem_transformer._layers.reset_length( tgt_len=args.eval_tgt_len, ext_len=args.ext_len, mem_len=args.mem_len + args.tgt_len - args.eval_tgt_len) total_len, total_loss = 0, 0. eval_mems = tuple() with paddle.no_grad(): for i, (src, target, seq_len) in enumerate(eval_loader): if args.max_eval_steps > 0 and i >= args.max_eval_steps: break ret = mem_transformer(src, target, *eval_mems) loss, eval_mems = ret[0], ret[1:] seq_len = seq_len.numpy() eval_cur_loss = seq_len * loss.numpy() total_loss += eval_cur_loss total_len += seq_len eval_loss = total_loss / total_len logger_info = "Validation, step_idx: %d, validation loss: %f" % \ (step_idx, eval_loss) if args.dataset in ['enwik8', 'text8']: logger_info = logger_info + ", bpc: %f" % (eval_loss / np.log(2)) else: logger_info = logger_info + ", ppl: %f" % ( np.exp(eval_loss)) logger.info(logger_info) if args.save_model and rank == 0: model_dir = os.path.join( args.save_model, "step_" + str(step_idx) + "_" + str(eval_loss)) if not os.path.exists(model_dir): os.makedirs(model_dir) paddle.save( mem_transformer.state_dict(), os.path.join(model_dir, "mem_transformer.pdparams")) paddle.save( optimizer.state_dict(), os.path.join(model_dir, "mem_transformer.pdopt")) if args.scheduler == 'dev_perf': scheduler.step(eval_loss) # TODO(FrostML): simplify this. if dist.get_world_size() == 1: mem_transformer.reset_length(tgt_len=args.tgt_len, ext_len=args.ext_len, mem_len=args.mem_len) else: mem_transformer._layers.reset_length(tgt_len=args.tgt_len, ext_len=args.ext_len, mem_len=args.mem_len) mem_transformer.train() step_idx += 1 batch_id += 1 if args.scheduler in ['cosine', 'dev_perf']: if step_idx < args.warmup_steps: curr_lr = args.learning_rate * step_idx / args.warmup_steps scheduler.base_lr = curr_lr else: if args.scheduler == 'cosine': scheduler.step() elif args.scheduler == 'constant': if step_idx < args.warmup_steps: curr_lr = args.learning_rate * step_idx / args.warmup_steps optimizer.set_lr(curr_lr) elif args.scheduler == 'noam': scheduler.step() if step_idx >= args.max_step: break if args.save_model and rank == 0: model_dir = os.path.join(args.save_model, "step_final") if not os.path.exists(model_dir): os.makedirs(model_dir) paddle.save(mem_transformer.state_dict(), os.path.join(model_dir, "mem_transformer.pdparams")) paddle.save(optimizer.state_dict(), os.path.join(model_dir, "mem_transformer.pdopt"))
def main(): args = parse_args() if args.affinity != 'disabled': nproc_per_node = torch.cuda.device_count() affinity = utils.gpu_affinity.set_affinity(args.local_rank, nproc_per_node, args.affinity) print(f'{args.local_rank}: thread affinity: {affinity}') if args.type == 'pytorch': from mem_transformer import MemTransformerLM else: from inference.mem_transformer_jit import MemTransformerLM torch.cuda.set_device(args.local_rank) l2_promote() device = torch.device('cuda' if args.cuda else 'cpu') utils.distributed.init_distributed(args.cuda) with utils.distributed.sync_workers() as rank: if rank == 0: create_exp_dir(args.work_dir, debug=args.debug) # Setup logging if args.log_all_ranks: log_file = f'eval_log_rank_{utils.distributed.get_rank()}.log' else: log_file = f'eval_log.log' dllog_file = args.dllog_file log_file = os.path.join(args.work_dir, log_file) dllog_file = os.path.join(args.work_dir, dllog_file) if args.debug: log_file = os.devnull dllog_file = os.devnull utils.exp_utils.setup_logging( log_all_ranks=args.log_all_ranks, filename=log_file, filemode='a', ) utils.exp_utils.setup_dllogger(enabled=True, filename=dllog_file) if args.profile: try: pyprof.init(enable_function_stack=True) except NameError: warnings.warn('Called pyprof.init() but pyprof is not available') logging.info(args) dllogger.log(step='PARAMETER', data=vars(args)) if not args.no_env: log_env_info() # Set the random seed manually for reproducibility. np.random.seed(args.seed) torch.manual_seed(args.seed) if args.model: model_path = args.model elif args.work_dir: model_path = os.path.join(args.work_dir, 'checkpoint_best.pt') else: raise RuntimeError( 'Specify path to checkpoint using --model or --work_dir') if not args.manual_config: checkpoint = load_checkpoint(model_path) vocab_type = checkpoint['args'].vocab else: checkpoint = None vocab_type = args.manual_vocab if args.manual: vocab = checkpoint['vocab'] if hasattr(vocab, 'sym2idx') and not hasattr(vocab, 'unk_idx'): vocab.unk_idx = vocab.sym2idx['<unk>'] text = " ".join(args.manual) tokenized = tokenize_raw(text) symbols = vocab.tokenize(tokenized, add_eos=True) tensor = vocab.convert_to_tensor(symbols) iter = data_utils.LMOrderedIterator(tensor, bsz=args.batch_size, bptt=args.tgt_len, device=device, ext_len=args.ext_len, warmup=False) else: # Load dataset corpus = get_lm_corpus(args.data, args.dataset, vocab_type) if args.split == 'valid' or args.split == 'test': iter = corpus.get_iterator(args.split, args.batch_size, args.tgt_len, device=device, mem_len=args.mem_len, ext_len=args.ext_len) else: raise RuntimeError('Unknown split') if args.fp16: dtype = torch.float16 math_str = 'fp16' else: dtype = torch.float32 math_str = 'fp32' if args.load_torchscript: model = torch.jit.load(args.load_torchscript) elif not args.manual_config: checkpoint['model_config']['tgt_len'] = args.tgt_len checkpoint['model_config']['ext_len'] = args.ext_len checkpoint['model_config']['mem_len'] = args.mem_len checkpoint['model_config']['clamp_len'] = args.clamp_len checkpoint['model_config']['same_length'] = args.same_length checkpoint['model_config']['dtype'] = dtype model = MemTransformerLM(**checkpoint['model_config']) if args.type == 'pytorch': model.load_state_dict(checkpoint['model_state']) elif args.type == 'torchscript': model.load_state_dict(checkpoint['model_state'], strict=False) elif args.manual_config: args.manual_config['tgt_len'] = args.tgt_len args.manual_config['ext_len'] = args.ext_len args.manual_config['mem_len'] = args.mem_len args.manual_config['clamp_len'] = args.clamp_len args.manual_config['same_length'] = args.same_length args.manual_config['dtype'] = dtype model = MemTransformerLM(**args.manual_config) model = model.eval() model = model.to(device) model = model.to(dtype) if args.type == 'torchscript' and not args.manual_config: state = checkpoint['model_state'] tie_projs = checkpoint['model_config']['tie_projs'] tie_weight = checkpoint['model_config']['tie_weight'] div_val = checkpoint['model_config']['div_val'] d_model = checkpoint['model_config']['d_model'] d_embed = checkpoint['model_config']['d_embed'] if div_val != 1 or d_model != d_embed: for i in range(len(model.word_emb.emb_projs)): model.word_emb.emb_projs[i] = state[ f'word_emb.emb_projs.{i}'].to(dtype) for i in range(len(model.crit.out_projs)): if div_val == 1: src = 0 else: src = i if model.crit.out_projs[i] is not None: if tie_projs[i]: model.crit.out_projs[i] = state[ f'word_emb.emb_projs.{src}'].to(dtype) else: model.crit.out_projs[i] = state[f'crit.out_projs.{i}'].to( dtype) for i in range(len(model.crit.out_layers_biases)): model.crit.out_layers_biases[i] = state[ f'crit.out_layers_biases.{i}'].to(dtype) if tie_weight: for i in range(len(model.crit.out_layers_weights)): model.crit.out_layers_weights[i] = state[ f'word_emb.emb_layers.{i}.weight'].to(dtype) else: for i in range(len(model.crit.out_layers_weights)): model.crit.out_layers_weights[i] = state[ f'crit.out_layers_weights.{i}'].to(dtype) model = torch.jit.script(model) if args.type != 'pytorch': compile_model(model, device, args) if args.type == 'torchscript' and args.save_torchscript: torch.jit.save(model, args.save_torchscript) logging.info(f'Evaluating with: math {math_str} type {args.type} ' f'bsz {args.batch_size} tgt_len {args.tgt_len} ' f'ext_len {args.ext_len} mem_len {args.mem_len} ' f'clamp_len {args.clamp_len}') meters = {} warmup = args.mem_len // args.tgt_len + 2 meters['eval_throughput'] = AverageMeter(warmup=warmup, keep=args.save_data) meters['eval_latency'] = AverageMeter(warmup=warmup, keep=args.save_data) with torch.autograd.profiler.emit_nvtx(enabled=args.profile): loss = evaluate(iter, model, meters, args.log_interval, args.max_size, args.repeat) perplexity = math.exp(loss) log_str = format_log(loss, args.split, args) summary = { 'eval_loss': loss, 'eval_ppl': perplexity, } logging.info('=' * 100) logging.info(log_str) logging.info('=' * 100) if args.save_data: latency_data = np.array(meters['eval_latency'].vals) throughput_data = np.array(meters['eval_throughput'].vals) precision = 'fp16' if args.fp16 else 'fp32' data_fname = f'eval_data_{args.batch_size}_{precision}_{args.type}' data_path = os.path.join(args.work_dir, data_fname) data = { 'args': args, 'throughput': throughput_data, 'latency': latency_data, } with open(data_path, 'wb') as f: pickle.dump(data, f) logging.info(f'Throughput Avg: {throughput_data.mean():.2f} tok/s') logging.info(f'Latency Avg: {1000.0 * latency_data.mean():.2f} ms') for p in args.percentiles: logging.info( f'Latency {p}%: {1000.0 * np.percentile(latency_data, p):.2f} ms' ) logging.info('=' * 100) summary.update({ 'eval_throughput': throughput_data.mean(), 'eval_avg_latency': 1000 * latency_data.mean(), }) for p in args.percentiles: summary[f'eval_{p}%_latency'] = 1000 * np.percentile( latency_data, p) dllogger.log(step=tuple(), data=summary) passed = benchmark( target_perplexity=args.target_perplexity, test_perplexity=perplexity, target_throughput=args.target_throughput, test_throughput=meters['eval_throughput'].avg, ) if not passed: sys.exit(1)
def main(): args = parse_args() if args.type == 'pytorch': from mem_transformer import MemTransformerLM else: from inference.mem_transformer_base_jit import MemTransformerLM torch.cuda.set_device(args.local_rank) device = torch.device('cuda' if args.cuda else 'cpu') utils.distributed.init_distributed(args.cuda) with utils.distributed.sync_workers() as rank: if rank == 0: create_exp_dir(args.work_dir, debug=args.debug) # Setup logging if args.log_all_ranks: log_file = f'log_rank_{utils.distributed.get_rank()}.log' else: log_file = f'log.log' log_file = os.path.join(args.work_dir, log_file) if args.debug: log_file = os.devnull utils.exp_utils.setup_logging( log_all_ranks=args.log_all_ranks, filename=log_file, filemode='a', ) logging.info(args) if args.model: model_path = args.model elif args.work_dir: model_path = os.path.join(args.work_dir, 'checkpoint_best.pt') else: raise RuntimeError( 'Specify path to checkpoint using --model or --work_dir') checkpoint = load_checkpoint(model_path) if args.manual: args.batch_size = 1 vocab = checkpoint['vocab'] if hasattr(vocab, 'sym2idx') and not hasattr(vocab, 'unk_idx'): vocab.unk_idx = vocab.sym2idx['<unk>'] text = " ".join(args.manual) tokenized = tokenize_raw(text) symbols = vocab.tokenize(tokenized, add_eos=True) tensor = vocab.convert_to_tensor(symbols) iter = data_utils.LMOrderedIterator(tensor, bsz=args.batch_size, bptt=args.tgt_len, device=device, ext_len=args.ext_len) else: # Load dataset corpus = get_lm_corpus(args.data, args.dataset, checkpoint['args'].vocab) if args.split == 'valid': iter = corpus.get_iterator('valid', args.batch_size, args.tgt_len, device=device, ext_len=args.ext_len) elif args.split == 'test': iter = corpus.get_iterator('test', args.batch_size, args.tgt_len, device=device, ext_len=args.ext_len) else: raise RuntimeError('Unknown split') if args.fp16: dtype = torch.float16 math_str = 'fp16' else: dtype = torch.float32 math_str = 'fp32' if args.load_torchscript: model = torch.jit.load(args.load_torchscript) else: checkpoint['model_config']['tgt_len'] = args.tgt_len checkpoint['model_config']['ext_len'] = args.ext_len checkpoint['model_config']['mem_len'] = args.mem_len checkpoint['model_config']['clamp_len'] = args.clamp_len checkpoint['model_config']['same_length'] = args.same_length checkpoint['model_config']['dtype'] = dtype model = MemTransformerLM(**checkpoint['model_config']) model.load_state_dict(checkpoint['model_state']) model = model.eval() model = model.to(device) model = model.float() if args.fp16: model = model.half() if args.type != 'pytorch': compile_model(model, device, args) if args.type == 'torchscript' and args.save_torchscript: torch.jit.save(model, args.save_torchscript) logging.info(f'Evaluating with: math {math_str} type {args.type} ' f'bsz {args.batch_size} tgt_len {args.tgt_len} ' f'ext_len {args.ext_len} mem_len {args.mem_len} ' f'clamp_len {args.clamp_len}') meters = {} warmup = args.mem_len // args.tgt_len + 1 meters['eval_throughput'] = AverageMeter(warmup=warmup, keep=args.save_data) meters['eval_latency'] = AverageMeter(warmup=warmup, keep=args.save_data) loss = evaluate(iter, model, meters, args.max_size, args.repeat) perplexity = math.exp(loss) log_str = format_log(loss, args.split, args) logging.info('=' * 100) logging.info(log_str) logging.info('=' * 100) if args.save_data: latency_data = np.array(meters['eval_latency'].vals) throughput_data = np.array(meters['eval_throughput'].vals) precision = 'fp16' if args.fp16 else 'fp32' data_fname = f'eval_data_{args.batch_size}_{precision}_{args.type}' data_path = os.path.join(args.work_dir, data_fname) data = { 'args': args, 'throughput': throughput_data, 'latency': latency_data, } with open(data_path, 'wb') as f: pickle.dump(data, f) logging.info(f'Throughput Avg: {throughput_data.mean():.2f} tok/s') logging.info(f'Latency Avg: {1000.0 * latency_data.mean():.2f} ms') for p in args.percentiles: logging.info( f'Latency {p}%: {1000.0 * np.percentile(latency_data, p):.2f} ms' ) logging.info('=' * 100) passed = benchmark( target_perplexity=args.target_perplexity, test_perplexity=perplexity, target_throughput=args.target_throughput, test_throughput=meters['eval_throughput'].avg, ) if not passed: sys.exit(1)