def validate(args, trainer, task, epoch_itr, subsets): """Evaluate the model on the validation set(s) and return the losses.""" if args['dataset']['fixed_validation_seed'] is not None: # set fixed seed for every validation set_seed.set_torch_seed(args['dataset']['fixed_validation_seed']) valid_losses = [] for subset in subsets: # Initialize data iterator itr = task.get_batch_iterator( dataset=task.dataset(subset), max_tokens=args['dataset']['max_tokens_valid'], max_sentences=args['dataset']['max_sentences_valid'], max_positions=utils.resolve_max_positions( task.max_positions(), trainer.get_model().max_positions(), ), ignore_invalid_inputs=args['dataset'] ['skip_invalid_size_inputs_valid_test'], required_batch_size_multiple=args['dataset'] ['required_batch_size_multiple'], seed=args['common']['seed'], num_shards=args['distributed_training']['distributed_world_size'], shard_id=args['distributed_training']['distributed_rank'], num_workers=args['dataset']['num_workers'], ).next_epoch_itr(shuffle=False) progress = progress_bar.progress_bar( itr, log_format=args['common']['log_format'], log_interval=args['common']['log_interval'], epoch=epoch_itr.epoch, prefix=f"valid on '{subset}' subset", tensorboard_logdir=(args['common']['tensorboard_logdir'] if distributed_utils.is_master(args) else None), default_log_format=('tqdm' if not args['common']['no_progress_bar'] else 'simple'), ) # create a new root metrics aggregator so validation metrics # don't pollute other aggregators (e.g., train meters) with metrics.aggregate(new_root=True) as agg: for sample in progress: trainer.valid_step(sample) # log validation stats stats = get_valid_stats(args, trainer, agg.get_smoothed_values()) progress.print(stats, tag=subset, step=trainer.get_num_updates()) valid_losses.append( stats[args['checkpoint']['best_checkpoint_metric']]) return valid_losses
def validate(args, trainer, task, epoch_itr, subsets): """Evaluate the model on the validation set(s) and return the losses.""" if args['dataset']['fixed_validation_seed'] is not None: # set fixed seed for every validation set_seed.set_torch_seed(args['dataset']['fixed_validation_seed']) valid_losses = [] for subset in subsets: # Initialize data iterator itr = task.get_batch_iterator( dataset=task.dataset(subset), max_tokens=args['dataset']['max_tokens_valid'], max_sentences=args['dataset']['max_sentences_valid'], max_positions=utils.resolve_max_positions( task.max_positions(), trainer.get_model().max_positions(), ), ignore_invalid_inputs=args['dataset'] ['skip_invalid_size_inputs_valid_test'], required_batch_size_multiple=args['dataset'] ['required_batch_size_multiple'], seed=args['common']['seed'], num_shards=args['distributed_training']['distributed_world_size'], shard_id=args['distributed_training']['distributed_rank'], num_workers=args['dataset']['num_workers'], ).next_epoch_itr(shuffle=False) progress = progress_bar.progress_bar( itr, log_format=args['common']['log_format'], log_interval=args['common']['log_interval'], epoch=epoch_itr.epoch, prefix=f"valid on '{subset}' subset", tensorboard_logdir=(args['common']['tensorboard_logdir'] if distributed_utils.is_master(args) else None), default_log_format=('tqdm' if not args['common']['no_progress_bar'] else 'simple'), ) accs, mrrs, maps, ndcgs = [], [], [], [] trainer.model.eval() trainer.criterion.eval() with torch.no_grad(): for sample in progress: sample = trainer._prepare_sample(sample) inputs = list(sample['net_input'].values()) code_repr = trainer.model.code_forward(*inputs[:6]) desc_repr = trainer.model.desc_forward(*inputs[6:8]) code_repr = code_repr / code_repr.norm(dim=-1, keepdim=True) desc_repr = desc_repr / desc_repr.norm(dim=-1, keepdim=True) similarity = code_repr @ desc_repr.t() acc, mrr, map, ndcg = inference(similarity) accs.append(acc.mean().item()) mrrs.append(mrr.mean().item()) maps.append(map.mean().item()) ndcgs.append(ndcg.mean().item()) accs = round(float(np.mean(accs)), 6) mrrs = round(float(np.mean(mrrs)), 6) maps = round(float(np.mean(maps)), 6) ndcgs = round(float(np.mean(ndcgs)), 6) stats = {'acc': accs, 'mrr': mrrs, 'map': maps, 'ndcg': ndcgs} progress.print(stats, tag=subset, step=trainer.get_num_updates()) valid_losses.append( stats[args['checkpoint']['best_checkpoint_metric']]) return valid_losses