def concat_entity_masks(inps, masks):
    '''
        Concats masks together. If mask is None, then it creates
            a tensor of 1's with shape (BS, T, NE).
        Args:
            inps (list of tensors): tensors that masks apply to
            masks (list of tensors): corresponding masks
    '''
    assert len(inps) == len(
        masks), "There should be the same number of inputs as masks"
    with tf.variable_scope('concat_masks'):
        shapes = [shape_list(_x) for _x in inps]
        new_masks = []
        for inp, mask in zip(inps, masks):
            if mask is None:
                inp_shape = shape_list(inp)
                if len(inp_shape) == 4:  # this is an entity tensor
                    new_masks.append(tf.ones(inp_shape[:3]))
                elif len(
                        inp_shape
                ) == 3:  # this is a pooled or main tensor. Set NE (outer dimension) to 1
                    new_masks.append(tf.ones(inp_shape[:2] + [1]))
            else:
                new_masks.append(mask)
        new_mask = tf.concat(new_masks, -1)
    return new_mask
def qkv_embed(inp,
              heads,
              n_embd,
              layer_norm=False,
              qk_w=1.0,
              v_w=0.01,
              reuse=False):
    '''
        Compute queries, keys, and values
        Args:
            inp (tf) -- tensor w/ shape (bs, T, NE, features)
            heads (int) -- number of attention heads
            n_embd (int) -- dimension of queries, keys, and values will be n_embd / heads
            layer_norm (bool) -- normalize embedding prior to computing qkv
            qk_w (float) -- Initialization scale for keys and queries. Actual scale will be
                sqrt(qk_w / #input features)
            v_w (float) -- Initialization scale for values. Actual scale will be sqrt(v_w / #input features)
            reuse (bool) -- tf reuse
    '''
    with tf.variable_scope('qkv_embed'):
        bs, T, NE, features = shape_list(inp)
        if layer_norm:
            with tf.variable_scope('pre_sa_layer_norm'):
                inp = tf.contrib.layers.layer_norm(inp, begin_norm_axis=3)

        # qk shape (bs x T x NE x h x n_embd/h)
        qk_scale = np.sqrt(qk_w / features)
        qk = tf.layers.dense(
            inp,
            n_embd * 2,
            kernel_initializer=tf.random_normal_initializer(stddev=qk_scale),
            reuse=reuse,
            name="qk_embed")  # bs x T x n_embd*2
        qk = tf.reshape(qk, (bs, T, NE, heads, n_embd // heads, 2))

        # (bs, T, NE, heads, features)
        query, key = [tf.squeeze(x, -1) for x in tf.split(qk, 2, -1)]

        v_scale = np.sqrt(v_w / features)
        value = tf.layers.dense(
            inp,
            n_embd,
            kernel_initializer=tf.random_normal_initializer(stddev=v_scale),
            reuse=reuse,
            name="v_embed")  # bs x T x n_embd
        value = tf.reshape(value, (bs, T, NE, heads, n_embd // heads))

        query = tf.transpose(
            query, (0, 1, 3, 2, 4),
            name="transpose_query")  # (bs, T, heads, NE, n_embd / heads)
        key = tf.transpose(
            key, (0, 1, 3, 4, 2),
            name="transpose_key")  # (bs, T, heads, n_embd / heads, NE)
        value = tf.transpose(
            value, (0, 1, 3, 2, 4),
            name="transpose_value")  # (bs, T, heads, NE, n_embd / heads)

    return query, key, value
def entity_concat(inps):
    '''
        Concat 4D tensors along the third dimension. If a 3D tensor is in the list
            then treat it as a single entity and expand the third dimension
        Args:
            inps (list of tensors): tensors to concatenate
    '''
    with tf.variable_scope('concat_entities'):
        shapes = [shape_list(_x) for _x in inps]
        # For inputs that don't have entity dimension add one.
        inps = [
            _x if len(_shape) == 4 else tf.expand_dims(_x, 2)
            for _x, _shape in zip(inps, shapes)
        ]
        shapes = [shape_list(_x) for _x in inps]
        assert np.all([_shape[-1] == shapes[0][-1] for _shape in shapes]),\
            f"Some entities don't have the same outer or inner dimensions {shapes}"
        # Concatenate along entity dimension
        out = tf.concat(inps, -2)
    return out
def circ_conv1d(inp, **conv_kwargs):
    valid_activations = {'relu': tf.nn.relu, 'tanh': tf.tanh, '': None}
    assert 'kernel_size' in conv_kwargs, f"Kernel size needs to be specified for circular convolution layer."
    conv_kwargs['activation'] = valid_activations[conv_kwargs['activation']]

    # concatenate input for circular convolution
    kernel_size = conv_kwargs['kernel_size']
    num_pad = kernel_size // 2
    inp_shape = shape_list(inp)
    inp_rs = tf.reshape(inp,
                        shape=[inp_shape[0] * inp_shape[1]] +
                        inp_shape[2:])  #  (BS * T, NE, feats)
    inp_padded = tf.concat(
        [inp_rs[..., -num_pad:, :], inp_rs, inp_rs[..., :num_pad, :]], -2)
    out = tf.layers.conv1d(
        inp_padded,
        kernel_initializer=tf.contrib.layers.xavier_initializer(),
        padding='valid',
        **conv_kwargs)

    out = tf.reshape(out, shape=inp_shape[:3] + [conv_kwargs['filters']])
    return out
def self_attention(inp,
                   mask,
                   heads,
                   n_embd,
                   layer_norm=False,
                   qk_w=1.0,
                   v_w=0.01,
                   scope='',
                   reuse=False):
    '''
        Self attention over entities.
        Notation:
            T  - Time
            NE - Number entities
        Args:
            inp (tf) -- tensor w/ shape (bs, T, NE, features)
            mask (tf) -- binary tensor with shape (bs, T, NE). For each batch x time,
                            nner matrix represents entity i's ability to see entity j
            heads (int) -- number of attention heads
            n_embd (int) -- dimension of queries, keys, and values will be n_embd / heads
            layer_norm (bool) -- normalize embedding prior to computing qkv
            qk_w, v_w (float) -- scale for gaussian init for keys/queries and values
                Std will be sqrt(scale/n_embd)
            scope (string) -- tf scope
            reuse (bool) -- tf reuse
    '''
    with tf.variable_scope(scope, reuse=reuse):
        bs, T, NE, features = shape_list(inp)
        # Put mask in format correct for logit matrix
        entity_mask = None
        if mask is not None:
            with tf.variable_scope('expand_mask'):
                assert np.all(np.array(mask.get_shape().as_list()) == np.array(inp.get_shape().as_list()[:3])),\
                    f"Mask and input should have the same first 3 dimensions. {shape_list(mask)} -- {shape_list(inp)}"
                entity_mask = mask
                mask = tf.expand_dims(mask, -2)  # (BS, T, 1, NE)

        query, key, value = qkv_embed(inp,
                                      heads,
                                      n_embd,
                                      layer_norm=layer_norm,
                                      qk_w=qk_w,
                                      v_w=v_w,
                                      reuse=reuse)
        logits = tf.matmul(query, key,
                           name="matmul_qk_parallel")  # (bs, T, heads, NE, NE)
        logits /= np.sqrt(n_embd / heads)
        softmax = stable_masked_softmax(logits, mask)
        att_sum = tf.matmul(
            softmax, value,
            name="matmul_softmax_value")  # (bs, T, heads, NE, features)
        with tf.variable_scope('flatten_heads'):
            out = tf.transpose(
                att_sum,
                (0, 1, 3, 2, 4))  # (bs, T, n_output_entities, heads, features)
            n_output_entities = shape_list(out)[2]
            out = tf.reshape(out,
                             (bs, T, n_output_entities,
                              n_embd))  # (bs, T, n_output_entities, n_embd)

        return out
Example #6
0
def construct_tf_graph(
    all_inputs,
    spec,
    act,
    scope='',
    reuse=False,
):
    '''
        Construct tensorflow graph from spec. See mas/ppo/base-architectures.jsonnet for examples.
        Args:
            main_inp (tf) -- input activations
            other_inp (dict of tf) -- other input activations such as state
            spec (list of dicts) -- network specification. see Usage below
            scope (string) -- tf variable scope
            reuse (bool) -- tensorflow reuse flag
        Usage:
            Each layer spec has optional arguments: nodes_in and nodes_in. If these arguments
                are omitted, then the default in and out nodes will be 'main'. For layers such as
                concatentation, these arguments must be specified.
            Dense layer (MLP) --
            {
                'layer_type': 'dense'
                'units': int (number of neurons)
                'activation': 'relu', 'tanh', or '' for no activation
            }
            LSTM layer --
            {
                'layer_type': 'lstm'
                'units': int (hidden state size)
            }
            Concat layer --
            Two use cases.
                First: the first input has one less dimension than the second input. In this case,
                    broadcast the first input along the second to last dimension and concatenated
                    along last dimension
                Second: Both inputs have the same dimension, and will be concatenated along last
                    dimension
            {
                'layer_type': 'concat'
                'nodes_in': ['node_one', 'node_two']
                'nodes_out': ['node_out']
            }
            Entity Concat Layer --
            Concatenate along entity dimension (second to last)
            {
                'layer_type': 'entity_concat'
                'nodes_in': ['node_one', 'node_two']
                'nodes_out': ['node_out']
            }
            Entity Self Attention --
            Self attention over entity dimension (second to last)
            See policy.utils:residual_sa_block for args
            {
                'layer_type': 'residual_sa_block'
                'nodes_in': ['node_one']
                'nodes_out': ['node_out']
                ...
            }
            Entity Pooling --
            Pooling along entity dimension (second to last)
            {
                'layer_type': 'entity_pooling'
                'nodes_in': ['node_one', 'node_two']
                'nodes_out': ['node_out']
                'type': (optional string, default 'avg_pooling') type of pooling
                         Current options are 'avg_pooling' and 'max_pooling'
            }
            Circular 1d convolution layer (second to last dimension) --
            {
                'layer_type': 'circ_conv1d',
                'filters': number of filters
                'kernel_size': kernel size
                'activation': 'relu', 'tanh', or '' for no activation
            }
            Flatten outer dimension --
            Flatten all dimensions higher or equal to 3 (necessary after conv layer)
            {
                'layer_type': 'flatten_outer',
            }
            Layernorm --

    '''
    # Make a new dict to not overwrite input
    inp = {k: v for k, v in all_inputs.items()}
    inp['main'] = inp['observation_self']

    valid_activations = {'relu': tf.nn.relu, 'tanh': tf.tanh, '': None}
    state_variables = OrderedDict()
    logger.info(f"Spec:\n{spec}")
    entity_locations = {}
    reset_ops = []
    with tf.variable_scope(scope, reuse=reuse):
        for i, layer in enumerate(spec):
            try:
                layer = deepcopy(layer)
                layer_type = layer.pop('layer_type')
                extra_layer_scope = layer.pop('scope', '')
                nodes_in = layer.pop('nodes_in', ['main'])
                nodes_out = layer.pop('nodes_out', ['main'])
                with tf.variable_scope(extra_layer_scope, reuse=reuse):
                    if layer_type == 'dense':
                        assert len(nodes_in) == len(
                            nodes_out
                        ), f"Dense layer must have same number of nodes in as nodes out. \
                            Nodes in: {nodes_in}, Nodes out {nodes_out}"

                        layer['activation'] = valid_activations[
                            layer['activation']]
                        layer_name = layer.pop('layer_name', f'dense{i}')
                        for j in range(len(nodes_in)):
                            inp[nodes_out[j]] = tf.layers.dense(
                                inp[nodes_in[j]],
                                name=f'{layer_name}-{j}',
                                kernel_initializer=tf.contrib.layers.
                                xavier_initializer(),
                                reuse=reuse,
                                **layer)
                    elif layer_type == 'lstm':
                        layer_name = layer.pop('layer_name', f'lstm{i}')
                        with tf.variable_scope(layer_name, reuse=reuse):
                            assert len(nodes_in) == len(nodes_out) == 1
                            cell = tf.contrib.rnn.BasicLSTMCell(layer['units'])
                            initial_state = tf.contrib.rnn.LSTMStateTuple(
                                inp[scope + f'_lstm{i}_state_c'],
                                inp[scope + f'_lstm{i}_state_h'])
                            inp[nodes_out[0]], state_out = tf.nn.dynamic_rnn(
                                cell,
                                inp[nodes_in[0]],
                                initial_state=initial_state)
                            state_variables[scope +
                                            f'_lstm{i}_state_c'] = state_out.c
                            state_variables[scope +
                                            f'_lstm{i}_state_h'] = state_out.h
                    elif layer_type == 'concat':
                        layer_name = layer.pop('layer_name', f'concat{i}')
                        with tf.variable_scope(layer_name):
                            assert len(
                                nodes_out
                            ) == 1, f"Concat op must only have one node out. Nodes Out: {nodes_out}"
                            assert len(
                                nodes_in
                            ) == 2, f"Concat op must have two nodes in. Nodes In: {nodes_in}"
                            assert (len(shape_list(inp[nodes_in[0]])) == len(shape_list(inp[nodes_in[1]])) or
                                    len(shape_list(inp[nodes_in[0]])) == len(shape_list(inp[nodes_in[1]])) - 1),\
                                f"shapes were {nodes_in[0]}:{shape_list(inp[nodes_in[0]])}, {nodes_in[1]}:{shape_list(inp[nodes_in[1]])}"

                            inp0, inp1 = inp[nodes_in[0]], inp[nodes_in[1]]
                            # tile inp0 along second to last dimension to match inp1
                            if len(shape_list(inp[nodes_in[0]])) == len(
                                    shape_list(inp1)) - 1:
                                inp0 = tf.expand_dims(inp[nodes_in[0]], -2)
                                tile_dims = [
                                    1 for i in range(len(shape_list(inp0)))
                                ]
                                tile_dims[-2] = shape_list(inp1)[-2]
                                inp0 = tf.tile(inp0, tile_dims)
                            inp[nodes_out[0]] = tf.concat([inp0, inp1], -1)
                    elif layer_type == 'entity_concat':
                        layer_name = layer.pop('layer_name',
                                               f'entity-concat{i}')
                        with tf.variable_scope(layer_name):
                            ec_inps = [inp[node_in] for node_in in nodes_in]
                            inp[nodes_out[0]] = entity_concat(ec_inps)
                            if "masks_in" in layer:
                                masks_in = [
                                    inp[_m] if _m is not None else None
                                    for _m in layer["masks_in"]
                                ]
                                inp[layer["mask_out"]] = concat_entity_masks(
                                    ec_inps, masks_in)
                            # Store where the entities are. We'll store with key nodes_out[0]
                            _ent_locs = {}
                            loc = 0
                            for node_in in nodes_in:
                                shape_in = shape_list(inp[node_in])
                                n_ent = shape_in[2] if len(
                                    shape_in) == 4 else 1
                                _ent_locs[node_in] = slice(loc, loc + n_ent)
                                loc += n_ent
                            entity_locations[nodes_out[0]] = _ent_locs
                    elif layer_type == 'residual_sa_block':
                        layer_name = layer.pop('layer_name',
                                               f'self-attention{i}')
                        with tf.variable_scope(layer_name):
                            assert len(
                                nodes_in
                            ) == 1, "self attention should only have one input"
                            sa_inp = inp[nodes_in[0]]

                            mask = inp[layer.pop(
                                'mask')] if 'mask' in layer else None
                            internal_layer_name = layer.pop(
                                'internal_layer_name', f'residual_sa_block{i}')
                            inp[nodes_out[0]] = residual_sa_block(
                                sa_inp,
                                mask,
                                **layer,
                                scope=internal_layer_name,
                                reuse=reuse)
                    elif layer_type == 'entity_pooling':
                        pool_type = layer.get('type', 'avg_pooling')
                        assert pool_type in ['avg_pooling', 'max_pooling'
                                             ], f"Pooling type {pool_type} \
                            not available. Pooling type must be either 'avg_pooling' or 'max_pooling'."

                        layer_name = layer.pop(
                            'layer_name', f'entity-{pool_type}-pooling{i}')
                        with tf.variable_scope(layer_name):
                            if 'mask' in layer:
                                mask = inp[layer.pop('mask')]
                                assert mask.get_shape()[-1] == inp[nodes_in[0]].get_shape()[-2], \
                                    f"Outer dim of mask must match second to last dim of input. \
                                     Mask shape: {mask.get_shape()}. Input shape: {inp[nodes_in[0]].get_shape()}"

                                if pool_type == 'avg_pooling':
                                    inp[nodes_out[
                                        0]] = entity_avg_pooling_masked(
                                            inp[nodes_in[0]], mask)
                                elif pool_type == 'max_pooling':
                                    inp[nodes_out[
                                        0]] = entity_max_pooling_masked(
                                            inp[nodes_in[0]], mask)
                            else:
                                if pool_type == 'avg_pooling':
                                    inp[nodes_out[0]] = tf.reduce_mean(
                                        inp[nodes_in[0]], -2)
                                elif pool_type == 'max_pooling':
                                    inp[nodes_out[0]] = tf.reduce_max(
                                        inp[nodes_in[0]], -2)
                    elif layer_type == 'circ_conv1d':
                        assert len(nodes_in) == len(
                            nodes_out
                        ) == 1, f"Circular convolution layer must have one nodes and one nodes out. \
                            Nodes in: {nodes_in}, Nodes out {nodes_out}"

                        layer_name = layer.pop('layer_name', f'circ_conv1d{i}')
                        with tf.variable_scope(layer_name, reuse=reuse):
                            inp[nodes_out[0]] = circ_conv1d(
                                inp[nodes_in[0]], **layer)
                    elif layer_type == 'flatten_outer':
                        layer_name = layer.pop('layer_name',
                                               f'flatten_outer{i}')
                        with tf.variable_scope(layer_name, reuse=reuse):
                            # flatten all dimensions higher or equal to 3
                            inp0 = inp[nodes_in[0]]
                            inp0_shape = shape_list(inp0)
                            inp[nodes_out[0]] = tf.reshape(
                                inp0,
                                shape=inp0_shape[0:2] +
                                [np.prod(inp0_shape[2:])])
                    elif layer_type == "layernorm":
                        layer_name = layer.pop('layer_name', f'layernorm{i}')
                        with tf.variable_scope(layer_name, reuse=reuse):
                            inp[nodes_out[0]] = tf.contrib.layers.layer_norm(
                                inp[nodes_in[0]], begin_norm_axis=2)
                    else:
                        raise NotImplementedError(
                            f"Layer type -- {layer_type} -- not yet implemented"
                        )
            except Exception:
                traceback.print_exc(file=sys.stdout)
                print(
                    f"Error in {layer_type} layer: \n{layer}\nNodes in: {nodes_in}, Nodes out: {nodes_out}"
                )
                sys.exit()

    return inp, state_variables, reset_ops
Example #7
0
    def _init_policy_out(self, pi, taken_actions):
        with tf.variable_scope('policy_out'):
            self.pdparams = {}
            for k in self.pdtypes.keys():
                with tf.variable_scope(k):
                    if self.gaussian_fixed_var and isinstance(
                            self.ac_space.spaces[k], gym.spaces.Box):
                        mean = tf.layers.dense(
                            pi["main"],
                            self.pdtypes[k].param_shape()[0] // 2,
                            kernel_initializer=normc_initializer(0.01),
                            activation=None)
                        logstd = tf.get_variable(
                            name="logstd",
                            shape=[1, self.pdtypes[k].param_shape()[0] // 2],
                            initializer=tf.zeros_initializer())
                        self.pdparams[k] = tf.concat(
                            [mean, mean * 0.0 + logstd], axis=2)
                    elif k in pi:
                        # This is just for the case of entity specific actions
                        if isinstance(self.ac_space.spaces[k],
                                      (gym.spaces.Discrete)):
                            assert pi[k].get_shape()[-1] == 1
                            self.pdparams[k] = pi[k][..., 0]
                        elif isinstance(self.ac_space.spaces[k],
                                        (gym.spaces.MultiDiscrete)):
                            assert np.prod(pi[k].get_shape()[-2:]) == self.pdtypes[k].param_shape()[0],\
                                f"policy had shape {pi[k].get_shape()} for action {k}, but required {self.pdtypes[k].param_shape()}"
                            new_shape = shape_list(pi[k])[:-2] + [
                                np.prod(pi[k].get_shape()[-2:]).value
                            ]
                            self.pdparams[k] = tf.reshape(pi[k],
                                                          shape=new_shape)
                        else:
                            assert False
                    else:
                        self.pdparams[k] = tf.layers.dense(
                            pi["main"],
                            self.pdtypes[k].param_shape()[0],
                            kernel_initializer=normc_initializer(0.01),
                            activation=None)

            with tf.variable_scope('pds'):
                self.pds = {
                    k: pdtype.pdfromflat(self.pdparams[k])
                    for k, pdtype in self.pdtypes.items()
                }

            with tf.variable_scope('sampled_action'):
                self.sampled_action = {
                    k: pd.sample() if self.stochastic else pd.mode()
                    for k, pd in self.pds.items()
                }
            with tf.variable_scope('sampled_action_logp'):
                self.sampled_action_logp = sum([
                    self.pds[k].logp(self.sampled_action[k])
                    for k in self.pdtypes.keys()
                ])
            with tf.variable_scope('entropy'):
                self.entropy = sum([pd.entropy() for pd in self.pds.values()])
            with tf.variable_scope('taken_action_logp'):
                self.taken_action_logp = sum([
                    self.pds[k].logp(taken_actions[k])
                    for k in self.pdtypes.keys()
                ])