def _train_epoch(epoch: int): torch.cuda.empty_cache() random.shuffle(train_data) train_iter = data.iterator.pool( train_data, config_data.batch_size, key=lambda x: (len(x[0]), len(x[1])), # key is not used if sort_within_batch is False by default batch_size_fn=utils.batch_size_fn, random_shuffler=data.iterator.RandomShuffler()) for _, train_batch in tqdm(enumerate(train_iter)): optim.zero_grad() in_arrays = data_utils.seq2seq_pad_concat_convert( train_batch, device=device) loss = model( encoder_input=in_arrays[0], is_train_mode=True, decoder_input=in_arrays[1], labels=in_arrays[2], ) loss.backward() optim.step() scheduler.step() step = scheduler.last_epoch if step % config_data.display_steps == 0: logger.info('step: %d, loss: %.4f', step, loss) lr = optim.param_groups[0]['lr'] print(f"lr: {lr} step: {step}, loss: {loss:.4}") if step and step % config_data.eval_steps == 0: _eval_epoch(epoch, mode='eval')
def _train_epoch(sess, epoch, step, smry_writer): random.shuffle(train_data) train_iter = data.iterator.pool( train_data, config_data.batch_size, key=lambda x: (len(x[0]), len(x[1])), batch_size_fn=utils.batch_size_fn, random_shuffler=data.iterator.RandomShuffler()) for _, train_batch in enumerate(train_iter): in_arrays = data_utils.seq2seq_pad_concat_convert(train_batch) feed_dict = { encoder_input: in_arrays[0], decoder_input: in_arrays[1], labels: in_arrays[2], learning_rate: utils.get_lr(step, config_model.lr) } fetches = { 'step': global_step, 'train_op': train_op, 'smry': summary_merged, 'loss': mle_loss, } fetches_ = sess.run(fetches, feed_dict=feed_dict) step, loss = fetches_['step'], fetches_['loss'] if step and step % config_data.display_steps == 0: logger.info('step: %d, loss: %.4f', step, loss) print('step: %d, loss: %.4f' % (step, loss)) smry_writer.add_summary(fetches_['smry'], global_step=step) if step and step % config_data.eval_steps == 0: _eval_epoch(sess, epoch, mode='eval') return step