class CAC(BaseModel): def __init__(self, num_classes, conf, sup_loss=None, ignore_index=None, testing=False, pretrained=True): super(CAC, self).__init__() assert int(conf['supervised']) + int( conf['semi']) == 1, 'one mode only' if conf['supervised']: self.mode = 'supervised' elif conf['semi']: self.mode = 'semi' else: raise ValueError('No such mode choice {}'.format(self.mode)) self.ignore_index = ignore_index self.num_classes = num_classes self.sup_loss_w = conf['supervised_w'] self.sup_loss = sup_loss self.downsample = conf['downsample'] self.backbone = conf['backbone'] self.layers = conf['layers'] self.out_dim = conf['out_dim'] self.proj_final_dim = conf['proj_final_dim'] assert self.layers in [50, 101] if self.backbone == 'deeplab_v3+': self.encoder = DeepLab_v3p(backbone='resnet{}'.format(self.layers)) self.classifier = nn.Sequential( nn.Dropout(0.1), nn.Conv2d(256, num_classes, kernel_size=1, stride=1)) for m in self.classifier.modules(): if isinstance(m, nn.Conv2d): torch.nn.init.kaiming_normal_(m.weight) elif isinstance(m, nn.BatchNorm2d): m.weight.data.fill_(1) m.bias.data.zero_() elif isinstance(m, nn.SyncBatchNorm): m.weight.data.fill_(1) m.bias.data.zero_() elif self.backbone == 'psp': self.encoder = Encoder(pretrained=pretrained) self.classifier = nn.Conv2d(self.out_dim, num_classes, kernel_size=1, stride=1) else: raise ValueError("No such backbone {}".format(self.backbone)) if self.mode == 'semi': self.project = nn.Sequential( nn.Conv2d(self.out_dim, self.out_dim, kernel_size=1, stride=1), nn.ReLU(inplace=True), nn.Conv2d(self.out_dim, self.proj_final_dim, kernel_size=1, stride=1)) self.weight_unsup = conf['weight_unsup'] self.temp = conf['temp'] self.epoch_start_unsup = conf['epoch_start_unsup'] self.selected_num = conf['selected_num'] self.step_save = conf['step_save'] self.step_count = 0 self.feature_bank = [] self.pseudo_label_bank = [] self.pos_thresh_value = conf['pos_thresh_value'] self.stride = conf['stride'] def forward(self, x_l=None, target_l=None, x_ul=None, target_ul=None, curr_iter=None, epoch=None, gpu=None, gt_l=None, ul1=None, br1=None, \ ul2=None, br2=None, flip=None): if not self.training: enc = self.encoder(x_l) enc = self.classifier(enc) return F.interpolate(enc, size=x_l.size()[2:], mode='bilinear', align_corners=True) if self.mode == 'supervised': enc = self.encoder(x_l) enc = self.classifier(enc) output_l = F.interpolate(enc, size=x_l.size()[2:], mode='bilinear', align_corners=True) loss_sup = self.sup_loss(output_l, target_l, ignore_index=self.ignore_index, temperature=1.0) * self.sup_loss_w curr_losses = {'loss_sup': loss_sup} outputs = {'sup_pred': output_l} total_loss = loss_sup return total_loss, curr_losses, outputs elif self.mode == 'semi': enc = self.encoder(x_l) enc = self.classifier(enc) output_l = F.interpolate(enc, size=x_l.size()[2:], mode='bilinear', align_corners=True) loss_sup = self.sup_loss(output_l, target_l, ignore_index=self.ignore_index, temperature=1.0) * self.sup_loss_w curr_losses = {'loss_sup': loss_sup} outputs = {'sup_pred': output_l} total_loss = loss_sup if epoch < self.epoch_start_unsup: return total_loss, curr_losses, outputs # x_ul: [batch_size, 2, 3, H, W] x_ul1 = x_ul[:, 0, :, :, :] x_ul2 = x_ul[:, 1, :, :, :] enc_ul1 = self.encoder(x_ul1) if self.downsample: enc_ul1 = F.avg_pool2d(enc_ul1, kernel_size=2, stride=2) output_ul1 = self.project(enc_ul1) #[b, c, h, w] output_ul1 = F.normalize(output_ul1, 2, 1) enc_ul2 = self.encoder(x_ul2) if self.downsample: enc_ul2 = F.avg_pool2d(enc_ul2, kernel_size=2, stride=2) output_ul2 = self.project(enc_ul2) #[b, c, h, w] output_ul2 = F.normalize(output_ul2, 2, 1) # compute pseudo label logits1 = self.classifier( enc_ul1) #[batch_size, num_classes, h, w] logits2 = self.classifier(enc_ul2) pseudo_logits_1 = F.softmax( logits1, 1).max(1)[0].detach() #[batch_size, h, w] pseudo_logits_2 = F.softmax(logits2, 1).max(1)[0].detach() pseudo_label1 = logits1.max(1)[1].detach() #[batch_size, h, w] pseudo_label2 = logits2.max(1)[1].detach() # get overlap part output_feature_list1 = [] output_feature_list2 = [] pseudo_label_list1 = [] pseudo_label_list2 = [] pseudo_logits_list1 = [] pseudo_logits_list2 = [] for idx in range(x_ul1.size(0)): output_ul1_idx = output_ul1[idx] output_ul2_idx = output_ul2[idx] pseudo_label1_idx = pseudo_label1[idx] pseudo_label2_idx = pseudo_label2[idx] pseudo_logits_1_idx = pseudo_logits_1[idx] pseudo_logits_2_idx = pseudo_logits_2[idx] if flip[0][idx] == True: output_ul1_idx = torch.flip(output_ul1_idx, dims=(2, )) pseudo_label1_idx = torch.flip(pseudo_label1_idx, dims=(1, )) pseudo_logits_1_idx = torch.flip(pseudo_logits_1_idx, dims=(1, )) if flip[1][idx] == True: output_ul2_idx = torch.flip(output_ul2_idx, dims=(2, )) pseudo_label2_idx = torch.flip(pseudo_label2_idx, dims=(1, )) pseudo_logits_2_idx = torch.flip(pseudo_logits_2_idx, dims=(1, )) output_feature_list1.append( output_ul1_idx[:, ul1[0][idx] // 8:br1[0][idx] // 8, ul1[1][idx] // 8:br1[1][idx] // 8].permute( 1, 2, 0).contiguous().view( -1, output_ul1.size(1))) output_feature_list2.append( output_ul2_idx[:, ul2[0][idx] // 8:br2[0][idx] // 8, ul2[1][idx] // 8:br2[1][idx] // 8].permute( 1, 2, 0).contiguous().view( -1, output_ul2.size(1))) pseudo_label_list1.append( pseudo_label1_idx[ul1[0][idx] // 8:br1[0][idx] // 8, ul1[1][idx] // 8:br1[1][idx] // 8].contiguous().view(-1)) pseudo_label_list2.append( pseudo_label2_idx[ul2[0][idx] // 8:br2[0][idx] // 8, ul2[1][idx] // 8:br2[1][idx] // 8].contiguous().view(-1)) pseudo_logits_list1.append( pseudo_logits_1_idx[ul1[0][idx] // 8:br1[0][idx] // 8, ul1[1][idx] // 8:br1[1][idx] // 8].contiguous().view(-1)) pseudo_logits_list2.append( pseudo_logits_2_idx[ul2[0][idx] // 8:br2[0][idx] // 8, ul2[1][idx] // 8:br2[1][idx] // 8].contiguous().view(-1)) output_feat1 = torch.cat(output_feature_list1, 0) #[n, c] output_feat2 = torch.cat(output_feature_list2, 0) #[n, c] pseudo_label1_overlap = torch.cat(pseudo_label_list1, 0) #[n,] pseudo_label2_overlap = torch.cat(pseudo_label_list2, 0) #[n,] pseudo_logits1_overlap = torch.cat(pseudo_logits_list1, 0) #[n,] pseudo_logits2_overlap = torch.cat(pseudo_logits_list2, 0) #[n,] assert output_feat1.size(0) == output_feat2.size(0) assert pseudo_label1_overlap.size(0) == pseudo_label2_overlap.size( 0) assert output_feat1.size(0) == pseudo_label1_overlap.size(0) # concat across multi-gpus b, c, h, w = output_ul1.size() selected_num = self.selected_num output_ul1_flatten = output_ul1.permute(0, 2, 3, 1).contiguous().view( b * h * w, c) output_ul2_flatten = output_ul2.permute(0, 2, 3, 1).contiguous().view( b * h * w, c) selected_idx1 = np.random.choice(range(b * h * w), selected_num, replace=False) selected_idx2 = np.random.choice(range(b * h * w), selected_num, replace=False) output_ul1_flatten_selected = output_ul1_flatten[selected_idx1] output_ul2_flatten_selected = output_ul2_flatten[selected_idx2] output_ul_flatten_selected = torch.cat( [output_ul1_flatten_selected, output_ul2_flatten_selected], 0) #[2*kk, c] output_ul_all = self.concat_all_gather( output_ul_flatten_selected) #[2*N, c] pseudo_label1_flatten_selected = pseudo_label1.view( -1)[selected_idx1] pseudo_label2_flatten_selected = pseudo_label2.view( -1)[selected_idx2] pseudo_label_flatten_selected = torch.cat([ pseudo_label1_flatten_selected, pseudo_label2_flatten_selected ], 0) #[2*kk] pseudo_label_all = self.concat_all_gather( pseudo_label_flatten_selected) #[2*N] self.feature_bank.append(output_ul_all) self.pseudo_label_bank.append(pseudo_label_all) if self.step_count > self.step_save: self.feature_bank = self.feature_bank[1:] self.pseudo_label_bank = self.pseudo_label_bank[1:] else: self.step_count += 1 output_ul_all = torch.cat(self.feature_bank, 0) pseudo_label_all = torch.cat(self.pseudo_label_bank, 0) eps = 1e-8 pos1 = (output_feat1 * output_feat2.detach()).sum( -1, keepdim=True) / self.temp #[n, 1] pos2 = (output_feat1.detach() * output_feat2).sum( -1, keepdim=True) / self.temp #[n, 1] # compute loss1 b = 8000 def run1(pos, output_feat1, output_ul_idx, pseudo_label_idx, pseudo_label1_overlap, neg_max1): # print("gpu: {}, i_1: {}".format(gpu, i)) mask1_idx = ( pseudo_label_idx.unsqueeze(0) != pseudo_label1_overlap.unsqueeze(-1)).float() #[n, b] neg1_idx = ( output_feat1 @ output_ul_idx.T) / self.temp #[n, b] logits1_neg_idx = (torch.exp(neg1_idx - neg_max1) * mask1_idx).sum(-1) #[n, ] return logits1_neg_idx def run1_0(pos, output_feat1, output_ul_idx, pseudo_label_idx, pseudo_label1_overlap): # print("gpu: {}, i_1_0: {}".format(gpu, i)) mask1_idx = ( pseudo_label_idx.unsqueeze(0) != pseudo_label1_overlap.unsqueeze(-1)).float() #[n, b] neg1_idx = ( output_feat1 @ output_ul_idx.T) / self.temp #[n, b] neg1_idx = torch.cat([pos, neg1_idx], 1) #[n, 1+b] mask1_idx = torch.cat([ torch.ones(mask1_idx.size(0), 1).float().cuda(), mask1_idx ], 1) #[n, 1+b] neg_max1 = torch.max(neg1_idx, 1, keepdim=True)[0] #[n, 1] logits1_neg_idx = (torch.exp(neg1_idx - neg_max1) * mask1_idx).sum(-1) #[n, ] return logits1_neg_idx, neg_max1 N = output_ul_all.size(0) logits1_down = torch.zeros(pos1.size(0)).float().cuda() for i in range((N - 1) // b + 1): # print("gpu: {}, i: {}".format(gpu, i)) pseudo_label_idx = pseudo_label_all[i * b:(i + 1) * b] output_ul_idx = output_ul_all[i * b:(i + 1) * b] if i == 0: logits1_neg_idx, neg_max1 = torch.utils.checkpoint.checkpoint( run1_0, pos1, output_feat1, output_ul_idx, pseudo_label_idx, pseudo_label1_overlap) else: logits1_neg_idx = torch.utils.checkpoint.checkpoint( run1, pos1, output_feat1, output_ul_idx, pseudo_label_idx, pseudo_label1_overlap, neg_max1) logits1_down += logits1_neg_idx logits1 = torch.exp(pos1 - neg_max1).squeeze(-1) / (logits1_down + eps) pos_mask_1 = ( (pseudo_logits2_overlap > self.pos_thresh_value) & (pseudo_logits1_overlap < pseudo_logits2_overlap)).float() loss1 = -torch.log(logits1 + eps) loss1 = (loss1 * pos_mask_1).sum() / (pos_mask_1.sum() + 1e-12) # compute loss2 def run2(pos, output_feat2, output_ul_idx, pseudo_label_idx, pseudo_label2_overlap, neg_max2): # print("gpu: {}, i_2: {}".format(gpu, i)) mask2_idx = ( pseudo_label_idx.unsqueeze(0) != pseudo_label2_overlap.unsqueeze(-1)).float() #[n, b] neg2_idx = ( output_feat2 @ output_ul_idx.T) / self.temp #[n, b] logits2_neg_idx = (torch.exp(neg2_idx - neg_max2) * mask2_idx).sum(-1) #[n, ] return logits2_neg_idx def run2_0(pos, output_feat2, output_ul_idx, pseudo_label_idx, pseudo_label2_overlap): # print("gpu: {}, i_2_0: {}".format(gpu, i)) mask2_idx = ( pseudo_label_idx.unsqueeze(0) != pseudo_label2_overlap.unsqueeze(-1)).float() #[n, b] neg2_idx = ( output_feat2 @ output_ul_idx.T) / self.temp #[n, b] neg2_idx = torch.cat([pos, neg2_idx], 1) #[n, 1+b] mask2_idx = torch.cat([ torch.ones(mask2_idx.size(0), 1).float().cuda(), mask2_idx ], 1) #[n, 1+b] neg_max2 = torch.max(neg2_idx, 1, keepdim=True)[0] #[n, 1] logits2_neg_idx = (torch.exp(neg2_idx - neg_max2) * mask2_idx).sum(-1) #[n, ] return logits2_neg_idx, neg_max2 N = output_ul_all.size(0) logits2_down = torch.zeros(pos2.size(0)).float().cuda() for i in range((N - 1) // b + 1): pseudo_label_idx = pseudo_label_all[i * b:(i + 1) * b] output_ul_idx = output_ul_all[i * b:(i + 1) * b] if i == 0: logits2_neg_idx, neg_max2 = torch.utils.checkpoint.checkpoint( run2_0, pos2, output_feat2, output_ul_idx, pseudo_label_idx, pseudo_label2_overlap) else: logits2_neg_idx = torch.utils.checkpoint.checkpoint( run2, pos2, output_feat2, output_ul_idx, pseudo_label_idx, pseudo_label2_overlap, neg_max2) logits2_down += logits2_neg_idx logits2 = torch.exp(pos2 - neg_max2).squeeze(-1) / (logits2_down + eps) pos_mask_2 = ( (pseudo_logits1_overlap > self.pos_thresh_value) & (pseudo_logits2_overlap < pseudo_logits1_overlap)).float() loss2 = -torch.log(logits2 + eps) loss2 = (loss2 * pos_mask_2).sum() / (pos_mask_2.sum() + 1e-12) loss_unsup = self.weight_unsup * (loss1 + loss2) curr_losses['loss1'] = loss1 curr_losses['loss2'] = loss2 curr_losses['loss_unsup'] = loss_unsup total_loss = total_loss + loss_unsup return total_loss, curr_losses, outputs else: raise ValueError("No such mode {}".format(self.mode)) def concat_all_gather(self, tensor): """ Performs all_gather operation on the provided tensors. *** Warning ***: torch.distributed.all_gather has no gradient. """ with torch.no_grad(): tensors_gather = [ torch.ones_like(tensor) for _ in range(torch.distributed.get_world_size()) ] torch.distributed.all_gather(tensors_gather, tensor, async_op=False) output = torch.cat(tensors_gather, dim=0) return output def get_backbone_params(self): return self.encoder.get_backbone_params() def get_other_params(self): if self.mode == 'supervised': return chain(self.encoder.get_module_params(), self.classifier.parameters()) elif self.mode == 'semi': return chain(self.encoder.get_module_params(), self.classifier.parameters(), self.project.parameters()) else: raise ValueError("No such mode {}".format(self.mode))
class CCT(BaseModel): def __init__(self, num_classes, conf, sup_loss=None, cons_w_unsup=None, ignore_index=None, testing=False, pretrained=True, use_weak_lables=False, weakly_loss_w=0.4): if not testing: assert (ignore_index is not None) and (sup_loss is not None) and (cons_w_unsup is not None) super(CCT, self).__init__() assert int(conf['supervised']) + int( conf['semi']) == 1, 'one mode only' if conf['supervised']: self.mode = 'supervised' else: self.mode = 'semi' # Supervised and unsupervised losses self.ignore_index = ignore_index if conf['un_loss'] == "KL": self.unsuper_loss = softmax_kl_loss elif conf['un_loss'] == "MSE": self.unsuper_loss = softmax_mse_loss elif conf['un_loss'] == "JS": self.unsuper_loss = softmax_js_loss else: raise ValueError(f"Invalid supervised loss {conf['un_loss']}") self.unsup_loss_w = cons_w_unsup self.sup_loss_w = conf['supervised_w'] self.softmax_temp = conf['softmax_temp'] self.sup_loss = sup_loss self.sup_type = conf['sup_loss'] # Use weak labels self.use_weak_lables = use_weak_lables self.weakly_loss_w = weakly_loss_w # pair wise loss (sup mat) self.aux_constraint = conf['aux_constraint'] self.aux_constraint_w = conf['aux_constraint_w'] # confidence masking (sup mat) self.confidence_th = conf['confidence_th'] self.confidence_masking = conf['confidence_masking'] # Create the model self.encoder = Encoder(pretrained=pretrained) # The main encoder upscale = 8 num_out_ch = 2048 decoder_in_ch = num_out_ch // 4 self.main_decoder = MainDecoder(upscale, decoder_in_ch, num_classes=num_classes) # The auxilary decoders if self.mode == 'semi' or self.mode == 'weakly_semi': vat_decoder = [ VATDecoder(upscale, decoder_in_ch, num_classes, xi=conf['xi'], eps=conf['eps']) for _ in range(conf['vat']) ] drop_decoder = [ DropOutDecoder(upscale, decoder_in_ch, num_classes, drop_rate=conf['drop_rate'], spatial_dropout=conf['spatial']) for _ in range(conf['drop']) ] cut_decoder = [ CutOutDecoder(upscale, decoder_in_ch, num_classes, erase=conf['erase']) for _ in range(conf['cutout']) ] context_m_decoder = [ ContextMaskingDecoder(upscale, decoder_in_ch, num_classes) for _ in range(conf['context_masking']) ] object_masking = [ ObjectMaskingDecoder(upscale, decoder_in_ch, num_classes) for _ in range(conf['object_masking']) ] feature_drop = [ FeatureDropDecoder(upscale, decoder_in_ch, num_classes) for _ in range(conf['feature_drop']) ] feature_noise = [ FeatureNoiseDecoder(upscale, decoder_in_ch, num_classes, uniform_range=conf['uniform_range']) for _ in range(conf['feature_noise']) ] random_rotate = [ RandomRotationDecoder( upscale, decoder_in_ch, num_classes, deg_range=conf['random_rotate_deg_range']) for _ in range(conf['rotate']) ] random_zoom = [ RandomZoomDecoder( upscale, decoder_in_ch, num_classes, zoom_ratio_lower=conf['random_zoom_ratio_lower'], zoom_ratio_upper=conf['random_zoom_ratio_upper']) for _ in range(conf['zoom']) ] random_blur = [ RandomBlurDecoder( upscale, decoder_in_ch, num_classes, sigma_lower=conf['gaussian_blur_sigma_lower'], sigma_upper=conf['gaussian_blur_sigma_upper']) for _ in range(conf['blur']) ] random_brightness = [ RandomBrightnessDecoder( upscale, decoder_in_ch, num_classes, brightness_factor_lower=conf['brightness_factor_lower'], brightness_factor_upper=conf['brightness_factor_upper']) for _ in range(conf['brightness']) ] self.aux_decoders = nn.ModuleList([ *vat_decoder, *drop_decoder, *cut_decoder, *context_m_decoder, *object_masking, *feature_drop, *feature_noise, *random_rotate, *random_zoom, *random_blur, *random_brightness ]) def forward(self, x_l=None, target_l=None, x_ul=None, target_ul=None, curr_iter=None, epoch=None): if not self.training: return self.main_decoder(self.encoder(x_l)) # We compute the losses in the forward pass to avoid problems encountered in muti-gpu # Forward pass the labels example input_size = (x_l.size(2), x_l.size(3)) output_l = self.main_decoder(self.encoder(x_l)) if output_l.shape != x_l.shape: output_l = F.interpolate(output_l, size=input_size, mode='bilinear', align_corners=True) # Supervised loss if self.sup_type == 'CE': loss_sup = self.sup_loss( output_l, target_l, ignore_index=self.ignore_index, temperature=self.softmax_temp) * self.sup_loss_w elif self.sup_type == 'FL': loss_sup = self.sup_loss(output_l, target_l) * self.sup_loss_w else: loss_sup = self.sup_loss( output_l, target_l, curr_iter=curr_iter, epoch=epoch, ignore_index=self.ignore_index) * self.sup_loss_w # If supervised mode only, return if self.mode == 'supervised': curr_losses = {'loss_sup': loss_sup} outputs = {'sup_pred': output_l} total_loss = loss_sup return total_loss, curr_losses, outputs # If semi supervised mode elif self.mode == 'semi': # Get main prediction x_ul = self.encoder(x_ul) output_ul = self.main_decoder(x_ul) # Get auxiliary predictions outputs_ul = [ aux_decoder(x_ul, output_ul.detach()) for aux_decoder in self.aux_decoders ] targets = F.softmax(output_ul.detach(), dim=1) # Compute unsupervised loss loss_unsup = sum([self.unsuper_loss(inputs=u, targets=targets, \ conf_mask=self.confidence_masking, threshold=self.confidence_th, use_softmax=False) for u in outputs_ul]) loss_unsup = (loss_unsup / len(outputs_ul)) curr_losses = {'loss_sup': loss_sup} if output_ul.shape != x_l.shape: output_ul = F.interpolate(output_ul, size=input_size, mode='bilinear', align_corners=True) outputs = {'sup_pred': output_l, 'unsup_pred': output_ul} # Compute the unsupervised loss weight_u = self.unsup_loss_w(epoch=epoch, curr_iter=curr_iter) loss_unsup = loss_unsup * weight_u curr_losses['loss_unsup'] = loss_unsup total_loss = loss_unsup + loss_sup # If case we're using weak lables, add the weak loss term with a weight (self.weakly_loss_w) if self.use_weak_lables: weight_w = (weight_u / self.unsup_loss_w.final_w) * self.weakly_loss_w loss_weakly = sum([ CE_loss(outp, target_ul, ignore_index=self.ignore_index) for outp in outputs_ul ]) / len(outputs_ul) loss_weakly = loss_weakly * weight_w curr_losses['loss_weakly'] = loss_weakly total_loss += loss_weakly # Pair-wise loss if self.aux_constraint: pair_wise = pair_wise_loss(outputs_ul) * self.aux_constraint_w curr_losses['pair_wise'] = pair_wise loss_unsup += pair_wise return total_loss, curr_losses, outputs def get_backbone_params(self): return self.encoder.get_backbone_params() def get_other_params(self): if self.mode == 'semi': return chain(self.encoder.get_module_params(), self.main_decoder.parameters(), self.aux_decoders.parameters()) return chain(self.encoder.get_module_params(), self.main_decoder.parameters())
class CCT(BaseModel): def __init__(self, num_classes, conf, sup_loss=None, cons_w_unsup=None, ignore_index=None, testing=False, pretrained=True, use_weak_lables=False, unsupervised_mode=None, weakly_loss_w=0.4): if not testing: assert (ignore_index is not None) and (sup_loss is not None) and (cons_w_unsup is not None) super(CCT, self).__init__() assert int(conf['supervised']) + int( conf['semi']) == 1, 'one mode only' if conf['supervised']: self.mode = 'supervised' else: self.mode = 'semi' # Supervised and unsupervised losses self.ignore_index = ignore_index if conf['un_loss'] == "KL": self.unsuper_loss = softmax_kl_loss elif conf['un_loss'] == "MSE": self.unsuper_loss = softmax_mse_loss elif conf['un_loss'] == "JS": self.unsuper_loss = softmax_js_loss else: raise ValueError(f"Invalid supervised loss {conf['un_loss']}") self.unsup_loss_w = cons_w_unsup self.sup_loss_w = conf['supervised_w'] self.softmax_temp = conf['softmax_temp'] self.sup_loss = sup_loss self.sup_type = conf['sup_loss'] self.unsupervised_mode = unsupervised_mode # Use weak labels self.use_weak_lables = use_weak_lables self.weakly_loss_w = weakly_loss_w # pair wise loss (sup mat) self.aux_constraint = conf['aux_constraint'] self.aux_constraint_w = conf['aux_constraint_w'] # confidence masking (sup mat) self.confidence_th = conf['confidence_th'] self.confidence_masking = conf['confidence_masking'] # Create the model self.encoder = Encoder(pretrained=pretrained) # The main encoder upscale = 8 num_out_ch = 2048 decoder_in_ch = num_out_ch // 4 self.main_decoder = MainDecoder(upscale, decoder_in_ch, num_classes=num_classes) # The auxilary decoders if self.mode == 'semi' or self.mode == 'weakly_semi': if 'seq' in unsupervised_mode: vat_decoder_seq = [ VATDecoder(upscale, decoder_in_ch, num_classes, xi=conf['xi'], eps=conf['eps']) for _ in range(conf['vat']) ] self.seq_decoder = nn.ModuleList([*vat_decoder_seq]) elif 'pert' in unsupervised_mode: vat_decoder = [ VATDecoder(upscale, decoder_in_ch, num_classes, xi=conf['xi'], eps=conf['eps']) for _ in range(conf['vat']) ] drop_decoder = [ DropOutDecoder(upscale, decoder_in_ch, num_classes, drop_rate=conf['drop_rate'], spatial_dropout=conf['spatial']) for _ in range(conf['drop']) ] cut_decoder = [ CutOutDecoder(upscale, decoder_in_ch, num_classes, erase=conf['erase']) for _ in range(conf['cutout']) ] context_m_decoder = [ ContextMaskingDecoder(upscale, decoder_in_ch, num_classes) for _ in range(conf['context_masking']) ] object_masking = [ ObjectMaskingDecoder(upscale, decoder_in_ch, num_classes) for _ in range(conf['object_masking']) ] feature_drop = [ FeatureDropDecoder(upscale, decoder_in_ch, num_classes) for _ in range(conf['feature_drop']) ] feature_noise = [ FeatureNoiseDecoder(upscale, decoder_in_ch, num_classes, uniform_range=conf['uniform_range']) for _ in range(conf['feature_noise']) ] self.aux_decoders = nn.ModuleList([ *vat_decoder, *drop_decoder, *cut_decoder, *context_m_decoder, *object_masking, *feature_drop, *feature_noise ]) ''' temp = [DropOutDecoder(upscale, decoder_in_ch, num_classes, drop_rate=conf['drop_rate'], spatial_dropout=conf['spatial']) for _ in range(conf['drop'])] self.seq_decoder = MainDecoder(upscale, decoder_in_ch, num_classes=num_classes) ''' #def forward(self, x_l=None, target_l=None, x_ul=None, target_ul=None, curr_iter=None, epoch=None, img_id_l=None, img_id_ul=None): def forward(self, x_l=None, target_l=None, x_ul=None, target_ul=None, curr_iter=None, epoch=None, x_seq_triplet=None, unsupervised_mode=None): '''x_seq_triplet is a list with the lan being equal to the len of the seq each element in the list is a tuple with the following format (imgs:torch.Size([5, 3, 320, 320]), annotations:torch.Size([5, 320, 320, 3]), [list of the img_names]) 5 in this case if the batch size''' if not self.training: return self.main_decoder(self.encoder(x_l)) def get_date(x): date_ = x.split('.')[0] y = '00' if date_[3] == '1': y = '12' elif date_[3] == '0': y = '11' m = date_[4:6] d = date_[6:8] date_ = m + d + y return datetime.datetime.strptime(date_, '%m%d%y') # We compute the losses in the forward pass to avoid problems encountered in muti-gpu # Forward pass the labels example input_size = (x_l.size(2), x_l.size(3)) output_l = self.main_decoder(self.encoder(x_l)) if output_l.shape != x_l.shape: output_l = F.interpolate(output_l, size=input_size, mode='bilinear', align_corners=True) # Supervised loss if self.sup_type == 'CE': loss_sup = self.sup_loss( output_l, target_l, ignore_index=self.ignore_index, temperature=self.softmax_temp) * self.sup_loss_w elif self.sup_type == 'FL': loss_sup = self.sup_loss(output_l, target_l) * self.sup_loss_w else: loss_sup = self.sup_loss( output_l, target_l, curr_iter=curr_iter, epoch=epoch, ignore_index=self.ignore_index) * self.sup_loss_w # If supervised mode only, return if self.mode == 'supervised': curr_losses = {'loss_sup': loss_sup} outputs = {'sup_pred': output_l} total_loss = loss_sup return total_loss, curr_losses, outputs # If semi supervised mode elif self.mode == 'semi': # Get sequence predictions if unsupervised_mode == 'seq' or unsupervised_mode == 'pertAndSeq': seqs = [] for i in range(len(x_seq_triplet)): '''a batch of ith imgs in the seq. If the batch_size ==5: (imgs:torch.Size([5, 3, 320, 320]), annotations:torch.Size([5, 320, 320, 3]), [list of the img_names])''' x_seq = self.encoder(x_seq_triplet[i][0]) output_seq = self.main_decoder(x_seq) seqs.append((x_seq, output_seq)) # Get auxiliary predictions seq_losses = [] ''' seq_losses has as many elements as the number of imgs in each seq. the ith element in seq_losses is the unsupervised loss calculated for the outputs of the maindecoder and the seqdecoder for the batch of ith imgs in seq''' all_outputs_ul = [] '''the ith element in all_outputs_ul is the outputs of the seqdecoder for the batch of the ith imgs in the seq''' for pair in seqs: ''' the number of pairs is the number of the images in each seq each pair hold (the output of the encoder for the batch , the output of the maindecoder on pair[0]) and outputs_ul is the output of seqdecoder for the batch''' #outputs_ul = [aux_decoder(pair[0], pair[1].detach()) for aux_decoder in self.aux_decoders] #outputs_ul = [aux_decoder(pair[0]) for aux_decoder in [self.seq_decoder]] outputs_ul = [ aux_decoder(pair[0], pair[1].detach()) for aux_decoder in self.seq_decoder ] targets = F.softmax(pair[1].detach(), dim=1) all_outputs_ul.append(outputs_ul) # Compute unsupervised loss # this loss is between the output of the main decoder and the seqdecoder's loss_seq = sum([ self.unsuper_loss(inputs=u, targets=targets, conf_mask=self.confidence_masking, threshold=self.confidence_th, use_softmax=False) for u in outputs_ul ]) loss_seq = (loss_seq / len(outputs_ul)) seq_losses.append(loss_seq) loss_seq = np.sum(seq_losses) / len(seqs) loss_seq_accross_decoder = [] for i in range(len(all_outputs_ul)): for j in range(i + 1, len(all_outputs_ul)): pert_loss = 0 for z in range( len(all_outputs_ul[0]) ): #because we have 2 decoders in seqdecoder pert_loss += self.unsuper_loss( inputs=all_outputs_ul[i][z], targets=all_outputs_ul[j][z], conf_mask=self.confidence_masking, threshold=self.confidence_th, use_softmax=False) pert_loss = pert_loss / len(all_outputs_ul[0]) loss_seq_accross_decoder.append(pert_loss) loss_seq += np.sum(loss_seq_accross_decoder) / len( loss_seq_accross_decoder) output_seqs = [] for i in range(len(seqs)): temp_x_seq = seqs[i][0] temp_output_seq = seqs[i][1] if temp_output_seq.shape != temp_x_seq.shape: temp_output_seq = F.interpolate(temp_output_seq, size=input_size, mode='bilinear', align_corners=True) output_seqs.append(temp_output_seq) # TODO: Figure out why we only return the first output from the main decoder. outputs = {'sup_pred': output_l, 'unsup_pred': output_seqs[0]} #loss_seq = (loss_seq / 3) weight_u = self.unsup_loss_w(epoch=epoch, curr_iter=curr_iter) loss_seq = loss_seq * weight_u if unsupervised_mode != 'seq': # Get main prediction x_ul = self.encoder(x_ul) output_ul = self.main_decoder(x_ul) # Get auxiliary predictions outputs_ul = [ aux_decoder(x_ul, output_ul.detach()) for aux_decoder in self.aux_decoders ] targets = F.softmax(output_ul.detach(), dim=1) # Compute unsupervised loss loss_unsup = sum([ self.unsuper_loss(inputs=u, targets=targets, conf_mask=self.confidence_masking, threshold=self.confidence_th, use_softmax=False) for u in outputs_ul ]) loss_unsup = (loss_unsup / len(outputs_ul)) if output_ul.shape != x_l.shape: output_ul = F.interpolate(output_ul, size=input_size, mode='bilinear', align_corners=True) outputs = {'sup_pred': output_l, 'unsup_pred': output_ul} weight_u = self.unsup_loss_w(epoch=epoch, curr_iter=curr_iter) loss_unsup = loss_unsup * weight_u curr_losses = {'loss_sup': loss_sup} if unsupervised_mode == 'seq': curr_losses['loss_unsup'] = loss_seq total_loss = loss_seq + loss_sup elif unsupervised_mode == 'pertAndSeq': curr_losses['loss_unsup'] = loss_seq + loss_unsup total_loss = loss_seq + loss_unsup + loss_sup else: curr_losses['loss_unsup'] = loss_unsup total_loss = loss_unsup + loss_sup # In case we're using weak lables, add the weak loss term with a #weight (self.weakly_loss_w) if self.use_weak_lables: target_ul = target_ul[:, :, :, 0] weight_w = (weight_u / self.unsup_loss_w.final_w) * self.weakly_loss_w loss_weakly = sum([ CE_loss(outp, target_ul, ignore_index=self.ignore_index) for outp in outputs_ul ]) / len(outputs_ul) loss_weakly = loss_weakly * weight_w curr_losses['loss_weakly'] = loss_weakly total_loss += loss_weakly # Pair-wise loss if self.aux_constraint: pair_wise = pair_wise_loss(outputs_ul) * self.aux_constraint_w curr_losses['pair_wise'] = pair_wise loss_unsup += pair_wise return total_loss, curr_losses, outputs def get_backbone_params(self): return self.encoder.get_backbone_params() def get_other_params(self): if self.mode == 'semi': if self.unsupervised_mode == 'pertAndSeq': return chain(self.encoder.get_module_params(), self.main_decoder.parameters(), self.aux_decoders.parameters(), self.seq_decoder.parameters()) elif self.unsupervised_mode == 'seq': return chain(self.encoder.get_module_params(), self.main_decoder.parameters(), self.seq_decoder.parameters()) elif self.unsupervised_mode == 'pert': return chain(self.encoder.get_module_params(), self.main_decoder.parameters(), self.aux_decoders.parameters()) return chain(self.encoder.get_module_params(), self.main_decoder.parameters())