Ejemplo n.º 1
0
def main():
    json_path = os.path.join(args.model_dir)
    params = utils.Params(json_path)

    net = Unet(params.model).cuda()
    # TODO - check exists
    #checkpoint = torch.load('./final.pth.tar')
    #net.load_state_dict(checkpoint)

    train_dataset = AudioDataset(data_type='train')
    test_dataset = AudioDataset(data_type='val')
    train_data_loader = DataLoader(dataset=train_dataset,
                                   batch_size=args.batch_size,
                                   collate_fn=train_dataset.collate,
                                   shuffle=True,
                                   num_workers=4)
    test_data_loader = DataLoader(dataset=test_dataset,
                                  batch_size=args.batch_size,
                                  collate_fn=test_dataset.collate,
                                  shuffle=False,
                                  num_workers=4)

    torch.set_printoptions(precision=10, profile="full")

    # Optimizer
    optimizer = optim.Adam(net.parameters(), lr=1e-3)
    # Learning rate scheduler
    scheduler = ExponentialLR(optimizer, 0.95)

    for epoch in range(args.num_epochs):
        train_bar = tqdm(train_data_loader)
        for input in train_bar:
            train_mixed, train_clean, seq_len = map(lambda x: x.cuda(), input)
            mixed = stft(train_mixed).unsqueeze(dim=1)
            real, imag = mixed[..., 0], mixed[..., 1]
            out_real, out_imag = net(real, imag)
            out_real, out_imag = torch.squeeze(out_real,
                                               1), torch.squeeze(out_imag, 1)
            out_audio = istft(out_real, out_imag, train_mixed.size(1))
            out_audio = torch.squeeze(out_audio, dim=1)
            for i, l in enumerate(seq_len):
                out_audio[i, l:] = 0
            librosa.output.write_wav(
                'mixed.wav', train_mixed[0].cpu().data.numpy()
                [:seq_len[0].cpu().data.numpy()], 16000)
            librosa.output.write_wav(
                'clean.wav', train_clean[0].cpu().data.numpy()
                [:seq_len[0].cpu().data.numpy()], 16000)
            librosa.output.write_wav(
                'out.wav', out_audio[0].cpu().data.numpy()
                [:seq_len[0].cpu().data.numpy()], 16000)
            loss = wSDRLoss(train_mixed, train_clean, out_audio)
            print(epoch, loss)
            optimizer.zero_grad()
            loss.backward()

            optimizer.step()
        scheduler.step()
    torch.save(net.state_dict(), './final.pth.tar')
Ejemplo n.º 2
0
def main():
    json_path = os.path.join(args.conf)
    params = utils.Params(json_path)

    net = Unet(params.model).cuda()
    # TODO - check exists
    # if os.path.exists('./ckpt/final.pth.tar'):
    #     checkpoint = torch.load('./ckpt/final.pth.tar')
    #     net.load_state_dict(checkpoint)

    train_dataset = AudioDataset(data_type='train')
    # test_dataset = AudioDataset(data_type='val')
    train_data_loader = DataLoader(dataset=train_dataset, batch_size=args.batch_size,
                                   collate_fn=train_dataset.collate, shuffle=True, num_workers=0)
    # test_data_loader = DataLoader(dataset=test_dataset, batch_size=args.batch_size,
    #                               collate_fn=test_dataset.collate, shuffle=False, num_workers=4)

    # torch.set_printoptions(precision=10, profile="full")

    # Optimizer
    optimizer = optim.Adam(net.parameters(), lr=1e-2)
    # Learning rate scheduler
    scheduler = ExponentialLR(optimizer, 0.996)

    if not os.path.exists('ckpt'): # model save dir
        os.mkdir('ckpt')

    for epoch in range(1, args.num_epochs+1):
        train_bar = tqdm(train_data_loader, ncols=60)
        loss_sum = 0.0
        step_cnt = 0
        for input_ in train_bar:
            train_mixed, train_clean, seq_len = map(lambda x: x.cuda(), input_)
            mixed = stft(train_mixed).unsqueeze(dim=1)
            real, imag = mixed[..., 0], mixed[..., 1]
            out_real, out_imag = net(real, imag)
            out_real, out_imag = torch.squeeze(out_real, 1), torch.squeeze(out_imag, 1)
            out_audio = istft(out_real, out_imag, train_mixed.size(1))
            out_audio = torch.squeeze(out_audio, dim=1)
            for i, l in enumerate(seq_len):
                out_audio[i, l:] = 0
            # librosa.output.write_wav('mixed.wav', train_mixed[0].cpu().data.numpy()[:seq_len[0].cpu().data.numpy()], 16000)
            # librosa.output.write_wav('clean.wav', train_clean[0].cpu().data.numpy()[:seq_len[0].cpu().data.numpy()], 16000)
            # librosa.output.write_wav('out.wav', out_audio[0].cpu().data.numpy()[:seq_len[0].cpu().data.numpy()], 16000)
            loss = wSDRLoss(train_mixed, train_clean, out_audio)
            # print(epoch, loss.item(), end='', flush=True)
            loss_sum += loss.item()
            step_cnt += 1
            optimizer.zero_grad()
            loss.backward()

            optimizer.step()

        avg_loss = loss_sum / step_cnt
        print('epoch %d> Avg_loss: %.6f.\n' % (epoch, avg_loss))
        scheduler.step()
        if epoch %20 == 0:
            torch.save(net.state_dict(), './ckpt/step%05d.pth.tar' % epoch)
Ejemplo n.º 3
0
def unet_tf_pth(checkpoint_path, pth_output_path):
    model = Unet().eval()
    state_dict = model.state_dict()

    reader = tf.train.NewCheckpointReader(checkpoint_path)

    pth_keys = state_dict.keys()
    keys = sorted(reader.get_variable_to_shape_map().keys())
    print(keys)
    print(pth_keys)
Ejemplo n.º 4
0
            # training of resnet
            rdf_inputs1 = rdf_inputs.to(device1, dtype=torch.float)
            inputs_cat = torch.cat((rdf_inputs1, qsm_inputs1), dim=1)
            rdfs1 = rdfs.to(device1, dtype=torch.float)
            masks1 = masks.to(device1, dtype=torch.float)
            weights1 = weights.to(device1, dtype=torch.float)
            wGs1 = wGs.to(device1, dtype=torch.float)

            loss_fidelity = BayesianQSM_train(model=resnet,
                                              input_RDFs=inputs_cat,
                                              in_loss_RDFs=rdfs1,
                                              QSMs=0,
                                              Masks=masks1,
                                              fidelity_Ws=weights1,
                                              gradient_Ws=wGs1,
                                              D=np.asarray(D[0, ...]),
                                              flag_COSMOS=0,
                                              optimizer=optimizer1,
                                              sigma_sq=0,
                                              Lambda_tv=0,
                                              voxel_size=(1, 1, 3),
                                              K=1,
                                              flag_l1=2)
            print('epochs: [%d/%d], time: %ds, Fidelity loss of resnet: %f' %
                  (epoch, niter, time.time() - t0, loss_fidelity))
            print(' ')

        torch.save(unet3d.state_dict(), rootDir + '/unet3d_fine.pt')
        torch.save(resnet.state_dict(), rootDir + '/resnet_fine.pt')
Ejemplo n.º 5
0
                                l2_regularisation(net.prior)
                        loss = -elbo + 1e-5 * reg_loss
                        klEpoch.append(kl.item())
                        recEpoch.append(recLoss.item())
                    else:
                        pred = torch.sigmoid(net.forward(patch, False))
                        loss = criterion(target=mask, input=pred)
                    vlLoss.append(loss.item())
                    break
                print ('Epoch [{}/{}], Step [{}/{}], TrLoss: {:.4f}, VlLoss: {:.4f}, RecLoss: {:.4f}, kl: {:.4f}, GED: {:.4f}'
                    .format(epoch+1, args.epochs, step+1, nTrain, trLoss[-1], vlLoss[-1], recLoss.item(),\
                            kl.item(), vlGed[-1]))
    epValidLoss = np.mean(vlLoss)
    if (epoch + 1) % 1 == 0 and epValidLoss > 0 and epValidLoss < minLoss:
        convIter = 0
        minLoss = epValidLoss
        print("New min: %.2f\nSaving model..." % (minLoss))
        torch.save(net.state_dict(), '../models/' + fName + '.pt')
    else:
        convIter += 1
    writeLog(logFile, epoch, np.mean(trLoss), epValidLoss, np.mean(recEpoch),
             np.mean(klEpoch),
             time.time() - t)

    if convIter == convCheck:
        print("Converged at epoch %d" % (epoch + 1 - convCheck))
        break
    elif np.isnan(epValidLoss):
        print("Nan error!")
        break
Ejemplo n.º 6
0
            for idx, (rdfs, masks, qsms) in enumerate(valLoader):
                idx += 1
                rdfs = (rdfs.to(device, dtype=torch.float) + trans) * scale
                qsms = (qsms.to(device, dtype=torch.float) + trans) * scale
                masks = masks.to(device, dtype=torch.float)

                outputs1 = unet3d(rdfs)
                outputs2 = resnet(torch.cat((rdfs, outputs1), 1))
                loss1 = loss_QSMnet(outputs1, qsms, masks, D)
                loss2 = loss_QSMnet(outputs2, qsms, masks, D)
                loss = loss1 + loss2
                loss_total += loss

            print('\n Validation loss: %f \n' % (loss_total / idx))
            Validation_loss.append(loss_total / idx)

        logger.print_and_save('Epoch: [%d/%d], Loss in Validation: %f' %
                              (epoch, niter, Validation_loss[-1]))

        if Validation_loss[-1] == min(Validation_loss):
            torch.save(
                unet3d.state_dict(), rootDir +
                '/linear_factor={0}_validation={1}_test={2}_unet3d'.format(
                    opt['linear_factor'], opt['case_validation'],
                    opt['case_test']) + '.pt')
            torch.save(
                resnet.state_dict(), rootDir +
                '/linear_factor={0}_validation={1}_test={2}_resnet'.format(
                    opt['linear_factor'], opt['case_validation'],
                    opt['case_test']) + '.pt')
            for idx, (rdfs, masks, weights, qsms) in enumerate(valLoader):
                idx += 1
                rdfs = (rdfs.to(device, dtype=torch.float) + trans) * scale
                qsms = (qsms.to(device, dtype=torch.float) + trans) * scale
                masks = masks.to(device, dtype=torch.float)
                outputs = unet3d(rdfs)

                mean_Maps = outputs[:, 0:1, ...]
                var_Maps = outputs[:, 1:2, ...]

                # loss1 = 1/2*torch.sum((mean_Maps - qsms)**2 / torch.exp(var_Maps))
                # loss2 = 1/2*torch.sum(var_Maps)
                loss1 = 1 / 2 * torch.sum(
                    (mean_Maps * masks - qsms * masks)**2 / var_Maps)
                loss2 = 1 / 2 * torch.sum(torch.log(var_Maps) * masks)
                loss = loss1 + loss2
                loss_total += loss.item()

            print('\n Validation loss: %f \n' % (loss_total / idx))
            Validation_loss.append(loss_total / idx)

        logger.print_and_save('Epoch: [%d/%d], loss in Validation: %f' %
                              (epoch, niter, Validation_loss[-1]))

        if Validation_loss[-1] == min(Validation_loss):
            torch.save(
                unet3d.state_dict(),
                rootDir + '/weights_rsa={0}_validation={1}_test={2}'.format(
                    opt['flag_rsa'], opt['case_validation'],
                    opt['case_test']) + '.pt')
Ejemplo n.º 8
0
            #     flag_COSMOS=flag_COSMOS,
            #     optimizer=optimizer,
            #     sigma_sq=sigma_sq,
            #     Lambda_tv=Lambda_tv,
            #     voxel_size=voxel_size
            # )

            errl1 = BayesianQSM_train(
                unet3d=unet3d,
                input_RDFs=input_RDFs,
                in_loss_RDFs=in_loss_RDFs,
                QSMs=QSMs,
                Masks=Masks,
                fidelity_Ws=fidelity_Ws,
                gradient_Ws=gradient_Ws,
                D=D,
                flag_COSMOS=flag_COSMOS,
                optimizer=optimizer,
                sigma_sq=sigma_sq,
                Lambda_tv=Lambda_tv,
                voxel_size=voxel_size
            )

            print(errl1)
            print('\n')

        print('Finish current epoch')
        torch.save(unet3d.state_dict(), rootDir+'/weights.pt')
            

Ejemplo n.º 9
0
                    else:
                        outputs = unet3d(rdfs_input)
                        mean_Maps = outputs[:, 0:1, ...]
                        var_Maps = outputs[:, 1:2, ...]

                        loss1 = 1 / 2 * torch.sum(
                            (mean_Maps * masks - qsms * masks)**2 / var_Maps)
                        loss2 = 1 / 2 * torch.sum(torch.log(var_Maps) * masks)
                        loss = loss1 + loss2
                        loss_total += loss.item()

            val_loss.append(loss_total)
            if val_loss[-1] == min(val_loss):
                if flag_VI:
                    torch.save(
                        unet3d.state_dict(), rootDir + folder_weights +
                        '/weights_PDI_VI0.pt'.format(Lambda_tv))
                else:
                    torch.save(unet3d.state_dict(),
                               rootDir + folder_weights + '/weights_PDI.pt')

    # test phase
    else:
        # dataloader
        dataLoader_test = Simulation_ICH_loader(split='test')
        testLoader = data.DataLoader(dataLoader_test,
                                     batch_size=batch_size,
                                     shuffle=True)

        if flag_VI:
            # unet3d.load_state_dict(torch.load(rootDir+folder_weights+'/weights_PDI_VI0.pt')) # used for mean prediction
Ejemplo n.º 10
0
                wGs = wGs.to(device, dtype=torch.float)

                # calculate KLD
                outputs = unet3d(rdfs)
                loss_kl = loss_KL(outputs=outputs,
                                  QSMs=0,
                                  flag_COSMOS=0,
                                  sigma_sq=0)
                loss_expectation, loss_tv = loss_Expectation(
                    outputs=outputs,
                    QSMs=0,
                    in_loss_RDFs=rdfs - trans * scale,
                    fidelity_Ws=weights,
                    gradient_Ws=wGs,
                    D=np.asarray(D[0, ...]),
                    flag_COSMOS=0,
                    Lambda_tv=Lambda_tv,
                    voxel_size=voxel_size,
                    K=K)
                loss_total += (loss_kl + loss_expectation + loss_tv).item()
            print('KL Divergence on validation set = {0}'.format(loss_total))

        val_loss.append(loss_total)
        if val_loss[-1] == min(val_loss):
            if Lambda_tv:
                torch.save(
                    unet3d.state_dict(), rootDir + folder_weights_VI +
                    '/weights_vi_cosmos_{}_9.pt'.format(Lambda_tv))
        torch.save(
            unet3d.state_dict(), rootDir + folder_weights_VI +
            '/weights_vi_cosmos_{}_last_9.pt'.format(Lambda_tv))
Ejemplo n.º 11
0
def train(args):

    # Setup Dataloader
    data_json = json.load(open('config.json'))
    data_path=data_json[args.dataset]['data_path']
    t_loader = SaltLoader(data_path, img_size_ori=args.img_size_ori,img_size_target=args.img_size_target)
    v_loader = SaltLoader(data_path, split='val',img_size_ori=args.img_size_ori, img_size_target=args.img_size_target)

    train_loader = data.DataLoader(t_loader, batch_size=args.batch_size, num_workers=8, shuffle=True)
    val_loader = data.DataLoader(v_loader, batch_size=args.batch_size, num_workers=8)

    # Setup Model
    if args.arch=='unet':
        model = Unet(start_fm=16)
    else:
        model=Unet_upsample(start_fm=16)
    print(model)
    total = sum([param.nelement() for param in model.parameters()])
    print('Number of params: %.2fM' % (total / 1e6))

    model.cuda()

    # Check if model has custom optimizer / loss
    optimizer = torch.optim.Adam(model.parameters(), lr=args.l_rate)
    loss_fn= nn.BCEWithLogitsLoss()

    best_loss=100
    mean_train_losses = []
    mean_val_losses = []
    for epoch in range(args.n_epoch):
        train_losses = []
        val_losses = []
        for images, masks in train_loader:
            images = Variable(images.cuda())
            masks = Variable(masks.cuda())

            outputs = model(images)

            loss = loss_fn(outputs, masks)
            train_losses.append(loss.data)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        for images, masks in val_loader:
            images = Variable(images.cuda())
            masks = Variable(masks.cuda())

            outputs = model(images)
            loss = loss_fn(outputs, masks)
            val_losses.append(loss.data)

        mean_train_losses.append(np.mean(train_losses))
        mean_val_losses.append(np.mean(val_losses))
        if np.mean(val_losses) < best_loss:
            best_loss = np.mean(val_losses)
            state = {'epoch': epoch + 1,
                     'model_state': model.state_dict(),
                     'optimizer_state': optimizer.state_dict(), }
            torch.save(state, "./saved_models/{}_{}_best_model.pkl".format(args.arch, args.dataset))

        # Print Loss
        print('Epoch: {}. Train Loss: {}. Val Loss: {}'.format(epoch + 1, np.mean(train_losses), np.mean(val_losses)))

    state = {'model_state': model.state_dict(),
             'optimizer_state': optimizer.state_dict(), }
    torch.save(state, "./saved_models/{}_{}_final_model.pkl".format(args.arch, args.dataset))

    print("saved two models in ./saved_models")
Ejemplo n.º 12
0
            for idx, (ifreqs, masks, data_weights,
                      wGs) in enumerate(valLoader):
                ifreqs = ifreqs.to(device, dtype=torch.float)
                masks = masks.to(device, dtype=torch.float)
                data_weights = data_weights.to(device, dtype=torch.float)
                wGs = wGs.to(device, dtype=torch.float)
                loss_PDF, loss_fidelity, loss_tv = utfi_train(model,
                                                              optimizer,
                                                              ifreqs,
                                                              masks,
                                                              data_weights,
                                                              wGs,
                                                              D,
                                                              D_smv,
                                                              lambda_pdf,
                                                              lambda_tv,
                                                              voxel_size,
                                                              flag_train=0)
                loss_total = loss_PDF + loss_fidelity + loss_tv
                loss_total_list.append(np.asarray(loss_total))
            Validation_loss.append(
                sum(loss_total_list) / float(len(loss_total_list)))

        if Validation_loss[-1] == min(Validation_loss):
            torch.save(
                model.state_dict(),
                rootDir + '/weights/weight_pdf={0}_tv={1}.pt'.format(
                    lambda_pdf, lambda_tv))

        epoch += 1