def valid(valid_loader, seq2seq, epoch): seq2seq.eval() total_loss_t = 0 for num, (test_index, test_in, test_in_len, test_out) in enumerate(valid_loader): #test_in = test_in.unsqueeze(1) test_in, test_out = Variable(test_in, volatile=True).cuda(), Variable( test_out, volatile=True).cuda() output_t, attn_weights_t = seq2seq(test_in, test_out, test_in_len, teacher_rate=False, train=False) batch_count_n = writePredict(epoch, test_index, output_t, 'valid') test_label = test_out.permute(1, 0)[1:].contiguous().view(-1) #loss_t = F.cross_entropy(output_t.view(-1, vocab_size), # test_label, ignore_index=tokens['PAD_TOKEN']) #loss_t = loss_label_smoothing(output_t.view(-1, vocab_size), test_label) if LABEL_SMOOTH: loss_t = crit(log_softmax(output_t.view(-1, vocab_size)), test_label) else: loss_t = F.cross_entropy(output_t.view(-1, vocab_size), test_label, ignore_index=tokens['PAD_TOKEN']) total_loss_t += loss_t.data[0] if 'n04-015-00-01,171' in test_index: b = test_index.tolist().index('n04-015-00-01,171') visualizeAttn(test_in.data[b, 0], test_in_len[0], [j[b] for j in attn_weights_t], epoch, batch_count_n[b], 'valid_n04-015-00-01') total_loss_t /= (num + 1) return total_loss_t
def test(test_loader, modelID, showAttn=True): encoder = Encoder(HIDDEN_SIZE_ENC, HEIGHT, WIDTH, Bi_GRU, CON_STEP, FLIP).cuda() decoder = Decoder(HIDDEN_SIZE_DEC, EMBEDDING_SIZE, vocab_size, Attention, TRADEOFF_CONTEXT_EMBED).cuda() seq2seq = Seq2Seq(encoder, decoder, output_max_len, vocab_size).cuda() model_file = 'save_weights/seq2seq-' + str(modelID) + '.model' pretrain_dict = torch.load(model_file) seq2seq_dict = seq2seq.state_dict() pretrain_dict = { k: v for k, v in pretrain_dict.items() if k in seq2seq_dict } seq2seq_dict.update(pretrain_dict) seq2seq.load_state_dict(seq2seq_dict) #load print('Loading ' + model_file) seq2seq.eval() total_loss_t = 0 start_t = time.time() for num, (test_index, test_in, test_in_len, test_out, test_domain) in enumerate(test_loader): lambd = LAMBD test_in, test_out = Variable(test_in, volatile=True).cuda(), Variable( test_out, volatile=True).cuda() test_domain = Variable(test_domain, volatile=True).cuda() output_t, attn_weights_t, out_domain_t = seq2seq(test_in, test_out, test_in_len, lambd, teacher_rate=False, train=False) batch_count_n = writePredict(modelID, test_index, output_t, 'test') test_label = test_out.permute(1, 0)[1:].contiguous().view(-1) if LABEL_SMOOTH: loss_t = crit(log_softmax(output_t.view(-1, vocab_size)), test_label) else: loss_t = F.cross_entropy(output_t.view(-1, vocab_size), test_label, ignore_index=tokens['PAD_TOKEN']) total_loss_t += loss_t.data[0] if showAttn: global_index_t = 0 for t_idx, t_in in zip(test_index, test_in): visualizeAttn(t_in.data[0], test_in_len[0], [j[global_index_t] for j in attn_weights_t], modelID, batch_count_n[global_index_t], 'test_' + t_idx.split(',')[0]) global_index_t += 1 total_loss_t /= (num + 1) writeLoss(total_loss_t, 'test') print(' TEST loss=%.3f, time=%.3f' % (total_loss_t, time.time() - start_t))
def test(test_loader, modelID, showAttn=True): encoder = Encoder(HIDDEN_SIZE_ENC, HEIGHT, WIDTH, Bi_GRU, CON_STEP, FLIP).to(device) decoder = Decoder(HIDDEN_SIZE_DEC, EMBEDDING_SIZE, vocab_size, Attention, TRADEOFF_CONTEXT_EMBED).to(device) seq2seq = Seq2Seq(encoder, decoder, output_max_len, vocab_size).to(device) model_file = 'save_weights/seq2seq-' + str(modelID) + '.model' print('Loading ' + model_file) seq2seq.load_state_dict(torch.load(model_file)) #load seq2seq.eval() total_loss_t = 0 start_t = time.time() with torch.no_grad(): for num, (test_index, test_in, test_in_len, test_out) in enumerate(test_loader): #test_in = test_in.unsqueeze(1) test_in, test_out = test_in.to(device), test_out.to(device) if test_in.requires_grad or test_out.requires_grad: print( 'ERROR! test_in, test_out should have requires_grad=False') output_t, attn_weights_t = seq2seq(test_in, test_out, test_in_len, teacher_rate=False, train=False) batch_count_n = writePredict(modelID, test_index, output_t, 'test') test_label = test_out.permute(1, 0)[1:].reshape(-1) #loss_t = F.cross_entropy(output_t.view(-1, vocab_size), # test_label, ignore_index=tokens['PAD_TOKEN']) #loss_t = loss_label_smoothing(output_t.view(-1, vocab_size), test_label) if LABEL_SMOOTH: loss_t = crit(log_softmax(output_t.reshape(-1, vocab_size)), test_label) else: loss_t = F.cross_entropy(output_t.reshape(-1, vocab_size), test_label, ignore_index=tokens['PAD_TOKEN']) total_loss_t += loss_t.item() if showAttn: global_index_t = 0 for t_idx, t_in in zip(test_index, test_in): visualizeAttn(t_in.detach()[0], test_in_len[0], [j[global_index_t] for j in attn_weights_t], modelID, batch_count_n[global_index_t], 'test_' + t_idx.split(',')[0]) global_index_t += 1 total_loss_t /= (num + 1) writeLoss(total_loss_t, 'test') print(' TEST loss=%.3f, time=%.3f' % (total_loss_t, time.time() - start_t))
def train(train_loader, seq2seq, opt, teacher_rate, epoch, lambd): seq2seq.train() total_loss = 0 total_loss_d = 0 for num, (train_index, train_in, train_in_len, train_out, train_domain) in enumerate(train_loader): train_in, train_out = Variable(train_in).cuda(), Variable( train_out).cuda() train_domain = Variable(train_domain).cuda() output, attn_weights, out_domain = seq2seq( train_in, train_out, train_in_len, lambd, teacher_rate=teacher_rate, train=True) # (100-1, 32, 62+1) batch_count_n = writePredict(epoch, train_index, output, 'train') train_label = train_out.permute(1, 0)[1:].contiguous().view( -1) #remove<GO> output_l = output.view(-1, vocab_size) # remove last <EOS> if VISUALIZE_TRAIN: if 'e02-074-03-00,191' in train_index: b = train_index.tolist().index('e02-074-03-00,191') visualizeAttn(train_in.data[b, 0], train_in_len[0], [j[b] for j in attn_weights], epoch, batch_count_n[b], 'train_e02-074-03-00') if LABEL_SMOOTH: loss = crit(log_softmax(output_l.view(-1, vocab_size)), train_label) else: loss = F.cross_entropy(output_l.view(-1, vocab_size), train_label, ignore_index=tokens['PAD_TOKEN']) loss2 = F.cross_entropy(out_domain, train_domain) loss2 = ALPHA * loss2 loss_total = loss + loss2 opt.zero_grad() loss_total.backward() opt.step() total_loss += loss.data[0] total_loss_d += loss2.data[0] total_loss /= (num + 1) total_loss_d /= (num + 1) return total_loss, total_loss_d
def train(train_loader, seq2seq, opt, teacher_rate, epoch): seq2seq.train() total_loss = 0 for num, (train_index, train_in, train_in_len, train_out) in enumerate(train_loader): #train_in = train_in.unsqueeze(1) train_in, train_out = Variable(train_in).cuda(), Variable( train_out).cuda() output, attn_weights = seq2seq(train_in, train_out, train_in_len, teacher_rate=teacher_rate, train=True) # (100-1, 32, 62+1) batch_count_n = writePredict(epoch, train_index, output, 'train') train_label = train_out.permute(1, 0)[1:].contiguous().view( -1) #remove<GO> output_l = output.view(-1, vocab_size) # remove last <EOS> if VISUALIZE_TRAIN: if 'e02-074-03-00,191' in train_index: b = train_index.tolist().index('e02-074-03-00,191') visualizeAttn(train_in.data[b, 0], train_in_len[0], [j[b] for j in attn_weights], epoch, batch_count_n[b], 'train_e02-074-03-00') #loss = F.cross_entropy(output_l.view(-1, vocab_size), # train_label, ignore_index=tokens['PAD_TOKEN']) #loss = loss_label_smoothing(output_l.view(-1, vocab_size), train_label) if LABEL_SMOOTH: loss = crit(log_softmax(output_l.view(-1, vocab_size)), train_label) else: loss = F.cross_entropy(output_l.view(-1, vocab_size), train_label, ignore_index=tokens['PAD_TOKEN']) opt.zero_grad() loss.backward() opt.step() total_loss += loss.item() print(f'Batch {num} loss: {loss.item()}') total_loss /= (num + 1) return total_loss
def valid(valid_loader, seq2seq, epoch): seq2seq.eval() total_loss_t = 0 with torch.no_grad(): for num, (test_index, test_in, test_in_len, test_out) in enumerate(valid_loader): #test_in = test_in.unsqueeze(1) test_in, test_out = test_in.to(device), test_out.to(device) if test_in.requires_grad or test_out.requires_grad: print( 'ERROR! test_in, test_out should have requires_grad=False') output_t, attn_weights_t = seq2seq(test_in, test_out, test_in_len, teacher_rate=False, train=False) batch_count_n = writePredict(epoch, test_index, output_t, 'valid') test_label = test_out.permute(1, 0)[1:].reshape(-1) #loss_t = F.cross_entropy(output_t.view(-1, vocab_size), # test_label, ignore_index=tokens['PAD_TOKEN']) #loss_t = loss_label_smoothing(output_t.view(-1, vocab_size), test_label) if LABEL_SMOOTH: loss_t = crit(log_softmax(output_t.reshape(-1, vocab_size)), test_label) else: loss_t = F.cross_entropy(output_t.reshape(-1, vocab_size), test_label, ignore_index=tokens['PAD_TOKEN']) total_loss_t += loss_t.item() if 'n04-015-00-01,171' in test_index: b = test_index.tolist().index('n04-015-00-01,171') visualizeAttn(test_in.detach()[b, 0], test_in_len[0], [j[b] for j in attn_weights_t], epoch, batch_count_n[b], 'valid_n04-015-00-01') total_loss_t /= (num + 1) return total_loss_t