def train(args, trainer, task, epoch_itr): """Train the model for one epoch and return validation losses.""" # Initialize data iterator itr = epoch_itr.next_epoch_itr( fix_batches_to_gpus=args.fix_batches_to_gpus, shuffle=(epoch_itr.next_epoch_idx > args.curriculum), ) update_freq = ( args.update_freq[epoch_itr.epoch - 1] if epoch_itr.epoch <= len(args.update_freq) else args.update_freq[-1] ) itr = iterators.GroupedIterator(itr, update_freq) if getattr(args, "tpu", False): itr = tpu_data_loader(args, itr) progress = progress_bar.progress_bar( itr, log_format=args.log_format, log_interval=args.log_interval, epoch=epoch_itr.epoch, tensorboard_logdir=( args.tensorboard_logdir if distributed_utils.is_master(args) else None ), default_log_format=("tqdm" if not args.no_progress_bar else "simple"), ) trainer.begin_epoch(epoch_itr.epoch) valid_subsets = args.valid_subset.split(",") should_stop = False for i, samples in enumerate(progress): with metrics.aggregate("train_inner"), torch.autograd.profiler.record_function("train_step-%d" % i): log_output = trainer.train_step(samples) if log_output is None: # OOM, overflow, ... continue # log mid-epoch stats num_updates = trainer.get_num_updates() if num_updates % args.log_interval == 0: stats = get_training_stats(metrics.get_smoothed_values("train_inner")) progress.log(stats, tag="train_inner", step=num_updates) # reset mid-epoch stats after each log interval # the end-of-epoch stats will still be preserved metrics.reset_meters("train_inner") end_of_epoch = not itr.has_next() valid_losses, should_stop = validate_and_save( args, trainer, task, epoch_itr, valid_subsets, end_of_epoch ) if should_stop: break # log end-of-epoch stats stats = get_training_stats(metrics.get_smoothed_values("train")) progress.print(stats, tag="train", step=num_updates) # reset epoch-level meters metrics.reset_meters("train") return valid_losses, should_stop
def train(args, trainer, task, epoch_itr): """Train the model for one epoch.""" # Initialize data iterator itr = epoch_itr.next_epoch_itr( fix_batches_to_gpus=args.fix_batches_to_gpus, shuffle=(epoch_itr.next_epoch_idx > args.curriculum), ) update_freq = (args.update_freq[epoch_itr.epoch - 1] if epoch_itr.epoch <= len(args.update_freq) else args.update_freq[-1]) itr = iterators.GroupedIterator(itr, update_freq) progress = progress_bar.progress_bar( itr, log_format=args.log_format, log_interval=args.log_interval, epoch=epoch_itr.epoch, tensorboard_logdir=(args.tensorboard_logdir if distributed_utils.is_master(args) else None), default_log_format=('tqdm' if not args.no_progress_bar else 'simple'), ) # task specific setup per epoch task.begin_epoch(epoch_itr.epoch, trainer.get_model()) valid_subsets = args.valid_subset.split(',') max_update = args.max_update or math.inf for samples in progress: with metrics.aggregate('train_inner'): log_output = trainer.train_step(samples) if log_output is None: # OOM, overflow, ... continue # log mid-epoch stats num_updates = trainer.get_num_updates() if num_updates % args.log_interval == 0: stats = get_training_stats( metrics.get_smoothed_values('train_inner')) progress.log(stats, tag='train_inner', step=num_updates) # reset mid-epoch stats after each log interval # the end-of-epoch stats will still be preserved metrics.reset_meters('train_inner') if (not args.disable_validation and args.save_interval_updates > 0 and num_updates % args.save_interval_updates == 0 and num_updates > 0): valid_losses = validate(args, trainer, task, epoch_itr, valid_subsets) checkpoint_utils.save_checkpoint(args, trainer, epoch_itr, valid_losses[0]) if num_updates >= max_update: break # log end-of-epoch stats stats = get_training_stats(metrics.get_smoothed_values('train')) progress.print(stats, tag='train', step=num_updates) # reset epoch-level meters metrics.reset_meters('train')
def valid_step(self, batch_itr): if self.model.global_steps % self.fs_args.validate_interval_updates != 0: return with torch.no_grad(): self.model.eval() for subset in batch_itr.valid_dataset(): with metrics.aggregate(new_root=True) as agg: for batch, is_dummy_batch in batch_itr.valid_batch(): _, sample_size, logging_output = self.task.valid_step( batch, self.model.module.model, self.model.module.criterion ) logging_outputs = [logging_output] if is_dummy_batch: if torch.is_tensor(sample_size): sample_size.zero_() else: sample_size *= 0.0 logging_outputs, (sample_size,) = torch_reduce_sum( self.model.device, logging_outputs, sample_size, ignore=is_dummy_batch, ) logging_output = self.reduce_log(logging_outputs, sample_size) log_dist( "Valid on step: {}, dataset: {}. {}".format( self.model.global_steps, subset, view_log(agg.get_smoothed_values()), ), ranks=[0], )
def _reduce_and_log_stats(self, logging_outputs, sample_size, grad_norm=None): if grad_norm is not None: metrics.log_speed("ups", 1., priority=100, round=2) metrics.log_scalar("gnorm", grad_norm, priority=400, round=3) if self.args.clip_norm > 0: metrics.log_scalar( "clip", torch.where( grad_norm > self.args.clip_norm, grad_norm.new_tensor(100), grad_norm.new_tensor(0), ), priority=500, round=1, ) with metrics.aggregate() as agg: if logging_outputs is not None: self.task.reduce_metrics(logging_outputs, self.get_criterion()) # support legacy interface logging_output = agg.get_smoothed_values() logging_output["sample_size"] = sample_size for key_to_delete in ["ppl", "wps", "wpb", "bsz"]: if key_to_delete in logging_output: del logging_output[key_to_delete] return logging_output
def validate( cfg: DictConfig, trainer: Trainer, task: tasks.FairseqTask, epoch_itr, subsets: List[str], ) -> List[Optional[float]]: """Evaluate the model on the validation set(s) and return the losses.""" if cfg.dataset.fixed_validation_seed is not None: # set fixed seed for every validation utils.set_torch_seed(cfg.dataset.fixed_validation_seed) trainer.begin_valid_epoch(epoch_itr.epoch) valid_losses = [] for subset in subsets: logger.info('begin validation on "{}" subset'.format(subset)) # Initialize data iterator itr = trainer.get_valid_iterator(subset).next_epoch_itr( shuffle=False, set_dataset_epoch=False # use a fixed valid set ) if cfg.common.tpu: itr = utils.tpu_data_loader(itr) progress = progress_bar.progress_bar( itr, log_format=cfg.common.log_format, log_interval=cfg.common.log_interval, epoch=epoch_itr.epoch, prefix=f"valid on '{subset}' subset", tensorboard_logdir=( cfg.common.tensorboard_logdir if distributed_utils.is_master(cfg.distributed_training) else None ), default_log_format=("tqdm" if not cfg.common.no_progress_bar else "simple"), wandb_project=( cfg.common.wandb_project if distributed_utils.is_master(cfg.distributed_training) else None ), wandb_run_name=os.environ.get( "WANDB_NAME", os.path.basename(cfg.checkpoint.save_dir) ), ) # create a new root metrics aggregator so validation metrics # don't pollute other aggregators (e.g., train meters) with metrics.aggregate(new_root=True) as agg: for i, sample in enumerate(progress): if cfg.dataset.max_valid_steps is not None and i > cfg.dataset.max_valid_steps: break trainer.valid_step(sample) # log validation stats stats = get_valid_stats(cfg, trainer, agg.get_smoothed_values()) progress.print(stats, tag=subset, step=trainer.get_num_updates()) valid_losses.append(stats[cfg.checkpoint.best_checkpoint_metric]) return valid_losses
def validate(args, trainer, task, epoch_itr, subsets): """Evaluate the model on the validation set(s) and return the losses.""" if args.fixed_validation_seed is not None: # set fixed seed for every validation utils.set_torch_seed(args.fixed_validation_seed) model = trainer.model val_conf = { "encoder": [{ "self_attn": [] } for i in range(args.encoder_layers)], "decoder": [{ "self_attn": [], "enc_attn": [] } for i in range(args.decoder_layers)] } valid_losses = [] for subset in subsets: # Initialize data iterator itr = trainer.get_valid_iterator(subset).next_epoch_itr(shuffle=False) if getattr(args, "tpu", False): itr = tpu_data_loader(args, itr) progress = progress_bar.progress_bar( itr, log_format=args.log_format, log_interval=args.log_interval, epoch=epoch_itr.epoch, prefix=f"valid on '{subset}' subset", tensorboard_logdir=(args.tensorboard_logdir if distributed_utils.is_master(args) else None), default_log_format=("tqdm" if not args.no_progress_bar else "simple"), ) # create a new root metrics aggregator so validation metrics # don't pollute other aggregators (e.g., train meters) with metrics.aggregate(new_root=True) as agg: for sample in progress: trainer.valid_step(sample) # Get confidence for each head if args.head_confidence_method is not None: val_conf = get_batch_confs(model, val_conf, args) # log validation stats stats = get_valid_stats(args, trainer, agg.get_smoothed_values()) progress.print(stats, tag=subset, step=trainer.get_num_updates()) valid_losses.append(stats[args.best_checkpoint_metric]) if args.head_confidence_method is not None: val_conf = convert_confs(val_conf, args) val_conf = calc_conf_per_epoch(val_conf, args) return valid_losses, val_conf
def validate(args, trainer, task, epoch_itr, subsets): """Evaluate the model on the validation set(s) and return the losses.""" if args.fixed_validation_seed is not None: # set fixed seed for every validation utils.set_torch_seed(args.fixed_validation_seed) # reset dummy batch only for validation trainer._dummy_batch = "DUMMY" # reset dummy batch valid_losses = [] for subset in subsets: # Initialize data iterator itr = task.get_batch_iterator( dataset=task.dataset(subset), max_tokens=args.max_tokens_valid, max_sentences=args.max_sentences_valid, max_positions=utils.resolve_max_positions( task.max_positions(), trainer.get_model().max_positions(), ), ignore_invalid_inputs=args.skip_invalid_size_inputs_valid_test, required_batch_size_multiple=args.required_batch_size_multiple, seed=args.seed, num_shards=args.distributed_world_size, shard_id=args.distributed_rank, num_workers=args.num_workers, ).next_epoch_itr(shuffle=False) progress = progress_bar.progress_bar( itr, log_format=args.log_format, log_interval=args.log_interval, epoch=epoch_itr.epoch, prefix=f"valid on '{subset}' subset", tensorboard_logdir=(args.tensorboard_logdir if distributed_utils.is_master(args) else None), default_log_format=('tqdm' if not args.no_progress_bar else 'simple'), ) # create a new root metrics aggregator so validation metrics # don't pollute other aggregators (e.g., train meters) with metrics.aggregate(new_root=True) as agg: for step, sample in enumerate(progress): trainer.valid_step(sample) stats = get_training_stats(agg.get_smoothed_values()) plog = progress.log if hasattr(progress, "wrapped_bar"): plog = progress.wrapped_bar.log plog(stats, tag='valid', step=step) # log validation stats stats = get_valid_stats(args, trainer, agg.get_smoothed_values()) progress.print(stats, tag=subset, step=trainer.get_num_updates()) valid_losses.append(stats[args.best_checkpoint_metric]) # reset dummy batch again for continuing training trainer._dummy_batch = "DUMMY" return valid_losses
def train(args, trainer, task, epoch_itr, max_update=math.inf): """Train the model for one epoch and return validation losses.""" # Initialize data iterator itr = epoch_itr.next_epoch_itr( fix_batches_to_gpus=args.fix_batches_to_gpus, shuffle=(epoch_itr.next_epoch_idx > args.curriculum), ) update_freq = (args.update_freq[epoch_itr.epoch - 1] if epoch_itr.epoch <= len(args.update_freq) else args.update_freq[-1]) itr = iterators.GroupedIterator(itr, update_freq) if getattr(args, 'tpu', False): itr = tpu_data_loader(args, itr) progress = progress_bar.progress_bar( itr, log_format=args.log_format, log_interval=args.log_interval, epoch=epoch_itr.epoch, tensorboard_logdir=(args.tensorboard_logdir if distributed_utils.is_master(args) else None), default_log_format=('tqdm' if not args.no_progress_bar else 'simple'), ) progress.log_args(args, tag='train') trainer.begin_epoch(epoch_itr.epoch) valid_subsets = args.valid_subset.split(',') for samples in progress: with metrics.aggregate('train_inner'): log_output = trainer.train_step(samples) if log_output is None: # OOM, overflow, ... continue # log mid-epoch stats num_updates = trainer.get_num_updates() if num_updates % args.log_interval == 0: stats = get_training_stats( metrics.get_smoothed_values('train_inner')) progress.log(stats, tag='train_inner', step=num_updates) # reset mid-epoch stats after each log interval # the end-of-epoch stats will still be preserved metrics.reset_meters('train_inner') end_of_epoch = not itr.has_next() valid_losses = validate_and_save(args, trainer, task, epoch_itr, valid_subsets, end_of_epoch) if should_stop_early(args, valid_losses[0]) or num_updates >= max_update: break # log end-of-epoch stats stats = get_training_stats(metrics.get_smoothed_values('train')) progress.print(stats, tag='train', step=num_updates) # reset epoch-level meters metrics.reset_meters('train') return valid_losses
def validate(args, trainer, task, epoch_itr, subsets): """Evaluate the model on the validation set(s) and return the losses.""" if args.fixed_validation_seed is not None: # set fixed seed for every validation utils.set_torch_seed(args.fixed_validation_seed) valid_losses = [] for subset in subsets: # Initialize data iterator itr = task.get_batch_iterator( dataset=task.dataset(subset), max_tokens=args.max_tokens_valid, max_sentences=args.max_sentences_valid, max_positions=utils.resolve_max_positions( task.max_positions(), trainer.get_model().max_positions(), ), ignore_invalid_inputs=args.skip_invalid_size_inputs_valid_test, required_batch_size_multiple=args.required_batch_size_multiple, seed=args.seed, num_shards=args.distributed_world_size, shard_id=args.distributed_rank, num_workers=args.num_workers, ).next_epoch_itr(shuffle=False) progress = progress_bar.progress_bar( itr, log_format=args.log_format, log_interval=args.log_interval, epoch=epoch_itr.epoch, prefix=f"valid on '{subset}' subset", tensorboard_logdir=( args.tensorboard_logdir if distributed_utils.is_master(args) else None ), default_log_format=('tqdm' if not args.no_progress_bar else 'simple'), ) # create a new root metrics aggregator so validation metrics # don't pollute other aggregators (e.g., train meters) all_preds, all_labels = [], [] with metrics.aggregate(new_root=True) as agg: for sample in progress: logging_outputs, preds, labels = trainer.valid_step(sample) if preds is not None: all_preds.extend(preds) all_labels.extend(labels) else: all_preds, all_labels = None, None if all_preds is not None: all_preds = torch.cat(all_preds).cpu().numpy() all_labels = torch.cat(all_labels).cpu().numpy() # log validation stats stats = get_valid_stats(args, trainer, agg.get_smoothed_values(), all_preds, all_labels) progress.print(stats, tag=subset, step=trainer.get_num_updates()) valid_losses.append(stats[args.best_checkpoint_metric]) return valid_losses
def validate(args, trainer, task, epoch_itr, subsets): """Evaluate the model on the validation set(s) and return the losses.""" assert args.max_sentences_valid == 1, 'Val only supports batch size 1!' if args.fixed_validation_seed is not None: # set fixed seed for every validation utils.set_torch_seed(args.fixed_validation_seed) valid_losses = [] shuffle = False if args.validation_max_size > 0: logging.info(f'Validation set truncated to {args.validation_max_size}.') shuffle = True assert args.seed == 1234, args.seed for subset in subsets: # Initialize data iterator itr = task.get_batch_iterator( dataset=task.dataset(subset), max_tokens=args.max_tokens_valid, max_sentences=args.max_sentences_valid, max_positions=utils.resolve_max_positions( task.max_positions(), trainer.get_model().max_positions(), ), ignore_invalid_inputs=args.skip_invalid_size_inputs_valid_test, required_batch_size_multiple=1,#args.required_batch_size_multiple, seed=args.seed, num_shards=args.distributed_world_size, shard_id=args.distributed_rank, num_workers=args.num_workers, ).next_epoch_itr(shuffle=shuffle) progress = progress_bar.progress_bar( itr, log_format=args.log_format, log_interval=args.log_interval, epoch=epoch_itr.epoch, prefix=f"valid on '{subset}' subset", tensorboard_logdir=( args.tensorboard_logdir if distributed_utils.is_master(args) else None ), default_log_format=('tqdm' if not args.no_progress_bar else 'simple'), ) # create a new root metrics aggregator so validation metrics # don't pollute other aggregators (e.g., train meters) with metrics.aggregate(new_root=True) as agg: i = 0 for sample in progress: i += 1 if args.validation_max_size > 0 and i > args.validation_max_size / args.distributed_world_size: continue trainer.valid_step(sample, validation_topk=args.validation_topk, validation_D=args.validation_D, validation_rounds=args.validation_rounds) # log validation stats stats = get_valid_stats(args, trainer, agg.get_smoothed_values()) progress.print(stats, tag=subset, step=trainer.get_num_updates()) valid_losses.append(stats[args.best_checkpoint_metric]) return valid_losses
def reduce_log(self, logging_outputs, sample_size): with metrics.aggregate() as agg: if logging_outputs is not None: self.task.reduce_metrics(logging_outputs, self.criterion) del logging_outputs logging_output = agg.get_smoothed_values() logging_output["sample_size"] = sample_size return logging_output
def validate( cfg: DictConfig, trainer: Trainer, task: tasks.FairseqTask, cur_step, epoch_itr, subsets: List[str], ) -> List[Optional[float]]: """Evaluate the model on the validation set(s) and return the losses.""" if cfg.dataset.fixed_validation_seed is not None: # set fixed seed for every validation utils.set_torch_seed(cfg.dataset.fixed_validation_seed) trainer.begin_valid_epoch(epoch_itr.epoch) valid_losses = [] for subset in subsets: logger.info('begin validation on "{}" subset'.format(subset)) # Initialize data iterator itr = trainer.get_valid_iterator(subset).next_epoch_itr(shuffle=False) if cfg.common.tpu: itr = utils.tpu_data_loader(itr) progress = progress_bar.progress_bar( itr, log_format=cfg.common.log_format, log_interval=cfg.common.log_interval, epoch=epoch_itr.epoch, prefix=f"valid on '{subset}' subset", tensorboard_logdir=(cfg.common.tensorboard_logdir if distributed_utils.is_master( cfg.distributed_training) else None), default_log_format=("tqdm" if not cfg.common.no_progress_bar else "simple"), wandb_project=(cfg.common.wandb_project if distributed_utils.is_master( cfg.distributed_training) else None), wandb_run_name=os.environ.get( "WANDB_NAME", os.path.basename(cfg.checkpoint.save_dir)), ) # create a new root metrics aggregator so validation metrics # don't pollute other aggregators (e.g., train meters) # 需要通过metrics.log_scalar("key", val)添加到metrics里面,才能在agg中显示出来log with metrics.aggregate(new_root=True) as agg: for sample in progress: trainer.valid_step(sample, cur_step=cur_step) # import pdb # pdb.set_trace() # log validation stats # stats里面已经有了agg.get_smoothed_values()这个orderedDict作为基础,通过get_valid_stats函数获得一些其他的state值 stats = get_valid_stats(cfg, trainer, agg.get_smoothed_values()) progress.print(stats, tag=subset, step=trainer.get_num_updates()) valid_losses.append(stats[cfg.checkpoint.best_checkpoint_metric]) # print("In fairseq/fairseq_cli/train.py line 400:\n{} not in stats".format(cfg.checkpoint.best_checkpoint_metric)) return valid_losses
def _reduce_and_log_stats(self, logging_outputs, sample_size, grad_norm=None): metrics.log_scalar("kl_loss", round(logging_outputs[0]["kl_loss"].item(), 3)) metrics.log_scalar("kld", round(logging_outputs[0]["kld"].item(), 3)) metrics.log_scalar("bow_loss", round(logging_outputs[0]["bow_loss"].item(), 3)) if grad_norm is not None and (not torch.is_tensor(grad_norm) or torch.isfinite(grad_norm)): metrics.log_speed("ups", 1.0, priority=100, round=2) metrics.log_scalar("gnorm", grad_norm, priority=400, round=3) if self.cfg.optimization.clip_norm > 0: metrics.log_scalar( "clip", torch.where( grad_norm > self.cfg.optimization.clip_norm, grad_norm.new_tensor(100), grad_norm.new_tensor(0), ), priority=500, round=1, ) with metrics.aggregate() as agg: if logging_outputs is not None: self.task.reduce_metrics(logging_outputs, self.get_criterion()) del logging_outputs # extra warning for criterions that don't properly log a loss value if "loss" not in agg: if "loss" not in self._warn_once: self._warn_once.add("loss") logger.warning( "Criterion.reduce_metrics did not log a 'loss' value, " "which may break some functionality") metrics.log_scalar("loss", -1) # support legacy interface if self.tpu: logging_output = {} else: logging_output = agg.get_smoothed_values() logging_output["sample_size"] = sample_size for key_to_delete in ["ppl", "wps", "wpb", "bsz"]: if key_to_delete in logging_output: del logging_output[key_to_delete] return logging_output
def validate(args, trainer, task, epoch_itr, subsets): """Evaluate the model on the validation set(s) and return the losses.""" if args.fixed_validation_seed is not None: # set fixed seed for every validation utils.set_torch_seed(args.fixed_validation_seed) valid_losses = [] for subset in subsets: logger.info('begin validation on "{}" subset'.format(subset)) # Initialize data iterator itr = trainer.get_valid_iterator(subset).next_epoch_itr(shuffle=False) if getattr(args, "tpu", False): itr = utils.tpu_data_loader(itr) progress = progress_bar.progress_bar( itr, log_format=args.log_format, log_interval=args.log_interval, epoch=epoch_itr.epoch, prefix=f"valid on '{subset}' subset", tensorboard_logdir=(args.tensorboard_logdir if distributed_utils.is_master(args) else None), default_log_format=("tqdm" if not args.no_progress_bar else "simple"), ) # create a new root metrics aggregator so validation metrics # don't pollute other aggregators (e.g., train meters) with metrics.aggregate(new_root=True) as agg: for sample in progress: trainer.valid_step(sample) # log validation stats stats = get_valid_stats(args, trainer, agg.get_smoothed_values()) if "fg_gloss0" in stats: criterion = trainer.get_criterion() ngroups = criterion.n_groups baselines = torch.zeros(ngroups, device='cuda') for ii in range(ngroups): key = "fg_gloss{}".format(ii) baselines[ii] = stats[key] stats.pop(key, None) criterion.set_valid_baselines(baselines) progress.print(stats, tag=subset, step=trainer.get_num_updates()) valid_losses.append(stats[args.best_checkpoint_metric]) return valid_losses
def validate(args, trainer, task, epoch_itr, subsets): """Evaluate the model on the validation set(s) and return the losses.""" if args.fixed_validation_seed is not None: # set fixed seed for every validation utils.set_torch_seed(args.fixed_validation_seed) valid_losses = [] for subset in subsets: logger.info('begin validation on "{}" subset'.format(subset)) # Initialize data iterator itr = trainer.get_valid_iterator(subset).next_epoch_itr(shuffle=False) if getattr(args, "tpu", False): itr = utils.tpu_data_loader(itr) progress = progress_bar.progress_bar( itr, log_format=args.log_format, log_interval=args.log_interval, epoch=epoch_itr.epoch, prefix=f"valid on '{subset}' subset", tensorboard_logdir=(args.tensorboard_logdir if distributed_utils.is_master(args) else None), default_log_format=("tqdm" if not args.no_progress_bar else "simple"), ) # create a new root metrics aggregator so validation metrics # don't pollute other aggregators (e.g., train meters) count = 0 with metrics.aggregate(new_root=True) as agg: for sample in progress: trainer.valid_step(sample) count += 1 if count % 50 == 0: logger.info("Processed {} batches!".format(count)) # log validation stats stats = get_valid_stats(args, trainer, agg.get_smoothed_values()) progress.print(stats, tag=subset, step=trainer.get_num_updates()) valid_losses.append(stats[args.best_checkpoint_metric]) return valid_losses
def _reduce_and_log_stats(self, logging_outputs, sample_size): if logging_outputs is None or len(logging_outputs) == 0: return {"sample_size": sample_size} with metrics.aggregate() as agg: # convert logging_outputs to CPU to avoid unnecessary # device-to-host transfers in reduce_metrics logging_outputs = utils.apply_to_sample( lambda t: t.to( device='cpu', non_blocking=True, dtype=torch.double), logging_outputs) self.task.reduce_metrics(logging_outputs, self.get_criterion()) # support legacy interface logging_output = agg.get_smoothed_values() logging_output["sample_size"] = sample_size for key_to_delete in ["ppl", "wps", "wpb", "bsz"]: if key_to_delete in logging_output: del logging_output[key_to_delete] return logging_output
def train(cfg: DictConfig, trainer: Trainer, task: tasks.FairseqTask, epoch_itr) -> Tuple[List[Optional[float]], bool]: """Train the model for one epoch and return validation losses.""" # Initialize data iterator itr = epoch_itr.next_epoch_itr( fix_batches_to_gpus=cfg.distributed_training.fix_batches_to_gpus, shuffle=(epoch_itr.next_epoch_idx > cfg.dataset.curriculum), ) update_freq = (cfg.optimization.update_freq[epoch_itr.epoch - 1] if epoch_itr.epoch <= len(cfg.optimization.update_freq) else cfg.optimization.update_freq[-1]) itr = iterators.GroupedIterator(itr, update_freq) if cfg.common.tpu: itr = utils.tpu_data_loader(itr) progress = progress_bar.progress_bar( itr, log_format=cfg.common.log_format, log_interval=cfg.common.log_interval, epoch=epoch_itr.epoch, tensorboard_logdir=(cfg.common.tensorboard_logdir if distributed_utils.is_master( cfg.distributed_training) else None), default_log_format=("tqdm" if not cfg.common.no_progress_bar else "simple"), wandb_project=(cfg.common.wandb_project if distributed_utils.is_master( cfg.distributed_training) else None), wandb_run_name=os.environ.get( "WANDB_NAME", os.path.basename(cfg.checkpoint.save_dir)), azureml_logging=(cfg.common.azureml_logging if distributed_utils.is_master( cfg.distributed_training) else False), ) progress.update_config(_flatten_config(cfg)) trainer.begin_epoch(epoch_itr.epoch) valid_subsets = cfg.dataset.valid_subset.split(",") should_stop = False num_updates = trainer.get_num_updates() logger.info("Start iterating over samples") for i, samples in enumerate(progress): with metrics.aggregate( "train_inner"), torch.autograd.profiler.record_function( "train_step-%d" % i): log_output = trainer.train_step(samples) if log_output is not None: # not OOM, overflow, ... # log mid-epoch stats num_updates = trainer.get_num_updates() if num_updates % cfg.common.log_interval == 0: stats = get_training_stats( metrics.get_smoothed_values("train_inner")) progress.log(stats, tag="train_inner", step=num_updates) # reset mid-epoch stats after each log interval # the end-of-epoch stats will still be preserved metrics.reset_meters("train_inner") end_of_epoch = not itr.has_next() valid_losses, should_stop = validate_and_save(cfg, trainer, task, epoch_itr, valid_subsets, end_of_epoch) if should_stop: break # log end-of-epoch stats logger.info("end of epoch {} (average epoch stats below)".format( epoch_itr.epoch)) stats = get_training_stats(metrics.get_smoothed_values("train")) progress.print(stats, tag="train", step=num_updates) # reset epoch-level meters metrics.reset_meters("train") return valid_losses, should_stop
def main(cfg: DictConfig, override_args=None): if isinstance(cfg, Namespace): cfg = convert_namespace_to_omegaconf(cfg) utils.import_user_module(cfg.common) assert ( cfg.dataset.max_tokens is not None or cfg.dataset.batch_size is not None ), "Must specify batch size either with --max-tokens or --batch-size" use_fp16 = cfg.common.fp16 use_cuda = torch.cuda.is_available() and not cfg.common.cpu if use_cuda: torch.cuda.set_device(cfg.distributed_training.device_id) if cfg.distributed_training.distributed_world_size > 1: data_parallel_world_size = distributed_utils.get_data_parallel_world_size( ) data_parallel_rank = distributed_utils.get_data_parallel_rank() else: data_parallel_world_size = 1 data_parallel_rank = 0 if override_args is not None: overrides = vars(override_args) overrides.update(eval(getattr(override_args, "model_overrides", "{}"))) else: overrides = None # Load ensemble logger.info("loading model(s) from {}".format(cfg.common_eval.path)) models, saved_cfg, task = checkpoint_utils.load_model_ensemble_and_task( [cfg.common_eval.path], arg_overrides=overrides, suffix=cfg.checkpoint.checkpoint_suffix, ) model = models[0] # Move models to GPU for model in models: if use_fp16: model.half() if use_cuda: model.cuda() # Print args logger.info(saved_cfg) # Build criterion criterion = task.build_criterion(saved_cfg.criterion) criterion.eval() for subset in cfg.dataset.valid_subset.split(","): try: task.load_dataset(subset, combine=False, epoch=1, task_cfg=saved_cfg.task) dataset = task.dataset(subset) except KeyError: raise Exception("Cannot find dataset: " + subset) # Initialize data iterator itr = task.get_batch_iterator( dataset=dataset, max_tokens=cfg.dataset.max_tokens, max_sentences=cfg.dataset.batch_size, max_positions=utils.resolve_max_positions( task.max_positions(), *[m.max_positions() for m in models], ), ignore_invalid_inputs=cfg.dataset. skip_invalid_size_inputs_valid_test, required_batch_size_multiple=cfg.dataset. required_batch_size_multiple, seed=cfg.common.seed, num_shards=data_parallel_world_size, shard_id=data_parallel_rank, num_workers=cfg.dataset.num_workers, data_buffer_size=cfg.dataset.data_buffer_size, ).next_epoch_itr(shuffle=False) progress = progress_bar.progress_bar( itr, log_format=cfg.common.log_format, log_interval=cfg.common.log_interval, prefix=f"valid on '{subset}' subset", default_log_format=("tqdm" if not cfg.common.no_progress_bar else "simple"), ) log_outputs = [] for i, sample in enumerate(progress): sample = utils.move_to_cuda(sample) if use_cuda else sample _loss, _sample_size, log_output = task.valid_step( sample, model, criterion) progress.log(log_output, step=i) log_outputs.append(log_output) if data_parallel_world_size > 1: log_outputs = distributed_utils.all_gather_list( log_outputs, max_size=cfg.common.all_gather_list_size, group=distributed_utils.get_data_parallel_group(), ) log_outputs = list(chain.from_iterable(log_outputs)) with metrics.aggregate() as agg: task.reduce_metrics(log_outputs, criterion) log_output = agg.get_smoothed_values() progress.print(log_output, tag=subset, step=i)
def main(args, override_args=None): utils.import_user_module(args) assert args.max_tokens is not None or args.max_sentences is not None, \ 'Must specify batch size either with --max-tokens or --max-sentences' use_fp16 = args.fp16 use_cuda = torch.cuda.is_available() and not args.cpu if use_cuda: torch.cuda.set_device(args.device_id) if override_args is not None: overrides = vars(override_args) overrides.update(eval(getattr(override_args, 'model_overrides', '{}'))) else: overrides = None # Load ensemble logger.info('loading model(s) from {}'.format(args.path)) models, model_args, task = checkpoint_utils.load_model_ensemble_and_task( [args.path], arg_overrides=overrides, suffix=getattr(args, "checkpoint_suffix", ""), ) model = models[0] # Move models to GPU for model in models: if use_fp16: model.half() if use_cuda: model.cuda() # Print args logger.info(model_args) # Build criterion criterion = task.build_criterion(model_args) criterion.eval() for subset in args.valid_subset.split(','): try: task.load_dataset(subset, combine=False, epoch=1) dataset = task.dataset(subset) except KeyError: raise Exception('Cannot find dataset: ' + subset) # Initialize data iterator itr = task.get_batch_iterator( dataset=dataset, max_tokens=args.max_tokens, max_sentences=args.max_sentences, max_positions=utils.resolve_max_positions( task.max_positions(), *[m.max_positions() for m in models], ), ignore_invalid_inputs=args.skip_invalid_size_inputs_valid_test, required_batch_size_multiple=args.required_batch_size_multiple, seed=args.seed, num_shards=args.distributed_world_size, shard_id=args.distributed_rank, num_workers=args.num_workers, data_buffer_size=args.data_buffer_size, ).next_epoch_itr(shuffle=False) progress = progress_bar.progress_bar( itr, log_format=args.log_format, log_interval=args.log_interval, prefix=f"valid on '{subset}' subset", default_log_format=('tqdm' if not args.no_progress_bar else 'simple'), ) log_outputs = [] for i, sample in enumerate(progress): sample = utils.move_to_cuda(sample) if use_cuda else sample _loss, _sample_size, log_output = task.valid_step( sample, model, criterion) progress.log(log_output, step=i) log_outputs.append(log_output) if args.distributed_world_size > 1: log_outputs = distributed_utils.all_gather_list( log_outputs, max_size=getattr(args, 'all_gather_list_size', 16384), ) log_outputs = list(chain.from_iterable(log_outputs)) with metrics.aggregate() as agg: task.reduce_metrics(log_outputs, criterion) log_output = agg.get_smoothed_values() progress.print(log_output, tag=subset, step=i)
def train(args, trainer, task, epoch_itr): """Train the model for one epoch and return validation losses.""" if isinstance(epoch_itr, list): itrs = [] for itr in epoch_itr: # Initialize data iterators itrs.append( itr.next_epoch_itr( fix_batches_to_gpus=args.fix_batches_to_gpus, shuffle=(itr.next_epoch_idx > args.curriculum), )) update_freq = (args.update_freq[epoch_itr[0].epoch - 1] if epoch_itr[0].epoch <= len(args.update_freq) else args.update_freq[-1]) grouped_itrs = [] for itr in itrs: grouped_itrs.append(iterators.GroupedIterator(itr, update_freq)) # not supported # if getattr(args, "tpu", False): # itr = utils.tpu_data_loader(itr) progress = progress_bar.progress_bar( grouped_itrs, log_format=args.log_format, log_interval=args.log_interval, epoch=epoch_itr[0].epoch, tensorboard_logdir=(args.tensorboard_logdir if distributed_utils.is_master(args) else None), default_log_format=("simplecluster"), ) trainer.begin_epoch(epoch_itr[0].epoch) else: # Initialize data iterators itr = epoch_itr.next_epoch_itr( fix_batches_to_gpus=args.fix_batches_to_gpus, shuffle=(epoch_itr.next_epoch_idx > args.curriculum), ) update_freq = (args.update_freq[epoch_itr.epoch - 1] if epoch_itr.epoch <= len(args.update_freq) else args.update_freq[-1]) itr = iterators.GroupedIterator(itr, update_freq) # not supported # if getattr(args, "tpu", False): # itr = utils.tpu_data_loader(itr) progress = progress_bar.progress_bar( itr, log_format=args.log_format, log_interval=args.log_interval, epoch=epoch_itr.epoch, tensorboard_logdir=(args.tensorboard_logdir if distributed_utils.is_master(args) else None), default_log_format=("tqdm" if not args.no_progress_bar else "simple"), ) trainer.begin_epoch(epoch_itr.epoch) valid_subsets = args.valid_subset.split(",") should_stop = False num_updates = trainer.get_num_updates() for i, samples in enumerate(progress): if 'cluster_ids' not in samples[0]['net_input']: samples[0]['net_input']['cluster_ids'] = numpy.full( (1), 0, dtype=numpy.single) with metrics.aggregate( "train_inner"), torch.autograd.profiler.record_function( "train_step-%d" % i): log_output = trainer.train_step(samples) if log_output is not None: # not OOM, overflow, ... # log mid-epoch stats num_updates = trainer.get_num_updates() if num_updates % args.log_interval == 0: stats = get_training_stats( metrics.get_smoothed_values("train_inner")) progress.log(stats, tag="train_inner", step=num_updates) # reset mid-epoch stats after each log interval # the end-of-epoch stats will still be preserved metrics.reset_meters("train_inner") if isinstance(itr, list): end_of_epoch = not itr[0].has_next() valid_losses, should_stop = validate_and_save( args, trainer, task, epoch_itr[0], valid_subsets, end_of_epoch) if should_stop: break else: end_of_epoch = not itr.has_next() valid_losses, should_stop = validate_and_save( args, trainer, task, epoch_itr, valid_subsets, end_of_epoch) if should_stop: break # log end-of-epoch stats logger.info("end of epoch {} (average epoch stats below)".format( epoch_itr[0].epoch)) stats = get_training_stats(metrics.get_smoothed_values("train")) progress.print(stats, tag="train", step=num_updates) # reset epoch-level meters metrics.reset_meters("train") return valid_losses, should_stop
def train(args, trainer, task, epoch_itr, m_mle=None): global model_old global model_mle model_old = copy.deepcopy(trainer.model) if m_mle is None: model_mle = model_old else: model_mle = m_mle """Train the model for one epoch.""" # Initialize data iterator itr = epoch_itr.next_epoch_itr( fix_batches_to_gpus=args.fix_batches_to_gpus, shuffle=(epoch_itr.next_epoch_idx > args.curriculum), ) update_freq = ( args.update_freq[epoch_itr.epoch - 1] if epoch_itr.epoch <= len(args.update_freq) else args.update_freq[-1] ) itr = iterators.GroupedIterator(itr, update_freq) progress = progress_bar.progress_bar( itr, log_format=args.log_format, log_interval=args.log_interval, epoch=epoch_itr.epoch, tensorboard_logdir=( args.tensorboard_logdir if distributed_utils.is_master(args) else None ), default_log_format=('tqdm' if not args.no_progress_bar else 'simple'), ) # task specific setup per epoch task.begin_epoch(epoch_itr.epoch, trainer.get_model()) valid_subsets = args.valid_subset.split(',') max_update = args.max_update or math.inf should_end_training = False for samples in progress: if True: # warning valid_losses = validate_and_save(args, trainer, task, epoch_itr, valid_subsets) with metrics.aggregate('train_inner'): # Debug: training goes here log_output = trainer.train_step(samples) if log_output is None: # OOM, overflow, ... continue # Log mid-epoch stats num_updates = trainer.get_num_updates() if num_updates % args.log_interval == 0: stats = get_training_stats(metrics.get_smoothed_values('train_inner')) progress.log(stats, tag='train_inner', step=num_updates) # Reset mid-epoch stats after each log interval # the end-of-epoch stats will still be preserved metrics.reset_meters('train_inner') if num_updates > 2 and num_updates % (args.policy_update_per_k_epoch) == 0: # warning del model_old torch.cuda.empty_cache() model_old = copy.deepcopy(trainer.model) valid_losses = validate_and_save(args, trainer, task, epoch_itr, valid_subsets) if should_stop_early(args, valid_losses[0]) or num_updates >= max_update: should_end_training = True break # Log end-of-epoch stats stats = get_training_stats(metrics.get_smoothed_values('train')) progress.print(stats, tag='train', step=num_updates) # Reset epoch-level meters metrics.reset_meters('train') return should_end_training
def train(args, trainer, task, epoch_itr): """Train the model for one epoch and return validation losses.""" logger.info("begin training epoch {}".format(epoch_itr.epoch)) # Initialize data iterator itr = epoch_itr.next_epoch_itr( fix_batches_to_gpus=args.fix_batches_to_gpus, shuffle=(epoch_itr.next_epoch_idx > args.curriculum), ) update_freq = (args.update_freq[epoch_itr.epoch - 1] if epoch_itr.epoch <= len(args.update_freq) else args.update_freq[-1]) itr = iterators.GroupedIterator(itr, update_freq) if getattr(args, "tpu", False): itr = tpu_data_loader(args, itr) progress = progress_bar.progress_bar( itr, log_format=args.log_format, log_interval=args.log_interval, epoch=epoch_itr.epoch, tensorboard_logdir=(args.tensorboard_logdir if distributed_utils.is_master(args) else None), default_log_format=("tqdm" if not args.no_progress_bar else "simple"), ) trainer.begin_epoch(epoch_itr.epoch) valid_subsets = args.valid_subset.split(",") should_stop = False time_cost = 0 for i, samples in enumerate(progress): ##### statistic program if args.validate_training_performance: performance_end_its = args.performance_begin_its + args.performance_its_count - 1 if args.validate_training_performance and i == args.performance_begin_its: processed_tokens = 0 with metrics.aggregate( "train_inner"), torch.autograd.profiler.record_function( "train_step-%d" % i): time_begin = time.time() log_output = trainer.train_step(samples) time_end = time.time() if args.validate_training_performance and i >= args.performance_begin_its and i <= performance_end_its: time_cost = time_cost + (time_end - time_begin) if log_output is None: # OOM, overflow, ... continue # log mid-epoch stats num_updates = trainer.get_num_updates() if num_updates % args.log_interval == 0: stats = get_training_stats( metrics.get_smoothed_values("train_inner")) progress.log(stats, tag="train_inner", step=num_updates) # reset mid-epoch stats after each log interval # the end-of-epoch stats will still be preserved metrics.reset_meters("train_inner") end_of_epoch = not itr.has_next() valid_losses, should_stop = validate_and_save(args, trainer, task, epoch_itr, valid_subsets, end_of_epoch) if args.validate_training_performance and i >= args.performance_begin_its: for sample in samples: net_input = sample['net_input'] bs, src_lens = net_input['src_tokens'].shape processed_tokens += bs * src_lens if args.validate_training_performance and i == performance_end_its: logger.info("Performance info:") logger.info("Begin iteration:{}".format( args.performance_begin_its)) logger.info("End iteration: {}".format(performance_end_its)) logger.info("Processed_tokens: {}".format(processed_tokens)) logger.info("Time cost: {} s".format(time_cost)) logger.info("Throughput:{} tokens/s".format(processed_tokens / (time_cost))) should_stop = True if should_stop: break # log end-of-epoch stats logger.info("end of epoch {} (average epoch stats below)".format( epoch_itr.epoch)) stats = get_training_stats(metrics.get_smoothed_values("train")) progress.print(stats, tag="train", step=num_updates) # reset epoch-level meters metrics.reset_meters("train") return valid_losses, should_stop
def train(args, trainer, task, epoch_itr): """Train the model for one epoch.""" # Initialize data iterator itr = epoch_itr.next_epoch_itr( fix_batches_to_gpus=args.fix_batches_to_gpus, shuffle=(epoch_itr.next_epoch_idx > args.curriculum), ) update_freq = (args.update_freq[epoch_itr.epoch - 1] if epoch_itr.epoch <= len(args.update_freq) else args.update_freq[-1]) itr = iterators.GroupedIterator(itr, update_freq) progress = progress_bar.progress_bar( itr, log_format=args.log_format, log_interval=args.log_interval, epoch=epoch_itr.epoch, tensorboard_logdir=(args.tensorboard_logdir if distributed_utils.is_master(args) else None), default_log_format=('tqdm' if not args.no_progress_bar else 'simple'), ) # task specific setup per epoch task.begin_epoch(epoch_itr.epoch, trainer.get_model()) valid_subsets = args.valid_subset.split(',') max_update = args.max_update or math.inf should_end_training = False for samples in progress: with metrics.aggregate('train_inner'): try: log_output = trainer.train_step(samples) except ResetTrainerException: trainer._wrapped_criterion = None trainer._wrapped_model = None trainer._optimizer = None logger.info("reset the trainer at {}".format( trainer.get_num_updates())) log_output = trainer.train_step(samples) if log_output is None: # OOM, overflow, ... continue # log mid-epoch stats num_updates = trainer.get_num_updates() if num_updates % args.log_interval == 0: stats = get_training_stats( metrics.get_smoothed_values('train_inner')) progress.log(stats, tag='train_inner', step=num_updates) # reset mid-epoch stats after each log interval # the end-of-epoch stats will still be preserved metrics.reset_meters('train_inner') valid_losses = validate_and_save(args, trainer, task, epoch_itr, valid_subsets) if should_stop_early(args, valid_losses[0]) or num_updates >= max_update: should_end_training = True break # log end-of-epoch stats stats = get_training_stats(metrics.get_smoothed_values('train')) progress.print(stats, tag='train', step=num_updates) # reset epoch-level meters metrics.reset_meters('train') return should_end_training
def _reduce_and_log_stats(self, logging_outputs, sample_size, grad_norm=None, gvar=None, adam_mom2=None, gvar_diff=None, xstd=None, ams_mom=None, acc_ratio=None, real_var=None, real_var_diff=None, ad_beta=None, lr_min=None, lr_max=None, lr_median=None, update_min=None, update_max=None, update_median=None, valid_ratio=None, var_adapt=None): if grad_norm is not None: metrics.log_speed("ups", 1., priority=100, round=2) metrics.log_scalar("gnorm", grad_norm, priority=400, round=3) if self.args.clip_norm > 0: metrics.log_scalar( "clip", torch.where( grad_norm > self.args.clip_norm, grad_norm.new_tensor(100), grad_norm.new_tensor(0), ), priority=500, round=1, ) if gvar is not None: metrics.log_scalar("gvar", gvar, priority=100) if adam_mom2 is not None: metrics.log_scalar("adam_mom2", adam_mom2, priority=100) if gvar_diff is not None: metrics.log_scalar("gvar_diff", gvar_diff, priority=100) if xstd is not None: metrics.log_scalar("xstd", xstd, priority=100) if ams_mom is not None: metrics.log_scalar("ams_mom", ams_mom, priority=100) if acc_ratio is not None: metrics.log_scalar("acc_ratio", acc_ratio, priority=50) if real_var is not None: metrics.log_scalar("real_var", real_var, priority=50) if real_var_diff is not None: metrics.log_scalar("real_var_diff", real_var_diff, priority=50) if ad_beta is not None: metrics.log_scalar("ad_beta", ad_beta, priority=50) if lr_min is not None: metrics.log_scalar("lr_min", lr_min, priority=50) if lr_max is not None: metrics.log_scalar("lr_max", lr_max, priority=50) if lr_median is not None: metrics.log_scalar("lr_median", lr_median, priority=50) if update_min is not None: metrics.log_scalar("update_min", update_min, priority=50) if update_median is not None: metrics.log_scalar("update_median", update_median, priority=50) if update_max is not None: metrics.log_scalar("update_max", update_max, priority=50) if valid_ratio is not None: metrics.log_scalar("valid_ratio", valid_ratio, priority=49) if var_adapt is not None: metrics.log_scalar("var_adapt", var_adapt, priority=1) with metrics.aggregate() as agg: if logging_outputs is not None: self.task.reduce_metrics(logging_outputs, self.get_criterion()) preds, labels = [], [] for log_output in logging_outputs: if 'preds' in log_output: preds.append(log_output['preds']) labels.append(log_output['labels']) else: preds = None labels = None # support legacy interface logging_output = agg.get_smoothed_values() logging_output["sample_size"] = sample_size for key_to_delete in ["ppl", "wps", "wpb", "bsz"]: if key_to_delete in logging_output: del logging_output[key_to_delete] return logging_output, preds, labels
def train(args, trainer, task, epoch_itr, model, experiment_path, total_samples=None, last_epoch_num=0, restore=None): """Train the model for one epoch and return validation losses.""" # Initialize data iterator itr = epoch_itr.next_epoch_itr( fix_batches_to_gpus=args.fix_batches_to_gpus, shuffle=(epoch_itr.next_epoch_idx > args.curriculum), ) update_freq = (args.update_freq[epoch_itr.epoch - 1] if epoch_itr.epoch <= len(args.update_freq) else args.update_freq[-1]) itr = iterators.GroupedIterator(itr, update_freq) if getattr(args, "tpu", False): itr = tpu_data_loader(args, itr) progress = progress_bar.progress_bar( itr, log_format=args.log_format, log_interval=args.log_interval, epoch=epoch_itr.epoch, tensorboard_logdir=(args.tensorboard_logdir if distributed_utils.is_master(args) else None), default_log_format=("tqdm" if not args.no_progress_bar else "simple"), ) num_heads = args.decoder_attention_heads head_dim = args.decoder_embed_dim // num_heads if experiment_path is not None: with open(experiment_path, 'r') as f: swaps = json.load(f) mhr(model, swaps, head_dim, num_heads, epoch_itr.epoch) trainer.begin_epoch(epoch_itr.epoch) valid_subsets = args.valid_subset.split(",") should_stop = False conf = { "encoder": [{ "self_attn": [] } for i in range(args.encoder_layers)], "decoder": [{ "self_attn": [], "enc_attn": [] } for i in range(args.decoder_layers)] } attentions = { "decoder": [{ "self_attn": [] } for i in range(args.decoder_layers)] } batch_regression = 1.0 - (total_samples / (160239 * 50)) for i, samples in enumerate(progress): with metrics.aggregate( "train_inner"), torch.autograd.profiler.record_function( "train_step-%d" % i): log_output = trainer.train_step(samples, batch_num=batch_regression) if log_output is None: # OOM, overflow, ... continue total_samples += model.decoder.layers[0].self_attn.bsz batch_regression = 1.0 - ( total_samples / (160239 * 40) ) # need to find more generic way to find total samples and epoch num. # Get Confidence for each Head. if args.head_confidence_method is not None: conf = get_batch_confs(model, conf, args) # log mid-epoch stats num_updates = trainer.get_num_updates() if num_updates % args.log_interval == 0: stats = get_training_stats( metrics.get_smoothed_values("train_inner")) progress.log(stats, tag="train_inner", step=num_updates) # reset mid-epoch stats after each log interval # the end-of-epoch stats will still be preserved metrics.reset_meters("train_inner") end_of_epoch = not itr.has_next() valid_losses, should_stop, val_conf = validate_and_save( args, trainer, task, epoch_itr, valid_subsets, end_of_epoch) if should_stop: break if args.head_confidence_method is not None: conf = convert_confs(conf, args) path = args.save_dir.replace("checkpoints", "confs") + "-method={0}".format( args.head_confidence_method) try: os.mkdir(path, 0o775) except: pass with open( args.save_dir.replace("checkpoints", "confs") + "-method={0}".format(args.head_confidence_method) + "/epoch-{0}.pkl".format(epoch_itr.epoch), 'wb') as fd: pickle.dump(conf, fd, protocol=3) if args.dynamic_type is not None and args.head_confidence_method is not None: conf = val_conf restore['enc_self_attn'], last_epoch_num[ 'enc_self_attn'] = dynamic_mhr(model, int(args.start_dynamic_mhr[0]), "encoder", "self_attn", restore['enc_self_attn'], int(args.dynamic_swap_frequency[0]), last_epoch_num['enc_self_attn'], epoch_itr.epoch + 1, int(args.dynamic_max_switches[0]), conf[0], num_heads, head_dim, args.encoder_layers, local_only=False, d_type=args.dynamic_type[0], rest=int(args.dynamic_rest[0]), end_epoch=int( args.dynamic_end_epoch[0])) restore['dec_self_attn'], last_epoch_num[ 'dec_self_attn'] = dynamic_mhr(model, int(args.start_dynamic_mhr[1]), "decoder", "self_attn", restore['dec_self_attn'], int(args.dynamic_swap_frequency[1]), last_epoch_num['dec_self_attn'], epoch_itr.epoch + 1, int(args.dynamic_max_switches[1]), conf[1], num_heads, head_dim, args.encoder_layers, local_only=False, d_type=args.dynamic_type[1], rest=int(args.dynamic_rest[1]), end_epoch=int( args.dynamic_end_epoch[1])) restore['dec_enc_attn'], last_epoch_num['dec_enc_attn'] = dynamic_mhr( model, int(args.start_dynamic_mhr[2]), "decoder", "encoder_attn", restore['dec_enc_attn'], int(args.dynamic_swap_frequency[2]), last_epoch_num['dec_enc_attn'], epoch_itr.epoch + 1, int(args.dynamic_max_switches[2]), conf[2], num_heads, head_dim, args.encoder_layers, local_only=False, d_type=args.dynamic_type[2], rest=int(args.dynamic_rest[2]), end_epoch=int(args.dynamic_end_epoch[2])) # log end-of-epoch stats stats = get_training_stats(metrics.get_smoothed_values("train")) progress.print(stats, tag="train", step=num_updates) # reset epoch-level meters metrics.reset_meters("train") return valid_losses, should_stop, total_samples, restore, last_epoch_num
def train(args, trainer, task, epoch_itr, max_update=math.inf, model=None): """Train the model for one epoch and return validation losses.""" # Initialize data iterator itr = epoch_itr.next_epoch_itr( fix_batches_to_gpus=args.fix_batches_to_gpus, shuffle=(epoch_itr.next_epoch_idx > args.curriculum), ) update_freq = ( args.update_freq[epoch_itr.epoch - 1] if epoch_itr.epoch <= len(args.update_freq) else args.update_freq[-1] ) itr = iterators.GroupedIterator(itr, update_freq) progress = progress_bar.progress_bar( itr, log_format=args.log_format, log_interval=args.log_interval, epoch=epoch_itr.epoch, tensorboard_logdir=( args.tensorboard_logdir if distributed_utils.is_master(args) else None ), default_log_format=('tqdm' if not args.no_progress_bar else 'simple'), ) # task specific setup per epoch task.begin_epoch(epoch_itr.epoch, trainer.get_model()) valid_subsets = args.valid_subset.split(',') for i, samples in enumerate(progress): with metrics.aggregate('train_inner'): log_output = trainer.train_step(samples) if log_output is None: # OOM, overflow, ... continue # log mid-epoch stats num_updates = trainer.get_num_updates() if num_updates % args.log_interval == 0: stats = get_training_stats(metrics.get_smoothed_values('train_inner')) progress.log(stats, tag='train_inner', step=num_updates) # reset mid-epoch stats after each log interval # the end-of-epoch stats will still be preserved metrics.reset_meters('train_inner') if(i==0): print('epoch: ', epoch_itr.epoch) endeattn_norm=[] selfattn_norm=[] for m in model.modules(): if(hasattr(m, 'selfattn_norm')): if(m.selfattn_norm != None): selfattn_norm.append(m.selfattn_norm) if(hasattr(m, 'endeattn_norm')): if(m.endeattn_norm != None): endeattn_norm.append(m.endeattn_norm) print('self attention norms: ', selfattn_norm) print('en/decoder attn norms:', endeattn_norm) valid_losses = validate_and_save(args, trainer, task, epoch_itr, valid_subsets) if should_stop_early(args, valid_losses[0]) or num_updates >= max_update: break # log end-of-epoch stats stats = get_training_stats(metrics.get_smoothed_values('train')) progress.print(stats, tag='train', step=num_updates) # reset epoch-level meters metrics.reset_meters('train') return valid_losses
def validate_iw(args, trainer, task, epoch_itr, subsets, prune=-1, mode='iw'): """Evaluate the model on the validation set(s) and return the losses.""" if mode == 'none' or mode == 'time' or args.criterior == 'lm_baseline': return [0] # top k instead of sampling to approximate sum of prototypes for evaluation for subset in subsets: task.dataset(subset).set_sampling(False) if args.fixed_validation_seed is not None: # set fixed seed for every validation utils.set_torch_seed(args.fixed_validation_seed) valid_losses = [] for subset in subsets: # Initialize data iterator itr = task.get_batch_iterator( dataset=task.dataset(subset), max_tokens=args.max_tokens_valid, max_sentences=args.max_sentences_valid, max_positions=utils.resolve_max_positions( task.max_positions(), trainer.get_model().max_positions(), ), ignore_invalid_inputs=args.skip_invalid_size_inputs_valid_test, required_batch_size_multiple=args.required_batch_size_multiple, seed=args.seed, num_shards=args.distributed_world_size, shard_id=args.distributed_rank, num_workers=args.num_workers, ).next_epoch_itr(shuffle=False) progress = progress_bar.progress_bar( itr, log_format=args.log_format, log_interval=args.log_interval, epoch=1, prefix=f"valid on '{subset}' subset", tensorboard_logdir=(args.tensorboard_logdir if distributed_utils.is_master(args) else None), default_log_format=('tqdm' if not args.no_progress_bar else 'simple'), ) if prune > 0: index_map = trainer.get_model().set_prune_index(prune) task.set_index_map(index_map) # create a new root metrics aggregator so validation metrics # don't pollute other aggregators (e.g., train meters) with metrics.aggregate(new_root=True) as agg: for sample in progress: trainer.valid_iw_step(sample, mode=mode) # log validation stats stats = get_valid_stats(args, trainer, agg.get_smoothed_values()) progress.print(stats, tag='valid_iw', step=trainer.get_num_updates()) # valid_losses.append(stats[args.best_checkpoint_metric]) if prune > 0: trainer.get_model().reset_prune_index() task.reset_index_map() return valid_losses
def validate(args, trainer, task, epoch_itr, subsets, prune=-1): """Evaluate the model on the validation set(s) and return the losses.""" if args.fixed_validation_seed is not None: # set fixed seed for every validation utils.set_torch_seed(args.fixed_validation_seed) valid_losses = [] for subset in subsets: # Initialize data iterator itr = task.get_batch_iterator( dataset=task.dataset(subset), max_tokens=args.max_tokens_valid, max_sentences=args.max_sentences_valid, max_positions=utils.resolve_max_positions( task.max_positions(), trainer.get_model().max_positions(), ), ignore_invalid_inputs=args.skip_invalid_size_inputs_valid_test, required_batch_size_multiple=args.required_batch_size_multiple, seed=args.seed, num_shards=args.distributed_world_size, shard_id=args.distributed_rank, num_workers=args.num_workers, ).next_epoch_itr(shuffle=False) progress = progress_bar.progress_bar( itr, log_format=args.log_format, log_interval=args.log_interval, epoch=epoch_itr.epoch, prefix=f"valid on '{subset}' subset", tensorboard_logdir=(args.tensorboard_logdir if distributed_utils.is_master(args) else None), default_log_format=('tqdm' if not args.no_progress_bar else 'simple'), ) # added by Junxian if prune > 0: index_map = trainer.get_model().set_prune_index(prune) task.set_index_map(index_map) # not write templates for time profiling write_template_flag = False if args.eval_mode == 'time' else True # only one worker deals with the template file in DDP if args.distributed_rank == 0 and write_template_flag: print('write template files') if args.eval_mode == 'none': fout = open( os.path.join( args.save_dir, 'templates_{}_{}.txt'.format( epoch_itr.epoch, trainer.get_num_updates())), 'w') else: fout = open( os.path.join(args.save_dir, 'templates_eval_{}.txt'.format(subset)), 'w') if prune <= 0: task.write_lambda(fout, trainer.get_model()) else: fout = None # create a new root metrics aggregator so validation metrics # don't pollute other aggregators (e.g., train meters) with metrics.aggregate(new_root=True) as agg: for sample in progress: trainer.valid_step(sample, split=subset) # added by Junxian if args.distributed_rank == 0: task.write_template(sample, trainer.get_model(), fout) if fout is not None: fout.close() # log validation stats stats = get_valid_stats(args, trainer, agg.get_smoothed_values()) progress.print(stats, tag=subset, step=trainer.get_num_updates()) valid_losses.append(stats[args.best_checkpoint_metric]) return valid_losses
def main(args, override_args=None): utils.import_user_module(args) assert args.max_tokens is not None or args.max_sentences is not None, \ 'Must specify batch size either with --max-tokens or --max-sentences' use_fp16 = args.fp16 use_cuda = torch.cuda.is_available() and not args.cpu if override_args is not None: try: override_args = override_args['override_args'] except TypeError: override_args = override_args overrides = vars(override_args) overrides.update(eval(getattr(override_args, 'model_overrides', '{}'))) else: overrides = None # Load ensemble logger.info('loading model(s) from {}'.format(args.path)) models, model_args, task = checkpoint_utils.load_model_ensemble_and_task( [args.path], arg_overrides=overrides, suffix=getattr(args, "checkpoint_suffix", ""), ) model = models[0] # Move models to GPU for model in models: if use_fp16: model.half() if use_cuda: model.cuda() # Print args logger.info(model_args) # Build criterion criterion = task.build_criterion(model_args) if use_fp16: criterion.half() if use_cuda: criterion.cuda() criterion.eval() for subset in args.valid_subset.split(','): try: task.load_dataset(subset, combine=False, epoch=1) dataset = task.dataset(subset) except KeyError: raise Exception('Cannot find dataset: ' + subset) # Initialize data iterator itr = task.get_batch_iterator( dataset=dataset, max_tokens=args.max_tokens, max_sentences=args.max_sentences, max_positions=utils.resolve_max_positions( task.max_positions(), *[m.max_positions() for m in models], ), ignore_invalid_inputs=args.skip_invalid_size_inputs_valid_test, required_batch_size_multiple=args.required_batch_size_multiple, seed=args.seed, num_workers=args.num_workers, num_shards=args.distributed_world_size, shard_id=args.distributed_rank).next_epoch_itr(shuffle=False) progress = progress_bar.progress_bar( itr, log_format=args.log_format, log_interval=args.log_interval, prefix=f"valid on '{subset}' subset", default_log_format=('tqdm' if not args.no_progress_bar else 'simple'), ) log_outputs = [] for i, sample in enumerate(progress): sample = utils.move_to_cuda(sample) if use_cuda else sample sample = utils.apply_to_sample( lambda t: t.half() if t.dtype is torch.float32 else t, sample) if use_fp16 else sample try: with torch.no_grad(): # do not save backward passes max_num_rays = 900 * 900 if sample['uv'].shape[3] > max_num_rays: sample['ray_split'] = sample['uv'].shape[ 3] // max_num_rays _loss, _sample_size, log_output = task.valid_step( sample, model, criterion) progress.log(log_output, step=i) log_outputs.append(log_output) except TypeError: break with metrics.aggregate() as agg: task.reduce_metrics(log_outputs, criterion) log_output = agg.get_smoothed_values() # summarize all the gpus if args.distributed_world_size > 1: all_log_output = list( zip(*distributed_utils.all_gather_list([log_output])))[0] log_output = { key: np.mean([log[key] for log in all_log_output]) for key in all_log_output[0] } progress.print(log_output, tag=subset, step=i)