def test(self):
        self.net.eval()
        labels = np.arange(self.opt.DATASET.NUM_CLASSES)

        preds = []
        gts = []
        cm = []

        for sample in iter(self.test_data['loader']):
            data, gt = to_cuda(sample['Img']), to_cuda(sample['Label'])
            logits = self.net(data)['logits']
            pred = torch.max(logits, dim=1).indices
            preds += [logits]
            gts += [gt]
            try:
                cm += confusion_matrix(gt.cpu(), pred.cpu(), labels)
            except:
                cm = confusion_matrix(gt.cpu(), pred.cpu(), labels)

        print('==================')
        print('Confusion matrix:')
        print(cm)
        print('==================')

        preds = torch.cat(preds, dim=0)
        gts = torch.cat(gts, dim=0)

        res = self.model_eval(preds, gts)
        return res
예제 #2
0
    def update_network(self):
        # initial configuration
        stop = False
        update_iters = 0
        self.train_data[self.source_name]['iterator'] = iter(
            self.train_data[self.source_name]['loader'])
        while not stop:
            loss = 0
            # update learning rate
            self.update_lr()

            # set the status of network
            self.net.train()
            self.net.zero_grad()

            # coventional sampling for training on labeled source data
            source_sample = self.get_samples(self.train_domain)
            source_data, source_gt = source_sample['Img'],\
                          source_sample['Label']

            source_data = to_cuda(source_data)
            source_gt = to_cuda(source_gt)
            self.net.module.set_bn_domain()
            source_preds = self.net(source_data)['logits']

            # compute the cross-entropy loss
            ce_loss = self.CELoss(source_preds, source_gt)
            ce_loss.backward()
            loss += ce_loss

            # update the network
            self.optimizer.step()

            if self.opt.TRAIN.LOGGING and (update_iters+1) % \
                      (max(1, self.iters_per_loop // 10)) == 0:
                accu = self.model_eval(source_preds, source_gt)
                cur_loss = {'ce_loss': ce_loss}
                self.logging(cur_loss, accu)

            if self.opt.TRAIN.TEST_INTERVAL > 0 and \
  (self.iters+1) % int(self.opt.TRAIN.TEST_INTERVAL * self.iters_per_loop) == 0:
                with torch.no_grad():
                    self.net.module.set_bn_domain()
                    accu = self.test()
                print('Test at (loop %d, iters %d) with %s: %.4f.' %
                      (self.loop, self.iters, self.opt.EVAL_METRIC, accu))

            if self.opt.TRAIN.SAVE_CKPT_INTERVAL > 0 and \
  (self.iters+1) % int(self.opt.TRAIN.SAVE_CKPT_INTERVAL * self.iters_per_loop) == 0:
                self.save_ckpt()

            update_iters += 1
            self.iters += 1

            # update stop condition
            if update_iters >= self.iters_per_loop:
                stop = True
            else:
                stop = False
예제 #3
0
    def D_step(self, x_S, x_T):
        self.set_domain_id(0)
        preds_D_S = self.net_D(x_S.detach())
        self.set_domain_id(1)
        preds_D_T = self.net_D(x_T.detach())

        preds_D = torch.cat((preds_D_S, preds_D_T), dim=0)

        gt_D_S = to_cuda(torch.FloatTensor(preds_D_S.size()).fill_(1.0))
        gt_D_T = to_cuda(torch.FloatTensor(preds_D_T.size()).fill_(0.0))
        gt_D = torch.cat((gt_D_S, gt_D_T), dim=0)

        loss_D = self.BCELoss(preds_D, gt_D)
        return loss_D
예제 #4
0
    def test(self):
        self.net.eval()
        preds = []
        gts = []
        for sample in iter(self.test_data['loader']):
            data, gt = to_cuda(sample['Img']), to_cuda(sample['Label'])
            logits = self.net(data)['logits']
            preds += [logits]
            gts += [gt]

        preds = torch.cat(preds, dim=0)
        gts = torch.cat(gts, dim=0)

        res = self.model_eval(preds, gts)
        return res
예제 #5
0
    def test(self):
        vout_all, vlabel_all = [], []
        self.net.eval()
        for sample in iter(self.test_data['loader']):
            _, _, vclip, vlabel = sample
            vclip = to_cuda(vclip)
            vlabel = to_cuda(vlabel)
            vout = self.net(vclip)
            vout_all += [vout]
            vlabel_all += [vlabel]

        vout_all = torch.cat(vout_all, dim=0)
        vlabel_all = torch.cat(vlabel_all, dim=0)
        iou, iou_pos, accu = self.model_eval(vout_all, vlabel_all)
        return iou, iou_pos, accu
    def patch_mean(self, nums_row, nums_col, dist, domain_probs_expand):
        assert (len(nums_row) == len(nums_col))
        num_classes = len(nums_row)
        num_domains = dist.size()[2]

        mean_tensor = to_cuda(
            torch.zeros([num_classes, num_classes, num_domains]))
        row_start = row_end = 0
        for row in range(num_classes):
            row_start = row_end
            row_end = row_start + nums_row[row]

            col_start = col_end = 0
            for col in range(num_classes):
                col_start = col_end
                col_end = col_start + nums_col[col]
                num = torch.sum(
                    dist.narrow(0, row_start,
                                nums_row[row]).narrow(1, col_start,
                                                      nums_col[col]), [0, 1])
                den = torch.sum(
                    domain_probs_expand.narrow(0, row_start,
                                               nums_row[row]).narrow(
                                                   1, col_start,
                                                   nums_col[col]), [0, 1])
                mean_tensor[row, col] = num / den
        return mean_tensor
    def forward(self, source, target, nums_S, nums_T):
        assert(len(nums_S) == len(nums_T)), \
             "The number of classes for source (%d) and target (%d) should be the same." \
             % (len(nums_S), len(nums_T))

        num_classes = len(nums_S)

        # compute the dist 
        dist_layers = []
        gamma_layers = []

        for i in range(self.num_layers):

            cur_source = source[i]
            cur_target = target[i]

            dist = {}
            dist['ss'] = self.compute_paired_dist(cur_source, cur_source)
            dist['tt'] = self.compute_paired_dist(cur_target, cur_target)
            dist['st'] = self.compute_paired_dist(cur_source, cur_target)

            dist['ss'] = self.split_classwise(dist['ss'], nums_S)
            dist['tt'] = self.split_classwise(dist['tt'], nums_T)
            dist_layers += [dist]

            gamma_layers += [self.patch_gamma_estimation(nums_S, nums_T, dist)]

        # compute the kernel dist
        for i in range(self.num_layers):
            for c in range(num_classes):
                gamma_layers[i]['ss'][c] = gamma_layers[i]['ss'][c].view(num_classes, 1, 1)
                gamma_layers[i]['tt'][c] = gamma_layers[i]['tt'][c].view(num_classes, 1, 1)

        kernel_dist_st = self.kernel_layer_aggregation(dist_layers, gamma_layers, 'st')
        kernel_dist_st = self.patch_mean(nums_S, nums_T, kernel_dist_st)

        kernel_dist_ss = []
        kernel_dist_tt = []
        for c in range(num_classes):
            kernel_dist_ss += [torch.mean(self.kernel_layer_aggregation(dist_layers, 
                             gamma_layers, 'ss', c).view(num_classes, -1), dim=1)]
            kernel_dist_tt += [torch.mean(self.kernel_layer_aggregation(dist_layers, 
                             gamma_layers, 'tt', c).view(num_classes, -1), dim=1)]

        kernel_dist_ss = torch.stack(kernel_dist_ss, dim=0)
        kernel_dist_tt = torch.stack(kernel_dist_tt, dim=0).transpose(1, 0)

        mmds = kernel_dist_ss + kernel_dist_tt - 2 * kernel_dist_st
        intra_mmds = torch.diag(mmds, 0)
        intra = torch.sum(intra_mmds) / self.num_classes

        inter = None
        if not self.intra_only:
            inter_mask = to_cuda((torch.ones([num_classes, num_classes]) \
                    - torch.eye(num_classes)).type(torch.bool))
            inter_mmds = torch.masked_select(mmds, inter_mask)
            inter = torch.sum(inter_mmds) / (self.num_classes * (self.num_classes - 1))

        cdd = intra if inter is None else intra - inter
        return {'cdd': cdd, 'intra': intra, 'inter': inter}
예제 #8
0
    def __init__(self, net, dataloader, resume=None, **kwargs):
        self.opt = cfg
        self.net = net
        self.init_data(dataloader)

        self.CEWeight = to_cuda(torch.tensor([1.0 - cfg.TRAIN.WPOS, cfg.TRAIN.WPOS]))
        self.CELoss = nn.CrossEntropyLoss(weight=self.CEWeight)
        self.BCELoss = nn.BCELoss()
        if torch.cuda.is_available():
            self.CELoss.cuda()
            self.BCELoss.cuda()

        self.iters = 0
        self.epochs = 0
        self.iters_per_epoch = None

        self.base_lr = self.opt.TRAIN.BASE_LR 
        self.momentum = self.opt.TRAIN.MOMENTUM
        self.optim_state_dict = None

        self.resume = False
        if resume is not None:
            self.resume = True
            self.epochs = resume['epochs']
            self.iters = resume['iters']
            self.optim_state_dict = resume['optimizer_state_dict']
            print('Resume Training from iters %d, %d.' % \
                     (self.epochs, self.iters))

        self.build_optimizer()
예제 #9
0
    def G_step(self, x_S, x_T):
        self.set_domain_id(1)
        preds_D_T = self.net_D(x_T)

        gt_D_S = to_cuda(torch.FloatTensor(preds_D_T.size()).fill_(1.0))
        loss_D = self.BCELoss(preds_D_T, gt_D_S)

        return loss_D
    def patch_gamma_estimation(self, nums_S, nums_T, dist):
        assert (len(nums_S) == len(nums_T))
        num_classes = len(nums_S)

        patch = {}
        gammas = {}
        gammas['st'] = to_cuda(
            torch.zeros_like(dist['st'], requires_grad=False))
        gammas['ss'] = []
        gammas['tt'] = []
        for c in range(num_classes):
            gammas['ss'] += [
                to_cuda(torch.zeros([num_classes], requires_grad=False))
            ]
            gammas['tt'] += [
                to_cuda(torch.zeros([num_classes], requires_grad=False))
            ]

        source_start = source_end = 0
        for ns in range(num_classes):
            source_start = source_end
            source_end = source_start + nums_S[ns]
            patch['ss'] = dist['ss'][ns]

            target_start = target_end = 0
            for nt in range(num_classes):
                target_start = target_end
                target_end = target_start + nums_T[nt]
                patch['tt'] = dist['tt'][nt]

                patch['st'] = dist['st'].narrow(0, source_start,
                                                nums_S[ns]).narrow(
                                                    1, target_start,
                                                    nums_T[nt])

                gamma = self.gamma_estimation(patch)

                gammas['ss'][ns][nt] = gamma
                gammas['tt'][nt][ns] = gamma
                gammas['st'][source_start:source_end, \
                     target_start:target_end] = gamma

        return gammas
예제 #11
0
    def test(self):
        # self.net.eval()
        self.feature_extractor.eval()
        self.classifier.eval()
        preds = []
        gts = []
        for sample in iter(self.test_data['loader']):
            data, gt = to_cuda(sample['Img']), to_cuda(sample['Label'])
            # logits = self.net(data)['logits']
            feature1, feature2 = self.feature_extractor(data)
            # feature1 = nn.AdaptiveAvgPool2d((1, 1))(feature1).view(-1, 2048)
            logits = self.classifier(feature1)
            preds += [logits]
            gts += [gt]

        preds = torch.cat(preds, dim=0)
        gts = torch.cat(gts, dim=0)

        res = self.model_eval(preds, gts)
        return res
예제 #12
0
    def test(self):
        self.set_domain_id(1)
        self.net.eval()

        num_classes = cfg.DATASET.NUM_CLASSES
        conmat = gen_utils.ConfusionMatrix(num_classes)

        for sample in iter(self.test_data['loader']):
            data, gt = gen_utils.to_cuda(sample['Img']), gen_utils.to_cuda(
                sample['Label'])
            logits = self.net(data)['out']
            logits = F.interpolate(logits,
                                   size=gt.shape[-2:],
                                   mode='bilinear',
                                   align_corners=False)
            preds = torch.max(logits, dim=1).indices

            conmat.update(gt.flatten(), preds.flatten())

        conmat.reduce_from_all_processes()
        accu, _, iou = conmat.compute()
        return accu.item() * 100.0, iou.mean().item() * 100.0
예제 #13
0
def get_centers(feature_extractor, dataloader, num_classes, key='feat'):
    centers = 0
    refs = to_cuda(torch.LongTensor(range(num_classes)).unsqueeze(1))
    for sample in iter(dataloader):
        data = to_cuda(sample['Img'])
        gt = to_cuda(sample['Label'])
        batch_size = data.size(0)

        # output = net.forward(data)[key]
        # feature = output.data
        feature, _ = feature_extractor(data)
        # feature = nn.AvgPool2d(7, stride=1)
        feature = nn.AdaptiveAvgPool2d((1, 1))(feature).view(-1, 2048)
        feature = feature.data
        feat_len = feature.size(1)

        gt = gt.unsqueeze(0).expand(num_classes, -1)
        mask = (gt == refs).unsqueeze(2).type(torch.cuda.FloatTensor)
        feature = feature.unsqueeze(0)
        # update centers
        centers += torch.sum(feature * mask, dim=1)

    return centers
예제 #14
0
    def collect_samples(self, net, loader):
        data_feat, data_gt, data_paths = [], [], []
        for sample in iter(loader):
            data = sample['Img'].cuda()
            data_paths += sample['Path']
            if 'Label' in sample.keys():
                data_gt += [to_cuda(sample['Label'])]

            output = net.forward(data)
            feature = output[self.feat_key].data
            data_feat += [feature]

        self.samples['data'] = data_paths
        self.samples['gt'] = torch.cat(data_gt, dim=0) \
                    if len(data_gt)>0 else None
        self.samples['feature'] = torch.cat(data_feat, dim=0)
예제 #15
0
파일: mmd.py 프로젝트: dongzhi0312/can-dta
    def compute_kernel_dist(self, dist, gamma, kernel_num, kernel_mul):
        base_gamma = gamma / (kernel_mul**(kernel_num // 2))
        gamma_list = [base_gamma * (kernel_mul**i) for i in range(kernel_num)]
        gamma_tensor = to_cuda(torch.tensor(gamma_list))

        eps = 1e-5
        gamma_mask = (gamma_tensor < eps).type(torch.cuda.FloatTensor)
        gamma_tensor = (1.0 - gamma_mask) * gamma_tensor + gamma_mask * eps
        gamma_tensor = gamma_tensor.detach()

        dist = dist.unsqueeze(0) / gamma_tensor.view(-1, 1, 1)
        upper_mask = (dist > 1e5).type(torch.cuda.FloatTensor).detach()
        lower_mask = (dist < 1e-5).type(torch.cuda.FloatTensor).detach()
        normal_mask = 1.0 - upper_mask - lower_mask
        dist = normal_mask * dist + upper_mask * 1e5 + lower_mask * 1e-5
        kernel_val = torch.sum(torch.exp(-1.0 * dist), dim=0)
        return kernel_val
    def patch_mean(self, nums_row, nums_col, dist):
        assert(len(nums_row) == len(nums_col))
        num_classes = len(nums_row)

        mean_tensor = to_cuda(torch.zeros([num_classes, num_classes]))
        row_start = row_end = 0
        for row in range(num_classes):
            row_start = row_end
            row_end = row_start + nums_row[row]

            col_start = col_end = 0
            for col in range(num_classes):
                col_start = col_end
                col_end = col_start + nums_col[col]
                val = torch.mean(dist.narrow(0, row_start, 
                           nums_row[row]).narrow(1, col_start, nums_col[col]))
                mean_tensor[row, col] = val
        return mean_tensor
예제 #17
0
    def collect_samples(self, feature_extractor, loader):
        data_feat, data_gt, data_paths = [], [], []
        for sample in iter(loader):
            data = sample['Img'].cuda()
            data_paths += sample['Path']
            if 'Label' in sample.keys():
                data_gt += [to_cuda(sample['Label'])]

            # output = net.forward(data)
            # feature = output[self.feat_key].data
            feature, _ = feature_extractor(data)
            feature = nn.AdaptiveAvgPool2d((1, 1))(feature).view(-1, 2048)
            feature = feature.data
            data_feat += [feature]

        self.samples['data'] = data_paths
        self.samples['gt'] = torch.cat(data_gt, dim=0) \
            if len(data_gt) > 0 else None
        self.samples['feature'] = torch.cat(data_feat, dim=0)
def step(opt, data_loader, model, to_train=True, optimizer=None):
    """
    Used as a trining step or validation step
    """
    nIters = len(data_loader)
    loss_meter = AverageMeter()
    with tqdm(total=nIters) as t:
        for i, data in enumerate(data_loader):
            # ===================forward=====================
            if opt.toCuda:
                data = to_cuda(data, device())
            image = data.pop('image')
            out_dict = model(image, data)
            loss = out_dict['loss']
            # ===================backward====================
            if to_train:
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
            loss_meter.update(loss.detach().cpu().item(), image.size(0))
            t.set_postfix(loss='{:10.8f}'.format(loss_meter.avg))
            t.update()

    return loss_meter.avg
예제 #19
0
    def update_network(self):
        stop = False
        update_iters = 0

        while not stop:
            self.net.train()
            self.net.zero_grad()
          
            if self.opt.TRAIN.OPTIMIZER != "Adam":
                self.update_lr()

            # get the video clip and corresponding mask
            #start = time()
            _, _, vclip, vlabel = self.get_samples()
            #end = time()
            #print('Time: %f' % (end-start))
            vclip = to_cuda(vclip)
            vlabel = to_cuda(vlabel)
            # forward and get the predictions
            # N x C x D x H x W
            #vout, vout_aux = self.net(vclip)
            vout = self.net(vclip)
            vprobs = F.softmax(vout, dim=1)

            #vout_aux = F.interpolate(vout_aux, scale_factor=(2, 4, 4))
            #vprobs_aux = F.softmax(vout_aux, dim=1)
            ## reshape and compute the cross-entropy loss
            #ch = vout.size(1)
            #vpreds = vout.transpose(0, 1).reshape(ch, -1).transpose(0, 1).squeeze(-1)
            ##vout0 = F.interpolate(vout0, scale_factor=(1, 2, 2))
            ##small_vpreds = vout0.transpose(0, 1).reshape(ch, -1).transpose(0, 1)

            #vgt = vlabel.view(-1)
            #loss = self.CELoss(vpreds, vgt) #self.BCELoss(vpreds, vgt)
            alpha = 0.3 
            loss = (1.0 - alpha) * solver_utils.dice_loss(vprobs, vlabel)
            loss += alpha * solver_utils.BF_loss(vprobs, vlabel)

            #loss_aux = (1.0 - alpha) * solver_utils.dice_loss(vprobs_aux, vlabel)
            #loss_aux += alpha * solver_utils.BF_loss(vprobs_aux, vlabel)

            #beta = 0.5
            #loss = beta * loss + (1.0 - beta) * loss_aux

            # downsample the mask by scale 2
            #small_mask = F.interpolate(vlabel.type(torch.cuda.FloatTensor), scale_factor=(0.5, 0.5))
            #small_vgt = small_mask.view(-1).type(torch.cuda.LongTensor)
            #loss += self.CELoss(small_vpreds, vgt)
            
            loss.backward()
            self.optimizer.step()

            if self.opt.TRAIN.LOGGING and (update_iters+1) % \
                      (max(1, self.iters_per_epoch // 
                      self.opt.TRAIN.NUM_LOGGING_PER_EPOCH)) == 0:

                iou, iou_pos, accuracy = self.model_eval(vout, vlabel)
                print('Training at (epoch %d, iters: %d) with loss, iou, iou_pos, accuracy: %.4f, %.4f, %.4f, %.4f.' % (
                      self.epochs, self.iters, loss, iou, iou_pos, accuracy))

            if self.opt.TRAIN.TEST_INTERVAL > 0 and \
                (self.iters+1) % int(self.opt.TRAIN.TEST_INTERVAL * 
                self.iters_per_epoch) == 0:

                with torch.no_grad():
                    iou, iou_pos, accuracy = self.test()
                    print('Test at (epoch %d, iters: %d): %.4f, %.4f, %.4f.' % (self.epochs,
                              self.iters, iou, iou_pos, accuracy))

            if self.opt.TRAIN.SAVE_CKPT_INTERVAL > 0 and \
                (self.iters+1) % int(self.opt.TRAIN.SAVE_CKPT_INTERVAL * 
                self.iters_per_epoch) == 0:

                self.save_ckpt()

            update_iters += 1
            self.iters += 1

            # update stop condition
            if update_iters >= self.iters_per_epoch:
                stop = True
            else:
                stop = False
예제 #20
0
    def feature_clustering(self, feature_extractor, loader):
        centers = None
        self.stop = False

        # self.collect_samples(net, loader)
        self.collect_samples(feature_extractor, loader)
        feature = self.samples['feature']

        refs = to_cuda(torch.LongTensor(range(self.num_classes)).unsqueeze(1))
        num_samples = feature.size(0)
        # 啥意思??
        num_split = ceil(1.0 * num_samples / self.max_len)

        while True:
            self.clustering_stop(centers)
            if centers is not None:
                self.centers = centers
            if self.stop:
                break

            centers = 0
            count = 0

            start = 0
            for N in range(num_split):
                cur_len = min(self.max_len, num_samples - start)
                cur_feature = feature.narrow(0, start, cur_len)
                dist2center, labels = self.assign_labels(cur_feature)
                labels_onehot = to_onehot(labels, self.num_classes)
                count += torch.sum(labels_onehot, dim=0)
                labels = labels.unsqueeze(0)
                mask = (labels == refs).unsqueeze(2).type(
                    torch.cuda.FloatTensor)
                reshaped_feature = cur_feature.unsqueeze(0)
                # update centers
                centers += torch.sum(reshaped_feature * mask, dim=1)
                start += cur_len

            mask = (count.unsqueeze(1) > 0).type(torch.cuda.FloatTensor)
            centers = mask * centers + (1 - mask) * self.init_centers

        dist2center, labels = [], []
        start = 0
        count = 0
        for N in range(num_split):
            cur_len = min(self.max_len, num_samples - start)
            cur_feature = feature.narrow(0, start, cur_len)
            cur_dist2center, cur_labels = self.assign_labels(cur_feature)

            labels_onehot = to_onehot(cur_labels, self.num_classes)
            count += torch.sum(labels_onehot, dim=0)

            dist2center += [cur_dist2center]
            labels += [cur_labels]
            start += cur_len

        self.samples['label'] = torch.cat(labels, dim=0)
        self.samples['dist2center'] = torch.cat(dist2center, dim=0)

        cluster2label = self.align_centers()
        # reorder the centers
        self.centers = self.centers[cluster2label, :]
        # re-label the data according to the index
        num_samples = len(self.samples['feature'])
        for k in range(num_samples):
            self.samples['label'][k] = cluster2label[self.samples['label']
                                                     [k]].item()

        self.center_change = torch.mean(self.Dist.get_dist(self.centers, \
                                                           self.init_centers))

        for i in range(num_samples):
            self.path2label[self.samples['data']
                            [i]] = self.samples['label'][i].item()

        del self.samples['feature']
    def forward(self, source, target, nums_S, nums_T, domain_probs,
                source_sample_labels):
        assert(len(nums_S) == len(nums_T)), \
             "The number of classes for source (%d) and target (%d) should be the same." \
             % (len(nums_S), len(nums_T))

        num_classes = len(nums_S)
        num_domains = domain_probs.size(2)
        # assert num_classes == domain_probs.size(1)

        # proper_labels = self.get_proper_labels(nums_S)
        proper_labels = source_sample_labels
        domain_probs_simple = domain_probs[torch.arange(
            domain_probs.size()[0]), proper_labels]  # Ns x K
        paired_domain_probs = self.compute_paired_domain_prob(
            domain_probs_simple)  # Ns x Ns x K

        paired_domain_probs_ss_classwise = self.split_paired_dp_classwise(
            paired_domain_probs, nums_S)

        # compute the dist
        dist_layers = []
        gamma_layers = []

        for i in range(self.num_layers):

            cur_source = source[i]
            cur_target = target[i]

            dist = {}
            dist['ss'] = self.compute_paired_dist(cur_source, cur_source)
            dist['tt'] = self.compute_paired_dist(cur_target, cur_target)
            dist['st'] = self.compute_paired_dist(cur_source, cur_target)
            dist['ss'] = self.split_classwise(dist['ss'], nums_S)
            dist['tt'] = self.split_classwise(dist['tt'], nums_T)

            # soft_dist = {}
            # soft_dist['ss'] = self.compute_soft_paired_dist(cur_source, cur_source, 'ss', domain_probs, paired_domain_probs)
            # soft_dist['tt'] = self.compute_soft_paired_dist(cur_target, cur_target, 'tt', domain_probs, paired_domain_probs)
            # soft_dist['st'] = self.compute_soft_paired_dist(cur_source, cur_target, 'st', domain_probs, paired_domain_probs)
            # soft_dist['ss'] = self.split_classwise(soft_dist['ss'], nums_S)
            # soft_dist['tt'] = self.split_classwise(soft_dist['tt'], nums_T)

            dist_layers += [dist]

            gamma_layers += [self.patch_gamma_estimation(nums_S, nums_T, dist)]

        # compute the kernel dist
        for i in range(self.num_layers):
            for c in range(num_classes):
                gamma_layers[i]['ss'][c] = gamma_layers[i]['ss'][c].view(
                    num_classes, 1, 1)
                gamma_layers[i]['tt'][c] = gamma_layers[i]['tt'][c].view(
                    num_classes, 1, 1)

        kernel_dist_st = self.kernel_layer_aggregation(dist_layers,
                                                       gamma_layers,
                                                       'st')  # Ns x Nt
        assert kernel_dist_st.size()[0] == domain_probs_simple.size()[0]

        kernel_dist_st_expand = kernel_dist_st.unsqueeze(2).expand(
            kernel_dist_st.size()[0],
            kernel_dist_st.size()[1], num_domains)  # Ns x Nt x K
        domain_probs_simple_expand = domain_probs_simple.unsqueeze(1).expand(
            kernel_dist_st.size()[0],
            kernel_dist_st.size()[1], num_domains)  # Ns x Nt x K

        kernel_dist_st_soft = kernel_dist_st_expand * domain_probs_simple_expand
        kernel_dist_st_soft = self.patch_mean(
            nums_S, nums_T, kernel_dist_st_soft,
            domain_probs_simple_expand)  # num_classes x num_classes x K

        kernel_dist_ss_soft = []
        kernel_dist_tt_soft = []
        for c in range(num_classes):
            kernel_dist_ss = self.kernel_layer_aggregation(
                dist_layers, gamma_layers, 'ss', c)  # num_classes x N_c x N_c
            paired_dp_ss_c = paired_domain_probs_ss_classwise[
                c]  # num_classes x N_c x N_c x K

            kernel_dist_ss_expand = kernel_dist_ss.unsqueeze(3).expand(
                kernel_dist_ss.size()[0],
                kernel_dist_ss.size()[1],
                kernel_dist_ss.size()[2],
                num_domains)  # num_classes x N_c x N_c x K

            temp_mult = kernel_dist_ss_expand * paired_dp_ss_c  # num_classes x N_c x N_c x K
            kernel_dist_ss_soft += [
                torch.sum(temp_mult.view(num_classes, -1, num_domains), dim=1)
                / torch.sum(paired_dp_ss_c.view(num_classes, -1, num_domains),
                            dim=1)
            ]  # list of num_classes x K

            temp_tt = torch.mean(self.kernel_layer_aggregation(
                dist_layers, gamma_layers, 'tt', c).view(num_classes, -1),
                                 dim=1)

            kernel_dist_tt_soft += [
                temp_tt.unsqueeze(1).expand(num_classes, num_domains)
            ]  # list of num_classes x K

        kernel_dist_ss_soft = torch.stack(kernel_dist_ss_soft, dim=0)
        kernel_dist_tt_soft = torch.stack(kernel_dist_tt_soft,
                                          dim=0).transpose(1, 0)

        mmds = kernel_dist_ss_soft + kernel_dist_tt_soft - 2 * kernel_dist_st_soft  # num_classes x num_classes x K

        nc2_intra = to_cuda(torch.zeros(1))
        nc2_inter = to_cuda(torch.zeros(1))
        nc1_intra = to_cuda(torch.zeros(1))
        nc1_inter = to_cuda(torch.zeros(1))

        for i in range(num_classes):
            for j in range(num_classes):
                if i == j:
                    nc1_intra += torch.mean(kernel_dist_ss_soft[i, i] +
                                            kernel_dist_tt_soft[j, j] -
                                            2 * kernel_dist_st_soft[i, j])
                else:
                    nc1_inter += torch.mean(kernel_dist_ss_soft[i, i] +
                                            kernel_dist_tt_soft[j, j] -
                                            2 * kernel_dist_st_soft[i, j])

        for i in range(num_classes):
            for j in range(num_classes):
                if i == j:
                    nc2_intra += torch.mean(kernel_dist_ss_soft[i, i] +
                                            kernel_dist_ss_soft[j, j] -
                                            2 * kernel_dist_ss_soft[i, j])
                else:
                    nc2_inter += torch.mean(kernel_dist_ss_soft[i, i] +
                                            kernel_dist_ss_soft[j, j] -
                                            2 * kernel_dist_ss_soft[i, j])

        nc1_intra = nc1_intra[0] / (self.num_classes)
        nc1_inter = nc1_inter[0] / (self.num_classes * (self.num_classes - 1))

        nc2_intra = nc2_intra[0] / (self.num_classes)
        nc2_inter = nc2_inter[0] / (self.num_classes * (self.num_classes - 1))

        cdd = nc1_intra + nc2_intra if self.intra_only else nc1_intra + nc2_intra - nc1_inter - nc2_inter
        return {
            'cdd': cdd,
            'intra': nc1_intra + nc2_intra,
            'inter': nc1_inter + nc2_inter
        }
예제 #22
0
    def update_network(self, filtered_classes):
        # initial configuration
        stop = False
        update_iters = 0

        self.train_data[self.source_name]['iterator'] = \
            iter(self.train_data[self.source_name]['loader'])
        self.train_data['categorical']['iterator'] = \
            iter(self.train_data['categorical']['loader'])

        while not stop:
            # update learning rate
            self.update_lr()

            # set the status of network
            self.net.train()
            self.net.zero_grad()

            loss = 0
            ce_loss_iter = 0
            cdd_loss_iter = 0

            # coventional sampling for training on labeled source data
            source_sample = self.get_samples(self.source_name)
            source_data, source_gt = source_sample['Img'], \
                                     source_sample['Label']

            source_data = to_cuda(source_data)
            source_gt = to_cuda(source_gt)
            self.net.module.set_bn_domain(self.bn_domain_map[self.source_name])
            source_preds = self.net(source_data)['logits']

            # compute the cross-entropy loss
            ce_loss = self.CELoss(source_preds, source_gt)
            ce_loss.backward()

            ce_loss_iter += ce_loss
            loss += ce_loss

            if len(filtered_classes) > 0:
                # update the network parameters
                # 1) class-aware sampling
                source_samples_cls, source_nums_cls, \
                target_samples_cls, target_nums_cls = self.CAS()

                # 2) forward and compute the loss
                source_cls_concat = torch.cat(
                    [to_cuda(samples) for samples in source_samples_cls],
                    dim=0)
                target_cls_concat = torch.cat(
                    [to_cuda(samples) for samples in target_samples_cls],
                    dim=0)

                self.net.module.set_bn_domain(
                    self.bn_domain_map[self.source_name])
                feats_source = self.net(source_cls_concat)
                self.net.module.set_bn_domain(
                    self.bn_domain_map[self.target_name])
                feats_target = self.net(target_cls_concat)

                # prepare the features
                feats_toalign_S = self.prepare_feats(feats_source)
                feats_toalign_T = self.prepare_feats(feats_target)

                cdd_loss = self.cdd.forward(
                    feats_toalign_S, feats_toalign_T, source_nums_cls,
                    target_nums_cls)[self.discrepancy_key]

                cdd_loss *= self.opt.CDD.LOSS_WEIGHT
                cdd_loss.backward()

                cdd_loss_iter += cdd_loss
                loss += cdd_loss

            # update the network
            self.optimizer.step()

            if self.opt.TRAIN.LOGGING and (update_iters + 1) % \
                    (max(1, self.iters_per_loop // self.opt.TRAIN.NUM_LOGGING_PER_LOOP)) == 0:
                accu = self.model_eval(source_preds, source_gt)
                cur_loss = {
                    'ce_loss': ce_loss_iter,
                    'cdd_loss': cdd_loss_iter,
                    'total_loss': loss
                }
                self.logging(cur_loss, accu)

            self.opt.TRAIN.TEST_INTERVAL = min(1.0,
                                               self.opt.TRAIN.TEST_INTERVAL)
            self.opt.TRAIN.SAVE_CKPT_INTERVAL = min(
                1.0, self.opt.TRAIN.SAVE_CKPT_INTERVAL)

            if self.opt.TRAIN.TEST_INTERVAL > 0 and \
                    (update_iters + 1) % int(self.opt.TRAIN.TEST_INTERVAL * self.iters_per_loop) == 0:
                with torch.no_grad():
                    self.net.module.set_bn_domain(
                        self.bn_domain_map[self.target_name])
                    accu = self.test()
                    print('Test at (loop %d, iters: %d) with %s: %.4f.' %
                          (self.loop, self.iters, self.opt.EVAL_METRIC, accu))

            if self.opt.TRAIN.SAVE_CKPT_INTERVAL > 0 and \
                    (update_iters + 1) % int(self.opt.TRAIN.SAVE_CKPT_INTERVAL * self.iters_per_loop) == 0:
                self.save_ckpt()

            update_iters += 1
            self.iters += 1

            # update stop condition
            if update_iters >= self.iters_per_loop:
                stop = True
            else:
                stop = False
예제 #23
0
def test(args):
    # prepare data
    dataloader = prepare_data()

    # initialize model
    model_state_dict = None
    fx_pretrained = True

    bn_domain_map = {}
    if cfg.WEIGHTS != '':
        weights_dict = torch.load(cfg.WEIGHTS)
        model_state_dict = weights_dict['weights']
        bn_domain_map = weights_dict['bn_domain_map']
        fx_pretrained = False

    if args.adapted_model:
        num_domains_bn = 2
    else:
        num_domains_bn = 1

    net = model.danet(num_classes=cfg.DATASET.NUM_CLASSES,
                      state_dict=model_state_dict,
                      feature_extractor=cfg.MODEL.FEATURE_EXTRACTOR,
                      fx_pretrained=fx_pretrained,
                      dropout_ratio=cfg.TRAIN.DROPOUT_RATIO,
                      fc_hidden_dims=cfg.MODEL.FC_HIDDEN_DIMS,
                      num_domains_bn=num_domains_bn)

    net = torch.nn.DataParallel(net)

    if torch.cuda.is_available():
        net.cuda()

    # test
    res = {}
    res['path'], res['preds'], res['gt'], res['probs'] = [], [], [], []
    net.eval()

    if cfg.TEST.DOMAIN in bn_domain_map:
        domain_id = bn_domain_map[cfg.TEST.DOMAIN]
    else:
        domain_id = 0

    with torch.no_grad():
        net.module.set_bn_domain(domain_id)
        for sample in iter(dataloader):
            res['path'] += sample['Path']

            if cfg.DATA_TRANSFORM.WITH_FIVE_CROP:
                n, ncrop, c, h, w = sample['Img'].size()
                sample['Img'] = sample['Img'].view(-1, c, h, w)
                img = to_cuda(sample['Img'])
                probs = net(img)['probs']
                probs = probs.view(n, ncrop, -1).mean(dim=1)
            else:
                img = to_cuda(sample['Img'])
                probs = net(img)['probs']

            preds = torch.max(probs, dim=1)[1]
            res['preds'] += [preds]
            res['probs'] += [probs]

            if 'Label' in sample:
                label = to_cuda(sample['Label'])
                res['gt'] += [label]
            print('Processed %d samples.' % len(res['path']))

        preds = torch.cat(res['preds'], dim=0)
        save_preds(res['path'], preds, cfg.SAVE_DIR)

        if 'gt' in res and len(res['gt']) > 0:
            gts = torch.cat(res['gt'], dim=0)
            probs = torch.cat(res['probs'], dim=0)

            assert (cfg.EVAL_METRIC == 'mean_accu'
                    or cfg.EVAL_METRIC == 'accuracy')
            if cfg.EVAL_METRIC == "mean_accu":
                eval_res = mean_accuracy(probs, gts)
                print('Test mean_accu: %.4f' % (eval_res))

            elif cfg.EVAL_METRIC == "accuracy":
                eval_res = accuracy(probs, gts)
                print('Test accuracy: %.4f' % (eval_res))

    print('Finished!')
예제 #24
0
    def update_network(self, **kwargs):
        stop = False
        self.train_data['source']['iterator'] = iter(
            self.train_data['source']['loader'])
        self.train_data['target']['iterator'] = iter(
            self.train_data['target']['loader'])
        self.iters_per_epoch = len(self.train_data['target']['loader'])
        iters_counter_within_epoch = 0
        data_time = AverageMeter()
        batch_time = AverageMeter()
        total_loss = AverageMeter()
        ce_loss = AverageMeter()
        da_loss = AverageMeter()
        prec1_task = AverageMeter()
        prec1_aux1 = AverageMeter()
        prec1_aux2 = AverageMeter()
        self.net.train()
        end = time.time()
        if self.opt.TRAIN.PROCESS_COUNTER == 'epoch':
            lam = 2 / (1 + math.exp(
                -1 * 10 * self.epoch / self.opt.TRAIN.MAX_EPOCH)) - 1
            self.update_lr()
            print('value of lam is: %3f' % (lam))
        while not stop:
            if self.opt.TRAIN.PROCESS_COUNTER == 'iteration':
                lam = 2 / (1 + math.exp(
                    -1 * 10 * self.iters /
                    (self.opt.TRAIN.MAX_EPOCH * self.iters_per_epoch))) - 1
                print('value of lam is: %3f' % (lam))
                self.update_lr()
            source_data, source_gt = self.get_samples('source')
            target_data, _ = self.get_samples('target')
            source_data = to_cuda(source_data)
            source_gt = to_cuda(source_gt)
            target_data = to_cuda(target_data)
            data_time.update(time.time() - end)

            feature_source, output_source, output_source1, output_source2, output_source_dc, output_source1_trunc, output_source2_trunc = self.net(
                source_data, lam)
            loss_task_auxiliary_1 = self.CELoss(output_source1_trunc,
                                                source_gt)
            loss_task_auxiliary_2 = self.CELoss(output_source2_trunc,
                                                source_gt)
            loss_task = self.CELoss(output_source, source_gt)
            if self.opt.MCDALNET.DISTANCE_TYPE != 'SourceOnly':
                feature_target, output_target, output_target1, output_target2, output_target_dc, output_target1_trunc, output_target2_trunc = self.net(
                    target_data, lam)
                if self.opt.MCDALNET.DISTANCE_TYPE == 'DANN':
                    num_source = source_data.size()[0]
                    num_target = target_data.size()[0]
                    dlabel_source = to_cuda(torch.zeros(num_source, 1))
                    dlabel_target = to_cuda(torch.ones(num_target, 1))
                    loss_domain_all = self.BCELoss(
                        output_source_dc, dlabel_source) + self.BCELoss(
                            output_target_dc, dlabel_target)
                    loss_all = loss_task + loss_domain_all
                elif self.opt.MCDALNET.DISTANCE_TYPE == 'MDD':
                    prob_target1 = F.softmax(output_target1, dim=1)
                    _, target_pseudo_label = torch.topk(output_target2, 1)
                    batch_index = torch.arange(output_target.size()[0]).long()
                    pred_gt_prob = prob_target1[
                        batch_index,
                        target_pseudo_label]  ## the prob values of the predicted gt
                    pred_gt_prob = process_one_values(pred_gt_prob)
                    loss_domain_target = (1 - pred_gt_prob).log().mean()

                    _, source_pseudo_label = torch.topk(output_source2, 1)
                    loss_domain_source = self.CELoss(output_source1,
                                                     source_pseudo_label[:, 0])
                    loss_domain_all = loss_domain_source - loss_domain_target
                    loss_all = loss_task + loss_domain_all + loss_task_auxiliary_1 + loss_task_auxiliary_2
                else:
                    loss_domain_source = self.McDalNetLoss(
                        output_source1, output_source2,
                        self.opt.MCDALNET.DISTANCE_TYPE)
                    loss_domain_target = self.McDalNetLoss(
                        output_target1, output_target2,
                        self.opt.MCDALNET.DISTANCE_TYPE)
                    loss_domain_all = loss_domain_source - loss_domain_target
                    loss_all = loss_task + loss_domain_all + loss_task_auxiliary_1 + loss_task_auxiliary_2
                da_loss.update(loss_domain_all, source_data.size()[0])
            else:
                loss_all = loss_task
            ce_loss.update(loss_task, source_data.size()[0])
            total_loss.update(loss_all, source_data.size()[0])
            prec1_task.update(accuracy(output_source, source_gt),
                              source_data.size()[0])
            prec1_aux1.update(accuracy(output_source1, source_gt),
                              source_data.size()[0])
            prec1_aux2.update(accuracy(output_source2, source_gt),
                              source_data.size()[0])

            self.optimizer.zero_grad()
            loss_all.backward()
            self.optimizer.step()

            print("  Train:epoch: %d:[%d/%d], LossCE: %3f, LossDA: %3f, LossAll: %3f, Auxi1: %3f, Auxi2: %3f, Task: %3f" % \
                  (self.epoch, iters_counter_within_epoch, self.iters_per_epoch, ce_loss.avg, da_loss.avg, total_loss.avg, prec1_aux1.avg, prec1_aux2.avg, prec1_task.avg))

            batch_time.update(time.time() - end)
            end = time.time()
            self.iters += 1
            iters_counter_within_epoch += 1
            if iters_counter_within_epoch >= self.iters_per_epoch:
                log = open(os.path.join(self.opt.SAVE_DIR, 'log.txt'), 'a')
                log.write("\n")
                log.write("  Train:epoch: %d:[%d/%d], LossCE: %3f, LossDA: %3f, LossAll: %3f, Auxi1: %3f, Auxi2: %3f, Task: %3f" % \
                    (self.epoch, iters_counter_within_epoch, self.iters_per_epoch, ce_loss.avg, da_loss.avg, total_loss.avg, prec1_aux1.avg, prec1_aux2.avg, prec1_task.avg))
                log.close()
                stop = True
예제 #25
0
def test(args):
    # initialize model
    model_state_dict = None

    if cfg.WEIGHTS != '':
        param_dict = torch.load(cfg.WEIGHTS, torch.device('cpu'))
        model_state_dict = param_dict['weights']

    net = SegNet.__dict__[cfg.MODEL.NETWORK_NAME](
        pretrained=False,
        pretrained_backbone=False,
        num_classes=cfg.DATASET.NUM_CLASSES,
        aux_loss=cfg.MODEL.USE_AUX_CLASSIFIER)

    if args.distributed:
        net = torch.nn.SyncBatchNorm.convert_sync_batchnorm(net)

    if cfg.MODEL.DOMAIN_BN:
        net = DomainBN.convert_domain_batchnorm(net, num_domains=2)

    if model_state_dict is not None:
        try:
            net.load_state_dict(model_state_dict)
        except:
            net = DomainBN.convert_domain_batchnorm(net, num_domains=2)
            net.load_state_dict(model_state_dict)
            if cfg.TEST.DOMAIN == 'source':
                DomainBN.set_domain_id(net, 0)
            if cfg.TEST.DOMAIN == 'target':
                DomainBN.set_domain_id(net, 1)

    if torch.cuda.is_available():
        net.cuda()

    if args.distributed:
        net = DistributedDataParallel(net, device_ids=[args.gpu])
    else:
        net = torch.nn.DataParallel(net)

    test_dataset, test_dataloader = prepare_data(args)

    net.eval()
    corrects = 0
    total_num_pixels = 0
    total_intersection = 0
    total_union = 0
    num_classes = cfg.DATASET.NUM_CLASSES

    with torch.no_grad():
        conmat = gen_utils.ConfusionMatrix(
            cfg.DATASET.NUM_CLASSES,
            list(LABEL_TASK['%s2%s' %
                            (cfg.DATASET.SOURCE, cfg.DATASET.TARGET)].keys()))
        for sample in iter(test_dataloader):
            data, gt = gen_utils.to_cuda(sample['Img']), gen_utils.to_cuda(
                sample['Label'])
            names = sample['Name']
            res = net(data)

            if cfg.TEST.WITH_AGGREGATION:
                feats = res['feat']
                alpha = 0.5
                feats = (1.0 - alpha) * feats + alpha * AssociationLoss(
                ).spatial_agg(feats)[-1]
                preds = F.softmax(net.module.classifier(feats), dim=1)
                preds = (1.0 - alpha) * preds + alpha * AssociationLoss(
                ).spatial_agg(preds, metric='kl')[-1]
            else:
                preds = res['out']

            preds = F.interpolate(preds,
                                  size=gt.shape[-2:],
                                  mode='bilinear',
                                  align_corners=False)
            preds = torch.max(preds, dim=1).indices

            if cfg.TEST.VISUALIZE:
                for i in range(preds.size(0)):
                    cur_pred = preds[i, :, :].cpu().numpy()
                    cur_gt = gt[i, :, :].cpu().numpy()
                    cur_pred_cp = cur_pred.copy()
                    cur_gt_cp = cur_gt.copy()
                    label_map = label_map_gtav if cfg.DATASET.SOURCE == 'GTAV' else label_map_syn
                    for n in range(cfg.DATASET.NUM_CLASSES):
                        cur_pred[cur_pred_cp == n] = label_map[n]
                        cur_gt[cur_gt_cp == n] = label_map[n]

                    cur_pred = np.where(cur_gt == 255, cur_gt, cur_pred)

                    cur_pred = np.asarray(cur_pred, dtype=np.uint8)
                    cur_gt = np.asarray(cur_gt, dtype=np.uint8)

                    vis_res = colorize_mask(cur_pred)
                    vis_gt = colorize_mask(cur_gt)

                    vis_name = 'vis_%s.png' % (names[i])
                    vis_res.save(os.path.join(cfg.SAVE_DIR, vis_name))

                    vis_name = 'vis_gt_%s.png' % (names[i])
                    vis_gt.save(os.path.join(cfg.SAVE_DIR, vis_name))

            conmat.update(gt.flatten(), preds.flatten())

        conmat.reduce_from_all_processes()
        print('Test with %d samples: ' % len(test_dataset))
        print(conmat)

    print('Finished!')
예제 #26
0
    def test(self):
        self.feature_extractor.eval()
        self.classifier.eval()
        prec1_fs = AverageMeter()
        prec1_ft = AverageMeter()
        counter_all_fs = torch.FloatTensor(self.opt.DATASET.NUM_CLASSES).fill_(0)
        counter_all_ft = torch.FloatTensor(self.opt.DATASET.NUM_CLASSES).fill_(0)
        counter_acc_fs = torch.FloatTensor(self.opt.DATASET.NUM_CLASSES).fill_(0)
        counter_acc_ft = torch.FloatTensor(self.opt.DATASET.NUM_CLASSES).fill_(0)

        for i, (input, target) in enumerate(self.test_data['loader']):
            input, target = to_cuda(input), to_cuda(target)
            with torch.no_grad():
                feature_test = self.feature_extractor(input)
                output_test = self.classifier(feature_test)


            if self.opt.EVAL_METRIC == 'accu':
                prec1_fs_iter = accuracy(output_test[:, :self.num_classes], target)
                prec1_ft_iter = accuracy(output_test[:, self.num_classes:], target)
                prec1_fs.update(prec1_fs_iter, input.size(0))
                prec1_ft.update(prec1_ft_iter, input.size(0))
                if i % self.opt.PRINT_STEP == 0:
                    print("  Test:epoch: %d:[%d/%d], AccFs: %3f, AccFt: %3f" % \
                          (self.epoch, i, len(self.test_data['loader']), prec1_fs.avg, prec1_ft.avg))
            elif self.opt.EVAL_METRIC == 'accu_mean':
                prec1_ft_iter = accuracy(output_test[:, self.num_classes:], target)
                prec1_ft.update(prec1_ft_iter, input.size(0))
                counter_all_fs, counter_acc_fs = accuracy_for_each_class(output_test[:, :self.num_classes], target, counter_all_fs, counter_acc_fs)
                counter_all_ft, counter_acc_ft = accuracy_for_each_class(output_test[:, self.num_classes:], target, counter_all_ft, counter_acc_ft)
                if i % self.opt.PRINT_STEP == 0:
                    print("  Test:epoch: %d:[%d/%d], Task: %3f" % \
                          (self.epoch, i, len(self.test_data['loader']), prec1_ft.avg))
            else:
                raise NotImplementedError
        acc_for_each_class_fs = counter_acc_fs / counter_all_fs
        acc_for_each_class_ft = counter_acc_ft / counter_all_ft
        log = open(os.path.join(self.opt.SAVE_DIR, 'log.txt'), 'a')
        log.write("\n")
        if self.opt.EVAL_METRIC == 'accu':
            log.write(
                "                                                          Test:epoch: %d, AccFs: %3f, AccFt: %3f" % \
                (self.epoch, prec1_fs.avg, prec1_ft.avg))
            log.close()
            return max(prec1_fs.avg, prec1_ft.avg)
        elif self.opt.EVAL_METRIC == 'accu_mean':
            log.write(
                "                                            Test:epoch: %d, AccFs: %3f, AccFt: %3f" % \
                (self.epoch,acc_for_each_class_fs.mean(), acc_for_each_class_ft.mean()))
            log.write("\nClass-wise Acc of Ft:")  ## based on the task classifier.
            for i in range(self.opt.DATASET.NUM_CLASSES):
                if i == 0:
                    log.write("%dst: %3f" % (i + 1, acc_for_each_class_ft[i]))
                elif i == 1:
                    log.write(",  %dnd: %3f" % (i + 1, acc_for_each_class_ft[i]))
                elif i == 2:
                    log.write(", %drd: %3f" % (i + 1, acc_for_each_class_ft[i]))
                else:
                    log.write(", %dth: %3f" % (i + 1, acc_for_each_class_ft[i]))
            log.close()
            return max(acc_for_each_class_ft.mean(), acc_for_each_class_fs.mean())
예제 #27
0
    def test(self):
        self.net.eval()
        prec1_task = AverageMeter()
        prec1_auxi1 = AverageMeter()
        prec1_auxi2 = AverageMeter()
        counter_all = torch.FloatTensor(self.opt.DATASET.NUM_CLASSES).fill_(0)
        counter_all_auxi1 = torch.FloatTensor(
            self.opt.DATASET.NUM_CLASSES).fill_(0)
        counter_all_auxi2 = torch.FloatTensor(
            self.opt.DATASET.NUM_CLASSES).fill_(0)
        counter_acc = torch.FloatTensor(self.opt.DATASET.NUM_CLASSES).fill_(0)
        counter_acc_auxi1 = torch.FloatTensor(
            self.opt.DATASET.NUM_CLASSES).fill_(0)
        counter_acc_auxi2 = torch.FloatTensor(
            self.opt.DATASET.NUM_CLASSES).fill_(0)

        for i, (input, target) in enumerate(self.test_data['loader']):
            input, target = to_cuda(input), to_cuda(target)
            with torch.no_grad():
                _, output_test, output_test1, output_test2, _, _, _ = self.net(
                    input,
                    1)  ## the value of lam do not affect the test process

            if self.opt.EVAL_METRIC == 'accu':
                prec1_task_iter = accuracy(output_test, target)
                prec1_auxi1_iter = accuracy(output_test1, target)
                prec1_auxi2_iter = accuracy(output_test2, target)
                prec1_task.update(prec1_task_iter, input.size(0))
                prec1_auxi1.update(prec1_auxi1_iter, input.size(0))
                prec1_auxi2.update(prec1_auxi2_iter, input.size(0))
                if i % self.opt.PRINT_STEP == 0:
                    print("  Test:epoch: %d:[%d/%d], Auxi1: %3f, Auxi2: %3f, Task: %3f" % \
                          (self.epoch, i, len(self.test_data['loader']), prec1_auxi1.avg, prec1_auxi2.avg, prec1_task.avg))
            elif self.opt.EVAL_METRIC == 'accu_mean':
                prec1_task_iter = accuracy(output_test, target)
                prec1_task.update(prec1_task_iter, input.size(0))
                counter_all, counter_acc = accuracy_for_each_class(
                    output_test, target, counter_all, counter_acc)
                counter_all_auxi1, counter_acc_auxi1 = accuracy_for_each_class(
                    output_test1, target, counter_all_auxi1, counter_acc_auxi1)
                counter_all_auxi2, counter_acc_auxi2 = accuracy_for_each_class(
                    output_test2, target, counter_all_auxi2, counter_acc_auxi2)
                if i % self.opt.PRINT_STEP == 0:
                    print("  Test:epoch: %d:[%d/%d], Task: %3f" % \
                          (self.epoch, i, len(self.test_data['loader']), prec1_task.avg))
            else:
                raise NotImplementedError
        acc_for_each_class = counter_acc / counter_all
        acc_for_each_class_auxi1 = counter_acc_auxi1 / counter_all_auxi1
        acc_for_each_class_auxi2 = counter_acc_auxi2 / counter_all_auxi2
        log = open(os.path.join(self.opt.SAVE_DIR, 'log.txt'), 'a')
        log.write("\n")
        if self.opt.EVAL_METRIC == 'accu':
            log.write(
                "                                    Test:epoch: %d, Top1_auxi1: %3f, Top1_auxi2: %3f, Top1: %3f" % \
                (self.epoch, prec1_auxi1.avg, prec1_auxi2.avg, prec1_task.avg))
            log.close()
            return max(prec1_auxi1.avg, prec1_auxi2.avg, prec1_task.avg)
        elif self.opt.EVAL_METRIC == 'accu_mean':
            log.write(
                "                                    Test:epoch: %d, Top1_auxi1: %3f, Top1_auxi2: %3f, Top1: %3f" % \
                (self.epoch, acc_for_each_class_auxi1.mean(), acc_for_each_class_auxi2.mean(), acc_for_each_class.mean()))
            log.write("\nClass-wise Acc:")  ## based on the task classifier.
            for i in range(self.opt.DATASET.NUM_CLASSES):
                if i == 0:
                    log.write("%dst: %3f" % (i + 1, acc_for_each_class[i]))
                elif i == 1:
                    log.write(",  %dnd: %3f" % (i + 1, acc_for_each_class[i]))
                elif i == 2:
                    log.write(", %drd: %3f" % (i + 1, acc_for_each_class[i]))
                else:
                    log.write(", %dth: %3f" % (i + 1, acc_for_each_class[i]))
            log.close()
            return max(acc_for_each_class_auxi1.mean(),
                       acc_for_each_class_auxi2.mean(),
                       acc_for_each_class.mean())
예제 #28
0
    def update_network(self, **kwargs):
        stop = False
        self.train_data['source']['iterator'] = iter(self.train_data['source']['loader'])
        self.train_data['target']['iterator'] = iter(self.train_data['target']['loader'])
        self.iters_per_epoch = len(self.train_data['target']['loader'])
        iters_counter_within_epoch = 0
        data_time = AverageMeter()
        batch_time = AverageMeter()
        classifier_loss = AverageMeter()
        feature_extractor_loss = AverageMeter()
        prec1_fs = AverageMeter()
        prec1_ft = AverageMeter()
        self.feature_extractor.train()
        self.classifier.train()
        end = time.time()
        if self.opt.TRAIN.PROCESS_COUNTER == 'epoch':
            lam = 2 / (1 + math.exp(-1 * 10 * self.epoch / self.opt.TRAIN.MAX_EPOCH)) - 1
            self.update_lr()
            print('value of lam is: %3f' % (lam))
        while not stop:
            if self.opt.TRAIN.PROCESS_COUNTER == 'iteration':
                lam = 2 / (1 + math.exp(-1 * 10 * self.iters / (self.opt.TRAIN.MAX_EPOCH * self.iters_per_epoch))) - 1
                print('value of lam is: %3f' % (lam))
                self.update_lr()
            source_data, source_gt = self.get_samples('source')
            target_data, _ = self.get_samples('target')
            source_data = to_cuda(source_data)
            source_gt = to_cuda(source_gt)
            target_data = to_cuda(target_data)
            data_time.update(time.time() - end)

            feature_source = self.feature_extractor(source_data)
            output_source = self.classifier(feature_source)
            feature_target = self.feature_extractor(target_data)
            output_target = self.classifier(feature_target)

            loss_task_fs = self.CELoss(output_source[:,:self.num_classes], source_gt)
            loss_task_ft = self.CELoss(output_source[:,self.num_classes:], source_gt)
            loss_discrim_source = self.CELoss(output_source, source_gt)
            loss_discrim_target = self.TargetDiscrimLoss(output_target)
            loss_summary_classifier = loss_task_fs + loss_task_ft + loss_discrim_source + loss_discrim_target

            source_gt_for_ft_in_fst = source_gt + self.num_classes
            loss_confusion_source = 0.5 * self.CELoss(output_source, source_gt) + 0.5 * self.CELoss(output_source, source_gt_for_ft_in_fst)
            loss_confusion_target = self.ConcatenatedCELoss(output_target)
            loss_summary_feature_extractor = loss_confusion_source + lam * loss_confusion_target

            self.optimizer_classifier.zero_grad()
            loss_summary_classifier.backward(retain_graph=True)
            self.optimizer_classifier.step()

            self.optimizer_feature_extractor.zero_grad()
            loss_summary_feature_extractor.backward()
            self.optimizer_feature_extractor.step()

            classifier_loss.update(loss_summary_classifier, source_data.size()[0])
            feature_extractor_loss.update(loss_summary_feature_extractor, source_data.size()[0])
            prec1_fs.update(accuracy(output_source[:, :self.num_classes], source_gt), source_data.size()[0])
            prec1_ft.update(accuracy(output_source[:, self.num_classes:], source_gt), source_data.size()[0])

            print("  Train:epoch: %d:[%d/%d], LossCla: %3f, LossFeat: %3f, AccFs: %3f, AccFt: %3f" % \
                  (self.epoch, iters_counter_within_epoch, self.iters_per_epoch, classifier_loss.avg, feature_extractor_loss.avg, prec1_fs.avg, prec1_ft.avg))

            batch_time.update(time.time() - end)
            end = time.time()
            self.iters += 1
            iters_counter_within_epoch += 1
            if iters_counter_within_epoch >= self.iters_per_epoch:
                log = open(os.path.join(self.opt.SAVE_DIR, 'log.txt'), 'a')
                log.write("\n")
                log.write("  Train:epoch: %d:[%d/%d], LossCla: %3f, LossFeat: %3f, AccFs: %3f, AccFt: %3f" % \
                  (self.epoch, iters_counter_within_epoch, self.iters_per_epoch, classifier_loss.avg, feature_extractor_loss.avg, prec1_fs.avg, prec1_ft.avg))
                log.close()
                stop = True
예제 #29
0
def test_and_save_mask(net, test_dataloader):
    clip_len = cfg.DATASET.CLIP_LEN
    clip_stride = cfg.DATASET.CLIP_STRIDE

    for sample in iter(test_dataloader):
        if cfg.TEST.WITH_MASK:
            vid, start_f, clips, vmask = sample
        else:
            vid, start_f, clips = sample
        if clips.size(2) < 8: continue
        # forward and get the prediction result
        vpred = net(to_cuda(clips))
        # N x D x H x W
        probs = F.softmax(vpred, dim=1)
        pos_probs = probs[:, 1, :, :, :]

        start_f = start_f.numpy()
        N = len(vid)
        assert(N == len(start_f))
        for i in range(N):
            cur_vid = vid[i]
            cur_video_path = os.path.join(cfg.DATASET.DATAROOT, 'videos', '%s.%s'%(cur_vid, cfg.DATASET.VIDEO_FORMAT))
            print('video: %s, start_f: %d' % (cur_video_path, start_f[i]))
            # TODO: for debugging
            #if start_f[i] < 5344:
            #    continue
            cur_video = VideoReader(cur_video_path)
            #cur_video.seek(int(start_f[i]))

            frame_count = 0
            proposals = []
            clip_imgs = []
            #masks = []
            #for frame in cur_video.get_iter(clip_len * clip_stride):
            #    if frame_count % clip_stride:
            #        frame_count += 1
            #        continue

            #    # read the image
            #    img = frame.numpy()
            #    assert(len(img.shape) > 1)
            #    clip_imgs += [img]
            #    #img = img[:, :, [2, 1, 0]]
            #    #img = (img / 255.) * 2 - 1

            for fid in range(clip_len * clip_stride):
                if frame_count % clip_stride:
                    frame_count += 1
                    continue

                count = frame_count // clip_stride
                if not cfg.TEST.WITH_DENSE_CRF:
                    cur_pos_probs = pos_probs[i, count, :, :].cpu().numpy()
                else:
                    cur_probs = probs[i, :, count, :, :].cpu().numpy()
                    # TODO: need normalize or not?
                    resized_img = clips[i, :, count, :, :].cpu().numpy()
                    resized_img = np.uint8(255 * (resized_img + 1.) / 2.0)
                    cur_pos_probs = 1.0 * dense_crf(cur_probs, resized_img)

                smoothing(cur_pos_probs)
                labels, num_regions = regional_growing(cur_pos_probs, pixel_val_thres=0.3)
                cur_pos_probs, bboxes = filtering(cur_pos_probs, labels, num_regions, 5)
                #masks.append(cur_pos_probs)
                if len(proposals) == 0:
                    proposals = [[(count, bbox)] for bbox in bboxes]
                else:
                    associate_bboxes(count, bboxes, proposals)

                frame_count += 1

            #heatmaps = draw_heatmaps(clip_imgs, masks)
            #save_visualizations(heatmaps, 'heatmaps', cur_vid, start_f[i])

            #h, w, _ = clip_imgs[0].shape
            h, w = cur_video.height, cur_video.width
            new_proposals = [prop for prop in proposals if len(prop) >= 7]
            print('Number of proposals before and after filtering: %d, %d' % (len(proposals), len(new_proposals)))
            if len(new_proposals) == 0:
                continue
            #for prop in new_proposals:
            #    print(prop)

            stride_x = 1.0 * w / cfg.DATA_TRANSFORM.FINESIZE
            stride_y = 1.0 * h / cfg.DATA_TRANSFORM.FINESIZE
            new_proposals = resize_proposals(new_proposals, stride_x, stride_y, w, h)
            save_proposals(new_proposals, 'proposals', cur_vid, start_f[i])
예제 #30
0
    def update_network(self, filtered_classes):
        # initial configuration
        stop = False
        update_iters = 0

        self.train_data[self.source_name]['iterator'] = \
                     iter(self.train_data[self.source_name]['loader'])
        self.train_data['categorical']['iterator'] = \
                     iter(self.train_data['categorical']['loader'])

        while not stop:
            # update learning rate
            self.update_lr()

            # set the status of network
            self.net.train()
            self.dpn.train()
            self.net.zero_grad()
            self.dpn.zero_grad()

            loss = 0
            ce_loss_iter = 0
            cdd_loss_iter = 0

            # coventional sampling for training on labeled source data
            source_sample = self.get_samples(self.source_name)
            source_data, source_gt = source_sample['Img'],\
                          source_sample['Label']

            source_data = to_cuda(source_data)
            source_gt = to_cuda(source_gt)
            self.net.module.set_bn_domain(self.bn_domain_map[self.source_name])
            source_preds = self.net(source_data)['logits']

            # compute the cross-entropy loss
            ce_loss = self.CELoss(source_preds, source_gt)
            ce_loss.backward()

            ce_loss_iter += ce_loss
            loss += ce_loss

            if len(filtered_classes) > 0:
                # update the network parameters
                # 1) class-aware sampling
                source_samples_cls, source_nums_cls, \
                       target_samples_cls, target_nums_cls, source_sample_labels = self.CAS()

                source_sample_labels = torch.cat(source_sample_labels,
                                                 dim=0).cuda()
                # 2) forward and compute the loss
                source_cls_concat = torch.cat(
                    [to_cuda(samples) for samples in source_samples_cls],
                    dim=0)
                target_cls_concat = torch.cat(
                    [to_cuda(samples) for samples in target_samples_cls],
                    dim=0)

                self.net.module.set_bn_domain(
                    self.bn_domain_map[self.source_name])
                feats_source = self.net(source_cls_concat)
                self.net.module.set_bn_domain(
                    self.bn_domain_map[self.target_name])
                feats_target = self.net(target_cls_concat)

                # prepare the features
                feats_toalign_S = self.prepare_feats(feats_source)
                feats_toalign_T = self.prepare_feats(feats_target)

                domain_logits = self.dpn(source_cls_concat)
                domain_logits = domain_logits.reshape(
                    domain_logits.shape[0], self.opt.DATASET.NUM_CLASSES, -1)
                domain_prob_s = torch.zeros(domain_logits.shape,
                                            dtype=torch.float32).cuda()

                kl_loss = 0
                entropy_loss = 0
                num_active_classes = 0
                for i in range(self.opt.DATASET.NUM_CLASSES):
                    indexes = source_sample_labels == i
                    if indexes.sum() == 0:
                        continue
                    entropy_loss_cl, domain_prob_s_cl = self.entropy_loss(
                        domain_logits[indexes, i])
                    kl_loss += -self.get_domain_entropy(domain_prob_s_cl)
                    entropy_loss += entropy_loss_cl
                    domain_prob_s[indexes, i] = domain_prob_s_cl
                    num_active_classes += 1

                entropy_loss = entropy_loss * self.clustering_wt / num_active_classes
                kl_loss = kl_loss * self.clustering_wt / num_active_classes

                cdd_loss = self.cdd.forward(
                    feats_toalign_S, feats_toalign_T, source_nums_cls,
                    target_nums_cls, domain_prob_s,
                    source_sample_labels)[self.discrepancy_key]

                total_loss = cdd_loss * self.opt.CDD.LOSS_WEIGHT + entropy_loss + kl_loss
                total_loss.backward()

                print("Entropy loss:", entropy_loss, "KL_loss:", kl_loss)

                cdd_loss_iter += total_loss
                loss += total_loss

            # update the network
            self.optimizer.step()

            if self.opt.TRAIN.LOGGING and (update_iters+1) % \
                      (max(1, self.iters_per_loop // self.opt.TRAIN.NUM_LOGGING_PER_LOOP)) == 0:
                accu = self.model_eval(source_preds, source_gt)
                cur_loss = {
                    'ce_loss': ce_loss_iter,
                    'cdd_loss': cdd_loss_iter,
                    'total_loss': loss
                }
                self.logging(cur_loss, accu)

            self.opt.TRAIN.TEST_INTERVAL = min(1.0,
                                               self.opt.TRAIN.TEST_INTERVAL)
            self.opt.TRAIN.SAVE_CKPT_INTERVAL = min(
                1.0, self.opt.TRAIN.SAVE_CKPT_INTERVAL)

            if self.opt.TRAIN.TEST_INTERVAL > 0 and \
  (update_iters+1) % int(self.opt.TRAIN.TEST_INTERVAL * self.iters_per_loop) == 0:
                with torch.no_grad():
                    self.net.module.set_bn_domain(
                        self.bn_domain_map[self.target_name])
                    accu = self.test()
                    print('Test at (loop %d, iters: %d) with %s: %.4f.' %
                          (self.loop, self.iters, self.opt.EVAL_METRIC, accu))

            if self.opt.TRAIN.SAVE_CKPT_INTERVAL > 0 and \
  (update_iters+1) % int(self.opt.TRAIN.SAVE_CKPT_INTERVAL * self.iters_per_loop) == 0:
                self.save_ckpt()

            update_iters += 1
            self.iters += 1

            # update stop condition
            if update_iters >= self.iters_per_loop:
                stop = True
            else:
                stop = False