示例#1
0
 def download(self):
     file_loc = self.dataset_file
     if is_file_correct(file_loc):
         return file_loc
     elif validate_url(
             file_loc):  # is it a web URL? check if exists in cache
         url = file_loc
         dcache_path = os.path.join(self.data_download_cache,
                                    DATA_CACHE_CONF)
         dcache = read_json(dcache_path)
         if url in dcache and is_file_correct(
                 dcache[url], self.data_download_cache,
                 url) and not self.cache_ignore:
             logger.info(
                 "file for {} found in cache, not downloading".format(url))
             return dcache[url]
         else:  # download the file in the cache, update the json
             cache_dir = self.data_download_cache
             logger.info(
                 "using {} as data/embeddings cache".format(cache_dir))
             temp_file = web_downloader(url)
             dload_file = extractor(filepath=temp_file,
                                    cache_dir=cache_dir,
                                    extractor_func=Downloader.ZIPD.get(
                                        mime_type(temp_file), None))
             dcache.update({url: dload_file})
             write_json(
                 dcache,
                 os.path.join(self.data_download_cache, DATA_CACHE_CONF))
             return dload_file
     raise RuntimeError(
         "the file [{}] is not in cache and can not be downloaded".format(
             file_loc))
示例#2
0
 def download(self):
     if is_file_correct(self.embedding_file):
         logger.info("embedding file location: {}".format(
             self.embedding_file))
         return self.embedding_file
     dcache_path = os.path.join(self.data_download_cache, DATA_CACHE_CONF)
     dcache = read_json(dcache_path)
     if self.embedding_file in dcache and not self.cache_ignore:
         download_loc = dcache[self.embedding_file]
         logger.info("files for {} found in cache".format(
             self.embedding_file))
         return self._get_embedding_file(download_loc, self.embedding_key)
     else:  # try to download the bundle and unzip
         url = self.embedding_file
         if not validate_url(url):
             raise RuntimeError("can not download from the given url")
         else:
             cache_dir = self.data_download_cache
             temp_file = web_downloader(url)
             unzip_fn = Downloader.ZIPD.get(
                 mime_type(temp_file)) if self.unzip_file else None
             download_loc = extractor(filepath=temp_file,
                                      cache_dir=cache_dir,
                                      extractor_func=unzip_fn)
             if self.sha1 is not None:
                 if os.path.split(download_loc)[-1] != self.sha1:
                     raise RuntimeError(
                         "The sha1 of the downloaded file does not match with the provided one"
                     )
             dcache.update({url: download_loc})
             write_json(
                 dcache,
                 os.path.join(self.data_download_cache, DATA_CACHE_CONF))
             return self._get_embedding_file(download_loc,
                                             self.embedding_key)
示例#3
0
 def save_md(self, target):
     write_json(
         {
             "vsz": self.get_vsz(),
             "dsz": self.get_dsz(),
             "vocab": self.get_vocab()
         }, target)
示例#4
0
    def save_md(self, basename):
        """This method saves out a `.state` file containing meta-data from these classes and any info
        registered by a user-defined derived class as a `property`. Also write the `graph` and `saver` and `labels`

        :param basename:
        :return:
        """

        write_json(self._state, basename + '.state')
        for key, embedding in self.embeddings.items():
            embedding.save_md(basename + '-{}-md.json'.format(key))
示例#5
0
    def save_md(self, basename):
        """
        This method saves out a `.state` file containing meta-data from these classes and any info
        registered by a user-defined derived class as a `property`. Also write the `graph` and `saver` and `labels`

        :param basename: The name of the model prefix
        :return: None
        """
        write_json(self._state, '{}.state'.format(basename))
        write_json(self.labels, '{}.labels'.format(basename))
        for key, embedding in self.embeddings.items():
            embedding.save_md('{}-{}-md.json'.format(basename, key))
示例#6
0
    def download(self):
        dload_bundle = self.dataset_desc.get("download", None)
        if dload_bundle is not None:  # download a zip/tar/tar.gz directory, look for train, dev test files inside that.
            dcache_path = os.path.join(self.data_download_cache,
                                       DATA_CACHE_CONF)
            dcache = read_json(dcache_path)
            if dload_bundle in dcache and \
                    is_dir_correct(dcache[dload_bundle], self.dataset_desc, self.data_download_cache, dload_bundle,
                                   self.enc_dec) and not self.cache_ignore:
                download_dir = dcache[dload_bundle]
                logger.info(
                    "files for {} found in cache, not downloading".format(
                        dload_bundle))
                updated = _update_md(self.dataset_desc, download_dir)
                return updated
            else:  # try to download the bundle and unzip
                if not validate_url(dload_bundle):
                    raise RuntimeError("can not download from the given url")
                else:
                    cache_dir = self.data_download_cache
                    temp_file = web_downloader(dload_bundle)

                    download_dir = extractor(
                        filepath=temp_file,
                        cache_dir=cache_dir,
                        extractor_func=Downloader.ZIPD.get(
                            mime_type(temp_file), None))
                    if "sha1" in self.dataset_desc:
                        if os.path.split(
                                download_dir)[-1] != self.dataset_desc["sha1"]:
                            raise RuntimeError(
                                "The sha1 of the downloaded file does not match with the provided one"
                            )
                    dcache.update({dload_bundle: download_dir})
                    write_json(
                        dcache,
                        os.path.join(self.data_download_cache,
                                     DATA_CACHE_CONF))
                    updated = _update_md(self.dataset_desc, download_dir)
                    return updated
        else:  # we have download links to every file or they exist
            updated = _update_md(self.dataset_desc, None)
            if not self.enc_dec:
                updated.update({
                    k:
                    SingleFileDownloader(self.dataset_desc[k],
                                         self.data_download_cache).download()
                    for k in self.dataset_desc
                    if k.endswith("_file") and self.dataset_desc[k]
                })
            return updated
示例#7
0
 def download(self):
     file_loc = self.dataset_file
     if is_file_correct(file_loc):
         return file_loc
     elif validate_url(
             file_loc):  # is it a web URL? check if exists in cache
         url = file_loc
         dcache_path = os.path.join(self.data_download_cache,
                                    DATA_CACHE_CONF)
         dcache = read_json(dcache_path)
         # If the file already exists in the cache
         if url in dcache and is_file_correct(
                 dcache[url], self.data_download_cache,
                 url) and not self.cache_ignore:
             logger.info(
                 "file for {} found in cache, not downloading".format(url))
             return dcache[url]
         # Otherwise, we want it to be placed in ~/.bl-cache/addons
         else:  # download the file in the cache, update the json
             cache_dir = self.data_download_cache
             addon_path = os.path.join(cache_dir,
                                       AddonDownloader.ADDON_SUBPATH)
             if not os.path.exists(addon_path):
                 os.makedirs(addon_path)
             path_to_save = os.path.join(addon_path,
                                         os.path.basename(file_loc))
             logger.info("using {} as data/addons cache".format(cache_dir))
             web_downloader(url, path_to_save)
             dcache.update({url: path_to_save})
             write_json(
                 dcache,
                 os.path.join(self.data_download_cache, DATA_CACHE_CONF))
             return path_to_save
     raise RuntimeError(
         "the file [{}] is not in cache and can not be downloaded".format(
             file_loc))
示例#8
0
def train():
    parser = ArgumentParser()
    parser.add_argument("--basedir", type=str)
    parser.add_argument("--train_dir",
                        type=str,
                        required=True,
                        help='Training directory')
    parser.add_argument("--valid_dir",
                        type=str,
                        required=True,
                        help='Validation directory')
    parser.add_argument(
        "--train_md",
        type=str,
        help="Training metadata YAML, defaults to `{train_dir}/md.yml`")
    parser.add_argument(
        "--valid_md",
        type=str,
        help="Validation metadata YAML, defaults to `{valid_dir}/md.yml`")
    parser.add_argument("--dataset_key",
                        default="tlm",
                        help="dataset key for basedir")
    parser.add_argument(
        "--embed_type",
        type=str,
        default='default',
        choices=["default", "positional", "learned-positional"],
        help="register label of the embeddings")

    parser.add_argument("--gen_d_model",
                        type=int,
                        default=256,
                        help="Model dimension (and embedding dsz)")
    parser.add_argument("--gen_d_ff",
                        type=int,
                        default=1024,
                        help="FFN dimension")
    parser.add_argument(
        "--gen_d_k",
        type=int,
        default=None,
        help="Dimension per head.  Use if num_heads=1 to reduce dims")
    parser.add_argument("--gen_num_heads",
                        type=int,
                        default=8,
                        help="Number of heads")
    parser.add_argument("--gen_num_layers",
                        type=int,
                        default=8,
                        help="Number of layers")
    parser.add_argument(
        '--gen_rpr_k',
        help=
        'Relative attention positional sizes pass 0 if you dont want relative attention',
        type=int,
        default=[8],
        nargs='+')
    parser.add_argument('--windowed_ra',
                        type=str2bool,
                        default=False,
                        help="whether prevent attention beyond rpr_k")
    parser.add_argument("--gen_loss_scale",
                        type=float,
                        default=50.0,
                        help="Scaling for loss function")
    parser.add_argument("--gen_dropout",
                        type=float,
                        default=0.1,
                        help="Dropout")

    parser.add_argument(
        '--discrim_rpr_k',
        help=
        'Relative attention positional sizes pass 0 if you dont want relative attention',
        type=int,
        default=[8],
        nargs='+')

    parser.add_argument("--discrim_d_model",
                        type=int,
                        default=512,
                        help="Model dimension (and embedding dsz)")
    parser.add_argument("--discrim_d_ff",
                        type=int,
                        default=2048,
                        help="FFN dimension")
    parser.add_argument(
        "--discrim_d_k",
        type=int,
        default=None,
        help="Dimension per head.  Use if num_heads=1 to reduce dims")
    parser.add_argument("--discrim_num_heads",
                        type=int,
                        default=8,
                        help="Number of heads")
    parser.add_argument("--discrim_num_layers",
                        type=int,
                        default=8,
                        help="Number of layers")
    parser.add_argument("--discrim_dropout",
                        type=float,
                        default=0.1,
                        help="Dropout")

    parser.add_argument("--num_train_workers",
                        type=int,
                        default=4,
                        help="Number train workers")
    parser.add_argument("--distribute",
                        type=str,
                        default="mirror",
                        choices=["mirror", "tpu", "nccl"])
    parser.add_argument("--tpu_ep",
                        type=str,
                        help="The TPU endpoint if using `distribute=tpu`")
    parser.add_argument("--nctx",
                        type=int,
                        default=256,
                        help="Max input length")
    parser.add_argument("--file_type",
                        default='tfrecord',
                        choices=['json', 'tfrecord'],
                        help="Glob pattern for data")
    parser.add_argument("--batch_size",
                        type=int,
                        default=256,
                        help="Batch Size")
    parser.add_argument("--subword_model_file",
                        type=str,
                        help="The BPE model file",
                        required=True)
    parser.add_argument("--subword_vocab_file",
                        type=str,
                        help="The BPE subword vocab",
                        required=True)
    parser.add_argument("--optim",
                        default="adam",
                        type=str,
                        help="Optimizer to use (defaults to adam)")
    parser.add_argument("--lr",
                        type=float,
                        default=4.0e-4,
                        help="Learning rate")
    parser.add_argument("--clip",
                        type=float,
                        default=1.0,
                        help="Clipping gradient norm")
    parser.add_argument("--weight_decay",
                        type=float,
                        default=1.0e-2,
                        help="Weight decay")
    parser.add_argument("--epochs",
                        type=int,
                        default=32,
                        help="Num training epochs")
    parser.add_argument(
        "--restart",
        type=str2bool,
        help="Option allows you to restart from a previous checkpoint")
    parser.add_argument("--warmup_steps",
                        type=int,
                        default=10000,
                        help="Num warmup steps")
    parser.add_argument("--causal",
                        type=str2bool,
                        default=False,
                        help="Use CLM (causal) instead of MLM")
    parser.add_argument("--saves_per_epoch",
                        type=int,
                        default=10,
                        help="The number of checkpoints to save per epoch")
    parser.add_argument("--strategy",
                        help="Training strategy, defaults to `mirror`",
                        choices=["mirror"])
    parser.add_argument("--npz",
                        help="Should we write out NPZ files?",
                        type=str2bool,
                        default=False)
    parser.add_argument("--tb",
                        help="Turn on tensorboard?",
                        type=str2bool,
                        default=False)
    parser.add_argument(
        "--convert_only",
        help="Should we just convert this file to NPZ and exit?",
        type=str2bool,
        default=False)
    args = parser.parse_args()
    SET_TRAIN_FLAG(True)

    if args.convert_only:
        args.restart = True
        args.npz = True

    if args.basedir is None:
        args.basedir = f'discrim-{args.dataset_key}-bpe-{os.getpid()}'
    logging.basicConfig(level=logging.INFO)
    logger.info(f"Writing results to {args.basedir}")

    if args.tb:
        logdir = f"logs/scalars/{os.getpid()}"
        file_writer = tf.summary.create_file_writer(logdir + "/metrics")
        file_writer.set_as_default()
        logger.info(f"Set up tensorboard logdir {logdir}")

    strategy = create_distribute_strategy(args.distribute, args.tpu_ep)
    num_replicas = strategy.num_replicas_in_sync
    logger.info(f"Using {num_replicas} replicas in this job.")
    vectorizer = BPEVectorizer1D(model_file=args.subword_model_file,
                                 vocab_file=args.subword_vocab_file,
                                 mxlen=args.nctx)
    vocab = {'x': vectorizer.vocab}
    gen_preproc_data = baseline.embeddings.load_embeddings(
        'x',
        dsz=args.gen_d_model,
        known_vocab=vocab['x'],
        preserve_vocab_indices=True,
        embed_type=args.embed_type)

    vocabs = gen_preproc_data['vocab']

    discrim_preproc_data = baseline.embeddings.load_embeddings(
        'x',
        dsz=args.discrim_d_model,
        known_vocab=vocab['x'],
        preserve_vocab_indices=True,
        embed_type=args.embed_type)

    def dataset_train_fn(input_context):
        batch_size = input_context.get_per_replica_batch_size(args.batch_size)
        ds = get_dataset(args.train_dir, args.file_type,
                         args.num_train_workers).batch(batch_size)
        return ds.shard(input_context.num_input_pipelines,
                        input_context.input_pipeline_id)

    train_loader = strategy.experimental_distribute_datasets_from_function(
        dataset_train_fn)

    def dataset_test_fn(input_context):
        batch_size = input_context.get_per_replica_batch_size(args.batch_size)
        ds = get_dataset(args.valid_dir,
                         args.file_type,
                         args.num_train_workers,
                         shuffle=False).batch(batch_size)
        return ds.shard(input_context.num_input_pipelines,
                        input_context.input_pipeline_id)

    valid_loader = strategy.experimental_distribute_datasets_from_function(
        dataset_test_fn)

    train_md = args.train_md if args.train_md else os.path.join(
        args.train_dir, 'md.yml')
    num_train_samples = get_num_samples(train_md)
    valid_md = args.valid_md if args.valid_md else os.path.join(
        args.valid_dir, 'md.yml')
    num_valid_samples = get_num_samples(valid_md)
    os.makedirs(args.basedir, exist_ok=True)
    # We want to make sure to save our input vocab into the basedir for reuse later
    write_json(vocabs, os.path.join(args.basedir, 'vocabs.json'))
    gen_embeddings = {'x': gen_preproc_data['embeddings']}
    discrim_embeddings = {'x': discrim_preproc_data['embeddings']}
    logger.info("Loaded embeddings")

    logger.info("Loaded datasets")
    logger.info("Using embedding type [%s]", args.embed_type)
    if len(args.gen_rpr_k) == 0 or args.gen_rpr_k[0] < 1:
        gen_rpr_k = None
    elif len(args.gen_rpr_k) == 1:
        gen_rpr_k = args.gen_rpr_k[0]
    else:
        gen_rpr_k = args.gen_rpr_k

    if len(args.discrim_rpr_k) == 0 or args.discrim_rpr_k[0] < 1:
        discrim_rpr_k = None
    elif len(args.gen_rpr_k) == 1:
        discrim_rpr_k = args.discrim_rpr_k[0]
    else:
        discrim_rpr_k = args.discrim_rpr_k

    gen_model = TransformerMaskedLanguageModel.create(
        gen_embeddings,
        hsz=args.gen_d_model,
        d_ff=args.gen_d_ff,
        tie_weights=True,
        dropout=args.gen_dropout,
        gpu=False,
        num_heads=args.gen_num_heads,
        layers=args.gen_num_layers,
        rpr_k=gen_rpr_k,
        d_k=args.gen_d_k,
        windowed_ra=args.windowed_ra,
        src_keys=['x'],
        tgt_key='x')

    discrim_model = TransformerDiscriminator(discrim_embeddings,
                                             d_model=args.discrim_d_model,
                                             d_ff=args.discrim_d_ff,
                                             dropout=args.discrim_dropout,
                                             num_heads=args.discrim_num_heads,
                                             layers=args.discrim_num_layers,
                                             rpr_k=discrim_rpr_k,
                                             d_k=args.discrim_d_k)

    logger.info("Loaded model and loss")
    steps_per_epoch = num_train_samples // args.batch_size
    steps_per_valid_epoch = num_valid_samples // args.batch_size
    update_on = steps_per_epoch // args.saves_per_epoch
    report_on = max(10, update_on) // 10
    logger.info(
        f"Steps per epoch: {steps_per_epoch}. Saving checkpoint every {update_on} steps."
    )

    lr_decay = CosineDecaySchedulerTensorFlow(steps_per_epoch * args.epochs,
                                              lr=args.lr)
    linear_warmup = WarmupLinearSchedulerTensorFlow(args.warmup_steps,
                                                    lr=args.lr)
    lr_sched = CompositeLRSchedulerTensorFlow(linear_warmup, lr_decay)

    mask_value = vocabs.get("[MASK]", vocabs.get("<MASK>", -1))
    if mask_value == -1:
        logger.error("We could not find a suitable masking token in the vocab")
        return

    optimizer, clip = create_keras_optimizer(**vars(args))

    discrim_checkpoint = tf.train.Checkpoint(optimizer=optimizer,
                                             model=discrim_model)
    discrim_checkpoint_manager = tf.train.CheckpointManager(
        discrim_checkpoint,
        directory=os.path.join(args.basedir, 'discrim'),
        max_to_keep=5)

    gen_checkpoint = tf.train.Checkpoint(optimizer=optimizer,
                                         model=discrim_model)
    gen_checkpoint_manager = tf.train.CheckpointManager(gen_checkpoint,
                                                        directory=os.path.join(
                                                            args.basedir,
                                                            'gen'),
                                                        max_to_keep=5)

    mask_value = vocabs.get("[MASK]", vocabs.get("<MASK>", -1))
    if mask_value == -1:
        logger.error("We could not find a suitable masking token in the vocab")
        return

    if args.restart:
        # The global step gets automatically updated here
        # so we dont have to worry about our LR regimen
        gen_checkpoint.restore(gen_checkpoint_manager.latest_checkpoint)
        discrim_checkpoint.restore(
            discrim_checkpoint_manager.latest_checkpoint)

    def _replicated_train_step(inputs):
        """This runs on a single replica"""
        noised_x, labels = inputs
        with tf.GradientTape() as tape:
            gen_loss_step, discrim_loss_step, acc = gen_vs_discrim(
                noised_x, labels, gen_model, discrim_model, mask_value)
            loss_value = (args.gen_loss_scale * gen_loss_step +
                          discrim_loss_step) / num_replicas

        grads = tape.gradient(
            loss_value,
            gen_model.trainable_variables + discrim_model.trainable_variables)
        grads, _ = tf.clip_by_global_norm(grads, clip)
        optimizer.apply_gradients(
            zip(
                grads, gen_model.trainable_variables +
                discrim_model.trainable_variables))

        return loss_value, gen_loss_step, discrim_loss_step, acc

    @tf.function
    def _distributed_train_step(inputs: Tuple[tf.Tensor, tf.Tensor]):
        """Runs across multiple replicas and aggregates the results.

        :param inputs:
        :return:
        """
        loss, gen_loss, discrim_loss, acc = strategy.run(
            _replicated_train_step, args=(inputs, ))
        sum_loss = strategy.reduce(tf.distribute.ReduceOp.SUM, loss, axis=None)
        sum_gen_loss = strategy.reduce(tf.distribute.ReduceOp.SUM,
                                       gen_loss,
                                       axis=None)
        sum_discrim_loss = strategy.reduce(tf.distribute.ReduceOp.SUM,
                                           discrim_loss,
                                           axis=None)
        sum_acc = strategy.reduce(tf.distribute.ReduceOp.SUM, acc, axis=None)
        return sum_loss, sum_gen_loss, sum_discrim_loss, sum_acc

    def _replicated_test_step(inputs):
        """This runs on a single replica"""
        noised_x, labels = inputs
        gen_loss_step, discrim_loss_step, acc = gen_vs_discrim(
            noised_x, labels, gen_model, discrim_model, mask_value)
        loss_value = (args.gen_loss_scale * gen_loss_step +
                      discrim_loss_step) / num_replicas
        return loss_value, gen_loss_step, discrim_loss_step, acc

    @tf.function
    def _distributed_test_step(inputs: Tuple[tf.Tensor, tf.Tensor]):
        """Runs across multiple replicas and aggregates the results.

        :param inputs:
        :return:
        """
        loss, gen_loss, discrim_loss, acc = strategy.run(_replicated_test_step,
                                                         args=(inputs, ))
        sum_loss = strategy.reduce(tf.distribute.ReduceOp.SUM, loss, axis=None)
        sum_gen_loss = strategy.reduce(tf.distribute.ReduceOp.SUM,
                                       gen_loss,
                                       axis=None)
        sum_discrim_loss = strategy.reduce(tf.distribute.ReduceOp.SUM,
                                           discrim_loss,
                                           axis=None)
        sum_acc = strategy.reduce(tf.distribute.ReduceOp.SUM, acc, axis=None)
        return sum_loss, sum_gen_loss, sum_discrim_loss, sum_acc

    # This is the training loop
    start_epoch = 0
    timer = Timer()
    with strategy.scope():

        for epoch in range(start_epoch, args.epochs):
            SET_TRAIN_FLAG(True)
            logger.info('Starting epoch %d', epoch + 1)
            avg_loss = Average('average_train_loss')
            avg_gen_loss = Average('average_gen_loss')
            avg_discrim_loss = Average('average_discrim_loss')
            avg_acc = Average('average_train_acc')

            metrics = {}
            timer.start()
            train_iter = iter(train_loader)
            for i in range(steps_per_epoch):
                loss, gen_loss, discrim_loss, acc = _distributed_train_step(
                    next(train_iter))
                avg_loss.update(loss.numpy().item())
                avg_gen_loss.update(gen_loss.numpy().item())
                avg_discrim_loss.update(discrim_loss.numpy().item())
                avg_acc.update(acc.numpy().item())

                tf.summary.scalar("train_loss",
                                  data=loss,
                                  step=optimizer.iterations)
                tf.summary.scalar("train_gen_loss",
                                  data=gen_loss,
                                  step=optimizer.iterations)
                tf.summary.scalar("train_discrim_loss",
                                  data=discrim_loss,
                                  step=optimizer.iterations)
                tf.summary.scalar("train_acc",
                                  data=acc,
                                  step=optimizer.iterations)

                if args.convert_only:
                    logger.warning(
                        "Convert only flag specified.  Stopping after one step"
                    )
                    steps = optimizer.iterations.numpy()
                    npz_checkpoint = os.path.join(args.basedir,
                                                  f'discrim-step-{steps}.npz')
                    save_tlm_npz(discrim_model, npz_checkpoint)
                    npz_checkpoint = os.path.join(args.basedir,
                                                  f'gen-step-{steps}.npz')
                    save_tlm_npz(gen_model, npz_checkpoint)
                    return

                if (i + 1) % report_on == 0:
                    logging.info(avg_loss)
                    logging.info(avg_gen_loss)
                    logging.info(avg_discrim_loss)
                    logging.info(avg_acc)
                if (i + 1) % update_on == 0:
                    elapsed = timer.elapsed(True)
                    logging.info('elapsed time this epoch %d min', elapsed)
                    logging.info('elapsed step time %f steps/min', i / elapsed)
                    gen_checkpoint_manager.save()
                    discrim_checkpoint_manager.save()

                    if args.npz:
                        steps = optimizer.iterations.numpy()
                        npz_checkpoint = os.path.join(
                            args.basedir, f'discrim-step-{steps}.npz')
                        save_tlm_npz(discrim_model, npz_checkpoint)
                        npz_checkpoint = os.path.join(args.basedir,
                                                      f'gen-step-{steps}.npz')
                        save_tlm_npz(gen_model, npz_checkpoint)

            # This is the average training token-level loss across all machines
            # This is the token-level training perplexity
            metrics['train_elapsed_min'] = timer.elapsed(True)
            metrics['average_train_loss'] = avg_loss.avg
            metrics['average_gen_loss'] = avg_gen_loss.avg
            metrics['average_discrim_loss'] = avg_discrim_loss.avg
            metrics['average_train_acc'] = avg_acc.avg
            metrics['lr'] = float(
                lr_sched(tf.cast(optimizer.global_step,
                                 tf.float32)).numpy().item())

            avg_valid_loss = Average('average_valid_loss')
            avg_valid_gen_loss = Average('average_valid_gen_loss')
            avg_valid_discrim_loss = Average('average_valid_discrim_loss')
            avg_valid_acc = Average('average_valid_acc')

            timer.start()
            SET_TRAIN_FLAG(False)
            valid_iter = iter(valid_loader)
            for i in range(steps_per_valid_epoch):
                valid_loss, valid_gen_loss, valid_discrim_loss, valid_acc = _distributed_test_step(
                    next(valid_iter))
                tf.summary.scalar('valid_loss',
                                  data=valid_loss,
                                  step=optimizer.iterations)
                tf.summary.scalar('valid_gen_loss',
                                  data=valid_gen_loss,
                                  step=optimizer.iterations)
                tf.summary.scalar('valid_discrim_loss',
                                  data=valid_discrim_loss,
                                  step=optimizer.iterations)
                tf.summary.scalar('valid_acc',
                                  data=valid_acc,
                                  step=optimizer.iterations)
                avg_valid_loss.update(valid_loss.numpy().item())
                avg_valid_gen_loss.update(valid_gen_loss.numpy().item())
                avg_valid_discrim_loss.update(
                    valid_discrim_loss.numpy().item())
                avg_valid_acc.update(valid_acc.numpy().item())

            metrics['valid_elapsed_min'] = timer.elapsed(True)
            metrics['average_valid_loss'] = avg_valid_loss.avg
            metrics['average_valid_gen_loss'] = avg_valid_gen_loss.avg
            metrics['average_valid_discrim_loss'] = avg_valid_discrim_loss.avg
            metrics['average_valid_acc'] = avg_valid_acc.avg
            logger.info(json.dumps(metrics, indent=4))
def train():
    parser = ArgumentParser()
    parser.add_argument("--basedir", type=str)
    parser.add_argument("--train_file",
                        type=str,
                        required=True,
                        help='File path to use for train file')
    parser.add_argument("--valid_file",
                        type=str,
                        required=True,
                        help='File path to use for valid file')
    parser.add_argument("--dataset_key",
                        default="paired",
                        help="dataset key for basedir")
    parser.add_argument(
        "--embed_type",
        type=str,
        default='default',
        choices=["default", "positional", "learned-positional"],
        help="register label of the embeddings")
    parser.add_argument("--d_model",
                        type=int,
                        default=512,
                        help="Model dimension (and embedding dsz)")
    parser.add_argument("--d_ff", type=int, default=2048, help="FFN dimension")
    parser.add_argument(
        "--d_k",
        type=int,
        default=None,
        help="Dimension per head.  Use if num_heads=1 to reduce dims")

    parser.add_argument("--num_heads",
                        type=int,
                        default=8,
                        help="Number of heads")
    parser.add_argument("--num_layers",
                        type=int,
                        default=8,
                        help="Number of layers")
    parser.add_argument("--windowed_ra",
                        type=str2bool,
                        default=False,
                        help="whether prevent attention beyond rpr_k")
    parser.add_argument("--num_train_workers",
                        type=int,
                        default=4,
                        help="Number train workers")
    parser.add_argument("--nctx",
                        type=int,
                        default=256,
                        help="Max input length")
    parser.add_argument("--pattern",
                        default='*.json',
                        help="Glob pattern for data")
    parser.add_argument("--batch_size",
                        type=int,
                        default=256,
                        help="Batch Size")
    parser.add_argument("--subword_model_file",
                        type=str,
                        help="The BPE model file",
                        required=True)
    parser.add_argument("--subword_vocab_file",
                        type=str,
                        help="The BPE subword vocab",
                        required=True)
    parser.add_argument("--dropout", type=float, default=0.1, help="Dropout")
    parser.add_argument("--lr_scheduler",
                        type=str,
                        default='cosine',
                        help="The type of learning rate decay scheduler")
    parser.add_argument("--lr_decay_steps",
                        type=int,
                        help="decay steps of lr scheduler")
    parser.add_argument("--lr_decay_rate",
                        type=float,
                        help="decay rate of lr scheduler")
    parser.add_argument("--lr_alpha",
                        type=float,
                        help="parameter alpha for cosine decay scheduler")
    parser.add_argument("--optim",
                        default="adamw",
                        type=str,
                        help="Optimizer to use (defaults to adamw)")
    parser.add_argument("--lr",
                        type=float,
                        default=4.0e-4,
                        help="Learning rate")
    parser.add_argument("--clip",
                        type=float,
                        default=1.0,
                        help="Clipping gradient norm")
    parser.add_argument("--weight_decay",
                        type=float,
                        default=1.0e-2,
                        help="Weight decay")
    parser.add_argument("--epochs",
                        type=int,
                        default=32,
                        help="Num training epochs")
    parser.add_argument(
        "--restart_from",
        type=str,
        help="Option allows you to restart from a previous checkpoint")
    parser.add_argument(
        "--restart_tt",
        type=str,
        help="Optional param for legacy checkpoints (step|epoch)")
    parser.add_argument("--warmup_steps",
                        type=int,
                        default=10000,
                        help="Num warmup steps")
    parser.add_argument("--saves_per_epoch",
                        type=int,
                        default=10,
                        help="The number of checkpoints to save per epoch")
    parser.add_argument("--reduction_d_k",
                        type=int,
                        default=64,
                        help="Dimensions of Key and Query in the single headed"
                        "reduction layers")
    parser.add_argument(
        "--stacking_layers",
        type=int,
        nargs='+',
        default=[1024, 1024, 1024],
        help="Hidden sizes of the dense stack (ff2 from the convert paper)")
    parser.add_argument("--ff_pdrop",
                        type=float,
                        default=0.1,
                        help="Dropout in the dense stack")

    parser.add_argument("--reader_type",
                        type=str,
                        default='preprocessed',
                        choices=['ntp', 'nsp', 'preprocessed'])

    parser.add_argument("--model_type",
                        default="dual-encoder",
                        choices=["dual-encoder", "encoder-decoder"])
    parser.add_argument("--loss",
                        type=str,
                        default='all',
                        choices=['triplet', 'all'])
    parser.add_argument(
        '--rpr_k',
        help=
        'Relative attention positional sizes pass 0 if you dont want relative attention',
        type=int,
        default=[8],
        nargs='+')
    parser.add_argument("--device",
                        type=str,
                        default="cuda" if torch.cuda.is_available() else "cpu",
                        help="Device (cuda or cpu)")
    parser.add_argument("--distributed",
                        type=str2bool,
                        default=False,
                        help="Are we doing distributed training?")
    parser.add_argument(
        "--local_rank",
        type=int,
        default=-1,
        help=
        "Local rank for distributed training (-1 means use the environment variables to find)"
    )

    args = parser.parse_args()

    if args.basedir is None:
        args.basedir = '{}-{}-paired-{}-bpe-{}'.format(args.model_type,
                                                       args.reader_type,
                                                       args.dataset_key,
                                                       os.getpid())
    logging.basicConfig(
        level=logging.INFO if args.local_rank in [-1, 0] else logging.WARN)

    num_gpus = get_num_gpus_multiworker()
    args.distributed = args.distributed or num_gpus > 1
    logger.info(f"Using {num_gpus} GPUs in this job.")

    if args.distributed:
        args.device, updated_local_rank = init_distributed(args.local_rank)
        args.local_rank = updated_local_rank

    reader = MultiFileDatasetReader(args.nctx,
                                    args.subword_model_file,
                                    args.subword_vocab_file,
                                    args.pattern,
                                    reader_type=args.reader_type)

    vocab = reader.build_vocab()
    # If we are not using chars, then use 'x' for both input and output
    preproc_data = baseline.embeddings.load_embeddings(
        'x',
        dsz=args.d_model,
        known_vocab=vocab['x'],
        preserve_vocab_indices=True,
        embed_type=args.embed_type)
    vocabs = preproc_data['vocab']

    os.makedirs(args.basedir, exist_ok=True)
    # We want to make sure to save our input vocab into the basedir for reuse later
    write_json(vocabs, os.path.join(args.basedir, 'vocabs.json'))
    embeddings = preproc_data['embeddings']
    logger.info("Loaded embeddings")

    train_set = reader.load(args.train_file, vocabs)
    valid_set = reader.load(args.valid_file,
                            vocabs,
                            distribute=False,
                            shuffle=False)

    train_loader = DataLoader(train_set,
                              batch_size=args.batch_size,
                              num_workers=args.num_train_workers)
    valid_loader = DataLoader(valid_set, batch_size=args.batch_size)
    logger.info("Loaded datasets")
    logger.info("Using embedding type [%s]", args.embed_type)

    if len(args.rpr_k) == 0 or args.rpr_k[0] < 1:
        rpr_k = None
    elif len(args.rpr_k) == 1:
        rpr_k = args.rpr_k[0]
    else:
        rpr_k = args.rpr_k

    model = create_model(embeddings,
                         d_model=args.d_model,
                         d_ff=args.d_ff,
                         dropout=args.dropout,
                         num_heads=args.num_heads,
                         num_layers=args.num_layers,
                         model_type=args.model_type,
                         rpr_k=rpr_k,
                         d_k=args.d_k,
                         reduction_d_k=args.reduction_d_k,
                         stacking_layers=args.stacking_layers,
                         ff_pdrop=args.ff_pdrop,
                         windowed_ra=args.windowed_ra,
                         logger=logger)

    model.to(args.device)
    loss_function = model.create_loss(args.loss)
    loss_function.to(args.device)

    logger.info("Loaded model and loss")

    # according to pytorch, len(train_loader) will return len(train_set) when train_set is IterableDataset, so manually
    # correct it here
    steps_per_epoch = len(train_loader) // (args.batch_size * num_gpus)
    valid_steps = len(valid_loader) // args.batch_size
    update_on = steps_per_epoch // args.saves_per_epoch
    report_on = max(10, update_on) // 10
    logger.info(
        f"Steps per epoch per GPU: {steps_per_epoch}. Saving checkpoint every {update_on} steps."
    )
    lr_decay = get_lr_decay(args.lr_scheduler,
                            args.lr,
                            steps_per_epoch,
                            args.epochs,
                            logger,
                            decay_steps=args.lr_decay_steps,
                            decay_rate=args.lr_decay_rate,
                            alpha=args.lr_alpha)
    linear_warmup = WarmupLinearSchedulerPyTorch(args.warmup_steps, lr=args.lr)
    lr_sched = CompositeLRScheduler(linear_warmup, lr_decay, lr=args.lr)

    global_step = 0
    start_epoch = 0

    if args.restart_from:
        model.load_state_dict(torch.load(args.restart_from))
        vec = args.restart_from.split("-")

        if args.restart_tt:
            tick_type = args.restart_tt
        else:
            tick_type = vec[-2]
        step_num = int(vec[-1].split(".")[0])
        if tick_type == 'epoch':
            start_epoch = step_num
            global_step = start_epoch * steps_per_epoch

        else:
            start_epoch = step_num // steps_per_epoch
            global_step = step_num

        logger.info(
            "Restarting from a previous checkpoint %s.\n\tStarting at global_step=%d, epoch=%d",
            args.restart_from, global_step, start_epoch + 1)
    optimizer = OptimizerManager(model,
                                 global_step,
                                 optim=args.optim,
                                 lr=args.lr,
                                 lr_function=lr_sched,
                                 weight_decay=args.weight_decay)
    logger.info("Model has {:,} parameters".format(
        sum(p.numel() for p in model.parameters() if p.requires_grad)))

    # Prepare model for distributed training if needed
    if args.distributed:
        model = DistributedDataParallel(model,
                                        device_ids=[args.local_rank],
                                        output_device=args.local_rank)
        logger.info("Model located on %d", args.local_rank)

    model_base = os.path.join(args.basedir, 'checkpoint')
    steps = global_step

    for epoch in range(start_epoch, args.epochs):
        avg_loss = Average('average_train_loss')
        metrics = {}
        optimizer.zero_grad()
        start = time.time()
        model.train()
        train_itr = iter(train_loader)
        for i in range(steps_per_epoch):
            batch = next(train_itr)
            steps += 1
            x, y = batch
            inputs = x.to(args.device)
            labels = y.to(args.device)
            loss = loss_function(inputs, labels)
            loss.backward()
            avg_loss.update(loss.item())

            torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip)
            optimizer.step()
            optimizer.zero_grad()
            if (i + 1) % report_on == 0:
                logging.info(avg_loss)
            if (i + 1) % update_on == 0 and args.local_rank < 1:
                elapsed = (time.time() - start) / 60
                logging.info('elapsed time this epoch %d min', elapsed)
                logging.info('elapsed step time %f steps/min', i / elapsed)
                save_checkpoint(model, model_base, steps, tick_type='step')

        # How much time elapsed in minutes
        elapsed = (time.time() - start) / 60
        train_avg_loss = avg_loss.avg
        # This is the average training token-level loss across all machines
        # This is the token-level training perplexity
        metrics['train_elapsed_min'] = elapsed
        metrics['average_train_loss'] = train_avg_loss
        if args.local_rank < 1:
            avg_valid_loss = Average('average_valid_loss')
            start = time.time()
            model.eval()
            valid_itr = iter(valid_loader)
            for j in range(valid_steps):
                with torch.no_grad():
                    x, y = batch
                    inputs = x.to(args.device)
                    labels = y.to(args.device)
                    loss = loss_function(inputs, labels)
                    avg_valid_loss.update(loss.item())

            valid_avg_loss = avg_valid_loss.avg

            elapsed = (time.time() - start) / 60
            metrics['valid_elapsed_min'] = elapsed

            metrics['average_valid_loss'] = valid_avg_loss
            logger.info(metrics)
            save_checkpoint(model, model_base, epoch, tick_type='epoch')
示例#10
0
    VECTORIZER.vocab[args.suffix],
    Offsets.EOS,
)

DOC2WORD = read_vocab_file(args.document_vocab)
label2word = read_vocab_file(args.label_vocab)
LABELS = {Offsets.VALUES[k]: k for k in range(Offsets.OFFSET)}
for label in label2word.values():
    for prefix in ["B", "I", "E", "S"]:
        LABELS[f"{prefix}-{label}"] = len(LABELS)

LABELS["O"] = len(LABELS)

if not os.path.exists(args.output_dir):
    os.makedirs(args.output_dir)
write_json(LABELS, os.path.join(args.output_dir, 'labels.json'))
valid_dir = os.path.join(args.output_dir, 'valid')
train_dir = os.path.join(args.output_dir, 'train')
makedir_if_none(args.output_dir)
makedir_if_none(train_dir)
makedir_if_none(valid_dir)

logger.info("Converting validation files")
fw_valid = create_file_writer(args.fmt, os.path.join(valid_dir, 'valid'),
                              args.fields, args.max_file_size)
write_files(VALID_FILES, args.input_files, fw_valid, valid_dir, args.pg_name)

logger.info("Converting training files")
fw_train = create_file_writer(args.fmt, os.path.join(train_dir, 'train'),
                              args.fields, args.max_file_size)
write_files(TRAIN_FILES, args.input_files, fw_train, train_dir, args.pg_name)
示例#11
0
def run(basedir=None,
        train_file=None,
        valid_file=None,
        dataset_key='tlm',
        embed_type='default',
        d_model=512,
        d_ff=2048,
        d_k=None,
        num_heads=8,
        num_layers=8,
        num_train_workers=4,
        nctx=256,
        file_type='json',
        batch_size=256,
        subword_model_file=None,
        subword_vocab_file=None,
        dropout=0.1,
        ffn_pdrop=0.0,
        layer_drop=0.0,
        lr_scheduler='cosine',
        lr_decay_steps=None,
        lr_decay_rate=None,
        lr_alpha=0.0,
        optim='adamw',
        lr=4.0e-4,
        clip=1.0,
        weight_decay=1.0e-2,
        epochs=32,
        restart_from=None,
        restart_tt=None,
        warmup_steps=10000,
        saves_per_epoch=10,
        mlm=True,
        preprocessed=True,
        rpr_k=[8],
        rpr_value_on=False,
        windowed_ra=False,
        device="cuda",
        distributed=False,
        local_rank=-1,
        extra_tokens=["[CLS]", "[MASK]"],
        do_early_stopping=False,
        model_type='transformer-mlm',
        modules=[],
        ra_type=None,
        transformer_type=None,
        **kwargs):
    if basedir is None:
        basedir = 'lm-{}-bpe-{}'.format(dataset_key, os.getpid())
    logging.basicConfig(
        level=logging.INFO if local_rank in [-1, 0] else logging.WARN)

    for module in modules:
        import_user_module(module)
    num_gpus = get_num_gpus_multiworker()
    distributed = distributed or num_gpus > 1
    logger.info(f"Using {num_gpus} GPUs in this job.")

    do_on_demand_masking = mlm and not preprocessed
    if do_on_demand_masking:
        logger.info(f"On-demand masking is turned on")
    if distributed:
        device, updated_local_rank = init_distributed(local_rank)
        local_rank = updated_local_rank

    if file_type == 'tfrecord':
        reader_type = 'tfrecord'
    elif preprocessed:
        reader_type = 'preprocessed'
    else:
        reader_type = 'lang'
    reader = MultiFileDatasetReader(src_nctx=nctx,
                                    model_file=subword_model_file,
                                    vocab_file=subword_vocab_file,
                                    file_type=file_type,
                                    reader_type=reader_type,
                                    record_keys=['x', 'y'] if mlm else ['x'],
                                    extra_tokens=extra_tokens)

    # This looks a bit funny but the streaming reader ignores our vocab and gives us the one from the subword_model
    # However, we do need to get counts from our dataset for validation so we can calculate the perplexity
    vocab = reader.build_vocab([valid_file])
    # If we are not using chars, then use 'x' for both input and output
    preproc_data = baseline.embeddings.load_embeddings(
        'x',
        dsz=d_model,
        known_vocab=vocab['x'],
        preserve_vocab_indices=True,
        embed_type=embed_type)
    vocabs = preproc_data['vocab']

    os.makedirs(basedir, exist_ok=True)
    # We want to make sure to save our input vocab into the basedir for reuse later
    write_json(vocabs, os.path.join(basedir, 'vocabs.json'))
    embeddings = {'x': preproc_data['embeddings']}
    logger.info("Loaded embeddings")

    train_set = reader.load(train_file, vocabs)
    valid_set = reader.load(valid_file,
                            vocabs,
                            distribute=False,
                            shuffle=False)

    train_loader = DataLoader(train_set,
                              batch_size=batch_size,
                              num_workers=num_train_workers)
    valid_loader = DataLoader(valid_set, batch_size=batch_size)
    logger.info("Loaded datasets")
    logger.info("Using embedding type [%s]", embed_type)

    if 'mlm' in model_type:
        mask_from = vocabs
        vocab_size = len(mask_from)
        mask_value = mask_from.get("[MASK]")
        if mask_value == -1:
            logger.error(
                "We could not find a suitable masking token in the vocab")
            return

    if len(rpr_k) == 0 or rpr_k[0] < 1:
        rpr_k = None
    elif len(rpr_k) == 1:
        rpr_k = None if rpr_k[0] == 0 else rpr_k[0]
    if ra_type != None and ra_type != 'shaw' and rpr_k is not None:
        print(
            f"Relative attention mismatch. You requested {ra_type} with rpr set.  Setting it to 0"
        )
        rpr_k = None

    model = create_lang_model(
        embeddings,
        hsz=d_model,
        nctx=nctx,  # Only for gMLP
        d_ff=d_ff,
        tie_weights=True,
        dropout=dropout,
        gpu=False,
        num_heads=num_heads,
        layers=num_layers,
        rpr_k=rpr_k,
        d_k=d_k,
        ffn_pdrop=ffn_pdrop,
        windowed_ra=windowed_ra,
        rpr_value_on=rpr_value_on,
        layer_drop=layer_drop,
        model_type=model_type,
        ra_type=ra_type,
        transformer_type=transformer_type,
        src_keys=['x'],
        tgt_key='x')
    model.to(device)

    loss_function = model.create_loss()
    loss_function.to(device)

    logger.info("Loaded model and loss")

    steps_per_epoch = len(train_loader) // num_gpus
    update_on = steps_per_epoch // saves_per_epoch
    report_on = max(10, update_on) // 10
    logger.info(
        f"Steps per epoch per GPU: {steps_per_epoch}. Saving checkpoint every {update_on} steps."
    )
    lr_decay = get_lr_decay(lr_scheduler,
                            lr,
                            steps_per_epoch,
                            epochs,
                            logger,
                            decay_steps=lr_decay_steps,
                            decay_rate=lr_decay_rate,
                            alpha=lr_alpha)
    linear_warmup = WarmupLinearSchedulerPyTorch(warmup_steps, lr=lr)
    lr_sched = CompositeLRScheduler(linear_warmup, lr_decay, lr=lr)

    global_step = 0
    start_epoch = 0
    if restart_from:

        if restart_from.endswith('npz'):
            load_tlm_npz(model, restart_from)
        else:
            model.load_state_dict(torch.load(restart_from))
        vec = restart_from.split("-")

        if restart_tt:
            tick_type = restart_tt
        else:
            tick_type = vec[-2]
        step_num = int(vec[-1].split(".")[0])
        if tick_type == 'epoch':
            start_epoch = step_num
            global_step = start_epoch * steps_per_epoch

        elif tick_type == 'step':
            start_epoch = step_num // steps_per_epoch
            global_step = step_num
        else:
            logger.warning(
                f"The previous tick was {step_num} but command-line specifies to ignore, setting to 0"
            )

        logger.info(
            "Restarting from a previous checkpoint %s.\n\tStarting at global_step=%d, epoch=%d",
            restart_from, global_step, start_epoch + 1)

    optimizer = OptimizerManager(model,
                                 global_step,
                                 optim=optim,
                                 lr=lr,
                                 lr_function=lr_sched,
                                 weight_decay=weight_decay)
    logger.info("Model has {:,} parameters".format(
        sum(p.numel() for p in model.parameters() if p.requires_grad)))

    # Prepare model for distributed training if needed
    if distributed:
        # This program assume pure data parallelism, each model is on a single gpu
        # If we wanted to support model and data parallelism we would need to update
        # the selection of gpus based on rank, it would need to select multiple ids
        # based on rank, here we select only a single gpu and use it for input and
        # output.
        model = DistributedDataParallel(model,
                                        device_ids=[device],
                                        output_device=device,
                                        find_unused_parameters=True)
        logger.info("Model located on %s", device)

    model_base = os.path.join(basedir, 'checkpoint')
    steps = global_step
    best_valid_loss = np.inf

    timer = Timer()
    for epoch in range(start_epoch, epochs):
        avg_loss = Average('average_train_loss')
        metrics = {}
        optimizer.zero_grad()
        timer.start()
        model.train()
        train_itr = iter(train_loader)
        for i in range(steps_per_epoch):
            batch = next(train_itr)
            steps += 1
            x, y = batch
            inputs = x.to(device)
            labels = y.to(device)
            if do_on_demand_masking:
                inputs, labels, _ = on_demand_mlm_masking(
                    inputs, labels, mask_value, vocab_size)
            inputs = {'x': inputs}

            labels = labels.contiguous()
            logits = model(inputs, None)[0].contiguous()
            if mlm:
                loss = loss_function(logits, labels)
            else:
                shift_logits = logits[:, -1]
                shift_labels = labels[:, 1:]
                loss = loss_function(shift_logits, shift_labels)
            loss.backward()
            avg_loss.update(loss.item())

            torch.nn.utils.clip_grad_norm_(model.parameters(), clip)
            optimizer.step()
            optimizer.zero_grad()
            if (i + 1) % report_on == 0:
                logging.info(avg_loss)

            if (i + 1) % update_on == 0 and local_rank < 1:
                elapsed = timer.elapsed(True)
                logging.info('elapsed time this epoch %d min', elapsed)
                logging.info('elapsed step time %f steps/min', i / elapsed)
                logging.info('LR: %f', optimizer.current_lr)

                if not do_early_stopping:
                    save_checkpoint(model, model_base, steps, tick_type='step')
                else:
                    valid_token_loss = validate(model, loss_function,
                                                valid_loader, avg_loss, timer,
                                                metrics, do_on_demand_masking,
                                                mlm, mask_value, vocab_size,
                                                device)
                    if valid_token_loss < best_valid_loss:
                        best_valid_loss = valid_token_loss
                        logger.info(
                            f"New best valid loss: {best_valid_loss}. Saving checkpoint..."
                        )
                        save_checkpoint(model,
                                        model_base,
                                        steps,
                                        tick_type='step')
                    model.train()

        if not do_early_stopping:
            _ = validate(model, loss_function, valid_loader, avg_loss, timer,
                         metrics, do_on_demand_masking, mlm, mask_value,
                         vocab_size, device)
            save_checkpoint(model, model_base, epoch, tick_type='epoch')
def run(basedir=None,
        train_file=None,
        valid_file=None,
        dataset_key='paired',
        embed_type='default',
        d_model=512,
        d_ff=2048,
        d_k=None,
        num_heads=8,
        num_layers=8,
        num_train_workers=4,
        nctx=256,
        tgt_nctx=None,
        file_type='json',
        record_keys=['x', 'y'],
        batch_size=256,
        subword_model_file=None,
        subword_vocab_file=None,
        dropout=0.1,
        lr_scheduler='cosine',
        lr_decay_steps=None,
        lr_decay_rate=None,
        lr_alpha=None,
        optim='adamw',
        lr=4.0e-4,
        clip=1.0,
        weight_decay=1.0e-2,
        epochs=32,
        restart_from=None,
        restart_tt=None,
        warmup_steps=10000,
        saves_per_epoch=10,
        layer_drop=0.0,
        reader_type='preprocessed',
        src_begin_tok=[],
        src_end_tok=['<EOS>'],
        tgt_begin_tok=['<GO>'],
        tgt_end_tok=['<EOS>'],
        lower=False,
        rpr_k=[8],
        device='cuda',
        distributed=False,
        local_rank=-1,
        save_npz=False,
        extra_tokens=["[CLS]", "[MASK]"],
        subword_type='bpe',
        label_smoothing=None,
        ra_type=None,
        transformer_type=None,
        **kwargs):
    if basedir is None:
        basedir = f's2s-{reader_type}-paired-{dataset_key}-bpe-{os.getpid()}'
    logging.basicConfig(
        level=logging.INFO if local_rank in [-1, 0] else logging.WARN)
    num_gpus = get_num_gpus_multiworker()
    distributed = distributed or num_gpus > 1
    logger.info(f"Using {num_gpus} GPUs in this job.")
    if distributed:
        device, updated_local_rank = init_distributed(local_rank)
        local_rank = updated_local_rank
    if not tgt_nctx:
        tgt_nctx = nctx
    reader = MultiFileDatasetReader(nctx,
                                    tgt_nctx,
                                    src_begin_tok,
                                    src_end_tok,
                                    tgt_begin_tok,
                                    tgt_end_tok,
                                    subword_model_file,
                                    subword_vocab_file,
                                    file_type,
                                    reader_type=reader_type,
                                    record_keys=record_keys,
                                    lower=lower,
                                    extra_tokens=extra_tokens,
                                    subword_type=subword_type)
    vocab = reader.build_vocab()
    # If we are not using chars, then use 'x' for both input and output
    preproc_data = baseline.embeddings.load_embeddings(
        'x',
        dsz=d_model,
        known_vocab=vocab['x'],
        preserve_vocab_indices=True,
        embed_type=embed_type)
    vocabs = preproc_data['vocab']
    os.makedirs(basedir, exist_ok=True)
    # We want to make sure to save our input vocab into the basedir for reuse later
    write_json(vocabs, os.path.join(basedir, 'vocabs.json'))
    embeddings = preproc_data['embeddings']
    logger.info("Loaded embeddings")
    train_set = reader.load(train_file, vocabs)
    valid_set = reader.load(valid_file,
                            vocabs,
                            distribute=False,
                            shuffle=False)
    train_loader = DataLoader(train_set,
                              batch_size=batch_size,
                              num_workers=num_train_workers)
    valid_loader = DataLoader(valid_set, batch_size=batch_size)
    logger.info("Loaded datasets")
    logger.info("Using embedding type [%s]", embed_type)
    if len(rpr_k) == 0 or rpr_k[0] < 1:
        rpr_k = None
    elif len(rpr_k) == 1:
        rpr_k = rpr_k[0]
    else:
        rpr_k = rpr_k

    hps = {
        "dsz": d_model,
        "hsz": d_model,
        "d_ff": d_ff,
        "dropout": dropout,
        "num_heads": num_heads,
        "layers": num_layers,
        "encoder_type": "transformer",
        "decoder_type": "transformer",
        "src_lengths_key": "x_lengths",
        "d_k": d_k,
        "layer_drop": layer_drop,
        "rpr_k": rpr_k,
        "ra_type": ra_type,
        "transformer_type": transformer_type
    }
    model = TiedEmbeddingsSeq2SeqModel({'x': embeddings}, None, **hps)
    model.to(device)

    loss_function = model.create_loss(label_smoothing=label_smoothing)
    loss_function.to(device)
    logger.info("Created model and loss")
    steps_per_epoch = len(train_loader) // num_gpus
    valid_steps = len(valid_loader)
    update_on = steps_per_epoch // saves_per_epoch
    report_on = max(10, update_on) // 10
    logger.info(
        f"Steps per epoch per GPU: {steps_per_epoch}. Saving checkpoint every {update_on} steps."
    )
    lr_decay = get_lr_decay(lr_scheduler,
                            lr,
                            steps_per_epoch,
                            epochs,
                            logger,
                            decay_steps=lr_decay_steps,
                            decay_rate=lr_decay_rate,
                            alpha=lr_alpha)
    linear_warmup = WarmupLinearSchedulerPyTorch(warmup_steps, lr=lr)
    lr_sched = CompositeLRScheduler(linear_warmup, lr_decay, lr=lr)
    global_step = 0
    start_epoch = 0
    if restart_from:

        global_step, start_epoch = reload_from_checkpoint(
            restart_from, restart_tt, model, steps_per_epoch)
        logger.info(
            "Restarting from a previous checkpoint %s.\n\tStarting at global_step=%d, epoch=%d",
            restart_from, global_step, start_epoch + 1)
    optimizer = OptimizerManager(model,
                                 global_step,
                                 optim=optim,
                                 lr=lr,
                                 lr_function=lr_sched,
                                 weight_decay=weight_decay)
    logger.info("Model has {:,} parameters".format(
        sum(p.numel() for p in model.parameters() if p.requires_grad)))
    # Prepare model for distributed training if needed
    if distributed:
        model = DistributedDataParallel(model,
                                        device_ids=[device],
                                        output_device=device)
        logger.info("Model located on %d", local_rank)
    model_base = os.path.join(basedir, 'checkpoint')
    steps = global_step
    timer = Timer()
    for epoch in range(start_epoch, epochs):
        avg_loss = Average('average_train_loss')
        metrics = {}
        optimizer.zero_grad()
        timer.start()
        model.train()
        train_itr = iter(train_loader)
        for i in range(steps_per_epoch):
            batch = next(train_itr)
            steps += 1

            x, y = batch
            loss = run_step(x, y, model, loss_function, distributed)
            loss.backward()
            avg_loss.update(loss.item())

            torch.nn.utils.clip_grad_norm_(model.parameters(), clip)
            optimizer.step()
            optimizer.zero_grad()
            if (i + 1) % report_on == 0:
                logging.info(avg_loss)
            if (i + 1) % update_on == 0 and local_rank < 1:
                elapsed = timer.elapsed(True)
                logging.info('elapsed time this epoch %d min', elapsed)
                logging.info('elapsed step time %f steps/min', i / elapsed)
                logging.info('LR: %f', optimizer.current_lr)
                save_checkpoint(model,
                                model_base,
                                steps,
                                tick_type='step',
                                save_npz=save_npz)

        # How much time elapsed in minutes
        elapsed = timer.elapsed(True)
        train_avg_loss = avg_loss.avg
        # This is the average training token-level loss across all machines
        # This is the token-level training perplexity
        metrics['train_elapsed_min'] = elapsed
        metrics['average_train_loss'] = train_avg_loss
        if local_rank < 1:
            avg_valid_loss = Average('average_valid_loss')
            timer.start()
            model.eval()
            valid_itr = iter(valid_loader)
            for j in range(valid_steps):
                with torch.no_grad():
                    batch = next(valid_itr)
                    x, y = batch
                    loss = run_step(x, y, model, loss_function, distributed)
                avg_valid_loss.update(loss.item())

            valid_avg_loss = avg_valid_loss.avg

            elapsed = timer.elapsed(True)
            metrics['valid_elapsed_min'] = elapsed

            metrics['average_valid_loss'] = valid_avg_loss
            logger.info(metrics)
            save_checkpoint(model,
                            model_base,
                            epoch,
                            tick_type='epoch',
                            save_npz=save_npz)
def main():
    parser = ArgumentParser()
    parser.add_argument("--basedir", type=str)
    parser.add_argument("--train_dir", type=str, required=True, help='Training directory')
    parser.add_argument("--valid_dir", type=str, required=True, help='Validation directory')
    parser.add_argument("--train_md", type=str, help="Training metadata YAML, defaults to `{train_dir}/md.yml`")
    parser.add_argument("--valid_md", type=str, help="Validation metadata YAML, defaults to `{valid_dir}/md.yml`")
    parser.add_argument("--dataset_key", default="tlm",
                        help="dataset key for basedir")
    parser.add_argument("--embed_type", type=str, default='default',
                        choices=["default", "positional", "learned-positional"],
                        help="register label of the embeddings")
    parser.add_argument("--d_model", type=int, default=512, help="Model dimension (and embedding dsz)")
    parser.add_argument("--d_ff", type=int, default=2048, help="FFN dimension")
    parser.add_argument("--d_k", type=int, default=None, help="Dimension per head.  Use if num_heads=1 to reduce dims")
    parser.add_argument("--num_heads", type=int, default=8, help="Number of heads")
    parser.add_argument("--num_layers", type=int, default=8, help="Number of layers")
    parser.add_argument("--num_train_workers", type=int, default=4, help="Number train workers")
    parser.add_argument("--distribute", type=str, default="mirror", choices=["mirror", "tpu", "nccl"])
    parser.add_argument("--tpu_ep", type=str, help="The TPU endpoint if using `distribute=tpu`")
    parser.add_argument("--nctx", type=int, default=256, help="Max input length (x)")
    parser.add_argument("--file_type", default='tfrecord', choices=['json', 'jsonl', 'tfrecord'], help="Glob pattern for data")
    parser.add_argument("--batch_size", type=int, default=256, help="Batch Size")
    parser.add_argument("--subword_model_file", type=str, help="The BPE model file", required=False)
    parser.add_argument("--subword_vocab_file", type=str, help="The BPE subword vocab", required=True)
    parser.add_argument("--subword_type", type=str, choices=["bpe", "wordpiece", "sentencepiece"], default="bpe")
    parser.add_argument("--dropout", type=float, default=0.1, help="Dropout")
    parser.add_argument("--layer_drop", type=float, default=0.0, help="LayerDrop to apply")
    parser.add_argument("--ff_pdrop", type=float, default=0.1, help="Dropout in the dense stack")
    parser.add_argument("--optim", default="adamw", type=str, help="Optimizer to use (defaults to adamw)")
    parser.add_argument("--lr", type=float, default=4.0e-4, help="Learning rate")
    parser.add_argument("--clip", type=float, default=1.0, help="Clipping gradient norm")
    parser.add_argument("--weight_decay", type=float, default=1.0e-2, help="Weight decay")
    parser.add_argument("--epochs", type=int, default=32, help="Num training epochs")
    parser.add_argument("--restart", type=str2bool, help="Option allows you to restart from a previous checkpoint")
    parser.add_argument("--warmup_steps", type=int, default=10000, help="Num warmup steps")
    parser.add_argument("--saves_per_epoch", type=int, default=10, help="The number of checkpoints to save per epoch")
    parser.add_argument('--rpr_k',
                        help='Relative attention positional sizes pass 0 if you dont want relative attention',
                        type=int, default=[8], nargs='+')
    parser.add_argument('--ra_type', type=str, help="Specify a relative attention type")
    parser.add_argument("--reduction_d_k", type=int, default=64, help="Dimensions of Key and Query in the single headed"
                                                                      "reduction layers")
    parser.add_argument("--reduction_type", type=str, default="2ha",
                        help="If using a dual encoder, specifies the reduction type")
    parser.add_argument("--stacking_layers", type=int, nargs='+', default=[])
    parser.add_argument("--loss", type=str, default='symmetric',
                        choices=['contrastive', 'symmetric'])
    parser.add_argument("--learn_temp", type=str2bool, default=True,
                        help="If 'constrastive' or 'symmetric' loss, should we learn the temperature scaling")
    parser.add_argument("--init_temp", type=float,
                        help="Initialize the temperature for 'contrastive' or 'symmetric' loss")
    parser.add_argument("--npz", help="Should we write out NPZ files?", type=str2bool, default=False)
    parser.add_argument("--tb", help="Turn on tensorboard?", type=str2bool, default=False)
    parser.add_argument("--convert_only", help="Should we just convert this file to NPZ and exit?", type=str2bool, default=False)
    parser.add_argument("--extra_tokens", help="What extra tokens should we use", nargs="+", default=["[CLS]", "[MASK]"])
    args = parser.parse_args()

    if args.tpu_ep is not None and args.file_type != 'tfrecord':
        raise Exception("For TPUs, TFRecord format is required!")

    SET_TRAIN_FLAG(True)

    if args.convert_only:
        args.restart = True

    if args.basedir is None:
        args.basedir = 'paired-{}-bpe-{}'.format(args.dataset_key, os.getpid())
    logging.basicConfig(level=logging.INFO)
    logger.info(f"Writing results to {args.basedir}")

    if args.tb:
        logdir = f"logs/scalars/{os.getpid()}"
        file_writer = tf.summary.create_file_writer(logdir + "/metrics")
        file_writer.set_as_default()
        logger.info(f"Set up tensorboard logdir {logdir}")

    strategy = create_distribute_strategy(args.distribute, args.tpu_ep, len(get_env_gpus(None)))
    num_replicas = strategy.num_replicas_in_sync
    logger.info(f"Using {num_replicas} replicas in this job.")
    Vec1D = get_subword_vec1d(args.subword_type)
    vectorizer = Vec1D(model_file=args.subword_model_file,
                       vocab_file=args.subword_vocab_file,
                       mxlen=args.nctx,
                       extra_tokens=args.extra_tokens)
    preproc_data = baseline.embeddings.load_embeddings('x', dsz=args.d_model, known_vocab=vectorizer.vocab,
                                                       preserve_vocab_indices=True,
                                                       embed_type=args.embed_type)
    vocabs = preproc_data['vocab']

    train_md = args.train_md if args.train_md else os.path.join(args.train_dir, 'md.yml')
    num_train_samples = get_num_samples(train_md)
    valid_md = args.valid_md if args.valid_md else os.path.join(args.valid_dir, 'md.yml')
    num_valid_samples = get_num_samples(valid_md)
    is_curriculum = True if isinstance(num_train_samples, Mapping) else False

    def dataset_train_fn(input_context):
        global_batchsz = args.batch_size
        base_batchsz = input_context.get_per_replica_batch_size(global_batchsz)
        ds = None
        num_shards = input_context.num_input_pipelines
        index = input_context.input_pipeline_id
        if is_curriculum:
            for sub in num_train_samples.keys():
                train_curr_dir = os.path.join(args.train_dir, str(sub))
                batchsz_scale_factor = args.nctx // sub
                this_batchsz = base_batchsz * batchsz_scale_factor
                curr_ds = get_dataset(train_curr_dir, args.file_type, args.num_train_workers, num_shards, index).batch(this_batchsz, drop_remainder=True)
                if ds is None:
                    ds = curr_ds
                else:
                    ds = ds.concatenate(curr_ds)
        else:
            ds = get_dataset(args.train_dir, args.file_type, args.num_train_workers, num_shards, index).batch(base_batchsz)
        return ds

    train_loader = strategy.experimental_distribute_datasets_from_function(dataset_train_fn)

    def dataset_test_fn(input_context):
        global_batchsz = args.batch_size
        base_batchsz = input_context.get_per_replica_batch_size(global_batchsz)
        ds = None
        num_shards = input_context.num_input_pipelines
        index = input_context.input_pipeline_id
        if is_curriculum:
            for sub in num_valid_samples.keys():
                valid_curr_dir = os.path.join(args.valid_dir, str(sub))
                batchsz_scale_factor = args.nctx // sub
                this_batchsz = base_batchsz * batchsz_scale_factor
                curr_ds = get_dataset(valid_curr_dir, args.file_type, args.num_train_workers, num_shards, index, shuffle=False).batch(
                    this_batchsz, drop_remainder=True)
                if ds is None:
                    ds = curr_ds
                else:
                    ds = ds.concatenate(curr_ds)
        else:
            ds = get_dataset(args.valid_dir, args.file_type, args.num_train_workers, num_shards, index, shuffle=False).batch(base_batchsz)

        return ds
    valid_loader = strategy.experimental_distribute_datasets_from_function(dataset_test_fn)

    os.makedirs(args.basedir, exist_ok=True)
    # We want to make sure to save our input vocab into the basedir for reuse later
    write_json(vocabs, os.path.join(args.basedir, 'vocabs.json'))
    embeddings = preproc_data['embeddings']
    logger.info("Loaded embeddings")

    logger.info("Loaded datasets")
    logger.info("Using embedding type [%s]", args.embed_type)
    if len(args.rpr_k) == 0 or args.rpr_k[0] < 1:
        rpr_k = None
    elif len(args.rpr_k) == 1:
        rpr_k = args.rpr_k[0]
    else:
        rpr_k = args.rpr_k

    logger.info("Creating dual encoder")
    model = PairedModel(embeddings, args.d_model, args.d_ff, args.dropout, args.num_heads, args.num_layers, rpr_k=rpr_k,
                        d_k=args.d_k, reduction_d_k=args.reduction_d_k, stacking_layers=args.stacking_layers,
                        ffn_pdrop=args.ff_pdrop, reduction_type=args.reduction_type, freeze_encoders=False,
                        ra_type=args.ra_type)

    loss_function = model.create_loss(loss_type=args.loss, init_temp=args.init_temp, learn_temp=args.learn_temp)
    logger.info("Loaded model and loss")
    if is_curriculum:
        steps_per_epoch = 0
        steps_per_valid_epoch = 0
        for k, v in num_train_samples.items():
            steps_per_epoch += int(num_train_samples[k] // (args.batch_size * (args.nctx / k)))
        for k, v in num_valid_samples.items():
            steps_per_valid_epoch += int(num_valid_samples[k] // (args.batch_size * (args.nctx / k)))

    else:
        steps_per_epoch = num_train_samples // args.batch_size
        steps_per_valid_epoch = num_valid_samples // args.batch_size

    update_on = steps_per_epoch // args.saves_per_epoch
    report_on = max(10, update_on) // 10
    logger.info(f"Steps per epoch: {steps_per_epoch}. Saving checkpoint every {update_on} steps.")

    lr_decay = CosineDecaySchedulerTensorFlow(steps_per_epoch * args.epochs, lr=args.lr)
    linear_warmup = WarmupLinearSchedulerTensorFlow(args.warmup_steps, lr=args.lr)
    lr_sched = CompositeLRSchedulerTensorFlow(linear_warmup, lr_decay)
    optimizer = EagerOptimizer(loss_function, optim=args.optim, lr_function=lr_sched, weight_decay=args.weight_decay,
                               clip=args.clip, lr=args.lr)
    checkpoint = tf.train.Checkpoint(optimizer=optimizer.optimizer, model=model)
    checkpoint_manager = tf.train.CheckpointManager(checkpoint,
                                                    directory=args.basedir,
                                                    max_to_keep=5)

    if args.restart:
        # The global step gets automatically updated here
        # so we dont have to worry about our LR regimen
        checkpoint.restore(checkpoint_manager.latest_checkpoint)

    def _replicated_train_step(inputs):
        """This runs on a single replica"""
        x, y = inputs
        per_replica_loss = optimizer.update(model, x, y, num_replicas)
        return per_replica_loss

    @tf.function
    def _distributed_train_step(inputs: Tuple[tf.Tensor, tf.Tensor]):
        """Runs across multiple replicas and aggregates the results.

        :param inputs:
        :return:
        """
        per_replica_loss = strategy.run(_replicated_train_step, args=(inputs,))
        return strategy.reduce(tf.distribute.ReduceOp.SUM, per_replica_loss, axis=None)

    def _replicated_test_step(inputs):
        """This runs on a single replica"""
        x, y = inputs
        per_replica_loss = loss_function(model, x, y) / num_replicas
        return per_replica_loss

    @tf.function
    def _distributed_test_step(inputs: Tuple[tf.Tensor, tf.Tensor]):
        """Runs across multiple replicas and aggregates the results.

        :param inputs:
        :return:
        """
        per_replica_loss = strategy.run(_replicated_test_step, args=(inputs,))
        return strategy.reduce(tf.distribute.ReduceOp.SUM, per_replica_loss, axis=None)

    # This is the training loop
    start_epoch = 0
    timer = Timer()
    with strategy.scope():

        for epoch in range(start_epoch, args.epochs):
            SET_TRAIN_FLAG(True)
            logger.info('Starting epoch %d', epoch + 1)
            avg_loss = Average('average_train_loss')
            metrics = {}
            timer.start()
            train_iter = iter(train_loader)
            for i in range(steps_per_epoch):
                loss = _distributed_train_step(next(train_iter))
                avg_loss.update(loss.numpy().item())
                tf.summary.scalar("train_loss", data=loss, step=optimizer.global_step)

                if args.convert_only:
                    logger.warning("Convert only flag specified.  Stopping after one step")
                    steps = optimizer.global_step.numpy()
                    npz_checkpoint = os.path.join(args.basedir, f'checkpoint-step-{steps}.npz')
                    save_transformer_de_npz(model, npz_checkpoint)
                    return

                if (i + 1) % report_on == 0:
                    logging.info(avg_loss)
                if (i + 1) % update_on == 0:
                    elapsed = timer.elapsed(True)
                    logging.info('elapsed time this epoch %d min', elapsed)
                    logging.info('elapsed step time %f steps/min', i / elapsed)
                    checkpoint_manager.save()
                    if args.npz:
                        steps = optimizer.global_step.numpy()
                        npz_checkpoint = os.path.join(args.basedir, f'checkpoint-step-{steps}.npz')
                        save_transformer_de_npz(model, npz_checkpoint)

            # How much time elapsed in minutes
            train_token_loss = avg_loss.avg
            # This is the average training token-level loss across all machines
            # This is the token-level training perplexity
            train_token_ppl = math.exp(train_token_loss)
            metrics['train_elapsed_min'] = timer.elapsed(True)
            metrics['average_train_loss'] = train_token_loss
            metrics['train_ppl'] = train_token_ppl
            metrics['lr'] = float(lr_sched(tf.cast(optimizer.global_step, tf.float32)).numpy().item())

            avg_valid_loss = Average('average_valid_loss')
            timer.start()
            SET_TRAIN_FLAG(False)
            valid_iter = iter(valid_loader)
            for i in range(steps_per_valid_epoch):
                valid_loss = _distributed_test_step(next(valid_iter))
                tf.summary.scalar('valid_loss', data=valid_loss, step=optimizer.global_step)
                avg_valid_loss.update(valid_loss.numpy().item())

            valid_token_loss = avg_valid_loss.avg
            valid_token_ppl = math.exp(valid_token_loss)

            metrics['valid_elapsed_min'] = timer.elapsed(True)
            metrics['average_valid_loss'] = valid_token_loss
            metrics['average_valid_word_ppl'] = valid_token_ppl
            logger.info(json.dumps(metrics, indent=4))
def train():
    parser = ArgumentParser()
    parser.add_argument("--basedir", type=str)
    parser.add_argument("--dataset_key",
                        type=str,
                        default='wikitext-2',
                        help="key from DATASETS global")
    parser.add_argument("--train_file",
                        type=str,
                        help='Optional file path to use for train file')
    parser.add_argument("--valid_file",
                        type=str,
                        help='Optional file path to use for valid file')
    parser.add_argument("--dataset_cache",
                        type=str,
                        default=os.path.expanduser('~/.bl-data'),
                        help="Path or url of the dataset cache")
    parser.add_argument("--cache_features", type=str2bool, default=True)
    parser.add_argument(
        "--embed_type",
        type=str,
        default='default',
        help=
        "register label of the embeddings, so far support positional or learned-positional"
    )
    parser.add_argument("--d_model",
                        type=int,
                        default=410,
                        help="Model dimension (and embedding dsz)")
    parser.add_argument("--d_ff", type=int, default=2100, help="FFN dimension")
    parser.add_argument(
        "--d_k",
        type=int,
        default=None,
        help="Dimension per head.  Use if num_heads=1 to reduce dims")
    parser.add_argument("--num_heads",
                        type=int,
                        default=10,
                        help="Number of heads")
    parser.add_argument("--num_layers",
                        type=int,
                        default=16,
                        help="Number of layers")
    parser.add_argument("--nctx",
                        type=int,
                        default=256,
                        help="Max input length")
    parser.add_argument("--batch_size", type=int, default=8, help="Batch Size")
    parser.add_argument("--tokens",
                        choices=["words", "chars", "bpe", "wordpiece"],
                        default="wordpiece",
                        help="What tokens to use")
    parser.add_argument("--subword_model_file",
                        type=str,
                        help="If using subwords, pass this",
                        default='bert-base-uncased')
    parser.add_argument(
        "--subword_vocab_file",
        type=str,
        help="If using subwords with separate vocab file, pass here")
    parser.add_argument(
        "--subword_special_tokens",
        type=str,
        nargs='*',
        help=
        "When using wordpiece vectorizer, this list provide special tokens to the never_split "
        "argument of BertTokenizer. These special tokens should also be in the customized vocab "
        "file so that they have their indices.")
    parser.add_argument("--dropout", type=float, default=0.1, help="Dropout")
    parser.add_argument("--lr_scheduler",
                        type=str,
                        help="The type of learning rate decay scheduler",
                        default='cosine')
    parser.add_argument("--lr_decay_steps",
                        type=int,
                        help="decay steps of lr scheduler")
    parser.add_argument("--lr_decay_rate",
                        type=float,
                        help="decay rate of lr scheduler")
    parser.add_argument("--lr_alpha",
                        type=float,
                        help="parameter alpha for cosine decay scheduler")
    parser.add_argument("--optim",
                        default="adam",
                        type=str,
                        help="Optimizer to use (defaults to adam)")
    parser.add_argument("--lr",
                        type=float,
                        default=4.0e-4,
                        help="Learning rate")
    parser.add_argument("--clip",
                        type=float,
                        default=0.25,
                        help="Clipping gradient norm")
    parser.add_argument("--weight_decay",
                        type=float,
                        default=0.0,
                        help="Weight decay")
    parser.add_argument("--epochs",
                        type=int,
                        default=20,
                        help="Num training epochs")
    parser.add_argument(
        "--restart_from",
        type=str,
        help="Option allows you to restart from a previous checkpoint")
    parser.add_argument(
        "--restart_tt",
        type=str,
        help="Optional param for legacy checkpoints (step|epoch)")
    parser.add_argument("--warmup_steps",
                        type=int,
                        default=1000,
                        help="Num warmup steps")
    parser.add_argument(
        "--saves_per_epoch",
        type=int,
        default=5,
        help="The number of checkpoints to save within an epoch")
    parser.add_argument("--mlm",
                        type=str2bool,
                        default=False,
                        help="Use Masked Language Model (MLM) objective")
    parser.add_argument(
        '--rpr_k',
        help=
        'Relative attention positional sizes pass 0 if you dont want relative attention',
        type=int,
        default=[8],
        nargs='+')
    parser.add_argument("--device",
                        type=str,
                        default="cuda" if torch.cuda.is_available() else "cpu",
                        help="Device (cuda or cpu)")
    parser.add_argument("--distributed",
                        type=str2bool,
                        default=False,
                        help="Are we doing distributed training?")
    parser.add_argument(
        "--local_rank",
        type=int,
        default=-1,
        help=
        "Local rank for distributed training (-1 means use the environment variables to find)"
    )
    parser.add_argument("--chars_per_word",
                        type=int,
                        default=40,
                        help="How many max characters per word")

    args = parser.parse_args()

    if args.train_file and not args.valid_file:
        logger.error(
            "If you provide a train_file, you must provide a valid_file")
        return

    if not args.train_file and args.valid_file:
        logger.error(
            "If you provide a valid_file, you must also provide a train_file")
        return

    if args.tokens == "chars" and args.mlm:
        logger.error(
            "Character composition cannot currently be used with the MLM objective"
        )

    if args.basedir is None:
        args.basedir = 'transformer-{}-{}-{}'.format(args.dataset_key,
                                                     args.tokens, os.getpid())
    logging.basicConfig(
        level=logging.INFO if args.local_rank in [-1, 0] else logging.WARN)
    logger.info("Cache directory [%s]", args.dataset_cache)

    num_gpus = get_num_gpus_multiworker()
    args.distributed = args.distributed or num_gpus > 1
    logger.info(f"Using {num_gpus} GPUs in this job.")

    if args.distributed:
        args.device = init_distributed(args.local_rank)

    if args.train_file:
        dataset = {
            'train_file': args.train_file,
            'valid_file': args.valid_file
        }
    else:
        dataset = DataDownloader(DATASETS[args.dataset_key],
                                 args.dataset_cache).download()
    if args.subword_special_tokens is None:
        special_tokens = ()
    else:
        special_tokens = tuple(args.subword_special_tokens)
    reader = create_reader(args.tokens, args.nctx, args.chars_per_word,
                           args.subword_model_file, args.subword_vocab_file,
                           special_tokens)

    preproc_data = load_embed_and_vocab(args.tokens, reader, dataset,
                                        args.dataset_key, args.embed_type,
                                        args.d_model, args.cache_features)

    vocabs = preproc_data['vocabs']
    if args.mlm:
        mask_from = vocabs['x']
        vocab_size = len(mask_from)
        mask_value = mask_from.get("[MASK]", mask_from.get("<MASK>", -1))
        if mask_value == -1:
            logger.error(
                "We could not find a suitable masking token in the vocab")
            return
    os.makedirs(args.basedir, exist_ok=True)
    # We want to make sure to save our input vocab into the basedir for reuse later
    write_json(vocabs['x'], os.path.join(args.basedir, 'vocabs.json'))
    embeddings = preproc_data['embeddings']
    valid_num_words = preproc_data['valid_num_words']
    tgt_key = preproc_data['tgt_key']
    logger.info("Loaded embeddings")

    train_set = load_data_caching(args.tokens, reader, dataset, 'train_file',
                                  vocabs, args.cache_features, logger)
    valid_set = load_data_caching(args.tokens, reader, dataset, 'valid_file',
                                  vocabs, args.cache_features, logger)
    logger.info("valid. tokens [%s], valid. words [%s]",
                valid_set.tensors[-1].numel(), valid_num_words)

    train_sampler = torch.utils.data.distributed.DistributedSampler(
        train_set) if args.distributed else None
    train_loader = DataLoader(train_set,
                              sampler=train_sampler,
                              batch_size=args.batch_size,
                              shuffle=(not args.distributed))
    valid_loader = DataLoader(valid_set,
                              batch_size=args.batch_size,
                              shuffle=False)
    logger.info("Loaded datasets")

    if len(args.rpr_k) == 0 or args.rpr_k[0] < 1:
        rpr_k = None
    elif len(args.rpr_k) == 1:
        rpr_k = args.rpr_k[0]
    else:
        rpr_k = args.rpr_k

    TLM = TransformerMaskedLanguageModel if args.mlm else TransformerLanguageModel
    model = TLM.create(embeddings,
                       hsz=args.d_model,
                       d_ff=args.d_ff,
                       tie_weights=(args.tokens != 'chars'),
                       dropout=args.dropout,
                       gpu=False,
                       num_heads=args.num_heads,
                       layers=args.num_layers,
                       rpr_k=rpr_k,
                       d_k=args.d_k,
                       src_keys=['x'],
                       tgt_key=tgt_key)
    model.to(args.device)
    loss_function = model.create_loss()
    loss_function.to(args.device)

    logger.info("Loaded model and loss")

    # in this case (train_loader is not iterator) the division by number of gpus is automatically taken care of by
    # torch.DataLoader
    steps_per_epoch = len(train_loader)
    update_on = steps_per_epoch // args.saves_per_epoch
    report_on = update_on // 10
    logger.info(
        f"Steps per epoch per GPU: {steps_per_epoch}. Saving a checkpoint every {update_on} steps."
    )
    lr_decay = get_lr_decay(args.lr_scheduler,
                            args.lr,
                            steps_per_epoch,
                            args.epochs,
                            logger,
                            decay_steps=args.lr_decay_steps,
                            decay_rate=args.lr_decay_rate,
                            alpha=args.lr_alpha)
    linear_warmup = WarmupLinearSchedulerPyTorch(args.warmup_steps, lr=args.lr)
    lr_sched = CompositeLRScheduler(linear_warmup, lr_decay, lr=args.lr)

    global_step = 0
    start_epoch = 0
    if args.restart_from:
        model.load_state_dict(torch.load(args.restart_from))
        vec = args.restart_from.split("-")

        if args.restart_tt:
            tick_type = args.restart_tt
        else:
            tick_type = vec[-2]
        step_num = int(vec[-1].split(".")[0])
        if tick_type == 'epoch':
            start_epoch = step_num
            global_step = start_epoch * steps_per_epoch

        else:
            start_epoch = step_num // steps_per_epoch
            global_step = step_num

        logger.info(
            "Restarting from a previous checkpoint %s.\n\tStarting at global_step=%d, epoch=%d",
            args.restart_from, global_step, start_epoch + 1)

    optimizer = OptimizerManager(model,
                                 global_step,
                                 optim=args.optim,
                                 lr=args.lr,
                                 lr_function=lr_sched,
                                 weight_decay=args.weight_decay)
    logger.info("Model has {:,} parameters".format(
        sum(p.numel() for p in model.parameters() if p.requires_grad)))

    # Prepare model for distributed training if needed
    if args.distributed:
        # This program assume pure data parallelism, each model is on a single gpu
        # If we wanted to support model and data parallelism we would need to update
        # the selection of gpus based on rank, it would need to select multiple ids
        # based on rank, here we select only a single gpu and use it for input and
        # output.
        model = DistributedDataParallel(model,
                                        device_ids=[args.device],
                                        output_device=args.device)
        logger.info("Model located on %s", args.device)

    # This is the training loop
    steps = global_step
    model_base = os.path.join(args.basedir, 'checkpoint')
    for epoch in range(start_epoch, args.epochs):
        avg_loss = Average('average_train_loss')
        metrics = {}
        optimizer.zero_grad()

        if args.distributed:
            train_sampler.set_epoch(epoch)

        start = time.time()
        model.train()
        for i, batch in enumerate(train_loader):
            steps += 1
            x, y = batch
            inputs = x.to(args.device)
            labels = y.to(args.device)
            if args.mlm:
                inputs, labels, _ = on_demand_mlm_masking(
                    inputs, labels, mask_value, vocab_size)
            inputs = {'x': inputs}
            labels = labels.transpose(0, 1).contiguous()
            logits = model(inputs, None)[0].transpose(0, 1).contiguous()
            if args.mlm:
                loss = loss_function(logits, labels)
            else:
                shift_logits = logits[:-1]
                shift_labels = labels[1:]
                loss = loss_function(shift_logits, shift_labels)
            loss.backward()
            avg_loss.update(loss.item())

            torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip)
            optimizer.step()
            optimizer.zero_grad()
            if (i + 1) % report_on == 0:
                logging.info(avg_loss)
            if (i + 1) % update_on == 0 and args.local_rank < 1:
                elapsed = (time.time() - start) / 60
                logging.info('elapsed time this epoch %d min', elapsed)
                logging.info('elapsed step time %f steps/min', i / elapsed)
                save_checkpoint(model, model_base, steps, tick_type='step')

        # How much time elapsed in minutes
        elapsed = (time.time() - start) / 60
        train_token_loss = avg_loss.avg
        # This is the average training token-level loss across all machines
        # This is the token-level training perplexity
        train_token_ppl = math.exp(train_token_loss)
        metrics['train_elapsed_min'] = elapsed
        metrics['average_train_loss'] = train_token_loss
        metrics['train_ppl'] = train_token_ppl
        model_base = os.path.join(args.basedir, 'checkpoint')
        avg_valid_loss = Average('average_valid_loss')
        start = time.time()
        model.eval()
        for batch in valid_loader:
            with torch.no_grad():
                x, y = batch
                inputs = x.to(args.device)
                labels = y.to(args.device)
                if args.mlm:
                    inputs, labels, _ = on_demand_mlm_masking(
                        inputs, labels, mask_value, vocab_size)
                inputs = {'x': inputs}
                labels = labels.transpose(0, 1).contiguous()
                logits = model(inputs, None)[0].transpose(0, 1).contiguous()
                if args.mlm:
                    loss = loss_function(logits, labels)
                else:
                    shift_logits = logits[:-1]
                    shift_labels = labels[1:]
                    loss = loss_function(shift_logits, shift_labels)
                avg_valid_loss.update(loss.item())

        valid_token_loss = avg_valid_loss.avg
        valid_token_ppl = math.exp(valid_token_loss)

        elapsed = (time.time() - start) / 60
        metrics['valid_elapsed_min'] = elapsed

        metrics['average_valid_loss'] = valid_token_loss
        if args.tokens in ['bpe', 'wordpiece']:
            metrics['valid_token_ppl'] = valid_token_ppl
            metrics['average_valid_word_ppl'] = math.exp(
                valid_token_loss * valid_set.tensors[-1].numel() /
                valid_num_words)
        else:
            metrics['average_valid_word_ppl'] = valid_token_ppl
        logger.info(metrics)

        if args.local_rank < 1:
            save_checkpoint(model, model_base, epoch, save_npz=True)
示例#15
0
def main():
    parser = ArgumentParser()
    parser.add_argument("--basedir", type=str)
    parser.add_argument("--train_dir", type=str, required=True, help='Training directory')
    parser.add_argument("--valid_dir", type=str, required=True, help='Validation directory')
    parser.add_argument("--train_md", type=str, help="Training metadata YAML, defaults to `{train_dir}/md.yml`")
    parser.add_argument("--valid_md", type=str, help="Validation metadata YAML, defaults to `{valid_dir}/md.yml`")
    parser.add_argument("--dataset_key", default="tlm",
                        help="dataset key for basedir")
    parser.add_argument("--embed_type", type=str, default='default',
                        choices=["default", "positional", "learned-positional"],
                        help="register label of the embeddings")
    parser.add_argument("--d_model", type=int, default=512, help="Model dimension (and embedding dsz)")
    parser.add_argument("--d_ff", type=int, default=2048, help="FFN dimension")
    parser.add_argument("--d_k", type=int, default=None, help="Dimension per head.  Use if num_heads=1 to reduce dims")
    parser.add_argument("--num_heads", type=int, default=8, help="Number of heads")
    parser.add_argument("--num_layers", type=int, default=8, help="Number of layers")
    parser.add_argument("--num_train_workers", type=int, default=4, help="Number train workers")
    parser.add_argument("--distribute", type=str, default="mirror", choices=["mirror", "tpu", "nccl"])
    parser.add_argument("--tpu_ep", type=str, help="The TPU endpoint if using `distribute=tpu`")
    parser.add_argument("--nctx", type=int, default=512, help="Max input length")
    parser.add_argument("--file_type", default='tfrecord', choices=['json', 'tfrecord'], help="Glob pattern for data")
    parser.add_argument("--batch_size", type=int, default=256, help="Batch Size")
    parser.add_argument("--subword_model_file", type=str, help="The BPE model file", required=False)
    parser.add_argument("--subword_vocab_file", type=str, help="The BPE subword vocab", required=False)
    parser.add_argument("--subword_type", type=str, choices=["bpe", "wordpiece", "sentencepiece"], default="bpe")
    parser.add_argument("--dropout", type=float, default=0.1, help="Dropout")
    parser.add_argument("--ffn_pdrop", type=float, default=0.0, help="Dropout in the dense stack")
    parser.add_argument("--layer_drop", type=float, default=0.0, help="LayerDrop to apply")
    parser.add_argument("--optim", default="adamw", type=str, help="Optimizer to use (defaults to adamw)")
    parser.add_argument("--lr", type=float, default=4.0e-4, help="Learning rate")
    parser.add_argument("--clip", type=float, default=1.0, help="Clipping gradient norm")
    parser.add_argument("--weight_decay", type=float, default=1.0e-2, help="Weight decay")
    parser.add_argument("--epochs", type=int, default=32, help="Num training epochs")
    parser.add_argument("--restart", type=str2bool, help="Option allows you to restart from a previous checkpoint")
    parser.add_argument("--warmup_steps", type=int, default=10000, help="Num warmup steps")
    parser.add_argument("--causal", type=str2bool, default=False, help="Use CLM (causal) instead of MLM")
    parser.add_argument("--mlp", type=str2bool, default=False, help="Use Gated MLP")
    parser.add_argument("--saves_per_epoch", type=int, default=10, help="The number of checkpoints to save per epoch")
    parser.add_argument('--rpr_k',
                        help='Relative attention positional sizes pass 0 if you dont want relative attention',
                        type=int, default=[8], nargs='+')
    parser.add_argument('--rpr_value_on', type=str2bool, default=True,
                        help="In relative attention, whether add positional correction to values in addition to the "
                             "correction to attention matrix")
    parser.add_argument('--ra_type', type=str, help="Specify a relative attention type")
    parser.add_argument('--windowed_ra', type=str2bool, default=False, help="whether prevent attention beyond rpr_k")
    parser.add_argument("--strategy", help="Training strategy, defaults to `mirror`", choices=["mirror"])
    parser.add_argument("--npz", help="Should we write out NPZ files?", type=str2bool, default=False)
    parser.add_argument("--tb", help="Turn on tensorboard?", type=str2bool, default=False)
    parser.add_argument("--convert_only", help="Should we just convert this file to NPZ and exit?", type=str2bool, default=False)
    parser.add_argument("--extra_tokens", help="What extra tokens should we use", nargs="+", default=["[CLS]", "[MASK]"])
    parser.add_argument("--eps", help="Epsilon", default=1e-6, type=float)
    parser.add_argument("--beta2", help="Epsilon", default=0.98, type=float)
    parser.add_argument("--grad_accum", help="Number of iterations to accum grads", default=1, type=int)
    parser.add_argument("--transformer_type", help="Transformer layer type")
    args = parser.parse_args()
    SET_TRAIN_FLAG(True)

    if args.convert_only:
        args.restart = True

    if args.basedir is None:
        args.basedir = f'lm-{args.dataset_key}-bpe-{os.getpid()}'
    logging.basicConfig(level=logging.INFO)
    logger.info(f"Writing results to {args.basedir}")

    if args.tb:
        logdir = f"{args.basedir}/scalars/{os.getpid()}"
        file_writer = tf.summary.create_file_writer(logdir + "/metrics")
        file_writer.set_as_default()
        logger.info(f"Set up tensorboard logdir {logdir}")

    strategy = create_distribute_strategy(args.distribute, args.tpu_ep)
    num_replicas = strategy.num_replicas_in_sync
    logger.info(f"Using {num_replicas} replicas in this job.")
    Vec1D = get_subword_vec1d(args.subword_type)
    vectorizer = Vec1D(model_file=args.subword_model_file,
                       vocab_file=args.subword_vocab_file,
                       mxlen=args.nctx,
                       extra_tokens=args.extra_tokens)

    vocab = {'x': vectorizer.vocab}
    preproc_data = baseline.embeddings.load_embeddings('x', dsz=args.d_model, known_vocab=vocab['x'],
                                                       preserve_vocab_indices=True,
                                                       embed_type=args.embed_type)
    vocabs = preproc_data['vocab']

    train_md = args.train_md if args.train_md else os.path.join(args.train_dir, 'md.yml')
    num_train_samples = get_num_samples(train_md)
    valid_md = args.valid_md if args.valid_md else os.path.join(args.valid_dir, 'md.yml')
    num_valid_samples = get_num_samples(valid_md)

    is_curriculum = True if isinstance(num_train_samples, Mapping) else False

    def dataset_train_fn(input_context):
        global_batchsz = args.batch_size
        base_batchsz = input_context.get_per_replica_batch_size(global_batchsz)
        ds = None
        num_shards = input_context.num_input_pipelines
        index = input_context.input_pipeline_id
        if is_curriculum:
            for sub in num_train_samples.keys():
                train_curr_dir = os.path.join(args.train_dir, str(sub))
                batchsz_scale_factor = args.nctx // sub
                this_batchsz = base_batchsz * batchsz_scale_factor
                curr_ds = get_dataset(train_curr_dir, args.file_type, args.num_train_workers, num_shards, index, causal=args.causal).batch(this_batchsz, drop_remainder=True)
                if ds is None:
                    ds = curr_ds
                else:
                    ds = ds.concatenate(curr_ds)
        else:
            ds = get_dataset(args.train_dir, args.file_type, args.num_train_workers, num_shards, index, causal=args.causal).batch(base_batchsz)
        return ds
    train_loader = strategy.experimental_distribute_datasets_from_function(dataset_train_fn)

    def dataset_test_fn(input_context):
        global_batchsz = args.batch_size
        base_batchsz = input_context.get_per_replica_batch_size(global_batchsz)
        num_shards = input_context.num_input_pipelines
        index = input_context.input_pipeline_id
        ds = None
        if is_curriculum:
            for sub in num_valid_samples.keys():
                valid_curr_dir = os.path.join(args.valid_dir, str(sub))
                batchsz_scale_factor = args.nctx // sub
                this_batchsz = base_batchsz * batchsz_scale_factor
                curr_ds = get_dataset(valid_curr_dir, args.file_type, args.num_train_workers, num_shards, index, causal=args.causal).batch(
                    this_batchsz, drop_remainder=True)
                if ds is None:
                    ds = curr_ds
                else:
                    ds = ds.concatenate(curr_ds)
        else:
            ds = get_dataset(args.valid_dir, args.file_type, args.num_train_workers, num_shards, index, shuffle=False, causal=args.causal).batch(base_batchsz)
        return ds
    valid_loader = strategy.experimental_distribute_datasets_from_function(dataset_test_fn)

    os.makedirs(args.basedir, exist_ok=True)
    # We want to make sure to save our input vocab into the basedir for reuse later
    write_json(vocabs, os.path.join(args.basedir, 'vocabs.json'))
    embeddings = {'x': preproc_data['embeddings']}
    logger.info("Loaded embeddings")

    logger.info("Loaded datasets")
    logger.info("Using embedding type [%s]", args.embed_type)
    model = create_model(args, embeddings)
    if isinstance(model, GatedMLPLanguageModel) and is_curriculum:
        raise Exception("Variable tensor lengths not currently supported for gMLP")
    logger.info("Loaded model and loss")

    eff_batch_size = args.batch_size * args.grad_accum
    logger.info(f"eff batch size: {eff_batch_size}, {args.batch_size}(b) x {args.grad_accum}(ga)")
    if is_curriculum:
        steps_per_epoch = 0
        steps_per_valid_epoch = 0
        for k, v in num_train_samples.items():
            steps_per_epoch += int(num_train_samples[k] // (eff_batch_size * (args.nctx / k)))
        for k, v in num_valid_samples.items():
            steps_per_valid_epoch += int(num_valid_samples[k] // (args.batch_size * (args.nctx / k)))

    else:
        steps_per_epoch = num_train_samples // eff_batch_size
        steps_per_valid_epoch = num_valid_samples // args.batch_size
    update_on = steps_per_epoch // args.saves_per_epoch
    report_on = max(10, update_on) // 10
    logger.info(f"Steps per epoch: {steps_per_epoch}. Saving checkpoint every {update_on} steps.")

    lr_decay = CosineDecaySchedulerTensorFlow(steps_per_epoch * args.epochs, lr=args.lr)
    linear_warmup = WarmupLinearSchedulerTensorFlow(args.warmup_steps, lr=args.lr)
    lr_sched = CompositeLRSchedulerTensorFlow(linear_warmup, lr_decay)
    optimizer = EagerOptimizer(loss_function, optim=args.optim, lr_function=lr_sched, weight_decay=args.weight_decay, clip=args.clip,
                               lr=args.lr, epsilon=args.eps, beta2=args.beta2)
    checkpoint = tf.train.Checkpoint(optimizer=optimizer.optimizer, model=model)
    checkpoint_manager = tf.train.CheckpointManager(checkpoint,
                                                    directory=args.basedir,
                                                    max_to_keep=5)

    grad_accum = GradientAccumulator()

    start_epoch = 0
    if args.restart:
        # The global step gets automatically updated here
        # so we dont have to worry about our LR regimen
        checkpoint.restore(checkpoint_manager.latest_checkpoint)
        current_step = optimizer.global_step
        start_epoch = current_step // steps_per_epoch

    def _replicated_forward_step(inputs):
        """This runs on a single replica"""
        x, y = inputs
        per_replica_grads, per_replica_loss = optimizer.get_grads_and_loss(model, {'x': x}, y, num_replicas * args.grad_accum)
        grad_accum(per_replica_grads)
        return per_replica_loss

    def _replicated_optz_step():
        optimizer.apply_grads(model, grad_accum.gradients)

    @tf.function
    def _distributed_optz_step():
        strategy.run(_replicated_optz_step)

    @tf.function
    def _distributed_forward_step(inputs: Tuple[tf.Tensor, tf.Tensor]):
        """Runs across multiple replicas and aggregates the results.

        :param inputs:
        :return:
        """
        per_replica_loss = strategy.run(_replicated_forward_step, args=(inputs,))
        return strategy.reduce(tf.distribute.ReduceOp.SUM, per_replica_loss, axis=None)

    def _replicated_test_step(inputs):
        """This runs on a single replica"""
        x, y = inputs
        per_replica_loss = loss_function(model, {'x': x}, y) / num_replicas
        return per_replica_loss

    @tf.function
    def _distributed_test_step(inputs: Tuple[tf.Tensor, tf.Tensor]):
        """Runs across multiple replicas and aggregates the results.

        :param inputs:
        :return:
        """
        per_replica_loss = strategy.run(_replicated_test_step, args=(inputs,))
        return strategy.reduce(tf.distribute.ReduceOp.SUM, per_replica_loss, axis=None)

    timer = Timer()
    with strategy.scope():

        for epoch in range(start_epoch, args.epochs):
            timer.start()
            SET_TRAIN_FLAG(True)
            logger.info('Starting epoch %d', epoch + 1)
            avg_loss = Average('average_train_loss')
            metrics = {}
            train_iter = iter(train_loader)

            step_loss = 0
            iterations = steps_per_epoch * args.batch_size
            for i in range(iterations):

                try:

                    loss = _distributed_forward_step(next(train_iter))
                    step_loss += loss

                    if (i + 1) % args.grad_accum == 0:
                        # This does a gradient update
                        _distributed_optz_step()
                        # Now reset the gradient accumulator
                        grad_accum.reset()
                        # Now update the loss info
                        tf.summary.scalar("train_loss", data=step_loss, step=optimizer.global_step)
                        avg_loss.update(step_loss.numpy().item())
                        # Now reset the loss
                        step_loss = 0
                        steps = optimizer.global_step.numpy()
                        if (steps + 1) % report_on == 0:
                            logger.info(avg_loss)
                        if (steps + 1) % update_on == 0:
                            elapsed = timer.elapsed(True)
                            logger.info('elapsed time this epoch %d min', elapsed)
                            logger.info('elapsed step time %f steps/min', i/elapsed)
                            checkpoint_manager.save()
                            if args.npz:
                                npz_checkpoint = os.path.join(args.basedir, f'checkpoint-step-{steps}.npz')
                                save_tlm_npz(model, npz_checkpoint)


                except Exception as e:
                    logger.error(e)
                    logger.error(f"Exception at training iter {i+1}/{iterations}. Skipping")
                    pass
                if args.convert_only:
                    logger.warning("Convert only flag specified.  Stopping after one step")
                    steps = optimizer.global_step.numpy()
                    npz_checkpoint = os.path.join(args.basedir, f'checkpoint-step-{steps}.npz')
                    save_tlm_npz(model, npz_checkpoint)
                    return



            # How much time elapsed in minutes
            train_token_loss = avg_loss.avg
            # This is the average training token-level loss across all machines
            # This is the token-level training perplexity
            train_token_ppl = math.exp(train_token_loss)
            metrics['train_elapsed_min'] = timer.elapsed(True)
            metrics['average_train_loss'] = train_token_loss
            metrics['train_ppl'] = train_token_ppl
            metrics['lr'] = float(lr_sched(tf.cast(optimizer.global_step, tf.float32)).numpy().item())

            avg_valid_loss = Average('average_valid_loss')
            timer.start()
            SET_TRAIN_FLAG(False)
            valid_iter = iter(valid_loader)
            for i in range(steps_per_valid_epoch):
                try:
                    valid_loss = _distributed_test_step(next(valid_iter))
                    tf.summary.scalar('valid_loss', data=valid_loss, step=optimizer.global_step)
                    avg_valid_loss.update(valid_loss.numpy().item())
                except Exception as e:
                    logger.error(f"Exception at validation step {i+1}/{steps_per_valid_epoch}. Skipping")
                    pass

            valid_token_loss = avg_valid_loss.avg
            valid_token_ppl = math.exp(valid_token_loss)
            metrics['valid_elapsed_min'] = timer.elapsed(True)
            metrics['average_valid_loss'] = valid_token_loss
            metrics['average_valid_word_ppl'] = valid_token_ppl
            logger.info(json.dumps(metrics, indent=4))
def train():
    parser = ArgumentParser()
    parser.add_argument("--basedir", type=str)
    parser.add_argument("--train_file",
                        type=str,
                        required=True,
                        help='File path to use for train file')
    parser.add_argument("--valid_file",
                        type=str,
                        required=True,
                        help='File path to use for valid file')
    parser.add_argument("--dataset_key",
                        default="paired",
                        help="dataset key for basedir")
    parser.add_argument(
        "--embed_type",
        type=str,
        default='default',
        choices=["default", "positional", "learned-positional"],
        help="register label of the embeddings")
    parser.add_argument("--d_model",
                        type=int,
                        default=512,
                        help="Model dimension (and embedding dsz)")
    parser.add_argument("--d_ff", type=int, default=2048, help="FFN dimension")
    parser.add_argument(
        "--d_k",
        type=int,
        default=None,
        help="Dimension per head.  Use if num_heads=1 to reduce dims")
    parser.add_argument("--num_heads",
                        type=int,
                        default=8,
                        help="Number of heads")
    parser.add_argument("--num_layers",
                        type=int,
                        default=8,
                        help="Number of layers")
    parser.add_argument("--windowed_ra",
                        type=str2bool,
                        default=False,
                        help="whether prevent attention beyond rpr_k")
    parser.add_argument("--num_train_workers",
                        type=int,
                        default=4,
                        help="Number train workers")
    parser.add_argument("--nctx",
                        type=int,
                        default=256,
                        help="Max input length")
    parser.add_argument("--tgt_nctx",
                        type=int,
                        help="Max output length, default to args.nctx")
    parser.add_argument("--file_type", default='json', help="Suffix for data")
    parser.add_argument("--record_keys", default=['x', 'y'], nargs='+')
    parser.add_argument("--batch_size",
                        type=int,
                        default=256,
                        help="Batch Size")
    parser.add_argument("--subword_model_file",
                        type=str,
                        help="The BPE model file",
                        required=True)
    parser.add_argument("--subword_vocab_file",
                        type=str,
                        help="The BPE subword vocab",
                        required=True)
    parser.add_argument("--dropout", type=float, default=0.1, help="Dropout")
    parser.add_argument("--lr_scheduler",
                        type=str,
                        default='cosine',
                        help="The type of learning rate decay scheduler")
    parser.add_argument("--lr_decay_steps",
                        type=int,
                        help="decay steps of lr scheduler")
    parser.add_argument("--lr_decay_rate",
                        type=float,
                        help="decay rate of lr scheduler")
    parser.add_argument("--lr_alpha",
                        type=float,
                        help="parameter alpha for cosine decay scheduler")
    parser.add_argument("--optim",
                        default="adamw",
                        type=str,
                        help="Optimizer to use (defaults to adamw)")
    parser.add_argument("--lr",
                        type=float,
                        default=4.0e-4,
                        help="Learning rate")
    parser.add_argument("--clip",
                        type=float,
                        default=1.0,
                        help="Clipping gradient norm")
    parser.add_argument("--weight_decay",
                        type=float,
                        default=1.0e-2,
                        help="Weight decay")
    parser.add_argument("--epochs",
                        type=int,
                        default=32,
                        help="Num training epochs")
    parser.add_argument(
        "--restart_from",
        type=str,
        help="Option allows you to restart from a previous checkpoint")
    parser.add_argument(
        "--restart_tt",
        type=str,
        help="Optional param for legacy checkpoints (step|epoch)")
    parser.add_argument("--warmup_steps",
                        type=int,
                        default=10000,
                        help="Num warmup steps")
    parser.add_argument("--saves_per_epoch",
                        type=int,
                        default=10,
                        help="The number of checkpoints to save per epoch")
    parser.add_argument("--reduction_d_k",
                        type=int,
                        default=64,
                        help="Dimensions of Key and Query in the single headed"
                        "reduction layers")
    parser.add_argument(
        "--reduction_type",
        type=str,
        default="2ha",
        help="If using a dual encoder, specifies the reduction type")
    parser.add_argument(
        "--unfreeze_after_step",
        default=0,
        type=int,
        help=
        "Unfreeze encoders after step, ignored if we dont have a checkpoint")
    parser.add_argument(
        "--stacking_layers",
        type=int,
        nargs='+',
        default=[],
        help="Hidden sizes of the dense stack (ff2 from the convert paper)")
    parser.add_argument("--layer_drop",
                        type=float,
                        default=0.0,
                        help="LayerDrop to apply")
    parser.add_argument("--ff_pdrop",
                        type=float,
                        default=0.1,
                        help="Dropout in the dense stack")

    parser.add_argument("--reader_type",
                        type=str,
                        default='preprocessed',
                        choices=['ntp', 'nsp', 'preprocessed', 'tfrecord'])
    parser.add_argument(
        "--model_type",
        default="dual-encoder",
        choices=["dual-encoder", "encoder-decoder", "transformer-bow"])
    parser.add_argument("--src_begin_tok", type=str, nargs='+', default=[])
    parser.add_argument("--src_end_tok",
                        type=str,
                        nargs='+',
                        default=['<EOS>'])
    parser.add_argument("--tgt_begin_tok",
                        type=str,
                        nargs='+',
                        default=['<GO>'])
    parser.add_argument("--tgt_end_tok",
                        type=str,
                        nargs='+',
                        default=['<EOS>'])
    parser.add_argument('--lower', type=baseline.str2bool, default=False)
    parser.add_argument(
        "--loss",
        type=str,
        default='symmetric',
        choices=['triplet', 'all', 'all_mean', 'contrastive', 'symmetric'])
    parser.add_argument(
        "--learn_temp",
        type=str2bool,
        default=True,
        help=
        "If 'constrastive' or 'symmetric' loss, should we learn the temperature scaling"
    )
    parser.add_argument(
        "--init_temp",
        type=float,
        help="Initialize the temperature for 'contrastive' or 'symmetric' loss"
    )
    parser.add_argument(
        '--rpr_k',
        help=
        'Relative attention positional sizes pass 0 if you dont want relative attention',
        type=int,
        default=[8],
        nargs='+')
    parser.add_argument("--device",
                        type=str,
                        default="cuda" if torch.cuda.is_available() else "cpu",
                        help="Device (cuda or cpu)")
    parser.add_argument("--distributed",
                        type=str2bool,
                        default=False,
                        help="Are we doing distributed training?")
    parser.add_argument(
        "--local_rank",
        type=int,
        default=-1,
        help=
        "Local rank for distributed training (-1 means use the environment variables to find)"
    )
    parser.add_argument("--save_npz",
                        type=str2bool,
                        default=False,
                        help="Whether save npz checkpoint")

    args = parser.parse_args()

    if args.basedir is None:
        args.basedir = '{}-{}-paired-{}-bpe-{}'.format(args.model_type,
                                                       args.reader_type,
                                                       args.dataset_key,
                                                       os.getpid())
    logging.basicConfig(
        level=logging.INFO if args.local_rank in [-1, 0] else logging.WARN)

    num_gpus = get_num_gpus_multiworker()
    args.distributed = args.distributed or num_gpus > 1
    logger.info(f"Using {num_gpus} GPUs in this job.")

    if args.distributed:
        args.device, updated_local_rank = init_distributed(args.local_rank)
        args.local_rank = updated_local_rank

    if not args.tgt_nctx:
        args.tgt_nctx = args.nctx
    reader = MultiFileDatasetReader(args.nctx,
                                    args.tgt_nctx,
                                    args.src_begin_tok,
                                    args.src_end_tok,
                                    args.tgt_begin_tok,
                                    args.tgt_end_tok,
                                    args.subword_model_file,
                                    args.subword_vocab_file,
                                    args.file_type,
                                    reader_type=args.reader_type,
                                    record_keys=args.record_keys,
                                    lower=args.lower)

    vocab = reader.build_vocab()
    # If we are not using chars, then use 'x' for both input and output
    preproc_data = baseline.embeddings.load_embeddings(
        'x',
        dsz=args.d_model,
        known_vocab=vocab['x'],
        preserve_vocab_indices=True,
        embed_type=args.embed_type)
    vocabs = preproc_data['vocab']

    os.makedirs(args.basedir, exist_ok=True)
    # We want to make sure to save our input vocab into the basedir for reuse later
    write_json(vocabs, os.path.join(args.basedir, 'vocabs.json'))
    embeddings = preproc_data['embeddings']
    logger.info("Loaded embeddings")

    train_set = reader.load(args.train_file, vocabs)
    valid_set = reader.load(args.valid_file,
                            vocabs,
                            distribute=False,
                            shuffle=False)

    train_loader = DataLoader(train_set,
                              batch_size=args.batch_size,
                              num_workers=args.num_train_workers)
    valid_loader = DataLoader(valid_set, batch_size=args.batch_size)
    logger.info("Loaded datasets")
    logger.info("Using embedding type [%s]", args.embed_type)

    if len(args.rpr_k) == 0 or args.rpr_k[0] < 1:
        rpr_k = None
    elif len(args.rpr_k) == 1:
        rpr_k = args.rpr_k[0]
    else:
        rpr_k = args.rpr_k

    model = create_model(embeddings,
                         d_model=args.d_model,
                         d_ff=args.d_ff,
                         dropout=args.dropout,
                         num_heads=args.num_heads,
                         num_layers=args.num_layers,
                         model_type=args.model_type,
                         rpr_k=rpr_k,
                         d_k=args.d_k,
                         reduction_d_k=args.reduction_d_k,
                         stacking_layers=args.stacking_layers,
                         ff_pdrop=args.ff_pdrop,
                         windowed_ra=args.windowed_ra,
                         reduction_type=args.reduction_type,
                         layer_drop=args.layer_drop,
                         logger=logger)

    model.to(args.device)
    if args.model_type == 'encoder-decoder':
        run_step = run_step_s2s
    else:
        run_step = run_step_dual
        logger.info(
            f"Creating {args.loss}, init temperature: {args.init_temp}, learnable: {args.learn_temp}"
        )
    loss_function = model.create_loss(loss_type=args.loss,
                                      init_temp=args.init_temp,
                                      learn_temp=args.learn_temp)
    loss_function.to(args.device)

    logger.info("Created model and loss")

    steps_per_epoch = len(train_loader) // num_gpus
    valid_steps = len(valid_loader)
    update_on = steps_per_epoch // args.saves_per_epoch
    report_on = max(10, update_on) // 10
    logger.info(
        f"Steps per epoch per GPU: {steps_per_epoch}. Saving checkpoint every {update_on} steps."
    )
    lr_decay = get_lr_decay(args.lr_scheduler,
                            args.lr,
                            steps_per_epoch,
                            args.epochs,
                            logger,
                            decay_steps=args.lr_decay_steps,
                            decay_rate=args.lr_decay_rate,
                            alpha=args.lr_alpha)
    linear_warmup = WarmupLinearSchedulerPyTorch(args.warmup_steps, lr=args.lr)
    lr_sched = CompositeLRScheduler(linear_warmup, lr_decay, lr=args.lr)

    global_step = 0
    start_epoch = 0

    if args.restart_from:

        if args.unfreeze_after_step > 0 and args.model_type == "dual-encoder":
            logger.info(f"Encoders will be frozen until step %d",
                        args.unfreeze_after_step)
        global_step, start_epoch = reload_from_checkpoint(
            args.model_type, args.restart_from, args.restart_tt, model,
            steps_per_epoch)
        logger.info(
            "Restarting from a previous checkpoint %s.\n\tStarting at global_step=%d, epoch=%d",
            args.restart_from, global_step, start_epoch + 1)

    target = model if args.model_type == 'encoder-decoder' else loss_function

    optimizer = OptimizerManager(target,
                                 global_step,
                                 optim=args.optim,
                                 lr=args.lr,
                                 lr_function=lr_sched,
                                 weight_decay=args.weight_decay)
    logger.info("Model has {:,} parameters".format(
        sum(p.numel() for p in target.parameters() if p.requires_grad)))
    # Prepare model for distributed training if needed
    if args.distributed:
        model = DistributedDataParallel(model,
                                        device_ids=[args.device],
                                        output_device=args.device)
        logger.info("Model located on %d", args.local_rank)

    model_base = os.path.join(args.basedir, 'checkpoint')
    steps = global_step
    timer = Timer()

    for epoch in range(start_epoch, args.epochs):
        avg_loss = Average('average_train_loss')
        metrics = {}
        optimizer.zero_grad()
        timer.start()
        model.train()
        train_itr = iter(train_loader)
        for i in range(steps_per_epoch):
            batch = next(train_itr)

            if steps > args.unfreeze_after_step and hasattr(
                    model, 'freeze') and model.freeze:
                logging.info("Unfreezing encoders at step %d", steps)
                model.freeze = False
            steps += 1

            x, y = batch
            loss = run_step(x, y, model, loss_function, args.device,
                            args.distributed)
            loss.backward()
            avg_loss.update(loss.item())

            torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip)
            optimizer.step()
            optimizer.zero_grad()
            if (i + 1) % report_on == 0:
                logging.info(avg_loss)
            if (i + 1) % update_on == 0 and args.local_rank < 1:
                elapsed = timer.elapsed(True)
                logging.info('elapsed time this epoch %d min', elapsed)
                logging.info('elapsed step time %f steps/min', i / elapsed)
                logging.info('LR: %f', optimizer.current_lr)
                save_checkpoint(model,
                                model_base,
                                steps,
                                tick_type='step',
                                save_npz=args.save_npz)

        # How much time elapsed in minutes
        elapsed = timer.elapsed(True)
        train_avg_loss = avg_loss.avg
        # This is the average training token-level loss across all machines
        # This is the token-level training perplexity
        metrics['train_elapsed_min'] = elapsed
        metrics['average_train_loss'] = train_avg_loss
        if args.local_rank < 1:
            avg_valid_loss = Average('average_valid_loss')
            timer.start()
            model.eval()
            valid_itr = iter(valid_loader)
            for j in range(valid_steps):
                with torch.no_grad():
                    batch = next(valid_itr)
                    x, y = batch
                    loss = run_step(x, y, model, loss_function, args.device,
                                    args.distributed)
                avg_valid_loss.update(loss.item())

            valid_avg_loss = avg_valid_loss.avg

            elapsed = timer.elapsed(True)
            metrics['valid_elapsed_min'] = elapsed

            metrics['average_valid_loss'] = valid_avg_loss
            logger.info(metrics)
            save_checkpoint(model,
                            model_base,
                            epoch,
                            tick_type='epoch',
                            save_npz=args.save_npz)
示例#17
0
def train():
    parser = ArgumentParser()
    parser.add_argument("--basedir", type=str)
    parser.add_argument("--train_file",
                        type=str,
                        required=True,
                        help='File path to use for train file')
    parser.add_argument("--valid_file",
                        type=str,
                        required=True,
                        help='File path to use for valid file')
    parser.add_argument("--dataset_key",
                        default="tlm",
                        help="dataset key for basedir")
    parser.add_argument(
        "--embed_type",
        type=str,
        default='default',
        choices=["default", "positional", "learned-positional"],
        help="register label of the embeddings")
    parser.add_argument("--d_model",
                        type=int,
                        default=512,
                        help="Model dimension (and embedding dsz)")
    parser.add_argument("--d_ff", type=int, default=2048, help="FFN dimension")
    parser.add_argument(
        "--d_k",
        type=int,
        default=None,
        help="Dimension per head.  Use if num_heads=1 to reduce dims")
    parser.add_argument("--num_heads",
                        type=int,
                        default=8,
                        help="Number of heads")
    parser.add_argument("--num_layers",
                        type=int,
                        default=8,
                        help="Number of layers")
    parser.add_argument("--num_train_workers",
                        type=int,
                        default=4,
                        help="Number train workers")
    parser.add_argument("--nctx",
                        type=int,
                        default=256,
                        help="Max input length")
    parser.add_argument("--file_type",
                        default='json',
                        help="Glob pattern for data")
    parser.add_argument("--batch_size",
                        type=int,
                        default=256,
                        help="Batch Size")
    parser.add_argument("--subword_model_file",
                        type=str,
                        help="The BPE model file",
                        required=True)
    parser.add_argument("--subword_vocab_file",
                        type=str,
                        help="The BPE subword vocab",
                        required=True)
    parser.add_argument("--dropout", type=float, default=0.1, help="Dropout")
    parser.add_argument("--ffn_pdrop",
                        type=float,
                        default=0.0,
                        help="Dropout in the dense stack")
    parser.add_argument("--layer_drop",
                        type=float,
                        default=0.0,
                        help="LayerDrop to apply")
    parser.add_argument("--lr_scheduler",
                        type=str,
                        default='cosine',
                        help="The type of learning rate decay scheduler")
    parser.add_argument("--lr_decay_steps",
                        type=int,
                        help="decay steps of lr scheduler")
    parser.add_argument("--lr_decay_rate",
                        type=float,
                        help="decay rate of lr scheduler")
    parser.add_argument("--lr_alpha",
                        type=float,
                        help="parameter alpha for cosine decay scheduler")
    parser.add_argument("--optim",
                        default="adamw",
                        type=str,
                        help="Optimizer to use (defaults to adamw)")
    parser.add_argument("--lr",
                        type=float,
                        default=4.0e-4,
                        help="Learning rate")
    parser.add_argument("--clip",
                        type=float,
                        default=1.0,
                        help="Clipping gradient norm")
    parser.add_argument("--weight_decay",
                        type=float,
                        default=1.0e-2,
                        help="Weight decay")
    parser.add_argument("--epochs",
                        type=int,
                        default=32,
                        help="Num training epochs")
    parser.add_argument(
        "--restart_from",
        type=str,
        help="Option allows you to restart from a previous checkpoint")
    parser.add_argument("--restart_tt",
                        type=str,
                        help="Optional param for legacy checkpoints",
                        choices=['step', 'epoch', 'ignore'])
    parser.add_argument("--warmup_steps",
                        type=int,
                        default=10000,
                        help="Num warmup steps")
    parser.add_argument("--saves_per_epoch",
                        type=int,
                        default=10,
                        help="The number of checkpoints to save per epoch")
    parser.add_argument("--mlm",
                        type=str2bool,
                        default=True,
                        help="Use Masked Language Model (MLM) objective")
    parser.add_argument("--preprocessed",
                        type=str2bool,
                        default=True,
                        help="Has the data already been preprocessed?")
    parser.add_argument(
        '--rpr_k',
        help=
        'Relative attention positional sizes pass 0 if you dont want relative attention',
        type=int,
        default=[8],
        nargs='+')
    parser.add_argument(
        '--rpr_value_on',
        type=str2bool,
        default=True,
        help=
        "In relative attention, whether add positional correction to values in addition to the "
        "correction to attention matrix")
    parser.add_argument("--windowed_ra",
                        type=str2bool,
                        default=False,
                        help="whether prevent attention beyond rpr_k")
    parser.add_argument("--device",
                        type=str,
                        default="cuda" if torch.cuda.is_available() else "cpu",
                        help="Device (cuda or cpu)")
    parser.add_argument("--distributed",
                        type=str2bool,
                        default=False,
                        help="Are we doing distributed training?")
    parser.add_argument(
        "--local_rank",
        type=int,
        default=-1,
        help=
        "Local rank for distributed training (-1 means use the environment variables to find)"
    )

    args = parser.parse_args()

    if args.basedir is None:
        args.basedir = 'lm-{}-bpe-{}'.format(args.dataset_key, os.getpid())
    logging.basicConfig(
        level=logging.INFO if args.local_rank in [-1, 0] else logging.WARN)

    num_gpus = get_num_gpus_multiworker()
    args.distributed = args.distributed or num_gpus > 1
    logger.info(f"Using {num_gpus} GPUs in this job.")

    do_on_demand_masking = args.mlm and not args.preprocessed
    if do_on_demand_masking:
        logger.info(f"On-demand masking is turned on")
    if args.distributed:
        args.device, updated_local_rank = init_distributed(args.local_rank)
        args.local_rank = updated_local_rank

    if args.file_type == 'tfrecord':
        reader_type = 'tfrecord'
    elif args.preprocessed:
        reader_type = 'preprocessed'
    else:
        reader_type = 'lang'
    reader = MultiFileDatasetReader(
        src_nctx=args.nctx,
        model_file=args.subword_model_file,
        vocab_file=args.subword_vocab_file,
        file_type=args.file_type,
        reader_type=reader_type,
        record_keys=['x', 'y'] if args.mlm else ['x'])

    # This looks a bit funny but the streaming reader ignores our vocab and gives us the one from the subword_model
    # However, we do need to get counts from our dataset for validation so we can calculate the perplexity
    vocab = reader.build_vocab([args.valid_file])
    # If we are not using chars, then use 'x' for both input and output
    preproc_data = baseline.embeddings.load_embeddings(
        'x',
        dsz=args.d_model,
        known_vocab=vocab['x'],
        preserve_vocab_indices=True,
        embed_type=args.embed_type)
    vocabs = preproc_data['vocab']

    os.makedirs(args.basedir, exist_ok=True)
    # We want to make sure to save our input vocab into the basedir for reuse later
    write_json(vocabs, os.path.join(args.basedir, 'vocabs.json'))
    embeddings = {'x': preproc_data['embeddings']}
    logger.info("Loaded embeddings")

    train_set = reader.load(args.train_file, vocabs)
    valid_set = reader.load(args.valid_file,
                            vocabs,
                            distribute=False,
                            shuffle=False)

    train_loader = DataLoader(train_set,
                              batch_size=args.batch_size,
                              num_workers=args.num_train_workers)
    valid_loader = DataLoader(valid_set, batch_size=args.batch_size)
    logger.info("Loaded datasets")
    logger.info("Using embedding type [%s]", args.embed_type)

    if args.mlm:
        mask_from = vocabs
        vocab_size = len(mask_from)
        mask_value = mask_from.get("[MASK]")
        if mask_value == -1:
            logger.error(
                "We could not find a suitable masking token in the vocab")
            return

    if len(args.rpr_k) == 0 or args.rpr_k[0] < 1:
        rpr_k = None
    elif len(args.rpr_k) == 1:
        rpr_k = args.rpr_k[0]
    else:
        rpr_k = args.rpr_k

    TLM = TransformerMaskedLanguageModel if args.mlm else TransformerLanguageModel
    model = TLM.create(embeddings,
                       hsz=args.d_model,
                       d_ff=args.d_ff,
                       tie_weights=True,
                       dropout=args.dropout,
                       gpu=False,
                       num_heads=args.num_heads,
                       layers=args.num_layers,
                       rpr_k=rpr_k,
                       d_k=args.d_k,
                       ffn_pdrop=args.ffn_pdrop,
                       windowed_ra=args.windowed_ra,
                       rpr_value_on=args.rpr_value_on,
                       layer_drop=args.layer_drop,
                       src_keys=['x'],
                       tgt_key='x')

    model.to(args.device)
    loss_function = model.create_loss()
    loss_function.to(args.device)

    logger.info("Loaded model and loss")

    steps_per_epoch = len(train_loader) // num_gpus
    valid_steps = len(valid_loader)
    update_on = steps_per_epoch // args.saves_per_epoch
    report_on = max(10, update_on) // 10
    logger.info(
        f"Steps per epoch per GPU: {steps_per_epoch}. Saving checkpoint every {update_on} steps."
    )
    lr_decay = get_lr_decay(args.lr_scheduler,
                            args.lr,
                            steps_per_epoch,
                            args.epochs,
                            logger,
                            decay_steps=args.lr_decay_steps,
                            decay_rate=args.lr_decay_rate,
                            alpha=args.lr_alpha)
    linear_warmup = WarmupLinearSchedulerPyTorch(args.warmup_steps, lr=args.lr)
    lr_sched = CompositeLRScheduler(linear_warmup, lr_decay, lr=args.lr)

    global_step = 0
    start_epoch = 0
    if args.restart_from:

        if args.restart_from.endswith('npz'):
            load_tlm_npz(model, args.restart_from)
        else:
            model.load_state_dict(torch.load(args.restart_from))
        vec = args.restart_from.split("-")

        if args.restart_tt:
            tick_type = args.restart_tt
        else:
            tick_type = vec[-2]
        step_num = int(vec[-1].split(".")[0])
        if tick_type == 'epoch':
            start_epoch = step_num
            global_step = start_epoch * steps_per_epoch

        elif tick_type == 'step':
            start_epoch = step_num // steps_per_epoch
            global_step = step_num
        else:
            logger.warning(
                f"The previous tick was {step_num} but command-line specifies to ignore, setting to 0"
            )

        logger.info(
            "Restarting from a previous checkpoint %s.\n\tStarting at global_step=%d, epoch=%d",
            args.restart_from, global_step, start_epoch + 1)

    optimizer = OptimizerManager(model,
                                 global_step,
                                 optim=args.optim,
                                 lr=args.lr,
                                 lr_function=lr_sched,
                                 weight_decay=args.weight_decay)
    logger.info("Model has {:,} parameters".format(
        sum(p.numel() for p in model.parameters() if p.requires_grad)))

    # Prepare model for distributed training if needed
    if args.distributed:
        # This program assume pure data parallelism, each model is on a single gpu
        # If we wanted to support model and data parallelism we would need to update
        # the selection of gpus based on rank, it would need to select multiple ids
        # based on rank, here we select only a single gpu and use it for input and
        # output.
        model = DistributedDataParallel(model,
                                        device_ids=[args.device],
                                        output_device=args.device)
        logger.info("Model located on %s", args.device)

    model_base = os.path.join(args.basedir, 'checkpoint')
    steps = global_step

    timer = Timer()
    for epoch in range(start_epoch, args.epochs):
        avg_loss = Average('average_train_loss')
        metrics = {}
        optimizer.zero_grad()
        timer.start()
        model.train()
        train_itr = iter(train_loader)
        for i in range(steps_per_epoch):
            batch = next(train_itr)
            steps += 1
            x, y = batch
            inputs = x.to(args.device)
            labels = y.to(args.device)
            if do_on_demand_masking:
                inputs, labels, _ = on_demand_mlm_masking(
                    inputs, labels, mask_value, vocab_size)
            inputs = {'x': inputs}

            labels = labels.transpose(0, 1).contiguous()
            logits = model(inputs, None)[0].transpose(0, 1).contiguous()
            if args.mlm:
                loss = loss_function(logits, labels)
            else:
                shift_logits = logits[:-1]
                shift_labels = labels[1:]
                loss = loss_function(shift_logits, shift_labels)
            loss.backward()
            avg_loss.update(loss.item())

            torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip)
            optimizer.step()
            optimizer.zero_grad()
            if (i + 1) % report_on == 0:
                logging.info(avg_loss)

            if (i + 1) % update_on == 0 and args.local_rank < 1:
                elapsed = timer.elapsed(True)
                logging.info('elapsed time this epoch %d min', elapsed)
                logging.info('elapsed step time %f steps/min', i / elapsed)
                logging.info('LR: %f', optimizer.current_lr)
                save_checkpoint(model, model_base, steps, tick_type='step')

        # How much time elapsed in minutes
        elapsed = timer.elapsed(True)
        train_token_loss = avg_loss.avg
        # This is the average training token-level loss across all machines
        # This is the token-level training perplexity
        train_token_ppl = math.exp(train_token_loss)
        metrics['train_elapsed_min'] = elapsed
        metrics['average_train_loss'] = train_token_loss
        metrics['train_ppl'] = train_token_ppl
        if args.local_rank < 1:
            avg_valid_loss = Average('average_valid_loss')
            timer.start()
            model.eval()
            valid_itr = iter(valid_loader)
            for j in range(valid_steps):
                batch = next(valid_itr)
                with torch.no_grad():
                    x, y = batch
                    inputs = x.to(args.device)
                    labels = y.to(args.device)

                    if do_on_demand_masking:
                        inputs, labels, _ = on_demand_mlm_masking(
                            inputs, labels, mask_value, vocab_size)
                    inputs = {'x': inputs}
                    labels = labels.transpose(0, 1).contiguous()
                    logits = model(inputs, None)[0].transpose(0,
                                                              1).contiguous()
                    if args.mlm:
                        loss = loss_function(logits, labels)
                    else:
                        shift_logits = logits[:-1]
                        shift_labels = labels[1:]
                        loss = loss_function(shift_logits, shift_labels)
                    avg_valid_loss.update(loss.item())

            valid_token_loss = avg_valid_loss.avg
            valid_token_ppl = math.exp(valid_token_loss)

            metrics['valid_elapsed_min'] = timer.elapsed(True)
            metrics['average_valid_loss'] = valid_token_loss
            metrics['average_valid_word_ppl'] = valid_token_ppl
            logger.info(metrics)
            save_checkpoint(model, model_base, epoch, save_npz=True)
示例#18
0
def train():
    parser = ArgumentParser()
    parser.add_argument("--basedir", type=str)
    parser.add_argument("--train_file",
                        type=str,
                        help='Optional file path to use for train file')
    parser.add_argument("--valid_file",
                        type=str,
                        help='Optional file path to use for valid file')
    parser.add_argument("--preprocessed",
                        type=str2bool,
                        default=True,
                        help="Has the data already been preprocessed?")

    parser.add_argument("--gen_d_model",
                        type=int,
                        default=256,
                        help="Model dimension (and embedding dsz)")
    parser.add_argument("--gen_d_ff",
                        type=int,
                        default=1024,
                        help="FFN dimension")
    parser.add_argument(
        "--gen_d_k",
        type=int,
        default=None,
        help="Dimension per head.  Use if num_heads=1 to reduce dims")
    parser.add_argument("--gen_num_heads",
                        type=int,
                        default=8,
                        help="Number of heads")
    parser.add_argument("--gen_num_layers",
                        type=int,
                        default=8,
                        help="Number of layers")
    parser.add_argument("--gen_dropout",
                        type=float,
                        default=0.1,
                        help="Dropout")
    parser.add_argument(
        '--gen_rpr_k',
        help=
        'Relative attention positional sizes pass 0 if you dont want relative attention',
        type=int,
        default=[8],
        nargs='+')

    parser.add_argument("--discrim_d_model",
                        type=int,
                        default=512,
                        help="Model dimension (and embedding dsz)")
    parser.add_argument("--discrim_d_ff",
                        type=int,
                        default=2048,
                        help="FFN dimension")
    parser.add_argument(
        "--discrim_d_k",
        type=int,
        default=None,
        help="Dimension per head.  Use if num_heads=1 to reduce dims")
    parser.add_argument("--discrim_num_heads",
                        type=int,
                        default=8,
                        help="Number of heads")
    parser.add_argument("--discrim_num_layers",
                        type=int,
                        default=8,
                        help="Number of layers")
    parser.add_argument("--discrim_dropout",
                        type=float,
                        default=0.1,
                        help="Dropout")
    parser.add_argument(
        '--discrim_rpr_k',
        help=
        'Relative attention positional sizes pass 0 if you dont want relative attention',
        type=int,
        default=[8],
        nargs='+')

    parser.add_argument("--num_train_workers",
                        type=int,
                        default=4,
                        help="Number train workers")
    parser.add_argument(
        "--nctx",
        type=int,
        default=256,
        help="Max context length (for both encoder and decoder)")
    parser.add_argument(
        "--embed_type",
        type=str,
        default='default',
        choices=["default", "positional", "learned-positional"],
        help="register label of the embeddings")
    parser.add_argument(
        "--pattern",
        default='*.json',
        help=
        "Glob pattern for files, defaults to *.json if preprocessed, *.txt otherwise"
    )
    parser.add_argument("--batch_size",
                        type=int,
                        default=256,
                        help="Batch Size")
    parser.add_argument("--dataset_key",
                        default="reddit",
                        help="dataset key for basedir")
    parser.add_argument("--subword_model_file", type=str, required=True)
    parser.add_argument("--subword_vocab_file", type=str, required=True)
    parser.add_argument("--lr_scheduler",
                        type=str,
                        default='cosine',
                        help="The type of learning rate decay scheduler")
    parser.add_argument("--lr_decay_steps",
                        type=int,
                        help="decay steps of lr scheduler")
    parser.add_argument("--lr_decay_rate",
                        type=float,
                        help="decay rate of lr scheduler")
    parser.add_argument("--lr_alpha",
                        type=float,
                        help="parameter alpha for cosine decay scheduler")
    parser.add_argument("--optim",
                        default="adam",
                        type=str,
                        help="Optimizer to use (defaults to adam)")
    parser.add_argument("--lr",
                        type=float,
                        default=4.0e-4,
                        help="Learning rate")
    parser.add_argument("--clip",
                        type=float,
                        default=1.0,
                        help="Clipping gradient norm")
    parser.add_argument("--gen_loss_scale",
                        type=float,
                        default=50.0,
                        help="Scaling for loss function")
    parser.add_argument("--weight_decay",
                        type=float,
                        default=0.0,
                        help="Weight decay")
    parser.add_argument("--epochs",
                        type=int,
                        default=32,
                        help="Num training epochs")
    parser.add_argument(
        "--restart_from",
        type=str,
        help=
        "Option allows you to restart from the latest checkpoint in a directory"
    )
    parser.add_argument(
        "--restart_tt",
        type=str,
        choices=['step', 'epoch'],
        default='step',
        help="Optional param for legacy checkpoints (step|epoch)")
    parser.add_argument("--warmup_steps",
                        type=int,
                        default=10000,
                        help="Num warmup steps")
    parser.add_argument("--saves_per_epoch",
                        type=int,
                        default=100,
                        help="The number of checkpoints to save per epoch")
    parser.add_argument("--print",
                        type=str2bool,
                        default=True,
                        help="Print some output")
    parser.add_argument("--device",
                        type=str,
                        default="cuda" if torch.cuda.is_available() else "cpu",
                        help="Device (cuda or cpu)")
    parser.add_argument("--distributed",
                        type=str2bool,
                        default=False,
                        help="Are we doing distributed training?")
    parser.add_argument(
        "--local_rank",
        type=int,
        default=-1,
        help=
        "Local rank for distributed training (-1 means use the environment variables to find)"
    )

    args = parser.parse_args()

    if args.train_file and not args.valid_file:
        logger.error(
            "If you provide a train_file, you must provide a valid_file")
        return

    if not args.train_file and args.valid_file:
        logger.error(
            "If you provide a valid_file, you must also provide a train_file")
        return

    if args.basedir is None:
        args.basedir = 'gd-{}-bpe-{}'.format(args.dataset_key, os.getpid())
    logging.basicConfig(
        format="%(name)s: %(levelname)s: %(message)s",
        level=logging.INFO if args.local_rank in [-1, 0] else logging.WARN)
    num_gpus = get_num_gpus_multiworker()
    args.distributed = args.distributed or num_gpus > 1
    logger.info(f"Using {num_gpus} GPUs in this job.")

    if args.distributed:
        args.device, args.local_rank = init_distributed(args.local_rank)

    if not args.preprocessed:
        reader_type = "lang"
        args.pattern = "*.txt"
    else:
        reader_type = "preprocessed"
    reader = MultiFileDatasetReader(args.nctx,
                                    args.subword_model_file,
                                    args.subword_vocab_file,
                                    args.pattern,
                                    reader_type=reader_type)
    #  just return the vocab from the BPE vectorizer
    vocab = reader.build_vocab([])
    gen_embed = baseline.embeddings.load_embeddings('x',
                                                    dsz=args.gen_d_model,
                                                    known_vocab=vocab['x'],
                                                    embed_type=args.embed_type)
    vocabs = gen_embed['vocab']
    index2word = revlut(vocabs)
    discrim_embed = baseline.embeddings.load_embeddings(
        'x',
        dsz=args.discrim_d_model,
        known_vocab=vocab['x'],
        embed_type=args.embed_type)

    os.makedirs(args.basedir, exist_ok=True)
    # We want to make sure to save our input vocab into the basedir for reuse later
    write_json(vocabs, os.path.join(args.basedir, 'vocabs.json'))
    gen_embeddings = {'x': gen_embed['embeddings']}
    discrim_embeddings = {'x': discrim_embed['embeddings']}
    logger.info("Loaded embeddings")

    train_set = reader.load(args.train_file, vocabs)
    valid_set = reader.load(args.valid_file, vocabs)
    train_loader = DataLoader(train_set,
                              batch_size=args.batch_size,
                              num_workers=args.num_train_workers)
    valid_loader = DataLoader(valid_set, batch_size=args.batch_size)
    train_steps_per_epoch = len(train_loader) // (args.batch_size * num_gpus)
    valid_steps_per_epoch = len(valid_loader) // args.batch_size
    logger.info("Loaded datasets")
    logger.info("Using embedding type [%s]", args.embed_type)

    mask_value = vocabs.get("[MASK]", vocabs.get("<MASK>", -1))
    if mask_value == -1:
        logger.error("We could not find a suitable masking token in the vocab")
        return
    os.makedirs(args.basedir, exist_ok=True)
    vocab_size = len(vocabs)

    if len(args.gen_rpr_k) == 0 or args.gen_rpr_k[0] < 1:
        gen_rpr_k = None
    elif len(args.gen_rpr_k) == 1:
        gen_rpr_k = args.gen_rpr_k[0]
    else:
        gen_rpr_k = args.gen_rpr_k

    if len(args.gen_rpr_k) == 0 or args.discrim_rpr_k[0] < 1:
        discrim_rpr_k = None
    elif len(args.discrim_rpr_k) == 1:
        discrim_rpr_k = args.discrim_rpr_k[0]
    else:
        discrim_rpr_k = args.discrim_rpr_k

    gen_model = TransformerMaskedLanguageModel.create(
        gen_embeddings,
        hsz=args.gen_d_model,
        d_ff=args.gen_d_ff,
        tie_weights=True,
        dropout=args.gen_dropout,
        num_heads=args.gen_num_heads,
        layers=args.gen_num_layers,
        rpr_k=gen_rpr_k,
        d_k=args.gen_d_k,
        src_keys=['x'],
        tgt_key='x')
    discrim_model = TransformerDiscriminator(discrim_embeddings,
                                             d_model=args.discrim_d_model,
                                             d_ff=args.discrim_d_ff,
                                             dropout=args.discrim_dropout,
                                             num_heads=args.discrim_num_heads,
                                             layers=args.discrim_num_layers,
                                             activation='gelu',
                                             layer_norm_eps=1.0e-12,
                                             rpr_k=discrim_rpr_k,
                                             d_k=args.discrim_d_k)
    gen_model.to(args.device)
    gen_loss_fn = gen_model.create_loss()

    discrim_model.to(args.device)
    discrim_loss_fn = discrim_model.create_loss()
    logger.info("Loaded model and loss")

    update_on = train_steps_per_epoch // args.saves_per_epoch
    report_on = update_on // 10
    logger.info(
        f"Steps per epoch per GPU: {train_steps_per_epoch}. Saving checkpoint every {update_on} steps."
    )
    lr_decay = get_lr_decay(args.lr_scheduler,
                            args.lr,
                            train_steps_per_epoch,
                            args.epochs,
                            logger,
                            decay_steps=args.lr_decay_steps,
                            decay_rate=args.lr_decay_rate,
                            alpha=args.lr_alpha)
    linear_warmup = WarmupLinearSchedulerPyTorch(args.warmup_steps, lr=args.lr)
    lr_sched = CompositeLRScheduler(linear_warmup, lr_decay, lr=args.lr)

    global_step = 0
    start_epoch = 0
    if args.restart_from:
        if not os.path.isdir(args.restart_from):
            raise Exception(
                f"Cannot restart from {args.restart_from}, directory not found"
            )
        tick_type = args.restart_tt
        discrim_latest, step_num = find_latest_checkpoint(
            args.restart_from, wildcard=f'checkpoint-discrim-{tick_type}')
        gen_latest, _ = find_latest_checkpoint(
            args.restart_from, wildcard=f'checkpoint-gen-{tick_type}')
        discrim_model.load_state_dict(torch.load(discrim_latest))
        gen_model.load_state_dict(torch.load(gen_latest))
        if tick_type == 'step':
            start_epoch = step_num // train_steps_per_epoch
            global_step = step_num
        else:
            start_epoch = step_num
            global_step = train_steps_per_epoch * start_epoch

    parameters = list(discrim_model.parameters()) + list(
        gen_model.parameters())
    optz = OptimizerManager(parameters,
                            global_step,
                            optim=args.optim,
                            lr=args.lr,
                            lr_function=lr_sched,
                            weight_decay=args.weight_decay)
    logger.info("Generator has {:,} parameters".format(
        sum(p.numel() for p in gen_model.parameters() if p.requires_grad)))
    logger.info("Discriminator has {:,} parameters".format(
        sum(p.numel() for p in discrim_model.parameters() if p.requires_grad)))
    # Prepare model for distributed training if needed
    if args.distributed:
        # This program assume pure data parallelism, each model is on a single gpu
        # If we wanted to support model and data parallelism we would need to update
        # the selection of gpus based on rank, it would need to select multiple ids
        # based on rank, here we select only a single gpu and use it for input and
        # output.
        gen_model = DistributedDataParallel(gen_model,
                                            device_ids=[args.device],
                                            output_device=args.device)
        discrim_model = DistributedDataParallel(discrim_model,
                                                device_ids=[args.device],
                                                output_device=args.device)
        logger.info("Model located on %s", args.device)

    # This is the training loop
    steps = global_step
    model_base = os.path.join(args.basedir, 'checkpoint')
    discrim_base = f'{model_base}-discrim'
    gen_base = f'{model_base}-gen'
    do_on_demand_masking = not args.preprocessed
    if do_on_demand_masking:
        logger.info(f"On-demand masking is turned on")

    timer = Timer()

    for epoch in range(start_epoch, args.epochs):
        gen_model.train()
        discrim_model.train()
        avg_gen_loss = Average('average_train_gen_loss')
        avg_discrim_loss = Average('average_train_discrim_loss')
        avg_discrim_acc = Average('average_train_discrim_acc')
        avg_train_loss = Average('average5_train_loss')
        metrics = {}
        optz.zero_grad()
        timer.start()
        print(f'Starting epoch {epoch + 1}')
        train_iter = iter(train_loader)
        valid_iter = iter(valid_loader)

        for i in range(train_steps_per_epoch):
            steps += 1
            x, y = next(train_iter)
            do_report = (i + 1) % report_on == 0 and args.print
            gen_loss_step, discrim_loss_step, acc = gen_vs_discrim(
                x, y, args.device, gen_model, gen_loss_fn, discrim_model,
                discrim_loss_fn, mask_value, vocab_size, index2word, do_report,
                do_on_demand_masking)
            avg_gen_loss.update(gen_loss_step.item())
            total_loss_step = gen_loss_step + args.gen_loss_scale * discrim_loss_step
            total_loss_step.backward()
            avg_discrim_loss.update(discrim_loss_step.item())
            avg_train_loss.update(total_loss_step.item())
            avg_discrim_acc.update(acc)
            torch.nn.utils.clip_grad_norm_(parameters, args.clip)
            optz.step()
            optz.zero_grad()
            if (i + 1) % report_on == 0:
                logging.info('Loss g=%f, d=%f total=%f, Per token acc=%f',
                             avg_gen_loss.avg, avg_discrim_loss.avg,
                             avg_train_loss.avg, avg_discrim_acc.avg)

            if (i + 1) % update_on == 0 and args.local_rank < 1:
                elapsed = timer.elapsed(True)
                logging.info('elapsed time this epoch %d min', elapsed)
                logging.info('elapsed step time %f steps/min', i / elapsed)
                logging.info('LR: %f', optz.current_lr)
                save_checkpoint(gen_model, gen_base, steps, tick_type='step')
                save_checkpoint(discrim_model,
                                discrim_base,
                                steps,
                                tick_type='step')

        # How much time elapsed in minutes
        elapsed = timer.elapsed(True)
        # This is the average training token-level loss across all machines
        # This is the token-level training perplexity
        metrics['train_elapsed_min'] = elapsed
        metrics['average_train_gen_loss'] = avg_gen_loss.avg
        metrics['average_train_discrim_loss'] = avg_discrim_loss.avg
        metrics[
            'average_train_discrim_per_token_accuracy'] = avg_discrim_acc.avg
        metrics['average_train_loss'] = avg_train_loss.avg

        if args.local_rank < 1:
            avg_valid_gen_loss = Average('average_valid_gen_loss')
            avg_valid_discrim_loss = Average('average_valid_discrim_loss')
            avg_valid_discrim_acc = Average('average_valid_discrim_acc')
            avg_valid_loss = Average('average_valid_loss')
            timer.start()
            gen_model.eval()
            discrim_model.eval()
            for i in range(valid_steps_per_epoch):
                with torch.no_grad():
                    x, y = next(valid_iter)
                    do_report = (i + 1) % report_on == 0 and args.print
                    gen_loss_step, discrim_loss_step, acc = gen_vs_discrim(
                        x, y, args.device, gen_model, gen_loss_fn,
                        discrim_model, discrim_loss_fn, mask_value, vocab_size,
                        index2word, do_report, do_on_demand_masking)
                    avg_valid_gen_loss.update(gen_loss_step.item())
                    avg_valid_discrim_acc.update(acc)
                    avg_valid_discrim_loss.update(discrim_loss_step.item())
                    total_loss_step = gen_loss_step + args.gen_loss_scale * discrim_loss_step
                    avg_valid_loss.update(total_loss_step.item())
            elapsed = timer.elapsed(True)
            metrics['valid_elapsed_min'] = elapsed
            metrics['average_valid_gen_loss'] = avg_valid_gen_loss.avg
            metrics['average_valid_discrim_loss'] = avg_valid_discrim_loss.avg
            metrics[
                'average_valid_discrim_per_token_accuracy'] = avg_valid_discrim_acc.avg
            metrics['average_valid_loss'] = avg_valid_loss.avg
            logger.info(metrics)
            save_checkpoint(discrim_model,
                            discrim_base,
                            epoch,
                            tick_type='epoch',
                            save_npz=True)
            save_checkpoint(gen_model,
                            gen_base,
                            epoch,
                            tick_type='epoch',
                            save_npz=True)
示例#19
0
def update_cache(key, data_download_cache):
    dcache = read_json(os.path.join(data_download_cache, DATA_CACHE_CONF))
    if key not in dcache:
        return
    del dcache[key]
    write_json(dcache, os.path.join(data_download_cache, DATA_CACHE_CONF))
示例#20
0
def main():
    parser = ArgumentParser()
    parser.add_argument("--basedir", type=str)
    parser.add_argument("--train_dir",
                        type=str,
                        required=True,
                        help='Training directory')
    parser.add_argument("--valid_dir",
                        type=str,
                        required=True,
                        help='Validation directory')
    parser.add_argument(
        "--train_md",
        type=str,
        help="Training metadata YAML, defaults to `{train_dir}/md.yml`")
    parser.add_argument(
        "--valid_md",
        type=str,
        help="Validation metadata YAML, defaults to `{valid_dir}/md.yml`")
    parser.add_argument("--label_file",
                        type=str,
                        help="JSON file mapping labels to integers",
                        default="labels.json")
    parser.add_argument("--dataset_key",
                        default="tlm",
                        help="dataset key for basedir")
    parser.add_argument(
        "--embed_type",
        type=str,
        default='default',
        choices=["default", "positional", "learned-positional"],
        help="register label of the embeddings")
    parser.add_argument("--d_model",
                        type=int,
                        default=512,
                        help="Model dimension (and embedding dsz)")
    parser.add_argument("--d_ff", type=int, default=2048, help="FFN dimension")
    parser.add_argument("--num_heads",
                        type=int,
                        default=8,
                        help="Number of heads")
    parser.add_argument("--num_layers",
                        type=int,
                        default=8,
                        help="Number of layers")
    parser.add_argument("--num_train_workers",
                        type=int,
                        default=4,
                        help="Number train workers")
    parser.add_argument("--distribute",
                        type=str,
                        default="mirror",
                        choices=["mirror", "tpu", "nccl"])
    parser.add_argument("--tpu_ep",
                        type=str,
                        help="The TPU endpoint if using `distribute=tpu`")
    parser.add_argument("--nctx",
                        type=int,
                        default=256,
                        help="Max input length")
    parser.add_argument("--file_type",
                        default='tfrecord',
                        choices=['json', 'tfrecord'],
                        help="Glob pattern for data")
    parser.add_argument("--batch_size",
                        type=int,
                        default=256,
                        help="Batch Size")
    parser.add_argument("--subword_model_file",
                        type=str,
                        help="The BPE model file",
                        required=True)
    parser.add_argument("--subword_vocab_file",
                        type=str,
                        help="The BPE subword vocab",
                        required=True)
    parser.add_argument("--dropout", type=float, default=0.1, help="Dropout")
    parser.add_argument("--ffn_pdrop",
                        type=float,
                        default=0.0,
                        help="Dropout in the dense stack")
    parser.add_argument("--layer_drop",
                        type=float,
                        default=0.0,
                        help="LayerDrop to apply")
    parser.add_argument("--optim",
                        default="adamw",
                        type=str,
                        help="Optimizer to use (defaults to adamw)")
    parser.add_argument("--lr",
                        type=float,
                        default=4.0e-4,
                        help="Learning rate")
    parser.add_argument("--clip",
                        type=float,
                        default=1.0,
                        help="Clipping gradient norm")
    parser.add_argument("--weight_decay",
                        type=float,
                        default=1.0e-2,
                        help="Weight decay")
    parser.add_argument("--epochs",
                        type=int,
                        default=32,
                        help="Num training epochs")
    parser.add_argument(
        "--restart",
        type=str2bool,
        help="Option allows you to restart from a previous checkpoint")
    parser.add_argument("--warmup_steps",
                        type=int,
                        default=10000,
                        help="Num warmup steps")
    parser.add_argument("--saves_per_epoch",
                        type=int,
                        default=10,
                        help="The number of checkpoints to save per epoch")
    parser.add_argument(
        '--rpr_k',
        help=
        'Relative attention positional sizes pass 0 if you dont want relative attention',
        type=int,
        default=[8],
        nargs='+')
    parser.add_argument(
        '--rpr_value_on',
        type=str2bool,
        default=True,
        help=
        "In relative attention, whether add positional correction to values in addition to the "
        "correction to attention matrix")
    parser.add_argument('--windowed_ra',
                        type=str2bool,
                        default=False,
                        help="whether prevent attention beyond rpr_k")
    parser.add_argument("--strategy",
                        help="Training strategy, defaults to `mirror`",
                        choices=["mirror"])
    parser.add_argument("--npz",
                        help="Should we write out NPZ files?",
                        type=str2bool,
                        default=False)
    parser.add_argument("--tb",
                        help="Turn on tensorboard?",
                        type=str2bool,
                        default=False)
    parser.add_argument(
        "--convert_only",
        help="Should we just convert this file to NPZ and exit?",
        type=str2bool,
        default=False)
    args = parser.parse_args()
    SET_TRAIN_FLAG(True)

    if args.convert_only:
        args.restart = True

    if args.basedir is None:
        args.basedir = f'lm-{args.dataset_key}-bpe-{os.getpid()}'
    logging.basicConfig(level=logging.INFO)
    logger.info(f"Writing results to {args.basedir}")

    if args.tb:
        logdir = f"logs/scalars/{os.getpid()}"
        file_writer = tf.summary.create_file_writer(logdir + "/metrics")
        file_writer.set_as_default()
        logger.info(f"Set up tensorboard logdir {logdir}")

    strategy = create_distribute_strategy(args.distribute, args.tpu_ep)
    num_replicas = strategy.num_replicas_in_sync
    logger.info(f"Using {num_replicas} replicas in this job.")
    vectorizer = BPEVectorizer1D(model_file=args.subword_model_file,
                                 vocab_file=args.subword_vocab_file,
                                 mxlen=args.nctx)
    vocab = {'x': vectorizer.vocab}
    preproc_data = baseline.embeddings.load_embeddings(
        'x',
        dsz=args.d_model,
        known_vocab=vocab['x'],
        preserve_vocab_indices=True,
        embed_type=args.embed_type)
    vocabs = preproc_data['vocab']

    train_md = args.train_md if args.train_md else os.path.join(
        args.train_dir, 'md.yml')
    num_train_samples = get_num_samples(train_md)
    valid_md = args.valid_md if args.valid_md else os.path.join(
        args.valid_dir, 'md.yml')
    num_valid_samples = get_num_samples(valid_md)
    labels = read_json_tf(args.label_file)
    num_labels = len(labels)

    def dataset_train_fn(input_context):
        global_batchsz = args.batch_size
        base_batchsz = input_context.get_per_replica_batch_size(global_batchsz)
        ds = get_dataset(args.train_dir, args.file_type,
                         args.num_train_workers).batch(base_batchsz)
        return ds.shard(input_context.num_input_pipelines,
                        input_context.input_pipeline_id)

    train_loader = strategy.experimental_distribute_datasets_from_function(
        dataset_train_fn)

    def dataset_test_fn(input_context):
        global_batchsz = args.batch_size
        base_batchsz = input_context.get_per_replica_batch_size(global_batchsz)
        ds = get_dataset(args.valid_dir,
                         args.file_type,
                         args.num_train_workers,
                         shuffle=False).batch(base_batchsz)

        return ds.shard(input_context.num_input_pipelines,
                        input_context.input_pipeline_id)

    valid_loader = strategy.experimental_distribute_datasets_from_function(
        dataset_test_fn)

    os.makedirs(args.basedir, exist_ok=True)
    # We want to make sure to save our input vocab into the basedir for reuse later
    write_json(vocabs, os.path.join(args.basedir, 'vocabs.json'))
    embeddings = {'x': preproc_data['embeddings']}
    logger.info("Loaded embeddings")

    logger.info("Loaded datasets")
    logger.info("Using embedding type [%s]", args.embed_type)
    if len(args.rpr_k) == 0 or args.rpr_k[0] < 1:
        args.rpr_k = None
    elif len(args.rpr_k) == 1:
        args.rpr_k = args.rpr_k[0]

    model = TransformerTagger(num_labels, embeddings, **vars(args))

    logger.info("Loaded model and loss")

    steps_per_epoch = num_train_samples // args.batch_size
    steps_per_valid_epoch = num_valid_samples // args.batch_size
    update_on = steps_per_epoch // args.saves_per_epoch
    report_on = max(10, update_on) // 10
    logger.info(
        f"Steps per epoch: {steps_per_epoch}. Saving checkpoint every {update_on} steps."
    )

    lr_decay = CosineDecaySchedulerTensorFlow(steps_per_epoch * args.epochs,
                                              lr=args.lr)
    linear_warmup = WarmupLinearSchedulerTensorFlow(args.warmup_steps,
                                                    lr=args.lr)
    lr_sched = CompositeLRSchedulerTensorFlow(linear_warmup, lr_decay)
    optimizer = EagerOptimizer(loss_function,
                               optim=args.optim,
                               lr_function=lr_sched,
                               weight_decay=args.weight_decay,
                               clip=args.clip,
                               lr=args.lr)
    checkpoint = tf.train.Checkpoint(optimizer=optimizer.optimizer,
                                     model=model)
    checkpoint_manager = tf.train.CheckpointManager(checkpoint,
                                                    directory=args.basedir,
                                                    max_to_keep=5)

    start_epoch = 0
    if args.restart:
        # The global step gets automatically updated here
        # so we dont have to worry about our LR regimen
        checkpoint.restore(checkpoint_manager.latest_checkpoint)
        current_step = optimizer.global_step
        start_epoch = current_step // steps_per_epoch

    def _replicated_train_step(inputs):
        """This runs on a single replica"""
        x, y = inputs
        per_replica_loss = optimizer.update(model, {'x': x}, y, num_replicas)
        return per_replica_loss

    @tf.function
    def _distributed_train_step(inputs: Tuple[tf.Tensor, tf.Tensor]):
        """Runs across multiple replicas and aggregates the results.

        :param inputs:
        :return:
        """
        per_replica_loss = strategy.run(_replicated_train_step,
                                        args=(inputs, ))
        return strategy.reduce(tf.distribute.ReduceOp.SUM,
                               per_replica_loss,
                               axis=None)

    def _replicated_test_step(inputs):
        """This runs on a single replica"""
        x, y = inputs
        per_replica_loss = loss_function(model, {'x': x}, y) / num_replicas
        return per_replica_loss

    @tf.function
    def _distributed_test_step(inputs: Tuple[tf.Tensor, tf.Tensor]):
        """Runs across multiple replicas and aggregates the results.

        :param inputs:
        :return:
        """
        per_replica_loss = strategy.run(_replicated_test_step, args=(inputs, ))
        return strategy.reduce(tf.distribute.ReduceOp.SUM,
                               per_replica_loss,
                               axis=None)

    timer = Timer()
    with strategy.scope():

        for epoch in range(start_epoch, args.epochs):
            SET_TRAIN_FLAG(True)
            logger.info('Starting epoch %d', epoch + 1)
            avg_loss = Average('average_train_loss')
            metrics = {}
            timer.start()
            train_iter = iter(train_loader)
            for i in range(steps_per_epoch):

                try:
                    loss = _distributed_train_step(next(train_iter))
                    avg_loss.update(loss.numpy().item())
                    tf.summary.scalar("train_loss",
                                      data=loss,
                                      step=optimizer.global_step)
                except Exception as e:
                    logger.error(
                        f"Exception at training step {i+1}/{steps_per_epoch}. Skipping"
                    )
                    pass
                if args.convert_only:
                    logger.warning(
                        "Convert only flag specified.  Stopping after one step"
                    )
                    steps = optimizer.global_step.numpy()
                    npz_checkpoint = os.path.join(
                        args.basedir, f'checkpoint-step-{steps}.npz')
                    save_tlm_output_npz(model, npz_checkpoint)
                    return

                steps = optimizer.global_step.numpy()
                if (steps + 1) % report_on == 0:
                    logger.info(avg_loss)
                if (steps + 1) % update_on == 0:
                    elapsed = timer.elapsed(True)
                    logger.info('elapsed time this epoch %d min', elapsed)
                    logger.info('elapsed step time %f steps/min', i / elapsed)
                    checkpoint_manager.save()
                    if args.npz:

                        npz_checkpoint = os.path.join(
                            args.basedir, f'checkpoint-step-{steps}.npz')
                        save_tlm_output_npz(model, npz_checkpoint)

            # How much time elapsed in minutes
            elapsed = timer.elapsed(True)
            train_token_loss = avg_loss.avg
            # This is the average training token-level loss across all machines
            # This is the token-level training perplexity
            train_token_ppl = math.exp(train_token_loss)
            metrics['train_elapsed_min'] = elapsed
            metrics['average_train_loss'] = train_token_loss
            metrics['train_ppl'] = train_token_ppl
            metrics['lr'] = float(
                lr_sched(tf.cast(optimizer.global_step,
                                 tf.float32)).numpy().item())

            avg_valid_loss = Average('average_valid_loss')
            timer.start()
            SET_TRAIN_FLAG(False)
            valid_iter = iter(valid_loader)
            for i in range(steps_per_valid_epoch):
                try:
                    valid_loss = _distributed_test_step(next(valid_iter))
                    tf.summary.scalar('valid_loss',
                                      data=valid_loss,
                                      step=optimizer.global_step)
                    avg_valid_loss.update(valid_loss.numpy().item())
                except Exception as e:
                    logger.error(
                        f"Exception at validation step {i+1}/{steps_per_valid_epoch}. Skipping"
                    )
                    pass

            valid_token_loss = avg_valid_loss.avg
            valid_token_ppl = math.exp(valid_token_loss)

            elapsed = timer.elapsed(True)

            metrics['valid_elapsed_min'] = elapsed
            metrics['average_valid_loss'] = valid_token_loss
            metrics['average_valid_word_ppl'] = valid_token_ppl
            logger.info(json.dumps(metrics, indent=4))
示例#21
0
def main():
    parser = argparse.ArgumentParser(
        description='Convert text into MLM fixed width contexts')

    parser.add_argument(
        '--input_files',
        help=
        'The text to classify as a string, or a path to a file with each line as an example',
        type=str)
    parser.add_argument(
        '--annot_files',
        help=
        'The text to classify as a string, or a path to a file with each line as an example',
        type=str)
    parser.add_argument('--codes', help='BPE codes')
    parser.add_argument('--vocab', help='BPE vocab')
    parser.add_argument("--nctx",
                        type=int,
                        default=256,
                        help="Max input length")
    parser.add_argument("--fmt",
                        type=str,
                        default='json',
                        choices=['json', 'tsv', 'tfrecord'])
    parser.add_argument("--fields",
                        type=str,
                        nargs="+",
                        default=["x_str", "y_str"])
    parser.add_argument("--output_dir",
                        type=str,
                        help="Output base name, e.g. /path/to/output/record")
    parser.add_argument("--max_file_size",
                        type=int,
                        default=100,
                        help="Shard size, defaults to 100MB")
    parser.add_argument(
        "--stride",
        type=int,
        help="Tokens to stride before next read, defaults to `nctx`")
    parser.add_argument("--tok_on_eol", type=str, default="<EOS>")
    parser.add_argument("--cased", type=baseline.str2bool, default=True)
    parser.add_argument("--document_vocab", type=str, default="document.vocab")
    parser.add_argument("--label_vocab", type=str, default="label.vocab")
    parser.add_argument("--valid_split", type=float, default=0.05)
    parser.add_argument("--prefix", default="<GO>")
    parser.add_argument("--suffix", default="<EOS>")
    parser.add_argument("--pg_name",
                        choices=["tqdm", "default"],
                        default="default")

    args = parser.parse_args()
    annot_files = list(Path(args.annot_files).iterdir())
    valid_split = int(len(annot_files) * args.valid_split)
    VALID_FILES = annot_files[:valid_split]
    TRAIN_FILES = annot_files[valid_split:]

    VECTORIZER = BPEVectorizer1D(
        transform_fn=baseline.lowercase if not args.cased else lambda x: x,
        model_file=args.codes,
        vocab_file=args.vocab,
        mxlen=1024)
    NCTX = args.nctx - 2
    PREFIX = (
        VECTORIZER.vocab[args.prefix],
        Offsets.GO,
    )
    SUFFIX = (
        VECTORIZER.vocab[args.suffix],
        Offsets.EOS,
    )

    DOC2WORD = read_vocab_file(args.document_vocab)
    label2word = read_vocab_file(args.label_vocab)
    LABELS = {Offsets.VALUES[k]: k for k in range(Offsets.OFFSET)}
    for label in label2word.values():
        for prefix in ["B", "I", "E", "S"]:
            LABELS[f"{prefix}-{label}"] = len(LABELS)

    LABELS["O"] = len(LABELS)

    if not os.path.exists(args.output_dir):
        os.makedirs(args.output_dir)
    write_json(LABELS, os.path.join(args.output_dir, 'labels.json'))
    valid_dir = os.path.join(args.output_dir, 'valid')
    train_dir = os.path.join(args.output_dir, 'train')
    makedir_if_none(args.output_dir)
    makedir_if_none(train_dir)
    makedir_if_none(valid_dir)

    logger.info("Converting validation files")
    fw_valid = create_file_writer(args.fmt, os.path.join(valid_dir, 'valid'),
                                  args.fields, args.max_file_size)
    write_files(VALID_FILES, args.input_files, fw_valid, valid_dir,
                args.pg_name)

    logger.info("Converting training files")
    fw_train = create_file_writer(args.fmt, os.path.join(train_dir, 'train'),
                                  args.fields, args.max_file_size)
    write_files(TRAIN_FILES, args.input_files, fw_train, train_dir,
                args.pg_name)