コード例 #1
0
def main(args):
    # Network Builders
    builder = ModelBuilder()
    net_encoder = builder.build_encoder(arch=args.arch_encoder,
                                        fc_dim=args.fc_dim,
                                        weights=args.weights_encoder)
    nr_classes = broden_dataset.nr.copy()
    nr_classes['part'] = sum(
        [len(parts) for obj, parts in broden_dataset.object_part.items()])
    net_decoder = builder.build_decoder(arch=args.arch_decoder,
                                        fc_dim=args.fc_dim,
                                        nr_classes=nr_classes,
                                        weights=args.weights_decoder)

    # TODO(LYC):: move criterion outside model.
    # crit = nn.NLLLoss(ignore_index=-1)

    if args.arch_decoder.endswith('deepsup'):
        segmentation_module = SegmentationModule(net_encoder, net_decoder,
                                                 args.deep_sup_scale)
    else:
        segmentation_module = SegmentationModule(net_encoder, net_decoder)

    print('1 Epoch = {} iters'.format(args.epoch_iters))

    # create loader iterator
    iterator_train = create_multi_source_train_data_loader(args=args)

    # load nets into gpu
    if args.num_gpus > 1:
        segmentation_module = UserScatteredDataParallel(segmentation_module,
                                                        device_ids=range(
                                                            args.num_gpus))
        # For sync bn
        patch_replication_callback(segmentation_module)
    segmentation_module.cuda()

    # Set up optimizers
    nets = (net_encoder, net_decoder)
    optimizers = create_optimizers(nets, args)

    # Main loop
    history = {'train': {'epoch': [], 'loss': [], 'acc': []}}

    for epoch in range(args.start_epoch, args.num_epoch + 1):
        train(segmentation_module, iterator_train, optimizers, history, epoch,
              args)

        # checkpointing
        checkpoint(nets, history, args, epoch)

    print('Training Done!')
コード例 #2
0
ファイル: train.py プロジェクト: CSAILVision/unifiedparsing
def main(args):
    # Network Builders
    builder = ModelBuilder()
    net_encoder = builder.build_encoder(
        arch=args.arch_encoder,
        fc_dim=args.fc_dim,
        weights=args.weights_encoder)
    net_decoder = builder.build_decoder(
        arch=args.arch_decoder,
        fc_dim=args.fc_dim,
        nr_classes=args.nr_classes,
        weights=args.weights_decoder)

    # TODO(LYC):: move criterion outside model.
    # crit = nn.NLLLoss(ignore_index=-1)

    if args.arch_decoder.endswith('deepsup'):
        segmentation_module = SegmentationModule(
            net_encoder, net_decoder, args.deep_sup_scale)
    else:
        segmentation_module = SegmentationModule(
            net_encoder, net_decoder)

    print('1 Epoch = {} iters'.format(args.epoch_iters))

    # create loader iterator
    iterator_train = create_multi_source_train_data_loader(args=args)

    # load nets into gpu
    if args.num_gpus > 1:
        segmentation_module = UserScatteredDataParallel(
            segmentation_module,
            device_ids=range(args.num_gpus))
        # For sync bn
        patch_replication_callback(segmentation_module)
    segmentation_module.cuda()

    # Set up optimizers
    nets = (net_encoder, net_decoder)
    optimizers = create_optimizers(nets, args)

    # Main loop
    history = {'train': {'epoch': [], 'loss': [], 'acc': []}}

    for epoch in range(args.start_epoch, args.num_epoch + 1):
        train(segmentation_module, iterator_train, optimizers, history, epoch, args)

        # checkpointing
        checkpoint(nets, history, args, epoch)

    print('Training Done!')
コード例 #3
0
ファイル: train.py プロジェクト: kunlqt/shape-attentive-unet
def main(args):
    # Network Builders
    builder = ModelBuilder()

    unet = builder.build_unet(num_class=args.num_class,
        arch=args.unet_arch,
        weights=args.weights_unet)

    print("Froze the following layers: ")
    for name, p in unet.named_parameters():
        if p.requires_grad == False:
            print(name)
    print()

    crit = DualLoss(mode="train")

    segmentation_module = SegmentationModule(crit, unet)

    train_augs = Compose([PaddingCenterCrop(256), RandomHorizontallyFlip(), RandomVerticallyFlip(), RandomRotate(180)])
    test_augs = Compose([PaddingCenterCrop(256)])

    # Dataset and Loader
    dataset_train = AC17( #Loads 3D volumes
            root=args.data_root,
            split='train',
            k_split=args.k_split,
            augmentations=train_augs,
            img_norm=args.img_norm)
    ac17_train = load2D(dataset_train, split='train', deform=True) #Dataloader for 2D slices. Requires 3D loader.

    loader_train = data.DataLoader(
        ac17_train,
        batch_size=args.batch_size_per_gpu,
        shuffle=True,
        num_workers=int(args.workers),
        drop_last=True,
        pin_memory=True)

    dataset_val = AC17(
            root=args.data_root,
            split='val',
            k_split=args.k_split,
            augmentations=test_augs,
            img_norm=args.img_norm)

    ac17_val = load2D(dataset_val, split='val', deform=False)

    loader_val = data.DataLoader(
        ac17_val,
        batch_size=1,
        shuffle=False,
        collate_fn=user_scattered_collate,
        num_workers=5,
        drop_last=True)

    # load nets into gpu
    if len(args.gpus) > 1:
        segmentation_module = UserScatteredDataParallel(
            segmentation_module,
            device_ids=args.gpus)
        # For sync bn
        patch_replication_callback(segmentation_module)
    segmentation_module.cuda()

    # Set up optimizers
    nets = (net_encoder, net_decoder, crit) if args.unet == False else (unet, crit)
    optimizers = create_optimizers(nets, args)

    # Main loop
    history = {'train': {'epoch': [], 'loss': [], 'acc': [], 'jaccard': []}}
    best_val = {'epoch_1': 0, 'mIoU_1': 0,
                'epoch_2': 0, 'mIoU_2': 0,
                'epoch_3': 0, 'mIoU_3': 0,
                'epoch' : 0, 'mIoU': 0}

    for epoch in range(args.start_epoch, args.num_epoch + 1):
        train(segmentation_module, loader_train, optimizers, history, epoch, args)
        iou, loss = eval(loader_val, segmentation_module, args, crit)
        #checkpointing
        ckpted = False
        if loss < 0.215:
            ckpted = True
        if iou[0] > best_val['mIoU_1']:
            best_val['epoch_1'] = epoch
            best_val['mIoU_1'] = iou[0]
            ckpted = True

        if iou[1] > best_val['mIoU_2']:
            best_val['epoch_2'] = epoch
            best_val['mIoU_2'] = iou[1]
            ckpted = True

        if iou[2] > best_val['mIoU_3']:
            best_val['epoch_3'] = epoch
            best_val['mIoU_3'] = iou[2]
            ckpted = True

        if (iou[0]+iou[1]+iou[2])/3 > best_val['mIoU']:
            best_val['epoch'] = epoch
            best_val['mIoU'] = (iou[0]+iou[1]+iou[2])/3
            ckpted = True

        if epoch % 50 == 0:
            checkpoint(nets, history, args, epoch)
            continue

        if epoch == args.num_epoch:
            checkpoint(nets, history, args, epoch)
            continue
        if epoch < 15:
            ckpted = False
        if ckpted == False:
            continue
        else:
            checkpoint(nets, history, args, epoch)
            continue
        print()

    print('Training Done!')
コード例 #4
0
def main(args):
    # Network Builders
    builder = ModelBuilder()
    net_encoder=None
    net_decoder=None
    unet=None
    
    if args.unet == False:
        net_encoder = builder.build_encoder(
            arch=args.arch_encoder,
            fc_dim=args.fc_dim,
            weights=args.weights_encoder)
        net_decoder = builder.build_decoder(
            arch=args.arch_decoder,
            fc_dim=args.fc_dim,
            num_class=args.num_class,
            weights=args.weights_decoder)
    else:
        unet = builder.build_unet(num_class=args.num_class, 
            arch=args.unet_arch,
            weights=args.weights_unet)

        print("Froze the following layers: ")
        for name, p in unet.named_parameters():
            if p.requires_grad == False:
                print(name)
        print()
    
    crit = ACLoss(mode="train")
    #crit = nn.CrossEntropyLoss().cuda()
    #crit = nn.BCEWithLogitsLoss(pos_weight=torch.tensor(50))
    #crit = nn.CrossEntropyLoss().cuda()
    #crit = nn.BCELoss()

    if args.arch_decoder.endswith('deepsup') and args.unet == False:
        segmentation_module = SegmentationModule(
            net_encoder, net_decoder, crit, args.deep_sup_scale)
    else:
        segmentation_module = SegmentationModule(
            net_encoder, net_decoder,  crit, is_unet=args.unet, unet=unet)

    train_augs = Compose([PaddingCenterCrop(256), RandomHorizontallyFlip(), RandomVerticallyFlip(), RandomRotate(180)])
    test_augs = Compose([PaddingCenterCrop(256)])
    # Dataset and Loader
    dataset_train = AC17(
            root=args.data_root,
            split='train',
            k_split=args.k_split,
            augmentations=train_augs,
            img_norm=args.img_norm)
    ac17_train = load2D(dataset_train, split='train', deform=True)
    
    loader_train = data.DataLoader(
        ac17_train,
        batch_size=args.batch_size_per_gpu,  # we have modified data_parallel
        shuffle=True, 
        num_workers=int(args.workers),
        drop_last=True,
        pin_memory=True)
    dataset_val = AC17(
            root=args.data_root,
            split='val',
            k_split=args.k_split,
            augmentations=test_augs,
            img_norm=args.img_norm)
    ac17_val = load2D(dataset_val, split='val', deform=False)
    loader_val = data.DataLoader(
        ac17_val,
        batch_size=1,
        shuffle=False,
        collate_fn=user_scattered_collate,
        num_workers=5,
        drop_last=True)
    # create loader iterator
    #iterator_train = iter(loader_train)

    # load nets into gpu
    if len(args.gpus) > 1:
        segmentation_module = UserScatteredDataParallel(
            segmentation_module,
            device_ids=args.gpus)
        # For sync bn
        patch_replication_callback(segmentation_module)
    segmentation_module.cuda()
    
    # Set up optimizers
    nets = (net_encoder, net_decoder, crit) if args.unet == False else (unet, crit)
    optimizers = create_optimizers(nets, args)

    # Main loop
    history = {'train': {'epoch': [], 'loss': [], 'acc': [], 'jaccard': []}}
    best_val = {'epoch_1': 0, 'mIoU_1': 0,
                'epoch_2': 0, 'mIoU_2': 0,
                'epoch_3': 0, 'mIoU_3': 0,
                'epoch' : 0, 'mIoU': 0}

    for epoch in range(args.start_epoch, args.num_epoch + 1):
        train(segmentation_module, loader_train, optimizers, history, epoch, args)
        iou, loss = eval(loader_val, segmentation_module, args, crit)
        #checkpointing
        ckpted = False
        if loss < 0.215:
            ckpted = True
        if iou[0] > best_val['mIoU_1']:
            best_val['epoch_1'] = epoch
            best_val['mIoU_1'] = iou[0]
            ckpted = True

        if iou[1] > best_val['mIoU_2']:
            best_val['epoch_2'] = epoch
            best_val['mIoU_2'] = iou[1]
            ckpted = True

        if iou[2] > best_val['mIoU_3']:
            best_val['epoch_3'] = epoch
            best_val['mIoU_3'] = iou[2]
            ckpted = True
        
        if (iou[0]+iou[1]+iou[2])/3 > best_val['mIoU']:
            best_val['epoch'] = epoch
            best_val['mIoU'] = (iou[0]+iou[1]+iou[2])/3
            ckpted = True
        
        if epoch % 50 == 0:
            checkpoint(nets, history, args, epoch)
            continue

        if epoch == args.num_epoch:
            checkpoint(nets, history, args, epoch)
            continue
        if epoch < 15:
            ckpted = False
        if ckpted == False:
            continue
        else:
            checkpoint(nets, history, args, epoch)
            continue
        print()
    
    #print("[Val] Class 1: Epoch " + str(best_val['epoch_1']) + " had the best mIoU of " + str(best_val['mIoU_1']) + ".")
    #print("[Val] Class 2: Epoch " + str(best_val['epoch_2']) + " had the best mIoU of " + str(best_val['mIoU_2']) + ".")
    #print("[Val] Class 3: Epoch " + str(best_val['epoch_3']) + " had the best mIoU of " + str(best_val['mIoU_3']) + ".")
    print('Training Done!')
コード例 #5
0
def main():
    """Create the model and start the training."""
    with open(args.config) as f:
        config = yaml.load(f)
    for k, v in config['common'].items():
        setattr(args, k, v)
    mkdirs(osp.join("logs/"+args.exp_name))

    logger = create_logger('global_logger', "logs/" + args.exp_name + '/log.txt')
    logger.info('{}'.format(args))
##############################

    for key, val in vars(args).items():
        logger.info("{:16} {}".format(key, val))
    logger.info("random_scale {}".format(args.random_scale))
    logger.info("is_training {}".format(args.is_training))

    h, w = map(int, args.input_size.split(','))
    input_size = (h, w)

    h, w = map(int, args.input_size_target.split(','))
    input_size_target = (h, w)
    print(type(input_size_target[1]))
    cudnn.enabled = True
    args.snapshot_dir = args.snapshot_dir + args.exp_name
    tb_logger = SummaryWriter("logs/"+args.exp_name)
##############################

#validation data
    h, w = map(int, args.input_size_test.split(','))
    input_size_test = (h,w)
    h, w = map(int, args.com_size.split(','))
    com_size = (h, w)
    h, w = map(int, args.input_size_crop.split(','))
    input_size_crop = h,w
    h,w = map(int, args.input_size_target_crop.split(','))
    input_size_target_crop = h,w


    test_normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                       std=[0.229, 0.224, 0.225])
    test_transform = transforms.Compose([
                         transforms.Resize((input_size_test[1], input_size_test[0])),
                         transforms.ToTensor(),
                         test_normalize])

    valloader = data.DataLoader(cityscapesDataSet(
                                       args.data_dir_target,
                                       args.data_list_target_val,
                                       crop_size=input_size_test,
                                       set='train',
                                       transform=test_transform),num_workers=args.num_workers,
                                 batch_size=1, shuffle=False, pin_memory=True)
    with open('./dataset/cityscapes_list/info.json', 'r') as fp:
        info = json.load(fp)
    mapping = np.array(info['label2train'], dtype=np.int)
    label_path_list_val = args.label_path_list_val
    label_path_list_test = args.label_path_list_test
    label_path_list_test = './dataset/cityscapes_list/label.txt'
    gt_imgs_val = open(label_path_list_val, 'r').read().splitlines()
    gt_imgs_val = [osp.join(args.data_dir_target_val, x) for x in gt_imgs_val]
    testloader = data.DataLoader(cityscapesDataSet(
                                    args.data_dir_target,
                                    args.data_list_target_test,
                                    crop_size=input_size_test,
                                    set='val',
                                    transform=test_transform),
                            num_workers=args.num_workers,
                            batch_size=1,
                            shuffle=False, pin_memory=True)

    gt_imgs_test = open(label_path_list_test ,'r').read().splitlines()
    gt_imgs_test = [osp.join(args.data_dir_target_test, x) for x in gt_imgs_test]

    name_classes = np.array(info['label'], dtype=np.str)
    interp_val = nn.Upsample(size=(com_size[1], com_size[0]),mode='bilinear', align_corners=True)

    ####
    #build model
    ####
    builder = ModelBuilder()
    net_encoder = builder.build_encoder(
        arch=args.arch_encoder,
        fc_dim=args.fc_dim,
        weights=args.weights_encoder)
    net_decoder = builder.build_decoder(
        arch=args.arch_decoder,
        fc_dim=args.fc_dim,
        num_class=args.num_classes,
        weights=args.weights_decoder,
        use_aux=True)



    model = SegmentationModule(
        net_encoder, net_decoder, args.use_aux)

    if args.num_gpus > 1:
        model = torch.nn.DataParallel(model)
        patch_replication_callback(model)
    model.cuda()

    nets = (net_encoder, net_decoder, None, None)
    optimizers = create_optimizer(nets, args)
    cudnn.enabled=True
    cudnn.benchmark=True
    model.train()



    mean=[0.485, 0.456, 0.406]
    std=[0.229, 0.224, 0.225]


    source_normalize = transforms_seg.Normalize(mean=mean,
                                                std=std)

    mean_mapping = [0.485, 0.456, 0.406]
    mean_mapping = [item * 255 for item in mean_mapping]

    if not os.path.exists(args.snapshot_dir):
        os.makedirs(args.snapshot_dir)
    source_transform = transforms_seg.Compose([
                             transforms_seg.Resize([input_size[1], input_size[0]]),
                             segtransforms.RandScale((args.scale_min, args.scale_max)),
                             #segtransforms.RandRotate((args.rotate_min, args.rotate_max), padding=mean_mapping, ignore_label=args.ignore_label),
                             #segtransforms.RandomGaussianBlur(),
                             segtransforms.RandomHorizontalFlip(),
                             segtransforms.Crop([input_size_crop[1], input_size_crop[0]], crop_type='rand', padding=mean_mapping, ignore_label=args.ignore_label),
                             transforms_seg.ToTensor(),
                             source_normalize])
    target_normalize = transforms_seg.Normalize(mean=mean,
                                            std=std)
    target_transform = transforms_seg.Compose([
                             transforms_seg.Resize([input_size_target[1], input_size_target[0]]),
                             segtransforms.RandScale((args.scale_min, args.scale_max)),
                             #segtransforms.RandRotate((args.rotate_min, args.rotate_max), padding=mean_mapping, ignore_label=args.ignore_label),
                             #segtransforms.RandomGaussianBlur(),
                             segtransforms.RandomHorizontalFlip(),
                             segtransforms.Crop([input_size_target_crop[1], input_size_target_crop[0]],crop_type='rand', padding=mean_mapping, ignore_label=args.ignore_label),
                             transforms_seg.ToTensor(),
                             target_normalize])
    trainloader = data.DataLoader(
        GTA5DataSet(args.data_dir, args.data_list, max_iters=args.num_steps * args.iter_size * args.batch_size,
                    crop_size=input_size, transform = source_transform),
        batch_size=args.batch_size, shuffle=True, num_workers=1, pin_memory=True)

    trainloader_iter = enumerate(trainloader)

    targetloader = data.DataLoader(fake_cityscapesDataSet(args.data_dir_target, args.data_list_target,
                                                     max_iters=args.num_steps * args.iter_size * args.batch_size,
                                                     crop_size=input_size_target,
                                                     set=args.set,
                                                     transform=target_transform),
                                   batch_size=args.batch_size, shuffle=True, num_workers=1,
                                   pin_memory=True)


    targetloader_iter = enumerate(targetloader)
    # implement model.optim_parameters(args) to handle different models' lr setting


    criterion_seg = torch.nn.CrossEntropyLoss(ignore_index=255,reduce=False)
    interp_target = nn.Upsample(size=(input_size_target[1], input_size_target[0]), align_corners=True, mode='bilinear')

    # labels for adversarial training
    source_label = 0
    target_label = 1


    optimizer_encoder, optimizer_decoder, optimizer_disc, optimizer_reconst = optimizers
    batch_time = AverageMeter(10)
    loss_seg_value1 = AverageMeter(10)
    is_best_test = True
    best_mIoUs = 0
    loss_seg_value2 = AverageMeter(10)
    loss_balance_value = AverageMeter(10)
    loss_pseudo_value = AverageMeter(10)
    bounding_num = AverageMeter(10)
    pseudo_num = AverageMeter(10)

    for i_iter in range(args.num_steps):
        # train G

        # don't accumulate grads in D

        end = time.time()
        _, batch = trainloader_iter.__next__()
        images, labels, _ = batch
        images = Variable(images).cuda(async=True)
        labels = Variable(labels).cuda(async=True)
        seg, aux_seg, loss_seg2, loss_seg1 = model(images, labels)


        loss_seg2 = torch.mean(loss_seg2)
        loss_seg1 = torch.mean(loss_seg1)
        loss = loss_seg2+args.lambda_seg*loss_seg1
        #logger.info(loss_seg1.data.cpu().numpy())
        loss_seg_value2.update(loss_seg2.data.cpu().numpy())
        # train with target
        optimizer_encoder.zero_grad()
        optimizer_decoder.zero_grad()
        loss.backward()
        optimizer_encoder.step()
        optimizer_decoder.step()

        del seg, loss_seg2

        _, batch = targetloader_iter.__next__()
        with torch.no_grad():
            images, labels, _ = batch
            images = Variable(images).cuda(async=True)
            result = model(images, None)
            del result



        batch_time.update(time.time() - end)

        remain_iter = args.num_steps - i_iter
        remain_time = remain_iter * batch_time.avg
        t_m, t_s = divmod(remain_time, 60)
        t_h, t_m = divmod(t_m, 60)
        remain_time = '{:02d}:{:02d}:{:02d}'.format(int(t_h), int(t_m), int(t_s))




        adjust_learning_rate(optimizer_encoder, i_iter, args.lr_encoder, args)
        adjust_learning_rate(optimizer_decoder, i_iter, args.lr_decoder, args)
        if i_iter % args.print_freq == 0:
            lr_encoder = optimizer_encoder.param_groups[0]['lr']
            lr_decoder = optimizer_decoder.param_groups[0]['lr']
            logger.info('exp = {}'.format(args.snapshot_dir))
            logger.info('Iter = [{0}/{1}]\t'
                        'Time = {batch_time.avg:.3f}\t'
                        'loss_seg1 = {loss_seg1.avg:4f}\t'
                        'loss_seg2 = {loss_seg2.avg:.4f}\t'
                        'lr_encoder = {lr_encoder:.8f} lr_decoder = {lr_decoder:.8f}'.format(
                         i_iter, args.num_steps, batch_time=batch_time,
                         loss_seg1=loss_seg_value1, loss_seg2=loss_seg_value2,
                         lr_encoder=lr_encoder,
                         lr_decoder=lr_decoder))


            logger.info("remain_time: {}".format(remain_time))
            if not tb_logger is None:
                tb_logger.add_scalar('loss_seg_value1', loss_seg_value1.avg, i_iter)
                tb_logger.add_scalar('loss_seg_value2', loss_seg_value2.avg, i_iter)
                tb_logger.add_scalar('lr', lr_encoder, i_iter)
            #####
            #save image result

            if i_iter % args.save_pred_every == 0 and i_iter != 0:
                logger.info('taking snapshot ...')
                model.eval()

                val_time = time.time()
                hist = np.zeros((19,19))
                for index, batch in tqdm(enumerate(valloader)):
                    with torch.no_grad():
                        image, name = batch
                        output2, _ = model(Variable(image).cuda(), None)
                        pred = interp_val(output2)
                        del output2
                        pred = pred.cpu().data[0].numpy()
                        pred = pred.transpose(1, 2, 0)
                        pred = np.asarray(np.argmax(pred, axis=2), dtype=np.uint8)
                        label = np.array(Image.open(gt_imgs_val[index]))
                        #label = np.array(label.resize(com_size, Image.
                        label = label_mapping(label, mapping)
                        #logger.info(label.shape)
                        hist += fast_hist(label.flatten(), pred.flatten(), 19)
                mIoUs = per_class_iu(hist)
                for ind_class in range(args.num_classes):
                    logger.info('===>' + name_classes[ind_class] + ':\t' + str(round(mIoUs[ind_class] * 100, 2)))
                    tb_logger.add_scalar(name_classes[ind_class] + '_mIoU', mIoUs[ind_class], i_iter)

                mIoUs = round(np.nanmean(mIoUs) *100, 2)
                if mIoUs >= best_mIoUs:
                    is_best_test = True
                    best_mIoUs = mIoUs
                else:
                    is_best_test = False

                logger.info("current mIoU {}".format(mIoUs))
                logger.info("best mIoU {}".format(best_mIoUs))
                tb_logger.add_scalar('val mIoU', mIoUs, i_iter)
                tb_logger.add_scalar('val mIoU', mIoUs, i_iter)
                net_encoder, net_decoder, net_disc, net_reconst = nets
                save_checkpoint(net_encoder, 'encoder', i_iter, args, is_best_test)
                save_checkpoint(net_decoder, 'decoder', i_iter, args, is_best_test)
            model.train()
コード例 #6
0
def main(cfg, gpus):
    # Network Builders
    net_encoder = ModelBuilder.build_encoder(
        arch=cfg.MODEL.arch_encoder.lower(),
        fc_dim=cfg.MODEL.fc_dim,
        weights=cfg.MODEL.weights_encoder)

    # net_encoder = ModelBuilder.build_encoder(
    #     arch=cfg.MODEL.arch_encoder.lower(),
    #     fc_dim=cfg.MODEL.fc_dim,
    #     weights="pretrained/encoder_epoch_20.pth")

    net_decoder = ModelBuilder.build_decoder(
        arch=cfg.MODEL.arch_decoder.lower(),
        fc_dim=cfg.MODEL.fc_dim,
        num_class=cfg.DATASET.num_class,
        weights=cfg.MODEL.weights_decoder)

    # net_decoder = ModelBuilder.build_decoder(
    #     arch=cfg.MODEL.arch_decoder.lower(),
    #     fc_dim=cfg.MODEL.fc_dim,
    #     num_class=cfg.DATASET.num_class,
    #     weights="pretrained/decoder_epoch_20.pth")

    ct = 0
    # for child in net_encoder.children():
    #     ct+= 1
    #     if ct < 12:
    # for parameter in net_encoder.layer1.parameters():
    #     parameter.requires_grad = False
    # for parameter in net_encoder.layer2.parameters():
    #     parameter.requires_grad = False
    # for parameter in net_encoder.parameters():
    #     parameter.requires_grad = False;
    crit = nn.NLLLoss(ignore_index=-1)

    if cfg.MODEL.arch_decoder.endswith('deepsup'):
        segmentation_module = SegmentationModule(net_encoder, net_decoder,
                                                 crit,
                                                 cfg.TRAIN.deep_sup_scale)
    else:
        segmentation_module = SegmentationModule(net_encoder, net_decoder,
                                                 crit)

    # Dataset and Loader
    dataset_train = TrainDataset(cfg.DATASET.root_dataset,
                                 cfg.DATASET.list_train,
                                 cfg.DATASET,
                                 batch_per_gpu=cfg.TRAIN.batch_size_per_gpu)

    loader_train = torch.utils.data.DataLoader(
        dataset_train,
        batch_size=len(gpus),  # we have modified data_parallel
        shuffle=False,  # we do not use this param
        collate_fn=user_scattered_collate,
        num_workers=cfg.TRAIN.workers,
        drop_last=True,
        pin_memory=True)
    print('1 Epoch = {} iters'.format(cfg.TRAIN.epoch_iters))

    # create loader iterator
    iterator_train = iter(loader_train)

    # load nets into gpu
    if len(gpus) > 1:
        segmentation_module = UserScatteredDataParallel(segmentation_module,
                                                        device_ids=gpus)
        # For sync bn
        patch_replication_callback(segmentation_module)
    segmentation_module.cuda()

    # Set up optimizers
    nets = (net_encoder, net_decoder, crit)
    optimizers = create_optimizers(nets, cfg)

    # Main loop
    history = {'train': {'epoch': [], 'loss': [], 'acc': []}}

    for epoch in range(cfg.TRAIN.start_epoch, cfg.TRAIN.num_epoch):
        train(segmentation_module, iterator_train, optimizers, history,
              epoch + 1, cfg)
        n = ((int(cfg.TRAIN.epoch_iters) // 20) * epoch) + 1
        for i in zip(history['train']['loss'][n:],
                     history['train']['acc'][n:]):
            writer.add_scalar('Loss/Train', i[0], epoch)
            writer.add_scalar('Accuracy/Train', i[1], epoch)
        # checkpointing
        checkpoint(nets, history, cfg, epoch + 1)
    writer.close()
    print('Training Done!')
コード例 #7
0
def main(cfg, gpus):
    # Network Builders
    net_encoder = ModelBuilder.build_encoder(
        arch=cfg.MODEL.arch_encoder.lower(),
        fc_dim=cfg.MODEL.fc_dim,
        weights=cfg.MODEL.weights_encoder,
        dilate_rate=cfg.DATASET.segm_downsampling_rate)
    net_decoder = ModelBuilder.build_decoder(
        arch=cfg.MODEL.arch_decoder.lower(),
        fc_dim=cfg.MODEL.fc_dim,
        num_class=cfg.DATASET.num_class,
        weights=cfg.MODEL.weights_decoder)
    if cfg.MODEL.foveation:
        net_foveater = ModelBuilder.build_foveater(
            in_channel=cfg.MODEL.in_dim,
            out_channel=len(cfg.MODEL.patch_bank),
            len_gpus=len(gpus),
            weights=cfg.MODEL.weights_foveater,
            cfg=cfg)

    # tensor
    writer = SummaryWriter('{}/tensorboard'.format(cfg.DIR))

    if cfg.DATASET.root_dataset == '/scratch0/chenjin/GLEASON2019_DATA/Data/':
        if cfg.TRAIN.loss_fun == 'DiceLoss':
            crit = DiceLoss()
        elif cfg.TRAIN.loss_fun == 'FocalLoss':
            crit = FocalLoss()
        elif cfg.TRAIN.loss_fun == 'DiceCoeff':
            crit = DiceCoeff()
        elif cfg.TRAIN.loss_fun == 'NLLLoss':
            crit = nn.NLLLoss(ignore_index=-2)
        else:
            crit = OhemCrossEntropy(ignore_label=-1,
                                    thres=0.9,
                                    min_kept=100000,
                                    weight=None)
    elif 'ADE20K' in cfg.DATASET.root_dataset:
        crit = nn.NLLLoss(ignore_index=-2)
    elif 'CITYSCAPES' in cfg.DATASET.root_dataset:
        if cfg.TRAIN.loss_fun == 'NLLLoss':
            crit = nn.NLLLoss(ignore_index=19)
        else:
            class_weights = torch.FloatTensor([
                0.8373, 0.918, 0.866, 1.0345, 1.0166, 0.9969, 0.9754, 1.0489,
                0.8786, 1.0023, 0.9539, 0.9843, 1.1116, 0.9037, 1.0865, 1.0955,
                1.0865, 1.1529, 1.0507
            ]).cuda()
            crit = OhemCrossEntropy(ignore_label=20,
                                    thres=0.9,
                                    min_kept=131072,
                                    weight=class_weights)
    elif 'DeepGlob' in cfg.DATASET.root_dataset and (
            cfg.TRAIN.loss_fun == 'FocalLoss'
            or cfg.TRAIN.loss_fun == 'OhemCrossEntropy'):
        if cfg.TRAIN.loss_fun == 'FocalLoss':
            crit = FocalLoss(gamma=6, ignore_label=cfg.DATASET.ignore_index)
        elif cfg.TRAIN.loss_fun == 'OhemCrossEntropy':
            crit = OhemCrossEntropy(ignore_label=cfg.DATASET.ignore_index,
                                    thres=0.9,
                                    min_kept=131072)
    else:
        if cfg.TRAIN.loss_fun == 'NLLLoss':
            if cfg.DATASET.ignore_index != -2:
                crit = nn.NLLLoss(ignore_index=cfg.DATASET.ignore_index)
            else:
                crit = nn.NLLLoss(ignore_index=-2)
        else:
            if cfg.DATASET.ignore_index != -2:
                crit = nn.CrossEntropyLoss(
                    ignore_index=cfg.DATASET.ignore_index)
            else:
                crit = nn.CrossEntropyLoss(ignore_index=-2)
    # crit = DiceCoeff()

    if cfg.MODEL.arch_decoder.endswith('deepsup'):
        segmentation_module = SegmentationModule(net_encoder, net_decoder,
                                                 crit, cfg,
                                                 cfg.TRAIN.deep_sup_scale)
    elif cfg.MODEL.foveation:
        segmentation_module = SegmentationModule(net_encoder, net_decoder,
                                                 crit, cfg)
    else:
        segmentation_module = SegmentationModule(net_encoder, net_decoder,
                                                 crit, cfg)

    if cfg.MODEL.foveation:
        foveation_module = FovSegmentationModule(net_foveater,
                                                 cfg,
                                                 len_gpus=len(gpus))
        total_fov = sum(
            [param.nelement() for param in foveation_module.parameters()])
        print('Number of FoveationModule params: %.2fM \n' % (total_fov / 1e6))

    total = sum(
        [param.nelement() for param in segmentation_module.parameters()])
    print('Number of SegmentationModule params: %.2fM \n' % (total / 1e6))

    # Dataset and Loader
    dataset_train = TrainDataset(cfg.DATASET.root_dataset,
                                 cfg.DATASET.list_train,
                                 cfg.DATASET,
                                 batch_per_gpu=cfg.TRAIN.batch_size_per_gpu)

    loader_train = torch.utils.data.DataLoader(
        dataset_train,
        batch_size=len(gpus),  # we have modified data_parallel
        shuffle=False,  # we do not use this param
        collate_fn=user_scattered_collate,
        num_workers=cfg.TRAIN.workers,
        drop_last=True,
        pin_memory=True)

    print('1 Epoch = {} iters'.format(cfg.TRAIN.epoch_iters))

    # create loader iterator
    iterator_train = iter(loader_train)

    # load nets into gpu
    if len(gpus) > 1:
        segmentation_module = UserScatteredDataParallel(segmentation_module,
                                                        device_ids=gpus)
        # For sync bn
        patch_replication_callback(segmentation_module)
        if cfg.MODEL.foveation:
            foveation_module = UserScatteredDataParallel(foveation_module,
                                                         device_ids=gpus)
            patch_replication_callback(foveation_module)

    segmentation_module.cuda()
    if cfg.MODEL.foveation:
        foveation_module.cuda()

    # Set up optimizers
    nets = (net_encoder, net_decoder, crit)
    if cfg.MODEL.foveation:
        nets = (net_encoder, net_decoder, crit, net_foveater)
    optimizers = create_optimizers(nets, cfg)

    # Main loop
    if cfg.VAL.dice:
        history = {
            'train': {
                'epoch': [],
                'loss': [],
                'acc': []
            },
            'save': {
                'epoch': [],
                'train_loss': [],
                'train_acc': [],
                'val_iou': [],
                'val_dice': [],
                'val_acc': [],
                'print_grad': None
            }
        }
    else:
        history = {
            'train': {
                'epoch': [],
                'loss': [],
                'acc': []
            },
            'save': {
                'epoch': [],
                'train_loss': [],
                'train_acc': [],
                'val_iou': [],
                'val_dice': [],
                'val_acc': [],
                'print_grad': None
            }
        }

    if cfg.TRAIN.start_epoch > 0:
        history_previous_epoches = pd.read_csv(
            '{}/history_epoch_{}.csv'.format(cfg.DIR, cfg.TRAIN.start_epoch))
        history['save']['epoch'] = list(history_previous_epoches['epoch'])
        history['save']['train_loss'] = list(
            history_previous_epoches['train_loss'])
        history['save']['train_acc'] = list(
            history_previous_epoches['train_acc'])
        history['save']['val_iou'] = list(history_previous_epoches['val_iou'])
        history['save']['val_acc'] = list(history_previous_epoches['val_acc'])
        # if cfg.VAL.dice:
        #     history['save']['val_dice'] = history_previous_epoches['val_dice']

    if not os.path.isdir(os.path.join(cfg.DIR,
                                      "Fov_probability_distribution")):
        os.makedirs(os.path.join(cfg.DIR, "Fov_probability_distribution"))
    f_prob = []
    for p in range(len(cfg.MODEL.patch_bank)):
        f = open(
            os.path.join(cfg.DIR, 'Fov_probability_distribution',
                         'patch_{}_distribution.txt'.format(p)), 'w')
        f.close()

    for epoch in range(cfg.TRAIN.start_epoch, cfg.TRAIN.num_epoch):
        if cfg.MODEL.foveation:
            train(segmentation_module, iterator_train, optimizers, epoch + 1,
                  cfg, history, foveation_module)
            if history['train']['print_grad'] is not None and type(
                    history['train']['print_grad']) is not torch.Tensor:
                if history['train']['print_grad'][
                        'layer1_grad'] is not None and history['train'][
                            'print_grad']['layer1_grad'][
                                history['train']['print_grad']['layer1_grad'] >
                                0].numel() > 0:
                    writer.add_histogram(
                        'Print non-zero gradient (layer1) histogram',
                        history['train']['print_grad']['layer1_grad'][
                            history['train']['print_grad']['layer1_grad'] > 0],
                        epoch + 1)
                    writer.add_histogram(
                        'Print gradient (layer1) histogram',
                        history['train']['print_grad']['layer1_grad'],
                        epoch + 1)
                    writer.add_scalar(
                        'Percentage none-zero gradients (layer1)',
                        history['train']['print_grad']['layer1_grad'][
                            history['train']['print_grad']['layer1_grad'] > 0].
                        numel() /
                        history['train']['print_grad']['layer1_grad'].numel(),
                        epoch + 1)
                    writer.add_image(
                        'Print_grad_Fov_softmax_layer1(normalized_b0_p0)',
                        (history['train']['print_grad']['layer1_grad'][0][0] -
                         history['train']['print_grad']['layer1_grad'][0]
                         [0].min()) /
                        (history['train']['print_grad']['layer1_grad'][0]
                         [0].max() - history['train']['print_grad']
                         ['layer1_grad'][0][0].min()),
                        epoch + 1,
                        dataformats='HW')
                if history['train']['print_grad'][
                        'layer2_grad'] is not None and history['train'][
                            'print_grad']['layer2_grad'][
                                history['train']['print_grad']['layer2_grad'] >
                                0].numel() > 0:
                    writer.add_histogram(
                        'Print non-zero gradient (layer2) histogram',
                        history['train']['print_grad']['layer2_grad'][
                            history['train']['print_grad']['layer2_grad'] > 0],
                        epoch + 1)
                    writer.add_histogram(
                        'Print gradient (layer2) histogram',
                        history['train']['print_grad']['layer2_grad'],
                        epoch + 1)
                    writer.add_scalar(
                        'Percentage none-zero gradients (layer2)',
                        history['train']['print_grad']['layer2_grad'][
                            history['train']['print_grad']['layer2_grad'] > 0].
                        numel() /
                        history['train']['print_grad']['layer2_grad'].numel(),
                        epoch + 1)
                    writer.add_image(
                        'Print_grad_Fov_softmax_layer2(normalized_b0_p0)',
                        (history['train']['print_grad']['layer2_grad'][0][0] -
                         history['train']['print_grad']['layer2_grad'][0]
                         [0].min()) /
                        (history['train']['print_grad']['layer2_grad'][0]
                         [0].max() - history['train']['print_grad']
                         ['layer2_grad'][0][0].min()),
                        epoch + 1,
                        dataformats='HW')
                if history['train']['print_grad'][
                        'layer3_grad'] is not None and history['train'][
                            'print_grad']['layer3_grad'][
                                history['train']['print_grad']['layer3_grad'] >
                                0].numel() > 0:
                    writer.add_histogram(
                        'Print non-zero gradient (layer3) histogram',
                        history['train']['print_grad']['layer3_grad'][
                            history['train']['print_grad']['layer3_grad'] > 0],
                        epoch + 1)
                    writer.add_histogram(
                        'Print gradient (layer3) histogram',
                        history['train']['print_grad']['layer3_grad'],
                        epoch + 1)
                    writer.add_scalar(
                        'Percentage none-zero gradients (layer3)',
                        history['train']['print_grad']['layer3_grad'][
                            history['train']['print_grad']['layer3_grad'] > 0].
                        numel() /
                        history['train']['print_grad']['layer3_grad'].numel(),
                        epoch + 1)
                    writer.add_image(
                        'Print_grad_Fov_softmax_layer3(normalized_b0_p0)',
                        (history['train']['print_grad']['layer3_grad'][0][0] -
                         history['train']['print_grad']['layer3_grad'][0]
                         [0].min()) /
                        (history['train']['print_grad']['layer3_grad'][0]
                         [0].max() - history['train']['print_grad']
                         ['layer3_grad'][0][0].min()),
                        epoch + 1,
                        dataformats='HW')

        else:
            train(segmentation_module, iterator_train, optimizers, epoch + 1,
                  cfg, history)
        # checkpointing

        if (epoch + 1) % cfg.TRAIN.checkpoint_per_epoch == 0:
            checkpoint(nets, cfg, epoch + 1)
            checkpoint_last(nets, cfg, epoch + 1)
        else:
            checkpoint_last(nets, cfg, epoch + 1)

        if (epoch + 1) % cfg.TRAIN.eval_per_epoch == 0:
            # eval during train
            if cfg.VAL.multipro:
                if cfg.MODEL.foveation:
                    if cfg.VAL.all_F_Xlr_time:
                        val_iou, val_acc, F_Xlr_all, F_Xlr_score_flat_all = eval_during_train_multipro(
                            cfg, gpus)
                    else:
                        val_iou, val_acc, F_Xlr, F_Xlr_score_flat = eval_during_train_multipro(
                            cfg, gpus)
                else:
                    val_iou, val_acc = eval_during_train_multipro(cfg, gpus)
            else:
                if cfg.VAL.dice:
                    if cfg.MODEL.foveation:
                        if cfg.VAL.all_F_Xlr_time:
                            val_iou, val_dice, val_acc, F_Xlr_all, F_Xlr_score_flat_all = eval_during_train(
                                cfg)
                        else:
                            val_iou, val_dice, val_acc, F_Xlr, F_Xlr_score_flat = eval_during_train(
                                cfg)
                    else:
                        val_iou, val_dice, val_acc = eval_during_train(cfg)
                else:
                    if cfg.MODEL.foveation:
                        if cfg.VAL.all_F_Xlr_time:
                            val_iou, val_acc, F_Xlr_all, F_Xlr_score_flat_all = eval_during_train(
                                cfg)
                        else:
                            val_iou, val_acc, F_Xlr, F_Xlr_score_flat = eval_during_train(
                                cfg)
                    else:
                        val_iou, val_acc = eval_during_train(cfg)

            history['save']['epoch'].append(epoch + 1)
            history['save']['train_loss'].append(history['train']['loss'][-1])
            history['save']['train_acc'].append(history['train']['acc'][-1] *
                                                100)
            history['save']['val_iou'].append(val_iou)
            if cfg.VAL.dice:
                history['save']['val_dice'].append(val_dice)
            history['save']['val_acc'].append(val_acc)
            # write to tensorboard
            writer.add_scalar('Loss/train', history['train']['loss'][-1],
                              epoch + 1)
            writer.add_scalar('Acc/train', history['train']['acc'][-1] * 100,
                              epoch + 1)
            writer.add_scalar('Acc/val', val_acc, epoch + 1)
            writer.add_scalar('mIoU/val', val_iou, epoch + 1)
            if cfg.VAL.dice:
                writer.add_scalar('mDice/val', val_acc, epoch + 1)
            if cfg.VAL.all_F_Xlr_time:
                print('=============F_Xlr_score_flat_all================\n',
                      F_Xlr_score_flat_all.shape)
                for p in range(F_Xlr_score_flat_all.shape[0]):
                    # add small artifact to modify range, because no range flag in add_histogram
                    F_Xlr_score_flat_all[p][0] = 0
                    F_Xlr_score_flat_all[p][-1] = 1
                    writer.add_histogram(
                        'Patch_{} probability histogram'.format(p),
                        F_Xlr_score_flat_all[p], epoch + 1)
                    f = open(
                        os.path.join(cfg.DIR, 'Fov_probability_distribution',
                                     'patch_{}_distribution.txt'.format(p)),
                        'a')
                    if epoch == 0:
                        f.write('epoch/ bins: {}\n'.format(
                            np.histogram(F_Xlr_score_flat_all[p],
                                         bins=10,
                                         range=(0, 1))[1]))
                    f.write('epoch {}: {}\n'.format(
                        epoch + 1,
                        np.histogram(F_Xlr_score_flat_all[p],
                                     bins=10,
                                     range=(0, 1))[0] / sum(
                                         np.histogram(F_Xlr_score_flat_all[p],
                                                      bins=10,
                                                      range=(0, 1))[0])))
                    f.close()
                writer.add_histogram('Patch_All probability histogram',
                                     F_Xlr_score_flat_all, epoch + 1)
            else:
                for p in range(F_Xlr_score_flat_all.shape[0]):
                    F_Xlr_score_flat[p][0] = 0
                    F_Xlr_score_flat[p][-1] = 1
                    writer.add_histogram(
                        'Patch_{} probability histogram'.format(p),
                        F_Xlr_score_flat[p], epoch + 1)
                writer.add_histogram('Patch_All probability histogram',
                                     F_Xlr_score_flat, epoch + 1)
        else:
            history['save']['epoch'].append(epoch + 1)
            history['save']['train_loss'].append(history['train']['loss'][-1])
            history['save']['train_acc'].append(history['train']['acc'][-1] *
                                                100)
            history['save']['val_iou'].append('')
            if cfg.VAL.dice:
                history['save']['val_dice'].append('')
            history['save']['val_acc'].append('')
            # write to tensorboard
            writer.add_scalar('Loss/train', history['train']['loss'][-1],
                              epoch + 1)
            writer.add_scalar('Acc/train', history['train']['acc'][-1] * 100,
                              epoch + 1)
            # writer.add_scalar('Acc/val', '', epoch+1)
            # writer.add_scalar('mIoU/val', '', epoch+1)

        # saving history
        checkpoint_history(history, cfg, epoch + 1)

        if (epoch + 1) % cfg.TRAIN.eval_per_epoch == 0:
            # output F_Xlr
            if cfg.MODEL.foveation:
                # save time series F_Xlr (t,b,d,w,h)
                if epoch == 0 or epoch == cfg.TRAIN.start_epoch:
                    if cfg.VAL.all_F_Xlr_time:
                        F_Xlr_time_all = []
                        for val_idx in range(len(F_Xlr_all)):
                            F_Xlr_time_all.append(F_Xlr_all[val_idx][0])
                    else:
                        F_Xlr_time = F_Xlr
                else:
                    if cfg.VAL.all_F_Xlr_time:
                        for val_idx in range(len(F_Xlr_all)):
                            F_Xlr_time_all[val_idx] = np.concatenate(
                                (F_Xlr_time_all[val_idx],
                                 F_Xlr_all[val_idx][0]),
                                axis=0)
                    else:
                        F_Xlr_time = np.concatenate((F_Xlr_time, F_Xlr),
                                                    axis=0)
                if cfg.VAL.all_F_Xlr_time:
                    for val_idx in range(len(F_Xlr_all)):
                        print('F_Xlr_time_{}'.format(F_Xlr_all[val_idx][1]),
                              F_Xlr_time_all[val_idx].shape)
                        if not os.path.isdir(
                                os.path.join(cfg.DIR, "F_Xlr_time_all_vals")):
                            os.makedirs(
                                os.path.join(cfg.DIR, "F_Xlr_time_all_vals"))
                        np.save(
                            '{}/F_Xlr_time_all_vals/F_Xlr_time_last_{}.npy'.
                            format(cfg.DIR, F_Xlr_all[val_idx][1]),
                            F_Xlr_time_all[val_idx])
                else:
                    print('F_Xlr_time', F_Xlr_time.shape)
                    np.save('{}/F_Xlr_time_last.npy'.format(cfg.DIR),
                            F_Xlr_time)

    if not cfg.TRAIN.save_checkpoint:
        os.remove('{}/encoder_epoch_last.pth'.format(cfg.DIR))
        os.remove('{}/decoder_epoch_last.pth'.format(cfg.DIR))
    print('Training Done!')
    writer.close()
コード例 #8
0
def main(cfg, gpus):
    # Network Builders
    net_enc_query = ModelBuilder.build_encoder(
        arch=cfg.MODEL.arch_encoder.lower(),
        fc_dim=cfg.MODEL.fc_dim,
        weights=cfg.MODEL.weights_enc_query)
    if cfg.MODEL.memory_encoder_noBN:
        net_enc_memory = ModelBuilder.build_encoder_memory_separate(
            arch=cfg.MODEL.arch_encoder.lower()+'_nobn',
            fc_dim=cfg.MODEL.fc_dim,
            weights=cfg.MODEL.weights_enc_memory,
            num_class=cfg.DATASET.num_class)
    else:
        net_enc_memory = ModelBuilder.build_encoder_memory_separate(
            arch=cfg.MODEL.arch_encoder.lower(),
            fc_dim=cfg.MODEL.fc_dim,
            weights=cfg.MODEL.weights_enc_memory,
            num_class=cfg.DATASET.num_class,
            pretrained=cfg.memory_enc_pretrained)
    net_att_query = ModelBuilder.build_encoder(
        arch='attention',
        fc_dim=cfg.MODEL.fc_dim,
        weights=cfg.MODEL.weights_att_query)
    net_att_memory = ModelBuilder.build_encoder(
        arch='attention',
        fc_dim=cfg.MODEL.fc_dim,
        weights=cfg.MODEL.weights_att_memory)
    net_decoder = ModelBuilder.build_decoder(
        arch=cfg.MODEL.arch_decoder.lower(),
        fc_dim=cfg.MODEL.fc_dim,
        num_class=cfg.DATASET.num_class,
        weights=cfg.MODEL.weights_decoder)
    

    crit = nn.NLLLoss(ignore_index=-1)

    if cfg.MODEL.arch_decoder.endswith('deepsup'):
        segmentation_module = SegmentationAttentionSeparateModule(
            net_enc_query, net_enc_memory, net_att_query, net_att_memory, net_decoder, crit, cfg.TRAIN.deep_sup_scale, zero_memory=cfg.MODEL.zero_memory, random_memory_bias=cfg.MODEL.random_memory_bias, random_memory_nobias=cfg.MODEL.random_memory_nobias, random_scale=cfg.MODEL.random_scale, qval_qread_BN=cfg.MODEL.qval_qread_BN, normalize_key=cfg.MODEL.normalize_key, p_scalar=cfg.MODEL.p_scalar, debug=True)
    else:
        segmentation_module = SegmentationAttentionSeparateModule(
            net_enc_query, net_enc_memory, net_att_query, net_att_memory, net_decoder, crit, zero_memory=cfg.MODEL.zero_memory, random_memory_bias=cfg.MODEL.random_memory_bias, random_memory_nobias=cfg.MODEL.random_memory_nobias, random_scale=cfg.MODEL.random_scale, qval_qread_BN=cfg.MODEL.qval_qread_BN, normalize_key=cfg.MODEL.normalize_key, p_scalar=cfg.MODEL.p_scalar, debug=True)

    # Dataset and Loader
    dataset_train = TrainDataset(
        cfg.DATASET.root_dataset,
        cfg.DATASET.list_train,
        cfg.DATASET,
        cfg.DATASET.ref_path, 
        cfg.DATASET.ref_start, 
        cfg.DATASET.ref_end,
        batch_per_gpu=cfg.TRAIN.batch_size_per_gpu)

    loader_train = torch.utils.data.DataLoader(
        dataset_train,
        batch_size=len(gpus),  # we have modified data_parallel
        shuffle=False,  # we do not use this param
        collate_fn=user_scattered_collate,
        num_workers=cfg.TRAIN.workers,
        drop_last=True,
        pin_memory=True)
    print('1 Epoch = {} iters'.format(cfg.TRAIN.epoch_iters))

    # create loader iterator
    iterator_train = iter(loader_train)

    # load nets into gpu
    '''if len(gpus) > 1:
        segmentation_module = UserScatteredDataParallel(
            segmentation_module,
            device_ids=gpus)
        # For sync bn
        patch_replication_callback(segmentation_module)'''
    segmentation_module = UserScatteredDataParallel(
        segmentation_module,
        device_ids=gpus)
    # For sync bn
    patch_replication_callback(segmentation_module)
    segmentation_module.cuda()

    # Set up optimizers
    nets = (net_enc_query, net_enc_memory, net_att_query, net_att_memory, net_decoder, crit)
    optimizers = create_optimizers(nets, cfg)

    for epoch in range(0, 1):
        train(segmentation_module, iterator_train, optimizers, None, epoch+1, cfg)

    print('Training Done!')
コード例 #9
0
ファイル: train.py プロジェクト: rexxxx1234/SAUNet-demo
        for name, p in unet.named_parameters():
            if p.requires_grad == False:
                print(name)
        print()

        crit = DualLoss(mode="train")

        segmentation_module = SegmentationModule(crit, unet)

        # load nets into gpu
        if len(args.gpus) > 1:
            segmentation_module = UserScatteredDataParallel(
                segmentation_module,
                device_ids=args.gpus)
            # For sync bn
            patch_replication_callback(segmentation_module)
        segmentation_module.cuda()

        print("ready to load data")

        import train
        import streamlit as st
        from os import listdir
        from os.path import isfile, join
        from PIL import Image
        test_path = '/home/rexma/demo/LCTSC/demo/'
        temp_path = '/home/rexma/demo/LCTSC/temp/'
        result_path = '/home/rexma/demo/LCTSC/result/'
        st.sidebar.title("About")
        st.sidebar.info(
            "This is a demo application written to help you understand Streamlit. The application identifies the animal in the picture. It was built using a Convolution Neural Network (CNN).")
コード例 #10
0
def main(args):
    # Network Builders
    builder = VGGModelBuilder()
    net_encoder = builder.build_encoder(arch=args.arch_encoder,
                                        weights=args.weights_encoder)
    net_decoder = builder.build_decoder(arch=args.arch_decoder,
                                        fc_dim=1024,
                                        num_class=args.num_class,
                                        weights=args.weights_decoder)

    crit = nn.CrossEntropyLoss(ignore_index=255, reduction='sum')

    if args.arch_decoder.endswith('deepsup'):
        segmentation_module = SegmentationModule(net_encoder, net_decoder,
                                                 crit, args.deep_sup_scale)
    else:
        segmentation_module = SegmentationModule(net_encoder, net_decoder,
                                                 crit)
    print(segmentation_module)

    # Dataset and Loader
    dataset_train = VOCTrainDataset(args,
                                    batch_per_gpu=args.batch_size_per_gpu)
    loader_train = torchdata.DataLoader(
        dataset_train,
        batch_size=1,  # data_parallel have been modified, not useful
        shuffle=False,  # do not use this param
        collate_fn=user_scattered_collate,
        num_workers=1,  # MUST be 1 or 0
        drop_last=True,
        pin_memory=True)

    print('[{}] 1 training epoch = {} iters'.format(
        args.epoch_iters,
        datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")))

    # create loader iterator
    iterator_train = iter(loader_train)

    # load nets into gpu
    if args.num_gpus > 1:
        segmentation_module = UserScatteredDataParallel(segmentation_module,
                                                        device_ids=args.gpu_id)
        # For sync bn
        patch_replication_callback(segmentation_module)
    segmentation_module.cuda(device=args.gpu_id[0])

    # Set up optimizers
    nets = (net_encoder, net_decoder, crit)
    optimizers = create_optimizers(nets, args)

    # Main loop
    history = {
        'train': {
            'epoch': [],
            'loss': [],
            'train_acc': [],
            'test_ious': [],
            'test_mean_iou': []
        }
    }

    for epoch in range(args.start_epoch, args.num_epoch + 1):
        # test/validate
        dataset_val = VOCValDataset(
            args
        )  # create val dataset loader for every eval, in order to use drop_last=false
        loader_val = torchdata.DataLoader(
            dataset_val,
            # data_parallel have been modified, MUST use val batch size
            #   and collate_fn MUST be user_scattered_collate
            batch_size=args.val_batch_size,
            shuffle=False,
            collate_fn=user_scattered_collate,
            num_workers=1,  # MUST be 1 or 0
            drop_last=False)
        iterator_val = iter(loader_val)
        if epoch % args.test_epoch_interval == 0 or epoch == args.num_epoch:  # epoch != 1 and
            (cls_ious, cls_mean_iou) = evaluate(segmentation_module,
                                                iterator_val, args)
            history['train']['test_ious'].append(cls_ious)
            history['train']['test_mean_iou'].append(cls_mean_iou)
        else:
            history['train']['test_ious'].append(-1)  # empty data
            history['train']['test_mean_iou'].append(-1)

# train
        train(segmentation_module, iterator_train, optimizers, history, epoch,
              args)

        # checkpointing
        checkpoint(nets, history, args, epoch)

    print('[{}] Training Done!'.format(
        datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")))
コード例 #11
0
def main():
    """Create the model and start the training."""
    with open(args.config) as f:
        config = yaml.load(f)
    for k, v in config['common'].items():
        setattr(args, k, v)
    mkdirs(osp.join("logs/"+args.exp_name))

    logger = create_logger('global_logger', "logs/" + args.exp_name + '/log.txt')
    logger.info('{}'.format(args))
##############################

    for key, val in vars(args).items():
        logger.info("{:16} {}".format(key, val))
    logger.info("random_scale {}".format(args.random_scale))
    logger.info("is_training {}".format(args.is_training))

    h, w = map(int, args.input_size.split(','))
    input_size = (h, w)

    h, w = map(int, args.input_size_target.split(','))
    input_size_target = (h, w)
    print(type(input_size_target[1]))
    cudnn.enabled = True
    args.snapshot_dir = args.snapshot_dir + args.exp_name
    tb_logger = SummaryWriter("logs/"+args.exp_name)
##############################

#validation data
    local_array = np.load("local.npy")
    local_array = local_array[:,:,:19]
    local_array = local_array / local_array.sum(2).reshape(512, 1024, 1)
    local_array = local_array.transpose(2,0,1)
    local_array = torch.from_numpy(local_array)
    local_array = local_array.view(1, 19, 512, 1024)
    h, w = map(int, args.input_size_test.split(','))
    input_size_test = (h,w)
    h, w = map(int, args.com_size.split(','))
    com_size = (h, w)
    h, w = map(int, args.input_size_crop.split(','))
    input_size_crop = h,w
    h,w = map(int, args.input_size_target_crop.split(','))
    input_size_target_crop = h,w

    mean=[0.485, 0.456, 0.406]
    std=[0.229, 0.224, 0.225]


    normalize_module = transforms_seg.Normalize(mean=mean,
                                                std=std)

    test_normalize = transforms.Normalize(mean=mean,
                                          std=std)

    test_transform = transforms.Compose([
                         transforms.Resize((input_size_test[1], input_size_test[0])),
                         transforms.ToTensor(),
                         test_normalize])

    valloader = data.DataLoader(cityscapesDataSet(
                                       args.data_dir_target,
                                       args.data_list_target_val,
                                       crop_size=input_size_test,
                                       set='train',
                                       transform=test_transform),num_workers=args.num_workers,
                                 batch_size=1, shuffle=False, pin_memory=True)
    with open('./dataset/cityscapes_list/info.json', 'r') as fp:
        info = json.load(fp)
    mapping = np.array(info['label2train'], dtype=np.int)
    label_path_list_val = args.label_path_list_val
    gt_imgs_val = open(label_path_list_val, 'r').read().splitlines()
    gt_imgs_val = [osp.join(args.data_dir_target_val, x) for x in gt_imgs_val]


    name_classes = np.array(info['label'], dtype=np.str)
    interp_val = nn.Upsample(size=(com_size[1], com_size[0]),mode='bilinear', align_corners=True)

    ####
    #build model
    ####
    builder = ModelBuilder()
    net_encoder = builder.build_encoder(
        arch=args.arch_encoder,
        fc_dim=args.fc_dim,
        weights=args.weights_encoder)
    net_decoder = builder.build_decoder(
        arch=args.arch_decoder,
        fc_dim=args.fc_dim,
        num_class=args.num_classes,
        weights=args.weights_decoder,
        use_aux=True)

    weighted_softmax = pd.read_csv("weighted_loss.txt", header=None)
    weighted_softmax = weighted_softmax.values
    weighted_softmax = torch.from_numpy(weighted_softmax)
    weighted_softmax = weighted_softmax / torch.sum(weighted_softmax)
    weighted_softmax = weighted_softmax.cuda().float()
    model = SegmentationModule(
        net_encoder, net_decoder, args.use_aux)

    if args.num_gpus > 1:
        model = torch.nn.DataParallel(model)
        patch_replication_callback(model)
    model.cuda()

    nets = (net_encoder, net_decoder, None, None)
    optimizers = create_optimizer(nets, args)
    cudnn.enabled=True
    cudnn.benchmark=True
    model.train()



    mean_mapping = [0.485, 0.456, 0.406]
    mean_mapping = [item * 255 for item in mean_mapping]

    if not os.path.exists(args.snapshot_dir):
        os.makedirs(args.snapshot_dir)
    source_transform = transforms_seg.Compose([
                             transforms_seg.Resize([input_size[1], input_size[0]]),
                             #segtransforms.RandScale((0.75, args.scale_max)),
                             #segtransforms.RandRotate((args.rotate_min, args.rotate_max), padding=mean_mapping, ignore_label=args.ignore_label),
                             #segtransforms.RandomGaussianBlur(),
                             #segtransforms.RandomHorizontalFlip(),
                             #segtransforms.Crop([input_size_crop[1], input_size_crop[0]], crop_type='rand', padding=mean_mapping, ignore_label=args.ignore_label),
                             transforms_seg.ToTensor(),
                             normalize_module])


    target_transform = transforms_seg.Compose([
                             transforms_seg.Resize([input_size_target[1], input_size_target[0]]),
                             #segtransforms.RandScale((0.75, args.scale_max)),
                             #segtransforms.RandRotate((args.rotate_min, args.rotate_max), padding=mean_mapping, ignore_label=args.ignore_label),
                             #segtransforms.RandomGaussianBlur(),
                             #segtransforms.RandomHorizontalFlip(),
                             #segtransforms.Crop([input_size_target_crop[1], input_size_target_crop[0]],crop_type='rand', padding=mean_mapping, ignore_label=args.ignore_label),
                             transforms_seg.ToTensor(),
                             normalize_module])
    trainloader = data.DataLoader(
        GTA5DataSet(args.data_dir, args.data_list, max_iters=args.num_steps * args.iter_size * args.batch_size,
                    crop_size=input_size, transform = source_transform),
        batch_size=args.batch_size, shuffle=True, num_workers=5, pin_memory=True)

    trainloader_iter = enumerate(trainloader)

    targetloader = data.DataLoader(fake_cityscapesDataSet(args.data_dir_target, args.data_list_target,
                                                     max_iters=args.num_steps * args.iter_size * args.batch_size,
                                                     crop_size=input_size_target,
                                                     set=args.set,
                                                     transform=target_transform),
                                   batch_size=args.batch_size, shuffle=True, num_workers=5,
                                   pin_memory=True)


    targetloader_iter = enumerate(targetloader)
    # implement model.optim_parameters(args) to handle different models' lr setting


    criterion_seg = torch.nn.CrossEntropyLoss(ignore_index=255,reduce=False)
    criterion_pseudo = torch.nn.BCEWithLogitsLoss(reduce=False).cuda()
    bce_loss = torch.nn.BCEWithLogitsLoss().cuda()
    criterion_reconst = torch.nn.L1Loss().cuda()
    criterion_soft_pseudo = torch.nn.MSELoss(reduce=False).cuda()
    criterion_box = torch.nn.CrossEntropyLoss(ignore_index=255, reduce=False)
    interp = nn.Upsample(size=(input_size[1], input_size[0]),align_corners=True, mode='bilinear')
    interp_target = nn.Upsample(size=(input_size_target[1], input_size_target[0]), align_corners=True, mode='bilinear')

    # labels for adversarial training
    source_label = 0
    target_label = 1


    optimizer_encoder, optimizer_decoder, optimizer_disc, optimizer_reconst = optimizers
    batch_time = AverageMeter(10)
    loss_seg_value1 = AverageMeter(10)
    best_mIoUs = 0
    best_test_mIoUs = 0
    loss_seg_value2 = AverageMeter(10)
    loss_reconst_source_value = AverageMeter(10)
    loss_reconst_target_value = AverageMeter(10)
    loss_balance_value = AverageMeter(10)
    loss_eq_att_value = AverageMeter(10)
    loss_pseudo_value = AverageMeter(10)
    bounding_num = AverageMeter(10)
    pseudo_num = AverageMeter(10)
    loss_bbx_att_value = AverageMeter(10)

    for i_iter in range(args.num_steps):
        # train G

        # don't accumulate grads in D

        end = time.time()
        _, batch = trainloader_iter.__next__()
        images, labels, _ = batch
        images = Variable(images).cuda(async=True)
        labels = Variable(labels).cuda(async=True)
        results  = model(images, labels)
        loss_seg2 = results[-2]
        loss_seg1 = results[-1]

        loss_seg2 = torch.mean(loss_seg2)
        loss_seg1 = torch.mean(loss_seg1)
        loss = args.lambda_trade_off*(loss_seg2+args.lambda_seg * loss_seg1)
        # proper normalization
        #logger.info(loss_seg1.data.cpu().numpy())
        loss_seg_value2.update(loss_seg2.data.cpu().numpy())
        optimizer_encoder.zero_grad()
        optimizer_decoder.zero_grad()
        loss.backward()
        optimizer_encoder.step()
        optimizer_decoder.step()

        _, batch = targetloader_iter.__next__()
        images, fake_labels, _ = batch
        images = Variable(images).cuda(async=True)
        fake_labels = Variable(fake_labels, requires_grad=False).cuda()
        results = model(images, None)
        target_seg = results[0]
        conf_tea, pseudo_label = torch.max(nn.functional.softmax(target_seg), dim=1)
        pseudo_label = pseudo_label.detach()
        # pseudo label hard
        loss_pseudo = criterion_seg(target_seg, pseudo_label)
        fake_mask = (fake_labels!=255).float().detach()
        conf_mask = torch.gt(conf_tea, args.conf_threshold).float().detach()
        loss_pseudo = loss_pseudo * conf_mask.detach() * fake_mask.detach()
        loss_pseudo = loss_pseudo.view(-1)
        loss_pseudo = loss_pseudo[loss_pseudo!=0]
        #loss_pseudo = torch.sum(loss_pseudo * conf_mask.detach() * fake_mask.detach())
        predict_class_mean = torch.mean(nn.functional.softmax(target_seg), dim=0).mean(1).mean(1)
        equalise_cls_loss = robust_binary_crossentropy(predict_class_mean, weighted_softmax)
        #equalise_cls_loss = torch.mean(equalise_cls_loss)* args.num_classes * torch.sum(conf_mask * fake_mask) / float(input_size_crop[0] * input_size_crop[1] * args.batch_size)
        # new equalise_cls_loss
        equalise_cls_loss = torch.mean(equalise_cls_loss)
        #loss=args.lambda_balance * equalise_cls_loss
        #bbx attention
        loss_bbx_att = []
        loss_eq_att = []
        for box_idx, box_size in enumerate(args.box_size):
            pooling = torch.nn.AvgPool2d(box_size)
            pooling_result_i = pooling(target_seg)
            local_i = pooling(local_array).float().cuda()
            pooling_conf_mask, pooling_pseudo = torch.max(nn.functional.softmax(pooling_result_i), dim=1)
            pooling_conf_mask = torch.gt(pooling_conf_mask, args.conf_threshold).float().detach()
            fake_mask_i = pooling(fake_labels.unsqueeze(1).float())
            fake_mask_i = fake_mask_i.squeeze(1)
            fake_mask_i = (fake_mask_i!=255).float().detach()
            loss_bbx_att_i = criterion_seg(pooling_result_i, pooling_pseudo)
            loss_bbx_att_i = loss_bbx_att_i * pooling_conf_mask * fake_mask_i
            loss_bbx_att_i = loss_bbx_att_i.view(-1)
            loss_bbx_att_i = loss_bbx_att_i[loss_bbx_att_i!=0]
            loss_bbx_att.append(loss_bbx_att_i)
            pooling_result_i = pooling_result_i.mean(0).unsqueeze(0)

            equalise_cls_loss_i = robust_binary_crossentropy(nn.functional.softmax(pooling_result_i), local_i)
            equalise_cls_loss_i = equalise_cls_loss_i.mean(1)
            equalise_cls_loss_i = equalise_cls_loss_i * pooling_conf_mask * fake_mask_i
            equalise_cls_loss_i = equalise_cls_loss_i.view(-1)
            equalise_cls_loss_i = equalise_cls_loss_i[equalise_cls_loss_i!=0]
            loss_eq_att.append(equalise_cls_loss_i)


        if len(args.box_size) > 0:
            if args.merge_1x1:
                loss_bbx_att.append(loss_pseudo)
            loss_bbx_att = torch.cat(loss_bbx_att, dim=0)
            bounding_num.update(loss_bbx_att.size(0) / float(560*480*args.batch_size))
            loss_bbx_att = torch.mean(loss_bbx_att)

            loss_eq_att = torch.cat(loss_eq_att, dim=0)
            loss_eq_att = torch.mean(loss_eq_att)

            loss_eq_att_value.update(loss_eq_att.item())
        else:
            loss_bbx_att = torch.mean(loss_pseudo)
            loss_eq_att = 0


        pseudo_num.update(loss_pseudo.size(0) / float(560*480*args.batch_size))
        loss_pseudo = torch.mean(loss_pseudo)
        if not args.merge_1x1:
            loss += args.lambda_pseudo * loss_pseudo
        loss = args.lambda_balance * equalise_cls_loss
        if not isinstance(loss_bbx_att, list):
            loss += args.lambda_pseudo * loss_bbx_att
        loss += args.lambda_eq * loss_eq_att
        loss_pseudo_value.update(loss_pseudo.item())
        loss_balance_value.update(equalise_cls_loss.item())


        optimizer_encoder.zero_grad()
        optimizer_decoder.zero_grad()
        loss.backward()
        optimizer_encoder.step()
        optimizer_decoder.step()
        #optimizer_disc.step()
        #loss_target_disc_value.update(loss_target_disc.data.cpu().numpy())




        batch_time.update(time.time() - end)

        remain_iter = args.num_steps - i_iter
        remain_time = remain_iter * batch_time.avg
        t_m, t_s = divmod(remain_time, 60)
        t_h, t_m = divmod(t_m, 60)
        remain_time = '{:02d}:{:02d}:{:02d}'.format(int(t_h), int(t_m), int(t_s))



        if i_iter == args.decrease_lr:
            adjust_learning_rate(optimizer_encoder, i_iter, args.lr_encoder, args)
            adjust_learning_rate(optimizer_decoder, i_iter, args.lr_decoder, args)
        if i_iter % args.print_freq == 0:
            lr_encoder = optimizer_encoder.param_groups[0]['lr']
            lr_decoder = optimizer_decoder.param_groups[0]['lr']
            logger.info('exp = {}'.format(args.snapshot_dir))
            logger.info('Iter = [{0}/{1}]\t'
                        'Time = {batch_time.avg:.3f}\t'
                        'loss_seg1 = {loss_seg1.avg:4f}\t'
                        'loss_seg2 = {loss_seg2.avg:.4f}\t'
                        'loss_reconst_source = {loss_reconst_source.avg:.4f}\t'
                        'loss_bbx_att = {loss_bbx_att.avg:.4f}\t'
                        'loss_reconst_target = {loss_reconst_target.avg:.4f}\t'
                        'loss_pseudo = {loss_pseudo.avg:.4f}\t'
                        'loss_eq_att = {loss_eq_att.avg:.4f}\t'
                        'loss_balance = {loss_balance.avg:.4f}\t'
                        'bounding_num = {bounding_num.avg:.4f}\t'
                        'pseudo_num = {pseudo_num.avg:4f}\t'
                        'lr_encoder = {lr_encoder:.8f} lr_decoder = {lr_decoder:.8f}'.format(
                         i_iter, args.num_steps, batch_time=batch_time,
                         loss_seg1=loss_seg_value1, loss_seg2=loss_seg_value2,
                         loss_pseudo=loss_pseudo_value,
                         loss_bbx_att = loss_bbx_att_value,
                         bounding_num = bounding_num,
                         loss_eq_att = loss_eq_att_value,
                         pseudo_num = pseudo_num,
                         loss_reconst_source=loss_reconst_source_value,
                         loss_balance=loss_balance_value,
                         loss_reconst_target=loss_reconst_target_value,
                         lr_encoder=lr_encoder,
                         lr_decoder=lr_decoder))


            logger.info("remain_time: {}".format(remain_time))
            if not tb_logger is None:
                tb_logger.add_scalar('loss_seg_value1', loss_seg_value1.avg, i_iter)
                tb_logger.add_scalar('loss_seg_value2', loss_seg_value2.avg, i_iter)
                tb_logger.add_scalar('bounding_num', bounding_num.avg, i_iter)
                tb_logger.add_scalar('pseudo_num', pseudo_num.avg, i_iter)
                tb_logger.add_scalar('loss_pseudo', loss_pseudo_value.avg, i_iter)
                tb_logger.add_scalar('lr', lr_encoder, i_iter)
                tb_logger.add_scalar('loss_balance', loss_balance_value.avg, i_iter)
            #####
            #save image result
            if i_iter % args.save_pred_every == 0 and i_iter != 0:
                logger.info('taking snapshot ...')
                model.eval()

                val_time = time.time()
                hist = np.zeros((19,19))
                # f = open(args.result_dir, 'a')
                # for index, batch in tqdm(enumerate(testloader)):
                #     with torch.no_grad():
                #         image, name = batch
                #         results = model(Variable(image).cuda(), None)
                #         output2 = results[0]
                #         pred = interp_val(output2)
                #         del output2
                #         pred = pred.cpu().data[0].numpy()
                #         pred = pred.transpose(1, 2, 0)
                #         pred = np.asarray(np.argmax(pred, axis=2), dtype=np.uint8)
                #         label = np.array(Image.open(gt_imgs_val[index]))
                #         #label = np.array(label.resize(com_size, Image.
                #         label = label_mapping(label, mapping)
                #         #logger.info(label.shape)
                #         hist += fast_hist(label.flatten(), pred.flatten(), 19)
                # mIoUs = per_class_iu(hist)
                # for ind_class in range(args.num_classes):
                #     logger.info('===>' + name_classes[ind_class] + ':\t' + str(round(mIoUs[ind_class] * 100, 2)))
                #     tb_logger.add_scalar(name_classes[ind_class] + '_mIoU', mIoUs[ind_class], i_iter)


                # logger.info(mIoUs)
                # tb_logger.add_scalar('val mIoU', mIoUs, i_iter)
                # tb_logger.add_scalar('val mIoU', mIoUs, i_iter)
                # f.write('i_iter:{:d},\tmiou:{:0.3f} \n'.format(i_iter, mIoUs))
                # f.close()
                # if mIoUs > best_mIoUs:
                is_best = True
                # best_mIoUs = mIoUs
                #test validation
                model.eval()
                val_time = time.time()
                hist = np.zeros((19,19))
                # f = open(args.result_dir, 'a')
                for index, batch in tqdm(enumerate(valloader)):
                    with torch.no_grad():
                        image, name = batch
                        results = model(Variable(image).cuda(), None)
                        output2 = results[0]
                        pred = interp_val(output2)
                        del output2
                        pred = pred.cpu().data[0].numpy()
                        pred = pred.transpose(1, 2, 0)
                        pred = np.asarray(np.argmax(pred, axis=2), dtype=np.uint8)
                        label = np.array(Image.open(gt_imgs_val[index]))
                        #label = np.array(label.resize(com_size, Image.
                        label = label_mapping(label, mapping)
                        #logger.info(label.shape)
                        hist += fast_hist(label.flatten(), pred.flatten(), 19)
                mIoUs = per_class_iu(hist)
                for ind_class in range(args.num_classes):
                    logger.info('===>' + name_classes[ind_class] + ':\t' + str(round(mIoUs[ind_class] * 100, 2)))
                    tb_logger.add_scalar(name_classes[ind_class] + '_mIoU', mIoUs[ind_class], i_iter)

                mIoUs = round(np.nanmean(mIoUs) *100, 2)
                is_best_test = False
                logger.info(mIoUs)
                tb_logger.add_scalar('test mIoU', mIoUs, i_iter)
                if mIoUs > best_test_mIoUs:
                    best_test_mIoUs = mIoUs
                    is_best_test = True
                # logger.info("best mIoU {}".format(best_mIoUs))
                logger.info("best test mIoU {}".format(best_test_mIoUs))
                net_encoder, net_decoder, net_disc, net_reconst = nets
                save_checkpoint(net_encoder, 'encoder', i_iter, args, is_best_test)
                save_checkpoint(net_decoder, 'decoder', i_iter, args, is_best_test)
                is_best_test = False
            model.train()
コード例 #12
0
def main(args):
    # Network Builders
    builder = ModelBuilder()
    net_encoder = builder.build_encoder(arch=args.arch_encoder,
                                        fc_dim=args.fc_dim,
                                        weights=args.weights_encoder)
    net_decoder = builder.build_decoder(arch=args.arch_decoder,
                                        fc_dim=args.fc_dim,
                                        num_class=150,
                                        weights=args.weights_decoder)

    crit = nn.NLLLoss(ignore_index=-1)

    if args.arch_decoder.endswith('deepsup'):
        segmentation_module = SegmentationModule(net_encoder, net_decoder,
                                                 crit, args.deep_sup_scale)
    else:
        segmentation_module = SegmentationModule(net_encoder, net_decoder,
                                                 crit)

    ########
    for param in segmentation_module.encoder.parameters():
        param.requires_grad = False
    #for name, param in segmentation_module.decoder.named_parameters():
    #   print(name)
    #  if(name == "conv_last.weight" or name =="conv_last.bias" or name =="conv_last_deepsup.weight" or name =="conv_last_deepsup.bias"):
    #     param.requires_grad = True
    #else:
    #   param.requires_grad = False
    #print(param.requires_grad)
    segmentation_module.decoder.conv_last = nn.Conv2d(args.fc_dim // 4, 12, 1,
                                                      1, 0)
    #segmentation_module.decoder.conv_last.
    segmentation_module.decoder.conv_last_deepsup = nn.Conv2d(
        args.fc_dim // 4, 12, 1, 1, 0)
    ########

    # Dataset and Loader
    dataset_train = TrainDataset(args.list_train,
                                 args,
                                 batch_per_gpu=args.batch_size_per_gpu)

    loader_train = torchdata.DataLoader(
        dataset_train,
        batch_size=len(args.gpus),  # we have modified data_parallel
        shuffle=False,  # we do not use this param
        collate_fn=user_scattered_collate,
        num_workers=int(args.workers),
        drop_last=True,
        pin_memory=True)

    print('1 Epoch = {} iters'.format(args.epoch_iters))

    # create loader iterator
    iterator_train = iter(loader_train)
    #######
    #torch.backends.cudnn.benchmark = True
    #CUDA_LAUNCH_BLOCKING=1
    #######
    # load nets into gpu
    if len(args.gpus) > 1:
        segmentation_module = UserScatteredDataParallel(segmentation_module,
                                                        device_ids=args.gpus)
        # For sync bn
        patch_replication_callback(segmentation_module)
    segmentation_module.cuda()

    # Set up optimizers
    nets = (net_encoder, net_decoder, crit)
    optimizers = create_optimizers(nets, args)

    # Main loop
    history = {'train': {'epoch': [], 'loss': [], 'acc': []}}

    for epoch in range(args.start_epoch, args.num_epoch + 1):
        train(segmentation_module, iterator_train, optimizers, history, epoch,
              args)

        # checkpointing
        checkpoint(nets, history, args, epoch)

    print('Training Done!')
コード例 #13
0
def main(cfg, args, gpus):
    USE_MULTI_GPU = True if len(gpus) > 1 else False
    NUM_WORKERS = cfg.TEST.workers
    if cfg.TEST.use_gpu:
        BATCH_SIZE = cfg.TEST.batch_size_per_gpu * len(gpus)
    else:
        BATCH_SIZE = cfg.TEST.batch_size

    RESULT_SAVE_DIR = cfg.TEST.result + '/' + cfg.MODEL.name + '/'
    if not os.path.exists(RESULT_SAVE_DIR):
        os.makedirs(RESULT_SAVE_DIR)
    if not os.path.exists(RESULT_SAVE_DIR):
        print("Error: Could not create result directory ", RESULT_SAVE_DIR)
        exit()

    test_set_dict = {
        "test_airsim_city_0": [
            1007, 1011, 1012, 1020, 1022, 1030, 1032, 1034, 1036, 2007, 2011,
            2012, 2018, 2022, 2032, 2034, 2036, 3007, 3011, 3012, 3018, 3020,
            3022, 3032, 3034, 3036, 4007, 4011, 4012, 4018, 4020, 4030, 4032,
            4034, 4036
        ],
        "test_jackal_0":
        [5, 6, 7, 11, 15, 18, 19, 23, 24, 29, 30, 33, 37, 40, 43],
        "test_jackal_t": [37],
        "test_ahg_husky_t": [1]
    }

    session_list_test = [7, 9, 10]

    if cfg.DATASET.test_set is not None:
        session_list_test = test_set_dict[cfg.DATASET.test_set]

    if cfg.TEST.use_gpu and torch.cuda.is_available():
        device = torch.device("cuda:" + str(gpus[0]))
        used_gpu_count = 1
        total_mem = (
            float(torch.cuda.get_device_properties(device).total_memory) /
            1000000.0)
        gpu_name = torch.cuda.get_device_name(device)
        print("Using ", gpu_name, " with ", total_mem, " MB of memory.")
    else:
        device = torch.device("cpu")
        used_gpu_count = 0

    print("Output Model: ", cfg.MODEL.name)
    print("Test data: ", session_list_test)
    print("Network base model is ",
          cfg.MODEL.arch_encoder + '+' + cfg.MODEL.arch_decoder)
    print("Batch size: ", BATCH_SIZE)
    print("Workers num: ", NUM_WORKERS)
    print("device: ", device)
    phases = ['test']
    #data_set_portion_to_sample = {'train': 0.8, 'val': 0.2}
    data_set_portion_to_sample = {'train': 1.0, 'val': 1.0}

    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])

    input_img_width = int(cfg.DATASET.img_width)
    input_img_height = int(cfg.DATASET.img_height)
    target_img_width = int(cfg.DATASET.img_width)
    target_img_height = int(cfg.DATASET.img_height)

    # Transform loaded images. If not using color images, it will copy the single
    # channel 3 times to keep the size of an RGB image.
    if cfg.DATASET.use_color_images:
        if cfg.DATASET.normalize_input:
            data_transform_input = transforms.Compose([
                transforms.Resize((input_img_height, input_img_width)),
                transforms.ToTensor(), normalize
            ])
        else:
            data_transform_input = transforms.Compose([
                transforms.Resize((input_img_height, input_img_width)),
                transforms.ToTensor(),
            ])
    else:
        if cfg.DATASET.normalize_input:
            data_transform_input = transforms.Compose([
                transforms.Resize((input_img_height, input_img_width)),
                transforms.ToTensor(),
                transforms.Lambda(lambda x: torch.cat([x, x, x], 0)), normalize
            ])
        else:
            data_transform_input = transforms.Compose([
                transforms.Resize((input_img_height, input_img_width)),
                transforms.ToTensor(),
                transforms.Lambda(lambda x: torch.cat([x, x, x], 0))
            ])

    data_transform_target = transforms.Compose([
        transforms.Resize((target_img_height, target_img_width)),
        transforms.ToTensor()
    ])

    # load_mask = cfg.TRAIN.use_masked_loss or cfg.MODEL.predict_conf_mask
    load_mask = False
    print("raw image root: " + cfg.DATASET.raw_img_root)
    print("raw image folder: " + cfg.DATASET.raw_img_folder)
    test_dataset = ImageQualityDataset(
        cfg.DATASET.root,
        cfg.DATASET.raw_img_root,
        session_list_test,
        loaded_image_color=cfg.DATASET.is_dataset_color,
        output_image_color=cfg.DATASET.use_color_images,
        session_prefix_length=cfg.DATASET.session_prefix_len,
        raw_img_folder=cfg.DATASET.raw_img_folder,
        no_meta_data_available=True,
        transform_input=data_transform_input,
        transform_target=data_transform_target,
        load_masks=load_mask,
        regression_mode=cfg.MODEL.is_regression_mode,
        binarize_target=cfg.DATASET.binarize_target)
    datasets = {phases[0]: test_dataset}

    data_loaders = {
        x: torch.utils.data.DataLoader(datasets[x],
                                       batch_size=BATCH_SIZE,
                                       num_workers=NUM_WORKERS)
        for x in phases
    }

    # Build the network from selected modules
    net_encoder = ModelBuilder.build_encoder(
        arch=cfg.MODEL.arch_encoder.lower(),
        fc_dim=cfg.MODEL.fc_dim,
        weights=cfg.MODEL.weights_encoder)
    net_decoder = ModelBuilder.build_decoder(
        arch=cfg.MODEL.arch_decoder.lower(),
        fc_dim=cfg.MODEL.fc_dim,
        num_class=cfg.DATASET.num_class,
        weights=cfg.MODEL.weights_decoder,
        regression_mode=cfg.MODEL.is_regression_mode,
        inference_mode=True)

    # The desired size of the output image. The model interpolates the output
    # to this size
    desired_size = (cfg.DATASET.img_height, cfg.DATASET.img_width)

    # The desired size for the output of the network when saving it to file
    # as an image
    raw_output_img_size = (cfg.TEST.output_img_width,
                           cfg.TEST.output_img_height)

    if cfg.MODEL.is_regression_mode:
        if cfg.TRAIN.use_masked_loss:
            criterion = MaskedMSELoss()
            print("Regression Mode with Masked Loss")
        else:
            criterion = nn.MSELoss(reduction='mean')
            print("Regression Mode")
    else:
        criterion = nn.NLLLoss(ignore_index=-1)
        print("Segmentation Mode")

    if cfg.MODEL.arch_decoder.endswith('deepsup'):
        net = SegmentationModule(net_encoder,
                                 net_decoder,
                                 criterion,
                                 cfg.TRAIN.deep_sup_scale,
                                 segSize=desired_size)
    else:
        net = SegmentationModule(net_encoder,
                                 net_decoder,
                                 criterion,
                                 segSize=desired_size)

    if cfg.TEST.use_gpu and USE_MULTI_GPU:
        if torch.cuda.device_count() >= len(gpus):
            available_gpu_count = torch.cuda.device_count()
            print("Using ", len(gpus), " GPUs out of available ",
                  available_gpu_count)
            print("Used GPUs: ", gpus)
            net = nn.DataParallel(net, device_ids=gpus)
            # For synchronized batch normalization:
            patch_replication_callback(net)
        else:
            print("Requested GPUs not available: ", gpus)
            exit()

    net = net.to(device)

    print("Starting Inference...")
    start_time = time.time()

    # Runs inference on all data
    for cur_phase in phases:
        # Set model to evaluate mode
        net.eval()

        # Iterate over data
        for i, data in enumerate(tqdm(data_loaders[cur_phase]), 0):
            # get the inputs
            input = data['img']
            img_names = data['img_name']
            session_nums = data['session']
            feed_dict = dict()
            feed_dict['input'] = input.to(device)

            # Do not track history since we are in eval mode
            with torch.set_grad_enabled(False):
                # forward pass
                output = net(feed_dict)

            # output = torch.sigmoid(20 * (output - 0.5))

            input_np = input.to(torch.device("cpu")).numpy()
            output_np = output.to(torch.device("cpu")).numpy()

            save_result_images(input_np,
                               target_imgs=None,
                               output_imgs=output_np,
                               img_names=img_names,
                               session_nums=session_nums,
                               save_dir=RESULT_SAVE_DIR,
                               raw_output_size=raw_output_img_size,
                               gt_available=False,
                               save_raw_output=cfg.TEST.save_raw_output,
                               initial_directory_prep=(i == 0))

            if cur_phase == 'test':
                if i % 100 == 99:  # print every 100 mini-batches
                    print('[%5d]' % (i + 1))

                    if used_gpu_count:
                        print("Total GPU usage (MB): ",
                              calculate_gpu_usage(gpus), " / ",
                              used_gpu_count * total_mem)

    print('All data was processed.')
    time_elapsed = time.time() - start_time
    print('Completed in {:.0f}h {:.0f}m {:.0f}s'.format(
        time_elapsed // 3600, (time_elapsed % 3600) // 60, time_elapsed % 60))
コード例 #14
0
ファイル: train.py プロジェクト: vidit98/graphconv
def main(args):
    # Network Builders
    builder = ModelBuilder()

    crit = nn.NLLLoss(ignore_index=-1)
    crit = crit.cuda()
    net_encoder = builder.build_encoder(
        weights="baseline-resnet50dilated-ppm_deepsup/encoder_epoch_20.pth")
    gcu = GraphConv(
        batch=args.batch_size_per_gpu
    )  #, V=2), GCU(X=enc_out, V=4), GCU(X=enc_out, V=8),GCU(X=enc_out, V=32)]
    # gcu.load_state_dict(torch.load("ckpt/baseline-resnet50dilated-ngpus1-batchSize1-imgMaxSize1000-paddingConst8-segmDownsampleRate8-epoch20/decoder_epoch_20.pth"))
    segmentation_module = SegmentationModule(net_encoder, gcu, crit, tr=True)

    # Dataset and Loader
    dataset_train = TrainDataset(args.list_train,
                                 args,
                                 batch_per_gpu=args.batch_size_per_gpu)

    loader_train = torchdata.DataLoader(
        dataset_train,
        batch_size=len(args.gpus),  # we have modified data_parallel
        shuffle=False,  # we do not use this param
        collate_fn=user_scattered_collate,
        num_workers=int(args.workers),
        drop_last=True,
        pin_memory=True)

    print('1 Epoch = {} iters'.format(args.epoch_iters))

    # create loader iterator
    iterator_train = iter(loader_train)

    # load nets into gpu
    if len(args.gpus) > 4:
        segmentation_module = UserScatteredDataParallel(segmentation_module,
                                                        device_ids=args.gpus)
        # For sync bn
        patch_replication_callback(segmentation_module)

# segmentation_module.cuda()

# Set up optimizers
# print(gcu[0].parameters())
    nets = (net_encoder, gcu, crit)
    optimizers, par = create_optimizers(nets, args)

    # Main loop
    history = {'train': {'epoch': [], 'loss': [], 'acc': []}}
    vis = visdom.Visdom()
    win = vis.line(np.array([5.7]),
                   opts=dict(xlabel='epochs',
                             ylabel='Loss',
                             title='Training Loss V=16',
                             legend=['Loss']))

    for epoch in range(args.start_epoch, args.num_epoch + 1):
        lss = train(segmentation_module, iterator_train, optimizers, history,
                    epoch, par, vis, win, args)

        # checkpointing
        checkpoint(nets, history, args, epoch)

    print('Training Done!')
コード例 #15
0
ファイル: train_modular.py プロジェクト: ut-amrl/IV_SLAM
def main(cfg, args, gpus):
    USE_MULTI_GPU = True if len(gpus) > 1 else False
    NUM_WORKERS = cfg.TRAIN.workers
    if cfg.TRAIN.use_gpu:
        BATCH_SIZE = cfg.TRAIN.batch_size_per_gpu * len(gpus)
    else:
        BATCH_SIZE = cfg.TRAIN.batch_size

    train_set_dict = {
        "train_airsim_city_0": [
            1005, 1006, 1009, 1010, 1013, 1014, 1015, 1016, 1021, 1023, 1025,
            1027, 1028, 1029, 1031, 1035, 2005, 2006, 2009, 2010, 2013, 2014,
            2015, 2016, 2021, 2023, 2025, 2027, 2028, 2029, 2031, 2033, 2035,
            3005, 3006, 3009, 3010, 3013, 3014, 3015, 3016, 3021, 3023, 3025,
            3027, 3028, 3029, 3031, 3033, 3035, 4005, 4006, 4009, 4010, 4013,
            4014, 4015, 4021, 4023, 4025, 4027, 4028, 4029, 4031, 4033, 4035
        ],
        "train_jackal_0":
        [1, 2, 3, 4, 8, 10, 13, 16, 17, 20, 22, 25, 27, 28, 31, 36, 42],
        "train_jackal_t": [37],
        "train_stereo_2020_12_21_run1_t": [1],
        "train_stereo_2021_01_13_run2_t": [1]
    }

    valid_set_dict = {
        "valid_airsim_city_0": [1008, 1016, 1017, 1024],
        "valid_jackal_0": [14],
        "valid_jackal_t": [37],
        "valid_stereo_2020_12_21_run1_t": [1],
        "valid_stereo_2021_01_13_run2_t": [1]
    }

    session_list_train = [0, 1, 2, 3, 4, 5]
    session_list_val = [6, 8]
    session_list_test = [7, 9, 10]

    if (cfg.DATASET.train_set is not None) and (cfg.DATASET.validation_set
                                                is not None):
        session_list_train = train_set_dict[cfg.DATASET.train_set]
        session_list_val = valid_set_dict[cfg.DATASET.validation_set]

    if cfg.TRAIN.use_gpu and torch.cuda.is_available():
        device = torch.device("cuda:" + str(gpus[0]))
        used_gpu_count = 1
        total_mem = (
            float(torch.cuda.get_device_properties(device).total_memory) /
            1000000.0)
        gpu_name = torch.cuda.get_device_name(device)
        print("Using ", gpu_name, " with ", total_mem, " MB of memory.")
    else:
        device = torch.device("cpu")
        used_gpu_count = 0

    print("Output Model: ", cfg.MODEL.name)
    print("Train data: ", session_list_train)
    print("Validation data: ", session_list_val)
    print("Network base model is ",
          cfg.MODEL.arch_encoder + '+' + cfg.MODEL.arch_decoder)
    print("Batch size: ", BATCH_SIZE)
    print("Workers num: ", NUM_WORKERS)
    print("device: ", device)
    phases = ['train', 'val']
    #data_set_portion_to_sample = {'train': 0.8, 'val': 0.2}
    data_set_portion_to_sample = {'train': 1.0, 'val': 1.0}

    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])

    input_img_width = int(cfg.DATASET.img_width)
    input_img_height = int(cfg.DATASET.img_height)
    target_img_width = int(cfg.DATASET.img_width /
                           cfg.DATASET.target_downsampling_rate)
    target_img_height = int(cfg.DATASET.img_height /
                            cfg.DATASET.target_downsampling_rate)

    # Transform loaded images. If not using color images, it will copy the single
    # channel 3 times to keep the size of an RGB image.
    if cfg.DATASET.use_color_images:
        if cfg.DATASET.normalize_input:
            data_transform_input = transforms.Compose([
                transforms.Resize((input_img_height, input_img_width)),
                transforms.ToTensor(), normalize
            ])
        else:
            data_transform_input = transforms.Compose([
                transforms.Resize((input_img_height, input_img_width)),
                transforms.ToTensor(),
            ])
    else:
        if cfg.DATASET.normalize_input:
            data_transform_input = transforms.Compose([
                transforms.Resize((input_img_height, input_img_width)),
                transforms.ToTensor(),
                transforms.Lambda(lambda x: torch.cat([x, x, x], 0)), normalize
            ])
        else:
            data_transform_input = transforms.Compose([
                transforms.Resize((input_img_height, input_img_width)),
                transforms.ToTensor(),
                transforms.Lambda(lambda x: torch.cat([x, x, x], 0))
            ])

    data_transform_target = transforms.Compose([
        transforms.Resize((target_img_height, target_img_width)),
        transforms.ToTensor()
    ])

    load_mask = cfg.TRAIN.use_masked_loss or cfg.MODEL.predict_conf_mask
    train_dataset = ImageQualityDataset(
        cfg.DATASET.root,
        cfg.DATASET.raw_img_root,
        session_list_train,
        loaded_image_color=cfg.DATASET.is_dataset_color,
        output_image_color=cfg.DATASET.use_color_images,
        session_prefix_length=cfg.DATASET.session_prefix_len,
        raw_img_folder=cfg.DATASET.raw_img_folder,
        no_meta_data_available=True,
        load_only_with_labels=True,
        transform_input=data_transform_input,
        transform_target=data_transform_target,
        load_masks=load_mask,
        regression_mode=cfg.MODEL.is_regression_mode,
        binarize_target=cfg.DATASET.binarize_target)
    val_dataset = ImageQualityDataset(
        cfg.DATASET.root,
        cfg.DATASET.raw_img_root,
        session_list_val,
        loaded_image_color=cfg.DATASET.is_dataset_color,
        output_image_color=cfg.DATASET.use_color_images,
        session_prefix_length=cfg.DATASET.session_prefix_len,
        raw_img_folder=cfg.DATASET.raw_img_folder,
        no_meta_data_available=True,
        load_only_with_labels=True,
        transform_input=data_transform_input,
        transform_target=data_transform_target,
        load_masks=load_mask,
        regression_mode=cfg.MODEL.is_regression_mode,
        binarize_target=cfg.DATASET.binarize_target)
    datasets = {phases[0]: train_dataset, phases[1]: val_dataset}

    data_loaders = {
        x: torch.utils.data.DataLoader(datasets[x],
                                       batch_size=BATCH_SIZE,
                                       num_workers=NUM_WORKERS)
        for x in phases
    }

    # Build the network from selected modules
    net_encoder = ModelBuilder.build_encoder(
        arch=cfg.MODEL.arch_encoder.lower(),
        fc_dim=cfg.MODEL.fc_dim,
        weights=cfg.MODEL.weights_encoder)
    net_decoder = ModelBuilder.build_decoder(
        arch=cfg.MODEL.arch_decoder.lower(),
        fc_dim=cfg.MODEL.fc_dim,
        num_class=cfg.DATASET.num_class,
        weights=cfg.MODEL.weights_decoder,
        regression_mode=cfg.MODEL.is_regression_mode,
        inference_mode=False)

    if cfg.MODEL.is_regression_mode:
        if cfg.TRAIN.use_masked_loss:
            criterion = MaskedMSELoss()
            print("Regression Mode with Masked Loss")
        else:
            criterion = nn.MSELoss(reduction='mean')
            print("Regression Mode")
    else:
        criterion = nn.NLLLoss(ignore_index=-1)
        print("Segmentation Mode")

    use_mask = cfg.TRAIN.use_masked_loss and cfg.MODEL.is_regression_mode
    if cfg.MODEL.arch_decoder.endswith('deepsup'):
        net = SegmentationModule(net_encoder,
                                 net_decoder,
                                 criterion,
                                 cfg.TRAIN.deep_sup_scale,
                                 use_mask=use_mask)
    else:
        net = SegmentationModule(net_encoder,
                                 net_decoder,
                                 criterion,
                                 use_mask=use_mask)

    if cfg.TRAIN.use_gpu and USE_MULTI_GPU:
        if torch.cuda.device_count() >= len(gpus):
            available_gpu_count = torch.cuda.device_count()
            print("Using ", len(gpus), " GPUs out of available ",
                  available_gpu_count)
            print("Used GPUs: ", gpus)
            net = nn.DataParallel(net, device_ids=gpus)
            # For synchronized batch normalization:
            patch_replication_callback(net)
        else:
            print("Requested GPUs not available: ", gpus)
            exit()

    net = net.to(device)

    # Set up optimizers
    modules = (net_encoder, net_decoder, criterion)
    optimizers = create_optimizers(modules, cfg)

    best_loss = 1000000000.0
    best_model_enc = copy.deepcopy(net_encoder.state_dict())
    best_model_dec = copy.deepcopy(net_decoder.state_dict())

    training_history = {x: {'loss': [], 'acc': []} for x in phases}

    print("Starting Training...")
    start_time = time.time()

    # Runs training and validation
    for epoch in range(cfg.TRAIN.num_epoch):
        for cur_phase in phases:
            if cur_phase == 'train':
                net.train()  # Set model to training mode
            else:
                net.eval()  # Set model to evaluate mode

            epoch_loss = 0.0
            running_loss = 0.0
            # Iterate over data
            i = 0
            for i, data in enumerate(data_loaders[cur_phase], 0):
                # get the inputs
                input = data['img']
                feed_dict = dict()
                feed_dict['input'] = input.to(device)
                if not cfg.MODEL.is_regression_mode:
                    target = data['labels']
                    feed_dict['target'] = target.to(device)
                else:
                    target = data['score_img']
                    feed_dict['target'] = target.to(device)
                    if cfg.TRAIN.use_masked_loss:
                        mask = data['mask_img']
                        feed_dict['mask'] = mask.to(device)

                # zero the parameter gradients
                net.zero_grad()

                # track history if only in train
                with torch.set_grad_enabled(cur_phase == 'train'):
                    # forward + backward + optimize
                    loss, acc = net(feed_dict)

                    # For multi gpu case, loss will be a vector the same length
                    # as the number of gpus. This is a side effect of loss function
                    # being inside the module.
                    loss = loss.mean()

                    # backward and optimize only if in training phase
                    if cur_phase == 'train':
                        loss.backward()
                        for optimizer in optimizers:
                            optimizer.step()

                # print statistics
                loss = loss.item()
                running_loss += loss
                epoch_loss += loss

                if cur_phase == 'train':
                    if i % 100 == 99:  # print every 100 mini-batches
                        print('[%d, %5d] Loss: %.6f' %
                              (epoch + 1, i + 1, running_loss / 100))
                        running_loss = 0.0

                        if used_gpu_count:
                            print("Total GPU usage (MB): ",
                                  calculate_gpu_usage(gpus), " / ",
                                  used_gpu_count * total_mem)

            epoch_loss = epoch_loss / i
            print('%s: Loss: %.6f' % (cur_phase, epoch_loss))

            # Keep the best model so far
            if cur_phase == 'val' and epoch_loss < best_loss:
                best_loss = epoch_loss
                best_model_enc = copy.deepcopy(net_encoder.state_dict())
                best_model_dec = copy.deepcopy(net_decoder.state_dict())

            training_history[cur_phase]['loss'].append(epoch_loss)

        print('Epoch #%d finished. *******************' % (epoch + 1))
        if (epoch + 1) % cfg.TRAIN.snapshot_interval == 0:
            last_model_state = (net_encoder.state_dict(),
                                net_decoder.state_dict())
            best_model_state = (best_model_enc, best_model_dec)
            save_results(cfg.DIR,
                         cfg.MODEL.name,
                         last_model_state,
                         best_model_state,
                         training_history,
                         snapshot_idx=epoch)
            print('Snapshot saved.')

    print('Finished Training')
    time_elapsed = time.time() - start_time
    print('Completed in {:.0f}h {:.0f}m {:.0f}s'.format(
        time_elapsed // 3600, (time_elapsed % 3600) // 60, time_elapsed % 60))

    last_model_state = (net_encoder.state_dict(), net_decoder.state_dict())
    best_model_state = (best_model_enc, best_model_dec)
    save_results(cfg.DIR,
                 cfg.MODEL.name,
                 last_model_state,
                 best_model_state,
                 training_history,
                 snapshot_idx=-1)
コード例 #16
0
def main(cfg, gpus):
    torch.backends.cudnn.enabled = False
    # cudnn.deterministic = False
    # cudnn.enabled = True
    # Network Builders
    net_encoder = ModelBuilder.build_encoder(
        arch=cfg.MODEL.arch_encoder.lower(),
        fc_dim=cfg.MODEL.fc_dim,
        weights=cfg.MODEL.weights_encoder)
    net_decoder = ModelBuilder.build_decoder(
        arch=cfg.MODEL.arch_decoder.lower(),
        fc_dim=cfg.MODEL.fc_dim,
        num_class=cfg.DATASET.num_class,
        weights=cfg.MODEL.weights_decoder)

    if cfg.MODEL.arch_decoder == 'ocr':
        print('Using cross entropy loss')
        crit = CrossEntropy(ignore_label=-1)
    else:
        crit = nn.NLLLoss(ignore_index=-1)

    if cfg.MODEL.arch_decoder.endswith('deepsup'):
        segmentation_module = SegmentationModule(net_encoder, net_decoder,
                                                 crit,
                                                 cfg.TRAIN.deep_sup_scale)
    else:
        segmentation_module = SegmentationModule(net_encoder, net_decoder,
                                                 crit)

    # Dataset and Loader
    dataset_train = TrainDataset(cfg.DATASET.root_dataset,
                                 cfg.DATASET.list_train,
                                 cfg.DATASET,
                                 batch_per_gpu=cfg.TRAIN.batch_size_per_gpu)

    loader_train = torch.utils.data.DataLoader(
        dataset_train,
        batch_size=len(gpus),
        shuffle=False,  # parameter is not used
        collate_fn=user_scattered_collate,
        num_workers=cfg.TRAIN.workers,
        drop_last=True,
        pin_memory=True)
    # create loader iterator
    iterator_train = iter(loader_train)
    print('1 Epoch = {} iters'.format(cfg.TRAIN.epoch_iters))

    if cfg.TRAIN.eval:
        # Dataset and Loader for validtaion data
        dataset_val = ValDataset(cfg.DATASET.root_dataset,
                                 cfg.DATASET.list_val, cfg.DATASET)
        loader_val = torch.utils.data.DataLoader(
            dataset_val,
            batch_size=cfg.VAL.batch_size,
            shuffle=False,
            collate_fn=user_scattered_collate,
            num_workers=5,
            drop_last=True)
        iterator_val = iter(loader_val)

    # load nets into gpu
    if len(gpus) > 1:
        segmentation_module = UserScatteredDataParallel(segmentation_module,
                                                        device_ids=gpus)
        # For sync bn
        patch_replication_callback(segmentation_module)
    segmentation_module.cuda()

    # Set up optimizers
    nets = (net_encoder, net_decoder, crit)
    optimizers = create_optimizers(nets, cfg)

    # Main loop
    history = {
        'train': {
            'epoch': [],
            'loss': [],
            'acc': [],
            'last_score': 0,
            'best_score': cfg.TRAIN.best_score
        }
    }
    for epoch in range(cfg.TRAIN.start_epoch, cfg.TRAIN.num_epoch):
        train(segmentation_module, iterator_train, optimizers, history,
              epoch + 1, cfg)
        # calculate segmentation score
        if cfg.TRAIN.eval and epoch in range(cfg.TRAIN.start_epoch,
                                             cfg.TRAIN.num_epoch,
                                             step=cfg.TRAIN.eval_step):
            iou, acc = evaluate(segmentation_module, iterator_val, cfg, gpus)
            history['train']['last_score'] = (iou + acc) / 2
            if history['train']['last_score'] > history['train']['best_score']:
                history['train']['best_score'] = history['train']['last_score']
                checkpoint(nets, history, cfg, 'best_score')
        # checkpointing
        checkpoint(nets, history, cfg, epoch + 1)
    print('Training Done!')
コード例 #17
0
def main(cfg, gpus):
    # Network Builders
    net_encoder = models.ModelBuilder.build_encoder(
        arch=cfg.MODEL.arch_encoder.lower(),
        fc_dim=cfg.MODEL.fc_dim,
        weights=cfg.MODEL.weights_encoder)
    net_decoder = models.ModelBuilder.build_decoder(
        arch=cfg.MODEL.arch_decoder.lower(),
        fc_dim=cfg.MODEL.fc_dim,
        num_class=cfg.DATASET.num_class,
        weights=cfg.MODEL.weights_decoder)

    crit = nn.NLLLoss(ignore_index=-1)
    # crit = nn.CrossentropyLoss()

    if cfg.MODEL.arch_decoder.endswith('deepsup'):
        segmentation_module = models.SegmentationModule(
            net_encoder, net_decoder, crit, None)
    else:
        segmentation_module = models.SegmentationModule(
            net_encoder, net_decoder, crit)

    # Dataset and Loader
    dataset_train = TrainDataset(cfg.DATASET.root_dataset,
                                 cfg.DATASET.list_train,
                                 cfg.DATASET,
                                 batch_per_gpu=cfg.TRAIN.batch_size_per_gpu)

    loader_train = torch.utils.data.DataLoader(
        dataset_train,
        batch_size=len(gpus),  # we have modified data_parallel
        shuffle=False,  # we do not use this param
        collate_fn=user_scattered_collate,
        num_workers=cfg.TRAIN.workers,
        drop_last=True,
        pin_memory=True)
    print('1 Epoch = {} iters'.format(cfg.TRAIN.epoch_iters))

    # create loader iterator
    iterator_train = iter(loader_train)

    # load nets into gpu
    if len(gpus) > 1:
        segmentation_module = UserScatteredDataParallel(segmentation_module,
                                                        device_ids=gpus)
        # For sync bn
        patch_replication_callback(segmentation_module)

    # segmentation_module.cuda()

    # Set up optimizers
    nets = (net_encoder, net_decoder, crit)
    optimizers = create_optimizers(nets, cfg)

    # Main loop
    history = {'train': {'epoch': [], 'loss': [], 'acc': []}}

    for epoch in range(cfg.TRAIN.start_epoch, cfg.TRAIN.num_epoch):
        train(segmentation_module, iterator_train, optimizers, history,
              epoch + 1, cfg)

        # checkpointing
        checkpoint(nets, history, cfg, epoch + 1)

    print('Training Done!')
コード例 #18
0
def main(args):
    # Network Builders
    builder = ModelBuilder()
    net_encoder = None
    net_decoder = None
    unet = None

    if args.unet == False:
        net_encoder = builder.build_encoder(arch=args.arch_encoder,
                                            fc_dim=args.fc_dim,
                                            weights=args.weights_encoder)
        net_decoder = builder.build_decoder(arch=args.arch_decoder,
                                            fc_dim=args.fc_dim,
                                            num_class=args.num_class,
                                            weights=args.weights_decoder)
    else:
        unet = builder.build_unet(num_class=args.num_class,
                                  arch=args.unet_arch,
                                  weights=args.weights_unet)

        print("Froze the following layers: ")
        for name, p in unet.named_parameters():
            if p.requires_grad == False:
                print(name)

    crit = nn.NLLLoss()
    #crit = nn.BCEWithLogitsLoss(pos_weight=torch.tensor(50))
    #crit = nn.CrossEntropyLoss().cuda()
    #crit = nn.BCELoss()

    if args.arch_decoder.endswith('deepsup') and args.unet == False:
        segmentation_module = SegmentationModule(net_encoder, net_decoder,
                                                 crit, args.deep_sup_scale)
    else:
        segmentation_module = SegmentationModule(net_encoder,
                                                 net_decoder,
                                                 crit,
                                                 is_unet=args.unet,
                                                 unet=unet)

    train_augs = Compose([
        RandomSized(224),
        RandomHorizontallyFlip(),
        RandomVerticallyFlip(),
        RandomRotate(180),
        AdjustContrast(cf=0.25),
        AdjustBrightness(bf=0.25)
    ])  #, RandomErasing()])
    #train_augs = None
    # Dataset and Loader
    dataset_train = TrainDataset(args.list_train,
                                 args,
                                 batch_per_gpu=args.batch_size_per_gpu,
                                 augmentations=train_augs)

    loader_train = data.DataLoader(
        dataset_train,
        batch_size=len(args.gpus),  # we have modified data_parallel
        shuffle=False,  # we do not use this param
        num_workers=int(args.workers),
        drop_last=True,
        pin_memory=False)

    print('1 Epoch = {} iters'.format(args.epoch_iters))
    # create loader iterator
    iterator_train = iter(loader_train)

    # load nets into gpu
    if len(args.gpus) > 1:
        segmentation_module = UserScatteredDataParallel(segmentation_module,
                                                        device_ids=args.gpus)
        # For sync bn
        patch_replication_callback(segmentation_module)
    segmentation_module.cuda()

    # Set up optimizers
    nets = (net_encoder, net_decoder, crit) if args.unet == False else (unet,
                                                                        crit)
    optimizers = create_optimizers(nets, args)

    # Main loop
    history = {'train': {'epoch': [], 'loss': [], 'acc': []}}

    for epoch in range(args.start_epoch, args.num_epoch + 1):
        train(segmentation_module, iterator_train, optimizers, history, epoch,
              args)
        # checkpointing
        checkpoint(nets, history, args, epoch)

    print('Training Done!')
コード例 #19
0
ファイル: train.py プロジェクト: rexxxx1234/SAUNet-demo
def main(args):
    # Network Builders
    builder = ModelBuilder()

    unet = builder.build_unet(num_class=args.num_class,
        arch=args.unet_arch,
        weights=args.weights_unet)

    print("Froze the following layers: ")
    for name, p in unet.named_parameters():
        if p.requires_grad == False:
            print(name)
    print()

    crit = DualLoss(mode="train")

    segmentation_module = SegmentationModule(crit, unet)

    test_augs = Compose([PaddingCenterCrop(256)])
    
    print("ready to load data")

    dataset_val = LungData( 
            root=args.data_root,
            split='test',
            k_split=args.k_split,
            augmentations=test_augs)

    
    loader_val = data.DataLoader(
        dataset_val,
        batch_size=1,
        shuffle=False,
        collate_fn=user_scattered_collate,
        num_workers=5,
        drop_last=True)

    print(len(loader_val))

    # load nets into gpu
    if len(args.gpus) > 1:
        segmentation_module = UserScatteredDataParallel(
            segmentation_module,
            device_ids=args.gpus)
        # For sync bn
        patch_replication_callback(segmentation_module)
    segmentation_module.cuda()

    # Set up optimizers
    nets = (net_encoder, net_decoder, crit) if args.unet == False else (unet, crit)
    optimizers = create_optimizers(nets, args)

    '''
    # Start the webapp: user update a dcm file, output the predicted segmentation pic of it
    inp = gradio.inputs.DcmUpload(preprocessing_fn=preprocess)
    #inp = gradio.inputs.ImageUpload(preprocessing_fn=preprocess)
    io = gradio.Interface(inputs=inp, outputs="image", model_type="lung_seg", model=segmentation_module, args=args)
    io.launch(validate=False)
    '''

    iou, loss = eval(loader_val, segmentation_module, args, crit)
    print('Evaluation Done!')
コード例 #20
0
def main():
    args = get_arguments()

    # seeding for reproducbility
    if torch.cuda.is_available():
        torch.cuda.manual_seed(args.random_seed)
    torch.manual_seed(args.random_seed)
    np.random.seed(args.random_seed)

    # instantiate dataset
    dataset = AdobeImageMattingDataset

    snapshot_dir = os.path.join(args.snapshot_dir, args.dataset.lower(),
                                args.exp)
    if not os.path.exists(snapshot_dir):
        os.makedirs(snapshot_dir)

    args.result_dir = os.path.join(args.result_dir, args.exp)
    if not os.path.exists(args.result_dir):
        os.makedirs(args.result_dir)

    args.restore_from = os.path.join(args.snapshot_dir, args.dataset.lower(),
                                     args.exp, args.restore_from)

    arguments = vars(args)
    for item in arguments:
        print(item, ':\t', arguments[item])

    # instantiate network
    hlnet = hlbackbone[args.backbone]
    net = hlnet(pretrained=True,
                freeze_bn=True,
                output_stride=args.output_stride,
                input_size=args.crop_size,
                apply_aspp=args.apply_aspp,
                conv_operator=args.conv_operator,
                decoder=args.decoder,
                decoder_kernel_size=args.decoder_kernel_size,
                indexnet=args.indexnet,
                index_mode=args.index_mode,
                use_nonlinear=args.use_nonlinear,
                use_context=args.use_context,
                sync_bn=args.sync_bn)

    if args.backbone == 'mobilenetv2':
        net = nn.DataParallel(net)
    if args.sync_bn:
        patch_replication_callback(net)
    net.cuda()

    # filter parameters
    pretrained_params = []
    learning_params = []
    for p in net.named_parameters():
        if 'dconv' in p[0] or 'pred' in p[0] or 'index' in p[0]:
            learning_params.append(p[1])
        else:
            pretrained_params.append(p[1])

    # define optimizer
    optimizer = torch.optim.Adam([
        {
            'params': learning_params
        },
        {
            'params': pretrained_params,
            'lr': args.learning_rate / args.mult
        },
    ],
                                 lr=args.learning_rate)

    # restore parameters
    start_epoch = 0
    net.train_loss = {'running_loss': [], 'epoch_loss': []}
    net.val_loss = {'running_loss': [], 'epoch_loss': []}
    net.measure = {'sad': [], 'mse': [], 'grad': [], 'conn': []}
    if args.restore_from is not None:
        if os.path.isfile(args.restore_from):
            checkpoint = torch.load(args.restore_from)
            net.load_state_dict(checkpoint['state_dict'])
            if 'epoch' in checkpoint:
                start_epoch = checkpoint['epoch']
            if 'optimizer' in checkpoint:
                optimizer.load_state_dict(checkpoint['optimizer'])
            if 'train_loss' in checkpoint:
                net.train_loss = checkpoint['train_loss']
            if 'val_loss' in checkpoint:
                net.val_loss = checkpoint['val_loss']
            if 'measure' in checkpoint:
                net.measure = checkpoint['measure']
            print("==> load checkpoint '{}' (epoch {})".format(
                args.restore_from, start_epoch))
        else:
            with open(os.path.join(args.result_dir, args.exp + '.txt'),
                      'a') as f:
                for item in arguments:
                    print(item, ':\t', arguments[item], file=f)
            print("==> no checkpoint found at '{}'".format(args.restore_from))

    # define transform
    transform_train_val = [
        RandomCrop(args.crop_size, args.scales),
        RandomFlip()
    ]
    transform_all = [
        Normalize(args.image_scale, args.image_mean, args.image_std),
        ToTensor()
    ]
    composed_transform_train = transforms.Compose(transform_train_val +
                                                  transform_all)
    composed_transform_val = transforms.Compose(transform_all)

    # define dataset loader
    trainset = dataset(data_file=args.data_list,
                       data_dir=args.data_dir,
                       train=True,
                       transform=composed_transform_train)
    train_loader = DataLoader(trainset,
                              batch_size=args.batch_size,
                              shuffle=True,
                              num_workers=args.num_workers,
                              pin_memory=True,
                              drop_last=True)
    valset = dataset(data_file=args.data_val_list,
                     data_dir=args.data_dir,
                     train=False,
                     transform=composed_transform_val)
    val_loader = DataLoader(valset,
                            batch_size=1,
                            shuffle=False,
                            num_workers=0,
                            pin_memory=True)

    print('alchemy start...')
    if args.evaluate_only:
        validate(net, val_loader, start_epoch + 1, args)
        return

    resume_epoch = -1 if start_epoch == 0 else start_epoch
    scheduler = MultiStepLR(optimizer,
                            milestones=[20, 26],
                            gamma=0.1,
                            last_epoch=resume_epoch)
    for epoch in range(start_epoch, args.num_epochs):
        scheduler.step()
        # train
        train(net, train_loader, optimizer, epoch + 1, scheduler, args)
        # val
        validate(net, val_loader, epoch + 1, args)
        # save checkpoint
        state = {
            'state_dict': net.state_dict(),
            'optimizer': optimizer.state_dict(),
            'epoch': epoch + 1,
            'train_loss': net.train_loss,
            'val_loss': net.val_loss,
            'measure': net.measure
        }
        save_checkpoint(state, snapshot_dir, filename='model_ckpt.pth.tar')
        print(args.exp + ' epoch {} finished!'.format(epoch + 1))
        if len(net.measure['grad']) > 1 and net.measure['grad'][-1] <= min(
                net.measure['grad'][:-1]):
            save_checkpoint(state, snapshot_dir, filename='model_best.pth.tar')
    print('Experiments with ' + args.exp + ' done!')
コード例 #21
0
def main(args):
    # Network Builders
    builder = ModelBuilder()
    net_encoder = builder.build_encoder(arch=args.arch_encoder,
                                        fc_dim=args.fc_dim,
                                        weights=args.weights_encoder)
    net_decoder = builder.build_decoder(arch=args.arch_decoder,
                                        fc_dim=args.fc_dim,
                                        num_class=args.num_class,
                                        weights=args.weights_decoder)

    crit = nn.NLLLoss(ignore_index=-1)

    if args.arch_decoder.endswith('deepsup'):
        segmentation_module = SegmentationModule(net_encoder, net_decoder,
                                                 crit, args.deep_sup_scale)
    else:
        segmentation_module = SegmentationModule(net_encoder, net_decoder,
                                                 crit)

    # Dataset and Loader
    dataset_train = TrainDataset(args.list_train,
                                 args,
                                 batch_per_gpu=args.batch_size_per_gpu)

    loader_train = torchdata.DataLoader(
        dataset_train,
        batch_size=args.num_gpus,  # we have modified data_parallel
        shuffle=False,  # we do not use this param
        collate_fn=user_scattered_collate,
        num_workers=int(args.workers),
        drop_last=True,
        pin_memory=True)

    print('1 Epoch = {} iters'.format(args.epoch_iters))

    # create loader iterator
    iterator_train = iter(loader_train)

    # load nets into gpu
    if args.num_gpus > 1:
        segmentation_module = UserScatteredDataParallel(segmentation_module,
                                                        device_ids=range(
                                                            args.num_gpus))
        # For sync bn
        patch_replication_callback(segmentation_module)
    segmentation_module.cuda()

    # Set up optimizers
    nets = (net_encoder, net_decoder, crit)
    optimizers = create_optimizers(nets, args)

    # Main loop
    history = {'train': {'epoch': [], 'loss': [], 'acc': []}}

    for epoch in range(args.start_epoch, args.num_epoch + 1):
        train(segmentation_module, iterator_train, optimizers, history, epoch,
              args)

        # checkpointing
        checkpoint(nets, history, args, epoch)

    print('Training Done!')
コード例 #22
0
def main(args):
    # Network Builders
    builder = ModelBuilder()
    net_encoder = builder.build_encoder(arch=args.arch_encoder,
                                        fc_dim=args.fc_dim,
                                        weights=args.weights_encoder)
    net_decoder = builder.build_decoder(arch=args.arch_decoder,
                                        fc_dim=args.fc_dim,
                                        num_class=args.num_class,
                                        weights=args.weights_decoder)

    crit = nn.NLLLoss(ignore_index=255)

    if args.arch_decoder.endswith('deepsup'):
        segmentation_module = SegmentationModule(net_encoder, net_decoder,
                                                 crit, args.deep_sup_scale)
    else:
        segmentation_module = SegmentationModule(net_encoder, net_decoder,
                                                 crit)

    # Dataset and Loader
    # dataset_train = TrainDataset(
    #     args.list_train, args, batch_per_gpu=args.batch_size_per_gpu)
    # dataset_train = voc.TrainDataset_VOC(dataset_root=config.img_root_folder, mode='train', transform=input_transform)
    dataset_train = VOC_TrainDataset(opt=args,
                                     batch_per_gpu=args.batch_size_per_gpu)

    loader_train = torchdata.DataLoader(
        dataset_train,
        batch_size=args.num_gpus,  # we have modified data_parallel
        shuffle=False,  # we do not use this param
        collate_fn=user_scattered_collate,
        num_workers=int(args.workers),
        drop_last=True,
        pin_memory=True)

    print('1 Epoch = {} iters'.format(args.epoch_iters))

    # create loader iterator
    iterator_train = iter(loader_train)

    # load nets into gpu
    if args.num_gpus > 1:
        segmentation_module = UserScatteredDataParallel(segmentation_module,
                                                        device_ids=range(
                                                            args.num_gpus))
        # segmentation_module = UserScatteredDataParallel(
        #     segmentation_module,
        #     device_ids=[0,3,4,5])

        # For sync bn
        patch_replication_callback(segmentation_module)
    segmentation_module.cuda()

    # Set up optimizers
    nets = (net_encoder, net_decoder, crit)
    optimizers = create_optimizers(nets, args)

    if args.resume:
        file_path_history = '{}/history_{}'.format(args.ckpt, 'epoch_last.pth')
        file_path_encoder = '{}/encoder_{}'.format(args.ckpt, 'epoch_last.pth')
        file_path_decoder = '{}/decoder_{}'.format(args.ckpt, 'epoch_last.pth')
        file_path_optimizers = '{}/optimizers_{}'.format(
            args.ckpt, 'epoch_last.pth')

        if os.path.isfile(file_path_history):
            print("=> loading checkpoint '{}'".format(file_path_history))
            checkpoint_history = torch.load(file_path_history)
            checkpoint_encoder = torch.load(file_path_encoder)
            checkpoint_decoder = torch.load(file_path_decoder)
            checkpoint_optimizers = torch.load(file_path_optimizers)

            args.start_epoch = int(checkpoint_history['train']['epoch'][0]) + 1

            nets[0].load_state_dict(checkpoint_encoder)
            nets[1].load_state_dict(checkpoint_decoder)

            optimizers[0].load_state_dict(
                checkpoint_optimizers['encoder_optimizer'])
            optimizers[1].load_state_dict(
                checkpoint_optimizers['decoder_optimizer'])

            print("=> start train epoch {}".format(args.start_epoch))
        else:
            print('resume not find epoch-last checkpoint')

    # Main loop
    history = {'train': {'epoch': [], 'loss': [], 'acc': []}}

    for epoch in range(args.start_epoch, args.num_epoch + 1):
        train(segmentation_module, iterator_train, optimizers, history, epoch,
              args)

        # checkpointing
        checkpoint(nets, optimizers, history, args, epoch)

    print('Training Done!')