Beispiel #1
0
    def __init__(self, dim, self_attention=False, memory_gate=False):
        """
        Constructor for the write unit.

        :param dim: global 'd' hidden dimension
        :param self_attention: whether or not to use self-attention on the previous control states
        :param memory_gate: whether or not to use memory gating.

        """

        # call base constructor
        super(WriteUnit, self).__init__()

        # linear layer for the concatenation of ri & mi-1
        self.concat_layer = linear(2 * dim, dim, bias=True)

        # self-attention & memory gating optional initializations
        self.self_attention = self_attention
        self.memory_gate = memory_gate

        if self.self_attention:
            self.attn = linear(dim, 1, bias=True)
            self.mi_sa_proj = linear(dim, dim, bias=True)
            self.mi_info_proj = linear(dim, dim, bias=True)

        if self.memory_gate:
            self.control = linear(dim, 1, bias=True)
Beispiel #2
0
    def __init__(self, dim, max_step):
        """
        Constructor for the control unit.

        :param dim: global 'd' hidden dimension
        :param max_step: maximum number of steps -> number of MAC cells in the network

        """

        # call base constructor
        super(ControlUnit, self).__init__()

        # define the linear layers (one per step) used to make the questions
        # encoding
        self.pos_aware_layers = nn.ModuleList()
        for _ in range(max_step):
            self.pos_aware_layers.append(linear(2 * dim, dim, bias=True))

        # define the linear layer used to create the cqi values
        self.ctrl_question = linear(2 * dim, dim, bias=True)

        # define the linear layer used to create the attention weights. Should
        # be one scalar weight per contextual word
        self.attn = linear(dim, 1, bias=True)
        self.step = 0
Beispiel #3
0
    def __init__(self, dim, embedded_dim):
        """
        Constructor for the input unit.

        :param dim: global 'd' hidden dimension
        :param embedded_dim: dimension of the word embeddings.

        """

        # call base constructor
        super(InputUnit, self).__init__()

        self.dim = dim

        # instantiate image processing (2-layers CNN)
        self.conv = ImageProcessing(dim)

        # define linear layer for the projection of the knowledge base
        self.kb_proj_layer = linear(dim, dim, bias=True)

        # create bidirectional LSTM layer
        self.lstm = nn.LSTM(input_size=embedded_dim,
                            hidden_size=self.dim,
                            num_layers=1,
                            batch_first=True,
                            bidirectional=True)

        # linear layer for projecting the word encodings from 2*dim to dim
        # TODO: linear(2*self.dim, self.dim, bias=True) ?
        self.lstm_proj = nn.Linear(2 * self.dim, self.dim)
Beispiel #4
0
    def __init__(self, dim, nb_classes):
        """
        Constructor for the write unit.

        :param dim: global 'd' dimension
        :param nb_classes: number of classes to consider (classification problem)

        """

        # call base constructor
        super(OutputUnit, self).__init__()

        # define the 2-layers MLP & specify weights initialization
        self.classifier = nn.Sequential(linear(dim * 3, dim, bias=True),
                                        nn.ELU(),
                                        linear(dim, nb_classes, bias=True))
        kaiming_uniform_(self.classifier[0].weight)
Beispiel #5
0
    def __init__(self, dim):
        """
        Constructor for the read unit.

        :param dim: global 'd' hidden dimension

        """

        # call base constructor
        super(ReadUnit, self).__init__()

        # define linear layer for the projection of the previous memory state
        self.mem_proj_layer = linear(dim, dim, bias=True)

        # linear layer to define I'(i,h,w) elements (r2 equation)
        self.concat_layer = linear(2 * dim, dim, bias=True)

        # linear layer to compute attention weights
        self.attn = linear(dim, 1, bias=True)