def forward(self, input_): idxs, inputs, targets = input_ story, question = inputs story = self.__( story, 'story') question = self.__(question, 'question') batch_size, story_size = story.size() batch_size, question_size = question.size() story = self.__( self.embed(story), 'story_emb') question = self.__( self.embed(question), 'question_emb') story = story.transpose(1,0) story, _ = self.__( self.encode_story( story, init_hidden(batch_size, self.encode_story)), 'C' ) question = question.transpose(1,0) question, _ = self.__( self.encode_question( question, init_hidden(batch_size, self.encode_question)), 'Q' ) c, m, r = [], [], [] c.append(Var(np.zeros((batch_size, 2 * self.hidden_size)))) m.append(Var(np.zeros((batch_size, 2 * self.hidden_size)))) qi = self.dropout(self.produce_qi(torch.cat([question[-1], m[-1]], dim=-1))) qattns, sattns, mattns = [], [], [] for i in range(config.HPCONFIG.reasoning_steps): ci, qattn = self.control(c[-1], qi, question, m[-1]) ci = self.dropout(ci) ri, sattn = self.read(m[-1], ci, story) ri = self.dropout(ri) mi, mattn = self.write( m[-1], ri, ci, c, m ) mi = self.dropout(mi) qi = self.dropout(self.produce_qi(torch.cat([qi, m[-1]], dim=-1))) c.append(ci) r.append(ri) m.append(mi) qattns.append(qattn) sattns.append(sattn) mattns.append(mattn) #projected_output = self.__( F.relu(self.project(torch.cat([qi, mi], dim=-1))), 'projected_output') return (self.__( F.log_softmax(self.answer(mi), dim=-1), 'return val'), ( torch.stack(sattns), torch.stack(qattns), torch.stack(mattns) ) )
def init_hidden(self, batch_size): hidden_state = Var(self.config, torch.zeros(1, batch_size, self.hidden_dim)) if self.config.CONFIG.cuda: hidden_state = hidden_state.cuda() return hidden_state
def do_validate(self): self.eval() for j in tqdm(range(self.test_feed.num_batch), desc='Tester.{}'.format(self.name())): input_ = self.test_feed.next_batch() idxs, inputs, targets = input_ sequence = inputs[0].transpose(0, 1) _, batch_size = sequence.size() state = self.initial_hidden(batch_size) loss, accuracy = Var(self.config, [0]), Var(self.config, [0]) output = sequence[0] outputs = [] ti = 0 for ti in range(1, sequence.size(0) - 1): output = self.forward(output, state) loss += self.loss_function(ti, output, input_) accuracy += self.accuracy_function(ti, output, input_) output, state = output output = output.max(1)[1] outputs.append(output) self.test_loss.cache(loss.item()) if ti == 0: ti = 1 self.accuracy.cache(accuracy.item() / ti) #print('====', self.test_loss, self.accuracy) self.log.info('= {} =loss:{}'.format(self.epoch, self.test_loss.epoch_cache)) self.log.info('- {} -accuracy:{}'.format(self.epoch, self.accuracy.epoch_cache)) if self.best_model[0] < self.accuracy.epoch_cache.avg: self.log.info('beat best ..') last_acc = self.best_model[0] self.best_model = (self.accuracy.epoch_cache.avg, self.state_dict()) self.save_best_model() if self.config.CONFIG.cuda: self.cuda() self.test_loss.clear_cache() self.accuracy.clear_cache() for m in self.metrics: m.write_to_file() if self.early_stopping: return self.loss_trend()
def forward(self, seq): seq = Var(seq) if seq.dim() == 1: seq = seq.unsqueeze(0) batch_size, seq_size = seq.size() seq_emb = F.tanh(self.embed(seq)) seq_emb = seq_emb.transpose(1, 0) pad_mask = (seq > 0).float() states, cell_state = self.encode(seq_emb) logits = self.classify(states[-1]) return F.log_softmax(logits, dim=-1)
def do_validate(self): self.eval() if self.test_feed.num_batch > 0: for j in tqdm(range(self.test_feed.num_batch), desc='Tester.{}'.format(self.name())): input_ = self.test_feed.next_batch() idxs, (gender, sequence), targets = input_ sequence = sequence.transpose(0,1) seq_size, batch_size = sequence.size() state = self.initial_hidden(batch_size) loss, accuracy = Var(self.config, [0]), Var(self.config, [0]) output = sequence[0] outputs = [] ti = 0 positions = LongVar(self.config, np.linspace(0, 1, seq_size)) for ti in range(1, sequence.size(0) - 1): output = self.forward(gender, positions[ti], output, state) loss += self.loss_function(ti, output, input_) accuracy += self.accuracy_function(ti, output, input_) output, state = output output = output.max(1)[1] outputs.append(output) self.test_loss.append(loss.item()) if ti == 0: ti = 1 self.accuracy.append(accuracy.item()/ti) #print('====', self.test_loss, self.accuracy) self.log.info('= {} =loss:{}'.format(self.epoch, self.test_loss)) self.log.info('- {} -accuracy:{}'.format(self.epoch, self.accuracy)) if len(self.best_model_criteria) > 1 and self.best_model[0] > self.best_model_criteria[-1]: self.log.info('beat best ..') self.best_model = (self.best_model_criteria[-1], self.cpu().state_dict()) self.save_best_model() if self.config.CONFIG.cuda: self.cuda() for m in self.metrics: m.write_to_file() if self.early_stopping: return self.loss_trend()
def initial_input(self, input_, encoder_output): story_states = self.__( encoder_output, 'encoder_output') seq_len, batch_size, hidden_size = story_states.size() decoder_input = self.__( LongVar([self.initial_decoder_input]).expand(batch_size), 'decoder_input') hidden = self.__( story_states[-1], 'hidden') context, _ = self.__( story_states.max(0), 'context') coverage = self.__( Var(torch.zeros(batch_size, seq_len)), 'coverage') return decoder_input, hidden, context, coverage
def waccuracy(output, batch, config, *args, **kwargs): indices, (sequence, ), (label) = batch index = label src = Var(config, torch.ones(label.size())) acc_nomin = Var(config, torch.zeros(output.size(1))) acc_denom = Var(config, torch.ones(output.size(1))) acc_denom.scatter_add_(0, index, (label == label).float() ) acc_nomin.scatter_add_(0, index, (label == output.max(1)[1]).float()) accuracy = acc_nomin / acc_denom #pdb.set_trace() return accuracy.mean()
def process_output(decoding_index, output, batch, *args, **kwargs): indices, (story, question), (answer, extvocab_story, target, extvocab_size) = batch pgen, vocab_dist, hidden, context, attn_dist, coverage = output vocab_dist, attn_dist = pgen * vocab_dist, (1-pgen) * attn_dist batch_size, vocab_size = vocab_dist.size() output = vocab_dist if extvocab_size: zeros = Var( torch.zeros(batch_size, extvocab_size) ) vocab_dist = torch.cat( [vocab_dist, zeros], dim=-1 ) output = vocab_dist.scatter_add_(1, extvocab_story, attn_dist) return output
def f1score(output, input_, *args, **kwargs): indices, (seq, ), (target, ) = input_ output, attn = output batch_size, class_size = output.size() tp, tn, fp, fn = Var([0]), Var([0]), Var([0]), Var([0]) p, r, f1 = Var([0]), Var([0]), Var([0]) i = output t = target i = i.max(dim=1)[1] log.debug('output:{}'.format(pformat(i))) log.debug('target:{}'.format(pformat(t))) i_ = i t_ = t tp_ = (i_ * t_).sum().float() fp_ = (i_ > t_).sum().float() fn_ = (i_ < t_).sum().float() i_ = i == 0 t_ = t == 0 tn_ = (i_ * t_).sum().float() tp += tp_ tn += tn_ fp += fp_ fn += fn_ log.debug('tp_: {}\n fp_:{} \n fn_: {}\n tn_: {}'.format( tp_, fp_, fn_, tn_)) if tp_.data.item() > 0: p_ = tp_ / (tp_ + fp_) r_ = tp_ / (tp_ + fn_) f1 += 2 * p_ * r_ / (p_ + r_) p += p_ r += r_ return (tp, fn, fp, tn), (p), (r), (f1)
def forward(self, encoder_output, decoder_input, input_): context_states, question_states = self.__(encoder_output, 'encoder_output') seq_len, batch_size, hidden_size = context_states.size() dropout = self.dropout decoder_input, hidden = decoder_input decoder_input = self.__(self.embed(decoder_input), 'decoder_input') if not isinstance(hidden, torch.Tensor): hidden = context_states[-1] hidden = self.__(self.decode(decoder_input, hidden), 'decoder_output') hidden = F.tanh(hidden) #combine question and current hidden state and project query = self.__(torch.cat([question_states[-1], hidden], dim=-1), 'query') query = self.__(self.project_query(query), 'projected query') query = F.tanh(query) sentinel = self.__( Var(torch.zeros(batch_size, context_states.size(2))), 'sentinel') pointer_predistribution = self.__( torch.cat([ context_states, sentinel.unsqueeze(0), ], dim=0), 'pointer_predistribution').transpose(0, 1) attn = self.__(self.attn.unsqueeze(0), 'attn') attn = self.__( torch.bmm(query.unsqueeze(1), attn.expand(batch_size, *self.attn.size())), 'attn') attn = self.__( torch.bmm(attn, pointer_predistribution.transpose(1, 2)), 'attn').squeeze(1) ret = self.__(F.log_softmax(attn, dim=-1), 'ret') return ret, hidden
def experiment(config, ROOT_DIR, model, VOCAB, LABELS, datapoints=[[], [], []], eons=1000, epochs=20, checkpoint=1): try: name = SELF_NAME _batchop = partial(batchop, VOCAB=VOCAB, LABELS=LABELS) train_feed = DataFeed(name, datapoints[0], batchop=_batchop, batch_size=config.HPCONFIG.batch_size) test_feed = DataFeed(name, datapoints[1], batchop=_batchop, batch_size=config.HPCONFIG.batch_size) predictor_feed = DataFeed(name, datapoints[2], batchop=_batchop, batch_size=1) max_freq = max(LABELS.freq_dict[i] for i in LABELS.index2word) loss_weight = [ 1 / (LABELS.freq_dict[i] / max_freq) for i in LABELS.index2word ] print(list((l, w) for l, w in zip(LABELS.index2word, loss_weight))) loss_weight = Var(loss_weight) loss_ = partial(loss, loss_function=nn.NLLLoss(loss_weight)) trainer = Trainer(name=name, model=model, optimizer=optim.SGD( model.parameters(), lr=config.HPCONFIG.OPTIM.lr, momentum=config.HPCONFIG.OPTIM.momentum), loss_function=loss_, accuracy_function=waccuracy, f1score_function=f1score, checkpoint=checkpoint, epochs=epochs, directory=ROOT_DIR, feeder=Feeder(train_feed, test_feed)) predictor = Predictor(model=model.clone(), feed=predictor_feed, repr_function=partial(test_repr_function, VOCAB=VOCAB, LABELS=LABELS)) for e in range(eons): if not trainer.train(): raise Exception predictor.model.load_state_dict(trainer.best_model[1]) dump = open('{}/results/eon_{}.csv'.format(ROOT_DIR, e), 'w') log.info('on {}th eon'.format(e)) results = ListTable() for ri in tqdm(range(predictor_feed.num_batch)): output, _results = predictor.predict(ri) results.extend(_results) dump.write(repr(results)) dump.close() except KeyboardInterrupt: return locals() except: log.exception('####################') return locals()