Esempio n. 1
0
def test(args):

    # Setup Dataloader
    data_json = json.load(open('config.json'))
    data_path = data_json[args.dataset]['data_path']

    t_loader = SaltLoader(data_path, split="test")
    test_df=t_loader.test_df
    test_loader = data.DataLoader(t_loader, batch_size=args.batch_size, num_workers=8)

    # load Model
    if args.arch=='unet':
        model = Unet(start_fm=16)
    else:
        model=Unet_upsample(start_fm=16)
    model_path = data_json[args.model]['model_path']
    model.load_state_dict(torch.load(model_path)['model_state'])
    model.cuda()
    total = sum([param.nelement() for param in model.parameters()])
    print('Number of params: %.2fM' % (total / 1e6))

    #test
    pred_list=[]
    for images in test_loader:
        images = Variable(images.cuda())
        y_preds = model(images)
        y_preds_shaped = y_preds.reshape(-1,  args.img_size_target, args.img_size_target)
        for idx in range(args.batch_size):
            y_pred = y_preds_shaped[idx]
            pred = torch.sigmoid(y_pred)
            pred = pred.cpu().data.numpy()
            pred_ori = resize(pred, (args.img_size_ori, args.img_size_ori), mode='constant', preserve_range=True)
            pred_list.append(pred_ori)

    #submit the test image predictions.
    threshold_best=args.threshold
    pred_dict = {idx: RLenc(np.round(pred_list[i] > threshold_best)) for i, idx in
                 enumerate(tqdm_notebook(test_df.index.values))}
    sub = pd.DataFrame.from_dict(pred_dict, orient='index')
    sub.index.names = ['id']
    sub.columns = ['rle_mask']
    sub.to_csv('./results/{}_submission.csv'.format(args.model))
    print("The submission.csv saved in ./results")
Esempio n. 2
0
def test(args):

    # Setup Data
    data_json = json.load(open('config.json'))
    x = Variable(torch.randn(32, 1, 128, 128))
    x = x.cuda()

    # load Model
    if args.arch == 'unet':
        model = Unet(start_fm=16)
    else:
        model = Unet_upsample(start_fm=16)
    model_path = data_json[args.model]['model_path']
    model.load_state_dict(torch.load(model_path)['model_state'])
    model.cuda()
    total = sum([param.nelement() for param in model.parameters()])
    print('Number of params: %.2fM' % (total / 1e6))

    #visualize
    y = model(x)
    g = make_dot(y)
    g.render('k')
Esempio n. 3
0
def main():
    json_path = os.path.join(args.model_dir)
    params = utils.Params(json_path)

    net = Unet(params.model).cuda()
    # TODO - check exists
    checkpoint = torch.load(args.ckpt)
    net.load_state_dict(checkpoint)

    # train_dataset = AudioDataset(data_type='train')
    test_dataset = AudioDataset(data_type='val')
    # train_data_loader = DataLoader(dataset=train_dataset, batch_size=args.batch_size,
    #                                collate_fn=train_dataset.collate, shuffle=True, num_workers=0)
    test_data_loader = DataLoader(dataset=test_dataset, batch_size=1,
                                  collate_fn=test_dataset.collate, shuffle=False, num_workers=0)

    torch.set_printoptions(precision=10, profile="full")


    train_bar = tqdm(test_data_loader, ncols=60)
    cnt = 1
    with torch.no_grad():
        for input_ in train_bar:
            train_mixed, train_clean, seq_len = map(lambda x: x.cuda(), input_)
            mixed = stft(train_mixed).unsqueeze(dim=1)
            real, imag = mixed[..., 0], mixed[..., 1]
            out_real, out_imag = net(real, imag)
            out_real, out_imag = torch.squeeze(out_real, 1), torch.squeeze(out_imag, 1)
            out_audio = istft(out_real, out_imag, train_mixed.size(1))
            out_audio = torch.squeeze(out_audio, dim=1)
            for i, l in enumerate(seq_len):
                out_audio[i, l:] = 0
            # librosa.output.write_wav('mixed.wav', train_mixed[0].cpu().data.numpy()[:seq_len[0].cpu().data.numpy()], 16000)
            # librosa.output.write_wav('clean.wav', train_clean[0].cpu().data.numpy()[:seq_len[0].cpu().data.numpy()], 16000)
            sf.write('enhanced_testset/enhanced_%03d.wav' % cnt, np.array(out_audio[0].cpu().data.numpy()[:seq_len[0].cpu().data.numpy()], dtype=np.float32), 16000,)
            cnt += 1
Esempio n. 4
0
    device0 = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
    device1 = torch.device('cuda:1' if torch.cuda.is_available() else 'cpu')
    rootDir = '/data/Jinwei/Bayesian_QSM/' + opt['weight_dir']

    # network
    unet3d = Unet(
        input_channels=1,
        output_channels=1,
        num_filters=[2**i for i in range(5, 10)],  # or range(3, 8)
        use_deconv=1,
        flag_rsa=0)
    unet3d.to(device0)
    weights_dict = torch.load(rootDir +
                              '/linear_factor=1_validation=6_test=7_unet3d.pt')
    unet3d.load_state_dict(weights_dict)

    resnet = ResBlock(
        input_dim=2,
        filter_dim=32,
        output_dim=1,
    )
    resnet.to(device1)
    weights_dict = torch.load(rootDir +
                              '/linear_factor=1_validation=6_test=7_resnet.pt')
    resnet.load_state_dict(weights_dict)

    # optimizer
    optimizer0 = optim.Adam(unet3d.parameters(), lr=lr, betas=(0.5, 0.999))
    optimizer1 = optim.Adam(resnet.parameters(), lr=lr, betas=(0.5, 0.999))
    volume_size = dataLoader_train.volume_size

    trainLoader = data.DataLoader(dataLoader_train,
                                  batch_size=batch_size,
                                  shuffle=True)

    # network of HOBIT
    unet3d = Unet(input_channels=1,
                  output_channels=1,
                  num_filters=[2**i for i in range(5, 10)],
                  use_deconv=1,
                  flag_rsa=0)
    unet3d.to(device0)
    weights_dict = torch.load(rootDir + '/weight_2nets/unet3d_fine.pt')
    # weights_dict = torch.load(rootDir+'/weight_2nets/linear_factor=1_validation=6_test=7_unet3d.pt')
    unet3d.load_state_dict(weights_dict)

    # QSMnet
    unet3d_ = Unet(input_channels=1,
                   output_channels=1,
                   num_filters=[2**i for i in range(5, 10)],
                   use_deconv=1,
                   flag_rsa=0)
    unet3d_.to(device0)
    weights_dict = torch.load(rootDir +
                              '/weight_2nets/rsa=0_validation=6_test=7_.pt'
                              )  # used to compute ich metric
    # weights_dict = torch.load(rootDir+'/weight_cv/rsa=0_validation=6_test=7.pt')
    # weights_dict = torch.load(rootDir+'/weight_2nets/linear_factor=1_validation=6_test=7_unet3d.pt')  # used on Unet simu
    unet3d_.load_state_dict(weights_dict)
Esempio n. 6
0
        unet3d = unetVggBNNAR1CLFRes(
            input_channels=1,
            output_channels=1,
            num_filters=[2**i for i in range(5, 10)],  # or range(3, 8)
            use_deconv=1)
    elif opt['flag_cfl'] == 3:
        unet3d = unetVggBNNAR1CLFEnc(
            input_channels=1,
            output_channels=1,
            num_filters=[2**i for i in range(5, 10)],  # or range(3, 8)
            use_deconv=1)

    print('{0} trainable parameters in total'.format(count_parameters(unet3d)))
    unet3d.to(device)
    unet3d.load_state_dict(
        torch.load(rootDir + opt['weight_dir'] +
                   '/cfl={0}_validation={1}_test={2}'.format(cfl, val, test) +
                   '.pt'))
    # unet3d.load_state_dict(torch.load(rootDir+opt['weight_dir']+'/weights_vi_cosmos.pt'))
    unet3d.eval()

    QSMs, STDs, RDFs = [], [], []
    RMSEs, Fidelities = [], []
    for test_dir in range(0, 5):
        dataLoader = COSMOS_data_loader(split='Test',
                                        case_validation=val,
                                        case_test=test,
                                        test_dir=test_dir,
                                        patchSize=patchSize,
                                        extraction_step=extraction_step,
                                        voxel_size=voxel_size,
                                        flag_smv=flag_smv,
Esempio n. 7
0
def validate(args):

    # Setup Dataloader
    data_json = json.load(open('config.json'))
    data_path = data_json[args.dataset]['data_path']

    v_loader = SaltLoader(data_path, split='val')
    train_df = v_loader.train_df

    val_loader = data.DataLoader(v_loader,
                                 batch_size=args.batch_size,
                                 num_workers=8)

    # load Model
    if args.arch == 'unet':
        model = Unet(start_fm=16)
    else:
        model = Unet_upsample(start_fm=16)
    model_path = data_json[args.model]['model_path']
    model.load_state_dict(torch.load(model_path)['model_state'])
    model.cuda()
    total = sum([param.nelement() for param in model.parameters()])
    print('Number of params: %.2fM' % (total / 1e6))

    #validate
    pred_list = []
    for images, masks in val_loader:
        images = Variable(images.cuda())
        y_preds = model(images)
        # print(y_preds.shape)
        y_preds_shaped = y_preds.reshape(-1, args.img_size_target,
                                         args.img_size_target)
        for idx in range(args.batch_size):
            y_pred = y_preds_shaped[idx]
            pred = torch.sigmoid(y_pred)
            pred = pred.cpu().data.numpy()
            pred_ori = resize(pred, (args.img_size_ori, args.img_size_ori),
                              mode='constant',
                              preserve_range=True)
            pred_list.append(pred_ori)

    preds_valid = np.array(pred_list)
    y_valid_ori = np.array(
        [train_df.loc[idx].masks for idx in v_loader.ids_valid])

    #jaccard score
    accuracies_best = 0.0
    for threshold in np.linspace(0, 1, 11):
        ious = []
        for y_pred, mask in zip(preds_valid, y_valid_ori):
            prediction = (y_pred > threshold).astype(int)
            iou = jaccard_similarity_score(mask.flatten(),
                                           prediction.flatten())
            ious.append(iou)

        accuracies = [
            np.mean(ious > iou_threshold)
            for iou_threshold in np.linspace(0.5, 0.95, 10)
        ]
        if accuracies_best < np.mean(accuracies):
            accuracies_best = np.mean(accuracies)
            threshold_best = threshold
        print('Threshold: %.1f, Metric: %.3f' %
              (threshold, np.mean(accuracies)))
    print("jaccard score gets threshold_best=", threshold_best)

    #other score way
    thresholds = np.linspace(0, 1, 50)
    ious = np.array([
        iou_metric_batch(y_valid_ori, np.int32(preds_valid > threshold))
        for threshold in tqdm_notebook(thresholds)
    ])
    #don't understand
    threshold_best_index = np.argmax(ious[9:-10]) + 9
    iou_best = ious[threshold_best_index]
    threshold_best = thresholds[threshold_best_index]
    print("other way gets iou_best=", iou_best, "threshold_best=",
          threshold_best)
        voxel_size = dataLoader_val.voxel_size
        volume_size = dataLoader_val.volume_size
        D_val = dipole_kernel(volume_size, voxel_size, B0_dir)

        if flag_VI:
            weights_dict = torch.load(rootDir +
                                      '/weights_VI//weights_lambda_tv=20.pt')
            weights_dict['r'] = (torch.ones(1) * r).to(device)
        else:
            weights_dict = torch.load(
                rootDir +
                '/weight/weights_sigma={0}_smv={1}_mv8'.format(sigma, 1) +
                '.pt')
            weights_dict['r'] = (torch.ones(1) * r).to(device)
        unet3d.load_state_dict(weights_dict)

        # optimizer
        optimizer = optim.Adam(unet3d.parameters(), lr=lr, betas=(0.5, 0.999))

        epoch = 0
        loss_iters = np.zeros(niter)
        while epoch < niter:
            epoch += 1

            unet3d.train()
            for idx, (qsms, rdfs_input, rdfs, masks, weights, wGs,
                      D) in enumerate(trainLoader):

                qsms = (qsms.to(device, dtype=torch.float) + trans) * scale
                rdfs = (rdfs.to(device, dtype=torch.float) + trans) * scale
Esempio n. 9
0
        flag_train=0)
    testLoader = data.DataLoader(dataLoader_test, batch_size=1, shuffle=False)

    # network
    model = Unet(input_channels=1,
                 output_channels=2,
                 num_filters=[2**i for i in range(5, 10)],
                 bilateral=1,
                 use_deconv=1,
                 use_deconv2=1,
                 renorm=0,
                 flag_r_train=0,
                 flag_UTFI=1)
    model.to(device)
    model.load_state_dict(
        torch.load(
            rootDir +
            '/weights/weight_pdf={0}_tv={1}.pt'.format(lambda_pdf, lambda_tv)))
    model.eval()

    chi_bs, chi_ls = [], []
    with torch.no_grad():
        for idx, (ifreqs, masks, data_weights, wGs) in enumerate(testLoader):
            ifreqs = ifreqs.to(device, dtype=torch.float)
            masks = masks.to(device, dtype=torch.float)
            outputs = model(ifreqs)
            chi_b, chi_l = outputs[:, 0:1, ...] * (
                1 - masks), outputs[:, 1:2, ...] * masks
            chi_bs.append(chi_b[0, ...].cpu().detach())
            chi_ls.append(chi_l[0, ...].cpu().detach())

        chi_bs = np.concatenate(chi_bs, axis=0)
Esempio n. 10
0
                           'images',
                           None,
                           False,
                           0,
                           transform=transform_test)

    dataloader = torch.utils.data.DataLoader(dataset,
                                             batch_size=1,
                                             shuffle=False,
                                             num_workers=8)

    print('model initalize')
    net = Unet(1, 1, 16).cuda()
    net = nn.DataParallel(net)
    print('model load')
    net.load_state_dict(torch.load('./ckpt/unet.pth'))
    net.eval()

    with open('output/result.csv', 'w') as f:
        f.write('id,rle_mask\n')
        for batch_image, batch_name in tqdm(dataloader):
            outputs = net(batch_image)
            outputs = F.softmax(outputs, dim=1)[:, 1, :, :]
            outputs = outputs > 0.50
            # pdb.set_trace()
            for k, v in zip(batch_name, outputs):
                run = rle_encode(np.array(v))
                f.write('{},{}\n'.format(k[:-4], run))

            pass
Esempio n. 11
0
        valLoader = data.DataLoader(dataLoader_val,
                                    batch_size=batch_size,
                                    shuffle=True)

        voxel_size = dataLoader_val.voxel_size
        volume_size = dataLoader_val.volume_size
        S = SMV_kernel(volume_size, voxel_size, radius=5)
        D = dipole_kernel(volume_size, voxel_size, B0_dir)
        D_val = np.real(S * D)

        weights_dict = torch.load(
            rootDir +
            '/weight/weights_sigma={0}_smv={1}_mv6'.format(sigma, 1) +
            '.pt')  # mv6 for plotting kl loss of validation
        weights_dict['r'] = (torch.ones(1) * r).to(device)
        unet3d.load_state_dict(weights_dict)

        # optimizer
        optimizer = optim.Adam(unet3d.parameters(), lr=lr, betas=(0.5, 0.999))

        epoch = 0
        loss_iters = np.zeros(niter)
        while epoch < niter:
            epoch += 1

            unet3d.train()
            # for idx, (rdfs, masks, weights, wGs, D) in enumerate(trainLoader):
            for idx, (rdfs_input, rdfs, masks, weights, wGs,
                      D) in enumerate(trainLoader):

                rdfs_input = (rdfs_input.to(device, dtype=torch.float) +