def __init__(self, expt_dir='experiment', loss=NLLLoss(), batch_size=64, random_seed=None, checkpoint_every=1000, print_every=100, tensorboard=True, batch_adv_loss=NLLLoss()): self._trainer = "Adversarial Trainer" self.random_seed = random_seed if random_seed is not None: random.seed(random_seed) torch.manual_seed(random_seed) self.loss = loss self.evaluator = Evaluator(loss=self.loss, batch_size=batch_size) self.optimizer = None self.checkpoint_every = checkpoint_every self.print_every = print_every if not os.path.isabs(expt_dir): expt_dir = os.path.join(os.getcwd(), expt_dir) self.expt_dir = expt_dir if not os.path.exists(self.expt_dir): os.makedirs(self.expt_dir) self.batch_size = batch_size self.logger = logging.getLogger(__name__) self.writer = SummaryWriter(log_dir=expt_dir) if tensorboard else None self.batch_adv_loss = batch_adv_loss
def __init__(self, expt_dir='experiment', loss=[NLLLoss()], loss_weights=None, metrics=[], batch_size=64, eval_batch_size=128, random_seed=None, checkpoint_every=100, print_every=100): self._trainer = "Simple Trainer" self.random_seed = random_seed if random_seed is not None: random.seed(random_seed) torch.manual_seed(random_seed) k = NLLLoss() self.loss = loss self.metrics = metrics self.loss_weights = loss_weights or len(loss) * [1.] self.evaluator = Evaluator(loss=self.loss, metrics=self.metrics, batch_size=eval_batch_size) self.optimizer = None self.checkpoint_every = checkpoint_every self.print_every = print_every if not os.path.isabs(expt_dir): expt_dir = os.path.join(os.getcwd(), expt_dir) self.expt_dir = expt_dir if not os.path.exists(self.expt_dir): os.makedirs(self.expt_dir) self.batch_size = batch_size self.logger = logging.getLogger(__name__)
def __init__(self, expt_dir='experiment', loss=NLLLoss(), batch_size=64, random_seed=None, state_loss=NLLLoss(), checkpoint_every=100, print_every=100): self._trainer = "Simple Trainer" self.random_seed = random_seed if random_seed is not None: random.seed(random_seed) torch.manual_seed(random_seed) self.loss = loss self.state_loss = state_loss self.evaluator = Evaluator(loss=self.loss, batch_size=batch_size) self.optimizer = None self.checkpoint_every = checkpoint_every self.print_every = print_every if not os.path.isabs(expt_dir): expt_dir = os.path.join(os.getcwd(), expt_dir) self.expt_dir = expt_dir if not os.path.exists(self.expt_dir): os.makedirs(self.expt_dir) self.batch_size = batch_size self.logger = logging.getLogger(__name__)
def __init__(self, vocab_size: int, embedding_size: int, n_hidden: int, sos_token: int = 0, eos_token: int = 1, mask_token: int = 2, max_output_length: int = 100, rnn_cell: str = 'lstm') -> None: self.decoder = DecoderRNN(vocab_size, max_output_length, embedding_size, n_layers=n_hidden, rnn_cell=rnn_cell, use_attention=False, bidirectional=False, eos_id=eos_token, sos_id=sos_token) if torch.cuda.is_available(): self.decoder.cuda() self.rnn_cell = rnn_cell self.n_hidden = n_hidden self.embedding_size = embedding_size self.SOS_token = sos_token self.EOS_token = eos_token self.mask_token = mask_token self.max_output_length = max_output_length token_weights = torch.ones(vocab_size) if torch.cuda.is_available(): token_weights = token_weights.cuda() self.loss = NLLLoss(weight=token_weights, mask=mask_token) self.optimizer = None
def __init__(self, expt_dir='experiment_sc', loss=PositiveLoss(), batch_size=64, random_seed=None, checkpoint_every=100, print_every=100, output_vocab=None): self._trainer = "Self Critical Trainer" self.random_seed = random_seed if random_seed is not None: random.seed(random_seed) torch.manual_seed(random_seed) self.loss = loss self.evaluator = Evaluator(loss=NLLLoss(), batch_size=batch_size) self.optimizer = None self.checkpoint_every = checkpoint_every self.print_every = print_every self.output_vocab = output_vocab if not os.path.isabs(expt_dir): expt_dir = os.path.join(os.getcwd(), expt_dir) self.expt_dir = expt_dir if not os.path.exists(self.expt_dir): os.makedirs(self.expt_dir) self.batch_size = batch_size self.logger = logging.getLogger(__name__)
def __init__(self, expt_dir='experiments', loss=NLLLoss(), batch_size=64, random_seed=None, checkpoint_every=100, patience=5): self._trainer = "Simple Trainer" self.random_seed = random_seed if random_seed is not None: random.seed(random_seed) torch.manual_seed(random_seed) self.loss = loss # set by a subclass self.evaluator = None self.optimizer = None self.checkpoint_every = checkpoint_every self.early_stopping_teacher = EarlyStopping_NoImprovement( patience=patience) self.early_stopping_student = EarlyStopping_NoImprovement( patience=patience) if not os.path.isabs(expt_dir): expt_dir = os.path.join(os.getcwd(), expt_dir) self.expt_dir = expt_dir if not os.path.exists(self.expt_dir): os.makedirs(self.expt_dir) self.batch_size = batch_size
def __init__(self, expt_dir='experiment', loss=NLLLoss(), batch_size=64, random_seed=None, checkpoint_every=100, print_every=100, optimizer=Optimizer(optim.Adam, max_grad_norm=5)): self._trainer = "Simple Trainer" self.random_seed = random_seed if random_seed is not None: random.seed(random_seed) torch.manual_seed(random_seed) self.loss = loss self.evaluator = Evaluator(loss=self.loss, batch_size=batch_size) self.optimizer = optimizer self.checkpoint_every = checkpoint_every self.print_every = print_every if not os.path.isabs(expt_dir): expt_dir = os.path.join(os.getcwd(), expt_dir) self.expt_dir = expt_dir if not os.path.exists(self.expt_dir): os.makedirs(self.expt_dir) self.batch_size = batch_size self.input_vocab_file = os.path.join(self.expt_dir, 'input_vocab') self.output_vocab_file = os.path.join(self.expt_dir, 'output_vocab') self.logger = logging.getLogger(__name__)
def __init__(self, loss=NLLLoss(), batch_size=64): """Class to initialize an evaluator Args: loss (seq2seq.loss, optional): loss for evaluator (default: seq2seq.loss.NLLLoss) batch_size (int, optional): batch size for evaluator (default: 64) """ self.loss = loss self.batch_size = batch_size
def __init__(self, loss=NLLLoss(), explosion_rate=120, batch_size=1024, polyglot=False): self.loss = loss self.batch_size = batch_size self.polyglot = polyglot self.explosion_rate = explosion_rate
def __init__(self, expt_dir='experiment', loss=[NLLLoss()], loss_weights=None, metrics=[], batch_size=64, eval_batch_size=128, random_seed=None, checkpoint_every=100, print_every=100, early_stopper=None, anneal_middropout=0, min_middropout=0.01): self._trainer = "Simple Trainer" self.random_seed = random_seed if random_seed is not None: random.seed(random_seed) torch.manual_seed(random_seed) k = NLLLoss() self.loss = loss self.metrics = metrics self.loss_weights = loss_weights or len(loss) * [1.] self.evaluator = Evaluator(loss=self.loss, metrics=self.metrics, batch_size=eval_batch_size) self.optimizer = None self.checkpoint_every = checkpoint_every self.print_every = print_every self.anneal_middropout = anneal_middropout self.min_middropout = 0 if self.anneal_middropout == 0 else min_middropout self.early_stopper = early_stopper if early_stopper is not None: assert self.early_stopper.mode == "min", "Can currently only be used with the loss, please use mode='min'" if not os.path.isabs(expt_dir): expt_dir = os.path.join(os.getcwd(), expt_dir) self.expt_dir = expt_dir if not os.path.exists(self.expt_dir): os.makedirs(self.expt_dir) self.batch_size = batch_size self.logger = logging.getLogger(__name__)
def test_perplexity(self): nll = NLLLoss() ppl = Perplexity() nll.eval_batch(self.outputs, self.batch) ppl.eval_batch(self.outputs, self.batch) nll_loss = nll.get_loss() ppl_loss = ppl.get_loss() self.assertAlmostEqual(ppl_loss, math.exp(nll_loss))
def test_nllloss_WITH_OUT_SIZE_AVERAGE(self): loss = NLLLoss(size_average=False) pytorch_loss = 0 pytorch_criterion = torch.nn.NLLLoss(size_average=False) for output, target in zip(self.outputs, self.targets): loss.eval_batch(output, target) pytorch_loss += pytorch_criterion(output, target) loss_val = loss.get_loss() self.assertAlmostEqual(loss_val, pytorch_loss.data[0])
def test_perplexity(self): nll = NLLLoss() ppl = Perplexity() for output, target in zip(self.outputs, self.targets): nll.eval_batch(output, target) ppl.eval_batch(output, target) nll_loss = nll.get_loss() ppl_loss = ppl.get_loss() self.assertAlmostEqual(ppl_loss, math.exp(nll_loss))
def test_nllloss_WITH_OUT_SIZE_AVERAGE(self): loss = NLLLoss(reduction='sum') pytorch_loss = 0 pytorch_criterion = torch.nn.NLLLoss(reduction='sum') for output, target in zip(self.outputs, self.targets): loss.eval_batch(output, target) pytorch_loss += pytorch_criterion(output, target) loss_val = loss.get_loss() self.assertAlmostEqual(loss_val, pytorch_loss.item())
def __init__(self, expt_dir='experiment', loss=NLLLoss(), batch_size=64, random_seed=None, checkpoint_every=100, pretraining=False, polyglot=False, explosion_train=10, explosion_eval=120): super(MirrorTrainer, self).__init__( expt_dir=expt_dir, loss=loss, batch_size=batch_size, random_seed=random_seed, checkpoint_every=checkpoint_every) self._trainer = "Mirror Trainer" self.pretraining = pretraining self.polyglot = polyglot self.evaluator = PolyEvaluator( explosion_rate=explosion_eval, loss=self.loss, batch_size=512, polyglot=self.polyglot) self.explosion_train = explosion_train
def test_nllloss(self): loss = NLLLoss() pytorch_loss = 0 pytorch_criterion = torch.nn.NLLLoss() for output, target in zip(self.outputs, self.targets): loss.eval_batch(output, target) pytorch_loss += pytorch_criterion(output, target) loss_val = loss.get_loss() pytorch_loss /= self.num_batch self.assertAlmostEqual(loss_val, pytorch_loss.data[0])
def test_nllloss_WITH_OUT_SIZE_AVERAGE(self): num_repeat = 10 loss = NLLLoss(reduction='sum') pytorch_loss = 0 pytorch_criterion = torch.nn.NLLLoss(reduction='sum') for _ in range(num_repeat): for step, output in enumerate(self.outputs): pytorch_loss += pytorch_criterion(output, self.targets[:, step + 1]) loss.eval_batch(self.outputs, self.batch) loss_val = loss.get_loss() self.assertAlmostEqual(loss_val, pytorch_loss.item())
def test_nllloss(self): num_batch = 10 loss = NLLLoss() pytorch_loss = 0 pytorch_criterion = torch.nn.NLLLoss() for _ in range(num_batch): for step, output in enumerate(self.outputs): pytorch_loss += pytorch_criterion(output, self.targets[:, step + 1]) loss.eval_batch(self.outputs, self.batch) loss_val = loss.get_loss() pytorch_loss /= (num_batch * len(self.outputs)) self.assertAlmostEqual(loss_val, pytorch_loss.item())
def __init__(self, model_dir='experiment', best_model_dir='experiment/best', loss=NLLLoss(), batch_size=64, random_seed=None, checkpoint_every=100, print_every=100, max_epochs=5, max_steps=10000, max_checkpoints_num=5, best_ppl=100000.0, device=None): self._trainer = "Simple Trainer" self.random_seed = random_seed if random_seed is not None: random.seed(random_seed) torch.manual_seed(random_seed) self.loss = loss self.optimizer = None self.checkpoint_every = checkpoint_every self.print_every = print_every self.max_steps = max_steps self.max_epochs = max_epochs self.batch_size = batch_size self.best_ppl = best_ppl self.max_checkpoints_num = max_checkpoints_num self.device = device self.evaluator = Evaluator(loss=self.loss, batch_size=batch_size, device=device) if not os.path.isabs(model_dir): model_dir = os.path.join(os.getcwd(), model_dir) self.model_dir = model_dir if not os.path.exists(self.model_dir): os.makedirs(self.model_dir) if not os.path.isabs(best_model_dir): best_model_dir = os.path.join(os.getcwd(), best_model_dir) self.best_model_dir = best_model_dir if not os.path.exists(self.best_model_dir): os.makedirs(self.best_model_dir) self.model_checkpoints = [] self.best_model_checkpoints = [] self.logger = logging.getLogger(__name__)
def __init__(self, experiment_directory='./experiment', loss=None, batch_size=64, random_seed=None, checkpoint_every=100, print_every=100): if loss is None: loss = NLLLoss() if random_seed is not None: random.seed(random_seed) torch.manual_seed(random_seed) self.loss = loss self.evaluator = Evaluator(loss=self.loss, batch_size=batch_size) self.optimizer = None self.checkpoint_every = checkpoint_every self.print_every = print_every self.batch_size = batch_size self.experiment_directory = experiment_directory if not os.path.exists(self.experiment_directory): os.makedirs(self.experiment_directory)
def test_perplexity(self): num_class = 5 num_batch = 10 batch_size = 5 outputs = [F.softmax(Variable(torch.randn(batch_size, num_class))) for _ in range(num_batch)] targets = [Variable(torch.LongTensor([random.randint(0, num_class - 1) for _ in range(batch_size)])) for _ in range(num_batch)] nll = NLLLoss() ppl = Perplexity() for output, target in zip(outputs, targets): nll.eval_batch(output, target) ppl.eval_batch(output, target) nll_loss = nll.get_loss() ppl_loss = ppl.get_loss() self.assertAlmostEqual(ppl_loss, math.exp(nll_loss))
def test_nllloss(self): num_class = 5 num_batch = 10 batch_size = 5 outputs = [F.softmax(Variable(torch.randn(batch_size, num_class))) for _ in range(num_batch)] targets = [Variable(torch.LongTensor([random.randint(0, num_class - 1) for _ in range(batch_size)])) for _ in range(num_batch)] loss = NLLLoss() pytorch_loss = 0 pytorch_criterion = torch.nn.NLLLoss() for output, target in zip(outputs, targets): loss.eval_batch(output, target) pytorch_loss += pytorch_criterion(output, target) loss_val = loss.get_loss() pytorch_loss /= num_batch self.assertAlmostEqual(loss_val, pytorch_loss.data[0])
def prepare_losses_and_metrics( opt, pad, unk, sos, eos, input_vocab, output_vocab): use_output_eos = not opt.ignore_output_eos # Prepare loss and metrics losses = [NLLLoss(ignore_index=pad)] loss_weights = [1.] for loss in losses: loss.to(device) metrics = [] if 'word_acc' in opt.metrics: metrics.append(WordAccuracy(ignore_index=pad)) if 'seq_acc' in opt.metrics: metrics.append(SequenceAccuracy(ignore_index=pad)) if 'target_acc' in opt.metrics: metrics.append(FinalTargetAccuracy(ignore_index=pad, eos_id=eos)) if 'sym_rwr_acc' in opt.metrics: metrics.append(SymbolRewritingAccuracy( input_vocab=input_vocab, output_vocab=output_vocab, use_output_eos=use_output_eos, output_sos_symbol=sos, output_pad_symbol=pad, output_eos_symbol=eos, output_unk_symbol=unk)) if 'bleu' in opt.metrics: metrics.append(BLEU( input_vocab=input_vocab, output_vocab=output_vocab, use_output_eos=use_output_eos, output_sos_symbol=sos, output_pad_symbol=pad, output_eos_symbol=eos, output_unk_symbol=unk)) return losses, loss_weights, metrics
def __init__(self, loss=NLLLoss(), loss_plan=None, loss_reconstruct=None, batch_size=64, valid_dep=None): self.loss = loss self.loss_plan = loss_plan self.loss_reconstruct = loss_reconstruct self.batch_size = batch_size self.valid_dep = valid_dep
# inputs.build_vocab(src.vocab) # src.vocab.load_vectors(wv_type='glove.6B', wv_dim=opt.word_dim) src.vocab.load_vectors(wv_type=opt.pre_emb, wv_dim=opt.word_dim) # src.vocab.load_vectors(wv_type='fasttext.en.300d', wv_dim=300) # src.vocab.load_vectors(wv_type='charngram.100d', wv_dim=100) # NOTE: If the source field name and the target field name # are different from 'src' and 'tgt' respectively, they have # to be set explicitly before any training or inference # seq2seq.src_field_name = 'src' # seq2seq.tgt_field_name = 'tgt' # Prepare loss weight = torch.ones(len(tgt.vocab)) pad = tgt.vocab.stoi[tgt.pad_token] # loss = Perplexity(weight, pad) loss = NLLLoss(weight=weight, mask=pad, size_average=True) if torch.cuda.is_available(): loss.cuda() seq2seq = None optimizer = None if not opt.resume: # Initialize model hidden_size = opt.word_lstm_dim bidirectional = opt.word_bidirect encoder = EncoderRNN(vocab_size=len(src.vocab), max_len=max_len, word_dim=opt.word_dim, hidden_size=hidden_size, input_dropout_p=opt.input_dropout, bidirectional=bidirectional,
print('Loaded A1 as submodel') else: A1 = m A1.to(device) else: A1 = get_seq2seq() if args.init_A1_from_A2: with open(args.init_A1_from_A2, "rb") as fin: A1 = pickle.load(fin).A2.to(device) print('Loaded A1 as an A2 submodel') A1.flatten_parameters() weight = torch.ones(len(field.vocab.stoi), device=device) pad = field.vocab.stoi['<pad>'] loss = NLLLoss(weight, pad) train_dataset = teacher_train dev_dataset = teacher_dev test_dataset = teacher_test if args.eval is not None: evaluator = PolyEvaluator( loss=loss, explosion_rate=args.explosion_eval, batch_size=2048, polyglot=polyglot) with open(args.eval, "rb") as fin: model = pickle.load(fin) eval_results = evaluator.evaluate(model, dev_dataset) dev_loss, teacher_accuracy, student_accuracy = eval_results log_msg = "Dev %s: %.4f, Accuracy Teacher: %.4f, Accuracy Student: %.4f" % ( loss.name, dev_loss, teacher_accuracy, student_accuracy) print(log_msg, flush=True)
def __init__(self, loss=NLLLoss(), batch_size=64): self.loss = loss self.batch_size = batch_size
except AttributeError: D = BinaryClassifierCNN(len(EN.vocab), embed_dim=opt.embed_dim, num_kernel=opt.num_kernel, kernel_sizes=opt.kernel_sizes, dropout_p=opt.dropout_p) # optim_G = ScheduledOptim(optim.Adam( # G.get_trainable_parameters(), # betas=(0.9, 0.98), eps=1e-09), # opt.d_model, opt.n_warmup_steps) optim_G = optim.Adam(G.parameters(), lr=1e-4, betas=(0.9, 0.98), eps=1e-09) optim_D = torch.optim.Adam(D.parameters(), lr=1e-4) crit_G = NLLLoss(size_average=False) crit_D = nn.BCELoss() if opt.cuda: G.cuda() D.cuda() crit_G.cuda() crit_D.cuda() # ---------- train ---------- trainer_D = trainers.DiscriminatorTrainer() # pre-train D if not hasattr(opt, 'load_D_from'): pool = helper.DiscriminatorDataPool(opt.max_len, D.min_len, Constants.PAD)
src.build_vocab(train, max_size=50000) tgt.build_vocab(train, max_size=50000) input_vocab = src.vocab output_vocab = tgt.vocab # NOTE: If the source field name and the target field name # are different from 'src' and 'tgt' respectively, they have # to be set explicitly before any training or inference # seq2seq.src_field_name = 'src' # seq2seq.tgt_field_name = 'tgt' # Prepare loss weight = torch.ones(len(tgt.vocab)) pad = tgt.vocab.stoi[tgt.pad_token] # loss = Perplexity(weight, pad) loss = NLLLoss(size_average=False) if torch.cuda.is_available(): loss.cuda() seq2seq = None optimizer = None if not opt.resume: # Initialize model hidden_size = 512 bidirectional = True encoder = EncoderRNN(len(src.vocab), max_len, hidden_size, n_layers=1, bidirectional=bidirectional, variable_lengths=True)
def main(): parser = argparse.ArgumentParser() opt = options.train_options(parser) opt = parser.parse_args() opt.cuda = torch.cuda.is_available() opt.device = None if opt.cuda else -1 # 快速變更設定 opt.exp_dir = './experiment/transformer-reinforce/use_billion' opt.load_vocab_from = './experiment/transformer/lang8-cor2err/vocab.pt' opt.build_vocab_from = './data/billion/billion.30m.model.vocab' opt.load_D_from = opt.exp_dir # opt.load_D_from = None # dataset params opt.max_len = 20 # G params # opt.load_G_a_from = './experiment/transformer/lang8-err2cor/' # opt.load_G_b_from = './experiment/transformer/lang8-cor2err/' opt.d_word_vec = 300 opt.d_model = 300 opt.d_inner_hid = 600 opt.n_head = 6 opt.n_layers = 3 opt.embs_share_weight = False opt.beam_size = 1 opt.max_token_seq_len = opt.max_len + 2 # 包含<BOS>, <EOS> opt.n_warmup_steps = 4000 # D params opt.embed_dim = opt.d_model opt.num_kernel = 100 opt.kernel_sizes = [3, 4, 5, 6, 7] opt.dropout_p = 0.25 # train params opt.batch_size = 1 opt.n_epoch = 10 if not os.path.exists(opt.exp_dir): os.makedirs(opt.exp_dir) logging.basicConfig(filename=opt.exp_dir + '/.log', format=LOG_FORMAT, level=logging.DEBUG) logging.getLogger().addHandler(logging.StreamHandler()) logging.info('Use CUDA? ' + str(opt.cuda)) logging.info(opt) # ---------- prepare dataset ---------- def len_filter(example): return len(example.src) <= opt.max_len and len( example.tgt) <= opt.max_len EN = SentencePieceField(init_token=Constants.BOS_WORD, eos_token=Constants.EOS_WORD, batch_first=True, include_lengths=True) train = datasets.TranslationDataset(path='./data/dualgan/train', exts=('.billion.sp', '.use.sp'), fields=[('src', EN), ('tgt', EN)], filter_pred=len_filter) val = datasets.TranslationDataset(path='./data/dualgan/val', exts=('.billion.sp', '.use.sp'), fields=[('src', EN), ('tgt', EN)], filter_pred=len_filter) train_lang8, val_lang8 = Lang8.splits(exts=('.err.sp', '.cor.sp'), fields=[('src', EN), ('tgt', EN)], train='test', validation='test', test=None, filter_pred=len_filter) # 讀取 vocabulary(確保一致) try: logging.info('Load voab from %s' % opt.load_vocab_from) EN.load_vocab(opt.load_vocab_from) except FileNotFoundError: EN.build_vocab_from(opt.build_vocab_from) EN.save_vocab(opt.load_vocab_from) logging.info('Vocab len: %d' % len(EN.vocab)) # 檢查Constants是否有誤 assert EN.vocab.stoi[Constants.BOS_WORD] == Constants.BOS assert EN.vocab.stoi[Constants.EOS_WORD] == Constants.EOS assert EN.vocab.stoi[Constants.PAD_WORD] == Constants.PAD assert EN.vocab.stoi[Constants.UNK_WORD] == Constants.UNK # ---------- init model ---------- # G = build_G(opt, EN, EN) hidden_size = 512 bidirectional = True encoder = EncoderRNN(len(EN.vocab), opt.max_len, hidden_size, n_layers=1, bidirectional=bidirectional) decoder = DecoderRNN(len(EN.vocab), opt.max_len, hidden_size * 2 if bidirectional else 1, n_layers=1, dropout_p=0.2, use_attention=True, bidirectional=bidirectional, eos_id=Constants.EOS, sos_id=Constants.BOS) G = Seq2seq(encoder, decoder) for param in G.parameters(): param.data.uniform_(-0.08, 0.08) # optim_G = ScheduledOptim(optim.Adam( # G.get_trainable_parameters(), # betas=(0.9, 0.98), eps=1e-09), # opt.d_model, opt.n_warmup_steps) optim_G = optim.Adam(G.parameters(), lr=1e-4, betas=(0.9, 0.98), eps=1e-09) loss_G = NLLLoss(size_average=False) if torch.cuda.is_available(): loss_G.cuda() # # 預先訓練D if opt.load_D_from: D = load_model(opt.load_D_from) else: D = build_D(opt, EN) optim_D = torch.optim.Adam(D.parameters(), lr=1e-4) def get_criterion(vocab_size): ''' With PAD token zero weight ''' weight = torch.ones(vocab_size) weight[Constants.PAD] = 0 return nn.CrossEntropyLoss(weight, size_average=False) crit_G = get_criterion(len(EN.vocab)) crit_D = nn.BCELoss() if opt.cuda: G.cuda() D.cuda() crit_G.cuda() crit_D.cuda() # ---------- train ---------- trainer_D = trainers.DiscriminatorTrainer() if not opt.load_D_from: for epoch in range(1): logging.info('[Pretrain D Epoch %d]' % epoch) pool = helper.DiscriminatorDataPool(opt.max_len, D.min_len, Constants.PAD) # 將資料塞進pool中 train_iter = data.BucketIterator(dataset=train, batch_size=opt.batch_size, device=opt.device, sort_key=lambda x: len(x.src), repeat=False) pool.fill(train_iter) # train D trainer_D.train(D, train_iter=pool.batch_gen(), crit=crit_D, optimizer=optim_D) pool.reset() Checkpoint(model=D, optimizer=optim_D, epoch=0, step=0, input_vocab=EN.vocab, output_vocab=EN.vocab).save(opt.exp_dir) def eval_D(): pool = helper.DiscriminatorDataPool(opt.max_len, D.min_len, Constants.PAD) val_iter = data.BucketIterator(dataset=val, batch_size=opt.batch_size, device=opt.device, sort_key=lambda x: len(x.src), repeat=False) pool.fill(val_iter) trainer_D.evaluate(D, val_iter=pool.batch_gen(), crit=crit_D) # eval_D() # Train G ALPHA = 0 for epoch in range(100): logging.info('[Epoch %d]' % epoch) train_iter = data.BucketIterator(dataset=train, batch_size=1, device=opt.device, sort_within_batch=True, sort_key=lambda x: len(x.src), repeat=False) for step, batch in enumerate(train_iter): src_seq = batch.src[0] src_length = batch.src[1] tgt_seq = src_seq[0].clone() # gold = tgt_seq[:, 1:] optim_G.zero_grad() loss_G.reset() decoder_outputs, decoder_hidden, other = G.rollout(src_seq, None, None, n_rollout=1) for i, step_output in enumerate(decoder_outputs): batch_size = tgt_seq.size(0) # print(step_output) # loss_G.eval_batch(step_output.contiguous().view(batch_size, -1), tgt_seq[:, i + 1]) softmax_output = torch.exp( torch.cat([x for x in decoder_outputs], dim=0)).unsqueeze(0) softmax_output = helper.stack(softmax_output, 8) print(softmax_output) rollout = softmax_output.multinomial(1) print(rollout) tgt_seq = helper.pad_seq(tgt_seq.data, max_len=len(decoder_outputs) + 1, pad_value=Constants.PAD) tgt_seq = autograd.Variable(tgt_seq) for i, step_output in enumerate(decoder_outputs): batch_size = tgt_seq.size(0) loss_G.eval_batch( step_output.contiguous().view(batch_size, -1), tgt_seq[:, i + 1]) G.zero_grad() loss_G.backward() optim_G.step() if step % 100 == 0: pred = torch.cat([x for x in other['sequence']], dim=1) print('[step %d] loss_rest %.4f' % (epoch * len(train_iter) + step, loss_G.get_loss())) print('%s -> %s' % (EN.reverse(tgt_seq.data)[0], EN.reverse(pred.data)[0])) # Reinforce Train G for p in D.parameters(): p.requires_grad = False