def embedding_network(state, mask, seed=123): # Placeholder layer sizes d_e = [[64], [64, 128]] d_o = [128] # Build graph: initial_elems = state # Embedding Part for i, block in enumerate(d_e): el = initial_elems for j, layer in enumerate(block): context = c if j == 0 and not i == 0 else None el, _ = invariant_layer(el, layer, context=context, name='l' + str(i) + '_' + str(j), seed=seed + i + j) c = mask_and_pool(el, mask) # pool to get context for next block # Fully connected part fc = c for i, layer in enumerate(d_o): fc, _, _ = linear(fc, layer, activation_fn=tf.nn.relu, name='lO_' + str(i)) # Output embedding = fc # Returns the network output and parameters return embedding, []
def relation_network(state, mask, seed=123): # Placeholder layer sizes d_e = [64, 64, 64] d_o = [128, 128] # Build graph: initial_elems = state # Embedding Part for i, layer in enumerate(d_e): el = initial_elems el, _ = relation_layer(layer, el, mask, name='l' + str(i)) c = mask_and_pool(el, mask) # pool to get context for next block # Fully connected part fc = c for i, layer in enumerate(d_o): fc, _, _ = linear(fc, layer, name='lO_' + str(i)) # Output embedding = fc # Returns the network output and parameters return embedding, []
def embedding_network(state, mask): # Placeholder layer sizes d_e = [[128, 256]] d_o = [128] # Build graph: initial_elems = state # Get mask mask = ops.get_mask(state) # Embedding Part for i, block in enumerate(d_e): el = initial_elems for j, layer in enumerate(block): context = c if j == 0 and not i == 0 else None el = set_layer(el, layer, context=context, name='l' + str(i) + '_' + str(j)) c = mask_and_pool(el, mask) # pool to get context for next block # Output embedding = c return embedding
def object_embedding_network2(state, l_e, l_o): mask = get_mask(state) # Embedding Part el = state el, _ = invariant_layer(el, l_e[0], name='l' + str(0)) for i, l in enumerate(l_e[1:]): el = el - tf.expand_dims(mask_and_pool(el, mask), axis=1) el, _ = invariant_layer(el, l, name='l' + str(i+1)) c = mask_and_pool(el, mask) # Fully connected part fc = c for i, layer in enumerate(l_o): fc, _, _ = linear(fc, layer, activation_fn=tf.nn.relu, name='lO_' + str(i)) return fc, []
def rav_layer(x, mask, out_size, **kwargs): x = x - mask_and_pool(x, mask) out = set_layer(x, **kwargs) return out