Beispiel #1
0
def main(class_num, pre_trained, train_data, batch_size, momentum, lr, cate_weight, epoch, weights):

    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    model = PSPNet(num_classes=class_num, downsample_factor=16, pretrained=True, aux_branch=False)
    model = model.to(device)

    train_loader = Data.DataLoader(dataset=train_data, batch_size=batch_size, shuffle=True)

    optimizer = torch.optim.SGD(model.parameters(), lr=lr, momentum=momentum)

    loss_func = nn.CrossEntropyLoss(weight=torch.from_numpy(np.array(cate_weight)).float()).cuda()

    model.train()
    for i in range(epoch):
        for step, (b_x, b_y) in enumerate(train_loader):
            b_x = b_x.to(device)
            b_y = b_y.to(device)
            b_y = b_y.view(-1, 473, 473)
            output = model(b_x)
            loss = loss_func(output, b_y.long())
            loss = loss.to(device)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            if step % 1 == 0:
                print("Epoch:{0} || Step:{1} || Loss:{2}".format(i, step, format(loss, ".4f")))

    torch.save(model.state_dict(), weights + "PSPNet_weights" + ".pth")
Beispiel #2
0
def main():
    global args, logger
    args = get_parser().parse_args()
    logger = get_logger()
    # os.environ["CUDA_VISIBLE_DEVICES"] = ','.join(str(x) for x in args.gpu)
    logger.info(args)
    assert args.classes > 1
    assert args.zoom_factor in [1, 2, 4, 8]
    assert (args.crop_h - 1) % 8 == 0 and (args.crop_w - 1) % 8 == 0
    assert args.split in ['train', 'val', 'test']
    logger.info("=> creating model ...")
    logger.info("Classes: {}".format(args.classes))

    value_scale = 255
    mean = [0.485, 0.456, 0.406]
    mean = [item * value_scale for item in mean]
    std = [0.229, 0.224, 0.225]
    std = [item * value_scale for item in std]

    train_transform = transforms.Compose([
        transforms.RandScale([0.5, 2]),
        transforms.RandRotate([-10, 10], padding=mean, ignore_label=args.ignore_label),
        transforms.RandomGaussianBlur(),
        transforms.RandomHorizontalFlip(),
        transforms.Crop([args.crop_h, args.crop_w], crop_type='rand', padding=mean, ignore_label=args.ignore_label),
        transforms.ToTensor()])

    val_transform = transforms.Compose([transforms.Crop([args.crop_h, args.crop_w], crop_type='center', padding=mean, ignore_label=args.ignore_label),
                                        transforms.ToTensor()])
    val_data1 = datasets.SegData(split='train', data_root=args.data_root, data_list=args.val_list1, transform=val_transform)
    val_loader1 = torch.utils.data.DataLoader(val_data1, batch_size=1, shuffle=False, num_workers=args.workers, pin_memory=True)

    from pspnet import PSPNet
    model = PSPNet(backbone = args.backbone, layers=args.layers, classes=args.classes, zoom_factor=args.zoom_factor, use_softmax=False, use_aux=False, pretrained=False, syncbn=False).cuda()

    logger.info(model)
  #  model = torch.nn.DataParallel(model).cuda()
    model = model.cuda()
    cudnn.enabled = True
    cudnn.benchmark = True
    if os.path.isfile(args.model_path):
        logger.info("=> loading checkpoint '{}'".format(args.model_path))
        checkpoint = torch.load(args.model_path)
     #   model.load_state_dict(checkpoint['state_dict'], strict=False)
     #   logger.info("=> loaded checkpoint '{}'".format(args.model_path))


        pretrained_dict = {k.replace('module.',''): v for k, v in checkpoint['state_dict'].items()}

        dict1 = model.state_dict()
        model.load_state_dict(pretrained_dict, strict=False)

    else:
        raise RuntimeError("=> no checkpoint found at '{}'".format(args.model_path))
    cv2.setNumThreads(0)
    validate(val_loader1, val_data1.data_list, model, args.classes, mean, std, args.base_size1, args.crop_h, args.crop_w, args.scales)
Beispiel #3
0
def main():
    # Step 0: preparation
    #place = paddle.fluid.CUDAPlace(0)
    with fluid.dygraph.guard():
        # Step 1: Define training dataloader
        image_folder="work/dummy_data"
        image_list_file="work/dummy_data/list.txt"
        transform = TrainAugmentation(224)
        data = BasicDataLoader(image_folder,image_list_file,transform=transform)
        #TODO: create dataloader
        train_dataloader = fluid.io.DataLoader.from_generator(capacity=2,return_list=True)
        train_dataloader.set_sample_generator(data,args.batch_size)
        total_batch = len(data)//args.batch_size
        # Step 2: Create model
        if args.net == "basic":
            #TODO: create basicmodel
            model = PSPNet()
        else:
            raise NotImplementedError(f"args.net: {args.net} is not Supported!")

        # Step 3: Define criterion and optimizer
        criterion = Basic_SegLoss

        # create optimizer
        optimizer = AdamOptimizer(learning_rate=args.lr,parameter_list=model.parameters())
        # Step 4: Training
        for epoch in range(1, args.num_epochs+1):
            train_loss = train(train_dataloader,
                               model,
                               criterion,
                               optimizer,
                               epoch,
                               total_batch)
            print(f"----- Epoch[{epoch}/{args.num_epochs}] Train Loss: {train_loss:.4f}")

            if epoch % args.save_freq == 0 or epoch == args.num_epochs:
                model_path = os.path.join(args.checkpoint_folder, f"{args.net}-Epoch-{epoch}")

                # TODO: save model and optmizer states
                model_dict = model.state_dict()
                fluid.save_dygraph(model_dict,model_path)
                optim_dict = optimizer.state_dict()
                fluid.save_dygraph(optim_dict,model_path)

                print(f'----- Save model: {model_path}.pdparams')
                print(f'----- Save optimizer: {model_path}.pdopt')
Beispiel #4
0
def main():
    global args, logger, writer
    args = get_parser().parse_args()
    import multiprocessing as mp
    if mp.get_start_method(allow_none=True) != 'spawn':
        mp.set_start_method('spawn', force=True)
    rank, world_size = dist_init(args.port)
    logger = get_logger()
    writer = SummaryWriter(args.save_path)
    #os.environ["CUDA_VISIBLE_DEVICES"] = ','.join(str(x) for x in args.gpu)
    #if len(args.gpu) == 1:
    #   args.syncbn = False
    if rank == 0:
        logger.info(args)

    if args.bn_group == 1:
        args.bn_group_comm = None
    else:
        assert world_size % args.bn_group == 0
        args.bn_group_comm = simple_group_split(world_size, rank, world_size // args.bn_group)

    if rank == 0:
        logger.info("=> creating model ...")
        logger.info("Classes: {}".format(args.classes))

    from pspnet import PSPNet
    model = PSPNet(backbone=args.backbone, layers=args.layers, classes=args.classes, zoom_factor=args.zoom_factor, syncbn=args.syncbn, group_size=args.bn_group, group=args.bn_group_comm).cuda()
    logger.info(model)
    model_ppm = PPM().cuda()
    # optimizer = torch.optim.SGD(model.parameters(), args.base_lr, momentum=args.momentum, weight_decay=args.weight_decay)
    # newly introduced layer with lr x10
    optimizer = torch.optim.SGD(
        [{'params': model.layer0.parameters()},
         {'params': model.layer1.parameters()},
         {'params': model.layer2.parameters()},
         {'params': model.layer3.parameters()},
         {'params': model.layer4_ICR.parameters()},
         {'params': model.layer4_PFR.parameters()},
         {'params': model.layer4_PRP.parameters()},
         {'params': model_ppm.cls_trans.parameters(), 'lr': args.base_lr * 10},
         {'params': model_ppm.cls_quat.parameters(), 'lr': args.base_lr * 10}
        ],
        lr=args.base_lr, momentum=args.momentum, weight_decay=args.weight_decay)

    #model = torch.nn.DataParallel(model).cuda()
    model = DistModule(model)
    model_ppm = DistModule(model_ppm)
    cudnn.enabled = True
    cudnn.benchmark = True
    criterion = nn.L1Loss().cuda()

    if args.weight:
        def map_func(storage, location):
            return storage.cuda()
        if os.path.isfile(args.weight):
            logger.info("=> loading weight '{}'".format(args.weight))
            checkpoint = torch.load(args.weight, map_location=map_func)
            model.load_state_dict(checkpoint['state_dict'])
            logger.info("=> loaded weight '{}'".format(args.weight))
        else:
            logger.info("=> no weight found at '{}'".format(args.weight))

    if args.resume:
        load_state(args.resume, model, model_ppm, optimizer)

    value_scale = 255
    mean = [0.485, 0.456, 0.406]
    mean = [item * value_scale for item in mean]
    std = [0.229, 0.224, 0.225]
    std = [item * value_scale for item in std]

    train_transform = transforms.Compose([
        transforms.Resize(size=(256,256)),
        #transforms.RandomGaussianBlur(),
        transforms.Crop([args.crop_h, args.crop_w], crop_type='rand', padding=mean, ignore_label=args.ignore_label),
        transforms.ColorJitter([0.4,0.4,0.4]),
        transforms.ToTensor(),
        transforms.Normalize(mean=mean, std=std)])
    
    train_data = datasets.SegData(split='train', data_root=args.data_root, data_list=args.train_list, transform=train_transform)
    train_sampler = DistributedSampler(train_data)
    train_loader = torch.utils.data.DataLoader(train_data, batch_size=args.batch_size, shuffle=False, num_workers=args.workers, pin_memory=False, sampler=train_sampler)

    if args.evaluate:
        val_transform = transforms.Compose([
            transforms.Resize(size=(256,256)),
            transforms.Crop([args.crop_h, args.crop_w], crop_type='center', padding=mean, ignore_label=args.ignore_label),
            transforms.ToTensor(),
            transforms.Normalize(mean=mean, std=std)])
        val_data = datasets.SegData(split='val', data_root=args.data_root, data_list=args.val_list, transform=val_transform)
        val_sampler = DistributedSampler(val_data)
        val_loader = torch.utils.data.DataLoader(val_data, batch_size=args.batch_size_val, shuffle=False, num_workers=args.workers, pin_memory=True, sampler=val_sampler)

    for epoch in range(args.start_epoch, args.epochs + 1):
        t_loss_train, r_loss_train= train(train_loader, model, model_ppm, criterion, optimizer, epoch, args.zoom_factor, args.batch_size, args.aux_weight)
        if rank == 0:
            writer.add_scalar('t_loss_train', t_loss_train, epoch)
            writer.add_scalar('r_loss_train', r_loss_train, epoch)
        # write parameters histogram costs lots of time
        # for name, param in model.named_parameters():
        #     writer.add_histogram(name, param, epoch)

        if epoch % args.save_step == 0 and rank == 0:
            filename = args.save_path + '/train_epoch_' + str(epoch) + '.pth'
            logger.info('Saving checkpoint to: ' + filename)
            filename_ppm = args.save_path + '/train_epoch_' + str(epoch) + '_ppm.pth'
            torch.save({'epoch': epoch, 'state_dict': model.state_dict(), 'optimizer': optimizer.state_dict()}, filename)
            torch.save({'epoch': epoch, 'state_dict': model_ppm.state_dict(), 'optimizer': optimizer.state_dict()}, filename_ppm)
            #if epoch / args.save_step > 2:
            #    deletename = args.save_path + '/train_epoch_' + str(epoch - args.save_step*2) + '.pth'
            #    os.remove(deletename)
        if args.evaluate:
            t_loss_val, r_loss_val= validate(val_loader, model, model_ppm, criterion)
            writer.add_scalar('t_loss_val', t_loss_val, epoch)
            writer.add_scalar('r_loss_val', r_loss_val, epoch)
    writer.close()
def main():
    global args, logger, writer
    args = get_parser().parse_args()
    import multiprocessing as mp
    if mp.get_start_method(allow_none=True) != 'spawn':
        mp.set_start_method('spawn', force=True)
    rank, world_size = dist_init(args.port)
    logger = get_logger()
    writer = SummaryWriter(args.save_path)
    #os.environ["CUDA_VISIBLE_DEVICES"] = ','.join(str(x) for x in args.gpu)
    #if len(args.gpu) == 1:
    #   args.syncbn = False
    if rank == 0:
        logger.info(args)
    assert args.classes > 1
    assert args.zoom_factor in [1, 2, 4, 8]
    assert (args.crop_h-1) % 8 == 0 and (args.crop_w-1) % 8 == 0
    assert args.net_type in [0, 1, 2, 3]

    if args.bn_group == 1:
        args.bn_group_comm = None
    else:
        assert world_size % args.bn_group == 0
        args.bn_group_comm = simple_group_split(world_size, rank, world_size // args.bn_group)

    if rank == 0:
        logger.info("=> creating model ...")
        logger.info("Classes: {}".format(args.classes))

    if args.net_type == 0:
        from pspnet import PSPNet
        model = PSPNet(backbone=args.backbone, layers=args.layers, classes=args.classes, zoom_factor=args.zoom_factor, syncbn=args.syncbn, group_size=args.bn_group, group=args.bn_group_comm).cuda()
    elif  args.net_type in [1, 2, 3]:
        from pspnet_div4 import PSPNet
        model = PSPNet(backbone=args.backbone, layers=args.layers, classes=args.classes, zoom_factor=args.zoom_factor, syncbn=args.syncbn, group_size=args.bn_group, group=args.bn_group_comm, net_type=args.net_type).cuda()
    logger.info(model)

    # optimizer = torch.optim.SGD(model.parameters(), args.base_lr, momentum=args.momentum, weight_decay=args.weight_decay)
    # newly introduced layer with lr x10
    if args.net_type == 0:
        optimizer = torch.optim.SGD(
            [{'params': model.layer0.parameters()},
             {'params': model.layer1.parameters()},
             {'params': model.layer2.parameters()},
             {'params': model.layer3.parameters()},
             {'params': model.layer4.parameters()},
             {'params': model.ppm.parameters(), 'lr': args.base_lr * 10},
			 {'params': model.conv6.parameters(), 'lr': args.base_lr * 10},
			 {'params': model.conv1_1x1.parameters(), 'lr': args.base_lr * 10},
             {'params': model.cls.parameters(), 'lr': args.base_lr * 10},
             {'params': model.aux.parameters(), 'lr': args.base_lr * 10}],
            lr=args.base_lr, momentum=args.momentum, weight_decay=args.weight_decay)
    elif args.net_type == 1:
        optimizer = torch.optim.SGD(
            [{'params': model.layer0.parameters()},
             {'params': model.layer1.parameters()},
             {'params': model.layer2.parameters()},
             {'params': model.layer3.parameters()},
             {'params': model.layer4.parameters()},
             {'params': model.layer4_p.parameters()},
             {'params': model.ppm.parameters(), 'lr': args.base_lr * 10},
             {'params': model.ppm_p.parameters(), 'lr': args.base_lr * 10},
             {'params': model.cls.parameters(), 'lr': args.base_lr * 10},
             {'params': model.cls_p.parameters(), 'lr': args.base_lr * 10},
             {'params': model.aux.parameters(), 'lr': args.base_lr * 10}],
            lr=args.base_lr, momentum=args.momentum, weight_decay=args.weight_decay)
    elif args.net_type == 2:
        optimizer = torch.optim.SGD(
            [{'params': model.layer0.parameters()},
             {'params': model.layer1.parameters()},
             {'params': model.layer2.parameters()},
             {'params': model.layer3.parameters()},
             {'params': model.layer4.parameters()},
             {'params': model.layer4_p.parameters()},
             {'params': model.ppm.parameters(), 'lr': args.base_lr * 10},
             {'params': model.ppm_p.parameters(), 'lr': args.base_lr * 10},
             {'params': model.cls.parameters(), 'lr': args.base_lr * 10},
             {'params': model.cls_p.parameters(), 'lr': args.base_lr * 10},
             {'params': model.att.parameters(), 'lr': args.base_lr * 10},
             {'params': model.att_p.parameters(), 'lr': args.base_lr * 10},
             {'params': model.aux.parameters(), 'lr': args.base_lr * 10}],
            lr=args.base_lr, momentum=args.momentum, weight_decay=args.weight_decay)
    elif args.net_type == 3:
        optimizer = torch.optim.SGD(
            [{'params': model.layer0.parameters()},
             {'params': model.layer1.parameters()},
             {'params': model.layer2.parameters()},
             {'params': model.layer3.parameters()},
             {'params': model.layer4.parameters()},
             {'params': model.layer4_p.parameters()},
             {'params': model.ppm.parameters(), 'lr': args.base_lr * 10},
             {'params': model.ppm_p.parameters(), 'lr': args.base_lr * 10},
             {'params': model.cls.parameters(), 'lr': args.base_lr * 10},
             {'params': model.cls_p.parameters(), 'lr': args.base_lr * 10},
             {'params': model.att.parameters(), 'lr': args.base_lr * 10},
             {'params': model.aux.parameters(), 'lr': args.base_lr * 10}],
            lr=args.base_lr, momentum=args.momentum, weight_decay=args.weight_decay)


    fcw = V11RFCN()
    fcw_model =  torch.load('checkpoint_e8.pth')['state_dict']
    fcw_dict = fcw.state_dict()
    pretrained_fcw = {k: v for k, v in fcw_model.items() if k in fcw_dict}
    fcw_dict.update(pretrained_fcw)
    fcw.load_state_dict(fcw_dict)
    #fcw = DistModule(fcw)
    #print(fcw)
    fcw = fcw.cuda()


    #model = torch.nn.DataParallel(model).cuda()
    model = DistModule(model)
    #if args.syncbn:
    #    from lib.syncbn import patch_replication_callback
    #    patch_replication_callback(model)

    cudnn.enabled = True
    cudnn.benchmark = True
    criterion = nn.NLLLoss(ignore_index=args.ignore_label).cuda()

    if args.weight:
        def map_func(storage, location):
            return storage.cuda()
        if os.path.isfile(args.weight):
            logger.info("=> loading weight '{}'".format(args.weight))
            checkpoint = torch.load(args.weight, map_location=map_func)['state_dict']
            checkpoint = {k: v for k, v in checkpoint.items() if 'ppm' not in k}
            model_dict = model.state_dict()
            model_dict.update(checkpoint)
            model.load_state_dict(model_dict)
            logger.info("=> loaded weight '{}'".format(args.weight))
        else:
            logger.info("=> no weight found at '{}'".format(args.weight))

    if args.resume:
        load_state(args.resume, model, optimizer)

    value_scale = 255
    mean = [0.485, 0.456, 0.406]
    mean = [item * value_scale for item in mean]
    std = [0.229, 0.224, 0.225]
    std = [item * value_scale for item in std]

    train_transform = transforms.Compose([
        transforms.RandScale([args.scale_min, args.scale_max]),
        #transforms.RandRotate([args.rotate_min, args.rotate_max], padding=mean, ignore_label=args.ignore_label),
        transforms.RandomGaussianBlur(),
        transforms.RandomHorizontalFlip(),
        transforms.Crop([args.crop_h, args.crop_w], crop_type='rand', padding=mean, ignore_label=args.ignore_label),
        transforms.ToTensor(),
        transforms.Normalize(mean=mean, std=std)])
    train_data = datasets.SegData(split='train', data_root=args.data_root, data_list=args.train_list, transform=train_transform)
    train_sampler = DistributedSampler(train_data)
    train_loader = torch.utils.data.DataLoader(train_data, batch_size=args.batch_size, shuffle=False, num_workers=args.workers, pin_memory=False, sampler=train_sampler)

    if args.evaluate:
        val_transform = transforms.Compose([
            transforms.Crop([args.crop_h, args.crop_w], crop_type='center', padding=mean, ignore_label=args.ignore_label),
            transforms.ToTensor(),
            transforms.Normalize(mean=mean, std=std)])
        val_data = datasets.SegData(split='val', data_root=args.data_root, data_list=args.val_list, transform=val_transform)
        val_loader = torch.utils.data.DataLoader(val_data, batch_size=args.batch_size_val, shuffle=False, num_workers=args.workers, pin_memory=True)

    for epoch in range(args.start_epoch, args.epochs + 1):
        loss_train, mIoU_train, mAcc_train, allAcc_train = train(train_loader, model, criterion, optimizer, epoch, args.zoom_factor, args.batch_size, args.aux_weight, fcw)
        if rank == 0:
            writer.add_scalar('loss_train', loss_train, epoch)
            writer.add_scalar('mIoU_train', mIoU_train, epoch)
            writer.add_scalar('mAcc_train', mAcc_train, epoch)
            writer.add_scalar('allAcc_train', allAcc_train, epoch)
        # write parameters histogram costs lots of time
        # for name, param in model.named_parameters():
        #     writer.add_histogram(name, param, epoch)

        if epoch % args.save_step == 0 and rank == 0:
            filename = args.save_path + '/train_epoch_' + str(epoch) + '.pth'
            logger.info('Saving checkpoint to: ' + filename)
            torch.save({'epoch': epoch, 'state_dict': model.state_dict(), 'optimizer': optimizer.state_dict()}, filename)
            #if epoch / args.save_step > 2:
            #    deletename = args.save_path + '/train_epoch_' + str(epoch - args.save_step*2) + '.pth'
            #    os.remove(deletename)
        if args.evaluate:
            loss_val, mIoU_val, mAcc_val, allAcc_val = validate(val_loader, model, criterion, args.classes, args.zoom_factor)
            writer.add_scalar('loss_val', loss_val, epoch)
            writer.add_scalar('mIoU_val', mIoU_val, epoch)
            writer.add_scalar('mAcc_val', mAcc_val, epoch)
            writer.add_scalar('allAcc_val', allAcc_val, epoch)
Beispiel #6
0
                seg_preds_np = seg_preds.detach().cpu().numpy()
                seg_gts_np = seg_gts.cpu().numpy()

                confusion_matrix = get_confusion_matrix_for_3d(seg_gts_np, seg_preds_np, class_num=6)
                pos = confusion_matrix.sum(1)
                res = confusion_matrix.sum(0)
                tp = np.diag(confusion_matrix)
                IU_array = (tp / np.maximum(1.0, pos + res - tp))
                mean_IU = IU_array.mean()

                log_str = "[E{}/{} - {}] ".format(i, epoch, j)
                log_str += "loss[seg]: {:0.4f}, miou: {:0.4f}, ".format(loss_seg.item(), mean_IU)
                print (log_str)

                images_np = np.transpose((images.cpu().numpy()+1)*127.5, (0, 2, 3, 1))
                n, h, w, c = images_np.shape
                images_np = images_np.reshape(n*h, w, -1)[:, :, 0]
                seg_preds_np = seg_preds_np.reshape(n*h, w)
                visual_np = np.concatenate([images_np, seg_preds_np*40], axis=1)       # NH * W
                cv2.imwrite('visual.png', visual_np)
                epoch_iou.append(mean_IU)

        epoch_iou = np.mean(epoch_iou)
        epoch_end = time.time()
        epoch_time = round(epoch_end-epoch_start, 2)
        print ("=> This epoch costs {}s...".format(epoch_time))
        if i % 10 == 0 or i ==  epoch-1:
            print ("=> saving to {}".format("{}/epoch_{}_iou{:0.2f}.pth".format(ckpt_dir, i, epoch_iou)))
            torch.save(model.state_dict(), "{}/epoch_{}_iou{:0.2f}.pth".format(ckpt_dir, i, epoch_iou))

Beispiel #7
0
def main():
    global args, best_record
    args = parser.parse_args()

    if args.augment:
        transform_train = joint_transforms.Compose([
            joint_transforms.FreeScale((512, 512)),
            joint_transforms.RandomHorizontallyFlip(),
            joint_transforms.RandomVerticallyFlip(),
            joint_transforms.Rotate(90),
        ])
        transform_val = joint_transforms.Compose(
            [joint_transforms.FreeScale((512, 512))])
    else:
        transform_train = None

    dataset_train = dataset.PRCVData('train', args.data_root,
                                     args.label_train_list, transform_train)
    dataloader_train = data.DataLoader(dataset_train,
                                       batch_size=args.batch_size,
                                       shuffle=True,
                                       num_workers=8)

    dataset_val = dataset.PRCVData('val', args.data_root, args.label_val_list,
                                   transform_val)
    dataloader_val = data.DataLoader(dataset_val,
                                     batch_size=args.batch_size,
                                     shuffle=None,
                                     num_workers=8)

    model = PSPNet(num_classes=args.num_class)

    print('Number of model parameters: {}'.format(
        sum([p.data.nelement() for p in model.parameters()])))

    saved_state_dict = torch.load(args.restore_from)
    new_params = model.state_dict().copy()
    if args.num_class != 21:
        for i in saved_state_dict:
            #Scale.layer5.conv2d_list.3.weight
            i_parts = i.split('.')
            # print i_parts
            if i_parts[0] != 'fc':
                new_params[i] = saved_state_dict[i]
    model.load_state_dict(new_params)

    model = model.cuda()
    model = torch.nn.DataParallel(model)
    cudnn.benchmark = True

    # optionally resume from a checkpoint
    if args.resume:
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            model.load_state_dict(checkpoint['state_dict'])
            print("=> loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))

    # define loss function (criterion) and pptimizer
    criterion = torch.nn.CrossEntropyLoss(ignore_index=255).cuda()
    optimizer = torch.optim.SGD([{
        'params': get_1x_lr_params(model),
        'lr': args.learning_rate
    }, {
        'params': get_10x_lr_params(model),
        'lr': 10 * args.learning_rate
    }],
                                lr=args.learning_rate,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay,
                                nesterov=True)

    for epoch in range(args.start_epoch, args.epochs):
        adjust_learning_rate(optimizer, epoch)

        # train for one epoch
        train(dataloader_train, model, criterion, optimizer, epoch)

        # evaluate on validation set
        acc, mean_iou, val_loss = validate(dataloader_val, model, criterion,
                                           args.result_pth, epoch)

        is_best = mean_iou > best_record['miou']
        if is_best:
            best_record['epoch'] = epoch
            best_record['val_loss'] = val_loss.avg
            best_record['acc'] = acc
            best_record['miou'] = mean_iou
        save_checkpoint(
            {
                'epoch': epoch + 1,
                'val_loss': val_loss.avg,
                'accuracy': acc,
                'miou': mean_iou,
                'state_dict': model.state_dict(),
            }, is_best)

        print(
            '------------------------------------------------------------------------------------------------------'
        )
        print('[epoch: %d], [val_loss: %5f], [acc: %.5f], [miou: %.5f]' %
              (epoch, val_loss.avg, acc, mean_iou))
        print(
            'best record: [epoch: {epoch}], [val_loss: {val_loss:.5f}], [acc: {acc:.5f}], [miou: {miou:.5f}]'
            .format(**best_record))
        print(
            '------------------------------------------------------------------------------------------------------'
        )
Beispiel #8
0
def main():
    global args, logger, writer
    args = get_parser().parse_args()
    import multiprocessing as mp
    if mp.get_start_method(allow_none=True) != 'spawn':
        mp.set_start_method('spawn', force=True)
    rank, world_size = dist_init(args.port)
    logger = get_logger()
    writer = SummaryWriter(args.save_path)
    #os.environ["CUDA_VISIBLE_DEVICES"] = ','.join(str(x) for x in args.gpu)
    #if len(args.gpu) == 1:
    #   args.syncbn = False
    if rank == 0:
        logger.info(args)
    assert args.classes > 1
    assert args.zoom_factor in [1, 2, 4, 8]
    assert (args.crop_h - 1) % 8 == 0 and (args.crop_w - 1) % 8 == 0
    assert args.net_type in [0, 1, 2, 3]

    if args.bn_group == 1:
        args.bn_group_comm = None
    else:
        assert world_size % args.bn_group == 0
        args.bn_group_comm = simple_group_split(world_size, rank,
                                                world_size // args.bn_group)

    if rank == 0:
        logger.info("=> creating model ...")
        logger.info("Classes: {}".format(args.classes))

    from pspnet import PSPNet
    model = PSPNet(backbone=args.backbone,
                   layers=args.layers,
                   classes=args.classes,
                   zoom_factor=args.zoom_factor,
                   syncbn=args.syncbn,
                   group_size=args.bn_group,
                   group=args.bn_group_comm,
                   use_softmax=False,
                   use_aux=False).cuda()

    logger.info(model)
    # optimizer = torch.optim.SGD(model.parameters(), args.base_lr, momentum=args.momentum, weight_decay=args.weight_decay)
    # newly introduced layer with lr x10
    optimizer = torch.optim.SGD(
        [{
            'params': model.layer0.parameters()
        }, {
            'params': model.layer1.parameters()
        }, {
            'params': model.layer2.parameters()
        }, {
            'params': model.layer3.parameters()
        }, {
            'params': model.layer4.parameters()
        }, {
            'params': model.ppm.parameters(),
            'lr': args.base_lr * 10
        }, {
            'params': model.cls.parameters(),
            'lr': args.base_lr * 10
        }, {
            'params': model.result.parameters(),
            'lr': args.base_lr * 10
        }],
        #  {'params': model.aux.parameters(), 'lr': args.base_lr * 10}],
        lr=args.base_lr,
        momentum=args.momentum,
        weight_decay=args.weight_decay)

    #model = torch.nn.DataParallel(model).cuda()
    model = DistModule(model)
    #if args.syncbn:
    #    from lib.syncbn import patch_replication_callback
    #    patch_replication_callback(model)
    cudnn.enabled = True
    cudnn.benchmark = True
    criterion = nn.NLLLoss(ignore_index=args.ignore_label).cuda()

    if args.weight:

        def map_func(storage, location):
            return storage.cuda()

        if os.path.isfile(args.weight):
            logger.info("=> loading weight '{}'".format(args.weight))
            checkpoint = torch.load(args.weight, map_location=map_func)
            model.load_state_dict(checkpoint['state_dict'])
            logger.info("=> loaded weight '{}'".format(args.weight))
        else:
            logger.info("=> no weight found at '{}'".format(args.weight))

    if args.resume:
        load_state(args.resume, model, optimizer)

    value_scale = 255
    mean = [0.485, 0.456, 0.406]
    mean = [item * value_scale for item in mean]
    std = [0.229, 0.224, 0.225]
    std = [item * value_scale for item in std]

    normalize = Normalize()
    train_data = voc12.data.VOC12ClsDataset(
        args.train_list,
        voc12_root=args.voc12_root,
        transform=transforms.Compose([
            imutils.RandomResizeLong(400, 512),
            transforms.RandomHorizontalFlip(),
            transforms.ColorJitter(brightness=0.3,
                                   contrast=0.3,
                                   saturation=0.3,
                                   hue=0.1), np.asarray, normalize,
            imutils.RandomCrop(args.crop_size), imutils.HWC_to_CHW,
            torch.from_numpy
        ]))

    train_sampler = DistributedSampler(train_data)
    train_loader = torch.utils.data.DataLoader(train_data,
                                               batch_size=args.batch_size,
                                               shuffle=False,
                                               num_workers=args.workers,
                                               pin_memory=False,
                                               sampler=train_sampler)

    for epoch in range(args.start_epoch, args.epochs + 1):
        loss_train = train(train_loader, model, criterion, optimizer, epoch,
                           args.zoom_factor, args.batch_size, args.aux_weight)
        if rank == 0:
            writer.add_scalar('loss_train', loss_train, epoch)
        # write parameters histogram costs lots of time
        # for name, param in model.named_parameters():
        #     writer.add_histogram(name, param, epoch)

        if epoch % args.save_step == 0 and rank == 0:
            filename = args.save_path + '/train_epoch_' + str(epoch) + '.pth'
            logger.info('Saving checkpoint to: ' + filename)
            torch.save(
                {
                    'epoch': epoch,
                    'state_dict': model.state_dict(),
                    'optimizer': optimizer.state_dict()
                }, filename)
Beispiel #9
0
def main():
    global args, logger
    args = get_parser().parse_args()
    logger = get_logger()
    # os.environ["CUDA_VISIBLE_DEVICES"] = ','.join(str(x) for x in args.gpu)
    logger.info(args)
    assert args.classes > 1
    assert args.zoom_factor in [1, 2, 4, 8]
    assert (args.crop_h - 1) % 8 == 0 and (args.crop_w - 1) % 8 == 0
    assert args.split in ['train', 'val', 'test']
    logger.info("=> creating model ...")
    logger.info("Classes: {}".format(args.classes))

    value_scale = 255
    mean = [0.485, 0.456, 0.406]
    mean = [item * value_scale for item in mean]
    std = [0.229, 0.224, 0.225]
    std = [item * value_scale for item in std]
    normalize = Normalize()
    infer_dataset = voc12.data.VOC12ClsDataset(
        args.train_list,
        voc12_root=args.voc12_root,
        transform=transforms.Compose([
            np.asarray,
            imutils.RandomCrop(441),
            #      normalize,
            imutils.HWC_to_CHW
        ]))

    val_loader1 = DataLoader(infer_dataset,
                             batch_size=1,
                             shuffle=False,
                             num_workers=args.workers,
                             pin_memory=True)

    from pspnet import PSPNet
    model = PSPNet(backbone=args.backbone,
                   layers=args.layers,
                   classes=args.classes,
                   zoom_factor=args.zoom_factor,
                   use_softmax=False,
                   use_aux=False,
                   pretrained=False,
                   syncbn=False).cuda()

    logger.info(model)
    #  model = torch.nn.DataParallel(model).cuda()
    model = model.cuda()
    cudnn.enabled = True
    cudnn.benchmark = True
    if os.path.isfile(args.model_path):
        logger.info("=> loading checkpoint '{}'".format(args.model_path))
        checkpoint = torch.load(args.model_path)
        #   model.load_state_dict(checkpoint['state_dict'], strict=False)
        #   logger.info("=> loaded checkpoint '{}'".format(args.model_path))

        pretrained_dict = {
            k.replace('module.', ''): v
            for k, v in checkpoint['state_dict'].items()
        }

        dict1 = model.state_dict()
        model.load_state_dict(pretrained_dict, strict=False)

    else:
        raise RuntimeError("=> no checkpoint found at '{}'".format(
            args.model_path))
    cv2.setNumThreads(0)
    validate(val_loader1, model, args.classes, mean, std, args.base_size1)
Beispiel #10
0
    parser.add_argument("--low_alpha", default=4, type=int)
    parser.add_argument("--high_alpha", default=32, type=int)
    parser.add_argument("--out_cam", default=None, type=str)
    parser.add_argument("--out_la_crf", default=None, type=str)
    parser.add_argument("--out_ha_crf", default=None, type=str)
    parser.add_argument("--out_cam_pred", default=None, type=str)

    args = parser.parse_args()

    from pspnet import PSPNet
    model = PSPNet(backbone = 'resnet', layers=50, classes=20, zoom_factor=1, pretrained=False, syncbn=False).cuda()
    checkpoint = torch.load('exp/drivable/res101_psp_coarse/model/train_epoch_14.pth')

    pretrained_dict = {k.replace('module.',''): v for k, v in checkpoint['state_dict'].items()}
    
    dict1 = model.state_dict()
    print (dict1.keys(), pretrained_dict.keys())
    for item in dict1:
        if item not in pretrained_dict.keys():
            print(item,'nbnmbkjhiuguig~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~`')
    model.load_state_dict(pretrained_dict, strict=False)

    model.eval()
    model.cuda()
    print(model)
    normalize = Normalize()
    infer_dataset = voc12.data.VOC12ClsDatasetMSF(args.infer_list, voc12_root=args.voc12_root,
                                                   scales=(1, 0.5, 1.5, 2.0),
                                                   inter_transform=torchvision.transforms.Compose(
                                                       [np.asarray,
                                                        normalize,
Beispiel #11
0
def main():
    global args, logger, writer
    args = get_parser().parse_args()
    logger = get_logger()
    writer = SummaryWriter(args.save_path)
    # os.environ["CUDA_VISIBLE_DEVICES"] = ','.join(str(x) for x in args.gpu)
    if len(args.gpu) == 1:
        args.syncbn = False
    logger.info(args)
    assert args.classes > 1
    assert args.zoom_factor in [1, 2, 4, 8]
    assert (args.crop_h-1) % 8 == 0 and (args.crop_w-1) % 8 == 0
    logger.info("=> creating model ...")
    logger.info("Classes: {}".format(args.classes))

    model = PSPNet(layers=args.layers, classes=args.classes, zoom_factor=args.zoom_factor, syncbn=args.syncbn).cuda()
    logger.info(model)
    # optimizer = torch.optim.SGD(model.parameters(), args.base_lr, momentum=args.momentum, weight_decay=args.weight_decay)
    # newly introduced layer with lr x10
    optimizer = torch.optim.SGD(
        [{'params': model.layer0.parameters()},
         {'params': model.layer1.parameters()},
         {'params': model.layer2.parameters()},
         {'params': model.layer3.parameters()},
         {'params': model.layer4.parameters()},
         {'params': model.ppm.parameters(), 'lr': args.base_lr * 10},
         {'params': model.cls.parameters(), 'lr': args.base_lr * 10},
         {'params': model.aux.parameters(), 'lr': args.base_lr * 10}],
        lr=args.base_lr, momentum=args.momentum, weight_decay=args.weight_decay)
    model = torch.nn.DataParallel(model).cuda()
    if args.syncbn:
        from lib.syncbn import patch_replication_callback
        patch_replication_callback(model)
    cudnn.enabled = True
    cudnn.benchmark = True
    criterion = nn.NLLLoss(ignore_index=args.ignore_label).cuda()

    if args.resume:
        if os.path.isfile(args.resume):
            logger.info("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            model.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            logger.info("=> loaded checkpoint '{}' (epoch {})".format(args.resume, checkpoint['epoch']))
        else:
            logger.info("=> no checkpoint found at '{}'".format(args.resume))

    value_scale = 255
    mean = [0.485, 0.456, 0.406]
    mean = [item * value_scale for item in mean]
    std = [0.229, 0.224, 0.225]
    std = [item * value_scale for item in std]

    train_transform = transforms.Compose([
        transforms.RandScale([args.scale_min, args.scale_max]),
        transforms.RandRotate([args.rotate_min, args.rotate_max], padding=mean, ignore_label=args.ignore_label),
        transforms.RandomGaussianBlur(),
        transforms.RandomHorizontalFlip(),
        transforms.Crop([args.crop_h, args.crop_w], crop_type='rand', padding=mean, ignore_label=args.ignore_label),
        transforms.ToTensor(),
        transforms.Normalize(mean=mean, std=std)])
    train_data = datasets.SegData(split='train', data_root=args.data_root, data_list=args.train_list, transform=train_transform)
    train_loader = torch.utils.data.DataLoader(train_data, batch_size=args.batch_size, shuffle=True, num_workers=args.workers, pin_memory=True)

    if args.evaluate:
        val_transform = transforms.Compose([
            transforms.Crop([args.crop_h, args.crop_w], crop_type='center', padding=mean, ignore_label=args.ignore_label),
            transforms.ToTensor(),
            transforms.Normalize(mean=mean, std=std)])
        val_data = datasets.SegData(split='val', data_root=args.data_root, data_list=args.val_list, transform=val_transform)
        val_loader = torch.utils.data.DataLoader(val_data, batch_size=args.batch_size_val, shuffle=False, num_workers=args.workers, pin_memory=True)

    for epoch in range(args.start_epoch, args.epochs + 1):
        loss_train, mIoU_train, mAcc_train, allAcc_train = train(train_loader, model, criterion, optimizer, epoch, args.zoom_factor, args.batch_size, args.aux_weight)
        writer.add_scalar('loss_train', loss_train.cpu().numpy(), epoch)
        writer.add_scalar('mIoU_train', mIoU_train, epoch)
        writer.add_scalar('mAcc_train', mAcc_train, epoch)
        writer.add_scalar('allAcc_train', allAcc_train, epoch)
        # write parameters histogram costs lots of time
        # for name, param in model.named_parameters():
        #     writer.add_histogram(name, param, epoch)

        if epoch % args.save_step == 0:
            filename = args.save_path + '/train_epoch_' + str(epoch) + '.pth'
            logger.info('Saving checkpoint to: ' + filename)
            torch.save({'epoch': epoch, 'state_dict': model.state_dict(), 'optimizer': optimizer.state_dict()}, filename)
        if args.evaluate:
            loss_val, mIoU_val, mAcc_val, allAcc_val = validate(val_loader, model, criterion, args.classes, args.zoom_factor)
            writer.add_scalar('loss_val', loss_val.cpu().numpy(), epoch)
            writer.add_scalar('mIoU_val', mIoU_val, epoch)
            writer.add_scalar('mAcc_val', mAcc_val, epoch)
            writer.add_scalar('allAcc_val', allAcc_val, epoch)