def __init__(self, src_nctx=64, tgt_nctx=64, src_begin_tok=[], src_end_tok=['<EOS>'], tgt_begin_tok=['<GO>'], tgt_end_tok=['<EOS>'], model_file=None, vocab_file=None, file_type='txt', reader_type="ntp", record_keys=None, lower=False): self.src_nctx = src_nctx self.tgt_nctx = tgt_nctx self.pattern = f'*.{file_type}' self.reader_type = reader_type if not src_begin_tok and self.reader_type == 'lang': src_begin_tok = ['[CLS]'] self.record_keys = record_keys if record_keys else ['x', 'y'] transform_fn = None if not lower else baseline.lowercase self.src_vectorizer = BPEVectorizer1D(model_file=model_file, vocab_file=vocab_file, mxlen=src_nctx, emit_begin_tok=src_begin_tok, emit_end_tok=src_end_tok, transform_fn=transform_fn) self.tgt_vectorizer = BPEVectorizer1D(model_file=model_file, vocab_file=vocab_file, mxlen=tgt_nctx, emit_begin_tok=tgt_begin_tok, emit_end_tok=tgt_end_tok, transform_fn=transform_fn)
def _create_subword_vectorizer(self, mxlen=None, model_file=None, vocab_file=None, emit_begin_tok=None, emit_end_tok=None, transform_fn=None, extra_tokens=None): if self.subword_type == 'bpe': return BPEVectorizer1D(model_file=model_file, vocab_file=vocab_file, mxlen=mxlen, emit_begin_tok=emit_begin_tok, emit_end_tok=emit_end_tok, transform_fn=transform_fn, extra_tokens=extra_tokens) if self.subword_type == 'wordpiece': return WordpieceVectorizer1D(vocab_file=vocab_file, mxlen=mxlen, emit_begin_tok=emit_begin_tok, emit_end_tok=emit_end_tok, transform_fn=transform_fn) else: from baseline.vectorizers import SentencePieceVectorizer1D return SentencePieceVectorizer1D(model_file=model_file, mxlen=mxlen, emit_begin_tok=emit_begin_tok, emit_end_tok=emit_end_tok, transform_fn=transform_fn, extra_tokens=extra_tokens)
def run(): parser = ArgumentParser() parser.add_argument("--basedir", type=str) parser.add_argument("--checkpoint", type=str, help='Checkpoint name or directory to load') parser.add_argument("--sample", type=str2bool, help='Sample from the decoder? Defaults to `false`', default=0) parser.add_argument("--query", type=str, default='hello , <unk> are you today ?') parser.add_argument("--dataset_cache", type=str, default=os.path.expanduser('~/.bl-data'), help="Path or url of the dataset cache") parser.add_argument("--d_model", type=int, default=512, help="Model dimension (and embedding dsz)") parser.add_argument("--d_ff", type=int, default=2048, help="FFN dimension") parser.add_argument("--d_k", type=int, default=None, help="Dimension per head. Use if num_heads=1 to reduce dims") parser.add_argument("--num_heads", type=int, default=8, help="Number of heads") parser.add_argument("--num_layers", type=int, default=8, help="Number of layers") parser.add_argument("--nctx", type=int, default=128, help="Max context length (for both encoder and decoder)") parser.add_argument("--embed_type", type=str, default='default', help="register label of the embeddings, so far support positional or learned-positional") parser.add_argument("--subword_model_file", type=str, required=True) parser.add_argument("--subword_vocab_file", type=str, required=True) parser.add_argument("--use_cls", type=str2bool, default=False) parser.add_argument('--end_token', default='<EOU>') parser.add_argument("--activation", type=str, default='gelu') parser.add_argument('--rpr_k', help='Relative attention positional sizes pass 0 if you dont want relative attention', type=int, default=[8], nargs='+') parser.add_argument("--y_only", type=str2bool, default=False) parser.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu", help="Device (cuda or cpu)") args = parser.parse_args() if torch.cuda.device_count() == 1: torch.cuda.set_device(0) args.device = torch.device("cuda", 0) if os.path.isdir(args.checkpoint): checkpoint, _ = find_latest_checkpoint(args.checkpoint) logger.warning("Found latest checkpoint %s", checkpoint) else: checkpoint = args.checkpoint cls = None if not args.use_cls else '[CLS]' end = args.end_token vectorizer = BPEVectorizer1D(model_file=args.subword_model_file, vocab_file=args.subword_vocab_file, mxlen=args.nctx, emit_begin_tok=cls, emit_end_tok=end) vocab = vectorizer.vocab.copy() # If we are not using chars, then use 'x' for both input and output preproc_data = baseline.embeddings.load_embeddings('x', dsz=args.d_model, counts=False, known_vocab=vocab, embed_type=args.embed_type, preserve_vocab_indices=True) embeddings = preproc_data['embeddings'] vocab = preproc_data['vocab'] model = create_model(embeddings, d_model=args.d_model, d_ff=args.d_ff, num_heads=args.num_heads, num_layers=args.num_layers, rpr_k=args.rpr_k, d_k=args.d_k, checkpoint_name=checkpoint, activation=args.activation) model.to(args.device) index2word = revlut(vocab) print('[Query]', args.query) bpe_out = decode_sentence(model, vectorizer, args.query.split(), vocab, index2word, args.device, sample=args.sample, y_only=args.y_only) print('[Response]', ' '.join(bpe_out))
def test_bpe_label_indices_generator(): pytest.importorskip("fastBPE") num_tokens = random.randint(1, 100) tokens = [random_string() for _ in range(num_tokens)] bpe = BPEVectorizer1D(model_file=os.path.join(TEST_DATA, "codes.30k"), vocab_file=os.path.join(TEST_DATA, "vocab.30k")) tokens = add_specials(tokens, bpe.special_tokens) bpe_toks, gold_indices = bpe_tokens(tokens, specials=bpe.special_tokens) indices = bpe.valid_label_indices((t for t in bpe_toks)) assert len(indices) == num_tokens assert indices == gold_indices
def _create_subword_vectorizer(self, mxlen=None, model_file=None, vocab_file=None, emit_begin_tok=None, emit_end_tok=None, transform_fn=None, extra_tokens=None): if self.subword_type == 'wordpiece': return WordpieceVectorizer1D( vocab_file=vocab_file, mxlen=mxlen, emit_begin_tok=emit_begin_tok, emit_end_tok=emit_end_tok, transform_fn=transform_fn) return BPEVectorizer1D(model_file=model_file, vocab_file=vocab_file, mxlen=mxlen, emit_begin_tok=emit_begin_tok, emit_end_tok=emit_end_tok, transform_fn=transform_fn, extra_tokens=extra_tokens)
def __init__(self, nctx=64, model_file=None, vocab_file=None, pattern='*.txt', reader_type="ntp"): self.nctx = nctx self.pattern = pattern self.reader_type = reader_type self.vectorizer = BPEVectorizer1D(model_file=model_file, vocab_file=vocab_file, mxlen=nctx)
def __init__(self, nctx, use_subword=None, model_file=None, vocab_file=None, special_tokens=None): """Create a reader with a context window that reads words :param nctx: The context window length :param use_subword: If this is not none, it should be either 'bpe' or 'wordpiece' """ self.use_subword = use_subword if self.use_subword == 'bpe': vectorizer = BPEVectorizer1D(model_file=model_file, vocab_file=vocab_file) elif self.use_subword == 'wordpiece': vectorizer = WordpieceVectorizer1D(embed_file=model_file, vocab_file=vocab_file, special_tokens=special_tokens) else: vectorizer = Token1DVectorizer(transform_fn=baseline.lowercase) super().__init__(nctx, {'x': vectorizer})
def main(): parser = ArgumentParser() parser.add_argument("--basedir", type=str) parser.add_argument("--train_dir", type=str, required=True, help='Training directory') parser.add_argument("--valid_dir", type=str, required=True, help='Validation directory') parser.add_argument( "--train_md", type=str, help="Training metadata YAML, defaults to `{train_dir}/md.yml`") parser.add_argument( "--valid_md", type=str, help="Validation metadata YAML, defaults to `{valid_dir}/md.yml`") parser.add_argument("--label_file", type=str, help="JSON file mapping labels to integers", default="labels.json") parser.add_argument("--dataset_key", default="tlm", help="dataset key for basedir") parser.add_argument( "--embed_type", type=str, default='default', choices=["default", "positional", "learned-positional"], help="register label of the embeddings") parser.add_argument("--d_model", type=int, default=512, help="Model dimension (and embedding dsz)") parser.add_argument("--d_ff", type=int, default=2048, help="FFN dimension") parser.add_argument("--num_heads", type=int, default=8, help="Number of heads") parser.add_argument("--num_layers", type=int, default=8, help="Number of layers") parser.add_argument("--num_train_workers", type=int, default=4, help="Number train workers") parser.add_argument("--distribute", type=str, default="mirror", choices=["mirror", "tpu", "nccl"]) parser.add_argument("--tpu_ep", type=str, help="The TPU endpoint if using `distribute=tpu`") parser.add_argument("--nctx", type=int, default=256, help="Max input length") parser.add_argument("--file_type", default='tfrecord', choices=['json', 'tfrecord'], help="Glob pattern for data") parser.add_argument("--batch_size", type=int, default=256, help="Batch Size") parser.add_argument("--subword_model_file", type=str, help="The BPE model file", required=True) parser.add_argument("--subword_vocab_file", type=str, help="The BPE subword vocab", required=True) parser.add_argument("--dropout", type=float, default=0.1, help="Dropout") parser.add_argument("--ffn_pdrop", type=float, default=0.0, help="Dropout in the dense stack") parser.add_argument("--layer_drop", type=float, default=0.0, help="LayerDrop to apply") parser.add_argument("--optim", default="adamw", type=str, help="Optimizer to use (defaults to adamw)") parser.add_argument("--lr", type=float, default=4.0e-4, help="Learning rate") parser.add_argument("--clip", type=float, default=1.0, help="Clipping gradient norm") parser.add_argument("--weight_decay", type=float, default=1.0e-2, help="Weight decay") parser.add_argument("--epochs", type=int, default=32, help="Num training epochs") parser.add_argument( "--restart", type=str2bool, help="Option allows you to restart from a previous checkpoint") parser.add_argument("--warmup_steps", type=int, default=10000, help="Num warmup steps") parser.add_argument("--saves_per_epoch", type=int, default=10, help="The number of checkpoints to save per epoch") parser.add_argument( '--rpr_k', help= 'Relative attention positional sizes pass 0 if you dont want relative attention', type=int, default=[8], nargs='+') parser.add_argument( '--rpr_value_on', type=str2bool, default=True, help= "In relative attention, whether add positional correction to values in addition to the " "correction to attention matrix") parser.add_argument('--windowed_ra', type=str2bool, default=False, help="whether prevent attention beyond rpr_k") parser.add_argument("--strategy", help="Training strategy, defaults to `mirror`", choices=["mirror"]) parser.add_argument("--npz", help="Should we write out NPZ files?", type=str2bool, default=False) parser.add_argument("--tb", help="Turn on tensorboard?", type=str2bool, default=False) parser.add_argument( "--convert_only", help="Should we just convert this file to NPZ and exit?", type=str2bool, default=False) args = parser.parse_args() SET_TRAIN_FLAG(True) if args.convert_only: args.restart = True if args.basedir is None: args.basedir = f'lm-{args.dataset_key}-bpe-{os.getpid()}' logging.basicConfig(level=logging.INFO) logger.info(f"Writing results to {args.basedir}") if args.tb: logdir = f"logs/scalars/{os.getpid()}" file_writer = tf.summary.create_file_writer(logdir + "/metrics") file_writer.set_as_default() logger.info(f"Set up tensorboard logdir {logdir}") strategy = create_distribute_strategy(args.distribute, args.tpu_ep) num_replicas = strategy.num_replicas_in_sync logger.info(f"Using {num_replicas} replicas in this job.") vectorizer = BPEVectorizer1D(model_file=args.subword_model_file, vocab_file=args.subword_vocab_file, mxlen=args.nctx) vocab = {'x': vectorizer.vocab} preproc_data = baseline.embeddings.load_embeddings( 'x', dsz=args.d_model, known_vocab=vocab['x'], preserve_vocab_indices=True, embed_type=args.embed_type) vocabs = preproc_data['vocab'] train_md = args.train_md if args.train_md else os.path.join( args.train_dir, 'md.yml') num_train_samples = get_num_samples(train_md) valid_md = args.valid_md if args.valid_md else os.path.join( args.valid_dir, 'md.yml') num_valid_samples = get_num_samples(valid_md) labels = read_json_tf(args.label_file) num_labels = len(labels) def dataset_train_fn(input_context): global_batchsz = args.batch_size base_batchsz = input_context.get_per_replica_batch_size(global_batchsz) ds = get_dataset(args.train_dir, args.file_type, args.num_train_workers).batch(base_batchsz) return ds.shard(input_context.num_input_pipelines, input_context.input_pipeline_id) train_loader = strategy.experimental_distribute_datasets_from_function( dataset_train_fn) def dataset_test_fn(input_context): global_batchsz = args.batch_size base_batchsz = input_context.get_per_replica_batch_size(global_batchsz) ds = get_dataset(args.valid_dir, args.file_type, args.num_train_workers, shuffle=False).batch(base_batchsz) return ds.shard(input_context.num_input_pipelines, input_context.input_pipeline_id) valid_loader = strategy.experimental_distribute_datasets_from_function( dataset_test_fn) os.makedirs(args.basedir, exist_ok=True) # We want to make sure to save our input vocab into the basedir for reuse later write_json(vocabs, os.path.join(args.basedir, 'vocabs.json')) embeddings = {'x': preproc_data['embeddings']} logger.info("Loaded embeddings") logger.info("Loaded datasets") logger.info("Using embedding type [%s]", args.embed_type) if len(args.rpr_k) == 0 or args.rpr_k[0] < 1: args.rpr_k = None elif len(args.rpr_k) == 1: args.rpr_k = args.rpr_k[0] model = TransformerTagger(num_labels, embeddings, **vars(args)) logger.info("Loaded model and loss") steps_per_epoch = num_train_samples // args.batch_size steps_per_valid_epoch = num_valid_samples // args.batch_size update_on = steps_per_epoch // args.saves_per_epoch report_on = max(10, update_on) // 10 logger.info( f"Steps per epoch: {steps_per_epoch}. Saving checkpoint every {update_on} steps." ) lr_decay = CosineDecaySchedulerTensorFlow(steps_per_epoch * args.epochs, lr=args.lr) linear_warmup = WarmupLinearSchedulerTensorFlow(args.warmup_steps, lr=args.lr) lr_sched = CompositeLRSchedulerTensorFlow(linear_warmup, lr_decay) optimizer = EagerOptimizer(loss_function, optim=args.optim, lr_function=lr_sched, weight_decay=args.weight_decay, clip=args.clip, lr=args.lr) checkpoint = tf.train.Checkpoint(optimizer=optimizer.optimizer, model=model) checkpoint_manager = tf.train.CheckpointManager(checkpoint, directory=args.basedir, max_to_keep=5) start_epoch = 0 if args.restart: # The global step gets automatically updated here # so we dont have to worry about our LR regimen checkpoint.restore(checkpoint_manager.latest_checkpoint) current_step = optimizer.global_step start_epoch = current_step // steps_per_epoch def _replicated_train_step(inputs): """This runs on a single replica""" x, y = inputs per_replica_loss = optimizer.update(model, {'x': x}, y, num_replicas) return per_replica_loss @tf.function def _distributed_train_step(inputs: Tuple[tf.Tensor, tf.Tensor]): """Runs across multiple replicas and aggregates the results. :param inputs: :return: """ per_replica_loss = strategy.run(_replicated_train_step, args=(inputs, )) return strategy.reduce(tf.distribute.ReduceOp.SUM, per_replica_loss, axis=None) def _replicated_test_step(inputs): """This runs on a single replica""" x, y = inputs per_replica_loss = loss_function(model, {'x': x}, y) / num_replicas return per_replica_loss @tf.function def _distributed_test_step(inputs: Tuple[tf.Tensor, tf.Tensor]): """Runs across multiple replicas and aggregates the results. :param inputs: :return: """ per_replica_loss = strategy.run(_replicated_test_step, args=(inputs, )) return strategy.reduce(tf.distribute.ReduceOp.SUM, per_replica_loss, axis=None) timer = Timer() with strategy.scope(): for epoch in range(start_epoch, args.epochs): SET_TRAIN_FLAG(True) logger.info('Starting epoch %d', epoch + 1) avg_loss = Average('average_train_loss') metrics = {} timer.start() train_iter = iter(train_loader) for i in range(steps_per_epoch): try: loss = _distributed_train_step(next(train_iter)) avg_loss.update(loss.numpy().item()) tf.summary.scalar("train_loss", data=loss, step=optimizer.global_step) except Exception as e: logger.error( f"Exception at training step {i+1}/{steps_per_epoch}. Skipping" ) pass if args.convert_only: logger.warning( "Convert only flag specified. Stopping after one step" ) steps = optimizer.global_step.numpy() npz_checkpoint = os.path.join( args.basedir, f'checkpoint-step-{steps}.npz') save_tlm_output_npz(model, npz_checkpoint) return steps = optimizer.global_step.numpy() if (steps + 1) % report_on == 0: logger.info(avg_loss) if (steps + 1) % update_on == 0: elapsed = timer.elapsed(True) logger.info('elapsed time this epoch %d min', elapsed) logger.info('elapsed step time %f steps/min', i / elapsed) checkpoint_manager.save() if args.npz: npz_checkpoint = os.path.join( args.basedir, f'checkpoint-step-{steps}.npz') save_tlm_output_npz(model, npz_checkpoint) # How much time elapsed in minutes elapsed = timer.elapsed(True) train_token_loss = avg_loss.avg # This is the average training token-level loss across all machines # This is the token-level training perplexity train_token_ppl = math.exp(train_token_loss) metrics['train_elapsed_min'] = elapsed metrics['average_train_loss'] = train_token_loss metrics['train_ppl'] = train_token_ppl metrics['lr'] = float( lr_sched(tf.cast(optimizer.global_step, tf.float32)).numpy().item()) avg_valid_loss = Average('average_valid_loss') timer.start() SET_TRAIN_FLAG(False) valid_iter = iter(valid_loader) for i in range(steps_per_valid_epoch): try: valid_loss = _distributed_test_step(next(valid_iter)) tf.summary.scalar('valid_loss', data=valid_loss, step=optimizer.global_step) avg_valid_loss.update(valid_loss.numpy().item()) except Exception as e: logger.error( f"Exception at validation step {i+1}/{steps_per_valid_epoch}. Skipping" ) pass valid_token_loss = avg_valid_loss.avg valid_token_ppl = math.exp(valid_token_loss) elapsed = timer.elapsed(True) metrics['valid_elapsed_min'] = elapsed metrics['average_valid_loss'] = valid_token_loss metrics['average_valid_word_ppl'] = valid_token_ppl logger.info(json.dumps(metrics, indent=4))
def train(): parser = ArgumentParser() parser.add_argument("--basedir", type=str) parser.add_argument("--train_dir", type=str, required=True, help='Training directory') parser.add_argument("--valid_dir", type=str, required=True, help='Validation directory') parser.add_argument( "--train_md", type=str, help="Training metadata YAML, defaults to `{train_dir}/md.yml`") parser.add_argument( "--valid_md", type=str, help="Validation metadata YAML, defaults to `{valid_dir}/md.yml`") parser.add_argument("--dataset_key", default="tlm", help="dataset key for basedir") parser.add_argument( "--embed_type", type=str, default='default', choices=["default", "positional", "learned-positional"], help="register label of the embeddings") parser.add_argument("--gen_d_model", type=int, default=256, help="Model dimension (and embedding dsz)") parser.add_argument("--gen_d_ff", type=int, default=1024, help="FFN dimension") parser.add_argument( "--gen_d_k", type=int, default=None, help="Dimension per head. Use if num_heads=1 to reduce dims") parser.add_argument("--gen_num_heads", type=int, default=8, help="Number of heads") parser.add_argument("--gen_num_layers", type=int, default=8, help="Number of layers") parser.add_argument( '--gen_rpr_k', help= 'Relative attention positional sizes pass 0 if you dont want relative attention', type=int, default=[8], nargs='+') parser.add_argument('--windowed_ra', type=str2bool, default=False, help="whether prevent attention beyond rpr_k") parser.add_argument("--gen_loss_scale", type=float, default=50.0, help="Scaling for loss function") parser.add_argument("--gen_dropout", type=float, default=0.1, help="Dropout") parser.add_argument( '--discrim_rpr_k', help= 'Relative attention positional sizes pass 0 if you dont want relative attention', type=int, default=[8], nargs='+') parser.add_argument("--discrim_d_model", type=int, default=512, help="Model dimension (and embedding dsz)") parser.add_argument("--discrim_d_ff", type=int, default=2048, help="FFN dimension") parser.add_argument( "--discrim_d_k", type=int, default=None, help="Dimension per head. Use if num_heads=1 to reduce dims") parser.add_argument("--discrim_num_heads", type=int, default=8, help="Number of heads") parser.add_argument("--discrim_num_layers", type=int, default=8, help="Number of layers") parser.add_argument("--discrim_dropout", type=float, default=0.1, help="Dropout") parser.add_argument("--num_train_workers", type=int, default=4, help="Number train workers") parser.add_argument("--distribute", type=str, default="mirror", choices=["mirror", "tpu", "nccl"]) parser.add_argument("--tpu_ep", type=str, help="The TPU endpoint if using `distribute=tpu`") parser.add_argument("--nctx", type=int, default=256, help="Max input length") parser.add_argument("--file_type", default='tfrecord', choices=['json', 'tfrecord'], help="Glob pattern for data") parser.add_argument("--batch_size", type=int, default=256, help="Batch Size") parser.add_argument("--subword_model_file", type=str, help="The BPE model file", required=True) parser.add_argument("--subword_vocab_file", type=str, help="The BPE subword vocab", required=True) parser.add_argument("--optim", default="adam", type=str, help="Optimizer to use (defaults to adam)") parser.add_argument("--lr", type=float, default=4.0e-4, help="Learning rate") parser.add_argument("--clip", type=float, default=1.0, help="Clipping gradient norm") parser.add_argument("--weight_decay", type=float, default=1.0e-2, help="Weight decay") parser.add_argument("--epochs", type=int, default=32, help="Num training epochs") parser.add_argument( "--restart", type=str2bool, help="Option allows you to restart from a previous checkpoint") parser.add_argument("--warmup_steps", type=int, default=10000, help="Num warmup steps") parser.add_argument("--causal", type=str2bool, default=False, help="Use CLM (causal) instead of MLM") parser.add_argument("--saves_per_epoch", type=int, default=10, help="The number of checkpoints to save per epoch") parser.add_argument("--strategy", help="Training strategy, defaults to `mirror`", choices=["mirror"]) parser.add_argument("--npz", help="Should we write out NPZ files?", type=str2bool, default=False) parser.add_argument("--tb", help="Turn on tensorboard?", type=str2bool, default=False) parser.add_argument( "--convert_only", help="Should we just convert this file to NPZ and exit?", type=str2bool, default=False) args = parser.parse_args() SET_TRAIN_FLAG(True) if args.convert_only: args.restart = True args.npz = True if args.basedir is None: args.basedir = f'discrim-{args.dataset_key}-bpe-{os.getpid()}' logging.basicConfig(level=logging.INFO) logger.info(f"Writing results to {args.basedir}") if args.tb: logdir = f"logs/scalars/{os.getpid()}" file_writer = tf.summary.create_file_writer(logdir + "/metrics") file_writer.set_as_default() logger.info(f"Set up tensorboard logdir {logdir}") strategy = create_distribute_strategy(args.distribute, args.tpu_ep) num_replicas = strategy.num_replicas_in_sync logger.info(f"Using {num_replicas} replicas in this job.") vectorizer = BPEVectorizer1D(model_file=args.subword_model_file, vocab_file=args.subword_vocab_file, mxlen=args.nctx) vocab = {'x': vectorizer.vocab} gen_preproc_data = baseline.embeddings.load_embeddings( 'x', dsz=args.gen_d_model, known_vocab=vocab['x'], preserve_vocab_indices=True, embed_type=args.embed_type) vocabs = gen_preproc_data['vocab'] discrim_preproc_data = baseline.embeddings.load_embeddings( 'x', dsz=args.discrim_d_model, known_vocab=vocab['x'], preserve_vocab_indices=True, embed_type=args.embed_type) def dataset_train_fn(input_context): batch_size = input_context.get_per_replica_batch_size(args.batch_size) ds = get_dataset(args.train_dir, args.file_type, args.num_train_workers).batch(batch_size) return ds.shard(input_context.num_input_pipelines, input_context.input_pipeline_id) train_loader = strategy.experimental_distribute_datasets_from_function( dataset_train_fn) def dataset_test_fn(input_context): batch_size = input_context.get_per_replica_batch_size(args.batch_size) ds = get_dataset(args.valid_dir, args.file_type, args.num_train_workers, shuffle=False).batch(batch_size) return ds.shard(input_context.num_input_pipelines, input_context.input_pipeline_id) valid_loader = strategy.experimental_distribute_datasets_from_function( dataset_test_fn) train_md = args.train_md if args.train_md else os.path.join( args.train_dir, 'md.yml') num_train_samples = get_num_samples(train_md) valid_md = args.valid_md if args.valid_md else os.path.join( args.valid_dir, 'md.yml') num_valid_samples = get_num_samples(valid_md) os.makedirs(args.basedir, exist_ok=True) # We want to make sure to save our input vocab into the basedir for reuse later write_json(vocabs, os.path.join(args.basedir, 'vocabs.json')) gen_embeddings = {'x': gen_preproc_data['embeddings']} discrim_embeddings = {'x': discrim_preproc_data['embeddings']} logger.info("Loaded embeddings") logger.info("Loaded datasets") logger.info("Using embedding type [%s]", args.embed_type) if len(args.gen_rpr_k) == 0 or args.gen_rpr_k[0] < 1: gen_rpr_k = None elif len(args.gen_rpr_k) == 1: gen_rpr_k = args.gen_rpr_k[0] else: gen_rpr_k = args.gen_rpr_k if len(args.discrim_rpr_k) == 0 or args.discrim_rpr_k[0] < 1: discrim_rpr_k = None elif len(args.gen_rpr_k) == 1: discrim_rpr_k = args.discrim_rpr_k[0] else: discrim_rpr_k = args.discrim_rpr_k gen_model = TransformerMaskedLanguageModel.create( gen_embeddings, hsz=args.gen_d_model, d_ff=args.gen_d_ff, tie_weights=True, dropout=args.gen_dropout, gpu=False, num_heads=args.gen_num_heads, layers=args.gen_num_layers, rpr_k=gen_rpr_k, d_k=args.gen_d_k, windowed_ra=args.windowed_ra, src_keys=['x'], tgt_key='x') discrim_model = TransformerDiscriminator(discrim_embeddings, d_model=args.discrim_d_model, d_ff=args.discrim_d_ff, dropout=args.discrim_dropout, num_heads=args.discrim_num_heads, layers=args.discrim_num_layers, rpr_k=discrim_rpr_k, d_k=args.discrim_d_k) logger.info("Loaded model and loss") steps_per_epoch = num_train_samples // args.batch_size steps_per_valid_epoch = num_valid_samples // args.batch_size update_on = steps_per_epoch // args.saves_per_epoch report_on = max(10, update_on) // 10 logger.info( f"Steps per epoch: {steps_per_epoch}. Saving checkpoint every {update_on} steps." ) lr_decay = CosineDecaySchedulerTensorFlow(steps_per_epoch * args.epochs, lr=args.lr) linear_warmup = WarmupLinearSchedulerTensorFlow(args.warmup_steps, lr=args.lr) lr_sched = CompositeLRSchedulerTensorFlow(linear_warmup, lr_decay) mask_value = vocabs.get("[MASK]", vocabs.get("<MASK>", -1)) if mask_value == -1: logger.error("We could not find a suitable masking token in the vocab") return optimizer, clip = create_keras_optimizer(**vars(args)) discrim_checkpoint = tf.train.Checkpoint(optimizer=optimizer, model=discrim_model) discrim_checkpoint_manager = tf.train.CheckpointManager( discrim_checkpoint, directory=os.path.join(args.basedir, 'discrim'), max_to_keep=5) gen_checkpoint = tf.train.Checkpoint(optimizer=optimizer, model=discrim_model) gen_checkpoint_manager = tf.train.CheckpointManager(gen_checkpoint, directory=os.path.join( args.basedir, 'gen'), max_to_keep=5) mask_value = vocabs.get("[MASK]", vocabs.get("<MASK>", -1)) if mask_value == -1: logger.error("We could not find a suitable masking token in the vocab") return if args.restart: # The global step gets automatically updated here # so we dont have to worry about our LR regimen gen_checkpoint.restore(gen_checkpoint_manager.latest_checkpoint) discrim_checkpoint.restore( discrim_checkpoint_manager.latest_checkpoint) def _replicated_train_step(inputs): """This runs on a single replica""" noised_x, labels = inputs with tf.GradientTape() as tape: gen_loss_step, discrim_loss_step, acc = gen_vs_discrim( noised_x, labels, gen_model, discrim_model, mask_value) loss_value = (args.gen_loss_scale * gen_loss_step + discrim_loss_step) / num_replicas grads = tape.gradient( loss_value, gen_model.trainable_variables + discrim_model.trainable_variables) grads, _ = tf.clip_by_global_norm(grads, clip) optimizer.apply_gradients( zip( grads, gen_model.trainable_variables + discrim_model.trainable_variables)) return loss_value, gen_loss_step, discrim_loss_step, acc @tf.function def _distributed_train_step(inputs: Tuple[tf.Tensor, tf.Tensor]): """Runs across multiple replicas and aggregates the results. :param inputs: :return: """ loss, gen_loss, discrim_loss, acc = strategy.run( _replicated_train_step, args=(inputs, )) sum_loss = strategy.reduce(tf.distribute.ReduceOp.SUM, loss, axis=None) sum_gen_loss = strategy.reduce(tf.distribute.ReduceOp.SUM, gen_loss, axis=None) sum_discrim_loss = strategy.reduce(tf.distribute.ReduceOp.SUM, discrim_loss, axis=None) sum_acc = strategy.reduce(tf.distribute.ReduceOp.SUM, acc, axis=None) return sum_loss, sum_gen_loss, sum_discrim_loss, sum_acc def _replicated_test_step(inputs): """This runs on a single replica""" noised_x, labels = inputs gen_loss_step, discrim_loss_step, acc = gen_vs_discrim( noised_x, labels, gen_model, discrim_model, mask_value) loss_value = (args.gen_loss_scale * gen_loss_step + discrim_loss_step) / num_replicas return loss_value, gen_loss_step, discrim_loss_step, acc @tf.function def _distributed_test_step(inputs: Tuple[tf.Tensor, tf.Tensor]): """Runs across multiple replicas and aggregates the results. :param inputs: :return: """ loss, gen_loss, discrim_loss, acc = strategy.run(_replicated_test_step, args=(inputs, )) sum_loss = strategy.reduce(tf.distribute.ReduceOp.SUM, loss, axis=None) sum_gen_loss = strategy.reduce(tf.distribute.ReduceOp.SUM, gen_loss, axis=None) sum_discrim_loss = strategy.reduce(tf.distribute.ReduceOp.SUM, discrim_loss, axis=None) sum_acc = strategy.reduce(tf.distribute.ReduceOp.SUM, acc, axis=None) return sum_loss, sum_gen_loss, sum_discrim_loss, sum_acc # This is the training loop start_epoch = 0 timer = Timer() with strategy.scope(): for epoch in range(start_epoch, args.epochs): SET_TRAIN_FLAG(True) logger.info('Starting epoch %d', epoch + 1) avg_loss = Average('average_train_loss') avg_gen_loss = Average('average_gen_loss') avg_discrim_loss = Average('average_discrim_loss') avg_acc = Average('average_train_acc') metrics = {} timer.start() train_iter = iter(train_loader) for i in range(steps_per_epoch): loss, gen_loss, discrim_loss, acc = _distributed_train_step( next(train_iter)) avg_loss.update(loss.numpy().item()) avg_gen_loss.update(gen_loss.numpy().item()) avg_discrim_loss.update(discrim_loss.numpy().item()) avg_acc.update(acc.numpy().item()) tf.summary.scalar("train_loss", data=loss, step=optimizer.iterations) tf.summary.scalar("train_gen_loss", data=gen_loss, step=optimizer.iterations) tf.summary.scalar("train_discrim_loss", data=discrim_loss, step=optimizer.iterations) tf.summary.scalar("train_acc", data=acc, step=optimizer.iterations) if args.convert_only: logger.warning( "Convert only flag specified. Stopping after one step" ) steps = optimizer.iterations.numpy() npz_checkpoint = os.path.join(args.basedir, f'discrim-step-{steps}.npz') save_tlm_npz(discrim_model, npz_checkpoint) npz_checkpoint = os.path.join(args.basedir, f'gen-step-{steps}.npz') save_tlm_npz(gen_model, npz_checkpoint) return if (i + 1) % report_on == 0: logging.info(avg_loss) logging.info(avg_gen_loss) logging.info(avg_discrim_loss) logging.info(avg_acc) if (i + 1) % update_on == 0: elapsed = timer.elapsed(True) logging.info('elapsed time this epoch %d min', elapsed) logging.info('elapsed step time %f steps/min', i / elapsed) gen_checkpoint_manager.save() discrim_checkpoint_manager.save() if args.npz: steps = optimizer.iterations.numpy() npz_checkpoint = os.path.join( args.basedir, f'discrim-step-{steps}.npz') save_tlm_npz(discrim_model, npz_checkpoint) npz_checkpoint = os.path.join(args.basedir, f'gen-step-{steps}.npz') save_tlm_npz(gen_model, npz_checkpoint) # This is the average training token-level loss across all machines # This is the token-level training perplexity metrics['train_elapsed_min'] = timer.elapsed(True) metrics['average_train_loss'] = avg_loss.avg metrics['average_gen_loss'] = avg_gen_loss.avg metrics['average_discrim_loss'] = avg_discrim_loss.avg metrics['average_train_acc'] = avg_acc.avg metrics['lr'] = float( lr_sched(tf.cast(optimizer.global_step, tf.float32)).numpy().item()) avg_valid_loss = Average('average_valid_loss') avg_valid_gen_loss = Average('average_valid_gen_loss') avg_valid_discrim_loss = Average('average_valid_discrim_loss') avg_valid_acc = Average('average_valid_acc') timer.start() SET_TRAIN_FLAG(False) valid_iter = iter(valid_loader) for i in range(steps_per_valid_epoch): valid_loss, valid_gen_loss, valid_discrim_loss, valid_acc = _distributed_test_step( next(valid_iter)) tf.summary.scalar('valid_loss', data=valid_loss, step=optimizer.iterations) tf.summary.scalar('valid_gen_loss', data=valid_gen_loss, step=optimizer.iterations) tf.summary.scalar('valid_discrim_loss', data=valid_discrim_loss, step=optimizer.iterations) tf.summary.scalar('valid_acc', data=valid_acc, step=optimizer.iterations) avg_valid_loss.update(valid_loss.numpy().item()) avg_valid_gen_loss.update(valid_gen_loss.numpy().item()) avg_valid_discrim_loss.update( valid_discrim_loss.numpy().item()) avg_valid_acc.update(valid_acc.numpy().item()) metrics['valid_elapsed_min'] = timer.elapsed(True) metrics['average_valid_loss'] = avg_valid_loss.avg metrics['average_valid_gen_loss'] = avg_valid_gen_loss.avg metrics['average_valid_discrim_loss'] = avg_valid_discrim_loss.avg metrics['average_valid_acc'] = avg_valid_acc.avg logger.info(json.dumps(metrics, indent=4))
def run(): parser = ArgumentParser() parser.add_argument("--basedir", type=str) parser.add_argument("--checkpoint", type=str, help='Checkpoint name or directory to load') parser.add_argument("--sample", type=str2bool, help='Sample from the decoder? Defaults to `false`', default=0) parser.add_argument("--vocab", type=str, help='Vocab file to load', required=False) parser.add_argument("--input", type=str, default='hello how are you ?') parser.add_argument("--dataset_cache", type=str, default=os.path.expanduser('~/.bl-data'), help="Path or url of the dataset cache") parser.add_argument("--d_model", type=int, default=512, help="Model dimension (and embedding dsz)") parser.add_argument("--d_ff", type=int, default=2048, help="FFN dimension") parser.add_argument( "--d_k", type=int, default=None, help="Dimension per head. Use if num_heads=1 to reduce dims") parser.add_argument("--num_heads", type=int, default=8, help="Number of heads") parser.add_argument("--num_layers", type=int, default=8, help="Number of layers") parser.add_argument( "--nctx", type=int, default=256, help="Max context length (for both encoder and decoder)") parser.add_argument( "--embed_type", type=str, default='default', help= "register label of the embeddings, so far support positional or learned-positional" ) parser.add_argument("--subword_model_file", type=str, required=True) parser.add_argument("--subword_vocab_file", type=str, required=True) parser.add_argument("--batchsz", help="Size of a batch to pass at once", default=4, type=int) parser.add_argument("--beamsz", help="Size of beam to use", default=4, type=int) parser.add_argument("--activation", type=str, default='relu') parser.add_argument( '--rpr_k', help= 'Relative attention positional sizes pass 0 if you dont want relative attention', type=int, default=[8] * 8, nargs='+') #parser.add_argument("--go_token", default="<GO>") parser.add_argument("--end_token", default="<EOS>") parser.add_argument("--output_file", type=str) parser.add_argument("--show_query", type=str2bool, default=False, help="Show the original query as well") parser.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu", help="Device (cuda or cpu)") args = parser.parse_args() if torch.cuda.device_count() == 1: torch.cuda.set_device(0) args.device = torch.device("cuda", 0) if os.path.isdir(args.checkpoint): checkpoint, _ = find_latest_checkpoint(args.checkpoint) logger.warning("Found latest checkpoint %s", checkpoint) else: checkpoint = args.checkpoint vectorizer = BPEVectorizer1D(model_file=args.subword_model_file, vocab_file=args.subword_vocab_file, mxlen=args.nctx, emit_end_tok=args.end_token) vocab = vectorizer.vocab # If we are not using chars, then use 'x' for both input and output preproc_data = baseline.embeddings.load_embeddings( 'x', dsz=args.d_model, counts=False, known_vocab=vocab, embed_type=args.embed_type) embeddings = preproc_data['embeddings'] vocab = preproc_data['vocab'] model = create_model(embeddings, d_model=args.d_model, d_ff=args.d_ff, num_heads=args.num_heads, num_layers=args.num_layers, rpr_k=args.rpr_k, d_k=args.d_k, checkpoint_name=checkpoint, activation=args.activation, device=args.device) model.to(args.device) index2word = revlut(vocab) wf = None if args.output_file: wf = open(args.output_file, "w") batches = [] if os.path.exists(args.input) and os.path.isfile(args.input): with open(args.input, 'rt', encoding='utf-8') as f: batch = [] for line in f: text = line.strip().split() if len(batch) == args.batchsz: batches.append(batch) batch = [] batch.append(text) if len(batch) > 0: batches.append(batch) else: batch = [args.input.split()] batches.append(batch) for queries in batches: outputs = decode_sentences(model, vectorizer, queries, vocab, index2word, args.beamsz) if args.show_query: for query, output in zip(queries, outputs): print(f"[Query] {query}") print(f"[Response] {output}") elif wf: for query, output in zip(queries, outputs): wf.write(f'{output}\n') wf.flush() else: for query, output in zip(queries, outputs): print(output) if wf: wf.close()
def run(input_files=[], input_pattern='*.txt', codes=None, vocab=None, nctx=256, fmt='json', fields=['x_str', 'y_str'], output=None, prefix=None, suffix=None, max_file_size=100, tok_on_eol="<EOS>", cased=True, mask_type="mlm", module=None, pad_y=True, extra_tokens=['[CLS]', '[MASK]'], world_size=1, world_offset=0, input_field='text', tokenizer_type=None, **kwargs): def parse_json_line(x): return json.loads(x)[input_field] if module: logger.warning("Loading custom user module %s for masking rules and tokenizers", module) baseline.import_user_module(module) get_line = lambda x: x.strip() if os.path.isdir(input_files): if '.json' in input_pattern: get_line = parse_json_line input_files = list(glob.glob(os.path.join(input_files, input_pattern))) if not output: output = os.path.join(input_files, 'records') else: if '.json' in input_files: get_line = parse_json_line input_files = [input_files] if not output: output = f'{input_files}.records' if len(input_files) < world_size: raise Exception(f"The number of input shards ({len(input_files)})should be greater than the world_size: {world_size}") logger.info('Output [%s]', output) transform = baseline.lowercase if not cased else lambda x: x vectorizer = BPEVectorizer1D(transform_fn=transform, model_file=codes, vocab_file=vocab, mxlen=1024, extra_tokens=extra_tokens) lookup_indices = [] indices2word = baseline.revlut(vectorizer.vocab) root_dir = os.path.dirname(output) tokenizer = create_tokenizer(tokenizer_type) masking = create_masking(mask_type, vectorizer.vocab, pad_y) if not os.path.exists(root_dir): os.makedirs(root_dir) if prefix: nctx -= 1 prefix = vectorizer.vocab[prefix] if suffix: nctx -= 1 suffix = vectorizer.vocab[suffix] fw = create_file_writer(fmt, output, fields, max_file_size, 1000 * world_offset) num_samples = 0 for i, text in enumerate(input_files): if i % world_size != world_offset: continue with TextFile(text) as rf: print(f"Reading from {text}...") for line in rf: to_bpe = tokenizer(get_line(line)) if not to_bpe: continue to_bpe += [tok_on_eol] output, available = vectorizer.run(to_bpe, vectorizer.vocab) while available > 0: if len(lookup_indices) == nctx: record = create_record(lookup_indices, indices2word, prefix, suffix, masking=masking) fw.write(record) num_samples += 1 lookup_indices = [] needed = nctx - len(lookup_indices) if available >= needed: lookup_indices += output[:needed].tolist() output = output[needed:] available -= needed record = create_record(lookup_indices, indices2word, prefix, suffix, masking=masking) fw.write(record) num_samples += 1 lookup_indices = [] # The amount available is less than what we need, so read the whole thing else: lookup_indices += output[:available].tolist() available = 0 fw.close() f_name = f'md-{world_offset}.yml' if world_size > 1 else 'md.yml' write_yaml({'num_samples': num_samples}, os.path.join(root_dir, f_name))
if os.path.isdir(args.input_files): import glob input_files = list( glob.glob(os.path.join(args.input_files, args.input_pattern))) if not args.output: args.output = os.path.join(args.input_files, 'records') else: input_files = [args.input_files] if not args.output: args.output = f'{args.input_files}.records' print(args.output) transform = baseline.lowercase if not args.cased else lambda x: x vectorizer = BPEVectorizer1D(transform_fn=transform, model_file=args.codes, vocab_file=args.vocab, mxlen=1024) lookup_indices = [] words = [] indices2word = baseline.revlut(vectorizer.vocab) vocab_size = max(vectorizer.vocab.values()) + 1 nctx = args.nctx mask_value = vectorizer.vocab['[MASK]'] prefix = suffix = None root_dir = os.path.dirname(args.output) if not os.path.exists(root_dir): os.makedirs(root_dir) if args.prefix: nctx -= 1
parser.add_argument("--valid_split", type=float, default=0.05) parser.add_argument("--prefix", default="<GO>") parser.add_argument("--suffix", default="<EOS>") parser.add_argument("--pg_name", choices=["tqdm", "default"], default="default") args = parser.parse_args() annot_files = list(Path(args.annot_files).iterdir()) valid_split = int(len(annot_files) * args.valid_split) VALID_FILES = annot_files[:valid_split] TRAIN_FILES = annot_files[valid_split:] VECTORIZER = BPEVectorizer1D( transform_fn=baseline.lowercase if not args.cased else lambda x: x, model_file=args.codes, vocab_file=args.vocab, mxlen=1024) NCTX = args.nctx - 2 PREFIX = ( VECTORIZER.vocab[args.prefix], Offsets.GO, ) SUFFIX = ( VECTORIZER.vocab[args.suffix], Offsets.EOS, ) DOC2WORD = read_vocab_file(args.document_vocab) label2word = read_vocab_file(args.label_vocab) LABELS = {Offsets.VALUES[k]: k for k in range(Offsets.OFFSET)}
def main(): parser = argparse.ArgumentParser( description='Convert text into MLM fixed width contexts') parser.add_argument( '--input_files', help= 'The text to classify as a string, or a path to a file with each line as an example', type=str) parser.add_argument( '--annot_files', help= 'The text to classify as a string, or a path to a file with each line as an example', type=str) parser.add_argument('--codes', help='BPE codes') parser.add_argument('--vocab', help='BPE vocab') parser.add_argument("--nctx", type=int, default=256, help="Max input length") parser.add_argument("--fmt", type=str, default='json', choices=['json', 'tsv', 'tfrecord']) parser.add_argument("--fields", type=str, nargs="+", default=["x_str", "y_str"]) parser.add_argument("--output_dir", type=str, help="Output base name, e.g. /path/to/output/record") parser.add_argument("--max_file_size", type=int, default=100, help="Shard size, defaults to 100MB") parser.add_argument( "--stride", type=int, help="Tokens to stride before next read, defaults to `nctx`") parser.add_argument("--tok_on_eol", type=str, default="<EOS>") parser.add_argument("--cased", type=baseline.str2bool, default=True) parser.add_argument("--document_vocab", type=str, default="document.vocab") parser.add_argument("--label_vocab", type=str, default="label.vocab") parser.add_argument("--valid_split", type=float, default=0.05) parser.add_argument("--prefix", default="<GO>") parser.add_argument("--suffix", default="<EOS>") parser.add_argument("--pg_name", choices=["tqdm", "default"], default="default") args = parser.parse_args() annot_files = list(Path(args.annot_files).iterdir()) valid_split = int(len(annot_files) * args.valid_split) VALID_FILES = annot_files[:valid_split] TRAIN_FILES = annot_files[valid_split:] VECTORIZER = BPEVectorizer1D( transform_fn=baseline.lowercase if not args.cased else lambda x: x, model_file=args.codes, vocab_file=args.vocab, mxlen=1024) NCTX = args.nctx - 2 PREFIX = ( VECTORIZER.vocab[args.prefix], Offsets.GO, ) SUFFIX = ( VECTORIZER.vocab[args.suffix], Offsets.EOS, ) DOC2WORD = read_vocab_file(args.document_vocab) label2word = read_vocab_file(args.label_vocab) LABELS = {Offsets.VALUES[k]: k for k in range(Offsets.OFFSET)} for label in label2word.values(): for prefix in ["B", "I", "E", "S"]: LABELS[f"{prefix}-{label}"] = len(LABELS) LABELS["O"] = len(LABELS) if not os.path.exists(args.output_dir): os.makedirs(args.output_dir) write_json(LABELS, os.path.join(args.output_dir, 'labels.json')) valid_dir = os.path.join(args.output_dir, 'valid') train_dir = os.path.join(args.output_dir, 'train') makedir_if_none(args.output_dir) makedir_if_none(train_dir) makedir_if_none(valid_dir) logger.info("Converting validation files") fw_valid = create_file_writer(args.fmt, os.path.join(valid_dir, 'valid'), args.fields, args.max_file_size) write_files(VALID_FILES, args.input_files, fw_valid, valid_dir, args.pg_name) logger.info("Converting training files") fw_train = create_file_writer(args.fmt, os.path.join(train_dir, 'train'), args.fields, args.max_file_size) write_files(TRAIN_FILES, args.input_files, fw_train, train_dir, args.pg_name)