Example #1
0
    def __init__(self):

        super(AffinityDisplacementLoss, self).__init__()

        self.path_index = indexing.PathIndex(radius=10,
                                             default_size=(irn_crop_size // 4,
                                                           irn_crop_size // 4))
        path_index = self.path_index

        self.n_path_lengths = len(path_index.path_indices)
        for i, pi in enumerate(path_index.path_indices):
            self.register_buffer(
                AffinityDisplacementLoss.path_indices_prefix + str(i),
                torch.from_numpy(pi))

        self.register_buffer(
            'disp_target',
            torch.unsqueeze(
                torch.unsqueeze(
                    torch.from_numpy(path_index.search_dst).transpose(1, 0),
                    0), -1).float())

        self.params = [
            tuple(self.edge_layers.parameters()),
            tuple(self.dp_layers.parameters())
        ]
Example #2
0
def run(args):
    path_index = indexing.PathIndex(radius=10,
                                    default_size=(args.irn_crop_size // 4,
                                                  args.irn_crop_size // 4))
    model = getattr(importlib.import_module(args.irn_network),
                    'AffinityDisplacementLoss')(path_index)

    train_dataset = voc12.dataloader.VOC12AffinityDataset(
        args.train_list,
        label_dir=args.ir_label_out_dir,
        voc12_root=args.voc12_root,
        indices_from=path_index.src_indices,
        indices_to=path_index.dst_indices,
        hor_flip=True,
        crop_size=args.irn_crop_size,
        crop_method="random",
        rescale=(0.5, 1.5))
    train_data_loader = DataLoader(train_dataset,
                                   batch_size=args.irn_batch_size,
                                   shuffle=True,
                                   num_workers=args.num_workers,
                                   pin_memory=True,
                                   drop_last=True)

    max_step = (len(train_dataset) //
                args.irn_batch_size) * args.irn_num_epoches

    param_groups = model.trainable_parameters()
    optimizer = torchutils.PolyOptimizer([{
        'params': param_groups[0],
        'lr': 1 * args.irn_learning_rate,
        'weight_decay': args.irn_weight_decay
    }, {
        'params': param_groups[1],
        'lr': 10 * args.irn_learning_rate,
        'weight_decay': args.irn_weight_decay
    }],
                                         lr=args.irn_learning_rate,
                                         weight_decay=args.irn_weight_decay,
                                         max_step=max_step)

    model = torch.nn.DataParallel(model).cuda()
    model.train()

    avg_meter = pyutils.AverageMeter()
    timer = pyutils.Timer()
    for ep in range(args.irn_num_epoches):
        print('Epoch %d/%d' % (ep + 1, args.irn_num_epoches))
        for iter, pack in enumerate(train_data_loader):
            img = pack['img'].cuda(non_blocking=True)
            bg_pos_label = pack['aff_bg_pos_label'].cuda(non_blocking=True)
            fg_pos_label = pack['aff_fg_pos_label'].cuda(non_blocking=True)
            neg_label = pack['aff_neg_label'].cuda(non_blocking=True)

            pos_aff_loss, neg_aff_loss, dp_fg_loss, dp_bg_loss = model(
                img, True)

            bg_pos_aff_loss = torch.sum(
                bg_pos_label * pos_aff_loss) / (torch.sum(bg_pos_label) + 1e-5)
            fg_pos_aff_loss = torch.sum(
                fg_pos_label * pos_aff_loss) / (torch.sum(fg_pos_label) + 1e-5)
            pos_aff_loss = bg_pos_aff_loss / 2 + fg_pos_aff_loss / 2
            neg_aff_loss = torch.sum(
                neg_label * neg_aff_loss) / (torch.sum(neg_label) + 1e-5)

            dp_fg_loss = torch.sum(dp_fg_loss * torch.unsqueeze(
                fg_pos_label, 1)) / (2 * torch.sum(fg_pos_label) + 1e-5)
            dp_bg_loss = torch.sum(dp_bg_loss * torch.unsqueeze(
                bg_pos_label, 1)) / (2 * torch.sum(bg_pos_label) + 1e-5)
            avg_meter.add({
                'loss1': pos_aff_loss.item(),
                'loss2': neg_aff_loss.item(),
                'loss3': dp_fg_loss.item(),
                'loss4': dp_bg_loss.item()
            })
            total_loss = (pos_aff_loss + neg_aff_loss) / 2 + (dp_fg_loss +
                                                              dp_bg_loss) / 2

            optimizer.zero_grad()
            total_loss.backward()
            optimizer.step()
            if (optimizer.global_step - 1) % 50 == 0:
                timer.update_progress(optimizer.global_step / max_step)

                print('step:%5d/%5d' % (optimizer.global_step - 1, max_step),
                      'loss:%.4f %.4f %.4f %.4f' %
                      (avg_meter.pop('loss1'), avg_meter.pop('loss2'),
                       avg_meter.pop('loss3'), avg_meter.pop('loss4')),
                      'imps:%.1f' % ((iter + 1) * args.irn_batch_size /
                                     timer.get_stage_elapsed()),
                      'lr: %.4f' % (optimizer.param_groups[0]['lr']),
                      'etc:%s' % (timer.str_estimated_complete()),
                      flush=True)
        else:
            timer.reset_stage()

    infer_dataset = voc12.dataloader.VOC12ImageDataset(
        args.infer_list,
        voc12_root=args.voc12_root,
        crop_size=args.irn_crop_size,
        crop_method="top_left")
    infer_data_loader = DataLoader(infer_dataset,
                                   batch_size=args.irn_batch_size,
                                   shuffle=False,
                                   num_workers=args.num_workers,
                                   pin_memory=True,
                                   drop_last=True)

    model.eval()
    print('Analyzing displacements mean ... ', end='')

    dp_mean_list = []

    with torch.no_grad():
        for iter, pack in enumerate(infer_data_loader):
            img = pack['img'].cuda(non_blocking=True)
            aff, dp = model(img, False)
            dp_mean_list.append(torch.mean(dp, dim=(0, 2, 3)).cpu())
        model.module.mean_shift.running_mean = torch.mean(
            torch.stack(dp_mean_list), dim=0)
    print('done.')

    torch.save(model.module.state_dict(), args.irn_weights_name)
    torch.cuda.empty_cache()
Example #3
0
def run(args):

    path_index = indexing.PathIndex(radius=10,
                                    default_size=(args.irn_crop_size // 4,
                                                  args.irn_crop_size // 4))

    model = getattr(importlib.import_module(args.irn_network),
                    'AffinityDisplacementLoss')(path_index, args.model_dir,
                                                args.dataset, args.tag,
                                                args.num_classes, args.use_cls)
    if args.dataset == 'voc12':
        train_dataset = voc12.dataloader.VOC12AffinityDataset(
            args.train_list,
            label_dir=args.ir_label_out_dir,
            dev_root=args.dev_root,
            indices_from=path_index.src_indices,
            indices_to=path_index.dst_indices,
            hor_flip=True,
            crop_size=args.irn_crop_size,
            crop_method=args.crop_method,
            rescale=args.rescale_range,
            outsize=args.outsize,
            norm_mode=args.norm_mode)
        infer_dataset = voc12.dataloader.VOC12ImageDataset(
            args.infer_list,
            dev_root=args.dev_root,
            crop_size=args.irn_crop_size,
            crop_method="top_left")
    elif args.dataset in ['adp_morph', 'adp_func']:
        train_dataset = adp.dataloader.ADPAffinityDataset(
            args.train_list,
            is_eval=args.dataset == 'evaluation',
            label_dir=args.ir_label_out_dir,
            dev_root=args.dev_root,
            htt_type=args.dataset.split('_')[-1],
            indices_from=path_index.src_indices,
            indices_to=path_index.dst_indices,
            hor_flip=True,
            crop_size=args.irn_crop_size,
            crop_method=args.crop_method,
            rescale=args.rescale_range,
            outsize=args.outsize,
            norm_mode=args.norm_mode)
        infer_dataset = adp.dataloader.ADPImageDataset(
            args.infer_list,
            dev_root=args.dev_root,
            htt_type=args.dataset.split('_')[-1],
            is_eval=args.dataset == 'evaluation',
            crop_size=args.irn_crop_size,
            crop_method="top_left")
    elif args.dataset in ['deepglobe', 'deepglobe_balanced']:
        train_dataset = deepglobe.dataloader.DeepGlobeAffinityDataset(
            args.train_list,
            is_balanced=args.dataset == 'deepglobe_balanced',
            label_dir=args.ir_label_out_dir,
            dev_root=args.dev_root,
            indices_from=path_index.src_indices,
            indices_to=path_index.dst_indices,
            hor_flip=True,
            crop_size=args.irn_crop_size,
            crop_method=args.crop_method,
            rescale=args.rescale_range,
            outsize=args.outsize,
            norm_mode=args.norm_mode)
        infer_dataset = deepglobe.dataloader.DeepGlobeImageDataset(
            args.infer_list,
            dev_root=args.dev_root,
            is_balanced=args.dataset == 'deepglobe_balanced',
            crop_size=args.irn_crop_size,
            crop_method="top_left")
    else:
        raise KeyError('Dataset %s not yet implemented' % args.dataset)

    train_data_loader = DataLoader(train_dataset,
                                   batch_size=args.irn_batch_size,
                                   shuffle=True,
                                   num_workers=args.num_workers,
                                   pin_memory=True,
                                   drop_last=True)

    max_step = (len(train_dataset) //
                args.irn_batch_size) * args.irn_num_epoches

    param_groups = model.trainable_parameters()
    optimizer = torchutils.PolyOptimizer([{
        'params': param_groups[0],
        'lr': 1 * args.irn_learning_rate,
        'weight_decay': args.irn_weight_decay
    }, {
        'params': param_groups[1],
        'lr': 10 * args.irn_learning_rate,
        'weight_decay': args.irn_weight_decay
    }],
                                         lr=args.irn_learning_rate,
                                         weight_decay=args.irn_weight_decay,
                                         max_step=max_step)

    model = torch.nn.DataParallel(model).cuda()
    model.train()

    # writer = SummaryWriter('log_tb/' + args.run_name)

    avg_meter = pyutils.AverageMeter()

    timer = pyutils.Timer()

    for ep in range(args.irn_num_epoches):

        print('Epoch %d/%d' % (ep + 1, args.irn_num_epoches))

        for iter, pack in enumerate(train_data_loader):

            img = pack['img'].cuda(non_blocking=True)
            bg_pos_label = pack['aff_bg_pos_label'].cuda(non_blocking=True)
            fg_pos_label = pack['aff_fg_pos_label'].cuda(non_blocking=True)
            neg_label = pack['aff_neg_label'].cuda(non_blocking=True)

            pos_aff_loss, neg_aff_loss, dp_fg_loss, dp_bg_loss = model(
                img, True)

            bg_pos_aff_loss = torch.sum(
                bg_pos_label * pos_aff_loss) / (torch.sum(bg_pos_label) + 1e-5)
            fg_pos_aff_loss = torch.sum(
                fg_pos_label * pos_aff_loss) / (torch.sum(fg_pos_label) + 1e-5)
            pos_aff_loss = bg_pos_aff_loss / 2 + fg_pos_aff_loss / 2
            neg_aff_loss = torch.sum(
                neg_label * neg_aff_loss) / (torch.sum(neg_label) + 1e-5)

            dp_fg_loss = torch.sum(dp_fg_loss * torch.unsqueeze(
                fg_pos_label, 1)) / (2 * torch.sum(fg_pos_label) + 1e-5)
            dp_bg_loss = torch.sum(dp_bg_loss * torch.unsqueeze(
                bg_pos_label, 1)) / (2 * torch.sum(bg_pos_label) + 1e-5)

            avg_meter.add({
                'loss1': pos_aff_loss.item(),
                'loss2': neg_aff_loss.item(),
                'loss3': dp_fg_loss.item(),
                'loss4': dp_bg_loss.item()
            })

            total_loss = (pos_aff_loss + neg_aff_loss) / 2 + (dp_fg_loss +
                                                              dp_bg_loss) / 2

            optimizer.zero_grad()
            total_loss.backward()
            optimizer.step()

            if (optimizer.global_step - 1) % 50 == 0:
                timer.update_progress(optimizer.global_step / max_step)
                losses = {}
                for i in range(1, 5):
                    losses[str(i)] = avg_meter.pop('loss' + str(i))

                print('step:%5d/%5d' % (optimizer.global_step - 1, max_step),
                      'loss:%.4f %.4f %.4f %.4f' %
                      (losses['1'], losses['2'], losses['3'], losses['4']),
                      'imps:%.1f' % ((iter + 1) * args.irn_batch_size /
                                     timer.get_stage_elapsed()),
                      'lr: %.4f' % (optimizer.param_groups[0]['lr']),
                      'etc:%s' % (timer.str_estimated_complete()),
                      flush=True)
                # writer.add_scalar('step', optimizer.global_step, ep * len(train_data_loader) + iter)
                # writer.add_scalar('loss', losses['1']+losses['2']+losses['3']+losses['4'],
                #                   ep * len(train_data_loader) + iter)
                # writer.add_scalar('lr', optimizer.param_groups[0]['lr'], ep * len(train_data_loader) + iter)
        else:
            timer.reset_stage()
    infer_data_loader = DataLoader(infer_dataset,
                                   batch_size=args.irn_batch_size,
                                   shuffle=False,
                                   num_workers=args.num_workers,
                                   pin_memory=True,
                                   drop_last=True)

    model.eval()
    print('Analyzing displacements mean ... ', end='')

    dp_mean_list = []

    with torch.no_grad():
        for iter, pack in enumerate(infer_data_loader):
            img = pack['img'].cuda(non_blocking=True)

            aff, dp = model(img, False)

            dp_mean_list.append(torch.mean(dp, dim=(0, 2, 3)).cpu())

        model.module.mean_shift.running_mean = torch.mean(
            torch.stack(dp_mean_list), dim=0)
    print('done.')

    torch.save(model.module.state_dict(), args.irn_weights_name)
    torch.cuda.empty_cache()
Example #4
0
def run(args):
    path_index = indexing.PathIndex(radius=10,
                                    default_size=(args.irn_crop_size // 4,
                                                  args.irn_crop_size // 4))

    model = getattr(importlib.import_module(args.irn_network),
                    'AffinityDisplacementLoss')(path_index)

    transform_config = {
        'augmentation_scope': 'horizontal_flip',
        'images_normalization': 'default',
        'images_output_format_type': 'float',
        'masks_normalization': 'none',
        'masks_output_format_type': 'byte',
        'size': 512,
        'size_transform': 'resize'
    }
    transform = get_transforms(transform_config)

    train_dataset = voc12.dataloader.PneumothoraxAffinityDataset(
        '/datasets/LID/Pneumothorax/train/train_all_positive.csv',
        transform=transform,
        indices_from=path_index.src_indices,
        indices_to=path_index.dst_indices,
    )

    train_data_loader = DataLoader(train_dataset,
                                   batch_size=args.irn_batch_size,
                                   shuffle=True,
                                   num_workers=args.num_workers,
                                   pin_memory=True,
                                   drop_last=True)

    max_step = (len(train_dataset) //
                args.irn_batch_size) * args.irn_num_epoches

    param_groups = model.trainable_parameters()
    optimizer = torchutils.PolyOptimizer([{
        'params': param_groups[0],
        'lr': 1 * args.irn_learning_rate,
        'weight_decay': args.irn_weight_decay
    }, {
        'params': param_groups[1],
        'lr': 10 * args.irn_learning_rate,
        'weight_decay': args.irn_weight_decay
    }],
                                         lr=args.irn_learning_rate,
                                         weight_decay=args.irn_weight_decay,
                                         max_step=max_step)

    model = torch.nn.DataParallel(model.cuda(1),
                                  device_ids=['cuda:1', 'cuda:2'])
    model.train()

    avg_meter = pyutils.AverageMeter()

    timer = pyutils.Timer()

    for ep in range(args.irn_num_epoches):

        print('Epoch %d/%d' % (ep + 1, args.irn_num_epoches))

        for iter, pack in enumerate(train_data_loader):

            img = pack['img']
            bg_pos_label = pack['aff_bg_pos_label'].cuda(1, non_blocking=True)
            fg_pos_label = pack['aff_fg_pos_label'].cuda(1, non_blocking=True)
            neg_label = pack['aff_neg_label'].cuda(1, non_blocking=True)

            pos_aff_loss, neg_aff_loss, dp_fg_loss, dp_bg_loss = model(
                img, True)

            bg_pos_aff_loss = torch.sum(
                bg_pos_label * pos_aff_loss) / (torch.sum(bg_pos_label) + 1e-5)
            fg_pos_aff_loss = torch.sum(
                fg_pos_label * pos_aff_loss) / (torch.sum(fg_pos_label) + 1e-5)
            pos_aff_loss = bg_pos_aff_loss / 2 + fg_pos_aff_loss / 2
            neg_aff_loss = torch.sum(
                neg_label * neg_aff_loss) / (torch.sum(neg_label) + 1e-5)

            dp_fg_loss = torch.sum(dp_fg_loss * torch.unsqueeze(
                fg_pos_label, 1)) / (2 * torch.sum(fg_pos_label) + 1e-5)
            dp_bg_loss = torch.sum(dp_bg_loss * torch.unsqueeze(
                bg_pos_label, 1)) / (2 * torch.sum(bg_pos_label) + 1e-5)

            avg_meter.add({
                'loss1': pos_aff_loss.item(),
                'loss2': neg_aff_loss.item(),
                'loss3': dp_fg_loss.item(),
                'loss4': dp_bg_loss.item()
            })

            total_loss = (pos_aff_loss + neg_aff_loss) / 2 + (dp_fg_loss +
                                                              dp_bg_loss) / 2

            optimizer.zero_grad()
            total_loss.backward()
            optimizer.step()

            if (optimizer.global_step - 1) % 50 == 0:
                timer.update_progress(optimizer.global_step / max_step)

                print('step:%5d/%5d' % (optimizer.global_step - 1, max_step),
                      'loss:%.4f %.4f %.4f %.4f' %
                      (avg_meter.pop('loss1'), avg_meter.pop('loss2'),
                       avg_meter.pop('loss3'), avg_meter.pop('loss4')),
                      'imps:%.1f' % ((iter + 1) * args.irn_batch_size /
                                     timer.get_stage_elapsed()),
                      'lr: %.4f' % (optimizer.param_groups[0]['lr']),
                      'etc:%s' % (timer.str_estimated_complete()),
                      flush=True)
        else:
            timer.reset_stage()

    transform_config = {
        'augmentation_scope': 'none',
        'images_normalization': 'default',
        'images_output_format_type': 'float',
        'size': 512,
        'size_transform': 'resize'
    }
    transform = get_transforms(transform_config)

    infer_dataset = voc12.dataloader.PneumothoraxImageDataset(
        '/datasets/LID/Pneumothorax/train/train_all_positive.csv',
        transform=transform)
    infer_data_loader = DataLoader(infer_dataset,
                                   batch_size=args.irn_batch_size,
                                   shuffle=False,
                                   num_workers=args.num_workers,
                                   pin_memory=True,
                                   drop_last=True)

    model.eval()
    print('Analyzing displacements mean ... ', end='')

    dp_mean_list = []

    with torch.no_grad():
        for iter, pack in enumerate(infer_data_loader):
            img = pack['img']

            aff, dp = model(img, False)

            dp_mean_list.append(torch.mean(dp, dim=(0, 2, 3)).cpu())

        model.module.mean_shift.running_mean = torch.mean(
            torch.stack(dp_mean_list), dim=0)
    print('done.')

    torch.save(model.module.state_dict(), args.irn_weights_name)
    torch.cuda.empty_cache()
Example #5
0
def run(args):

    path_index = indexing.PathIndex(radius=10,
                                    default_size=(args.irn_crop_size // 4,
                                                  args.irn_crop_size // 4))

    model = getattr(importlib.import_module(args.irn_network),
                    'AffinityDisplacement')(
                        path_index.default_path_indices,
                        torch.from_numpy(path_index.default_src_indices),
                        torch.from_numpy(path_index.default_dst_indices))

    train_dataset = voc12.dataloader.VOC12AffinityDataset(
        args.train_list,
        label_dir=args.ir_label_out_dir,
        voc12_root=args.voc12_root,
        indices_from=path_index.default_src_indices,
        indices_to=path_index.default_dst_indices,
        hor_flip=True,
        crop_size=args.irn_crop_size,
        crop_method="random",
        rescale=(0.5, 1.5))
    train_data_loader = DataLoader(train_dataset,
                                   batch_size=args.irn_batch_size,
                                   shuffle=True,
                                   num_workers=args.num_workers,
                                   pin_memory=True,
                                   drop_last=True)

    max_step = (len(train_dataset) //
                args.irn_batch_size) * args.irn_num_epoches

    param_groups = model.trainable_parameters()
    optimizer = torchutils.PolyOptimizer([{
        'params': param_groups[0],
        'lr': 1 * args.irn_learning_rate,
        'weight_decay': args.irn_weight_decay
    }, {
        'params': param_groups[1],
        'lr': 10 * args.irn_learning_rate,
        'weight_decay': args.irn_weight_decay
    }],
                                         lr=args.irn_learning_rate,
                                         weight_decay=args.irn_weight_decay,
                                         max_step=max_step)

    model = model.cuda()
    model.train()

    avg_meter = pyutils.AverageMeter()

    timer = pyutils.Timer()

    for ep in range(args.irn_num_epoches):

        print('Epoch %d/%d' % (ep + 1, args.irn_num_epoches))

        for iter, pack in enumerate(train_data_loader):

            img = pack['img'].cuda(non_blocking=True)
            bg_pos_label = pack['aff_bg_pos_label'].cuda(non_blocking=True)
            fg_pos_label = pack['aff_fg_pos_label'].cuda(non_blocking=True)
            neg_label = pack['aff_neg_label'].cuda(non_blocking=True)

            aff, dp = model(img)

            dp = path_index.to_displacement(dp)

            bg_pos_aff_loss = torch.sum(
                -bg_pos_label *
                torch.log(aff + 1e-5)) / (torch.sum(bg_pos_label) + 1e-5)
            fg_pos_aff_loss = torch.sum(
                -fg_pos_label *
                torch.log(aff + 1e-5)) / (torch.sum(fg_pos_label) + 1e-5)
            pos_aff_loss = bg_pos_aff_loss / 2 + fg_pos_aff_loss / 2

            neg_aff_loss = torch.sum(
                -neg_label *
                torch.log(1. + 1e-5 - aff)) / (torch.sum(neg_label) + 1e-5)

            dp_fg_loss = torch.sum(
                path_index.to_displacement_loss(dp) * torch.unsqueeze(
                    fg_pos_label, 1)) / (2 * torch.sum(fg_pos_label) + 1e-5)

            dp_bg_loss = torch.sum(
                torch.abs(dp) * torch.unsqueeze(bg_pos_label, 1)) / (
                    2 * torch.sum(bg_pos_label) + 1e-5)

            avg_meter.add({
                'loss1': pos_aff_loss,
                'loss2': neg_aff_loss,
                'loss3': dp_fg_loss.item(),
                'loss4': dp_bg_loss.item()
            })

            total_loss = (pos_aff_loss + neg_aff_loss) / 2 + (dp_fg_loss +
                                                              dp_bg_loss) / 2

            optimizer.zero_grad()
            total_loss.backward()
            optimizer.step()

            if (optimizer.global_step - 1) % 100 == 0:
                timer.update_progress(optimizer.global_step / max_step)

                print('step:%5d/%5d' % (optimizer.global_step - 1, max_step),
                      'loss:%.4f %.4f %.4f %.4f' %
                      (avg_meter.pop('loss1'), avg_meter.pop('loss2'),
                       avg_meter.pop('loss3'), avg_meter.pop('loss4')),
                      'imps:%.1f' % ((iter + 1) * args.irn_batch_size /
                                     timer.get_stage_elapsed()),
                      'lr: %.4f' % (optimizer.param_groups[0]['lr']),
                      'etc:%s' % (timer.str_estimated_complete()),
                      flush=True)
        else:
            timer.reset_stage()

    torch.save(model.state_dict(), args.irn_weights_name)
    torch.cuda.empty_cache()