def train(train_loader, model):
    running_loss = 0.0
    data_size = train_data.__len__()

    model.train()
    # for inputs, masks, labels in progress_bar(train_loader, parent=mb):
    for inputs, masks, labels in train_loader:
        inputs, masks, labels = inputs.to(device), masks.to(device), labels.to(device)
        optimizer.zero_grad()

        with torch.set_grad_enabled(True):
            if args.is_pseudo:
                logit, logit_pixel, logit_image = model(inputs)
                loss1 = lovasz_hinge(logit.squeeze(1), masks.squeeze(1))
                loss2 = nn.BCELoss()(logit_image, labels)
                loss3 = lovasz_hinge2(logit_pixel.squeeze(1), masks.squeeze(1))
                loss = loss1 + loss2 + loss3
            else:
                logit = model(inputs)
                loss = lovasz_hinge(logit.squeeze(1), masks.squeeze(1))

            loss.backward()
            optimizer.step()

        running_loss += loss.item() * inputs.size(0)
        # mb.child.comment = 'loss: {}'.format(loss.item())
    epoch_loss = running_loss / data_size
    return epoch_loss
Example #2
0
File: train.py Project: chicm/ship
def criterion(args, output, target, epoch=0):
    mask_output, ship_output = output
    mask_target, ship_target = target

    #dice_loss = mixed_dice_bce_loss(mask_output, mask_target)
    focal_loss = focal_loss2d(mask_output, mask_target)
    #lovasz_loss = lovasz_hinge(mask_output, mask_target)

    lovasz_loss = (lovasz_hinge(mask_output, mask_target) +
                   lovasz_hinge(-mask_output, 1 - mask_target)) / 2

    bce_loss = F.binary_cross_entropy_with_logits(mask_output, mask_target)
    cls_loss = F.binary_cross_entropy_with_logits(ship_output, ship_target)

    if args.train_cls:
        #cls_loss = F.binary_cross_entropy_with_logits(ship_output, ship_target)
        return lovasz_loss + bce_loss + cls_loss, focal_loss.item(
        ), lovasz_loss.item(), bce_loss.item(), cls_loss.item()

    # four losses for: 1. grad, 2, display, 3, display 4, measurement
    #if epoch < 10:
    #    return bce_loss, focal_loss.item(), lovasz_loss.item(), 0., lovasz_loss.item() + focal_loss.item()*focal_weight
    #else:
    #return lovasz_loss+focal_loss*focal_weight, focal_loss.item(), lovasz_loss.item(), 0., lovasz_loss.item() + focal_loss.item()*focal_weight
    return lovasz_loss + bce_loss * 0.1, focal_loss.item(), lovasz_loss.item(
    ), bce_loss.item(), cls_loss.item()
def test(test_loader, model):
    running_loss = 0.0
    predicts = []
    truths = []

    model.eval()
    for inputs, masks in test_loader:
        inputs, masks = inputs.to(device), masks.to(device)
        with torch.set_grad_enabled(False):
            outputs = model(inputs)
            outputs = outputs[:, :,
                              args.pad_left:args.pad_left + args.fine_size,
                              args.pad_left:args.pad_left +
                              args.fine_size].contiguous()
            loss = lovasz_hinge(outputs.squeeze(1), masks.squeeze(1))

        predicts.append(F.sigmoid(outputs).detach().cpu().numpy())
        truths.append(masks.detach().cpu().numpy())
        running_loss += loss.item() * inputs.size(0)

    predicts = np.concatenate(predicts).squeeze()
    truths = np.concatenate(truths).squeeze()
    precision, _, _ = do_kaggle_metric(predicts, truths, 0.52)
    precision = precision.mean()
    epoch_loss = running_loss / val_data.__len__()
    return epoch_loss, precision
Example #4
0
File: train.py Project: chicm/salt
def weighted_loss(output, target, epoch=0):
    mask_output, _ = output
    mask_target, _ = target

    lovasz_loss = lovasz_hinge(mask_output, mask_target)
    dice_loss = mixed_dice_bce_loss(mask_output, mask_target)
    #print(bce_loss, lovasz_loss)
    if epoch < 5:
        return dice_loss
    else:
        return lovasz_loss  #, lovasz_loss.item(), bce_loss.item()
Example #5
0
 def criterion(self, logit_clf, truth, logit_mask=None, mask=None):
     """Define the (customized) loss function here."""
     ## 1. classification loss
     Loss_FUNC = FocalLoss()
     #Loss_FUNC = nn.BCEWithLogitsLoss()#nn.MultiLabelSoftMarginLoss()
     loss_clf = Loss_FUNC(logit_clf, truth)
     if logit_mask is not None:
         ## 2. segmentation mask loss
         loss_mask = L.lovasz_hinge(logit_mask, mask, ignore=255)
         return loss_clf, loss_mask
     else:
         return loss_clf
    def forward(self, input, target):

        pred = input.view(-1)
        truth = target.view(-1)

        bce_loss = nn.BCEWithLogitsLoss()(pred, truth).double()

        # lovasz loss
        lovasz_loss = L.lovasz_hinge(input, target, per_image=False)

        loss = bce_loss + lovasz_loss.double()

        return loss, bce_loss, lovasz_loss.double()
Example #7
0
    def __call__(self, logits, labels):
        loss = (1 - self.jaccard_weight) * (lovasz_hinge(
            logits, labels, per_image=True, ignore=None)) + self.focal_loss(
                logits, labels) * self.focal_weight

        if self.jaccard_weight:
            eps = 1e-15
            jaccard_target = (labels == 1).float()
            jaccard_output = F.sigmoid(logits)

            intersection = (jaccard_output * jaccard_target).sum()
            union = jaccard_output.sum() + jaccard_target.sum()

            loss -= self.jaccard_weight * torch.log(
                (intersection + eps) / (union - intersection + eps))
        return loss
Example #8
0
def weighted_loss(args, output, target, epoch=0):
    mask_output, salt_output = output
    mask_target, salt_target = target

    lovasz_loss = lovasz_hinge(mask_output, mask_target)
    focal_loss = focal_loss2d(mask_output, mask_target)

    focal_weight = 0.2

    if salt_output is not None and args.train_cls:
        salt_loss = F.binary_cross_entropy_with_logits(salt_output,
                                                       salt_target)
        return salt_loss, focal_loss.item(), lovasz_loss.item(
        ), salt_loss.item(
        ), lovasz_loss.item() + focal_loss.item() * focal_weight

    return lovasz_loss + focal_loss * focal_weight, focal_loss.item(
    ), lovasz_loss.item(
    ), 0., lovasz_loss.item() + focal_loss.item() * focal_weight
def train(train_loader, model):
    running_loss = 0.0

    model.train()
    # for inputs, masks, labels in progress_bar(train_loader, parent=mb):
    for inputs, masks, labels in train_loader:
        inputs, masks, labels = inputs.to(device), masks.to(device), labels.to(
            device)
        optimizer.zero_grad()

        with torch.set_grad_enabled(True):
            logit = model(inputs)
            loss = lovasz_hinge(logit.squeeze(1), masks.squeeze(1))
            loss.backward()
            optimizer.step()

        running_loss += loss.item() * inputs.size(0)
        # mb.child.comment = 'loss: {}'.format(loss.item())
    epoch_loss = running_loss / train_data.__len__()
    return epoch_loss
Example #10
0
 def __call__(self, logits, labels):
     return lovasz_hinge(
         logits, labels, self.per_image,
         self.ignore) + self.focal_loss(logits, labels) * self.focal_weight
Example #11
0
def test(cfg):
    num_classes = int(cfg['n_classes'])
    sample_size = int(cfg['fixed_size'])
    cfg['loss'] = cfg['loss'].split(' ')
    batch_size = 1
    cfg['batch_size'] = batch_size
    epoch = eval(cfg['n_epochs'])
    #n_gf = int(cfg['num_gf'])
    input_size = int(cfg['data_dim'])

    trans_val = []
    if cfg['rnd_sampling']:
        trans_val.append(ds.TestSampling(sample_size))
    #if cfg['standardization']:
    #    trans_val.append(ds.SampleStandardization())

    if cfg['dataset'] == 'left_ifof_ss_sl':
        dataset = ds.LeftIFOFSupersetDataset(
            cfg['sub_list_test'],
            cfg['dataset_dir'],
            transform=transforms.Compose(trans_val),
            uniform_size=True,
            train=False,
            split_obj=True,
            with_gt=cfg['with_gt'])
    elif cfg['dataset'] == 'hcp20_graph':
        dataset = ds.HCP20Dataset(
            cfg['sub_list_test'],
            cfg['dataset_dir'],
            transform=transforms.Compose(trans_val),
            with_gt=cfg['with_gt'],
            #distance=T.Distance(norm=True,cat=False),
            return_edges=True,
            split_obj=True,
            train=False,
            load_one_full_subj=False,
            standardize=cfg['standardization'])
    elif cfg['dataset'] == 'left_ifof_ss_sl_graph':
        dataset = ds.LeftIFOFSupersetGraphDataset(
            cfg['sub_list_test'],
            cfg['dataset_dir'],
            transform=transforms.Compose(trans_val),
            train=False,
            split_obj=True,
            with_gt=cfg['with_gt'])
    elif cfg['dataset'] == 'left_ifof_emb':
        dataset = ds.EmbDataset(cfg['sub_list_test'],
                                cfg['emb_dataset_dir'],
                                cfg['gt_dataset_dir'],
                                transform=transforms.Compose(trans_val),
                                load_all=cfg['load_all_once'],
                                precompute_graph=cfg['precompute_graph'],
                                k_graph=int(cfg['knngraph']))
    elif cfg['dataset'] == 'psb_airplane':
        dataset = ds.PsbAirplaneDataset(cfg['dataset_dir'], train=False)
    elif cfg['dataset'] == 'shapes':
        dataset = ds.ShapesDataset(cfg['dataset_dir'],
                                   train=False,
                                   multi_cat=cfg['multi_category'])
    elif cfg['dataset'] == 'shapenet':
        dataset = ds.ShapeNetCore(cfg['dataset_dir'],
                                  train=False,
                                  multi_cat=cfg['multi_category'])
    elif cfg['dataset'] == 'modelnet':
        dataset = ds.ModelNetDataset(cfg['dataset_dir'],
                                     split=cfg['mn40_split'],
                                     fold_size=int(cfg['mn40_fold_size']),
                                     load_all=cfg['load_all_once'])
    elif cfg['dataset'] == 'scanobj':
        dataset = ds.ScanObjNNDataset(cfg['dataset_dir'],
                                      run='test',
                                      variant=cfg['scanobj_variant'],
                                      background=cfg['scanobj_bg'],
                                      load_all=cfg['load_all_once'])
    else:
        dataset = ds.DRLeftIFOFSupersetDataset(
            cfg['sub_list_test'],
            cfg['val_dataset_dir'],
            transform=transforms.Compose(trans_val),
            with_gt=cfg['with_gt'])

    dataloader = DataLoader(dataset,
                            batch_size=batch_size,
                            shuffle=False,
                            num_workers=0)
    print("Validation dataset loaded, found %d samples" % (len(dataset)))

    for ext in range(100):
        logdir = '%s/test_%d' % (cfg['exp_path'], ext)
        if not os.path.exists(logdir):
            break
    writer = SummaryWriter(logdir)
    if cfg['weights_path'] == '':
        cfg['weights_path'] = glob.glob(cfg['exp_path'] + '/models/best*')[0]
        epoch = int(cfg['weights_path'].rsplit('-', 1)[1].split('.')[0])
    elif 'ep-' in cfg['weights_path']:
        epoch = int(cfg['weights_path'].rsplit('-', 1)[1].split('.')[0])

    tb_log_name = glob.glob('%s/events*' % writer.logdir)[0].rsplit('/', 1)[1]
    tb_log_dir = 'tb_logs/%s' % logdir.split('/', 1)[1]
    os.system('mkdir -p %s' % tb_log_dir)
    os.system('ln -sr %s/%s %s/%s ' %
              (writer.logdir, tb_log_name, tb_log_dir, tb_log_name))

    #### BUILD THE MODEL
    classifier = get_model(cfg)

    classifier.cuda()
    classifier.load_state_dict(torch.load(cfg['weights_path']))
    classifier.eval()

    with torch.no_grad():
        pred_buffer = {}
        sm_buffer = {}
        sm2_buffer = {}
        gf_buffer = {}
        emb_buffer = {}
        print('\n\n')
        mean_val_acc = torch.tensor([])
        mean_val_iou = torch.tensor([])
        mean_val_prec = torch.tensor([])
        mean_val_recall = torch.tensor([])

        if 'split_obj' in dir(dataset) and dataset.split_obj:
            split_obj = True
        else:
            split_obj = False
            dataset.transform = []

        if split_obj:
            consumed = False
        else:
            consumed = True
        j = 0
        visualized = 0
        new_obj_read = True
        sls_count = 1
        while j < len(dataset):
            #while sls_count <= len(dataset):
            data = dataset[j]

            if split_obj:
                if new_obj_read:
                    obj_pred_choice = torch.zeros(data['obj_full_size'],
                                                  dtype=torch.int).cuda()
                    obj_target = torch.zeros(data['obj_full_size'],
                                             dtype=torch.int).cuda()
                    new_obj_read = False
                    #if cfg['save_embedding']:
                    #obj_embedding = torch.empty((data['obj_full_size'], int(cfg['embedding_size']))).cuda()

                if len(dataset.remaining[j]) == 0:
                    consumed = True

            sample_name = data['name'] if type(
                data['name']) == str else data['name'][0]

            #print(points)
            #if len(points.shape()) == 2:
            #points = points.unsqueeze(0)
            if 'graph' not in cfg['dataset']:
                points = data['points']
                if cfg['with_gt']:
                    target = data['gt']
                    target = target.to('cuda')
                    target = target.view(-1, 1)[:, 0]
            else:
                #print(data)
                points = gBatch().from_data_list([data['points']])
                #points = data['points']
                if 'bvec' in points.keys:
                    points.batch = points.bvec.clone()
                    del points.bvec
                # points.ori_batch = torch.zeros(points.x.size(0)).long()
                if cfg['with_gt']:
                    target = points['y']
                    target = target.to('cuda')
                    target = target.view(-1, 1)[:, 0]
                if cfg['same_size']:
                    points['lengths'] = points['lengths'][0].item()
            #if cfg['model'] == 'pointnet_cls':
            #points = points.view(len(data['obj_idxs']), -1, input_size)
            points = points.to('cuda')
            #print('streamline number:',sls_count)
            #sls_count+=1
            #print('lengths:',points['lengths'].item())
            ### add one-hot labels if multi-category task
            #new_k = points['lengths'].item()*(5/16)
            #print('new k:',new_k,'rounded k:',int(round(new_k)))
            #classifier.conv2.k = int(round(new_k))

            #if cfg['multi_loss']:
            #logits, gf = classifier(points)
            #else:
            logits = classifier(points)
            logits = logits.view(-1, num_classes)

            if len(cfg['loss']) == 2:
                if epoch <= int(cfg['switch_loss_epoch']):
                    loss_type = cfg['loss'][0]
                else:
                    loss_type = cfg['loss'][1]
            else:
                loss_type = cfg['loss'][0]

                if loss_type == 'nll':
                    pred = F.log_softmax(logits, dim=-1)
                    pred = pred.view(-1, num_classes)
                    probas = torch.exp(pred.data)
                    pred_choice = pred.data.max(1)[1].int()
                    if cfg['with_gt']:
                        loss_seg = F.nll_loss(pred, target.long())
                elif loss_type == 'LLh':
                    pred_choice = (logits.data > 0).int()
                    if cfg['with_gt']:
                        loss_seg = L.lovasz_hinge(
                            logits.view(batch_size, sample_size, 1),
                            target.view(batch_size, sample_size, 1),
                            per_image=False)
                    #loss = L.lovasz_hinge_flat(pred.view(-1), target.view(-1))
                elif loss_type == 'LLm':
                    pred = F.softmax(logits, dim=-1)
                    probas = pred.data
                    pred_choice = pred.data.max(1)[1].int()
                    if cfg['with_gt']:
                        loss = L.lovasz_softmax_flat(
                            pred,
                            target,
                            op=cfg['llm_op'],
                            only_present=cfg['multi_category'])
            #print('pred:',pred)
            #print('pred shape:',pred.shape)
            #print('pred choice:',pred_choice)
            #print('pred choice shape:',pred_choice.shape)
            #if visualized < int(cfg['viz_clusters']):
            #    visualized += 1
            #    colors = torch.from_numpy(get_spaced_colors(n_gf))
            #    sm_out = classifier.feat.mf.softmax_out[0,:,:].max(1)[1].squeeze().int()
            #    writer.add_mesh('latent clustering', points, colors[sm_out.tolist()].unsqueeze(0))
            #    if 'bg' in data.keys():
            #        bg_msk = data['bg']*-1
            #        writer.add_mesh('bg_mask', points, colors[bg_msk.tolist()].unsqueeze(0))

            if split_obj:
                obj_pred_choice[data['obj_idxs']] = pred_choice
                obj_target[data['obj_idxs']] = target.int()
                #if cfg['save_embedding']:
                #    obj_embedding[data['obj_idxs']] = classifier.embedding.squeeze()
            else:
                obj_data = points
                obj_pred_choice = pred_choice
                obj_target = target
                if cfg['save_embedding']:
                    obj_embedding = classifier.embedding.squeeze()

            if cfg['with_gt'] and consumed:
                #if cfg['multi_loss']:
                #    loss_cluster = cluster_loss_fn(gf.squeeze(3))
                #    loss = loss_seg + alfa * loss_cluster

                #pred_choice = torch.sigmoid(pred.view(-1,1)).data.round().type_as(target.data)
                #print('points:',points['streamlines'])
                #print('points shape:',points['streamlines'].shape)
                #print('streamlines:',
                data_dir = cfg['dataset_dir']
                #streamlines, head, leng, idxs = load_streamlines(data['dir']+'/'+data['name']+'.trk')
                #print('tract:',len(streamlines))
                #print('pred:',obj_pred_choice)
                #print('taget:',obj_target)
                #print('pred shape:',obj_pred_choice.shape)
                #print('target shape:',obj_target.shape)
                print('val max class red ', obj_pred_choice.max().item())
                print('val min class pred ', obj_pred_choice.min().item())
                y_pred = obj_pred_choice.cpu().numpy()
                np.save(data['dir'] + '/y_pred_sDEC_k5_16pts_nodropout',
                        y_pred)
                y_test = obj_target.cpu().numpy()
                np.save(data['dir'] + '/y_test_sDEC_k5_16pts_nodropout',
                        y_test)
                #np.save(data['dir']+'/streamlines_lstm_GIN',streamlines)
                correct = obj_pred_choice.eq(obj_target.data.int()).cpu().sum()
                acc = correct.item() / float(obj_target.size(0))

                if num_classes > 2:
                    iou, prec, recall = L.iou_multi(
                        obj_pred_choice.data.int().cpu().numpy(),
                        obj_target.data.int().cpu().numpy(),
                        num_classes,
                        multi_cat=cfg['multi_category'])
                    assert (np.isnan(iou).sum() == 0)
                    if cfg['multi_category']:
                        s, n_parts = data['gt_offset']
                        e = s + n_parts
                        iou[0, :s], prec[0, :s], recall[0, :s] = 0., 0., 0.
                        iou[0, e:], prec[0, e:], recall[0, e:] = 0., 0., 0.
                        iou = torch.from_numpy(iou).float()
                        prec = torch.from_numpy(prec).float()
                        recall = torch.from_numpy(recall).float()
                        category = data['category'].squeeze().nonzero().float()
                        iou = torch.cat([iou, category], 1)
                    else:
                        iou = torch.tensor([iou.mean()])
                        prec = torch.tensor([prec.mean()])
                        recall = torch.tensor([recall.mean()])
                    assert (torch.isnan(iou).sum() == 0)

                else:
                    tp = torch.mul(
                        obj_pred_choice.data,
                        obj_target.data.int()).cpu().sum().item() + 0.00001
                    fp = obj_pred_choice.gt(
                        obj_target.data.int()).cpu().sum().item()
                    fn = obj_pred_choice.lt(
                        obj_target.data.int()).cpu().sum().item()
                    tn = correct.item() - tp
                    iou = torch.tensor([float(tp) / (tp + fp + fn)])
                    prec = torch.tensor([float(tp) / (tp + fp)])
                    recall = torch.tensor([float(tp) / (tp + fn)])

                mean_val_prec = torch.cat((mean_val_prec, prec), 0)
                mean_val_recall = torch.cat((mean_val_recall, recall), 0)
                mean_val_iou = torch.cat((mean_val_iou, iou), 0)
                mean_val_acc = torch.cat((mean_val_acc, torch.tensor([acc])),
                                         0)
                print('VALIDATION [%d: %d/%d] val accuracy: %f' \
                        % (epoch, j, len(dataset), acc))

            if cfg['save_pred'] and consumed:
                print('buffering prediction %s' % sample_name)
                if num_classes > 2:
                    for cl in range(num_classes):
                        sl_idx = np.where(
                            obj_pred.data.cpu().view(-1).numpy() == cl)[0]
                        if cl == 0:
                            pred_buffer[sample_name] = []
                        pred_buffer[sample_name].append(sl_idx.tolist())
                else:
                    sl_idx = np.where(
                        obj_pred.data.cpu().view(-1).numpy() == 1)[0]
                    pred_buffer[sample_name] = sl_idx.tolist()
            #if cfg['save_softmax_out']:
            #    if cfg['model'] in 'pointnet_mgfml':
            #        if sample_name not in sm_buffer.keys():
            #            sm_buffer[sample_name] = []
            #        if classifier.feat.multi_feat > 1:
            #            sm_buffer[sample_name].append(
            #                classifier.feat.mf.softmax_out.cpu().numpy())
            #    if cfg['model'] == 'pointnet_mgfml':
            #        for l in classifier.layers:
            #            sm_buffer[sample_name].append(
            #                    l.mf.softmax_out.cpu().numpy())
            #    sm2_buffer[sample_name] = probas.cpu().numpy()
            #if cfg['save_gf']:
            #   gf_buffer[sample_name] = np.unique(
            #           classifier.feat.globalfeat.data.cpu().squeeze().numpy(), axis = 0)
            #gf_buffer[sample_name] = classifier.globalfeat
            #if cfg['save_embedding'] and consumed:
            #    emb_buffer[sample_name] = obj_embedding

            if consumed:
                print(j)
                j += 1
                if split_obj:
                    consumed = False
                    new_obj_read = True

        macro_iou = torch.mean(mean_val_iou)
        macro_prec = torch.mean(mean_val_prec)
        macro_recall = torch.mean(mean_val_recall)

        epoch_iou = macro_iou.item()

    if cfg['save_pred']:
        #os.system('rm -r %s/predictions_test*' % writer.logdir)
        pred_dir = writer.logdir + '/predictions_test_%d' % epoch
        if not os.path.exists(pred_dir):
            os.makedirs(pred_dir)
        print('saving files')
        for filename, value in pred_buffer.items():
            with open(os.path.join(pred_dir, filename) + '.pkl', 'wb') as f:
                pickle.dump(value, f, protocol=pickle.HIGHEST_PROTOCOL)

    #if cfg['save_softmax_out']:
    #    os.system('rm -r %s/sm_out_test*' % writer.logdir)
    #    sm_dir = writer.logdir + '/sm_out_test_%d' % epoch
    #    if not os.path.exists(sm_dir):
    #        os.makedirs(sm_dir)
    #    for filename, value in sm_buffer.iteritems():
    #        with open(os.path.join(sm_dir, filename) + '_sm_1.pkl', 'wb') as f:
    #            pickle.dump(
    #                value, f, protocol=pickle.HIGHEST_PROTOCOL)
    #    for filename, value in sm2_buffer.iteritems():
    #        with open(os.path.join(sm_dir, filename) + '_sm_2.pkl', 'wb') as f:
    #            pickle.dump(
    #                value, f, protocol=pickle.HIGHEST_PROTOCOL)

    #if cfg['save_gf']:
    #os.system('rm -r %s/gf_test*' % writer.logdir)
    #    gf_dir = writer.logdir + '/gf_test_%d' % epoch
    #    if not os.path.exists(gf_dir):
    #        os.makedirs(gf_dir)
    #    i = 0
    #    for filename, value in gf_buffer.items():
    #        if i == 3:
    #            break
    #        with open(os.path.join(gf_dir, filename) + '.pkl', 'wb') as f:
    #            pickle.dump(value, f, protocol=pickle.HIGHEST_PROTOCOL)

    #if cfg['save_embedding']:
    #    print('saving embedding')
    #    emb_dir = writer.logdir + '/embedding_test_%d' % epoch
    #    if not os.path.exists(emb_dir):
    #        os.makedirs(emb_dir)
    #    for filename, value in emb_buffer.iteritems():
    #        np.save(os.path.join(emb_dir, filename), value.cpu().numpy())

    if cfg['with_gt']:
        print('TEST ACCURACY: %f' % torch.mean(mean_val_acc).item())
        print('TEST PRECISION: %f' % macro_prec.item())
        print('TEST RECALL: %f' % macro_recall.item())
        print('TEST IOU: %f' % macro_iou.item())
        mean_val_dsc = mean_val_prec * mean_val_recall * 2 / (mean_val_prec +
                                                              mean_val_recall)
        final_scores_file = writer.logdir + '/final_scores_test_%d.txt' % epoch
        scores_file = writer.logdir + '/scores_test_%d.txt' % epoch
        print('saving scores')
        with open(scores_file, 'w') as f:
            f.write('acc\n')
            f.writelines('%f\n' % v for v in mean_val_acc.tolist())
            f.write('prec\n')
            f.writelines('%f\n' % v for v in mean_val_prec.tolist())
            f.write('recall\n')
            f.writelines('%f\n' % v for v in mean_val_recall.tolist())
            f.write('dsc\n')
            f.writelines('%f\n' % v for v in mean_val_dsc.tolist())
            f.write('iou\n')
            f.writelines('%f\n' % v for v in mean_val_iou.tolist())
        with open(final_scores_file, 'w') as f:
            f.write('acc\n')
            f.write('%f\n' % mean_val_acc.mean())
            f.write('%f\n' % mean_val_acc.std())
            f.write('prec\n')
            f.write('%f\n' % mean_val_prec.mean())
            f.write('%f\n' % mean_val_prec.std())
            f.write('recall\n')
            f.write('%f\n' % mean_val_recall.mean())
            f.write('%f\n' % mean_val_recall.std())
            f.write('dsc\n')
            f.write('%f\n' % mean_val_dsc.mean())
            f.write('%f\n' % mean_val_dsc.std())
            f.write('iou\n')
            f.write('%f\n' % mean_val_iou.mean())
            f.write('%f\n' % mean_val_iou.std())

    print('\n\n')
Example #12
0
 def criterion(self, logit, truth):
     logit = logit.squeeze(1)
     truth = truth.squeeze(1)
     loss = L.lovasz_hinge(logit, truth, per_image=True, ignore=None)
     return loss
        train_loss = []
        train_iou = []

        model.train()
        with tqdm(train_loader) as pbar:
            for images, masks in pbar: 
                masks = masks.cuda()
                y_pred = model(Variable(images).cuda())

                prob = torch.sigmoid(y_pred).cpu().data.numpy()
                truth = masks.cpu().data.numpy()

                iou = do_kaggle_metric(prob, truth, threshold=0.5)
                train_iou.append(iou)

                loss = L.lovasz_hinge(y_pred.squeeze(), masks.squeeze().cuda(), per_image=True, ignore=None)
                train_loss.append(loss.item())

                loss.backward()
                optimizer.step()
                optimizer.zero_grad()

                pbar.set_description("Loss: %.3f, IoU: %.3f, Progress" % (loss, iou))
            
        val_loss = []
        val_iou = []
        model.eval()
        with tqdm(val_loader) as pbar:
            for images, masks in pbar:
                if len(images) == 2:
                    image_ori, image_rev = images
Example #14
0
def build_model(main_prog, start_prog, phase=ModelPhase.TRAIN):
    if not ModelPhase.is_valid_phase(phase):
        raise ValueError("ModelPhase {} is not valid!".format(phase))
    if ModelPhase.is_train(phase):
        width = cfg.TRAIN_CROP_SIZE[0]
        height = cfg.TRAIN_CROP_SIZE[1]
    else:
        width = cfg.EVAL_CROP_SIZE[0]
        height = cfg.EVAL_CROP_SIZE[1]

    image_shape = [-1, cfg.DATASET.DATA_DIM, height, width]
    grt_shape = [-1, 1, height, width]
    class_num = cfg.DATASET.NUM_CLASSES

    with fluid.program_guard(main_prog, start_prog):
        with fluid.unique_name.guard():
            # 在导出模型的时候,增加图像标准化预处理,减小预测部署时图像的处理流程
            # 预测部署时只须对输入图像增加batch_size维度即可
            if ModelPhase.is_predict(phase):
                if cfg.SLIM.PREPROCESS:
                    image = fluid.data(
                        name='image', shape=image_shape, dtype='float32')
                else:
                    origin_image = fluid.data(
                        name='image',
                        shape=[-1, -1, -1, cfg.DATASET.DATA_DIM],
                        dtype='float32')
                    image, valid_shape, origin_shape = export_preprocess(
                        origin_image)

            else:
                image = fluid.data(
                    name='image', shape=image_shape, dtype='float32')
            label = fluid.data(name='label', shape=grt_shape, dtype='int32')
            mask = fluid.data(name='mask', shape=grt_shape, dtype='int32')

            # use DataLoader when doing traning and evaluation
            if ModelPhase.is_train(phase) or ModelPhase.is_eval(phase):
                data_loader = fluid.io.DataLoader.from_generator(
                    feed_list=[image, label, mask],
                    capacity=cfg.DATALOADER.BUF_SIZE,
                    iterable=False,
                    use_double_buffer=True)

            loss_type = cfg.SOLVER.LOSS
            if not isinstance(loss_type, list):
                loss_type = list(loss_type)

            # lovasz_hinge_loss或dice_loss或bce_loss只适用两类分割中
            if class_num > 2 and (("lovasz_hinge_loss" in loss_type) or
                                  ("dice_loss" in loss_type) or
                                  ("bce_loss" in loss_type)):
                raise Exception(
                    "lovasz hinge loss, dice loss and bce loss are only applicable to binary classfication."
                )

            # 在两类分割情况下,当loss函数选择lovasz_hinge_loss或dice_loss或bce_loss的时候,最后logit输出通道数设置为1
            if ("dice_loss" in loss_type) or ("bce_loss" in loss_type) or (
                    "lovasz_hinge_loss" in loss_type):
                class_num = 1
                if ("softmax_loss" in loss_type) or (
                        "lovasz_softmax_loss" in loss_type):
                    raise Exception(
                        "softmax loss or lovasz softmax loss can not combine with bce loss or dice loss or lovasz hinge loss."
                    )
            logits = seg_model(image, class_num)

            # 根据选择的loss函数计算相应的损失函数
            if ModelPhase.is_train(phase) or ModelPhase.is_eval(phase):
                loss_valid = False
                avg_loss_list = []
                valid_loss = []
                if "softmax_loss" in loss_type:
                    weight = cfg.SOLVER.CROSS_ENTROPY_WEIGHT
                    avg_loss_list.append(
                        multi_softmax_with_loss(logits, label, mask, class_num,
                                                weight))
                    loss_valid = True
                    valid_loss.append("softmax_loss")
                if "dice_loss" in loss_type:
                    avg_loss_list.append(multi_dice_loss(logits, label, mask))
                    loss_valid = True
                    valid_loss.append("dice_loss")
                if "bce_loss" in loss_type:
                    avg_loss_list.append(multi_bce_loss(logits, label, mask))
                    loss_valid = True
                    valid_loss.append("bce_loss")
                if "lovasz_hinge_loss" in loss_type:
                    avg_loss_list.append(
                        lovasz_hinge(logits, label, ignore=mask))
                    loss_valid = True
                    valid_loss.append("lovasz_hinge_loss")
                if "lovasz_softmax_loss" in loss_type:
                    probas = fluid.layers.softmax(logits, axis=1)
                    avg_loss_list.append(
                        lovasz_softmax(probas, label, ignore=mask))
                    loss_valid = True
                    valid_loss.append("lovasz_softmax_loss")
                if not loss_valid:
                    raise Exception(
                        "SOLVER.LOSS: {} is set wrong. it should "
                        "include one of (softmax_loss, bce_loss, dice_loss, lovasz_hinge_loss, lovasz_softmax_loss) at least"
                        " example: ['softmax_loss'], ['dice_loss'], ['bce_loss', 'dice_loss'], ['lovasz_hinge_loss','bce_loss'], ['lovasz_softmax_loss','softmax_loss']"
                        .format(cfg.SOLVER.LOSS))

                invalid_loss = [x for x in loss_type if x not in valid_loss]
                if len(invalid_loss) > 0:
                    print(
                        "Warning: the loss {} you set is invalid. it will not be included in loss computed."
                        .format(invalid_loss))

                avg_loss = 0
                for i in range(0, len(avg_loss_list)):
                    loss_name = valid_loss[i].upper()
                    loss_weight = eval('cfg.SOLVER.LOSS_WEIGHT.' + loss_name)
                    avg_loss += loss_weight * avg_loss_list[i]

            #get pred result in original size
            if isinstance(logits, tuple):
                logit = logits[0]
            else:
                logit = logits

            if logit.shape[2:] != label.shape[2:]:
                logit = fluid.layers.resize_bilinear(logit, label.shape[2:])

            # return image input and logit output for inference graph prune
            if ModelPhase.is_predict(phase):
                # 两类分割中,使用lovasz_hinge_loss或dice_loss或bce_loss返回的logit为单通道,进行到两通道的变换
                if class_num == 1:
                    logit = sigmoid_to_softmax(logit)
                else:
                    logit = softmax(logit)

                # 获取有效部分
                if cfg.SLIM.PREPROCESS:
                    return image, logit

                else:
                    logit = fluid.layers.slice(
                        logit, axes=[2, 3], starts=[0, 0], ends=valid_shape)

                    logit = fluid.layers.resize_bilinear(
                        logit,
                        out_shape=origin_shape,
                        align_corners=False,
                        align_mode=0)
                    logit = fluid.layers.argmax(logit, axis=1)
                return origin_image, logit

            if class_num == 1:
                out = sigmoid_to_softmax(logit)
                out = fluid.layers.transpose(out, [0, 2, 3, 1])
            else:
                out = fluid.layers.transpose(logit, [0, 2, 3, 1])

            pred = fluid.layers.argmax(out, axis=3)
            pred = fluid.layers.unsqueeze(pred, axes=[3])
            if ModelPhase.is_visual(phase):
                if class_num == 1:
                    logit = sigmoid_to_softmax(logit)
                else:
                    logit = softmax(logit)
                return pred, logit

            if ModelPhase.is_eval(phase):
                return data_loader, avg_loss, pred, label, mask

            if ModelPhase.is_train(phase):
                optimizer = solver.Solver(main_prog, start_prog)
                decayed_lr = optimizer.optimise(avg_loss)
                return data_loader, avg_loss, decayed_lr, pred, label, mask
 def forward(self, input, target):
     loss = L.lovasz_hinge(input, target)
     return loss
Example #16
0
        # train
        net.train()
        for batch_image, batch_mask in tqdm(dataloader['train']):
            optimizer.zero_grad()
            # pdb.set_trace()

            batch_image = batch_image.cuda()
            batch_mask = batch_mask.cuda()
            with torch.set_grad_enabled(True):
                outputs = net(batch_image).squeeze(dim=1)

                if epoch < 120:
                    loss = criterion(outputs, batch_mask)
                else:
                    loss = lovasz_hinge(outputs, batch_mask)

                loss.backward()
                optimizer.step()
            train_running_corrects += torch.sum(
                (outputs > 0.5) == (batch_mask > 0.5)).item()
            train_running_loss += loss.item() * batch_image.size(0)
            train_running_dice_loss += dice_loss(
                outputs, batch_mask).item() * batch_image.size(0)
            train_iou += iou(outputs, batch_mask) * batch_image.size(0)

        # val
        net.eval()
        for batch_image, batch_mask in tqdm(dataloader['val']):
            batch_image = batch_image.cuda()
            batch_mask = batch_mask.cuda()
Example #17
0
    def forward(self, input, target):
        input = input.squeeze(1)
        target = target.squeeze(1)
        loss = lovasz_hinge(input, target, per_image=True)

        return loss
Example #18
0
def symmetric_lovasz(outputs, targets):
    return (lovasz_hinge(outputs, targets) +
            lovasz_hinge(-outputs, 1 - targets)) / 2
 def forward(self, input, target):
     #         loss = self.alpha*self.focal(input, target) - torch.log(dice_loss(input, target))
     loss = L.lovasz_hinge(input, target)
     return loss.mean()
def unet_train():

    batch_size = 1
    num_epochs = [5000, 5000, 5000, 5000, 5000, 5000, 5000, 5000]
    num_workers = 2
    lr = 0.0001

    losslist = ['dice']  # ['focal', 'bce', 'dice', 'lovasz']
    optimlist = ['adam']  # ['adam', 'sgd']
    iflog = True

    SC_root_dir = '../dataset-EdmSealedCrack-512'
    train_files, val_files, test_files = myutils.organize_SC_files(SC_root_dir)

    train_RC_dataset = DatasetRealCrack('../dataset-EdmCrack600-512/A/train',
                                        transform=transform)
    train_SC_dataset = DatasetSealedCrack(files=train_files,
                                          root_dir=SC_root_dir,
                                          transform=data_Train_transforms)
    val_RC_dataset = DatasetRealCrack('../dataset-EdmCrack600-512/A/val',
                                      transform=transform)
    val_SC_dataset = DatasetSealedCrack(files=val_files,
                                        root_dir=SC_root_dir,
                                        transform=data_Test_transforms)

    train_loader = torch.utils.data.DataLoader(ConcatDataset(
        train_RC_dataset, train_SC_dataset),
                                               batch_size=2,
                                               shuffle=True,
                                               num_workers=2)

    criterion = nn.BCELoss()
    focallos = FocalLoss(gamma=2)
    doubleFocalloss = focalloss.FocalLoss_2_datasets(gamma=2)

    epoidx = -1
    for los in losslist:
        for opt in optimlist:
            start = time.time()
            print(los, opt)
            torch.manual_seed(77)
            torch.cuda.manual_seed(77)
            #################
            #unet = Unet_SpatialPyramidPooling(3).cuda()
            #################
            unet = Unet(3).cuda()
            SC_classifier = classifier(64, 2).cuda()
            RC_classifier = classifier(64, 2).cuda()

            ##################
            #unet = smp.Unet('resnet34', encoder_weights='imagenet').cuda()
            #unet.segmentation_head = torch.nn.Sequential().cuda()
            #SC_classifier = classifier(16, 2).cuda()
            #RC_classifier = classifier(16, 2).cuda()

            #UNCOMMENT TO KEEP TRAINING THE BEST MODEL
            prev_epoch = 0  # if loading model 58, change to prev_epoch = 58. When saving the model, it is going to be named as 59, 60, 61...
            #unet.load_state_dict(torch.load('trained_models/unet_adam_dice_58.pkl'))
            #SC_classifier.load_state_dict(torch.load('trained_models/SC_classifier_adam_dice_58.pkl'))
            #RC_classifier.load_state_dict(torch.load('trained_models/RC_classifier_adam_dice_58.pkl'))

            history = []
            if 'adam' in opt:
                optimizer = torch.optim.Adam(unet.parameters(), lr=lr)
            elif 'sgd' in opt:
                optimizer = torch.optim.SGD(unet.parameters(),
                                            lr=10 * lr,
                                            momentum=0.9)

            logging.basicConfig(filename='./logs/logger_unet.log',
                                level=logging.INFO)

            total_step = len(train_loader)
            epoidx += 1
            for epoch in range(num_epochs[epoidx]):
                totalloss = 0
                for i, (realCrack_batch,
                        sealedCrack_batch) in enumerate(train_loader):
                    SC_images = sealedCrack_batch[0].cuda()
                    SC_masks = sealedCrack_batch[1].cuda()
                    RC_images = realCrack_batch[0].cuda()
                    RC_masks = realCrack_batch[1].cuda()
                    SC_encoder = unet(SC_images)
                    RC_encoder = unet(RC_images)
                    #############
                    SC_outputs = SC_classifier(SC_encoder)
                    RC_outputs = RC_classifier(RC_encoder)
                    #############
                    #Deep lab v3
                    #SC_outputs = SC_classifier(SC_encoder['out'])
                    #RC_outputs = RC_classifier(RC_encoder['out'])
                    ##############
                    if 'bce' in los:
                        masks = onehot(masks)
                        loss = criterion(outputs, masks)
                    elif 'dice' in los:
                        branch_RC = {'outputs': RC_outputs, 'masks': RC_masks}
                        branch_SC = {'outputs': SC_outputs, 'masks': SC_masks}
                        loss = dice_loss_2_datasets(branch_RC, branch_SC)
                        #masks = onehot(masks)
                        #loss = dice_loss(outputs, masks)
                    elif 'lovasz' in los:
                        masks = onehot(masks)
                        loss = L.lovasz_hinge(outputs, masks)
                    elif 'focal' in los:
                        #loss = focallos(outputs, masks.long())
                        branch_RC = {
                            'outputs': RC_outputs,
                            'masks': RC_masks.long()
                        }
                        branch_SC = {
                            'outputs': SC_outputs,
                            'masks': SC_masks.long()
                        }
                        loss = doubleFocalloss(branch_RC, branch_SC)
                    totalloss += loss * RC_images.size(0)  #*2?
                    #print(RC_images.size(0))

                    optimizer.zero_grad()
                    loss.backward()
                    optimizer.step()

                    if i % 10 == 0:
                        print(epoch, i)
                        print("total loss: ", totalloss)
                    if i % 1000 == 0:
                        print("Epoch:%d;     Iteration:%d;      Loss:%f" %
                              (epoch, i, loss))

                    if i + 1 == total_step:  # and epoch%1==0: #and val_miou>0.85:
                        torch.save(
                            unet.state_dict(),
                            './trained_models/unet_' + opt + '_' + los + '_' +
                            str(epoch + 1 + prev_epoch) + '.pkl')
                        torch.save(
                            RC_classifier.state_dict(),
                            './trained_models/RC_classifier_' + opt + '_' +
                            los + '_' + str(epoch + 1 + prev_epoch) + '.pkl')
                        torch.save(
                            SC_classifier.state_dict(),
                            './trained_models/SC_classifier_' + opt + '_' +
                            los + '_' + str(epoch + 1 + prev_epoch) + '.pkl')
                history_np = np.array(history)
                np.save('./logs/unet_' + opt + '_' + los + '.npy', history_np)
            end = time.time()
            print((end - start) / 60)