Esempio n. 1
0
def main(cfg, gpus):
    # Network Builders
    builder = ModelBuilder()
    net_encoder = builder.build_encoder(arch=cfg.MODEL.arch_encoder.lower(),
                                        fc_dim=cfg.MODEL.fc_dim,
                                        weights=cfg.MODEL.weights_encoder)
    net_decoder = builder.build_decoder(arch=cfg.MODEL.arch_decoder.lower(),
                                        fc_dim=cfg.MODEL.fc_dim,
                                        num_class=cfg.DATASET.num_class,
                                        weights=cfg.MODEL.weights_decoder)

    crit = nn.NLLLoss(ignore_index=-1)

    if cfg.MODEL.arch_decoder.endswith('deepsup'):
        segmentation_module = SegmentationModule(net_encoder, net_decoder,
                                                 crit,
                                                 cfg.TRAIN.deep_sup_scale)
    else:
        segmentation_module = SegmentationModule(net_encoder, net_decoder,
                                                 crit)

    # Dataset and Loader
    dataset_train = TrainDataset(cfg.DATASET.root_dataset,
                                 cfg.DATASET.list_train,
                                 cfg.DATASET,
                                 batch_per_gpu=cfg.TRAIN.batch_size_per_gpu)

    loader_train = torch.utils.data.DataLoader(
        dataset_train,
        batch_size=len(gpus),  # we have modified data_parallel
        shuffle=False,  # we do not use this param
        collate_fn=user_scattered_collate,
        num_workers=cfg.TRAIN.workers,
        drop_last=True,
        pin_memory=True)
    print('1 Epoch = {} iters'.format(cfg.TRAIN.epoch_iters))

    # create loader iterator
    iterator_train = iter(loader_train)

    # load nets into gpu
    if len(gpus) > 1:
        segmentation_module = UserScatteredDataParallel(segmentation_module,
                                                        device_ids=gpus)
        # For sync bn
        patch_replication_callback(segmentation_module)
    segmentation_module.cuda()

    # Set up optimizers
    nets = (net_encoder, net_decoder, crit)
    optimizers = create_optimizers(nets, cfg)

    # Main loop
    history = {'train': {'epoch': [], 'loss': [], 'acc': []}}

    for epoch in range(cfg.TRAIN.start_epoch, cfg.TRAIN.num_epoch):
        train(segmentation_module, iterator_train, optimizers, history,
              epoch + 1, cfg)

        # checkpointing
        checkpoint(nets, history, cfg, epoch + 1)

    print('Training Done!')
Esempio n. 2
0
def main(args):
    # Dataset
    dataset_train = Dataset(args,
                            split_name='train',
                            batch_per_gpu=args.batch_size_per_gpu)

    # Network Builders
    builder = ModelBuilder()
    net_encoder = builder.build_encoder(arch=args.arch_encoder,
                                        fc_dim=args.fc_dim,
                                        freeze_until=args.freeze_until,
                                        weights=args.weights_encoder)
    net_decoder = builder.build_decoder(arch=args.arch_decoder,
                                        fc_dim=args.fc_dim,
                                        num_class=dataset_train.num_classes,
                                        weights=args.weights_decoder)

    net_encoder.train()

    crit = nn.NLLLoss(ignore_index=-1)

    if args.arch_decoder.endswith('deepsup'):
        segmentation_module = SegmentationModule(net_encoder, net_decoder,
                                                 crit, args.deep_sup_scale)

    else:
        segmentation_module = SegmentationModule(net_encoder, net_decoder,
                                                 crit)

    # loader
    loader_train = torchdata.DataLoader(
        dataset_train,
        batch_size=len(args.gpus),  # we have modified data_parallel
        shuffle=False,
        collate_fn=user_scattered_collate,
        num_workers=int(args.workers),
        drop_last=True,
        pin_memory=True)

    print('1 Epoch = {} iters'.format(args.epoch_iters))

    # create loader iterator
    iterator_train = iter(loader_train)

    # load nets into gpu
    if len(args.gpus) > 1:
        segmentation_module = UserScatteredDataParallel(segmentation_module,
                                                        device_ids=args.gpus)
        # For sync bn
        patch_replication_callback(segmentation_module)
    segmentation_module.cuda()

    # Set up optimizers
    nets = (net_encoder, net_decoder, crit)
    optimizers = create_optimizers(nets, args)

    # Main loop
    history = {'train': {'epoch': [], 'loss': [], 'acc': []}}

    for epoch in range(args.start_epoch, args.num_epoch + 1):
        train(segmentation_module, iterator_train, optimizers, history, epoch,
              args)

        # checkpointing
        checkpoint(nets, history, args, epoch)

    print('Training Done!')
Esempio n. 3
0
def main(args):
    # Network Builders
    builder = ModelBuilder()

    unet = builder.build_unet(num_class=args.num_class,
                              arch=args.unet_arch,
                              weights=args.weights_unet)

    print("Froze the following layers: ")
    for name, p in unet.named_parameters():
        if p.requires_grad == False:
            print(name)
    print()

    crit = DualLoss(mode="train")

    segmentation_module = SegmentationModule(crit, unet)

    train_augs = Compose([
        PaddingCenterCrop(256),
        RandomHorizontallyFlip(),
        RandomVerticallyFlip(),
        RandomRotate(180)
    ])
    test_augs = Compose([PaddingCenterCrop(256)])

    # Dataset and Loader
    # dataset_train = AC17( #Loads 3D volumes
    #         root=args.data_root,
    #         split='train',
    #         k_split=args.k_split,
    #         augmentations=train_augs)

    dataset_train = SideWalkData(  # Loads 3D volumes
        root=args.data_root,
        split='train',
        k_split=args.k_split,
        augmentations=train_augs)
    ac17_train = load2D(
        dataset_train, split='train',
        deform=True)  #Dataloader for 2D slices. Requires 3D loader.

    loader_train = data.DataLoader(ac17_train,
                                   batch_size=args.batch_size_per_gpu,
                                   shuffle=True,
                                   num_workers=int(args.workers),
                                   drop_last=True,
                                   pin_memory=True)

    dataset_val = SideWalkData(root=args.data_root,
                               split='val',
                               k_split=args.k_split,
                               augmentations=test_augs)

    ac17_val = load2D(dataset_val, split='val', deform=False)

    loader_val = data.DataLoader(ac17_val,
                                 batch_size=1,
                                 shuffle=False,
                                 collate_fn=user_scattered_collate,
                                 num_workers=5,
                                 drop_last=True)

    # load nets into gpu
    if len(args.gpus) > 1:
        segmentation_module = UserScatteredDataParallel(segmentation_module,
                                                        device_ids=args.gpus)
        # For sync bn
        patch_replication_callback(segmentation_module)
    segmentation_module.cuda()

    # Set up optimizers
    nets = (net_encoder, net_decoder, crit) if args.unet == False else (unet,
                                                                        crit)
    optimizers = create_optimizers(nets, args)

    # Main loop
    history = {'train': {'epoch': [], 'loss': [], 'acc': [], 'jaccard': []}}
    best_val = {
        'epoch_1': 0,
        'mIoU_1': 0,
        'epoch_2': 0,
        'mIoU_2': 0,
        'epoch_3': 0,
        'mIoU_3': 0,
        'epoch': 0,
        'mIoU': 0
    }

    for epoch in range(args.start_epoch, args.num_epoch + 1):
        train(segmentation_module, loader_train, optimizers, history, epoch,
              args)
        iou, loss = eval(loader_val, segmentation_module, args, crit)
        #checkpointing
        ckpted = False
        if loss < 0.215:
            ckpted = True
        if iou[0] > best_val['mIoU_1']:
            best_val['epoch_1'] = epoch
            best_val['mIoU_1'] = iou[0]
            ckpted = True

        if iou[1] > best_val['mIoU_2']:
            best_val['epoch_2'] = epoch
            best_val['mIoU_2'] = iou[1]
            ckpted = True

        if iou[2] > best_val['mIoU_3']:
            best_val['epoch_3'] = epoch
            best_val['mIoU_3'] = iou[2]
            ckpted = True

        if (iou[0] + iou[1] + iou[2]) / 3 > best_val['mIoU']:
            best_val['epoch'] = epoch
            best_val['mIoU'] = (iou[0] + iou[1] + iou[2]) / 3
            ckpted = True

        if epoch % 50 == 0:
            checkpoint(nets, history, args, epoch)
            continue

        if epoch == args.num_epoch:
            checkpoint(nets, history, args, epoch)
            continue
        if epoch < 15:
            ckpted = False
        if ckpted == False:
            continue
        else:
            checkpoint(nets, history, args, epoch)
            continue
        print()

    print('Training Done!')
def main(cfg, gpus):
    torch.backends.cudnn.enabled = False
    # cudnn.deterministic = False
    # cudnn.enabled = True
    # Network Builders
    net_encoder = ModelBuilder.build_encoder(
        arch=cfg.MODEL.arch_encoder.lower(),
        fc_dim=cfg.MODEL.fc_dim,
        weights=cfg.MODEL.weights_encoder)
    net_decoder = ModelBuilder.build_decoder(
        arch=cfg.MODEL.arch_decoder.lower(),
        fc_dim=cfg.MODEL.fc_dim,
        num_class=cfg.DATASET.num_class,
        weights=cfg.MODEL.weights_decoder)

    if cfg.MODEL.arch_decoder == 'ocr':
        print('Using cross entropy loss')
        crit = CrossEntropy(ignore_label=-1)
    else:
        crit = nn.NLLLoss(ignore_index=-1)

    if cfg.MODEL.arch_decoder.endswith('deepsup'):
        segmentation_module = SegmentationModule(net_encoder, net_decoder,
                                                 crit,
                                                 cfg.TRAIN.deep_sup_scale)
    else:
        segmentation_module = SegmentationModule(net_encoder, net_decoder,
                                                 crit)

    # Dataset and Loader
    dataset_train = TrainDataset(cfg.DATASET.root_dataset,
                                 cfg.DATASET.list_train,
                                 cfg.DATASET,
                                 batch_per_gpu=cfg.TRAIN.batch_size_per_gpu)

    loader_train = torch.utils.data.DataLoader(
        dataset_train,
        batch_size=len(gpus),
        shuffle=False,  # parameter is not used
        collate_fn=user_scattered_collate,
        num_workers=cfg.TRAIN.workers,
        drop_last=True,
        pin_memory=True)
    # create loader iterator
    iterator_train = iter(loader_train)
    print('1 Epoch = {} iters'.format(cfg.TRAIN.epoch_iters))

    if cfg.TRAIN.eval:
        # Dataset and Loader for validtaion data
        dataset_val = ValDataset(cfg.DATASET.root_dataset,
                                 cfg.DATASET.list_val, cfg.DATASET)
        loader_val = torch.utils.data.DataLoader(
            dataset_val,
            batch_size=cfg.VAL.batch_size,
            shuffle=False,
            collate_fn=user_scattered_collate,
            num_workers=5,
            drop_last=True)
        iterator_val = iter(loader_val)

    # load nets into gpu
    if len(gpus) > 1:
        segmentation_module = UserScatteredDataParallel(segmentation_module,
                                                        device_ids=gpus)
        # For sync bn
        patch_replication_callback(segmentation_module)
    segmentation_module.cuda()

    # Set up optimizers
    nets = (net_encoder, net_decoder, crit)
    optimizers = create_optimizers(nets, cfg)

    # Main loop
    history = {
        'train': {
            'epoch': [],
            'loss': [],
            'acc': [],
            'last_score': 0,
            'best_score': cfg.TRAIN.best_score
        }
    }
    for epoch in range(cfg.TRAIN.start_epoch, cfg.TRAIN.num_epoch):
        train(segmentation_module, iterator_train, optimizers, history,
              epoch + 1, cfg)
        # calculate segmentation score
        if cfg.TRAIN.eval and epoch in range(cfg.TRAIN.start_epoch,
                                             cfg.TRAIN.num_epoch,
                                             step=cfg.TRAIN.eval_step):
            iou, acc = evaluate(segmentation_module, iterator_val, cfg, gpus)
            history['train']['last_score'] = (iou + acc) / 2
            if history['train']['last_score'] > history['train']['best_score']:
                history['train']['best_score'] = history['train']['last_score']
                checkpoint(nets, history, cfg, 'best_score')
        # checkpointing
        checkpoint(nets, history, cfg, epoch + 1)
    print('Training Done!')
Esempio n. 5
0
def main(args):
    # Network Builders
    builder = ModelBuilder()
    net_encoder = None
    net_decoder = None
    unet = None

    if args.unet == False:
        net_encoder = builder.build_encoder(arch=args.arch_encoder,
                                            fc_dim=args.fc_dim,
                                            weights=args.weights_encoder)
        net_decoder = builder.build_decoder(arch=args.arch_decoder,
                                            fc_dim=args.fc_dim,
                                            num_class=args.num_class,
                                            weights=args.weights_decoder)
    else:
        unet = builder.build_unet(num_class=args.num_class,
                                  arch=args.unet_arch,
                                  weights=args.weights_unet)

        print("Froze the following layers: ")
        for name, p in unet.named_parameters():
            if p.requires_grad == False:
                print(name)

    crit = nn.NLLLoss()
    #crit = nn.BCEWithLogitsLoss(pos_weight=torch.tensor(50))
    #crit = nn.CrossEntropyLoss().cuda()
    #crit = nn.BCELoss()

    if args.arch_decoder.endswith('deepsup') and args.unet == False:
        segmentation_module = SegmentationModule(net_encoder, net_decoder,
                                                 crit, args.deep_sup_scale)
    else:
        segmentation_module = SegmentationModule(net_encoder,
                                                 net_decoder,
                                                 crit,
                                                 is_unet=args.unet,
                                                 unet=unet)

    train_augs = Compose([
        RandomSized(224),
        RandomHorizontallyFlip(),
        RandomVerticallyFlip(),
        RandomRotate(180),
        AdjustContrast(cf=0.25),
        AdjustBrightness(bf=0.25)
    ])  #, RandomErasing()])
    #train_augs = None
    # Dataset and Loader
    dataset_train = TrainDataset(args.list_train,
                                 args,
                                 batch_per_gpu=args.batch_size_per_gpu,
                                 augmentations=train_augs)

    loader_train = data.DataLoader(
        dataset_train,
        batch_size=len(args.gpus),  # we have modified data_parallel
        shuffle=False,  # we do not use this param
        num_workers=int(args.workers),
        drop_last=True,
        pin_memory=False)

    print('1 Epoch = {} iters'.format(args.epoch_iters))
    # create loader iterator
    iterator_train = iter(loader_train)

    # load nets into gpu
    if len(args.gpus) > 1:
        segmentation_module = UserScatteredDataParallel(segmentation_module,
                                                        device_ids=args.gpus)
        # For sync bn
        patch_replication_callback(segmentation_module)
    segmentation_module.cuda()

    # Set up optimizers
    nets = (net_encoder, net_decoder, crit) if args.unet == False else (unet,
                                                                        crit)
    optimizers = create_optimizers(nets, args)

    # Main loop
    history = {'train': {'epoch': [], 'loss': [], 'acc': []}}

    for epoch in range(args.start_epoch, args.num_epoch + 1):
        train(segmentation_module, iterator_train, optimizers, history, epoch,
              args)
        # checkpointing
        checkpoint(nets, history, args, epoch)

    print('Training Done!')
def main(args):
    # Network Builders
    builder = VGGModelBuilder()
    net_encoder = builder.build_encoder(arch=args.arch_encoder,
                                        weights=args.weights_encoder)
    net_decoder = builder.build_decoder(arch=args.arch_decoder,
                                        fc_dim=1024,
                                        num_class=args.num_class,
                                        weights=args.weights_decoder)

    crit = nn.CrossEntropyLoss(ignore_index=255, reduction='sum')

    if args.arch_decoder.endswith('deepsup'):
        segmentation_module = SegmentationModule(net_encoder, net_decoder,
                                                 crit, args.deep_sup_scale)
    else:
        segmentation_module = SegmentationModule(net_encoder, net_decoder,
                                                 crit)
    print(segmentation_module)

    # Dataset and Loader
    dataset_train = VOCTrainDataset(args,
                                    batch_per_gpu=args.batch_size_per_gpu)
    loader_train = torchdata.DataLoader(
        dataset_train,
        batch_size=1,  # data_parallel have been modified, not useful
        shuffle=False,  # do not use this param
        collate_fn=user_scattered_collate,
        num_workers=1,  # MUST be 1 or 0
        drop_last=True,
        pin_memory=True)

    print('[{}] 1 training epoch = {} iters'.format(
        args.epoch_iters,
        datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")))

    # create loader iterator
    iterator_train = iter(loader_train)

    # load nets into gpu
    if args.num_gpus > 1:
        segmentation_module = UserScatteredDataParallel(segmentation_module,
                                                        device_ids=args.gpu_id)
        # For sync bn
        patch_replication_callback(segmentation_module)
    segmentation_module.cuda(device=args.gpu_id[0])

    # Set up optimizers
    nets = (net_encoder, net_decoder, crit)
    optimizers = create_optimizers(nets, args)

    # Main loop
    history = {
        'train': {
            'epoch': [],
            'loss': [],
            'train_acc': [],
            'test_ious': [],
            'test_mean_iou': []
        }
    }

    for epoch in range(args.start_epoch, args.num_epoch + 1):
        # test/validate
        dataset_val = VOCValDataset(
            args
        )  # create val dataset loader for every eval, in order to use drop_last=false
        loader_val = torchdata.DataLoader(
            dataset_val,
            # data_parallel have been modified, MUST use val batch size
            #   and collate_fn MUST be user_scattered_collate
            batch_size=args.val_batch_size,
            shuffle=False,
            collate_fn=user_scattered_collate,
            num_workers=1,  # MUST be 1 or 0
            drop_last=False)
        iterator_val = iter(loader_val)
        if epoch % args.test_epoch_interval == 0 or epoch == args.num_epoch:  # epoch != 1 and
            (cls_ious, cls_mean_iou) = evaluate(segmentation_module,
                                                iterator_val, args)
            history['train']['test_ious'].append(cls_ious)
            history['train']['test_mean_iou'].append(cls_mean_iou)
        else:
            history['train']['test_ious'].append(-1)  # empty data
            history['train']['test_mean_iou'].append(-1)

# train
        train(segmentation_module, iterator_train, optimizers, history, epoch,
              args)

        # checkpointing
        checkpoint(nets, history, args, epoch)

    print('[{}] Training Done!'.format(
        datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")))
Esempio n. 7
0
def main(args):
    # Network Builders
    builder = ModelBuilder()
    net_encoder = builder.build_encoder(arch=args.arch_encoder,
                                        fc_dim=args.fc_dim,
                                        weights=args.weights_encoder)
    net_decoder = builder.build_decoder(arch=args.arch_decoder,
                                        fc_dim=args.fc_dim,
                                        num_class=150,
                                        weights=args.weights_decoder)

    crit = nn.NLLLoss(ignore_index=-1)

    if args.arch_decoder.endswith('deepsup'):
        segmentation_module = SegmentationModule(net_encoder, net_decoder,
                                                 crit, args.deep_sup_scale)
    else:
        segmentation_module = SegmentationModule(net_encoder, net_decoder,
                                                 crit)

    ########
    for param in segmentation_module.encoder.parameters():
        param.requires_grad = False
    #for name, param in segmentation_module.decoder.named_parameters():
    #   print(name)
    #  if(name == "conv_last.weight" or name =="conv_last.bias" or name =="conv_last_deepsup.weight" or name =="conv_last_deepsup.bias"):
    #     param.requires_grad = True
    #else:
    #   param.requires_grad = False
    #print(param.requires_grad)
    segmentation_module.decoder.conv_last = nn.Conv2d(args.fc_dim // 4, 12, 1,
                                                      1, 0)
    #segmentation_module.decoder.conv_last.
    segmentation_module.decoder.conv_last_deepsup = nn.Conv2d(
        args.fc_dim // 4, 12, 1, 1, 0)
    ########

    # Dataset and Loader
    dataset_train = TrainDataset(args.list_train,
                                 args,
                                 batch_per_gpu=args.batch_size_per_gpu)

    loader_train = torchdata.DataLoader(
        dataset_train,
        batch_size=len(args.gpus),  # we have modified data_parallel
        shuffle=False,  # we do not use this param
        collate_fn=user_scattered_collate,
        num_workers=int(args.workers),
        drop_last=True,
        pin_memory=True)

    print('1 Epoch = {} iters'.format(args.epoch_iters))

    # create loader iterator
    iterator_train = iter(loader_train)
    #######
    #torch.backends.cudnn.benchmark = True
    #CUDA_LAUNCH_BLOCKING=1
    #######
    # load nets into gpu
    if len(args.gpus) > 1:
        segmentation_module = UserScatteredDataParallel(segmentation_module,
                                                        device_ids=args.gpus)
        # For sync bn
        patch_replication_callback(segmentation_module)
    segmentation_module.cuda()

    # Set up optimizers
    nets = (net_encoder, net_decoder, crit)
    optimizers = create_optimizers(nets, args)

    # Main loop
    history = {'train': {'epoch': [], 'loss': [], 'acc': []}}

    for epoch in range(args.start_epoch, args.num_epoch + 1):
        train(segmentation_module, iterator_train, optimizers, history, epoch,
              args)

        # checkpointing
        checkpoint(nets, history, args, epoch)

    print('Training Done!')
Esempio n. 8
0
def main(args):
    # Network Builders
    builder = ModelBuilder()

    crit = nn.NLLLoss(ignore_index=-1)
    crit = crit.cuda()
    net_encoder = builder.build_encoder(
        weights="baseline-resnet50dilated-ppm_deepsup/encoder_epoch_20.pth")
    gcu = GraphConv(
        batch=args.batch_size_per_gpu
    )  #, V=2), GCU(X=enc_out, V=4), GCU(X=enc_out, V=8),GCU(X=enc_out, V=32)]
    # gcu.load_state_dict(torch.load("ckpt/baseline-resnet50dilated-ngpus1-batchSize1-imgMaxSize1000-paddingConst8-segmDownsampleRate8-epoch20/decoder_epoch_20.pth"))
    segmentation_module = SegmentationModule(net_encoder, gcu, crit, tr=True)

    # Dataset and Loader
    dataset_train = TrainDataset(args.list_train,
                                 args,
                                 batch_per_gpu=args.batch_size_per_gpu)

    loader_train = torchdata.DataLoader(
        dataset_train,
        batch_size=len(args.gpus),  # we have modified data_parallel
        shuffle=False,  # we do not use this param
        collate_fn=user_scattered_collate,
        num_workers=int(args.workers),
        drop_last=True,
        pin_memory=True)

    print('1 Epoch = {} iters'.format(args.epoch_iters))

    # create loader iterator
    iterator_train = iter(loader_train)

    # load nets into gpu
    if len(args.gpus) > 4:
        segmentation_module = UserScatteredDataParallel(segmentation_module,
                                                        device_ids=args.gpus)
        # For sync bn
        patch_replication_callback(segmentation_module)

# segmentation_module.cuda()

# Set up optimizers
# print(gcu[0].parameters())
    nets = (net_encoder, gcu, crit)
    optimizers, par = create_optimizers(nets, args)

    # Main loop
    history = {'train': {'epoch': [], 'loss': [], 'acc': []}}
    vis = visdom.Visdom()
    win = vis.line(np.array([5.7]),
                   opts=dict(xlabel='epochs',
                             ylabel='Loss',
                             title='Training Loss V=16',
                             legend=['Loss']))

    for epoch in range(args.start_epoch, args.num_epoch + 1):
        lss = train(segmentation_module, iterator_train, optimizers, history,
                    epoch, par, vis, win, args)

        # checkpointing
        checkpoint(nets, history, args, epoch)

    print('Training Done!')
def main(cfg, gpus):
    # Network Builders
    net_encoder = ModelBuilder.build_encoder(
        arch=cfg.MODEL.arch_encoder.lower(),
        fc_dim=cfg.MODEL.fc_dim,
        weights=cfg.MODEL.weights_encoder,
        dilate_rate=cfg.DATASET.segm_downsampling_rate)
    net_decoder = ModelBuilder.build_decoder(
        arch=cfg.MODEL.arch_decoder.lower(),
        fc_dim=cfg.MODEL.fc_dim,
        num_class=cfg.DATASET.num_class,
        weights=cfg.MODEL.weights_decoder)
    if cfg.MODEL.foveation:
        net_foveater = ModelBuilder.build_foveater(
            in_channel=cfg.MODEL.in_dim,
            out_channel=len(cfg.MODEL.patch_bank),
            len_gpus=len(gpus),
            weights=cfg.MODEL.weights_foveater,
            cfg=cfg)

    # tensor
    writer = SummaryWriter('{}/tensorboard'.format(cfg.DIR))

    if cfg.DATASET.root_dataset == '/scratch0/chenjin/GLEASON2019_DATA/Data/':
        if cfg.TRAIN.loss_fun == 'DiceLoss':
            crit = DiceLoss()
        elif cfg.TRAIN.loss_fun == 'FocalLoss':
            crit = FocalLoss()
        elif cfg.TRAIN.loss_fun == 'DiceCoeff':
            crit = DiceCoeff()
        elif cfg.TRAIN.loss_fun == 'NLLLoss':
            crit = nn.NLLLoss(ignore_index=-2)
        else:
            crit = OhemCrossEntropy(ignore_label=-1,
                                    thres=0.9,
                                    min_kept=100000,
                                    weight=None)
    elif 'ADE20K' in cfg.DATASET.root_dataset:
        crit = nn.NLLLoss(ignore_index=-2)
    elif 'CITYSCAPES' in cfg.DATASET.root_dataset:
        if cfg.TRAIN.loss_fun == 'NLLLoss':
            crit = nn.NLLLoss(ignore_index=19)
        else:
            class_weights = torch.FloatTensor([
                0.8373, 0.918, 0.866, 1.0345, 1.0166, 0.9969, 0.9754, 1.0489,
                0.8786, 1.0023, 0.9539, 0.9843, 1.1116, 0.9037, 1.0865, 1.0955,
                1.0865, 1.1529, 1.0507
            ]).cuda()
            crit = OhemCrossEntropy(ignore_label=20,
                                    thres=0.9,
                                    min_kept=131072,
                                    weight=class_weights)
    elif 'DeepGlob' in cfg.DATASET.root_dataset and (
            cfg.TRAIN.loss_fun == 'FocalLoss'
            or cfg.TRAIN.loss_fun == 'OhemCrossEntropy'):
        if cfg.TRAIN.loss_fun == 'FocalLoss':
            crit = FocalLoss(gamma=6, ignore_label=cfg.DATASET.ignore_index)
        elif cfg.TRAIN.loss_fun == 'OhemCrossEntropy':
            crit = OhemCrossEntropy(ignore_label=cfg.DATASET.ignore_index,
                                    thres=0.9,
                                    min_kept=131072)
    else:
        if cfg.TRAIN.loss_fun == 'NLLLoss':
            if cfg.DATASET.ignore_index != -2:
                crit = nn.NLLLoss(ignore_index=cfg.DATASET.ignore_index)
            else:
                crit = nn.NLLLoss(ignore_index=-2)
        else:
            if cfg.DATASET.ignore_index != -2:
                crit = nn.CrossEntropyLoss(
                    ignore_index=cfg.DATASET.ignore_index)
            else:
                crit = nn.CrossEntropyLoss(ignore_index=-2)
    # crit = DiceCoeff()

    if cfg.MODEL.arch_decoder.endswith('deepsup'):
        segmentation_module = SegmentationModule(net_encoder, net_decoder,
                                                 crit, cfg,
                                                 cfg.TRAIN.deep_sup_scale)
    elif cfg.MODEL.foveation:
        segmentation_module = SegmentationModule(net_encoder, net_decoder,
                                                 crit, cfg)
    else:
        segmentation_module = SegmentationModule(net_encoder, net_decoder,
                                                 crit, cfg)

    if cfg.MODEL.foveation:
        foveation_module = FovSegmentationModule(net_foveater,
                                                 cfg,
                                                 len_gpus=len(gpus))
        total_fov = sum(
            [param.nelement() for param in foveation_module.parameters()])
        print('Number of FoveationModule params: %.2fM \n' % (total_fov / 1e6))

    total = sum(
        [param.nelement() for param in segmentation_module.parameters()])
    print('Number of SegmentationModule params: %.2fM \n' % (total / 1e6))

    # Dataset and Loader
    dataset_train = TrainDataset(cfg.DATASET.root_dataset,
                                 cfg.DATASET.list_train,
                                 cfg.DATASET,
                                 batch_per_gpu=cfg.TRAIN.batch_size_per_gpu)

    loader_train = torch.utils.data.DataLoader(
        dataset_train,
        batch_size=len(gpus),  # we have modified data_parallel
        shuffle=False,  # we do not use this param
        collate_fn=user_scattered_collate,
        num_workers=cfg.TRAIN.workers,
        drop_last=True,
        pin_memory=True)

    print('1 Epoch = {} iters'.format(cfg.TRAIN.epoch_iters))

    # create loader iterator
    iterator_train = iter(loader_train)

    # load nets into gpu
    if len(gpus) > 1:
        segmentation_module = UserScatteredDataParallel(segmentation_module,
                                                        device_ids=gpus)
        # For sync bn
        patch_replication_callback(segmentation_module)
        if cfg.MODEL.foveation:
            foveation_module = UserScatteredDataParallel(foveation_module,
                                                         device_ids=gpus)
            patch_replication_callback(foveation_module)

    segmentation_module.cuda()
    if cfg.MODEL.foveation:
        foveation_module.cuda()

    # Set up optimizers
    nets = (net_encoder, net_decoder, crit)
    if cfg.MODEL.foveation:
        nets = (net_encoder, net_decoder, crit, net_foveater)
    optimizers = create_optimizers(nets, cfg)

    # Main loop
    if cfg.VAL.dice:
        history = {
            'train': {
                'epoch': [],
                'loss': [],
                'acc': []
            },
            'save': {
                'epoch': [],
                'train_loss': [],
                'train_acc': [],
                'val_iou': [],
                'val_dice': [],
                'val_acc': [],
                'print_grad': None
            }
        }
    else:
        history = {
            'train': {
                'epoch': [],
                'loss': [],
                'acc': []
            },
            'save': {
                'epoch': [],
                'train_loss': [],
                'train_acc': [],
                'val_iou': [],
                'val_dice': [],
                'val_acc': [],
                'print_grad': None
            }
        }

    if cfg.TRAIN.start_epoch > 0:
        history_previous_epoches = pd.read_csv(
            '{}/history_epoch_{}.csv'.format(cfg.DIR, cfg.TRAIN.start_epoch))
        history['save']['epoch'] = list(history_previous_epoches['epoch'])
        history['save']['train_loss'] = list(
            history_previous_epoches['train_loss'])
        history['save']['train_acc'] = list(
            history_previous_epoches['train_acc'])
        history['save']['val_iou'] = list(history_previous_epoches['val_iou'])
        history['save']['val_acc'] = list(history_previous_epoches['val_acc'])
        # if cfg.VAL.dice:
        #     history['save']['val_dice'] = history_previous_epoches['val_dice']

    if not os.path.isdir(os.path.join(cfg.DIR,
                                      "Fov_probability_distribution")):
        os.makedirs(os.path.join(cfg.DIR, "Fov_probability_distribution"))
    f_prob = []
    for p in range(len(cfg.MODEL.patch_bank)):
        f = open(
            os.path.join(cfg.DIR, 'Fov_probability_distribution',
                         'patch_{}_distribution.txt'.format(p)), 'w')
        f.close()

    for epoch in range(cfg.TRAIN.start_epoch, cfg.TRAIN.num_epoch):
        if cfg.MODEL.foveation:
            train(segmentation_module, iterator_train, optimizers, epoch + 1,
                  cfg, history, foveation_module)
            if history['train']['print_grad'] is not None and type(
                    history['train']['print_grad']) is not torch.Tensor:
                if history['train']['print_grad'][
                        'layer1_grad'] is not None and history['train'][
                            'print_grad']['layer1_grad'][
                                history['train']['print_grad']['layer1_grad'] >
                                0].numel() > 0:
                    writer.add_histogram(
                        'Print non-zero gradient (layer1) histogram',
                        history['train']['print_grad']['layer1_grad'][
                            history['train']['print_grad']['layer1_grad'] > 0],
                        epoch + 1)
                    writer.add_histogram(
                        'Print gradient (layer1) histogram',
                        history['train']['print_grad']['layer1_grad'],
                        epoch + 1)
                    writer.add_scalar(
                        'Percentage none-zero gradients (layer1)',
                        history['train']['print_grad']['layer1_grad'][
                            history['train']['print_grad']['layer1_grad'] > 0].
                        numel() /
                        history['train']['print_grad']['layer1_grad'].numel(),
                        epoch + 1)
                    writer.add_image(
                        'Print_grad_Fov_softmax_layer1(normalized_b0_p0)',
                        (history['train']['print_grad']['layer1_grad'][0][0] -
                         history['train']['print_grad']['layer1_grad'][0]
                         [0].min()) /
                        (history['train']['print_grad']['layer1_grad'][0]
                         [0].max() - history['train']['print_grad']
                         ['layer1_grad'][0][0].min()),
                        epoch + 1,
                        dataformats='HW')
                if history['train']['print_grad'][
                        'layer2_grad'] is not None and history['train'][
                            'print_grad']['layer2_grad'][
                                history['train']['print_grad']['layer2_grad'] >
                                0].numel() > 0:
                    writer.add_histogram(
                        'Print non-zero gradient (layer2) histogram',
                        history['train']['print_grad']['layer2_grad'][
                            history['train']['print_grad']['layer2_grad'] > 0],
                        epoch + 1)
                    writer.add_histogram(
                        'Print gradient (layer2) histogram',
                        history['train']['print_grad']['layer2_grad'],
                        epoch + 1)
                    writer.add_scalar(
                        'Percentage none-zero gradients (layer2)',
                        history['train']['print_grad']['layer2_grad'][
                            history['train']['print_grad']['layer2_grad'] > 0].
                        numel() /
                        history['train']['print_grad']['layer2_grad'].numel(),
                        epoch + 1)
                    writer.add_image(
                        'Print_grad_Fov_softmax_layer2(normalized_b0_p0)',
                        (history['train']['print_grad']['layer2_grad'][0][0] -
                         history['train']['print_grad']['layer2_grad'][0]
                         [0].min()) /
                        (history['train']['print_grad']['layer2_grad'][0]
                         [0].max() - history['train']['print_grad']
                         ['layer2_grad'][0][0].min()),
                        epoch + 1,
                        dataformats='HW')
                if history['train']['print_grad'][
                        'layer3_grad'] is not None and history['train'][
                            'print_grad']['layer3_grad'][
                                history['train']['print_grad']['layer3_grad'] >
                                0].numel() > 0:
                    writer.add_histogram(
                        'Print non-zero gradient (layer3) histogram',
                        history['train']['print_grad']['layer3_grad'][
                            history['train']['print_grad']['layer3_grad'] > 0],
                        epoch + 1)
                    writer.add_histogram(
                        'Print gradient (layer3) histogram',
                        history['train']['print_grad']['layer3_grad'],
                        epoch + 1)
                    writer.add_scalar(
                        'Percentage none-zero gradients (layer3)',
                        history['train']['print_grad']['layer3_grad'][
                            history['train']['print_grad']['layer3_grad'] > 0].
                        numel() /
                        history['train']['print_grad']['layer3_grad'].numel(),
                        epoch + 1)
                    writer.add_image(
                        'Print_grad_Fov_softmax_layer3(normalized_b0_p0)',
                        (history['train']['print_grad']['layer3_grad'][0][0] -
                         history['train']['print_grad']['layer3_grad'][0]
                         [0].min()) /
                        (history['train']['print_grad']['layer3_grad'][0]
                         [0].max() - history['train']['print_grad']
                         ['layer3_grad'][0][0].min()),
                        epoch + 1,
                        dataformats='HW')

        else:
            train(segmentation_module, iterator_train, optimizers, epoch + 1,
                  cfg, history)
        # checkpointing

        if (epoch + 1) % cfg.TRAIN.checkpoint_per_epoch == 0:
            checkpoint(nets, cfg, epoch + 1)
            checkpoint_last(nets, cfg, epoch + 1)
        else:
            checkpoint_last(nets, cfg, epoch + 1)

        if (epoch + 1) % cfg.TRAIN.eval_per_epoch == 0:
            # eval during train
            if cfg.VAL.multipro:
                if cfg.MODEL.foveation:
                    if cfg.VAL.all_F_Xlr_time:
                        val_iou, val_acc, F_Xlr_all, F_Xlr_score_flat_all = eval_during_train_multipro(
                            cfg, gpus)
                    else:
                        val_iou, val_acc, F_Xlr, F_Xlr_score_flat = eval_during_train_multipro(
                            cfg, gpus)
                else:
                    val_iou, val_acc = eval_during_train_multipro(cfg, gpus)
            else:
                if cfg.VAL.dice:
                    if cfg.MODEL.foveation:
                        if cfg.VAL.all_F_Xlr_time:
                            val_iou, val_dice, val_acc, F_Xlr_all, F_Xlr_score_flat_all = eval_during_train(
                                cfg)
                        else:
                            val_iou, val_dice, val_acc, F_Xlr, F_Xlr_score_flat = eval_during_train(
                                cfg)
                    else:
                        val_iou, val_dice, val_acc = eval_during_train(cfg)
                else:
                    if cfg.MODEL.foveation:
                        if cfg.VAL.all_F_Xlr_time:
                            val_iou, val_acc, F_Xlr_all, F_Xlr_score_flat_all = eval_during_train(
                                cfg)
                        else:
                            val_iou, val_acc, F_Xlr, F_Xlr_score_flat = eval_during_train(
                                cfg)
                    else:
                        val_iou, val_acc = eval_during_train(cfg)

            history['save']['epoch'].append(epoch + 1)
            history['save']['train_loss'].append(history['train']['loss'][-1])
            history['save']['train_acc'].append(history['train']['acc'][-1] *
                                                100)
            history['save']['val_iou'].append(val_iou)
            if cfg.VAL.dice:
                history['save']['val_dice'].append(val_dice)
            history['save']['val_acc'].append(val_acc)
            # write to tensorboard
            writer.add_scalar('Loss/train', history['train']['loss'][-1],
                              epoch + 1)
            writer.add_scalar('Acc/train', history['train']['acc'][-1] * 100,
                              epoch + 1)
            writer.add_scalar('Acc/val', val_acc, epoch + 1)
            writer.add_scalar('mIoU/val', val_iou, epoch + 1)
            if cfg.VAL.dice:
                writer.add_scalar('mDice/val', val_acc, epoch + 1)
            if cfg.VAL.all_F_Xlr_time:
                print('=============F_Xlr_score_flat_all================\n',
                      F_Xlr_score_flat_all.shape)
                for p in range(F_Xlr_score_flat_all.shape[0]):
                    # add small artifact to modify range, because no range flag in add_histogram
                    F_Xlr_score_flat_all[p][0] = 0
                    F_Xlr_score_flat_all[p][-1] = 1
                    writer.add_histogram(
                        'Patch_{} probability histogram'.format(p),
                        F_Xlr_score_flat_all[p], epoch + 1)
                    f = open(
                        os.path.join(cfg.DIR, 'Fov_probability_distribution',
                                     'patch_{}_distribution.txt'.format(p)),
                        'a')
                    if epoch == 0:
                        f.write('epoch/ bins: {}\n'.format(
                            np.histogram(F_Xlr_score_flat_all[p],
                                         bins=10,
                                         range=(0, 1))[1]))
                    f.write('epoch {}: {}\n'.format(
                        epoch + 1,
                        np.histogram(F_Xlr_score_flat_all[p],
                                     bins=10,
                                     range=(0, 1))[0] / sum(
                                         np.histogram(F_Xlr_score_flat_all[p],
                                                      bins=10,
                                                      range=(0, 1))[0])))
                    f.close()
                writer.add_histogram('Patch_All probability histogram',
                                     F_Xlr_score_flat_all, epoch + 1)
            else:
                for p in range(F_Xlr_score_flat_all.shape[0]):
                    F_Xlr_score_flat[p][0] = 0
                    F_Xlr_score_flat[p][-1] = 1
                    writer.add_histogram(
                        'Patch_{} probability histogram'.format(p),
                        F_Xlr_score_flat[p], epoch + 1)
                writer.add_histogram('Patch_All probability histogram',
                                     F_Xlr_score_flat, epoch + 1)
        else:
            history['save']['epoch'].append(epoch + 1)
            history['save']['train_loss'].append(history['train']['loss'][-1])
            history['save']['train_acc'].append(history['train']['acc'][-1] *
                                                100)
            history['save']['val_iou'].append('')
            if cfg.VAL.dice:
                history['save']['val_dice'].append('')
            history['save']['val_acc'].append('')
            # write to tensorboard
            writer.add_scalar('Loss/train', history['train']['loss'][-1],
                              epoch + 1)
            writer.add_scalar('Acc/train', history['train']['acc'][-1] * 100,
                              epoch + 1)
            # writer.add_scalar('Acc/val', '', epoch+1)
            # writer.add_scalar('mIoU/val', '', epoch+1)

        # saving history
        checkpoint_history(history, cfg, epoch + 1)

        if (epoch + 1) % cfg.TRAIN.eval_per_epoch == 0:
            # output F_Xlr
            if cfg.MODEL.foveation:
                # save time series F_Xlr (t,b,d,w,h)
                if epoch == 0 or epoch == cfg.TRAIN.start_epoch:
                    if cfg.VAL.all_F_Xlr_time:
                        F_Xlr_time_all = []
                        for val_idx in range(len(F_Xlr_all)):
                            F_Xlr_time_all.append(F_Xlr_all[val_idx][0])
                    else:
                        F_Xlr_time = F_Xlr
                else:
                    if cfg.VAL.all_F_Xlr_time:
                        for val_idx in range(len(F_Xlr_all)):
                            F_Xlr_time_all[val_idx] = np.concatenate(
                                (F_Xlr_time_all[val_idx],
                                 F_Xlr_all[val_idx][0]),
                                axis=0)
                    else:
                        F_Xlr_time = np.concatenate((F_Xlr_time, F_Xlr),
                                                    axis=0)
                if cfg.VAL.all_F_Xlr_time:
                    for val_idx in range(len(F_Xlr_all)):
                        print('F_Xlr_time_{}'.format(F_Xlr_all[val_idx][1]),
                              F_Xlr_time_all[val_idx].shape)
                        if not os.path.isdir(
                                os.path.join(cfg.DIR, "F_Xlr_time_all_vals")):
                            os.makedirs(
                                os.path.join(cfg.DIR, "F_Xlr_time_all_vals"))
                        np.save(
                            '{}/F_Xlr_time_all_vals/F_Xlr_time_last_{}.npy'.
                            format(cfg.DIR, F_Xlr_all[val_idx][1]),
                            F_Xlr_time_all[val_idx])
                else:
                    print('F_Xlr_time', F_Xlr_time.shape)
                    np.save('{}/F_Xlr_time_last.npy'.format(cfg.DIR),
                            F_Xlr_time)

    if not cfg.TRAIN.save_checkpoint:
        os.remove('{}/encoder_epoch_last.pth'.format(cfg.DIR))
        os.remove('{}/decoder_epoch_last.pth'.format(cfg.DIR))
    print('Training Done!')
    writer.close()
def main(args):
    # Network Builders
    builder = ModelBuilder()
    net_encoder = builder.build_encoder(arch=args.arch_encoder,
                                        fc_dim=args.fc_dim,
                                        weights=args.weights_encoder)
    net_decoder = builder.build_decoder(arch=args.arch_decoder,
                                        fc_dim=args.fc_dim,
                                        num_class=args.num_class,
                                        weights=args.weights_decoder)

    crit = nn.NLLLoss(ignore_index=255)

    if args.arch_decoder.endswith('deepsup'):
        segmentation_module = SegmentationModule(net_encoder, net_decoder,
                                                 crit, args.deep_sup_scale)
    else:
        segmentation_module = SegmentationModule(net_encoder, net_decoder,
                                                 crit)

    # Dataset and Loader
    # dataset_train = TrainDataset(
    #     args.list_train, args, batch_per_gpu=args.batch_size_per_gpu)
    # dataset_train = voc.TrainDataset_VOC(dataset_root=config.img_root_folder, mode='train', transform=input_transform)
    dataset_train = VOC_TrainDataset(opt=args,
                                     batch_per_gpu=args.batch_size_per_gpu)

    loader_train = torchdata.DataLoader(
        dataset_train,
        batch_size=args.num_gpus,  # we have modified data_parallel
        shuffle=False,  # we do not use this param
        collate_fn=user_scattered_collate,
        num_workers=int(args.workers),
        drop_last=True,
        pin_memory=True)

    print('1 Epoch = {} iters'.format(args.epoch_iters))

    # create loader iterator
    iterator_train = iter(loader_train)

    # load nets into gpu
    if args.num_gpus > 1:
        segmentation_module = UserScatteredDataParallel(segmentation_module,
                                                        device_ids=range(
                                                            args.num_gpus))
        # segmentation_module = UserScatteredDataParallel(
        #     segmentation_module,
        #     device_ids=[0,3,4,5])

        # For sync bn
        patch_replication_callback(segmentation_module)
    segmentation_module.cuda()

    # Set up optimizers
    nets = (net_encoder, net_decoder, crit)
    optimizers = create_optimizers(nets, args)

    if args.resume:
        file_path_history = '{}/history_{}'.format(args.ckpt, 'epoch_last.pth')
        file_path_encoder = '{}/encoder_{}'.format(args.ckpt, 'epoch_last.pth')
        file_path_decoder = '{}/decoder_{}'.format(args.ckpt, 'epoch_last.pth')
        file_path_optimizers = '{}/optimizers_{}'.format(
            args.ckpt, 'epoch_last.pth')

        if os.path.isfile(file_path_history):
            print("=> loading checkpoint '{}'".format(file_path_history))
            checkpoint_history = torch.load(file_path_history)
            checkpoint_encoder = torch.load(file_path_encoder)
            checkpoint_decoder = torch.load(file_path_decoder)
            checkpoint_optimizers = torch.load(file_path_optimizers)

            args.start_epoch = int(checkpoint_history['train']['epoch'][0]) + 1

            nets[0].load_state_dict(checkpoint_encoder)
            nets[1].load_state_dict(checkpoint_decoder)

            optimizers[0].load_state_dict(
                checkpoint_optimizers['encoder_optimizer'])
            optimizers[1].load_state_dict(
                checkpoint_optimizers['decoder_optimizer'])

            print("=> start train epoch {}".format(args.start_epoch))
        else:
            print('resume not find epoch-last checkpoint')

    # Main loop
    history = {'train': {'epoch': [], 'loss': [], 'acc': []}}

    for epoch in range(args.start_epoch, args.num_epoch + 1):
        train(segmentation_module, iterator_train, optimizers, history, epoch,
              args)

        # checkpointing
        checkpoint(nets, optimizers, history, args, epoch)

    print('Training Done!')