def init_hidden(self, num_samples): """ Initializes hidden by passing :math:`state_0` embedding and zero-initialized hidden state to graph encoder """ states = [] for cell in self.cells: num_states = 1 + 1 * isinstance(cell, nn.LSTMCell) zeros = wrap(torch.zeros(num_samples, cell.hidden_size), self.device) states.append(tuple([zeros.clone()] * num_states)) input = self.embedding(wrap([0], self.device, dtype=torch.long)) states = self(input, states) return states
def torch_embedding(self) -> torch.nn.Embedding: """ Returns: torch.nn.Embedding with pretrained weights if ``fit_emeddings`` was called, randomly initialized otherwise. """ embedding = torch.nn.Embedding(self.vocab_size, self.embedding_stats['dims']) if self.embedding is not None: embedding.weight = torch.nn.Parameter(wrap(self.embedding)) return embedding
def make_torch_dataset(self, corpus=None, labels=None, labels_dtype=torch.long): """ Makes a torch.utils.data.Dataset from ``corpus`` or ``self.sequences`` if ``corpus`` is None. Args: corpus: corpus of text to make a dataset of labels (iterable): labels, if there are any labels_dtype: data type of the labels, if there are any Returns: An instance of torch.utils.data.Dataset with corpus indices and labels is there are any as data. """ if corpus is not None or self.sequences is None: self.tokenize(corpus) sequence_tensor = wrap(self.sequences, dtype=torch.long) if 'predict' in self.task_type: self.dataset = TensorDataset(sequence_tensor[:-1], sequence_tensor[1:]) elif 'class' in self.task_type: labels = labels if labels is not None else self.labels labels_tensor = wrap(labels, dtype=labels_dtype) self.dataset = TensorDataset(sequence_tensor, labels_tensor) else: class MonoDataset(Dataset): def __init__(self, data_tensor): self.data_tensor = data_tensor def __getitem__(self, index): return self.data_tensor[index] def __len__(self): return self.data_tensor.size(0) self.dataset = MonoDataset(sequence_tensor) return self.dataset
def act(self, inputs, hidden, explore, key, output=None, pick=None): """ Get action given encoded graph predicted to the moment. Args: inputs (torch.Tensor): encoded graph representation explore(bool): whether to explore or exploit only hidden (list): list of encoder cell(s) hidden states key (str): key corresponding to policy and value layers output (defaultdict(list), optional): if provided, chosen logprobs, chosen values and entropies will be appended to those lists, instead of being returned pick (int, optional): index of action to pick instead of sampling or argmax. Returns: action, hidden state, chosen logprobs, chosen values, entropies """ assert key is not None, '`key` must not be None for act. ' \ 'If you want to update hidden only, call `forward` instead' logits, values, hidden = self(inputs, hidden, key) distribution = Categorical(logits=logits) if pick is not None: assert pick < logits.size(1) action = wrap([pick], self.device, dtype=torch.long) elif explore: action = distribution.sample() else: action = distribution.probs.max(1)[1] chosen_logs = distribution.log_prob(action).unsqueeze(1) chosen_values = values.gather(1, action.unsqueeze(1)) entropies = distribution.entropy() if output is not None: assert isinstance(output, defaultdict) and output.default_factory == list, \ '`output` must be a defaultdict with `list` default_factory.' output['logprob'].append(chosen_logs) output['values'].append(chosen_values) output['entropies'].append(entropies) return action, hidden return action, hidden, chosen_logs, chosen_values, entropies
def _subsample(self, hidden, explore, search_space, names, representations, output, description, outer_i=None): """ The recursive workhorse of `sample` method. Args: hidden (list): previous encoder hidden state explore (bool): whether to explore the search/action space or exploit only search_space (SearchSpace): current level of search space names (list): names of search spaces up to current level output (defaultdict(list), optional): dict of lists to append outputs to description (dict): description being generated outer_i (int): outer iteration ordinal (used in recursion, `None` on depth `0`) Returns: list: next hidden state """ name = search_space.name names = copy.deepcopy(names) names.append(name) # Those checks were introduced to reuse this method to evaluate model output for already existing description. if description.get(name) is None: description[name] = {} # region num inner prediction num_inner = self.search_space.eval_(search_space.num_inner, **locals()) # region forcing facilitation if description.get(f'num_{name}') is None: forced_inner = self.search_space.eval_( search_space.forced_num_inner, **locals()) max_available = max(num_inner) if isinstance( num_inner, (list, tuple)) else num_inner assert forced_inner is None or isinstance( forced_inner, int) and 0 < forced_inner <= max_available if forced_inner is not None: try: forced_inner = num_inner.index(forced_inner) except ValueError: raise ValueError( f'Number of inner search spaces "{forced_inner}" ' 'is not present in original search space.') else: forced_inner = num_inner.index(description[f'num_{name}']) # endregion index = self.embedding_index[f'{name}_start'] index = wrap([index], self.device, dtype=torch.long) input = self.embedding(index) if len(num_inner) > 1: key = f'{"_".join(names[:-1])}_{len(num_inner)}_{name}s' action, hidden = self.act(input, hidden, explore, key, output, forced_inner) num_inner = num_inner[action.item()] else: hidden = self(input, hidden) num_inner = num_inner[ forced_inner] if forced_inner is not None else num_inner[0] if description.get(f'num_{name}') is None: description[f'num_{name}'] = num_inner # endregion # region inner space prediction index = self.embedding_index[f'{num_inner}_{name}s'] index = wrap([index], self.device, dtype=torch.long) input = self.embedding(index) encoded_flag = False for i in range(int(num_inner)): if description[name].get(i) is None: description[name][i] = {} if isinstance(search_space.inner, dict): for k, v in search_space.inner.items(): v = self.search_space.eval_(v, **locals()) key = f'{"_".join(names[:-1])}_{len(v)}_{k}s' if isinstance(v, (list, tuple)) and len(v) > 1: pick = description[name][i].get(k) if pick is not None: try: pick = v.index(pick) except ValueError: raise ValueError( f'Point "{pick}" is not present in ' f'{k} dimension of the search space.') action, hidden = self.act(input, hidden, explore, key, output, pick) choice = v[action.item()] if pick is None: description[name][i][k] = choice else: assert choice == description[name][i][k] if k == 'id': if choice in representations: input = representations[choice] continue index = self.embedding_index[f'{k}_{choice}'] index = wrap([index], self.device, dtype=torch.long) input = self.embedding(index) else: if description[name][i].get(k) is None: description[name][i][k] = v[0] else: assert v[0] == description[name][i][k] else: assert isinstance(search_space.inner, (list, tuple, SearchSpace)), \ 'Inner search space must be either dict, SearchSpace or list of SearchSpaces.' if not encoded_flag: hidden = self(input, hidden) encoded_flag = True spaces = [search_space.inner] if isinstance( search_space.inner, SearchSpace) else search_space.inner for space in spaces: input = self._subsample(hidden, explore, space, names, representations, output, description[name][i], i) hidden = self(input[-1][0], hidden) index = self.embedding_index[f'{name}_inner_done'] index = wrap([index], self.device, dtype=torch.long) input = self.embedding(index) # endregion # region outer keys prediction for k, v in search_space.outer.items(): v = self.search_space.eval_(v, **locals()) key = f'{"_".join(names[:-1])}_{len(v)}_{k}s' if isinstance(v, (list, tuple)) and len(v) > 1: pick = description.get(k) if pick is not None: try: pick = v.index(pick) except ValueError: raise ValueError(f'Point "{pick}" is not present in ' f'{k} dimension of the search space.') action, hidden = self.act(input, hidden, explore, key, output, pick) choice = v[action.item()] if pick is None: description[k] = choice else: assert choice == description[k] if k == 'id': if choice in representations: input = representations[choice] continue index = self.embedding_index[f'{k}_{choice}'] index = wrap([index], self.device, dtype=torch.long) input = self.embedding(index) else: if description[name][i].get(k) is None: description[name][i][k] = v[0] else: assert v[0] == description[name][i][k] # endregion index = self.embedding_index[f'{name}_end'] index = wrap([index], self.device, dtype=torch.long) input = self.embedding(index) hidden = self(input, hidden) if len(names) > 2: repr_key = f'{names[-2]}' if outer_i is None else f'{names[-2]}_{outer_i}' representations[repr_key] = hidden[-1][0] return hidden