Exemple #1
0
def main():
    args = get_args()
    torch.manual_seed(args.seed)

    shape = (224, 224, 3)
    """ define dataloader """
    train_loader, valid_loader, test_loader = make_dataloader(args)
    """ define model architecture """
    model = get_model(args, shape, args.num_classes)

    if torch.cuda.device_count() >= 1:
        print('Model pushed to {} GPU(s), type {}.'.format(
            torch.cuda.device_count(), torch.cuda.get_device_name(0)))
        model = model.cuda()
    else:
        raise ValueError('CPU training is not supported')
    """ define loss criterion """
    criterion = nn.CrossEntropyLoss().cuda()
    """ define optimizer """
    optimizer = make_optimizer(args, model)
    """ define learning rate scheduler """
    scheduler = make_scheduler(args, optimizer)
    """ define trainer, evaluator, result_dictionary """
    result_dict = {
        'args': vars(args),
        'epoch': [],
        'train_loss': [],
        'train_acc': [],
        'val_loss': [],
        'val_acc': [],
        'test_acc': []
    }
    trainer = Trainer(model, criterion, optimizer, scheduler)
    evaluator = Evaluator(model, criterion)

    if args.evaluate:
        """ load model checkpoint """
        model.load()
        result_dict = evaluator.test(test_loader, args, result_dict)
    else:
        evaluator.save(result_dict)
        """ define training loop """
        for epoch in range(args.epochs):
            result_dict['epoch'] = epoch
            result_dict = trainer.train(train_loader, epoch, args, result_dict)
            result_dict = evaluator.evaluate(valid_loader, epoch, args,
                                             result_dict)
            evaluator.save(result_dict)
            plot_learning_curves(result_dict, epoch, args)

        result_dict = evaluator.test(test_loader, args, result_dict)
        evaluator.save(result_dict)
        """ save model checkpoint """
        model.save()

    print(result_dict)
 def __init__(self, args, gan_type):
     super(Adversarial, self).__init__()
     self.gan_type = gan_type
     self.gan_k = args.gan_k
     self.discriminator = discriminator.Discriminator(args, gan_type)
     if gan_type != 'WGAN_GP':
         self.optimizer = utils.make_optimizer(args, self.discriminator)
     else:
         self.optimizer = optim.Adam(self.discriminator.parameters(),
                                     betas=(0, 0.9),
                                     eps=1e-8,
                                     lr=1e-5)
     self.scheduler = utils.make_scheduler(args, self.optimizer)
def runExperiment():
    seed = int(cfg['model_tag'].split('_')[0])
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    dataset = fetch_dataset(cfg['data_name'], cfg['subset'])
    process_dataset(dataset)
    model = eval('models.{}(model_rate=cfg["global_model_rate"]).to(cfg["device"])'.format(cfg['model_name']))
    optimizer = make_optimizer(model, cfg['lr'])
    scheduler = make_scheduler(optimizer)
    if cfg['resume_mode'] == 1:
        last_epoch, data_split, label_split, model, optimizer, scheduler, logger = resume(model, cfg['model_tag'],
                                                                                          optimizer, scheduler)
    elif cfg['resume_mode'] == 2:
        last_epoch = 1
        _, data_split, label_split, model, _, _, _ = resume(model, cfg['model_tag'])
        current_time = datetime.datetime.now().strftime('%b%d_%H-%M-%S')
        logger_path = 'output/runs/{}_{}'.format(cfg['model_tag'], current_time)
        logger = Logger(logger_path)
    else:
        last_epoch = 1
        data_split, label_split = split_dataset(dataset, cfg['num_users'], cfg['data_split_mode'])
        current_time = datetime.datetime.now().strftime('%b%d_%H-%M-%S')
        logger_path = 'output/runs/train_{}_{}'.format(cfg['model_tag'], current_time)
        logger = Logger(logger_path)
    if data_split is None:
        data_split, label_split = split_dataset(dataset, cfg['num_users'], cfg['data_split_mode'])
    global_parameters = model.state_dict()
    federation = Federation(global_parameters, cfg['model_rate'], label_split)
    for epoch in range(last_epoch, cfg['num_epochs']['global'] + 1):
        logger.safe(True)
        train(dataset['train'], data_split['train'], label_split, federation, model, optimizer, logger, epoch)
        test_model = stats(dataset['train'], model)
        test(dataset['test'], data_split['test'], label_split, test_model, logger, epoch)
        if cfg['scheduler_name'] == 'ReduceLROnPlateau':
            scheduler.step(metrics=logger.mean['train/{}'.format(cfg['pivot_metric'])])
        else:
            scheduler.step()
        logger.safe(False)
        model_state_dict = model.state_dict()
        save_result = {
            'cfg': cfg, 'epoch': epoch + 1, 'data_split': data_split, 'label_split': label_split,
            'model_dict': model_state_dict, 'optimizer_dict': optimizer.state_dict(),
            'scheduler_dict': scheduler.state_dict(), 'logger': logger}
        save(save_result, './output/model/{}_checkpoint.pt'.format(cfg['model_tag']))
        if cfg['pivot'] < logger.mean['test/{}'.format(cfg['pivot_metric'])]:
            cfg['pivot'] = logger.mean['test/{}'.format(cfg['pivot_metric'])]
            shutil.copy('./output/model/{}_checkpoint.pt'.format(cfg['model_tag']),
                        './output/model/{}_best.pt'.format(cfg['model_tag']))
        logger.reset()
    logger.safe(False)
    return
Exemple #4
0
    def __init__(self, args, train_loader, val_loader, model, loss):
        self.args = args
        self.train_loader = train_loader
        self.val_loader = val_loader
        self.model = model
        if args.pre_train_model != "...":
            self.model.load_state_dict(torch.load(args.pre_train_model))
        self.loss = loss

        # Actually I am not sure what would happen if I specify the optimizer and scheduler in main.py
        self.optimizer = utils.make_optimizer(self.args, self.model)
        if args.pre_train_optimizer != "...":
            self.optimizer.load_state_dict(torch.load(
                args.pre_train_optimizer))
        self.scheduler = utils.make_scheduler(self.args, self.optimizer)
        self.iter = 0
Exemple #5
0
        net = get_model(args).to(device)
        #net = ConvNet().to(device)

        # unpack args
        epochs = args.n_epochs
        lr = args.lr
        lmbda = args.lmbda
        first_regularized_epoch = args.first_regularized_epoch
        if first_regularized_epoch < 0:
            first_regularized_epoch = 0

        optimizer = torch.optim.SGD(net.parameters(),
                                    lr=lr,
                                    momentum=0.9,
                                    weight_decay=5e-4)
        scheduler = make_scheduler(args, optimizer)

        criterion = make_criterion(args)
        regularizer = make_regularizer(args)
        regularizer_name = 'None'
        if regularizer:
            regularizer_name = regularizer.name_attribute

        if args.watch:  #((not isinstance(regularizer, LPGrad)) and args.watch):
            wandb.watch(
                net, criterion, log="all",
                log_freq=args.watch_freq)  # track parameters and gradients

        best_test_acc = float('-inf')

        # train_loss_tracker, train_acc_tracker = [], []
def main():
    parser = argparse.ArgumentParser()

    parser.add_argument('--data', default='mimic3.npz', help='data file')
    parser.add_argument('--seed',
                        type=int,
                        default=None,
                        help='random seed. Randomly set if not specified.')

    # training options
    parser.add_argument('--nz',
                        type=int,
                        default=32,
                        help='dimension of latent variable')
    parser.add_argument('--epoch',
                        type=int,
                        default=200,
                        help='number of training epochs')
    parser.add_argument('--batch-size',
                        type=int,
                        default=64,
                        help='batch size')
    # Use smaller test batch size to accommodate more importance samples
    parser.add_argument('--test-batch-size',
                        type=int,
                        default=32,
                        help='batch size for validation and test set')
    parser.add_argument('--train-k',
                        type=int,
                        default=8,
                        help='number of importance weights for training')
    parser.add_argument('--test-k',
                        type=int,
                        default=50,
                        help='number of importance weights for evaluation')
    parser.add_argument('--flow',
                        type=int,
                        default=2,
                        help='number of IAF layers')
    parser.add_argument('--lr',
                        type=float,
                        default=2e-4,
                        help='global learning rate')
    parser.add_argument('--enc-lr',
                        type=float,
                        default=1e-4,
                        help='encoder learning rate')
    parser.add_argument('--dec-lr',
                        type=float,
                        default=1e-4,
                        help='decoder learning rate')
    parser.add_argument('--min-lr',
                        type=float,
                        default=-1,
                        help='min learning rate for LR scheduler. '
                        '-1 to disable annealing')
    parser.add_argument('--wd', type=float, default=1e-3, help='weight decay')
    parser.add_argument('--overlap',
                        type=float,
                        default=.5,
                        help='kernel overlap')
    parser.add_argument('--cls',
                        type=float,
                        default=200,
                        help='classification weight')
    parser.add_argument('--clsdep',
                        type=int,
                        default=1,
                        help='number of layers for classifier')
    parser.add_argument('--ts',
                        type=float,
                        default=1,
                        help='log-likelihood weight for ELBO')
    parser.add_argument('--kl',
                        type=float,
                        default=.1,
                        help='KL weight for ELBO')
    parser.add_argument('--eval-interval',
                        type=int,
                        default=1,
                        help='AUC evaluation interval. '
                        '0 to disable evaluation.')
    parser.add_argument('--save-interval',
                        type=int,
                        default=0,
                        help='interval to save models. 0 to disable saving.')
    parser.add_argument('--prefix',
                        default='pvae',
                        help='prefix of output directory')
    parser.add_argument('--comp',
                        type=int,
                        default=7,
                        help='continuous convolution kernel size')
    parser.add_argument('--sigma',
                        type=float,
                        default=.2,
                        help='standard deviation for Gaussian likelihood')
    parser.add_argument('--dec-ch',
                        default='8-16-16',
                        help='decoder architecture')
    parser.add_argument('--enc-ch',
                        default='64-32-32-16',
                        help='encoder architecture')
    parser.add_argument('--rescale',
                        dest='rescale',
                        action='store_const',
                        const=True,
                        default=True,
                        help='if set, rescale time to [-1, 1]')
    parser.add_argument('--no-rescale',
                        dest='rescale',
                        action='store_const',
                        const=False)
    parser.add_argument('--cconvnorm',
                        dest='cconv_norm',
                        action='store_const',
                        const=True,
                        default=True,
                        help='if set, normalize continuous convolutional '
                        'layer using mean pooling')
    parser.add_argument('--no-cconvnorm',
                        dest='cconv_norm',
                        action='store_const',
                        const=False)
    parser.add_argument('--cconv-ref',
                        type=int,
                        default=98,
                        help='number of evenly-spaced reference locations '
                        'for continuous convolutional layer')
    parser.add_argument('--dec-ref',
                        type=int,
                        default=128,
                        help='number of evenly-spaced reference locations '
                        'for decoder')
    parser.add_argument('--ema',
                        dest='ema',
                        type=int,
                        default=0,
                        help='start epoch of exponential moving average '
                        '(EMA). -1 to disable EMA')
    parser.add_argument('--ema-decay',
                        type=float,
                        default=.9999,
                        help='EMA decay')

    args = parser.parse_args()

    nz = args.nz

    epochs = args.epoch
    eval_interval = args.eval_interval
    save_interval = args.save_interval

    if args.seed is None:
        rnd = np.random.RandomState(None)
        random_seed = rnd.randint(np.iinfo(np.uint32).max)
    else:
        random_seed = args.seed
    rnd = np.random.RandomState(random_seed)
    np.random.seed(random_seed)
    torch.manual_seed(random_seed)

    max_time = 5
    cconv_ref = args.cconv_ref
    overlap = args.overlap
    train_dataset, val_dataset, test_dataset = time_series.split_data(
        args.data, rnd, max_time, cconv_ref, overlap, device, args.rescale)

    train_loader = DataLoader(train_dataset,
                              batch_size=args.batch_size,
                              shuffle=True,
                              drop_last=True,
                              collate_fn=train_dataset.collate_fn)
    n_train_batch = len(train_loader)

    val_loader = DataLoader(val_dataset,
                            batch_size=args.test_batch_size,
                            shuffle=False,
                            collate_fn=val_dataset.collate_fn)

    test_loader = DataLoader(test_dataset,
                             batch_size=args.test_batch_size,
                             shuffle=False,
                             collate_fn=test_dataset.collate_fn)

    in_channels, seq_len = train_dataset.data.shape[1:]

    dec_channels = [int(c) for c in args.dec_ch.split('-')] + [in_channels]
    enc_channels = [int(c) for c in args.enc_ch.split('-')]

    out_channels = enc_channels[0]

    squash = torch.sigmoid
    if args.rescale:
        squash = torch.tanh

    dec_ch_up = 2**(len(dec_channels) - 2)
    assert args.dec_ref % dec_ch_up == 0, (
        f'--dec-ref={args.dec_ref} is not divided by {dec_ch_up}.')
    dec_len0 = args.dec_ref // dec_ch_up
    grid_decoder = GridDecoder(nz, dec_channels, dec_len0, squash)

    decoder = Decoder(grid_decoder, max_time=max_time,
                      dec_ref=args.dec_ref).to(device)

    cconv = ContinuousConv1D(in_channels,
                             out_channels,
                             max_time,
                             cconv_ref,
                             overlap_rate=overlap,
                             kernel_size=args.comp,
                             norm=args.cconv_norm).to(device)
    encoder = Encoder(cconv, nz, enc_channels, args.flow).to(device)

    classifier = Classifier(nz, args.clsdep).to(device)

    pvae = PVAE(encoder, decoder, classifier, args.sigma, args.cls).to(device)

    ema = None
    if args.ema >= 0:
        ema = EMA(pvae, args.ema_decay, args.ema)

    other_params = [
        param for name, param in pvae.named_parameters()
        if not (name.startswith('decoder.grid_decoder')
                or name.startswith('encoder.grid_encoder'))
    ]
    params = [
        {
            'params': decoder.grid_decoder.parameters(),
            'lr': args.dec_lr
        },
        {
            'params': encoder.grid_encoder.parameters(),
            'lr': args.enc_lr
        },
        {
            'params': other_params
        },
    ]

    optimizer = optim.Adam(params, lr=args.lr, weight_decay=args.wd)

    scheduler = make_scheduler(optimizer, args.lr, args.min_lr, epochs)

    path = '{}_{}'.format(args.prefix, datetime.now().strftime('%m%d.%H%M%S'))

    output_dir = Path('results') / 'mimic3-pvae' / path
    print(output_dir)
    log_dir = mkdir(output_dir / 'log')
    model_dir = mkdir(output_dir / 'model')

    start_epoch = 0

    with (log_dir / 'seed.txt').open('w') as f:
        print(random_seed, file=f)
    with (log_dir / 'gpu.txt').open('a') as f:
        print(torch.cuda.device_count(), start_epoch, file=f)
    with (log_dir / 'args.txt').open('w') as f:
        for key, val in sorted(vars(args).items()):
            print(f'{key}: {val}', file=f)
    with (log_dir / 'params.txt').open('w') as f:

        def print_params_count(module, name):
            try:  # sum counts if module is a list
                params_count = sum(count_parameters(m) for m in module)
            except TypeError:
                params_count = count_parameters(module)
            print(f'{name} {params_count}', file=f)

        print_params_count(grid_decoder, 'grid_decoder')
        print_params_count(decoder, 'decoder')
        print_params_count(cconv, 'cconv')
        print_params_count(encoder, 'encoder')
        print_params_count(classifier, 'classifier')
        print_params_count(pvae, 'pvae')
        print_params_count(pvae, 'total')

    tracker = Tracker(log_dir, n_train_batch)
    evaluator = Evaluator(pvae,
                          val_loader,
                          test_loader,
                          log_dir,
                          eval_args={'iw_samples': args.test_k})
    start = time.time()
    epoch_start = start

    for epoch in range(start_epoch, epochs):
        loss_breakdown = defaultdict(float)
        epoch_start = time.time()
        for (val, idx, mask, y, _, cconv_graph) in train_loader:
            optimizer.zero_grad()
            loss, _, _, loss_info = pvae(val, idx, mask, y, cconv_graph,
                                         args.train_k, args.ts, args.kl)
            loss.backward()
            optimizer.step()

            if ema:
                ema.update()

            for loss_name, loss_val in loss_info.items():
                loss_breakdown[loss_name] += loss_val

        if scheduler:
            scheduler.step()

        cur_time = time.time()
        tracker.log(epoch, loss_breakdown, cur_time - epoch_start,
                    cur_time - start)

        if eval_interval > 0 and (epoch + 1) % eval_interval == 0:
            if ema:
                ema.apply()
                evaluator.evaluate(epoch)
                ema.restore()
            else:
                evaluator.evaluate(epoch)

        model_dict = {
            'pvae': pvae.state_dict(),
            'ema': ema.state_dict() if ema else None,
            'epoch': epoch + 1,
            'args': args,
        }
        torch.save(model_dict, str(log_dir / 'model.pth'))
        if save_interval > 0 and (epoch + 1) % save_interval == 0:
            torch.save(model_dict, str(model_dir / f'{epoch:04d}.pth'))

    print(output_dir)
def train_pbigan(args):
    torch.manual_seed(args.seed)

    if args.mask == 'indep':
        data = IndepMaskedMNIST(obs_prob=args.obs_prob,
                                obs_prob_max=args.obs_prob_max)
        mask_str = f'{args.mask}_{args.obs_prob}_{args.obs_prob_max}'
    elif args.mask == 'block':
        data = BlockMaskedMNIST(block_len=args.block_len,
                                block_len_max=args.block_len_max)
        mask_str = f'{args.mask}_{args.block_len}_{args.block_len_max}'

    data_loader = DataLoader(data,
                             batch_size=args.batch_size,
                             shuffle=True,
                             drop_last=True)
    mask_loader = DataLoader(data,
                             batch_size=args.batch_size,
                             shuffle=True,
                             drop_last=True)

    # Evaluate the training progress using 2000 examples from the training data
    test_loader = DataLoader(data, batch_size=args.batch_size, drop_last=True)

    decoder = ConvDecoder(args.latent)
    encoder = ConvEncoder(args.latent, args.flow, logprob=False)
    pbigan = PBiGAN(encoder, decoder, args.aeloss).to(device)

    critic = ConvCritic(args.latent).to(device)

    lrate = 1e-4
    optimizer = optim.Adam(pbigan.parameters(), lr=lrate, betas=(.5, .9))

    critic_optimizer = optim.Adam(critic.parameters(),
                                  lr=lrate,
                                  betas=(.5, .9))

    grad_penalty = GradientPenalty(critic, args.batch_size)

    scheduler = make_scheduler(optimizer, args.lr, args.min_lr, args.epoch)

    path = '{}_{}_{}'.format(args.prefix,
                             datetime.now().strftime('%m%d.%H%M%S'), mask_str)
    output_dir = Path('results') / 'mnist-pbigan' / path
    mkdir(output_dir)
    print(output_dir)

    if args.save_interval > 0:
        model_dir = mkdir(output_dir / 'model')

    logging.basicConfig(
        level=logging.INFO,
        format='%(asctime)s %(message)s',
        datefmt='%Y-%m-%d %H:%M:%S',
        handlers=[
            logging.FileHandler(output_dir / 'log.txt'),
            logging.StreamHandler(sys.stdout),
        ],
    )

    with (output_dir / 'args.txt').open('w') as f:
        print(pprint.pformat(vars(args)), file=f)

    vis = Visualizer(output_dir)

    test_x, test_mask, index = iter(test_loader).next()
    test_x = test_x.to(device)
    test_mask = test_mask.to(device).float()
    bbox = None
    if data.mask_loc is not None:
        bbox = [data.mask_loc[idx] for idx in index]

    n_critic = 5
    critic_updates = 0
    ae_weight = 0
    ae_flat = 100

    for epoch in range(args.epoch):
        loss_breakdown = defaultdict(float)

        if epoch > ae_flat:
            ae_weight = args.ae * (epoch - ae_flat) / (args.epoch - ae_flat)

        for (x, mask, _), (_, mask_gen, _) in zip(data_loader, mask_loader):
            x = x.to(device)
            mask = mask.to(device).float()
            mask_gen = mask_gen.to(device).float()

            z_enc, z_gen, x_rec, x_gen, _ = pbigan(x, mask, ae=False)

            real_score = critic((x * mask, z_enc)).mean()
            fake_score = critic((x_gen * mask_gen, z_gen)).mean()

            w_dist = real_score - fake_score
            D_loss = -w_dist + grad_penalty((x * mask, z_enc),
                                            (x_gen * mask_gen, z_gen))

            critic_optimizer.zero_grad()
            D_loss.backward()
            critic_optimizer.step()

            loss_breakdown['D'] += D_loss.item()

            critic_updates += 1

            if critic_updates == n_critic:
                critic_updates = 0

                # Update generators' parameters
                for p in critic.parameters():
                    p.requires_grad_(False)

                z_enc, z_gen, x_rec, x_gen, ae_loss = pbigan(x, mask)

                real_score = critic((x * mask, z_enc)).mean()
                fake_score = critic((x_gen * mask_gen, z_gen)).mean()

                G_loss = real_score - fake_score

                ae_loss = ae_loss * ae_weight
                loss = G_loss + ae_loss

                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

                loss_breakdown['G'] += G_loss.item()
                loss_breakdown['AE'] += ae_loss.item()
                loss_breakdown['total'] += loss.item()

                for p in critic.parameters():
                    p.requires_grad_(True)

        if scheduler:
            scheduler.step()

        vis.plot_loss(epoch, loss_breakdown)

        if epoch % args.plot_interval == 0:
            with torch.no_grad():
                pbigan.eval()
                z, z_gen, x_rec, x_gen, ae_loss = pbigan(test_x, test_mask)
                pbigan.train()
            vis.plot(epoch, test_x, test_mask, bbox, x_rec, x_gen)

        model_dict = {
            'pbigan': pbigan.state_dict(),
            'critic': critic.state_dict(),
            'history': vis.history,
            'epoch': epoch,
            'args': args,
        }
        torch.save(model_dict, str(output_dir / 'model.pth'))
        if args.save_interval > 0 and (epoch + 1) % args.save_interval == 0:
            torch.save(model_dict, str(model_dir / f'{epoch:04d}.pth'))

    print(output_dir)
Exemple #8
0
def main():
    parser = argparse.ArgumentParser()

    default_dataset = 'toy-data.npz'
    parser.add_argument('--data', default=default_dataset, help='data file')
    parser.add_argument('--seed',
                        type=int,
                        default=None,
                        help='random seed. Randomly set if not specified.')

    # training options
    parser.add_argument('--nz',
                        type=int,
                        default=32,
                        help='dimension of latent variable')
    parser.add_argument('--epoch',
                        type=int,
                        default=1000,
                        help='number of training epochs')
    parser.add_argument('--batch-size',
                        type=int,
                        default=128,
                        help='batch size')
    parser.add_argument('--lr',
                        type=float,
                        default=8e-5,
                        help='encoder/decoder learning rate')
    parser.add_argument('--dis-lr',
                        type=float,
                        default=1e-4,
                        help='discriminator learning rate')
    parser.add_argument('--min-lr',
                        type=float,
                        default=5e-5,
                        help='min encoder/decoder learning rate for LR '
                        'scheduler. -1 to disable annealing')
    parser.add_argument('--min-dis-lr',
                        type=float,
                        default=7e-5,
                        help='min discriminator learning rate for LR '
                        'scheduler. -1 to disable annealing')
    parser.add_argument('--wd', type=float, default=0, help='weight decay')
    parser.add_argument('--overlap',
                        type=float,
                        default=.5,
                        help='kernel overlap')
    parser.add_argument('--no-norm-trans',
                        action='store_true',
                        help='if set, use Gaussian posterior without '
                        'transformation')
    parser.add_argument('--plot-interval',
                        type=int,
                        default=1,
                        help='plot interval. 0 to disable plotting.')
    parser.add_argument('--save-interval',
                        type=int,
                        default=0,
                        help='interval to save models. 0 to disable saving.')
    parser.add_argument('--prefix',
                        default='pbigan',
                        help='prefix of output directory')
    parser.add_argument('--comp',
                        type=int,
                        default=7,
                        help='continuous convolution kernel size')
    parser.add_argument('--ae',
                        type=float,
                        default=.2,
                        help='autoencoding regularization strength')
    parser.add_argument('--aeloss',
                        default='smooth_l1',
                        help='autoencoding loss. (options: mse, smooth_l1)')
    parser.add_argument('--ema',
                        dest='ema',
                        type=int,
                        default=-1,
                        help='start epoch of exponential moving average '
                        '(EMA). -1 to disable EMA')
    parser.add_argument('--ema-decay',
                        type=float,
                        default=.9999,
                        help='EMA decay')
    parser.add_argument('--mmd',
                        type=float,
                        default=1,
                        help='MMD strength for latent variable')

    # squash is off when rescale is off
    parser.add_argument('--squash',
                        dest='squash',
                        action='store_const',
                        const=True,
                        default=True,
                        help='bound the generated time series value '
                        'using tanh')
    parser.add_argument('--no-squash',
                        dest='squash',
                        action='store_const',
                        const=False)

    # rescale to [-1, 1]
    parser.add_argument('--rescale',
                        dest='rescale',
                        action='store_const',
                        const=True,
                        default=True,
                        help='if set, rescale time to [-1, 1]')
    parser.add_argument('--no-rescale',
                        dest='rescale',
                        action='store_const',
                        const=False)

    args = parser.parse_args()

    batch_size = args.batch_size
    nz = args.nz

    epochs = args.epoch
    plot_interval = args.plot_interval
    save_interval = args.save_interval

    try:
        npz = np.load(args.data)
        train_data = npz['data']
        train_time = npz['time']
        train_mask = npz['mask']
    except FileNotFoundError:
        if args.data != default_dataset:
            raise
        # Generate the default toy dataset from scratch
        train_data, train_time, train_mask, _, _ = gen_data(
            n_samples=10000,
            seq_len=200,
            max_time=1,
            poisson_rate=50,
            obs_span_rate=.25,
            save_file=default_dataset)

    _, in_channels, seq_len = train_data.shape
    train_time *= train_mask

    if args.seed is None:
        rnd = np.random.RandomState(None)
        random_seed = rnd.randint(np.iinfo(np.uint32).max)
    else:
        random_seed = args.seed
    rnd = np.random.RandomState(random_seed)
    np.random.seed(random_seed)
    torch.manual_seed(random_seed)

    # Scale time
    max_time = 5
    train_time *= max_time

    squash = None
    rescaler = None
    if args.rescale:
        rescaler = Rescaler(train_data)
        train_data = rescaler.rescale(train_data)
        if args.squash:
            squash = torch.tanh

    out_channels = 64
    cconv_ref = 98

    train_dataset = TimeSeries(train_data,
                               train_time,
                               train_mask,
                               label=None,
                               max_time=max_time,
                               cconv_ref=cconv_ref,
                               overlap_rate=args.overlap,
                               device=device)

    train_loader = DataLoader(train_dataset,
                              batch_size=batch_size,
                              shuffle=True,
                              drop_last=True,
                              collate_fn=train_dataset.collate_fn)
    n_train_batch = len(train_loader)

    time_loader = DataLoader(train_dataset,
                             batch_size=batch_size,
                             shuffle=True,
                             drop_last=True,
                             collate_fn=train_dataset.collate_fn)

    test_loader = DataLoader(train_dataset,
                             batch_size=batch_size,
                             collate_fn=train_dataset.collate_fn)

    grid_decoder = SeqGeneratorDiscrete(in_channels, nz, squash)
    decoder = Decoder(grid_decoder, max_time=max_time).to(device)

    cconv = ContinuousConv1D(in_channels,
                             out_channels,
                             max_time,
                             cconv_ref,
                             overlap_rate=args.overlap,
                             kernel_size=args.comp,
                             norm=True).to(device)
    encoder = Encoder(cconv, nz, not args.no_norm_trans).to(device)

    pbigan = PBiGAN(encoder, decoder, args.aeloss).to(device)

    critic_cconv = ContinuousConv1D(in_channels,
                                    out_channels,
                                    max_time,
                                    cconv_ref,
                                    overlap_rate=args.overlap,
                                    kernel_size=args.comp,
                                    norm=True).to(device)
    critic = ConvCritic(critic_cconv, nz).to(device)

    ema = None
    if args.ema >= 0:
        ema = EMA(pbigan, args.ema_decay, args.ema)

    optimizer = optim.Adam(pbigan.parameters(),
                           lr=args.lr,
                           weight_decay=args.wd)
    critic_optimizer = optim.Adam(critic.parameters(),
                                  lr=args.dis_lr,
                                  weight_decay=args.wd)

    scheduler = make_scheduler(optimizer, args.lr, args.min_lr, epochs)
    dis_scheduler = make_scheduler(critic_optimizer, args.dis_lr,
                                   args.min_dis_lr, epochs)

    path = '{}_{}'.format(args.prefix, datetime.now().strftime('%m%d.%H%M%S'))

    output_dir = Path('results') / 'toy-pbigan' / path
    print(output_dir)
    log_dir = mkdir(output_dir / 'log')
    model_dir = mkdir(output_dir / 'model')

    start_epoch = 0

    with (log_dir / 'seed.txt').open('w') as f:
        print(random_seed, file=f)
    with (log_dir / 'gpu.txt').open('a') as f:
        print(torch.cuda.device_count(), start_epoch, file=f)
    with (log_dir / 'args.txt').open('w') as f:
        for key, val in sorted(vars(args).items()):
            print(f'{key}: {val}', file=f)

    tracker = Tracker(log_dir, n_train_batch)
    visualizer = Visualizer(encoder, decoder, batch_size, max_time,
                            test_loader, rescaler, output_dir, device)
    start = time.time()
    epoch_start = start

    for epoch in range(start_epoch, epochs):
        loss_breakdown = defaultdict(float)

        for ((val, idx, mask, _, cconv_graph),
             (_, idx_t, mask_t, index, _)) in zip(train_loader, time_loader):

            z_enc, x_recon, z_gen, x_gen, ae_loss = pbigan(
                val, idx, mask, cconv_graph, idx_t, mask_t)

            cconv_graph_gen = train_dataset.make_graph(x_gen, idx_t, mask_t,
                                                       index)

            real = critic(cconv_graph, batch_size, z_enc)
            fake = critic(cconv_graph_gen, batch_size, z_gen)

            D_loss = gan_loss(real, fake, 1, 0)

            critic_optimizer.zero_grad()
            D_loss.backward(retain_graph=True)
            critic_optimizer.step()

            G_loss = gan_loss(real, fake, 0, 1)

            mmd_loss = mmd(z_enc, z_gen)

            loss = G_loss + ae_loss * args.ae + mmd_loss * args.mmd

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            if ema:
                ema.update()

            loss_breakdown['D'] += D_loss.item()
            loss_breakdown['G'] += G_loss.item()
            loss_breakdown['AE'] += ae_loss.item()
            loss_breakdown['MMD'] += mmd_loss.item()
            loss_breakdown['total'] += loss.item()

        if scheduler:
            scheduler.step()
        if dis_scheduler:
            dis_scheduler.step()

        cur_time = time.time()
        tracker.log(epoch, loss_breakdown, cur_time - epoch_start,
                    cur_time - start)

        if plot_interval > 0 and (epoch + 1) % plot_interval == 0:
            if ema:
                ema.apply()
                visualizer.plot(epoch)
                ema.restore()
            else:
                visualizer.plot(epoch)

        model_dict = {
            'pbigan': pbigan.state_dict(),
            'critic': critic.state_dict(),
            'ema': ema.state_dict() if ema else None,
            'epoch': epoch + 1,
            'args': args,
        }
        torch.save(model_dict, str(log_dir / 'model.pth'))
        if save_interval > 0 and (epoch + 1) % save_interval == 0:
            torch.save(model_dict, str(model_dir / f'{epoch:04d}.pth'))

    print(output_dir)
def train_pbigan(args):
    torch.manual_seed(args.seed)

    if args.mask == 'indep':
        data = IndepMaskedCelebA(obs_prob=args.obs_prob)
        mask_str = f'{args.mask}_{args.obs_prob}'
    elif args.mask == 'block':
        data = BlockMaskedCelebA(block_len=args.block_len)
        mask_str = f'{args.mask}_{args.block_len}'

    data_loader = DataLoader(data,
                             batch_size=args.batch_size,
                             shuffle=True,
                             drop_last=True)
    mask_loader = DataLoader(data,
                             batch_size=args.batch_size,
                             shuffle=True,
                             drop_last=True)

    test_loader = DataLoader(data, batch_size=args.batch_size, drop_last=True)

    decoder = ConvDecoder(args.latent)
    encoder = ConvEncoder(args.latent, args.flow, logprob=False)
    pbigan = PBiGAN(encoder, decoder, args.aeloss).to(device)

    critic = ConvCritic(args.latent).to(device)

    optimizer = optim.Adam(pbigan.parameters(), lr=args.lr, betas=(.5, .9))

    critic_optimizer = optim.Adam(critic.parameters(),
                                  lr=args.lr,
                                  betas=(.5, .9))

    grad_penalty = GradientPenalty(critic, args.batch_size)

    scheduler = make_scheduler(optimizer, args.lr, args.min_lr, args.epoch)

    path = '{}_{}_{}'.format(args.prefix,
                             datetime.now().strftime('%m%d.%H%M%S'), mask_str)
    output_dir = Path('results') / 'celeba-pbigan' / path
    mkdir(output_dir)
    print(output_dir)

    if args.save_interval > 0:
        model_dir = mkdir(output_dir / 'model')

    with (output_dir / 'args.txt').open('w') as f:
        print(pprint.pformat(vars(args)), file=f)

    vis = Visualizer(output_dir, loss_xlim=(0, args.epoch))

    test_x, test_mask, index = iter(test_loader).next()
    test_x = test_x.to(device)
    test_mask = test_mask.to(device).float()
    bbox = None
    if data.mask_loc is not None:
        bbox = [data.mask_loc[idx] for idx in index]

    n_critic = 5
    critic_updates = 0
    ae_weight = 0

    for epoch in range(args.epoch):
        loss_breakdown = defaultdict(float)

        if epoch >= args.ae_start:
            ae_weight = args.ae

        for (x, mask, _), (_, mask_gen, _) in zip(data_loader, mask_loader):
            x = x.to(device)
            mask = mask.to(device).float()
            mask_gen = mask_gen.to(device).float()

            if critic_updates < n_critic:
                z_enc, z_gen, x_rec, x_gen, _ = pbigan(x, mask, ae=False)

                real_score = critic((x * mask, z_enc)).mean()
                fake_score = critic((x_gen * mask_gen, z_gen)).mean()

                w_dist = real_score - fake_score
                D_loss = -w_dist + grad_penalty((x * mask, z_enc),
                                                (x_gen * mask_gen, z_gen))

                critic_optimizer.zero_grad()
                D_loss.backward()
                critic_optimizer.step()

                loss_breakdown['D'] += D_loss.item()

                critic_updates += 1
            else:
                critic_updates = 0

                # Update generators' parameters
                for p in critic.parameters():
                    p.requires_grad_(False)

                z_enc, z_gen, x_rec, x_gen, ae_loss = pbigan(x,
                                                             mask,
                                                             ae=(args.ae > 0))

                real_score = critic((x * mask, z_enc)).mean()
                fake_score = critic((x_gen * mask_gen, z_gen)).mean()

                G_loss = real_score - fake_score

                ae_loss = ae_loss * ae_weight
                loss = G_loss + ae_loss

                mmd_loss = 0
                if args.mmd > 0:
                    mmd_loss = mmd(z_enc, z_gen)
                    loss += mmd_loss * args.mmd

                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

                loss_breakdown['G'] += G_loss.item()
                if torch.is_tensor(ae_loss):
                    loss_breakdown['AE'] += ae_loss.item()
                if torch.is_tensor(mmd_loss):
                    loss_breakdown['MMD'] += mmd_loss.item()
                loss_breakdown['total'] += loss.item()

                for p in critic.parameters():
                    p.requires_grad_(True)

        if scheduler:
            scheduler.step()

        vis.plot_loss(epoch, loss_breakdown)

        if epoch % args.plot_interval == 0:
            with torch.no_grad():
                pbigan.eval()
                z, z_gen, x_rec, x_gen, ae_loss = pbigan(test_x, test_mask)
                pbigan.train()
            vis.plot(epoch, test_x, test_mask, bbox, x_rec, x_gen)

        model_dict = {
            'pbigan': pbigan.state_dict(),
            'critic': critic.state_dict(),
            'history': vis.history,
            'epoch': epoch,
            'args': args,
        }
        torch.save(model_dict, str(output_dir / 'model.pth'))
        if args.save_interval > 0 and (epoch + 1) % args.save_interval == 0:
            torch.save(model_dict, str(model_dir / f'{epoch:04d}.pth'))

    print(output_dir)
def train_pvae(args):
    torch.manual_seed(args.seed)

    if args.mask == 'indep':
        data = IndepMaskedMNIST(obs_prob=args.obs_prob,
                                obs_prob_max=args.obs_prob_max)
        mask_str = f'{args.mask}_{args.obs_prob}_{args.obs_prob_max}'
    elif args.mask == 'block':
        data = BlockMaskedMNIST(block_len=args.block_len,
                                block_len_max=args.block_len_max)
        mask_str = f'{args.mask}_{args.block_len}_{args.block_len_max}'

    data_loader = DataLoader(data,
                             batch_size=args.batch_size,
                             shuffle=True,
                             drop_last=True)

    # Evaluate the training progress using 2000 examples from the training data
    test_loader = DataLoader(data, batch_size=args.batch_size, drop_last=True)

    decoder = ConvDecoder(args.latent)
    encoder = ConvEncoder(args.latent, args.flow, logprob=True)
    pvae = PVAE(encoder, decoder).to(device)

    optimizer = optim.Adam(pvae.parameters(), lr=args.lr)
    scheduler = make_scheduler(optimizer, args.lr, args.min_lr, args.epoch)

    rand_z = torch.empty(args.batch_size, args.latent, device=device)

    path = '{}_{}_{}'.format(args.prefix,
                             datetime.now().strftime('%m%d.%H%M%S'), mask_str)
    output_dir = Path('results') / 'mnist-pvae' / path
    mkdir(output_dir)
    print(output_dir)

    if args.save_interval > 0:
        model_dir = mkdir(output_dir / 'model')

    logging.basicConfig(
        level=logging.INFO,
        format='%(asctime)s %(message)s',
        datefmt='%Y-%m-%d %H:%M:%S',
        handlers=[
            logging.FileHandler(output_dir / 'log.txt'),
            logging.StreamHandler(sys.stdout),
        ],
    )

    with (output_dir / 'args.txt').open('w') as f:
        print(pprint.pformat(vars(args)), file=f)

    vis = Visualizer(output_dir)

    test_x, test_mask, index = iter(test_loader).next()
    test_x = test_x.to(device)
    test_mask = test_mask.to(device).float()
    bbox = None
    if data.mask_loc is not None:
        bbox = [data.mask_loc[idx] for idx in index]

    kl_center = (args.kl_on + args.kl_off) / 2
    kl_scale = 12 / min(args.kl_on - args.kl_off, 1)

    for epoch in range(args.epoch):
        if epoch >= args.kl_on:
            kl_weight = 1
        elif epoch < args.kl_off:
            kl_weight = 0
        else:
            kl_weight = 1 / (1 + math.exp(-(epoch - kl_center) * kl_scale))
        loss_breakdown = defaultdict(float)
        for x, mask, _ in data_loader:
            x = x.to(device)
            mask = mask.to(device).float()

            optimizer.zero_grad()
            loss, _, _, _, loss_info = pvae(x,
                                            mask,
                                            args.k,
                                            kl_weight=kl_weight)
            loss.backward()
            optimizer.step()
            for name, val in loss_info.items():
                loss_breakdown[name] += val

        if scheduler:
            scheduler.step()

        vis.plot_loss(epoch, loss_breakdown)

        if epoch % args.plot_interval == 0:
            x_recon = pvae.impute(test_x, test_mask, args.k)
            with torch.no_grad():
                pvae.eval()
                rand_z.normal_()
                _, x_gen = decoder(rand_z)
                pvae.train()
            vis.plot(epoch, test_x, test_mask, bbox, x_recon, x_gen)

        model_dict = {
            'pvae': pvae.state_dict(),
            'history': vis.history,
            'epoch': epoch,
            'args': args,
        }
        torch.save(model_dict, str(output_dir / 'model.pth'))
        if args.save_interval > 0 and (epoch + 1) % args.save_interval == 0:
            torch.save(model_dict, str(model_dir / f'{epoch:04d}.pth'))

    print(output_dir)
def main():
    parser = argparse.ArgumentParser()

    parser.add_argument('--data', default='mimic3.npz', help='data file')
    parser.add_argument('--seed',
                        type=int,
                        default=None,
                        help='random seed. Randomly set if not specified.')

    # training options
    parser.add_argument('--nz',
                        type=int,
                        default=32,
                        help='dimension of latent variable')
    parser.add_argument('--epoch',
                        type=int,
                        default=500,
                        help='number of training epochs')
    parser.add_argument('--batch-size',
                        type=int,
                        default=64,
                        help='batch size')
    # Use smaller test batch size to accommodate more importance samples
    parser.add_argument('--test-batch-size',
                        type=int,
                        default=32,
                        help='batch size for validation and test set')
    parser.add_argument('--lr',
                        type=float,
                        default=2e-4,
                        help='encoder/decoder learning rate')
    parser.add_argument('--dis-lr',
                        type=float,
                        default=3e-4,
                        help='discriminator learning rate')
    parser.add_argument('--min-lr',
                        type=float,
                        default=1e-4,
                        help='min encoder/decoder learning rate for LR '
                        'scheduler. -1 to disable annealing')
    parser.add_argument('--min-dis-lr',
                        type=float,
                        default=1.5e-4,
                        help='min discriminator learning rate for LR '
                        'scheduler. -1 to disable annealing')
    parser.add_argument('--wd', type=float, default=1e-4, help='weight decay')
    parser.add_argument('--overlap',
                        type=float,
                        default=.5,
                        help='kernel overlap')
    parser.add_argument('--cls',
                        type=float,
                        default=1,
                        help='classification weight')
    parser.add_argument('--clsdep',
                        type=int,
                        default=1,
                        help='number of layers for classifier')
    parser.add_argument('--eval-interval',
                        type=int,
                        default=1,
                        help='AUC evaluation interval. '
                        '0 to disable evaluation.')
    parser.add_argument('--save-interval',
                        type=int,
                        default=0,
                        help='interval to save models. 0 to disable saving.')
    parser.add_argument('--prefix',
                        default='pbigan',
                        help='prefix of output directory')
    parser.add_argument('--comp',
                        type=int,
                        default=7,
                        help='continuous convolution kernel size')
    parser.add_argument('--ae',
                        type=float,
                        default=1,
                        help='autoencoding regularization strength')
    parser.add_argument('--aeloss',
                        default='mse',
                        help='autoencoding loss. (options: mse, smooth_l1)')
    parser.add_argument('--dec-ch',
                        default='8-16-16',
                        help='decoder architecture')
    parser.add_argument('--enc-ch',
                        default='64-32-32-16',
                        help='encoder architecture')
    parser.add_argument('--dis-ch',
                        default=None,
                        help='discriminator architecture. Use encoder '
                        'architecture if unspecified.')
    parser.add_argument('--rescale',
                        dest='rescale',
                        action='store_const',
                        const=True,
                        default=True,
                        help='if set, rescale time to [-1, 1]')
    parser.add_argument('--no-rescale',
                        dest='rescale',
                        action='store_const',
                        const=False)
    parser.add_argument('--cconvnorm',
                        dest='cconv_norm',
                        action='store_const',
                        const=True,
                        default=True,
                        help='if set, normalize continuous convolutional '
                        'layer using mean pooling')
    parser.add_argument('--no-cconvnorm',
                        dest='cconv_norm',
                        action='store_const',
                        const=False)
    parser.add_argument('--cconv-ref',
                        type=int,
                        default=98,
                        help='number of evenly-spaced reference locations '
                        'for continuous convolutional layer')
    parser.add_argument('--dec-ref',
                        type=int,
                        default=128,
                        help='number of evenly-spaced reference locations '
                        'for decoder')
    parser.add_argument('--trans',
                        type=int,
                        default=2,
                        help='number of encoder layers')
    parser.add_argument('--ema',
                        dest='ema',
                        type=int,
                        default=0,
                        help='start epoch of exponential moving average '
                        '(EMA). -1 to disable EMA')
    parser.add_argument('--ema-decay',
                        type=float,
                        default=.9999,
                        help='EMA decay')
    parser.add_argument('--mmd',
                        type=float,
                        default=1,
                        help='MMD strength for latent variable')

    args = parser.parse_args()

    nz = args.nz

    epochs = args.epoch
    eval_interval = args.eval_interval
    save_interval = args.save_interval

    if args.seed is None:
        rnd = np.random.RandomState(None)
        random_seed = rnd.randint(np.iinfo(np.uint32).max)
    else:
        random_seed = args.seed
    rnd = np.random.RandomState(random_seed)
    np.random.seed(random_seed)
    torch.manual_seed(random_seed)

    max_time = 5
    cconv_ref = args.cconv_ref
    overlap = args.overlap
    train_dataset, val_dataset, test_dataset = time_series.split_data(
        args.data, rnd, max_time, cconv_ref, overlap, device, args.rescale)

    train_loader = DataLoader(train_dataset,
                              batch_size=args.batch_size,
                              shuffle=True,
                              drop_last=True,
                              collate_fn=train_dataset.collate_fn)
    n_train_batch = len(train_loader)

    time_loader = DataLoader(train_dataset,
                             batch_size=args.batch_size,
                             shuffle=True,
                             drop_last=True,
                             collate_fn=train_dataset.collate_fn)

    val_loader = DataLoader(val_dataset,
                            batch_size=args.test_batch_size,
                            shuffle=False,
                            collate_fn=val_dataset.collate_fn)

    test_loader = DataLoader(test_dataset,
                             batch_size=args.test_batch_size,
                             shuffle=False,
                             collate_fn=test_dataset.collate_fn)

    in_channels, seq_len = train_dataset.data.shape[1:]

    if args.dis_ch is None:
        args.dis_ch = args.enc_ch

    dec_channels = [int(c) for c in args.dec_ch.split('-')] + [in_channels]
    enc_channels = [int(c) for c in args.enc_ch.split('-')]
    dis_channels = [int(c) for c in args.dis_ch.split('-')]

    out_channels = enc_channels[0]

    squash = torch.sigmoid
    if args.rescale:
        squash = torch.tanh

    dec_ch_up = 2**(len(dec_channels) - 2)
    assert args.dec_ref % dec_ch_up == 0, (
        f'--dec-ref={args.dec_ref} is not divided by {dec_ch_up}.')
    dec_len0 = args.dec_ref // dec_ch_up
    grid_decoder = GridDecoder(nz, dec_channels, dec_len0, squash)

    decoder = Decoder(grid_decoder, max_time=max_time,
                      dec_ref=args.dec_ref).to(device)
    cconv = ContinuousConv1D(in_channels,
                             out_channels,
                             max_time,
                             cconv_ref,
                             overlap_rate=overlap,
                             kernel_size=args.comp,
                             norm=args.cconv_norm).to(device)
    encoder = Encoder(cconv, nz, enc_channels, args.trans).to(device)

    classifier = Classifier(nz, args.clsdep).to(device)

    pbigan = PBiGAN(encoder, decoder, classifier,
                    ae_loss=args.aeloss).to(device)

    ema = None
    if args.ema >= 0:
        ema = EMA(pbigan, args.ema_decay, args.ema)

    critic_cconv = ContinuousConv1D(in_channels,
                                    out_channels,
                                    max_time,
                                    cconv_ref,
                                    overlap_rate=overlap,
                                    kernel_size=args.comp,
                                    norm=args.cconv_norm).to(device)
    critic_embed = 32
    critic = ConvCritic(critic_cconv, nz, dis_channels,
                        critic_embed).to(device)

    optimizer = optim.Adam(pbigan.parameters(),
                           lr=args.lr,
                           betas=(0, .999),
                           weight_decay=args.wd)
    critic_optimizer = optim.Adam(critic.parameters(),
                                  lr=args.dis_lr,
                                  betas=(0, .999),
                                  weight_decay=args.wd)

    scheduler = make_scheduler(optimizer, args.lr, args.min_lr, epochs)
    dis_scheduler = make_scheduler(critic_optimizer, args.dis_lr,
                                   args.min_dis_lr, epochs)

    path = '{}_{}'.format(args.prefix, datetime.now().strftime('%m%d.%H%M%S'))

    output_dir = Path('results') / 'mimic3-pbigan' / path
    print(output_dir)
    log_dir = mkdir(output_dir / 'log')
    model_dir = mkdir(output_dir / 'model')

    start_epoch = 0

    with (log_dir / 'seed.txt').open('w') as f:
        print(random_seed, file=f)
    with (log_dir / 'gpu.txt').open('a') as f:
        print(torch.cuda.device_count(), start_epoch, file=f)
    with (log_dir / 'args.txt').open('w') as f:
        for key, val in sorted(vars(args).items()):
            print(f'{key}: {val}', file=f)
    with (log_dir / 'params.txt').open('w') as f:

        def print_params_count(module, name):
            try:  # sum counts if module is a list
                params_count = sum(count_parameters(m) for m in module)
            except TypeError:
                params_count = count_parameters(module)
            print(f'{name} {params_count}', file=f)

        print_params_count(grid_decoder, 'grid_decoder')
        print_params_count(decoder, 'decoder')
        print_params_count(cconv, 'cconv')
        print_params_count(encoder, 'encoder')
        print_params_count(classifier, 'classifier')
        print_params_count(pbigan, 'pbigan')
        print_params_count(critic, 'critic')
        print_params_count([pbigan, critic], 'total')

    tracker = Tracker(log_dir, n_train_batch)
    evaluator = Evaluator(pbigan, val_loader, test_loader, log_dir)
    start = time.time()
    epoch_start = start

    batch_size = args.batch_size

    for epoch in range(start_epoch, epochs):
        loss_breakdown = defaultdict(float)
        epoch_start = time.time()

        if epoch >= 40:
            args.cls = 200

        for ((val, idx, mask, y, _, cconv_graph),
             (_, idx_t, mask_t, _, index, _)) in zip(train_loader,
                                                     time_loader):

            z_enc, x_recon, z_gen, x_gen, ae_loss, cls_loss = pbigan(
                val, idx, mask, y, cconv_graph, idx_t, mask_t)

            cconv_graph_gen = train_dataset.make_graph(x_gen, idx_t, mask_t,
                                                       index)

            # Don't need pbigan.requires_grad_(False);
            # critic takes as input only the detached tensors.
            real = critic(cconv_graph, batch_size, z_enc.detach())
            detached_graph = [[cat_y.detach() for cat_y in x] if i == 2 else x
                              for i, x in enumerate(cconv_graph_gen)]
            fake = critic(detached_graph, batch_size, z_gen.detach())

            D_loss = gan_loss(real, fake, 1, 0)

            critic_optimizer.zero_grad()
            D_loss.backward()
            critic_optimizer.step()

            for p in critic.parameters():
                p.requires_grad_(False)
            real = critic(cconv_graph, batch_size, z_enc)
            fake = critic(cconv_graph_gen, batch_size, z_gen)

            G_loss = gan_loss(real, fake, 0, 1)

            mmd_loss = mmd(z_enc, z_gen)

            loss = (G_loss + ae_loss * args.ae + cls_loss * args.cls +
                    mmd_loss * args.mmd)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            for p in critic.parameters():
                p.requires_grad_(True)

            if ema:
                ema.update()

            loss_breakdown['D'] += D_loss.item()
            loss_breakdown['G'] += G_loss.item()
            loss_breakdown['AE'] += ae_loss.item()
            loss_breakdown['MMD'] += mmd_loss.item()
            loss_breakdown['CLS'] += cls_loss.item()
            loss_breakdown['total'] += loss.item()

        if scheduler:
            scheduler.step()
        if dis_scheduler:
            dis_scheduler.step()

        cur_time = time.time()
        tracker.log(epoch, loss_breakdown, cur_time - epoch_start,
                    cur_time - start)

        if eval_interval > 0 and (epoch + 1) % eval_interval == 0:
            if ema:
                ema.apply()
                evaluator.evaluate(epoch)
                ema.restore()
            else:
                evaluator.evaluate(epoch)

        model_dict = {
            'pbigan': pbigan.state_dict(),
            'critic': critic.state_dict(),
            'ema': ema.state_dict() if ema else None,
            'epoch': epoch + 1,
            'args': args,
        }
        torch.save(model_dict, str(log_dir / 'model.pth'))
        if save_interval > 0 and (epoch + 1) % save_interval == 0:
            torch.save(model_dict, str(model_dir / f'{epoch:04d}.pth'))

    print(output_dir)
Exemple #12
0
def main():
    args = get_args()
    torch.manual_seed(args.seed)

    shape = (224, 224, 3)
    """ define dataloader """
    train_loader, valid_loader, test_loader = make_dataloader(args)
    """ define model architecture """
    model = get_model(args, shape, args.num_classes)

    if torch.cuda.device_count() >= 1:
        print('Model pushed to {} GPU(s), type {}.'.format(
            torch.cuda.device_count(), torch.cuda.get_device_name(0)))
        model = model.cuda()
    else:
        raise ValueError('CPU training is not supported')
    """ define loss criterion """
    criterion = nn.CrossEntropyLoss().cuda()
    """ define optimizer """
    optimizer = make_optimizer(args, model)
    """ define learning rate scheduler """
    scheduler = make_scheduler(args, optimizer)
    """ define loss scaler for automatic mixed precision """
    scaler = torch.cuda.amp.GradScaler()
    """ define trainer, evaluator, result_dictionary """
    result_dict = {
        'args': vars(args),
        'epoch': [],
        'train_loss': [],
        'train_acc': [],
        'val_loss': [],
        'val_acc': [],
        'test_acc': []
    }
    trainer = Trainer(model, criterion, optimizer, scheduler, scaler)
    evaluator = Evaluator(model, criterion)

    train_time_list = []
    valid_time_list = []

    if args.evaluate:
        """ load model checkpoint """
        model.load()
        result_dict = evaluator.test(test_loader, args, result_dict)
    else:
        evaluator.save(result_dict)

        best_val_acc = 0.0
        """ define training loop """
        for epoch in range(args.epochs):
            result_dict['epoch'] = epoch

            torch.cuda.synchronize()
            tic1 = time.time()

            result_dict = trainer.train(train_loader, epoch, args, result_dict)

            torch.cuda.synchronize()
            tic2 = time.time()
            train_time_list.append(tic2 - tic1)

            torch.cuda.synchronize()
            tic3 = time.time()

            result_dict = evaluator.evaluate(valid_loader, epoch, args,
                                             result_dict)

            torch.cuda.synchronize()
            tic4 = time.time()
            valid_time_list.append(tic4 - tic3)

            if result_dict['val_acc'][-1] > best_val_acc:
                print("{} epoch, best epoch was updated! {}%".format(
                    epoch, result_dict['val_acc'][-1]))
                best_val_acc = result_dict['val_acc'][-1]
                model.save(checkpoint_name='best_model')

            evaluator.save(result_dict)
            plot_learning_curves(result_dict, epoch, args)

        result_dict = evaluator.test(test_loader, args, result_dict)
        evaluator.save(result_dict)
        """ calculate test accuracy using best model """
        model.load(checkpoint_name='best_model')
        result_dict = evaluator.test(test_loader, args, result_dict)
        evaluator.save(result_dict)

    print(result_dict)

    np.savetxt(os.path.join(model.checkpoint_dir, model.checkpoint_name,
                            'train_time_amp.csv'),
               train_time_list,
               delimiter=',',
               fmt='%s')
    np.savetxt(os.path.join(model.checkpoint_dir, model.checkpoint_name,
                            'valid_time_amp.csv'),
               valid_time_list,
               delimiter=',',
               fmt='%s')
def main():
    args = get_args()
    torch.manual_seed(args.seed)

    shape = (224, 224, 3)
    """ define dataloader """
    train_loader, valid_loader, test_loader = make_dataloader(args)
    """ define model architecture """
    model = get_model(args, shape, args.num_classes)

    if torch.cuda.device_count() >= 1:
        print('Model pushed to {} GPU(s), type {}.'.format(
            torch.cuda.device_count(), torch.cuda.get_device_name(0)))
        model = model.cuda()
    else:
        raise ValueError('CPU training is not supported')
    """ define loss criterion """
    criterion = nn.CrossEntropyLoss().cuda()
    """ define optimizer """
    optimizer = make_optimizer(args, model)
    """ define learning rate scheduler """
    scheduler = make_scheduler(args, optimizer)
    """ define trainer, evaluator, result_dictionary """
    result_dict = {
        'args': vars(args),
        'epoch': [],
        'train_loss': [],
        'train_acc': [],
        'val_loss': [],
        'val_acc': [],
        'test_acc': []
    }
    trainer = Trainer(model, criterion, optimizer, scheduler)
    evaluator = Evaluator(model, criterion)

    if args.evaluate:
        """ load model checkpoint """
        model.load("best_model")
        result_dict = evaluator.test(test_loader, args, result_dict, True)

        model.load("last_model")
        result_dict = evaluator.test(test_loader, args, result_dict, False)

    else:
        evaluator.save(result_dict)

        best_val_acc = 0.0
        """ define training loop """
        tolerance = 0
        for epoch in range(args.epochs):
            result_dict['epoch'] = epoch
            result_dict = trainer.train(train_loader, epoch, args, result_dict)
            result_dict = evaluator.evaluate(valid_loader, epoch, args,
                                             result_dict)

            tolerance += 1
            print("tolerance: ", tolerance)

            if result_dict['val_acc'][-1] > best_val_acc:
                tolerance = 0
                print("{} epoch, best epoch was updated! {}%".format(
                    epoch, result_dict['val_acc'][-1]))
                best_val_acc = result_dict['val_acc'][-1]
                model.save(checkpoint_name='best_model')

            evaluator.save(result_dict)
            plot_learning_curves(result_dict, epoch, args)

            if tolerance > 20:
                break

        result_dict = evaluator.test(test_loader, args, result_dict, False)
        evaluator.save(result_dict)
        """ save model checkpoint """
        model.save(checkpoint_name='last_model')
        """ calculate test accuracy using best model """
        model.load(checkpoint_name='best_model')
        result_dict = evaluator.test(test_loader, args, result_dict, True)
        evaluator.save(result_dict)

    print(result_dict)
def main():
    parser = argparse.ArgumentParser()

    default_dataset = 'toy-data.npz'
    parser.add_argument('--data', default=default_dataset,
                        help='data file')
    parser.add_argument('--seed', type=int, default=None,
                        help='random seed. Randomly set if not specified.')

    # training options
    parser.add_argument('--nz', type=int, default=32,
                        help='dimension of latent variable')
    parser.add_argument('--epoch', type=int, default=1000,
                        help='number of training epochs')
    parser.add_argument('--batch-size', type=int, default=128,
                        help='batch size')
    parser.add_argument('--lr', type=float, default=1e-4,
                        help='learning rate')
    parser.add_argument('--min-lr', type=float, default=5e-5,
                        help='min learning rate for LR scheduler. '
                             '-1 to disable annealing')
    parser.add_argument('--plot-interval', type=int, default=10,
                        help='plot interval. 0 to disable plotting.')
    parser.add_argument('--save-interval', type=int, default=0,
                        help='interval to save models. 0 to disable saving.')
    parser.add_argument('--prefix', default='pvae',
                        help='prefix of output directory')
    parser.add_argument('--comp', type=int, default=5,
                        help='continuous convolution kernel size')
    parser.add_argument('--sigma', type=float, default=.2,
                        help='standard deviation for Gaussian likelihood')
    parser.add_argument('--overlap', type=float, default=.5,
                        help='kernel overlap')
    # squash is off when rescale is off
    parser.add_argument('--squash', dest='squash', action='store_const',
                        const=True, default=True,
                        help='bound the generated time series value '
                             'using tanh')
    parser.add_argument('--no-squash', dest='squash', action='store_const',
                        const=False)

    # rescale to [-1, 1]
    parser.add_argument('--rescale', dest='rescale', action='store_const',
                        const=True, default=True,
                        help='if set, rescale time to [-1, 1]')
    parser.add_argument('--no-rescale', dest='rescale', action='store_const',
                        const=False)

    args = parser.parse_args()

    batch_size = args.batch_size
    nz = args.nz

    epochs = args.epoch
    plot_interval = args.plot_interval
    save_interval = args.save_interval

    try:
        npz = np.load(args.data)
        train_data = npz['data']
        train_time = npz['time']
        train_mask = npz['mask']
    except FileNotFoundError:
        if args.data != default_dataset:
            raise
        # Generate the default toy dataset from scratch
        train_data, train_time, train_mask, _, _ = gen_data(
            n_samples=10000, seq_len=200, max_time=1, poisson_rate=50,
            obs_span_rate=.25, save_file=default_dataset)

    _, in_channels, seq_len = train_data.shape
    train_time *= train_mask

    if args.seed is None:
        rnd = np.random.RandomState(None)
        random_seed = rnd.randint(np.iinfo(np.uint32).max)
    else:
        random_seed = args.seed
    rnd = np.random.RandomState(random_seed)
    np.random.seed(random_seed)
    torch.manual_seed(random_seed)

    # Scale time
    max_time = 5
    train_time *= max_time

    squash = None
    rescaler = None
    if args.rescale:
        rescaler = Rescaler(train_data)
        train_data = rescaler.rescale(train_data)
        if args.squash:
            squash = torch.tanh

    out_channels = 64
    cconv_ref = 98

    train_dataset = TimeSeries(
        train_data, train_time, train_mask, label=None, max_time=max_time,
        cconv_ref=cconv_ref, overlap_rate=args.overlap, device=device)

    train_loader = DataLoader(
        train_dataset, batch_size=batch_size, shuffle=True,
        drop_last=True, collate_fn=train_dataset.collate_fn)
    n_train_batch = len(train_loader)

    test_batch_size = 64
    test_loader = DataLoader(train_dataset, batch_size=test_batch_size,
                             collate_fn=train_dataset.collate_fn)

    grid_decoder = SeqGeneratorDiscrete(in_channels, nz, squash)
    decoder = Decoder(grid_decoder, max_time=max_time).to(device)

    cconv = ContinuousConv1D(
        in_channels, out_channels, max_time, cconv_ref,
        overlap_rate=args.overlap, kernel_size=args.comp, norm=True).to(device)

    encoder = Encoder(nz, cconv).to(device)

    pvae = PVAE(encoder, decoder, sigma=args.sigma).to(device)

    optimizer = optim.Adam(pvae.parameters(), lr=args.lr)

    scheduler = make_scheduler(optimizer, args.lr, args.min_lr, epochs)

    path = '{}_{}_{}'.format(
        args.prefix, datetime.now().strftime('%m%d.%H%M%S'),
        '_'.join([f'lr_{args.lr:g}']))

    output_dir = Path('results') / 'toy-pvae' / path
    print(output_dir)
    log_dir = mkdir(output_dir / 'log')
    model_dir = mkdir(output_dir / 'model')

    start_epoch = 0

    with (log_dir / 'seed.txt').open('w') as f:
        print(random_seed, file=f)
    with (log_dir / 'gpu.txt').open('a') as f:
        print(torch.cuda.device_count(), start_epoch, file=f)
    with (log_dir / 'args.txt').open('w') as f:
        for key, val in sorted(vars(args).items()):
            print(f'{key}: {val}', file=f)

    tracker = Tracker(log_dir, n_train_batch)
    visualizer = Visualizer(encoder, decoder, test_batch_size, max_time,
                            test_loader, rescaler, output_dir, device)
    start = time.time()
    epoch_start = start

    for epoch in range(start_epoch, epochs):
        loss_breakdown = defaultdict(float)
        for val, idx, mask, _, cconv_graph in train_loader:
            optimizer.zero_grad()
            loss = pvae(val, idx, mask, cconv_graph)
            loss.backward()
            optimizer.step()
            loss_breakdown['loss'] += loss.item()

        if scheduler:
            scheduler.step()

        cur_time = time.time()
        tracker.log(
            epoch, loss_breakdown, cur_time - epoch_start, cur_time - start)

        if plot_interval > 0 and (epoch + 1) % plot_interval == 0:
            visualizer.plot(epoch)

        model_dict = {
            'pvae': pvae.state_dict(),
            'epoch': epoch + 1,
            'args': args,
        }
        torch.save(model_dict, str(log_dir / 'model.pth'))
        if save_interval > 0 and (epoch + 1) % save_interval == 0:
            torch.save(model_dict, str(model_dir / f'{epoch:04d}.pth'))

    print(output_dir)