示例#1
0
 def init_hidden(self, num_samples):
     """
     Initializes hidden by passing :math:`state_0` embedding and
     zero-initialized hidden state to graph encoder
     """
     states = []
     for cell in self.cells:
         num_states = 1 + 1 * isinstance(cell, nn.LSTMCell)
         zeros = wrap(torch.zeros(num_samples, cell.hidden_size),
                      self.device)
         states.append(tuple([zeros.clone()] * num_states))
     input = self.embedding(wrap([0], self.device, dtype=torch.long))
     states = self(input, states)
     return states
示例#2
0
 def torch_embedding(self) -> torch.nn.Embedding:
     """
     Returns:
         torch.nn.Embedding with pretrained weights if ``fit_emeddings`` was called,
         randomly initialized otherwise.
     """
     embedding = torch.nn.Embedding(self.vocab_size,
                                    self.embedding_stats['dims'])
     if self.embedding is not None:
         embedding.weight = torch.nn.Parameter(wrap(self.embedding))
     return embedding
示例#3
0
    def make_torch_dataset(self,
                           corpus=None,
                           labels=None,
                           labels_dtype=torch.long):
        """
        Makes a torch.utils.data.Dataset from ``corpus`` or ``self.sequences`` if ``corpus`` is None.

        Args:
            corpus: corpus of text to make a dataset of
            labels (iterable): labels, if there are any
            labels_dtype: data type of the labels, if there are any

        Returns:
            An instance of torch.utils.data.Dataset with corpus indices and labels is there are any as data.
        """
        if corpus is not None or self.sequences is None:
            self.tokenize(corpus)

        sequence_tensor = wrap(self.sequences, dtype=torch.long)
        if 'predict' in self.task_type:
            self.dataset = TensorDataset(sequence_tensor[:-1],
                                         sequence_tensor[1:])
        elif 'class' in self.task_type:
            labels = labels if labels is not None else self.labels
            labels_tensor = wrap(labels, dtype=labels_dtype)
            self.dataset = TensorDataset(sequence_tensor, labels_tensor)
        else:

            class MonoDataset(Dataset):
                def __init__(self, data_tensor):
                    self.data_tensor = data_tensor

                def __getitem__(self, index):
                    return self.data_tensor[index]

                def __len__(self):
                    return self.data_tensor.size(0)

            self.dataset = MonoDataset(sequence_tensor)

        return self.dataset
示例#4
0
    def act(self, inputs, hidden, explore, key, output=None, pick=None):
        """
        Get action given encoded graph predicted to the moment.

        Args:
            inputs (torch.Tensor): encoded graph representation
            explore(bool): whether to explore or exploit only
            hidden (list): list of encoder cell(s) hidden states
            key (str): key corresponding to policy and value layers
            output (defaultdict(list), optional): if provided, chosen logprobs, chosen values and
                entropies will be appended to those lists, instead of being returned
            pick (int, optional): index of action to pick instead of sampling or argmax.

        Returns:
            action, hidden state, chosen logprobs, chosen values, entropies
        """
        assert key is not None, '`key` must not be None for act. ' \
                                'If you want to update hidden only, call `forward` instead'
        logits, values, hidden = self(inputs, hidden, key)
        distribution = Categorical(logits=logits)

        if pick is not None:
            assert pick < logits.size(1)
            action = wrap([pick], self.device, dtype=torch.long)
        elif explore:
            action = distribution.sample()
        else:
            action = distribution.probs.max(1)[1]

        chosen_logs = distribution.log_prob(action).unsqueeze(1)
        chosen_values = values.gather(1, action.unsqueeze(1))
        entropies = distribution.entropy()

        if output is not None:
            assert isinstance(output, defaultdict) and output.default_factory == list, \
                '`output` must be a defaultdict with `list` default_factory.'

            output['logprob'].append(chosen_logs)
            output['values'].append(chosen_values)
            output['entropies'].append(entropies)

            return action, hidden

        return action, hidden, chosen_logs, chosen_values, entropies
示例#5
0
    def _subsample(self,
                   hidden,
                   explore,
                   search_space,
                   names,
                   representations,
                   output,
                   description,
                   outer_i=None):
        """
        The recursive workhorse of `sample` method.

        Args:
            hidden (list): previous encoder hidden state
            explore (bool): whether to explore the search/action space or exploit only
            search_space (SearchSpace): current level of search space
            names (list): names of search spaces up to current level
            output (defaultdict(list), optional): dict of lists to append outputs to
            description (dict): description being generated
            outer_i (int): outer iteration ordinal (used in recursion, `None` on depth `0`)

        Returns:
            list: next hidden state
        """
        name = search_space.name
        names = copy.deepcopy(names)
        names.append(name)
        # Those checks were introduced to reuse this method to evaluate model output for already existing description.
        if description.get(name) is None:
            description[name] = {}

        # region num inner prediction
        num_inner = self.search_space.eval_(search_space.num_inner, **locals())

        # region forcing facilitation
        if description.get(f'num_{name}') is None:
            forced_inner = self.search_space.eval_(
                search_space.forced_num_inner, **locals())
            max_available = max(num_inner) if isinstance(
                num_inner, (list, tuple)) else num_inner
            assert forced_inner is None or isinstance(
                forced_inner, int) and 0 < forced_inner <= max_available

            if forced_inner is not None:
                try:
                    forced_inner = num_inner.index(forced_inner)
                except ValueError:
                    raise ValueError(
                        f'Number of inner search spaces "{forced_inner}" '
                        'is not present in original search space.')
        else:
            forced_inner = num_inner.index(description[f'num_{name}'])
        # endregion

        index = self.embedding_index[f'{name}_start']
        index = wrap([index], self.device, dtype=torch.long)
        input = self.embedding(index)

        if len(num_inner) > 1:
            key = f'{"_".join(names[:-1])}_{len(num_inner)}_{name}s'
            action, hidden = self.act(input, hidden, explore, key, output,
                                      forced_inner)
            num_inner = num_inner[action.item()]
        else:
            hidden = self(input, hidden)
            num_inner = num_inner[
                forced_inner] if forced_inner is not None else num_inner[0]

        if description.get(f'num_{name}') is None:
            description[f'num_{name}'] = num_inner
        # endregion

        # region inner space prediction
        index = self.embedding_index[f'{num_inner}_{name}s']
        index = wrap([index], self.device, dtype=torch.long)
        input = self.embedding(index)

        encoded_flag = False
        for i in range(int(num_inner)):
            if description[name].get(i) is None:
                description[name][i] = {}

            if isinstance(search_space.inner, dict):
                for k, v in search_space.inner.items():
                    v = self.search_space.eval_(v, **locals())
                    key = f'{"_".join(names[:-1])}_{len(v)}_{k}s'

                    if isinstance(v, (list, tuple)) and len(v) > 1:
                        pick = description[name][i].get(k)
                        if pick is not None:
                            try:
                                pick = v.index(pick)
                            except ValueError:
                                raise ValueError(
                                    f'Point "{pick}" is not present in '
                                    f'{k} dimension of the search space.')

                        action, hidden = self.act(input, hidden, explore, key,
                                                  output, pick)

                        choice = v[action.item()]

                        if pick is None: description[name][i][k] = choice
                        else: assert choice == description[name][i][k]

                        if k == 'id':
                            if choice in representations:
                                input = representations[choice]
                                continue

                        index = self.embedding_index[f'{k}_{choice}']
                        index = wrap([index], self.device, dtype=torch.long)
                        input = self.embedding(index)
                    else:
                        if description[name][i].get(k) is None:
                            description[name][i][k] = v[0]
                        else:
                            assert v[0] == description[name][i][k]

            else:
                assert isinstance(search_space.inner, (list, tuple, SearchSpace)), \
                    'Inner search space must be either dict, SearchSpace or list of SearchSpaces.'

                if not encoded_flag:
                    hidden = self(input, hidden)
                    encoded_flag = True

                spaces = [search_space.inner] if isinstance(
                    search_space.inner, SearchSpace) else search_space.inner
                for space in spaces:
                    input = self._subsample(hidden, explore, space, names,
                                            representations, output,
                                            description[name][i], i)
                    hidden = self(input[-1][0], hidden)

        index = self.embedding_index[f'{name}_inner_done']
        index = wrap([index], self.device, dtype=torch.long)
        input = self.embedding(index)
        # endregion

        # region outer keys prediction
        for k, v in search_space.outer.items():
            v = self.search_space.eval_(v, **locals())
            key = f'{"_".join(names[:-1])}_{len(v)}_{k}s'

            if isinstance(v, (list, tuple)) and len(v) > 1:
                pick = description.get(k)
                if pick is not None:
                    try:
                        pick = v.index(pick)
                    except ValueError:
                        raise ValueError(f'Point "{pick}" is not present in '
                                         f'{k} dimension of the search space.')

                action, hidden = self.act(input, hidden, explore, key, output,
                                          pick)

                choice = v[action.item()]

                if pick is None: description[k] = choice
                else: assert choice == description[k]

                if k == 'id':
                    if choice in representations:
                        input = representations[choice]
                        continue

                index = self.embedding_index[f'{k}_{choice}']
                index = wrap([index], self.device, dtype=torch.long)
                input = self.embedding(index)
            else:
                if description[name][i].get(k) is None:
                    description[name][i][k] = v[0]
                else:
                    assert v[0] == description[name][i][k]
        # endregion

        index = self.embedding_index[f'{name}_end']
        index = wrap([index], self.device, dtype=torch.long)
        input = self.embedding(index)

        hidden = self(input, hidden)
        if len(names) > 2:
            repr_key = f'{names[-2]}' if outer_i is None else f'{names[-2]}_{outer_i}'
            representations[repr_key] = hidden[-1][0]
        return hidden