Ejemplo n.º 1
0
    def __init__(self, vocab, pos, rels, enum_word, options, onto, cpos):
        super(HybridModel, self).__init__()
        random.seed(2)
        dims = options.wembedding_dims + options.pembedding_dims + options.cembedding_dims + options.oembedding_dims
        self.shared = nn.LSTM(dims,
                              options.lstm_dims,
                              batch_first=True,
                              bidirectional=True)
        model0 = GraphModel(vocab, pos, rels, enum_word, options, onto, cpos,
                            self.shared)
        model1 = TransitionModel(vocab, pos, rels, enum_word, options, onto,
                                 cpos, self.shared)
        self.graphModel = model0.cuda() if torch.cuda.is_available(
        ) else model0
        self.transitionModel = model1.cuda() if torch.cuda.is_available(
        ) else model1

        self.graphTrainer = get_optim(options.optim,
                                      self.graphModel.parameters())
        self.transitionTrainer = get_optim(options.optim,
                                           self.transitionModel.parameters())

        classifier = LinearClassifier(options.lstm_dims * 4)
        self.classifier = classifier.cuda() if torch.cuda.is_available(
        ) else classifier
        self.loss = nn.CrossEntropyLoss()
        self.optimizer = optim.Adam(classifier.parameters(), lr=0.01)
Ejemplo n.º 2
0
    def __init__(self, keyword, src_id, dst_id, in_dim, hid_c, src_len,
                 n_layers, device):
        super(PredictModel, self).__init__()

        self.oneHotEmbed = OneHotProcess(in_dim, hid_c)

        if keyword == "Graph":
            self.model = GraphModel(src_id, dst_id, src_len * hid_c,
                                    src_len * hid_c, device)
        elif keyword == "SG":
            self.model = SGModel(src_id, dst_id, src_len * hid_c,
                                 src_len * hid_c, n_layers, device)
        elif keyword == "SAGE":
            self.model = SAGEModel(src_id, dst_id, src_len * hid_c,
                                   src_len * hid_c, n_layers, device)
        else:
            raise KeyError("Keyword is not defined! ")

        self.linear = nn.Linear(src_len * hid_c, in_dim)
 def _getGraph():
     loader = getLoaderPrimaryForm()
     g = GraphModel(loader, degree)
     g.processGraphs(slice(51555))
     return g