def __init__(self, emb_type, **kwargs): super(TextEmbedding, self).__init__() self.model_data_dir = kwargs.get("model_data_dir", None) self.embedding_dim = kwargs.get("embedding_dim", None) # Update kwargs here if emb_type == "identity": self.module = Identity() self.module.text_out_dim = self.embedding_dim elif emb_type == "vocab": self.module = VocabEmbedding(**kwargs) self.module.text_out_dim = self.embedding_dim elif emb_type == "preextracted": self.module = PreExtractedEmbedding(**kwargs) elif emb_type == "bilstm": self.module = BiLSTMTextEmbedding(**kwargs) elif emb_type == "attention": self.module = AttentionTextEmbedding(**kwargs) elif emb_type == "torch": vocab_size = kwargs["vocab_size"] embedding_dim = kwargs["embedding_dim"] self.module = nn.Embedding(vocab_size, embedding_dim) self.module.text_out_dim = self.embedding_dim else: raise NotImplementedError("Unknown question embedding '%s'" % emb_type) self.text_out_dim = self.module.text_out_dim
def __init__(self, encoder_type, in_dim, **kwargs): super(ImageEncoder, self).__init__() if encoder_type == "default": self.module = Identity() self.module.in_dim = in_dim self.module.out_dim = in_dim elif encoder_type == "finetune_faster_rcnn_fpn_fc7": self.module = FinetuneFasterRcnnFpnFc7(in_dim, **kwargs) else: raise NotImplementedError("Unknown Image Encoder: %s" % encoder_type) self.out_dim = self.module.out_dim