예제 #1
0
class WeightEMA(object):
    def __init__(self, model, ema_model, alpha=0.999):
        self.model = model
        self.ema_model = ema_model
        self.alpha = alpha
        self.tmp_model = MobileNet(num_classes=16).cuda()
        self.wd = 0.02 * args.lr

        for param, ema_param in zip(self.model.parameters(),
                                    self.ema_model.parameters()):
            ema_param.data.copy_(param.data)

    def step(self, bn=False):
        if bn:
            # copy batchnorm stats to ema model
            for ema_param, tmp_param in zip(self.ema_model.parameters(),
                                            self.tmp_model.parameters()):
                tmp_param.data.copy_(ema_param.data.detach())

            self.ema_model.load_state_dict(self.model.state_dict())

            for ema_param, tmp_param in zip(self.ema_model.parameters(),
                                            self.tmp_model.parameters()):
                ema_param.data.copy_(tmp_param.data.detach())
        else:
            one_minus_alpha = 1.0 - self.alpha
            for param, ema_param in zip(self.model.parameters(),
                                        self.ema_model.parameters()):
                ema_param.data.mul_(self.alpha)
                ema_param.data.add_(param.data.detach() * one_minus_alpha)
                # customized weight decay
                param.data.mul_(1 - self.wd)
예제 #2
0
    def create_model(num_classes, ema=False):
        model = MobileNet(num_classes)
        model = torch.nn.DataParallel(model).cuda()

        if ema:
            for param in model.parameters():
                param.detach_()

        return model
    def create_model(num_classes, ema=False):
        model = MobileNet(num_classes)
        #model = WideResNet(num_classes)
        model.cuda()

        if ema:
            for param in model.parameters():
                param.detach_()

        return model
예제 #4
0
    def __init__(self, model, ema_model, alpha=0.999):
        self.model = model
        self.ema_model = ema_model
        self.alpha = alpha
        self.tmp_model = MobileNet(num_classes=16).cuda()
        self.wd = 0.02 * args.lr

        for param, ema_param in zip(self.model.parameters(),
                                    self.ema_model.parameters()):
            ema_param.data.copy_(param.data)
예제 #5
0
    def create_model(num_classes, ema=False):
        model = MobileNet(num_classes)
        #model = WideResNet(num_classes)
        model = torch.nn.DataParallel(model).cuda()

        if ema:
            for param in model.parameters():
                param.detach_()
                #param.requires_grad = False

        return model
예제 #6
0
    parser.add_argument('--target-layer', type=int, default=13,
                        help='Target layer')                   

    args = parser.parse_args()

    return args
    


if __name__ == '__main__':
   
    args = get_args()
    random.seed(1) 
    torch.manual_seed(1) 
    torch.backends.cudnn.deterministic = True
    model = MobileNet(16)
    checkpoint = torch.load('/home/jingyi/cxr-jingyi/Age/result/supervised/model_best.pth.tar')
    model.load_state_dict(checkpoint['state_dict'])
    grad_cam = GradCam(model=model, target_layer=args.target_layer)

    img = imageio.imread(args.image_path)
    cxr_test_transforms = tfms.Compose([
    tfms.ToPILImage(),
    tfms.Resize((512,512), interpolation=3),
    tfms.CenterCrop(256),
    tfms.ToTensor()
    ])
    img_mask = get_mask(img)
    cropped_img = segment(img, img_mask)
    # transformation
    preprocessed_img = cxr_test_transforms(cropped_img)
예제 #7
0
import torch.nn.functional as F

# import model
from dnn121 import DenseNet121, MobileNet
import wisenet as models
# import dataset
from mixmatch_dataset import train_val_split, NIH_CXR_BASE, CxrDataset, CXR_unlabeled
from utils import Bar, Logger, AverageMeter, accuracy, mkdir_p, savefig
from tensorboardX import SummaryWriter



test_set = CxrDataset(NIH_CXR_BASE, "~/cxr-jingyi/Age/NIH_test_2500.csv") 
test_loader = data.DataLoader(test_set, batch_size=32, shuffle=False, num_workers=32)

model = MobileNet(16)
model = model.cuda()
#checkpoint = torch.load('/home/jingyi/cxr-jingyi/Age/result/supervised/model_best.pth.tar')
checkpoint = torch.load('/home/jingyi/cxr-jingyi/Age/checkpoint/cifar10-semi/exp/ckpt.pth.tar')
#model.load_state_dict(checkpoint['state_dict'])
model.load_state_dict(checkpoint['net'])

def validate(val_loader, model, mode = 'valid'):
    
    top1 = AverageMeter()
    top5 = AverageMeter()
    predict = []

    # switch to evaluate mode
    model.eval()