def load(self):
     """Load the model from the last checkpoint."""
     if not self.load_path.exists():
         log.info("[initializing new `{}`]".format(self.__class__.__name__))
         self.dict.init_from_scratch()
         self.embdict.init_from_scratch(self.dict.tok2int_vocab)
         if hasattr(self.dict, 'char2int_vocab'):
             chars_num = len(self.dict.char2int_vocab)
         else:
             chars_num = 0
         self._net = RankingNetwork(chars_num=chars_num,
                                    toks_num=len(self.dict.tok2int_vocab),
                                    emb_dict=self.embdict,
                                    **self.network_parameters)
         self._net.init_from_scratch(self.embdict.emb_matrix)
     else:
         log.info("[initializing `{}` from saved]".format(
             self.__class__.__name__))
         self.dict.load()
         self.embdict.load()
         if hasattr(self.dict, 'char2int_vocab'):
             chars_num = len(self.dict.char2int_vocab)
         else:
             chars_num = 0
         self._net = RankingNetwork(chars_num=chars_num,
                                    toks_num=len(self.dict.tok2int_vocab),
                                    emb_dict=self.embdict,
                                    **self.network_parameters)
         self._net.load(self.load_path)
Esempio n. 2
0
 def load(self):
     """Load the model from the last checkpoint."""
     if not self.load_path.exists():
         log.info("[initializing new `{}`]".format(self.__class__.__name__))
         self.dict.init_from_scratch()
         self.embdict.init_from_scratch(self.dict.tok2int_vocab)
         if hasattr(self.dict, 'char2int_vocab'):
             chars_num = len(self.dict.char2int_vocab)
         else:
             chars_num = 0
         self._net = RankingNetwork(chars_num=chars_num,
                                    toks_num=len(self.dict.tok2int_vocab),
                                    emb_dict=self.embdict,
                                    **self.network_parameters)
         self._net.init_from_scratch(self.embdict.emb_matrix)
     else:
         log.info("[initializing `{}` from saved]".format(self.__class__.__name__))
         self.dict.load()
         self.embdict.load()
         if hasattr(self.dict, 'char2int_vocab'):
             chars_num = len(self.dict.char2int_vocab)
         else:
             chars_num = 0
         self._net = RankingNetwork(chars_num=chars_num,
                                    toks_num=len(self.dict.tok2int_vocab),
                                    emb_dict=self.embdict,
                                    **self.network_parameters)
         self._net.load(self.load_path)
Esempio n. 3
0
 def load(self):
     if not self.load_path.exists():
         log.info("[initializing new `{}`]".format(self.__class__.__name__))
         self.dict.init_from_scratch()
         self._net = RankingNetwork(len(self.dict.tok2int_vocab), **self.network_parameters)
         embdict_parameter_names = list(inspect.signature(Embeddings.__init__).parameters)
         embdict_parameters = {par: self.opt[par] for par in embdict_parameter_names if par in self.opt}
         embdict= Embeddings(self.dict.tok2int_vocab, **embdict_parameters)
         self._net.set_emb_matrix(embdict.emb_matrix)
     else:
         log.info("[initializing `{}` from saved]".format(self.__class__.__name__))
         self.dict.load()
         self._net = RankingNetwork(len(self.dict.tok2int_vocab), **self.network_parameters)
         self._net.load(self.load_path)
Esempio n. 4
0
 def load(self):
     if not self.load_path.exists():
         log.info("[initializing new `{}`]".format(self.__class__.__name__))
         self.dict.init_from_scratch()
         self.embdict.init_from_scratch(self.dict.tok2int_vocab)
         self._net = RankingNetwork(toks_num=len(self.dict.tok2int_vocab),
                                    emb_dict=self.embdict,
                                    **self.network_parameters)
         self._net.init_from_scratch(self.embdict.emb_matrix)
     else:
         log.info("[initializing `{}` from saved]".format(self.__class__.__name__))
         self.dict.load()
         self.embdict.load()
         self._net = RankingNetwork(toks_num=len(self.dict.tok2int_vocab),
                                    emb_dict=self.embdict,
                                    **self.network_parameters)
         self._net.load(self.load_path)
Esempio n. 5
0
class RankingModel(NNModel):
    def __init__(self, **kwargs):
        """ Initialize the model and additional parent classes attributes

        Args:
            **kwargs: a dictionary containing parameters for model and parameters for training
                      it formed from json config file part that correspond to your model.

        """

        # Parameters for parent classes
        save_path = kwargs.get('save_path', None)
        load_path = kwargs.get('load_path', None)
        train_now = kwargs.get('train_now', None)
        mode = kwargs.get('mode', None)

        # Call parent constructors. Results in addition of attributes (save_path,
        # load_path, train_now, mode to current instance) and creation of save_folder
        # if it doesn't exist
        super().__init__(save_path=save_path, load_path=load_path,
                         train_now=train_now, mode=mode)

        opt = deepcopy(kwargs)
        self.train_now = opt['train_now']
        self.opt = opt
        self.interact_pred_num = opt['interact_pred_num']
        self.vocabs = opt.get('vocabs', None)

        dict_parameter_names = list(inspect.signature(InsuranceDict.__init__).parameters)
        dict_parameters = {par: opt[par] for par in dict_parameter_names if par in opt}
        network_parameter_names = list(inspect.signature(RankingNetwork.__init__).parameters)
        self.network_parameters = {par: opt[par] for par in network_parameter_names if par in opt}

        self.dict = InsuranceDict(**dict_parameters)

        self.load()

        train_parameters_names = list(inspect.signature(self._net.train_on_batch).parameters)
        self.train_parameters = {par: opt[par] for par in train_parameters_names if par in opt}

    @overrides
    def load(self):
        if not self.load_path.exists():
            log.info("[initializing new `{}`]".format(self.__class__.__name__))
            self.dict.init_from_scratch()
            self._net = RankingNetwork(len(self.dict.tok2int_vocab), **self.network_parameters)
            embdict_parameter_names = list(inspect.signature(Embeddings.__init__).parameters)
            embdict_parameters = {par: self.opt[par] for par in embdict_parameter_names if par in self.opt}
            embdict= Embeddings(self.dict.tok2int_vocab, **embdict_parameters)
            self._net.set_emb_matrix(embdict.emb_matrix)
        else:
            log.info("[initializing `{}` from saved]".format(self.__class__.__name__))
            self.dict.load()
            self._net = RankingNetwork(len(self.dict.tok2int_vocab), **self.network_parameters)
            self._net.load(self.load_path)

    @overrides
    def save(self):
        """Save model to the save_path, provided in config. The directory is
        already created by super().__init__ part in called in __init__ of this class"""
        log.info('[saving model to {}]'.format(self.save_path.resolve()))
        self._net.save(self.save_path)
        self.set_embeddings()
        self.dict.save()

    @check_attr_true('train_now')
    def train_on_batch(self, x, y):
        self.reset_embeddings()
        context = [el[0] for el in x]
        response = [el[1] for el in x]
        negative_response = [el[2] for el in x]
        c = self.dict.make_toks(context, type="context")
        c = self.dict.make_ints(c)
        rp = self.dict.make_toks(response, type="response")
        rp = self.dict.make_ints(rp)
        rn = self.dict.make_toks(negative_response, type="response")
        rn = self.dict.make_ints(rn)
        b = [c, rp, rn], y
        self._net.train_on_batch(b)

    @overrides
    def __call__(self, batch):
        self.set_embeddings()
        if type(batch[0]) == list:
            context = [el[0] for el in batch]
            c = self.dict.make_toks(context, type="context")
            c = self.dict.make_ints(c)
            c_emb = self._net.predict_context_emb([c, c, c], bs=len(batch))
            response = [el[1] for el in batch]
            batch_size = len(response)
            ranking_length = len(response[0])
            response = reduce(operator.concat, response)
            response = [response[i:batch_size*ranking_length:ranking_length] for i in range(ranking_length)]
            y_pred = []
            for i in range(ranking_length):
                r_emb = [self.dict.response2emb_vocab[el] for el in response[i]]
                r_emb = np.vstack(r_emb)
                yp = np.sum(c_emb * r_emb, axis=1) / np.linalg.norm(c_emb, axis=1) / np.linalg.norm(r_emb, axis=1)
                y_pred.append(np.expand_dims(yp, axis=1))
            y_pred = np.hstack(y_pred)
            return y_pred

        elif type(batch[0]) == str:
            c_input = tokenize(batch)
            c_input = self.dict.make_ints(c_input)
            c_input_emb = self._net.predict_context_emb([c_input, c_input, c_input], bs=1)

            c_emb = [self.dict.context2emb_vocab[i] for i in range(len(self.dict.context2emb_vocab))]
            c_emb = np.vstack(c_emb)
            pred_cont = np.sum(c_input_emb * c_emb, axis=1)\
                     / np.linalg.norm(c_input_emb, axis=1) / np.linalg.norm(c_emb, axis=1)
            pred_cont = np.flip(np.argsort(pred_cont), 0)[:self.interact_pred_num]
            pred_cont = [' '.join(self.dict.context2toks_vocab[el]) for el in pred_cont]

            r_emb = [self.dict.response2emb_vocab[i] for i in range(len(self.dict.response2emb_vocab))]
            r_emb = np.vstack(r_emb)
            pred_resp = np.sum(c_input_emb * r_emb, axis=1)\
                     / np.linalg.norm(c_input_emb, axis=1) / np.linalg.norm(r_emb, axis=1)
            pred_resp = np.flip(np.argsort(pred_resp), 0)[:self.interact_pred_num]
            pred_resp = [' '.join(self.dict.response2toks_vocab[el]) for el in pred_resp]
            y_pred = [{"contexts": pred_cont, "responses": pred_resp}]
            return y_pred

    def set_embeddings(self):
        if self.dict.response2emb_vocab[0] is None:
            r = []
            for i in range(len(self.dict.response2toks_vocab)):
                r.append(self.dict.response2toks_vocab[i])
            r = self.dict.make_ints(r)
            response_embeddings = self._net.predict_response_emb([r, r, r], 512)
            for i in range(len(self.dict.response2toks_vocab)):
                self.dict.response2emb_vocab[i] = response_embeddings[i]
        if self.dict.context2emb_vocab[0] is None:
            contexts = []
            for i in range(len(self.dict.context2toks_vocab)):
                contexts.append(self.dict.context2toks_vocab[i])
            contexts = self.dict.make_ints(contexts)
            context_embeddings = self._net.predict_context_emb([contexts, contexts, contexts], 512)
            for i in range(len(self.dict.context2toks_vocab)):
                self.dict.context2emb_vocab[i] = context_embeddings[i]

    def reset_embeddings(self):
        if self.dict.response2emb_vocab[0] is not None:
            for i in range(len(self.dict.response2emb_vocab)):
                self.dict.response2emb_vocab[i] = None
        if self.dict.context2emb_vocab[0] is not None:
            for i in range(len(self.dict.context2emb_vocab)):
                self.dict.context2emb_vocab[i] = None

    def shutdown(self):
        pass

    def reset(self):
        pass
class RankingModel(NNModel):
    """Class to perform ranking.

    Args:
        vocab_name: A key word that indicates which subclass
            of the :class:`deeppavlov.models.ranking.ranking_dict.RankingDict` to use.
        hard_triplets_sampling: Whether to use hard triplets sampling to train the model
            i.e. to choose negative samples close to positive ones.
        hardest_positives: Whether to use only one hardest positive sample per each anchor sample.
        semi_hard_negatives: Whether hard negative samples should be further away from anchor samples
            than positive samples or not.
        update_embeddings: Whether to store and update context and response embeddings or not.
        interact_pred_num: The number of the most relevant contexts and responses
            which model returns in the `interact` regime.
        **kwargs: Other parameters.
    """
    def __init__(self,
                 vocab_name,
                 hard_triplets_sampling: bool = False,
                 hardest_positives: bool = False,
                 semi_hard_negatives: bool = False,
                 num_hardest_negatives: int = None,
                 update_embeddings: bool = False,
                 interact_pred_num: int = 3,
                 **kwargs):

        # Parameters for parent classes
        save_path = kwargs.get('save_path', None)
        load_path = kwargs.get('load_path', None)
        train_now = kwargs.get('train_now', None)
        mode = kwargs.get('mode', None)

        super().__init__(save_path=save_path,
                         load_path=load_path,
                         train_now=train_now,
                         mode=mode)

        self.hard_triplets_sampling = hard_triplets_sampling
        self.hardest_positives = hardest_positives
        self.semi_hard_negatives = semi_hard_negatives
        self.num_hardest_negatives = num_hardest_negatives
        self.upd_embs = update_embeddings
        self.interact_pred_num = interact_pred_num
        self.train_now = train_now
        self.vocab_name = vocab_name

        opt = deepcopy(kwargs)

        if self.vocab_name == "insurance":
            dict_parameter_names = list(
                inspect.signature(InsuranceDict.__init__).parameters)
            dict_parameters = {
                par: opt[par]
                for par in dict_parameter_names if par in opt
            }
            self.dict = InsuranceDict(**dict_parameters,
                                      update_embeddings=update_embeddings)

        embdict_parameter_names = list(
            inspect.signature(EmbDict.__init__).parameters)
        embdict_parameters = {
            par: opt[par]
            for par in embdict_parameter_names if par in opt
        }
        self.embdict = EmbDict(**embdict_parameters)

        network_parameter_names = list(
            inspect.signature(RankingNetwork.__init__).parameters)
        self.network_parameters = {
            par: opt[par]
            for par in network_parameter_names if par in opt
        }

        self.load()

        train_parameters_names = list(
            inspect.signature(self._net.train_on_batch).parameters)
        self.train_parameters = {
            par: opt[par]
            for par in train_parameters_names if par in opt
        }

    @overrides
    def load(self):
        """Load the model from the last checkpoint."""
        if not self.load_path.exists():
            log.info("[initializing new `{}`]".format(self.__class__.__name__))
            self.dict.init_from_scratch()
            self.embdict.init_from_scratch(self.dict.tok2int_vocab)
            if hasattr(self.dict, 'char2int_vocab'):
                chars_num = len(self.dict.char2int_vocab)
            else:
                chars_num = 0
            self._net = RankingNetwork(chars_num=chars_num,
                                       toks_num=len(self.dict.tok2int_vocab),
                                       emb_dict=self.embdict,
                                       **self.network_parameters)
            self._net.init_from_scratch(self.embdict.emb_matrix)
        else:
            log.info("[initializing `{}` from saved]".format(
                self.__class__.__name__))
            self.dict.load()
            self.embdict.load()
            if hasattr(self.dict, 'char2int_vocab'):
                chars_num = len(self.dict.char2int_vocab)
            else:
                chars_num = 0
            self._net = RankingNetwork(chars_num=chars_num,
                                       toks_num=len(self.dict.tok2int_vocab),
                                       emb_dict=self.embdict,
                                       **self.network_parameters)
            self._net.load(self.load_path)

    @overrides
    def save(self):
        """Save the model."""
        log.info('[saving model to {}]'.format(self.save_path.resolve()))
        self._net.save(self.save_path)
        if self.upd_embs:
            self.set_embeddings()
        self.dict.save()
        self.embdict.save()

    @check_attr_true('train_now')
    def train_on_batch(self, x: List[List[Tuple[int, int]]], y: List[int]):
        """Train the model on a batch."""
        if self.upd_embs:
            self.reset_embeddings()
        if self.hard_triplets_sampling:
            b = self.make_hard_triplets(x, y, self._net)
            y = np.ones(len(b[0][0]))
        else:
            b = self.make_batch(x)
            for i in range(len(x[0])):
                c = self.dict.make_toks(b[i][0], type="context")
                c = self.dict.make_ints(c)
                b[i][0] = c
                r = self.dict.make_toks(b[i][1], type="response")
                r = self.dict.make_ints(r)
                b[i][1] = r
        self._net.train_on_batch(b, y)

    def make_batch(self, x):
        sample_len = len(x[0])
        b = []
        for i in range(sample_len):
            c = []
            r = []
            for el in x:
                c.append(el[i][0])
                r.append(el[i][1])
            b.append([c, r])
        return b

    def make_hard_triplets(self, x, y, net):
        samples = [[s[1] for s in el] for el in x]
        labels = y
        batch_size = len(samples)
        num_samples = len(samples[0])
        samp = [y for el in samples for y in el]
        s = self.dict.make_toks(samp, type="context")
        s = self.dict.make_ints(s)

        embeddings = net.predict_embedding([s, s], 512, type='context')
        embeddings = embeddings / np.expand_dims(
            np.linalg.norm(embeddings, axis=1), axis=1)
        dot_product = embeddings @ embeddings.T
        square_norm = np.diag(dot_product)
        distances = np.expand_dims(square_norm,
                                   0) - 2.0 * dot_product + np.expand_dims(
                                       square_norm, 1)
        distances = np.maximum(distances, 0.0)
        distances = np.sqrt(distances)

        mask_anchor_negative = np.expand_dims(np.repeat(labels, num_samples), 0)\
                               != np.expand_dims(np.repeat(labels, num_samples), 1)
        mask_anchor_negative = mask_anchor_negative.astype(float)
        max_anchor_negative_dist = np.max(distances, axis=1, keepdims=True)
        anchor_negative_dist = distances + max_anchor_negative_dist * (
            1.0 - mask_anchor_negative)
        if self.num_hardest_negatives is not None:
            hard = np.argsort(anchor_negative_dist,
                              axis=1)[:, :self.num_hardest_negatives]
            ind = np.random.randint(self.num_hardest_negatives,
                                    size=batch_size * num_samples)
            hardest_negative_ind = hard[batch_size * num_samples * [True], ind]
        else:
            hardest_negative_ind = np.argmin(anchor_negative_dist, axis=1)

        mask_anchor_positive = np.expand_dims(np.repeat(labels, num_samples), 0) \
                               == np.expand_dims(np.repeat(labels, num_samples), 1)
        mask_anchor_positive = mask_anchor_positive.astype(float)
        anchor_positive_dist = mask_anchor_positive * distances

        c = []
        rp = []
        rn = []
        hrds = []

        if self.hardest_positives:

            if self.semi_hard_negatives:
                hardest_positive_ind = []
                hardest_negative_ind = []
                for p, n in zip(anchor_positive_dist, anchor_negative_dist):
                    no_samples = True
                    p_li = list(
                        zip(p, np.arange(batch_size * num_samples),
                            batch_size * num_samples * [True]))
                    n_li = list(
                        zip(n, np.arange(batch_size * num_samples),
                            batch_size * num_samples * [False]))
                    pn_li = sorted(p_li + n_li, key=lambda el: el[0])
                    for i, x in enumerate(pn_li):
                        if not x[2]:
                            for y in pn_li[:i][::-1]:
                                if y[2] and y[0] > 0.0:
                                    assert (x[1] != y[1])
                                    hardest_negative_ind.append(x[1])
                                    hardest_positive_ind.append(y[1])
                                    no_samples = False
                                    break
                        if not no_samples:
                            break
                    if no_samples:
                        log.error(
                            "There are no negative examples with distances"
                            " greater than positive examples distances.")
                        exit(0)
            else:
                if self.num_hardest_negatives is not None:
                    hard = np.argsort(anchor_positive_dist,
                                      axis=1)[:, -self.num_hardest_negatives:]
                    ind = np.random.randint(self.num_hardest_negatives,
                                            size=batch_size * num_samples)
                    hardest_positive_ind = hard[batch_size * num_samples *
                                                [True], ind]
                else:
                    hardest_positive_ind = np.argmax(anchor_positive_dist,
                                                     axis=1)

            for i in range(batch_size):
                for j in range(num_samples):
                    c.append(s[i * num_samples + j])
                    rp.append(s[hardest_positive_ind[i * num_samples + j]])
                    rn.append(s[hardest_negative_ind[i * num_samples + j]])

        else:
            if self.semi_hard_negatives:
                for i in range(batch_size):
                    for j in range(num_samples):
                        for k in range(j + 1, num_samples):
                            c.append(s[i * num_samples + j])
                            c.append(s[i * num_samples + k])
                            rp.append(s[i * num_samples + k])
                            rp.append(s[i * num_samples + j])
                            n, hrd = self.get_semi_hard_negative_ind(
                                i, j, k, distances, anchor_negative_dist,
                                batch_size, num_samples)
                            assert (n != i * num_samples + k)
                            rn.append(s[n])
                            hrds.append(hrd)
                            n, hrd = self.get_semi_hard_negative_ind(
                                i, k, j, distances, anchor_negative_dist,
                                batch_size, num_samples)
                            assert (n != i * num_samples + j)
                            rn.append(s[n])
                            hrds.append(hrd)
            else:
                for i in range(batch_size):
                    for j in range(num_samples):
                        for k in range(j + 1, num_samples):
                            c.append(s[i * num_samples + j])
                            c.append(s[i * num_samples + k])
                            rp.append(s[i * num_samples + k])
                            rp.append(s[i * num_samples + j])
                            rn.append(s[hardest_negative_ind[i * num_samples +
                                                             j]])
                            rn.append(s[hardest_negative_ind[i * num_samples +
                                                             k]])

        triplets = list(zip(c, rp, rn))
        np.random.shuffle(triplets)
        c = [el[0] for el in triplets]
        rp = [el[1] for el in triplets]
        rn = [el[2] for el in triplets]
        ratio = sum(hrds) / len(hrds)
        log.info("Ratio of semi-hard negative samples is %f" % ratio)
        return [(c, rp), (c, rn)]

    def get_semi_hard_negative_ind(self, i, j, k, distances,
                                   anchor_negative_dist, batch_size,
                                   num_samples):
        anc_pos_dist = distances[i * num_samples + j, i * num_samples + k]
        neg_dists = anchor_negative_dist[i * num_samples + j]
        n_li_pre = sorted(list(
            zip(neg_dists, np.arange(batch_size * num_samples))),
                          key=lambda el: el[0])
        n_li = list(filter(lambda x: x[1]<i*num_samples, n_li_pre)) + \
               list(filter(lambda x: x[1]>=(i+1)*num_samples, n_li_pre))
        for x in n_li:
            if x[0] > anc_pos_dist:
                return x[1], True
        return random.choice(n_li)[1], False

    def __call__(self, batch: Union[List[List[Tuple[int, int]]], List[str]]) ->\
            Union[np.ndarray, Dict[str, List[str]]]:
        """Make a prediction on a batch."""
        if type(batch[0]) == list:
            y_pred = []
            b = self.make_batch(batch)
            for el in b:
                c = self.dict.make_toks(el[0], type="context")
                c = self.dict.make_ints(c)
                r = self.dict.make_toks(el[1], type="response")
                r = self.dict.make_ints(r)
                yp = self._net.predict_score_on_batch([c, r])
                y_pred.append(yp)
            y_pred = np.hstack(y_pred)
            return y_pred

        elif type(batch[0]) == str:
            c_input = tokenize(batch)
            c_input = self.dict.make_ints(c_input)
            c_input_emb = self._net.predict_embedding_on_batch(
                [c_input, c_input], type='context')

            c_emb = [
                self.dict.context2emb_vocab[i]
                for i in range(len(self.dict.context2emb_vocab))
            ]
            c_emb = np.vstack(c_emb)
            pred_cont = np.sum(c_input_emb * c_emb, axis=1)\
                     / np.linalg.norm(c_input_emb, axis=1) / np.linalg.norm(c_emb, axis=1)
            pred_cont = np.flip(np.argsort(pred_cont),
                                0)[:self.interact_pred_num]
            pred_cont = [
                ' '.join(self.dict.context2toks_vocab[el]) for el in pred_cont
            ]

            r_emb = [
                self.dict.response2emb_vocab[i]
                for i in range(len(self.dict.response2emb_vocab))
            ]
            r_emb = np.vstack(r_emb)
            pred_resp = np.sum(c_input_emb * r_emb, axis=1)\
                     / np.linalg.norm(c_input_emb, axis=1) / np.linalg.norm(r_emb, axis=1)
            pred_resp = np.flip(np.argsort(pred_resp),
                                0)[:self.interact_pred_num]
            pred_resp = [
                ' '.join(self.dict.response2toks_vocab[el]) for el in pred_resp
            ]
            y_pred = [{"contexts": pred_cont, "responses": pred_resp}]
            return y_pred

    def update_embeddings(self, batch):
        sample_len = len(batch[0])
        labels_cont = []
        labels_resp = []
        cont = []
        resp = []
        for i in range(sample_len):
            lc = []
            lr = []
            for el in batch:
                lc.append(el[i][0])
                lr.append(el[i][1])
            labels_cont.append(lc)
            labels_resp.append(lr)
        for i in range(sample_len):
            c = self.dict.make_toks(labels_cont[i], type="context")
            c = self.dict.make_ints(c)
            cont.append(c)
            r = self.dict.make_toks(labels_resp[i], type="response")
            r = self.dict.make_ints(r)
            resp.append(r)
        for el in zip(labels_cont, cont):
            c_emb = self._net.predict_embedding_on_batch([el[1], el[1]],
                                                         type='context')
            for i in range(len(el[0])):
                self.dict.context2emb_vocab[el[0][i]] = c_emb[i]
        for el in zip(labels_resp, resp):
            r_emb = self._net.predict_embedding_on_batch([el[1], el[1]],
                                                         type='response')
            for i in range(len(el[0])):
                self.dict.response2emb_vocab[el[0][i]] = r_emb[i]

    def set_embeddings(self):
        if self.dict.response2emb_vocab[0] is None:
            r = []
            for i in range(len(self.dict.response2toks_vocab)):
                r.append(self.dict.response2toks_vocab[i])
            r = self.dict.make_ints(r)
            response_embeddings = self._net.predict_embedding([r, r],
                                                              512,
                                                              type='response')
            for i in range(len(self.dict.response2toks_vocab)):
                self.dict.response2emb_vocab[i] = response_embeddings[i]
        if self.dict.context2emb_vocab[0] is None:
            c = []
            for i in range(len(self.dict.context2toks_vocab)):
                c.append(self.dict.context2toks_vocab[i])
            c = self.dict.make_ints(c)
            context_embeddings = self._net.predict_embedding([c, c],
                                                             512,
                                                             type='context')
            for i in range(len(self.dict.context2toks_vocab)):
                self.dict.context2emb_vocab[i] = context_embeddings[i]

    def reset_embeddings(self):
        if self.dict.response2emb_vocab[0] is not None:
            for i in range(len(self.dict.response2emb_vocab)):
                self.dict.response2emb_vocab[i] = None
        if self.dict.context2emb_vocab[0] is not None:
            for i in range(len(self.dict.context2emb_vocab)):
                self.dict.context2emb_vocab[i] = None

    def shutdown(self):
        pass

    def reset(self):
        pass
Esempio n. 7
0
class RankingModel(NNModel):
    """Class to perform ranking.

    Args:
        vocab_name: A key word that indicates which subclass
            of the :class:`deeppavlov.models.ranking.ranking_dict.RankingDict` to use.
        hard_triplets_sampling: Whether to use hard triplets sampling to train the model
            i.e. to choose negative samples close to positive ones.
        hardest_positives: Whether to use only one hardest positive sample per each anchor sample.
        semi_hard_negatives: Whether hard negative samples should be further away from anchor samples
            than positive samples or not.
        update_embeddings: Whether to store and update context and response embeddings or not.
        interact_pred_num: The number of the most relevant contexts and responses
            which model returns in the `interact` regime.
        **kwargs: Other parameters.
    """

    def __init__(self,
                 vocab_name,
                 hard_triplets_sampling: bool = False,
                 hardest_positives: bool = False,
                 semi_hard_negatives: bool = False,
                 num_hardest_negatives: int = None,
                 update_embeddings: bool = False,
                 interact_pred_num: int = 3,
                 **kwargs):

        # Parameters for parent classes
        save_path = kwargs.get('save_path', None)
        load_path = kwargs.get('load_path', None)
        train_now = kwargs.get('train_now', None)
        mode = kwargs.get('mode', None)

        super().__init__(save_path=save_path, load_path=load_path,
                         train_now=train_now, mode=mode)

        self.hard_triplets_sampling = hard_triplets_sampling
        self.hardest_positives = hardest_positives
        self.semi_hard_negatives = semi_hard_negatives
        self.num_hardest_negatives = num_hardest_negatives
        self.upd_embs = update_embeddings
        self.interact_pred_num = interact_pred_num
        self.train_now = train_now
        self.vocab_name = vocab_name

        opt = deepcopy(kwargs)

        if self.vocab_name == "insurance":
            dict_parameter_names = list(inspect.signature(InsuranceDict.__init__).parameters)
            dict_parameters = {par: opt[par] for par in dict_parameter_names if par in opt}
            self.dict = InsuranceDict(**dict_parameters, update_embeddings=update_embeddings)

        embdict_parameter_names = list(inspect.signature(EmbDict.__init__).parameters)
        embdict_parameters = {par: opt[par] for par in embdict_parameter_names if par in opt}
        self.embdict= EmbDict(**embdict_parameters)

        network_parameter_names = list(inspect.signature(RankingNetwork.__init__).parameters)
        self.network_parameters = {par: opt[par] for par in network_parameter_names if par in opt}

        self.load()

        train_parameters_names = list(inspect.signature(self._net.train_on_batch).parameters)
        self.train_parameters = {par: opt[par] for par in train_parameters_names if par in opt}

    @overrides
    def load(self):
        """Load the model from the last checkpoint."""
        if not self.load_path.exists():
            log.info("[initializing new `{}`]".format(self.__class__.__name__))
            self.dict.init_from_scratch()
            self.embdict.init_from_scratch(self.dict.tok2int_vocab)
            if hasattr(self.dict, 'char2int_vocab'):
                chars_num = len(self.dict.char2int_vocab)
            else:
                chars_num = 0
            self._net = RankingNetwork(chars_num=chars_num,
                                       toks_num=len(self.dict.tok2int_vocab),
                                       emb_dict=self.embdict,
                                       **self.network_parameters)
            self._net.init_from_scratch(self.embdict.emb_matrix)
        else:
            log.info("[initializing `{}` from saved]".format(self.__class__.__name__))
            self.dict.load()
            self.embdict.load()
            if hasattr(self.dict, 'char2int_vocab'):
                chars_num = len(self.dict.char2int_vocab)
            else:
                chars_num = 0
            self._net = RankingNetwork(chars_num=chars_num,
                                       toks_num=len(self.dict.tok2int_vocab),
                                       emb_dict=self.embdict,
                                       **self.network_parameters)
            self._net.load(self.load_path)

    @overrides
    def save(self):
        """Save the model."""
        log.info('[saving model to {}]'.format(self.save_path.resolve()))
        self._net.save(self.save_path)
        if self.upd_embs:
            self.set_embeddings()
        self.dict.save()
        self.embdict.save()

    @check_attr_true('train_now')
    def train_on_batch(self, x: List[List[Tuple[int, int]]], y: List[int]):
        """Train the model on a batch."""
        if self.upd_embs:
            self.reset_embeddings()
        if self.hard_triplets_sampling:
            b = self.make_hard_triplets(x, y, self._net)
            y = np.ones(len(b[0][0]))
        else:
            b = self.make_batch(x)
            for i in range(len(x[0])):
                c = self.dict.make_toks(b[i][0], type="context")
                c = self.dict.make_ints(c)
                b[i][0] = c
                r = self.dict.make_toks(b[i][1], type="response")
                r = self.dict.make_ints(r)
                b[i][1] = r
        self._net.train_on_batch(b, y)

    def  make_batch(self, x):
        sample_len = len(x[0])
        b = []
        for i in range(sample_len):
            c = []
            r = []
            for el in x:
                c.append(el[i][0])
                r.append(el[i][1])
            b.append([c, r])
        return b

    def make_hard_triplets(self, x, y, net):
        samples = [[s[1] for s in el] for el in x]
        labels = y
        batch_size = len(samples)
        num_samples = len(samples[0])
        samp = [y for el in samples for y in el]
        s = self.dict.make_toks(samp, type="context")
        s = self.dict.make_ints(s)

        embeddings = net.predict_embedding([s, s], 512, type='context')
        embeddings = embeddings / np.expand_dims(np.linalg.norm(embeddings, axis=1), axis=1)
        dot_product = embeddings @ embeddings.T
        square_norm = np.diag(dot_product)
        distances = np.expand_dims(square_norm, 0) - 2.0 * dot_product + np.expand_dims(square_norm, 1)
        distances = np.maximum(distances, 0.0)
        distances = np.sqrt(distances)

        mask_anchor_negative = np.expand_dims(np.repeat(labels, num_samples), 0)\
                               != np.expand_dims(np.repeat(labels, num_samples), 1)
        mask_anchor_negative = mask_anchor_negative.astype(float)
        max_anchor_negative_dist = np.max(distances, axis=1, keepdims=True)
        anchor_negative_dist = distances + max_anchor_negative_dist * (1.0 - mask_anchor_negative)
        if self.num_hardest_negatives is not None:
            hard = np.argsort(anchor_negative_dist, axis=1)[:, :self.num_hardest_negatives]
            ind = np.random.randint(self.num_hardest_negatives, size=batch_size * num_samples)
            hardest_negative_ind = hard[batch_size * num_samples * [True], ind]
        else:
            hardest_negative_ind = np.argmin(anchor_negative_dist, axis=1)

        mask_anchor_positive = np.expand_dims(np.repeat(labels, num_samples), 0) \
                               == np.expand_dims(np.repeat(labels, num_samples), 1)
        mask_anchor_positive = mask_anchor_positive.astype(float)
        anchor_positive_dist = mask_anchor_positive * distances

        c =[]
        rp = []
        rn = []
        hrds = []

        if self.hardest_positives:

            if self.semi_hard_negatives:
                hardest_positive_ind = []
                hardest_negative_ind = []
                for p, n in zip(anchor_positive_dist, anchor_negative_dist):
                    no_samples = True
                    p_li = list(zip(p, np.arange(batch_size * num_samples), batch_size * num_samples * [True]))
                    n_li = list(zip(n, np.arange(batch_size * num_samples), batch_size * num_samples * [False]))
                    pn_li = sorted(p_li + n_li, key=lambda el: el[0])
                    for i, x in enumerate(pn_li):
                        if not x[2]:
                            for y in pn_li[:i][::-1]:
                                if y[2] and y[0] > 0.0:
                                    assert (x[1] != y[1])
                                    hardest_negative_ind.append(x[1])
                                    hardest_positive_ind.append(y[1])
                                    no_samples = False
                                    break
                        if not no_samples:
                            break
                    if no_samples:
                        log.error("There are no negative examples with distances"
                                  " greater than positive examples distances.")
                        exit(0)
            else:
                if self.num_hardest_negatives is not None:
                    hard = np.argsort(anchor_positive_dist, axis=1)[:, -self.num_hardest_negatives:]
                    ind = np.random.randint(self.num_hardest_negatives, size=batch_size * num_samples)
                    hardest_positive_ind = hard[batch_size * num_samples * [True], ind]
                else:
                    hardest_positive_ind = np.argmax(anchor_positive_dist, axis=1)

            for i in range(batch_size):
                for j in range(num_samples):
                    c.append(s[i*num_samples+j])
                    rp.append(s[hardest_positive_ind[i*num_samples+j]])
                    rn.append(s[hardest_negative_ind[i*num_samples+j]])

        else:
            if self.semi_hard_negatives:
                for i in range(batch_size):
                    for j in range(num_samples):
                        for k in range(j+1, num_samples):
                            c.append(s[i*num_samples+j])
                            c.append(s[i*num_samples+k])
                            rp.append(s[i*num_samples+k])
                            rp.append(s[i*num_samples+j])
                            n, hrd = self.get_semi_hard_negative_ind(i, j, k, distances,
                                                                anchor_negative_dist,
                                                                batch_size, num_samples)
                            assert(n != i*num_samples+k)
                            rn.append(s[n])
                            hrds.append(hrd)
                            n, hrd = self.get_semi_hard_negative_ind(i, k, j, distances,
                                                                anchor_negative_dist,
                                                                batch_size, num_samples)
                            assert(n != i*num_samples+j)
                            rn.append(s[n])
                            hrds.append(hrd)
            else:
                for i in range(batch_size):
                    for j in range(num_samples):
                        for k in range(j + 1, num_samples):
                            c.append(s[i * num_samples + j])
                            c.append(s[i * num_samples + k])
                            rp.append(s[i * num_samples + k])
                            rp.append(s[i * num_samples + j])
                            rn.append(s[hardest_negative_ind[i * num_samples + j]])
                            rn.append(s[hardest_negative_ind[i * num_samples + k]])

        triplets = list(zip(c, rp, rn))
        np.random.shuffle(triplets)
        c = [el[0] for el in triplets]
        rp = [el[1] for el in triplets]
        rn = [el[2] for el in triplets]
        ratio = sum(hrds) / len(hrds)
        log.info("Ratio of semi-hard negative samples is %f" % ratio)
        return [(c, rp), (c, rn)]

    def get_semi_hard_negative_ind(self, i, j, k, distances, anchor_negative_dist, batch_size, num_samples):
        anc_pos_dist = distances[i * num_samples + j, i * num_samples + k]
        neg_dists = anchor_negative_dist[i * num_samples + j]
        n_li_pre = sorted(list(zip(neg_dists, np.arange(batch_size * num_samples))), key=lambda el: el[0])
        n_li = list(filter(lambda x: x[1]<i*num_samples, n_li_pre)) + \
               list(filter(lambda x: x[1]>=(i+1)*num_samples, n_li_pre))
        for x in n_li:
            if x[0] > anc_pos_dist :
                return x[1], True
        return random.choice(n_li)[1], False

    def __call__(self, batch: Union[List[List[Tuple[int, int]]], List[str]]) ->\
            Union[np.ndarray, Dict[str, List[str]]]:
        """Make a prediction on a batch."""
        if type(batch[0]) == list:
            y_pred = []
            b = self.make_batch(batch)
            for el in b:
                c = self.dict.make_toks(el[0], type="context")
                c = self.dict.make_ints(c)
                r = self.dict.make_toks(el[1], type="response")
                r = self.dict.make_ints(r)
                yp = self._net.predict_score_on_batch([c, r])
                y_pred.append(yp)
            y_pred = np.hstack(y_pred)
            return y_pred

        elif type(batch[0]) == str:
            c_input = tokenize(batch)
            c_input = self.dict.make_ints(c_input)
            c_input_emb = self._net.predict_embedding_on_batch([c_input, c_input], type='context')

            c_emb = [self.dict.context2emb_vocab[i] for i in range(len(self.dict.context2emb_vocab))]
            c_emb = np.vstack(c_emb)
            pred_cont = np.sum(c_input_emb * c_emb, axis=1)\
                     / np.linalg.norm(c_input_emb, axis=1) / np.linalg.norm(c_emb, axis=1)
            pred_cont = np.flip(np.argsort(pred_cont), 0)[:self.interact_pred_num]
            pred_cont = [' '.join(self.dict.context2toks_vocab[el]) for el in pred_cont]

            r_emb = [self.dict.response2emb_vocab[i] for i in range(len(self.dict.response2emb_vocab))]
            r_emb = np.vstack(r_emb)
            pred_resp = np.sum(c_input_emb * r_emb, axis=1)\
                     / np.linalg.norm(c_input_emb, axis=1) / np.linalg.norm(r_emb, axis=1)
            pred_resp = np.flip(np.argsort(pred_resp), 0)[:self.interact_pred_num]
            pred_resp = [' '.join(self.dict.response2toks_vocab[el]) for el in pred_resp]
            y_pred = [{"contexts": pred_cont, "responses": pred_resp}]
            return y_pred

    def update_embeddings(self, batch):
        sample_len = len(batch[0])
        labels_cont = []
        labels_resp = []
        cont = []
        resp = []
        for i in range(sample_len):
            lc = []
            lr = []
            for el in batch:
                lc.append(el[i][0])
                lr.append(el[i][1])
            labels_cont.append(lc)
            labels_resp.append(lr)
        for i in range(sample_len):
            c = self.dict.make_toks(labels_cont[i], type="context")
            c = self.dict.make_ints(c)
            cont.append(c)
            r = self.dict.make_toks(labels_resp[i], type="response")
            r = self.dict.make_ints(r)
            resp.append(r)
        for el in zip(labels_cont, cont):
            c_emb = self._net.predict_embedding_on_batch([el[1], el[1]], type='context')
            for i in range(len(el[0])):
                self.dict.context2emb_vocab[el[0][i]] = c_emb[i]
        for el in zip(labels_resp, resp):
            r_emb = self._net.predict_embedding_on_batch([el[1], el[1]], type='response')
            for i in range(len(el[0])):
                self.dict.response2emb_vocab[el[0][i]] = r_emb[i]

    def set_embeddings(self):
        if self.dict.response2emb_vocab[0] is None:
            r = []
            for i in range(len(self.dict.response2toks_vocab)):
                r.append(self.dict.response2toks_vocab[i])
            r = self.dict.make_ints(r)
            response_embeddings = self._net.predict_embedding([r, r], 512, type='response')
            for i in range(len(self.dict.response2toks_vocab)):
                self.dict.response2emb_vocab[i] = response_embeddings[i]
        if self.dict.context2emb_vocab[0] is None:
            c = []
            for i in range(len(self.dict.context2toks_vocab)):
                c.append(self.dict.context2toks_vocab[i])
            c = self.dict.make_ints(c)
            context_embeddings = self._net.predict_embedding([c, c], 512, type='context')
            for i in range(len(self.dict.context2toks_vocab)):
                self.dict.context2emb_vocab[i] = context_embeddings[i]

    def reset_embeddings(self):
        if self.dict.response2emb_vocab[0] is not None:
            for i in range(len(self.dict.response2emb_vocab)):
                self.dict.response2emb_vocab[i] = None
        if self.dict.context2emb_vocab[0] is not None:
            for i in range(len(self.dict.context2emb_vocab)):
                self.dict.context2emb_vocab[i] = None

    def shutdown(self):
        pass

    def reset(self):
        pass