def train(args, trainer, task, epoch_itr): """Train the model for one epoch and return validation losses.""" # Initialize data iterator itr = epoch_itr.next_epoch_itr( fix_batches_to_gpus=args.fix_batches_to_gpus, shuffle=(epoch_itr.next_epoch_idx > args.curriculum), ) update_freq = ( args.update_freq[epoch_itr.epoch - 1] if epoch_itr.epoch <= len(args.update_freq) else args.update_freq[-1] ) itr = iterators.GroupedIterator(itr, update_freq) if getattr(args, "tpu", False): itr = tpu_data_loader(args, itr) progress = progress_bar.progress_bar( itr, log_format=args.log_format, log_interval=args.log_interval, epoch=epoch_itr.epoch, tensorboard_logdir=( args.tensorboard_logdir if distributed_utils.is_master(args) else None ), default_log_format=("tqdm" if not args.no_progress_bar else "simple"), ) trainer.begin_epoch(epoch_itr.epoch) valid_subsets = args.valid_subset.split(",") should_stop = False for i, samples in enumerate(progress): with metrics.aggregate("train_inner"), torch.autograd.profiler.record_function("train_step-%d" % i): log_output = trainer.train_step(samples) if log_output is None: # OOM, overflow, ... continue # log mid-epoch stats num_updates = trainer.get_num_updates() if num_updates % args.log_interval == 0: stats = get_training_stats(metrics.get_smoothed_values("train_inner")) progress.log(stats, tag="train_inner", step=num_updates) # reset mid-epoch stats after each log interval # the end-of-epoch stats will still be preserved metrics.reset_meters("train_inner") end_of_epoch = not itr.has_next() valid_losses, should_stop = validate_and_save( args, trainer, task, epoch_itr, valid_subsets, end_of_epoch ) if should_stop: break # log end-of-epoch stats stats = get_training_stats(metrics.get_smoothed_values("train")) progress.print(stats, tag="train", step=num_updates) # reset epoch-level meters metrics.reset_meters("train") return valid_losses, should_stop
def train(args, trainer, task, epoch_itr): """Train the model for one epoch.""" # Initialize data iterator itr = epoch_itr.next_epoch_itr( fix_batches_to_gpus=args.fix_batches_to_gpus, shuffle=(epoch_itr.next_epoch_idx > args.curriculum), ) update_freq = (args.update_freq[epoch_itr.epoch - 1] if epoch_itr.epoch <= len(args.update_freq) else args.update_freq[-1]) itr = iterators.GroupedIterator(itr, update_freq) progress = progress_bar.progress_bar( itr, log_format=args.log_format, log_interval=args.log_interval, epoch=epoch_itr.epoch, tensorboard_logdir=(args.tensorboard_logdir if distributed_utils.is_master(args) else None), default_log_format=('tqdm' if not args.no_progress_bar else 'simple'), ) # task specific setup per epoch task.begin_epoch(epoch_itr.epoch, trainer.get_model()) valid_subsets = args.valid_subset.split(',') max_update = args.max_update or math.inf for samples in progress: with metrics.aggregate('train_inner'): log_output = trainer.train_step(samples) if log_output is None: # OOM, overflow, ... continue # log mid-epoch stats num_updates = trainer.get_num_updates() if num_updates % args.log_interval == 0: stats = get_training_stats( metrics.get_smoothed_values('train_inner')) progress.log(stats, tag='train_inner', step=num_updates) # reset mid-epoch stats after each log interval # the end-of-epoch stats will still be preserved metrics.reset_meters('train_inner') if (not args.disable_validation and args.save_interval_updates > 0 and num_updates % args.save_interval_updates == 0 and num_updates > 0): valid_losses = validate(args, trainer, task, epoch_itr, valid_subsets) checkpoint_utils.save_checkpoint(args, trainer, epoch_itr, valid_losses[0]) if num_updates >= max_update: break # log end-of-epoch stats stats = get_training_stats(metrics.get_smoothed_values('train')) progress.print(stats, tag='train', step=num_updates) # reset epoch-level meters metrics.reset_meters('train')
def train(args, trainer, task, epoch_itr, max_update=math.inf): """Train the model for one epoch and return validation losses.""" # Initialize data iterator itr = epoch_itr.next_epoch_itr( fix_batches_to_gpus=args.fix_batches_to_gpus, shuffle=(epoch_itr.next_epoch_idx > args.curriculum), ) update_freq = (args.update_freq[epoch_itr.epoch - 1] if epoch_itr.epoch <= len(args.update_freq) else args.update_freq[-1]) itr = iterators.GroupedIterator(itr, update_freq) if getattr(args, 'tpu', False): itr = tpu_data_loader(args, itr) progress = progress_bar.progress_bar( itr, log_format=args.log_format, log_interval=args.log_interval, epoch=epoch_itr.epoch, tensorboard_logdir=(args.tensorboard_logdir if distributed_utils.is_master(args) else None), default_log_format=('tqdm' if not args.no_progress_bar else 'simple'), ) progress.log_args(args, tag='train') trainer.begin_epoch(epoch_itr.epoch) valid_subsets = args.valid_subset.split(',') for samples in progress: with metrics.aggregate('train_inner'): log_output = trainer.train_step(samples) if log_output is None: # OOM, overflow, ... continue # log mid-epoch stats num_updates = trainer.get_num_updates() if num_updates % args.log_interval == 0: stats = get_training_stats( metrics.get_smoothed_values('train_inner')) progress.log(stats, tag='train_inner', step=num_updates) # reset mid-epoch stats after each log interval # the end-of-epoch stats will still be preserved metrics.reset_meters('train_inner') end_of_epoch = not itr.has_next() valid_losses = validate_and_save(args, trainer, task, epoch_itr, valid_subsets, end_of_epoch) if should_stop_early(args, valid_losses[0]) or num_updates >= max_update: break # log end-of-epoch stats stats = get_training_stats(metrics.get_smoothed_values('train')) progress.print(stats, tag='train', step=num_updates) # reset epoch-level meters metrics.reset_meters('train') return valid_losses
def tmp(): fs_args, ds_config = gen_ds_fairseq_arg() set_seed(fs_args.seed) task = tasks.setup_task(fs_args) trainer = DsFairseqTrainer(fs_args, ds_config, task) batch_itr = BatchIterator(fs_args, task) for epoch in batch_itr.train_epoch(): train(batch_itr, trainer) log_dist( f'Finish epoch {epoch}, \ {view_log(metrics.get_smoothed_values("train"))}', [0], ) metrics.reset_meters("train")
def train_step(self, sample, is_dummy_batch): self.model.train() self.model.zero_grad() loss, sample_size, logging_output = self.model(sample) if is_dummy_batch: if torch.is_tensor(sample_size): sample_size.zero_() else: sample_size *= 0.0 loss *= 0.0 if torch.is_tensor(sample_size): sample_size = sample_size.float() else: sample_size = float(sample_size) logging_outputs, (sample_size, ) = torch_reduce_sum( self.model.device, [logging_output], sample_size, ignore=is_dummy_batch) final_loss = loss * (dist.get_world_size() / sample_size) self.model.backward(final_loss) self.model.step() logging_output = self.reduce_log(logging_outputs, sample_size) if self.model.global_steps % self.model.steps_per_print() != 0: return log_dist( f'Step: {self.model.global_steps}, \ {view_log(metrics.get_smoothed_values("train_inner"))}', [0], ) metrics.reset_meters("train_inner")
def train(cfg: DictConfig, trainer: Trainer, task: tasks.FairseqTask, epoch_itr) -> Tuple[List[Optional[float]], bool]: """Train the model for one epoch and return validation losses.""" # Initialize data iterator itr = epoch_itr.next_epoch_itr( fix_batches_to_gpus=cfg.distributed_training.fix_batches_to_gpus, shuffle=(epoch_itr.next_epoch_idx > cfg.dataset.curriculum), ) update_freq = (cfg.optimization.update_freq[epoch_itr.epoch - 1] if epoch_itr.epoch <= len(cfg.optimization.update_freq) else cfg.optimization.update_freq[-1]) itr = iterators.GroupedIterator(itr, update_freq) if cfg.common.tpu: itr = utils.tpu_data_loader(itr) progress = progress_bar.progress_bar( itr, log_format=cfg.common.log_format, log_interval=cfg.common.log_interval, epoch=epoch_itr.epoch, tensorboard_logdir=(cfg.common.tensorboard_logdir if distributed_utils.is_master( cfg.distributed_training) else None), default_log_format=("tqdm" if not cfg.common.no_progress_bar else "simple"), wandb_project=(cfg.common.wandb_project if distributed_utils.is_master( cfg.distributed_training) else None), wandb_run_name=os.environ.get( "WANDB_NAME", os.path.basename(cfg.checkpoint.save_dir)), azureml_logging=(cfg.common.azureml_logging if distributed_utils.is_master( cfg.distributed_training) else False), ) progress.update_config(_flatten_config(cfg)) trainer.begin_epoch(epoch_itr.epoch) valid_subsets = cfg.dataset.valid_subset.split(",") should_stop = False num_updates = trainer.get_num_updates() logger.info("Start iterating over samples") for i, samples in enumerate(progress): with metrics.aggregate( "train_inner"), torch.autograd.profiler.record_function( "train_step-%d" % i): log_output = trainer.train_step(samples) if log_output is not None: # not OOM, overflow, ... # log mid-epoch stats num_updates = trainer.get_num_updates() if num_updates % cfg.common.log_interval == 0: stats = get_training_stats( metrics.get_smoothed_values("train_inner")) progress.log(stats, tag="train_inner", step=num_updates) # reset mid-epoch stats after each log interval # the end-of-epoch stats will still be preserved metrics.reset_meters("train_inner") end_of_epoch = not itr.has_next() valid_losses, should_stop = validate_and_save(cfg, trainer, task, epoch_itr, valid_subsets, end_of_epoch) if should_stop: break # log end-of-epoch stats logger.info("end of epoch {} (average epoch stats below)".format( epoch_itr.epoch)) stats = get_training_stats(metrics.get_smoothed_values("train")) progress.print(stats, tag="train", step=num_updates) # reset epoch-level meters metrics.reset_meters("train") return valid_losses, should_stop
def train(args, trainer, task, epoch_itr): """Train the model for one epoch.""" # Initialize data iterator itr = epoch_itr.next_epoch_itr( fix_batches_to_gpus=args.fix_batches_to_gpus, shuffle=(epoch_itr.next_epoch_idx > args.curriculum), ) update_freq = (args.update_freq[epoch_itr.epoch - 1] if epoch_itr.epoch <= len(args.update_freq) else args.update_freq[-1]) itr = iterators.GroupedIterator(itr, update_freq) progress = progress_bar.progress_bar( itr, log_format=args.log_format, log_interval=args.log_interval, epoch=epoch_itr.epoch, tensorboard_logdir=(args.tensorboard_logdir if distributed_utils.is_master(args) else None), default_log_format=('tqdm' if not args.no_progress_bar else 'simple'), ) # task specific setup per epoch task.begin_epoch(epoch_itr.epoch, trainer.get_model()) valid_subsets = args.valid_subset.split(',') max_update = args.max_update or math.inf should_end_training = False for samples in progress: with metrics.aggregate('train_inner'): try: log_output = trainer.train_step(samples) except ResetTrainerException: trainer._wrapped_criterion = None trainer._wrapped_model = None trainer._optimizer = None logger.info("reset the trainer at {}".format( trainer.get_num_updates())) log_output = trainer.train_step(samples) if log_output is None: # OOM, overflow, ... continue # log mid-epoch stats num_updates = trainer.get_num_updates() if num_updates % args.log_interval == 0: stats = get_training_stats( metrics.get_smoothed_values('train_inner')) progress.log(stats, tag='train_inner', step=num_updates) # reset mid-epoch stats after each log interval # the end-of-epoch stats will still be preserved metrics.reset_meters('train_inner') valid_losses = validate_and_save(args, trainer, task, epoch_itr, valid_subsets) if should_stop_early(args, valid_losses[0]) or num_updates >= max_update: should_end_training = True break # log end-of-epoch stats stats = get_training_stats(metrics.get_smoothed_values('train')) progress.print(stats, tag='train', step=num_updates) # reset epoch-level meters metrics.reset_meters('train') return should_end_training
def train(args, trainer, task, epoch_itr, max_update=math.inf, model=None): """Train the model for one epoch and return validation losses.""" # Initialize data iterator itr = epoch_itr.next_epoch_itr( fix_batches_to_gpus=args.fix_batches_to_gpus, shuffle=(epoch_itr.next_epoch_idx > args.curriculum), ) update_freq = ( args.update_freq[epoch_itr.epoch - 1] if epoch_itr.epoch <= len(args.update_freq) else args.update_freq[-1] ) itr = iterators.GroupedIterator(itr, update_freq) progress = progress_bar.progress_bar( itr, log_format=args.log_format, log_interval=args.log_interval, epoch=epoch_itr.epoch, tensorboard_logdir=( args.tensorboard_logdir if distributed_utils.is_master(args) else None ), default_log_format=('tqdm' if not args.no_progress_bar else 'simple'), ) # task specific setup per epoch task.begin_epoch(epoch_itr.epoch, trainer.get_model()) valid_subsets = args.valid_subset.split(',') for i, samples in enumerate(progress): with metrics.aggregate('train_inner'): log_output = trainer.train_step(samples) if log_output is None: # OOM, overflow, ... continue # log mid-epoch stats num_updates = trainer.get_num_updates() if num_updates % args.log_interval == 0: stats = get_training_stats(metrics.get_smoothed_values('train_inner')) progress.log(stats, tag='train_inner', step=num_updates) # reset mid-epoch stats after each log interval # the end-of-epoch stats will still be preserved metrics.reset_meters('train_inner') if(i==0): print('epoch: ', epoch_itr.epoch) endeattn_norm=[] selfattn_norm=[] for m in model.modules(): if(hasattr(m, 'selfattn_norm')): if(m.selfattn_norm != None): selfattn_norm.append(m.selfattn_norm) if(hasattr(m, 'endeattn_norm')): if(m.endeattn_norm != None): endeattn_norm.append(m.endeattn_norm) print('self attention norms: ', selfattn_norm) print('en/decoder attn norms:', endeattn_norm) valid_losses = validate_and_save(args, trainer, task, epoch_itr, valid_subsets) if should_stop_early(args, valid_losses[0]) or num_updates >= max_update: break # log end-of-epoch stats stats = get_training_stats(metrics.get_smoothed_values('train')) progress.print(stats, tag='train', step=num_updates) # reset epoch-level meters metrics.reset_meters('train') return valid_losses
def train(args, trainer, task, epoch_itr, model, experiment_path, total_samples=None, last_epoch_num=0, restore=None): """Train the model for one epoch and return validation losses.""" # Initialize data iterator itr = epoch_itr.next_epoch_itr( fix_batches_to_gpus=args.fix_batches_to_gpus, shuffle=(epoch_itr.next_epoch_idx > args.curriculum), ) update_freq = (args.update_freq[epoch_itr.epoch - 1] if epoch_itr.epoch <= len(args.update_freq) else args.update_freq[-1]) itr = iterators.GroupedIterator(itr, update_freq) if getattr(args, "tpu", False): itr = tpu_data_loader(args, itr) progress = progress_bar.progress_bar( itr, log_format=args.log_format, log_interval=args.log_interval, epoch=epoch_itr.epoch, tensorboard_logdir=(args.tensorboard_logdir if distributed_utils.is_master(args) else None), default_log_format=("tqdm" if not args.no_progress_bar else "simple"), ) num_heads = args.decoder_attention_heads head_dim = args.decoder_embed_dim // num_heads if experiment_path is not None: with open(experiment_path, 'r') as f: swaps = json.load(f) mhr(model, swaps, head_dim, num_heads, epoch_itr.epoch) trainer.begin_epoch(epoch_itr.epoch) valid_subsets = args.valid_subset.split(",") should_stop = False conf = { "encoder": [{ "self_attn": [] } for i in range(args.encoder_layers)], "decoder": [{ "self_attn": [], "enc_attn": [] } for i in range(args.decoder_layers)] } attentions = { "decoder": [{ "self_attn": [] } for i in range(args.decoder_layers)] } batch_regression = 1.0 - (total_samples / (160239 * 50)) for i, samples in enumerate(progress): with metrics.aggregate( "train_inner"), torch.autograd.profiler.record_function( "train_step-%d" % i): log_output = trainer.train_step(samples, batch_num=batch_regression) if log_output is None: # OOM, overflow, ... continue total_samples += model.decoder.layers[0].self_attn.bsz batch_regression = 1.0 - ( total_samples / (160239 * 40) ) # need to find more generic way to find total samples and epoch num. # Get Confidence for each Head. if args.head_confidence_method is not None: conf = get_batch_confs(model, conf, args) # log mid-epoch stats num_updates = trainer.get_num_updates() if num_updates % args.log_interval == 0: stats = get_training_stats( metrics.get_smoothed_values("train_inner")) progress.log(stats, tag="train_inner", step=num_updates) # reset mid-epoch stats after each log interval # the end-of-epoch stats will still be preserved metrics.reset_meters("train_inner") end_of_epoch = not itr.has_next() valid_losses, should_stop, val_conf = validate_and_save( args, trainer, task, epoch_itr, valid_subsets, end_of_epoch) if should_stop: break if args.head_confidence_method is not None: conf = convert_confs(conf, args) path = args.save_dir.replace("checkpoints", "confs") + "-method={0}".format( args.head_confidence_method) try: os.mkdir(path, 0o775) except: pass with open( args.save_dir.replace("checkpoints", "confs") + "-method={0}".format(args.head_confidence_method) + "/epoch-{0}.pkl".format(epoch_itr.epoch), 'wb') as fd: pickle.dump(conf, fd, protocol=3) if args.dynamic_type is not None and args.head_confidence_method is not None: conf = val_conf restore['enc_self_attn'], last_epoch_num[ 'enc_self_attn'] = dynamic_mhr(model, int(args.start_dynamic_mhr[0]), "encoder", "self_attn", restore['enc_self_attn'], int(args.dynamic_swap_frequency[0]), last_epoch_num['enc_self_attn'], epoch_itr.epoch + 1, int(args.dynamic_max_switches[0]), conf[0], num_heads, head_dim, args.encoder_layers, local_only=False, d_type=args.dynamic_type[0], rest=int(args.dynamic_rest[0]), end_epoch=int( args.dynamic_end_epoch[0])) restore['dec_self_attn'], last_epoch_num[ 'dec_self_attn'] = dynamic_mhr(model, int(args.start_dynamic_mhr[1]), "decoder", "self_attn", restore['dec_self_attn'], int(args.dynamic_swap_frequency[1]), last_epoch_num['dec_self_attn'], epoch_itr.epoch + 1, int(args.dynamic_max_switches[1]), conf[1], num_heads, head_dim, args.encoder_layers, local_only=False, d_type=args.dynamic_type[1], rest=int(args.dynamic_rest[1]), end_epoch=int( args.dynamic_end_epoch[1])) restore['dec_enc_attn'], last_epoch_num['dec_enc_attn'] = dynamic_mhr( model, int(args.start_dynamic_mhr[2]), "decoder", "encoder_attn", restore['dec_enc_attn'], int(args.dynamic_swap_frequency[2]), last_epoch_num['dec_enc_attn'], epoch_itr.epoch + 1, int(args.dynamic_max_switches[2]), conf[2], num_heads, head_dim, args.encoder_layers, local_only=False, d_type=args.dynamic_type[2], rest=int(args.dynamic_rest[2]), end_epoch=int(args.dynamic_end_epoch[2])) # log end-of-epoch stats stats = get_training_stats(metrics.get_smoothed_values("train")) progress.print(stats, tag="train", step=num_updates) # reset epoch-level meters metrics.reset_meters("train") return valid_losses, should_stop, total_samples, restore, last_epoch_num
def train(args, trainer, task, epoch_itr, m_mle=None): global model_old global model_mle model_old = copy.deepcopy(trainer.model) if m_mle is None: model_mle = model_old else: model_mle = m_mle """Train the model for one epoch.""" # Initialize data iterator itr = epoch_itr.next_epoch_itr( fix_batches_to_gpus=args.fix_batches_to_gpus, shuffle=(epoch_itr.next_epoch_idx > args.curriculum), ) update_freq = ( args.update_freq[epoch_itr.epoch - 1] if epoch_itr.epoch <= len(args.update_freq) else args.update_freq[-1] ) itr = iterators.GroupedIterator(itr, update_freq) progress = progress_bar.progress_bar( itr, log_format=args.log_format, log_interval=args.log_interval, epoch=epoch_itr.epoch, tensorboard_logdir=( args.tensorboard_logdir if distributed_utils.is_master(args) else None ), default_log_format=('tqdm' if not args.no_progress_bar else 'simple'), ) # task specific setup per epoch task.begin_epoch(epoch_itr.epoch, trainer.get_model()) valid_subsets = args.valid_subset.split(',') max_update = args.max_update or math.inf should_end_training = False for samples in progress: if True: # warning valid_losses = validate_and_save(args, trainer, task, epoch_itr, valid_subsets) with metrics.aggregate('train_inner'): # Debug: training goes here log_output = trainer.train_step(samples) if log_output is None: # OOM, overflow, ... continue # Log mid-epoch stats num_updates = trainer.get_num_updates() if num_updates % args.log_interval == 0: stats = get_training_stats(metrics.get_smoothed_values('train_inner')) progress.log(stats, tag='train_inner', step=num_updates) # Reset mid-epoch stats after each log interval # the end-of-epoch stats will still be preserved metrics.reset_meters('train_inner') if num_updates > 2 and num_updates % (args.policy_update_per_k_epoch) == 0: # warning del model_old torch.cuda.empty_cache() model_old = copy.deepcopy(trainer.model) valid_losses = validate_and_save(args, trainer, task, epoch_itr, valid_subsets) if should_stop_early(args, valid_losses[0]) or num_updates >= max_update: should_end_training = True break # Log end-of-epoch stats stats = get_training_stats(metrics.get_smoothed_values('train')) progress.print(stats, tag='train', step=num_updates) # Reset epoch-level meters metrics.reset_meters('train') return should_end_training
def train(args, trainer, task, epoch_itr): """Train the model for one epoch and return validation losses.""" logger.info("begin training epoch {}".format(epoch_itr.epoch)) # Initialize data iterator itr = epoch_itr.next_epoch_itr( fix_batches_to_gpus=args.fix_batches_to_gpus, shuffle=(epoch_itr.next_epoch_idx > args.curriculum), ) update_freq = (args.update_freq[epoch_itr.epoch - 1] if epoch_itr.epoch <= len(args.update_freq) else args.update_freq[-1]) itr = iterators.GroupedIterator(itr, update_freq) if getattr(args, "tpu", False): itr = tpu_data_loader(args, itr) progress = progress_bar.progress_bar( itr, log_format=args.log_format, log_interval=args.log_interval, epoch=epoch_itr.epoch, tensorboard_logdir=(args.tensorboard_logdir if distributed_utils.is_master(args) else None), default_log_format=("tqdm" if not args.no_progress_bar else "simple"), ) trainer.begin_epoch(epoch_itr.epoch) valid_subsets = args.valid_subset.split(",") should_stop = False time_cost = 0 for i, samples in enumerate(progress): ##### statistic program if args.validate_training_performance: performance_end_its = args.performance_begin_its + args.performance_its_count - 1 if args.validate_training_performance and i == args.performance_begin_its: processed_tokens = 0 with metrics.aggregate( "train_inner"), torch.autograd.profiler.record_function( "train_step-%d" % i): time_begin = time.time() log_output = trainer.train_step(samples) time_end = time.time() if args.validate_training_performance and i >= args.performance_begin_its and i <= performance_end_its: time_cost = time_cost + (time_end - time_begin) if log_output is None: # OOM, overflow, ... continue # log mid-epoch stats num_updates = trainer.get_num_updates() if num_updates % args.log_interval == 0: stats = get_training_stats( metrics.get_smoothed_values("train_inner")) progress.log(stats, tag="train_inner", step=num_updates) # reset mid-epoch stats after each log interval # the end-of-epoch stats will still be preserved metrics.reset_meters("train_inner") end_of_epoch = not itr.has_next() valid_losses, should_stop = validate_and_save(args, trainer, task, epoch_itr, valid_subsets, end_of_epoch) if args.validate_training_performance and i >= args.performance_begin_its: for sample in samples: net_input = sample['net_input'] bs, src_lens = net_input['src_tokens'].shape processed_tokens += bs * src_lens if args.validate_training_performance and i == performance_end_its: logger.info("Performance info:") logger.info("Begin iteration:{}".format( args.performance_begin_its)) logger.info("End iteration: {}".format(performance_end_its)) logger.info("Processed_tokens: {}".format(processed_tokens)) logger.info("Time cost: {} s".format(time_cost)) logger.info("Throughput:{} tokens/s".format(processed_tokens / (time_cost))) should_stop = True if should_stop: break # log end-of-epoch stats logger.info("end of epoch {} (average epoch stats below)".format( epoch_itr.epoch)) stats = get_training_stats(metrics.get_smoothed_values("train")) progress.print(stats, tag="train", step=num_updates) # reset epoch-level meters metrics.reset_meters("train") return valid_losses, should_stop
def train(args, trainer, task, epoch_itr): """Train the model for one epoch and return validation losses.""" if isinstance(epoch_itr, list): itrs = [] for itr in epoch_itr: # Initialize data iterators itrs.append( itr.next_epoch_itr( fix_batches_to_gpus=args.fix_batches_to_gpus, shuffle=(itr.next_epoch_idx > args.curriculum), )) update_freq = (args.update_freq[epoch_itr[0].epoch - 1] if epoch_itr[0].epoch <= len(args.update_freq) else args.update_freq[-1]) grouped_itrs = [] for itr in itrs: grouped_itrs.append(iterators.GroupedIterator(itr, update_freq)) # not supported # if getattr(args, "tpu", False): # itr = utils.tpu_data_loader(itr) progress = progress_bar.progress_bar( grouped_itrs, log_format=args.log_format, log_interval=args.log_interval, epoch=epoch_itr[0].epoch, tensorboard_logdir=(args.tensorboard_logdir if distributed_utils.is_master(args) else None), default_log_format=("simplecluster"), ) trainer.begin_epoch(epoch_itr[0].epoch) else: # Initialize data iterators itr = epoch_itr.next_epoch_itr( fix_batches_to_gpus=args.fix_batches_to_gpus, shuffle=(epoch_itr.next_epoch_idx > args.curriculum), ) update_freq = (args.update_freq[epoch_itr.epoch - 1] if epoch_itr.epoch <= len(args.update_freq) else args.update_freq[-1]) itr = iterators.GroupedIterator(itr, update_freq) # not supported # if getattr(args, "tpu", False): # itr = utils.tpu_data_loader(itr) progress = progress_bar.progress_bar( itr, log_format=args.log_format, log_interval=args.log_interval, epoch=epoch_itr.epoch, tensorboard_logdir=(args.tensorboard_logdir if distributed_utils.is_master(args) else None), default_log_format=("tqdm" if not args.no_progress_bar else "simple"), ) trainer.begin_epoch(epoch_itr.epoch) valid_subsets = args.valid_subset.split(",") should_stop = False num_updates = trainer.get_num_updates() for i, samples in enumerate(progress): if 'cluster_ids' not in samples[0]['net_input']: samples[0]['net_input']['cluster_ids'] = numpy.full( (1), 0, dtype=numpy.single) with metrics.aggregate( "train_inner"), torch.autograd.profiler.record_function( "train_step-%d" % i): log_output = trainer.train_step(samples) if log_output is not None: # not OOM, overflow, ... # log mid-epoch stats num_updates = trainer.get_num_updates() if num_updates % args.log_interval == 0: stats = get_training_stats( metrics.get_smoothed_values("train_inner")) progress.log(stats, tag="train_inner", step=num_updates) # reset mid-epoch stats after each log interval # the end-of-epoch stats will still be preserved metrics.reset_meters("train_inner") if isinstance(itr, list): end_of_epoch = not itr[0].has_next() valid_losses, should_stop = validate_and_save( args, trainer, task, epoch_itr[0], valid_subsets, end_of_epoch) if should_stop: break else: end_of_epoch = not itr.has_next() valid_losses, should_stop = validate_and_save( args, trainer, task, epoch_itr, valid_subsets, end_of_epoch) if should_stop: break # log end-of-epoch stats logger.info("end of epoch {} (average epoch stats below)".format( epoch_itr[0].epoch)) stats = get_training_stats(metrics.get_smoothed_values("train")) progress.print(stats, tag="train", step=num_updates) # reset epoch-level meters metrics.reset_meters("train") return valid_losses, should_stop