예제 #1
0
    def forward(self, x):
        self.residual_path1.change_state(False)
        x1 = self.residual_path1(x[0])
        self.residual_path2.change_state(False)
        x2 = self.residual_path2(x[0])

        self.residual_path1.change_state(True)
        x1min = self.residual_path1(x[1])
        self.residual_path2.change_state(True)
        x2min = self.residual_path2(x[1])

        if self.training:
            shake_config = self.shake_config
        else:
            shake_config = (False, False, False)

        alpha, beta = get_alpha_beta(x[0].size(0), shake_config, x[0].is_cuda)
        alpha_min, beta_min = get_alpha_beta(x[1].size(0), shake_config,
                                             x[1].is_cuda)
        y = shake_function(x1, x2, alpha, beta)
        ymin = shake_function(x1min, x2min, alpha_min, beta_min)
        if self.in_channels != self.out_channels:
            self.shortcut[0].change_state(False)
        xsc = self.shortcut(x[0])
        if self.in_channels != self.out_channels:
            self.shortcut[0].change_state(True)
        xscmin = self.shortcut(x[1])

        return xsc + y, xscmin + ymin
예제 #2
0
    def forward(self, x):
        x1 = self.residual_path1(x)
        x2 = self.residual_path2(x)

        if self.training:
            shake_config = self.shake_config
        else:
            shake_config = (False, False, False)

        alpha, beta = get_alpha_beta(x.size(0), shake_config, x.device)
        y = shake_function(x1, x2, alpha, beta)

        return self.shortcut(x) + y