def train(args, trainer, task, epoch_itr): """Train the model for one epoch.""" # Initialize data iterator itr = epoch_itr.next_epoch_itr( fix_batches_to_gpus=args.fix_batches_to_gpus, shuffle=(epoch_itr.epoch >= args.curriculum), ) update_freq = (args.update_freq[epoch_itr.epoch - 1] if epoch_itr.epoch <= len(args.update_freq) else args.update_freq[-1]) itr = iterators.GroupedIterator(itr, update_freq) progress = progress_bar.build_progress_bar( args, itr, epoch_itr.epoch, no_progress_bar='simple', ) # task specific setup per epoch task.begin_epoch(epoch_itr.epoch, trainer.get_model()) valid_subsets = args.valid_subset.split(',') max_update = args.max_update or math.inf if hasattr(trainer.criterion, 'set_epoch'): trainer.criterion.set_epoch(epoch_itr.epoch) for samples in progress: if hasattr(trainer.criterion, 'set_num_updates'): trainer.criterion.set_num_updates(trainer.get_num_updates()) log_output = trainer.train_step(samples) num_updates = trainer.get_num_updates() if log_output is None: continue # log mid-epoch stats stats = get_training_stats(metrics.get_smoothed_values('train')) progress.log(stats, tag='train', step=num_updates) if (not args.disable_validation and args.save_interval_updates > 0 and num_updates % args.save_interval_updates == 0 and num_updates > 0): valid_losses = validate(args, trainer, task, epoch_itr, valid_subsets) checkpoint_utils.save_checkpoint(args, trainer, epoch_itr, valid_losses[0]) if num_updates >= max_update: break # log end-of-epoch stats stats = get_training_stats(metrics.get_smoothed_values('train')) progress.print(stats, tag='train', step=num_updates) # reset epoch-level meters metrics.reset_meters('train')
def downstream_train_pytorch(args, trainer, task, epoch_itr, train_prefix): """Fine-tune PyTorch classifier on downstream training set for one epoch""" task.split = 'train' num_updates = trainer.get_num_updates() # Initialize data iterator itr = epoch_itr.next_epoch_itr( fix_batches_to_gpus=args.fix_batches_to_gpus, shuffle=(epoch_itr.next_epoch_idx > args.curriculum), ) update_freq = (args.update_freq[epoch_itr.epoch - 1] if epoch_itr.epoch <= len(args.update_freq) else args.update_freq[-1]) itr = iterators.GroupedIterator(itr, update_freq) progress = progress_bar.build_progress_bar( args, itr, epoch_itr.epoch, no_progress_bar='simple', ) progress = maybe_wrap_neptune_logging(progress, args) # Task specific setup per epoch task.begin_epoch(epoch_itr.epoch, trainer.get_model()) max_update = args.max_update or math.inf with metrics.aggregate() as agg: for samples in progress: # Train for one step log_output = trainer.train_step(samples) num_updates = trainer.get_num_updates() if log_output is None: continue # log mid-epoch stats stats = get_ft_train_stats(agg.get_smoothed_values()) progress.log(stats, tag=train_prefix, step=num_updates) if num_updates >= max_update: break # log end-of-epoch stats stats = get_ft_train_stats(agg.get_smoothed_values()) try: progress.print(stats, tag=train_prefix, step=num_updates, log=False) except: progress.print(stats, tag=train_prefix, step=num_updates) # Reset epoch-level meters metrics.reset_meters(train_prefix)
def test_nested_duplicate_names(self): name = str(uuid.uuid4()) metrics.reset_meters(name) with metrics.aggregate(name): metrics.log_scalar('loss', 1) with metrics.aggregate() as other: with metrics.aggregate(name): metrics.log_scalar('loss', 2) metrics.log_scalar('loss', 6) self.assertEqual(metrics.get_smoothed_values(name)['loss'], 3) self.assertEqual(other.get_smoothed_values()['loss'], 2)
def test_named(self): name = str(uuid.uuid4()) metrics.reset_meters(name) with metrics.aggregate(name): metrics.log_scalar('loss', 1) metrics.log_scalar('loss', 3) with metrics.aggregate(name): metrics.log_scalar('loss', 2) self.assertEqual(metrics.get_smoothed_values(name)['loss'], 1.5)
def downstream_validate_sklearn(args, task, model, epoch_for_logging, task_name, num_updates, classifier, scaler): """Evaluate classifier on downstream validation set""" if args.fixed_validation_seed is not None: # set fixed seed for every validation utils.set_torch_seed(args.fixed_validation_seed) # Initialize data iterator itr = task.get_batch_iterator( dataset=task.dataset('valid'), max_tokens=args.max_tokens_sklearn, max_sentences=args.max_sentences_sklearn, max_positions=utils.resolve_max_positions( task.max_positions(), model.max_positions(), ), ignore_invalid_inputs=args.skip_invalid_size_inputs_valid_test, required_batch_size_multiple=args.required_batch_size_multiple, seed=args.seed, num_shards=1, shard_id=0, num_workers=args.num_workers, ).next_epoch_itr(shuffle=False) progress = progress_bar.build_progress_bar( args, itr, epoch_for_logging, prefix='sklearn valid on \'{}\''.format(task_name), no_progress_bar='simple') progress = maybe_wrap_neptune_logging(progress, args) # Reset validation meters metrics.reset_meters(task_name) # Load downstream validation data with torch.no_grad(): model.eval() features, targets = load_downstream_data(args, progress, model, 'valid', scaler, None) # Compute class predictions and probabilities class_predictions = classifier.predict(features) class_probabilities = classifier.predict_proba(features) # Compute and log downstream validation stats stats = compute_sklearn_stats(targets, class_predictions, class_probabilities, args.num_classes, args.eval_metric) stats = get_sklearn_stats(stats, num_updates) progress.print(stats, tag=task_name + '_sk_valid', step=num_updates)
def validate(args, trainer, task, epoch_for_logging, valid_name, ckpt_idx): """Evaluate the model on the validation set(s) and return the losses.""" task.split = 'valid' if args.fixed_validation_seed is not None: # Set fixed seed for every validation utils.set_torch_seed(args.fixed_validation_seed) # Initialize data iterator itr = task.get_batch_iterator( dataset=task.dataset('valid'), max_tokens=args.max_tokens_valid, max_sentences=args.max_sentences_valid, max_positions=utils.resolve_max_positions( task.max_positions(), trainer.get_model().max_positions(), ), ignore_invalid_inputs=args.skip_invalid_size_inputs_valid_test, required_batch_size_multiple=args.required_batch_size_multiple, seed=args.seed, num_shards=args.distributed_world_size, shard_id=args.distributed_rank, num_workers=args.num_workers, epoch=epoch_for_logging, ).next_epoch_itr(shuffle=False) progress = progress_bar.build_progress_bar( args, itr, epoch_for_logging, prefix='valid on \'{}\' subset'.format(valid_name), no_progress_bar='simple') progress = maybe_wrap_neptune_logging(progress, args) # Reset validation meters metrics.reset_meters(valid_name) with metrics.aggregate(valid_name) as agg: for sample in progress: trainer.valid_step(sample) # Log validation stats stats = get_valid_stats(args, trainer, agg.get_smoothed_values()) if args.log_valid_progress: valid_progress_prefix = '{}_ckpt{}'.format(valid_name, ckpt_idx) progress.print({args.eval_metric: stats[args.eval_metric]}, tag=valid_progress_prefix, step=epoch_for_logging) # Return validations score return stats[args.best_checkpoint_metric], stats[ args.eval_metric], progress
def train(args, trainer, task, epoch_itr): """Train the model for one epoch.""" task.split = 'train' # Initialize data iterator itr = epoch_itr.next_epoch_itr( fix_batches_to_gpus=args.fix_batches_to_gpus, shuffle=(epoch_itr.next_epoch_idx > args.curriculum), ) update_freq = (args.update_freq[epoch_itr.epoch - 1] if epoch_itr.epoch <= len(args.update_freq) else args.update_freq[-1]) itr = iterators.GroupedIterator(itr, update_freq) progress = maybe_wrap_neptune_logging( progress_bar.build_progress_bar( args, itr, epoch_itr.epoch, no_progress_bar='simple', ), args=args, ) # task specific setup per epoch task.begin_epoch(epoch_itr.epoch, trainer.get_model()) valid_subsets = args.valid_subset.split(',') max_update = args.max_update or math.inf with metrics.aggregate() as agg: for samples in progress: log_output = trainer.train_step(samples) num_updates = trainer.get_num_updates() if log_output is None: continue # log mid-epoch stats stats = get_training_stats(agg.get_smoothed_values()) progress.log(stats, tag='train', step=num_updates) if num_updates >= max_update: break # log end-of-epoch stats stats = get_training_stats(agg.get_smoothed_values()) try: progress.print(stats, tag='train', step=num_updates, log=False) except: progress.print(stats, tag='train', step=num_updates) # reset epoch-level meters metrics.reset_meters('train')
def validate(args, trainer, task, epoch_itr, subsets): """Evaluate the model on the validation set(s) and return the losses.""" if args.fixed_validation_seed is not None: # set fixed seed for every validation utils.set_torch_seed(args.fixed_validation_seed) valid_losses = [] for subset in subsets: # Initialize data iterator itr = task.get_batch_iterator( dataset=task.dataset(subset), max_tokens=args.max_tokens_valid, max_sentences=args.max_sentences_valid, max_positions=utils.resolve_max_positions( task.max_positions(), trainer.get_model().max_positions(), ), ignore_invalid_inputs=args.skip_invalid_size_inputs_valid_test, required_batch_size_multiple=args.required_batch_size_multiple, seed=args.seed, num_shards=args.distributed_world_size, shard_id=args.distributed_rank, num_workers=args.num_workers, ).next_epoch_itr(shuffle=False) progress = progress_bar.build_progress_bar( args, itr, epoch_itr.epoch, prefix='valid on \'{}\' subset'.format(subset), no_progress_bar='simple') # reset validation meters metrics.reset_meters('valid') with metrics.aggregate() as agg: for sample in progress: trainer.valid_step(sample) # log validation stats stats = get_valid_stats(args, trainer, agg.get_smoothed_values()) progress.print(stats, tag=subset, step=trainer.get_num_updates()) valid_losses.append(stats[args.best_checkpoint_metric]) return valid_losses
def downstream_validate_pytorch(args, task, model, criterion, epoch_for_logging, subsets, valid_name, num_updates, global_epoch=None): """Evaluate the model on the validation set(s) and return the losses.""" task.split = 'valid' valid_name_ = valid_name if valid_name is not None else 'valid' if args.fixed_validation_seed is not None: # Set fixed seed for every validation utils.set_torch_seed(args.fixed_validation_seed) valid_losses = [] for subset in subsets: # Initialize data iterator itr = task.get_batch_iterator( dataset=task.dataset(subset), max_tokens=args.max_tokens_valid, max_sentences=args.max_sentences_valid, max_positions=utils.resolve_max_positions( task.max_positions(), model.max_positions(), ), ignore_invalid_inputs=args.skip_invalid_size_inputs_valid_test, required_batch_size_multiple=args.required_batch_size_multiple, seed=args.seed, num_shards=args.distributed_world_size, shard_id=args.distributed_rank, num_workers=args.num_workers, epoch=1, ).next_epoch_itr(shuffle=False) progress = progress_bar.build_progress_bar( args, itr, epoch_for_logging, prefix='valid on \'{}\' subset'.format(valid_name_), no_progress_bar='simple') # Add global epoch to beginning of progress bar description if global_epoch is not None: try: progress.wrapped_bar.tqdm.set_description( desc='epoch {:03d} | \'{}\' {}'.format( global_epoch, valid_name_, progress.wrapped_bar.prefix), refresh=True) except: progress.tqdm.set_description( desc='epoch {:03d} | \'{}\' {}'.format( global_epoch, valid_name_, progress.tqdm.desc), refresh=True) progress = maybe_wrap_neptune_logging(progress, args) # Reset validation meters metrics.reset_meters(valid_name_) with metrics.aggregate(valid_name) as agg: dummy_batch = "DUMMY" for sample in progress: dummy_batch = sample if dummy_batch == "DUMMY" else dummy_batch valid_step(args, sample, task, model, criterion, dummy_batch, logger) # Log validation stats stats = get_ft_valid_stats(args, agg.get_smoothed_values(), num_updates) progress.print(stats, tag=valid_name_, step=num_updates) valid_losses.append(stats[args.best_checkpoint_metric]) return valid_losses
def downstream_train_sklearn(args, task, model, epoch_for_logging, task_name, num_updates): """Fine-tune sklearn classifier on downstream training set""" if args.fixed_validation_seed is not None: # set fixed seed for every validation utils.set_torch_seed(args.fixed_validation_seed) # Initialize data iterator itr = task.get_batch_iterator( dataset=task.dataset('train'), max_tokens=args.max_tokens_sklearn, max_sentences=args.max_sentences_sklearn, max_positions=utils.resolve_max_positions( task.max_positions(), model.max_positions(), ), ignore_invalid_inputs=args.skip_invalid_size_inputs_valid_test, required_batch_size_multiple=args.required_batch_size_multiple, seed=args.seed, num_shards=1, shard_id=0, num_workers=args.num_workers, ).next_epoch_itr(shuffle=False) progress = progress_bar.build_progress_bar( args, itr, epoch_for_logging, prefix='sklearn fine-tune on \'{}\''.format(task_name), no_progress_bar='simple') progress = maybe_wrap_neptune_logging(progress, args) # Reset meters metrics.reset_meters(task_name) # Load downstream train data with torch.no_grad(): model.eval() features, targets, scaler = load_downstream_data( args, progress, model, 'train', None, args.scaler_type) # Train classifier logger.info('fine-tuning LogisticRegression classifier on \'{}\''.format( task_name)) timer_start = timer() best_C = LogRegCV(args, features, targets) classifier = LogisticRegression( multi_class=args.multi_class, solver=args.solver, C=best_C, n_jobs=min(os.cpu_count(), args.num_classes, args.num_workers_sklearn) if args.solver != 'liblinear' else None, tol=args.tol, random_state=args.seed, max_iter=args.max_iter, verbose=args.verbose).fit(features, targets) timer_end = timer() logger.info( 'finished sklearn fine-tuning in {:.2f} seconds'.format(timer_end - timer_start)) # Compute class predictions and probabilities class_predictions = classifier.predict(features) class_probabilities = classifier.predict_proba(features) # Compute and log downstream training stats stats = compute_sklearn_stats(targets, class_predictions, class_probabilities, args.num_classes, args.eval_metric) stats = get_sklearn_stats(stats, num_updates) progress.print(stats, tag=task_name + '_sk_train', step=num_updates) return classifier, scaler
def train(args, trainer, task, epoch_itr): """Train the model for one epoch.""" # Initialize data iterator itr = epoch_itr.next_epoch_itr( fix_batches_to_gpus=args.fix_batches_to_gpus, shuffle=(epoch_itr.epoch >= args.curriculum), ) update_freq = (args.update_freq[epoch_itr.epoch - 1] if epoch_itr.epoch <= len(args.update_freq) else args.update_freq[-1]) itr = iterators.GroupedIterator(itr, update_freq) progress = progress_bar.build_progress_bar( args, itr, epoch_itr.epoch, no_progress_bar='simple', ) # task specific setup per epoch task.begin_epoch(epoch_itr.epoch, trainer.get_model()) valid_subsets = args.valid_subset.split(',') max_update = args.max_update or math.inf start_time = time.time() step = 0 for samples in progress: log_output = trainer.train_step(samples) num_updates = trainer.get_num_updates() step += 1 """ if step % 10 == 0: print(step) if step >= 200: pr.disable() #pr.dump_stats( "torch_profile") sys.exit() step += 1 """ if log_output is None: continue # log mid-epoch stats stats = get_training_stats(metrics.get_smoothed_values('train')) progress.log(stats, tag='train', step=num_updates) if (not args.disable_validation and args.save_interval_updates > 0 and num_updates % args.save_interval_updates == 0 and num_updates > 0): print("validate and save_checkpoint") valid_losses = validate(args, trainer, task, epoch_itr, valid_subsets) checkpoint_utils.save_checkpoint(args, trainer, epoch_itr, valid_losses[0]) if num_updates >= max_update: break train_epoch_cost = time.time() - start_time # log end-of-epoch stats stats = get_training_stats(metrics.get_smoothed_values('train')) progress.print(stats, tag='train', step=num_updates) print("epoch_cost: %.5f s, avg_speed: %.5f steps/s" % (train_epoch_cost, float(step) / train_epoch_cost)) # reset epoch-level meters metrics.reset_meters('train')