Esempio n. 1
0
    def __init__(self,
                 text_processor: TextProcessor,
                 config: BertConfig = None,
                 encoder: BertModel = None,
                 enc_layer: int = 6,
                 embed_dim: int = 768,
                 intermediate_dim: int = 3072):
        super(LM, self).__init__()
        self.text_processor: TextProcessor = text_processor

        if config is not None:
            self.config = config
        else:
            self.config = lm_config.get_config(
                vocab_size=text_processor.tokenizer.get_vocab_size(),
                pad_token_id=text_processor.pad_token_id(),
                bos_token_id=text_processor.bos_token_id(),
                eos_token_id=text_processor.sep_token_id(),
                enc_layer=enc_layer,
                embed_dim=embed_dim,
                intermediate_dim=intermediate_dim)

            self.config["type_vocab_size"] = len(text_processor.languages)
            self.config = BertConfig(**self.config)

        self.masked_lm = BertOutputLayer(self.config)
        if encoder is None:
            self.encoder: BertModel = BertModel(self.config)
            self.encoder.init_weights()
        else:
            self.encoder = encoder
        self.encoder._tie_or_clone_weights(
            self.masked_lm.decoder, self.encoder.embeddings.word_embeddings)
Esempio n. 2
0
def mask_text(mask_prob,
              pads,
              texts,
              text_processor: TextProcessor,
              mask_eos: bool = True):
    assert 0 < mask_prob < 1
    mask = torch.empty(texts.size()).uniform_(0, 1) < mask_prob
    mask[~pads] = False  # We should not mask pads.
    if not mask_eos:
        eos_idx = texts == text_processor.sep_token_id()
        mask[
            eos_idx] = False  # We should not mask end-of-sentence (usually in case of BART training).

    masked_ids = texts[mask]
    random_index = lambda: random.randint(len(text_processor.special_tokens),
                                          text_processor.vocab_size() - 1)
    rand_select = lambda r, c: text_processor.mask_token_id() if r < 0.8 else (
        random_index() if r < 0.9 else int(masked_ids[c]))
    replacements = list(
        map(lambda i: rand_select(random.random(), i),
            range(masked_ids.size(0))))
    texts[mask] = torch.LongTensor(replacements)
    return mask, masked_ids, texts
Esempio n. 3
0
    def __init__(self,
                 text_processor: TextProcessor,
                 enc_layer: int = 6,
                 embed_dim: int = 768,
                 intermediate_dim: int = 3072):
        super(SenSim, self).__init__()
        self.text_processor: TextProcessor = text_processor
        self.config = lm_config.get_config(
            vocab_size=text_processor.tokenizer.get_vocab_size(),
            pad_token_id=text_processor.pad_token_id(),
            bos_token_id=text_processor.bos_token_id(),
            eos_token_id=text_processor.sep_token_id(),
            enc_layer=enc_layer,
            embed_dim=embed_dim,
            intermediate_dim=intermediate_dim)

        self.enc_layer = enc_layer
        self.embed_dim = embed_dim
        self.intermediate_dim = intermediate_dim
        self.config["type_vocab_size"] = len(text_processor.languages)
        self.config = BertConfig(**self.config)
        self.encoder = BertEncoderModel(self.config)
        self.encoder.init_weights()
        self.input_attention = nn.Linear(self.config.hidden_size, 1)
    def __init__(self,
                 text_processor: TextProcessor,
                 config: ReformerConfig = None,
                 size: int = 1):
        """
        :param size: config size: 1 small, 2 medium, 3 base.
        """
        super(ReformerLM, self).__init__()
        self.text_processor: TextProcessor = text_processor

        if config is not None:
            self.config = config
        else:
            config_func = _small_config if size == 1 else (
                _base_config if size == 3 else _medium_config)
            self.config = config_func(
                vocab_size=text_processor.tokenizer.get_vocab_size(),
                pad_token_id=text_processor.pad_token_id(),
                eos_token_id=text_processor.sep_token_id())
            self.config = ReformerConfig(**self.config)

        reformer = ReformerModelWithLMHead(self.config)
        self.lm_head: ReformerOnlyLMHead = reformer.lm_head
        self.encoder: ReformerModel = reformer.reformer
Esempio n. 5
0
    def __init__(self,
                 text_processor: TextProcessor,
                 lang_dec: bool = True,
                 use_proposals=False,
                 tie_embed=False,
                 enc_layer: int = 6,
                 dec_layer: int = 3,
                 embed_dim: int = 768,
                 intermediate_dim: int = 3072,
                 freeze_image: bool = False,
                 resnet_depth: int = 1):
        super(Seq2Seq, self).__init__()
        self.text_processor: TextProcessor = text_processor
        self.config = lm_config.get_config(
            vocab_size=text_processor.tokenizer.get_vocab_size(),
            pad_token_id=text_processor.pad_token_id(),
            bos_token_id=text_processor.bos_token_id(),
            eos_token_id=text_processor.sep_token_id(),
            enc_layer=enc_layer,
            embed_dim=embed_dim,
            intermediate_dim=intermediate_dim)

        self.enc_layer = enc_layer
        self.dec_layer = dec_layer
        self.embed_dim = embed_dim
        self.intermediate_dim = intermediate_dim
        self.config["type_vocab_size"] = len(text_processor.languages)
        self.config = BertConfig(**self.config)
        dec_config = copy.deepcopy(self.config)
        dec_config.num_hidden_layers = self.dec_layer

        self.encoder = BertEncoderModel(self.config)
        self.encoder.init_weights()
        self.lang_dec = lang_dec
        self.tie_embed = tie_embed
        if not lang_dec:
            self.decoder = BertDecoderModel(dec_config)
            self.encoder._tie_or_clone_weights(
                self.encoder.embeddings.position_embeddings,
                self.decoder.embeddings.position_embeddings)
            self.encoder._tie_or_clone_weights(
                self.encoder.embeddings.token_type_embeddings,
                self.decoder.embeddings.token_type_embeddings)
            self.encoder._tie_or_clone_weights(
                self.encoder.embeddings.word_embeddings,
                self.decoder.embeddings.word_embeddings)

            if tie_embed:
                self.output_layer = BertOutputLayer(dec_config)
                self.encoder._tie_or_clone_weights(
                    self.output_layer, self.encoder.embeddings.word_embeddings)
                self.encoder._tie_or_clone_weights(
                    self.encoder.embeddings.position_embeddings,
                    self.decoder.embeddings.position_embeddings)
                self.decoder._tie_or_clone_weights(
                    self.output_layer, self.decoder.embeddings.word_embeddings)
            else:
                self.output_layer = nn.ModuleList([
                    BertOutputLayer(dec_config)
                    for _ in text_processor.languages
                ])

            if len(self.encoder.encoder.layer) == len(
                    self.decoder.decoder.layer):
                for i in range(len(self.encoder.encoder.layer)):
                    self.decoder.decoder.layer[
                        i].attention = self.encoder.encoder.layer[i].attention

        else:
            dec = BertDecoderModel(dec_config)
            self.decoder = nn.ModuleList(
                [copy.deepcopy(dec) for _ in text_processor.languages])
            self.output_layer = nn.ModuleList([
                BertOutputLayer(dec_config) for _ in text_processor.languages
            ])
            for i, dec in enumerate(self.decoder):
                if tie_embed:
                    self.encoder._tie_or_clone_weights(
                        self.output_layer[i],
                        self.encoder.embeddings.word_embeddings)
                    dec.embeddings.position_embeddings = self.encoder.embeddings.position_embeddings
                dec._tie_or_clone_weights(self.output_layer[i],
                                          dec.embeddings.word_embeddings)
                dec._tie_or_clone_weights(
                    self.encoder.embeddings.token_type_embeddings,
                    dec.embeddings.token_type_embeddings)

        self.use_proposals = use_proposals
        if self.use_proposals:
            self.proposal_embedding = self.encoder.embeddings.word_embeddings
            self.lexical_gate = nn.Parameter(torch.zeros(
                1, self.config.hidden_size).fill_(0.1),
                                             requires_grad=True)
            self.lexical_layer_norm = nn.LayerNorm(
                self.config.hidden_size, eps=self.config.layer_norm_eps)

        self.freeze_image = freeze_image
        self.resnet_depth = resnet_depth
    def train(options):
        lex_dict = None
        if options.dict_path is not None:
            lex_dict = get_lex_dict(options.dict_path)
        if not os.path.exists(options.model_path):
            os.makedirs(options.model_path)

        text_processor = TextProcessor(options.tokenizer_path)
        assert text_processor.pad_token_id() == 0

        if options.pretrained_path is not None:
            mt_model = Seq2Seq.load(ImageCaptioning,
                                    options.pretrained_path,
                                    tok_dir=options.tokenizer_path)
        else:
            mt_model = ImageCaptioning(
                use_proposals=lex_dict is not None,
                tie_embed=options.tie_embed,
                text_processor=text_processor,
                resnet_depth=options.resnet_depth,
                lang_dec=options.lang_decoder,
                enc_layer=options.encoder_layer,
                dec_layer=options.decoder_layer,
                embed_dim=options.embed_dim,
                intermediate_dim=options.intermediate_layer_dim)

        if options.lm_path is not None:
            lm = LM(text_processor=text_processor,
                    enc_layer=options.encoder_layer,
                    embed_dim=options.embed_dim,
                    intermediate_dim=options.intermediate_layer_dim)
            mt_model.init_from_lm(lm)

        print("Model initialization done!")

        # We assume that the collator function returns a list with the size of number of gpus (in case of cpus,
        collator = dataset.ImageTextCollator()
        num_batches = max(1, torch.cuda.device_count())

        if options.continue_train:
            with open(os.path.join(options.pretrained_path, "optim"),
                      "rb") as fp:
                optimizer = pickle.load(fp)
        else:
            optimizer = build_optimizer(mt_model,
                                        options.learning_rate,
                                        warump_steps=options.warmup)
        trainer = ImageCaptionTrainer(
            model=mt_model,
            mask_prob=options.mask_prob,
            optimizer=optimizer,
            clip=options.clip,
            beam_width=options.beam_width,
            max_len_a=options.max_len_a,
            max_len_b=options.max_len_b,
            len_penalty_ratio=options.len_penalty_ratio,
            fp16=options.fp16,
            mm_mode=options.mm_mode)

        pin_memory = torch.cuda.is_available()
        img_train_loader = ImageMTTrainer.get_img_loader(
            collator,
            dataset.ImageCaptionDataset,
            options.train_path,
            mt_model,
            num_batches,
            options,
            pin_memory,
            lex_dict=lex_dict)

        img_dev_loader = ImageMTTrainer.get_img_loader(
            collator,
            dataset.ImageCaptionDataset,
            options.dev_path,
            mt_model,
            num_batches,
            options,
            pin_memory,
            lex_dict=lex_dict,
            shuffle=False,
            denom=2)

        trainer.reference = None
        if img_dev_loader is not None:
            trainer.reference = []
            generator = (trainer.generator.module if hasattr(
                trainer.generator, "module") else trainer.generator)
            for data in img_dev_loader:
                for batch in data:
                    captions = [b["captions"] for b in batch]
                    for caption in captions:
                        refs = get_outputs_until_eos(
                            text_processor.sep_token_id(),
                            caption,
                            remove_first_token=True)
                        ref = [
                            generator.seq2seq_model.text_processor.tokenizer.
                            decode(ref.numpy()) for ref in refs
                        ]
                        trainer.reference += ref

        step, train_epoch = 0, 1
        while options.step > 0 and step < options.step:
            print("train epoch", train_epoch)
            step = trainer.train_epoch(img_data_iter=img_train_loader,
                                       img_dev_data_iter=img_dev_loader,
                                       max_step=options.step,
                                       lex_dict=lex_dict,
                                       saving_path=options.model_path,
                                       step=step)
            train_epoch += 1