コード例 #1
0
def proc_build(tarbz2_file, work_dir, save_dir, config):

    tarbz2_lib.untar_file_to_path(tarbz2_file, work_dir)
    #untar_one_pickle_file(tarbz2_file, work_dir)

    pickle_file = work_dir + os.path.basename(tarbz2_file).replace(
        '.tar.bz2', '')
    pickle_file_content = pickle_lib.get_pickle_file_content(pickle_file)
    #pickle_file_content = get_pickle_file_content(work_dir + os.path.basename(pickle_file).replace('.tar.bz2', ''))

    binaries = set()
    functions = set()
    for elem in pickle_file_content:
        binaries.add(elem[7])
        functions.add(elem[2])

    print(f'binaries >{binaries}<')

    counter = 0
    dataset_list = list()

    ## 1. get one binary
    ## 2. get one function of this binary
    ## 3. get disassembly of this function
    ## 4. check if this disassembly calls another function
    ## 4.1 filter @plt
    ## 5. if yes: get disassembly of caller function
    ## 6. save caller, callee, nr_of_args
    ## 7. check again, if it calls another function
    ## 8. if yes: get disassembly of caller function
    ## 9. save caller, calle, func_signature
    ##10. get disassembly of next function of this binary
    ##11. check if ....
    for bin in binaries:
        for func in functions:
            ## search for bin and func
            for elem in pickle_file_content:
                ### if we found bin and func
                if elem[7] == bin and elem[2] == func:
                    ## get att disassembly
                    att_dis = elem[4]
                    #print(f'att-dis >{att_dis}<')
                    ## check every line if there is a call
                    for item in att_dis:
                        ## find call in disas
                        if disassembly_lib.find_call_in_disassembly_line(item):
                            ## if found, get callee name
                            callee_name = disassembly_lib.get_callee_name_from_disassembly_line(
                                item)

                            #print(f'callee_name >{callee_name}<')

                            ## search for same bin, but callee func
                            for elem2 in pickle_file_content:
                                ### if we found it, get return type and disassembly
                                if elem2[7] == bin and elem2[2] == callee_name:

                                    #return_type_func_sign = return_type_lib.get_return_type_from_function_signature(elem2[0])
                                    #return_type = return_type_lib.get_return_type_from_gdb_ptype(elem2[1])
                                    nr_of_args = return_type_lib.get_nr_of_args_from_function_signature(
                                        elem2[0])
                                    arg_nr_we_want = 3
                                    if nr_of_args < arg_nr_we_want:
                                        #print(f'func got to less args for us')
                                        break

                                    arg_two = return_type_lib.get_arg_two_name_from_function_signature(
                                        elem2[0])

                                    result = common_stuff_lib.is_type_known(
                                        arg_two)

                                    if result == False:
                                        #print(f'arg_two not a known type')
                                        pass
                                    else:
                                        tmp_att_dis = att_dis
                                        #print(f'len att-dis 1 >{len(tmp_att_dis)}<')
                                        tmp_att_dis = disassembly_lib.clean_att_disassembly_from_comment(
                                            tmp_att_dis)
                                        callee_dis = disassembly_lib.clean_att_disassembly_from_comment(
                                            elem2[4])
                                        #print(f'len att-dis 1 >{len(tmp_att_dis)}<')
                                        #print(f'att-dis >{tmp_att_dis}<')

                                        dis1_str = ' '.join(tmp_att_dis)
                                        #dis2_str = ' '.join(elem2[4])
                                        dis2_str = ' '.join(callee_dis)

                                        dis1_str = disassembly_lib.split_disassembly(
                                            dis1_str)
                                        dis2_str = disassembly_lib.split_disassembly(
                                            dis2_str)
                                        #dis1_str = dis_split(dis1_str)
                                        #dis2_str = dis_split(dis2_str)

                                        ##the max-seq-length blows memory (>160GB ram) with model.fit() if e.g. over 6million
                                        if (len(dis1_str) > 100000) or (
                                                len(dis2_str) > 100000) or (
                                                    len(dis1_str) <
                                                    1) or (len(dis2_str) < 1):
                                            print(
                                                f'dis1_str >{len(dis1_str)}<')
                                            print(
                                                f'dis2_str >{len(dis2_str)}<')
                                            #print(f"package >{elem[2]}< bin >{elem[3]}< file >{elem[6]}< func >{elem[7]}<")
                                            #print(f"package >{elem2[2]}< bin >{elem2[3]}< file >{elem2[6]}< func >{elem2[7]}<")

                                        else:
                                            dis_str = dis1_str + dis2_str

                                            #print(f'dis_str >{dis_str}<')

                                            dataset_list.append(
                                                (dis_str, arg_two))
                                            counter += 1

                                        break

    if dataset_list:
        if config['save_file_type'] == 'pickle':
            ret_file = open(
                config['save_dir'] +
                os.path.basename(pickle_file).replace('.tar.bz2', ''), 'wb+')
            pickle_list = pickle.dump(dataset_list, ret_file)
            ret_file.close()
        else:
            ## save as tfrecord
            dis_list = list()
            ret_list = list()

            for item in dataset_list:
                dis_list.append(item[0])
                ret_list.append(item[1])

            raw_dataset = tf.data.Dataset.from_tensor_slices(
                (dis_list, ret_list))

            serialized_features_dataset = raw_dataset.map(tf_serialize_example)

            filename = config['save_dir'] + os.path.basename(
                tarbz2_file).replace('.pickle.tar.bz2', '') + '.tfrecord'
            writer = tf.data.experimental.TFRecordWriter(filename)
            writer.write(serialized_features_dataset)

    return counter
コード例 #2
0
def proc_build(tarbz2_file, work_dir, save_dir, config):

    tarbz2_lib.untar_file_to_path(tarbz2_file, work_dir)
    #untar_one_pickle_file(tarbz2_file, work_dir)

    pickle_file = work_dir + os.path.basename(tarbz2_file).replace(
        '.tar.bz2', '')
    pickle_file_content = pickle_lib.get_pickle_file_content(pickle_file)
    #pickle_file_content = get_pickle_file_content(work_dir + os.path.basename(pickle_file).replace('.tar.bz2', ''))

    binaries = set()
    functions = set()
    for elem in pickle_file_content:
        binaries.add(elem[7])
        functions.add(elem[2])

    print(f'binaries >{binaries}<')

    counter = 0
    dataset_list = list()

    ## 1. get one binary
    ## 2. get one function of this binary
    ## 3. get disassembly of this function
    ## 4. check if this disassembly calls another function
    ## 4.1 filter @plt
    ## 5. if yes: get disassembly of caller function
    ## 6. save caller, callee, func_signature
    ## 7. check again, if it calls another function
    ## 8. if yes: get disassembly of caller function
    ## 9. save caller, calle, func_signature
    ##10. get disassembly of next function of this binary
    ##11. check if ....
    for bin in binaries:
        for func in functions:
            ## search for bin and func
            for elem in pickle_file_content:
                ### if we found bin and func
                if elem[7] == bin and elem[2] == func:
                    ## get att disassembly
                    att_dis = elem[4]
                    ## check every line if there is a call
                    for item in att_dis:
                        ## find call in disas
                        if disassembly_lib.find_call_in_disassembly_line(item):
                            ## if found, get callee name
                            callee_name = disassembly_lib.get_callee_name_from_disassembly_line(
                                item)

                            #print(f'callee_name >{callee_name}<')

                            ## search for same bin, but callee func
                            for elem2 in pickle_file_content:
                                ### if we found it, get return type and disassembly
                                if elem2[7] == bin and elem2[2] == callee_name:

                                    return_type_func_sign = return_type_lib.get_return_type_from_function_signature(
                                        elem2[0])
                                    return_type = return_type_lib.get_return_type_from_gdb_ptype(
                                        elem2[1])

                                    ###for debugging, what string is still unknown ?? should show nothing
                                    if return_type == 'unknown':
                                        print(
                                            f'string_before_func_name: {string_before_func_name}'
                                        )

                                    if return_type == 'unknown':
                                        #print('unknown found')
                                        #breaker = True
                                        #break
                                        pass
                                    elif return_type == 'delete':
                                        #print('delete found')
                                        ### no return type found, so delete this item
                                        pass
                                    elif return_type == 'process_further':
                                        print(f'ERRROOOORRRR---------------')
                                    else:

                                        dis1_str = ' '.join(att_dis)
                                        dis2_str = ' '.join(elem2[4])

                                        dis1_str = disassembly_lib.split_disassembly(
                                            dis1_str)
                                        dis2_str = disassembly_lib.split_disassembly(
                                            dis2_str)
                                        #dis1_str = dis_split(dis1_str)
                                        #dis2_str = dis_split(dis2_str)

                                        dis_str = dis1_str + dis2_str

                                        #print(f'dis_str >{dis_str}<')

                                        dataset_list.append(
                                            (dis_str, return_type))
                                        counter += 1
                                        break

    if dataset_list:
        if config['save_file_type'] == 'pickle':
            ret_file = open(
                config['save_dir'] +
                os.path.basename(pickle_file).replace('.tar.bz2', ''), 'wb+')
            pickle_list = pickle.dump(dataset_list, ret_file)
            ret_file.close()
        else:
            ## save as tfrecord
            dis_list = list()
            ret_list = list()

            for item in dataset_list:
                dis_list.append(item[0])
                ret_list.append(item[1])

            raw_dataset = tf.data.Dataset.from_tensor_slices(
                (dis_list, ret_list))

            serialized_features_dataset = raw_dataset.map(tf_serialize_example)

            filename = config['save_dir'] + os.path.basename(
                tarbz2_file).replace('.pickle.tar.bz2', '') + '.tfrecord'
            writer = tf.data.experimental.TFRecordWriter(filename)
            writer.write(serialized_features_dataset)

    return counter
コード例 #3
0
    def run(self):
        self.updateProgressBar.emit(1)
        
        curr_pos = cutter.cmd('s')
        if curr_pos.strip() == '0x0':
            print('runInference not from addr 0x0')
            return
         
        self.set_new_radare2_e()
        
        ### get name of current function
        current_func_name = cutter.cmdj("afdj $F").get('name')
         
        #print(f'current_func_name >{current_func_name}<')
        
        ## find data/code references to this address with $F
        current_func_header = cutter.cmdj("axtj $F")
        
        ## get addr of callee
        caller_addr = 0
        for item_dicts in current_func_header:
            print(f'item_dicts >{item_dicts}<')
            for elem in item_dicts:
                if elem == 'from':
                    caller_addr = item_dicts[elem]
                    #print(f'address of caller >{item_dicts[elem]}<')
                 
         
        ## get disassembly of current/callee function
        address = cutter.cmd('s').strip()
        #print(f'address >{address}<')
        if address == '0x0':
            self.updateProgressBar.emit(100)
            self.resultReady.emit('')
            self.summaryReady.emit('Current address is 0x0, choose other function')
            #print(f'address kicked >{address}<')
            return
        
        disasm_callee_str = self.get_disassembly_of(address)
        #print(f'disasm_callee_str >{disasm_callee_str}<')
        
        ### get disassembly of caller function
        disasm_caller_str = self.get_disassembly_of(caller_addr)
        #print(f'disasm_caller_str >{disasm_caller_str}<')
  
        ### split disas for the tf-model     
        disasm_caller_str = disassembly_lib.split_disassembly(disasm_caller_str)
        disasm_callee_str = disassembly_lib.split_disassembly(disasm_callee_str)

        ##check if we got caller and callee disassembly
        if (len(disasm_caller_str) == 0) or (len(disasm_callee_str) == 0):
            self.updateProgressBar.emit(100)
            self.resultReady.emit('')
            self.summaryReady.emit('No caller disassembly found, choose other function')
            
            #print(f'Not found callee and caller disassembly.')
            return
         
        ### the path were we cloned git repo to
        self._userHomePath = os.path.expanduser('~')
        #print(f'userHomePath >{self._userHomePath}<')
        
        func_sign_prob_git_path = self._userHomePath + "/git/func_sign_prob/"
        
        self.updateProgressBar.emit(9)
         
        callee_addr_split = [char if char != '0' else 'null' for char in address]
        callee_addr = ' '.join(callee_addr_split)
        disasm_caller_callee_predict_str = disasm_caller_str + ' caller_callee_separator ' + callee_addr + ' ' + disasm_callee_str
        ### predict now ret-type
        ret_type_prediction_summary_str = self.get_prediction('return_type/words_100000', 
                                                                disasm_caller_callee_predict_str, 
                                                                func_sign_prob_git_path)
        
        
          
        ## store for later, will be overridden
        ret_type_model_summary_str = self.model_summary_str
        ret_type_biggest_prob = self.biggest_prob
        ret_type_biggest_prob_type = self.biggest_prob_type
        ret_type_biggest_prob_percent = 100 * ret_type_biggest_prob
        self.updateProgressBar.emit(10)
                  
        ### predict now nr_of_args
        nr_of_args_prediction_summary_str = self.get_prediction('nr_of_args/words_100000', 
                                                                disasm_caller_callee_predict_str, 
                                                                func_sign_prob_git_path)
                   
                
        ## store for later, will be overridden
        nr_of_args_model_summary_str = self.model_summary_str
        nr_of_args_biggest_prob = self.biggest_prob
        nr_of_args_biggest_prob_type = self.biggest_prob_type
        self.updateProgressBar.emit(20)
           
        ###predict now arg_one
        arg_one_prediction_summary_str = self.get_prediction('arg_one', 
                                                                disasm_caller_callee_predict_str, 
                                                                func_sign_prob_git_path)
            

        ## store for later, will be overridden
        arg_one_model_summary_str = self.model_summary_str
        arg_one_biggest_prob = self.biggest_prob
        arg_one_biggest_prob_type = self.biggest_prob_type
        arg_one_biggest_prob_percent = 100 * arg_one_biggest_prob
        
            
        #nr_of_args_biggest_prob_type = 1
        if nr_of_args_biggest_prob_type == 1:
            
            self.set_stored_radare2_e()
            
            self.updateProgressBar.emit(100)
            
            self.summaryReady.emit(f"tf return type model summary:\n \
                                        {ret_type_model_summary_str}\n \
                                        {ret_type_prediction_summary_str}\n \
                                        tf nr_of_args model summary:\n \
                                         {nr_of_args_model_summary_str}\n \
                                         {nr_of_args_prediction_summary_str}\n \
                                         tf arg_one model summary:\n \
                                         {self.model_summary_str}\n \
                                         {arg_one_prediction_summary_str}")

            self.resultReady.emit(f'{ret_type_biggest_prob_type} \
                 <span style=\"background-color:red;\">({ret_type_biggest_prob_percent:3.1f}%)</span> \
                 {current_func_name} ( \
                 {arg_one_biggest_prob_type} \
                 <span style=\"background-color:red;\">({arg_one_biggest_prob_percent:3.1f}%)</span> \
                 )')
               
            
            return
        else:
            self.updateProgressBar.emit(30)   
              
        ###if more one args
        ###predict now arg_two
        arg_two_prediction_summary_str = self.get_prediction('arg_two', 
                                                                disasm_caller_callee_predict_str, 
                                                                func_sign_prob_git_path)
           
   
        ## store for later, will be overridden
        arg_two_model_summary_str = self.model_summary_str
        arg_two_biggest_prob = self.biggest_prob
        arg_two_biggest_prob_type = self.biggest_prob_type
        arg_two_biggest_prob_percent = 100 * arg_two_biggest_prob
        
          
        #nr_of_args_biggest_prob_type = 2 
        if nr_of_args_biggest_prob_type == 2:
            
            self.updateProgressBar.emit(100)
            
            self.summaryReady.emit(f"tf return type model summary:\n \
                                        {ret_type_model_summary_str}\n \
                                        {ret_type_prediction_summary_str}\n \
                                        tf nr_of_args model summary:\n \
                                        {nr_of_args_model_summary_str}\n \
                                        {nr_of_args_prediction_summary_str}\n \
                                        tf arg_one model summary:\n \
                                        {arg_one_model_summary_str}\n \
                                        {arg_one_prediction_summary_str}\n \
                                        tf arg_two model summary:\n \
                                        {arg_two_model_summary_str}\n \
                                        {arg_two_prediction_summary_str}")

            self.resultReady.emit(f'{ret_type_biggest_prob_type} \
                <span style=\"background-color:red;\">({ret_type_biggest_prob_percent:3.1f}%)</span> \
                {current_func_name} ( \
                {arg_one_biggest_prob_type} \
                <span style=\"background-color:red;\">({arg_one_biggest_prob_percent:3.1f}%)</span> , \
                {arg_two_biggest_prob_type} \
                <span style=\"background-color:red;\">({arg_two_biggest_prob_percent:3.1f}%)</span> \
                )')
            
              
            self.set_stored_radare2_e()
            return
        else:
            self.updateProgressBar.emit(40)
         
        ###if more than two args
        ###predict now arg_three
        arg_three_prediction_summary_str = self.get_prediction('arg_three', 
                                                                disasm_caller_callee_predict_str, 
                                                                func_sign_prob_git_path)
          
  
        ## store for later, will be overridden
        arg_three_model_summary_str = self.model_summary_str
        arg_three_biggest_prob = self.biggest_prob
        arg_three_biggest_prob_type = self.biggest_prob_type
        arg_three_biggest_prob_percent = 100 * arg_three_biggest_prob
         
                                              
        #nr_of_args_biggest_prob_type = 3
        if nr_of_args_biggest_prob_type >= 3:   #hack, if more args, need more models
              
            self.updateProgressBar.emit(100)
            
            self.summaryReady.emit(f"tf return type model summary:\n \
                                        {ret_type_model_summary_str}\n \
                                        {ret_type_prediction_summary_str}\n \
                                        tf nr_of_args model summary:\n \
                                        {nr_of_args_model_summary_str}\n \
                                        {nr_of_args_prediction_summary_str}\n \
                                        tf arg_one model summary:\n \
                                        {arg_one_model_summary_str}\n \
                                        {arg_one_prediction_summary_str}\n \
                                        tf arg_two model summary:\n \
                                        {arg_two_model_summary_str}\n \
                                        {arg_two_prediction_summary_str}\n \
                                        tf arg_three model summary:\n \
                                        {arg_three_model_summary_str}\n \
                                        {arg_three_prediction_summary_str}")

            self.resultReady.emit(f'{ret_type_biggest_prob_type} \
                <span style=\"background-color:red;\">({ret_type_biggest_prob_percent:3.1f}%)</span> \
                {current_func_name} ( \
                {arg_one_biggest_prob_type} \
                <span style=\"background-color:red;\">({arg_one_biggest_prob_percent:3.1f}%)</span> , \
                {arg_two_biggest_prob_type} \
                <span style=\"background-color:red;\">({arg_two_biggest_prob_percent:3.1f}%)</span> , \
                {arg_three_biggest_prob_type} \
                <span style=\"background-color:red;\">({arg_three_biggest_prob_percent:3.1f}%)</span> \
                )')
              
            self.set_stored_radare2_e()
            return
        else:
            self.updateProgressBar.emit(50)
         
        #for debug
        print('over')
         
        self.set_stored_radare2_e()
#         /* ... here is the expensive or blocking operation ... */
        self.resultReady.emit('')
コード例 #4
0
def proc_build(tarbz2_file, work_dir, save_dir, config):

    tarbz2_lib.untar_file_to_path(tarbz2_file, work_dir)
    #untar_one_pickle_file(tarbz2_file, work_dir)

    pickle_file = work_dir + os.path.basename(tarbz2_file).replace(
        '.tar.bz2', '')
    pickle_file_content = pickle_lib.get_pickle_file_content(pickle_file)
    #pickle_file_content = get_pickle_file_content(work_dir + os.path.basename(pickle_file).replace('.tar.bz2', ''))

    binaries = set()
    functions = set()
    for elem in pickle_file_content:
        binaries.add(elem[7])
        functions.add(elem[2])

    print(f'binaries >{binaries}<')

    counter = 0
    dataset_list = list()

    ## 1. get one binary
    ## 2. get one function of this binary
    ## 3. get disassembly of this function
    ## 4. check if this disassembly calls another function
    ## 4.1 filter @plt
    ## 5. if yes: get disassembly of caller function
    ## 6. save caller, callee, nr_of_args
    ## 7. check again, if it calls another function
    ## 8. if yes: get disassembly of caller function
    ## 9. save caller, calle, func_signature
    ##10. get disassembly of next function of this binary
    ##11. check if ....
    for bin in binaries:
        for func in functions:
            ## search for bin and func
            for elem in pickle_file_content:
                ### if we found bin and func
                if elem[7] == bin and elem[2] == func:
                    ## get att disassembly
                    att_dis = elem[4]
                    #print(f'att-dis >{att_dis}<')
                    ## check every line if there is a call
                    for item in att_dis:
                        ## find call in disas
                        if disassembly_lib.find_call_in_disassembly_line(item):
                            ## if found, get callee name
                            callee_name = disassembly_lib.get_callee_name_from_disassembly_line(
                                item)

                            #print(f'callee_name >{callee_name}<')

                            ## search for same bin, but callee func
                            for elem2 in pickle_file_content:
                                ### if we found it, get return type and disassembly
                                if elem2[7] == bin and elem2[2] == callee_name:

                                    if (len(elem2[4]) >
                                        (int(config[
                                            'tokenized_disassembly_length']) /
                                         2)) or (len(att_dis) > (int(config[
                                             'tokenized_disassembly_length']) /
                                                                 2)
                                                 ) or (len(elem2[4]) < 1) or (
                                                     len(att_dis) < 1):
                                        continue

                                    #return_type_func_sign = return_type_lib.get_return_type_from_function_signature(elem2[0])
                                    #return_type = return_type_lib.get_return_type_from_gdb_ptype(elem2[1])

                                    nr_of_args = return_type_lib.get_nr_of_args_from_function_signature(
                                        elem2[0])

                                    ###for debugging, what string is still unknown ?? should show nothing
                                    #                                     if return_type == 'unknown':
                                    #                                         print(f'string_before_func_name: {return_type_func_sign}')
                                    #
                                    #                                     if return_type == 'unknown':
                                    #                                         #print('unknown found')
                                    #                                         #breaker = True
                                    #                                         #break
                                    #                                         pass
                                    #                                     elif return_type == 'delete':
                                    #                                         #print('delete found')
                                    #                                         ### no return type found, so delete this item
                                    #                                         pass
                                    #                                     elif return_type == 'process_further':
                                    #                                         print(f'ERRROOOORRRR---------------')
                                    if nr_of_args == -1:
                                        print(f'Error nr_of_args')
                                    else:
                                        print(f'nr_of_args >{nr_of_args}<',
                                              end='\r')

                                        tmp_att_dis = att_dis
                                        #print(f'len att-dis 1 >{len(tmp_att_dis)}<')
                                        tmp_att_dis = disassembly_lib.clean_att_disassembly_from_comment(
                                            tmp_att_dis)
                                        callee_dis = disassembly_lib.clean_att_disassembly_from_comment(
                                            elem2[4])
                                        #print(f'len att-dis 1 >{len(tmp_att_dis)}<')
                                        #print(f'att-dis >{tmp_att_dis}<')

                                        dis1_str = ' '.join(tmp_att_dis)
                                        #dis2_str = ' '.join(elem2[4])
                                        dis2_str = ' '.join(callee_dis)

                                        dis1_str = disassembly_lib.split_disassembly(
                                            dis1_str)
                                        dis2_str = disassembly_lib.split_disassembly(
                                            dis2_str)
                                        #dis1_str = dis_split(dis1_str)
                                        #dis2_str = dis_split(dis2_str)
                                        #print(f'dis1_str >{dis1_str}<')

                                        ##the max-seq-length blows memory (>160GB ram) with model.fit() if e.g. over 6million
                                        if (len(dis1_str) > (int(config[
                                                'tokenized_disassembly_length']
                                                                 ) / 2)
                                            ) or (len(dis2_str) > (int(config[
                                                'tokenized_disassembly_length']
                                                                       ) / 2)
                                                  ) or (len(dis1_str) < 1) or (
                                                      len(dis2_str) < 1):
                                            print(
                                                f'tokenized_disassembly_length caller >{len(dis1_str)}<'
                                            )
                                            print(
                                                f'tokenized_disassembly_length callee >{len(dis2_str)}<'
                                            )
                                            #print(f"package >{elem[2]}< bin >{elem[3]}< file >{elem[6]}< func >{elem[7]}<")
                                            #print(f"package >{elem2[2]}< bin >{elem2[3]}< file >{elem2[6]}< func >{elem2[7]}<")

                                        else:
                                            dis_str = dis1_str + dis2_str

                                            #print(f'dis_str >{dis_str}<')

                                            dataset_list.append(
                                                (dis_str, nr_of_args))
                                            counter += 1

                                        break

    if dataset_list:
        if config['save_file_type'] == 'pickle':
            ret_file = open(
                config['save_dir'] +
                os.path.basename(pickle_file).replace('.tar.bz2', ''), 'wb+')
            pickle_list = pickle.dump(dataset_list, ret_file)
            ret_file.close()
        else:
            ## save as tfrecord
            dis_list = list()
            ret_list = list()

            for item in dataset_list:
                dis_list.append(item[0])
                ret_list.append(item[1])

            raw_dataset = tf.data.Dataset.from_tensor_slices(
                (dis_list, ret_list))

            serialized_features_dataset = raw_dataset.map(tf_serialize_example)

            filename = config['save_dir'] + os.path.basename(
                tarbz2_file).replace('.pickle.tar.bz2', '') + '.tfrecord'
            writer = tf.data.experimental.TFRecordWriter(filename)
            writer.write(serialized_features_dataset)

    return counter