def embedding_network(state, mask, seed=123):
    # Placeholder layer sizes
    d_e = [[64], [64, 128]]
    d_o = [128]

    # Build graph:
    initial_elems = state

    # Embedding Part
    for i, block in enumerate(d_e):
        el = initial_elems
        for j, layer in enumerate(block):
            context = c if j == 0 and not i == 0 else None
            el, _ = invariant_layer(el,
                                    layer,
                                    context=context,
                                    name='l' + str(i) + '_' + str(j),
                                    seed=seed + i + j)

        c = mask_and_pool(el, mask)  # pool to get context for next block

    # Fully connected part
    fc = c
    for i, layer in enumerate(d_o):
        fc, _, _ = linear(fc,
                          layer,
                          activation_fn=tf.nn.relu,
                          name='lO_' + str(i))

    # Output
    embedding = fc

    # Returns the network output and parameters
    return embedding, []
def relation_network(state, mask, seed=123):
    # Placeholder layer sizes
    d_e = [64, 64, 64]
    d_o = [128, 128]

    # Build graph:
    initial_elems = state

    # Embedding Part
    for i, layer in enumerate(d_e):
        el = initial_elems
        el, _ = relation_layer(layer, el, mask, name='l' + str(i))

    c = mask_and_pool(el, mask)  # pool to get context for next block

    # Fully connected part
    fc = c
    for i, layer in enumerate(d_o):
        fc, _, _ = linear(fc, layer, name='lO_' + str(i))

    # Output
    embedding = fc

    # Returns the network output and parameters
    return embedding, []
def embedding_network(state, mask):
    # Placeholder layer sizes
    d_e = [[128, 256]]
    d_o = [128]

    # Build graph:
    initial_elems = state

    # Get mask
    mask = ops.get_mask(state)

    # Embedding Part
    for i, block in enumerate(d_e):
        el = initial_elems
        for j, layer in enumerate(block):
            context = c if j == 0 and not i == 0 else None
            el = set_layer(el,
                           layer,
                           context=context,
                           name='l' + str(i) + '_' + str(j))

        c = mask_and_pool(el, mask)  # pool to get context for next block

    # Output
    embedding = c

    return embedding
def object_embedding_network2(state, l_e, l_o):
    mask = get_mask(state)

    # Embedding Part
    el = state
    
    el, _ = invariant_layer(el, l_e[0], name='l' + str(0))
    for i, l in enumerate(l_e[1:]):
        el = el - tf.expand_dims(mask_and_pool(el, mask), axis=1)
        el, _ = invariant_layer(el, l, name='l' + str(i+1))
        
    c = mask_and_pool(el, mask)
    
    # Fully connected part
    fc = c
    for i, layer in enumerate(l_o):
        fc, _, _ = linear(fc, layer, activation_fn=tf.nn.relu, name='lO_' + str(i))
    
    return fc, []
def rav_layer(x, mask, out_size, **kwargs):
    x = x - mask_and_pool(x, mask)
    out = set_layer(x, **kwargs)
    return out