示例#1
0
文件: rank_aml.py 项目: xssstory/STAS
def main(args):
    # we should not do this!
    '''
    if args.max_tokens is None:
        args.max_tokens = 6000
    '''
    utils.xpprint(args)

    if not torch.cuda.is_available():
        raise NotImplementedError('Training on CPU is not supported')
    torch.cuda.set_device(args.device_id)
    torch.manual_seed(args.seed)

    # Setup task, e.g., translation, language modeling, etc.
    task = tasks.setup_task(args)

    utils.xprintln('setup task done!')

    # Load dataset splits
    load_dataset_splits(args, task, ['train'])
    valid_dataset = args.valid_subset.split(',')
    load_dataset_splits(args, task, valid_dataset, shuffle=False)
    utils.xprintln('load dataset done!')

    if args.task.startswith('extractive_summarization'):
        if distributed_utils.is_master(args):
            from sum_eval import MultiProcSumEval
            sum_eval_pool = MultiProcSumEval(args.ncpu_eval)
            sum_valid_pool_params = dict(
                article_file=args.raw_valid + '.article',
                summary_file=args.raw_valid + '.summary',
                entity_map_file=None,
                length=-1,
                eval_type='predict',
                topk=args.topk_sent_eval,
                rerank=False,
                with_m=False,
                cmd='-a -c 95 -m -n 4 -w 1.2',
                trigram_block=args.trigram_block,
            )

            sum_test_pool_params = dict(
                article_file=args.raw_test + '.article',
                summary_file=args.raw_test + '.summary',
                entity_map_file=None,
                length=-1,
                eval_type='predict',
                topk=args.topk_sent_eval,
                rerank=False,
                with_m=False,
                cmd='-a -c 95 -m -n 4 -w 1.2',
                trigram_block=args.trigram_block,
            )
            sum_pool_params = dict(valid=sum_valid_pool_params,
                                   test=sum_test_pool_params)

            def make_params(default_dict,
                            result_file,
                            out_rouge_file,
                            rerank=False,
                            with_m=False):
                para_dict = dict(default_dict)
                para_dict['result_file'] = result_file
                para_dict['out_rouge_file'] = out_rouge_file
                para_dict['rerank'] = rerank
                para_dict['with_m'] = with_m
                return para_dict

    # Build model and criterion
    model = task.build_model(args)
    criterion = task.build_criterion(args)
    print('| model {}, criterion {}'.format(args.arch,
                                            criterion.__class__.__name__))
    print('| num. model params: {}'.format(
        sum(p.numel() for p in model.parameters())))
    # print(model)
    import sys
    sys.stdout.flush()

    # if summarization try to load pretrained model
    # if args.task.startswith('extractive_summarization') or args.task == 'pretrain_document_modeling':
    #     # assume this is a single GPU program
    if args.init_from_pretrained_doc_model:
        task.load_pretrained_model(model, args.pretrained_doc_model_path)
    sys.stdout.flush()

    # Build trainer
    trainer = Trainer(args, task, model, criterion)
    print('| training on {} GPUs'.format(args.distributed_world_size))
    print('| max tokens per GPU = {} and max sentences per GPU = {}'.format(
        args.max_tokens,
        args.max_sentences,
    ))

    # Initialize dataloader
    max_positions = trainer.get_model().max_positions()
    epoch_itr = trainer.get_train_iterator(epoch=0, load_dataset=False)

    # Load the latest checkpoint if one is available
    # load_checkpoint(args, trainer, epoch_itr)
    # make sure training from a different checkpoint will use different random seed
    cur_dataset = task.dataset('train')
    if hasattr(cur_dataset, 'rng'):
        print('epoch ', epoch_itr.epoch)
        cur_dataset.rng = numpy.random.RandomState(args.seed + epoch_itr.epoch)

    # Train until the learning rate gets too small
    max_epoch = args.max_epoch or math.inf
    max_update = args.max_update or math.inf
    lr = trainer.get_lr()
    train_meter = StopwatchMeter()
    train_meter.start()
    valid_losses = [None]
    valid_subsets = args.valid_subset.split(',')
    for alpha in range(10, 9, -1):
        # train for one epoch
        # train(args, trainer, task, epoch_itr)

        epoch_itr.next_epoch_itr()

        if epoch_itr.epoch % args.validate_interval == 0:
            if args.task.startswith('extractive_summarization'):
                if distributed_utils.is_master(args):
                    validate_metric(args, trainer, task, epoch_itr,
                                    valid_subsets)
示例#2
0
def main(args):
    dummy_batch_size = args.max_tokens
    if args.max_tokens is None:
        args.max_tokens = 4096
        dummy_batch_size = 1024
    print(args)

    if not torch.cuda.is_available():
        raise NotImplementedError('Training on CPU is not supported')
    torch.cuda.set_device(args.device_id)
    torch.manual_seed(args.seed)

    # Setup task, e.g., translation, language modeling, etc.
    task = tasks.setup_task(args)

    # Load dataset splits
    load_dataset_splits(task, ['train'] + args.valid_subset.split(','))

    # Build model and criterion
    model = task.build_model(args)
    criterion = task.build_criterion(args)
    print('| model {}, criterion {}'.format(args.arch,
                                            criterion.__class__.__name__))
    print('| num. model params: {}'.format(
        sum(p.numel() for p in model.parameters())))
    # Make a dummy batch to (i) warm the caching allocator and (ii) as a
    # placeholder DistributedDataParallel when there's an uneven number of
    # batches per worker.
    max_positions = utils.resolve_max_positions(
        task.max_positions(),
        model.max_positions(),
    )
    dummy_batch = task.dataset('train').get_dummy_batch(
        args.max_tokens, max_positions)

    # Build trainer
    trainer = Trainer(args, task, model, criterion, dummy_batch)
    print('| training on {} GPUs'.format(args.distributed_world_size))
    print('| max tokens per GPU = {} and max sentences per GPU = {}'.format(
        args.max_tokens,
        args.max_sentences,
    ))

    # Initialize dataloader
    epoch_itr = task.get_batch_iterator(
        dataset=task.dataset(args.train_subset),
        max_tokens=args.max_tokens,
        max_sentences=args.max_sentences,
        max_positions=max_positions,
        ignore_invalid_inputs=True,
        required_batch_size_multiple=8,
        seed=args.seed,
        num_shards=args.distributed_world_size,
        shard_id=args.distributed_rank,
    )

    # Load the latest checkpoint if one is available
    if not load_checkpoint(args, trainer, epoch_itr):
        trainer.dummy_train_step([dummy_batch])

    # Train until the learning rate gets too small
    max_epoch = args.max_epoch or math.inf
    max_update = args.max_update or math.inf
    lr = trainer.get_lr()
    train_meter = StopwatchMeter()
    train_meter.start()
    valid_losses = [None]
    valid_subsets = args.valid_subset.split(',')
    while lr > args.min_lr and trainer.get_num_updates(
    ) < max_update and epoch_itr.epoch < max_epoch:
        # train for one epoch
        train(args, trainer, task, epoch_itr)

        if epoch_itr.epoch % args.validate_interval == 0:
            valid_losses = validate(args, trainer, task, epoch_itr,
                                    valid_subsets)
        # only use first validation loss to update the learning rate
        lr = trainer.lr_step(epoch_itr.epoch, valid_losses[0])

        # save checkpoint
        if epoch_itr.epoch % args.save_interval == 0:
            save_checkpoint(args, trainer, epoch_itr, valid_losses[0])
            if args.task == 'squad':
                eval_dataset(task, trainer.get_model(), task.dataset('valid'),
                             args.data_file, args)

    train_meter.stop()
    print('| done training in {:.1f} seconds'.format(train_meter.sum))
示例#3
0
文件: train.py 项目: sk210892/fairseq
def main(args):
    if args.max_tokens is None:
        args.max_tokens = 6000
    print(args)

    if not torch.cuda.is_available():
        raise NotImplementedError('Training on CPU is not supported')
    torch.cuda.set_device(args.device_id)
    torch.manual_seed(args.seed)

    # Setup task, e.g., translation, language modeling, etc.
    task = tasks.setup_task(args)

    # Load dataset splits
    load_dataset_splits(args, task, ['train', 'valid'])

    # Build model and criterion
    model = task.build_model(args)
    criterion = task.build_criterion(args)
    print('| model {}, criterion {}'.format(args.arch,
                                            criterion.__class__.__name__))
    print('| num. model params: {}'.format(
        sum(p.numel() for p in model.parameters())))

    # Build trainer
    if args.fp16:
        trainer = FP16Trainer(args, task, model, criterion)
    else:
        if torch.cuda.get_device_capability(0)[0] >= 7:
            print(
                '| NOTICE: your device may support faster training with --fp16'
            )
        trainer = Trainer(args, task, model, criterion)
    print('| training on {} GPUs'.format(args.distributed_world_size))
    print('| max tokens per GPU = {} and max sentences per GPU = {}'.format(
        args.max_tokens,
        args.max_sentences,
    ))

    # Initialize dataloader
    max_positions = trainer.get_model().max_positions()
    epoch_itr = data.EpochBatchIterator(
        dataset=task.dataset(args.train_subset),
        max_tokens=args.max_tokens,
        max_sentences=args.max_sentences_valid,
        max_positions=max_positions,
        ignore_invalid_inputs=True,
        required_batch_size_multiple=8,
        seed=args.seed,
        num_shards=args.distributed_world_size,
        shard_id=args.distributed_rank,
    )

    # Load the latest checkpoint if one is available
    load_checkpoint(args, trainer, epoch_itr)

    # Send a dummy batch to warm the caching allocator
    dummy_batch = task.dataset('train').get_dummy_batch(
        args.max_tokens, max_positions)
    trainer.dummy_train_step(dummy_batch)

    # Train until the learning rate gets too small
    max_epoch = args.max_epoch or math.inf
    max_update = args.max_update or math.inf
    lr = trainer.get_lr()
    train_meter = StopwatchMeter()
    train_meter.start()
    valid_losses = [None]
    valid_subsets = args.valid_subset.split(',')
    while lr > args.min_lr and epoch_itr.epoch <= max_epoch and trainer.get_num_updates(
    ) < max_update:
        # train for one epoch
        train(args, trainer, task, epoch_itr)

        if epoch_itr.epoch % args.validate_interval == 0:
            valid_losses = validate(args, trainer, task, epoch_itr,
                                    valid_subsets)

        # only use first validation loss to update the learning rate
        lr = trainer.lr_step(epoch_itr.epoch, valid_losses[0])

        # save checkpoint
        if epoch_itr.epoch % args.save_interval == 0:
            save_checkpoint(args, trainer, epoch_itr, valid_losses[0])
    train_meter.stop()
    print('| done training in {:.1f} seconds'.format(train_meter.sum))
示例#4
0
def main(args):
    if not torch.cuda.is_available():
        raise NotImplementedError('Training on CPU is not supported')
    torch.cuda.set_device(args.device_id)

    mlperf_compliance.mlperf_log.LOGGER.propagate = False

    # framework = f'Pytorch NGC {os.environ["NVIDIA_PYTORCH_VERSION"]}'
    # mlperf_submission_log(
    #     benchmark=mlperf_compliance.constants.TRANSFORMER,
    #     framework=framework)

    mlperf_compliance.mlperf_log.setdefault(
        root_dir=os.path.dirname(os.path.abspath(__file__)),
        benchmark=mlperf_compliance.constants.TRANSFORMER,
        stack_offset=1,
        extra_print=False)

    mlperf_print(key=mlperf_compliance.constants.INIT_START,
                 log_all_ranks=True)

    torch.manual_seed(args.seed)
    torch.cuda.manual_seed_all(args.seed)
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

    # preinit and warmup streams/groups for allreduce communicators
    allreduce_communicators = None
    if args.distributed_world_size > 1 and args.enable_parallel_backward_allred_opt:
        allreduce_groups = [
            torch.distributed.new_group()
            for _ in range(args.parallel_backward_allred_cuda_nstreams)
        ]
        allreduce_streams = [
            torch.cuda.Stream()
            for _ in range(args.parallel_backward_allred_cuda_nstreams)
        ]
        for group, stream in zip(allreduce_groups, allreduce_streams):
            with torch.cuda.stream(stream):
                torch.distributed.all_reduce(torch.cuda.FloatTensor(1),
                                             group=group)
        allreduce_communicators = (allreduce_groups, allreduce_streams)

    if args.max_tokens is None:
        args.max_tokens = 6000

    print(args)

    mlperf_print(key=mlperf_compliance.constants.GLOBAL_BATCH_SIZE,
                 value=args.max_tokens * args.distributed_world_size)
    mlperf_print(key=mlperf_compliance.constants.OPT_NAME,
                 value=args.optimizer)
    assert (len(args.lr) == 1)
    mlperf_print(key=mlperf_compliance.constants.OPT_BASE_LR,
                 value=args.lr[0] if len(args.lr) == 1 else args.lr)
    mlperf_print(key=mlperf_compliance.constants.OPT_LR_WARMUP_STEPS,
                 value=args.warmup_updates)
    assert (args.max_source_positions == args.max_target_positions)
    mlperf_print(key=mlperf_compliance.constants.MAX_SEQUENCE_LENGTH,
                 value=args.max_target_positions)
    mlperf_print(key=mlperf_compliance.constants.OPT_ADAM_BETA_1,
                 value=eval(args.adam_betas)[0])
    mlperf_print(key=mlperf_compliance.constants.OPT_ADAM_BETA_2,
                 value=eval(args.adam_betas)[1])
    mlperf_print(key=mlperf_compliance.constants.OPT_ADAM_EPSILON,
                 value=args.adam_eps)

    pValue = ctypes.cast((ctypes.c_int * 1)(), ctypes.POINTER(ctypes.c_int))
    result = torch.cuda.cudart().cudaDeviceSetLimit(ctypes.c_int(0x05),
                                                    ctypes.c_int(128))
    result = torch.cuda.cudart().cudaDeviceGetLimit(pValue, ctypes.c_int(0x05))

    #    torch.manual_seed(args.seed)

    # Setup task, e.g., translation, language modeling, etc.
    task = tasks.setup_task(args)

    # Build model and criterion
    model = task.build_model(args)
    criterion = task.build_criterion(args)

    print('| model {}, criterion {}'.format(args.arch,
                                            criterion.__class__.__name__))
    print('| num. model params: {}'.format(
        sum(p.numel() for p in model.parameters())))

    # Build trainer
    if args.fp16:
        trainer = FP16Trainer(args,
                              task,
                              model,
                              criterion,
                              allreduce_communicators=allreduce_communicators)
    else:
        if torch.cuda.get_device_capability(0)[0] >= 7:
            print(
                '| NOTICE: your device may support faster training with --fp16'
            )

        trainer = Trainer(args,
                          task,
                          model,
                          criterion,
                          allreduce_communicators=None)

    #if (args.online_eval or args.target_bleu) and not args.remove_bpe:
    #    args.remove_bpe='@@ '

    print('| training on {} GPUs'.format(args.distributed_world_size))
    print('| max tokens per GPU = {} and max sentences per GPU = {}'.format(
        args.max_tokens,
        args.max_sentences,
    ))

    # Initialize dataloader
    max_positions = trainer.get_model().max_positions()

    # Send a dummy batch to warm the caching allocator
    dummy_batch = language_pair_dataset.get_dummy_batch_isolated(
        args.max_tokens, max_positions, 8)
    trainer.dummy_train_step(dummy_batch)

    # Train until the learning rate gets too small or model reaches target score
    max_epoch = args.max_epoch if args.max_epoch >= 0 else math.inf
    max_update = args.max_update or math.inf
    tgt_bleu = args.target_bleu or math.inf
    current_bleu = 0.0
    lr = trainer.get_lr()
    train_meter = StopwatchMeter()
    train_meter.start()
    valid_losses = [None]
    valid_subsets = args.valid_subset.split(',')

    # mlperf compliance synchronization
    if args.distributed_world_size > 1:
        assert (torch.distributed.is_initialized())
        torch.distributed.all_reduce(torch.cuda.FloatTensor(1))
        torch.cuda.synchronize()

    mlperf_print(key=mlperf_compliance.constants.INIT_STOP, sync=True)

    mlperf_print(key=mlperf_compliance.constants.RUN_START, sync=True)
    # second sync after RUN_START tag is printed.
    # this ensures no rank touches data until after RUN_START tag is printed.
    barrier()

    # Load dataset splits
    load_dataset_splits(task, ['train', 'test'])

    ctr = 0

    class DummyEpochBatchIterator:
        def __init__(self, epoch=0):
            self.epoch = epoch

    epoch_itr = DummyEpochBatchIterator(0)

    # Main training loop
    while lr >= args.min_lr and epoch_itr.epoch < max_epoch and trainer.get_num_updates(
    ) < max_update and current_bleu < tgt_bleu:
        first_epoch = epoch_itr.epoch + 1
        mlperf_print(key=mlperf_compliance.constants.BLOCK_START,
                     metadata={
                         'first_epoch_num': first_epoch,
                         'epoch_count': 1
                     },
                     sync=True)
        mlperf_print(key=mlperf_compliance.constants.EPOCH_START,
                     metadata={'epoch_num': first_epoch},
                     sync=True)
        start = time.time()

        gc.disable()

        epoch_itr = data.EpochBatchIterator(
            dataset=task.dataset(args.train_subset),
            dataloader_num_workers=args.dataloader_num_workers,
            dataloader_pin_memory=args.enable_dataloader_pin_memory,
            max_tokens=args.max_tokens,
            max_sentences=args.max_sentences_valid,
            max_positions=max_positions,
            ignore_invalid_inputs=True,
            required_batch_size_multiple=8,
            seed=args.seed,
            num_shards=args.distributed_world_size,
            shard_id=args.distributed_rank,
            epoch=epoch_itr.epoch if ctr is not 0 else 0,
            bucket_growth_factor=args.bucket_growth_factor,
            seq_len_multiple=args.seq_len_multiple,
            batching_scheme=args.batching_scheme,
            batch_multiple_strategy=args.batch_multiple_strategy,
        )

        print("got epoch iterator", time.time() - start)

        # Load the latest checkpoint if one is available
        if ctr is 0:
            load_checkpoint(args, trainer, epoch_itr)

        # train for one epoch
        start = time.time()
        #exit(1)
        train(args, trainer, task, epoch_itr)
        print("epoch time ", time.time() - start)

        start = time.time()
        mlperf_print(key=mlperf_compliance.constants.EPOCH_STOP,
                     metadata={'epoch_num': first_epoch},
                     sync=True)

        #if epoch_itr.epoch % args.validate_interval == 0:
        #    valid_losses = validate(args, trainer, task, epoch_itr, valid_subsets)

        # Eval BLEU score
        if args.online_eval or (not tgt_bleu is math.inf):
            current_bleu = score(args, trainer, task, epoch_itr,
                                 args.gen_subset)
            mlperf_print(key=mlperf_compliance.tags.EVAL_ACCURACY,
                         value=str(current_bleu),
                         metadata={'epoch_num': first_epoch})

        gc.enable()

        # Only use first validation loss to update the learning rate
        #lr = trainer.lr_step(epoch_itr.epoch, valid_losses[0])

        # Save checkpoint
        #if epoch_itr.epoch % args.save_interval == 0:
        #    save_checkpoint(args, trainer, epoch_itr, valid_losses[0])

        ctr = ctr + 1
        print("validation and scoring ", time.time() - start)
        mlperf_print(key=mlperf_compliance.constants.BLOCK_STOP,
                     metadata={'first_epoch_num': first_epoch},
                     sync=True)

    train_meter.stop()
    status = 'success' if current_bleu >= tgt_bleu else 'aborted'
    mlperf_print(key=mlperf_compliance.constants.RUN_STOP,
                 metadata={'status': status})
    print('| done training in {:.1f} seconds'.format(train_meter.sum))
示例#5
0
文件: train.py 项目: lhu17/translate
def setup_training(args):
    """Parse args, load dataset, and load model trainer."""
    if not torch.cuda.is_available():
        raise NotImplementedError("Training on CPU is not supported")
    torch.cuda.set_device(args.device_id)
    torch.manual_seed(args.seed)

    # Setup task and load dataset
    task = tasks.setup_task(args)
    task.load_dataset(
        args.train_subset,
        args.train_source_binary_path,
        args.train_target_binary_path,
        weights_file=getattr(args, "train_weights_path", None),
    )
    task.load_dataset(args.valid_subset, args.eval_source_binary_path,
                      args.eval_target_binary_path)

    # Build model and criterion
    model = task.build_model(args)
    print("| building criterion")
    criterion = task.build_criterion(args)
    print(f"| model {args.arch}, criterion {criterion.__class__.__name__}")
    print(f"| num. model params: \
        {sum(p.numel() for p in model.parameters())}")

    # Build trainer
    if args.fp16:
        trainer = FP16Trainer(args, task, model, criterion)
    else:
        if torch.cuda.get_device_capability(0)[0] >= 7:
            print(
                "| NOTICE: your device may support faster training with --fp16"
            )
        trainer = Trainer(args, task, model, criterion)
    print(f"| training on {args.distributed_world_size} GPUs")
    print(
        f"| max tokens per GPU = {args.max_tokens} and \
        max sentences per GPU = {args.max_sentences}",
        flush=True,
    )

    os.makedirs(args.save_dir, exist_ok=True)

    # If --restore-file is already present under --save-dir, use that one
    # instead of --pretrained-checkpoint-file. The idea is that
    # --pretrained-checkpoint-file allows the user to specify restoring from a
    # different run's checkpoint (possibly with different training params),
    # while not polluting the previous run's checkpoint directory
    # with new checkpoints. However, if training gets interrupted
    # and the user restarts training, we want to resume from
    # the checkpoints under --save-dir, instead of
    # restarting again from the old run's checkpoint at
    # --pretrained-checkpoint-file.
    #
    # Note that if args.restore_file is an absolute path, os.path.join() will
    # ignore previous directory args and just use the absolute path as is.
    checkpoint_path = os.path.join(args.save_dir, args.restore_file)
    restore_state = True
    if os.path.exists(checkpoint_path):
        print(
            f"| Using --save-dir={args.save_dir}, --restore-file={args.restore_file}."
        )
    elif args.pretrained_checkpoint_file and os.path.exists(
            args.pretrained_checkpoint_file):
        checkpoint_path = args.pretrained_checkpoint_file
        restore_state = args.load_pretrained_checkpoint_state
        print(
            f"| Using --pretrained-checkpoint-file={args.pretrained_checkpoint_file}, "
            f"--load-pretrained-checkpoint-state={args.load_pretrained_checkpoint_state}."
        )

    extra_state = default_extra_state(args)
    if not os.path.isfile(checkpoint_path) and args.multi_model_restore_files:
        print(
            f"| Restoring individual models from {args.multi_model_restore_files}"
        )
        multi_model.import_individual_models(args.multi_model_restore_files,
                                             trainer)
    else:
        loaded, loaded_extra_state = load_existing_checkpoint(
            checkpoint_path=checkpoint_path,
            trainer=trainer,
            restore_state=restore_state,
        )
        if loaded_extra_state:
            extra_state.update(loaded_extra_state)
        if loaded:
            args.path = [checkpoint_path]
            calculate_bleu_on_subset(
                args=args,
                task=task,
                epoch_str="initial loaded checkpoint",
                offset=None,
                dataset_split=args.valid_subset,
            )
    print(f"| extra_state: {extra_state}")

    epoch_itr = data.EpochBatchIterator(
        dataset=task.dataset(args.train_subset),
        max_tokens=args.max_tokens,
        max_sentences=args.max_sentences,
        max_positions=trainer.get_model().max_positions(),
        ignore_invalid_inputs=args.skip_invalid_size_inputs_valid_test,
        seed=args.seed,
        num_shards=args.distributed_world_size,
        shard_id=args.distributed_rank,
    )
    epoch = extra_state["epoch"]
    if extra_state["batch_offset"] == 0:
        epoch -= 1  # this will be incremented when we call epoch_itr.next_epoch_itr()
    epoch_itr.load_state_dict({
        "epoch":
        epoch,
        "iterations_in_epoch":
        extra_state["batch_offset"]
    })

    return extra_state, trainer, task, epoch_itr
示例#6
0
def main(args, init_distributed=False):
    utils.import_user_module(args)

    assert args.max_tokens is not None or args.max_sentences is not None, \
        'Must specify batch size either with --max-tokens or --max-sentences'

    # Initialize CUDA and distributed training
    if torch.cuda.is_available() and not args.cpu:
        torch.cuda.set_device(args.device_id)
    torch.manual_seed(args.seed)
    if init_distributed:
        args.distributed_rank = distributed_init_hvd(args)

    # Print args
    print(args)

    # if not HAS_NSML:
    #     args.data[0] = args.data[0].replace("/train", "")

    # Setup task, e.g., translation, language modeling, etc.
    task = tasks.setup_task(args)

    # Load valid dataset (we load training data below, based on the latest checkpoint)
    for valid_sub_split in args.valid_subset.split(','):
        task.load_dataset(valid_sub_split, combine=False, epoch=0)

    # Build model and criterion
    model = task.build_model(args)
    criterion = task.build_criterion(args)

    if args.train_decoder_only:
        for name, param in model.named_parameters():
            if "decoder" not in name:
                param.requires_grad_(False)

    print(model)
    print('| model {}, criterion {}'.format(args.arch,
                                            criterion.__class__.__name__))
    print('| num. model params: {} (num. trained: {})'.format(
        sum(p.numel() for p in model.parameters()),
        sum(p.numel() for p in model.parameters() if p.requires_grad),
    ))

    # Setup session
    if HAS_WANDB and distributed_utils.is_master(args):
        wandb.init(project="cmlm", config=args)
        wandb.watch(model)

    # Load pre-trained model
    data_token = args.data[0].split("/")[-1]
    if "bert" in args.arch:
        pretrained_path = "{}/train/pretrained_models/maskPredict_{}/checkpoint_best.pt".format(
            DATASET_PATH,
            data_token.split(".")[-1].replace("-", "_"))
        if not HAS_NSML:
            pretrained_path = pretrained_path.replace("/train", "")
        print("| loading", pretrained_path)
        state = checkpoint_utils.load_checkpoint_to_cpu(pretrained_path)
        model.load_state_dict(state["model"], strict=True)
        baseline_model = task.build_model(args)
        baseline_model.load_state_dict(state["model"], strict=True)
        if torch.cuda.is_available():
            baseline_model.cuda()
        task.set_baseline_model(baseline_model)

    if not args.masking and HAS_NSML:

        def nsml_bind(model):
            def save(dir_path):
                state = {
                    'model': model.state_dict(),
                }
                torch.save(state, os.path.join(dir_path, 'best.pt'))

            def load(dir_path):
                state = torch.load(os.path.join(dir_path, 'best.pt'),
                                   map_location="cpu")
                model.load_state_dict(state['model'], strict=False)
                model.cuda()
                print('model loaded!')

            nsml.bind(save=save, load=load)

        nsml_bind(model)

    if args.load:
        print("loading model from session", args.load)
        if args.load.startswith("nsml://"):
            session = args.load.replace("nsml://", "")
        if ".pt" in session:
            session = session.replace(".pt", "")
            session, checkpoint_name = session.rsplit("/", 1)
        else:
            checkpoint_name = "best"
        if "-" in checkpoint_name:
            start, end = checkpoint_name.replace("epoch", "").split("-")
            checkpoints = [
                "epoch{}".format(i) for i in range(int(start),
                                                   int(end) + 1)
            ]
            print("| checkpoint average:", checkpoints)
            state_dict = None

            def load(dir_path):
                nonlocal state_dict, checkpoints
                state = torch.load(os.path.join(dir_path, 'best.pt'))
                model_state = state["model"]
                for k in model_state:
                    model_state[k] = model_state[k] / float(len(checkpoints))
                if state_dict is None:
                    state_dict = model_state
                else:
                    for k in state_dict:
                        state_dict[k] += model_state[k]
                print("checkpoint loaded")

            for checkpoint_name in checkpoints:
                nsml.load(checkpoint_name, load_fn=load, session=session)
            model.load_state_dict(state_dict)
        else:

            def load(dir_path):
                state = torch.load(os.path.join(dir_path, 'best.pt'))
                state_dict = state["model"]
                model.load_state_dict(state_dict)
                print("loaded")

            nsml.load(checkpoint_name, load_fn=load, session=session)

    # Prepare for decoder wise training
    if args.decoder_wise_training:
        print("| Decoder wise training, start refinement step 0")
        progressive_training_step = 0
        assert args.ddp_backend == "c10d"
    else:
        progressive_training_step = None

    # Build trainer
    trainer = Trainer(args, task, model, criterion)
    print('| training on {} GPUs'.format(args.distributed_world_size))
    print('| max tokens per GPU = {} and max sentences per GPU = {}'.format(
        args.max_tokens,
        args.max_sentences,
    ))

    # Load the latest checkpoint if one is available and restore the
    # corresponding train iterator
    extra_state, epoch_itr = checkpoint_utils.load_checkpoint(args, trainer)

    # Train until the learning rate gets too small
    max_epoch = args.max_epoch or math.inf
    max_update = args.max_update or math.inf
    lr = trainer.get_lr()
    train_meter = StopwatchMeter()
    train_meter.start()
    valid_losses = [None]
    valid_subsets = args.valid_subset.split(',')
    if hasattr(args, "progressive") and args.progressive:
        for i in range(args.refinetot if not getattr(args, "pnet", False) else
                       args.refinetot - 1):
            print("validating for refine step", i)
            validate(args,
                     trainer,
                     task,
                     epoch_itr,
                     valid_subsets,
                     force_refine_step=i)
        print("---")
    validate(args, trainer, task, epoch_itr, valid_subsets)
    while lr > args.min_lr and epoch_itr.epoch < max_epoch and trainer.get_num_updates(
    ) < max_update:
        # train for one epoch
        train(args,
              trainer,
              task,
              epoch_itr,
              force_refine_step=progressive_training_step)
        if not args.disable_validation and epoch_itr.epoch % args.validate_interval == 0:
            valid_losses = validate(
                args,
                trainer,
                task,
                epoch_itr,
                valid_subsets,
                force_refine_step=progressive_training_step)
        else:
            valid_losses = [None]

        if args.decoder_wise_training:
            progressive_training_step = update_num_to_refine_step(
                trainer.get_num_updates())

        # only use first validation loss to update the learning rate
        lr = trainer.lr_step(epoch_itr.epoch, valid_losses[0])

        # save checkpoint
        if epoch_itr.epoch % args.save_interval == 0:
            if HAS_NSML:
                if distributed_utils.is_master(args):
                    print("nsml save for epoch", epoch_itr.epoch)
                    nsml.save("epoch{}".format(epoch_itr.epoch))
            else:
                torch.save({"model": trainer.get_model().state_dict()},
                           "/tmp/epoch{}.pt".format(epoch_itr.epoch))
                if HAS_WANDB:
                    wandb.save("/tmp/epoch{}.pt".format(epoch_itr.epoch))
                # checkpoint_utils.save_checkpoint(args, trainer, epoch_itr, valid_losses[0])

        if ':' in getattr(args, 'data', ''):
            # sharded data: get train iterator for next epoch
            epoch_itr = trainer.get_train_iterator(epoch_itr.epoch)
    train_meter.stop()
    print('| done training in {:.1f} seconds'.format(train_meter.sum))
示例#7
0
def validate(
    cfg: DictConfig,
    trainer: Trainer,
    task: tasks.FairseqTask,
    epoch_itr,
    subsets: List[str],
) -> List[Optional[float]]:
    """Evaluate the model on the validation set(s) and return the losses."""

    if cfg.dataset.fixed_validation_seed is not None:
        # set fixed seed for every validation
        utils.set_torch_seed(cfg.dataset.fixed_validation_seed)

    trainer.begin_valid_epoch(epoch_itr.epoch)
    valid_losses = []
    for subset_idx, subset in enumerate(subsets):
        logger.info('begin validation on "{}" subset'.format(subset))

        # Initialize data iterator
        itr = trainer.get_valid_iterator(subset).next_epoch_itr(
            shuffle=False,
            set_dataset_epoch=False  # use a fixed valid set
        )
        if cfg.common.tpu:
            itr = utils.tpu_data_loader(itr)
        progress = progress_bar.progress_bar(
            itr,
            log_format=cfg.common.log_format,
            log_interval=cfg.common.log_interval,
            epoch=epoch_itr.epoch,
            prefix=f"valid on '{subset}' subset",
            tensorboard_logdir=(cfg.common.tensorboard_logdir
                                if distributed_utils.is_master(
                                    cfg.distributed_training) else None),
            default_log_format=("tqdm" if not cfg.common.no_progress_bar else
                                "simple"),
            wandb_project=(cfg.common.wandb_project
                           if distributed_utils.is_master(
                               cfg.distributed_training) else None),
            wandb_run_name=os.environ.get(
                "WANDB_NAME", os.path.basename(cfg.checkpoint.save_dir)),
        )

        # create a new root metrics aggregator so validation metrics
        # don't pollute other aggregators (e.g., train meters)
        with metrics.aggregate(new_root=True) as agg:
            for i, sample in enumerate(progress):
                if (cfg.dataset.max_valid_steps is not None
                        and i > cfg.dataset.max_valid_steps):
                    break
                trainer.valid_step(sample)

        # log validation stats
        # only tracking the best metric on the 1st validation subset
        tracking_best = subset_idx == 0
        stats = get_valid_stats(cfg, trainer, agg.get_smoothed_values(),
                                tracking_best)

        if hasattr(task, "post_validate"):
            task.post_validate(trainer.get_model(), stats, agg)

        progress.print(stats, tag=subset, step=trainer.get_num_updates())

        valid_losses.append(stats[cfg.checkpoint.best_checkpoint_metric])
    return valid_losses
示例#8
0
def main(args):
    if not torch.cuda.is_available():
        raise NotImplementedError('Training on CPU is not supported')
    torch.cuda.set_device(args.device_id)
    if args.distributed_world_size > 1:
        assert (torch.distributed.is_initialized())
        torch.distributed.broadcast(torch.tensor([1], device="cuda"), 0)
        torch.cuda.synchronize()
    if args.max_tokens is None:
        args.max_tokens = 6000
    print(args)
    pValue = ctypes.cast((ctypes.c_int * 1)(), ctypes.POINTER(ctypes.c_int))
    result = torch.cuda.cudart().cudaDeviceSetLimit(ctypes.c_int(0x05),
                                                    ctypes.c_int(128))
    result = torch.cuda.cudart().cudaDeviceGetLimit(pValue, ctypes.c_int(0x05))
    torch.manual_seed(args.seed)

    # Setup task, e.g., translation, language modeling, etc.
    task = tasks.setup_task(args)

    # Load dataset splits
    load_dataset_splits(task, ['train', 'valid'])

    # Build model and criterion
    model = task.build_model(args)
    criterion = task.build_criterion(args)
    print('| model {}, criterion {}'.format(args.arch,
                                            criterion.__class__.__name__))
    print('| num. model params: {}'.format(
        sum(p.numel() for p in model.parameters())))

    # Build trainer
    if args.fp16:
        trainer = FP16Trainer(args, task, model, criterion)
    else:
        if torch.cuda.get_device_capability(0)[0] >= 7:
            print(
                '| NOTICE: your device may support faster training with --fp16'
            )
        trainer = Trainer(args, task, model, criterion)
    if (args.online_eval or args.target_bleu) and not args.remove_bpe:
        args.remove_bpe = '@@ '
    print('| training on {} GPUs'.format(args.distributed_world_size))
    print('| max tokens per GPU = {} and max sentences per GPU = {}'.format(
        args.max_tokens,
        args.max_sentences,
    ))
    max_positions = trainer.get_model().max_positions()
    epoch_itr = data.EpochBatchIterator(
        dataset=task.dataset(args.train_subset),
        max_tokens=args.max_tokens,
        max_sentences=args.max_sentences_valid,
        max_positions=max_positions,
        ignore_invalid_inputs=True,
        required_batch_size_multiple=8,
        seed=args.seed,
        num_shards=args.distributed_world_size,
        shard_id=args.distributed_rank,
    )
    # Load the latest checkpoint if one is available
    load_checkpoint(args, trainer, epoch_itr)

    # Send a dummy batch to warm the caching allocator
    dummy_batch = task.dataset('train').get_dummy_batch(
        args.max_tokens, max_positions)
    trainer.dummy_train_step(dummy_batch)

    # Train until the learning rate gets too small or model reaches target score
    max_epoch = args.max_epoch or math.inf
    max_update = args.max_update or math.inf
    tgt_bleu = args.target_bleu or math.inf
    current_bleu = 0.0
    best_bleu = 0.0
    lr = trainer.get_lr()
    train_meter = StopwatchMeter()
    train_meter.start()
    valid_losses = [None]
    valid_subsets = args.valid_subset.split(',')

    while lr >= args.min_lr and epoch_itr.epoch < max_epoch and trainer.get_num_updates(
    ) < max_update and current_bleu < tgt_bleu:
        # train for one epoch
        train(args, trainer, task, epoch_itr)
        if epoch_itr.epoch % args.validate_interval == 0:
            valid_losses = validate(args, trainer, task, epoch_itr,
                                    valid_subsets)

        # Eval BLEU score
        if args.online_eval or (not tgt_bleu is math.inf):
            current_bleu, current_sc_bleu = score(args, trainer, task,
                                                  epoch_itr, args.gen_subset)
            if current_bleu > best_bleu:
                best_bleu = current_bleu
                save_checkpoint(args, trainer, epoch_itr, valid_losses[0])

        # Only use first validation loss to update the learning rate
        lr = trainer.lr_step(epoch_itr.epoch, valid_losses[0])

        # Save checkpoint
        if epoch_itr.epoch % args.save_interval == 0:
            save_checkpoint(args, trainer, epoch_itr, valid_losses[0])

    train_meter.stop()
    print('| done training in {:.1f} seconds'.format(train_meter.sum))
示例#9
0
def train(
    cfg: DictConfig, trainer: Trainer, task: tasks.FairseqTask, epoch_itr
) -> Tuple[List[Optional[float]], bool]:
    """Train the model for one epoch and return validation losses."""
    # Initialize data iterator
    itr = epoch_itr.next_epoch_itr(
        fix_batches_to_gpus=cfg.distributed_training.fix_batches_to_gpus,
        shuffle=(epoch_itr.next_epoch_idx > cfg.dataset.curriculum),
    )
    update_freq = (
        cfg.optimization.update_freq[epoch_itr.epoch - 1]
        if epoch_itr.epoch <= len(cfg.optimization.update_freq)
        else cfg.optimization.update_freq[-1]
    )
    itr = iterators.GroupedIterator(itr, update_freq)
    if cfg.common.tpu:
        itr = utils.tpu_data_loader(itr)
    progress = progress_bar.progress_bar(
        itr,
        log_format=cfg.common.log_format,
        log_interval=cfg.common.log_interval,
        epoch=epoch_itr.epoch,
        tensorboard_logdir=(
            cfg.common.tensorboard_logdir
            if distributed_utils.is_master(cfg.distributed_training)
            else None
        ),
        default_log_format=("tqdm" if not cfg.common.no_progress_bar else "simple"),
        wandb_project=(
            cfg.common.wandb_project
            if distributed_utils.is_master(cfg.distributed_training)
            else None
        ),
        wandb_run_name=os.environ.get(
            "WANDB_NAME", os.path.basename(cfg.checkpoint.save_dir)
        ),
        azureml_logging=(
            cfg.common.azureml_logging
            if distributed_utils.is_master(cfg.distributed_training)
            else False
        ),
    )
    progress.update_config(_flatten_config(cfg))

    trainer.begin_epoch(epoch_itr.epoch)

    if hasattr(trainer.criterion, "set_epoch"):
        trainer.criterion.set_epoch(epoch_itr.epoch)

    valid_subsets = cfg.dataset.valid_subset.split(",")
    should_stop = False
    num_updates = trainer.get_num_updates()
    for i, samples in enumerate(progress):
        with metrics.aggregate("train_inner"), torch.autograd.profiler.record_function(
            "train_step-%d" % i
        ):
            log_output = trainer.train_step(samples)

        if log_output is not None:  # not OOM, overflow, ...
            # log mid-epoch stats
            num_updates = trainer.get_num_updates()
            if num_updates % cfg.common.log_interval == 0:
                stats = get_training_stats(metrics.get_smoothed_values("train_inner"))
                progress.log(stats, tag="train_inner", step=num_updates)

                # reset mid-epoch stats after each log interval
                # the end-of-epoch stats will still be preserved
                metrics.reset_meters("train_inner")

        # update the state prior stored in the model for cross-entropy training of hybrid systems
        if hasattr(task, "update_state_prior"):
            task.update_state_prior(trainer.get_model())

        end_of_epoch = not itr.has_next()
        valid_losses, should_stop = validate_and_save(
            cfg, trainer, task, epoch_itr, valid_subsets, end_of_epoch
        )

        if should_stop:
            break

    # log end-of-epoch stats
    logger.info("end of epoch {} (average epoch stats below)".format(epoch_itr.epoch))
    stats = get_training_stats(metrics.get_smoothed_values("train"))
    progress.print(stats, tag="train", step=num_updates)

    # reset epoch-level meters
    metrics.reset_meters("train")
    return valid_losses, should_stop
示例#10
0
def main(args):
    if not torch.cuda.is_available():
        raise NotImplementedError('Training on CPU is not supported')
    torch.cuda.set_device(args.device_id)
    from mlperf_compliance.mlperf_log import transformer_print
    transformer_print(
        key=mlperf_log.RUN_CLEAR_CACHES
    )  #before this tag we should run clearing caches on the host
    # mlperf compliance synchronization
    if args.distributed_world_size > 1:
        assert (torch.distributed.is_initialized())
        torch.distributed.broadcast(torch.tensor([1], device="cuda"), 0)
        torch.cuda.synchronize()
    transformer_print(key=mlperf_log.RUN_START)
    if args.max_tokens is None:
        args.max_tokens = 6000
    print(args)
    transformer_print(key=mlperf_log.OPT_NAME, value=args.optimizer)
    transformer_print(key=mlperf_log.OPT_LR, value=args.lr)
    transformer_print(key=mlperf_log.OPT_HP_ADAM_BETA1,
                      value=eval(args.adam_betas)[0])
    transformer_print(key=mlperf_log.OPT_HP_ADAM_BETA2,
                      value=eval(args.adam_betas)[1])
    transformer_print(key=mlperf_log.OPT_HP_ADAM_EPSILON, value=args.adam_eps)
    pValue = ctypes.cast((ctypes.c_int * 1)(), ctypes.POINTER(ctypes.c_int))
    result = torch.cuda.cudart().cudaDeviceSetLimit(ctypes.c_int(0x05),
                                                    ctypes.c_int(128))
    result = torch.cuda.cudart().cudaDeviceGetLimit(pValue, ctypes.c_int(0x05))
    torch.manual_seed(args.seed)
    transformer_print(key=mlperf_log.RUN_SET_RANDOM_SEED, value=args.seed)

    # Setup task, e.g., translation, language modeling, etc.
    task = tasks.setup_task(args)
    transformer_print(key=mlperf_log.MODEL_HP_SEQ_BEAM_SEARCH,
                      value={
                          'alpha': args.lenpen,
                          'beam_size': args.beam,
                          'extra_decode_length': args.max_len_b,
                          'vocab_size': task.target_dictionary.__len__()
                      })

    # Load dataset splits
    load_dataset_splits(task, ['train', 'valid'])

    # Build model and criterion
    model = task.build_model(args)
    criterion = task.build_criterion(args)
    print('| model {}, criterion {}'.format(args.arch,
                                            criterion.__class__.__name__))
    print('| num. model params: {}'.format(
        sum(p.numel() for p in model.parameters())))

    # Build trainer
    if args.fp16:
        trainer = FP16Trainer(args, task, model, criterion)
    else:
        if torch.cuda.get_device_capability(0)[0] >= 7:
            print(
                '| NOTICE: your device may support faster training with --fp16'
            )
        trainer = Trainer(args, task, model, criterion)
    if (args.online_eval or args.target_bleu) and not args.remove_bpe:
        args.remove_bpe = '@@ '
    print('| training on {} GPUs'.format(args.distributed_world_size))
    print('| max tokens per GPU = {} and max sentences per GPU = {}'.format(
        args.max_tokens,
        args.max_sentences,
    ))
    transformer_print(key=mlperf_log.INPUT_BATCH_SIZE, value=args.max_tokens)
    transformer_print(key=mlperf_log.INPUT_ORDER)
    # Initialize dataloader
    max_positions = trainer.get_model().max_positions()

    # Send a dummy batch to warm the caching allocator
    dummy_batch = task.dataset('train').get_dummy_batch(
        args.max_tokens, max_positions)
    trainer.dummy_train_step(dummy_batch)

    # Train until the learning rate gets too small or model reaches target score
    max_epoch = args.max_epoch or math.inf
    max_update = args.max_update or math.inf
    tgt_bleu = args.target_bleu or math.inf
    current_bleu = 0.0
    lr = trainer.get_lr()
    train_meter = StopwatchMeter()
    train_meter.start()
    valid_losses = [None]
    valid_subsets = args.valid_subset.split(',')

    ctr = 0

    class DummyEpochBatchIterator:
        def __init__(self, epoch=0):
            self.epoch = epoch

    epoch_itr = DummyEpochBatchIterator(0)
    transformer_print(key=mlperf_log.TRAIN_LOOP)
    while lr >= args.min_lr and epoch_itr.epoch < max_epoch and trainer.get_num_updates(
    ) < max_update and current_bleu < tgt_bleu:
        transformer_print(key=mlperf_log.TRAIN_EPOCH, value=epoch_itr.epoch)
        import time
        start = time.time()
        epoch_itr = data.EpochBatchIterator(
            dataset=task.dataset(args.train_subset),
            max_tokens=args.max_tokens,
            max_sentences=args.max_sentences_valid,
            max_positions=max_positions,
            ignore_invalid_inputs=True,
            required_batch_size_multiple=8,
            seed=args.seed,
            num_shards=args.distributed_world_size,
            shard_id=args.distributed_rank,
            epoch=epoch_itr.epoch if ctr is not 0 else 0)
        print("got epoch iterator", time.time() - start)

        # Load the latest checkpoint if one is available
        if ctr is 0:
            load_checkpoint(args, trainer, epoch_itr)

        # train for one epoch
        start = time.time()
        train(args, trainer, task, epoch_itr)
        print("epoch time ", time.time() - start)

        start = time.time()

        if epoch_itr.epoch % args.validate_interval == 0:
            valid_losses = validate(args, trainer, task, epoch_itr,
                                    valid_subsets)

        # Eval BLEU score
        transformer_print(key=mlperf_log.EVAL_START, value=epoch_itr.epoch)
        if args.online_eval or (not tgt_bleu is math.inf):
            current_bleu = score(args, trainer, task, epoch_itr,
                                 args.gen_subset)
            transformer_print(key=mlperf_log.EVAL_ACCURACY,
                              value={
                                  'epoch': epoch_itr.epoch,
                                  'value': current_bleu
                              })
            transformer_print(key=mlperf_log.EVAL_TARGET, value=tgt_bleu)
        transformer_print(key=mlperf_log.EVAL_STOP, value=epoch_itr.epoch)

        # Only use first validation loss to update the learning rate
        lr = trainer.lr_step(epoch_itr.epoch, valid_losses[0])

        # Save checkpoint
        if epoch_itr.epoch % args.save_interval == 0:
            save_checkpoint(args, trainer, epoch_itr, valid_losses[0])

        ctr = ctr + 1
        print("validation and scoring ", time.time() - start)

    train_meter.stop()
    transformer_print(key=mlperf_log.RUN_STOP)
    transformer_print(key=mlperf_log.RUN_FINAL)
    print('| done training in {:.1f} seconds'.format(train_meter.sum))
def setup():
    parser = options.get_training_parser()
    args = options.parse_args_and_arch(parser)
    # make sure everything is reset before loading the model
    args.reset_optimizer = True
    args.reset_meters = True
    args.reset_dataloader = True
    args.reset_lr_scheduler = True
    args.path = args.restore_file
    args.max_sentences_valid = 1  # We attack batch size 1 at the moment
    args.beam = 1  # beam size 1 for inference on the model, could use higher
    utils.import_user_module(args)

    torch.manual_seed(args.seed)

    # setup task, model, loss function, and trainer
    task = tasks.setup_task(args)
    if not args.interactive_attacks:
        for valid_sub_split in args.valid_subset.split(
                ','):  # load validation data
            task.load_dataset(valid_sub_split, combine=False, epoch=0)
    models, _ = checkpoint_utils.load_model_ensemble(args.path.split(':'),
                                                     arg_overrides={},
                                                     task=task)
    assert len(
        models) == 1  # Make sure you didn't pass an ensemble of models in
    model = models[0]

    if torch.cuda.is_available() and not args.cpu:
        assert torch.cuda.device_count() == 1  # only works on 1 GPU for now
        torch.cuda.set_device(0)
        model.cuda()
    args.beam = 1  # beam size 1 for now
    model.make_generation_fast_(beamable_mm_beam_size=args.beam,
                                need_attn=False)

    criterion = task.build_criterion(args)
    trainer = Trainer(args, task, model, criterion)
    generator = task.build_generator(args)

    bpe_vocab_size = trainer.get_model().encoder.embed_tokens.weight.shape[0]
    add_hooks(trainer.get_model(),
              bpe_vocab_size)  # add gradient hooks to embeddings
    embedding_weight = get_embedding_weight(
        trainer.get_model(), bpe_vocab_size)  # save the embedding matrix
    if not args.interactive_attacks:
        subset = args.valid_subset.split(',')[
            0]  # only one validation subset handled
        itr = trainer.task.get_batch_iterator(
            dataset=trainer.task.dataset(subset),
            max_tokens=args.max_tokens_valid,
            max_sentences=args.max_sentences_valid,
            max_positions=utils.resolve_max_positions(
                trainer.task.max_positions(),
                trainer.get_model().max_positions(),
            ),
            ignore_invalid_inputs=args.skip_invalid_size_inputs_valid_test,
            required_batch_size_multiple=args.required_batch_size_multiple,
            seed=args.seed,
            num_shards=args.distributed_world_size,
            shard_id=args.distributed_rank,
            num_workers=args.num_workers,
        ).next_epoch_itr(shuffle=False)
    else:
        itr = [
            None
        ] * 100000  # a fake dataset to go through, overwritten when doing interactive attacks

    # Handle BPE
    bpe = encoders.build_bpe(args)
    assert bpe is not None
    return args, trainer, generator, embedding_weight, itr, bpe
示例#12
0
def main(args):
    if args.max_tokens is None:
        args.max_tokens = 6000
    print(args)

    if not torch.cuda.is_available():
        raise NotImplementedError('Training on CPU is not supported')
    torch.cuda.set_device(args.device_id)
    torch.manual_seed(args.seed)

    # Setup task, e.g., translation, language modeling, etc.
    task = tasks.setup_task(args)

    # Load dataset splits
    load_dataset_splits(task, ['train', 'valid'])

    # print("<AFTER>load_dataset_splits")
    # Build model and criterion
    model = task.build_model(args)
    print('| num. model params: {}'.format(sum(p.numel() for p in model.parameters())))
    print('| max tokens per GPU = {} and max sentences per GPU = {}'.format(
        args.max_tokens,
        args.max_sentences,
    ))
    # print("<AFTER>build_model")

    # Validation
    valid_losses = [None]
    valid_subsets = args.valid_subset.split(',')
    val_criterion = task.build_criterion(args, 'label_smoothed_cross_entropy')
    val_trainer = Trainer(args, task, model, val_criterion)

    class_pretrain_flag = False
    mt_pretrain_flag = False

    # Pre-training on CNN discriminator and Seq2Seq recontruction
    if args.task == 'style_transfer':
        # classification pretrain
        criterion = task.build_criterion(args, 'classification')
        trainer = Trainer(args, task, model, criterion)
        print('| model {}, criterion {}'.format(args.arch, criterion.__class__.__name__))
        max_positions = trainer.get_model().max_positions()
        epoch_itr = data.EpochBatchIterator(
            dataset=task.dataset('train'),
            max_tokens=args.max_tokens,
            max_sentences=args.max_sentences_valid,
            max_positions=max_positions,
            ignore_invalid_inputs=True,
            required_batch_size_multiple=8,
            seed=args.seed,
            num_shards=args.distributed_world_size,
            shard_id=args.distributed_rank,
        )

        # Load the latest checkpoint if one is available
        load_checkpoint(args, trainer, epoch_itr, load_optim=True, find_best=args.restore_best)

        max_epoch = args.pre_train_max_epoch
        while epoch_itr.epoch < max_epoch:
            class_pretrain_flag = True

            # train for one epoch
            train(args, trainer, task, epoch_itr)
            valid_losses = validate(args, trainer, task, epoch_itr, valid_subsets)

            # save to checkpoint
            save_checkpoint(args, trainer, epoch_itr, valid_losses[0])

        print("Done classification pretrain")

        # MT pretrain
        criterion = task.build_criterion(args, 'style_transfer_pretrain')
        trainer = Trainer(args, task, model, criterion)
        print('| model {}, criterion {}'.format(args.arch, criterion.__class__.__name__))

        # Load the latest checkpoint if one is available
        if epoch_itr.epoch <= args.pre_train_max_epoch:
            load_checkpoint(args, trainer, epoch_itr, load_optim=False, find_best=True)
            epoch_itr.epoch = args.pre_train_max_epoch
            save_checkpoint.best = float("inf")
        else:
            load_checkpoint(args, trainer, epoch_itr, load_optim=True, find_best=args.restore_best)

        # Send a dummy batch to warm the caching allocator
        dummy_batch = task.dataset('train').get_dummy_batch(args.max_tokens, max_positions)
        trainer.dummy_train_step(dummy_batch)

        max_epoch = 2 * args.pre_train_max_epoch
        while epoch_itr.epoch < max_epoch:
            mt_pretrain_flag = True

            # train for one epoch
            train(args, trainer, task, epoch_itr)
            valid_losses = validate(args, val_trainer, task, epoch_itr, valid_subsets)

            # save to checkpoint
            save_checkpoint(args, trainer, epoch_itr, valid_losses[0])

        print("Done MT pretrain")

    # Training
    if args.task == 'style_transfer':
         criterion_name = "style_transfer_train"
         print("Loading plain data")
         load_dataset_splits(task, ['plain'])
    else:
         criterion_name = None
    criterion = task.build_criterion(args, criterion_name)
    print('| model {}, criterion {}'.format(args.arch, criterion.__class__.__name__))

    trainer = Trainer(args, task, model, criterion)

    # Load the latest checkpoint if one is available
    if epoch_itr.epoch <= 2*args.pre_train_max_epoch:
        load_checkpoint(args, trainer, epoch_itr, load_optim=False,
                            fix_discriminator=True, find_best=True)
    else:
        load_checkpoint(args, trainer, epoch_itr, load_optim=True,
                            fix_discriminator=True, find_best=args.restore_best)
        print("# WARNING:  Loading checkpoint with optimizer")

    # Initialize dataloader
    max_positions = trainer.get_model().max_positions()
    if args.task == 'style_transfer':
        src_plain_epoch_iter = data.EpochBatchIterator(
            dataset=task.dataset('plain')[0],
            max_tokens=args.max_tokens,
            max_sentences=args.max_sentences_valid,
            max_positions=max_positions,
            ignore_invalid_inputs=True,
            required_batch_size_multiple=8,
            seed=args.seed,
            num_shards=args.distributed_world_size,
            shard_id=args.distributed_rank,
        )
        trg_plain_epoch_iter = data.EpochBatchIterator(
            dataset=task.dataset('plain')[1],
            max_tokens=args.max_tokens,
            max_sentences=args.max_sentences_valid,
            max_positions=max_positions,
            ignore_invalid_inputs=True,
            required_batch_size_multiple=8,
            seed=args.seed,
            num_shards=args.distributed_world_size,
            shard_id=args.distributed_rank,
        )
        pre_train_max_epoch = 2 * args.pre_train_max_epoch

    # Train until the learning rate gets too small
    max_epoch = args.max_epoch or math.inf
    max_update = args.max_update or math.inf
    lr = trainer.get_lr()
    train_meter = StopwatchMeter()
    train_meter.start()

    while lr > args.min_lr and epoch_itr.epoch < (max_epoch + pre_train_max_epoch) and trainer.get_num_updates() < max_update:
        # train for one epoch
        train(args, trainer, task, epoch_itr,
                use_plain=(args.task=='style_transfer'),
                src_plain_epoch_iter=src_plain_epoch_iter,
                trg_plain_epoch_iter=trg_plain_epoch_iter,
            )

        if epoch_itr.epoch % args.validate_interval == 0:
            valid_losses = validate(
                    args, val_trainer, task, epoch_itr, valid_subsets)

        # only use first validation loss to update the learning rate
        lr = trainer.lr_step(epoch_itr.epoch, valid_losses[0])

        # save checkpoint
        if epoch_itr.epoch % args.save_interval == 0:
            save_checkpoint(args, trainer, epoch_itr, valid_losses[0])
    train_meter.stop()
    print('| done training in {:.1f} seconds'.format(train_meter.sum))
示例#13
0
def main(args):
    if not torch.cuda.is_available():
        raise NotImplementedError('Training on CPU is not supported')
    torch.cuda.set_device(args.device_id)

    mllog.config(filename=os.path.join(
        os.path.dirname(os.path.abspath(__file__)), 'transformer.log'))
    mllogger = mllog.get_mllogger()
    mllogger.logger.propagate = False

    log_start(key=constants.INIT_START, log_all_ranks=True)

    # preinit and warmup streams/groups for allreduce communicators
    allreduce_communicators = None
    if args.distributed_world_size > 1 and args.enable_parallel_backward_allred_opt:
        allreduce_groups = [
            torch.distributed.new_group()
            for _ in range(args.parallel_backward_allred_cuda_nstreams)
        ]
        allreduce_streams = [
            torch.cuda.Stream()
            for _ in range(args.parallel_backward_allred_cuda_nstreams)
        ]
        for group, stream in zip(allreduce_groups, allreduce_streams):
            with torch.cuda.stream(stream):
                torch.distributed.all_reduce(torch.cuda.FloatTensor(1),
                                             group=group)
        allreduce_communicators = (allreduce_groups, allreduce_streams)

    if args.max_tokens is None:
        args.max_tokens = 6000

    print(args)

    log_event(key=constants.GLOBAL_BATCH_SIZE,
              value=args.max_tokens * args.distributed_world_size)
    log_event(key=constants.OPT_NAME, value=args.optimizer)
    assert (len(args.lr) == 1)
    log_event(key=constants.OPT_BASE_LR,
              value=args.lr[0] if len(args.lr) == 1 else args.lr)
    log_event(key=constants.OPT_LR_WARMUP_STEPS, value=args.warmup_updates)
    assert (args.max_source_positions == args.max_target_positions)
    log_event(key=constants.MAX_SEQUENCE_LENGTH,
              value=args.max_target_positions,
              metadata={'method': 'discard'})
    log_event(key=constants.OPT_ADAM_BETA_1, value=eval(args.adam_betas)[0])
    log_event(key=constants.OPT_ADAM_BETA_2, value=eval(args.adam_betas)[1])
    log_event(key=constants.OPT_ADAM_EPSILON, value=args.adam_eps)
    log_event(key=constants.SEED, value=args.seed)

    # L2 Sector Promotion
    pValue = ctypes.cast((ctypes.c_int * 1)(), ctypes.POINTER(ctypes.c_int))
    result = ctypes.CDLL('libcudart.so').cudaDeviceSetLimit(
        ctypes.c_int(0x05), ctypes.c_int(128))
    result = ctypes.CDLL('libcudart.so').cudaDeviceGetLimit(
        pValue, ctypes.c_int(0x05))

    worker_seeds, shuffling_seeds = setup_seeds(
        args.seed,
        args.max_epoch + 1,
        torch.device('cuda'),
        args.distributed_rank,
        args.distributed_world_size,
    )
    worker_seed = worker_seeds[args.distributed_rank]
    print(
        f'Worker {args.distributed_rank} is using worker seed: {worker_seed}')
    torch.manual_seed(worker_seed)

    # Setup task, e.g., translation, language modeling, etc.
    task = tasks.setup_task(args)

    # Build model and criterion
    model = task.build_model(args)
    criterion = task.build_criterion(args)

    print('| model {}, criterion {}'.format(args.arch,
                                            criterion.__class__.__name__))
    print('| num. model params: {}'.format(
        sum(p.numel() for p in model.parameters())))

    # Build trainer
    if args.fp16:
        if args.distributed_weight_update != 0:
            from fairseq.fp16_trainer import DistributedFP16Trainer
            trainer = DistributedFP16Trainer(
                args,
                task,
                model,
                criterion,
                allreduce_communicators=allreduce_communicators)
        else:
            from fairseq.fp16_trainer import FP16Trainer
            trainer = FP16Trainer(
                args,
                task,
                model,
                criterion,
                allreduce_communicators=allreduce_communicators)
    else:
        if torch.cuda.get_device_capability(0)[0] >= 7:
            print(
                '| NOTICE: your device may support faster training with --fp16'
            )

        trainer = Trainer(args,
                          task,
                          model,
                          criterion,
                          allreduce_communicators=None)

    #if (args.online_eval or args.target_bleu) and not args.remove_bpe:
    #    args.remove_bpe='@@ '

    print('| training on {} GPUs'.format(args.distributed_world_size))
    print('| max tokens per GPU = {} and max sentences per GPU = {}'.format(
        args.max_tokens,
        args.max_sentences,
    ))

    # Initialize dataloader
    max_positions = trainer.get_model().max_positions()

    # Send a dummy batch to warm the caching allocator
    dummy_batch = language_pair_dataset.get_dummy_batch_isolated(
        args.max_tokens, max_positions, 8)
    trainer.dummy_train_step(dummy_batch)

    # Train until the learning rate gets too small or model reaches target score
    max_epoch = args.max_epoch if args.max_epoch >= 0 else math.inf
    max_update = args.max_update or math.inf
    tgt_bleu = args.target_bleu or math.inf
    current_bleu = 0.0
    lr = trainer.get_lr()
    train_meter = StopwatchMeter()
    train_meter.start()
    valid_losses = [None]

    # mlperf compliance synchronization
    if args.distributed_world_size > 1:
        assert (torch.distributed.is_initialized())
        torch.distributed.all_reduce(torch.cuda.FloatTensor(1))
        torch.cuda.synchronize()

    log_end(key=constants.INIT_STOP, sync=False)

    log_start(key=constants.RUN_START, sync=True)
    # second sync after RUN_START tag is printed.
    # this ensures no rank touches data until after RUN_START tag is printed.
    barrier()

    # Load dataset splits
    load_dataset_splits(task, ['train', 'test'])

    log_event(key=constants.TRAIN_SAMPLES,
              value=len(task.dataset(args.train_subset)),
              sync=False)
    log_event(key=constants.EVAL_SAMPLES,
              value=len(task.dataset(args.gen_subset)),
              sync=False)

    ctr = 0

    start = time.time()
    epoch_itr = data.EpochBatchIterator(
        dataset=task.dataset(args.train_subset),
        dataloader_num_workers=args.dataloader_num_workers,
        dataloader_pin_memory=args.enable_dataloader_pin_memory,
        max_tokens=args.max_tokens,
        max_sentences=args.max_sentences_valid,
        max_positions=max_positions,
        ignore_invalid_inputs=True,
        required_batch_size_multiple=8,
        seeds=shuffling_seeds,
        num_shards=args.distributed_world_size,
        shard_id=args.distributed_rank,
        epoch=epoch_itr.epoch if ctr is not 0 else 0,
        bucket_growth_factor=args.bucket_growth_factor,
        seq_len_multiple=args.seq_len_multiple,
        batching_scheme=args.batching_scheme,
        batch_multiple_strategy=args.batch_multiple_strategy,
    )
    print("got epoch iterator", time.time() - start)

    # Main training loop
    while lr >= args.min_lr and epoch_itr.epoch < max_epoch and trainer.get_num_updates(
    ) < max_update and current_bleu < tgt_bleu:
        first_epoch = epoch_itr.epoch + 1
        log_start(key=constants.BLOCK_START,
                  metadata={
                      'first_epoch_num': first_epoch,
                      'epoch_count': 1
                  },
                  sync=False)
        log_start(key=constants.EPOCH_START,
                  metadata={'epoch_num': first_epoch},
                  sync=False)

        gc.disable()

        # Load the latest checkpoint if one is available
        if ctr is 0:
            load_checkpoint(args, trainer, epoch_itr)

        # train for one epoch
        start = time.time()
        #exit(1)
        train(args, trainer, task, epoch_itr, shuffling_seeds)
        print("epoch time ", time.time() - start)

        start = time.time()
        log_end(key=constants.EPOCH_STOP,
                metadata={'epoch_num': first_epoch},
                sync=False)

        # Eval BLEU score
        if args.online_eval or (not tgt_bleu is math.inf):
            current_bleu = score(args, trainer, task, epoch_itr,
                                 args.gen_subset)
            log_event(key=constants.EVAL_ACCURACY,
                      value=float(current_bleu) / 100.0,
                      metadata={'epoch_num': first_epoch})

        gc.enable()

        # Only use first validation loss to update the learning rate
        #lr = trainer.lr_step(epoch_itr.epoch, valid_losses[0])

        # Save checkpoint
        #if epoch_itr.epoch % args.save_interval == 0:
        #    save_checkpoint(args, trainer, epoch_itr, valid_losses[0])

        ctr = ctr + 1
        print("validation and scoring ", time.time() - start)
        log_end(key=constants.BLOCK_STOP,
                metadata={'first_epoch_num': first_epoch},
                sync=False)

    train_meter.stop()
    status = 'success' if current_bleu >= tgt_bleu else 'aborted'
    log_end(key=constants.RUN_STOP, metadata={'status': status})
    print('| done training in {:.1f} seconds'.format(train_meter.sum))
示例#14
0
def main(args):
    if args.max_tokens is None:
        args.max_tokens = 6000
    print(args)

    if not torch.cuda.is_available():
        raise NotImplementedError('Training on CPU is not supported')
    torch.cuda.set_device(args.device_id)
    torch.manual_seed(args.seed)

    # Setup task, e.g., translation, language modeling, etc.
    task = tasks.setup_task(args)

    # Load dataset splits
    load_dataset_splits(task, ['train', 'valid'])

    # Build model and criterion
    model = task.build_model(args)
    criterion = task.build_criterion(args)

    print('| model {}, criterion {}'.format(args.arch,
                                            criterion.__class__.__name__))
    print('| num. model params: {}'.format(
        sum(p.numel() for p in model.parameters())))

    # Build trainer
    if args.fp16:
        trainer = FP16Trainer(args, task, model, criterion)
    else:
        if torch.cuda.get_device_capability(0)[0] >= 7:
            print(
                '| NOTICE: your device may support faster training with --fp16'
            )
        trainer = Trainer(args, task, model, criterion)
    print('| training on {} GPUs'.format(args.distributed_world_size))
    print('| max tokens per GPU = {} and max sentences per GPU = {}'.format(
        args.max_tokens,
        args.max_sentences,
    ))

    # Initialize dataloader
    max_positions = trainer.get_model().max_positions()
    epoch_itr = data.EpochBatchIterator(
        dataset=task.dataset(args.train_subset),
        max_tokens=args.max_tokens,
        max_sentences=args.max_sentences_valid,
        max_positions=max_positions,
        ignore_invalid_inputs=True,
        required_batch_size_multiple=8,
        seed=args.seed,
        num_shards=args.distributed_world_size,
        shard_id=args.distributed_rank,
    )

    # Load the latest checkpoint if one is available
    load_checkpoint(args, trainer, epoch_itr)
    embedding = model.decoder.embed_tokens.weight.data.cpu().numpy()
    print(embedding.shape)

    Ar, s = low_rank_approx(embedding, 2)
    print(Ar.shape)

    np.savetxt('svd', Ar, delimiter=' ')
示例#15
0
def sari_validate(cfg: DictConfig, trainer: Trainer, task: tasks.FairseqTask,
                  epoch_itr, subsets: List[str]) -> List[Optional[float]]:
    from pathlib import Path
    from access.resources.paths import get_data_filepath
    from access.utils.helpers import read_lines
    from access.preprocessors import load_preprocessors, ComposedPreprocessor
    from easse.report import get_all_scores
    from fairseq.data import encoders
    from fairseq_cli.interactive import buffered_read, make_batches
    from fairseq_cli.generate import get_symbols_to_strip_from_output
    from fairseq.token_generation_constraints import pack_constraints, unpack_constraints
    import tempfile

    use_cuda = torch.cuda.is_available() and not cfg.common.cpu

    # Setup task, e.g., translation
    task = tasks.setup_task(cfg.task)

    # TODO: Choose parameters for the preprocessors ?
    # 从pickle文件读取preprocessor
    # preprocessors = load_preprocessors(Path(cfg.task.data).parent)
    # composed_preprocessor = ComposedPreprocessor(preprocessors)
    # 获得turkcorpus.valid.complex的路径
    complex_filepath = get_data_filepath('turkcorpus', 'valid', 'complex')
    # make temp dir
    # encoded_complex_filepath = tempfile.mkstemp()[1]
    # encoded_pred_filepath = tempfile.mkstemp()[1]
    pred_filepath = tempfile.mkstemp()[1]
    # use preprocessors to encode complex file
    # composed_preprocessor.encode_file(complex_filepath, encoded_complex_filepath)
    max_positions = utils.resolve_max_positions(
        task.max_positions(),
        trainer.get_model().max_positions(),
    )
    parser = options.get_generation_parser(interactive=True)
    # TODO: Take args from fairseq_generate
    gen_args = options.parse_args_and_arch(
        parser, input_args=['/dummy_data', '--beam', '2'])
    # Initialize generator
    generator = task.build_generator([trainer.model], gen_args)

    # Handle tokenization and BPE
    tokenizer = encoders.build_tokenizer(cfg.tokenizer)
    bpe = encoders.build_bpe(cfg.bpe)

    # Set dictionaries
    src_dict = task.source_dictionary
    tgt_dict = task.target_dictionary

    def encode_fn(x):
        if tokenizer is not None:
            x = tokenizer.encode(x)
        if bpe is not None:
            x = bpe.encode(x)
        return x

    def decode_fn(x):
        if bpe is not None:
            x = bpe.decode(x)
        if tokenizer is not None:
            x = tokenizer.decode(x)
        return x

    align_dict = utils.load_align_dict(cfg.generation.replace_unk)

    with open(pred_filepath, 'w') as f:
        start_id = 0
        for inputs in buffered_read(complex_filepath, buffer_size=9999):
            results = []
            for batch in make_batches(inputs, cfg, task, max_positions,
                                      encode_fn):
                bsz = batch.src_tokens.size(0)
                src_tokens = batch.src_tokens
                src_lengths = batch.src_lengths
                constraints = batch.constraints
                if use_cuda:
                    src_tokens = src_tokens.cuda()
                    src_lengths = src_lengths.cuda()
                    if constraints is not None:
                        constraints = constraints.cuda()
                sample = {
                    "net_input": {
                        "src_tokens": src_tokens,
                        "src_lengths": src_lengths,
                    },
                }
                translations = task.inference_step(generator, [trainer.model],
                                                   sample,
                                                   constraints=constraints)
                list_constraints = [[] for _ in range(bsz)]
                if cfg.generation.constraints:
                    list_constraints = [
                        unpack_constraints(c) for c in constraints
                    ]
                for i, (id, hypos) in enumerate(
                        zip(batch.ids.tolist(), translations)):
                    src_tokens_i = utils.strip_pad(src_tokens[i],
                                                   tgt_dict.pad())
                    constraints = list_constraints[i]
                    results.append((
                        start_id + id,
                        src_tokens_i,
                        hypos,
                        {
                            "constraints": constraints,
                        },
                    ))

            # sort output to match input order
            for id_, src_tokens, hypos, info in sorted(results,
                                                       key=lambda x: x[0]):
                if src_dict is not None:
                    src_str = src_dict.string(src_tokens,
                                              cfg.common_eval.post_process)
                    for constraint in info["constraints"]:
                        pass

                # Process top predictions
                for hypo in hypos[:min(len(hypos), cfg.generation.nbest)]:
                    hypo_tokens, hypo_str, alignment = utils.post_process_prediction(
                        hypo_tokens=hypo["tokens"].int().cpu(),
                        src_str=src_str,
                        alignment=hypo["alignment"],
                        align_dict=align_dict,
                        tgt_dict=tgt_dict,
                        remove_bpe=cfg.common_eval.post_process,
                        extra_symbols_to_ignore=
                        get_symbols_to_strip_from_output(generator),
                    )
                    detok_hypo_str = decode_fn(hypo_str)
                    # detokenized hypothesis
                    f.write(f'{detok_hypo_str}\n')
                    if cfg.generation.print_alignment:
                        alignment_str = " ".join([
                            "{}-{}".format(src, tgt) for src, tgt in alignment
                        ])

            # update running id_ counter
            start_id += len(inputs)

        # composed_preprocessor.decode_file(encoded_pred_filepath, pred_filepath)
        ref_filepaths = [
            get_data_filepath('turkcorpus', 'valid', 'simple.turk', i)
            for i in range(8)
        ]
        scores = get_all_scores(
            read_lines(complex_filepath), read_lines(pred_filepath),
            [read_lines(ref_filepath) for ref_filepath in ref_filepaths])
        print(f'num_updates={trainer.get_num_updates()}')
        print(f'ts_scores={scores}')
        sari = scores['SARI']
        if not hasattr(trainer, 'best_sari'):
            trainer.best_sari = 0
        if not hasattr(trainer, 'n_validations_since_best'):
            trainer.n_validations_since_best = 0
        if sari > trainer.best_sari:
            trainer.best_sari = sari
            trainer.n_validations_since_best = 0
        else:
            trainer.n_validations_since_best += 1
            print(
                f'SARI did not improve for {trainer.n_validations_since_best} validations'
            )
            # Does not work because scheduler will set it to previous value everytime
            # trainer.optimizer.set_lr(0.75 * trainer.optimizer.get_lr())
            if trainer.n_validations_since_best >= cfg.validations_before_sari_early_stopping:
                print(
                    f'Early stopping because SARI did not improve for {trainer.n_validations_since_best} validations'
                )
                trainer.early_stopping = True

            def is_abort(epoch_itr, best_sari):
                if (epoch_itr.epoch >= 2 and best_sari < 19):
                    return True
                if (epoch_itr.epoch >= 5 and best_sari < 22):
                    return True
                if (epoch_itr.epoch >= 10 and best_sari < 25):
                    return True
                return False

            # if is_abort(epoch_itr, best_sari):
            #     print(f'Early stopping because best SARI is too low ({best_sari:.2f}) after {epoch_itr.epoch} epochs.')
            #     # Remove the checkpoint directory as we got nothing interesting
            #     shutil.rmtree(args.save_dir)
            #     # TODO: Abort
    return [-sari]
示例#16
0
def main(args, init_distributed=False):
    utils.import_user_module(args)

    assert args.max_tokens is not None or args.max_sentences is not None, \
        'Must specify batch size either with --max-tokens or --max-sentences'

    # Initialize CUDA and distributed training
    if torch.cuda.is_available() and not args.cpu:
        torch.cuda.set_device(args.device_id)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if init_distributed:
        args.distributed_rank = distributed_utils.distributed_init(args)

    if distributed_utils.is_master(args):
        checkpoint_utils.verify_checkpoint_directory(args.save_dir)

    # Print args
    print(args)

    # Setup task, e.g., translation, language modeling, etc.
    task = tasks.setup_task(args)

    # Load valid dataset (we load training data below, based on the latest checkpoint)
    for valid_sub_split in args.valid_subset.split(','):
        task.load_dataset(valid_sub_split, combine=False, epoch=0)

    # Build model and criterion
    model = task.build_model(args)
    criterion = task.build_criterion(args)
    print(model)
    print('| model {}, criterion {}'.format(args.arch, criterion.__class__.__name__))
    print('| num. model params: {} (num. trained: {})'.format(
        sum(p.numel() for p in model.parameters()),
        sum(p.numel() for p in model.parameters() if p.requires_grad),
    ))

    # Build trainer
    trainer = Trainer(args, task, model, criterion)
    print('| training on {} GPUs'.format(args.distributed_world_size))
    print('| max tokens per GPU = {} and max sentences per GPU = {}'.format(
        args.max_tokens,
        args.max_sentences,
    ))

    # load model from snapshot
    if args.snap_model_file != 'None':
        print('load model file from {}'.format(args.snap_model_file))
        trainer.load_model_only(args.snap_model_file)
        if args.only_convert:
            state = {'args': args, 'model': trainer.get_model().state_dict()}
            path = os.path.join('checkpoints', args.snap_model_file.split('/')[-1])
            torch.save(state, path)
            exit()
    # Load the latest checkpoint if one is available and restore the
    # corresponding train iterator
    extra_state, epoch_itr = checkpoint_utils.load_checkpoint(args, trainer)

    # Train until the learning rate gets too small
    max_epoch = args.max_epoch or math.inf
    max_update = args.max_update or math.inf
    lr = trainer.get_lr()
    train_meter = StopwatchMeter()
    train_meter.start()
    valid_subsets = args.valid_subset.split(',')
    while lr > args.min_lr and epoch_itr.epoch < max_epoch and trainer.get_num_updates() < max_update:
        # train for one epoch
        train(args, trainer, task, epoch_itr)

        if not args.disable_validation and epoch_itr.epoch % args.validate_interval == 0:
            valid_losses = validate(args, trainer, task, epoch_itr, valid_subsets)
        else:
            valid_losses = [None]

        # only use first validation loss to update the learning rate
        lr = trainer.lr_step(epoch_itr.epoch, valid_losses[0])

        # save checkpoint
        if epoch_itr.epoch % args.save_interval == 0:
            checkpoint_utils.save_checkpoint(args, trainer, epoch_itr, valid_losses[0])

        reload_dataset = ':' in getattr(args, 'data', '')
        # sharded data: get train iterator for next epoch
        epoch_itr = trainer.get_train_iterator(epoch_itr.epoch, load_dataset=reload_dataset)
    train_meter.stop()
    print('| done training in {:.1f} seconds'.format(train_meter.sum))