def train_loop_fn(loader): tracker = xm.RateTracker() model.train() for x, (data, target) in enumerate(loader): optimizer.zero_grad() output = model(data) loss = loss_fn(output, target) loss.backward() xm.optimizer_step(optimizer) tracker.add(FLAGS.batch_size) if x % FLAGS.log_steps == 0: xm.add_step_closure(_train_update, args=(device, x, loss, tracker))
def test_loop_fn(loader, epoch): total_samples, correct = 0, 0 model.eval() for step, (data, target) in enumerate(loader): output = model(data) pred = output.max(1, keepdim=True)[1] correct += pred.eq(target.view_as(pred)).sum() total_samples += data.size()[0] if step % FLAGS.log_steps == 0: xm.add_step_closure(test_utils.print_test_update, args=(device, None, epoch, step)) accuracy = 100.0 * correct.item() / total_samples accuracy = xm.mesh_reduce('test_accuracy', accuracy, np.mean) return accuracy
def train_loop_fn(loader): tracker = xm.RateTracker() model.train() for step, (data, target) in enumerate(loader): optimizer.zero_grad() output = model(data) loss = loss_fn(output, target) loss.backward() xm.optimizer_step(optimizer) tracker.add(flags.batch_size) if step % flags.log_steps == 0: xm.add_step_closure( _train_update, args=(device, step, loss, tracker, writer), run_async=FLAGS.async_closures)
def test_synchronous_exception(self): flag = Event() assert not flag.is_set() try: def closure(): flag.set() raise RuntimeError("Simulating exception in closure") xm.add_step_closure(closure) xm.mark_step() assert False # Should not reach here except RuntimeError as e: assert flag.is_set(), "Should have caught exception from closure"
def train_loop_fn(loader, epoch): tracker = xm.RateTracker() model.train() for step, (data, target) in enumerate(loader): optimizer.zero_grad() output = model(data) loss = loss_fn(output, target) loss.backward() optimizer.step() # do not reduce gradients on sharded params tracker.add(FLAGS.batch_size) if lr_scheduler: lr_scheduler.step() if step % FLAGS.log_steps == 0: xm.add_step_closure(_train_update, args=(device, step, loss, tracker, epoch, writer))
def train_loop_fn(loader): tracker = xm.RateTracker() model.train() for step, (data, target) in enumerate(loader): optimizer.zero_grad() with autocast(): output = model(data) loss = loss_fn(output, target) scaler.scale(loss).backward() gradients = xm._fetch_gradients(optimizer) xm.all_reduce('sum', gradients, scale=1.0 / xm.xrt_world_size()) scaler.step(optimizer) scaler.update() tracker.add(flags.batch_size) if step % flags.log_steps == 0: xm.add_step_closure(_train_update, args=(device, step, loss, tracker, writer))
def train_loop_fn(loader): tracker = xm.RateTracker() model.train() for step, (data, target) in enumerate(loader): with xp.StepTrace('train_mnist', step_num=step): with xp.Trace('build_graph'): optimizer.zero_grad() output = model(data) loss = loss_fn(output, target) loss.backward() xm.optimizer_step(optimizer) tracker.add(flags.batch_size) if step % flags.log_steps == 0: xm.add_step_closure(_train_update, args=(device, step, loss, tracker, writer))
def train_loop_fn(device, trainer, loader, last_batch_index): """ This is the main training loop. It trains for 1 epoch. """ def print_training_update(trainer, progress, args, i): stats = get_training_stats(trainer, args=args) stats['now'] = now() progress.log(stats, tag='train', step=trainer.get_num_updates()) progress.print_mid_epoch(i+1, force=True) stats, log_output, skip_stat_keys = None, None, {'clip'} max_update = args.max_update or math.inf for i, samples in enumerate(loader, start=epoch_itr.iterations_in_epoch): if i == last_batch_index: # last batches are incomplete break log_output = trainer.train_step(samples) reset_perf_training_meters(trainer, i, ignore_index=10) if (not (i % args.log_steps)) or (i == last_batch_index-1): step_args = trainer, progress, args, i xm.add_step_closure(print_training_update, args=step_args) num_updates = trainer.get_num_updates() if ( not args.disable_validation and args.save_interval_updates > 0 and num_updates % args.save_interval_updates == 0 and num_updates > 0 ): vloss = validate_subset( args, device, trainer, task, epoch_itr, valid_subsets[0] ) checkpoint_utils.save_checkpoint( args, trainer, epoch_itr, vloss.item(), epoch=epoch, end_of_epoch=False, ) if num_updates >= max_update: break
def validate(args, e, mdl, opt, ls_f, vd_dl, tp_ix, tp_rix, ls_mtr, tb_sw): vd_ls = Integer(0) mtr = Metric() vd = pl.ParallelLoader(vd_dl, [ args.dvc, ]).per_device_loader(args.dvc) if args.tpu else vd_dl mdl.eval() with torch.no_grad(): for p, n, w, x, y in vd: p = p.view(-1, p.shape[-1]).to(args.aux_dvc) n = n.view(-1, n.shape[-1]).to(args.aux_dvc) x = x.view(-1, x.shape[-1]).to(args.aux_dvc) ls = _loss(args, p, n, w, mdl, ls_f) if args.tpu: xm.add_step_closure(_validate, args=(vd_ls, ls)) else: _validate(vd_ls, ls) evaluate(args, x, y, mdl, tp_ix, tp_rix, mtr) vd_ls.val /= len(vd_dl) vd_ls_avg = vd_ls.val if args.tpu else _allreduce( vd_ls.val, 'validate.vd_ls_avg', hvd.Average) if not args.tpu: mtr.allreduce() if is_master(args): print(f'Epoch {e}/{args.epochs} validation loss: {vd_ls_avg}') print(mtr) tb_sw.add_scalar(f'loss/validation', vd_ls_avg, e) tb_sw.add_scalars('validation', dict(mtr)) is_bst, bst_ls = ls_mtr.update(vd_ls_avg) _checkpoint(args, e, mdl, opt, bst_ls, is_bst)
def train(rank, args): print('enter train @ %s'%(rank), flush=True) args.rank = rank args.split = '' torch.manual_seed(42) save_fn = os.path.join(args.save_dir, 'checkpoint_final.pt') tokenizer = get_tokenizer(args) args.vocab_size = tokenizer._tokenizer.get_vocab_size() if not args.vocab_size else args.vocab_size train_dataset = get_dataset(args) batched_already = hasattr(train_dataset, '__getbatch__') if args.total_num_updates < 100: args.total_num_updates = len(train_dataset) * args.total_num_updates if args.warmup_updates < 1: args.warmup_updates = int(args.total_num_updates * args.warmup_updates) else: args.warmup_updates = int(args.warmup_updates) train_sampler = None if args.gpus: dist.init_process_group( 'nccl', rank=rank, world_size=args.world_size ) if args.gpus > 1: train_sampler = torch.utils.data.distributed.DistributedSampler( train_dataset, num_replicas=args.gpus, rank=rank, shuffle=args.shuffle) else: rank = xm.get_ordinal() if xm.xrt_world_size() > 1: train_sampler = torch.utils.data.distributed.DistributedSampler( train_dataset, num_replicas=xm.xrt_world_size(), rank=rank, shuffle=args.shuffle) train_loader = torch.utils.data.DataLoader( train_dataset, batch_size=args.batch_size if not batched_already else None, sampler=train_sampler, pin_memory=True, shuffle=False, num_workers=args.num_workers) eval_loaders = [] if args.eval_dir: for split in args.splits.split(','): split = split.strip() eval_sampler = None if args.gpus: if args.gpus > 1: eval_sampler = torch.utils.data.distributed.DistributedSampler( train_dataset, num_replicas=args.gpus, rank=rank, shuffle=False) else: rank = xm.get_ordinal() if xm.xrt_world_size() > 1: eval_sampler = torch.utils.data.distributed.DistributedSampler( train_dataset, num_replicas=xm.xrt_world_size(), rank=rank, shuffle=False) args.split = split eval_dataset = get_eval_dataset(args) eval_loader = torch.utils.data.DataLoader( eval_dataset, batch_size=args.batch_size if not batched_already else None, sampler=eval_sampler, pin_memory=True, shuffle=False, num_workers=args.num_workers) eval_loaders.append(eval_loader) if args.gpus: assert apex_enabled torch.cuda.set_device(rank) ########################## ## ## Model Creation ## ########################## model = get_model(args, tokenizer) model.cuda(rank) device = torch.device('cuda:'+str(rank)) ########################## ## ## Init Optimizer ## ########################## optimizer = apex.optimizers.FusedAdam( model_get_parameters(model, lr=args.lr, lw_lr_decay=args.lw_lr_decay, weight_decay=args.weight_decay, special_layer_wise_lr=args.special_layer_wise_lr, log = rank == 0, ), # use this function to set extra optimizer arguments, # see model_get_parameters betas=(0.9, 0.999), eps=1e-6, lr=args.lr, weight_decay=args.weight_decay ) model, optimizer = amp.initialize(model, optimizer, opt_level='O1') model = DDP(model) batches = train_loader else: assert tpu_enabled device = xm.xla_device() ########################## ## ## Model Creation ## ########################## model = get_model(args, tokenizer) ########################## ## ## For shared parameters, TPU requires modules to be tied after .to(device) ## So we first find the shared parameters first ## ########################## shared_parameters = {e[0]: e[1:] for e in _catalog_shared_params(model)} model.to(device) do_share_parameters_again(model, shared_parameters, log = rank == 0) ########################## ## ## Init Optimizer ## ########################## optimizer = optim.Adam( model_get_parameters(model, lr=args.lr, lw_lr_decay=args.lw_lr_decay, weight_decay=args.weight_decay ), # use this function to set extra optimizer arguments, # see model_get_parameters lr=args.lr, weight_decay=args.weight_decay ) writer = None if xm.is_master_ordinal(): writer = test_utils.get_summary_writer(args.save_dir) xm.rendezvous("load_checkpoint") # wait for all workers xm.mark_step() # tracker = xm.RateTracker() if args.restore_file: states = torch.load(args.restore_file, map_location=device) for k, v in list(states.items()): if k.startswith('module.'): del states[k] k = k[7:] states[k] = v if k.endswith('position_ids'): del states[k] states[k[:-12] + 'position_embeddings'] = v if args.gpus: states = {"module.%s"%k : v for k, v in states.items()} try: model.load_state_dict(states) except Exception as err: import traceback if rank == 0: traceback.print_exc() model.load_state_dict(states, strict=False) if rank == 0: if not os.path.exists(os.path.dirname(save_fn)): try: os.makedirs(os.path.dirname(save_fn)) except OSError as exc: # Guard against race condition if exc.errno != errno.EEXIST: raise if args.gpus: torch.save(model.state_dict(), save_fn ) else: xm.save(model.state_dict(), save_fn ) model.train() if args.anomaly_detection and rank == 0: torch.set_anomaly_enabled(True) ########################## ## ## Init LR Scheduler ## ########################## if not batched_already: args.total_num_updates = args.total_num_updates // args.batch_size args.warmup_updates = args.total_num_updates // args.batch_size args.total_num_updates = args.total_num_updates // args.world_size args.warmup_updates = args.total_num_updates // args.world_size scheduler = get_linear_schedule_with_warmup( optimizer, num_warmup_steps=args.warmup_updates, num_training_steps=args.total_num_updates, ) step_i = 0 err = None tb = None #tb = SummaryWriter() try: if rank == 0: pbar = tqdm(total=args.total_num_updates, file=sys.stdout) while step_i < args.total_num_updates: if not args.gpus: batches = pl.ParallelLoader(train_loader, [device]).per_device_loader(device) n_samples = len(batches) for sample in batches: step_i += 1 if step_i > args.total_num_updates: break report_step = step_i % args.log_interval == 0 while True: # the loop only for apex Gradient Overflow optimizer.zero_grad() total_loss, log = get_loss( model, sample, args=args, device=device, gpus=args.gpus, report=report_step ) if args.gpus: default_optimizer_step = optimizer.step with amp.scale_loss(total_loss, optimizer) as scaled_loss: scaled_loss.backward() # If Amp detects an overflow, it patches optimizer.step. In other words, if optimizer.step # was left unpatched, there was no overflow, and we don't need to replay. if optimizer.step is default_optimizer_step: optimizer.step() break optimizer.step() # If an overflow was detected, "optimizer.step" is the patched call, which does # nothing but restore optimizer.step to default_optimizer_step. if rank == 0: print("Overflowed, reducing loss scale and replaying batch.", flush=True) else: total_loss.backward() xm.optimizer_step(optimizer) xm.mark_step() break scheduler.step() if report_step: if 'loss' not in log: log['loss'] = total_loss # tb.add_scalar("Loss", total_loss, step_i) for k, v in log.items(): try: dist.all_reduce(v, op=dist.reduce_op.SUM) log[k] = float(v) except Exception as e: print(v, e) pass if args.gpus: if rank == 0: pbar.set_description(format_log(log, log_formatter, tb, step_i)) else: xm.add_step_closure(_train_update, args=(log, log_formatter, tb, step_i)) if args.report_metrics: xm.master_print(met.metrics_report()) if rank == 0: pbar.update(1) if rank == 0: pbar.close() if eval_loaders: model.half() model.eval() model.cuda() for k, v in model.named_parameters(): v.requires_grad =False for split, eval_loader in zip(args.splits.split(','), eval_loaders): batches = eval_loader if rank == 0: eval_length = len(batches) if not batched_already: eval_length = eval_length // args.batch_size eval_length = eval_length // args.world_size pbar = tqdm(total=eval_length, file=sys.stdout) if not args.gpus: batches = pl.ParallelLoader(eval_loader, [device]).per_device_loader(device) with torch.no_grad(): record = OrderedDict() for sample in batches: evaluate( model, sample, args=args, device=device, record=record, gpus=args.gpus, report=False ) if rank == 0: pbar.update(1) for k, v in record.items(): try: def handle_reduce(v): if len(v.shape) == 0: dist.all_reduce(v, op=dist.reduce_op.SUM) else: L = [torch.ones_like(v) for _ in range(dist.get_world_size())] dist.all_gather(L, v) v = torch.car(L, dim=0) return v if isinstance(v, list): v = [handle_reduce(e) for e in v] else: v = handle_reduce(v) record[k] = float(v) except Exception as e: pass post_evaluate(record, args=args) import json if rank == 0: print('',flush=True) print('Test result for %s'%split, flush=True) print(json.dumps(record, indent=2),flush=True) print('',flush=True) except Exception as _err: err = _err finally: folder = os.path.split(os.path.abspath(save_fn))[0] os.makedirs(folder, exist_ok=True) if rank == 0: print("Saving to %s"%save_fn) if args.gpus: torch.save(model.state_dict(), save_fn ) if err: raise err else: xm.save(model.state_dict(), save_fn ) if err: raise err print("Saved to %s"%save_fn)
def train(rank, args): print('enter train @ %s' % (rank), flush=True) args.rank = rank torch.manual_seed(42) tokenizer = get_tokenizer(args) args.vocab_size = tokenizer._tokenizer.get_vocab_size() train_dataset = get_dataset(args) if args.total_num_updates < 100: args.total_num_updates = len(train_dataset) * args.total_num_updates if args.warmup_updates < 1: args.warmup_updates = int(args.total_num_updates * args.warmup_updates) else: args.warmup_updates = int(args.warmup_updates) train_sampler = None if args.gpus: dist.init_process_group('nccl', rank=rank, world_size=args.world_size) if args.gpus > 1: train_sampler = torch.utils.data.distributed.DistributedSampler( train_dataset, num_replicas=args.gpus, rank=rank, shuffle=False) else: rank = xm.get_ordinal() if xm.xrt_world_size() > 1: train_sampler = torch.utils.data.distributed.DistributedSampler( train_dataset, num_replicas=xm.xrt_world_size(), rank=rank, shuffle=False) train_loader = torch.utils.data.DataLoader( train_dataset, batch_size=args.batch_size if not hasattr(train_dataset, '__getbatch__') else None, sampler=train_sampler, pin_memory=True, shuffle=False, num_workers=args.num_workers) eval_loader = None if args.eval_dir: eval_sampler = None if args.gpus: dist.init_process_group('nccl', rank=rank, world_size=args.world_size) if args.gpus > 1: traieval_samplern_sampler = torch.utils.data.distributed.DistributedSampler( train_dataset, num_replicas=args.gpus, rank=rank, shuffle=False) else: rank = xm.get_ordinal() if xm.xrt_world_size() > 1: eval_sampler = torch.utils.data.distributed.DistributedSampler( train_dataset, num_replicas=xm.xrt_world_size(), rank=rank, shuffle=False) eval_dataset = get_eval_dataset(args) eval_loader = torch.utils.data.DataLoader( eval_dataset, batch_size=args.batch_size if not hasattr(train_dataset, '__getbatch__') else None, sampler=eval_sampler, pin_memory=True, shuffle=False, num_workers=args.num_workers) if args.gpus: assert apex_enabled torch.cuda.set_device(rank) ########################## ## ## Model Creation ## ########################## model = get_model(args) model.cuda(rank) device = torch.device('cuda:' + str(rank)) ########################## ## ## Init Optimizer ## ########################## optimizer = apex.optimizers.FusedAdam( model_get_parameters(model, lr=args.lr, lw_lr_decay=args.lw_lr_decay, weight_decay=args.weight_decay), # use this function to set extra optimizer arguments, # see model_get_parameters betas=(0.9, 0.999), eps=1e-6, lr=args.lr, weight_decay=args.weight_decay) model, optimizer = amp.initialize(model, optimizer, opt_level='O1') model = DDP(model) batches = train_loader else: assert tpu_enabled device = xm.xla_device() ########################## ## ## Model Creation ## ########################## model = get_model(args) ########################## ## ## For shared parameters, TPU requires modules to be tied after .to(device) ## So we first find the shared parameters first ## ########################## shared_parameters = { e[0]: e[1:] for e in _catalog_shared_params(model) } model.to(device) do_share_parameters_again(model, shared_parameters, log=rank == 0) ########################## ## ## Init Optimizer ## ########################## optimizer = optim.Adam( model_get_parameters(model, lr=args.lr, lw_lr_decay=args.lw_lr_decay, weight_decay=args.weight_decay), # use this function to set extra optimizer arguments, # see model_get_parameters lr=args.lr, weight_decay=args.weight_decay) writer = None if xm.is_master_ordinal(): writer = test_utils.get_summary_writer(args.save_dir) xm.rendezvous("load_checkpoint") # wait for all workers xm.mark_step() # tracker = xm.RateTracker() if args.restore_file: states = torch.load(args.restore_file, map_location=device) for k, v in list(states.items()): if k.startswith('module.'): del states[k] k = k[7:] states[k] = v if k.endswith('position_ids'): del states[k] states[k[:-12] + 'position_embeddings'] = v try: model.load_state_dict(states) except Exception as err: import traceback traceback.print_exc() model.load_state_dict(states, strict=False) model.train() if args.anomaly_detection and rank == 0: torch.set_anomaly_enabled(True) ########################## ## ## Init LR Scheduler ## ########################## scheduler = get_linear_schedule_with_warmup( optimizer, num_warmup_steps=args.warmup_updates, num_training_steps=args.total_num_updates, ) step_i = 0 err = None try: if rank == 0: pbar = tqdm(total=args.total_num_updates) while step_i < args.total_num_updates: if not args.gpus: batches = pl.ParallelLoader(train_loader, [device]).per_device_loader(device) for sample in batches: step_i += 1 if step_i > args.total_num_updates: break report_step = step_i % args.log_interval == 0 while True: # the loop only for apex Gradient Overflow optimizer.zero_grad() total_loss, log = get_loss(model, sample, args=args, device=device, gpu=args.gpus, report=report_step) if args.gpus: default_optimizer_step = optimizer.step with amp.scale_loss(total_loss, optimizer) as scaled_loss: scaled_loss.backward() # If Amp detects an overflow, it patches optimizer.step. In other words, if optimizer.step # was left unpatched, there was no overflow, and we don't need to replay. if optimizer.step is default_optimizer_step: optimizer.step() break optimizer.step( ) # If an overflow was detected, "optimizer.step" is the patched call, which does # nothing but restore optimizer.step to default_optimizer_step. if rank == 0: print( "Overflowed, reducing loss scale and replaying batch.", flush=True) else: total_loss.backward() xm.optimizer_step(optimizer) xm.mark_step() break scheduler.step() if report_step: if 'loss' not in log: log['loss'] = total_loss if args.gpus: if rank == 0: pbar.set_description(format_log( log, log_formatter)) else: xm.add_step_closure(_train_update, args=(log, log_formatter)) if args.report_metrics: xm.master_print(met.metrics_report()) if rank == 0: pbar.update(1) if eval_loader is not None: model.eval() if not args.gpus: batches = pl.ParallelLoader(eval_loader, [device]).per_device_loader(device) with torch.no_grad(): record = OrderedDict() for sample in batches: evaluate(model, sample, args=args, device=device, record=record, gpu=args.gpus, report=report_step) post_evaluate(record, args=args) import json print('', flush=True) print(json.dumps(record), flush=True) print('', flush=True) except Exception as _err: err = _err finally: save_fn = os.path.join(args.save_dir, 'checkpoint_final.pt') folder = os.path.split(os.path.abspath(save_fn))[0] os.makedirs(folder, exist_ok=True) if rank == 0 and args.gpus: torch.save(model.state_dict(), save_fn) if err: raise err else: xm.save(model.state_dict(), save_fn) if err: raise err
def train(self): hps = self.state.hps ss = self.state current_stats = {} writer_stats = {} # for resuming the learning rate sorted_lr_steps = sorted(self.learning_rates.keys()) lr_index = util.greatest_lower_bound(sorted_lr_steps, ss.data.global_step) ss.update_learning_rate(self.learning_rates[sorted_lr_steps[lr_index]]) if ss.model.bn_type != 'none': sorted_as_steps = sorted(self.anneal_schedule.keys()) as_index = util.greatest_lower_bound(sorted_as_steps, ss.data.global_step) ss.model.objective.update_anneal_weight( self.anneal_schedule[sorted_as_steps[as_index]]) if ss.model.bn_type in ('vqvae', 'vqvae-ema'): ss.model.init_codebook(self.data_iter, 10000) start_time = time.time() for batch_num, batch in enumerate(self.device_loader): wav, mel, voice, jitter, position = batch global_step = len(ss.data.dataset) * position[0] + position[1] # print(f'replica {self.replica_index}, batch {batch_num}', file=stderr) # stderr.flush() if (batch_num % hps.save_interval == 0 and batch_num != 0): self.save_checkpoint(position) if hps.skip_loop_body: continue lr_index = util.greatest_lower_bound(sorted_lr_steps, global_step) ss.update_learning_rate( self.learning_rates[sorted_lr_steps[lr_index]]) # if ss.data.global_step in self.learning_rates: # ss.update_learning_rate(self.learning_rates[ss.data.global_step]) if ss.model.bn_type == 'vae' and ss.step in self.anneal_schedule: ss.model.objective.update_anneal_weight( self.anneal_schedule[ss.data.global_step]) ss.optim.zero_grad() quant, self.target, loss = self.state.model.run( wav, mel, voice, jitter) self.probs = self.softmax(quant) self.mel_enc_input = mel # print(f'after model.run', file=stderr) # stderr.flush() loss.backward() # print(f'after loss.backward()', file=stderr) # stderr.flush() if batch_num % hps.progress_interval == 0: pars_copy = [p.data.clone() for p in ss.model.parameters()] # print(f'after pars_copy', file=stderr) # stderr.flush() if self.is_tpu: xm.optimizer_step(ss.optim) else: ss.optim.step() ss.optim_step += 1 if ss.model.bn_type == 'vqvae-ema' and ss.data.global_step == 10000: ss.model.bottleneck.update_codebook() tprb_m = self.avg_prob_target() if batch_num % hps.progress_interval == 0: iterator = zip(pars_copy, ss.model.named_parameters()) uw_ratio = { np[0]: t.norm(c - np[1].data) / c.norm() for c, np in iterator } writer_stats.update({'uwr': uw_ratio}) if self.is_tpu: count = torch_xla._XLAC._xla_get_replication_devices_count( ) loss_red, tprb_red = xm.all_reduce('sum', [loss, tprb_m], scale=1.0 / count) # loss_red = xm.all_reduce('all_loss', loss, reduce_mean) # tprb_red = xm.all_reduce('all_tprb', tprb_m, reduce_mean) else: loss_red = loss tprb_red = tprb_m writer_stats.update({ 'loss_r': loss_red, 'tprb_r': tprb_red, 'optim_step': ss.optim_step }) current_stats.update({ 'optim_step': ss.optim_step, 'gstep': global_step, # 'gstep': ss.data.global_step, 'epoch': position[0], 'step': position[1], # 'loss': loss, 'lrate': ss.optim.param_groups[0]['lr'], # 'tprb_m': tprb_m, # 'pk_d_m': avg_peak_dist }) current_stats.update(ss.model.objective.metrics) if ss.model.bn_type in ('vae'): current_stats['free_nats'] = ss.model.objective.free_nats current_stats['anneal_weight'] = \ ss.model.objective.anneal_weight.item() if ss.model.bn_type in ('vqvae', 'vqvae-ema', 'ae', 'vae'): current_stats.update(ss.model.encoder.metrics) if self.is_tpu: xm.add_step_closure(self.train_update, args=(writer_stats, current_stats)) else: self.train_update(writer_stats, current_stats) # if not self.is_tpu or xm.is_master_ordinal(): # if batch_num in range(25, 50) or batch_num in range(75, 100): stderr.flush() elapsed = time.time() - start_time