def generate_name(self, lstm: Decoder, address: str, batch_size: int, hidd_cell_states: tuple = None, sample: bool = True): """ lstm: Decoder associated with name being generated address: The address to correlate pyro distribution with latent variables hidd_cell_states: Previous LSTM hidden state or empty hidden state max_name_length: The max name length allowed """ # If no hidden state is provided, initialize it with all 0s if hidd_cell_states == None: hidd_cell_states = lstm.init_hidden(batch_size=batch_size) input_tensor = strings_to_tensor([SOS] * batch_size, 1, letter_to_index) names = [''] * batch_size for index in range(MAX_NAME_LENGTH): char_dist, hidd_cell_states = lstm.forward(input_tensor, hidd_cell_states) if sample: # Next LSTM input is the sampled character input_tensor = pyro.sample(f"unsup_{address}_{index}", dist.OneHotCategorical(char_dist)) chars_at_indexes = list( map(lambda index: MODEL_CHARS[int(index.item())], torch.argmax(input_tensor, dim=2).squeeze(0))) else: # Next LSTM input is the character with the highest probability of occurring pyro.sample(f"unsup_{address}_{index}", dist.OneHotCategorical(char_dist)) chars_at_indexes = list( map(lambda index: MODEL_CHARS[int(index.item())], torch.argmax(char_dist, dim=2).squeeze(0))) input_tensor = strings_to_tensor(chars_at_indexes, 1, letter_to_index) # Add sampled characters to names for i, char in enumerate(chars_at_indexes): names[i] += char # Discard everything after EOS character # names = list(map(lambda name: name[:name.find(EOS)] if name.find(EOS) > -1 else name, names)) return hidd_cell_states, names
def guide(self, X_u: list, X_s: list, Z_s: dict, observations=None): """ Guide for approximation of the posterior q(z|x) x: Training data (name string) z: Optionally supervised latent values (dictionary of name/format values) """ pyro.module("guide_fn_lstm", self.guide_fn_lstm) pyro.module("encoder_lstm", self.encoder_lstm) if observations is None: formatted_X_u = strings_to_tensor(X_u, MAX_NAME_LENGTH, printable_to_index) else: formatted_X_u = observations['unsup_output'].transpose(0, 1) hidd_cell_states = self.encoder_lstm.init_hidden(batch_size=len(X_u)) for i in range(formatted_X_u.shape[0]): _, hidd_cell_states = self.encoder_lstm.forward( formatted_X_u[i].unsqueeze(0), hidd_cell_states) with pyro.plate("unsup_batch", len(X_u)): _, first_names = self.generate_name( self.guide_fn_lstm, FIRST_NAME_ADD, len(X_u), hidd_cell_states=hidd_cell_states, sample=False) return first_names
def model(self, X_u: list, X_s: list, Z_s: dict, observations=None): """ Model for generating names representing p(x,z) x: Training data (name string) z: Optionally supervised latent values (dictionary of name/format values) """ pyro.module("model_fn_lstm", self.model_fn_lstm) formatted_X_u = strings_to_tensor(X_u, MAX_NAME_LENGTH, printable_to_index) formatted_X_s = strings_to_tensor(X_s, MAX_NAME_LENGTH, printable_to_index) with pyro.plate("sup_batch", len(X_s)): _, first_names = self.generate_name_supervised( self.model_fn_lstm, FIRST_NAME_ADD, len(X_s), observed=Z_s[FIRST_NAME_ADD]) full_names = list( map(lambda name: pad_string(name, MAX_NAME_LENGTH), first_names)) probs = strings_to_probs(full_names, MAX_NAME_LENGTH, printable_to_index, true_index_prob=self.peak_prob) pyro.sample("sup_output", dist.OneHotCategorical(probs.transpose(0, 1)).to_event(1), obs=formatted_X_s.transpose(0, 1)) with pyro.plate("unsup_batch", len(X_u)): _, first_names = self.generate_name(self.model_fn_lstm, FIRST_NAME_ADD, len(X_u)) full_names = list( map(lambda name: pad_string(name, MAX_NAME_LENGTH), first_names)) probs = strings_to_probs(full_names, MAX_NAME_LENGTH, printable_to_index, true_index_prob=self.peak_prob) pyro.sample("unsup_output", dist.OneHotCategorical(probs.transpose(0, 1)).to_event(1), obs=formatted_X_u.transpose(0, 1)) return full_names
def infer(self, X_u: list): formatted_X_u = strings_to_tensor(X_u, MAX_NAME_LENGTH, printable_to_index) hidd_cell_states = self.encoder_lstm.init_hidden(batch_size=len(X_u)) for i in range(formatted_X_u.shape[0]): _, hidd_cell_states = self.encoder_lstm.forward( formatted_X_u[i].unsqueeze(0), hidd_cell_states) _, first_names = self.generate_name(self.guide_fn_lstm, FIRST_NAME_ADD, len(X_u), hidd_cell_states=hidd_cell_states, sample=False) return first_names
def generate_name_supervised(self, lstm: Decoder, address: str, batch_size: int, observed: list = None): """ lstm: Decoder associated with name being generated address: The address to correlate pyro distribution with latent variables observed: Dictionary of name/format values """ hidd_cell_states = lstm.init_hidden(batch_size=batch_size) observed_tensor = strings_to_tensor(observed, MAX_NAME_LENGTH, letter_to_index) input_tensor = strings_to_tensor([SOS] * batch_size, 1, letter_to_index) names = [''] * batch_size for index in range(MAX_NAME_LENGTH): char_dist, hidd_cell_states = lstm.forward(input_tensor, hidd_cell_states) input_tensor = pyro.sample(f"sup_{address}_{index}", dist.OneHotCategorical(char_dist), obs=observed_tensor[index].unsqueeze(0)) # Sampled char should be an index not a one-hot chars_at_indexes = list( map(lambda index: MODEL_CHARS[int(index.item())], torch.argmax(input_tensor, dim=2).squeeze(0))) # Add sampled characters to names for i, char in enumerate(chars_at_indexes): names[i] += char # Discard everything after EOS character names = list( map( lambda name: name[:name.find(EOS)] if name.find(EOS) > -1 else name, names)) return hidd_cell_states, names
def step_rnn(rnn, address, address_type, length=15, hidden_layer=None, custom_input_size=None): """ Hacky function for sampling from RNNs with input/output dimension that is nont N_DIGIT or N_LETTER """ name = "" next_char = '0' if hidden_layer is None: hidden_layer = rnn.init_hidden() for _ in range(length): lstm_input = strings_to_tensor([next_char],1,custom_input_size=custom_input_size) next_char_probs, hidden_layer = rnn(lstm_input, hidden_layer) next_char_index = pyro.sample(f"{address_type}_{address}", dist.Categorical(next_char_probs)).item() next_char = str(next_char_index) name += next_char address += 1 return name, hidden_layer, address
def generate_string(rnn, address, address_type, length=15, hidden_layer=None): """ Given a character RNN, generate a digit by sampling RNN generated distribution per timestep """ name = "" next_char = '0' # TODO Is there better alternatives???????????????????????????????????????????????? if hidden_layer is None: hidden_layer = rnn.init_hidden() for _ in range(length): lstm_input = strings_to_tensor([next_char],1,number_only=True) next_char_probs, hidden_layer = rnn(lstm_input, hidden_layer) next_char_index = pyro.sample(f"{address_type}_{address}", dist.Categorical(next_char_probs)).item() next_char = index_to_digit(next_char_index) name += next_char address += 1 return name, hidden_layer, address
] svae = PhoneVAE(batch_size=1) optimizer = Adam(ADAM_CONFIG) svi = SVI(svae.model, svae.guide, optimizer, loss=Trace_ELBO()) """ Train the model """ train_elbo = [] for e in range(NUM_EPOCHS): epoch_loss = 0. for string in TEST_STRINGS: # Pad input string differently than observed string so program doesn't get rewarded by making string short one_hot_string = strings_to_tensor([string], MAX_STRING_LEN) if CUDA: one_hot_string.cuda() svi.step(one_hot_string) epoch_loss += svi.step(one_hot_string) if e % RECORD_EVERY == 0: avg_epoch_loss = epoch_loss/len(TEST_STRINGS) print(f"Epoch #{e} Average Loss: {avg_epoch_loss}") train_elbo.append(avg_epoch_loss) epoch_loss = 0 plt.plot(train_elbo) plt.title("ELBO") plt.xlabel("step") plt.ylabel("loss")
TEST_DATASET = ["+1 (604) 250 1363", "1-778-855-5941"] phone_csis = PhoneCSIS(hidden_size=HIDDEN_SIZE) phone_csis.load_checkpoint(filename=f"infcomp_{SESSION_NAME}.pth.tar") csis = pyro.infer.CSIS(phone_csis.model, phone_csis.guide, Adam({'lr': 0.001}), num_inference_samples=NUM_INFERENCE_SAMPLES) from phone_infcomp import EXT for phone_number in TEST_DATASET: print("=============================") print(f"Test Phone Number: {phone_number}") test_dataset = strings_to_tensor([phone_number], max_string_len=MAX_STRING_LEN, index_function=letter_to_index) posterior = csis.run(observations={'x': test_dataset}) # Draw samples from the posterior # print(marginal._categorical.probs.T) # marginal = pyro.infer.EmpiricalMarginal(posterior, sites=static_sample_sites) print(f"Posterior Samples") csis_samples = [posterior() for _ in range(NUM_POSTERIOR_SAMPLES)] for sample in csis_samples: rv_names = sample.stochastic_nodes ext_format = sample.nodes['ext_format']['value'].item() ext_index = sample.nodes['ext_index']['value'].item() prefix_format = sample.nodes['prefix_format']['value'].item() prefix_len = sample.nodes['prefix_len']['value'].item()