class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, in_planes, planes, stride=1, test=False):
        super(BasicBlock, self).__init__()
        self.test = test
        self.conv1 = nn.Conv2d(in_planes,
                               planes,
                               kernel_size=3,
                               stride=stride,
                               padding=1,
                               bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes,
                               planes,
                               kernel_size=3,
                               stride=1,
                               padding=1,
                               bias=False)
        self.bn2 = nn.BatchNorm2d(planes)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != self.expansion * planes:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_planes,
                          self.expansion * planes,
                          kernel_size=1,
                          stride=stride,
                          bias=False), nn.BatchNorm2d(self.expansion * planes))

        # Gate layers
        self.w = nn.Parameter(torch.cuda.FloatTensor([.1, 4]).view((2, 1, 1)))
        self.gs = GumbleSoftmax()
        self.gs.cuda()

    def forward(self, x, temperature=1):
        # Compute relevance score
        w = self.w
        w = w.expand(x.shape[0], 2, 1, 1)
        w = self.gs(w, temp=temperature, force_hard=True)

        # TODO(chi): Write the test code
        #print(w[:,1].unsqueeze(1))
        #if self.test and w[:,1].unsqueeze(1) == 0:
        #    out = self.shortcut(x)
        #    return out, w[:,1]

        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out = self.shortcut(x) + out * w[:, 1].unsqueeze(1)
        out = F.relu(out)
        # Return output of layer and the value of the gate
        # The value of the gate will be used in the target rate loss

        return out, w[:, 1]
class Bottleneck(nn.Module):
    expansion = 4

    def __init__(self, in_planes, planes, stride=1, test=False):
        super(Bottleneck, self).__init__()
        self.test = test
        self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)
        self.conv3 = nn.Conv2d(planes, self.expansion*planes, kernel_size=1, bias=False)
        self.bn3 = nn.BatchNorm2d(self.expansion*planes)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != self.expansion*planes:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(self.expansion*planes)
            )

        # Gate layers
        self.fc1 = nn.Conv2d(in_planes, 16, kernel_size=1)
        self.fc1bn = nn.BatchNorm1d(16)
        self.fc2 = nn.Conv2d(16, 2, kernel_size=1)
        # initialize the bias of the last fc for 
        # initial opening rate of the gate of about 85%
        self.fc2.bias.data[0] = 0.1
        self.fc2.bias.data[1] = 2

        self.gs = GumbleSoftmax()
        self.gs.cuda()

    def forward(self, x, temperature=1):
        # Compute relevance score
        w = F.avg_pool2d(x, x.size(2))
        w = F.relu(self.fc1bn(self.fc1(w)))
        w = self.fc2(w)
        # Sample from Gumble Module
        w = self.gs(w, temp=temperature, force_hard=True)

        # TODO(chi): For fast inference, check decision of gate and jump right 
        # to the next layer if needed.
        #if self.test and w[:,1].unsqueeze(1) == 0:
        #    out = self.shortcut(x)
        #    return out, w[:, 1]

        out = F.relu(self.bn1(self.conv1(x)), inplace=True)
        out = F.relu(self.bn2(self.conv2(out)), inplace=True)
        out = self.bn3(self.conv3(out))
        out = self.shortcut(x) + out * w[:,1].unsqueeze(1)
        out = F.relu(out, inplace=True)
        # Return output of layer and the value of the gate
        # The value of the gate will be used in the target rate loss
        return out, w[:, 1]
Example #3
0
class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, in_planes, planes, stride=1, test=False):
        super(BasicBlock, self).__init__()
        self.test = test
        self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != self.expansion*planes:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(self.expansion*planes)
            )

        # Gate layers
        self.fc1 = nn.Conv2d(in_planes, 16, kernel_size=1)
        self.fc1bn = nn.BatchNorm1d(16)
        self.fc2 = nn.Conv2d(16, 2, kernel_size=1)
        # initialize the bias of the last fc for 
        # initial opening rate of the gate of about 85%
        self.fc2.bias.data[0] = 0.1
        self.fc2.bias.data[1] = 2
        self.gs = GumbleSoftmax()
        self.gs.cuda()

    def forward(self, x, temperature=1, gate_mode='stochastic'):
        assert(gate_mode in ['stochastic', 'always_on', 'argmax'])

        # Compute relevance score
        w = F.avg_pool2d(x, x.size(2))
        w = F.relu(self.fc1bn(self.fc1(w)))
        w = self.fc2(w)
        # Sample from Gumble Module
#        print 'fc before gumble', w.shape

        if gate_mode == "argmax":
          _, max_value_indexes = w.data.max(1, keepdim=True) #max_values_indices is batchsize x 1 and is 0 or 1.
          output_multiplier = max_value_indexes.unsqueeze(1)
        elif gate_mode == "stochastic":
          w = self.gs(w, temp=temperature, force_hard=True)
          output_multiplier = w[:,1].unsqueeze(1)
        elif gate_mode == "always_on":
          output_multiplier = torch.ones(w[:,1].unsqueeze(1).size())
        else:
          assert(False) # Error: added a possible gate mode without implementing it.

        # TODO(chi): Write the test code
        #print(w[:,1].unsqueeze(1))
        #if self.test and w[:,1].unsqueeze(1) == 0:
        #    out = self.shortcut(x)
        #    return out, w[:,1]

        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out = self.shortcut(x) + out * output_multiplier
        out = F.relu(out)
        # Return output of layer and the value of the gate
        # The value of the gate will be used in the target rate loss

        return out, output_multiplier.squeeze(1)
Example #4
0
class Bottleneck(nn.Module):
    expansion = 4

    def __init__(self, in_planes, planes, stride=1, test=False):
        super(Bottleneck, self).__init__()
        self.test = test
        self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)
        self.conv3 = nn.Conv2d(planes, self.expansion*planes, kernel_size=1, bias=False)
        self.bn3 = nn.BatchNorm2d(self.expansion*planes)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != self.expansion*planes:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(self.expansion*planes)
            )

        # Gate layers
        self.fc1 = nn.Conv2d(in_planes, 16, kernel_size=1)
        self.fc1bn = nn.BatchNorm1d(16)
        self.fc2 = nn.Conv2d(16, 2, kernel_size=1)
        # initialize the bias of the last fc for 
        # initial opening rate of the gate of about 85%
        self.fc2.bias.data[0] = 0.1
        self.fc2.bias.data[1] = 2

        self.gs = GumbleSoftmax()
        self.gs.cuda()

    def forward(self, x, temperature=1, gate_mode='stochastic', threshold=.5):
        assert(gate_mode in ['stochastic', 'always_on', 'argmax'])

        # Compute relevance score
        w = F.avg_pool2d(x, x.size(2))
        w1bn = self.fc1(w)
        w = self.fc1bn(w1bn)
        w = F.relu(w)
        w = self.fc2(w)

        out = F.relu(self.bn1(self.conv1(x)), inplace=True)
        out = F.relu(self.bn2(self.conv2(out)), inplace=True)
        out = self.bn3(self.conv3(out))

        if gate_mode == "argmax":
            _, max_value_indexes = w.data.max(1, keepdim=True) #max_values_indices is batchsize x 1 and is 0 or 1.
            output_multiplier = torch.autograd.Variable(max_value_indexes.float(), volatile=True)
        elif gate_mode == 'threshold':
            output_multiplier = torch.autograd.Variable(torch.gt(w[:,1], threshold).unsqueeze(1), volatile=True)
        elif gate_mode == "stochastic":
            w = self.gs(w, temp=temperature, force_hard=True)
            output_multiplier = w[:,1].unsqueeze(1)
        elif gate_mode == "always_on":
            output_multiplier = torch.autograd.Variable(torch.ones(w[:,1].unsqueeze(1).size()).cuda(), volatile=True)
        else:
          assert(False) # Error: added a possible gate mode without implementing it.



        out = self.shortcut(x) + out * output_multiplier
        out = F.relu(out, inplace=True)
        # Return output of layer and the value of the gate
        # The value of the gate will be used in the target rate loss
        return out, output_multiplier.squeeze(1), w1bn
class Bottleneck(nn.Module):
    expansion = 4

    def __init__(self, in_planes, planes, stride=1, test=False):
        super(Bottleneck, self).__init__()
        self.test = test
        self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes,
                               planes,
                               kernel_size=3,
                               stride=stride,
                               padding=1,
                               bias=False)
        self.bn2 = nn.BatchNorm2d(planes)
        self.conv3 = nn.Conv2d(planes,
                               self.expansion * planes,
                               kernel_size=1,
                               bias=False)
        self.bn3 = nn.BatchNorm2d(self.expansion * planes)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != self.expansion * planes:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_planes,
                          self.expansion * planes,
                          kernel_size=1,
                          stride=stride,
                          bias=False), nn.BatchNorm2d(self.expansion * planes))

        # Gate layers
        self.w = nn.Parameter(torch.cuda.FloatTensor([.1, 4]).view((2, 1, 1)))
        self.gs = GumbleSoftmax()
        self.gs.cuda()

    def forward(self, x, temperature=1, gate_mode='stochastic', prob=1):
        assert (gate_mode in [
            'stochastic', 'always_on', 'argmax', 'stochastic-variable'
        ])

        # Compute relevance score
        w = self.w
        w = w.expand(x.shape[0], 2, 1, 1)

        out = F.relu(self.bn1(self.conv1(x)), inplace=True)
        out = F.relu(self.bn2(self.conv2(out)), inplace=True)
        out = self.bn3(self.conv3(out))

        if gate_mode == "argmax":
            _, max_value_indexes = w.data.max(
                1, keepdim=True
            )  #max_values_indices is batchsize x 1 and is 0 or 1.
            output_multiplier = torch.autograd.Variable(
                max_value_indexes.float(), volatile=True)
#            output_on = output_multiplier
        elif gate_mode == "stochastic":
            w = self.gs(w, temp=temperature, force_hard=True)
            output_multiplier = w[:, 1].unsqueeze(1)
#            output_on = output_multiplier
        elif gate_mode == "stochastic-variable":
            w_prob = self.gs(w * prob, temp=temperature, force_hard=True)
            w = self.gs(w_prob, temp=temperature, force_hard=True)
            output_multiplier = w[:, 1].unsqueeze(1)

#            w = w.detach()
#            w_soft, w_soft_index = self.gs.gumbel_softmax_sample(w, temperature).data.max(1, keepdim=True)
#            wprob_soft, _ = self.gs.gumbel_softmax_sample(w*prob, temperature).data.max(1, keepdim=True)
#            print 'w_soft', w_soft
#            print 'wprob_soft', wprob_soft
#            exit(1)
#            print w_soft
#            not_output_multiplier = torch.autograd.Variable(torch.ones(output_multiplier.size()).cuda(), volatile=True) - output_multiplier
#            not_coeff = torch.autograd.Variable(w_soft / (torch.ones_like(w_soft) - w_soft), volatile=True)
#            not_coeff.requires_grad = False
#            out = out.detach()
#            shortcut_x = self.shortcut(x)
#            shortcut_x = shortcut_x.detach()
#            print 'shapes', shortcut_x.shape, out.shape, output_multiplier.shape, not_output_multiplier.shape, not_coeff.shape
        elif gate_mode == "always_on":
            #            w_soft = torch.autograd.Variable(self.gs.gumbel_softmax_sample(w, temperature).data, volatile=True)
            #            print 'w_soft.shape', w_soft.shape
            output_multiplier = torch.autograd.Variable(
                torch.ones(out.size()).cuda(),
                volatile=True)  # * w_soft[:,1].unsqueeze(1)
#            output_on = torch.autograd.Variable(torch.ones(out.size()).cuda(), volatile=True)
        else:
            assert (
                False
            )  # Error: added a possible gate mode without implementing it.


#        if gate_mode != 'stochastic-variable':
        out = self.shortcut(x) + out * output_multiplier
        out = F.relu(out, inplace=True)
        # Return output of layer and the value of the gate
        # The value of the gate will be used in the target rate loss
        return out, output_multiplier.squeeze(1)