Пример #1
0
def main():
    print(args)

    print("=> creating model '{}'".format(args.arch))
    model = Ensemble()
    model = torch.nn.DataParallel(model).cuda()
    print(model)
    
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
    std=[0.229, 0.224, 0.225])
    test_data = datautil.SceneDataset(args.data,img_transform=
                                            transforms.Compose([
                                                transforms.Resize((args.img_size,args.img_size)),
                                                transforms.ToTensor(),
                                                normalize]))
    test_loader = torch.utils.data.DataLoader(test_data,batch_size=args.batch_size,shuffle=False,num_workers=4,pin_memory=True)
    checkpoint = torch.load(args.test_model)
    model.load_state_dict(checkpoint['state_dict'])
    #model.load_state_dict(checkpoint)
    if os.path.isdir(args.data):
        ret = test(test_loader,model)
        imgs = [i[:-4] for i in os.listdir(args.data)]
        with open('result3_.csv', 'w') as f:
            '''
            f.write(','.join(['FILE_ID','CATEGORY_ID'])+'\n')
            f.write('\n'.join([','.join([str(a),str(b)]) for a,b in zip(imgs,ret)]))
            '''
            #FILE_ID,CATEGORY_ID0,CATEGORY_ID1,CATEGORY_ID2
            f.write(','.join(['FILE_ID','CATEGORY_ID0','CATEGORY_ID1','CATEGORY_ID2'])+'\n')
            f.write('\n'.join([','.join([str(a)]+[str(int(i)) for i in b]) for a,b in zip(imgs,ret)]))
    else:
        test_labeled(test_loader,model)
Пример #2
0
def main():
    global best_prec1
    if args.tensorboard:
        configure('log/'+args.arch.lower() + '_bs' + str(args.batch_size) + '_ep' + str(args.epochs) + '_loglr' + str(args.lr) +
                '_size' + str(args.img_size)+ '_wd' + str(args.weight_decay))
    print(args)
    # create model
    print("=> creating model '{}'".format(args.arch))
    
    #if args.arch.lower().startswith('resnet'):
    #    model.avgpool = nn.AvgPool2d(args.img_size // 32, 1)
    #model.fc = nn.Linear(model.fc.in_features, args.num_classes)

    # default parameter n_class=1000, input_size=224, width_mult=1.
    model = Ensemble()
    if not args.resume:
        model.MobileNetV2.load_state_dict(torch.load('mobilenet_pretrained.pth'))
        model.NASNetAMobile.load_state_dict(torch.load('nasnet_pretrained.pth'))

    model = torch.nn.DataParallel(model).cuda()
    print(model)
    # optionally resume from a checkpoint
    if args.resume:
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            best_prec1 = checkpoint['best_prec1']
            model.load_state_dict(checkpoint['state_dict'])
            print("=> loaded checkpoint '{}' (epoch {})"
                  .format(args.resume, checkpoint['epoch']))
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))

    cudnn.benchmark = True

    # Data loading code
    import datautil
    '''
    norm_dict = {
    #320:transforms.Normalize(mean=[0.4333,0.4429,0.4313],std=[ 1.,  1.,  1.]),
    320:transforms.Normalize(mean=[0.4333,0.4429,0.4313],std=[0.2295,  0.2385,  0.2479]),
    0:transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225]),
    256:transforms.Normalize(mean=[0.4333,0.4429,0.4313],std=[ 0.2295,  0.2385,  0.2479]),
    224:transforms.Normalize(mean=[0.4333,0.4429,0.4313],std=[ 0.2295,  0.2385,  0.2479]),
    }
    norm_default = norm_dict[0]
    normalize = norm_dict[args.img_size]

    currrent 
    tensor([[ 0.4828,  0.4693,  0.4602]], device='cuda:0')
    tensor([[ 45.3332,  41.1241,  45.7719]], device='cuda:0')

    '''
    #normalize = transforms.Normalize(mean=[0.48280172,0.46929353,0.46019437],std=[0.25859008,0.28414325,0.288328])
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
    std=[0.229, 0.224, 0.225])
    train_data = datautil.SceneDataset(args.data,img_transform=
                                             transforms.Compose([
                                             transforms.RandomResizedCrop(args.img_size),
                                             transforms.RandomHorizontalFlip(),
                                             transforms.ToTensor(),
                                             normalize]))


    train_loader = torch.utils.data.DataLoader(train_data,batch_size=args.batch_size,shuffle=True,num_workers=args.workers,pin_memory=True)
    if args.val:
        val_data = datautil.SceneDataset(args.val, img_transform=
                                        transforms.Compose([
                                            #transforms.Scale(256),
                                            transforms.Resize((args.img_size,args.img_size)),
                                            transforms.ToTensor(),
                                            normalize]))
        val_loader = torch.utils.data.DataLoader(val_data, batch_size=args.batch_size//2, shuffle=False,
                                             num_workers=args.workers, pin_memory=True)
    # define loss function (criterion) and optimizer
    criterion = nn.CrossEntropyLoss().cuda()
    # optimizer = torch.optim.Adam(model.parameters(), args.lr, weight_decay=args.weight_decay)

    optimizer = torch.optim.SGD(model.parameters(), args.lr,
                                 momentum=args.momentum,
                                weight_decay=args.weight_decay)
    #optimizer = torch.optim.RMSprop(model.parameters(), args.lr,
    #                             momentum=args.momentum,
    #                            weight_decay=args.weight_decay,eps=1)

    if args.evaluate:
        validate(val_loader, model, criterion,0)
        return

    for epoch in range(args.start_epoch, args.epochs):
        adjust_learning_rate(optimizer, epoch)

        # train for one epoch
        train(train_loader, model, criterion, optimizer, epoch)

        # evaluate on validation set
        if args.val:
            prec1 = validate(val_loader, model, criterion,epoch)

        # remember best prec@1 and save checkpoint
            is_best = prec1 > best_prec1
            best_prec1 = max(prec1, best_prec1)
        if epoch % args.interval == 0:
            save_checkpoint({
                'epoch': epoch + 1,
                'arch': args.arch,
                'state_dict': model.state_dict(),
            })