Exemplo n.º 1
0
def main():
    """Main program."""
    args = get_args()

    if args.num_layers_per_virtual_pipeline_stage is not None:
        print("Interleaved pipeline schedule is not yet supported for text generation.")
        exit()

    if args.task == 'LAMBADA':
        eval_metric = 'accuracy'
    elif args.task == 'WIKITEXT103':
        eval_metric = 'loss'
    else:
        raise NotImplementedError('{} task is not implemented.'.format(
            args.task))

    # Set up model and load checkpoint.
    model = get_model(get_model_provider(eval_metric), wrap_with_ddp=False)
    if args.load is not None:
        _ = load_checkpoint(model, None, None)

    assert len(model) == 1, "Above condition should have caught this"
    model = model[0]

    # Data stuff.
    dataset = build_dataset(args.task)
    dataloader = build_data_loader(dataset, args.micro_batch_size,
                                   args.num_workers, drop_last=False)

    # Run evaluation.
    evaluate_and_print_results(args.task, dataloader, model, eval_metric)

    print_rank_0('done :-)')
Exemplo n.º 2
0
def main():
    """Main program."""
    args = get_args()

    if args.task == 'LAMBADA':
        eval_metric = 'accuracy'
    elif args.task == 'WIKITEXT103':
        eval_metric = 'loss'
    else:
        raise NotImplementedError('{} task is not implemented.'.format(
            args.task))

    # Set up model and load checkpoint.
    model = get_model(get_model_provider(eval_metric))
    if args.load is not None:
        _ = load_checkpoint(model, None, None)

    # Data stuff.
    dataset = build_dataset(args.task)
    dataloader = build_data_loader(dataset,
                                   args.batch_size,
                                   args.num_workers,
                                   drop_last=False)

    # Run evaluation.
    evaluate_and_print_results(args.task, dataloader, model, eval_metric)

    print_rank_0('done :-)')
Exemplo n.º 3
0
def accuracy_func_provider(single_dataset_provider):
    """Provide function that calculates accuracies."""
    args = get_args()

    # Build dataloaders.
    datapaths = args.valid_data
    dataloaders = []
    for datapath in datapaths:
        dataset = single_dataset_provider(datapath)
        dataloader = build_data_loader(
            dataset,
            args.orig_micro_batch_size,
            num_workers=args.num_workers,
            drop_last=(mpu.get_data_parallel_world_size() > 1))
        dataloaders.append((dataset.dataset_name, dataloader))

    def metrics_func(model, epoch, output_predictions=False):
        print_rank_last('calculating metrics ...')
        correct = 0
        total = 0
        if output_predictions:
            assert mpu.get_data_parallel_world_size() == 1
            named_predictions = []
            names = 'predictions'
        for name, dataloader in dataloaders:
            output = calculate_correct_answers(name, model, dataloader, epoch,
                                               output_predictions)
            if not output_predictions:
                correct_ans, total_count = output
            else:
                correct_ans, total_count, predictions = output
                named_predictions.append((name, predictions))
                names += '_' + name
            correct += correct_ans
            total += total_count
        if is_last_rank():
            percent = float(correct) * 100.0 / float(total)
            print(' >> |epoch: {}| overall: correct / total = {} / {} = '
                  '{:.4f} %'.format(epoch, correct, total, percent))

        if output_predictions and is_last_rank():
            assert args.load is not None
            filename = os.path.join(args.load, names + '.pt')
            torch.save(named_predictions, filename)

    return metrics_func
Exemplo n.º 4
0
def accuracy_func_provider(single_dataset_provider, rank0sampler=False):
    """Provide function that calculates accuracies."""
    args = get_args()

    print_rank_0("accuracy_func_provider is CALLED")

    # Build dataloaders
    datapath = args.valid_data
    dataset = single_dataset_provider(datapath)

    drop_last = False
    if mpu.get_data_parallel_world_size() > 1 and not rank0sampler:
        drop_last = True

    print_rank_0(datapath)
    print_rank_0(rank0sampler)

    dataloader = build_data_loader(dataset,
                                   args.eval_micro_batch_size,
                                   num_workers=args.num_workers,
                                   drop_last=drop_last,
                                   task_collate_fn=task_collate_fn)
    dataloaders = (dataset.dataset_name, dataloader)

    def metrics_func(model, epoch, output_predictions=False):
        print_rank_0('calculating metrics by accuracy func in ORQA...')

        if output_predictions:
            assert rank0sampler
            names = 'predictions'
        name, dataloader = dataloaders
        if args.task == "RET-FINETUNE-NQ":
            start_time = time.time()
            output = retrieval_loss(model, dataloader)
            stats_dict, total = output
            format_string = ""
            for k, v in stats_dict.items():
                format_string += "|{} = {:.2f}".format(k, v / total)
            print_rank_0("epoch:{}{}".format(epoch, format_string))
            print_rank_0("taken time to calcuate metrics {:.3f}".format(\
                time.time() - start_time))
        else:
            raise AssertionError("{} Task not supported".format(args.task))

    return metrics_func