Example #1
0
 def __init__(self,
              src_nctx=64,
              tgt_nctx=64,
              src_begin_tok=[],
              src_end_tok=['<EOS>'],
              tgt_begin_tok=['<GO>'],
              tgt_end_tok=['<EOS>'],
              model_file=None,
              vocab_file=None,
              file_type='txt',
              reader_type="ntp",
              record_keys=None,
              lower=False):
     self.src_nctx = src_nctx
     self.tgt_nctx = tgt_nctx
     self.pattern = f'*.{file_type}'
     self.reader_type = reader_type
     if not src_begin_tok and self.reader_type == 'lang':
         src_begin_tok = ['[CLS]']
     self.record_keys = record_keys if record_keys else ['x', 'y']
     transform_fn = None if not lower else baseline.lowercase
     self.src_vectorizer = BPEVectorizer1D(model_file=model_file,
                                           vocab_file=vocab_file,
                                           mxlen=src_nctx,
                                           emit_begin_tok=src_begin_tok,
                                           emit_end_tok=src_end_tok,
                                           transform_fn=transform_fn)
     self.tgt_vectorizer = BPEVectorizer1D(model_file=model_file,
                                           vocab_file=vocab_file,
                                           mxlen=tgt_nctx,
                                           emit_begin_tok=tgt_begin_tok,
                                           emit_end_tok=tgt_end_tok,
                                           transform_fn=transform_fn)
 def _create_subword_vectorizer(self,
                                mxlen=None,
                                model_file=None,
                                vocab_file=None,
                                emit_begin_tok=None,
                                emit_end_tok=None,
                                transform_fn=None,
                                extra_tokens=None):
     if self.subword_type == 'bpe':
         return BPEVectorizer1D(model_file=model_file,
                                vocab_file=vocab_file,
                                mxlen=mxlen,
                                emit_begin_tok=emit_begin_tok,
                                emit_end_tok=emit_end_tok,
                                transform_fn=transform_fn,
                                extra_tokens=extra_tokens)
     if self.subword_type == 'wordpiece':
         return WordpieceVectorizer1D(vocab_file=vocab_file,
                                      mxlen=mxlen,
                                      emit_begin_tok=emit_begin_tok,
                                      emit_end_tok=emit_end_tok,
                                      transform_fn=transform_fn)
     else:
         from baseline.vectorizers import SentencePieceVectorizer1D
         return SentencePieceVectorizer1D(model_file=model_file,
                                          mxlen=mxlen,
                                          emit_begin_tok=emit_begin_tok,
                                          emit_end_tok=emit_end_tok,
                                          transform_fn=transform_fn,
                                          extra_tokens=extra_tokens)
Example #3
0
def run():
    parser = ArgumentParser()
    parser.add_argument("--basedir", type=str)
    parser.add_argument("--checkpoint", type=str, help='Checkpoint name or directory to load')
    parser.add_argument("--sample", type=str2bool, help='Sample from the decoder?  Defaults to `false`', default=0)
    parser.add_argument("--query", type=str, default='hello , <unk> are you today ?')
    parser.add_argument("--dataset_cache", type=str, default=os.path.expanduser('~/.bl-data'),
                        help="Path or url of the dataset cache")
    parser.add_argument("--d_model", type=int, default=512, help="Model dimension (and embedding dsz)")
    parser.add_argument("--d_ff", type=int, default=2048, help="FFN dimension")
    parser.add_argument("--d_k", type=int, default=None, help="Dimension per head.  Use if num_heads=1 to reduce dims")
    parser.add_argument("--num_heads", type=int, default=8, help="Number of heads")
    parser.add_argument("--num_layers", type=int, default=8, help="Number of layers")
    parser.add_argument("--nctx", type=int, default=128, help="Max context length (for both encoder and decoder)")
    parser.add_argument("--embed_type", type=str, default='default',
                        help="register label of the embeddings, so far support positional or learned-positional")
    parser.add_argument("--subword_model_file", type=str, required=True)
    parser.add_argument("--subword_vocab_file", type=str, required=True)
    parser.add_argument("--use_cls", type=str2bool, default=False)
    parser.add_argument('--end_token', default='<EOU>')
    parser.add_argument("--activation", type=str, default='gelu')
    parser.add_argument('--rpr_k', help='Relative attention positional sizes pass 0 if you dont want relative attention',
                        type=int, default=[8], nargs='+')
    parser.add_argument("--y_only", type=str2bool, default=False)
    parser.add_argument("--device", type=str,
                        default="cuda" if torch.cuda.is_available() else "cpu",
                        help="Device (cuda or cpu)")
    args = parser.parse_args()

    if torch.cuda.device_count() == 1:
        torch.cuda.set_device(0)
        args.device = torch.device("cuda", 0)


    if os.path.isdir(args.checkpoint):
        checkpoint, _ = find_latest_checkpoint(args.checkpoint)
        logger.warning("Found latest checkpoint %s", checkpoint)
    else:
        checkpoint = args.checkpoint

    cls = None if not args.use_cls else '[CLS]'
    end = args.end_token
    vectorizer = BPEVectorizer1D(model_file=args.subword_model_file, vocab_file=args.subword_vocab_file, mxlen=args.nctx, emit_begin_tok=cls, emit_end_tok=end)
    vocab = vectorizer.vocab.copy()
    # If we are not using chars, then use 'x' for both input and output
    preproc_data = baseline.embeddings.load_embeddings('x', dsz=args.d_model, counts=False, known_vocab=vocab, embed_type=args.embed_type, preserve_vocab_indices=True)
    embeddings = preproc_data['embeddings']
    vocab = preproc_data['vocab']
    model = create_model(embeddings, d_model=args.d_model, d_ff=args.d_ff, num_heads=args.num_heads, num_layers=args.num_layers,
                         rpr_k=args.rpr_k, d_k=args.d_k, checkpoint_name=checkpoint, activation=args.activation)
    model.to(args.device)

    index2word = revlut(vocab)
    print('[Query]', args.query)
    bpe_out = decode_sentence(model, vectorizer, args.query.split(), vocab, index2word, args.device, sample=args.sample, y_only=args.y_only)

    print('[Response]', ' '.join(bpe_out))
Example #4
0
def test_bpe_label_indices_generator():
    pytest.importorskip("fastBPE")
    num_tokens = random.randint(1, 100)
    tokens = [random_string() for _ in range(num_tokens)]
    bpe = BPEVectorizer1D(model_file=os.path.join(TEST_DATA, "codes.30k"), vocab_file=os.path.join(TEST_DATA, "vocab.30k"))
    tokens = add_specials(tokens, bpe.special_tokens)
    bpe_toks, gold_indices = bpe_tokens(tokens, specials=bpe.special_tokens)
    indices = bpe.valid_label_indices((t for t in bpe_toks))
    assert len(indices) == num_tokens
    assert indices == gold_indices
Example #5
0
 def _create_subword_vectorizer(self, mxlen=None, model_file=None, vocab_file=None, emit_begin_tok=None, emit_end_tok=None, transform_fn=None, extra_tokens=None):
     if self.subword_type == 'wordpiece':
         return WordpieceVectorizer1D(
             vocab_file=vocab_file,
             mxlen=mxlen,
             emit_begin_tok=emit_begin_tok,
             emit_end_tok=emit_end_tok,
             transform_fn=transform_fn)
     return BPEVectorizer1D(model_file=model_file, vocab_file=vocab_file, mxlen=mxlen,
                            emit_begin_tok=emit_begin_tok, emit_end_tok=emit_end_tok,
                            transform_fn=transform_fn, extra_tokens=extra_tokens)
Example #6
0
 def __init__(self,
              nctx=64,
              model_file=None,
              vocab_file=None,
              pattern='*.txt',
              reader_type="ntp"):
     self.nctx = nctx
     self.pattern = pattern
     self.reader_type = reader_type
     self.vectorizer = BPEVectorizer1D(model_file=model_file,
                                       vocab_file=vocab_file,
                                       mxlen=nctx)
Example #7
0
    def __init__(self, nctx, use_subword=None, model_file=None, vocab_file=None, special_tokens=None):
        """Create a reader with a context window that reads words

        :param nctx: The context window length
        :param use_subword: If this is not none, it should be either 'bpe' or 'wordpiece'
        """
        self.use_subword = use_subword

        if self.use_subword == 'bpe':
            vectorizer = BPEVectorizer1D(model_file=model_file, vocab_file=vocab_file)
        elif self.use_subword == 'wordpiece':
            vectorizer = WordpieceVectorizer1D(embed_file=model_file, vocab_file=vocab_file,
                                               special_tokens=special_tokens)
        else:
            vectorizer = Token1DVectorizer(transform_fn=baseline.lowercase)
        super().__init__(nctx, {'x': vectorizer})
Example #8
0
def main():
    parser = ArgumentParser()
    parser.add_argument("--basedir", type=str)
    parser.add_argument("--train_dir",
                        type=str,
                        required=True,
                        help='Training directory')
    parser.add_argument("--valid_dir",
                        type=str,
                        required=True,
                        help='Validation directory')
    parser.add_argument(
        "--train_md",
        type=str,
        help="Training metadata YAML, defaults to `{train_dir}/md.yml`")
    parser.add_argument(
        "--valid_md",
        type=str,
        help="Validation metadata YAML, defaults to `{valid_dir}/md.yml`")
    parser.add_argument("--label_file",
                        type=str,
                        help="JSON file mapping labels to integers",
                        default="labels.json")
    parser.add_argument("--dataset_key",
                        default="tlm",
                        help="dataset key for basedir")
    parser.add_argument(
        "--embed_type",
        type=str,
        default='default',
        choices=["default", "positional", "learned-positional"],
        help="register label of the embeddings")
    parser.add_argument("--d_model",
                        type=int,
                        default=512,
                        help="Model dimension (and embedding dsz)")
    parser.add_argument("--d_ff", type=int, default=2048, help="FFN dimension")
    parser.add_argument("--num_heads",
                        type=int,
                        default=8,
                        help="Number of heads")
    parser.add_argument("--num_layers",
                        type=int,
                        default=8,
                        help="Number of layers")
    parser.add_argument("--num_train_workers",
                        type=int,
                        default=4,
                        help="Number train workers")
    parser.add_argument("--distribute",
                        type=str,
                        default="mirror",
                        choices=["mirror", "tpu", "nccl"])
    parser.add_argument("--tpu_ep",
                        type=str,
                        help="The TPU endpoint if using `distribute=tpu`")
    parser.add_argument("--nctx",
                        type=int,
                        default=256,
                        help="Max input length")
    parser.add_argument("--file_type",
                        default='tfrecord',
                        choices=['json', 'tfrecord'],
                        help="Glob pattern for data")
    parser.add_argument("--batch_size",
                        type=int,
                        default=256,
                        help="Batch Size")
    parser.add_argument("--subword_model_file",
                        type=str,
                        help="The BPE model file",
                        required=True)
    parser.add_argument("--subword_vocab_file",
                        type=str,
                        help="The BPE subword vocab",
                        required=True)
    parser.add_argument("--dropout", type=float, default=0.1, help="Dropout")
    parser.add_argument("--ffn_pdrop",
                        type=float,
                        default=0.0,
                        help="Dropout in the dense stack")
    parser.add_argument("--layer_drop",
                        type=float,
                        default=0.0,
                        help="LayerDrop to apply")
    parser.add_argument("--optim",
                        default="adamw",
                        type=str,
                        help="Optimizer to use (defaults to adamw)")
    parser.add_argument("--lr",
                        type=float,
                        default=4.0e-4,
                        help="Learning rate")
    parser.add_argument("--clip",
                        type=float,
                        default=1.0,
                        help="Clipping gradient norm")
    parser.add_argument("--weight_decay",
                        type=float,
                        default=1.0e-2,
                        help="Weight decay")
    parser.add_argument("--epochs",
                        type=int,
                        default=32,
                        help="Num training epochs")
    parser.add_argument(
        "--restart",
        type=str2bool,
        help="Option allows you to restart from a previous checkpoint")
    parser.add_argument("--warmup_steps",
                        type=int,
                        default=10000,
                        help="Num warmup steps")
    parser.add_argument("--saves_per_epoch",
                        type=int,
                        default=10,
                        help="The number of checkpoints to save per epoch")
    parser.add_argument(
        '--rpr_k',
        help=
        'Relative attention positional sizes pass 0 if you dont want relative attention',
        type=int,
        default=[8],
        nargs='+')
    parser.add_argument(
        '--rpr_value_on',
        type=str2bool,
        default=True,
        help=
        "In relative attention, whether add positional correction to values in addition to the "
        "correction to attention matrix")
    parser.add_argument('--windowed_ra',
                        type=str2bool,
                        default=False,
                        help="whether prevent attention beyond rpr_k")
    parser.add_argument("--strategy",
                        help="Training strategy, defaults to `mirror`",
                        choices=["mirror"])
    parser.add_argument("--npz",
                        help="Should we write out NPZ files?",
                        type=str2bool,
                        default=False)
    parser.add_argument("--tb",
                        help="Turn on tensorboard?",
                        type=str2bool,
                        default=False)
    parser.add_argument(
        "--convert_only",
        help="Should we just convert this file to NPZ and exit?",
        type=str2bool,
        default=False)
    args = parser.parse_args()
    SET_TRAIN_FLAG(True)

    if args.convert_only:
        args.restart = True

    if args.basedir is None:
        args.basedir = f'lm-{args.dataset_key}-bpe-{os.getpid()}'
    logging.basicConfig(level=logging.INFO)
    logger.info(f"Writing results to {args.basedir}")

    if args.tb:
        logdir = f"logs/scalars/{os.getpid()}"
        file_writer = tf.summary.create_file_writer(logdir + "/metrics")
        file_writer.set_as_default()
        logger.info(f"Set up tensorboard logdir {logdir}")

    strategy = create_distribute_strategy(args.distribute, args.tpu_ep)
    num_replicas = strategy.num_replicas_in_sync
    logger.info(f"Using {num_replicas} replicas in this job.")
    vectorizer = BPEVectorizer1D(model_file=args.subword_model_file,
                                 vocab_file=args.subword_vocab_file,
                                 mxlen=args.nctx)
    vocab = {'x': vectorizer.vocab}
    preproc_data = baseline.embeddings.load_embeddings(
        'x',
        dsz=args.d_model,
        known_vocab=vocab['x'],
        preserve_vocab_indices=True,
        embed_type=args.embed_type)
    vocabs = preproc_data['vocab']

    train_md = args.train_md if args.train_md else os.path.join(
        args.train_dir, 'md.yml')
    num_train_samples = get_num_samples(train_md)
    valid_md = args.valid_md if args.valid_md else os.path.join(
        args.valid_dir, 'md.yml')
    num_valid_samples = get_num_samples(valid_md)
    labels = read_json_tf(args.label_file)
    num_labels = len(labels)

    def dataset_train_fn(input_context):
        global_batchsz = args.batch_size
        base_batchsz = input_context.get_per_replica_batch_size(global_batchsz)
        ds = get_dataset(args.train_dir, args.file_type,
                         args.num_train_workers).batch(base_batchsz)
        return ds.shard(input_context.num_input_pipelines,
                        input_context.input_pipeline_id)

    train_loader = strategy.experimental_distribute_datasets_from_function(
        dataset_train_fn)

    def dataset_test_fn(input_context):
        global_batchsz = args.batch_size
        base_batchsz = input_context.get_per_replica_batch_size(global_batchsz)
        ds = get_dataset(args.valid_dir,
                         args.file_type,
                         args.num_train_workers,
                         shuffle=False).batch(base_batchsz)

        return ds.shard(input_context.num_input_pipelines,
                        input_context.input_pipeline_id)

    valid_loader = strategy.experimental_distribute_datasets_from_function(
        dataset_test_fn)

    os.makedirs(args.basedir, exist_ok=True)
    # We want to make sure to save our input vocab into the basedir for reuse later
    write_json(vocabs, os.path.join(args.basedir, 'vocabs.json'))
    embeddings = {'x': preproc_data['embeddings']}
    logger.info("Loaded embeddings")

    logger.info("Loaded datasets")
    logger.info("Using embedding type [%s]", args.embed_type)
    if len(args.rpr_k) == 0 or args.rpr_k[0] < 1:
        args.rpr_k = None
    elif len(args.rpr_k) == 1:
        args.rpr_k = args.rpr_k[0]

    model = TransformerTagger(num_labels, embeddings, **vars(args))

    logger.info("Loaded model and loss")

    steps_per_epoch = num_train_samples // args.batch_size
    steps_per_valid_epoch = num_valid_samples // args.batch_size
    update_on = steps_per_epoch // args.saves_per_epoch
    report_on = max(10, update_on) // 10
    logger.info(
        f"Steps per epoch: {steps_per_epoch}. Saving checkpoint every {update_on} steps."
    )

    lr_decay = CosineDecaySchedulerTensorFlow(steps_per_epoch * args.epochs,
                                              lr=args.lr)
    linear_warmup = WarmupLinearSchedulerTensorFlow(args.warmup_steps,
                                                    lr=args.lr)
    lr_sched = CompositeLRSchedulerTensorFlow(linear_warmup, lr_decay)
    optimizer = EagerOptimizer(loss_function,
                               optim=args.optim,
                               lr_function=lr_sched,
                               weight_decay=args.weight_decay,
                               clip=args.clip,
                               lr=args.lr)
    checkpoint = tf.train.Checkpoint(optimizer=optimizer.optimizer,
                                     model=model)
    checkpoint_manager = tf.train.CheckpointManager(checkpoint,
                                                    directory=args.basedir,
                                                    max_to_keep=5)

    start_epoch = 0
    if args.restart:
        # The global step gets automatically updated here
        # so we dont have to worry about our LR regimen
        checkpoint.restore(checkpoint_manager.latest_checkpoint)
        current_step = optimizer.global_step
        start_epoch = current_step // steps_per_epoch

    def _replicated_train_step(inputs):
        """This runs on a single replica"""
        x, y = inputs
        per_replica_loss = optimizer.update(model, {'x': x}, y, num_replicas)
        return per_replica_loss

    @tf.function
    def _distributed_train_step(inputs: Tuple[tf.Tensor, tf.Tensor]):
        """Runs across multiple replicas and aggregates the results.

        :param inputs:
        :return:
        """
        per_replica_loss = strategy.run(_replicated_train_step,
                                        args=(inputs, ))
        return strategy.reduce(tf.distribute.ReduceOp.SUM,
                               per_replica_loss,
                               axis=None)

    def _replicated_test_step(inputs):
        """This runs on a single replica"""
        x, y = inputs
        per_replica_loss = loss_function(model, {'x': x}, y) / num_replicas
        return per_replica_loss

    @tf.function
    def _distributed_test_step(inputs: Tuple[tf.Tensor, tf.Tensor]):
        """Runs across multiple replicas and aggregates the results.

        :param inputs:
        :return:
        """
        per_replica_loss = strategy.run(_replicated_test_step, args=(inputs, ))
        return strategy.reduce(tf.distribute.ReduceOp.SUM,
                               per_replica_loss,
                               axis=None)

    timer = Timer()
    with strategy.scope():

        for epoch in range(start_epoch, args.epochs):
            SET_TRAIN_FLAG(True)
            logger.info('Starting epoch %d', epoch + 1)
            avg_loss = Average('average_train_loss')
            metrics = {}
            timer.start()
            train_iter = iter(train_loader)
            for i in range(steps_per_epoch):

                try:
                    loss = _distributed_train_step(next(train_iter))
                    avg_loss.update(loss.numpy().item())
                    tf.summary.scalar("train_loss",
                                      data=loss,
                                      step=optimizer.global_step)
                except Exception as e:
                    logger.error(
                        f"Exception at training step {i+1}/{steps_per_epoch}. Skipping"
                    )
                    pass
                if args.convert_only:
                    logger.warning(
                        "Convert only flag specified.  Stopping after one step"
                    )
                    steps = optimizer.global_step.numpy()
                    npz_checkpoint = os.path.join(
                        args.basedir, f'checkpoint-step-{steps}.npz')
                    save_tlm_output_npz(model, npz_checkpoint)
                    return

                steps = optimizer.global_step.numpy()
                if (steps + 1) % report_on == 0:
                    logger.info(avg_loss)
                if (steps + 1) % update_on == 0:
                    elapsed = timer.elapsed(True)
                    logger.info('elapsed time this epoch %d min', elapsed)
                    logger.info('elapsed step time %f steps/min', i / elapsed)
                    checkpoint_manager.save()
                    if args.npz:

                        npz_checkpoint = os.path.join(
                            args.basedir, f'checkpoint-step-{steps}.npz')
                        save_tlm_output_npz(model, npz_checkpoint)

            # How much time elapsed in minutes
            elapsed = timer.elapsed(True)
            train_token_loss = avg_loss.avg
            # This is the average training token-level loss across all machines
            # This is the token-level training perplexity
            train_token_ppl = math.exp(train_token_loss)
            metrics['train_elapsed_min'] = elapsed
            metrics['average_train_loss'] = train_token_loss
            metrics['train_ppl'] = train_token_ppl
            metrics['lr'] = float(
                lr_sched(tf.cast(optimizer.global_step,
                                 tf.float32)).numpy().item())

            avg_valid_loss = Average('average_valid_loss')
            timer.start()
            SET_TRAIN_FLAG(False)
            valid_iter = iter(valid_loader)
            for i in range(steps_per_valid_epoch):
                try:
                    valid_loss = _distributed_test_step(next(valid_iter))
                    tf.summary.scalar('valid_loss',
                                      data=valid_loss,
                                      step=optimizer.global_step)
                    avg_valid_loss.update(valid_loss.numpy().item())
                except Exception as e:
                    logger.error(
                        f"Exception at validation step {i+1}/{steps_per_valid_epoch}. Skipping"
                    )
                    pass

            valid_token_loss = avg_valid_loss.avg
            valid_token_ppl = math.exp(valid_token_loss)

            elapsed = timer.elapsed(True)

            metrics['valid_elapsed_min'] = elapsed
            metrics['average_valid_loss'] = valid_token_loss
            metrics['average_valid_word_ppl'] = valid_token_ppl
            logger.info(json.dumps(metrics, indent=4))
Example #9
0
def train():
    parser = ArgumentParser()
    parser.add_argument("--basedir", type=str)
    parser.add_argument("--train_dir",
                        type=str,
                        required=True,
                        help='Training directory')
    parser.add_argument("--valid_dir",
                        type=str,
                        required=True,
                        help='Validation directory')
    parser.add_argument(
        "--train_md",
        type=str,
        help="Training metadata YAML, defaults to `{train_dir}/md.yml`")
    parser.add_argument(
        "--valid_md",
        type=str,
        help="Validation metadata YAML, defaults to `{valid_dir}/md.yml`")
    parser.add_argument("--dataset_key",
                        default="tlm",
                        help="dataset key for basedir")
    parser.add_argument(
        "--embed_type",
        type=str,
        default='default',
        choices=["default", "positional", "learned-positional"],
        help="register label of the embeddings")

    parser.add_argument("--gen_d_model",
                        type=int,
                        default=256,
                        help="Model dimension (and embedding dsz)")
    parser.add_argument("--gen_d_ff",
                        type=int,
                        default=1024,
                        help="FFN dimension")
    parser.add_argument(
        "--gen_d_k",
        type=int,
        default=None,
        help="Dimension per head.  Use if num_heads=1 to reduce dims")
    parser.add_argument("--gen_num_heads",
                        type=int,
                        default=8,
                        help="Number of heads")
    parser.add_argument("--gen_num_layers",
                        type=int,
                        default=8,
                        help="Number of layers")
    parser.add_argument(
        '--gen_rpr_k',
        help=
        'Relative attention positional sizes pass 0 if you dont want relative attention',
        type=int,
        default=[8],
        nargs='+')
    parser.add_argument('--windowed_ra',
                        type=str2bool,
                        default=False,
                        help="whether prevent attention beyond rpr_k")
    parser.add_argument("--gen_loss_scale",
                        type=float,
                        default=50.0,
                        help="Scaling for loss function")
    parser.add_argument("--gen_dropout",
                        type=float,
                        default=0.1,
                        help="Dropout")

    parser.add_argument(
        '--discrim_rpr_k',
        help=
        'Relative attention positional sizes pass 0 if you dont want relative attention',
        type=int,
        default=[8],
        nargs='+')

    parser.add_argument("--discrim_d_model",
                        type=int,
                        default=512,
                        help="Model dimension (and embedding dsz)")
    parser.add_argument("--discrim_d_ff",
                        type=int,
                        default=2048,
                        help="FFN dimension")
    parser.add_argument(
        "--discrim_d_k",
        type=int,
        default=None,
        help="Dimension per head.  Use if num_heads=1 to reduce dims")
    parser.add_argument("--discrim_num_heads",
                        type=int,
                        default=8,
                        help="Number of heads")
    parser.add_argument("--discrim_num_layers",
                        type=int,
                        default=8,
                        help="Number of layers")
    parser.add_argument("--discrim_dropout",
                        type=float,
                        default=0.1,
                        help="Dropout")

    parser.add_argument("--num_train_workers",
                        type=int,
                        default=4,
                        help="Number train workers")
    parser.add_argument("--distribute",
                        type=str,
                        default="mirror",
                        choices=["mirror", "tpu", "nccl"])
    parser.add_argument("--tpu_ep",
                        type=str,
                        help="The TPU endpoint if using `distribute=tpu`")
    parser.add_argument("--nctx",
                        type=int,
                        default=256,
                        help="Max input length")
    parser.add_argument("--file_type",
                        default='tfrecord',
                        choices=['json', 'tfrecord'],
                        help="Glob pattern for data")
    parser.add_argument("--batch_size",
                        type=int,
                        default=256,
                        help="Batch Size")
    parser.add_argument("--subword_model_file",
                        type=str,
                        help="The BPE model file",
                        required=True)
    parser.add_argument("--subword_vocab_file",
                        type=str,
                        help="The BPE subword vocab",
                        required=True)
    parser.add_argument("--optim",
                        default="adam",
                        type=str,
                        help="Optimizer to use (defaults to adam)")
    parser.add_argument("--lr",
                        type=float,
                        default=4.0e-4,
                        help="Learning rate")
    parser.add_argument("--clip",
                        type=float,
                        default=1.0,
                        help="Clipping gradient norm")
    parser.add_argument("--weight_decay",
                        type=float,
                        default=1.0e-2,
                        help="Weight decay")
    parser.add_argument("--epochs",
                        type=int,
                        default=32,
                        help="Num training epochs")
    parser.add_argument(
        "--restart",
        type=str2bool,
        help="Option allows you to restart from a previous checkpoint")
    parser.add_argument("--warmup_steps",
                        type=int,
                        default=10000,
                        help="Num warmup steps")
    parser.add_argument("--causal",
                        type=str2bool,
                        default=False,
                        help="Use CLM (causal) instead of MLM")
    parser.add_argument("--saves_per_epoch",
                        type=int,
                        default=10,
                        help="The number of checkpoints to save per epoch")
    parser.add_argument("--strategy",
                        help="Training strategy, defaults to `mirror`",
                        choices=["mirror"])
    parser.add_argument("--npz",
                        help="Should we write out NPZ files?",
                        type=str2bool,
                        default=False)
    parser.add_argument("--tb",
                        help="Turn on tensorboard?",
                        type=str2bool,
                        default=False)
    parser.add_argument(
        "--convert_only",
        help="Should we just convert this file to NPZ and exit?",
        type=str2bool,
        default=False)
    args = parser.parse_args()
    SET_TRAIN_FLAG(True)

    if args.convert_only:
        args.restart = True
        args.npz = True

    if args.basedir is None:
        args.basedir = f'discrim-{args.dataset_key}-bpe-{os.getpid()}'
    logging.basicConfig(level=logging.INFO)
    logger.info(f"Writing results to {args.basedir}")

    if args.tb:
        logdir = f"logs/scalars/{os.getpid()}"
        file_writer = tf.summary.create_file_writer(logdir + "/metrics")
        file_writer.set_as_default()
        logger.info(f"Set up tensorboard logdir {logdir}")

    strategy = create_distribute_strategy(args.distribute, args.tpu_ep)
    num_replicas = strategy.num_replicas_in_sync
    logger.info(f"Using {num_replicas} replicas in this job.")
    vectorizer = BPEVectorizer1D(model_file=args.subword_model_file,
                                 vocab_file=args.subword_vocab_file,
                                 mxlen=args.nctx)
    vocab = {'x': vectorizer.vocab}
    gen_preproc_data = baseline.embeddings.load_embeddings(
        'x',
        dsz=args.gen_d_model,
        known_vocab=vocab['x'],
        preserve_vocab_indices=True,
        embed_type=args.embed_type)

    vocabs = gen_preproc_data['vocab']

    discrim_preproc_data = baseline.embeddings.load_embeddings(
        'x',
        dsz=args.discrim_d_model,
        known_vocab=vocab['x'],
        preserve_vocab_indices=True,
        embed_type=args.embed_type)

    def dataset_train_fn(input_context):
        batch_size = input_context.get_per_replica_batch_size(args.batch_size)
        ds = get_dataset(args.train_dir, args.file_type,
                         args.num_train_workers).batch(batch_size)
        return ds.shard(input_context.num_input_pipelines,
                        input_context.input_pipeline_id)

    train_loader = strategy.experimental_distribute_datasets_from_function(
        dataset_train_fn)

    def dataset_test_fn(input_context):
        batch_size = input_context.get_per_replica_batch_size(args.batch_size)
        ds = get_dataset(args.valid_dir,
                         args.file_type,
                         args.num_train_workers,
                         shuffle=False).batch(batch_size)
        return ds.shard(input_context.num_input_pipelines,
                        input_context.input_pipeline_id)

    valid_loader = strategy.experimental_distribute_datasets_from_function(
        dataset_test_fn)

    train_md = args.train_md if args.train_md else os.path.join(
        args.train_dir, 'md.yml')
    num_train_samples = get_num_samples(train_md)
    valid_md = args.valid_md if args.valid_md else os.path.join(
        args.valid_dir, 'md.yml')
    num_valid_samples = get_num_samples(valid_md)
    os.makedirs(args.basedir, exist_ok=True)
    # We want to make sure to save our input vocab into the basedir for reuse later
    write_json(vocabs, os.path.join(args.basedir, 'vocabs.json'))
    gen_embeddings = {'x': gen_preproc_data['embeddings']}
    discrim_embeddings = {'x': discrim_preproc_data['embeddings']}
    logger.info("Loaded embeddings")

    logger.info("Loaded datasets")
    logger.info("Using embedding type [%s]", args.embed_type)
    if len(args.gen_rpr_k) == 0 or args.gen_rpr_k[0] < 1:
        gen_rpr_k = None
    elif len(args.gen_rpr_k) == 1:
        gen_rpr_k = args.gen_rpr_k[0]
    else:
        gen_rpr_k = args.gen_rpr_k

    if len(args.discrim_rpr_k) == 0 or args.discrim_rpr_k[0] < 1:
        discrim_rpr_k = None
    elif len(args.gen_rpr_k) == 1:
        discrim_rpr_k = args.discrim_rpr_k[0]
    else:
        discrim_rpr_k = args.discrim_rpr_k

    gen_model = TransformerMaskedLanguageModel.create(
        gen_embeddings,
        hsz=args.gen_d_model,
        d_ff=args.gen_d_ff,
        tie_weights=True,
        dropout=args.gen_dropout,
        gpu=False,
        num_heads=args.gen_num_heads,
        layers=args.gen_num_layers,
        rpr_k=gen_rpr_k,
        d_k=args.gen_d_k,
        windowed_ra=args.windowed_ra,
        src_keys=['x'],
        tgt_key='x')

    discrim_model = TransformerDiscriminator(discrim_embeddings,
                                             d_model=args.discrim_d_model,
                                             d_ff=args.discrim_d_ff,
                                             dropout=args.discrim_dropout,
                                             num_heads=args.discrim_num_heads,
                                             layers=args.discrim_num_layers,
                                             rpr_k=discrim_rpr_k,
                                             d_k=args.discrim_d_k)

    logger.info("Loaded model and loss")
    steps_per_epoch = num_train_samples // args.batch_size
    steps_per_valid_epoch = num_valid_samples // args.batch_size
    update_on = steps_per_epoch // args.saves_per_epoch
    report_on = max(10, update_on) // 10
    logger.info(
        f"Steps per epoch: {steps_per_epoch}. Saving checkpoint every {update_on} steps."
    )

    lr_decay = CosineDecaySchedulerTensorFlow(steps_per_epoch * args.epochs,
                                              lr=args.lr)
    linear_warmup = WarmupLinearSchedulerTensorFlow(args.warmup_steps,
                                                    lr=args.lr)
    lr_sched = CompositeLRSchedulerTensorFlow(linear_warmup, lr_decay)

    mask_value = vocabs.get("[MASK]", vocabs.get("<MASK>", -1))
    if mask_value == -1:
        logger.error("We could not find a suitable masking token in the vocab")
        return

    optimizer, clip = create_keras_optimizer(**vars(args))

    discrim_checkpoint = tf.train.Checkpoint(optimizer=optimizer,
                                             model=discrim_model)
    discrim_checkpoint_manager = tf.train.CheckpointManager(
        discrim_checkpoint,
        directory=os.path.join(args.basedir, 'discrim'),
        max_to_keep=5)

    gen_checkpoint = tf.train.Checkpoint(optimizer=optimizer,
                                         model=discrim_model)
    gen_checkpoint_manager = tf.train.CheckpointManager(gen_checkpoint,
                                                        directory=os.path.join(
                                                            args.basedir,
                                                            'gen'),
                                                        max_to_keep=5)

    mask_value = vocabs.get("[MASK]", vocabs.get("<MASK>", -1))
    if mask_value == -1:
        logger.error("We could not find a suitable masking token in the vocab")
        return

    if args.restart:
        # The global step gets automatically updated here
        # so we dont have to worry about our LR regimen
        gen_checkpoint.restore(gen_checkpoint_manager.latest_checkpoint)
        discrim_checkpoint.restore(
            discrim_checkpoint_manager.latest_checkpoint)

    def _replicated_train_step(inputs):
        """This runs on a single replica"""
        noised_x, labels = inputs
        with tf.GradientTape() as tape:
            gen_loss_step, discrim_loss_step, acc = gen_vs_discrim(
                noised_x, labels, gen_model, discrim_model, mask_value)
            loss_value = (args.gen_loss_scale * gen_loss_step +
                          discrim_loss_step) / num_replicas

        grads = tape.gradient(
            loss_value,
            gen_model.trainable_variables + discrim_model.trainable_variables)
        grads, _ = tf.clip_by_global_norm(grads, clip)
        optimizer.apply_gradients(
            zip(
                grads, gen_model.trainable_variables +
                discrim_model.trainable_variables))

        return loss_value, gen_loss_step, discrim_loss_step, acc

    @tf.function
    def _distributed_train_step(inputs: Tuple[tf.Tensor, tf.Tensor]):
        """Runs across multiple replicas and aggregates the results.

        :param inputs:
        :return:
        """
        loss, gen_loss, discrim_loss, acc = strategy.run(
            _replicated_train_step, args=(inputs, ))
        sum_loss = strategy.reduce(tf.distribute.ReduceOp.SUM, loss, axis=None)
        sum_gen_loss = strategy.reduce(tf.distribute.ReduceOp.SUM,
                                       gen_loss,
                                       axis=None)
        sum_discrim_loss = strategy.reduce(tf.distribute.ReduceOp.SUM,
                                           discrim_loss,
                                           axis=None)
        sum_acc = strategy.reduce(tf.distribute.ReduceOp.SUM, acc, axis=None)
        return sum_loss, sum_gen_loss, sum_discrim_loss, sum_acc

    def _replicated_test_step(inputs):
        """This runs on a single replica"""
        noised_x, labels = inputs
        gen_loss_step, discrim_loss_step, acc = gen_vs_discrim(
            noised_x, labels, gen_model, discrim_model, mask_value)
        loss_value = (args.gen_loss_scale * gen_loss_step +
                      discrim_loss_step) / num_replicas
        return loss_value, gen_loss_step, discrim_loss_step, acc

    @tf.function
    def _distributed_test_step(inputs: Tuple[tf.Tensor, tf.Tensor]):
        """Runs across multiple replicas and aggregates the results.

        :param inputs:
        :return:
        """
        loss, gen_loss, discrim_loss, acc = strategy.run(_replicated_test_step,
                                                         args=(inputs, ))
        sum_loss = strategy.reduce(tf.distribute.ReduceOp.SUM, loss, axis=None)
        sum_gen_loss = strategy.reduce(tf.distribute.ReduceOp.SUM,
                                       gen_loss,
                                       axis=None)
        sum_discrim_loss = strategy.reduce(tf.distribute.ReduceOp.SUM,
                                           discrim_loss,
                                           axis=None)
        sum_acc = strategy.reduce(tf.distribute.ReduceOp.SUM, acc, axis=None)
        return sum_loss, sum_gen_loss, sum_discrim_loss, sum_acc

    # This is the training loop
    start_epoch = 0
    timer = Timer()
    with strategy.scope():

        for epoch in range(start_epoch, args.epochs):
            SET_TRAIN_FLAG(True)
            logger.info('Starting epoch %d', epoch + 1)
            avg_loss = Average('average_train_loss')
            avg_gen_loss = Average('average_gen_loss')
            avg_discrim_loss = Average('average_discrim_loss')
            avg_acc = Average('average_train_acc')

            metrics = {}
            timer.start()
            train_iter = iter(train_loader)
            for i in range(steps_per_epoch):
                loss, gen_loss, discrim_loss, acc = _distributed_train_step(
                    next(train_iter))
                avg_loss.update(loss.numpy().item())
                avg_gen_loss.update(gen_loss.numpy().item())
                avg_discrim_loss.update(discrim_loss.numpy().item())
                avg_acc.update(acc.numpy().item())

                tf.summary.scalar("train_loss",
                                  data=loss,
                                  step=optimizer.iterations)
                tf.summary.scalar("train_gen_loss",
                                  data=gen_loss,
                                  step=optimizer.iterations)
                tf.summary.scalar("train_discrim_loss",
                                  data=discrim_loss,
                                  step=optimizer.iterations)
                tf.summary.scalar("train_acc",
                                  data=acc,
                                  step=optimizer.iterations)

                if args.convert_only:
                    logger.warning(
                        "Convert only flag specified.  Stopping after one step"
                    )
                    steps = optimizer.iterations.numpy()
                    npz_checkpoint = os.path.join(args.basedir,
                                                  f'discrim-step-{steps}.npz')
                    save_tlm_npz(discrim_model, npz_checkpoint)
                    npz_checkpoint = os.path.join(args.basedir,
                                                  f'gen-step-{steps}.npz')
                    save_tlm_npz(gen_model, npz_checkpoint)
                    return

                if (i + 1) % report_on == 0:
                    logging.info(avg_loss)
                    logging.info(avg_gen_loss)
                    logging.info(avg_discrim_loss)
                    logging.info(avg_acc)
                if (i + 1) % update_on == 0:
                    elapsed = timer.elapsed(True)
                    logging.info('elapsed time this epoch %d min', elapsed)
                    logging.info('elapsed step time %f steps/min', i / elapsed)
                    gen_checkpoint_manager.save()
                    discrim_checkpoint_manager.save()

                    if args.npz:
                        steps = optimizer.iterations.numpy()
                        npz_checkpoint = os.path.join(
                            args.basedir, f'discrim-step-{steps}.npz')
                        save_tlm_npz(discrim_model, npz_checkpoint)
                        npz_checkpoint = os.path.join(args.basedir,
                                                      f'gen-step-{steps}.npz')
                        save_tlm_npz(gen_model, npz_checkpoint)

            # This is the average training token-level loss across all machines
            # This is the token-level training perplexity
            metrics['train_elapsed_min'] = timer.elapsed(True)
            metrics['average_train_loss'] = avg_loss.avg
            metrics['average_gen_loss'] = avg_gen_loss.avg
            metrics['average_discrim_loss'] = avg_discrim_loss.avg
            metrics['average_train_acc'] = avg_acc.avg
            metrics['lr'] = float(
                lr_sched(tf.cast(optimizer.global_step,
                                 tf.float32)).numpy().item())

            avg_valid_loss = Average('average_valid_loss')
            avg_valid_gen_loss = Average('average_valid_gen_loss')
            avg_valid_discrim_loss = Average('average_valid_discrim_loss')
            avg_valid_acc = Average('average_valid_acc')

            timer.start()
            SET_TRAIN_FLAG(False)
            valid_iter = iter(valid_loader)
            for i in range(steps_per_valid_epoch):
                valid_loss, valid_gen_loss, valid_discrim_loss, valid_acc = _distributed_test_step(
                    next(valid_iter))
                tf.summary.scalar('valid_loss',
                                  data=valid_loss,
                                  step=optimizer.iterations)
                tf.summary.scalar('valid_gen_loss',
                                  data=valid_gen_loss,
                                  step=optimizer.iterations)
                tf.summary.scalar('valid_discrim_loss',
                                  data=valid_discrim_loss,
                                  step=optimizer.iterations)
                tf.summary.scalar('valid_acc',
                                  data=valid_acc,
                                  step=optimizer.iterations)
                avg_valid_loss.update(valid_loss.numpy().item())
                avg_valid_gen_loss.update(valid_gen_loss.numpy().item())
                avg_valid_discrim_loss.update(
                    valid_discrim_loss.numpy().item())
                avg_valid_acc.update(valid_acc.numpy().item())

            metrics['valid_elapsed_min'] = timer.elapsed(True)
            metrics['average_valid_loss'] = avg_valid_loss.avg
            metrics['average_valid_gen_loss'] = avg_valid_gen_loss.avg
            metrics['average_valid_discrim_loss'] = avg_valid_discrim_loss.avg
            metrics['average_valid_acc'] = avg_valid_acc.avg
            logger.info(json.dumps(metrics, indent=4))
Example #10
0
def run():
    parser = ArgumentParser()
    parser.add_argument("--basedir", type=str)
    parser.add_argument("--checkpoint",
                        type=str,
                        help='Checkpoint name or directory to load')
    parser.add_argument("--sample",
                        type=str2bool,
                        help='Sample from the decoder?  Defaults to `false`',
                        default=0)
    parser.add_argument("--vocab",
                        type=str,
                        help='Vocab file to load',
                        required=False)
    parser.add_argument("--input", type=str, default='hello how are you ?')
    parser.add_argument("--dataset_cache",
                        type=str,
                        default=os.path.expanduser('~/.bl-data'),
                        help="Path or url of the dataset cache")
    parser.add_argument("--d_model",
                        type=int,
                        default=512,
                        help="Model dimension (and embedding dsz)")
    parser.add_argument("--d_ff", type=int, default=2048, help="FFN dimension")
    parser.add_argument(
        "--d_k",
        type=int,
        default=None,
        help="Dimension per head.  Use if num_heads=1 to reduce dims")
    parser.add_argument("--num_heads",
                        type=int,
                        default=8,
                        help="Number of heads")
    parser.add_argument("--num_layers",
                        type=int,
                        default=8,
                        help="Number of layers")
    parser.add_argument(
        "--nctx",
        type=int,
        default=256,
        help="Max context length (for both encoder and decoder)")
    parser.add_argument(
        "--embed_type",
        type=str,
        default='default',
        help=
        "register label of the embeddings, so far support positional or learned-positional"
    )
    parser.add_argument("--subword_model_file", type=str, required=True)
    parser.add_argument("--subword_vocab_file", type=str, required=True)
    parser.add_argument("--batchsz",
                        help="Size of a batch to pass at once",
                        default=4,
                        type=int)
    parser.add_argument("--beamsz",
                        help="Size of beam to use",
                        default=4,
                        type=int)
    parser.add_argument("--activation", type=str, default='relu')
    parser.add_argument(
        '--rpr_k',
        help=
        'Relative attention positional sizes pass 0 if you dont want relative attention',
        type=int,
        default=[8] * 8,
        nargs='+')
    #parser.add_argument("--go_token", default="<GO>")
    parser.add_argument("--end_token", default="<EOS>")
    parser.add_argument("--output_file", type=str)
    parser.add_argument("--show_query",
                        type=str2bool,
                        default=False,
                        help="Show the original query as well")
    parser.add_argument("--device",
                        type=str,
                        default="cuda" if torch.cuda.is_available() else "cpu",
                        help="Device (cuda or cpu)")
    args = parser.parse_args()

    if torch.cuda.device_count() == 1:
        torch.cuda.set_device(0)
        args.device = torch.device("cuda", 0)

    if os.path.isdir(args.checkpoint):
        checkpoint, _ = find_latest_checkpoint(args.checkpoint)
        logger.warning("Found latest checkpoint %s", checkpoint)
    else:
        checkpoint = args.checkpoint

    vectorizer = BPEVectorizer1D(model_file=args.subword_model_file,
                                 vocab_file=args.subword_vocab_file,
                                 mxlen=args.nctx,
                                 emit_end_tok=args.end_token)
    vocab = vectorizer.vocab
    # If we are not using chars, then use 'x' for both input and output
    preproc_data = baseline.embeddings.load_embeddings(
        'x',
        dsz=args.d_model,
        counts=False,
        known_vocab=vocab,
        embed_type=args.embed_type)
    embeddings = preproc_data['embeddings']
    vocab = preproc_data['vocab']
    model = create_model(embeddings,
                         d_model=args.d_model,
                         d_ff=args.d_ff,
                         num_heads=args.num_heads,
                         num_layers=args.num_layers,
                         rpr_k=args.rpr_k,
                         d_k=args.d_k,
                         checkpoint_name=checkpoint,
                         activation=args.activation,
                         device=args.device)
    model.to(args.device)

    index2word = revlut(vocab)
    wf = None
    if args.output_file:
        wf = open(args.output_file, "w")

    batches = []
    if os.path.exists(args.input) and os.path.isfile(args.input):
        with open(args.input, 'rt', encoding='utf-8') as f:
            batch = []
            for line in f:
                text = line.strip().split()
                if len(batch) == args.batchsz:
                    batches.append(batch)
                    batch = []
                batch.append(text)

            if len(batch) > 0:
                batches.append(batch)

    else:
        batch = [args.input.split()]
        batches.append(batch)

    for queries in batches:

        outputs = decode_sentences(model, vectorizer, queries, vocab,
                                   index2word, args.beamsz)

        if args.show_query:
            for query, output in zip(queries, outputs):
                print(f"[Query] {query}")
                print(f"[Response] {output}")
        elif wf:
            for query, output in zip(queries, outputs):
                wf.write(f'{output}\n')
                wf.flush()
        else:
            for query, output in zip(queries, outputs):
                print(output)
    if wf:
        wf.close()
Example #11
0
def run(input_files=[], input_pattern='*.txt', codes=None, vocab=None, nctx=256, fmt='json', fields=['x_str', 'y_str'],
        output=None, prefix=None, suffix=None, max_file_size=100, tok_on_eol="<EOS>", cased=True,
        mask_type="mlm", module=None, pad_y=True, extra_tokens=['[CLS]', '[MASK]'], world_size=1, world_offset=0,
        input_field='text', tokenizer_type=None, **kwargs):

    def parse_json_line(x): return json.loads(x)[input_field]

    if module:
        logger.warning("Loading custom user module %s for masking rules and tokenizers", module)
        baseline.import_user_module(module)

    get_line = lambda x: x.strip()
    if os.path.isdir(input_files):
        if '.json' in input_pattern:
            get_line = parse_json_line
        input_files = list(glob.glob(os.path.join(input_files, input_pattern)))
        if not output:
            output = os.path.join(input_files, 'records')
    else:
        if '.json' in input_files:
            get_line = parse_json_line
        input_files = [input_files]
        if not output:
            output = f'{input_files}.records'

    if len(input_files) < world_size:
        raise Exception(f"The number of input shards ({len(input_files)})should be greater than the world_size: {world_size}")

    logger.info('Output [%s]', output)
    transform = baseline.lowercase if not cased else lambda x: x
    vectorizer = BPEVectorizer1D(transform_fn=transform, model_file=codes, vocab_file=vocab, mxlen=1024, extra_tokens=extra_tokens)

    lookup_indices = []
    indices2word = baseline.revlut(vectorizer.vocab)
    root_dir = os.path.dirname(output)
    tokenizer = create_tokenizer(tokenizer_type)
    masking = create_masking(mask_type, vectorizer.vocab, pad_y)
    if not os.path.exists(root_dir):
        os.makedirs(root_dir)

    if prefix:
        nctx -= 1
        prefix = vectorizer.vocab[prefix]

    if suffix:
        nctx -= 1
        suffix = vectorizer.vocab[suffix]

    fw = create_file_writer(fmt, output, fields, max_file_size, 1000 * world_offset)
    num_samples = 0
    for i, text in enumerate(input_files):

        if i % world_size != world_offset:
            continue

        with TextFile(text) as rf:
            print(f"Reading from {text}...")
            for line in rf:
                to_bpe = tokenizer(get_line(line))
                if not to_bpe:
                    continue
                to_bpe += [tok_on_eol]

                output, available = vectorizer.run(to_bpe, vectorizer.vocab)
                while available > 0:
                    if len(lookup_indices) == nctx:
                        record = create_record(lookup_indices, indices2word, prefix, suffix, masking=masking)
                        fw.write(record)
                        num_samples += 1
                        lookup_indices = []
                    needed = nctx - len(lookup_indices)
                    if available >= needed:
                        lookup_indices += output[:needed].tolist()
                        output = output[needed:]
                        available -= needed
                        record = create_record(lookup_indices, indices2word, prefix, suffix, masking=masking)
                        fw.write(record)
                        num_samples += 1
                        lookup_indices = []
                    # The amount available is less than what we need, so read the whole thing
                    else:
                        lookup_indices += output[:available].tolist()
                        available = 0

    fw.close()
    f_name = f'md-{world_offset}.yml' if world_size > 1 else 'md.yml'
    write_yaml({'num_samples': num_samples}, os.path.join(root_dir, f_name))
Example #12
0
if os.path.isdir(args.input_files):
    import glob
    input_files = list(
        glob.glob(os.path.join(args.input_files, args.input_pattern)))
    if not args.output:
        args.output = os.path.join(args.input_files, 'records')
else:
    input_files = [args.input_files]
    if not args.output:
        args.output = f'{args.input_files}.records'

print(args.output)
transform = baseline.lowercase if not args.cased else lambda x: x
vectorizer = BPEVectorizer1D(transform_fn=transform,
                             model_file=args.codes,
                             vocab_file=args.vocab,
                             mxlen=1024)

lookup_indices = []
words = []
indices2word = baseline.revlut(vectorizer.vocab)
vocab_size = max(vectorizer.vocab.values()) + 1
nctx = args.nctx
mask_value = vectorizer.vocab['[MASK]']
prefix = suffix = None
root_dir = os.path.dirname(args.output)
if not os.path.exists(root_dir):
    os.makedirs(root_dir)

if args.prefix:
    nctx -= 1
Example #13
0
parser.add_argument("--valid_split", type=float, default=0.05)
parser.add_argument("--prefix", default="<GO>")
parser.add_argument("--suffix", default="<EOS>")
parser.add_argument("--pg_name",
                    choices=["tqdm", "default"],
                    default="default")

args = parser.parse_args()
annot_files = list(Path(args.annot_files).iterdir())
valid_split = int(len(annot_files) * args.valid_split)
VALID_FILES = annot_files[:valid_split]
TRAIN_FILES = annot_files[valid_split:]

VECTORIZER = BPEVectorizer1D(
    transform_fn=baseline.lowercase if not args.cased else lambda x: x,
    model_file=args.codes,
    vocab_file=args.vocab,
    mxlen=1024)
NCTX = args.nctx - 2
PREFIX = (
    VECTORIZER.vocab[args.prefix],
    Offsets.GO,
)
SUFFIX = (
    VECTORIZER.vocab[args.suffix],
    Offsets.EOS,
)

DOC2WORD = read_vocab_file(args.document_vocab)
label2word = read_vocab_file(args.label_vocab)
LABELS = {Offsets.VALUES[k]: k for k in range(Offsets.OFFSET)}
Example #14
0
def main():
    parser = argparse.ArgumentParser(
        description='Convert text into MLM fixed width contexts')

    parser.add_argument(
        '--input_files',
        help=
        'The text to classify as a string, or a path to a file with each line as an example',
        type=str)
    parser.add_argument(
        '--annot_files',
        help=
        'The text to classify as a string, or a path to a file with each line as an example',
        type=str)
    parser.add_argument('--codes', help='BPE codes')
    parser.add_argument('--vocab', help='BPE vocab')
    parser.add_argument("--nctx",
                        type=int,
                        default=256,
                        help="Max input length")
    parser.add_argument("--fmt",
                        type=str,
                        default='json',
                        choices=['json', 'tsv', 'tfrecord'])
    parser.add_argument("--fields",
                        type=str,
                        nargs="+",
                        default=["x_str", "y_str"])
    parser.add_argument("--output_dir",
                        type=str,
                        help="Output base name, e.g. /path/to/output/record")
    parser.add_argument("--max_file_size",
                        type=int,
                        default=100,
                        help="Shard size, defaults to 100MB")
    parser.add_argument(
        "--stride",
        type=int,
        help="Tokens to stride before next read, defaults to `nctx`")
    parser.add_argument("--tok_on_eol", type=str, default="<EOS>")
    parser.add_argument("--cased", type=baseline.str2bool, default=True)
    parser.add_argument("--document_vocab", type=str, default="document.vocab")
    parser.add_argument("--label_vocab", type=str, default="label.vocab")
    parser.add_argument("--valid_split", type=float, default=0.05)
    parser.add_argument("--prefix", default="<GO>")
    parser.add_argument("--suffix", default="<EOS>")
    parser.add_argument("--pg_name",
                        choices=["tqdm", "default"],
                        default="default")

    args = parser.parse_args()
    annot_files = list(Path(args.annot_files).iterdir())
    valid_split = int(len(annot_files) * args.valid_split)
    VALID_FILES = annot_files[:valid_split]
    TRAIN_FILES = annot_files[valid_split:]

    VECTORIZER = BPEVectorizer1D(
        transform_fn=baseline.lowercase if not args.cased else lambda x: x,
        model_file=args.codes,
        vocab_file=args.vocab,
        mxlen=1024)
    NCTX = args.nctx - 2
    PREFIX = (
        VECTORIZER.vocab[args.prefix],
        Offsets.GO,
    )
    SUFFIX = (
        VECTORIZER.vocab[args.suffix],
        Offsets.EOS,
    )

    DOC2WORD = read_vocab_file(args.document_vocab)
    label2word = read_vocab_file(args.label_vocab)
    LABELS = {Offsets.VALUES[k]: k for k in range(Offsets.OFFSET)}
    for label in label2word.values():
        for prefix in ["B", "I", "E", "S"]:
            LABELS[f"{prefix}-{label}"] = len(LABELS)

    LABELS["O"] = len(LABELS)

    if not os.path.exists(args.output_dir):
        os.makedirs(args.output_dir)
    write_json(LABELS, os.path.join(args.output_dir, 'labels.json'))
    valid_dir = os.path.join(args.output_dir, 'valid')
    train_dir = os.path.join(args.output_dir, 'train')
    makedir_if_none(args.output_dir)
    makedir_if_none(train_dir)
    makedir_if_none(valid_dir)

    logger.info("Converting validation files")
    fw_valid = create_file_writer(args.fmt, os.path.join(valid_dir, 'valid'),
                                  args.fields, args.max_file_size)
    write_files(VALID_FILES, args.input_files, fw_valid, valid_dir,
                args.pg_name)

    logger.info("Converting training files")
    fw_train = create_file_writer(args.fmt, os.path.join(train_dir, 'train'),
                                  args.fields, args.max_file_size)
    write_files(TRAIN_FILES, args.input_files, fw_train, train_dir,
                args.pg_name)