예제 #1
0
 def forward(self, x):
     new_features = super(_DenseLayer, self).forward(x)
     if self.drop_rate > 0:
         new_features = F.dropout(new_features,
                                  p=self.drop_rate,
                                  training=self.training)
     return torchgraph.cat([x, new_features], 1)
예제 #2
0
    def forward(self, x):
        branch3x3 = self.branch3x3(x)

        branch3x3dbl = self.branch3x3dbl_1(x)
        branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
        branch3x3dbl = self.branch3x3dbl_3(branch3x3dbl)

        branch_pool = self.max_pool(x)

        outputs = [branch3x3, branch3x3dbl, branch_pool]
        return torchgraph.cat(outputs, 1)
예제 #3
0
    def forward(self, x):
        branch3x3 = self.branch3x3_1(x)
        branch3x3 = self.branch3x3_2(branch3x3)

        branch7x7x3 = self.branch7x7x3_1(x)
        branch7x7x3 = self.branch7x7x3_2(branch7x7x3)
        branch7x7x3 = self.branch7x7x3_3(branch7x7x3)
        branch7x7x3 = self.branch7x7x3_4(branch7x7x3)

        branch_pool = self.max_pool(x)
        outputs = [branch3x3, branch7x7x3, branch_pool]
        return torchgraph.cat(outputs, 1)
예제 #4
0
    def forward(self, x, x_prev):
        x_relu = self.relu(x_prev)
        # path 1
        x_path1 = self.path_1(x_relu)
        # path 2
        x_path2 = self.path_2(x_relu[:, :, 1:, 1:])
        # final path
        x_left = self.final_path_bn(torchgraph.cat([x_path1, x_path2], 1))

        x_right = self.conv_1x1(x)

        return NormalCellBranchCombine(self, x_left, x_right)
예제 #5
0
    def forward(self, x_conv0, x_stem_0):
        x_left = self.conv_1x1(x_stem_0)

        x_relu = self.relu(x_conv0)
        # path 1
        x_path1 = self.path_1(x_relu)
        # path 2
        x_path2 = self.path_2(x_relu[:, :, 1:, 1:])
        # final path
        x_right = self.final_path_bn(torchgraph.cat([x_path1, x_path2], 1))

        return ReductionCellBranchCombine(self, x_left, x_right)
예제 #6
0
    def forward(self, x):
        branch1x1 = self.branch1x1(x)

        branch3x3 = self.branch3x3_1(x)
        branch3x3 = [
            self.branch3x3_2a(branch3x3),
            self.branch3x3_2b(branch3x3),
        ]
        branch3x3 = torchgraph.cat(branch3x3, 1)

        branch3x3dbl = self.branch3x3dbl_1(x)
        branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
        branch3x3dbl = [
            self.branch3x3dbl_3a(branch3x3dbl),
            self.branch3x3dbl_3b(branch3x3dbl),
        ]
        branch3x3dbl = torchgraph.cat(branch3x3dbl, 1)

        branch_pool = self.avg_pool(x)
        branch_pool = self.branch_pool(branch_pool)

        outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool]
        return torchgraph.cat(outputs, 1)
예제 #7
0
    def forward(self, x):
        branch1x1 = self.branch1x1(x)

        branch5x5 = self.branch5x5_1(x)
        branch5x5 = self.branch5x5_2(branch5x5)

        branch3x3dbl = self.branch3x3dbl_1(x)
        branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
        branch3x3dbl = self.branch3x3dbl_3(branch3x3dbl)

        branch_pool = self.avg_pool(x)
        branch_pool = self.branch_pool(branch_pool)

        outputs = [branch1x1, branch5x5, branch3x3dbl, branch_pool]
        return torchgraph.cat(outputs, 1)
예제 #8
0
    def forward(self, inputs, context, inference=False):
        """
        Execute the decoder.

        :param inputs: tensor with inputs to the decoder
        :param context: state of encoder, encoder sequence lengths and hidden
            state of decoder's LSTM layers
        :param inference: if True stores and repackages hidden state
        """
        self.inference = inference

        enc_context, enc_len, hidden = context
        hidden = self.init_hidden(hidden)

        x = self.embedder(inputs)

        x, h, attn, scores = self.att_rnn(x, hidden[0], enc_context, enc_len)
        self.append_hidden(h)

        x = self.dropout(x)
        x = torchgraph.cat((x, attn), dim=2)
        x, h = self.rnn_layers[0](x, hidden[1])
        self.append_hidden(h)

        for i in range(1, len(self.rnn_layers)):
            residual = x
            x = self.dropout(x)
            x = torchgraph.cat((x, attn), dim=2)
            x, h = self.rnn_layers[i](x, hidden[i + 1])
            self.append_hidden(h)
            x = x + residual

        x = self.classifier(x)
        hidden = self.package_hidden()

        return x, scores, [enc_context, enc_len, hidden]
예제 #9
0
    def forward(self, x):
        branch1x1 = self.branch1x1(x)

        branch7x7 = self.branch7x7_1(x)
        branch7x7 = self.branch7x7_2(branch7x7)
        branch7x7 = self.branch7x7_3(branch7x7)

        branch7x7dbl = self.branch7x7dbl_1(x)
        branch7x7dbl = self.branch7x7dbl_2(branch7x7dbl)
        branch7x7dbl = self.branch7x7dbl_3(branch7x7dbl)
        branch7x7dbl = self.branch7x7dbl_4(branch7x7dbl)
        branch7x7dbl = self.branch7x7dbl_5(branch7x7dbl)

        branch_pool = self.avg_pool(x)
        branch_pool = self.branch_pool(branch_pool)

        outputs = [branch1x1, branch7x7, branch7x7dbl, branch_pool]
        return torchgraph.cat(outputs, 1)
예제 #10
0
def ReductionCellBranchCombine(cell, x_left, x_right):

    x_comb_iter_0_left = cell.comb_iter_0_left(x_left)
    x_comb_iter_0_right = cell.comb_iter_0_right(x_right)
    x_comb_iter_0 = x_comb_iter_0_left + x_comb_iter_0_right

    x_comb_iter_1_left = cell.comb_iter_1_left(x_left)
    x_comb_iter_1_right = cell.comb_iter_1_right(x_right)
    x_comb_iter_1 = x_comb_iter_1_left + x_comb_iter_1_right

    x_comb_iter_2_left = cell.comb_iter_2_left(x_left)
    x_comb_iter_2_right = cell.comb_iter_2_right(x_right)
    x_comb_iter_2 = x_comb_iter_2_left + x_comb_iter_2_right

    x_comb_iter_3_right = cell.comb_iter_3_right(x_comb_iter_0)
    x_comb_iter_3 = x_comb_iter_3_right + x_comb_iter_1

    x_comb_iter_4_left = cell.comb_iter_4_left(x_comb_iter_0)
    x_comb_iter_4_right = cell.comb_iter_4_right(x_left)
    x_comb_iter_4 = x_comb_iter_4_left + x_comb_iter_4_right

    x_out = torchgraph.cat(
        [x_comb_iter_1, x_comb_iter_2, x_comb_iter_3, x_comb_iter_4], 1)
    return x_out
예제 #11
0
 def forward(self, x):
     x = self.squeeze_activation(self.squeeze(x))
     return torchgraph.cat([
         self.expand1x1_activation(self.expand1x1(x)),
         self.expand3x3_activation(self.expand3x3(x))
     ], 1)