def test_permutation(self, z, adjacency, length): ldj = z.new_zeros(z.size(0), dtype=torch.float32) kwargs = dict() kwargs["length"] = length kwargs["src_key_padding_mask"] = create_transformer_mask(length) kwargs["channel_padding_mask"] = create_channel_mask(length) z, ldj = self._run_layer(self.node_embed_flow, z, reverse=False, ldj=ldj, ldj_per_layer=None, adjacency=adjacency, **kwargs) shuffle_noise = torch.rand(z.size(1)).to(z.device) - 2 shuffle_noise = shuffle_noise * (kwargs["channel_padding_mask"].sum( dim=[0, 2]) == z.size(0)).float() shuffle_noise = shuffle_noise + 0.001 * torch.arange(z.size(1), device=z.device) _, shuffle_indices = shuffle_noise.sort(dim=0, descending=False) _, unshuffle_indices = shuffle_indices.sort(dim=0, descending=False) z_shuffled = z[:, shuffle_indices] adjacency_shuffled = adjacency[:, shuffle_indices][:, :, shuffle_indices] ldj_shuffled = ldj for flow in self.flow_layers[1:]: z, ldj = self._run_layer(flow, z, adjacency=adjacency, reverse=False, ldj=ldj, **kwargs) z_shuffled, ldj_shuffled = self._run_layer( flow, z_shuffled, adjacency=adjacency_shuffled, reverse=False, ldj=ldj_shuffled, **kwargs) z_unshuffled = z_shuffled[:, unshuffle_indices] ldj_unshuffled = ldj_shuffled z_diff = ((z - z_unshuffled).abs() > 1e-4).sum() ldj_diff = ((ldj - ldj_unshuffled).abs() > 1e-3).sum() if z_diff > 0 or ldj_diff > 0: print("Differences z: %s, ldj: %s" % (str(z_diff.item()), str(ldj_diff.item()))) print("Z", z[0, :, 0]) print("Z shuffled", z_shuffled[0, :, 0]) print("Z unshuffled", z_unshuffled[0, :, 0]) print("LDJ", ldj[0:5]) print("LDJ unshuffled", ldj_unshuffled[0:5]) print("LDJ diff", (ldj - ldj_unshuffled).abs()) return False else: print("Shuffle test succeeded!") return True
def initialize_data_dependent(self, batch_list): # Batch list needs to consist of tuples: (z, kwargs) print("Initializing data dependent...") with torch.no_grad(): for batch, kwargs in batch_list: kwargs["src_key_padding_mask"] = create_transformer_mask(kwargs["length"]) kwargs["channel_padding_mask"] = create_channel_mask(kwargs["length"]) for layer_index, layer in enumerate(self.flow_layers): batch_list = FlowModel.run_data_init_layer(batch_list, layer)
def forward(self, z, ldj=None, reverse=False, length=None, **kwargs): if length is not None: kwargs["src_key_padding_mask"] = create_transformer_mask(length) kwargs["channel_padding_mask"] = create_channel_mask(length) return super().forward(z, ldj=ldj, reverse=reverse, length=length, **kwargs)
def forward(self, z, adjacency=None, ldj=None, reverse=False, length=None, sample_temp=1.0, **kwargs): if ldj is None: ldj = z.new_zeros(z.size(0), dtype=torch.float32) if length is not None: kwargs["length"] = length kwargs["src_key_padding_mask"] = create_transformer_mask( length, max_len=z.size(1)) kwargs["channel_padding_mask"] = create_channel_mask( length, max_len=z.size(1)) ldj_per_layer = [] if not reverse: orig_nodes = z label_list = self._create_labels(z, length, max_batch_len=z.shape[1]) batch_ldj = z.new_zeros(z.size(0), dtype=torch.float32) for nodes_mask, labels_one_hot in label_list: z = (orig_nodes + 1) * nodes_mask ## Run RNN z_nodes = self.embed_layer(z) out_pred = self.graph_layer(z_nodes, adjacency, **kwargs) out_pred = F.log_softmax(out_pred, dim=-1) ## Calculate loss class_ldj = (out_pred * labels_one_hot).sum(dim=-1) batch_ldj = batch_ldj + class_ldj.sum(dim=1) if len(label_list) > 1: batch_ldj = batch_ldj / length.float() ldj = ldj + batch_ldj return ldj else: z_nodes = z.new_zeros(z.size(0), z.size(1), dtype=torch.long) for rnn_iter in range(length.max()): node_embed = self.embed_layer(z_nodes) out_pred = self.graph_layer(node_embed, adjacency, **kwargs) out_pred = F.log_softmax(out_pred, dim=-1) out_pred = out_pred[:, rnn_iter, :] if sample_temp > 0.0: out_pred = out_pred / sample_temp out_pred = torch.softmax(out_pred, dim=-1) out_sample = torch.multinomial(out_pred, num_samples=1, replacement=True).squeeze() else: out_sample = torch.argmax(out_pred, dim=-1) z_nodes[:, rnn_iter] = out_sample + 1 z_nodes = (z_nodes - 1) * kwargs["channel_padding_mask"].squeeze( dim=-1).long() return z_nodes, None
def initialize_data_dependent(self, batch_list): # Batch list needs to consist of tuples: (z, kwargs) # kwargs contains the adjacency matrix as well with torch.no_grad(): for batch, kwargs in batch_list: kwargs["src_key_padding_mask"] = create_transformer_mask( kwargs["length"], max_len=batch.shape[1]) kwargs["channel_padding_mask"] = create_channel_mask( kwargs["length"], max_len=batch.shape[1]) for layer_index, layer in enumerate(self.flow_layers): batch_list = FlowModel.run_data_init_layer(batch_list, layer)
def forward(self, z, ldj=None, reverse=False, length=None, **kwargs): if ldj is None: ldj = z.new_zeros(z.size(0), dtype=torch.float32) if length is not None: kwargs["src_key_padding_mask"] = create_transformer_mask(length) kwargs["channel_padding_mask"] = create_channel_mask(length) if not reverse: z = one_hot(z, num_classes=self.vocab_size) for flow in self.flow_layers: z, ldj = flow(z, ldj, reverse=reverse, length=length, **kwargs) prior = F.log_softmax(self.prior, dim=-1) ldj = (z * prior[None, :z.size(1)] * kwargs["channel_padding_mask"]).sum(dim=[1, 2]) else: for flow in reversed(self.flow_layers): z, ldj = flow(z, ldj, reverse=reverse, length=length, **kwargs) return z, ldj
def get_inner_activations(self, z, length=None, return_names=False, **kwargs): if length is not None: kwargs["length"] = length kwargs["src_key_padding_mask"] = create_transformer_mask(length) kwargs["channel_padding_mask"] = create_channel_mask(length) out_per_layer = [] layer_names = [] for layer_index, layer in enumerate(self.flow_layers): z = self._run_layer(layer, z, reverse=False, **kwargs)[0] out_per_layer.append(z.detach()) layer_names.append(layer.__class__.__name__) if not return_names: return out_per_layer else: return out_per_layer, layer_names
def forward(self, z, adjacency=None, ldj=None, reverse=False, get_ldj_per_layer=False, length=None, sample_temp=1.0, gamma=1.0, **kwargs): if ldj is None: ldj = z.new_zeros(z.size(0), dtype=torch.float32) if length is not None: kwargs["length"] = length kwargs["src_key_padding_mask"] = create_transformer_mask(length, max_len=z.size(1)) kwargs["channel_padding_mask"] = create_channel_mask(length, max_len=z.size(1)) z_nodes = None ldj_per_layer = [] if not reverse: orig_nodes = z ## Encoder z_embed = self.embed_layer(z) z_enc = self.graph_encoder(z_embed, adjacency, **kwargs) z_mu, z_log_std = z_enc.chunk(2, dim=-1) z_std = z_log_std.exp() z_latent = torch.randn_like(z_mu) * z_std + z_mu ## Decoder z_dec = self.graph_decoder(z_latent, adjacency, **kwargs) z_rec = F.log_softmax(z_dec, dim=-1) ## Loss calculation loss_mask = kwargs["channel_padding_mask"].squeeze(dim=-1) reconstruction_loss = F.nll_loss(z_rec.view(-1, self.num_node_types), z.view(-1), reduction='none').view(z.shape) * loss_mask reconstruction_loss = reconstruction_loss.sum(dim=-1) / length.float() KL_div = (- z_log_std + (z_std ** 2 - 1 + z_mu ** 2) / 2).sum(dim=-1) * loss_mask KL_div = KL_div.sum(dim=-1) / length.float() ldj = ldj - (reconstruction_loss + (gamma * KL_div + (1-gamma) * KL_div.detach())) ldj_per_layer.append({"KL_div": -KL_div, "reconstruction_loss": -reconstruction_loss}) return ldj, ldj_per_layer else: z_latent = torch.randn_like(z) ## Decoder z_dec = self.graph_decoder(z_latent, adjacency, **kwargs) ## Sampling z_dec = z_dec / sample_temp out_pred = torch.softmax(z_dec, dim=-1) z_nodes = torch.multinomial(out_pred.view(-1,out_pred.size(-1)), num_samples=1, replacement=True).view(out_pred.shape[:-1]) return z_nodes, None
def test_reversibility(self, z, adjacency, length): ldj = z.new_zeros(z.size(0), dtype=torch.float32) kwargs = dict() kwargs["length"] = length kwargs["src_key_padding_mask"] = create_transformer_mask( length, max_len=z.size(1)) kwargs["channel_padding_mask"] = create_channel_mask(length, max_len=z.size(1)) ## Embed nodes z_nodes, ldj = self._run_layer(self.node_embed_flow, z, reverse=False, ldj=ldj, adjacency=adjacency, **kwargs) z_nodes_embed = z_nodes ldj_embed = ldj ## Testing RGCN flows for flow in self.flow_layers[1:]: z_nodes, ldj = self._run_layer(flow, z_nodes, reverse=False, ldj=ldj, adjacency=adjacency, **kwargs) z_nodes_reversed, ldj_reversed = z_nodes, ldj for flow in reversed(self.flow_layers[1:]): z_nodes_reversed, ldj_reversed = self._run_layer( flow, z_nodes_reversed, reverse=True, ldj=ldj_reversed, adjacency=adjacency, **kwargs) reverse_succeeded = ( (z_nodes_reversed - z_nodes_embed).abs() > 1e-2).sum() == 0 and ( (ldj_reversed - ldj_embed).abs() > 1e-1).sum() == 0 if not reverse_succeeded: print("[!] ERROR: Coupling layer with given adjacency matrix are not reversible. Max diffs:\n" + \ "Nodes: %s\n" % str(torch.max((z_nodes_reversed - z_nodes_embed).abs())) + \ "LDJ: %s\n" % str(torch.max((ldj_reversed - ldj_embed).abs()))) z_nodes = z_nodes_embed ldj = ldj_embed large_error = False for flow_index, flow in enumerate(self.flow_layers[1:]): z_nodes_forward, ldj_forward = self._run_layer( flow, z_nodes, reverse=False, ldj=ldj, adjacency=adjacency, **kwargs) z_nodes_backward, ldj_backward = self._run_layer( flow, z_nodes_forward, reverse=True, ldj=ldj_forward, adjacency=adjacency, **kwargs) node_diff = (z_nodes_backward - z_nodes).abs() ldj_diff = (ldj_backward - ldj).abs() max_node_diff = torch.max(node_diff) max_ldj_diff = torch.max(ldj_diff) mean_node_diff = torch.mean(node_diff) mean_ldj_diff = torch.mean(ldj_diff) print("Flow [%i]: %s" % (flow_index + 1, flow.info())) print("-> Max node diff: %s" % str(max_node_diff)) print("-> Max ldj diff: %s" % str(max_ldj_diff)) print("-> Mean node diff: %s" % str(mean_node_diff)) print("-> Mean ldj diff: %s" % str(mean_ldj_diff)) if max_node_diff > 1e-2: batch_index = torch.argmax(ldj_diff).item() print("-> Batch index", batch_index) print("-> Nodes with max diff:") print(node_diff[batch_index]) print(z_nodes_backward[batch_index]) print(z_nodes[batch_index]) print(z_nodes_forward[batch_index]) node_mask = (node_diff[batch_index] > 1e-4) faulty_nodes = z_nodes_forward[batch_index].masked_select( node_mask) num_small_faulty_nodes = (faulty_nodes.abs() < 1.0).sum().item() large_error = large_error or (num_small_faulty_nodes > 0) z_nodes = z_nodes_forward ldj = ldj_forward if not large_error: print("-" * 50) print( "Error probably caused by large values out of range in the mixture layer. Ignored for now." ) return (not large_error) else: print("Reversibility test passed") return True