Ejemplo n.º 1
0
            img = img.convert('RGB') if args.channels == 3 else img.convert(
                'L')
            img = img.resize(
                (args.resize_image_width, args.resize_image_height))
            return img


if __name__ == '__main__':
    if args.channels == 3:
        normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                         std=[0.229, 0.224, 0.225])
    elif args.channels == 1:
        normalize = transforms.Normalize(mean=[0.5], std=[0.5])

    # create model
    avg_pool_size = (args.avg_pooling_height, args.avg_pooling_width)
    model = DenseNet(num_init_features=64,
                     growth_rate=32,
                     block_config=(6, 12, 24, 16),
                     num_classes=args.num_classes,
                     channels=args.channels,
                     avg_pooling_size=avg_pool_size)

    # create optimizer
    optimizer = torch.optim.Adam(model.parameters(),
                                 lr=args.lr,
                                 weight_decay=args.weight_decay)

    # start main loop
    main(args, model, pil_loader, pil_loader, normalize, optimizer)
Ejemplo n.º 2
0
def main():
    if not os.path.isfile(args.resume):
        raise ValueError("=> no checkpoint found at '{}'".format(args.resume))

    print("=> loading checkpoint '{}'".format(args.resume))
    checkpoint = torch.load(args.resume) if use_cuda else torch.load(
        args.resume, map_location=lambda storage, loc: storage)

    avg_pool_size = (args.avg_pooling_height, args.avg_pooling_width)
    if args.densenet:
        # create Model
        model = DenseNet(num_init_features=32,
                         growth_rate=16,
                         block_config=(6, 12, 24, 16),
                         channels=args.channels,
                         avg_pooling_size=avg_pool_size,
                         num_classes=args.num_classes)
        model = model.cuda() if use_cuda else model

        # create extractor
        extractor = DenseNetExtractor(model)
    else:
        # create Model
        model = ResNet(layers=[2, 2, 2, 2],
                       channels=args.channels,
                       global_pooling_size=avg_pool_size)
        model = model.cuda() if use_cuda else model

        # create extractor
        extractor = ResNetExtractor(model)

    # load Model
    state_dict = checkpoint['state_dict']
    state_dict_rename = OrderedDict()
    for k, v in state_dict.items():
        name = k.replace("module.", "")
        state_dict_rename[name] = v
    model.load_state_dict(state_dict_rename)
    print("=> loaded checkpoint '{}' (epoch {})".format(
        args.resume, checkpoint['epoch']))

    # set Model for evaluation
    model.eval()

    # create gradient class activation map
    grad_cam = GradCam(extractor, use_cuda=use_cuda)

    for (path, dir, files) in os.walk(args.image_path):
        for filename in files:
            ext = os.path.splitext(filename)[-1]
            if ext == '.png' or ext == '.jpg' or ext == '.jpeg':
                image_path = os.path.join(path, filename)

                # read image
                img = pil_loader(image_path)
                img = np.float32(img) / 255.0
                input = preprocess_image(img, args.channels)

                # get class activation map
                cam, pred = grad_cam(input, args.target_index)

                cam = np.maximum(cam, 0)
                cam = cv2.resize(cam, (args.image_width, args.image_height))
                cam = cam - np.min(cam)
                mask = cam / np.max(cam)

                # make class activation map
                img = cv2.cvtColor(
                    img, cv2.COLOR_GRAY2RGB) if args.channels == 1 else img
                result_img = make_cam_with_image(
                    img,
                    mask,
                    transparency=args.transparency,
                    blur_times=args.blur_times)

                # save cam image
                filename = os.path.splitext(filename)[0] + '@pred_' + str(pred)
                Image.fromarray(result_img).save(args.result +
                                                 "/%s.png" % filename)
                print('Saved', image_path, '->', filename)