Example #1
0
def export_onnx(path, batch_size, seq_len):
    print('The model is also exported in ONNX format at {}'.format(
        os.path.realpath(args.onnx_export)))
    model.eval()
    dummy_input = torch.LongTensor(seq_len * batch_size).zero_().view(
        -1, batch_size).to(device)
    hidden = model.init_hidden(batch_size)
    torch.onnx.export(model, (dummy_input, hidden), path)
Example #2
0
def evaluate(data_source):
    # Turn on evaluation mode which disables dropout.
    model.eval()
    total_loss = 0.
    ntokens = len(corpus.dictionary)
    hidden = model.init_hidden(eval_batch_size)
    with torch.no_grad():
        for i in range(0, data_source.size(0) - 1, args.bptt):
            data, targets = get_batch(data_source, i)
            output, hidden = model(data, hidden)
            output_flat = output.view(-1, ntokens)
            total_loss += len(data) * criterion(output_flat, targets).item()
            hidden = repackage_hidden(hidden)
    return total_loss / (len(data_source) - 1)
Example #3
0
def train():
    # Turn on training mode which enables dropout.
    model.train()
    total_loss = 0.
    start_time = time.time()
    ntokens = len(corpus.dictionary)
    hidden = model.init_hidden(args.batch_size)
    for batch, i in enumerate(range(0, train_data.size(0) - 1, args.bptt)):
        data, targets = get_batch(train_data, i)
        # Starting each batch, we detach the hidden state from how it was previously produced.
        # If we didn't, the model would try backpropagating all the way to start of the dataset.
        hidden = repackage_hidden(hidden)
        model.zero_grad()
        output, hidden = model(data, hidden)
        loss = criterion(output.view(-1, ntokens), targets)
        loss.backward()

        # `clip_grad_norm` helps prevent the exploding gradient problem in RNNs / LSTMs.
        torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip)
        for p in model.parameters():
            p.data.add_(-lr, p.grad.data)

        total_loss += loss.item()

        if batch % args.log_interval == 0 and batch > 0:
            cur_loss = total_loss / args.log_interval
            elapsed = time.time() - start_time
            print(
                '| epoch {:3d} | {:5d}/{:5d} batches | lr {:02.2f} | ms/batch {:5.2f} | '
                'loss {:5.2f} | ppl {:8.2f}'.format(
                    epoch, batch,
                    len(train_data) // args.bptt, lr,
                    elapsed * 1000 / args.log_interval, cur_loss,
                    math.exp(cur_loss)))
            total_loss = 0
            start_time = time.time()
Example #4
0
                      debug=False)

total_params = sum([np.prod(x[1].shape) for x in model.params.items()])
print('RNN type: ' + args.rnn + " Grammar: " + args.data + " Seed: " +
      str(args.seed))
print('Model total parameters: {}'.format(total_params))

try:
    if args.continue_train:
        model.params, h_init = load_params(params_file, hinit_file,
                                           model.params)
        model.reload_hidden(h_init, args.batch)
        model.init_tparams()
        model.build_model()
    else:
        model.init_hidden(args.batch, args.seed)
        save_hinit(model.h_init[0], hinit_file)
        model.build_model()

    model_test.reload_hidden(model.h_init[0], args.test_batch)
    model_test.build_model()

    # train the model
    if args.curriculum:
        model = curriculum_train(model=model,
                                 x=train_x,
                                 m=train_m,
                                 y=train_y,
                                 x_v=val_x,
                                 m_v=val_m,
                                 y_v=val_y,
Example #5
0
class WordLanguageModelTrial(PyTorchTrial):
    def __init__(self, context: PyTorchTrialContext):
        self.context = context
        data_config = self.context.get_data_config()
        hparams = self.context.get_hparams()
        using_bind_mount = data_config["use_bind_mount"]
        use_cache = data_config["use_cache"]
        self.eval_batch_size = hparams["eval_batch_size"]

        download_directory = (
            Path(data_config["bind_mount_path"]) if using_bind_mount else
            Path("/data")) / f"data-rank{self.context.distributed.get_rank()}"

        self.corpus = data.load_and_cache_dataset(download_directory,
                                                  use_cache)
        self.model_cls = hparams["model_cls"]
        emsize = hparams["word_embeddings_size"]
        num_hidden = hparams["num_hidden"]
        num_layers = hparams["num_layers"]
        dropout = hparams["dropout"]
        self.bptt = hparams["bptt"]

        if self.model_cls.lower() == "transformer":
            num_heads = hparams["num_heads"]
            self.model = TransformerModel(self.corpus.ntokens, emsize,
                                          num_heads, num_hidden, num_layers,
                                          dropout)
        else:
            tied = hparams["tied"]
            self.model = RNNModel(
                self.model_cls,
                self.corpus.ntokens,
                emsize,
                num_hidden,
                num_layers,
                dropout,
                tied,
            )

        self.model = self.context.wrap_model(self.model)
        self.criterion = nn.NLLLoss()

        lr = hparams["lr"]
        optimizer = torch.optim.SGD(self.model.parameters(), lr=lr)
        self.optimizer = self.context.wrap_optimizer(optimizer)

        self.lr_scheduler = self.context.wrap_lr_scheduler(
            torch.optim.lr_scheduler.ReduceLROnPlateau(
                self.optimizer,
                factor=0.25,
                patience=0,
                threshold=0.001,
                threshold_mode="abs",
                verbose=True,
            ),
            LRScheduler.StepMode.MANUAL_STEP,
        )

    def build_training_data_loader(self) -> DataLoader:
        train_dataset = data.WikiTextDataset(
            self.corpus,
            batch_size=self.context.get_per_slot_batch_size(),
        )
        batch_samp = data.BatchSamp(train_dataset, self.bptt)
        return DataLoader(train_dataset, batch_sampler=batch_samp)

    def build_validation_data_loader(self) -> DataLoader:
        val_dataset = data.WikiTextDataset(
            self.corpus,
            batch_size=self.eval_batch_size,
            valid=True,
        )
        self.val_data_len = len(val_dataset) - 1
        batch_samp = data.BatchSamp(val_dataset, self.bptt)
        return DataLoader(val_dataset, batch_sampler=batch_samp)

    def train_batch(self, batch: TorchData, epoch_idx: int,
                    batch_idx: int) -> Dict[str, Union[torch.Tensor, float]]:
        if batch_idx == 0 and self.model_cls.lower() != "transformer":
            self.hidden = self.model.init_hidden(
                self.context.get_per_slot_batch_size())
        inputs = batch[:-1]
        labels = batch[1:].view(-1)
        if self.model_cls.lower() == "transformer":
            output = self.model(inputs)
            output = output.view(-1, self.corpus.ntokens)
        else:
            self.hidden = self.model.repackage_hidden(self.hidden)
            output, self.hidden = self.model(inputs, self.hidden)
        loss = self.criterion(output, labels)

        self.context.backward(loss)
        self.context.step_optimizer(
            self.optimizer,
            clip_grads=lambda params: torch.nn.utils.clip_grad_norm_(
                params, self.context.get_hparam("max_grad_norm")),
        )
        return {
            "loss": loss,
            "lr": float(self.optimizer.param_groups[0]["lr"])
        }

    def evaluate_full_dataset(
            self, data_loader: DataLoader) -> Dict[str, torch.Tensor]:
        validation_loss = 0.0
        if self.model_cls.lower() != "transformer":
            self.hidden = self.model.init_hidden(self.eval_batch_size)
        for batch in data_loader:
            batch = self.context.to_device(batch)
            if self.model_cls.lower() == "transformer":
                output = self.model(batch[:-1])
                output = output.view(-1, self.corpus.ntokens)
            else:
                output, self.hidden = self.model(batch[:-1], self.hidden)
                self.hidden = self.model.repackage_hidden(self.hidden)
            validation_loss += (
                len(batch[:-1]) *
                self.criterion(output, batch[1:].view(-1)).item())

        validation_loss /= len(data_loader.dataset) - 1
        self.lr_scheduler.step(validation_loss)
        if self.model_cls.lower() != "transformer":
            self.hidden = self.model.init_hidden(
                self.context.get_per_slot_batch_size())
        return {"validation_loss": validation_loss}
Example #6
0
class DartsTrainer():
    def __init__(self, arm):
        # Default params for eval network
        args = {
            'emsize': 850,
            'nhid': 850,
            'nhidlast': 850,
            'dropoute': 0.1,
            'wdecay': 8e-7
        }

        args['data'] = '/home/liamli4465/darts/data/penn'
        args['lr'] = 20
        args['clip'] = 0.25
        args['batch_size'] = 64
        args['search_batch_size'] = 256 * 4
        args['small_batch_size'] = 64
        args['bptt'] = 35
        args['dropout'] = 0.75
        args['dropouth'] = 0.25
        args['dropoutx'] = 0.75
        args['dropouti'] = 0.2
        args['seed'] = arm['seed']
        args['nonmono'] = 5
        args['log_interval'] = 50
        args['save'] = arm['dir']
        args['alpha'] = 0
        args['beta'] = 1e-3
        args['max_seq_length_delta'] = 20
        args['unrolled'] = True
        args['gpu'] = 0
        args['cuda'] = True
        args['genotype'] = arm['genotype']
        args = AttrDict(args)
        self.args = args
        self.epoch = 0

        np.random.seed(args.seed)
        torch.manual_seed(args.seed)
        torch.cuda.set_device(args.gpu)
        cudnn.benchmark = True
        cudnn.enabled = True
        torch.cuda.manual_seed_all(args.seed)

        corpus = data.Corpus(args.data)
        self.corpus = corpus

        self.eval_batch_size = 10
        self.test_batch_size = 1

        self.train_data = batchify(corpus.train, args.batch_size, args)
        self.search_data = batchify(corpus.valid, args.search_batch_size, args)
        self.val_data = batchify(corpus.valid, self.eval_batch_size, args)
        self.test_data = batchify(corpus.test, self.test_batch_size, args)

        self.ntokens = len(corpus.dictionary)

    def model_save(self, fn, to_save):
        if self.epoch % 150 == 0:
            with open(
                    os.path.join(self.args.save,
                                 "checkpoint-incumbent-%d" % self.epoch),
                    'wb') as f:
                torch.save(to_save, f)

        with open(fn, 'wb') as f:
            torch.save(to_save, f)

    def model_load(self, fn):
        with open(fn, 'rb') as f:
            self.model, self.optimizer, rng_state, cuda_state = torch.load(f)
            torch.set_rng_state(rng_state)
            torch.cuda.set_rng_state(cuda_state)

    def model_resume(self, filename):
        logging.info('Resuming model from %s' % filename)
        self.model_load(filename)
        self.optimizer.param_groups[0]['lr'] = self.args.lr
        for rnn in self.model.rnns:
            rnn.genotype = self.args.genotype

    def train_epochs(self, epochs):
        args = self.args
        resume_filename = os.path.join(self.args.save, "checkpoint.incumbent")
        if os.path.exists(resume_filename):
            self.model_resume(resume_filename)
            logging.info('Loaded model from checkpoint')
        else:
            self.model = RNNModel(self.ntokens,
                                  args.emsize,
                                  args.nhid,
                                  args.nhidlast,
                                  args.dropout,
                                  args.dropouth,
                                  args.dropoutx,
                                  args.dropouti,
                                  args.dropoute,
                                  genotype=args.genotype)
            self.optimizer = torch.optim.SGD(self.model.parameters(),
                                             lr=args.lr,
                                             weight_decay=args.wdecay)

        size = 0
        for p in self.model.parameters():
            size += p.nelement()
        logging.info('param size: {}'.format(size))
        logging.info('initial genotype:')
        logging.info(self.model.rnns[0].genotype)

        total_params = sum(x.data.nelement() for x in self.model.parameters())
        logging.info('Args: {}'.format(args))
        logging.info('Model total parameters: {}'.format(total_params))

        self.model = self.model.cuda()
        # Loop over epochs.
        lr = args.lr
        best_val_loss = []
        stored_loss = 100000000

        # At any point you can hit Ctrl + C to break out of training early.
        try:
            for epoch in range(epochs):
                epoch_start_time = time.time()
                self.train()
                if 't0' in self.optimizer.param_groups[0]:
                    tmp = {}
                    for prm in self.model.parameters():
                        tmp[prm] = prm.data.clone()
                        prm.data = self.optimizer.state[prm]['ax'].clone()

                    val_loss2 = self.evaluate(self.val_data)
                    logging.info('-' * 89)
                    logging.info(
                        '| end of epoch {:3d} | time: {:5.2f}s | valid loss {:5.2f} | '
                        'valid ppl {:8.2f} | valid bpc {:8.3f}'.format(
                            self.epoch,
                            (time.time() - epoch_start_time), val_loss2,
                            math.exp(val_loss2), val_loss2 / math.log(2)))
                    logging.info('-' * 89)

                    if val_loss2 < stored_loss:
                        self.model_save(
                            os.path.join(args.save, 'checkpoint.incumbent'), [
                                self.model, self.optimizer,
                                torch.get_rng_state(),
                                torch.cuda.get_rng_state()
                            ])
                        logging.info('Saving Averaged!')
                        stored_loss = val_loss2

                    for prm in self.model.parameters():
                        prm.data = tmp[prm].clone()

                else:
                    val_loss = self.evaluate(self.val_data,
                                             self.eval_batch_size)
                    logging.info('-' * 89)
                    logging.info(
                        '| end of epoch {:3d} | time: {:5.2f}s | valid loss {:5.2f} | '
                        'valid ppl {:8.2f} | valid bpc {:8.3f}'.format(
                            self.epoch,
                            (time.time() - epoch_start_time), val_loss,
                            math.exp(val_loss), val_loss / math.log(2)))
                    logging.info('-' * 89)

                    if val_loss < stored_loss:
                        self.model_save(
                            os.path.join(args.save, 'checkpoint.incumbent'), [
                                self.model, self.optimizer,
                                torch.get_rng_state(),
                                torch.cuda.get_rng_state()
                            ])
                        logging.info('Saving model (new best validation)')
                        stored_loss = val_loss

                    if (self.epoch > 75
                            and 't0' not in self.optimizer.param_groups[0] and
                        (len(best_val_loss) > args.nonmono
                         and val_loss > min(best_val_loss[:-args.nonmono]))):
                        logging.info('Switching to ASGD')
                        self.optimizer = torch.optim.ASGD(
                            self.model.parameters(),
                            lr=args.lr,
                            t0=0,
                            lambd=0.,
                            weight_decay=args.wdecay)

                    best_val_loss.append(val_loss)

        except Exception as e:
            logging.info('-' * 89)
            logging.info(e)
            logging.info('Exiting from training early')
            return 0, 10000, 10000

        # Load the best saved model.
        self.model_load(os.path.join(args.save, 'checkpoint.incumbent'))

        # Run on test data.
        val_loss = self.evaluate(self.val_data, self.eval_batch_size)
        logging.info(math.exp(val_loss))
        test_loss = self.evaluate(self.test_data, self.test_batch_size)
        logging.info('=' * 89)
        logging.info(
            '| End of training | test loss {:5.2f} | test ppl {:8.2f} | test bpc {:8.3f}'
            .format(test_loss, math.exp(test_loss), test_loss / math.log(2)))
        logging.info('=' * 89)

        return 0, math.exp(val_loss), math.exp(test_loss)

    def train(self):
        args = self.args
        corpus = self.corpus
        total_loss = 0
        start_time = time.time()
        hidden = [
            self.model.init_hidden(args.small_batch_size)
            for _ in range(args.batch_size // args.small_batch_size)
        ]
        batch, i = 0, 0

        while i < self.train_data.size(0) - 1 - 1:
            bptt = args.bptt if np.random.random() < 0.95 else args.bptt / 2.
            # Prevent excessively small or negative sequence lengths
            seq_len = max(5, int(np.random.normal(bptt, 5)))
            # There's a very small chance that it could select a very long sequence length resulting in OOM
            seq_len = min(seq_len, args.bptt + args.max_seq_length_delta)

            lr2 = self.optimizer.param_groups[0]['lr']
            self.optimizer.param_groups[0]['lr'] = lr2 * seq_len / args.bptt
            self.model.train()
            data, targets = get_batch(self.train_data,
                                      i,
                                      args,
                                      seq_len=seq_len)

            self.optimizer.zero_grad()

            start, end, s_id = 0, args.small_batch_size, 0
            while start < args.batch_size:
                cur_data, cur_targets = data[:, start:
                                             end], targets[:, start:
                                                           end].contiguous(
                                                           ).view(-1)

                # Starting each batch, we detach the hidden state from how it was previously produced.
                # If we didn't, the model would try backpropagating all the way to start of the dataset.
                hidden[s_id] = repackage_hidden(hidden[s_id])

                log_prob, hidden[s_id], rnn_hs, dropped_rnn_hs = self.model(
                    cur_data, hidden[s_id], return_h=True)
                raw_loss = nn.functional.nll_loss(
                    log_prob.view(-1, log_prob.size(2)), cur_targets)

                loss = raw_loss
                # Activiation Regularization
                if args.alpha > 0:
                    loss = loss + sum(args.alpha * dropped_rnn_h.pow(2).mean()
                                      for dropped_rnn_h in dropped_rnn_hs[-1:])
                # Temporal Activation Regularization (slowness)
                loss = loss + sum(args.beta *
                                  (rnn_h[1:] - rnn_h[:-1]).pow(2).mean()
                                  for rnn_h in rnn_hs[-1:])
                loss *= args.small_batch_size / args.batch_size
                total_loss += raw_loss.data * args.small_batch_size / args.batch_size
                loss.backward()

                s_id += 1
                start = end
                end = start + args.small_batch_size

                gc.collect()

            # `clip_grad_norm` helps prevent the exploding gradient problem in RNNs.
            torch.nn.utils.clip_grad_norm(self.model.parameters(), args.clip)
            self.optimizer.step()

            # total_loss += raw_loss.data
            self.optimizer.param_groups[0]['lr'] = lr2

            if np.isnan(total_loss[0]):
                raise

            #if batch % args.log_interval == 0 and batch > 0:
            #    cur_loss = total_loss[0] / args.log_interval
            #    elapsed = time.time() - start_time
            #    logging.info('| epoch {:3d} | {:5d}/{:5d} batches | lr {:02.2f} | ms/batch {:5.2f} | '
            #            'loss {:5.2f} | ppl {:8.2f}'.format(
            #        self.epoch, batch, len(self.train_data) // args.bptt, self.optimizer.param_groups[0]['lr'],
            #        elapsed * 1000 / args.log_interval, cur_loss, math.exp(cur_loss)))
            #    total_loss = 0
            #    start_time = time.time()
            batch += 1
            i += seq_len
        self.epoch += 1

    def evaluate(self, data_source, batch_size=10):
        # Turn on evaluation mode which disables dropout.
        self.model.eval()
        total_loss = 0
        hidden = self.model.init_hidden(batch_size)
        for i in range(0, data_source.size(0) - 1, self.args.bptt):
            data, targets = get_batch(data_source,
                                      i,
                                      self.args,
                                      evaluation=True)
            targets = targets.view(-1)

            log_prob, hidden = self.model(data, hidden)
            loss = nn.functional.nll_loss(log_prob.view(-1, log_prob.size(2)),
                                          targets).data

            total_loss += loss * len(data)

            hidden = repackage_hidden(hidden)
        return total_loss[0] / len(data_source)
Example #7
0
def get_df(text):
    fn = 'corpus.{}.data'.format(hashlib.md5(args.data.encode()).hexdigest())
    if args.philly:
        fn = os.path.join(os.environ['PT_OUTPUT_DIR'], fn)
    if os.path.exists(fn):
        print('Loading cached dataset...')
        corpus = torch.load(fn)
    else:
        print('Producing dataset...')
        corpus = data.Corpus(data_path, mode=mode)
        torch.save(corpus, fn)

    ntokens = len(corpus.dictionary)

    #initialize the model
    model = RNNModel(args.model, ntokens, args.emsize, args.nhid,
                     args.chunk_size, args.nlayers, args.dropout,
                     args.dropouth, args.dropouti, args.dropoute, args.wdrop,
                     args.tied)

    with open(model_path, "rb") as f:
        model, criterion, optimizer = torch.load(f)

    #prepare data
    eval_batch_size = 10
    test_batch_size = 1
    train_data = batchify(corpus.train, args.batch_size, args)
    val_data = batchify(corpus.valid, eval_batch_size, args)
    test_data = batchify(corpus.test, test_batch_size, args)

    def idx2text(index):
        global corpus
        text = [corpus.dictionary.idx2word[idx] for idx in index]
        text = " ".join(text)
        return text

    def text2idx(text, mode="chinese"):
        global corpus
        if mode == "chinese":
            idx = [
                corpus.dictionary.word2idx.get(word,
                                               corpus.dictionary.word2idx['K'])
                for word in text
            ]
        else:
            idx = [
                corpus.dictionary.word2idx.get(
                    word, corpus.dictionary.word2idx['<unk>'])
                for word in text.split()
            ]
        return idx

    idx = torch.tensor(text2idx(text, mode=mode)).unsqueeze(dim=-1).cuda()
    # seq_len = idx.size(0)
    hidden = model.init_hidden(args.batch_size)
    hidden = repackage_hidden(hidden)
    output, hidden, distances = model(idx, hidden, return_d=True)

    target_layer = 2
    target_idx = 0

    df = distances[0].cpu().data.numpy()

    target_text = [word for word in texts]
    df = df[target_layer, :, target_idx]

    return df
def train():
    best_val_loss = 100

    ntokens = len(corpus.dictionary)
    train_data = batchify(corpus.train, args.batch_size)  # num_batches, batch_size
    val_data = batchify(corpus.valid, args.batch_size)
    model = RNNModel(rnn_type=args.model,
                     ntoken=ntokens,
                     ninp=args.emsize,
                     nfeat=args.nfeat,
                     nhid=args.nhid,
                     nlayers=args.nlayers,
                     font_path=args.font_path,
                     font_size=args.font_size,
                     dropout=args.dropout,
                     tie_weights=args.tied,
                     ).to(device)
    criterion = torch.nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)
    print('start training...')
    hidden = model.init_hidden(args.batch_size)
    epoch_start_time = time.time()

    for epoch in range(args.epochs):

        model.eval()  # 在validation上测试
        total_loss = 0.
        with torch.no_grad():
            for idx in range(0, val_data.size(0) - 1, args.bptt):
                data, targets = get_batch(val_data, idx)
                output, hidden = model(data, hidden)
                output_flat = output.view(-1, ntokens)  # (seq_len, batch, ntokens) -> (seq_len*batch, ntokens)
                total_loss += len(data) * criterion(output_flat, targets.view(-1)).item()
                hidden = repackage_hidden(hidden)
        val_loss = total_loss / len(val_data)
        best_val_loss = min(best_val_loss, val_loss)
        print('-' * 100)
        print('| end of epoch {:3d} | time: {:5.2f}s | valid loss {:5.2f} | valid ppl {:8.2f} | best valid ppl {:8.2f}'
              .format(epoch, (time.time() - epoch_start_time), val_loss, math.exp(val_loss), math.exp(best_val_loss)))
        print('-' * 100)
        epoch_start_time = time.time()
        if val_loss == best_val_loss:  # Save the model if the validation loss is best so far.
            torch.save(model, os.path.join(args.save, 'model.pkl'))
        else:
            args.lr /= 4.0

        model.train()  # 在training set上训练
        total_loss = 0.
        start_time = time.time()
        for i, idx in enumerate(range(0, train_data.size(0) - 1, args.bptt)):
            data, targets = get_batch(train_data, idx)
            hidden = repackage_hidden(hidden)
            model.zero_grad()  # 求loss和梯度
            output, hidden = model(data, hidden)
            loss = criterion(output.view(-1, ntokens), targets.view(-1))
            loss.backward()
            total_loss += loss.item()

            torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip)  # 用梯度更新参数
            optimizer.step()
            # for p in model.parameters():
            #     p.data.add_(-args.lr, p.grad.data)

            if i % args.log_interval == 0 and i > 0:
                cur_loss = total_loss / args.log_interval
                elapsed = time.time() - start_time
                print('| epoch {:3d} | {:5d}/{:5d} batches | lr {:02.2f} | ms/batch {:5.2f} |loss {:5.2f} | ppl {:8.2f}'
                      .format(epoch + 1, i, len(train_data) // args.bptt, args.lr, elapsed * 1000 / args.log_interval,
                              cur_loss, math.exp(cur_loss)))
                total_loss = 0
                start_time = time.time()
Example #9
0
def train():
    # 载入数据与配置模型
    print("Loading data...")
    corpus = Corpus(train_dir)
    print(corpus)

    config = Config()
    config.vocab_size = len(corpus.dictionary)
    train_data = batchify(corpus.train, config.batch_size)
    train_len = train_data.size(0)
    seq_len = config.seq_len

    print("Configuring model...")
    model = RNNModel(config)
    if use_cuda:
        model.cuda()
    print(model)

    criterion = nn.CrossEntropyLoss()
    lr = config.learning_rate  # 初始学习率
    start_time = time.time()

    print("Training and generating...")
    for epoch in range(1, config.num_epochs + 1):  # 多轮次训练
        total_loss = 0.0
        model.train()  # 在训练模式下dropout才可用。
        hidden = model.init_hidden(config.batch_size)  # 初始化隐藏层参数

        for ibatch, i in enumerate(range(0, train_len - 1, seq_len)):
            data, targets = get_batch(train_data, i, seq_len)  # 取一个批次的数据
            # 在每批开始之前,将隐藏的状态与之前产生的结果分离。
            # 如果不这样做,模型会尝试反向传播到数据集的起点。
            hidden = repackage_hidden(hidden)
            model.zero_grad()

            output, hidden = model(data, hidden)
            loss = criterion(output.view(-1, config.vocab_size), targets)
            loss.backward()  # 反向传播

            # `clip_grad_norm` 有助于防止RNNs/LSTMs中的梯度爆炸问题。
            torch.nn.utils.clip_grad_norm(model.parameters(), config.clip)
            for p in model.parameters():  # 梯度更新
                p.data.add_(-lr, p.grad.data)

            total_loss += loss.data  # loss累计

            if ibatch % config.log_interval == 0 and ibatch > 0:  # 每隔多少个批次输出一次状态
                cur_loss = total_loss[0] / config.log_interval
                elapsed = get_time_dif(start_time)
                print(
                    "Epoch {:3d}, {:5d}/{:5d} batches, lr {:2.3f}, loss {:5.2f}, ppl {:8.2f}, time {}"
                    .format(epoch, ibatch, train_len // seq_len, lr, cur_loss,
                            math.exp(cur_loss), elapsed))
                total_loss = 0.0
        lr /= 4.0  # 在一轮迭代完成后,尝试缩小学习率

        # 每隔多少轮次保存一次模型参数
        if epoch % config.save_interval == 0:
            torch.save(model.state_dict(),
                       os.path.join(save_dir, model_name.format(epoch)))

        print(''.join(generate(model, corpus.dictionary.idx2word)))