Ejemplo n.º 1
0
    def forward(self, sample_list):
        scores = torch.rand(sample_list.get_batch_size(), 3127)
        decoder = registry.get_decoder_class(self.config.inference.type)(
            self.vocab, self.config)
        sample_list = decoder.init_batch(sample_list)
        batch_size = sample_list.image_feature_0.size(0)
        data = {}
        data["texts"] = sample_list.answers.new_full((batch_size, 1),
                                                     self.vocab.SOS_INDEX,
                                                     dtype=torch.long)
        timesteps = 10
        sample_list.add_field("targets", sample_list.answers[:, 0, 1:])
        output = None
        batch_size_t = batch_size
        for t in range(timesteps):
            data, batch_size_t = self.get_data_t(data, batch_size_t)
            output = torch.randn(1, 9491)
            if t == timesteps - 1:
                output = torch.ones(1, 9491) * -30
                output[0][2] = 10
            finish, data, batch_size_t = decoder.decode(t, data, output)
            if finish:
                break

        model_output = {"scores": scores}
        model_output["captions"] = decoder.get_result()

        return model_output
Ejemplo n.º 2
0
    def forward(self, sample_list):
        # Stores the output probabilites.
        scores = sample_list.answers.new_ones(
            (
                sample_list.answers.size(0),
                self.text_processor.max_length,
                self.vocab_size,
            ),
            dtype=torch.float,
        )

        if self.config["inference"]["type"] in [
                "beam_search", "nucleus_sampling"
        ]:
            decoder = registry.get_decoder_class(
                self.config["inference"]["type"])(self.vocab, self.config)
            sample_list = decoder.init_batch(sample_list)

        batch_size = sample_list.image_feature_0.size(0)
        data, sample_list, timesteps = self.prepare_data(
            sample_list, batch_size)
        output = None
        batch_size_t = batch_size
        for t in range(timesteps):
            data, batch_size_t = self.get_data_t(t, data, batch_size_t, output)
            if self.config.inference.type in [
                    "beam_search", "nucleus_sampling"
            ]:
                pi_t = data["texts"]
            else:
                pi_t = data["texts"][:, t].unsqueeze(-1)
            embedding = self.word_embedding(pi_t)
            attention_feature, _ = self.process_feature_embedding(
                "image",
                sample_list,
                embedding[:, 0, :],
                batch_size_t=batch_size_t)
            output = self.classifier(attention_feature)
            # Compute decoding
            if self.config.inference.type in [
                    "beam_search", "nucleus_sampling"
            ]:
                finish, data, batch_size_t = decoder.decode(t, data, output)
                if finish:
                    break
            else:
                scores[:batch_size_t, t] = output

        model_output = {"scores": scores}
        if self.config.inference.type in ["beam_search", "nucleus_sampling"]:
            model_output["captions"] = decoder.get_result()

        return model_output
Ejemplo n.º 3
0
    def forward(self, sample_list):
        # Stores the output probabilites.
        scores = sample_list.answers.new_ones(
            (
                sample_list.answers.size(0),
                self.text_processor.max_length,
                self.vocab_size,
            ),
            dtype=torch.float,
        )

        if self.config["inference"]["type"] in [
                "beam_search", "nucleus_sampling"
        ]:
            decoder = registry.get_decoder_class(
                self.config["inference"]["type"])(self.vocab, self.config)
            sample_list = decoder.init_batch(sample_list)

        batch_size = sample_list.image_feature_0.size(0)
        data, sample_list, timesteps = self.prepare_data(
            sample_list, batch_size)
        output = None
        batch_size_t = batch_size
        for t in range(timesteps):
            data, batch_size_t = self.get_data_t(t, data, batch_size_t, output)
            if self.config.inference.type in [
                    "beam_search", "nucleus_sampling"
            ]:
                pi_t = data["texts"]
            else:
                pi_t = data["texts"][:, t].unsqueeze(-1)
            embedding = self.word_embedding(pi_t)
            attention_feature, _ = self.process_feature_embedding(
                "image",
                sample_list,
                embedding[:, 0, :],
                batch_size_t=batch_size_t)
            output = self.classifier(attention_feature)
            # Compute decoding
            if self.config.inference.type in [
                    "beam_search", "nucleus_sampling"
            ]:
                finish, data, batch_size_t = decoder.decode(t, data, output)
                if finish:
                    break
            else:
                scores[:batch_size_t, t] = output

        model_output = {}
        if self.config.inference.type in ["beam_search", "nucleus_sampling"]:
            results = decoder.get_result()
            results = torch.nn.functional.pad(
                results,
                (0, self.text_processor.max_length - results.size()[-1]),
                "constant",
                0,
            )
            model_output["captions"] = results
            model_output["losses"] = {}
            loss_key = "{}/{}".format(sample_list.dataset_name,
                                      sample_list.dataset_type)
            # Add a dummy loss so that loss calculation is not required
            model_output["losses"][loss_key + "/dummy_loss"] = torch.zeros(
                batch_size, device=sample_list.answers.device)
        else:
            model_output["scores"] = scores

        return model_output