예제 #1
0
 def score_comms(self, nodes: List[List[int]]):
     if self.with_attr:
         batch_g = dgl.BatchedDGLGraph([self.prepare_graph(x) for x in nodes], ['state', 'feats'], None)
     else:
         batch_g = dgl.BatchedDGLGraph([self.prepare_graph(x) for x in nodes], 'state', None)
     self.model.eval()
     with torch.no_grad():
         logits = self.model(batch_g)
         p = torch.exp(logits[:, 1]).cpu().numpy()
     if self.log_reward:
         r = -np.log(1 - p + 1e-9)
         r = np.clip(r, 0, 5.)
         return r
     else:
         return p
예제 #2
0
 def score_comms(self, nodes: List[List[int]]):
     comms = [list(x) for x in nodes]
     subgs = [self.prepare_graph(x) for x in comms]
     if self.with_attr:
         batched_graph = dgl.BatchedDGLGraph(subgs, ['state', 'feats'],
                                             None)
     else:
         batched_graph = dgl.BatchedDGLGraph(subgs, 'state', None)
     self.model.eval()
     all_logits, values = self.model(batched_graph)
     offset = 0
     rewards = []
     for comm, n_nodes in zip(comms, batched_graph.batch_num_nodes):
         scores = all_logits[offset:offset + len(comm)]
         offset += n_nodes
         rewards.append((scores <= scores[0]).float().mean().item())
     rewards = np.array(rewards)
     return rewards
예제 #3
0
 def train_step(self, pos_comms, neg_comms):
     subgs = []
     for nodes in itertools.chain(pos_comms):
         subg = self.prepare_graph(nodes)
         subgs.append(subg)
     for nodes in itertools.chain(neg_comms):
         subg = self.prepare_graph(nodes)
         subgs.append(subg)
     if self.with_attr:
         batch_graph = dgl.BatchedDGLGraph(subgs, ['state', 'feats'], None)
     else:
         batch_graph = dgl.BatchedDGLGraph(subgs, 'state', None)
     labels = torch.cat([torch.ones(len(pos_comms)), torch.zeros(len(neg_comms))]).long().to(self.device)
     self.model.train()
     self.optimizer.zero_grad()
     logits = self.model(batch_graph)
     loss = F.nll_loss(logits, labels)
     loss.backward()
     self.optimizer.step()
     acc = (torch.argmax(logits, 1) == labels).float().mean().item()
     return loss.item(), {'acc': acc}
예제 #4
0
 def train_step(self, comms, fn):
     comms = [list(x) for x in comms]
     subgs = [self.prepare_graph(x) for x in comms]
     if self.with_attr:
         batched_graph = dgl.BatchedDGLGraph(subgs, ['state', 'feats'],
                                             None)
     else:
         batched_graph = dgl.BatchedDGLGraph(subgs, 'state', None)
     self.model.train()
     self.optimizer.zero_grad()
     all_logits, values = self.model(batched_graph)
     offset = 0
     seeds = []
     logps = []
     for comm, n_nodes in zip(comms, batched_graph.batch_num_nodes):
         logits = torch.log_softmax(all_logits[offset:offset + len(comm)],
                                    0)
         offset += n_nodes
         ps = torch.exp(logits.detach())
         seed_idx = torch.multinomial(ps, 1).item()
         seeds.append(comm[seed_idx])
         logps.append(logits[seed_idx])
     logps = torch.stack(logps)
     # dual learning
     generated_comms = [x[:-1] if x[-1] == 'EOS' else x for x in fn(seeds)]
     rewards = []
     for x, y in zip(comms, generated_comms):
         a = set(x)
         b = set(y)
         rewards.append(len(a & b) / len(a | b))
     rewards = torch.tensor(rewards, device=self.device, dtype=torch.float)
     # advantages = rewards - values.detach()
     value_loss = ((values - rewards)**2).mean()
     # policy_loss = -(advantages * logps).mean()
     policy_loss = -(rewards * logps).mean()
     loss = policy_loss + .5 * value_loss
     loss.backward()
     self.optimizer.step()
     return policy_loss.item(), value_loss.item(), rewards.mean().item()
예제 #5
0
def load_data(args):
    '''Wraps the dgl's load_data utility to handle ppi special case'''
    if args.dataset != 'ppi':
        return _load_data(args)
    train_dataset = PPIDataset('train')
    val_dataset = PPIDataset('valid')
    test_dataset = PPIDataset('test')
    PPIDataType = namedtuple('PPIDataset', [
        'train_mask', 'test_mask', 'val_mask', 'features', 'labels',
        'num_labels', 'graph'
    ])
    G = dgl.BatchedDGLGraph(
        [train_dataset.graph, val_dataset.graph, test_dataset.graph],
        edge_attrs=None,
        node_attrs=None)
    G = G.to_networkx()
    # hack to dodge the potential bugs of to_networkx
    for (n1, n2, d) in G.edges(data=True):
        d.clear()
    train_nodes_num = train_dataset.graph.number_of_nodes()
    test_nodes_num = test_dataset.graph.number_of_nodes()
    val_nodes_num = val_dataset.graph.number_of_nodes()
    nodes_num = G.number_of_nodes()
    assert (nodes_num == (train_nodes_num + test_nodes_num + val_nodes_num))
    # construct mask
    mask = np.zeros((nodes_num, ), dtype=bool)
    train_mask = mask.copy()
    train_mask[:train_nodes_num] = True
    val_mask = mask.copy()
    val_mask[train_nodes_num:-test_nodes_num] = True
    test_mask = mask.copy()
    test_mask[-test_nodes_num:] = True

    # construct features
    features = np.concatenate(
        [train_dataset.features, val_dataset.features, test_dataset.features],
        axis=0)

    labels = np.concatenate(
        [train_dataset.labels, val_dataset.labels, test_dataset.labels],
        axis=0)

    data = PPIDataType(graph=G,
                       train_mask=train_mask,
                       test_mask=test_mask,
                       val_mask=val_mask,
                       features=features,
                       labels=labels,
                       num_labels=121)
    return data