def main(args, init_distributed=False): utils.import_user_module(args) assert args.max_tokens is not None or args.max_sentences is not None, \ 'Must specify batch size either with --max-tokens or --max-sentences' # Initialize CUDA and distributed training if torch.cuda.is_available() and not args.cpu: torch.cuda.set_device(args.device_id) torch.manual_seed(args.seed) if init_distributed: args.distributed_rank = distributed_utils.distributed_init(args) # Print args print(args) # Setup task, e.g., translation, language modeling, etc. task = tasks.setup_task(args) # Load valid dataset (we load training data below, based on the latest checkpoint) for valid_sub_split in args.valid_subset.split(','): task.load_dataset(valid_sub_split, combine=True, epoch=0) # Build model and criterion model = task.build_model(args) criterion = task.build_criterion(args) print(model) print('| model {}, criterion {}'.format(args.arch, criterion.__class__.__name__)) print('| num. model params: {} (num. trained: {})'.format( sum(p.numel() for p in model.parameters()), sum(p.numel() for p in model.parameters() if p.requires_grad), )) # Build trainer trainer = Trainer(args, task, model, criterion) print('| training on {} GPUs'.format(args.distributed_world_size)) print('| max tokens per GPU = {} and max sentences per GPU = {}'.format( args.max_tokens, args.max_sentences, )) # Load the latest checkpoint if one is available and restore the # corresponding train iterator extra_state, epoch_itr = checkpoint_utils.load_checkpoint(args, trainer) # Train until the learning rate gets too small max_epoch = args.max_epoch or math.inf max_update = args.max_update or math.inf lr = trainer.get_lr() train_meter = StopwatchMeter() train_meter.start() valid_losses = [None] valid_subsets = args.valid_subset.split(',') while lr > args.min_lr and epoch_itr.epoch < max_epoch and trainer.get_num_updates( ) < max_update: # train for one epoch train(args, trainer, task, epoch_itr) if not args.disable_validation and epoch_itr.epoch % args.validate_interval == 0: valid_losses = validate(args, trainer, task, epoch_itr, valid_subsets) else: valid_losses = [None] # only use first validation loss to update the learning rate lr = trainer.lr_step(epoch_itr.epoch, valid_losses[0]) # save checkpoint if epoch_itr.epoch % args.save_interval == 0: checkpoint_utils.save_checkpoint(args, trainer, epoch_itr, valid_losses[0]) if ':' in getattr(args, 'data', ''): # sharded data: get train iterator for next epoch epoch_itr = trainer.get_train_iterator(epoch_itr.epoch) train_meter.stop() print('| done training in {:.1f} seconds'.format(train_meter.sum))
def main(args, init_distributed=False): utils.import_user_module(args) try: from fairseq.fb_pathmgr import fb_pathmgr global fb_pathmgr_registerd if not fb_pathmgr_registerd: fb_pathmgr.register() fb_pathmgr_registerd = True except (ModuleNotFoundError, ImportError): pass assert args.max_tokens is not None or args.max_sentences is not None, \ 'Must specify batch size either with --max-tokens or --max-sentences' # Initialize CUDA and distributed training if torch.cuda.is_available() and not args.cpu: torch.cuda.set_device(args.device_id) np.random.seed(args.seed) torch.manual_seed(args.seed) if init_distributed: args.distributed_rank = distributed_utils.distributed_init(args) if distributed_utils.is_master(args): checkpoint_utils.verify_checkpoint_directory(args.save_dir) # Print args print(args) # Setup task, e.g., translation, language modeling, etc. task = tasks.setup_task(args) # Load valid dataset (we load training data below, based on the latest checkpoint) for valid_sub_split in args.valid_subset.split(','): task.load_dataset(valid_sub_split, combine=False, epoch=0) # Build model and criterion model = task.build_model(args) criterion = task.build_criterion(args) print(model) print('| model {}, criterion {}'.format(args.arch, criterion.__class__.__name__)) print('| num. model params: {} (num. trained: {})'.format( sum(p.numel() for p in model.parameters()), sum(p.numel() for p in model.parameters() if p.requires_grad), )) # Build trainer trainer = Trainer(args, task, model, criterion) print('| training on {} GPUs'.format(args.distributed_world_size)) print('| max input frames per GPU = {} and max sentences per GPU = {}'.format( args.max_tokens, args.max_sentences, )) # Load the latest checkpoint if one is available and restore the # corresponding train iterator extra_state, epoch_itr = checkpoint_utils.load_checkpoint(args, trainer) if callable(getattr(trainer.criterion, 'set_train_tgt_dataset', None)): trainer.criterion.set_train_tgt_dataset(task.dataset(args.train_subset).tgt) # Train until the learning rate gets too small max_epoch = args.max_epoch or math.inf max_update = args.max_update or math.inf lr = trainer.get_lr() train_meter = StopwatchMeter() train_meter.start() valid_subsets = args.valid_subset.split(',') while ( (lr >= args.min_lr or trainer.get_num_updates() <= getattr(args, 'warmup_updates', 0)) and ( epoch_itr.epoch < max_epoch or ( epoch_itr.epoch == max_epoch and epoch_itr._next_epoch_itr is not None ) ) and trainer.get_num_updates() < max_update ): # train for one epoch train(args, trainer, task, epoch_itr) if not args.disable_validation and epoch_itr.epoch % args.validate_interval == 0: valid_losses = validate(args, trainer, task, epoch_itr, valid_subsets) # only use first validation wer to update the learning rate lr = trainer.lr_step(epoch_itr.epoch, valid_losses[0]) # save checkpoint if epoch_itr.epoch % args.save_interval == 0: checkpoint_utils.save_checkpoint(args, trainer, epoch_itr, valid_losses[0]) reload_dataset = len(args.train_feat_files) > 1 # sharded data: get train iterator for next epoch epoch_itr = trainer.get_train_iterator(epoch_itr.epoch, load_dataset=reload_dataset) train_meter.stop() print('| done training in {:.1f} seconds'.format(train_meter.sum))
def train(args, trainer, task, epoch_itr): """Train the model for one epoch.""" # Update parameters every N batches update_freq = args.update_freq[epoch_itr.epoch - 1] \ if epoch_itr.epoch <= len(args.update_freq) else args.update_freq[-1] # Initialize data iterator itr = epoch_itr.next_epoch_itr( fix_batches_to_gpus=args.fix_batches_to_gpus, shuffle=(epoch_itr.epoch >= args.curriculum), ) itr = iterators.GroupedIterator(itr, update_freq) progress = progress_bar.build_progress_bar( args, itr, epoch_itr.epoch, no_progress_bar='simple', ) extra_meters = collections.defaultdict(lambda: AverageMeter()) valid_subsets = args.valid_subset.split(',') max_update = args.max_update or math.inf for i, samples in enumerate(progress, start=epoch_itr.iterations_in_epoch): log_output = trainer.train_step(samples) if log_output is None: continue # log mid-epoch stats stats = get_training_stats(trainer) for k, v in log_output.items(): if k in [ 'loss', 'nll_loss', 'ntokens', 'nsentences', 'sample_size' ]: continue # these are already logged above if 'loss' in k: extra_meters[k].update(v, log_output['sample_size']) else: extra_meters[k].update(v) stats[k] = extra_meters[k].avg progress.log(stats, tag='train', step=stats['num_updates']) # ignore the first mini-batch in words-per-second calculation if i == 0: trainer.get_meter('wps').reset() num_updates = trainer.get_num_updates() if (not args.disable_validation and args.save_interval_updates > 0 and num_updates % args.save_interval_updates == 0 and num_updates > 0): valid_losses = validate(args, trainer, task, epoch_itr, valid_subsets) checkpoint_utils.save_checkpoint(args, trainer, epoch_itr, valid_losses[0]) if num_updates >= max_update: break # log end-of-epoch stats stats = get_training_stats(trainer) # add epoch information #progress._log_epochInf_to_tensorboard('epoch_loss',stats['loss'],epoch_itr.epoch) #progress._log_epochInf_to_tensorboard('epoch_nll_loss', stats['nll_loss'], epoch_itr.epoch) #progress._log_epochInf_to_tensorboard('epoch_pll_loss', stats['nll_loss'], epoch_itr.epoch) for k, meter in extra_meters.items(): stats[k] = meter.avg progress.print(stats, tag='train', step=stats['num_updates']) # reset training meters for k in [ 'train_loss', 'train_nll_loss', 'wps', 'ups', 'wpb', 'bsz', 'gnorm', 'clip', ]: meter = trainer.get_meter(k) if meter is not None: meter.reset()
def main(args, init_distributed=False): utils.import_user_module(args) assert args.max_tokens is not None or args.max_sentences is not None, \ 'Must specify batch size either with --max-tokens or --max-sentences' # Initialize CUDA and distributed training if torch.cuda.is_available() and not args.cpu: torch.cuda.set_device(args.device_id) np.random.seed(args.seed) torch.manual_seed(args.seed) if init_distributed: args.distributed_rank = distributed_utils.distributed_init(args) if distributed_utils.is_master(args): checkpoint_utils.verify_checkpoint_directory(args.save_dir) # Print args print(args) # Setup task, e.g., translation, language modeling, etc. task = tasks.setup_task(args) # Load valid dataset (we load training data below, based on the latest checkpoint) for valid_sub_split in args.valid_subset.split(','): task.load_dataset(valid_sub_split, combine=False, epoch=0) # Build model and criterion model = task.build_model(args) criterion = task.build_criterion(args) adv_criterion = task.build_adversarial_criterion(args) adv = task.build_adversary(args, model) print(model) print('| model {}, criterion {}'.format(args.arch, criterion.__class__.__name__)) print('| num. model params: {} (num. trained: {})'.format( sum(p.numel() for p in model.parameters()), sum(p.numel() for p in model.parameters() if p.requires_grad), )) # Build trainer trainer = AdversarialTrainer(args, task, model, criterion, adv_criterion, adv) print('| training on {} GPUs'.format(args.distributed_world_size)) print('| max tokens per GPU = {} and max sentences per GPU = {}'.format( args.max_tokens, args.max_sentences, )) # Load the latest checkpoint if one is available and restore the # corresponding train iterator extra_state, epoch_itr, filtered_maxpos_indices = checkpoint_utils.load_checkpoint( args, trainer) # pretrain data actor if args.pretrain_data_actor and args.data_actor == 'lan' and args.data_actor_step_update: trainer.pretrain_data_actor() # Train until the learning rate gets too small max_epoch = args.max_epoch or math.inf max_update = args.max_update or math.inf lr = trainer.get_lr() train_meter = StopwatchMeter() train_meter.start() valid_subsets = args.valid_subset.split(',') if args.eval_bleu: gen_args = copy.deepcopy(args) gen_args.sample = False gen_args.beam = 5 gen_args.batch_size = 32 generator = task.build_generator(gen_args) args.maximize_best_checkpoint_metric = True else: generator = None while lr > args.min_lr and epoch_itr.epoch < max_epoch and trainer.get_num_updates( ) < max_update: # train for one epoch epoch_itr = train(args, trainer, task, epoch_itr, generator, filtered_maxpos_indices) #trainer.update_language_sampler(args) if not args.disable_validation and epoch_itr.epoch % args.validate_interval == 0: valid_losses = validate(args, trainer, task, epoch_itr, valid_subsets, generator) else: valid_losses = [None] # only use first validation loss to update the learning rate lr = trainer.lr_step(epoch_itr.epoch, valid_losses[0]) # save checkpoint if epoch_itr.epoch % args.save_interval == 0: checkpoint_utils.save_checkpoint(args, trainer, epoch_itr, valid_losses[0]) if ':' in getattr(args, 'data', ''): # sharded data: get train iterator for next epoch epoch_itr = trainer.get_train_iterator(epoch_itr.epoch)[0] train_meter.stop() print('| done training in {:.1f} seconds'.format(train_meter.sum)) for idx in sorted(trainer.idx_to_dev_grad_dotprod.keys()): print(idx) str_dotprod = [str(i) for i in trainer.idx_to_dev_grad_dotprod[idx]] print(" ".join(str_dotprod))
def train(args, trainer, task, epoch_itr): """Train the model for one epoch.""" # Initialize data iterator itr = epoch_itr.next_epoch_itr( fix_batches_to_gpus=args.fix_batches_to_gpus, shuffle=(epoch_itr.next_epoch_idx > args.curriculum), ) update_freq = (args.update_freq[epoch_itr.epoch - 1] if epoch_itr.epoch <= len(args.update_freq) else args.update_freq[-1]) itr = iterators.GroupedIterator(itr, update_freq) progress = progress_bar.progress_bar( itr, log_format=args.log_format, log_interval=args.log_interval, epoch=epoch_itr.epoch, tensorboard_logdir=(args.tensorboard_logdir if distributed_utils.is_master(args) else None), default_log_format=('tqdm' if not args.no_progress_bar else 'simple'), ) # task specific setup per epoch task.begin_epoch(epoch_itr.epoch, trainer.get_model()) valid_subsets = args.valid_subset.split(',') max_update = args.max_update or math.inf if hasattr(trainer.criterion, 'set_epoch'): trainer.criterion.set_epoch(epoch_itr.epoch) for samples in progress: with metrics.aggregate('train_inner'): log_output = trainer.train_step(samples) if log_output is None: # OOM, overflow, ... continue # log mid-epoch stats num_updates = trainer.get_num_updates() if num_updates % args.log_interval == 0: stats = get_training_stats( metrics.get_smoothed_values('train_inner')) progress.log(stats, tag='train_inner', step=num_updates) # reset mid-epoch stats after each log interval # the end-of-epoch stats will still be preserved metrics.reset_meters('train_inner') if (not args.disable_validation and args.save_interval_updates > 0 and num_updates % args.save_interval_updates == 0 and num_updates > 0): valid_losses = validate(args, trainer, task, epoch_itr, valid_subsets) checkpoint_utils.save_checkpoint(args, trainer, epoch_itr, valid_losses[0]) if num_updates >= max_update: break # log end-of-epoch stats stats = get_training_stats(metrics.get_smoothed_values('train')) progress.print(stats, tag='train', step=num_updates) # reset epoch-level meters metrics.reset_meters('train')
def main_tpu(args): def prepare_task(args, xla_device): # Setup task, e.g., translation, language modeling, etc. task = tasks.setup_task(args) # Load valid dataset (we load training data below, based on the latest checkpoint) for valid_sub_split in args.valid_subset.split(','): task.load_dataset(valid_sub_split, combine=True, epoch=0) # Build models and criteria to print some metadata torch.manual_seed(args.seed) model, criterion = task.build_model(args), task.build_criterion(args) xm.master_print(model) xm.master_print('| model {}, criterion {}'.format( args.arch, criterion.__class__.__name__)) xm.master_print('| num. model params: {} (num. trained: {})'.format( sum(p.numel() for p in model.parameters()), sum(p.numel() for p in model.parameters() if p.requires_grad))) model = model.to(xla_device) trainer = Trainer(args, task, model, criterion, xla_device=xla_device) lr = trainer.get_lr() # Load the latest checkpoint if one is available and restore the # corresponding train iterator # we overwrite distributed args here to shard data using torch_xla's # distributed training. trainer.args.distributed_rank = xm.get_ordinal() trainer.args.distributed_world_size = xm.xrt_world_size() extra_state, epoch_itr = checkpoint_utils.load_checkpoint(args, trainer) trainer.args.distributed_rank = 0 trainer.args.distributed_world_size = 1 trainer.meters_to_device(xla_device) valid_subsets = args.valid_subset.split(',') ordinal = xm.get_ordinal(defval=-1) device_str = ( str(xla_device) if ordinal < 0 else '{}/{}'.format(xla_device, ordinal) ) return task, trainer, model, epoch_itr, lr, valid_subsets, device_str def train_loop_fn(device, trainer, loader, last_batch_index): """ This is the main training loop. It trains for 1 epoch. """ def print_training_update(trainer, progress, args, i): stats = get_training_stats(trainer, args=args) stats['now'] = now() progress.log(stats, tag='train', step=trainer.get_num_updates()) progress.print_mid_epoch(i+1, force=True) stats, log_output, skip_stat_keys = None, None, {'clip'} max_update = args.max_update or math.inf for i, samples in enumerate(loader, start=epoch_itr.iterations_in_epoch): if i == last_batch_index: # last batches are incomplete break log_output = trainer.train_step(samples) reset_perf_training_meters(trainer, i, ignore_index=10) if (not (i % args.log_steps)) or (i == last_batch_index-1): step_args = trainer, progress, args, i xm.add_step_closure(print_training_update, args=step_args) num_updates = trainer.get_num_updates() if ( not args.disable_validation and args.save_interval_updates > 0 and num_updates % args.save_interval_updates == 0 and num_updates > 0 ): vloss = validate_subset( args, device, trainer, task, epoch_itr, valid_subsets[0] ) checkpoint_utils.save_checkpoint( args, trainer, epoch_itr, vloss.item(), epoch=epoch, end_of_epoch=False, ) if num_updates >= max_update: break def valid_loop_fn( args, device, trainer, progress, loader, last_batch_index ): extra_meters = collections.defaultdict(lambda: AverageMeter()) for i, sample in enumerate(loader): if i == last_batch_index: # last batches are of different size, will cause recompilations break log_output = trainer.valid_step(sample) for k, v in log_output.items(): if k in ['loss', 'nll_loss', 'ntokens', 'nsentences', 'sample_size']: continue extra_meters[k].update(v) stats = get_valid_stats(trainer, args) for k, meter in extra_meters.items(): stats[k] = meter.avg return stats def validate_subset(args, device, trainer, task, epoch_itr, subset): xm.master_print('Validating the subset "{}", {}'.format(subset, now())) # Initialize data iterator # we're not sharding the validation set itr = task.get_batch_iterator( dataset=task.dataset(subset), max_tokens=args.max_tokens, max_sentences=args.max_sentences_valid, max_positions=utils.resolve_max_positions( task.max_positions(), trainer.get_model().max_positions(), ), ignore_invalid_inputs=args.skip_invalid_size_inputs_valid_test, required_batch_size_multiple=args.required_batch_size_multiple, seed=args.seed, num_workers=args.num_workers ).next_epoch_itr(shuffle=False) progress = progress_bar.build_progress_bar( args, itr, epoch_itr.epoch, prefix='valid on {} \'{}\' subset'.format(device, subset), no_progress_bar='simple' ) para_loader = pl.ParallelLoader(progress, [xla_device]) reset_validation_loss_meters(trainer) stats = valid_loop_fn( args, device, trainer, progress, para_loader.per_device_loader(xla_device), len(progress) - 1 ) progress_bar.progress_bar_print( progress, stats, step=trainer.get_num_updates(), force=True, tag='validate-{}'.format(subset), flush_writer=True, ) xm.master_print('Validated the subset "{}", {}'.format(subset, now())) return stats['loss'].avg def validate_subsets(args, device, trainer, task, epoch_itr, subsets): valid_losses = { subset: validate_subset( args, device, trainer, task, epoch_itr, subset ) for subset in subsets } return valid_losses def keep_training(lr, epoch_itr, trainer): # Train until the learning rate gets too small max_epoch = args.max_epoch or math.inf max_update = args.max_update or math.inf lr, n_updates = trainer.get_lr(), trainer.get_num_updates() return ((lr > args.min_lr) and (epoch_itr.epoch < max_epoch) and (n_updates < max_update)) if xu.getenv_as('XLA_USE_BF16', bool, False): xm.master_print( 'WARNING: bfloat16 is enabled. Note that fairseq meters such as ' 'loss will accumulate the numerator, and increment the denominator.' ' Due to lack of precision in higher numbers in bfloat16, these ' 'meters will report invalid values after a while.', fd=sys.stderr ) xm.master_print('Args', fd=sys.stderr) for key, val in args.__dict__.items(): xm.master_print('\t{} {}'.format(key, val), fd=sys.stderr) # `xla_device` is `torch.device` and `device` is `str` xla_device = xm.xla_device() task, trainer, model, epoch_itr, lr, valid_subsets, device = prepare_task( args, xla_device) train_meter = StopwatchMeter() train_meter.start() while keep_training(lr, epoch_itr, trainer): # TRAINING epoch = epoch_itr.epoch + 1 xm.master_print('Epoch {} begin {}'.format(epoch, now())) progress = initialize_loader_for_epoch( args, epoch_itr, prefix='training on {}'.format(device), ) skip_stat_keys = {'clip'} if args.suppress_loss_report: skip_stat_keys.update({'loss', 'nll_loss', 'gnorm'}) progress.set_keys_to_skip_mid_epoch(skip_stat_keys) para_loader = pl.ParallelLoader(progress, [xla_device]) train_loop_fn( device, trainer, para_loader.per_device_loader(xla_device), len(progress) - 1 ) training_stats = get_training_stats(trainer, args=args) tloss = training_stats['loss'].avg.item() progress_bar.progress_bar_print( progress, training_stats, tag='train', force=True, step=trainer.get_num_updates(), log_xla_metrics=True, flush_writer=True, ) xm.master_print('Epoch {} end {}'.format(epoch_itr.epoch, now())) if args.metrics_debug: xm.master_print(met.metrics_report()) reset_training_meters(trainer) # VALIDATION if not args.disable_validation and epoch_itr.epoch % args.validate_interval == 0: valid_losses = validate_subsets( args, device, trainer, task, epoch_itr, valid_subsets ) # only use average first validation loss to update learning rate vloss = valid_losses[valid_subsets[0]].item() xm.master_print('old learning rate: {}'.format(lr)) lr = trainer.lr_step(epoch_itr.epoch, vloss) xm.master_print('new learning rate: {}'.format(lr)) if args.metrics_debug: xm.master_print(met.metrics_report()) else: vloss = None # save checkpoint if epoch_itr.epoch % args.save_interval == 0: checkpoint_utils.save_checkpoint( args, trainer, epoch_itr, vloss, epoch=epoch, end_of_epoch=True, ) train_meter.stop() xm.master_print('| done training in {:.1f} seconds'.format(train_meter.sum)) assert_on_losses(args, train_loss=tloss, valid_loss=vloss)
def train(args, trainer, task, epoch_itr, generator=None, filtered_maxpos_indices=None): """Train the model for one epoch.""" # Update parameters every N batches update_freq = args.update_freq[epoch_itr.epoch - 1] \ if epoch_itr.epoch <= len(args.update_freq) else args.update_freq[-1] extra_meters = collections.defaultdict(lambda: AverageMeter()) valid_subsets = args.valid_subset.split(',') max_update = args.max_update or math.inf # data selection: reset epoch iter to filter out unselected data if epoch_itr.epoch == args.select_by_dds_epoch and args.select_by_dds_epoch > 0: epoch_itr, _ = trainer.get_filtered_train_iterator( epoch_itr.epoch, filtered_maxpos_indices=filtered_maxpos_indices) if args.update_language_sampling > 0 and args.select_by_dds_epoch < 0 and ( not args.data_actor_step_update): num_reset = len(epoch_itr.frozen_batches) // ( args.update_language_sampling * args.update_freq[0] + 1) datasize = args.update_language_sampling * args.update_freq[0] + 1 if num_reset * datasize < len(epoch_itr.frozen_batches): num_reset += 1 else: num_reset = 1 datasize = -1 for reset_idx in range(num_reset): print("resetting at step", reset_idx) # Initialize data iterator itr = epoch_itr.next_epoch_itr( fix_batches_to_gpus=args.fix_batches_to_gpus, shuffle=(epoch_itr.epoch >= args.curriculum), offset=reset_idx * (args.update_language_sampling * args.update_freq[0] + 1), datasize=datasize, ) itr = iterators.GroupedIterator(itr, update_freq) progress = progress_bar.build_progress_bar( args, itr, epoch_itr.epoch, no_progress_bar='simple', ) for i, samples in enumerate(progress, start=epoch_itr.iterations_in_epoch): if args.extra_data_actor == 'ave_emb': update_actor = (i % args.extra_update_language_sampling == 0) elif args.data_actor_step_update: update_actor = (i % args.update_language_sampling == 0) elif args.data_actor == 'lan' and args.data_actor_step_update: update_actor = (i % args.update_language_sampling == 0) else: update_actor = False if (epoch_itr.epoch > args.select_by_dds_epoch and args.select_by_dds_epoch > 0): update_actor = False log_output = trainer.train_step(samples, update_actor=update_actor) if log_output is None: continue # update sampling distribution if args.update_language_sampling > 0 and i % args.update_language_sampling == 0 and args.data_actor != 'ave_emb' and not args.data_actor_step_update: if args.data_actor_multilin: trainer.update_language_sampler_multilin( args, epoch=epoch_itr.epoch) else: trainer.update_language_sampler(args) # log mid-epoch stats stats = get_training_stats(trainer) for k, v in log_output.items(): if k in [ 'loss', 'nll_loss', 'ntokens', 'nsentences', 'sample_size' ]: continue # these are already logged above if 'loss' in k or k == 'accuracy': extra_meters[k].update(v, log_output['sample_size']) else: extra_meters[k].update(v) stats[k] = extra_meters[k].avg progress.log(stats, tag='train', step=stats['num_updates']) # ignore the first mini-batch in words-per-second calculation if i == 0: trainer.get_meter('wps').reset() num_updates = trainer.get_num_updates() if (not args.disable_validation and args.save_interval_updates > 0 and num_updates % args.save_interval_updates == 0 and num_updates > 0): valid_losses = validate(args, trainer, task, epoch_itr, valid_subsets, generator) checkpoint_utils.save_checkpoint(args, trainer, epoch_itr, valid_losses[0]) if num_updates >= max_update: break # log end-of-epoch stats stats = get_training_stats(trainer) for k, meter in extra_meters.items(): stats[k] = meter.avg progress.print(stats, tag='train', step=stats['num_updates']) # reset training meters for k in [ 'train_loss', 'train_nll_loss', 'wps', 'ups', 'wpb', 'bsz', 'gnorm', 'clip', ]: meter = trainer.get_meter(k) if meter is not None: meter.reset() return epoch_itr
def main(args, init_distributed=False): utils.import_user_module(args) utils.handle_save_path(args) assert args.max_tokens is not None or args.max_sentences is not None, \ 'Must specify batch size either with --max-tokens or --max-sentences' # Initialize CUDA and distributed training if torch.cuda.is_available() and not args.cpu: torch.cuda.set_device(args.device_id) torch.manual_seed(args.seed) if init_distributed: args.distributed_rank = distributed_utils.distributed_init(args) if distributed_utils.is_master(args): checkpoint_utils.verify_checkpoint_directory(args.save_dir) # Print args print(f"| Configs: {args}") # Setup task, e.g., translation, language modeling, etc. task = tasks.setup_task(args) # Load valid dataset (we load training data below, based on the latest checkpoint) for valid_sub_split in args.valid_subset.split(','): task.load_dataset(valid_sub_split, combine=False, epoch=0) # Build model and criterion model = task.build_model(args) criterion = task.build_criterion(args) print(f"| Model: {args.arch} \n| Criterion: {criterion.__class__.__name__}") # Log architecture if args.train_subtransformer: print(" \n\n\t\tWARNING!!! Training one single SubTransformer\n\n") print(f"| SubTransformer Arch: {utils.get_subtransformer_config(args)} \n") else: print(" \n\n\t\tWARNING!!! Training SuperTransformer\n\n") print(f"| SuperTransformer Arch: {model} \n") # Log model size if args.train_subtransformer: print(f"| SubTransformer size (without embedding weights): {model.get_sampled_params_numel(utils.get_subtransformer_config(args))}") embed_size = args.decoder_embed_dim_subtransformer * len(task.tgt_dict) print(f"| Embedding layer size: {embed_size} \n") else: model_s = 0 # if use model.state_dict, then will add 2 more parameters, they are encoder.version and decoder.version. Should not count them for name, param in model.named_parameters(): if 'embed' not in name: model_s += param.numel() print(f"| SuperTransofmer model size (without embedding weights): {model_s}") print(f"| Embedding layer size: {sum(p.numel() for p in model.parameters() if p.requires_grad) - model_s} \n") # specify the length of the dummy input for profile # for iwslt, the average length is 23, for wmt, that is 30 dummy_sentence_length_dict = {'iwslt': 23, 'wmt': 30} if 'iwslt' in args.arch: dummy_sentence_length = dummy_sentence_length_dict['iwslt'] elif 'wmt' in args.arch: dummy_sentence_length = dummy_sentence_length_dict['wmt'] else: raise NotImplementedError dummy_src_tokens = [2] + [7] * (dummy_sentence_length - 1) dummy_prev = [7] * (dummy_sentence_length - 1) + [2] # profile the overall FLOPs number if args.profile_flops: import torchprofile config_subtransformer = utils.get_subtransformer_config(args) model.set_sample_config(config_subtransformer) model.profile(mode=True) macs = torchprofile.profile_macs(model, args=(torch.tensor([dummy_src_tokens], dtype=torch.long), torch.tensor([30]), torch.tensor([dummy_prev], dtype=torch.long))) model.profile(mode=False) last_layer_macs = config_subtransformer['decoder']['decoder_embed_dim'] * dummy_sentence_length * len(task.tgt_dict) print(f"| Total FLOPs: {macs * 2}") print(f"| Last layer FLOPs: {last_layer_macs * 2}") print(f"| Total FLOPs without last layer: {(macs - last_layer_macs) * 2} \n") exit(0) # Build trainer trainer = Trainer(args, task, model, criterion) print(f"| Training on {args.distributed_world_size} GPUs") print(f"| Max tokens per GPU = {args.max_tokens} and max sentences per GPU = {args.max_sentences} \n") # Measure model latency, the program will exit after profiling latency if args.latcpu or args.latgpu: utils.measure_latency(args, model, dummy_src_tokens, dummy_prev) exit(0) # Load the latest checkpoint if one is available and restore the corresponding train iterator extra_state, epoch_itr = checkpoint_utils.load_checkpoint(args, trainer) # Evaluate the SubTransformer if args.validate_subtransformer: config = utils.get_subtransformer_config(args) trainer.set_sample_config(config) valid_loss = validate(args, trainer, task, epoch_itr, ['valid'], 'SubTransformer') print(f"| SubTransformer validation loss:{valid_loss}") # Loop boundaries max_epoch = args.max_epoch or math.inf max_update = args.max_update or math.inf lr = trainer.get_lr() train_meter = StopwatchMeter() train_meter.start() valid_subsets = args.valid_subset.split(',') represent_configs = utils.get_represent_configs(args) # Main training loop while lr > args.min_lr and epoch_itr.epoch < max_epoch and trainer.get_num_updates() < max_update: # train for one epoch train(args, trainer, task, epoch_itr) if not args.disable_validation and epoch_itr.epoch % args.validate_interval == 0: for k, v in represent_configs.items(): trainer.set_sample_config(config=v) valid_losses = validate(args, trainer, task, epoch_itr, valid_subsets, sampled_arch_name=k) else: valid_losses = [None] # update the best loss and get current lr; the real lr scheduling is done in trainer.train_step() lr = trainer.lr_step(epoch_itr.epoch, valid_losses[0]) # save checkpoint epoch level if epoch_itr.epoch % args.save_interval == 0: checkpoint_utils.save_checkpoint(args, trainer, epoch_itr, valid_losses[0]) train_meter.stop() print('| Done training in {:.1f} seconds'.format(train_meter.sum))
def train(args, trainer, task, epoch_itr): """Train the model for one epoch.""" # Update parameters every N batches update_freq = args.update_freq[epoch_itr.epoch - 1] \ if epoch_itr.epoch <= len(args.update_freq) else args.update_freq[-1] # Initialize data iterator itr = epoch_itr.next_epoch_itr( fix_batches_to_gpus=args.fix_batches_to_gpus, shuffle=(epoch_itr.epoch >= args.curriculum), ) itr = iterators.GroupedIterator(itr, update_freq) progress = progress_bar.build_progress_bar( args, itr, epoch_itr.epoch, ) extra_meters = collections.defaultdict(lambda: AverageMeter()) valid_subsets = args.valid_subset.split(',') max_update = args.max_update or math.inf represent_configs = utils.get_represent_configs(args) for i, samples in enumerate(progress, start=epoch_itr.iterations_in_epoch): if args.train_subtransformer: # training one SubTransformer only configs = [utils.get_subtransformer_config(args)] else: # training SuperTransformer by randomly sampling SubTransformers configs = [utils.sample_configs(utils.get_all_choices(args), reset_rand_seed=True, rand_seed=trainer.get_num_updates(), super_decoder_num_layer=args.decoder_layers)] log_output = trainer.train_step(samples, configs=configs) if log_output is None: continue # log mid-epoch stats stats = utils.get_training_stats(trainer) for k, v in log_output.items(): if k in ['loss', 'nll_loss', 'ntokens', 'nsentences', 'sample_size']: continue # these are already logged above if 'loss' in k or k == 'accuracy': extra_meters[k].update(v, log_output['sample_size']) else: extra_meters[k].update(v) stats[k] = extra_meters[k].avg utils.log_arch_info(stats, configs[0]) progress.log(stats, tag='train', step=stats['num_updates']) # ignore the first mini-batch in words-per-second calculation if i == 0: trainer.get_meter('wps').reset() num_updates = trainer.get_num_updates() if ( not args.disable_validation and args.save_interval_updates > 0 and num_updates % args.save_interval_updates == 0 and num_updates > 0 ): for k, v in represent_configs.items(): trainer.set_sample_config(config=v) valid_losses = validate(args, trainer, task, epoch_itr, valid_subsets, sampled_arch_name=k) checkpoint_utils.save_checkpoint(args, trainer, epoch_itr, valid_losses[0]) if num_updates >= max_update: break # log end-of-epoch stats stats = utils.get_training_stats(trainer) for k, meter in extra_meters.items(): stats[k] = meter.avg progress.print(stats, tag='train', step=stats['num_updates']) # reset training meters for k in [ 'train_loss', 'train_nll_loss', 'wps', 'ups', 'wpb', 'bsz', 'gnorm', 'clip', ]: meter = trainer.get_meter(k) if meter is not None: meter.reset()
def validate_and_save( cfg: DictConfig, trainer: Trainer, task: tasks.FairseqTask, epoch_itr, valid_subsets: List[str], end_of_epoch: bool, ) -> Tuple[List[Optional[float]], bool]: num_updates = trainer.get_num_updates() max_update = cfg.optimization.max_update or math.inf # Stopping conditions (and an additional one based on validation loss later # on) should_stop = False if num_updates >= max_update: should_stop = True logger.info( f"Stopping training due to " f"num_updates: {num_updates} >= max_update: {max_update}" ) training_time_hours = trainer.cumulative_training_time() / (60 * 60) if ( cfg.optimization.stop_time_hours > 0 and training_time_hours > cfg.optimization.stop_time_hours ): should_stop = True logger.info( f"Stopping training due to " f"cumulative_training_time: {training_time_hours} > " f"stop_time_hours: {cfg.optimization.stop_time_hours} hour(s)" ) do_save = ( (end_of_epoch and epoch_itr.epoch % cfg.checkpoint.save_interval == 0) or should_stop or ( cfg.checkpoint.save_interval_updates > 0 and num_updates > 0 and num_updates % cfg.checkpoint.save_interval_updates == 0 and num_updates >= cfg.dataset.validate_after_updates ) ) do_validate = ( (not end_of_epoch and do_save) # validate during mid-epoch saves or (end_of_epoch and epoch_itr.epoch % cfg.dataset.validate_interval == 0) or should_stop or ( cfg.dataset.validate_interval_updates > 0 and num_updates > 0 and num_updates % cfg.dataset.validate_interval_updates == 0 ) ) and not cfg.dataset.disable_validation # Validate valid_losses = [None] if do_validate: valid_losses = validate(cfg, trainer, task, epoch_itr, valid_subsets) should_stop |= should_stop_early(cfg, valid_losses[0]) # Save checkpoint if do_save or should_stop: checkpoint_utils.save_checkpoint( cfg.checkpoint, trainer, epoch_itr, valid_losses[0] ) return valid_losses, should_stop
def main(args, init_distributed=False): utils.import_user_module(args) assert args.max_tokens is not None or args.max_sentences is not None, \ 'Must specify batch size either with --max-tokens or --max-sentences' # Initialize CUDA and distributed training if torch.cuda.is_available() and not args.cpu: torch.cuda.set_device(args.device_id) np.random.seed(args.seed) torch.manual_seed(args.seed) if init_distributed: args.distributed_rank = distributed_utils.distributed_init(args) if distributed_utils.is_master(args): checkpoint_utils.verify_checkpoint_directory(args.save_dir) # Print args logger.info(args) # Setup task, e.g., translation, language modeling, etc. task = tasks.setup_task(args) # Load valid dataset (we load training data below, based on the latest checkpoint) for valid_sub_split in args.valid_subset.split(','): task.load_dataset(valid_sub_split, combine=False, epoch=0) # Build model and criterion model = task.build_model(args) criterion = task.build_criterion(args) logger.info(model) logger.info('model {}, criterion {}'.format(args.arch, criterion.__class__.__name__)) logger.info('num. model params: {} (num. trained: {})'.format( sum(p.numel() for p in model.parameters()), sum(p.numel() for p in model.parameters() if p.requires_grad), )) # Build trainer trainer = Trainer(args, task, model, criterion) logger.info('training on {} GPUs'.format(args.distributed_world_size)) logger.info( 'max input frames per GPU = {} and max sentences per GPU = {}'.format( args.max_tokens, args.max_sentences, )) # Load the latest checkpoint if one is available and restore the # corresponding train iterator extra_state, epoch_itr = checkpoint_utils.load_checkpoint(args, trainer) # Train until the learning rate gets too small max_epoch = args.max_epoch or math.inf max_update = args.max_update or math.inf lr = trainer.get_lr() train_meter = StopwatchMeter() train_meter.start() valid_subsets = args.valid_subset.split(',') while (lr > args.min_lr and (epoch_itr.epoch < max_epoch # allow resuming training from the final checkpoint or epoch_itr._next_epoch_itr is not None) and trainer.get_num_updates() < max_update): # train for one epoch train(args, trainer, task, epoch_itr) if not args.disable_validation and epoch_itr.epoch % args.validate_interval == 0: valid_losses = validate(args, trainer, task, epoch_itr, valid_subsets) else: valid_losses = [None] # only use first validation loss to update the learning rate lr = trainer.lr_step(epoch_itr.epoch, valid_losses[0]) # save checkpoint if epoch_itr.epoch % args.save_interval == 0: checkpoint_utils.save_checkpoint(args, trainer, epoch_itr, valid_losses[0]) # early stop if should_stop_early(args, valid_losses[0]): logger.info( 'early stop since valid performance hasn\'t improved for last {} runs' .format(args.patience)) break reload_dataset = len(args.train_feat_files) > 1 # sharded data: get train iterator for next epoch epoch_itr = trainer.get_train_iterator(epoch_itr.epoch, load_dataset=reload_dataset) train_meter.stop() logger.info('done training in {:.1f} seconds'.format(train_meter.sum))
def train(args, trainer, task, epoch_itr): """Train the model for one epoch.""" # logger.info("DEBUG: Entering fairseq_cli/train.py: train()") # Initialize data iterator itr = epoch_itr.next_epoch_itr( fix_batches_to_gpus=args.fix_batches_to_gpus, shuffle=(epoch_itr.next_epoch_idx > args.curriculum), # shuffling leads to error for multitask learning wiht cls_indices!!! ) # logger.info("DEBUG: initialized itr") update_freq = ( args.update_freq[epoch_itr.epoch - 1] if epoch_itr.epoch <= len(args.update_freq) else args.update_freq[-1] ) itr = iterators.GroupedIterator(itr, update_freq) progress = progress_bar.progress_bar( itr, log_format=args.log_format, log_interval=args.log_interval, epoch=epoch_itr.epoch, tensorboard_logdir=( args.tensorboard_logdir if distributed_utils.is_master(args) else None ), default_log_format=('tqdm' if not args.no_progress_bar else 'simple'), ) # logger.info("DEBUG: Got the progress bar") # task specific setup per epoch task.begin_epoch(epoch_itr.epoch, trainer.get_model()) # logger.info("DEBUG: finished task specific setup per epoch") valid_subsets = args.valid_subset.split(',') max_update = args.max_update or math.inf # debug_count = 0 for samples in progress: # if debug_count > 10: # continue # debug_count += 1 with metrics.aggregate('train_inner'): log_output = trainer.train_step(samples) if log_output is None: # OOM, overflow, ... continue # if debug_count < 20: # logger.info("DEBUG: mini-batch {}".format(debug_count)) # log mid-epoch stats num_updates = trainer.get_num_updates() if num_updates % args.log_interval == 0: stats = get_training_stats(metrics.get_smoothed_values('train_inner')) progress.log(stats, tag='train_inner', step=num_updates) # reset mid-epoch stats after each log interval # the end-of-epoch stats will still be preserved metrics.reset_meters('train_inner') if ( not args.disable_validation and args.save_interval_updates > 0 and num_updates % args.save_interval_updates == 0 and num_updates > 0 ): valid_losses = validate(args, trainer, task, epoch_itr, valid_subsets) checkpoint_utils.save_checkpoint(args, trainer, epoch_itr, valid_losses[0]) if num_updates >= max_update: break # log end-of-epoch stats stats = get_training_stats(metrics.get_smoothed_values('train')) progress.print(stats, tag='train', step=num_updates) # reset epoch-level meters metrics.reset_meters('train')