def create_model(num_classes, ema=False):
        model = MobileNet(num_classes)
        #model = WideResNet(num_classes)
        model.cuda()

        if ema:
            for param in model.parameters():
                param.detach_()

        return model
Exemplo n.º 2
0
# import model
from dnn121 import DenseNet121, MobileNet
import wisenet as models
# import dataset
from mixmatch_dataset import train_val_split, NIH_CXR_BASE, CxrDataset, CXR_unlabeled
from utils import Bar, Logger, AverageMeter, accuracy, mkdir_p, savefig
from tensorboardX import SummaryWriter



test_set = CxrDataset(NIH_CXR_BASE, "~/cxr-jingyi/Age/NIH_test_2500.csv") 
test_loader = data.DataLoader(test_set, batch_size=32, shuffle=False, num_workers=32)

model = MobileNet(16)
model = model.cuda()
#checkpoint = torch.load('/home/jingyi/cxr-jingyi/Age/result/supervised/model_best.pth.tar')
checkpoint = torch.load('/home/jingyi/cxr-jingyi/Age/checkpoint/cifar10-semi/exp/ckpt.pth.tar')
#model.load_state_dict(checkpoint['state_dict'])
model.load_state_dict(checkpoint['net'])

def validate(val_loader, model, mode = 'valid'):
    
    top1 = AverageMeter()
    top5 = AverageMeter()
    predict = []

    # switch to evaluate mode
    model.eval()
    
    with torch.no_grad():