def main(): args = parse_args() if args.affinity != 'disabled': nproc_per_node = torch.cuda.device_count() affinity = utils.gpu_affinity.set_affinity(args.local_rank, nproc_per_node, args.affinity) print(f'{args.local_rank}: thread affinity: {affinity}') # Initialize device and distributed backend torch.cuda.set_device(args.local_rank) l2_promote() device = torch.device('cuda' if args.cuda else 'cpu') utils.distributed.init_distributed(args.cuda) args.work_dir = utils.exp_utils.build_work_dir_name( args.work_dir, args.dataset, args.append_dataset, args.append_time, ) with utils.distributed.sync_workers() as rank: if rank == 0: create_exp_dir(args.work_dir, scripts_to_save=['train.py', 'mem_transformer.py'], debug=args.debug) # Setup logging if args.log_all_ranks: log_file = f'train_log_rank_{utils.distributed.get_rank()}.log' else: log_file = args.txtlog_file dllog_file = args.dllog_file log_file = os.path.join(args.work_dir, log_file) dllog_file = os.path.join(args.work_dir, dllog_file) if args.debug: log_file = os.devnull dllog_file = os.devnull utils.exp_utils.setup_logging( log_all_ranks=args.log_all_ranks, filename=log_file, ) utils.exp_utils.setup_dllogger(enabled=True, filename=dllog_file) if args.local_batch_size is not None: world_size = utils.distributed.get_world_size() args.batch_size = world_size * args.local_batch_size logging.info(f'--local_batch_size was set, adjusting global batch size' f' to {args.batch_size} (local_batch_size * world_size)') if args.batch_size % args.batch_chunk != 0: raise RuntimeError('Batch size needs to be divisible by ' 'batch chunk') if args.profile: try: pyprof.init(enable_function_stack=True) except NameError: warnings.warn('Called pyprof.init() but pyprof is not available') logging.info(args) dllogger.log(step='PARAMETER', data=vars(args)) logging.info(f'world size: {utils.distributed.get_world_size()}') if not args.no_env: log_env_info() register_ignoring_timeout_handler() # Set the random seed manually for reproducibility. np.random.seed(args.seed) torch.manual_seed(args.seed) ########################################################################### # Load data ########################################################################### corpus = get_lm_corpus(args.data, args.dataset, args.vocab) ntokens = len(corpus.vocab) vocab = corpus.vocab args.n_token = ntokens if args.mem_len == 0: eval_mem_len = 0 else: eval_mem_len = args.mem_len + args.tgt_len - args.eval_tgt_len tr_iter = corpus.get_iterator('train', args.batch_size, args.tgt_len, device=device, ext_len=args.ext_len) va_iter = corpus.get_iterator('valid', args.eval_batch_size, args.eval_tgt_len, device=device, mem_len=eval_mem_len, ext_len=args.ext_len) te_iter = corpus.get_iterator('test', args.eval_batch_size, args.eval_tgt_len, device=device, mem_len=eval_mem_len, ext_len=args.ext_len) # adaptive softmax / embedding cutoffs, tie_projs = [], [False] if args.adaptive: assert args.dataset in ['wt103', 'lm1b'] if args.dataset == 'wt103': cutoffs = [19997, 39997, 199997] tie_projs += [True] * len(cutoffs) elif args.dataset == 'lm1b': cutoffs = [59997, 99997, 639997] tie_projs += [False] * len(cutoffs) ########################################################################### # Build the model ########################################################################### model_config = { 'n_token': ntokens, 'n_layer': args.n_layer, 'n_head': args.n_head, 'd_model': args.d_model, 'd_head': args.d_head, 'd_inner': args.d_inner, 'dropout': args.dropout, 'dropatt': args.dropatt, 'dtype': None, 'tie_weight': args.tied, 'd_embed': args.d_embed, 'div_val': args.div_val, 'tie_projs': tie_projs, 'pre_lnorm': args.pre_lnorm, 'tgt_len': args.tgt_len, 'ext_len': args.ext_len, 'mem_len': args.mem_len, 'cutoffs': cutoffs, 'same_length': args.same_length, 'attn_type': args.attn_type, 'clamp_len': args.clamp_len, 'sample_softmax': args.sample_softmax, } model = MemTransformerLM(**model_config) model.apply(functools.partial(weights_init, args=args)) # ensure embedding init is not overridden by out_layer in case of weight sharing model.word_emb.apply(functools.partial(weights_init, args=args)) args.n_all_param = sum([p.nelement() for p in model.parameters()]) args.n_nonemb_param = sum( [p.nelement() for p in model.layers.parameters()]) # optimizer if args.optim.lower() == 'sgd': if args.sample_softmax > 0: dense_params, sparse_params = [], [] for param in model.parameters(): if param.size() == model.word_emb.weight.size(): sparse_params.append(param) else: dense_params.append(param) optimizer_sparse = optim.SGD(sparse_params, lr=args.lr * 2) optimizer = optim.SGD(dense_params, lr=args.lr, momentum=args.mom) else: optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=args.mom) optimizer_sparse = None elif args.optim.lower() == 'adam': if args.sample_softmax > 0: dense_params, sparse_params = [], [] for param in model.parameters(): if param.size() == model.word_emb.weight.size(): sparse_params.append(param) else: dense_params.append(param) optimizer_sparse = optim.SparseAdam(sparse_params, lr=args.lr) optimizer = optim.Adam(dense_params, lr=args.lr, weight_decay=args.weight_decay) else: optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay) optimizer_sparse = None elif args.optim.lower() == 'adagrad': optimizer = optim.Adagrad(model.parameters(), lr=args.lr) optimizer_sparse = None elif args.optim.lower() == 'lamb': optimizer = lamb.Lamb(model.parameters(), lr=args.lr, weight_decay=args.weight_decay) optimizer_sparse = None elif args.optim.lower() == 'jitlamb': optimizer = lamb.JITLamb(model.parameters(), lr=args.lr, weight_decay=args.weight_decay) optimizer_sparse = None model = model.to(device) scaler = None if args.fp16: if args.amp == 'pytorch': scaler = torch.cuda.amp.GradScaler() elif args.amp == 'apex': model, optimizer = amp.initialize( model, optimizer, opt_level=args.apex_amp_opt_level, ) if args.multi_gpu == 'ddp' and torch.distributed.is_initialized(): para_model = DistributedDataParallel( model, device_ids=[args.local_rank], output_device=args.local_rank, broadcast_buffers=False, find_unused_parameters=True, ) elif args.multi_gpu == 'dp': if args.gpu0_bsz >= 0: para_model = BalancedDataParallel(args.gpu0_bsz // args.batch_chunk, model, dim=1).to(device) else: para_model = nn.DataParallel(model, dim=1).to(device) else: para_model = model # scheduler if args.scheduler == 'cosine': if args.max_step_scheduler: max_step = args.max_step_scheduler else: max_step = args.max_step scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, max_step - args.warmup_step, eta_min=args.eta_min) if args.sample_softmax > 0 and optimizer_sparse is not None: scheduler_sparse = optim.lr_scheduler.CosineAnnealingLR( optimizer_sparse, max_step - args.warmup_step, eta_min=args.eta_min) else: scheduler_sparse = None elif args.scheduler == 'inv_sqrt': # originally used for Transformer (in Attention is all you need) def lr_lambda(step): # return a multiplier instead of a learning rate if step == 0 and args.warmup_step == 0: return 1. else: return 1. / (step ** 0.5) if step > args.warmup_step \ else step / (args.warmup_step ** 1.5) scheduler = optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lr_lambda) if args.sample_softmax > 0 and optimizer_sparse is not None: scheduler_sparse = optim.lr_scheduler.LambdaLR(optimizer_sparse, lr_lambda=lr_lambda) else: scheduler_sparse = None elif args.scheduler == 'dev_perf': scheduler = optim.lr_scheduler.ReduceLROnPlateau( optimizer, factor=args.decay_rate, patience=args.patience, min_lr=args.lr_min, ) if args.sample_softmax > 0 and optimizer_sparse is not None: scheduler_sparse = optim.lr_scheduler.ReduceLROnPlateau( optimizer_sparse, factor=args.decay_rate, patience=args.patience, min_lr=args.lr_min, ) else: scheduler_sparse = None elif args.scheduler == 'constant': pass logging.info('=' * 100) for k, v in args.__dict__.items(): logging.info(' - {} : {}'.format(k, v)) logging.info('=' * 100) logging.info('#params = {}'.format(args.n_all_param)) logging.info('#non emb params = {}'.format(args.n_nonemb_param)) train_step = 0 start_epoch = 1 last_batch = 0 last_iter = 0 best_val_loss = None if args.restart: try: checkpoint = load_checkpoint(args.restart) model.load_state_dict(checkpoint['model_state']) optimizer.load_state_dict(checkpoint['optimizer_state']) scheduler.load_state_dict(checkpoint['scheduler_state']) if args.fp16: if args.amp == 'pytorch': scaler.load_state_dict(checkpoint['amp_state']) elif args.amp == 'apex': amp.load_state_dict(checkpoint['amp_state']) train_step = checkpoint['train_step'] start_epoch = checkpoint['epoch'] last_batch = checkpoint['batch'] last_iter = checkpoint['last_iter'] best_val_loss = checkpoint['best_val_loss'] if train_step >= args.max_step: logging.info( f'Loaded checkpoint after {train_step} steps, but ' f'this run was scheduled for a total of ' f'{args.max_step} steps, exiting') sys.exit(1) model.apply(functools.partial(update_dropout, args=args)) model.apply(functools.partial(update_dropatt, args=args)) except FileNotFoundError: logging.info(f'Could not load checkpoint from {args.restart}, ' f'starting training from random init') meters = {} warmup = args.mem_len // args.tgt_len + 2 meters['train_throughput'] = AverageMeter(warmup=warmup) ########################################################################### # Train ########################################################################### # Loop over epochs. # At any point you can hit Ctrl + C to break out of training early. start_time = time.time() with torch.autograd.profiler.emit_nvtx(enabled=args.profile): with TimeoutHandler() as timeout_handler: try: for epoch in itertools.count(start=start_epoch): if args.roll: tr_iter.roll(seed=args.seed + epoch) train_step, best_val_loss = train( tr_iter, va_iter, model, para_model, model_config, optimizer, optimizer_sparse, scheduler, scheduler_sparse, scaler, vocab, epoch, last_batch, last_iter, train_step, best_val_loss, meters, timeout_handler, device, args) last_batch = 0 last_iter = 0 if train_step == args.max_step: logging.info('-' * 100) logging.info('End of training') break except KeyboardInterrupt: logging.info('-' * 100) logging.info('Exiting from training early') elapsed = time.time() - start_time ########################################################################### # Test ########################################################################### summary = {} test_path = os.path.join(args.work_dir, 'checkpoint_best.pt') if not args.debug and not args.no_eval and os.path.exists(test_path): # Load the best saved model. checkpoint = load_checkpoint(test_path) model.load_state_dict(checkpoint['model_state']) # Run on test data. test_start_time = time.time() with torch.autograd.profiler.emit_nvtx(enabled=args.profile): test_loss = evaluate(te_iter, model, args) test_loss = utils.distributed.all_reduce_item(test_loss, 'mean') test_elapsed = time.time() - test_start_time logging.info('=' * 100) if args.dataset in ['enwik8', 'text8']: logging.info( '| End of training | test time: {:5.2f}s | test loss {:5.2f} | test bpc {:9.5f}' .format(test_elapsed, test_loss, test_loss / math.log(2))) else: logging.info( '| End of training | test time: {:5.2f}s | test loss {:5.2f} | test ppl {:9.3f}' .format(test_elapsed, test_loss, math.exp(test_loss))) logging.info('=' * 100) summary.update({ 'test_elapsed': test_elapsed, 'test_loss': test_loss, }) if args.dataset in ['enwik8', 'text8']: summary['test_bits_per_character'] = test_loss / math.log(2) else: summary['test_perplexity'] = math.exp(test_loss) logging.info(f'Training time: {(elapsed / 60):.2f} minutes') logging.info( f'Training throughput: {meters["train_throughput"].avg:.2f} tok/s') if best_val_loss: val_perplexity = math.exp(best_val_loss) else: val_perplexity = None summary.update({ 'train_throughput': meters['train_throughput'].avg, 'train_elapsed': elapsed / 60, 'valid_loss': best_val_loss, 'valid_perplexity': val_perplexity, }) dllogger.log(step=tuple(), data=summary) passed = benchmark(target_perplexity=args.target_perplexity, test_perplexity=val_perplexity, target_throughput=args.target_throughput, test_throughput=meters['train_throughput'].avg) if not passed: sys.exit(1)
div_val=args.div_val, tie_projs=tie_projs, pre_lnorm=args.pre_lnorm, tgt_len=args.tgt_len, ext_len=args.ext_len, mem_len=args.mem_len, cutoffs=cutoffs, same_length=args.same_length, attn_type=args.attn_type, clamp_len=args.clamp_len, sample_softmax=args.sample_softmax) model.apply(weights_init) model.word_emb.apply( weights_init ) # ensure embedding init is not overridden by out_layer in case of weight sharing args.n_all_param = sum([p.nelement() for p in model.parameters()]) args.n_nonemb_param = sum([p.nelement() for p in model.layers.parameters()]) if args.fp16: model = model.half() if args.multi_gpu: model = model.to(device) if args.gpu0_bsz >= 0: para_model = BalancedDataParallel(args.gpu0_bsz // args.batch_chunk, model, dim=1).to(device) else: para_model = nn.DataParallel(model, dim=1).to(device) else: para_model = model.to(device)
def do_train(args): if args.use_gpu: rank = dist.get_rank() trainer_count = dist.get_world_size() else: rank = 0 trainer_count = 1 paddle.set_device("cpu") if trainer_count > 1: dist.init_parallel_env() random_seed = eval(str(args.random_seed)) if random_seed is not None: paddle.seed(random_seed) vocab = get_lm_vocab(args) train_loader = get_lm_data_loader(args, vocab, "train") eval_loader = get_lm_data_loader(args, vocab, "valid") cutoffs, tie_projs = [], [False] if args.adaptive: assert args.dataset in ['wt103', 'lm1b'] if args.dataset == 'wt103': cutoffs = [20000, 40000, 200000] tie_projs += [True] * len(cutoffs) elif args.dataset == 'lm1b': cutoffs = [60000, 100000, 640000] tie_projs += [False] * len(cutoffs) mem_transformer = MemTransformerLM(args.ntokens, args.n_layer, args.n_head, args.d_model, args.d_head, args.d_inner_hid, args.dropout, args.attn_dropout, tie_weight=args.tie_weight, d_embed=args.d_model, div_val=args.div_val, tie_projs=tie_projs, normalize_before=args.normalize_before, tgt_len=args.tgt_len, ext_len=args.ext_len, mem_len=args.mem_len, cutoffs=cutoffs, same_length=args.same_length, attn_type=args.attn_type, clamp_len=args.clamp_len, sample_softmax=args.sample_softmax) if args.scheduler == 'cosine': scheduler = paddle.optimizer.lr.CosineAnnealingDecay( learning_rate=args.learning_rate, T_max=args.max_step, eta_min=args.eta_min) elif args.scheduler == 'noam': scheduler = paddle.optimizer.lr.NoamDecay( d_model=args.d_model, warmup_steps=args.warmup_steps, learning_rate=args.learning_rate) elif args.scheduler == 'dev_perf': # fluid api scheduler = paddle.fluid.dygraph.ReduceLROnPlateau( learning_rate=args.learning_rate, decay_rate=args.decay_rate, patience=args.patience, min_lr=args.lr_min) elif args.scheduler == 'constant': scheduler = args.learning_rate clip = paddle.nn.ClipGradByGlobalNorm(args.clip) if args.optim.lower() == 'momentum': optimizer = paddle.optimizer.Momentum( learning_rate=scheduler, parameters=mem_transformer.parameters(), momentum=args.mom, grad_clip=clip) elif args.optim.lower() == 'adam': optimizer = paddle.optimizer.Adam( learning_rate=scheduler, parameters=mem_transformer.parameters(), beta1=args.beta1, beta2=args.beta2, epsilon=eval(args.eps), grad_clip=clip) elif args.optim.lower() == 'adagrad': optimizer = paddle.optimizer.Adagrad( learning_rate=scheduler, parameters=mem_transformer.parameters(), grad_clip=clip) # Init from some checkpoint, to resume the previous training if args.init_from_checkpoint: model_dict = paddle.load( os.path.join(args.init_from_checkpoint, "mem_transformer.pdparams")) opt_dict = paddle.load( os.path.join(args.init_from_checkpoint, "mem_transformer.pdopt")) mem_transformer.set_state_dict(model_dict) optimizer.set_state_dict(opt_dict) print("loaded from checkpoint.") # Init from some pretrain models, to better solve the current task if args.init_from_pretrain_model: model_dict = paddle.load( os.path.join(args.init_from_pretrain_model, "mem_transformer.pdparams")) mem_transformer.set_state_dict(model_dict) print("loaded from pre-trained model.") if trainer_count > 1: mem_transformer = paddle.DataParallel(mem_transformer) step_idx = 0 train_loss = 0.0 log_start_time = time.time() for pass_id in range(args.epoch): batch_id = 0 mems = tuple() for input_data in train_loader: (src, target, seq_len) = input_data ret = mem_transformer(src, target, *mems) loss = ret[0] mems = ret[1:] train_loss += loss.numpy() loss.backward() optimizer.step() optimizer.clear_grad() if step_idx > 0 and step_idx % args.print_step == 0 and rank == 0: cur_loss = train_loss / args.print_step elapsed = time.time() - log_start_time if args.scheduler == "constant": lr = optimizer.get_lr() else: lr = scheduler.get_lr() logger_info = "step_idx: %d, epoch: %d, batch: %d, learning rate: %.8f, " \ "speed: %f ms/batch, loss: %f" % \ (step_idx, pass_id, batch_id, lr, elapsed * 1000.0 / args.print_step, cur_loss) if args.dataset in ["enwik8", "text8"]: logger_info = logger_info + ", bpc: %f" % (cur_loss / np.log(2)) else: logger_info = logger_info + ", ppl: %f" % ( np.exp(cur_loss)) logger.info(logger_info) train_loss = 0.0 log_start_time = time.time() if step_idx % args.save_step == 0 and step_idx != 0: # Do validation. mem_transformer.eval() # TODO(FrostML): simplify this. if args.mem_len == 0: if dist.get_world_size() == 1: mem_transformer.reset_length(tgt_len=args.eval_tgt_len, ext_len=args.ext_len + args.tgt_len - args.eval_tgt_len, mem_len=args.mem_len) else: mem_transformer._layers.reset_length( tgt_len=args.eval_tgt_len, ext_len=args.ext_len + args.tgt_len - args.eval_tgt_len, mem_len=args.mem_len) else: if dist.get_world_size() == 1: mem_transformer.reset_length(tgt_len=args.eval_tgt_len, ext_len=args.ext_len, mem_len=args.mem_len + args.tgt_len - args.eval_tgt_len) else: mem_transformer._layers.reset_length( tgt_len=args.eval_tgt_len, ext_len=args.ext_len, mem_len=args.mem_len + args.tgt_len - args.eval_tgt_len) total_len, total_loss = 0, 0. eval_mems = tuple() with paddle.no_grad(): for i, (src, target, seq_len) in enumerate(eval_loader): if args.max_eval_steps > 0 and i >= args.max_eval_steps: break ret = mem_transformer(src, target, *eval_mems) loss, eval_mems = ret[0], ret[1:] seq_len = seq_len.numpy() eval_cur_loss = seq_len * loss.numpy() total_loss += eval_cur_loss total_len += seq_len eval_loss = total_loss / total_len logger_info = "Validation, step_idx: %d, validation loss: %f" % \ (step_idx, eval_loss) if args.dataset in ['enwik8', 'text8']: logger_info = logger_info + ", bpc: %f" % (eval_loss / np.log(2)) else: logger_info = logger_info + ", ppl: %f" % ( np.exp(eval_loss)) logger.info(logger_info) if args.save_model and rank == 0: model_dir = os.path.join( args.save_model, "step_" + str(step_idx) + "_" + str(eval_loss)) if not os.path.exists(model_dir): os.makedirs(model_dir) paddle.save( mem_transformer.state_dict(), os.path.join(model_dir, "mem_transformer.pdparams")) paddle.save( optimizer.state_dict(), os.path.join(model_dir, "mem_transformer.pdopt")) if args.scheduler == 'dev_perf': scheduler.step(eval_loss) # TODO(FrostML): simplify this. if dist.get_world_size() == 1: mem_transformer.reset_length(tgt_len=args.tgt_len, ext_len=args.ext_len, mem_len=args.mem_len) else: mem_transformer._layers.reset_length(tgt_len=args.tgt_len, ext_len=args.ext_len, mem_len=args.mem_len) mem_transformer.train() step_idx += 1 batch_id += 1 if args.scheduler in ['cosine', 'dev_perf']: if step_idx < args.warmup_steps: curr_lr = args.learning_rate * step_idx / args.warmup_steps scheduler.base_lr = curr_lr else: if args.scheduler == 'cosine': scheduler.step() elif args.scheduler == 'constant': if step_idx < args.warmup_steps: curr_lr = args.learning_rate * step_idx / args.warmup_steps optimizer.set_lr(curr_lr) elif args.scheduler == 'noam': scheduler.step() if step_idx >= args.max_step: break if args.save_model and rank == 0: model_dir = os.path.join(args.save_model, "step_final") if not os.path.exists(model_dir): os.makedirs(model_dir) paddle.save(mem_transformer.state_dict(), os.path.join(model_dir, "mem_transformer.pdparams")) paddle.save(optimizer.state_dict(), os.path.join(model_dir, "mem_transformer.pdopt"))
def main(): args = parse_args() # Initialize device and distributed backend torch.cuda.set_device(args.local_rank) device = torch.device('cuda' if args.cuda else 'cpu') utils.distributed.init_distributed(args.cuda) args.work_dir = utils.exp_utils.build_work_dir_name( args.work_dir, args.dataset, args.append_dataset, args.append_time, ) with utils.distributed.sync_workers() as rank: if rank == 0: create_exp_dir(args.work_dir, scripts_to_save=['train.py', 'mem_transformer.py'], debug=args.debug) # Setup logging if args.log_all_ranks: log_file = f'log_rank_{utils.distributed.get_rank()}.log' else: log_file = f'log.log' log_file = os.path.join(args.work_dir, log_file) if args.debug: log_file = os.devnull utils.exp_utils.setup_logging( log_all_ranks=args.log_all_ranks, filename=log_file, ) logging.info(args) # Set the random seed manually for reproducibility. np.random.seed(args.seed + utils.distributed.get_rank()) torch.manual_seed(args.seed + utils.distributed.get_rank()) ########################################################################### # Load data ########################################################################### corpus = get_lm_corpus(args.data, args.dataset, args.vocab) ntokens = len(corpus.vocab) vocab = corpus.vocab args.n_token = ntokens tr_iter = corpus.get_iterator('train', args.batch_size, args.tgt_len, device=device, ext_len=args.ext_len) va_iter = corpus.get_iterator('valid', args.eval_batch_size, args.eval_tgt_len, device=device, ext_len=args.ext_len) te_iter = corpus.get_iterator('test', args.eval_batch_size, args.eval_tgt_len, device=device, ext_len=args.ext_len) # adaptive softmax / embedding cutoffs, tie_projs = [], [False] if args.adaptive: assert args.dataset in ['wt103', 'lm1b'] if args.dataset == 'wt103': cutoffs = [19997, 39997, 199997] tie_projs += [True] * len(cutoffs) elif args.dataset == 'lm1b': cutoffs = [59997, 99997, 639997] tie_projs += [False] * len(cutoffs) ########################################################################### # Build the model ########################################################################### model_config = { 'n_token': ntokens, 'n_layer': args.n_layer, 'n_head': args.n_head, 'd_model': args.d_model, 'd_head': args.d_head, 'd_inner': args.d_inner, 'dropout': args.dropout, 'dropatt': args.dropatt, 'dtype': None, 'tie_weight': args.tied, 'd_embed': args.d_embed, 'div_val': args.div_val, 'tie_projs': tie_projs, 'pre_lnorm': args.pre_lnorm, 'tgt_len': args.tgt_len, 'ext_len': args.ext_len, 'mem_len': args.mem_len, 'cutoffs': cutoffs, 'same_length': args.same_length, 'attn_type': args.attn_type, 'clamp_len': args.clamp_len, 'sample_softmax': args.sample_softmax, } model = MemTransformerLM(**model_config) model.apply(functools.partial(weights_init, args=args)) # ensure embedding init is not overridden by out_layer in case of weight sharing model.word_emb.apply(functools.partial(weights_init, args=args)) args.n_all_param = sum([p.nelement() for p in model.parameters()]) args.n_nonemb_param = sum( [p.nelement() for p in model.layers.parameters()]) # optimizer if args.optim.lower() == 'sgd': if args.sample_softmax > 0: dense_params, sparse_params = [], [] for param in model.parameters(): if param.size() == model.word_emb.weight.size(): sparse_params.append(param) else: dense_params.append(param) optimizer_sparse = optim.SGD(sparse_params, lr=args.lr * 2) optimizer = optim.SGD(dense_params, lr=args.lr, momentum=args.mom) else: optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=args.mom) optimizer_sparse = None elif args.optim.lower() == 'adam': if args.sample_softmax > 0: dense_params, sparse_params = [], [] for param in model.parameters(): if param.size() == model.word_emb.weight.size(): sparse_params.append(param) else: dense_params.append(param) optimizer_sparse = optim.SparseAdam(sparse_params, lr=args.lr) optimizer = optim.Adam(dense_params, lr=args.lr, weight_decay=args.weight_decay) else: optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay) optimizer_sparse = None elif args.optim.lower() == 'adagrad': optimizer = optim.Adagrad(model.parameters(), lr=args.lr) optimizer_sparse = None elif args.optim.lower() == 'lamb': optimizer = lamb.Lamb(model.parameters(), lr=args.lr, weight_decay=args.weight_decay) optimizer_sparse = None elif args.optim.lower() == 'jitlamb': optimizer = lamb.JITLamb(model.parameters(), lr=args.lr, weight_decay=args.weight_decay) optimizer_sparse = None model = model.to(device) if args.fp16: model, optimizer = amp.initialize( model, optimizer, opt_level='O2', ) if args.multi_gpu == 'ddp' and torch.distributed.is_initialized(): para_model = DistributedDataParallel( model, delay_allreduce=True, ) elif args.multi_gpu == 'dp': if args.gpu0_bsz >= 0: para_model = BalancedDataParallel(args.gpu0_bsz // args.batch_chunk, model, dim=1).to(device) else: para_model = nn.DataParallel(model, dim=1).to(device) else: para_model = model # scheduler if args.scheduler == 'cosine': if args.max_step_scheduler: max_step = args.max_step_scheduler else: max_step = args.max_step scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, max_step, eta_min=args.eta_min) if args.sample_softmax > 0: scheduler_sparse = optim.lr_scheduler.CosineAnnealingLR( optimizer_sparse, max_step, eta_min=args.eta_min) else: scheduler_sparse = None elif args.scheduler == 'inv_sqrt': # originally used for Transformer (in Attention is all you need) def lr_lambda(step): # return a multiplier instead of a learning rate if step == 0 and args.warmup_step == 0: return 1. else: return 1. / (step ** 0.5) if step > args.warmup_step \ else step / (args.warmup_step ** 1.5) scheduler = optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lr_lambda) elif args.scheduler == 'dev_perf': scheduler = optim.lr_scheduler.ReduceLROnPlateau( optimizer, factor=args.decay_rate, patience=args.patience, min_lr=args.lr_min, ) if args.sample_softmax > 0: scheduler_sparse = optim.lr_scheduler.ReduceLROnPlateau( optimizer_sparse, factor=args.decay_rate, patience=args.patience, min_lr=args.lr_min, ) else: scheduler_sparse = None elif args.scheduler == 'constant': pass logging.info('=' * 100) for k, v in args.__dict__.items(): logging.info(' - {} : {}'.format(k, v)) logging.info('=' * 100) logging.info('#params = {}'.format(args.n_all_param)) logging.info('#non emb params = {}'.format(args.n_nonemb_param)) train_step = 0 best_val_loss = None if args.restart: checkpoint = load_checkpoint(args.restart) model.load_state_dict(checkpoint['model_state']) optimizer.load_state_dict(checkpoint['optimizer_state']) scheduler.load_state_dict(checkpoint['scheduler_state']) if args.fp16: amp.load_state_dict(checkpoint['amp_state']) train_step = checkpoint['train_step'] best_val_loss = checkpoint['best_val_loss'] model.apply(functools.partial(update_dropout, args=args)) model.apply(functools.partial(update_dropatt, args=args)) meters = {} warmup = args.mem_len // args.tgt_len + 1 meters['train_throughput'] = AverageMeter(warmup=warmup) ########################################################################### # Train ########################################################################### # Loop over epochs. # At any point you can hit Ctrl + C to break out of training early. start_time = time.time() try: for epoch in itertools.count(start=1): if args.roll: tr_iter.roll() train_step, best_val_loss = train(tr_iter, va_iter, model, para_model, model_config, optimizer, optimizer_sparse, scheduler, scheduler_sparse, vocab, epoch, train_step, best_val_loss, meters, args) if train_step == args.max_step: logging.info('-' * 100) logging.info('End of training') break except KeyboardInterrupt: logging.info('-' * 100) logging.info('Exiting from training early') elapsed = time.time() - start_time ########################################################################### # Test ########################################################################### test_path = os.path.join(args.work_dir, 'checkpoint_best.pt') if not args.debug and os.path.exists(test_path): # Load the best saved model. checkpoint = load_checkpoint(test_path) model.load_state_dict(checkpoint['model_state']) # Run on test data. test_start_time = time.time() test_loss = evaluate(te_iter, model, args) test_loss = utils.distributed.all_reduce_item(test_loss, 'mean') logging.info('=' * 100) if args.dataset in ['enwik8', 'text8']: logging.info( '| End of training | test time: {:5.2f}s | test loss {:5.2f} | test bpc {:9.5f}' .format(time.time() - test_start_time, test_loss, test_loss / math.log(2))) else: logging.info( '| End of training | test time: {:5.2f}s | test loss {:5.2f} | test ppl {:9.3f}' .format(time.time() - test_start_time, test_loss, math.exp(test_loss))) logging.info('=' * 100) logging.info(f'Training time: {(elapsed / 60):.2f} minutes') logging.info( f'Training throughput: {meters["train_throughput"].avg:.2f} tok/s') if best_val_loss: val_perplexity = math.exp(best_val_loss) else: val_perplexity = None passed = benchmark(target_perplexity=args.target_perplexity, test_perplexity=val_perplexity, target_throughput=args.target_throughput, test_throughput=meters['train_throughput'].avg) if not passed: sys.exit(1)
def load(cls, model_path, sp_model_path, device, print_stats=True): paramspath = os.path.join(model_path, 'params.json') with open(paramspath, 'r') as paramsf: xl_params = json.loads(paramsf.read()) print(repr(xl_params)) model = MemTransformerLM( xl_params['ntokens'], # 50000, xl_params['n_layer'], # 16, xl_params['n_head'], # 10, xl_params['d_model'], # 410, xl_params['d_head'], # 41, xl_params['d_inner'], # 2100, 0.0, # no dropout, 0.0, # no dropatt, tie_weight=xl_params['tie_weight'], # True, d_embed=xl_params['d_embed'], # 410, div_val=xl_params['div_val'], # 1, tie_projs=xl_params['tie_projs'], # [False, True, True, True] pre_lnorm=xl_params['pre_lnorm'], # False, tgt_len=xl_params['tgt_len'], # 150, ext_len=xl_params['ext_len'], # 0, mem_len=xl_params['mem_len'], # 150, cutoffs=xl_params['cutoffs'], # [3500, 7500, 37500], same_length=xl_params['same_length'], # False, attn_type=xl_params['attn_type'], # 0, clamp_len=xl_params['clamp_len'], # -1, sample_softmax=xl_params['sample_softmax']) # -1 state_dict_path = os.path.join(model_path, 'valid_state_dict.pt') print("loading weights %s ..." % state_dict_path) tensor_dict = torch.load(state_dict_path, map_location=torch.device(device)) model.load_state_dict(tensor_dict) print("loading weights %s ... done." % state_dict_path) if print_stats: tensor_list = list(tensor_dict.items()) for layer_tensor_name, tensor in tensor_list: print("Layer %-42s: %9d elements" % (layer_tensor_name, torch.numel(tensor))) pytorch_total_params = sum(p.numel() for p in model.parameters()) print("Total # params: %d" % pytorch_total_params) # with open(os.path.join(MODEL_PATH, 'model.pt'), 'rb') as f: # model = torch.load(f) # model.apply(update_dropout) # model.apply(update_dropatt) para_model = model.to(device) # print ("loading model %s ... done." % MODEL_PATH) print("loading sp model from %s ..." % sp_model_path) sp_model = spm.SentencePieceProcessor() sp_model.load(sp_model_path) print("loading sp model from %s ... done." % sp_model_path) return cls(para_model, sp_model, device)
def main(): global global_token_count, event_writer, train_step, train_loss, last_log_step, \ best_val_loss, epoch, model if args.local_rank > 0: pass # skip shutdown when rank is explicitly set + not zero rank else: os.system('shutdown -c') if not args.local: logger.info( f'Distributed initializing process group with {args.dist_backend}, {args.dist_url}, {util.get_world_size()}' ) dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url, world_size=util.get_world_size()) assert (util.get_world_size() == dist.get_world_size()) logger.info( f"Distributed: success ({args.local_rank}/{dist.get_world_size()})" ) model = MemTransformerLM(ntokens, args.n_layer, args.n_head, args.d_model, args.d_head, args.d_inner, args.dropout, args.dropatt, tie_weight=args.tied, d_embed=args.d_embed, div_val=args.div_val, tie_projs=tie_projs, pre_lnorm=args.pre_lnorm, tgt_len=args.tgt_len, ext_len=args.ext_len, mem_len=args.mem_len, cutoffs=cutoffs, same_length=args.same_length, attn_type=args.attn_type, clamp_len=args.clamp_len, sample_softmax=args.sample_softmax) # log model info n_all_param = sum([p.nelement() for p in model.parameters()]) log_tb('sizes/params', n_all_param) n_nonemb_param = sum([p.nelement() for p in model.layers.parameters()]) log_tb('sizes/non_emb_params', n_nonemb_param) logger.info('params %s non_emb_params %s', n_all_param, n_nonemb_param) # optimizer if args.optim.lower() == 'sgd': optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=args.mom) elif args.optim.lower() == 'lamb': optimizer = Lamb(model.parameters(), lr=args.lr, weight_decay=args.wd) else: assert args.optim.lower() == 'adam' optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.wd) # scheduler if args.scheduler == 'cosine': # Divide by 1e6 for numerical stability. scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, args.max_tokens // 1e6, eta_min=args.eta_min) elif args.scheduler == 'finder': scheduler = LRFinder(optimizer, args.max_tokens, init_value=args.lr / 1e3) elif args.scheduler == 'constant': pass model.apply(weights_init) model.word_emb.apply( weights_init ) # ensure embedding init is not overridden by out_layer in case of weight sharing if args.checkpoint: if global_rank == 0: util.restore_from_checkpoint(model=model, checkpoint_fn=args.checkpoint) model = model.to(device) if args.fp16: model = FP16_Module(model) optimizer = FP16_Optimizer(optimizer, static_loss_scale=args.static_loss_scale, dynamic_loss_scale=args.dynamic_loss_scale, dynamic_loss_args={'init_scale': 2**16}, verbose=False) if args.local: model = nn.DataParallel(model, dim=1) else: # Uncomment find_unused_parameters and upgrade to torch 1.1 for adaptive embedding. model = DistributedDataParallel( model, device_ids=[args.local_rank], output_device=args.local_rank) #, find_unused_parameters=True) if global_rank == 0: event_writer = SummaryWriter(args.logdir) event_writer.add_text('args', str(args)) # test checkpoint writing if args.checkpoint_each_epoch: logger.info(f'Saving checkpoint for epoch {epoch}') util.dist_save_checkpoint(model, optimizer, args.logdir, suffix=f'{0}') # Loop over epochs. train_step = 0 train_loss = 0 last_log_step = 0 best_val_loss = None va_iter, te_iter = [ corpus.get_dist_iterator(split, global_rank, max_rank, args.batch_size * 2, args.tgt_len, device=device, ext_len=args.ext_len) for split in ('valid', 'test') ] # At any point you can hit Ctrl + C to break out of training early. try: for epoch in itertools.count(start=1): train(va_iter, optimizer, scheduler) except KeyboardInterrupt: logger.info('-' * 100) logger.info('Exiting from training early') except StopIteration: pass # Eval one more time. evaluate_and_log(optimizer, va_iter, 'val', train_step=-1) # Load the best saved model. logger.info("Loading best checkpoint") model_file = os.path.join(args.logdir, 'model-best.pt') if os.path.exists(model_file): with open(model_file, 'rb') as model_f: with timeit('load'): if args.local: model = torch.load(model_f) else: model = torch.load(model_f, map_location=lambda storage, loc: storage.cuda(args.local_rank)) model = DistributedDataParallel( model, device_ids=[args.local_rank], output_device=args.local_rank) else: logger.warn('no model file, using current model for loss') # Run on test data. evaluate_and_log(optimizer, te_iter, 'test', -1)
def train_ts(args): def build_scheduler(optimizers, args): optimizer, optimizer_sparse = optimizers scheduler_sparse = None if args.scheduler == "cosine": # here we do not set eta_min to lr_min to be backward compatible # because in previous versions eta_min is default to 0 # rather than the default value of lr_min 1e-6 scheduler = optim.lr_scheduler.CosineAnnealingLR( optimizer, args.max_step, eta_min=args.eta_min) # should use eta_min arg elif args.scheduler == "inv_sqrt": # originally used for Transformer (in Attention is all you need) def lr_lambda(step): # return a multiplier instead of a learning rate if step == 0 and args.warmup_step == 0: return 1.0 else: return (1.0 / (step**0.5) if step > args.warmup_step else step / (args.warmup_step**1.5)) scheduler = optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lr_lambda) elif args.scheduler == "dev_perf": scheduler = optim.lr_scheduler.ReduceLROnPlateau( optimizer, factor=args.decay_rate, patience=args.patience, min_lr=args.lr_min, ) elif args.scheduler == "constant": pass else: raise ValueError(f"scheduler type {args.scheduler} not recognized") return scheduler, scheduler_sparse ############################################################################### # Training code ############################################################################### def evaluate(eval_iter, model): # Turn on evaluation mode which disables dropout. model.eval() # debug # If the model does not use memory at all, make the ext_len longer. # Otherwise, make the mem_len longer and keep the ext_len the same. # if default_args.mem_len == 0: # model.reset_length(default_args.eval_tgt_len, # default_args.ext_len + default_args.tgt_len - # default_args.eval_tgt_len, default_args.mem_len) # else: # model.reset_length(default_args.eval_tgt_len, # default_args.ext_len, default_args.mem_len + # default_args.tgt_len - default_args.eval_tgt_len) # Evaluation total_len, total_loss = 0, 0.0 with torch.no_grad(): mems = tuple() for i, (data, target, seq_len) in enumerate(eval_iter): if i >= args.max_eval_steps > 0: break ret = model(data, target, *mems) loss, mems = ret[0], ret[1:] loss = loss.mean() total_loss += seq_len * loss.float().item() total_len += seq_len # Switch back to the training mode # model.reset_length(default_args.tgt_len, default_args.ext_len, # default_args.mem_len) model.train() return total_loss / total_len # reverse distillation util def get_original_batches(model, tr_iter, integration_length): model.eval() if args.batch_chunk > 1: mems = [None for _ in range(args.batch_chunk)] first_logits = [[] for _ in range(args.batch_chunk)] else: mems = None first_logits = [] train_iter = tr_iter.get_varlen_iter() if args.varlen else tr_iter with torch.no_grad(): for batch, (data, target, seq_len) in enumerate(train_iter): if batch == integration_length: break if args.batch_chunk > 1: data_chunks = torch.chunk(data, args.batch_chunk, 1) for i in range(args.batch_chunk): data_i = data_chunks[i].contiguous() logits, mems[i] = model._forward(data_i, mems=mems[i]) first_logits[i].append(logits.cpu()) else: logits, mems = model._forward(data, mems=mems) first_logits.append(logits.cpu()) return first_logits def build_optimizer(model, args, reload=False): optimizer_sparse = None if args.optim.lower() == "sgd": optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=args.mom) elif args.optim.lower() == "adam": optimizer = optim.Adam(model.parameters(), lr=args.lr) elif args.optim.lower() == "adagrad": optimizer = optim.Adagrad(model.parameters(), lr=args.lr) else: raise ValueError(f"optimizer type {args.optim} not recognized") if reload: if args.restart_from is not None: optim_name = f"optimizer_{args.restart_from}.pt" else: optim_name = "optimizer.pt" optim_file_name = os.path.join(args.restart_dir, optim_name) logging(f"reloading {optim_file_name}") if os.path.exists(os.path.join(args.restart_dir, optim_name)): with open(os.path.join(args.restart_dir, optim_name), "rb") as optim_file: opt_state_dict = torch.load(optim_file) try: optimizer.load_state_dict(opt_state_dict) # in case the optimizer param groups aren't the same shape, # merge them except: logging("merging optimizer param groups") opt_state_dict["param_groups"][0]["params"] = [ param for param_group in opt_state_dict["param_groups"] for param in param_group["params"] ] opt_state_dict["param_groups"] = [ opt_state_dict["param_groups"][0] ] optimizer.load_state_dict(opt_state_dict) else: logging("Optimizer was not saved. Start from scratch.") return optimizer, optimizer_sparse def log_val(val_loss, step, compute): logging("-" * 100) log_str = ("| Eval {:3d} at step {:>8d} | time: {:5.2f}s " "| valid loss {:5.2f}".format( step // args.eval_interval, step, (time.time() - eval_start_time), val_loss, )) log_str += " | bpc {:9.5f}".format(val_loss / math.log(2)) logging(log_str) logging("-" * 100) def epoch_loop( epoch, model, optimizers, schedulers, ): nonlocal train_step # Turn on training mode which enables dropout. if isinstance(model, nn.DataParallel): parent_model = model.module else: parent_model = model optimizer, optimizer_sparse = optimizers scheduler, scheduler_sparse = schedulers # global train_step, best_val_loss, eval_start_time, log_start_time train_losses = [] model.train() if args.batch_chunk > 1: mems = [tuple() for _ in range(args.batch_chunk)] else: mems = tuple() train_iter = tr_iter.get_varlen_iter() if args.varlen else tr_iter log_start_time = time.time() best_val_loss = float("Infinity") for batch, (data, target, seq_len) in enumerate(train_iter): model.zero_grad() if args.batch_chunk > 1: data_chunks = torch.chunk(data, args.batch_chunk, 1) target_chunks = torch.chunk(target, args.batch_chunk, 1) for i in range(args.batch_chunk): data_i = data_chunks[i].contiguous() target_i = target_chunks[i].contiguous() ret = model(data_i, target_i, *mems[i]) loss, mems[i] = ret[0], ret[1:] loss = loss.float().mean().type_as(loss) / args.batch_chunk if args.fp16: with amp.scale_loss(loss, optimizer) as scaled_loss: scaled_loss.backward() else: loss.backward() train_losses.append(loss.float().item()) else: ret = model(data, target, *mems) loss, mems = ret[0], ret[1:] loss = loss.float().mean().type_as(loss) if args.fp16: with amp.scale_loss(loss, optimizer) as scaled_loss: scaled_loss.backward() else: loss.backward() train_losses.append(loss.float().item()) if args.fp16: torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), args.clip) else: torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip) optimizer.step() parent_model.compute += openai_compute( non_emb_param_count(parent_model, nseries), data.numel(), 1) # step-wise learning rate annealing train_step += 1 parent_model.training_steps += 1 # check for yet-to-thaw parameters if getattr(parent_model, "freeze_countdown", 0) > 0: parent_model.freeze_countdown -= 1 # if this is the last step if parent_model.freeze_countdown == 0: for parameter in parent_model.parameters(): parameter.requires_grad = True logging("thawing all parameters") if args.scheduler in ["cosine", "constant", "dev_perf"]: # linear warmup stage if train_step < args.warmup_step: curr_lr = args.lr * train_step / args.warmup_step optimizer.param_groups = curr_lr else: if args.scheduler == "cosine": scheduler.step(train_step) elif args.scheduler == "inv_sqrt": scheduler.step(train_step) if train_step % args.log_interval == 0: cur_loss = np.mean(train_losses) elapsed = time.time() - log_start_time log_str = ("| epoch {:3d} step {:>8d} " "| {:>6d} batches " "| lr {:.3g} " "| ms/batch {:5.2f} " "| loss {:5.2f}".format( epoch, train_step, batch + 1, optimizer.param_groups[0]["lr"], elapsed * 1000 / args.log_interval, cur_loss, )) log_str += " | bpc {:9.5f}".format(cur_loss / math.log(2)) logging(log_str) train_losses = [] log_start_time = time.time() if train_step % args.eval_interval == 0: val_loss = evaluate(va_iter, model) log_val(val_loss, step=train_step, compute=parent_model.compute) # Save the model if the validation loss is the best we've seen so # far. if not best_val_loss or val_loss < best_val_loss: best_val_loss = val_loss if not args.debug: if args.fp16: with open( os.path.join(args.work_dir, "amp_checkpoint.pt"), "wb", ) as f: checkpoint = { "model": model.state_dict(), "optimizer": optimizer.state_dict(), "amp": amp.state_dict(), } torch.save(checkpoint, f) else: with open(os.path.join(args.work_dir, "model.pt"), "wb") as f: torch.save(parent_model, f) with open( os.path.join(args.work_dir, "optimizer.pt"), "wb", ) as f: torch.save(optimizer.state_dict(), f) # dev-performance based learning rate annealing if args.scheduler == "dev_perf": scheduler.step(val_loss) eval_start_time = time.time() if train_step == args.max_step: break def expand_model( strategy, integration, integration_length, n_add, model: MemTransformerLM, optimizers, schedulers, tr_iter, va_iter, epoch, step, ): optimizer, _ = optimizers scheduler, _ = schedulers if integration: if not integration_length or integration_length <= 0: warnings.warn( f"integration {integration} passed but integration_length is {integration_length}" ) else: logging( f"applying integration strategy {integration} with integration length {integration_length}" ) # pre-expansion validation logging(f"evaluating before expanding") val_loss = evaluate(va_iter, model) log_val(val_loss, step=step, compute=model.compute) # infer example logits for reverse distillation if "reverse_distil" in integration: first_logits = get_original_batches(model, tr_iter, integration_length) # expansion logging( f"adding {n_add} layers before starting epoch {epoch} with method {strategy}" ) new_layers = model.expand_layers(n_add, strategy=strategy, function=initialization_func) # optimizer update optimizer.add_param_group({ "params": new_layers.parameters(), "lr": optimizer.param_groups[0]["lr"], "initial_lr": optimizer.param_groups[0]["initial_lr"], }) scheduler.base_lrs.append(optimizer.param_groups[-1]["initial_lr"]) # training loop for reverse distillation if "reverse_distil" in integration: fit_to_previous_model(model, new_layers, tr_iter, first_logits, integration) # freezing parameters for frozen restart, we do this afterwards else the # new layers get copied also without grads if "freeze" in integration and integration_length > 0: for param_group in optimizer.param_groups[:-1]: for parameter in param_group["params"]: parameter.requires_grad = False model.freeze_countdown = integration_length # post-expansion validation logging(f"reevaluating") val_loss = evaluate(va_iter, model) log_val(val_loss, step=step, compute=model.compute) def expand_state(param, state): if param.shape != state.shape: ratios = [ param.shape[i] // state.shape[i] for i in range(len(param.shape)) ] return state.repeat(*ratios) else: return state def widen_model( strategy, ratio, model: MemTransformerLM, optimizers, va_iter, epoch, step, ): optimizer, _ = optimizers # pre-expansion validation logging(f"evaluating before widening") # debug val_loss = evaluate(va_iter, model) log_val(val_loss, compute=model.compute, step=step) # infer example logits for reverse distillation expansion logging( f"adding {ratio} layers before starting epoch {epoch} with method {strategy}" ) model.add_heads(ratio, strategy=strategy, function=initialization_func) # optimizer update for param, states in optimizer.state.items(): if isinstance(param, nn.Parameter): states["exp_avg"] = expand_state(param, states["exp_avg"]) states["exp_avg_sq"] = expand_state(param, states["exp_avg_sq"]) # training loop for reverse distillation # post-expansion validation logging(f"reevaluating") val_loss = evaluate(va_iter, model) log_val(val_loss, step=step, compute=model.compute) # reverse distillation trainer def fit_to_previous_model(model, new_layers, tr_iter, first_logits, integration): mse_loss = torch.nn.MSELoss() if "partial" in integration: distil_optimizer, distil_optimizer_sparse = build_optimizer( new_layers, reload=False) else: distil_optimizer, distil_optimizer_sparse = build_optimizer( model, reload=False) if args.cuda and args.fp16: model, distil_optimizer = amp.initialize(model, distil_optimizer, opt_level=args.fp16) model.train() if args.batch_chunk > 1: mems = [None for _ in range(args.batch_chunk)] else: mems = None train_iter = tr_iter.get_varlen_iter() if args.varlen else tr_iter for batch, (data, _, _) in enumerate(train_iter): if batch == len(first_logits): break model.zero_grad() if args.batch_chunk > 1: data_chunks = torch.chunk(data, args.batch_chunk, 1) for i in range(args.batch_chunk): data_i = data_chunks[i].contiguous() logits, mems[i] = model._forward(data_i, mems=mems[i]) target_logits = first_logits[i][batch].to(logits.device) loss = mse_loss(logits, target_logits) / args.batch_chunk if args.fp16: with amp.scale_loss(loss, distil_optimizer) as scaled_loss: scaled_loss.backward() else: loss.backward() else: logits, mems = model._forward(data, mems=mems) target_logits = first_logits[batch].to(logits.device) loss = mse_loss(logits, target_logits) if args.fp16: with amp.scale_loss(loss, distil_optimizer) as scaled_loss: scaled_loss.backward() else: loss.backward() if args.fp16: torch.nn.utils.clip_grad_norm_( amp.master_params(distil_optimizer), args.clip) else: torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip) distil_optimizer.step() ################################################################################### # # main() # args.tied = not args.not_tied if args.d_embed < 0: args.d_embed = args.d_model # Validate `--fp16` option if args.fp16: if not args.cuda: print("WARNING: --fp16 requires --cuda, ignoring --fp16 option") args.fp16 = False else: try: from apex import amp if args.fp16 == "O1": amp.register_half_function(torch, "einsum") except: print("WARNING: apex not installed, ignoring --fp16 option") args.fp16 = False device = torch.device("cuda" if args.cuda else "cpu") # Set the random seed manually for reproducibility. np.random.seed(args.seed) torch.manual_seed(args.seed) if torch.cuda.is_available(): if not args.cuda: print( "WARNING: You have a CUDA device, so you should probably run " "with --cuda ") else: torch.cuda.manual_seed_all(args.seed) ############################################################################ # Logging ############################################################################ assert args.ext_len >= 0, "extended context length must be non-negative" assert args.d_batch % args.batch_chunk == 0 args.work_dir = "{}-{}".format(args.work_dir, args.dataset) args.work_dir = os.path.join(args.work_dir, time.strftime("%Y%m%d-%H%M%S")) logging = create_exp_dir( args.work_dir, scripts_to_save=["train_ts.py", "mem_transformer.py"], debug=args.debug, ) ############################################################################ # Load data ############################################################################ time_series = get_time_series(args.datadir, args.dataset) nseries = len(time_series.vocab) args.n_token = nseries eval_batch_size = 20 tr_iter = time_series.get_iterator( "train", args.d_batch, args.tgt_len, device=device, ext_len=args.ext_len, ) va_iter = time_series.get_iterator( "valid", eval_batch_size, args.eval_tgt_len, device=device, ext_len=args.ext_len, ) te_iter = time_series.get_iterator( "test", eval_batch_size, args.eval_tgt_len, device=device, ext_len=args.ext_len, ) cutoffs, tie_projs = [], [False] ############################################################################ # Define model ############################################################################ initialization_func = partial( weights_init, init=args.init, init_range=args.init_range, init_std=args.init_std, proj_init_std=args.proj_init_std, ) if args.restart and not args.fp16: if args.restart_from is not None: model_name = f"model_{args.restart_from}.pt" else: model_name = "model.pt" model_file_name = os.path.join(args.restart_dir, model_name) logging(f"reloading {model_file_name}") with open(model_file_name, "rb") as f: model = torch.load(f) # backwards compatibility with older saves if isinstance(model, nn.DataParallel): model = model.module model.backward_compatible(tie_weight=args.tied, tie_projs=tie_projs) if not args.fp16: model = model.float() model.apply(update_dropout) model.apply(update_dropatt) else: model = MemTransformerLM( nseries, args.n_layer, args.n_head, args.d_model, args.d_head, args.d_inner, args.dropout, args.dropatt, tie_weight=args.tied, d_embed=args.d_embed, div_val=args.div_val, tie_projs=tie_projs, pre_lnorm=args.pre_lnorm, tgt_len=args.tgt_len, ext_len=args.ext_len, mem_len=args.mem_len, cutoffs=cutoffs, same_length=args.same_length, clamp_len=args.clamp_len, ) model.apply(initialization_func) # debug # model.word_emb.apply(initialization_func) # ensure embedding init is not overridden by out_layer in case of # weight sharing args.n_all_param = sum([p.nelement() for p in model.parameters()]) args.n_nonemb_param = non_emb_param_count(model, nseries) logging("=" * 100) for k, v in args.__dict__.items(): logging(" - {} : {}".format(k, v)) logging("=" * 100) logging("#params = {}".format(args.n_all_param)) logging("#non emb params = {}".format(args.n_nonemb_param)) para_model = parallelize_model(model, args) optimizers = build_optimizer(para_model, args, reload=args.restart and not args.fp16) optimizer, optimizer_sparse = optimizers schedulers = build_scheduler(optimizers, args) scheduler, scheduler_sparse = schedulers if args.cuda and args.fp16: para_model, optimizer = amp.initialize(para_model, optimizer, opt_level=args.fp16) if args.restart: if args.restart_from is not None: checkpoint_name = f"amp_checkpoint_{args.restart_from}.pt" else: checkpoint_name = "amp_checkpoint.pt" with open(os.path.join(args.work_dir, checkpoint_name), "rb") as f: checkpoint = torch.load(f) model.load_state_dict(checkpoint["model"]) optimizer.load_state_dict(checkpoint["optimizer"]) amp.load_state_dict(checkpoint["amp"]) ############################################################################ # Training loop ############################################################################ # Loop over epochs. if args.reset_lr: # then they're different and we use train_step only for the new lr # scheduling train_step = 0 optimizer.defaults["lr"] = args.lr for param_group in optimizer.param_groups: param_group["lr"] = args.lr param_group["initial_lr"] = args.lr scheduler.base_lrs = [args.lr] * len(scheduler.base_lrs) else: train_step = model.training_steps best_val_loss = None # Reload previous step number in case of default_args.restart if train_step > 0: logging(f"restarting from step {train_step}") log_start_time = time.time() eval_start_time = time.time() def run_training(): nonlocal train_step for epoch in itertools.count(start=first_epoch): # we check before the training loop; expanding at epoch 0 means # before training (for debug purposes) if args.expand and str(epoch - 1) in args.expansion_dict: n_add = int(args.expansion_dict[str(epoch - 1)]) expand_model( args.expand, args.integration, args.integration_length, n_add, model, optimizers, schedulers, tr_iter, va_iter, epoch, train_step, ) if args.widen and str(epoch - 1) in args.widen_dict: ratio = int(args.widen_dict[str(epoch - 1)]) widen_model( args.widen, ratio, model, optimizers, va_iter, epoch, train_step, ) epoch_loop(epoch, para_model, optimizers, schedulers) if train_step >= args.max_step: logging("-" * 100) logging("End of training") break if not args.debug and args.log_first_epochs: if epoch <= args.log_first_epochs: logging(f"saving model at the end of epoch {epoch}") if args.fp16: with open( os.path.join(args.work_dir, f"amp_checkpoint_{epoch}.pt"), "wb", ) as f: checkpoint = { "model": model.state_dict(), "optimizer": optimizer.state_dict(), "amp": amp.state_dict(), } torch.save(checkpoint, f) else: with open( os.path.join(args.work_dir, f"model_{epoch}.pt"), "wb", ) as f: torch.save(model, f) with open( os.path.join(args.work_dir, f"optimizer_{epoch}.pt"), "wb", ) as f: torch.save(optimizer.state_dict(), f) # At any point you can hit Ctrl + C to break out of training early. try: if args.restart_from: first_epoch = args.restart_from + 1 print(f"restarting from epoch {first_epoch}") else: first_epoch = 1 run_training() except KeyboardInterrupt: logging("-" * 100) logging("Exiting from training early") # Load the best model. if args.fp16: with open(os.path.join(args.work_dir, "amp_checkpoint.pt"), "rb") as f: checkpoint = torch.load(f) model.load_state_dict(checkpoint["model"]) optimizer.load_state_dict(checkpoint["optimizer"]) amp.load_state_dict(checkpoint["amp"]) else: with open(os.path.join(args.work_dir, "model.pt"), "rb") as f: model = torch.load(f) para_model = model.to(device) # Run on test data. test_loss = evaluate(te_iter, para_model) logging("=" * 100) logging("| End of training | test loss {:5.2f} | test bpc {:9.5f}".format( test_loss, test_loss / math.log(2))) logging("=" * 100)
if args.multi_gpu: model = model.to(device) if args.gpu0_bsz >= 0: para_model = BalancedDataParallel(args.gpu0_bsz // args.batch_chunk, model, dim=1).to(device) else: para_model = nn.DataParallel(model, dim=1).to(device) else: para_model = model.to(device) #### optimizer #### add new optimizers if args.optim.lower() == 'adam': if args.sample_softmax > 0: dense_params, sparse_params = [], [] for param in model.parameters(): if param.size() == model.word_emb.weight.size(): sparse_params.append(param) else: dense_params.append(param) optimizer_sparse = optim.SparseAdam(sparse_params, lr=args.lr) optimizer = optim.Adam(dense_params, lr=args.lr) else: optimizer = optim.Adam(model.parameters(), lr=args.lr) #### scheduler if args.scheduler == 'cosine': # here we do not set eta_min to lr_min to be backward compatible # because in previous versions eta_min is default to 0 # rather than the default value of lr_min 1e-6 scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer,
model = model.float() model.apply(update_dropout) model.apply(update_dropatt) else: model = MemTransformerLM(ntokens, args.n_layer, args.n_head, args.d_model, args.d_head, args.d_inner, args.dropout, args.dropatt, tie_weight=args.tied, d_embed=args.d_embed, div_val=args.div_val, tie_projs=tie_projs, pre_lnorm=args.pre_lnorm, tgt_len=args.tgt_len, ext_len=args.ext_len, mem_len=args.mem_len, cutoffs=cutoffs, same_length=args.same_length, attn_type=args.attn_type, clamp_len=args.clamp_len, sample_softmax=args.sample_softmax) model.apply(initialization_func) # debug # model.word_emb.apply(initialization_func) # ensure embedding init is not overridden by out_layer in case of weight sharing args.n_all_param = sum([p.nelement() for p in model.parameters()]) args.n_nonemb_param = non_emb_param_count(model, ntokens) logging('=' * 100) for k, v in args.__dict__.items(): logging(' - {} : {}'.format(k, v)) logging('=' * 100) logging('#params = {}'.format(args.n_all_param)) logging('#non emb params = {}'.format(args.n_nonemb_param)) para_model = parallelize_model(model, args) optimizers = build_optimizer(para_model, args, reload=args.restart and not args.fp16) optimizer, optimizer_sparse = optimizers schedulers = build_scheduler(optimizers, args) scheduler, scheduler_sparse = schedulers