# for b_samples in samples: # b_seq_len = sq.find_first_min_zero(b_samples) # for dec, dec_len in zip(b_samples.T, b_seq_len): # dec_text = ' '.join(vocabs[1].i2w(dec[:dec_len])) # ofp.write(f'{dec_text}\n') # decode(opt, model_opt, decode_opt, decode_batch, logger, # data_fn, sq.SeqModel) else: vocab = sq.Vocabulary.from_vocab_file( os.path.join(opt['data_dir'], 'vocab.txt')) with open('tmp.txt', mode='w') as ofp: def collect_fn(batch, collect): labels = vocab.i2w(batch.labels.label[:, 0]) nlls = collect[0][:, 0] for label, nll in zip(labels, nlls): ofp.write(f'{label}\t{nll}\n') eval_run_fn = partial(sq.run_collecting_epoch, collect_keys=['nll'], collect_fn=collect_fn) mle(opt, model_opt, train_opt, logger, data_fn, sq.SeqModel, eval_run_fn=eval_run_fn) # mle(opt, model_opt, train_opt, logger, data_fn, sq.SeqModel) logger.info(f'Total time: {sq.time_span_str(time.time() - start_time)}')
decode(opt, model_opt, decode_opt, decode_batch, logger, data_fn, sq.Word2DefModel) else: if pg_opt['pg:enable']: reward_fn = get_reward_fn(opt, pg_opt) policy_gradient(opt, model_opt, train_opt, pg_opt, logger, data_fn, sq.Word2DefModel, reward_fn=reward_fn, pack_data_fn=pack_data) else: mle(opt, model_opt, train_opt, logger, data_fn, sq.Word2DefModel) # with open('tmp.txt', 'w') as ofp: # def write_score(batch, collect): # enc = batch.features.enc_inputs # dec = batch.features.dec_inputs # score = collect[0] # for i in range(len(score)): # _e = enc[0, i] # _d = ' '.join([str(_x) for _x in dec[:, i]]) # ofp.write(f'{_e}\t{_d}\t{score[i]}\n') # eval_fn = partial(sq.run_collecting_epoch, # collect_keys=['dec.batch_loss'], # collect_fn=write_score) # mle(opt, model_opt, train_opt, logger, data_fn, sq.Word2DefModel, # eval_run_fn=eval_fn) logger.info(f'Total time: {sq.time_span_str(time.time() - start_time)}')
batch_iter = partial(sq.seq2seq_batch_iter, batch_size=opt['batch_size']) return data, batch_iter, (enc_vocab, dec_vocab) if opt['command'] == 'decode': with open(decode_opt['decode:outpath'], 'w') as ofp: def decode_batch(batch, samples, vocabs): b_enc = vocabs[0].i2w(batch.features.enc_inputs.T) b_enc_len = batch.features.enc_seq_len for b_samples in samples: b_seq_len = sq.find_first_min_zero(b_samples) for enc, enc_len, dec, dec_len in zip( b_enc, b_enc_len, b_samples.T, b_seq_len): if enc[0] == '</s>': continue enc_text = ' '.join(enc[:enc_len - 1]) dec_text = ' '.join(vocabs[1].i2w(dec[:dec_len])) ofp.write(f'{enc_text}\t{dec_text}\n') decode(opt, model_opt, decode_opt, decode_batch, logger, data_fn, sq.Seq2SeqModel) else: if pg_opt['pg:enable']: policy_gradient(opt, model_opt, train_opt, pg_opt, logger, data_fn, sq.Seq2SeqModel) else: mle(opt, model_opt, train_opt, logger, data_fn, sq.Seq2SeqModel) logger.info(f'Total time: {sq.time_span_str(time.time() - start_time)}')