def batchop2(datapoints, VOCAB, GENDER, config, for_prediction=False, *args, **kwargs): indices = [d.id for d in datapoints] in_sequence = [] if for_prediction: out_sequence = [] gender = [] for d in datapoints: gender.append(GENDER[d.gender]) in_sequence.append([VOCAB['GO']] + [VOCAB[w] for w in d.in_sequence] + [VOCAB['EOS']]) if for_prediction: out_sequence.append([VOCAB['GO']] + [VOCAB[w] for w in d.out_sequence] + [VOCAB['EOS']]) gender = LongVar(config, gender) in_sequence = LongVar(config, pad_seq(in_sequence)).transpose(0, 1) if for_prediction: out_sequence = LongVar(config, pad_seq(out_sequence)).transpose(0, 1) if for_prediction: batch = indices, (gender, in_sequence), (out_sequence) else: batch = indices, (gender, in_sequence), () return batch
def batchop(datapoints, VOCAB, GENDER, config, for_prediction=False, *args, **kwargs): indices = [d.id for d in datapoints] in_sequence = [] out_sequence = [] gender = [] for d in datapoints: gender.append(GENDER[d.gender]) in_sequence.append([VOCAB['GO']] + [VOCAB[w] for w in d.in_sequence] + [VOCAB['EOS']]) out_sequence.append([VOCAB['GO']] + [VOCAB[w] for w in d.out_sequence] + [VOCAB['EOS']]) gender = LongVar(config, gender) in_sequence = LongVar(config, pad_seq(in_sequence)).transpose(0, 1) out_sequence = LongVar(config, pad_seq(out_sequence)).transpose(0, 1) #print(list(i.size() for i in [gender, in_sequence, out_sequence])) batch = indices, (gender, in_sequence), (out_sequence) return batch
def _predict(): print(' requests incoming..') sentence = [] try: input_string = word_tokenize(request.json["text"].lower()) sentence.append([VOCAB[w] for w in input_string] + [VOCAB['EOS']]) dummy_label = LongVar([0]) sentence = LongVar(sentence) input_ = [0], (sentence, ), (0, ) output, attn = model(input_) #print(LABELS[output.max(1)[1]], attn) nwords = len(input_string) return jsonify({ "result": { 'sentence': input_string, 'attn': [ '{:0.4f}'.format(i) for i in attn.squeeze().data.cpu().numpy().tolist()[:-1] ], 'probs': [ '{:0.4f}'.format(i) for i in output.exp().squeeze().data.cpu().numpy().tolist() ], 'label': LABELS[output.max(1)[1].squeeze().data.cpu().numpy()] } }) except Exception as e: print(e) return jsonify({"result": "model failed"})
def batchop(datapoints, VOCAB, config, *args, **kwargs): indices = [d.id for d in datapoints] sequence = [] for d in datapoints: sequence.append([VOCAB[w] for w in d.sequence]) sequence = LongVar(config, pad_seq(sequence)) sequence = sequence.transpose(1, 0) batch = indices, (sequence[:-1]), (sequence[1:]) return batch
def batchop(datapoints, VOCAB, GENDER, config, *args, **kwargs): indices = [d.id for d in datapoints] sequence = [] gender = [] for d in datapoints: sequence.append([VOCAB['GO']] + [VOCAB[w] for w in d.sequence] + [VOCAB['EOS']]) gender.append(GENDER[d.gender]) sequence = LongVar(config, pad_seq(sequence)) gender = LongVar(config, gender) batch = indices, (gender, sequence), () return batch
def batchop(datapoints, VOCAB, config, *args, **kwargs): indices = [d.id for d in datapoints] word = [] context = [] for d in datapoints: word.append (VOCAB[ d.word]) context.append(VOCAB[d.context]) word = LongVar(config, word) context = LongVar(config, context) batch = indices, word, context return batch
def predict_batchop(datapoints, VOCAB, LABELS, config, *args, **kwargs): indices = [d.id for d in datapoints] story = [] question = [] for d in datapoints: story.append([VOCAB[w] for w in d.story]) question.append([VOCAB[w] for w in d.q]) story = LongVar(config, pad_seq(story)) question = LongVar(config, pad_seq(question)) batch = indices, (story, question), () return batch
def batchop(datapoints, VOCAB, LABELS, *args, **kwargs): indices = [d.id for d in datapoints] sentence = [] label = [] for d in datapoints: sentence.append([VOCAB[w] for w in d.sentence] + [VOCAB['EOS']]) #sentence.append([VOCAB[w] for w in d.sentence]) label.append(LABELS[d.label]) sentence = LongVar(pad_seq(sentence)) label = LongVar(label) batch = indices, (sentence, ), (label, ) return batch
def train_on_feed(feed): losses = [] feed.reset_offset() for j in tqdm(range(feed.num_batch), desc='Trainer.{}'.format(self.name())): self.optimizer.zero_grad() input_ = feed.next_batch() idxs, (gender, sequence), targets = input_ sequence = sequence.transpose(0,1) seq_size, batch_size = sequence.size() state = self.initial_hidden(batch_size) loss = 0 output = sequence[0] positions = LongVar(self.config, np.linspace(0, 1, seq_size)) for ti in range(1, sequence.size(0) - 1): output = self.forward(gender, positions[ti], output, state) loss += self.loss_function(ti, output, input_) output, state = output if random.random() > self.teacher_forcing_ratio: output = output.max(1)[1] teacher_force_count[0] += 1 else: output = sequence[ti+1] teacher_force_count[1] += 1 losses.append(loss) loss.backward() self.optimizer.step() return torch.stack(losses).mean()
def batchop(datapoints, VOCAB, LABELS, *args, **kwargs): indices = [d.id for d in datapoints] story = [] question = [] answer = [] for d in datapoints: story.append([VOCAB[w] for w in d.story]) question.append([VOCAB[w] for w in d.q]) answer.append(LABELS[d.a]) story = LongVar(pad_seq(story)) question = LongVar(pad_seq(question)) answer = LongVar(answer) batch = indices, (story, question), (answer) return batch
def batchop(self, datapoints, for_prediction=False, *args, **kwargs): indices = [d.id for d in datapoints] sequence = [] label = [] for d in datapoints: sequence.append([dataset.input_vocab[w] for w in d.sequence]) if not for_prediction: label.append(dataset.output_vocab[d.label]) sequence = LongVar(self.config, pad_seq(sequence)) if not for_prediction: label = LongVar(self.config, label) batch = indices, (sequence, ), (label) return batch
def batchop(datapoints, VOCAB, LABELS, config, for_prediction=False, *args, **kwargs): indices = [d.id for d in datapoints] sequence = [] label = [] for d in datapoints: sequence.append([VOCAB[w] for w in d.sequence]) if not for_prediction: label.append(LABELS[d.label]) sequence = LongVar(config, pad_seq(sequence)) if not for_prediction: label = LongVar(config, label) batch = indices, (sequence, ), (label) return batch
def initial_input(self, input_, encoder_output): story_states = self.__( encoder_output, 'encoder_output') seq_len, batch_size, hidden_size = story_states.size() decoder_input = self.__( LongVar([self.initial_decoder_input]).expand(batch_size), 'decoder_input') hidden = self.__( story_states[-1], 'hidden') context, _ = self.__( story_states.max(0), 'context') coverage = self.__( Var(torch.zeros(batch_size, seq_len)), 'coverage') return decoder_input, hidden, context, coverage
def batchop(datapoints, VOCAB, config, *args, **kwargs): indices = [d.id for d in datapoints] sequence = [] for d in datapoints: s = [] sequence.append([VOCAB[w] for w in d.sequence]) sequence = LongVar(config, pad_seq(sequence)) batch = indices, (sequence, ), () return batch
def batchop(datapoints, WORD2INDEX, *args, **kwargs): indices = [d.id for d in datapoints] story = [] question = [] answer = [] extvocab_story = [] extvocab_answer = [] def build_oov(d, WORD2INDEX): oov = [w for w in d.story + d.q + d.a if WORD2INDEX[w] == UNK] oov = list(set(oov)) return oov UNK = WORD2INDEX['UNK'] extvocab_size = 0 for d in datapoints: story.append([WORD2INDEX[w] for w in d.story] + [WORD2INDEX['EOS']]) question.append([WORD2INDEX[w] for w in d.q] + [WORD2INDEX['EOS']]) answer.append([WORD2INDEX[w] for w in d.a] + [WORD2INDEX['EOS']]) oov = build_oov(d, WORD2INDEX) extvocab_story.append( [ oov.index(w) + len(WORD2INDEX) if WORD2INDEX[w] == UNK else WORD2INDEX[w] for w in d.story] + [WORD2INDEX['EOS']] ) extvocab_answer.append( [ oov.index(w) + len(WORD2INDEX) if WORD2INDEX[w] == UNK else WORD2INDEX[w] for w in d.a] + [WORD2INDEX['EOS']] ) extvocab_size = max(extvocab_size, len(oov)) story = LongVar(pad_seq(story)) question = LongVar(pad_seq(question)) answer = LongVar(pad_seq(answer)) extvocab_answer = LongVar(pad_seq(extvocab_answer)) extvocab_story = LongVar(pad_seq(extvocab_story)) batch = indices, (story, question), (answer, extvocab_story, extvocab_answer, extvocab_size) return batch
def batchop(datapoints, VOCAB, GENDER, config, *args, **kwargs): indices = [d.id for d in datapoints] in_sequence = [] out_sequence = [] gender = [] for d in datapoints: in_sequence.append([VOCAB['GO']] + [VOCAB[w] for w in d.in_sequence] + [VOCAB['EOS']]) out_sequence.append([VOCAB['GO']] + [VOCAB[w] for w in d.out_sequence] + [VOCAB['EOS']]) gender.append(GENDER[d.gender]) in_sequence = LongVar(config, pad_seq(in_sequence)).transpose(0, 1) out_sequence = LongVar(config, pad_seq(out_sequence)).transpose(0, 1) gender = LongVar(config, gender) batch = indices, (gender, in_sequence), (out_sequence) return batch
def batchop(datapoints, VOCAB, config, *args, **kwargs): indices = [d.id for d in datapoints] max_len = max([d.max_token_len for d in datapoints]) word1 = [] word2 = [] existence = [] for d in datapoints: w1, w2 = d.pair word1.append([VOCAB[i] for i in w1]) word2.append([VOCAB[i] for i in w2]) existence.append(d.existence) word1 = LongVar(config, pad_seq(word1)) word2 = LongVar(config, pad_seq(word2)) existence = LongVar(config, existence) batch = indices, (word1, word2), existence return batch
def forward(self, context, question): context = LongVar(context) question = LongVar(question) batch_size, context_size = context.size() _, question_size = question.size() context = self.__(self.embed(context), 'context_emb') question = self.__(self.embed(question), 'question_emb') context = context.transpose(1, 0) C, _ = self.__( self.encode(context, init_hidden(batch_size, self.encode)), 'C') C = self.__(C.transpose(1, 0), 'C') s = self.__(self.sentinel(batch_size), 's') C = self.__(torch.cat([C, s], dim=1), 'C') question = question.transpose(1, 0) Q, _ = self.__( self.encode(question, init_hidden(batch_size, self.encode)), 'Q') Q = self.__(Q.transpose(1, 0), 'Q') s = self.__(self.sentinel(batch_size), 's') Q = self.__(torch.cat([Q, s], dim=1), 'Q') squashedQ = self.__(Q.view(batch_size * (question_size + 1), -1), 'squashedQ') transformedQ = self.__(F.tanh(self.linear(Q)), 'transformedQ') Q = self.__(Q.view(batch_size, question_size + 1, -1), 'Q') affinity = self.__(torch.bmm(C, Q.transpose(1, 2)), 'affinity') affinity = F.softmax(affinity, dim=-1) context_attn = self.__(affinity.transpose(1, 2), 'context_attn') question_attn = self.__(affinity, 'question_attn') context_question = self.__(torch.bmm(C.transpose(1, 2), question_attn), 'context_question') context_question = self.__( torch.cat([Q, context_question.transpose(1, 2)], -1), 'context_question') attn_cq = self.__( torch.bmm(context_question.transpose(1, 2), context_attn), 'attn_cq') attn_cq = self.__(attn_cq.transpose(1, 2).transpose(0, 1), 'attn_cq') hidden = self.__(init_hidden(batch_size, self.attend), 'hidden') final_repr, _ = self.__(self.attend(attn_cq, hidden), 'final_repr') final_repr = self.__(final_repr.transpose(0, 1), 'final_repr') return final_repr[:, :-1] #exclude sentinel
def batchop(datapoints, WORD2INDEX, *args, **kwargs): indices = [d.id for d in datapoints] context = [] question = [] answer_positions = [] answer_lengths = [] for d in datapoints: context.append([WORD2INDEX[w] for w in d.context] + [WORD2INDEX['EOS']]) question.append([WORD2INDEX[w] for w in d.q]) answer_length = len(d.a_positions) + 1 answer_positions.append([i for i in d.a_positions] + [len(d.context)]) answer_lengths.append(answer_length) context = LongVar(pad_seq(context)) question = LongVar(pad_seq(question)) answer_positions = LongVar(pad_seq(answer_positions)) answer_lengths = LongVar(answer_lengths) batch = indices, (context, question, answer_lengths), (answer_positions, ) return batch
def do_validate(self): self.eval() if self.test_feed.num_batch > 0: for j in tqdm(range(self.test_feed.num_batch), desc='Tester.{}'.format(self.name())): input_ = self.test_feed.next_batch() idxs, (gender, sequence), targets = input_ sequence = sequence.transpose(0,1) seq_size, batch_size = sequence.size() state = self.initial_hidden(batch_size) loss, accuracy = Var(self.config, [0]), Var(self.config, [0]) output = sequence[0] outputs = [] ti = 0 positions = LongVar(self.config, np.linspace(0, 1, seq_size)) for ti in range(1, sequence.size(0) - 1): output = self.forward(gender, positions[ti], output, state) loss += self.loss_function(ti, output, input_) accuracy += self.accuracy_function(ti, output, input_) output, state = output output = output.max(1)[1] outputs.append(output) self.test_loss.append(loss.item()) if ti == 0: ti = 1 self.accuracy.append(accuracy.item()/ti) #print('====', self.test_loss, self.accuracy) self.log.info('= {} =loss:{}'.format(self.epoch, self.test_loss)) self.log.info('- {} -accuracy:{}'.format(self.epoch, self.accuracy)) if len(self.best_model_criteria) > 1 and self.best_model[0] > self.best_model_criteria[-1]: self.log.info('beat best ..') self.best_model = (self.best_model_criteria[-1], self.cpu().state_dict()) self.save_best_model() if self.config.CONFIG.cuda: self.cuda() for m in self.metrics: m.write_to_file() if self.early_stopping: return self.loss_trend()
def do_predict(self, input_=None, length=10, beam_width=50): self.eval() if not input_: input_ = self.train_feed.nth_batch( random.randint(0, self.train_feed.size - 10), 1 ) try: idxs, (gender, sequence), targets = input_ sequence = sequence.transpose(0,1) seq_size, batch_size = sequence.size() state = self.initial_hidden(batch_size) loss = 0 output = sequence[1] outputs = [] positions = LongVar(self.config, np.linspace(0, 1, length)) for ti in range(length - 1): outputs.append(output) output = self.forward(gender, positions[ti], output, state) output, state = output output = output.topk(beam_width)[1] index = random.randint(0, beam_width-1) output = output[:, index] outputs = torch.stack(outputs).transpose(0,1) for i in range(outputs.size(0)): s = [self.dataset.input_vocab[outputs[i][j]] for j in range(outputs.size(1))] name = ''.join(s) name = ''.join([i for i in name if ord(i) >= 0x0B80 and ord(i) <= 0x0BFF]) print(self.dataset.gender_vocab[gender.item()], name, length, beam_width) return True except: self.log.exception('PREDICTION') print(locals())
class Model(Base): def __init__( self, config, name, input_vocab_size, output_vocab_size, kv_size, # sos_token sos_token, # feeds dataset, train_feed, test_feed, # loss function loss_function, f1score_function=None, save_model_weights=True, epochs=1000, checkpoint=1, early_stopping=True, # optimizer optimizer=None, ): super(Model, self).__init__(config, name) self.vocab_size = input_vocab_size self.kv_size = kv_size self.hidden_dim = config.HPCONFIG.hidden_dim self.embed_dim = config.HPCONFIG.embed_dim self.sos_token = LongVar(self.config, torch.Tensor([sos_token])) self.keys = nn.Parameter(torch.rand([self.kv_size, self.hidden_dim])) self.values = nn.Parameter(torch.rand([self.kv_size, self.hidden_dim])) self.encode = ConvEncoder(self.config, name + '.encoder', input_vocab_size) self.decode = Decoder(self.config, name + '.decoder', input_vocab_size, output_vocab_size) self.loss_function = loss_function if loss_function else nn.NLLLoss() self.f1score_function = f1score_function self.epochs = epochs self.checkpoint = checkpoint self.early_stopping = early_stopping self.dataset = dataset self.train_feed = train_feed self.test_feed = test_feed self.save_model_weights = save_model_weights self.__build_stats__() self.best_model_criteria = self.train_loss self.best_model = (100000, self.cpu().state_dict()) self.optimizer = optimizer if optimizer else optim.SGD( self.parameters(), lr=1, momentum=0.1) self.optimizer = optimizer if optimizer else optim.Adam( self.parameters()) if config.CONFIG.cuda: self.cuda() def restore_and_save(self): self.restore_checkpoint() self.save_best_model() def init_hidden(self, batch_size): hidden_state = Var(self.config, torch.zeros(1, batch_size, self.hidden_dim)) if self.config.CONFIG.cuda: hidden_state = hidden_state.cuda() return hidden_state def embed(self, word): encoded_info = self.__(self.encode(word), 'encoded_info') keys = self.__(self.keys.transpose(0, 1), 'keys') keys = self.__(keys.expand([encoded_info.size(0), *keys.size()]), 'keys') inner_product = self.__( torch.bmm( encoded_info.unsqueeze(1), #final state keys), 'inner_product') values = self.__(self.values, 'values') values = self.__( values.expand([inner_product.size(0), *values.size()]), 'values') weighted_sum = self.__(torch.bmm(inner_product, values), 'weighted_sum') weighted_sum = self.__(weighted_sum.squeeze(1), 'weighted_sum') return weighted_sum @profile def do_train(self): for epoch in range(self.epochs): self.log.critical('memory consumed : {}'.format(memory_consumed())) self.epoch = epoch if epoch and epoch % max(1, (self.checkpoint - 1)) == 0: #self.do_predict() if self.do_validate() == FLAGS.STOP_TRAINING: self.log.info('loss trend suggests to stop training') return self.train() losses = [] tracemalloc.start() for j in tqdm(range(self.train_feed.num_batch), desc='Trainer.{}'.format(self.name())): self.optimizer.zero_grad() input_ = self.train_feed.next_batch() idxs, word, targets = input_ loss = 0 encoded_info = self.__(self.encode(word), 'encoded_info') keys = self.__(self.keys.transpose(0, 1), 'keys') keys = self.__( keys.expand([encoded_info.size(0), *keys.size()]), 'keys') inner_product = self.__( torch.bmm( encoded_info.unsqueeze(1), #final state keys), 'inner_product') values = self.__(self.values, 'values') values = self.__( values.expand([inner_product.size(0), *values.size()]), 'values') weighted_sum = self.__(torch.bmm(inner_product, values), 'weighted_sum') weighted_sum = self.__(weighted_sum.squeeze(1), 'weighted_sum') #make the same chane in do_[predict|validate] tseq_len, batch_size = targets.size() state = self.__( (weighted_sum, self.init_hidden(batch_size).squeeze(0)), 'decoder initial state') #state = self.__( (encoded_info, state[1].squeeze(0)), 'decoder initial state') prev_output = self.__( self.sos_token.expand([encoded_info.size(0)]), 'sos_token') for i in range(targets.size(0)): output = self.decode(prev_output, state) loss += self.loss_function(output, targets[i]) prev_output = output.max(1)[1].long() losses.append(loss) loss.backward() self.optimizer.step() del input_ #, keys, values if j and not j % 100000: malloc_snap = tracemalloc.take_snapshot() display_tracemalloc_top(malloc_snap, limit=100) epoch_loss = torch.stack(losses).mean() self.train_loss.append(epoch_loss.data.item()) self.log.info('-- {} -- loss: {}\n'.format(epoch, epoch_loss)) for m in self.metrics: m.write_to_file() return True def do_validate(self): self.eval() if self.test_feed.num_batch > 0: losses, accuracies = [], [] for j in tqdm(range(self.test_feed.num_batch), desc='Tester.{}'.format(self.name())): input_ = self.test_feed.next_batch() idxs, word, targets = input_ loss = 0 encoded_info = self.__(self.encode(word), 'output') state = self.init_hidden(targets.size(1)) state = encoded_info[-1], state[1] prev_output = self.initial_token for i in range(targets.size(0)): output = self.decode(prev_ouptut, state) loss += self.loss_function(output, targets[i]) prev_output = output.max(1)[1].long() losses.append(loss) epoch_loss = torch.stack(losses).mean() self.test_loss.append(epoch_loss.data.item()) self.log.info('= {} =loss:{}'.format(self.epoch, epoch_loss)) if len(self.best_model_criteria) > 1: if self.best_model[0] > self.best_model_criteria[-1]: self.log.info('beat best ..') self.best_model = (self.best_model_criteria[-1], self.cpu().state_dict()) self.save_best_model() """ dump_vocab_tsv(self.config, self.dataset.input_vocab, self.embed.weight.data.cpu().numpy(), self.config.ROOT_DIR + '/vocab.tsv') """ if self.config.CONFIG.cuda: self.cuda() for m in self.metrics: m.write_to_file() if self.early_stopping: return self.loss_trend() def do_predict(self, input_=None, max_len=10): if not input_: input_ = self.train_feed.nth_batch( random.randint(0, self.train_feed.size), 1) idxs, word, targets = input_ loss = 0 outputs = [] encoded_info = self.__(self.encode(word), 'output') state = self.init_hidden(targets.size(1)) state = encoded_info[-1], state[1] prev_output = self.initial_token for i in range(targets.size(0)): output = self.decode(prev_ouptut, state) loss += self.loss_function(output, targets[i]) prev_output = output.max(1)[1] outputs.append(prev_output) output = output.max(1)[1].long() print(output.size()) ids, (sequence, ), (label) = input_ print(' '.join([ self.dataset.input_vocab[i.data[0]] for i in sequence[0] ]).replace('@@ ', '')) print(self.dataset.output_vocab[output.data[0]], ' ==? ', self.dataset.output_vocab[label.data[0]]) return True
def __init__(self, # config and name config, name, # model parameters vocab_size, gender_vocab_size, # feeds dataset, pretrain_feed, train_feed, test_feed, # loss function loss_function, accuracy_function=None, f1score_function=None, save_model_weights=True, epochs = 1000, checkpoint = 1, early_stopping = True, # optimizer optimizer = None, ): super().__init__(config, name) self.config = config self.embed_size = config.HPCONFIG.embed_size self.hidden_size = config.HPCONFIG.hidden_size self.vocab_size = vocab_size self.gender_vocab_size = gender_vocab_size self.loss_function = loss_function self.embed = nn.Embedding(self.vocab_size, self.embed_size) self.gender_embed = nn.Embedding(self.gender_vocab_size, self.embed_size) self.position = nn.Linear(1, self.embed_size) self.position_range = LongVar(self.config, np.arange(0, 100)).unsqueeze(1).float() self.blend = nn.Linear(3 * self.embed_size, self. embed_size) self.lm = nn.GRUCell(self.embed.embedding_dim, self.hidden_size) self.dropout = nn.Dropout(0.1) self.answer = nn.Linear(self.hidden_size, self.vocab_size) self.loss_function = loss_function if loss_function else nn.NLLLoss() self.accuracy_function = accuracy_function if accuracy_function else lambda *x, **xx: 1 / loss_function(*x, **xx) self.optimizer = optimizer if optimizer else optim.SGD(self.parameters(),lr=0.00001, momentum=0.001) self.optimizer = optimizer if optimizer else optim.Adam(self.parameters()) self.f1score_function = f1score_function self.epochs = epochs self.checkpoint = checkpoint self.early_stopping = early_stopping self.dataset = dataset self.pretrain_feed = pretrain_feed self.train_feed = train_feed self.test_feed = test_feed ######################################################################################## # Saving model weights ######################################################################################## self.save_model_weights = save_model_weights self.best_model = (1e5, self.cpu().state_dict()) try: f = '{}/{}_best_model_accuracy.txt'.format(self.config.ROOT_DIR, self.name()) if os.path.isfile(f): self.best_model = (float(open(f).read().strip()), self.cpu().state_dict()) self.log.info('loaded last best accuracy: {}'.format(self.best_model[0])) except: log.exception('no last best model') self.__build_stats__() self.best_model_criteria = self.train_loss if config.CONFIG.cuda: self.cuda()
ROOT_DIR, model, VOCAB, LABELS, datapoints=[train_set, train_set + test_set, train_set + test_set]) if 'predict' in sys.argv: print('=========== PREDICTION ==============') model.eval() count = 0 while True: count += 1 sentence = [] input_string = word_tokenize(input('?').lower()) sentence.append([VOCAB[w] for w in input_string] + [VOCAB['EOS']]) dummy_label = LongVar([0]) sentence = LongVar(sentence) input_ = [0], (sentence, ), (0, ) output, attn = model(input_) print(LABELS[output.max(1)[1]]) if 'show_plot' in sys.argv or 'save_plot' in sys.argv: nwords = len(input_string) from matplotlib import pyplot as plt plt.figure(figsize=(20, 10)) plt.bar(range(nwords + 1), attn.squeeze().data.cpu().numpy()) plt.title('{}\n{}'.format(output.exp().tolist(), LABELS[output.max(1)[1]])) plt.xticks(range(nwords), input_string, rotation='vertical')
def initial_input(self, batch_size): decoder_input = LongVar([self.initial_decoder_input ]).expand(batch_size) return decoder_input
def __init__( self, config, name, input_vocab_size, output_vocab_size, kv_size, # sos_token sos_token, # feeds dataset, train_feed, test_feed, # loss function loss_function, f1score_function=None, save_model_weights=True, epochs=1000, checkpoint=1, early_stopping=True, # optimizer optimizer=None, ): super(Model, self).__init__(config, name) self.vocab_size = input_vocab_size self.kv_size = kv_size self.hidden_dim = config.HPCONFIG.hidden_dim self.embed_dim = config.HPCONFIG.embed_dim self.sos_token = LongVar(self.config, torch.Tensor([sos_token])) self.keys = nn.Parameter(torch.rand([self.kv_size, self.hidden_dim])) self.values = nn.Parameter(torch.rand([self.kv_size, self.hidden_dim])) self.encode = ConvEncoder(self.config, name + '.encoder', input_vocab_size) self.decode = Decoder(self.config, name + '.decoder', input_vocab_size, output_vocab_size) self.loss_function = loss_function if loss_function else nn.NLLLoss() self.f1score_function = f1score_function self.epochs = epochs self.checkpoint = checkpoint self.early_stopping = early_stopping self.dataset = dataset self.train_feed = train_feed self.test_feed = test_feed self.save_model_weights = save_model_weights self.__build_stats__() self.best_model_criteria = self.train_loss self.best_model = (100000, self.cpu().state_dict()) self.optimizer = optimizer if optimizer else optim.SGD( self.parameters(), lr=1, momentum=0.1) self.optimizer = optimizer if optimizer else optim.Adam( self.parameters()) if config.CONFIG.cuda: self.cuda()
def do_predict(self, input_=None, max_len=100, length=10, beam_width=4, teacher_force=False): if not input_: input_ = self.train_feed.nth_batch( random.randint(0, self.train_feed.size), 1) idxs, (gender, seq), target = input_ #seq = seq[1:] #target = target[1:] seq_size, batch_size = seq.size() pad_mask = (seq > 0).float() hidden_states, (hidden, cell_state) = self.__(self.encode_sequence(seq), 'encoded_outpus') outputs = [] target_size, batch_size = seq.size() output = self.__(target[0], 'hidden') outputs.append(output) state = self.__((hidden, cell_state), 'init_hidden') gender_embedding = self.gender_embed(gender) null_tensor = LongVar(self.config, [self.dataset.input_vocab['_']]) for i in range(1, target_size): output, state = self.__( self.decode(hidden_states, output, state, gender_embedding), 'output, state') output = output.topk(beam_width)[1] index = random.randint(0, beam_width - 1) output = output[:, index] if teacher_force: #teacher force only where non '_' characters are given if seq[i].eq(null_tensor.expand_as( seq[i])).sum().float() < 0.5: #print(seq[i].eq(null_tensor.expand_as(seq[i])).sum()) output = seq[i] #print(self.dataset.input_vocab[seq[i][0]]) outputs.append(output) outputs = torch.stack(outputs).long().t() seq = seq.t() #print(output.size()) #print(seq.size()) #print(target.size()) print(''.join( [self.dataset.input_vocab[i.item()] for i in target[1:-1]]), end='\t') print(''.join( [self.dataset.input_vocab[i.item()] for i in seq[0][1:-1]]), end='\t') print(''.join( [self.dataset.input_vocab[i.item()] for i in outputs[0][1:-1]])) return True