Ejemplo n.º 1
0
def stateRepresentationLearningCall(exp_config):
    """
    :param exp_config: (dict)
    :return: (bool) True if no error occured
    """
    printGreen("\nLearning a state representation...")

    args = ['--no-display-plots']

    if exp_config.get('multi-view', False):
        args.extend(['--multi-view'])

    for arg in ['learning-rate', 'l1-reg', 'batch-size',
                'state-dim', 'epochs', 'seed', 'model-type',
                'log-folder', 'data-folder', 'training-set-size']:
        args.extend(['--{}'.format(arg), str(exp_config[arg])])

    ok = subprocess.call(['python', 'train.py'] + args)
    if ok == 0:
        print("End of state representation learning.\n")
        return True
    else:
        printRed("An error occured, error code: {}".format(ok))
        pprint(exp_config)
        if ok == NO_PAIRS_ERROR:
            printRed("No Pairs found, consider increasing the batch_size or using a different seed")
            return False
        elif ok == NAN_ERROR:
            printRed("NaN Loss, consider increasing NOISE_STD in the gaussian noise layer")
            return False
        elif ok != MATPLOTLIB_WARNING_CODE:
            raise RuntimeError("Error during state representation learning (config file above)")
        else:
            return False
Ejemplo n.º 2
0
def knnCall(exp_config):
    """
    Evaluate the representation using knn
    and compute knn-mse on a set of images.
    :param exp_config: (dict)
    """
    folder_path = '{}/NearestNeighbors/'.format(exp_config['log-folder'])
    createFolder(folder_path, "NearestNeighbors folder already exist")

    printGreen("\nEvaluating the state representation with KNN")

    args = ['--seed', str(exp_config['knn-seed']), '--n-samples', str(exp_config['knn-samples'])]

    if exp_config.get('ground-truth', False):
        args.extend(['--ground-truth'])

    if exp_config.get('multi-view', False):
        args.extend(['--multi-view'])

    if exp_config.get('relative-pos', False):
        args.extend(['--relative-pos'])

    for arg in ['log-folder', 'n-neighbors', 'n-to-plot']:
        args.extend(['--{}'.format(arg), str(exp_config[arg])])

    ok = subprocess.call(['python', '-m', 'evaluation.knn_images'] + args)
    printConfigOnError(ok, exp_config, "knnCall")
Ejemplo n.º 3
0
def baselineCall(exp_config, baseline="supervised"):
    """
    :param exp_config: (dict)
    :param baseline: (str) one of "supervised" , "autoencoder" or "vae"
    """
    printGreen("\n Baseline {}...".format(baseline))
    ok = False
    args = ['--no-display-plots']

    config_args = ['epochs', 'seed', 'model-type',
                   'data-folder', 'training-set-size', 'batch-size']

    if 'log-folder' in exp_config.keys():
        config_args += ['log-folder']

    if baseline in ["supervised", "autoencoder", "vae"]:

        if baseline == "supervised":
            if exp_config['relative-pos']:
                args += ['--relative-pos']
        else:
            config_args += ['state-dim']
            # because ae & vae use the script train.py with loss argument
            args += ['--losses', baseline]
        exp_config['losses'] = [baseline]

        for arg in config_args:
            args.extend(['--{}'.format(arg), str(exp_config[arg])])

        if baseline == "supervised":
            ok = subprocess.call(['python', '-m', 'srl_baselines.{}'.format(baseline)] + args)
        else:
            ok = subprocess.call(['python', 'train.py'.format(baseline)] + args)

    printConfigOnError(ok, exp_config, "baselineCall")
Ejemplo n.º 4
0
def pcaCall(exp_config):
    """
    :param exp_config: (dict)
    """
    printGreen("\n Baseline PCA...")

    args = ['--no-display-plots']
    config_args = ['data-folder', 'training-set-size', 'state-dim']

    for arg in config_args:
        args.extend(['--{}'.format(arg), str(exp_config[arg])])

    ok = subprocess.call(['python', '-m', 'srl_baselines.pca'] + args)
    printConfigOnError(ok, exp_config, "pcaCall")
def main():
    parser = argparse.ArgumentParser(description="Train the model")
    parser.add_argument('-trainf', "--train-filepath", type=str, default=None, required=True,
                        help="training dataset filepath.")
    parser.add_argument('-validf', "--val-filepath", type=str, default=None,
                        help="validation dataset filepath.")
    parser.add_argument("--shuffle", action="store_true", default=False,
                        help="Shuffle the dataset")
    parser.add_argument("--load-weights", type=str, default=None,
                        help="load pretrained weights")
    parser.add_argument("--load-model", type=str, default=None,
                        help="load pretrained model, entire model (filepath, default: None)")

    parser.add_argument("--debug", action="store_true", default=False)
    parser.add_argument('--epochs', type=int, default=30,
                        help='number of epochs to train (default: 30)')
    parser.add_argument("--batch-size", type=int, default=32,
                        help="Batch size")

    parser.add_argument('--img-shape', type=str, default="(1,512,512)",
                        help='Image shape (default "(1,512,512)"')

    parser.add_argument("--num-cpu", type=int, default=10,
                        help="Number of CPUs to use in parallel for dataloader.")
    parser.add_argument('--cuda', type=int, default=0,
                        help='CUDA visible device (use CPU if -1, default: 0)')
    parser.add_argument('--cuda-non-deterministic', action='store_true', default=False,
                        help="sets flags for non-determinism when using CUDA (potentially fast)")

    parser.add_argument('-lr', type=float, default=0.0005,
                        help='Learning rate')
    parser.add_argument('--seed', type=int, default=0,
                        help='Seed (numpy and cuda if GPU is used.).')

    parser.add_argument('--log-dir', type=str, default=None,
                        help='Save the results/model weights/logs under the directory.')

    args = parser.parse_args()

    # TODO: support image reshape
    img_shape = tuple(map(int, args.img_shape.strip()[1:-1].split(",")))

    if args.log_dir:
        os.makedirs(args.log_dir, exist_ok=True)
        best_model_path = os.path.join(args.log_dir, "model_weights.pth")
    else:
        best_model_path = None

    if args.seed is not None:
        np.random.seed(args.seed)
        torch.manual_seed(args.seed)
        if args.cuda >= 0:
            if args.cuda_non_deterministic:
                printBlue("Warning: using CUDA non-deterministc. Could be faster but results might not be reproducible.")
            else:
                printBlue("Using CUDA deterministc. Use --cuda-non-deterministic might accelerate the training a bit.")
            # Make CuDNN Determinist
            torch.backends.cudnn.deterministic = not args.cuda_non_deterministic

            # torch.cuda.manual_seed(args.seed)
            torch.cuda.manual_seed_all(args.seed)

    # TODO [OPT] enable multi-GPUs ?
    # https://pytorch.org/tutorials/beginner/former_torchies/parallelism_tutorial.html
    device = torch.device("cuda:{}".format(args.cuda) if torch.cuda.is_available()
                          and (args.cuda >= 0) else "cpu")

    # ================= Build dataloader =================
    # DataLoader
    # transform_normalize = transforms.Normalize(mean=[0.5, 0.5, 0.5],
    #                                            std=[0.5, 0.5, 0.5])
    transform_normalize = transforms.Normalize(mean=[0.5],
                                               std=[0.5])

    # Warning: DO NOT use geometry transform (do it in the dataloader instead)
    data_transform = transforms.Compose([
        # transforms.ToPILImage(mode='F'), # mode='F' for one-channel image
        # transforms.Resize((256, 256)) # NO
        # transforms.RandomResizedCrop(256), # NO
        # transforms.RandomHorizontalFlip(p=0.5), # NO
        # WARNING, ISSUE: transforms.ColorJitter doesn't work with ToPILImage(mode='F').
        # Need custom data augmentation functions: TODO: DONE.
        # transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),

        # Use OpenCVRotation, OpenCVXXX, ... (our implementation)
        # OpenCVRotation((-10, 10)), # angles (in degree)
        transforms.ToTensor(),  # already done in the dataloader
        transform_normalize
    ])

    geo_transform = GeoCompose([
        OpenCVRotation(angles=(-10, 10),
                       scales=(0.9, 1.1),
                       centers=(-0.05, 0.05)),

        # TODO add more data augmentation here
    ])

    def worker_init_fn(worker_id):
        # WARNING spawn start method is used,
        # worker_init_fn cannot be an unpicklable object, e.g., a lambda function.
        # A work-around for issue #5059: https://github.com/pytorch/pytorch/issues/5059
        np.random.seed()

    data_loader_train = {'batch_size': args.batch_size,
                         'shuffle': args.shuffle,
                         'num_workers': args.num_cpu,
                         #   'sampler': balanced_sampler,
                         'drop_last': True,  # for GAN-like
                         'pin_memory': False,
                         'worker_init_fn': worker_init_fn,
                         }

    data_loader_valid = {'batch_size': args.batch_size,
                         'shuffle': False,
                         'num_workers': args.num_cpu,
                         'drop_last': False,
                         'pin_memory': False,
                         }

    train_set = LiTSDataset(args.train_filepath,
                            dtype=np.float32,
                            geometry_transform=geo_transform,  # TODO enable data augmentation
                            pixelwise_transform=data_transform,
                            )
    valid_set = LiTSDataset(args.val_filepath,
                            dtype=np.float32,
                            pixelwise_transform=data_transform,
                            )

    dataloader_train = torch.utils.data.DataLoader(train_set, **data_loader_train)
    dataloader_valid = torch.utils.data.DataLoader(valid_set, **data_loader_valid)
    # =================== Build model ===================
    # TODO: control the model by bash command

    if args.load_weights:
        model = UNet(in_ch=1,
                     out_ch=3,  # there are 3 classes: 0: background, 1: liver, 2: tumor
                     depth=4,
                     start_ch=32, # 64
                     inc_rate=2,
                     kernel_size=5, # 3 
                     padding=True,
                     batch_norm=True,
                     spec_norm=False,
                     dropout=0.5,
                     up_mode='upconv',
                     include_top=True,
                     include_last_act=False,
                     )
        printYellow(f"Loading pretrained weights from: {args.load_weights}...")
        model.load_state_dict(torch.load(args.load_weights))
        printYellow("+ Done.")
    elif args.load_model:
        # load entire model
        model = torch.load(args.load_model)
        printYellow("Successfully loaded pretrained model.")

    model.to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, betas=(0.9, 0.95))  # TODO
    best_valid_loss = float('inf')
    # TODO TODO: add learning decay
    
    for epoch in range(args.epochs):
        for valid_mode, dataloader in enumerate([dataloader_train, dataloader_valid]):
            n_batch_per_epoch = len(dataloader)
            if args.debug:
                n_batch_per_epoch = 1

            # infinite dataloader allows several update per iteration (for special models e.g. GAN)
            dataloader = infinite_dataloader(dataloader)
            if valid_mode:
                printYellow("Switch to validation mode.")
                model.eval()
                prev_grad_mode = torch.is_grad_enabled()
                torch.set_grad_enabled(False)
            else:
                model.train()

            st = time.time()
            cum_loss = 0
            for iter_ind in range(n_batch_per_epoch):
                supplement_logs = ""
                # reset cumulated losses at the begining of each batch
                # loss_manager.reset_losses() # TODO: use torch.utils.tensorboard !!
                optimizer.zero_grad()

                img, msk = next(dataloader)
                img, msk = img.to(device), msk.to(device)

                # TODO this is ugly: convert dtype and convert the shape from (N, 1, 512, 512) to (N, 512, 512)
                msk = msk.to(torch.long).squeeze(1)

                msk_pred = model(img)  # shape (N, 3, 512, 512)

                # label_weights is determined according the liver_ratio & tumor_ratio
                # loss = CrossEntropyLoss(msk_pred, msk, label_weights=[1., 10., 100.], device=device)
                loss = DiceLoss(msk_pred, msk, label_weights=[1., 20., 50.], device=device)
                # loss = DiceLoss(msk_pred, msk, label_weights=[1., 20., 500.], device=device)

                if valid_mode:
                    pass
                else:
                    loss.backward()
                    optimizer.step()

                loss = loss.item()  # release
                cum_loss += loss
                if valid_mode:
                    print("\r--------(valid) {:.2%} Loss: {:.3f} (time: {:.1f}s) |supp: {}".format(
                        (iter_ind+1)/n_batch_per_epoch, cum_loss/(iter_ind+1), time.time()-st, supplement_logs), end="")
                else:
                    print("\rEpoch: {:3}/{} {:.2%} Loss: {:.3f} (time: {:.1f}s) |supp: {}".format(
                        (epoch+1), args.epochs, (iter_ind+1)/n_batch_per_epoch, cum_loss/(iter_ind+1), time.time()-st, supplement_logs), end="")
            print()
            if valid_mode:
                torch.set_grad_enabled(prev_grad_mode)

        valid_mean_loss = cum_loss/(iter_ind+1)  # validation (mean) loss of the current epoch

        if best_model_path and (valid_mean_loss < best_valid_loss):
            printGreen("Valid loss decreases from {:.5f} to {:.5f}, saving best model.".format(
                best_valid_loss, valid_mean_loss))
            best_valid_loss = valid_mean_loss
            # Only need to save the weights
            # torch.save(model.state_dict(), best_model_path)
            # save the entire model
            torch.save(model, best_model_path)

    return best_valid_loss
Ejemplo n.º 6
0
def inference():
    """Support two mode: evaluation (on valid set) or inference mode (on test-set for submission)

    """
    parser = argparse.ArgumentParser(description="Inference mode")
    parser.add_argument('-testf', "--test-filepath", type=str, default=None, required=True,
                        help="testing dataset filepath.")
    parser.add_argument("-eval", "--evaluate", action="store_true", default=False,
                        help="Evaluation mode")
    parser.add_argument("--load-weights", type=str, default=None,
                        help="Load pretrained weights, torch state_dict() (filepath, default: None)")
    parser.add_argument("--load-model", type=str, default=None,
                        help="Load pretrained model, entire model (filepath, default: None)")

    parser.add_argument("--save2dir", type=str, default=None,
                        help="save the prediction labels to the directory (default: None)")
    parser.add_argument("--debug", action="store_true", default=False)
    parser.add_argument("--batch-size", type=int, default=32,
                        help="Batch size")

    parser.add_argument("--num-cpu", type=int, default=10,
                        help="Number of CPUs to use in parallel for dataloader.")
    parser.add_argument('--cuda', type=int, default=0,
                        help='CUDA visible device (use CPU if -1, default: 0)')
    args = parser.parse_args()

    printYellow("="*10 + " Inference mode. "+"="*10)
    if args.save2dir:
        os.makedirs(args.save2dir, exist_ok=True)

    device = torch.device("cuda:{}".format(args.cuda) if torch.cuda.is_available()
                          and (args.cuda >= 0) else "cpu")

    transform_normalize = transforms.Normalize(mean=[0.5],
                                               std=[0.5])

    data_transform = transforms.Compose([
        transforms.ToTensor(),
        transform_normalize
    ])

    data_loader_params = {'batch_size': args.batch_size,
                          'shuffle': False,
                          'num_workers': args.num_cpu,
                          'drop_last': False,
                          'pin_memory': False
                          }

    test_set = LiTSDataset(args.test_filepath,
                           dtype=np.float32,
                           pixelwise_transform=data_transform,
                           inference_mode=(not args.evaluate),
                           )
    dataloader_test = torch.utils.data.DataLoader(test_set, **data_loader_params)
    # =================== Build model ===================
    if args.load_weights:
        model = UNet(in_ch=1,
                     out_ch=3,  # there are 3 classes: 0: background, 1: liver, 2: tumor
                     depth=4,
                     start_ch=64,
                     inc_rate=2,
                     kernel_size=3,
                     padding=True,
                     batch_norm=True,
                     spec_norm=False,
                     dropout=0.5,
                     up_mode='upconv',
                     include_top=True,
                     include_last_act=False,
                     )
        model.load_state_dict(torch.load(args.load_weights))
        printYellow("Successfully loaded pretrained weights.")
    elif args.load_model:
        # load entire model
        model = torch.load(args.load_model)
        printYellow("Successfully loaded pretrained model.")
    model.eval()
    model.to(device)

    # n_batch_per_epoch = len(dataloader_test)

    sigmoid_act = torch.nn.Sigmoid()
    st = time.time()

    volume_start_index = test_set.volume_start_index
    spacing = test_set.spacing
    direction = test_set.direction  # use it for the submission
    offset = test_set.offset

    msk_pred_buffer = []
    if args.evaluate:
        msk_gt_buffer = []

    for data_batch in tqdm(dataloader_test):
        # import ipdb
        # ipdb.set_trace()
        if args.evaluate:
            img, msk_gt = data_batch
            msk_gt_buffer.append(msk_gt.cpu().detach().numpy())
        else:
            img = data_batch
        img = img.to(device)
        with torch.no_grad():
            msk_pred = model(img)  # shape (N, 3, H, W)
            msk_pred = sigmoid_act(msk_pred)
        msk_pred_buffer.append(msk_pred.cpu().detach().numpy())

    msk_pred_buffer = np.vstack(msk_pred_buffer)  # shape (N, 3, H, W)
    if args.evaluate:
        msk_gt_buffer = np.vstack(msk_gt_buffer)

    results = []
    for vol_ind, vol_start_ind in enumerate(volume_start_index):
        if vol_ind == len(volume_start_index) - 1:
            volume_msk = msk_pred_buffer[vol_start_ind:]  # shape (N, 3, H, W)
            if args.evaluate:
                volume_msk_gt = msk_gt_buffer[vol_start_ind:]
        else:
            vol_end_ind = volume_start_index[vol_ind+1]
            volume_msk = msk_pred_buffer[vol_start_ind:vol_end_ind]  # shape (N, 3, H, W)
            if args.evaluate:
                volume_msk_gt = msk_gt_buffer[vol_start_ind:vol_end_ind]
        if args.evaluate:
            # liver
            liver_scores = get_scores(volume_msk[:, 1] >= 0.5, volume_msk_gt >= 1, spacing[vol_ind])
            # tumor
            lesion_scores = get_scores(volume_msk[:, 2] >= 0.5, volume_msk_gt == 2, spacing[vol_ind])
            print("Liver dice", liver_scores['dice'], "Lesion dice", lesion_scores['dice'])
            results.append([vol_ind, liver_scores, lesion_scores])
            # ===========================
        else:
            # import ipdb; ipdb.set_trace()
            if args.save2dir:
                # reverse the order, because we prioritize tumor, liver then background.
                msk_pred = (volume_msk >= 0.5)[:, ::-1, ...]  # shape (N, 3, H, W)
                msk_pred = np.argmax(msk_pred, axis=1)  # shape (N, H, W) = (z, x, y)
                msk_pred = np.transpose(msk_pred, axes=(1, 2, 0))  # shape (x, y, z)
                # remember to correct 'direction' and np.transpose before the submission !!!
                if direction[vol_ind][0] == -1:
                    # x-axis
                    msk_pred = msk_pred[::-1, ...]
                if direction[vol_ind][1] == -1:
                    # y-axis
                    msk_pred = msk_pred[:, ::-1, :]
                if direction[vol_ind][2] == -1:
                    # z-axis
                    msk_pred = msk_pred[..., ::-1]
                # save medical image header as well
                # see: http://loli.github.io/medpy/generated/medpy.io.header.Header.html
                file_header = med_header(spacing=tuple(spacing[vol_ind]),
                                         offset=tuple(offset[vol_ind]),
                                         direction=np.diag(direction[vol_ind]))
                # submission guide:
                # see: https://github.com/PatrickChrist/LITS-CHALLENGE/blob/master/submission-guide.md
                # test-segmentation-X.nii
                filepath = os.path.join(args.save2dir, f"test-segmentation-{vol_ind}.nii")
                med_save(msk_pred, filepath, hdr=file_header)
    if args.save2dir:
        # outpath = os.path.join(args.save2dir, "results.csv")
        outpath = os.path.join(args.save2dir, "results.pkl")
        with open(outpath, "wb") as file:
            final_result = {}
            final_result['liver'] = defaultdict(list)
            final_result['tumor'] = defaultdict(list)
            for vol_ind, liver_scores, lesion_scores in results:
                # [OTC] assuming vol_ind is continuous
                for key in liver_scores:
                    final_result['liver'][key].append(liver_scores[key])
                for key in lesion_scores:
                    final_result['tumor'][key].append(lesion_scores[key])
            pickle.dump(final_result, file, protocol=3)
        # ======== code from official metric ========
        # create line for csv file
        # outstr = str(vol_ind) + ','
        # for l in [liver_scores, lesion_scores]:
        #     for k, v in l.items():
        #         outstr += str(v) + ','
        #         outstr += '\n'
        # # create header for csv file if necessary
        # if not os.path.isfile(outpath):
        #     headerstr = 'Volume,'
        #     for k, v in liver_scores.items():
        #         headerstr += 'Liver_' + k + ','
        #     for k, v in liver_scores.items():
        #         headerstr += 'Lesion_' + k + ','
        #     headerstr += '\n'
        #     outstr = headerstr + outstr
        # # write to file
        # f = open(outpath, 'a+')
        # f.write(outstr)
        # f.close()
        # ===========================
    printGreen(f"Total elapsed time: {time.time()-st}")
    return results
Ejemplo n.º 7
0
import mido
import sys
from utils import printRed, printGreen

printGreen("Initializing")
input_devices = mido.get_input_names()
output_devices = mido.get_output_names()
output_devices = list(filter(lambda x: 'dtx' in x.lower(), output_devices))
if len(output_devices) == 0:
    printRed("Failed to found the dtx module")
    sys.exit(1) 

input_devices = list(filter(lambda x: 'deluge' in x.lower(), input_devices))
if len(input_devices) == 0:
    printRed("Failed to found the deluge module")
    sys.exit(1) 
try:
    with mido.open_output(output_devices[0]) as dtx:
        with mido.open_input(input_devices[0]) as deluge:
            printGreen("Ready to forward! In: {0} out: {1}".format(deluge.name, dtx.name))
            for msg in deluge:
                if "channel" in vars(msg) and msg.channel == 9 and "type" in vars(msg) and "note_" in msg.type:
                    printGreen("Message {} matches".format(msg))
                    dtx.send(msg)
                elif 'type' in vars(msg) and msg.type in ("program_change", "control_change"):
                    printGreen("Sending system command {}".format(msg))
                    dtx.send(msg)
except Exception as e:
    printRed("Exception! {}".format(e))
Ejemplo n.º 8
0
    # Reproduce a previous experiment using "exp_config.json"
    elif args.exp_config != "":
        with open(args.exp_config, 'r') as f:
            exp_config = json.load(f)

        print("\n Pipeline using json config file: {} \n".format(args.exp_config))
        exp_config = {k.replace('_', '-'): v for k, v in exp_config.items()}

        baseline = None
        for name in ['vae', 'autoencoder', 'supervised']:
            if name in exp_config['log-folder']:
                baseline = name
                break

        data_folder = exp_config['data-folder']
        printGreen("\nDataset folder: {}".format(data_folder))
        # Update and save config
        log_folder, experiment_name = getLogFolderName(exp_config)
        exp_config['log-folder'] = log_folder
        exp_config['experiment-name'] = experiment_name
        exp_config['relative-pos'] = useRelativePosition(data_folder)
        # Save config in log folder
        saveConfig(exp_config)
        # Check that the dataset is already preprocessed
        preprocessingCheck(exp_config)

        if baseline is None:
            # Learn a state representation and plot it
            ok = stateRepresentationLearningCall(exp_config)
            if ok:
                # Evaluate the representation with kNN