示例#1
0
def main():
    parser = argparse.ArgumentParser()

    default_dataset = 'toy-data.npz'
    parser.add_argument('--data', default=default_dataset, help='data file')
    parser.add_argument('--seed',
                        type=int,
                        default=None,
                        help='random seed. Randomly set if not specified.')

    # training options
    parser.add_argument('--nz',
                        type=int,
                        default=32,
                        help='dimension of latent variable')
    parser.add_argument('--epoch',
                        type=int,
                        default=1000,
                        help='number of training epochs')
    parser.add_argument('--batch-size',
                        type=int,
                        default=128,
                        help='batch size')
    parser.add_argument('--lr',
                        type=float,
                        default=8e-5,
                        help='encoder/decoder learning rate')
    parser.add_argument('--dis-lr',
                        type=float,
                        default=1e-4,
                        help='discriminator learning rate')
    parser.add_argument('--min-lr',
                        type=float,
                        default=5e-5,
                        help='min encoder/decoder learning rate for LR '
                        'scheduler. -1 to disable annealing')
    parser.add_argument('--min-dis-lr',
                        type=float,
                        default=7e-5,
                        help='min discriminator learning rate for LR '
                        'scheduler. -1 to disable annealing')
    parser.add_argument('--wd', type=float, default=0, help='weight decay')
    parser.add_argument('--overlap',
                        type=float,
                        default=.5,
                        help='kernel overlap')
    parser.add_argument('--no-norm-trans',
                        action='store_true',
                        help='if set, use Gaussian posterior without '
                        'transformation')
    parser.add_argument('--plot-interval',
                        type=int,
                        default=1,
                        help='plot interval. 0 to disable plotting.')
    parser.add_argument('--save-interval',
                        type=int,
                        default=0,
                        help='interval to save models. 0 to disable saving.')
    parser.add_argument('--prefix',
                        default='pbigan',
                        help='prefix of output directory')
    parser.add_argument('--comp',
                        type=int,
                        default=7,
                        help='continuous convolution kernel size')
    parser.add_argument('--ae',
                        type=float,
                        default=.2,
                        help='autoencoding regularization strength')
    parser.add_argument('--aeloss',
                        default='smooth_l1',
                        help='autoencoding loss. (options: mse, smooth_l1)')
    parser.add_argument('--ema',
                        dest='ema',
                        type=int,
                        default=-1,
                        help='start epoch of exponential moving average '
                        '(EMA). -1 to disable EMA')
    parser.add_argument('--ema-decay',
                        type=float,
                        default=.9999,
                        help='EMA decay')
    parser.add_argument('--mmd',
                        type=float,
                        default=1,
                        help='MMD strength for latent variable')

    # squash is off when rescale is off
    parser.add_argument('--squash',
                        dest='squash',
                        action='store_const',
                        const=True,
                        default=True,
                        help='bound the generated time series value '
                        'using tanh')
    parser.add_argument('--no-squash',
                        dest='squash',
                        action='store_const',
                        const=False)

    # rescale to [-1, 1]
    parser.add_argument('--rescale',
                        dest='rescale',
                        action='store_const',
                        const=True,
                        default=True,
                        help='if set, rescale time to [-1, 1]')
    parser.add_argument('--no-rescale',
                        dest='rescale',
                        action='store_const',
                        const=False)

    args = parser.parse_args()

    batch_size = args.batch_size
    nz = args.nz

    epochs = args.epoch
    plot_interval = args.plot_interval
    save_interval = args.save_interval

    try:
        npz = np.load(args.data)
        train_data = npz['data']
        train_time = npz['time']
        train_mask = npz['mask']
    except FileNotFoundError:
        if args.data != default_dataset:
            raise
        # Generate the default toy dataset from scratch
        train_data, train_time, train_mask, _, _ = gen_data(
            n_samples=10000,
            seq_len=200,
            max_time=1,
            poisson_rate=50,
            obs_span_rate=.25,
            save_file=default_dataset)

    _, in_channels, seq_len = train_data.shape
    train_time *= train_mask

    if args.seed is None:
        rnd = np.random.RandomState(None)
        random_seed = rnd.randint(np.iinfo(np.uint32).max)
    else:
        random_seed = args.seed
    rnd = np.random.RandomState(random_seed)
    np.random.seed(random_seed)
    torch.manual_seed(random_seed)

    # Scale time
    max_time = 5
    train_time *= max_time

    squash = None
    rescaler = None
    if args.rescale:
        rescaler = Rescaler(train_data)
        train_data = rescaler.rescale(train_data)
        if args.squash:
            squash = torch.tanh

    out_channels = 64
    cconv_ref = 98

    train_dataset = TimeSeries(train_data,
                               train_time,
                               train_mask,
                               label=None,
                               max_time=max_time,
                               cconv_ref=cconv_ref,
                               overlap_rate=args.overlap,
                               device=device)

    train_loader = DataLoader(train_dataset,
                              batch_size=batch_size,
                              shuffle=True,
                              drop_last=True,
                              collate_fn=train_dataset.collate_fn)
    n_train_batch = len(train_loader)

    time_loader = DataLoader(train_dataset,
                             batch_size=batch_size,
                             shuffle=True,
                             drop_last=True,
                             collate_fn=train_dataset.collate_fn)

    test_loader = DataLoader(train_dataset,
                             batch_size=batch_size,
                             collate_fn=train_dataset.collate_fn)

    grid_decoder = SeqGeneratorDiscrete(in_channels, nz, squash)
    decoder = Decoder(grid_decoder, max_time=max_time).to(device)

    cconv = ContinuousConv1D(in_channels,
                             out_channels,
                             max_time,
                             cconv_ref,
                             overlap_rate=args.overlap,
                             kernel_size=args.comp,
                             norm=True).to(device)
    encoder = Encoder(cconv, nz, not args.no_norm_trans).to(device)

    pbigan = PBiGAN(encoder, decoder, args.aeloss).to(device)

    critic_cconv = ContinuousConv1D(in_channels,
                                    out_channels,
                                    max_time,
                                    cconv_ref,
                                    overlap_rate=args.overlap,
                                    kernel_size=args.comp,
                                    norm=True).to(device)
    critic = ConvCritic(critic_cconv, nz).to(device)

    ema = None
    if args.ema >= 0:
        ema = EMA(pbigan, args.ema_decay, args.ema)

    optimizer = optim.Adam(pbigan.parameters(),
                           lr=args.lr,
                           weight_decay=args.wd)
    critic_optimizer = optim.Adam(critic.parameters(),
                                  lr=args.dis_lr,
                                  weight_decay=args.wd)

    scheduler = make_scheduler(optimizer, args.lr, args.min_lr, epochs)
    dis_scheduler = make_scheduler(critic_optimizer, args.dis_lr,
                                   args.min_dis_lr, epochs)

    path = '{}_{}'.format(args.prefix, datetime.now().strftime('%m%d.%H%M%S'))

    output_dir = Path('results') / 'toy-pbigan' / path
    print(output_dir)
    log_dir = mkdir(output_dir / 'log')
    model_dir = mkdir(output_dir / 'model')

    start_epoch = 0

    with (log_dir / 'seed.txt').open('w') as f:
        print(random_seed, file=f)
    with (log_dir / 'gpu.txt').open('a') as f:
        print(torch.cuda.device_count(), start_epoch, file=f)
    with (log_dir / 'args.txt').open('w') as f:
        for key, val in sorted(vars(args).items()):
            print(f'{key}: {val}', file=f)

    tracker = Tracker(log_dir, n_train_batch)
    visualizer = Visualizer(encoder, decoder, batch_size, max_time,
                            test_loader, rescaler, output_dir, device)
    start = time.time()
    epoch_start = start

    for epoch in range(start_epoch, epochs):
        loss_breakdown = defaultdict(float)

        for ((val, idx, mask, _, cconv_graph),
             (_, idx_t, mask_t, index, _)) in zip(train_loader, time_loader):

            z_enc, x_recon, z_gen, x_gen, ae_loss = pbigan(
                val, idx, mask, cconv_graph, idx_t, mask_t)

            cconv_graph_gen = train_dataset.make_graph(x_gen, idx_t, mask_t,
                                                       index)

            real = critic(cconv_graph, batch_size, z_enc)
            fake = critic(cconv_graph_gen, batch_size, z_gen)

            D_loss = gan_loss(real, fake, 1, 0)

            critic_optimizer.zero_grad()
            D_loss.backward(retain_graph=True)
            critic_optimizer.step()

            G_loss = gan_loss(real, fake, 0, 1)

            mmd_loss = mmd(z_enc, z_gen)

            loss = G_loss + ae_loss * args.ae + mmd_loss * args.mmd

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            if ema:
                ema.update()

            loss_breakdown['D'] += D_loss.item()
            loss_breakdown['G'] += G_loss.item()
            loss_breakdown['AE'] += ae_loss.item()
            loss_breakdown['MMD'] += mmd_loss.item()
            loss_breakdown['total'] += loss.item()

        if scheduler:
            scheduler.step()
        if dis_scheduler:
            dis_scheduler.step()

        cur_time = time.time()
        tracker.log(epoch, loss_breakdown, cur_time - epoch_start,
                    cur_time - start)

        if plot_interval > 0 and (epoch + 1) % plot_interval == 0:
            if ema:
                ema.apply()
                visualizer.plot(epoch)
                ema.restore()
            else:
                visualizer.plot(epoch)

        model_dict = {
            'pbigan': pbigan.state_dict(),
            'critic': critic.state_dict(),
            'ema': ema.state_dict() if ema else None,
            'epoch': epoch + 1,
            'args': args,
        }
        torch.save(model_dict, str(log_dir / 'model.pth'))
        if save_interval > 0 and (epoch + 1) % save_interval == 0:
            torch.save(model_dict, str(model_dir / f'{epoch:04d}.pth'))

    print(output_dir)
def main():
    parser = argparse.ArgumentParser()

    parser.add_argument('--data', default='mimic3.npz', help='data file')
    parser.add_argument('--seed',
                        type=int,
                        default=None,
                        help='random seed. Randomly set if not specified.')

    # training options
    parser.add_argument('--nz',
                        type=int,
                        default=32,
                        help='dimension of latent variable')
    parser.add_argument('--epoch',
                        type=int,
                        default=200,
                        help='number of training epochs')
    parser.add_argument('--batch-size',
                        type=int,
                        default=64,
                        help='batch size')
    # Use smaller test batch size to accommodate more importance samples
    parser.add_argument('--test-batch-size',
                        type=int,
                        default=32,
                        help='batch size for validation and test set')
    parser.add_argument('--train-k',
                        type=int,
                        default=8,
                        help='number of importance weights for training')
    parser.add_argument('--test-k',
                        type=int,
                        default=50,
                        help='number of importance weights for evaluation')
    parser.add_argument('--flow',
                        type=int,
                        default=2,
                        help='number of IAF layers')
    parser.add_argument('--lr',
                        type=float,
                        default=2e-4,
                        help='global learning rate')
    parser.add_argument('--enc-lr',
                        type=float,
                        default=1e-4,
                        help='encoder learning rate')
    parser.add_argument('--dec-lr',
                        type=float,
                        default=1e-4,
                        help='decoder learning rate')
    parser.add_argument('--min-lr',
                        type=float,
                        default=-1,
                        help='min learning rate for LR scheduler. '
                        '-1 to disable annealing')
    parser.add_argument('--wd', type=float, default=1e-3, help='weight decay')
    parser.add_argument('--overlap',
                        type=float,
                        default=.5,
                        help='kernel overlap')
    parser.add_argument('--cls',
                        type=float,
                        default=200,
                        help='classification weight')
    parser.add_argument('--clsdep',
                        type=int,
                        default=1,
                        help='number of layers for classifier')
    parser.add_argument('--ts',
                        type=float,
                        default=1,
                        help='log-likelihood weight for ELBO')
    parser.add_argument('--kl',
                        type=float,
                        default=.1,
                        help='KL weight for ELBO')
    parser.add_argument('--eval-interval',
                        type=int,
                        default=1,
                        help='AUC evaluation interval. '
                        '0 to disable evaluation.')
    parser.add_argument('--save-interval',
                        type=int,
                        default=0,
                        help='interval to save models. 0 to disable saving.')
    parser.add_argument('--prefix',
                        default='pvae',
                        help='prefix of output directory')
    parser.add_argument('--comp',
                        type=int,
                        default=7,
                        help='continuous convolution kernel size')
    parser.add_argument('--sigma',
                        type=float,
                        default=.2,
                        help='standard deviation for Gaussian likelihood')
    parser.add_argument('--dec-ch',
                        default='8-16-16',
                        help='decoder architecture')
    parser.add_argument('--enc-ch',
                        default='64-32-32-16',
                        help='encoder architecture')
    parser.add_argument('--rescale',
                        dest='rescale',
                        action='store_const',
                        const=True,
                        default=True,
                        help='if set, rescale time to [-1, 1]')
    parser.add_argument('--no-rescale',
                        dest='rescale',
                        action='store_const',
                        const=False)
    parser.add_argument('--cconvnorm',
                        dest='cconv_norm',
                        action='store_const',
                        const=True,
                        default=True,
                        help='if set, normalize continuous convolutional '
                        'layer using mean pooling')
    parser.add_argument('--no-cconvnorm',
                        dest='cconv_norm',
                        action='store_const',
                        const=False)
    parser.add_argument('--cconv-ref',
                        type=int,
                        default=98,
                        help='number of evenly-spaced reference locations '
                        'for continuous convolutional layer')
    parser.add_argument('--dec-ref',
                        type=int,
                        default=128,
                        help='number of evenly-spaced reference locations '
                        'for decoder')
    parser.add_argument('--ema',
                        dest='ema',
                        type=int,
                        default=0,
                        help='start epoch of exponential moving average '
                        '(EMA). -1 to disable EMA')
    parser.add_argument('--ema-decay',
                        type=float,
                        default=.9999,
                        help='EMA decay')

    args = parser.parse_args()

    nz = args.nz

    epochs = args.epoch
    eval_interval = args.eval_interval
    save_interval = args.save_interval

    if args.seed is None:
        rnd = np.random.RandomState(None)
        random_seed = rnd.randint(np.iinfo(np.uint32).max)
    else:
        random_seed = args.seed
    rnd = np.random.RandomState(random_seed)
    np.random.seed(random_seed)
    torch.manual_seed(random_seed)

    max_time = 5
    cconv_ref = args.cconv_ref
    overlap = args.overlap
    train_dataset, val_dataset, test_dataset = time_series.split_data(
        args.data, rnd, max_time, cconv_ref, overlap, device, args.rescale)

    train_loader = DataLoader(train_dataset,
                              batch_size=args.batch_size,
                              shuffle=True,
                              drop_last=True,
                              collate_fn=train_dataset.collate_fn)
    n_train_batch = len(train_loader)

    val_loader = DataLoader(val_dataset,
                            batch_size=args.test_batch_size,
                            shuffle=False,
                            collate_fn=val_dataset.collate_fn)

    test_loader = DataLoader(test_dataset,
                             batch_size=args.test_batch_size,
                             shuffle=False,
                             collate_fn=test_dataset.collate_fn)

    in_channels, seq_len = train_dataset.data.shape[1:]

    dec_channels = [int(c) for c in args.dec_ch.split('-')] + [in_channels]
    enc_channels = [int(c) for c in args.enc_ch.split('-')]

    out_channels = enc_channels[0]

    squash = torch.sigmoid
    if args.rescale:
        squash = torch.tanh

    dec_ch_up = 2**(len(dec_channels) - 2)
    assert args.dec_ref % dec_ch_up == 0, (
        f'--dec-ref={args.dec_ref} is not divided by {dec_ch_up}.')
    dec_len0 = args.dec_ref // dec_ch_up
    grid_decoder = GridDecoder(nz, dec_channels, dec_len0, squash)

    decoder = Decoder(grid_decoder, max_time=max_time,
                      dec_ref=args.dec_ref).to(device)

    cconv = ContinuousConv1D(in_channels,
                             out_channels,
                             max_time,
                             cconv_ref,
                             overlap_rate=overlap,
                             kernel_size=args.comp,
                             norm=args.cconv_norm).to(device)
    encoder = Encoder(cconv, nz, enc_channels, args.flow).to(device)

    classifier = Classifier(nz, args.clsdep).to(device)

    pvae = PVAE(encoder, decoder, classifier, args.sigma, args.cls).to(device)

    ema = None
    if args.ema >= 0:
        ema = EMA(pvae, args.ema_decay, args.ema)

    other_params = [
        param for name, param in pvae.named_parameters()
        if not (name.startswith('decoder.grid_decoder')
                or name.startswith('encoder.grid_encoder'))
    ]
    params = [
        {
            'params': decoder.grid_decoder.parameters(),
            'lr': args.dec_lr
        },
        {
            'params': encoder.grid_encoder.parameters(),
            'lr': args.enc_lr
        },
        {
            'params': other_params
        },
    ]

    optimizer = optim.Adam(params, lr=args.lr, weight_decay=args.wd)

    scheduler = make_scheduler(optimizer, args.lr, args.min_lr, epochs)

    path = '{}_{}'.format(args.prefix, datetime.now().strftime('%m%d.%H%M%S'))

    output_dir = Path('results') / 'mimic3-pvae' / path
    print(output_dir)
    log_dir = mkdir(output_dir / 'log')
    model_dir = mkdir(output_dir / 'model')

    start_epoch = 0

    with (log_dir / 'seed.txt').open('w') as f:
        print(random_seed, file=f)
    with (log_dir / 'gpu.txt').open('a') as f:
        print(torch.cuda.device_count(), start_epoch, file=f)
    with (log_dir / 'args.txt').open('w') as f:
        for key, val in sorted(vars(args).items()):
            print(f'{key}: {val}', file=f)
    with (log_dir / 'params.txt').open('w') as f:

        def print_params_count(module, name):
            try:  # sum counts if module is a list
                params_count = sum(count_parameters(m) for m in module)
            except TypeError:
                params_count = count_parameters(module)
            print(f'{name} {params_count}', file=f)

        print_params_count(grid_decoder, 'grid_decoder')
        print_params_count(decoder, 'decoder')
        print_params_count(cconv, 'cconv')
        print_params_count(encoder, 'encoder')
        print_params_count(classifier, 'classifier')
        print_params_count(pvae, 'pvae')
        print_params_count(pvae, 'total')

    tracker = Tracker(log_dir, n_train_batch)
    evaluator = Evaluator(pvae,
                          val_loader,
                          test_loader,
                          log_dir,
                          eval_args={'iw_samples': args.test_k})
    start = time.time()
    epoch_start = start

    for epoch in range(start_epoch, epochs):
        loss_breakdown = defaultdict(float)
        epoch_start = time.time()
        for (val, idx, mask, y, _, cconv_graph) in train_loader:
            optimizer.zero_grad()
            loss, _, _, loss_info = pvae(val, idx, mask, y, cconv_graph,
                                         args.train_k, args.ts, args.kl)
            loss.backward()
            optimizer.step()

            if ema:
                ema.update()

            for loss_name, loss_val in loss_info.items():
                loss_breakdown[loss_name] += loss_val

        if scheduler:
            scheduler.step()

        cur_time = time.time()
        tracker.log(epoch, loss_breakdown, cur_time - epoch_start,
                    cur_time - start)

        if eval_interval > 0 and (epoch + 1) % eval_interval == 0:
            if ema:
                ema.apply()
                evaluator.evaluate(epoch)
                ema.restore()
            else:
                evaluator.evaluate(epoch)

        model_dict = {
            'pvae': pvae.state_dict(),
            'ema': ema.state_dict() if ema else None,
            'epoch': epoch + 1,
            'args': args,
        }
        torch.save(model_dict, str(log_dir / 'model.pth'))
        if save_interval > 0 and (epoch + 1) % save_interval == 0:
            torch.save(model_dict, str(model_dir / f'{epoch:04d}.pth'))

    print(output_dir)
def main():
    parser = argparse.ArgumentParser()

    parser.add_argument('--data', default='mimic3.npz', help='data file')
    parser.add_argument('--seed',
                        type=int,
                        default=None,
                        help='random seed. Randomly set if not specified.')

    # training options
    parser.add_argument('--nz',
                        type=int,
                        default=32,
                        help='dimension of latent variable')
    parser.add_argument('--epoch',
                        type=int,
                        default=500,
                        help='number of training epochs')
    parser.add_argument('--batch-size',
                        type=int,
                        default=64,
                        help='batch size')
    # Use smaller test batch size to accommodate more importance samples
    parser.add_argument('--test-batch-size',
                        type=int,
                        default=32,
                        help='batch size for validation and test set')
    parser.add_argument('--lr',
                        type=float,
                        default=2e-4,
                        help='encoder/decoder learning rate')
    parser.add_argument('--dis-lr',
                        type=float,
                        default=3e-4,
                        help='discriminator learning rate')
    parser.add_argument('--min-lr',
                        type=float,
                        default=1e-4,
                        help='min encoder/decoder learning rate for LR '
                        'scheduler. -1 to disable annealing')
    parser.add_argument('--min-dis-lr',
                        type=float,
                        default=1.5e-4,
                        help='min discriminator learning rate for LR '
                        'scheduler. -1 to disable annealing')
    parser.add_argument('--wd', type=float, default=1e-4, help='weight decay')
    parser.add_argument('--overlap',
                        type=float,
                        default=.5,
                        help='kernel overlap')
    parser.add_argument('--cls',
                        type=float,
                        default=1,
                        help='classification weight')
    parser.add_argument('--clsdep',
                        type=int,
                        default=1,
                        help='number of layers for classifier')
    parser.add_argument('--eval-interval',
                        type=int,
                        default=1,
                        help='AUC evaluation interval. '
                        '0 to disable evaluation.')
    parser.add_argument('--save-interval',
                        type=int,
                        default=0,
                        help='interval to save models. 0 to disable saving.')
    parser.add_argument('--prefix',
                        default='pbigan',
                        help='prefix of output directory')
    parser.add_argument('--comp',
                        type=int,
                        default=7,
                        help='continuous convolution kernel size')
    parser.add_argument('--ae',
                        type=float,
                        default=1,
                        help='autoencoding regularization strength')
    parser.add_argument('--aeloss',
                        default='mse',
                        help='autoencoding loss. (options: mse, smooth_l1)')
    parser.add_argument('--dec-ch',
                        default='8-16-16',
                        help='decoder architecture')
    parser.add_argument('--enc-ch',
                        default='64-32-32-16',
                        help='encoder architecture')
    parser.add_argument('--dis-ch',
                        default=None,
                        help='discriminator architecture. Use encoder '
                        'architecture if unspecified.')
    parser.add_argument('--rescale',
                        dest='rescale',
                        action='store_const',
                        const=True,
                        default=True,
                        help='if set, rescale time to [-1, 1]')
    parser.add_argument('--no-rescale',
                        dest='rescale',
                        action='store_const',
                        const=False)
    parser.add_argument('--cconvnorm',
                        dest='cconv_norm',
                        action='store_const',
                        const=True,
                        default=True,
                        help='if set, normalize continuous convolutional '
                        'layer using mean pooling')
    parser.add_argument('--no-cconvnorm',
                        dest='cconv_norm',
                        action='store_const',
                        const=False)
    parser.add_argument('--cconv-ref',
                        type=int,
                        default=98,
                        help='number of evenly-spaced reference locations '
                        'for continuous convolutional layer')
    parser.add_argument('--dec-ref',
                        type=int,
                        default=128,
                        help='number of evenly-spaced reference locations '
                        'for decoder')
    parser.add_argument('--trans',
                        type=int,
                        default=2,
                        help='number of encoder layers')
    parser.add_argument('--ema',
                        dest='ema',
                        type=int,
                        default=0,
                        help='start epoch of exponential moving average '
                        '(EMA). -1 to disable EMA')
    parser.add_argument('--ema-decay',
                        type=float,
                        default=.9999,
                        help='EMA decay')
    parser.add_argument('--mmd',
                        type=float,
                        default=1,
                        help='MMD strength for latent variable')

    args = parser.parse_args()

    nz = args.nz

    epochs = args.epoch
    eval_interval = args.eval_interval
    save_interval = args.save_interval

    if args.seed is None:
        rnd = np.random.RandomState(None)
        random_seed = rnd.randint(np.iinfo(np.uint32).max)
    else:
        random_seed = args.seed
    rnd = np.random.RandomState(random_seed)
    np.random.seed(random_seed)
    torch.manual_seed(random_seed)

    max_time = 5
    cconv_ref = args.cconv_ref
    overlap = args.overlap
    train_dataset, val_dataset, test_dataset = time_series.split_data(
        args.data, rnd, max_time, cconv_ref, overlap, device, args.rescale)

    train_loader = DataLoader(train_dataset,
                              batch_size=args.batch_size,
                              shuffle=True,
                              drop_last=True,
                              collate_fn=train_dataset.collate_fn)
    n_train_batch = len(train_loader)

    time_loader = DataLoader(train_dataset,
                             batch_size=args.batch_size,
                             shuffle=True,
                             drop_last=True,
                             collate_fn=train_dataset.collate_fn)

    val_loader = DataLoader(val_dataset,
                            batch_size=args.test_batch_size,
                            shuffle=False,
                            collate_fn=val_dataset.collate_fn)

    test_loader = DataLoader(test_dataset,
                             batch_size=args.test_batch_size,
                             shuffle=False,
                             collate_fn=test_dataset.collate_fn)

    in_channels, seq_len = train_dataset.data.shape[1:]

    if args.dis_ch is None:
        args.dis_ch = args.enc_ch

    dec_channels = [int(c) for c in args.dec_ch.split('-')] + [in_channels]
    enc_channels = [int(c) for c in args.enc_ch.split('-')]
    dis_channels = [int(c) for c in args.dis_ch.split('-')]

    out_channels = enc_channels[0]

    squash = torch.sigmoid
    if args.rescale:
        squash = torch.tanh

    dec_ch_up = 2**(len(dec_channels) - 2)
    assert args.dec_ref % dec_ch_up == 0, (
        f'--dec-ref={args.dec_ref} is not divided by {dec_ch_up}.')
    dec_len0 = args.dec_ref // dec_ch_up
    grid_decoder = GridDecoder(nz, dec_channels, dec_len0, squash)

    decoder = Decoder(grid_decoder, max_time=max_time,
                      dec_ref=args.dec_ref).to(device)
    cconv = ContinuousConv1D(in_channels,
                             out_channels,
                             max_time,
                             cconv_ref,
                             overlap_rate=overlap,
                             kernel_size=args.comp,
                             norm=args.cconv_norm).to(device)
    encoder = Encoder(cconv, nz, enc_channels, args.trans).to(device)

    classifier = Classifier(nz, args.clsdep).to(device)

    pbigan = PBiGAN(encoder, decoder, classifier,
                    ae_loss=args.aeloss).to(device)

    ema = None
    if args.ema >= 0:
        ema = EMA(pbigan, args.ema_decay, args.ema)

    critic_cconv = ContinuousConv1D(in_channels,
                                    out_channels,
                                    max_time,
                                    cconv_ref,
                                    overlap_rate=overlap,
                                    kernel_size=args.comp,
                                    norm=args.cconv_norm).to(device)
    critic_embed = 32
    critic = ConvCritic(critic_cconv, nz, dis_channels,
                        critic_embed).to(device)

    optimizer = optim.Adam(pbigan.parameters(),
                           lr=args.lr,
                           betas=(0, .999),
                           weight_decay=args.wd)
    critic_optimizer = optim.Adam(critic.parameters(),
                                  lr=args.dis_lr,
                                  betas=(0, .999),
                                  weight_decay=args.wd)

    scheduler = make_scheduler(optimizer, args.lr, args.min_lr, epochs)
    dis_scheduler = make_scheduler(critic_optimizer, args.dis_lr,
                                   args.min_dis_lr, epochs)

    path = '{}_{}'.format(args.prefix, datetime.now().strftime('%m%d.%H%M%S'))

    output_dir = Path('results') / 'mimic3-pbigan' / path
    print(output_dir)
    log_dir = mkdir(output_dir / 'log')
    model_dir = mkdir(output_dir / 'model')

    start_epoch = 0

    with (log_dir / 'seed.txt').open('w') as f:
        print(random_seed, file=f)
    with (log_dir / 'gpu.txt').open('a') as f:
        print(torch.cuda.device_count(), start_epoch, file=f)
    with (log_dir / 'args.txt').open('w') as f:
        for key, val in sorted(vars(args).items()):
            print(f'{key}: {val}', file=f)
    with (log_dir / 'params.txt').open('w') as f:

        def print_params_count(module, name):
            try:  # sum counts if module is a list
                params_count = sum(count_parameters(m) for m in module)
            except TypeError:
                params_count = count_parameters(module)
            print(f'{name} {params_count}', file=f)

        print_params_count(grid_decoder, 'grid_decoder')
        print_params_count(decoder, 'decoder')
        print_params_count(cconv, 'cconv')
        print_params_count(encoder, 'encoder')
        print_params_count(classifier, 'classifier')
        print_params_count(pbigan, 'pbigan')
        print_params_count(critic, 'critic')
        print_params_count([pbigan, critic], 'total')

    tracker = Tracker(log_dir, n_train_batch)
    evaluator = Evaluator(pbigan, val_loader, test_loader, log_dir)
    start = time.time()
    epoch_start = start

    batch_size = args.batch_size

    for epoch in range(start_epoch, epochs):
        loss_breakdown = defaultdict(float)
        epoch_start = time.time()

        if epoch >= 40:
            args.cls = 200

        for ((val, idx, mask, y, _, cconv_graph),
             (_, idx_t, mask_t, _, index, _)) in zip(train_loader,
                                                     time_loader):

            z_enc, x_recon, z_gen, x_gen, ae_loss, cls_loss = pbigan(
                val, idx, mask, y, cconv_graph, idx_t, mask_t)

            cconv_graph_gen = train_dataset.make_graph(x_gen, idx_t, mask_t,
                                                       index)

            # Don't need pbigan.requires_grad_(False);
            # critic takes as input only the detached tensors.
            real = critic(cconv_graph, batch_size, z_enc.detach())
            detached_graph = [[cat_y.detach() for cat_y in x] if i == 2 else x
                              for i, x in enumerate(cconv_graph_gen)]
            fake = critic(detached_graph, batch_size, z_gen.detach())

            D_loss = gan_loss(real, fake, 1, 0)

            critic_optimizer.zero_grad()
            D_loss.backward()
            critic_optimizer.step()

            for p in critic.parameters():
                p.requires_grad_(False)
            real = critic(cconv_graph, batch_size, z_enc)
            fake = critic(cconv_graph_gen, batch_size, z_gen)

            G_loss = gan_loss(real, fake, 0, 1)

            mmd_loss = mmd(z_enc, z_gen)

            loss = (G_loss + ae_loss * args.ae + cls_loss * args.cls +
                    mmd_loss * args.mmd)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            for p in critic.parameters():
                p.requires_grad_(True)

            if ema:
                ema.update()

            loss_breakdown['D'] += D_loss.item()
            loss_breakdown['G'] += G_loss.item()
            loss_breakdown['AE'] += ae_loss.item()
            loss_breakdown['MMD'] += mmd_loss.item()
            loss_breakdown['CLS'] += cls_loss.item()
            loss_breakdown['total'] += loss.item()

        if scheduler:
            scheduler.step()
        if dis_scheduler:
            dis_scheduler.step()

        cur_time = time.time()
        tracker.log(epoch, loss_breakdown, cur_time - epoch_start,
                    cur_time - start)

        if eval_interval > 0 and (epoch + 1) % eval_interval == 0:
            if ema:
                ema.apply()
                evaluator.evaluate(epoch)
                ema.restore()
            else:
                evaluator.evaluate(epoch)

        model_dict = {
            'pbigan': pbigan.state_dict(),
            'critic': critic.state_dict(),
            'ema': ema.state_dict() if ema else None,
            'epoch': epoch + 1,
            'args': args,
        }
        torch.save(model_dict, str(log_dir / 'model.pth'))
        if save_interval > 0 and (epoch + 1) % save_interval == 0:
            torch.save(model_dict, str(model_dir / f'{epoch:04d}.pth'))

    print(output_dir)