Exemplo n.º 1
0
def main():
    json_path = os.path.join(args.model_dir)
    params = utils.Params(json_path)

    net = Unet(params.model).cuda()
    # TODO - check exists
    #checkpoint = torch.load('./final.pth.tar')
    #net.load_state_dict(checkpoint)

    train_dataset = AudioDataset(data_type='train')
    test_dataset = AudioDataset(data_type='val')
    train_data_loader = DataLoader(dataset=train_dataset,
                                   batch_size=args.batch_size,
                                   collate_fn=train_dataset.collate,
                                   shuffle=True,
                                   num_workers=4)
    test_data_loader = DataLoader(dataset=test_dataset,
                                  batch_size=args.batch_size,
                                  collate_fn=test_dataset.collate,
                                  shuffle=False,
                                  num_workers=4)

    torch.set_printoptions(precision=10, profile="full")

    # Optimizer
    optimizer = optim.Adam(net.parameters(), lr=1e-3)
    # Learning rate scheduler
    scheduler = ExponentialLR(optimizer, 0.95)

    for epoch in range(args.num_epochs):
        train_bar = tqdm(train_data_loader)
        for input in train_bar:
            train_mixed, train_clean, seq_len = map(lambda x: x.cuda(), input)
            mixed = stft(train_mixed).unsqueeze(dim=1)
            real, imag = mixed[..., 0], mixed[..., 1]
            out_real, out_imag = net(real, imag)
            out_real, out_imag = torch.squeeze(out_real,
                                               1), torch.squeeze(out_imag, 1)
            out_audio = istft(out_real, out_imag, train_mixed.size(1))
            out_audio = torch.squeeze(out_audio, dim=1)
            for i, l in enumerate(seq_len):
                out_audio[i, l:] = 0
            librosa.output.write_wav(
                'mixed.wav', train_mixed[0].cpu().data.numpy()
                [:seq_len[0].cpu().data.numpy()], 16000)
            librosa.output.write_wav(
                'clean.wav', train_clean[0].cpu().data.numpy()
                [:seq_len[0].cpu().data.numpy()], 16000)
            librosa.output.write_wav(
                'out.wav', out_audio[0].cpu().data.numpy()
                [:seq_len[0].cpu().data.numpy()], 16000)
            loss = wSDRLoss(train_mixed, train_clean, out_audio)
            print(epoch, loss)
            optimizer.zero_grad()
            loss.backward()

            optimizer.step()
        scheduler.step()
    torch.save(net.state_dict(), './final.pth.tar')
Exemplo n.º 2
0
def main():
    json_path = os.path.join(args.conf)
    params = utils.Params(json_path)

    net = Unet(params.model).cuda()
    # TODO - check exists
    # if os.path.exists('./ckpt/final.pth.tar'):
    #     checkpoint = torch.load('./ckpt/final.pth.tar')
    #     net.load_state_dict(checkpoint)

    train_dataset = AudioDataset(data_type='train')
    # test_dataset = AudioDataset(data_type='val')
    train_data_loader = DataLoader(dataset=train_dataset, batch_size=args.batch_size,
                                   collate_fn=train_dataset.collate, shuffle=True, num_workers=0)
    # test_data_loader = DataLoader(dataset=test_dataset, batch_size=args.batch_size,
    #                               collate_fn=test_dataset.collate, shuffle=False, num_workers=4)

    # torch.set_printoptions(precision=10, profile="full")

    # Optimizer
    optimizer = optim.Adam(net.parameters(), lr=1e-2)
    # Learning rate scheduler
    scheduler = ExponentialLR(optimizer, 0.996)

    if not os.path.exists('ckpt'): # model save dir
        os.mkdir('ckpt')

    for epoch in range(1, args.num_epochs+1):
        train_bar = tqdm(train_data_loader, ncols=60)
        loss_sum = 0.0
        step_cnt = 0
        for input_ in train_bar:
            train_mixed, train_clean, seq_len = map(lambda x: x.cuda(), input_)
            mixed = stft(train_mixed).unsqueeze(dim=1)
            real, imag = mixed[..., 0], mixed[..., 1]
            out_real, out_imag = net(real, imag)
            out_real, out_imag = torch.squeeze(out_real, 1), torch.squeeze(out_imag, 1)
            out_audio = istft(out_real, out_imag, train_mixed.size(1))
            out_audio = torch.squeeze(out_audio, dim=1)
            for i, l in enumerate(seq_len):
                out_audio[i, l:] = 0
            # librosa.output.write_wav('mixed.wav', train_mixed[0].cpu().data.numpy()[:seq_len[0].cpu().data.numpy()], 16000)
            # librosa.output.write_wav('clean.wav', train_clean[0].cpu().data.numpy()[:seq_len[0].cpu().data.numpy()], 16000)
            # librosa.output.write_wav('out.wav', out_audio[0].cpu().data.numpy()[:seq_len[0].cpu().data.numpy()], 16000)
            loss = wSDRLoss(train_mixed, train_clean, out_audio)
            # print(epoch, loss.item(), end='', flush=True)
            loss_sum += loss.item()
            step_cnt += 1
            optimizer.zero_grad()
            loss.backward()

            optimizer.step()

        avg_loss = loss_sum / step_cnt
        print('epoch %d> Avg_loss: %.6f.\n' % (epoch, avg_loss))
        scheduler.step()
        if epoch %20 == 0:
            torch.save(net.state_dict(), './ckpt/step%05d.pth.tar' % epoch)
Exemplo n.º 3
0
def test(args):

    # Setup Dataloader
    data_json = json.load(open('config.json'))
    data_path = data_json[args.dataset]['data_path']

    t_loader = SaltLoader(data_path, split="test")
    test_df=t_loader.test_df
    test_loader = data.DataLoader(t_loader, batch_size=args.batch_size, num_workers=8)

    # load Model
    if args.arch=='unet':
        model = Unet(start_fm=16)
    else:
        model=Unet_upsample(start_fm=16)
    model_path = data_json[args.model]['model_path']
    model.load_state_dict(torch.load(model_path)['model_state'])
    model.cuda()
    total = sum([param.nelement() for param in model.parameters()])
    print('Number of params: %.2fM' % (total / 1e6))

    #test
    pred_list=[]
    for images in test_loader:
        images = Variable(images.cuda())
        y_preds = model(images)
        y_preds_shaped = y_preds.reshape(-1,  args.img_size_target, args.img_size_target)
        for idx in range(args.batch_size):
            y_pred = y_preds_shaped[idx]
            pred = torch.sigmoid(y_pred)
            pred = pred.cpu().data.numpy()
            pred_ori = resize(pred, (args.img_size_ori, args.img_size_ori), mode='constant', preserve_range=True)
            pred_list.append(pred_ori)

    #submit the test image predictions.
    threshold_best=args.threshold
    pred_dict = {idx: RLenc(np.round(pred_list[i] > threshold_best)) for i, idx in
                 enumerate(tqdm_notebook(test_df.index.values))}
    sub = pd.DataFrame.from_dict(pred_dict, orient='index')
    sub.index.names = ['id']
    sub.columns = ['rle_mask']
    sub.to_csv('./results/{}_submission.csv'.format(args.model))
    print("The submission.csv saved in ./results")
Exemplo n.º 4
0
def test(args):

    # Setup Data
    data_json = json.load(open('config.json'))
    x = Variable(torch.randn(32, 1, 128, 128))
    x = x.cuda()

    # load Model
    if args.arch == 'unet':
        model = Unet(start_fm=16)
    else:
        model = Unet_upsample(start_fm=16)
    model_path = data_json[args.model]['model_path']
    model.load_state_dict(torch.load(model_path)['model_state'])
    model.cuda()
    total = sum([param.nelement() for param in model.parameters()])
    print('Number of params: %.2fM' % (total / 1e6))

    #visualize
    y = model(x)
    g = make_dot(y)
    g.render('k')
Exemplo n.º 5
0
    weights_dict = torch.load(rootDir +
                              '/linear_factor=1_validation=6_test=7_unet3d.pt')
    unet3d.load_state_dict(weights_dict)

    resnet = ResBlock(
        input_dim=2,
        filter_dim=32,
        output_dim=1,
    )
    resnet.to(device1)
    weights_dict = torch.load(rootDir +
                              '/linear_factor=1_validation=6_test=7_resnet.pt')
    resnet.load_state_dict(weights_dict)

    # optimizer
    optimizer0 = optim.Adam(unet3d.parameters(), lr=lr, betas=(0.5, 0.999))
    optimizer1 = optim.Adam(resnet.parameters(), lr=lr, betas=(0.5, 0.999))

    # dataloaders
    # dataLoader_train_MS = Patient_data_loader_all(patientType='MS_old', flag_RDF_input=1)
    # trainLoader_MS = data.DataLoader(dataLoader_train_MS, batch_size=batch_size, shuffle=True)

    dataLoader_train_ICH = Patient_data_loader_all(patientType='ICH',
                                                   flag_RDF_input=1)
    trainLoader_ICH = data.DataLoader(dataLoader_train_ICH,
                                      batch_size=batch_size,
                                      shuffle=True)

    epoch = 0
    gen_iterations = 1
    display_iters = 5
Exemplo n.º 6
0
                   'w': 'he_normal',
                   'b': 'normal'
               })
    criterion = nn.BCELoss(size_average=False)
else:
    print("Choose a model.\nAborting....")
    sys.exit()

if not os.path.exists('logs'):
    os.mkdir('logs')

logFile = 'logs/' + fName + '.txt'
makeLogFile(logFile)

net.to(device)
optimizer = torch.optim.Adam(net.parameters(), lr=args.lr, weight_decay=1e-5)
nTrain = len(train_loader)
nValid = len(valid_loader)
nTest = len(test_loader)

minLoss = 1e8

convIter = 0
convCheck = 20

for epoch in range(args.epochs):
    trLoss = []
    vlLoss = []
    vlGed = [0]
    klEpoch = [0]
    recEpoch = [0]
Exemplo n.º 7
0
        output_channels=1,
        num_filters=[2**i for i in range(5, 10)],  # or range(3, 8)
        use_deconv=1,
        flag_rsa=0)

    resnet = ResBlock(
        input_dim=2,
        filter_dim=32,
        output_dim=1,
    )

    unet3d.to(device)
    resnet.to(device)

    # optimizer
    optimizer = optim.Adam(list(unet3d.parameters()) +
                           list(resnet.parameters()),
                           lr=lr,
                           betas=(0.5, 0.999))
    ms = [0.3, 0.5, 0.7, 0.9]
    ms = [np.floor(m * niter).astype(int) for m in ms]
    scheduler = MultiStepLR(optimizer, milestones=ms, gamma=0.5)

    # logger
    logger = Logger('logs', rootDir, opt['linear_factor'],
                    opt['case_validation'], opt['case_test'])

    # dataloader
    dataLoader_train = COSMOS_data_loader(
        split='Train',
        patchSize=patchSize,
Exemplo n.º 8
0
def validate(args):

    # Setup Dataloader
    data_json = json.load(open('config.json'))
    data_path = data_json[args.dataset]['data_path']

    v_loader = SaltLoader(data_path, split='val')
    train_df = v_loader.train_df

    val_loader = data.DataLoader(v_loader,
                                 batch_size=args.batch_size,
                                 num_workers=8)

    # load Model
    if args.arch == 'unet':
        model = Unet(start_fm=16)
    else:
        model = Unet_upsample(start_fm=16)
    model_path = data_json[args.model]['model_path']
    model.load_state_dict(torch.load(model_path)['model_state'])
    model.cuda()
    total = sum([param.nelement() for param in model.parameters()])
    print('Number of params: %.2fM' % (total / 1e6))

    #validate
    pred_list = []
    for images, masks in val_loader:
        images = Variable(images.cuda())
        y_preds = model(images)
        # print(y_preds.shape)
        y_preds_shaped = y_preds.reshape(-1, args.img_size_target,
                                         args.img_size_target)
        for idx in range(args.batch_size):
            y_pred = y_preds_shaped[idx]
            pred = torch.sigmoid(y_pred)
            pred = pred.cpu().data.numpy()
            pred_ori = resize(pred, (args.img_size_ori, args.img_size_ori),
                              mode='constant',
                              preserve_range=True)
            pred_list.append(pred_ori)

    preds_valid = np.array(pred_list)
    y_valid_ori = np.array(
        [train_df.loc[idx].masks for idx in v_loader.ids_valid])

    #jaccard score
    accuracies_best = 0.0
    for threshold in np.linspace(0, 1, 11):
        ious = []
        for y_pred, mask in zip(preds_valid, y_valid_ori):
            prediction = (y_pred > threshold).astype(int)
            iou = jaccard_similarity_score(mask.flatten(),
                                           prediction.flatten())
            ious.append(iou)

        accuracies = [
            np.mean(ious > iou_threshold)
            for iou_threshold in np.linspace(0.5, 0.95, 10)
        ]
        if accuracies_best < np.mean(accuracies):
            accuracies_best = np.mean(accuracies)
            threshold_best = threshold
        print('Threshold: %.1f, Metric: %.3f' %
              (threshold, np.mean(accuracies)))
    print("jaccard score gets threshold_best=", threshold_best)

    #other score way
    thresholds = np.linspace(0, 1, 50)
    ious = np.array([
        iou_metric_batch(y_valid_ori, np.int32(preds_valid > threshold))
        for threshold in tqdm_notebook(thresholds)
    ])
    #don't understand
    threshold_best_index = np.argmax(ious[9:-10]) + 9
    iou_best = ious[threshold_best_index]
    threshold_best = thresholds[threshold_best_index]
    print("other way gets iou_best=", iou_best, "threshold_best=",
          threshold_best)
Exemplo n.º 9
0
def train(args):

    # Setup Dataloader
    data_json = json.load(open('config.json'))
    data_path=data_json[args.dataset]['data_path']
    t_loader = SaltLoader(data_path, img_size_ori=args.img_size_ori,img_size_target=args.img_size_target)
    v_loader = SaltLoader(data_path, split='val',img_size_ori=args.img_size_ori, img_size_target=args.img_size_target)

    train_loader = data.DataLoader(t_loader, batch_size=args.batch_size, num_workers=8, shuffle=True)
    val_loader = data.DataLoader(v_loader, batch_size=args.batch_size, num_workers=8)

    # Setup Model
    if args.arch=='unet':
        model = Unet(start_fm=16)
    else:
        model=Unet_upsample(start_fm=16)
    print(model)
    total = sum([param.nelement() for param in model.parameters()])
    print('Number of params: %.2fM' % (total / 1e6))

    model.cuda()

    # Check if model has custom optimizer / loss
    optimizer = torch.optim.Adam(model.parameters(), lr=args.l_rate)
    loss_fn= nn.BCEWithLogitsLoss()

    best_loss=100
    mean_train_losses = []
    mean_val_losses = []
    for epoch in range(args.n_epoch):
        train_losses = []
        val_losses = []
        for images, masks in train_loader:
            images = Variable(images.cuda())
            masks = Variable(masks.cuda())

            outputs = model(images)

            loss = loss_fn(outputs, masks)
            train_losses.append(loss.data)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        for images, masks in val_loader:
            images = Variable(images.cuda())
            masks = Variable(masks.cuda())

            outputs = model(images)
            loss = loss_fn(outputs, masks)
            val_losses.append(loss.data)

        mean_train_losses.append(np.mean(train_losses))
        mean_val_losses.append(np.mean(val_losses))
        if np.mean(val_losses) < best_loss:
            best_loss = np.mean(val_losses)
            state = {'epoch': epoch + 1,
                     'model_state': model.state_dict(),
                     'optimizer_state': optimizer.state_dict(), }
            torch.save(state, "./saved_models/{}_{}_best_model.pkl".format(args.arch, args.dataset))

        # Print Loss
        print('Epoch: {}. Train Loss: {}. Val Loss: {}'.format(epoch + 1, np.mean(train_losses), np.mean(val_losses)))

    state = {'model_state': model.state_dict(),
             'optimizer_state': optimizer.state_dict(), }
    torch.save(state, "./saved_models/{}_{}_final_model.pkl".format(args.arch, args.dataset))

    print("saved two models in ./saved_models")
Exemplo n.º 10
0
    D_smv = np.real(S * D)

    # network
    model = Unet(input_channels=1,
                 output_channels=2,
                 num_filters=[2**i for i in range(5, 10)],
                 bilateral=1,
                 use_deconv=1,
                 use_deconv2=1,
                 renorm=0,
                 flag_r_train=0,
                 flag_UTFI=1)
    model.to(device)

    # optimizer
    optimizer = optim.Adam(model.parameters(), lr=lr, betas=(0.5, 0.999))

    Validation_loss = []

    loss_total_list = []
    while epoch < niter:
        # training phase
        model.train()
        for idx, (ifreqs, masks, data_weights, wGs) in enumerate(trainLoader):
            ifreqs = ifreqs.to(device, dtype=torch.float)
            masks = masks.to(device, dtype=torch.float)
            data_weights = data_weights.to(device, dtype=torch.float)
            wGs = wGs.to(device, dtype=torch.float)

            loss_PDF, loss_fidelity, loss_tv = utfi_train(model,
                                                          optimizer,