예제 #1
0
def train(args):
    if not os.path.exists('checkpoints'):
        os.mkdir('checkpoints')

    # Setup Augmentations
    data_aug = transforms.Compose([
        transforms.ToPILImage(),
        transforms.RandomHorizontalFlip(p=0.5),
        transforms.RandomAffine(degrees=10,
                                translate=(0.05, 0.05),
                                scale=(0.95, 1.05)),
    ])

    # Setup Dataloader
    data_loader = get_loader(args.dataset)
    data_path = get_data_path(args.dataset)
    t_loader = data_loader(data_path,
                           is_transform=True,
                           split='train',
                           version='simplified',
                           img_size=(args.img_rows, args.img_cols),
                           augmentations=data_aug,
                           train_fold_num=args.train_fold_num,
                           num_train_folds=args.num_train_folds,
                           seed=args.seed)
    v_loader = data_loader(data_path,
                           is_transform=True,
                           split='val',
                           version='simplified',
                           img_size=(args.img_rows, args.img_cols),
                           num_val=args.num_val,
                           seed=args.seed)

    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed(args.seed)

    n_classes = t_loader.n_classes
    trainloader = data.DataLoader(t_loader,
                                  batch_size=args.batch_size,
                                  num_workers=2,
                                  shuffle=True,
                                  pin_memory=True,
                                  drop_last=True)
    valloader = data.DataLoader(v_loader,
                                batch_size=args.batch_size,
                                num_workers=2,
                                pin_memory=True)

    # Setup Metrics
    running_metrics = runningScore(n_classes)

    # Setup Model
    v_demision = 300
    model = get_model(args.arch, v_demision, use_cbam=args.use_cbam)
    model.cuda()

    optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad,
                                        model.parameters()),
                                 lr=args.l_rate,
                                 weight_decay=args.weight_decay)
    if args.num_cycles > 0:
        len_trainloader = int(5e6)  # 4960414
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
            optimizer,
            T_max=args.num_train_folds * len_trainloader // args.num_cycles,
            eta_min=args.l_rate * 1e-1)
    else:
        scheduler = torch.optim.lr_scheduler.MultiStepLR(
            optimizer, milestones=[2, 4, 6, 8], gamma=0.5)

    start_epoch = 0
    if args.resume is not None:
        if os.path.isfile(args.resume):
            print("Loading model and optimizer from checkpoint '{}'".format(
                args.resume))
            checkpoint = torch.load(args.resume)

            model_dict = model.state_dict()
            if checkpoint.get('model_state', -1) == -1:
                model_dict.update(
                    convert_state_dict(checkpoint,
                                       load_classifier=args.load_classifier))
            else:
                model_dict.update(
                    convert_state_dict(checkpoint['model_state'],
                                       load_classifier=args.load_classifier))

                print(
                    "Loaded checkpoint '{}' (epoch {}, mapk {:.5f}, top1_acc {:7.3f}, top2_acc {:7.3f} top3_acc {:7.3f})"
                    .format(args.resume, checkpoint['epoch'],
                            checkpoint['mapk'], checkpoint['top1_acc'],
                            checkpoint['top2_acc'], checkpoint['top3_acc']))
            model.load_state_dict(model_dict)

            if checkpoint.get('optimizer_state', None) is not None:
                optimizer.load_state_dict(checkpoint['optimizer_state'])
                start_epoch = checkpoint['epoch']
        else:
            print("No checkpoint found at '{}'".format(args.resume))

    loss_sum = 0.0
    for epoch in range(start_epoch, args.n_epoch):
        start_train_time = timeit.default_timer()

        if args.num_cycles == 0:
            scheduler.step(epoch)

        model.train()
        optimizer.zero_grad()
        for i, (images, labels, recognized, _) in enumerate(trainloader):
            if args.num_cycles > 0:
                iter_num = i + epoch * len_trainloader
                scheduler.step(
                    iter_num %
                    (args.num_train_folds * len_trainloader //
                     args.num_cycles))  # Cosine Annealing with Restarts

            images = images.cuda()
            labels = labels.cuda()

            outputs = model(images)
            a_loss = Adptive_loss().cuda()
            loss = a_loss(outputs, labels)
            loss = loss / float(args.iter_size)  # Accumulated gradients
            loss_sum = loss_sum + loss

            loss.backward()

            if (i + 1) % args.print_train_freq == 0:
                print("Epoch [%d/%d] Iter [%6d/%6d] Loss: %.4f" %
                      (epoch + 1, args.n_epoch, i + 1, len(trainloader),
                       loss_sum))

            if (i + 1) % args.iter_size == 0 or i == len(trainloader) - 1:
                optimizer.step()
                optimizer.zero_grad()
                loss_sum = 0.0

        elapsed_train_time = timeit.default_timer() - start_train_time
        print('Training time (epoch {0:5d}): {1:10.5f} seconds'.format(
            epoch + 1, elapsed_train_time))
예제 #2
0
def train(args):
    if not os.path.exists('checkpoints'):
        os.mkdir('checkpoints')

    # Setup Augmentations
    data_aug = transforms.Compose([
                    transforms.ToPILImage(),
                    transforms.RandomHorizontalFlip(p=0.5),
                    transforms.RandomAffine(degrees=10, translate=(0.05, 0.05), scale=(0.95, 1.05)),
                ])

    # Setup Dataloader
    data_loader = get_loader(args.dataset)
    data_path = get_data_path(args.dataset)
    t_loader = data_loader(data_path, is_transform=True, split='train', version='simplified', img_size=(args.img_rows, args.img_cols), augmentations=data_aug, train_fold_num=args.train_fold_num, num_train_folds=args.num_train_folds, seed=args.seed)
    v_loader = data_loader(data_path, is_transform=True, split='val', version='simplified', img_size=(args.img_rows, args.img_cols), num_val=args.num_val, seed=args.seed)

    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed(args.seed)

    n_classes = t_loader.n_classes
    trainloader = data.DataLoader(t_loader, batch_size=args.batch_size, num_workers=2, shuffle=True, pin_memory=True, drop_last=True)
    valloader = data.DataLoader(v_loader, batch_size=args.batch_size, num_workers=2, pin_memory=True)

    # Setup Metrics
    running_metrics = runningScore(n_classes)

    # Setup Model
    # model = get_model(args.arch, n_classes, use_cbam=args.use_cbam)
    model = torchvision.models.mobilenet_v2(pretrained=True)
    num_ftrs = model.last_channel
    model.classifier = nn.Sequential(
        nn.Dropout(0.2),
        nn.Linear(num_ftrs, n_classes),
    )
    model.cuda()

    # Check if model has custom optimizer / loss
    if hasattr(model, 'optimizer'):
        optimizer = model.optimizer
    else:
        ##optimizer = torch.optim.SGD(filter(lambda p: p.requires_grad, model.parameters()), lr=args.l_rate, momentum=args.momentum, weight_decay=args.weight_decay)
        optimizer = torch.optim.Adam(model.parameters(), lr=args.l_rate, weight_decay=args.weight_decay)
        # if args.num_cycles > 0:
        #     len_trainloader = int(5e6) # 4960414
        #     scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.num_train_folds*len_trainloader//args.num_cycles, eta_min=args.l_rate*1e-1)
        # else:
        #     scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[2, 4, 6, 8], gamma=0.5)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer,mode='min',factor=0.5,patience=5,cooldown=5,min_lr=1e-7)

    if hasattr(model, 'loss'):
        print('Using custom loss')
        loss_fn = model.loss
    else:
        loss_fn = F.cross_entropy

    start_epoch = 0
    if args.resume is not None:                                         
        if os.path.isfile(args.resume):
            print("Loading model and optimizer from checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)

            model_dict = model.state_dict()
            if checkpoint.get('model_state', -1) == -1:
                model_dict.update(convert_state_dict(checkpoint, load_classifier=args.load_classifier))
            else:
                model_dict.update(convert_state_dict(checkpoint['model_state'], load_classifier=args.load_classifier))

                print("Loaded checkpoint '{}' (epoch {}, mapk {:.5f}, top1_acc {:7.3f}, top2_acc {:7.3f} top3_acc {:7.3f})"
                      .format(args.resume, checkpoint['epoch'], checkpoint['mapk'], checkpoint['top1_acc'], checkpoint['top2_acc'], checkpoint['top3_acc']))
            model.load_state_dict(model_dict)

            if checkpoint.get('optimizer_state', None) is not None:
                optimizer.load_state_dict(checkpoint['optimizer_state'])
                start_epoch = checkpoint['epoch']
        else:
            print("No checkpoint found at '{}'".format(args.resume)) 


    loss_sum = 0.0
    for epoch in range(start_epoch, args.n_epoch):
        start_train_time = timeit.default_timer()

        model.train()
        optimizer.zero_grad()
        for i, (images, labels, recognized, _) in enumerate(trainloader):


            images = images.cuda()
            labels = labels.cuda()
            recognized = recognized.cuda()

            outputs = model(images)

            loss = (loss_fn(outputs, labels.view(-1), ignore_index=t_loader.ignore_index, reduction='none') * recognized.view(-1)).mean()
            # loss = loss / float(args.iter_size) # Accumulated gradients
            loss_sum = loss_sum + loss

            loss.backward()

            if (i+1) % args.print_train_freq == 0:
                print("Epoch [%d/%d] Iter [%6d/%6d] Loss: %.4f" % (epoch+1, args.n_epoch, i+1, len(trainloader), loss_sum))

            if (i+1) % args.iter_size == 0 or i == len(trainloader) - 1:
                optimizer.step()
                optimizer.zero_grad()
                loss_sum = 0.0

        mapk_val = AverageMeter()
        top1_acc_val = AverageMeter()
        top2_acc_val = AverageMeter()
        top3_acc_val = AverageMeter()
        mean_loss_val = AverageMeter()
        model.eval()
        with torch.no_grad():
            for i_val, (images_val, labels_val, recognized_val, _) in tqdm(enumerate(valloader)):
                images_val = images_val.cuda()
                labels_val = labels_val.cuda()
                recognized_val = recognized_val.cuda()

                outputs_val = model(images_val)

                loss_val = (loss_fn(outputs_val, labels_val.view(-1), ignore_index=v_loader.ignore_index, reduction='none') * recognized_val.view(-1)).mean()
                mean_loss_val.update(loss_val, n=images_val.size(0))

                _, pred = outputs_val.topk(k=3, dim=1, largest=True, sorted=True)
                running_metrics.update(labels_val, pred[:, 0])

                acc1, acc2, acc3 = accuracy(outputs_val, labels_val, topk=(1, 2, 3))
                top1_acc_val.update(acc1, n=images_val.size(0))
                top2_acc_val.update(acc2, n=images_val.size(0))
                top3_acc_val.update(acc3, n=images_val.size(0))

                mapk_v = mapk(labels_val, pred, k=3)
                mapk_val.update(mapk_v, n=images_val.size(0))

        print('Mean Average Precision (MAP) @ 3: {:.5f}'.format(mapk_val.avg))
        print('Top 3 accuracy: {:7.3f} / {:7.3f} / {:7.3f}'.format(top1_acc_val.avg, top2_acc_val.avg, top3_acc_val.avg))
        print('Mean val loss: {:.4f}'.format(mean_loss_val.avg))

        score, class_iou = running_metrics.get_scores()

        for k, v in score.items():
            print(k, v)

        #for i in range(n_classes):
        #    print(i, class_iou[i])
        scheduler.step(mean_loss_val.avg)
        state = {'epoch': epoch+1,
                 'model_state': model.state_dict(),
                 'optimizer_state': optimizer.state_dict(),
                 'mapk': mapk_val.avg,
                 'top1_acc': top1_acc_val.avg,
                 'top2_acc': top2_acc_val.avg,
                 'top3_acc': top3_acc_val.avg,}
        torch.save(state, "checkpoints/{}_{}_{}_{}x{}_{}-{}-{}_model.pth".format(args.arch, args.dataset, epoch+1, args.img_rows, args.img_cols, args.train_fold_num, args.num_train_folds, args.num_val))

        running_metrics.reset()
        mapk_val.reset()
        top1_acc_val.reset()
        top2_acc_val.reset()
        top3_acc_val.reset()
        mean_loss_val.reset()

        elapsed_train_time = timeit.default_timer() - start_train_time
        print('Training time (epoch {0:5d}): {1:10.5f} seconds'.format(epoch+1, elapsed_train_time))
예제 #3
0
def merge(args):
    if not os.path.exists(args.root_results):
        os.makedirs(args.root_results)

    # Setup Dataloader
    data_loader = get_loader(args.dataset)
    data_path = get_data_path(args.dataset)
    loader = data_loader(data_path, split=args.split, transforms=None, fold_num=0, num_folds=1, no_gt=args.no_gt, seed=args.seed, no_load_images=True)

    n_classes = loader.n_classes
    testloader = data.DataLoader(loader, batch_size=args.batch_size)#, num_workers=2, pin_memory=True)

    avg_y_prob = np.zeros((loader.__len__(), 1, 1024, 1024), dtype=np.float32)
    avg_y_pred_sum = np.zeros((loader.__len__(),), dtype=np.int32)
    fold_list = []
    for prob_file_name in glob.glob(os.path.join(args.root_results, '*.npy')):
        prob = np.load(prob_file_name, mmap_mode='r')
        for i in range(loader.__len__()):
            avg_y_prob[i, :, :, :] += prob[i, :, :, :]
        fold_list.append(prob_file_name)
        print(prob_file_name)
    avg_y_prob = avg_y_prob / len(fold_list)
    ##avgprob_file_name = 'prob_{}_avg'.format(len(fold_list))
    ##np.save(os.path.join(args.root_results, '{}.npy'.format(avgprob_file_name)), avg_y_prob)

    avg_y_pred = (avg_y_prob > args.thresh).astype(np.int)
    avg_y_pred_sum = avg_y_pred.sum(3).sum(2).sum(1)

    avg_y_pred_sum_argsorted = np.argsort(avg_y_pred_sum)[::-1]
    pruned_idx = int(avg_y_pred_sum_argsorted.shape[0]*args.non_empty_ratio)
    mask_sum_thresh = int(avg_y_pred_sum[avg_y_pred_sum_argsorted[pruned_idx]]) if pruned_idx < avg_y_pred_sum_argsorted.shape[0] else 0

    running_metrics = runningScore(n_classes=2, weight_acc_non_empty=args.weight_acc_non_empty)

    pred_dict = collections.OrderedDict()
    num_non_empty_masks = 0
    for i, (_, labels, names) in tqdm(enumerate(testloader)):
        labels = labels.cuda()

        prob = avg_y_prob[i*args.batch_size:i*args.batch_size+labels.size(0), :, :, :]
        pred = (prob > args.thresh).astype(np.int)
        pred = torch.from_numpy(pred).long().cuda()

        pred_sum = pred.sum(3).sum(2).sum(1)
        for k in range(labels.size(0)):
            if pred_sum[k] > mask_sum_thresh:
                num_non_empty_masks += 1
            else:
                pred[k, :, :, :] = torch.zeros_like(pred[k, :, :, :])
                if args.only_non_empty:
                    pred[k, :, 0, 0] = 1

        if not args.no_gt:
            running_metrics.update(labels.long(), pred.long())

        for k in range(labels.size(0)):
            name = names[0][k]
            if pred_dict.get(name, None) is None:
                mask = pred[k, 0, :, :].cpu().numpy()
                rle = loader.mask2rle(mask)
                pred_dict[name] = rle

    print('# non-empty masks: {:5d} (non_empty_ratio: {:.5f} / mask_sum_thresh: {:6d})'.format(num_non_empty_masks, args.non_empty_ratio, mask_sum_thresh))
    if not args.no_gt:
        dice, dice_empty, dice_non_empty, miou, acc, acc_empty, acc_non_empty = running_metrics.get_scores()
        print('Dice (per image): {:.5f} (empty: {:.5f} / non-empty: {:.5f})'.format(dice, dice_empty, dice_non_empty))
        print('Classification accuracy: {:.5f} (empty: {:.5f} / non-empty: {:.5f})'.format(acc, acc_empty, acc_non_empty))
        print('Overall mIoU: {:.5f}'.format(miou))
    running_metrics.reset()

    # Create submission
    csv_file_name = 'merged_{}_{}_{}_{}'.format(args.split, len(fold_list), args.thresh, args.non_empty_ratio)
    sub = pd.DataFrame.from_dict(pred_dict, orient='index')
    sub.index.names = ['ImageId']
    sub.columns = ['EncodedPixels']
    sub.to_csv(os.path.join(args.root_results, '{}.csv'.format(csv_file_name)))
def test(args):
    model_file_name = os.path.split(args.model_path)[1]
    model_name = model_file_name[:model_file_name.find('_')]

    # Setup Dataloader
    data_loader = get_loader(args.dataset)
    data_path = get_data_path(args.dataset)
    loader = data_loader(data_path,
                         split=args.split,
                         is_transform=True,
                         img_size=(args.img_rows, args.img_cols),
                         no_gt=args.no_gt,
                         seed=args.seed)

    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed(args.seed)

    n_classes = loader.n_classes
    testloader = data.DataLoader(loader,
                                 batch_size=args.batch_size,
                                 num_workers=4,
                                 pin_memory=True)

    # Setup Model
    model = torchvision.models.mobilenet_v2(pretrained=True)
    num_ftrs = model.last_channel
    model.classifier = nn.Sequential(
        nn.Dropout(0.2),
        nn.Linear(num_ftrs, n_classes),
    )
    model.cuda()

    checkpoint = torch.load(args.model_path)
    state = convert_state_dict(checkpoint['model_state'])
    model_dict = model.state_dict()
    model_dict.update(state)
    model.load_state_dict(model_dict)

    print(
        "Loaded checkpoint '{}' (epoch {}, mapk {:.5f}, top1_acc {:7.3f}, top2_acc {:7.3f} top3_acc {:7.3f})"
        .format(args.model_path, checkpoint['epoch'], checkpoint['mapk'],
                checkpoint['top1_acc'], checkpoint['top2_acc'],
                checkpoint['top3_acc']))

    running_metrics = runningScore(n_classes)

    pred_dict = collections.OrderedDict()
    mapk = AverageMeter()
    model.eval()
    with torch.no_grad():
        for i, (images, labels, _, names) in tqdm(enumerate(testloader)):
            plt.imshow((images[0].numpy().transpose(1, 2, 0) -
                        np.min(images[0].numpy().transpose(1, 2, 0))) /
                       (np.max(images[0].numpy().transpose(1, 2, 0) -
                               np.min(images[0].numpy().transpose(1, 2, 0)))))

            plt.show()
            images = images.cuda()
            if args.tta:
                images_flip = flip(images, dim=3)

            outputs = model(images)
            if args.tta:
                outputs_flip = model(images_flip)

            prob = F.softmax(outputs, dim=1)
            if args.tta:
                prob_flip = F.softmax(outputs_flip, dim=1)
                prob = (prob + prob_flip) / 2.0

            _, pred = prob.topk(k=3, dim=1, largest=True, sorted=True)
            for k in range(images.size(0)):
                pred_dict[int(names[0][k])] = loader.encode_pred_name(
                    pred[k, :])

            if not args.no_gt:
                running_metrics.update(labels, pred)

                mapk_val = mapk(labels, pred, k=3)
                mapk.update(mapk_val, n=images.size(0))

        print('Mean Average Precision (MAP) @ 3: {:.5f}'.format(mapk.avg))
    if not args.no_gt:
        print('Mean Average Precision (MAP) @ 3: {:.5f}'.format(mapk.avg))

        score, class_iou = running_metrics.get_scores()

        for k, v in score.items():
            print(k, v)

        #for i in range(n_classes):
        #    print(i, class_iou[i])

        running_metrics.reset()
        mapk.reset()

    # Create submission
    sub = pd.DataFrame.from_dict(pred_dict, orient='index')
    sub.index.names = ['key_id']
    sub.columns = ['word']
    sub.to_csv('{}_{}x{}.csv'.format(args.split, args.img_rows, args.img_cols))
예제 #5
0
def train(args):
    if not os.path.exists('checkpoints'):
        os.mkdir('checkpoints')

    # Setup Augmentations & Transforms
    rgb_mean = [122.7717 / 255., 115.9465 / 255., 102.9801 /
                255.] if args.norm_type == 'gn' and args.load_pretrained else [
                    0.485, 0.456, 0.406
                ]
    rgb_std = [1. / 255., 1. / 255., 1. /
               255.] if args.norm_type == 'gn' and args.load_pretrained else [
                   0.229, 0.224, 0.225
               ]
    data_trans = transforms.Compose([
        transforms.ToPILImage(),
        transforms.Resize(size=(args.img_rows, args.img_cols)),
        transforms.ToTensor(),
        transforms.Normalize(mean=rgb_mean, std=rgb_std),
    ])

    # Setup Dataloader
    data_loader = get_loader(args.dataset)
    data_path = get_data_path(args.dataset)
    t_loader = data_loader(data_path,
                           transforms=data_trans,
                           in_channels=args.in_channels,
                           split='train',
                           augmentations=True,
                           fold_num=args.fold_num,
                           num_folds=args.num_folds,
                           only_non_empty=args.only_non_empty,
                           seed=args.seed,
                           mask_dilation_size=args.mask_dilation_size)
    v_loader = data_loader(data_path,
                           transforms=data_trans,
                           in_channels=args.in_channels,
                           split='val',
                           fold_num=args.fold_num,
                           num_folds=args.num_folds,
                           only_non_empty=args.only_non_empty,
                           seed=args.seed)

    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed(args.seed)

    trainloader = data.DataLoader(t_loader,
                                  batch_size=args.batch_size,
                                  num_workers=2,
                                  pin_memory=True,
                                  shuffle=args.only_non_empty,
                                  drop_last=args.only_non_empty)
    valloader = data.DataLoader(v_loader,
                                batch_size=args.batch_size,
                                num_workers=2,
                                pin_memory=True)

    # Setup Model
    model = get_model(args.arch,
                      n_classes=1,
                      in_channels=args.in_channels,
                      norm_type=args.norm_type,
                      load_pretrained=args.load_pretrained,
                      use_cbam=args.use_cbam)
    model.to(torch.device(args.device))

    running_metrics = runningScore(
        n_classes=2,
        weight_acc_non_empty=args.weight_acc_non_empty,
        device=args.device)

    # Check if model has custom optimizer / loss
    if hasattr(model, 'optimizer'):
        optimizer = model.optimizer
    else:
        warmup_iter = int(args.n_iter * 5. / 100.)
        milestones = [
            int(args.n_iter * 30. / 100.) - warmup_iter,
            int(args.n_iter * 60. / 100.) - warmup_iter,
            int(args.n_iter * 90. / 100.) - warmup_iter
        ]  # [30, 60, 90]
        gamma = 0.5  #0.1

        if args.optimizer_type == 'sgd':
            optimizer = torch.optim.SGD(group_weight(model),
                                        lr=args.l_rate,
                                        momentum=args.momentum,
                                        weight_decay=args.weight_decay)
        elif args.optimizer_type == 'adam':
            optimizer = torch.optim.Adam(group_weight(model),
                                         lr=args.l_rate,
                                         weight_decay=args.weight_decay)
        else:  #if args.optimizer_type == 'radam':
            optimizer = RAdam(group_weight(model),
                              lr=args.l_rate,
                              weight_decay=args.weight_decay)

        if args.num_cycles > 0:
            scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
                optimizer,
                T_max=(args.n_iter - warmup_iter) // args.num_cycles,
                eta_min=args.l_rate * 0.1)
        else:
            scheduler = torch.optim.lr_scheduler.MultiStepLR(
                optimizer, milestones=milestones, gamma=gamma)
        scheduler_warmup = GradualWarmupScheduler(optimizer,
                                                  total_epoch=warmup_iter,
                                                  min_lr_mul=0.1,
                                                  after_scheduler=scheduler)

    start_iter = 0
    if args.resume is not None:
        if os.path.isfile(args.resume):
            print("Loading model and optimizer from checkpoint '{}'".format(
                args.resume))
            checkpoint = torch.load(args.resume,
                                    map_location=torch.device(
                                        args.device))  #, encoding="latin1")

            model_dict = model.state_dict()
            if checkpoint.get('model_state', None) is not None:
                model_dict.update(convert_state_dict(
                    checkpoint['model_state']))
            else:
                model_dict.update(convert_state_dict(checkpoint))

            start_iter = checkpoint.get('iter', -1)
            dice_val = checkpoint.get('dice', -1)
            wacc_val = checkpoint.get('wacc', -1)
            print("Loaded checkpoint '{}' (iter {}, dice {:.5f}, wAcc {:.5f})".
                  format(args.resume, start_iter, dice_val, wacc_val))

            model.load_state_dict(model_dict)

            if checkpoint.get('optimizer_state', None) is not None:
                optimizer.load_state_dict(checkpoint['optimizer_state'])

            del model_dict
            del checkpoint
            torch.cuda.empty_cache()
        else:
            print("No checkpoint found at '{}'".format(args.resume))
    start_iter = args.start_iter if args.start_iter >= 0 else start_iter

    scale_weight = torch.tensor([1.0, 0.4, 0.4,
                                 0.4]).to(torch.device(args.device))
    dice_weight = [args.dice_weight0, args.dice_weight1]
    lv_margin = [args.lv_margin0, args.lv_margin1]
    total_loss_sum = 0.0
    ms_loss_sum = 0.0
    cls_loss_sum = 0.0
    t_loader.__gen_batchs__(args.batch_size, ratio=args.ratio)
    trainloader_iter = iter(trainloader)
    optimizer.zero_grad()
    start_train_time = timeit.default_timer()
    elapsed_train_time = 0.0
    best_dice = -100.0
    best_wacc = -100.0
    for i in range(start_iter, args.n_iter):
        #"""
        model.train()

        if i % args.iter_size == 0:
            if args.num_cycles == 0:
                scheduler_warmup.step(i)
            else:
                scheduler_warmup.step(i // args.num_cycles)

        try:
            images, labels, _ = next(trainloader_iter)
        except:
            t_loader.__gen_batchs__(args.batch_size, ratio=args.ratio)
            trainloader_iter = iter(trainloader)
            images, labels, _ = next(trainloader_iter)

        images = images.to(torch.device(args.device))
        labels = labels.to(torch.device(args.device))

        outputs, outputs_gap = model(images)

        labels_gap = torch.where(
            labels.sum(3, keepdim=True).sum(2, keepdim=True) > 0,
            torch.ones(labels.size(0), 1, 1, 1).to(torch.device(args.device)),
            torch.zeros(labels.size(0), 1, 1, 1).to(torch.device(args.device)))
        cls_loss = F.binary_cross_entropy_with_logits(
            outputs_gap,
            labels_gap) if args.lambda_cls > 0 else torch.tensor(0.0).to(
                labels.device)
        ms_loss = multi_scale_loss(outputs,
                                   labels,
                                   scale_weight=scale_weight,
                                   reduction='mean',
                                   alpha=args.alpha,
                                   gamma=args.gamma,
                                   dice_weight=dice_weight,
                                   lv_margin=lv_margin,
                                   lambda_fl=args.lambda_fl,
                                   lambda_dc=args.lambda_dc,
                                   lambda_lv=args.lambda_lv)
        total_loss = ms_loss + args.lambda_cls * cls_loss
        total_loss = total_loss / float(args.iter_size)
        total_loss.backward()
        total_loss_sum = total_loss_sum + total_loss.item()
        ms_loss_sum = ms_loss_sum + ms_loss.item()
        cls_loss_sum = cls_loss_sum + cls_loss.item()

        if (i + 1) % args.print_train_freq == 0:
            print("Iter [%7d/%7d] Loss: %7.4f (MS: %7.4f / CLS: %7.4f)" %
                  (i + 1, args.n_iter, total_loss_sum, ms_loss_sum,
                   cls_loss_sum))

        if (i + 1) % args.iter_size == 0:
            optimizer.step()
            optimizer.zero_grad()
            total_loss_sum = 0.0
            ms_loss_sum = 0.0
            cls_loss_sum = 0.0
        #"""

        if args.eval_freq > 0 and (i + 1) % args.eval_freq == 0:
            state = {
                'iter': i + 1,
                'model_state': model.state_dict(),
            }
            #'optimizer_state': optimizer.state_dict(),}
            if (i + 1) % int(args.eval_freq / args.save_freq) == 0:
                torch.save(
                    state, "checkpoints/{}_{}_{}_{}x{}_{}-{}_model.pth".format(
                        args.arch, args.dataset, i + 1, args.img_rows,
                        args.img_cols, args.fold_num, args.num_folds))

            dice_val = 0.0
            thresh = 0.5
            mask_sum_thresh = 0
            mean_loss_val = AverageMeter()
            model.eval()
            with torch.no_grad():
                for i_val, (images_val, labels_val, _) in enumerate(valloader):
                    images_val = images_val.to(torch.device(args.device))
                    labels_val = labels_val.to(torch.device(args.device))

                    outputs_val, outputs_gap_val = model(images_val)
                    pred_val = (F.sigmoid(outputs_val if not isinstance(
                        outputs_val, tuple) else outputs_val[0]) >
                                thresh).long()  #outputs_val.max(1)[1]

                    pred_val_sum = pred_val.sum(3).sum(2).sum(1)
                    for k in range(labels_val.size(0)):
                        if pred_val_sum[k] < mask_sum_thresh:
                            pred_val[k, :, :, :] = torch.zeros_like(
                                pred_val[k, :, :, :])

                    labels_gap_val = torch.where(
                        labels_val.sum(3, keepdim=True).sum(2, keepdim=True) >
                        0,
                        torch.ones(labels_val.size(0), 1, 1,
                                   1).to(torch.device(args.device)),
                        torch.zeros(labels_val.size(0), 1, 1,
                                    1).to(torch.device(args.device)))
                    cls_loss_val = F.binary_cross_entropy_with_logits(
                        outputs_gap_val, labels_gap_val
                    ) if args.lambda_cls > 0 else torch.tensor(0.0).to(
                        labels_val.device)
                    ms_loss_val = multi_scale_loss(outputs_val,
                                                   labels_val,
                                                   scale_weight=scale_weight,
                                                   reduction='mean',
                                                   alpha=args.alpha,
                                                   gamma=args.gamma,
                                                   dice_weight=dice_weight,
                                                   lv_margin=lv_margin,
                                                   lambda_fl=args.lambda_fl,
                                                   lambda_dc=args.lambda_dc,
                                                   lambda_lv=args.lambda_lv)
                    loss_val = ms_loss_val + args.lambda_cls * cls_loss_val
                    mean_loss_val.update(loss_val.item(), n=labels_val.size(0))

                    running_metrics.update(labels_val.long(), pred_val.long())

            dice_val, dice_empty_val, dice_non_empty_val, miou_val, wacc_val, acc_empty_val, acc_non_empty_val = running_metrics.get_scores(
            )
            print(
                'Dice (per image): {:.5f} (empty: {:.5f} / non-empty: {:.5f})'.
                format(dice_val, dice_empty_val, dice_non_empty_val))
            print('wAcc: {:.5f} (empty: {:.5f} / non-empty: {:.5f})'.format(
                wacc_val, acc_empty_val, acc_non_empty_val))
            print('Overall mIoU: {:.5f}'.format(miou_val))
            print('Mean val loss: {:.4f}'.format(mean_loss_val.avg))
            state['dice'] = dice_val
            state['wacc'] = wacc_val
            state['miou'] = miou_val
            running_metrics.reset()
            mean_loss_val.reset()

            if (i + 1) % int(args.eval_freq / args.save_freq) == 0:
                torch.save(
                    state, "checkpoints/{}_{}_{}_{}x{}_{}-{}_model.pth".format(
                        args.arch, args.dataset, i + 1, args.img_rows,
                        args.img_cols, args.fold_num, args.num_folds))
            if best_dice <= dice_val:
                best_dice = dice_val
                torch.save(
                    state, "checkpoints/{}_{}_{}_{}x{}_{}-{}_model.pth".format(
                        args.arch, args.dataset, 'best-dice', args.img_rows,
                        args.img_cols, args.fold_num, args.num_folds))
            if best_wacc <= wacc_val:
                best_wacc = wacc_val
                torch.save(
                    state, "checkpoints/{}_{}_{}_{}x{}_{}-{}_model.pth".format(
                        args.arch, args.dataset, 'best-wacc', args.img_rows,
                        args.img_cols, args.fold_num, args.num_folds))

            elapsed_train_time = timeit.default_timer() - start_train_time
            print('Training time (iter {0:5d}): {1:10.5f} seconds'.format(
                i + 1, elapsed_train_time))

        if args.saving_last_time > 0 and (i + 1) % args.iter_size == 0 and (
                timeit.default_timer() -
                start_train_time) > args.saving_last_time:
            state = {
                'iter': i + 1,
                'model_state': model.state_dict(),  #}
                'optimizer_state': optimizer.state_dict(),
            }
            torch.save(
                state, "checkpoints/{}_{}_{}_{}x{}_{}-{}_model.pth".format(
                    args.arch, args.dataset, i + 1, args.img_rows,
                    args.img_cols, args.fold_num, args.num_folds))
            return

    print('best_dice: {:.5f}; best_wacc: {:.5f}'.format(best_dice, best_wacc))
예제 #6
0
def test(args):
    if not os.path.exists(args.root_results):
        os.makedirs(args.root_results)

    model_file_name = os.path.split(args.model_path)[1]
    model_name = model_file_name[:model_file_name.find('_')]

    # Setup Transforms
    rgb_mean = [122.7717 / 255., 115.9465 / 255., 102.9801 /
                255.] if args.norm_type == 'gn' and args.load_pretrained else [
                    0.485, 0.456, 0.406
                ]
    rgb_std = [1. / 255., 1. / 255., 1. /
               255.] if args.norm_type == 'gn' and args.load_pretrained else [
                   0.229, 0.224, 0.225
               ]
    data_trans = transforms.Compose([
        transforms.ToPILImage(),
        transforms.Resize(size=(args.img_rows, args.img_cols)),
        transforms.ToTensor(),
        transforms.Normalize(mean=rgb_mean, std=rgb_std),
    ])

    # Setup Dataloader
    data_loader = get_loader(args.dataset)
    data_path = get_data_path(args.dataset)
    loader = data_loader(data_path,
                         split=args.split,
                         in_channels=args.in_channels,
                         transforms=data_trans,
                         fold_num=args.fold_num,
                         num_folds=args.num_folds,
                         no_gt=args.no_gt,
                         seed=args.seed)

    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed(args.seed)

    n_classes = loader.n_classes
    testloader = data.DataLoader(
        loader, batch_size=args.batch_size)  #, num_workers=2, pin_memory=True)

    # Setup Model
    model = get_model(model_name,
                      n_classes=1,
                      in_channels=args.in_channels,
                      norm_type=args.norm_type,
                      use_cbam=args.use_cbam)
    model.cuda()

    checkpoint = torch.load(args.model_path)  #, encoding="latin1")
    state = convert_state_dict(checkpoint['model_state'])
    model_dict = model.state_dict()
    model_dict.update(state)
    model.load_state_dict(model_dict)

    saved_iter = checkpoint.get('iter', -1)
    dice_val = checkpoint.get('dice', -1)
    wacc_val = checkpoint.get('wacc', -1)
    print("Loaded checkpoint '{}' (iter {}, dice {:.5f}, wAcc {:.5f})".format(
        args.model_path, saved_iter, dice_val, wacc_val))

    running_metrics = runningScore(
        n_classes=2, weight_acc_non_empty=args.weight_acc_non_empty)

    y_prob = np.zeros((loader.__len__(), 1, 1024, 1024), dtype=np.float32)
    y_pred_sum = np.zeros((loader.__len__(), ), dtype=np.int32)
    pred_dict = collections.OrderedDict()
    num_non_empty_masks = 0
    model.eval()
    with torch.no_grad():
        for i, (images, labels, _) in tqdm(enumerate(testloader)):
            images = images.cuda()
            labels = labels.cuda()
            if args.tta:
                bs, c, h, w = images.size()
                images = torch.cat(
                    [images, torch.flip(images, dims=[3])], dim=0)  # hflip

            outputs = model(images, return_aux=False)
            prob = F.sigmoid(outputs)
            if args.tta:
                prob = prob.view(-1, bs, 1, h, w)
                prob[1, :, :, :, :] = torch.flip(prob[1, :, :, :, :], dims=[3])
                prob = prob.mean(0)
            pred = (prob > args.thresh).long()
            pred_sum = pred.sum(3).sum(2).sum(1)
            y_prob[i * args.batch_size:i * args.batch_size +
                   labels.size(0), :, :, :] = prob.cpu().numpy()
            y_pred_sum[i * args.batch_size:i * args.batch_size +
                       labels.size(0)] = pred_sum.cpu().numpy()

        y_pred_sum_argsorted = np.argsort(y_pred_sum)[::-1]
        pruned_idx = int(y_pred_sum_argsorted.shape[0] * args.non_empty_ratio)
        mask_sum_thresh = int(
            y_pred_sum[y_pred_sum_argsorted[pruned_idx]]
        ) if pruned_idx < y_pred_sum_argsorted.shape[0] else 0

        for i, (_, labels, names) in tqdm(enumerate(testloader)):
            labels = labels.cuda()

            prob = torch.from_numpy(
                y_prob[i * args.batch_size:i * args.batch_size +
                       labels.size(0), :, :, :]).float().cuda()
            pred = (prob > args.thresh).long()
            pred_sum = pred.sum(3).sum(2).sum(1)
            for k in range(labels.size(0)):
                if pred_sum[k] > mask_sum_thresh:
                    num_non_empty_masks += 1
                else:
                    pred[k, :, :, :] = torch.zeros_like(pred[k, :, :, :])
                    if args.only_non_empty:
                        pred[k, :, 0, 0] = 1

            if not args.no_gt:
                running_metrics.update(labels.long(), pred.long())
            """
            if args.split == 'test':
                for k in range(labels.size(0)):
                    name = names[0][k]
                    if pred_dict.get(name, None) is None:
                        mask = pred[k, 0, :, :].cpu().numpy()
                        rle = loader.mask2rle(mask)
                        pred_dict[name] = rle
            #"""

    print(
        '# non-empty masks: {:5d} (non_empty_ratio: {:.5f} / mask_sum_thresh: {:6d})'
        .format(num_non_empty_masks, args.non_empty_ratio, mask_sum_thresh))
    if not args.no_gt:
        dice, dice_empty, dice_non_empty, miou, wacc, acc_empty, acc_non_empty = running_metrics.get_scores(
        )
        print('Dice (per image): {:.5f} (empty: {:.5f} / non-empty: {:.5f})'.
              format(dice, dice_empty, dice_non_empty))
        print('wAcc: {:.5f} (empty: {:.5f} / non-empty: {:.5f})'.format(
            wacc, acc_empty, acc_non_empty))
        print('Overall mIoU: {:.5f}'.format(miou))
    running_metrics.reset()

    if args.split == 'test':
        fold_num, num_folds = model_file_name.split('_')[4].split('-')
        prob_file_name = 'prob-{}_{}x{}_{}_{}_{}-{}'.format(
            args.split, args.img_rows, args.img_cols, model_name, saved_iter,
            fold_num, num_folds)
        np.save(
            os.path.join(args.root_results, '{}.npy'.format(prob_file_name)),
            y_prob)
        """