Ejemplo n.º 1
0
 def sample(self,batch_size,seq_len,data=None):
     #用来采样出一个batch的结果,
     """
     data 是已有序列
     """
     #如果没有data就从0开始
     if data is None:
         sample_batch=(torch.zeros(batch_size,seq_len).type(torch.LongTensor)).cuda()
         inp=Variable(torch.zeros(batch_size,1).type(torch.LongTensor)).cuda()
         h=self.init_hidden(batch_size)
         for i in range(seq_len):
             output,h=self.forward(inp,h)
             output=torch.multinomial(output.exp().squeeze(),1)
             sample_batch[:,i]=output.data
             inp=output
         return sample_batch
     #否则就从部分开始
     else:
         sample_batch=(torch.zeros(batch_size,seq_len).type(torch.LongTensor)).cuda()
         inp=Variable(torch.zeros(batch_size,1).type(torch.LongTensor)).cuda()
         h=self.init_hidden(batch_size)
         for i in range(seq_len):
             if i<data.size(1):
                 inp=data[:,i].unsqueeze(1)
             else:
                 inp=sample_batch[:,i-1].unsqueeze(1)
             output,h=self.forward(inp,h)
             output=torch.multinomial(output.exp().squeeze(),1)
             sample_batch[:,i]=output.data
         return sample_batch
Ejemplo n.º 2
0
def predict_fn(input_data, model):
    logger.info('Generating text based on input parameters.')
    corpus = model['corpus']
    model = model['model']

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    logger.info('Current device: {}'.format(device))
    torch.manual_seed(input_data['seed'])
    ntokens = len(corpus.dictionary)
    input = torch.randint(ntokens, (1, 1), dtype=torch.long).to(device)
    hidden = model.init_hidden(1)

    logger.info('Generating {} words.'.format(input_data['words']))
    result = []
    with torch.no_grad():  # no tracking history
        for i in range(input_data['words']):
            output, hidden = model(input, hidden)
            word_weights = output.squeeze().div(input_data['temperature']).exp().cpu()
            word_idx = torch.multinomial(word_weights, 1)[0]
            input.fill_(word_idx)
            word = corpus.dictionary.idx2word[word_idx]
            word = word if type(word) == str else word.decode()
            if word == '<eos>':
                word = '\n'
            elif i % 12 == 11:
                word = word + '\n'
            else:
                word = word + ' '
            result.append(word)
    return ''.join(result)
Ejemplo n.º 3
0
 def sample(self, sample_shape=torch.Size()):
     sample_shape = self._extended_shape(sample_shape)
     param_shape = sample_shape + torch.Size((self._num_events,))
     probs = self.probs.expand(param_shape)
     probs_2d = probs.contiguous().view(-1, self._num_events)
     sample_2d = torch.multinomial(probs_2d, 1, True)
     return sample_2d.contiguous().view(sample_shape)
    def forward(self, fc_feats, att_feats, seq):
        batch_size = fc_feats.size(0)
        state = self.init_hidden(batch_size)
        outputs = []

        for i in range(seq.size(1)):
            if i == 0:
                xt = self.img_embed(fc_feats)
            else:
                if self.training and i >= 2 and self.ss_prob > 0.0: # otherwiste no need to sample
                    sample_prob = fc_feats.data.new(batch_size).uniform_(0, 1)
                    sample_mask = sample_prob < self.ss_prob
                    if sample_mask.sum() == 0:
                        it = seq[:, i-1].clone()
                    else:
                        sample_ind = sample_mask.nonzero().view(-1)
                        it = seq[:, i-1].data.clone()
                        #prob_prev = torch.exp(outputs[-1].data.index_select(0, sample_ind)) # fetch prev distribution: shape Nx(M+1)
                        #it.index_copy_(0, sample_ind, torch.multinomial(prob_prev, 1).view(-1))
                        prob_prev = torch.exp(outputs[-1].data) # fetch prev distribution: shape Nx(M+1)
                        it.index_copy_(0, sample_ind, torch.multinomial(prob_prev, 1).view(-1).index_select(0, sample_ind))
                        it = Variable(it, requires_grad=False)
                else:
                    it = seq[:, i-1].clone()
                # break if all the sequences end
                if i >= 2 and seq[:, i-1].data.sum() == 0:
                    break
                xt = self.embed(it)

            output, state = self.core(xt, state)
            output = F.log_softmax(self.logit(output))
            outputs.append(output)

        return torch.cat([_.unsqueeze(1) for _ in outputs[1:]], 1).contiguous()
Ejemplo n.º 5
0
def translate(enc_input='thisissungkim.iloveyou.', predict_len=100, temperature=0.9):
    input_var = str2tensor(enc_input)
    encoder_hidden = encoder.init_hidden()
    encoder_outputs, encoder_hidden = encoder(input_var, encoder_hidden)

    hidden = encoder_hidden

    predicted = ''
    dec_input = str2tensor(SOS_token)
    for c in range(predict_len):
        output, hidden = decoder(dec_input, hidden)

        # Sample from the network as a multi nominal distribution
        output_dist = output.data.view(-1).div(temperature).exp()
        top_i = torch.multinomial(output_dist, 1)[0]

        # Stop at the EOS
        if top_i is EOS_token:
            break

        predicted_char = chr(top_i)
        predicted += predicted_char

        dec_input = str2tensor(predicted_char)

    return enc_input, predicted
Ejemplo n.º 6
0
def blue_eval(output,corpus):
    #采样以后为20*19
    sent_idx=torch.multinomial(output.exp().cpu(), 1).view(-1,19)
    sent_idx=sent_idx.cpu().data.numpy()
    sent_str=[]

    #对生产的一个batch量数据进行处理
    for i in range(sent_idx.shape[0]):
        str_=[str(int(x)) for x in sent_idx[i,:-1]]
        sent_str.append(str_)

    eval_data=[]
    for sent in corpus.valid.numpy():
        eval_data.append([str(int(x)) for x in sent[1:-1]])

    weight = tuple((1. / 4 for _ in range(4)))
    BLEUscores=[]

    for gen_sent in sent_str:
        ref_sent_info=[]
        for ref_sent in eval_data:
            #找到与这个最相似的句子
            common_tokens = Counter(gen_sent) & Counter(ref_sent)
            correct_preds = sum(common_tokens.values())
            recall_wrt = float(correct_preds) / len(gen_sent)
            ref_sent_info.append((ref_sent,recall_wrt))

        ref_sent_info.sort(key=lambda x: -x[1])
        top_refs=[x[0] for x in ref_sent_info[:50]]

        BLEUscore = nltk.translate.bleu_score.sentence_bleu(top_refs, gen_sent, weight)
        BLEUscores.append(BLEUscore)

    score=(np.mean(BLEUscores))
    return score
Ejemplo n.º 7
0
 def sample(self, input, temperature=1., hidden=None):
     hidden = self.module_.init_hidden(1) if hidden is None else hidden
     output, hidden = self.module_(input, hidden)
     probas = output.squeeze().data.div(temperature).exp()
     sample = torch.multinomial(probas, 1)[-1]
     if probas.dim() > 1:
         sample = sample[0]
     return sample, self.repackage_hidden(hidden)
Ejemplo n.º 8
0
def torch_multinomial(input, num_samples, replacement=False):
    """
    Like `torch.multinomial()` but works with cuda tensors.
    Does not support keyword argument `out`.
    """
    if input.is_cuda:
        return torch_multinomial(input.cpu(), num_samples, replacement).cuda()
    else:
        return torch.multinomial(input, num_samples, replacement)
    def sample(self, fc_feats, att_feats, opt={}):
        sample_max = opt.get('sample_max', 1)
        beam_size = opt.get('beam_size', 1)
        temperature = opt.get('temperature', 1.0)
        if beam_size > 1:
            return self.sample_beam(fc_feats, att_feats, opt)

        batch_size = fc_feats.size(0)
        state = self.init_hidden(batch_size)

        # embed fc and att feats
        fc_feats = self.fc_embed(fc_feats)
        _att_feats = self.att_embed(att_feats.view(-1, self.att_feat_size))
        att_feats = _att_feats.view(*(att_feats.size()[:-1] + (self.rnn_size,)))

        # Project the attention feats first to reduce memory and computation comsumptions.
        p_att_feats = self.ctx2att(att_feats.view(-1, self.rnn_size))
        p_att_feats = p_att_feats.view(*(att_feats.size()[:-1] + (self.att_hid_size,)))

        seq = []
        seqLogprobs = []
        for t in range(self.seq_length + 1):
            if t == 0: # input <bos>
                it = fc_feats.data.new(batch_size).long().zero_()
            elif sample_max:
                sampleLogprobs, it = torch.max(logprobs.data, 1)
                it = it.view(-1).long()
            else:
                if temperature == 1.0:
                    prob_prev = torch.exp(logprobs.data).cpu() # fetch prev distribution: shape Nx(M+1)
                else:
                    # scale logprobs by temperature
                    prob_prev = torch.exp(torch.div(logprobs.data, temperature)).cpu()
                it = torch.multinomial(prob_prev, 1).cuda()
                sampleLogprobs = logprobs.gather(1, Variable(it, requires_grad=False)) # gather the logprobs at sampled positions
                it = it.view(-1).long() # and flatten indices for downstream processing

            xt = self.embed(Variable(it, requires_grad=False))

            if t >= 1:
                # stop when all finished
                if t == 1:
                    unfinished = it > 0
                else:
                    unfinished = unfinished * (it > 0)
                if unfinished.sum() == 0:
                    break
                it = it * unfinished.type_as(it)
                seq.append(it) #seq[t] the input of t+2 time step

                seqLogprobs.append(sampleLogprobs.view(-1))

            output, state = self.core(xt, fc_feats, att_feats, p_att_feats, state)
            logprobs = F.log_softmax(self.logit(output))

        return torch.cat([_.unsqueeze(1) for _ in seq], 1), torch.cat([_.unsqueeze(1) for _ in seqLogprobs], 1)
Ejemplo n.º 10
0
 def sample(self, Q):
     if self.use_actor_critic:
         pi = F.softmax(Q, dim=-1)
         a = torch.multinomial(pi, 1).squeeze()
         return a.data.cpu().numpy()
     else:
         sample = random.random()
         if sample > self.eps_threshold:
             return Q.data.max(1)[1].cpu().numpy()
         else:
             return np.random.randint(0, self.num_actions, self.nenv)
Ejemplo n.º 11
0
    def forward(self, x, hiddens):
        batchsize = x["s"].size(0)
        if not hasattr(self, "prob"):
            self.prob = x["res"].clone().resize_(2)
            self.prob[0] = 1 - self.args.ratio_skip_observation
            self.prob[1] = self.args.ratio_skip_observation

        skip_mat = self._var(torch.multinomial(self.prob, batchsize, replacement=True).float().view(-1, 1))

        output = self._merge(x, hiddens, skip_mat)
        return self.decision(output)
Ejemplo n.º 12
0
def sample_K(probs, K, mode='test'):
    probs = 1e-6 + probs*(1 - 2e-6) # to avoid log(0)
    probs = probs.view(-1, 2**K)
    if mode == 'train':
        bin_sample = torch.multinomial(probs, 1).detach()
    else:
        bin_sample = probs.max(1)[1].detach().unsqueeze(1)
    sample = bin_sample.clone().type(dtype)
    log_probs_samples = torch.log(probs).gather(1, bin_sample).squeeze()
    log_probs_samples = log_probs_samples.view(batch_size, N).sum(1)
    return bin_sample.data.view(batch_size, N), log_probs_samples
Ejemplo n.º 13
0
def main():
    path='my_data1'
    sec_text="I wanna go out tonight"
    n_bins=4

    #加载数据,默认是data1
    corpus = data.Corpus(path=os.path.join("data", path))
    gen=torch.load("models/gen_"+path+".pt").cuda()
    print(gen)

    bin_stream=string2bins(sec_text,n_bins)

    ntokens = len(corpus.dictionary)
    tokens = list(range(ntokens)) # * args.replication_factor
    np.random.shuffle(tokens)
    words_in_bin = int(len(tokens) /n_bins)
    bins = [tokens[i:i + words_in_bin] for i in range(0, len(tokens), words_in_bin)]
    zero = [list(set(tokens) - set(bin_)) for bin_ in bins]



    #循环生成每一个词

    for _ in range(10):

        input = Variable(torch.Tensor([corpus.dictionary.word2idx['<start>']]), volatile=True).view(-1,1).type(torch.LongTensor).cuda()
        h=gen.init_hidden(1)
        gen_words=[]
        for i in range(len(bin_stream[:16])):
            output,h=gen(input,h)

            zero_index = zero[int(bin_stream[i],2)]
            zero_index = torch.LongTensor(zero_index).cuda()

            output = output.squeeze().data.div(0.8).exp()
            output.index_fill_(0, zero_index, 0)

            word_idx = torch.multinomial(output, 1)[0]
            gen_words.append(word_idx)
            input.data.fill_(word_idx)

        print(len(gen_words))
        str_=" ".join([corpus.dictionary.idx2word[x] for x in gen_words])
        print(str_)
Ejemplo n.º 14
0
    def sample(self, max_time_step=200):
        """generate one sample"""
        sample_words = [self.vocab['<s>']]
        h_tm1 = None
        for t in xrange(max_time_step):
            x_tm1_embed = self.embed(Variable(torch.LongTensor([sample_words[-1]])))
            x_tm1_embed = x_tm1_embed.unsqueeze(0)
            h_t, (last_state, last_cell) = self.lstm(x_tm1_embed, h_tm1)
            h_t = self.dropout(h_t.view(-1))
            p_t = F.softmax(self.read_out(h_t), dim=-1)
            x_t_wid = torch.multinomial(p_t).data[0]
            x_t = self.vocab.id2word[x_t_wid]

            if x_t == '</s>':
                return [self.vocab.id2word[wid] for wid in sample_words[1:]]
            else:
                sample_words.append(x_t_wid)

            h_tm1 = last_state, last_cell
Ejemplo n.º 15
0
def sample_from_model(model, vectorizer, nationalities, sample_size=20,
                      temperature=1.0):
    num_samples = len(nationalities)
    begin_seq_index = [vectorizer.char_vocab.begin_seq_index
                       for _ in range(num_samples)]
    begin_seq_index = torch.tensor(begin_seq_index, dtype=torch.int64).unsqueeze(dim=1)
    indices = [begin_seq_index]
    nationality_indices = torch.tensor(nationalities, dtype=torch.int64).unsqueeze(dim=0)
    h_t = model.nation_emb(nationality_indices)

    for time_step in range(sample_size):
        x_t = indices[time_step]
        x_emb_t = model.char_emb(x_t)
        rnn_out_t, h_t = model.rnn(x_emb_t, h_t)
        prediction_vector = model.fc(rnn_out_t.squeeze(dim=1))
        probability_vector = F.softmax(prediction_vector / temperature, dim=1)
        indices.append(torch.multinomial(probability_vector, num_samples=1))
    indices = torch.stack(indices).squeeze().permute(1, 0)
    return indices
Ejemplo n.º 16
0
    def sample(self, netW, input, state, opt={}):
        sample_max = opt.get('sample_max', 1)
        beam_size = opt.get('beam_size', 1)
        temperature = opt.get('temperature', 1.0)
        seq_length = opt.get('seq_length', 9)
        self.seq_length = seq_length

        if beam_size > 1:
            return self.sample_beam(netW, input, state, opt)

        batch_size = input.size(1)
        seq = []
        seqLogprobs = []
        for t in range(self.seq_length + 1):
            if t == 0: # input <bos>
                it = input.data
            elif sample_max:
                sampleLogprobs, it = torch.max(logprobs.data, 1)
                it = it.view(-1).long()
            else:
                if temperature == 1.0:
                    prob_prev = torch.exp(logprobs.data).cpu() # fetch prev distribution: shape Nx(M+1)
                else:
                    # scale logprobs by temperature
                    prob_prev = torch.exp(torch.div(logprobs.data, temperature)).cpu()
                it = torch.multinomial(prob_prev, 1).cuda()
                sampleLogprobs = logprobs.gather(1, Variable(it, requires_grad=False)) # gather the logprobs at sampled positions
                it = it.view(-1).long() # and flatten indices for downstream processing

            xt = netW(Variable(it.view(1,-1), requires_grad=False))

            if t >= 1:
                seq.append(it) #seq[t] the input of t+2 time step
                seqLogprobs.append(sampleLogprobs.view(-1))

            output, state = self.rnn(xt, state)

            output = F.dropout(output, self.d, training=self.training)
            decoded = self.decoder(output.view(output.size(0)*output.size(1), output.size(2)))
            logprobs = F.log_softmax(self.beta * decoded)

        return torch.cat([_.unsqueeze(1) for _ in seq], 1), torch.cat([_.unsqueeze(1) for _ in seqLogprobs], 1)
    def forward(self, fc_feats, att_feats, seq):
        batch_size = fc_feats.size(0)
        state = self.init_hidden(batch_size)

        outputs = []

        # embed fc and att feats
        fc_feats = self.fc_embed(fc_feats)
        _att_feats = self.att_embed(att_feats.view(-1, self.att_feat_size))
        att_feats = _att_feats.view(*(att_feats.size()[:-1] + (self.rnn_size,)))

        # Project the attention feats first to reduce memory and computation comsumptions.
        p_att_feats = self.ctx2att(att_feats.view(-1, self.rnn_size))
        p_att_feats = p_att_feats.view(*(att_feats.size()[:-1] + (self.att_hid_size,)))

        for i in range(seq.size(1) - 1):
            if self.training and i >= 1 and self.ss_prob > 0.0: # otherwiste no need to sample
                sample_prob = fc_feats.data.new(batch_size).uniform_(0, 1)
                sample_mask = sample_prob < self.ss_prob
                if sample_mask.sum() == 0:
                    it = seq[:, i].clone()
                else:
                    sample_ind = sample_mask.nonzero().view(-1)
                    it = seq[:, i].data.clone()
                    #prob_prev = torch.exp(outputs[-1].data.index_select(0, sample_ind)) # fetch prev distribution: shape Nx(M+1)
                    #it.index_copy_(0, sample_ind, torch.multinomial(prob_prev, 1).view(-1))
                    prob_prev = torch.exp(outputs[-1].data) # fetch prev distribution: shape Nx(M+1)
                    it.index_copy_(0, sample_ind, torch.multinomial(prob_prev, 1).view(-1).index_select(0, sample_ind))
                    it = Variable(it, requires_grad=False)
            else:
                it = seq[:, i].clone()          
            # break if all the sequences end
            if i >= 1 and seq[:, i].data.sum() == 0:
                break

            xt = self.embed(it)

            output, state = self.core(xt, fc_feats, att_feats, p_att_feats, state)
            output = F.log_softmax(self.logit(output))
            outputs.append(output)

        return torch.cat([_.unsqueeze(1) for _ in outputs], 1)
Ejemplo n.º 18
0
def initialize(data):
    pyro.clear_param_store()

    optim = pyro.optim.Adam({'lr': 0.1, 'betas': [0.8, 0.99]})
    elbo = TraceEnum_ELBO(max_iarange_nesting=1)
    svi = SVI(model, full_guide, optim, loss=elbo)

    # Initialize weights to uniform.
    pyro.param('auto_weights', 0.5 * torch.ones(K), constraint=constraints.simplex)

    # Assume half of the data variance is due to intra-component noise.
    var = (data.var() / 2).sqrt()
    pyro.param('auto_scale', torch.tensor([var]*4), constraint=constraints.positive)

    # Initialize means from a subsample of data.
    pyro.param('auto_locs', data[torch.multinomial(torch.ones(len(data)) / len(data), K)])

    loss = svi.loss(model, full_guide, data)

    return loss, svi
Ejemplo n.º 19
0
    def __init__(self, input_dim, hidden_dim, output_dim_multiplier=1,
                 mask_encoding=None, permutation=None):
        super(AutoRegressiveNN, self).__init__()
        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        self.output_dim_multiplier = output_dim_multiplier

        if mask_encoding is None:
            # the dependency structure is chosen at random
            self.mask_encoding = 1 + torch.multinomial(torch.ones(input_dim - 1) / (input_dim - 1),
                                                       num_samples=hidden_dim, replacement=True)
        else:
            # the dependency structure is given by the user
            self.mask_encoding = mask_encoding

        if permutation is None:
            # a permutation is chosen at random
            self.permutation = torch.randperm(input_dim, device=torch.device('cpu'))
        else:
            # the permutation is chosen by the user
            self.permutation = permutation

        # these masks control the autoregressive structure
        self.mask1 = torch.zeros(hidden_dim, input_dim)
        self.mask2 = torch.zeros(input_dim * self.output_dim_multiplier, hidden_dim)

        for k in range(hidden_dim):
            # fill in mask1
            m_k = self.mask_encoding[k].item()
            slice_k = torch.cat([torch.ones(m_k), torch.zeros(input_dim - m_k)])
            for j in range(input_dim):
                self.mask1[k, self.permutation[j]] = slice_k[j]
            # fill in mask2
            slice_k = torch.cat([torch.zeros(m_k), torch.ones(input_dim - m_k)])
            for r in range(self.output_dim_multiplier):
                for j in range(input_dim):
                    self.mask2[r * input_dim + self.permutation[j], k] = slice_k[j]

        self.lin1 = MaskedLinear(input_dim, hidden_dim, self.mask1)
        self.lin2 = MaskedLinear(hidden_dim, input_dim * output_dim_multiplier, self.mask2)
        self.relu = nn.ReLU()
    def forward(self, fc_feats, att_feats, seq):
        batch_size = fc_feats.size(0)
        state = self.init_hidden(batch_size)
        #print ("batch size:",batch_size)  ## 50 (10 images 50 captions(sentence))
        #print ("seq size:",seq.size())  ##  (50L, 18L)
        #print("seq : ",seq[29].data)  ## seq bug, data loader 
        outputs = []

        for i in range(seq.size(1)):
            if i == 0:
                xt = self.img_embed(fc_feats)
            else:
                if self.training and i >= 2 and self.ss_prob > 0.0: # otherwiste no need to sample
                    sample_prob = fc_feats.data.new(batch_size).uniform_(0, 1)
                    sample_mask = sample_prob < self.ss_prob
                    if sample_mask.sum() == 0:
                        it = seq[:, i-1].clone()
                    else:
                        sample_ind = sample_mask.nonzero().view(-1)
                        it = seq[:, i-1].data.clone()
                        #prob_prev = torch.exp(outputs[-1].data.index_select(0, sample_ind)) # fetch prev distribution: shape Nx(M+1)
                        #it.index_copy_(0, sample_ind, torch.multinomial(prob_prev, 1).view(-1))
                        prob_prev = torch.exp(outputs[-1].data) # fetch prev distribution: shape Nx(M+1)
                        it.index_copy_(0, sample_ind, torch.multinomial(prob_prev, 1).view(-1).index_select(0, sample_ind))
                        it = Variable(it, requires_grad=False)
                else:
                    it = seq[:, i-1].clone()                
                # break if all the sequences end
                if i >= 2 and seq[:, i-1].data.sum() == 0:
                    break
                xt = self.embed(it)
                
            output, state = self.core(xt.unsqueeze(0), state) ## call lstm
            output = F.log_softmax(self.logit(self.dropout(output.squeeze(0))))
            outputs.append(output)
        #print ("output: length",len(outputs))  # length 18
        t = torch.cat([_.unsqueeze(1) for _ in outputs[1:]], 1).contiguous() # 1-index

        #print("output size:",t.size())
        #exit()
        return t  #(50L, 17L, 9488L)
Ejemplo n.º 21
0
    def forward(self, words_i, words_j):
        batch_size = words_i.size()[0]

        for p in itertools.chain(self.log_sigma.parameters(),
                                 self.log_sigma_c.parameters()):
            p.data.clamp_(math.log(self.sigma_min), math.log(self.sigma_max))

        for p in itertools.chain(self.mu.parameters(),
                                 self.mu_c.parameters()):
            p.data.clamp_(-math.sqrt(self.C), math.sqrt(self.C))

        words_n = torch.multinomial(self.dset.weights, batch_size, replacement=True)
        if torch.cuda.is_available():
            words_n = Variable(words_n).cuda()

        mu_i, mu_j, mu_n = self.mu(words_i), self.mu_c(words_j), self.mu_c(words_n)
        sigma_i, sigma_j, sigma_n = torch.exp(self.log_sigma(words_i)), \
                                    torch.exp(self.log_sigma_c(words_j)), \
                                    torch.exp(self.log_sigma_c(words_n))

        return torch.mean(F.relu(self.ob - self.kl_energy(mu_i, mu_j, sigma_i, sigma_j) + self.kl_energy(mu_i, mu_n, sigma_i, sigma_n)), dim=0)
Ejemplo n.º 22
0
    def forward(self, prob, targets, infos, wt=None):
        if wt is None:
            wt = torch.ones_like(prob)
        prob = prob.clamp(min=1e-7, max=1-1e-7)
        with torch.no_grad():
            prob_diff_wt = torch.abs((prob - targets) * wt) ** config.TRAIN.RHEM_POWER
            idx = torch.multinomial(prob_diff_wt.view(-1), config.TRAIN.RHEM_BATCH_SIZE, replacement=True)
            # hist = np.histogram(idx.cpu().numpy(), np.arange(torch.numel(prob)+1))[0]
            # hist = np.reshape(hist, prob.shape)
            # pos = np.where(hist == np.max(hist))
            # row = pos[0][0]
            # col = pos[1][0]
            # print np.max(hist), prob[row, col].item(), targets[row, col].item(), \
            #     default.term_list[col], int(self.pos_wt[col].item()), infos[row][0]#, prob_diff_wt.mean(0)[col].item()

        targets = targets.view(-1)[idx]
        prob = prob.view(-1)[idx]
        loss_per_smp = - (torch.log(prob) * targets + torch.log(1-prob) * (1-targets))
        loss = loss_per_smp.mean()

        return loss
Ejemplo n.º 23
0
    def generate(self, hidden, maxlen, sample=True, temp=1.0):
        """Generate through decoder; no backprop"""

        batch_size = hidden.size(0)

        if self.hidden_init:
            # initialize decoder hidden state to encoder output
            state = (hidden.unsqueeze(0), self.init_state(batch_size))
        else:
            state = self.init_hidden(batch_size)

        # <sos>
        self.start_symbols.data.resize_(batch_size, 1)
        self.start_symbols.data.fill_(1)

        embedding = self.embedding_decoder(self.start_symbols)
        inputs = torch.cat([embedding, hidden.unsqueeze(1)], 2)

        # unroll
        all_indices = []
        for i in range(maxlen):
            output, state = self.decoder(inputs, state)
            overvocab = self.linear(output.squeeze(1))

            if not sample:
                vals, indices = torch.max(overvocab, 1)
            else:
                # sampling
                probs = F.softmax(overvocab/temp)
                indices = torch.multinomial(probs, 1)

            all_indices.append(indices)

            embedding = self.embedding_decoder(indices)
            inputs = torch.cat([embedding, hidden.unsqueeze(1)], 2)

        max_indices = torch.cat(all_indices, 1)

        return max_indices
Ejemplo n.º 24
0
def generate_one(category, start_char='A', temperature=0.5):
    category_input = make_category_input(category)
    chars_input = make_chars_input(start_char)
    hidden = rnn.init_hidden()

    output_str = start_char
    
    for i in range(max_length):
        output, hidden = rnn(category_input, chars_input[0], hidden)
        
        # Sample as a multinomial distribution
        output_dist = output.data.view(-1).div(temperature).exp()
        top_i = torch.multinomial(output_dist, 1)[0]
        
        # Stop at EOS, or add to output_str
        if top_i == EOS:
            break
        else:    
            char = all_letters[top_i]
            output_str += char
            chars_input = make_chars_input(char)

    return output_str
Ejemplo n.º 25
0
def generate(decoder, prime_str='A', predict_len=100, temperature=0.8):
    hidden = decoder.init_hidden()
    prime_input = str2tensor(prime_str)
    predicted = prime_str

    # Use priming string to "build up" hidden state
    for p in range(len(prime_str) - 1):
        _, hidden = decoder(prime_input[p], hidden)

    inp = prime_input[-1]

    for p in range(predict_len):
        output, hidden = decoder(inp, hidden)

        # Sample from the network as a multinomial distribution
        output_dist = output.data.view(-1).div(temperature).exp()
        top_i = torch.multinomial(output_dist, 1)[0]

        # Add predicted character to string and use as next input
        predicted_char = chr(top_i)
        predicted += predicted_char
        inp = str2tensor(predicted_char)

    return predicted
Ejemplo n.º 26
0
    def beam_search(self, enc_contexts=[], return_beams=False):
        with torch.no_grad():
            if len(enc_contexts) == 0:
                return []

            batch_size = enc_contexts[0][0].shape[0]
            device = next(self.parameters()).device

            prevs = torch.full((batch_size * self.beam_size, 1),
                               fill_value=self.bos_id,
                               dtype=torch.long,
                               device=device)

            beam_scores = torch.zeros(batch_size,
                                      self.beam_size,
                                      device=device)
            beam_lens = torch.ones(batch_size,
                                   self.beam_size,
                                   dtype=torch.long,
                                   device=device)
            is_end = torch.zeros(batch_size,
                                 self.beam_size,
                                 dtype=torch.uint8,
                                 device=device)

            beam_enc_contexts = []
            for c, p in enc_contexts:
                c = c.unsqueeze(1).repeat(1, self.beam_size, 1, 1)
                c = c.view(-1, c.shape[2], c.shape[3])
                p = p.unsqueeze(1).repeat(1, self.beam_size, 1)
                p = p.view(-1, p.shape[2])
                beam_enc_contexts.append((c, p))

            current_sample_prob = 1
            group_size = self.beam_size // self.diversity_groups
            diversity_penalty = torch.zeros((batch_size, self.n_embeddings),
                                            device=device)

            for i in range(self.max_seq_len):
                outputs, _ = self.transformer_module(prevs, beam_enc_contexts)

                logits = self.generate(outputs[:, -1, :])
                log_probs = F.log_softmax(logits, dim=-1)
                log_probs = log_probs.view(batch_size, self.beam_size, -1)

                beam_scores = beam_scores.unsqueeze(
                    -1) + log_probs * (1 - is_end.float().unsqueeze(-1))
                penalty = self._length_penalty(beam_lens.float() + 1 -
                                               is_end.float())
                penalty = penalty.unsqueeze(-1).repeat(1, 1, self.n_embeddings)
                beam_scores = beam_scores / penalty

                if i == 0:
                    penalty = penalty[:, 0, :]
                    beam_scores = beam_scores[:, 0, :]

                    beam_scores, idxs = beam_scores.topk(self.beam_size,
                                                         dim=-1)
                    beam_idxs = torch.zeros((batch_size, self.beam_size),
                                            dtype=torch.long,
                                            device=device)
                else:
                    penalty = penalty.view(batch_size, self.diversity_groups,
                                           group_size, -1)
                    beam_scores = beam_scores.view(batch_size,
                                                   self.diversity_groups,
                                                   group_size, -1)

                    all_scores, all_idxs = [], []
                    for g in range(self.diversity_groups):
                        g_beam_scores = beam_scores[:, g, :, :]
                        g_penalty = penalty[:, g, :, :]
                        g_beam_scores -= self.diversity_coef * diversity_penalty.unsqueeze(
                            1) / g_penalty
                        g_beam_scores = g_beam_scores.view(batch_size, -1)

                        if random.random() < current_sample_prob:
                            beam_probas = F.softmax(g_beam_scores, dim=-1)
                            if self.annealing_topk is not None:
                                beam_probas, sample_idxs = beam_probas.topk(
                                    self.annealing_topk, dim=-1)
                                g_idxs = torch.multinomial(
                                    beam_probas, group_size)
                                g_idxs = torch.gather(sample_idxs, 1, g_idxs)
                            else:
                                g_idxs = torch.multinomial(
                                    beam_probas, group_size)
                        else:
                            _, g_idxs = g_beam_scores.topk(group_size, dim=-1)

                        g_scores = torch.gather(
                            beam_scores[:, g, :, :].view(batch_size, -1), 1,
                            g_idxs)
                        g_idxs += g * group_size * self.n_embeddings

                        all_scores.append(g_scores)
                        all_idxs.append(g_idxs)

                        diversity_penalty.scatter_add_(
                            1, torch.fmod(g_idxs, self.n_embeddings),
                            torch.ones((batch_size, group_size),
                                       device=device))

                    diversity_penalty.fill_(0)
                    penalty = penalty.view(batch_size, -1)
                    beam_scores = torch.cat(all_scores, dim=-1)
                    idxs = torch.cat(all_idxs, dim=-1)

                    beam_idxs = (idxs.float() / self.n_embeddings).long()

                penalty = torch.gather(penalty, 1, idxs)
                sym_idxs = torch.fmod(idxs, log_probs.shape[-1])
                is_end = torch.gather(is_end, 1, beam_idxs)
                beam_lens = torch.gather(beam_lens, 1, beam_idxs)

                sym_idxs[is_end] = self.padding_idx
                beam_lens[~is_end] += 1
                is_end[sym_idxs == self.eos_id] = 1

                sym_idxs = sym_idxs.view(batch_size * self.beam_size, 1)
                prevs = prevs.view(batch_size, self.beam_size, -1)
                prevs = torch.gather(
                    prevs, 1,
                    beam_idxs.unsqueeze(-1).repeat(1, 1, prevs.shape[-1]))
                prevs = prevs.view(batch_size * self.beam_size, -1)
                prevs = torch.cat([prevs, sym_idxs], dim=1)

                if all(is_end.view(-1)):
                    break

                beam_scores *= penalty
                current_sample_prob *= self.annealing

            predicts = []
            result = prevs.view(batch_size, self.beam_size, -1)

            if return_beams:
                return result, beam_lens

            if self.sample:
                probs = F.softmax(beam_scores, dim=-1)
                bests = torch.multinomial(probs, 1).view(-1)
            else:
                bests = beam_scores.argmax(dim=-1)

            for i in range(batch_size):
                best_len = beam_lens[i, bests[i]]
                best_seq = result[i, bests[i], 1:best_len - 1]
                predicts.append(best_seq.tolist())

        return predicts
Ejemplo n.º 27
0
    def prepare_batch_mlm(self, batch):
        """
        Prepare the batch: from the token_ids and the lenghts, compute the attention mask and the masked label for MLM.
        Input:
        ------
            batch: `Tuple`
                token_ids: `torch.tensor(bs, seq_length)` - The token ids for each of the sequence. It is padded.
                lengths: `torch.tensor(bs)` - The lengths of each of the sequences in the batch.
        Output:
        -------
            token_ids: `torch.tensor(bs, seq_length)` - The token ids after the modifications for MLM.
            attn_mask: `torch.tensor(bs, seq_length)` - The attention mask for the self-attention.
            mlm_labels: `torch.tensor(bs, seq_length)` - The masked languge modeling labels. There is a -100 where there is nothing to predict.
        """
        token_ids, lengths = batch
        token_ids, lengths = self.round_batch(x=token_ids, lengths=lengths)
        assert token_ids.size(0) == lengths.size(0)

        attn_mask = torch.arange(token_ids.size(1),
                                 dtype=torch.long,
                                 device=lengths.device) < lengths[:, None]

        bs, max_seq_len = token_ids.size()
        mlm_labels = token_ids.new(token_ids.size()).copy_(token_ids)

        x_prob = self.token_probs[token_ids.flatten()]
        n_tgt = math.ceil(self.mlm_mask_prop * lengths.sum().item())
        tgt_ids = torch.multinomial(x_prob / x_prob.sum(),
                                    n_tgt,
                                    replacement=False)
        pred_mask = torch.zeros(
            bs * max_seq_len, dtype=torch.bool, device=token_ids.device
        )  # previously `dtype=torch.uint8`, cf pytorch 1.2.0 compatibility
        pred_mask[tgt_ids] = 1
        pred_mask = pred_mask.view(bs, max_seq_len)

        pred_mask[token_ids == self.params.special_tok_ids["pad_token"]] = 0

        # mask a number of words == 0 [8] (faster with fp16)
        if self.fp16:
            n1 = pred_mask.sum().item()
            if n1 > 8:
                pred_mask = pred_mask.view(-1)
                n2 = max(n1 % 8, 8 * (n1 // 8))
                if n2 != n1:
                    pred_mask[torch.nonzero(pred_mask).view(-1)[:n1 - n2]] = 0
                pred_mask = pred_mask.view(bs, max_seq_len)
                assert pred_mask.sum().item() % 8 == 0, pred_mask.sum().item()

        _token_ids_real = token_ids[pred_mask]
        _token_ids_rand = _token_ids_real.clone().random_(self.vocab_size)
        _token_ids_mask = _token_ids_real.clone().fill_(
            self.params.special_tok_ids["mask_token"])
        probs = torch.multinomial(self.pred_probs,
                                  len(_token_ids_real),
                                  replacement=True)
        _token_ids = (_token_ids_mask * (probs == 0).long() + _token_ids_real *
                      (probs == 1).long() + _token_ids_rand *
                      (probs == 2).long())
        token_ids = token_ids.masked_scatter(pred_mask, _token_ids)

        mlm_labels[
            ~pred_mask] = -100  # previously `mlm_labels[1-pred_mask] = -1`, cf pytorch 1.2.0 compatibility

        # sanity checks
        assert 0 <= token_ids.min() <= token_ids.max() < self.vocab_size

        return token_ids, attn_mask, mlm_labels
Ejemplo n.º 28
0
    parser.error("--temperature has to be greater or equal 1e-3")

with open(args.checkpoint, 'rb') as f:
    model = torch.load(f)
model.eval()

if args.cuda:
    model.cuda()
else:
    model.cpu()

corpus = data.Corpus(args.data)
ntokens = len(corpus.dictionary)
hidden = model.init_hidden(1)
input = Variable(torch.rand(1, 1).mul(ntokens).long(), volatile=True)
if args.cuda:
    input.data = input.data.cuda()

with open(args.outf, 'w') as outf:
    for i in range(args.words):
        output, hidden = model(input, hidden)
        word_weights = output.squeeze().data.div(args.temperature).exp().cpu()
        word_idx = torch.multinomial(word_weights, 1)[0]
        input.data.fill_(word_idx)
        word = corpus.dictionary.idx2word[word_idx]

        outf.write(word + ('\n' if i % 20 == 19 else ' '))

        if i % args.log_interval == 0:
            print('| Generated {}/{} words'.format(i, args.words))
Ejemplo n.º 29
0
 def __iter__(self):
     return iter(torch.multinomial(self.weights, self.num_samples, self.replacement))
Ejemplo n.º 30
0
    def dfs_assemble(self, tree_mess, mol_vec, all_nodes, cur_mol, global_amap,
                     fa_amap, cur_node, fa_node, prob_decode):
        fa_nid = fa_node.nid if fa_node is not None else -1
        prev_nodes = [fa_node] if fa_node is not None else []

        children = [nei for nei in cur_node.neighbors if nei.nid != fa_nid]
        neighbors = [nei for nei in children if nei.mol.GetNumAtoms() > 1]
        neighbors = sorted(neighbors,
                           key=lambda x: x.mol.GetNumAtoms(),
                           reverse=True)
        singletons = [nei for nei in children if nei.mol.GetNumAtoms() == 1]
        neighbors = singletons + neighbors

        cur_amap = [(fa_nid, a2, a1) for nid, a1, a2 in fa_amap
                    if nid == cur_node.nid]
        cands = enum_assemble(cur_node, neighbors, prev_nodes, cur_amap)
        if len(cands) == 0:
            return None
        cand_smiles, cand_mols, cand_amap = zip(*cands)

        cands = [(candmol, all_nodes, cur_node) for candmol in cand_mols]

        cand_vecs = self.jtmpn(cands, tree_mess)
        cand_vecs = self.G_mean(cand_vecs)
        mol_vec = mol_vec.squeeze()
        scores = torch.mv(cand_vecs, mol_vec) * 20

        if prob_decode:
            probs = nn.Softmax()(scores.view(
                1, -1)).squeeze() + 1e-5  #prevent prob = 0
            cand_idx = torch.multinomial(probs, probs.numel())
        else:
            _, cand_idx = torch.sort(scores, descending=True)

        backup_mol = Chem.RWMol(cur_mol)
        for i in xrange(cand_idx.numel()):
            cur_mol = Chem.RWMol(backup_mol)
            pred_amap = cand_amap[cand_idx[i].item()]
            new_global_amap = copy.deepcopy(global_amap)

            for nei_id, ctr_atom, nei_atom in pred_amap:
                if nei_id == fa_nid:
                    continue
                new_global_amap[nei_id][nei_atom] = new_global_amap[
                    cur_node.nid][ctr_atom]

            cur_mol = attach_mols(cur_mol, children, [],
                                  new_global_amap)  #father is already attached
            new_mol = cur_mol.GetMol()
            new_mol = Chem.MolFromSmiles(Chem.MolToSmiles(new_mol))

            if new_mol is None: continue

            result = True
            for nei_node in children:
                if nei_node.is_leaf: continue
                cur_mol = self.dfs_assemble(tree_mess, mol_vec, all_nodes,
                                            cur_mol, new_global_amap,
                                            pred_amap, nei_node, cur_node,
                                            prob_decode)
                if cur_mol is None:
                    result = False
                    break
            if result: return cur_mol

        return None
    def step(self, step: int, lprobs, scores):
        bsz, beam_size, vocab_size = lprobs.size()

        if step == 0:
            # at the first step all hypotheses are equally likely, so use
            # only the first beam
            lprobs = lprobs[:, ::beam_size, :].contiguous()

        if self.sampling_topp > 0:
            # only sample from the smallest set of words whose cumulative probability mass exceeds p
            probs, top_indices = self._sample_topp(lprobs)
        elif self.sampling_topk > 0:
            # only sample from top-k candidates
            lprobs, top_indices = lprobs.topk(self.sampling_topk)
            probs = lprobs.exp_()
        else:
            probs = lprobs.exp_()

            # dummy data to be consistent with true branch for type check
            top_indices = torch.empty(0).to(probs)
        # sample
        if step == 0:
            indices_buf = torch.multinomial(
                probs.view(bsz, -1),
                beam_size,
                replacement=True,
            ).view(bsz, beam_size)
        else:
            indices_buf = torch.multinomial(
                probs.view(bsz * beam_size, -1),
                1,
                replacement=True,
            ).view(bsz, beam_size)

        if step == 0:
            # expand to beam size
            probs = probs.expand(bsz, beam_size, -1)

        # gather scores
        scores_buf = torch.gather(probs,
                                  dim=2,
                                  index=indices_buf.unsqueeze(-1))
        scores_buf = scores_buf.log_().view(bsz, -1)

        # remap indices if using top-k or top-P sampling
        if self.sampling_topk > 0 or self.sampling_topp > 0:
            indices_buf = torch.gather(
                top_indices.expand(bsz, beam_size, -1),
                dim=2,
                index=indices_buf.unsqueeze(-1),
            ).squeeze(2)

        if step == 0:
            beams_buf = indices_buf.new_zeros(bsz, beam_size)
        else:
            beams_buf = torch.arange(0,
                                     beam_size).to(indices_buf).repeat(bsz, 1)
            # make scores cumulative
            scores_buf.add_(
                torch.gather(scores[:, :, step - 1], dim=1, index=beams_buf))

        return scores_buf, indices_buf, beams_buf
Ejemplo n.º 32
0
 def __iter__(self) -> Iterator[int]:
     rand_tensor = torch.multinomial(self.weights,
                                     self.num_samples,
                                     self.replacement,
                                     generator=self.generator)
     return iter(rand_tensor.tolist())
    def nucleus_sampling(self, encoder_output, beam_size, top_p, print_beam=False):
        """Generate and return the top k sequences using nucleus sampling."""

        current_beam_width = beam_size

        encoder_dim = encoder_output.size()[-1]

        # Flatten encoding
        encoder_output = encoder_output.view(1, -1, encoder_dim)

        # We'll treat the problem as having a batch size of k
        encoder_output = encoder_output.expand(
            beam_size, encoder_output.size(1), encoder_dim
        )

        # Tensor to store top k sequences; now they're just <start>
        top_k_sequences = torch.full(
            (beam_size, 1), self.word_map[TOKEN_START], dtype=torch.int64, device=device
        )

        # Tensor to store top k sequences' scores; now they're just 0
        top_k_scores = torch.zeros(beam_size, device=device)

        # Lists to store completed sequences, scores, and alphas and the full decoding beam
        complete_seqs = []
        complete_seqs_scores = []

        # Initialize hidden states
        states = self.init_hidden_states(encoder_output)

        # Start decoding
        for step in range(0, self.params["max_caption_len"] - 1):
            prev_words = top_k_sequences[:, step]

            prev_word_embeddings = self.word_embedding(prev_words)
            predictions, states, alpha = self.forward_step(
                encoder_output, prev_word_embeddings, states
            )
            scores = F.log_softmax(predictions, dim=1)

            sorted_logits, sorted_indices = torch.sort(scores, descending=True, dim=-1)
            cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)

            # Remove tokens with cumulative probability above the threshold
            sorted_indices_to_remove = cumulative_probs > top_p
            # Shift the indices to the right to keep also the first token above the threshold
            sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[
                ..., :-1
            ].clone()
            sorted_indices_to_remove[..., 0] = 0

            top_k_scores = torch.zeros(
                current_beam_width, dtype=torch.float, device=device
            )
            top_k_words = torch.zeros(
                current_beam_width, dtype=torch.long, device=device
            )

            for i in range(0, current_beam_width):
                scores[i][sorted_indices[i][sorted_indices_to_remove[i]]] = -float(
                    "inf"
                )

                # Sample from the scores
                top_k_words[i] = torch.multinomial(torch.softmax(scores[i], -1), 1)
                top_k_scores[i] = scores[i][top_k_words[i]]

            # Add new words to sequences
            top_k_sequences = torch.cat(
                (top_k_sequences, top_k_words.unsqueeze(1)), dim=1
            )

            if print_beam:
                print_current_beam(top_k_sequences, top_k_scores, self.word_map)

            # Check for complete and incomplete sequences (based on the <end> token)
            incomplete_inds = (
                torch.nonzero(top_k_words != self.word_map[TOKEN_END]).view(-1).tolist()
            )
            complete_inds = (
                torch.nonzero(top_k_words == self.word_map[TOKEN_END]).view(-1).tolist()
            )

            # Set aside complete sequences and reduce beam size accordingly
            if len(complete_inds) > 0:
                complete_seqs.extend(top_k_sequences[complete_inds].tolist())
                complete_seqs_scores.extend(top_k_scores[complete_inds])

            # Stop if k captions have been completely generated
            current_beam_width = len(incomplete_inds)
            if current_beam_width == 0:
                break

            # Proceed with incomplete sequences
            top_k_sequences = top_k_sequences[incomplete_inds]
            for i in range(len(states)):
                states[i] = states[i][incomplete_inds]
            encoder_output = encoder_output[incomplete_inds]
            top_k_scores = top_k_scores[incomplete_inds]

        if len(complete_seqs) < beam_size:
            complete_seqs.extend(top_k_sequences.tolist())
            complete_seqs_scores.extend(top_k_scores)

        sorted_sequences = [
            sequence
            for _, sequence in sorted(
                zip(complete_seqs_scores, complete_seqs), reverse=True
            )
        ]
        return sorted_sequences, None, None
Ejemplo n.º 34
0
    def get_arm_loss_two_layer_fast(self, fc_feats, att_feats, att_masks, opt,
                                    data, loader):
        sample_max = 1
        batch_size = fc_feats.size(0)
        state = self.init_hidden(batch_size)
        seq = fc_feats.new_zeros(batch_size, self.seq_length,
                                 dtype=torch.long).cuda()
        unfinished = fc_feats.new_ones(batch_size, dtype=torch.uint8).cuda()
        loss = torch.zeros([]).float().cuda()
        mask_sum = 0
        n_cluster = len(self.cluster_size)
        for t in range(self.seq_length + 2):
            if t == 0:
                xt = self.img_embed(fc_feats)
            else:
                if t == 1:  # input <bos>
                    it = fc_feats.data.new(batch_size).long().zero_()
                xt = self.embed(it)

            output, state = self.core(xt, state)
            phi = self.logit(output)  # phi: batch, vocab-1
            # sample the next_word
            if t == self.seq_length + 1:  # skip if we achieve maximum length
                break
            if t >= 1:
                mask_depth = unfinished.clone()  # batch,
                # things to concat across depths:
                seqs_arm_list = []  # length 2
                state_arm_list = []  # length 2
                it_arm_list = []
                unfinished_arm_list = []
                pi_list = []
                phi_list = []

                arm_pseudo_action_set_list = []
                arm_index_list = []
                arm_index_2_list = []
                arm_pseudo_counts_list = []
                arm_pseudo_index_list = []
                counts_per_sample_list_list = []
                batch_index_list = []

                ### first depth:
                unfinished_size = unfinished.sum()
                pi_step_1 = np.random.uniform(
                    size=[unfinished_size, n_cluster])
                phi_pad_step_1 = torch.cat([
                    phi[unfinished, :(n_cluster - 1)].clone(),
                    torch.zeros(unfinished_size, 1).float().cuda()
                ], 1)
                pseudo_action_step_1 = pseudo_action_batch(
                    pi_step_1,
                    phi_pad_step_1.data.cpu().numpy()
                )  #batch, n_cluster, n_cluster
                pseudo_action_step_1 = np.reshape(pseudo_action_step_1,
                                                  [unfinished_size, -1])
                ## concate unique pseudo actions
                arm_pseudo_action_set, arm_index, arm_index_2, arm_pseudo_counts, arm_pseudo_index, \
                counts_per_sample_list = unique_function(pseudo_action_step_1)
                ## complete words
                if np.sum(arm_pseudo_counts) != 0:  #TODO: what if it ==0
                    arm_pseudo_action_set_list.append(arm_pseudo_action_set)
                    pi_list.append(pi_step_1)
                    phi_list.append(phi_pad_step_1)
                    arm_index_list.append(arm_index)
                    arm_index_2_list.append(arm_index_2)
                    arm_pseudo_counts_list.append(arm_pseudo_counts)
                    arm_pseudo_index_list.append(arm_pseudo_index)
                    counts_per_sample_list_list.append(counts_per_sample_list)
                    seqs_arm_step_1 = seq[unfinished, :][
                        arm_index_2, :].clone()
                    unfinished_arm_step_1 = unfinished[unfinished][arm_index_2]
                    it_step_1 = torch.from_numpy(
                        arm_pseudo_action_set).long().cuda()
                    phi_arm_step_1 = phi[unfinished, :][arm_index_2, :].clone()
                    start = n_cluster - 1
                    it_step_2 = torch.zeros_like(it_step_1).cuda()
                    for i in range(n_cluster):
                        index = it_step_1 == i
                        if index.sum() != 0:
                            probs_step_2 = F.softmax(
                                torch.cat([
                                    phi_arm_step_1[index, start:(
                                        start + self.cluster_size[i] - 1)],
                                    torch.zeros(index.sum(), 1).float().cuda()
                                ], 1), 1)
                            if sample_max:
                                it_step_2[index] = torch.max(
                                    probs_step_2.data,
                                    1)[1].view(-1).cuda().long()
                            else:
                                it_step_2[index] = torch.multinomial(
                                    probs_step_2.data, 1).cuda().squeeze(1)
                        start = start + self.cluster_size[i] - 1
                    code_sum = it_step_1 * (self.vocab_size + 1) + it_step_2
                    it_step = torch.from_numpy(
                        code2vocab_fun(code_sum.cpu().numpy(),
                                       self.code2vocab)).cuda().long()
                    seqs_arm_step_1[:, t -
                                    1] = it_step * unfinished_arm_step_1.type_as(
                                        it)
                    unfinished_arm_step_1 = unfinished_arm_step_1 * (it_step >
                                                                     0)
                    state_h, state_c = state
                    state_h_arm_step_1 = state_h[:,
                                                 unfinished, :][:,
                                                                arm_index_2, :]
                    state_c_arm_step_1 = state_c[:,
                                                 unfinished, :][:,
                                                                arm_index_2, :]
                    state_arm_step_1 = (state_h_arm_step_1, state_c_arm_step_1)
                    seqs_arm_list.append(seqs_arm_step_1)
                    state_arm_list.append(state_arm_step_1)
                    it_arm_list.append(it_step)
                    unfinished_arm_list.append(unfinished_arm_step_1)
                    batch_index_list.append(
                        torch.arange(batch_size)[unfinished])

                ### second depth:
                probs_step_1 = F.softmax(
                    torch.cat([
                        phi[:, :(n_cluster - 1)],
                        torch.zeros(batch_size, 1).float().cuda()
                    ], 1), 1)
                if sample_max:
                    it_1 = torch.max(probs_step_1.data,
                                     1)[1].view(-1).cuda().long()
                else:
                    it_1 = torch.multinomial(probs_step_1.data,
                                             1).cuda().squeeze(1)
                it_2 = torch.zeros_like(it_1).cuda()
                start = n_cluster - 1
                for i in range(n_cluster):
                    index = it_1[unfinished] == i
                    if index.sum() != 0:
                        # pseudo actions
                        effect_batch = index.sum()
                        cluster_size = self.cluster_size[i]
                        pi_step_2 = np.random.uniform(
                            size=[effect_batch, cluster_size])
                        phi_pad_step_2 = torch.cat([
                            phi[unfinished, :][index,
                                               start:(start + cluster_size -
                                                      1)].clone(),
                            torch.zeros(effect_batch, 1).float().cuda()
                        ], 1)
                        pseudo_action_step_2 = pseudo_action_batch(
                            pi_step_2,
                            phi_pad_step_2.data.cpu().numpy()
                        )  # batch, n_cluster, n_cluster
                        pseudo_action_step_2 = np.reshape(
                            pseudo_action_step_2, [effect_batch, -1])
                        arm_pseudo_action_set, arm_index, arm_index_2, arm_pseudo_counts, arm_pseudo_index, \
                        counts_per_sample_list = unique_function(pseudo_action_step_2)
                        arm_pseudo_action_set_list.append(
                            arm_pseudo_action_set)
                        arm_index_list.append(arm_index)
                        arm_index_2_list.append(arm_index_2)
                        arm_pseudo_counts_list.append(arm_pseudo_counts)
                        arm_pseudo_index_list.append(arm_pseudo_index)
                        counts_per_sample_list_list.append(
                            counts_per_sample_list)
                        pi_list.append(pi_step_2)
                        phi_list.append(phi_pad_step_2)

                        code_sum = it_1[unfinished][index][arm_index_2].clone(
                        ) * (self.vocab_size + 1) + torch.from_numpy(
                            np.array(arm_pseudo_action_set)).long().cuda()
                        it_step = torch.from_numpy(
                            code2vocab_fun(code_sum.cpu().numpy(),
                                           self.code2vocab)).cuda().long()

                        seqs_arm_step_2 = seq[unfinished, :][index, :][
                            arm_index_2, :].clone()
                        unfinished_arm_step_2 = unfinished[unfinished][index][
                            arm_index_2]

                        seqs_arm_step_2[:, t -
                                        1] = it_step * unfinished_arm_step_2.type_as(
                                            it)
                        unfinished_arm_step_2 = unfinished_arm_step_2 * (
                            it_step > 0)
                        state_h_arm_step_2 = state_h[:,
                                                     unfinished, :][:,
                                                                    index, :][:,
                                                                              arm_index_2, :]
                        state_c_arm_step_2 = state_c[:,
                                                     unfinished, :][:,
                                                                    index, :][:,
                                                                              arm_index_2, :]
                        state_arm_step_2 = (state_h_arm_step_2,
                                            state_c_arm_step_2)
                        seqs_arm_list.append(seqs_arm_step_2)
                        state_arm_list.append(state_arm_step_2)
                        it_arm_list.append(it_step)
                        unfinished_arm_list.append(unfinished_arm_step_2)
                        batch_index_list.append(
                            torch.arange(batch_size)[unfinished][index])
                    start = start + self.cluster_size[i] - 1
                start = n_cluster - 1
                for i in range(n_cluster):
                    if self.cluster_size[i] != 1:
                        index = it_1 == i
                        if index.sum() != 0:
                            probs_step_2 = F.softmax(
                                torch.cat([
                                    phi[index,
                                        start:(start + self.cluster_size[i] -
                                               1)],
                                    torch.zeros(index.sum(), 1).float().cuda()
                                ], 1), 1)
                            if sample_max:
                                it_2[index] = torch.max(
                                    probs_step_2.data,
                                    1)[1].view(-1).cuda().long()
                            else:
                                it_2[index] = torch.multinomial(
                                    probs_step_2.data, 1).cuda().squeeze(1)
                    start = start + self.cluster_size[i] - 1
                if len(unfinished_arm_list) > 0:
                    unfinished_arm_straight = straight_fun(unfinished_arm_list)
                    it_arm_straight = straight_fun(it_arm_list)
                    seqs_arm_straight = straight_fun(seqs_arm_list)
                    for i, item in enumerate(state_arm_list):
                        if i == 0:
                            state_h_arm_straight = item[0]
                            state_c_arm_straight = item[1]
                        else:
                            state_h_arm_straight = torch.cat(
                                [state_h_arm_straight, item[0]], 1)
                            state_c_arm_straight = torch.cat(
                                [state_c_arm_straight, item[0]], 1)
                    state_arm_straight = (state_h_arm_straight,
                                          state_c_arm_straight)
                    seqs_arm_completed = self.sentence_completion(
                        t, unfinished_arm_straight, it_arm_straight,
                        state_arm_straight, seqs_arm_straight, sample_max)
                    gts = OrderedDict()
                    for i in range(len(data['gts'])):
                        gts[i] = [
                            array_to_str(data['gts'][i][j])
                            for j in range(len(data['gts'][i]))
                        ]
                    start_index = 0
                    for i in range(len(unfinished_arm_list)):
                        arm_pseudo_index = np.array(
                            arm_pseudo_index_list[i]
                        )  # TODO: only run non-1 pseudo
                        batch_index = np.array(batch_index_list[i])
                        effect_batch = np.sum(
                            arm_pseudo_index[arm_pseudo_index > 1])
                        if effect_batch > 1:
                            arm_metric_value = reward_function(
                                data, batch_size,
                                seqs_arm_completed[start_index:(start_index +
                                                                effect_batch)],
                                batch_index[np.array(
                                    arm_index_2_list[i]).astype(int)], gts)
                            start_index = start_index + effect_batch
                            arm_index = arm_index_list[i]
                            arm_pseudo_counts = arm_pseudo_counts_list[i]
                            vocab_size = phi_list[i].size(1)
                            arm_index += np.repeat(
                                np.expand_dims(
                                    np.concatenate(
                                        [[0],
                                         np.cumsum(arm_pseudo_counts)[0:-1]]),
                                    1), vocab_size * vocab_size, 1)
                            arm_index = np.reshape(arm_index, [-1])
                            #print(i, batch_index[np.array(arm_index_2_list[i]).astype(int)])
                            arm_metric_matrix = np.reshape(
                                arm_metric_value[arm_index],
                                [-1, vocab_size, vocab_size])
                            arm_metric_matrix_cuda = torch.from_numpy(
                                arm_metric_matrix).float().cuda()
                            f_delta = (arm_metric_matrix_cuda -
                                       arm_metric_matrix_cuda.mean(1).
                                       unsqueeze(1).repeat(1, vocab_size, 1))
                            f_delta = (
                                f_delta *
                                (1 / vocab_size - torch.from_numpy(
                                    pi_list[i][arm_pseudo_index > 1]).float().
                                 cuda().unsqueeze(1).repeat(1, vocab_size, 1))
                            ).sum(2)
                            f_delta = f_delta - f_delta[:, -1].unsqueeze(
                                1).repeat(
                                    1, vocab_size)  #TODO: verify formulation
                            loss = loss - (f_delta.detach() * phi_list[i][
                                torch.from_numpy(arm_pseudo_index).cuda() > 1]
                                           ).sum()
                            if np.random.randint(200) == 1:
                                print('step', t, 'vocab', i, 'average reward',
                                      np.mean(arm_metric_value),
                                      'ave pseudo num',
                                      np.mean(arm_pseudo_index))
                    assert start_index == seqs_arm_completed.size(0)
                code_sum = it_1 * (self.vocab_size + 1) + it_2
                it = torch.from_numpy(
                    code2vocab_fun(code_sum.cpu().numpy(),
                                   self.code2vocab)).cuda().long()
                if t == 1:
                    unfinished = it > 0
                else:
                    unfinished = unfinished * (it > 0)
                it = it * unfinished.type_as(it)
                seq[:, t - 1] = it
                mask_sum += unfinished.sum()
                if unfinished.sum() == 0:
                    break
        # reward = reward_function(data, batch_size, seq, torch.arange(batch_size), gts)
        # print('ave reward', np.mean(reward))
        return loss / mask_sum
    def test_mutinomial(self):
        """
        Confirm that torch.multinomial does not sample elements which have
        zero probability.
        """
        freqs = torch.cuda.FloatTensor(
            [
                0.0,
                0.0,
                0.0,
                0.0,
                0.0,
                0.0,
                0.0,
                0.0,
                0.0,
                0.03178183361887932,
                0.027680952101945877,
                0.033176131546497345,
                0.046052902936935425,
                0.07742464542388916,
                0.11543981730937958,
                0.14148041605949402,
                0.15784293413162231,
                0.13180233538150787,
                0.08271478116512299,
                0.049702685326337814,
                0.027557924389839172,
                0.018125897273421288,
                0.011851548217236996,
                0.010252203792333603,
                0.007422595750540495,
                0.005372154992073774,
                0.0045109698548913,
                0.0036087757907807827,
                0.0035267581697553396,
                0.0018864056328311563,
                0.0024605290964245796,
                0.0022964938543736935,
                0.0018453967059031129,
                0.0010662291897460818,
                0.0009842115687206388,
                0.00045109697384759784,
                0.0007791675161570311,
                0.00020504408166743815,
                0.00020504408166743815,
                0.00020504408166743815,
                0.00012302644609007984,
                0.0,
                0.00012302644609007984,
                4.100881778867915e-05,
                0.0,
                0.0,
                0.0,
                0.0,
                0.0,
                0.0,
            ]
        )

        sample = []
        for _ in range(1000):
            torch.cuda.get_rng_state()
            sample = torch.multinomial(freqs, 1000, True)
            if freqs[sample].min() == 0:
                sample_idx = (freqs[sample] == 0).nonzero()[0][0]
                sampled = sample[sample_idx]
                print(
                    "%s th element of last sample was %s, which has probability %s"
                    % (sample_idx, sampled, freqs[sampled])
                )
                return False
        return True
Ejemplo n.º 36
0
    def train_one_batch(self, batch, alpha, beta):
        enc_batch, enc_padding_mask, enc_lens, enc_batch_extend_vocab, extra_zeros, c_t_1, coverage = \
            get_input_from_batch(batch, use_cuda)
        dec_batch, dec_padding_mask, max_dec_len, dec_lens_var, target_batch = \
            get_output_from_batch(batch, use_cuda)

        self.optimizer.zero_grad()

        encoder_outputs, encoder_feature, encoder_hidden = self.model.encoder(
            enc_batch, enc_lens)
        s_t_1 = self.model.reduce_state(encoder_hidden)

        nll_list = []
        gen_summary = torch.LongTensor(
            config.batch_size * [config.sample_size * [[2]]])  # B x S x 1
        if use_cuda: gen_summary = gen_summary.cuda()
        preds_y = gen_summary.squeeze(2)  # B x S
        for di in range(min(config.max_dec_steps, dec_batch.size(1))):
            # Select the current input word
            p1 = np.random.uniform()
            if p1 < alpha:  # use ground truth word
                y_t_1 = dec_batch[:, di]
            else:  # use decoded word
                y_t_1 = preds_y[:, 0]

            final_dist, s_t_1, c_t_1, attn_dist, p_gen, next_coverage = self.model.decoder(
                y_t_1, s_t_1, encoder_outputs, encoder_feature,
                enc_padding_mask, c_t_1, extra_zeros, enc_batch_extend_vocab,
                coverage, di)

            # Select the current output word
            p2 = np.random.uniform()
            if p2 < beta:  # sample the ground truth word
                target = target_batch[:, di]
                sampled_batch = torch.stack(config.sample_size * [target],
                                            1)  # B x S
            else:  # randomly sample a word with given probabilities
                sampled_batch = torch.multinomial(final_dist,
                                                  config.sample_size,
                                                  replacement=True)  # B x S

            # Compute the NLL
            probs = torch.gather(final_dist, 1, sampled_batch).squeeze()
            step_nll = -torch.log(probs + config.eps)

            if config.is_coverage:
                step_coverage_loss = torch.sum(torch.min(attn_dist, coverage),
                                               1)
                step_nll = step_nll + config.cov_loss_wt * step_coverage_loss
                coverage = next_coverage
            nll_list.append(step_nll)

            # Store the decoded words in preds_y
            preds_y = gen_preds(sampled_batch, use_cuda)
            # Add the decoded words into gen_summary (mixed with ground truth and decoded words)
            gen_summary = torch.cat((gen_summary, preds_y.unsqueeze(2)),
                                    2)  # B x S x L

        # compute the REINFORCE score
        nll = torch.sum(torch.stack(nll_list, 2), 2)  # B x S
        all_rewards, avg_reward = compute_reward(batch, gen_summary,
                                                 self.vocab, config.mode,
                                                 use_cuda)  # B x S, 1
        batch_loss = torch.sum(nll * all_rewards, dim=1)  # B
        loss = torch.mean(batch_loss)

        loss.backward()

        self.norm = clip_grad_norm_(self.model.encoder.parameters(),
                                    config.max_grad_norm)
        clip_grad_norm_(self.model.decoder.parameters(), config.max_grad_norm)
        clip_grad_norm_(self.model.reduce_state.parameters(),
                        config.max_grad_norm)

        self.optimizer.step()
        return loss.item(), avg_reward.item()
Ejemplo n.º 37
0
 def _get_negative_sample(self):
     """   임의의 negative sample(index)를 return합니다. 10000개씩 미리 sampling 해놓고 pop하여 사용합니다.   """
     if len(self.neg_candidates) == 0:
         self.neg_candidates = torch.multinomial(
             self.negative_sampling_prob, 10000, replacement=True).tolist()
     return self.neg_candidates.pop()
Ejemplo n.º 38
0
 def __iter__(self):
     return (self.indices[i] for i in torch.multinomial(
         self.weights, self.num_samples, replacement=True))
Ejemplo n.º 39
0
    def observe(self, x, t, y, pretrained=None):
        # update memory
        # temp
        # we dont really use it :) in the greedy variant

        # Update ring buffer storing examples from current task, equals to batch size
        bsz = y.data.size(0)

        endcnt = min(self.mem_cnt + bsz, self.n_memories)
        effbsz = endcnt - self.mem_cnt
        self.memory_data[self.mem_cnt:endcnt].copy_(x.data[:effbsz])
        if bsz == 1:
            self.memory_labs[self.mem_cnt] = y.data[0]
        else:
            self.memory_labs[self.mem_cnt:endcnt].copy_(y.data[:effbsz])
        self.mem_cnt += effbsz

        if self.sampled_memory_data is not None:
            #shuffle buffer, determine batch size of buffer sampled memories
            shuffeled_inds = torch.randperm(self.sampled_memory_labs.size(0))
            effective_batch_size = min(self.n_constraints,
                                       self.sampled_memory_labs.size(0))
            b_index = 0

        #gradients of used buffer samples
        self.mem_grads = None
        this_sim = 0

        for iter_i in range(self.n_iter):

            self.zero_grad()
            loss = self.ce(self.forward(x), y)
            loss.backward()
            this_grad = get_grad_vector(self.parameters,
                                        self.grad_dims).unsqueeze(0)
            self.opt.step()

            if self.sampled_memory_data is not None:

                random_batch_inds = shuffeled_inds[
                    b_index *
                    effective_batch_size:b_index * effective_batch_size +
                    effective_batch_size]
                batch_x = self.sampled_memory_data[random_batch_inds]
                batch_y = self.sampled_memory_labs[random_batch_inds]
                self.zero_grad()

                loss = self.ce(self.forward(batch_x), batch_y)
                loss.backward()

                self.opt.step()
                b_index += 1
                if b_index * effective_batch_size >= self.sampled_memory_labs.size(
                        0):
                    b_index = 0

        ##HERE MEMORY IS EQUAL TO THE BATCH SIZE, this procedure is performed for every recieved batch
        if self.mem_cnt == self.n_memories:
            self.eval()

            if self.sampled_memory_data is not None and self.n_sampled_memories <= self.sampled_memory_data.size(
                    0):  #buffer is full

                batch_sim = self.get_batch_sim(
                    effective_batch_size
                )  #estimate similarity score for the recieved samples to randomly drawn samples from buffer
                # for effecency we estimate the similarity for the whole batch

                if (batch_sim) < self.sim_th:

                    mem_data = x.clone()
                    mem_lab = y.clone()

                    buffer_sim = (self.sampled_memory_cos - torch.min(
                        self.sampled_memory_cos)) / (
                            (torch.max(self.sampled_memory_cos) -
                             torch.min(self.sampled_memory_cos)) + 0.01)

                    index = torch.multinomial(
                        buffer_sim, mem_data.size(0), replacement=False
                    )  #draw candidates for replacement from the buffer

                    # self.memory_data.size(0)
                    batch_item_sim = self.get_each_batch_sample_sim(
                    )  # estimate the similarity of each sample in the recieved batch to the randomly drawn samples from the buffer.

                    scaled_batch_item_sim = ((batch_item_sim + 1) /
                                             2).unsqueeze(1).clone()

                    buffer_repl_batch_sim = (
                        (self.sampled_memory_cos[index] + 1) /
                        2).unsqueeze(1).clone()
                    #draw an event to decide on replacement decision

                    # (100,1)
                    outcome = torch.multinomial(torch.cat(
                        (scaled_batch_item_sim, buffer_repl_batch_sim), dim=1),
                                                1,
                                                replacement=False)  #

                    #replace samples with outcome =1
                    added_indx = torch.arange(end=batch_item_sim.size(0))
                    sub_index = outcome.squeeze(1).byte()
                    self.sampled_memory_data[index[sub_index]] = mem_data[
                        added_indx[sub_index]].clone()
                    self.sampled_memory_labs[index[sub_index]] = mem_lab[
                        added_indx[sub_index]].clone()
                    self.sampled_memory_cos[index[sub_index]] = batch_item_sim[
                        added_indx[sub_index]].clone()
                    self.sampled_memory_taskids[index[sub_index]] = t

            else:
                #add new samples to the buffer
                added_inds = torch.arange(0, self.memory_data.size(0))

                new_task_ids = torch.zeros(added_inds.size(0)) + t
                #first buffer insertion
                if self.sampled_memory_data is None:

                    self.sampled_memory_data = self.memory_data[
                        added_inds].clone()
                    self.sampled_memory_labs = self.memory_labs[
                        added_inds].clone()
                    self.sampled_memory_taskids = new_task_ids.clone()

                    self.sampled_memory_cos = torch.zeros(
                        added_inds.size(0)) + 0.1
                else:
                    self.get_batch_sim(
                        effective_batch_size)  #draw random samples from buffer
                    this_sampled_memory_cos = self.get_each_batch_sample_sim(
                    ).clone()  #estimate a score for each added sample
                    self.sampled_memory_cos = torch.cat(
                        (self.sampled_memory_cos,
                         this_sampled_memory_cos.clone()),
                        dim=0)
                    self.sampled_memory_data = torch.cat(
                        (self.sampled_memory_data,
                         self.memory_data[added_inds].clone()),
                        dim=0)
                    self.sampled_memory_labs = torch.cat(
                        (self.sampled_memory_labs,
                         self.memory_labs[added_inds].clone()),
                        dim=0)
                    self.sampled_memory_taskids = torch.cat(
                        (self.sampled_memory_taskids, new_task_ids),
                        dim=0).clone()

            #self.print_taskids_stats()
            self.mem_cnt = 0
            self.train()
Ejemplo n.º 40
0
def main():
    nb_epochs = 30
    #nb_epochs = 1
    batch_size = 64
    hidden_size = 256
    embedding_dim = 300
    max_len = 20
    teacher_forcing = 0.6
    min_count = 2
    max_grad_norm = 5
    val_len = 5000
    weight_decay = 0.00001

    eng_fr_filename = './data/eng-fra.txt'
    dataset = TSVSentencePairDataset(eng_fr_filename, max_len, min_count)
    print('Dataset: {}'.format(len(dataset)))

    train_len = len(dataset) - val_len
    dataset_train, dataset_val = torch.utils.data.dataset.random_split(
        dataset, [train_len, val_len])
    print('Train {}, val: {}'.format(len(dataset_train), len(dataset_val)))

    data_loader_train = torch.utils.data.DataLoader(dataset_train,
                                                    batch_size,
                                                    shuffle=True)
    data_loader_val = torch.utils.data.DataLoader(dataset_val,
                                                  batch_size,
                                                  shuffle=False)

    vocab_size = len(dataset.vocab)
    padding_idx = dataset.vocab[TSVSentencePairDataset.PAD_TOKEN]
    init_idx = dataset.vocab[TSVSentencePairDataset.INIT_TOKEN]
    model = Seq2SeqModel(vocab_size, embedding_dim, hidden_size, padding_idx,
                         init_idx, max_len, teacher_forcing)
    model = cuda(model)

    parameters = list(model.parameters())
    optimizer = torch.optim.Adam(parameters, weight_decay=weight_decay)
    criterion = torch.nn.CrossEntropyLoss(
        ignore_index=dataset.vocab[TSVSentencePairDataset.PAD_TOKEN])

    phases = [
        'train',
        'val',
    ]
    data_loaders = [
        data_loader_train,
        data_loader_val,
    ]

    for epoch in range(nb_epochs):
        for phase, data_loader in zip(phases, data_loaders):
            if phase == 'train':
                model.train()
            else:
                model.eval()

            epoch_loss = []
            for i, (inputs, targets) in enumerate(data_loader):
                optimizer.zero_grad()

                inputs = variable(inputs)
                targets = variable(targets)

                outputs = model(inputs, targets)

                targets = targets.view(-1)
                outputs = outputs.view(targets.size(0), -1)

                loss = criterion(outputs, targets)

                if phase == 'train':
                    loss.backward()
                    torch.nn.utils.clip_grad_norm(parameters, max_grad_norm)
                    optimizer.step()

                epoch_loss.append(float(loss))

            epoch_loss = np.mean(epoch_loss)
            if phase == 'train':
                print('Epoch {:03d} | {} loss: {:.3f}'.format(
                    epoch, phase, epoch_loss),
                      end='')
            else:
                print(', {} loss: {:.3f}'.format(phase, epoch_loss), end='\n')

            # print random sentence
            if phase == 'val':
                random_idx = np.random.randint(len(dataset_val))
                inputs, targets = dataset_val[random_idx]
                inputs_var = variable(inputs)
                print("Test")

                outputs_var = model(inputs_var.unsqueeze(
                    0))  # unsqueeze to get the batch dimension

                #outputs = argmax(outputs_var).squeeze(0).data.cpu().numpy()

                softmax = torch.nn.Softmax(dim=1)
                outputs = softmax(outputs_var)
                outputs = torch.multinomial(outputs, 1).data.view(-1)

                print(u'> {}'.format(
                    get_sentence_from_indices(
                        inputs, dataset.vocab,
                        TSVSentencePairDataset.EOS_TOKEN)))
                print(u'= {}'.format(
                    get_sentence_from_indices(
                        targets, dataset.vocab,
                        TSVSentencePairDataset.EOS_TOKEN)))
                print(u'< {}'.format(
                    get_sentence_from_indices(
                        outputs, dataset.vocab,
                        TSVSentencePairDataset.EOS_TOKEN)))

                print()
def train(pcnn,
          optimizer,
          datasets,
          max_epoch,
          batch_size,
          max_patience,
          beta,
          ims,
          directory,
          generation=True):
    """
    pcnn (autoregressive.pixelcnn.PixelCNN): Model to train
    optimizer (torch.optim.Optimizer): Optimizer
    datasets (list of torch.utils.data.Dataset): Trainset and testset
    epoch (int): Number of training epochs
    batch_size (int): Mini-batch size
    max_patience (int): Early stopping algorithm's maximum patience
    beta (float): Entropy regularization coefficient (https://arxiv.org/abs/1701.06548)
    ims (list of int): Images' dimension
    directory (str): Path to a directory to store results
    generation (bool): Weither or not generate random images
    """

    phase = ('train', 'test')
    trainset, testset = datasets
    sets = {'train': trainset, 'test': testset}

    writer = SummaryWriter(os.path.join(directory, 'logs'))

    best_auc = 0.0
    best_model = copy.deepcopy(pcnn)

    patience = 0
    epoch = 0

    while patience < max_patience and epoch < max_epoch:
        # for e in range(epoch):

        likelihood = []
        groundtruth = []
        name = []

        for p in phase:
            running_loss = 0
            running_xentropy = 0
            running_entropy = 0
            running_non_fixed_entropy = 0

            pcnn.train(p == 'train')

            dataloader = DataLoader(sets[p],
                                    batch_size=batch_size,
                                    shuffle=True,
                                    num_workers=4)

            for i_batch, sample in enumerate(tqdm(dataloader)):
                optimizer.zero_grad()
                img = Variable(sample['img'],
                               volatile=(p == 'test')).float().cuda()
                lbl = Variable(img.data[:, 0] * 255,
                               volatile=(p == 'test')).long().cuda()
                name += sample['name']

                noise = Variable(torch.randn(img.size()) * beta).cuda()
                img = img + noise
                logits = pcnn(img)[0]

                cross_entropy = torch.nn.functional.cross_entropy(logits, lbl)
                mean_entropy, non_fixed_mean_entropy = compute_entropy(logits)
                loss = cross_entropy  #- beta * mean_entropy
                if p == 'train':
                    loss.backward()
                    optimizer.step()
                else:
                    lbl = torch.unsqueeze(lbl, 1)
                    groundtruth += sample['lbl'].numpy().tolist()
                    onehot_lbl = torch.FloatTensor(img.size(0), 256, ims[0],
                                                   ims[1]).zero_().cuda()
                    onehot_lbl = Variable(onehot_lbl.scatter_(1, lbl.data, 1))

                    probs = torch.nn.functional.softmax(logits, dim=1)
                    probs = probs * onehot_lbl
                    probs = torch.sum(probs, 1)
                    probs = torch.log(probs)  #* -1
                    probs = probs.view((-1, ims[0] * ims[1]))
                    probs = torch.sum(probs, dim=1)
                    probs = probs.data.cpu().numpy().tolist()
                    likelihood += probs

                running_loss += loss.data[0]
                running_xentropy += cross_entropy.data[0]
                running_entropy += mean_entropy.data[0]
                running_non_fixed_entropy += non_fixed_mean_entropy.data[0]

            if p == 'test':
                likelihood = np.array(likelihood)
                infidx = np.argwhere(np.isinf(likelihood))
                # for infx in infidx:
                #     print(name[infx[0]])
                try:
                    likelihood[likelihood == -np.inf] = likelihood[
                        likelihood != -np.inf].min()  #Remove -inf
                except ValueError:
                    likelihood[likelihood == -np.inf] = -20000.0
                if (likelihood.dtype.char in np.typecodes['AllFloat']
                        and not np.isfinite(likelihood.sum())
                        and not np.isfinite(likelihood).all()):
                    import pudb
                    pudb.set_trace()
                fpr, tpr, thresholds = metrics.roc_curve(
                    groundtruth, likelihood)
                auc = metrics.auc(fpr, tpr)
            else:
                auc = 0

            epoch_loss = running_loss / (i_batch + 1)
            epoch_xentropy = running_xentropy / (i_batch + 1)
            epoch_entropy = running_entropy / (i_batch + 1)
            epoch_non_fixed_entropy = running_non_fixed_entropy / (i_batch + 1)
            writer.add_scalar('learning_curve/{}/loss'.format(p), epoch_loss,
                              epoch)
            writer.add_scalar('learning_curve/{}/cross_entropy'.format(p),
                              epoch_xentropy, epoch)
            writer.add_scalar('learning_curve/{}/entropy'.format(p),
                              epoch_entropy, epoch)
            writer.add_scalar('learning_curve/{}/non_fixed_entropy'.format(p),
                              epoch_non_fixed_entropy, epoch)
            writer.add_scalar('auc/{}'.format(p), auc, epoch)
            print(
                'Epoch {} ({}): loss = {} (xentropy = {}, entropy = {}), AUC = {}'
                .format(epoch, p, epoch_loss, epoch_xentropy, epoch_entropy,
                        auc))

            if p == 'test':
                if epoch % 10 == 0 and generation:
                    synthetic = torch.zeros(16, 1, ims[0], ims[1]).cuda()
                    for i in tqdm(range(ims[0])):
                        for j in range(ims[1]):
                            probs = pcnn(Variable(synthetic, volatile=True))[0]
                            probs = torch.nn.functional.softmax(probs[:, :, i,
                                                                      j]).data
                            synthetic[:, :, i, j] = torch.multinomial(
                                probs, 1).float() / 255.

                    synthetic = synthetic.cpu().numpy()
                    synthetic = np.reshape(synthetic, (4, 4, ims[0], ims[1]))
                    synthetic = np.swapaxes(synthetic, 1, 2)
                    synthetic = np.reshape(synthetic, (ims[0] * 4, ims[1] * 4))
                    plt.clf()
                    plt.imshow(synthetic)
                    plt.savefig(os.path.join(directory, 'generation',
                                             '{}.svg'.format(epoch)),
                                format='svg',
                                bbox_inches='tight')

                if auc > best_auc:
                    best_model = copy.deepcopy(pcnn)
                    torch.save(pcnn.state_dict(),
                               os.path.join(directory, 'serial', 'best_model'))
                    print('Best model saved.')
                    best_auc = auc
                    patience = 0
                else:
                    patience += 1
                    print('Patience {}/{}'.format(patience, max_patience))

            #Plot reconstructions
            logits = logits.permute(0, 2, 3, 1)
            probs = torch.nn.functional.softmax(logits, dim=3)
            argmax = torch.max(probs, 3)[1]
            argmax = argmax.data.cpu().numpy()
            lbl = lbl.data.cpu().numpy()
            nb_img = min(argmax.shape[0], 4)
            lbl = np.reshape(lbl, (-1, ims[0], ims[1]))[0:nb_img]
            argmax = np.reshape(argmax, (-1, ims[0], ims[1]))[0:nb_img]

            plt.clf()
            lbl = np.reshape(lbl, (1, nb_img, ims[0], ims[1]))
            lbl = np.swapaxes(lbl, 1, 2)
            lbl = np.reshape(lbl, (ims[0], nb_img * ims[1]))
            ax = plt.subplot2grid((2, 1), (0, 0), rowspan=1, colspan=1)
            ax.imshow(lbl)
            argmax = np.reshape(argmax, (1, nb_img, ims[0], ims[1]))
            argmax = np.swapaxes(argmax, 1, 2)
            argmax = np.reshape(argmax, (ims[0], nb_img * ims[1]))
            ax = plt.subplot2grid((2, 1), (1, 0), rowspan=1, colspan=1)
            ax.imshow(argmax)
            plt.savefig(os.path.join(directory, 'reconstruction_{}'.format(p),
                                     '{}.svg'.format(epoch)),
                        format='svg',
                        bbox_inches='tight')

            if p == 'test':
                epoch += 1

    writer.export_scalars_to_json(
        os.path.join(directory, 'logs', 'scalars.json'))
    writer.close()

    return best_model
    def forward(self, src_tokens, src_lengths, prev_output_tokens, tgt_tokens,
                **kwargs):

        assert tgt_tokens is not None, "forward function only supports training."

        # encoding
        encoder_out = self.encoder(src_tokens,
                                   src_lengths=src_lengths,
                                   **kwargs)

        # generate training labels for insertion
        masked_tgt_masks, masked_tgt_tokens, mask_ins_targets = _get_ins_targets(
            prev_output_tokens, tgt_tokens, self.pad, self.unk)
        mask_ins_targets = mask_ins_targets.clamp(
            min=0, max=255)  # for safe prediction
        mask_ins_masks = prev_output_tokens[:, 1:].ne(self.pad)

        mask_ins_out, _ = self.decoder.forward_mask_ins(
            normalize=False,
            prev_output_tokens=prev_output_tokens,
            encoder_out=encoder_out)
        word_ins_out, _ = self.decoder.forward_word_ins(
            normalize=False,
            prev_output_tokens=masked_tgt_tokens,
            encoder_out=encoder_out)

        # make online prediction
        if self.decoder.sampling_for_deletion:
            word_predictions = torch.multinomial(
                F.softmax(word_ins_out, -1).view(-1, word_ins_out.size(-1)),
                1).view(word_ins_out.size(0), -1)
        else:
            word_predictions = F.log_softmax(word_ins_out, dim=-1).max(2)[1]

        word_predictions.masked_scatter_(~masked_tgt_masks,
                                         tgt_tokens[~masked_tgt_masks])

        # generate training labels for deletion
        word_del_targets = _get_del_targets(word_predictions, tgt_tokens,
                                            self.pad)
        word_del_out, _ = self.decoder.forward_word_del(
            normalize=False,
            prev_output_tokens=word_predictions,
            encoder_out=encoder_out)
        word_del_masks = word_predictions.ne(self.pad)

        return {
            "mask_ins": {
                "out": mask_ins_out,
                "tgt": mask_ins_targets,
                "mask": mask_ins_masks,
                "ls": 0.01,
            },
            "word_ins": {
                "out": word_ins_out,
                "tgt": tgt_tokens,
                "mask": masked_tgt_masks,
                "ls": self.args.label_smoothing,
                "nll_loss": True
            },
            "word_del": {
                "out": word_del_out,
                "tgt": word_del_targets,
                "mask": word_del_masks
            }
        }
Ejemplo n.º 43
0
    def _sample(self,
                fc_feats,
                att_feats,
                att_masks=None,
                opt={},
                total_probs=False):
        sample_max = opt.get('sample_max', 1)
        beam_size = opt.get('beam_size', 1)
        temperature = opt.get('temperature', 1.0)
        batch_size = fc_feats.size(0)
        state = self.init_hidden(batch_size)
        seq = fc_feats.new_zeros(batch_size, self.seq_length,
                                 dtype=torch.long).cuda()
        output_logit = fc_feats.new_zeros(batch_size, self.seq_length,
                                          self.depth).cuda()
        unfinished = fc_feats.new_ones(batch_size, dtype=torch.uint8).cuda()
        n_cluster = len(self.cluster_size)
        for t in range(self.seq_length + 2):
            if t == 0:
                xt = self.img_embed(fc_feats)
            else:
                if t == 1:  # input <bos>
                    it = fc_feats.data.new(batch_size).long().zero_()
                xt = self.embed(it)

            output, state = self.core(xt, state)
            phi = self.logit(output)  # phi: batch, vocab-1
            # sample the next_word
            if t == self.seq_length + 1:  # skip if we achieve maximum length
                break
            if t >= 1:
                mask_depth = unfinished.clone()
                code_sum = torch.zeros(batch_size, 1).float().cuda()
                probs_step_1 = F.softmax(
                    torch.cat([
                        phi[:, :(n_cluster - 1)],
                        torch.zeros(batch_size, 1).float().cuda()
                    ], 1), 1)
                if sample_max:
                    it_1 = torch.max(probs_step_1.data,
                                     1)[1].view(-1).cuda().long()
                else:
                    it_1 = torch.multinomial(probs_step_1.data,
                                             1).cuda().squeeze(1)
                it_2 = torch.zeros_like(it_1).cuda()
                start = n_cluster - 1
                for i in range(n_cluster):
                    if self.cluster_size[i] != 1:
                        index = it_1 == i
                        if index.sum() != 0:
                            probs_step_2 = F.softmax(
                                torch.cat([
                                    phi[index,
                                        start:(start + self.cluster_size[i] -
                                               1)],
                                    torch.zeros(index.sum(), 1).float().cuda()
                                ], 1), 1)
                            if sample_max:
                                it_2[index] = torch.max(
                                    probs_step_2.data,
                                    1)[1].view(-1).cuda().long()
                            else:
                                it_2[index] = torch.multinomial(
                                    probs_step_2.data, 1).cuda().squeeze(1)
                    start = start + self.cluster_size[i] - 1
                code_sum = it_1 * (self.vocab_size + 1) + it_2
                it = torch.from_numpy(
                    code2vocab_fun(code_sum.cpu().numpy(),
                                   self.code2vocab)).cuda().long()
                if t == 1:
                    unfinished = it > 0
                else:
                    unfinished = unfinished * (it > 0)
                it = it * unfinished.type_as(it)
                seq[:, t - 1] = it
                if unfinished.sum() == 0:
                    break
        return seq, output_logit
Ejemplo n.º 44
0
            "WARNING: You have a CUDA device, so you should probably run with --cuda"
        )

device = torch.device("cuda" if args.cuda else "cpu")

if args.temperature < 1e-3:
    parser.error("--temperature has to be greater or equal 1e-3")

with open(args.checkpoint, 'rb') as f:
    model = torch.load(f).to(device)
model.eval()

corpus = data.Corpus(args.data)
ntokens = len(corpus.dictionary)
hidden = model.init_hidden(1)
input = torch.randint(ntokens, (1, 1), dtype=torch.long).to(device)

with open(args.outf, 'w') as outf:
    with torch.no_grad():  # no tracking history
        for i in range(args.words):
            output, hidden = model(input, hidden)
            word_weights = output.squeeze().div(args.temperature).exp().cpu()
            word_idx = torch.multinomial(word_weights, 1)[0]
            input.fill_(word_idx)
            word = corpus.dictionary.idx2word[word_idx]

            outf.write(word + ('\n' if i % 20 == 19 else ' '))

            if i % args.log_interval == 0:
                print('| Generated {}/{} words'.format(i, args.words))
Ejemplo n.º 45
0
    def sample(self, fc_feats, att_feats, init_index, opt={}):
        sample_max = opt.get('sample_max', 1)
        beam_size = opt.get('beam_size', 1)
        temperature = opt.get('temperature', 1.0)

        if beam_size > 1:
            return self.sample_beam(fc_feats, att_feats, init_index, opt)

        batch_size = fc_feats.size(0)
        seq = []
        seqLogprobs = []
        logprobs_all = []

        init_h = self.fc2h(fc_feats)
        init_h = init_h.unsqueeze(0)
        init_c = init_h.clone()
        state = (init_h, init_c)

        for t in range(self.seq_length):
            if t == 0:
                it = fc_feats.data.new(batch_size).long().fill_(init_index)
            elif sample_max:
                sampleLogprobs, it = torch.max(logprobs.data, 1)
                it = it.view(-1).long()
            else:
                if temperature == 1.0:
                    prob_prev = torch.exp(logprobs.data).cpu(
                    )  # fetch prev distribution: shape Nx(M+1)
                else:
                    # scale logprobs by temperature
                    prob_prev = torch.exp(torch.div(logprobs.data,
                                                    temperature)).cpu()

                it = torch.multinomial(prob_prev, 1).cuda()
                sampleLogprobs = logprobs.gather(
                    1,
                    Variable(it, requires_grad=False).cuda(
                    ))  # gather the logprobs at sampled positions
                it = it.view(
                    -1).long()  # and flatten indices for downstream processing

            xt = self.embed(Variable(it, requires_grad=False).cuda())

            if t >= 1:
                if t == 1:
                    unfinished = it > 0
                else:
                    unfinished *= (it > 0)
                if unfinished.sum() == 0:
                    break
                it = it * unfinished.type_as(it)
                seq.append(it)
                seqLogprobs.append(sampleLogprobs.view(-1))

            output, state = self.core.forward(xt, att_feats, state)

            logprobs = F.log_softmax(self.logit(output), dim=1)
            logprobs_all.append(logprobs)

        greedy_seq = torch.cat([_.unsqueeze(1) for _ in seq], 1)
        greedy_seqLogprobs = torch.cat([_.unsqueeze(1) for _ in seqLogprobs],
                                       1)
        greedy_logprobs_all = torch.cat([_.unsqueeze(1) for _ in logprobs_all],
                                        1).contiguous()

        return greedy_seq, greedy_seqLogprobs, greedy_logprobs_all
Ejemplo n.º 46
0
def generate_text_pplm(
    model,
    tokenizer,
    context=None,
    past=None,
    device="cuda",
    perturb=True,
    bow_indices=None,
    classifier=None,
    class_label=None,
    loss_type=0,
    length=100,
    stepsize=0.02,
    temperature=1.0,
    top_k=10,
    sample=False,
    num_iterations=3,
    grad_length=10000,
    horizon_length=1,
    window_length=0,
    decay=False,
    gamma=1.5,
    gm_scale=0.9,
    kl_scale=0.01,
    repetition_penalty=1.0,
):
    output_so_far = None
    if context:
        context_t = torch.tensor(context, device=device, dtype=torch.long)
        while len(context_t.shape) < 2:
            context_t = context_t.unsqueeze(0)
        output_so_far = context_t

    # collect one hot vectors for bags of words
    one_hot_bows_vectors = build_bows_one_hot_vectors(bow_indices, tokenizer, device)

    grad_norms = None
    last = None
    unpert_discrim_loss = 0
    loss_in_time = []
    for i in trange(length, ascii=True):

        # Get past/probs for current output, except for last word
        # Note that GPT takes 2 inputs: past + current_token

        # run model forward to obtain unperturbed
        if past is None and output_so_far is not None:
            last = output_so_far[:, -1:]
            if output_so_far.shape[1] > 1:
                _, past, _ = model(output_so_far[:, :-1])

        unpert_logits, unpert_past, unpert_all_hidden = model(output_so_far)
        unpert_last_hidden = unpert_all_hidden[-1]

        # check if we are abowe grad max length
        if i >= grad_length:
            current_stepsize = stepsize * 0
        else:
            current_stepsize = stepsize

        # modify the past if necessary
        if not perturb or num_iterations == 0:
            pert_past = past

        else:
            accumulated_hidden = unpert_last_hidden[:, :-1, :]
            accumulated_hidden = torch.sum(accumulated_hidden, dim=1)

            if past is not None:
                pert_past, _, grad_norms, loss_this_iter = perturb_past(
                    past,
                    model,
                    last,
                    unpert_past=unpert_past,
                    unpert_logits=unpert_logits,
                    accumulated_hidden=accumulated_hidden,
                    grad_norms=grad_norms,
                    stepsize=current_stepsize,
                    one_hot_bows_vectors=one_hot_bows_vectors,
                    classifier=classifier,
                    class_label=class_label,
                    loss_type=loss_type,
                    num_iterations=num_iterations,
                    horizon_length=horizon_length,
                    window_length=window_length,
                    decay=decay,
                    gamma=gamma,
                    kl_scale=kl_scale,
                    device=device,
                )
                loss_in_time.append(loss_this_iter)
            else:
                pert_past = past

        pert_logits, past, pert_all_hidden = model(last, past=pert_past)
        pert_logits = pert_logits[:, -1, :] / temperature  # + SMALL_CONST

        for token_idx in set(output_so_far[0].tolist()):
            if pert_logits[0, token_idx] < 0:
                pert_logits[0, token_idx] *= repetition_penalty
            else:
                pert_logits[0, token_idx] /= repetition_penalty

        pert_probs = F.softmax(pert_logits, dim=-1)

        if classifier is not None:
            ce_loss = torch.nn.CrossEntropyLoss()
            prediction = classifier(torch.mean(unpert_last_hidden, dim=1))
            label = torch.tensor([class_label], device=device, dtype=torch.long)
            unpert_discrim_loss = ce_loss(prediction, label)
            print("unperturbed discrim loss", unpert_discrim_loss.data.cpu().numpy())
        else:
            unpert_discrim_loss = 0

        # Fuse the modified model and original model
        if perturb:

            unpert_probs = F.softmax(unpert_logits[:, -1, :], dim=-1)

            pert_probs = (pert_probs ** gm_scale) * (unpert_probs ** (1 - gm_scale))  # + SMALL_CONST
            pert_probs = top_k_filter(pert_probs, k=top_k, probs=True)  # + SMALL_CONST

            # rescale
            if torch.sum(pert_probs) <= 1:
                pert_probs = pert_probs / torch.sum(pert_probs)

        else:
            pert_logits = top_k_filter(pert_logits, k=top_k)  # + SMALL_CONST
            pert_probs = F.softmax(pert_logits, dim=-1)

        # sample or greedy
        if sample:
            last = torch.multinomial(pert_probs, num_samples=1)

        else:
            _, last = torch.topk(pert_probs, k=1, dim=-1)

        # update context/output_so_far appending the new token
        output_so_far = last if output_so_far is None else torch.cat((output_so_far, last), dim=1)

        print(tokenizer.decode(output_so_far.tolist()[0]))

    return output_so_far, unpert_discrim_loss, loss_in_time
Ejemplo n.º 47
0
outputs = torch.stack(outputs)
outputs = outputs.permute(1, 2, 0)
output = outputs[:, :, -1]

temperature = 1.0
length_of_review = 150

review = []
####
for j in range(length_of_review):

    output = output / temperature
    probs = torch.exp(output)
    probs[:, 0] = 0.0
    probs = probs / (torch.sum(probs, dim=1).unsqueeze(1))
    x = torch.multinomial(probs, 1)
    review.append(x.cpu().data.numpy()[:, 0])

    embed = model.embedding(x)

    h = model.lstm1(embed[:, 0, :])
    h = model.bn_lstm1(h)
    h = model.dropout1(h, dropout=0.3, train=False)

    h = model.lstm2(h)
    h = model.bn_lstm2(h)
    h = model.dropout2(h, dropout=0.3, train=False)

    h = model.lstm3(h)
    h = model.bn_lstm3(h)
    h = model.dropout3(h, dropout=0.3, train=False)
def _log_prob_proposal_posterior_atomic(self,
        theta, x, masks, _num_atoms = 10, _use_combined_loss = False
):
    """
    Return log probability of the proposal posterior for atomic proposals.

    We have two main options when evaluating the proposal posterior.
        (1) Generate atoms from the proposal prior.
        (2) Generate atoms from a more targeted distribution, such as the most
            recent posterior.
    If we choose the latter, it is likely beneficial not to do this in the first
    round, since we would be sampling from a randomly-initialized neural density
    estimator.

    Args:
        theta: Batch of parameters θ.
        x: Batch of data.
        masks: Mask that is True for prior samples in the batch in order to train
            them with prior loss.

    Returns:
        Log-probability of the proposal posterior.
    """

    batch_size = theta.shape[0]

    num_atoms = clamp_and_warn(
        "num_atoms", _num_atoms, min_val=2, max_val=batch_size
    )

    # Each set of parameter atoms is evaluated using the same x,
    # so we repeat rows of the data x, e.g. [1, 2] -> [1, 1, 2, 2]
    repeated_x = repeat_rows(x, num_atoms)

    # To generate the full set of atoms for a given item in the batch,
    # we sample without replacement num_atoms - 1 times from the rest
    # of the theta in the batch.
    probs = torch.ones(batch_size, batch_size) * (1 - torch.eye(batch_size)) / (batch_size - 1)

    choices = torch.multinomial(probs, num_samples=num_atoms - 1, replacement=False)
    contrasting_theta = theta[choices]
    # We can now create our sets of atoms from the contrasting parameter sets
    # we have generated.
    atomic_theta = torch.cat((theta[:, None, :], contrasting_theta), dim=1).reshape(
        batch_size * num_atoms, -1
    )
    # Evaluate large batch giving (batch_size * num_atoms) log prob posterior evals.
    #print("flow model forward calculation ...")
    log_prob_posterior = self.netPosterior.log_prob(atomic_theta, repeated_x)
    log_prob_posterior = log_prob_posterior.reshape(batch_size, num_atoms)
    #print("log prob posterior : ", log_prob_posterior.mean())
    # Get (batch_size * num_atoms) log prob prior evals.
    log_prob_prior = torch.Tensor(self.prior.log_prob(atomic_theta)).to(self.args.device)
    log_prob_prior = log_prob_prior.reshape(batch_size, num_atoms)

    # Compute unnormalized proposal posterior.
    unnormalized_log_prob = log_prob_posterior - log_prob_prior

    # Normalize proposal posterior across discrete set of atoms.
    log_prob_proposal_posterior = unnormalized_log_prob[:, 0] - torch.logsumexp(
        unnormalized_log_prob, dim=-1
    )

    # XXX This evaluates the posterior on _all_ prior samples
    if _use_combined_loss:
        log_prob_posterior_non_atomic = self.netPosterior.log_prob(theta, x)
        masks = masks.reshape(-1)
        log_prob_proposal_posterior = (
                masks * log_prob_posterior_non_atomic + log_prob_proposal_posterior
        )

    return log_prob_proposal_posterior
Ejemplo n.º 49
0
    def forward(self, encoder_state, initial_hidden_state, target_sequence, sample_probability=0.0):
        """

        :param encoder_state:  (bs, max_len, 256)
        :param initial_hidden_state: (bs, 256)
        :param target_sequence:  (bs, 25) target
        :param sample_probability:
        :return:
        """
        if target_sequence is None:
            sample_probability = 1.
        else:
            target_sequence = target_sequence.permute(1, 0)  # (25,bs)

        h_t = self.hidden_map(initial_hidden_state)  # (bs, 256)

        batch_size = encoder_state.size(0)  # bs

        context_vectors = self._init_context_vectors(batch_size)  # (bs, 256)

        y_t_index = self._init_indices(batch_size)  # (bs, ) [2] * bs

        device = encoder_state.device
        h_t = h_t.to(device)
        y_t_index = y_t_index.to(device)
        context_vectors = context_vectors.to(device)

        output_vectors = []
        self._cached_p_attn = []
        self._cached_ht = []
        self._cached_decoder_state = encoder_state.cpu().detach().numpy()  # (bs ,10, 256)

        output_sequence_size = target_sequence.size(0)  # 25

        for i in range(output_sequence_size):
            use_sample = np.random.random() < sample_probability
            if not use_sample:
                y_t_index = target_sequence[i]

            y_input_vector = self.target_embedding(y_t_index)  # (bs, 64)

            rnn_input = torch.cat([y_input_vector, context_vectors], dim=1)  # (bs, 64 + 256)

            h_t = self.gru_cell(rnn_input, h_t)  # (bs, 256)

            self._cached_ht.append(h_t.cpu().data.numpy())

            # (bs, max_len, 256)
            # (bs, 256)

            # 输出
            # (bs ,256)
            # (bs, max_len)
            context_vectors, p_attn, _ = verbose_attention(
                encoder_state_vectors=encoder_state,
                query_vector=h_t,
            )

            self._cached_p_attn.append(p_attn.cpu().detach().numpy())

            prediction_vector = torch.cat((context_vectors, h_t), dim=1)
            score_for_y_t_index = self.classifier(F.dropout(prediction_vector, 0.3))  # (bs, 4911)

            if use_sample:
                p_y_t_index = F.softmax(score_for_y_t_index * self._sampling_temperature, dim=1)
                y_t_index = torch.multinomial(p_y_t_index, 1).squeeze()

            output_vectors.append(score_for_y_t_index)

        # (25, 5, 4911)
        output_vectors = torch.stack(output_vectors).permute(1, 0, 2)  # (bs, 25, 4911)

        return output_vectors
Ejemplo n.º 50
0
def fit_reconstructeur(reconstructeur, train, epochs, sample_size=None, mini_batch_size=1,
                       lr=1e-4, lr_decay=.05, grid_scale=1, loss_factor=1.0,
                       ind_cloud_saved=range(0), test=None, list_epoch_loss=range(0)):
    """
    grid_scale:
        les points 3D générés sont calculés à partir d'un échantillonage d'un carré 1*1,
        prendre un carré 100*100 revient à augmenter les coefficients du premier Linear sans pénaliser le modèle
        
    ind_predicted:
        indices dans train ou test pour lesquels sauvegarder l'évolution des prédictions
        (les indices de test suivent ceux de train)
    
    ind_plotted:
        indices pour lesquels calculer la loss en test
    """
    if type(loss_factor) == float:
        loss_factor = [loss_factor]*epochs
    assert len(loss_factor) == epochs
    
    if sample_size is None:
        # considère tous les points à chaque itération
        sub_sampling = [None] * epochs
    else:
        ones = torch.ones(len(train[1][0].points))
        sub_sampling = [torch.multinomial(ones, sample_size) for _ in range(epochs)]
    
    list_predicted = {i:[] for i in ind_cloud_saved}
    list_loss_train = []
    list_loss_test = []
    list_loss_detailled = [[] for _ in range(len(train[0]))]
    time_tot = time.time()
    time_loss = 0
    
    reconstructeur.grid *= grid_scale
    optimizer = torch.optim.Adagrad(reconstructeur.parameters(), lr=lr, lr_decay=lr_decay)
    for epoch in range(epochs):
        loss_train = [0,0]
        
        with torch.no_grad():
            reconstructeur.eval()
            # calcule la loss en test
            if epoch in list_epoch_loss:
                # ne se sert pas du même loss_factor que pour train
                if test is not None and len(test[0]):
                    loss_test = [reconstructeur.loss(Nuage(reconstructeur.forward(x), eps=0), y, sub_sampling=sub_sampling[epoch], k=1)
                              for (x, y) in zip(test[0], test[1])]
                    s0 = sum([s[0].item() for s in loss_test]) / len(test[1])
                    s1 = sum([s[1].item() for s in loss_test]) / len(test[1])
                    list_loss_test.append((s0, s1))
            
            # sauvegarde pour l'animation
            for i in ind_cloud_saved:
                # les tests sont considérés comme étant juste après les train
                if i-len(train[0]) >= 0:
                    list_predicted[i].append(reconstructeur.forward(test[0][i-len(train[0])]).detach().numpy())
                
            reconstructeur.train()
        
        # train
        for i in range(0, len(train[0]), mini_batch_size):
            mini_batch_loss = None
            for j in range(i, min(i+mini_batch_size, len(train[0]))):
                x = train[0][j]
                y = train[1][j]
                assert not x.requires_grad
                
                y_pred = reconstructeur.forward(x)
                
                if j in ind_cloud_saved:
                    list_predicted[j].append(y_pred.detach().numpy().copy())
                
                time0 = time.time()
                y_pred = Nuage(y_pred, eps=0)
                loss = reconstructeur.loss(y_pred, y, k=loss_factor[epoch], sub_sampling=sub_sampling[epoch])
                time_loss += time.time() - time0
                
                print("loss", loss[0].item(), loss[1].item())
                loss_train[0] += loss[0].item()
                loss_train[1] += loss[1].item()
                
                list_loss_detailled[j].append((loss[0].item(), loss[1].item()))
                
                loss = loss[0] + loss[1]
                mini_batch_loss = loss if mini_batch_loss is None else mini_batch_loss + loss
                
            # Zero gradients, perform a backward pass, and update the weights.
            optimizer.zero_grad()
            mini_batch_loss.backward()
            optimizer.step()

        list_loss_train.append((loss_train[0] / len(train[1]), loss_train[1] / len(train[1])))
        
        if epoch in list_epoch_loss:
            print("time", epoch,
                  "loss train %.3e" % (list_loss_train[-1][0]+list_loss_train[-1][1]),
                  "loss test %.3e" % (list_loss_test[-1][0]+list_loss_test[-1][1] if len(test[0]) else 0),
                  "\n")
        
    #apprentissage fini, passage en mode évaluation
    reconstructeur.eval()
    time_tot = time.time() - time_tot
    print("temps loss", time_loss," tot", time_tot, "ratio", time_loss/time_tot)
    return {"loss_train": list_loss_train,
            "loss_test":  list_loss_test,
            "loss_detailled":  list_loss_detailled,
            "predicted":  list_predicted,
            "time":       (time_tot, time_loss, time_loss/time_tot),
            }
 def __iter__(self):
     return iter(torch.multinomial(self.weights, self.num_samples, self.replacement).tolist())
def sample_sequence(model,
                    length,
                    context,
                    num_samples=1,
                    temperature=1,
                    top_k=0,
                    top_p=0.0,
                    repetition_penalty=1.0,
                    is_xlnet=False,
                    is_xlm_mlm=False,
                    xlm_mask_token=None,
                    xlm_lang=None,
                    device='cpu'):
    context = torch.tensor(context, dtype=torch.long, device=device)
    context = context.unsqueeze(0).repeat(num_samples, 1)
    generated = context
    with torch.no_grad():
        for _ in range(length):

            inputs = {'input_ids': generated}
            if is_xlnet:
                # XLNet is a direct (predict same token, not next token) and bi-directional model by default
                # => need one additional dummy token in the input (will be masked), attention mask and target mapping (see model docstring)
                input_ids = torch.cat(
                    (generated,
                     torch.zeros((1, 1), dtype=torch.long, device=device)),
                    dim=1)
                perm_mask = torch.zeros(
                    (1, input_ids.shape[1], input_ids.shape[1]),
                    dtype=torch.float,
                    device=device)
                perm_mask[:, :,
                          -1] = 1.0  # Previous tokens don't see last token
                target_mapping = torch.zeros((1, 1, input_ids.shape[1]),
                                             dtype=torch.float,
                                             device=device)
                target_mapping[0, 0, -1] = 1.0  # predict last token
                inputs = {
                    'input_ids': input_ids,
                    'perm_mask': perm_mask,
                    'target_mapping': target_mapping
                }

            if is_xlm_mlm and xlm_mask_token:
                # XLM MLM models are direct models (predict same token, not next token)
                # => need one additional dummy token in the input (will be masked and guessed)
                input_ids = torch.cat((generated,
                                       torch.full((1, 1),
                                                  xlm_mask_token,
                                                  dtype=torch.long,
                                                  device=device)),
                                      dim=1)
                inputs = {'input_ids': input_ids}

            if xlm_lang is not None:
                inputs["langs"] = torch.tensor([xlm_lang] *
                                               inputs["input_ids"].shape[1],
                                               device=device).view(1, -1)

            outputs = model(
                **inputs
            )  # Note: we could also use 'past' with GPT-2/Transfo-XL/XLNet/CTRL (cached hidden-states)
            next_token_logits = outputs[0][0, -1, :] / (
                temperature if temperature > 0 else 1.)

            # reptition penalty from CTRL (https://arxiv.org/abs/1909.05858)
            for _ in set(generated.view(-1).tolist()):
                next_token_logits[_] /= repetition_penalty

            filtered_logits = top_k_top_p_filtering(next_token_logits,
                                                    top_k=top_k,
                                                    top_p=top_p)
            if temperature == 0:  #greedy sampling:
                next_token = torch.argmax(filtered_logits).unsqueeze(0)
            else:
                next_token = torch.multinomial(F.softmax(filtered_logits,
                                                         dim=-1),
                                               num_samples=1)
            generated = torch.cat((generated, next_token.unsqueeze(0)), dim=1)
    return generated
    def sample(self, src_input, src_len, src_oov, oov_list, word2id, k, is_greedy=False):
        """
        Sample k sequeces for each src in src_input

        Args:
            k: number of sequences to sample
            is_greedy: if True, pick up the most probable word after the 1st time step

        """
        # self.model.eval()  # have to be in training mode, to backprop
        batch_size = len(src_input)

        src_mask = self.get_mask(src_input)  # same size as input_src
        src_context, (src_h, src_c) = self.model.encode(src_input, src_len)

        # prepare the init hidden vector, (batch_size, trg_seq_len, dec_hidden_dim)
        dec_hiddens = self.model.init_decoder_state(src_h, src_c)

        # each dec_hidden is (trg_seq_len, dec_hidden_dim)
        initial_input = [word2id[pykp.io.BOS_WORD]] * batch_size
        if isinstance(dec_hiddens, tuple):
            dec_hiddens = (dec_hiddens[0].squeeze(0), dec_hiddens[1].squeeze(0))
            dec_hiddens = [(dec_hiddens[0][i], dec_hiddens[1][i]) for i in range(batch_size)]
        elif isinstance(dec_hiddens, list):
            dec_hiddens = dec_hiddens

        sampled_sequences = [TopN_heap(self.beam_size) for _ in range(batch_size)]

        for batch_i in range(batch_size):
            seq = Sequence(
                batch_id=batch_i,
                sentence=[initial_input[batch_i]],
                dec_hidden=dec_hiddens[batch_i],
                context=src_context[batch_i],
                ctx_mask=src_mask[batch_i],
                src_oov=src_oov[batch_i],
                oov_list=oov_list[batch_i],
                logprobs=None,
                score=0.0,
                attention=[])
            sampled_sequences[batch_i].push(seq)

        for current_len in range(1, self.max_sequence_length + 1):
            # the total number of partial sequences of all the batches
            num_partial_sequences = sum([len(batch_seqs) for batch_seqs in sampled_sequences])

            # flatten 2d sequences (batch_size, beam_size) into 1d batches (batch_size * beam_size) to feed model
            seq_id2batch_id, flattened_id_map, inputs, dec_hiddens, contexts, ctx_mask, src_oovs, oov_lists = self.sequence_to_batch(sampled_sequences)

            # Run one-step generation. log_probs=(batch_size, 1, K), dec_hidden=tuple of (1, batch_size, trg_hidden_dim)
            log_probs, new_dec_hiddens, attn_weights = self.model.generate(
                trg_input=inputs,
                dec_hidden=dec_hiddens,
                enc_context=contexts,
                ctx_mask=ctx_mask,
                src_map=src_oovs,
                oov_list=oov_lists,
                max_len=1,
                return_attention=self.return_attention
            )

            # squeeze these outputs, (hyp_seq_size, trg_len=1, K+1) -> (hyp_seq_size, K+1)
            log_probs = log_probs.view(num_partial_sequences, -1)
            exp_log_probs = torch.exp(log_probs)  # convert the log_prob back to prob
            # m = Categorical(exp_log_probs)

            # probs, words are [batch_size, k] at time 0, and [batch_size * k, 1] later on
            if current_len == 1:
                if is_greedy:
                    probs, words = log_probs.data.topk(k, dim=-1)
                else:
                    # m.sample_n(k)
                    words = torch.multinomial(exp_log_probs, k, replacement=False)
                    probs = torch.gather(log_probs, 1, words)
                    words = words.data
            else:
                if is_greedy:
                    probs, words = log_probs.data.topk(1, dim=-1)
                else:
                    # words = m.sample_n(1)
                    words = torch.multinomial(exp_log_probs, 1, replacement=False)
                    probs = torch.gather(log_probs, 1, words)
                    words = words.data

            # (hyp_seq_size, trg_len=1, src_len) -> (hyp_seq_size, src_len)
            if isinstance(attn_weights, tuple):  # if it's (attn, copy_attn)
                attn_weights = (attn_weights[0].squeeze(1), attn_weights[1].squeeze(1))
            else:
                attn_weights = attn_weights.squeeze(1)

            # tuple of (num_layers * num_directions, batch_size, trg_hidden_dim)=(1, hyp_seq_size, trg_hidden_dim), squeeze the first dim
            if isinstance(new_dec_hiddens, tuple):
                new_dec_hiddens1 = new_dec_hiddens[0].squeeze(0)
                new_dec_hiddens2 = new_dec_hiddens[1].squeeze(0)
                new_dec_hiddens = [(new_dec_hiddens1[i], new_dec_hiddens2[i]) for i in range(num_partial_sequences)]

            # For every partial_sequence (num_partial_sequences in total), find and trim to the best hypotheses (beam_size in total)
            for batch_i in range(batch_size):
                new_partial_sequences = TopN_heap(self.beam_size)

                for partial_id, partial_seq in enumerate(sampled_sequences[batch_i].extract()):
                    flattened_seq_id = flattened_id_map[batch_i][partial_id]

                    seq_number = 1 if current_len > 1 else k

                    # check each new beam and decide to add to hypotheses or completed list
                    for seq_i in range(seq_number):
                        w = words[flattened_seq_id][seq_i]
                        # if w has appeared before, ignore current hypothese
                        # if w in partial_seq.vocab:
                        #     continue

                        # score=0 means this is the first word <BOS>, empty the sentence
                        if current_len > 1:
                            new_sent = copy.copy(partial_seq.sentence) + [w]
                            new_logprobs = partial_seq.logprobs + [probs[flattened_seq_id][seq_i]]
                            new_score = partial_seq.score + probs[flattened_seq_id][seq_i]
                        else:
                            new_sent = [w]
                            new_logprobs = [probs[flattened_seq_id][seq_i]]
                            new_score = probs[flattened_seq_id][seq_i]

                        # dec_hidden and attention of this partial_seq are shared by its descendant beams
                        new_dec_hidden = new_dec_hiddens[flattened_seq_id]

                        if self.return_attention:
                            new_attention = copy.copy(partial_seq.attention)
                            if isinstance(attn_weights, tuple):  # if it's (attn, copy_attn)
                                attn_weights = (attn_weights[0].squeeze(1), attn_weights[1].squeeze(1))
                                new_attention.append((attn_weights[0][flattened_seq_id], attn_weights[1][flattened_seq_id]))
                            else:
                                new_attention.append(attn_weights[flattened_seq_id])
                        else:
                            new_attention = None

                        new_partial_seq = Sequence(
                            batch_id=partial_seq.batch_id,
                            sentence=new_sent,
                            dec_hidden=new_dec_hidden,
                            context=partial_seq.context,
                            ctx_mask=partial_seq.ctx_mask,
                            src_oov=partial_seq.src_oov,
                            oov_list=partial_seq.oov_list,
                            logprobs=new_logprobs,
                            score=new_score,
                            attention=new_attention
                        )

                        # print('Before pushing[%d]' % new_partial_sequences.size())
                        # print(sorted([s.score for s in new_partial_sequences._data]))
                        new_partial_sequences.push(new_partial_seq)
                        # print('After pushing[%d]' % new_partial_sequences.size())
                        # print(sorted([s.score for s in new_partial_sequences._data]))

                    # print('Finished no.%d partial sequence' % partial_id)
                    # print('\t#(hypothese) = %d' % (len(new_partial_sequences)))
                    # print('\t#(completed) = %d' % (sum([len(c) for c in complete_sequences])))

                sampled_sequences[batch_i] = new_partial_sequences

                # print('Batch=%d, \t#(hypothese) = %d' % (batch_i, len(sampled_sequences[batch_i])))
                '''
                # print-out for debug
                print('Source with OOV: \n\t %s' % ' '.join([str(w) for w in partial_seq.src_oov.cpu().data.numpy().tolist()]))
                print('OOV list: \n\t %s' % str(partial_seq.oov_list))

                for seq_id, seq in enumerate(new_partial_sequences._data):
                    print('%d, score=%.5f : %s' % (seq_id, seq.score, str(seq.sentence)))

                print('*' * 50)
                '''

            # print('Round=%d, \t#(batch) = %d, \t#(hypothese) = %d' % (current_len, batch_size, sum([len(batch_heap) for batch_heap in sampled_sequences])))

            # print('Round=%d' % (current_len))
            # print('\t#(hypothese) = %d' % (sum([len(batch_heap) for batch_heap in partial_sequences])))
            # for b_i in range(batch_size):
            #     print('\t\tbatch %d, #(hyp seq)=%d' % (b_i, len(partial_sequences[b_i])))
            # print('\t#(completed) = %d' % (sum([len(batch_heap) for batch_heap in complete_sequences])))
            # for b_i in range(batch_size):
            #     print('\t\tbatch %d, #(completed seq)=%d' % (b_i, len(complete_sequences[b_i])))

        for batch_i in range(batch_size):
            sampled_sequences[batch_i] = sampled_sequences[batch_i].extract(sort=True)

        return sampled_sequences
Ejemplo n.º 54
0
past = None
flag = True

sep = tokenizer.encode("\n\n\n")

while flag:
    "Sampling based method"
    sent = []
    with torch.no_grad():
        for i in range(200):
            logits, past = model_A(prev_input, past=past)
            logits = logits[:, -1, :] / temperature
            logits = top_filtering(logits, top_k=200, top_p=0.9)
            # prev_input = logits.argmax(-1).unsqueeze(1)
            probs = F.softmax(logits, -1)
            prev_input = torch.multinomial(probs, num_samples=1)
            prev_word = prev_input.item()

            if prev_word == 628:
                break
            elif prev_word == tokenizer.encoder["[EOS]"]:
                flag = False
                break
            else:
                sent.append(prev_word)

            # past_position_ids = past_position_ids[:, -1:] + 1

    if not flag:
        break
Ejemplo n.º 55
0
def sample_hard_z(gates_prob):
    return torch.multinomial(gates_prob.exp(), num_samples=1)[:, 0]
Ejemplo n.º 56
0
def generate_text(n, state, words, net, w2i, ntokens, device = 'cuda'):
    # Extract last word
    word = state.split()[-1]
    # Handle the situation where the seed is not contained in the dictionary
    if word in words:
        input = torch.tensor(np.reshape(w2i(word), (1, -1))).long().to(device)
    else:
        input = torch.randint(ntokens, (1, 1), dtype=torch.long).to(device)

    # Generate next word
    with torch.no_grad():  # no tracking history
        for i in range(n):
            # Get output
            output = net(input, False)
            word_weights = output[-1].squeeze().exp().cpu()

            # Sample word from output distribution
            word_idx = torch.multinomial(word_weights, 1)[0]
            word_tensor = torch.Tensor([[word_idx]]).long().to(device)

            # Concatenate the word predicted with the current state
            input = torch.cat([input, word_tensor], 0)
            word = w2i.decoder[word_idx.item()]
            state = '{} {}'.format(state, word)

    # Set punctuations signs and upper case signs
    punc = ['!', '?', '.', ';', ':', ',',"'"]
    upcase = ['?',  '!',  '.']

    # Set initial params
    after_point = False
    new_line_counter = 0
    previous = '_'

    # Print initial state
    print('TEXT:')
    print('{}'.format(state.split()[0]), end = '')

    # Print next word following some given rules
    for i in state.split()[1:]:
        # Avoid loops
        if i == previous:
            continue

        # Update
        previous = i

        # Increment
        new_line_counter += 1

        # Flag: next word capitalized
        if i in upcase:
          after_point = True

        # Flag: start newline after a full point
        if i == '.' and new_line_counter > 10:
          new_line_counter = 0
          print('.')

        # Flag: do not add whitespace, there is punctuation
        elif i in punc:
          print(i, end='')
          new_line_counter -= 1

        # Print new word following flags
        else:
          if after_point:
            if new_line_counter > 1:
                print(' {}'.format(i.capitalize()), end='')
                after_point=False
            # After newline, no whitespace added
            else:
                print('{}'.format(i.capitalize()), end='')
                after_point=False
          else:
            print(' {}'.format(i), end='')
Ejemplo n.º 57
0
Archivo: model.py Proyecto: frank/dl
 def sample_char(self, c_prob, temp):
     # Sample from the softmax distribution using the temperature parameter
     c_prob = torch.softmax(c_prob / temp, 0)
     return torch.multinomial(c_prob, 1)
Ejemplo n.º 58
0
def return_answer_index(probs_numpy,
                        probs_torch,
                        sample_method="sample",
                        max_num_of_ans=10,
                        method="herke"):
    """
    :param probs: numpy array of the probablities for all answers/docs for a question/query
    :param sample_method: greedy or sample
    :param max_num_of_ans: max num of answers to be selected
    :return: a list of index for the selected ans
    """
    assert isinstance(sample_method, str)
    if max_num_of_ans <= 0:
        if sample_method == "sample":
            l = np.random.binomial(1, probs_numpy)
        elif sample_method == "greedy":
            l = [1 if prob >= 0.5 else 0 for prob in probs_numpy]
        answer_index = np.nonzero(l)[0]
    else:
        if sample_method == "sample":
            probs_torch = probs_torch.squeeze()
            assert len(probs_torch.size()) == 1

            if method == 'original':
                # original method
                probs_clip = probs_numpy * 0.8 + 0.1
                # print("sampling the index for the answer")
                index = range(len(probs_clip))
                probs_clip_norm = probs_clip / sum(probs_clip)
                answer_index = np.random.choice(index,
                                                max_num_of_ans,
                                                replace=False,
                                                p=np.reshape(
                                                    probs_clip_norm,
                                                    len(probs_clip_norm)))
                p_answer_index = probs_numpy[answer_index]
                sorted_idx = np.argsort(p_answer_index)[::-1]
                answer_index = answer_index[sorted_idx]
                loss = 0.
                for idx in index:
                    if idx in answer_index:
                        loss += probs_torch[idx].log()
                    else:
                        loss += (1 - probs_torch[idx]).log()
            elif method == 'herke':
                # herke's method
                answer_index = []
                epsilon = 0.1
                mask = Variable(torch.ones(probs_torch.size()).cuda(),
                                requires_grad=False)
                #mask = Variable(torch.ones(probs_torch.size()), requires_grad=False)
                loss_list = []
                for i in range(max_num_of_ans):
                    p_masked = probs_torch * mask
                    if random.uniform(0, 1) <= epsilon:  # explore
                        selected_idx = torch.multinomial(mask, 1)
                    else:
                        selected_idx = torch.multinomial(p_masked, 1)
                    loss_i = (epsilon / mask.sum() + (1 - epsilon) *
                              p_masked[selected_idx] / p_masked.sum()).log()
                    loss_list.append(loss_i)
                    mask = mask.clone()
                    mask[selected_idx] = 0
                    answer_index.append(selected_idx)

                answer_index = torch.cat(answer_index, dim=0)
                answer_index = answer_index.data.cpu().numpy()

                loss = sum(loss_list)
        elif sample_method == "greedy":
            loss = 0
            answer_index = np.argsort(np.reshape(
                probs_numpy, len(probs_numpy)))[-max_num_of_ans:]
            answer_index = answer_index[::-1]

    # answer_index.sort()
    return answer_index, loss
Ejemplo n.º 59
0
        step = (i+1) // seq_length
        if step % 100 == 0:
            print ('Epoch [{}/{}], Step[{}/{}], Loss: {:.4f}, Perplexity: {:5.2f}'
                   .format(epoch+1, num_epochs, step, num_batches, loss.item(), np.exp(loss.item())))

# Test the model
with torch.no_grad():
    with open('sample.txt', 'w') as f:
        # Set intial hidden ane cell states
        state = (torch.zeros(num_layers, 1, hidden_size).to(device),
                 torch.zeros(num_layers, 1, hidden_size).to(device))

        # Select one word id randomly
        prob = torch.ones(vocab_size)
        input = torch.multinomial(prob, num_samples=1).unsqueeze(1).to(device)

        for i in range(num_samples):
            # Forward propagate RNN 
            output, state = model(input, state)

            # Sample a word id
            prob = output.exp()
            word_id = torch.multinomial(prob, num_samples=1).item()

            # Fill input with sampled word id for the next time step
            input.fill_(word_id)

            # File write
            word = corpus.dictionary.idx2word[word_id]
            word = '\n' if word == '<eos>' else word + ' '
Ejemplo n.º 60
0
 def sample_n(self, n):
     return torch.multinomial(self.probs, n, True).t()