class RSI(object): def __init__(self, period): self.value = None self.last = None self.ema_u = EMA(period) self.ema_d = EMA(period) self.tbl = None def setupH5(self, h5file, h5where, h5name): if h5file != None and h5where != None and h5name != None: self.tbl = h5file.createTable(h5where, h5name, RSIData) def update(self, value, date=None): if self.last == None: self.last = value U = value - self.last D = self.last - value self.last = value if U > 0: D = 0 elif D > 0: U = 0 self.ema_u.update(U) self.ema_d.update(D) if self.ema_d.value == 0: self.value = 100.0 else: rs = self.ema_u.value / self.ema_d.value self.value = 100.0 - (100.0 / (1 + rs)) if self.tbl != None and date: self.tbl.row["date"] = date.date().toordinal() self.tbl.row["value"] = self.value self.tbl.row.append() self.tbl.flush() return self.value
def main(): parser = argparse.ArgumentParser() parser.add_argument('--data', default='mimic3.npz', help='data file') parser.add_argument('--seed', type=int, default=None, help='random seed. Randomly set if not specified.') # training options parser.add_argument('--nz', type=int, default=32, help='dimension of latent variable') parser.add_argument('--epoch', type=int, default=200, help='number of training epochs') parser.add_argument('--batch-size', type=int, default=64, help='batch size') # Use smaller test batch size to accommodate more importance samples parser.add_argument('--test-batch-size', type=int, default=32, help='batch size for validation and test set') parser.add_argument('--train-k', type=int, default=8, help='number of importance weights for training') parser.add_argument('--test-k', type=int, default=50, help='number of importance weights for evaluation') parser.add_argument('--flow', type=int, default=2, help='number of IAF layers') parser.add_argument('--lr', type=float, default=2e-4, help='global learning rate') parser.add_argument('--enc-lr', type=float, default=1e-4, help='encoder learning rate') parser.add_argument('--dec-lr', type=float, default=1e-4, help='decoder learning rate') parser.add_argument('--min-lr', type=float, default=-1, help='min learning rate for LR scheduler. ' '-1 to disable annealing') parser.add_argument('--wd', type=float, default=1e-3, help='weight decay') parser.add_argument('--overlap', type=float, default=.5, help='kernel overlap') parser.add_argument('--cls', type=float, default=200, help='classification weight') parser.add_argument('--clsdep', type=int, default=1, help='number of layers for classifier') parser.add_argument('--ts', type=float, default=1, help='log-likelihood weight for ELBO') parser.add_argument('--kl', type=float, default=.1, help='KL weight for ELBO') parser.add_argument('--eval-interval', type=int, default=1, help='AUC evaluation interval. ' '0 to disable evaluation.') parser.add_argument('--save-interval', type=int, default=0, help='interval to save models. 0 to disable saving.') parser.add_argument('--prefix', default='pvae', help='prefix of output directory') parser.add_argument('--comp', type=int, default=7, help='continuous convolution kernel size') parser.add_argument('--sigma', type=float, default=.2, help='standard deviation for Gaussian likelihood') parser.add_argument('--dec-ch', default='8-16-16', help='decoder architecture') parser.add_argument('--enc-ch', default='64-32-32-16', help='encoder architecture') parser.add_argument('--rescale', dest='rescale', action='store_const', const=True, default=True, help='if set, rescale time to [-1, 1]') parser.add_argument('--no-rescale', dest='rescale', action='store_const', const=False) parser.add_argument('--cconvnorm', dest='cconv_norm', action='store_const', const=True, default=True, help='if set, normalize continuous convolutional ' 'layer using mean pooling') parser.add_argument('--no-cconvnorm', dest='cconv_norm', action='store_const', const=False) parser.add_argument('--cconv-ref', type=int, default=98, help='number of evenly-spaced reference locations ' 'for continuous convolutional layer') parser.add_argument('--dec-ref', type=int, default=128, help='number of evenly-spaced reference locations ' 'for decoder') parser.add_argument('--ema', dest='ema', type=int, default=0, help='start epoch of exponential moving average ' '(EMA). -1 to disable EMA') parser.add_argument('--ema-decay', type=float, default=.9999, help='EMA decay') args = parser.parse_args() nz = args.nz epochs = args.epoch eval_interval = args.eval_interval save_interval = args.save_interval if args.seed is None: rnd = np.random.RandomState(None) random_seed = rnd.randint(np.iinfo(np.uint32).max) else: random_seed = args.seed rnd = np.random.RandomState(random_seed) np.random.seed(random_seed) torch.manual_seed(random_seed) max_time = 5 cconv_ref = args.cconv_ref overlap = args.overlap train_dataset, val_dataset, test_dataset = time_series.split_data( args.data, rnd, max_time, cconv_ref, overlap, device, args.rescale) train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, drop_last=True, collate_fn=train_dataset.collate_fn) n_train_batch = len(train_loader) val_loader = DataLoader(val_dataset, batch_size=args.test_batch_size, shuffle=False, collate_fn=val_dataset.collate_fn) test_loader = DataLoader(test_dataset, batch_size=args.test_batch_size, shuffle=False, collate_fn=test_dataset.collate_fn) in_channels, seq_len = train_dataset.data.shape[1:] dec_channels = [int(c) for c in args.dec_ch.split('-')] + [in_channels] enc_channels = [int(c) for c in args.enc_ch.split('-')] out_channels = enc_channels[0] squash = torch.sigmoid if args.rescale: squash = torch.tanh dec_ch_up = 2**(len(dec_channels) - 2) assert args.dec_ref % dec_ch_up == 0, ( f'--dec-ref={args.dec_ref} is not divided by {dec_ch_up}.') dec_len0 = args.dec_ref // dec_ch_up grid_decoder = GridDecoder(nz, dec_channels, dec_len0, squash) decoder = Decoder(grid_decoder, max_time=max_time, dec_ref=args.dec_ref).to(device) cconv = ContinuousConv1D(in_channels, out_channels, max_time, cconv_ref, overlap_rate=overlap, kernel_size=args.comp, norm=args.cconv_norm).to(device) encoder = Encoder(cconv, nz, enc_channels, args.flow).to(device) classifier = Classifier(nz, args.clsdep).to(device) pvae = PVAE(encoder, decoder, classifier, args.sigma, args.cls).to(device) ema = None if args.ema >= 0: ema = EMA(pvae, args.ema_decay, args.ema) other_params = [ param for name, param in pvae.named_parameters() if not (name.startswith('decoder.grid_decoder') or name.startswith('encoder.grid_encoder')) ] params = [ { 'params': decoder.grid_decoder.parameters(), 'lr': args.dec_lr }, { 'params': encoder.grid_encoder.parameters(), 'lr': args.enc_lr }, { 'params': other_params }, ] optimizer = optim.Adam(params, lr=args.lr, weight_decay=args.wd) scheduler = make_scheduler(optimizer, args.lr, args.min_lr, epochs) path = '{}_{}'.format(args.prefix, datetime.now().strftime('%m%d.%H%M%S')) output_dir = Path('results') / 'mimic3-pvae' / path print(output_dir) log_dir = mkdir(output_dir / 'log') model_dir = mkdir(output_dir / 'model') start_epoch = 0 with (log_dir / 'seed.txt').open('w') as f: print(random_seed, file=f) with (log_dir / 'gpu.txt').open('a') as f: print(torch.cuda.device_count(), start_epoch, file=f) with (log_dir / 'args.txt').open('w') as f: for key, val in sorted(vars(args).items()): print(f'{key}: {val}', file=f) with (log_dir / 'params.txt').open('w') as f: def print_params_count(module, name): try: # sum counts if module is a list params_count = sum(count_parameters(m) for m in module) except TypeError: params_count = count_parameters(module) print(f'{name} {params_count}', file=f) print_params_count(grid_decoder, 'grid_decoder') print_params_count(decoder, 'decoder') print_params_count(cconv, 'cconv') print_params_count(encoder, 'encoder') print_params_count(classifier, 'classifier') print_params_count(pvae, 'pvae') print_params_count(pvae, 'total') tracker = Tracker(log_dir, n_train_batch) evaluator = Evaluator(pvae, val_loader, test_loader, log_dir, eval_args={'iw_samples': args.test_k}) start = time.time() epoch_start = start for epoch in range(start_epoch, epochs): loss_breakdown = defaultdict(float) epoch_start = time.time() for (val, idx, mask, y, _, cconv_graph) in train_loader: optimizer.zero_grad() loss, _, _, loss_info = pvae(val, idx, mask, y, cconv_graph, args.train_k, args.ts, args.kl) loss.backward() optimizer.step() if ema: ema.update() for loss_name, loss_val in loss_info.items(): loss_breakdown[loss_name] += loss_val if scheduler: scheduler.step() cur_time = time.time() tracker.log(epoch, loss_breakdown, cur_time - epoch_start, cur_time - start) if eval_interval > 0 and (epoch + 1) % eval_interval == 0: if ema: ema.apply() evaluator.evaluate(epoch) ema.restore() else: evaluator.evaluate(epoch) model_dict = { 'pvae': pvae.state_dict(), 'ema': ema.state_dict() if ema else None, 'epoch': epoch + 1, 'args': args, } torch.save(model_dict, str(log_dir / 'model.pth')) if save_interval > 0 and (epoch + 1) % save_interval == 0: torch.save(model_dict, str(model_dir / f'{epoch:04d}.pth')) print(output_dir)
def train_val_model(pipeline_cfg, model_cfg, train_cfg): data_pipeline = DataPipeline(**pipeline_cfg) device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") if model_cfg['cxt_emb_pretrained'] is not None: model_cfg['cxt_emb_pretrained'] = torch.load( model_cfg['cxt_emb_pretrained']) bidaf = BiDAF(word_emb=data_pipeline.word_type.vocab.vectors, **model_cfg) ema = EMA(train_cfg['exp_decay_rate']) for name, param in bidaf.named_parameters(): if param.requires_grad: ema.register(name, param.data) parameters = filter(lambda p: p.requires_grad, bidaf.parameters()) optimizer = optim.Adadelta(parameters, lr=train_cfg['lr']) criterion = nn.CrossEntropyLoss() result = {'best_f1': 0.0, 'best_model': None} num_epochs = train_cfg['num_epochs'] for epoch in range(1, num_epochs + 1): print('Epoch {}/{}'.format(epoch, num_epochs)) print('-' * 10) for phase in ['train', 'val']: val_answers = dict() val_f1 = 0 val_em = 0 val_cnt = 0 val_r = 0 if phase == 'train': bidaf.train() else: bidaf.eval() backup_params = EMA(0) for name, param in bidaf.named_parameters(): if param.requires_grad: backup_params.register(name, param.data) param.data.copy_(ema.get(name)) with torch.set_grad_enabled(phase == 'train'): for batch_num, batch in enumerate( data_pipeline.data_iterators[phase]): optimizer.zero_grad() p1, p2 = bidaf(batch) loss = criterion(p1, batch.s_idx) + criterion( p2, batch.e_idx) if phase == 'train': loss.backward() optimizer.step() for name, param in bidaf.named_parameters(): if param.requires_grad: ema.update(name, param.data) if batch_num % train_cfg['batch_per_disp'] == 0: batch_loss = loss.item() print('batch %d: loss %.3f' % (batch_num, batch_loss)) if phase == 'val': batch_size, c_len = p1.size() val_cnt += batch_size ls = nn.LogSoftmax(dim=1) mask = (torch.ones(c_len, c_len) * float('-inf')).to(device).tril(-1). \ unsqueeze(0).expand(batch_size, -1, -1) score = (ls(p1).unsqueeze(2) + ls(p2).unsqueeze(1)) + mask score, s_idx = score.max(dim=1) score, e_idx = score.max(dim=1) s_idx = torch.gather(s_idx, 1, e_idx.view(-1, 1)).squeeze() for i in range(batch_size): answer = (s_idx[i], e_idx[i]) gt = (batch.s_idx[i], batch.e_idx[i]) val_f1 += f1_score(answer, gt) val_em += exact_match_score(answer, gt) val_r += r_score(answer, gt) if phase == 'val': val_f1 = val_f1 * 100 / val_cnt val_em = val_em * 100 / val_cnt val_r = val_r * 100 / val_cnt print('Epoch %d: %s f1 %.3f | %s em %.3f | %s rouge %.3f' % (epoch, phase, val_f1, phase, val_em, phase, val_r)) if val_f1 > result['best_f1']: result['best_f1'] = val_f1 result['best_em'] = val_em result['best_model'] = copy.deepcopy(bidaf.state_dict()) torch.save(result, train_cfg['ckpoint_file']) # with open(train_cfg['val_answers'], 'w', encoding='utf-8') as f: # print(json.dumps(val_answers), file=f) for name, param in bidaf.named_parameters(): if param.requires_grad: param.data.copy_(backup_params.get(name))
def main(): parser = argparse.ArgumentParser() default_dataset = 'toy-data.npz' parser.add_argument('--data', default=default_dataset, help='data file') parser.add_argument('--seed', type=int, default=None, help='random seed. Randomly set if not specified.') # training options parser.add_argument('--nz', type=int, default=32, help='dimension of latent variable') parser.add_argument('--epoch', type=int, default=1000, help='number of training epochs') parser.add_argument('--batch-size', type=int, default=128, help='batch size') parser.add_argument('--lr', type=float, default=8e-5, help='encoder/decoder learning rate') parser.add_argument('--dis-lr', type=float, default=1e-4, help='discriminator learning rate') parser.add_argument('--min-lr', type=float, default=5e-5, help='min encoder/decoder learning rate for LR ' 'scheduler. -1 to disable annealing') parser.add_argument('--min-dis-lr', type=float, default=7e-5, help='min discriminator learning rate for LR ' 'scheduler. -1 to disable annealing') parser.add_argument('--wd', type=float, default=0, help='weight decay') parser.add_argument('--overlap', type=float, default=.5, help='kernel overlap') parser.add_argument('--no-norm-trans', action='store_true', help='if set, use Gaussian posterior without ' 'transformation') parser.add_argument('--plot-interval', type=int, default=1, help='plot interval. 0 to disable plotting.') parser.add_argument('--save-interval', type=int, default=0, help='interval to save models. 0 to disable saving.') parser.add_argument('--prefix', default='pbigan', help='prefix of output directory') parser.add_argument('--comp', type=int, default=7, help='continuous convolution kernel size') parser.add_argument('--ae', type=float, default=.2, help='autoencoding regularization strength') parser.add_argument('--aeloss', default='smooth_l1', help='autoencoding loss. (options: mse, smooth_l1)') parser.add_argument('--ema', dest='ema', type=int, default=-1, help='start epoch of exponential moving average ' '(EMA). -1 to disable EMA') parser.add_argument('--ema-decay', type=float, default=.9999, help='EMA decay') parser.add_argument('--mmd', type=float, default=1, help='MMD strength for latent variable') # squash is off when rescale is off parser.add_argument('--squash', dest='squash', action='store_const', const=True, default=True, help='bound the generated time series value ' 'using tanh') parser.add_argument('--no-squash', dest='squash', action='store_const', const=False) # rescale to [-1, 1] parser.add_argument('--rescale', dest='rescale', action='store_const', const=True, default=True, help='if set, rescale time to [-1, 1]') parser.add_argument('--no-rescale', dest='rescale', action='store_const', const=False) args = parser.parse_args() batch_size = args.batch_size nz = args.nz epochs = args.epoch plot_interval = args.plot_interval save_interval = args.save_interval try: npz = np.load(args.data) train_data = npz['data'] train_time = npz['time'] train_mask = npz['mask'] except FileNotFoundError: if args.data != default_dataset: raise # Generate the default toy dataset from scratch train_data, train_time, train_mask, _, _ = gen_data( n_samples=10000, seq_len=200, max_time=1, poisson_rate=50, obs_span_rate=.25, save_file=default_dataset) _, in_channels, seq_len = train_data.shape train_time *= train_mask if args.seed is None: rnd = np.random.RandomState(None) random_seed = rnd.randint(np.iinfo(np.uint32).max) else: random_seed = args.seed rnd = np.random.RandomState(random_seed) np.random.seed(random_seed) torch.manual_seed(random_seed) # Scale time max_time = 5 train_time *= max_time squash = None rescaler = None if args.rescale: rescaler = Rescaler(train_data) train_data = rescaler.rescale(train_data) if args.squash: squash = torch.tanh out_channels = 64 cconv_ref = 98 train_dataset = TimeSeries(train_data, train_time, train_mask, label=None, max_time=max_time, cconv_ref=cconv_ref, overlap_rate=args.overlap, device=device) train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, drop_last=True, collate_fn=train_dataset.collate_fn) n_train_batch = len(train_loader) time_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, drop_last=True, collate_fn=train_dataset.collate_fn) test_loader = DataLoader(train_dataset, batch_size=batch_size, collate_fn=train_dataset.collate_fn) grid_decoder = SeqGeneratorDiscrete(in_channels, nz, squash) decoder = Decoder(grid_decoder, max_time=max_time).to(device) cconv = ContinuousConv1D(in_channels, out_channels, max_time, cconv_ref, overlap_rate=args.overlap, kernel_size=args.comp, norm=True).to(device) encoder = Encoder(cconv, nz, not args.no_norm_trans).to(device) pbigan = PBiGAN(encoder, decoder, args.aeloss).to(device) critic_cconv = ContinuousConv1D(in_channels, out_channels, max_time, cconv_ref, overlap_rate=args.overlap, kernel_size=args.comp, norm=True).to(device) critic = ConvCritic(critic_cconv, nz).to(device) ema = None if args.ema >= 0: ema = EMA(pbigan, args.ema_decay, args.ema) optimizer = optim.Adam(pbigan.parameters(), lr=args.lr, weight_decay=args.wd) critic_optimizer = optim.Adam(critic.parameters(), lr=args.dis_lr, weight_decay=args.wd) scheduler = make_scheduler(optimizer, args.lr, args.min_lr, epochs) dis_scheduler = make_scheduler(critic_optimizer, args.dis_lr, args.min_dis_lr, epochs) path = '{}_{}'.format(args.prefix, datetime.now().strftime('%m%d.%H%M%S')) output_dir = Path('results') / 'toy-pbigan' / path print(output_dir) log_dir = mkdir(output_dir / 'log') model_dir = mkdir(output_dir / 'model') start_epoch = 0 with (log_dir / 'seed.txt').open('w') as f: print(random_seed, file=f) with (log_dir / 'gpu.txt').open('a') as f: print(torch.cuda.device_count(), start_epoch, file=f) with (log_dir / 'args.txt').open('w') as f: for key, val in sorted(vars(args).items()): print(f'{key}: {val}', file=f) tracker = Tracker(log_dir, n_train_batch) visualizer = Visualizer(encoder, decoder, batch_size, max_time, test_loader, rescaler, output_dir, device) start = time.time() epoch_start = start for epoch in range(start_epoch, epochs): loss_breakdown = defaultdict(float) for ((val, idx, mask, _, cconv_graph), (_, idx_t, mask_t, index, _)) in zip(train_loader, time_loader): z_enc, x_recon, z_gen, x_gen, ae_loss = pbigan( val, idx, mask, cconv_graph, idx_t, mask_t) cconv_graph_gen = train_dataset.make_graph(x_gen, idx_t, mask_t, index) real = critic(cconv_graph, batch_size, z_enc) fake = critic(cconv_graph_gen, batch_size, z_gen) D_loss = gan_loss(real, fake, 1, 0) critic_optimizer.zero_grad() D_loss.backward(retain_graph=True) critic_optimizer.step() G_loss = gan_loss(real, fake, 0, 1) mmd_loss = mmd(z_enc, z_gen) loss = G_loss + ae_loss * args.ae + mmd_loss * args.mmd optimizer.zero_grad() loss.backward() optimizer.step() if ema: ema.update() loss_breakdown['D'] += D_loss.item() loss_breakdown['G'] += G_loss.item() loss_breakdown['AE'] += ae_loss.item() loss_breakdown['MMD'] += mmd_loss.item() loss_breakdown['total'] += loss.item() if scheduler: scheduler.step() if dis_scheduler: dis_scheduler.step() cur_time = time.time() tracker.log(epoch, loss_breakdown, cur_time - epoch_start, cur_time - start) if plot_interval > 0 and (epoch + 1) % plot_interval == 0: if ema: ema.apply() visualizer.plot(epoch) ema.restore() else: visualizer.plot(epoch) model_dict = { 'pbigan': pbigan.state_dict(), 'critic': critic.state_dict(), 'ema': ema.state_dict() if ema else None, 'epoch': epoch + 1, 'args': args, } torch.save(model_dict, str(log_dir / 'model.pth')) if save_interval > 0 and (epoch + 1) % save_interval == 0: torch.save(model_dict, str(model_dir / f'{epoch:04d}.pth')) print(output_dir)
def main(): parser = argparse.ArgumentParser() parser.add_argument('--data', default='mimic3.npz', help='data file') parser.add_argument('--seed', type=int, default=None, help='random seed. Randomly set if not specified.') # training options parser.add_argument('--nz', type=int, default=32, help='dimension of latent variable') parser.add_argument('--epoch', type=int, default=500, help='number of training epochs') parser.add_argument('--batch-size', type=int, default=64, help='batch size') # Use smaller test batch size to accommodate more importance samples parser.add_argument('--test-batch-size', type=int, default=32, help='batch size for validation and test set') parser.add_argument('--lr', type=float, default=2e-4, help='encoder/decoder learning rate') parser.add_argument('--dis-lr', type=float, default=3e-4, help='discriminator learning rate') parser.add_argument('--min-lr', type=float, default=1e-4, help='min encoder/decoder learning rate for LR ' 'scheduler. -1 to disable annealing') parser.add_argument('--min-dis-lr', type=float, default=1.5e-4, help='min discriminator learning rate for LR ' 'scheduler. -1 to disable annealing') parser.add_argument('--wd', type=float, default=1e-4, help='weight decay') parser.add_argument('--overlap', type=float, default=.5, help='kernel overlap') parser.add_argument('--cls', type=float, default=1, help='classification weight') parser.add_argument('--clsdep', type=int, default=1, help='number of layers for classifier') parser.add_argument('--eval-interval', type=int, default=1, help='AUC evaluation interval. ' '0 to disable evaluation.') parser.add_argument('--save-interval', type=int, default=0, help='interval to save models. 0 to disable saving.') parser.add_argument('--prefix', default='pbigan', help='prefix of output directory') parser.add_argument('--comp', type=int, default=7, help='continuous convolution kernel size') parser.add_argument('--ae', type=float, default=1, help='autoencoding regularization strength') parser.add_argument('--aeloss', default='mse', help='autoencoding loss. (options: mse, smooth_l1)') parser.add_argument('--dec-ch', default='8-16-16', help='decoder architecture') parser.add_argument('--enc-ch', default='64-32-32-16', help='encoder architecture') parser.add_argument('--dis-ch', default=None, help='discriminator architecture. Use encoder ' 'architecture if unspecified.') parser.add_argument('--rescale', dest='rescale', action='store_const', const=True, default=True, help='if set, rescale time to [-1, 1]') parser.add_argument('--no-rescale', dest='rescale', action='store_const', const=False) parser.add_argument('--cconvnorm', dest='cconv_norm', action='store_const', const=True, default=True, help='if set, normalize continuous convolutional ' 'layer using mean pooling') parser.add_argument('--no-cconvnorm', dest='cconv_norm', action='store_const', const=False) parser.add_argument('--cconv-ref', type=int, default=98, help='number of evenly-spaced reference locations ' 'for continuous convolutional layer') parser.add_argument('--dec-ref', type=int, default=128, help='number of evenly-spaced reference locations ' 'for decoder') parser.add_argument('--trans', type=int, default=2, help='number of encoder layers') parser.add_argument('--ema', dest='ema', type=int, default=0, help='start epoch of exponential moving average ' '(EMA). -1 to disable EMA') parser.add_argument('--ema-decay', type=float, default=.9999, help='EMA decay') parser.add_argument('--mmd', type=float, default=1, help='MMD strength for latent variable') args = parser.parse_args() nz = args.nz epochs = args.epoch eval_interval = args.eval_interval save_interval = args.save_interval if args.seed is None: rnd = np.random.RandomState(None) random_seed = rnd.randint(np.iinfo(np.uint32).max) else: random_seed = args.seed rnd = np.random.RandomState(random_seed) np.random.seed(random_seed) torch.manual_seed(random_seed) max_time = 5 cconv_ref = args.cconv_ref overlap = args.overlap train_dataset, val_dataset, test_dataset = time_series.split_data( args.data, rnd, max_time, cconv_ref, overlap, device, args.rescale) train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, drop_last=True, collate_fn=train_dataset.collate_fn) n_train_batch = len(train_loader) time_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, drop_last=True, collate_fn=train_dataset.collate_fn) val_loader = DataLoader(val_dataset, batch_size=args.test_batch_size, shuffle=False, collate_fn=val_dataset.collate_fn) test_loader = DataLoader(test_dataset, batch_size=args.test_batch_size, shuffle=False, collate_fn=test_dataset.collate_fn) in_channels, seq_len = train_dataset.data.shape[1:] if args.dis_ch is None: args.dis_ch = args.enc_ch dec_channels = [int(c) for c in args.dec_ch.split('-')] + [in_channels] enc_channels = [int(c) for c in args.enc_ch.split('-')] dis_channels = [int(c) for c in args.dis_ch.split('-')] out_channels = enc_channels[0] squash = torch.sigmoid if args.rescale: squash = torch.tanh dec_ch_up = 2**(len(dec_channels) - 2) assert args.dec_ref % dec_ch_up == 0, ( f'--dec-ref={args.dec_ref} is not divided by {dec_ch_up}.') dec_len0 = args.dec_ref // dec_ch_up grid_decoder = GridDecoder(nz, dec_channels, dec_len0, squash) decoder = Decoder(grid_decoder, max_time=max_time, dec_ref=args.dec_ref).to(device) cconv = ContinuousConv1D(in_channels, out_channels, max_time, cconv_ref, overlap_rate=overlap, kernel_size=args.comp, norm=args.cconv_norm).to(device) encoder = Encoder(cconv, nz, enc_channels, args.trans).to(device) classifier = Classifier(nz, args.clsdep).to(device) pbigan = PBiGAN(encoder, decoder, classifier, ae_loss=args.aeloss).to(device) ema = None if args.ema >= 0: ema = EMA(pbigan, args.ema_decay, args.ema) critic_cconv = ContinuousConv1D(in_channels, out_channels, max_time, cconv_ref, overlap_rate=overlap, kernel_size=args.comp, norm=args.cconv_norm).to(device) critic_embed = 32 critic = ConvCritic(critic_cconv, nz, dis_channels, critic_embed).to(device) optimizer = optim.Adam(pbigan.parameters(), lr=args.lr, betas=(0, .999), weight_decay=args.wd) critic_optimizer = optim.Adam(critic.parameters(), lr=args.dis_lr, betas=(0, .999), weight_decay=args.wd) scheduler = make_scheduler(optimizer, args.lr, args.min_lr, epochs) dis_scheduler = make_scheduler(critic_optimizer, args.dis_lr, args.min_dis_lr, epochs) path = '{}_{}'.format(args.prefix, datetime.now().strftime('%m%d.%H%M%S')) output_dir = Path('results') / 'mimic3-pbigan' / path print(output_dir) log_dir = mkdir(output_dir / 'log') model_dir = mkdir(output_dir / 'model') start_epoch = 0 with (log_dir / 'seed.txt').open('w') as f: print(random_seed, file=f) with (log_dir / 'gpu.txt').open('a') as f: print(torch.cuda.device_count(), start_epoch, file=f) with (log_dir / 'args.txt').open('w') as f: for key, val in sorted(vars(args).items()): print(f'{key}: {val}', file=f) with (log_dir / 'params.txt').open('w') as f: def print_params_count(module, name): try: # sum counts if module is a list params_count = sum(count_parameters(m) for m in module) except TypeError: params_count = count_parameters(module) print(f'{name} {params_count}', file=f) print_params_count(grid_decoder, 'grid_decoder') print_params_count(decoder, 'decoder') print_params_count(cconv, 'cconv') print_params_count(encoder, 'encoder') print_params_count(classifier, 'classifier') print_params_count(pbigan, 'pbigan') print_params_count(critic, 'critic') print_params_count([pbigan, critic], 'total') tracker = Tracker(log_dir, n_train_batch) evaluator = Evaluator(pbigan, val_loader, test_loader, log_dir) start = time.time() epoch_start = start batch_size = args.batch_size for epoch in range(start_epoch, epochs): loss_breakdown = defaultdict(float) epoch_start = time.time() if epoch >= 40: args.cls = 200 for ((val, idx, mask, y, _, cconv_graph), (_, idx_t, mask_t, _, index, _)) in zip(train_loader, time_loader): z_enc, x_recon, z_gen, x_gen, ae_loss, cls_loss = pbigan( val, idx, mask, y, cconv_graph, idx_t, mask_t) cconv_graph_gen = train_dataset.make_graph(x_gen, idx_t, mask_t, index) # Don't need pbigan.requires_grad_(False); # critic takes as input only the detached tensors. real = critic(cconv_graph, batch_size, z_enc.detach()) detached_graph = [[cat_y.detach() for cat_y in x] if i == 2 else x for i, x in enumerate(cconv_graph_gen)] fake = critic(detached_graph, batch_size, z_gen.detach()) D_loss = gan_loss(real, fake, 1, 0) critic_optimizer.zero_grad() D_loss.backward() critic_optimizer.step() for p in critic.parameters(): p.requires_grad_(False) real = critic(cconv_graph, batch_size, z_enc) fake = critic(cconv_graph_gen, batch_size, z_gen) G_loss = gan_loss(real, fake, 0, 1) mmd_loss = mmd(z_enc, z_gen) loss = (G_loss + ae_loss * args.ae + cls_loss * args.cls + mmd_loss * args.mmd) optimizer.zero_grad() loss.backward() optimizer.step() for p in critic.parameters(): p.requires_grad_(True) if ema: ema.update() loss_breakdown['D'] += D_loss.item() loss_breakdown['G'] += G_loss.item() loss_breakdown['AE'] += ae_loss.item() loss_breakdown['MMD'] += mmd_loss.item() loss_breakdown['CLS'] += cls_loss.item() loss_breakdown['total'] += loss.item() if scheduler: scheduler.step() if dis_scheduler: dis_scheduler.step() cur_time = time.time() tracker.log(epoch, loss_breakdown, cur_time - epoch_start, cur_time - start) if eval_interval > 0 and (epoch + 1) % eval_interval == 0: if ema: ema.apply() evaluator.evaluate(epoch) ema.restore() else: evaluator.evaluate(epoch) model_dict = { 'pbigan': pbigan.state_dict(), 'critic': critic.state_dict(), 'ema': ema.state_dict() if ema else None, 'epoch': epoch + 1, 'args': args, } torch.save(model_dict, str(log_dir / 'model.pth')) if save_interval > 0 and (epoch + 1) % save_interval == 0: torch.save(model_dict, str(model_dir / f'{epoch:04d}.pth')) print(output_dir)
def train_bidaf(args, data): device = torch.device( f"cuda:{args.gpu}" if torch.cuda.is_available() else "cpu") model = BiDAF(args, data.WORD.vocab.vectors).to(device) ema = EMA(args.exp_decay_rate) for name, param in model.named_parameters(): if param.requires_grad: ema.register(name, param.data) parameters = filter(lambda p: p.requires_grad, model.parameters()) optimizer = optim.Adadelta(parameters, lr=args.learning_rate) criterion = nn.CrossEntropyLoss() writer = SummaryWriter(logdir='runs/' + args.model_time) model.train() loss, last_epoch = 0, -1 max_dev_exact, max_dev_f1 = -1, -1 iterator = data.train_iter for i, batch in tqdm(enumerate(iterator)): present_epoch = int(iterator.epoch) if present_epoch == args.epoch: break if present_epoch > last_epoch: print('epoch:', present_epoch + 1) last_epoch = present_epoch p1, p2 = model(batch) optimizer.zero_grad() # print(p1, batch.s_idx) # print(p2, batch.e_idx) batch_loss = criterion(p1, batch.s_idx) + criterion(p2, batch.e_idx) # print('p1', p1.shape, p1) # print('batch.s_idx', batch.s_idx.shape, batch.s_idx.shape) # print(loss, batch_loss.item()) loss += batch_loss.item() # print(loss) # print(batch_loss.item()) batch_loss.backward() optimizer.step() for name, param in model.named_parameters(): if param.requires_grad: ema.update(name, param.data) if (i + 1) % args.print_freq == 0: dev_loss, dev_exact, dev_f1 = test(model, ema, args, data) c = (i + 1) // args.print_freq writer.add_scalar('loss/train', loss, c) writer.add_scalar('loss/dev', dev_loss, c) writer.add_scalar('exact_match/dev', dev_exact, c) writer.add_scalar('f1/dev', dev_f1, c) print(f'train loss: {loss:.3f} / dev loss: {dev_loss:.3f}' f' / dev EM: {dev_exact:.3f} / dev F1: {dev_f1:.3f}') if dev_f1 > max_dev_f1: max_dev_f1 = dev_f1 max_dev_exact = dev_exact best_model = copy.deepcopy(model) loss = 0 model.train() writer.close() print(f'max dev EM: {max_dev_exact:.3f} / max dev F1: {max_dev_f1:.3f}') return best_model
class SNTGRunLoop(object): def __init__(self, net, dataloader=None, params=None, update_fn=None, eval_loader=None, test_loader=None, has_cuda=True): if has_cuda: device = torch.device("cuda:0") else: device = torch.device('cpu') self.net = net.to(device) self.loader = dataloader self.eval_loader = eval_loader self.test_loader = test_loader self.params = params self.device = device # self.net.to(device) if params is not None: n_data, num_classes = params['n_data'], params['num_classes'] n_eval_data, batch_size = params['n_eval_data'], params[ 'batch_size'] self.ensemble_pred = torch.zeros((n_data, num_classes), device=device) self.target_pred = torch.zeros((n_data, num_classes), device=device) t_one = torch.ones(()) self.epoch_pred = t_one.new_empty((n_data, num_classes), dtype=torch.float32, device=device) self.epoch_mask = t_one.new_empty((n_data), dtype=torch.float32, device=device) self.train_epoch_loss = \ t_one.new_empty((n_data // batch_size, 4), dtype=torch.float32, device=device) self.train_epoch_acc = \ t_one.new_empty((n_data // batch_size), dtype=torch.float32, device=device) self.eval_epoch_loss = \ t_one.new_empty((n_eval_data // batch_size, 2), dtype=torch.float32, device=device) self.eval_epoch_acc = \ t_one.new_empty((n_eval_data // batch_size, 2), dtype=torch.float32, device=device) self.optimizer = opt.Adam(self.net.parameters()) self.update_fn = update_fn self.ema = EMA(params['polyak_decay'], self.net, has_cuda) self.unsup_weight = 0.0 # self.loss_fn = nn.CrossEntropyLoss() def train(self): # labeled_loss = nn.CrossEntropyLoss() train_losses, train_accs = [], [] eval_losses, eval_accs = [], [] ema_eval_losses, ema_eval_accs = [], [] for epoch in range(self.params['num_epochs']): # training phase self.net.train() train_time = -time.time() self.epoch_pred.zero_() self.epoch_mask.zero_() # self.epoch_loss.zero_() self.unsup_weight = self.update_fn(self.optimizer, epoch) for i, data_batched in enumerate(self.loader, 0): images, is_lens, mask, indices = \ data_batched['image'], data_batched['is_lens'], \ data_batched['mask'], data_batched['index'] targets = torch.index_select(self.target_pred, 0, indices) # print(f"y value dimension:{is_lens.size()}") self.optimizer.zero_grad() outputs, h_x = self.net(images) # print(f"output dimension: {outputs.size()}") predicts = F.softmax(outputs, dim=1) # update for ensemble for k, j in enumerate(indices): self.epoch_pred[j] = predicts[k] self.epoch_mask[j] = 1.0 # labeled loss labeled_mask = mask.eq(0) # loss = self.loss_fn( # outputs[labeled_mask], is_lens[labeled_mask]) # labeled loss with binary entropy with logits, use one_hot one_hot = torch.zeros( len(is_lens[labeled_mask]), is_lens[labeled_mask].max()+1, device=self.device) \ .scatter_(1, is_lens[labeled_mask].unsqueeze(1), 1.) loss = F.binary_cross_entropy_with_logits( outputs[labeled_mask], one_hot) # one_hot = torch.zeros( # len(is_lens), is_lens.max() + 1, device=self.device) \ # .scatter_(1, is_lens.unsqueeze(1), 1.) # loss = F.binary_cross_entropy_with_logits(outputs, one_hot) # print(loss.item()) self.train_epoch_acc[i] = \ torch.mean(torch.argmax( outputs[labeled_mask], 1).eq(is_lens[labeled_mask]) .float()).item() # train_acc = torch.mean( # torch.argmax(outputs, 1).eq(is_lens).float()) self.train_epoch_loss[i, 0] = loss.item() # unlabeled loss unlabeled_loss = torch.mean((predicts - targets)**2) self.train_epoch_loss[i, 1] = unlabeled_loss.item() loss += unlabeled_loss * self.unsup_weight # SNTG loss if self.params['embed']: half = int(h_x.size()[0] // 2) eucd2 = torch.mean((h_x[:half] - h_x[half:])**2, dim=1) eucd = torch.sqrt(eucd2) target_hard = torch.argmax(targets, dim=1).int() merged_tar = torch.where(mask == 0, target_hard, is_lens.int()) neighbor_bool = torch.eq(merged_tar[:half], merged_tar[half:]) eucd_y = torch.where(eucd < 1.0, (1.0 - eucd)**2, torch.zeros_like(eucd)) embed_losses = torch.where(neighbor_bool, eucd2, eucd_y) embed_loss = torch.mean(embed_losses) self.train_epoch_loss[i, 2] = embed_loss.item() loss += embed_loss * \ self.unsup_weight * self.params['embed_coeff'] self.train_epoch_loss[i, 3] = loss.item() loss.backward() self.optimizer.step() self.ema.update() self.ensemble_pred = \ self.params['pred_decay'] * self.ensemble_pred + \ (1 - self.params['pred_decay']) * self.epoch_pred self.targets_pred = self.ensemble_pred / \ (1.0 - self.params['pred_decay'] ** (epoch + 1)) loss_mean = torch.mean(self.train_epoch_loss, 0) train_losses.append(loss_mean[3].item()) acc_mean = torch.mean(self.train_epoch_acc) train_accs.append(acc_mean.item()) print(f"epoch {epoch}, time cosumed: {time.time() + train_time}, " f"labeled loss: {loss_mean[0].item()}, " f"unlabeled loss: {loss_mean[1].item()}, " f"SNTG loss: {loss_mean[2].item()}, " f"total loss: {loss_mean[3].item()}") # print(f"epoch {epoch}, time consumed: {time.time() + train_time}, " # f"labeled loss: {loss_mean[0].item()}") # eval phase if self.eval_loader is not None: # none ema evaluation self.net.eval() for i, data_batched in enumerate(self.eval_loader, 0): images, is_lens = data_batched['image'], \ data_batched['is_lens'] # currently h_x in evalization is not used eval_logits, _ = self.ema(images) self.eval_epoch_acc[i, 0] = torch.mean( torch.argmax(eval_logits, 1).eq(is_lens).float()).item() # print(f"ema evaluation accuracy: {ema_eval_acc.item()}") eval_lens = torch.zeros( len(is_lens), is_lens.max()+1, device=self.device) \ .scatter_(1, is_lens.unsqueeze(1), 1.) # eval_loss = self.loss_fn(eval_logits, is_lens) self.eval_epoch_loss[i, 0] = \ F.binary_cross_entropy_with_logits( eval_logits, eval_lens).item() # break eval_logits, _ = self.net(images) self.eval_epoch_acc[i, 1] = torch.mean( torch.argmax(eval_logits, 1).eq(is_lens).float()).item() # print(f"evaluation accuracy: {eval_acc.item()}") self.eval_epoch_loss[i, 1] = \ F.binary_cross_entropy_with_logits( eval_logits, eval_lens).item() loss_mean = torch.mean(self.eval_epoch_loss, 0) acc_mean = torch.mean(self.eval_epoch_acc, 0) ema_eval_accs.append(acc_mean[0].item()) ema_eval_losses.append(loss_mean[0].item()) eval_accs.append(acc_mean[1].item()) eval_losses.append(loss_mean[1].item()) print(f"ema accuracy: {acc_mean[0].item()}" f"normal accuracy: {acc_mean[1].item()}") return train_losses, train_accs, eval_losses, eval_accs, \ ema_eval_losses, ema_eval_accs def test(self): self.net.eval() with torch.no_grad(): for i, data_batched in enumerate(self.test_loader, 0): images, is_lens = data_batched['image'], data_batched[ 'is_lens'] start = time.time() test_logits, _ = self.net(images) end = time.time() result = torch.argmax(F.softmax(test_logits, dim=1), dim=1) accuracy = torch.mean(result.eq(is_lens).float()).item() # return roc_curve(is_lens, test_logits) return result.tolist(), is_lens.tolist(), end - start, \ accuracy def test_origin(self): self.net.eval() with torch.no_grad(): for i, data_batched in enumerate(self.test_loader, 0): images, is_lens = data_batched['image'], data_batched[ 'is_lens'] test_logits, _ = self.net(images) return test_logits, is_lens