Example #1
0
File: eval.py Project: dingmyu/psa
def main():
    global args, logger
    args = get_parser().parse_args()
    logger = get_logger()
    # os.environ["CUDA_VISIBLE_DEVICES"] = ','.join(str(x) for x in args.gpu)
    logger.info(args)
    assert args.classes > 1
    assert args.zoom_factor in [1, 2, 4, 8]
    assert (args.crop_h - 1) % 8 == 0 and (args.crop_w - 1) % 8 == 0
    assert args.split in ['train', 'val', 'test']
    logger.info("=> creating model ...")
    logger.info("Classes: {}".format(args.classes))

    value_scale = 255
    mean = [0.485, 0.456, 0.406]
    mean = [item * value_scale for item in mean]
    std = [0.229, 0.224, 0.225]
    std = [item * value_scale for item in std]

    train_transform = transforms.Compose([
        transforms.RandScale([0.5, 2]),
        transforms.RandRotate([-10, 10], padding=mean, ignore_label=args.ignore_label),
        transforms.RandomGaussianBlur(),
        transforms.RandomHorizontalFlip(),
        transforms.Crop([args.crop_h, args.crop_w], crop_type='rand', padding=mean, ignore_label=args.ignore_label),
        transforms.ToTensor()])

    val_transform = transforms.Compose([transforms.Crop([args.crop_h, args.crop_w], crop_type='center', padding=mean, ignore_label=args.ignore_label),
                                        transforms.ToTensor()])
    val_data1 = datasets.SegData(split='train', data_root=args.data_root, data_list=args.val_list1, transform=val_transform)
    val_loader1 = torch.utils.data.DataLoader(val_data1, batch_size=1, shuffle=False, num_workers=args.workers, pin_memory=True)

    from pspnet import PSPNet
    model = PSPNet(backbone = args.backbone, layers=args.layers, classes=args.classes, zoom_factor=args.zoom_factor, use_softmax=False, use_aux=False, pretrained=False, syncbn=False).cuda()

    logger.info(model)
  #  model = torch.nn.DataParallel(model).cuda()
    model = model.cuda()
    cudnn.enabled = True
    cudnn.benchmark = True
    if os.path.isfile(args.model_path):
        logger.info("=> loading checkpoint '{}'".format(args.model_path))
        checkpoint = torch.load(args.model_path)
     #   model.load_state_dict(checkpoint['state_dict'], strict=False)
     #   logger.info("=> loaded checkpoint '{}'".format(args.model_path))


        pretrained_dict = {k.replace('module.',''): v for k, v in checkpoint['state_dict'].items()}

        dict1 = model.state_dict()
        model.load_state_dict(pretrained_dict, strict=False)

    else:
        raise RuntimeError("=> no checkpoint found at '{}'".format(args.model_path))
    cv2.setNumThreads(0)
    validate(val_loader1, val_data1.data_list, model, args.classes, mean, std, args.base_size1, args.crop_h, args.crop_w, args.scales)
Example #2
0
def main():
    global args, logger, writer
    args = get_parser().parse_args()
    import multiprocessing as mp
    if mp.get_start_method(allow_none=True) != 'spawn':
        mp.set_start_method('spawn', force=True)
    rank, world_size = dist_init(args.port)
    logger = get_logger()
    writer = SummaryWriter(args.save_path)
    #os.environ["CUDA_VISIBLE_DEVICES"] = ','.join(str(x) for x in args.gpu)
    #if len(args.gpu) == 1:
    #   args.syncbn = False
    if rank == 0:
        logger.info(args)

    if args.bn_group == 1:
        args.bn_group_comm = None
    else:
        assert world_size % args.bn_group == 0
        args.bn_group_comm = simple_group_split(world_size, rank, world_size // args.bn_group)

    if rank == 0:
        logger.info("=> creating model ...")
        logger.info("Classes: {}".format(args.classes))

    from pspnet import PSPNet
    model = PSPNet(backbone=args.backbone, layers=args.layers, classes=args.classes, zoom_factor=args.zoom_factor, syncbn=args.syncbn, group_size=args.bn_group, group=args.bn_group_comm).cuda()
    logger.info(model)
    model_ppm = PPM().cuda()
    # optimizer = torch.optim.SGD(model.parameters(), args.base_lr, momentum=args.momentum, weight_decay=args.weight_decay)
    # newly introduced layer with lr x10
    optimizer = torch.optim.SGD(
        [{'params': model.layer0.parameters()},
         {'params': model.layer1.parameters()},
         {'params': model.layer2.parameters()},
         {'params': model.layer3.parameters()},
         {'params': model.layer4_ICR.parameters()},
         {'params': model.layer4_PFR.parameters()},
         {'params': model.layer4_PRP.parameters()},
         {'params': model_ppm.cls_trans.parameters(), 'lr': args.base_lr * 10},
         {'params': model_ppm.cls_quat.parameters(), 'lr': args.base_lr * 10}
        ],
        lr=args.base_lr, momentum=args.momentum, weight_decay=args.weight_decay)

    #model = torch.nn.DataParallel(model).cuda()
    model = DistModule(model)
    model_ppm = DistModule(model_ppm)
    cudnn.enabled = True
    cudnn.benchmark = True
    criterion = nn.L1Loss().cuda()

    if args.weight:
        def map_func(storage, location):
            return storage.cuda()
        if os.path.isfile(args.weight):
            logger.info("=> loading weight '{}'".format(args.weight))
            checkpoint = torch.load(args.weight, map_location=map_func)
            model.load_state_dict(checkpoint['state_dict'])
            logger.info("=> loaded weight '{}'".format(args.weight))
        else:
            logger.info("=> no weight found at '{}'".format(args.weight))

    if args.resume:
        load_state(args.resume, model, model_ppm, optimizer)

    value_scale = 255
    mean = [0.485, 0.456, 0.406]
    mean = [item * value_scale for item in mean]
    std = [0.229, 0.224, 0.225]
    std = [item * value_scale for item in std]

    train_transform = transforms.Compose([
        transforms.Resize(size=(256,256)),
        #transforms.RandomGaussianBlur(),
        transforms.Crop([args.crop_h, args.crop_w], crop_type='rand', padding=mean, ignore_label=args.ignore_label),
        transforms.ColorJitter([0.4,0.4,0.4]),
        transforms.ToTensor(),
        transforms.Normalize(mean=mean, std=std)])
    
    train_data = datasets.SegData(split='train', data_root=args.data_root, data_list=args.train_list, transform=train_transform)
    train_sampler = DistributedSampler(train_data)
    train_loader = torch.utils.data.DataLoader(train_data, batch_size=args.batch_size, shuffle=False, num_workers=args.workers, pin_memory=False, sampler=train_sampler)

    if args.evaluate:
        val_transform = transforms.Compose([
            transforms.Resize(size=(256,256)),
            transforms.Crop([args.crop_h, args.crop_w], crop_type='center', padding=mean, ignore_label=args.ignore_label),
            transforms.ToTensor(),
            transforms.Normalize(mean=mean, std=std)])
        val_data = datasets.SegData(split='val', data_root=args.data_root, data_list=args.val_list, transform=val_transform)
        val_sampler = DistributedSampler(val_data)
        val_loader = torch.utils.data.DataLoader(val_data, batch_size=args.batch_size_val, shuffle=False, num_workers=args.workers, pin_memory=True, sampler=val_sampler)

    for epoch in range(args.start_epoch, args.epochs + 1):
        t_loss_train, r_loss_train= train(train_loader, model, model_ppm, criterion, optimizer, epoch, args.zoom_factor, args.batch_size, args.aux_weight)
        if rank == 0:
            writer.add_scalar('t_loss_train', t_loss_train, epoch)
            writer.add_scalar('r_loss_train', r_loss_train, epoch)
        # write parameters histogram costs lots of time
        # for name, param in model.named_parameters():
        #     writer.add_histogram(name, param, epoch)

        if epoch % args.save_step == 0 and rank == 0:
            filename = args.save_path + '/train_epoch_' + str(epoch) + '.pth'
            logger.info('Saving checkpoint to: ' + filename)
            filename_ppm = args.save_path + '/train_epoch_' + str(epoch) + '_ppm.pth'
            torch.save({'epoch': epoch, 'state_dict': model.state_dict(), 'optimizer': optimizer.state_dict()}, filename)
            torch.save({'epoch': epoch, 'state_dict': model_ppm.state_dict(), 'optimizer': optimizer.state_dict()}, filename_ppm)
            #if epoch / args.save_step > 2:
            #    deletename = args.save_path + '/train_epoch_' + str(epoch - args.save_step*2) + '.pth'
            #    os.remove(deletename)
        if args.evaluate:
            t_loss_val, r_loss_val= validate(val_loader, model, model_ppm, criterion)
            writer.add_scalar('t_loss_val', t_loss_val, epoch)
            writer.add_scalar('r_loss_val', r_loss_val, epoch)
    writer.close()
def main():
    global args, logger, writer
    args = get_parser().parse_args()
    import multiprocessing as mp
    if mp.get_start_method(allow_none=True) != 'spawn':
        mp.set_start_method('spawn', force=True)
    rank, world_size = dist_init(args.port)
    logger = get_logger()
    writer = SummaryWriter(args.save_path)
    #os.environ["CUDA_VISIBLE_DEVICES"] = ','.join(str(x) for x in args.gpu)
    #if len(args.gpu) == 1:
    #   args.syncbn = False
    if rank == 0:
        logger.info(args)
    assert args.classes > 1
    assert args.zoom_factor in [1, 2, 4, 8]
    assert (args.crop_h-1) % 8 == 0 and (args.crop_w-1) % 8 == 0
    assert args.net_type in [0, 1, 2, 3]

    if args.bn_group == 1:
        args.bn_group_comm = None
    else:
        assert world_size % args.bn_group == 0
        args.bn_group_comm = simple_group_split(world_size, rank, world_size // args.bn_group)

    if rank == 0:
        logger.info("=> creating model ...")
        logger.info("Classes: {}".format(args.classes))

    if args.net_type == 0:
        from pspnet import PSPNet
        model = PSPNet(backbone=args.backbone, layers=args.layers, classes=args.classes, zoom_factor=args.zoom_factor, syncbn=args.syncbn, group_size=args.bn_group, group=args.bn_group_comm).cuda()
    elif  args.net_type in [1, 2, 3]:
        from pspnet_div4 import PSPNet
        model = PSPNet(backbone=args.backbone, layers=args.layers, classes=args.classes, zoom_factor=args.zoom_factor, syncbn=args.syncbn, group_size=args.bn_group, group=args.bn_group_comm, net_type=args.net_type).cuda()
    logger.info(model)

    # optimizer = torch.optim.SGD(model.parameters(), args.base_lr, momentum=args.momentum, weight_decay=args.weight_decay)
    # newly introduced layer with lr x10
    if args.net_type == 0:
        optimizer = torch.optim.SGD(
            [{'params': model.layer0.parameters()},
             {'params': model.layer1.parameters()},
             {'params': model.layer2.parameters()},
             {'params': model.layer3.parameters()},
             {'params': model.layer4.parameters()},
             {'params': model.ppm.parameters(), 'lr': args.base_lr * 10},
			 {'params': model.conv6.parameters(), 'lr': args.base_lr * 10},
			 {'params': model.conv1_1x1.parameters(), 'lr': args.base_lr * 10},
             {'params': model.cls.parameters(), 'lr': args.base_lr * 10},
             {'params': model.aux.parameters(), 'lr': args.base_lr * 10}],
            lr=args.base_lr, momentum=args.momentum, weight_decay=args.weight_decay)
    elif args.net_type == 1:
        optimizer = torch.optim.SGD(
            [{'params': model.layer0.parameters()},
             {'params': model.layer1.parameters()},
             {'params': model.layer2.parameters()},
             {'params': model.layer3.parameters()},
             {'params': model.layer4.parameters()},
             {'params': model.layer4_p.parameters()},
             {'params': model.ppm.parameters(), 'lr': args.base_lr * 10},
             {'params': model.ppm_p.parameters(), 'lr': args.base_lr * 10},
             {'params': model.cls.parameters(), 'lr': args.base_lr * 10},
             {'params': model.cls_p.parameters(), 'lr': args.base_lr * 10},
             {'params': model.aux.parameters(), 'lr': args.base_lr * 10}],
            lr=args.base_lr, momentum=args.momentum, weight_decay=args.weight_decay)
    elif args.net_type == 2:
        optimizer = torch.optim.SGD(
            [{'params': model.layer0.parameters()},
             {'params': model.layer1.parameters()},
             {'params': model.layer2.parameters()},
             {'params': model.layer3.parameters()},
             {'params': model.layer4.parameters()},
             {'params': model.layer4_p.parameters()},
             {'params': model.ppm.parameters(), 'lr': args.base_lr * 10},
             {'params': model.ppm_p.parameters(), 'lr': args.base_lr * 10},
             {'params': model.cls.parameters(), 'lr': args.base_lr * 10},
             {'params': model.cls_p.parameters(), 'lr': args.base_lr * 10},
             {'params': model.att.parameters(), 'lr': args.base_lr * 10},
             {'params': model.att_p.parameters(), 'lr': args.base_lr * 10},
             {'params': model.aux.parameters(), 'lr': args.base_lr * 10}],
            lr=args.base_lr, momentum=args.momentum, weight_decay=args.weight_decay)
    elif args.net_type == 3:
        optimizer = torch.optim.SGD(
            [{'params': model.layer0.parameters()},
             {'params': model.layer1.parameters()},
             {'params': model.layer2.parameters()},
             {'params': model.layer3.parameters()},
             {'params': model.layer4.parameters()},
             {'params': model.layer4_p.parameters()},
             {'params': model.ppm.parameters(), 'lr': args.base_lr * 10},
             {'params': model.ppm_p.parameters(), 'lr': args.base_lr * 10},
             {'params': model.cls.parameters(), 'lr': args.base_lr * 10},
             {'params': model.cls_p.parameters(), 'lr': args.base_lr * 10},
             {'params': model.att.parameters(), 'lr': args.base_lr * 10},
             {'params': model.aux.parameters(), 'lr': args.base_lr * 10}],
            lr=args.base_lr, momentum=args.momentum, weight_decay=args.weight_decay)


    fcw = V11RFCN()
    fcw_model =  torch.load('checkpoint_e8.pth')['state_dict']
    fcw_dict = fcw.state_dict()
    pretrained_fcw = {k: v for k, v in fcw_model.items() if k in fcw_dict}
    fcw_dict.update(pretrained_fcw)
    fcw.load_state_dict(fcw_dict)
    #fcw = DistModule(fcw)
    #print(fcw)
    fcw = fcw.cuda()


    #model = torch.nn.DataParallel(model).cuda()
    model = DistModule(model)
    #if args.syncbn:
    #    from lib.syncbn import patch_replication_callback
    #    patch_replication_callback(model)

    cudnn.enabled = True
    cudnn.benchmark = True
    criterion = nn.NLLLoss(ignore_index=args.ignore_label).cuda()

    if args.weight:
        def map_func(storage, location):
            return storage.cuda()
        if os.path.isfile(args.weight):
            logger.info("=> loading weight '{}'".format(args.weight))
            checkpoint = torch.load(args.weight, map_location=map_func)['state_dict']
            checkpoint = {k: v for k, v in checkpoint.items() if 'ppm' not in k}
            model_dict = model.state_dict()
            model_dict.update(checkpoint)
            model.load_state_dict(model_dict)
            logger.info("=> loaded weight '{}'".format(args.weight))
        else:
            logger.info("=> no weight found at '{}'".format(args.weight))

    if args.resume:
        load_state(args.resume, model, optimizer)

    value_scale = 255
    mean = [0.485, 0.456, 0.406]
    mean = [item * value_scale for item in mean]
    std = [0.229, 0.224, 0.225]
    std = [item * value_scale for item in std]

    train_transform = transforms.Compose([
        transforms.RandScale([args.scale_min, args.scale_max]),
        #transforms.RandRotate([args.rotate_min, args.rotate_max], padding=mean, ignore_label=args.ignore_label),
        transforms.RandomGaussianBlur(),
        transforms.RandomHorizontalFlip(),
        transforms.Crop([args.crop_h, args.crop_w], crop_type='rand', padding=mean, ignore_label=args.ignore_label),
        transforms.ToTensor(),
        transforms.Normalize(mean=mean, std=std)])
    train_data = datasets.SegData(split='train', data_root=args.data_root, data_list=args.train_list, transform=train_transform)
    train_sampler = DistributedSampler(train_data)
    train_loader = torch.utils.data.DataLoader(train_data, batch_size=args.batch_size, shuffle=False, num_workers=args.workers, pin_memory=False, sampler=train_sampler)

    if args.evaluate:
        val_transform = transforms.Compose([
            transforms.Crop([args.crop_h, args.crop_w], crop_type='center', padding=mean, ignore_label=args.ignore_label),
            transforms.ToTensor(),
            transforms.Normalize(mean=mean, std=std)])
        val_data = datasets.SegData(split='val', data_root=args.data_root, data_list=args.val_list, transform=val_transform)
        val_loader = torch.utils.data.DataLoader(val_data, batch_size=args.batch_size_val, shuffle=False, num_workers=args.workers, pin_memory=True)

    for epoch in range(args.start_epoch, args.epochs + 1):
        loss_train, mIoU_train, mAcc_train, allAcc_train = train(train_loader, model, criterion, optimizer, epoch, args.zoom_factor, args.batch_size, args.aux_weight, fcw)
        if rank == 0:
            writer.add_scalar('loss_train', loss_train, epoch)
            writer.add_scalar('mIoU_train', mIoU_train, epoch)
            writer.add_scalar('mAcc_train', mAcc_train, epoch)
            writer.add_scalar('allAcc_train', allAcc_train, epoch)
        # write parameters histogram costs lots of time
        # for name, param in model.named_parameters():
        #     writer.add_histogram(name, param, epoch)

        if epoch % args.save_step == 0 and rank == 0:
            filename = args.save_path + '/train_epoch_' + str(epoch) + '.pth'
            logger.info('Saving checkpoint to: ' + filename)
            torch.save({'epoch': epoch, 'state_dict': model.state_dict(), 'optimizer': optimizer.state_dict()}, filename)
            #if epoch / args.save_step > 2:
            #    deletename = args.save_path + '/train_epoch_' + str(epoch - args.save_step*2) + '.pth'
            #    os.remove(deletename)
        if args.evaluate:
            loss_val, mIoU_val, mAcc_val, allAcc_val = validate(val_loader, model, criterion, args.classes, args.zoom_factor)
            writer.add_scalar('loss_val', loss_val, epoch)
            writer.add_scalar('mIoU_val', mIoU_val, epoch)
            writer.add_scalar('mAcc_val', mAcc_val, epoch)
            writer.add_scalar('allAcc_val', allAcc_val, epoch)
Example #4
0
def main():
    global args, logger, writer
    args = get_parser().parse_args()
    logger = get_logger()
    writer = SummaryWriter(args.save_path)
    # os.environ["CUDA_VISIBLE_DEVICES"] = ','.join(str(x) for x in args.gpu)

    if args.dist:
        dist_init(args.port, backend=args.backend)

    if len(args.gpu) == 1:
        args.syncbn = False
    logger.info(args)
    assert args.classes > 1
    assert args.zoom_factor in [1, 2, 4, 8]
    assert (args.crop_h - 1) % 8 == 0 and (args.crop_w - 1) % 8 == 0

    world_size = 1
    rank = 0
    if args.dist:
        rank = dist.get_rank()
        world_size = dist.get_world_size()
    if rank == 0:
        logger.info('dist:{}'.format(args.dist))
        logger.info("=> creating model ...")
        logger.info("Classes: {}".format(args.classes))

    # rank = dist.get_rank()

    if args.bn_group > 1:
        args.syncbn = True
        bn_sync_stats = True
        bn_group_comm = simple_group_split(world_size, rank,
                                           world_size // args.bn_group)
    else:
        args.syncbn = False
        bn_sync_stats = False
        bn_group_comm = None

    model = PSPNet(layers=args.layers,
                   classes=args.classes,
                   zoom_factor=args.zoom_factor,
                   syncbn=args.syncbn,
                   group_size=args.bn_group,
                   group=bn_group_comm,
                   sync_stats=bn_sync_stats)
    if rank == 0:
        logger.info(model)
    # optimizer = torch.optim.SGD(model.parameters(), args.base_lr, momentum=args.momentum, weight_decay=args.weight_decay)
    # newly introduced layer with lr x10
    optimizer = torch.optim.SGD([{
        'params': model.layer0.parameters()
    }, {
        'params': model.layer1.parameters()
    }, {
        'params': model.layer2.parameters()
    }, {
        'params': model.layer3.parameters()
    }, {
        'params': model.layer4.parameters()
    }, {
        'params': model.ppm.parameters(),
        'lr': args.base_lr * 10
    }, {
        'params': model.cls.parameters(),
        'lr': args.base_lr * 10
    }, {
        'params': model.aux.parameters(),
        'lr': args.base_lr * 10
    }],
                                lr=args.base_lr,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)
    # model = torch.nn.DataParallel(model).cuda()
    # if args.syncbn:
    #     from lib.syncbn import patch_replication_callback
    #     patch_replication_callback(model)
    model = model.cuda()

    cudnn.enabled = True
    cudnn.benchmark = True
    criterion = nn.NLLLoss(ignore_index=args.ignore_label)

    if args.weight:

        def map_func(storage, location):
            return storage.cuda()

        if os.path.isfile(args.weight):
            logger.info("=> loading weight '{}'".format(args.weight))
            checkpoint = torch.load(args.weight, map_location=map_func)
            model.load_state_dict(checkpoint['state_dict'])
            logger.info("=> loaded weight '{}'".format(args.weight))
        else:
            logger.info("=> no weight found at '{}'".format(args.weight))

    if args.resume:
        if os.path.isfile(args.resume):
            logger.info("=> loading checkpoint '{}'".format(args.resume))
            # checkpoint = torch.load(args.resume)
            # args.start_epoch = checkpoint['epoch']
            # model.load_state_dict(checkpoint['state_dict'])
            # optimizer.load_state_dict(checkpoint['optimizer'])
            model, optimizer, args.start_epoch = restore_from(
                model, optimizer, args.resume)

            logger.info("=> loaded checkpoint '{}' (epoch {})".format(
                args.resume, args.start_epoch))
        else:
            logger.info("=> no checkpoint found at '{}'".format(args.resume))

    if args.dist:
        broadcast_params(model)

    value_scale = 255
    mean = [0.485, 0.456, 0.406]
    mean = [item * value_scale for item in mean]
    std = [0.229, 0.224, 0.225]
    std = [item * value_scale for item in std]

    train_transform = transforms.Compose([
        transforms.RandScale([args.scale_min, args.scale_max]),
        transforms.RandRotate([args.rotate_min, args.rotate_max],
                              padding=mean,
                              ignore_label=args.ignore_label),
        transforms.RandomGaussianBlur(),
        transforms.RandomHorizontalFlip(),
        transforms.Crop([args.crop_h, args.crop_w],
                        crop_type='rand',
                        padding=mean,
                        ignore_label=args.ignore_label),
        transforms.ToTensor(),
        transforms.Normalize(mean=mean, std=std)
    ])
    train_data = datasets.SegData(split='train',
                                  data_root=args.data_root,
                                  data_list=args.train_list,
                                  transform=train_transform)
    train_sampler = None
    if args.dist:
        train_sampler = DistributedSampler(train_data)

    train_loader = torch.utils.data.DataLoader(
        train_data,
        batch_size=args.batch_size,
        shuffle=False if train_sampler else True,
        num_workers=args.workers,
        pin_memory=False,
        sampler=train_sampler)

    if args.evaluate:
        val_transform = transforms.Compose([
            transforms.Crop([args.crop_h, args.crop_w],
                            crop_type='center',
                            padding=mean,
                            ignore_label=args.ignore_label),
            transforms.ToTensor(),
            transforms.Normalize(mean=mean, std=std)
        ])
        val_data = datasets.SegData(split='val',
                                    data_root=args.data_root,
                                    data_list=args.val_list,
                                    transform=val_transform)
        val_sampler = None
        if args.dist:
            val_sampler = DistributedSampler(val_data)
        val_loader = torch.utils.data.DataLoader(
            val_data,
            batch_size=args.batch_size_val,
            shuffle=False,
            num_workers=args.workers,
            pin_memory=False,
            sampler=val_sampler)

    for epoch in range(args.start_epoch, args.epochs + 1):
        loss_train, mIoU_train, mAcc_train, allAcc_train = train(
            train_loader, model, criterion, optimizer, epoch, args.zoom_factor,
            args.batch_size, args.aux_weight)
        writer.add_scalar('loss_train', loss_train.cpu().numpy(), epoch)
        writer.add_scalar('mIoU_train', mIoU_train, epoch)
        writer.add_scalar('mAcc_train', mAcc_train, epoch)
        writer.add_scalar('allAcc_train', allAcc_train, epoch)
        # write parameters histogram costs lots of time
        # for name, param in model.named_parameters():
        #     writer.add_histogram(name, param, epoch)

        if args.evaluate and rank == 0:
            loss_val, mIoU_val, mAcc_val, allAcc_val = validate(
                val_loader, model, criterion, args.classes, args.zoom_factor)
            writer.add_scalar('loss_val', loss_val.cpu().numpy(), epoch)
            writer.add_scalar('mIoU_val', mIoU_val, epoch)
            writer.add_scalar('mAcc_val', mAcc_val, epoch)
            writer.add_scalar('allAcc_val', allAcc_val, epoch)

        if epoch % args.save_step == 0 and (rank == 0):
            filename = args.save_path + '/train_epoch_' + str(epoch) + '.pth'
            logger.info('Saving checkpoint to: ' + filename)
            torch.save(
                {
                    'epoch': epoch,
                    'state_dict': model.cpu().state_dict(),
                    'optimizer': optimizer.state_dict()
                }, filename)
Example #5
0
def main():
    global args, logger
    args = get_parser().parse_args()
    logger = get_logger()
    # os.environ["CUDA_VISIBLE_DEVICES"] = ','.join(str(x) for x in args.gpu)
    logger.info(args)
    logger.info("=> creating model ...")
    value_scale = 255
    mean = [0.485, 0.456, 0.406]
    mean = [item * value_scale for item in mean]
    std = [0.229, 0.224, 0.225]
    std = [item * value_scale for item in std]

    val_transform = transforms.Compose([
        transforms.Resize(size=(256, 256)),
        transforms.Crop([args.crop_h, args.crop_w],
                        crop_type='center',
                        padding=mean,
                        ignore_label=args.ignore_label),
        transforms.ToTensor(),
        transforms.Normalize(mean=mean, std=std)
    ])
    val_data1 = datasets.SegData(split=args.split,
                                 data_root=args.data_root,
                                 data_list=args.val_list1,
                                 transform=val_transform)
    val_loader1 = torch.utils.data.DataLoader(val_data1,
                                              batch_size=args.batch_size,
                                              shuffle=False,
                                              num_workers=args.workers,
                                              pin_memory=True)

    model_pfr = PFR().cuda()
    model_pfr = torch.nn.DataParallel(model_pfr)
    model_prp = PRP().cuda()
    model_prp = torch.nn.DataParallel(model_prp)

    from pspnet import PSPNet
    model = PSPNet(backbone=args.backbone,
                   layers=args.layers,
                   classes=args.classes,
                   zoom_factor=args.zoom_factor,
                   use_softmax=True,
                   pretrained=False,
                   syncbn=False).cuda()
    logger.info(model)
    model = torch.nn.DataParallel(model).cuda()
    cudnn.enabled = True
    cudnn.benchmark = True
    if os.path.isfile(args.model_path):
        logger.info("=> loading checkpoint '{}'".format(args.model_path))
        checkpoint = torch.load(args.model_path)
        model.load_state_dict(checkpoint['state_dict'], strict=False)
        logger.info("=> loaded checkpoint '{}'".format(args.model_path))
    else:
        raise RuntimeError("=> no checkpoint found at '{}'".format(
            args.model_path))

    checkpoint_pfr = torch.load(args.model_path.replace('.pth', '_pfr.pth'))
    checkpoint_prp = torch.load(args.model_path.replace('.pth', '_prp.pth'))
    model_pfr.load_state_dict(checkpoint_pfr['state_dict'], strict=False)
    model_prp.load_state_dict(checkpoint_prp['state_dict'], strict=False)

    cv2.setNumThreads(0)

    validate(val_loader1, val_data1.data_list, model, model_pfr, model_prp)
Example #6
0
def main():
    """Create the model and start the training."""
    with open(args.config) as f:
        config = yaml.load(f)
    for k, v in config['common'].items():
        setattr(args, k, v)
    mkdirs(osp.join("logs/"+args.exp_name))

    logger = create_logger('global_logger', "logs/" + args.exp_name + '/log.txt')
    logger.info('{}'.format(args))
##############################

    for key, val in vars(args).items():
        logger.info("{:16} {}".format(key, val))
    logger.info("random_scale {}".format(args.random_scale))
    logger.info("is_training {}".format(args.is_training))

    h, w = map(int, args.input_size.split(','))
    input_size = (h, w)

    h, w = map(int, args.input_size_target.split(','))
    input_size_target = (h, w)
    print(type(input_size_target[1]))
    cudnn.enabled = True
    args.snapshot_dir = args.snapshot_dir + args.exp_name
    tb_logger = SummaryWriter("logs/"+args.exp_name)
##############################

#validation data
    h, w = map(int, args.input_size_test.split(','))
    input_size_test = (h,w)
    h, w = map(int, args.com_size.split(','))
    com_size = (h, w)
    h, w = map(int, args.input_size_crop.split(','))
    input_size_crop = h,w
    h,w = map(int, args.input_size_target_crop.split(','))
    input_size_target_crop = h,w


    test_normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                       std=[0.229, 0.224, 0.225])
    test_transform = transforms.Compose([
                         transforms.Resize((input_size_test[1], input_size_test[0])),
                         transforms.ToTensor(),
                         test_normalize])

    valloader = data.DataLoader(cityscapesDataSet(
                                       args.data_dir_target,
                                       args.data_list_target_val,
                                       crop_size=input_size_test,
                                       set='train',
                                       transform=test_transform),num_workers=args.num_workers,
                                 batch_size=1, shuffle=False, pin_memory=True)
    with open('./dataset/cityscapes_list/info.json', 'r') as fp:
        info = json.load(fp)
    mapping = np.array(info['label2train'], dtype=np.int)
    label_path_list_val = args.label_path_list_val
    label_path_list_test = args.label_path_list_test
    label_path_list_test = './dataset/cityscapes_list/label.txt'
    gt_imgs_val = open(label_path_list_val, 'r').read().splitlines()
    gt_imgs_val = [osp.join(args.data_dir_target_val, x) for x in gt_imgs_val]
    testloader = data.DataLoader(cityscapesDataSet(
                                    args.data_dir_target,
                                    args.data_list_target_test,
                                    crop_size=input_size_test,
                                    set='val',
                                    transform=test_transform),
                            num_workers=args.num_workers,
                            batch_size=1,
                            shuffle=False, pin_memory=True)

    gt_imgs_test = open(label_path_list_test ,'r').read().splitlines()
    gt_imgs_test = [osp.join(args.data_dir_target_test, x) for x in gt_imgs_test]

    name_classes = np.array(info['label'], dtype=np.str)
    interp_val = nn.Upsample(size=(com_size[1], com_size[0]),mode='bilinear', align_corners=True)

    ####
    #build model
    ####
    builder = ModelBuilder()
    net_encoder = builder.build_encoder(
        arch=args.arch_encoder,
        fc_dim=args.fc_dim,
        weights=args.weights_encoder)
    net_decoder = builder.build_decoder(
        arch=args.arch_decoder,
        fc_dim=args.fc_dim,
        num_class=args.num_classes,
        weights=args.weights_decoder,
        use_aux=True)



    model = SegmentationModule(
        net_encoder, net_decoder, args.use_aux)

    if args.num_gpus > 1:
        model = torch.nn.DataParallel(model)
        patch_replication_callback(model)
    model.cuda()

    nets = (net_encoder, net_decoder, None, None)
    optimizers = create_optimizer(nets, args)
    cudnn.enabled=True
    cudnn.benchmark=True
    model.train()



    mean=[0.485, 0.456, 0.406]
    std=[0.229, 0.224, 0.225]


    source_normalize = transforms_seg.Normalize(mean=mean,
                                                std=std)

    mean_mapping = [0.485, 0.456, 0.406]
    mean_mapping = [item * 255 for item in mean_mapping]

    if not os.path.exists(args.snapshot_dir):
        os.makedirs(args.snapshot_dir)
    source_transform = transforms_seg.Compose([
                             transforms_seg.Resize([input_size[1], input_size[0]]),
                             segtransforms.RandScale((args.scale_min, args.scale_max)),
                             #segtransforms.RandRotate((args.rotate_min, args.rotate_max), padding=mean_mapping, ignore_label=args.ignore_label),
                             #segtransforms.RandomGaussianBlur(),
                             segtransforms.RandomHorizontalFlip(),
                             segtransforms.Crop([input_size_crop[1], input_size_crop[0]], crop_type='rand', padding=mean_mapping, ignore_label=args.ignore_label),
                             transforms_seg.ToTensor(),
                             source_normalize])
    target_normalize = transforms_seg.Normalize(mean=mean,
                                            std=std)
    target_transform = transforms_seg.Compose([
                             transforms_seg.Resize([input_size_target[1], input_size_target[0]]),
                             segtransforms.RandScale((args.scale_min, args.scale_max)),
                             #segtransforms.RandRotate((args.rotate_min, args.rotate_max), padding=mean_mapping, ignore_label=args.ignore_label),
                             #segtransforms.RandomGaussianBlur(),
                             segtransforms.RandomHorizontalFlip(),
                             segtransforms.Crop([input_size_target_crop[1], input_size_target_crop[0]],crop_type='rand', padding=mean_mapping, ignore_label=args.ignore_label),
                             transforms_seg.ToTensor(),
                             target_normalize])
    trainloader = data.DataLoader(
        GTA5DataSet(args.data_dir, args.data_list, max_iters=args.num_steps * args.iter_size * args.batch_size,
                    crop_size=input_size, transform = source_transform),
        batch_size=args.batch_size, shuffle=True, num_workers=1, pin_memory=True)

    trainloader_iter = enumerate(trainloader)

    targetloader = data.DataLoader(fake_cityscapesDataSet(args.data_dir_target, args.data_list_target,
                                                     max_iters=args.num_steps * args.iter_size * args.batch_size,
                                                     crop_size=input_size_target,
                                                     set=args.set,
                                                     transform=target_transform),
                                   batch_size=args.batch_size, shuffle=True, num_workers=1,
                                   pin_memory=True)


    targetloader_iter = enumerate(targetloader)
    # implement model.optim_parameters(args) to handle different models' lr setting


    criterion_seg = torch.nn.CrossEntropyLoss(ignore_index=255,reduce=False)
    interp_target = nn.Upsample(size=(input_size_target[1], input_size_target[0]), align_corners=True, mode='bilinear')

    # labels for adversarial training
    source_label = 0
    target_label = 1


    optimizer_encoder, optimizer_decoder, optimizer_disc, optimizer_reconst = optimizers
    batch_time = AverageMeter(10)
    loss_seg_value1 = AverageMeter(10)
    is_best_test = True
    best_mIoUs = 0
    loss_seg_value2 = AverageMeter(10)
    loss_balance_value = AverageMeter(10)
    loss_pseudo_value = AverageMeter(10)
    bounding_num = AverageMeter(10)
    pseudo_num = AverageMeter(10)

    for i_iter in range(args.num_steps):
        # train G

        # don't accumulate grads in D

        end = time.time()
        _, batch = trainloader_iter.__next__()
        images, labels, _ = batch
        images = Variable(images).cuda(async=True)
        labels = Variable(labels).cuda(async=True)
        seg, aux_seg, loss_seg2, loss_seg1 = model(images, labels)


        loss_seg2 = torch.mean(loss_seg2)
        loss_seg1 = torch.mean(loss_seg1)
        loss = loss_seg2+args.lambda_seg*loss_seg1
        #logger.info(loss_seg1.data.cpu().numpy())
        loss_seg_value2.update(loss_seg2.data.cpu().numpy())
        # train with target
        optimizer_encoder.zero_grad()
        optimizer_decoder.zero_grad()
        loss.backward()
        optimizer_encoder.step()
        optimizer_decoder.step()

        del seg, loss_seg2

        _, batch = targetloader_iter.__next__()
        with torch.no_grad():
            images, labels, _ = batch
            images = Variable(images).cuda(async=True)
            result = model(images, None)
            del result



        batch_time.update(time.time() - end)

        remain_iter = args.num_steps - i_iter
        remain_time = remain_iter * batch_time.avg
        t_m, t_s = divmod(remain_time, 60)
        t_h, t_m = divmod(t_m, 60)
        remain_time = '{:02d}:{:02d}:{:02d}'.format(int(t_h), int(t_m), int(t_s))




        adjust_learning_rate(optimizer_encoder, i_iter, args.lr_encoder, args)
        adjust_learning_rate(optimizer_decoder, i_iter, args.lr_decoder, args)
        if i_iter % args.print_freq == 0:
            lr_encoder = optimizer_encoder.param_groups[0]['lr']
            lr_decoder = optimizer_decoder.param_groups[0]['lr']
            logger.info('exp = {}'.format(args.snapshot_dir))
            logger.info('Iter = [{0}/{1}]\t'
                        'Time = {batch_time.avg:.3f}\t'
                        'loss_seg1 = {loss_seg1.avg:4f}\t'
                        'loss_seg2 = {loss_seg2.avg:.4f}\t'
                        'lr_encoder = {lr_encoder:.8f} lr_decoder = {lr_decoder:.8f}'.format(
                         i_iter, args.num_steps, batch_time=batch_time,
                         loss_seg1=loss_seg_value1, loss_seg2=loss_seg_value2,
                         lr_encoder=lr_encoder,
                         lr_decoder=lr_decoder))


            logger.info("remain_time: {}".format(remain_time))
            if not tb_logger is None:
                tb_logger.add_scalar('loss_seg_value1', loss_seg_value1.avg, i_iter)
                tb_logger.add_scalar('loss_seg_value2', loss_seg_value2.avg, i_iter)
                tb_logger.add_scalar('lr', lr_encoder, i_iter)
            #####
            #save image result

            if i_iter % args.save_pred_every == 0 and i_iter != 0:
                logger.info('taking snapshot ...')
                model.eval()

                val_time = time.time()
                hist = np.zeros((19,19))
                for index, batch in tqdm(enumerate(valloader)):
                    with torch.no_grad():
                        image, name = batch
                        output2, _ = model(Variable(image).cuda(), None)
                        pred = interp_val(output2)
                        del output2
                        pred = pred.cpu().data[0].numpy()
                        pred = pred.transpose(1, 2, 0)
                        pred = np.asarray(np.argmax(pred, axis=2), dtype=np.uint8)
                        label = np.array(Image.open(gt_imgs_val[index]))
                        #label = np.array(label.resize(com_size, Image.
                        label = label_mapping(label, mapping)
                        #logger.info(label.shape)
                        hist += fast_hist(label.flatten(), pred.flatten(), 19)
                mIoUs = per_class_iu(hist)
                for ind_class in range(args.num_classes):
                    logger.info('===>' + name_classes[ind_class] + ':\t' + str(round(mIoUs[ind_class] * 100, 2)))
                    tb_logger.add_scalar(name_classes[ind_class] + '_mIoU', mIoUs[ind_class], i_iter)

                mIoUs = round(np.nanmean(mIoUs) *100, 2)
                if mIoUs >= best_mIoUs:
                    is_best_test = True
                    best_mIoUs = mIoUs
                else:
                    is_best_test = False

                logger.info("current mIoU {}".format(mIoUs))
                logger.info("best mIoU {}".format(best_mIoUs))
                tb_logger.add_scalar('val mIoU', mIoUs, i_iter)
                tb_logger.add_scalar('val mIoU', mIoUs, i_iter)
                net_encoder, net_decoder, net_disc, net_reconst = nets
                save_checkpoint(net_encoder, 'encoder', i_iter, args, is_best_test)
                save_checkpoint(net_decoder, 'decoder', i_iter, args, is_best_test)
            model.train()
Example #7
0
def main():
    """Create the model and start the training."""
    with open(args.config) as f:
        config = yaml.load(f)
    for k, v in config['common'].items():
        setattr(args, k, v)
    mkdirs(osp.join("logs/" + args.exp_name))

    logger = create_logger('global_logger',
                           "logs/" + args.exp_name + '/log.txt')
    logger.info('{}'.format(args))
    ##############################

    for key, val in vars(args).items():
        logger.info("{:16} {}".format(key, val))
    logger.info("random_scale {}".format(args.random_scale))
    logger.info("is_training {}".format(args.is_training))

    h, w = map(int, args.input_size.split(','))
    input_size = (h, w)

    h, w = map(int, args.input_size_target.split(','))
    input_size_target = (h, w)
    print(type(input_size_target[1]))
    cudnn.enabled = True
    args.snapshot_dir = args.snapshot_dir + args.exp_name
    tb_logger = SummaryWriter("logs/" + args.exp_name)
    ##############################

    #validation data
    h, w = map(int, args.input_size_test.split(','))
    input_size_test = (h, w)
    h, w = map(int, args.com_size.split(','))
    com_size = (h, w)
    h, w = map(int, args.input_size_crop.split(','))
    input_size_crop = h, w
    h, w = map(int, args.input_size_target_crop.split(','))
    input_size_target_crop = h, w

    mean = [0.485, 0.456, 0.406]
    std = [0.229, 0.224, 0.225]

    normalize_module = transforms_seg.Normalize(mean=mean, std=std)

    test_normalize = transforms.Normalize(mean=mean, std=std)

    test_transform = transforms.Compose([
        transforms.Resize((input_size_test[1], input_size_test[0])),
        transforms.ToTensor(), test_normalize
    ])

    valloader = data.DataLoader(cityscapesDataSet(args.data_dir_target,
                                                  args.data_list_target_val,
                                                  crop_size=input_size_test,
                                                  set='train',
                                                  transform=test_transform),
                                num_workers=args.num_workers,
                                batch_size=1,
                                shuffle=False,
                                pin_memory=True)
    with open('./dataset/cityscapes_list/info.json', 'r') as fp:
        info = json.load(fp)
    mapping = np.array(info['label2train'], dtype=np.int)
    label_path_list_val = args.label_path_list_val
    label_path_list_test = './dataset/cityscapes_list/label.txt'
    gt_imgs_val = open(label_path_list_val, 'r').read().splitlines()
    gt_imgs_val = [osp.join(args.data_dir_target_val, x) for x in gt_imgs_val]

    name_classes = np.array(info['label'], dtype=np.str)
    interp_val = nn.Upsample(size=(com_size[1], com_size[0]),
                             mode='bilinear',
                             align_corners=True)

    ####
    #build model
    ####
    builder = ModelBuilder()
    net_encoder = builder.build_encoder(arch=args.arch_encoder,
                                        fc_dim=args.fc_dim,
                                        weights=args.weights_encoder)
    net_decoder = builder.build_decoder(arch=args.arch_decoder,
                                        fc_dim=args.fc_dim,
                                        num_class=args.num_classes,
                                        weights=args.weights_decoder,
                                        use_aux=True)

    weighted_softmax = pd.read_csv("weighted_loss.txt", header=None)
    weighted_softmax = weighted_softmax.values
    weighted_softmax = torch.from_numpy(weighted_softmax)
    weighted_softmax = weighted_softmax / torch.sum(weighted_softmax)
    weighted_softmax = weighted_softmax.cuda().float()
    model = SegmentationModule(net_encoder, net_decoder, args.use_aux)

    if args.num_gpus > 1:
        model = torch.nn.DataParallel(model)
        patch_replication_callback(model)
    model.cuda()

    nets = (net_encoder, net_decoder, None, None)
    optimizers = create_optimizer(nets, args)
    cudnn.enabled = True
    cudnn.benchmark = True
    model.train()

    mean_mapping = [0.485, 0.456, 0.406]
    mean_mapping = [item * 255 for item in mean_mapping]

    if not os.path.exists(args.snapshot_dir):
        os.makedirs(args.snapshot_dir)
    source_transform = transforms_seg.Compose([
        transforms_seg.Resize([input_size[1], input_size[0]]),
        #segtransforms.RandScale((0.75, args.scale_max)),
        #segtransforms.RandRotate((args.rotate_min, args.rotate_max), padding=mean_mapping, ignore_label=args.ignore_label),
        #segtransforms.RandomGaussianBlur(),
        #segtransforms.RandomHorizontalFlip(),
        segtransforms.Crop([input_size_crop[1], input_size_crop[0]],
                           crop_type='rand',
                           padding=mean_mapping,
                           ignore_label=args.ignore_label),
        transforms_seg.ToTensor(),
        normalize_module
    ])

    target_transform = transforms_seg.Compose([
        transforms_seg.Resize([input_size_target[1], input_size_target[0]]),
        #segtransforms.RandScale((0.75, args.scale_max)),
        #segtransforms.RandRotate((args.rotate_min, args.rotate_max), padding=mean_mapping, ignore_label=args.ignore_label),
        #segtransforms.RandomGaussianBlur(),
        #segtransforms.RandomHorizontalFlip(),
        segtransforms.Crop(
            [input_size_target_crop[1], input_size_target_crop[0]],
            crop_type='rand',
            padding=mean_mapping,
            ignore_label=args.ignore_label),
        transforms_seg.ToTensor(),
        normalize_module
    ])
    trainloader = data.DataLoader(GTA5DataSet(args.data_dir,
                                              args.data_list,
                                              max_iters=args.num_steps *
                                              args.iter_size * args.batch_size,
                                              crop_size=input_size,
                                              transform=source_transform),
                                  batch_size=args.batch_size,
                                  shuffle=True,
                                  num_workers=5,
                                  pin_memory=True)

    trainloader_iter = enumerate(trainloader)

    targetloader = data.DataLoader(fake_cityscapesDataSet(
        args.data_dir_target,
        args.data_list_target,
        max_iters=args.num_steps * args.iter_size * args.batch_size,
        crop_size=input_size_target,
        set=args.set,
        transform=target_transform),
                                   batch_size=args.batch_size,
                                   shuffle=True,
                                   num_workers=5,
                                   pin_memory=True)

    targetloader_iter = enumerate(targetloader)
    # implement model.optim_parameters(args) to handle different models' lr setting

    criterion_seg = torch.nn.CrossEntropyLoss(ignore_index=255, reduce=False)
    criterion_pseudo = torch.nn.BCEWithLogitsLoss(reduce=False).cuda()
    bce_loss = torch.nn.BCEWithLogitsLoss().cuda()
    criterion_reconst = torch.nn.L1Loss().cuda()
    criterion_soft_pseudo = torch.nn.MSELoss(reduce=False).cuda()
    criterion_box = torch.nn.CrossEntropyLoss(ignore_index=255, reduce=False)
    interp = nn.Upsample(size=(input_size[1], input_size[0]),
                         align_corners=True,
                         mode='bilinear')
    interp_target = nn.Upsample(size=(input_size_target[1],
                                      input_size_target[0]),
                                align_corners=True,
                                mode='bilinear')

    # labels for adversarial training
    source_label = 0
    target_label = 1

    optimizer_encoder, optimizer_decoder, optimizer_disc, optimizer_reconst = optimizers
    batch_time = AverageMeter(10)
    loss_seg_value1 = AverageMeter(10)
    best_mIoUs = 0
    best_test_mIoUs = 0
    loss_seg_value2 = AverageMeter(10)
    loss_reconst_source_value = AverageMeter(10)
    loss_reconst_target_value = AverageMeter(10)
    loss_balance_value = AverageMeter(10)
    loss_pseudo_value = AverageMeter(10)
    bounding_num = AverageMeter(10)
    pseudo_num = AverageMeter(10)
    loss_bbx_att_value = AverageMeter(10)

    for i_iter in range(args.num_steps):
        # train G

        # don't accumulate grads in D

        end = time.time()
        _, batch = trainloader_iter.__next__()
        images, labels, _ = batch
        images = Variable(images).cuda(async=True)
        labels = Variable(labels).cuda(async=True)
        results = model(images, labels)
        loss_seg2 = results[-2]
        loss_seg1 = results[-1]

        loss_seg2 = torch.mean(loss_seg2)
        loss_seg1 = torch.mean(loss_seg1)
        loss = args.lambda_trade_off * (loss_seg2 +
                                        args.lambda_seg * loss_seg1)
        '''
        source_tensor = Variable(torch.FloatTensor(disc.size()).fill_(source_label)).cuda()
        loss_source_disc = bce_loss(disc, source_tensor)

        loss += loss_source_disc * args.lambda_disc
        '''
        # proper normalization
        #logger.info(loss_seg1.data.cpu().numpy())
        loss_seg_value2.update(loss_seg2.data.cpu().numpy())
        #loss_source_disc_value.update(loss_source_disc.data.cpu().numpy())
        # train with target
        optimizer_encoder.zero_grad()
        optimizer_decoder.zero_grad()
        loss.backward()
        #optimizer.step()
        optimizer_encoder.step()
        optimizer_decoder.step()
        del loss
        del results, loss_seg2, loss_seg1
        #optimizer_disc.step()

        _, batch = targetloader_iter.__next__()
        images, fake_labels, _ = batch
        images = Variable(images).cuda()
        fake_labels = Variable(fake_labels).cuda()
        results = model(images, None)
        target_seg = results[0]
        conf_tea, pseudo_label = torch.max(nn.functional.softmax(target_seg),
                                           dim=1)
        pseudo_label = pseudo_label.detach()
        # pseudo label hard
        loss_pseudo = criterion_seg(target_seg, pseudo_label)
        fake_mask = (fake_labels != 255).float().detach()
        conf_mask = torch.gt(conf_tea, args.conf_threshold).float().detach()
        #loss_weight_pseudo = 0
        #for class_idx in range(args.num_classes):
        #    pseudo_loss_i = torch.sum(loss_pseudo[(pseudo_label == class_idx) & (fake_mask != 0) & (conf_mask!=0)])
        #    pseudo_loss_i /= (1e-15 + torch.sum((fake_mask != 0) & (pseudo_label == class_idx) & (conf_mask!=0)).float() )
        #    loss_weight_pseudo += pseudo_loss_i
        loss_pseudo = loss_pseudo * conf_mask.detach() * fake_mask.detach()
        loss_pseudo = loss_pseudo.view(-1)
        loss_pseudo = loss_pseudo[loss_pseudo != 0]
        #loss_pseudo = torch.sum(loss_pseudo * conf_mask.detach() * fake_mask.detach())
        #logger.info("box_size 1: {}".format(torch.sum(conf_mask * fake_mask) / float(560*480*4)))
        #loss = args.lambda_pseudo * loss_pseudo

        #fake_labels = fake_label.unsqueeze(1)
        #print(loss_pseudo.size(), conf_mask.size(), fake_mask.size())

        #loss_pseudo += loss_soft_pseudo * args.lambda_soft_pseudo
        #class balance loss
        predict_class_mean = torch.mean(nn.functional.softmax(target_seg),
                                        dim=0).mean(1).mean(1)
        equalise_cls_loss = robust_binary_crossentropy(predict_class_mean,
                                                       weighted_softmax)
        #equalise_cls_loss = torch.mean(equalise_cls_loss)* args.num_classes * torch.sum(conf_mask * fake_mask) / float(input_size_crop[0] * input_size_crop[1] * args.batch_size)
        # new equalise_cls_loss
        equalise_cls_loss = torch.mean(equalise_cls_loss)
        #loss=args.lambda_balance * equalise_cls_loss
        #bbx attention
        loss_bbx_att = []
        for box_idx, box_size in enumerate(args.box_size):
            pooling = torch.nn.AvgPool2d(box_size)
            pooling_result_i = pooling(target_seg)
            pooling_conf_mask, pooling_pseudo = torch.max(
                nn.functional.softmax(pooling_result_i), dim=1)
            pooling_conf_mask = torch.gt(pooling_conf_mask,
                                         args.conf_threshold).float().detach()
            fake_mask_i = pooling(fake_labels.unsqueeze(1).float())
            fake_mask_i = fake_mask_i.squeeze(1)
            fake_mask_i = (fake_mask_i != 255).float().detach()
            loss_bbx_att_i = criterion_seg(pooling_result_i, pooling_pseudo)
            loss_bbx_att_i = loss_bbx_att_i * pooling_conf_mask * fake_mask_i
            loss_bbx_att_i = loss_bbx_att_i.view(-1)
            loss_bbx_att_i = loss_bbx_att_i[loss_bbx_att_i != 0]
            loss_bbx_att.append(loss_bbx_att_i)
        del pooling_result_i

        if len(args.box_size) > 0:
            if args.merge_1x1:
                loss_bbx_att.append(loss_pseudo)
            loss_bbx_att = torch.cat(loss_bbx_att, dim=0)
            bounding_num.update(
                loss_bbx_att.size(0) / float(560 * 480 * args.batch_size))
            loss_bbx_att = torch.mean(loss_bbx_att)

        pseudo_num.update(
            loss_pseudo.size(0) / float(560 * 480 * args.batch_size))
        loss_pseudo = torch.mean(loss_pseudo)
        loss = args.lambda_balance * equalise_cls_loss
        if not args.merge_1x1:
            loss += args.lambda_pseudo * loss_pseudo
        if not isinstance(loss_bbx_att, list):
            loss += args.lambda_pseudo * loss_bbx_att
        loss_pseudo_value.update(loss_pseudo.item())
        loss_balance_value.update(equalise_cls_loss.item())

        optimizer_encoder.zero_grad()
        optimizer_decoder.zero_grad()
        loss.backward()
        optimizer_encoder.step()
        optimizer_decoder.step()
        #optimizer_disc.step()
        #loss_target_disc_value.update(loss_target_disc.data.cpu().numpy())

        batch_time.update(time.time() - end)

        remain_iter = args.num_steps - i_iter
        remain_time = remain_iter * batch_time.avg
        t_m, t_s = divmod(remain_time, 60)
        t_h, t_m = divmod(t_m, 60)
        remain_time = '{:02d}:{:02d}:{:02d}'.format(int(t_h), int(t_m),
                                                    int(t_s))

        adjust_learning_rate(optimizer_encoder, i_iter, args.lr_encoder, args)
        adjust_learning_rate(optimizer_decoder, i_iter, args.lr_decoder, args)
        if i_iter % args.print_freq == 0:
            lr_encoder = optimizer_encoder.param_groups[0]['lr']
            lr_decoder = optimizer_decoder.param_groups[0]['lr']
            logger.info('exp = {}'.format(args.snapshot_dir))
            logger.info(
                'Iter = [{0}/{1}]\t'
                'Time = {batch_time.avg:.3f}\t'
                'loss_seg1 = {loss_seg1.avg:4f}\t'
                'loss_seg2 = {loss_seg2.avg:.4f}\t'
                'loss_reconst_source = {loss_reconst_source.avg:.4f}\t'
                'loss_bbx_att = {loss_bbx_att.avg:.4f}\t'
                'loss_reconst_target = {loss_reconst_target.avg:.4f}\t'
                'loss_pseudo = {loss_pseudo.avg:.4f}\t'
                'loss_balance = {loss_balance.avg:.4f}\t'
                'bounding_num = {bounding_num.avg:.4f}\t'
                'pseudo_num = {pseudo_num.avg:4f}\t'
                'lr_encoder = {lr_encoder:.8f} lr_decoder = {lr_decoder:.8f}'.
                format(i_iter,
                       args.num_steps,
                       batch_time=batch_time,
                       loss_seg1=loss_seg_value1,
                       loss_seg2=loss_seg_value2,
                       loss_pseudo=loss_pseudo_value,
                       loss_bbx_att=loss_bbx_att_value,
                       bounding_num=bounding_num,
                       pseudo_num=pseudo_num,
                       loss_reconst_source=loss_reconst_source_value,
                       loss_balance=loss_balance_value,
                       loss_reconst_target=loss_reconst_target_value,
                       lr_encoder=lr_encoder,
                       lr_decoder=lr_decoder))

            logger.info("remain_time: {}".format(remain_time))
            if not tb_logger is None:
                tb_logger.add_scalar('loss_seg_value1', loss_seg_value1.avg,
                                     i_iter)
                tb_logger.add_scalar('loss_seg_value2', loss_seg_value2.avg,
                                     i_iter)
                tb_logger.add_scalar('bounding_num', bounding_num.avg, i_iter)
                tb_logger.add_scalar('pseudo_num', pseudo_num.avg, i_iter)
                tb_logger.add_scalar('loss_pseudo', loss_pseudo_value.avg,
                                     i_iter)
                tb_logger.add_scalar('lr', lr_encoder, i_iter)
                tb_logger.add_scalar('loss_balance', loss_balance_value.avg,
                                     i_iter)
            #####
            #save image result
            if i_iter % args.save_pred_every == 0 and i_iter != 0:
                logger.info('taking snapshot ...')
                model.eval()

                val_time = time.time()
                hist = np.zeros((19, 19))
                is_best = True
                # best_mIoUs = mIoUs
                #test validation
                model.eval()
                val_time = time.time()
                hist = np.zeros((19, 19))
                # f = open(args.result_dir, 'a')
                for index, batch in tqdm(enumerate(valloader)):
                    with torch.no_grad():
                        image, name = batch
                        results = model(Variable(image).cuda(), None)
                        output2 = results[0]
                        pred = interp_val(output2)
                        del output2
                        pred = pred.cpu().data[0].numpy()
                        pred = pred.transpose(1, 2, 0)
                        pred = np.asarray(np.argmax(pred, axis=2),
                                          dtype=np.uint8)
                        label = np.array(Image.open(gt_imgs_val[index]))
                        #label = np.array(label.resize(com_size, Image.
                        label = label_mapping(label, mapping)
                        #logger.info(label.shape)
                        hist += fast_hist(label.flatten(), pred.flatten(), 19)
                mIoUs = per_class_iu(hist)
                for ind_class in range(args.num_classes):
                    logger.info('===>' + name_classes[ind_class] + ':\t' +
                                str(round(mIoUs[ind_class] * 100, 2)))
                    tb_logger.add_scalar(name_classes[ind_class] + '_mIoU',
                                         mIoUs[ind_class], i_iter)

                mIoUs = round(np.nanmean(mIoUs) * 100, 2)
                is_best_test = False
                logger.info(mIoUs)
                tb_logger.add_scalar('test mIoU', mIoUs, i_iter)
                if mIoUs > best_test_mIoUs:
                    best_test_mIoUs = mIoUs
                    is_best_test = True
                # logger.info("best mIoU {}".format(best_mIoUs))
                logger.info("best test mIoU {}".format(best_test_mIoUs))
                net_encoder, net_decoder, net_disc, net_reconst = nets
                save_checkpoint(net_encoder, 'encoder', i_iter, args,
                                is_best_test)
                save_checkpoint(net_decoder, 'decoder', i_iter, args,
                                is_best_test)
                is_best_test = False
            model.train()
def main():
    global args, logger, writer
    args = get_parser().parse_args()
    import multiprocessing as mp
    if mp.get_start_method(allow_none=True) != 'spawn':
        mp.set_start_method('spawn', force=True)
    rank, world_size = dist_init(args.port)
    logger = get_logger()
    writer = SummaryWriter(args.save_path)
    # os.environ["CUDA_VISIBLE_DEVICES"] = ','.join(str(x) for x in args.gpu)
    # if len(args.gpu) == 1:
    #   args.syncbn = False
    if rank == 0:
        logger.info(args)
    assert args.classes > 1
    assert args.zoom_factor in [1, 2, 4, 8]
    assert (args.crop_h - 1) % 8 == 0 and (args.crop_w - 1) % 8 == 0
    assert args.net_type in [0, 1, 2, 3]

    if args.bn_group == 1:
        args.bn_group_comm = None
    else:
        assert world_size % args.bn_group == 0
        args.bn_group_comm = simple_group_split(world_size, rank, world_size // args.bn_group)

    if rank == 0:
        logger.info("=> creating two branch model ...")
        logger.info("Classes: {}".format(args.classes))

    if args.net_type == 0:
        #         from pspnet import PSPNet
        from pspnet2S import PSP2S
        model = PSP2S(backbone=args.backbone, layers=args.layers, classes=args.classes, zoom_factor=args.zoom_factor,
                      syncbn=args.syncbn, group_size=args.bn_group, group=args.bn_group_comm).cuda()
    logger.info(model)
    # optimizer = torch.optim.SGD(model.parameters(), args.base_lr, momentum=args.momentum, weight_decay=args.weight_decay)
    # newly introduced layer with lr x10
    optimizer = torch.optim.SGD(
        [{'params': model.rgb_branch.layer0_img.parameters()},
         {'params': model.rgb_branch.layer0_ins.parameters()},
         {'params': model.rgb_branch.layer0_conv1.parameters()},
         {'params': model.rgb_branch.layer1.parameters()},
         {'params': model.rgb_branch.layer2.parameters()},
         {'params': model.rgb_branch.layer3.parameters()},
         {'params': model.rgb_branch.layer4.parameters()},
         {'params': model.rgb_branch.conv6.parameters()},
         {'params': model.rgb_branch.conv1_1x1.parameters()},
         {'params': model.rgb_branch.ppm.parameters(), 'lr': args.base_lr * 10},
         {'params': model.rgb_branch.cls_ins.parameters(), 'lr': args.base_lr * 10},

         {'params': model.flow_branch.layer0_flo.parameters()},
         {'params': model.flow_branch.layer0_ins.parameters()},
         {'params': model.flow_branch.layer0_conv1.parameters()},
         {'params': model.flow_branch.layer1.parameters()},
         {'params': model.flow_branch.layer2.parameters()},
         {'params': model.flow_branch.layer3.parameters()},
         {'params': model.flow_branch.layer4.parameters()},
         {'params': model.flow_branch.conv6.parameters()},
         {'params': model.flow_branch.conv1_1x1.parameters()},
         {'params': model.flow_branch.ppm.parameters(), 'lr': args.base_lr * 10},
         {'params': model.flow_branch.cls_ins.parameters(), 'lr': args.base_lr * 10}],
        lr=args.base_lr, momentum=args.momentum, weight_decay=args.weight_decay)

    cudnn.enabled = True
    cudnn.benchmark = True
    criterion = nn.NLLLoss(ignore_index=args.ignore_label).cuda()

    if args.weight:
        def map_func(storage, location):
            return storage.cuda()

        if os.path.isfile(args.weight):
            # load rgb branch params.
            logger.info("=> loading rgb weight '{}'".format(args.weight))
            checkpoint = torch.load(args.weight, map_location=map_func)
            new_dict = remove_prefix(checkpoint['state_dict'], 'module.')
            model.rgb_branch.load_state_dict(new_dict, strict=True)
            logger.info("=> loaded rgb weight '{}'".format(args.weight))

            # load flow branch params.
            logger.info("=> loading flow weight '{}'".format(args.weight_flow))
            checkpoint_flow = torch.load(args.weight_flow, map_location=map_func)
            new_dict_flow = remove_prefix(checkpoint_flow['state_dict'], 'module.')
            model.flow_branch.load_state_dict(new_dict_flow, strict=True)
            logger.info("=> loaded flow weight '{}'".format(args.weight_flow))

    else:
            logger.info("=> no weight found at '{}'".format(args.weight))

    model = DistModule(model)

    if args.resume:
        load_state(args.resume, model, optimizer)

    value_scale = 255
    mean = [0.485, 0.456, 0.406]
    mean = [item * value_scale for item in mean]
    std = [0.229, 0.224, 0.225]
    std = [item * value_scale for item in std]

    if args.dataset_name == 'coco':
        logger.info("=> load coco patch datasets....")
        train_data = datasets.PatchDataSet(dataset_path='/mnt/lustre/share/sunpeng/voc_coco_label/coco/',
                                           data_list='train2017', norm=[mean, std])
    elif args.dataset_name == 'voc':
        logger.info("=> load voc patch datasets....")
        train_data = datasets.PatchDataSet(dataset_path='/mnt/lustre/share/sunpeng/voc_coco_label/voc/',
                                           data_list='train_val_ins', norm=[mean, std])
    elif args.dataset_name == 'youtube':
        logger.info("=> load youtube shuffle three pairs patch datasets....")
        train_data = datasets.YoutubePatchDataSet(dataset_path='', gt_path='',
                                                  data_list='/mnt/lustre/sunpeng/Research/video-seg-workshop/models/deeplabv3_models/final_2s_seg_youtube_online_two_finetune_first_frame/add_first_frame_final_0823.txt',
                                                  norm=[mean, std])
    elif args.dataset_name == 'youtube_and_davis':
        logger.info("=> load youtube and davis shuffle three pairs patch datasets....")
        train_data = datasets.YoutubePatchDataSet(dataset_path='', gt_path='',                                                                                                                                           data_list='/mnt/lustre/share/sunpeng/video-seg-workshop/davis/DAVIS/train_shuffle_davis.txt', norm=[mean, std])
    else:
        logger.info("=> no right datset name found, please input coco or voc.")
        assert 0 == 1
    # train_data = datasets.SegData(split='train', data_root=args.data_root, data_list=args.train_list, transform=train_transform)
    train_sampler = DistributedSampler(train_data)
    train_loader = torch.utils.data.DataLoader(train_data, batch_size=args.batch_size, shuffle=False,
                                               num_workers=args.workers, pin_memory=False, sampler=train_sampler)

    if args.evaluate:
        val_transform = transforms.Compose([
            transforms.Crop([args.crop_h, args.crop_w], crop_type='center', padding=mean,
                            ignore_label=args.ignore_label),
            transforms.ToTensor(),
            transforms.Normalize(mean=mean, std=std)])
        val_data = datasets.SegData(split='val', data_root=args.data_root, data_list=args.val_list,
                                    transform=val_transform)
        val_loader = torch.utils.data.DataLoader(val_data, batch_size=args.batch_size_val, shuffle=False,
                                                 num_workers=args.workers, pin_memory=True)

    for epoch in range(args.start_epoch, args.epochs + 1):
        loss_train, mIoU_train, mAcc_train, allAcc_train = train(train_loader, model, criterion, optimizer, epoch,
                                                                 args.zoom_factor, args.batch_size, args.ins_weight)
        if rank == 0:
            writer.add_scalar('loss_train', loss_train, epoch)
            writer.add_scalar('mIoU_train', mIoU_train, epoch)
            writer.add_scalar('mAcc_train', mAcc_train, epoch)
            writer.add_scalar('allAcc_train', allAcc_train, epoch)
        # write parameters histogram costs lots of time
        # for name, param in model.named_parameters():
        #     writer.add_histogram(name, param, epoch)
        if epoch % args.save_step == 0 and rank == 0:
            filename = args.save_path + '/train_epoch_' + str(epoch) + '.pth'
            logger.info('Saving checkpoint to: ' + filename)
            torch.save({'epoch': epoch, 'state_dict': model.state_dict(), 'optimizer': optimizer.state_dict()},
                       filename)
            # if epoch / args.save_step > 2:
            #    deletename = args.save_path + '/train_epoch_' + str(epoch - args.save_step*2) + '.pth'
            #    os.remove(deletename)
        if args.evaluate:
            loss_val, mIoU_val, mAcc_val, allAcc_val = validate(val_loader, model, criterion, args.classes,
                                                                args.zoom_factor)
            writer.add_scalar('loss_val', loss_val, epoch)
            writer.add_scalar('mIoU_val', mIoU_val, epoch)
            writer.add_scalar('mAcc_val', mAcc_val, epoch)
            writer.add_scalar('allAcc_val', allAcc_val, epoch)