Пример #1
0
    def forward(self, input_sequence, length):

        batch_size = input_sequence.size(0)
        sorted_lengths, sorted_idx = torch.sort(length, descending=True)
        input_sequence = input_sequence[sorted_idx]

        # ENCODER
        input_embedding = self.embedding(input_sequence)

        packed_input = rnn_utils.pack_padded_sequence(input_embedding, sorted_lengths.data.tolist(), batch_first=True)

        _, hidden = self.encoder_rnn(packed_input)

        if self.bidirectional or self.num_layers > 1:
            # flatten hidden state
            hidden = hidden.view(batch_size, self.hidden_size*self.hidden_factor)
        else:
            hidden = hidden.squeeze()

        # REPARAMETERIZATION
        mean = self.hidden2mean(hidden)
        logv = self.hidden2logv(hidden)
        std = torch.exp(0.5 * logv)

        z = to_device(torch.randn([batch_size, self.latent_size]))
        z = z * std + mean

        # DECODER
        hidden = self.latent2hidden(z)

        if self.bidirectional or self.num_layers > 1:
            # unflatten hidden state
            hidden = hidden.view(self.hidden_factor, batch_size, self.hidden_size)
        else:
            hidden = hidden.unsqueeze(0)

        # decoder input
        if self.word_dropout_rate > 0:
            # randomly replace decoder input with <unk>
            prob = torch.rand(input_sequence.size())
            if torch.cuda.is_available():
                prob=prob.cuda()
            prob[(input_sequence.data - self.sos_idx) * (input_sequence.data - self.pad_idx) == 0] = 1
            decoder_input_sequence = input_sequence.clone()
            decoder_input_sequence[prob < self.word_dropout_rate] = self.unk_idx
            input_embedding = self.embedding(decoder_input_sequence)
        input_embedding = self.embedding_dropout(input_embedding)
        packed_input = rnn_utils.pack_padded_sequence(input_embedding, sorted_lengths.data.tolist(), batch_first=True)

        # decoder forward pass
        outputs, _ = self.decoder_rnn(packed_input, hidden)

        # process outputs
        padded_outputs = rnn_utils.pad_packed_sequence(outputs, batch_first=True)[0]
        padded_outputs = padded_outputs.contiguous()
        _,reversed_idx = torch.sort(sorted_idx)
        padded_outputs = padded_outputs[reversed_idx]
        b,s,_ = padded_outputs.size()

        # project outputs to vocab
        logp = nn.functional.log_softmax(self.outputs2vocab(padded_outputs.view(-1, padded_outputs.size(2))), dim=-1)
        logp = logp.view(b, s, self.embedding.num_embeddings)


        return logp, mean, logv, None, z, None
Пример #2
0
def train_and_eval_cvae(args):
    torch.manual_seed(args.seed)
    random.seed(args.seed)
    np.random.seed(args.seed)

    # output folder
    if args.pickle is not None:
        pickle_path = Path(args.pickle.rstrip('.pkl'))
        pickle_name = pickle_path.stem
        run_dir = pickle_path
    else:
        output_dir = Path(args.output_folder)
        if not output_dir.exists():
            output_dir.mkdir()
        current_time = datetime.now().strftime('%b%d_%H-%M-%S')
        run_dir = output_dir / current_time
    if not run_dir.exists():
        run_dir.mkdir()

    # data handling
    if args.cosine_threshold is not None and args.none_intents is not None:
        raise ValueError("None intents cannot be specified while using a "
                         "cosine similarity selection")
    data_folder = Path(args.data_folder)
    dataset_folder = data_folder / args.dataset_type
    none_folder = data_folder / args.none_type
    none_idx = NONE_COLUMN_MAPPING[args.none_type]

    dataset = create_dataset(
        dataset_type=args.dataset_type,
        dataset_folder=dataset_folder,
        dataset_size=args.dataset_size,
        restrict_intents=args.restrict_intents,
        none_folder=none_folder,
        none_size=args.none_size,
        none_intents=args.none_intents,
        none_idx=none_idx,
        infersent_selection=args.infersent_selection,
        cosine_threshold=args.cosine_threshold,
        input_type=args.input_type,
        tokenizer_type=args.tokenizer_type,
        preprocessing_type=args.preprocessing_type,
        max_sequence_length=args.max_sequence_length,
        embedding_type=args.embedding_type,
        embedding_dimension=args.embedding_dimension,
        max_vocab_size=args.max_vocab_size,
        slot_embedding=args.slot_embedding,
        run_dir=run_dir
    )

    if args.load_folder:
        original_vocab_size = dataset.update(args.load_folder)
        LOGGER.info('Loaded vocab from %s' % args.load_folder)

    # training
    if args.conditioning == NO_CONDITIONING:
        args.conditioning = None

    if not args.load_folder:
        model = CVAE(
            conditional=args.conditioning,
            compute_bow=args.bow_loss,
            vocab_size=dataset.vocab_size,
            embedding_size=args.embedding_dimension,
            rnn_type=args.rnn_type,
            hidden_size_encoder=args.hidden_size_encoder,
            hidden_size_decoder=args.hidden_size_decoder,
            word_dropout_rate=args.word_dropout_rate,
            embedding_dropout_rate=args.embedding_dropout_rate,
            z_size=args.latent_size,
            n_classes=dataset.n_classes,
            cat_size=dataset.n_classes if args.cat_size is None else args.cat_size,
            sos_idx=dataset.sos_idx,
            eos_idx=dataset.eos_idx,
            pad_idx=dataset.pad_idx,
            unk_idx=dataset.unk_idx,
            max_sequence_length=args.max_sequence_length,
            num_layers_encoder=args.num_layers_encoder,
            num_layers_decoder=args.num_layers_decoder,
            bidirectional=args.bidirectional,
            temperature=args.temperature,
            force_cpu=args.force_cpu
        )
    else:
        model = CVAE.from_folder(args.load_folder)
        LOGGER.info('Loaded model from %s' % args.load_folder)
        model.n_classes = dataset.n_classes
        model.update_embedding(dataset.vectors)
        model.update_outputs2vocab(original_vocab_size, dataset.vocab_size)

    model = to_device(model, args.force_cpu)
    parameters = filter(lambda p: p.requires_grad, model.parameters())

    optimizer = getattr(torch.optim, args.optimizer_type)(
        model.parameters(),
        lr=args.learning_rate
    )

    trainer = Trainer(
        dataset,
        model,
        optimizer,
        batch_size=args.batch_size,
        annealing_strategy=args.annealing_strategy,
        kl_anneal_rate=args.kl_anneal_rate,
        kl_anneal_time=args.kl_anneal_time,
        kl_anneal_target=args.kl_anneal_target,
        label_anneal_rate=args.label_anneal_rate,
        label_anneal_time=args.label_anneal_time,
        label_anneal_target=args.label_anneal_target,
        add_bow_loss=args.bow_loss,
        force_cpu=args.force_cpu,
        run_dir=run_dir / "tensorboard",
        alpha = args.alpha
    )

    trainer.run(args.n_epochs, dev_step_every_n_epochs=1)

    if args.pickle is not None:
        model_path = run_dir / "{}_load".format(pickle_name)
    else:
        model_path = run_dir / "load"
    dataset.save(model_path)
    model.save(model_path)

    # evaluation
    run_dict = dict()

    # generate queries
    generated_sentences, logp = generate_vae_sentences(
        model=model,
        n_to_generate=args.n_generated,
        input_type=args.input_type,
        i2int=dataset.i2int,
        i2w=dataset.i2w,
        eos_idx=dataset.eos_idx,
        slotdic=dataset.slotdic if args.input_type == 'delexicalised' else None,
        verbose=True
    )
    run_dict['generated'] = generated_sentences
    run_dict['metrics'] = compute_generation_metrics(
        dataset,
        generated_sentences['utterances'],
        generated_sentences['intents'],
        logp
    )
    for k, v in run_dict['metrics'].items():
        LOGGER.info((k, v))

    if args.input_type == "delexicalised":
        run_dict['delexicalised_metrics'] = compute_generation_metrics(
            dataset,
            generated_sentences['delexicalised'],
            generated_sentences['intents'],
            logp,
            input_type='delexicalised'
        )
    for k, v in run_dict['delexicalised_metrics'].items():
        LOGGER.info((k, v))

    save_augmented_dataset(generated_sentences, args.n_generated,
                           dataset.train_path, run_dir)

    run_dict['args'] = vars(args)
    run_dict['logs'] = trainer.run_logs
    run_dict['latent_rep'] = trainer.latent_rep
    run_dict['i2w'] = dataset.i2w
    run_dict['w2i'] = dataset.w2i
    run_dict['i2int'] = dataset.i2int
    run_dict['int2i'] = dataset.int2i
    run_dict['vectors'] = {
        'before': dataset.vocab.vectors,
        'after': model.embedding.weight.data
    }

    if args.pickle is not None:
        run_dict_path = run_dir.parents[0] / "{}.pkl".format(pickle_name)
    else:
        run_dict_path = run_dir / "run.pkl"
    torch.save(run_dict, str(run_dict_path))
Пример #3
0
    def inference(self, n=4, z=None):

        if z is None:
            batch_size = n
            z = to_device(torch.randn([batch_size, self.latent_size]))
        else:
            batch_size = z.size(0)

        hidden = self.latent2hidden(z)

        if self.bidirectional or self.num_layers > 1:
            # unflatten hidden state
            hidden = hidden.view(self.hidden_factor, batch_size, self.hidden_size)

        hidden = hidden.unsqueeze(0)

        # required for dynamic stopping of sentence generation
        sequence_idx = torch.arange(0, batch_size, out=self.tensor()).long() # all idx of batch
        sequence_running = torch.arange(0, batch_size, out=self.tensor()).long() # all idx of batch which are still generating
        sequence_mask = torch.ones(batch_size, out=self.tensor()).byte()

        running_seqs = torch.arange(0, batch_size, out=self.tensor()).long() # idx of still generating sequences with respect to current loop

        generations = self.tensor(batch_size, self.max_sequence_length).fill_(self.pad_idx).long()

        t=0
        while(t<self.max_sequence_length and len(running_seqs)>0):

            if t == 0:
                input_sequence = to_device(torch.Tensor(batch_size).fill_(self.sos_idx).long())

            input_sequence = input_sequence.unsqueeze(1)

            input_embedding = self.embedding(input_sequence)

            output, hidden = self.decoder_rnn(input_embedding, hidden)

            logits = self.outputs2vocab(output)

            input_sequence = self._sample(logits)

            # save next input
            generations = self._save_sample(generations, input_sequence, sequence_running, t)

            # update gloabl running sequence
            sequence_mask[sequence_running] = (input_sequence != self.eos_idx).data
            sequence_running = sequence_idx.masked_select(sequence_mask)

            # update local running sequences
            running_mask = (input_sequence != self.eos_idx).data
            running_seqs = running_seqs.masked_select(running_mask)

            # prune input and hidden state according to local update
            if len(running_seqs) > 0:
                input_sequence = input_sequence[running_seqs]
                hidden = hidden[:, running_seqs]

                running_seqs = torch.arange(0, len(running_seqs), out=self.tensor()).long()

            t += 1

        return generations, z, None
Пример #4
0
def train(model, datasets, args):

    train_iter, val_iter = datasets.get_iterators(batch_size=args.batch_size)

    opt = getattr(torch.optim, args.optimizer)(model.parameters(),
                                               lr=args.learning_rate)
    # opt = torch.optim.Adam([
    #     {"params": model.encoder_rnn.parameters(), "lr": args.learning_rate},
    #     {"params": model.hidden2mean.parameters(), "lr": args.learning_rate},
    #     {"params": model.hidden2logv.parameters(), "lr": args.learning_rate},
    #     {"params": model.hidden2cat.parameters(),  "lr": args.learning_rate},
    #     {"params": model.latent2hidden.parameters(), "lr": args.learning_rate},
    #     {"params": model.latent2bow.parameters(), "lr": args.learning_rate},
    #     {"params": model.outputs2vocab.parameters(), "lr": args.learning_rate}])

    step = 0

    NLL_hist = []
    KL_hist = []
    BOW_hist = []
    NMI_hist = []
    acc_hist = []

    latent_rep = {i: [] for i in range(model.n_classes)}

    for epoch in range(1, args.epochs + 1):
        tr_loss = 0.0
        NLL_tr_loss = 0.0
        KL_tr_loss = 0.0
        BOW_tr_loss = 0.0
        NMI_tr = 0.0
        n_correct_tr = 0.0
        acc_tr = 0.0

        model.train()  # turn on training mode
        for iteration, batch in enumerate(tqdm(train_iter)):
            step += 1
            opt.zero_grad()
            # model.word_dropout_rate =  anneal_fn(args.anneal_function, step, args.k3, args.x3, args.m3)

            x, lengths = getattr(batch, args.input_type)
            input = x[:, :-1]  # remove <eos>
            target = x[:, 1:]  # remove <sos>
            lengths -= 1  # account for the removal
            input, target = to_device(input), to_device(target)
            if args.conditional != 'none':
                y = batch.intent.squeeze()
                y = to_device(y)
                sorted_lengths, sorted_idx = torch.sort(lengths,
                                                        descending=True)
                y = y[sorted_idx]

            logp, mean, logv, logc, z, bow = model(input, lengths)
            if epoch == args.epochs and args.conditional != 'none':
                for i, intent in enumerate(y):
                    latent_rep[int(intent)].append(z[i].cpu().detach().numpy())

            # loss calculation
            NLL_loss, KL_losses, KL_weight, BOW_loss = loss_fn(
                logp, bow, target, lengths, mean, logv, args.anneal_function,
                step, args.k1, args.x1, args.m1)
            KL_loss = torch.sum(KL_losses)
            NLL_hist.append(NLL_loss.detach().cpu().numpy() / args.batch_size)
            KL_hist.append(KL_losses.detach().cpu().numpy() / args.batch_size)
            BOW_hist.append(BOW_loss.detach().cpu().numpy() / args.batch_size)
            label_loss, label_weight = loss_labels(logc, y,
                                                   args.anneal_function, step,
                                                   args.k2, args.x2, args.m2)
            loss = (NLL_loss + KL_weight * KL_loss + label_weight * label_loss
                    )  #/args.batch_size

            if args.bow_loss:
                loss += BOW_loss

            if args.conditional == 'none':
                pred_labels = 0
                n_correct = 0
                NMI = 0
            else:
                if args.conditional == 'supervised':
                    label_loss, label_weight = loss_labels(
                        logc, y, args.anneal_function, step, args.k2, args.x2,
                        args.m2)
                    loss += label_weight * label_loss
                elif args.conditional == 'unsupervised':
                    entropy = torch.sum(
                        torch.exp(logc) *
                        torch.log(model.n_classes * torch.exp(logc)))
                    loss += entropy
                pred_labels = logc.data.max(1)[1].long()
                n_correct = pred_labels.eq(y.data).cpu().sum().float().item()
                acc_hist.append(n_correct / args.batch_size)
                NMI = normalized_mutual_info_score(
                    y.cpu().detach().numpy(),
                    torch.exp(logc).cpu().max(1)[1].numpy())
                NMI_hist.append(NMI)

            loss.backward()
            # CLIPPING
            # for p in model.parameters():
            #     p.register_hook(lambda grad: torch.clamp(grad, -1, 1))
            # torch.nn.utils.clip_grad_value_(model.parameters(), clip_value=1)
            # torch.nn.utils.clip_grad_norm(model.parameters(), clipping_value=1)
            opt.step()

            tr_loss += loss.item()
            NLL_tr_loss += NLL_loss.item()
            KL_tr_loss += KL_loss.item()
            BOW_tr_loss += BOW_loss.item()
            NMI_tr += NMI
            n_correct_tr += n_correct

            # if iteration % 100 == 0:
            #     print("Loss %9.4f, NLL-Loss %9.4f, KL-Loss %9.4f, KL-Weight %6.3f"
            #               %(loss.data, NLL_loss.item()/args.batch_size, KL_loss.item()/args.batch_size, KL_weight))
            #     x_sentences = input[:3].cpu().numpy()
            #     print('\nInput sentences :')
            #     print(*idx2word(x_sentences, i2w=i2w, eos_idx=eos_idx), sep='\n')
            #     _, y_sentences = torch.topk(logp, 1, dim=-1)
            #     y_sentences = y_sentences[:3].squeeze().cpu().numpy()
            #     print('\nOutput sentences : ')
            #     print(*idx2word(y_sentences, i2w=i2w, eos_idx=eos_idx), sep='\n')
            #     print('\n')

        tr_loss = tr_loss / len(datasets.train)
        NLL_tr_loss = NLL_tr_loss / len(datasets.train)
        KL_tr_loss = KL_tr_loss / len(datasets.train)
        BOW_tr_loss = BOW_tr_loss / len(datasets.train)
        NMI_tr = NMI_tr / len(datasets.train)
        acc_tr = n_correct_tr / len(datasets.train)

        # calculate the validation loss for this epoch
        val_loss = 0.0
        NLL_val_loss = 0.0
        KL_val_loss = 0.0
        BOW_val_loss = 0.0
        NMI_val = 0.0
        n_correct_val = 0.0
        acc_val = 0.0

        model.eval()  # turn on evaluation mode
        for batch in tqdm(val_iter):
            x, lengths = getattr(batch, args.input_type)
            target = x[:, 1:]  # remove <sos>
            input = x[:, :-1]  # remove <eos>
            lengths -= 1  # account for the removal
            input, target = to_device(input), to_device(target)
            if args.conditional != 'none':
                y = batch.intent.squeeze()
                y = to_device(y)
                sorted_lengths, sorted_idx = torch.sort(lengths,
                                                        descending=True)
                y = y[sorted_idx]

            logp, mean, logv, logc, z, bow = model(input, lengths)
            # loss calculation
            NLL_loss, KL_losses, KL_weight, BOW_loss = loss_fn(
                logp, bow, target, lengths, mean, logv, args.anneal_function,
                step, args.k1, args.x1, args.m1)

            KL_loss = torch.sum(KL_losses)
            loss = (NLL_loss + KL_weight * KL_loss)  #/args.batch_size
            if args.bow_loss:
                loss += BOW_loss

            if args.conditional == 'none':
                pred_labels = 0
                n_correct = 0
                NMI = 0
            else:
                if args.conditional == 'supervised':
                    label_loss, label_weight = loss_labels(
                        logc, y, args.anneal_function, step, args.k2, args.x2,
                        args.m2)
                    loss += label_weight * label_loss
                elif args.conditional == 'unsupervised':
                    entropy = torch.sum(
                        torch.exp(logc) *
                        torch.log(model.n_classes * torch.exp(logc)))
                    loss += entropy
                pred_labels = logc.data.max(1)[1].long()
                n_correct = pred_labels.eq(y.data).cpu().sum().float().item()
                NMI = normalized_mutual_info_score(
                    y.cpu().detach().numpy(),
                    torch.exp(logc).cpu().max(1)[1].numpy())

            val_loss += loss.item()
            NLL_val_loss += NLL_loss.item()
            KL_val_loss += KL_loss.item()
            BOW_val_loss += BOW_loss.item()
            NMI_val += NMI
            n_correct_val += n_correct

        val_loss = val_loss / len(datasets.valid)
        NLL_val_loss = NLL_val_loss / len(datasets.valid)
        KL_val_loss = KL_val_loss / len(datasets.valid)
        BOW_val_loss = BOW_val_loss / len(datasets.valid)
        NMI_val = NMI_val / len(datasets.valid)
        acc_val = n_correct_val / len(datasets.valid)

        print('Epoch {} : train {:.6f} valid {:.6f}'.format(
            epoch, tr_loss, val_loss))
        print(
            'Training   :  NLL loss : {:.6f}, KL loss : {:.6f}, BOW : {:.6f}, acc : {:.6f}'
            .format(NLL_tr_loss, KL_tr_loss, BOW_tr_loss, acc_tr))
        print(
            'Validation :  NLL loss : {:.6f}, KL loss : {:.6f}, BOW : {:.6f}, acc : {:.6f}'
            .format(NLL_val_loss, KL_val_loss, BOW_val_loss, acc_val))

    run['NLL_hist'] = NLL_hist
    run['KL_hist'] = KL_hist
    run['NLL_val'] = NLL_val_loss
    run['KL_val'] = KL_val_loss
    run['NMI_hist'] = NMI_hist
    run['acc_hist'] = acc_hist
    run['latent'] = latent_rep

    return
Пример #5
0
    def do_one_sweep(self, iter, is_last_epoch, train_or_dev):
        if train_or_dev not in ['train', 'dev']:
            raise TypeError("train_or_dev should be either train or dev")

        if train_or_dev == "train":
            self.model.train()
        else:
            self.model.eval()

        sweep_loss = 0
        sweep_recon_loss = 0
        sweep_kl_loss = 0
        sweep_accuracy = 0
        n_batches = 0
        for iteration, batch in enumerate(tqdm(iter)):
            # if len(batch) < self.batch_size and :
            #     continue
            if train_or_dev == "train":
                self.step += 1
                self.optimizer.zero_grad()

            # forward pass
            x, lengths = getattr(batch, self.dataset.input_type)
            input = x[:, :-1]  # remove <eos>
            target = x[:, 1:]  # remove <sos>
            lengths -= 1  # account for the removal
            input, target = to_device(input, self.force_cpu), to_device(
                target, self.force_cpu)

            y = None
            if self.model.conditional is not None:
                y = batch.intent.squeeze()
                y = to_device(y, self.force_cpu)
                sorted_lengths, sorted_idx = torch.sort(lengths,
                                                        descending=True)
                y = y[sorted_idx]

            logp, mean, logv, logc, z, bow = self.model(input, lengths)

            if is_last_epoch:
                _, reversed_idx = torch.sort(sorted_idx)
                y = y[reversed_idx]
                logc = logc[reversed_idx]
                real_labels = [self.i2int[label] for label in y]
                pred_labels = [
                    self.i2int[label] if label < len(self.i2int) else 'None'
                    for label in logc.max(1)[1]
                ]
                for real_label, pred_label in zip(real_labels, pred_labels):
                    self.run_logs[train_or_dev]['classifications'][real_label][
                        pred_label] += 1
                for real_label in real_labels:
                    self.run_logs[train_or_dev]['transfer'][
                        real_label] += logc.sum(dim=0).cpu().detach()

                # save latent representation
                if train_or_dev == "train" and self.model.conditional:
                    for i, intent in enumerate(y):
                        self.latent_rep[self.i2int[intent]].append(
                            z[i].cpu().detach().numpy())

            # loss calculation
            loss, recon_loss, kl_loss, accuracy = self.compute_loss(
                logp, bow, target, lengths, mean, logv, logc, y, train_or_dev)

            sweep_loss += loss
            sweep_recon_loss += recon_loss
            sweep_kl_loss += kl_loss
            sweep_accuracy += accuracy

            n_batches += 1
            if train_or_dev == "train":
                loss.backward()
                self.optimizer.step()

        if is_last_epoch:
            for intent1 in self.i2int:
                n_sentences = sum(self.run_logs[train_or_dev]
                                  ['classifications'][intent1].values())
                self.run_logs[train_or_dev]['transfer'][intent1] /= n_sentences
                for intent2 in self.i2int:
                    self.run_logs[train_or_dev]['classifications'][intent1][
                        intent2] /= n_sentences

        return sweep_loss / n_batches, sweep_recon_loss / n_batches, \
               sweep_kl_loss / n_batches, sweep_accuracy / n_batches
Пример #6
0
    def inference(self, n=10, z=None, y_onehot=None, temperature=0):

        if z is None:
            batch_size = n
            z = torch.randn(batch_size, self.z_size)
        else:
            batch_size = z.size(0)

        if self.conditional is not None:
            if y_onehot is None:
                y = torch.LongTensor(batch_size, 1).random_() % self.n_classes
                y_onehot = torch.FloatTensor(batch_size, self.cat_size)
                y_onehot.fill_(0)
                y_onehot.scatter_(dim=1, index=y, value=1)
            latent = to_device(torch.cat((z, y_onehot), dim=1), self.force_cpu)
        else:
            y_onehot = None
            latent = to_device(z, self.force_cpu)

        hidden = self.latent2hidden(latent)

        if self.bidirectional or self.num_layers_decoder > 1:
            # unflatten hidden state
            hidden = hidden.view(self.num_layers_decoder, batch_size,
                                 self.hidden_size)
        else:
            hidden = hidden.unsqueeze(0)

        # required for dynamic stopping of sentence generation
        sequence_idx = torch.arange(
            0, batch_size, out=self.tensor()).long()  # all idx of batch
        sequence_running = torch.arange(
            0, batch_size, out=self.tensor()).long()  # all idx of batch
        # which are still generating
        sequence_mask = torch.ones(batch_size, out=self.tensor()).byte()

        running_seqs = torch.arange(0, batch_size,
                                    out=self.tensor()).long()  # idx of still
        # generating sequences with respect to current loop

        generations = self.tensor(batch_size, self.max_sequence_length).fill_(
            self.pad_idx).long()

        t = 0
        while t < self.max_sequence_length and len(running_seqs) > 0:
            if t == 0:
                input_sequence = torch.Tensor(batch_size).fill_(
                    self.sos_idx).long()
                # input_sequence = torch.randint(0, self.vocab_size,
                # (batch_size,))

            input_sequence = to_device(input_sequence.unsqueeze(1),
                                       self.force_cpu)

            input_embedding = self.embedding(input_sequence)
            output, hidden = self.decoder_rnn(input_embedding, hidden)

            logits = self.outputs2vocab(output)
            logp = nn.functional.log_softmax(logits / self.temperature, dim=-1)

            input_sequence = self._sample(logits)

            # save next input
            generations = self._save_sample(generations, input_sequence,
                                            sequence_running, t)

            # update gloabl running sequence
            sequence_mask[sequence_running] = (input_sequence !=
                                               self.eos_idx).data
            sequence_running = sequence_idx.masked_select(sequence_mask)

            # update local running sequences
            running_mask = (input_sequence != self.eos_idx).data
            running_seqs = running_seqs.masked_select(running_mask)

            # prune input and hidden state according to local update
            if len(running_seqs) > 0:
                try:
                    input_sequence = input_sequence[running_seqs]
                except:
                    break
                hidden = hidden[:, running_seqs]

                running_seqs = torch.arange(0,
                                            len(running_seqs),
                                            out=self.tensor()).long()
            t += 1

        return generations, z, y_onehot, logp
Пример #7
0
    def forward(self, input_sequence, lengths):
        batch_size = input_sequence.size(0)
        sorted_lengths, sorted_idx = torch.sort(lengths, descending=True)
        input_sequence = input_sequence[sorted_idx]

        # ENCODER
        input_embedding = self.embedding(input_sequence)
        packed_input = rnn_utils.pack_padded_sequence(
            input_embedding, sorted_lengths.data.tolist(), batch_first=True)
        _, hidden = self.encoder_rnn(packed_input)

        if self.bidirectional or self.num_layers_encoder > 1:
            # flatten hidden state
            hidden = hidden.view(
                batch_size,
                self.hidden_size_encoder * self.hidden_factor_encoder)
        else:
            hidden = hidden.squeeze()

        # REPARAMETERIZATION
        mean = self.hidden2mean(hidden)
        logv = self.hidden2logv(hidden)
        std = torch.exp(0.5 * logv)
        z = to_device(torch.randn(batch_size, self.z_size), self.force_cpu)
        z = z * std + mean

        if self.conditional is not None:
            logc = nn.functional.log_softmax(self.hidden2cat(hidden), dim=-1)
            y_onehot = nn.functional.gumbel_softmax(logc)
            latent = torch.cat((z, y_onehot), dim=-1)
        else:
            logc = None
            latent = z

        # DECODER
        hidden = self.latent2hidden(latent)

        if self.bidirectional or self.num_layers_decoder > 1:
            # unflatten hidden state
            hidden = hidden.view(self.num_layers_decoder, batch_size,
                                 self.hidden_size_decoder)
        else:
            hidden = hidden.unsqueeze(0)

        # decoder input
        if self.word_dropout_rate > 0:
            # randomly replace decoder input with <unk>
            prob = torch.rand(input_sequence.size())
            prob = to_device(prob)
            prob[(input_sequence.data - self.sos_idx) *
                 (input_sequence.data - self.pad_idx) == 0] = 1
            decoder_input_sequence = input_sequence.clone()
            decoder_input_sequence[
                prob < self.word_dropout_rate] = self.unk_idx
            input_embedding = self.embedding(decoder_input_sequence)
        input_embedding = self.embedding_dropout(input_embedding)

        packed_input = rnn_utils.pack_padded_sequence(
            input_embedding, sorted_lengths.data.tolist(), batch_first=True)
        outputs, _ = self.decoder_rnn(packed_input, hidden)

        # process outputs
        padded_outputs = \
            rnn_utils.pad_packed_sequence(outputs, batch_first=True)[0]
        padded_outputs = padded_outputs.contiguous()
        _, reversed_idx = torch.sort(sorted_idx)
        padded_outputs = padded_outputs[reversed_idx]
        bs, seqlen, hs = padded_outputs.size()

        logits = self.outputs2vocab(padded_outputs.view(-1, hs))
        logp = nn.functional.log_softmax(logits / self.temperature, dim=-1)
        logp = logp.view(bs, seqlen, self.embedding.num_embeddings)

        if self.bow:
            bow = nn.functional.log_softmax(self.z2bow(z), dim=0)
            bow = bow[reversed_idx]
        else:
            bow = None

        return logp, mean, logv, logc, z, bow