def build_model(args): # build model if args.method_type == "order": model = models.OrderEmbedder(1, args.hidden_dim, args) elif args.method_type == "mlp": model = models.BaselineMLP(1, args.hidden_dim, args) model.to(utils.get_device()) if args.test and args.model_path: model.load_state_dict(torch.load(args.model_path, map_location=utils.get_device())) return model
def get_dga_sdirs(args, data, labels): device = get_device(args) sdirs = [] for x, y in zip(data, labels): # dga_bs: dist grad accum. batch size dataloader = get_dataloader(x, y, args.dga_bs, shuffle=False) count = 0 for xiter, yiter in dataloader: model, loss_type = get_model(args, False) loss_fn = get_loss_fn(loss_type) opt = get_optim(args, model) loss, _ = forward(model, xiter, yiter, opt, loss_fn, device) loss.backward() sdirs.append(get_model_grads(model, flatten=True)) count += 1 if count >= args.num_dga: break stacked = [[] for _ in range(len(sdirs[0]))] for l in range(len(sdirs[0])): for i in range(len(sdirs)): stacked[l].append(sdirs[i][l].flatten()) sdirs = [[] for _ in range(args.ncomponent)] for l, layer in enumerate(stacked): layer = torch.stack(layer, dim=0).T.cpu().numpy() layer, _ = pca_transform(layer, args.ncomponent) for i in range(args.ncomponent): sdirs[i].append(layer[:, i].flatten()) assert len(sdirs) == args.ncomponent return sdirs
def train(args, model, logger, in_queue, out_queue): """Train the order embedding model. args: Commandline arguments logger: logger for logging progress in_queue: input queue to an intersection computation worker out_queue: output queue to an intersection computation worker """ scheduler, opt = utils.build_optimizer(args, model.parameters()) if args.method_type == "order": clf_opt = optim.Adam(model.clf_model.parameters(), lr=args.lr) done = False while not done: data_source = make_data_source(args) loaders = data_source.gen_data_loaders(args.eval_interval * args.batch_size, args.batch_size, train=True) for batch_target, batch_neg_target, batch_neg_query in zip(*loaders): msg, _ = in_queue.get() if msg == "done": done = True break # train model.train() model.zero_grad() pos_a, pos_b, neg_a, neg_b = data_source.gen_batch(batch_target, batch_neg_target, batch_neg_query, True) emb_pos_a, emb_pos_b = model.emb_model(pos_a), model.emb_model(pos_b) emb_neg_a, emb_neg_b = model.emb_model(neg_a), model.emb_model(neg_b) #print(emb_pos_a.shape, emb_neg_a.shape, emb_neg_b.shape) emb_as = torch.cat((emb_pos_a, emb_neg_a), dim=0) emb_bs = torch.cat((emb_pos_b, emb_neg_b), dim=0) labels = torch.tensor([1]*pos_a.num_graphs + [0]*neg_a.num_graphs).to( utils.get_device()) intersect_embs = None pred = model(emb_as, emb_bs) loss = model.criterion(pred, intersect_embs, labels) loss.backward() torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) opt.step() if scheduler: scheduler.step() if args.method_type == "order": with torch.no_grad(): pred = model.predict(pred) model.clf_model.zero_grad() pred = model.clf_model(pred.unsqueeze(1)) criterion = nn.NLLLoss() clf_loss = criterion(pred, labels) clf_loss.backward() clf_opt.step() pred = pred.argmax(dim=-1) acc = torch.mean((pred == labels).type(torch.float)) train_loss = loss.item() train_acc = acc.item() out_queue.put(("step", (loss.item(), acc)))
def get_model(args, parallel=True, ckpt_path=False): if args.clf == 'fcn': print('Initializing FCN...') model = FCN(args.input_size, args.output_size) elif args.clf == 'mlp': print('Initializing MLP...') model = MLP(args.input_size, args.output_size) elif args.clf == 'svm': print('Initializing SVM...') model = SVM(args.input_size, args.output_size) elif args.clf == 'cnn': print('Initializing CNN...') model = CNN(nc=args.num_channels, fs=args.cnn_view) elif args.clf == 'resnet18': print('Initializing ResNet18...') model = resnet.resnet18(num_channels=args.num_channels, num_classes=args.output_size) elif args.clf == 'vgg19': print('Initializing VGG19...') model = VGG(vgg_name=args.clf, num_channels=args.num_channels, num_classes=args.output_size) elif args.clf == 'unet': print('Initializing UNet...') model = UNet(in_channels=args.num_channels, out_channels=args.output_size) num_params, num_layers = get_model_size(model) print("# params: {}\n# layers: {}".format(num_params, num_layers)) if ckpt_path: model.load_state_dict(torch.load(ckpt_path)) print('Load init: {}'.format(ckpt_path)) if parallel: model = nn.DataParallel(model.to(get_device(args)), device_ids=args.device_id) else: model = model.to(get_device(args)) loss_type = 'hinge' if args.clf == 'svm' else args.loss_type print("Loss: {}".format(loss_type)) return model, loss_type
def criterion(self, pred, intersect_embs, labels): """Loss function for order emb. pred: lists of embeddings outputted by forward intersect_embs: not used labels: subgraph labels for each entry in pred """ emb_as, emb_bs = pred e = torch.sum(torch.max( torch.zeros_like(emb_as, device=utils.get_device()), emb_bs - emb_as)**2, dim=1) margin = self.margin e[labels == 0] = torch.max( torch.tensor(0.0, device=utils.get_device()), margin - e)[labels == 0] relation_loss = torch.sum(e) return relation_loss
def criterion(self, pred, intersect_embs, labels): """Loss function for order emb. The e term is the amount of violation (if b is a subgraph of a). For positive examples, the e term is minimized (close to 0); for negative examples, the e term is trained to be at least greater than self.margin. pred: lists of embeddings outputted by forward intersect_embs: not used labels: subgraph labels for each entry in pred """ emb_as, emb_bs = pred e = torch.sum(torch.max( torch.zeros_like(emb_as, device=utils.get_device()), emb_bs - emb_as)**2, dim=1) margin = self.margin e[labels == 0] = torch.max( torch.tensor(0.0, device=utils.get_device()), margin - e)[labels == 0] relation_loss = torch.sum(e) return relation_loss
def step(self): new_beam_sets = [] print("seeds come from", len(set(b[0][-1] for b in self.beam_sets)), "distinct graphs") analyze_embs_cur = [] for beam_set in tqdm(self.beam_sets): new_beams = [] for _, neigh, frontier, visited, graph_idx in beam_set: graph = self.dataset[graph_idx] if len(neigh) >= self.max_pattern_size or not frontier: continue cand_neighs, anchors = [], [] for cand_node in frontier: cand_neigh = graph.subgraph(neigh + [cand_node]) cand_neighs.append(cand_neigh) if self.node_anchored: anchors.append(neigh[0]) cand_embs = self.model.emb_model(utils.batch_nx_graphs( cand_neighs, anchors=anchors if self.node_anchored else None)) best_score, best_node = float("inf"), None for cand_node, cand_emb in zip(frontier, cand_embs): score, n_embs = 0, 0 for emb_batch in self.embs: n_embs += len(emb_batch) if self.model_type == "order": score -= torch.sum(torch.argmax( self.model.clf_model(self.model.predict(( emb_batch.to(utils.get_device()), cand_emb)).unsqueeze(1)), axis=1)).item() elif self.model_type == "mlp": score += torch.sum(self.model( emb_batch.to(utils.get_device()), cand_emb.unsqueeze(0).expand(len(emb_batch), -1) )[:,0]).item() else: print("unrecognized model type") if score < best_score: best_score = score best_node = cand_node new_frontier = list(((set(frontier) | set(graph.neighbors(cand_node))) - visited) - set([cand_node])) new_beams.append(( score, neigh + [cand_node], new_frontier, visited | set([cand_node]), graph_idx)) new_beams = list(sorted(new_beams, key=lambda x: x[0]))[:self.n_beams] for score, neigh, frontier, visited, graph_idx in new_beams[:1]: graph = self.dataset[graph_idx] # add to record neigh_g = graph.subgraph(neigh).copy() neigh_g.remove_edges_from(nx.selfloop_edges(neigh_g)) for v in neigh_g.nodes: neigh_g.nodes[v]["anchor"] = 1 if v == neigh[0] else 0 self.cand_patterns[len(neigh_g)].append((score, neigh_g)) if self.rank_method in ["counts", "hybrid"]: self.counts[len(neigh_g)][utils.wl_hash(neigh_g, node_anchored=self.node_anchored)].append(neigh_g) if self.analyze and len(neigh) >= 3: emb = self.model.emb_model(utils.batch_nx_graphs( [neigh_g], anchors=[neigh[0]] if self.node_anchored else None)).squeeze(0) analyze_embs_cur.append(emb.detach().cpu().numpy()) if len(new_beams) > 0: new_beam_sets.append(new_beams) self.beam_sets = new_beam_sets self.analyze_embs.append(analyze_embs_cur)
def step(self): ps = np.array([len(g) for g in self.dataset], dtype=np.float) ps /= np.sum(ps) graph_dist = stats.rv_discrete(values=(np.arange(len(self.dataset)), ps)) print("Size", self.max_size) print(len(self.visited_seed_nodes), "distinct seeds") for simulation_n in tqdm(range(self.n_trials // (self.max_pattern_size+1-self.min_pattern_size))): # pick seed node best_graph_idx, best_start_node, best_score = None, None, -float("inf") for cand_graph_idx, cand_start_node in self.visited_seed_nodes: state = cand_graph_idx, cand_start_node my_visit_counts = sum(self.visit_counts[state].values()) q_score = (sum(self.cum_action_values[state].values()) / (my_visit_counts or 1)) uct_score = self.c_uct * np.sqrt(np.log(simulation_n or 1) / (my_visit_counts or 1)) node_score = q_score + uct_score if node_score > best_score: best_score = node_score best_graph_idx = cand_graph_idx best_start_node = cand_start_node # if existing seed beats choosing a new seed if best_score >= self.c_uct * np.sqrt(np.log(simulation_n or 1)): graph_idx, start_node = best_graph_idx, best_start_node assert best_start_node in self.dataset[graph_idx].nodes graph = self.dataset[graph_idx] else: found = False while not found: graph_idx = np.arange(len(self.dataset))[graph_dist.rvs()] graph = self.dataset[graph_idx] start_node = random.choice(list(graph.nodes)) # don't pick isolated nodes or small islands if self.has_min_reachable_nodes(graph, start_node, self.min_pattern_size): found = True self.visited_seed_nodes.add((graph_idx, start_node)) neigh = [start_node] frontier = list(set(graph.neighbors(start_node)) - set(neigh)) visited = set([start_node]) neigh_g = nx.Graph() neigh_g.add_node(start_node, anchor=1) cur_state = graph_idx, start_node state_list = [cur_state] while frontier and len(neigh) < self.max_size: cand_neighs, anchors = [], [] for cand_node in frontier: cand_neigh = graph.subgraph(neigh + [cand_node]) cand_neighs.append(cand_neigh) if self.node_anchored: anchors.append(neigh[0]) cand_embs = self.model.emb_model(utils.batch_nx_graphs( cand_neighs, anchors=anchors if self.node_anchored else None)) best_v_score, best_node_score, best_node = 0, -float("inf"), None for cand_node, cand_emb in zip(frontier, cand_embs): score, n_embs = 0, 0 for emb_batch in self.embs: score += torch.sum(self.model.predict(( emb_batch.to(utils.get_device()), cand_emb))).item() n_embs += len(emb_batch) v_score = -np.log(score/n_embs + 1) + 1 # get wl hash of next state neigh_g = graph.subgraph(neigh + [cand_node]).copy() neigh_g.remove_edges_from(nx.selfloop_edges(neigh_g)) for v in neigh_g.nodes: neigh_g.nodes[v]["anchor"] = 1 if v == neigh[0] else 0 next_state = utils.wl_hash(neigh_g, node_anchored=self.node_anchored) # compute node score parent_visit_counts = sum(self.visit_counts[cur_state].values()) my_visit_counts = sum(self.visit_counts[next_state].values()) q_score = (sum(self.cum_action_values[next_state].values()) / (my_visit_counts or 1)) uct_score = self.c_uct * np.sqrt(np.log(parent_visit_counts or 1) / (my_visit_counts or 1)) node_score = q_score + uct_score if node_score > best_node_score: best_node_score = node_score best_v_score = v_score best_node = cand_node frontier = list(((set(frontier) | set(graph.neighbors(best_node))) - visited) - set([best_node])) visited.add(best_node) neigh.append(best_node) # update visit counts, wl cache neigh_g = graph.subgraph(neigh).copy() neigh_g.remove_edges_from(nx.selfloop_edges(neigh_g)) for v in neigh_g.nodes: neigh_g.nodes[v]["anchor"] = 1 if v == neigh[0] else 0 prev_state = cur_state cur_state = utils.wl_hash(neigh_g, node_anchored=self.node_anchored) state_list.append(cur_state) self.wl_hash_to_graphs[cur_state].append(neigh_g) # backprop value for i in range(0, len(state_list) - 1): self.cum_action_values[state_list[i]][ state_list[i+1]] += best_v_score self.visit_counts[state_list[i]][state_list[i+1]] += 1 self.max_size += 1
def gen_batch(self, batch_target, batch_neg_target, batch_neg_query, train): def sample_subgraph(graph, offset=0, use_precomp_sizes=False, filter_negs=False, supersample_small_graphs=False, neg_target=None, hard_neg_idxs=None): if neg_target is not None: graph_idx = graph.G.graph["idx"] use_hard_neg = (hard_neg_idxs is not None and graph.G.graph["idx"] in hard_neg_idxs) done = False n_tries = 0 while not done: if use_precomp_sizes: size = graph.G.graph["subgraph_size"] else: if train and supersample_small_graphs: sizes = np.arange(self.min_size + offset, len(graph.G) + offset) ps = (sizes - self.min_size + 2)**(-1.1) ps /= ps.sum() size = stats.rv_discrete(values=(sizes, ps)).rvs() else: d = 1 if train else 0 size = random.randint(self.min_size + offset - d, len(graph.G) - 1 + offset) start_node = random.choice(list(graph.G.nodes)) neigh = [start_node] frontier = list( set(graph.G.neighbors(start_node)) - set(neigh)) visited = set([start_node]) while len(neigh) < size: new_node = random.choice(list(frontier)) assert new_node not in neigh neigh.append(new_node) visited.add(new_node) frontier += list(graph.G.neighbors(new_node)) frontier = [x for x in frontier if x not in visited] if self.node_anchored: anchor = neigh[0] for v in graph.G.nodes: graph.G.nodes[v]["node_feature"] = ( torch.ones(1) if anchor == v else torch.zeros(1)) #print(v, graph.G.nodes[v]["node_feature"]) neigh = graph.G.subgraph(neigh) if use_hard_neg and train: neigh = neigh.copy() if random.random( ) < 1.0 or not self.node_anchored: # add edges non_edges = list(nx.non_edges(neigh)) if len(non_edges) > 0: for u, v in random.sample( non_edges, random.randint(1, min(len(non_edges), 5))): neigh.add_edge(u, v) else: # perturb anchor anchor = random.choice(list(neigh.nodes)) for v in neigh.nodes: neigh.nodes[v]["node_feature"] = (torch.ones(1) if anchor == v else torch.zeros(1)) if (filter_negs and train and len(neigh) <= 6 and neg_target is not None): matcher = nx.algorithms.isomorphism.GraphMatcher( neg_target[graph_idx], neigh) if not matcher.subgraph_is_isomorphic(): done = True else: done = True return graph, DSGraph(neigh) augmenter = feature_preprocess.FeatureAugment() pos_target = batch_target pos_target, pos_query = pos_target.apply_transform_multi( sample_subgraph) neg_target = batch_neg_target # TODO: use hard negs hard_neg_idxs = set( random.sample(range(len(neg_target.G)), int(len(neg_target.G) * 1 / 2))) #hard_neg_idxs = set() batch_neg_query = Batch.from_data_list( GraphDataset.list_to_graphs([ self.generator.generate( size=len(g)) if i not in hard_neg_idxs else g for i, g in enumerate(neg_target.G) ])) for i, g in enumerate(batch_neg_query.G): g.graph["idx"] = i _, neg_query = batch_neg_query.apply_transform_multi( sample_subgraph, hard_neg_idxs=hard_neg_idxs) if self.node_anchored: def add_anchor(g, anchors=None): if anchors is not None: anchor = anchors[g.G.graph["idx"]] else: anchor = random.choice(list(g.G.nodes)) for v in g.G.nodes: if "node_feature" not in g.G.nodes[v]: g.G.nodes[v]["node_feature"] = ( torch.ones(1) if anchor == v else torch.zeros(1)) return g neg_target = neg_target.apply_transform(add_anchor) pos_target = augmenter.augment(pos_target).to(utils.get_device()) pos_query = augmenter.augment(pos_query).to(utils.get_device()) neg_target = augmenter.augment(neg_target).to(utils.get_device()) neg_query = augmenter.augment(neg_query).to(utils.get_device()) #print(len(pos_target.G[0]), len(pos_query.G[0])) return pos_target, pos_query, neg_target, neg_query
def validation(args, model, test_pts, logger, batch_n, epoch, verbose=False): # test on new motifs model.eval() all_raw_preds, all_preds, all_labels = [], [], [] for pos_a, pos_b, neg_a, neg_b in test_pts: if pos_a: pos_a = pos_a.to(utils.get_device()) pos_b = pos_b.to(utils.get_device()) neg_a = neg_a.to(utils.get_device()) neg_b = neg_b.to(utils.get_device()) labels = torch.tensor([1]*(pos_a.num_graphs if pos_a else 0) + [0]*neg_a.num_graphs).to(utils.get_device()) with torch.no_grad(): emb_neg_a, emb_neg_b = (model.emb_model(neg_a), model.emb_model(neg_b)) if pos_a: emb_pos_a, emb_pos_b = (model.emb_model(pos_a), model.emb_model(pos_b)) emb_as = torch.cat((emb_pos_a, emb_neg_a), dim=0) emb_bs = torch.cat((emb_pos_b, emb_neg_b), dim=0) else: emb_as, emb_bs = emb_neg_a, emb_neg_b pred = model(emb_as, emb_bs) raw_pred = model.predict(pred) if USE_ORCA_FEATS: import orca import matplotlib.pyplot as plt def make_feats(g): counts5 = np.array(orca.orbit_counts("node", 5, g)) for v, n in zip(counts5, g.nodes): if g.nodes[n]["node_feature"][0] > 0: anchor_v = v break v5 = np.sum(counts5, axis=0) return v5, anchor_v for i, (ga, gb) in enumerate(zip(neg_a.G, neg_b.G)): (va, na), (vb, nb) = make_feats(ga), make_feats(gb) if (va < vb).any() or (na < nb).any(): raw_pred[pos_a.num_graphs + i] = MAX_MARGIN_SCORE if args.method_type == "order": pred = model.clf_model(raw_pred.unsqueeze(1)).argmax(dim=-1) raw_pred *= -1 elif args.method_type == "ensemble": pred = torch.stack([m.clf_model( raw_pred.unsqueeze(1)).argmax(dim=-1) for m in model.models]) for i in range(pred.shape[1]): print(pred[:,i]) pred = torch.min(pred, dim=0)[0] raw_pred *= -1 elif args.method_type == "mlp": raw_pred = raw_pred[:,1] pred = pred.argmax(dim=-1) all_raw_preds.append(raw_pred) all_preds.append(pred) all_labels.append(labels) pred = torch.cat(all_preds, dim=-1) labels = torch.cat(all_labels, dim=-1) raw_pred = torch.cat(all_raw_preds, dim=-1) acc = torch.mean((pred == labels).type(torch.float)) prec = (torch.sum(pred * labels).item() / torch.sum(pred).item() if torch.sum(pred) > 0 else float("NaN")) recall = (torch.sum(pred * labels).item() / torch.sum(labels).item() if torch.sum(labels) > 0 else float("NaN")) labels = labels.detach().cpu().numpy() raw_pred = raw_pred.detach().cpu().numpy() pred = pred.detach().cpu().numpy() auroc = roc_auc_score(labels, raw_pred) avg_prec = average_precision_score(labels, raw_pred) tn, fp, fn, tp = confusion_matrix(labels, pred).ravel() if verbose: import matplotlib.pyplot as plt precs, recalls, threshs = precision_recall_curve(labels, raw_pred) plt.plot(recalls, precs) plt.xlabel("Recall") plt.ylabel("Precision") plt.savefig("plots/precision-recall-curve.png") print("Saved PR curve plot in plots/precision-recall-curve.png") print("\n{}".format(str(datetime.now()))) print("Validation. Epoch {}. Acc: {:.4f}. " "P: {:.4f}. R: {:.4f}. AUROC: {:.4f}. AP: {:.4f}.\n " "TN: {}. FP: {}. FN: {}. TP: {}".format(epoch, acc, prec, recall, auroc, avg_prec, tn, fp, fn, tp)) if not args.test: logger.add_scalar("Accuracy/test", acc, batch_n) logger.add_scalar("Precision/test", prec, batch_n) logger.add_scalar("Recall/test", recall, batch_n) logger.add_scalar("AUROC/test", auroc, batch_n) logger.add_scalar("AvgPrec/test", avg_prec, batch_n) logger.add_scalar("TP/test", tp, batch_n) logger.add_scalar("TN/test", tn, batch_n) logger.add_scalar("FP/test", fp, batch_n) logger.add_scalar("FN/test", fn, batch_n) print("Saving {}".format(args.model_path)) torch.save(model.state_dict(), args.model_path) if verbose: conf_mat_examples = defaultdict(list) idx = 0 for pos_a, pos_b, neg_a, neg_b in test_pts: if pos_a: pos_a = pos_a.to(utils.get_device()) pos_b = pos_b.to(utils.get_device()) neg_a = neg_a.to(utils.get_device()) neg_b = neg_b.to(utils.get_device()) for list_a, list_b in [(pos_a, pos_b), (neg_a, neg_b)]: if not list_a: continue for a, b in zip(list_a.G, list_b.G): correct = pred[idx] == labels[idx] conf_mat_examples[correct, pred[idx]].append((a, b)) idx += 1
from data.distributor import get_fl_graph from data.loader import get_loader from models.train import distributed_train, test from models.utils import get_model from viz.training_plots import training_plots print = functools.partial(print, flush=True) torch.set_printoptions(linewidth=120) # ------------------------------------------------------------------------------ # Setups # ------------------------------------------------------------------------------ args = Arguments(argparser()) hook = sy.TorchHook(torch) device = get_device(args) paths = get_paths(args, distributed=True) log_file, std_out = init_logger(paths.log_file, args.dry_run, args.load_model) if os.path.exists(paths.tb_path): shutil.rmtree(paths.tb_path) tb = SummaryWriter(paths.tb_path) print('+' * 80) print(paths.model_name) print('+' * 80) print(args.__dict__) print('+' * 80) # prepare graph and data _, workers = get_fl_graph(hook, args.num_workers)
data_loaders = get_dataloader(image_resize=256, mean=MEAN, std=STD, fast_train=FAST_TRAIN, batch_size=BATCH_SIZE, img_dir=prop('images256')) # %% GAIN_TARGET_LABEL = 'pigment_network' gain_target = 4 # data_loaders[PHASE_TRAIN].dataset.labels().index(GAIN_TARGET_LABEL) # %% device = get_device(cpu_force=is_local_env()) print('Got device', device) # %% model = multi_label_resnet50(num_labels=N_CLASSES, pretrained=True) model = MLGradCamResnet(model=model, device=device, cam_category=gain_target, target_layer='layer4.2') optimizer = optim.Adam(model.model.parameters(), lr=LEARNING_RATE) scheduler = lr_scheduler.CosineAnnealingLR(optimizer, T_max=5, eta_min=0.005) # %% N_COLS = 3
from subgraph_matching.config import parse_encoder # Now we load the model and a dataset to analyze embeddings on, here ENZYMES. from subgraph_matching.train import make_data_source parser = argparse.ArgumentParser() utils.parse_optimizer(parser) parse_encoder(parser) args = parser.parse_args("") args.model_path = os.path.join("..", args.model_path) print("Using dataset {}".format(args.dataset)) model = models.OrderEmbedder(1, args.hidden_dim, args) model.to(utils.get_device()) model.eval() model.load_state_dict( torch.load(args.model_path, map_location=utils.get_device())) train, test, task = data.load_dataset("wn18") from collections import Counter done = False train_accs = [] while not done: data_source = make_data_source(args) loaders = data_source.gen_data_loaders(args.eval_interval * args.batch_size, args.batch_size,
def pattern_growth(dataset, task, args): # init model if args.method_type == "end2end": model = models.End2EndOrder(1, args.hidden_dim, args) elif args.method_type == "mlp": model = models.BaselineMLP(1, args.hidden_dim, args) else: model = models.OrderEmbedder(1, args.hidden_dim, args) model.to(utils.get_device()) model.eval() model.load_state_dict( torch.load(args.model_path, map_location=utils.get_device())) if task == "graph-labeled": dataset, labels = dataset # load data neighs_pyg, neighs = [], [] print(len(dataset), "graphs") print("search strategy:", args.search_strategy) if task == "graph-labeled": print("using label 0") graphs = [] for i, graph in enumerate(dataset): if task == "graph-labeled" and labels[i] != 0: continue if task == "graph-truncate" and i >= 1000: break if not type(graph) == nx.Graph: graph = pyg_utils.to_networkx(graph).to_undirected() graphs.append(graph) if args.use_whole_graphs: neighs = graphs else: anchors = [] if args.sample_method == "radial": for i, graph in enumerate(graphs): print(i) for j, node in enumerate(graph.nodes): if len(dataset) <= 10 and j % 100 == 0: print(i, j) if args.use_whole_graphs: neigh = graph.nodes else: neigh = list( nx.single_source_shortest_path_length( graph, node, cutoff=args.radius).keys()) if args.subgraph_sample_size != 0: neigh = random.sample( neigh, min(len(neigh), args.subgraph_sample_size)) if len(neigh) > 1: neigh = graph.subgraph(neigh) if args.subgraph_sample_size != 0: neigh = neigh.subgraph( max(nx.connected_components(neigh), key=len)) neigh = nx.convert_node_labels_to_integers(neigh) neigh.add_edge(0, 0) neighs.append(neigh) elif args.sample_method == "tree": start_time = time.time() for j in tqdm(range(args.n_neighborhoods)): graph, neigh = utils.sample_neigh( graphs, random.randint(args.min_neighborhood_size, args.max_neighborhood_size)) neigh = graph.subgraph(neigh) neigh = nx.convert_node_labels_to_integers(neigh) neigh.add_edge(0, 0) neighs.append(neigh) if args.node_anchored: anchors.append( 0) # after converting labels, 0 will be anchor embs = [] if len(neighs) % args.batch_size != 0: print("WARNING: number of graphs not multiple of batch size") for i in range(len(neighs) // args.batch_size): #top = min(len(neighs), (i+1)*args.batch_size) top = (i + 1) * args.batch_size with torch.no_grad(): batch = utils.batch_nx_graphs( neighs[i * args.batch_size:top], anchors=anchors if args.node_anchored else None) emb = model.emb_model(batch) emb = emb.to(torch.device("cpu")) embs.append(emb) if args.analyze: embs_np = torch.stack(embs).numpy() plt.scatter(embs_np[:, 0], embs_np[:, 1], label="node neighborhood") if args.search_strategy == "mcts": assert args.method_type == "order" agent = MCTSSearchAgent(args.min_pattern_size, args.max_pattern_size, model, graphs, embs, node_anchored=args.node_anchored, analyze=args.analyze, out_batch_size=args.out_batch_size) elif args.search_strategy == "greedy": agent = GreedySearchAgent(args.min_pattern_size, args.max_pattern_size, model, graphs, embs, node_anchored=args.node_anchored, analyze=args.analyze, model_type=args.method_type, out_batch_size=args.out_batch_size) out_graphs = agent.run_search(args.n_trials) print(time.time() - start_time, "TOTAL TIME") x = int(time.time() - start_time) print(x // 60, "mins", x % 60, "secs") # visualize out patterns count_by_size = defaultdict(int) for pattern in out_graphs: if args.node_anchored: colors = ["red"] + ["blue"] * (len(pattern) - 1) nx.draw(pattern, node_color=colors, with_labels=True) else: nx.draw(pattern) print("Saving plots/cluster/{}-{}.png".format( len(pattern), count_by_size[len(pattern)])) plt.savefig("plots/cluster/{}-{}.png".format( len(pattern), count_by_size[len(pattern)])) plt.savefig("plots/cluster/{}-{}.pdf".format( len(pattern), count_by_size[len(pattern)])) plt.close() count_by_size[len(pattern)] += 1 if not os.path.exists("results"): os.makedirs("results") with open(args.out_path, "wb") as f: pickle.dump(out_graphs, f)
def train(args, model, logger, in_queue, out_queue): """Train the order embedding model. args: Commandline arguments logger: logger for logging progress in_queue: input queue to an intersection computation worker out_queue: output queue to an intersection computation worker """ scheduler, opt = utils.build_optimizer(args, model.parameters()) if args.method_type == "order": clf_opt = optim.Adam(model.clf_model.parameters(), lr=args.lr) done = False while not done: data_source = make_data_source(args) loaders = data_source.gen_data_loaders(args.eval_interval * args.batch_size, args.batch_size, train=True) c = -1 for batch_target, batch_neg_target, batch_neg_query in zip(*loaders): # msg, _ = in_queue.get() # if msg == "done": # done = True # break # train c += 1 if c == 100: done = True print("Saving") torch.save(model.state_dict(), args.model_path) break model.train() model.zero_grad() pos_a, pos_b, neg_a, neg_b, _ = data_source.gen_batch( batch_target, batch_neg_target, batch_neg_query, True) emb_pos_a, emb_pos_b = model.emb_model(pos_a), model.emb_model( pos_b) emb_neg_a, emb_neg_b = model.emb_model(neg_a), model.emb_model( neg_b) #print(emb_pos_a.shape, emb_neg_a.shape, emb_neg_b.shape) emb_as = torch.cat((emb_pos_a, emb_neg_a), dim=0) emb_bs = torch.cat((emb_pos_b, emb_neg_b), dim=0) labels = torch.tensor([1] * pos_a.num_graphs + [0] * neg_a.num_graphs).to( utils.get_device()) intersect_embs = None pred = model(emb_as, emb_bs) loss = model.criterion(pred, intersect_embs, labels) loss.backward() torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) opt.step() if scheduler: scheduler.step() if args.method_type == "order": with torch.no_grad(): pred = model.predict(pred) model.clf_model.zero_grad() pred = model.clf_model(pred.unsqueeze(1)) criterion = nn.NLLLoss() clf_loss = criterion(pred, labels) clf_loss.backward() clf_opt.step() pred = pred.argmax(dim=-1) # import pdb; pdb.set_trace() acc = torch.mean(((1 - pred) == labels).type(torch.float)) acc_ = torch.mean((pred == labels).type(torch.float)) train_loss = loss.item() train_acc = acc.item() train_acc_ = acc_.item() print('Loss/ACC: ', c, train_loss, train_acc, train_acc_) with open('performance.txt', 'a') as f: f.write(' '.join([ 'Loss/ACC: ', str(c), str(train_loss), str(train_acc), str(train_acc_) ])) f.write('/n')