Exemple #1
0
    def on_estimateButton_clicked(self):
        import torch
        import pickle
        import gat
        # import gcn

        self.ui.statusbar.showMessage("estimate")
        if self.model is None:
            params = pickle.load(open('model.prms', 'rb'))
            # self.model = gcn.GCN(*params)
            self.model = gat.GAT(*params)
            self.model.load_state_dict(torch.load('model.tch'))
            self.model.eval()


        device = torch.device("cpu")
        structure = self.world.serialize()
        train_dataset = socnav.SocNavDataset(structure, mode='train')
        train_dataloader = DataLoader(train_dataset, batch_size=1, collate_fn=collate)
        for batch, data in enumerate(train_dataloader):
            subgraph, feats, labels = data
            feats = feats.to(device)
            self.model.g = subgraph
            # for layer in self.model.layers:
            for layer in self.model.gat_layers:
                layer.g = subgraph
            logits = self.model(feats.float())[0].detach().numpy()[0]
            translate = 1. / (1. + math.exp(-logits))*100
            if translate < 0: translate = 0
            if translate > 100: translate = 100
            self.ui.slider.setValue(int(translate))
Exemple #2
0
 def on_estimateButton_clicked(self):
     self.ui.statusbar.showMessage("estimate")
     if self.model is None:
         self.device = torch.device('cpu')
         self.params = pickle.load(open('../SNGNN_PARAMETERS.prms', 'rb'),
                                   fix_imports=True)
         self.GNNmodel = pg_rgcn_gat.PRGAT(self.params[5],
                                           self.params[7],
                                           self.params[8][0],
                                           self.params[14],
                                           self.params[14],
                                           self.params[6],
                                           int(self.params[4] / 2),
                                           int(self.params[4] / 2),
                                           self.params[10],
                                           self.params[9],
                                           self.params[12],
                                           bias=True)
         self.GNNmodel.load_state_dict(
             torch.load('../SNGNN_MODEL.tch', map_location='cpu'))
         self.GNNmodel.eval()
     print(json.loads(self.current_line))
     graph_type = 'relational'
     train_dataset = socnav.SocNavDataset(json.loads(self.current_line),
                                          mode='train',
                                          alt=graph_type,
                                          verbose=False)
     train_dataloader = DataLoader(train_dataset,
                                   batch_size=1,
                                   collate_fn=collate)
     for batch, data in enumerate(train_dataloader):
         subgraph, feats, labels = data
         feats = feats.to(self.device)
         data = Data(
             x=feats.float(),
             edge_index=torch.stack(subgraph.edges()).to(self.device),
             edge_type=subgraph.edata['rel_type'].squeeze().to(self.device))
         logits = self.GNNmodel(data)[0].detach().numpy()[0]
         score = logits * 100
         if score > 100:
             score = 100
         elif score < 0:
             score = 0
         self.ui.slider.setValue(int(score))
Exemple #3
0
 def predict(self, sn_scenario):
     jsonmodel = self.makeJson(sn_scenario)
     graph_type = 'relational'
     train_dataset = socnav.SocNavDataset(jsonmodel,
                                          mode='train',
                                          alt=graph_type)
     train_dataloader = DataLoader(train_dataset,
                                   batch_size=1,
                                   collate_fn=collate)
     for batch, data in enumerate(train_dataloader):
         subgraph, feats, labels = data
         feats = feats.to(self.device)
         data = Data(
             x=feats.float(),
             edge_index=torch.stack(subgraph.edges()).to(self.device),
             edge_type=subgraph.edata['rel_type'].squeeze().to(self.device))
         logits = self.GNNmodel(data)[0].detach().numpy()[0]
         score = logits * 100
         if score > 100:
             score = 100
         elif score < 0:
             score = 0
     return score