def test_batching_of_batches(): data = Data(x=torch.randn(2, 16)) batch = Batch.from_data_list([data, data]) batch = Batch.from_data_list([batch, batch]) assert len(batch) == 2 assert batch.x[0:2].tolist() == data.x.tolist() assert batch.x[2:4].tolist() == data.x.tolist() assert batch.x[4:6].tolist() == data.x.tolist() assert batch.x[6:8].tolist() == data.x.tolist() assert batch.batch.tolist() == [0, 0, 1, 1, 2, 2, 3, 3]
def score(self): """ Scoring. """ print("\n\nModel evaluation.\n") self.model.eval() scores = np.empty( (len(self.testing_graphs), len(self.training_graphs))) ground_truth = np.empty( (len(self.testing_graphs), len(self.training_graphs))) prediction_mat = np.empty( (len(self.testing_graphs), len(self.training_graphs))) rho_list = [] tau_list = [] prec_at_10_list = [] prec_at_20_list = [] t = tqdm(total=len(self.testing_graphs) * len(self.training_graphs)) for i, g in enumerate(self.testing_graphs): source_batch = Batch.from_data_list([g] * len(self.training_graphs)) target_batch = Batch.from_data_list(self.training_graphs) data = self.transform((source_batch, target_batch)) target = data["target"] ground_truth[i] = target prediction = self.model(data) prediction_mat[i] = prediction.detach().numpy() scores[i] = (F.mse_loss(prediction, target, reduction="none").detach().numpy()) rho_list.append( calculate_ranking_correlation(spearmanr, prediction_mat[i], ground_truth[i])) tau_list.append( calculate_ranking_correlation(kendalltau, prediction_mat[i], ground_truth[i])) prec_at_10_list.append( calculate_prec_at_k(10, prediction_mat[i], ground_truth[i])) prec_at_20_list.append( calculate_prec_at_k(20, prediction_mat[i], ground_truth[i])) t.update(len(self.training_graphs)) self.rho = np.mean(rho_list).item() self.tau = np.mean(tau_list).item() self.prec_at_10 = np.mean(prec_at_10_list).item() self.prec_at_20 = np.mean(prec_at_20_list).item() self.model_error = np.mean(scores).item() self.print_evaluation()
def from_data_list(data_list): r""" from a list of torch_points3d.datasets.registation.pair.Pair objects, create a batch Warning : follow_batch is not here yet... """ assert isinstance(data_list[0], Pair) data_list_s, data_list_t = list(map(list, zip(*[data.to_data() for data in data_list]))) batch_s = Batch.from_data_list(data_list_s) batch_t = Batch.from_data_list(data_list_t) return PairBatch.make_pair(batch_s, batch_t).contiguous()
def sample(self, batch_size): max_mem = min(self.mem_cntr, self.mem_size) batch = np.random.choice(max_mem, batch_size, replace=False) graphs_pre_batch = Batch.from_data_list( [self.graphs_pre[b] for b in batch]) graphs_later_batch = Batch.from_data_list( [self.graphs_later[b] for b in batch]) actions_batch = T.tensor([self.actions[b] for b in batch]) rewards_batch = T.tensor([self.rewards[b] for b in batch]) return graphs_pre_batch, graphs_later_batch, actions_batch, rewards_batch
def sample(self, batch_size): max_mem = min(self.mem_cntr, self.mem_size) p = np.array(self.rewards) / np.sum(self.rewards) batch = np.random.choice(self.mem_size, batch_size, replace=False, p=p) graphs_former_batch = Batch.from_data_list( [self.graphs_former[b] for b in batch]) graphs_later_batch = Batch.from_data_list( [self.graphs_later[b] for b in batch]) actions_batch = torch.Tensor([self.actions[b] for b in batch]) rewards_batch = torch.Tensor([self.rewards[b] for b in batch]) done_batch = torch.Tensor([self.done[b] for b in batch]) return graphs_former_batch, graphs_later_batch, actions_batch, rewards_batch, done_batch
def train(train_loader, model, criterion, optimizer, epoch, args): batch_time = AverageMeter('Time', ':6.3f') data_time = AverageMeter('Data', ':6.3f') losses = AverageMeter('Loss', ':.4e') top1 = AverageMeter('Acc@1', ':6.2f') top5 = AverageMeter('Acc@5', ':6.2f') progress = ProgressMeter(len(train_loader), [batch_time, data_time, losses, top1, top5], prefix="Epoch: [{}]".format(epoch)) # switch to train mode model.train() for i, images in enumerate(train_loader): end = time.time() for im in images: im.edge_attr = None images_cls = Batch.from_data_list(images) im_q = Batch.from_data_list(random_augmentation(images)) im_k = Batch.from_data_list(random_augmentation(images)) data_time.update(time.time() - end) if args.gpu is not None: im_q = im_q.to(args.gpu) #, non_blocking=True) im_k = im_k.to(args.gpu) #, non_blocking=True) images_cls = images_cls.to(args.gpu) output, target, q_cls = model(im_q=im_q, im_k=im_k, image=images_cls) if args.gpu != None: target = target.to(args.gpu) loss = criterion(output, target) acc1, acc5 = accuracy(output, target, topk=(1, 5)) losses.update(loss.item(), len(images)) loss_list.append(loss.item()) top1.update(acc1[0], len(images)) top5.update(acc5[0], len(images)) optimizer.zero_grad() loss.backward() optimizer.step() # measure elapsed time batch_time.update(time.time() - end) end = time.time() if i % args.print_freq == 0: progress.display(i)
def __init__(self, X: torch.Tensor, edge_index: torch.Tensor, num_hops: int, n_rollout: int = 10, min_atoms: int = 3, c_puct: float = 10.0, expand_atoms: int = 14, high2low: bool = False, node_idx: int = None, score_func: Callable = None): """ graph is a networkX graph """ self.X = X self.edge_index = edge_index self.num_hops = num_hops self.data = Data(x=self.X, edge_index=self.edge_index) self.graph = to_networkx(self.data, to_undirected=True) self.data = Batch.from_data_list([self.data]) self.num_nodes = self.graph.number_of_nodes() self.score_func = score_func self.n_rollout = n_rollout self.min_atoms = min_atoms self.c_puct = c_puct self.expand_atoms = expand_atoms self.high2low = high2low # extract the sub-graph and change the node indices. if node_idx is not None: self.ori_node_idx = node_idx self.ori_graph = copy.copy(self.graph) x, edge_index, subset, edge_mask, kwargs = \ self.__subgraph__(node_idx, self.X, self.edge_index, self.num_hops) self.data = Batch.from_data_list( [Data(x=x, edge_index=edge_index)]) self.graph = self.ori_graph.subgraph(subset.tolist()) mapping = {int(v): k for k, v in enumerate(subset)} self.graph = nx.relabel_nodes(self.graph, mapping) self.node_idx = torch.where(subset == self.ori_node_idx)[0] self.num_nodes = self.graph.number_of_nodes() self.subset = subset self.root_coalition = sorted([node for node in range(self.num_nodes)]) self.MCTSNodeClass = partial(MCTSNode, data=self.data, ori_graph=self.graph, c_puct=self.c_puct) self.root = self.MCTSNodeClass(self.root_coalition) self.state_map = {str(self.root.coalition): self.root}
def collate_fn(samples): # print(samples) # filtering none samples = [sample for sample in samples if sample is not None] if samples: # nonempty : tuple or torch_batch if isinstance(samples[0], list): # list : multiple transform num_transforms = len(samples[0]) flatten_list = [_ for sample in samples for _ in sample] # bs * num_transforms data_trsfs_dict = OrderedDict() for trsf_i in range(num_transforms): trsf_data = flatten_list[ trsf_i:: num_transforms] # list or list of tuples(aug, perm) trsf_data = [data for data in trsf_data if data is not None] # None filtered if trsf_data: if isinstance(trsf_data[0], tuple): # aug, perm sample_list = [ _ for sample in trsf_data for _ in sample ] data_trsfs_dict[trsf_i] = tuple([ Batch.from_data_list( sample_list[pair_i::len(trsf_data[0])]) for pair_i in range(len(trsf_data[0])) ]) # left, right = sample_list[::2], sample_list[1::2] # data_trsfs_dict[trsf_i]=(Batch.from_data_list(left), Batch.from_data_list(right)) else: # dest, mask data_trsfs_dict[trsf_i] = Batch.from_data_list( trsf_data) else: # transformed data is all none and filtered out data_trsfs_dict[trsf_i] = None return list(data_trsfs_dict.values()) elif isinstance(samples[0], tuple): # tuple sample_list = [_ for sample in samples for _ in sample] left, right = sample_list[::2], sample_list[1::2] return Batch.from_data_list(left), Batch.from_data_list(right) else: # torch_batch #samples = [sample for sample in samples if sample is not None] return Batch.from_data_list(samples) else: #empty return None
def test_pair_data_batching(): class PairData(Data): def __inc__(self, key, value, *args, **kwargs): if key == 'edge_index_s': return self.x_s.size(0) if key == 'edge_index_t': return self.x_t.size(0) else: return super().__inc__(key, value, *args, **kwargs) x_s = torch.randn(5, 16) edge_index_s = torch.tensor([ [0, 0, 0, 0], [1, 2, 3, 4], ]) x_t = torch.randn(4, 16) edge_index_t = torch.tensor([ [0, 0, 0], [1, 2, 3], ]) data = PairData(x_s=x_s, edge_index_s=edge_index_s, x_t=x_t, edge_index_t=edge_index_t) batch = Batch.from_data_list([data, data]) assert torch.allclose(batch.x_s, torch.cat([x_s, x_s], dim=0)) assert batch.edge_index_s.tolist() == [[0, 0, 0, 0, 5, 5, 5, 5], [1, 2, 3, 4, 6, 7, 8, 9]] assert torch.allclose(batch.x_t, torch.cat([x_t, x_t], dim=0)) assert batch.edge_index_t.tolist() == [[0, 0, 0, 4, 4, 4], [1, 2, 3, 5, 6, 7]]
def validate(model, validate_loader, batch_size): model.eval() loss_fn = nn.MSELoss() loss_all = 0 for i in range(len(validate_loader) // batch_size): # conserve gpu memory try: del pred_, batch, x, edges, y, loss except: pass # ordered mini-batch batch = [ validate_loader[j] for j in range(i * batch_size, (i + 1) * batch_size) ] batch = Batch.from_data_list(batch).to(device=CUDA_DEVICE) x, edges, y = batch.x, batch.edge_index, batch.y pred_ = model(batch) loss = loss_fn(pred_, y) loss_all += loss.item() return loss_all / (len(validate_loader) // batch_size)
def sample_buffer(self, batch_size): max_mem = min(self.mem_cntr, self.mem_size) batch = np.random.choice(max_mem, batch_size, replace=False) graph_list = [self.graph_memory[b] for b in batch] keys = graph_list[0].keys return Batch.from_data_list(graph_list)
def explain_node(self, node_idx, x, edge_index, **kwargs): data = Batch.from_data_list([Data(x=x, edge_index=edge_index)]) data = data.to(self.device) with torch.no_grad(): _, prob, emb = self.get_model_output(data.x, data.edge_index) _, edge_mask = self.forward((data.x, emb, data.edge_index, 1.0), training=False) return edge_mask
def test_parallel(model, loader, total, batch_size, loss_ftn_obj): model.eval() sum_loss = 0. t = tqdm.tqdm(enumerate(loader), total=total / batch_size) for i, data in t: # forward and loss if loss_ftn_obj.name == 'vae_loss': batch_output, mu, log_var = model(data) y = torch.cat([d.x for d in data]).to(device) batch_loss_item = loss_ftn_obj.loss_ftn(batch_output, y, mu, log_var).item() elif loss_ftn_obj.name == 'emd_loss' or loss_ftn_obj.name == 'chamfer_loss': batch_output = model(data) data_batch = Batch.from_data_list(data).to(device) batch_loss = loss_ftn_obj.loss_ftn(batch_output, data_batch.x, data_batch.batch) batch_loss_item = batch_loss.mean().item() elif loss_ftn_obj.name == 'emd_loss_layer': _, batch_loss = model(data) batch_loss_item = batch_loss.mean().item() else: batch_output = model(data) y = torch.cat([d.x for d in data]).to(device) batch_loss_item = loss_ftn_obj.loss_ftn(batch_output, y).item() sum_loss += batch_loss_item t.set_description('eval loss = %.5f' % (batch_loss_item)) t.refresh() # to show immediately the update return sum_loss / (i + 1)
def _one_test_case(layer_generator): x_dim, input_x_dim, edge_attr_dim, output_x_dim = np.random.randint(1, 100, size=4).tolist() output_x_dim = x_dim data = Batch.from_data_list([_generate_data(x_dim, input_x_dim, edge_attr_dim, output_x_dim) for _ in range(10)]) layer = layer_generator(x_dim=x_dim, input_x_dim=input_x_dim, output_x_dim=output_x_dim, edge_attr_dim=edge_attr_dim) layer = layer.to(data.x.device) output_x = layer(data) # print(data, output_x.shape, layer) # Test output dimensionality assert(output_x.size(1) == output_x_dim) # Test homogeneous if layer.homogeneous_flag and layer.module.gnn_module.__class__.__name__ not in ['GATConv', 'GINConv', 'EpsGINConv', 'MPNNConv']: s = np.random.rand()*1000. data.x, data.input_x, data.edge_attr = data.x*s, data.input_x*s, data.edge_attr*s assert(torch.max(torch.abs(output_x*s-layer(data))) < 1e-3) data.x, data.input_x, data.edge_attr = data.x/s, data.input_x/s, data.edge_attr/s # Test backward loss_hist = [] optimizer = torch.optim.Adam(layer.parameters(), lr=1e-3, eps=1e-5) for _ in range(1000): optimizer.zero_grad() output_x = layer(data) loss = torch.sum(output_x**2) loss_hist.append(loss.item()) loss.backward() optimizer.step() print('**',loss_hist[::100]) if np.std(loss_hist) > 1e-4: corr = np.corrcoef(np.arange(len(loss_hist)), loss_hist,)[0,1] assert(corr < -1e-3)
def _process(self, data_list): if len(data_list) == 0: return Data() data = Batch.from_data_list(data_list) delattr(data, "batch") delattr(data, "ptr") return data
def __init__(self, node_idx: int, X: torch.Tensor, edge_index: torch.Tensor, ori_graph: nx.Graph, n_rollout: int, min_atoms: int, c_puct: float, expand_atoms: int, score_func=None, num_hops: int = 3): self.X = X self.edge_index = edge_index self.num_hops = num_hops self.ori_graph = ori_graph self.ori_node_idx = node_idx self.ori_num_nodes = self.ori_graph.number_of_nodes() self.n_rollout = n_rollout self.min_atoms = min_atoms self.c_puct = c_puct self.expand_atoms = expand_atoms self.score_func = score_func # extract the sub-graph and change the node indices. x, edge_index, subset, edge_mask, kwargs = \ self.__subgraph__(node_idx, self.X, self.edge_index) self.data = Batch.from_data_list([Data(x=x, edge_index=edge_index)]) self.graph = self.ori_graph.subgraph(subset.tolist()) mapping = {int(v): k for k, v in enumerate(subset)} self.graph = nx.relabel_nodes(self.graph, mapping) self.node_idx = torch.where(subset == self.ori_node_idx)[0] self.num_nodes = self.graph.number_of_nodes() self.root_coalition = [i for i in range(self.num_nodes)] self.MCTSNodeClass = partial(MCTSNode, data=self.data, ori_graph=self.graph, c_puct=self.c_puct) self.root = self.MCTSNodeClass(self.root_coalition) self.state_map = {str(sorted(self.root.coalition)): self.root}
def load_pyg_batch_from_network_list(network_list): data_list = [] for network in network_list: data = load_pyg_data_from_network(network) data_list.append(data) batch = Batch.from_data_list(data_list) return batch
def plot_reconstructions(self, index=0, path=None, name=None): """Plot reconstruction bar chart with validation data. Args: index (int): optional. The index of the validation data to use. Default is 0. path (str): optional. Path to save the plottings. Default is the current working directory. name (str): optional. Name of the saved plotting. Default is "reconstructions.png". """ root = self._rooting(path) self._setup_models("eval") if name is None: filep = os.path.join(root, "reconstructions.png") else: filep = os.path.join(root, name) self._setup_models("eval") data = self.dataloader.val_loader.dataset[index] batch = Batch.from_data_list([data]).to(self.device) label = batch.y[0].to("cpu").detach() with torch.no_grad(): encoder_out = self.encoder(batch) if isinstance(encoder_out, tuple): encoder_out, *_ = encoder_out out = self.decoder(encoder_out) out = torch.round(torch.sigmoid(out))[0].to("cpu").detach().numpy() fig, axes = plt.subplots(2, 1, figsize=(8.0, 12.0)) ax1, ax2 = axes.flatten() ax1.bar(list(range(out.shape[0])), label) ax1.set_xlabel("PubChem Fingerprint") ax2.bar(list(range(out.shape[0])), out) ax2.set_xlabel("Reconstructed Fingerprint") fig.savefig(filep, dpi=300, bbox_inches="tight") plt.close()
def get_pos_neg_pairs(data_list): num_data = len(data_list) pos_list = list() neg_list = list() labels = torch.tensor([x.y for x in data_list], dtype=torch.long) for i in range(0, num_data): same_labels = (labels == labels[i]).nonzero() pos_idx = same_labels[torch.randint(0, same_labels.shape[0], (1,)), 0] pos_list.append(data_list[pos_idx]) diff_labels = (labels != labels[i]).nonzero() neg_idx = diff_labels[torch.randint(0, diff_labels.shape[0], (1,)), 0] neg_list.append(data_list[neg_idx]) return Batch.from_data_list(pos_list), Batch.from_data_list(neg_list)
def test_batch(): torch_geometric.set_debug(True) x1 = torch.tensor([1, 2, 3], dtype=torch.float) e1 = torch.tensor([[0, 1, 1, 2], [1, 0, 2, 1]]) s1 = '1' x2 = torch.tensor([1, 2], dtype=torch.float) e2 = torch.tensor([[0, 1], [1, 0]]) s2 = '2' data = Batch.from_data_list([Data(x1, e1, s=s1), Data(x2, e2, s=s2)]) assert data.__repr__() == ( 'Batch(batch=[5], edge_index=[2, 6], ptr=[3], s=[2], x=[5])') assert len(data) == 5 assert data.x.tolist() == [1, 2, 3, 1, 2] assert data.edge_index.tolist() == [[0, 1, 1, 2, 3, 4], [1, 0, 2, 1, 4, 3]] assert data.s == ['1', '2'] assert data.batch.tolist() == [0, 0, 0, 1, 1] assert data.ptr.tolist() == [0, 3, 5] assert data.num_graphs == 2 data_list = data.to_data_list() assert len(data_list) == 2 assert len(data_list[0]) == 3 assert data_list[0].x.tolist() == [1, 2, 3] assert data_list[0].edge_index.tolist() == [[0, 1, 1, 2], [1, 0, 2, 1]] assert data_list[0].s == '1' assert len(data_list[1]) == 3 assert data_list[1].x.tolist() == [1, 2] assert data_list[1].edge_index.tolist() == [[0, 1], [1, 0]] assert data_list[1].s == '2' torch_geometric.set_debug(True)
def _save_graphs(sharded, shard_num, out_dir): print(f'Processing shard {shard_num:}') shard = sharded.read_shard(shard_num) neighbors = sharded.read_shard(shard_num, 'neighbors') curr_idx = 0 for i, (ensemble_name, target_df) in enumerate(shard.groupby(['ensemble'])): sub_names, (bound1, bound2, _, _) = nb.get_subunits(target_df) positives = neighbors[neighbors.ensemble0 == ensemble_name] negatives = nb.get_negatives(positives, bound1, bound2) negatives['label'] = 0 labels = create_labels(positives, negatives, num_pos=10, neg_pos_ratio=1) for index, row in labels.iterrows(): label = float(row['label']) chain_res1 = row[['chain0', 'residue0']].values chain_res2 = row[['chain1', 'residue1']].values graph1 = df_to_graph(bound1, chain_res1, label) graph2 = df_to_graph(bound2, chain_res2, label) if (graph1 is None) or (graph2 is None): continue pair = Batch.from_data_list([graph1, graph2]) torch.save(pair, os.path.join(out_dir, f'data_{shard_num}_{curr_idx}.pt')) curr_idx += 1
def _obs(self) -> Tuple[Batch, List[List[int]]]: """ returns ------- Tuple[Batch, List[List[int]] The Batch object contains the Pytorch Geometric graph representing the molecule. The list of lists of integers is a list of all the torsions of the molecule, where each torsion is represented by a list of four integers, where the integers are the indices of the four atoms making up the torsion. """ mol = Chem.rdmolops.RemoveHs(self.mol) conf = mol.GetConformer() atoms = mol.GetAtoms() bonds = mol.GetBonds() node_features = [molecule_features.atom_type_CO(atom) + molecule_features.atom_coords(atom, conf) for atom in atoms] edge_indices = molecule_features.get_bond_pairs(mol) edge_attributes = [molecule_features.bond_type(bond) for bond in bonds] * 2 data = Data( x=torch.tensor(node_features, dtype=torch.float), edge_index=torch.tensor(edge_indices, dtype=torch.long), edge_attr=torch.tensor(edge_attributes,dtype=torch.float), pos=torch.Tensor(conf.GetPositions()) ) data = Center()(data) data = NormalizeRotation()(data) data.x[:,-3:] = data.pos data = Batch.from_data_list([data]) return data, self.nonring
def construct_hidden_graph(self, bsize: int, num_agent: int, hidden_size: int): # Compute edge connections edge_index = torch.tensor(list(permutations(range(num_agent), 2)), dtype=torch.long) edge_index = edge_index.t().contiguous() # Shape: [2 x E], E = n^2 e = edge_index.shape[1] # U vector. |U|-dimensional 0-vector x = torch.zeros((num_agent, hidden_size), dtype=torch.float32, device=self.device) u = torch.zeros((1, hidden_size), dtype=torch.float32, device=self.device) edge_attr = torch.zeros((e, hidden_size), dtype=torch.float32, device=self.device) # Create list of Data objects, then call Batch.from_data_list() data_objs = [ Data(x=x.clone(), edge_index=edge_index, edge_attr=edge_attr.clone(), u=u.clone()) for _ in range(bsize) ] batch = Batch.from_data_list(data_objs).to(x.device) return batch
def _prepare_batch(self, batch): """Create batch data for GAT. Parameters ---------- batch: Tuple The tuple are `(inputs, labels, weights)`. Returns ------- inputs: torch_geometric.data.Batch A mini-batch graph data for PyTorch Geometric models. labels: List[torch.Tensor] or None The labels converted to torch.Tensor. weights: List[torch.Tensor] or None The weights for each sample or sample/task pair converted to torch.Tensor. """ try: from torch_geometric.data import Batch except: raise ValueError( "This class requires PyTorch Geometric to be installed.") inputs, labels, weights = batch pyg_graphs = [graph.to_pyg_graph() for graph in inputs[0]] inputs = Batch.from_data_list(pyg_graphs) inputs = inputs.to(self.device) _, labels, weights = super(GATModel, self)._prepare_batch( ([], labels, weights)) return inputs, labels, weights
def _get_schema_graph_encoding( self, worlds: List[SpiderWorld], initial_graph_embeddings: torch.Tensor ) -> Tuple[torch.Tensor, torch.Tensor]: max_num_entities = max([ len(world.db_context.knowledge_graph.entities) for world in worlds ]) batch_size = initial_graph_embeddings.size(0) graph_data_list = [] for batch_index, world in enumerate(worlds): x = initial_graph_embeddings[batch_index] adj_list = self._get_graph_adj_lists( initial_graph_embeddings.device, world, initial_graph_embeddings.size(1) - 1) graph_data = Data(x) for i, l in enumerate(adj_list): graph_data[f'edge_index_{i}'] = l graph_data_list.append(graph_data) batch = Batch.from_data_list(graph_data_list) gnn_output = self._gnn(batch.x, [ batch[f'edge_index_{i}'] for i in range(self._gnn.num_edge_types) ]) num_nodes = max_num_entities gnn_output = gnn_output.view(batch_size, num_nodes, -1) # entities_encodings = gnn_output entities_encodings = gnn_output[:, :max_num_entities] # global_node_encodings = gnn_output[:, max_num_entities] return entities_encodings
def forward(self, inputs, training=None): x, embed, edge_index, tmp = inputs nodesize = embed.shape[0] feature_dim = embed.shape[1] f1 = embed.unsqueeze(1).repeat(1, nodesize, 1).reshape(-1, feature_dim) f2 = embed.unsqueeze(0).repeat(nodesize, 1, 1).reshape(-1, feature_dim) # using the node embedding to calculate the edge weight f12self = torch.cat([f1, f2], dim=-1) h = f12self.to(self.device) for elayer in self.elayers: h = elayer(h) values = h.reshape(-1) values = self.concrete_sample(values, beta=tmp, training=training) self.mask_sigmoid = values.reshape(nodesize, nodesize) # set the symmetric edge weights sym_mask = (self.mask_sigmoid + self.mask_sigmoid.transpose(0, 1)) / 2 edge_mask = sym_mask[edge_index[0], edge_index[1]] # inverse the weights before sigmoid in MessagePassing Module edge_mask = inv_sigmoid(edge_mask) self.__clear_masks__() self.__set_masks__(x, edge_index, edge_mask) # the model prediction with edge mask data = Batch.from_data_list([Data(x=x, edge_index=edge_index)]) data.to(self.device) outputs = self.model(data) return outputs[1].squeeze(), edge_mask
def stack(inp): if type(inp[0]) == list: ret = [] for vs in zip(*inp): ret.append(stack(vs)) elif type(inp[0]) == dict: ret = {} for kvs in zip(*[x.items() for x in inp]): ks, vs = zip(*kvs) for k in ks: assert k == ks[0], "Key value mismatch." ret[k] = stack(vs) elif type(inp[0]) == torch.Tensor: new_t = pad_tensor(inp) ret = torch.stack(new_t, 0) elif type(inp[0]) == np.ndarray: new_t = pad_tensor([torch.from_numpy(x) for x in inp]) ret = torch.stack(new_t, 0) elif type(inp[0]) == str: ret = inp elif type(inp[0]) == Data: # Graph from torch.geometric, create a batch ret = Batch.from_data_list(inp) else: raise ValueError("Cannot handle type {}".format(type(inp[0]))) return ret
def train_parallel(model, optimizer, loader, total, batch_size, loss_ftn_obj): model.train() sum_loss = 0. t = tqdm.tqdm(enumerate(loader), total=total / batch_size) for i, data in t: optimizer.zero_grad() if loss_ftn_obj.name == 'vae_loss': batch_output, mu, log_var = model(data) y = torch.cat([d.x for d in data]).to(device) batch_loss = loss_ftn_obj.loss_ftn(batch_output, y, mu, log_var) elif loss_ftn_obj.name == 'emd_loss' or loss_ftn_obj.name == 'chamfer_loss': batch_output = model(data) data_batch = Batch.from_data_list(data).to(device) batch_loss = loss_ftn_obj.loss_ftn(batch_output, data_batch.x, data_batch.batch) batch_loss = batch_loss.mean() elif loss_ftn_obj.name == 'emd_loss_layer': _, batch_loss = model(data) batch_loss = batch_loss.mean() else: batch_output = model(data) y = torch.cat([d.x for d in data]).to(device) batch_loss = loss_ftn_obj.loss_ftn(batch_output, y) batch_loss.backward() batch_loss_item = batch_loss.item() t.set_description('train loss = %.5f' % batch_loss_item) t.refresh() # to show immediately the update sum_loss += batch_loss_item optimizer.step() return sum_loss / (i + 1)
def collate(output_list): if isinstance(output_list[0], torch.Tensor): return torch.cat(output_list, dim=0) elif geometric and isinstance(output_list[0], Data): return Batch.from_data_list(output_list) else: return [collate(dim) for dim in zip(*output_list)]
def __init__(self, dataset, batch_size=1, shuffle=True, **kwargs): super(DataLoader, self).__init__( dataset, batch_size, shuffle, collate_fn=lambda data_list: Batch.from_data_list(data_list), **kwargs)