Beispiel #1
0
    def __init__(self, tokenizer: PreTrainedTokenizer, args, file_path):
        assert os.path.isfile(file_path)
        print('Loading dataset...')
        self.examples = []
        with open(file_path, encoding="utf-8") as f:
            for idx, line in enumerate(tqdm(f.readlines())):
                if len(line) == 0 or line.isspace() or not len(line.split(' ||| ')) == 2:
                    raise ValueError(f'Line {idx+1} is not in the correct format!')
                
                src, tgt = line.split(' ||| ')
                if src.rstrip() == '' or tgt.rstrip() == '':
                    raise ValueError(f'Line {idx+1} is not in the correct format!')
            
                sent_src, sent_tgt = src.strip().split(), tgt.strip().split()
                token_src, token_tgt = [tokenizer.tokenize(word) for word in sent_src], [tokenizer.tokenize(word) for word in sent_tgt]
                wid_src, wid_tgt = [tokenizer.convert_tokens_to_ids(x) for x in token_src], [tokenizer.convert_tokens_to_ids(x) for x in token_tgt]

                ids_src, ids_tgt = tokenizer.prepare_for_model(list(itertools.chain(*wid_src)), return_tensors='pt', max_length=tokenizer.max_len)['input_ids'], tokenizer.prepare_for_model(list(itertools.chain(*wid_tgt)), return_tensors='pt', max_length=tokenizer.max_len)['input_ids']
                if len(ids_src[0]) == 2 or len(ids_tgt[0]) == 2:
                    raise ValueError(f'Line {idx+1} is not in the correct format!')

                bpe2word_map_src = []
                for i, word_list in enumerate(token_src):
                    bpe2word_map_src += [i for x in word_list]
                bpe2word_map_tgt = []
                for i, word_list in enumerate(token_tgt):
                    bpe2word_map_tgt += [i for x in word_list]

                self.examples.append( (ids_src[0], ids_tgt[0], bpe2word_map_src, bpe2word_map_tgt) )
def mask_tokens(inputs: torch.Tensor,
                tokenizer: PreTrainedTokenizer,
                args,
                langid_mask=None,
                lang_id=None) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    Prepare masked tokens inputs/labels for masked language modeling: 80% MASK, 10% random, 10% original.
    """

    if tokenizer.mask_token is None:
        raise ValueError(
            "This tokenizer does not have a mask token which is necessary for masked language modeling. Remove the --mlm flag if you want to use this tokenizer."
        )
    mask_type = torch.bool if torch.__version__ >= '1.2' else torch.uint8

    labels = inputs.clone()
    # We sample a few tokens in each sequence for masked-LM training (with probability args.mlm_probability defaults to 0.15 in Bert/RoBERTa)
    probability_matrix = torch.full(labels.shape, args.mlm_probability)
    special_tokens_mask = [
        tokenizer.get_special_tokens_mask(val, already_has_special_tokens=True)
        for val in labels.tolist()
    ]
    probability_matrix.masked_fill_(torch.tensor(special_tokens_mask,
                                                 dtype=mask_type),
                                    value=0.0)
    if tokenizer._pad_token is not None:
        padding_mask = labels.eq(tokenizer.pad_token_id)
        probability_matrix.masked_fill_(padding_mask, value=0.0)

    if langid_mask is not None:
        padding_mask = langid_mask.eq(lang_id)
        probability_matrix.masked_fill_(padding_mask, value=0.0)

    masked_indices = torch.bernoulli(probability_matrix).to(mask_type)
    labels[~masked_indices] = -100  # We only compute loss on masked tokens

    # 80% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK])
    indices_replaced = torch.bernoulli(torch.full(
        labels.shape, 0.8)).to(mask_type) & masked_indices
    inputs[indices_replaced] = tokenizer.convert_tokens_to_ids(
        tokenizer.mask_token)

    # 10% of the time, we replace masked input tokens with random word
    indices_random = torch.bernoulli(torch.full(
        labels.shape, 0.5)).to(mask_type) & masked_indices & ~indices_replaced
    random_words = torch.randint(len(tokenizer),
                                 labels.shape,
                                 dtype=torch.long)
    inputs[indices_random] = random_words[indices_random]

    # The rest of the time (10% of the time) we keep the masked input tokens unchanged
    return inputs, labels
Beispiel #3
0
    def __init__(self, tokenizer: PreTrainedTokenizer, args, file_path):
        assert os.path.isfile(file_path)
        logger.info("Creating features from dataset file at %s", file_path)

        cache_fn = f'{file_path}.cache'
        if args.cache_data and os.path.isfile(cache_fn) and not args.overwrite_cache:
            logger.info("Loading cached data from %s", cache_fn)
            self.examples = torch.load(cache_fn)
        else:
            self.examples = []
            with open(file_path, encoding="utf-8") as f:
                for line in f.readlines():
                    if len(line) > 0 and not line.isspace() and len(line.split(' ||| ')) == 2:
                        try:
                            src, tgt = line.split(' ||| ')
                            if src.rstrip() == '' or tgt.rstrip() == '':
                                logger.info("Skipping instance %s", line)
                                continue
                        except:
                            logger.info("Skipping instance %s", line)
                            continue
                        sent_src, sent_tgt = src.strip().split(), tgt.strip().split()
                        token_src, token_tgt = [tokenizer.tokenize(word) for word in sent_src], [tokenizer.tokenize(word) for word in sent_tgt]
                        wid_src, wid_tgt = [tokenizer.convert_tokens_to_ids(x) for x in token_src], [tokenizer.convert_tokens_to_ids(x) for x in token_tgt]

                        ids_src, ids_tgt = tokenizer.prepare_for_model(list(itertools.chain(*wid_src)), return_tensors='pt', max_length=tokenizer.max_len)['input_ids'], tokenizer.prepare_for_model(list(itertools.chain(*wid_tgt)), return_tensors='pt', max_length=tokenizer.max_len)['input_ids']
                        if len(ids_src[0]) == 2 or len(ids_tgt[0]) == 2:
                            logger.info("Skipping instance %s", line)
                            continue

                        bpe2word_map_src = []
                        for i, word_list in enumerate(token_src):
                            bpe2word_map_src += [i for x in word_list]
                        bpe2word_map_tgt = []
                        for i, word_list in enumerate(token_tgt):
                            bpe2word_map_tgt += [i for x in word_list]

                        self.examples.append( (ids_src, ids_tgt, bpe2word_map_src, bpe2word_map_tgt) )

            if args.cache_data:
                logger.info("Saving cached data to %s", cache_fn)
                torch.save(self.examples, cache_fn)
Beispiel #4
0
def train(args, train_dataset, model: PreTrainedModel, tokenizer: PreTrainedTokenizer) -> Tuple[int, float]:
    """ Train the model """
    args.train_batch_size = args.per_gpu_train_batch_size * max(1, args.n_gpu)
    def collate(examples):
        model.eval()
        examples_src, examples_tgt, examples_srctgt, examples_tgtsrc, langid_srctgt, langid_tgtsrc, psi_examples_srctgt, psi_labels = [], [], [], [], [], [], [], []
        src_len = tgt_len = 0
        bpe2word_map_src, bpe2word_map_tgt = [], []
        for example in examples:
            end_id = example[0][0][-1].view(-1)

            src_id = example[0][0][:args.block_size]
            src_id = torch.cat([src_id[:-1], end_id])
            tgt_id = example[1][0][:args.block_size]
            tgt_id = torch.cat([tgt_id[:-1], end_id])

            half_block_size = int(args.block_size/2)
            half_src_id = example[0][0][:half_block_size]
            half_src_id = torch.cat([half_src_id[:-1], end_id])
            half_tgt_id = example[1][0][:half_block_size]
            half_tgt_id = torch.cat([half_tgt_id[:-1], end_id])

            examples_src.append(src_id)
            examples_tgt.append(tgt_id)
            src_len = max(src_len, len(src_id))
            tgt_len = max(tgt_len, len(tgt_id))

            srctgt = torch.cat( [half_src_id, half_tgt_id] )
            langid = torch.cat([ torch.ones_like(half_src_id), torch.ones_like(half_tgt_id)*2] )
            examples_srctgt.append(srctgt)
            langid_srctgt.append(langid)

            tgtsrc = torch.cat( [half_tgt_id, half_src_id])
            langid = torch.cat([ torch.ones_like(half_tgt_id), torch.ones_like(half_src_id)*2] )
            examples_tgtsrc.append(tgtsrc)
            langid_tgtsrc.append(langid)

            # [neg, neg] pair
            neg_half_src_id = example[-2][0][:half_block_size]
            neg_half_src_id = torch.cat([neg_half_src_id[:-1], end_id])
            neg_half_tgt_id = example[-1][0][:half_block_size]
            neg_half_tgt_id = torch.cat([neg_half_tgt_id[:-1], end_id])
            if random.random()> 0.5:
                neg_srctgt = torch.cat( [neg_half_src_id, neg_half_tgt_id] )
            else:
                neg_srctgt = torch.cat( [neg_half_tgt_id, neg_half_src_id] )
            psi_examples_srctgt.append(neg_srctgt)
            psi_labels.append(1)
                
            # [pos, neg] pair
            rd = random.random()
            if rd> 0.75:
                neg_srctgt = torch.cat([half_src_id, neg_half_tgt_id])
            elif rd > 0.5:
                neg_srctgt = torch.cat([neg_half_src_id, half_tgt_id])
            elif rd > 0.25:
                neg_srctgt = torch.cat([half_tgt_id, neg_half_src_id])
            else:
                neg_srctgt = torch.cat([neg_half_tgt_id, half_src_id])
            psi_examples_srctgt.append(neg_srctgt)
            psi_labels.append(0)

            bpe2word_map_src.append(example[2])
            bpe2word_map_tgt.append(example[3])
            
        examples_src = pad_sequence(examples_src, batch_first=True, padding_value=tokenizer.pad_token_id)
        examples_tgt = pad_sequence(examples_tgt, batch_first=True, padding_value=tokenizer.pad_token_id)
        examples_srctgt = pad_sequence(examples_srctgt, batch_first=True, padding_value=tokenizer.pad_token_id)
        langid_srctgt = pad_sequence(langid_srctgt, batch_first=True, padding_value=tokenizer.pad_token_id)
        examples_tgtsrc = pad_sequence(examples_tgtsrc, batch_first=True, padding_value=tokenizer.pad_token_id)
        langid_tgtsrc = pad_sequence(langid_tgtsrc, batch_first=True, padding_value=tokenizer.pad_token_id)
        psi_examples_srctgt = pad_sequence(psi_examples_srctgt, batch_first=True, padding_value=tokenizer.pad_token_id)
        psi_labels = torch.tensor(psi_labels)
        if args.n_gpu > 1 or args.local_rank != -1:
            guides = model.module.get_aligned_word(examples_src, examples_tgt, bpe2word_map_src, bpe2word_map_tgt, args.device, src_len, tgt_len, align_layer=args.align_layer, extraction=args.extraction, softmax_threshold=args.softmax_threshold)
        else:
            guides = model.get_aligned_word(examples_src, examples_tgt, bpe2word_map_src, bpe2word_map_tgt, args.device, src_len, tgt_len, align_layer=args.align_layer, extraction=args.extraction, softmax_threshold=args.softmax_threshold)

        return examples_src, examples_tgt, guides, examples_srctgt, langid_srctgt, examples_tgtsrc, langid_tgtsrc, psi_examples_srctgt, psi_labels


    train_sampler = RandomSampler(train_dataset) if args.local_rank == -1 else DistributedSampler(train_dataset)
    train_dataloader = DataLoader(
        train_dataset, sampler=train_sampler, batch_size=args.train_batch_size, collate_fn=collate
    )

    t_total = len(train_dataloader) // args.gradient_accumulation_steps * args.num_train_epochs
    if args.max_steps > 0 and args.max_steps < t_total:
        t_total = args.max_steps
        args.num_train_epochs = args.max_steps // (len(train_dataloader) // args.gradient_accumulation_steps) + 1
    
    # Prepare optimizer and schedule (linear warmup and decay)
    no_decay = ["bias", "LayerNorm.weight"]

    optimizer_grouped_parameters = [
        {
            "params": [p for n, p in model.named_parameters() if ( not (any(nd in n for nd in no_decay)) )],
            "weight_decay": args.weight_decay,
        },
        {"params": [p for n, p in model.named_parameters() if ( (any(nd in n for nd in no_decay)) )], "weight_decay": 0.0},
    ]
    optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon)
    scheduler = get_linear_schedule_with_warmup(
        optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=t_total
    )

    if args.fp16:
        try:
            from apex import amp
        except ImportError:
            raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.")
        model, optimizer = amp.initialize(model, optimizer, opt_level=args.fp16_opt_level)

    # multi-gpu training (should be after apex fp16 initialization)
    if args.n_gpu > 1:
        model = torch.nn.DataParallel(model)

    # Distributed training (should be after apex fp16 initialization)
    if args.local_rank != -1:
        model = torch.nn.parallel.DistributedDataParallel(
            model, device_ids=[args.local_rank], output_device=args.local_rank, find_unused_parameters=True
        )

    # Train!
    logger.info("***** Running training *****")
    logger.info("  Num examples = %d", len(train_dataset))
    logger.info("  Num Epochs = %d", args.num_train_epochs)
    logger.info("  Instantaneous batch size per GPU = %d", args.per_gpu_train_batch_size)
    logger.info(
        "  Total train batch size (w. parallel, distributed & accumulation) = %d",
        args.train_batch_size
        * args.gradient_accumulation_steps
        * (torch.distributed.get_world_size() if args.local_rank != -1 else 1),
    )
    logger.info("  Gradient Accumulation steps = %d", args.gradient_accumulation_steps)
    logger.info("  Total optimization steps = %d", t_total)

    global_step = 0
    # Check if continuing training from a checkpoint
    tr_loss, logging_loss = 0.0, 0.0

    model_to_resize = model.module if hasattr(model, "module") else model  # Take care of distributed/parallel training
    model_to_resize.resize_token_embeddings(len(tokenizer))

    model.zero_grad()
    set_seed(args)  # Added here for reproducibility

    def backward_loss(loss, tot_loss):
        if args.n_gpu > 1:
            loss = loss.mean()  # mean() to average on multi-gpu parallel training
        if args.gradient_accumulation_steps > 1:
            loss = loss / args.gradient_accumulation_steps

        tot_loss += loss.item()
        if args.fp16:
            with amp.scale_loss(loss, optimizer) as scaled_loss:
                scaled_loss.backward()
        else:
            loss.backward()
        return tot_loss

    tqdm_iterator = trange(int(t_total), desc="Iteration", disable=args.local_rank not in [-1, 0])
    for _ in range(int(args.num_train_epochs)):
        for step, batch in enumerate(train_dataloader):
            model.train()

            if args.train_so or args.train_co:
                inputs_src, inputs_tgt = batch[0].clone(), batch[1].clone()
                inputs_src, inputs_tgt = inputs_src.to(args.device), inputs_tgt.to(args.device)
                attention_mask_src, attention_mask_tgt = (inputs_src!=0), (inputs_tgt!=0)
                guide = batch[2].to(args.device)
                loss = model(inputs_src=inputs_src, inputs_tgt=inputs_tgt, attention_mask_src=attention_mask_src, attention_mask_tgt=attention_mask_tgt, guide=guide, align_layer=args.align_layer, extraction=args.extraction, softmax_threshold=args.softmax_threshold, train_so=args.train_so, train_co=args.train_co)
                tr_loss = backward_loss(loss, tr_loss)

            if args.train_mlm:
                inputs_src, labels_src = mask_tokens(batch[0], tokenizer, args)
                inputs_tgt, labels_tgt = mask_tokens(batch[1], tokenizer, args)
                inputs_src, inputs_tgt = inputs_src.to(args.device), inputs_tgt.to(args.device)
                labels_src, labels_tgt = labels_src.to(args.device), labels_tgt.to(args.device)
                loss = model(inputs_src=inputs_src, labels_src=labels_src)
                tr_loss = backward_loss(loss, tr_loss)

                loss = model(inputs_src=inputs_tgt, labels_src=labels_tgt)
                tr_loss = backward_loss(loss, tr_loss)

            if args.train_tlm:
                rand_ids = [0, 1]
                if not args.train_tlm_full:
                    rand_ids = [int(random.random() > 0.5)]
                for rand_id in rand_ids:
                    select_srctgt = batch[int(3+rand_id*2)]
                    select_langid = batch[int(4+rand_id*2)]
                    for lang_id in [1, 2]:
                        inputs_srctgt, labels_srctgt = mask_tokens(select_srctgt, tokenizer, args, select_langid, lang_id)
                        inputs_srctgt, labels_srctgt = inputs_srctgt.to(args.device), labels_srctgt.to(args.device)
                        loss = model(inputs_src=inputs_srctgt, labels_src=labels_srctgt)
                        tr_loss = backward_loss(loss, tr_loss)

            if args.train_psi:
                loss = model(inputs_src=batch[7].to(args.device), labels_psi=batch[8].to(args.device), align_layer=args.align_layer+1)
                tr_loss = backward_loss(loss, tr_loss)

            if (step + 1) % args.gradient_accumulation_steps == 0:
                if args.fp16:
                    torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), args.max_grad_norm)
                else:
                    torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm)
                optimizer.step()
                scheduler.step()  # Update learning rate schedule
                model.zero_grad()
                global_step += 1
                tqdm_iterator.update()

                if args.local_rank in [-1, 0] and args.logging_steps > 0 and global_step % args.logging_steps == 0:
                    logger.info("  Step %s. Training loss = %s", str(global_step), str((tr_loss-logging_loss)/args.logging_steps))
                    logging_loss = tr_loss

                if args.local_rank in [-1, 0] and args.save_steps > 0 and global_step % args.save_steps == 0:
                    checkpoint_prefix = "checkpoint"
                    # Save model checkpoint
                    output_dir = os.path.join(args.output_dir, "{}-{}".format(checkpoint_prefix, global_step))
                    os.makedirs(output_dir, exist_ok=True)
                    model_to_save = (
                        model.module if hasattr(model, "module") else model
                    )  # Take care of distributed/parallel training
                    model_to_save.save_pretrained(output_dir)
                    tokenizer.save_pretrained(output_dir)

                    torch.save(args, os.path.join(output_dir, "training_args.bin"))
                    logger.info("Saving model checkpoint to %s", output_dir)

                    _rotate_checkpoints(args, checkpoint_prefix)

                    torch.save(optimizer.state_dict(), os.path.join(output_dir, "optimizer.pt"))
                    torch.save(scheduler.state_dict(), os.path.join(output_dir, "scheduler.pt"))
                    logger.info("Saving optimizer and scheduler states to %s", output_dir)

            if global_step > t_total:
                break
        if global_step > t_total:
            break

    return global_step, tr_loss / global_step
Beispiel #5
0
    def __init__(self, tokenizer: PreTrainedTokenizer, args, file_path, gold_path):
        assert os.path.isfile(file_path)
        logger.info("Creating features from dataset file at %s", file_path)

        cache_fn = f'{file_path}.cache' if gold_path is None else f'{file_path}.gold.cache'
        if args.cache_data and os.path.isfile(cache_fn) and not args.overwrite_cache:
            logger.info("Loading cached data from %s", cache_fn)
            self.examples = torch.load(cache_fn)
        else:
            # Loading text data
            self.examples = []
            with open(file_path, encoding="utf-8") as f:
                lines = f.readlines()
            
            # Loading gold data
            if gold_path is not None:
                assert os.path.isfile(gold_path)
                logger.info("Loading gold alignments at %s", gold_path)
                with open(gold_path, encoding="utf-8") as f:
                    gold_lines = f.readlines()
                assert len(gold_lines) == len(lines)

            for line_id, line in tqdm(enumerate(lines), desc='Loading data', total=len(lines)):
                if len(line) > 0 and not line.isspace() and len(line.split(' ||| ')) == 2:
                    try:
                        src, tgt = line.split(' ||| ')
                        if src.rstrip() == '' or tgt.rstrip() == '':
                            logger.info("Skipping instance %s", line)
                            continue
                    except:
                        logger.info("Skipping instance %s", line)
                        continue
                    sent_src, sent_tgt = src.strip().split(), tgt.strip().split()
                    token_src, token_tgt = [tokenizer.tokenize(word) for word in sent_src], [tokenizer.tokenize(word) for word in sent_tgt]
                    wid_src, wid_tgt = [tokenizer.convert_tokens_to_ids(x) for x in token_src], [tokenizer.convert_tokens_to_ids(x) for x in token_tgt]

                    ids_src, ids_tgt = tokenizer.prepare_for_model(list(itertools.chain(*wid_src)), return_tensors='pt', max_length=tokenizer.max_len)['input_ids'], tokenizer.prepare_for_model(list(itertools.chain(*wid_tgt)), return_tensors='pt', max_length=tokenizer.max_len)['input_ids']
                    if len(ids_src[0]) == 2 or len(ids_tgt[0]) == 2:
                        logger.info("Skipping instance %s", line)
                        continue

                    bpe2word_map_src = []
                    for i, word_list in enumerate(token_src):
                        bpe2word_map_src += [i for x in word_list]
                    bpe2word_map_tgt = []
                    for i, word_list in enumerate(token_tgt):
                        bpe2word_map_tgt += [i for x in word_list]

                    if gold_path is not None:
                        try:
                            gold_line = gold_lines[line_id].strip().split()
                            gold_word_pairs = []
                            for src_tgt in gold_line:
                                if 'p' in src_tgt:
                                    if args.ignore_possible_alignments:
                                        continue
                                    wsrc, wtgt = src_tgt.split('p')
                                else:
                                    wsrc, wtgt = src_tgt.split('-')
                                wsrc, wtgt = (int(wsrc), int(wtgt)) if not args.gold_one_index else (int(wsrc)-1, int(wtgt)-1)
                                gold_word_pairs.append( (wsrc, wtgt) )
                            self.examples.append( (ids_src, ids_tgt, bpe2word_map_src, bpe2word_map_tgt, gold_word_pairs) )
                        except:
                            logger.info("Error when processing the gold alignment %s, skipping", gold_lines[line_id].strip())
                            continue
                    else:
                        self.examples.append( (ids_src, ids_tgt, bpe2word_map_src, bpe2word_map_tgt, None) )

            if args.cache_data:
                logger.info("Saving cached data to %s", cache_fn)
                torch.save(self.examples, cache_fn)