Пример #1
0
    def metrics_func(model, epoch, output_predictions=False):
        print_rank_last('calculating metrics ...')
        correct = 0
        total = 0
        if output_predictions:
            assert mpu.get_data_parallel_world_size() == 1
            named_predictions = []
            names = 'predictions'
        for name, dataloader in dataloaders:
            output = calculate_correct_answers(name, model, dataloader,
                                               epoch, output_predictions)
            if not output_predictions:
                correct_ans, total_count = output
            else:
                correct_ans, total_count, predictions = output
                named_predictions.append((name, predictions))
                names += '_' + name
            correct += correct_ans
            total += total_count
        if is_last_rank():
            percent = float(correct) * 100.0 / float(total)
            print(' >> |epoch: {}| overall: correct / total = {} / {} = '
                  '{:.4f} %'.format(epoch, correct, total, percent))

        if output_predictions and is_last_rank():
            assert args.load is not None
            filename = os.path.join(args.load, names + '.pt')
            torch.save(named_predictions, filename)
Пример #2
0
    def load_state_dict(self, state_dict, strict=True):
        """Customized load."""

        self.language_model.load_state_dict(
            state_dict[self._language_model_key], strict=strict)
        if self.post_process:
            if self._multichoice_head_key in state_dict:
                self.multichoice_head.load_state_dict(
                    state_dict[self._multichoice_head_key], strict=strict)
            else:
                print_rank_last(
                    '***WARNING*** could not find {} in the checkpoint, '
                    'initializing to random'.format(
                        self._multichoice_head_key))
Пример #3
0
    def load_state_dict(self, state_dict, strict=True):
        """Customized load."""

        self.language_model.load_state_dict(
            state_dict[self._language_model_key], strict=strict)
        if mpu.is_pipeline_last_stage():
            if self._classification_head_key in state_dict:
                self.classification_head.load_state_dict(
                    state_dict[self._classification_head_key], strict=strict)
            else:
                print_rank_last(
                    '***WARNING*** could not find {} in the checkpoint, '
                    'initializing to random'.format(
                        self._classification_head_key))
Пример #4
0
    def load_state_dict(checkpoint_name):
        try:
            state_dict = torch.load(checkpoint_name, map_location="cpu")
        except ModuleNotFoundError:
            from megatron.fp16_deprecated import loss_scaler

            # For backward compatibility.
            print_rank_last(
                " > deserializing using the old code structure ...")
            sys.modules["fp16.loss_scaler"] = sys.modules[
                "megatron.fp16_deprecated.loss_scaler"]
            sys.modules["megatron.fp16.loss_scaler"] = sys.modules[
                "megatron.fp16_deprecated.loss_scaler"]
            state_dict = torch.load(checkpoint_name, map_location="cpu")
            sys.modules.pop("fp16.loss_scaler", None)
            sys.modules.pop("megatron.fp16.loss_scaler", None)
        return state_dict
Пример #5
0
def evaluate_and_print_results(prefix,
                               forward_step_func,
                               data_iterator,
                               model,
                               iteration,
                               verbose=False):
    """Helper function to evaluate and dump results on screen."""
    args = get_args()
    writer = get_tensorboard_writer()

    total_loss_dict = evaluate(forward_step_func, data_iterator, model,
                               verbose)
    string = ' validation loss at {} | '.format(prefix)
    for key in total_loss_dict:
        string += '{} value: {:.6E} | '.format(key,
                                               total_loss_dict[key].item())
        ppl = math.exp(min(20, total_loss_dict[key].item()))
        string += '{} PPL: {:.6E} | '.format(key, ppl)
        if writer and is_last_rank():
            writer.add_scalar('{} value-validation'.format(key),
                              total_loss_dict[key].item(), iteration)
            writer.add_scalar('{} ppl-validation'.format(key), ppl, iteration)
            writer.add_scalar('{} value-validation vs samples'.format(key),
                              total_loss_dict[key].item(),
                              args.consumed_train_samples)
            writer.add_scalar('{} ppl-validation vs samples'.format(key), ppl,
                              args.consumed_train_samples)

    length = len(string) + 1
    print_rank_last('-' * length)
    print_rank_last(string)
    print_rank_last('-' * length)
Пример #6
0
def load_checkpoint(model, optimizer, lr_scheduler, load_arg="load"):
    """Load a model checkpoint and return the iteration."""

    from megatron import get_args
    from megatron import mpu
    from megatron import print_rank_last
    from megatron.checkpointing import get_checkpoint_tracker_filename
    from megatron.checkpointing import set_checkpoint_version
    from megatron.checkpointing import check_checkpoint_args
    from megatron.checkpointing import update_num_microbatches

    if mpu.get_data_parallel_rank() == 0:
        # at dp rank 0, we still follow the native load_checkpoint by megatron
        from megatron.checkpointing import load_checkpoint as load_checkpoint_native

        return load_checkpoint_native(model, optimizer, lr_scheduler, load_arg)

    args = get_args()
    load_dir = getattr(args, load_arg)

    if isinstance(model, DistributedDataParallel):
        model = model.module
    # Read the tracker file and set the iteration.
    tracker_filename = get_checkpoint_tracker_filename(load_dir)

    # If no tracker file, return iretation zero.
    if not os.path.isfile(tracker_filename):
        print_rank_last("WARNING: could not find the metadata file {} ".format(
            tracker_filename))
        print_rank_last(
            "    will not load any checkpoints and will start from "
            "random")
        return 0

    # Otherwise, read the tracker file and either set the iteration or
    # mark it as a release checkpoint.
    iteration = 0
    release = False
    with open(tracker_filename, "r") as f:
        metastring = f.read().strip()
        try:
            iteration = int(metastring)
        except ValueError:
            release = metastring == "release"
            if not release:
                print_rank_last(
                    "ERROR: Invalid metadata file {}. Exiting".format(
                        tracker_filename))
                sys.exit()

    assert iteration > 0 or release, "error parsing metadata file {}".format(
        tracker_filename)

    # Checkpoint.
    checkpoint_name_rank0 = get_fmoe_checkpoint_name(load_dir, iteration,
                                                     release, 0)
    checkpoint_name_local = get_fmoe_checkpoint_name(
        load_dir, iteration, release, mpu.get_data_parallel_rank())
    print_rank_last(
        " loading checkpoint at rank 0 from {} and rank {} from {} at iteration {}, will merge them later"
        .format(
            checkpoint_name_rank0,
            mpu.get_data_parallel_rank(),
            checkpoint_name_local,
            iteration,
        ))

    # Load the checkpoint.
    def load_state_dict(checkpoint_name):
        try:
            state_dict = torch.load(checkpoint_name, map_location="cpu")
        except ModuleNotFoundError:
            from megatron.fp16_deprecated import loss_scaler

            # For backward compatibility.
            print_rank_last(
                " > deserializing using the old code structure ...")
            sys.modules["fp16.loss_scaler"] = sys.modules[
                "megatron.fp16_deprecated.loss_scaler"]
            sys.modules["megatron.fp16.loss_scaler"] = sys.modules[
                "megatron.fp16_deprecated.loss_scaler"]
            state_dict = torch.load(checkpoint_name, map_location="cpu")
            sys.modules.pop("fp16.loss_scaler", None)
            sys.modules.pop("megatron.fp16.loss_scaler", None)
        except BaseException:
            print_rank_last("could not load the checkpoint")
            sys.exit()
        return state_dict

    state_dict_rank0 = load_state_dict(checkpoint_name_rank0)
    state_dict_local = load_state_dict(checkpoint_name_local)

    state_dict = merge_state_dict(state_dict_rank0, state_dict_local,
                                  args.fp16)

    # set checkpoint version
    set_checkpoint_version(state_dict.get("checkpoint_version", 0))

    # Set iteration.
    if args.finetune or release:
        iteration = 0
    else:
        try:
            iteration = state_dict["iteration"]
        except KeyError:
            try:  # Backward compatible with older checkpoints
                iteration = state_dict["total_iters"]
            except KeyError:
                print_rank_last("A metadata file exists but unable to load "
                                "iteration from checkpoint {}, exiting".format(
                                    checkpoint_name_local))
                sys.exit()

    # Check arguments.
    assert args.consumed_train_samples == 0
    assert args.consumed_valid_samples == 0
    if "args" in state_dict:
        checkpoint_args = state_dict["args"]
        check_checkpoint_args(checkpoint_args)
        args.consumed_train_samples = getattr(checkpoint_args,
                                              "consumed_train_samples", 0)
        update_num_microbatches(consumed_samples=args.consumed_train_samples)
        args.consumed_valid_samples = getattr(checkpoint_args,
                                              "consumed_valid_samples", 0)
    else:
        print_rank_last("could not find arguments in the checkpoint ...")

    # Model.
    model.load_state_dict(state_dict["model"])

    # Optimizer.
    if not release and not args.finetune and not args.no_load_optim:
        try:
            if optimizer is not None:
                optimizer.load_state_dict(state_dict["optimizer"])
            if lr_scheduler is not None:
                lr_scheduler.load_state_dict(state_dict["lr_scheduler"])
        except KeyError:
            print_rank_last("Unable to load optimizer from checkpoint {}. "
                            "Specify --no-load-optim or --finetune to prevent "
                            "attempting to load the optimizer state, "
                            "exiting ...".format(checkpoint_name_local))
            sys.exit()

    # rng states.
    if not release and not args.finetune and not args.no_load_rng:
        try:
            random.setstate(state_dict["random_rng_state"])
            np.random.set_state(state_dict["np_rng_state"])
            torch.set_rng_state(state_dict["torch_rng_state"])
            torch.cuda.set_rng_state(state_dict["cuda_rng_state"])
            mpu.get_cuda_rng_tracker().set_states(
                state_dict["rng_tracker_states"])
        except KeyError:
            print_rank_last("Unable to load optimizer from checkpoint {}. "
                            "Specify --no-load-rng or --finetune to prevent "
                            "attempting to load the optimizer state, "
                            "exiting ...".format(checkpoint_name_local))
            sys.exit()

    torch.distributed.barrier()
    print_rank_last(
        "  successfully loaded checkpoint (with expert parametes updated) from {} at iteration {}"
        .format(args.load, iteration))

    return iteration
Пример #7
0
def save_checkpoint(iteration, model, optimizer, lr_scheduler):
    """Save a model checkpoint with expert parallel """
    # TODO: update patch
    from megatron import get_args
    from megatron import mpu
    from megatron import print_rank_last

    expert_dp_comm = "none"

    if mpu.get_data_parallel_rank() == 0:
        # at dp rank 0, we still follows the native load_checkpoint by megatron
        from megatron.checkpointing import save_checkpoint as save_checkpoint_native

        save_checkpoint_native(iteration, model, optimizer, lr_scheduler)
        return

    args = get_args()

    # Only rank zero of the data parallel writes to the disk.
    if isinstance(model, DistributedDataParallel):
        model = model.module

    print_rank_last("saving checkpoint at iteration {:7d} to {}".format(
        iteration, args.save))

    # Arguments, iteration, and model.
    state_dict = {}
    state_dict["model"] = model.state_dict_for_save_checkpoint(
        keep_vars=(mpu.get_data_parallel_rank() > 0))

    def extract_expert_param(state_dict, expert_dp_comm="none"):
        state_dict_new = state_dict.__class__()
        for k, v in state_dict.items():
            # megatron uses both dict and OrderedDict in its state_dict
            if isinstance(v, (OrderedDict, dict)):
                v_new = extract_expert_param(v, expert_dp_comm)
                if len(v_new) > 0:
                    state_dict_new[k] = v_new
            elif hasattr(v, "dp_comm") and v.dp_comm == expert_dp_comm:
                state_dict_new[k] = v.detach()
        return state_dict_new

    state_dict["model"] = extract_expert_param(state_dict["model"],
                                               expert_dp_comm)

    # Optimizer stuff.
    if not args.no_save_optim:
        if optimizer is not None:
            state_dict["optimizer"] = optimizer.state_dict()
            param_global_idx = 0
            for param_group in optimizer.optimizer.param_groups:
                for param in param_group["params"]:
                    if not (hasattr(param, "dp_comm")
                            and param.dp_comm == expert_dp_comm):
                        # this parameter is not an expert parameter
                        # thus there is no need to save its state in current rank
                        # since it has been saved by data parallel rank 0
                        if args.fp16:
                            # fp16 optimizer may have empty state due to overflow
                            state_dict["optimizer"]["optimizer"]["state"].pop(
                                param_global_idx, None)
                        else:
                            state_dict["optimizer"]["state"].pop(
                                param_global_idx)
                    param_global_idx += 1
            if args.fp16:
                state_dict["optimizer"]["optimizer"].pop("param_groups")
                # fp32_from_fp16_params in state_dict is not a copy
                # but a reference to optimizer.fp32_from_fp16_params,
                # changing it in state_dict will change
                # optimizer.fp32_from_fp16_params as well
                # thus we create an empty fp32_from_fp16_params in state_dict
                # and only insert expert parameters.
                fp32_from_fp16_params = state_dict["optimizer"][
                    "fp32_from_fp16_params"]
                state_dict["optimizer"]["fp32_from_fp16_params"] = []
                for param_group in fp32_from_fp16_params:
                    param_group_copy = []
                    for param in param_group:
                        param_copy = (param if hasattr(param, "dp_comm")
                                      and param.dp_comm == expert_dp_comm else
                                      None)
                        param_group_copy.append(param_copy)
                    state_dict["optimizer"]["fp32_from_fp16_params"].append(
                        param_group_copy)
            else:
                state_dict["optimizer"].pop("param_groups")

    # Save.
    checkpoint_name = get_fmoe_checkpoint_name(args.save, iteration)
    from megatron.checkpointing import ensure_directory_exists
    from megatron.checkpointing import get_checkpoint_tracker_filename

    ensure_directory_exists(checkpoint_name)
    torch.save(state_dict, checkpoint_name)

    # Wait so everyone is done (necessary)
    torch.distributed.barrier()
    if torch.distributed.get_rank() == 0:
        print(
            "  successfully saved checkpoint at iteration {:7d} to {}".format(
                iteration, args.save),
            flush=True,
        )
    # And update the latest iteration
    if torch.distributed.get_rank() == 0:
        tracker_filename = get_checkpoint_tracker_filename(args.save)
        with open(tracker_filename, "w") as f:
            f.write(str(iteration))
    # Wait so everyone is done (not necessary)
    torch.distributed.barrier()
Пример #8
0
def calculate_correct_answers(name, model, dataloader,
                              epoch, output_predictions):
    """Calculate correct over total answers and return prediction if the
    `output_predictions` is true."""
    args = get_args()
    start_time = time.time()
    model.eval()
    saved_batch_size = args.micro_batch_size
    with torch.no_grad():
        # For all the batches in the dataset.
        total = 0
        correct = 0
        if output_predictions:
            # This option is only possible when data parallel size is 1.
            assert mpu.get_data_parallel_world_size() == 1
            softmaxes = []
            labels = []
            ids = []
        for _, batch in enumerate(dataloader):
            # Run the model forward.
            tokens, types, labels_, attention_mask = process_batch(batch)

            # For evaluation only mode we use drop_last = False to get all the
            # samples, which means we might not have a full batch, so we
            # adjust batch_size here to actual batch size of data
            actual_batch_size = len(labels_)
            # ... applying sample_multiplier if necessary
            ds = dataloader.dataset
            if hasattr(ds, 'sample_multiplier'):
                actual_batch_size *= ds.sample_multiplier
            args.micro_batch_size = actual_batch_size

            if not mpu.is_pipeline_first_stage():
                input_tensor, _ = communicate(
                    tensor_send_next=None,
                    tensor_send_prev=None,
                    recv_forward=True,
                    recv_backward=False)
            else:
                input_tensor = None

            # Forward model.
            if mpu.is_pipeline_first_stage():
                assert input_tensor is None
                output_tensor = model(tokens, attention_mask, tokentype_ids=types)
            else:
                assert input_tensor is not None
                output_tensor = model(input_tensor, attention_mask)

            if mpu.is_pipeline_last_stage():
                logits = output_tensor

                # Add output predictions.
                if output_predictions:
                    softmaxes.extend(torch.nn.Softmax(dim=-1)(
                        logits.float()).data.cpu().numpy().tolist())
                    labels.extend(labels_.data.cpu().numpy().tolist())
                    ids.extend(batch['uid'].cpu().numpy().tolist())
                # Compute the correct answers.
                predicted = torch.argmax(logits, dim=-1)
                corrects = (predicted == labels_)
                # Add to the counters.
                total += labels_.size(0)
                correct += corrects.sum().item()
            else:
                communicate(
                    tensor_send_next=output_tensor,
                    tensor_send_prev=None,
                    recv_forward=False,
                    recv_backward=False)

    model.train()
    args.micro_batch_size = saved_batch_size

    # Reduce.
    if mpu.is_pipeline_last_stage():
        unreduced = torch.cuda.LongTensor([correct, total])
        torch.distributed.all_reduce(unreduced,
                                     group=mpu.get_data_parallel_group())

        # Print on screen.

        correct_ans = unreduced[0].item()
        total_count = unreduced[1].item()
        percent = float(correct_ans) * 100.0 / float(total_count)
        elapsed_time = time.time() - start_time
        print_rank_last(' > |epoch: {}| metrics for {}: correct / total '
                        '= {} / {} = {:.4f} %, elapsed time (sec): {:.3f}'.format(
                            epoch, name, correct_ans, total_count,
                            percent, elapsed_time))

        if output_predictions:
            return correct_ans, total_count, (softmaxes, labels, ids)
        return correct_ans, total_count
    if output_predictions:
        return 0, 0, ()
    return 0, 0
Пример #9
0
def training_log(loss_dict, total_loss_dict, learning_rate, iteration,
                 loss_scale, report_memory_flag, skipped_iter):
    """Log training information such as losses, timing, ...."""
    args = get_args()
    timers = get_timers()
    writer = get_tensorboard_writer()

    # Advanced, skipped, and Nan iterations.
    advanced_iters_key = 'advanced iterations'
    skipped_iters_key = 'skipped iterations'
    nan_iters_key = 'nan iterations'
    # Advanced iterations.
    if not skipped_iter:
        total_loss_dict[advanced_iters_key] = total_loss_dict.get(
            advanced_iters_key, 0) + 1
    else:
        if advanced_iters_key not in total_loss_dict:
            total_loss_dict[advanced_iters_key] = 0
    # Skipped iterations.
    total_loss_dict[skipped_iters_key] = total_loss_dict.get(
        skipped_iters_key, 0) + skipped_iter
    # Update losses and set nan iterations
    got_nan = False
    for key in loss_dict:
        if not skipped_iter:
            total_loss_dict[key] = total_loss_dict.get(
                key, torch.cuda.FloatTensor([0.0])) + loss_dict[key]
        else:
            value = loss_dict[key].float().sum().item()
            is_nan = value == float('inf') or \
                     value == -float('inf') or \
                     value != value
            got_nan = got_nan or is_nan
    total_loss_dict[nan_iters_key] = total_loss_dict.get(nan_iters_key,
                                                         0) + int(got_nan)

    # Logging.
    timers_to_log = []

    def add_to_logging(name):
        if name in timers.timers:
            timers_to_log.append(name)

    add_to_logging('forward-compute')
    add_to_logging('forward-recv')
    add_to_logging('forward-send')
    add_to_logging('forward-send-backward-recv')
    add_to_logging('backward-compute')
    add_to_logging('backward-recv')
    add_to_logging('backward-send')
    add_to_logging('backward-send-forward-recv')
    add_to_logging('backward-params-all-reduce')
    add_to_logging('backward-embedding-all-reduce')
    add_to_logging('optimizer-copy-to-main-grad')
    add_to_logging('optimizer-unscale-and-check-inf')
    add_to_logging('optimizer-clip-main-grad')
    add_to_logging('optimizer-copy-main-to-model-params')
    add_to_logging('optimizer')
    add_to_logging('batch-generator')

    # Calculate batch size.
    batch_size = args.micro_batch_size * args.data_parallel_size * \
        get_num_microbatches()

    total_iterations = total_loss_dict[advanced_iters_key] + \
                       total_loss_dict[skipped_iters_key]

    # Tensorboard values.
    if writer and is_last_rank():
        writer.add_scalar('learning-rate', learning_rate, iteration)
        writer.add_scalar('learning-rate vs samples', learning_rate,
                          args.consumed_train_samples)
        writer.add_scalar('batch-size', batch_size, iteration)
        writer.add_scalar('batch-size vs samples', batch_size,
                          args.consumed_train_samples)
        for key in loss_dict:
            writer.add_scalar(key, loss_dict[key], iteration)
            writer.add_scalar(key + ' vs samples', loss_dict[key],
                              args.consumed_train_samples)
        writer.add_scalar('loss-scale', loss_scale, iteration)
        writer.add_scalar('loss-scale vs samples', loss_scale,
                          args.consumed_train_samples)
        timers.write(timers_to_log,
                     writer,
                     iteration,
                     normalizer=total_iterations)

    if iteration % args.log_interval == 0:
        elapsed_time = timers('interval time').elapsed()
        elapsed_time_per_iteration = elapsed_time / total_iterations
        if writer and torch.distributed.get_rank() == 0:
            writer.add_scalar('iteration-time', elapsed_time_per_iteration,
                              iteration)
        log_string = ' iteration {:8d}/{:8d} |'.format(iteration,
                                                       args.train_iters)
        log_string += ' consumed samples: {:12d} |'.format(
            args.consumed_train_samples)
        log_string += ' elapsed time per iteration (ms): {:.1f} |'.format(
            elapsed_time_per_iteration * 1000.0)
        log_string += ' learning rate: {:.3E} |'.format(learning_rate)
        log_string += ' global batch size: {:5d} |'.format(batch_size)
        for key in total_loss_dict:
            if key not in [
                    advanced_iters_key, skipped_iters_key, nan_iters_key
            ]:
                avg = total_loss_dict[key].item() / \
                      float(max(1, total_loss_dict[advanced_iters_key]))
                if avg > 0.0:
                    log_string += ' {}: {:.6E} |'.format(key, avg)
                total_loss_dict[key] = torch.cuda.FloatTensor([0.0])
        log_string += ' loss scale: {:.1f} |'.format(loss_scale)
        log_string += ' number of skipped iterations: {:3d} |'.format(
            total_loss_dict[skipped_iters_key])
        log_string += ' number of nan iterations: {:3d} |'.format(
            total_loss_dict[nan_iters_key])
        total_loss_dict[advanced_iters_key] = 0
        total_loss_dict[skipped_iters_key] = 0
        total_loss_dict[nan_iters_key] = 0
        print_rank_last(log_string)
        if report_memory_flag and learning_rate > 0.:
            # Report memory after optimizer state has been initialized.
            report_memory('(after {} iterations)'.format(iteration))
            report_memory_flag = False
        timers.log(timers_to_log, normalizer=args.log_interval)

    return report_memory_flag
Пример #10
0
def calculate_correct_answers(name, model, dataloader, epoch,
                              output_predictions):
    """Calculate correct over total answers and return prediction if the
    `output_predictions` is true."""
    args = get_args()
    forward_backward_func = get_forward_backward_func()
    start_time = time.time()
    for m in model:
        m.eval()
    saved_micro_batch_size = args.micro_batch_size
    saved_global_batch_size = args.global_batch_size

    ds = dataloader.dataset
    if hasattr(ds, 'sample_multiplier'):
        # If our dataset as a sample_multiplier attribute that means
        # each "sample" from the dataset actually has multiple samples
        # that will collapse into the batch dimension (for example in
        # the RACE dataset that has several options), we need to
        # account for that when setting the micro batch size.
        sample_multiplier = ds.sample_multiplier
    else:
        sample_multiplier = 1
    micro_batch_size_times_data_parallel = args.orig_micro_batch_size * args.data_parallel_size
    num_micro_batches = args.orig_global_batch_size // micro_batch_size_times_data_parallel

    def loss_func(output_predictions, labels, output_tensor):
        logits = output_tensor

        loss_dict = {}
        # Add output predictions.
        if output_predictions:
            assert False
            loss_dict['softmaxes'] = torch.nn.Softmax(dim=-1)(
                logits.float()).data.cpu().numpy().tolist()
            loss_dict['labels'] = labels.data.cpu().numpy().tolist()
            loss_dict['ids'] = batch['uid'].cpu().numpy().tolist()
        # Compute the correct answers.
        predicted = torch.argmax(logits, dim=-1)
        corrects = (predicted == labels)
        # Add to the counters.
        loss_dict['total'] = labels.size(0)
        loss_dict['correct'] = corrects.sum().item()

        return 0, loss_dict

    # defined inside to capture output_predictions
    def correct_answers_forward_step(batch, model):
        try:
            batch_ = next(batch)
        except BaseException:
            batch_ = batch
        tokens, types, labels, attention_mask = process_batch(batch_)

        # Forward model.
        args = get_args()
        output_tensor = model(tokens, attention_mask, tokentype_ids=types)

        return output_tensor, partial(loss_func, output_predictions, labels)

    with torch.no_grad():
        # For all the batches in the dataset.
        total = 0
        correct = 0
        if output_predictions:
            # This option is only possible when data parallel size is 1.
            assert mpu.get_data_parallel_world_size() == 1
            softmaxes = []
            labels = []
            ids = []
        for _, batch in enumerate(dataloader):
            # For evaluation only mode we use drop_last = False to get all the
            # samples, which means we might not have a full batch, so we
            # adjust batch_size here to actual batch size of data
            actual_batch_size = len(batch['label'])
            # ... applying sample_multiplier if necessary
            args.micro_batch_size = actual_batch_size * sample_multiplier
            args.global_batch_size = actual_batch_size * sample_multiplier * num_micro_batches

            loss_dicts = forward_backward_func(correct_answers_forward_step,
                                               batch,
                                               model,
                                               optimizer=None,
                                               timers=None,
                                               forward_only=True)

            for loss_dict in loss_dicts:
                if output_predictions:
                    softmaxes.extend(loss_dict['softmaxes'])
                    labels.extend(loss_dict['labels'])
                    ids.extend(loss_dict['ids'])
                total += loss_dict['total']
                correct += loss_dict['correct']

    for m in model:
        m.train()
    args.micro_batch_size = saved_micro_batch_size
    args.global_batch_size = saved_global_batch_size

    # Reduce.
    if mpu.is_pipeline_last_stage():
        unreduced = torch.cuda.LongTensor([correct, total])
        torch.distributed.all_reduce(unreduced,
                                     group=mpu.get_data_parallel_group())

        # Print on screen.

        correct_ans = unreduced[0].item()
        total_count = unreduced[1].item()
        percent = float(correct_ans) * 100.0 / float(total_count)
        elapsed_time = time.time() - start_time
        print_rank_last(
            ' > |epoch: {}| metrics for {}: correct / total '
            '= {} / {} = {:.4f} %, elapsed time (sec): {:.3f}'.format(
                epoch, name, correct_ans, total_count, percent, elapsed_time))

        if output_predictions:
            return correct_ans, total_count, (softmaxes, labels, ids)
        return correct_ans, total_count
    if output_predictions:
        return 0, 0, ()
    return 0, 0
Пример #11
0
 def metrics_func(model, epoch):
     print_rank_0("calculating metrics ...")
     correct, total = calculate_correct_answers(model, dataloader, epoch)
     percent = float(correct) * 100.0 / float(total)
     print_rank_last(" >> |epoch: {}| overall: correct / total = {} / {} = "
                     "{:.4f} %".format(epoch, correct, total, percent))