Пример #1
0
    def test_permutation(self, z, adjacency, length):
        ldj = z.new_zeros(z.size(0), dtype=torch.float32)
        kwargs = dict()
        kwargs["length"] = length
        kwargs["src_key_padding_mask"] = create_transformer_mask(length)
        kwargs["channel_padding_mask"] = create_channel_mask(length)

        z, ldj = self._run_layer(self.node_embed_flow,
                                 z,
                                 reverse=False,
                                 ldj=ldj,
                                 ldj_per_layer=None,
                                 adjacency=adjacency,
                                 **kwargs)

        shuffle_noise = torch.rand(z.size(1)).to(z.device) - 2
        shuffle_noise = shuffle_noise * (kwargs["channel_padding_mask"].sum(
            dim=[0, 2]) == z.size(0)).float()
        shuffle_noise = shuffle_noise + 0.001 * torch.arange(z.size(1),
                                                             device=z.device)
        _, shuffle_indices = shuffle_noise.sort(dim=0, descending=False)
        _, unshuffle_indices = shuffle_indices.sort(dim=0, descending=False)
        z_shuffled = z[:, shuffle_indices]
        adjacency_shuffled = adjacency[:, shuffle_indices][:, :,
                                                           shuffle_indices]
        ldj_shuffled = ldj

        for flow in self.flow_layers[1:]:
            z, ldj = self._run_layer(flow,
                                     z,
                                     adjacency=adjacency,
                                     reverse=False,
                                     ldj=ldj,
                                     **kwargs)
            z_shuffled, ldj_shuffled = self._run_layer(
                flow,
                z_shuffled,
                adjacency=adjacency_shuffled,
                reverse=False,
                ldj=ldj_shuffled,
                **kwargs)
        z_unshuffled = z_shuffled[:, unshuffle_indices]
        ldj_unshuffled = ldj_shuffled

        z_diff = ((z - z_unshuffled).abs() > 1e-4).sum()
        ldj_diff = ((ldj - ldj_unshuffled).abs() > 1e-3).sum()

        if z_diff > 0 or ldj_diff > 0:
            print("Differences z: %s, ldj: %s" %
                  (str(z_diff.item()), str(ldj_diff.item())))
            print("Z", z[0, :, 0])
            print("Z shuffled", z_shuffled[0, :, 0])
            print("Z unshuffled", z_unshuffled[0, :, 0])
            print("LDJ", ldj[0:5])
            print("LDJ unshuffled", ldj_unshuffled[0:5])
            print("LDJ diff", (ldj - ldj_unshuffled).abs())
            return False
        else:
            print("Shuffle test succeeded!")
            return True
Пример #2
0
	def _preprocess_batch(self, batch, length_clipping=True):
		x_in, x_adjacency, x_length = batch
		if length_clipping:
			max_len = x_length.max()
			x_in = x_in[:,:max_len].contiguous()
			x_adjacency = x_adjacency[:,:max_len,:max_len].contiguous()
		x_channel_mask = create_channel_mask(x_length, max_len=x_in.shape[1])
		return x_in, x_adjacency, x_length, x_channel_mask
Пример #3
0
	def _preprocess_batch(self, batch, length_clipping=True):
		x_in, x_adjacency, x_length = batch # the length is th number of real nodes in a graph excluding those virtual nodes
		if length_clipping: # clipping virtual node part that exists in all the elements in one batch
			max_len = x_length.max()
			x_in = x_in[:,:max_len].contiguous() # the size of x is [batch_size, max_len]
			x_adjacency = x_adjacency[:,:max_len,:max_len].contiguous()
		x_channel_mask = create_channel_mask(x_length, max_len=x_in.shape[1]) # the shape: [batch_size, max_len, 1]
		return x_in, x_adjacency, x_length, x_channel_mask
Пример #4
0
	def initialize_data_dependent(self, batch_list):
		# Batch list needs to consist of tuples: (z, kwargs)
		print("Initializing data dependent...")
		with torch.no_grad():
			for batch, kwargs in batch_list:
				kwargs["src_key_padding_mask"] = create_transformer_mask(kwargs["length"])
				kwargs["channel_padding_mask"] = create_channel_mask(kwargs["length"])
			for layer_index, layer in enumerate(self.flow_layers):
				batch_list = FlowModel.run_data_init_layer(batch_list, layer)
Пример #5
0
 def forward(self, z, ldj=None, reverse=False, length=None, **kwargs):
     if length is not None:
         kwargs["src_key_padding_mask"] = create_transformer_mask(length)
         kwargs["channel_padding_mask"] = create_channel_mask(length)
     return super().forward(z,
                            ldj=ldj,
                            reverse=reverse,
                            length=length,
                            **kwargs)
Пример #6
0
    def forward(self,
                z,
                adjacency=None,
                ldj=None,
                reverse=False,
                length=None,
                sample_temp=1.0,
                **kwargs):
        if ldj is None:
            ldj = z.new_zeros(z.size(0), dtype=torch.float32)
        if length is not None:
            kwargs["length"] = length
            kwargs["src_key_padding_mask"] = create_transformer_mask(
                length, max_len=z.size(1))
            kwargs["channel_padding_mask"] = create_channel_mask(
                length, max_len=z.size(1))

        ldj_per_layer = []
        if not reverse:
            orig_nodes = z
            label_list = self._create_labels(z,
                                             length,
                                             max_batch_len=z.shape[1])
            batch_ldj = z.new_zeros(z.size(0), dtype=torch.float32)
            for nodes_mask, labels_one_hot in label_list:
                z = (orig_nodes + 1) * nodes_mask
                ## Run RNN
                z_nodes = self.embed_layer(z)
                out_pred = self.graph_layer(z_nodes, adjacency, **kwargs)
                out_pred = F.log_softmax(out_pred, dim=-1)
                ## Calculate loss
                class_ldj = (out_pred * labels_one_hot).sum(dim=-1)
                batch_ldj = batch_ldj + class_ldj.sum(dim=1)
            if len(label_list) > 1:
                batch_ldj = batch_ldj / length.float()
            ldj = ldj + batch_ldj
            return ldj
        else:
            z_nodes = z.new_zeros(z.size(0), z.size(1), dtype=torch.long)
            for rnn_iter in range(length.max()):
                node_embed = self.embed_layer(z_nodes)
                out_pred = self.graph_layer(node_embed, adjacency, **kwargs)
                out_pred = F.log_softmax(out_pred, dim=-1)
                out_pred = out_pred[:, rnn_iter, :]
                if sample_temp > 0.0:
                    out_pred = out_pred / sample_temp
                    out_pred = torch.softmax(out_pred, dim=-1)
                    out_sample = torch.multinomial(out_pred,
                                                   num_samples=1,
                                                   replacement=True).squeeze()
                else:
                    out_sample = torch.argmax(out_pred, dim=-1)
                z_nodes[:, rnn_iter] = out_sample + 1
            z_nodes = (z_nodes - 1) * kwargs["channel_padding_mask"].squeeze(
                dim=-1).long()
            return z_nodes, None
Пример #7
0
 def initialize_data_dependent(self, batch_list):
     # Batch list needs to consist of tuples: (z, kwargs)
     # kwargs contains the adjacency matrix as well
     with torch.no_grad():
         for batch, kwargs in batch_list:
             kwargs["src_key_padding_mask"] = create_transformer_mask(
                 kwargs["length"], max_len=batch.shape[1])
             kwargs["channel_padding_mask"] = create_channel_mask(
                 kwargs["length"], max_len=batch.shape[1])
         for layer_index, layer in enumerate(self.flow_layers):
             batch_list = FlowModel.run_data_init_layer(batch_list, layer)
Пример #8
0
 def _preprocess_batch(self, batch):
     if isinstance(batch, tuple):
         x_in, x_length = batch
         x_in = x_in[:, :x_length.max()]
         x_channel_mask = create_channel_mask(x_length,
                                              max_len=x_in.size(1))
     else:
         x_in = batch
         x_length = x_in.new_zeros(x_in.size(0),
                                   dtype=torch.long) + x_in.size(1)
         x_channel_mask = x_in.new_ones(x_in.size(0),
                                        x_in.size(1),
                                        1,
                                        dtype=torch.float32)
     return x_in, x_length, x_channel_mask
Пример #9
0
    def forward(self, z, ldj=None, reverse=False, length=None, **kwargs):
        if ldj is None:
            ldj = z.new_zeros(z.size(0), dtype=torch.float32)
        if length is not None:
            kwargs["src_key_padding_mask"] = create_transformer_mask(length)
            kwargs["channel_padding_mask"] = create_channel_mask(length)

        if not reverse:
            z = one_hot(z, num_classes=self.vocab_size)
            for flow in self.flow_layers:
                z, ldj = flow(z, ldj, reverse=reverse, length=length, **kwargs)
            prior = F.log_softmax(self.prior, dim=-1)
            ldj = (z * prior[None, :z.size(1)] *
                   kwargs["channel_padding_mask"]).sum(dim=[1, 2])
        else:
            for flow in reversed(self.flow_layers):
                z, ldj = flow(z, ldj, reverse=reverse, length=length, **kwargs)
        return z, ldj
Пример #10
0
    def get_inner_activations(self,
                              z,
                              length=None,
                              return_names=False,
                              **kwargs):
        if length is not None:
            kwargs["length"] = length
            kwargs["src_key_padding_mask"] = create_transformer_mask(length)
            kwargs["channel_padding_mask"] = create_channel_mask(length)

        out_per_layer = []
        layer_names = []
        for layer_index, layer in enumerate(self.flow_layers):
            z = self._run_layer(layer, z, reverse=False, **kwargs)[0]
            out_per_layer.append(z.detach())
            layer_names.append(layer.__class__.__name__)

        if not return_names:
            return out_per_layer
        else:
            return out_per_layer, layer_names
Пример #11
0
	def forward(self, z, adjacency=None, ldj=None, reverse=False, get_ldj_per_layer=False, length=None, sample_temp=1.0, gamma=1.0, **kwargs):
		if ldj is None:
			ldj = z.new_zeros(z.size(0), dtype=torch.float32)
		if length is not None:
			kwargs["length"] = length
			kwargs["src_key_padding_mask"] = create_transformer_mask(length, max_len=z.size(1))
			kwargs["channel_padding_mask"] = create_channel_mask(length, max_len=z.size(1))

		z_nodes = None
		ldj_per_layer = []
		if not reverse:
			orig_nodes = z
			## Encoder
			z_embed = self.embed_layer(z)
			z_enc = self.graph_encoder(z_embed, adjacency, **kwargs)
			z_mu, z_log_std = z_enc.chunk(2, dim=-1)
			z_std = z_log_std.exp()
			z_latent = torch.randn_like(z_mu) * z_std + z_mu
			## Decoder
			z_dec = self.graph_decoder(z_latent, adjacency, **kwargs)
			z_rec = F.log_softmax(z_dec, dim=-1)
			## Loss calculation
			loss_mask = kwargs["channel_padding_mask"].squeeze(dim=-1)
			reconstruction_loss = F.nll_loss(z_rec.view(-1, self.num_node_types), z.view(-1), reduction='none').view(z.shape) * loss_mask
			reconstruction_loss = reconstruction_loss.sum(dim=-1) / length.float()
			KL_div = (- z_log_std + (z_std ** 2 - 1 + z_mu ** 2) / 2).sum(dim=-1) * loss_mask
			KL_div = KL_div.sum(dim=-1) / length.float()

			ldj = ldj - (reconstruction_loss + (gamma * KL_div + (1-gamma) * KL_div.detach()))
			ldj_per_layer.append({"KL_div": -KL_div, "reconstruction_loss": -reconstruction_loss})
			return ldj, ldj_per_layer
		else:
			z_latent = torch.randn_like(z)
			## Decoder
			z_dec = self.graph_decoder(z_latent, adjacency, **kwargs)
			## Sampling
			z_dec = z_dec / sample_temp
			out_pred = torch.softmax(z_dec, dim=-1)
			z_nodes = torch.multinomial(out_pred.view(-1,out_pred.size(-1)), num_samples=1, replacement=True).view(out_pred.shape[:-1])
			return z_nodes, None
        s = "Autoregressive Mixture CDF Coupling Layer - Input size %i" % (
            self.c_in)
        if self.block_type is not None:
            s += ", block type %s" % (self.block_type)
        return s


if __name__ == "__main__":
    torch.manual_seed(42)
    np.random.seed(42)

    batch_size, seq_len, c_in = 1, 3, 3
    hidden_size = 8
    _inp = torch.randn(batch_size, seq_len, c_in)
    lengths = torch.LongTensor([seq_len] * batch_size)
    channel_padding_mask = create_channel_mask(length=lengths, max_len=seq_len)
    time_embed = nn.Linear(2 * seq_len, 2)

    module = AutoregressiveMixtureCDFCoupling1D(c_in=c_in,
                                                hidden_size=hidden_size,
                                                num_mixtures=4,
                                                time_embed=time_embed,
                                                autoreg_hidden=True)

    orig_out, _ = module(z=_inp,
                         length=lengths,
                         channel_padding_mask=channel_padding_mask)
    print("Out", orig_out)

    _inp[0, 1, 1] = 10
    alt_out, _ = module(z=_inp,
Пример #13
0
    def test_reversibility(self, z_nodes, adjacency, length, **kwargs):
        ldj = z_nodes.new_zeros(z_nodes.size(0), dtype=torch.float32)
        if length is not None:
            kwargs["length"] = length
            kwargs["channel_padding_mask"] = create_channel_mask(
                length, max_len=z_nodes.size(1))

        ## Performing encoding of step 1
        z_nodes, ldj = self._run_layer(self.node_encoding,
                                       z_nodes,
                                       False,
                                       ldj=ldj,
                                       **kwargs)
        z_nodes_embed = z_nodes
        ldj_embed = ldj

        ## Testing step 1 flows
        for flow in self.step1_flows:
            z_nodes, ldj = self._run_layer(flow,
                                           z_nodes,
                                           reverse=False,
                                           ldj=ldj,
                                           adjacency=adjacency,
                                           **kwargs)
        z_nodes_reversed, ldj_reversed = z_nodes, ldj
        for flow in reversed(self.step1_flows):
            z_nodes_reversed, ldj_reversed = self._run_layer(
                flow,
                z_nodes_reversed,
                reverse=True,
                ldj=ldj_reversed,
                adjacency=adjacency,
                **kwargs)
        rev_node = (
            (z_nodes_reversed - z_nodes_embed).abs() > 1e-3).sum() == 0 and (
                (ldj_reversed - ldj_embed).abs() > 1e-1).sum() == 0
        if not rev_node:
            print("[#] WARNING: Step 1 - Coupling layers are not precisely reversible. Max diffs:\n" + \
              "Nodes: %s\n" % str(torch.max((z_nodes_reversed - z_nodes_embed).abs())) + \
              "LDJ: %s" % str(torch.max((ldj_reversed - ldj_embed).abs())))

        ## Performing encoding of step 2
        z_edges_disc, x_indices, mask_valid = adjacency2pairs(
            adjacency=adjacency, length=length)
        kwargs["mask_valid"] = mask_valid * (z_edges_disc != 0).to(
            mask_valid.dtype)
        kwargs["x_indices"] = x_indices
        binary_adjacency = (adjacency > 0).long()
        kwargs_edge_embed = kwargs.copy()
        kwargs_edge_embed["channel_padding_mask"] = kwargs[
            "mask_valid"].unsqueeze(dim=-1)
        z_attr = (z_edges_disc - 1).clamp(min=0)
        z_edges, ldj = self._run_layer(self.edge_attr_encoding, z_attr, False,
                                       ldj, **kwargs_edge_embed)

        ## Testing step 2 flows
        z_nodes_orig, z_edges_orig, ldj_orig = z_nodes, z_edges, ldj
        for flow in self.step2_flows:
            z_nodes, z_edges, ldj = self._run_node_edge_layer(
                flow,
                z_nodes,
                z_edges,
                False,
                ldj,
                binary_adjacency=binary_adjacency,
                **kwargs)
        z_nodes_rev, z_edges_rev, ldj_rev = z_nodes, z_edges, ldj
        for flow in reversed(self.step2_flows):
            z_nodes_rev, z_edges_rev, ldj_rev = self._run_node_edge_layer(
                flow,
                z_nodes_rev,
                z_edges_rev,
                True,
                ldj_rev,
                binary_adjacency=binary_adjacency,
                **kwargs)
        rev_edge_attr = ((z_nodes_rev - z_nodes_orig).abs() > 1e-3).sum() == 0 and \
            ((z_edges_rev - z_edges_orig).abs() > 1e-3).sum() == 0 and \
            ((ldj_rev - ldj_orig).abs() > 1e-1).sum() == 0
        if not rev_edge_attr:
            print("[#] WARNING: Step 2 - Coupling layers are not precisely reversible. Max diffs:\n" + \
              "Nodes: %s\n" % str(torch.max((z_nodes_rev - z_nodes_orig).abs())) + \
              "Edges: %s\n" % str(torch.max((z_edges_rev - z_edges_orig).abs())) + \
              "LDJ: %s" % str(torch.max((ldj_rev - ldj_orig).abs())))

        ## Performing encoding of step 3
        kwargs["mask_valid"] = mask_valid
        virtual_edge_mask = mask_valid * (z_edges_disc == 0).float()
        kwargs_no_edge_embed = kwargs.copy()
        kwargs_no_edge_embed[
            "channel_padding_mask"] = virtual_edge_mask.unsqueeze(dim=-1)
        virt_edges = z_edges.new_zeros(z_edges.shape[:-1], dtype=torch.long)
        z_virtual_edges, ldj = self._run_layer(self.edge_virtual_encoding,
                                               virt_edges, False, ldj,
                                               **kwargs_no_edge_embed)
        z_edges = torch.where(
            virtual_edge_mask.unsqueeze(dim=-1) == 1, z_virtual_edges, z_edges)

        ## Testing step 3 flows
        z_nodes_orig, z_edges_orig, ldj_orig = z_nodes, z_edges, ldj
        for flow in self.step3_flows:
            z_nodes, z_edges, ldj = self._run_node_edge_layer(
                flow, z_nodes, z_edges, False, ldj, **kwargs)
        z_nodes_rev, z_edges_rev, ldj_rev = z_nodes, z_edges, ldj
        for flow in reversed(self.step3_flows):
            z_nodes_rev, z_edges_rev, ldj_rev = self._run_node_edge_layer(
                flow, z_nodes_rev, z_edges_rev, True, ldj_rev, **kwargs)
        rev_edge_virt = ((z_nodes_rev - z_nodes_orig).abs() > 1e-3).sum() == 0 and \
            ((z_edges_rev - z_edges_orig).abs() > 1e-3).sum() == 0 and \
            ((ldj_rev - ldj_orig).abs() > 1e-1).sum() == 0
        if not rev_edge_virt:
            print("[#] WARNING: Step 3 - Coupling layers are not precisely reversible. Max diffs:\n" + \
              "Nodes: %s\n" % str(torch.max((z_nodes_rev - z_nodes_orig).abs())) + \
              "Edges: %s\n" % str(torch.max((z_edges_rev - z_edges_orig).abs())) + \
              "LDJ: %s" % str(torch.max((ldj_rev - ldj_orig).abs())))

        if rev_node and rev_edge_attr and rev_edge_virt:
            print("Reversibility test succeeded!")
        else:
            print(
                "Reversibility test finished with warnings. Non-reversibility can be due to limited precision in mixture coupling layers"
            )
Пример #14
0
    def initialize_data_dependent(self, batch_list):
        # Batch list needs to consist of tuples: (z, kwargs)
        # kwargs contains the adjacency matrix as well
        with torch.no_grad():

            for batch, kwargs in batch_list:
                kwargs["channel_padding_mask"] = create_channel_mask(
                    kwargs["length"], max_len=batch.shape[1])

            for module_index, module_list in enumerate([[self.node_encoding],
                                                        self.step1_flows]):
                for layer_index, layer in enumerate(module_list):
                    print("Processing layer %i (module %i)..." %
                          (layer_index + 1, module_index + 1),
                          end="\r")
                    if isinstance(layer, FlowLayer):
                        batch_list = FlowModel.run_data_init_layer(
                            batch_list, layer)
                    elif isinstance(layer, FlowModel):
                        batch_list = layer.initialize_data_dependent(
                            batch_list)
                    else:
                        print("[!] ERROR: Unknown layer type", layer)
                        sys.exit(1)

            ## Initialize main flow
            for i in range(len(batch_list)):
                z_nodes, kwargs = batch_list[i]
                z_adjacency, x_indices, mask_valid = adjacency2pairs(
                    adjacency=kwargs["adjacency"], length=kwargs["length"])
                attr_mask_valid = mask_valid * (z_adjacency != 0).to(
                    mask_valid.dtype)
                z_edges, _, _ = self.edge_attr_encoding(
                    (z_adjacency - 1).clamp(min=0),
                    reverse=False,
                    channel_padding_mask=attr_mask_valid.unsqueeze(dim=-1))
                kwargs["original_z_adjacency"] = z_adjacency
                kwargs["binary_adjacency"] = (kwargs["adjacency"] > 0).long()
                kwargs["original_mask_valid"] = mask_valid
                kwargs["mask_valid"] = attr_mask_valid
                kwargs["x_indices"] = x_indices
                batch_list[i] = ([z_nodes, z_edges], kwargs)

            for layer_index, layer in enumerate(self.step2_flows):
                batch_list = FlowModel.run_data_init_layer(batch_list, layer)

            for i in range(len(batch_list)):
                z, kwargs = batch_list[i]
                z_nodes, z_edges = z[0], z[1]
                no_edge_mask_valid = kwargs["original_mask_valid"] * (
                    kwargs["original_z_adjacency"] == 0).float()
                z_no_edges, _, _ = self.edge_virtual_encoding(
                    torch.zeros_like(kwargs["original_z_adjacency"]),
                    reverse=False,
                    channel_padding_mask=no_edge_mask_valid.unsqueeze(dim=-1))
                z_edges = z_edges * (1 - no_edge_mask_valid)[
                    ..., None] + z_no_edges * no_edge_mask_valid[..., None]
                kwargs["mask_valid"] = kwargs["original_mask_valid"]
                kwargs.pop("binary_adjacency")
                batch_list[i] = ([z_nodes, z_edges], kwargs)

            for layer_index, layer in enumerate(self.step3_flows):
                batch_list = FlowModel.run_data_init_layer(batch_list, layer)
Пример #15
0
    def forward(self,
                z,
                adjacency=None,
                ldj=None,
                reverse=False,
                get_ldj_per_layer=False,
                length=None,
                sample_temp=1.0,
                **kwargs):
        z_nodes = z  # Renaming as argument "z" is usually used for the flows, but here it represents the discrete node types
        if ldj is None:
            ldj = z_nodes.new_zeros(z_nodes.size(0), dtype=torch.float32)
        if length is not None:
            kwargs["length"] = length
            kwargs["channel_padding_mask"] = create_channel_mask(
                length, max_len=z_nodes.size(1))

        ldj_per_layer = []
        if not reverse:
            ## Step 1 => Encode nodes in latent space and apply RGCN flows
            z_nodes, ldj = self._step1_forward(z_nodes, adjacency, ldj, False,
                                               ldj_per_layer, **kwargs)
            ## Edges are represented as a list (1D tensor), not as a matrix, because we do not want to consider each edge twice
            # X_indices is a tuple, where each element has the size of the edge tensor and states the node indices of the corresponding edge
            # Mask_valid is a tensor of the same size, but contains 0 for those edges that are not "valid".
            # This is when edges are padding elements for graphs of different sizes
            # here z_edges_disc is the discrete type of edge type; its shape is [batch_size, (37+1)*37/2], mask_valid is to show the valid edges which should not contain any virtual node
            z_edges_disc, x_indices, mask_valid = adjacency2pairs(
                adjacency=adjacency, length=length)
            kwargs["mask_valid"] = mask_valid * (z_edges_disc != 0).to(
                mask_valid.dtype
            )  # rule out the non-exisitng edges between valid nodes
            kwargs["x_indices"] = x_indices
            binary_adjacency = (adjacency > 0).long(
            )  # attention that adjacency matrix itself is type-specific matrix while binary matrix is proximity matrix
            ## Step 2 => Encode edge attributes in latent space and apply first EdgeGNN flows
            z_nodes, z_edges, ldj = self._step2_forward(
                z_nodes,
                z_edges_disc,
                ldj,
                False,
                ldj_per_layer,
                binary_adjacency=binary_adjacency,
                **kwargs)

            ## Step 3 => Encode virtual edges in latent space and apply final EdgeGNN flows
            kwargs["mask_valid"] = mask_valid
            virtual_edge_mask = mask_valid * (z_edges_disc == 0).float()
            z_nodes, z_edges, ldj = self._step3_forward(
                z_nodes, z_edges, ldj, False, ldj_per_layer, virtual_edge_mask,
                **kwargs)

            ## Add log probability of adjacency matrix to ldj. Only nodes are considered in the task object
            adjacency_log_prob = (self.prior_distribution.log_prob(z_edges) *
                                  mask_valid.unsqueeze(dim=-1)).sum(dim=[1, 2])
            ldj = ldj + adjacency_log_prob
            ldj_per_layer.append({"adjacency_log_prob": adjacency_log_prob})
        else:
            z_nodes = z
            batch_size, num_nodes = z_nodes.size(0), z_nodes.size(1)
            ## Sample latent variables for adjacency matrix
            mask_valid, x_indices = get_adjacency_indices(num_nodes=num_nodes,
                                                          length=length)
            kwargs["mask_valid"] = mask_valid
            kwargs["x_indices"] = x_indices
            z_edges = self.prior_distribution.sample(
                shape=(batch_size, mask_valid.size(1),
                       self.encoding_dim_edges),
                temp=sample_temp).to(z.device)
            ## Reverse step 3 => decode virtual edges
            z_nodes, z_edges, ldj, mask_valid = self._step3_forward(
                z_nodes, z_edges, ldj, True, ldj_per_layer, **kwargs)
            binary_adjacency = pairs2adjacency(num_nodes=num_nodes,
                                               pairs=mask_valid,
                                               length=length,
                                               x_indices=x_indices)
            ## Reverse step 2 => decode edge attributes
            kwargs["mask_valid"] = mask_valid
            z_nodes, z_edges, ldj = self._step2_forward(
                z_nodes,
                z_edges,
                ldj,
                True,
                ldj_per_layer,
                binary_adjacency=binary_adjacency,
                **kwargs)
            adjacency = pairs2adjacency(num_nodes=num_nodes,
                                        pairs=z_edges,
                                        length=length,
                                        x_indices=x_indices)
            ## Reverse step 1 => decode node types
            z_nodes, ldj = self._step1_forward(z_nodes,
                                               adjacency,
                                               ldj,
                                               reverse=True,
                                               ldj_per_layer=ldj_per_layer,
                                               **kwargs)
            z_nodes = (z_nodes, adjacency)

        if get_ldj_per_layer:
            return z_nodes, ldj, ldj_per_layer
        else:
            return z_nodes, ldj
Пример #16
0
    def test_reversibility(self, z, adjacency, length):
        ldj = z.new_zeros(z.size(0), dtype=torch.float32)
        kwargs = dict()
        kwargs["length"] = length
        kwargs["src_key_padding_mask"] = create_transformer_mask(
            length, max_len=z.size(1))
        kwargs["channel_padding_mask"] = create_channel_mask(length,
                                                             max_len=z.size(1))

        ## Embed nodes
        z_nodes, ldj = self._run_layer(self.node_embed_flow,
                                       z,
                                       reverse=False,
                                       ldj=ldj,
                                       adjacency=adjacency,
                                       **kwargs)
        z_nodes_embed = z_nodes
        ldj_embed = ldj
        ## Testing RGCN flows
        for flow in self.flow_layers[1:]:
            z_nodes, ldj = self._run_layer(flow,
                                           z_nodes,
                                           reverse=False,
                                           ldj=ldj,
                                           adjacency=adjacency,
                                           **kwargs)
        z_nodes_reversed, ldj_reversed = z_nodes, ldj
        for flow in reversed(self.flow_layers[1:]):
            z_nodes_reversed, ldj_reversed = self._run_layer(
                flow,
                z_nodes_reversed,
                reverse=True,
                ldj=ldj_reversed,
                adjacency=adjacency,
                **kwargs)
        reverse_succeeded = (
            (z_nodes_reversed - z_nodes_embed).abs() > 1e-2).sum() == 0 and (
                (ldj_reversed - ldj_embed).abs() > 1e-1).sum() == 0
        if not reverse_succeeded:
            print("[!] ERROR: Coupling layer with given adjacency matrix are not reversible. Max diffs:\n" + \
             "Nodes: %s\n" % str(torch.max((z_nodes_reversed - z_nodes_embed).abs())) + \
             "LDJ: %s\n" % str(torch.max((ldj_reversed - ldj_embed).abs())))
            z_nodes = z_nodes_embed
            ldj = ldj_embed
            large_error = False
            for flow_index, flow in enumerate(self.flow_layers[1:]):
                z_nodes_forward, ldj_forward = self._run_layer(
                    flow,
                    z_nodes,
                    reverse=False,
                    ldj=ldj,
                    adjacency=adjacency,
                    **kwargs)
                z_nodes_backward, ldj_backward = self._run_layer(
                    flow,
                    z_nodes_forward,
                    reverse=True,
                    ldj=ldj_forward,
                    adjacency=adjacency,
                    **kwargs)

                node_diff = (z_nodes_backward - z_nodes).abs()
                ldj_diff = (ldj_backward - ldj).abs()
                max_node_diff = torch.max(node_diff)
                max_ldj_diff = torch.max(ldj_diff)
                mean_node_diff = torch.mean(node_diff)
                mean_ldj_diff = torch.mean(ldj_diff)
                print("Flow [%i]: %s" % (flow_index + 1, flow.info()))
                print("-> Max node diff: %s" % str(max_node_diff))
                print("-> Max ldj diff: %s" % str(max_ldj_diff))
                print("-> Mean node diff: %s" % str(mean_node_diff))
                print("-> Mean ldj diff: %s" % str(mean_ldj_diff))
                if max_node_diff > 1e-2:
                    batch_index = torch.argmax(ldj_diff).item()
                    print("-> Batch index", batch_index)
                    print("-> Nodes with max diff:")
                    print(node_diff[batch_index])
                    print(z_nodes_backward[batch_index])
                    print(z_nodes[batch_index])
                    print(z_nodes_forward[batch_index])
                    node_mask = (node_diff[batch_index] > 1e-4)
                    faulty_nodes = z_nodes_forward[batch_index].masked_select(
                        node_mask)
                    num_small_faulty_nodes = (faulty_nodes.abs() <
                                              1.0).sum().item()
                    large_error = large_error or (num_small_faulty_nodes > 0)

                z_nodes = z_nodes_forward
                ldj = ldj_forward
            if not large_error:
                print("-" * 50)
                print(
                    "Error probably caused by large values out of range in the mixture layer. Ignored for now."
                )
            return (not large_error)
        else:
            print("Reversibility test passed")
            return True