def main(cfg, gpus):
    if 'CITYSCAPE' in cfg.DATASET.list_train:
        crit = nn.NLLLoss(ignore_index=19)
    else:
        crit = nn.NLLLoss(ignore_index=-2)
    # Segmentation Network Builders
    net_encoder = ModelBuilder.build_encoder(
        arch=cfg.MODEL.arch_encoder.lower(),
        fc_dim=cfg.MODEL.fc_dim,
        weights=cfg.MODEL.weights_encoder,
        dilate_rate=cfg.DATASET.segm_downsampling_rate)
    net_decoder = ModelBuilder.build_decoder(
        arch=cfg.MODEL.arch_decoder.lower(),
        fc_dim=cfg.MODEL.fc_dim,
        num_class=cfg.DATASET.num_class,
        weights=cfg.MODEL.weights_decoder)
    segmentation_module = SegmentationModule(net_encoder, net_decoder, crit,
                                             cfg)
    segmentation_module.cuda()
    nets = (net_encoder, net_decoder, crit)
    # Foveation Network Builders
    if cfg.MODEL.foveation:
        net_foveater = ModelBuilder.build_foveater(
            in_channel=cfg.MODEL.in_dim,
            out_channel=len(cfg.MODEL.patch_bank),
            len_gpus=len(gpus),
            weights=cfg.MODEL.weights_foveater,
            cfg=cfg)
        foveation_module = FovSegmentationModule(net_foveater,
                                                 cfg,
                                                 len_gpus=len(gpus))
        foveation_module.cuda()
        nets = (net_encoder, net_decoder, crit, net_foveater)
    # Set up optimizers
    optimizers = create_optimizers(nets, cfg)

    # Dataset and Loader
    dataset_train = TrainDataset(cfg.DATASET.root_dataset,
                                 cfg.DATASET.list_train, cfg.DATASET)
    loader_train = torch.utils.data.DataLoader(
        dataset_train,
        batch_size=len(gpus),  # customerized pre-batched dataset
        pin_memory=True)

    print('1 Epoch = {} iters'.format(cfg.TRAIN.epoch_iters))
    # create loader iterator
    iterator_train = iter(loader_train)
    # Main loop
    for epoch in range(cfg.TRAIN.start_epoch, cfg.TRAIN.num_epoch):
        if cfg.MODEL.foveation:
            train(segmentation_module,
                  iterator_train,
                  optimizers,
                  epoch + 1,
                  cfg,
                  history=None,
                  foveation_module=foveation_module)
        else:
            train(segmentation_module, iterator_train, optimizers, epoch + 1,
                  cfg)
        # save checkpoint
        checkpoint_last(nets, cfg, epoch + 1)
        # eval during train
        if cfg.MODEL.foveation:
            val_iou, val_acc, F_Xlr, F_Xlr_score_flat = eval_during_train(cfg)
        else:
            val_iou, val_acc = eval_during_train(cfg)
    print('Training Done!')
def main(cfg, gpus):
    # Network Builders
    net_encoder = ModelBuilder.build_encoder(
        arch=cfg.MODEL.arch_encoder.lower(),
        fc_dim=cfg.MODEL.fc_dim,
        weights=cfg.MODEL.weights_encoder,
        dilate_rate=cfg.DATASET.segm_downsampling_rate)
    net_decoder = ModelBuilder.build_decoder(
        arch=cfg.MODEL.arch_decoder.lower(),
        fc_dim=cfg.MODEL.fc_dim,
        num_class=cfg.DATASET.num_class,
        weights=cfg.MODEL.weights_decoder)
    if cfg.MODEL.foveation:
        net_foveater = ModelBuilder.build_foveater(
            in_channel=cfg.MODEL.in_dim,
            out_channel=len(cfg.MODEL.patch_bank),
            len_gpus=len(gpus),
            weights=cfg.MODEL.weights_foveater,
            cfg=cfg)

    # tensor
    writer = SummaryWriter('{}/tensorboard'.format(cfg.DIR))

    if cfg.DATASET.root_dataset == '/scratch0/chenjin/GLEASON2019_DATA/Data/':
        if cfg.TRAIN.loss_fun == 'DiceLoss':
            crit = DiceLoss()
        elif cfg.TRAIN.loss_fun == 'FocalLoss':
            crit = FocalLoss()
        elif cfg.TRAIN.loss_fun == 'DiceCoeff':
            crit = DiceCoeff()
        elif cfg.TRAIN.loss_fun == 'NLLLoss':
            crit = nn.NLLLoss(ignore_index=-2)
        else:
            crit = OhemCrossEntropy(ignore_label=-1,
                                    thres=0.9,
                                    min_kept=100000,
                                    weight=None)
    elif 'ADE20K' in cfg.DATASET.root_dataset:
        crit = nn.NLLLoss(ignore_index=-2)
    elif 'CITYSCAPES' in cfg.DATASET.root_dataset:
        if cfg.TRAIN.loss_fun == 'NLLLoss':
            crit = nn.NLLLoss(ignore_index=19)
        else:
            class_weights = torch.FloatTensor([
                0.8373, 0.918, 0.866, 1.0345, 1.0166, 0.9969, 0.9754, 1.0489,
                0.8786, 1.0023, 0.9539, 0.9843, 1.1116, 0.9037, 1.0865, 1.0955,
                1.0865, 1.1529, 1.0507
            ]).cuda()
            crit = OhemCrossEntropy(ignore_label=20,
                                    thres=0.9,
                                    min_kept=131072,
                                    weight=class_weights)
    elif 'DeepGlob' in cfg.DATASET.root_dataset and (
            cfg.TRAIN.loss_fun == 'FocalLoss'
            or cfg.TRAIN.loss_fun == 'OhemCrossEntropy'):
        if cfg.TRAIN.loss_fun == 'FocalLoss':
            crit = FocalLoss(gamma=6, ignore_label=cfg.DATASET.ignore_index)
        elif cfg.TRAIN.loss_fun == 'OhemCrossEntropy':
            crit = OhemCrossEntropy(ignore_label=cfg.DATASET.ignore_index,
                                    thres=0.9,
                                    min_kept=131072)
    else:
        if cfg.TRAIN.loss_fun == 'NLLLoss':
            if cfg.DATASET.ignore_index != -2:
                crit = nn.NLLLoss(ignore_index=cfg.DATASET.ignore_index)
            else:
                crit = nn.NLLLoss(ignore_index=-2)
        else:
            if cfg.DATASET.ignore_index != -2:
                crit = nn.CrossEntropyLoss(
                    ignore_index=cfg.DATASET.ignore_index)
            else:
                crit = nn.CrossEntropyLoss(ignore_index=-2)
    # crit = DiceCoeff()

    if cfg.MODEL.arch_decoder.endswith('deepsup'):
        segmentation_module = SegmentationModule(net_encoder, net_decoder,
                                                 crit, cfg,
                                                 cfg.TRAIN.deep_sup_scale)
    elif cfg.MODEL.foveation:
        segmentation_module = SegmentationModule(net_encoder, net_decoder,
                                                 crit, cfg)
    else:
        segmentation_module = SegmentationModule(net_encoder, net_decoder,
                                                 crit, cfg)

    if cfg.MODEL.foveation:
        foveation_module = FovSegmentationModule(net_foveater,
                                                 cfg,
                                                 len_gpus=len(gpus))
        total_fov = sum(
            [param.nelement() for param in foveation_module.parameters()])
        print('Number of FoveationModule params: %.2fM \n' % (total_fov / 1e6))

    total = sum(
        [param.nelement() for param in segmentation_module.parameters()])
    print('Number of SegmentationModule params: %.2fM \n' % (total / 1e6))

    # Dataset and Loader
    dataset_train = TrainDataset(cfg.DATASET.root_dataset,
                                 cfg.DATASET.list_train,
                                 cfg.DATASET,
                                 batch_per_gpu=cfg.TRAIN.batch_size_per_gpu)

    loader_train = torch.utils.data.DataLoader(
        dataset_train,
        batch_size=len(gpus),  # we have modified data_parallel
        shuffle=False,  # we do not use this param
        collate_fn=user_scattered_collate,
        num_workers=cfg.TRAIN.workers,
        drop_last=True,
        pin_memory=True)

    print('1 Epoch = {} iters'.format(cfg.TRAIN.epoch_iters))

    # create loader iterator
    iterator_train = iter(loader_train)

    # load nets into gpu
    if len(gpus) > 1:
        segmentation_module = UserScatteredDataParallel(segmentation_module,
                                                        device_ids=gpus)
        # For sync bn
        patch_replication_callback(segmentation_module)
        if cfg.MODEL.foveation:
            foveation_module = UserScatteredDataParallel(foveation_module,
                                                         device_ids=gpus)
            patch_replication_callback(foveation_module)

    segmentation_module.cuda()
    if cfg.MODEL.foveation:
        foveation_module.cuda()

    # Set up optimizers
    nets = (net_encoder, net_decoder, crit)
    if cfg.MODEL.foveation:
        nets = (net_encoder, net_decoder, crit, net_foveater)
    optimizers = create_optimizers(nets, cfg)

    # Main loop
    if cfg.VAL.dice:
        history = {
            'train': {
                'epoch': [],
                'loss': [],
                'acc': []
            },
            'save': {
                'epoch': [],
                'train_loss': [],
                'train_acc': [],
                'val_iou': [],
                'val_dice': [],
                'val_acc': [],
                'print_grad': None
            }
        }
    else:
        history = {
            'train': {
                'epoch': [],
                'loss': [],
                'acc': []
            },
            'save': {
                'epoch': [],
                'train_loss': [],
                'train_acc': [],
                'val_iou': [],
                'val_dice': [],
                'val_acc': [],
                'print_grad': None
            }
        }

    if cfg.TRAIN.start_epoch > 0:
        history_previous_epoches = pd.read_csv(
            '{}/history_epoch_{}.csv'.format(cfg.DIR, cfg.TRAIN.start_epoch))
        history['save']['epoch'] = list(history_previous_epoches['epoch'])
        history['save']['train_loss'] = list(
            history_previous_epoches['train_loss'])
        history['save']['train_acc'] = list(
            history_previous_epoches['train_acc'])
        history['save']['val_iou'] = list(history_previous_epoches['val_iou'])
        history['save']['val_acc'] = list(history_previous_epoches['val_acc'])
        # if cfg.VAL.dice:
        #     history['save']['val_dice'] = history_previous_epoches['val_dice']

    if not os.path.isdir(os.path.join(cfg.DIR,
                                      "Fov_probability_distribution")):
        os.makedirs(os.path.join(cfg.DIR, "Fov_probability_distribution"))
    f_prob = []
    for p in range(len(cfg.MODEL.patch_bank)):
        f = open(
            os.path.join(cfg.DIR, 'Fov_probability_distribution',
                         'patch_{}_distribution.txt'.format(p)), 'w')
        f.close()

    for epoch in range(cfg.TRAIN.start_epoch, cfg.TRAIN.num_epoch):
        if cfg.MODEL.foveation:
            train(segmentation_module, iterator_train, optimizers, epoch + 1,
                  cfg, history, foveation_module)
            if history['train']['print_grad'] is not None and type(
                    history['train']['print_grad']) is not torch.Tensor:
                if history['train']['print_grad'][
                        'layer1_grad'] is not None and history['train'][
                            'print_grad']['layer1_grad'][
                                history['train']['print_grad']['layer1_grad'] >
                                0].numel() > 0:
                    writer.add_histogram(
                        'Print non-zero gradient (layer1) histogram',
                        history['train']['print_grad']['layer1_grad'][
                            history['train']['print_grad']['layer1_grad'] > 0],
                        epoch + 1)
                    writer.add_histogram(
                        'Print gradient (layer1) histogram',
                        history['train']['print_grad']['layer1_grad'],
                        epoch + 1)
                    writer.add_scalar(
                        'Percentage none-zero gradients (layer1)',
                        history['train']['print_grad']['layer1_grad'][
                            history['train']['print_grad']['layer1_grad'] > 0].
                        numel() /
                        history['train']['print_grad']['layer1_grad'].numel(),
                        epoch + 1)
                    writer.add_image(
                        'Print_grad_Fov_softmax_layer1(normalized_b0_p0)',
                        (history['train']['print_grad']['layer1_grad'][0][0] -
                         history['train']['print_grad']['layer1_grad'][0]
                         [0].min()) /
                        (history['train']['print_grad']['layer1_grad'][0]
                         [0].max() - history['train']['print_grad']
                         ['layer1_grad'][0][0].min()),
                        epoch + 1,
                        dataformats='HW')
                if history['train']['print_grad'][
                        'layer2_grad'] is not None and history['train'][
                            'print_grad']['layer2_grad'][
                                history['train']['print_grad']['layer2_grad'] >
                                0].numel() > 0:
                    writer.add_histogram(
                        'Print non-zero gradient (layer2) histogram',
                        history['train']['print_grad']['layer2_grad'][
                            history['train']['print_grad']['layer2_grad'] > 0],
                        epoch + 1)
                    writer.add_histogram(
                        'Print gradient (layer2) histogram',
                        history['train']['print_grad']['layer2_grad'],
                        epoch + 1)
                    writer.add_scalar(
                        'Percentage none-zero gradients (layer2)',
                        history['train']['print_grad']['layer2_grad'][
                            history['train']['print_grad']['layer2_grad'] > 0].
                        numel() /
                        history['train']['print_grad']['layer2_grad'].numel(),
                        epoch + 1)
                    writer.add_image(
                        'Print_grad_Fov_softmax_layer2(normalized_b0_p0)',
                        (history['train']['print_grad']['layer2_grad'][0][0] -
                         history['train']['print_grad']['layer2_grad'][0]
                         [0].min()) /
                        (history['train']['print_grad']['layer2_grad'][0]
                         [0].max() - history['train']['print_grad']
                         ['layer2_grad'][0][0].min()),
                        epoch + 1,
                        dataformats='HW')
                if history['train']['print_grad'][
                        'layer3_grad'] is not None and history['train'][
                            'print_grad']['layer3_grad'][
                                history['train']['print_grad']['layer3_grad'] >
                                0].numel() > 0:
                    writer.add_histogram(
                        'Print non-zero gradient (layer3) histogram',
                        history['train']['print_grad']['layer3_grad'][
                            history['train']['print_grad']['layer3_grad'] > 0],
                        epoch + 1)
                    writer.add_histogram(
                        'Print gradient (layer3) histogram',
                        history['train']['print_grad']['layer3_grad'],
                        epoch + 1)
                    writer.add_scalar(
                        'Percentage none-zero gradients (layer3)',
                        history['train']['print_grad']['layer3_grad'][
                            history['train']['print_grad']['layer3_grad'] > 0].
                        numel() /
                        history['train']['print_grad']['layer3_grad'].numel(),
                        epoch + 1)
                    writer.add_image(
                        'Print_grad_Fov_softmax_layer3(normalized_b0_p0)',
                        (history['train']['print_grad']['layer3_grad'][0][0] -
                         history['train']['print_grad']['layer3_grad'][0]
                         [0].min()) /
                        (history['train']['print_grad']['layer3_grad'][0]
                         [0].max() - history['train']['print_grad']
                         ['layer3_grad'][0][0].min()),
                        epoch + 1,
                        dataformats='HW')

        else:
            train(segmentation_module, iterator_train, optimizers, epoch + 1,
                  cfg, history)
        # checkpointing

        if (epoch + 1) % cfg.TRAIN.checkpoint_per_epoch == 0:
            checkpoint(nets, cfg, epoch + 1)
            checkpoint_last(nets, cfg, epoch + 1)
        else:
            checkpoint_last(nets, cfg, epoch + 1)

        if (epoch + 1) % cfg.TRAIN.eval_per_epoch == 0:
            # eval during train
            if cfg.VAL.multipro:
                if cfg.MODEL.foveation:
                    if cfg.VAL.all_F_Xlr_time:
                        val_iou, val_acc, F_Xlr_all, F_Xlr_score_flat_all = eval_during_train_multipro(
                            cfg, gpus)
                    else:
                        val_iou, val_acc, F_Xlr, F_Xlr_score_flat = eval_during_train_multipro(
                            cfg, gpus)
                else:
                    val_iou, val_acc = eval_during_train_multipro(cfg, gpus)
            else:
                if cfg.VAL.dice:
                    if cfg.MODEL.foveation:
                        if cfg.VAL.all_F_Xlr_time:
                            val_iou, val_dice, val_acc, F_Xlr_all, F_Xlr_score_flat_all = eval_during_train(
                                cfg)
                        else:
                            val_iou, val_dice, val_acc, F_Xlr, F_Xlr_score_flat = eval_during_train(
                                cfg)
                    else:
                        val_iou, val_dice, val_acc = eval_during_train(cfg)
                else:
                    if cfg.MODEL.foveation:
                        if cfg.VAL.all_F_Xlr_time:
                            val_iou, val_acc, F_Xlr_all, F_Xlr_score_flat_all = eval_during_train(
                                cfg)
                        else:
                            val_iou, val_acc, F_Xlr, F_Xlr_score_flat = eval_during_train(
                                cfg)
                    else:
                        val_iou, val_acc = eval_during_train(cfg)

            history['save']['epoch'].append(epoch + 1)
            history['save']['train_loss'].append(history['train']['loss'][-1])
            history['save']['train_acc'].append(history['train']['acc'][-1] *
                                                100)
            history['save']['val_iou'].append(val_iou)
            if cfg.VAL.dice:
                history['save']['val_dice'].append(val_dice)
            history['save']['val_acc'].append(val_acc)
            # write to tensorboard
            writer.add_scalar('Loss/train', history['train']['loss'][-1],
                              epoch + 1)
            writer.add_scalar('Acc/train', history['train']['acc'][-1] * 100,
                              epoch + 1)
            writer.add_scalar('Acc/val', val_acc, epoch + 1)
            writer.add_scalar('mIoU/val', val_iou, epoch + 1)
            if cfg.VAL.dice:
                writer.add_scalar('mDice/val', val_acc, epoch + 1)
            if cfg.VAL.all_F_Xlr_time:
                print('=============F_Xlr_score_flat_all================\n',
                      F_Xlr_score_flat_all.shape)
                for p in range(F_Xlr_score_flat_all.shape[0]):
                    # add small artifact to modify range, because no range flag in add_histogram
                    F_Xlr_score_flat_all[p][0] = 0
                    F_Xlr_score_flat_all[p][-1] = 1
                    writer.add_histogram(
                        'Patch_{} probability histogram'.format(p),
                        F_Xlr_score_flat_all[p], epoch + 1)
                    f = open(
                        os.path.join(cfg.DIR, 'Fov_probability_distribution',
                                     'patch_{}_distribution.txt'.format(p)),
                        'a')
                    if epoch == 0:
                        f.write('epoch/ bins: {}\n'.format(
                            np.histogram(F_Xlr_score_flat_all[p],
                                         bins=10,
                                         range=(0, 1))[1]))
                    f.write('epoch {}: {}\n'.format(
                        epoch + 1,
                        np.histogram(F_Xlr_score_flat_all[p],
                                     bins=10,
                                     range=(0, 1))[0] / sum(
                                         np.histogram(F_Xlr_score_flat_all[p],
                                                      bins=10,
                                                      range=(0, 1))[0])))
                    f.close()
                writer.add_histogram('Patch_All probability histogram',
                                     F_Xlr_score_flat_all, epoch + 1)
            else:
                for p in range(F_Xlr_score_flat_all.shape[0]):
                    F_Xlr_score_flat[p][0] = 0
                    F_Xlr_score_flat[p][-1] = 1
                    writer.add_histogram(
                        'Patch_{} probability histogram'.format(p),
                        F_Xlr_score_flat[p], epoch + 1)
                writer.add_histogram('Patch_All probability histogram',
                                     F_Xlr_score_flat, epoch + 1)
        else:
            history['save']['epoch'].append(epoch + 1)
            history['save']['train_loss'].append(history['train']['loss'][-1])
            history['save']['train_acc'].append(history['train']['acc'][-1] *
                                                100)
            history['save']['val_iou'].append('')
            if cfg.VAL.dice:
                history['save']['val_dice'].append('')
            history['save']['val_acc'].append('')
            # write to tensorboard
            writer.add_scalar('Loss/train', history['train']['loss'][-1],
                              epoch + 1)
            writer.add_scalar('Acc/train', history['train']['acc'][-1] * 100,
                              epoch + 1)
            # writer.add_scalar('Acc/val', '', epoch+1)
            # writer.add_scalar('mIoU/val', '', epoch+1)

        # saving history
        checkpoint_history(history, cfg, epoch + 1)

        if (epoch + 1) % cfg.TRAIN.eval_per_epoch == 0:
            # output F_Xlr
            if cfg.MODEL.foveation:
                # save time series F_Xlr (t,b,d,w,h)
                if epoch == 0 or epoch == cfg.TRAIN.start_epoch:
                    if cfg.VAL.all_F_Xlr_time:
                        F_Xlr_time_all = []
                        for val_idx in range(len(F_Xlr_all)):
                            F_Xlr_time_all.append(F_Xlr_all[val_idx][0])
                    else:
                        F_Xlr_time = F_Xlr
                else:
                    if cfg.VAL.all_F_Xlr_time:
                        for val_idx in range(len(F_Xlr_all)):
                            F_Xlr_time_all[val_idx] = np.concatenate(
                                (F_Xlr_time_all[val_idx],
                                 F_Xlr_all[val_idx][0]),
                                axis=0)
                    else:
                        F_Xlr_time = np.concatenate((F_Xlr_time, F_Xlr),
                                                    axis=0)
                if cfg.VAL.all_F_Xlr_time:
                    for val_idx in range(len(F_Xlr_all)):
                        print('F_Xlr_time_{}'.format(F_Xlr_all[val_idx][1]),
                              F_Xlr_time_all[val_idx].shape)
                        if not os.path.isdir(
                                os.path.join(cfg.DIR, "F_Xlr_time_all_vals")):
                            os.makedirs(
                                os.path.join(cfg.DIR, "F_Xlr_time_all_vals"))
                        np.save(
                            '{}/F_Xlr_time_all_vals/F_Xlr_time_last_{}.npy'.
                            format(cfg.DIR, F_Xlr_all[val_idx][1]),
                            F_Xlr_time_all[val_idx])
                else:
                    print('F_Xlr_time', F_Xlr_time.shape)
                    np.save('{}/F_Xlr_time_last.npy'.format(cfg.DIR),
                            F_Xlr_time)

    if not cfg.TRAIN.save_checkpoint:
        os.remove('{}/encoder_epoch_last.pth'.format(cfg.DIR))
        os.remove('{}/decoder_epoch_last.pth'.format(cfg.DIR))
    print('Training Done!')
    writer.close()