def __init__(self, in_planes, planes, filter_info_inp, filter_info_mid, filter_info_oup, filter_info_shortcut, num_gates_fixed_open, dependent_gates=True, stride=1):
        super(PlainBasicBlock, self).__init__()
        #print 'stride', stride
        self.num_gates_per_sht = 0
        self.shortcut = nn.Sequential()

        num_filters_per_inp, num_gates_per_inp = filter_info_inp
        num_filters_per_mid, num_gates_per_mid = filter_info_mid
        num_filters_per_oup, num_gates_per_oup = filter_info_oup
        self.num_gates_per_inp = num_gates_per_inp
        self.num_gates_per_mid = num_gates_per_mid

        self.gating_check(in_planes, num_filters_per_inp, num_gates_per_inp, num_gates_fixed_open)
        self.gating_check(planes, num_filters_per_mid, num_gates_per_mid, num_gates_fixed_open)
        self.gating_check(planes, num_filters_per_oup, num_gates_per_oup, num_gates_fixed_open)

        self.conv = Sequential_withGumble(
            SpecialGumble(num_gates_fixed_open, num_gates_per_inp, num_filters_per_inp),
            nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False),
            nn.BatchNorm2d(planes),
            nn.ReLU(),
            SpecialGumble(num_gates_fixed_open, num_gates_per_mid, num_filters_per_mid),
            nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False),
            nn.BatchNorm2d(planes),
            SpecialGumble(num_gates_fixed_open, num_gates_per_oup, num_filters_per_oup),
        )

        self.total_gates = self.num_gates_per_sht + num_gates_per_inp + num_gates_per_mid + num_gates_per_oup
        self.splits = (0, self.num_gates_per_inp, self.num_gates_per_inp + self.num_gates_per_mid, self.total_gates)
        print 'dependent_gates', dependent_gates
        if dependent_gates:
            self.gate = gates.DependentGate_Block(in_planes, self.total_gates)
        else:
            self.gate = gates.IndependentGate(in_planes, self.total_gates)
    def __init__(self, in_planes, planes, filter_infos, dependent_gates=True, stride=1):
        super(Bottleneck, self).__init__()

        self.filter_info_list, self.num_gates_fixed_open = filter_infos

        assert(len(self.filter_info_list) == 5)
        for i in range(len(self.filter_info_list)):
            assert(len(self.filter_info_list[i]) == 2)

        self.conv = Sequential_withGumble(
            SpecialGumble(self.num_gates_fixed_open, self.filter_info_list[1][1], self.filter_info_list[1][0]), # expand
            nn.Conv2d(in_planes, planes, kernel_size=1, bias=False),
            nn.BatchNorm2d(planes),
            nn.ReLU(),
            SpecialGumble(self.num_gates_fixed_open, self.filter_info_list[2][1], self.filter_info_list[2][0]), # contract
            nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False),
            nn.BatchNorm2d(planes),
            nn.ReLU(),
            SpecialGumble(self.num_gates_fixed_open, self.filter_info_list[3][1], self.filter_info_list[3][0]), # contract
            nn.Conv2d(planes, self.expansion*planes, kernel_size=1, bias=False),
            nn.BatchNorm2d(self.expansion*planes),
            SpecialGumble(self.num_gates_fixed_open, self.filter_info_list[4][1], self.filter_info_list[4][0]), # expand
        )

        self.shortcut = nn.Sequential()
        self.has_shortcut = stride != 1 or in_planes != self.expansion*planes
        if self.has_shortcut:
            print 'has shortcut!'
            num_filters_per_sht, self.num_gates_per_sht = self.filter_info_list[0]
            self.gating_check(self.expansion*planes, num_filters_per_sht, self.num_gates_per_sht, self.num_gates_fixed_open)
            self.shortcut = Sequential_withGumble(
                nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(self.expansion*planes),
                SpecialGumble(self.num_gates_fixed_open, self.num_gates_per_sht, num_filters_per_sht),
            )
        else:
            self.num_gates_per_sht = 0

        filters_per = [in_planes, planes, planes, self.expansion*planes]

        # Gate layers
        self.total_gates = self.num_gates_per_sht
        self.splits = [0]
        for i, filter_info in enumerate(self.filter_info_list[1:]):
            num_filters_per_gatingstructure, num_gates_per_gatingstructure = filter_info
            self.total_gates += num_gates_per_gatingstructure
            self.gating_check(filters_per[i], num_filters_per_gatingstructure, num_gates_per_gatingstructure, self.num_gates_fixed_open)
            self.splits.append(self.splits[-1] + num_gates_per_gatingstructure)

        assert(self.splits[-1] == self.total_gates - self.num_gates_per_sht), str(self.splits) + ' ' + str(self.total_gates) + ' ' + str(self.num_gates_per_sht)
        self.shortcut_splits = (0, self.num_gates_per_sht)

        print 'dependent_gates', dependent_gates
        if dependent_gates:
            self.gate = gates.DependentGate_Block(in_planes, self.total_gates)
        else:
            self.gate = gates.IndependentGate(in_planes, self.total_gates)