Example #1
0
def save_checkpoint(iteration,
                    model,
                    optimizer,
                    lr_scheduler,
                    args,
                    tag=None,
                    barrier=True):
    """Save a model checkpoint."""
    if tag is None:
        tag = str(iteration)
    if args.deepspeed:
        save_ds_checkpoint(iteration, model, lr_scheduler, args, tag=tag)
    else:
        # Only rank zer0 of the data parallel writes to the disk.
        if isinstance(model, torchDDP):
            model = model.module

        if mpu.get_data_parallel_rank() == 0:
            checkpoint_name = get_checkpoint_name(args.save, tag)
            print(
                'global rank {} is saving checkpoint at iteration {:7d} to {}'.
                format(torch.distributed.get_rank(), iteration,
                       checkpoint_name))

            sd = {}
            sd['iteration'] = iteration
            sd['module'] = model.state_dict()

            # Optimizer stuff.
            if not args.no_save_optim:
                if optimizer is not None:
                    sd['optimizer'] = optimizer.state_dict()
                if lr_scheduler is not None:
                    sd['lr_scheduler'] = lr_scheduler.state_dict()

            # rng states.
            if not args.no_save_rng:
                sd['random_rng_state'] = random.getstate()
                sd['np_rng_state'] = np.random.get_state()
                sd['torch_rng_state'] = torch.get_rng_state()
                sd['cuda_rng_state'] = torch.cuda.get_rng_state()
                sd['rng_tracker_states'] = mpu.get_cuda_rng_tracker(
                ).get_states()

            ensure_directory_exists(checkpoint_name)
            torch.save(sd, checkpoint_name)
            print('  successfully saved {}'.format(checkpoint_name))

    # Wait so everyone is done (necessary)
    if barrier:
        torch.distributed.barrier()
    # And update the latest iteration
    if torch.distributed.get_rank() == 0:
        tracker_filename = get_checkpoint_tracker_filename(args.save)
        with open(tracker_filename, 'w') as f:
            f.write(tag)
Example #2
0
def save_ds_checkpoint(iteration, model, args):
    """Save a model checkpoint."""

    sd = {}
    sd['iteration'] = iteration
    # rng states.
    if not args.no_save_rng:
        sd['random_rng_state'] = random.getstate()
        sd['np_rng_state'] = np.random.get_state()
        sd['torch_rng_state'] = torch.get_rng_state()
        sd['cuda_rng_state'] = torch.cuda.get_rng_state()
        sd['rng_tracker_states'] = mpu.get_cuda_rng_tracker().get_states()

    model.save_checkpoint(args.save, iteration, client_state=sd)
def test_model_parallel_cuda_manual_seed(tensor_model_parallel_size):

    if torch.distributed.get_rank() == 0:
        print('> testing model parallel cuda manual seed with size {} ...'.
              format(tensor_model_parallel_size))

    mpu.initialize_model_parallel(tensor_model_parallel_size)
    tensor_model_parallel_size = mpu.get_tensor_model_parallel_world_size()

    mpu.model_parallel_cuda_manual_seed(12345)
    assert torch.cuda.initial_seed() == 12345
    with mpu.get_cuda_rng_tracker().fork():
        assert torch.cuda.initial_seed() == (
            12345 + 2718 + mpu.get_tensor_model_parallel_rank())

    # Reset the tracker
    mpu.get_cuda_rng_tracker().reset()

    # Reset groups
    mpu.destroy_model_parallel()

    torch.distributed.barrier()
    if torch.distributed.get_rank() == 0:
        print('>> passed the test :-)')
Example #4
0
def save_ds_checkpoint(iteration, model, lr_scheduler, args, tag):
    """Save a model checkpoint."""

    sd = {}
    sd['iteration'] = iteration
    if lr_scheduler is not None:
        sd['client_lr_scheduler'] = lr_scheduler.state_dict()
    # rng states.
    if not args.no_save_rng:
        sd['random_rng_state'] = random.getstate()
        sd['np_rng_state'] = np.random.get_state()
        sd['torch_rng_state'] = torch.get_rng_state()
        sd['cuda_rng_state'] = torch.cuda.get_rng_state()
        sd['rng_tracker_states'] = mpu.get_cuda_rng_tracker().get_states()
    model.save_checkpoint(args.save, tag, client_state=sd)
Example #5
0
def load_checkpoint(model, optimizer, lr_scheduler, args):
    """Load a model checkpoint."""

    iteration, release, success = get_checkpoint_iteration(args)

    if not success:
        return 0

    if args.deepspeed:
        raise NotImplemented("No installed deep speed")

    else:

        if args.load_openai:
            from utils import move_weights
            from model import DistributedDataParallel as DDP
            from fp16 import FP16_Module
            model_path = args.load
            from transformers import GPT2LMHeadModel
            print('global rank {} is loading openai weights {}'.format(
                torch.distributed.get_rank(), model_path))
            model.cpu()
            gpt2model = GPT2LMHeadModel.from_pretrained(
                model_path, cache_dir='gpt2_weights')
            model2fill = model
            while isinstance(model2fill, (DDP, FP16_Module)):
                model2fill = model2fill.module
            move_weights(model2fill, gpt2model)
            model.cuda(torch.cuda.current_device())
            sd = {}
        else:
            # Checkpoint.
            checkpoint_name = get_checkpoint_name(args.load, iteration,
                                                  release)

            if mpu.get_data_parallel_rank() == 0:
                print('global rank {} is loading checkpoint {}'.format(
                    torch.distributed.get_rank(), checkpoint_name))
            sd = torch.load(checkpoint_name, map_location='cpu')

            if isinstance(model, torchDDP):
                model = model.module

            # Model.
            try:
                model.load_state_dict(sd['model'])
            except KeyError:
                print_rank_0(
                    'A metadata file exists but unable to load model '
                    'from checkpoint {}, exiting'.format(checkpoint_name))
                exit()

            # Optimizer.
            if not release and not args.finetune and not args.no_load_optim:
                try:
                    if optimizer is not None:
                        optimizer.load_state_dict(sd['optimizer'])
                    if lr_scheduler is not None:
                        lr_scheduler.load_state_dict(sd['lr_scheduler'])
                except KeyError:
                    print_rank_0(
                        'Unable to load optimizer from checkpoint {}, exiting. '
                        'Specify --no-load-optim or --finetune to prevent '
                        'attempting to load the optimizer '
                        'state.'.format(checkpoint_name))
                    exit()

    # Iterations.
    if args.finetune or release:
        iteration = 0
    else:
        try:
            iteration = sd['iteration']
        except KeyError:
            try:  # Backward compatible with older checkpoints
                iteration = sd['total_iters']
            except KeyError:
                print_rank_0(
                    'A metadata file exists but Unable to load iteration '
                    ' from checkpoint {}, exiting'.format(checkpoint_name))
                exit()

    # rng states.
    if not release and not args.finetune and not args.no_load_rng:
        try:
            random.setstate(sd['random_rng_state'])
            np.random.set_state(sd['np_rng_state'])
            torch.set_rng_state(sd['torch_rng_state'])
            torch.cuda.set_rng_state(sd['cuda_rng_state'])
            mpu.get_cuda_rng_tracker().set_states(sd['rng_tracker_states'])
        except KeyError:
            print_rank_0(
                'Unable to load optimizer from checkpoint {}, exiting. '
                'Specify --no-load-optim or --finetune to prevent '
                'attempting to load the optimizer '
                'state.'.format(checkpoint_name))
            exit()

    torch.distributed.barrier()
    if mpu.get_data_parallel_rank() == 0:
        print('  successfully loaded {}'.format(checkpoint_name))

    return iteration
Example #6
0
def load_checkpoint(load_path, model, optimizer, lr_scheduler, args):
    """Load a model checkpoint."""

    iteration, release, success = get_checkpoint_iteration(load_path)

    if not success:
        return 0

    if args.deepspeed:

        checkpoint_name, sd = model.load_checkpoint(
            load_path,
            iteration,
            load_module_strict=False,
            load_optimizer_states=False,
            load_lr_scheduler_states=False)

        if checkpoint_name is None:
            if mpu.get_data_parallel_rank() == 0:
                print("Unable to load checkpoint.")
            return iteration

    else:

        # Checkpoint.
        checkpoint_name = get_checkpoint_name(load_path, iteration, release)

        if mpu.get_data_parallel_rank() == 0:
            print('global rank {} is loading checkpoint {}'.format(
                torch.distributed.get_rank(), checkpoint_name))

        # Load the checkpoint.
        sd = torch.load(checkpoint_name, map_location='cpu')

        if isinstance(model, torchDDP):
            model = model.module

        # Model.
        try:
            model.load_state_dict(sd['model'])
        except KeyError:
            print_rank_0('A metadata file exists but unable to load model '
                         'from checkpoint {}, exiting'.format(checkpoint_name))
            exit()

        # Optimizer.
        if not release and not args.finetune and not args.no_load_optim:
            try:
                if optimizer is not None:
                    optimizer.load_state_dict(sd['optimizer'])
                if lr_scheduler is not None:
                    lr_scheduler.load_state_dict(sd['lr_scheduler'])
            except KeyError:
                print_rank_0(
                    'Unable to load optimizer from checkpoint {}, exiting. '
                    'Specify --no-load-optim or --finetune to prevent '
                    'attempting to load the optimizer '
                    'state.'.format(checkpoint_name))
                exit()

    # Iterations.
    if args.finetune or release:
        iteration = 0
    else:
        try:
            iteration = sd['iteration']
        except KeyError:
            try:  # Backward compatible with older checkpoints
                iteration = sd['total_iters']
            except KeyError:
                print_rank_0(
                    'A metadata file exists but Unable to load iteration '
                    ' from checkpoint {}, exiting'.format(checkpoint_name))
                exit()

    # rng states.
    if not release and not args.finetune and not args.no_load_rng:
        try:
            random.setstate(sd['random_rng_state'])
            np.random.set_state(sd['np_rng_state'])
            torch.set_rng_state(sd['torch_rng_state'])
            torch.cuda.set_rng_state(sd['cuda_rng_state'])
            mpu.get_cuda_rng_tracker().set_states(sd['rng_tracker_states'])
        except KeyError:
            print_rank_0(
                'Unable to load optimizer from checkpoint {}, exiting. '
                'Specify --no-load-optim or --finetune to prevent '
                'attempting to load the optimizer '
                'state.'.format(checkpoint_name))
            exit()

    torch.distributed.barrier()
    if mpu.get_data_parallel_rank() == 0:
        print('  successfully loaded {}'.format(checkpoint_name))

    return iteration
def test_cuda_rng_tracker(tensor_model_parallel_size):

    if torch.distributed.get_rank() == 0:
        print('> testing cuda rng tracker with size {} ...'.format(
            tensor_model_parallel_size))

    mpu.initialize_model_parallel(tensor_model_parallel_size)
    tensor_model_parallel_size = mpu.get_tensor_model_parallel_world_size()

    seed_1 = 1234
    seed_2 = 4321
    size = [12, 21]
    tensor = torch.cuda.FloatTensor(size)

    # Set to seed_1 and generate two tensors.
    torch.cuda.manual_seed(seed_1)
    torch.randn(size, out=tensor)
    target_11 = tensor.clone()
    torch.randn(size, out=tensor)
    target_12 = tensor.clone()

    # Set to seed_2 and generate two tensors.
    torch.cuda.manual_seed(seed_2)
    torch.randn(size, out=tensor)
    target_21 = tensor.clone()
    torch.randn(size, out=tensor)
    target_22 = tensor.clone()

    # Now if we interleave seed_1 and seed_2,
    # we should still get the same tensors
    torch.cuda.manual_seed(seed_1)
    mpu.get_cuda_rng_tracker().add('test', seed_2)

    torch.randn(size, out=tensor)
    result_11 = tensor.clone()

    with mpu.get_cuda_rng_tracker().fork('test'):
        torch.randn(size, out=tensor)
        result_21 = tensor.clone()

    torch.randn(size, out=tensor)
    result_12 = tensor.clone()

    with mpu.get_cuda_rng_tracker().fork('test'):
        torch.randn(size, out=tensor)
        result_22 = tensor.clone()

    diff = result_11.sub(result_21).abs().max()
    diff = min(diff, result_12.sub(result_22).abs().max())
    print('   max diff in generated tensors (should be non-zero) on '
          'global rank {}: {}'.format(torch.distributed.get_rank(), diff))
    assert diff > 1.0e-6
    error = max(
        result_11.sub(target_11).abs().max(),
        result_12.sub(target_12).abs().max())
    error = max(error, result_21.sub(target_21).abs().max())
    error = max(error, result_22.sub(target_22).abs().max())
    print('   max error in generated tensors (should be zero) on '
          'global rank {}: {}'.format(torch.distributed.get_rank(), error))
    assert error < 1.0e-6

    # Reset the tracker
    mpu.get_cuda_rng_tracker().reset()

    # Reset groups
    mpu.destroy_model_parallel()

    torch.distributed.barrier()
    if torch.distributed.get_rank() == 0:
        print('>> passed the test :-)')
def load_checkpoint(model, optimizer, lr_scheduler, args):
    """Load a model checkpoint."""
    if isinstance(model, torchDDP):
        model = model.module
    # Read the tracker file and set the iteration.
    tracker_filename = get_checkpoint_tracker_filename(args.load)
    if not os.path.isfile(tracker_filename):
        print_rank_0('WARNING: could not find the metadata file {} '.format(
            tracker_filename))
        print_rank_0('    will not load any checkpoints and will start from '
                     'random')
        return 0
    iteration = 0
    release = False
    with open(tracker_filename, 'r') as f:
        metastring = f.read().strip()
        try:
            iteration = int(metastring)
        except ValueError:
            release = metastring == 'release'
            if not release:
                print_rank_0('ERROR: Invalid metadata file {}. Exiting'.format(
                    tracker_filename))
                exit()

    assert iteration > 0 or release, 'error parsing metadata file {}'.format(
        tracker_filename)

    # Checkpoint.
    checkpoint_name = get_checkpoint_name(args.load, iteration, release)
    if mpu.get_data_parallel_rank() == 0:
        print('global rank {} is loading checkpoint {}'.format(
            torch.distributed.get_rank(), checkpoint_name))

    # Load the checkpoint.
    sd = torch.load(checkpoint_name, map_location='cpu')

    # Iterations.
    if args.finetune or release:
        iteration = 0
    else:
        try:
            iteration = sd['iteration']
        except KeyError:
            try: # Backward compatible with older checkpoints
                iteration = sd['total_iters']
            except KeyError:
                print_rank_0('A metadata file exists but Unable to load iteration '
                             ' from checkpoint {}, exiting'.format(checkpoint_name))
                exit()

    # Model.
    try:
        model.load_state_dict(sd['model'])
    except KeyError:
        print_rank_0('A metadata file exists but unable to load model '
                     'from checkpoint {}, exiting'.format(checkpoint_name))
        exit()

    # Optimizer.
    if not release and not args.finetune and not args.no_load_optim:
        try:
            if optimizer is not None:
                optimizer.load_state_dict(sd['optimizer'])
            if lr_scheduler is not None:
                lr_scheduler.load_state_dict(sd['lr_scheduler'])
        except KeyError:
            print_rank_0('Unable to load optimizer from checkpoint {}, exiting. '
                         'Specify --no-load-optim or --finetune to prevent '
                         'attempting to load the optimizer '
                         'state.'.format(checkpoint_name))
            exit()

    # rng states.
    if not release and not args.finetune and not args.no_load_rng:
        try:
            random.setstate(sd['random_rng_state'])
            np.random.set_state(sd['np_rng_state'])
            torch.set_rng_state(sd['torch_rng_state'])
            torch.cuda.set_rng_state(sd['cuda_rng_state'])
            mpu.get_cuda_rng_tracker().set_states(sd['rng_tracker_states'])
        except KeyError:
            print_rank_0('Unable to load optimizer from checkpoint {}, exiting. '
                         'Specify --no-load-optim or --finetune to prevent '
                         'attempting to load the optimizer '
                         'state.'.format(checkpoint_name))
            exit()

    #torch.distributed.barrier()
    if mpu.get_data_parallel_rank() == 0:
        print('  successfully loaded {}'.format(checkpoint_name))

    return iteration
Example #9
0
def load_checkpoint(model,
                    optimizer,
                    lr_scheduler,
                    args,
                    no_deepspeed=False,
                    no_load_optim=False):
    """Load a model checkpoint."""

    load_dir, tag, release, success = get_checkpoint_iteration(args.load)

    if not success:
        return 0

    if args.deepspeed and not no_deepspeed:

        checkpoint_name, sd = model.load_checkpoint(
            load_dir,
            tag,
            load_optimizer_states=not args.no_load_optim and not no_load_optim,
            load_lr_scheduler_states=not args.no_load_lr_scheduler)
        if not args.no_load_lr_scheduler and "client_lr_scheduler" in sd:
            lr_scheduler.load_state_dict(sd["client_lr_scheduler"])
            print_rank_0("Load lr scheduler state")
        if checkpoint_name is None:
            if mpu.get_data_parallel_rank() == 0:
                print("Unable to load checkpoint.")
            return tag

    else:

        # Checkpoint.
        checkpoint_name = get_checkpoint_name(load_dir, tag, release)

        if mpu.get_data_parallel_rank() == 0:
            print('global rank {} is loading checkpoint {}'.format(
                torch.distributed.get_rank(), checkpoint_name))

        # Load the checkpoint.
        sd = torch.load(checkpoint_name, map_location='cpu')

        # Model.
        if args.deepspeed:
            model = model.module
        missing_keys, unexpected_keys = model.load_state_dict(sd['module'],
                                                              strict=False)
        if missing_keys or unexpected_keys:
            print_rank_0(
                f"Missing keys {missing_keys}, unexpected keys {unexpected_keys}"
            )

        # Optimizer.
        if not release and not args.finetune and not args.no_load_optim and not no_load_optim:
            try:
                if optimizer is not None:
                    optimizer.load_state_dict(sd['optimizer'])
                if lr_scheduler is not None:
                    lr_scheduler.load_state_dict(sd['lr_scheduler'])
            except KeyError:
                print_rank_0(
                    'Unable to load optimizer from checkpoint {}, exiting. '
                    'Specify --no-load-optim or --finetune to prevent '
                    'attempting to load the optimizer '
                    'state.'.format(checkpoint_name))

    # Iterations.
    if args.finetune or release:
        iteration = 0
    else:
        try:
            iteration = sd['iteration']
        except KeyError:
            try:  # Backward compatible with older checkpoints
                iteration = sd['total_iters']
            except KeyError:
                print_rank_0(
                    'A metadata file exists but Unable to load iteration '
                    ' from checkpoint {}, starting from 0 iteration'.format(
                        checkpoint_name))
                iteration = 0

    # rng states.
    if not release and not args.finetune and not args.no_load_rng:
        try:
            random.setstate(sd['random_rng_state'])
            np.random.set_state(sd['np_rng_state'])
            torch.set_rng_state(sd['torch_rng_state'])
            torch.cuda.set_rng_state(sd['cuda_rng_state'])
            mpu.get_cuda_rng_tracker().set_states(sd['rng_tracker_states'])
        except KeyError:
            print_rank_0(
                'Unable to load random state from checkpoint {}, exiting. '
                'Specify --no-load-rng or --finetune to prevent '
                'attempting to load the random '
                'state.'.format(checkpoint_name))

    if mpu.get_data_parallel_rank() == 0:
        print('  successfully loaded {}'.format(checkpoint_name))

    return iteration
Example #10
0
def save_checkpoint(iteration,
                    model,
                    optimizer,
                    lr_scheduler,
                    args,
                    tag=None,
                    barrier=True,
                    only_changed_parameters=False,
                    no_deepspeed=False,
                    no_save_optim=False):
    """Save a model checkpoint."""
    if tag is None:
        tag = str(iteration)
    if args.deepspeed and not no_deepspeed:
        save_ds_checkpoint(iteration, model, lr_scheduler, args, tag=tag)
    else:
        # Only rank zer0 of the data parallel writes to the disk.

        if mpu.get_data_parallel_rank() == 0:
            checkpoint_name = get_checkpoint_name(args.save, tag)
            print(
                'global rank {} is saving checkpoint at iteration {:7d} to {}'.
                format(torch.distributed.get_rank(), iteration,
                       checkpoint_name))
            sd = {'iteration': iteration}
            if args.deepspeed:
                model = model.module
            state_dict = model.state_dict()
            if only_changed_parameters:
                requires_grad_dict = {}
                for name, parameter in model.named_parameters():
                    requires_grad_dict[name] = parameter.requires_grad
                state_dict = {
                    key: value
                    for key, value in state_dict.items()
                    if requires_grad_dict[key]
                }
            sd['module'] = state_dict

            # Optimizer stuff.
            if not args.no_save_optim and not no_save_optim:
                if optimizer is not None:
                    sd['optimizer'] = optimizer.state_dict()
                if lr_scheduler is not None:
                    sd['lr_scheduler'] = lr_scheduler.state_dict()

            # rng states.
            if not args.no_save_rng:
                sd['random_rng_state'] = random.getstate()
                sd['np_rng_state'] = np.random.get_state()
                sd['torch_rng_state'] = torch.get_rng_state()
                sd['cuda_rng_state'] = torch.cuda.get_rng_state()
                sd['rng_tracker_states'] = mpu.get_cuda_rng_tracker(
                ).get_states()

            ensure_directory_exists(checkpoint_name)
            torch.save(sd, checkpoint_name)
            print('  successfully saved {}'.format(checkpoint_name))

    # Wait so everyone is done (necessary)
    if barrier:
        torch.distributed.barrier()
    # And update the latest iteration
    if torch.distributed.get_rank() == 0:
        tracker_filename = get_checkpoint_tracker_filename(args.save)
        with open(tracker_filename, 'w') as f:
            f.write(tag)
Example #11
0
def load_checkpoint(model, optimizer, lr_scheduler, args):
    """Load a model checkpoint."""

    load_dir, tag, release, success = get_checkpoint_iteration(args)

    if not success:
        return 0

    if args.deepspeed:

        checkpoint_name, sd = model.load_checkpoint(
            load_dir,
            tag,
            load_optimizer_states=not args.no_load_optim,
            load_lr_scheduler_states=not args.no_load_optim)
        if "client_lr_scheduler" in sd:
            lr_scheduler.load_state_dict(sd["client_lr_scheduler"])
            print_rank_0("Load lr scheduler state")
        if checkpoint_name is None:
            if mpu.get_data_parallel_rank() == 0:
                print("Unable to load checkpoint.")
            return tag

    else:

        # Checkpoint.
        checkpoint_name = get_checkpoint_name(load_dir, tag, release)

        if mpu.get_data_parallel_rank() == 0:
            print('global rank {} is loading checkpoint {}'.format(
                torch.distributed.get_rank(), checkpoint_name))

        # Load the checkpoint.
        sd = torch.load(checkpoint_name, map_location='cpu')

        if isinstance(model, torchDDP):
            model = model.module

        # Model.
        try:

            def extend_embedding_weights(state_weights, model_weights):
                original_length = state_weights.shape[0]
                assert original_length <= args.max_position_embeddings + 1
                new_weights = model_weights.clone()
                new_weights[:original_length] = state_weights
                return new_weights

            if args.block_lm:
                if "transformer.block_position_embeddings.weight" in sd[
                        "module"]:
                    position_weights = sd['module'][
                        "transformer.position_embeddings.weight"]
                    if args.max_position_embeddings + 1 > position_weights.shape[
                            0]:
                        sd['module'][
                            "transformer.position_embeddings.weight"] = extend_embedding_weights(
                                position_weights,
                                model.state_dict()
                                ["transformer.position_embeddings.weight"].data
                            )
                        print_rank_0(
                            f"Extend position embedding to {args.max_position_embeddings + 1}"
                        )
                if "transformer.block_position_embeddings.weight" in sd[
                        "module"]:
                    block_position_weights = sd['module'][
                        "transformer.block_position_embeddings.weight"]
                    if args.max_position_embeddings + 1 > block_position_weights.shape[
                            0]:
                        sd['module'][
                            "transformer.block_position_embeddings.weight"] = extend_embedding_weights(
                                block_position_weights,
                                model.state_dict()
                                ["transformer.block_position_embeddings.weight"]
                                .data)
                        print_rank_0(
                            f"Extend block position embedding to {args.max_position_embeddings + 1}"
                        )

            model.load_state_dict(sd['module'], strict=False)
        except KeyError:
            print_rank_0('A metadata file exists but unable to load model '
                         'from checkpoint {}, exiting'.format(checkpoint_name))
            exit()

        # Optimizer.
        if not release and not args.finetune and not args.no_load_optim:
            try:
                if optimizer is not None:
                    optimizer.load_state_dict(sd['optimizer'])
                if lr_scheduler is not None:
                    lr_scheduler.load_state_dict(sd['lr_scheduler'])
            except KeyError:
                print_rank_0(
                    'Unable to load optimizer from checkpoint {}, exiting. '
                    'Specify --no-load-optim or --finetune to prevent '
                    'attempting to load the optimizer '
                    'state.'.format(checkpoint_name))
                exit()

    # Iterations.
    if args.finetune or release:
        iteration = 0
    else:
        try:
            iteration = sd['iteration']
        except KeyError:
            try:  # Backward compatible with older checkpoints
                iteration = sd['total_iters']
            except KeyError:
                print_rank_0(
                    'A metadata file exists but Unable to load iteration '
                    ' from checkpoint {}, exiting'.format(checkpoint_name))
                exit()

    # rng states.
    if not release and not args.finetune and not args.no_load_rng:
        try:
            random.setstate(sd['random_rng_state'])
            np.random.set_state(sd['np_rng_state'])
            torch.set_rng_state(sd['torch_rng_state'])
            torch.cuda.set_rng_state(sd['cuda_rng_state'])
            mpu.get_cuda_rng_tracker().set_states(sd['rng_tracker_states'])
        except KeyError:
            print_rank_0(
                'Unable to load optimizer from checkpoint {}, exiting. '
                'Specify --no-load-rng or --finetune to prevent '
                'attempting to load the random '
                'state.'.format(checkpoint_name))
            exit()

    if mpu.get_data_parallel_rank() == 0:
        print('  successfully loaded {}'.format(checkpoint_name))

    return iteration