Пример #1
0
def build_model(args):
    # build model
    if args.method_type == "order":
        model = models.OrderEmbedder(1, args.hidden_dim, args)
    elif args.method_type == "mlp":
        model = models.BaselineMLP(1, args.hidden_dim, args)
    model.to(utils.get_device())
    if args.test and args.model_path:
        model.load_state_dict(torch.load(args.model_path,
            map_location=utils.get_device()))
    return model
Пример #2
0
def get_dga_sdirs(args, data, labels):
    device = get_device(args)
    sdirs = []
    for x, y in zip(data, labels):
        # dga_bs: dist grad accum. batch size
        dataloader = get_dataloader(x, y, args.dga_bs, shuffle=False)
        count = 0
        for xiter, yiter in dataloader:
            model, loss_type = get_model(args, False)
            loss_fn = get_loss_fn(loss_type)
            opt = get_optim(args, model)

            loss, _ = forward(model, xiter, yiter, opt, loss_fn, device)
            loss.backward()
            sdirs.append(get_model_grads(model, flatten=True))
            count += 1
            if count >= args.num_dga:
                break

    stacked = [[] for _ in range(len(sdirs[0]))]

    for l in range(len(sdirs[0])):
        for i in range(len(sdirs)):
            stacked[l].append(sdirs[i][l].flatten())

    sdirs = [[] for _ in range(args.ncomponent)]
    for l, layer in enumerate(stacked):
        layer = torch.stack(layer, dim=0).T.cpu().numpy()
        layer, _ = pca_transform(layer, args.ncomponent)
        for i in range(args.ncomponent):
            sdirs[i].append(layer[:, i].flatten())

    assert len(sdirs) == args.ncomponent

    return sdirs
Пример #3
0
def train(args, model, logger, in_queue, out_queue):
    """Train the order embedding model.

    args: Commandline arguments
    logger: logger for logging progress
    in_queue: input queue to an intersection computation worker
    out_queue: output queue to an intersection computation worker
    """
    scheduler, opt = utils.build_optimizer(args, model.parameters())
    if args.method_type == "order":
        clf_opt = optim.Adam(model.clf_model.parameters(), lr=args.lr)

    done = False
    while not done:
        data_source = make_data_source(args)
        loaders = data_source.gen_data_loaders(args.eval_interval *
            args.batch_size, args.batch_size, train=True)
        for batch_target, batch_neg_target, batch_neg_query in zip(*loaders):
            msg, _ = in_queue.get()
            if msg == "done":
                done = True
                break
            # train
            model.train()
            model.zero_grad()
            pos_a, pos_b, neg_a, neg_b = data_source.gen_batch(batch_target,
                batch_neg_target, batch_neg_query, True)
            emb_pos_a, emb_pos_b = model.emb_model(pos_a), model.emb_model(pos_b)
            emb_neg_a, emb_neg_b = model.emb_model(neg_a), model.emb_model(neg_b)
            #print(emb_pos_a.shape, emb_neg_a.shape, emb_neg_b.shape)
            emb_as = torch.cat((emb_pos_a, emb_neg_a), dim=0)
            emb_bs = torch.cat((emb_pos_b, emb_neg_b), dim=0)
            labels = torch.tensor([1]*pos_a.num_graphs + [0]*neg_a.num_graphs).to(
                utils.get_device())
            intersect_embs = None
            pred = model(emb_as, emb_bs)
            loss = model.criterion(pred, intersect_embs, labels)
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            opt.step()
            if scheduler:
                scheduler.step()

            if args.method_type == "order":
                with torch.no_grad():
                    pred = model.predict(pred)
                model.clf_model.zero_grad()
                pred = model.clf_model(pred.unsqueeze(1))
                criterion = nn.NLLLoss()
                clf_loss = criterion(pred, labels)
                clf_loss.backward()
                clf_opt.step()
            pred = pred.argmax(dim=-1)
            acc = torch.mean((pred == labels).type(torch.float))
            train_loss = loss.item()
            train_acc = acc.item()

            out_queue.put(("step", (loss.item(), acc)))
Пример #4
0
def get_model(args, parallel=True, ckpt_path=False):
    if args.clf == 'fcn':
        print('Initializing FCN...')
        model = FCN(args.input_size, args.output_size)
    elif args.clf == 'mlp':
        print('Initializing MLP...')
        model = MLP(args.input_size, args.output_size)
    elif args.clf == 'svm':
        print('Initializing SVM...')
        model = SVM(args.input_size, args.output_size)
    elif args.clf == 'cnn':
        print('Initializing CNN...')
        model = CNN(nc=args.num_channels, fs=args.cnn_view)
    elif args.clf == 'resnet18':
        print('Initializing ResNet18...')
        model = resnet.resnet18(num_channels=args.num_channels,
                                num_classes=args.output_size)
    elif args.clf == 'vgg19':
        print('Initializing VGG19...')
        model = VGG(vgg_name=args.clf,
                    num_channels=args.num_channels,
                    num_classes=args.output_size)
    elif args.clf == 'unet':
        print('Initializing UNet...')
        model = UNet(in_channels=args.num_channels,
                     out_channels=args.output_size)

    num_params, num_layers = get_model_size(model)
    print("# params: {}\n# layers: {}".format(num_params, num_layers))

    if ckpt_path:
        model.load_state_dict(torch.load(ckpt_path))
        print('Load init: {}'.format(ckpt_path))

    if parallel:
        model = nn.DataParallel(model.to(get_device(args)),
                                device_ids=args.device_id)
    else:
        model = model.to(get_device(args))

    loss_type = 'hinge' if args.clf == 'svm' else args.loss_type
    print("Loss: {}".format(loss_type))

    return model, loss_type
Пример #5
0
    def criterion(self, pred, intersect_embs, labels):
        """Loss function for order emb.

        pred: lists of embeddings outputted by forward
        intersect_embs: not used
        labels: subgraph labels for each entry in pred
        """
        emb_as, emb_bs = pred
        e = torch.sum(torch.max(
            torch.zeros_like(emb_as, device=utils.get_device()),
            emb_bs - emb_as)**2,
                      dim=1)

        margin = self.margin
        e[labels == 0] = torch.max(
            torch.tensor(0.0, device=utils.get_device()),
            margin - e)[labels == 0]

        relation_loss = torch.sum(e)

        return relation_loss
    def criterion(self, pred, intersect_embs, labels):
        """Loss function for order emb.
        The e term is the amount of violation (if b is a subgraph of a).
        For positive examples, the e term is minimized (close to 0); 
        for negative examples, the e term is trained to be at least greater than self.margin.

        pred: lists of embeddings outputted by forward
        intersect_embs: not used
        labels: subgraph labels for each entry in pred
        """
        emb_as, emb_bs = pred
        e = torch.sum(torch.max(
            torch.zeros_like(emb_as, device=utils.get_device()),
            emb_bs - emb_as)**2,
                      dim=1)

        margin = self.margin
        e[labels == 0] = torch.max(
            torch.tensor(0.0, device=utils.get_device()),
            margin - e)[labels == 0]

        relation_loss = torch.sum(e)

        return relation_loss
 def step(self):
     new_beam_sets = []
     print("seeds come from", len(set(b[0][-1] for b in self.beam_sets)),
         "distinct graphs")
     analyze_embs_cur = []
     for beam_set in tqdm(self.beam_sets):
         new_beams = []
         for _, neigh, frontier, visited, graph_idx in beam_set:
             graph = self.dataset[graph_idx]
             if len(neigh) >= self.max_pattern_size or not frontier: continue
             cand_neighs, anchors = [], []
             for cand_node in frontier:
                 cand_neigh = graph.subgraph(neigh + [cand_node])
                 cand_neighs.append(cand_neigh)
                 if self.node_anchored:
                     anchors.append(neigh[0])
             cand_embs = self.model.emb_model(utils.batch_nx_graphs(
                 cand_neighs, anchors=anchors if self.node_anchored else None))
             best_score, best_node = float("inf"), None
             for cand_node, cand_emb in zip(frontier, cand_embs):
                 score, n_embs = 0, 0
                 for emb_batch in self.embs:
                     n_embs += len(emb_batch)
                     if self.model_type == "order":
                         score -= torch.sum(torch.argmax(
                             self.model.clf_model(self.model.predict((
                             emb_batch.to(utils.get_device()),
                             cand_emb)).unsqueeze(1)), axis=1)).item()
                     elif self.model_type == "mlp":
                         score += torch.sum(self.model(
                             emb_batch.to(utils.get_device()),
                             cand_emb.unsqueeze(0).expand(len(emb_batch), -1)
                             )[:,0]).item()
                     else:
                         print("unrecognized model type")
                 if score < best_score:
                     best_score = score
                     best_node = cand_node
                 new_frontier = list(((set(frontier) |
                     set(graph.neighbors(cand_node))) - visited) -
                     set([cand_node]))
                 new_beams.append((
                     score, neigh + [cand_node],
                     new_frontier, visited | set([cand_node]), graph_idx))
         new_beams = list(sorted(new_beams, key=lambda x:
             x[0]))[:self.n_beams]
         for score, neigh, frontier, visited, graph_idx in new_beams[:1]:
             graph = self.dataset[graph_idx]
             # add to record
             neigh_g = graph.subgraph(neigh).copy()
             neigh_g.remove_edges_from(nx.selfloop_edges(neigh_g))
             for v in neigh_g.nodes:
                 neigh_g.nodes[v]["anchor"] = 1 if v == neigh[0] else 0
             self.cand_patterns[len(neigh_g)].append((score, neigh_g))
             if self.rank_method in ["counts", "hybrid"]:
                 self.counts[len(neigh_g)][utils.wl_hash(neigh_g,
                     node_anchored=self.node_anchored)].append(neigh_g)
             if self.analyze and len(neigh) >= 3:
                 emb = self.model.emb_model(utils.batch_nx_graphs(
                     [neigh_g], anchors=[neigh[0]] if self.node_anchored
                     else None)).squeeze(0)
                 analyze_embs_cur.append(emb.detach().cpu().numpy())
         if len(new_beams) > 0:
             new_beam_sets.append(new_beams)
     self.beam_sets = new_beam_sets
     self.analyze_embs.append(analyze_embs_cur)
    def step(self):
        ps = np.array([len(g) for g in self.dataset], dtype=np.float)
        ps /= np.sum(ps)
        graph_dist = stats.rv_discrete(values=(np.arange(len(self.dataset)), ps))

        print("Size", self.max_size)
        print(len(self.visited_seed_nodes), "distinct seeds")
        for simulation_n in tqdm(range(self.n_trials //
            (self.max_pattern_size+1-self.min_pattern_size))):
            # pick seed node
            best_graph_idx, best_start_node, best_score = None, None, -float("inf")
            for cand_graph_idx, cand_start_node in self.visited_seed_nodes:
                state = cand_graph_idx, cand_start_node
                my_visit_counts = sum(self.visit_counts[state].values())
                q_score = (sum(self.cum_action_values[state].values()) /
                    (my_visit_counts or 1))
                uct_score = self.c_uct * np.sqrt(np.log(simulation_n or 1) /
                    (my_visit_counts or 1))
                node_score = q_score + uct_score
                if node_score > best_score:
                    best_score = node_score
                    best_graph_idx = cand_graph_idx
                    best_start_node = cand_start_node
            # if existing seed beats choosing a new seed
            if best_score >= self.c_uct * np.sqrt(np.log(simulation_n or 1)):
                graph_idx, start_node = best_graph_idx, best_start_node
                assert best_start_node in self.dataset[graph_idx].nodes
                graph = self.dataset[graph_idx]
            else:
                found = False
                while not found:
                    graph_idx = np.arange(len(self.dataset))[graph_dist.rvs()]
                    graph = self.dataset[graph_idx]
                    start_node = random.choice(list(graph.nodes))
                    # don't pick isolated nodes or small islands
                    if self.has_min_reachable_nodes(graph, start_node,
                        self.min_pattern_size):
                        found = True
                self.visited_seed_nodes.add((graph_idx, start_node))
            neigh = [start_node]
            frontier = list(set(graph.neighbors(start_node)) - set(neigh))
            visited = set([start_node])
            neigh_g = nx.Graph()
            neigh_g.add_node(start_node, anchor=1)
            cur_state = graph_idx, start_node
            state_list = [cur_state]
            while frontier and len(neigh) < self.max_size:
                cand_neighs, anchors = [], []
                for cand_node in frontier:
                    cand_neigh = graph.subgraph(neigh + [cand_node])
                    cand_neighs.append(cand_neigh)
                    if self.node_anchored:
                        anchors.append(neigh[0])
                cand_embs = self.model.emb_model(utils.batch_nx_graphs(
                    cand_neighs, anchors=anchors if self.node_anchored else None))
                best_v_score, best_node_score, best_node = 0, -float("inf"), None
                for cand_node, cand_emb in zip(frontier, cand_embs):
                    score, n_embs = 0, 0
                    for emb_batch in self.embs:
                        score += torch.sum(self.model.predict((
                            emb_batch.to(utils.get_device()), cand_emb))).item()
                        n_embs += len(emb_batch)
                    v_score = -np.log(score/n_embs + 1) + 1
                    # get wl hash of next state
                    neigh_g = graph.subgraph(neigh + [cand_node]).copy()
                    neigh_g.remove_edges_from(nx.selfloop_edges(neigh_g))
                    for v in neigh_g.nodes:
                        neigh_g.nodes[v]["anchor"] = 1 if v == neigh[0] else 0
                    next_state = utils.wl_hash(neigh_g,
                        node_anchored=self.node_anchored)
                    # compute node score
                    parent_visit_counts = sum(self.visit_counts[cur_state].values())
                    my_visit_counts = sum(self.visit_counts[next_state].values())
                    q_score = (sum(self.cum_action_values[next_state].values()) /
                        (my_visit_counts or 1))
                    uct_score = self.c_uct * np.sqrt(np.log(parent_visit_counts or
                        1) / (my_visit_counts or 1))
                    node_score = q_score + uct_score
                    if node_score > best_node_score:
                        best_node_score = node_score
                        best_v_score = v_score
                        best_node = cand_node
                frontier = list(((set(frontier) |
                    set(graph.neighbors(best_node))) - visited) -
                    set([best_node]))
                visited.add(best_node)
                neigh.append(best_node)

                # update visit counts, wl cache
                neigh_g = graph.subgraph(neigh).copy()
                neigh_g.remove_edges_from(nx.selfloop_edges(neigh_g))
                for v in neigh_g.nodes:
                    neigh_g.nodes[v]["anchor"] = 1 if v == neigh[0] else 0
                prev_state = cur_state
                cur_state = utils.wl_hash(neigh_g, node_anchored=self.node_anchored)
                state_list.append(cur_state)
                self.wl_hash_to_graphs[cur_state].append(neigh_g)

            # backprop value
            for i in range(0, len(state_list) - 1):
                self.cum_action_values[state_list[i]][
                    state_list[i+1]] += best_v_score
                self.visit_counts[state_list[i]][state_list[i+1]] += 1
        self.max_size += 1
    def gen_batch(self, batch_target, batch_neg_target, batch_neg_query,
                  train):
        def sample_subgraph(graph,
                            offset=0,
                            use_precomp_sizes=False,
                            filter_negs=False,
                            supersample_small_graphs=False,
                            neg_target=None,
                            hard_neg_idxs=None):
            if neg_target is not None: graph_idx = graph.G.graph["idx"]
            use_hard_neg = (hard_neg_idxs is not None
                            and graph.G.graph["idx"] in hard_neg_idxs)
            done = False
            n_tries = 0
            while not done:
                if use_precomp_sizes:
                    size = graph.G.graph["subgraph_size"]
                else:
                    if train and supersample_small_graphs:
                        sizes = np.arange(self.min_size + offset,
                                          len(graph.G) + offset)
                        ps = (sizes - self.min_size + 2)**(-1.1)
                        ps /= ps.sum()
                        size = stats.rv_discrete(values=(sizes, ps)).rvs()
                    else:
                        d = 1 if train else 0
                        size = random.randint(self.min_size + offset - d,
                                              len(graph.G) - 1 + offset)
                start_node = random.choice(list(graph.G.nodes))
                neigh = [start_node]
                frontier = list(
                    set(graph.G.neighbors(start_node)) - set(neigh))
                visited = set([start_node])
                while len(neigh) < size:
                    new_node = random.choice(list(frontier))
                    assert new_node not in neigh
                    neigh.append(new_node)
                    visited.add(new_node)
                    frontier += list(graph.G.neighbors(new_node))
                    frontier = [x for x in frontier if x not in visited]
                if self.node_anchored:
                    anchor = neigh[0]
                    for v in graph.G.nodes:
                        graph.G.nodes[v]["node_feature"] = (
                            torch.ones(1) if anchor == v else torch.zeros(1))
                        #print(v, graph.G.nodes[v]["node_feature"])
                neigh = graph.G.subgraph(neigh)
                if use_hard_neg and train:
                    neigh = neigh.copy()
                    if random.random(
                    ) < 1.0 or not self.node_anchored:  # add edges
                        non_edges = list(nx.non_edges(neigh))
                        if len(non_edges) > 0:
                            for u, v in random.sample(
                                    non_edges,
                                    random.randint(1, min(len(non_edges), 5))):
                                neigh.add_edge(u, v)
                    else:  # perturb anchor
                        anchor = random.choice(list(neigh.nodes))
                        for v in neigh.nodes:
                            neigh.nodes[v]["node_feature"] = (torch.ones(1) if
                                                              anchor == v else
                                                              torch.zeros(1))

                if (filter_negs and train and len(neigh) <= 6
                        and neg_target is not None):
                    matcher = nx.algorithms.isomorphism.GraphMatcher(
                        neg_target[graph_idx], neigh)
                    if not matcher.subgraph_is_isomorphic(): done = True
                else:
                    done = True

            return graph, DSGraph(neigh)

        augmenter = feature_preprocess.FeatureAugment()

        pos_target = batch_target
        pos_target, pos_query = pos_target.apply_transform_multi(
            sample_subgraph)
        neg_target = batch_neg_target
        # TODO: use hard negs
        hard_neg_idxs = set(
            random.sample(range(len(neg_target.G)),
                          int(len(neg_target.G) * 1 / 2)))
        #hard_neg_idxs = set()
        batch_neg_query = Batch.from_data_list(
            GraphDataset.list_to_graphs([
                self.generator.generate(
                    size=len(g)) if i not in hard_neg_idxs else g
                for i, g in enumerate(neg_target.G)
            ]))
        for i, g in enumerate(batch_neg_query.G):
            g.graph["idx"] = i
        _, neg_query = batch_neg_query.apply_transform_multi(
            sample_subgraph, hard_neg_idxs=hard_neg_idxs)
        if self.node_anchored:

            def add_anchor(g, anchors=None):
                if anchors is not None:
                    anchor = anchors[g.G.graph["idx"]]
                else:
                    anchor = random.choice(list(g.G.nodes))
                for v in g.G.nodes:
                    if "node_feature" not in g.G.nodes[v]:
                        g.G.nodes[v]["node_feature"] = (
                            torch.ones(1) if anchor == v else torch.zeros(1))
                return g

            neg_target = neg_target.apply_transform(add_anchor)
        pos_target = augmenter.augment(pos_target).to(utils.get_device())
        pos_query = augmenter.augment(pos_query).to(utils.get_device())
        neg_target = augmenter.augment(neg_target).to(utils.get_device())
        neg_query = augmenter.augment(neg_query).to(utils.get_device())
        #print(len(pos_target.G[0]), len(pos_query.G[0]))
        return pos_target, pos_query, neg_target, neg_query
Пример #10
0
def validation(args, model, test_pts, logger, batch_n, epoch, verbose=False):
    # test on new motifs
    model.eval()
    all_raw_preds, all_preds, all_labels = [], [], []
    for pos_a, pos_b, neg_a, neg_b in test_pts:
        if pos_a:
            pos_a = pos_a.to(utils.get_device())
            pos_b = pos_b.to(utils.get_device())
        neg_a = neg_a.to(utils.get_device())
        neg_b = neg_b.to(utils.get_device())
        labels = torch.tensor([1]*(pos_a.num_graphs if pos_a else 0) +
            [0]*neg_a.num_graphs).to(utils.get_device())
        with torch.no_grad():
            emb_neg_a, emb_neg_b = (model.emb_model(neg_a),
                model.emb_model(neg_b))
            if pos_a:
                emb_pos_a, emb_pos_b = (model.emb_model(pos_a),
                    model.emb_model(pos_b))
                emb_as = torch.cat((emb_pos_a, emb_neg_a), dim=0)
                emb_bs = torch.cat((emb_pos_b, emb_neg_b), dim=0)
            else:
                emb_as, emb_bs = emb_neg_a, emb_neg_b
            pred = model(emb_as, emb_bs)
            raw_pred = model.predict(pred)
            if USE_ORCA_FEATS:
                import orca
                import matplotlib.pyplot as plt
                def make_feats(g):
                    counts5 = np.array(orca.orbit_counts("node", 5, g))
                    for v, n in zip(counts5, g.nodes):
                        if g.nodes[n]["node_feature"][0] > 0:
                            anchor_v = v
                            break
                    v5 = np.sum(counts5, axis=0)
                    return v5, anchor_v
                for i, (ga, gb) in enumerate(zip(neg_a.G, neg_b.G)):
                    (va, na), (vb, nb) = make_feats(ga), make_feats(gb)
                    if (va < vb).any() or (na < nb).any():
                        raw_pred[pos_a.num_graphs + i] = MAX_MARGIN_SCORE

            if args.method_type == "order":
                pred = model.clf_model(raw_pred.unsqueeze(1)).argmax(dim=-1)
                raw_pred *= -1
            elif args.method_type == "ensemble":
                pred = torch.stack([m.clf_model(
                    raw_pred.unsqueeze(1)).argmax(dim=-1) for m in model.models])
                for i in range(pred.shape[1]):
                    print(pred[:,i])
                pred = torch.min(pred, dim=0)[0]
                raw_pred *= -1
            elif args.method_type == "mlp":
                raw_pred = raw_pred[:,1]
                pred = pred.argmax(dim=-1)
        all_raw_preds.append(raw_pred)
        all_preds.append(pred)
        all_labels.append(labels)
    pred = torch.cat(all_preds, dim=-1)
    labels = torch.cat(all_labels, dim=-1)
    raw_pred = torch.cat(all_raw_preds, dim=-1)
    acc = torch.mean((pred == labels).type(torch.float))
    prec = (torch.sum(pred * labels).item() / torch.sum(pred).item() if
        torch.sum(pred) > 0 else float("NaN"))
    recall = (torch.sum(pred * labels).item() /
        torch.sum(labels).item() if torch.sum(labels) > 0 else
        float("NaN"))
    labels = labels.detach().cpu().numpy()
    raw_pred = raw_pred.detach().cpu().numpy()
    pred = pred.detach().cpu().numpy()
    auroc = roc_auc_score(labels, raw_pred)
    avg_prec = average_precision_score(labels, raw_pred)
    tn, fp, fn, tp = confusion_matrix(labels, pred).ravel()
    if verbose:
        import matplotlib.pyplot as plt
        precs, recalls, threshs = precision_recall_curve(labels, raw_pred)
        plt.plot(recalls, precs)
        plt.xlabel("Recall")
        plt.ylabel("Precision")
        plt.savefig("plots/precision-recall-curve.png")
        print("Saved PR curve plot in plots/precision-recall-curve.png")

    print("\n{}".format(str(datetime.now())))
    print("Validation. Epoch {}. Acc: {:.4f}. "
        "P: {:.4f}. R: {:.4f}. AUROC: {:.4f}. AP: {:.4f}.\n     "
        "TN: {}. FP: {}. FN: {}. TP: {}".format(epoch,
            acc, prec, recall, auroc, avg_prec,
            tn, fp, fn, tp))

    if not args.test:
        logger.add_scalar("Accuracy/test", acc, batch_n)
        logger.add_scalar("Precision/test", prec, batch_n)
        logger.add_scalar("Recall/test", recall, batch_n)
        logger.add_scalar("AUROC/test", auroc, batch_n)
        logger.add_scalar("AvgPrec/test", avg_prec, batch_n)
        logger.add_scalar("TP/test", tp, batch_n)
        logger.add_scalar("TN/test", tn, batch_n)
        logger.add_scalar("FP/test", fp, batch_n)
        logger.add_scalar("FN/test", fn, batch_n)
        print("Saving {}".format(args.model_path))
        torch.save(model.state_dict(), args.model_path)

    if verbose:
        conf_mat_examples = defaultdict(list)
        idx = 0
        for pos_a, pos_b, neg_a, neg_b in test_pts:
            if pos_a:
                pos_a = pos_a.to(utils.get_device())
                pos_b = pos_b.to(utils.get_device())
            neg_a = neg_a.to(utils.get_device())
            neg_b = neg_b.to(utils.get_device())
            for list_a, list_b in [(pos_a, pos_b), (neg_a, neg_b)]:
                if not list_a: continue
                for a, b in zip(list_a.G, list_b.G):
                    correct = pred[idx] == labels[idx]
                    conf_mat_examples[correct, pred[idx]].append((a, b))
                    idx += 1
Пример #11
0
from data.distributor import get_fl_graph
from data.loader import get_loader
from models.train import distributed_train, test
from models.utils import get_model
from viz.training_plots import training_plots

print = functools.partial(print, flush=True)
torch.set_printoptions(linewidth=120)

# ------------------------------------------------------------------------------
# Setups
# ------------------------------------------------------------------------------

args = Arguments(argparser())
hook = sy.TorchHook(torch)
device = get_device(args)
paths = get_paths(args, distributed=True)
log_file, std_out = init_logger(paths.log_file, args.dry_run, args.load_model)
if os.path.exists(paths.tb_path):
    shutil.rmtree(paths.tb_path)
tb = SummaryWriter(paths.tb_path)

print('+' * 80)
print(paths.model_name)
print('+' * 80)

print(args.__dict__)
print('+' * 80)

# prepare graph and data
_, workers = get_fl_graph(hook, args.num_workers)
Пример #12
0
data_loaders = get_dataloader(image_resize=256,
                              mean=MEAN,
                              std=STD,
                              fast_train=FAST_TRAIN,
                              batch_size=BATCH_SIZE,
                              img_dir=prop('images256'))

# %%

GAIN_TARGET_LABEL = 'pigment_network'

gain_target = 4  # data_loaders[PHASE_TRAIN].dataset.labels().index(GAIN_TARGET_LABEL)

# %%

device = get_device(cpu_force=is_local_env())
print('Got device', device)

# %%
model = multi_label_resnet50(num_labels=N_CLASSES, pretrained=True)
model = MLGradCamResnet(model=model,
                        device=device,
                        cam_category=gain_target,
                        target_layer='layer4.2')

optimizer = optim.Adam(model.model.parameters(), lr=LEARNING_RATE)
scheduler = lr_scheduler.CosineAnnealingLR(optimizer, T_max=5, eta_min=0.005)

# %%

N_COLS = 3
Пример #13
0
from subgraph_matching.config import parse_encoder

# Now we load the model and a dataset to analyze embeddings on, here ENZYMES.

from subgraph_matching.train import make_data_source

parser = argparse.ArgumentParser()

utils.parse_optimizer(parser)
parse_encoder(parser)
args = parser.parse_args("")
args.model_path = os.path.join("..", args.model_path)

print("Using dataset {}".format(args.dataset))
model = models.OrderEmbedder(1, args.hidden_dim, args)
model.to(utils.get_device())
model.eval()
model.load_state_dict(
    torch.load(args.model_path, map_location=utils.get_device()))

train, test, task = data.load_dataset("wn18")

from collections import Counter

done = False
train_accs = []
while not done:
    data_source = make_data_source(args)
    loaders = data_source.gen_data_loaders(args.eval_interval *
                                           args.batch_size,
                                           args.batch_size,
def pattern_growth(dataset, task, args):
    # init model
    if args.method_type == "end2end":
        model = models.End2EndOrder(1, args.hidden_dim, args)
    elif args.method_type == "mlp":
        model = models.BaselineMLP(1, args.hidden_dim, args)
    else:
        model = models.OrderEmbedder(1, args.hidden_dim, args)
    model.to(utils.get_device())
    model.eval()
    model.load_state_dict(
        torch.load(args.model_path, map_location=utils.get_device()))

    if task == "graph-labeled":
        dataset, labels = dataset

    # load data
    neighs_pyg, neighs = [], []
    print(len(dataset), "graphs")
    print("search strategy:", args.search_strategy)
    if task == "graph-labeled": print("using label 0")
    graphs = []
    for i, graph in enumerate(dataset):
        if task == "graph-labeled" and labels[i] != 0: continue
        if task == "graph-truncate" and i >= 1000: break
        if not type(graph) == nx.Graph:
            graph = pyg_utils.to_networkx(graph).to_undirected()
        graphs.append(graph)
    if args.use_whole_graphs:
        neighs = graphs
    else:
        anchors = []
        if args.sample_method == "radial":
            for i, graph in enumerate(graphs):
                print(i)
                for j, node in enumerate(graph.nodes):
                    if len(dataset) <= 10 and j % 100 == 0: print(i, j)
                    if args.use_whole_graphs:
                        neigh = graph.nodes
                    else:
                        neigh = list(
                            nx.single_source_shortest_path_length(
                                graph, node, cutoff=args.radius).keys())
                        if args.subgraph_sample_size != 0:
                            neigh = random.sample(
                                neigh,
                                min(len(neigh), args.subgraph_sample_size))
                    if len(neigh) > 1:
                        neigh = graph.subgraph(neigh)
                        if args.subgraph_sample_size != 0:
                            neigh = neigh.subgraph(
                                max(nx.connected_components(neigh), key=len))
                        neigh = nx.convert_node_labels_to_integers(neigh)
                        neigh.add_edge(0, 0)
                        neighs.append(neigh)
        elif args.sample_method == "tree":
            start_time = time.time()
            for j in tqdm(range(args.n_neighborhoods)):
                graph, neigh = utils.sample_neigh(
                    graphs,
                    random.randint(args.min_neighborhood_size,
                                   args.max_neighborhood_size))
                neigh = graph.subgraph(neigh)
                neigh = nx.convert_node_labels_to_integers(neigh)
                neigh.add_edge(0, 0)
                neighs.append(neigh)
                if args.node_anchored:
                    anchors.append(
                        0)  # after converting labels, 0 will be anchor

    embs = []
    if len(neighs) % args.batch_size != 0:
        print("WARNING: number of graphs not multiple of batch size")
    for i in range(len(neighs) // args.batch_size):
        #top = min(len(neighs), (i+1)*args.batch_size)
        top = (i + 1) * args.batch_size
        with torch.no_grad():
            batch = utils.batch_nx_graphs(
                neighs[i * args.batch_size:top],
                anchors=anchors if args.node_anchored else None)
            emb = model.emb_model(batch)
            emb = emb.to(torch.device("cpu"))

        embs.append(emb)

    if args.analyze:
        embs_np = torch.stack(embs).numpy()
        plt.scatter(embs_np[:, 0], embs_np[:, 1], label="node neighborhood")

    if args.search_strategy == "mcts":
        assert args.method_type == "order"
        agent = MCTSSearchAgent(args.min_pattern_size,
                                args.max_pattern_size,
                                model,
                                graphs,
                                embs,
                                node_anchored=args.node_anchored,
                                analyze=args.analyze,
                                out_batch_size=args.out_batch_size)
    elif args.search_strategy == "greedy":
        agent = GreedySearchAgent(args.min_pattern_size,
                                  args.max_pattern_size,
                                  model,
                                  graphs,
                                  embs,
                                  node_anchored=args.node_anchored,
                                  analyze=args.analyze,
                                  model_type=args.method_type,
                                  out_batch_size=args.out_batch_size)
    out_graphs = agent.run_search(args.n_trials)
    print(time.time() - start_time, "TOTAL TIME")
    x = int(time.time() - start_time)
    print(x // 60, "mins", x % 60, "secs")

    # visualize out patterns
    count_by_size = defaultdict(int)
    for pattern in out_graphs:
        if args.node_anchored:
            colors = ["red"] + ["blue"] * (len(pattern) - 1)
            nx.draw(pattern, node_color=colors, with_labels=True)
        else:
            nx.draw(pattern)
        print("Saving plots/cluster/{}-{}.png".format(
            len(pattern), count_by_size[len(pattern)]))
        plt.savefig("plots/cluster/{}-{}.png".format(
            len(pattern), count_by_size[len(pattern)]))
        plt.savefig("plots/cluster/{}-{}.pdf".format(
            len(pattern), count_by_size[len(pattern)]))
        plt.close()
        count_by_size[len(pattern)] += 1

    if not os.path.exists("results"):
        os.makedirs("results")
    with open(args.out_path, "wb") as f:
        pickle.dump(out_graphs, f)
Пример #15
0
def train(args, model, logger, in_queue, out_queue):
    """Train the order embedding model.

    args: Commandline arguments
    logger: logger for logging progress
    in_queue: input queue to an intersection computation worker
    out_queue: output queue to an intersection computation worker
    """
    scheduler, opt = utils.build_optimizer(args, model.parameters())
    if args.method_type == "order":
        clf_opt = optim.Adam(model.clf_model.parameters(), lr=args.lr)

    done = False
    while not done:
        data_source = make_data_source(args)
        loaders = data_source.gen_data_loaders(args.eval_interval *
                                               args.batch_size,
                                               args.batch_size,
                                               train=True)
        c = -1
        for batch_target, batch_neg_target, batch_neg_query in zip(*loaders):
            # msg, _ = in_queue.get()
            # if msg == "done":
            #     done = True
            #     break
            # train
            c += 1
            if c == 100:
                done = True
                print("Saving")
                torch.save(model.state_dict(), args.model_path)
                break
            model.train()
            model.zero_grad()
            pos_a, pos_b, neg_a, neg_b, _ = data_source.gen_batch(
                batch_target, batch_neg_target, batch_neg_query, True)
            emb_pos_a, emb_pos_b = model.emb_model(pos_a), model.emb_model(
                pos_b)
            emb_neg_a, emb_neg_b = model.emb_model(neg_a), model.emb_model(
                neg_b)
            #print(emb_pos_a.shape, emb_neg_a.shape, emb_neg_b.shape)
            emb_as = torch.cat((emb_pos_a, emb_neg_a), dim=0)
            emb_bs = torch.cat((emb_pos_b, emb_neg_b), dim=0)
            labels = torch.tensor([1] * pos_a.num_graphs +
                                  [0] * neg_a.num_graphs).to(
                                      utils.get_device())
            intersect_embs = None
            pred = model(emb_as, emb_bs)
            loss = model.criterion(pred, intersect_embs, labels)
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            opt.step()
            if scheduler:
                scheduler.step()

            if args.method_type == "order":
                with torch.no_grad():
                    pred = model.predict(pred)
                model.clf_model.zero_grad()
                pred = model.clf_model(pred.unsqueeze(1))
                criterion = nn.NLLLoss()
                clf_loss = criterion(pred, labels)
                clf_loss.backward()
                clf_opt.step()
            pred = pred.argmax(dim=-1)
            # import pdb; pdb.set_trace()
            acc = torch.mean(((1 - pred) == labels).type(torch.float))
            acc_ = torch.mean((pred == labels).type(torch.float))
            train_loss = loss.item()
            train_acc = acc.item()
            train_acc_ = acc_.item()
            print('Loss/ACC: ', c, train_loss, train_acc, train_acc_)
            with open('performance.txt', 'a') as f:
                f.write(' '.join([
                    'Loss/ACC: ',
                    str(c),
                    str(train_loss),
                    str(train_acc),
                    str(train_acc_)
                ]))
                f.write('/n')