Example #1
0
    def critic_loss_on_batch(self, batch):
        batch.domain_x = cudable(batch.domain_x)
        batch.domain_y = cudable(batch.domain_y)

        x_hid = self.encoder(batch.domain_x)
        y_hid = self.encoder(batch.domain_y)
        x2y_hid = self.gen_x2y(x_hid)
        y2x_hid = self.gen_y2x(y_hid)

        # Critic loss
        critic_x_preds_x, critic_x_preds_y2x = self.critic_x(
            x_hid), self.critic_x(y2x_hid)
        critic_y_preds_y, critic_y_preds_x2y = self.critic_y(
            y_hid), self.critic_y(x2y_hid)
        critic_x_loss = self.critic_criterion(critic_x_preds_x,
                                              critic_x_preds_y2x)
        critic_y_loss = self.critic_criterion(critic_y_preds_y,
                                              critic_y_preds_x2y)
        critic_x_gp = wgan_gp(self.critic_x, x_hid, y2x_hid)
        critic_y_gp = wgan_gp(self.critic_y, y_hid, x2y_hid)
        critic_x_total_loss = critic_x_loss + self.config.hp.gp_lambda * critic_x_gp
        critic_y_total_loss = critic_y_loss + self.config.hp.gp_lambda * critic_y_gp
        critics_total_loss = (critic_x_total_loss + critic_y_total_loss) / 2

        losses_info = {
            'critic_loss/domain_x': critic_x_loss.item(),
            'critic_loss/domain_y': critic_y_loss.item(),
            'critic_loss/gp_x': critic_x_gp.item(),
            'critic_loss/gp_y': critic_y_gp.item(),
        }

        return critics_total_loss, losses_info
def wgan_gp(critic, real_data, fake_data, *critic_args, **critic_kwargs):
    "Computes gradient penalty according to WGAN-GP paper"
    assert real_data.size() == fake_data.size()

    batch_size, ndim = real_data.size(0), real_data.dim()

    if ndim == 1:
        eps = cudable(torch.rand(batch_size))
    else:
        eps = cudable(torch.rand(batch_size, *np.ones(ndim - 1).astype(int)))

    eps = eps.expand(real_data.size())

    interpolations = eps * real_data + (1 - eps) * fake_data
    preds = critic(interpolations, *critic_args, **critic_kwargs)

    grads = autograd.grad(
        outputs=preds,
        inputs=interpolations,
        grad_outputs=cudable(torch.ones(preds.size())),
        retain_graph=True, create_graph=True, only_inputs=True
    )[0]

    gp = ((grads.norm(2, dim=1) - 1) ** 2).mean()

    return gp
Example #3
0
    def gen_loss_on_batch(self, batch):
        batch.domain_x = cudable(batch.domain_x)
        batch.domain_y = cudable(batch.domain_y)

        x_hid = self.encoder(batch.domain_x)
        y_hid = self.encoder(batch.domain_y)
        x2y_hid = self.gen_x2y(x_hid)
        y2x_hid = self.gen_y2x(y_hid)
        x2y2x_hid = self.gen_y2x(x2y_hid)
        y2x2y_hid = self.gen_x2y(y2x_hid)

        # Generator loss (consists of making critic's life harder and lp loss)
        critic_x_preds_y2x = self.critic_x(y2x_hid)
        critic_y_preds_x2y = self.critic_y(x2y_hid)
        adv_x_loss = self.generator_criterion(critic_y_preds_x2y)
        adv_y_loss = self.generator_criterion(critic_x_preds_y2x)
        adv_loss = (adv_x_loss + adv_y_loss) / 2

        x_hid_lp_loss = torch.norm(x_hid - x2y2x_hid, p=self.config.hp.p_norm)
        y_hid_lp_loss = torch.norm(y_hid - y2x2y_hid, p=self.config.hp.p_norm)
        lp_loss = (x_hid_lp_loss + y_hid_lp_loss) / 2

        # Total loss
        gen_loss = adv_loss + self.config.hp.lp_loss_coef * lp_loss

        losses_info = {
            'gen_adv_loss/domain_x': adv_x_loss.item(),
            'gen_adv_loss/domain_y': adv_y_loss.item(),
            'gen_lp_loss/domain_x': x_hid_lp_loss.item(),
            'gen_lp_loss/domain_y': y_hid_lp_loss.item(),
        }

        return gen_loss, losses_info
Example #4
0
    def validate(self):
        rec_losses = []
        bleus = []

        for batch in self.val_dataloader:
            batch.src, batch.trg = cudable(batch.src), cudable(batch.trg)
            # CE loss
            rec_loss = self.loss_on_batch(batch)
            rec_losses.append(rec_loss.item())

            # BLEU
            encs, enc_mask = self.transformer.encoder(batch.src)
            preds = InferenceState({
                'model': self.transformer.decoder,
                'inputs': encs,
                'enc_mask': enc_mask,
                'vocab': self.vocab_trg,
                'max_len': 50
            }).inference()
            preds = itos_many(preds, self.vocab_trg)
            gold = itos_many(batch.trg, self.vocab_trg)
            bleu = compute_bleu_for_sents(preds, gold)
            bleus.append(bleu)

        self.writer.add_scalar('val/rec_loss', np.mean(rec_losses), self.num_iters_done)
        self.writer.add_scalar('val/bleu', np.mean(bleus), self.num_iters_done)
        self.losses['val_bleu'].append(np.mean(bleus))

        texts = ['Translation: {}\n\n Gold: {}'.format(t,g) for t,g in zip(preds, gold)]
        text = '\n\n ================== \n\n'.join(texts[:10])
        self.writer.add_text('Samples', text, self.num_iters_done)
Example #5
0
    def loss_on_batch(self, batch):
        batch.src, batch.trg = cudable(batch.src), cudable(batch.trg)
        recs = self.transformer(batch.src, batch.trg)
        targets = batch.trg[:, 1:].contiguous().view(-1)
        rec_loss = self.criterion(recs.view(-1, len(self.vocab_trg)), targets)

        return rec_loss
def predict(lines):
    lines = tokenize(lines)
    # Grouping all lines into batches
    src, trg = generate_dataset_with_middle_chars(lines)
    examples = [Example.fromlist([m,o], fields) for m,o in zip(src, trg)]
    ds = Dataset(examples, fields)
    dataloader = data.BucketIterator(ds, batch_size, repeat=False, shuffle=False)

    word_translations = []

    for batch in dataloader:
        # Generating predictions
        batch.src = cudable(batch.src)
        batch.trg = cudable(batch.trg)
        morphs = morph_chars_idx(batch.trg, field.vocab)
        morphs = cudable(torch.from_numpy(morphs).float())
        first_chars_embs = decoder.embed(batch.trg[:, :n_first_chars])

        z = encoder(batch.src)
        z = merge_z(torch.cat([z, morphs], dim=1))
        z = decoder.gru(first_chars_embs, z.unsqueeze(0))[1].squeeze(0)
        out = simple_inference(decoder, z, field.vocab, max_len=30)

        first_chars = batch.trg[:, :n_first_chars].cpu().numpy().tolist()
        results = [s + p for s,p in zip(first_chars, out)]
        results = itos_many(results, field.vocab, sep='')

        word_translations.extend(results)

    transfered = group_by_lens(word_translations, [len(s.split()) for s in lines])
    transfered = [mix_transfered(o,t) for o,t in zip(lines, transfered)]

    return transfered
    def pad_to_max(self, seq, max_len):
        if len(seq) == max_len: return seq

        if self.gumbel:
            # TODO: we fill with eps to prevent numerical issues in loss computation later
            # But loss on pads is always zero, why are we doing this?
            eps = 1e-8
            pads = cudable(T.zeros(max_len - len(seq), len(self.vocab)))
            pads = pads.fill_(eps).index_fill_(1, T.tensor(self.pad_idx), 1.)
        else:
            pads = cudable(T.zeros(max_len - len(seq)).fill_(self.pad_idx))

        return T.cat((seq, pads), dim=0)
    def init_models(self):
        size = self.config.hp.size
        dropout_p = self.config.hp.dropout
        dropword_p = self.config.hp.dropword

        self.encoder = RNNEncoder(size, size, self.vocab, dropword_p)
        self.decoder = RNNDecoder(size, size, self.vocab, dropword_p)
        self.split_nn = SplitNN(size, self.config.hp.style_vec_size)
        self.motivator = FFN([self.config.hp.style_vec_size, 1],
                             dropout=dropout_p)
        self.critic = FFN([size, size, 1], dropout=dropout_p)
        self.merge_nn = MergeNN(size, self.config.hp.style_vec_size)

        # Let's save all ae params into single list for future use
        self.ae_params = list(
            chain(
                self.encoder.parameters(),
                self.decoder.parameters(),
                self.split_nn.parameters(),
                self.merge_nn.parameters(),
            ))

        self.dissonet = cudable(
            DissoNet(self.encoder, self.decoder, self.split_nn, self.motivator,
                     self.critic, self.merge_nn))

        if torch.cuda.device_count() > 1:
            print('Going to parallelize on {} GPUs'.format(
                torch.cuda.device_count()))
            self.dissonet = nn.DataParallel(self.dissonet)
    def transfer_style_on_batch(self, batch):
        batch = cudable(batch)

        state_x = self.encoder(batch.domain_x)
        state_y = self.encoder(batch.domain_y)

        content_x, style_x = self.dissonet.split_nn(state_x)
        content_y, style_y = self.dissonet.split_nn(state_y)

        state_x2y = self.merge_nn(content_x, style_y)
        state_y2x = self.merge_nn(content_y, style_x)
        state_x2x = self.merge_nn(content_x, style_x)
        state_y2y = self.merge_nn(content_y, style_y)

        x2y = simple_inference(self.decoder,
                               state_x2y,
                               self.vocab,
                               eos_token='|')
        y2x = simple_inference(self.decoder,
                               state_y2x,
                               self.vocab,
                               eos_token='|')
        x2x = simple_inference(self.decoder,
                               state_x2x,
                               self.vocab,
                               eos_token='|')
        y2y = simple_inference(self.decoder,
                               state_y2y,
                               self.vocab,
                               eos_token='|')

        return x2y, y2x, x2x, y2y
Example #10
0
def subsequent_mask(size):
    "Mask out subsequent positions."

    attn_shape = (1, size, size)
    mask = np.triu(np.ones(attn_shape), k=1).astype('uint8')
    mask = cudable(torch.from_numpy(mask) == 0)

    return mask
    def stack_finished(self):
        "Pads finished sequences with <pad> token and stacks into tensor"
        max_len = max(len(s) for s in self.finished)

        for i, seq in enumerate(self.finished):
            self.finished[i] = self.pad_to_max(seq, max_len)

        self.finished = cudable(T.stack(self.finished))
Example #12
0
    def transfer_style_on_batch(self, batch):
        batch.domain_x = cudable(batch.domain_x)
        batch.domain_y = cudable(batch.domain_y)

        x_z = self.encoder(batch.domain_x)
        y_z = self.encoder(batch.domain_y)

        x2y_z = self.gen_x2y(x_z)
        y2x_z = self.gen_y2x(y_z)
        x2x_z = self.gen_y2x(x2y_z)
        y2y_z = self.gen_x2y(y2x_z)

        x2y = simple_inference(self.decoder, x2y_z, self.vocab, eos_token='|')
        y2x = simple_inference(self.decoder, y2x_z, self.vocab, eos_token='|')
        x2x = simple_inference(self.decoder, x2x_z, self.vocab, eos_token='|')
        y2y = simple_inference(self.decoder, y2y_z, self.vocab, eos_token='|')

        return x2y, y2x, x2x, y2y
Example #13
0
    def pad_to(self, seq, n):
        "Pads sequence with zero vectors to the desired length"
        assert len(seq) <= n

        if len(seq) == n: return seq

        pads = torch.zeros(n - len(seq), seq.size(1))
        pads = cudable(pads).type(seq.type())

        return torch.cat([seq, pads])
Example #14
0
    def forward(self, encs, encs_mask):
        tokens = torch.arange(self.config.n_vecs).long().unsqueeze(0).repeat(
            encs.size(0), 1)
        x = self.embed(cudable(tokens))
        x = x * np.sqrt(self.config.d_model)

        for _ in range(self.config.n_steps):
            x = self.layer(encs, x, encs_mask, None)

        x = self.norm(x)

        return x
Example #15
0
    def init_models(self):
        size = self.config.hp.model_size
        dropout_p = self.config.hp.dropout
        dropword_p = self.config.hp.dropword

        self.encoder = cudable(
            RNNEncoder(size,
                       size,
                       self.vocab,
                       dropword_p,
                       noise=self.config.hp.noiseness))
        self.decoder = cudable(RNNDecoder(size, size, self.vocab, dropword_p))

        def create_critic():
            return FFN([size, size, 1], dropout_p)

        # GAN from X to Y
        self.gen_x2y = cudable(Generator(size, self.config.hp.gen_n_rec_steps))
        self.critic_y = cudable(create_critic())

        # GAN from Y to X
        self.gen_y2x = cudable(Generator(size, self.config.hp.gen_n_rec_steps))
        self.critic_x = cudable(create_critic())

        self.transfer_strength = TransferStrength(self.config)
        self.content_preservation = ContentPreservation(self.config)
Example #16
0
    def validate(self):
        losses = []

        for batch in self.val_dataloader:
            batch.domain_x = cudable(batch.domain_x)
            batch.domain_y = cudable(batch.domain_y)

            *_, rec_losses_info = self.ae_loss_on_batch(batch)
            *_, gen_losses_info = self.gen_loss_on_batch(batch)
            with torch.enable_grad():
                *_, critic_losses_info = self.critic_loss_on_batch(batch)
            losses_info = dict(
                list(critic_losses_info.items()) +
                list(gen_losses_info.items()) + list(rec_losses_info.items()))
            losses.append(losses_info)

        for l in losses[0].keys():
            value = np.mean([info[l] for info in losses])
            self.writer.add_scalar('VAL/' + l, value, self.num_iters_done)

        # Ok, let's now validate style transfer and auto-encoding
        self.validate_inference()
def init_lm(config_path, state_path, model_cls_name: str):
    model_cls = MODEL_CLASSES[model_cls_name]
    hp = load_config(config_path).get('hp')
    get_path = create_get_path_fn(state_path)

    # Loading vocab
    field = Field(eos_token=EOS_TOKEN,
                  batch_first=True,
                  tokenize=char_tokenize,
                  pad_first=True)
    field.vocab = pickle.load(open(get_path('vocab', 'pickle'), 'rb'))

    print('Loading models..')
    device = None if torch.cuda.is_available() else 'cpu'

    if model_cls is RNNLM:
        lm = cudable(RNNLM(hp.model_size, field.vocab,
                           n_layers=hp.n_layers)).eval()
        lm.load_state_dict(torch.load(get_path('lm'), map_location=device))
    elif model_cls is ConditionalLM:
        lm = cudable(ConditionalLM(hp.model_size, field.vocab)).eval()
        lm.load_state_dict(torch.load(get_path('lm'), map_location=device))
    elif model_cls is CharLMFromEmbs:
        rnn_lm = cudable(
            RNNLM(hp.model_size, field.vocab, n_layers=hp.n_layers))
        style_embed = cudable(nn.Embedding(2, hp.model_size))

        rnn_lm.load_state_dict(torch.load(get_path('lm'), map_location=device))
        style_embed.load_state_dict(
            torch.load(get_path('style_embed'), map_location=device))

        lm = cudable(CharLMFromEmbs(rnn_lm, style_embed,
                                    n_layers=hp.n_layers)).eval()
    else:
        raise NotImplementedError

    return lm, field
Example #18
0
    def ae_loss_on_batch(self, batch):
        batch.domain_x = cudable(batch.domain_x)
        batch.domain_y = cudable(batch.domain_y)

        x_hid = self.encoder(batch.domain_x)
        y_hid = self.encoder(batch.domain_y)

        # Reconstruction loss
        recs_x = self.decoder(x_hid, batch.domain_x[:, :-1])
        recs_y = self.decoder(y_hid, batch.domain_y[:, :-1])
        rec_loss_x = self.rec_criterion(
            recs_x.view(-1, len(self.vocab)),
            batch.domain_x[:, 1:].contiguous().view(-1))
        rec_loss_y = self.rec_criterion(
            recs_y.view(-1, len(self.vocab)),
            batch.domain_y[:, 1:].contiguous().view(-1))
        rec_loss = (rec_loss_x + rec_loss_y) / 2

        losses_info = {
            'rec_loss/domain_x': rec_loss_x.item(),
            'rec_loss/domain_y': rec_loss_y.item(),
        }

        return rec_loss, losses_info
Example #19
0
def onehot_gumbel_softmax(logits, temperature):
    """
    input: [*, n_class]
    return: [*, n_class] a one-hot vector
    """
    y = gumbel_softmax_sample(logits, temperature)

    shape = y.size()
    _, idx = y.max(dim=-1)

    y_hard = torch.zeros_like(y).view(-1, shape[-1])
    y_hard.scatter_(1, idx.view(-1, 1), 1)
    y_hard = cudable(y_hard.view(*shape))

    return (y_hard - y).detach() + y
    def validate(self):
        losses = []
        rec_losses = []

        for batch in self.val_dataloader:
            batch = cudable(batch)
            with torch.enable_grad():
                rec_loss, *_, losses_info = self.loss_on_batch(batch)
            rec_losses.append(rec_loss.item())
            losses.append(losses_info)

        for l in losses[0].keys():
            value = np.mean([info[l] for info in losses])
            self.writer.add_scalar('VAL/' + l, value, self.num_iters_done)

        self.losses['val_rec_loss'].append(np.mean(rec_losses))

        # Ok, let's now validate style transfer and auto-encoding
        self.validate_inference()
Example #21
0
 def init_models(self):
     self.transformer = cudable(Transformer(self.config.hp.transformer, self.vocab_src, self.vocab_trg))
Example #22
0
def sample_gumbel(shape, eps=1e-20):
    U = cudable(torch.rand(shape))
    G = -torch.log(-torch.log(U + eps) + eps)

    return G
    def init_z(self, batch_size, style):
        styles = cudable(torch.ones(self.n_layers,
                                    batch_size).fill_(style)).long()
        z = self.style_embed(styles)

        return z
 def generate_active_seqs(self):
     return cudable(
         T.tensor([[self.bos_idx] for _ in range(self.batch_size)]).long())
    def predict(sentences: List[str],
                n_lines: int,
                temperature: float = None,
                max_len: int = None):
        "For each sentence generates `n_lines` lines sequentially to form a dialog"
        dialogs = [s for s in sentences
                   ]  # Let's not mutate original list and copy it
        batch_size = len(dialogs)
        temperature = temperature or DEFAULT_TEMPERATURE
        max_len = max_len or DEFAULT_MAX_LINE_LEN

        for _ in range(n_lines):
            examples = [
                Example.fromlist([EOS_TOKEN.join(d)], [('text', field)])
                for d in dialogs
            ]
            dataset = Dataset(examples, [('text', field)])
            dataloader = data.BucketIterator(dataset,
                                             batch_size,
                                             shuffle=False,
                                             repeat=False)
            batch = next(iter(dataloader))  # We have a single batch
            text = cudable(
                batch.text[:, -MAX_CONTEXT_SIZE:]
            )  # As we made pad_first we are not afraid of losing information

            if model_cls_name == 'CharLMFromEmbs':
                z = lm.init_z(text.size(0), 1)
                z = lm(z, text, return_z=True)[1]
            elif model_cls_name == 'ConditionalLM':
                z = cudable(torch.zeros(2, len(text), 2048))
                z = lm(z, text, style=1, return_z=True)[1]
            elif model_cls_name == 'WeightedLMEnsemble':
                z = cudable(torch.zeros(2, 1, len(text), 4096))
                z = lm(z, text, return_z=True)[1]
            else:
                embs = lm.embed(text)
                z = lm.gru(embs)[1]

            next_lines = InferenceState({
                'model':
                lm,
                'inputs':
                z,
                'vocab':
                field.vocab,
                'max_len':
                max_len,
                'bos_token':
                EOS_TOKEN,  # We start infering a new reply when we see EOS
                'eos_token':
                EOS_TOKEN,
                'temperature':
                temperature,
                'sample_type':
                'sample',
                'inputs_batch_dim':
                1 if model_cls_name != 'WeightedLMEnsemble' else 2,
                'substitute_inputs':
                True,
                'kwargs':
                inference_kwargs
            }).inference()

            next_lines = itos_many(next_lines, field.vocab, sep='')
            next_lines = [slice_unfinished_sentence(l) for l in next_lines]
            dialogs = [d + EOS_TOKEN + l for d, l in zip(dialogs, next_lines)]

        dialogs = [d.split(EOS_TOKEN) for d in dialogs]
        dialogs = [[s for s in d if len(s) != 0] for d in dialogs]
        dialogs = [assign_speakers(d) for d in dialogs]

        return dialogs
Example #26
0
 def init_models(self):
     self.classifier = cudable(
         RNNClassifier(self.config.hp.model_size, self.vocab))
# models_dir = 'experiments/char-wm/checkpoints'
models_dir = 'models/style-model'
versions = set([int(f.split('-')[1].split('.')[0]) for f in os.listdir(models_dir) if '-' in f])
latest_iter = max(versions)
get_path = lambda m, ext='pth': os.path.join(models_dir, '{}-{}.{}'.format(m, latest_iter, ext))
print('Latest iter (version) found: {}. Loading from it.'.format(latest_iter))

# Loading vocab
field = Field(init_token='<bos>', eos_token='<eos>',
              batch_first=True, tokenize=char_tokenize)
field.vocab = pickle.load(open(get_path('vocab', 'pickle'), 'rb'))
fields = [('src', field), ('trg', field)]

print('Loading models..')
encoder = cudable(RNNEncoder(512, 512, field.vocab)).eval()
decoder = cudable(RNNDecoder(512, 512, field.vocab)).eval()
merge_z = cudable(FFN([512 + MORPHS_SIZE, 512])).eval()

location = None if torch.cuda.is_available() else 'cpu'
encoder.load_state_dict(torch.load(get_path('encoder'), map_location=location))
decoder.load_state_dict(torch.load(get_path('decoder'), map_location=location))
merge_z.load_state_dict(torch.load(get_path('merge_z'), map_location=location))


def predict(lines):
    lines = tokenize(lines)
    # Grouping all lines into batches
    src, trg = generate_dataset_with_middle_chars(lines)
    examples = [Example.fromlist([m,o], fields) for m,o in zip(src, trg)]
    ds = Dataset(examples, fields)