Ejemplo n.º 1
0
    def prepare_task(args, xla_device):
        # Setup task, e.g., translation, language modeling, etc.
        task = tasks.setup_task(args)

        # Load valid dataset (we load training data below, based on the latest checkpoint)
        for valid_sub_split in args.valid_subset.split(','):
            task.load_dataset(valid_sub_split, combine=True, epoch=0)

        # Build models and criteria to print some metadata
        torch.manual_seed(args.seed)
        model, criterion = task.build_model(args), task.build_criterion(args)
        xm.master_print(model)
        xm.master_print('| model {}, criterion {}'.format(
            args.arch, criterion.__class__.__name__))
        xm.master_print('| num. model params: {} (num. trained: {})'.format(
            sum(p.numel() for p in model.parameters()),
            sum(p.numel() for p in model.parameters() if p.requires_grad)))
        model = model.to(xla_device)
        trainer = Trainer(args, task, model, criterion, xla_device=xla_device)
        lr = trainer.get_lr()

        # Load the latest checkpoint if one is available and restore the
        # corresponding train iterator
        # we overwrite distributed args here to shard data using torch_xla's
        # distributed training.
        trainer.args.distributed_rank = xm.get_ordinal()
        trainer.args.distributed_world_size = xm.xrt_world_size()
        extra_state, epoch_itr = checkpoint_utils.load_checkpoint(args, trainer)
        trainer.args.distributed_rank = 0
        trainer.args.distributed_world_size = 1
        trainer.meters_to_device(xla_device)
        valid_subsets = args.valid_subset.split(',')
        ordinal = xm.get_ordinal(defval=-1)
        device_str = (
            str(xla_device) if ordinal < 0 else
            '{}/{}'.format(xla_device, ordinal)
        )
        return task, trainer, model, epoch_itr, lr, valid_subsets, device_str
Ejemplo n.º 2
0
def main(args):
    # we should not do this!
    '''
    if args.max_tokens is None:
        args.max_tokens = 6000
    '''
    utils.xpprint(args)

    if not torch.cuda.is_available():
        raise NotImplementedError('Training on CPU is not supported')
    torch.cuda.set_device(args.device_id)
    torch.manual_seed(args.seed)

    # Setup task, e.g., translation, language modeling, etc.
    task = tasks.setup_task(args)

    utils.xprintln('setup task done!')

    # Load dataset splits
    load_dataset_splits(args, task, ['train'])
    valid_dataset = args.valid_subset.split(',')
    load_dataset_splits(args, task, valid_dataset, shuffle=False)
    utils.xprintln('load dataset done!')

    if args.task.startswith('extractive_summarization'):
        if distributed_utils.is_master(args):
            from sum_eval import MultiProcSumEval
            sum_eval_pool = MultiProcSumEval(args.ncpu_eval)
            sum_valid_pool_params = dict(
                article_file=args.raw_valid + '.article',
                summary_file=args.raw_valid + '.summary',
                entity_map_file=None,
                length=-1,
                eval_type='predict',
                topk=args.topk_sent_eval,
                rerank=False,
                with_m=False,
                cmd='-a -c 95 -m -n 4 -w 1.2',
                trigram_block=args.trigram_block,
            )

            sum_test_pool_params = dict(
                article_file=args.raw_test + '.article',
                summary_file=args.raw_test + '.summary',
                entity_map_file=None,
                length=-1,
                eval_type='predict',
                topk=args.topk_sent_eval,
                rerank=False,
                with_m=False,
                cmd='-a -c 95 -m -n 4 -w 1.2',
                trigram_block=args.trigram_block,
            )
            sum_pool_params = dict(valid=sum_valid_pool_params,
                                   test=sum_test_pool_params)

            def make_params(default_dict,
                            result_file,
                            out_rouge_file,
                            rerank=False,
                            with_m=False):
                para_dict = dict(default_dict)
                para_dict['result_file'] = result_file
                para_dict['out_rouge_file'] = out_rouge_file
                para_dict['rerank'] = rerank
                para_dict['with_m'] = with_m
                return para_dict

    # Build model and criterion
    model = task.build_model(args)
    criterion = task.build_criterion(args)
    print('| model {}, criterion {}'.format(args.arch,
                                            criterion.__class__.__name__))
    print('| num. model params: {}'.format(
        sum(p.numel() for p in model.parameters())))
    # print(model)
    import sys
    sys.stdout.flush()

    # if summarization try to load pretrained model
    # if args.task.startswith('extractive_summarization') or args.task == 'pretrain_document_modeling':
    #     # assume this is a single GPU program
    if args.init_from_pretrained_doc_model:
        task.load_pretrained_model(model, args.pretrained_doc_model_path)
    sys.stdout.flush()

    # Build trainer
    trainer = Trainer(args, task, model, criterion)
    print('| training on {} GPUs'.format(args.distributed_world_size))
    print('| max tokens per GPU = {} and max sentences per GPU = {}'.format(
        args.max_tokens,
        args.max_sentences,
    ))

    # Initialize dataloader
    max_positions = trainer.get_model().max_positions()
    epoch_itr = trainer.get_train_iterator(epoch=0, load_dataset=False)

    # Load the latest checkpoint if one is available
    # load_checkpoint(args, trainer, epoch_itr)
    # make sure training from a different checkpoint will use different random seed
    cur_dataset = task.dataset('train')
    if hasattr(cur_dataset, 'rng'):
        print('epoch ', epoch_itr.epoch)
        cur_dataset.rng = numpy.random.RandomState(args.seed + epoch_itr.epoch)

    # Train until the learning rate gets too small
    max_epoch = args.max_epoch or math.inf
    max_update = args.max_update or math.inf
    lr = trainer.get_lr()
    train_meter = StopwatchMeter()
    train_meter.start()
    valid_losses = [None]
    valid_subsets = args.valid_subset.split(',')
    for alpha in range(10, 9, -1):
        # train for one epoch
        # train(args, trainer, task, epoch_itr)

        epoch_itr.next_epoch_itr()

        if epoch_itr.epoch % args.validate_interval == 0:
            if args.task.startswith('extractive_summarization'):
                if distributed_utils.is_master(args):
                    validate_metric(args, trainer, task, epoch_itr,
                                    valid_subsets)
Ejemplo n.º 3
0
def main(cfg: FairseqConfig) -> None:
    if isinstance(cfg, argparse.Namespace):
        cfg = convert_namespace_to_omegaconf(cfg)

    utils.import_user_module(cfg.common)

    if is_master(cfg.distributed_training) and "job_logging_cfg" in cfg:
        # make hydra logging work with ddp (see # see https://github.com/facebookresearch/hydra/issues/1126)
        logging.config.dictConfig(OmegaConf.to_container(cfg.job_logging_cfg))

    assert (
        cfg.dataset.max_tokens is not None
        or cfg.dataset.batch_size is not None
    ), "Must specify batch size either with --max-tokens or --batch-size"
    metrics.reset()

    np.random.seed(cfg.common.seed)
    utils.set_torch_seed(cfg.common.seed)

    if distributed_utils.is_master(cfg.distributed_training):
        checkpoint_utils.verify_checkpoint_directory(cfg.checkpoint.save_dir)

    # Print args
    logger.info(cfg)

    if cfg.checkpoint.write_checkpoints_asynchronously:
        try:
            import iopath  # noqa: F401
        except ImportError:
            logging.exception(
                "Asynchronous checkpoint writing is specified but iopath is "
                "not installed: `pip install iopath`")
            return

    # Setup task, e.g., translation, language modeling, etc.
    task = tasks.setup_task(cfg.task)
    # Load valid dataset (we load training data below, based on the latest checkpoint)
    for valid_sub_split in cfg.dataset.valid_subset.split(","):
        task.load_dataset(valid_sub_split, combine=False, epoch=1)

    assert cfg.criterion, "Please specify criterion to train a model"

    # Build model and criterion
    model = task.build_model(cfg.model)
    criterion = task.build_criterion(cfg.criterion)
    logger.info(model)
    logger.info("task: {}".format(task.__class__.__name__))
    logger.info("model: {}".format(model.__class__.__name__))
    logger.info("criterion: {}".format(criterion.__class__.__name__))
    logger.info("num. model params: {:,} (num. trained: {:,})".format(
        sum(p.numel() for p in model.parameters()),
        sum(p.numel() for p in model.parameters() if p.requires_grad),
    ))

    # (optionally) Configure quantization
    if cfg.common.quantization_config_path is not None:
        quantizer = quantization_utils.Quantizer(
            config_path=cfg.common.quantization_config_path,
            max_epoch=cfg.optimization.max_epoch,
            max_update=cfg.optimization.max_update,
        )
    else:
        quantizer = None

    # Build trainer
    if cfg.common.model_parallel_size == 1:
        trainer = Trainer(cfg, task, model, criterion, quantizer)
    else:
        trainer = MegatronTrainer(cfg, task, model, criterion)

    logger.info("training on {} devices (GPUs/TPUs)".format(
        cfg.distributed_training.distributed_world_size))
    logger.info("max tokens per GPU = {} and batch size per GPU = {}".format(
        cfg.dataset.max_tokens,
        cfg.dataset.batch_size,
    ))

    # Load the latest checkpoint if one is available and restore the
    # corresponding train iterator
    extra_state, epoch_itr = checkpoint_utils.load_checkpoint(
        cfg.checkpoint,
        trainer,
        # don't cache epoch iterators for sharded datasets
        disable_iterator_cache=task.has_sharded_data("train"),
    )

    max_epoch = cfg.optimization.max_epoch or math.inf
    lr = trainer.get_lr()
    train_meter = meters.StopwatchMeter()
    train_meter.start()
    while epoch_itr.next_epoch_idx <= max_epoch:
        if lr <= cfg.optimization.stop_min_lr:
            logger.info(
                f"stopping training because current learning rate ({lr}) is smaller "
                "than or equal to minimum learning rate "
                f"(--stop-min-lr={cfg.optimization.stop_min_lr})")
            break

        # train for one epoch
        valid_losses, should_stop = train(cfg, trainer, task, epoch_itr)
        if should_stop:
            break

        # only use first validation loss to update the learning rate
        lr = trainer.lr_step(epoch_itr.epoch, valid_losses[0])

        epoch_itr = trainer.get_train_iterator(
            epoch_itr.next_epoch_idx,
            # sharded data: get train iterator for next epoch
            load_dataset=task.has_sharded_data("train"),
            # don't cache epoch iterators for sharded datasets
            disable_iterator_cache=task.has_sharded_data("train"),
        )
    train_meter.stop()
    logger.info("done training in {:.1f} seconds".format(train_meter.sum))

    # ioPath implementation to wait for all asynchronous file writes to complete.
    if cfg.checkpoint.write_checkpoints_asynchronously:
        logger.info(
            "ioPath PathManager waiting for all asynchronous checkpoint "
            "writes to finish.")
        PathManager.async_close()
        logger.info("ioPath PathManager finished waiting.")
Ejemplo n.º 4
0
def main(args):
    import_user_module(args)

    if args.max_tokens is None:
        args.max_tokens = 6000
    print(args)

    if torch.cuda.is_available() and not args.cpu:
        torch.cuda.set_device(args.device_id)
    torch.manual_seed(args.seed)

    # Setup task, e.g., translation, language modeling, etc.
    task = tasks.setup_task(args)

    # Load dataset splits
    load_dataset_splits(task, ['train', 'valid'])

    # Build model and criterion
    model = task.build_model(args)
    criterion = task.build_criterion(args)
    print(model)
    print('| model {}, criterion {}'.format(args.arch,
                                            criterion.__class__.__name__))
    print('| num. model params: {} (num. trained: {})'.format(
        sum(p.numel() for p in model.parameters()),
        sum(p.numel() for p in model.parameters() if p.requires_grad),
    ))

    # Make a dummy batch to (i) warm the caching allocator and (ii) as a
    # placeholder DistributedDataParallel when there's an uneven number of
    # batches per worker.
    max_positions = utils.resolve_max_positions(
        task.max_positions(),
        model.max_positions(),
    )
    dummy_batch = task.dataset('train').get_dummy_batch(
        args.max_tokens, max_positions)
    oom_batch = task.dataset('train').get_dummy_batch(1, max_positions)

    # Build trainer
    trainer = Trainer(args, task, model, criterion, dummy_batch, oom_batch)
    print('| training on {} GPUs'.format(args.distributed_world_size))
    print('| max tokens per GPU = {} and max sentences per GPU = {}'.format(
        args.max_tokens,
        args.max_sentences,
    ))

    # Initialize dataloader
    epoch_itr = task.get_batch_iterator(
        dataset=task.dataset(args.train_subset),
        max_tokens=args.max_tokens,
        max_sentences=args.max_sentences,
        max_positions=max_positions,
        ignore_invalid_inputs=True,
        required_batch_size_multiple=8,
        seed=args.seed,
        num_shards=args.distributed_world_size,
        shard_id=args.distributed_rank,
        num_workers=args.num_workers,
    )

    # Load the latest checkpoint if one is available
    if not load_checkpoint(args, trainer, epoch_itr):
        trainer.dummy_train_step([dummy_batch])

    # Train until the learning rate gets too small
    max_epoch = args.max_epoch or math.inf
    max_update = args.max_update or math.inf
    lr = trainer.get_lr()
    train_meter = StopwatchMeter()
    train_meter.start()
    valid_losses = [None]
    valid_subsets = args.valid_subset.split(',')
    while lr > args.min_lr and epoch_itr.epoch < max_epoch and trainer.get_num_updates(
    ) < max_update:
        # train for one epoch
        train(args, trainer, task, epoch_itr)

        if epoch_itr.epoch % args.validate_interval == 0:
            valid_losses = validate(args, trainer, task, epoch_itr,
                                    valid_subsets)

        # only use first validation loss to update the learning rate
        lr = trainer.lr_step(epoch_itr.epoch, valid_losses[0])

        # save checkpoint
        if epoch_itr.epoch % args.save_interval == 0:
            save_checkpoint(args, trainer, epoch_itr, valid_losses[0])
    train_meter.stop()
    print('| done training in {:.1f} seconds'.format(train_meter.sum))
Ejemplo n.º 5
0
def main(args, init_distributed=False):
    utils.import_user_module(args)

    assert args.max_tokens is not None or args.max_sentences is not None, \
        'Must specify batch size either with --max-tokens or --max-sentences'

    # Initialize CUDA and distributed training
    if torch.cuda.is_available() and not args.cpu:
        torch.cuda.set_device(args.device_id)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if init_distributed:
        args.distributed_rank = distributed_utils.distributed_init(args)

    if distributed_utils.is_master(args):
        checkpoint_utils.verify_checkpoint_directory(args.save_dir)

    # Print args
    logger.info(args)

    # Setup task, e.g., translation, language modeling, etc.
    task = tasks.setup_task(args)

    # Load valid dataset (we load training data below, based on the latest checkpoint)
    for valid_sub_split in args.valid_subset.split(','):
        task.load_dataset(valid_sub_split, combine=False, epoch=0)

    # Build model and criterion
    model = task.build_model(args)
    criterion = task.build_criterion(args)
    logger.info(model)
    logger.info('model {}, criterion {}'.format(args.arch, criterion.__class__.__name__))
    logger.info('num. model params: {} (num. trained: {})'.format(
        sum(p.numel() for p in model.parameters()),
        sum(p.numel() for p in model.parameters() if p.requires_grad),
    ))

    # Build trainer
    trainer = Trainer(args, task, model, criterion)
    logger.info('training on {} GPUs'.format(args.distributed_world_size))
    logger.info('max tokens per GPU = {} and max sentences per GPU = {}'.format(
        args.max_tokens,
        args.max_sentences,
    ))

    # Load the latest checkpoint if one is available and restore the
    # corresponding train iterator
    extra_state, epoch_itr = checkpoint_utils.load_checkpoint(args, trainer)

    # Train until the learning rate gets too small
    max_epoch = args.max_epoch or math.inf
    max_update = args.max_update or math.inf
    lr = trainer.get_lr()
    train_meter = StopwatchMeter()
    train_meter.start()
    valid_subsets = args.valid_subset.split(',')

    print(args.multi_views)

    while (
        lr > args.min_lr
        and (
            epoch_itr.epoch < max_epoch
            # allow resuming training from the final checkpoint
            or epoch_itr._next_epoch_itr is not None
        )
        and trainer.get_num_updates() < max_update
    ):
        

        # train for one epoch
        train(args, trainer, task, epoch_itr)

        if not args.disable_validation and epoch_itr.epoch % args.validate_interval == 0:
            valid_losses = validate(args, trainer, task, epoch_itr, valid_subsets)
        else:
            valid_losses = [None]

        # only use first validation loss to update the learning rate
        lr = trainer.lr_step(epoch_itr.epoch, valid_losses[0])

        
        
        bart = BARTHubInterface(args, task, trainer.model).cuda()
        #print(bart.device)
        bart.eval()
        count = 1
        bsz = 8


        print("Test on val set: ")
        

        with open('../data/val_sent_trans_cons_label.source') as source, open('../data/val_sent_c99_label.source') as source2, open('./val_best_multi_attn_'+str(args.lr_weight)+'_.hypo', 'wt', encoding='utf-8') as fout:
            s1 = source.readlines()
            s2 = source2.readlines()
            
            slines = [s1[0].strip()]
            slines2 = [s2[0].strip()]
            
            for i in tqdm(range(1, len(s1))):
                if count % bsz == 0:
                    with torch.no_grad():
                        if args.multi_views:
                            hypotheses_batch = bart.sample(slines, sentences2 = slines2, balance = True, beam=4, lenpen=2.0, max_len_b=100, min_len=5, no_repeat_ngram_size=3)
                        else:
                            hypotheses_batch = bart.sample(slines, beam=4, lenpen=2.0, max_len_b=100, min_len=5, no_repeat_ngram_size=3)
                    for hypothesis in hypotheses_batch:
                        fout.write(hypothesis + '\n')
                        fout.flush()
                    slines = []
                    slines2 = []
                
                slines.append(s1[i].strip())
                slines2.append(s2[i].strip())
            
                count += 1
                
            if slines != []:
                if args.multi_views:
                    hypotheses_batch = bart.sample(slines, sentences2 = slines2, balance = True, beam=4, lenpen=2.0, max_len_b=100, min_len=5, no_repeat_ngram_size=3)
                else:
                    hypotheses_batch = bart.sample(slines, beam=4, lenpen=2.0, max_len_b=100, min_len=5, no_repeat_ngram_size=3)
                #hypotheses_batch = bart.sample(slines, sentences2 = slines2, balance = True, beam=4, lenpen=2.0, max_len_b=100, min_len=5, no_repeat_ngram_size=3)
                for hypothesis in hypotheses_batch:
                    fout.write(hypothesis + '\n')
                    fout.flush()
        hyp_path = './val_best_multi_attn_'+str(args.lr_weight)+'_.hypo'
        ref_path = '../data/val_sent_trans_cons_label.target'
        hypothesis = []
        with open(hyp_path, 'r') as f:
            lines = f.readlines()
            for l in lines:
                hypothesis.append(l[:-1])
        
        reference = []
        with open(ref_path, 'r') as f:
            lines = f.readlines()
            for l in lines:
                reference.append(l[:-1])

        rouge = Rouge()
        print("Val", rouge.get_scores(hypothesis, reference, avg = True))
        

        # save checkpoint
        if epoch_itr.epoch % args.save_interval == 0:
            checkpoint_utils.save_checkpoint(args, trainer, epoch_itr, valid_losses[0])
        
        
        print("Test on testing set: ")

        count = 1
        bsz = 8
        with open('../data/test_sent_trans_cons_label.source') as source, open('../data/test_sent_c99_label.source') as source2, open('./test_best_multi_attn_'+str(args.lr_weight)+'_.hypo', 'wt', encoding='utf-8') as fout:
            s1 = source.readlines()
            s2 = source2.readlines()
            
            slines = [s1[0].strip()]
            slines2 = [s2[0].strip()]
            
            for i in tqdm(range(1, len(s1))):
                if count % bsz == 0:
                    with torch.no_grad():
                        if args.multi_views:
                            hypotheses_batch = bart.sample(slines, sentences2 = slines2, balance = True, beam=4, lenpen=2.0, max_len_b=100, min_len=5, no_repeat_ngram_size=3)
                        else:
                            hypotheses_batch = bart.sample(slines, beam=4, lenpen=2.0, max_len_b=100, min_len=5, no_repeat_ngram_size=3)
                    for hypothesis in hypotheses_batch:
                        fout.write(hypothesis + '\n')
                        fout.flush()
                    slines = []
                    slines2 = []
                
                slines.append(s1[i].strip())
                slines2.append(s2[i].strip())
            
                count += 1
                
            if slines != []:
                if args.multi_views:
                    hypotheses_batch = bart.sample(slines, sentences2 = slines2, balance = True, beam=4, lenpen=2.0, max_len_b=100, min_len=5, no_repeat_ngram_size=3)
                else:
                    hypotheses_batch = bart.sample(slines, beam=4, lenpen=2.0, max_len_b=100, min_len=5, no_repeat_ngram_size=3)
                
                for hypothesis in hypotheses_batch:
                    fout.write(hypothesis + '\n')
                    fout.flush()
        hyp_path = './test_best_multi_attn_'+str(args.lr_weight)+'_.hypo'
        ref_path = '../data/test_sent_trans_cons_label.target'
        hypothesis = []
        with open(hyp_path, 'r') as f:
            lines = f.readlines()
            for l in lines:
                hypothesis.append(l[:-1])
        
        reference = []
        with open(ref_path, 'r') as f:
            lines = f.readlines()
            for l in lines:
                reference.append(l[:-1])

        rouge = Rouge()
        print('Test', rouge.get_scores(hypothesis, reference, avg = True))
        

        # early stop
        if should_stop_early(args, valid_losses[0]):
            logger.info('early stop since valid performance hasn\'t improved for last {} runs'.format(args.patience))
            break

        epoch_itr = trainer.get_train_iterator(
            epoch_itr.epoch,
            # sharded data: get train iterator for next epoch
            load_dataset=(os.pathsep in getattr(args, 'data', '')),
        )
    train_meter.stop()
    logger.info('done training in {:.1f} seconds'.format(train_meter.sum))
Ejemplo n.º 6
0
def main(args, init_distributed=False):
    import_user_module(args)

    if args.max_tokens is None:
        args.max_tokens = 6000
    print(args)

    #args.distributed_world_size = 1

    if torch.cuda.is_available() and not args.cpu:
        torch.cuda.set_device(args.device_id)
    torch.manual_seed(args.seed)

    # Setup task, e.g., translation, language modeling, etc.
    task = tasks.setup_task(args)

    # Load dataset splits
    load_dataset_splits(args, task)

    # Build model and criterion
    model = task.build_model(args)
    criterion = task.build_criterion(args)
    print(model)
    print('| model {}, criterion {}'.format(args.arch,
                                            criterion.__class__.__name__))
    print('| num. model params: {} (num. trained: {})'.format(
        sum(p.numel() for p in model.parameters()),
        sum(p.numel() for p in model.parameters() if p.requires_grad),
    ))

    # Make a dummy batch to (i) warm the caching allocator and (ii) as a
    # placeholder DistributedDataParallel when there's an uneven number of
    # batches per worker.
    max_positions = utils.resolve_max_positions(
        task.max_positions(),
        model.max_positions(),
    )
    dummy_batch = task.dataset(args.train_subset).get_dummy_batch(
        args.max_tokens, max_positions)
    oom_batch = task.dataset(args.train_subset).get_dummy_batch(
        1, max_positions)

    # Build trainer
    print("Building trainer...")
    trainer = Trainer(args, task, model, criterion, dummy_batch, oom_batch)
    print('| training on {} GPUs'.format(args.distributed_world_size))
    print('| max tokens per GPU = {} and max sentences per GPU = {}'.format(
        args.max_tokens,
        args.max_sentences,
    ))

    # Initialize dataloader
    print("Initialize dataloader...")
    epoch_itr = task.get_batch_iterator(
        dataset=task.dataset(args.train_subset),
        max_tokens=args.max_tokens,
        max_sentences=args.max_sentences,
        max_positions=max_positions,
        ignore_invalid_inputs=True,
        required_batch_size_multiple=args.required_batch_size_multiple,
        seed=args.seed,
        num_shards=args.distributed_world_size,
        shard_id=args.distributed_rank,
        num_workers=args.num_workers,
    )

    # Initialize distributed training (after data loading)
    print("Initialize distributed training (after data loading)...")
    if init_distributed:
        import socket
        args.distributed_rank = distributed_utils.distributed_init(args)
        print('| initialized host {} as rank {}'.format(
            socket.gethostname(), args.distributed_rank))

    # Load the latest checkpoint if one is available
    print("Load the latest checkpoint if one is available...")
    if not load_checkpoint(args, trainer, epoch_itr):
        trainer.dummy_train_step([dummy_batch])
    if args.reset_target_embedding:
        trainer.init_meters(args)
        print("reset trainer.meters")
    # Train until the learning rate gets too small
    max_epoch = args.max_epoch or math.inf
    max_update = args.max_update or math.inf
    lr = trainer.get_lr()
    train_meter = StopwatchMeter()
    train_meter.start()
    valid_losses = [None]
    valid_subsets = args.valid_subset.split(',')

    if args.distributed_rank == 0:
        if os.path.basename(args.save_dir) != "":
            log_file = os.path.join(
                args.save_dir,
                "({0})-params.log".format(os.path.basename(args.save_dir)))
        else:
            log_file = os.path.join(
                args.save_dir,
                "({0})-params.log".format(args.save_dir.split('/')[-2]))
        # create log file
        args.log_file = log_file
        if os.path.exists(log_file):
            w = open(log_file, "a+", encoding="utf-8")
        else:
            w = open(log_file, "w", encoding="utf-8")
        w.write(str(args).replace(", ", ",\n") + "\n")
        w.write(str(model) + "\n")
        w.flush()
        w.close()
        print("saving params file into{}...".format(log_file))
    while lr > args.min_lr and epoch_itr.epoch < max_epoch and trainer.get_num_updates(
    ) < max_update:
        # train for one epoch
        train(args, trainer, task, epoch_itr)

        if epoch_itr.epoch % args.validate_interval == 0:
            valid_losses = validate(args, trainer, task, epoch_itr,
                                    valid_subsets)

        # only use first validation loss to update the learning rate
        lr = trainer.lr_step(epoch_itr.epoch, valid_losses[0])

        # save checkpoint
        if epoch_itr.epoch % args.save_interval == 0:
            save_checkpoint(args, trainer, epoch_itr, valid_losses[0])
    train_meter.stop()
    print('| done training in {:.1f} seconds'.format(train_meter.sum))
Ejemplo n.º 7
0
def main(cfg: DictConfig) -> None:
    if isinstance(cfg, argparse.Namespace):
        cfg = convert_namespace_to_omegaconf(cfg)

    utils.import_user_module(cfg.common)

    assert (
        cfg.dataset.max_tokens is not None
        or cfg.dataset.batch_size is not None
    ), "Must specify batch size either with --max-tokens or --batch-size"
    metrics.reset()

    np.random.seed(cfg.common.seed)
    utils.set_torch_seed(cfg.common.seed)

    if distributed_utils.is_master(cfg.distributed_training):
        checkpoint_utils.verify_checkpoint_directory(cfg.checkpoint.save_dir)

    # Print args
    logger.info(cfg)

    # Setup task, e.g., translation, language modeling, etc.
    task = tasks.setup_task(cfg.task)
    # Load valid dataset (we load training data below, based on the latest checkpoint)
    for valid_sub_split in cfg.dataset.valid_subset.split(","):
        task.load_dataset(valid_sub_split, combine=False, epoch=1)

    assert cfg.criterion, "Please specify criterion to train a model"

    # Build model and criterion
    model = task.build_model(cfg.model)
    criterion = task.build_criterion(cfg.criterion)
    logger.info(model)
    logger.info("task: {}".format(task.__class__.__name__))
    logger.info("model: {}".format(model.__class__.__name__))
    logger.info("criterion: {})".format(criterion.__class__.__name__))
    logger.info("num. model params: {} (num. trained: {})".format(
        sum(p.numel() for p in model.parameters()),
        sum(p.numel() for p in model.parameters() if p.requires_grad),
    ))

    # (optionally) Configure quantization
    if cfg.common.quantization_config_path is not None:
        quantizer = quantization_utils.Quantizer(
            config_path=cfg.common.quantization_config_path,
            max_epoch=cfg.optimization.max_epoch,
            max_update=cfg.optimization.max_update,
        )
    else:
        quantizer = None

    # Build trainer
    if cfg.common.model_parallel_size == 1:
        trainer = Trainer(cfg, task, model, criterion, quantizer)
    else:
        trainer = MegatronTrainer(cfg, task, model, criterion)

    logger.info("training on {} devices (GPUs/TPUs)".format(
        cfg.distributed_training.distributed_world_size))
    logger.info("max tokens per GPU = {} and batch size per GPU = {}".format(
        cfg.dataset.max_tokens,
        cfg.dataset.batch_size,
    ))

    # Load the latest checkpoint if one is available and restore the
    # corresponding train iterator
    extra_state, epoch_itr = checkpoint_utils.load_checkpoint(
        cfg.checkpoint,
        trainer,
        # don't cache epoch iterators for sharded datasets
        disable_iterator_cache=task.has_sharded_data("train"),
    )

    max_epoch = cfg.optimization.max_epoch or math.inf
    lr = trainer.get_lr()
    train_meter = meters.StopwatchMeter()
    train_meter.start()
    while lr > cfg.optimization.min_lr and epoch_itr.next_epoch_idx <= max_epoch:
        # train for one epoch
        valid_losses, should_stop = train(cfg, trainer, task, epoch_itr)
        if should_stop:
            break

        # only use first validation loss to update the learning rate
        lr = trainer.lr_step(epoch_itr.epoch, valid_losses[0])

        epoch_itr = trainer.get_train_iterator(
            epoch_itr.next_epoch_idx,
            # sharded data: get train iterator for next epoch
            load_dataset=task.has_sharded_data("train"),
            # don't cache epoch iterators for sharded datasets
            disable_iterator_cache=task.has_sharded_data("train"),
        )
    train_meter.stop()
    logger.info("done training in {:.1f} seconds".format(train_meter.sum))
Ejemplo n.º 8
0
def main(args):
    if args.max_tokens is None:
        args.max_tokens = 6000
    print(args)

    if not torch.cuda.is_available():
        raise NotImplementedError('Training on CPU is not supported')
    torch.cuda.set_device(args.device_id)
    torch.manual_seed(args.seed)

    # Setup task, e.g., translation, language modeling, etc.
    task = tasks.setup_task(args)

    # Load dataset splits
    load_dataset_splits(task, ['train', 'valid'])

    # Build model and criterion
    model = task.build_model(args)
    criterion = task.build_criterion(args)
    print('| model {}, criterion {}'.format(args.arch,
                                            criterion.__class__.__name__))
    print('| num. model params: {}'.format(
        sum(p.numel() for p in model.parameters())))

    # Make a dummy batch to (i) warm the caching allocator and (ii) as a
    # placeholder DistributedDataParallel when there's an uneven number of
    # batches per worker.
    max_positions = utils.resolve_max_positions(
        task.max_positions(),
        model.max_positions(),
    )
    dummy_batch = task.dataset('train').get_dummy_batch(
        args.max_tokens, max_positions)

    # Build trainer
    trainer = Trainer(args, task, model, criterion, dummy_batch)
    print('| training on {} GPUs'.format(args.distributed_world_size))
    print('| max tokens per GPU = {} and max sentences per GPU = {}'.format(
        args.max_tokens,
        args.max_sentences,
    ))

    summary_writer = SummaryWriter(log_dir=args.save_dir,
                                   enable=args.distributed_rank == 0)

    # Initialize dataloader
    epoch_itr = task.get_batch_iterator(
        dataset=task.dataset(args.train_subset),
        max_tokens=args.max_tokens,
        max_sentences=args.max_sentences,
        max_positions=max_positions,
        ignore_invalid_inputs=True,
        required_batch_size_multiple=8,
        seed=args.seed,
        num_shards=args.distributed_world_size,
        shard_id=args.distributed_rank,
    )

    first_train = True
    # Load the latest checkpoint if one is available
    if not load_checkpoint(args, trainer, epoch_itr):
        trainer.dummy_train_step([dummy_batch])
        first_train = False

    # Train until the learning rate gets too small
    max_epoch = args.max_epoch or math.inf
    max_update = args.max_update or math.inf
    lr = trainer.get_lr()
    train_meter = StopwatchMeter()
    train_meter.start()
    valid_losses = [None]
    valid_subsets = args.valid_subset.split(',')

    if not hasattr(save_checkpoint, 'not_best'):
        save_checkpoint.not_best = 0

    if not args.no_first_valid and first_train:
        valid_losses = validate(args, trainer, task, epoch_itr, valid_subsets,
                                True, summary_writer)

    if args.finetune_params != '':
        print("| train parameters.")
        for name, param in trainer.model.named_parameters():
            if trainer.should_train(name):
                print(name)

        print("| fixed parameters.")
        for name, param in trainer.model.named_parameters():
            if not trainer.should_train(name):
                print(name)

    if args.start_ckpt != '':
        save_checkpoint.not_best = 0
        save_checkpoint.best = 9999

    print("| train begin.")
    while lr > args.min_lr and epoch_itr.epoch < max_epoch and trainer.get_num_updates(
    ) < max_update:
        # train for one epoch
        train(args, trainer, task, epoch_itr, summary_writer)

        if epoch_itr.epoch % args.validate_interval == 0:
            valid_losses = validate(
                args, trainer, task, epoch_itr, valid_subsets,
                epoch_itr.epoch % args.test_bleu_interval == 0, summary_writer)
            if args.early_stop > 0:
                if hasattr(save_checkpoint,
                           'best') and valid_losses[0] > save_checkpoint.best:
                    save_checkpoint.not_best += 1
                    print("| Not the best ckpt... not best:",
                          save_checkpoint.not_best)
                    if save_checkpoint.not_best > args.early_stop:
                        print("| Early stop...")
                        break
                else:
                    save_checkpoint.not_best = 0

        # only use first validation loss to update the learning rate
        lr = trainer.lr_step(epoch_itr.epoch, valid_losses[0])

        # save checkpoint
        if epoch_itr.epoch % args.save_interval == 0:
            save_checkpoint(args, trainer, epoch_itr, valid_losses[0])
    train_meter.stop()
    print('| done training in {:.1f} seconds'.format(train_meter.sum))
    os.system("ps aux | grep redis-server | awk '{print $2}' | xargs kill")

    if args.save_output:
        save_expert_outputs(args, task, trainer)
Ejemplo n.º 9
0
def main(args, init_distributed=False):
    utils.import_user_module(args)

    try:
        from fairseq.fb_pathmgr import fb_pathmgr
        global fb_pathmgr_registerd
        if not fb_pathmgr_registerd:
            fb_pathmgr.register()
            fb_pathmgr_registerd = True
    except (ModuleNotFoundError, ImportError):
        pass

    assert args.max_tokens is not None or args.max_sentences is not None, \
        'Must specify batch size either with --max-tokens or --max-sentences'

    # Initialize CUDA and distributed training
    if torch.cuda.is_available() and not args.cpu:
        torch.cuda.set_device(args.device_id)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if init_distributed:
        args.distributed_rank = distributed_utils.distributed_init(args)

    if distributed_utils.is_master(args):
        checkpoint_utils.verify_checkpoint_directory(args.save_dir)

    # Print args
    print(args)

    # Setup task, e.g., translation, language modeling, etc.
    task = tasks.setup_task(args)

    # Load valid dataset (we load training data below, based on the latest checkpoint)
    for valid_sub_split in args.valid_subset.split(','):
        task.load_dataset(valid_sub_split, combine=False, epoch=0)

    # Build model and criterion
    model = task.build_model(args)
    criterion = task.build_criterion(args)
    print(model)
    print('| model {}, criterion {}'.format(args.arch,
                                            criterion.__class__.__name__))
    print('| num. model params: {} (num. trained: {})'.format(
        sum(p.numel() for p in model.parameters()),
        sum(p.numel() for p in model.parameters() if p.requires_grad),
    ))
    # filter the params that is unused for finetuing, ad-hoc for finetuing, should turn off when bert pretraining.
    for n, p in model.named_parameters():
        if "lm_head" in n:
            p.requires_grad = False
        #    print(n)
    #    print(n, p.requires_grad, p.shape)
    # for i, (n, p) in enumerate(model.named_parameters()):
    # print(i, n, p.size())
    # asdf

    # Build trainer
    trainer = Trainer(args, task, model, criterion)
    print('| training on {} GPUs'.format(args.distributed_world_size))
    print('| max tokens per GPU = {} and max sentences per GPU = {}'.format(
        args.max_tokens,
        args.max_sentences,
    ))

    # Load the latest checkpoint if one is available and restore the
    # corresponding train iterator
    extra_state, epoch_itr = checkpoint_utils.load_checkpoint(args, trainer)

    # Train until the learning rate gets too small
    max_epoch = args.max_epoch or math.inf
    max_update = args.max_update or math.inf
    lr = trainer.get_lr()
    train_meter = StopwatchMeter()
    train_meter.start()
    valid_subsets = args.valid_subset.split(',')

    if not hasattr(checkpoint_utils.save_checkpoint, 'not_best'):
        checkpoint_utils.save_checkpoint.not_best = 0

    #import pdb; pdb.set_trace()
    while epoch_itr.epoch < max_epoch and trainer.get_num_updates(
    ) < max_update:
        print('Start training')
        # train for one epoch
        train(args, trainer, task, epoch_itr)

        if not args.disable_validation and epoch_itr.epoch % args.validate_interval == 0:
            valid_losses = validate(args, trainer, task, epoch_itr,
                                    valid_subsets)
            if args.early_stop > 0:
                if hasattr(
                        checkpoint_utils.save_checkpoint, 'best'
                ) and valid_losses[0] > checkpoint_utils.save_checkpoint.best:
                    checkpoint_utils.save_checkpoint.not_best += 1
                    print("| Not the best ckpt... not best:",
                          checkpoint_utils.save_checkpoint.not_best)
                    if checkpoint_utils.save_checkpoint.not_best > args.early_stop:
                        print("| Early stop...")
                        break
                else:
                    checkpoint_utils.save_checkpoint.not_best = 0
        else:
            valid_losses = [None]

        # only use first validation loss to update the learning rate
        lr = trainer.lr_step(epoch_itr.epoch, valid_losses[0])

        # save checkpoint
        if epoch_itr.epoch % args.save_interval == 0:
            checkpoint_utils.save_checkpoint(args, trainer, epoch_itr,
                                             valid_losses[0])

        reload_dataset = ':' in getattr(args, 'data', '')
        # sharded data: get train iterator for next epoch
        epoch_itr = trainer.get_train_iterator(epoch_itr.epoch,
                                               load_dataset=reload_dataset)
    train_meter.stop()
    print('| done training in {:.1f} seconds'.format(train_meter.sum))
Ejemplo n.º 10
0
def main(args):
    if not torch.cuda.is_available():
        raise NotImplementedError('Training on CPU is not supported')
    torch.cuda.set_device(args.device_id)
    if args.distributed_world_size > 1:
        assert (torch.distributed.is_initialized())
        torch.distributed.broadcast(torch.tensor([1], device="cuda"), 0)
        torch.cuda.synchronize()
    if args.max_tokens is None:
        args.max_tokens = 6000
    print(args)
    pValue = ctypes.cast((ctypes.c_int * 1)(), ctypes.POINTER(ctypes.c_int))
    result = torch.cuda.cudart().cudaDeviceSetLimit(ctypes.c_int(0x05),
                                                    ctypes.c_int(128))
    result = torch.cuda.cudart().cudaDeviceGetLimit(pValue, ctypes.c_int(0x05))
    torch.manual_seed(args.seed)

    # Setup task, e.g., translation, language modeling, etc.
    task = tasks.setup_task(args)

    # Load dataset splits
    load_dataset_splits(task, ['train', 'valid'])

    # Build model and criterion
    model = task.build_model(args)
    criterion = task.build_criterion(args)
    print('| model {}, criterion {}'.format(args.arch,
                                            criterion.__class__.__name__))
    print('| num. model params: {}'.format(
        sum(p.numel() for p in model.parameters())))

    # Build trainer
    if args.fp16:
        trainer = FP16Trainer(args, task, model, criterion)
    else:
        if torch.cuda.get_device_capability(0)[0] >= 7:
            print(
                '| NOTICE: your device may support faster training with --fp16'
            )
        trainer = Trainer(args, task, model, criterion)
    if (args.online_eval or args.target_bleu) and not args.remove_bpe:
        args.remove_bpe = '@@ '
    print('| training on {} GPUs'.format(args.distributed_world_size))
    print('| max tokens per GPU = {} and max sentences per GPU = {}'.format(
        args.max_tokens,
        args.max_sentences,
    ))
    max_positions = trainer.get_model().max_positions()
    epoch_itr = data.EpochBatchIterator(
        dataset=task.dataset(args.train_subset),
        max_tokens=args.max_tokens,
        max_sentences=args.max_sentences_valid,
        max_positions=max_positions,
        ignore_invalid_inputs=True,
        required_batch_size_multiple=8,
        seed=args.seed,
        num_shards=args.distributed_world_size,
        shard_id=args.distributed_rank,
    )
    # Load the latest checkpoint if one is available
    load_checkpoint(args, trainer, epoch_itr)

    # Send a dummy batch to warm the caching allocator
    dummy_batch = task.dataset('train').get_dummy_batch(
        args.max_tokens, max_positions)
    trainer.dummy_train_step(dummy_batch)

    # Train until the learning rate gets too small or model reaches target score
    max_epoch = args.max_epoch or math.inf
    max_update = args.max_update or math.inf
    tgt_bleu = args.target_bleu or math.inf
    current_bleu = 0.0
    best_bleu = 0.0
    lr = trainer.get_lr()
    train_meter = StopwatchMeter()
    train_meter.start()
    valid_losses = [None]
    valid_subsets = args.valid_subset.split(',')

    while lr >= args.min_lr and epoch_itr.epoch < max_epoch and trainer.get_num_updates(
    ) < max_update and current_bleu < tgt_bleu:
        # train for one epoch
        train(args, trainer, task, epoch_itr)
        if epoch_itr.epoch % args.validate_interval == 0:
            valid_losses = validate(args, trainer, task, epoch_itr,
                                    valid_subsets)

        # Eval BLEU score
        if args.online_eval or (not tgt_bleu is math.inf):
            current_bleu, current_sc_bleu = score(args, trainer, task,
                                                  epoch_itr, args.gen_subset)
            if current_bleu > best_bleu:
                best_bleu = current_bleu
                save_checkpoint(args, trainer, epoch_itr, valid_losses[0])

        # Only use first validation loss to update the learning rate
        lr = trainer.lr_step(epoch_itr.epoch, valid_losses[0])

        # Save checkpoint
        if epoch_itr.epoch % args.save_interval == 0:
            save_checkpoint(args, trainer, epoch_itr, valid_losses[0])

    train_meter.stop()
    print('| done training in {:.1f} seconds'.format(train_meter.sum))
def main(args, init_distributed=False):
    utils.import_user_module(args)
    utils.handle_save_path(args)

    assert args.max_tokens is not None or args.max_sentences is not None, \
        'Must specify batch size either with --max-tokens or --max-sentences'

    # Initialize CUDA and distributed training
    #if torch.cuda.is_available() and not args.cpu:
    #    torch.cuda.set_device(args.device_id)
    torch.manual_seed(args.seed)
    if init_distributed:
        args.distributed_rank = distributed_utils.distributed_init(args)

    if distributed_utils.is_master(args):
        checkpoint_utils.verify_checkpoint_directory(args.save_dir)

    # Print args
    print(f"| Configs: {args}")

    # Setup task, e.g., translation, language modeling, etc.
    task = tasks.setup_task(args)

    # Load valid dataset (we load training data below, based on the latest checkpoint)
    for valid_sub_split in args.valid_subset.split(','):
        task.load_dataset(valid_sub_split, combine=False, epoch=0)

    # Build model and criterion
    model = task.build_model(args)
    criterion = task.build_criterion(args)
    print(
        f"| Model: {args.arch} \n| Criterion: {criterion.__class__.__name__}")

    # Log architecture
    if args.train_subtransformer:
        print(" \n\n\t\tWARNING!!! Training one single SubTransformer\n\n")
        print(
            f"| SubTransformer Arch: {utils.get_subtransformer_config(args)} \n"
        )
    else:
        print(" \n\n\t\tWARNING!!! Training SuperTransformer\n\n")
        print(f"| SuperTransformer Arch: {model} \n")

    # Log model size
    if args.train_subtransformer:
        print(
            f"| SubTransformer size (without embedding weights): {model.get_sampled_params_numel(utils.get_subtransformer_config(args))}"
        )
        embed_size = args.decoder_embed_dim_subtransformer * len(task.tgt_dict)
        print(f"| Embedding layer size: {embed_size} \n")

    else:
        model_s = 0
        # if use model.state_dict, then will add 2 more parameters, they are encoder.version and decoder.version. Should not count them
        for name, param in model.named_parameters():
            if 'embed' not in name:
                model_s += param.numel()
        print(
            f"| SuperTransofmer model size (without embedding weights): {model_s}"
        )

        print(
            f"| Embedding layer size: {sum(p.numel() for p in model.parameters() if p.requires_grad) - model_s} \n"
        )

    # specify the length of the dummy input for profile
    # for iwslt, the average length is 23, for wmt, that is 30
    dummy_sentence_length_dict = {'iwslt': 23, 'wmt': 30}
    if 'iwslt' in args.arch:
        dummy_sentence_length = dummy_sentence_length_dict['iwslt']
    elif 'wmt' in args.arch:
        dummy_sentence_length = dummy_sentence_length_dict['wmt']
    else:
        raise NotImplementedError

    dummy_src_tokens = [2] + [7] * (dummy_sentence_length - 1)
    dummy_prev = [7] * (dummy_sentence_length - 1) + [2]

    # profile the overall FLOPs number
    if args.profile_flops:
        import torchprofile
        config_subtransformer = utils.get_subtransformer_config(args)
        model.set_sample_config(config_subtransformer)
        model.profile(mode=True)
        macs = torchprofile.profile_macs(model,
                                         args=(torch.tensor([dummy_src_tokens],
                                                            dtype=torch.long),
                                               torch.tensor([30]),
                                               torch.tensor([dummy_prev],
                                                            dtype=torch.long)))
        model.profile(mode=False)

        last_layer_macs = config_subtransformer['decoder'][
            'decoder_embed_dim'] * dummy_sentence_length * len(task.tgt_dict)

        print(f"| Total FLOPs: {macs * 2}")
        print(f"| Last layer FLOPs: {last_layer_macs * 2}")
        print(
            f"| Total FLOPs without last layer: {(macs - last_layer_macs) * 2} \n"
        )
        exit(0)
    with torch.autograd.set_detect_anomaly(True):
        # Build trainer
        trainer = Trainer(args, task, model, criterion)
    print(f"| Training on {args.distributed_world_size} GPUs")
    # print(f"| Max tokens per GPU = {args.max_tokens} and max sentences per GPU = {args.max_sentences} \n")
    print(
        f"| Max tokens per GPU = {args.max_tokens} and max sentences per GPU = {None} \n"
    )

    # Measure model latency, the program will exit after profiling latency
    if args.latcpu or args.latgpu:
        utils.measure_latency(args, model, dummy_src_tokens, dummy_prev)
        exit(0)

    # Load the latest checkpoint if one is available and restore the corresponding train iterator
    extra_state, epoch_itr = checkpoint_utils.load_checkpoint(args, trainer)

    # Evaluate the SubTransformer
    if args.validate_subtransformer:
        config = utils.get_subtransformer_config(args)
        trainer.set_sample_config(config)
        valid_loss = validate(args, trainer, task, epoch_itr, ['valid'],
                              'SubTransformer')
        print(f"| SubTransformer validation loss:{valid_loss}")

    # Loop boundaries
    max_epoch = args.max_epoch or math.inf
    max_update = args.max_update or math.inf
    lr = trainer.get_lr()

    train_meter = StopwatchMeter()
    train_meter.start()
    valid_subsets = args.valid_subset.split(',')

    represent_configs = utils.get_represent_configs(args)

    # Main training loop
    while lr > args.stop_min_lr and epoch_itr.epoch < max_epoch and trainer.get_num_updates(
    ) < max_update:
        # train for one epoch
        train(args, trainer, task, epoch_itr)

        if not args.disable_validation and epoch_itr.epoch % args.validate_interval == 0:
            for k, v in represent_configs.items():
                trainer.set_sample_config(config=v)
                valid_losses = validate(args,
                                        trainer,
                                        task,
                                        epoch_itr,
                                        valid_subsets,
                                        sampled_arch_name=k)
        else:
            valid_losses = [None]

        # update the best loss and get current lr; the real lr scheduling is done in trainer.train_step()
        lr = trainer.lr_step(epoch_itr.epoch, valid_losses[0])

        # save checkpoint epoch level
        if epoch_itr.epoch % args.save_interval == 0:
            checkpoint_utils.save_checkpoint(args, trainer, epoch_itr,
                                             valid_losses[0])

    train_meter.stop()
    print('| Done training in {:.1f} seconds'.format(train_meter.sum))
Ejemplo n.º 12
0
def baseline_with_meta_evaluation(model, meta_learning_task,
                                  meta_learning_args, meta_learning_criterion,
                                  fine_tune_args):
    meta_epoch_itr, meta_trainer, max_meta_epoch, max_meta_update, valid_subsets = prepare_meta_task(
        model=model,
        meta_learning_task=meta_learning_task,
        meta_learning_args=meta_learning_args,
        meta_learning_criterion=meta_learning_criterion)
    # Combine and do fine-tuning on combined data
    meta_train = meta_learning_task.dataset(meta_learning_args.train_subset)
    combined_fairseq_task = combine_data(meta_train=meta_train,
                                         fine_tune_args=fine_tune_args)
    # Fine-tune using the combined task
    criterion = combined_fairseq_task.build_criterion(fine_tune_args)
    import math
    from fairseq.trainer import Trainer
    combined_fairseq_task.load_dataset(fine_tune_args.train_subset)
    train_dataset = combined_fairseq_task.dataset(fine_tune_args.train_subset)
    # Make a dummy batch to (i) warm the caching allocator and (ii) as a  placeholder DistributedDataParallel when
    # there's an uneven number of batches per worker.
    max_positions = utils.resolve_max_positions(
        combined_fairseq_task.max_positions(),
        model.max_positions(),
    )
    dummy_batch = train_dataset.get_dummy_batch(
        num_tokens=fine_tune_args.max_tokens, max_positions=max_positions)
    oom_batch = combined_fairseq_task.dataset(
        fine_tune_args.train_subset).get_dummy_batch(1, max_positions)
    # Create a trainer for training the model
    trainer = Trainer(fine_tune_args, combined_fairseq_task, model, criterion,
                      dummy_batch, oom_batch)
    epoch_itr = utils.create_epoch_iterator(task=combined_fairseq_task,
                                            dataset=train_dataset,
                                            args=fine_tune_args,
                                            max_positions=max_positions)
    max_epoch = fine_tune_args.max_epoch or math.inf
    max_update = fine_tune_args.max_update or math.inf
    # Do SGD on this task
    valid_subsets = fine_tune_args.valid_subset.split(',')
    lr = trainer.get_lr()
    batch_info = []
    # Always validate once before training
    valid_losses, _ = utils.validate(fine_tune_args, trainer,
                                     combined_fairseq_task, epoch_itr,
                                     valid_subsets)
    while lr > fine_tune_args.min_lr and epoch_itr.epoch < max_epoch and trainer.get_num_updates(
    ) < max_update:
        # Train the model for one epoch
        import collections
        import math
        from fairseq.data import iterators
        from fairseq import progress_bar
        from fairseq.meters import AverageMeter, ConcatentateMeter, BleuMeter
        """Train the model for one epoch."""
        # Update parameters every N batches
        update_freq = fine_tune_args.update_freq[epoch_itr.epoch - 1] \
            if epoch_itr.epoch <= len(fine_tune_args.update_freq) else fine_tune_args.update_freq[-1]

        # Initialize data iterator
        itr = epoch_itr.next_epoch_itr(
            fix_batches_to_gpus=fine_tune_args.fix_batches_to_gpus,
            shuffle=(epoch_itr.epoch >= fine_tune_args.curriculum),
        )
        itr = iterators.GroupedIterator(itr, update_freq)
        progress = progress_bar.build_progress_bar(
            fine_tune_args,
            itr,
            epoch_itr.epoch,
            no_progress_bar='simple',
        )

        extra_meters = collections.defaultdict(lambda: AverageMeter())
        extra_meters['strings'] = ConcatentateMeter()
        extra_meters['bleu_stats'] = BleuMeter()

        valid_subsets = fine_tune_args.valid_subset.split(',')
        max_update = fine_tune_args.max_update or math.inf
        for i, samples in enumerate(progress,
                                    start=epoch_itr.iterations_in_epoch):
            log_output = trainer.train_step(samples)
            if log_output is None:
                continue

            # log mid-epoch stats
            stats = utils.get_training_stats(trainer)
            for k, v in log_output.items():
                if k in [
                        'loss', 'nll_loss', 'ntokens', 'nsentences',
                        'sample_size'
                ]:
                    continue  # these are already logged above
                if 'loss' in k:
                    extra_meters[k].update(v, log_output['sample_size'])
                else:
                    extra_meters[k].update(v)
                stats[k] = extra_meters[k].avg
            progress.log(stats,
                         tag=fine_tune_args.train_subset,
                         step=stats['num_updates'])

            # ignore the first mini-batch in words-per-second calculation
            if i == 0:
                trainer.get_meter('wps').reset()

            num_updates = trainer.get_num_updates()
            if fine_tune_args.save_interval_updates > 0 and num_updates % fine_tune_args.save_interval_updates == 0 and num_updates > 0:
                valid_losses, _ = utils.validate(fine_tune_args,
                                                 trainer,
                                                 combined_fairseq_task,
                                                 epoch_itr,
                                                 valid_subsets,
                                                 train_progress=progress)
                utils.save_checkpoint(fine_tune_args, trainer, epoch_itr,
                                      valid_losses[0])

            if num_updates >= max_update:
                break

        # log end-of-epoch stats
        stats = utils.get_training_stats(trainer)
        for k, meter in extra_meters.items():
            stats[k] = meter.avg
            stats[k + '_std'] = meter.std
        progress.print(stats,
                       tag=fine_tune_args.train_subset,
                       step=stats['num_updates'])

        # reset training meters
        for k in [
                'train_loss',
                'train_nll_loss',
                'wps',
                'ups',
                'wpb',
                'bsz',
                'gnorm',
                'clip',
        ]:
            meter = trainer.get_meter(k)
            if meter is not None:
                meter.reset()
        # Evaluate on validation split
        if epoch_itr.epoch % fine_tune_args.validate_interval == 0:
            valid_losses, _ = utils.validate(fine_tune_args, trainer,
                                             combined_fairseq_task, epoch_itr,
                                             valid_subsets)
        # save checkpoint
        if epoch_itr.epoch % fine_tune_args.save_interval == 0:
            utils.save_checkpoint(fine_tune_args, trainer, epoch_itr,
                                  valid_losses[0])
        # only use first validation loss to update the learning rate
        lr = trainer.lr_step(epoch_itr.epoch, valid_losses[0])
    if batch_info is None:
        # Handle the original train function
        batch_info = []
    # Evaluate on validation split
    maybe_validate(meta_epoch_itr=meta_epoch_itr,
                   meta_learning_args=meta_learning_args,
                   meta_trainer=meta_trainer,
                   meta_learning_task=meta_learning_task,
                   valid_subsets=valid_subsets)
Ejemplo n.º 13
0
def sub_main(args, init_distributed=False):
    utils.import_user_module(args)

    assert args.max_tokens is not None or args.max_sentences is not None, \
        'Must specify batch size either with --max-tokens or --max-sentences'

    # Initialize CUDA and distributed training
    if torch.cuda.is_available() and not args.cpu:
        torch.cuda.set_device(args.device_id)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if init_distributed:
        args.distributed_rank = distributed_utils.distributed_init(args)

    if distributed_utils.is_master(args):
        checkpoint_utils.verify_checkpoint_directory(args.save_dir)

    # Print args
    logger.info(args)
    # Setup task, e.g., translation, language modeling, etc.
    task = tasks.setup_task(args)

    # Load valid dataset (we load training data below, based on the latest checkpoint)
    for valid_sub_split in args.valid_subset.split(','):
        task.load_dataset(valid_sub_split, combine=False, epoch=1)

    # Build model and criterion
    model = task.build_model(args)
    criterion = task.build_criterion(args)
    logger.info(model)
    logger.info('model {}, criterion {}'.format(args.arch,
                                                criterion.__class__.__name__))
    logger.info('num. model params: {} (num. trained: {})'.format(
        sum(p.numel() for p in model.parameters()),
        sum(p.numel() for p in model.parameters() if p.requires_grad),
    ))

    # Build trainer
    trainer = Trainer(args, task, model, criterion)
    logger.info('training on {} GPUs'.format(args.distributed_world_size))
    logger.info(
        'max tokens per GPU = {} and max sentences per GPU = {}'.format(
            args.max_tokens,
            args.max_sentences,
        ))

    # Load the latest checkpoint if one is available and restore the
    # corresponding train iterator
    extra_state, epoch_itr = checkpoint_utils.load_checkpoint(args, trainer)

    # Train until the learning rate gets too small
    max_epoch = args.max_epoch or math.inf
    max_update = args.max_update or math.inf
    lr = trainer.get_lr()
    train_meter = meters.StopwatchMeter()
    train_meter.start()
    valid_subsets = args.valid_subset.split(',')
    while (lr > args.min_lr and epoch_itr.next_epoch_idx <= max_epoch
           and trainer.get_num_updates() < max_update):
        # train for one epoch
        train(args, trainer, task, epoch_itr)

        if not args.disable_validation and epoch_itr.epoch % args.validate_interval == 0:
            valid_losses = validate(args, trainer, task, epoch_itr,
                                    valid_subsets)
        else:
            valid_losses = [None]

        # only use first validation loss to update the learning rate
        lr = trainer.lr_step(epoch_itr.epoch, valid_losses[0])

        # save checkpoint
        if epoch_itr.epoch % args.save_interval == 0:
            checkpoint_utils.save_checkpoint(args, trainer, epoch_itr,
                                             valid_losses[0])
            if args.distributed_rank == 0:
                print('Saving checkpoint to ml flow...')
                start_time = time()
                mlflow.log_artifact(args.save_dir + '/checkpoint_best.pt')
                mlflow.log_artifact(args.save_dir + '/checkpoint_last.pt')
                print('Took {} seconds.'.format(time() - start_time))

        # early stop
        if should_stop_early(args, valid_losses[0]):
            logger.info(
                'early stop since valid performance hasn\'t improved for last {} runs'
                .format(args.patience))
            break

        epoch_itr = trainer.get_train_iterator(
            epoch_itr.next_epoch_idx,
            # sharded data: get train iterator for next epoch
            load_dataset=(os.pathsep in getattr(args, 'data', '')),
        )
    train_meter.stop()
    logger.info('done training in {:.1f} seconds'.format(train_meter.sum))
Ejemplo n.º 14
0
def main(args):
    print(args)

    if not torch.cuda.is_available():
        raise NotImplementedError('Training on CPU is not supported')
    torch.cuda.set_device(args.device_id)
    torch.manual_seed(args.seed)

    # Load dataset
    splits = ['train', 'valid']
    if data.has_binary_files(args.data, splits):
        dataset = data.load_dataset(args.data, splits, args.source_lang,
                                    args.target_lang)
    else:
        dataset = data.load_raw_text_dataset(args.data, splits,
                                             args.source_lang,
                                             args.target_lang)
    if args.source_lang is None or args.target_lang is None:
        # record inferred languages in args, so that it's saved in checkpoints
        args.source_lang, args.target_lang = dataset.src, dataset.dst
    print('| [{}] dictionary: {} types'.format(dataset.src,
                                               len(dataset.src_dict)))
    print('| [{}] dictionary: {} types'.format(dataset.dst,
                                               len(dataset.dst_dict)))
    for split in splits:
        print('| {} {} {} examples'.format(args.data, split,
                                           len(dataset.splits[split])))

    # Build model and criterion
    model = models.build_model(args, dataset.src_dict, dataset.dst_dict)
    if 0.0 < args.rank_scale < 1.0:
        patch_transformer(args, model)
        if args.wd2fd:
            no_decay, skiplist = [
                'fc1', 'fc2', 'embed_tokens', 'embed_positions', 'out_embed'
            ], []
        else:
            no_decay, skiplist = [], [
                'fc1', 'fc2', 'embed_tokens', 'embed_positions', 'out_embed'
            ]
    else:
        no_decay, skiplist = [], []
    spectral_init(args, model)

    criterion = criterions.build_criterion(args, dataset.src_dict,
                                           dataset.dst_dict)
    print('| model {}, criterion {}'.format(args.arch,
                                            criterion.__class__.__name__))
    print('| num. model params: {}'.format(
        sum(p.data.numel() for p in model.parameters())))

    # Build trainer
    no_decay, skiplist = [], []
    if args.wd2fd_quekey:
        no_decay.extend(['_query.weight', '_key.weight'])
    else:
        skiplist.append('quekey')
    if args.wd2fd_outval:
        no_decay.extend(['_value.weight', 'output_perform.weight'])
    else:
        skiplist.append('outval')

    trainer = Trainer(args,
                      model,
                      criterion,
                      skiplist=skiplist,
                      no_decay=no_decay)
    print('| training on {} GPUs'.format(args.distributed_world_size))
    print('| max tokens per GPU = {} and max sentences per GPU = {}'.format(
        args.max_tokens,
        args.max_sentences,
    ))

    # Load the latest checkpoint if one is available
    os.makedirs(args.save_dir, exist_ok=True)
    checkpoint_path = os.path.join(args.save_dir, args.restore_file)
    extra_state = trainer.load_checkpoint(checkpoint_path)
    if extra_state is not None:
        epoch = extra_state['epoch']
        batch_offset = extra_state['batch_offset']
        print('| loaded checkpoint {} (epoch {})'.format(
            checkpoint_path, epoch))
        if batch_offset == 0:
            trainer.lr_step(epoch)
            epoch += 1
    else:
        epoch, batch_offset = 1, 0

    if args.distributed_rank <= 0:
        writer = SummaryWriter(args.save_dir)
        with open(os.path.join(args.save_dir, 'args.json'), 'w') as f:
            json.dump(vars(args), f, indent=4)
    else:
        writer = SummaryWriter(
            os.path.join(args.save_dir, str(args.distributed_rank)))

    # Train until the learning rate gets too small
    max_epoch = args.max_epoch or math.inf
    max_update = args.max_update or math.inf
    lr = trainer.get_lr()
    train_meter = StopwatchMeter()
    train_meter.start()
    while lr > args.min_lr and epoch <= max_epoch:

        if args.distributed_rank <= 0:
            writer.add_scalar('hyper/lr', lr, epoch)
            for form in ['QueKey', 'OutVal']:
                frobnorm, nucnorm, bound, nonorth = [], [], [], []
                for module in model.modules():
                    if hasattr(module, form.lower()):
                        U, VT = getattr(module, form.lower()).get_UVT()
                        for u, vt in zip(U, VT):
                            frobnorm.append(frobenius_norm(u, vt))
                            nucnorm.append(
                                torch.norm(torch.matmul(u, vt), 'nuc'))
                            bound.append(
                                (u.pow(2).sum() + vt.pow(2).sum()) / 2.)
                            nonorth.append(sum(non_orthogonality(u, vt)) / 2.)
                writer.add_scalar('FrobNorm/' + form,
                                  sum(frobnorm) / len(frobnorm), epoch)
                writer.add_scalar('NucNorm/' + form,
                                  sum(nucnorm) / len(nucnorm), epoch)
                writer.add_scalar('NucNorm/' + form + '-Bound',
                                  sum(bound) / len(bound), epoch)
                writer.add_scalar('NonOrth/' + form,
                                  sum(nonorth) / len(nonorth), epoch)
            frobnorm, nucnorm, bound, nonorth = [], [], [], []
            for name, module in model.named_modules():
                if not any(
                        block in name for block in
                    ['embed', '_query', '_key', '_value', 'output_perform']):
                    if hasattr(module,
                               'frobgrad') and not hasattr(module, 'get_UVT'):
                        U, VT = module.U.data, module.VT.data
                        frobnorm.append(frobenius_norm(U, VT))
                        nucnorm.append(torch.norm(torch.matmul(U, VT), 'nuc'))
                        nonorth.append(sum(non_orthogonality(U, VT)) / 2.)
                        bound.append((U.pow(2).sum() + VT.pow(2).sum()) / 2.)
                    elif hasattr(module, 'weight'):
                        frobnorm.append(torch.norm(module.weight.data))
                        nucnorm.append(torch.norm(module.weight.data, 'nuc'))
            writer.add_scalar('FrobNorm/Linear',
                              sum(frobnorm) / len(frobnorm), epoch)
            writer.add_scalar('NucNorm/Linear',
                              sum(nucnorm) / len(nucnorm), epoch)
            if nonorth:
                writer.add_scalar('NucNorm/Linear-Bound',
                                  sum(bound) / len(bound), epoch)
                writer.add_scalar('NonOrth/Linear',
                                  sum(nonorth) / len(nonorth), epoch)

        # train for one epoch
        train(args, trainer, dataset, epoch, batch_offset)

        # evaluate on validate set
        if epoch % args.validate_interval == 0:
            for k, subset in enumerate(args.valid_subset.split(',')):
                val_loss = validate(args, trainer, dataset, subset, epoch)
                if k == 0:
                    # only use first validation loss to update the learning schedule
                    lr = trainer.lr_step(epoch, val_loss)

                    # save checkpoint
                    if not args.no_save:
                        save_checkpoint(trainer, args, epoch, 0, val_loss)
            for k in ['loss', 'nll_loss']:
                writer.add_scalar('valid/' + k,
                                  trainer.meters['valid_' + k].avg, epoch)
                writer.add_scalar('train/' + k,
                                  trainer.meters['train_' + k].avg, epoch)
        else:
            lr = trainer.lr_step(epoch)

        epoch += 1
        batch_offset = 0

        if trainer.get_num_updates() >= max_update:
            break
    train_meter.stop()

    print('| done training in {:.1f} seconds'.format(train_meter.sum))
    writer.flush()
    newpar = sum(p.numel() for p in model.parameters())
    if 0.0 < args.rank_scale < 1.0:
        args.rank_scale = 1.0
        origpar = sum(p.numel() for p in models.build_model(
            args, dataset.src_dict, dataset.dst_dict).parameters())
    else:
        origpar = newpar
    if args.distributed_rank <= 0:
        with open(os.path.join(args.save_dir, 'results.json'), 'w') as f:
            json.dump(
                {
                    'final validation loss':
                    trainer.meters['valid_nll_loss'].avg,
                    'original parameter count': origpar,
                    'compressed parameter count': newpar,
                    'compression ratio': newpar / origpar
                },
                f,
                indent=4)
Ejemplo n.º 15
0
def main(args):
    if not torch.cuda.is_available():
        raise NotImplementedError('Training on CPU is not supported')
    torch.cuda.set_device(args.device_id)
    from mlperf_compliance.mlperf_log import transformer_print
    transformer_print(
        key=mlperf_log.RUN_CLEAR_CACHES
    )  #before this tag we should run clearing caches on the host
    # mlperf compliance synchronization
    if args.distributed_world_size > 1:
        assert (torch.distributed.is_initialized())
        torch.distributed.broadcast(torch.tensor([1], device="cuda"), 0)
        torch.cuda.synchronize()
    transformer_print(key=mlperf_log.RUN_START)
    if args.max_tokens is None:
        args.max_tokens = 6000
    print(args)
    transformer_print(key=mlperf_log.OPT_NAME, value=args.optimizer)
    transformer_print(key=mlperf_log.OPT_LR, value=args.lr)
    transformer_print(key=mlperf_log.OPT_HP_ADAM_BETA1,
                      value=eval(args.adam_betas)[0])
    transformer_print(key=mlperf_log.OPT_HP_ADAM_BETA2,
                      value=eval(args.adam_betas)[1])
    transformer_print(key=mlperf_log.OPT_HP_ADAM_EPSILON, value=args.adam_eps)
    pValue = ctypes.cast((ctypes.c_int * 1)(), ctypes.POINTER(ctypes.c_int))
    result = torch.cuda.cudart().cudaDeviceSetLimit(ctypes.c_int(0x05),
                                                    ctypes.c_int(128))
    result = torch.cuda.cudart().cudaDeviceGetLimit(pValue, ctypes.c_int(0x05))
    torch.manual_seed(args.seed)
    transformer_print(key=mlperf_log.RUN_SET_RANDOM_SEED, value=args.seed)

    # Setup task, e.g., translation, language modeling, etc.
    task = tasks.setup_task(args)
    transformer_print(key=mlperf_log.MODEL_HP_SEQ_BEAM_SEARCH,
                      value={
                          'alpha': args.lenpen,
                          'beam_size': args.beam,
                          'extra_decode_length': args.max_len_b,
                          'vocab_size': task.target_dictionary.__len__()
                      })

    # Load dataset splits
    load_dataset_splits(task, ['train', 'valid'])

    # Build model and criterion
    model = task.build_model(args)
    criterion = task.build_criterion(args)
    print('| model {}, criterion {}'.format(args.arch,
                                            criterion.__class__.__name__))
    print('| num. model params: {}'.format(
        sum(p.numel() for p in model.parameters())))

    # Build trainer
    if args.fp16:
        trainer = FP16Trainer(args, task, model, criterion)
    else:
        if torch.cuda.get_device_capability(0)[0] >= 7:
            print(
                '| NOTICE: your device may support faster training with --fp16'
            )
        trainer = Trainer(args, task, model, criterion)
    if (args.online_eval or args.target_bleu) and not args.remove_bpe:
        args.remove_bpe = '@@ '
    print('| training on {} GPUs'.format(args.distributed_world_size))
    print('| max tokens per GPU = {} and max sentences per GPU = {}'.format(
        args.max_tokens,
        args.max_sentences,
    ))
    transformer_print(key=mlperf_log.INPUT_BATCH_SIZE, value=args.max_tokens)
    transformer_print(key=mlperf_log.INPUT_ORDER)
    # Initialize dataloader
    max_positions = trainer.get_model().max_positions()

    # Send a dummy batch to warm the caching allocator
    dummy_batch = task.dataset('train').get_dummy_batch(
        args.max_tokens, max_positions)
    trainer.dummy_train_step(dummy_batch)

    # Train until the learning rate gets too small or model reaches target score
    max_epoch = args.max_epoch or math.inf
    max_update = args.max_update or math.inf
    tgt_bleu = args.target_bleu or math.inf
    current_bleu = 0.0
    lr = trainer.get_lr()
    train_meter = StopwatchMeter()
    train_meter.start()
    valid_losses = [None]
    valid_subsets = args.valid_subset.split(',')

    ctr = 0

    class DummyEpochBatchIterator:
        def __init__(self, epoch=0):
            self.epoch = epoch

    epoch_itr = DummyEpochBatchIterator(0)
    transformer_print(key=mlperf_log.TRAIN_LOOP)
    while lr >= args.min_lr and epoch_itr.epoch < max_epoch and trainer.get_num_updates(
    ) < max_update and current_bleu < tgt_bleu:
        transformer_print(key=mlperf_log.TRAIN_EPOCH, value=epoch_itr.epoch)
        import time
        start = time.time()
        epoch_itr = data.EpochBatchIterator(
            dataset=task.dataset(args.train_subset),
            max_tokens=args.max_tokens,
            max_sentences=args.max_sentences_valid,
            max_positions=max_positions,
            ignore_invalid_inputs=True,
            required_batch_size_multiple=8,
            seed=args.seed,
            num_shards=args.distributed_world_size,
            shard_id=args.distributed_rank,
            epoch=epoch_itr.epoch if ctr is not 0 else 0)
        print("got epoch iterator", time.time() - start)

        # Load the latest checkpoint if one is available
        if ctr is 0:
            load_checkpoint(args, trainer, epoch_itr)

        # train for one epoch
        start = time.time()
        train(args, trainer, task, epoch_itr)
        print("epoch time ", time.time() - start)

        start = time.time()

        if epoch_itr.epoch % args.validate_interval == 0:
            valid_losses = validate(args, trainer, task, epoch_itr,
                                    valid_subsets)

        # Eval BLEU score
        transformer_print(key=mlperf_log.EVAL_START, value=epoch_itr.epoch)
        if args.online_eval or (not tgt_bleu is math.inf):
            current_bleu = score(args, trainer, task, epoch_itr,
                                 args.gen_subset)
            transformer_print(key=mlperf_log.EVAL_ACCURACY,
                              value={
                                  'epoch': epoch_itr.epoch,
                                  'value': current_bleu
                              })
            transformer_print(key=mlperf_log.EVAL_TARGET, value=tgt_bleu)
        transformer_print(key=mlperf_log.EVAL_STOP, value=epoch_itr.epoch)

        # Only use first validation loss to update the learning rate
        lr = trainer.lr_step(epoch_itr.epoch, valid_losses[0])

        # Save checkpoint
        if epoch_itr.epoch % args.save_interval == 0:
            save_checkpoint(args, trainer, epoch_itr, valid_losses[0])

        ctr = ctr + 1
        print("validation and scoring ", time.time() - start)

    train_meter.stop()
    transformer_print(key=mlperf_log.RUN_STOP)
    transformer_print(key=mlperf_log.RUN_FINAL)
    print('| done training in {:.1f} seconds'.format(train_meter.sum))
Ejemplo n.º 16
0
def main(args):
    import_user_module(args)

    assert (
        args.max_tokens is not None or args.batch_size is not None
    ), "Must specify batch size either with --max-tokens or --batch-size"

    metrics.reset()

    np.random.seed(args.seed)
    utils.set_torch_seed(args.seed)

    if distributed_utils.is_master(args):
        checkpoint_utils.verify_checkpoint_directory(args.save_dir)

    # Print args
    logger.info(args)

    # Setup task, e.g., translation, language modeling, etc.
    task = tasks.setup_task(args)

    # Load valid dataset (we load training data below, based on the latest checkpoint)
    for valid_sub_split in args.valid_subset.split(","):
        task.load_dataset(valid_sub_split, combine=False, epoch=1)

    # Build model and criterion
    model = task.build_model(args)
    criterion = task.build_criterion(args)
    logger.info(model)
    logger.info("task: {} ({})".format(args.task, task.__class__.__name__))
    logger.info("model: {} ({})".format(args.arch, model.__class__.__name__))
    logger.info("criterion: {} ({})".format(args.criterion,
                                            criterion.__class__.__name__))
    logger.info("num. model params: {} (num. trained: {})".format(
        sum(p.numel() for p in model.parameters()),
        sum(p.numel() for p in model.parameters() if p.requires_grad),
    ))

    # breakpoint()

    # ========== initialize the model with pretrained BART parameters ==========
    # for shared embeddings and subtoken split for amr nodes
    if 'bartsv' in args.arch:

        if args.initialize_with_bart:
            logger.info(
                '-' * 10 +
                ' initializing model parameters with pretrained BART model ' +
                '-' * 10)

            new_state_dict = copy.deepcopy(task.bart.model.state_dict())
            # treat the embedding initialization separately later, as the size different
            logger.info(
                '-' * 10 +
                ' delay encoder embeddings, decoder input and output embeddings initialization '
                + '-' * 10)
            ignore_keys = set([
                'encoder.embed_tokens.weight', 'decoder.embed_tokens.weight',
                'decoder.output_projection.weight'
            ])
            for k in ignore_keys:
                del new_state_dict[k]

            if not args.initialize_with_bart_enc:
                logger.info(
                    '-' * 10 +
                    ' do not initialize with BART encoder parameters ' +
                    '-' * 10)
                for k in list(new_state_dict.keys()):
                    if k.startswith('encoder'):
                        del new_state_dict[k]

            if not args.initialize_with_bart_dec:
                logger.info(
                    '-' * 10 +
                    ' do not initialize with BART decoder parameters ' +
                    '-' * 10)
                for k in list(new_state_dict.keys()):
                    if k.startswith('decoder'):
                        del new_state_dict[k]

            model.load_state_dict(new_state_dict, strict=False, args=args)

            # initialize the Bart part embeddings
            bart_vocab_size = task.target_dictionary.bart_vocab_size
            # NOTE we need to prune the pretrained BART embeddings, especially for bart.base
            bart_embed_weight = task.bart.model.encoder.embed_tokens.weight.data[:
                                                                                 bart_vocab_size]
            assert len(bart_embed_weight) == bart_vocab_size

            with torch.no_grad():
                model.encoder.embed_tokens.weight[:bart_vocab_size].copy_(
                    bart_embed_weight)
                model.decoder.embed_tokens.weight[:bart_vocab_size].copy_(
                    bart_embed_weight)
                model.decoder.output_projection.weight[:bart_vocab_size].copy_(
                    bart_embed_weight)

        if args.bart_emb_init_composition:
            logger.info(
                '-' * 10 +
                ' initialize extended target embeddings with compositional embeddings '
                'from BART vocabulary ' + '-' * 10)

            # breakpoint()
            symbols = [
                task.target_dictionary[idx]
                for idx in range(bart_vocab_size, len(task.target_dictionary))
            ]
            mapper = MapAvgEmbeddingBART(task.bart,
                                         task.bart.model.decoder.embed_tokens)
            comp_embed_weight, map_all = mapper.map_avg_embeddings(
                symbols, transform=transform_action_symbol, add_noise=False)
            assert len(comp_embed_weight) == len(symbols)

            with torch.no_grad():
                model.encoder.embed_tokens.weight[bart_vocab_size:].copy_(
                    comp_embed_weight)
                model.decoder.embed_tokens.weight[bart_vocab_size:].copy_(
                    comp_embed_weight)
                model.decoder.output_projection.weight[bart_vocab_size:].copy_(
                    comp_embed_weight)

    elif 'bart' in args.arch:

        if args.initialize_with_bart:
            logger.info(
                '-' * 10 +
                ' initializing model parameters with pretrained BART model ' +
                '-' * 10)

            new_state_dict = copy.deepcopy(task.bart.model.state_dict())
            if not args.bart_emb_decoder:
                logger.info('-' * 10 +
                            ' build a separate decoder dictionary embedding ' +
                            '-' * 10)
                if not args.bart_emb_decoder_input:
                    ignore_keys = set([
                        'decoder.embed_tokens.weight',
                        'decoder.output_projection.weight'
                    ])
                else:
                    logger.info(
                        '-' * 10 +
                        ' use BART dictionary embedding for target input ' +
                        '-' * 10)
                    ignore_keys = set(['decoder.output_projection.weight'])
                for k in ignore_keys:
                    del new_state_dict[k]

            if not args.initialize_with_bart_enc:
                logger.info(
                    '-' * 10 +
                    ' do not initialize with BART encoder parameters ' +
                    '-' * 10)
                for k in list(new_state_dict.keys()):
                    if k.startswith('encoder'):
                        del new_state_dict[k]

            if not args.initialize_with_bart_dec:
                logger.info(
                    '-' * 10 +
                    ' do not initialize with BART decoder parameters ' +
                    '-' * 10)
                for k in list(new_state_dict.keys()):
                    if k.startswith('decoder'):
                        del new_state_dict[k]

            model.load_state_dict(new_state_dict, strict=False, args=args)

        # initialize the target embeddings with average of subtoken embeddings in BART vocabulary
        if args.bart_emb_init_composition:
            assert not args.bart_emb_decoder, 'should not use the compositional embeddings on top of BART vocabulary here'
            logger.info(
                '-' * 10 +
                ' initialize target embeddings with compositional embeddings from BART vocabulary '
                + '-' * 10)
            composite_embed = CompositeEmbeddingBART(
                task.bart, task.bart.model.decoder.embed_tokens,
                task.target_dictionary)
            if args.bart_emb_decoder_input:
                # only initialize the decoder output embeddings
                with torch.no_grad():
                    model.decoder.output_projection.weight.copy_(
                        composite_embed.embedding_weight)
            else:
                # initialize both the decoder input and output embeddings
                with torch.no_grad():
                    model.decoder.embed_tokens.weight.copy_(
                        composite_embed.embedding_weight)
                    model.decoder.output_projection.weight.copy_(
                        composite_embed.embedding_weight)

    elif 'roberta' in args.arch:
        # initialize the target embeddings with average of subtoken embeddings in BART vocabulary
        if args.bart_emb_init_composition:
            assert not args.bart_emb_decoder, 'should not use the compositional embeddings on top of RoBERTa vocabulary here'
            logger.info(
                '-' * 10 +
                ' initialize target embeddings with compositional embeddings from RoBERTa vocabulary '
                + '-' * 10)
            composite_embed = CompositeEmbeddingBART(
                task.bart,  # NOTE here "bart" means roberta
                task.bart.model.encoder.sentence_encoder.embed_tokens,
                task.target_dictionary)

            if args.bart_emb_decoder_input:
                # only initialize the decoder output embeddings
                with torch.no_grad():
                    model.decoder.output_projection.weight.copy_(
                        composite_embed.embedding_weight)
            else:
                # initialize both the decoder input and output embeddings
                with torch.no_grad():
                    model.decoder.embed_tokens.weight.copy_(
                        composite_embed.embedding_weight)
                    model.decoder.output_projection.weight.copy_(
                        composite_embed.embedding_weight)

    else:
        raise ValueError
    # ==========================================================================

    # breakpoint()

    # (optionally) Configure quantization
    if args.quantization_config_path is not None:
        quantizer = quantization_utils.Quantizer(
            config_path=args.quantization_config_path,
            max_epoch=args.max_epoch,
            max_update=args.max_update,
        )
    else:
        quantizer = None

    # Build trainer
    if args.model_parallel_size == 1:
        trainer = Trainer(args, task, model, criterion, quantizer)
    else:
        trainer = MegatronTrainer(args, task, model, criterion)

    logger.info("training on {} devices (GPUs/TPUs)".format(
        args.distributed_world_size))
    logger.info(
        "max tokens per GPU = {} and max sentences per GPU = {}".format(
            args.max_tokens, args.batch_size))

    # Load the latest checkpoint if one is available and restore the
    # corresponding train iterator
    extra_state, epoch_itr = checkpoint_utils.load_checkpoint(
        args,
        trainer,
        # don't cache epoch iterators for sharded datasets
        disable_iterator_cache=task.has_sharded_data("train"),
    )

    # Train until the learning rate gets too small
    max_epoch = args.max_epoch or math.inf
    lr = trainer.get_lr()
    train_meter = meters.StopwatchMeter()
    train_meter.start()

    while lr > args.min_lr and epoch_itr.next_epoch_idx <= max_epoch:
        # train for one epoch
        valid_losses, should_stop = train(args, trainer, task, epoch_itr)
        if should_stop:
            break

        # only use first validation loss to update the learning rate
        lr = trainer.lr_step(epoch_itr.epoch, valid_losses[0])

        epoch_itr = trainer.get_train_iterator(
            epoch_itr.next_epoch_idx,
            # sharded data: get train iterator for next epoch
            load_dataset=task.has_sharded_data("train"),
            # don't cache epoch iterators for sharded datasets
            disable_iterator_cache=task.has_sharded_data("train"),
        )
    train_meter.stop()
    logger.info("done training in {:.1f} seconds".format(train_meter.sum))
Ejemplo n.º 17
0
def main(args, init_distributed=False):
    utils.import_user_module(args)

    assert args.max_tokens is not None or args.max_sentences is not None, \
        'Must specify batch size either with --max-tokens or --max-sentences'

    # Initialize CUDA and distributed training
    if torch.cuda.is_available() and not args.cpu:
        torch.cuda.set_device(args.device_id)
    torch.manual_seed(args.seed)
    if init_distributed:
        raise ValueError("Distibuted training not supported by multiobj "
                         "training")

    # Print args
    print(args)

    # Setup task, e.g., translation, language modeling, etc.
    task = tasks.setup_task(args)

    # Load valid dataset (we load training data below, based on the latest
    # checkpoint)
    for valid_sub_split in args.valid_subset.split(','):
        task.load_dataset(valid_sub_split, combine=False, epoch=0)

    # Build model and criterion
    if args.restore_file is not None:
        # Load from checkpoint
        print('| loading model from {}'.format(args.restore_file))
        [model], _model_args = checkpoint_utils.load_model_ensemble(
            [args.restore_file],
            arg_overrides=eval(args.model_overrides),
            task=task,
        )
        # Overwrite architecture arguments
        # (this is very hacky but I don't know a better way)
        for k, v in _model_args.__dict__.items():
            is_model_argument = k == "arch"
            is_model_argument |= k.startswith("encoder_")
            is_model_argument |= k.startswith("decoder_")
            is_model_argument |= k.startswith("share_")
            is_model_argument |= k.startswith("adaptive_")
            if hasattr(args, k) and is_model_argument:
                setattr(args, k, v)
    else:
        # Or build model from scratch
        model = task.build_model(args)

    # Training criterion
    criterion = task.build_criterion(args)
    print(model)
    print('| model {}, criterion {}'.format(args.arch,
                                            criterion.__class__.__name__))
    print('| num. model params: {} (num. trained: {})'.format(
        sum(p.numel() for p in model.parameters()),
        sum(p.numel() for p in model.parameters() if p.requires_grad),
    ))

    # Build trainer
    trainer = Trainer(args, task, model, criterion)
    print('| training on {} GPUs'.format(args.distributed_world_size))
    print('| max tokens per GPU = {} and max sentences per GPU = {}'.format(
        args.max_tokens,
        args.max_sentences,
    ))

    # Load the latest checkpoint if one is available and restore the
    # corresponding train iterator
    extra_state, epoch_itr = checkpoint_utils.load_checkpoint(args, trainer)

    # Load auxiliary data
    epoch_aux_itr = task.get_batch_iterator(
        dataset=task.dataset(args.train_subset, idx=1),
        max_tokens=args.max_tokens,
        max_sentences=args.max_sentences,
        max_positions=utils.resolve_max_positions(
            task.max_positions(),
            trainer.model.max_positions(),
        ),
        ignore_invalid_inputs=True,
        required_batch_size_multiple=args.required_batch_size_multiple,
        seed=args.seed,
        num_shards=args.distributed_world_size,
        shard_id=args.distributed_rank,
        num_workers=args.num_workers,
        epoch=0,
    )

    # Estimate fisher if needed
    if args.inverse_fisher or args.ewc > 0:
        fisher_itr = task.get_batch_iterator(
            dataset=task.dataset(args.train_subset, idx=1),
            max_tokens=args.max_tokens,
            max_sentences=1,
            max_positions=utils.resolve_max_positions(
                task.max_positions(),
                trainer.model.max_positions(),
            ),
            ignore_invalid_inputs=True,
            required_batch_size_multiple=args.required_batch_size_multiple,
            seed=args.seed,
            num_shards=args.distributed_world_size,
            shard_id=args.distributed_rank,
            num_workers=args.num_workers,
            epoch=0,
        )
        fim = estimate_diagonal_fisher(args,
                                       trainer,
                                       fisher_itr,
                                       args.n_fisher_samples,
                                       precomputed=args.precomputed_fisher)
        trainer.fim = fim
    # EWC
    if args.ewc > 0.0:
        trainer.prepare_ewc(args.ewc)

    # Train until the learning rate gets too small
    max_epoch = args.max_epoch or math.inf
    max_update = args.max_update or math.inf
    lr = trainer.get_lr()
    train_meter = StopwatchMeter()
    train_meter.start()
    valid_losses = [None]
    valid_subsets = args.valid_subset.split(',')
    while lr > args.min_lr and epoch_itr.epoch < max_epoch and trainer.get_num_updates(
    ) < max_update:
        # train for one epoch
        train(args, trainer, task, epoch_itr, epoch_aux_itr)

        if not args.disable_validation and epoch_itr.epoch % args.validate_interval == 0:
            valid_losses = validate(args, trainer, task, epoch_itr,
                                    valid_subsets)
        else:
            valid_losses = [None]

        # only use first validation loss to update the learning rate
        lr = trainer.lr_step(epoch_itr.epoch, valid_losses[0])

        # save checkpoint
        if epoch_itr.epoch % args.save_interval == 0:
            checkpoint_utils.save_checkpoint(args, trainer, epoch_itr, None)

        if ':' in getattr(args, 'data', ''):
            # sharded data: get train iterator for next epoch
            epoch_itr = trainer.get_train_iterator(epoch_itr.epoch)
    train_meter.stop()
    print('| done training in {:.1f} seconds'.format(train_meter.sum))
Ejemplo n.º 18
0
def main(args, init_distributed=False):
    utils.import_user_module(args)

    assert args.max_tokens is not None or args.max_sentences is not None, \
        'Must specify batch size either with --max-tokens or --max-sentences'
    metrics.reset()

    # Initialize CUDA and distributed training
    if torch.cuda.is_available() and not args.cpu:
        torch.cuda.set_device(args.device_id)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if init_distributed:
        args.distributed_rank = distributed_utils.distributed_init(args)

    if distributed_utils.is_master(args):
        checkpoint_utils.verify_checkpoint_directory(args.save_dir)

    # Print args
    logger.info(args)

    # Setup task, e.g., translation, language modeling, etc.
    task = tasks.setup_task(args)

    # Load valid dataset (we load training data below, based on the latest checkpoint)
    for valid_sub_split in args.valid_subset.split(','):
        task.load_dataset(valid_sub_split, combine=False, epoch=1)

    # Build model and criterion
    model = task.build_model(args)
    criterion = task.build_criterion(args)
    logger.info(model)
    logger.info('model {}, criterion {}'.format(args.arch,
                                                criterion.__class__.__name__))
    logger.info('num. model params: {} (num. trained: {})'.format(
        sum(p.numel() for p in model.parameters()),
        sum(p.numel() for p in model.parameters() if p.requires_grad),
    ))

    # (optionally) Configure quantization
    if args.quantization_config_path is not None:
        quantizer = quantization_utils.Quantizer(
            config_path=args.quantization_config_path,
            max_epoch=args.max_epoch,
            max_update=args.max_update,
        )
    else:
        quantizer = None

    # Build trainer
    if args.model_parallel_size == 1:
        trainer = Trainer(args, task, model, criterion, quantizer)
    else:
        trainer = MegatronTrainer(args, task, model, criterion)

    logger.info('training on {} GPUs'.format(args.distributed_world_size))
    logger.info(
        'max tokens per GPU = {} and max sentences per GPU = {}'.format(
            args.max_tokens,
            args.max_sentences,
        ))

    # Load the latest checkpoint if one is available and restore the
    # corresponding train iterator
    extra_state, epoch_itr = checkpoint_utils.load_checkpoint(args, trainer)

    # Train until the learning rate gets too small
    max_epoch = args.max_epoch or math.inf
    max_update = args.max_update or math.inf
    lr = trainer.get_lr()
    train_meter = meters.StopwatchMeter()
    train_meter.start()
    while (lr > args.min_lr and epoch_itr.next_epoch_idx <= max_epoch):
        # train for one epoch
        valid_losses = train(args, trainer, task, epoch_itr, max_update)
        if should_stop_early(
                args,
                valid_losses[0]) or trainer.get_num_updates() >= max_update:
            break

        # only use first validation loss to update the learning rate
        lr = trainer.lr_step(epoch_itr.epoch, valid_losses[0])

        epoch_itr = trainer.get_train_iterator(
            epoch_itr.next_epoch_idx,
            # sharded data: get train iterator for next epoch
            load_dataset=(os.pathsep in getattr(args, 'data', '')),
        )
    train_meter.stop()
    logger.info('done training in {:.1f} seconds'.format(train_meter.sum))
Ejemplo n.º 19
0
def main(
    args,
    init_distributed=False,
    after_distributed_init_fn: Optional[Callable[[argparse.Namespace],
                                                 argparse.Namespace]] = None,
):
    utils.import_user_module(args)

    assert (
        args.max_tokens is not None or args.max_sentences is not None
    ), "Must specify batch size either with --max-tokens or --max-sentences"
    metrics.reset()

    # Initialize CUDA and distributed training
    if torch.cuda.is_available() and not args.cpu and not getattr(
            args, "tpu", False):
        torch.cuda.set_device(args.device_id)
    np.random.seed(args.seed)
    utils.set_torch_seed(args.seed)
    if init_distributed:
        args.distributed_rank = distributed_utils.distributed_init(args)
        if after_distributed_init_fn:
            args = after_distributed_init_fn(args)

    if distributed_utils.is_master(args):
        checkpoint_utils.verify_checkpoint_directory(args.save_dir)

    # Print args
    logger.info(args)

    # Setup task, e.g., translation, language modeling, etc.
    task = tasks.setup_task(args)

    # Load valid dataset (we load training data below, based on the latest checkpoint)
    for valid_sub_split in args.valid_subset.split(","):
        task.load_dataset(valid_sub_split, combine=False, epoch=1)

    # Build model and criterion
    model = task.build_model(args)
    criterion = task.build_criterion(args)
    logger.info(model)
    logger.info("model {}, criterion {}".format(args.arch,
                                                criterion.__class__.__name__))
    logger.info("num. model params: {} (num. trained: {})".format(
        sum(p.numel() for p in model.parameters()),
        sum(p.numel() for p in model.parameters() if p.requires_grad),
    ))

    # (optionally) Configure quantization
    if args.quantization_config_path is not None:
        quantizer = quantization_utils.Quantizer(
            config_path=args.quantization_config_path,
            max_epoch=args.max_epoch,
            max_update=args.max_update,
        )
    else:
        quantizer = None

    # Build trainer
    if args.model_parallel_size == 1:
        trainer = Trainer(args, task, model, criterion, quantizer)
    else:
        trainer = MegatronTrainer(args, task, model, criterion)

    logger.info("training on {} devices (GPUs/TPUs)".format(
        args.distributed_world_size))
    logger.info("training on {} devices (GPUs/TPUs)".format(
        args.distributed_world_size))
    logger.info(
        "max tokens per GPU = {} and max sentences per GPU = {}".format(
            args.max_tokens, args.max_sentences))

    # Load the latest checkpoint if one is available and restore the
    # corresponding train iterator
    extra_state, epoch_itr = checkpoint_utils.load_checkpoint(args, trainer)
    if args.tpu:
        import torch_xla.core.xla_model as xm

        xm.rendezvous("load_checkpoint")  # wait for all workers
        xm.mark_step()

    # Train until the learning rate gets too small
    max_epoch = args.max_epoch or math.inf
    lr = trainer.get_lr()
    train_meter = meters.StopwatchMeter()
    train_meter.start()
    experiment_path = args.mhr_experiment  # path for experiment configuration
    total_samples = 0
    restore = {
        'enc_self_attn': None,
        'dec_self_attn': None,
        'dec_enc_attn': None
    }
    last_epoch_num = {
        'enc_self_attn': 0,
        'dec_self_attn': 0,
        'dec_enc_attn': 0
    }
    while lr > args.min_lr and epoch_itr.next_epoch_idx <= max_epoch:
        # train for one epoch
        valid_losses, should_stop, total_samples_temp, restore, last_epoch_num = train(
            args,
            trainer,
            task,
            epoch_itr,
            model,
            experiment_path,
            total_samples=total_samples,
            restore=restore,
            last_epoch_num=last_epoch_num)
        total_samples = total_samples_temp

        if should_stop:
            break

        # only use first validation loss to update the learning rate
        lr = trainer.lr_step(epoch_itr.epoch, valid_losses[0])

        epoch_itr = trainer.get_train_iterator(
            epoch_itr.next_epoch_idx,
            # sharded data: get train iterator for next epoch
            load_dataset=(os.pathsep in getattr(args, "data", "")),
        )
    train_meter.stop()
    logger.info("done training in {:.1f} seconds".format(train_meter.sum))
Ejemplo n.º 20
0
def main(args):
    if args.max_tokens is None:
        args.max_tokens = 6000
    print(args)

    if not torch.cuda.is_available():
        raise NotImplementedError('Training on CPU is not supported')
    torch.cuda.set_device(args.device_id)
    torch.manual_seed(args.seed)

    # Setup task, e.g., translation, language modeling, etc.
    task = tasks.setup_task(args)

    # Load dataset splits
    load_dataset_splits(task, ['train', 'valid'])

    # Build model and criterion
    model = task.build_model(args)
    criterion = task.build_criterion(args)
    print('| model {}, criterion {}'.format(args.arch,
                                            criterion.__class__.__name__))
    print('| num. model params: {}'.format(
        sum(p.numel() for p in model.parameters())))

    # Make a dummy batch to (i) warm the caching allocator and (ii) as a
    # placeholder DistributedDataParallel when there's an uneven number of
    # batches per worker.
    max_positions = utils.resolve_max_positions(
        task.max_positions(),
        model.max_positions(),
    )
    dummy_batch = task.dataset('train').get_dummy_batch(
        args.max_tokens, max_positions)

    # Build trainer
    trainer = Trainer(args, task, model, criterion, dummy_batch)
    print('| training on {} GPUs'.format(args.distributed_world_size))
    print('| max tokens per GPU = {} and max sentences per GPU = {}'.format(
        args.max_tokens,
        args.max_sentences,
    ))
    ## "Prune" heads (actually mask but shh...)
    #if len(args.transformer_mask_heads) > 0:
    #    # Determine which head to prune
    #    to_prune = parse_head_pruning_descriptors(
    #        args.transformer_mask_heads,
    #        reverse_descriptors=args.transformer_mask_all_but_one_head,
    #        n_heads=model.encoder.layers[0].self_attn.num_heads
    #    )
    #    print(to_prune)
    #    # Apply pruning
    #    mask_heads(model, to_prune, args.transformer_mask_rescale)

    # Save initial model
    initial = os.path.join(args.save_dir, "checkpoint_initial.pt")
    trainer.save_checkpoint(initial, {})

    # Initialize dataloader
    epoch_itr = task.get_batch_iterator(
        dataset=task.dataset(args.train_subset),
        max_tokens=args.max_tokens,
        max_sentences=args.max_sentences,
        max_positions=max_positions,
        ignore_invalid_inputs=True,
        required_batch_size_multiple=8,
        seed=args.seed,
        num_shards=args.distributed_world_size,
        shard_id=args.distributed_rank,
    )

    # Load the latest checkpoint if one is available
    if not load_checkpoint(args, trainer, epoch_itr):
        trainer.dummy_train_step([dummy_batch])

    # Train until the learning rate gets too small
    max_epoch = args.max_epoch or math.inf
    max_update = args.max_update or math.inf
    lr = trainer.get_lr()
    train_meter = StopwatchMeter()
    train_meter.start()
    valid_losses = [None]
    valid_subsets = args.valid_subset.split(',')
    while lr > args.min_lr and epoch_itr.epoch < max_epoch and trainer.get_num_updates(
    ) < max_update:
        # train for one epoch
        train(args, trainer, task, epoch_itr)

        if epoch_itr.epoch % args.validate_interval == 0:
            valid_losses = validate(args, trainer, task, epoch_itr,
                                    valid_subsets)

        # only use first validation loss to update the learning rate
        lr = trainer.lr_step(epoch_itr.epoch, valid_losses[0])

        #if epoch_itr.epoch % args.save_interval == 0:
        #    save_checkpoint(args, trainer, epoch_itr, valid_losses[0])

        # ***********************************Below Changed******************************
        # save checkpoint
        #if epoch_itr.epoch % args.save_interval == 0:
        save_interval = 5  # prune and save checkpoint for every five epoch
        if epoch_itr.epoch % save_interval == 0:  #****** changed
            # do prunning before saving
            prune2(args, task, model, trainer, epoch_itr)  #****** changed2
            # save checkpoint
            save_checkpoint(args, trainer, epoch_itr, valid_losses[0])
    prune2(args, task, model, trainer, epoch_itr
           )  #****** changed2 do last prunning on the last chekcpoint saved
    # ***********************************Above Changed******************************
    train_meter.stop()
    print('| done training in {:.1f} seconds'.format(train_meter.sum))
def main(args):
    if args.max_tokens is None:
        args.max_tokens = 6000
    print(args)

    if not torch.cuda.is_available():
        raise NotImplementedError('Training on CPU is not supported')
    torch.cuda.set_device(args.device_id)
    torch.manual_seed(args.seed)

    # Setup task, e.g., translation, language modeling, etc.
    task = tasks.setup_task(args)

    # Load dataset splits
    load_dataset_splits(task, ['train', 'valid'])

    # Build model and criterion
    model = task.build_model(args)
    criterion = task.build_criterion(args)
    print('| model {}, criterion {}'.format(args.arch, criterion.__class__.__name__))
    print('| num. model params: {}'.format(sum(p.numel() for p in model.parameters())))

    # Make a dummy batch to (i) warm the caching allocator and (ii) as a
    # placeholder DistributedDataParallel when there's an uneven number of
    # batches per worker.
    max_positions = utils.resolve_max_positions(
        task.max_positions(),
        model.max_positions(),
    )
    dummy_batch = task.dataset('train').get_dummy_batch(args.max_tokens, max_positions)

    # Build trainer
    trainer = Trainer(args, task, model, criterion, dummy_batch)
    print('| training on {} GPUs'.format(args.distributed_world_size))
    print('| max tokens per GPU = {} and max sentences per GPU = {}'.format(
        args.max_tokens,
        args.max_sentences,
    ))

    # Initialize dataloader
    epoch_itr = task.get_batch_iterator(
        dataset=task.dataset(args.train_subset),
        max_tokens=args.max_tokens,
        max_sentences=args.max_sentences,
        max_positions=max_positions,
        ignore_invalid_inputs=True,
        required_batch_size_multiple=8,
        seed=args.seed,
        num_shards=args.distributed_world_size,
        shard_id=args.distributed_rank,
    )

    # Load the latest checkpoint if one is available
    if not load_checkpoint(args, trainer, epoch_itr):
        trainer.dummy_train_step([dummy_batch])

    # Train until the learning rate gets too small
    max_epoch = args.max_epoch or math.inf
    max_update = args.max_update or math.inf
    lr = trainer.get_lr()
    train_meter = StopwatchMeter()
    train_meter.start()
    valid_losses = [None]
    valid_subsets = args.valid_subset.split(',')
    while lr > args.min_lr and epoch_itr.epoch < max_epoch and trainer.get_num_updates() < max_update:
        # train for one epoch
        train(args, trainer, task, epoch_itr)

        if epoch_itr.epoch % args.validate_interval == 0:
            valid_losses = validate(args, trainer, task, epoch_itr, valid_subsets)

        # only use first validation loss to update the learning rate
        lr = trainer.lr_step(epoch_itr.epoch, valid_losses[0])

        # save checkpoint
        if epoch_itr.epoch % args.save_interval == 0:
            save_checkpoint(args, trainer, epoch_itr, valid_losses[0])
    train_meter.stop()
    print('| done training in {:.1f} seconds'.format(train_meter.sum))
Ejemplo n.º 22
0
def main(args, init_distributed=False):
    utils.import_user_module(args)

    assert args.max_tokens is not None or args.max_sentences is not None, \
        'Must specify batch size either with --max-tokens or --max-sentences'

    # Initialize CUDA and distributed training
    if torch.cuda.is_available() and not args.cpu:
        torch.cuda.set_device(args.device_id)
    torch.manual_seed(args.seed)
    if init_distributed:
        args.distributed_rank = distributed_utils.distributed_init(args)

    # Print args
    print(args)

    # Setup task, e.g., translation, language modeling, etc.
    task = tasks.setup_task(args)

    # Load dataset splits
    task.load_dataset(args.train_subset, combine=True, epoch=0)
    for valid_sub_split in args.valid_subset.split(','):
        task.load_dataset(valid_sub_split, combine=True, epoch=0)

    # Build model and criterion
    model = task.build_model(args)
    criterion = task.build_criterion(args)
    print(model)
    print('| model {}, criterion {}'.format(args.arch,
                                            criterion.__class__.__name__))
    print('| num. model params: {} (num. trained: {})'.format(
        sum(p.numel() for p in model.parameters()),
        sum(p.numel() for p in model.parameters() if p.requires_grad),
    ))

    # Build trainer
    trainer = Trainer(args, task, model, criterion)
    print('| training on {} GPUs'.format(args.distributed_world_size))
    print('| max tokens per GPU = {} and max sentences per GPU = {}'.format(
        args.max_tokens,
        args.max_sentences,
    ))

    max_positions = utils.resolve_max_positions(
        task.max_positions(),
        model.max_positions(),
    )
    # Initialize dataloader
    epoch_itr = task.get_batch_iterator(
        dataset=task.dataset(args.train_subset),
        max_tokens=args.max_tokens,
        max_sentences=args.max_sentences,
        max_positions=max_positions,
        ignore_invalid_inputs=True,
        required_batch_size_multiple=args.required_batch_size_multiple,
        seed=args.seed,
        num_shards=args.distributed_world_size,
        shard_id=args.distributed_rank,
        num_workers=args.num_workers,
    )

    # Load the latest checkpoint if one is available
    load_checkpoint(args, trainer, epoch_itr, max_positions, task)

    # Train until the learning rate gets too small
    max_epoch = args.max_epoch or math.inf
    max_update = args.max_update or math.inf
    lr = trainer.get_lr()
    train_meter = StopwatchMeter()
    train_meter.start()
    valid_losses = [None]
    valid_subsets = args.valid_subset.split(',')
    while lr > args.min_lr and epoch_itr.epoch < max_epoch and trainer.get_num_updates(
    ) < max_update:
        # train for one epoch
        train(args, trainer, task, epoch_itr)

        if epoch_itr.epoch % args.validate_interval == 0:
            valid_losses = validate(args, trainer, task, epoch_itr,
                                    valid_subsets)

        # only use first validation loss to update the learning rate
        lr = trainer.lr_step(epoch_itr.epoch, valid_losses[0])

        # save checkpoint
        if epoch_itr.epoch % args.save_interval == 0:
            save_checkpoint(args, trainer, epoch_itr, valid_losses[0])

        epoch_itr = reload_train(args, epoch_itr, max_positions, task)
    train_meter.stop()
    print('| done training in {:.1f} seconds'.format(train_meter.sum))
Ejemplo n.º 23
0
def main(args, init_distributed=False):
    utils.import_user_module(args)

    assert args.max_tokens is not None or args.max_sentences is not None, \
        'Must specify batch size either with --max-tokens or --max-sentences'

    # Initialize CUDA and distributed training
    if torch.cuda.is_available() and not args.cpu:
        torch.cuda.set_device(args.device_id)
    torch.manual_seed(args.seed)
    if init_distributed:
        args.distributed_rank = distributed_init_hvd(args)

    # Print args
    print(args)

    # if not HAS_NSML:
    #     args.data[0] = args.data[0].replace("/train", "")

    # Setup task, e.g., translation, language modeling, etc.
    task = tasks.setup_task(args)

    # Load valid dataset (we load training data below, based on the latest checkpoint)
    for valid_sub_split in args.valid_subset.split(','):
        task.load_dataset(valid_sub_split, combine=False, epoch=0)

    # Build model and criterion
    model = task.build_model(args)
    criterion = task.build_criterion(args)

    if args.train_decoder_only:
        for name, param in model.named_parameters():
            if "decoder" not in name:
                param.requires_grad_(False)

    print(model)
    print('| model {}, criterion {}'.format(args.arch,
                                            criterion.__class__.__name__))
    print('| num. model params: {} (num. trained: {})'.format(
        sum(p.numel() for p in model.parameters()),
        sum(p.numel() for p in model.parameters() if p.requires_grad),
    ))

    # Setup session
    if HAS_WANDB and distributed_utils.is_master(args):
        wandb.init(project="cmlm", config=args)
        wandb.watch(model)

    # Load pre-trained model
    data_token = args.data[0].split("/")[-1]
    if "bert" in args.arch:
        pretrained_path = "{}/train/pretrained_models/maskPredict_{}/checkpoint_best.pt".format(
            DATASET_PATH,
            data_token.split(".")[-1].replace("-", "_"))
        if not HAS_NSML:
            pretrained_path = pretrained_path.replace("/train", "")
        print("| loading", pretrained_path)
        state = checkpoint_utils.load_checkpoint_to_cpu(pretrained_path)
        model.load_state_dict(state["model"], strict=True)
        baseline_model = task.build_model(args)
        baseline_model.load_state_dict(state["model"], strict=True)
        if torch.cuda.is_available():
            baseline_model.cuda()
        task.set_baseline_model(baseline_model)

    if not args.masking and HAS_NSML:

        def nsml_bind(model):
            def save(dir_path):
                state = {
                    'model': model.state_dict(),
                }
                torch.save(state, os.path.join(dir_path, 'best.pt'))

            def load(dir_path):
                state = torch.load(os.path.join(dir_path, 'best.pt'),
                                   map_location="cpu")
                model.load_state_dict(state['model'], strict=False)
                model.cuda()
                print('model loaded!')

            nsml.bind(save=save, load=load)

        nsml_bind(model)

    if args.load:
        print("loading model from session", args.load)
        if args.load.startswith("nsml://"):
            session = args.load.replace("nsml://", "")
        if ".pt" in session:
            session = session.replace(".pt", "")
            session, checkpoint_name = session.rsplit("/", 1)
        else:
            checkpoint_name = "best"
        if "-" in checkpoint_name:
            start, end = checkpoint_name.replace("epoch", "").split("-")
            checkpoints = [
                "epoch{}".format(i) for i in range(int(start),
                                                   int(end) + 1)
            ]
            print("| checkpoint average:", checkpoints)
            state_dict = None

            def load(dir_path):
                nonlocal state_dict, checkpoints
                state = torch.load(os.path.join(dir_path, 'best.pt'))
                model_state = state["model"]
                for k in model_state:
                    model_state[k] = model_state[k] / float(len(checkpoints))
                if state_dict is None:
                    state_dict = model_state
                else:
                    for k in state_dict:
                        state_dict[k] += model_state[k]
                print("checkpoint loaded")

            for checkpoint_name in checkpoints:
                nsml.load(checkpoint_name, load_fn=load, session=session)
            model.load_state_dict(state_dict)
        else:

            def load(dir_path):
                state = torch.load(os.path.join(dir_path, 'best.pt'))
                state_dict = state["model"]
                model.load_state_dict(state_dict)
                print("loaded")

            nsml.load(checkpoint_name, load_fn=load, session=session)

    # Prepare for decoder wise training
    if args.decoder_wise_training:
        print("| Decoder wise training, start refinement step 0")
        progressive_training_step = 0
        assert args.ddp_backend == "c10d"
    else:
        progressive_training_step = None

    # Build trainer
    trainer = Trainer(args, task, model, criterion)
    print('| training on {} GPUs'.format(args.distributed_world_size))
    print('| max tokens per GPU = {} and max sentences per GPU = {}'.format(
        args.max_tokens,
        args.max_sentences,
    ))

    # Load the latest checkpoint if one is available and restore the
    # corresponding train iterator
    extra_state, epoch_itr = checkpoint_utils.load_checkpoint(args, trainer)

    # Train until the learning rate gets too small
    max_epoch = args.max_epoch or math.inf
    max_update = args.max_update or math.inf
    lr = trainer.get_lr()
    train_meter = StopwatchMeter()
    train_meter.start()
    valid_losses = [None]
    valid_subsets = args.valid_subset.split(',')
    if hasattr(args, "progressive") and args.progressive:
        for i in range(args.refinetot if not getattr(args, "pnet", False) else
                       args.refinetot - 1):
            print("validating for refine step", i)
            validate(args,
                     trainer,
                     task,
                     epoch_itr,
                     valid_subsets,
                     force_refine_step=i)
        print("---")
    validate(args, trainer, task, epoch_itr, valid_subsets)
    while lr > args.min_lr and epoch_itr.epoch < max_epoch and trainer.get_num_updates(
    ) < max_update:
        # train for one epoch
        train(args,
              trainer,
              task,
              epoch_itr,
              force_refine_step=progressive_training_step)
        if not args.disable_validation and epoch_itr.epoch % args.validate_interval == 0:
            valid_losses = validate(
                args,
                trainer,
                task,
                epoch_itr,
                valid_subsets,
                force_refine_step=progressive_training_step)
        else:
            valid_losses = [None]

        if args.decoder_wise_training:
            progressive_training_step = update_num_to_refine_step(
                trainer.get_num_updates())

        # only use first validation loss to update the learning rate
        lr = trainer.lr_step(epoch_itr.epoch, valid_losses[0])

        # save checkpoint
        if epoch_itr.epoch % args.save_interval == 0:
            if HAS_NSML:
                if distributed_utils.is_master(args):
                    print("nsml save for epoch", epoch_itr.epoch)
                    nsml.save("epoch{}".format(epoch_itr.epoch))
            else:
                torch.save({"model": trainer.get_model().state_dict()},
                           "/tmp/epoch{}.pt".format(epoch_itr.epoch))
                if HAS_WANDB:
                    wandb.save("/tmp/epoch{}.pt".format(epoch_itr.epoch))
                # checkpoint_utils.save_checkpoint(args, trainer, epoch_itr, valid_losses[0])

        if ':' in getattr(args, 'data', ''):
            # sharded data: get train iterator for next epoch
            epoch_itr = trainer.get_train_iterator(epoch_itr.epoch)
    train_meter.stop()
    print('| done training in {:.1f} seconds'.format(train_meter.sum))
Ejemplo n.º 24
0
def main(args):
    utils.import_user_module(args)

    assert (
        args.max_tokens is not None or args.max_sentences is not None
    ), "Must specify batch size either with --max-tokens or --max-sentences"

    metrics.reset()

    np.random.seed(args.seed)
    utils.set_torch_seed(args.seed)

    if distributed_utils.is_master(args):
        checkpoint_utils.verify_checkpoint_directory(args.save_dir)

    # Print args
    logger.info(args)

    # Setup task, e.g., translation, language modeling, etc.
    task = tasks.setup_task(args)

    # Load valid dataset (we load training data below, based on the latest checkpoint)
    for valid_sub_split in args.valid_subset.split(","):
        task.load_dataset(valid_sub_split, combine=False, epoch=1)

    # Build model and criterion
    model = task.build_model(args)
    criterion = task.build_criterion(args)
    logger.info(model)
    logger.info("task: {} ({})".format(args.task, task.__class__.__name__))
    logger.info("model: {} ({})".format(args.arch, model.__class__.__name__))
    logger.info("criterion: {} ({})".format(args.criterion,
                                            criterion.__class__.__name__))
    logger.info("num. model params: {} (num. trained: {})".format(
        sum(p.numel() for p in model.parameters()),
        sum(p.numel() for p in model.parameters() if p.requires_grad),
    ))

    # (optionally) Configure quantization
    if args.quantization_config_path is not None:
        quantizer = quantization_utils.Quantizer(
            config_path=args.quantization_config_path,
            max_epoch=args.max_epoch,
            max_update=args.max_update,
        )
    else:
        quantizer = None

    # Build trainer
    if args.model_parallel_size == 1:
        trainer = Trainer(args, task, model, criterion, quantizer)
    else:
        trainer = MegatronTrainer(args, task, model, criterion)

    logger.info("training on {} devices (GPUs/TPUs)".format(
        args.distributed_world_size))
    logger.info(
        "max tokens per GPU = {} and max sentences per GPU = {}".format(
            args.max_tokens, args.max_sentences))

    # Load the latest checkpoint if one is available and restore the
    # corresponding train iterator
    extra_state, epoch_itr = checkpoint_utils.load_checkpoint(args, trainer)

    # Train until the learning rate gets too small
    max_epoch = args.max_epoch or math.inf
    lr = trainer.get_lr()
    train_meter = meters.StopwatchMeter()
    train_meter.start()

    while lr > args.min_lr and epoch_itr.next_epoch_idx <= max_epoch:
        # train for one epoch
        valid_losses, should_stop = train(args, trainer, task, epoch_itr)
        if should_stop:
            break

        # only use first validation loss to update the learning rate
        lr = trainer.lr_step(epoch_itr.epoch, valid_losses[0])

        epoch_itr = trainer.get_train_iterator(
            epoch_itr.next_epoch_idx,
            # sharded data: get train iterator for next epoch
            load_dataset=task.has_sharded_data("train"),
        )
    train_meter.stop()
    logger.info("done training in {:.1f} seconds".format(train_meter.sum))
Ejemplo n.º 25
0
def main(args, config=None, init_distributed=False):
    utils.import_user_module(args)

    experiment = None
    if config:
        experiment = ExistingExperiment(
            api_key=config["api_key"],
            previous_experiment=config["experiment_key"],
            auto_output_logging=None,
        )

    assert (
        args.max_tokens is not None or args.max_sentences is not None
    ), "Must specify batch size either with --max-tokens or --max-sentences"

    # Initialize CUDA and distributed training
    if torch.cuda.is_available() and not args.cpu:
        torch.cuda.set_device(args.device_id)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if init_distributed:
        args.distributed_rank = distributed_utils.distributed_init(args)

    if distributed_utils.is_master(args):
        checkpoint_utils.verify_checkpoint_directory(args.save_dir)

    print(args)
    if experiment:
        experiment.log_parameters(vars(args),
                                  prefix="Device {} :: ".format(
                                      args.device_id))

    # Setup task, e.g., translation, language modeling, etc.
    task = tasks.setup_task(args)

    # Load valid dataset (we load training data below, based on the latest checkpoint)
    for valid_sub_split in args.valid_subset.split(","):
        task.load_dataset(valid_sub_split, combine=False, epoch=0)

    # Build model and criterion
    model = task.build_model(args)
    criterion = task.build_criterion(args)
    print(model)
    print("| model {}, criterion {}".format(args.arch,
                                            criterion.__class__.__name__))
    print("| num. model params: {} (num. trained: {})".format(
        sum(p.numel() for p in model.parameters()),
        sum(p.numel() for p in model.parameters() if p.requires_grad),
    ))

    if experiment:
        experiment.log_parameters(
            {
                "criterion":
                criterion.__class__.__name__,
                "num. model params":
                sum(p.numel() for p in model.parameters()),
                "num. trained params":
                sum(p.numel() for p in model.parameters() if p.requires_grad),
            },
            prefix="Device {} :: ".format(args.device_id),
        )

    # Build trainer
    trainer = Trainer(args, task, model, criterion)
    print("| training on {} GPUs".format(args.distributed_world_size))
    print("| max tokens per GPU = {} and max sentences per GPU = {}".format(
        args.max_tokens, args.max_sentences))

    # Load the latest checkpoint if one is available and restore the
    # corresponding train iterator
    extra_state, epoch_itr = checkpoint_utils.load_checkpoint(args, trainer)

    # Train until the learning rate gets too small
    max_epoch = args.max_epoch or math.inf
    max_update = args.max_update or math.inf
    lr = trainer.get_lr()
    train_meter = StopwatchMeter()
    train_meter.start()
    valid_subsets = args.valid_subset.split(",")
    while (lr > args.min_lr and epoch_itr.epoch < max_epoch
           and trainer.get_num_updates() < max_update):
        # train for one epoch
        train(args, trainer, task, epoch_itr, experiment)

        if (not args.disable_validation
                and epoch_itr.epoch % args.validate_interval == 0):
            valid_losses = validate(args, trainer, task, epoch_itr,
                                    valid_subsets, experiment)
        else:
            valid_losses = [None]

        # only use first validation loss to update the learning rate
        lr = trainer.lr_step(epoch_itr.epoch, valid_losses[0])

        # save checkpoint
        if epoch_itr.epoch % args.save_interval == 0:
            checkpoint_utils.save_checkpoint(args, trainer, epoch_itr,
                                             valid_losses[0])

        reload_dataset = ":" in getattr(args, "data", "")
        # sharded data: get train iterator for next epoch
        epoch_itr = trainer.get_train_iterator(epoch_itr.epoch,
                                               load_dataset=reload_dataset)
    train_meter.stop()
    print("| done training in {:.1f} seconds".format(train_meter.sum))

    if experiment:
        experiment.log_metrics(
            {
                "valid_loss": valid_losses[0],
                "lr": lr
            },
            prefix="Device {} ".format(args.device_id),
        )
Ejemplo n.º 26
0
def main(args):
    print(args)

    if not torch.cuda.is_available():
        raise NotImplementedError('Training on CPU is not supported')
    torch.cuda.set_device(args.device_id)
    torch.manual_seed(args.seed)

    # Load dataset
    splits = ['train', 'valid']
    if data.has_binary_files(args.data, splits):
        dataset = data.load_dataset(
            args.data, splits, args.source_lang, args.target_lang)
    else:
        dataset = data.load_raw_text_dataset(
            args.data, splits, args.source_lang, args.target_lang)
    if args.source_lang is None or args.target_lang is None:
        # record inferred languages in args, so that it's saved in checkpoints
        args.source_lang, args.target_lang = dataset.src, dataset.dst
    print('| [{}] dictionary: {} types'.format(dataset.src, len(dataset.src_dict)))
    print('| [{}] dictionary: {} types'.format(dataset.dst, len(dataset.dst_dict)))
    for split in splits:
        print('| {} {} {} examples'.format(args.data, split, len(dataset.splits[split])))

    # Build model and criterion
    model = models.build_model(args, dataset.src_dict, dataset.dst_dict)
    criterion = criterions.build_criterion(args, dataset.src_dict, dataset.dst_dict)
    print('| model {}, criterion {}'.format(args.arch, criterion.__class__.__name__))
    print('| num. model params: {}'.format(sum(p.data.numel() for p in model.parameters())))

    # Build trainer
    trainer = Trainer(args, model, criterion)
    print('| training on {} GPUs'.format(args.distributed_world_size))
    print('| max tokens per GPU = {} and max sentences per GPU = {}'.format(
        args.max_tokens,
        args.max_sentences,
    ))

    # Load the latest checkpoint if one is available
    os.makedirs(args.save_dir, exist_ok=True)
    checkpoint_path = os.path.join(args.save_dir, args.restore_file)
    extra_state = trainer.load_checkpoint(checkpoint_path)
    if extra_state is not None:
        epoch = extra_state['epoch']
        batch_offset = extra_state['batch_offset']
        print('| loaded checkpoint {} (epoch {})'.format(checkpoint_path, epoch))
        if batch_offset == 0:
            trainer.lr_step(epoch)
            epoch += 1
    else:
        epoch, batch_offset = 1, 0

    # Train until the learning rate gets too small
    max_epoch = args.max_epoch or math.inf
    lr = trainer.get_lr()
    train_meter = StopwatchMeter()
    train_meter.start()
    while lr > args.min_lr and epoch <= max_epoch:
        # train for one epoch
        train(args, trainer, dataset, epoch, batch_offset)

        # evaluate on validate set
        for k, subset in enumerate(args.valid_subset.split(',')):
            val_loss = validate(args, trainer, dataset, subset, epoch)
            if k == 0:
                # only use first validation loss to update the learning schedule
                lr = trainer.lr_step(epoch, val_loss)

                # save checkpoint
                if not args.no_save:
                    save_checkpoint(trainer, args, epoch, 0, val_loss)

        epoch += 1
        batch_offset = 0
    train_meter.stop()

    print('| done training in {:.1f} seconds'.format(train_meter.sum))
Ejemplo n.º 27
0
def main(args, init_distributed=False):
    utils.import_user_module(args)

    assert args.max_tokens is not None or args.max_sentences is not None, \
        'Must specify batch size either with --max-tokens or --max-sentences'

    # Initialize CUDA and distributed training
    if torch.cuda.is_available() and not args.cpu:
        torch.cuda.set_device(args.device_id)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if init_distributed:
        args.distributed_rank = distributed_utils.distributed_init(args)

    if distributed_utils.is_master(args):
        checkpoint_utils.verify_checkpoint_directory(args.save_dir)

    # Print args
    print(args)

    # Setup task, e.g., translation, language modeling, etc.
    task = tasks.setup_task(args)

    # Load valid dataset (we load training data below, based on the latest checkpoint)
    for valid_sub_split in args.valid_subset.split(','):
        task.load_dataset(valid_sub_split, combine=False, epoch=0)

    # Build model and criterion
    model = task.build_model(args)
    criterion = task.build_criterion(args)
    print(model)
    print('| model {}, criterion {}'.format(args.arch,
                                            criterion.__class__.__name__))
    print('| num. model params: {} (num. trained: {})'.format(
        sum(p.numel() for p in model.parameters()),
        sum(p.numel() for p in model.parameters() if p.requires_grad),
    ))

    # Build trainer
    trainer = Trainer(args, task, model, criterion)
    print('| training on {} GPUs'.format(args.distributed_world_size))
    print('| max tokens per GPU = {} and max sentences per GPU = {}'.format(
        args.max_tokens,
        args.max_sentences,
    ))

    # Load the latest checkpoint if one is available and restore the
    # corresponding train iterator
    extra_state, epoch_itr, filtered_maxpos_indices = checkpoint_utils.load_checkpoint(
        args, trainer)

    # pretrain data actor
    # only the language actor model can be pretrained

    if args.pretrain_laser and args.pretrain_data_actor and args.data_actor == 'ave':
        # pretrain the agent with LASER score
        # epoch_itr, indices = trainer.get_train_iterator(1)
        path = '/home/wtan12/multiDDS/'
        trainer.pretrain_LASER('en-ps.laser-score', epoch_itr)

    if args.compare_laser:
        epoch_itr, indices = trainer.get_train_iterator(1)
        print('Number of Indices: ', len(indices))
        scores = collections.defaultdict(float)
        # compare with laser label using R^2 Score, only used after model is trained
        # itr = epoch_itr.next_epoch_itr(fix_batches_to_gpus=False, shuffle=False)
        data_actor = trainer.data_actor
        itr = epoch_itr.next_epoch_itr(
            fix_batches_to_gpus=args.fix_batches_to_gpus,
            shuffle=False,
            offset=0,
            datasize=-1,
        )
        for i, sample in enumerate(itr):
            sample = trainer._prepare_sample(sample)
            sample = list(sample.values())[0]
            score = data_actor(sample).cpu().detach().numpy().tolist()
            indices = sample['id'].data.cpu().numpy().ravel().tolist()
            for k, v in zip(indices, score):
                scores[k] = float(v[0])

        scores = sorted(scores.items(), key=lambda x: x[0])
        print('Number of Indices in Scoring file: ', len(scores))
        path = '/home/wtan12/multiDDS/'
        with open(path + 'en-ps.laser-score', 'r') as r:
            data = r.read()
        laser_score = []
        for i, item in enumerate(data.split('\n')):
            laser_score.append(item)
        laser_score.pop()
        r2 = 0.0
        with open(path + 'en-ps.dds_score', 'w') as f:
            for k, v in scores:
                f.write(str(v) + '\n')
                truth = float(laser_score[k])
                r2 += (truth - v)**2
        print('R2 Score compared to LASER file: ', r2)
        return

    # Train until the learning rate gets too small
    max_epoch = args.max_epoch or math.inf
    max_update = args.max_update or math.inf
    lr = trainer.get_lr()
    train_meter = StopwatchMeter()
    train_meter.start()
    valid_subsets = args.valid_subset.split(',')
    if args.eval_bleu:
        generator = task.build_generator(args)
        args.maximize_best_checkpoint_metric = True
    else:
        generator = None
    while lr > args.min_lr and epoch_itr.epoch < max_epoch and trainer.get_num_updates(
    ) < max_update:
        # train for one epoch
        epoch_itr = train(args, trainer, task, epoch_itr, generator,
                          filtered_maxpos_indices)

        if not args.disable_validation and epoch_itr.epoch % args.validate_interval == 0:
            valid_losses = validate(args, trainer, task, epoch_itr,
                                    valid_subsets, generator)
        else:
            valid_losses = [None]

        # only use first validation loss to update the learning rate
        lr = trainer.lr_step(epoch_itr.epoch, valid_losses[0])

        # save checkpoint
        if epoch_itr.epoch % args.save_interval == 0:
            checkpoint_utils.save_checkpoint(args, trainer, epoch_itr,
                                             valid_losses[0])

        if ':' in getattr(args, 'data', ''):
            # sharded data: get train iterator for next epoch
            epoch_itr = trainer.get_train_iterator(epoch_itr.epoch)[0]
    train_meter.stop()
    print('| done training in {:.1f} seconds'.format(train_meter.sum))
Ejemplo n.º 28
0
def main(args):
    if args.max_tokens is None:
        args.max_tokens = 6000
    print(args)

    if not torch.cuda.is_available():
        raise NotImplementedError('Training on CPU is not supported')
    torch.cuda.set_device(args.device_id)
    torch.manual_seed(args.seed)

    # Setup task, e.g., translation, language modeling, etc.
    task = tasks.setup_task(args)

    # Load dataset splits
    load_dataset_splits(args, task, ['train', 'valid'])

    # Build model and criterion
    model = task.build_model(args)
    criterion = task.build_criterion(args)
    print('| model {}, criterion {}'.format(args.arch,
                                            criterion.__class__.__name__))
    print('| num. model params: {}'.format(
        sum(p.numel() for p in model.parameters())))

    # Build trainer
    if args.fp16:
        trainer = FP16Trainer(args, task, model, criterion)
    else:
        if torch.cuda.get_device_capability(0)[0] >= 7:
            print(
                '| NOTICE: your device may support faster training with --fp16'
            )
        trainer = Trainer(args, task, model, criterion)
    print('| training on {} GPUs'.format(args.distributed_world_size))
    print('| max tokens per GPU = {} and max sentences per GPU = {}'.format(
        args.max_tokens,
        args.max_sentences,
    ))

    # Initialize dataloader
    max_positions = trainer.get_model().max_positions()
    epoch_itr = data.EpochBatchIterator(
        dataset=task.dataset(args.train_subset),
        max_tokens=args.max_tokens,
        max_sentences=args.max_sentences_valid,
        max_positions=max_positions,
        ignore_invalid_inputs=True,
        required_batch_size_multiple=8,
        seed=args.seed,
        num_shards=args.distributed_world_size,
        shard_id=args.distributed_rank,
    )

    # Load the latest checkpoint if one is available
    load_checkpoint(args, trainer, epoch_itr)

    # Send a dummy batch to warm the caching allocator
    dummy_batch = task.dataset('train').get_dummy_batch(
        args.max_tokens, max_positions)
    trainer.dummy_train_step(dummy_batch)

    # Train until the learning rate gets too small
    max_epoch = args.max_epoch or math.inf
    max_update = args.max_update or math.inf
    lr = trainer.get_lr()
    train_meter = StopwatchMeter()
    train_meter.start()
    valid_losses = [None]
    valid_subsets = args.valid_subset.split(',')
    while lr > args.min_lr and epoch_itr.epoch <= max_epoch and trainer.get_num_updates(
    ) < max_update:
        # train for one epoch
        train(args, trainer, task, epoch_itr)

        if epoch_itr.epoch % args.validate_interval == 0:
            valid_losses = validate(args, trainer, task, epoch_itr,
                                    valid_subsets)

        # only use first validation loss to update the learning rate
        lr = trainer.lr_step(epoch_itr.epoch, valid_losses[0])

        # save checkpoint
        if epoch_itr.epoch % args.save_interval == 0:
            save_checkpoint(args, trainer, epoch_itr, valid_losses[0])
    train_meter.stop()
    print('| done training in {:.1f} seconds'.format(train_meter.sum))
Ejemplo n.º 29
0
def main(args):
    utils.import_user_module(args)

    assert (
        args.max_tokens is not None or args.max_sentences is not None
    ), "Must specify batch size either with --max-tokens or --max-sentences"

    metrics.reset()

    np.random.seed(args.seed)
    utils.set_torch_seed(args.seed)

    if distributed_utils.is_master(args):
        checkpoint_utils.verify_checkpoint_directory(args.save_dir)
        checkpoint_utils.verify_checkpoint_directory(args.jason_log_dir)

    # Print args
    logger.info(args)

    # Setup task, e.g., translation, language modeling, etc.
    task = tasks.setup_task(args)

    # Load valid dataset (we load training data below, based on the latest checkpoint)
    for valid_sub_split in args.valid_subset.split(","):
        task.load_dataset(valid_sub_split, combine=False, epoch=1)

    # Build model and criterion
    model = task.build_model(args)
    criterion = task.build_criterion(args)
    logger.info(model)
    logger.info("task: {} ({})".format(args.task, task.__class__.__name__))
    logger.info("model: {} ({})".format(args.arch, model.__class__.__name__))
    logger.info(
        "criterion: {} ({})".format(args.criterion, criterion.__class__.__name__)
    )
    logger.info(
        "num. model params: {} (num. trained: {})".format(
            sum(p.numel() for p in model.parameters()),
            sum(p.numel() for p in model.parameters() if p.requires_grad),
        )
    )

    # (optionally) Configure quantization
    if args.quantization_config_path is not None:
        quantizer = quantization_utils.Quantizer(
            config_path=args.quantization_config_path,
            max_epoch=args.max_epoch,
            max_update=args.max_update,
        )
    else:
        quantizer = None

    # Build trainer
    if args.model_parallel_size == 1:
        trainer = Trainer(args, task, model, criterion, quantizer)
    else:
        trainer = MegatronTrainer(args, task, model, criterion)

    logger.info(
        "training on {} devices (GPUs/TPUs)".format(args.distributed_world_size)
    )
    logger.info(
        "max tokens per GPU = {} and max sentences per GPU = {}".format(
            args.max_tokens, args.max_sentences
        )
    )

    # Load the latest checkpoint if one is available and restore the
    # corresponding train iterator
    extra_state, epoch_itr = checkpoint_utils.load_checkpoint(
        args,
        trainer,
        # don't cache epoch iterators for sharded datasets
        disable_iterator_cache=task.has_sharded_data("train"),
    )

    # Train until the learning rate gets too small
    max_epoch = args.max_epoch or math.inf
    lr = trainer.get_lr()
    train_meter = meters.StopwatchMeter()
    train_meter.start()

    ##### begin jason #####
    updates_list = []; train_ppl_list = []; train_loss_list = []; val_ppl_list = []; val_loss_list = []; train_uid_loss_list = []; val_uid_loss_list = []
    log_writer = open(os.path.join(args.save_dir, 'train_logs.csv'), 'w')
    log_writer.write(f'updates,train_loss,train_ppl,val_loss,val_ppl\n')
    backup_writefile = os.path.join(args.jason_log_dir, 'train_logs_backup.csv')
    os.system(f'touch {backup_writefile}')
    os.system(f'echo "updates,train_loss,train_ppl,val_loss,val_ppl,train_uid_loss,val_uid_loss" >> {backup_writefile}')
    ##### end jason #####

    while lr > args.min_lr and epoch_itr.next_epoch_idx <= max_epoch:
        # train for one epoch
        valid_losses, should_stop, train_stats, valid_stats = train(args, trainer, task, epoch_itr)
        print("hello", valid_stats, train_stats)

        ##### begin jason #####
        if train_stats and valid_stats: 
            updates_list.append(train_stats['num_updates'])
            train_loss_list.append(train_stats['loss'])
            train_ppl_list.append(train_stats['ppl'])
            val_loss_list.append(valid_stats['loss'])
            val_ppl_list.append(valid_stats['ppl'])
            if 'uid_loss' not in train_stats:
                train_stats['uid_loss'] = -1
                valid_stats['uid_loss'] = -1
            train_uid_loss_list.append(train_stats['uid_loss'])
            val_uid_loss_list.append(valid_stats['uid_loss'])
            log_line = f"{train_stats['num_updates']},{train_stats['loss']},{train_stats['ppl']},{valid_stats['loss']},{valid_stats['ppl']},{train_stats['uid_loss']},{valid_stats['uid_loss']}"
            log_writer.write(f"{log_line}\n")
            os.system(f'echo "{log_line}" >> {backup_writefile}')

            best_val_loss = min(val_loss_list)
            best_val_loss_idx = val_loss_list.index(best_val_loss)
            updates_to_best_val_loss = updates_list[best_val_loss_idx]
            train_loss_at_best_val_loss = train_loss_list[best_val_loss_idx]

            jasons_vis.plot_jasons_lineplot(
                x_list = updates_list,
                y_list_list = [train_loss_list, val_loss_list, train_uid_loss_list, val_uid_loss_list],
                y_labels_list = ['train', 'dev', 'train uid', 'dev uid'], 
                x_ax_label = "Updates",
                y_ax_label = "Loss",
                title = f"dev_l={best_val_loss} updates={updates_to_best_val_loss} train_l={train_loss_at_best_val_loss}",
                output_png_path = os.path.join(args.jason_log_dir, f"{args.jason_log_dir.split('/')[-1]}_loss.png"),
            )
            jasons_vis.plot_jasons_lineplot(
                x_list = updates_list,
                y_list_list = [train_ppl_list, val_ppl_list],
                y_labels_list = ['train', 'dev'], 
                x_ax_label = "Updates",
                y_ax_label = "Perplexity",
                title = f" best_val_ppl={best_val_loss} " + args.jason_log_dir[:20],
                output_png_path = os.path.join(args.jason_log_dir, f"{args.jason_log_dir.split('/')[-1]}_perplexity.png"),
            )
        ##### end jason #####

        if should_stop:
            break

        # only use first validation loss to update the learning rate
        lr = trainer.lr_step(epoch_itr.epoch, valid_losses[0])

        epoch_itr = trainer.get_train_iterator(
            epoch_itr.next_epoch_idx,
            # sharded data: get train iterator for next epoch
            load_dataset=task.has_sharded_data("train"),
            # don't cache epoch iterators for sharded datasets
            disable_iterator_cache=task.has_sharded_data("train"),
        )
    train_meter.stop()
    logger.info("done training in {:.1f} seconds".format(train_meter.sum))
Ejemplo n.º 30
0
def main(args, init_distributed=False):
    utils.import_user_module(args)

    assert args.max_tokens is not None or args.max_sentences is not None, \
        'Must specify batch size either with --max-tokens or --max-sentences'

    # Initialize CUDA and distributed training
    if torch.cuda.is_available() and not args.cpu:
        torch.cuda.set_device(args.device_id)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if init_distributed:
        args.distributed_rank = distributed_utils.distributed_init(args)

    if distributed_utils.is_master(args):
        checkpoint_utils.verify_checkpoint_directory(args.save_dir)

    # Print args
    print(args)

    # Setup task, e.g., translation, language modeling, etc.
    task = tasks.setup_task(args)

    # Load valid dataset (we load training data below, based on the latest checkpoint)
    for valid_sub_split in args.valid_subset.split(','):
        task.load_dataset(valid_sub_split, combine=False, epoch=0)

    # Build model and criterion
    model = task.build_model(args)
    criterion = task.build_criterion(args)
    print(model)
    print('| model {}, criterion {}'.format(args.arch,
                                            criterion.__class__.__name__))
    print('| num. model params: {} (num. trained: {})'.format(
        sum(p.numel() for p in model.parameters()),
        sum(p.numel() for p in model.parameters() if p.requires_grad),
    ))

    # Build trainer
    trainer = Trainer(args, task, model, criterion)
    print('| training on {} GPUs'.format(args.distributed_world_size))
    print('| max tokens per GPU = {} and max sentences per GPU = {}'.format(
        args.max_tokens,
        args.max_sentences,
    ))

    # Load the latest checkpoint if one is available and restore the
    # corresponding train iterator
    extra_state, epoch_itr = checkpoint_utils.load_checkpoint(args, trainer)

    # Train until the learning rate gets too small
    max_epoch = args.max_epoch or math.inf
    max_update = args.max_update or math.inf
    lr = trainer.get_lr()
    train_meter = StopwatchMeter()
    train_meter.start()
    valid_subsets = args.valid_subset.split(',')
    while lr > args.min_lr and epoch_itr.epoch < max_epoch and trainer.get_num_updates(
    ) < max_update:
        # train for one epoch
        train(args, trainer, task, epoch_itr)

        if not args.disable_validation and epoch_itr.epoch % args.validate_interval == 0:
            valid_losses = validate(args, trainer, task, epoch_itr,
                                    valid_subsets)
        else:
            valid_losses = [None]

        # only use first validation loss to update the learning rate
        lr = trainer.lr_step(epoch_itr.epoch, valid_losses[0])

        # save checkpoint
        if epoch_itr.epoch % args.save_interval == 0:
            checkpoint_utils.save_checkpoint(args, trainer, epoch_itr,
                                             valid_losses[0])

        reload_dataset = ':' in getattr(args, 'data', '')
        # sharded data: get train iterator for next epoch
        epoch_itr = trainer.get_train_iterator(epoch_itr.epoch,
                                               load_dataset=reload_dataset)
    train_meter.stop()
    print('| done training in {:.1f} seconds'.format(train_meter.sum))
Ejemplo n.º 31
0
def main(args):
    if not torch.cuda.is_available():
        raise NotImplementedError('Training on CPU is not supported')
    torch.cuda.set_device(args.device_id)

    mlperf_compliance.mlperf_log.LOGGER.propagate = False

    # framework = f'Pytorch NGC {os.environ["NVIDIA_PYTORCH_VERSION"]}'
    # mlperf_submission_log(
    #     benchmark=mlperf_compliance.constants.TRANSFORMER,
    #     framework=framework)

    mlperf_compliance.mlperf_log.setdefault(
        root_dir=os.path.dirname(os.path.abspath(__file__)),
        benchmark=mlperf_compliance.constants.TRANSFORMER,
        stack_offset=1,
        extra_print=False)

    mlperf_print(key=mlperf_compliance.constants.INIT_START,
                 log_all_ranks=True)

    torch.manual_seed(args.seed)
    torch.cuda.manual_seed_all(args.seed)
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

    # preinit and warmup streams/groups for allreduce communicators
    allreduce_communicators = None
    if args.distributed_world_size > 1 and args.enable_parallel_backward_allred_opt:
        allreduce_groups = [
            torch.distributed.new_group()
            for _ in range(args.parallel_backward_allred_cuda_nstreams)
        ]
        allreduce_streams = [
            torch.cuda.Stream()
            for _ in range(args.parallel_backward_allred_cuda_nstreams)
        ]
        for group, stream in zip(allreduce_groups, allreduce_streams):
            with torch.cuda.stream(stream):
                torch.distributed.all_reduce(torch.cuda.FloatTensor(1),
                                             group=group)
        allreduce_communicators = (allreduce_groups, allreduce_streams)

    if args.max_tokens is None:
        args.max_tokens = 6000

    print(args)

    mlperf_print(key=mlperf_compliance.constants.GLOBAL_BATCH_SIZE,
                 value=args.max_tokens * args.distributed_world_size)
    mlperf_print(key=mlperf_compliance.constants.OPT_NAME,
                 value=args.optimizer)
    assert (len(args.lr) == 1)
    mlperf_print(key=mlperf_compliance.constants.OPT_BASE_LR,
                 value=args.lr[0] if len(args.lr) == 1 else args.lr)
    mlperf_print(key=mlperf_compliance.constants.OPT_LR_WARMUP_STEPS,
                 value=args.warmup_updates)
    assert (args.max_source_positions == args.max_target_positions)
    mlperf_print(key=mlperf_compliance.constants.MAX_SEQUENCE_LENGTH,
                 value=args.max_target_positions)
    mlperf_print(key=mlperf_compliance.constants.OPT_ADAM_BETA_1,
                 value=eval(args.adam_betas)[0])
    mlperf_print(key=mlperf_compliance.constants.OPT_ADAM_BETA_2,
                 value=eval(args.adam_betas)[1])
    mlperf_print(key=mlperf_compliance.constants.OPT_ADAM_EPSILON,
                 value=args.adam_eps)

    pValue = ctypes.cast((ctypes.c_int * 1)(), ctypes.POINTER(ctypes.c_int))
    result = torch.cuda.cudart().cudaDeviceSetLimit(ctypes.c_int(0x05),
                                                    ctypes.c_int(128))
    result = torch.cuda.cudart().cudaDeviceGetLimit(pValue, ctypes.c_int(0x05))

    #    torch.manual_seed(args.seed)

    # Setup task, e.g., translation, language modeling, etc.
    task = tasks.setup_task(args)

    # Build model and criterion
    model = task.build_model(args)
    criterion = task.build_criterion(args)

    print('| model {}, criterion {}'.format(args.arch,
                                            criterion.__class__.__name__))
    print('| num. model params: {}'.format(
        sum(p.numel() for p in model.parameters())))

    # Build trainer
    if args.fp16:
        trainer = FP16Trainer(args,
                              task,
                              model,
                              criterion,
                              allreduce_communicators=allreduce_communicators)
    else:
        if torch.cuda.get_device_capability(0)[0] >= 7:
            print(
                '| NOTICE: your device may support faster training with --fp16'
            )

        trainer = Trainer(args,
                          task,
                          model,
                          criterion,
                          allreduce_communicators=None)

    #if (args.online_eval or args.target_bleu) and not args.remove_bpe:
    #    args.remove_bpe='@@ '

    print('| training on {} GPUs'.format(args.distributed_world_size))
    print('| max tokens per GPU = {} and max sentences per GPU = {}'.format(
        args.max_tokens,
        args.max_sentences,
    ))

    # Initialize dataloader
    max_positions = trainer.get_model().max_positions()

    # Send a dummy batch to warm the caching allocator
    dummy_batch = language_pair_dataset.get_dummy_batch_isolated(
        args.max_tokens, max_positions, 8)
    trainer.dummy_train_step(dummy_batch)

    # Train until the learning rate gets too small or model reaches target score
    max_epoch = args.max_epoch if args.max_epoch >= 0 else math.inf
    max_update = args.max_update or math.inf
    tgt_bleu = args.target_bleu or math.inf
    current_bleu = 0.0
    lr = trainer.get_lr()
    train_meter = StopwatchMeter()
    train_meter.start()
    valid_losses = [None]
    valid_subsets = args.valid_subset.split(',')

    # mlperf compliance synchronization
    if args.distributed_world_size > 1:
        assert (torch.distributed.is_initialized())
        torch.distributed.all_reduce(torch.cuda.FloatTensor(1))
        torch.cuda.synchronize()

    mlperf_print(key=mlperf_compliance.constants.INIT_STOP, sync=True)

    mlperf_print(key=mlperf_compliance.constants.RUN_START, sync=True)
    # second sync after RUN_START tag is printed.
    # this ensures no rank touches data until after RUN_START tag is printed.
    barrier()

    # Load dataset splits
    load_dataset_splits(task, ['train', 'test'])

    ctr = 0

    class DummyEpochBatchIterator:
        def __init__(self, epoch=0):
            self.epoch = epoch

    epoch_itr = DummyEpochBatchIterator(0)

    # Main training loop
    while lr >= args.min_lr and epoch_itr.epoch < max_epoch and trainer.get_num_updates(
    ) < max_update and current_bleu < tgt_bleu:
        first_epoch = epoch_itr.epoch + 1
        mlperf_print(key=mlperf_compliance.constants.BLOCK_START,
                     metadata={
                         'first_epoch_num': first_epoch,
                         'epoch_count': 1
                     },
                     sync=True)
        mlperf_print(key=mlperf_compliance.constants.EPOCH_START,
                     metadata={'epoch_num': first_epoch},
                     sync=True)
        start = time.time()

        gc.disable()

        epoch_itr = data.EpochBatchIterator(
            dataset=task.dataset(args.train_subset),
            dataloader_num_workers=args.dataloader_num_workers,
            dataloader_pin_memory=args.enable_dataloader_pin_memory,
            max_tokens=args.max_tokens,
            max_sentences=args.max_sentences_valid,
            max_positions=max_positions,
            ignore_invalid_inputs=True,
            required_batch_size_multiple=8,
            seed=args.seed,
            num_shards=args.distributed_world_size,
            shard_id=args.distributed_rank,
            epoch=epoch_itr.epoch if ctr is not 0 else 0,
            bucket_growth_factor=args.bucket_growth_factor,
            seq_len_multiple=args.seq_len_multiple,
            batching_scheme=args.batching_scheme,
            batch_multiple_strategy=args.batch_multiple_strategy,
        )

        print("got epoch iterator", time.time() - start)

        # Load the latest checkpoint if one is available
        if ctr is 0:
            load_checkpoint(args, trainer, epoch_itr)

        # train for one epoch
        start = time.time()
        #exit(1)
        train(args, trainer, task, epoch_itr)
        print("epoch time ", time.time() - start)

        start = time.time()
        mlperf_print(key=mlperf_compliance.constants.EPOCH_STOP,
                     metadata={'epoch_num': first_epoch},
                     sync=True)

        #if epoch_itr.epoch % args.validate_interval == 0:
        #    valid_losses = validate(args, trainer, task, epoch_itr, valid_subsets)

        # Eval BLEU score
        if args.online_eval or (not tgt_bleu is math.inf):
            current_bleu = score(args, trainer, task, epoch_itr,
                                 args.gen_subset)
            mlperf_print(key=mlperf_compliance.tags.EVAL_ACCURACY,
                         value=str(current_bleu),
                         metadata={'epoch_num': first_epoch})

        gc.enable()

        # Only use first validation loss to update the learning rate
        #lr = trainer.lr_step(epoch_itr.epoch, valid_losses[0])

        # Save checkpoint
        #if epoch_itr.epoch % args.save_interval == 0:
        #    save_checkpoint(args, trainer, epoch_itr, valid_losses[0])

        ctr = ctr + 1
        print("validation and scoring ", time.time() - start)
        mlperf_print(key=mlperf_compliance.constants.BLOCK_STOP,
                     metadata={'first_epoch_num': first_epoch},
                     sync=True)

    train_meter.stop()
    status = 'success' if current_bleu >= tgt_bleu else 'aborted'
    mlperf_print(key=mlperf_compliance.constants.RUN_STOP,
                 metadata={'status': status})
    print('| done training in {:.1f} seconds'.format(train_meter.sum))