def load(cls, model_path: Path, spm_path: Path, device: str = None) -> 'ModelWrapper': if device is None: device = 'cuda' if torch.cuda.is_available() else 'cpu' with open(model_path, 'rb') as f: state = torch.load(f, map_location='cpu') model = MemTransformerLM(**state['model_params']) model.load_state_dict(state['state_dict']) vocab_params = state['vocab_params'] vocab = Vocab.from_symbols(state['vocab'], ) sp_processor = spm.SentencePieceProcessor() sp_processor.Load(str(spm_path)) return cls(model, vocab, sp_processor, device)
def __init__(self, model: MemTransformerLM, vocab: Vocab, sp_processor: spm.SentencePieceProcessor, device: str): self.vocab = vocab self.sp_processor = sp_processor self.device = device self.model = model.to(device=self.device) self.model.eval()
def widen_model(strategy, ratio, model: MemTransformerLM, optimizers, va_iter, epoch, step): optimizer, _ = optimizers # pre-expansion validation logging(f"evaluating before widening") # debug val_loss = evaluate(va_iter, model) log_val(val_loss, compute=model.compute, step=step) # infer example logits for reverse distillation # expansion logging(f"adding {ratio} layers before starting epoch {epoch} with method {strategy}") model.add_heads(ratio, strategy=strategy, function=initialization_func) # optimizer update for param, states in optimizer.state.items(): if isinstance(param, nn.Parameter): states["exp_avg"] = expand_state(param, states["exp_avg"]) states["exp_avg_sq"] = expand_state(param, states["exp_avg_sq"]) # training loop for reverse distillation # post-expansion validation logging(f"reevaluating") val_loss = evaluate(va_iter, model) log_val(val_loss, step=step, compute=model.compute)
def expand_model(strategy, integration, integration_length, n_add, model: MemTransformerLM, optimizers, schedulers, tr_iter, va_iter, epoch, step): optimizer, _ = optimizers scheduler, _ = schedulers if integration: if not integration_length or integration_length <= 0: warnings.warn(f"integration {integration} passed but integration_length is {integration_length}") else: logging(f"applying integration strategy {integration} with integration length {integration_length}") # pre-expansion validation logging(f"evaluating before expanding") val_loss = evaluate(va_iter, model) log_val(val_loss, step=step, compute=model.compute) # infer example logits for reverse distillation if "reverse_distil" in integration: first_logits = get_original_batches(model, tr_iter, integration_length) # expansion logging(f"adding {n_add} layers before starting epoch {epoch} with method {strategy}") new_layers = model.expand_layers(n_add, strategy=strategy, function=initialization_func) # optimizer update optimizer.add_param_group({'params': new_layers.parameters(), 'lr': optimizer.param_groups[0]["lr"], 'initial_lr': optimizer.param_groups[0]["initial_lr"]}) scheduler.base_lrs.append(optimizer.param_groups[-1]["initial_lr"]) # training loop for reverse distillation if "reverse_distil" in integration: fit_to_previous_model(model, new_layers, tr_iter, first_logits, integration) # freezing parameters for frozen restart, we do this afterwards else the new layers get copied also without grads if "freeze" in integration and integration_length > 0: for param_group in optimizer.param_groups[:-1]: for parameter in param_group['params']: parameter.requires_grad = False model.freeze_countdown = integration_length # post-expansion validation logging(f"reevaluating") val_loss = evaluate(va_iter, model) log_val(val_loss, step=step, compute=model.compute)
def __init__( self, model: Union[str, Dict, MemTransformerLM], tokenizer_config: TokenizerWrapperConfig, device: torch.device, verbose: bool = False, ): if isinstance(model, Dict): model = MemTransformerLM(**model) elif isinstance(model, str): model = self._load_model(model, device) elif isinstance(model, MemTransformerLM): model = model else: raise TypeError( f"model has type {type(model)}, but only str, Dict or MemTransformerLM allowed" ) super().__init__(model, tokenizer_config, device) self._model.reset_length(1, 0, self._model.mem_len) self._batch_len = self._model.mem_len self._verbose = verbose
def do_eval(args): assert args.ext_len >= 0, 'Extended context length must be no less than 0' def _evaluate(loader): total_len, total_loss = 0, 0. eval_mems = tuple() for i, (src, target, seq_len) in enumerate(loader): if args.max_eval_steps > 0 and i >= args.max_eval_steps: break ret = mem_transformer(src, target, *eval_mems) loss, eval_mems = ret[0], ret[1:] eval_cur_loss = seq_len * loss.numpy() total_loss += eval_cur_loss total_len += seq_len return total_loss / total_len def _logger(loss): if args.dataset in ['enwik8', 'text8']: logger_info = "loss: %f, bpc: %f" % \ (loss, loss / np.log(2)) else: logger_info = "loss: %f, ppl: %.2f" % \ (loss, np.exp(loss)) return logger_info if not args.use_gpu: paddle.set_device("cpu") vocab = get_lm_vocab(args) eval_loader = get_lm_data_loader(args, vocab, "valid") test_loader = get_lm_data_loader(args, vocab, "test") cutoffs, tie_projs = [], [False] if args.adaptive: assert args.dataset in ['wt103', 'lm1b'] if args.dataset == 'wt103': cutoffs = [20000, 40000, 200000] tie_projs += [True] * len(cutoffs) elif args.dataset == 'lm1b': cutoffs = [60000, 100000, 640000] tie_projs += [False] * len(cutoffs) mem_transformer = MemTransformerLM(args.ntokens, args.n_layer, args.n_head, args.d_model, args.d_head, args.d_inner_hid, args.dropout, args.attn_dropout, tie_weight=args.tie_weight, d_embed=args.d_model, div_val=args.div_val, tie_projs=tie_projs, normalize_before=args.normalize_before, tgt_len=args.tgt_len, ext_len=args.ext_len, mem_len=args.mem_len, cutoffs=cutoffs, same_length=args.same_length, attn_type=args.attn_type, clamp_len=args.clamp_len, sample_softmax=args.sample_softmax) assert args.init_from_params, ( "Please set init_from_params to load the infer model.") model_dict = paddle.load( os.path.join(args.init_from_params, "mem_transformer.pdparams")) mem_transformer.load_dict(model_dict) logger.info( "Evaluating with bsz {} tgt_len {} ext_len {} mem_len {} clamp_len {}". format(args.eval_batch_size, args.tgt_len, args.ext_len, args.mem_len, args.clamp_len)) mem_transformer.reset_length(args.tgt_len, args.ext_len, args.mem_len) test_loss = None valid_loss = None if args.mode == 'all': test_loss = _evaluate(test_loader) valid_loss = _evaluate(eval_loader) elif args.mode == 'valid': valid_loss = _evaluate(eval_loader) elif args.mode == 'test': test_loss = _evaluate(test_loader) logger_info = '' if valid_loss is not None: logger_info = logger_info + "validation loss: " + _logger( valid_loss) + " | " if test_loss is not None: logger_info = logger_info + "test loss: " + _logger(test_loss) + " | " logger.info(logger_info)
def load(cls, model_path, sp_model_path, device, print_stats=True): paramspath = os.path.join(model_path, 'params.json') with open(paramspath, 'r') as paramsf: xl_params = json.loads(paramsf.read()) print(repr(xl_params)) model = MemTransformerLM( xl_params['ntokens'], # 50000, xl_params['n_layer'], # 16, xl_params['n_head'], # 10, xl_params['d_model'], # 410, xl_params['d_head'], # 41, xl_params['d_inner'], # 2100, 0.0, # no dropout, 0.0, # no dropatt, tie_weight=xl_params['tie_weight'], # True, d_embed=xl_params['d_embed'], # 410, div_val=xl_params['div_val'], # 1, tie_projs=xl_params['tie_projs'], # [False, True, True, True] pre_lnorm=xl_params['pre_lnorm'], # False, tgt_len=xl_params['tgt_len'], # 150, ext_len=xl_params['ext_len'], # 0, mem_len=xl_params['mem_len'], # 150, cutoffs=xl_params['cutoffs'], # [3500, 7500, 37500], same_length=xl_params['same_length'], # False, attn_type=xl_params['attn_type'], # 0, clamp_len=xl_params['clamp_len'], # -1, sample_softmax=xl_params['sample_softmax']) # -1 state_dict_path = os.path.join(model_path, 'valid_state_dict.pt') print("loading weights %s ..." % state_dict_path) tensor_dict = torch.load(state_dict_path, map_location=torch.device(device)) model.load_state_dict(tensor_dict) print("loading weights %s ... done." % state_dict_path) if print_stats: tensor_list = list(tensor_dict.items()) for layer_tensor_name, tensor in tensor_list: print("Layer %-42s: %9d elements" % (layer_tensor_name, torch.numel(tensor))) pytorch_total_params = sum(p.numel() for p in model.parameters()) print("Total # params: %d" % pytorch_total_params) # with open(os.path.join(MODEL_PATH, 'model.pt'), 'rb') as f: # model = torch.load(f) # model.apply(update_dropout) # model.apply(update_dropatt) para_model = model.to(device) # print ("loading model %s ... done." % MODEL_PATH) print("loading sp model from %s ..." % sp_model_path) sp_model = spm.SentencePieceProcessor() sp_model.load(sp_model_path) print("loading sp model from %s ... done." % sp_model_path) return cls(para_model, sp_model, device)
model = torch.load(f) model = model.float() model.apply(update_dropout) model.apply(update_dropatt) else: model = MemTransformerLM(n_token_in, n_token_out, args.n_layer, args.n_head, args.d_model, args.d_head, args.d_inner, args.dropout, args.dropatt, args.conv_size, args.conv_emb, args.pre_conv, tie_weight=args.tied, d_embed=args.d_embed, tie_projs=tie_projs, tgt_len=args.tgt_len, mem_len=args.mem_len, ext_ds=args.ext_ds, cutoffs=cutoffs, same_length=args.same_length, clamp_len=args.clamp_len) model.apply(weights_init) model.word_emb.apply( weights_init ) # ensure embedding init is not overridden by out_layer in case of weight sharing args.n_all_param = sum([p.nelement() for p in model.parameters()])
def main(): args = parse_args() if args.affinity != 'disabled': nproc_per_node = torch.cuda.device_count() affinity = utils.gpu_affinity.set_affinity(args.local_rank, nproc_per_node, args.affinity) print(f'{args.local_rank}: thread affinity: {affinity}') # Initialize device and distributed backend torch.cuda.set_device(args.local_rank) l2_promote() device = torch.device('cuda' if args.cuda else 'cpu') utils.distributed.init_distributed(args.cuda) args.work_dir = utils.exp_utils.build_work_dir_name( args.work_dir, args.dataset, args.append_dataset, args.append_time, ) with utils.distributed.sync_workers() as rank: if rank == 0: create_exp_dir(args.work_dir, scripts_to_save=['train.py', 'mem_transformer.py'], debug=args.debug) # Setup logging if args.log_all_ranks: log_file = f'train_log_rank_{utils.distributed.get_rank()}.log' else: log_file = args.txtlog_file dllog_file = args.dllog_file log_file = os.path.join(args.work_dir, log_file) dllog_file = os.path.join(args.work_dir, dllog_file) if args.debug: log_file = os.devnull dllog_file = os.devnull utils.exp_utils.setup_logging( log_all_ranks=args.log_all_ranks, filename=log_file, ) utils.exp_utils.setup_dllogger(enabled=True, filename=dllog_file) if args.local_batch_size is not None: world_size = utils.distributed.get_world_size() args.batch_size = world_size * args.local_batch_size logging.info(f'--local_batch_size was set, adjusting global batch size' f' to {args.batch_size} (local_batch_size * world_size)') if args.batch_size % args.batch_chunk != 0: raise RuntimeError('Batch size needs to be divisible by ' 'batch chunk') if args.profile: try: pyprof.init(enable_function_stack=True) except NameError: warnings.warn('Called pyprof.init() but pyprof is not available') logging.info(args) dllogger.log(step='PARAMETER', data=vars(args)) logging.info(f'world size: {utils.distributed.get_world_size()}') if not args.no_env: log_env_info() register_ignoring_timeout_handler() # Set the random seed manually for reproducibility. np.random.seed(args.seed) torch.manual_seed(args.seed) ########################################################################### # Load data ########################################################################### corpus = get_lm_corpus(args.data, args.dataset, args.vocab) ntokens = len(corpus.vocab) vocab = corpus.vocab args.n_token = ntokens if args.mem_len == 0: eval_mem_len = 0 else: eval_mem_len = args.mem_len + args.tgt_len - args.eval_tgt_len tr_iter = corpus.get_iterator('train', args.batch_size, args.tgt_len, device=device, ext_len=args.ext_len) va_iter = corpus.get_iterator('valid', args.eval_batch_size, args.eval_tgt_len, device=device, mem_len=eval_mem_len, ext_len=args.ext_len) te_iter = corpus.get_iterator('test', args.eval_batch_size, args.eval_tgt_len, device=device, mem_len=eval_mem_len, ext_len=args.ext_len) # adaptive softmax / embedding cutoffs, tie_projs = [], [False] if args.adaptive: assert args.dataset in ['wt103', 'lm1b'] if args.dataset == 'wt103': cutoffs = [19997, 39997, 199997] tie_projs += [True] * len(cutoffs) elif args.dataset == 'lm1b': cutoffs = [59997, 99997, 639997] tie_projs += [False] * len(cutoffs) ########################################################################### # Build the model ########################################################################### model_config = { 'n_token': ntokens, 'n_layer': args.n_layer, 'n_head': args.n_head, 'd_model': args.d_model, 'd_head': args.d_head, 'd_inner': args.d_inner, 'dropout': args.dropout, 'dropatt': args.dropatt, 'dtype': None, 'tie_weight': args.tied, 'd_embed': args.d_embed, 'div_val': args.div_val, 'tie_projs': tie_projs, 'pre_lnorm': args.pre_lnorm, 'tgt_len': args.tgt_len, 'ext_len': args.ext_len, 'mem_len': args.mem_len, 'cutoffs': cutoffs, 'same_length': args.same_length, 'attn_type': args.attn_type, 'clamp_len': args.clamp_len, 'sample_softmax': args.sample_softmax, } model = MemTransformerLM(**model_config) model.apply(functools.partial(weights_init, args=args)) # ensure embedding init is not overridden by out_layer in case of weight sharing model.word_emb.apply(functools.partial(weights_init, args=args)) args.n_all_param = sum([p.nelement() for p in model.parameters()]) args.n_nonemb_param = sum( [p.nelement() for p in model.layers.parameters()]) # optimizer if args.optim.lower() == 'sgd': if args.sample_softmax > 0: dense_params, sparse_params = [], [] for param in model.parameters(): if param.size() == model.word_emb.weight.size(): sparse_params.append(param) else: dense_params.append(param) optimizer_sparse = optim.SGD(sparse_params, lr=args.lr * 2) optimizer = optim.SGD(dense_params, lr=args.lr, momentum=args.mom) else: optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=args.mom) optimizer_sparse = None elif args.optim.lower() == 'adam': if args.sample_softmax > 0: dense_params, sparse_params = [], [] for param in model.parameters(): if param.size() == model.word_emb.weight.size(): sparse_params.append(param) else: dense_params.append(param) optimizer_sparse = optim.SparseAdam(sparse_params, lr=args.lr) optimizer = optim.Adam(dense_params, lr=args.lr, weight_decay=args.weight_decay) else: optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay) optimizer_sparse = None elif args.optim.lower() == 'adagrad': optimizer = optim.Adagrad(model.parameters(), lr=args.lr) optimizer_sparse = None elif args.optim.lower() == 'lamb': optimizer = lamb.Lamb(model.parameters(), lr=args.lr, weight_decay=args.weight_decay) optimizer_sparse = None elif args.optim.lower() == 'jitlamb': optimizer = lamb.JITLamb(model.parameters(), lr=args.lr, weight_decay=args.weight_decay) optimizer_sparse = None model = model.to(device) scaler = None if args.fp16: if args.amp == 'pytorch': scaler = torch.cuda.amp.GradScaler() elif args.amp == 'apex': model, optimizer = amp.initialize( model, optimizer, opt_level=args.apex_amp_opt_level, ) if args.multi_gpu == 'ddp' and torch.distributed.is_initialized(): para_model = DistributedDataParallel( model, device_ids=[args.local_rank], output_device=args.local_rank, broadcast_buffers=False, find_unused_parameters=True, ) elif args.multi_gpu == 'dp': if args.gpu0_bsz >= 0: para_model = BalancedDataParallel(args.gpu0_bsz // args.batch_chunk, model, dim=1).to(device) else: para_model = nn.DataParallel(model, dim=1).to(device) else: para_model = model # scheduler if args.scheduler == 'cosine': if args.max_step_scheduler: max_step = args.max_step_scheduler else: max_step = args.max_step scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, max_step - args.warmup_step, eta_min=args.eta_min) if args.sample_softmax > 0 and optimizer_sparse is not None: scheduler_sparse = optim.lr_scheduler.CosineAnnealingLR( optimizer_sparse, max_step - args.warmup_step, eta_min=args.eta_min) else: scheduler_sparse = None elif args.scheduler == 'inv_sqrt': # originally used for Transformer (in Attention is all you need) def lr_lambda(step): # return a multiplier instead of a learning rate if step == 0 and args.warmup_step == 0: return 1. else: return 1. / (step ** 0.5) if step > args.warmup_step \ else step / (args.warmup_step ** 1.5) scheduler = optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lr_lambda) if args.sample_softmax > 0 and optimizer_sparse is not None: scheduler_sparse = optim.lr_scheduler.LambdaLR(optimizer_sparse, lr_lambda=lr_lambda) else: scheduler_sparse = None elif args.scheduler == 'dev_perf': scheduler = optim.lr_scheduler.ReduceLROnPlateau( optimizer, factor=args.decay_rate, patience=args.patience, min_lr=args.lr_min, ) if args.sample_softmax > 0 and optimizer_sparse is not None: scheduler_sparse = optim.lr_scheduler.ReduceLROnPlateau( optimizer_sparse, factor=args.decay_rate, patience=args.patience, min_lr=args.lr_min, ) else: scheduler_sparse = None elif args.scheduler == 'constant': pass logging.info('=' * 100) for k, v in args.__dict__.items(): logging.info(' - {} : {}'.format(k, v)) logging.info('=' * 100) logging.info('#params = {}'.format(args.n_all_param)) logging.info('#non emb params = {}'.format(args.n_nonemb_param)) train_step = 0 start_epoch = 1 last_batch = 0 last_iter = 0 best_val_loss = None if args.restart: try: checkpoint = load_checkpoint(args.restart) model.load_state_dict(checkpoint['model_state']) optimizer.load_state_dict(checkpoint['optimizer_state']) scheduler.load_state_dict(checkpoint['scheduler_state']) if args.fp16: if args.amp == 'pytorch': scaler.load_state_dict(checkpoint['amp_state']) elif args.amp == 'apex': amp.load_state_dict(checkpoint['amp_state']) train_step = checkpoint['train_step'] start_epoch = checkpoint['epoch'] last_batch = checkpoint['batch'] last_iter = checkpoint['last_iter'] best_val_loss = checkpoint['best_val_loss'] if train_step >= args.max_step: logging.info( f'Loaded checkpoint after {train_step} steps, but ' f'this run was scheduled for a total of ' f'{args.max_step} steps, exiting') sys.exit(1) model.apply(functools.partial(update_dropout, args=args)) model.apply(functools.partial(update_dropatt, args=args)) except FileNotFoundError: logging.info(f'Could not load checkpoint from {args.restart}, ' f'starting training from random init') meters = {} warmup = args.mem_len // args.tgt_len + 2 meters['train_throughput'] = AverageMeter(warmup=warmup) ########################################################################### # Train ########################################################################### # Loop over epochs. # At any point you can hit Ctrl + C to break out of training early. start_time = time.time() with torch.autograd.profiler.emit_nvtx(enabled=args.profile): with TimeoutHandler() as timeout_handler: try: for epoch in itertools.count(start=start_epoch): if args.roll: tr_iter.roll(seed=args.seed + epoch) train_step, best_val_loss = train( tr_iter, va_iter, model, para_model, model_config, optimizer, optimizer_sparse, scheduler, scheduler_sparse, scaler, vocab, epoch, last_batch, last_iter, train_step, best_val_loss, meters, timeout_handler, device, args) last_batch = 0 last_iter = 0 if train_step == args.max_step: logging.info('-' * 100) logging.info('End of training') break except KeyboardInterrupt: logging.info('-' * 100) logging.info('Exiting from training early') elapsed = time.time() - start_time ########################################################################### # Test ########################################################################### summary = {} test_path = os.path.join(args.work_dir, 'checkpoint_best.pt') if not args.debug and not args.no_eval and os.path.exists(test_path): # Load the best saved model. checkpoint = load_checkpoint(test_path) model.load_state_dict(checkpoint['model_state']) # Run on test data. test_start_time = time.time() with torch.autograd.profiler.emit_nvtx(enabled=args.profile): test_loss = evaluate(te_iter, model, args) test_loss = utils.distributed.all_reduce_item(test_loss, 'mean') test_elapsed = time.time() - test_start_time logging.info('=' * 100) if args.dataset in ['enwik8', 'text8']: logging.info( '| End of training | test time: {:5.2f}s | test loss {:5.2f} | test bpc {:9.5f}' .format(test_elapsed, test_loss, test_loss / math.log(2))) else: logging.info( '| End of training | test time: {:5.2f}s | test loss {:5.2f} | test ppl {:9.3f}' .format(test_elapsed, test_loss, math.exp(test_loss))) logging.info('=' * 100) summary.update({ 'test_elapsed': test_elapsed, 'test_loss': test_loss, }) if args.dataset in ['enwik8', 'text8']: summary['test_bits_per_character'] = test_loss / math.log(2) else: summary['test_perplexity'] = math.exp(test_loss) logging.info(f'Training time: {(elapsed / 60):.2f} minutes') logging.info( f'Training throughput: {meters["train_throughput"].avg:.2f} tok/s') if best_val_loss: val_perplexity = math.exp(best_val_loss) else: val_perplexity = None summary.update({ 'train_throughput': meters['train_throughput'].avg, 'train_elapsed': elapsed / 60, 'valid_loss': best_val_loss, 'valid_perplexity': val_perplexity, }) dllogger.log(step=tuple(), data=summary) passed = benchmark(target_perplexity=args.target_perplexity, test_perplexity=val_perplexity, target_throughput=args.target_throughput, test_throughput=meters['train_throughput'].avg) if not passed: sys.exit(1)
def main(): args = parse_args() # Initialize device and distributed backend torch.cuda.set_device(args.local_rank) device = torch.device('cuda' if args.cuda else 'cpu') utils.distributed.init_distributed(args.cuda) args.work_dir = utils.exp_utils.build_work_dir_name( args.work_dir, args.dataset, args.append_dataset, args.append_time, ) with utils.distributed.sync_workers() as rank: if rank == 0: create_exp_dir(args.work_dir, scripts_to_save=['train.py', 'mem_transformer.py'], debug=args.debug) # Setup logging if args.log_all_ranks: log_file = f'log_rank_{utils.distributed.get_rank()}.log' else: log_file = f'log.log' log_file = os.path.join(args.work_dir, log_file) if args.debug: log_file = os.devnull utils.exp_utils.setup_logging( log_all_ranks=args.log_all_ranks, filename=log_file, ) logging.info(args) # Set the random seed manually for reproducibility. np.random.seed(args.seed + utils.distributed.get_rank()) torch.manual_seed(args.seed + utils.distributed.get_rank()) ########################################################################### # Load data ########################################################################### corpus = get_lm_corpus(args.data, args.dataset, args.vocab) ntokens = len(corpus.vocab) vocab = corpus.vocab args.n_token = ntokens tr_iter = corpus.get_iterator('train', args.batch_size, args.tgt_len, device=device, ext_len=args.ext_len) va_iter = corpus.get_iterator('valid', args.eval_batch_size, args.eval_tgt_len, device=device, ext_len=args.ext_len) te_iter = corpus.get_iterator('test', args.eval_batch_size, args.eval_tgt_len, device=device, ext_len=args.ext_len) # adaptive softmax / embedding cutoffs, tie_projs = [], [False] if args.adaptive: assert args.dataset in ['wt103', 'lm1b'] if args.dataset == 'wt103': cutoffs = [19997, 39997, 199997] tie_projs += [True] * len(cutoffs) elif args.dataset == 'lm1b': cutoffs = [59997, 99997, 639997] tie_projs += [False] * len(cutoffs) ########################################################################### # Build the model ########################################################################### model_config = { 'n_token': ntokens, 'n_layer': args.n_layer, 'n_head': args.n_head, 'd_model': args.d_model, 'd_head': args.d_head, 'd_inner': args.d_inner, 'dropout': args.dropout, 'dropatt': args.dropatt, 'dtype': None, 'tie_weight': args.tied, 'd_embed': args.d_embed, 'div_val': args.div_val, 'tie_projs': tie_projs, 'pre_lnorm': args.pre_lnorm, 'tgt_len': args.tgt_len, 'ext_len': args.ext_len, 'mem_len': args.mem_len, 'cutoffs': cutoffs, 'same_length': args.same_length, 'attn_type': args.attn_type, 'clamp_len': args.clamp_len, 'sample_softmax': args.sample_softmax, } model = MemTransformerLM(**model_config) model.apply(functools.partial(weights_init, args=args)) # ensure embedding init is not overridden by out_layer in case of weight sharing model.word_emb.apply(functools.partial(weights_init, args=args)) args.n_all_param = sum([p.nelement() for p in model.parameters()]) args.n_nonemb_param = sum( [p.nelement() for p in model.layers.parameters()]) # optimizer if args.optim.lower() == 'sgd': if args.sample_softmax > 0: dense_params, sparse_params = [], [] for param in model.parameters(): if param.size() == model.word_emb.weight.size(): sparse_params.append(param) else: dense_params.append(param) optimizer_sparse = optim.SGD(sparse_params, lr=args.lr * 2) optimizer = optim.SGD(dense_params, lr=args.lr, momentum=args.mom) else: optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=args.mom) optimizer_sparse = None elif args.optim.lower() == 'adam': if args.sample_softmax > 0: dense_params, sparse_params = [], [] for param in model.parameters(): if param.size() == model.word_emb.weight.size(): sparse_params.append(param) else: dense_params.append(param) optimizer_sparse = optim.SparseAdam(sparse_params, lr=args.lr) optimizer = optim.Adam(dense_params, lr=args.lr, weight_decay=args.weight_decay) else: optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay) optimizer_sparse = None elif args.optim.lower() == 'adagrad': optimizer = optim.Adagrad(model.parameters(), lr=args.lr) optimizer_sparse = None elif args.optim.lower() == 'lamb': optimizer = lamb.Lamb(model.parameters(), lr=args.lr, weight_decay=args.weight_decay) optimizer_sparse = None elif args.optim.lower() == 'jitlamb': optimizer = lamb.JITLamb(model.parameters(), lr=args.lr, weight_decay=args.weight_decay) optimizer_sparse = None model = model.to(device) if args.fp16: model, optimizer = amp.initialize( model, optimizer, opt_level='O2', ) if args.multi_gpu == 'ddp' and torch.distributed.is_initialized(): para_model = DistributedDataParallel( model, delay_allreduce=True, ) elif args.multi_gpu == 'dp': if args.gpu0_bsz >= 0: para_model = BalancedDataParallel(args.gpu0_bsz // args.batch_chunk, model, dim=1).to(device) else: para_model = nn.DataParallel(model, dim=1).to(device) else: para_model = model # scheduler if args.scheduler == 'cosine': if args.max_step_scheduler: max_step = args.max_step_scheduler else: max_step = args.max_step scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, max_step, eta_min=args.eta_min) if args.sample_softmax > 0: scheduler_sparse = optim.lr_scheduler.CosineAnnealingLR( optimizer_sparse, max_step, eta_min=args.eta_min) else: scheduler_sparse = None elif args.scheduler == 'inv_sqrt': # originally used for Transformer (in Attention is all you need) def lr_lambda(step): # return a multiplier instead of a learning rate if step == 0 and args.warmup_step == 0: return 1. else: return 1. / (step ** 0.5) if step > args.warmup_step \ else step / (args.warmup_step ** 1.5) scheduler = optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lr_lambda) elif args.scheduler == 'dev_perf': scheduler = optim.lr_scheduler.ReduceLROnPlateau( optimizer, factor=args.decay_rate, patience=args.patience, min_lr=args.lr_min, ) if args.sample_softmax > 0: scheduler_sparse = optim.lr_scheduler.ReduceLROnPlateau( optimizer_sparse, factor=args.decay_rate, patience=args.patience, min_lr=args.lr_min, ) else: scheduler_sparse = None elif args.scheduler == 'constant': pass logging.info('=' * 100) for k, v in args.__dict__.items(): logging.info(' - {} : {}'.format(k, v)) logging.info('=' * 100) logging.info('#params = {}'.format(args.n_all_param)) logging.info('#non emb params = {}'.format(args.n_nonemb_param)) train_step = 0 best_val_loss = None if args.restart: checkpoint = load_checkpoint(args.restart) model.load_state_dict(checkpoint['model_state']) optimizer.load_state_dict(checkpoint['optimizer_state']) scheduler.load_state_dict(checkpoint['scheduler_state']) if args.fp16: amp.load_state_dict(checkpoint['amp_state']) train_step = checkpoint['train_step'] best_val_loss = checkpoint['best_val_loss'] model.apply(functools.partial(update_dropout, args=args)) model.apply(functools.partial(update_dropatt, args=args)) meters = {} warmup = args.mem_len // args.tgt_len + 1 meters['train_throughput'] = AverageMeter(warmup=warmup) ########################################################################### # Train ########################################################################### # Loop over epochs. # At any point you can hit Ctrl + C to break out of training early. start_time = time.time() try: for epoch in itertools.count(start=1): if args.roll: tr_iter.roll() train_step, best_val_loss = train(tr_iter, va_iter, model, para_model, model_config, optimizer, optimizer_sparse, scheduler, scheduler_sparse, vocab, epoch, train_step, best_val_loss, meters, args) if train_step == args.max_step: logging.info('-' * 100) logging.info('End of training') break except KeyboardInterrupt: logging.info('-' * 100) logging.info('Exiting from training early') elapsed = time.time() - start_time ########################################################################### # Test ########################################################################### test_path = os.path.join(args.work_dir, 'checkpoint_best.pt') if not args.debug and os.path.exists(test_path): # Load the best saved model. checkpoint = load_checkpoint(test_path) model.load_state_dict(checkpoint['model_state']) # Run on test data. test_start_time = time.time() test_loss = evaluate(te_iter, model, args) test_loss = utils.distributed.all_reduce_item(test_loss, 'mean') logging.info('=' * 100) if args.dataset in ['enwik8', 'text8']: logging.info( '| End of training | test time: {:5.2f}s | test loss {:5.2f} | test bpc {:9.5f}' .format(time.time() - test_start_time, test_loss, test_loss / math.log(2))) else: logging.info( '| End of training | test time: {:5.2f}s | test loss {:5.2f} | test ppl {:9.3f}' .format(time.time() - test_start_time, test_loss, math.exp(test_loss))) logging.info('=' * 100) logging.info(f'Training time: {(elapsed / 60):.2f} minutes') logging.info( f'Training throughput: {meters["train_throughput"].avg:.2f} tok/s') if best_val_loss: val_perplexity = math.exp(best_val_loss) else: val_perplexity = None passed = benchmark(target_perplexity=args.target_perplexity, test_perplexity=val_perplexity, target_throughput=args.target_throughput, test_throughput=meters['train_throughput'].avg) if not passed: sys.exit(1)
same_length=args.same_length, attn_type=args.attn_type, clamp_len=args.clamp_len, sample_softmax=args.sample_softmax) else: model = MemTransformerLM(ntokens, args.n_layer, args.n_head, args.d_model, args.d_head, args.d_inner, args.dropout, args.dropatt, tie_weight=args.tied, d_embed=args.d_embed, div_val=args.div_val, tie_projs=tie_projs, pre_lnorm=args.pre_lnorm, tgt_len=args.tgt_len, ext_len=args.ext_len, mem_len=args.mem_len, cutoffs=cutoffs, same_length=args.same_length, attn_type=args.attn_type, clamp_len=args.clamp_len, sample_softmax=args.sample_softmax) model.apply(weights_init) model.word_emb.apply( weights_init ) # ensure embedding init is not overridden by out_layer in case of weight sharing args.n_all_param = sum([p.nelement() for p in model.parameters()])
n_token_in = (4 + 4 * args.methylation)**args.merge_size n_token_out = 2 print(args.coords[1:4]) model = MemTransformerLM(n_token_in, n_token_out, args.n_layer, args.n_head, args.d_model, args.d_head, args.d_inner, args.dropout, args.dropatt, tie_weight=args.tied, d_embed=args.d_embed, custom_emb=None, tie_projs=[False], pre_lnorm=args.pre_lnorm, tgt_len=args.tgt_len, ext_len=args.ext_len, mem_len=args.mem_len, cutoffs=[], same_length=args.same_length, clamp_len=args.clamp_len, sample_softmax=args.sample_softmax) model.load_state_dict(torch.load(os.path.join(args.restart_dir, 'model.pt'))) model = model.to(device)
mem_len=args.mem_len, cutoffs=cutoffs, same_length=args.same_length, attn_type=args.attn_type, clamp_len=args.clamp_len, sample_softmax=args.sample_softmax, ) if args.restart: print('Restarting training from {args.restart_dir}') with open(os.path.join(args.restart_dir, 'model.pt'), 'rb') as f: state = torch.load(f) if isinstance(state, MemTransformerLM): # old format model = state else: model_params = state['model_params'] model = MemTransformerLM(**model_params) model.load_state_dict(state['state_dict']) del state if not args.fp16: model = model.float() model.apply(update_dropout) model.apply(update_dropatt) else: model = MemTransformerLM(**model_params) model.apply(weights_init) # ensure embedding init is not overridden by out_layer # in case of weight sharing model.word_emb.apply(weights_init) args.n_all_param = sum([p.nelement() for p in model.parameters()]) args.n_nonemb_param = sum([p.nelement() for p in model.layers.parameters()])
def train_ts(args): def build_scheduler(optimizers, args): optimizer, optimizer_sparse = optimizers scheduler_sparse = None if args.scheduler == "cosine": # here we do not set eta_min to lr_min to be backward compatible # because in previous versions eta_min is default to 0 # rather than the default value of lr_min 1e-6 scheduler = optim.lr_scheduler.CosineAnnealingLR( optimizer, args.max_step, eta_min=args.eta_min) # should use eta_min arg elif args.scheduler == "inv_sqrt": # originally used for Transformer (in Attention is all you need) def lr_lambda(step): # return a multiplier instead of a learning rate if step == 0 and args.warmup_step == 0: return 1.0 else: return (1.0 / (step**0.5) if step > args.warmup_step else step / (args.warmup_step**1.5)) scheduler = optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lr_lambda) elif args.scheduler == "dev_perf": scheduler = optim.lr_scheduler.ReduceLROnPlateau( optimizer, factor=args.decay_rate, patience=args.patience, min_lr=args.lr_min, ) elif args.scheduler == "constant": pass else: raise ValueError(f"scheduler type {args.scheduler} not recognized") return scheduler, scheduler_sparse ############################################################################### # Training code ############################################################################### def evaluate(eval_iter, model): # Turn on evaluation mode which disables dropout. model.eval() # debug # If the model does not use memory at all, make the ext_len longer. # Otherwise, make the mem_len longer and keep the ext_len the same. # if default_args.mem_len == 0: # model.reset_length(default_args.eval_tgt_len, # default_args.ext_len + default_args.tgt_len - # default_args.eval_tgt_len, default_args.mem_len) # else: # model.reset_length(default_args.eval_tgt_len, # default_args.ext_len, default_args.mem_len + # default_args.tgt_len - default_args.eval_tgt_len) # Evaluation total_len, total_loss = 0, 0.0 with torch.no_grad(): mems = tuple() for i, (data, target, seq_len) in enumerate(eval_iter): if i >= args.max_eval_steps > 0: break ret = model(data, target, *mems) loss, mems = ret[0], ret[1:] loss = loss.mean() total_loss += seq_len * loss.float().item() total_len += seq_len # Switch back to the training mode # model.reset_length(default_args.tgt_len, default_args.ext_len, # default_args.mem_len) model.train() return total_loss / total_len # reverse distillation util def get_original_batches(model, tr_iter, integration_length): model.eval() if args.batch_chunk > 1: mems = [None for _ in range(args.batch_chunk)] first_logits = [[] for _ in range(args.batch_chunk)] else: mems = None first_logits = [] train_iter = tr_iter.get_varlen_iter() if args.varlen else tr_iter with torch.no_grad(): for batch, (data, target, seq_len) in enumerate(train_iter): if batch == integration_length: break if args.batch_chunk > 1: data_chunks = torch.chunk(data, args.batch_chunk, 1) for i in range(args.batch_chunk): data_i = data_chunks[i].contiguous() logits, mems[i] = model._forward(data_i, mems=mems[i]) first_logits[i].append(logits.cpu()) else: logits, mems = model._forward(data, mems=mems) first_logits.append(logits.cpu()) return first_logits def build_optimizer(model, args, reload=False): optimizer_sparse = None if args.optim.lower() == "sgd": optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=args.mom) elif args.optim.lower() == "adam": optimizer = optim.Adam(model.parameters(), lr=args.lr) elif args.optim.lower() == "adagrad": optimizer = optim.Adagrad(model.parameters(), lr=args.lr) else: raise ValueError(f"optimizer type {args.optim} not recognized") if reload: if args.restart_from is not None: optim_name = f"optimizer_{args.restart_from}.pt" else: optim_name = "optimizer.pt" optim_file_name = os.path.join(args.restart_dir, optim_name) logging(f"reloading {optim_file_name}") if os.path.exists(os.path.join(args.restart_dir, optim_name)): with open(os.path.join(args.restart_dir, optim_name), "rb") as optim_file: opt_state_dict = torch.load(optim_file) try: optimizer.load_state_dict(opt_state_dict) # in case the optimizer param groups aren't the same shape, # merge them except: logging("merging optimizer param groups") opt_state_dict["param_groups"][0]["params"] = [ param for param_group in opt_state_dict["param_groups"] for param in param_group["params"] ] opt_state_dict["param_groups"] = [ opt_state_dict["param_groups"][0] ] optimizer.load_state_dict(opt_state_dict) else: logging("Optimizer was not saved. Start from scratch.") return optimizer, optimizer_sparse def log_val(val_loss, step, compute): logging("-" * 100) log_str = ("| Eval {:3d} at step {:>8d} | time: {:5.2f}s " "| valid loss {:5.2f}".format( step // args.eval_interval, step, (time.time() - eval_start_time), val_loss, )) log_str += " | bpc {:9.5f}".format(val_loss / math.log(2)) logging(log_str) logging("-" * 100) def epoch_loop( epoch, model, optimizers, schedulers, ): nonlocal train_step # Turn on training mode which enables dropout. if isinstance(model, nn.DataParallel): parent_model = model.module else: parent_model = model optimizer, optimizer_sparse = optimizers scheduler, scheduler_sparse = schedulers # global train_step, best_val_loss, eval_start_time, log_start_time train_losses = [] model.train() if args.batch_chunk > 1: mems = [tuple() for _ in range(args.batch_chunk)] else: mems = tuple() train_iter = tr_iter.get_varlen_iter() if args.varlen else tr_iter log_start_time = time.time() best_val_loss = float("Infinity") for batch, (data, target, seq_len) in enumerate(train_iter): model.zero_grad() if args.batch_chunk > 1: data_chunks = torch.chunk(data, args.batch_chunk, 1) target_chunks = torch.chunk(target, args.batch_chunk, 1) for i in range(args.batch_chunk): data_i = data_chunks[i].contiguous() target_i = target_chunks[i].contiguous() ret = model(data_i, target_i, *mems[i]) loss, mems[i] = ret[0], ret[1:] loss = loss.float().mean().type_as(loss) / args.batch_chunk if args.fp16: with amp.scale_loss(loss, optimizer) as scaled_loss: scaled_loss.backward() else: loss.backward() train_losses.append(loss.float().item()) else: ret = model(data, target, *mems) loss, mems = ret[0], ret[1:] loss = loss.float().mean().type_as(loss) if args.fp16: with amp.scale_loss(loss, optimizer) as scaled_loss: scaled_loss.backward() else: loss.backward() train_losses.append(loss.float().item()) if args.fp16: torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), args.clip) else: torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip) optimizer.step() parent_model.compute += openai_compute( non_emb_param_count(parent_model, nseries), data.numel(), 1) # step-wise learning rate annealing train_step += 1 parent_model.training_steps += 1 # check for yet-to-thaw parameters if getattr(parent_model, "freeze_countdown", 0) > 0: parent_model.freeze_countdown -= 1 # if this is the last step if parent_model.freeze_countdown == 0: for parameter in parent_model.parameters(): parameter.requires_grad = True logging("thawing all parameters") if args.scheduler in ["cosine", "constant", "dev_perf"]: # linear warmup stage if train_step < args.warmup_step: curr_lr = args.lr * train_step / args.warmup_step optimizer.param_groups = curr_lr else: if args.scheduler == "cosine": scheduler.step(train_step) elif args.scheduler == "inv_sqrt": scheduler.step(train_step) if train_step % args.log_interval == 0: cur_loss = np.mean(train_losses) elapsed = time.time() - log_start_time log_str = ("| epoch {:3d} step {:>8d} " "| {:>6d} batches " "| lr {:.3g} " "| ms/batch {:5.2f} " "| loss {:5.2f}".format( epoch, train_step, batch + 1, optimizer.param_groups[0]["lr"], elapsed * 1000 / args.log_interval, cur_loss, )) log_str += " | bpc {:9.5f}".format(cur_loss / math.log(2)) logging(log_str) train_losses = [] log_start_time = time.time() if train_step % args.eval_interval == 0: val_loss = evaluate(va_iter, model) log_val(val_loss, step=train_step, compute=parent_model.compute) # Save the model if the validation loss is the best we've seen so # far. if not best_val_loss or val_loss < best_val_loss: best_val_loss = val_loss if not args.debug: if args.fp16: with open( os.path.join(args.work_dir, "amp_checkpoint.pt"), "wb", ) as f: checkpoint = { "model": model.state_dict(), "optimizer": optimizer.state_dict(), "amp": amp.state_dict(), } torch.save(checkpoint, f) else: with open(os.path.join(args.work_dir, "model.pt"), "wb") as f: torch.save(parent_model, f) with open( os.path.join(args.work_dir, "optimizer.pt"), "wb", ) as f: torch.save(optimizer.state_dict(), f) # dev-performance based learning rate annealing if args.scheduler == "dev_perf": scheduler.step(val_loss) eval_start_time = time.time() if train_step == args.max_step: break def expand_model( strategy, integration, integration_length, n_add, model: MemTransformerLM, optimizers, schedulers, tr_iter, va_iter, epoch, step, ): optimizer, _ = optimizers scheduler, _ = schedulers if integration: if not integration_length or integration_length <= 0: warnings.warn( f"integration {integration} passed but integration_length is {integration_length}" ) else: logging( f"applying integration strategy {integration} with integration length {integration_length}" ) # pre-expansion validation logging(f"evaluating before expanding") val_loss = evaluate(va_iter, model) log_val(val_loss, step=step, compute=model.compute) # infer example logits for reverse distillation if "reverse_distil" in integration: first_logits = get_original_batches(model, tr_iter, integration_length) # expansion logging( f"adding {n_add} layers before starting epoch {epoch} with method {strategy}" ) new_layers = model.expand_layers(n_add, strategy=strategy, function=initialization_func) # optimizer update optimizer.add_param_group({ "params": new_layers.parameters(), "lr": optimizer.param_groups[0]["lr"], "initial_lr": optimizer.param_groups[0]["initial_lr"], }) scheduler.base_lrs.append(optimizer.param_groups[-1]["initial_lr"]) # training loop for reverse distillation if "reverse_distil" in integration: fit_to_previous_model(model, new_layers, tr_iter, first_logits, integration) # freezing parameters for frozen restart, we do this afterwards else the # new layers get copied also without grads if "freeze" in integration and integration_length > 0: for param_group in optimizer.param_groups[:-1]: for parameter in param_group["params"]: parameter.requires_grad = False model.freeze_countdown = integration_length # post-expansion validation logging(f"reevaluating") val_loss = evaluate(va_iter, model) log_val(val_loss, step=step, compute=model.compute) def expand_state(param, state): if param.shape != state.shape: ratios = [ param.shape[i] // state.shape[i] for i in range(len(param.shape)) ] return state.repeat(*ratios) else: return state def widen_model( strategy, ratio, model: MemTransformerLM, optimizers, va_iter, epoch, step, ): optimizer, _ = optimizers # pre-expansion validation logging(f"evaluating before widening") # debug val_loss = evaluate(va_iter, model) log_val(val_loss, compute=model.compute, step=step) # infer example logits for reverse distillation expansion logging( f"adding {ratio} layers before starting epoch {epoch} with method {strategy}" ) model.add_heads(ratio, strategy=strategy, function=initialization_func) # optimizer update for param, states in optimizer.state.items(): if isinstance(param, nn.Parameter): states["exp_avg"] = expand_state(param, states["exp_avg"]) states["exp_avg_sq"] = expand_state(param, states["exp_avg_sq"]) # training loop for reverse distillation # post-expansion validation logging(f"reevaluating") val_loss = evaluate(va_iter, model) log_val(val_loss, step=step, compute=model.compute) # reverse distillation trainer def fit_to_previous_model(model, new_layers, tr_iter, first_logits, integration): mse_loss = torch.nn.MSELoss() if "partial" in integration: distil_optimizer, distil_optimizer_sparse = build_optimizer( new_layers, reload=False) else: distil_optimizer, distil_optimizer_sparse = build_optimizer( model, reload=False) if args.cuda and args.fp16: model, distil_optimizer = amp.initialize(model, distil_optimizer, opt_level=args.fp16) model.train() if args.batch_chunk > 1: mems = [None for _ in range(args.batch_chunk)] else: mems = None train_iter = tr_iter.get_varlen_iter() if args.varlen else tr_iter for batch, (data, _, _) in enumerate(train_iter): if batch == len(first_logits): break model.zero_grad() if args.batch_chunk > 1: data_chunks = torch.chunk(data, args.batch_chunk, 1) for i in range(args.batch_chunk): data_i = data_chunks[i].contiguous() logits, mems[i] = model._forward(data_i, mems=mems[i]) target_logits = first_logits[i][batch].to(logits.device) loss = mse_loss(logits, target_logits) / args.batch_chunk if args.fp16: with amp.scale_loss(loss, distil_optimizer) as scaled_loss: scaled_loss.backward() else: loss.backward() else: logits, mems = model._forward(data, mems=mems) target_logits = first_logits[batch].to(logits.device) loss = mse_loss(logits, target_logits) if args.fp16: with amp.scale_loss(loss, distil_optimizer) as scaled_loss: scaled_loss.backward() else: loss.backward() if args.fp16: torch.nn.utils.clip_grad_norm_( amp.master_params(distil_optimizer), args.clip) else: torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip) distil_optimizer.step() ################################################################################### # # main() # args.tied = not args.not_tied if args.d_embed < 0: args.d_embed = args.d_model # Validate `--fp16` option if args.fp16: if not args.cuda: print("WARNING: --fp16 requires --cuda, ignoring --fp16 option") args.fp16 = False else: try: from apex import amp if args.fp16 == "O1": amp.register_half_function(torch, "einsum") except: print("WARNING: apex not installed, ignoring --fp16 option") args.fp16 = False device = torch.device("cuda" if args.cuda else "cpu") # Set the random seed manually for reproducibility. np.random.seed(args.seed) torch.manual_seed(args.seed) if torch.cuda.is_available(): if not args.cuda: print( "WARNING: You have a CUDA device, so you should probably run " "with --cuda ") else: torch.cuda.manual_seed_all(args.seed) ############################################################################ # Logging ############################################################################ assert args.ext_len >= 0, "extended context length must be non-negative" assert args.d_batch % args.batch_chunk == 0 args.work_dir = "{}-{}".format(args.work_dir, args.dataset) args.work_dir = os.path.join(args.work_dir, time.strftime("%Y%m%d-%H%M%S")) logging = create_exp_dir( args.work_dir, scripts_to_save=["train_ts.py", "mem_transformer.py"], debug=args.debug, ) ############################################################################ # Load data ############################################################################ time_series = get_time_series(args.datadir, args.dataset) nseries = len(time_series.vocab) args.n_token = nseries eval_batch_size = 20 tr_iter = time_series.get_iterator( "train", args.d_batch, args.tgt_len, device=device, ext_len=args.ext_len, ) va_iter = time_series.get_iterator( "valid", eval_batch_size, args.eval_tgt_len, device=device, ext_len=args.ext_len, ) te_iter = time_series.get_iterator( "test", eval_batch_size, args.eval_tgt_len, device=device, ext_len=args.ext_len, ) cutoffs, tie_projs = [], [False] ############################################################################ # Define model ############################################################################ initialization_func = partial( weights_init, init=args.init, init_range=args.init_range, init_std=args.init_std, proj_init_std=args.proj_init_std, ) if args.restart and not args.fp16: if args.restart_from is not None: model_name = f"model_{args.restart_from}.pt" else: model_name = "model.pt" model_file_name = os.path.join(args.restart_dir, model_name) logging(f"reloading {model_file_name}") with open(model_file_name, "rb") as f: model = torch.load(f) # backwards compatibility with older saves if isinstance(model, nn.DataParallel): model = model.module model.backward_compatible(tie_weight=args.tied, tie_projs=tie_projs) if not args.fp16: model = model.float() model.apply(update_dropout) model.apply(update_dropatt) else: model = MemTransformerLM( nseries, args.n_layer, args.n_head, args.d_model, args.d_head, args.d_inner, args.dropout, args.dropatt, tie_weight=args.tied, d_embed=args.d_embed, div_val=args.div_val, tie_projs=tie_projs, pre_lnorm=args.pre_lnorm, tgt_len=args.tgt_len, ext_len=args.ext_len, mem_len=args.mem_len, cutoffs=cutoffs, same_length=args.same_length, clamp_len=args.clamp_len, ) model.apply(initialization_func) # debug # model.word_emb.apply(initialization_func) # ensure embedding init is not overridden by out_layer in case of # weight sharing args.n_all_param = sum([p.nelement() for p in model.parameters()]) args.n_nonemb_param = non_emb_param_count(model, nseries) logging("=" * 100) for k, v in args.__dict__.items(): logging(" - {} : {}".format(k, v)) logging("=" * 100) logging("#params = {}".format(args.n_all_param)) logging("#non emb params = {}".format(args.n_nonemb_param)) para_model = parallelize_model(model, args) optimizers = build_optimizer(para_model, args, reload=args.restart and not args.fp16) optimizer, optimizer_sparse = optimizers schedulers = build_scheduler(optimizers, args) scheduler, scheduler_sparse = schedulers if args.cuda and args.fp16: para_model, optimizer = amp.initialize(para_model, optimizer, opt_level=args.fp16) if args.restart: if args.restart_from is not None: checkpoint_name = f"amp_checkpoint_{args.restart_from}.pt" else: checkpoint_name = "amp_checkpoint.pt" with open(os.path.join(args.work_dir, checkpoint_name), "rb") as f: checkpoint = torch.load(f) model.load_state_dict(checkpoint["model"]) optimizer.load_state_dict(checkpoint["optimizer"]) amp.load_state_dict(checkpoint["amp"]) ############################################################################ # Training loop ############################################################################ # Loop over epochs. if args.reset_lr: # then they're different and we use train_step only for the new lr # scheduling train_step = 0 optimizer.defaults["lr"] = args.lr for param_group in optimizer.param_groups: param_group["lr"] = args.lr param_group["initial_lr"] = args.lr scheduler.base_lrs = [args.lr] * len(scheduler.base_lrs) else: train_step = model.training_steps best_val_loss = None # Reload previous step number in case of default_args.restart if train_step > 0: logging(f"restarting from step {train_step}") log_start_time = time.time() eval_start_time = time.time() def run_training(): nonlocal train_step for epoch in itertools.count(start=first_epoch): # we check before the training loop; expanding at epoch 0 means # before training (for debug purposes) if args.expand and str(epoch - 1) in args.expansion_dict: n_add = int(args.expansion_dict[str(epoch - 1)]) expand_model( args.expand, args.integration, args.integration_length, n_add, model, optimizers, schedulers, tr_iter, va_iter, epoch, train_step, ) if args.widen and str(epoch - 1) in args.widen_dict: ratio = int(args.widen_dict[str(epoch - 1)]) widen_model( args.widen, ratio, model, optimizers, va_iter, epoch, train_step, ) epoch_loop(epoch, para_model, optimizers, schedulers) if train_step >= args.max_step: logging("-" * 100) logging("End of training") break if not args.debug and args.log_first_epochs: if epoch <= args.log_first_epochs: logging(f"saving model at the end of epoch {epoch}") if args.fp16: with open( os.path.join(args.work_dir, f"amp_checkpoint_{epoch}.pt"), "wb", ) as f: checkpoint = { "model": model.state_dict(), "optimizer": optimizer.state_dict(), "amp": amp.state_dict(), } torch.save(checkpoint, f) else: with open( os.path.join(args.work_dir, f"model_{epoch}.pt"), "wb", ) as f: torch.save(model, f) with open( os.path.join(args.work_dir, f"optimizer_{epoch}.pt"), "wb", ) as f: torch.save(optimizer.state_dict(), f) # At any point you can hit Ctrl + C to break out of training early. try: if args.restart_from: first_epoch = args.restart_from + 1 print(f"restarting from epoch {first_epoch}") else: first_epoch = 1 run_training() except KeyboardInterrupt: logging("-" * 100) logging("Exiting from training early") # Load the best model. if args.fp16: with open(os.path.join(args.work_dir, "amp_checkpoint.pt"), "rb") as f: checkpoint = torch.load(f) model.load_state_dict(checkpoint["model"]) optimizer.load_state_dict(checkpoint["optimizer"]) amp.load_state_dict(checkpoint["amp"]) else: with open(os.path.join(args.work_dir, "model.pt"), "rb") as f: model = torch.load(f) para_model = model.to(device) # Run on test data. test_loss = evaluate(te_iter, para_model) logging("=" * 100) logging("| End of training | test loss {:5.2f} | test bpc {:9.5f}".format( test_loss, test_loss / math.log(2))) logging("=" * 100)
def do_train(args): if args.use_gpu: rank = dist.get_rank() trainer_count = dist.get_world_size() else: rank = 0 trainer_count = 1 paddle.set_device("cpu") if trainer_count > 1: dist.init_parallel_env() random_seed = eval(str(args.random_seed)) if random_seed is not None: paddle.seed(random_seed) vocab = get_lm_vocab(args) train_loader = get_lm_data_loader(args, vocab, "train") eval_loader = get_lm_data_loader(args, vocab, "valid") cutoffs, tie_projs = [], [False] if args.adaptive: assert args.dataset in ['wt103', 'lm1b'] if args.dataset == 'wt103': cutoffs = [20000, 40000, 200000] tie_projs += [True] * len(cutoffs) elif args.dataset == 'lm1b': cutoffs = [60000, 100000, 640000] tie_projs += [False] * len(cutoffs) mem_transformer = MemTransformerLM(args.ntokens, args.n_layer, args.n_head, args.d_model, args.d_head, args.d_inner_hid, args.dropout, args.attn_dropout, tie_weight=args.tie_weight, d_embed=args.d_model, div_val=args.div_val, tie_projs=tie_projs, normalize_before=args.normalize_before, tgt_len=args.tgt_len, ext_len=args.ext_len, mem_len=args.mem_len, cutoffs=cutoffs, same_length=args.same_length, attn_type=args.attn_type, clamp_len=args.clamp_len, sample_softmax=args.sample_softmax) if args.scheduler == 'cosine': scheduler = paddle.optimizer.lr.CosineAnnealingDecay( learning_rate=args.learning_rate, T_max=args.max_step, eta_min=args.eta_min) elif args.scheduler == 'noam': scheduler = paddle.optimizer.lr.NoamDecay( d_model=args.d_model, warmup_steps=args.warmup_steps, learning_rate=args.learning_rate) elif args.scheduler == 'dev_perf': # fluid api scheduler = paddle.fluid.dygraph.ReduceLROnPlateau( learning_rate=args.learning_rate, decay_rate=args.decay_rate, patience=args.patience, min_lr=args.lr_min) elif args.scheduler == 'constant': scheduler = args.learning_rate clip = paddle.nn.ClipGradByGlobalNorm(args.clip) if args.optim.lower() == 'momentum': optimizer = paddle.optimizer.Momentum( learning_rate=scheduler, parameters=mem_transformer.parameters(), momentum=args.mom, grad_clip=clip) elif args.optim.lower() == 'adam': optimizer = paddle.optimizer.Adam( learning_rate=scheduler, parameters=mem_transformer.parameters(), beta1=args.beta1, beta2=args.beta2, epsilon=eval(args.eps), grad_clip=clip) elif args.optim.lower() == 'adagrad': optimizer = paddle.optimizer.Adagrad( learning_rate=scheduler, parameters=mem_transformer.parameters(), grad_clip=clip) # Init from some checkpoint, to resume the previous training if args.init_from_checkpoint: model_dict = paddle.load( os.path.join(args.init_from_checkpoint, "mem_transformer.pdparams")) opt_dict = paddle.load( os.path.join(args.init_from_checkpoint, "mem_transformer.pdopt")) mem_transformer.set_state_dict(model_dict) optimizer.set_state_dict(opt_dict) print("loaded from checkpoint.") # Init from some pretrain models, to better solve the current task if args.init_from_pretrain_model: model_dict = paddle.load( os.path.join(args.init_from_pretrain_model, "mem_transformer.pdparams")) mem_transformer.set_state_dict(model_dict) print("loaded from pre-trained model.") if trainer_count > 1: mem_transformer = paddle.DataParallel(mem_transformer) step_idx = 0 train_loss = 0.0 log_start_time = time.time() for pass_id in range(args.epoch): batch_id = 0 mems = tuple() for input_data in train_loader: (src, target, seq_len) = input_data ret = mem_transformer(src, target, *mems) loss = ret[0] mems = ret[1:] train_loss += loss.numpy() loss.backward() optimizer.step() optimizer.clear_grad() if step_idx > 0 and step_idx % args.print_step == 0 and rank == 0: cur_loss = train_loss / args.print_step elapsed = time.time() - log_start_time if args.scheduler == "constant": lr = optimizer.get_lr() else: lr = scheduler.get_lr() logger_info = "step_idx: %d, epoch: %d, batch: %d, learning rate: %.8f, " \ "speed: %f ms/batch, loss: %f" % \ (step_idx, pass_id, batch_id, lr, elapsed * 1000.0 / args.print_step, cur_loss) if args.dataset in ["enwik8", "text8"]: logger_info = logger_info + ", bpc: %f" % (cur_loss / np.log(2)) else: logger_info = logger_info + ", ppl: %f" % ( np.exp(cur_loss)) logger.info(logger_info) train_loss = 0.0 log_start_time = time.time() if step_idx % args.save_step == 0 and step_idx != 0: # Do validation. mem_transformer.eval() # TODO(FrostML): simplify this. if args.mem_len == 0: if dist.get_world_size() == 1: mem_transformer.reset_length(tgt_len=args.eval_tgt_len, ext_len=args.ext_len + args.tgt_len - args.eval_tgt_len, mem_len=args.mem_len) else: mem_transformer._layers.reset_length( tgt_len=args.eval_tgt_len, ext_len=args.ext_len + args.tgt_len - args.eval_tgt_len, mem_len=args.mem_len) else: if dist.get_world_size() == 1: mem_transformer.reset_length(tgt_len=args.eval_tgt_len, ext_len=args.ext_len, mem_len=args.mem_len + args.tgt_len - args.eval_tgt_len) else: mem_transformer._layers.reset_length( tgt_len=args.eval_tgt_len, ext_len=args.ext_len, mem_len=args.mem_len + args.tgt_len - args.eval_tgt_len) total_len, total_loss = 0, 0. eval_mems = tuple() with paddle.no_grad(): for i, (src, target, seq_len) in enumerate(eval_loader): if args.max_eval_steps > 0 and i >= args.max_eval_steps: break ret = mem_transformer(src, target, *eval_mems) loss, eval_mems = ret[0], ret[1:] seq_len = seq_len.numpy() eval_cur_loss = seq_len * loss.numpy() total_loss += eval_cur_loss total_len += seq_len eval_loss = total_loss / total_len logger_info = "Validation, step_idx: %d, validation loss: %f" % \ (step_idx, eval_loss) if args.dataset in ['enwik8', 'text8']: logger_info = logger_info + ", bpc: %f" % (eval_loss / np.log(2)) else: logger_info = logger_info + ", ppl: %f" % ( np.exp(eval_loss)) logger.info(logger_info) if args.save_model and rank == 0: model_dir = os.path.join( args.save_model, "step_" + str(step_idx) + "_" + str(eval_loss)) if not os.path.exists(model_dir): os.makedirs(model_dir) paddle.save( mem_transformer.state_dict(), os.path.join(model_dir, "mem_transformer.pdparams")) paddle.save( optimizer.state_dict(), os.path.join(model_dir, "mem_transformer.pdopt")) if args.scheduler == 'dev_perf': scheduler.step(eval_loss) # TODO(FrostML): simplify this. if dist.get_world_size() == 1: mem_transformer.reset_length(tgt_len=args.tgt_len, ext_len=args.ext_len, mem_len=args.mem_len) else: mem_transformer._layers.reset_length(tgt_len=args.tgt_len, ext_len=args.ext_len, mem_len=args.mem_len) mem_transformer.train() step_idx += 1 batch_id += 1 if args.scheduler in ['cosine', 'dev_perf']: if step_idx < args.warmup_steps: curr_lr = args.learning_rate * step_idx / args.warmup_steps scheduler.base_lr = curr_lr else: if args.scheduler == 'cosine': scheduler.step() elif args.scheduler == 'constant': if step_idx < args.warmup_steps: curr_lr = args.learning_rate * step_idx / args.warmup_steps optimizer.set_lr(curr_lr) elif args.scheduler == 'noam': scheduler.step() if step_idx >= args.max_step: break if args.save_model and rank == 0: model_dir = os.path.join(args.save_model, "step_final") if not os.path.exists(model_dir): os.makedirs(model_dir) paddle.save(mem_transformer.state_dict(), os.path.join(model_dir, "mem_transformer.pdparams")) paddle.save(optimizer.state_dict(), os.path.join(model_dir, "mem_transformer.pdopt"))
def main_loop(): util.cancel_shutdown() losses = [] args = g.args if not args.local: g.logger.info( f'Distributed initializing process group with {args.dist_backend}, {args.dist_url}, {util.get_world_size()}') dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url, world_size=util.get_world_size()) assert (util.get_world_size() == dist.get_world_size()) g.logger.info(f"Distributed: success ({args.local_rank}/{dist.get_world_size()})") if args.load_state_fn: g.state = load_state(args.load_state_fn) g.logger.info(f"Restoring training from {args.load_state_fn}") else: g.logger.info("creating new model") g.state = TrainState(args) g.state.model = MemTransformerLM(g.ntokens, args.n_layer, args.n_head, args.d_model, args.d_head, args.d_inner, args.dropout, args.dropatt, tie_weight=args.tied, d_embed=args.d_embed, div_val=args.div_val, tie_projs=g.tie_projs, pre_lnorm=args.pre_lnorm, tgt_len=args.tgt_len, ext_len=args.ext_len, mem_len=args.mem_len, cutoffs=g.cutoffs, same_length=args.same_length, attn_type=args.attn_type, clamp_len=args.clamp_len, sample_softmax=args.sample_softmax) if args.checkpoint: util.restore_from_checkpoint(g.state.model, checkpoint_fn=args.checkpoint) else: g.state.model.apply(weights_init) g.state.model.word_emb.apply( weights_init) # ensure embedding init is not overridden by out_layer in case of weight sharing g.state.model.to(g.device) optimizer_setup(g.state) model: MemTransformerLM = g.state.model optimizer = g.state.optimizer # log model info # n_all_param = sum([p.nelement() for p in model.parameters()]) # log_tb('sizes/params', n_all_param) # n_nonemb_param = sum([p.nelement() for p in model.layers.parameters()]) # log_tb('sizes/non_emb_params', n_nonemb_param) # g.logger.info('params %s non_emb_params %s', n_all_param, n_nonemb_param) # scheduler if not g.args.load_state_fn: if args.scheduler == 'cosine': # Divide by 1e6 for numerical stability. g.state.scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, args.max_tokens // 1e6, eta_min=args.eta_min) elif args.scheduler == 'finder': g.state.scheduler: LRFinder = LRFinder(optimizer, args.max_tokens, init_value=args.lr / 1e3) else: assert args.scheduler == 'constant' g.state.scheduler = util.NoOp() # Setup distributed model if args.local: model = nn.DataParallel(model, dim=1) else: # Uncomment find_unused_parameters and upgrade to torch 1.1 for adaptive embedding. model = DistributedDataParallel(model, device_ids=[args.local_rank], output_device=args.local_rank) # , find_unused_parameters=True) if util.get_global_rank() == 0: if not args.test: wandb.config.update(vars(args)) # wandb.watch(model) g.event_writer.add_text('args', str(args)) # TODO: replace with log_tb accumulated_loss = 0 # At any point you can hit Ctrl + C to break out of training early. try: for epoch in itertools.count(start=g.state.last_epoch): print(f"epoch -- {epoch}, token_count -- {g.state.
def main_loop(): util.cancel_shutdown() losses = [] args = g.args if not args.local: g.logger.info( f'Distributed initializing process group with ' f'{args.dist_backend}, {args.dist_url}, {util.get_world_size()}') dist.init_process_group( backend=args.dist_backend, #init_method=args.dist_url, #world_size=util.get_world_size() ) assert (util.get_world_size() == dist.get_world_size()) g.logger.info( f"Distributed: success ({args.local_rank}/{dist.get_world_size()})" ) g.logger.info("creating new model") g.state = TrainState(args) g.state.model = MemTransformerLM(g.ntokens, args.n_layer, args.n_head, args.d_model, args.d_head, args.d_inner, args.dropout, args.dropatt, tie_weight=args.tied, d_embed=args.d_embed, div_val=args.div_val, tie_projs=g.tie_projs, pre_lnorm=args.pre_lnorm, tgt_len=args.tgt_len, ext_len=args.ext_len, mem_len=args.mem_len, cutoffs=g.cutoffs, same_length=args.same_length, attn_type=args.attn_type, clamp_len=args.clamp_len, sample_softmax=args.sample_softmax, freeze_below=args.freeze_below) g.state.model.to(g.device) optimizer_setup(g.state) if args.checkpoint: if args.checkpoint_secondary: g.logger.info(f"restoring extra checkpoint") util.restore_from_checkpoint(g.state.model, g.state.optimizer, args.checkpoint_secondary, args.optim_state_dict) g.logger.info(f"Restoring model from {args.checkpoint}" + f" and optimizer from {args.optim_state_dict}" if args. optim_state_dict else "") util.restore_from_checkpoint(g.state.model, g.state.optimizer, args.checkpoint, args.optim_state_dict) else: g.state.model.apply(weights_init) # ensure embedding init is not overridden by out_layer in case of weight sharing g.state.model.word_emb.apply(weights_init) model: MemTransformerLM = g.state.model optimizer = g.state.optimizer if g.state.args.fp16: model = FP16_Module(model) optimizer = FP16_Optimizer( optimizer, static_loss_scale=g.state.args.static_loss_scale, dynamic_loss_scale=g.state.args.dynamic_loss_scale, dynamic_loss_args={'init_scale': 2**16}, verbose=False) # log model info # n_all_param = sum([p.nelement() for p in model.parameters()]) # log_tb('sizes/params', n_all_param) # n_nonemb_param = sum([p.nelement() for p in model.layers.parameters()]) # log_tb('sizes/non_emb_params', n_nonemb_param) # g.logger.info('params %s non_emb_params %s', n_all_param, n_nonemb_param) # scheduler if args.scheduler == 'cosine': # Divide by 1e6 for numerical stability. g.state.scheduler = optim.lr_scheduler.CosineAnnealingLR( optimizer, args.max_tokens // 1e6, eta_min=args.eta_min) elif args.scheduler == 'finder': g.state.scheduler: LRFinder = LRFinder(optimizer, args.max_tokens, init_value=args.lr / 1e3) else: assert args.scheduler == 'constant' g.state.scheduler = util.NoOp() # Setup distributed model if args.local: model = nn.DataParallel(model, dim=1) else: # Uncomment find_unused_parameters and upgrade to torch 1.1 for adaptive embedding. model = DistributedDataParallel( model, device_ids=[args.local_rank], output_device=args.local_rank) # , find_unused_parameters=True) if util.get_global_rank() == 0: if not args.test: wandb.config.update(vars(args)) # wandb.watch(model) g.event_writer.add_text('args', str(args)) # TODO: replace with log_tb accumulated_loss = 0 # At any point you can hit Ctrl + C to break out of training early. try: for epoch in itertools.count(start=g.state.last_epoch): print(f"epoch -- {epoch}, token_count -- {g.state.token_count}") model.train() log_tb('sizes/batch_size', args.batch_size) log_tb('sizes/seq_size', args.tgt_len) if g.state.partial_epoch: # reuse previously loaded tr_iter and states assert g.state.tr_iter is not None assert g.state.mems is not None else: g.state.tr_iter = g.corpus.get_dist_iterator( 'train', rank=util.get_global_rank(), max_rank=util.get_world_size(), bsz=args.batch_size, bptt=args.tgt_len, device=g.device, ext_len=args.ext_len, skip_files=g.args.skip_files) g.state.mems = tuple() g.state.last_epoch = epoch log_start_time = time.time() tokens_per_epoch = 0 for batch, (data, target, seq_len) in enumerate(g.state.tr_iter): # assert seq_len == data.shape[0] # for i in range(1, data.shape[0]): # assert torch.all(torch.eq(data[i], target[i - 1])) # break # print(g.state.token_count, data) if g.state.train_step % args.eval_interval == 0: evaluate_and_log(model, g.va_iter, 'val_short-mem-1', generate_text=False, reset_mems_interval=1) evaluate_and_log(model, g.va_iter, 'val_short-mem-2', generate_text=False, reset_mems_interval=2) evaluate_and_log(model, g.va_iter, 'val_short-mem-3', generate_text=False, reset_mems_interval=3) evaluate_and_log(model, g.va_iter, 'val') if g.va_custom_iter: evaluate_and_log(g.state.model, g.va_custom_iter, g.args.valid_custom, generate_text=False) batch_total = torch.tensor(data.shape[1]).to(g.device) if args.local: # TODO(y): factor out (need way to see if dist was inited) batch_total = batch_total.sum() else: batch_total = util.dist_sum_tensor( batch_total) # global batch size batch_total = util.toscalar(batch_total) should_log = (g.state.train_step < args.verbose_log_steps) or \ (g.state.train_step + 1) % args.log_interval == 0 model.zero_grad() ret = model(data, target, *g.state.mems) loss, g.state.mems = ret[0], ret[1:] loss: torch.Tensor = loss.float().mean().type_as(loss) with timeit('backwards', noop=not should_log): if args.fp16: optimizer.backward(loss) else: loss.backward() loss0 = util.toscalar(loss) util.record('loss', loss0) util.record('params', torch.sum(util.flat_param(model)).item()) losses.append(loss0) accumulated_loss += loss0 if args.fp16: optimizer.clip_master_grads(args.clip) else: torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip) # step-wise learning rate annealing if hasattr(optimizer, 'overflow') and optimizer.overflow: g.logger.info("skipped iteration") else: if args.scheduler in ['cosine', 'constant', 'dev_perf']: # linear warmup stage if g.state.token_count < args.warmup_tokens: curr_lr = args.lr * float( g.state.token_count) / args.warmup_tokens optimizer.param_groups[0]['lr'] = curr_lr elif args.scheduler == 'cosine': # Divide by 1e6 for numerical stability. g.state.scheduler.step(g.state.token_count // 1000 // 1000) else: g.state.scheduler.step(g.state.token_count) optimizer.step() g.state.train_step += 1 consumed_tokens = data.shape[0] * data.shape[1] world_size = int(os.environ.get("WORLD_SIZE", "8")) if world_size > 8: # correction factor for multiple machines consumed_tokens = consumed_tokens * (world_size // 8) tokens_per_epoch += consumed_tokens g.state.token_count += consumed_tokens g.token_count = g.state.token_count if g.state.token_count >= args.max_tokens: g.state.partial_epoch = True raise StopIteration # break out of parent train loop if should_log: elapsed_time = time.time() - log_start_time elapsed_steps = g.state.train_step - g.state.last_log_step # compute average loss over last logging interval cur_loss = accumulated_loss / elapsed_steps cur_loss_mean = util.dist_mean(cur_loss) log_str = f'| epoch {epoch:3d} step {g.state.train_step:>8d} ' \ f'| {batch:>6d} batches ' \ f'| lr {optimizer.param_groups[0]["lr"]:.3g} ' \ f'| ms/batch {elapsed_time * 1000 / elapsed_steps:5.2f} ' \ f'| loss {cur_loss:5.2f}' if args.dataset in ['enwik8', 'text8']: log_str += f' | bpc {cur_loss / math.log(2):9.5f}' else: log_str += f' | ppl {math.exp(cur_loss):9.3f}' g.logger.info(log_str) log_tb('learning/epoch', epoch) log_tb('_loss', cur_loss_mean) # the most important thing log_tb('learning/loss', cur_loss_mean) log_tb('learning/ppl', math.exp(cur_loss_mean)) # currently step timings are not synchronized in multi-machine # case (see #4). Can add torch.distributed.barrier() to get # more accurate timings, but this may add slowness. log_tb('times/step', 1000 * elapsed_time / elapsed_steps) current_lr = optimizer.param_groups[0]['lr'] log_tb('learning/lr', current_lr) # 32 is the "canonical" batch size linear_scaling_factor = batch_total / 32 # TODO(y): merge logic from master log_tb('learning/base_lr', current_lr / linear_scaling_factor) if args.optim == 'lamb': log_lamb_rs(optimizer, g.event_writer, g.state.token_count) time_per_batch = elapsed_time / elapsed_steps time_per_sample = time_per_batch / args.batch_size time_per_token = time_per_sample / args.tgt_len log_tb('times/batches_per_sec', 1 / time_per_batch) log_tb('times/samples_per_sec', 1 / time_per_sample) log_tb('times/tokens_per_sec', 1 / time_per_token) if str(g.device) == 'cuda': log_tb("memory/allocated_gb", torch.cuda.memory_allocated() / 1e9) log_tb("memory/max_allocated_gb", torch.cuda.max_memory_allocated() / 1e9) log_tb("memory/cached_gb", torch.cuda.memory_cached() / 1e9) log_tb("memory/max_cached_gb", torch.cuda.max_memory_cached() / 1e9) accumulated_loss = 0 log_start_time = time.time() g.state.last_log_step = g.state.train_step if args.checkpoint_each_epoch: g.logger.info(f'Saving checkpoint for epoch {epoch}') util.dist_save_checkpoint(model, optimizer, args.logdir, suffix=f'{epoch}') if tokens_per_epoch == 0: logging.info("Zero tokens in last epoch, breaking") break g.state.partial_epoch = False except KeyboardInterrupt: g.logger.info('-' * 100) g.logger.info('Exiting from training early') except StopIteration: pass return losses
def main(): global global_token_count, event_writer, train_step, train_loss, last_log_step, \ best_val_loss, epoch, model if args.local_rank > 0: pass # skip shutdown when rank is explicitly set + not zero rank else: os.system('shutdown -c') if not args.local: logger.info( f'Distributed initializing process group with {args.dist_backend}, {args.dist_url}, {util.get_world_size()}' ) dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url, world_size=util.get_world_size()) assert (util.get_world_size() == dist.get_world_size()) logger.info( f"Distributed: success ({args.local_rank}/{dist.get_world_size()})" ) model = MemTransformerLM(ntokens, args.n_layer, args.n_head, args.d_model, args.d_head, args.d_inner, args.dropout, args.dropatt, tie_weight=args.tied, d_embed=args.d_embed, div_val=args.div_val, tie_projs=tie_projs, pre_lnorm=args.pre_lnorm, tgt_len=args.tgt_len, ext_len=args.ext_len, mem_len=args.mem_len, cutoffs=cutoffs, same_length=args.same_length, attn_type=args.attn_type, clamp_len=args.clamp_len, sample_softmax=args.sample_softmax) # log model info n_all_param = sum([p.nelement() for p in model.parameters()]) log_tb('sizes/params', n_all_param) n_nonemb_param = sum([p.nelement() for p in model.layers.parameters()]) log_tb('sizes/non_emb_params', n_nonemb_param) logger.info('params %s non_emb_params %s', n_all_param, n_nonemb_param) # optimizer if args.optim.lower() == 'sgd': optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=args.mom) elif args.optim.lower() == 'lamb': optimizer = Lamb(model.parameters(), lr=args.lr, weight_decay=args.wd) else: assert args.optim.lower() == 'adam' optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.wd) # scheduler if args.scheduler == 'cosine': # Divide by 1e6 for numerical stability. scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, args.max_tokens // 1e6, eta_min=args.eta_min) elif args.scheduler == 'finder': scheduler = LRFinder(optimizer, args.max_tokens, init_value=args.lr / 1e3) elif args.scheduler == 'constant': pass model.apply(weights_init) model.word_emb.apply( weights_init ) # ensure embedding init is not overridden by out_layer in case of weight sharing if args.checkpoint: if global_rank == 0: util.restore_from_checkpoint(model=model, checkpoint_fn=args.checkpoint) model = model.to(device) if args.fp16: model = FP16_Module(model) optimizer = FP16_Optimizer(optimizer, static_loss_scale=args.static_loss_scale, dynamic_loss_scale=args.dynamic_loss_scale, dynamic_loss_args={'init_scale': 2**16}, verbose=False) if args.local: model = nn.DataParallel(model, dim=1) else: # Uncomment find_unused_parameters and upgrade to torch 1.1 for adaptive embedding. model = DistributedDataParallel( model, device_ids=[args.local_rank], output_device=args.local_rank) #, find_unused_parameters=True) if global_rank == 0: event_writer = SummaryWriter(args.logdir) event_writer.add_text('args', str(args)) # test checkpoint writing if args.checkpoint_each_epoch: logger.info(f'Saving checkpoint for epoch {epoch}') util.dist_save_checkpoint(model, optimizer, args.logdir, suffix=f'{0}') # Loop over epochs. train_step = 0 train_loss = 0 last_log_step = 0 best_val_loss = None va_iter, te_iter = [ corpus.get_dist_iterator(split, global_rank, max_rank, args.batch_size * 2, args.tgt_len, device=device, ext_len=args.ext_len) for split in ('valid', 'test') ] # At any point you can hit Ctrl + C to break out of training early. try: for epoch in itertools.count(start=1): train(va_iter, optimizer, scheduler) except KeyboardInterrupt: logger.info('-' * 100) logger.info('Exiting from training early') except StopIteration: pass # Eval one more time. evaluate_and_log(optimizer, va_iter, 'val', train_step=-1) # Load the best saved model. logger.info("Loading best checkpoint") model_file = os.path.join(args.logdir, 'model-best.pt') if os.path.exists(model_file): with open(model_file, 'rb') as model_f: with timeit('load'): if args.local: model = torch.load(model_f) else: model = torch.load(model_f, map_location=lambda storage, loc: storage.cuda(args.local_rank)) model = DistributedDataParallel( model, device_ids=[args.local_rank], output_device=args.local_rank) else: logger.warn('no model file, using current model for loss') # Run on test data. evaluate_and_log(optimizer, te_iter, 'test', -1)
def main(): args = parse_args() if args.affinity != 'disabled': nproc_per_node = torch.cuda.device_count() affinity = utils.gpu_affinity.set_affinity(args.local_rank, nproc_per_node, args.affinity) print(f'{args.local_rank}: thread affinity: {affinity}') if args.type == 'pytorch': from mem_transformer import MemTransformerLM else: from inference.mem_transformer_jit import MemTransformerLM torch.cuda.set_device(args.local_rank) l2_promote() device = torch.device('cuda' if args.cuda else 'cpu') utils.distributed.init_distributed(args.cuda) with utils.distributed.sync_workers() as rank: if rank == 0: create_exp_dir(args.work_dir, debug=args.debug) # Setup logging if args.log_all_ranks: log_file = f'eval_log_rank_{utils.distributed.get_rank()}.log' else: log_file = f'eval_log.log' dllog_file = args.dllog_file log_file = os.path.join(args.work_dir, log_file) dllog_file = os.path.join(args.work_dir, dllog_file) if args.debug: log_file = os.devnull dllog_file = os.devnull utils.exp_utils.setup_logging( log_all_ranks=args.log_all_ranks, filename=log_file, filemode='a', ) utils.exp_utils.setup_dllogger(enabled=True, filename=dllog_file) if args.profile: try: pyprof.init(enable_function_stack=True) except NameError: warnings.warn('Called pyprof.init() but pyprof is not available') logging.info(args) dllogger.log(step='PARAMETER', data=vars(args)) if not args.no_env: log_env_info() # Set the random seed manually for reproducibility. np.random.seed(args.seed) torch.manual_seed(args.seed) if args.model: model_path = args.model elif args.work_dir: model_path = os.path.join(args.work_dir, 'checkpoint_best.pt') else: raise RuntimeError( 'Specify path to checkpoint using --model or --work_dir') if not args.manual_config: checkpoint = load_checkpoint(model_path) vocab_type = checkpoint['args'].vocab else: checkpoint = None vocab_type = args.manual_vocab if args.manual: vocab = checkpoint['vocab'] if hasattr(vocab, 'sym2idx') and not hasattr(vocab, 'unk_idx'): vocab.unk_idx = vocab.sym2idx['<unk>'] text = " ".join(args.manual) tokenized = tokenize_raw(text) symbols = vocab.tokenize(tokenized, add_eos=True) tensor = vocab.convert_to_tensor(symbols) iter = data_utils.LMOrderedIterator(tensor, bsz=args.batch_size, bptt=args.tgt_len, device=device, ext_len=args.ext_len, warmup=False) else: # Load dataset corpus = get_lm_corpus(args.data, args.dataset, vocab_type) if args.split == 'valid' or args.split == 'test': iter = corpus.get_iterator(args.split, args.batch_size, args.tgt_len, device=device, mem_len=args.mem_len, ext_len=args.ext_len) else: raise RuntimeError('Unknown split') if args.fp16: dtype = torch.float16 math_str = 'fp16' else: dtype = torch.float32 math_str = 'fp32' if args.load_torchscript: model = torch.jit.load(args.load_torchscript) elif not args.manual_config: checkpoint['model_config']['tgt_len'] = args.tgt_len checkpoint['model_config']['ext_len'] = args.ext_len checkpoint['model_config']['mem_len'] = args.mem_len checkpoint['model_config']['clamp_len'] = args.clamp_len checkpoint['model_config']['same_length'] = args.same_length checkpoint['model_config']['dtype'] = dtype model = MemTransformerLM(**checkpoint['model_config']) if args.type == 'pytorch': model.load_state_dict(checkpoint['model_state']) elif args.type == 'torchscript': model.load_state_dict(checkpoint['model_state'], strict=False) elif args.manual_config: args.manual_config['tgt_len'] = args.tgt_len args.manual_config['ext_len'] = args.ext_len args.manual_config['mem_len'] = args.mem_len args.manual_config['clamp_len'] = args.clamp_len args.manual_config['same_length'] = args.same_length args.manual_config['dtype'] = dtype model = MemTransformerLM(**args.manual_config) model = model.eval() model = model.to(device) model = model.to(dtype) if args.type == 'torchscript' and not args.manual_config: state = checkpoint['model_state'] tie_projs = checkpoint['model_config']['tie_projs'] tie_weight = checkpoint['model_config']['tie_weight'] div_val = checkpoint['model_config']['div_val'] d_model = checkpoint['model_config']['d_model'] d_embed = checkpoint['model_config']['d_embed'] if div_val != 1 or d_model != d_embed: for i in range(len(model.word_emb.emb_projs)): model.word_emb.emb_projs[i] = state[ f'word_emb.emb_projs.{i}'].to(dtype) for i in range(len(model.crit.out_projs)): if div_val == 1: src = 0 else: src = i if model.crit.out_projs[i] is not None: if tie_projs[i]: model.crit.out_projs[i] = state[ f'word_emb.emb_projs.{src}'].to(dtype) else: model.crit.out_projs[i] = state[f'crit.out_projs.{i}'].to( dtype) for i in range(len(model.crit.out_layers_biases)): model.crit.out_layers_biases[i] = state[ f'crit.out_layers_biases.{i}'].to(dtype) if tie_weight: for i in range(len(model.crit.out_layers_weights)): model.crit.out_layers_weights[i] = state[ f'word_emb.emb_layers.{i}.weight'].to(dtype) else: for i in range(len(model.crit.out_layers_weights)): model.crit.out_layers_weights[i] = state[ f'crit.out_layers_weights.{i}'].to(dtype) model = torch.jit.script(model) if args.type != 'pytorch': compile_model(model, device, args) if args.type == 'torchscript' and args.save_torchscript: torch.jit.save(model, args.save_torchscript) logging.info(f'Evaluating with: math {math_str} type {args.type} ' f'bsz {args.batch_size} tgt_len {args.tgt_len} ' f'ext_len {args.ext_len} mem_len {args.mem_len} ' f'clamp_len {args.clamp_len}') meters = {} warmup = args.mem_len // args.tgt_len + 2 meters['eval_throughput'] = AverageMeter(warmup=warmup, keep=args.save_data) meters['eval_latency'] = AverageMeter(warmup=warmup, keep=args.save_data) with torch.autograd.profiler.emit_nvtx(enabled=args.profile): loss = evaluate(iter, model, meters, args.log_interval, args.max_size, args.repeat) perplexity = math.exp(loss) log_str = format_log(loss, args.split, args) summary = { 'eval_loss': loss, 'eval_ppl': perplexity, } logging.info('=' * 100) logging.info(log_str) logging.info('=' * 100) if args.save_data: latency_data = np.array(meters['eval_latency'].vals) throughput_data = np.array(meters['eval_throughput'].vals) precision = 'fp16' if args.fp16 else 'fp32' data_fname = f'eval_data_{args.batch_size}_{precision}_{args.type}' data_path = os.path.join(args.work_dir, data_fname) data = { 'args': args, 'throughput': throughput_data, 'latency': latency_data, } with open(data_path, 'wb') as f: pickle.dump(data, f) logging.info(f'Throughput Avg: {throughput_data.mean():.2f} tok/s') logging.info(f'Latency Avg: {1000.0 * latency_data.mean():.2f} ms') for p in args.percentiles: logging.info( f'Latency {p}%: {1000.0 * np.percentile(latency_data, p):.2f} ms' ) logging.info('=' * 100) summary.update({ 'eval_throughput': throughput_data.mean(), 'eval_avg_latency': 1000 * latency_data.mean(), }) for p in args.percentiles: summary[f'eval_{p}%_latency'] = 1000 * np.percentile( latency_data, p) dllogger.log(step=tuple(), data=summary) passed = benchmark( target_perplexity=args.target_perplexity, test_perplexity=perplexity, target_throughput=args.target_throughput, test_throughput=meters['eval_throughput'].avg, ) if not passed: sys.exit(1)
with open(os.path.join(args.work_dir, 'checkpoint.pt'), 'rb') as f: # model_state_dict = torch.load(f) checkpoint = torch.load(f) model_args = checkpoint['args'] model = MemTransformerLM(ntokens, model_args.n_layer, model_args.n_head, model_args.d_model, model_args.d_head, model_args.d_inner, 0.0, 0.0, tie_weight=model_args.tied, d_embed=model_args.d_embed, div_val=1.0, tie_projs=[False], pre_lnorm=model_args.pre_lnorm, tgt_len=model_args.tgt_len, ext_len=model_args.ext_len, mem_len=model_args.mem_len, cutoffs=[], same_length=model_args.same_length, attn_type=model_args.attn_type, clamp_len=model_args.clamp_len, sample_softmax=False) model.load_state_dict(checkpoint['model']) model.backward_compatible() model = model.to(device) logging('Evaluating with bsz {} tgt_len {} ext_len {} mem_len {} clamp_len {}'.
tracker.launch_impact_monitor() n_head, d_head = head_repartition_rule(d_model) d_inner = d_model model = MemTransformerLM(ntokens, n_layer, n_head, d_model, d_head, d_inner, default_args.dropout, default_args.dropatt, tie_weight=default_args.tied, d_embed=d_model, div_val=default_args.div_val, tie_projs=tie_projs, pre_lnorm=default_args.pre_lnorm, tgt_len=default_args.tgt_len, ext_len=default_args.ext_len, mem_len=default_args.mem_len, cutoffs=cutoffs, same_length=default_args.same_length, attn_type=default_args.attn_type, clamp_len=default_args.clamp_len, sample_softmax=default_args.sample_softmax) initialization_func = partial(weights_init, init="normal", init_range=0.1, init_std=0.02, proj_init_std=0.01)
def main(): args = parse_args() if args.type == 'pytorch': from mem_transformer import MemTransformerLM else: from inference.mem_transformer_base_jit import MemTransformerLM torch.cuda.set_device(args.local_rank) device = torch.device('cuda' if args.cuda else 'cpu') utils.distributed.init_distributed(args.cuda) with utils.distributed.sync_workers() as rank: if rank == 0: create_exp_dir(args.work_dir, debug=args.debug) # Setup logging if args.log_all_ranks: log_file = f'log_rank_{utils.distributed.get_rank()}.log' else: log_file = f'log.log' log_file = os.path.join(args.work_dir, log_file) if args.debug: log_file = os.devnull utils.exp_utils.setup_logging( log_all_ranks=args.log_all_ranks, filename=log_file, filemode='a', ) logging.info(args) if args.model: model_path = args.model elif args.work_dir: model_path = os.path.join(args.work_dir, 'checkpoint_best.pt') else: raise RuntimeError( 'Specify path to checkpoint using --model or --work_dir') checkpoint = load_checkpoint(model_path) if args.manual: args.batch_size = 1 vocab = checkpoint['vocab'] if hasattr(vocab, 'sym2idx') and not hasattr(vocab, 'unk_idx'): vocab.unk_idx = vocab.sym2idx['<unk>'] text = " ".join(args.manual) tokenized = tokenize_raw(text) symbols = vocab.tokenize(tokenized, add_eos=True) tensor = vocab.convert_to_tensor(symbols) iter = data_utils.LMOrderedIterator(tensor, bsz=args.batch_size, bptt=args.tgt_len, device=device, ext_len=args.ext_len) else: # Load dataset corpus = get_lm_corpus(args.data, args.dataset, checkpoint['args'].vocab) if args.split == 'valid': iter = corpus.get_iterator('valid', args.batch_size, args.tgt_len, device=device, ext_len=args.ext_len) elif args.split == 'test': iter = corpus.get_iterator('test', args.batch_size, args.tgt_len, device=device, ext_len=args.ext_len) else: raise RuntimeError('Unknown split') if args.fp16: dtype = torch.float16 math_str = 'fp16' else: dtype = torch.float32 math_str = 'fp32' if args.load_torchscript: model = torch.jit.load(args.load_torchscript) else: checkpoint['model_config']['tgt_len'] = args.tgt_len checkpoint['model_config']['ext_len'] = args.ext_len checkpoint['model_config']['mem_len'] = args.mem_len checkpoint['model_config']['clamp_len'] = args.clamp_len checkpoint['model_config']['same_length'] = args.same_length checkpoint['model_config']['dtype'] = dtype model = MemTransformerLM(**checkpoint['model_config']) model.load_state_dict(checkpoint['model_state']) model = model.eval() model = model.to(device) model = model.float() if args.fp16: model = model.half() if args.type != 'pytorch': compile_model(model, device, args) if args.type == 'torchscript' and args.save_torchscript: torch.jit.save(model, args.save_torchscript) logging.info(f'Evaluating with: math {math_str} type {args.type} ' f'bsz {args.batch_size} tgt_len {args.tgt_len} ' f'ext_len {args.ext_len} mem_len {args.mem_len} ' f'clamp_len {args.clamp_len}') meters = {} warmup = args.mem_len // args.tgt_len + 1 meters['eval_throughput'] = AverageMeter(warmup=warmup, keep=args.save_data) meters['eval_latency'] = AverageMeter(warmup=warmup, keep=args.save_data) loss = evaluate(iter, model, meters, args.max_size, args.repeat) perplexity = math.exp(loss) log_str = format_log(loss, args.split, args) logging.info('=' * 100) logging.info(log_str) logging.info('=' * 100) if args.save_data: latency_data = np.array(meters['eval_latency'].vals) throughput_data = np.array(meters['eval_throughput'].vals) precision = 'fp16' if args.fp16 else 'fp32' data_fname = f'eval_data_{args.batch_size}_{precision}_{args.type}' data_path = os.path.join(args.work_dir, data_fname) data = { 'args': args, 'throughput': throughput_data, 'latency': latency_data, } with open(data_path, 'wb') as f: pickle.dump(data, f) logging.info(f'Throughput Avg: {throughput_data.mean():.2f} tok/s') logging.info(f'Latency Avg: {1000.0 * latency_data.mean():.2f} ms') for p in args.percentiles: logging.info( f'Latency {p}%: {1000.0 * np.percentile(latency_data, p):.2f} ms' ) logging.info('=' * 100) passed = benchmark( target_perplexity=args.target_perplexity, test_perplexity=perplexity, target_throughput=args.target_throughput, test_throughput=meters['eval_throughput'].avg, ) if not passed: sys.exit(1)
if not args.fp16: model = model.float() model.apply(update_dropout) model.apply(update_dropatt) else: model = MemTransformerLM(ntokens, args.n_layer, args.n_head, args.d_model, args.d_head, args.d_inner, args.dropout, args.dropatt, tie_weight=args.tied, d_embed=args.d_embed, div_val=args.div_val, tie_projs=tie_projs, pre_lnorm=args.pre_lnorm, tgt_len=args.tgt_len, ext_len=args.ext_len, mem_len=args.mem_len, cutoffs=cutoffs, same_length=args.same_length, attn_type=args.attn_type, clamp_len=args.clamp_len, sample_softmax=args.sample_softmax) model.apply(weights_init) model.word_emb.apply( weights_init ) # ensure embedding init is not overridden by out_layer in case of weight sharing args.n_all_param = sum([p.nelement() for p in model.parameters()])
logging(f"reloading {model_file_name}") with open(model_file_name, 'rb') as f: model = torch.load(f) # backwards compatibility with older saves if isinstance(model, nn.DataParallel): model = model.module model.backward_compatible(tie_weight=args.tied, tie_projs=tie_projs) if not args.fp16: model = model.float() model.apply(update_dropout) model.apply(update_dropatt) else: model = MemTransformerLM(ntokens, args.n_layer, args.n_head, args.d_model, args.d_head, args.d_inner, args.dropout, args.dropatt, tie_weight=args.tied, d_embed=args.d_embed, div_val=args.div_val, tie_projs=tie_projs, pre_lnorm=args.pre_lnorm, tgt_len=args.tgt_len, ext_len=args.ext_len, mem_len=args.mem_len, cutoffs=cutoffs, same_length=args.same_length, attn_type=args.attn_type, clamp_len=args.clamp_len, sample_softmax=args.sample_softmax) model.apply(initialization_func) # debug # model.word_emb.apply(initialization_func) # ensure embedding init is not overridden by out_layer in case of weight sharing args.n_all_param = sum([p.nelement() for p in model.parameters()]) args.n_nonemb_param = non_emb_param_count(model, ntokens) logging('=' * 100) for k, v in args.__dict__.items(): logging(' - {} : {}'.format(k, v)) logging('=' * 100) logging('#params = {}'.format(args.n_all_param))