print(model)

    tblogger = SummaryWriter(args.tblog_dir)

    train_dataset = voc12.data.VOC12ClsDataset(
        args.train_list,
        voc12_root=args.voc12_root,
        transform=transforms.Compose([
            imutils.RandomResizeLong(448, 768),
            transforms.RandomHorizontalFlip(),
            transforms.ColorJitter(brightness=0.3,
                                   contrast=0.3,
                                   saturation=0.3,
                                   hue=0.1), np.asarray, model.normalize,
            imutils.RandomCrop(args.crop_size), imutils.HWC_to_CHW,
            torch.from_numpy
        ]))

    def worker_init_fn(worker_id):
        np.random.seed(1 + worker_id)

    train_data_loader = DataLoader(train_dataset,
                                   batch_size=args.batch_size,
                                   shuffle=True,
                                   num_workers=args.num_workers,
                                   pin_memory=True,
                                   drop_last=True,
                                   worker_init_fn=worker_init_fn)
    max_step = len(train_dataset) // args.batch_size * args.max_epoches
Ejemplo n.º 2
0
def main():
    global args, logger, writer
    args = get_parser().parse_args()
    import multiprocessing as mp
    if mp.get_start_method(allow_none=True) != 'spawn':
        mp.set_start_method('spawn', force=True)
    rank, world_size = dist_init(args.port)
    logger = get_logger()
    writer = SummaryWriter(args.save_path)
    #os.environ["CUDA_VISIBLE_DEVICES"] = ','.join(str(x) for x in args.gpu)
    #if len(args.gpu) == 1:
    #   args.syncbn = False
    if rank == 0:
        logger.info(args)
    assert args.classes > 1
    assert args.zoom_factor in [1, 2, 4, 8]
    assert (args.crop_h - 1) % 8 == 0 and (args.crop_w - 1) % 8 == 0
    assert args.net_type in [0, 1, 2, 3]

    if args.bn_group == 1:
        args.bn_group_comm = None
    else:
        assert world_size % args.bn_group == 0
        args.bn_group_comm = simple_group_split(world_size, rank,
                                                world_size // args.bn_group)

    if rank == 0:
        logger.info("=> creating model ...")
        logger.info("Classes: {}".format(args.classes))

    from pspnet import PSPNet
    model = PSPNet(backbone=args.backbone,
                   layers=args.layers,
                   classes=args.classes,
                   zoom_factor=args.zoom_factor,
                   syncbn=args.syncbn,
                   group_size=args.bn_group,
                   group=args.bn_group_comm,
                   use_softmax=False,
                   use_aux=False).cuda()

    logger.info(model)
    # optimizer = torch.optim.SGD(model.parameters(), args.base_lr, momentum=args.momentum, weight_decay=args.weight_decay)
    # newly introduced layer with lr x10
    optimizer = torch.optim.SGD(
        [{
            'params': model.layer0.parameters()
        }, {
            'params': model.layer1.parameters()
        }, {
            'params': model.layer2.parameters()
        }, {
            'params': model.layer3.parameters()
        }, {
            'params': model.layer4.parameters()
        }, {
            'params': model.ppm.parameters(),
            'lr': args.base_lr * 10
        }, {
            'params': model.cls.parameters(),
            'lr': args.base_lr * 10
        }, {
            'params': model.result.parameters(),
            'lr': args.base_lr * 10
        }],
        #  {'params': model.aux.parameters(), 'lr': args.base_lr * 10}],
        lr=args.base_lr,
        momentum=args.momentum,
        weight_decay=args.weight_decay)

    #model = torch.nn.DataParallel(model).cuda()
    model = DistModule(model)
    #if args.syncbn:
    #    from lib.syncbn import patch_replication_callback
    #    patch_replication_callback(model)
    cudnn.enabled = True
    cudnn.benchmark = True
    criterion = nn.NLLLoss(ignore_index=args.ignore_label).cuda()

    if args.weight:

        def map_func(storage, location):
            return storage.cuda()

        if os.path.isfile(args.weight):
            logger.info("=> loading weight '{}'".format(args.weight))
            checkpoint = torch.load(args.weight, map_location=map_func)
            model.load_state_dict(checkpoint['state_dict'])
            logger.info("=> loaded weight '{}'".format(args.weight))
        else:
            logger.info("=> no weight found at '{}'".format(args.weight))

    if args.resume:
        load_state(args.resume, model, optimizer)

    value_scale = 255
    mean = [0.485, 0.456, 0.406]
    mean = [item * value_scale for item in mean]
    std = [0.229, 0.224, 0.225]
    std = [item * value_scale for item in std]

    normalize = Normalize()
    train_data = voc12.data.VOC12ClsDataset(
        args.train_list,
        voc12_root=args.voc12_root,
        transform=transforms.Compose([
            imutils.RandomResizeLong(400, 512),
            transforms.RandomHorizontalFlip(),
            transforms.ColorJitter(brightness=0.3,
                                   contrast=0.3,
                                   saturation=0.3,
                                   hue=0.1), np.asarray, normalize,
            imutils.RandomCrop(args.crop_size), imutils.HWC_to_CHW,
            torch.from_numpy
        ]))

    train_sampler = DistributedSampler(train_data)
    train_loader = torch.utils.data.DataLoader(train_data,
                                               batch_size=args.batch_size,
                                               shuffle=False,
                                               num_workers=args.workers,
                                               pin_memory=False,
                                               sampler=train_sampler)

    for epoch in range(args.start_epoch, args.epochs + 1):
        loss_train = train(train_loader, model, criterion, optimizer, epoch,
                           args.zoom_factor, args.batch_size, args.aux_weight)
        if rank == 0:
            writer.add_scalar('loss_train', loss_train, epoch)
        # write parameters histogram costs lots of time
        # for name, param in model.named_parameters():
        #     writer.add_histogram(name, param, epoch)

        if epoch % args.save_step == 0 and rank == 0:
            filename = args.save_path + '/train_epoch_' + str(epoch) + '.pth'
            logger.info('Saving checkpoint to: ' + filename)
            torch.save(
                {
                    'epoch': epoch,
                    'state_dict': model.state_dict(),
                    'optimizer': optimizer.state_dict()
                }, filename)
Ejemplo n.º 3
0
    print(vars(args))

    model = getattr(importlib.import_module(args.network), 'Net')()

    print(model)

    train_dataset = voc12.data.VOC12AffDataset(
        args.train_list,
        label_la_dir=args.la_crf_dir,
        label_ha_dir=args.ha_crf_dir,
        voc12_root=args.voc12_root,
        cropsize=args.crop_size,
        radius=5,
        joint_transform_list=[
            None, None,
            imutils.RandomCrop(args.crop_size),
            imutils.RandomHorizontalFlip()
        ],
        img_transform_list=[
            transforms.ColorJitter(brightness=0.3,
                                   contrast=0.3,
                                   saturation=0.3,
                                   hue=0.1), np.asarray, model.normalize,
            imutils.HWC_to_CHW
        ],
        label_transform_list=[None, None, None,
                              imutils.AvgPool2d(8)])

    train_data_loader = DataLoader(train_dataset,
                                   batch_size=args.batch_size,
                                   shuffle=True,
Ejemplo n.º 4
0
Archivo: eval.py Proyecto: dingmyu/psa
def main():
    global args, logger
    args = get_parser().parse_args()
    logger = get_logger()
    # os.environ["CUDA_VISIBLE_DEVICES"] = ','.join(str(x) for x in args.gpu)
    logger.info(args)
    assert args.classes > 1
    assert args.zoom_factor in [1, 2, 4, 8]
    assert (args.crop_h - 1) % 8 == 0 and (args.crop_w - 1) % 8 == 0
    assert args.split in ['train', 'val', 'test']
    logger.info("=> creating model ...")
    logger.info("Classes: {}".format(args.classes))

    value_scale = 255
    mean = [0.485, 0.456, 0.406]
    mean = [item * value_scale for item in mean]
    std = [0.229, 0.224, 0.225]
    std = [item * value_scale for item in std]
    normalize = Normalize()
    infer_dataset = voc12.data.VOC12ClsDataset(
        args.train_list,
        voc12_root=args.voc12_root,
        transform=transforms.Compose([
            np.asarray,
            imutils.RandomCrop(441),
            #      normalize,
            imutils.HWC_to_CHW
        ]))

    val_loader1 = DataLoader(infer_dataset,
                             batch_size=1,
                             shuffle=False,
                             num_workers=args.workers,
                             pin_memory=True)

    from pspnet import PSPNet
    model = PSPNet(backbone=args.backbone,
                   layers=args.layers,
                   classes=args.classes,
                   zoom_factor=args.zoom_factor,
                   use_softmax=False,
                   use_aux=False,
                   pretrained=False,
                   syncbn=False).cuda()

    logger.info(model)
    #  model = torch.nn.DataParallel(model).cuda()
    model = model.cuda()
    cudnn.enabled = True
    cudnn.benchmark = True
    if os.path.isfile(args.model_path):
        logger.info("=> loading checkpoint '{}'".format(args.model_path))
        checkpoint = torch.load(args.model_path)
        #   model.load_state_dict(checkpoint['state_dict'], strict=False)
        #   logger.info("=> loaded checkpoint '{}'".format(args.model_path))

        pretrained_dict = {
            k.replace('module.', ''): v
            for k, v in checkpoint['state_dict'].items()
        }

        dict1 = model.state_dict()
        model.load_state_dict(pretrained_dict, strict=False)

    else:
        raise RuntimeError("=> no checkpoint found at '{}'".format(
            args.model_path))
    cv2.setNumThreads(0)
    validate(val_loader1, model, args.classes, mean, std, args.base_size1)