예제 #1
0
    def __init__(self, embed_dim, texts):
        super(ImageTextEncodeTransformModel, self).__init__()

        self.snorm = torch_functions.NormalizationLayer(normalize_scale=4.0,
                                                        learn_scale=True)

        # image
        self.img_encoder = torchvision.models.resnet50(pretrained=True)
        self.img_encoder.fc = torch.nn.Sequential(
            torch.nn.Dropout(0.2), torch.nn.Linear(2048, 2048),
            torch.nn.BatchNorm1d(2048), torch.nn.Dropout(0.2), torch.nn.ReLU(),
            torch.nn.Linear(2048, embed_dim))

        # text
        self.text_encoder = text_model.TextLSTMModel(
            texts_to_build_vocab=texts,
            word_embed_dim=256,
            lstm_hidden_dim=embed_dim)
        self.text_encoder.fc_output = torch.nn.Sequential(
            torch.nn.Dropout(0.1), torch.nn.Linear(embed_dim, 2048),
            torch.nn.BatchNorm1d(2048), torch.nn.Dropout(0.1), torch.nn.ReLU(),
            torch.nn.Linear(2048, embed_dim))

        # transformer
        self.transformer = MTirgTransform(embed_dim)
예제 #2
0
 def __init__(self, embed_dim):
     super(ConcatTransform, self).__init__()
     self.m = torch.nn.Sequential(
         torch.nn.Linear(embed_dim * 3, embed_dim * 5), torch.nn.ReLU(),
         torch.nn.Linear(embed_dim * 5, embed_dim * 5),
         torch.nn.BatchNorm1d(embed_dim * 5), torch.nn.ReLU(),
         torch.nn.Linear(embed_dim * 5, embed_dim))
     self.norm = torch_functions.NormalizationLayer(learn_scale=False)
예제 #3
0
 def __init__(self, embed_dim):
     super(MTirgTransform, self).__init__()
     self.m = torch.nn.Sequential(
         torch.nn.Linear(embed_dim * 3, embed_dim * 5), torch.nn.ReLU(),
         torch.nn.Linear(embed_dim * 5, embed_dim * 5),
         torch.nn.BatchNorm1d(embed_dim * 5), torch.nn.ReLU(),
         torch.nn.Linear(embed_dim * 5, embed_dim))
     self.norm = torch_functions.NormalizationLayer(learn_scale=False)
     self.a = torch.nn.Parameter(torch.tensor([1.0, 0.1]))
예제 #4
0
 def __init__(self):
     super(ImgTextCompositionBase, self).__init__()
     self.normalization_layer = torch_functions.NormalizationLayer(
         normalize_scale=4.0, learn_scale=True)
     self.soft_triplet_loss = torch_functions.TripletLoss()