Ejemplo n.º 1
0
    def __init__(self, model, **kwargs):
        super().__init__()
        if type(model) is dict:
            checkpoint = kwargs.get('checkpoint')
            if checkpoint:
                model['checkpoint'] = checkpoint
            model = create_model_for('tagger', **model)
        self.grad_accum = int(kwargs.get('grad_accum', 1))
        self.gpus = int(kwargs.get('gpus', 1))
        # By default support IOB1/IOB2
        self.span_type = kwargs.get('span_type', 'iob')
        self.verbose = kwargs.get('verbose', False)

        logger.info('Setting span type %s', self.span_type)
        self.model = model
        self.idx2label = revlut(self.model.labels["tags"])
        self.idx2classlabel = revlut(self.model.labels["class_labels"])
        self.clip = float(kwargs.get('clip', 5))
        self.optimizer = OptimizerManager(self.model, **kwargs)
        if self.gpus > 1:
            logger.info(
                "Trainer for PyTorch tagger currently doesnt support multiple GPUs.  Setting to 1"
            )
            self.gpus = 1
        if self.gpus > 0 and self.model.gpu:
            self.model = model.cuda()
        else:
            logger.warning("Requested training on CPU.  This will be slow.")

        self.nsteps = kwargs.get('nsteps', six.MAXSIZE)
Ejemplo n.º 2
0
def generate_text(model, start_string, temperature=1.0, num_generate=20):
    input_eval = np.array(
        [vocabs["word"].get(s) for s in start_string.split()],
        dtype=np.int32).reshape(1, -1)
    rlut = revlut(vocabs["word"])
    # Empty string to store our results
    text_generated = [start_string]

    h = None
    for i in range(num_generate):
        predictions, h = model({"word": input_eval, "h": h})
        # remove the batch dimension
        predictions = tf.nn.softmax(predictions / temperature, axis=-1)
        predictions = tf.squeeze(predictions, 0)

        # using a multinomial distribution to predict the word returned by the model
        predicted_id = tf.random.categorical(predictions,
                                             num_samples=1)[-1, 0].numpy()
        # We pass the predicted word as the next input to the model
        # along with the previous hidden state
        input_eval = tf.expand_dims([predicted_id], 0)

        text_generated.append(rlut[predicted_id])

    return text_generated
Ejemplo n.º 3
0
    def __init__(self, model, span_type, verbose):
        """Construct from an existing model

        :param model: A model
        :param span_type: (`str`) The span type
        :param verbose: (`bool`) Be verbose?
        """
        self.model = model

        self.idx2label = revlut(model.labels["tags"])
        self.idx2classlabel = revlut(model.labels["class_labels"])
        self.cm = None

        self.span_type = span_type
        if verbose:
            print('Setting span type {}'.format(self.span_type))
        self.verbose = verbose
Ejemplo n.º 4
0
def run():
    parser = ArgumentParser()
    parser.add_argument("--basedir", type=str)
    parser.add_argument("--checkpoint", type=str, help='Checkpoint name or directory to load')
    parser.add_argument("--sample", type=str2bool, help='Sample from the decoder?  Defaults to `false`', default=0)
    parser.add_argument("--query", type=str, default='hello , <unk> are you today ?')
    parser.add_argument("--dataset_cache", type=str, default=os.path.expanduser('~/.bl-data'),
                        help="Path or url of the dataset cache")
    parser.add_argument("--d_model", type=int, default=512, help="Model dimension (and embedding dsz)")
    parser.add_argument("--d_ff", type=int, default=2048, help="FFN dimension")
    parser.add_argument("--d_k", type=int, default=None, help="Dimension per head.  Use if num_heads=1 to reduce dims")
    parser.add_argument("--num_heads", type=int, default=8, help="Number of heads")
    parser.add_argument("--num_layers", type=int, default=8, help="Number of layers")
    parser.add_argument("--nctx", type=int, default=128, help="Max context length (for both encoder and decoder)")
    parser.add_argument("--embed_type", type=str, default='default',
                        help="register label of the embeddings, so far support positional or learned-positional")
    parser.add_argument("--subword_model_file", type=str, required=True)
    parser.add_argument("--subword_vocab_file", type=str, required=True)
    parser.add_argument("--use_cls", type=str2bool, default=False)
    parser.add_argument('--end_token', default='<EOU>')
    parser.add_argument("--activation", type=str, default='gelu')
    parser.add_argument('--rpr_k', help='Relative attention positional sizes pass 0 if you dont want relative attention',
                        type=int, default=[8], nargs='+')
    parser.add_argument("--y_only", type=str2bool, default=False)
    parser.add_argument("--device", type=str,
                        default="cuda" if torch.cuda.is_available() else "cpu",
                        help="Device (cuda or cpu)")
    args = parser.parse_args()

    if torch.cuda.device_count() == 1:
        torch.cuda.set_device(0)
        args.device = torch.device("cuda", 0)


    if os.path.isdir(args.checkpoint):
        checkpoint, _ = find_latest_checkpoint(args.checkpoint)
        logger.warning("Found latest checkpoint %s", checkpoint)
    else:
        checkpoint = args.checkpoint

    cls = None if not args.use_cls else '[CLS]'
    end = args.end_token
    vectorizer = BPEVectorizer1D(model_file=args.subword_model_file, vocab_file=args.subword_vocab_file, mxlen=args.nctx, emit_begin_tok=cls, emit_end_tok=end)
    vocab = vectorizer.vocab.copy()
    # If we are not using chars, then use 'x' for both input and output
    preproc_data = baseline.embeddings.load_embeddings('x', dsz=args.d_model, counts=False, known_vocab=vocab, embed_type=args.embed_type, preserve_vocab_indices=True)
    embeddings = preproc_data['embeddings']
    vocab = preproc_data['vocab']
    model = create_model(embeddings, d_model=args.d_model, d_ff=args.d_ff, num_heads=args.num_heads, num_layers=args.num_layers,
                         rpr_k=args.rpr_k, d_k=args.d_k, checkpoint_name=checkpoint, activation=args.activation)
    model.to(args.device)

    index2word = revlut(vocab)
    print('[Query]', args.query)
    bpe_out = decode_sentence(model, vectorizer, args.query.split(), vocab, index2word, args.device, sample=args.sample, y_only=args.y_only)

    print('[Response]', ' '.join(bpe_out))
Ejemplo n.º 5
0
    def __init__(self,
                 filenames,
                 known_vocab,
                 counts=True,
                 unif_weight=None,
                 normalize=False,
                 **kwargs):
        uw = 0.0 if unif_weight is None else unif_weight

        self.vocab = dict()
        for i, name in enumerate(Offsets.VALUES):
            self.vocab[name] = i
        self.vsz = Offsets.OFFSET

        if counts is True:
            for name in Offsets.VALUES:
                known_vocab.pop(name, 0)
            attested = [v for v, cnt in known_vocab.items() if cnt > 0]
            for k, v in enumerate(attested):
                self.vocab[v] = k + Offsets.OFFSET
                self.vsz += 1
        else:
            self.vocab = known_vocab
            self.vsz = max(self.vocab.values()) + 1

        index2word = revlut(self.vocab)
        # vocab = word2index
        embeddings = []

        for file in filenames:
            embeddings.append(PretrainedEmbeddingsModel(file, known_vocab))

        self.dsz = sum([embedding.dsz for embedding in embeddings])
        self.weights = np.random.uniform(-uw, uw, (self.vsz, self.dsz)).astype(
            np.float32)

        for i in range(len(self.vocab.keys())):
            w = index2word[i]
            e = []
            for emb in embeddings:
                e.append(emb.lookup(w, False))
            self.weights[i] = np.concatenate(e)
        if normalize is True:
            self.weights = norm_weights(self.weights)
Ejemplo n.º 6
0
def generate_text(model, start_string, temperature=1.0, num_generate=20):
    input_eval = torch.tensor([
        vocabs["word"].get(s) for s in start_string.split()
    ]).long().view(1, -1).to(args.device)
    rlut = revlut(vocabs["word"])
    # Empty string to store our results
    text_generated = [start_string]

    h = None
    for i in range(num_generate):
        predictions, h = model({"word": input_eval, "h": h})
        # remove the batch dimension
        predictions = torch.softmax(predictions / temperature, dim=-1)
        predictions = predictions.squeeze(0)

        # using a multinomial distribution to predict the word returned by the model
        predicted_id = torch.multinomial(predictions, num_samples=1)[-1, 0]
        # We pass the predicted word as the next input to the model
        # along with the previous hidden state
        input_eval = predicted_id.unsqueeze(0).unsqueeze(0)

        text_generated.append(rlut[predicted_id.cpu().numpy().item()])

    return text_generated
Ejemplo n.º 7
0
def run():
    parser = ArgumentParser()
    parser.add_argument("--basedir", type=str)
    parser.add_argument("--checkpoint",
                        type=str,
                        help='Checkpoint name or directory to load')
    parser.add_argument("--sample",
                        type=str2bool,
                        help='Sample from the decoder?  Defaults to `false`',
                        default=0)
    parser.add_argument("--vocab",
                        type=str,
                        help='Vocab file to load',
                        required=False)
    parser.add_argument("--input", type=str, default='hello how are you ?')
    parser.add_argument("--dataset_cache",
                        type=str,
                        default=os.path.expanduser('~/.bl-data'),
                        help="Path or url of the dataset cache")
    parser.add_argument("--d_model",
                        type=int,
                        default=512,
                        help="Model dimension (and embedding dsz)")
    parser.add_argument("--d_ff", type=int, default=2048, help="FFN dimension")
    parser.add_argument(
        "--d_k",
        type=int,
        default=None,
        help="Dimension per head.  Use if num_heads=1 to reduce dims")
    parser.add_argument("--num_heads",
                        type=int,
                        default=8,
                        help="Number of heads")
    parser.add_argument("--num_layers",
                        type=int,
                        default=8,
                        help="Number of layers")
    parser.add_argument(
        "--nctx",
        type=int,
        default=256,
        help="Max context length (for both encoder and decoder)")
    parser.add_argument(
        "--embed_type",
        type=str,
        default='default',
        help=
        "register label of the embeddings, so far support positional or learned-positional"
    )
    parser.add_argument("--subword_model_file", type=str, required=True)
    parser.add_argument("--subword_vocab_file", type=str, required=True)
    parser.add_argument("--batchsz",
                        help="Size of a batch to pass at once",
                        default=4,
                        type=int)
    parser.add_argument("--beamsz",
                        help="Size of beam to use",
                        default=4,
                        type=int)
    parser.add_argument("--activation", type=str, default='relu')
    parser.add_argument(
        '--rpr_k',
        help=
        'Relative attention positional sizes pass 0 if you dont want relative attention',
        type=int,
        default=[8] * 8,
        nargs='+')
    #parser.add_argument("--go_token", default="<GO>")
    parser.add_argument("--end_token", default="<EOS>")
    parser.add_argument("--output_file", type=str)
    parser.add_argument("--show_query",
                        type=str2bool,
                        default=False,
                        help="Show the original query as well")
    parser.add_argument("--device",
                        type=str,
                        default="cuda" if torch.cuda.is_available() else "cpu",
                        help="Device (cuda or cpu)")
    args = parser.parse_args()

    if torch.cuda.device_count() == 1:
        torch.cuda.set_device(0)
        args.device = torch.device("cuda", 0)

    if os.path.isdir(args.checkpoint):
        checkpoint, _ = find_latest_checkpoint(args.checkpoint)
        logger.warning("Found latest checkpoint %s", checkpoint)
    else:
        checkpoint = args.checkpoint

    vectorizer = BPEVectorizer1D(model_file=args.subword_model_file,
                                 vocab_file=args.subword_vocab_file,
                                 mxlen=args.nctx,
                                 emit_end_tok=args.end_token)
    vocab = vectorizer.vocab
    # If we are not using chars, then use 'x' for both input and output
    preproc_data = baseline.embeddings.load_embeddings(
        'x',
        dsz=args.d_model,
        counts=False,
        known_vocab=vocab,
        embed_type=args.embed_type)
    embeddings = preproc_data['embeddings']
    vocab = preproc_data['vocab']
    model = create_model(embeddings,
                         d_model=args.d_model,
                         d_ff=args.d_ff,
                         num_heads=args.num_heads,
                         num_layers=args.num_layers,
                         rpr_k=args.rpr_k,
                         d_k=args.d_k,
                         checkpoint_name=checkpoint,
                         activation=args.activation,
                         device=args.device)
    model.to(args.device)

    index2word = revlut(vocab)
    wf = None
    if args.output_file:
        wf = open(args.output_file, "w")

    batches = []
    if os.path.exists(args.input) and os.path.isfile(args.input):
        with open(args.input, 'rt', encoding='utf-8') as f:
            batch = []
            for line in f:
                text = line.strip().split()
                if len(batch) == args.batchsz:
                    batches.append(batch)
                    batch = []
                batch.append(text)

            if len(batch) > 0:
                batches.append(batch)

    else:
        batch = [args.input.split()]
        batches.append(batch)

    for queries in batches:

        outputs = decode_sentences(model, vectorizer, queries, vocab,
                                   index2word, args.beamsz)

        if args.show_query:
            for query, output in zip(queries, outputs):
                print(f"[Query] {query}")
                print(f"[Response] {output}")
        elif wf:
            for query, output in zip(queries, outputs):
                wf.write(f'{output}\n')
                wf.flush()
        else:
            for query, output in zip(queries, outputs):
                print(output)
    if wf:
        wf.close()
Ejemplo n.º 8
0
def train():
    parser = ArgumentParser()
    parser.add_argument("--basedir", type=str)
    parser.add_argument("--train_file",
                        type=str,
                        help='Optional file path to use for train file')
    parser.add_argument("--valid_file",
                        type=str,
                        help='Optional file path to use for valid file')
    parser.add_argument("--preprocessed",
                        type=str2bool,
                        default=True,
                        help="Has the data already been preprocessed?")

    parser.add_argument("--gen_d_model",
                        type=int,
                        default=256,
                        help="Model dimension (and embedding dsz)")
    parser.add_argument("--gen_d_ff",
                        type=int,
                        default=1024,
                        help="FFN dimension")
    parser.add_argument(
        "--gen_d_k",
        type=int,
        default=None,
        help="Dimension per head.  Use if num_heads=1 to reduce dims")
    parser.add_argument("--gen_num_heads",
                        type=int,
                        default=8,
                        help="Number of heads")
    parser.add_argument("--gen_num_layers",
                        type=int,
                        default=8,
                        help="Number of layers")
    parser.add_argument("--gen_dropout",
                        type=float,
                        default=0.1,
                        help="Dropout")
    parser.add_argument(
        '--gen_rpr_k',
        help=
        'Relative attention positional sizes pass 0 if you dont want relative attention',
        type=int,
        default=[8],
        nargs='+')

    parser.add_argument("--discrim_d_model",
                        type=int,
                        default=512,
                        help="Model dimension (and embedding dsz)")
    parser.add_argument("--discrim_d_ff",
                        type=int,
                        default=2048,
                        help="FFN dimension")
    parser.add_argument(
        "--discrim_d_k",
        type=int,
        default=None,
        help="Dimension per head.  Use if num_heads=1 to reduce dims")
    parser.add_argument("--discrim_num_heads",
                        type=int,
                        default=8,
                        help="Number of heads")
    parser.add_argument("--discrim_num_layers",
                        type=int,
                        default=8,
                        help="Number of layers")
    parser.add_argument("--discrim_dropout",
                        type=float,
                        default=0.1,
                        help="Dropout")
    parser.add_argument(
        '--discrim_rpr_k',
        help=
        'Relative attention positional sizes pass 0 if you dont want relative attention',
        type=int,
        default=[8],
        nargs='+')

    parser.add_argument("--num_train_workers",
                        type=int,
                        default=4,
                        help="Number train workers")
    parser.add_argument(
        "--nctx",
        type=int,
        default=256,
        help="Max context length (for both encoder and decoder)")
    parser.add_argument(
        "--embed_type",
        type=str,
        default='default',
        choices=["default", "positional", "learned-positional"],
        help="register label of the embeddings")
    parser.add_argument(
        "--pattern",
        default='*.json',
        help=
        "Glob pattern for files, defaults to *.json if preprocessed, *.txt otherwise"
    )
    parser.add_argument("--batch_size",
                        type=int,
                        default=256,
                        help="Batch Size")
    parser.add_argument("--dataset_key",
                        default="reddit",
                        help="dataset key for basedir")
    parser.add_argument("--subword_model_file", type=str, required=True)
    parser.add_argument("--subword_vocab_file", type=str, required=True)
    parser.add_argument("--lr_scheduler",
                        type=str,
                        default='cosine',
                        help="The type of learning rate decay scheduler")
    parser.add_argument("--lr_decay_steps",
                        type=int,
                        help="decay steps of lr scheduler")
    parser.add_argument("--lr_decay_rate",
                        type=float,
                        help="decay rate of lr scheduler")
    parser.add_argument("--lr_alpha",
                        type=float,
                        help="parameter alpha for cosine decay scheduler")
    parser.add_argument("--optim",
                        default="adam",
                        type=str,
                        help="Optimizer to use (defaults to adam)")
    parser.add_argument("--lr",
                        type=float,
                        default=4.0e-4,
                        help="Learning rate")
    parser.add_argument("--clip",
                        type=float,
                        default=1.0,
                        help="Clipping gradient norm")
    parser.add_argument("--gen_loss_scale",
                        type=float,
                        default=50.0,
                        help="Scaling for loss function")
    parser.add_argument("--weight_decay",
                        type=float,
                        default=0.0,
                        help="Weight decay")
    parser.add_argument("--epochs",
                        type=int,
                        default=32,
                        help="Num training epochs")
    parser.add_argument(
        "--restart_from",
        type=str,
        help=
        "Option allows you to restart from the latest checkpoint in a directory"
    )
    parser.add_argument(
        "--restart_tt",
        type=str,
        choices=['step', 'epoch'],
        default='step',
        help="Optional param for legacy checkpoints (step|epoch)")
    parser.add_argument("--warmup_steps",
                        type=int,
                        default=10000,
                        help="Num warmup steps")
    parser.add_argument("--saves_per_epoch",
                        type=int,
                        default=100,
                        help="The number of checkpoints to save per epoch")
    parser.add_argument("--print",
                        type=str2bool,
                        default=True,
                        help="Print some output")
    parser.add_argument("--device",
                        type=str,
                        default="cuda" if torch.cuda.is_available() else "cpu",
                        help="Device (cuda or cpu)")
    parser.add_argument("--distributed",
                        type=str2bool,
                        default=False,
                        help="Are we doing distributed training?")
    parser.add_argument(
        "--local_rank",
        type=int,
        default=-1,
        help=
        "Local rank for distributed training (-1 means use the environment variables to find)"
    )

    args = parser.parse_args()

    if args.train_file and not args.valid_file:
        logger.error(
            "If you provide a train_file, you must provide a valid_file")
        return

    if not args.train_file and args.valid_file:
        logger.error(
            "If you provide a valid_file, you must also provide a train_file")
        return

    if args.basedir is None:
        args.basedir = 'gd-{}-bpe-{}'.format(args.dataset_key, os.getpid())
    logging.basicConfig(
        format="%(name)s: %(levelname)s: %(message)s",
        level=logging.INFO if args.local_rank in [-1, 0] else logging.WARN)
    num_gpus = get_num_gpus_multiworker()
    args.distributed = args.distributed or num_gpus > 1
    logger.info(f"Using {num_gpus} GPUs in this job.")

    if args.distributed:
        args.device, args.local_rank = init_distributed(args.local_rank)

    if not args.preprocessed:
        reader_type = "lang"
        args.pattern = "*.txt"
    else:
        reader_type = "preprocessed"
    reader = MultiFileDatasetReader(args.nctx,
                                    args.subword_model_file,
                                    args.subword_vocab_file,
                                    args.pattern,
                                    reader_type=reader_type)
    #  just return the vocab from the BPE vectorizer
    vocab = reader.build_vocab([])
    gen_embed = baseline.embeddings.load_embeddings('x',
                                                    dsz=args.gen_d_model,
                                                    known_vocab=vocab['x'],
                                                    embed_type=args.embed_type)
    vocabs = gen_embed['vocab']
    index2word = revlut(vocabs)
    discrim_embed = baseline.embeddings.load_embeddings(
        'x',
        dsz=args.discrim_d_model,
        known_vocab=vocab['x'],
        embed_type=args.embed_type)

    os.makedirs(args.basedir, exist_ok=True)
    # We want to make sure to save our input vocab into the basedir for reuse later
    write_json(vocabs, os.path.join(args.basedir, 'vocabs.json'))
    gen_embeddings = {'x': gen_embed['embeddings']}
    discrim_embeddings = {'x': discrim_embed['embeddings']}
    logger.info("Loaded embeddings")

    train_set = reader.load(args.train_file, vocabs)
    valid_set = reader.load(args.valid_file, vocabs)
    train_loader = DataLoader(train_set,
                              batch_size=args.batch_size,
                              num_workers=args.num_train_workers)
    valid_loader = DataLoader(valid_set, batch_size=args.batch_size)
    train_steps_per_epoch = len(train_loader) // (args.batch_size * num_gpus)
    valid_steps_per_epoch = len(valid_loader) // args.batch_size
    logger.info("Loaded datasets")
    logger.info("Using embedding type [%s]", args.embed_type)

    mask_value = vocabs.get("[MASK]", vocabs.get("<MASK>", -1))
    if mask_value == -1:
        logger.error("We could not find a suitable masking token in the vocab")
        return
    os.makedirs(args.basedir, exist_ok=True)
    vocab_size = len(vocabs)

    if len(args.gen_rpr_k) == 0 or args.gen_rpr_k[0] < 1:
        gen_rpr_k = None
    elif len(args.gen_rpr_k) == 1:
        gen_rpr_k = args.gen_rpr_k[0]
    else:
        gen_rpr_k = args.gen_rpr_k

    if len(args.gen_rpr_k) == 0 or args.discrim_rpr_k[0] < 1:
        discrim_rpr_k = None
    elif len(args.discrim_rpr_k) == 1:
        discrim_rpr_k = args.discrim_rpr_k[0]
    else:
        discrim_rpr_k = args.discrim_rpr_k

    gen_model = TransformerMaskedLanguageModel.create(
        gen_embeddings,
        hsz=args.gen_d_model,
        d_ff=args.gen_d_ff,
        tie_weights=True,
        dropout=args.gen_dropout,
        num_heads=args.gen_num_heads,
        layers=args.gen_num_layers,
        rpr_k=gen_rpr_k,
        d_k=args.gen_d_k,
        src_keys=['x'],
        tgt_key='x')
    discrim_model = TransformerDiscriminator(discrim_embeddings,
                                             d_model=args.discrim_d_model,
                                             d_ff=args.discrim_d_ff,
                                             dropout=args.discrim_dropout,
                                             num_heads=args.discrim_num_heads,
                                             layers=args.discrim_num_layers,
                                             activation='gelu',
                                             layer_norm_eps=1.0e-12,
                                             rpr_k=discrim_rpr_k,
                                             d_k=args.discrim_d_k)
    gen_model.to(args.device)
    gen_loss_fn = gen_model.create_loss()

    discrim_model.to(args.device)
    discrim_loss_fn = discrim_model.create_loss()
    logger.info("Loaded model and loss")

    update_on = train_steps_per_epoch // args.saves_per_epoch
    report_on = update_on // 10
    logger.info(
        f"Steps per epoch per GPU: {train_steps_per_epoch}. Saving checkpoint every {update_on} steps."
    )
    lr_decay = get_lr_decay(args.lr_scheduler,
                            args.lr,
                            train_steps_per_epoch,
                            args.epochs,
                            logger,
                            decay_steps=args.lr_decay_steps,
                            decay_rate=args.lr_decay_rate,
                            alpha=args.lr_alpha)
    linear_warmup = WarmupLinearSchedulerPyTorch(args.warmup_steps, lr=args.lr)
    lr_sched = CompositeLRScheduler(linear_warmup, lr_decay, lr=args.lr)

    global_step = 0
    start_epoch = 0
    if args.restart_from:
        if not os.path.isdir(args.restart_from):
            raise Exception(
                f"Cannot restart from {args.restart_from}, directory not found"
            )
        tick_type = args.restart_tt
        discrim_latest, step_num = find_latest_checkpoint(
            args.restart_from, wildcard=f'checkpoint-discrim-{tick_type}')
        gen_latest, _ = find_latest_checkpoint(
            args.restart_from, wildcard=f'checkpoint-gen-{tick_type}')
        discrim_model.load_state_dict(torch.load(discrim_latest))
        gen_model.load_state_dict(torch.load(gen_latest))
        if tick_type == 'step':
            start_epoch = step_num // train_steps_per_epoch
            global_step = step_num
        else:
            start_epoch = step_num
            global_step = train_steps_per_epoch * start_epoch

    parameters = list(discrim_model.parameters()) + list(
        gen_model.parameters())
    optz = OptimizerManager(parameters,
                            global_step,
                            optim=args.optim,
                            lr=args.lr,
                            lr_function=lr_sched,
                            weight_decay=args.weight_decay)
    logger.info("Generator has {:,} parameters".format(
        sum(p.numel() for p in gen_model.parameters() if p.requires_grad)))
    logger.info("Discriminator has {:,} parameters".format(
        sum(p.numel() for p in discrim_model.parameters() if p.requires_grad)))
    # Prepare model for distributed training if needed
    if args.distributed:
        # This program assume pure data parallelism, each model is on a single gpu
        # If we wanted to support model and data parallelism we would need to update
        # the selection of gpus based on rank, it would need to select multiple ids
        # based on rank, here we select only a single gpu and use it for input and
        # output.
        gen_model = DistributedDataParallel(gen_model,
                                            device_ids=[args.device],
                                            output_device=args.device)
        discrim_model = DistributedDataParallel(discrim_model,
                                                device_ids=[args.device],
                                                output_device=args.device)
        logger.info("Model located on %s", args.device)

    # This is the training loop
    steps = global_step
    model_base = os.path.join(args.basedir, 'checkpoint')
    discrim_base = f'{model_base}-discrim'
    gen_base = f'{model_base}-gen'
    do_on_demand_masking = not args.preprocessed
    if do_on_demand_masking:
        logger.info(f"On-demand masking is turned on")

    timer = Timer()

    for epoch in range(start_epoch, args.epochs):
        gen_model.train()
        discrim_model.train()
        avg_gen_loss = Average('average_train_gen_loss')
        avg_discrim_loss = Average('average_train_discrim_loss')
        avg_discrim_acc = Average('average_train_discrim_acc')
        avg_train_loss = Average('average5_train_loss')
        metrics = {}
        optz.zero_grad()
        timer.start()
        print(f'Starting epoch {epoch + 1}')
        train_iter = iter(train_loader)
        valid_iter = iter(valid_loader)

        for i in range(train_steps_per_epoch):
            steps += 1
            x, y = next(train_iter)
            do_report = (i + 1) % report_on == 0 and args.print
            gen_loss_step, discrim_loss_step, acc = gen_vs_discrim(
                x, y, args.device, gen_model, gen_loss_fn, discrim_model,
                discrim_loss_fn, mask_value, vocab_size, index2word, do_report,
                do_on_demand_masking)
            avg_gen_loss.update(gen_loss_step.item())
            total_loss_step = gen_loss_step + args.gen_loss_scale * discrim_loss_step
            total_loss_step.backward()
            avg_discrim_loss.update(discrim_loss_step.item())
            avg_train_loss.update(total_loss_step.item())
            avg_discrim_acc.update(acc)
            torch.nn.utils.clip_grad_norm_(parameters, args.clip)
            optz.step()
            optz.zero_grad()
            if (i + 1) % report_on == 0:
                logging.info('Loss g=%f, d=%f total=%f, Per token acc=%f',
                             avg_gen_loss.avg, avg_discrim_loss.avg,
                             avg_train_loss.avg, avg_discrim_acc.avg)

            if (i + 1) % update_on == 0 and args.local_rank < 1:
                elapsed = timer.elapsed(True)
                logging.info('elapsed time this epoch %d min', elapsed)
                logging.info('elapsed step time %f steps/min', i / elapsed)
                logging.info('LR: %f', optz.current_lr)
                save_checkpoint(gen_model, gen_base, steps, tick_type='step')
                save_checkpoint(discrim_model,
                                discrim_base,
                                steps,
                                tick_type='step')

        # How much time elapsed in minutes
        elapsed = timer.elapsed(True)
        # This is the average training token-level loss across all machines
        # This is the token-level training perplexity
        metrics['train_elapsed_min'] = elapsed
        metrics['average_train_gen_loss'] = avg_gen_loss.avg
        metrics['average_train_discrim_loss'] = avg_discrim_loss.avg
        metrics[
            'average_train_discrim_per_token_accuracy'] = avg_discrim_acc.avg
        metrics['average_train_loss'] = avg_train_loss.avg

        if args.local_rank < 1:
            avg_valid_gen_loss = Average('average_valid_gen_loss')
            avg_valid_discrim_loss = Average('average_valid_discrim_loss')
            avg_valid_discrim_acc = Average('average_valid_discrim_acc')
            avg_valid_loss = Average('average_valid_loss')
            timer.start()
            gen_model.eval()
            discrim_model.eval()
            for i in range(valid_steps_per_epoch):
                with torch.no_grad():
                    x, y = next(valid_iter)
                    do_report = (i + 1) % report_on == 0 and args.print
                    gen_loss_step, discrim_loss_step, acc = gen_vs_discrim(
                        x, y, args.device, gen_model, gen_loss_fn,
                        discrim_model, discrim_loss_fn, mask_value, vocab_size,
                        index2word, do_report, do_on_demand_masking)
                    avg_valid_gen_loss.update(gen_loss_step.item())
                    avg_valid_discrim_acc.update(acc)
                    avg_valid_discrim_loss.update(discrim_loss_step.item())
                    total_loss_step = gen_loss_step + args.gen_loss_scale * discrim_loss_step
                    avg_valid_loss.update(total_loss_step.item())
            elapsed = timer.elapsed(True)
            metrics['valid_elapsed_min'] = elapsed
            metrics['average_valid_gen_loss'] = avg_valid_gen_loss.avg
            metrics['average_valid_discrim_loss'] = avg_valid_discrim_loss.avg
            metrics[
                'average_valid_discrim_per_token_accuracy'] = avg_valid_discrim_acc.avg
            metrics['average_valid_loss'] = avg_valid_loss.avg
            logger.info(metrics)
            save_checkpoint(discrim_model,
                            discrim_base,
                            epoch,
                            tick_type='epoch',
                            save_npz=True)
            save_checkpoint(gen_model,
                            gen_base,
                            epoch,
                            tick_type='epoch',
                            save_npz=True)
Ejemplo n.º 9
0
def main():
    parser = ArgumentParser()
    parser.add_argument("--basedir", type=str)
    parser.add_argument("--checkpoint",
                        type=str,
                        help='Checkpoint name or directory to load')
    parser.add_argument("--sample",
                        type=str2bool,
                        help='Sample from the decoder?  Defaults to `true`',
                        default=True)
    parser.add_argument("--query",
                        type=str,
                        default='hello , <unk> are you today ?')
    parser.add_argument("--dataset_cache",
                        type=str,
                        default=os.path.expanduser('~/.bl-data'),
                        help="Path or url of the dataset cache")
    parser.add_argument("--d_model",
                        type=int,
                        default=512,
                        help="Model dimension (and embedding dsz)")
    parser.add_argument("--d_ff", type=int, default=2048, help="FFN dimension")
    parser.add_argument(
        "--d_k",
        type=int,
        default=None,
        help="Dimension per head.  Use if num_heads=1 to reduce dims")
    parser.add_argument("--num_heads",
                        type=int,
                        default=8,
                        help="Number of heads")
    parser.add_argument("--num_layers",
                        type=int,
                        default=8,
                        help="Number of layers")
    parser.add_argument(
        "--nctx",
        type=int,
        default=128,
        help="Max context length (for both encoder and decoder)")
    parser.add_argument("--max_seq_len",
                        type=int,
                        default=512,
                        help="Max sequence length for LP")
    parser.add_argument(
        "--embed_type",
        type=str,
        default='default',
        help=
        "register label of the embeddings, so far support positional or learned-positional"
    )
    parser.add_argument("--subword_model_file", type=str, required=True)
    parser.add_argument("--subword_vocab_file", type=str, required=True)
    parser.add_argument("--subword_type",
                        type=str,
                        default="bpe",
                        choices=["gpt2", "bpe", "spm", "wordpiece"])
    parser.add_argument('--go_token')
    parser.add_argument('--end_token', default='<EOU>')
    parser.add_argument("--activation", type=str, default='gelu')
    parser.add_argument(
        '--rpr_k',
        help=
        'Relative attention positional sizes pass 0 if you dont want relative attention',
        type=int,
        default=[8],
        nargs='+')
    parser.add_argument("--rpr_value_on",
                        help="Use different embeddings for RPV key and value",
                        type=str2bool,
                        default=False)
    parser.add_argument("--device",
                        type=str,
                        default="cuda" if torch.cuda.is_available() else "cpu",
                        help="Device (cuda or cpu)")
    parser.add_argument("--transformer_type",
                        help="What TransformerEncoder type to use")
    parser.add_argument('--temperature',
                        help='Sample temperature during generation',
                        default=1.0,
                        type=float)

    args = parser.parse_args()
    if args.sample:
        logger.info("Sampling with temperature %f", args.temperature)
    else:
        logger.info("Sampling is turned off")
    if torch.cuda.device_count() == 1:
        torch.cuda.set_device(0)
        args.device = torch.device("cuda", 0)

    if os.path.isdir(args.checkpoint):
        checkpoint, _ = find_latest_checkpoint(args.checkpoint)
        logger.warning("Found latest checkpoint %s", checkpoint)
    else:
        checkpoint = args.checkpoint

    SubwordType = get_subword_vec1d(args.subword_type)
    vectorizer = SubwordType(model_file=args.subword_model_file,
                             vocab_file=args.subword_vocab_file,
                             mxlen=args.nctx,
                             emit_begin_tok=args.go_token)
    vocab = vectorizer.vocab.copy()
    # If we are not using chars, then use 'x' for both input and output
    preproc_data = baseline.embeddings.load_embeddings(
        'x',
        dsz=args.d_model,
        counts=False,
        known_vocab=vocab,
        embed_type=args.embed_type,
        preserve_vocab_indices=True,
        mxlen=args.max_seq_len)
    embeddings = preproc_data['embeddings']
    vocab = preproc_data['vocab']
    model = create_model(embeddings,
                         d_model=args.d_model,
                         d_ff=args.d_ff,
                         num_heads=args.num_heads,
                         num_layers=args.num_layers,
                         rpr_k=args.rpr_k,
                         rpr_value_on=args.rpr_value_on,
                         d_k=args.d_k,
                         checkpoint_name=checkpoint,
                         activation=args.activation,
                         transformer_type=args.transformer_type)
    model.to(args.device)

    index2word = revlut(vocab)
    print('[Query]', args.query)
    bpe_out = decode_sentence(model,
                              vectorizer,
                              args.query.split(),
                              vocab,
                              index2word,
                              args.device,
                              end_token=args.end_token,
                              sample=args.sample,
                              sample_temperature=args.temperature)
    unbpe = ' '.join(bpe_out).replace('@@ ', '')
    print('[Response]', unbpe)