Ejemplo n.º 1
0
    def get_line(self, cut_lines=True, line_break="\n"):
        """
        :param cut_lines: cut buffer to lines
        :type cut_lines: bool
        :param line_break: line break of the file, like '\\\\n' or '\\\\r'
        :type line_break: string

        :return: one line or a buffer of bytes
        :rtype: string
        """
        remained = ""
        while True:
            buff = self.process.stdout.read(self.bufsize)
            if buff:
                if self.file_type == "gzip":
                    decomp_buff = cpt.to_text(self.dec.decompress(buff))
                elif self.file_type == "plain":
                    decomp_buff = cpt.to_text(buff)
                else:
                    raise TypeError("file_type %s is not allowed" %
                                    self.file_type)

                if cut_lines:
                    lines, remained = _buf2lines(
                        ''.join([remained, decomp_buff]), line_break)
                    for line in lines:
                        yield line
                else:
                    yield decomp_buff
            else:
                break
Ejemplo n.º 2
0
def __load_dict(tar_file, dict_size, lang, reverse=False):
    dict_path = os.path.join(paddle.dataset.common.DATA_HOME,
                             "wmt16/%s_%d.dict" % (lang, dict_size))
    if not os.path.exists(dict_path) or (len(
            open(dict_path, "rb").readlines()) != dict_size):
        __build_dict(tar_file, dict_size, dict_path, lang)

    word_dict = {}
    with open(dict_path, "rb") as fdict:
        for idx, line in enumerate(fdict):
            if reverse:
                word_dict[idx] = cpt.to_text(line.strip())
            else:
                word_dict[cpt.to_text(line.strip())] = idx
    return word_dict
Ejemplo n.º 3
0
    def reader():
        src_dict, trg_dict = __read_to_dict(tar_file, dict_size)
        with tarfile.open(tar_file, mode='r') as f:
            names = [
                each_item.name for each_item in f
                if each_item.name.endswith(file_name)
            ]
            for name in names:
                for line in f.extractfile(name):
                    line = cpt.to_text(line)
                    line_split = line.strip().split('\t')
                    if len(line_split) != 2:
                        continue
                    src_seq = line_split[0]  # one source sequence
                    src_words = src_seq.split()
                    src_ids = [
                        src_dict.get(w, UNK_IDX)
                        for w in [START] + src_words + [END]
                    ]

                    trg_seq = line_split[1]  # one target sequence
                    trg_words = trg_seq.split()
                    trg_ids = [trg_dict.get(w, UNK_IDX) for w in trg_words]

                    # remove sequence whose length > 80 in training mode
                    if len(src_ids) > 80 or len(trg_ids) > 80:
                        continue
                    trg_ids_next = trg_ids + [trg_dict[END]]
                    trg_ids = [trg_dict[START]] + trg_ids

                    yield src_ids, trg_ids, trg_ids_next
Ejemplo n.º 4
0
    def _load_data(self):
        def __to_dict(fd, size):
            out_dict = dict()
            for line_count, line in enumerate(fd):
                if line_count < size:
                    out_dict[cpt.to_text(line.strip())] = line_count
                else:
                    break
            return out_dict

        self.src_ids = []
        self.trg_ids = []
        self.trg_ids_next = []
        with tarfile.open(self.data_file, mode='r') as f:
            names = [
                each_item.name for each_item in f
                if each_item.name.endswith("src.dict")
            ]
            assert len(names) == 1
            self.src_dict = __to_dict(f.extractfile(names[0]), self.dict_size)
            names = [
                each_item.name for each_item in f
                if each_item.name.endswith("trg.dict")
            ]
            assert len(names) == 1
            self.trg_dict = __to_dict(f.extractfile(names[0]), self.dict_size)

            file_name = "{}/{}".format(self.mode, self.mode)
            names = [
                each_item.name for each_item in f
                if each_item.name.endswith(file_name)
            ]
            for name in names:
                for line in f.extractfile(name):
                    line = cpt.to_text(line)
                    line_split = line.strip().split('\t')
                    if len(line_split) != 2:
                        continue
                    src_seq = line_split[0]  # one source sequence
                    src_words = src_seq.split()
                    src_ids = [
                        self.src_dict.get(w, UNK_IDX)
                        for w in [START] + src_words + [END]
                    ]

                    trg_seq = line_split[1]  # one target sequence
                    trg_words = trg_seq.split()
                    trg_ids = [
                        self.trg_dict.get(w, UNK_IDX) for w in trg_words
                    ]

                    # remove sequence whose length > 80 in training mode
                    if len(src_ids) > 80 or len(trg_ids) > 80:
                        continue
                    trg_ids_next = trg_ids + [self.trg_dict[END]]
                    trg_ids = [self.trg_dict[START]] + trg_ids

                    self.src_ids.append(src_ids)
                    self.trg_ids.append(trg_ids)
                    self.trg_ids_next.append(trg_ids_next)
Ejemplo n.º 5
0
def save_persistable_nodes(executor, dirname, graph):
    """
    Save persistable nodes to the given directory by the executor.

    Args:
        executor(Executor): The executor to run for saving node values.
        dirname(str): The directory path.
        graph(IrGraph): All the required persistable nodes in the graph will be saved.
    """
    persistable_node_names = set()
    persistable_nodes = []
    all_persistable_nodes = graph.all_persistable_nodes()
    for node in all_persistable_nodes:
        name = cpt.to_text(node.name())
        if name not in persistable_node_names:
            persistable_node_names.add(name)
            persistable_nodes.append(node)
    program = Program()
    var_list = []
    for node in persistable_nodes:
        var_desc = node.var()
        if var_desc.type() == core.VarDesc.VarType.RAW or \
                var_desc.type() == core.VarDesc.VarType.READER:
            continue
        var = program.global_block().create_var(
            name=var_desc.name(),
            shape=var_desc.shape(),
            dtype=var_desc.dtype(),
            type=var_desc.type(),
            lod_level=var_desc.lod_level(),
            persistable=var_desc.persistable())
        var_list.append(var)
    fluid.io.save_vars(executor=executor, dirname=dirname, vars=var_list)
Ejemplo n.º 6
0
    def _load_data(self):
        # the index for start mark, end mark, and unk are the same in source
        # language and target language. Here uses the source language
        # dictionary to determine their indices.
        start_id = self.src_dict[START_MARK]
        end_id = self.src_dict[END_MARK]
        unk_id = self.src_dict[UNK_MARK]

        src_col = 0 if self.lang == "en" else 1
        trg_col = 1 - src_col

        self.src_ids = []
        self.trg_ids = []
        self.trg_ids_next = []
        with tarfile.open(self.data_file, mode="r") as f:
            for line in f.extractfile("wmt16/{}".format(self.mode)):
                line = cpt.to_text(line)
                line_split = line.strip().split("\t")
                if len(line_split) != 2:
                    continue
                src_words = line_split[src_col].split()
                src_ids = [start_id] + [
                    self.src_dict.get(w, unk_id) for w in src_words
                ] + [end_id]

                trg_words = line_split[trg_col].split()
                trg_ids = [self.trg_dict.get(w, unk_id) for w in trg_words]

                trg_ids_next = trg_ids + [end_id]
                trg_ids = [start_id] + trg_ids

                self.src_ids.append(src_ids)
                self.trg_ids.append(trg_ids)
                self.trg_ids_next.append(trg_ids_next)
Ejemplo n.º 7
0
    def reader():
        src_dict = __load_dict(tar_file, src_dict_size, src_lang)
        trg_dict = __load_dict(tar_file, trg_dict_size,
                               ("de" if src_lang == "en" else "en"))

        # the index for start mark, end mark, and unk are the same in source
        # language and target language. Here uses the source language
        # dictionary to determine their indices.
        start_id = src_dict[START_MARK]
        end_id = src_dict[END_MARK]
        unk_id = src_dict[UNK_MARK]

        src_col = 0 if src_lang == "en" else 1
        trg_col = 1 - src_col

        with tarfile.open(tar_file, mode="r") as f:
            for line in f.extractfile(file_name):
                line = cpt.to_text(line)
                line_split = line.strip().split("\t")
                if len(line_split) != 2:
                    continue
                src_words = line_split[src_col].split()
                src_ids = [start_id
                           ] + [src_dict.get(w, unk_id)
                                for w in src_words] + [end_id]

                trg_words = line_split[trg_col].split()
                trg_ids = [trg_dict.get(w, unk_id) for w in trg_words]

                trg_ids_next = trg_ids + [end_id]
                trg_ids = [start_id] + trg_ids

                yield src_ids, trg_ids, trg_ids_next
Ejemplo n.º 8
0
 def __to_dict(fd, size):
     out_dict = dict()
     for line_count, line in enumerate(fd):
         if line_count < size:
             out_dict[cpt.to_text(line.strip())] = line_count
         else:
             break
     return out_dict
Ejemplo n.º 9
0
def __initialize_meta_info__():
    fn = paddle.dataset.common.download(URL, "movielens", MD5)
    global MOVIE_INFO
    if MOVIE_INFO is None:
        pattern = re.compile(r'^(.*)\((\d+)\)$')
        with zipfile.ZipFile(file=fn) as package:
            for info in package.infolist():
                assert isinstance(info, zipfile.ZipInfo)
                MOVIE_INFO = dict()
                title_word_set = set()
                categories_set = set()
                with package.open('ml-1m/movies.dat') as movie_file:
                    for i, line in enumerate(movie_file):
                        line = cpt.to_text(line, encoding='latin')
                        movie_id, title, categories = line.strip().split('::')
                        categories = categories.split('|')
                        for c in categories:
                            categories_set.add(c)
                        title = pattern.match(title).group(1)
                        MOVIE_INFO[int(movie_id)] = MovieInfo(
                            index=movie_id, categories=categories, title=title)
                        for w in title.split():
                            title_word_set.add(w.lower())

                global MOVIE_TITLE_DICT
                MOVIE_TITLE_DICT = dict()
                for i, w in enumerate(title_word_set):
                    MOVIE_TITLE_DICT[w] = i

                global CATEGORIES_DICT
                CATEGORIES_DICT = dict()
                for i, c in enumerate(categories_set):
                    CATEGORIES_DICT[c] = i

                global USER_INFO
                USER_INFO = dict()
                with package.open('ml-1m/users.dat') as user_file:
                    for line in user_file:
                        line = cpt.to_text(line, encoding='latin')
                        uid, gender, age, job, _ = line.strip().split("::")
                        USER_INFO[int(uid)] = UserInfo(index=uid,
                                                       gender=gender,
                                                       age=age,
                                                       job_id=job)
    return fn
Ejemplo n.º 10
0
    def _load_dict(self, lang, dict_size, reverse=False):
        dict_path = os.path.join(paddle.dataset.common.DATA_HOME,
                                 "wmt16/%s_%d.dict" % (lang, dict_size))
        dict_found = False
        if os.path.exists(dict_path):
            with open(dict_path, "rb") as d:
                dict_found = len(d.readlines()) == dict_size
        if not dict_found:
            self._build_dict(dict_path, dict_size, lang)

        word_dict = {}
        with open(dict_path, "rb") as fdict:
            for idx, line in enumerate(fdict):
                if reverse:
                    word_dict[idx] = cpt.to_text(line.strip())
                else:
                    word_dict[cpt.to_text(line.strip())] = idx
        return word_dict
Ejemplo n.º 11
0
def _append_loaded_suffix(name):
    """
    Append loaded suffix to the given variable name
    e.g. x ==> x.load_0, x.load_0 ==> x.load_0.load_0
    """
    suffix = LOADED_VAR_SUFFIX
    name = cpt.to_text(name)
    new_name = unique_name.generate_with_ignorable_key('.'.join(
        (name, suffix)))
    return new_name
Ejemplo n.º 12
0
    def _load_meta_info(self):
        pattern = re.compile(r'^(.*)\((\d+)\)$')
        self.movie_info = dict()
        self.movie_title_dict = dict()
        self.categories_dict = dict()
        self.user_info = dict()
        with zipfile.ZipFile(self.data_file) as package:
            for info in package.infolist():
                assert isinstance(info, zipfile.ZipInfo)
                title_word_set = set()
                categories_set = set()
                with package.open('ml-1m/movies.dat') as movie_file:
                    for i, line in enumerate(movie_file):
                        line = cpt.to_text(line, encoding='latin')
                        movie_id, title, categories = line.strip().split('::')
                        categories = categories.split('|')
                        for c in categories:
                            categories_set.add(c)
                        title = pattern.match(title).group(1)
                        self.movie_info[int(movie_id)] = MovieInfo(
                            index=movie_id, categories=categories, title=title)
                        for w in title.split():
                            title_word_set.add(w.lower())

                for i, w in enumerate(title_word_set):
                    self.movie_title_dict[w] = i

                for i, c in enumerate(categories_set):
                    self.categories_dict[c] = i

                with package.open('ml-1m/users.dat') as user_file:
                    for line in user_file:
                        line = cpt.to_text(line, encoding='latin')
                        uid, gender, age, job, _ = line.strip().split("::")
                        self.user_info[int(uid)] = UserInfo(index=uid,
                                                            gender=gender,
                                                            age=age,
                                                            job_id=job)
Ejemplo n.º 13
0
def __reader__(rand_seed=0, test_ratio=0.1, is_test=False):
    fn = __initialize_meta_info__()
    np.random.seed(rand_seed)
    with zipfile.ZipFile(file=fn) as package:
        with package.open('ml-1m/ratings.dat') as rating:
            for line in rating:
                line = cpt.to_text(line, encoding='latin')
                if (np.random.random() < test_ratio) == is_test:
                    uid, mov_id, rating, _ = line.strip().split("::")
                    uid = int(uid)
                    mov_id = int(mov_id)
                    rating = float(rating) * 2 - 5.0

                    mov = MOVIE_INFO[mov_id]
                    usr = USER_INFO[uid]
                    yield usr.value() + mov.value() + [[rating]]
Ejemplo n.º 14
0
def load_persistable_nodes(executor, dirname, graph):
    """
    Load persistable node values from the given directory by the executor.

    Args:
        executor(Executor): The executor to run for loading node values.
        dirname(str): The directory path.
        graph(IrGraph): All the required persistable nodes in the graph will be loaded.
    """
    persistable_node_names = set()
    persistable_nodes = []
    all_persistable_nodes = graph.all_persistable_nodes()
    for node in all_persistable_nodes:
        name = cpt.to_text(node.name())
        if name not in persistable_node_names:
            persistable_node_names.add(name)
            persistable_nodes.append(node)
    program = Program()
    var_list = []

    def _exist(var):
        return os.path.exists(os.path.join(dirname, var.name))

    def _load_var(name, scope):
        return np.array(scope.find_var(name).get_tensor())

    def _store_var(name, array, scope, place):
        tensor = scope.find_var(name).get_tensor()
        tensor.set(array, place)

    for node in persistable_nodes:
        var_desc = node.var()
        if var_desc.type() == core.VarDesc.VarType.RAW or \
                var_desc.type() == core.VarDesc.VarType.READER:
            continue
        var = program.global_block().create_var(
            name=var_desc.name(),
            shape=var_desc.shape(),
            dtype=var_desc.dtype(),
            type=var_desc.type(),
            lod_level=var_desc.lod_level(),
            persistable=var_desc.persistable())
        if _exist(var):
            var_list.append(var)
        else:
            _logger.info("Cannot find the var %s!!!" %(node.name()))
    fluid.io.load_vars(executor=executor, dirname=dirname, vars=var_list)
Ejemplo n.º 15
0
    def _load_data(self):
        self.data = []
        is_test = self.mode == 'test'
        with zipfile.ZipFile(self.data_file) as package:
            with package.open('ml-1m/ratings.dat') as rating:
                for line in rating:
                    line = cpt.to_text(line, encoding='latin')
                    if (np.random.random() < self.test_ratio) == is_test:
                        uid, mov_id, rating, _ = line.strip().split("::")
                        uid = int(uid)
                        mov_id = int(mov_id)
                        rating = float(rating) * 2 - 5.0

                        mov = self.movie_info[mov_id]
                        usr = self.user_info[uid]
                        self.data.append(usr.value() + \
                                         mov.value(self.categories_dict, self.movie_title_dict) + \
                                         [[rating]])
Ejemplo n.º 16
0
 def reader():
     while True:
         for file in open(file_list):
             file = file.strip()
             batch = None
             with open(file, 'rb') as f:
                 if six.PY2:
                     batch = pickle.load(f)
                 else:
                     batch = pickle.load(f, encoding='bytes')
             if six.PY3:
                 batch = cpt.to_text(batch)
             data = batch['data']
             labels = batch['label']
             for sample, label in six.moves.zip(data, batch['label']):
                 yield sample, int(label) - 1
         if not cycle:
             break
Ejemplo n.º 17
0
def __build_dict(tar_file, dict_size, save_path, lang):
    word_dict = defaultdict(int)
    with tarfile.open(tar_file, mode="r") as f:
        for line in f.extractfile("wmt16/train"):
            line = cpt.to_text(line)
            line_split = line.strip().split("\t")
            if len(line_split) != 2: continue
            sen = line_split[0] if lang == "en" else line_split[1]
            for w in sen.split():
                word_dict[w] += 1

    with open(save_path, "w") as fout:
        fout.write("%s\n%s\n%s\n" % (START_MARK, END_MARK, UNK_MARK))
        for idx, word in enumerate(
                sorted(six.iteritems(word_dict),
                       key=lambda x: x[1],
                       reverse=True)):
            if idx + 3 == dict_size: break
            fout.write("%s\n" % (cpt.to_bytes(word[0])))
Ejemplo n.º 18
0
 def to_string(self, throw_on_error, with_details=False):
     """
     To debug string.
     :param throw_on_error:
     :param with_details:
     :return:
     """
     assert isinstance(throw_on_error, bool) and isinstance(
         with_details, bool)
     if with_details:
         res_str = MpcVariable.to_string(self, throw_on_error, True)
         additional_attr = ("trainable", "optimize_attr", "regularizer",
                            "gradient_clip_attr", "do_model_average")
         for attr_name in additional_attr:
             res_str += "%s: %s\n" % (attr_name,
                                      cpt.to_text(getattr(self, attr_name)))
     else:
         res_str = MpcVariable.to_string(self, throw_on_error, False)
     return res_str
Ejemplo n.º 19
0
    def test_to_text(self):
        self.assertIsNone(cpt.to_text(None))

        self.assertTrue(isinstance(cpt.to_text(str("")), str))
        self.assertTrue(isinstance(cpt.to_text(str("123")), str))
        self.assertTrue(isinstance(cpt.to_text(b""), str))
        self.assertTrue(isinstance(cpt.to_text(b""), str))
        self.assertTrue(isinstance(cpt.to_text(u""), str))
        self.assertTrue(isinstance(cpt.to_text(u""), str))

        self.assertEqual("", cpt.to_text(str("")))
        self.assertEqual("123", cpt.to_text(str("123")))
        self.assertEqual("", cpt.to_text(b""))
        self.assertEqual("123", cpt.to_text(b"123"))
        self.assertEqual("", cpt.to_text(u""))
        self.assertEqual("123", cpt.to_text(u"123"))

        # check list types, not inplace
        l = [""]
        l2 = cpt.to_text(l)
        self.assertTrue(isinstance(l2, list))
        self.assertFalse(l is l2)
        self.assertEqual(l, l2)
        self.assertEqual([""], l2)
        l = ["", "123"]
        l2 = cpt.to_text(l)
        self.assertTrue(isinstance(l2, list))
        self.assertFalse(l is l2)
        self.assertEqual(l, l2)
        self.assertEqual(["", "123"], l2)
        l = ["", b"123", u"321"]
        l2 = cpt.to_text(l)
        self.assertTrue(isinstance(l2, list))
        self.assertFalse(l is l2)
        self.assertNotEqual(l, l2)
        self.assertEqual(["", "123", "321"], l2)

        # check list types, inplace
        l = [""]
        l2 = cpt.to_text(l, inplace=True)
        self.assertTrue(isinstance(l2, list))
        self.assertTrue(l is l2)
        self.assertEqual(l, l2)
        self.assertEqual([""], l2)
        l = ["", b"123"]
        l2 = cpt.to_text(l, inplace=True)
        self.assertTrue(isinstance(l2, list))
        self.assertTrue(l is l2)
        self.assertEqual(l, l2)
        self.assertEqual(["", "123"], l2)
        l = ["", b"123", u"321"]
        l2 = cpt.to_text(l, inplace=True)
        self.assertTrue(isinstance(l2, list))
        self.assertTrue(l is l2)
        self.assertEqual(l, l2)
        self.assertEqual(["", "123", "321"], l2)
        for i in l2:
            self.assertTrue(isinstance(i, str))

        # check set types, not inplace
        l = set("")
        l2 = cpt.to_text(l, inplace=False)
        self.assertTrue(isinstance(l2, set))
        self.assertFalse(l is l2)
        self.assertEqual(l, l2)
        self.assertEqual(set(""), l2)
        l = set([b"", b"123"])
        l2 = cpt.to_text(l, inplace=False)
        self.assertTrue(isinstance(l2, set))
        self.assertFalse(l is l2)
        self.assertNotEqual(l, l2)
        self.assertEqual(set(["", "123"]), l2)
        l = set(["", b"123", u"321"])
        l2 = cpt.to_text(l, inplace=False)
        self.assertTrue(isinstance(l2, set))
        self.assertFalse(l is l2)
        self.assertNotEqual(l, l2)
        self.assertEqual(set(["", "123", "321"]), l2)

        # check set types, inplace
        l = set("")
        l2 = cpt.to_text(l, inplace=True)
        self.assertTrue(isinstance(l2, set))
        self.assertTrue(l is l2)
        self.assertEqual(l, l2)
        self.assertEqual(set(""), l2)
        l = set([b"", b"123"])
        l2 = cpt.to_text(l, inplace=True)
        self.assertTrue(isinstance(l2, set))
        self.assertTrue(l is l2)
        self.assertEqual(l, l2)
        self.assertEqual(set(["", "123"]), l2)
        l = set(["", b"123", u"321"])
        l2 = cpt.to_text(l, inplace=True)
        self.assertTrue(isinstance(l2, set))
        self.assertTrue(l is l2)
        self.assertEqual(l, l2)
        self.assertEqual(set(["", "123", "321"]), l2)
        for i in l2:
            self.assertTrue(isinstance(i, str))

        # check dict types, not inplace
        l = {"": ""}
        l2 = cpt.to_text(l, inplace=False)
        self.assertTrue(isinstance(l2, dict))
        self.assertFalse(l is l2)
        self.assertEqual(l, l2)
        self.assertEqual({"": ""}, l2)

        # check dict types, inplace
        l = {"": ""}
        l2 = cpt.to_text(l, inplace=True)
        self.assertTrue(isinstance(l2, dict))
        self.assertTrue(l is l2)
        self.assertEqual(l, l2)
        self.assertEqual({"": ""}, l2)
Ejemplo n.º 20
0
    def reader():
        tf = tarfile.open(data_path)
        wf = tf.extractfile(words_name)
        pf = tf.extractfile(props_name)
        with gzip.GzipFile(fileobj=wf) as words_file, gzip.GzipFile(
                fileobj=pf) as props_file:
            sentences = []
            labels = []
            one_seg = []
            for word, label in zip(words_file, props_file):
                word = cpt.to_text(word.strip())
                label = cpt.to_text(label.strip().split())

                if len(label) == 0:  # end of sentence
                    for i in range(len(one_seg[0])):
                        a_kind_lable = [x[i] for x in one_seg]
                        labels.append(a_kind_lable)

                    if len(labels) >= 1:
                        verb_list = []
                        for x in labels[0]:
                            if x != '-':
                                verb_list.append(x)

                        for i, lbl in enumerate(labels[1:]):
                            cur_tag = 'O'
                            is_in_bracket = False
                            lbl_seq = []
                            verb_word = ''
                            for l in lbl:
                                if l == '*' and is_in_bracket == False:
                                    lbl_seq.append('O')
                                elif l == '*' and is_in_bracket == True:
                                    lbl_seq.append('I-' + cur_tag)
                                elif l == '*)':
                                    lbl_seq.append('I-' + cur_tag)
                                    is_in_bracket = False
                                elif l.find('(') != -1 and l.find(')') != -1:
                                    cur_tag = l[1:l.find('*')]
                                    lbl_seq.append('B-' + cur_tag)
                                    is_in_bracket = False
                                elif l.find('(') != -1 and l.find(')') == -1:
                                    cur_tag = l[1:l.find('*')]
                                    lbl_seq.append('B-' + cur_tag)
                                    is_in_bracket = True
                                else:
                                    raise RuntimeError('Unexpected label: %s' %
                                                       l)

                            yield sentences, verb_list[i], lbl_seq

                    sentences = []
                    labels = []
                    one_seg = []
                else:
                    sentences.append(word)
                    one_seg.append(label)

        pf.close()
        wf.close()
        tf.close()
Ejemplo n.º 21
0
def append_backward(loss,
                    parameter_list=None,
                    no_grad_set=None,
                    callbacks=None,
                    checkpoints=None):
    """
    This function appends backward part to main_program.
    A complete neural network training is made up of forward and backward
    propagation. However, when we configure a network, we only need to
    specify its forward part. This function uses the chain rule to automatically
    generate the backward part according to the forward part.
    In most cases, users do not need to invoke this function manually.
    It will be automatically invoked by the optimizer's `minimize` function.
    Parameters:
        loss( :ref:`api_guide_Variable_en` ): The loss variable of the network.
        parameter_list(list of str, optional): Names of parameters that need
                                           to be updated by optimizers.
                                           If it is None, all parameters
                                           will be updated.
                                           Default: None.
        no_grad_set(set of str, optional): Variable names in the :ref:`api_guide_Block_en` 0 whose gradients
                               should be ignored. All variables with
                               `stop_gradient=True` from all blocks will
                               be automatically added into this set.
                               If this parameter is not None, the names in this set will be added to the default set.
                               Default: None.
       callbacks(list of callable object, optional): List of callback functions.
                                               The callbacks are used for
                                               doing some custom jobs during
                                               backward part building. All
                                               callable objects in it will
                                               be invoked once each time a
                                               new gradient operator is added
                                               into the program. The callable
                                               object must has two input
                                               parameters: 'block' and 'context'.
                                               The 'block' is the :ref:`api_guide_Block_en` which
                                               the new gradient operator will
                                               be added to. The 'context' is a
                                               map, whose keys are gradient
                                               variable names and values are
                                               corresponding original :ref:`api_guide_Variable_en` .
                                               In addition to this, the 'context'
                                               has another special key-value pair:
                                               the key is string '__current_op_desc__'
                                               and the value is the op_desc of the
                                               gradient operator who has just
                                               triggered the callable object.
                                               Default: None.
    Returns:
        list of tuple ( :ref:`api_guide_Variable_en` , :ref:`api_guide_Variable_en` ): Pairs of parameter and its corresponding gradients.
        The key is the parameter and the value is gradient variable.
    Raises:
        AssertionError: If `loss` is not an instance of Variable.
    Examples:
        .. code-block:: python
            import paddle.fluid as fluid
            x = fluid.data(name='x', shape=[None, 13], dtype='float32')
            y = fluid.data(name='y', shape=[None, 1], dtype='float32')
            y_predict = fluid.layers.fc(input=x, size=1, act=None)
            loss = fluid.layers.square_error_cost(input=y_predict, label=y)
            avg_loss = fluid.layers.mean(loss)
            param_grad_list = fluid.backward.append_backward(loss=avg_loss)
            p_g_list1 = fluid.backward.append_backward(loss=avg_loss)  # len(p_g_list1) == 2
            p_g_list2 = fluid.backward.append_backward(loss=avg_loss, parameter_list=[p_g_list1[0][0].name])  # len(p_g_list1) == 1
            p_g_list3 = fluid.backward.append_backward(loss=avg_loss, no_grad_set=set([p_g_list1[0][0].name]))  # len(p_g_list1) == 1
            p_g_list4 = fluid.backward.append_backward(loss=avg_loss, parameter_list=[p_g_list1[0][0].name], no_grad_set=set([p_g_list1[0][0].name]))  # len(p_g_list1) == 0
    """

    assert isinstance(loss, framework.Variable)

    if loss.op is None:
        # the loss is from a cloned program. Find loss op manually.
        backward._find_loss_op_(loss)

    loss.op._set_attr(
        core.op_proto_and_checker_maker.kOpRoleAttrName(),
        int(core.op_proto_and_checker_maker.OpRole.Forward)
        | int(core.op_proto_and_checker_maker.OpRole.Loss))

    if callbacks is not None:
        isinstance(callbacks, list)

    program = loss.block.program
    program._appending_grad_times += 1

    if no_grad_set is None:
        no_grad_set = set()
    no_grad_set = copy.copy(no_grad_set)
    no_grad_dict = backward._get_stop_gradients_(program)
    no_grad_dict[0].update(
        list(map(backward._append_grad_suffix_, no_grad_set)))

    grad_info_map = dict()
    root_block = program.block(0)

    fwd_op_num = root_block.desc.op_size()
    current_block_idx = program.current_block_idx
    grad_to_var = dict()

    op_desc = _create_loss_op_desc_(loss)
    root_block.desc.append_op().copy_from(op_desc)

    block_no_grad_set = set(map(backward._strip_grad_suffix_, no_grad_dict[0]))
    op_path = backward._find_op_path_(root_block, [loss], [],
                                      block_no_grad_set)
    no_grad_vars = backward._find_no_grad_vars(root_block, op_path, [loss],
                                               block_no_grad_set)
    block_no_grad_set.update(no_grad_vars)
    no_grad_dict[0].update(
        list(map(backward._append_grad_suffix_, block_no_grad_set)))

    input_grad_names_set = None
    # For double backward, input_grad_names is used for filter
    # some non-used gradients op.
    if program._appending_grad_times > 1:
        input_grad_names_set = set([backward._append_grad_suffix_(loss.name)])

    backward._append_backward_ops_(root_block,
                                   op_path,
                                   root_block,
                                   no_grad_dict,
                                   grad_to_var,
                                   callbacks,
                                   input_grad_names_set=input_grad_names_set)

    # Because calc_gradient may be called multiple times,
    # we need rename the internal gradient variables so that they have
    # different names.
    backward._rename_grad_(root_block, fwd_op_num, grad_to_var, {})

    backward._append_backward_vars_(root_block, fwd_op_num, grad_to_var,
                                    grad_info_map)

    program.current_block_idx = current_block_idx
    program._sync_with_cpp()

    if parameter_list is not None:
        parameters = parameter_list
    else:
        params = list(filter(is_mpc_parameter, program.list_vars()))
        parameters = [param.name for param in params if param.trainable]

    params_and_grads = []
    for param in parameters:
        if cpt.to_text(param) not in grad_info_map:
            continue
        grad_info = grad_info_map[param]
        grad_block = grad_info[1]
        if not grad_block.has_var(grad_info[0]):
            raise ValueError(
                "grad block[{0}] did not have grad var {1}".format(
                    grad_info[1], grad_info[0]))
        # Get the param var from the global block
        param_var = program.global_block().var(param)
        grad_var = grad_block.var(grad_info[0])
        if loss.block.has_var(grad_info[0]):
            params_and_grads.append((param_var, grad_var))
        else:
            params_and_grads.append((param_var, None))

    op_role_var_attr_name = core.op_proto_and_checker_maker.kOpRoleVarAttrName(
    )
    for p, g in params_and_grads:
        if g is None:
            continue
        for op in reversed(program.global_block().ops):
            assert isinstance(op, framework.Operator)
            if g.name in op.output_arg_names:
                g.op = op
                break

        if g.op is None:
            raise ValueError("Unexpected branch")
        attr_val = [p.name, g.name]
        if g.op.has_attr(op_role_var_attr_name):
            attr_val.extend(g.op.attr(op_role_var_attr_name))
        g.op._set_attr(op_role_var_attr_name, attr_val)

    return params_and_grads
Ejemplo n.º 22
0
    def test_to_text(self):
        # Only support python2.x and python3.x now
        self.assertTrue(six.PY2 | six.PY3)

        if six.PY2:
            # check None
            self.assertIsNone(cpt.to_text(None))

            # check all string related types
            self.assertTrue(isinstance(cpt.to_text(str("")), unicode))
            self.assertTrue(isinstance(cpt.to_text(str("123")), unicode))
            self.assertTrue(isinstance(cpt.to_text(b""), unicode))
            self.assertTrue(isinstance(cpt.to_text(b""), unicode))
            self.assertTrue(isinstance(cpt.to_text(u""), unicode))
            self.assertTrue(isinstance(cpt.to_text(u""), unicode))

            self.assertEqual(u"", cpt.to_text(str("")))
            self.assertEqual(u"123", cpt.to_text(str("123")))
            self.assertEqual(u"", cpt.to_text(b""))
            self.assertEqual(u"123", cpt.to_text(b"123"))
            self.assertEqual(u"", cpt.to_text(u""))
            self.assertEqual(u"123", cpt.to_text(u"123"))

            # check list types, not inplace
            l = [""]
            l2 = cpt.to_text(l)
            self.assertTrue(isinstance(l2, list))
            self.assertFalse(l is l2)
            self.assertEqual(l, l2)
            self.assertEqual([u""], l2)
            l = ["", "123"]
            l2 = cpt.to_text(l)
            self.assertTrue(isinstance(l2, list))
            self.assertFalse(l is l2)
            self.assertEqual(l, l2)
            self.assertEqual([u"", u"123"], l2)
            l = ["", b'123', u"321"]
            l2 = cpt.to_text(l)
            self.assertTrue(isinstance(l2, list))
            self.assertFalse(l is l2)
            self.assertEqual(l, l2)
            self.assertEqual([u"", u"123", u"321"], l2)
            for i in l2:
                self.assertTrue(isinstance(i, unicode))

            # check list types, inplace
            l = [""]
            l2 = cpt.to_text(l, inplace=True)
            self.assertTrue(isinstance(l2, list))
            self.assertTrue(l is l2)
            self.assertEqual(l, l2)
            self.assertEqual([u""], l2)
            l = ["", "123"]
            l2 = cpt.to_text(l, inplace=True)
            self.assertTrue(isinstance(l2, list))
            self.assertTrue(l is l2)
            self.assertEqual(l, l2)
            self.assertEqual([u"", u"123"], l2)
            l = ["", b"123", u"321"]
            l2 = cpt.to_text(l, inplace=True)
            self.assertTrue(isinstance(l2, list))
            self.assertTrue(l is l2)
            self.assertEqual(l, l2)
            self.assertEqual([u"", u"123", u"321"], l2)

            # check set types, not inplace
            l = set("")
            l2 = cpt.to_text(l, inplace=False)
            self.assertTrue(isinstance(l2, set))
            self.assertFalse(l is l2)
            self.assertEqual(l, l2)
            self.assertEqual(set(u""), l2)
            l = set([b"", b"123"])
            l2 = cpt.to_text(l, inplace=False)
            self.assertTrue(isinstance(l2, set))
            self.assertFalse(l is l2)
            self.assertEqual(l, l2)
            self.assertEqual(set([u"", u"123"]), l2)
            l = set(["", b"123", u"321"])
            l2 = cpt.to_text(l, inplace=False)
            self.assertTrue(isinstance(l2, set))
            self.assertFalse(l is l2)
            self.assertEqual(l, l2)
            self.assertEqual(set([u"", u"123", u"321"]), l2)
            for i in l2:
                self.assertTrue(isinstance(i, unicode))

            # check set types, inplace
            l = set("")
            l2 = cpt.to_text(l, inplace=True)
            self.assertTrue(isinstance(l2, set))
            self.assertTrue(l is l2)
            self.assertEqual(l, l2)
            self.assertEqual(set(u""), l2)
            l = set([b"", b"123"])
            l2 = cpt.to_text(l, inplace=True)
            self.assertTrue(isinstance(l2, set))
            self.assertTrue(l is l2)
            self.assertEqual(l, l2)
            self.assertEqual(set([u"", u"123"]), l2)
            l = set(["", b"123", u"321"])
            l2 = cpt.to_text(l, inplace=True)
            self.assertTrue(isinstance(l2, set))
            self.assertTrue(l is l2)
            self.assertEqual(l, l2)
            self.assertEqual(set([u"", u"123", u"321"]), l2)

        elif six.PY3:
            self.assertIsNone(cpt.to_text(None))

            self.assertTrue(isinstance(cpt.to_text(str("")), str))
            self.assertTrue(isinstance(cpt.to_text(str("123")), str))
            self.assertTrue(isinstance(cpt.to_text(b""), str))
            self.assertTrue(isinstance(cpt.to_text(b""), str))
            self.assertTrue(isinstance(cpt.to_text(u""), str))
            self.assertTrue(isinstance(cpt.to_text(u""), str))

            self.assertEqual("", cpt.to_text(str("")))
            self.assertEqual("123", cpt.to_text(str("123")))
            self.assertEqual("", cpt.to_text(b""))
            self.assertEqual("123", cpt.to_text(b"123"))
            self.assertEqual("", cpt.to_text(u""))
            self.assertEqual("123", cpt.to_text(u"123"))

            # check list types, not inplace
            l = [""]
            l2 = cpt.to_text(l)
            self.assertTrue(isinstance(l2, list))
            self.assertFalse(l is l2)
            self.assertEqual(l, l2)
            self.assertEqual([""], l2)
            l = ["", "123"]
            l2 = cpt.to_text(l)
            self.assertTrue(isinstance(l2, list))
            self.assertFalse(l is l2)
            self.assertEqual(l, l2)
            self.assertEqual(["", "123"], l2)
            l = ["", b"123", u"321"]
            l2 = cpt.to_text(l)
            self.assertTrue(isinstance(l2, list))
            self.assertFalse(l is l2)
            self.assertNotEqual(l, l2)
            self.assertEqual(["", "123", "321"], l2)

            # check list types, inplace
            l = [""]
            l2 = cpt.to_text(l, inplace=True)
            self.assertTrue(isinstance(l2, list))
            self.assertTrue(l is l2)
            self.assertEqual(l, l2)
            self.assertEqual([""], l2)
            l = ["", b"123"]
            l2 = cpt.to_text(l, inplace=True)
            self.assertTrue(isinstance(l2, list))
            self.assertTrue(l is l2)
            self.assertEqual(l, l2)
            self.assertEqual(["", "123"], l2)
            l = ["", b"123", u"321"]
            l2 = cpt.to_text(l, inplace=True)
            self.assertTrue(isinstance(l2, list))
            self.assertTrue(l is l2)
            self.assertEqual(l, l2)
            self.assertEqual(["", "123", "321"], l2)
            for i in l2:
                self.assertTrue(isinstance(i, str))

            # check set types, not inplace
            l = set("")
            l2 = cpt.to_text(l, inplace=False)
            self.assertTrue(isinstance(l2, set))
            self.assertFalse(l is l2)
            self.assertEqual(l, l2)
            self.assertEqual(set(""), l2)
            l = set([b"", b"123"])
            l2 = cpt.to_text(l, inplace=False)
            self.assertTrue(isinstance(l2, set))
            self.assertFalse(l is l2)
            self.assertNotEqual(l, l2)
            self.assertEqual(set(["", "123"]), l2)
            l = set(["", b"123", u"321"])
            l2 = cpt.to_text(l, inplace=False)
            self.assertTrue(isinstance(l2, set))
            self.assertFalse(l is l2)
            self.assertNotEqual(l, l2)
            self.assertEqual(set(["", "123", "321"]), l2)

            # check set types, inplace
            l = set("")
            l2 = cpt.to_text(l, inplace=True)
            self.assertTrue(isinstance(l2, set))
            self.assertTrue(l is l2)
            self.assertEqual(l, l2)
            self.assertEqual(set(""), l2)
            l = set([b"", b"123"])
            l2 = cpt.to_text(l, inplace=True)
            self.assertTrue(isinstance(l2, set))
            self.assertTrue(l is l2)
            self.assertEqual(l, l2)
            self.assertEqual(set(["", "123"]), l2)
            l = set(["", b"123", u"321"])
            l2 = cpt.to_text(l, inplace=True)
            self.assertTrue(isinstance(l2, set))
            self.assertTrue(l is l2)
            self.assertEqual(l, l2)
            self.assertEqual(set(["", "123", "321"]), l2)
            for i in l2:
                self.assertTrue(isinstance(i, str))
Ejemplo n.º 23
0
    def __init__(self,
                 block,
                 type=core.VarDesc.VarType.LOD_TENSOR,
                 name=None,
                 shape=None,
                 dtype=None,
                 lod_level=None,
                 capacity=None,
                 persistable=None,
                 error_clip=None,
                 stop_gradient=False,
                 is_data=False,
                 need_check_feed=False,
                 belong_to_optimizer=False,
                 **kwargs):
        self.block = block
        if name is None:
            name = unique_name.generate('_generated_var')

        if dtype is not None:
            if not isinstance(dtype, core.VarDesc.VarType):
                dtype = convert_np_dtype_to_dtype_(dtype)

        self.belong_to_optimizer = belong_to_optimizer

        self.error_clip = error_clip

        is_new_var = False
        name = cpt.to_text(name)
        self.desc = self.block.desc.find_var(cpt.to_bytes(name))

        if self.desc is None:
            self.desc = self.block.desc.var(cpt.to_bytes(name))
            is_new_var = True

        if is_new_var:
            self.desc.set_type(type)
        elif self.desc.type() != type:
            raise ValueError("MpcVariable {0} has been created before. The "
                             "previous type is {1}; the new type is {2}. They"
                             " are not matched".format(self.name,
                                                       self.desc.type(), type))
        if shape is not None:
            if is_new_var:
                # resize the shape for MpcVariable
                mpc_shape = list(shape)
                mpc_shape.insert(0, 2)
                self.desc.set_shape(mpc_shape)
            else:
                old_shape = self.shape
                shape = tuple(shape)
                if shape != old_shape:
                    raise ValueError(
                        "MpcVariable {0} has been created before. the previous "
                        "shape is {1}; the new shape is {2}. They are not "
                        "matched.".format(self.name, old_shape, shape))
        if dtype is not None:
            if is_new_var:
                self.desc.set_dtype(dtype)
            else:
                old_dtype = self.dtype
                if dtype != old_dtype:
                    raise ValueError(
                        "MpcVariable {0} has been created before. "
                        "The previous data type is {1}; the new "
                        "data type is {2}. They are not "
                        "matched.".format(self.name, old_dtype, dtype))

        if lod_level is not None:
            if is_new_var:
                self.desc.set_lod_level(lod_level)
            else:
                if lod_level != self.lod_level:
                    raise ValueError(
                        "MpcVariable {0} has been created before. "
                        "The previous lod_level is {1}; the new "
                        "lod_level is {2}. They are not "
                        "matched".format(self.name, self.lod_level, lod_level))
        if persistable is not None:
            if is_new_var:
                self.desc.set_persistable(persistable)
            else:
                if persistable != self.persistable:
                    raise ValueError(
                        "MpcVariable {0} has been created before."
                        "The previous persistable is {1}; the new "
                        "persistable is {2}. They are not matched".format(
                            self.name, self.persistable, persistable))

        if need_check_feed and is_new_var:
            self.desc.set_need_check_feed(need_check_feed)

        if capacity is not None:
            if is_new_var:
                self.desc.set_capacity(capacity)
            else:
                # TODO(abhinavarora) by Paddle 1.7: Compare with set capacity once,
                # get_capacity is implemented
                pass

        self.block.vars[name] = self
        self.op = None
        self._stop_gradient = stop_gradient
        self.is_data = is_data
Ejemplo n.º 24
0
    def _load_anno(self):
        tf = tarfile.open(self.data_file)
        wf = tf.extractfile(
            "conll05st-release/test.wsj/words/test.wsj.words.gz")
        pf = tf.extractfile(
            "conll05st-release/test.wsj/props/test.wsj.props.gz")
        self.sentences = []
        self.predicates = []
        self.labels = []
        with gzip.GzipFile(fileobj=wf) as words_file, gzip.GzipFile(
                fileobj=pf) as props_file:
            sentences = []
            labels = []
            one_seg = []
            for word, label in zip(words_file, props_file):
                word = cpt.to_text(word.strip())
                label = cpt.to_text(label.strip().split())

                if len(label) == 0:  # end of sentence
                    for i in range(len(one_seg[0])):
                        a_kind_lable = [x[i] for x in one_seg]
                        labels.append(a_kind_lable)

                    if len(labels) >= 1:
                        verb_list = []
                        for x in labels[0]:
                            if x != '-':
                                verb_list.append(x)

                        for i, lbl in enumerate(labels[1:]):
                            cur_tag = 'O'
                            is_in_bracket = False
                            lbl_seq = []
                            verb_word = ''
                            for l in lbl:
                                if l == '*' and is_in_bracket == False:
                                    lbl_seq.append('O')
                                elif l == '*' and is_in_bracket == True:
                                    lbl_seq.append('I-' + cur_tag)
                                elif l == '*)':
                                    lbl_seq.append('I-' + cur_tag)
                                    is_in_bracket = False
                                elif l.find('(') != -1 and l.find(')') != -1:
                                    cur_tag = l[1:l.find('*')]
                                    lbl_seq.append('B-' + cur_tag)
                                    is_in_bracket = False
                                elif l.find('(') != -1 and l.find(')') == -1:
                                    cur_tag = l[1:l.find('*')]
                                    lbl_seq.append('B-' + cur_tag)
                                    is_in_bracket = True
                                else:
                                    raise RuntimeError('Unexpected label: %s' %
                                                       l)

                            self.sentences.append(sentences)
                            self.predicates.append(verb_list[i])
                            self.labels.append(lbl_seq)

                    sentences = []
                    labels = []
                    one_seg = []
                else:
                    sentences.append(word)
                    one_seg.append(label)

        pf.close()
        wf.close()
        tf.close()