예제 #1
0
    def __init__(self, config):
        """Vanilla VAE model"""
        super(VAE, self).__init__()

        self.config = config

        if config.layer == 'mlp':
            # Encoder:  q(Z|X)
            self.encoder = layers.MLP(dims=[config.x_size, config.h_size])
            self.relu = nn.ReLU()
            self.to_mu = nn.Linear(config.h_size, config.z_size)
            self.to_sigma = nn.Linear(config.h_size, config.z_size)

            # Decoder: q(X|Z)
            self.decoder = layers.MLP(dims=[config.z_size, 200, config.x_size])
            self.sigmoid = nn.Sigmoid()

        elif config.layer == 'conv':
            # Encoder:  q(Z|X)
            self.encoder = layers.ConvEncoder()
            self.relu = nn.ReLU()
            self.to_mu = nn.Linear(32 * 5 * 5, config.z_size)
            self.to_sigma = nn.Linear(32 * 5 * 5, config.z_size)

            # Decoder: q(X|Z)
            self.fc1 = layers.MLP(dims=[config.z_size, 32 * 5 * 5, 1568])
            self.decoder = layers.ConvDecoder()
            self.sigmoid = nn.Sigmoid()
예제 #2
0
 def __init__(self,
              dim,
              num_heads,
              mlp_ratio=4.,
              qkv_bias=False,
              qk_scale=None,
              attn_drop=0.,
              mlp_drop=0.,
              drop_path=0.,
              act_layer=nn.GELU,
              norm_layer=nn.LayerNorm):
     super(Block, self).__init__()
     self.norm1 = norm_layer(dim)
     self.attn = Attention(dim,
                           num_heads=num_heads,
                           qkv_bias=qkv_bias,
                           qk_scale=qk_scale,
                           attn_drop=attn_drop,
                           proj_drop=mlp_drop)
     # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
     self.drop_path = xlayers.DropPath(
         drop_path) if drop_path > 0. else nn.Identity()
     self.norm2 = norm_layer(dim)
     mlp_hidden_dim = int(dim * mlp_ratio)
     self.mlp = xlayers.MLP(in_features=dim,
                            hidden_features=mlp_hidden_dim,
                            act_layer=act_layer,
                            drop=mlp_drop)
예제 #3
0
    def __init__(self, n_features, n_heads, dropout, attn_dropout, ff_dropout, max_seq_len):
        super(TransformerBlock, self).__init__()

        self.ln_1 = nn.LayerNorm(n_features)
        self.attn = layers.MultiheadAttention(n_features, n_heads, attn_dropout, ff_dropout, max_seq_len)
        self.ln_2 = nn.LayerNorm(n_features)
        self.mlp = layers.MLP(4 * n_features, n_features, ff_dropout)
예제 #4
0
 def __init__(self, *args):
     sizes = it2list(args)
     it = iter(sizes)
     self.start = layers.Input(next(it))
     layer = self.start
     for size in it:
         layer.append(layers.MLP(size, activations.sigmoid))
         layer = layer.next_
     self.end = layer
예제 #5
0
def create_model(dtype=tf.float16, eval_mode=False):
    """
    creates the GNN model
    params:
    bs -- batch size
    dtype -- data type
    eval_mode: used for checking correctness of the model function using
      pretrained weights
    """
    # nodes, edges, receivers, senders, node_graph_idx, edge_graph_idx = graph

    bs = FLAGS.batch_size
    inputs_list = [
        tf.keras.Input((bs * FLAGS.n_nodes, NODE_FEATURE_DIMS),
                       dtype=dtype,
                       batch_size=1),
        tf.keras.Input((bs * FLAGS.n_edges, EDGE_FEATURE_DIMS),
                       dtype=dtype,
                       batch_size=1),
        tf.keras.Input((bs * FLAGS.n_edges, 1), dtype=tf.int32, batch_size=1),
        tf.keras.Input((bs * FLAGS.n_edges, 1), dtype=tf.int32, batch_size=1),
        # node graph idx
        tf.keras.Input((bs * FLAGS.n_nodes), dtype=tf.int32, batch_size=1),
        # edge graph idx
        tf.keras.Input((bs * FLAGS.n_edges), dtype=tf.int32, batch_size=1),
    ]

    inputs_list_squeezed = [tf.squeeze(_input) for _input in inputs_list]

    if FLAGS.model == 'graph_network':
        x = layers.EncoderLayer(
            edge_model_fn=lambda: get_default_mlp(activate_final=False,
                                                  name='edge_encoder'),
            node_model_fn=lambda: get_default_mlp(activate_final=False,
                                                  name='node_encoder'),
        )(inputs_list_squeezed)
        for i in range(FLAGS.n_graph_layers):
            x = layers.GraphNetworkLayer(
                edge_model_fn=lambda: get_default_mlp(activate_final=False,
                                                      name='edge'),
                node_model_fn=lambda: get_default_mlp(activate_final=False,
                                                      name='node'),
                global_model_fn=lambda: get_default_mlp(activate_final=False,
                                                        name='global'),
                # eval mode -- load all the weights, including the redundant last global layer
                nodes_dropout=FLAGS.nodes_dropout,
                edges_dropout=FLAGS.edges_dropout,
                globals_dropout=FLAGS.globals_dropout)(x)
        output_logits = layers.DecoderLayer(
            global_model_fn=lambda: layers.MLP(n_layers=3,
                                               n_hidden=FLAGS.n_hidden,
                                               n_out=1,
                                               activate_final=False,
                                               name='output_logits'))(x)
    elif FLAGS.model == 'interaction_network':
        x = layers.EncoderLayer(
            edge_model_fn=lambda: get_default_mlp(activate_final=False,
                                                  name='edge_encoder'),
            node_model_fn=lambda: get_default_mlp(activate_final=False,
                                                  name='node_encoder'),
        )(inputs_list_squeezed)
        for i in range(FLAGS.n_graph_layers):
            x = layers.InteractionNetworkLayer(
                edge_model_fn=lambda: get_default_mlp(activate_final=False,
                                                      name='edge'),
                node_model_fn=lambda: get_default_mlp(activate_final=False,
                                                      name='node'),
                nodes_dropout=FLAGS.nodes_dropout,
                edges_dropout=FLAGS.edges_dropout,
            )(x)
        output_logits = layers.DecoderLayer(
            global_model_fn=lambda: layers.MLP(n_layers=3,
                                               n_hidden=FLAGS.n_hidden,
                                               n_out=1,
                                               activate_final=False,
                                               name='output_logits'))(x)
    elif FLAGS.model == 'graph_isomorphism':
        graph_tuple = layers.GinEncoderLayer()(inputs_list_squeezed)
        for i in range(FLAGS.n_graph_layers):
            graph_tuple = layers.GraphIsomorphismLayer(
                # final layer before output decoder is NOT activated
                get_mlp=lambda: get_default_mlp(name='GIN_mlp',
                                                activate_final=i < FLAGS.
                                                n_graph_layers - 1),
                use_edges=FLAGS.use_edges,
                # edge embedding dimensionality must match the input to the layer
                edge_dim=FLAGS.n_latent
                if i > 0 else FLAGS.n_embedding_channels,
                dropout=FLAGS.nodes_dropout)(graph_tuple)
        output_prob = layers.GinDecoderLayer()(graph_tuple)
        return tf.keras.Model(inputs_list, output_prob)

    # dummy dim needed -- see
    # https://www.tensorflow.org/tutorials/distribute/custom_training#define_the_loss_function
    output_prob = tf.reshape(tf.nn.sigmoid(output_logits), [-1, 1])
    model = tf.keras.Model(inputs_list, output_prob)
    return model
예제 #6
0
def get_default_mlp(activate_final, name=None):
    return layers.MLP(n_layers=FLAGS.n_mlp_layers,
                      n_hidden=FLAGS.n_hidden,
                      n_out=FLAGS.n_latent,
                      activate_final=activate_final,
                      name=name)