Example #1
0
def main(args):
    conf = config.CONFIG[args.data_name]
    data_pth = "data/%s" % args.data_name
    train_data_pth = os.path.join(data_pth, "train_data.txt")
    train_data = MonoTextData(train_data_pth, True)

    vocab = train_data.vocab
    print('Vocabulary size: %d' % len(vocab))

    dev_data_pth = os.path.join(data_pth, "dev_data.txt")
    dev_data = MonoTextData(dev_data_pth, True, vocab=vocab)
    test_data_pth = os.path.join(data_pth, "test_data.txt")
    test_data = MonoTextData(test_data_pth, True, vocab=vocab)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    save_path = '{}-{}'.format(args.save, args.data_name)
    save_path = os.path.join(save_path, time.strftime("%Y%m%d-%H%M%S"))
    scripts_to_save = [
        'run.py', 'models/aggressive_vae.py', 'models/vae.py',
        'models/base_network.py', 'config.py'
    ]
    logging = create_exp_dir(save_path,
                             scripts_to_save=scripts_to_save,
                             debug=args.debug)

    train = train_data.create_data_batch(args.bsz, device)
    dev = dev_data.create_data_batch(args.bsz, device)
    test = test_data.create_data_batch(args.bsz, device)

    kwargs = {
        "train": train,
        "valid": dev,
        "test": test,
        "bsz": args.bsz,
        "save_path": save_path,
        "logging": logging,
    }
    params = conf["params"]
    params["vae_params"]["vocab"] = vocab
    params["vae_params"]["device"] = device
    kwargs = dict(kwargs, **params)

    model = AgressiveVAE(**kwargs)
    try:
        valid_loss = model.fit()
        logging("val loss : {}".format(valid_loss))
    except KeyboardInterrupt:
        logging("Exiting from training early")

    model.load(save_path)
    test_loss = model.evaluate(model.test_data)
    logging("test loss: {}".format(test_loss[0]))
    logging("test recon: {}".format(test_loss[1]))
    logging("test kl: {}".format(test_loss[2]))
    logging("test mi: {}".format(test_loss[3]))
Example #2
0
    def __init__(self, args):
        super(Tester, self).__init__()
        self.args = args

        self.args.cuda = not args.no_cuda and torch.cuda.is_available()
        self.device = torch.device("cuda") if self.args.cuda else torch.device("cpu")
        self.args.factor = not args.no_factor
        self.exp_dir = args.exp_dir
        self.logging = create_exp_dir("exp/test_res")

        seed = args.seed
        np.random.seed(seed)
        torch.manual_seed(seed)
        if torch.cuda.is_available():
            torch.cuda.manual_seed(seed)
Example #3
0
    def __init__(self, args):
        super(Trainer, self).__init__()
        self.args = args

        self.args.cuda = not args.no_cuda and torch.cuda.is_available()
        self.device = torch.device("cuda") if self.args.cuda else torch.device("cpu")
        self.args.factor = not args.no_factor
        self.exp_dir = args.exp_dir

        seed = args.seed
        np.random.seed(seed)
        torch.manual_seed(seed)
        if torch.cuda.is_available():
            torch.cuda.manual_seed(seed)

        self.logging = create_exp_dir(args.exp_dir)

        meta_file_name = osp.join(args.exp_dir, "meta.txt")
        meta_file = open(meta_file_name, "w")
        meta_file.write(str(args))
        meta_file.close()
Example #4
0
args.pretrain_steps += args.start_train_steps
print(f"Experiment name: {args.name}")
assert args.seq_len > 0, "For now you must set seq_len > 0 when using deq"
args.work_dir += "deq"
args.cuda = torch.cuda.is_available()

if args.d_embed < 0:
    args.d_embed = args.nout

assert args.batch_size % args.batch_chunk == 0

args.work_dir = '{}-{}'.format(args.work_dir, args.dataset)
args.work_dir = os.path.join(args.work_dir, time.strftime('%Y%m%d-%H%M%S'))
logging = create_exp_dir(args.work_dir,
                         scripts_to_save=[
                             'train_trellisnet.py',
                             'models/trellisnets/deq_trellisnet.py'
                         ],
                         debug=args.debug)

# Set the random seed manually for reproducibility.
np.random.seed(args.seed)
torch.manual_seed(args.seed)
if torch.cuda.is_available():
    if not args.cuda:
        print(
            'WARNING: You have a CUDA device, so you should probably run with --cuda'
        )
    else:
        torch.set_default_tensor_type('torch.cuda.FloatTensor')
        torch.cuda.manual_seed_all(args.seed)
Example #5
0
def main():
    args = parse_args()
    if args.affinity != 'disabled':
        nproc_per_node = torch.cuda.device_count()
        affinity = utils.gpu_affinity.set_affinity(args.local_rank,
                                                   nproc_per_node,
                                                   args.affinity)
        print(f'{args.local_rank}: thread affinity: {affinity}')

    if args.type == 'pytorch':
        from mem_transformer import MemTransformerLM
    else:
        from inference.mem_transformer_jit import MemTransformerLM

    torch.cuda.set_device(args.local_rank)
    l2_promote()
    device = torch.device('cuda' if args.cuda else 'cpu')
    utils.distributed.init_distributed(args.cuda)

    with utils.distributed.sync_workers() as rank:
        if rank == 0:
            create_exp_dir(args.work_dir, debug=args.debug)

    # Setup logging
    if args.log_all_ranks:
        log_file = f'eval_log_rank_{utils.distributed.get_rank()}.log'
    else:
        log_file = f'eval_log.log'

    dllog_file = args.dllog_file
    log_file = os.path.join(args.work_dir, log_file)
    dllog_file = os.path.join(args.work_dir, dllog_file)
    if args.debug:
        log_file = os.devnull
        dllog_file = os.devnull

    utils.exp_utils.setup_logging(
        log_all_ranks=args.log_all_ranks,
        filename=log_file,
        filemode='a',
    )
    utils.exp_utils.setup_dllogger(enabled=True, filename=dllog_file)

    if args.profile:
        try:
            pyprof.init(enable_function_stack=True)
        except NameError:
            warnings.warn('Called pyprof.init() but pyprof is not available')

    logging.info(args)
    dllogger.log(step='PARAMETER', data=vars(args))

    if not args.no_env:
        log_env_info()

    # Set the random seed manually for reproducibility.
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)

    if args.model:
        model_path = args.model
    elif args.work_dir:
        model_path = os.path.join(args.work_dir, 'checkpoint_best.pt')
    else:
        raise RuntimeError(
            'Specify path to checkpoint using --model or --work_dir')

    if not args.manual_config:
        checkpoint = load_checkpoint(model_path)
        vocab_type = checkpoint['args'].vocab
    else:
        checkpoint = None
        vocab_type = args.manual_vocab

    if args.manual:
        vocab = checkpoint['vocab']

        if hasattr(vocab, 'sym2idx') and not hasattr(vocab, 'unk_idx'):
            vocab.unk_idx = vocab.sym2idx['<unk>']

        text = " ".join(args.manual)
        tokenized = tokenize_raw(text)
        symbols = vocab.tokenize(tokenized, add_eos=True)
        tensor = vocab.convert_to_tensor(symbols)

        iter = data_utils.LMOrderedIterator(tensor,
                                            bsz=args.batch_size,
                                            bptt=args.tgt_len,
                                            device=device,
                                            ext_len=args.ext_len,
                                            warmup=False)
    else:
        # Load dataset
        corpus = get_lm_corpus(args.data, args.dataset, vocab_type)

        if args.split == 'valid' or args.split == 'test':
            iter = corpus.get_iterator(args.split,
                                       args.batch_size,
                                       args.tgt_len,
                                       device=device,
                                       mem_len=args.mem_len,
                                       ext_len=args.ext_len)
        else:
            raise RuntimeError('Unknown split')

    if args.fp16:
        dtype = torch.float16
        math_str = 'fp16'
    else:
        dtype = torch.float32
        math_str = 'fp32'

    if args.load_torchscript:
        model = torch.jit.load(args.load_torchscript)
    elif not args.manual_config:
        checkpoint['model_config']['tgt_len'] = args.tgt_len
        checkpoint['model_config']['ext_len'] = args.ext_len
        checkpoint['model_config']['mem_len'] = args.mem_len
        checkpoint['model_config']['clamp_len'] = args.clamp_len
        checkpoint['model_config']['same_length'] = args.same_length
        checkpoint['model_config']['dtype'] = dtype

        model = MemTransformerLM(**checkpoint['model_config'])
        if args.type == 'pytorch':
            model.load_state_dict(checkpoint['model_state'])
        elif args.type == 'torchscript':
            model.load_state_dict(checkpoint['model_state'], strict=False)
    elif args.manual_config:
        args.manual_config['tgt_len'] = args.tgt_len
        args.manual_config['ext_len'] = args.ext_len
        args.manual_config['mem_len'] = args.mem_len
        args.manual_config['clamp_len'] = args.clamp_len
        args.manual_config['same_length'] = args.same_length
        args.manual_config['dtype'] = dtype

        model = MemTransformerLM(**args.manual_config)

    model = model.eval()
    model = model.to(device)
    model = model.to(dtype)

    if args.type == 'torchscript' and not args.manual_config:
        state = checkpoint['model_state']

        tie_projs = checkpoint['model_config']['tie_projs']
        tie_weight = checkpoint['model_config']['tie_weight']
        div_val = checkpoint['model_config']['div_val']
        d_model = checkpoint['model_config']['d_model']
        d_embed = checkpoint['model_config']['d_embed']

        if div_val != 1 or d_model != d_embed:
            for i in range(len(model.word_emb.emb_projs)):
                model.word_emb.emb_projs[i] = state[
                    f'word_emb.emb_projs.{i}'].to(dtype)

        for i in range(len(model.crit.out_projs)):
            if div_val == 1:
                src = 0
            else:
                src = i
            if model.crit.out_projs[i] is not None:
                if tie_projs[i]:
                    model.crit.out_projs[i] = state[
                        f'word_emb.emb_projs.{src}'].to(dtype)
                else:
                    model.crit.out_projs[i] = state[f'crit.out_projs.{i}'].to(
                        dtype)

        for i in range(len(model.crit.out_layers_biases)):
            model.crit.out_layers_biases[i] = state[
                f'crit.out_layers_biases.{i}'].to(dtype)

        if tie_weight:
            for i in range(len(model.crit.out_layers_weights)):
                model.crit.out_layers_weights[i] = state[
                    f'word_emb.emb_layers.{i}.weight'].to(dtype)
        else:
            for i in range(len(model.crit.out_layers_weights)):
                model.crit.out_layers_weights[i] = state[
                    f'crit.out_layers_weights.{i}'].to(dtype)

        model = torch.jit.script(model)

    if args.type != 'pytorch':
        compile_model(model, device, args)

    if args.type == 'torchscript' and args.save_torchscript:
        torch.jit.save(model, args.save_torchscript)

    logging.info(f'Evaluating with: math {math_str} type {args.type} '
                 f'bsz {args.batch_size} tgt_len {args.tgt_len} '
                 f'ext_len {args.ext_len} mem_len {args.mem_len} '
                 f'clamp_len {args.clamp_len}')

    meters = {}
    warmup = args.mem_len // args.tgt_len + 2
    meters['eval_throughput'] = AverageMeter(warmup=warmup,
                                             keep=args.save_data)
    meters['eval_latency'] = AverageMeter(warmup=warmup, keep=args.save_data)

    with torch.autograd.profiler.emit_nvtx(enabled=args.profile):
        loss = evaluate(iter, model, meters, args.log_interval, args.max_size,
                        args.repeat)
    perplexity = math.exp(loss)
    log_str = format_log(loss, args.split, args)

    summary = {
        'eval_loss': loss,
        'eval_ppl': perplexity,
    }

    logging.info('=' * 100)
    logging.info(log_str)
    logging.info('=' * 100)

    if args.save_data:
        latency_data = np.array(meters['eval_latency'].vals)
        throughput_data = np.array(meters['eval_throughput'].vals)
        precision = 'fp16' if args.fp16 else 'fp32'
        data_fname = f'eval_data_{args.batch_size}_{precision}_{args.type}'
        data_path = os.path.join(args.work_dir, data_fname)
        data = {
            'args': args,
            'throughput': throughput_data,
            'latency': latency_data,
        }
        with open(data_path, 'wb') as f:
            pickle.dump(data, f)
        logging.info(f'Throughput Avg: {throughput_data.mean():.2f} tok/s')
        logging.info(f'Latency Avg: {1000.0 * latency_data.mean():.2f} ms')
        for p in args.percentiles:
            logging.info(
                f'Latency {p}%: {1000.0 * np.percentile(latency_data, p):.2f} ms'
            )

        logging.info('=' * 100)

        summary.update({
            'eval_throughput': throughput_data.mean(),
            'eval_avg_latency': 1000 * latency_data.mean(),
        })
        for p in args.percentiles:
            summary[f'eval_{p}%_latency'] = 1000 * np.percentile(
                latency_data, p)

    dllogger.log(step=tuple(), data=summary)

    passed = benchmark(
        target_perplexity=args.target_perplexity,
        test_perplexity=perplexity,
        target_throughput=args.target_throughput,
        test_throughput=meters['eval_throughput'].avg,
    )
    if not passed:
        sys.exit(1)
Example #6
0
def main():
    args = parse_args()

    # Initialize device and distributed backend
    torch.cuda.set_device(args.local_rank)
    device = torch.device('cuda' if args.cuda else 'cpu')
    utils.distributed.init_distributed(args.cuda)

    args.work_dir = utils.exp_utils.build_work_dir_name(
        args.work_dir,
        args.dataset,
        args.append_dataset,
        args.append_time,
    )

    with utils.distributed.sync_workers() as rank:
        if rank == 0:
            create_exp_dir(args.work_dir,
                           scripts_to_save=['train.py', 'mem_transformer.py'],
                           debug=args.debug)

    # Setup logging
    if args.log_all_ranks:
        log_file = f'log_rank_{utils.distributed.get_rank()}.log'
    else:
        log_file = f'log.log'
    log_file = os.path.join(args.work_dir, log_file)

    if args.debug:
        log_file = os.devnull

    utils.exp_utils.setup_logging(
        log_all_ranks=args.log_all_ranks,
        filename=log_file,
    )
    logging.info(args)

    # Set the random seed manually for reproducibility.
    np.random.seed(args.seed + utils.distributed.get_rank())
    torch.manual_seed(args.seed + utils.distributed.get_rank())

    ###########################################################################
    # Load data
    ###########################################################################
    corpus = get_lm_corpus(args.data, args.dataset, args.vocab)
    ntokens = len(corpus.vocab)
    vocab = corpus.vocab
    args.n_token = ntokens

    tr_iter = corpus.get_iterator('train',
                                  args.batch_size,
                                  args.tgt_len,
                                  device=device,
                                  ext_len=args.ext_len)
    va_iter = corpus.get_iterator('valid',
                                  args.eval_batch_size,
                                  args.eval_tgt_len,
                                  device=device,
                                  ext_len=args.ext_len)
    te_iter = corpus.get_iterator('test',
                                  args.eval_batch_size,
                                  args.eval_tgt_len,
                                  device=device,
                                  ext_len=args.ext_len)

    # adaptive softmax / embedding
    cutoffs, tie_projs = [], [False]
    if args.adaptive:
        assert args.dataset in ['wt103', 'lm1b']
        if args.dataset == 'wt103':
            cutoffs = [19997, 39997, 199997]
            tie_projs += [True] * len(cutoffs)
        elif args.dataset == 'lm1b':
            cutoffs = [59997, 99997, 639997]
            tie_projs += [False] * len(cutoffs)

    ###########################################################################
    # Build the model
    ###########################################################################
    model_config = {
        'n_token': ntokens,
        'n_layer': args.n_layer,
        'n_head': args.n_head,
        'd_model': args.d_model,
        'd_head': args.d_head,
        'd_inner': args.d_inner,
        'dropout': args.dropout,
        'dropatt': args.dropatt,
        'dtype': None,
        'tie_weight': args.tied,
        'd_embed': args.d_embed,
        'div_val': args.div_val,
        'tie_projs': tie_projs,
        'pre_lnorm': args.pre_lnorm,
        'tgt_len': args.tgt_len,
        'ext_len': args.ext_len,
        'mem_len': args.mem_len,
        'cutoffs': cutoffs,
        'same_length': args.same_length,
        'attn_type': args.attn_type,
        'clamp_len': args.clamp_len,
        'sample_softmax': args.sample_softmax,
    }

    model = MemTransformerLM(**model_config)

    model.apply(functools.partial(weights_init, args=args))
    # ensure embedding init is not overridden by out_layer in case of weight sharing
    model.word_emb.apply(functools.partial(weights_init, args=args))

    args.n_all_param = sum([p.nelement() for p in model.parameters()])
    args.n_nonemb_param = sum(
        [p.nelement() for p in model.layers.parameters()])

    # optimizer
    if args.optim.lower() == 'sgd':
        if args.sample_softmax > 0:
            dense_params, sparse_params = [], []
            for param in model.parameters():
                if param.size() == model.word_emb.weight.size():
                    sparse_params.append(param)
                else:
                    dense_params.append(param)
            optimizer_sparse = optim.SGD(sparse_params, lr=args.lr * 2)
            optimizer = optim.SGD(dense_params, lr=args.lr, momentum=args.mom)
        else:
            optimizer = optim.SGD(model.parameters(),
                                  lr=args.lr,
                                  momentum=args.mom)
            optimizer_sparse = None
    elif args.optim.lower() == 'adam':
        if args.sample_softmax > 0:
            dense_params, sparse_params = [], []
            for param in model.parameters():
                if param.size() == model.word_emb.weight.size():
                    sparse_params.append(param)
                else:
                    dense_params.append(param)
            optimizer_sparse = optim.SparseAdam(sparse_params, lr=args.lr)
            optimizer = optim.Adam(dense_params,
                                   lr=args.lr,
                                   weight_decay=args.weight_decay)
        else:
            optimizer = optim.Adam(model.parameters(),
                                   lr=args.lr,
                                   weight_decay=args.weight_decay)
            optimizer_sparse = None
    elif args.optim.lower() == 'adagrad':
        optimizer = optim.Adagrad(model.parameters(), lr=args.lr)
        optimizer_sparse = None
    elif args.optim.lower() == 'lamb':
        optimizer = lamb.Lamb(model.parameters(),
                              lr=args.lr,
                              weight_decay=args.weight_decay)
        optimizer_sparse = None
    elif args.optim.lower() == 'jitlamb':
        optimizer = lamb.JITLamb(model.parameters(),
                                 lr=args.lr,
                                 weight_decay=args.weight_decay)
        optimizer_sparse = None

    model = model.to(device)

    if args.fp16:
        model, optimizer = amp.initialize(
            model,
            optimizer,
            opt_level='O2',
        )

    if args.multi_gpu == 'ddp' and torch.distributed.is_initialized():
        para_model = DistributedDataParallel(
            model,
            delay_allreduce=True,
        )
    elif args.multi_gpu == 'dp':
        if args.gpu0_bsz >= 0:
            para_model = BalancedDataParallel(args.gpu0_bsz //
                                              args.batch_chunk,
                                              model,
                                              dim=1).to(device)
        else:
            para_model = nn.DataParallel(model, dim=1).to(device)
    else:
        para_model = model

    # scheduler
    if args.scheduler == 'cosine':
        if args.max_step_scheduler:
            max_step = args.max_step_scheduler
        else:
            max_step = args.max_step

        scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer,
                                                         max_step,
                                                         eta_min=args.eta_min)
        if args.sample_softmax > 0:
            scheduler_sparse = optim.lr_scheduler.CosineAnnealingLR(
                optimizer_sparse, max_step, eta_min=args.eta_min)
        else:
            scheduler_sparse = None
    elif args.scheduler == 'inv_sqrt':
        # originally used for Transformer (in Attention is all you need)
        def lr_lambda(step):
            # return a multiplier instead of a learning rate
            if step == 0 and args.warmup_step == 0:
                return 1.
            else:
                return 1. / (step ** 0.5) if step > args.warmup_step \
                    else step / (args.warmup_step ** 1.5)

        scheduler = optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lr_lambda)
    elif args.scheduler == 'dev_perf':
        scheduler = optim.lr_scheduler.ReduceLROnPlateau(
            optimizer,
            factor=args.decay_rate,
            patience=args.patience,
            min_lr=args.lr_min,
        )
        if args.sample_softmax > 0:
            scheduler_sparse = optim.lr_scheduler.ReduceLROnPlateau(
                optimizer_sparse,
                factor=args.decay_rate,
                patience=args.patience,
                min_lr=args.lr_min,
            )
        else:
            scheduler_sparse = None
    elif args.scheduler == 'constant':
        pass

    logging.info('=' * 100)
    for k, v in args.__dict__.items():
        logging.info('    - {} : {}'.format(k, v))
    logging.info('=' * 100)
    logging.info('#params = {}'.format(args.n_all_param))
    logging.info('#non emb params = {}'.format(args.n_nonemb_param))

    train_step = 0
    best_val_loss = None

    if args.restart:
        checkpoint = load_checkpoint(args.restart)
        model.load_state_dict(checkpoint['model_state'])
        optimizer.load_state_dict(checkpoint['optimizer_state'])
        scheduler.load_state_dict(checkpoint['scheduler_state'])
        if args.fp16:
            amp.load_state_dict(checkpoint['amp_state'])
        train_step = checkpoint['train_step']
        best_val_loss = checkpoint['best_val_loss']

        model.apply(functools.partial(update_dropout, args=args))
        model.apply(functools.partial(update_dropatt, args=args))

    meters = {}
    warmup = args.mem_len // args.tgt_len + 1
    meters['train_throughput'] = AverageMeter(warmup=warmup)
    ###########################################################################
    # Train
    ###########################################################################
    # Loop over epochs.
    # At any point you can hit Ctrl + C to break out of training early.
    start_time = time.time()
    try:
        for epoch in itertools.count(start=1):
            if args.roll:
                tr_iter.roll()
            train_step, best_val_loss = train(tr_iter, va_iter, model,
                                              para_model, model_config,
                                              optimizer, optimizer_sparse,
                                              scheduler, scheduler_sparse,
                                              vocab, epoch, train_step,
                                              best_val_loss, meters, args)

            if train_step == args.max_step:
                logging.info('-' * 100)
                logging.info('End of training')
                break
    except KeyboardInterrupt:
        logging.info('-' * 100)
        logging.info('Exiting from training early')
    elapsed = time.time() - start_time

    ###########################################################################
    # Test
    ###########################################################################
    test_path = os.path.join(args.work_dir, 'checkpoint_best.pt')
    if not args.debug and os.path.exists(test_path):
        # Load the best saved model.
        checkpoint = load_checkpoint(test_path)
        model.load_state_dict(checkpoint['model_state'])

        # Run on test data.
        test_start_time = time.time()
        test_loss = evaluate(te_iter, model, args)
        test_loss = utils.distributed.all_reduce_item(test_loss, 'mean')

        logging.info('=' * 100)
        if args.dataset in ['enwik8', 'text8']:
            logging.info(
                '| End of training | test time: {:5.2f}s | test loss {:5.2f} | test bpc {:9.5f}'
                .format(time.time() - test_start_time, test_loss,
                        test_loss / math.log(2)))
        else:
            logging.info(
                '| End of training | test time: {:5.2f}s | test loss {:5.2f} | test ppl {:9.3f}'
                .format(time.time() - test_start_time, test_loss,
                        math.exp(test_loss)))
        logging.info('=' * 100)

    logging.info(f'Training time: {(elapsed / 60):.2f} minutes')
    logging.info(
        f'Training throughput: {meters["train_throughput"].avg:.2f} tok/s')

    if best_val_loss:
        val_perplexity = math.exp(best_val_loss)
    else:
        val_perplexity = None

    passed = benchmark(target_perplexity=args.target_perplexity,
                       test_perplexity=val_perplexity,
                       target_throughput=args.target_throughput,
                       test_throughput=meters['train_throughput'].avg)
    if not passed:
        sys.exit(1)
Example #7
0
                    default=1.2e-6,
                    help='weight decay applied to all weights')

args = parser.parse_args()
print(args)
args.tied = not args.not_tied

if args.d_embed < 0:
    args.d_embed = args.d_model

assert args.ext_len >= 0, 'extended context length must be non-negative'
assert args.batch_size % args.batch_chunk == 0

args.work_dir = '{}-{}-{}'.format(args.work_dir, args.dataset, args.seed)
#args.work_dir = os.path.join(args.work_dir, time.strftime('%Y%m%d-%H%M%S'))
logging = create_exp_dir(args.work_dir, debug=args.debug)

# Set the random seed manually for reproducibility.
np.random.seed(args.seed)
torch.manual_seed(args.seed)
if torch.cuda.is_available():
    if not args.cuda:
        print(
            'WARNING: You have a CUDA device, so you should probably run with --cuda'
        )
    else:
        torch.cuda.manual_seed_all(args.seed)

# Validate `--fp16` option
if args.fp16:
    if not args.cuda:
Example #8
0
args.pretrain_steps += args.start_train_steps
print(f"Experiment name: {args.name}")
assert args.mem_len > 0, "For now you must set mem_len > 0 when using deq"
args.work_dir += "deq"
args.cuda = torch.cuda.is_available()

if args.d_embed < 0:
    args.d_embed = args.d_model

assert args.batch_size % args.batch_chunk == 0

args.work_dir = '{}-{}'.format(args.work_dir, args.dataset)
args.work_dir = os.path.join(args.work_dir, time.strftime('%Y%m%d-%H%M%S'))
logging = create_exp_dir(args.work_dir,
                         scripts_to_save=[
                             'train_transformer.py',
                             'models/transformers/deq_transformer.py'
                         ],
                         debug=args.debug)

# Set the random seed manually for reproducibility.
np.random.seed(args.seed)
torch.manual_seed(args.seed)
if torch.cuda.is_available():
    if not args.cuda:
        print(
            'WARNING: You have a CUDA device, so you should probably run with --cuda'
        )
    else:
        torch.set_default_tensor_type('torch.cuda.FloatTensor')
        torch.cuda.manual_seed_all(args.seed)
Example #9
0
# os.environ['CUDA_LAUNCH_BLOCKING'] = "1"

parser = create_parser()

args = parser.parse_args()
args.tied = not args.not_tied

if args.d_embed < 0:
    args.d_embed = args.d_model

assert args.ext_len >= 0, 'extended context length must be non-negative'
assert args.batch_size % args.batch_chunk == 0

# args.work_dir = '{}-{}'.format(args.work_dir, args.dataset)
# args.work_dir = os.path.join(args.work_dir, time.strftime('%Y%m%d-%H%M%S'))
logging = create_exp_dir(args.work_dir, scripts_to_save=[], debug=args.debug)

# Set the random seed manually for reproducibility.
np.random.seed(args.seed)
torch.manual_seed(args.seed)
if torch.cuda.is_available():
    if not args.cuda:
        print(
            'WARNING: You have a CUDA device, so you should probably run with --cuda'
        )
    else:
        torch.cuda.manual_seed_all(args.seed)

# Validate `--fp16` option
if args.fp16:
    if not args.cuda:
Example #10
0
def main(args):
    conf = config.CONFIG[args.data_name]
    data_pth = "data/%s" % args.data_name

    train_sentiment_data_pth = os.path.join(data_pth,
                                            "train_sentiment_data.txt")
    train_sentiment_feat_pth = os.path.join(
        data_pth, "train_sentiment_%s.npy" % args.feat)
    train_sentiment_data = MonoTextData(train_sentiment_data_pth, True)
    train_sentiment_feat = np.load(train_sentiment_feat_pth)

    train_tense_data_pth = os.path.join(data_pth, "train_tense_data.txt")
    train_tense_feat_pth = os.path.join(data_pth,
                                        "train_tense_%s.npy" % args.feat)
    train_tense_data = MonoTextData(train_tense_data_pth, True)
    train_tense_feat = np.load(train_tense_feat_pth)

    sentiment_vocab = train_sentiment_data.vocab
    print('Sentiment Vocabulary size: %d' % len(sentiment_vocab))

    tense_vocab = train_tense_data.vocab
    print('Tense Vocabulary size: %d' % len(tense_vocab))

    dev_sentiment_data_pth = os.path.join(data_pth, "dev_sentiment_data.txt")
    dev_sentiment_feat_pth = os.path.join(data_pth,
                                          "dev_sentiment_%s.npy" % args.feat)
    dev_sentiment_data = MonoTextData(dev_sentiment_data_pth,
                                      True,
                                      vocab=sentiment_vocab)
    dev_sentiment_feat = np.load(dev_sentiment_feat_pth)

    dev_tense_data_pth = os.path.join(data_pth, "dev_tense_data.txt")
    dev_tense_feat_pth = os.path.join(data_pth, "dev_tense_%s.npy" % args.feat)
    dev_tense_data = MonoTextData(dev_tense_data_pth, True, vocab=tense_vocab)
    dev_tense_feat = np.load(dev_tense_feat_pth)

    test_sentiment_data_pth = os.path.join(data_pth, "test_sentiment_data.txt")
    test_sentiment_feat_pth = os.path.join(data_pth,
                                           "test_sentiment_%s.npy" % args.feat)
    test_sentiment_data = MonoTextData(test_sentiment_data_pth,
                                       True,
                                       vocab=sentiment_vocab)
    test_sentiment_feat = np.load(test_sentiment_feat_pth)

    test_tense_data_pth = os.path.join(data_pth, "test_tense_data.txt")
    test_tense_feat_pth = os.path.join(data_pth,
                                       "test_tense_%s.npy" % args.feat)
    test_tense_data = MonoTextData(test_tense_data_pth,
                                   True,
                                   vocab=tense_vocab)
    test_tense_feat = np.load(test_tense_feat_pth)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    save_path0 = 'sentiment-{}-{}-{}'.format(args.save, args.data_name,
                                             args.feat)
    save_path0 = os.path.join(save_path0, time.strftime("%Y%m%d-%H%M%S"))
    save_path1 = 'tense-{}-{}-{}'.format(args.save, args.data_name, args.feat)
    save_path1 = os.path.join(save_path1, time.strftime("%Y%m%d-%H%M%S"))

    scripts_to_save = [
        'run.py', 'models/decomposed_vae.py', 'models/vae.py',
        'models/base_network.py', 'config.py'
    ]
    logging0 = create_exp_dir(save_path0,
                              scripts_to_save=scripts_to_save,
                              debug=args.debug)
    logging1 = create_exp_dir(save_path1,
                              scripts_to_save=scripts_to_save,
                              debug=args.debug)

    if args.text_only:
        train_sentiment = train_sentiment_data.create_data_batch(
            args.bsz, device)
        dev_sentiment = dev_sentiment_data.create_data_batch(args.bsz, device)
        test_sentiment = test_sentiment_data.create_data_batch(
            args.bsz, device)
        feat_sentiment = train_sentiment

        train_tense = train_tense_data.create_data_batch(args.bsz, device)
        test_tense = test_tense_data.create_data_batch(args.bsz, device)
        feat_tense = train_tense
    else:
        train_sentiment = train_sentiment_data.create_data_batch_feats(
            args.bsz, train_sentiment_feat, device)
        dev_sentiment = dev_sentiment_data.create_data_batch_feats(
            args.bsz, dev_sentiment_feat, device)
        test_sentiment = test_sentiment_data.create_data_batch_feats(
            args.bsz, test_sentiment_feat, device)
        feat_sentiment = train_sentiment_feat
        train_tense = train_tense_data.create_data_batch_feats(
            args.bsz, train_tense_feat, device)
        test_tense = test_tense_data.create_data_batch_feats(
            args.bsz, test_tense_feat, device)
        feat_tense = train_tense_feat

    #VAE training on sentiment data
    # kwargs0 = {
    #     "train": train_sentiment,
    #     "valid": dev_sentiment,
    #     "test": test_sentiment,
    #     "feat": feat_sentiment,
    #     "bsz": args.bsz,
    #     "save_path": save_path0,
    #     "logging": logging0,
    #     "text_only": args.text_only,
    # }
    # params = conf["params"]
    # params["vae_params"]["vocab"] = sentiment_vocab
    # params["vae_params"]["device"] = device
    # params["vae_params"]["text_only"] = args.text_only
    # params["vae_params"]["mlp_ni"] = train_sentiment_feat.shape[1]
    # kwargs0 = dict(kwargs0, **params)

    # sentiment_model = DecomposedVAE(**kwargs0)
    # try:
    #     valid_loss = sentiment_model.fit()
    #     logging("sentiment val loss : {}".format(valid_loss))
    # except KeyboardInterrupt:
    #     logging("Exiting from training early")

    # sentiment_model.load(save_path0)
    # test_loss = model.evaluate(sentiment_model.test_data, sentiment_model.test_feat)
    # logging("sentiment test loss: {}".format(test_loss[0]))
    # logging("sentiment test recon: {}".format(test_loss[1]))
    # logging("sentiment test kl1: {}".format(test_loss[2]))
    # logging("sentiment test kl2: {}".format(test_loss[3]))
    # logging("sentiment test mi1: {}".format(test_loss[4]))
    # logging("sentiment test mi2: {}".format(test_loss[5]))

    #VAE training on tense data
    kwargs1 = {
        "train": train_tense,
        "valid": test_tense,
        "test": test_tense,
        "feat": feat_tense,
        "bsz": args.bsz,
        "save_path": save_path1,
        "logging": logging1,
        "text_only": args.text_only,
    }
    params = conf["params"]
    params["vae_params"]["vocab"] = tense_vocab
    params["vae_params"]["device"] = device
    params["vae_params"]["text_only"] = args.text_only
    params["vae_params"]["mlp_ni"] = train_tense_feat.shape[1]
    kwargs1 = dict(kwargs1, **params)

    tense_model = DecomposedVAE(**kwargs1)
    try:
        valid_loss = tense_model.fit()
        logging("tense val loss : {}".format(valid_loss))
    except KeyboardInterrupt:
        logging("Exiting from training early")

    tense_model.load(save_path1)
    test_loss = model.evaluate(tense_model.test_data, tense_model.test_feat)
    logging("tense test loss: {}".format(test_loss[0]))
    logging("tense test recon: {}".format(test_loss[1]))
    logging("tense test kl1: {}".format(test_loss[2]))
    logging("tense test kl2: {}".format(test_loss[3]))
    logging("tense test mi1: {}".format(test_loss[4]))
    logging("tense test mi2: {}".format(test_loss[5]))
Example #11
0
def main(args):
    conf = config.CONFIG[args.data_name]
    data_pth = "data/%s" % args.data_name
    train_data_pth = os.path.join(data_pth, "train_input_data.csv")
    train_feat_pth = os.path.join(data_pth, "train_%s.npy" % args.feat)
    train_data = MonoTextData(train_data_pth, True)
    train_feat = np.load(train_feat_pth)

    vocab = train_data.vocab
    print('Vocabulary size: %d' % len(vocab))

    dev_data_pth = os.path.join(data_pth, "dev_input_data.csv")
    dev_feat_pth = os.path.join(data_pth, "dev_%s.npy" % args.feat)
    dev_data = MonoTextData(dev_data_pth, True, vocab=vocab)
    dev_feat = np.load(dev_feat_pth)
    test_data_pth = os.path.join(data_pth, "test_input_data.csv")
    test_feat_pth = os.path.join(data_pth, "test_%s.npy" % args.feat)
    test_data = MonoTextData(test_data_pth, True, vocab=vocab)
    test_feat = np.load(test_feat_pth)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    save_path = '{}-{}-{}'.format(args.save, args.data_name, args.feat)
    save_path = os.path.join(save_path, time.strftime("%Y%m%d-%H%M%S"))
    scripts_to_save = [
        'run.py', 'models/decomposed_vae.py', 'models/vae.py',
        'models/base_network.py', 'config.py'
    ]
    logging = create_exp_dir(save_path,
                             scripts_to_save=scripts_to_save,
                             debug=args.debug)

    if args.text_only:
        train, train_sentiments, train_tenses = train_data.create_data_batch_labels(
            args.bsz, device)
        dev, dev_sentiments, dev_tenses = dev_data.create_data_batch_labels(
            args.bsz, device)
        test, test_sentiments, test_tenses = test_data.create_data_batch_labels(
            args.bsz, device)
        feat = train
    else:
        train = train_data.create_data_batch_feats(args.bsz, train_feat,
                                                   device)
        dev = dev_data.create_data_batch_feats(args.bsz, dev_feat, device)
        test = test_data.create_data_batch_feats(args.bsz, test_feat, device)
        feat = train_feat

    print("data done.")

    kwargs = {
        "train": train,
        "valid": dev,
        "test": test,
        "train_sentiments": train_sentiments,
        "train_tenses": train_tenses,
        "dev_sentiments": dev_sentiments,
        "dev_tenses": dev_tenses,
        "test_sentiments": test_sentiments,
        "test_tenses": test_tenses,
        "feat": feat,
        "bsz": args.bsz,
        "save_path": save_path,
        "logging": logging,
        "text_only": args.text_only,
    }
    params = conf["params"]
    params["vae_params"]["vocab"] = vocab
    params["vae_params"]["device"] = device
    params["vae_params"]["text_only"] = args.text_only
    params["vae_params"]["mlp_ni"] = train_feat.shape[1]
    kwargs = dict(kwargs, **params)

    model = DecomposedVAE(**kwargs)
    try:
        valid_loss = model.fit()
        logging("val loss : {}".format(valid_loss))
    except KeyboardInterrupt:
        logging("Exiting from training early")

    model.load(save_path)
    test_loss = model.evaluate(model.test_data, model.test_feat)
    logging("test loss: {}".format(test_loss[0]))
    logging("test recon: {}".format(test_loss[1]))
    logging("test kl1: {}".format(test_loss[2]))
    logging("test kl2: {}".format(test_loss[3]))
    logging("test mi1: {}".format(test_loss[4]))
    logging("test mi2: {}".format(test_loss[5]))
Example #12
0
def train_ts(args):
    def build_scheduler(optimizers, args):
        optimizer, optimizer_sparse = optimizers
        scheduler_sparse = None

        if args.scheduler == "cosine":
            # here we do not set eta_min to lr_min to be backward compatible
            # because in previous versions eta_min is default to 0
            # rather than the default value of lr_min 1e-6
            scheduler = optim.lr_scheduler.CosineAnnealingLR(
                optimizer, args.max_step,
                eta_min=args.eta_min)  # should use eta_min arg

        elif args.scheduler == "inv_sqrt":
            # originally used for Transformer (in Attention is all you need)
            def lr_lambda(step):
                # return a multiplier instead of a learning rate
                if step == 0 and args.warmup_step == 0:
                    return 1.0
                else:
                    return (1.0 /
                            (step**0.5) if step > args.warmup_step else step /
                            (args.warmup_step**1.5))

            scheduler = optim.lr_scheduler.LambdaLR(optimizer,
                                                    lr_lambda=lr_lambda)

        elif args.scheduler == "dev_perf":
            scheduler = optim.lr_scheduler.ReduceLROnPlateau(
                optimizer,
                factor=args.decay_rate,
                patience=args.patience,
                min_lr=args.lr_min,
            )

        elif args.scheduler == "constant":
            pass

        else:
            raise ValueError(f"scheduler type {args.scheduler} not recognized")

        return scheduler, scheduler_sparse

    ###############################################################################
    # Training code
    ###############################################################################
    def evaluate(eval_iter, model):

        # Turn on evaluation mode which disables dropout.
        model.eval()

        # debug
        # If the model does not use memory at all, make the ext_len longer.
        # Otherwise, make the mem_len longer and keep the ext_len the same.
        # if default_args.mem_len == 0:
        #     model.reset_length(default_args.eval_tgt_len,
        #                        default_args.ext_len + default_args.tgt_len -
        #                        default_args.eval_tgt_len, default_args.mem_len)
        # else:
        #     model.reset_length(default_args.eval_tgt_len,
        #                        default_args.ext_len, default_args.mem_len +
        #                       default_args.tgt_len - default_args.eval_tgt_len)

        # Evaluation
        total_len, total_loss = 0, 0.0
        with torch.no_grad():
            mems = tuple()
            for i, (data, target, seq_len) in enumerate(eval_iter):
                if i >= args.max_eval_steps > 0:
                    break
                ret = model(data, target, *mems)
                loss, mems = ret[0], ret[1:]
                loss = loss.mean()
                total_loss += seq_len * loss.float().item()
                total_len += seq_len

        # Switch back to the training mode
        # model.reset_length(default_args.tgt_len, default_args.ext_len,
        # default_args.mem_len)
        model.train()

        return total_loss / total_len

    # reverse distillation util
    def get_original_batches(model, tr_iter, integration_length):
        model.eval()
        if args.batch_chunk > 1:
            mems = [None for _ in range(args.batch_chunk)]
            first_logits = [[] for _ in range(args.batch_chunk)]
        else:
            mems = None
            first_logits = []
        train_iter = tr_iter.get_varlen_iter() if args.varlen else tr_iter
        with torch.no_grad():
            for batch, (data, target, seq_len) in enumerate(train_iter):
                if batch == integration_length:
                    break
                if args.batch_chunk > 1:
                    data_chunks = torch.chunk(data, args.batch_chunk, 1)
                    for i in range(args.batch_chunk):
                        data_i = data_chunks[i].contiguous()
                        logits, mems[i] = model._forward(data_i, mems=mems[i])
                        first_logits[i].append(logits.cpu())
                else:
                    logits, mems = model._forward(data, mems=mems)
                    first_logits.append(logits.cpu())
        return first_logits

    def build_optimizer(model, args, reload=False):
        optimizer_sparse = None
        if args.optim.lower() == "sgd":
            optimizer = optim.SGD(model.parameters(),
                                  lr=args.lr,
                                  momentum=args.mom)
        elif args.optim.lower() == "adam":
            optimizer = optim.Adam(model.parameters(), lr=args.lr)
        elif args.optim.lower() == "adagrad":
            optimizer = optim.Adagrad(model.parameters(), lr=args.lr)
        else:
            raise ValueError(f"optimizer type {args.optim} not recognized")

        if reload:
            if args.restart_from is not None:
                optim_name = f"optimizer_{args.restart_from}.pt"
            else:
                optim_name = "optimizer.pt"
            optim_file_name = os.path.join(args.restart_dir, optim_name)
            logging(f"reloading {optim_file_name}")
            if os.path.exists(os.path.join(args.restart_dir, optim_name)):
                with open(os.path.join(args.restart_dir, optim_name),
                          "rb") as optim_file:
                    opt_state_dict = torch.load(optim_file)
                    try:
                        optimizer.load_state_dict(opt_state_dict)
                    # in case the optimizer param groups aren't the same shape,
                    # merge them
                    except:
                        logging("merging optimizer param groups")
                        opt_state_dict["param_groups"][0]["params"] = [
                            param
                            for param_group in opt_state_dict["param_groups"]
                            for param in param_group["params"]
                        ]
                        opt_state_dict["param_groups"] = [
                            opt_state_dict["param_groups"][0]
                        ]
                        optimizer.load_state_dict(opt_state_dict)
            else:
                logging("Optimizer was not saved. Start from scratch.")

        return optimizer, optimizer_sparse

    def log_val(val_loss, step, compute):
        logging("-" * 100)
        log_str = ("| Eval {:3d} at step {:>8d} | time: {:5.2f}s "
                   "| valid loss {:5.2f}".format(
                       step // args.eval_interval,
                       step,
                       (time.time() - eval_start_time),
                       val_loss,
                   ))
        log_str += " | bpc {:9.5f}".format(val_loss / math.log(2))
        logging(log_str)
        logging("-" * 100)

    def epoch_loop(
        epoch,
        model,
        optimizers,
        schedulers,
    ):
        nonlocal train_step

        # Turn on training mode which enables dropout.
        if isinstance(model, nn.DataParallel):
            parent_model = model.module
        else:
            parent_model = model
        optimizer, optimizer_sparse = optimizers
        scheduler, scheduler_sparse = schedulers

        # global train_step, best_val_loss, eval_start_time, log_start_time
        train_losses = []
        model.train()
        if args.batch_chunk > 1:
            mems = [tuple() for _ in range(args.batch_chunk)]
        else:
            mems = tuple()
        train_iter = tr_iter.get_varlen_iter() if args.varlen else tr_iter

        log_start_time = time.time()
        best_val_loss = float("Infinity")
        for batch, (data, target, seq_len) in enumerate(train_iter):
            model.zero_grad()
            if args.batch_chunk > 1:
                data_chunks = torch.chunk(data, args.batch_chunk, 1)
                target_chunks = torch.chunk(target, args.batch_chunk, 1)
                for i in range(args.batch_chunk):
                    data_i = data_chunks[i].contiguous()
                    target_i = target_chunks[i].contiguous()
                    ret = model(data_i, target_i, *mems[i])
                    loss, mems[i] = ret[0], ret[1:]
                    loss = loss.float().mean().type_as(loss) / args.batch_chunk
                    if args.fp16:
                        with amp.scale_loss(loss, optimizer) as scaled_loss:
                            scaled_loss.backward()
                    else:
                        loss.backward()
                    train_losses.append(loss.float().item())
            else:
                ret = model(data, target, *mems)
                loss, mems = ret[0], ret[1:]
                loss = loss.float().mean().type_as(loss)
                if args.fp16:
                    with amp.scale_loss(loss, optimizer) as scaled_loss:
                        scaled_loss.backward()
                else:
                    loss.backward()
                train_losses.append(loss.float().item())

            if args.fp16:
                torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer),
                                               args.clip)
            else:
                torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip)

            optimizer.step()
            parent_model.compute += openai_compute(
                non_emb_param_count(parent_model, nseries), data.numel(), 1)

            # step-wise learning rate annealing
            train_step += 1
            parent_model.training_steps += 1

            # check for yet-to-thaw parameters
            if getattr(parent_model, "freeze_countdown", 0) > 0:
                parent_model.freeze_countdown -= 1

                # if this is the last step
                if parent_model.freeze_countdown == 0:
                    for parameter in parent_model.parameters():
                        parameter.requires_grad = True
                    logging("thawing all parameters")

            if args.scheduler in ["cosine", "constant", "dev_perf"]:
                # linear warmup stage
                if train_step < args.warmup_step:
                    curr_lr = args.lr * train_step / args.warmup_step
                    optimizer.param_groups = curr_lr
                else:
                    if args.scheduler == "cosine":
                        scheduler.step(train_step)
            elif args.scheduler == "inv_sqrt":
                scheduler.step(train_step)

            if train_step % args.log_interval == 0:
                cur_loss = np.mean(train_losses)
                elapsed = time.time() - log_start_time
                log_str = ("| epoch {:3d} step {:>8d} "
                           "| {:>6d} batches "
                           "| lr {:.3g} "
                           "| ms/batch {:5.2f} "
                           "| loss {:5.2f}".format(
                               epoch,
                               train_step,
                               batch + 1,
                               optimizer.param_groups[0]["lr"],
                               elapsed * 1000 / args.log_interval,
                               cur_loss,
                           ))

                log_str += " | bpc {:9.5f}".format(cur_loss / math.log(2))
                logging(log_str)

                train_losses = []
                log_start_time = time.time()

            if train_step % args.eval_interval == 0:
                val_loss = evaluate(va_iter, model)
                log_val(val_loss,
                        step=train_step,
                        compute=parent_model.compute)
                # Save the model if the validation loss is the best we've seen so
                # far.
                if not best_val_loss or val_loss < best_val_loss:
                    best_val_loss = val_loss
                    if not args.debug:
                        if args.fp16:
                            with open(
                                    os.path.join(args.work_dir,
                                                 "amp_checkpoint.pt"),
                                    "wb",
                            ) as f:
                                checkpoint = {
                                    "model": model.state_dict(),
                                    "optimizer": optimizer.state_dict(),
                                    "amp": amp.state_dict(),
                                }
                                torch.save(checkpoint, f)
                        else:
                            with open(os.path.join(args.work_dir, "model.pt"),
                                      "wb") as f:
                                torch.save(parent_model, f)
                            with open(
                                    os.path.join(args.work_dir,
                                                 "optimizer.pt"),
                                    "wb",
                            ) as f:
                                torch.save(optimizer.state_dict(), f)

                # dev-performance based learning rate annealing
                if args.scheduler == "dev_perf":
                    scheduler.step(val_loss)

                eval_start_time = time.time()

            if train_step == args.max_step:
                break

    def expand_model(
        strategy,
        integration,
        integration_length,
        n_add,
        model: MemTransformerLM,
        optimizers,
        schedulers,
        tr_iter,
        va_iter,
        epoch,
        step,
    ):
        optimizer, _ = optimizers
        scheduler, _ = schedulers
        if integration:
            if not integration_length or integration_length <= 0:
                warnings.warn(
                    f"integration {integration} passed but integration_length is {integration_length}"
                )
            else:
                logging(
                    f"applying integration strategy {integration} with integration length {integration_length}"
                )

        # pre-expansion validation
        logging(f"evaluating before expanding")
        val_loss = evaluate(va_iter, model)
        log_val(val_loss, step=step, compute=model.compute)

        # infer example logits for reverse distillation
        if "reverse_distil" in integration:
            first_logits = get_original_batches(model, tr_iter,
                                                integration_length)

        # expansion
        logging(
            f"adding {n_add} layers before starting epoch {epoch} with method {strategy}"
        )
        new_layers = model.expand_layers(n_add,
                                         strategy=strategy,
                                         function=initialization_func)

        # optimizer update
        optimizer.add_param_group({
            "params":
            new_layers.parameters(),
            "lr":
            optimizer.param_groups[0]["lr"],
            "initial_lr":
            optimizer.param_groups[0]["initial_lr"],
        })
        scheduler.base_lrs.append(optimizer.param_groups[-1]["initial_lr"])

        # training loop for reverse distillation
        if "reverse_distil" in integration:
            fit_to_previous_model(model, new_layers, tr_iter, first_logits,
                                  integration)

        # freezing parameters for frozen restart, we do this afterwards else the
        # new layers get copied also without grads
        if "freeze" in integration and integration_length > 0:
            for param_group in optimizer.param_groups[:-1]:
                for parameter in param_group["params"]:
                    parameter.requires_grad = False
            model.freeze_countdown = integration_length

        # post-expansion validation
        logging(f"reevaluating")
        val_loss = evaluate(va_iter, model)
        log_val(val_loss, step=step, compute=model.compute)

    def expand_state(param, state):
        if param.shape != state.shape:
            ratios = [
                param.shape[i] // state.shape[i]
                for i in range(len(param.shape))
            ]
            return state.repeat(*ratios)
        else:
            return state

    def widen_model(
        strategy,
        ratio,
        model: MemTransformerLM,
        optimizers,
        va_iter,
        epoch,
        step,
    ):
        optimizer, _ = optimizers

        # pre-expansion validation
        logging(f"evaluating before widening")

        # debug
        val_loss = evaluate(va_iter, model)
        log_val(val_loss, compute=model.compute, step=step)

        # infer example logits for reverse distillation expansion
        logging(
            f"adding {ratio} layers before starting epoch {epoch} with method {strategy}"
        )
        model.add_heads(ratio, strategy=strategy, function=initialization_func)

        # optimizer update
        for param, states in optimizer.state.items():
            if isinstance(param, nn.Parameter):
                states["exp_avg"] = expand_state(param, states["exp_avg"])
                states["exp_avg_sq"] = expand_state(param,
                                                    states["exp_avg_sq"])

        # training loop for reverse distillation
        # post-expansion validation
        logging(f"reevaluating")
        val_loss = evaluate(va_iter, model)
        log_val(val_loss, step=step, compute=model.compute)

    # reverse distillation trainer
    def fit_to_previous_model(model, new_layers, tr_iter, first_logits,
                              integration):
        mse_loss = torch.nn.MSELoss()
        if "partial" in integration:
            distil_optimizer, distil_optimizer_sparse = build_optimizer(
                new_layers, reload=False)
        else:
            distil_optimizer, distil_optimizer_sparse = build_optimizer(
                model, reload=False)
        if args.cuda and args.fp16:
            model, distil_optimizer = amp.initialize(model,
                                                     distil_optimizer,
                                                     opt_level=args.fp16)

        model.train()
        if args.batch_chunk > 1:
            mems = [None for _ in range(args.batch_chunk)]
        else:
            mems = None
        train_iter = tr_iter.get_varlen_iter() if args.varlen else tr_iter
        for batch, (data, _, _) in enumerate(train_iter):
            if batch == len(first_logits):
                break
            model.zero_grad()
            if args.batch_chunk > 1:
                data_chunks = torch.chunk(data, args.batch_chunk, 1)
                for i in range(args.batch_chunk):
                    data_i = data_chunks[i].contiguous()
                    logits, mems[i] = model._forward(data_i, mems=mems[i])
                    target_logits = first_logits[i][batch].to(logits.device)
                    loss = mse_loss(logits, target_logits) / args.batch_chunk
                    if args.fp16:
                        with amp.scale_loss(loss,
                                            distil_optimizer) as scaled_loss:
                            scaled_loss.backward()
                    else:
                        loss.backward()
            else:
                logits, mems = model._forward(data, mems=mems)
                target_logits = first_logits[batch].to(logits.device)
                loss = mse_loss(logits, target_logits)
                if args.fp16:
                    with amp.scale_loss(loss, distil_optimizer) as scaled_loss:
                        scaled_loss.backward()
                else:
                    loss.backward()

            if args.fp16:
                torch.nn.utils.clip_grad_norm_(
                    amp.master_params(distil_optimizer), args.clip)
            else:
                torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip)

            distil_optimizer.step()

    ###################################################################################
    #
    # main()
    #
    args.tied = not args.not_tied

    if args.d_embed < 0:
        args.d_embed = args.d_model

    # Validate `--fp16` option
    if args.fp16:
        if not args.cuda:
            print("WARNING: --fp16 requires --cuda, ignoring --fp16 option")
            args.fp16 = False
        else:
            try:
                from apex import amp

                if args.fp16 == "O1":
                    amp.register_half_function(torch, "einsum")
            except:
                print("WARNING: apex not installed, ignoring --fp16 option")
                args.fp16 = False

    device = torch.device("cuda" if args.cuda else "cpu")

    # Set the random seed manually for reproducibility.
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if torch.cuda.is_available():
        if not args.cuda:
            print(
                "WARNING: You have a CUDA device, so you should probably run "
                "with --cuda ")
        else:
            torch.cuda.manual_seed_all(args.seed)

    ############################################################################
    # Logging
    ############################################################################

    assert args.ext_len >= 0, "extended context length must be non-negative"
    assert args.d_batch % args.batch_chunk == 0

    args.work_dir = "{}-{}".format(args.work_dir, args.dataset)
    args.work_dir = os.path.join(args.work_dir, time.strftime("%Y%m%d-%H%M%S"))
    logging = create_exp_dir(
        args.work_dir,
        scripts_to_save=["train_ts.py", "mem_transformer.py"],
        debug=args.debug,
    )

    ############################################################################
    # Load data
    ############################################################################
    time_series = get_time_series(args.datadir, args.dataset)
    nseries = len(time_series.vocab)
    args.n_token = nseries

    eval_batch_size = 20
    tr_iter = time_series.get_iterator(
        "train",
        args.d_batch,
        args.tgt_len,
        device=device,
        ext_len=args.ext_len,
    )
    va_iter = time_series.get_iterator(
        "valid",
        eval_batch_size,
        args.eval_tgt_len,
        device=device,
        ext_len=args.ext_len,
    )
    te_iter = time_series.get_iterator(
        "test",
        eval_batch_size,
        args.eval_tgt_len,
        device=device,
        ext_len=args.ext_len,
    )

    cutoffs, tie_projs = [], [False]

    ############################################################################
    # Define model
    ############################################################################

    initialization_func = partial(
        weights_init,
        init=args.init,
        init_range=args.init_range,
        init_std=args.init_std,
        proj_init_std=args.proj_init_std,
    )

    if args.restart and not args.fp16:
        if args.restart_from is not None:
            model_name = f"model_{args.restart_from}.pt"
        else:
            model_name = "model.pt"
        model_file_name = os.path.join(args.restart_dir, model_name)
        logging(f"reloading {model_file_name}")
        with open(model_file_name, "rb") as f:
            model = torch.load(f)
        # backwards compatibility with older saves
        if isinstance(model, nn.DataParallel):
            model = model.module
        model.backward_compatible(tie_weight=args.tied, tie_projs=tie_projs)
        if not args.fp16:
            model = model.float()
        model.apply(update_dropout)
        model.apply(update_dropatt)

    else:
        model = MemTransformerLM(
            nseries,
            args.n_layer,
            args.n_head,
            args.d_model,
            args.d_head,
            args.d_inner,
            args.dropout,
            args.dropatt,
            tie_weight=args.tied,
            d_embed=args.d_embed,
            div_val=args.div_val,
            tie_projs=tie_projs,
            pre_lnorm=args.pre_lnorm,
            tgt_len=args.tgt_len,
            ext_len=args.ext_len,
            mem_len=args.mem_len,
            cutoffs=cutoffs,
            same_length=args.same_length,
            clamp_len=args.clamp_len,
        )
        model.apply(initialization_func)

        # debug
        # model.word_emb.apply(initialization_func)
        # ensure embedding init is not overridden by out_layer in case of
        # weight sharing
    args.n_all_param = sum([p.nelement() for p in model.parameters()])
    args.n_nonemb_param = non_emb_param_count(model, nseries)

    logging("=" * 100)
    for k, v in args.__dict__.items():
        logging("    - {} : {}".format(k, v))
    logging("=" * 100)
    logging("#params = {}".format(args.n_all_param))
    logging("#non emb params = {}".format(args.n_nonemb_param))

    para_model = parallelize_model(model, args)
    optimizers = build_optimizer(para_model,
                                 args,
                                 reload=args.restart and not args.fp16)
    optimizer, optimizer_sparse = optimizers
    schedulers = build_scheduler(optimizers, args)
    scheduler, scheduler_sparse = schedulers

    if args.cuda and args.fp16:
        para_model, optimizer = amp.initialize(para_model,
                                               optimizer,
                                               opt_level=args.fp16)

        if args.restart:
            if args.restart_from is not None:
                checkpoint_name = f"amp_checkpoint_{args.restart_from}.pt"
            else:
                checkpoint_name = "amp_checkpoint.pt"
            with open(os.path.join(args.work_dir, checkpoint_name), "rb") as f:
                checkpoint = torch.load(f)
                model.load_state_dict(checkpoint["model"])
                optimizer.load_state_dict(checkpoint["optimizer"])
                amp.load_state_dict(checkpoint["amp"])

    ############################################################################
    # Training loop
    ############################################################################

    # Loop over epochs.
    if args.reset_lr:
        # then they're different and we use train_step only for the new lr
        # scheduling
        train_step = 0
        optimizer.defaults["lr"] = args.lr
        for param_group in optimizer.param_groups:
            param_group["lr"] = args.lr
            param_group["initial_lr"] = args.lr
        scheduler.base_lrs = [args.lr] * len(scheduler.base_lrs)
    else:
        train_step = model.training_steps

    best_val_loss = None

    # Reload previous step number in case of default_args.restart
    if train_step > 0:
        logging(f"restarting from step {train_step}")

    log_start_time = time.time()
    eval_start_time = time.time()

    def run_training():
        nonlocal train_step

        for epoch in itertools.count(start=first_epoch):
            # we check before the training loop; expanding at epoch 0 means
            # before training (for debug purposes)
            if args.expand and str(epoch - 1) in args.expansion_dict:
                n_add = int(args.expansion_dict[str(epoch - 1)])
                expand_model(
                    args.expand,
                    args.integration,
                    args.integration_length,
                    n_add,
                    model,
                    optimizers,
                    schedulers,
                    tr_iter,
                    va_iter,
                    epoch,
                    train_step,
                )
            if args.widen and str(epoch - 1) in args.widen_dict:
                ratio = int(args.widen_dict[str(epoch - 1)])
                widen_model(
                    args.widen,
                    ratio,
                    model,
                    optimizers,
                    va_iter,
                    epoch,
                    train_step,
                )
            epoch_loop(epoch, para_model, optimizers, schedulers)
            if train_step >= args.max_step:
                logging("-" * 100)
                logging("End of training")
                break
            if not args.debug and args.log_first_epochs:
                if epoch <= args.log_first_epochs:
                    logging(f"saving model at the end of epoch {epoch}")
                    if args.fp16:
                        with open(
                                os.path.join(args.work_dir,
                                             f"amp_checkpoint_{epoch}.pt"),
                                "wb",
                        ) as f:
                            checkpoint = {
                                "model": model.state_dict(),
                                "optimizer": optimizer.state_dict(),
                                "amp": amp.state_dict(),
                            }
                            torch.save(checkpoint, f)
                    else:
                        with open(
                                os.path.join(args.work_dir,
                                             f"model_{epoch}.pt"),
                                "wb",
                        ) as f:
                            torch.save(model, f)
                        with open(
                                os.path.join(args.work_dir,
                                             f"optimizer_{epoch}.pt"),
                                "wb",
                        ) as f:
                            torch.save(optimizer.state_dict(), f)

    # At any point you can hit Ctrl + C to break out of training early.
    try:
        if args.restart_from:
            first_epoch = args.restart_from + 1
            print(f"restarting from epoch {first_epoch}")
        else:
            first_epoch = 1
        run_training()

    except KeyboardInterrupt:
        logging("-" * 100)
        logging("Exiting from training early")

    # Load the best model.
    if args.fp16:
        with open(os.path.join(args.work_dir, "amp_checkpoint.pt"), "rb") as f:
            checkpoint = torch.load(f)
            model.load_state_dict(checkpoint["model"])
            optimizer.load_state_dict(checkpoint["optimizer"])
            amp.load_state_dict(checkpoint["amp"])
    else:
        with open(os.path.join(args.work_dir, "model.pt"), "rb") as f:
            model = torch.load(f)
    para_model = model.to(device)

    # Run on test data.
    test_loss = evaluate(te_iter, para_model)
    logging("=" * 100)
    logging("| End of training | test loss {:5.2f} | test bpc {:9.5f}".format(
        test_loss, test_loss / math.log(2)))
    logging("=" * 100)
Example #13
0
    action='store_true',
    help='Use dynamic loss scaling.  If supplied, this argument'
    ' supersedes --static-loss-scale.')
args = parser.parse_args()
args.tied = not args.not_tied

if args.d_embed < 0:
    args.d_embed = args.d_model

assert args.ext_len >= 0, 'extended context length must be non-negative'
assert args.batch_size % args.batch_chunk == 0

args.work_dir = '{}-{}'.format(args.work_dir, args.dataset)
args.work_dir = os.path.join(args.work_dir, time.strftime('%Y%m%d-%H%M%S'))
logging = create_exp_dir(args.work_dir,
                         scripts_to_save=['train.py', 'model.py'],
                         debug=args.debug)

# Set the random seed manually for reproducibility.
np.random.seed(args.seed)
torch.manual_seed(args.seed)
if torch.cuda.is_available():
    if not args.cuda:
        print(
            'WARNING: You have a CUDA device, so you should probably run with --cuda'
        )
    else:
        torch.cuda.manual_seed_all(args.seed)

# Validate `--fp16` option
if args.fp16:
Example #14
0
def main():
    args = parse_args()
    if args.affinity != 'disabled':
        nproc_per_node = torch.cuda.device_count()
        affinity = utils.gpu_affinity.set_affinity(args.local_rank,
                                                   nproc_per_node,
                                                   args.affinity)
        print(f'{args.local_rank}: thread affinity: {affinity}')

    # Initialize device and distributed backend
    torch.cuda.set_device(args.local_rank)
    l2_promote()
    device = torch.device('cuda' if args.cuda else 'cpu')
    utils.distributed.init_distributed(args.cuda)

    args.work_dir = utils.exp_utils.build_work_dir_name(
        args.work_dir,
        args.dataset,
        args.append_dataset,
        args.append_time,
    )

    with utils.distributed.sync_workers() as rank:
        if rank == 0:
            create_exp_dir(args.work_dir,
                           scripts_to_save=['train.py', 'mem_transformer.py'],
                           debug=args.debug)

    # Setup logging
    if args.log_all_ranks:
        log_file = f'train_log_rank_{utils.distributed.get_rank()}.log'
    else:
        log_file = args.txtlog_file
    dllog_file = args.dllog_file
    log_file = os.path.join(args.work_dir, log_file)
    dllog_file = os.path.join(args.work_dir, dllog_file)

    if args.debug:
        log_file = os.devnull
        dllog_file = os.devnull

    utils.exp_utils.setup_logging(
        log_all_ranks=args.log_all_ranks,
        filename=log_file,
    )
    utils.exp_utils.setup_dllogger(enabled=True, filename=dllog_file)

    if args.local_batch_size is not None:
        world_size = utils.distributed.get_world_size()
        args.batch_size = world_size * args.local_batch_size
        logging.info(f'--local_batch_size was set, adjusting global batch size'
                     f' to {args.batch_size} (local_batch_size * world_size)')
        if args.batch_size % args.batch_chunk != 0:
            raise RuntimeError('Batch size needs to be divisible by '
                               'batch chunk')

    if args.profile:
        try:
            pyprof.init(enable_function_stack=True)
        except NameError:
            warnings.warn('Called pyprof.init() but pyprof is not available')

    logging.info(args)
    dllogger.log(step='PARAMETER', data=vars(args))

    logging.info(f'world size: {utils.distributed.get_world_size()}')

    if not args.no_env:
        log_env_info()

    register_ignoring_timeout_handler()

    # Set the random seed manually for reproducibility.
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)

    ###########################################################################
    # Load data
    ###########################################################################
    corpus = get_lm_corpus(args.data, args.dataset, args.vocab)
    ntokens = len(corpus.vocab)
    vocab = corpus.vocab
    args.n_token = ntokens

    if args.mem_len == 0:
        eval_mem_len = 0
    else:
        eval_mem_len = args.mem_len + args.tgt_len - args.eval_tgt_len

    tr_iter = corpus.get_iterator('train',
                                  args.batch_size,
                                  args.tgt_len,
                                  device=device,
                                  ext_len=args.ext_len)
    va_iter = corpus.get_iterator('valid',
                                  args.eval_batch_size,
                                  args.eval_tgt_len,
                                  device=device,
                                  mem_len=eval_mem_len,
                                  ext_len=args.ext_len)
    te_iter = corpus.get_iterator('test',
                                  args.eval_batch_size,
                                  args.eval_tgt_len,
                                  device=device,
                                  mem_len=eval_mem_len,
                                  ext_len=args.ext_len)

    # adaptive softmax / embedding
    cutoffs, tie_projs = [], [False]
    if args.adaptive:
        assert args.dataset in ['wt103', 'lm1b']
        if args.dataset == 'wt103':
            cutoffs = [19997, 39997, 199997]
            tie_projs += [True] * len(cutoffs)
        elif args.dataset == 'lm1b':
            cutoffs = [59997, 99997, 639997]
            tie_projs += [False] * len(cutoffs)

    ###########################################################################
    # Build the model
    ###########################################################################
    model_config = {
        'n_token': ntokens,
        'n_layer': args.n_layer,
        'n_head': args.n_head,
        'd_model': args.d_model,
        'd_head': args.d_head,
        'd_inner': args.d_inner,
        'dropout': args.dropout,
        'dropatt': args.dropatt,
        'dtype': None,
        'tie_weight': args.tied,
        'd_embed': args.d_embed,
        'div_val': args.div_val,
        'tie_projs': tie_projs,
        'pre_lnorm': args.pre_lnorm,
        'tgt_len': args.tgt_len,
        'ext_len': args.ext_len,
        'mem_len': args.mem_len,
        'cutoffs': cutoffs,
        'same_length': args.same_length,
        'attn_type': args.attn_type,
        'clamp_len': args.clamp_len,
        'sample_softmax': args.sample_softmax,
    }

    model = MemTransformerLM(**model_config)

    model.apply(functools.partial(weights_init, args=args))
    # ensure embedding init is not overridden by out_layer in case of weight sharing
    model.word_emb.apply(functools.partial(weights_init, args=args))

    args.n_all_param = sum([p.nelement() for p in model.parameters()])
    args.n_nonemb_param = sum(
        [p.nelement() for p in model.layers.parameters()])

    # optimizer
    if args.optim.lower() == 'sgd':
        if args.sample_softmax > 0:
            dense_params, sparse_params = [], []
            for param in model.parameters():
                if param.size() == model.word_emb.weight.size():
                    sparse_params.append(param)
                else:
                    dense_params.append(param)
            optimizer_sparse = optim.SGD(sparse_params, lr=args.lr * 2)
            optimizer = optim.SGD(dense_params, lr=args.lr, momentum=args.mom)
        else:
            optimizer = optim.SGD(model.parameters(),
                                  lr=args.lr,
                                  momentum=args.mom)
            optimizer_sparse = None
    elif args.optim.lower() == 'adam':
        if args.sample_softmax > 0:
            dense_params, sparse_params = [], []
            for param in model.parameters():
                if param.size() == model.word_emb.weight.size():
                    sparse_params.append(param)
                else:
                    dense_params.append(param)
            optimizer_sparse = optim.SparseAdam(sparse_params, lr=args.lr)
            optimizer = optim.Adam(dense_params,
                                   lr=args.lr,
                                   weight_decay=args.weight_decay)
        else:
            optimizer = optim.Adam(model.parameters(),
                                   lr=args.lr,
                                   weight_decay=args.weight_decay)
            optimizer_sparse = None
    elif args.optim.lower() == 'adagrad':
        optimizer = optim.Adagrad(model.parameters(), lr=args.lr)
        optimizer_sparse = None
    elif args.optim.lower() == 'lamb':
        optimizer = lamb.Lamb(model.parameters(),
                              lr=args.lr,
                              weight_decay=args.weight_decay)
        optimizer_sparse = None
    elif args.optim.lower() == 'jitlamb':
        optimizer = lamb.JITLamb(model.parameters(),
                                 lr=args.lr,
                                 weight_decay=args.weight_decay)
        optimizer_sparse = None

    model = model.to(device)

    scaler = None
    if args.fp16:
        if args.amp == 'pytorch':
            scaler = torch.cuda.amp.GradScaler()
        elif args.amp == 'apex':
            model, optimizer = amp.initialize(
                model,
                optimizer,
                opt_level=args.apex_amp_opt_level,
            )

    if args.multi_gpu == 'ddp' and torch.distributed.is_initialized():
        para_model = DistributedDataParallel(
            model,
            device_ids=[args.local_rank],
            output_device=args.local_rank,
            broadcast_buffers=False,
            find_unused_parameters=True,
        )
    elif args.multi_gpu == 'dp':
        if args.gpu0_bsz >= 0:
            para_model = BalancedDataParallel(args.gpu0_bsz //
                                              args.batch_chunk,
                                              model,
                                              dim=1).to(device)
        else:
            para_model = nn.DataParallel(model, dim=1).to(device)
    else:
        para_model = model

    # scheduler
    if args.scheduler == 'cosine':
        if args.max_step_scheduler:
            max_step = args.max_step_scheduler
        else:
            max_step = args.max_step

        scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer,
                                                         max_step -
                                                         args.warmup_step,
                                                         eta_min=args.eta_min)
        if args.sample_softmax > 0 and optimizer_sparse is not None:
            scheduler_sparse = optim.lr_scheduler.CosineAnnealingLR(
                optimizer_sparse,
                max_step - args.warmup_step,
                eta_min=args.eta_min)
        else:
            scheduler_sparse = None
    elif args.scheduler == 'inv_sqrt':
        # originally used for Transformer (in Attention is all you need)
        def lr_lambda(step):
            # return a multiplier instead of a learning rate
            if step == 0 and args.warmup_step == 0:
                return 1.
            else:
                return 1. / (step ** 0.5) if step > args.warmup_step \
                    else step / (args.warmup_step ** 1.5)

        scheduler = optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lr_lambda)
        if args.sample_softmax > 0 and optimizer_sparse is not None:
            scheduler_sparse = optim.lr_scheduler.LambdaLR(optimizer_sparse,
                                                           lr_lambda=lr_lambda)
        else:
            scheduler_sparse = None
    elif args.scheduler == 'dev_perf':
        scheduler = optim.lr_scheduler.ReduceLROnPlateau(
            optimizer,
            factor=args.decay_rate,
            patience=args.patience,
            min_lr=args.lr_min,
        )
        if args.sample_softmax > 0 and optimizer_sparse is not None:
            scheduler_sparse = optim.lr_scheduler.ReduceLROnPlateau(
                optimizer_sparse,
                factor=args.decay_rate,
                patience=args.patience,
                min_lr=args.lr_min,
            )
        else:
            scheduler_sparse = None
    elif args.scheduler == 'constant':
        pass

    logging.info('=' * 100)
    for k, v in args.__dict__.items():
        logging.info('    - {} : {}'.format(k, v))
    logging.info('=' * 100)
    logging.info('#params = {}'.format(args.n_all_param))
    logging.info('#non emb params = {}'.format(args.n_nonemb_param))

    train_step = 0
    start_epoch = 1
    last_batch = 0
    last_iter = 0
    best_val_loss = None

    if args.restart:
        try:
            checkpoint = load_checkpoint(args.restart)
            model.load_state_dict(checkpoint['model_state'])
            optimizer.load_state_dict(checkpoint['optimizer_state'])
            scheduler.load_state_dict(checkpoint['scheduler_state'])
            if args.fp16:
                if args.amp == 'pytorch':
                    scaler.load_state_dict(checkpoint['amp_state'])
                elif args.amp == 'apex':
                    amp.load_state_dict(checkpoint['amp_state'])
            train_step = checkpoint['train_step']
            start_epoch = checkpoint['epoch']
            last_batch = checkpoint['batch']
            last_iter = checkpoint['last_iter']
            best_val_loss = checkpoint['best_val_loss']

            if train_step >= args.max_step:
                logging.info(
                    f'Loaded checkpoint after {train_step} steps, but '
                    f'this run was scheduled for a total of '
                    f'{args.max_step} steps, exiting')
                sys.exit(1)

            model.apply(functools.partial(update_dropout, args=args))
            model.apply(functools.partial(update_dropatt, args=args))
        except FileNotFoundError:
            logging.info(f'Could not load checkpoint from {args.restart}, '
                         f'starting training from random init')

    meters = {}
    warmup = args.mem_len // args.tgt_len + 2
    meters['train_throughput'] = AverageMeter(warmup=warmup)
    ###########################################################################
    # Train
    ###########################################################################
    # Loop over epochs.
    # At any point you can hit Ctrl + C to break out of training early.
    start_time = time.time()
    with torch.autograd.profiler.emit_nvtx(enabled=args.profile):
        with TimeoutHandler() as timeout_handler:
            try:
                for epoch in itertools.count(start=start_epoch):
                    if args.roll:
                        tr_iter.roll(seed=args.seed + epoch)
                    train_step, best_val_loss = train(
                        tr_iter, va_iter, model, para_model, model_config,
                        optimizer, optimizer_sparse, scheduler,
                        scheduler_sparse, scaler, vocab, epoch, last_batch,
                        last_iter, train_step, best_val_loss, meters,
                        timeout_handler, device, args)

                    last_batch = 0
                    last_iter = 0

                    if train_step == args.max_step:
                        logging.info('-' * 100)
                        logging.info('End of training')
                        break
            except KeyboardInterrupt:
                logging.info('-' * 100)
                logging.info('Exiting from training early')
    elapsed = time.time() - start_time

    ###########################################################################
    # Test
    ###########################################################################
    summary = {}
    test_path = os.path.join(args.work_dir, 'checkpoint_best.pt')
    if not args.debug and not args.no_eval and os.path.exists(test_path):
        # Load the best saved model.
        checkpoint = load_checkpoint(test_path)
        model.load_state_dict(checkpoint['model_state'])

        # Run on test data.
        test_start_time = time.time()
        with torch.autograd.profiler.emit_nvtx(enabled=args.profile):
            test_loss = evaluate(te_iter, model, args)
            test_loss = utils.distributed.all_reduce_item(test_loss, 'mean')
        test_elapsed = time.time() - test_start_time

        logging.info('=' * 100)
        if args.dataset in ['enwik8', 'text8']:
            logging.info(
                '| End of training | test time: {:5.2f}s | test loss {:5.2f} | test bpc {:9.5f}'
                .format(test_elapsed, test_loss, test_loss / math.log(2)))
        else:
            logging.info(
                '| End of training | test time: {:5.2f}s | test loss {:5.2f} | test ppl {:9.3f}'
                .format(test_elapsed, test_loss, math.exp(test_loss)))
        logging.info('=' * 100)

        summary.update({
            'test_elapsed': test_elapsed,
            'test_loss': test_loss,
        })

        if args.dataset in ['enwik8', 'text8']:
            summary['test_bits_per_character'] = test_loss / math.log(2)
        else:
            summary['test_perplexity'] = math.exp(test_loss)

    logging.info(f'Training time: {(elapsed / 60):.2f} minutes')
    logging.info(
        f'Training throughput: {meters["train_throughput"].avg:.2f} tok/s')

    if best_val_loss:
        val_perplexity = math.exp(best_val_loss)
    else:
        val_perplexity = None

    summary.update({
        'train_throughput': meters['train_throughput'].avg,
        'train_elapsed': elapsed / 60,
        'valid_loss': best_val_loss,
        'valid_perplexity': val_perplexity,
    })
    dllogger.log(step=tuple(), data=summary)

    passed = benchmark(target_perplexity=args.target_perplexity,
                       test_perplexity=val_perplexity,
                       target_throughput=args.target_throughput,
                       test_throughput=meters['train_throughput'].avg)
    if not passed:
        sys.exit(1)
Example #15
0
def main():
    args = parse_args()

    if args.type == 'pytorch':
        from mem_transformer import MemTransformerLM
    else:
        from inference.mem_transformer_base_jit import MemTransformerLM

    torch.cuda.set_device(args.local_rank)
    device = torch.device('cuda' if args.cuda else 'cpu')
    utils.distributed.init_distributed(args.cuda)

    with utils.distributed.sync_workers() as rank:
        if rank == 0:
            create_exp_dir(args.work_dir, debug=args.debug)

    # Setup logging
    if args.log_all_ranks:
        log_file = f'log_rank_{utils.distributed.get_rank()}.log'
    else:
        log_file = f'log.log'

    log_file = os.path.join(args.work_dir, log_file)
    if args.debug:
        log_file = os.devnull

    utils.exp_utils.setup_logging(
        log_all_ranks=args.log_all_ranks,
        filename=log_file,
        filemode='a',
    )
    logging.info(args)

    if args.model:
        model_path = args.model
    elif args.work_dir:
        model_path = os.path.join(args.work_dir, 'checkpoint_best.pt')
    else:
        raise RuntimeError(
            'Specify path to checkpoint using --model or --work_dir')

    checkpoint = load_checkpoint(model_path)

    if args.manual:
        args.batch_size = 1
        vocab = checkpoint['vocab']

        if hasattr(vocab, 'sym2idx') and not hasattr(vocab, 'unk_idx'):
            vocab.unk_idx = vocab.sym2idx['<unk>']

        text = " ".join(args.manual)
        tokenized = tokenize_raw(text)
        symbols = vocab.tokenize(tokenized, add_eos=True)
        tensor = vocab.convert_to_tensor(symbols)

        iter = data_utils.LMOrderedIterator(tensor,
                                            bsz=args.batch_size,
                                            bptt=args.tgt_len,
                                            device=device,
                                            ext_len=args.ext_len)
    else:
        # Load dataset
        corpus = get_lm_corpus(args.data, args.dataset,
                               checkpoint['args'].vocab)

        if args.split == 'valid':
            iter = corpus.get_iterator('valid',
                                       args.batch_size,
                                       args.tgt_len,
                                       device=device,
                                       ext_len=args.ext_len)
        elif args.split == 'test':
            iter = corpus.get_iterator('test',
                                       args.batch_size,
                                       args.tgt_len,
                                       device=device,
                                       ext_len=args.ext_len)
        else:
            raise RuntimeError('Unknown split')

    if args.fp16:
        dtype = torch.float16
        math_str = 'fp16'
    else:
        dtype = torch.float32
        math_str = 'fp32'

    if args.load_torchscript:
        model = torch.jit.load(args.load_torchscript)

    else:
        checkpoint['model_config']['tgt_len'] = args.tgt_len
        checkpoint['model_config']['ext_len'] = args.ext_len
        checkpoint['model_config']['mem_len'] = args.mem_len
        checkpoint['model_config']['clamp_len'] = args.clamp_len
        checkpoint['model_config']['same_length'] = args.same_length
        checkpoint['model_config']['dtype'] = dtype

        model = MemTransformerLM(**checkpoint['model_config'])
        model.load_state_dict(checkpoint['model_state'])

    model = model.eval()
    model = model.to(device)

    model = model.float()
    if args.fp16:
        model = model.half()

    if args.type != 'pytorch':
        compile_model(model, device, args)

    if args.type == 'torchscript' and args.save_torchscript:
        torch.jit.save(model, args.save_torchscript)

    logging.info(f'Evaluating with: math {math_str} type {args.type} '
                 f'bsz {args.batch_size} tgt_len {args.tgt_len} '
                 f'ext_len {args.ext_len} mem_len {args.mem_len} '
                 f'clamp_len {args.clamp_len}')

    meters = {}
    warmup = args.mem_len // args.tgt_len + 1
    meters['eval_throughput'] = AverageMeter(warmup=warmup,
                                             keep=args.save_data)
    meters['eval_latency'] = AverageMeter(warmup=warmup, keep=args.save_data)

    loss = evaluate(iter, model, meters, args.max_size, args.repeat)
    perplexity = math.exp(loss)
    log_str = format_log(loss, args.split, args)

    logging.info('=' * 100)
    logging.info(log_str)
    logging.info('=' * 100)

    if args.save_data:
        latency_data = np.array(meters['eval_latency'].vals)
        throughput_data = np.array(meters['eval_throughput'].vals)
        precision = 'fp16' if args.fp16 else 'fp32'
        data_fname = f'eval_data_{args.batch_size}_{precision}_{args.type}'
        data_path = os.path.join(args.work_dir, data_fname)
        data = {
            'args': args,
            'throughput': throughput_data,
            'latency': latency_data,
        }
        with open(data_path, 'wb') as f:
            pickle.dump(data, f)
        logging.info(f'Throughput Avg: {throughput_data.mean():.2f} tok/s')
        logging.info(f'Latency Avg: {1000.0 * latency_data.mean():.2f} ms')
        for p in args.percentiles:
            logging.info(
                f'Latency {p}%: {1000.0 * np.percentile(latency_data, p):.2f} ms'
            )

        logging.info('=' * 100)

    passed = benchmark(
        target_perplexity=args.target_perplexity,
        test_perplexity=perplexity,
        target_throughput=args.target_throughput,
        test_throughput=meters['eval_throughput'].avg,
    )
    if not passed:
        sys.exit(1)
Example #16
0
    action='store_true',
    help='Use dynamic loss scaling.  If supplied, this argument'
    ' supersedes --static-loss-scale.')
args = parser.parse_args()
args.tied = not args.not_tied

if args.d_embed < 0:
    args.d_embed = args.d_model

assert args.ext_len >= 0, 'extended context length must be non-negative'
assert args.batch_size % args.batch_chunk == 0

args.work_dir = '{}-{}'.format(args.work_dir, args.dataset)
args.work_dir = os.path.join(args.work_dir, time.strftime('%Y%m%d-%H%M%S'))
logging = create_exp_dir(args.work_dir,
                         scripts_to_save=['train.py', 'mem_transformer.py'],
                         debug=args.debug)

# Set the random seed manually for reproducibility.
np.random.seed(args.seed)
torch.manual_seed(args.seed)
if torch.cuda.is_available():
    if not args.cuda:
        print(
            'WARNING: You have a CUDA device, so you should probably run with --cuda'
        )
    else:
        torch.cuda.manual_seed_all(args.seed)

# Validate `--fp16` option
if args.fp16:
Example #17
0
args = parser.parse_args()
args.tied = not args.not_tied

if args.d_embed < 0:
    args.d_embed = args.d_model

assert args.ext_len >= 0, 'extended context length must be non-negative'
assert args.batch_size % args.batch_chunk == 0

args.work_dir = '{}-{}'.format(args.work_dir, args.dataset)
args.work_dir = os.path.join(args.work_dir, time.strftime('%Y%m%d-%H%M%S'))
src_script_path = args.src_script_path
logging = create_exp_dir(args.work_dir,
                         scripts_to_save=[
                             os.path.join(src_script_path, 'train.py'),
                             os.path.join(src_script_path,
                                          'mem_transformer.py')
                         ],
                         debug=args.debug)

# Set the random seed manually for reproducibility.
np.random.seed(args.seed)
torch.manual_seed(args.seed)
if torch.cuda.is_available():
    if not args.cuda:
        print(
            'WARNING: You have a CUDA device, so you should probably run with --cuda'
        )
    else:
        torch.cuda.manual_seed_all(args.seed)