Example #1
0
    def forward(self, input_):
        idxs, inputs, targets = input_
        story, question = inputs
        story = self.__( story, 'story')
        question = self.__(question, 'question')

        batch_size, story_size  = story.size()
        batch_size, question_size = question.size()
        
        story  = self.__( self.embed(story),  'story_emb')
        question = self.__( self.embed(question), 'question_emb')

        story  = story.transpose(1,0)
        story, _  = self.__(  self.encode_story(
            story,
            init_hidden(batch_size, self.encode_story)), 'C'
        )
        
        question  = question.transpose(1,0)
        question, _ = self.__(  self.encode_question(
            question,
            init_hidden(batch_size, self.encode_question)), 'Q'
        )

        c, m, r = [], [], []
        c.append(Var(np.zeros((batch_size, 2 * self.hidden_size))))
        m.append(Var(np.zeros((batch_size, 2 * self.hidden_size))))
        qi = self.dropout(self.produce_qi(torch.cat([question[-1], m[-1]], dim=-1)))

        qattns, sattns, mattns = [], [], []
        
        for i in range(config.HPCONFIG.reasoning_steps):

            ci, qattn = self.control(c[-1], qi, question, m[-1])
            ci = self.dropout(ci)

            ri, sattn = self.read(m[-1], ci, story)
            ri = self.dropout(ri)

            mi, mattn = self.write( m[-1], ri, ci, c, m )
            mi = self.dropout(mi)

            qi = self.dropout(self.produce_qi(torch.cat([qi, m[-1]], dim=-1)))
            
            c.append(ci)
            r.append(ri)
            m.append(mi)

            qattns.append(qattn)
            sattns.append(sattn)
            mattns.append(mattn)
            
        #projected_output = self.__( F.relu(self.project(torch.cat([qi, mi], dim=-1))), 'projected_output')
        return (self.__( F.log_softmax(self.answer(mi), dim=-1), 'return val'),
                (
                    torch.stack(sattns),
                    torch.stack(qattns),
                    torch.stack(mattns)
                )
        )
Example #2
0
    def init_hidden(self, batch_size):
        hidden_state = Var(self.config,
                           torch.zeros(1, batch_size, self.hidden_dim))

        if self.config.CONFIG.cuda:
            hidden_state = hidden_state.cuda()

        return hidden_state
Example #3
0
    def do_validate(self):
        self.eval()
        for j in tqdm(range(self.test_feed.num_batch),
                      desc='Tester.{}'.format(self.name())):
            input_ = self.test_feed.next_batch()
            idxs, inputs, targets = input_
            sequence = inputs[0].transpose(0, 1)
            _, batch_size = sequence.size()

            state = self.initial_hidden(batch_size)
            loss, accuracy = Var(self.config, [0]), Var(self.config, [0])
            output = sequence[0]
            outputs = []
            ti = 0
            for ti in range(1, sequence.size(0) - 1):
                output = self.forward(output, state)
                loss += self.loss_function(ti, output, input_)
                accuracy += self.accuracy_function(ti, output, input_)
                output, state = output
                output = output.max(1)[1]
                outputs.append(output)

            self.test_loss.cache(loss.item())
            if ti == 0: ti = 1
            self.accuracy.cache(accuracy.item() / ti)
            #print('====', self.test_loss, self.accuracy)

        self.log.info('= {} =loss:{}'.format(self.epoch,
                                             self.test_loss.epoch_cache))
        self.log.info('- {} -accuracy:{}'.format(self.epoch,
                                                 self.accuracy.epoch_cache))

        if self.best_model[0] < self.accuracy.epoch_cache.avg:
            self.log.info('beat best ..')
            last_acc = self.best_model[0]
            self.best_model = (self.accuracy.epoch_cache.avg,
                               self.state_dict())

            self.save_best_model()

            if self.config.CONFIG.cuda:
                self.cuda()

        self.test_loss.clear_cache()
        self.accuracy.clear_cache()

        for m in self.metrics:
            m.write_to_file()

        if self.early_stopping:
            return self.loss_trend()
Example #4
0
    def forward(self, seq):
        seq = Var(seq)
        if seq.dim() == 1: seq = seq.unsqueeze(0)

        batch_size, seq_size = seq.size()
        seq_emb = F.tanh(self.embed(seq))
        seq_emb = seq_emb.transpose(1, 0)
        pad_mask = (seq > 0).float()

        states, cell_state = self.encode(seq_emb)

        logits = self.classify(states[-1])

        return F.log_softmax(logits, dim=-1)
Example #5
0
    def do_validate(self):
        self.eval()
        if self.test_feed.num_batch > 0:
            for j in tqdm(range(self.test_feed.num_batch), desc='Tester.{}'.format(self.name())):
                input_ = self.test_feed.next_batch()
                idxs, (gender, sequence), targets = input_
                sequence = sequence.transpose(0,1)
                seq_size, batch_size = sequence.size()

                state = self.initial_hidden(batch_size)
                loss, accuracy = Var(self.config, [0]), Var(self.config, [0])
                output = sequence[0]
                outputs = []
                ti = 0
                positions = LongVar(self.config, np.linspace(0, 1, seq_size))
                for ti in range(1, sequence.size(0) - 1):
                    output = self.forward(gender, positions[ti], output, state)
                    loss += self.loss_function(ti, output, input_)
                    accuracy += self.accuracy_function(ti, output, input_)
                    output, state = output
                    output = output.max(1)[1]
                    outputs.append(output)

                self.test_loss.append(loss.item())
                if ti == 0: ti = 1
                self.accuracy.append(accuracy.item()/ti)
                #print('====', self.test_loss, self.accuracy)

            self.log.info('= {} =loss:{}'.format(self.epoch, self.test_loss))
            self.log.info('- {} -accuracy:{}'.format(self.epoch, self.accuracy))

            
        if len(self.best_model_criteria) > 1 and self.best_model[0] > self.best_model_criteria[-1]:
            self.log.info('beat best ..')
            self.best_model = (self.best_model_criteria[-1],
                               self.cpu().state_dict())                             

            self.save_best_model()
            
            if self.config.CONFIG.cuda:
                self.cuda()

        
        for m in self.metrics:
            m.write_to_file()
            
        if self.early_stopping:
            return self.loss_trend()
Example #6
0
 def initial_input(self, input_, encoder_output):
     story_states = self.__( encoder_output, 'encoder_output')
     seq_len, batch_size, hidden_size = story_states.size()
     decoder_input = self.__( LongVar([self.initial_decoder_input]).expand(batch_size), 'decoder_input')
     hidden = self.__( story_states[-1], 'hidden')
     context, _ = self.__( story_states.max(0), 'context')
     coverage = self.__( Var(torch.zeros(batch_size, seq_len)), 'coverage')
         
     return decoder_input, hidden, context,  coverage
Example #7
0
def waccuracy(output, batch, config, *args, **kwargs):
    indices, (sequence, ), (label) = batch

    index = label
    src = Var(config, torch.ones(label.size()))
    
    acc_nomin = Var(config, torch.zeros(output.size(1)))
    acc_denom = Var(config, torch.ones(output.size(1)))

    acc_denom.scatter_add_(0, index, (label == label).float() )
    acc_nomin.scatter_add_(0, index, (label == output.max(1)[1]).float())

    accuracy = acc_nomin / acc_denom

    #pdb.set_trace()
    return accuracy.mean()
Example #8
0
def process_output(decoding_index, output, batch,  *args, **kwargs):
    indices, (story, question), (answer, extvocab_story, target, extvocab_size) = batch
    pgen, vocab_dist, hidden, context, attn_dist, coverage = output

    vocab_dist, attn_dist = pgen * vocab_dist, (1-pgen) * attn_dist
    batch_size, vocab_size = vocab_dist.size()
    output  = vocab_dist
    if extvocab_size:
        zeros      = Var( torch.zeros(batch_size, extvocab_size) )
        vocab_dist = torch.cat( [vocab_dist, zeros], dim=-1 )
        output     = vocab_dist.scatter_add_(1, extvocab_story, attn_dist)

    return output
def f1score(output, input_, *args, **kwargs):

    indices, (seq, ), (target, ) = input_
    output, attn = output
    batch_size, class_size = output.size()

    tp, tn, fp, fn = Var([0]), Var([0]), Var([0]), Var([0])
    p, r, f1 = Var([0]), Var([0]), Var([0])

    i = output
    t = target
    i = i.max(dim=1)[1]
    log.debug('output:{}'.format(pformat(i)))
    log.debug('target:{}'.format(pformat(t)))
    i_ = i
    t_ = t
    tp_ = (i_ * t_).sum().float()
    fp_ = (i_ > t_).sum().float()
    fn_ = (i_ < t_).sum().float()

    i_ = i == 0
    t_ = t == 0
    tn_ = (i_ * t_).sum().float()

    tp += tp_
    tn += tn_
    fp += fp_
    fn += fn_

    log.debug('tp_: {}\n fp_:{} \n fn_: {}\n tn_: {}'.format(
        tp_, fp_, fn_, tn_))

    if tp_.data.item() > 0:
        p_ = tp_ / (tp_ + fp_)
        r_ = tp_ / (tp_ + fn_)
        f1 += 2 * p_ * r_ / (p_ + r_)
        p += p_
        r += r_

    return (tp, fn, fp, tn), (p), (r), (f1)
Example #10
0
    def forward(self, encoder_output, decoder_input, input_):
        context_states, question_states = self.__(encoder_output,
                                                  'encoder_output')
        seq_len, batch_size, hidden_size = context_states.size()
        dropout = self.dropout
        decoder_input, hidden = decoder_input

        decoder_input = self.__(self.embed(decoder_input), 'decoder_input')
        if not isinstance(hidden, torch.Tensor):
            hidden = context_states[-1]

        hidden = self.__(self.decode(decoder_input, hidden), 'decoder_output')
        hidden = F.tanh(hidden)
        #combine question and current hidden state and project
        query = self.__(torch.cat([question_states[-1], hidden], dim=-1),
                        'query')
        query = self.__(self.project_query(query), 'projected query')
        query = F.tanh(query)

        sentinel = self.__(
            Var(torch.zeros(batch_size, context_states.size(2))), 'sentinel')
        pointer_predistribution = self.__(
            torch.cat([
                context_states,
                sentinel.unsqueeze(0),
            ], dim=0), 'pointer_predistribution').transpose(0, 1)

        attn = self.__(self.attn.unsqueeze(0), 'attn')
        attn = self.__(
            torch.bmm(query.unsqueeze(1),
                      attn.expand(batch_size, *self.attn.size())), 'attn')
        attn = self.__(
            torch.bmm(attn, pointer_predistribution.transpose(1, 2)),
            'attn').squeeze(1)

        ret = self.__(F.log_softmax(attn, dim=-1), 'ret')
        return ret, hidden
def experiment(config,
               ROOT_DIR,
               model,
               VOCAB,
               LABELS,
               datapoints=[[], [], []],
               eons=1000,
               epochs=20,
               checkpoint=1):
    try:
        name = SELF_NAME
        _batchop = partial(batchop, VOCAB=VOCAB, LABELS=LABELS)
        train_feed = DataFeed(name,
                              datapoints[0],
                              batchop=_batchop,
                              batch_size=config.HPCONFIG.batch_size)
        test_feed = DataFeed(name,
                             datapoints[1],
                             batchop=_batchop,
                             batch_size=config.HPCONFIG.batch_size)
        predictor_feed = DataFeed(name,
                                  datapoints[2],
                                  batchop=_batchop,
                                  batch_size=1)

        max_freq = max(LABELS.freq_dict[i] for i in LABELS.index2word)
        loss_weight = [
            1 / (LABELS.freq_dict[i] / max_freq) for i in LABELS.index2word
        ]
        print(list((l, w) for l, w in zip(LABELS.index2word, loss_weight)))
        loss_weight = Var(loss_weight)

        loss_ = partial(loss, loss_function=nn.NLLLoss(loss_weight))
        trainer = Trainer(name=name,
                          model=model,
                          optimizer=optim.SGD(
                              model.parameters(),
                              lr=config.HPCONFIG.OPTIM.lr,
                              momentum=config.HPCONFIG.OPTIM.momentum),
                          loss_function=loss_,
                          accuracy_function=waccuracy,
                          f1score_function=f1score,
                          checkpoint=checkpoint,
                          epochs=epochs,
                          directory=ROOT_DIR,
                          feeder=Feeder(train_feed, test_feed))

        predictor = Predictor(model=model.clone(),
                              feed=predictor_feed,
                              repr_function=partial(test_repr_function,
                                                    VOCAB=VOCAB,
                                                    LABELS=LABELS))

        for e in range(eons):

            if not trainer.train():
                raise Exception

            predictor.model.load_state_dict(trainer.best_model[1])

            dump = open('{}/results/eon_{}.csv'.format(ROOT_DIR, e), 'w')
            log.info('on {}th eon'.format(e))
            results = ListTable()
            for ri in tqdm(range(predictor_feed.num_batch)):
                output, _results = predictor.predict(ri)
                results.extend(_results)
            dump.write(repr(results))
            dump.close()

    except KeyboardInterrupt:
        return locals()
    except:
        log.exception('####################')
        return locals()