def train(my_lang, criterion, teacher_forcing_ratio, \ training_data, encoder, decoder,\ encoder_optimizer, decoder_optimizer, max_length): total_loss = 0 predict_num = 0 # Training mode encoder.train() decoder.train() for index, sentence in enumerate(training_data): if index == len(training_data) - 1: break encoder_optimizer.zero_grad() decoder_optimizer.zero_grad() loss = 0 encoder_hidden = encoder.init_hidden() encoder_outputs = Variable(torch.zeros(max_length, encoder.hidden_size)) decoder_input = Variable( torch.LongTensor([[my_lang.word2index["SOS"]]])) encoder_outputs = check_cuda_for_var(encoder_outputs) decoder_input = check_cuda_for_var(decoder_input) for ei in range(len(sentence)): encoder_output, encoder_hidden = encoder(sentence[ei], encoder_hidden) encoder_outputs[ei] = encoder_output[0][0] decoder_hidden = encoder_hidden next_sentence = training_data[index + 1] if random.random() < teacher_forcing_ratio: for di in range(len(next_sentence)): decoder_output, decoder_hidden, decoder_attention = decoder(decoder_input, decoder_hidden, \ encoder_outputs) loss += criterion(decoder_output[0], next_sentence[di]) predict_num += 1 decoder_input = next_sentence[di] else: for di in range(len(next_sentence)): decoder_output, decoder_hidden, decoder_attention = decoder(decoder_input, decoder_hidden, \ encoder_outputs) loss += criterion(decoder_output[0], next_sentence[di]) predict_num += 1 topv, topi = decoder_output.data.topk(1) ni = topi[0][0] decoder_input = Variable(torch.LongTensor([[ni]])) decoder_input = check_cuda_for_var(decoder_input) total_loss += loss loss.backward() encoder_optimizer.step() decoder_optimizer.step() return total_loss.data[0] / predict_num
def sample(my_lang, dialog, encoder, decoder, max_length): # Eval mode encoder.eval() decoder.eval() print("Golden ->") for sentence in dialog: string = ' '.join( [my_lang.index2word[word.data[0]] for word in sentence]) print(string) print("Predict ->") gen_sentence = [] for index, sentence in enumerate(dialog): if index == len(dialog) - 1: break encoder_hidden = encoder.init_hidden() encoder_outputs = Variable(torch.zeros(max_length, encoder.hidden_size)) decoder_input = Variable( torch.LongTensor([[my_lang.word2index["SOS"]]])) encoder_outputs = check_cuda_for_var(encoder_outputs) decoder_input = check_cuda_for_var(decoder_input) if len(gen_sentence) > 0: for ei in range(len(gen_sentence)): encoder_output, encoder_hidden = encoder( gen_sentence[ei], encoder_hidden) encoder_outputs[ei] = encoder_output[0][0] gen_sentence = [] else: for ei in range(len(sentence)): encoder_output, encoder_hidden = encoder( sentence[ei], encoder_hidden) encoder_outputs[ei] = encoder_output[0][0] decoder_hidden = encoder_hidden next_sentence = dialog[index + 1] for di in range(len(next_sentence)): gen_sentence.append(decoder_input.data[0][0]) decoder_output, decoder_hidden, decoder_attention = decoder(decoder_input, decoder_hidden, \ encoder_outputs) topv, topi = decoder_output.data.topk(1) ni = topi[0][0] decoder_input = Variable(torch.LongTensor([[ni]])) decoder_input = check_cuda_for_var(decoder_input) gen_sentence.append(my_lang.word2index["EOS"]) gen_sentence = Variable(torch.LongTensor(gen_sentence)) gen_sentence = check_cuda_for_var(gen_sentence) string = ' '.join( [my_lang.index2word[word.data[0]] for word in gen_sentence]) print(string)
def validate(my_lang, criterion, validation_data, encoder, decoder, max_length): total_loss = 0 predict_num = 0 # Eval mode encoder.eval() decoder.eval() for counter, dialog in enumerate(validation_data): if counter == len(validation_data) - 1: sample(my_lang, dialog, encoder, decoder, max_length) for index, sentence in enumerate(dialog): if index == len(dialog) - 1: break loss = 0 encoder_hidden = encoder.init_hidden() encoder_outputs = Variable( torch.zeros(max_length, encoder.hidden_size)) decoder_input = Variable( torch.LongTensor([[my_lang.word2index["SOS"]]])) encoder_outputs = check_cuda_for_var(encoder_outputs) decoder_input = check_cuda_for_var(decoder_input) for ei in range(len(sentence)): encoder_output, encoder_hidden = encoder( sentence[ei], encoder_hidden) encoder_outputs[ei] = encoder_output[0][0] decoder_hidden = encoder_hidden next_sentence = dialog[index + 1] for di in range(len(next_sentence)): decoder_output, decoder_hidden, decoder_attention = decoder(decoder_input, decoder_hidden, \ encoder_outputs) loss += criterion(decoder_output[0], next_sentence[di]) predict_num += 1 topv, topi = decoder_output.data.topk(1) ni = topi[0][0] decoder_input = Variable(torch.LongTensor([[ni]])) decoder_input = check_cuda_for_var(decoder_input) if isinstance(loss, float): total_loss += loss else: total_loss += loss.data[0] return total_loss / predict_num
def gen(sentence): max_length = 20 encoder.eval() decoder.eval() talking_history = [] gen_sentence = [] counter = 0 while counter < 10: encoder_hidden = encoder.init_hidden() encoder_outputs = Variable(torch.zeros(max_length, encoder.hidden_size)) decoder_input = Variable(torch.LongTensor([[my_lang.word2index["SOS"]]])) encoder_outputs = check_cuda_for_var(encoder_outputs) decoder_input = check_cuda_for_var(decoder_input) if len(gen_sentence) > 0: for ei in range(len(gen_sentence)): encoder_output, encoder_hidden = encoder(gen_sentence[ei], encoder_hidden) encoder_outputs[ei] = encoder_output[0][0] # Clean generated sentence list gen_sentence = [] else: for ei in range(len(sentence)): encoder_output, encoder_hidden = encoder(sentence[ei], encoder_hidden) encoder_outputs[ei] = encoder_output[0][0] decoder_hidden = encoder_hidden while True: if DEBUG: print("[Debug] ", decoder_input.data) gen_sentence.append(decoder_input.data[0][0]) if gen_sentence[-1] == my_lang.word2index["EOS"] or len(gen_sentence) >= max_length - 1: break decoder_output, decoder_hidden, decoder_attention = decoder(decoder_input, decoder_hidden, \ encoder_outputs) _, topi = decoder_output.data.topk(1) ni = topi[0][0] decoder_input = Variable(torch.LongTensor([[ni]])) decoder_input = check_cuda_for_var(decoder_input) gen_sentence = Variable(torch.LongTensor(gen_sentence)) gen_sentence = check_cuda_for_var(gen_sentence) string = ' '.join([my_lang.index2word[word.data[0]] for word in gen_sentence]) print(string) talking_history.append(string) if "EOD" in string or args.sbs: break counter += 1 return talking_history
def sample(my_lang, dialog, encoder, context, decoder, print_golden=True): # Eval mode encoder.eval() context.eval() decoder.eval() if print_golden: print("Golden ->") for sentence in dialog: string = ' '.join( [my_lang.index2word[word.data[0]] for word in sentence]) print(string) print("Predict ->") gen_sentence = [] context_hidden = context.init_hidden() for index, sentence in enumerate(dialog): if index == len(dialog) - 1: break decoder_input = Variable( torch.LongTensor([[my_lang.word2index["SOS"]]])) decoder_input = check_cuda_for_var(decoder_input) encoder_hidden = encoder.init_hidden() decoder_hidden = decoder.init_hidden() if len(gen_sentence) > 0: for ei in range(len(gen_sentence)): _, encoder_hidden = encoder(gen_sentence[ei], encoder_hidden) # Clean generated sentence list gen_sentence = [] else: for ei in range(len(sentence)): _, encoder_hidden = encoder(sentence[ei], encoder_hidden) # decoder_hidden = encoder_hidden context_output, context_hidden = context(encoder_hidden, context_hidden) next_sentence = dialog[index + 1] for di in range(len(next_sentence)): gen_sentence.append(decoder_input.data[0][0]) decoder_output, decoder_hidden = decoder(context_hidden,\ decoder_input, decoder_hidden) _, topi = decoder_output.data.topk(1) ni = topi[0][0] decoder_input = Variable(torch.LongTensor([[ni]])) if torch.cuda.is_available(): decoder_input = decoder_input.cuda() # Make gen_sentence concated with a EOS and make it torch Variable gen_sentence.append(my_lang.word2index["EOS"]) gen_sentence = Variable(torch.LongTensor(gen_sentence)) if torch.cuda.is_available(): gen_sentence = gen_sentence.cuda() string = ' '.join( [my_lang.index2word[word.data[0]] for word in gen_sentence]) print(string)
def train(my_lang, criterion, teacher_forcing_ratio, \ training_data, encoder, context, decoder,\ encoder_optimizer, context_optimizer, decoder_optimizer): # Training mode encoder.train() context.train() decoder.train() # Zero gradients encoder_optimizer.zero_grad() context_optimizer.zero_grad() decoder_optimizer.zero_grad() loss = Variable(torch.FloatTensor(1)) nn.init.constant(loss, 0) loss = check_cuda_for_var(loss) context_hidden = context.init_hidden() predict_count = 0 model_predict = [] for index, sentence in enumerate(training_data): if index == len(training_data) - 1: break decoder_input = Variable( torch.LongTensor([[my_lang.word2index["SOS"]]])) decoder_input = check_cuda_for_var(decoder_input) encoder_hidden = encoder.init_hidden() decoder_hidden = decoder.init_hidden() for ei in range(len(sentence)): if ei > len(model_predict) - 1 or random.random( ) < teacher_forcing_ratio: _, encoder_hidden = encoder(sentence[ei], encoder_hidden) else: _, encoder_hidden = encoder(model_predict[ei], encoder_hidden) # Assign last encoder's hidden to decoder # decoder_hidden = encoder_hidden context_output, context_hidden = context(encoder_hidden, context_hidden) next_sentence = training_data[index + 1] model_predict = [] teacher_forcing = random.random() < teacher_forcing_ratio for di in range(len(next_sentence)): predict_count += 1 decoder_output, decoder_hidden = decoder(context_hidden,\ decoder_input, decoder_hidden) loss += criterion(decoder_output[0], next_sentence[di]) # Scheduled Sampling _, topi = decoder_output.data.topk(1) ni = topi[0][0] ni_var = Variable(torch.LongTensor([[ni]])) if torch.cuda.is_available(): ni_var = ni_var.cuda() model_predict.append(ni_var) if teacher_forcing: decoder_input = next_sentence[di].unsqueeze(1) else: decoder_input = ni_var loss.backward() encoder_optimizer.step() context_optimizer.step() decoder_optimizer.step() return loss.data[0] / (predict_count)
def validate(my_lang, criterion, teacher_forcing_ratio, \ validation_data, encoder, context, decoder,\ encoder_optimizer, context_optimizer, decoder_optimizer): validation_loss = 0 # Eval mode encoder.eval() context.eval() decoder.eval() for dialog in validation_data: context_hidden = context.init_hidden() predict_count = 0 loss = 0 gen_sentence = [] for index, sentence in enumerate(dialog): if index == len(dialog) - 1: break decoder_input = Variable( torch.LongTensor([[my_lang.word2index["SOS"]]])) decoder_input = check_cuda_for_var(decoder_input) encoder_hidden = encoder.init_hidden() decoder_hidden = decoder.init_hidden() if len(gen_sentence) > 0: for ei in range(len(gen_sentence)): _, encoder_hidden = encoder(gen_sentence[ei], encoder_hidden) # Clean generated sentence list gen_sentence = [] else: for ei in range(len(sentence)): _, encoder_hidden = encoder(sentence[ei], encoder_hidden) # decoder_hidden = encoder_hidden context_output, context_hidden = context(encoder_hidden, context_hidden) next_sentence = dialog[index + 1] for di in range(len(next_sentence)): predict_count += 1 gen_sentence.append(decoder_input.data[0][0]) decoder_output, decoder_hidden = decoder(context_hidden,\ decoder_input, decoder_hidden) loss += criterion(decoder_output[0], next_sentence[di]) # TODO Greedy alg. now, maybe use beam search when inferencing in the future _, topi = decoder_output.data.topk(1) ni = topi[0][0] #if ni == 1: # EOS # break decoder_input = Variable(torch.LongTensor([[ni]])) if torch.cuda.is_available(): decoder_input = decoder_input.cuda() # Make gen_sentence concated with a EOS and make it torch Variable gen_sentence.append(my_lang.word2index["EOS"]) gen_sentence = Variable(torch.LongTensor(gen_sentence)) if torch.cuda.is_available(): gen_sentence = gen_sentence.cuda() validation_loss += (loss.data[0] / predict_count) return validation_loss / len(validation_data)
def gen(sentence): encoder.eval() context.eval() decoder.eval() # Inference gen_sentence = [] talking_history = [] context_hidden = context.init_hidden() max_dialog_len = 20 max_sentence_len = 15 beam_size = args.beam for _ in range(max_dialog_len): decoder_input = Variable(torch.LongTensor([[my_lang.word2index["SOS"]]])) decoder_input = check_cuda_for_var(decoder_input) encoder_hidden = encoder.init_hidden() decoder_hidden = decoder.init_hidden() if len(gen_sentence) > 0: for ei in range(len(gen_sentence)): _, encoder_hidden = encoder(gen_sentence[ei], encoder_hidden) # Clean generated sentence list gen_sentence = [] else: for ei in range(len(sentence)): _, encoder_hidden = encoder(sentence[ei], encoder_hidden) context_output, context_hidden = context(encoder_hidden, context_hidden) # Beam search index2state = {} for index in range(beam_size): index2state[index] = [decoder_input, decoder_hidden, [decoder_input.data[0][0]], 0.0] # One step to get beam_size candidates decoder_output, decoder_hidden = decoder(context_hidden,\ decoder_input, decoder_hidden) scores, topi = decoder_output.data.topk(beam_size) for index in range(beam_size): ni = topi[0][index] index2state[index][0] = check_cuda_for_var(Variable(torch.LongTensor([[ni]]))) index2state[index][1] = decoder_hidden index2state[index][2].append(ni) index2state[index][3] = scores[0][index] for sentence_pointer in range(max_sentence_len): current_scores = [] current2state = {} # Init current2state for index in range(beam_size): for jndex in range(beam_size): current2state[index * beam_size + jndex] = [0, 0, 0, 0] for index in range(beam_size): output, hidden = decoder(context_hidden, \ index2state[index][0], index2state[index][1]) tops, topi = output.data.topk(beam_size) for jndex in range(beam_size): ni = topi[0][jndex] current_map = current2state[index * beam_size + jndex] current_map[0] = check_cuda_for_var(Variable(torch.LongTensor([[ni]]))) current_map[1] = hidden current_map[2] = index2state[index][2][:] current_map[2].append(ni) current_map[3] = tops[0][jndex] + index2state[index][3] if args.eodlong == 1 and my_lang.word2index["EOD"] in current_map[2]: current_map[3] *= exp(max_sentence_len - 12 - sentence_pointer) current_scores.append(current_map[3]) _, top_of_beamsize2 = torch.FloatTensor(current_scores).topk(beam_size) # Top beam's output is eos, break and output the top beam if current2state[top_of_beamsize2[0]][2][-1] == my_lang.word2index["EOS"]: if args.nosr == 1 and current2state[top_of_beamsize2[0]][2] in talking_history: # Don't repeat itself # Soft verion current2state[top_of_beamsize2[0]][3] *= 2 # Hard version #current2state[top_of_beamsize2[0][3]] *= 100000.0 else: first_eos = current2state[top_of_beamsize2[0]][2].index(my_lang.word2index["EOS"]) gen_sentence = current2state[top_of_beamsize2[0]][2][:first_eos+1] break after_beam_dict = {} for index, candidate in enumerate(top_of_beamsize2): after_beam_dict[index] = current2state[candidate] index2state = after_beam_dict # Beam Search a good sentence and assign to gen_sentence talking_history.append(gen_sentence) gen_sentence = Variable(torch.LongTensor(gen_sentence)) gen_sentence = check_cuda_for_var(gen_sentence) try: string = ' '.join([my_lang.index2word[word.data[0]] for word in gen_sentence]) print(string) if "EOD" in string: break except RuntimeError: break return talking_history
break decoder_output, decoder_hidden, decoder_attention = decoder(decoder_input, decoder_hidden, \ encoder_outputs) _, topi = decoder_output.data.topk(1) ni = topi[0][0] decoder_input = Variable(torch.LongTensor([[ni]])) decoder_input = check_cuda_for_var(decoder_input) gen_sentence = Variable(torch.LongTensor(gen_sentence)) gen_sentence = check_cuda_for_var(gen_sentence) string = ' '.join([my_lang.index2word[word.data[0]] for word in gen_sentence]) print(string) talking_history.append(string) if "EOD" in string or args.sbs: break counter += 1 return talking_history # Generating string try: if args.sbs == 0 or args.type == 'seq2seq': while True: start = input("[%s] >>> " % (args.type.upper())) clean_sentence = clean(start) clean_sentence_idx = my_lang.sentence2index(clean_sentence) clean_sentence_idx = Variable(torch.LongTensor(clean_sentence_idx)) clean_sentence_idx = check_cuda_for_var(clean_sentence_idx) gen(clean_sentence_idx) else: genSbyS() except KeyboardInterrupt: print()