def train(save_pth, use_mixup, mixup_alpha): model, criteria = set_model() ema = EMA(model, ema_alpha) optim, lr_sheduler = set_optimizer(model) dltrain = get_train_loader(batch_size=batchsize, num_workers=n_workers, dataset=ds_name, pin_memory=False) for e in range(n_epochs): tic = time.time() loss_avg = train_one_epoch(model, criteria, dltrain, optim, ema, use_mixup, mixup_alpha) lr_sheduler.step() acc = evaluate(model, verbose=False) ema.apply_shadow() acc_ema = evaluate(model, verbose=False) ema.restore() toc = time.time() msg = 'epoch: {}, loss: {:.4f}, lr: {:.4f}, acc: {:.4f}, acc_ema: {:.4f}, time: {:.2f}'.format( e, loss_avg, list(optim.param_groups)[0]['lr'], acc, acc_ema, toc - tic) print(msg) save_model(model, save_pth) print('done') return model
def main(): parser = argparse.ArgumentParser() parser.add_argument('--data', default='mimic3.npz', help='data file') parser.add_argument('--seed', type=int, default=None, help='random seed. Randomly set if not specified.') # training options parser.add_argument('--nz', type=int, default=32, help='dimension of latent variable') parser.add_argument('--epoch', type=int, default=200, help='number of training epochs') parser.add_argument('--batch-size', type=int, default=64, help='batch size') # Use smaller test batch size to accommodate more importance samples parser.add_argument('--test-batch-size', type=int, default=32, help='batch size for validation and test set') parser.add_argument('--train-k', type=int, default=8, help='number of importance weights for training') parser.add_argument('--test-k', type=int, default=50, help='number of importance weights for evaluation') parser.add_argument('--flow', type=int, default=2, help='number of IAF layers') parser.add_argument('--lr', type=float, default=2e-4, help='global learning rate') parser.add_argument('--enc-lr', type=float, default=1e-4, help='encoder learning rate') parser.add_argument('--dec-lr', type=float, default=1e-4, help='decoder learning rate') parser.add_argument('--min-lr', type=float, default=-1, help='min learning rate for LR scheduler. ' '-1 to disable annealing') parser.add_argument('--wd', type=float, default=1e-3, help='weight decay') parser.add_argument('--overlap', type=float, default=.5, help='kernel overlap') parser.add_argument('--cls', type=float, default=200, help='classification weight') parser.add_argument('--clsdep', type=int, default=1, help='number of layers for classifier') parser.add_argument('--ts', type=float, default=1, help='log-likelihood weight for ELBO') parser.add_argument('--kl', type=float, default=.1, help='KL weight for ELBO') parser.add_argument('--eval-interval', type=int, default=1, help='AUC evaluation interval. ' '0 to disable evaluation.') parser.add_argument('--save-interval', type=int, default=0, help='interval to save models. 0 to disable saving.') parser.add_argument('--prefix', default='pvae', help='prefix of output directory') parser.add_argument('--comp', type=int, default=7, help='continuous convolution kernel size') parser.add_argument('--sigma', type=float, default=.2, help='standard deviation for Gaussian likelihood') parser.add_argument('--dec-ch', default='8-16-16', help='decoder architecture') parser.add_argument('--enc-ch', default='64-32-32-16', help='encoder architecture') parser.add_argument('--rescale', dest='rescale', action='store_const', const=True, default=True, help='if set, rescale time to [-1, 1]') parser.add_argument('--no-rescale', dest='rescale', action='store_const', const=False) parser.add_argument('--cconvnorm', dest='cconv_norm', action='store_const', const=True, default=True, help='if set, normalize continuous convolutional ' 'layer using mean pooling') parser.add_argument('--no-cconvnorm', dest='cconv_norm', action='store_const', const=False) parser.add_argument('--cconv-ref', type=int, default=98, help='number of evenly-spaced reference locations ' 'for continuous convolutional layer') parser.add_argument('--dec-ref', type=int, default=128, help='number of evenly-spaced reference locations ' 'for decoder') parser.add_argument('--ema', dest='ema', type=int, default=0, help='start epoch of exponential moving average ' '(EMA). -1 to disable EMA') parser.add_argument('--ema-decay', type=float, default=.9999, help='EMA decay') args = parser.parse_args() nz = args.nz epochs = args.epoch eval_interval = args.eval_interval save_interval = args.save_interval if args.seed is None: rnd = np.random.RandomState(None) random_seed = rnd.randint(np.iinfo(np.uint32).max) else: random_seed = args.seed rnd = np.random.RandomState(random_seed) np.random.seed(random_seed) torch.manual_seed(random_seed) max_time = 5 cconv_ref = args.cconv_ref overlap = args.overlap train_dataset, val_dataset, test_dataset = time_series.split_data( args.data, rnd, max_time, cconv_ref, overlap, device, args.rescale) train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, drop_last=True, collate_fn=train_dataset.collate_fn) n_train_batch = len(train_loader) val_loader = DataLoader(val_dataset, batch_size=args.test_batch_size, shuffle=False, collate_fn=val_dataset.collate_fn) test_loader = DataLoader(test_dataset, batch_size=args.test_batch_size, shuffle=False, collate_fn=test_dataset.collate_fn) in_channels, seq_len = train_dataset.data.shape[1:] dec_channels = [int(c) for c in args.dec_ch.split('-')] + [in_channels] enc_channels = [int(c) for c in args.enc_ch.split('-')] out_channels = enc_channels[0] squash = torch.sigmoid if args.rescale: squash = torch.tanh dec_ch_up = 2**(len(dec_channels) - 2) assert args.dec_ref % dec_ch_up == 0, ( f'--dec-ref={args.dec_ref} is not divided by {dec_ch_up}.') dec_len0 = args.dec_ref // dec_ch_up grid_decoder = GridDecoder(nz, dec_channels, dec_len0, squash) decoder = Decoder(grid_decoder, max_time=max_time, dec_ref=args.dec_ref).to(device) cconv = ContinuousConv1D(in_channels, out_channels, max_time, cconv_ref, overlap_rate=overlap, kernel_size=args.comp, norm=args.cconv_norm).to(device) encoder = Encoder(cconv, nz, enc_channels, args.flow).to(device) classifier = Classifier(nz, args.clsdep).to(device) pvae = PVAE(encoder, decoder, classifier, args.sigma, args.cls).to(device) ema = None if args.ema >= 0: ema = EMA(pvae, args.ema_decay, args.ema) other_params = [ param for name, param in pvae.named_parameters() if not (name.startswith('decoder.grid_decoder') or name.startswith('encoder.grid_encoder')) ] params = [ { 'params': decoder.grid_decoder.parameters(), 'lr': args.dec_lr }, { 'params': encoder.grid_encoder.parameters(), 'lr': args.enc_lr }, { 'params': other_params }, ] optimizer = optim.Adam(params, lr=args.lr, weight_decay=args.wd) scheduler = make_scheduler(optimizer, args.lr, args.min_lr, epochs) path = '{}_{}'.format(args.prefix, datetime.now().strftime('%m%d.%H%M%S')) output_dir = Path('results') / 'mimic3-pvae' / path print(output_dir) log_dir = mkdir(output_dir / 'log') model_dir = mkdir(output_dir / 'model') start_epoch = 0 with (log_dir / 'seed.txt').open('w') as f: print(random_seed, file=f) with (log_dir / 'gpu.txt').open('a') as f: print(torch.cuda.device_count(), start_epoch, file=f) with (log_dir / 'args.txt').open('w') as f: for key, val in sorted(vars(args).items()): print(f'{key}: {val}', file=f) with (log_dir / 'params.txt').open('w') as f: def print_params_count(module, name): try: # sum counts if module is a list params_count = sum(count_parameters(m) for m in module) except TypeError: params_count = count_parameters(module) print(f'{name} {params_count}', file=f) print_params_count(grid_decoder, 'grid_decoder') print_params_count(decoder, 'decoder') print_params_count(cconv, 'cconv') print_params_count(encoder, 'encoder') print_params_count(classifier, 'classifier') print_params_count(pvae, 'pvae') print_params_count(pvae, 'total') tracker = Tracker(log_dir, n_train_batch) evaluator = Evaluator(pvae, val_loader, test_loader, log_dir, eval_args={'iw_samples': args.test_k}) start = time.time() epoch_start = start for epoch in range(start_epoch, epochs): loss_breakdown = defaultdict(float) epoch_start = time.time() for (val, idx, mask, y, _, cconv_graph) in train_loader: optimizer.zero_grad() loss, _, _, loss_info = pvae(val, idx, mask, y, cconv_graph, args.train_k, args.ts, args.kl) loss.backward() optimizer.step() if ema: ema.update() for loss_name, loss_val in loss_info.items(): loss_breakdown[loss_name] += loss_val if scheduler: scheduler.step() cur_time = time.time() tracker.log(epoch, loss_breakdown, cur_time - epoch_start, cur_time - start) if eval_interval > 0 and (epoch + 1) % eval_interval == 0: if ema: ema.apply() evaluator.evaluate(epoch) ema.restore() else: evaluator.evaluate(epoch) model_dict = { 'pvae': pvae.state_dict(), 'ema': ema.state_dict() if ema else None, 'epoch': epoch + 1, 'args': args, } torch.save(model_dict, str(log_dir / 'model.pth')) if save_interval > 0 and (epoch + 1) % save_interval == 0: torch.save(model_dict, str(model_dir / f'{epoch:04d}.pth')) print(output_dir)
def main(): parser = argparse.ArgumentParser() default_dataset = 'toy-data.npz' parser.add_argument('--data', default=default_dataset, help='data file') parser.add_argument('--seed', type=int, default=None, help='random seed. Randomly set if not specified.') # training options parser.add_argument('--nz', type=int, default=32, help='dimension of latent variable') parser.add_argument('--epoch', type=int, default=1000, help='number of training epochs') parser.add_argument('--batch-size', type=int, default=128, help='batch size') parser.add_argument('--lr', type=float, default=8e-5, help='encoder/decoder learning rate') parser.add_argument('--dis-lr', type=float, default=1e-4, help='discriminator learning rate') parser.add_argument('--min-lr', type=float, default=5e-5, help='min encoder/decoder learning rate for LR ' 'scheduler. -1 to disable annealing') parser.add_argument('--min-dis-lr', type=float, default=7e-5, help='min discriminator learning rate for LR ' 'scheduler. -1 to disable annealing') parser.add_argument('--wd', type=float, default=0, help='weight decay') parser.add_argument('--overlap', type=float, default=.5, help='kernel overlap') parser.add_argument('--no-norm-trans', action='store_true', help='if set, use Gaussian posterior without ' 'transformation') parser.add_argument('--plot-interval', type=int, default=1, help='plot interval. 0 to disable plotting.') parser.add_argument('--save-interval', type=int, default=0, help='interval to save models. 0 to disable saving.') parser.add_argument('--prefix', default='pbigan', help='prefix of output directory') parser.add_argument('--comp', type=int, default=7, help='continuous convolution kernel size') parser.add_argument('--ae', type=float, default=.2, help='autoencoding regularization strength') parser.add_argument('--aeloss', default='smooth_l1', help='autoencoding loss. (options: mse, smooth_l1)') parser.add_argument('--ema', dest='ema', type=int, default=-1, help='start epoch of exponential moving average ' '(EMA). -1 to disable EMA') parser.add_argument('--ema-decay', type=float, default=.9999, help='EMA decay') parser.add_argument('--mmd', type=float, default=1, help='MMD strength for latent variable') # squash is off when rescale is off parser.add_argument('--squash', dest='squash', action='store_const', const=True, default=True, help='bound the generated time series value ' 'using tanh') parser.add_argument('--no-squash', dest='squash', action='store_const', const=False) # rescale to [-1, 1] parser.add_argument('--rescale', dest='rescale', action='store_const', const=True, default=True, help='if set, rescale time to [-1, 1]') parser.add_argument('--no-rescale', dest='rescale', action='store_const', const=False) args = parser.parse_args() batch_size = args.batch_size nz = args.nz epochs = args.epoch plot_interval = args.plot_interval save_interval = args.save_interval try: npz = np.load(args.data) train_data = npz['data'] train_time = npz['time'] train_mask = npz['mask'] except FileNotFoundError: if args.data != default_dataset: raise # Generate the default toy dataset from scratch train_data, train_time, train_mask, _, _ = gen_data( n_samples=10000, seq_len=200, max_time=1, poisson_rate=50, obs_span_rate=.25, save_file=default_dataset) _, in_channels, seq_len = train_data.shape train_time *= train_mask if args.seed is None: rnd = np.random.RandomState(None) random_seed = rnd.randint(np.iinfo(np.uint32).max) else: random_seed = args.seed rnd = np.random.RandomState(random_seed) np.random.seed(random_seed) torch.manual_seed(random_seed) # Scale time max_time = 5 train_time *= max_time squash = None rescaler = None if args.rescale: rescaler = Rescaler(train_data) train_data = rescaler.rescale(train_data) if args.squash: squash = torch.tanh out_channels = 64 cconv_ref = 98 train_dataset = TimeSeries(train_data, train_time, train_mask, label=None, max_time=max_time, cconv_ref=cconv_ref, overlap_rate=args.overlap, device=device) train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, drop_last=True, collate_fn=train_dataset.collate_fn) n_train_batch = len(train_loader) time_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, drop_last=True, collate_fn=train_dataset.collate_fn) test_loader = DataLoader(train_dataset, batch_size=batch_size, collate_fn=train_dataset.collate_fn) grid_decoder = SeqGeneratorDiscrete(in_channels, nz, squash) decoder = Decoder(grid_decoder, max_time=max_time).to(device) cconv = ContinuousConv1D(in_channels, out_channels, max_time, cconv_ref, overlap_rate=args.overlap, kernel_size=args.comp, norm=True).to(device) encoder = Encoder(cconv, nz, not args.no_norm_trans).to(device) pbigan = PBiGAN(encoder, decoder, args.aeloss).to(device) critic_cconv = ContinuousConv1D(in_channels, out_channels, max_time, cconv_ref, overlap_rate=args.overlap, kernel_size=args.comp, norm=True).to(device) critic = ConvCritic(critic_cconv, nz).to(device) ema = None if args.ema >= 0: ema = EMA(pbigan, args.ema_decay, args.ema) optimizer = optim.Adam(pbigan.parameters(), lr=args.lr, weight_decay=args.wd) critic_optimizer = optim.Adam(critic.parameters(), lr=args.dis_lr, weight_decay=args.wd) scheduler = make_scheduler(optimizer, args.lr, args.min_lr, epochs) dis_scheduler = make_scheduler(critic_optimizer, args.dis_lr, args.min_dis_lr, epochs) path = '{}_{}'.format(args.prefix, datetime.now().strftime('%m%d.%H%M%S')) output_dir = Path('results') / 'toy-pbigan' / path print(output_dir) log_dir = mkdir(output_dir / 'log') model_dir = mkdir(output_dir / 'model') start_epoch = 0 with (log_dir / 'seed.txt').open('w') as f: print(random_seed, file=f) with (log_dir / 'gpu.txt').open('a') as f: print(torch.cuda.device_count(), start_epoch, file=f) with (log_dir / 'args.txt').open('w') as f: for key, val in sorted(vars(args).items()): print(f'{key}: {val}', file=f) tracker = Tracker(log_dir, n_train_batch) visualizer = Visualizer(encoder, decoder, batch_size, max_time, test_loader, rescaler, output_dir, device) start = time.time() epoch_start = start for epoch in range(start_epoch, epochs): loss_breakdown = defaultdict(float) for ((val, idx, mask, _, cconv_graph), (_, idx_t, mask_t, index, _)) in zip(train_loader, time_loader): z_enc, x_recon, z_gen, x_gen, ae_loss = pbigan( val, idx, mask, cconv_graph, idx_t, mask_t) cconv_graph_gen = train_dataset.make_graph(x_gen, idx_t, mask_t, index) real = critic(cconv_graph, batch_size, z_enc) fake = critic(cconv_graph_gen, batch_size, z_gen) D_loss = gan_loss(real, fake, 1, 0) critic_optimizer.zero_grad() D_loss.backward(retain_graph=True) critic_optimizer.step() G_loss = gan_loss(real, fake, 0, 1) mmd_loss = mmd(z_enc, z_gen) loss = G_loss + ae_loss * args.ae + mmd_loss * args.mmd optimizer.zero_grad() loss.backward() optimizer.step() if ema: ema.update() loss_breakdown['D'] += D_loss.item() loss_breakdown['G'] += G_loss.item() loss_breakdown['AE'] += ae_loss.item() loss_breakdown['MMD'] += mmd_loss.item() loss_breakdown['total'] += loss.item() if scheduler: scheduler.step() if dis_scheduler: dis_scheduler.step() cur_time = time.time() tracker.log(epoch, loss_breakdown, cur_time - epoch_start, cur_time - start) if plot_interval > 0 and (epoch + 1) % plot_interval == 0: if ema: ema.apply() visualizer.plot(epoch) ema.restore() else: visualizer.plot(epoch) model_dict = { 'pbigan': pbigan.state_dict(), 'critic': critic.state_dict(), 'ema': ema.state_dict() if ema else None, 'epoch': epoch + 1, 'args': args, } torch.save(model_dict, str(log_dir / 'model.pth')) if save_interval > 0 and (epoch + 1) % save_interval == 0: torch.save(model_dict, str(model_dir / f'{epoch:04d}.pth')) print(output_dir)
def main(): parser = argparse.ArgumentParser() parser.add_argument('--data', default='mimic3.npz', help='data file') parser.add_argument('--seed', type=int, default=None, help='random seed. Randomly set if not specified.') # training options parser.add_argument('--nz', type=int, default=32, help='dimension of latent variable') parser.add_argument('--epoch', type=int, default=500, help='number of training epochs') parser.add_argument('--batch-size', type=int, default=64, help='batch size') # Use smaller test batch size to accommodate more importance samples parser.add_argument('--test-batch-size', type=int, default=32, help='batch size for validation and test set') parser.add_argument('--lr', type=float, default=2e-4, help='encoder/decoder learning rate') parser.add_argument('--dis-lr', type=float, default=3e-4, help='discriminator learning rate') parser.add_argument('--min-lr', type=float, default=1e-4, help='min encoder/decoder learning rate for LR ' 'scheduler. -1 to disable annealing') parser.add_argument('--min-dis-lr', type=float, default=1.5e-4, help='min discriminator learning rate for LR ' 'scheduler. -1 to disable annealing') parser.add_argument('--wd', type=float, default=1e-4, help='weight decay') parser.add_argument('--overlap', type=float, default=.5, help='kernel overlap') parser.add_argument('--cls', type=float, default=1, help='classification weight') parser.add_argument('--clsdep', type=int, default=1, help='number of layers for classifier') parser.add_argument('--eval-interval', type=int, default=1, help='AUC evaluation interval. ' '0 to disable evaluation.') parser.add_argument('--save-interval', type=int, default=0, help='interval to save models. 0 to disable saving.') parser.add_argument('--prefix', default='pbigan', help='prefix of output directory') parser.add_argument('--comp', type=int, default=7, help='continuous convolution kernel size') parser.add_argument('--ae', type=float, default=1, help='autoencoding regularization strength') parser.add_argument('--aeloss', default='mse', help='autoencoding loss. (options: mse, smooth_l1)') parser.add_argument('--dec-ch', default='8-16-16', help='decoder architecture') parser.add_argument('--enc-ch', default='64-32-32-16', help='encoder architecture') parser.add_argument('--dis-ch', default=None, help='discriminator architecture. Use encoder ' 'architecture if unspecified.') parser.add_argument('--rescale', dest='rescale', action='store_const', const=True, default=True, help='if set, rescale time to [-1, 1]') parser.add_argument('--no-rescale', dest='rescale', action='store_const', const=False) parser.add_argument('--cconvnorm', dest='cconv_norm', action='store_const', const=True, default=True, help='if set, normalize continuous convolutional ' 'layer using mean pooling') parser.add_argument('--no-cconvnorm', dest='cconv_norm', action='store_const', const=False) parser.add_argument('--cconv-ref', type=int, default=98, help='number of evenly-spaced reference locations ' 'for continuous convolutional layer') parser.add_argument('--dec-ref', type=int, default=128, help='number of evenly-spaced reference locations ' 'for decoder') parser.add_argument('--trans', type=int, default=2, help='number of encoder layers') parser.add_argument('--ema', dest='ema', type=int, default=0, help='start epoch of exponential moving average ' '(EMA). -1 to disable EMA') parser.add_argument('--ema-decay', type=float, default=.9999, help='EMA decay') parser.add_argument('--mmd', type=float, default=1, help='MMD strength for latent variable') args = parser.parse_args() nz = args.nz epochs = args.epoch eval_interval = args.eval_interval save_interval = args.save_interval if args.seed is None: rnd = np.random.RandomState(None) random_seed = rnd.randint(np.iinfo(np.uint32).max) else: random_seed = args.seed rnd = np.random.RandomState(random_seed) np.random.seed(random_seed) torch.manual_seed(random_seed) max_time = 5 cconv_ref = args.cconv_ref overlap = args.overlap train_dataset, val_dataset, test_dataset = time_series.split_data( args.data, rnd, max_time, cconv_ref, overlap, device, args.rescale) train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, drop_last=True, collate_fn=train_dataset.collate_fn) n_train_batch = len(train_loader) time_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, drop_last=True, collate_fn=train_dataset.collate_fn) val_loader = DataLoader(val_dataset, batch_size=args.test_batch_size, shuffle=False, collate_fn=val_dataset.collate_fn) test_loader = DataLoader(test_dataset, batch_size=args.test_batch_size, shuffle=False, collate_fn=test_dataset.collate_fn) in_channels, seq_len = train_dataset.data.shape[1:] if args.dis_ch is None: args.dis_ch = args.enc_ch dec_channels = [int(c) for c in args.dec_ch.split('-')] + [in_channels] enc_channels = [int(c) for c in args.enc_ch.split('-')] dis_channels = [int(c) for c in args.dis_ch.split('-')] out_channels = enc_channels[0] squash = torch.sigmoid if args.rescale: squash = torch.tanh dec_ch_up = 2**(len(dec_channels) - 2) assert args.dec_ref % dec_ch_up == 0, ( f'--dec-ref={args.dec_ref} is not divided by {dec_ch_up}.') dec_len0 = args.dec_ref // dec_ch_up grid_decoder = GridDecoder(nz, dec_channels, dec_len0, squash) decoder = Decoder(grid_decoder, max_time=max_time, dec_ref=args.dec_ref).to(device) cconv = ContinuousConv1D(in_channels, out_channels, max_time, cconv_ref, overlap_rate=overlap, kernel_size=args.comp, norm=args.cconv_norm).to(device) encoder = Encoder(cconv, nz, enc_channels, args.trans).to(device) classifier = Classifier(nz, args.clsdep).to(device) pbigan = PBiGAN(encoder, decoder, classifier, ae_loss=args.aeloss).to(device) ema = None if args.ema >= 0: ema = EMA(pbigan, args.ema_decay, args.ema) critic_cconv = ContinuousConv1D(in_channels, out_channels, max_time, cconv_ref, overlap_rate=overlap, kernel_size=args.comp, norm=args.cconv_norm).to(device) critic_embed = 32 critic = ConvCritic(critic_cconv, nz, dis_channels, critic_embed).to(device) optimizer = optim.Adam(pbigan.parameters(), lr=args.lr, betas=(0, .999), weight_decay=args.wd) critic_optimizer = optim.Adam(critic.parameters(), lr=args.dis_lr, betas=(0, .999), weight_decay=args.wd) scheduler = make_scheduler(optimizer, args.lr, args.min_lr, epochs) dis_scheduler = make_scheduler(critic_optimizer, args.dis_lr, args.min_dis_lr, epochs) path = '{}_{}'.format(args.prefix, datetime.now().strftime('%m%d.%H%M%S')) output_dir = Path('results') / 'mimic3-pbigan' / path print(output_dir) log_dir = mkdir(output_dir / 'log') model_dir = mkdir(output_dir / 'model') start_epoch = 0 with (log_dir / 'seed.txt').open('w') as f: print(random_seed, file=f) with (log_dir / 'gpu.txt').open('a') as f: print(torch.cuda.device_count(), start_epoch, file=f) with (log_dir / 'args.txt').open('w') as f: for key, val in sorted(vars(args).items()): print(f'{key}: {val}', file=f) with (log_dir / 'params.txt').open('w') as f: def print_params_count(module, name): try: # sum counts if module is a list params_count = sum(count_parameters(m) for m in module) except TypeError: params_count = count_parameters(module) print(f'{name} {params_count}', file=f) print_params_count(grid_decoder, 'grid_decoder') print_params_count(decoder, 'decoder') print_params_count(cconv, 'cconv') print_params_count(encoder, 'encoder') print_params_count(classifier, 'classifier') print_params_count(pbigan, 'pbigan') print_params_count(critic, 'critic') print_params_count([pbigan, critic], 'total') tracker = Tracker(log_dir, n_train_batch) evaluator = Evaluator(pbigan, val_loader, test_loader, log_dir) start = time.time() epoch_start = start batch_size = args.batch_size for epoch in range(start_epoch, epochs): loss_breakdown = defaultdict(float) epoch_start = time.time() if epoch >= 40: args.cls = 200 for ((val, idx, mask, y, _, cconv_graph), (_, idx_t, mask_t, _, index, _)) in zip(train_loader, time_loader): z_enc, x_recon, z_gen, x_gen, ae_loss, cls_loss = pbigan( val, idx, mask, y, cconv_graph, idx_t, mask_t) cconv_graph_gen = train_dataset.make_graph(x_gen, idx_t, mask_t, index) # Don't need pbigan.requires_grad_(False); # critic takes as input only the detached tensors. real = critic(cconv_graph, batch_size, z_enc.detach()) detached_graph = [[cat_y.detach() for cat_y in x] if i == 2 else x for i, x in enumerate(cconv_graph_gen)] fake = critic(detached_graph, batch_size, z_gen.detach()) D_loss = gan_loss(real, fake, 1, 0) critic_optimizer.zero_grad() D_loss.backward() critic_optimizer.step() for p in critic.parameters(): p.requires_grad_(False) real = critic(cconv_graph, batch_size, z_enc) fake = critic(cconv_graph_gen, batch_size, z_gen) G_loss = gan_loss(real, fake, 0, 1) mmd_loss = mmd(z_enc, z_gen) loss = (G_loss + ae_loss * args.ae + cls_loss * args.cls + mmd_loss * args.mmd) optimizer.zero_grad() loss.backward() optimizer.step() for p in critic.parameters(): p.requires_grad_(True) if ema: ema.update() loss_breakdown['D'] += D_loss.item() loss_breakdown['G'] += G_loss.item() loss_breakdown['AE'] += ae_loss.item() loss_breakdown['MMD'] += mmd_loss.item() loss_breakdown['CLS'] += cls_loss.item() loss_breakdown['total'] += loss.item() if scheduler: scheduler.step() if dis_scheduler: dis_scheduler.step() cur_time = time.time() tracker.log(epoch, loss_breakdown, cur_time - epoch_start, cur_time - start) if eval_interval > 0 and (epoch + 1) % eval_interval == 0: if ema: ema.apply() evaluator.evaluate(epoch) ema.restore() else: evaluator.evaluate(epoch) model_dict = { 'pbigan': pbigan.state_dict(), 'critic': critic.state_dict(), 'ema': ema.state_dict() if ema else None, 'epoch': epoch + 1, 'args': args, } torch.save(model_dict, str(log_dir / 'model.pth')) if save_interval > 0 and (epoch + 1) % save_interval == 0: torch.save(model_dict, str(model_dir / f'{epoch:04d}.pth')) print(output_dir)