Exemplo n.º 1
0
 def build_net(self, wdict=None):
     with variable_scope(self._name):
         K = self.config.num_classes
         if self.config.inner_loop_loss == "mix":
             K += 1
         if self.config.cosine_classifier:
             # mlp = CosineLinear(
             #     "mlp",
             #     # list(self.config.num_filters) + [K],
             #     self.config.num_filters[0],
             #     K,
             #     temp=10.0,
             #     wdict=wdict)
             mlp = CosineLastMLP(
                 "mlp",
                 list(self.config.num_filters) + [K],
                 # self.config.num_filters[0], K,
                 temp=10.0,
                 wdict=wdict)
         else:
             mlp = MLP("mlp",
                       list(self.config.num_filters) + [K],
                       wdict=wdict,
                       add_bias=self.config.classifier_bias)
     return mlp
Exemplo n.º 2
0
    def __init__(self, config, wdict=None):
        super(C4PlusFCBackbone, self).__init__(config)
        self.backbone = C4Backbone(config)

        if len(self.config.num_fc_dim) > 1:
            self.fc = MLP('fc', [config.num_filters[-1]] +
                          list(self.config.num_fc_dim),
                          wdict=wdict)
        else:
            self.fc = Linear('fc',
                             config.num_filters[-1],
                             self.config.num_fc_dim,
                             wdict=wdict)  # Hard coded for now.
Exemplo n.º 3
0
    def __init__(self,
                 name,
                 rnn_memory,
                 proto_memory,
                 readout_type='linear',
                 use_pred_beta_gamma=True,
                 use_feature_fuse=True,
                 use_feature_fuse_gate=True,
                 use_feature_scaling=True,
                 use_feature_memory_only=False,
                 skip_unk_memory_update=False,
                 use_ssl=True,
                 use_ssl_beta_gamma_write=True,
                 use_ssl_temp=True,
                 dtype=tf.float32):
        super(RNNEncoder, self).__init__(dtype=dtype)
        self._rnn_memory = rnn_memory
        self._proto_memory = proto_memory

        # ------------- Feature Fusing Capability Ablation --------------
        self._use_pred_beta_gamma = use_pred_beta_gamma  # CHECK
        self._use_feature_fuse = use_feature_fuse  # CHECK
        self._use_feature_fuse_gate = use_feature_fuse_gate  # CHECK
        self._use_feature_scaling = use_feature_scaling  # CHECK
        self._use_feature_memory_only = use_feature_memory_only  # CHECK

        # ------------- SSL Capability Ablation --------------
        self._skip_unk_memory_update = skip_unk_memory_update  # CHECK
        self._use_ssl = use_ssl  # CHECK
        self._use_ssl_beta_gamma_write = use_ssl_beta_gamma_write  # CHECK
        self._use_ssl_temp = use_ssl_temp  # CHECK

        D_in = self._rnn_memory.memory_dim
        D = self._rnn_memory.in_dim
        self._dim = D

        # h        [D]
        # scale    [D]
        # temp     [1]
        # gamma2   [1]
        # beta2    [1]
        # gamma    [1]
        # beta     [1]
        # x_gate   [1]
        # h_gate   [1]
        bias_init = [
            tf.zeros(D),
            tf.zeros(D),
            tf.zeros([1]),
            tf.zeros([1]),
            tf.zeros([1]) + proto_memory._radius_init,
            tf.zeros([1]),
            tf.zeros([1]) + proto_memory._radius_init_write,
            tf.zeros([1]) + 1.0,
            tf.zeros([1]) - 1.0
        ]
        bias_init = tf.concat(bias_init, axis=0)

        D_out = bias_init.shape[-1]

        def b_init():
            return bias_init

        if readout_type == 'linear':
            log.info("Using linear readout")
            self._readout = Linear('readout', D_in, D_out, b_init=b_init)
        elif readout_type == 'mlp':
            log.info("Using MLP readout")
            self._readout = MLP('readout_mlp', [D_in, D_out, D_out],
                                bias_init=[None, bias_init],
                                act_func=[tf.math.tanh])
        elif readout_type == 'resmlp':
            log.info("Using ResMLP readout")
            self._readout = ResMLP('readout_mlp', [D_in, D_out, D_out, D_out],
                                   bias_init=[None, None, bias_init],
                                   act_func=[swish, swish, None])
Exemplo n.º 4
0
    def __init__(self, name, in_dim, config, dtype=tf.float32):
        """Initialize a DNC module.

    Args:
      name: String. Name of the module.
      in_dim: Int. Input dimension.
      memory_dim: Int. Memory dimension.
      controller_dim: Int. Hidden dimension for the controller.
      nslot: Int. Number of memory slots.
      nread: Int. Number of read heads.
      nwrite: Int. Number of write heads.
      controller_type: String. `lstm` or `mlp.
      memory_layernorm: Bool. Whether perform LayerNorm on each memory
        iteration.
      dtype: Data type.
    """
        super(DNC, self).__init__(dtype=dtype)
        log.info('Currently using MANN with separate write attention')
        log.info('Currently using MANN with decay')
        self._in_dim = in_dim
        self._memory_dim = config.memory_dim
        self._controller_dim = config.controller_dim
        self._nslot = config.num_slots
        self._nread = config.num_reads
        self._nwrite = config.num_writes
        self._controller_nstack = config.controller_nstack
        self._controller_type = config.controller_type
        self._similarity_type = config.similarity_type
        with variable_scope(name):
            if config.controller_layernorm:
                log.info('Using LayerNorm in controller module.')
            if config.controller_type == 'lstm':
                self._controller = LSTM("controller_lstm",
                                        in_dim,
                                        config.controller_dim,
                                        layernorm=config.controller_layernorm,
                                        dtype=dtype)
            elif config.controller_type == 'stack_lstm':
                log.info('Use {}-stack LSTM'.format(config.controller_nstack))
                self._controller = StackLSTM(
                    "stack_controller_lstm",
                    in_dim,
                    config.controller_dim,
                    config.controller_nstack,
                    layernorm=config.controller_layernorm,
                    dtype=dtype)
            elif config.controller_type == 'mlp':
                log.info('Use MLP')
                self._controller = MLP(
                    "controller_mlp",
                    [in_dim, config.controller_dim, config.controller_dim],
                    layernorm=config.controller_layernorm,
                    dtype=dtype)
            rnd = np.random.RandomState(0)
            self._rnd = rnd
            self._memory_init = 1e-5 * tf.ones(
                [config.num_slots, config.memory_dim],
                name="memory_init",
                dtype=dtype)

            # N. Item name         Shape    Init    Comment
            # ------------------------------------------------------------
            # 1) read query        N x D    0.0
            # 2) write query       Nw x D   0.0
            # 3) write content     Nw x D   0.0
            # 4) forget gate       N        -2.0    No forget after read
            # 5) write gate        Nw       2.0     Always write
            # 6) interp gate       Nw       -2.0    Always use LRU
            # 7) read temp         N        0.0     Default 1.0
            # 8) write temp        Nw       0.0     Default 1.0
            # 9) erase             M        -2.0    Default no erase
            Nr = self._nread
            Nw = self._nwrite
            D = self._memory_dim
            M = self._nslot

            def ctrl2mem_bias_init():
                AA = tf.zeros([Nr * D + 2 * Nw * D], dtype=self.dtype)
                BB = -2.0 * tf.ones([Nr], dtype=self.dtype)
                CC = 2.0 * tf.ones([Nw], dtype=self.dtype)
                DD = -2.0 * tf.ones([Nw], dtype=self.dtype)
                EE = 0.0 * tf.ones([Nr], dtype=self.dtype)
                FF = 0.0 * tf.ones([Nw], dtype=self.dtype)
                GG = -2.0 * tf.ones([M], dtype=self.dtype)
                return tf.concat([AA, BB, CC, DD, EE, FF, GG], axis=0)

            self._ctrl2mem = Linear("ctrl2mem",
                                    config.controller_dim,
                                    Nr * D + 2 * Nw * D + Nr + 2 * Nw + Nr +
                                    Nw + M,
                                    b_init=ctrl2mem_bias_init)
            if config.memory_layernorm:
                log.info('Using LayerNorm for each memory iteration.')
                self._mem_layernorm = LayerNorm("memory_layernorm",
                                                D,
                                                dtype=dtype)
            else:
                self._mem_layernorm = None
Exemplo n.º 5
0
class DNC(ContainerModule):
    def __init__(self, name, in_dim, config, dtype=tf.float32):
        """Initialize a DNC module.

    Args:
      name: String. Name of the module.
      in_dim: Int. Input dimension.
      memory_dim: Int. Memory dimension.
      controller_dim: Int. Hidden dimension for the controller.
      nslot: Int. Number of memory slots.
      nread: Int. Number of read heads.
      nwrite: Int. Number of write heads.
      controller_type: String. `lstm` or `mlp.
      memory_layernorm: Bool. Whether perform LayerNorm on each memory
        iteration.
      dtype: Data type.
    """
        super(DNC, self).__init__(dtype=dtype)
        log.info('Currently using MANN with separate write attention')
        log.info('Currently using MANN with decay')
        self._in_dim = in_dim
        self._memory_dim = config.memory_dim
        self._controller_dim = config.controller_dim
        self._nslot = config.num_slots
        self._nread = config.num_reads
        self._nwrite = config.num_writes
        self._controller_nstack = config.controller_nstack
        self._controller_type = config.controller_type
        self._similarity_type = config.similarity_type
        with variable_scope(name):
            if config.controller_layernorm:
                log.info('Using LayerNorm in controller module.')
            if config.controller_type == 'lstm':
                self._controller = LSTM("controller_lstm",
                                        in_dim,
                                        config.controller_dim,
                                        layernorm=config.controller_layernorm,
                                        dtype=dtype)
            elif config.controller_type == 'stack_lstm':
                log.info('Use {}-stack LSTM'.format(config.controller_nstack))
                self._controller = StackLSTM(
                    "stack_controller_lstm",
                    in_dim,
                    config.controller_dim,
                    config.controller_nstack,
                    layernorm=config.controller_layernorm,
                    dtype=dtype)
            elif config.controller_type == 'mlp':
                log.info('Use MLP')
                self._controller = MLP(
                    "controller_mlp",
                    [in_dim, config.controller_dim, config.controller_dim],
                    layernorm=config.controller_layernorm,
                    dtype=dtype)
            rnd = np.random.RandomState(0)
            self._rnd = rnd
            self._memory_init = 1e-5 * tf.ones(
                [config.num_slots, config.memory_dim],
                name="memory_init",
                dtype=dtype)

            # N. Item name         Shape    Init    Comment
            # ------------------------------------------------------------
            # 1) read query        N x D    0.0
            # 2) write query       Nw x D   0.0
            # 3) write content     Nw x D   0.0
            # 4) forget gate       N        -2.0    No forget after read
            # 5) write gate        Nw       2.0     Always write
            # 6) interp gate       Nw       -2.0    Always use LRU
            # 7) read temp         N        0.0     Default 1.0
            # 8) write temp        Nw       0.0     Default 1.0
            # 9) erase             M        -2.0    Default no erase
            Nr = self._nread
            Nw = self._nwrite
            D = self._memory_dim
            M = self._nslot

            def ctrl2mem_bias_init():
                AA = tf.zeros([Nr * D + 2 * Nw * D], dtype=self.dtype)
                BB = -2.0 * tf.ones([Nr], dtype=self.dtype)
                CC = 2.0 * tf.ones([Nw], dtype=self.dtype)
                DD = -2.0 * tf.ones([Nw], dtype=self.dtype)
                EE = 0.0 * tf.ones([Nr], dtype=self.dtype)
                FF = 0.0 * tf.ones([Nw], dtype=self.dtype)
                GG = -2.0 * tf.ones([M], dtype=self.dtype)
                return tf.concat([AA, BB, CC, DD, EE, FF, GG], axis=0)

            self._ctrl2mem = Linear("ctrl2mem",
                                    config.controller_dim,
                                    Nr * D + 2 * Nw * D + Nr + 2 * Nw + Nr +
                                    Nw + M,
                                    b_init=ctrl2mem_bias_init)
            if config.memory_layernorm:
                log.info('Using LayerNorm for each memory iteration.')
                self._mem_layernorm = LayerNorm("memory_layernorm",
                                                D,
                                                dtype=dtype)
            else:
                self._mem_layernorm = None

    def slice_vec_2d(self, vec, size):
        return tf.split(vec, size, axis=1)

    def oneplus(self, x):
        return 1 + tf.math.log(1 + tf.math.exp(x))

    def get_lru(self, usage, Nw):
        """Get least used soft flag."""
        B = tf.constant(usage.shape[0])
        M = usage.shape[1]
        usage_sort, usage_idx = tf.nn.top_k(-usage, M)  # [B, M] [B, M]
        usage_sort = -usage_sort
        least_used = tf.TensorArray(self.dtype, size=Nw)
        for s in range(Nw):
            AA = tf.zeros([B, s], dtype=self.dtype)
            CC = tf.concat(
                [tf.ones([B, 1], dtype=self.dtype), usage_sort[:, s:]], axis=1)
            BB = tf.math.cumprod(CC, axis=1)[:, :-1]
            least_used_ = tf.concat([AA, BB], axis=-1)  # [B, M]
            least_used_ *= 1.0 - usage_sort
            least_used = least_used.write(s, least_used_)
        least_used = tf.transpose(least_used.stack(), [1, 0, 2])  # [B, Nw, M]

        batch_idx = tf.tile(tf.reshape(tf.range(B), [-1, 1, 1]), [1, Nw, M])
        nw_idx = tf.tile(tf.reshape(tf.range(Nw), [1, -1, 1]), [B, 1, M])
        usage_idx_ = tf.tile(tf.reshape(usage_idx, [-1, 1, M]), [1, Nw, 1])
        unsrt_idx = tf.stack([batch_idx, nw_idx, usage_idx_],
                             -1)  # [B, Nw, M, 3]
        least_used = tf.scatter_nd(unsrt_idx, least_used, [B, Nw, M])
        return least_used

    def content_attention(self, query, key, temp):
        """Attention to memory content.

    Args:
      query: [B, N, D] Query vector.
      key: [B, M, D] Key vector.
      temp: [B, N] Temperature.

    Returns:
      attn: [B, N, M] Attention.
    """
        return tf.nn.softmax(self.similarity(query, key) * temp)

    def forward(self, x, memory, usage, t, *args):
        """Forward one timestep.

    Args:
      x: [B, D]. Input.
      ctrl_c: [B, D]. Last step controller state.
      ctrl_h: [B, D]. Last step controller state.
      memory: [B, M, D]. Memory slots.
      usage: [B, M]. Memory usage.
      t: Int. Timestep counter.
    """
        B = tf.constant(x.shape[0])
        N = self.nread
        Nw = self.nwrite
        M = self.nslot
        D = self.memory_dim
        if self._controller_type in ['lstm', 'stack_lstm']:
            ctrl_c, ctrl_h = args
            ctrl_out = self._controller(x, ctrl_c, ctrl_h)
            ctrl_out, (ctrl_c, ctrl_h) = ctrl_out
        else:
            ctrl_out = self._controller(x)
        ctrl_mem = self._ctrl2mem(ctrl_out)  # [B, N * D + 2 * Nw * D]
        (read_query, write_query, write_content, forget, write_gate, sigma,
         temp_read, temp_write, erase) = self.slice_vec_2d(
             ctrl_mem, [N * D, Nw * D, Nw * D, N, Nw, Nw, N, Nw, M])
        temp_read = tf.expand_dims(self.oneplus(temp_read), -1)  # [B, N, 1]
        temp_write = tf.expand_dims(self.oneplus(temp_write), -1)  # [B, Nw, 1]
        read_query = tf.reshape(read_query, [-1, N, D])  # [B, N * D]
        write_query = tf.reshape(write_query, [-1, Nw, D])  # [B, Nw * D]
        write_content = tf.reshape(write_content, [-1, Nw, D])  # [B, Nw * D]
        # [B, N, M]
        read_vec = self.content_attention(read_query, memory, temp_read)
        write_vec = self.content_attention(write_query, memory,
                                           temp_write)  # [B, Nw, M]

        forget = tf.expand_dims(tf.math.sigmoid(forget), -1)  # [B, N, 1]
        free = tf.reduce_prod(1.0 - forget * read_vec, [1])  # [B, M]

        # Read memory content
        y = tf.reduce_sum(tf.matmul(read_vec, memory),
                          [1])  # [B, N, M] x [B, M, D] = [B, N, D] => [B, D]

        # Write memory content
        interp_gate = tf.expand_dims(tf.math.sigmoid(sigma), -1)  # [B, Nw, 1]
        write_gate = tf.expand_dims(tf.math.sigmoid(write_gate),
                                    -1)  # [B, Nw, 1]
        erase = tf.expand_dims(tf.math.sigmoid(erase), -1)  # [B, M, 1]
        least_used = self.get_lru(usage, Nw)  # [B, Nw, M]

        write_vec = write_gate * (interp_gate * write_vec +
                                  (1 - interp_gate) * least_used)  # [B, Nw, M]
        all_write = tf.reduce_max(write_vec, [1])
        usage = (usage + all_write - all_write * usage) * free  # [B, M]

        # Erase memory content. # [B, M, D]
        memory *= (tf.ones([B, M, 1]) - erase * tf.expand_dims(all_write, -1))

        # Write new content.
        memory += tf.matmul(
            write_vec, write_content,
            transpose_a=True)  # [B, Nw, M] x [B, Nw, D] = [B, M, D]

        # Optional memory layer norm.
        if self._mem_layernorm is not None:
            memory = self._mem_layernorm(memory)

        # Time step.
        t += tf.constant(1)

        if self._controller_type in ['lstm', 'stack_lstm']:
            return y, (memory, usage, t, ctrl_c, ctrl_h)
        else:
            return y, (memory, usage, t)

    def get_initial_state(self, bsize):
        """Initialize hidden state."""
        memory = tf.tile(tf.expand_dims(self._memory_init, 0), [bsize, 1, 1])
        usage = tf.zeros([bsize, self.nslot], dtype=self.dtype)
        t = tf.constant(0, dtype=tf.int32)
        if self._controller_type in ['lstm', 'stack_lstm']:
            ctrl_c, ctrl_h = self._controller.get_initial_state(bsize)
            return memory, usage, t, ctrl_c, ctrl_h
        elif self._controller_type in ['mlp']:
            return memory, usage, t
        else:
            assert False

    def _expand(self, num, x):
        """Expand one variable."""
        tile = [1] + [num] + [1] * (len(x.shape) - 1)
        reshape = [-1] + list(x.shape[1:])
        return tf.reshape(tf.tile(tf.expand_dims(x, 1), tile), reshape)

    def expand_state(self, num, ctrl_c, ctrl_h, memory, usage, t):
        """Expand the hidden state for query set."""

        memory = self._expand(num, memory)
        usage = self._expand(num, usage)
        if self._controller_type in ['lstm', 'stack_lstm']:
            ctrl_c = self._expand(num, ctrl_c)
            ctrl_h = self._expand(num, ctrl_h)
            return memory, usage, t, ctrl_c, ctrl_h
        else:
            return memory, usage, t

    def similarity(self, query, key):
        """Query the memory with a key using cosine similarity.

    Args:
      query: [B, N, D]. B: batch size, N: number of reads, D: dimension.
      key: [B, M, D]. B: batch size, M: number of slots, D: dimension.

    Returns:
      sim: [B, N, M]
    """
        eps = 1e-7
        if self._similarity_type == 'cosine':
            q_norm = tf.sqrt(
                tf.reduce_sum(tf.square(query), [-1],
                              keepdims=True))  # [B, N, 1]
            q_ = query / (q_norm + eps)  # [B, N, D]
            k_norm = tf.sqrt(tf.reduce_sum(tf.square(key), [-1],
                                           keepdims=True))  # [B, M, 1]
            k_ = key / (k_norm + eps)  # [B, M, D]
            sim = tf.matmul(q_, k_, transpose_b=True)
        elif self._similarity_type == 'dot_product':
            sim = tf.matmul(query, key, transpose_b=True)
        return sim

    def end_iteration(self, h_last):
        """End recurrent iterations."""
        return h_last

    @property
    def in_dim(self):
        return self._in_dim

    @property
    def memory_dim(self):
        return self._memory_dim

    @property
    def nin(self):
        return self._in_dim

    @property
    def nout(self):
        return self._memory_dim

    @property
    def controller_dim(self):
        return self._controller_dim

    @property
    def nslot(self):
        return self._nslot

    @property
    def nread(self):
        return self._nread

    @property
    def nwrite(self):
        return self._nwrite