Ejemplo n.º 1
0
def generate(params, checkpt, outname, num=64, batch=64):
  device = torch.device("cuda" if (torch.cuda.is_available() and params.ngpu > 0) else "cpu")

  netG = dcgan.Generator(params).to(device)

  num_GPUs = torch.cuda.device_count()
  print('Using %d GPUs'%num_GPUs)
  if num_GPUs > 1:
    device_list = list(range(num_GPUs))
    netG = nn.DataParallel(netG, device_list)

  checkpoint = None
  print("Loading checkpoint %s"%checkpt)
  checkpoint = torch.load(checkpt)
  netG.load_state_dict(checkpoint['G_state'])
  netG.eval()

  print("Starting generation...")
  Niters = num//batch
  out = np.zeros((num, 5, 124, 365)).astype('ushort')
  pars_norm = np.zeros((num, params.num_params)).astype(np.float)
  for i in range(Niters):
    with torch.no_grad():
      if '2parB' in params.data_path:
        # Different normalized range for different dataset
        pars_R0 = 1.4*torch.rand((batch, 1), device=device) - 0.4
        pars_TriggerDay = 1.3791*torch.rand((batch, 1), device=device) - 1.
        pars = torch.cat([pars_R0, pars_TriggerDay], axis=-1)
      elif '3parA' in params.tag:
        pars_R0 = 0.4*torch.rand((batch, 1), device=device) - 0.2
        pars_WFHcomp = 1.8*torch.rand((batch, 1), device=device) - 1.
        pars_WFHdays = 0.24725*torch.rand((batch, 1), device=device) - 0.75275
        pars = torch.cat([pars_R0, pars_WFHcomp, pars_WFHdays], axis=-1)
      else:
        pars = 2.*torch.rand((batch, 2), device=device) - 1.
      noise = torch.randn((batch, 1, params.z_dim), device=device)
      fake = netG(noise, pars).detach().cpu().numpy() 
      if params.norm == 'log':
        fake = np.exp(fake) - 1.
      out[i*batch:(i+1)*batch,:,:,:] = np.round(fake).astype('ushort')
      pars_norm[i*batch:(i+1)*batch,:] = pars.detach().cpu().numpy()
  print("Output: shape %s, type %s, size %f MB"%(str(out.shape), str(out.dtype), out.nbytes/1e6))
  np.save(outname+'_dat.npy', out)
  np.save(outname+'_par.npy', pars_norm)
  print("Saved output to %s"%(outname+'_dat.npy'))
  print("Saved parameters to %s"%(outname+'_par.npy'))
Ejemplo n.º 2
0
                    choices=["channels_first", "channels_last"],
                    default="channels_last",
                    help="data_format")
parser.add_argument('--train', action="store_true", help="training mode")
parser.add_argument('--gpu', type=str, default="0", help="gpu id")
args = parser.parse_args()

tf.logging.set_verbosity(tf.logging.INFO)

gan_models = [
    gan.Model(dataset=celeba.Dataset(image_size=[64, 64],
                                     data_format=args.data_format),
              generator=dcgan.Generator(
                  min_resolution=4,
                  max_resolution=64,
                  min_filters=32,
                  max_filters=512,
                  data_format=args.data_format,
              ),
              discriminator=dcgan.Discriminator(min_resolution=4,
                                                max_resolution=64,
                                                min_filters=32,
                                                max_filters=512,
                                                data_format=args.data_format),
              loss_function=gan.Model.LossFunction.NS_GAN,
              gradient_penalty=gan.Model.GradientPenalty.ONE_CENTERED,
              hyper_params=attr_dict.AttrDict(latent_size=128,
                                              gradient_coefficient=1.0,
                                              learning_rate=0.0002,
                                              beta1=0.5,
                                              beta2=0.999),
Ejemplo n.º 3
0
def generate(params, checkpt, outname, num=64, batch=64):
    device = torch.device("cuda" if (
        torch.cuda.is_available() and params.ngpu > 0) else "cpu")

    netG = dcgan.Generator(params).to(device)

    num_GPUs = torch.cuda.device_count()
    print('Using %d GPUs' % num_GPUs)
    if num_GPUs > 1:
        device_list = list(range(num_GPUs))
        netG = nn.DataParallel(netG, device_list)

    checkpoint = None
    print("Loading checkpoint %s" % checkpt)
    checkpoint = torch.load(checkpt)
    netG.load_state_dict(checkpoint['G_state'])
    netG.eval()

    print("Starting generation...")
    Niters = num // batch
    out = np.zeros((num, 5, 124, 365)).astype('ushort')
    pars_norm = np.zeros((num, params.num_params)).astype(np.float)
    pars_orig = np.zeros((num, params.num_params)).astype(np.float)
    for i in range(Niters):
        with torch.no_grad():
            pars_R0 = normalized_params(2.,
                                        3.,
                                        param='R0',
                                        batchsize=batch,
                                        device=device)
            pars_WFHcomp = normalized_params(0.,
                                             0.9,
                                             param='WFHcomp',
                                             batchsize=batch,
                                             device=device)
            pars_WFHdays = normalized_params(45,
                                             90,
                                             param='WFHdays',
                                             batchsize=batch,
                                             device=device)
            pars = torch.cat([pars_R0, pars_WFHcomp, pars_WFHdays], axis=-1)

            noise = torch.randn((batch, 1, params.z_dim), device=device)
            fake = netG(noise, pars).detach().cpu().numpy()
            out[i * batch:(i + 1) *
                batch, :, :, :] = np.round(fake).astype('ushort')
            upars = pars.detach().cpu().numpy()
            print(upars.shape)
            print(unnormalized(upars[:, 0], 'R0').shape)
            pars_norm[i * batch:(i + 1) * batch, :] = upars
            pars_orig[i * batch:(i + 1) * batch, :] = np.stack([
                unnormalized(upars[:, 0], 'R0'),
                unnormalized(upars[:, 1], 'WFHcomp'),
                unnormalized(upars[:, 2], 'WFHdays')
            ],
                                                               axis=-1)
    print("Output: shape %s, type %s, size %f MB" %
          (str(out.shape), str(out.dtype), out.nbytes / 1e6))
    with h5py.File(outname, 'w') as f:
        f.create_dataset('symptomatic3D', data=out)
        f.create_dataset('parBio', data=pars_orig)
        f.create_dataset('uniBio', data=pars_norm)
    print("Saved output to %s" % (outname))
Ejemplo n.º 4
0
def train(params, writer):

    data_loader = get_data_loader(params)
    vald_loader = get_vald_loader(params)

    device = torch.device("cuda" if (
        torch.cuda.is_available() and params.ngpu > 0) else "cpu")

    netG = dcgan.Generator(params).to(device)
    logging.info(netG)
    #summary(netG, (params.z_dim, params.num_params))

    netD = dcgan.Discriminator(params).to(device)
    logging.info(netD)
    #summary(netD, (params.num_channels, params.x_out_size, params.y_out_size))

    num_GPUs = torch.cuda.device_count()
    logging.info('Using %d GPUs' % num_GPUs)
    if num_GPUs > 1:
        device_list = list(range(num_GPUs))
        netG = nn.DataParallel(netG, device_list)
        netD = nn.DataParallel(netD, device_list)

    netG.apply(dcgan.get_weights_function(params.gen_init))
    netD.apply(dcgan.get_weights_function(params.disc_init))

    if params.optim == 'Adam':
        optimizerD = optim.Adam(netD.parameters(),
                                lr=params.lr,
                                betas=(params.adam_beta1, params.adam_beta2))
        optimizerG = optim.Adam(netG.parameters(),
                                lr=params.lr,
                                betas=(params.adam_beta1, params.adam_beta2))
        criterion = nn.BCEWithLogitsLoss()

    iters = 0
    startEpoch = 0

    checkpoint = None
    if params.load_checkpoint:
        logging.info("Loading checkpoint %s" % params.load_checkpoint)
        checkpoint = torch.load(params.load_checkpoint)
    elif os.path.exists(params.checkpoint_file):
        logging.info("Loading checkpoint %s" % params.checkpoint_file)
        checkpoint = torch.load(params.checkpoint_file)

    if checkpoint:
        netG.load_state_dict(checkpoint['G_state'])
        netD.load_state_dict(checkpoint['D_state'])

        if params.load_optimizers_from_checkpoint:
            iters = checkpoint['iters']
            startEpoch = checkpoint['epoch']
            if params.optim == 'Adam':
                optimizerD.load_state_dict(checkpoint['optimizerD_state_dict'])
                optimizerG.load_state_dict(checkpoint['optimizerG_state_dict'])

    logging.info("Starting Training Loop...")
    for epoch in range(startEpoch, startEpoch + params.num_epochs):
        netG.train()
        netD.train()
        for i, data in enumerate(data_loader, 0):
            iters += 1
            real = data[0].to(device)
            pars = data[1].to(device)
            b_size = real.size(0)
            if params.optim == 'Adam':
                real_label, fake_label = get_labels(
                    params.flip_labels_alpha, params.decay_label_flipping,
                    epoch)
                label = torch.full((b_size, ), real_label, device=device)

                netD.zero_grad()
                output = netD(real, pars).view(-1)

                errD_real = criterion(output, label)
                errD_real.backward()
                D_x = output.mean().item()

                if params.resample_pars:
                    pars = 2. * torch.rand(
                        (b_size, params.num_params), device=device) - 1.
                noise = torch.randn((b_size, 1, params.z_dim), device=device)
                fake = netG(noise, pars)
                label.fill_(fake_label)
                output = netD(fake.detach(), pars).view(-1)
                errD_fake = criterion(output, label)
                errD_fake.backward()
                D_G_z1 = output.mean().item()
                errD = errD_real + errD_fake
                optimizerD.step()

                netG.zero_grad()
                label.fill_(real_label)
                output = netD(fake, pars).view(-1)
                errG = criterion(output, label)
                errG.backward()
                D_G_z2 = output.mean().item()
                optimizerG.step()

        logging.info(
            '[%d/%d][%d/%d]\tLoss_D: %.4f\tLoss_G: %.4f\tD(x): %.4f\tD(G(z)): %.4f / %.4f'
            % (epoch, startEpoch + params.num_epochs, i, len(data_loader),
               errD.item(), errG.item(), D_x, D_G_z1, D_G_z2))

        # Scalars
        writer.add_scalar('Loss/Disc', errD.item(), iters)
        writer.add_scalar('Loss/Gen', errG.item(), iters)
        writer.add_scalar('Disc_output/real', D_x, iters)
        writer.add_scalar('Disc_output/fake', D_G_z1, iters)

        # Plots
        netG.eval()
        reals = []
        fakes = []
        vpars = []
        for i, data in enumerate(vald_loader, 0):
            real = data[0].numpy()
            pars = data[1].to(device)
            biopars = data[2]
            with torch.no_grad():
                noise = torch.randn((b_size, 1, params.z_dim), device=device)
                fake = netG(noise, pars).detach().cpu().numpy()
                fakes.append(fake)
                vpars.append(biopars.detach().numpy())
                reals.append(real)
        reals = np.concatenate(reals, axis=0)
        fakes = np.concatenate(fakes, axis=0)
        vpars = np.concatenate(vpars, axis=0)
        fig = plot_samples(fakes[:1], vpars[:1], tag=params.tag)
        writer.add_figure('samples', fig, iters, close=True)
        fig, score, fig2 = compare_curves(reals,
                                          fakes,
                                          vpars,
                                          tag=params.tag,
                                          norm=params.norm)
        writer.add_figure('peak_curve', fig, iters, close=True)
        writer.add_figure('MAE_dist', fig2, iters, close=True)
        writer.add_scalar('peak_curve_MAE', score, iters)

        # Save checkpoint
        if params.optim == 'Adam':
            torch.save(
                {
                    'epoch': epoch,
                    'iters': iters,
                    'G_state': netG.state_dict(),
                    'D_state': netD.state_dict(),
                    'optimizerG_state_dict': optimizerG.state_dict(),
                    'optimizerD_state_dict': optimizerD.state_dict()
                }, params.checkpoint_file)