def test(model, val_iter, vocab, reg_lambda, recon_lambda): model.eval() loss_checker = LossChecker(4) PAD_idx = vocab.word2idx['<PAD>'] for b, batch in enumerate(val_iter, 1): _, feats, captions = parse_batch(batch) output, feats_recon = model(feats) cross_entropy_loss = F.nll_loss(output[1:].view(-1, vocab.n_vocabs), captions[1:].contiguous().view(-1), ignore_index=PAD_idx) entropy_loss = losses.entropy_loss( output[1:], ignore_mask=(captions[1:] == PAD_idx)) if model.reconstructor is None: reconstruction_loss = torch.zeros(1) elif model.reconstructor._type == 'global': reconstruction_loss = losses.global_reconstruction_loss( feats, feats_recon, keep_mask=(captions != PAD_idx)) else: reconstruction_loss = losses.local_reconstruction_loss( feats, feats_recon) loss = cross_entropy_loss + reg_lambda * entropy_loss + recon_lambda * reconstruction_loss loss_checker.update(loss.item(), cross_entropy_loss.item(), entropy_loss.item(), reconstruction_loss.item()) total_loss, cross_entropy_loss, entropy_loss, reconstruction_loss = loss_checker.mean( ) loss = { 'total': total_loss, 'cross_entropy': cross_entropy_loss, 'entropy': entropy_loss, 'reconstruction': reconstruction_loss, } return loss
def train(e, model, optimizer, train_iter, vocab, teacher_forcing_ratio, reg_lambda, gradient_clip): model.train() loss_checker = LossChecker(3) PAD_idx = vocab.word2idx['<PAD>'] for b, batch in enumerate(train_iter, 1): _, feats, captions = parse_batch(batch) optimizer.zero_grad() output = model(feats, captions, teacher_forcing_ratio) cross_entropy_loss = F.nll_loss(output[1:].view(-1, vocab.n_vocabs), captions[1:].contiguous().view(-1), ignore_index=PAD_idx) entropy_loss = reg_lambda * losses.entropy_loss(output[1:], ignore_mask=(captions[1:] == PAD_idx)) loss = cross_entropy_loss + entropy_loss loss.backward() if gradient_clip is not None: torch.nn.utils.clip_grad_norm_(model.parameters(), gradient_clip) optimizer.step() loss_checker.update(loss.item(), cross_entropy_loss.item(), entropy_loss.item()) if len(train_iter) < 10 or b % (len(train_iter) // 10) == 0: inter_loss, inter_cross_entropy_loss, inter_entropy_loss = loss_checker.mean(last=10) print("\t[{:d}/{:d}] loss: {:.4f} (CE {:.4f} + E {:.4f})".format( b, len(train_iter), inter_loss, inter_cross_entropy_loss, inter_entropy_loss)) total_loss, cross_entropy_loss, entropy_loss = loss_checker.mean() loss = { 'total': total_loss, 'cross_entropy': cross_entropy_loss, 'entropy': entropy_loss, } return loss
def train(e, model, optimizer, train_iter, vocab, teacher_forcing_ratio, reg_lambda, gradient_clip): model.train() loss_checker = LossChecker(3) PAD_idx = vocab.word2idx['<PAD>'] t = tqdm(train_iter) for batch in t: _, feats, captions = parse_batch(batch) optimizer.zero_grad() output = model(feats, captions, teacher_forcing_ratio) cross_entropy_loss = F.nll_loss(output[1:].view(-1, vocab.n_vocabs), captions[1:].contiguous().view(-1), ignore_index=PAD_idx) entropy_loss = losses.entropy_loss( output[1:], ignore_mask=(captions[1:] == PAD_idx)) loss = cross_entropy_loss + reg_lambda * entropy_loss loss.backward() if gradient_clip is not None: torch.nn.utils.clip_grad_norm_(model.parameters(), gradient_clip) optimizer.step() loss_checker.update(loss.item(), cross_entropy_loss.item(), entropy_loss.item()) t.set_description( "[Epoch #{0}] loss: {2:.3f} = (CE: {3:.3f}) + (Ent: {1} * {4:.3f})" .format(e, reg_lambda, *loss_checker.mean(last=10))) total_loss, cross_entropy_loss, entropy_loss = loss_checker.mean() loss = { 'total': total_loss, 'cross_entropy': cross_entropy_loss, 'entropy': entropy_loss, } return loss
def train(e, model, optimizer, train_iter, vocab, teacher_forcing_ratio, reg_lambda, recon_lambda, gradient_clip): model.train() loss_checker = LossChecker(4) PAD_idx = vocab.word2idx['<PAD>'] #Python进度条 t = tqdm(train_iter) for batch in t: #变量放GPU上 _, feats, captions = parse_batch(batch) optimizer.zero_grad() #解码器输出、重构器 output, feats_recon = model(feats, captions, teacher_forcing_ratio) #最大似然 / log似然代价函数NLLLoss的输入是一个对数概率向量和一个目标标签. 它不会为我们计算对数概率. 适合网络的最后一层是log_softmax. 交叉熵损失 # cross_entropy_loss = F.nll_loss(output[1:].view(-1, vocab.n_vocabs), captions[1:].contiguous().view(-1), ignore_index=PAD_idx) #??????????????????????????????????????? entropy_loss = losses.entropy_loss( output[1:], ignore_mask=(captions[1:] == PAD_idx)) #reg_lambda=0 loss = cross_entropy_loss + reg_lambda * entropy_loss if model.reconstructor is None: reconstruction_loss = torch.zeros(1) else: if model.reconstructor._type == 'global': reconstruction_loss = losses.global_reconstruction_loss( feats, feats_recon, keep_mask=(captions != PAD_idx)) else: reconstruction_loss = losses.local_reconstruction_loss( feats, feats_recon) loss += recon_lambda * reconstruction_loss loss.backward() if gradient_clip is not None: torch.nn.utils.clip_grad_norm_(model.parameters(), gradient_clip) optimizer.step() loss_checker.update(loss.item(), cross_entropy_loss.item(), entropy_loss.item(), reconstruction_loss.item()) #.3f 保留三位小数 t.set_description( "[Epoch #{0}] loss: {3:.3f} = (CE: {4:.3f}) + (Ent: {1} * {5:.3f}) + (Rec: {2} * {6:.3f})" .format(e, reg_lambda, recon_lambda, *loss_checker.mean(last=10))) total_loss, cross_entropy_loss, entropy_loss, reconstruction_loss = loss_checker.mean( ) loss = { 'total': total_loss, 'cross_entropy': cross_entropy_loss, 'entropy': entropy_loss, 'reconstruction': reconstruction_loss, } return loss
def evaluate(model, val_iter, vocab, reg_lambda): model.eval() loss_checker = LossChecker(3) PAD_idx = vocab.word2idx['<PAD>'] for b, batch in enumerate(val_iter, 1): _, feats, captions = parse_batch(batch) output = model(feats, captions) cross_entropy_loss = F.nll_loss(output[1:].view(-1, vocab.n_vocabs), captions[1:].contiguous().view(-1), ignore_index=PAD_idx) entropy_loss = reg_lambda * losses.entropy_loss(output[1:], ignore_mask=(captions[1:] == PAD_idx)) loss = cross_entropy_loss + entropy_loss loss_checker.update(loss.item(), cross_entropy_loss.item(), entropy_loss.item()) total_loss, cross_entropy_loss, entropy_loss = loss_checker.mean() loss = { 'total': total_loss, 'cross_entropy': cross_entropy_loss, 'entropy': entropy_loss, } return loss