def _train_epoch(epoch: int):
        torch.cuda.empty_cache()
        random.shuffle(train_data)
        train_iter = data.iterator.pool(
            train_data,
            config_data.batch_size,
            key=lambda x: (len(x[0]), len(x[1])),
            # key is not used if sort_within_batch is False by default
            batch_size_fn=utils.batch_size_fn,
            random_shuffler=data.iterator.RandomShuffler())

        for _, train_batch in tqdm(enumerate(train_iter)):
            optim.zero_grad()
            in_arrays = data_utils.seq2seq_pad_concat_convert(
                train_batch, device=device)
            loss = model(
                encoder_input=in_arrays[0],
                is_train_mode=True,
                decoder_input=in_arrays[1],
                labels=in_arrays[2],
            )
            loss.backward()

            optim.step()
            scheduler.step()

            step = scheduler.last_epoch
            if step % config_data.display_steps == 0:
                logger.info('step: %d, loss: %.4f', step, loss)
                lr = optim.param_groups[0]['lr']
                print(f"lr: {lr} step: {step}, loss: {loss:.4}")
            if step and step % config_data.eval_steps == 0:
                _eval_epoch(epoch, mode='eval')
Beispiel #2
0
    def _train_epoch(sess, epoch, step, smry_writer):
        random.shuffle(train_data)
        train_iter = data.iterator.pool(
            train_data,
            config_data.batch_size,
            key=lambda x: (len(x[0]), len(x[1])),
            batch_size_fn=utils.batch_size_fn,
            random_shuffler=data.iterator.RandomShuffler())

        for _, train_batch in enumerate(train_iter):
            in_arrays = data_utils.seq2seq_pad_concat_convert(train_batch)
            feed_dict = {
                encoder_input: in_arrays[0],
                decoder_input: in_arrays[1],
                labels: in_arrays[2],
                learning_rate: utils.get_lr(step, config_model.lr)
            }
            fetches = {
                'step': global_step,
                'train_op': train_op,
                'smry': summary_merged,
                'loss': mle_loss,
            }

            fetches_ = sess.run(fetches, feed_dict=feed_dict)

            step, loss = fetches_['step'], fetches_['loss']
            if step and step % config_data.display_steps == 0:
                logger.info('step: %d, loss: %.4f', step, loss)
                print('step: %d, loss: %.4f' % (step, loss))
                smry_writer.add_summary(fetches_['smry'], global_step=step)

            if step and step % config_data.eval_steps == 0:
                _eval_epoch(sess, epoch, mode='eval')
        return step