Exemplo n.º 1
0
def save_checkpoint(args, trainer, epoch_itr, val_loss):
    if args.no_save:
        return
    # if args.no_save or not distributed_utils.is_master(args):
    #     return
    epoch = epoch_itr.epoch
    end_of_epoch = epoch_itr.end_of_epoch()
    updates = trainer.get_num_updates()

    checkpoint_conds = collections.OrderedDict()
    checkpoint_conds['checkpoint{}_r{}_n{}.pt'.format(
        epoch, distributed_utils.get_rank(),
        distributed_utils.get_world_size())] = (
            end_of_epoch and not args.no_epoch_checkpoints
            and epoch % args.save_interval == 0)
    checkpoint_conds['checkpoint_{}_{}_r{}_n{}.pt'.format(
        epoch, updates, distributed_utils.get_rank(),
        distributed_utils.get_world_size())] = (
            not end_of_epoch and args.save_interval_updates > 0
            and updates % args.save_interval_updates == 0)
    checkpoint_conds['checkpoint_best_r{}_n{}.pt'.format(
        distributed_utils.get_rank(), distributed_utils.get_world_size())] = (
            val_loss is not None and (not hasattr(save_checkpoint, 'best')
                                      or val_loss < save_checkpoint.best))
    checkpoint_conds['checkpoint_last_r{}_n{}.pt'.format(
        distributed_utils.get_rank(), distributed_utils.get_world_size()
    )] = True  # keep this last so that it's a symlink

    prev_best = getattr(save_checkpoint, 'best', val_loss)
    if val_loss is not None:
        save_checkpoint.best = min(val_loss, prev_best)
    extra_state = {
        'best': save_checkpoint.best,
        'train_iterator': epoch_itr.state_dict(),
        'val_loss': val_loss,
    }

    checkpoints = [
        os.path.join(args.save_dir, fn)
        for fn, cond in checkpoint_conds.items() if cond
    ]
    if len(checkpoints) > 0:
        for cp in checkpoints:
            trainer.save_checkpoint(cp, extra_state)

    if not end_of_epoch and args.keep_interval_updates > 0:
        # remove old checkpoints; checkpoints are sorted in descending order
        checkpoints = utils.checkpoint_paths(
            args.save_dir, pattern=r'checkpoint_\d+_(\d+)\.pt')
        for old_chk in checkpoints[args.keep_interval_updates:]:
            os.remove(old_chk)
Exemplo n.º 2
0
 def __init__(self, iterable, epoch=None, prefix=None, log_interval=1000, main_process = False, assistant=None, truncate=None):
     super().__init__(iterable, epoch, prefix, main_process = main_process)
     self.log_interval = log_interval
     self.stats = None
     self.assistant = assistant
     self.world_size = get_world_size()
     self.successes = [ 0 for i in range(self.world_size)]
     self.samples = [ 0 for i in range(self.world_size)]
     self.confidence = [ 0 for i in range(self.world_size)]
     self.my_id = get_rank()
     self.truncate = truncate
Exemplo n.º 3
0
def _split(input_):
    """Split the tensor along its last dimension and keep the
    corresponding slice."""
    group = get_model_parallel_group()

    # Bypass the function if we are using only 1 GPU.
    if get_world_size(group=group) == 1:
        return input_

    # Split along last dimension.
    world_size = get_world_size(group=group)
    input_list = split_tensor_along_last_dim(input_, world_size)

    # Note: torch.split does not create contiguous tensors by default.
    rank = get_rank(group=group)
    output = input_list[rank].contiguous()

    return output
Exemplo n.º 4
0
def _gather(input_):
    """Gather tensors and concatinate along the last dimension."""
    group = get_model_parallel_group()

    # Bypass the function if we are using only 1 GPU.
    if get_world_size(group=group) == 1:
        return input_

    # Size and dimension.
    last_dim = input_.dim() - 1
    rank = get_rank(group=group)
    world_size = get_world_size(group=group)

    tensor_list = all_gather(None, input_, group=group)

    # Note: torch.cat already creates a contiguous tensor.
    output = torch.cat(tensor_list, dim=last_dim).contiguous()

    return output
Exemplo n.º 5
0
 def __init__(self, args, task):
     super().__init__(args, task)
     self.eps = args.label_smoothing
     from fairseq.sequence_generator import SequenceGenerator
     self.gen = SequenceGenerator(task.target_dictionary,
                                  beam_size=args.beam_size)
     if args.reward == "bleurt":
         from fairseq.distributed_utils import get_rank
         sys.argv = sys.argv[:1]
         my_rank = 0 if torch.cuda.device_count() <= 1 else get_rank()
         os.environ["CUDA_VISIBLE_DEVICES"] = str(my_rank % 4)
         from bleurt import score
         from transformers import cached_path
         import tensorflow as tf
         gpus = tf.config.experimental.list_physical_devices('GPU')
         if gpus:
             this_gpu = gpus[my_rank % 4]
             tf.config.set_visible_devices([this_gpu], 'GPU')
             try:
                 tf.config.experimental.set_memory_growth(this_gpu, True)
                 tf.config.experimental.set_virtual_device_configuration(
                     this_gpu, [
                         tf.config.experimental.VirtualDeviceConfiguration(
                             memory_limit=2048)
                     ])
                 logical_devices = tf.config.list_logical_devices('GPU')
                 self.logical_device = tf.device(logical_devices[0].name)
                 print("num of logical gpus", len(logical_devices))
             except RuntimeError as e:
                 print(e)
         with self.logical_device:
             self.bleurt_scorer = score.BleurtScorer(
                 os.path.join(
                     cached_path(
                         "https://storage.googleapis.com/bleurt-oss/bleurt-base-128.zip",
                         extract_compressed_file=True), "bleurt-base-128"))
Exemplo n.º 6
0
def get_rank():
    try:
        return du.get_rank()
    except AssertionError:
        return 0
Exemplo n.º 7
0
def main(args):
    if args.max_tokens is None:
        args.max_tokens = 6000
    args.restore_file = 'checkpoint_last_r{}_n{}.pt'.format(
        distributed_utils.get_rank(), distributed_utils.get_world_size())
    print(args)

    if not torch.cuda.is_available():
        raise NotImplementedError('Training on CPU is not supported')
    torch.cuda.set_device(args.device_id)
    torch.manual_seed(args.seed)

    # Setup task, e.g., translation, language modeling, etc.
    task = tasks.setup_task(args)

    # Load dataset splits
    load_dataset_splits(task, ['train', 'valid'])

    # Build model and criterion
    model = task.build_model(args)
    criterion = task.build_criterion(args)
    print('| model {}, criterion {}'.format(args.arch,
                                            criterion.__class__.__name__))
    print('| num. model params: {}'.format(
        sum(p.numel() for p in model.parameters())))

    # Make a dummy batch to (i) warm the caching allocator and (ii) as a
    # placeholder DistributedDataParallel when there's an uneven number of
    # batches per worker.
    max_positions = utils.resolve_max_positions(
        task.max_positions(),
        model.max_positions(),
    )
    dummy_batch = task.dataset('train').get_dummy_batch(
        args.max_tokens, max_positions)

    # Build trainer
    trainer = Trainer(args, task, model, criterion, dummy_batch)
    print('| training on {} Nodes'.format(args.distributed_world_size))
    print('| max tokens per node = {} and max sentences per node = {}'.format(
        args.max_tokens,
        args.max_sentences,
    ))

    # Initialize dataloader
    epoch_itr = task.get_batch_iterator(
        dataset=task.dataset(args.train_subset),
        max_tokens=args.max_tokens,
        max_sentences=args.max_sentences,
        max_positions=max_positions,
        ignore_invalid_inputs=True,
        required_batch_size_multiple=8,
        seed=args.seed,
        num_shards=args.distributed_world_size,
        shard_id=args.distributed_rank,
    )

    # Load the latest checkpoint if one is available
    if not load_checkpoint(args, trainer, epoch_itr):
        trainer.dummy_train_step([dummy_batch])

    # Train until the learning rate gets too small
    max_epoch = args.max_epoch or math.inf
    max_update = args.max_update or math.inf
    lr = trainer.get_lr()
    train_meter = StopwatchMeter()
    train_meter.start()
    valid_losses = [None]
    valid_subsets = args.valid_subset.split(',')
    while lr > args.min_lr and epoch_itr.epoch < max_epoch and trainer.get_num_updates(
    ) < max_update:
        # train for one epoch
        train(args, trainer, task, epoch_itr)

        if epoch_itr.epoch % args.validate_interval == 0:
            valid_losses = validate(args, trainer, task, epoch_itr,
                                    valid_subsets)

        # only use first validation loss to update the learning rate
        lr = trainer.lr_step(epoch_itr.epoch, valid_losses[0])

        # save checkpoint
        if epoch_itr.epoch % args.save_interval == 0:
            save_checkpoint(args, trainer, epoch_itr, valid_losses[0])
    train_meter.stop()
    print('| done training in {:.1f} seconds'.format(train_meter.sum))