def run(): parser = ArgumentParser() parser.add_argument("--basedir", type=str) parser.add_argument("--checkpoint", type=str, help='Checkpoint name or directory to load') parser.add_argument("--sample", type=str2bool, help='Sample from the decoder? Defaults to `false`', default=0) parser.add_argument("--query", type=str, default='hello , <unk> are you today ?') parser.add_argument("--dataset_cache", type=str, default=os.path.expanduser('~/.bl-data'), help="Path or url of the dataset cache") parser.add_argument("--d_model", type=int, default=512, help="Model dimension (and embedding dsz)") parser.add_argument("--d_ff", type=int, default=2048, help="FFN dimension") parser.add_argument("--d_k", type=int, default=None, help="Dimension per head. Use if num_heads=1 to reduce dims") parser.add_argument("--num_heads", type=int, default=8, help="Number of heads") parser.add_argument("--num_layers", type=int, default=8, help="Number of layers") parser.add_argument("--nctx", type=int, default=128, help="Max context length (for both encoder and decoder)") parser.add_argument("--embed_type", type=str, default='default', help="register label of the embeddings, so far support positional or learned-positional") parser.add_argument("--subword_model_file", type=str, required=True) parser.add_argument("--subword_vocab_file", type=str, required=True) parser.add_argument("--use_cls", type=str2bool, default=False) parser.add_argument('--end_token', default='<EOU>') parser.add_argument("--activation", type=str, default='gelu') parser.add_argument('--rpr_k', help='Relative attention positional sizes pass 0 if you dont want relative attention', type=int, default=[8], nargs='+') parser.add_argument("--y_only", type=str2bool, default=False) parser.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu", help="Device (cuda or cpu)") args = parser.parse_args() if torch.cuda.device_count() == 1: torch.cuda.set_device(0) args.device = torch.device("cuda", 0) if os.path.isdir(args.checkpoint): checkpoint, _ = find_latest_checkpoint(args.checkpoint) logger.warning("Found latest checkpoint %s", checkpoint) else: checkpoint = args.checkpoint cls = None if not args.use_cls else '[CLS]' end = args.end_token vectorizer = BPEVectorizer1D(model_file=args.subword_model_file, vocab_file=args.subword_vocab_file, mxlen=args.nctx, emit_begin_tok=cls, emit_end_tok=end) vocab = vectorizer.vocab.copy() # If we are not using chars, then use 'x' for both input and output preproc_data = baseline.embeddings.load_embeddings('x', dsz=args.d_model, counts=False, known_vocab=vocab, embed_type=args.embed_type, preserve_vocab_indices=True) embeddings = preproc_data['embeddings'] vocab = preproc_data['vocab'] model = create_model(embeddings, d_model=args.d_model, d_ff=args.d_ff, num_heads=args.num_heads, num_layers=args.num_layers, rpr_k=args.rpr_k, d_k=args.d_k, checkpoint_name=checkpoint, activation=args.activation) model.to(args.device) index2word = revlut(vocab) print('[Query]', args.query) bpe_out = decode_sentence(model, vectorizer, args.query.split(), vocab, index2word, args.device, sample=args.sample, y_only=args.y_only) print('[Response]', ' '.join(bpe_out))
def run(): parser = ArgumentParser() parser.add_argument("--basedir", type=str) parser.add_argument("--checkpoint", type=str, help='Checkpoint name or directory to load') parser.add_argument("--sample", type=str2bool, help='Sample from the decoder? Defaults to `false`', default=0) parser.add_argument("--vocab", type=str, help='Vocab file to load', required=False) parser.add_argument("--input", type=str, default='hello how are you ?') parser.add_argument("--dataset_cache", type=str, default=os.path.expanduser('~/.bl-data'), help="Path or url of the dataset cache") parser.add_argument("--d_model", type=int, default=512, help="Model dimension (and embedding dsz)") parser.add_argument("--d_ff", type=int, default=2048, help="FFN dimension") parser.add_argument( "--d_k", type=int, default=None, help="Dimension per head. Use if num_heads=1 to reduce dims") parser.add_argument("--num_heads", type=int, default=8, help="Number of heads") parser.add_argument("--num_layers", type=int, default=8, help="Number of layers") parser.add_argument( "--nctx", type=int, default=256, help="Max context length (for both encoder and decoder)") parser.add_argument( "--embed_type", type=str, default='default', help= "register label of the embeddings, so far support positional or learned-positional" ) parser.add_argument("--subword_model_file", type=str, required=True) parser.add_argument("--subword_vocab_file", type=str, required=True) parser.add_argument("--batchsz", help="Size of a batch to pass at once", default=4, type=int) parser.add_argument("--beamsz", help="Size of beam to use", default=4, type=int) parser.add_argument("--activation", type=str, default='relu') parser.add_argument( '--rpr_k', help= 'Relative attention positional sizes pass 0 if you dont want relative attention', type=int, default=[8] * 8, nargs='+') #parser.add_argument("--go_token", default="<GO>") parser.add_argument("--end_token", default="<EOS>") parser.add_argument("--output_file", type=str) parser.add_argument("--show_query", type=str2bool, default=False, help="Show the original query as well") parser.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu", help="Device (cuda or cpu)") args = parser.parse_args() if torch.cuda.device_count() == 1: torch.cuda.set_device(0) args.device = torch.device("cuda", 0) if os.path.isdir(args.checkpoint): checkpoint, _ = find_latest_checkpoint(args.checkpoint) logger.warning("Found latest checkpoint %s", checkpoint) else: checkpoint = args.checkpoint vectorizer = BPEVectorizer1D(model_file=args.subword_model_file, vocab_file=args.subword_vocab_file, mxlen=args.nctx, emit_end_tok=args.end_token) vocab = vectorizer.vocab # If we are not using chars, then use 'x' for both input and output preproc_data = baseline.embeddings.load_embeddings( 'x', dsz=args.d_model, counts=False, known_vocab=vocab, embed_type=args.embed_type) embeddings = preproc_data['embeddings'] vocab = preproc_data['vocab'] model = create_model(embeddings, d_model=args.d_model, d_ff=args.d_ff, num_heads=args.num_heads, num_layers=args.num_layers, rpr_k=args.rpr_k, d_k=args.d_k, checkpoint_name=checkpoint, activation=args.activation, device=args.device) model.to(args.device) index2word = revlut(vocab) wf = None if args.output_file: wf = open(args.output_file, "w") batches = [] if os.path.exists(args.input) and os.path.isfile(args.input): with open(args.input, 'rt', encoding='utf-8') as f: batch = [] for line in f: text = line.strip().split() if len(batch) == args.batchsz: batches.append(batch) batch = [] batch.append(text) if len(batch) > 0: batches.append(batch) else: batch = [args.input.split()] batches.append(batch) for queries in batches: outputs = decode_sentences(model, vectorizer, queries, vocab, index2word, args.beamsz) if args.show_query: for query, output in zip(queries, outputs): print(f"[Query] {query}") print(f"[Response] {output}") elif wf: for query, output in zip(queries, outputs): wf.write(f'{output}\n') wf.flush() else: for query, output in zip(queries, outputs): print(output) if wf: wf.close()
def train(): parser = ArgumentParser() parser.add_argument("--basedir", type=str) parser.add_argument("--train_file", type=str, help='Optional file path to use for train file') parser.add_argument("--valid_file", type=str, help='Optional file path to use for valid file') parser.add_argument("--preprocessed", type=str2bool, default=True, help="Has the data already been preprocessed?") parser.add_argument("--gen_d_model", type=int, default=256, help="Model dimension (and embedding dsz)") parser.add_argument("--gen_d_ff", type=int, default=1024, help="FFN dimension") parser.add_argument( "--gen_d_k", type=int, default=None, help="Dimension per head. Use if num_heads=1 to reduce dims") parser.add_argument("--gen_num_heads", type=int, default=8, help="Number of heads") parser.add_argument("--gen_num_layers", type=int, default=8, help="Number of layers") parser.add_argument("--gen_dropout", type=float, default=0.1, help="Dropout") parser.add_argument( '--gen_rpr_k', help= 'Relative attention positional sizes pass 0 if you dont want relative attention', type=int, default=[8], nargs='+') parser.add_argument("--discrim_d_model", type=int, default=512, help="Model dimension (and embedding dsz)") parser.add_argument("--discrim_d_ff", type=int, default=2048, help="FFN dimension") parser.add_argument( "--discrim_d_k", type=int, default=None, help="Dimension per head. Use if num_heads=1 to reduce dims") parser.add_argument("--discrim_num_heads", type=int, default=8, help="Number of heads") parser.add_argument("--discrim_num_layers", type=int, default=8, help="Number of layers") parser.add_argument("--discrim_dropout", type=float, default=0.1, help="Dropout") parser.add_argument( '--discrim_rpr_k', help= 'Relative attention positional sizes pass 0 if you dont want relative attention', type=int, default=[8], nargs='+') parser.add_argument("--num_train_workers", type=int, default=4, help="Number train workers") parser.add_argument( "--nctx", type=int, default=256, help="Max context length (for both encoder and decoder)") parser.add_argument( "--embed_type", type=str, default='default', choices=["default", "positional", "learned-positional"], help="register label of the embeddings") parser.add_argument( "--pattern", default='*.json', help= "Glob pattern for files, defaults to *.json if preprocessed, *.txt otherwise" ) parser.add_argument("--batch_size", type=int, default=256, help="Batch Size") parser.add_argument("--dataset_key", default="reddit", help="dataset key for basedir") parser.add_argument("--subword_model_file", type=str, required=True) parser.add_argument("--subword_vocab_file", type=str, required=True) parser.add_argument("--lr_scheduler", type=str, default='cosine', help="The type of learning rate decay scheduler") parser.add_argument("--lr_decay_steps", type=int, help="decay steps of lr scheduler") parser.add_argument("--lr_decay_rate", type=float, help="decay rate of lr scheduler") parser.add_argument("--lr_alpha", type=float, help="parameter alpha for cosine decay scheduler") parser.add_argument("--optim", default="adam", type=str, help="Optimizer to use (defaults to adam)") parser.add_argument("--lr", type=float, default=4.0e-4, help="Learning rate") parser.add_argument("--clip", type=float, default=1.0, help="Clipping gradient norm") parser.add_argument("--gen_loss_scale", type=float, default=50.0, help="Scaling for loss function") parser.add_argument("--weight_decay", type=float, default=0.0, help="Weight decay") parser.add_argument("--epochs", type=int, default=32, help="Num training epochs") parser.add_argument( "--restart_from", type=str, help= "Option allows you to restart from the latest checkpoint in a directory" ) parser.add_argument( "--restart_tt", type=str, choices=['step', 'epoch'], default='step', help="Optional param for legacy checkpoints (step|epoch)") parser.add_argument("--warmup_steps", type=int, default=10000, help="Num warmup steps") parser.add_argument("--saves_per_epoch", type=int, default=100, help="The number of checkpoints to save per epoch") parser.add_argument("--print", type=str2bool, default=True, help="Print some output") parser.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu", help="Device (cuda or cpu)") parser.add_argument("--distributed", type=str2bool, default=False, help="Are we doing distributed training?") parser.add_argument( "--local_rank", type=int, default=-1, help= "Local rank for distributed training (-1 means use the environment variables to find)" ) args = parser.parse_args() if args.train_file and not args.valid_file: logger.error( "If you provide a train_file, you must provide a valid_file") return if not args.train_file and args.valid_file: logger.error( "If you provide a valid_file, you must also provide a train_file") return if args.basedir is None: args.basedir = 'gd-{}-bpe-{}'.format(args.dataset_key, os.getpid()) logging.basicConfig( format="%(name)s: %(levelname)s: %(message)s", level=logging.INFO if args.local_rank in [-1, 0] else logging.WARN) num_gpus = get_num_gpus_multiworker() args.distributed = args.distributed or num_gpus > 1 logger.info(f"Using {num_gpus} GPUs in this job.") if args.distributed: args.device, args.local_rank = init_distributed(args.local_rank) if not args.preprocessed: reader_type = "lang" args.pattern = "*.txt" else: reader_type = "preprocessed" reader = MultiFileDatasetReader(args.nctx, args.subword_model_file, args.subword_vocab_file, args.pattern, reader_type=reader_type) # just return the vocab from the BPE vectorizer vocab = reader.build_vocab([]) gen_embed = baseline.embeddings.load_embeddings('x', dsz=args.gen_d_model, known_vocab=vocab['x'], embed_type=args.embed_type) vocabs = gen_embed['vocab'] index2word = revlut(vocabs) discrim_embed = baseline.embeddings.load_embeddings( 'x', dsz=args.discrim_d_model, known_vocab=vocab['x'], embed_type=args.embed_type) os.makedirs(args.basedir, exist_ok=True) # We want to make sure to save our input vocab into the basedir for reuse later write_json(vocabs, os.path.join(args.basedir, 'vocabs.json')) gen_embeddings = {'x': gen_embed['embeddings']} discrim_embeddings = {'x': discrim_embed['embeddings']} logger.info("Loaded embeddings") train_set = reader.load(args.train_file, vocabs) valid_set = reader.load(args.valid_file, vocabs) train_loader = DataLoader(train_set, batch_size=args.batch_size, num_workers=args.num_train_workers) valid_loader = DataLoader(valid_set, batch_size=args.batch_size) train_steps_per_epoch = len(train_loader) // (args.batch_size * num_gpus) valid_steps_per_epoch = len(valid_loader) // args.batch_size logger.info("Loaded datasets") logger.info("Using embedding type [%s]", args.embed_type) mask_value = vocabs.get("[MASK]", vocabs.get("<MASK>", -1)) if mask_value == -1: logger.error("We could not find a suitable masking token in the vocab") return os.makedirs(args.basedir, exist_ok=True) vocab_size = len(vocabs) if len(args.gen_rpr_k) == 0 or args.gen_rpr_k[0] < 1: gen_rpr_k = None elif len(args.gen_rpr_k) == 1: gen_rpr_k = args.gen_rpr_k[0] else: gen_rpr_k = args.gen_rpr_k if len(args.gen_rpr_k) == 0 or args.discrim_rpr_k[0] < 1: discrim_rpr_k = None elif len(args.discrim_rpr_k) == 1: discrim_rpr_k = args.discrim_rpr_k[0] else: discrim_rpr_k = args.discrim_rpr_k gen_model = TransformerMaskedLanguageModel.create( gen_embeddings, hsz=args.gen_d_model, d_ff=args.gen_d_ff, tie_weights=True, dropout=args.gen_dropout, num_heads=args.gen_num_heads, layers=args.gen_num_layers, rpr_k=gen_rpr_k, d_k=args.gen_d_k, src_keys=['x'], tgt_key='x') discrim_model = TransformerDiscriminator(discrim_embeddings, d_model=args.discrim_d_model, d_ff=args.discrim_d_ff, dropout=args.discrim_dropout, num_heads=args.discrim_num_heads, layers=args.discrim_num_layers, activation='gelu', layer_norm_eps=1.0e-12, rpr_k=discrim_rpr_k, d_k=args.discrim_d_k) gen_model.to(args.device) gen_loss_fn = gen_model.create_loss() discrim_model.to(args.device) discrim_loss_fn = discrim_model.create_loss() logger.info("Loaded model and loss") update_on = train_steps_per_epoch // args.saves_per_epoch report_on = update_on // 10 logger.info( f"Steps per epoch per GPU: {train_steps_per_epoch}. Saving checkpoint every {update_on} steps." ) lr_decay = get_lr_decay(args.lr_scheduler, args.lr, train_steps_per_epoch, args.epochs, logger, decay_steps=args.lr_decay_steps, decay_rate=args.lr_decay_rate, alpha=args.lr_alpha) linear_warmup = WarmupLinearSchedulerPyTorch(args.warmup_steps, lr=args.lr) lr_sched = CompositeLRScheduler(linear_warmup, lr_decay, lr=args.lr) global_step = 0 start_epoch = 0 if args.restart_from: if not os.path.isdir(args.restart_from): raise Exception( f"Cannot restart from {args.restart_from}, directory not found" ) tick_type = args.restart_tt discrim_latest, step_num = find_latest_checkpoint( args.restart_from, wildcard=f'checkpoint-discrim-{tick_type}') gen_latest, _ = find_latest_checkpoint( args.restart_from, wildcard=f'checkpoint-gen-{tick_type}') discrim_model.load_state_dict(torch.load(discrim_latest)) gen_model.load_state_dict(torch.load(gen_latest)) if tick_type == 'step': start_epoch = step_num // train_steps_per_epoch global_step = step_num else: start_epoch = step_num global_step = train_steps_per_epoch * start_epoch parameters = list(discrim_model.parameters()) + list( gen_model.parameters()) optz = OptimizerManager(parameters, global_step, optim=args.optim, lr=args.lr, lr_function=lr_sched, weight_decay=args.weight_decay) logger.info("Generator has {:,} parameters".format( sum(p.numel() for p in gen_model.parameters() if p.requires_grad))) logger.info("Discriminator has {:,} parameters".format( sum(p.numel() for p in discrim_model.parameters() if p.requires_grad))) # Prepare model for distributed training if needed if args.distributed: # This program assume pure data parallelism, each model is on a single gpu # If we wanted to support model and data parallelism we would need to update # the selection of gpus based on rank, it would need to select multiple ids # based on rank, here we select only a single gpu and use it for input and # output. gen_model = DistributedDataParallel(gen_model, device_ids=[args.device], output_device=args.device) discrim_model = DistributedDataParallel(discrim_model, device_ids=[args.device], output_device=args.device) logger.info("Model located on %s", args.device) # This is the training loop steps = global_step model_base = os.path.join(args.basedir, 'checkpoint') discrim_base = f'{model_base}-discrim' gen_base = f'{model_base}-gen' do_on_demand_masking = not args.preprocessed if do_on_demand_masking: logger.info(f"On-demand masking is turned on") for epoch in range(start_epoch, args.epochs): gen_model.train() discrim_model.train() avg_gen_loss = Average('average_train_gen_loss') avg_discrim_loss = Average('average_train_discrim_loss') avg_discrim_acc = Average('average_train_discrim_acc') avg_train_loss = Average('average5_train_loss') metrics = {} optz.zero_grad() start = time.time() print(f'Starting epoch {epoch + 1}') train_iter = iter(train_loader) valid_iter = iter(valid_loader) for i in range(train_steps_per_epoch): steps += 1 x, y = next(train_iter) do_report = (i + 1) % report_on == 0 and args.print gen_loss_step, discrim_loss_step, acc = gen_vs_discrim( x, y, args.device, gen_model, gen_loss_fn, discrim_model, discrim_loss_fn, mask_value, vocab_size, index2word, do_report, do_on_demand_masking) avg_gen_loss.update(gen_loss_step.item()) total_loss_step = gen_loss_step + args.gen_loss_scale * discrim_loss_step total_loss_step.backward() avg_discrim_loss.update(discrim_loss_step.item()) avg_train_loss.update(total_loss_step.item()) avg_discrim_acc.update(acc) torch.nn.utils.clip_grad_norm_(parameters, args.clip) optz.step() optz.zero_grad() if (i + 1) % report_on == 0: logging.info('Loss g=%f, d=%f total=%f, Per token acc=%f', avg_gen_loss.avg, avg_discrim_loss.avg, avg_train_loss.avg, avg_discrim_acc.avg) if (i + 1) % update_on == 0 and args.local_rank < 1: elapsed = (time.time() - start) / 60 logging.info('elapsed time this epoch %d min', elapsed) logging.info('elapsed step time %f steps/min', i / elapsed) logging.info('LR: %f', optz.current_lr) save_checkpoint(gen_model, gen_base, steps, tick_type='step') save_checkpoint(discrim_model, discrim_base, steps, tick_type='step') # How much time elapsed in minutes elapsed = (time.time() - start) / 60 # This is the average training token-level loss across all machines # This is the token-level training perplexity metrics['train_elapsed_min'] = elapsed metrics['average_train_gen_loss'] = avg_gen_loss.avg metrics['average_train_discrim_loss'] = avg_discrim_loss.avg metrics[ 'average_train_discrim_per_token_accuracy'] = avg_discrim_acc.avg metrics['average_train_loss'] = avg_train_loss.avg if args.local_rank < 1: avg_valid_gen_loss = Average('average_valid_gen_loss') avg_valid_discrim_loss = Average('average_valid_discrim_loss') avg_valid_discrim_acc = Average('average_valid_discrim_acc') avg_valid_loss = Average('average_valid_loss') start = time.time() gen_model.eval() discrim_model.eval() for i in range(valid_steps_per_epoch): with torch.no_grad(): x, y = next(valid_iter) do_report = (i + 1) % report_on == 0 and args.print gen_loss_step, discrim_loss_step, acc = gen_vs_discrim( x, y, args.device, gen_model, gen_loss_fn, discrim_model, discrim_loss_fn, mask_value, vocab_size, index2word, do_report, do_on_demand_masking) avg_valid_gen_loss.update(gen_loss_step.item()) avg_valid_discrim_acc.update(acc) avg_valid_discrim_loss.update(discrim_loss_step.item()) total_loss_step = gen_loss_step + args.gen_loss_scale * discrim_loss_step avg_valid_loss.update(total_loss_step.item()) elapsed = (time.time() - start) / 60 metrics['valid_elapsed_min'] = elapsed metrics['average_valid_gen_loss'] = avg_valid_gen_loss.avg metrics['average_valid_discrim_loss'] = avg_valid_discrim_loss.avg metrics[ 'average_valid_discrim_per_token_accuracy'] = avg_valid_discrim_acc.avg metrics['average_valid_loss'] = avg_valid_loss.avg logger.info(metrics) save_checkpoint(discrim_model, discrim_base, epoch, tick_type='epoch', save_npz=True) save_checkpoint(gen_model, gen_base, epoch, tick_type='epoch', save_npz=True)