Beispiel #1
0
def build(train_config, dataset_train_configs, dataset_test_configs):
    """Build PyTorch train and test data-loaders.

    :param train_config: Train configurations
    :param dataset_train_configs
    :param dataset_test_configs
    :return:
    """
    add_from_gan = dataset_train_configs.add_data_from_gan
    batch_size = train_config.batch_size
    composed = transforms.Compose([ToTensor()])

    train_dataset = ecg_dataset_pytorch.EcgHearBeatsDatasetPytorch(
        configs=dataset_train_configs, transform=composed)
    test_dataset = ecg_dataset_pytorch.EcgHearBeatsDatasetPytorch(
        configs=dataset_test_configs, transform=composed)

    testdataloader = torch.utils.data.DataLoader(test_dataset,
                                                 batch_size=300,
                                                 shuffle=True,
                                                 num_workers=1)

    #
    # Check if to add data from GAN:
    #
    if add_from_gan:
        logging.info(
            "Size of training data after additional data from GAN: {}".format(
                len(train_dataset)))
        logging.info("#N: {}\t #S: {}\t #V: {}\t #F: {}\t".format(
            train_dataset.len_beat('N'), train_dataset.len_beat('S'),
            train_dataset.len_beat('V'), train_dataset.len_beat('F')))
    else:
        logging.info("No data is added. Train set size: ")
        logging.info("#N: {}\t #S: {}\t #V: {}\t #F: {}\t".format(
            train_dataset.len_beat('N'), train_dataset.len_beat('S'),
            train_dataset.len_beat('V'), train_dataset.len_beat('F')))
        logging.info("test set size: ")
        logging.info("#N: {}\t #S: {}\t #V: {}\t #F: {}\t".format(
            test_dataset.len_beat('N'), test_dataset.len_beat('S'),
            test_dataset.len_beat('V'), test_dataset.len_beat('F')))

    if train_config.weighted_sampling:
        weights_for_balance = train_dataset.make_weights_for_balanced_classes()
        weights_for_balance = torch.DoubleTensor(weights_for_balance)
        sampler = torch.utils.data.sampler.WeightedRandomSampler(
            weights=weights_for_balance,
            num_samples=len(weights_for_balance),
            replacement=True)
        train_data_loader = torch.utils.data.DataLoader(train_dataset,
                                                        batch_size=batch_size,
                                                        num_workers=1,
                                                        sampler=sampler)
    else:
        train_data_loader = torch.utils.data.DataLoader(train_dataset,
                                                        batch_size=batch_size,
                                                        num_workers=1,
                                                        shuffle=True)

    return train_data_loader, testdataloader
Beispiel #2
0
def train(batch_size, num_train_steps, model_dir, beat_type, generator_net,
          discriminator_net):

    device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")
    ode_params = ODEParams(device)

    #
    # Support for tensorboard:
    #
    writer = SummaryWriter(model_dir)

    #
    # 1. create the ECG dataset:
    #
    composed = transforms.Compose(
        [ecg_dataset_pytorch.Scale(),
         ecg_dataset_pytorch.ToTensor()])

    positive_configs = dataset_configs.DatasetConfigs(
        'train',
        beat_type,
        one_vs_all=True,
        lstm_setting=False,
        over_sample_minority_class=False,
        under_sample_majority_class=False,
        only_take_heartbeat_of_type=beat_type,
        add_data_from_gan=False,
        gan_configs=None)

    dataset = ecg_dataset_pytorch.EcgHearBeatsDatasetPytorch(
        positive_configs, transform=composed)

    dataloader = torch.utils.data.DataLoader(dataset,
                                             batch_size=batch_size,
                                             shuffle=True,
                                             num_workers=1)
    print("Size of real dataset is {}".format(len(dataset)))

    #
    # 2. Create the models:
    netG = generator_net.to(device)
    netD = discriminator_net.to(device)

    #
    # Define loss functions:
    #
    cross_entropy_loss = nn.BCELoss()
    mse_loss = nn.MSELoss()

    #
    # Optimizers:
    #
    lr = 0.0002
    beta1 = 0.5
    writer.add_scalar('Learning_Rate', lr)
    optimizer_d = optim.Adam(netD.parameters(), lr=lr, betas=(beta1, 0.999))
    optimizer_g = optim.Adam(netG.parameters(), lr=lr, betas=(beta1, 0.999))

    #
    # Noise for validation:
    #
    val_noise = torch.Tensor(np.random.normal(0, 1, (4, 100))).to(device)

    #
    # Training loop"
    #
    epoch = 0
    iters = 0
    while True:
        num_of_beats_seen = 0
        if iters == num_train_steps:
            break
        for i, data in enumerate(dataloader):
            if iters == num_train_steps:
                break

            netD.zero_grad()

            #
            # Discriminator from real beats:
            #
            ecg_batch = data['cardiac_cycle'].float().to(device)
            b_size = ecg_batch.shape[0]

            num_of_beats_seen += ecg_batch.shape[0]
            output = netD(ecg_batch)
            labels = torch.full((b_size, ), 1, device=device)

            ce_loss_d_real = cross_entropy_loss(output, labels)
            writer.add_scalar('Discriminator/cross_entropy_on_real_batch',
                              ce_loss_d_real.item(),
                              global_step=iters)
            writer.add_scalars(
                'Merged/losses',
                {'d_cross_entropy_on_real_batch': ce_loss_d_real.item()},
                global_step=iters)
            ce_loss_d_real.backward()
            mean_d_real_output = output.mean().item()

            #
            # Discriminator from fake beats:
            #
            noise_input = torch.Tensor(np.random.normal(
                0, 1, (b_size, 100))).to(device)
            # noise_input = torch.Tensor(np.random.normal(0, 1, (b_size, 100))).to(device)

            output_g_fake = netG(noise_input)
            output = netD(output_g_fake.detach()).to(device)
            labels.fill_(0)

            ce_loss_d_fake = cross_entropy_loss(output, labels)
            writer.add_scalar('Discriminator/cross_entropy_on_fake_batch',
                              ce_loss_d_fake.item(), iters)
            writer.add_scalars(
                'Merged/losses',
                {'d_cross_entropy_on_fake_batch': ce_loss_d_fake.item()},
                global_step=iters)
            ce_loss_d_fake.backward()

            mean_d_fake_output = output.mean().item()
            total_loss_d = ce_loss_d_fake + ce_loss_d_real
            writer.add_scalar(tag='Discriminator/total_loss',
                              scalar_value=total_loss_d.item(),
                              global_step=iters)
            optimizer_d.step()

            netG.zero_grad()
            labels.fill_(1)
            output = netD(output_g_fake)

            #
            # Add euler loss:
            #
            delta_hb_signal, f_ode_z_signal = ode_loss(output_g_fake,
                                                       ode_params, device,
                                                       beat_type)
            mse_loss_euler = mse_loss(delta_hb_signal, f_ode_z_signal)
            logging.info("MSE ODE loss: {}".format(mse_loss_euler.item()))
            ce_loss_g_fake = cross_entropy_loss(output, labels)
            total_g_loss = mse_loss_euler + ce_loss_g_fake
            # total_g_loss = mse_loss_euler
            total_g_loss.backward()

            writer.add_scalar(tag='Generator/mse_ode',
                              scalar_value=mse_loss_euler.item(),
                              global_step=iters)
            writer.add_scalar(tag='Generator/cross_entropy_on_fake_batch',
                              scalar_value=ce_loss_g_fake.item(),
                              global_step=iters)
            writer.add_scalars(
                'Merged/losses',
                {'g_cross_entropy_on_fake_batch': ce_loss_g_fake.item()},
                global_step=iters)
            mean_d_fake_output_2 = output.mean().item()

            optimizer_g.step()

            if iters % 50 == 0:
                print(
                    "{}/{}: Epoch #{}: Iteration #{}: Mean D(real_hb_batch) = {}, mean D(G(z)) = {}."
                    .format(num_of_beats_seen, len(dataset), epoch, iters,
                            mean_d_real_output, mean_d_fake_output),
                    end=" ")
                print("mean D(G(z)) = {} After backprop of D".format(
                    mean_d_fake_output_2))

                print(
                    "Loss D from real beats = {}. Loss D from Fake beats = {}. Total Loss D = {}"
                    .format(ce_loss_d_real, ce_loss_d_fake, total_loss_d),
                    end=" ")
                print("Loss G = {}".format(ce_loss_g_fake))

            #
            # Norma of gradients:
            #
            gNormGrad = get_gradient_norm_l2(netG)
            dNormGrad = get_gradient_norm_l2(netD)
            writer.add_scalar('Generator/gradients_norm', gNormGrad, iters)
            writer.add_scalar('Discriminator/gradients_norm', dNormGrad, iters)
            print(
                "Generator Norm of gradients = {}. Discriminator Norm of gradients = {}."
                .format(gNormGrad, dNormGrad))

            if iters % 25 == 0:
                with torch.no_grad():
                    with torch.no_grad():
                        netG.eval()
                        output_g = netG(val_noise)
                        netG.train()
                        fig = plt.figure()
                        plt.title(
                            "Fake beats from Generator. iteration {}".format(
                                i))
                        for p in range(4):
                            plt.subplot(2, 2, p + 1)
                            plt.plot(output_g[p].cpu().detach().numpy(),
                                     label="fake beat")
                            plt.plot(ecg_batch[p].cpu().detach().numpy(),
                                     label="real beat")
                            plt.legend()
                        writer.add_figure('Generator/output_example', fig,
                                          iters)
                        plt.close()
            if iters % 50 == 0:
                torch.save(
                    {
                        'epoch': epoch,
                        'generator_state_dict': netG.state_dict(),
                    }, model_dir +
                    '/checkpoint_epoch_{}_iters_{}'.format(epoch, iters))
            iters += 1
        epoch += 1
    torch.save({
        'epoch': epoch,
        'generator_state_dict': netG.state_dict(),
    }, model_dir + '/checkpoint_epoch_{}_iters_{}'.format(epoch, iters))
    writer.close()
Beispiel #3
0
def train(batch_size, num_train_steps, generator, discriminator, model_dir,
          beat_type, device):

    ode_params = ODEParams(device)
    #
    # Support for tensorboard:
    #
    writer = SummaryWriter(model_dir)

    #
    # 1. create the ECG dataset:
    #
    positive_configs = dataset_configs.DatasetConfigs(
        'train',
        beat_type,
        one_vs_all=True,
        lstm_setting=False,
        over_sample_minority_class=False,
        under_sample_majority_class=False,
        only_take_heartbeat_of_type=beat_type,
        add_data_from_gan=False,
        gan_configs=None)

    composed = transforms.Compose(
        [ecg_dataset_pytorch.Scale(),
         ecg_dataset_pytorch.ToTensor()])
    dataset = ecg_dataset_pytorch.EcgHearBeatsDatasetPytorch(
        positive_configs, transform=composed)

    dataloader = torch.utils.data.DataLoader(dataset,
                                             batch_size=batch_size,
                                             shuffle=True,
                                             num_workers=1)
    print("Size of real dataset is {}".format(len(dataset)))

    #
    # 2. Create the Networks:
    #
    netG = generator.float()
    netD = discriminator.float()

    num_d_iters = 5
    weight_cliping_limit = 0.01
    #
    # Loss functions for WGAN and ode:
    #
    mse_loss = nn.MSELoss()

    # Optimizers:
    # WGAN values from paper
    lr = 0.00005

    writer.add_scalar('Learning_Rate', lr)
    # WGAN with gradient clipping uses RMSprop instead of ADAM
    optimizer_d = torch.optim.RMSprop(netD.parameters(), lr=lr)
    optimizer_g = torch.optim.RMSprop(netG.parameters(), lr=lr)

    # Noise for validation:
    val_noise = torch.from_numpy(np.random.uniform(
        0, 1, (4, 100))).float().to(device)
    loss_d_real_hist = []
    loss_d_fake_hist = []
    loss_g_fake_hist = []
    norma_grad_g = []
    norm_grad_d = []
    d_real_pred_hist = []
    d_fake_pred_hist = []
    epoch = 0
    iters = 0
    while True:
        num_of_beats_seen = 0
        if iters == num_train_steps:
            break
        for i, data in enumerate(dataloader):
            if iters == num_train_steps:
                break

            # Train Dicriminator forward - loss - backward - update num_d_iters times while 1 Generator
            # forward-loss-backward-update
            for p in netD.parameters():
                p.requires_grad = True
            for d_iter in range(num_d_iters):

                netD.zero_grad()

                # Clamp parameters to a range [-c, c], c=self.weight_cliping_limit
                for p in netD.parameters():
                    p.data.clamp_(-weight_cliping_limit, weight_cliping_limit)

                ecg_batch = data['cardiac_cycle'].float().to(device)
                b_size = ecg_batch.shape[0]

                # Check for batch to have full batch_size
                if (b_size != batch_size):
                    continue
                num_of_beats_seen += ecg_batch.shape[0]

                output = netD(ecg_batch)

                # Adversarial loss
                loss_d_real = -torch.mean(output)

                writer.add_scalar('Discriminator/cross_entropy_on_real_batch',
                                  loss_d_real.item(),
                                  global_step=iters)
                writer.add_scalars(
                    'Merged/losses',
                    {'d_cross_entropy_on_real_batch': loss_d_real.item()},
                    global_step=iters)
                loss_d_real.backward()
                loss_d_real_hist.append(loss_d_real.item())

                mean_d_real_output = output.mean().item()
                d_real_pred_hist.append(mean_d_real_output)

                #
                # D loss from fake:
                #
                noise_input = torch.from_numpy(
                    np.random.uniform(0, 1, (b_size, 100))).float().to(device)

                output_g_fake = netG(noise_input)
                output = netD(output_g_fake.detach())

                loss_d_fake = torch.mean(output)
                # ce_loss_d_fake = cross_entropy_loss(output, labels)
                writer.add_scalar('Discriminator/cross_entropy_on_fake_batch',
                                  loss_d_fake.item(), iters)
                writer.add_scalars(
                    'Merged/losses',
                    {'d_cross_entropy_on_fake_batch': loss_d_fake.item()},
                    global_step=iters)
                loss_d_fake.backward()

                loss_d_fake_hist.append(loss_d_fake.item())

                mean_d_fake_output = output.mean().item()
                d_fake_pred_hist.append(mean_d_fake_output)
                total_loss_d = loss_d_fake + loss_d_real
                writer.add_scalar(tag='Discriminator/total_loss',
                                  scalar_value=total_loss_d.item(),
                                  global_step=iters)
                optimizer_d.step()

            #
            # Generator updates:
            #
            for p in netD.parameters():
                p.requires_grad = False  # to avoid computation

            netG.zero_grad()

            noise_input = torch.from_numpy(
                np.random.uniform(0, 1, (batch_size, 100))).float().to(device)

            output_g_fake = netG(noise_input)

            output = netD(output_g_fake)

            # Adversarial loss:
            loss_g_fake = -torch.mean(output)

            # Euler loss:
            delta_hb_signal, f_ode_z_signal = ode_loss(output_g_fake,
                                                       ode_params, device,
                                                       beat_type)
            mse_loss_euler = mse_loss(delta_hb_signal, f_ode_z_signal)
            logging.info("MSE ODE loss: {}".format(mse_loss_euler.item()))

            total_g_loss = mse_loss_euler + loss_g_fake

            total_g_loss.backward()
            loss_g_fake_hist.append(loss_g_fake.item())
            writer.add_scalar(tag='Generator/mse_ode',
                              scalar_value=mse_loss_euler.item(),
                              global_step=iters)
            writer.add_scalar(tag='Generator/cross_entropy_on_fake_batch',
                              scalar_value=loss_g_fake.item(),
                              global_step=iters)
            writer.add_scalars(
                'Merged/losses',
                {'g_cross_entropy_on_fake_batch': total_g_loss.item()},
                global_step=iters)
            mean_d_fake_output_2 = output.mean().item()

            optimizer_g.step()

            print(
                "{}/{}: Epoch #{}: Iteration #{}: Mean D(real_hb_batch) = {}, mean D(G(z)) = {}."
                .format(num_of_beats_seen, len(dataset), epoch, iters,
                        mean_d_real_output, mean_d_fake_output),
                end=" ")
            print("mean D(G(z)) = {} After backprop of D".format(
                mean_d_fake_output_2))

            print(
                "Loss D from real beats = {}. Loss D from Fake beats = {}. Total Loss D = {}"
                .format(loss_d_real, loss_d_fake, total_loss_d),
                end=" ")
            print("Loss G = {}".format(loss_g_fake))

            # Norma of gradients:
            gNormGrad = get_gradient_norm_l2(netG)
            dNormGrad = get_gradient_norm_l2(netD)
            writer.add_scalar('Generator/gradients_norm', gNormGrad, iters)
            writer.add_scalar('Discriminator/gradients_norm', dNormGrad, iters)
            norm_grad_d.append(dNormGrad)
            norma_grad_g.append(gNormGrad)
            print(
                "Generator Norm of gradients = {}. Discriminator Norm of gradients = {}."
                .format(gNormGrad, dNormGrad))

            if iters % 25 == 0:
                with torch.no_grad():
                    output_g = netG(val_noise)
                    fig = plt.figure()
                    plt.title(
                        "Fake beats from Generator. iteration {}".format(i))
                    for p in range(4):
                        plt.subplot(2, 2, p + 1)
                        plt.plot(output_g[p].cpu().detach().numpy(),
                                 label="fake beat")
                        plt.plot(ecg_batch[p].cpu().detach().numpy(),
                                 label="real beat")
                        plt.legend()
                    writer.add_figure('Generator/output_example', fig, iters)
                    plt.close()
            iters += 1
        epoch += 1

    torch.save(
        {
            'epoch': epoch,
            'generator_state_dict': netG.state_dict(),
            'discriminator_state_dict': netD.state_dict(),
            'optimizer_g_state_dict': optimizer_g.state_dict(),
            'optimizer_d_state_dict': optimizer_d.state_dict(),
        }, model_dir + '/checkpoint_epoch_{}_iters_{}'.format(epoch, iters))
    writer.close()
Beispiel #4
0
def train_ecg_gan(batch_size, num_train_steps, generator, discriminator,
                  model_dir, beat_type):
    # Support for tensorboard:
    writer = SummaryWriter(model_dir)
    # 1. create the ECG dataset:
    # composed = transforms.Compose([ecg_dataset.Scale(), ecg_dataset.Smooth(), ecg_dataset.ToTensor()])

    positive_configs = dataset_configs.DatasetConfigs(
        'train',
        beat_type,
        one_vs_all=True,
        lstm_setting=False,
        over_sample_minority_class=False,
        under_sample_majority_class=False,
        only_take_heartbeat_of_type=beat_type,
        add_data_from_gan=False,
        gan_configs=None)

    dataset = ecg_dataset_pytorch.EcgHearBeatsDatasetPytorch(
        positive_configs, transform=ecg_dataset_pytorch.ToTensor())

    dataloader = torch.utils.data.DataLoader(dataset,
                                             batch_size=batch_size,
                                             shuffle=True,
                                             num_workers=1)
    print("Size of real dataset is {}".format(len(dataset)))

    # 2. Create the models:
    netG = generator
    netD = discriminator

    # Loss functions:
    cross_entropy_loss = nn.BCELoss()

    # Optimizers:
    lr = 0.0002
    beta1 = 0.5
    writer.add_scalar('Learning_Rate', lr)
    optimizer_d = optim.Adam(netD.parameters(), lr=lr, betas=(beta1, 0.999))
    optimizer_g = optim.Adam(netG.parameters(), lr=lr, betas=(beta1, 0.999))

    # Noise for validation:
    val_noise = torch.Tensor(np.random.normal(0, 1, (4, 100)))
    loss_d_real_hist = []
    loss_d_fake_hist = []
    loss_g_fake_hist = []
    norma_grad_g = []
    norm_grad_d = []
    d_real_pred_hist = []
    d_fake_pred_hist = []
    epoch = 0
    iters = 0
    while True:
        num_of_beats_seen = 0
        if iters == num_train_steps:
            break
        for i, data in enumerate(dataloader):
            if iters == num_train_steps:
                break

            netD.zero_grad()
            ecg_batch = data['cardiac_cycle'].float()
            b_size = ecg_batch.shape[0]
            num_of_beats_seen += ecg_batch.shape[0]
            output = netD(ecg_batch)
            labels = torch.full((b_size, ), 1, device='cpu')
            ce_loss_d_real = cross_entropy_loss(output, labels)
            writer.add_scalar('Discriminator/cross_entropy_on_real_batch',
                              ce_loss_d_real.item(),
                              global_step=iters)
            writer.add_scalars(
                'Merged/losses',
                {'d_cross_entropy_on_real_batch': ce_loss_d_real.item()},
                global_step=iters)
            ce_loss_d_real.backward()
            loss_d_real_hist.append(ce_loss_d_real.item())

            mean_d_real_output = output.mean().item()
            d_real_pred_hist.append(mean_d_real_output)

            noise_input = torch.Tensor(np.random.noraml(0, 1, (b_size, 100)))

            output_g_fake = netG(noise_input)
            output = netD(output_g_fake.detach())
            labels.fill_(0)

            ce_loss_d_fake = cross_entropy_loss(output, labels)
            writer.add_scalar('Discriminator/cross_entropy_on_fake_batch',
                              ce_loss_d_fake.item(), iters)
            writer.add_scalars(
                'Merged/losses',
                {'d_cross_entropy_on_fake_batch': ce_loss_d_fake.item()},
                global_step=iters)
            ce_loss_d_fake.backward()

            loss_d_fake_hist.append(ce_loss_d_fake.item())

            mean_d_fake_output = output.mean().item()
            d_fake_pred_hist.append(mean_d_fake_output)
            total_loss_d = ce_loss_d_fake + ce_loss_d_real
            writer.add_scalar(tag='Discriminator/total_loss',
                              scalar_value=total_loss_d.item(),
                              global_step=iters)
            optimizer_d.step()

            netG.zero_grad()
            labels.fill_(1)

            output = netD(output_g_fake)
            ce_loss_g_fake = cross_entropy_loss(output, labels)
            ce_loss_g_fake.backward()
            loss_g_fake_hist.append(ce_loss_g_fake.item())
            writer.add_scalar(tag='Generator/cross_entropy_on_fake_batch',
                              scalar_value=ce_loss_g_fake.item(),
                              global_step=iters)
            writer.add_scalars(
                'Merged/losses',
                {'g_cross_entropy_on_fake_batch': ce_loss_g_fake.item()},
                global_step=iters)
            mean_d_fake_output_2 = output.mean().item()

            optimizer_g.step()

            print(
                "{}/{}: Epoch #{}: Iteration #{}: Mean D(real_hb_batch) = {}, mean D(G(z)) = {}."
                .format(num_of_beats_seen, len(dataset), epoch, iters,
                        mean_d_real_output, mean_d_fake_output),
                end=" ")
            print("mean D(G(z)) = {} After backprop of D".format(
                mean_d_fake_output_2))

            print(
                "Loss D from real beats = {}. Loss D from Fake beats = {}. Total Loss D = {}"
                .format(ce_loss_d_real, ce_loss_d_fake, total_loss_d),
                end=" ")
            print("Loss G = {}".format(ce_loss_g_fake))

            # Norma of gradients:
            gNormGrad = get_gradient_norm_l2(netG)
            dNormGrad = get_gradient_norm_l2(netD)
            writer.add_scalar('Generator/gradients_norm', gNormGrad, iters)
            writer.add_scalar('Discriminator/gradients_norm', dNormGrad, iters)
            norm_grad_d.append(dNormGrad)
            norma_grad_g.append(gNormGrad)
            print(
                "Generator Norm of gradients = {}. Discriminator Norm of gradients = {}."
                .format(gNormGrad, dNormGrad))

            if iters % 25 == 0:
                with torch.no_grad():
                    output_g = netG(val_noise)
                    fig = plt.figure()
                    plt.title(
                        "Fake beats from Generator. iteration {}".format(i))
                    for p in range(4):
                        plt.subplot(2, 2, p + 1)
                        plt.plot(output_g[p].detach().numpy(),
                                 label="fake beat")
                        plt.plot(ecg_batch[p].detach().numpy(),
                                 label="real beat")
                        plt.legend()
                    writer.add_figure('Generator/output_example', fig, iters)
                    plt.close()
            iters += 1
        epoch += 1

    torch.save(
        {
            'epoch': epoch,
            'generator_state_dict': netG.state_dict(),
            'discriminator_state_dict': netD.state_dict(),
            'optimizer_g_state_dict': optimizer_g.state_dict(),
            'optimizer_d_state_dict': optimizer_d.state_dict(),
            'loss': cross_entropy_loss,
        }, model_dir + '/checkpoint_epoch_{}_iters_{}'.format(epoch, iters))
    writer.close()