class AttentionModel(nn.Module): def __init__(self, num_classes=61, mem_size=512, no_cam=False, enable_motion_segmentation=False): super(AttentionModel, self).__init__() self.num_classes = num_classes self.noCam = no_cam self.mem_size = mem_size self.enable_motion_segmentation = enable_motion_segmentation self.resnet = resnet34(pretrained=True, noBN=True) self.weight_softmax = self.resnet.fc.weight self.lstm_cell = ConvLSTM(512, mem_size) self.avgpool = nn.AvgPool2d(7) self.dropout = nn.Dropout(0.7) self.fc = nn.Linear(mem_size, self.num_classes) self.classifier = nn.Sequential(self.dropout, self.fc) self.motion_segmentation = MotionSegmentationBlock() self._custom_train_mode = True def train(self, mode=True): correct_values = {True, 'stage2', 'stage1', False} if mode not in correct_values: raise ValueError('Invalid modes, correct values are: ' + ' '.join(correct_values)) self._custom_train_mode = mode # Fai fare il training completo solo se mode == True super().train(mode == True) self.resnet.train(mode) self.lstm_cell.train(mode) if mode == 'stage2' or mode == True: self.motion_segmentation.train(True) if mode != False: self.classifier.train(True) def get_training_parameters(self, name='all'): train_params = [] train_params_ms = [] # Prima levo i gradienti a tutti, e poi li aggiungo solo a quelli # su cui faccio il training for params in self.parameters(): params.requires_grad = False # è responsabilità della funzione negli oggetti aggiungere i gradienti train_params += self.resnet.get_training_parameters() train_params += self.lstm_cell.get_training_parameters() # trainiamo l'ultimo layer a tutti gli stagi, eccetto se non sono in training if self._custom_train_mode != False: for params in self.classifier.parameters(): params.requires_grad = True train_params += [params] train_params_ms = self.motion_segmentation.get_training_parameters() if name == 'all': return train_params + train_params_ms elif name == 'main': return train_params elif name == 'ms': return train_params_ms def load_weights(self, file_path): model_dict = torch.load(file_path) if 'model_state_dict' in model_dict: self.load_state_dict(model_dict['model_state_dict']) else: self.load_state_dict(model_dict) def forward(self, inputVariable): state = (Variable( torch.zeros((inputVariable.size(1), self.mem_size, 7, 7)).cuda()), Variable( torch.zeros( (inputVariable.size(1), self.mem_size, 7, 7)).cuda())) ms_feats = None if self.enable_motion_segmentation: ms_feats = Variable( torch.zeros(inputVariable.size(0), inputVariable.size(1), 49 * 2).cuda()) for t in range(inputVariable.size(0)): logit, feature_conv, feature_convNBN = self.resnet( inputVariable[t]) bz, nc, h, w = feature_conv.size() feature_conv1 = feature_conv.view(bz, nc, h * w) probs, idxs = logit.sort(1, True) class_idx = idxs[:, 0] cam = torch.bmm(self.weight_softmax[class_idx].unsqueeze(1), feature_conv1) attentionMAP = F.softmax(cam.squeeze(1), dim=1) attentionMAP = attentionMAP.view(attentionMAP.size(0), 1, 7, 7) attentionFeat = feature_convNBN * attentionMAP.expand_as( feature_conv) if self.enable_motion_segmentation: ms_feats[t] = self.motion_segmentation(feature_convNBN) if self.noCam: state = self.lstm_cell(feature_convNBN, state) else: state = self.lstm_cell(attentionFeat, state) feats1 = self.avgpool(state[1]).view(state[1].size(0), -1) feats = self.classifier(feats1) return { 'classifications': feats, 'ms_feats': ms_feats, 'lstm_feats': feats1 } def get_class_activation_id(self, inputVariable): logit, _, _ = self.resnet(inputVariable) return logit def get_cam_visualisation(self, input_pil_image, preprocess_for_viz, preprocess_for_model): return get_cam_visualisation(self.resnet, self.weight_softmax, input_pil_image, preprocess_for_viz, preprocess_for_model)
class NewAttentionModelBi(nn.Module): def __init__(self, num_classes=61, mem_size=512, no_cam=False): super(NewAttentionModelBi, self).__init__() self.num_classes = num_classes self.noCam = no_cam self.mem_size = mem_size self.resnet = resnet34(pretrained=True, noBN=True) self.attention_rgb = Variable( (torch.FloatTensor(512).normal_(0, .05)).unsqueeze(0).cuda()) self.attention_flow = Variable( (torch.FloatTensor(512).normal_(0, .05)).unsqueeze(0).cuda()) self.lstm_cell = ConvLSTM(1024, mem_size) self.avgpool = nn.AvgPool2d(7) self.dropout = nn.Dropout(0.7) self.fc = nn.Linear(mem_size, self.num_classes) self.classifier = nn.Sequential(self.dropout, self.fc) self._custom_train_mode = True def train(self, mode=True): correct_values = {True, 'stage2', 'stage1', False} if mode not in correct_values: raise ValueError('Invalid modes, correct values are: ' + ' '.join(correct_values)) self._custom_train_mode = mode super().train(mode == True) self.resnet.train(mode) self.lstm_cell.train(mode) if mode != False: self.classifier.train(True) def get_training_parameters(self): train_params = [] for params in self.parameters(): params.requires_grad = False train_params += self.resnet.get_training_parameters() train_params += self.lstm_cell.get_training_parameters() if self._custom_train_mode != False: for params in self.classifier.parameters(): params.requires_grad = True train_params += [params] self.attention_rgb.requires_grad = True train_params += [self.attention_rgb] self.attention_flow.requires_grad = True train_params += [self.attention_flow] return train_params def load_weights(self, file_path): model_dict = torch.load(file_path) if 'model_state_dict' in model_dict: self.load_state_dict(model_dict['model_state_dict']) else: self.load_state_dict(model_dict) def get_resnet_output_feats(self, resnet, attention, input_frames): logit, feature_conv, feature_convNBN = resnet(input_frames) if self.noCam: return feature_convNBN bz, nc, h, w = feature_conv.size() feature_conv1 = feature_conv.view(bz, nc, h * w) cam = torch.bmm(attention[[0] * input_frames.size(0)].unsqueeze(1), feature_conv1) attentionMAP = F.softmax(cam.squeeze(1), dim=1) attentionMAP = attentionMAP.view(attentionMAP.size(0), 1, 7, 7) attentionFeat = feature_convNBN * attentionMAP.expand_as(feature_conv) return attentionFeat def forward(self, rgb_frames, flow_frames): state = (Variable( torch.zeros((rgb_frames.size(1), self.mem_size, 7, 7)).cuda()), Variable( torch.zeros( (rgb_frames.size(1), self.mem_size, 7, 7)).cuda())) for t in range(rgb_frames.size(0)): rgb_feats = self.get_resnet_output_feats(self.resnet, self.attention_rgb, rgb_frames[t]) flow_feats = self.get_resnet_output_feats(self.resnet, self.attention_flow, flow_frames[t]) state = self.lstm_cell(torch.cat((rgb_feats, flow_feats), dim=1), state) feats1 = self.avgpool(state[1]).view(state[1].size(0), -1) feats = self.classifier(feats1) return {'classifications': feats}