Пример #1
0
def run(args):
    with open(args.model_path + 'cfg.json') as f:
        cfg = edict(json.load(f))

    device_ids = list(map(int, args.device_ids.split(',')))
    num_devices = torch.cuda.device_count()
    if num_devices < len(device_ids):
        raise Exception('#available gpu : {} < --device_ids : {}'.format(
            num_devices, len(device_ids)))
    device = torch.device('cuda:{}'.format(device_ids[0]))

    ckpt_path = os.path.join(args.model_path, 'best.ckpt')
    ckpt = torch.load(ckpt_path, map_location=device)

    if args.fl == 'True':
        model = Classifier(cfg).to(device).eval()
        model.load_state_dict(ckpt['state_dict'])
    else:
        model = Classifier(cfg)
        model = DataParallel(model, device_ids=device_ids).to(device).eval()
        model.module.load_state_dict(ckpt['state_dict'])

    dataloader_test = DataLoader(ImageDataset(args.in_csv_path,
                                              cfg,
                                              mode='test'),
                                 batch_size=cfg.dev_batch_size,
                                 num_workers=args.num_workers,
                                 drop_last=False,
                                 shuffle=False)

    test_epoch(cfg, args, model, dataloader_test, args.out_csv_path)

    print('Save best is step :', ckpt['step'], 'AUC :', ckpt['auc_dev_best'])
Пример #2
0
def build_model(cfg, paramsfile):
    model = Classifier(cfg)
    model = model.to('cpu')
    ckpt = torch.load(paramsfile, map_location='cpu')
    state_dict = ckpt['state_dict'] if 'state_dict' in ckpt else ckpt
    model.load_state_dict(state_dict)
    if 'step' in ckpt and 'auc_dev_best' in ckpt:
        print(f"Using model '{paramsfile}' at step: {ckpt['step']} "
              f"with AUC: {ckpt['auc_dev_best']}")
    return model.eval()
class MDAIModel:
    def __init__(self):
        root_path = os.path.join(os.path.dirname(os.path.dirname(__file__)),
                                 "model")

        with open(os.path.join(root_path, "config/example.json")) as f:
            cfg = edict(json.load(f))

        self.model = Classifier(cfg)
        self.model.cfg.num_classes = [1, 1, 1, 1, 1, 1]
        self.model._init_classifier()
        self.model._init_attention_map()
        self.model._init_bn()

        if torch.cuda.is_available():
            self.model = self.model.eval().cuda()
        else:
            self.model = self.model.eval().cpu()

        chkpt_path = os.path.join(root_path, "model_best.pt")
        self.model.load_state_dict(
            torch.load(chkpt_path, map_location=lambda storage, loc: storage))

    def predict(self, data):
        """
        See https://github.com/mdai/model-deploy/blob/master/mdai/server.py for details on the
        schema of `data` and the required schema of the outputs returned by this function.
        """
        input_files = data["files"]
        input_annotations = data["annotations"]
        input_args = data["args"]

        outputs = []

        for file in input_files:
            if file["content_type"] != "application/dicom":
                continue

            ds = pydicom.dcmread(BytesIO(file["content"]))
            x = ds.pixel_array

            x_orig = x

            # preprocess image
            # convert grayscale to RGB
            x = cv2.resize(x, (1024, 1024))
            x = equalize_adapthist(x.astype(float) / x.max(), clip_limit=0.01)
            x = cv2.resize(x, (512, 512))
            x = x * 2 - 1
            x = np.array([[x, x, x]])
            x = torch.from_numpy(x).float()
            if torch.cuda.is_available():
                x = x.cuda()
            else:
                x = x.cpu()

            with torch.no_grad():
                logits, logit_maps = self.model(x)
                logits = torch.cat(logits, dim=1).detach().cpu()
                y_prob = torch.sigmoid(logits -
                                       torch.from_numpy(threshs).reshape((1,
                                                                          6)))
                y_prob = y_prob.cpu().numpy()

            x.requires_grad = True

            y_classes = y_prob >= 0.5
            class_indices = np.where(y_classes.astype("bool"))[1]

            if len(class_indices) == 0:
                # no outputs, return 'NONE' output type
                output = {
                    "type": "NONE",
                    "study_uid": str(ds.StudyInstanceUID),
                    "series_uid": str(ds.SeriesInstanceUID),
                    "instance_uid": str(ds.SOPInstanceUID),
                    "frame_number": None,
                }
                outputs.append(output)
            else:
                for class_index in class_indices:
                    probability = y_prob[0][class_index]

                    gradcam = GradCam(self.model)
                    gradcam_output = gradcam.generate_cam(
                        x, x_orig, class_index)
                    gradcam_output_buffer = BytesIO()
                    gradcam_output.save(gradcam_output_buffer, format="PNG")

                    intgrad = IntegratedGradients(self.model)
                    intgrad_output = intgrad.generate_integrated_gradients(
                        x, class_index, 5)
                    intgrad_output_buffer = BytesIO()
                    intgrad_output.save(intgrad_output_buffer, format="PNG")

                    output = {
                        "type":
                        "ANNOTATION",
                        "study_uid":
                        str(ds.StudyInstanceUID),
                        "series_uid":
                        str(ds.SeriesInstanceUID),
                        "instance_uid":
                        str(ds.SOPInstanceUID),
                        "frame_number":
                        None,
                        "class_index":
                        int(class_index),
                        "data":
                        None,
                        "probability":
                        float(probability),
                        "explanations": [
                            {
                                "name": "Grad-CAM",
                                "description":
                                "Visualize how parts of the image affects neural network’s output by looking into the activation maps. From _Grad-CAM: Visual Explanations from Deep Networks via Gradient-based Localization_ (https://arxiv.org/abs/1610.02391)",
                                "content": gradcam_output_buffer.getvalue(),
                                "content_type": "image/png",
                            },
                            {
                                "name": "Integrated Gradients",
                                "description":
                                "Visualize an average of the gradients along the construction of the input towards the decision. From _Axiomatic Attribution for Deep Networks_ (https://arxiv.org/abs/1703.01365)",
                                "content": intgrad_output_buffer.getvalue(),
                                "content_type": "image/png",
                            },
                        ],
                    }
                    outputs.append(output)

        return outputs
Пример #4
0
def main():
    if args.dataset == 'ChestXray-NIHCC':
        if args.no_fiding:
            classes = [
                'Atelectasis', 'Cardiomegaly', 'Effusion', 'Infiltration',
                'Mass', 'Nodule', 'Pneumonia', 'Pneumothorax', 'Consolidation',
                'Edema', 'Emphysema', 'Fibrosis', 'Pleural_Thickening',
                'Hernia', 'No Fiding'
            ]
        else:
            classes = [
                'Atelectasis', 'Cardiomegaly', 'Effusion', 'Infiltration',
                'Mass', 'Nodule', 'Pneumonia', 'Pneumothorax', 'Consolidation',
                'Edema', 'Emphysema', 'Fibrosis', 'Pleural_Thickening',
                'Hernia'
            ]
    elif args.dataset == 'CheXpert-v1.0-small':
        classes = [
            'No Finding', 'Enlarged Cardiomediastinum', 'Cardiomegaly',
            'Lung Opacity', 'Lung Lesion', 'Edema', 'Consolidation',
            'Pneumonia', 'Atelectasis', 'Pneumothorax', 'Pleural Effusion',
            'Pleural Other', 'Fracture', 'Support Devices'
        ]
    else:
        print('--dataset incorrect')
        return

    torch.manual_seed(args.seed)
    use_gpu = torch.cuda.is_available()
    if args.use_cpu: use_gpu = False

    if not args.evaluate:
        sys.stdout = Logger(osp.join(args.save_dir, 'log_train.txt'))
    else:
        sys.stdout = Logger(osp.join(args.save_dir, 'log_test.txt'))
    print("==========\nArgs:{}\n==========".format(args))

    if use_gpu:
        print("Currently using GPU {}".format(args.gpu_devices))
        cudnn.benchmark = True
        torch.cuda.manual_seed_all(args.seed)
    else:
        print("Currently using CPU (GPU is highly recommended)")

    pin_memory = True if use_gpu else False

    print("Initializing dataset: {}".format(args.dataset))

    data_transforms = {
        'train':
        transforms.Compose([
            transforms.RandomHorizontalFlip(),
            transforms.Resize(556),
            transforms.CenterCrop(512),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])
        ]),
        'valid':
        transforms.Compose([
            transforms.Resize(556),
            transforms.CenterCrop(512),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])
        ]),
    }

    datasetTrain = DatasetGenerator(path_base=args.base_dir,
                                    dataset_file='train',
                                    transform=data_transforms['train'],
                                    dataset_=args.dataset,
                                    no_fiding=args.no_fiding)

    datasetVal = DatasetGenerator(path_base=args.base_dir,
                                  dataset_file='valid',
                                  transform=data_transforms['valid'],
                                  dataset_=args.dataset,
                                  no_fiding=args.no_fiding)

    train_loader = DataLoader(dataset=datasetTrain,
                              batch_size=args.train_batch,
                              shuffle=args.train_shuffle,
                              num_workers=args.workers,
                              pin_memory=pin_memory)
    valid_loader = DataLoader(dataset=datasetVal,
                              batch_size=args.valid_batch,
                              shuffle=args.valid_shuffle,
                              num_workers=args.workers,
                              pin_memory=pin_memory)

    with open(args.infos_densenet) as f:
        cfg = edict(json.load(f))

    print('Initializing densenet branch')
    model_dense = Classifier(cfg)
    print("Model size: {:.5f}M".format(
        sum(p.numel() for p in model_dense.parameters()) / 1000000.0))

    with open(args.infos_resnet) as f:
        cfg = edict(json.load(f))

    print('Initializing resnet branch')
    model_res = Classifier(cfg)
    print("Model size: {:.5f}M".format(
        sum(p.numel() for p in model_res.parameters()) / 1000000.0))

    print('Initializing fusion branch')
    model_fusion = Fusion(input_size=7424, output_size=len(classes))
    print("Model size: {:.5f}M".format(
        sum(p.numel() for p in model_fusion.parameters()) / 1000000.0))

    print("Initializing optimizers")
    optimizer_dense = init_optim(args.optim, model_dense.parameters(),
                                 args.learning_rate, args.weight_decay,
                                 args.momentum)
    optimizer_res = init_optim(args.optim, model_res.parameters(),
                               args.learning_rate, args.weight_decay,
                               args.momentum)
    optimizer_fusion = init_optim(args.optim, model_fusion.parameters(),
                                  args.learning_rate, args.weight_decay,
                                  args.momentum)

    criterion = nn.BCELoss()

    print("Initializing scheduler: {}".format(args.scheduler))
    if args.stepsize > 0:
        scheduler_dense = init_scheduler(args.scheduler, optimizer_dense,
                                         args.stepsize, args.gamma)
        scheduler_res = init_scheduler(args.scheduler, optimizer_res,
                                       args.stepsize, args.gamma)
        scheduler_fusion = init_scheduler(args.scheduler, optimizer_fusion,
                                          args.stepsize, args.gamma)

    start_epoch = args.start_epoch
    best_loss = np.inf

    if args.resume_densenet:
        checkpoint_dense = torch.load(args.resume_densenet)
        model_dense.load_state_dict(checkpoint_dense['state_dict'])
        epoch_dense = checkpoint_dense['epoch']
        print("Resuming densenet from epoch {}".format(epoch_dense + 1))

    if args.resume_resnet:
        checkpoint_res = torch.load(args.resume_resnet)
        model_res.load_state_dict(checkpoint_res['state_dict'])
        epoch_res = checkpoint_res['epoch']
        print("Resuming resnet from epoch {}".format(epoch_res + 1))

    if args.resume_fusion:
        checkpoint_fusion = torch.load(args.resume_fusion)
        model_fusion.load_state_dict(checkpoint_fusion['state_dict'])
        epoch_fusion = checkpoint_fusion['epoch']
        print("Resuming fusion from epoch {}".format(epoch_fusion + 1))

    if use_gpu:
        model_dense = nn.DataParallel(model_dense).cuda()
        model_res = nn.DataParallel(model_res).cuda()
        model_fusion = nn.DataParallel(model_fusion).cuda()

    if args.evaluate:
        print("Evaluate only")
        if args.step == 1:
            valid('step1', model_dense, model_res, model_fusion, valid_loader,
                  criterion, args.print_freq, classes, cfg,
                  data_transforms['valid'])
        elif args.step == 2:
            valid('step2', model_dense, model_res, model_fusion, valid_loader,
                  criterion, args.print_freq, classes, cfg,
                  data_transforms['valid'])
        elif args.step == 3:
            valid('step3', model_dense, model_res, model_fusion, valid_loader,
                  criterion, args.print_freq, classes, cfg,
                  data_transforms['valid'])
        else:
            print('args.step not found')
        return

    if args.step == 1:
        #################################### DENSENET BRANCH INIT ##########################################
        start_time = time.time()
        train_time = 0
        best_epoch = 0
        print("==> Start training of densenet branch")

        for p in model_dense.parameters():
            p.requires_grad = True

        for p in model_res.parameters():
            p.requires_grad = False

        for p in model_fusion.parameters():
            p.requires_grad = True

        for epoch in range(start_epoch, args.max_epoch):
            start_train_time = time.time()
            train('step1', model_dense, model_res, model_fusion, train_loader,
                  optimizer_dense, optimizer_res, optimizer_fusion, criterion,
                  args.print_freq, epoch, args.max_epoch, cfg,
                  data_transforms['train'])
            train_time += round(time.time() - start_train_time)
            if args.eval_step > 0 and (epoch + 1) % args.eval_step == 0 or (
                    epoch + 1) == args.max_epoch:
                print("==> Validation")
                loss_val = valid('step1', model_dense, model_res, model_fusion,
                                 valid_loader, criterion, args.print_freq,
                                 classes, cfg, data_transforms['valid'])

                if args.stepsize > 0:
                    if args.scheduler == 'ReduceLROnPlateau':
                        scheduler_dense.step(loss_val)
                        scheduler_fusion.step(loss_val)
                    else:
                        scheduler_dense.step()
                        scheduler_fusion.step()

                is_best = loss_val < best_loss
                if is_best:
                    best_loss = loss_val
                    best_epoch = epoch + 1

                if use_gpu:
                    state_dict_dense = model_dense.module.state_dict()
                    state_dict_fusion = model_fusion.module.state_dict()
                else:
                    state_dict_dense = model_dense.state_dict()
                    state_dict_fusion = model_fusion.state_dict()

                save_checkpoint(
                    {
                        'state_dict': state_dict_dense,
                        'loss': best_loss,
                        'epoch': epoch,
                    }, is_best, args.save_dir,
                    'checkpoint_ep' + str(epoch + 1) + '.pth.tar', 'dense')
                save_checkpoint(
                    {
                        'state_dict': state_dict_fusion,
                        'loss': best_loss,
                        'epoch': epoch,
                    }, is_best, args.save_dir,
                    'checkpoint_ep' + str(epoch + 1) + '.pth.tar', 'fusion')

        print("==> Best Validation Loss {:.4%}, achieved at epoch {}".format(
            best_loss, best_epoch))

        elapsed = round(time.time() - start_time)
        elapsed = str(datetime.timedelta(seconds=elapsed))
        train_time = str(datetime.timedelta(seconds=train_time))
        print(
            "Dense branch finished. Total elapsed time (h:m:s): {}. Training time (h:m:s): {}."
            .format(elapsed, train_time))
        #################################### DENSENET BRANCH END ##########################################

    elif args.step == 2:
        #################################### RESNET BRANCH INIT ##########################################
        start_time = time.time()
        train_time = 0
        best_epoch = 0
        print("==> Start training of local branch")

        for p in model_dense.parameters():
            p.requires_grad = False

        for p in model_res.parameters():
            p.requires_grad = True

        for p in model_fusion.parameters():
            p.requires_grad = True

        for epoch in range(start_epoch, args.max_epoch):
            start_train_time = time.time()
            train('step2', model_dense, model_res, model_fusion, train_loader,
                  optimizer_dense, optimizer_res, optimizer_fusion, criterion,
                  args.print_freq, epoch, args.max_epoch, cfg,
                  data_transforms['train'])
            train_time += round(time.time() - start_train_time)
            if args.eval_step > 0 and (epoch + 1) % args.eval_step == 0 or (
                    epoch + 1) == args.max_epoch:
                print("==> Validation")
                loss_val = valid('step2', model_dense, model_res, model_fusion,
                                 valid_loader, criterion, args.print_freq,
                                 classes, cfg, data_transforms['valid'])

                if args.stepsize > 0:
                    if args.scheduler == 'ReduceLROnPlateau':
                        scheduler_res.step(loss_val)
                        scheduler_fusion.step(loss_val)
                    else:
                        scheduler_res.step()
                        scheduler_fusion.step()

                is_best = loss_val < best_loss
                if is_best:
                    best_loss = loss_val
                    best_epoch = epoch + 1

                if use_gpu:
                    state_dict_res = model_res.module.state_dict()
                    state_dict_fusion = model_fusion.module.state_dict()
                else:
                    state_dict_res = model_res.state_dict()
                    state_dict_fusion = model_fusion.state_dict()

                save_checkpoint(
                    {
                        'state_dict': state_dict_res,
                        'loss': best_loss,
                        'epoch': epoch,
                    }, is_best, args.save_dir,
                    'checkpoint_ep' + str(epoch + 1) + '.pth.tar', 'res')
                save_checkpoint(
                    {
                        'state_dict': state_dict_fusion,
                        'loss': best_loss,
                        'epoch': epoch,
                    }, is_best, args.save_dir,
                    'checkpoint_ep' + str(epoch + 1) + '.pth.tar', 'fusion')

        print("==> Best Validation Loss {:.4%}, achieved at epoch {}".format(
            best_loss, best_epoch))

        elapsed = round(time.time() - start_time)
        elapsed = str(datetime.timedelta(seconds=elapsed))
        train_time = str(datetime.timedelta(seconds=train_time))
        print(
            "Resnet branch finished. Total elapsed time (h:m:s): {}. Training time (h:m:s): {}."
            .format(elapsed, train_time))
        #################################### RESNET BRANCH END ##########################################

    elif args.step == 3:
        #################################### FUSION BRANCH INIT ##########################################
        start_time = time.time()
        train_time = 0
        best_epoch = 0
        print("==> Start training of fusion branch")

        for p in model_dense.parameters():
            p.requires_grad = True

        for p in model_res.parameters():
            p.requires_grad = True

        for p in model_fusion.parameters():
            p.requires_grad = True

        for epoch in range(start_epoch, args.max_epoch):
            start_train_time = time.time()
            train('step3', model_dense, model_res, model_fusion, train_loader,
                  optimizer_dense, optimizer_res, optimizer_fusion, criterion,
                  args.print_freq, epoch, args.max_epoch, cfg,
                  data_transforms['train'])
            train_time += round(time.time() - start_train_time)
            if args.eval_step > 0 and (epoch + 1) % args.eval_step == 0 or (
                    epoch + 1) == args.max_epoch:
                print("==> Validation")
                loss_val = valid('step3', model_dense, model_res, model_fusion,
                                 valid_loader, criterion, args.print_freq,
                                 classes, cfg, data_transforms['valid'])

                if args.stepsize > 0:
                    if args.scheduler == 'ReduceLROnPlateau':
                        scheduler_dense.step(loss_val)
                        scheduler_res.step(loss_val)
                        scheduler_fusion.step(loss_val)
                    else:
                        scheduler_dense.step()
                        scheduler_res.step()
                        scheduler_fusion.step()

                is_best = loss_val < best_loss
                if is_best:
                    best_loss = loss_val
                    best_epoch = epoch + 1

                if use_gpu:
                    state_dict_dense = model_dense.module.state_dict()
                    state_dict_res = model_res.module.state_dict()
                    state_dict_fusion = model_fusion.module.state_dict()
                else:
                    state_dict_dense = model_dense.state_dict()
                    state_dict_res = model_res.state_dict()
                    state_dict_fusion = model_fusion.state_dict()

                save_checkpoint(
                    {
                        'state_dict': state_dict_dense,
                        'loss': best_loss,
                        'epoch': epoch,
                    }, is_best, args.save_dir,
                    'checkpoint_ep' + str(epoch + 1) + '.pth.tar', 'dense')
                save_checkpoint(
                    {
                        'state_dict': state_dict_res,
                        'loss': best_loss,
                        'epoch': epoch,
                    }, is_best, args.save_dir,
                    'checkpoint_ep' + str(epoch + 1) + '.pth.tar', 'res')
                save_checkpoint(
                    {
                        'state_dict': state_dict_fusion,
                        'loss': best_loss,
                        'epoch': epoch,
                    }, is_best, args.save_dir,
                    'checkpoint_ep' + str(epoch + 1) + '.pth.tar', 'fusion')

        print("==> Best Validation Loss {:.4%}, achieved at epoch {}".format(
            best_loss, best_epoch))

        elapsed = round(time.time() - start_time)
        elapsed = str(datetime.timedelta(seconds=elapsed))
        train_time = str(datetime.timedelta(seconds=train_time))
        print(
            "Fusion branch finished. Total elapsed time (h:m:s): {}. Training time (h:m:s): {}."
            .format(elapsed, train_time))
        #################################### FUSION BRANCH END ##########################################

    else:
        print('args.step not found')
Пример #5
0
def run_fl(args):
    with open(args.cfg_path) as f:
        cfg = edict(json.load(f))
        if args.verbose is True:
            print(json.dumps(cfg, indent=4))

    if not os.path.exists(args.save_path):
        os.mkdir(args.save_path)
    if args.logtofile is True:
        logging.basicConfig(filename=args.save_path + '/log.txt',
                            filemode="w",
                            level=logging.INFO)
    else:
        logging.basicConfig(level=logging.INFO)

    if not args.resume:
        with open(os.path.join(args.save_path, 'cfg.json'), 'w') as f:
            json.dump(cfg, f, indent=1)

    device_ids = list(map(int, args.device_ids.split(',')))
    num_devices = torch.cuda.device_count()
    if num_devices < len(device_ids):
        raise Exception('#available gpu : {} < --device_ids : {}'.format(
            num_devices, len(device_ids)))
    device = torch.device('cuda:{}'.format(device_ids[0]))

    # initialise global model
    model = Classifier(cfg).to(device).train()

    if args.verbose is True:
        from torchsummary import summary
        if cfg.fix_ratio:
            h, w = cfg.long_side, cfg.long_side
        else:
            h, w = cfg.height, cfg.width
        summary(model.to(device), (3, h, w))

    if args.pre_train is not None:
        if os.path.exists(args.pre_train):
            ckpt = torch.load(args.pre_train, map_location=device)
            model.load_state_dict(ckpt)

    src_folder = os.path.dirname(os.path.abspath(__file__)) + '/../'
    dst_folder = os.path.join(args.save_path, 'classification')
    rc, size = subprocess.getstatusoutput('du --max-depth=0 %s | cut -f1' %
                                          src_folder)

    if rc != 0:
        raise Exception('Copy folder error : {}'.format(rc))
    else:
        print('Successfully determined size of directory')

    rc, err_msg = subprocess.getstatusoutput('cp -R %s %s' %
                                             (src_folder, dst_folder))
    if rc != 0:
        raise Exception('copy folder error : {}'.format(err_msg))
    else:
        print('Successfully copied folder')

    # copy train files
    train_files = cfg.train_csv
    clients = {}
    for i, c in enumerate(string.ascii_uppercase):
        if i < len(train_files):
            clients[c] = {}
        else:
            break

    # initialise clients
    for i, client in enumerate(clients):
        copyfile(train_files[i],
                 os.path.join(args.save_path, f'train_{client}.csv'))
        clients[client]['dataloader_train'] =\
            DataLoader(
                ImageDataset(train_files[i], cfg, mode='train'),
                batch_size=cfg.train_batch_size,
                num_workers=args.num_workers,drop_last=True,
                shuffle=True
            )
        clients[client]['bytes_uploaded'] = 0.0
        clients[client]['epoch'] = 0
    copyfile(cfg.dev_csv, os.path.join(args.save_path, 'dev.csv'))

    dataloader_dev = DataLoader(ImageDataset(cfg.dev_csv, cfg, mode='dev'),
                                batch_size=cfg.dev_batch_size,
                                num_workers=args.num_workers,
                                drop_last=False,
                                shuffle=False)
    dev_header = dataloader_dev.dataset._label_header

    w_global = model.state_dict()

    summary_train = {'epoch': 0, 'step': 0}
    summary_dev = {'loss': float('inf'), 'acc': 0.0}
    summary_writer = SummaryWriter(args.save_path)
    comm_rounds = cfg.epoch
    best_dict = {
        "acc_dev_best": 0.0,
        "auc_dev_best": 0.0,
        "loss_dev_best": float('inf'),
        "fused_dev_best": 0.0,
        "best_idx": 1
    }

    # Communication rounds loop
    for cr in range(comm_rounds):
        logging.info('{}, Start communication round {} of FL - {} ...'.format(
            time.strftime("%Y-%m-%d %H:%M:%S"), cr + 1, cfg.fl_technique))

        w_locals = []

        for client in clients:

            logging.info(
                '{}, Start local training process for client {}, communication round: {} ...'
                .format(time.strftime("%Y-%m-%d %H:%M:%S"), client, cr + 1))

            # Load previous current global model as start point
            model = Classifier(cfg).to(device).train()

            model.load_state_dict(w_global)

            if cfg.fl_technique == "FedProx":
                global_weight_collector = get_global_weights(model, device)
            else:
                global_weight_collector = None

            optimizer = get_optimizer(model.parameters(), cfg)

            # local training loops
            for epoch in range(cfg.local_epoch):
                lr = lr_schedule(cfg.lr, cfg.lr_factor, epoch, cfg.lr_epochs)
                for param_group in optimizer.param_groups:
                    param_group['lr'] = lr

                summary_train, best_dict = train_epoch_fl(
                    summary_train, summary_dev, cfg, args, model,
                    clients[client]['dataloader_train'], dataloader_dev,
                    optimizer, summary_writer, best_dict, dev_header, epoch,
                    global_weight_collector)

                summary_train['step'] += 1

            bytes_to_upload = sys.getsizeof(model.state_dict())
            clients[client]['bytes_uploaded'] += bytes_to_upload
            logging.info(
                '{}, Completed local rounds for client {} in communication round {}. '
                'Uploading {} bytes to server, {} bytes in total sent from client'
                .format(time.strftime("%Y-%m-%d %H:%M:%S"), client, cr + 1,
                        bytes_to_upload, clients[client]['bytes_uploaded']))

            w_locals.append(model.state_dict())

        if cfg.fl_technique == "FedAvg":
            w_global = fed_avg(w_locals)
        elif cfg.fl_technique == 'WFedAvg':
            w_global = weighted_fed_avg(w_locals, cfg.train_proportions)
        elif cfg.fl_technique == 'FedProx':
            # Use weighted FedAvg when using FedProx
            w_global = weighted_fed_avg(w_locals, cfg.train_proportions)

        # Test the performance of the averaged model
        avged_model = Classifier(cfg).to(device)
        avged_model.load_state_dict(w_global)

        time_now = time.time()
        summary_dev, predlist, true_list = test_epoch(summary_dev, cfg, args,
                                                      avged_model,
                                                      dataloader_dev)
        time_spent = time.time() - time_now

        auclist = []
        for i in range(len(cfg.num_classes)):
            y_pred = predlist[i]
            y_true = true_list[i]
            fpr, tpr, thresholds = metrics.roc_curve(y_true,
                                                     y_pred,
                                                     pos_label=1)
            auc = metrics.auc(fpr, tpr)
            auclist.append(auc)
        auc_summary = np.array(auclist)

        loss_dev_str = ' '.join(
            map(lambda x: '{:.5f}'.format(x), summary_dev['loss']))
        acc_dev_str = ' '.join(
            map(lambda x: '{:.3f}'.format(x), summary_dev['acc']))
        auc_dev_str = ' '.join(map(lambda x: '{:.3f}'.format(x), auc_summary))

        logging.info(
            '{}, Averaged Model -> Dev, Step : {}, Loss : {}, Acc : {}, Auc : {},'
            'Mean auc: {:.3f} '
            'Run Time : {:.2f} sec'.format(time.strftime("%Y-%m-%d %H:%M:%S"),
                                           summary_train['step'], loss_dev_str,
                                           acc_dev_str, auc_dev_str,
                                           auc_summary.mean(), time_spent))

        for t in range(len(cfg.num_classes)):
            summary_writer.add_scalar('dev/loss_{}'.format(dev_header[t]),
                                      summary_dev['loss'][t],
                                      summary_train['step'])
            summary_writer.add_scalar('dev/acc_{}'.format(dev_header[t]),
                                      summary_dev['acc'][t],
                                      summary_train['step'])
            summary_writer.add_scalar('dev/auc_{}'.format(dev_header[t]),
                                      auc_summary[t], summary_train['step'])

        save_best = False

        mean_acc = summary_dev['acc'][cfg.save_index].mean()
        if mean_acc >= best_dict['acc_dev_best']:
            best_dict['acc_dev_best'] = mean_acc
            if cfg.best_target == 'acc':
                save_best = True

        mean_auc = auc_summary[cfg.save_index].mean()
        if mean_auc >= best_dict['auc_dev_best']:
            best_dict['auc_dev_best'] = mean_auc
            if cfg.best_target == 'auc':
                save_best = True

        mean_loss = summary_dev['loss'][cfg.save_index].mean()
        if mean_loss <= best_dict['loss_dev_best']:
            best_dict['loss_dev_best'] = mean_loss
            if cfg.best_target == 'loss':
                save_best = True

        if save_best:
            torch.save(
                {
                    'epoch': summary_train['epoch'],
                    'step': summary_train['step'],
                    'acc_dev_best': best_dict['acc_dev_best'],
                    'auc_dev_best': best_dict['auc_dev_best'],
                    'loss_dev_best': best_dict['loss_dev_best'],
                    'state_dict': avged_model.state_dict()
                },
                os.path.join(args.save_path,
                             'best{}.ckpt'.format(best_dict['best_idx'])))

            best_dict['best_idx'] += 1
            if best_dict['best_idx'] > cfg.save_top_k:
                best_dict['best_idx'] = 1
            logging.info('{}, Best, Step : {}, Loss : {}, Acc : {},'
                         'Auc :{},Best Auc : {:.3f}'.format(
                             time.strftime("%Y-%m-%d %H:%M:%S"),
                             summary_train['step'], loss_dev_str, acc_dev_str,
                             auc_dev_str, best_dict['auc_dev_best']))
        torch.save(
            {
                'epoch': cr,
                'step': summary_train['step'],
                'acc_dev_best': best_dict['acc_dev_best'],
                'auc_dev_best': best_dict['auc_dev_best'],
                'loss_dev_best': best_dict['loss_dev_best'],
                'state_dict': avged_model.state_dict()
            }, os.path.join(args.save_path, 'train.ckpt'))
class MDAIModel:
    def __init__(self):
        root_path = os.path.dirname(__file__)

        with open(os.path.join(root_path, "config/example.json")) as f:
            cfg = edict(json.load(f))

        self.model = Classifier(cfg)
        self.model.cfg.num_classes = [1, 1, 1, 1, 1, 1]
        self.model._init_classifier()
        self.model._init_attention_map()
        self.model._init_bn()

        if torch.cuda.is_available():
            self.model = self.model.eval().cuda()
        else:
            self.model = self.model.eval().cpu()

        chkpt_path = os.path.join(root_path, "model_best.pt")
        self.model.load_state_dict(
            torch.load(chkpt_path, map_location=lambda storage, loc: storage)
        )

    def predict(self, data):
        """
        The input data has the following schema:

        {
            "instances": [
                {
                    "file": "bytes"
                    "tags": {
                        "StudyInstanceUID": "str",
                        "SeriesInstanceUID": "str",
                        "SOPInstanceUID": "str",
                        ...
                    }
                },
                ...
            ],
            "args": {
                "arg1": "str",
                "arg2": "str",
                ...
            }
        }

        Model scope specifies whether an entire study, series, or instance is given to the model.
        If the model scope is 'INSTANCE', then `instances` will be a single instance (list length of 1).
        If the model scope is 'SERIES', then `instances` will be a list of all instances in a series.
        If the model scope is 'STUDY', then `instances` will be a list of all instances in a study.

        The additional `args` dict supply values that may be used in a given run.

        For a single instance dict, `files` is the raw binary data representing a DICOM file, and
        can be loaded using: `ds = pydicom.dcmread(BytesIO(instance["file"]))`.

        The results returned by this function should have the following schema:

        [
            {
                "type": "str", // 'NONE', 'ANNOTATION', 'IMAGE', 'DICOM', 'TEXT'
                "study_uid": "str",
                "series_uid": "str",
                "instance_uid": "str",
                "frame_number": "int",
                "class_index": "int",
                "data": {},
                "probability": "float",
                "explanations": [
                    {
                        "name": "str",
                        "description": "str",
                        "content": "bytes",
                        "content_type": "str",
                    },
                    ...
                ],
            },
            ...
        ]

        The DICOM UIDs must be supplied based on the scope of the label attached to `class_index`.
        """
        input_instances = data["instances"]
        input_args = data["args"]

        results = []

        for instance in input_instances:
            tags = instance["tags"]
            ds = pydicom.dcmread(BytesIO(instance["file"]))
            x = ds.pixel_array

            x_orig = x

            # preprocess image
            # convert grayscale to RGB
            x = cv2.resize(x, (1024, 1024))
            x = equalize_adapthist(x.astype(float) / x.max(), clip_limit=0.01)
            x = cv2.resize(x, (512, 512))
            x = x * 2 - 1
            x = np.array([[x, x, x]])
            x = torch.from_numpy(x).float()
            if torch.cuda.is_available():
                x = x.cuda()
            else:
                x = x.cpu()

            with torch.no_grad():
                logits, logit_maps = self.model(x)
                logits = torch.cat(logits, dim=1).detach().cpu()
                y_prob = torch.sigmoid(logits - torch.from_numpy(threshs).reshape((1, 6)))
                y_prob = y_prob.cpu().numpy()

            x.requires_grad = True

            y_classes = y_prob >= 0.5
            class_indices = np.where(y_classes.astype("bool"))[1]

            if len(class_indices) == 0:
                # no outputs, return 'NONE' output type
                result = {
                    "type": "NONE",
                    "study_uid": tags["StudyInstanceUID"],
                    "series_uid": tags["SeriesInstanceUID"],
                    "instance_uid": tags["SOPInstanceUID"],
                    "frame_number": None,
                }
                results.append(result)
            else:
                for class_index in class_indices:
                    probability = y_prob[0][class_index]

                    gradcam = GradCam(self.model)
                    gradcam_output = gradcam.generate_cam(x, x_orig, class_index)
                    gradcam_output_buffer = BytesIO()
                    gradcam_output.save(gradcam_output_buffer, format="PNG")

                    intgrad = IntegratedGradients(self.model)
                    intgrad_output = intgrad.generate_integrated_gradients(x, class_index, 5)
                    intgrad_output_buffer = BytesIO()
                    intgrad_output.save(intgrad_output_buffer, format="PNG")

                    result = {
                        "type": "ANNOTATION",
                        "study_uid": tags["StudyInstanceUID"],
                        "series_uid": tags["SeriesInstanceUID"],
                        "instance_uid": tags["SOPInstanceUID"],
                        "frame_number": None,
                        "class_index": int(class_index),
                        "data": None,
                        "probability": float(probability),
                        "explanations": [
                            {
                                "name": "Grad-CAM",
                                "description": "Visualize how parts of the image affects neural network’s output by looking into the activation maps. From _Grad-CAM: Visual Explanations from Deep Networks via Gradient-based Localization_ (https://arxiv.org/abs/1610.02391)",
                                "content": gradcam_output_buffer.getvalue(),
                                "content_type": "image/png",
                            },
                            {
                                "name": "Integrated Gradients",
                                "description": "Visualize an average of the gradients along the construction of the input towards the decision. From _Axiomatic Attribution for Deep Networks_ (https://arxiv.org/abs/1703.01365)",
                                "content": intgrad_output_buffer.getvalue(),
                                "content_type": "image/png",
                            },
                        ],
                    }
                    results.append(result)

        return results
Пример #7
0
from converter.pytorch.pytorch_parser import PytorchParser  # noqa
from model.classifier import Classifier  # noqa

parser = argparse.ArgumentParser(description='test converter')
parser.add_argument('model_path', default=None, metavar='MODEL_PATH', type=str,
                    help="Path to the trained models")
args = parser.parse_args()

with open(args.model_path+'cfg.json') as f:
    cfg = edict(json.load(f))

model_file = "model/best.pth"
device = torch.device('cpu')  # PyTorch v0.4.0
net = Classifier(cfg)
ckpt = torch.load("model/best.ckpt")
net.load_state_dict(ckpt['state_dict'], strict=False)
torch.save(net, model_file)

net.eval()

dummy_input = torch.ones([1, 3, 1024, 1024])

net.to(device)
output = net(dummy_input)

device = torch.device("cuda")  # PyTorch v0.4.0
summary(net.to(device), (3, 1024, 1024))

pytorch_parser = PytorchParser(model_file, [3, 1024, 1024])
#
pytorch_parser.run(model_file)