Ejemplo n.º 1
0
    def greedy_decode(self, latent, max_len, start_id):
        '''
        latent: (batch_size, max_src_seq, d_model)
        src_mask: (batch_size, 1, max_src_len)
        '''
        batch_size = latent.size(0)

        # memory = self.latent2memory(latent)

        ys = get_cuda(torch.ones(batch_size,
                                 1).fill_(start_id).long())  # (batch_size, 1)
        for i in range(max_len - 1):
            # input("==========")
            # print("="*10, i)
            # print("ys", ys.size())  # (batch_size, i)
            # print("tgt_mask", subsequent_mask(ys.size(1)).size())  # (1, i, i)
            out = self.decode(latent.unsqueeze(1), to_var(ys),
                              to_var(subsequent_mask(ys.size(1)).long()))
            prob = self.generator(out[:, -1])
            # print("prob", prob.size())  # (batch_size, vocab_size)
            _, next_word = torch.max(prob, dim=1)
            # print("next_word", next_word.size())  # (batch_size)

            # print("next_word.unsqueeze(1)", next_word.unsqueeze(1).size())

            ys = torch.cat([ys, next_word.unsqueeze(1)], dim=1)
            # print("ys", ys.size())
        return ys[:, 1:]
Ejemplo n.º 2
0
def fgim_attack(model, origin_data, target, ae_model, max_sequence_length, id_bos,
                id2text_sentence, id_to_word, gold_ans, output_file):
    """Fast Gradient Iterative Methods"""

    dis_criterion = nn.BCELoss(size_average=True)

    gold_text = id2text_sentence(gold_ans, id_to_word)
    print("gold:", gold_text)
    # while True:
    for epsilon in [2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]:
        it = 0
        data = origin_data
        output_text = str(epsilon)
        while True:
            print("epsilon:", epsilon)

            data = to_var(data.clone())  # (batch_size, seq_length, latent_size)
            # Set requires_grad attribute of tensor. Important for Attack
            data.requires_grad = True
            output = model.forward(data)
            # Calculate gradients of model in backward pass
            # print("target", target[0].item())
            # print("output", output[0].item())
            loss = dis_criterion(output, target)
            model.zero_grad()
            # dis_optimizer.zero_grad()
            loss.backward()
            data_grad = data.grad.data
            # print("data_grad")
            # print(data_grad)
            data = data - epsilon * data_grad
            # print("epsilon * data_grad")
            # print((epsilon * data_grad))
            # print("data")
            # print(data)
            # print("perturbed_data")
            # print(perturbed_data)
            it += 1
            # data = perturbed_data
            epsilon = epsilon * 0.9

            # generator_id = ae_model.greedy_decode(data,
            #                                         max_len=max_sequence_length,
            #                                         start_id=id_bos)
            try:
                generator_id = ae_model.greedy_decode(data,
                                                    max_len=max_sequence_length,
                                                    start_id=id_bos)
            except:
                generator_id = ae_model.module.greedy_decode(data,
                                                    max_len=max_sequence_length,
                                                    start_id=id_bos)
            generator_text = id2text_sentence(generator_id[0], id_to_word)
            print("| It {:2d} | dis model pred {:5.4f} |".format(it, output[0].item()))
            # print(generator_text)
            if it >= 5:
                print(generator_text)
                break
        add_output(output_file, ": ".join([output_text, generator_text])) # save sentence
    return generator_text
def fgim_step(model, origin_data, target, ae_model, max_sequence_length,
              id_bos, id2text_sentence, id_to_word, gold_ans, epsilon, step):
    """Fast Gradient Iterative Methods"""

    dis_criterion = nn.BCELoss(size_average=True)

    it = 0
    data = origin_data
    while it < step:

        data = to_var(data.clone())  # (batch_size, seq_length, latent_size)
        # Set requires_grad attribute of tensor. Important for Attack
        data.requires_grad = True
        output = model.forward(data)
        loss = dis_criterion(output, target)
        model.zero_grad()
        loss.backward()
        data_grad = data.grad.data
        data = data - epsilon * data_grad
        it += 1
        epsilon = epsilon * 0.9

        generator_id = ae_model.greedy_decode(data,
                                              max_len=max_sequence_length,
                                              start_id=id_bos)
        generator_text = id2text_sentence(generator_id[0], id_to_word)
        print("| It {:2d} | dis model pred {:5.4f} |".format(
            it, output[0].item()))
        print(generator_text)

    return data, generator_id
Ejemplo n.º 4
0
    def greedy_decode(self, latent, max_len, start_id):
        '''
        latent: (batch_size, max_src_seq, d_model)
        src_mask: (batch_size, 1, max_src_len)
        '''
        if self.ae_vs_ar :
            batch_size = latent.size(0)

            # memory = self.latent2memory(latent)

            ys = torch.ones(batch_size, 1).fill_(start_id).long().to(latent.device)  # (batch_size, 1)
            for i in range(max_len - 1):
                # input("==========")
                # print("="*10, i)
                # print("ys", ys.size())  # (batch_size, i)
                # print("tgt_mask", subsequent_mask(ys.size(1)).size())  # (1, i, i)
                out = self.decode(latent.unsqueeze(1), to_var(ys), to_var(subsequent_mask(ys.size(1)).long()))
                prob = self.generator(out[:, -1])
                # print("prob", prob.size())  # (batch_size, vocab_size)
                _, next_word = torch.max(prob, dim=1)
                # print("next_word", next_word.size())  # (batch_size)

                # print("next_word.unsqueeze(1)", next_word.unsqueeze(1).size())

                ys = torch.cat([ys, next_word.unsqueeze(1)], dim=1)
                # print("ys", ys.size())
            return ys[:, 1:]
        else :
            pad_id = 0
            batch_size = latent.size(0)
            ys = torch.ones(batch_size, 1).fill_(start_id).long().to(latent.device)  # (batch_size, 1)
            for i in range(max_len - 1):
                tensor = self.src_embed(to_var(ys))
                src_attn_mask = to_var((ys != pad_id).long())
                tgt_mask = to_var(subsequent_mask(ys.size(1)).long())
                tensor *= src_attn_mask.unsqueeze(-1).to(tensor.dtype)
                tensor = self.encoder(tensor, tgt_mask)
                tensor *= src_attn_mask.unsqueeze(-1).to(tensor.dtype)
                prob = self.generator(tensor[:, -1])
                _, next_word = torch.max(prob, dim=1)
                ys = torch.cat([ys, next_word.unsqueeze(1)], dim=1)
            return ys[:, 1:]
Ejemplo n.º 5
0
def fgim_attack(model, origin_data, target, ae_model, max_sequence_length,
                id_bos, id2text_sentence, id_to_word):
    """Fast Gradient Iterative Methods"""

    dis_criterion = nn.BCELoss(size_average=True)

    generator_text = list()
    #    for epsilon in [2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]:
    for epsilon in [8.0]:
        it = 0
        data = origin_data
        while True:

            data = to_var(
                data.clone(),
                model.getGpu())  # (batch_size, seq_length, latent_size)
            data.requires_grad = True

            output = model.forward(data)
            # Calculate gradients of model in backward pass
            loss = dis_criterion(output, target.float())
            model.zero_grad()
            loss.backward()
            data_grad = data.grad.data
            data = data - epsilon * data_grad

            it += 1
            epsilon = epsilon * 0.9

            generator_id = ae_model.greedy_decode(data,
                                                  target,
                                                  max_len=max_sequence_length,
                                                  start_id=id_bos)
            #            generator_text = id2text_sentence(generator_id[0], id_to_word)

            if it >= 5:
                break
    for i in range(len(generator_id)):
        generator_text.append(id2text_sentence(generator_id[i], id_to_word))
    print("| It {:2d} | dis model pred {:5.4f} |".format(it, output[0].item()))
    return generator_text
Ejemplo n.º 6
0
def predict(args, ae_model, dis_model, batch, epsilon):
    (batch_sentences, 
     tensor_labels, 
     tensor_src,
     tensor_src_mask,
     tensor_tgt,
     tensor_tgt_y,
     tensor_tgt_mask,
     tensor_ntokens) = batch
    
    ae_model.eval()
    dis_model.eval()
    latent, out = ae_model.forward(tensor_src, tensor_tgt, tensor_src_mask, tensor_tgt_mask)
    generator_text = ae_model.greedy_decode(latent,
                                            max_len=args.max_sequence_length,
                                            start_id=args.id_bos)
    print(id2text_sentence(tensor_tgt_y[0], args.id_to_word))
    print(id2text_sentence(generator_text[0], args.id_to_word))
    target = get_cuda(torch.tensor([[1.0]], dtype=torch.float))
    if tensor_labels[0].item() > 0.5:
        target = get_cuda(torch.tensor([[0.0]], dtype=torch.float))

    dis_criterion = nn.BCELoss(size_average=True)

    data = to_var(latent.clone())  # (batch_size, seq_length, latent_size)
    data.requires_grad = True
    output = dis_model.forward(data)
    loss = dis_criterion(output, target)
    dis_model.zero_grad()
    loss.backward()
    data_grad = data.grad.data
    data = data - epsilon * data_grad

    generator_id = ae_model.greedy_decode(data,
                                          max_len=args.max_sequence_length,
                                          start_id=args.id_bos)
    return id2text_sentence(generator_id[0], args.id_to_word)
def train_iters(ae_model, dis_model):
    train_data_loader = non_pair_data_loader(
        batch_size=args.batch_size,
        id_bos=args.id_bos,
        id_eos=args.id_eos,
        id_unk=args.id_unk,
        max_sequence_length=args.max_sequence_length,
        vocab_size=args.vocab_size)
    train_data_loader.create_batches(args.train_file_list,
                                     args.train_label_list,
                                     if_shuffle=True)
    add_log("Start train process.")
    ae_model.train()
    dis_model.train()

    ae_optimizer = NoamOpt(
        ae_model.src_embed[0].d_model, 1, 2000,
        torch.optim.Adam(ae_model.parameters(),
                         lr=0,
                         betas=(0.9, 0.98),
                         eps=1e-9))
    dis_optimizer = torch.optim.Adam(dis_model.parameters(), lr=0.0001)

    ae_criterion = get_cuda(
        LabelSmoothing(size=args.vocab_size,
                       padding_idx=args.id_pad,
                       smoothing=0.1))
    dis_criterion = nn.BCELoss(size_average=True)

    for epoch in range(200):
        print('-' * 94)
        epoch_start_time = time.time()
        for it in range(train_data_loader.num_batch):
            batch_sentences, tensor_labels, \
            tensor_src, tensor_src_mask, tensor_tgt, tensor_tgt_y, \
            tensor_tgt_mask, tensor_ntokens = train_data_loader.next_batch()

            # For debug
            # print(batch_sentences[0])
            # print(tensor_src[0])
            # print(tensor_src_mask[0])
            # print("tensor_src_mask", tensor_src_mask.size())
            # print(tensor_tgt[0])
            # print(tensor_tgt_y[0])
            # print(tensor_tgt_mask[0])
            # print(batch_ntokens)

            # Forward pass
            latent, out = ae_model.forward(tensor_src, tensor_tgt,
                                           tensor_src_mask, tensor_tgt_mask)
            # print(latent.size())  # (batch_size, max_src_seq, d_model)
            # print(out.size())  # (batch_size, max_tgt_seq, vocab_size)

            # Loss calculation
            loss_rec = ae_criterion(
                out.contiguous().view(-1, out.size(-1)),
                tensor_tgt_y.contiguous().view(-1)) / tensor_ntokens.data

            # loss_all = loss_rec + loss_dis

            ae_optimizer.optimizer.zero_grad()
            loss_rec.backward()
            ae_optimizer.step()

            # Classifier
            dis_lop = dis_model.forward(to_var(latent.clone()))

            loss_dis = dis_criterion(dis_lop, tensor_labels)

            dis_optimizer.zero_grad()
            loss_dis.backward()
            dis_optimizer.step()

            if it % 200 == 0:
                add_log(
                    '| epoch {:3d} | {:5d}/{:5d} batches | rec loss {:5.4f} | dis loss {:5.4f} |'
                    .format(epoch, it, train_data_loader.num_batch, loss_rec,
                            loss_dis))

                print(id2text_sentence(tensor_tgt_y[0], args.id_to_word))
                generator_text = ae_model.greedy_decode(
                    latent,
                    max_len=args.max_sequence_length,
                    start_id=args.id_bos)
                print(id2text_sentence(generator_text[0], args.id_to_word))

        add_log('| end of epoch {:3d} | time: {:5.2f}s |'.format(
            epoch, (time.time() - epoch_start_time)))
        # Save model
        torch.save(ae_model.state_dict(),
                   args.current_save_path + 'ae_model_params.pkl')
        torch.save(dis_model.state_dict(),
                   args.current_save_path + 'dis_model_params.pkl')
    return
Ejemplo n.º 8
0
def fgim(data_loader, args, ae_model, c_theta, gold_ans = None) :
    """
    Input: 
        Original latent representation z : (n_batch, batch_size, seq_length, latent_size)
        Well-trained attribute classifier C_θ
        Target attribute y
        A set of weights w = {w_i}
        Decay coefficient λ 
        Threshold t
        
    Output: An optimal modified latent representation z'
    """
    w = args.w
    lambda_ = args.lambda_
    t = args.threshold
    max_iter_per_epsilon = args.max_iter_per_epsilon
    max_sequence_length = args.max_sequence_length
    id_bos = args.id_bos
    id_to_word = args.id_to_word
    limit_batches = args.limit_batches
        
    text_z_prime = {}
    text_z_prime = {"source" : [], "origin_labels" : [], "before" : [], "after" : [], "change" : [], "pred_label" : []}
    if gold_ans is not None :
        text_z_prime["gold_ans"] = []
    z_prime = []
    
    dis_criterion = nn.BCELoss(size_average=True)
    n_batches = 0
    for it in tqdm.tqdm(list(range(data_loader.num_batch)), desc="FGIM"):
        if gold_ans is not None :
            text_z_prime["gold_ans"].append(gold_ans[it])
        
        _, tensor_labels, \
        tensor_src, tensor_src_mask, tensor_src_attn_mask, tensor_tgt, tensor_tgt_y, \
        tensor_tgt_mask, _ = data_loader.next_batch()
        # only on negative example
        negative_examples = ~(tensor_labels.squeeze()==args.positive_label)
        tensor_labels = tensor_labels[negative_examples].squeeze(0) # .view(1, -1)
        tensor_src = tensor_src[negative_examples].squeeze(0) 
        tensor_src_attn_mask = tensor_src_attn_mask[negative_examples].squeeze(0)
        tensor_src_mask = tensor_src_mask[negative_examples].squeeze(0)  
        tensor_tgt_y = tensor_tgt_y[negative_examples].squeeze(0) 
        tensor_tgt = tensor_tgt[negative_examples].squeeze(0) 
        tensor_tgt_mask = tensor_tgt_mask[negative_examples].squeeze(0) 
        #if gold_ans is not None :
        #    text_z_prime["gold_ans"][-1] = text_z_prime["gold_ans"][-1][negative_examples]

        #print("------------%d------------" % it)
        if negative_examples.any():
            text_z_prime["source"].append([id2text_sentence(t, args.id_to_word) for t in tensor_tgt_y])
            text_z_prime["origin_labels"].append(tensor_labels.cpu().numpy())

            origin_data, _ = ae_model.forward(tensor_src, tensor_tgt, tensor_src_mask, tensor_src_attn_mask, tensor_tgt_mask)

            # Define target label
            y_prime = 1.0 - (tensor_labels > 0.5).float()

            ############################### FGIM ######################################################
            generator_id = ae_model.greedy_decode(origin_data, max_len=max_sequence_length, start_id=id_bos)
            generator_text = [id2text_sentence(gid, id_to_word) for gid in generator_id]
            text_z_prime["before"].append(generator_text)
            
            flag = False
            for w_i in w:
                #print("---------- w_i:", w_i)
                data = to_var(origin_data.clone())  # (batch_size, seq_length, latent_size)
                b = True
                if b :
                    data.requires_grad = True
                    output = c_theta.forward(data)
                    loss = dis_criterion(output, y_prime)
                    c_theta.zero_grad()
                    loss.backward()
                    data = data - w_i * data.grad.data
                else :
                    data = origin_data
                    output = c_theta.forward(data)
                    
                it = 0 
                while True:
                    #if torch.cdist(output, y_prime) < t :
                    #if torch.sum((output - y_prime)**2, dim=1).sqrt().mean() < t :
                    if torch.sum((output - y_prime).abs(), dim=1).mean() < t :
                        flag = True
                        break
            
                    data = to_var(data.clone())  # (batch_size, seq_length, latent_size)
                    # Set requires_grad attribute of tensor. Important for Attack
                    data.requires_grad = True
                    output = c_theta.forward(data)
                    # Calculate gradients of model in backward pass
                    loss = dis_criterion(output, y_prime)
                    c_theta.zero_grad()
                    # dis_optimizer.zero_grad()
                    loss.backward()
                    data = data - w_i * data.grad.data
                    it += 1
                    # data = perturbed_data
                    w_i = lambda_ * w_i
                    if False :
                        if text_gen_params is not None :
                            generator_id = ae_model.greedy_decode(data,
                                                                    max_len=max_sequence_length,
                                                                    start_id=id_bos)
                            generator_text = id2text_sentence(generator_id[0], id_to_word)
                            print("| It {:2d} | dis model pred {:5.4f} |".format(it, output[0].item()))
                            print(generator_text)
                    if it > max_iter_per_epsilon:
                        break
                
                if flag :    
                    z_prime.append(data)
                    generator_id = ae_model.greedy_decode(data, max_len=max_sequence_length, start_id=id_bos)
                    generator_text = [id2text_sentence(gid, id_to_word) for gid in generator_id]
                    text_z_prime["after"].append(generator_text)
                    text_z_prime["change"].append([True]*len(output))
                    text_z_prime["pred_label"].append([o.item() for o in output])
                    break
            
            if not flag : # cannot debiaising
                z_prime.append(origin_data)
                text_z_prime["after"].append(text_z_prime["before"][-1])
                text_z_prime["change"].append([False]*len(y_prime))
                text_z_prime["pred_label"].append([o.item() for o in y_prime])
            
            n_batches += 1
            if n_batches > limit_batches:
                break        
    return z_prime, text_z_prime
Ejemplo n.º 9
0
def fgim_attack(model, origin_data, target, ae_model, max_sequence_length, id_bos,
                id2text_sentence, id_to_word, gold_ans = None):
    """Fast Gradient Iterative Methods"""

    dis_criterion = nn.BCELoss(size_average=True)
    t = 0.001 # Threshold
    lambda_ = 0.9 # Decay coefficient
    max_iter_per_epsilon=20
    
    if gold_ans is not None :
        gold_text = id2text_sentence(gold_ans, id_to_word)
        print("gold:", gold_text)
        
    flag = False
    for epsilon in [2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]:
        print("---------- epsilon:", epsilon)
        generator_id = ae_model.greedy_decode(origin_data, max_len=max_sequence_length, start_id=id_bos)
        generator_text = id2text_sentence(generator_id[0], id_to_word)
        print("z:", generator_text)
        
        data = to_var(origin_data.clone())  # (batch_size, seq_length, latent_size)
        b = True
        if b :
            data.requires_grad = True
            output = model.forward(data)
            loss = dis_criterion(output, target)
            model.zero_grad()
            loss.backward()
            data_grad = data.grad.data
            data = data - epsilon * data_grad
        else :
            data = origin_data
            
        it = 0 
        while True:
            if torch.cdist(output, target) < t :
                flag = True
                break
    
            data = to_var(data.clone())  # (batch_size, seq_length, latent_size)
            # Set requires_grad attribute of tensor. Important for Attack
            data.requires_grad = True
            output = model.forward(data)
            # Calculate gradients of model in backward pass
            loss = dis_criterion(output, target)
            model.zero_grad()
            # dis_optimizer.zero_grad()
            loss.backward()
            data_grad = data.grad.data
            data = data - epsilon * data_grad
            it += 1
            # data = perturbed_data
            epsilon = epsilon * lambda_
            if False :
                generator_id = ae_model.greedy_decode(data,
                                                        max_len=max_sequence_length,
                                                        start_id=id_bos)
                generator_text = id2text_sentence(generator_id[0], id_to_word)
                print("| It {:2d} | dis model pred {:5.4f} |".format(it, output[0].item()))
                print(generator_text)
            if it > max_iter_per_epsilon:
                break
        
        generator_id = ae_model.greedy_decode(data, max_len=max_sequence_length, start_id=id_bos)
        generator_text = id2text_sentence(generator_id[0], id_to_word)
        print("|dis model pred {:5.4f} |".format(output[0].item()))
        print("z*", generator_text)
        print()
        if flag :
            return generator_text
Ejemplo n.º 10
0
    def beam_decode(self, latent, beam_size, max_len, start_id):
        '''
        latent: (batch_size, max_src_seq, d_model)
        src_mask: (batch_size, 1, max_src_len)
        '''
        memory_beam = latent.detach().repeat(beam_size, 1, 1)
        beam = Beam(beam_size=beam_size,
                    min_length=0,
                    n_top=beam_size,
                    ranker=None)
        batch_size = latent.size(0)
        candidate = get_cuda(torch.zeros(beam_size, batch_size, max_len),
                             self.gpu)
        global_scores = get_cuda(torch.zeros(beam_size), self.gpu)

        tmp_cand = get_cuda(torch.zeros(beam_size * beam_size), self.gpu)
        tmp_scores = get_cuda(torch.zeros(beam_size * beam_size), self.gpu)

        ys = get_cuda(
            torch.ones(batch_size, 1).fill_(start_id).long(),
            self.gpu)  # (batch_size, 1)
        candidate[:, :, 0] = ys.clone()
        #first
        out = self.decode(latent.unsqueeze(1), to_var(ys, self.gpu),
                          to_var(subsequent_mask(ys.size(1)).long(), self.gpu))
        prob = self.generator(out[:, -1])
        scores, ids = prob.topk(k=beam_size, dim=1)  #shape:1,baem_size
        global_scores = scores.view(-1)
        candidate[:, :, 1] = ids.transpose(0, 1)
        for i in range(1, max_len - 1):
            for j in range(beam_size):
                #                candidate[j,:,:i+1] = torch.cat([candidate[j,:,:i], ids[j]], dim=-1)
                tmp = candidate[j, :, :i + 1].view(1, -1)
                #tmp_cand:3
                tp, tc = self.recursive_beam(
                    beam_size, latent.unsqueeze(1),
                    to_var(tmp.long(), self.gpu),
                    to_var(subsequent_mask(tmp.size(1)).long(), self.gpu))
                tmp_cand[beam_size * j:beam_size * (j + 1)] = tc.view(-1)
                tmp_scores[beam_size * j:beam_size *
                           (j + 1)] = tp.view(-1) + global_scores[j]
            beam_head_scores, beam_head_ids = tmp_scores.topk(k=beam_size,
                                                              dim=0)
            global_scores = beam_head_scores
            can_list = []
            for bb in range(beam_size):
                can_list.append(
                    torch.cat([
                        candidate[int(beam_head_ids[bb].item() /
                                      beam_size), :, :i + 1].long(),
                        tmp_cand[beam_head_ids[bb]].long().unsqueeze(
                            0).unsqueeze(0)
                    ],
                              dim=1))
#            c2=torch.cat([candidate[int(beam_head_ids[1].item()/beam_size),:,:i+1].long(), tmp_cand[beam_head_ids[1]].long().unsqueeze(0).unsqueeze(0)], dim=1)
#            c3=torch.cat([candidate[int(beam_head_ids[2].item()/3),:,:i+1].long(), tmp_cand[beam_head_ids[2]].long().unsqueeze(0).unsqueeze(0)], dim=1)
            for bb in range(beam_size):
                candidate[bb, :, :i + 2] = can_list[bb]
#            candidate[0,:,:i+2]=c1
#            candidate[1,:,:i+2]=c2
#            candidate[2,:,:i+2]=c3
        top_s, top_i = global_scores.sort()
        candidate = candidate.view(beam_size, -1)
        candidate = candidate[:, 1:]
        sorted_candidate = candidate.clone()
        for bb in range(beam_size):
            sorted_candidate[bb] = candidate[top_i[bb]]
        return sorted_candidate.long().view(beam_size, -1)
Ejemplo n.º 11
0
def train_step(args, data_loader, ae_model, dis_model, ae_optimizer,
               dis_optimizer, ae_criterion, dis_criterion, epoch):
    ae_model.train()
    dis_model.train()

    loss_ae, n_words_ae, xe_loss_ae, n_valid_ae = 0, 0, 0, 0
    loss_clf, total_clf, n_valid_clf = 0, 0, 0
    epoch_start_time = time.time()
    for it in range(data_loader.num_batch):
        flag_rec = True
        batch_sentences, tensor_labels, \
        tensor_src, tensor_src_mask, tensor_src_attn_mask, tensor_tgt, tensor_tgt_y, \
        tensor_tgt_mask, tensor_ntokens = data_loader.next_batch()

        # Forward pass
        latent, out = ae_model.forward(tensor_src, tensor_tgt, tensor_src_mask,
                                       tensor_src_attn_mask, tensor_tgt_mask)

        # Loss calculation
        if not args.sedat:
            loss_rec = ae_criterion(out.contiguous().view(-1, out.size(-1)),
                                    tensor_tgt_y.contiguous().view(-1)) / (
                                        tensor_ntokens.data + eps)

        else:
            # only on positive example
            positive_examples = tensor_labels.squeeze() == args.positive_label
            out = out[positive_examples]  # or out[positive_examples,:,:]
            tensor_tgt_y = tensor_tgt_y[
                positive_examples]  # or tensor_tgt_y[positive_examples,:]
            tensor_ntokens = (tensor_tgt_y != 0).data.sum().float()
            loss_rec = ae_criterion(out.contiguous().view(-1, out.size(-1)),
                                    tensor_tgt_y.contiguous().view(-1)) / (
                                        tensor_ntokens.data + eps)
            flag_rec = positive_examples.any()
            out = out.squeeze(0)
            tensor_tgt_y = tensor_tgt_y.squeeze(0)

        if flag_rec:
            n_v, n_w = get_n_v_w(tensor_tgt_y, out)
        else:
            n_w = float("nan")
            n_v = float("nan")

        x_e = loss_rec.item() * n_w
        loss_ae += loss_rec.item()
        n_words_ae += n_w
        xe_loss_ae += x_e
        n_valid_ae += n_v
        ae_acc = 100. * n_v / (n_w + eps)
        avg_ae_acc = 100. * n_valid_ae / (n_words_ae + eps)
        avg_ae_loss = loss_ae / (it + 1)
        ae_ppl = np.exp(x_e / (n_w + eps))
        avg_ae_ppl = np.exp(xe_loss_ae / (n_words_ae + eps))

        ae_optimizer.zero_grad()
        loss_rec.backward(retain_graph=not args.detach_classif)
        ae_optimizer.step()

        # Classifier
        if args.detach_classif:
            dis_lop = dis_model.forward(to_var(latent.clone()))
        else:
            dis_lop = dis_model.forward(latent)

        loss_dis = dis_criterion(dis_lop, tensor_labels)

        dis_optimizer.zero_grad()
        loss_dis.backward()
        dis_optimizer.step()

        t_c = tensor_labels.view(-1).size(0)
        n_v = (dis_lop.round().int() == tensor_labels).sum().item()
        loss_clf += loss_dis.item()
        total_clf += t_c
        n_valid_clf += n_v
        clf_acc = 100. * n_v / (t_c + eps)
        avg_clf_acc = 100. * n_valid_clf / (total_clf + eps)
        avg_clf_loss = loss_clf / (it + 1)
        if it % args.log_interval == 0:
            add_log(
                args, 'epoch {:3d} | {:5d}/{:5d} batches |'.format(
                    epoch, it, data_loader.num_batch))
            add_log(
                args,
                'Train : rec acc {:5.4f} | rec loss {:5.4f} | ppl {:5.4f} | dis acc {:5.4f} | dis loss {:5.4f} |'
                .format(ae_acc, loss_rec.item(), ae_ppl, clf_acc,
                        loss_dis.item()))
            add_log(
                args,
                'Train, avg : rec acc {:5.4f} | rec loss {:5.4f} | ppl {:5.4f} |  dis acc {:5.4f} | dis loss {:5.4f} |'
                .format(avg_ae_acc, avg_ae_loss, avg_ae_ppl, avg_clf_acc,
                        avg_clf_loss))
            if flag_rec:
                i = random.randint(0, len(tensor_tgt_y) - 1)
                reference = id2text_sentence(tensor_tgt_y[i], args.id_to_word)
                add_log(args, "input : %s" % reference)
                generator_text = ae_model.greedy_decode(
                    latent,
                    max_len=args.max_sequence_length,
                    start_id=args.id_bos)
                # batch_sentences
                hypothesis = id2text_sentence(generator_text[i],
                                              args.id_to_word)
                add_log(args, "gen : %s" % hypothesis)
                add_log(
                    args, "bleu : %s" %
                    calc_bleu(reference.split(" "), hypothesis.split(" ")))

    s = {}
    L = data_loader.num_batch + eps
    s["train_ae_loss"] = loss_ae / L
    s["train_ae_acc"] = 100. * n_valid_ae / (n_words_ae + eps)
    s["train_ae_ppl"] = np.exp(xe_loss_ae / (n_words_ae + eps))
    s["train_clf_loss"] = loss_clf / L
    s["train_clf_acc"] = 100. * n_valid_clf / (total_clf + eps)

    add_log(
        args, '| end of epoch {:3d} | time: {:5.2f}s |'.format(
            epoch, (time.time() - epoch_start_time)))

    add_log(
        args,
        '| rec acc {:5.4f} | rec loss {:5.4f} | rec ppl {:5.4f} | dis acc {:5.4f} | dis loss {:5.4f} |'
        .format(s["train_ae_acc"], s["train_ae_loss"], s["train_ae_ppl"],
                s["train_clf_acc"], s["train_clf_loss"]))
    return s
Ejemplo n.º 12
0
def eval_step(args, data_loader, ae_model, dis_model, ae_criterion,
              dis_criterion):
    ae_model.eval()
    dis_model.eval()

    loss_ae, n_words_ae, xe_loss_ae, n_valid_ae = 0, 0, 0, 0
    loss_clf, total_clf, n_valid_clf = 0, 0, 0
    for it in range(data_loader.num_batch):
        flag_rec = True
        batch_sentences, tensor_labels, \
        tensor_src, tensor_src_mask, tensor_src_attn_mask, tensor_tgt, tensor_tgt_y, \
        tensor_tgt_mask, tensor_ntokens = data_loader.next_batch()
        # Forward pass
        latent, out = ae_model.forward(tensor_src, tensor_tgt, tensor_src_mask,
                                       tensor_src_attn_mask, tensor_tgt_mask)
        # Loss calculation
        if not args.sedat:
            loss_rec = ae_criterion(out.contiguous().view(-1, out.size(-1)),
                                    tensor_tgt_y.contiguous().view(-1)) / (
                                        tensor_ntokens.data + eps)
        else:
            # only on positive example
            positive_examples = tensor_labels.squeeze() == args.positive_label
            out = out[positive_examples]  # or out[positive_examples,:,:]
            tensor_tgt_y = tensor_tgt_y[
                positive_examples]  # or tensor_tgt_y[positive_examples,:]
            tensor_ntokens = (tensor_tgt_y != 0).data.sum().float()
            loss_rec = ae_criterion(out.contiguous().view(-1, out.size(-1)),
                                    tensor_tgt_y.contiguous().view(-1)) / (
                                        tensor_ntokens.data + eps)
            flag_rec = positive_examples.any()
            out = out.squeeze(0)
            tensor_tgt_y = tensor_tgt_y.squeeze(0)

        if flag_rec:
            n_v, n_w = get_n_v_w(tensor_tgt_y, out)
        else:
            n_w = float("nan")
            n_v = float("nan")

        x_e = loss_rec.item() * n_w
        loss_ae += loss_rec.item()
        n_words_ae += n_w
        xe_loss_ae += x_e
        n_valid_ae += n_v

        # Classifier
        dis_lop = dis_model.forward(to_var(latent.clone()))
        loss_dis = dis_criterion(dis_lop, tensor_labels)

        t_c = tensor_labels.view(-1).size(0)
        n_v = (dis_lop.round().int() == tensor_labels).sum().item()
        loss_clf += loss_dis.item()
        total_clf += t_c
        n_valid_clf += n_v

    s = {}
    L = data_loader.num_batch + eps
    s["eval_ae_loss"] = loss_ae / L
    s["eval_ae_acc"] = 100. * n_valid_ae / (n_words_ae + eps)
    s["eval_ae_ppl"] = np.exp(xe_loss_ae / (n_words_ae + eps))
    s["eval_clf_loss"] = loss_clf / L
    s["eval_clf_acc"] = 100. * n_valid_clf / (total_clf + eps)

    add_log(
        args,
        'Val : rec acc {:5.4f} | rec loss {:5.4f} | rec ppl {:5.4f} | dis acc {:5.4f} | dis loss {:5.4f} |'
        .format(s["eval_ae_acc"], s["eval_ae_loss"], s["eval_ae_ppl"],
                s["eval_clf_acc"], s["eval_clf_loss"]))
    return s