コード例 #1
0
    def generate_name(self,
                      lstm: Decoder,
                      address: str,
                      batch_size: int,
                      hidd_cell_states: tuple = None,
                      sample: bool = True):
        """
        lstm: Decoder associated with name being generated
        address: The address to correlate pyro distribution with latent variables
        hidd_cell_states: Previous LSTM hidden state or empty hidden state
        max_name_length: The max name length allowed
        """
        # If no hidden state is provided, initialize it with all 0s
        if hidd_cell_states == None:
            hidd_cell_states = lstm.init_hidden(batch_size=batch_size)

        input_tensor = strings_to_tensor([SOS] * batch_size, 1,
                                         letter_to_index)
        names = [''] * batch_size

        for index in range(MAX_NAME_LENGTH):
            char_dist, hidd_cell_states = lstm.forward(input_tensor,
                                                       hidd_cell_states)

            if sample:
                # Next LSTM input is the sampled character
                input_tensor = pyro.sample(f"unsup_{address}_{index}",
                                           dist.OneHotCategorical(char_dist))
                chars_at_indexes = list(
                    map(lambda index: MODEL_CHARS[int(index.item())],
                        torch.argmax(input_tensor, dim=2).squeeze(0)))
            else:
                # Next LSTM input is the character with the highest probability of occurring
                pyro.sample(f"unsup_{address}_{index}",
                            dist.OneHotCategorical(char_dist))
                chars_at_indexes = list(
                    map(lambda index: MODEL_CHARS[int(index.item())],
                        torch.argmax(char_dist, dim=2).squeeze(0)))
                input_tensor = strings_to_tensor(chars_at_indexes, 1,
                                                 letter_to_index)

            # Add sampled characters to names
            for i, char in enumerate(chars_at_indexes):
                names[i] += char

        # Discard everything after EOS character
        # names = list(map(lambda name: name[:name.find(EOS)] if name.find(EOS) > -1 else name, names))
        return hidd_cell_states, names
コード例 #2
0
    def guide(self, X_u: list, X_s: list, Z_s: dict, observations=None):
        """
        Guide for approximation of the posterior q(z|x)
        x: Training data (name string)
        z: Optionally supervised latent values (dictionary of name/format values)
        """

        pyro.module("guide_fn_lstm", self.guide_fn_lstm)
        pyro.module("encoder_lstm", self.encoder_lstm)

        if observations is None:
            formatted_X_u = strings_to_tensor(X_u, MAX_NAME_LENGTH,
                                              printable_to_index)
        else:
            formatted_X_u = observations['unsup_output'].transpose(0, 1)

        hidd_cell_states = self.encoder_lstm.init_hidden(batch_size=len(X_u))
        for i in range(formatted_X_u.shape[0]):
            _, hidd_cell_states = self.encoder_lstm.forward(
                formatted_X_u[i].unsqueeze(0), hidd_cell_states)

        with pyro.plate("unsup_batch", len(X_u)):
            _, first_names = self.generate_name(
                self.guide_fn_lstm,
                FIRST_NAME_ADD,
                len(X_u),
                hidd_cell_states=hidd_cell_states,
                sample=False)

        return first_names
コード例 #3
0
    def model(self, X_u: list, X_s: list, Z_s: dict, observations=None):
        """
        Model for generating names representing p(x,z)
        x: Training data (name string)
        z: Optionally supervised latent values (dictionary of name/format values)
        """
        pyro.module("model_fn_lstm", self.model_fn_lstm)

        formatted_X_u = strings_to_tensor(X_u, MAX_NAME_LENGTH,
                                          printable_to_index)
        formatted_X_s = strings_to_tensor(X_s, MAX_NAME_LENGTH,
                                          printable_to_index)

        with pyro.plate("sup_batch", len(X_s)):
            _, first_names = self.generate_name_supervised(
                self.model_fn_lstm,
                FIRST_NAME_ADD,
                len(X_s),
                observed=Z_s[FIRST_NAME_ADD])
            full_names = list(
                map(lambda name: pad_string(name, MAX_NAME_LENGTH),
                    first_names))
            probs = strings_to_probs(full_names,
                                     MAX_NAME_LENGTH,
                                     printable_to_index,
                                     true_index_prob=self.peak_prob)
            pyro.sample("sup_output",
                        dist.OneHotCategorical(probs.transpose(0,
                                                               1)).to_event(1),
                        obs=formatted_X_s.transpose(0, 1))

        with pyro.plate("unsup_batch", len(X_u)):
            _, first_names = self.generate_name(self.model_fn_lstm,
                                                FIRST_NAME_ADD, len(X_u))
            full_names = list(
                map(lambda name: pad_string(name, MAX_NAME_LENGTH),
                    first_names))
            probs = strings_to_probs(full_names,
                                     MAX_NAME_LENGTH,
                                     printable_to_index,
                                     true_index_prob=self.peak_prob)
            pyro.sample("unsup_output",
                        dist.OneHotCategorical(probs.transpose(0,
                                                               1)).to_event(1),
                        obs=formatted_X_u.transpose(0, 1))

        return full_names
コード例 #4
0
 def infer(self, X_u: list):
     formatted_X_u = strings_to_tensor(X_u, MAX_NAME_LENGTH,
                                       printable_to_index)
     hidd_cell_states = self.encoder_lstm.init_hidden(batch_size=len(X_u))
     for i in range(formatted_X_u.shape[0]):
         _, hidd_cell_states = self.encoder_lstm.forward(
             formatted_X_u[i].unsqueeze(0), hidd_cell_states)
     _, first_names = self.generate_name(self.guide_fn_lstm,
                                         FIRST_NAME_ADD,
                                         len(X_u),
                                         hidd_cell_states=hidd_cell_states,
                                         sample=False)
     return first_names
コード例 #5
0
    def generate_name_supervised(self,
                                 lstm: Decoder,
                                 address: str,
                                 batch_size: int,
                                 observed: list = None):
        """
        lstm: Decoder associated with name being generated
        address: The address to correlate pyro distribution with latent variables
        observed: Dictionary of name/format values
        """
        hidd_cell_states = lstm.init_hidden(batch_size=batch_size)
        observed_tensor = strings_to_tensor(observed, MAX_NAME_LENGTH,
                                            letter_to_index)
        input_tensor = strings_to_tensor([SOS] * batch_size, 1,
                                         letter_to_index)
        names = [''] * batch_size

        for index in range(MAX_NAME_LENGTH):
            char_dist, hidd_cell_states = lstm.forward(input_tensor,
                                                       hidd_cell_states)
            input_tensor = pyro.sample(f"sup_{address}_{index}",
                                       dist.OneHotCategorical(char_dist),
                                       obs=observed_tensor[index].unsqueeze(0))
            # Sampled char should be an index not a one-hot
            chars_at_indexes = list(
                map(lambda index: MODEL_CHARS[int(index.item())],
                    torch.argmax(input_tensor, dim=2).squeeze(0)))
            # Add sampled characters to names
            for i, char in enumerate(chars_at_indexes):
                names[i] += char

        # Discard everything after EOS character
        names = list(
            map(
                lambda name: name[:name.find(EOS)]
                if name.find(EOS) > -1 else name, names))
        return hidd_cell_states, names
コード例 #6
0
def step_rnn(rnn, address, address_type, length=15, hidden_layer=None, custom_input_size=None):
    """
    Hacky function for sampling from RNNs with input/output dimension that is nont N_DIGIT or N_LETTER
    """
    name = ""
    next_char = '0'
    if hidden_layer is None: hidden_layer = rnn.init_hidden()
    for _ in range(length):
        lstm_input = strings_to_tensor([next_char],1,custom_input_size=custom_input_size)
        next_char_probs, hidden_layer = rnn(lstm_input, hidden_layer)
        next_char_index = pyro.sample(f"{address_type}_{address}", dist.Categorical(next_char_probs)).item()
        next_char = str(next_char_index)
        name += next_char
        address += 1
    return name, hidden_layer, address
コード例 #7
0
def generate_string(rnn, address, address_type, length=15, hidden_layer=None):
    """
    Given a character RNN, generate a digit by sampling RNN generated distribution per timestep
    """
    name = ""
    next_char = '0' # TODO Is there better alternatives????????????????????????????????????????????????
    if hidden_layer is None: hidden_layer = rnn.init_hidden()
    for _ in range(length):
        lstm_input = strings_to_tensor([next_char],1,number_only=True)
        next_char_probs, hidden_layer = rnn(lstm_input, hidden_layer)
        next_char_index = pyro.sample(f"{address_type}_{address}", dist.Categorical(next_char_probs)).item()
        next_char = index_to_digit(next_char_index)
        name += next_char
        address += 1
    return name, hidden_layer, address
コード例 #8
0
]

svae = PhoneVAE(batch_size=1)
optimizer = Adam(ADAM_CONFIG)
svi = SVI(svae.model, svae.guide, optimizer, loss=Trace_ELBO())


"""
Train the model
"""
train_elbo = []
for e in range(NUM_EPOCHS):
    epoch_loss = 0.
    for string in TEST_STRINGS:
        # Pad input string differently than observed string so program doesn't get rewarded by making string short
        one_hot_string = strings_to_tensor([string], MAX_STRING_LEN)
        if CUDA: one_hot_string.cuda()
        svi.step(one_hot_string)
        epoch_loss += svi.step(one_hot_string)
    if e % RECORD_EVERY == 0:
        avg_epoch_loss = epoch_loss/len(TEST_STRINGS)
        print(f"Epoch #{e} Average Loss: {avg_epoch_loss}")
        train_elbo.append(avg_epoch_loss)
        epoch_loss = 0



plt.plot(train_elbo)
plt.title("ELBO")
plt.xlabel("step")
plt.ylabel("loss")
コード例 #9
0
TEST_DATASET = ["+1 (604) 250 1363", "1-778-855-5941"]

phone_csis = PhoneCSIS(hidden_size=HIDDEN_SIZE)
phone_csis.load_checkpoint(filename=f"infcomp_{SESSION_NAME}.pth.tar")

csis = pyro.infer.CSIS(phone_csis.model,
                       phone_csis.guide,
                       Adam({'lr': 0.001}),
                       num_inference_samples=NUM_INFERENCE_SAMPLES)

from phone_infcomp import EXT
for phone_number in TEST_DATASET:
    print("=============================")
    print(f"Test Phone Number: {phone_number}")
    test_dataset = strings_to_tensor([phone_number],
                                     max_string_len=MAX_STRING_LEN,
                                     index_function=letter_to_index)
    posterior = csis.run(observations={'x': test_dataset})

    # Draw samples from the posterior
    # print(marginal._categorical.probs.T)
    # marginal = pyro.infer.EmpiricalMarginal(posterior, sites=static_sample_sites)
    print(f"Posterior Samples")
    csis_samples = [posterior() for _ in range(NUM_POSTERIOR_SAMPLES)]
    for sample in csis_samples:
        rv_names = sample.stochastic_nodes

        ext_format = sample.nodes['ext_format']['value'].item()
        ext_index = sample.nodes['ext_index']['value'].item()
        prefix_format = sample.nodes['prefix_format']['value'].item()
        prefix_len = sample.nodes['prefix_len']['value'].item()