Example #1
0
class AttentionModel(nn.Module):
    def __init__(self,
                 num_classes=61,
                 mem_size=512,
                 no_cam=False,
                 enable_motion_segmentation=False):
        super(AttentionModel, self).__init__()

        self.num_classes = num_classes
        self.noCam = no_cam
        self.mem_size = mem_size
        self.enable_motion_segmentation = enable_motion_segmentation

        self.resnet = resnet34(pretrained=True, noBN=True)
        self.weight_softmax = self.resnet.fc.weight
        self.lstm_cell = ConvLSTM(512, mem_size)
        self.avgpool = nn.AvgPool2d(7)
        self.dropout = nn.Dropout(0.7)
        self.fc = nn.Linear(mem_size, self.num_classes)
        self.classifier = nn.Sequential(self.dropout, self.fc)

        self.motion_segmentation = MotionSegmentationBlock()

        self._custom_train_mode = True

    def train(self, mode=True):
        correct_values = {True, 'stage2', 'stage1', False}

        if mode not in correct_values:
            raise ValueError('Invalid modes, correct values are: ' +
                             ' '.join(correct_values))

        self._custom_train_mode = mode

        # Fai fare il training completo solo se mode == True
        super().train(mode == True)

        self.resnet.train(mode)
        self.lstm_cell.train(mode)
        if mode == 'stage2' or mode == True:
            self.motion_segmentation.train(True)
        if mode != False:
            self.classifier.train(True)

    def get_training_parameters(self, name='all'):
        train_params = []
        train_params_ms = []

        # Prima levo i gradienti a tutti, e poi li aggiungo solo a quelli
        # su cui faccio il training
        for params in self.parameters():
            params.requires_grad = False

        # è responsabilità della funzione negli oggetti aggiungere i gradienti
        train_params += self.resnet.get_training_parameters()
        train_params += self.lstm_cell.get_training_parameters()
        # trainiamo l'ultimo layer a tutti gli stagi, eccetto se non sono in training
        if self._custom_train_mode != False:
            for params in self.classifier.parameters():
                params.requires_grad = True
                train_params += [params]

        train_params_ms = self.motion_segmentation.get_training_parameters()

        if name == 'all':
            return train_params + train_params_ms
        elif name == 'main':
            return train_params
        elif name == 'ms':
            return train_params_ms

    def load_weights(self, file_path):
        model_dict = torch.load(file_path)
        if 'model_state_dict' in model_dict:
            self.load_state_dict(model_dict['model_state_dict'])
        else:
            self.load_state_dict(model_dict)

    def forward(self, inputVariable):
        state = (Variable(
            torch.zeros((inputVariable.size(1), self.mem_size, 7, 7)).cuda()),
                 Variable(
                     torch.zeros(
                         (inputVariable.size(1), self.mem_size, 7, 7)).cuda()))

        ms_feats = None
        if self.enable_motion_segmentation:
            ms_feats = Variable(
                torch.zeros(inputVariable.size(0), inputVariable.size(1),
                            49 * 2).cuda())

        for t in range(inputVariable.size(0)):
            logit, feature_conv, feature_convNBN = self.resnet(
                inputVariable[t])

            bz, nc, h, w = feature_conv.size()
            feature_conv1 = feature_conv.view(bz, nc, h * w)
            probs, idxs = logit.sort(1, True)
            class_idx = idxs[:, 0]
            cam = torch.bmm(self.weight_softmax[class_idx].unsqueeze(1),
                            feature_conv1)
            attentionMAP = F.softmax(cam.squeeze(1), dim=1)
            attentionMAP = attentionMAP.view(attentionMAP.size(0), 1, 7, 7)
            attentionFeat = feature_convNBN * attentionMAP.expand_as(
                feature_conv)

            if self.enable_motion_segmentation:
                ms_feats[t] = self.motion_segmentation(feature_convNBN)

            if self.noCam:
                state = self.lstm_cell(feature_convNBN, state)
            else:
                state = self.lstm_cell(attentionFeat, state)

        feats1 = self.avgpool(state[1]).view(state[1].size(0), -1)
        feats = self.classifier(feats1)

        return {
            'classifications': feats,
            'ms_feats': ms_feats,
            'lstm_feats': feats1
        }

    def get_class_activation_id(self, inputVariable):
        logit, _, _ = self.resnet(inputVariable)
        return logit

    def get_cam_visualisation(self, input_pil_image, preprocess_for_viz,
                              preprocess_for_model):
        return get_cam_visualisation(self.resnet, self.weight_softmax,
                                     input_pil_image, preprocess_for_viz,
                                     preprocess_for_model)
Example #2
0
class NewAttentionModelBi(nn.Module):
    def __init__(self, num_classes=61, mem_size=512, no_cam=False):
        super(NewAttentionModelBi, self).__init__()

        self.num_classes = num_classes
        self.noCam = no_cam
        self.mem_size = mem_size

        self.resnet = resnet34(pretrained=True, noBN=True)

        self.attention_rgb = Variable(
            (torch.FloatTensor(512).normal_(0, .05)).unsqueeze(0).cuda())
        self.attention_flow = Variable(
            (torch.FloatTensor(512).normal_(0, .05)).unsqueeze(0).cuda())

        self.lstm_cell = ConvLSTM(1024, mem_size)

        self.avgpool = nn.AvgPool2d(7)
        self.dropout = nn.Dropout(0.7)
        self.fc = nn.Linear(mem_size, self.num_classes)
        self.classifier = nn.Sequential(self.dropout, self.fc)

        self._custom_train_mode = True

    def train(self, mode=True):
        correct_values = {True, 'stage2', 'stage1', False}

        if mode not in correct_values:
            raise ValueError('Invalid modes, correct values are: ' +
                             ' '.join(correct_values))

        self._custom_train_mode = mode

        super().train(mode == True)

        self.resnet.train(mode)
        self.lstm_cell.train(mode)
        if mode != False:
            self.classifier.train(True)

    def get_training_parameters(self):
        train_params = []

        for params in self.parameters():
            params.requires_grad = False

        train_params += self.resnet.get_training_parameters()

        train_params += self.lstm_cell.get_training_parameters()

        if self._custom_train_mode != False:
            for params in self.classifier.parameters():
                params.requires_grad = True
                train_params += [params]
            self.attention_rgb.requires_grad = True
            train_params += [self.attention_rgb]
            self.attention_flow.requires_grad = True
            train_params += [self.attention_flow]

        return train_params

    def load_weights(self, file_path):
        model_dict = torch.load(file_path)
        if 'model_state_dict' in model_dict:
            self.load_state_dict(model_dict['model_state_dict'])
        else:
            self.load_state_dict(model_dict)

    def get_resnet_output_feats(self, resnet, attention, input_frames):
        logit, feature_conv, feature_convNBN = resnet(input_frames)

        if self.noCam:
            return feature_convNBN

        bz, nc, h, w = feature_conv.size()
        feature_conv1 = feature_conv.view(bz, nc, h * w)

        cam = torch.bmm(attention[[0] * input_frames.size(0)].unsqueeze(1),
                        feature_conv1)
        attentionMAP = F.softmax(cam.squeeze(1), dim=1)
        attentionMAP = attentionMAP.view(attentionMAP.size(0), 1, 7, 7)
        attentionFeat = feature_convNBN * attentionMAP.expand_as(feature_conv)

        return attentionFeat

    def forward(self, rgb_frames, flow_frames):
        state = (Variable(
            torch.zeros((rgb_frames.size(1), self.mem_size, 7, 7)).cuda()),
                 Variable(
                     torch.zeros(
                         (rgb_frames.size(1), self.mem_size, 7, 7)).cuda()))

        for t in range(rgb_frames.size(0)):
            rgb_feats = self.get_resnet_output_feats(self.resnet,
                                                     self.attention_rgb,
                                                     rgb_frames[t])
            flow_feats = self.get_resnet_output_feats(self.resnet,
                                                      self.attention_flow,
                                                      flow_frames[t])
            state = self.lstm_cell(torch.cat((rgb_feats, flow_feats), dim=1),
                                   state)

        feats1 = self.avgpool(state[1]).view(state[1].size(0), -1)
        feats = self.classifier(feats1)

        return {'classifications': feats}