示例#1
0
文件: eval.py 项目: dingmyu/psa
def main():
    global args, logger
    args = get_parser().parse_args()
    logger = get_logger()
    # os.environ["CUDA_VISIBLE_DEVICES"] = ','.join(str(x) for x in args.gpu)
    logger.info(args)
    assert args.classes > 1
    assert args.zoom_factor in [1, 2, 4, 8]
    assert (args.crop_h - 1) % 8 == 0 and (args.crop_w - 1) % 8 == 0
    assert args.split in ['train', 'val', 'test']
    logger.info("=> creating model ...")
    logger.info("Classes: {}".format(args.classes))

    value_scale = 255
    mean = [0.485, 0.456, 0.406]
    mean = [item * value_scale for item in mean]
    std = [0.229, 0.224, 0.225]
    std = [item * value_scale for item in std]

    train_transform = transforms.Compose([
        transforms.RandScale([0.5, 2]),
        transforms.RandRotate([-10, 10], padding=mean, ignore_label=args.ignore_label),
        transforms.RandomGaussianBlur(),
        transforms.RandomHorizontalFlip(),
        transforms.Crop([args.crop_h, args.crop_w], crop_type='rand', padding=mean, ignore_label=args.ignore_label),
        transforms.ToTensor()])

    val_transform = transforms.Compose([transforms.Crop([args.crop_h, args.crop_w], crop_type='center', padding=mean, ignore_label=args.ignore_label),
                                        transforms.ToTensor()])
    val_data1 = datasets.SegData(split='train', data_root=args.data_root, data_list=args.val_list1, transform=val_transform)
    val_loader1 = torch.utils.data.DataLoader(val_data1, batch_size=1, shuffle=False, num_workers=args.workers, pin_memory=True)

    from pspnet import PSPNet
    model = PSPNet(backbone = args.backbone, layers=args.layers, classes=args.classes, zoom_factor=args.zoom_factor, use_softmax=False, use_aux=False, pretrained=False, syncbn=False).cuda()

    logger.info(model)
  #  model = torch.nn.DataParallel(model).cuda()
    model = model.cuda()
    cudnn.enabled = True
    cudnn.benchmark = True
    if os.path.isfile(args.model_path):
        logger.info("=> loading checkpoint '{}'".format(args.model_path))
        checkpoint = torch.load(args.model_path)
     #   model.load_state_dict(checkpoint['state_dict'], strict=False)
     #   logger.info("=> loaded checkpoint '{}'".format(args.model_path))


        pretrained_dict = {k.replace('module.',''): v for k, v in checkpoint['state_dict'].items()}

        dict1 = model.state_dict()
        model.load_state_dict(pretrained_dict, strict=False)

    else:
        raise RuntimeError("=> no checkpoint found at '{}'".format(args.model_path))
    cv2.setNumThreads(0)
    validate(val_loader1, val_data1.data_list, model, args.classes, mean, std, args.base_size1, args.crop_h, args.crop_w, args.scales)
示例#2
0
def build_network(snapshot):
    epoch = 0
    net = PSPNet()
    net = nn.DataParallel(net)
    if snapshot is not None:
        _, epoch = os.path.basename(snapshot).split('_')
        epoch = int(epoch)
        net.load_state_dict(torch.load(snapshot))
        logging.info("Snapshot for epoch {} loaded from {}".format(
            epoch, snapshot))
    net = net.cuda()
    return net, epoch
示例#3
0
def build_network(snapshot, backend):
    epoch = 0
    backend = backend.lower()
    net = None
    if backend.startswith('resnet'):
        net = PSPNet(sizes=(1, 2, 3, 6),
                     psp_size=2048,
                     deep_features_size=1024,
                     backend=backend)
    net = nn.DataParallel(net)
    if snapshot is not None:
        _, epoch = os.path.basename(snapshot).split('_')
        epoch = int(epoch)
        net.load_state_dict(torch.load(snapshot))
        logging.info("Snapshot for epoch {} loaded from {}".format(
            epoch, snapshot))
    net = net.cuda()
    return net, epoch
示例#4
0
def main():
    global args, best_record
    args = parser.parse_args()

    if args.augment:
        transform_train = joint_transforms.Compose([
            joint_transforms.FreeScale((512, 512)),
            joint_transforms.RandomHorizontallyFlip(),
            joint_transforms.RandomVerticallyFlip(),
            joint_transforms.Rotate(90),
        ])
        transform_val = joint_transforms.Compose(
            [joint_transforms.FreeScale((512, 512))])
    else:
        transform_train = None

    dataset_train = dataset.PRCVData('train', args.data_root,
                                     args.label_train_list, transform_train)
    dataloader_train = data.DataLoader(dataset_train,
                                       batch_size=args.batch_size,
                                       shuffle=True,
                                       num_workers=8)

    dataset_val = dataset.PRCVData('val', args.data_root, args.label_val_list,
                                   transform_val)
    dataloader_val = data.DataLoader(dataset_val,
                                     batch_size=args.batch_size,
                                     shuffle=None,
                                     num_workers=8)

    model = PSPNet(num_classes=args.num_class)

    print('Number of model parameters: {}'.format(
        sum([p.data.nelement() for p in model.parameters()])))

    saved_state_dict = torch.load(args.restore_from)
    new_params = model.state_dict().copy()
    if args.num_class != 21:
        for i in saved_state_dict:
            #Scale.layer5.conv2d_list.3.weight
            i_parts = i.split('.')
            # print i_parts
            if i_parts[0] != 'fc':
                new_params[i] = saved_state_dict[i]
    model.load_state_dict(new_params)

    model = model.cuda()
    model = torch.nn.DataParallel(model)
    cudnn.benchmark = True

    # optionally resume from a checkpoint
    if args.resume:
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            model.load_state_dict(checkpoint['state_dict'])
            print("=> loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))

    # define loss function (criterion) and pptimizer
    criterion = torch.nn.CrossEntropyLoss(ignore_index=255).cuda()
    optimizer = torch.optim.SGD([{
        'params': get_1x_lr_params(model),
        'lr': args.learning_rate
    }, {
        'params': get_10x_lr_params(model),
        'lr': 10 * args.learning_rate
    }],
                                lr=args.learning_rate,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay,
                                nesterov=True)

    for epoch in range(args.start_epoch, args.epochs):
        adjust_learning_rate(optimizer, epoch)

        # train for one epoch
        train(dataloader_train, model, criterion, optimizer, epoch)

        # evaluate on validation set
        acc, mean_iou, val_loss = validate(dataloader_val, model, criterion,
                                           args.result_pth, epoch)

        is_best = mean_iou > best_record['miou']
        if is_best:
            best_record['epoch'] = epoch
            best_record['val_loss'] = val_loss.avg
            best_record['acc'] = acc
            best_record['miou'] = mean_iou
        save_checkpoint(
            {
                'epoch': epoch + 1,
                'val_loss': val_loss.avg,
                'accuracy': acc,
                'miou': mean_iou,
                'state_dict': model.state_dict(),
            }, is_best)

        print(
            '------------------------------------------------------------------------------------------------------'
        )
        print('[epoch: %d], [val_loss: %5f], [acc: %.5f], [miou: %.5f]' %
              (epoch, val_loss.avg, acc, mean_iou))
        print(
            'best record: [epoch: {epoch}], [val_loss: {val_loss:.5f}], [acc: {acc:.5f}], [miou: {miou:.5f}]'
            .format(**best_record))
        print(
            '------------------------------------------------------------------------------------------------------'
        )
示例#5
0
def main():
    global args, logger, writer
    args = get_parser().parse_args()
    logger = get_logger()
    writer = SummaryWriter(args.save_path)
    # os.environ["CUDA_VISIBLE_DEVICES"] = ','.join(str(x) for x in args.gpu)

    if args.dist:
        dist_init(args.port, backend=args.backend)

    if len(args.gpu) == 1:
        args.syncbn = False
    logger.info(args)
    assert args.classes > 1
    assert args.zoom_factor in [1, 2, 4, 8]
    assert (args.crop_h - 1) % 8 == 0 and (args.crop_w - 1) % 8 == 0

    world_size = 1
    rank = 0
    if args.dist:
        rank = dist.get_rank()
        world_size = dist.get_world_size()
    if rank == 0:
        logger.info('dist:{}'.format(args.dist))
        logger.info("=> creating model ...")
        logger.info("Classes: {}".format(args.classes))

    # rank = dist.get_rank()

    if args.bn_group > 1:
        args.syncbn = True
        bn_sync_stats = True
        bn_group_comm = simple_group_split(world_size, rank,
                                           world_size // args.bn_group)
    else:
        args.syncbn = False
        bn_sync_stats = False
        bn_group_comm = None

    model = PSPNet(layers=args.layers,
                   classes=args.classes,
                   zoom_factor=args.zoom_factor,
                   syncbn=args.syncbn,
                   group_size=args.bn_group,
                   group=bn_group_comm,
                   sync_stats=bn_sync_stats)
    if rank == 0:
        logger.info(model)
    # optimizer = torch.optim.SGD(model.parameters(), args.base_lr, momentum=args.momentum, weight_decay=args.weight_decay)
    # newly introduced layer with lr x10
    optimizer = torch.optim.SGD([{
        'params': model.layer0.parameters()
    }, {
        'params': model.layer1.parameters()
    }, {
        'params': model.layer2.parameters()
    }, {
        'params': model.layer3.parameters()
    }, {
        'params': model.layer4.parameters()
    }, {
        'params': model.ppm.parameters(),
        'lr': args.base_lr * 10
    }, {
        'params': model.cls.parameters(),
        'lr': args.base_lr * 10
    }, {
        'params': model.aux.parameters(),
        'lr': args.base_lr * 10
    }],
                                lr=args.base_lr,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)
    # model = torch.nn.DataParallel(model).cuda()
    # if args.syncbn:
    #     from lib.syncbn import patch_replication_callback
    #     patch_replication_callback(model)
    model = model.cuda()

    cudnn.enabled = True
    cudnn.benchmark = True
    criterion = nn.NLLLoss(ignore_index=args.ignore_label)

    if args.weight:

        def map_func(storage, location):
            return storage.cuda()

        if os.path.isfile(args.weight):
            logger.info("=> loading weight '{}'".format(args.weight))
            checkpoint = torch.load(args.weight, map_location=map_func)
            model.load_state_dict(checkpoint['state_dict'])
            logger.info("=> loaded weight '{}'".format(args.weight))
        else:
            logger.info("=> no weight found at '{}'".format(args.weight))

    if args.resume:
        if os.path.isfile(args.resume):
            logger.info("=> loading checkpoint '{}'".format(args.resume))
            # checkpoint = torch.load(args.resume)
            # args.start_epoch = checkpoint['epoch']
            # model.load_state_dict(checkpoint['state_dict'])
            # optimizer.load_state_dict(checkpoint['optimizer'])
            model, optimizer, args.start_epoch = restore_from(
                model, optimizer, args.resume)

            logger.info("=> loaded checkpoint '{}' (epoch {})".format(
                args.resume, args.start_epoch))
        else:
            logger.info("=> no checkpoint found at '{}'".format(args.resume))

    if args.dist:
        broadcast_params(model)

    value_scale = 255
    mean = [0.485, 0.456, 0.406]
    mean = [item * value_scale for item in mean]
    std = [0.229, 0.224, 0.225]
    std = [item * value_scale for item in std]

    train_transform = transforms.Compose([
        transforms.RandScale([args.scale_min, args.scale_max]),
        transforms.RandRotate([args.rotate_min, args.rotate_max],
                              padding=mean,
                              ignore_label=args.ignore_label),
        transforms.RandomGaussianBlur(),
        transforms.RandomHorizontalFlip(),
        transforms.Crop([args.crop_h, args.crop_w],
                        crop_type='rand',
                        padding=mean,
                        ignore_label=args.ignore_label),
        transforms.ToTensor(),
        transforms.Normalize(mean=mean, std=std)
    ])
    train_data = datasets.SegData(split='train',
                                  data_root=args.data_root,
                                  data_list=args.train_list,
                                  transform=train_transform)
    train_sampler = None
    if args.dist:
        train_sampler = DistributedSampler(train_data)

    train_loader = torch.utils.data.DataLoader(
        train_data,
        batch_size=args.batch_size,
        shuffle=False if train_sampler else True,
        num_workers=args.workers,
        pin_memory=False,
        sampler=train_sampler)

    if args.evaluate:
        val_transform = transforms.Compose([
            transforms.Crop([args.crop_h, args.crop_w],
                            crop_type='center',
                            padding=mean,
                            ignore_label=args.ignore_label),
            transforms.ToTensor(),
            transforms.Normalize(mean=mean, std=std)
        ])
        val_data = datasets.SegData(split='val',
                                    data_root=args.data_root,
                                    data_list=args.val_list,
                                    transform=val_transform)
        val_sampler = None
        if args.dist:
            val_sampler = DistributedSampler(val_data)
        val_loader = torch.utils.data.DataLoader(
            val_data,
            batch_size=args.batch_size_val,
            shuffle=False,
            num_workers=args.workers,
            pin_memory=False,
            sampler=val_sampler)

    for epoch in range(args.start_epoch, args.epochs + 1):
        loss_train, mIoU_train, mAcc_train, allAcc_train = train(
            train_loader, model, criterion, optimizer, epoch, args.zoom_factor,
            args.batch_size, args.aux_weight)
        writer.add_scalar('loss_train', loss_train.cpu().numpy(), epoch)
        writer.add_scalar('mIoU_train', mIoU_train, epoch)
        writer.add_scalar('mAcc_train', mAcc_train, epoch)
        writer.add_scalar('allAcc_train', allAcc_train, epoch)
        # write parameters histogram costs lots of time
        # for name, param in model.named_parameters():
        #     writer.add_histogram(name, param, epoch)

        if args.evaluate and rank == 0:
            loss_val, mIoU_val, mAcc_val, allAcc_val = validate(
                val_loader, model, criterion, args.classes, args.zoom_factor)
            writer.add_scalar('loss_val', loss_val.cpu().numpy(), epoch)
            writer.add_scalar('mIoU_val', mIoU_val, epoch)
            writer.add_scalar('mAcc_val', mAcc_val, epoch)
            writer.add_scalar('allAcc_val', allAcc_val, epoch)

        if epoch % args.save_step == 0 and (rank == 0):
            filename = args.save_path + '/train_epoch_' + str(epoch) + '.pth'
            logger.info('Saving checkpoint to: ' + filename)
            torch.save(
                {
                    'epoch': epoch,
                    'state_dict': model.cpu().state_dict(),
                    'optimizer': optimizer.state_dict()
                }, filename)
示例#6
0
def train():
    print(torch.cuda.device_count())
    # os.environ["CUDA_VISIBLE_DEVICES"] = '0, 1'

    net = PSPNet(n_classes=1,
                 sizes=(1, 2, 3, 6),
                 psp_size=512,
                 deep_features_size=256,
                 backend='resnet34',
                 pretrained=False)
    net = nn.DataParallel(net)
    net = net.cuda()

    weight = torch.ones(3)
    weight[0] = 0

    # print(net)
    optimizer = Adam(net.parameters(), lr=1e-3)
    # criterion = nn.CrossEntropyLoss()

    criterion = nn.MSELoss()

    loaders = prepared_train_data()
    test_loaders = prepared_test_data()
    # print(len(loaders))
    # print(loaders)

    for epoch in range(1, 100):
        print('Training................')
        epoch_loss = []
        iteration = 0

        net.train(mode=True)

        for step, sample_batch in enumerate(loaders):

            # print("Iter:"+str(iteration))
            iteration = iteration + 1
            images = sample_batch['image'].cuda()
            # masks = torch.transpose(sample_batch['mask'], 1, 3)
            masks = sample_batch['mask'].cuda()

            # print(images.size())
            # print(masks.size())

            inputs = Variable(images)
            targets = Variable(masks)

            # print(targets.size())

            outputs = net(inputs)
            outputs = torch.clamp(outputs, 0., 255.)
            # results = outputs.cpu().data.numpy()
            #
            # # print(np.shape(results))
            #
            # map = np.squeeze(results, axis=[0, 1])
            #
            # misc.imsave('./test_images/test_image_' + str(iteration) + '.png', map)
            # print(outputs)
            # print(outputs)

            # print(outputs.size())
            optimizer.zero_grad()

            loss = criterion(outputs, targets)
            loss.backward()
            optimizer.step()

            epoch_loss.append(loss.data[0])

            # if iteration % 10 == 0:
            #     print("Epoch:{}, iteration:{}, loss:{}".format(epoch,
            #                                                    step,
            #                                                    loss))

            # if iteration % 10 == 0:
            #     results = outputs.cpu().data.numpy()
            #     # print(np.shape(results))
            #     map = np.squeeze(results, axis=1)
            #     # print(np.shape(map))
            #     # map = np.transpose(map, [0, 2, 3, 1])
            #
            #     misc.imsave('./train_image/test_image'+str(iteration)+'.png', map[0, :, :])

            # if iteration % 400 == 0:
            #     torch.save(net, 'Models/modol-'+str(epoch)+'-'+str(iteration)+'.pth')
        #
        torch.save(net, 'Models/modol-' + str(epoch) + '.pth')
        print('Testing........................')
        net.train(mode=False)
        total_m1 = 0
        disc = 44
        for iteration, item in enumerate(test_loaders):

            images = item['image'].cuda()
            # masks = torch.transpose(sample_batch['mask'], 1, 3)
            masks = item['mask'].numpy()
            name = item['name']

            masks = np.squeeze(masks, axis=0)

            test_image = Variable(images).cuda()

            predict_label = net(test_image)
            predict_label = torch.clamp(predict_label, 0., 255.)

            results = predict_label.cpu().data.numpy()

            map = np.squeeze(results, axis=[0, 1])

            gt = np.zeros(shape=np.shape(masks))

            gt[masks > 200] = 1

            prediction = np.zeros(shape=np.shape(map))
            prediction[map > disc] = 1

            overlap = gt + prediction
            # print(overlap.max(), overlap.min())
            # print(np.shape(overlap))

            image_inter = np.zeros(shape=np.shape(overlap))
            image_inter[overlap > 1] = 1
            num_inter = sum(sum(image_inter))

            # print(np.shape(num_inter))

            image_union = np.zeros(shape=np.shape(overlap))
            image_union[overlap > 0] = 1
            num_union = sum(sum(image_union))
            # print(np.shape(num_union))

            m_1 = (1 - num_inter / num_union)
            print('Image name is {}, and m1 is {}'.format(name[0], m_1))
            total_m1 = total_m1 + m_1

            map[map > disc] = 255

            #
            # misc.imsave('./test_images/test_image_' + str(iteration) + '.png', map)

            misc.imsave('./test_image/' + name[0], map)
        print('m1 is {}'.format(total_m1 / 200))
示例#7
0
文件: eval.py 项目: dingmyu/psa
def main():
    global args, logger
    args = get_parser().parse_args()
    logger = get_logger()
    # os.environ["CUDA_VISIBLE_DEVICES"] = ','.join(str(x) for x in args.gpu)
    logger.info(args)
    assert args.classes > 1
    assert args.zoom_factor in [1, 2, 4, 8]
    assert (args.crop_h - 1) % 8 == 0 and (args.crop_w - 1) % 8 == 0
    assert args.split in ['train', 'val', 'test']
    logger.info("=> creating model ...")
    logger.info("Classes: {}".format(args.classes))

    value_scale = 255
    mean = [0.485, 0.456, 0.406]
    mean = [item * value_scale for item in mean]
    std = [0.229, 0.224, 0.225]
    std = [item * value_scale for item in std]
    normalize = Normalize()
    infer_dataset = voc12.data.VOC12ClsDataset(
        args.train_list,
        voc12_root=args.voc12_root,
        transform=transforms.Compose([
            np.asarray,
            imutils.RandomCrop(441),
            #      normalize,
            imutils.HWC_to_CHW
        ]))

    val_loader1 = DataLoader(infer_dataset,
                             batch_size=1,
                             shuffle=False,
                             num_workers=args.workers,
                             pin_memory=True)

    from pspnet import PSPNet
    model = PSPNet(backbone=args.backbone,
                   layers=args.layers,
                   classes=args.classes,
                   zoom_factor=args.zoom_factor,
                   use_softmax=False,
                   use_aux=False,
                   pretrained=False,
                   syncbn=False).cuda()

    logger.info(model)
    #  model = torch.nn.DataParallel(model).cuda()
    model = model.cuda()
    cudnn.enabled = True
    cudnn.benchmark = True
    if os.path.isfile(args.model_path):
        logger.info("=> loading checkpoint '{}'".format(args.model_path))
        checkpoint = torch.load(args.model_path)
        #   model.load_state_dict(checkpoint['state_dict'], strict=False)
        #   logger.info("=> loaded checkpoint '{}'".format(args.model_path))

        pretrained_dict = {
            k.replace('module.', ''): v
            for k, v in checkpoint['state_dict'].items()
        }

        dict1 = model.state_dict()
        model.load_state_dict(pretrained_dict, strict=False)

    else:
        raise RuntimeError("=> no checkpoint found at '{}'".format(
            args.model_path))
    cv2.setNumThreads(0)
    validate(val_loader1, model, args.classes, mean, std, args.base_size1)
示例#8
0
def main():
    global args
    # 如果有多GPU 使用多GPU训练
    if torch.cuda.device_count() > 1:
        print("Let's use", torch.cuda.device_count(), "GPUs!")
        args.batch_size = args.batch_size * torch.cuda.device_count()
    else:
        print("Let's use", torch.cuda.current_device())

    if args.resume:
        assert os.path.isfile(args.resume), \
            "=> no checkpoint found at '{}'".format(args.resume)
        print("=> loading checkpoint '{}'".format(args.resume))
        checkpoint = torch.load(args.resume)
        args = checkpoint['args']
        start_epoch = checkpoint['epoch'] + 1
        best_result = checkpoint['best_result']
        model = checkpoint['model']
        optimizer = checkpoint['optimizer']
        output_directory = os.path.dirname(os.path.abspath(args.resume))
        print("=> loaded checkpoint (epoch {})".format(checkpoint['epoch']))
        train_loader, val_loader = create_data_loaders(args)
        args.resume = True
    else:
        train_loader, val_loader = create_data_loaders(args)
        print("=> creating Model ")
        model = PSPNet(n_classes=args.max_classes,
                       sizes=(1, 2, 3, 6),
                       psp_size=512,
                       deep_features_size=256,
                       backend='resnet34')
        print("=> model created.")
        start_epoch = 0
        optimizer = torch.optim.SGD(model.parameters(),
                                    args.lr,
                                    momentum=args.momentum,
                                    weight_decay=args.decay)
        best_result = np.inf

        # create results folder, if not already exists
        output_directory = os.path.join(args.saved_path)
        if not os.path.exists(output_directory):
            os.makedirs(output_directory)

    log_path = os.path.join(args.log_path, "{}".format('vocaug'))
    if os.path.isdir(log_path):
        shutil.rmtree(log_path)
    os.makedirs(log_path)
    logger = SummaryWriter(log_path)

    # for multi-gpu training
    if torch.cuda.device_count() > 1:
        model = torch.nn.DataParallel(model).cuda()
    else:
        model = model.cuda()

    # define loss function
    weights = torch.ones(args.max_classes)
    weights[0] = 0
    seg_criterion = nn.NLLLoss2d(weight=weights.cuda()).cuda()
    cls_criterion = nn.BCEWithLogitsLoss(weight=weights.cuda()).cuda()

    # criterion = [seg_criterion, cls_criterion]

    is_best = False

    for epoch in range(start_epoch, args.epochs):
        loss = train(train_loader, model, seg_criterion, optimizer, epoch,
                     logger)  # train for one epoch

        if loss < best_result:
            best_result = loss
            is_best = True
        else:
            is_best = False

        utils.save_checkpoint(
            {
                'args': args,
                'epoch': epoch,
                'model': model,
                'best_result': best_result,
                'optimizer': optimizer,
            }, is_best, epoch, output_directory)

        if (epoch + 1) % 10 == 0:
            validate(val_loader, model, epoch, logger)
示例#9
0
文件: infer_cls.py 项目: dingmyu/psa
    from pspnet import PSPNet
    model = PSPNet(backbone = 'resnet', layers=50, classes=20, zoom_factor=1, pretrained=False, syncbn=False).cuda()
    checkpoint = torch.load('exp/drivable/res101_psp_coarse/model/train_epoch_14.pth')

    pretrained_dict = {k.replace('module.',''): v for k, v in checkpoint['state_dict'].items()}
    
    dict1 = model.state_dict()
    print (dict1.keys(), pretrained_dict.keys())
    for item in dict1:
        if item not in pretrained_dict.keys():
            print(item,'nbnmbkjhiuguig~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~`')
    model.load_state_dict(pretrained_dict, strict=False)

    model.eval()
    model.cuda()
    print(model)
    normalize = Normalize()
    infer_dataset = voc12.data.VOC12ClsDatasetMSF(args.infer_list, voc12_root=args.voc12_root,
                                                   scales=(1, 0.5, 1.5, 2.0),
                                                   inter_transform=torchvision.transforms.Compose(
                                                       [np.asarray,
                                                        normalize,
                                                        imutils.HWC_to_CHW]))

    infer_data_loader = DataLoader(infer_dataset, shuffle=False, num_workers=args.num_workers, pin_memory=True)
    print('data ready')
    n_gpus = torch.cuda.device_count()
    model_replicas = torch.nn.parallel.replicate(model, list(range(n_gpus)))

    for iter, (img_name, img_list, label) in enumerate(infer_data_loader):