예제 #1
0
def main(args):
    tokenizer = BertTokenizer.from_pretrained(args.bert_model)
    data = CustomDatset(tokenizer, args.data_file, args.max_length)
    data_loader = DataLoader(data, shuffle=False, batch_size=args.batch_size)

    device = torch.cuda.current_device()
    checkpoint = torch.load(args.checkpoint, map_location="cpu")

    vae = DiscreteVAE(checkpoint["args"])
    vae.load_state_dict(checkpoint["state_dict"])
    vae.eval()
    vae = vae.to(device)

    if (args.vocabulary):
        token_list = processing_vocabulary(args.vocabulary)
        print(len(tokenizer))
        tokenizer.add_tokens(token_list)
        print(len(tokenizer))
        #vae.resize_token_embeddings(len(tokenizer))
        #print(vae.embeddings.word_embeddings.weight[-1, :])
        #vae.embeddings.word_embeddings.weight[-1, :] = torch.zeros([vae.config.hidden_size])
        #print(vae.embeddings.word_embeddings.weight[-1, :])

    if not os.path.exists(args.output_dir):
        os.makedirs(args.output_dir)
    output_file = os.path.join(args.output_dir, "synthetic_qa.jsonl")

    fw = open(output_file, "w")
    for batch in tqdm(data_loader, total=len(data_loader)):
        c_ids = batch
        c_len = torch.sum(torch.sign(c_ids), 1)
        max_c_len = torch.max(c_len)
        c_ids = c_ids[:, :max_c_len].to(device)

        # sample latent variable K times
        for _ in range(args.k):
            with torch.no_grad():
                _, _, zq, _, za = vae.prior_encoder(c_ids)
                batch_q_ids, batch_start, batch_end = vae.generate(
                    zq, za, c_ids)

            for i in range(c_ids.size(0)):
                _c_ids = c_ids[i].cpu().tolist()
                q_ids = batch_q_ids[i].cpu().tolist()
                start_pos = batch_start[i].item()
                end_pos = batch_end[i].item()

                a_ids = _c_ids[start_pos:end_pos + 1]
                c_text = tokenizer.decode(_c_ids, replace_special_tokens=True)
                q_text = tokenizer.decode(q_ids, replace_speical_tokens=True)
                a_text = tokenizer.decode(a_ids, replace_special_tokens=True)
                json_dict = {
                    "context": c_text,
                    "question": q_text,
                    "answer": a_text
                }
                fw.write(json.dumps(json_dict) + "\n")
                fw.flush()

    fw.close()
예제 #2
0
class VAETrainer(object):
    def __init__(self, args):
        self.args = args
        self.clip = args.clip
        self.device = args.device

        self.vae = DiscreteVAE(args).to(self.device)
        params = filter(lambda p: p.requires_grad, self.vae.parameters())
        self.optimizer = torch.optim.Adam(params, lr=args.lr)

        self.loss_q_rec = 0
        self.loss_a_rec = 0
        self.loss_zq_kl = 0
        self.loss_za_kl = 0
        self.loss_info = 0

    def train(self, c_ids, q_ids, a_ids, start_positions, end_positions):
        self.vae = self.vae.train()

        # Forward
        loss, \
        loss_q_rec, loss_a_rec, \
        loss_zq_kl, loss_za_kl, \
        loss_info \
        = self.vae(c_ids, q_ids, a_ids, start_positions, end_positions)

        # Backward
        self.optimizer.zero_grad()
        loss.backward()

        # Step
        self.optimizer.step()

        self.loss_q_rec = loss_q_rec.item()
        self.loss_a_rec = loss_a_rec.item()
        self.loss_zq_kl = loss_zq_kl.item()
        self.loss_za_kl = loss_za_kl.item()
        self.loss_info = loss_info.item()

    def generate_posterior(self, c_ids, q_ids, a_ids):
        self.vae = self.vae.eval()
        with torch.no_grad():
            _, _, zq, _, za = self.vae.posterior_encoder(c_ids, q_ids, a_ids)
            q_ids, start_positions, end_positions = self.vae.generate(
                zq, za, c_ids)
        return q_ids, start_positions, end_positions, zq

    def generate_answer_logits(self, c_ids, q_ids, a_ids):
        self.vae = self.vae.eval()
        with torch.no_grad():
            _, _, zq, _, za = self.vae.posterior_encoder(c_ids, q_ids, a_ids)
            start_logits, end_logits = self.vae.return_answer_logits(
                zq, za, c_ids)
        return start_logits, end_logits

    def generate_prior(self, c_ids):
        self.vae = self.vae.eval()
        with torch.no_grad():
            _, _, zq, _, za = self.vae.prior_encoder(c_ids)
            q_ids, start_positions, end_positions = self.vae.generate(
                zq, za, c_ids)
        return q_ids, start_positions, end_positions, zq

    # [Update] Ahora guarda el estado actual del objeto (todas las variables del objeto).
    def save(self, filename, epoch, f1, bleu, em):
        params = {
            'state_dict': self.vae.state_dict(),
            'args': self.args,
            'optimizer': self.optimizer.state_dict(),
            'loss_q_rec': self.loss_q_rec,
            'loss_a_rec': self.loss_a_rec,
            'loss_zq_kl': self.loss_zq_kl,
            'loss_za_kl': self.loss_za_kl,
            'loss_info': self.loss_info,
            'epoch': epoch,
            'f1': f1,
            'bleu': bleu,
            'em': em
        }
        torch.save(params, filename)

    # [New] Función que carga un modelo así (Y todas las variables que forman el estado en el que se guardó).
    #     - Carga la última época almacenada con save_epoch y la devuelve.
    def loadd(self, foldername):
        checkpoint = torch.load(f"{foldername}/checkpoint.pt")

        self.vae.load_state_dict(checkpoint['state_dict'])
        self.optimizer.load_state_dict(checkpoint['optimizer'])

        self.loss_q_rec = checkpoint['loss_q_rec']
        self.loss_a_rec = checkpoint['loss_a_rec']
        self.loss_zq_kl = checkpoint['loss_zq_kl']
        self.loss_za_kl = checkpoint['loss_za_kl']
        self.loss_info = checkpoint['loss_info']

        print(f"Loading model trained in {checkpoint['epoch']} epochs.")

        return checkpoint['epoch']

    @staticmethod
    def get_best_f1(foldername):
        return torch.load(f"{foldername}/best_f1_model.pt")['f1']

    @staticmethod
    def get_best_bleu(foldername):
        return torch.load(f"{foldername}/best_bleu_model.pt")['bleu']

    @staticmethod
    def get_best_em(foldername):
        return torch.load(f"{foldername}/best_f1_model.pt")['em']

    @staticmethod
    def load_measures(fn):
        return VAETrainer.get_best_f1(fn), VAETrainer.get_best_bleu(
            fn), VAETrainer.get_best_em(fn)
예제 #3
0
class VAETrainer(object):
    def __init__(self, args):
        self.args = args
        self.clip = args.clip
        self.device = args.device

        self.vae = DiscreteVAE(args).to(self.device)
        self.params = filter(lambda p: p.requires_grad, self.vae.parameters())
        self.optimizer = torch.optim.Adam(self.params, lr=args.lr)

        self.loss_q_rec = 0
        self.loss_a_rec = 0
        self.loss_kl = 0
        self.loss_info = 0

    def train(self, c_ids, q_ids, a_ids, start_positions, end_positions):
        self.vae = self.vae.train()

        # Forward
        loss, loss_q_rec, loss_a_rec, \
        loss_kl, loss_info = \
        self.vae(c_ids, q_ids, a_ids, start_positions, end_positions)

        # Backward
        self.optimizer.zero_grad()
        loss.backward()

        # Step
        clip_grad_norm_(self.params, self.clip)
        self.optimizer.step()

        self.loss_q_rec = loss_q_rec.item()
        self.loss_a_rec = loss_a_rec.item()
        self.loss_kl = loss_kl.item()
        self.loss_info = loss_info.item()

    def generate_posterior(self, c_ids, q_ids, a_ids):
        self.vae = self.vae.eval()
        with torch.no_grad():
            posterior_z_prob, posterior_z = self.vae.posterior_encoder(
                c_ids, q_ids, a_ids)
            q_ids, start_positions, end_positions = self.vae.generate(
                posterior_z, c_ids)
        return q_ids, start_positions, end_positions, posterior_z_prob

    def generate_prior(self, c_ids):
        self.vae = self.vae.eval()
        with torch.no_grad():
            prior_z_prob, prior_z = self.vae.prior_encoder(c_ids)
            q_ids, start_positions, end_positions = self.vae.generate(
                prior_z, c_ids)
        return q_ids, start_positions, end_positions, prior_z_prob

    def save(self, filename):
        params = {'state_dict': self.vae.state_dict(), 'args': self.args}
        torch.save(params, filename)

    def reduce_lr(self):
        self.optimizer.param_groups[0]['lr'] *= 0.5

    @staticmethod
    def post_process(q_ids, start_positions, end_positions, c_ids):
        batch_size = q_ids.size(0)
        # exclude CLS token in c_ids
        c_ids = c_ids[:, 1:]
        start_positions = start_positions - 1
        end_positions = end_positions - 1

        q_mask, q_lengths = return_mask_lengths(q_ids)
        c_mask, c_lengths = return_mask_lengths(c_ids)

        total_max_len = torch.max(q_lengths + c_lengths)

        all_input_ids = []
        all_seg_ids = []
        for i in range(batch_size):
            q_length = q_lengths[i]
            c_length = c_lengths[i]
            q = q_ids[i, :q_length]  # exclude pad tokens
            c = c_ids[i, :c_length]  # exclude pad tokens

            # input ids
            pads = torch.zeros((total_max_len - q_length - c_length),
                               device=q_ids.device,
                               dtype=torch.long)
            input_ids = torch.cat([q, c, pads], dim=0)
            all_input_ids.append(input_ids)

            # segment ids
            zeros = torch.zeros_like(q)
            ones = torch.ones_like(c)
            seg_ids = torch.cat([zeros, ones, pads], dim=0)
            all_seg_ids.append(seg_ids)

            start_positions[i] = start_positions[i] + q_length
            end_positions[i] = end_positions[i] + q_length

        all_input_ids = torch.stack(all_input_ids, dim=0)
        all_seg_ids = torch.stack(all_seg_ids, dim=0)
        all_input_mask = (all_input_ids != 0).byte()

        return all_input_ids, all_seg_ids, all_input_mask, start_positions, end_positions

    @staticmethod
    def get_loss(start_logits, end_logits, start_positions, end_positions):
        if len(start_positions.size()) > 1:
            start_positions = start_positions.squeeze(-1)
        if len(end_positions.size()) > 1:
            end_positions = end_positions.squeeze(-1)
        # sometimes the start/end positions are outside our model inputs, we ignore these terms
        ignored_index = start_logits.size(1)
        start_positions = start_positions.clamp(0, ignored_index)
        end_positions = end_positions.clamp(0, ignored_index)

        loss_fct = nn.CrossEntropyLoss(ignore_index=ignored_index,
                                       reduction="none")
        start_loss = loss_fct(start_logits, start_positions)
        end_loss = loss_fct(end_logits, end_positions)
        total_loss = (start_loss + end_loss) * 0.5
        return total_loss
예제 #4
0
def main(args):
    tokenizer = BertTokenizer.from_pretrained(args.bert_model)
    args.tokenizer = tokenizer

    device = torch.cuda.current_device()
    checkpoint = torch.load(args.checkpoint, map_location="cpu")
    vae = DiscreteVAE(checkpoint["args"])
    vae.load_state_dict(checkpoint["state_dict"])
    vae.eval()
    vae = vae.to(device)

    if args.squad:
        examples = read_squad_examples(args.data_file,
                                       is_training=True,
                                       debug=args.debug)
        features = convert_examples_to_harv_features(
            examples,
            tokenizer=tokenizer,
            max_seq_length=args.max_c_len,
            max_query_length=args.max_q_len,
            doc_stride=128,
            is_training=True)
    else:
        examples = read_examples(args.data_file,
                                 is_training=True,
                                 debug=args.debug)
        features = convert_examples_to_harv_features(
            examples,
            tokenizer=tokenizer,
            max_seq_length=args.max_c_len,
            max_query_length=args.max_q_len,
            doc_stride=128,
            is_training=True)

    features = features[:int(len(features) * args.ratio)]
    all_c_ids = torch.tensor([f.c_ids for f in features], dtype=torch.long)
    data = TensorDataset(all_c_ids)
    data_loader = DataLoader(data, shuffle=False, batch_size=args.batch_size)

    new_features = []

    for batch in tqdm(data_loader, total=len(data_loader)):
        c_ids = batch[0]
        _, c_len = return_mask_lengths(c_ids)
        max_c_len = torch.max(c_len)
        c_ids = c_ids[:, :max_c_len].to(device)

        # sample latent variable K times
        for _ in range(args.k):
            with torch.no_grad():
                _, _, zq, _, za = vae.prior_encoder(c_ids)
                batch_q_ids, batch_start, batch_end = vae.generate(
                    zq, za, c_ids)

                all_input_ids, all_seg_ids, \
                all_input_mask, all_start, all_end = post_process(batch_q_ids, batch_start, batch_end, c_ids)

            for i in range(c_ids.size(0)):
                new_features.append(
                    InputFeatures(unique_id=None,
                                  example_index=None,
                                  doc_span_index=None,
                                  tokens=None,
                                  token_to_orig_map=None,
                                  token_is_max_context=None,
                                  input_ids=all_input_ids[i].cpu().tolist(),
                                  input_mask=all_input_mask[i].cpu().tolist(),
                                  c_ids=None,
                                  context_tokens=None,
                                  q_ids=None,
                                  q_tokens=None,
                                  answer_text=None,
                                  tag_ids=None,
                                  segment_ids=all_seg_ids[i].cpu().tolist(),
                                  noq_start_position=None,
                                  noq_end_position=None,
                                  start_position=all_start[i].cpu().tolist(),
                                  end_position=all_end[i].cpu().tolist(),
                                  is_impossible=None))

    dir_name = os.path.dirname(args.output_file)
    if not os.path.exists(dir_name):
        os.makedirs(dir_name)
    with open(args.output_file, "wb") as f:
        pickle.dump(new_features, f)
예제 #5
0
파일: trainer.py 프로젝트: yyht/Info-HCVAE
class VAETrainer(object):
    def __init__(self, args):
        self.args = args
        self.clip = args.clip
        self.device = args.device

        self.vae = DiscreteVAE(args).to(self.device)
        params = filter(lambda p: p.requires_grad, self.vae.parameters())
        self.optimizer = torch.optim.Adam(params, lr=args.lr)

        self.loss_q_rec = 0
        self.loss_a_rec = 0
        self.loss_zq_kl = 0
        self.loss_za_kl = 0
        self.loss_info = 0

    def train(self, c_ids, q_ids, a_ids, start_positions, end_positions):
        self.vae = self.vae.train()

        # Forward
        loss, \
        loss_q_rec, loss_a_rec, \
        loss_zq_kl, loss_za_kl, \
        loss_info \
        = self.vae(c_ids, q_ids, a_ids, start_positions, end_positions)

        # Backward
        self.optimizer.zero_grad()
        loss.backward()

        # Step
        self.optimizer.step()

        self.loss_q_rec = loss_q_rec.item()
        self.loss_a_rec = loss_a_rec.item()
        self.loss_zq_kl = loss_zq_kl.item()
        self.loss_za_kl = loss_za_kl.item()
        self.loss_info = loss_info.item()

    def generate_posterior(self, c_ids, q_ids, a_ids):
        self.vae = self.vae.eval()
        with torch.no_grad():
            _, _, zq, _, za = self.vae.posterior_encoder(c_ids, q_ids, a_ids)
            q_ids, start_positions, end_positions = self.vae.generate(
                zq, za, c_ids)
        return q_ids, start_positions, end_positions, zq

    def generate_answer_logits(self, c_ids, q_ids, a_ids):
        self.vae = self.vae.eval()
        with torch.no_grad():
            _, _, zq, _, za = self.vae.posterior_encoder(c_ids, q_ids, a_ids)
            start_logits, end_logits = self.vae.return_answer_logits(
                zq, za, c_ids)
        return start_logits, end_logits

    def generate_prior(self, c_ids):
        self.vae = self.vae.eval()
        with torch.no_grad():
            _, _, zq, _, za = self.vae.prior_encoder(c_ids)
            q_ids, start_positions, end_positions = self.vae.generate(
                zq, za, c_ids)
        return q_ids, start_positions, end_positions, zq

    def save(self, filename):
        params = {'state_dict': self.vae.state_dict(), 'args': self.args}
        torch.save(params, filename)