コード例 #1
0
 def load(cls,
          model_path: Path,
          spm_path: Path,
          device: str = None) -> 'ModelWrapper':
     if device is None:
         device = 'cuda' if torch.cuda.is_available() else 'cpu'
     with open(model_path, 'rb') as f:
         state = torch.load(f, map_location='cpu')
     model = MemTransformerLM(**state['model_params'])
     model.load_state_dict(state['state_dict'])
     vocab_params = state['vocab_params']
     vocab = Vocab.from_symbols(state['vocab'], )
     sp_processor = spm.SentencePieceProcessor()
     sp_processor.Load(str(spm_path))
     return cls(model, vocab, sp_processor, device)
コード例 #2
0
 def __init__(self, model: MemTransformerLM, vocab: Vocab,
              sp_processor: spm.SentencePieceProcessor, device: str):
     self.vocab = vocab
     self.sp_processor = sp_processor
     self.device = device
     self.model = model.to(device=self.device)
     self.model.eval()
コード例 #3
0
def widen_model(strategy, ratio, model: MemTransformerLM, optimizers, va_iter, epoch, step):
    optimizer, _ = optimizers
    # pre-expansion validation
    logging(f"evaluating before widening")
    # debug
    val_loss = evaluate(va_iter, model)
    log_val(val_loss, compute=model.compute, step=step)
    # infer example logits for reverse distillation
    # expansion
    logging(f"adding {ratio} layers before starting epoch {epoch} with method {strategy}")
    model.add_heads(ratio, strategy=strategy, function=initialization_func)
    # optimizer update
    for param, states in optimizer.state.items():
        if isinstance(param, nn.Parameter):
            states["exp_avg"] = expand_state(param, states["exp_avg"])
            states["exp_avg_sq"] = expand_state(param, states["exp_avg_sq"])
    # training loop for reverse distillation
    # post-expansion validation
    logging(f"reevaluating")
    val_loss = evaluate(va_iter, model)
    log_val(val_loss, step=step, compute=model.compute)
コード例 #4
0
def expand_model(strategy, integration, integration_length, n_add, model: MemTransformerLM, optimizers, schedulers,
                 tr_iter, va_iter, epoch, step):
    optimizer, _ = optimizers
    scheduler, _ = schedulers
    if integration:
        if not integration_length or integration_length <= 0:
            warnings.warn(f"integration {integration} passed but integration_length is {integration_length}")
        else:
            logging(f"applying integration strategy {integration} with integration length {integration_length}")
    # pre-expansion validation
    logging(f"evaluating before expanding")
    val_loss = evaluate(va_iter, model)
    log_val(val_loss, step=step, compute=model.compute)
    # infer example logits for reverse distillation
    if "reverse_distil" in integration:
        first_logits = get_original_batches(model, tr_iter, integration_length)
    # expansion
    logging(f"adding {n_add} layers before starting epoch {epoch} with method {strategy}")
    new_layers = model.expand_layers(n_add, strategy=strategy, function=initialization_func)
    # optimizer update
    optimizer.add_param_group({'params': new_layers.parameters(), 'lr': optimizer.param_groups[0]["lr"],
                               'initial_lr': optimizer.param_groups[0]["initial_lr"]})
    scheduler.base_lrs.append(optimizer.param_groups[-1]["initial_lr"])
    # training loop for reverse distillation
    if "reverse_distil" in integration:
        fit_to_previous_model(model, new_layers, tr_iter, first_logits, integration)
    # freezing parameters for frozen restart, we do this afterwards else the new layers get copied also without grads
    if "freeze" in integration and integration_length > 0:
        for param_group in optimizer.param_groups[:-1]:
            for parameter in param_group['params']:
                parameter.requires_grad = False
        model.freeze_countdown = integration_length
    # post-expansion validation
    logging(f"reevaluating")
    val_loss = evaluate(va_iter, model)
    log_val(val_loss, step=step, compute=model.compute)
コード例 #5
0
    def __init__(
        self,
        model: Union[str, Dict, MemTransformerLM],
        tokenizer_config: TokenizerWrapperConfig,
        device: torch.device,
        verbose: bool = False,
    ):
        if isinstance(model, Dict):
            model = MemTransformerLM(**model)
        elif isinstance(model, str):
            model = self._load_model(model, device)
        elif isinstance(model, MemTransformerLM):
            model = model
        else:
            raise TypeError(
                f"model has type {type(model)}, but only str, Dict or MemTransformerLM allowed"
            )

        super().__init__(model, tokenizer_config, device)

        self._model.reset_length(1, 0, self._model.mem_len)
        self._batch_len = self._model.mem_len
        self._verbose = verbose
コード例 #6
0
def do_eval(args):
    assert args.ext_len >= 0, 'Extended context length must be no less than 0'

    def _evaluate(loader):
        total_len, total_loss = 0, 0.

        eval_mems = tuple()
        for i, (src, target, seq_len) in enumerate(loader):
            if args.max_eval_steps > 0 and i >= args.max_eval_steps:
                break
            ret = mem_transformer(src, target, *eval_mems)
            loss, eval_mems = ret[0], ret[1:]
            eval_cur_loss = seq_len * loss.numpy()
            total_loss += eval_cur_loss
            total_len += seq_len
        return total_loss / total_len

    def _logger(loss):
        if args.dataset in ['enwik8', 'text8']:
            logger_info = "loss: %f, bpc: %f" % \
                          (loss, loss / np.log(2))
        else:
            logger_info = "loss: %f, ppl: %.2f" % \
                          (loss, np.exp(loss))
        return logger_info

    if not args.use_gpu:
        paddle.set_device("cpu")

    vocab = get_lm_vocab(args)
    eval_loader = get_lm_data_loader(args, vocab, "valid")
    test_loader = get_lm_data_loader(args, vocab, "test")

    cutoffs, tie_projs = [], [False]
    if args.adaptive:
        assert args.dataset in ['wt103', 'lm1b']
        if args.dataset == 'wt103':
            cutoffs = [20000, 40000, 200000]
            tie_projs += [True] * len(cutoffs)
        elif args.dataset == 'lm1b':
            cutoffs = [60000, 100000, 640000]
            tie_projs += [False] * len(cutoffs)

    mem_transformer = MemTransformerLM(args.ntokens,
                                       args.n_layer,
                                       args.n_head,
                                       args.d_model,
                                       args.d_head,
                                       args.d_inner_hid,
                                       args.dropout,
                                       args.attn_dropout,
                                       tie_weight=args.tie_weight,
                                       d_embed=args.d_model,
                                       div_val=args.div_val,
                                       tie_projs=tie_projs,
                                       normalize_before=args.normalize_before,
                                       tgt_len=args.tgt_len,
                                       ext_len=args.ext_len,
                                       mem_len=args.mem_len,
                                       cutoffs=cutoffs,
                                       same_length=args.same_length,
                                       attn_type=args.attn_type,
                                       clamp_len=args.clamp_len,
                                       sample_softmax=args.sample_softmax)

    assert args.init_from_params, (
        "Please set init_from_params to load the infer model.")

    model_dict = paddle.load(
        os.path.join(args.init_from_params, "mem_transformer.pdparams"))
    mem_transformer.load_dict(model_dict)

    logger.info(
        "Evaluating with bsz {} tgt_len {} ext_len {} mem_len {} clamp_len {}".
        format(args.eval_batch_size, args.tgt_len, args.ext_len, args.mem_len,
               args.clamp_len))

    mem_transformer.reset_length(args.tgt_len, args.ext_len, args.mem_len)

    test_loss = None
    valid_loss = None
    if args.mode == 'all':
        test_loss = _evaluate(test_loader)
        valid_loss = _evaluate(eval_loader)
    elif args.mode == 'valid':
        valid_loss = _evaluate(eval_loader)
    elif args.mode == 'test':
        test_loss = _evaluate(test_loader)

    logger_info = ''
    if valid_loss is not None:
        logger_info = logger_info + "validation loss: " + _logger(
            valid_loss) + " | "
    if test_loss is not None:
        logger_info = logger_info + "test loss: " + _logger(test_loss) + " | "
    logger.info(logger_info)
コード例 #7
0
ファイル: inference.py プロジェクト: zaemyung/transformer-xl
    def load(cls, model_path, sp_model_path, device, print_stats=True):

        paramspath = os.path.join(model_path, 'params.json')
        with open(paramspath, 'r') as paramsf:
            xl_params = json.loads(paramsf.read())

        print(repr(xl_params))

        model = MemTransformerLM(
            xl_params['ntokens'],  # 50000,
            xl_params['n_layer'],  # 16,
            xl_params['n_head'],  # 10,
            xl_params['d_model'],  # 410,
            xl_params['d_head'],  # 41,
            xl_params['d_inner'],  # 2100,
            0.0,  # no dropout, 
            0.0,  # no dropatt,
            tie_weight=xl_params['tie_weight'],  # True, 
            d_embed=xl_params['d_embed'],  # 410, 
            div_val=xl_params['div_val'],  # 1,
            tie_projs=xl_params['tie_projs'],  # [False, True, True, True] 
            pre_lnorm=xl_params['pre_lnorm'],  # False, 
            tgt_len=xl_params['tgt_len'],  # 150,
            ext_len=xl_params['ext_len'],  # 0, 
            mem_len=xl_params['mem_len'],  # 150, 
            cutoffs=xl_params['cutoffs'],  # [3500, 7500, 37500],
            same_length=xl_params['same_length'],  # False,
            attn_type=xl_params['attn_type'],  # 0,
            clamp_len=xl_params['clamp_len'],  # -1, 
            sample_softmax=xl_params['sample_softmax'])  # -1

        state_dict_path = os.path.join(model_path, 'valid_state_dict.pt')
        print("loading weights %s ..." % state_dict_path)
        tensor_dict = torch.load(state_dict_path,
                                 map_location=torch.device(device))
        model.load_state_dict(tensor_dict)
        print("loading weights %s ... done." % state_dict_path)

        if print_stats:
            tensor_list = list(tensor_dict.items())
            for layer_tensor_name, tensor in tensor_list:
                print("Layer %-42s: %9d elements" %
                      (layer_tensor_name, torch.numel(tensor)))

            pytorch_total_params = sum(p.numel() for p in model.parameters())
            print("Total # params: %d" % pytorch_total_params)

        # with open(os.path.join(MODEL_PATH, 'model.pt'), 'rb') as f:
        #     model = torch.load(f)
        # model.apply(update_dropout)
        # model.apply(update_dropatt)

        para_model = model.to(device)

        # print ("loading model %s ... done." % MODEL_PATH)

        print("loading sp model from %s ..." % sp_model_path)
        sp_model = spm.SentencePieceProcessor()
        sp_model.load(sp_model_path)
        print("loading sp model from %s ... done." % sp_model_path)

        return cls(para_model, sp_model, device)
コード例 #8
0
ファイル: train.py プロジェクト: Vasyka/DeepGQuad
        model = torch.load(f)
        model = model.float()
    model.apply(update_dropout)
    model.apply(update_dropatt)
else:
    model = MemTransformerLM(n_token_in,
                             n_token_out,
                             args.n_layer,
                             args.n_head,
                             args.d_model,
                             args.d_head,
                             args.d_inner,
                             args.dropout,
                             args.dropatt,
                             args.conv_size,
                             args.conv_emb,
                             args.pre_conv,
                             tie_weight=args.tied,
                             d_embed=args.d_embed,
                             tie_projs=tie_projs,
                             tgt_len=args.tgt_len,
                             mem_len=args.mem_len,
                             ext_ds=args.ext_ds,
                             cutoffs=cutoffs,
                             same_length=args.same_length,
                             clamp_len=args.clamp_len)
    model.apply(weights_init)
    model.word_emb.apply(
        weights_init
    )  # ensure embedding init is not overridden by out_layer in case of weight sharing
args.n_all_param = sum([p.nelement() for p in model.parameters()])
コード例 #9
0
def main():
    args = parse_args()
    if args.affinity != 'disabled':
        nproc_per_node = torch.cuda.device_count()
        affinity = utils.gpu_affinity.set_affinity(args.local_rank,
                                                   nproc_per_node,
                                                   args.affinity)
        print(f'{args.local_rank}: thread affinity: {affinity}')

    # Initialize device and distributed backend
    torch.cuda.set_device(args.local_rank)
    l2_promote()
    device = torch.device('cuda' if args.cuda else 'cpu')
    utils.distributed.init_distributed(args.cuda)

    args.work_dir = utils.exp_utils.build_work_dir_name(
        args.work_dir,
        args.dataset,
        args.append_dataset,
        args.append_time,
    )

    with utils.distributed.sync_workers() as rank:
        if rank == 0:
            create_exp_dir(args.work_dir,
                           scripts_to_save=['train.py', 'mem_transformer.py'],
                           debug=args.debug)

    # Setup logging
    if args.log_all_ranks:
        log_file = f'train_log_rank_{utils.distributed.get_rank()}.log'
    else:
        log_file = args.txtlog_file
    dllog_file = args.dllog_file
    log_file = os.path.join(args.work_dir, log_file)
    dllog_file = os.path.join(args.work_dir, dllog_file)

    if args.debug:
        log_file = os.devnull
        dllog_file = os.devnull

    utils.exp_utils.setup_logging(
        log_all_ranks=args.log_all_ranks,
        filename=log_file,
    )
    utils.exp_utils.setup_dllogger(enabled=True, filename=dllog_file)

    if args.local_batch_size is not None:
        world_size = utils.distributed.get_world_size()
        args.batch_size = world_size * args.local_batch_size
        logging.info(f'--local_batch_size was set, adjusting global batch size'
                     f' to {args.batch_size} (local_batch_size * world_size)')
        if args.batch_size % args.batch_chunk != 0:
            raise RuntimeError('Batch size needs to be divisible by '
                               'batch chunk')

    if args.profile:
        try:
            pyprof.init(enable_function_stack=True)
        except NameError:
            warnings.warn('Called pyprof.init() but pyprof is not available')

    logging.info(args)
    dllogger.log(step='PARAMETER', data=vars(args))

    logging.info(f'world size: {utils.distributed.get_world_size()}')

    if not args.no_env:
        log_env_info()

    register_ignoring_timeout_handler()

    # Set the random seed manually for reproducibility.
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)

    ###########################################################################
    # Load data
    ###########################################################################
    corpus = get_lm_corpus(args.data, args.dataset, args.vocab)
    ntokens = len(corpus.vocab)
    vocab = corpus.vocab
    args.n_token = ntokens

    if args.mem_len == 0:
        eval_mem_len = 0
    else:
        eval_mem_len = args.mem_len + args.tgt_len - args.eval_tgt_len

    tr_iter = corpus.get_iterator('train',
                                  args.batch_size,
                                  args.tgt_len,
                                  device=device,
                                  ext_len=args.ext_len)
    va_iter = corpus.get_iterator('valid',
                                  args.eval_batch_size,
                                  args.eval_tgt_len,
                                  device=device,
                                  mem_len=eval_mem_len,
                                  ext_len=args.ext_len)
    te_iter = corpus.get_iterator('test',
                                  args.eval_batch_size,
                                  args.eval_tgt_len,
                                  device=device,
                                  mem_len=eval_mem_len,
                                  ext_len=args.ext_len)

    # adaptive softmax / embedding
    cutoffs, tie_projs = [], [False]
    if args.adaptive:
        assert args.dataset in ['wt103', 'lm1b']
        if args.dataset == 'wt103':
            cutoffs = [19997, 39997, 199997]
            tie_projs += [True] * len(cutoffs)
        elif args.dataset == 'lm1b':
            cutoffs = [59997, 99997, 639997]
            tie_projs += [False] * len(cutoffs)

    ###########################################################################
    # Build the model
    ###########################################################################
    model_config = {
        'n_token': ntokens,
        'n_layer': args.n_layer,
        'n_head': args.n_head,
        'd_model': args.d_model,
        'd_head': args.d_head,
        'd_inner': args.d_inner,
        'dropout': args.dropout,
        'dropatt': args.dropatt,
        'dtype': None,
        'tie_weight': args.tied,
        'd_embed': args.d_embed,
        'div_val': args.div_val,
        'tie_projs': tie_projs,
        'pre_lnorm': args.pre_lnorm,
        'tgt_len': args.tgt_len,
        'ext_len': args.ext_len,
        'mem_len': args.mem_len,
        'cutoffs': cutoffs,
        'same_length': args.same_length,
        'attn_type': args.attn_type,
        'clamp_len': args.clamp_len,
        'sample_softmax': args.sample_softmax,
    }

    model = MemTransformerLM(**model_config)

    model.apply(functools.partial(weights_init, args=args))
    # ensure embedding init is not overridden by out_layer in case of weight sharing
    model.word_emb.apply(functools.partial(weights_init, args=args))

    args.n_all_param = sum([p.nelement() for p in model.parameters()])
    args.n_nonemb_param = sum(
        [p.nelement() for p in model.layers.parameters()])

    # optimizer
    if args.optim.lower() == 'sgd':
        if args.sample_softmax > 0:
            dense_params, sparse_params = [], []
            for param in model.parameters():
                if param.size() == model.word_emb.weight.size():
                    sparse_params.append(param)
                else:
                    dense_params.append(param)
            optimizer_sparse = optim.SGD(sparse_params, lr=args.lr * 2)
            optimizer = optim.SGD(dense_params, lr=args.lr, momentum=args.mom)
        else:
            optimizer = optim.SGD(model.parameters(),
                                  lr=args.lr,
                                  momentum=args.mom)
            optimizer_sparse = None
    elif args.optim.lower() == 'adam':
        if args.sample_softmax > 0:
            dense_params, sparse_params = [], []
            for param in model.parameters():
                if param.size() == model.word_emb.weight.size():
                    sparse_params.append(param)
                else:
                    dense_params.append(param)
            optimizer_sparse = optim.SparseAdam(sparse_params, lr=args.lr)
            optimizer = optim.Adam(dense_params,
                                   lr=args.lr,
                                   weight_decay=args.weight_decay)
        else:
            optimizer = optim.Adam(model.parameters(),
                                   lr=args.lr,
                                   weight_decay=args.weight_decay)
            optimizer_sparse = None
    elif args.optim.lower() == 'adagrad':
        optimizer = optim.Adagrad(model.parameters(), lr=args.lr)
        optimizer_sparse = None
    elif args.optim.lower() == 'lamb':
        optimizer = lamb.Lamb(model.parameters(),
                              lr=args.lr,
                              weight_decay=args.weight_decay)
        optimizer_sparse = None
    elif args.optim.lower() == 'jitlamb':
        optimizer = lamb.JITLamb(model.parameters(),
                                 lr=args.lr,
                                 weight_decay=args.weight_decay)
        optimizer_sparse = None

    model = model.to(device)

    scaler = None
    if args.fp16:
        if args.amp == 'pytorch':
            scaler = torch.cuda.amp.GradScaler()
        elif args.amp == 'apex':
            model, optimizer = amp.initialize(
                model,
                optimizer,
                opt_level=args.apex_amp_opt_level,
            )

    if args.multi_gpu == 'ddp' and torch.distributed.is_initialized():
        para_model = DistributedDataParallel(
            model,
            device_ids=[args.local_rank],
            output_device=args.local_rank,
            broadcast_buffers=False,
            find_unused_parameters=True,
        )
    elif args.multi_gpu == 'dp':
        if args.gpu0_bsz >= 0:
            para_model = BalancedDataParallel(args.gpu0_bsz //
                                              args.batch_chunk,
                                              model,
                                              dim=1).to(device)
        else:
            para_model = nn.DataParallel(model, dim=1).to(device)
    else:
        para_model = model

    # scheduler
    if args.scheduler == 'cosine':
        if args.max_step_scheduler:
            max_step = args.max_step_scheduler
        else:
            max_step = args.max_step

        scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer,
                                                         max_step -
                                                         args.warmup_step,
                                                         eta_min=args.eta_min)
        if args.sample_softmax > 0 and optimizer_sparse is not None:
            scheduler_sparse = optim.lr_scheduler.CosineAnnealingLR(
                optimizer_sparse,
                max_step - args.warmup_step,
                eta_min=args.eta_min)
        else:
            scheduler_sparse = None
    elif args.scheduler == 'inv_sqrt':
        # originally used for Transformer (in Attention is all you need)
        def lr_lambda(step):
            # return a multiplier instead of a learning rate
            if step == 0 and args.warmup_step == 0:
                return 1.
            else:
                return 1. / (step ** 0.5) if step > args.warmup_step \
                    else step / (args.warmup_step ** 1.5)

        scheduler = optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lr_lambda)
        if args.sample_softmax > 0 and optimizer_sparse is not None:
            scheduler_sparse = optim.lr_scheduler.LambdaLR(optimizer_sparse,
                                                           lr_lambda=lr_lambda)
        else:
            scheduler_sparse = None
    elif args.scheduler == 'dev_perf':
        scheduler = optim.lr_scheduler.ReduceLROnPlateau(
            optimizer,
            factor=args.decay_rate,
            patience=args.patience,
            min_lr=args.lr_min,
        )
        if args.sample_softmax > 0 and optimizer_sparse is not None:
            scheduler_sparse = optim.lr_scheduler.ReduceLROnPlateau(
                optimizer_sparse,
                factor=args.decay_rate,
                patience=args.patience,
                min_lr=args.lr_min,
            )
        else:
            scheduler_sparse = None
    elif args.scheduler == 'constant':
        pass

    logging.info('=' * 100)
    for k, v in args.__dict__.items():
        logging.info('    - {} : {}'.format(k, v))
    logging.info('=' * 100)
    logging.info('#params = {}'.format(args.n_all_param))
    logging.info('#non emb params = {}'.format(args.n_nonemb_param))

    train_step = 0
    start_epoch = 1
    last_batch = 0
    last_iter = 0
    best_val_loss = None

    if args.restart:
        try:
            checkpoint = load_checkpoint(args.restart)
            model.load_state_dict(checkpoint['model_state'])
            optimizer.load_state_dict(checkpoint['optimizer_state'])
            scheduler.load_state_dict(checkpoint['scheduler_state'])
            if args.fp16:
                if args.amp == 'pytorch':
                    scaler.load_state_dict(checkpoint['amp_state'])
                elif args.amp == 'apex':
                    amp.load_state_dict(checkpoint['amp_state'])
            train_step = checkpoint['train_step']
            start_epoch = checkpoint['epoch']
            last_batch = checkpoint['batch']
            last_iter = checkpoint['last_iter']
            best_val_loss = checkpoint['best_val_loss']

            if train_step >= args.max_step:
                logging.info(
                    f'Loaded checkpoint after {train_step} steps, but '
                    f'this run was scheduled for a total of '
                    f'{args.max_step} steps, exiting')
                sys.exit(1)

            model.apply(functools.partial(update_dropout, args=args))
            model.apply(functools.partial(update_dropatt, args=args))
        except FileNotFoundError:
            logging.info(f'Could not load checkpoint from {args.restart}, '
                         f'starting training from random init')

    meters = {}
    warmup = args.mem_len // args.tgt_len + 2
    meters['train_throughput'] = AverageMeter(warmup=warmup)
    ###########################################################################
    # Train
    ###########################################################################
    # Loop over epochs.
    # At any point you can hit Ctrl + C to break out of training early.
    start_time = time.time()
    with torch.autograd.profiler.emit_nvtx(enabled=args.profile):
        with TimeoutHandler() as timeout_handler:
            try:
                for epoch in itertools.count(start=start_epoch):
                    if args.roll:
                        tr_iter.roll(seed=args.seed + epoch)
                    train_step, best_val_loss = train(
                        tr_iter, va_iter, model, para_model, model_config,
                        optimizer, optimizer_sparse, scheduler,
                        scheduler_sparse, scaler, vocab, epoch, last_batch,
                        last_iter, train_step, best_val_loss, meters,
                        timeout_handler, device, args)

                    last_batch = 0
                    last_iter = 0

                    if train_step == args.max_step:
                        logging.info('-' * 100)
                        logging.info('End of training')
                        break
            except KeyboardInterrupt:
                logging.info('-' * 100)
                logging.info('Exiting from training early')
    elapsed = time.time() - start_time

    ###########################################################################
    # Test
    ###########################################################################
    summary = {}
    test_path = os.path.join(args.work_dir, 'checkpoint_best.pt')
    if not args.debug and not args.no_eval and os.path.exists(test_path):
        # Load the best saved model.
        checkpoint = load_checkpoint(test_path)
        model.load_state_dict(checkpoint['model_state'])

        # Run on test data.
        test_start_time = time.time()
        with torch.autograd.profiler.emit_nvtx(enabled=args.profile):
            test_loss = evaluate(te_iter, model, args)
            test_loss = utils.distributed.all_reduce_item(test_loss, 'mean')
        test_elapsed = time.time() - test_start_time

        logging.info('=' * 100)
        if args.dataset in ['enwik8', 'text8']:
            logging.info(
                '| End of training | test time: {:5.2f}s | test loss {:5.2f} | test bpc {:9.5f}'
                .format(test_elapsed, test_loss, test_loss / math.log(2)))
        else:
            logging.info(
                '| End of training | test time: {:5.2f}s | test loss {:5.2f} | test ppl {:9.3f}'
                .format(test_elapsed, test_loss, math.exp(test_loss)))
        logging.info('=' * 100)

        summary.update({
            'test_elapsed': test_elapsed,
            'test_loss': test_loss,
        })

        if args.dataset in ['enwik8', 'text8']:
            summary['test_bits_per_character'] = test_loss / math.log(2)
        else:
            summary['test_perplexity'] = math.exp(test_loss)

    logging.info(f'Training time: {(elapsed / 60):.2f} minutes')
    logging.info(
        f'Training throughput: {meters["train_throughput"].avg:.2f} tok/s')

    if best_val_loss:
        val_perplexity = math.exp(best_val_loss)
    else:
        val_perplexity = None

    summary.update({
        'train_throughput': meters['train_throughput'].avg,
        'train_elapsed': elapsed / 60,
        'valid_loss': best_val_loss,
        'valid_perplexity': val_perplexity,
    })
    dllogger.log(step=tuple(), data=summary)

    passed = benchmark(target_perplexity=args.target_perplexity,
                       test_perplexity=val_perplexity,
                       target_throughput=args.target_throughput,
                       test_throughput=meters['train_throughput'].avg)
    if not passed:
        sys.exit(1)
コード例 #10
0
def main():
    args = parse_args()

    # Initialize device and distributed backend
    torch.cuda.set_device(args.local_rank)
    device = torch.device('cuda' if args.cuda else 'cpu')
    utils.distributed.init_distributed(args.cuda)

    args.work_dir = utils.exp_utils.build_work_dir_name(
        args.work_dir,
        args.dataset,
        args.append_dataset,
        args.append_time,
    )

    with utils.distributed.sync_workers() as rank:
        if rank == 0:
            create_exp_dir(args.work_dir,
                           scripts_to_save=['train.py', 'mem_transformer.py'],
                           debug=args.debug)

    # Setup logging
    if args.log_all_ranks:
        log_file = f'log_rank_{utils.distributed.get_rank()}.log'
    else:
        log_file = f'log.log'
    log_file = os.path.join(args.work_dir, log_file)

    if args.debug:
        log_file = os.devnull

    utils.exp_utils.setup_logging(
        log_all_ranks=args.log_all_ranks,
        filename=log_file,
    )
    logging.info(args)

    # Set the random seed manually for reproducibility.
    np.random.seed(args.seed + utils.distributed.get_rank())
    torch.manual_seed(args.seed + utils.distributed.get_rank())

    ###########################################################################
    # Load data
    ###########################################################################
    corpus = get_lm_corpus(args.data, args.dataset, args.vocab)
    ntokens = len(corpus.vocab)
    vocab = corpus.vocab
    args.n_token = ntokens

    tr_iter = corpus.get_iterator('train',
                                  args.batch_size,
                                  args.tgt_len,
                                  device=device,
                                  ext_len=args.ext_len)
    va_iter = corpus.get_iterator('valid',
                                  args.eval_batch_size,
                                  args.eval_tgt_len,
                                  device=device,
                                  ext_len=args.ext_len)
    te_iter = corpus.get_iterator('test',
                                  args.eval_batch_size,
                                  args.eval_tgt_len,
                                  device=device,
                                  ext_len=args.ext_len)

    # adaptive softmax / embedding
    cutoffs, tie_projs = [], [False]
    if args.adaptive:
        assert args.dataset in ['wt103', 'lm1b']
        if args.dataset == 'wt103':
            cutoffs = [19997, 39997, 199997]
            tie_projs += [True] * len(cutoffs)
        elif args.dataset == 'lm1b':
            cutoffs = [59997, 99997, 639997]
            tie_projs += [False] * len(cutoffs)

    ###########################################################################
    # Build the model
    ###########################################################################
    model_config = {
        'n_token': ntokens,
        'n_layer': args.n_layer,
        'n_head': args.n_head,
        'd_model': args.d_model,
        'd_head': args.d_head,
        'd_inner': args.d_inner,
        'dropout': args.dropout,
        'dropatt': args.dropatt,
        'dtype': None,
        'tie_weight': args.tied,
        'd_embed': args.d_embed,
        'div_val': args.div_val,
        'tie_projs': tie_projs,
        'pre_lnorm': args.pre_lnorm,
        'tgt_len': args.tgt_len,
        'ext_len': args.ext_len,
        'mem_len': args.mem_len,
        'cutoffs': cutoffs,
        'same_length': args.same_length,
        'attn_type': args.attn_type,
        'clamp_len': args.clamp_len,
        'sample_softmax': args.sample_softmax,
    }

    model = MemTransformerLM(**model_config)

    model.apply(functools.partial(weights_init, args=args))
    # ensure embedding init is not overridden by out_layer in case of weight sharing
    model.word_emb.apply(functools.partial(weights_init, args=args))

    args.n_all_param = sum([p.nelement() for p in model.parameters()])
    args.n_nonemb_param = sum(
        [p.nelement() for p in model.layers.parameters()])

    # optimizer
    if args.optim.lower() == 'sgd':
        if args.sample_softmax > 0:
            dense_params, sparse_params = [], []
            for param in model.parameters():
                if param.size() == model.word_emb.weight.size():
                    sparse_params.append(param)
                else:
                    dense_params.append(param)
            optimizer_sparse = optim.SGD(sparse_params, lr=args.lr * 2)
            optimizer = optim.SGD(dense_params, lr=args.lr, momentum=args.mom)
        else:
            optimizer = optim.SGD(model.parameters(),
                                  lr=args.lr,
                                  momentum=args.mom)
            optimizer_sparse = None
    elif args.optim.lower() == 'adam':
        if args.sample_softmax > 0:
            dense_params, sparse_params = [], []
            for param in model.parameters():
                if param.size() == model.word_emb.weight.size():
                    sparse_params.append(param)
                else:
                    dense_params.append(param)
            optimizer_sparse = optim.SparseAdam(sparse_params, lr=args.lr)
            optimizer = optim.Adam(dense_params,
                                   lr=args.lr,
                                   weight_decay=args.weight_decay)
        else:
            optimizer = optim.Adam(model.parameters(),
                                   lr=args.lr,
                                   weight_decay=args.weight_decay)
            optimizer_sparse = None
    elif args.optim.lower() == 'adagrad':
        optimizer = optim.Adagrad(model.parameters(), lr=args.lr)
        optimizer_sparse = None
    elif args.optim.lower() == 'lamb':
        optimizer = lamb.Lamb(model.parameters(),
                              lr=args.lr,
                              weight_decay=args.weight_decay)
        optimizer_sparse = None
    elif args.optim.lower() == 'jitlamb':
        optimizer = lamb.JITLamb(model.parameters(),
                                 lr=args.lr,
                                 weight_decay=args.weight_decay)
        optimizer_sparse = None

    model = model.to(device)

    if args.fp16:
        model, optimizer = amp.initialize(
            model,
            optimizer,
            opt_level='O2',
        )

    if args.multi_gpu == 'ddp' and torch.distributed.is_initialized():
        para_model = DistributedDataParallel(
            model,
            delay_allreduce=True,
        )
    elif args.multi_gpu == 'dp':
        if args.gpu0_bsz >= 0:
            para_model = BalancedDataParallel(args.gpu0_bsz //
                                              args.batch_chunk,
                                              model,
                                              dim=1).to(device)
        else:
            para_model = nn.DataParallel(model, dim=1).to(device)
    else:
        para_model = model

    # scheduler
    if args.scheduler == 'cosine':
        if args.max_step_scheduler:
            max_step = args.max_step_scheduler
        else:
            max_step = args.max_step

        scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer,
                                                         max_step,
                                                         eta_min=args.eta_min)
        if args.sample_softmax > 0:
            scheduler_sparse = optim.lr_scheduler.CosineAnnealingLR(
                optimizer_sparse, max_step, eta_min=args.eta_min)
        else:
            scheduler_sparse = None
    elif args.scheduler == 'inv_sqrt':
        # originally used for Transformer (in Attention is all you need)
        def lr_lambda(step):
            # return a multiplier instead of a learning rate
            if step == 0 and args.warmup_step == 0:
                return 1.
            else:
                return 1. / (step ** 0.5) if step > args.warmup_step \
                    else step / (args.warmup_step ** 1.5)

        scheduler = optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lr_lambda)
    elif args.scheduler == 'dev_perf':
        scheduler = optim.lr_scheduler.ReduceLROnPlateau(
            optimizer,
            factor=args.decay_rate,
            patience=args.patience,
            min_lr=args.lr_min,
        )
        if args.sample_softmax > 0:
            scheduler_sparse = optim.lr_scheduler.ReduceLROnPlateau(
                optimizer_sparse,
                factor=args.decay_rate,
                patience=args.patience,
                min_lr=args.lr_min,
            )
        else:
            scheduler_sparse = None
    elif args.scheduler == 'constant':
        pass

    logging.info('=' * 100)
    for k, v in args.__dict__.items():
        logging.info('    - {} : {}'.format(k, v))
    logging.info('=' * 100)
    logging.info('#params = {}'.format(args.n_all_param))
    logging.info('#non emb params = {}'.format(args.n_nonemb_param))

    train_step = 0
    best_val_loss = None

    if args.restart:
        checkpoint = load_checkpoint(args.restart)
        model.load_state_dict(checkpoint['model_state'])
        optimizer.load_state_dict(checkpoint['optimizer_state'])
        scheduler.load_state_dict(checkpoint['scheduler_state'])
        if args.fp16:
            amp.load_state_dict(checkpoint['amp_state'])
        train_step = checkpoint['train_step']
        best_val_loss = checkpoint['best_val_loss']

        model.apply(functools.partial(update_dropout, args=args))
        model.apply(functools.partial(update_dropatt, args=args))

    meters = {}
    warmup = args.mem_len // args.tgt_len + 1
    meters['train_throughput'] = AverageMeter(warmup=warmup)
    ###########################################################################
    # Train
    ###########################################################################
    # Loop over epochs.
    # At any point you can hit Ctrl + C to break out of training early.
    start_time = time.time()
    try:
        for epoch in itertools.count(start=1):
            if args.roll:
                tr_iter.roll()
            train_step, best_val_loss = train(tr_iter, va_iter, model,
                                              para_model, model_config,
                                              optimizer, optimizer_sparse,
                                              scheduler, scheduler_sparse,
                                              vocab, epoch, train_step,
                                              best_val_loss, meters, args)

            if train_step == args.max_step:
                logging.info('-' * 100)
                logging.info('End of training')
                break
    except KeyboardInterrupt:
        logging.info('-' * 100)
        logging.info('Exiting from training early')
    elapsed = time.time() - start_time

    ###########################################################################
    # Test
    ###########################################################################
    test_path = os.path.join(args.work_dir, 'checkpoint_best.pt')
    if not args.debug and os.path.exists(test_path):
        # Load the best saved model.
        checkpoint = load_checkpoint(test_path)
        model.load_state_dict(checkpoint['model_state'])

        # Run on test data.
        test_start_time = time.time()
        test_loss = evaluate(te_iter, model, args)
        test_loss = utils.distributed.all_reduce_item(test_loss, 'mean')

        logging.info('=' * 100)
        if args.dataset in ['enwik8', 'text8']:
            logging.info(
                '| End of training | test time: {:5.2f}s | test loss {:5.2f} | test bpc {:9.5f}'
                .format(time.time() - test_start_time, test_loss,
                        test_loss / math.log(2)))
        else:
            logging.info(
                '| End of training | test time: {:5.2f}s | test loss {:5.2f} | test ppl {:9.3f}'
                .format(time.time() - test_start_time, test_loss,
                        math.exp(test_loss)))
        logging.info('=' * 100)

    logging.info(f'Training time: {(elapsed / 60):.2f} minutes')
    logging.info(
        f'Training throughput: {meters["train_throughput"].avg:.2f} tok/s')

    if best_val_loss:
        val_perplexity = math.exp(best_val_loss)
    else:
        val_perplexity = None

    passed = benchmark(target_perplexity=args.target_perplexity,
                       test_perplexity=val_perplexity,
                       target_throughput=args.target_throughput,
                       test_throughput=meters['train_throughput'].avg)
    if not passed:
        sys.exit(1)
コード例 #11
0
ファイル: train.py プロジェクト: rsvp-ai/segatron_aaai
                                         same_length=args.same_length,
                                         attn_type=args.attn_type,
                                         clamp_len=args.clamp_len,
                                         sample_softmax=args.sample_softmax)
    else:
        model = MemTransformerLM(ntokens,
                                 args.n_layer,
                                 args.n_head,
                                 args.d_model,
                                 args.d_head,
                                 args.d_inner,
                                 args.dropout,
                                 args.dropatt,
                                 tie_weight=args.tied,
                                 d_embed=args.d_embed,
                                 div_val=args.div_val,
                                 tie_projs=tie_projs,
                                 pre_lnorm=args.pre_lnorm,
                                 tgt_len=args.tgt_len,
                                 ext_len=args.ext_len,
                                 mem_len=args.mem_len,
                                 cutoffs=cutoffs,
                                 same_length=args.same_length,
                                 attn_type=args.attn_type,
                                 clamp_len=args.clamp_len,
                                 sample_softmax=args.sample_softmax)
    model.apply(weights_init)
    model.word_emb.apply(
        weights_init
    )  # ensure embedding init is not overridden by out_layer in case of weight sharing
args.n_all_param = sum([p.nelement() for p in model.parameters()])
コード例 #12
0
n_token_in = (4 + 4 * args.methylation)**args.merge_size
n_token_out = 2

print(args.coords[1:4])

model = MemTransformerLM(n_token_in,
                         n_token_out,
                         args.n_layer,
                         args.n_head,
                         args.d_model,
                         args.d_head,
                         args.d_inner,
                         args.dropout,
                         args.dropatt,
                         tie_weight=args.tied,
                         d_embed=args.d_embed,
                         custom_emb=None,
                         tie_projs=[False],
                         pre_lnorm=args.pre_lnorm,
                         tgt_len=args.tgt_len,
                         ext_len=args.ext_len,
                         mem_len=args.mem_len,
                         cutoffs=[],
                         same_length=args.same_length,
                         clamp_len=args.clamp_len,
                         sample_softmax=args.sample_softmax)

model.load_state_dict(torch.load(os.path.join(args.restart_dir, 'model.pt')))
model = model.to(device)

コード例 #13
0
    mem_len=args.mem_len,
    cutoffs=cutoffs,
    same_length=args.same_length,
    attn_type=args.attn_type,
    clamp_len=args.clamp_len,
    sample_softmax=args.sample_softmax,
)
if args.restart:
    print('Restarting training from {args.restart_dir}')
    with open(os.path.join(args.restart_dir, 'model.pt'), 'rb') as f:
        state = torch.load(f)
    if isinstance(state, MemTransformerLM):  # old format
        model = state
    else:
        model_params = state['model_params']
        model = MemTransformerLM(**model_params)
        model.load_state_dict(state['state_dict'])
        del state
    if not args.fp16:
        model = model.float()
    model.apply(update_dropout)
    model.apply(update_dropatt)
else:
    model = MemTransformerLM(**model_params)
    model.apply(weights_init)
    # ensure embedding init is not overridden by out_layer
    # in case of weight sharing
    model.word_emb.apply(weights_init)
args.n_all_param = sum([p.nelement() for p in model.parameters()])
args.n_nonemb_param = sum([p.nelement() for p in model.layers.parameters()])
コード例 #14
0
def train_ts(args):
    def build_scheduler(optimizers, args):
        optimizer, optimizer_sparse = optimizers
        scheduler_sparse = None

        if args.scheduler == "cosine":
            # here we do not set eta_min to lr_min to be backward compatible
            # because in previous versions eta_min is default to 0
            # rather than the default value of lr_min 1e-6
            scheduler = optim.lr_scheduler.CosineAnnealingLR(
                optimizer, args.max_step,
                eta_min=args.eta_min)  # should use eta_min arg

        elif args.scheduler == "inv_sqrt":
            # originally used for Transformer (in Attention is all you need)
            def lr_lambda(step):
                # return a multiplier instead of a learning rate
                if step == 0 and args.warmup_step == 0:
                    return 1.0
                else:
                    return (1.0 /
                            (step**0.5) if step > args.warmup_step else step /
                            (args.warmup_step**1.5))

            scheduler = optim.lr_scheduler.LambdaLR(optimizer,
                                                    lr_lambda=lr_lambda)

        elif args.scheduler == "dev_perf":
            scheduler = optim.lr_scheduler.ReduceLROnPlateau(
                optimizer,
                factor=args.decay_rate,
                patience=args.patience,
                min_lr=args.lr_min,
            )

        elif args.scheduler == "constant":
            pass

        else:
            raise ValueError(f"scheduler type {args.scheduler} not recognized")

        return scheduler, scheduler_sparse

    ###############################################################################
    # Training code
    ###############################################################################
    def evaluate(eval_iter, model):

        # Turn on evaluation mode which disables dropout.
        model.eval()

        # debug
        # If the model does not use memory at all, make the ext_len longer.
        # Otherwise, make the mem_len longer and keep the ext_len the same.
        # if default_args.mem_len == 0:
        #     model.reset_length(default_args.eval_tgt_len,
        #                        default_args.ext_len + default_args.tgt_len -
        #                        default_args.eval_tgt_len, default_args.mem_len)
        # else:
        #     model.reset_length(default_args.eval_tgt_len,
        #                        default_args.ext_len, default_args.mem_len +
        #                       default_args.tgt_len - default_args.eval_tgt_len)

        # Evaluation
        total_len, total_loss = 0, 0.0
        with torch.no_grad():
            mems = tuple()
            for i, (data, target, seq_len) in enumerate(eval_iter):
                if i >= args.max_eval_steps > 0:
                    break
                ret = model(data, target, *mems)
                loss, mems = ret[0], ret[1:]
                loss = loss.mean()
                total_loss += seq_len * loss.float().item()
                total_len += seq_len

        # Switch back to the training mode
        # model.reset_length(default_args.tgt_len, default_args.ext_len,
        # default_args.mem_len)
        model.train()

        return total_loss / total_len

    # reverse distillation util
    def get_original_batches(model, tr_iter, integration_length):
        model.eval()
        if args.batch_chunk > 1:
            mems = [None for _ in range(args.batch_chunk)]
            first_logits = [[] for _ in range(args.batch_chunk)]
        else:
            mems = None
            first_logits = []
        train_iter = tr_iter.get_varlen_iter() if args.varlen else tr_iter
        with torch.no_grad():
            for batch, (data, target, seq_len) in enumerate(train_iter):
                if batch == integration_length:
                    break
                if args.batch_chunk > 1:
                    data_chunks = torch.chunk(data, args.batch_chunk, 1)
                    for i in range(args.batch_chunk):
                        data_i = data_chunks[i].contiguous()
                        logits, mems[i] = model._forward(data_i, mems=mems[i])
                        first_logits[i].append(logits.cpu())
                else:
                    logits, mems = model._forward(data, mems=mems)
                    first_logits.append(logits.cpu())
        return first_logits

    def build_optimizer(model, args, reload=False):
        optimizer_sparse = None
        if args.optim.lower() == "sgd":
            optimizer = optim.SGD(model.parameters(),
                                  lr=args.lr,
                                  momentum=args.mom)
        elif args.optim.lower() == "adam":
            optimizer = optim.Adam(model.parameters(), lr=args.lr)
        elif args.optim.lower() == "adagrad":
            optimizer = optim.Adagrad(model.parameters(), lr=args.lr)
        else:
            raise ValueError(f"optimizer type {args.optim} not recognized")

        if reload:
            if args.restart_from is not None:
                optim_name = f"optimizer_{args.restart_from}.pt"
            else:
                optim_name = "optimizer.pt"
            optim_file_name = os.path.join(args.restart_dir, optim_name)
            logging(f"reloading {optim_file_name}")
            if os.path.exists(os.path.join(args.restart_dir, optim_name)):
                with open(os.path.join(args.restart_dir, optim_name),
                          "rb") as optim_file:
                    opt_state_dict = torch.load(optim_file)
                    try:
                        optimizer.load_state_dict(opt_state_dict)
                    # in case the optimizer param groups aren't the same shape,
                    # merge them
                    except:
                        logging("merging optimizer param groups")
                        opt_state_dict["param_groups"][0]["params"] = [
                            param
                            for param_group in opt_state_dict["param_groups"]
                            for param in param_group["params"]
                        ]
                        opt_state_dict["param_groups"] = [
                            opt_state_dict["param_groups"][0]
                        ]
                        optimizer.load_state_dict(opt_state_dict)
            else:
                logging("Optimizer was not saved. Start from scratch.")

        return optimizer, optimizer_sparse

    def log_val(val_loss, step, compute):
        logging("-" * 100)
        log_str = ("| Eval {:3d} at step {:>8d} | time: {:5.2f}s "
                   "| valid loss {:5.2f}".format(
                       step // args.eval_interval,
                       step,
                       (time.time() - eval_start_time),
                       val_loss,
                   ))
        log_str += " | bpc {:9.5f}".format(val_loss / math.log(2))
        logging(log_str)
        logging("-" * 100)

    def epoch_loop(
        epoch,
        model,
        optimizers,
        schedulers,
    ):
        nonlocal train_step

        # Turn on training mode which enables dropout.
        if isinstance(model, nn.DataParallel):
            parent_model = model.module
        else:
            parent_model = model
        optimizer, optimizer_sparse = optimizers
        scheduler, scheduler_sparse = schedulers

        # global train_step, best_val_loss, eval_start_time, log_start_time
        train_losses = []
        model.train()
        if args.batch_chunk > 1:
            mems = [tuple() for _ in range(args.batch_chunk)]
        else:
            mems = tuple()
        train_iter = tr_iter.get_varlen_iter() if args.varlen else tr_iter

        log_start_time = time.time()
        best_val_loss = float("Infinity")
        for batch, (data, target, seq_len) in enumerate(train_iter):
            model.zero_grad()
            if args.batch_chunk > 1:
                data_chunks = torch.chunk(data, args.batch_chunk, 1)
                target_chunks = torch.chunk(target, args.batch_chunk, 1)
                for i in range(args.batch_chunk):
                    data_i = data_chunks[i].contiguous()
                    target_i = target_chunks[i].contiguous()
                    ret = model(data_i, target_i, *mems[i])
                    loss, mems[i] = ret[0], ret[1:]
                    loss = loss.float().mean().type_as(loss) / args.batch_chunk
                    if args.fp16:
                        with amp.scale_loss(loss, optimizer) as scaled_loss:
                            scaled_loss.backward()
                    else:
                        loss.backward()
                    train_losses.append(loss.float().item())
            else:
                ret = model(data, target, *mems)
                loss, mems = ret[0], ret[1:]
                loss = loss.float().mean().type_as(loss)
                if args.fp16:
                    with amp.scale_loss(loss, optimizer) as scaled_loss:
                        scaled_loss.backward()
                else:
                    loss.backward()
                train_losses.append(loss.float().item())

            if args.fp16:
                torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer),
                                               args.clip)
            else:
                torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip)

            optimizer.step()
            parent_model.compute += openai_compute(
                non_emb_param_count(parent_model, nseries), data.numel(), 1)

            # step-wise learning rate annealing
            train_step += 1
            parent_model.training_steps += 1

            # check for yet-to-thaw parameters
            if getattr(parent_model, "freeze_countdown", 0) > 0:
                parent_model.freeze_countdown -= 1

                # if this is the last step
                if parent_model.freeze_countdown == 0:
                    for parameter in parent_model.parameters():
                        parameter.requires_grad = True
                    logging("thawing all parameters")

            if args.scheduler in ["cosine", "constant", "dev_perf"]:
                # linear warmup stage
                if train_step < args.warmup_step:
                    curr_lr = args.lr * train_step / args.warmup_step
                    optimizer.param_groups = curr_lr
                else:
                    if args.scheduler == "cosine":
                        scheduler.step(train_step)
            elif args.scheduler == "inv_sqrt":
                scheduler.step(train_step)

            if train_step % args.log_interval == 0:
                cur_loss = np.mean(train_losses)
                elapsed = time.time() - log_start_time
                log_str = ("| epoch {:3d} step {:>8d} "
                           "| {:>6d} batches "
                           "| lr {:.3g} "
                           "| ms/batch {:5.2f} "
                           "| loss {:5.2f}".format(
                               epoch,
                               train_step,
                               batch + 1,
                               optimizer.param_groups[0]["lr"],
                               elapsed * 1000 / args.log_interval,
                               cur_loss,
                           ))

                log_str += " | bpc {:9.5f}".format(cur_loss / math.log(2))
                logging(log_str)

                train_losses = []
                log_start_time = time.time()

            if train_step % args.eval_interval == 0:
                val_loss = evaluate(va_iter, model)
                log_val(val_loss,
                        step=train_step,
                        compute=parent_model.compute)
                # Save the model if the validation loss is the best we've seen so
                # far.
                if not best_val_loss or val_loss < best_val_loss:
                    best_val_loss = val_loss
                    if not args.debug:
                        if args.fp16:
                            with open(
                                    os.path.join(args.work_dir,
                                                 "amp_checkpoint.pt"),
                                    "wb",
                            ) as f:
                                checkpoint = {
                                    "model": model.state_dict(),
                                    "optimizer": optimizer.state_dict(),
                                    "amp": amp.state_dict(),
                                }
                                torch.save(checkpoint, f)
                        else:
                            with open(os.path.join(args.work_dir, "model.pt"),
                                      "wb") as f:
                                torch.save(parent_model, f)
                            with open(
                                    os.path.join(args.work_dir,
                                                 "optimizer.pt"),
                                    "wb",
                            ) as f:
                                torch.save(optimizer.state_dict(), f)

                # dev-performance based learning rate annealing
                if args.scheduler == "dev_perf":
                    scheduler.step(val_loss)

                eval_start_time = time.time()

            if train_step == args.max_step:
                break

    def expand_model(
        strategy,
        integration,
        integration_length,
        n_add,
        model: MemTransformerLM,
        optimizers,
        schedulers,
        tr_iter,
        va_iter,
        epoch,
        step,
    ):
        optimizer, _ = optimizers
        scheduler, _ = schedulers
        if integration:
            if not integration_length or integration_length <= 0:
                warnings.warn(
                    f"integration {integration} passed but integration_length is {integration_length}"
                )
            else:
                logging(
                    f"applying integration strategy {integration} with integration length {integration_length}"
                )

        # pre-expansion validation
        logging(f"evaluating before expanding")
        val_loss = evaluate(va_iter, model)
        log_val(val_loss, step=step, compute=model.compute)

        # infer example logits for reverse distillation
        if "reverse_distil" in integration:
            first_logits = get_original_batches(model, tr_iter,
                                                integration_length)

        # expansion
        logging(
            f"adding {n_add} layers before starting epoch {epoch} with method {strategy}"
        )
        new_layers = model.expand_layers(n_add,
                                         strategy=strategy,
                                         function=initialization_func)

        # optimizer update
        optimizer.add_param_group({
            "params":
            new_layers.parameters(),
            "lr":
            optimizer.param_groups[0]["lr"],
            "initial_lr":
            optimizer.param_groups[0]["initial_lr"],
        })
        scheduler.base_lrs.append(optimizer.param_groups[-1]["initial_lr"])

        # training loop for reverse distillation
        if "reverse_distil" in integration:
            fit_to_previous_model(model, new_layers, tr_iter, first_logits,
                                  integration)

        # freezing parameters for frozen restart, we do this afterwards else the
        # new layers get copied also without grads
        if "freeze" in integration and integration_length > 0:
            for param_group in optimizer.param_groups[:-1]:
                for parameter in param_group["params"]:
                    parameter.requires_grad = False
            model.freeze_countdown = integration_length

        # post-expansion validation
        logging(f"reevaluating")
        val_loss = evaluate(va_iter, model)
        log_val(val_loss, step=step, compute=model.compute)

    def expand_state(param, state):
        if param.shape != state.shape:
            ratios = [
                param.shape[i] // state.shape[i]
                for i in range(len(param.shape))
            ]
            return state.repeat(*ratios)
        else:
            return state

    def widen_model(
        strategy,
        ratio,
        model: MemTransformerLM,
        optimizers,
        va_iter,
        epoch,
        step,
    ):
        optimizer, _ = optimizers

        # pre-expansion validation
        logging(f"evaluating before widening")

        # debug
        val_loss = evaluate(va_iter, model)
        log_val(val_loss, compute=model.compute, step=step)

        # infer example logits for reverse distillation expansion
        logging(
            f"adding {ratio} layers before starting epoch {epoch} with method {strategy}"
        )
        model.add_heads(ratio, strategy=strategy, function=initialization_func)

        # optimizer update
        for param, states in optimizer.state.items():
            if isinstance(param, nn.Parameter):
                states["exp_avg"] = expand_state(param, states["exp_avg"])
                states["exp_avg_sq"] = expand_state(param,
                                                    states["exp_avg_sq"])

        # training loop for reverse distillation
        # post-expansion validation
        logging(f"reevaluating")
        val_loss = evaluate(va_iter, model)
        log_val(val_loss, step=step, compute=model.compute)

    # reverse distillation trainer
    def fit_to_previous_model(model, new_layers, tr_iter, first_logits,
                              integration):
        mse_loss = torch.nn.MSELoss()
        if "partial" in integration:
            distil_optimizer, distil_optimizer_sparse = build_optimizer(
                new_layers, reload=False)
        else:
            distil_optimizer, distil_optimizer_sparse = build_optimizer(
                model, reload=False)
        if args.cuda and args.fp16:
            model, distil_optimizer = amp.initialize(model,
                                                     distil_optimizer,
                                                     opt_level=args.fp16)

        model.train()
        if args.batch_chunk > 1:
            mems = [None for _ in range(args.batch_chunk)]
        else:
            mems = None
        train_iter = tr_iter.get_varlen_iter() if args.varlen else tr_iter
        for batch, (data, _, _) in enumerate(train_iter):
            if batch == len(first_logits):
                break
            model.zero_grad()
            if args.batch_chunk > 1:
                data_chunks = torch.chunk(data, args.batch_chunk, 1)
                for i in range(args.batch_chunk):
                    data_i = data_chunks[i].contiguous()
                    logits, mems[i] = model._forward(data_i, mems=mems[i])
                    target_logits = first_logits[i][batch].to(logits.device)
                    loss = mse_loss(logits, target_logits) / args.batch_chunk
                    if args.fp16:
                        with amp.scale_loss(loss,
                                            distil_optimizer) as scaled_loss:
                            scaled_loss.backward()
                    else:
                        loss.backward()
            else:
                logits, mems = model._forward(data, mems=mems)
                target_logits = first_logits[batch].to(logits.device)
                loss = mse_loss(logits, target_logits)
                if args.fp16:
                    with amp.scale_loss(loss, distil_optimizer) as scaled_loss:
                        scaled_loss.backward()
                else:
                    loss.backward()

            if args.fp16:
                torch.nn.utils.clip_grad_norm_(
                    amp.master_params(distil_optimizer), args.clip)
            else:
                torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip)

            distil_optimizer.step()

    ###################################################################################
    #
    # main()
    #
    args.tied = not args.not_tied

    if args.d_embed < 0:
        args.d_embed = args.d_model

    # Validate `--fp16` option
    if args.fp16:
        if not args.cuda:
            print("WARNING: --fp16 requires --cuda, ignoring --fp16 option")
            args.fp16 = False
        else:
            try:
                from apex import amp

                if args.fp16 == "O1":
                    amp.register_half_function(torch, "einsum")
            except:
                print("WARNING: apex not installed, ignoring --fp16 option")
                args.fp16 = False

    device = torch.device("cuda" if args.cuda else "cpu")

    # Set the random seed manually for reproducibility.
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if torch.cuda.is_available():
        if not args.cuda:
            print(
                "WARNING: You have a CUDA device, so you should probably run "
                "with --cuda ")
        else:
            torch.cuda.manual_seed_all(args.seed)

    ############################################################################
    # Logging
    ############################################################################

    assert args.ext_len >= 0, "extended context length must be non-negative"
    assert args.d_batch % args.batch_chunk == 0

    args.work_dir = "{}-{}".format(args.work_dir, args.dataset)
    args.work_dir = os.path.join(args.work_dir, time.strftime("%Y%m%d-%H%M%S"))
    logging = create_exp_dir(
        args.work_dir,
        scripts_to_save=["train_ts.py", "mem_transformer.py"],
        debug=args.debug,
    )

    ############################################################################
    # Load data
    ############################################################################
    time_series = get_time_series(args.datadir, args.dataset)
    nseries = len(time_series.vocab)
    args.n_token = nseries

    eval_batch_size = 20
    tr_iter = time_series.get_iterator(
        "train",
        args.d_batch,
        args.tgt_len,
        device=device,
        ext_len=args.ext_len,
    )
    va_iter = time_series.get_iterator(
        "valid",
        eval_batch_size,
        args.eval_tgt_len,
        device=device,
        ext_len=args.ext_len,
    )
    te_iter = time_series.get_iterator(
        "test",
        eval_batch_size,
        args.eval_tgt_len,
        device=device,
        ext_len=args.ext_len,
    )

    cutoffs, tie_projs = [], [False]

    ############################################################################
    # Define model
    ############################################################################

    initialization_func = partial(
        weights_init,
        init=args.init,
        init_range=args.init_range,
        init_std=args.init_std,
        proj_init_std=args.proj_init_std,
    )

    if args.restart and not args.fp16:
        if args.restart_from is not None:
            model_name = f"model_{args.restart_from}.pt"
        else:
            model_name = "model.pt"
        model_file_name = os.path.join(args.restart_dir, model_name)
        logging(f"reloading {model_file_name}")
        with open(model_file_name, "rb") as f:
            model = torch.load(f)
        # backwards compatibility with older saves
        if isinstance(model, nn.DataParallel):
            model = model.module
        model.backward_compatible(tie_weight=args.tied, tie_projs=tie_projs)
        if not args.fp16:
            model = model.float()
        model.apply(update_dropout)
        model.apply(update_dropatt)

    else:
        model = MemTransformerLM(
            nseries,
            args.n_layer,
            args.n_head,
            args.d_model,
            args.d_head,
            args.d_inner,
            args.dropout,
            args.dropatt,
            tie_weight=args.tied,
            d_embed=args.d_embed,
            div_val=args.div_val,
            tie_projs=tie_projs,
            pre_lnorm=args.pre_lnorm,
            tgt_len=args.tgt_len,
            ext_len=args.ext_len,
            mem_len=args.mem_len,
            cutoffs=cutoffs,
            same_length=args.same_length,
            clamp_len=args.clamp_len,
        )
        model.apply(initialization_func)

        # debug
        # model.word_emb.apply(initialization_func)
        # ensure embedding init is not overridden by out_layer in case of
        # weight sharing
    args.n_all_param = sum([p.nelement() for p in model.parameters()])
    args.n_nonemb_param = non_emb_param_count(model, nseries)

    logging("=" * 100)
    for k, v in args.__dict__.items():
        logging("    - {} : {}".format(k, v))
    logging("=" * 100)
    logging("#params = {}".format(args.n_all_param))
    logging("#non emb params = {}".format(args.n_nonemb_param))

    para_model = parallelize_model(model, args)
    optimizers = build_optimizer(para_model,
                                 args,
                                 reload=args.restart and not args.fp16)
    optimizer, optimizer_sparse = optimizers
    schedulers = build_scheduler(optimizers, args)
    scheduler, scheduler_sparse = schedulers

    if args.cuda and args.fp16:
        para_model, optimizer = amp.initialize(para_model,
                                               optimizer,
                                               opt_level=args.fp16)

        if args.restart:
            if args.restart_from is not None:
                checkpoint_name = f"amp_checkpoint_{args.restart_from}.pt"
            else:
                checkpoint_name = "amp_checkpoint.pt"
            with open(os.path.join(args.work_dir, checkpoint_name), "rb") as f:
                checkpoint = torch.load(f)
                model.load_state_dict(checkpoint["model"])
                optimizer.load_state_dict(checkpoint["optimizer"])
                amp.load_state_dict(checkpoint["amp"])

    ############################################################################
    # Training loop
    ############################################################################

    # Loop over epochs.
    if args.reset_lr:
        # then they're different and we use train_step only for the new lr
        # scheduling
        train_step = 0
        optimizer.defaults["lr"] = args.lr
        for param_group in optimizer.param_groups:
            param_group["lr"] = args.lr
            param_group["initial_lr"] = args.lr
        scheduler.base_lrs = [args.lr] * len(scheduler.base_lrs)
    else:
        train_step = model.training_steps

    best_val_loss = None

    # Reload previous step number in case of default_args.restart
    if train_step > 0:
        logging(f"restarting from step {train_step}")

    log_start_time = time.time()
    eval_start_time = time.time()

    def run_training():
        nonlocal train_step

        for epoch in itertools.count(start=first_epoch):
            # we check before the training loop; expanding at epoch 0 means
            # before training (for debug purposes)
            if args.expand and str(epoch - 1) in args.expansion_dict:
                n_add = int(args.expansion_dict[str(epoch - 1)])
                expand_model(
                    args.expand,
                    args.integration,
                    args.integration_length,
                    n_add,
                    model,
                    optimizers,
                    schedulers,
                    tr_iter,
                    va_iter,
                    epoch,
                    train_step,
                )
            if args.widen and str(epoch - 1) in args.widen_dict:
                ratio = int(args.widen_dict[str(epoch - 1)])
                widen_model(
                    args.widen,
                    ratio,
                    model,
                    optimizers,
                    va_iter,
                    epoch,
                    train_step,
                )
            epoch_loop(epoch, para_model, optimizers, schedulers)
            if train_step >= args.max_step:
                logging("-" * 100)
                logging("End of training")
                break
            if not args.debug and args.log_first_epochs:
                if epoch <= args.log_first_epochs:
                    logging(f"saving model at the end of epoch {epoch}")
                    if args.fp16:
                        with open(
                                os.path.join(args.work_dir,
                                             f"amp_checkpoint_{epoch}.pt"),
                                "wb",
                        ) as f:
                            checkpoint = {
                                "model": model.state_dict(),
                                "optimizer": optimizer.state_dict(),
                                "amp": amp.state_dict(),
                            }
                            torch.save(checkpoint, f)
                    else:
                        with open(
                                os.path.join(args.work_dir,
                                             f"model_{epoch}.pt"),
                                "wb",
                        ) as f:
                            torch.save(model, f)
                        with open(
                                os.path.join(args.work_dir,
                                             f"optimizer_{epoch}.pt"),
                                "wb",
                        ) as f:
                            torch.save(optimizer.state_dict(), f)

    # At any point you can hit Ctrl + C to break out of training early.
    try:
        if args.restart_from:
            first_epoch = args.restart_from + 1
            print(f"restarting from epoch {first_epoch}")
        else:
            first_epoch = 1
        run_training()

    except KeyboardInterrupt:
        logging("-" * 100)
        logging("Exiting from training early")

    # Load the best model.
    if args.fp16:
        with open(os.path.join(args.work_dir, "amp_checkpoint.pt"), "rb") as f:
            checkpoint = torch.load(f)
            model.load_state_dict(checkpoint["model"])
            optimizer.load_state_dict(checkpoint["optimizer"])
            amp.load_state_dict(checkpoint["amp"])
    else:
        with open(os.path.join(args.work_dir, "model.pt"), "rb") as f:
            model = torch.load(f)
    para_model = model.to(device)

    # Run on test data.
    test_loss = evaluate(te_iter, para_model)
    logging("=" * 100)
    logging("| End of training | test loss {:5.2f} | test bpc {:9.5f}".format(
        test_loss, test_loss / math.log(2)))
    logging("=" * 100)
コード例 #15
0
ファイル: train.py プロジェクト: wbj0110/models
def do_train(args):
    if args.use_gpu:
        rank = dist.get_rank()
        trainer_count = dist.get_world_size()
    else:
        rank = 0
        trainer_count = 1
        paddle.set_device("cpu")

    if trainer_count > 1:
        dist.init_parallel_env()

    random_seed = eval(str(args.random_seed))
    if random_seed is not None:
        paddle.seed(random_seed)

    vocab = get_lm_vocab(args)
    train_loader = get_lm_data_loader(args, vocab, "train")
    eval_loader = get_lm_data_loader(args, vocab, "valid")

    cutoffs, tie_projs = [], [False]
    if args.adaptive:
        assert args.dataset in ['wt103', 'lm1b']
        if args.dataset == 'wt103':
            cutoffs = [20000, 40000, 200000]
            tie_projs += [True] * len(cutoffs)
        elif args.dataset == 'lm1b':
            cutoffs = [60000, 100000, 640000]
            tie_projs += [False] * len(cutoffs)

    mem_transformer = MemTransformerLM(args.ntokens,
                                       args.n_layer,
                                       args.n_head,
                                       args.d_model,
                                       args.d_head,
                                       args.d_inner_hid,
                                       args.dropout,
                                       args.attn_dropout,
                                       tie_weight=args.tie_weight,
                                       d_embed=args.d_model,
                                       div_val=args.div_val,
                                       tie_projs=tie_projs,
                                       normalize_before=args.normalize_before,
                                       tgt_len=args.tgt_len,
                                       ext_len=args.ext_len,
                                       mem_len=args.mem_len,
                                       cutoffs=cutoffs,
                                       same_length=args.same_length,
                                       attn_type=args.attn_type,
                                       clamp_len=args.clamp_len,
                                       sample_softmax=args.sample_softmax)

    if args.scheduler == 'cosine':
        scheduler = paddle.optimizer.lr.CosineAnnealingDecay(
            learning_rate=args.learning_rate,
            T_max=args.max_step,
            eta_min=args.eta_min)
    elif args.scheduler == 'noam':
        scheduler = paddle.optimizer.lr.NoamDecay(
            d_model=args.d_model,
            warmup_steps=args.warmup_steps,
            learning_rate=args.learning_rate)
    elif args.scheduler == 'dev_perf':
        # fluid api
        scheduler = paddle.fluid.dygraph.ReduceLROnPlateau(
            learning_rate=args.learning_rate,
            decay_rate=args.decay_rate,
            patience=args.patience,
            min_lr=args.lr_min)
    elif args.scheduler == 'constant':
        scheduler = args.learning_rate

    clip = paddle.nn.ClipGradByGlobalNorm(args.clip)
    if args.optim.lower() == 'momentum':
        optimizer = paddle.optimizer.Momentum(
            learning_rate=scheduler,
            parameters=mem_transformer.parameters(),
            momentum=args.mom,
            grad_clip=clip)
    elif args.optim.lower() == 'adam':
        optimizer = paddle.optimizer.Adam(
            learning_rate=scheduler,
            parameters=mem_transformer.parameters(),
            beta1=args.beta1,
            beta2=args.beta2,
            epsilon=eval(args.eps),
            grad_clip=clip)
    elif args.optim.lower() == 'adagrad':
        optimizer = paddle.optimizer.Adagrad(
            learning_rate=scheduler,
            parameters=mem_transformer.parameters(),
            grad_clip=clip)

    # Init from some checkpoint, to resume the previous training
    if args.init_from_checkpoint:
        model_dict = paddle.load(
            os.path.join(args.init_from_checkpoint,
                         "mem_transformer.pdparams"))
        opt_dict = paddle.load(
            os.path.join(args.init_from_checkpoint, "mem_transformer.pdopt"))
        mem_transformer.set_state_dict(model_dict)
        optimizer.set_state_dict(opt_dict)
        print("loaded from checkpoint.")
    # Init from some pretrain models, to better solve the current task
    if args.init_from_pretrain_model:
        model_dict = paddle.load(
            os.path.join(args.init_from_pretrain_model,
                         "mem_transformer.pdparams"))
        mem_transformer.set_state_dict(model_dict)
        print("loaded from pre-trained model.")

    if trainer_count > 1:
        mem_transformer = paddle.DataParallel(mem_transformer)

    step_idx = 0
    train_loss = 0.0

    log_start_time = time.time()

    for pass_id in range(args.epoch):
        batch_id = 0

        mems = tuple()
        for input_data in train_loader:
            (src, target, seq_len) = input_data
            ret = mem_transformer(src, target, *mems)
            loss = ret[0]
            mems = ret[1:]
            train_loss += loss.numpy()

            loss.backward()
            optimizer.step()
            optimizer.clear_grad()

            if step_idx > 0 and step_idx % args.print_step == 0 and rank == 0:
                cur_loss = train_loss / args.print_step
                elapsed = time.time() - log_start_time
                if args.scheduler == "constant":
                    lr = optimizer.get_lr()
                else:
                    lr = scheduler.get_lr()
                logger_info = "step_idx: %d, epoch: %d, batch: %d, learning rate: %.8f, " \
                              "speed: %f ms/batch, loss: %f" % \
                              (step_idx, pass_id, batch_id, lr,
                               elapsed * 1000.0 / args.print_step, cur_loss)
                if args.dataset in ["enwik8", "text8"]:
                    logger_info = logger_info + ", bpc: %f" % (cur_loss /
                                                               np.log(2))
                else:
                    logger_info = logger_info + ", ppl: %f" % (
                        np.exp(cur_loss))

                logger.info(logger_info)
                train_loss = 0.0
                log_start_time = time.time()

            if step_idx % args.save_step == 0 and step_idx != 0:
                # Do validation.
                mem_transformer.eval()

                # TODO(FrostML): simplify this.
                if args.mem_len == 0:
                    if dist.get_world_size() == 1:
                        mem_transformer.reset_length(tgt_len=args.eval_tgt_len,
                                                     ext_len=args.ext_len +
                                                     args.tgt_len -
                                                     args.eval_tgt_len,
                                                     mem_len=args.mem_len)
                    else:
                        mem_transformer._layers.reset_length(
                            tgt_len=args.eval_tgt_len,
                            ext_len=args.ext_len + args.tgt_len -
                            args.eval_tgt_len,
                            mem_len=args.mem_len)
                else:
                    if dist.get_world_size() == 1:
                        mem_transformer.reset_length(tgt_len=args.eval_tgt_len,
                                                     ext_len=args.ext_len,
                                                     mem_len=args.mem_len +
                                                     args.tgt_len -
                                                     args.eval_tgt_len)
                    else:
                        mem_transformer._layers.reset_length(
                            tgt_len=args.eval_tgt_len,
                            ext_len=args.ext_len,
                            mem_len=args.mem_len + args.tgt_len -
                            args.eval_tgt_len)

                total_len, total_loss = 0, 0.

                eval_mems = tuple()
                with paddle.no_grad():
                    for i, (src, target, seq_len) in enumerate(eval_loader):
                        if args.max_eval_steps > 0 and i >= args.max_eval_steps:
                            break
                        ret = mem_transformer(src, target, *eval_mems)
                        loss, eval_mems = ret[0], ret[1:]
                        seq_len = seq_len.numpy()
                        eval_cur_loss = seq_len * loss.numpy()
                        total_loss += eval_cur_loss
                        total_len += seq_len
                    eval_loss = total_loss / total_len

                logger_info = "Validation, step_idx: %d, validation loss: %f" % \
                            (step_idx, eval_loss)
                if args.dataset in ['enwik8', 'text8']:
                    logger_info = logger_info + ", bpc: %f" % (eval_loss /
                                                               np.log(2))
                else:
                    logger_info = logger_info + ", ppl: %f" % (
                        np.exp(eval_loss))
                logger.info(logger_info)

                if args.save_model and rank == 0:
                    model_dir = os.path.join(
                        args.save_model,
                        "step_" + str(step_idx) + "_" + str(eval_loss))
                    if not os.path.exists(model_dir):
                        os.makedirs(model_dir)
                    paddle.save(
                        mem_transformer.state_dict(),
                        os.path.join(model_dir, "mem_transformer.pdparams"))
                    paddle.save(
                        optimizer.state_dict(),
                        os.path.join(model_dir, "mem_transformer.pdopt"))

                if args.scheduler == 'dev_perf':
                    scheduler.step(eval_loss)

                # TODO(FrostML): simplify this.
                if dist.get_world_size() == 1:
                    mem_transformer.reset_length(tgt_len=args.tgt_len,
                                                 ext_len=args.ext_len,
                                                 mem_len=args.mem_len)
                else:
                    mem_transformer._layers.reset_length(tgt_len=args.tgt_len,
                                                         ext_len=args.ext_len,
                                                         mem_len=args.mem_len)

                mem_transformer.train()

            step_idx += 1
            batch_id += 1
            if args.scheduler in ['cosine', 'dev_perf']:
                if step_idx < args.warmup_steps:
                    curr_lr = args.learning_rate * step_idx / args.warmup_steps
                    scheduler.base_lr = curr_lr
                else:
                    if args.scheduler == 'cosine':
                        scheduler.step()
            elif args.scheduler == 'constant':
                if step_idx < args.warmup_steps:
                    curr_lr = args.learning_rate * step_idx / args.warmup_steps
                    optimizer.set_lr(curr_lr)
            elif args.scheduler == 'noam':
                scheduler.step()
        if step_idx >= args.max_step:
            break

    if args.save_model and rank == 0:
        model_dir = os.path.join(args.save_model, "step_final")
        if not os.path.exists(model_dir):
            os.makedirs(model_dir)
        paddle.save(mem_transformer.state_dict(),
                    os.path.join(model_dir, "mem_transformer.pdparams"))
        paddle.save(optimizer.state_dict(),
                    os.path.join(model_dir, "mem_transformer.pdopt"))
コード例 #16
0
def main_loop():
    util.cancel_shutdown()
    losses = []

    args = g.args

    if not args.local:
        g.logger.info(
            f'Distributed initializing process group with {args.dist_backend}, {args.dist_url}, {util.get_world_size()}')
        dist.init_process_group(backend=args.dist_backend,
                                init_method=args.dist_url,
                                world_size=util.get_world_size())
        assert (util.get_world_size() == dist.get_world_size())
        g.logger.info(f"Distributed: success ({args.local_rank}/{dist.get_world_size()})")

    if args.load_state_fn:
        g.state = load_state(args.load_state_fn)
        g.logger.info(f"Restoring training from {args.load_state_fn}")
    else:
        g.logger.info("creating new model")
        g.state = TrainState(args)

        g.state.model = MemTransformerLM(g.ntokens, args.n_layer, args.n_head, args.d_model,
                                         args.d_head, args.d_inner, args.dropout, args.dropatt,
                                         tie_weight=args.tied, d_embed=args.d_embed, div_val=args.div_val,
                                         tie_projs=g.tie_projs, pre_lnorm=args.pre_lnorm, tgt_len=args.tgt_len,
                                         ext_len=args.ext_len, mem_len=args.mem_len, cutoffs=g.cutoffs,
                                         same_length=args.same_length, attn_type=args.attn_type,
                                         clamp_len=args.clamp_len, sample_softmax=args.sample_softmax)
        if args.checkpoint:
            util.restore_from_checkpoint(g.state.model, checkpoint_fn=args.checkpoint)
        else:
            g.state.model.apply(weights_init)
            g.state.model.word_emb.apply(
                weights_init)  # ensure embedding init is not overridden by out_layer in case of weight sharing
        g.state.model.to(g.device)
        optimizer_setup(g.state)

    model: MemTransformerLM = g.state.model
    optimizer = g.state.optimizer

    # log model info
    # n_all_param = sum([p.nelement() for p in model.parameters()])
    # log_tb('sizes/params', n_all_param)
    # n_nonemb_param = sum([p.nelement() for p in model.layers.parameters()])
    # log_tb('sizes/non_emb_params', n_nonemb_param)
    # g.logger.info('params %s non_emb_params %s', n_all_param, n_nonemb_param)

    # scheduler
    if not g.args.load_state_fn:
        if args.scheduler == 'cosine':
            # Divide by 1e6 for numerical stability.
            g.state.scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, args.max_tokens // 1e6,
                                                                     eta_min=args.eta_min)
        elif args.scheduler == 'finder':
            g.state.scheduler: LRFinder = LRFinder(optimizer, args.max_tokens, init_value=args.lr / 1e3)
        else:
            assert args.scheduler == 'constant'
            g.state.scheduler = util.NoOp()

    # Setup distributed model
    if args.local:
        model = nn.DataParallel(model, dim=1)
    else:
        # Uncomment find_unused_parameters and upgrade to torch 1.1 for adaptive embedding.
        model = DistributedDataParallel(model, device_ids=[args.local_rank],
                                        output_device=args.local_rank)  # , find_unused_parameters=True)

    if util.get_global_rank() == 0:
        if not args.test:
            wandb.config.update(vars(args))
            # wandb.watch(model)

    g.event_writer.add_text('args', str(args))  # TODO: replace with log_tb

    accumulated_loss = 0
    # At any point you can hit Ctrl + C to break out of training early.
    try:
        for epoch in itertools.count(start=g.state.last_epoch):
            print(f"epoch -- {epoch}, token_count -- {g.state.
コード例 #17
0
def main_loop():
    util.cancel_shutdown()
    losses = []

    args = g.args

    if not args.local:
        g.logger.info(
            f'Distributed initializing process group with '
            f'{args.dist_backend}, {args.dist_url}, {util.get_world_size()}')
        dist.init_process_group(
            backend=args.dist_backend,
            #init_method=args.dist_url,
            #world_size=util.get_world_size()
        )
        assert (util.get_world_size() == dist.get_world_size())
        g.logger.info(
            f"Distributed: success ({args.local_rank}/{dist.get_world_size()})"
        )

    g.logger.info("creating new model")
    g.state = TrainState(args)
    g.state.model = MemTransformerLM(g.ntokens,
                                     args.n_layer,
                                     args.n_head,
                                     args.d_model,
                                     args.d_head,
                                     args.d_inner,
                                     args.dropout,
                                     args.dropatt,
                                     tie_weight=args.tied,
                                     d_embed=args.d_embed,
                                     div_val=args.div_val,
                                     tie_projs=g.tie_projs,
                                     pre_lnorm=args.pre_lnorm,
                                     tgt_len=args.tgt_len,
                                     ext_len=args.ext_len,
                                     mem_len=args.mem_len,
                                     cutoffs=g.cutoffs,
                                     same_length=args.same_length,
                                     attn_type=args.attn_type,
                                     clamp_len=args.clamp_len,
                                     sample_softmax=args.sample_softmax,
                                     freeze_below=args.freeze_below)
    g.state.model.to(g.device)
    optimizer_setup(g.state)
    if args.checkpoint:
        if args.checkpoint_secondary:
            g.logger.info(f"restoring extra checkpoint")
            util.restore_from_checkpoint(g.state.model, g.state.optimizer,
                                         args.checkpoint_secondary,
                                         args.optim_state_dict)
        g.logger.info(f"Restoring model from {args.checkpoint}" +
                      f" and optimizer from {args.optim_state_dict}" if args.
                      optim_state_dict else "")
        util.restore_from_checkpoint(g.state.model, g.state.optimizer,
                                     args.checkpoint, args.optim_state_dict)

    else:
        g.state.model.apply(weights_init)
        # ensure embedding init is not overridden by out_layer in case of weight sharing
        g.state.model.word_emb.apply(weights_init)

    model: MemTransformerLM = g.state.model
    optimizer = g.state.optimizer

    if g.state.args.fp16:
        model = FP16_Module(model)
        optimizer = FP16_Optimizer(
            optimizer,
            static_loss_scale=g.state.args.static_loss_scale,
            dynamic_loss_scale=g.state.args.dynamic_loss_scale,
            dynamic_loss_args={'init_scale': 2**16},
            verbose=False)

    # log model info
    # n_all_param = sum([p.nelement() for p in model.parameters()])
    # log_tb('sizes/params', n_all_param)
    # n_nonemb_param = sum([p.nelement() for p in model.layers.parameters()])
    # log_tb('sizes/non_emb_params', n_nonemb_param)
    # g.logger.info('params %s non_emb_params %s', n_all_param, n_nonemb_param)

    # scheduler
    if args.scheduler == 'cosine':
        # Divide by 1e6 for numerical stability.
        g.state.scheduler = optim.lr_scheduler.CosineAnnealingLR(
            optimizer, args.max_tokens // 1e6, eta_min=args.eta_min)
    elif args.scheduler == 'finder':
        g.state.scheduler: LRFinder = LRFinder(optimizer,
                                               args.max_tokens,
                                               init_value=args.lr / 1e3)
    else:
        assert args.scheduler == 'constant'
        g.state.scheduler = util.NoOp()

    # Setup distributed model
    if args.local:
        model = nn.DataParallel(model, dim=1)
    else:
        # Uncomment find_unused_parameters and upgrade to torch 1.1 for adaptive embedding.
        model = DistributedDataParallel(
            model, device_ids=[args.local_rank],
            output_device=args.local_rank)  # , find_unused_parameters=True)

    if util.get_global_rank() == 0:
        if not args.test:
            wandb.config.update(vars(args))
            # wandb.watch(model)

    g.event_writer.add_text('args', str(args))  # TODO: replace with log_tb

    accumulated_loss = 0
    # At any point you can hit Ctrl + C to break out of training early.
    try:
        for epoch in itertools.count(start=g.state.last_epoch):
            print(f"epoch -- {epoch}, token_count -- {g.state.token_count}")
            model.train()

            log_tb('sizes/batch_size', args.batch_size)
            log_tb('sizes/seq_size', args.tgt_len)

            if g.state.partial_epoch:
                # reuse previously loaded tr_iter and states
                assert g.state.tr_iter is not None
                assert g.state.mems is not None
            else:
                g.state.tr_iter = g.corpus.get_dist_iterator(
                    'train',
                    rank=util.get_global_rank(),
                    max_rank=util.get_world_size(),
                    bsz=args.batch_size,
                    bptt=args.tgt_len,
                    device=g.device,
                    ext_len=args.ext_len,
                    skip_files=g.args.skip_files)
                g.state.mems = tuple()
            g.state.last_epoch = epoch

            log_start_time = time.time()
            tokens_per_epoch = 0
            for batch, (data, target, seq_len) in enumerate(g.state.tr_iter):
                # assert seq_len == data.shape[0]
                # for i in range(1, data.shape[0]):
                #     assert torch.all(torch.eq(data[i], target[i - 1]))
                #     break

                # print(g.state.token_count, data)

                if g.state.train_step % args.eval_interval == 0:
                    evaluate_and_log(model,
                                     g.va_iter,
                                     'val_short-mem-1',
                                     generate_text=False,
                                     reset_mems_interval=1)
                    evaluate_and_log(model,
                                     g.va_iter,
                                     'val_short-mem-2',
                                     generate_text=False,
                                     reset_mems_interval=2)
                    evaluate_and_log(model,
                                     g.va_iter,
                                     'val_short-mem-3',
                                     generate_text=False,
                                     reset_mems_interval=3)
                    evaluate_and_log(model, g.va_iter, 'val')
                    if g.va_custom_iter:
                        evaluate_and_log(g.state.model,
                                         g.va_custom_iter,
                                         g.args.valid_custom,
                                         generate_text=False)

                batch_total = torch.tensor(data.shape[1]).to(g.device)
                if args.local:  # TODO(y): factor out (need way to see if dist was inited)
                    batch_total = batch_total.sum()
                else:
                    batch_total = util.dist_sum_tensor(
                        batch_total)  # global batch size
                batch_total = util.toscalar(batch_total)

                should_log = (g.state.train_step < args.verbose_log_steps) or \
                             (g.state.train_step + 1) % args.log_interval == 0

                model.zero_grad()

                ret = model(data, target, *g.state.mems)
                loss, g.state.mems = ret[0], ret[1:]

                loss: torch.Tensor = loss.float().mean().type_as(loss)
                with timeit('backwards', noop=not should_log):
                    if args.fp16:
                        optimizer.backward(loss)
                    else:
                        loss.backward()
                loss0 = util.toscalar(loss)
                util.record('loss', loss0)

                util.record('params', torch.sum(util.flat_param(model)).item())
                losses.append(loss0)
                accumulated_loss += loss0

                if args.fp16:
                    optimizer.clip_master_grads(args.clip)
                else:
                    torch.nn.utils.clip_grad_norm_(model.parameters(),
                                                   args.clip)

                # step-wise learning rate annealing
                if hasattr(optimizer, 'overflow') and optimizer.overflow:
                    g.logger.info("skipped iteration")
                else:
                    if args.scheduler in ['cosine', 'constant', 'dev_perf']:
                        # linear warmup stage
                        if g.state.token_count < args.warmup_tokens:
                            curr_lr = args.lr * float(
                                g.state.token_count) / args.warmup_tokens
                            optimizer.param_groups[0]['lr'] = curr_lr
                        elif args.scheduler == 'cosine':
                            # Divide by 1e6 for numerical stability.
                            g.state.scheduler.step(g.state.token_count //
                                                   1000 // 1000)
                    else:
                        g.state.scheduler.step(g.state.token_count)

                optimizer.step()
                g.state.train_step += 1

                consumed_tokens = data.shape[0] * data.shape[1]
                world_size = int(os.environ.get("WORLD_SIZE", "8"))
                if world_size > 8:  # correction factor for multiple machines
                    consumed_tokens = consumed_tokens * (world_size // 8)
                tokens_per_epoch += consumed_tokens
                g.state.token_count += consumed_tokens
                g.token_count = g.state.token_count
                if g.state.token_count >= args.max_tokens:
                    g.state.partial_epoch = True
                    raise StopIteration  # break out of parent train loop

                if should_log:
                    elapsed_time = time.time() - log_start_time
                    elapsed_steps = g.state.train_step - g.state.last_log_step

                    # compute average loss over last logging interval
                    cur_loss = accumulated_loss / elapsed_steps
                    cur_loss_mean = util.dist_mean(cur_loss)
                    log_str = f'| epoch {epoch:3d} step {g.state.train_step:>8d} ' \
                              f'| {batch:>6d} batches ' \
                              f'| lr {optimizer.param_groups[0]["lr"]:.3g} ' \
                              f'| ms/batch {elapsed_time * 1000 / elapsed_steps:5.2f} ' \
                              f'| loss {cur_loss:5.2f}'
                    if args.dataset in ['enwik8', 'text8']:
                        log_str += f' | bpc {cur_loss / math.log(2):9.5f}'
                    else:
                        log_str += f' | ppl {math.exp(cur_loss):9.3f}'
                    g.logger.info(log_str)
                    log_tb('learning/epoch', epoch)
                    log_tb('_loss', cur_loss_mean)  # the most important thing
                    log_tb('learning/loss', cur_loss_mean)
                    log_tb('learning/ppl', math.exp(cur_loss_mean))

                    # currently step timings are not synchronized in multi-machine
                    # case (see #4). Can add torch.distributed.barrier() to get
                    # more accurate timings, but this may add slowness.
                    log_tb('times/step', 1000 * elapsed_time / elapsed_steps)
                    current_lr = optimizer.param_groups[0]['lr']

                    log_tb('learning/lr', current_lr)

                    # 32 is the "canonical" batch size
                    linear_scaling_factor = batch_total / 32  # TODO(y): merge logic from master
                    log_tb('learning/base_lr',
                           current_lr / linear_scaling_factor)
                    if args.optim == 'lamb':
                        log_lamb_rs(optimizer, g.event_writer,
                                    g.state.token_count)

                    time_per_batch = elapsed_time / elapsed_steps
                    time_per_sample = time_per_batch / args.batch_size
                    time_per_token = time_per_sample / args.tgt_len

                    log_tb('times/batches_per_sec', 1 / time_per_batch)
                    log_tb('times/samples_per_sec', 1 / time_per_sample)
                    log_tb('times/tokens_per_sec', 1 / time_per_token)

                    if str(g.device) == 'cuda':
                        log_tb("memory/allocated_gb",
                               torch.cuda.memory_allocated() / 1e9)
                        log_tb("memory/max_allocated_gb",
                               torch.cuda.max_memory_allocated() / 1e9)
                        log_tb("memory/cached_gb",
                               torch.cuda.memory_cached() / 1e9)
                        log_tb("memory/max_cached_gb",
                               torch.cuda.max_memory_cached() / 1e9)

                    accumulated_loss = 0
                    log_start_time = time.time()
                    g.state.last_log_step = g.state.train_step

            if args.checkpoint_each_epoch:
                g.logger.info(f'Saving checkpoint for epoch {epoch}')
                util.dist_save_checkpoint(model,
                                          optimizer,
                                          args.logdir,
                                          suffix=f'{epoch}')
            if tokens_per_epoch == 0:
                logging.info("Zero tokens in last epoch, breaking")

                break

            g.state.partial_epoch = False

    except KeyboardInterrupt:
        g.logger.info('-' * 100)
        g.logger.info('Exiting from training early')
    except StopIteration:
        pass

    return losses
コード例 #18
0
def main():
    global global_token_count, event_writer, train_step, train_loss, last_log_step, \
        best_val_loss, epoch, model

    if args.local_rank > 0:
        pass  # skip shutdown when rank is explicitly set + not zero rank
    else:
        os.system('shutdown -c')

    if not args.local:
        logger.info(
            f'Distributed initializing process group with {args.dist_backend}, {args.dist_url}, {util.get_world_size()}'
        )
        dist.init_process_group(backend=args.dist_backend,
                                init_method=args.dist_url,
                                world_size=util.get_world_size())
        assert (util.get_world_size() == dist.get_world_size())
        logger.info(
            f"Distributed: success ({args.local_rank}/{dist.get_world_size()})"
        )

    model = MemTransformerLM(ntokens,
                             args.n_layer,
                             args.n_head,
                             args.d_model,
                             args.d_head,
                             args.d_inner,
                             args.dropout,
                             args.dropatt,
                             tie_weight=args.tied,
                             d_embed=args.d_embed,
                             div_val=args.div_val,
                             tie_projs=tie_projs,
                             pre_lnorm=args.pre_lnorm,
                             tgt_len=args.tgt_len,
                             ext_len=args.ext_len,
                             mem_len=args.mem_len,
                             cutoffs=cutoffs,
                             same_length=args.same_length,
                             attn_type=args.attn_type,
                             clamp_len=args.clamp_len,
                             sample_softmax=args.sample_softmax)

    # log model info
    n_all_param = sum([p.nelement() for p in model.parameters()])
    log_tb('sizes/params', n_all_param)
    n_nonemb_param = sum([p.nelement() for p in model.layers.parameters()])
    log_tb('sizes/non_emb_params', n_nonemb_param)
    logger.info('params %s non_emb_params %s', n_all_param, n_nonemb_param)

    # optimizer
    if args.optim.lower() == 'sgd':
        optimizer = optim.SGD(model.parameters(),
                              lr=args.lr,
                              momentum=args.mom)
    elif args.optim.lower() == 'lamb':
        optimizer = Lamb(model.parameters(), lr=args.lr, weight_decay=args.wd)
    else:
        assert args.optim.lower() == 'adam'
        optimizer = optim.Adam(model.parameters(),
                               lr=args.lr,
                               weight_decay=args.wd)

    # scheduler
    if args.scheduler == 'cosine':
        # Divide by 1e6 for numerical stability.
        scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer,
                                                         args.max_tokens //
                                                         1e6,
                                                         eta_min=args.eta_min)
    elif args.scheduler == 'finder':
        scheduler = LRFinder(optimizer,
                             args.max_tokens,
                             init_value=args.lr / 1e3)
    elif args.scheduler == 'constant':
        pass

    model.apply(weights_init)
    model.word_emb.apply(
        weights_init
    )  # ensure embedding init is not overridden by out_layer in case of weight sharing

    if args.checkpoint:
        if global_rank == 0:
            util.restore_from_checkpoint(model=model,
                                         checkpoint_fn=args.checkpoint)

    model = model.to(device)
    if args.fp16:
        model = FP16_Module(model)
        optimizer = FP16_Optimizer(optimizer,
                                   static_loss_scale=args.static_loss_scale,
                                   dynamic_loss_scale=args.dynamic_loss_scale,
                                   dynamic_loss_args={'init_scale': 2**16},
                                   verbose=False)

    if args.local:
        model = nn.DataParallel(model, dim=1)
    else:
        # Uncomment find_unused_parameters and upgrade to torch 1.1 for adaptive embedding.
        model = DistributedDataParallel(
            model, device_ids=[args.local_rank],
            output_device=args.local_rank)  #, find_unused_parameters=True)

    if global_rank == 0:
        event_writer = SummaryWriter(args.logdir)

    event_writer.add_text('args', str(args))

    # test checkpoint writing
    if args.checkpoint_each_epoch:
        logger.info(f'Saving checkpoint for epoch {epoch}')
        util.dist_save_checkpoint(model, optimizer, args.logdir, suffix=f'{0}')

    # Loop over epochs.
    train_step = 0
    train_loss = 0
    last_log_step = 0
    best_val_loss = None
    va_iter, te_iter = [
        corpus.get_dist_iterator(split,
                                 global_rank,
                                 max_rank,
                                 args.batch_size * 2,
                                 args.tgt_len,
                                 device=device,
                                 ext_len=args.ext_len)
        for split in ('valid', 'test')
    ]

    # At any point you can hit Ctrl + C to break out of training early.
    try:
        for epoch in itertools.count(start=1):
            train(va_iter, optimizer, scheduler)
    except KeyboardInterrupt:
        logger.info('-' * 100)
        logger.info('Exiting from training early')
    except StopIteration:
        pass

    # Eval one more time.
    evaluate_and_log(optimizer, va_iter, 'val', train_step=-1)

    # Load the best saved model.
    logger.info("Loading best checkpoint")
    model_file = os.path.join(args.logdir, 'model-best.pt')
    if os.path.exists(model_file):
        with open(model_file, 'rb') as model_f:
            with timeit('load'):
                if args.local:
                    model = torch.load(model_f)
                else:
                    model = torch.load(model_f,
                                       map_location=lambda storage, loc:
                                       storage.cuda(args.local_rank))
                    model = DistributedDataParallel(
                        model,
                        device_ids=[args.local_rank],
                        output_device=args.local_rank)
    else:
        logger.warn('no model file, using current model for loss')

    # Run on test data.
    evaluate_and_log(optimizer, te_iter, 'test', -1)
コード例 #19
0
def main():
    args = parse_args()
    if args.affinity != 'disabled':
        nproc_per_node = torch.cuda.device_count()
        affinity = utils.gpu_affinity.set_affinity(args.local_rank,
                                                   nproc_per_node,
                                                   args.affinity)
        print(f'{args.local_rank}: thread affinity: {affinity}')

    if args.type == 'pytorch':
        from mem_transformer import MemTransformerLM
    else:
        from inference.mem_transformer_jit import MemTransformerLM

    torch.cuda.set_device(args.local_rank)
    l2_promote()
    device = torch.device('cuda' if args.cuda else 'cpu')
    utils.distributed.init_distributed(args.cuda)

    with utils.distributed.sync_workers() as rank:
        if rank == 0:
            create_exp_dir(args.work_dir, debug=args.debug)

    # Setup logging
    if args.log_all_ranks:
        log_file = f'eval_log_rank_{utils.distributed.get_rank()}.log'
    else:
        log_file = f'eval_log.log'

    dllog_file = args.dllog_file
    log_file = os.path.join(args.work_dir, log_file)
    dllog_file = os.path.join(args.work_dir, dllog_file)
    if args.debug:
        log_file = os.devnull
        dllog_file = os.devnull

    utils.exp_utils.setup_logging(
        log_all_ranks=args.log_all_ranks,
        filename=log_file,
        filemode='a',
    )
    utils.exp_utils.setup_dllogger(enabled=True, filename=dllog_file)

    if args.profile:
        try:
            pyprof.init(enable_function_stack=True)
        except NameError:
            warnings.warn('Called pyprof.init() but pyprof is not available')

    logging.info(args)
    dllogger.log(step='PARAMETER', data=vars(args))

    if not args.no_env:
        log_env_info()

    # Set the random seed manually for reproducibility.
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)

    if args.model:
        model_path = args.model
    elif args.work_dir:
        model_path = os.path.join(args.work_dir, 'checkpoint_best.pt')
    else:
        raise RuntimeError(
            'Specify path to checkpoint using --model or --work_dir')

    if not args.manual_config:
        checkpoint = load_checkpoint(model_path)
        vocab_type = checkpoint['args'].vocab
    else:
        checkpoint = None
        vocab_type = args.manual_vocab

    if args.manual:
        vocab = checkpoint['vocab']

        if hasattr(vocab, 'sym2idx') and not hasattr(vocab, 'unk_idx'):
            vocab.unk_idx = vocab.sym2idx['<unk>']

        text = " ".join(args.manual)
        tokenized = tokenize_raw(text)
        symbols = vocab.tokenize(tokenized, add_eos=True)
        tensor = vocab.convert_to_tensor(symbols)

        iter = data_utils.LMOrderedIterator(tensor,
                                            bsz=args.batch_size,
                                            bptt=args.tgt_len,
                                            device=device,
                                            ext_len=args.ext_len,
                                            warmup=False)
    else:
        # Load dataset
        corpus = get_lm_corpus(args.data, args.dataset, vocab_type)

        if args.split == 'valid' or args.split == 'test':
            iter = corpus.get_iterator(args.split,
                                       args.batch_size,
                                       args.tgt_len,
                                       device=device,
                                       mem_len=args.mem_len,
                                       ext_len=args.ext_len)
        else:
            raise RuntimeError('Unknown split')

    if args.fp16:
        dtype = torch.float16
        math_str = 'fp16'
    else:
        dtype = torch.float32
        math_str = 'fp32'

    if args.load_torchscript:
        model = torch.jit.load(args.load_torchscript)
    elif not args.manual_config:
        checkpoint['model_config']['tgt_len'] = args.tgt_len
        checkpoint['model_config']['ext_len'] = args.ext_len
        checkpoint['model_config']['mem_len'] = args.mem_len
        checkpoint['model_config']['clamp_len'] = args.clamp_len
        checkpoint['model_config']['same_length'] = args.same_length
        checkpoint['model_config']['dtype'] = dtype

        model = MemTransformerLM(**checkpoint['model_config'])
        if args.type == 'pytorch':
            model.load_state_dict(checkpoint['model_state'])
        elif args.type == 'torchscript':
            model.load_state_dict(checkpoint['model_state'], strict=False)
    elif args.manual_config:
        args.manual_config['tgt_len'] = args.tgt_len
        args.manual_config['ext_len'] = args.ext_len
        args.manual_config['mem_len'] = args.mem_len
        args.manual_config['clamp_len'] = args.clamp_len
        args.manual_config['same_length'] = args.same_length
        args.manual_config['dtype'] = dtype

        model = MemTransformerLM(**args.manual_config)

    model = model.eval()
    model = model.to(device)
    model = model.to(dtype)

    if args.type == 'torchscript' and not args.manual_config:
        state = checkpoint['model_state']

        tie_projs = checkpoint['model_config']['tie_projs']
        tie_weight = checkpoint['model_config']['tie_weight']
        div_val = checkpoint['model_config']['div_val']
        d_model = checkpoint['model_config']['d_model']
        d_embed = checkpoint['model_config']['d_embed']

        if div_val != 1 or d_model != d_embed:
            for i in range(len(model.word_emb.emb_projs)):
                model.word_emb.emb_projs[i] = state[
                    f'word_emb.emb_projs.{i}'].to(dtype)

        for i in range(len(model.crit.out_projs)):
            if div_val == 1:
                src = 0
            else:
                src = i
            if model.crit.out_projs[i] is not None:
                if tie_projs[i]:
                    model.crit.out_projs[i] = state[
                        f'word_emb.emb_projs.{src}'].to(dtype)
                else:
                    model.crit.out_projs[i] = state[f'crit.out_projs.{i}'].to(
                        dtype)

        for i in range(len(model.crit.out_layers_biases)):
            model.crit.out_layers_biases[i] = state[
                f'crit.out_layers_biases.{i}'].to(dtype)

        if tie_weight:
            for i in range(len(model.crit.out_layers_weights)):
                model.crit.out_layers_weights[i] = state[
                    f'word_emb.emb_layers.{i}.weight'].to(dtype)
        else:
            for i in range(len(model.crit.out_layers_weights)):
                model.crit.out_layers_weights[i] = state[
                    f'crit.out_layers_weights.{i}'].to(dtype)

        model = torch.jit.script(model)

    if args.type != 'pytorch':
        compile_model(model, device, args)

    if args.type == 'torchscript' and args.save_torchscript:
        torch.jit.save(model, args.save_torchscript)

    logging.info(f'Evaluating with: math {math_str} type {args.type} '
                 f'bsz {args.batch_size} tgt_len {args.tgt_len} '
                 f'ext_len {args.ext_len} mem_len {args.mem_len} '
                 f'clamp_len {args.clamp_len}')

    meters = {}
    warmup = args.mem_len // args.tgt_len + 2
    meters['eval_throughput'] = AverageMeter(warmup=warmup,
                                             keep=args.save_data)
    meters['eval_latency'] = AverageMeter(warmup=warmup, keep=args.save_data)

    with torch.autograd.profiler.emit_nvtx(enabled=args.profile):
        loss = evaluate(iter, model, meters, args.log_interval, args.max_size,
                        args.repeat)
    perplexity = math.exp(loss)
    log_str = format_log(loss, args.split, args)

    summary = {
        'eval_loss': loss,
        'eval_ppl': perplexity,
    }

    logging.info('=' * 100)
    logging.info(log_str)
    logging.info('=' * 100)

    if args.save_data:
        latency_data = np.array(meters['eval_latency'].vals)
        throughput_data = np.array(meters['eval_throughput'].vals)
        precision = 'fp16' if args.fp16 else 'fp32'
        data_fname = f'eval_data_{args.batch_size}_{precision}_{args.type}'
        data_path = os.path.join(args.work_dir, data_fname)
        data = {
            'args': args,
            'throughput': throughput_data,
            'latency': latency_data,
        }
        with open(data_path, 'wb') as f:
            pickle.dump(data, f)
        logging.info(f'Throughput Avg: {throughput_data.mean():.2f} tok/s')
        logging.info(f'Latency Avg: {1000.0 * latency_data.mean():.2f} ms')
        for p in args.percentiles:
            logging.info(
                f'Latency {p}%: {1000.0 * np.percentile(latency_data, p):.2f} ms'
            )

        logging.info('=' * 100)

        summary.update({
            'eval_throughput': throughput_data.mean(),
            'eval_avg_latency': 1000 * latency_data.mean(),
        })
        for p in args.percentiles:
            summary[f'eval_{p}%_latency'] = 1000 * np.percentile(
                latency_data, p)

    dllogger.log(step=tuple(), data=summary)

    passed = benchmark(
        target_perplexity=args.target_perplexity,
        test_perplexity=perplexity,
        target_throughput=args.target_throughput,
        test_throughput=meters['eval_throughput'].avg,
    )
    if not passed:
        sys.exit(1)
コード例 #20
0
ファイル: eval.py プロジェクト: quanpn90/NMTBMajor
with open(os.path.join(args.work_dir, 'checkpoint.pt'), 'rb') as f:
    # model_state_dict = torch.load(f)
    checkpoint = torch.load(f)
    model_args = checkpoint['args']
    model = MemTransformerLM(ntokens,
                             model_args.n_layer,
                             model_args.n_head,
                             model_args.d_model,
                             model_args.d_head,
                             model_args.d_inner,
                             0.0,
                             0.0,
                             tie_weight=model_args.tied,
                             d_embed=model_args.d_embed,
                             div_val=1.0,
                             tie_projs=[False],
                             pre_lnorm=model_args.pre_lnorm,
                             tgt_len=model_args.tgt_len,
                             ext_len=model_args.ext_len,
                             mem_len=model_args.mem_len,
                             cutoffs=[],
                             same_length=model_args.same_length,
                             attn_type=model_args.attn_type,
                             clamp_len=model_args.clamp_len,
                             sample_softmax=False)
    model.load_state_dict(checkpoint['model'])
model.backward_compatible()
model = model.to(device)

logging('Evaluating with bsz {} tgt_len {} ext_len {} mem_len {} clamp_len {}'.
コード例 #21
0
            tracker.launch_impact_monitor()

        n_head, d_head = head_repartition_rule(d_model)
        d_inner = d_model

        model = MemTransformerLM(ntokens,
                                 n_layer,
                                 n_head,
                                 d_model,
                                 d_head,
                                 d_inner,
                                 default_args.dropout,
                                 default_args.dropatt,
                                 tie_weight=default_args.tied,
                                 d_embed=d_model,
                                 div_val=default_args.div_val,
                                 tie_projs=tie_projs,
                                 pre_lnorm=default_args.pre_lnorm,
                                 tgt_len=default_args.tgt_len,
                                 ext_len=default_args.ext_len,
                                 mem_len=default_args.mem_len,
                                 cutoffs=cutoffs,
                                 same_length=default_args.same_length,
                                 attn_type=default_args.attn_type,
                                 clamp_len=default_args.clamp_len,
                                 sample_softmax=default_args.sample_softmax)
        initialization_func = partial(weights_init,
                                      init="normal",
                                      init_range=0.1,
                                      init_std=0.02,
                                      proj_init_std=0.01)
コード例 #22
0
def main():
    args = parse_args()

    if args.type == 'pytorch':
        from mem_transformer import MemTransformerLM
    else:
        from inference.mem_transformer_base_jit import MemTransformerLM

    torch.cuda.set_device(args.local_rank)
    device = torch.device('cuda' if args.cuda else 'cpu')
    utils.distributed.init_distributed(args.cuda)

    with utils.distributed.sync_workers() as rank:
        if rank == 0:
            create_exp_dir(args.work_dir, debug=args.debug)

    # Setup logging
    if args.log_all_ranks:
        log_file = f'log_rank_{utils.distributed.get_rank()}.log'
    else:
        log_file = f'log.log'

    log_file = os.path.join(args.work_dir, log_file)
    if args.debug:
        log_file = os.devnull

    utils.exp_utils.setup_logging(
        log_all_ranks=args.log_all_ranks,
        filename=log_file,
        filemode='a',
    )
    logging.info(args)

    if args.model:
        model_path = args.model
    elif args.work_dir:
        model_path = os.path.join(args.work_dir, 'checkpoint_best.pt')
    else:
        raise RuntimeError(
            'Specify path to checkpoint using --model or --work_dir')

    checkpoint = load_checkpoint(model_path)

    if args.manual:
        args.batch_size = 1
        vocab = checkpoint['vocab']

        if hasattr(vocab, 'sym2idx') and not hasattr(vocab, 'unk_idx'):
            vocab.unk_idx = vocab.sym2idx['<unk>']

        text = " ".join(args.manual)
        tokenized = tokenize_raw(text)
        symbols = vocab.tokenize(tokenized, add_eos=True)
        tensor = vocab.convert_to_tensor(symbols)

        iter = data_utils.LMOrderedIterator(tensor,
                                            bsz=args.batch_size,
                                            bptt=args.tgt_len,
                                            device=device,
                                            ext_len=args.ext_len)
    else:
        # Load dataset
        corpus = get_lm_corpus(args.data, args.dataset,
                               checkpoint['args'].vocab)

        if args.split == 'valid':
            iter = corpus.get_iterator('valid',
                                       args.batch_size,
                                       args.tgt_len,
                                       device=device,
                                       ext_len=args.ext_len)
        elif args.split == 'test':
            iter = corpus.get_iterator('test',
                                       args.batch_size,
                                       args.tgt_len,
                                       device=device,
                                       ext_len=args.ext_len)
        else:
            raise RuntimeError('Unknown split')

    if args.fp16:
        dtype = torch.float16
        math_str = 'fp16'
    else:
        dtype = torch.float32
        math_str = 'fp32'

    if args.load_torchscript:
        model = torch.jit.load(args.load_torchscript)

    else:
        checkpoint['model_config']['tgt_len'] = args.tgt_len
        checkpoint['model_config']['ext_len'] = args.ext_len
        checkpoint['model_config']['mem_len'] = args.mem_len
        checkpoint['model_config']['clamp_len'] = args.clamp_len
        checkpoint['model_config']['same_length'] = args.same_length
        checkpoint['model_config']['dtype'] = dtype

        model = MemTransformerLM(**checkpoint['model_config'])
        model.load_state_dict(checkpoint['model_state'])

    model = model.eval()
    model = model.to(device)

    model = model.float()
    if args.fp16:
        model = model.half()

    if args.type != 'pytorch':
        compile_model(model, device, args)

    if args.type == 'torchscript' and args.save_torchscript:
        torch.jit.save(model, args.save_torchscript)

    logging.info(f'Evaluating with: math {math_str} type {args.type} '
                 f'bsz {args.batch_size} tgt_len {args.tgt_len} '
                 f'ext_len {args.ext_len} mem_len {args.mem_len} '
                 f'clamp_len {args.clamp_len}')

    meters = {}
    warmup = args.mem_len // args.tgt_len + 1
    meters['eval_throughput'] = AverageMeter(warmup=warmup,
                                             keep=args.save_data)
    meters['eval_latency'] = AverageMeter(warmup=warmup, keep=args.save_data)

    loss = evaluate(iter, model, meters, args.max_size, args.repeat)
    perplexity = math.exp(loss)
    log_str = format_log(loss, args.split, args)

    logging.info('=' * 100)
    logging.info(log_str)
    logging.info('=' * 100)

    if args.save_data:
        latency_data = np.array(meters['eval_latency'].vals)
        throughput_data = np.array(meters['eval_throughput'].vals)
        precision = 'fp16' if args.fp16 else 'fp32'
        data_fname = f'eval_data_{args.batch_size}_{precision}_{args.type}'
        data_path = os.path.join(args.work_dir, data_fname)
        data = {
            'args': args,
            'throughput': throughput_data,
            'latency': latency_data,
        }
        with open(data_path, 'wb') as f:
            pickle.dump(data, f)
        logging.info(f'Throughput Avg: {throughput_data.mean():.2f} tok/s')
        logging.info(f'Latency Avg: {1000.0 * latency_data.mean():.2f} ms')
        for p in args.percentiles:
            logging.info(
                f'Latency {p}%: {1000.0 * np.percentile(latency_data, p):.2f} ms'
            )

        logging.info('=' * 100)

    passed = benchmark(
        target_perplexity=args.target_perplexity,
        test_perplexity=perplexity,
        target_throughput=args.target_throughput,
        test_throughput=meters['eval_throughput'].avg,
    )
    if not passed:
        sys.exit(1)
コード例 #23
0
    if not args.fp16:
        model = model.float()
    model.apply(update_dropout)
    model.apply(update_dropatt)
else:
    model = MemTransformerLM(ntokens,
                             args.n_layer,
                             args.n_head,
                             args.d_model,
                             args.d_head,
                             args.d_inner,
                             args.dropout,
                             args.dropatt,
                             tie_weight=args.tied,
                             d_embed=args.d_embed,
                             div_val=args.div_val,
                             tie_projs=tie_projs,
                             pre_lnorm=args.pre_lnorm,
                             tgt_len=args.tgt_len,
                             ext_len=args.ext_len,
                             mem_len=args.mem_len,
                             cutoffs=cutoffs,
                             same_length=args.same_length,
                             attn_type=args.attn_type,
                             clamp_len=args.clamp_len,
                             sample_softmax=args.sample_softmax)
    model.apply(weights_init)
    model.word_emb.apply(
        weights_init
    )  # ensure embedding init is not overridden by out_layer in case of weight sharing
args.n_all_param = sum([p.nelement() for p in model.parameters()])
コード例 #24
0
        logging(f"reloading {model_file_name}")
        with open(model_file_name, 'rb') as f:
            model = torch.load(f)
        # backwards compatibility with older saves
        if isinstance(model, nn.DataParallel):
            model = model.module
        model.backward_compatible(tie_weight=args.tied, tie_projs=tie_projs)
        if not args.fp16:
            model = model.float()
        model.apply(update_dropout)
        model.apply(update_dropatt)
    else:
        model = MemTransformerLM(ntokens, args.n_layer, args.n_head, args.d_model,
                                 args.d_head, args.d_inner, args.dropout, args.dropatt,
                                 tie_weight=args.tied, d_embed=args.d_embed, div_val=args.div_val,
                                 tie_projs=tie_projs, pre_lnorm=args.pre_lnorm, tgt_len=args.tgt_len,
                                 ext_len=args.ext_len, mem_len=args.mem_len, cutoffs=cutoffs,
                                 same_length=args.same_length, attn_type=args.attn_type,
                                 clamp_len=args.clamp_len, sample_softmax=args.sample_softmax)
        model.apply(initialization_func)
        # debug
        # model.word_emb.apply(initialization_func)
        # ensure embedding init is not overridden by out_layer in case of weight sharing
    args.n_all_param = sum([p.nelement() for p in model.parameters()])
    args.n_nonemb_param = non_emb_param_count(model, ntokens)

    logging('=' * 100)
    for k, v in args.__dict__.items():
        logging('    - {} : {}'.format(k, v))
    logging('=' * 100)
    logging('#params = {}'.format(args.n_all_param))