示例#1
0
class Generator:
  def __init__(self, outDir):
    self.config = None
    self.__csHash = set()
    self.__outputDir = outDir
    self.__csInstance = CSHandler()
    self.__dataHandler = DataHandler()
    self.__utils = Utils()
    self.__Tree = Dependencytree()
    self.__fileSuffix = ""
    self.prepareConfig()
    
  def prepareConfig(self):
    self.config = GeneratingConfig()
    self.config.setCSVariants([0, 1, 2, 3, 4])
    self.config.setDataRanges({0:range(50, 1001, 50), 1:range(50, 1001, 50), 2:range(50, 1001, 50), 3:range(50, 1001, 50), 4:range(50, 1001, 50)})
    self.config.setSplits([(50, 50), (60, 40), (70, 30), (80, 20), (90, 10)])
    self.config.setTagsetVariants([".uniq", ".uni"])
  
  def prepareGenerator(self):
    self.__csInstance.updateLIDTags(self.__dataHandler.LID[0], self.__dataHandler.LID[1])
  
  def prepareRealTest(self, dataFile, outFile):
    dataFile = open(dataFile)
    outFile = open(outFile, 'w')
    for line in dataFile:
      line = map(lambda x:x.split('_#'), line.strip().split())
      uniLine = self.__dataHandler.mapLD2Uni(line)
      outFile.write(' '.join(map(lambda x:'_#'.join(x), uniLine)) + '\n')
    outFile.close()

  def generateTestData(self):
    self.config.setDataRanges({0:range(30, 151, 50), 1:range(30, 151, 50), 2:range(30, 151, 50), 3:range(30, 151, 50), 4:range(30, 151, 50)})
    for csType in self.config.csVariants:
      print "type" + str(csType)
      for data in self.config.dataRanges[csType]:
        print
        print " numSents:" + str(data * 2),
        initialSplitCSData = []
        for splitIndex in range(len(self.config.splits)):
          csData = []
          Split = self.config.splits[splitIndex]
          pureData = []
          
          pureFile = open(self.__outputDir + "TrainCSType" + str(csType) + "CS" + str(Split[1]) + "Pure" + str(Split[0]) + "Total" + str(2 * data) + "_Control" + self.__fileSuffix, 'w')
          dataFile = open(self.__outputDir + "TrainCSType" + str(csType) + "CS" + str(Split[1]) + "Pure" + str(Split[0]) + "Total" + str(2 * data) + self.__fileSuffix, 'w')
          pureUniFile = open(self.__outputDir + "TrainCSType" + str(csType) + "CS" + str(Split[1]) + "Pure" + str(Split[0]) + "Total" + str(2 * data) + "_Control" + ".uni" + self.__fileSuffix, 'w')
          dataUniFile = open(self.__outputDir + "TrainCSType" + str(csType) + "CS" + str(Split[1]) + "Pure" + str(Split[0]) + "Total" + str(2 * data) + ".uni" + self.__fileSuffix, 'w')
          
          pr = int((Split[0] * 1.0 / (Split[0] + Split[1])) * data)
          tr = data - pr
          print " Pure:" + str(2 * pr),
          print " CS:" + str(2 * tr),
          random.seed()
          
          pIndicesL1 = random.sample(range(len(self.__dataHandler.pureL1)), pr)
          pIndicesL2 = random.sample(range(len(self.__dataHandler.pureL2)), pr)
          
          for index in pIndicesL1:
            line = self.__dataHandler.pureL1[index]
            line = self.__dataHandler.addLangTags(line, self.__dataHandler.LID[0])
            line = self.__dataHandler.makeLD(line)
            pureData.append(tuple(line))
            csData.append(tuple(line))
          
          for index in pIndicesL2:
            line = self.__dataHandler.pureL2[index]
            line = self.__dataHandler.addLangTags(line, self.__dataHandler.LID[1])
            line = self.__dataHandler.makeLD(line)
            pureData.append(tuple(line))
            csData.append(tuple(line))

          if splitIndex != 0:
            random.seed()
            csSample = random.sample(initialSplitCSData, tr)
            for sample in csSample:
              csData.append(sample[0])
              csData.append(sample[1])
              pureData.append(sample[2])
              pureData.append(sample[3])
          else:
            self.__csHash = set()
            stopLength = tr
            index = -1
            while 1:
              index += 1
              
              if index == len(self.__dataHandler.parL1):
                ##break
                index = 0
                print "Still:", stopLength, " Looping.."
              
              csLines = []
              csSeqs = []
              
              hashKeys = ["", ""]
              for order in range(2):
              #order = stopLength%2
                self.__csInstance.updateHandler(self.__dataHandler.parL1[index], self.__dataHandler.parL2[index], self.__dataHandler.align[index], order)
                csReturn = self.__csInstance.csSentence(csType)
                csLine = csReturn[0]
                if csLine != -1:
                  hashKeys[order] = (index, order, tuple(csReturn[1]))
                  csLines.append(csLine)
                  csSeqs.append(csReturn[1])
              
              if len(csLines) == 2:
                csWords = set([x[0] for x in csLines[0]]) | set([x[0] for x in csLines[1]])
                self.__Tree.updateTree(self.__dataHandler.parL1[index])
                pureLine1 = self.__Tree.wordTags()
                pureLine1 = self.__dataHandler.addLangTags(pureLine1, self.__dataHandler.LID[0])
                pureLine1 = self.__dataHandler.makeLD(pureLine1)
                self.__Tree.updateTree(self.__dataHandler.parL2[index])
                pureLine2 = self.__Tree.wordTags()
                pureLine2 = self.__dataHandler.addLangTags(pureLine2, self.__dataHandler.LID[1])
                pureLine2 = self.__dataHandler.makeLD(pureLine2)
                pureWords = set([x[0] for x in pureLine1]) | set([x[0] for x in pureLine2])
                if True or pureWords == csWords and hashKeys[0] not in self.__csHash and hashKeys[1] not in self.__csHash:
                  pureData.append(tuple(pureLine1))
                  pureData.append(tuple(pureLine2))
                  csData.append(tuple(csLines[0]))
                  csData.append(tuple(csLines[1]))
                  if splitIndex == 0:
                    initialSplitCSData.append((tuple(csLines[0]), tuple(csLines[1]), tuple(pureLine1), tuple(pureLine2)))
                  stopLength -= 1
                  for hashKey in hashKeys:
                    self.__csHash.add(hashKey)
              else:
                continue
              
              if stopLength <= 0:
                break
              
            if stopLength > 0:
              print tr, stopLength, "Testing Break!!"
              dummy = raw_input()
            
          for csLine in csData:
            dataUniFile.write(self.makeString(self.__dataHandler.mapLD2Uni(csLine)))
            dataFile.write(self.makeString(csLine))
          for pureLine in pureData:
            pureFile.write(self.makeString(pureLine))
            pureUniFile.write(self.makeString(self.__dataHandler.mapLD2Uni(pureLine)))
          pureFile.close()
          dataFile.close()
          pureUniFile.close()
          dataUniFile.close()
  
  
  def generateDataForTest(self):
    for i in range(10):
      self.__fileSuffix = "."+str(i)
      self.generateTrainDataForTest()
  
  def generateTrainDataForTest(self):
    self.config.setDataRanges({0:[450], 1:[450], 2:[450], 3:[450], 4:[450]})
    statusCount = 0
    for csType in self.config.csVariants:
      print "type" + str(csType),
      for data in self.config.dataRanges[csType]:
        print " numSents:" + str(data * 2),
        initialSplitCSData = []
        for splitIndex in range(len(self.config.splits)):
          csData = []
          Split = self.config.splits[splitIndex]
          pureData = []
          
          pureFile = open(self.__outputDir + "TrainCSType" + str(csType) + "CS" + str(Split[1]) + "Pure" + str(Split[0]) + "Total" + str(2 * data) + "_Control" + self.__fileSuffix, 'w')
          dataFile = open(self.__outputDir + "TrainCSType" + str(csType) + "CS" + str(Split[1]) + "Pure" + str(Split[0]) + "Total" + str(2 * data) + self.__fileSuffix, 'w')
          pureUniFile = open(self.__outputDir + "TrainCSType" + str(csType) + "CS" + str(Split[1]) + "Pure" + str(Split[0]) + "Total" + str(2 * data) + "_Control" + ".uni" + self.__fileSuffix, 'w')
          dataUniFile = open(self.__outputDir + "TrainCSType" + str(csType) + "CS" + str(Split[1]) + "Pure" + str(Split[0]) + "Total" + str(2 * data) + ".uni" + self.__fileSuffix, 'w')
          
          pr = int((Split[0] * 1.0 / (Split[0] + Split[1])) * data)
          tr = data - pr
          print " Pure:" + str(2 * pr),
          print " CS:" + str(2 * tr),
          if splitIndex == len(self.config.splits) - 1:
            print
          random.seed()
          
          pIndicesL1 = random.sample(range(len(self.__dataHandler.pureL1)), pr)
          pIndicesL2 = random.sample(range(len(self.__dataHandler.pureL2)), pr)
          
          for index in pIndicesL1:
            line = self.__dataHandler.pureL1[index]
            line = self.__dataHandler.addLangTags(line, self.__dataHandler.LID[0])
            line = self.__dataHandler.makeLD(line)
            pureData.append(tuple(line))
            csData.append(tuple(line))
          
          for index in pIndicesL2:
            line = self.__dataHandler.pureL2[index]
            line = self.__dataHandler.addLangTags(line, self.__dataHandler.LID[1])
            line = self.__dataHandler.makeLD(line)
            pureData.append(tuple(line))
            csData.append(tuple(line))

          if splitIndex != 0:
            random.seed()
            csSample = random.sample(initialSplitCSData, tr)
            for sample in csSample:
              csData.append(sample[0])
              csData.append(sample[1])
              pureData.append(sample[2])
              pureData.append(sample[3])
          else:
            self.__csHash = set()
            stopLength = tr
            index = -1
            while 1:
              index += 1
              
              if index == len(self.__dataHandler.parL1):
                ##break
                index = 0
                print "Still:", stopLength, " Looping.. ",
              
              csLines = []
              csSeqs = []
              
              hashKeys = ["", ""]
              for order in range(2):
              #order = stopLength%2
                self.__csInstance.updateHandler(self.__dataHandler.parL1[index], self.__dataHandler.parL2[index], self.__dataHandler.align[index], order)
                csReturn = self.__csInstance.csSentence(csType)
                csLine = csReturn[0]
                if csLine != -1:
                  hashKeys[order] = (index, order, tuple(csReturn[1]))
                  csLines.append(csLine)
                  csSeqs.append(csReturn[1])
              
              if len(csLines) == 2:
                csWords = set([x[0] for x in csLines[0]]) | set([x[0] for x in csLines[1]])
                self.__Tree.updateTree(self.__dataHandler.parL1[index])
                pureLine1 = self.__Tree.wordTags()
                pureLine1 = self.__dataHandler.addLangTags(pureLine1, self.__dataHandler.LID[0])
                pureLine1 = self.__dataHandler.makeLD(pureLine1)
                self.__Tree.updateTree(self.__dataHandler.parL2[index])
                pureLine2 = self.__Tree.wordTags()
                pureLine2 = self.__dataHandler.addLangTags(pureLine2, self.__dataHandler.LID[1])
                pureLine2 = self.__dataHandler.makeLD(pureLine2)
                pureWords = set([x[0] for x in pureLine1]) | set([x[0] for x in pureLine2])
                if True or pureWords == csWords and hashKeys[0] not in self.__csHash and hashKeys[1] not in self.__csHash:
                  pureData.append(tuple(pureLine1))
                  pureData.append(tuple(pureLine2))
                  csData.append(tuple(csLines[0]))
                  csData.append(tuple(csLines[1]))
                  if splitIndex == 0:
                    initialSplitCSData.append((tuple(csLines[0]), tuple(csLines[1]), tuple(pureLine1), tuple(pureLine2)))
                  stopLength -= 1
                  for hashKey in hashKeys:
                    self.__csHash.add(hashKey)
              else:
                continue
              
              if stopLength <= 0:
                break
              
            if stopLength > 0:
              print tr, stopLength, "Training Break!!"
              dummy = raw_input()
            
          for csLine in csData:
            dataUniFile.write(self.makeString(self.__dataHandler.mapLD2Uni(csLine)))
            dataFile.write(self.makeString(csLine))
          for pureLine in pureData:
            pureFile.write(self.makeString(pureLine))
            pureUniFile.write(self.makeString(self.__dataHandler.mapLD2Uni(pureLine)))
          pureFile.close()
          dataFile.close()
          pureUniFile.close()
          dataUniFile.close()

          statusCount += 1
          if statusCount % 50 == 0:
            print statusCount,
            sys.stdout.flush()
    print statusCount

  
  def generateTrainData(self):
    statusCount = 0
    for csType in self.config.csVariants:
      print "type" + str(csType)
      for data in self.config.dataRanges[csType]:
        print
        print " numSents:" + str(data * 2),
        initialSplitCSData = []
        for splitIndex in range(len(self.config.splits)):
          csData = []
          Split = self.config.splits[splitIndex]
          pureData = []
          
          pureFile = open(self.__outputDir + "TrainCSType" + str(csType) + "CS" + str(Split[1]) + "Pure" + str(Split[0]) + "Total" + str(2 * data) + "_Control" + self.__fileSuffix, 'w')
          dataFile = open(self.__outputDir + "TrainCSType" + str(csType) + "CS" + str(Split[1]) + "Pure" + str(Split[0]) + "Total" + str(2 * data) + self.__fileSuffix, 'w')
          pureUniFile = open(self.__outputDir + "TrainCSType" + str(csType) + "CS" + str(Split[1]) + "Pure" + str(Split[0]) + "Total" + str(2 * data) + "_Control" + ".uni" + self.__fileSuffix, 'w')
          dataUniFile = open(self.__outputDir + "TrainCSType" + str(csType) + "CS" + str(Split[1]) + "Pure" + str(Split[0]) + "Total" + str(2 * data) + ".uni" + self.__fileSuffix, 'w')
          
          pr = int((Split[0] * 1.0 / (Split[0] + Split[1])) * data)
          tr = data - pr
          print " Pure:" + str(2 * pr),
          print " CS:" + str(2 * tr),
          random.seed()
          
          pIndicesL1 = random.sample(range(len(self.__dataHandler.pureL1)), pr)
          pIndicesL2 = random.sample(range(len(self.__dataHandler.pureL2)), pr)
          
          for index in pIndicesL1:
            line = self.__dataHandler.pureL1[index]
            line = self.__dataHandler.addLangTags(line, self.__dataHandler.LID[0])
            line = self.__dataHandler.makeLD(line)
            pureData.append(tuple(line))
            csData.append(tuple(line))
          
          for index in pIndicesL2:
            line = self.__dataHandler.pureL2[index]
            line = self.__dataHandler.addLangTags(line, self.__dataHandler.LID[1])
            line = self.__dataHandler.makeLD(line)
            pureData.append(tuple(line))
            csData.append(tuple(line))

          if splitIndex != 0:
            random.seed()
            csSample = random.sample(initialSplitCSData, tr)
            for sample in csSample:
              csData.append(sample[0])
              csData.append(sample[1])
              pureData.append(sample[2])
              pureData.append(sample[3])
          else:
            self.__csHash = set()
            stopLength = tr
            index = -1
            while 1:
              index += 1
              
              if index == len(self.__dataHandler.parL1):
                ##break
                index = 0
                print "Still:", stopLength, " Looping.."
              
              csLines = []
              csSeqs = []
              
              hashKeys = ["", ""]
              for order in range(2):
              #order = stopLength%2
                self.__csInstance.updateHandler(self.__dataHandler.parL1[index], self.__dataHandler.parL2[index], self.__dataHandler.align[index], order)
                csReturn = self.__csInstance.csSentence(csType)
                csLine = csReturn[0]
                if csLine != -1:
                  hashKeys[order] = (index, order, tuple(csReturn[1]))
                  csLines.append(csLine)
                  csSeqs.append(csReturn[1])
              
              if len(csLines) == 2:
                csWords = set([x[0] for x in csLines[0]]) | set([x[0] for x in csLines[1]])
                self.__Tree.updateTree(self.__dataHandler.parL1[index])
                pureLine1 = self.__Tree.wordTags()
                pureLine1 = self.__dataHandler.addLangTags(pureLine1, self.__dataHandler.LID[0])
                pureLine1 = self.__dataHandler.makeLD(pureLine1)
                self.__Tree.updateTree(self.__dataHandler.parL2[index])
                pureLine2 = self.__Tree.wordTags()
                pureLine2 = self.__dataHandler.addLangTags(pureLine2, self.__dataHandler.LID[1])
                pureLine2 = self.__dataHandler.makeLD(pureLine2)
                pureWords = set([x[0] for x in pureLine1]) | set([x[0] for x in pureLine2])
                if True or pureWords == csWords and hashKeys[0] not in self.__csHash and hashKeys[1] not in self.__csHash:
                  pureData.append(tuple(pureLine1))
                  pureData.append(tuple(pureLine2))
                  csData.append(tuple(csLines[0]))
                  csData.append(tuple(csLines[1]))
                  if splitIndex == 0:
                    initialSplitCSData.append((tuple(csLines[0]), tuple(csLines[1]), tuple(pureLine1), tuple(pureLine2)))
                  stopLength -= 1
                  for hashKey in hashKeys:
                    self.__csHash.add(hashKey)
              else:
                continue
              
              if stopLength <= 0:
                break
              
            if stopLength > 0:
              print tr, stopLength, "Training Break!!"
              dummy = raw_input()
            
          for csLine in csData:
            dataUniFile.write(self.makeString(self.__dataHandler.mapLD2Uni(csLine)))
            dataFile.write(self.makeString(csLine))
          for pureLine in pureData:
            pureFile.write(self.makeString(pureLine))
            pureUniFile.write(self.makeString(self.__dataHandler.mapLD2Uni(pureLine)))
          pureFile.close()
          dataFile.close()
          pureUniFile.close()
          dataUniFile.close()

          statusCount += 1
          if statusCount % 50 == 0:
            print statusCount,
            sys.stdout.flush()
    print statusCount
    
  def generateUCTrainData(self): # Unknown words constrained training data
    statusCount = 0
    for csType in self.config.csVariants:
      for data in self.config.dataRanges[csType]:
        initialSplitCSData = []
        for splitIndex in range(len(self.config.splits)):
          csData = []
          Split = self.config.splits[splitIndex]
          pureData = []
          
          pureFile = open(self.__outputDir + "TrainCSType" + str(csType) + "CS" + str(Split[1]) + "Pure" + str(Split[0]) + "Total" + str(2 * data) + "_Control" + self.__fileSuffix, 'w')
          dataFile = open(self.__outputDir + "TrainCSType" + str(csType) + "CS" + str(Split[1]) + "Pure" + str(Split[0]) + "Total" + str(2 * data) + self.__fileSuffix, 'w')
          pureUniFile = open(self.__outputDir + "TrainCSType" + str(csType) + "CS" + str(Split[1]) + "Pure" + str(Split[0]) + "Total" + str(2 * data) + "_Control" + ".uni" + self.__fileSuffix, 'w')
          dataUniFile = open(self.__outputDir + "TrainCSType" + str(csType) + "CS" + str(Split[1]) + "Pure" + str(Split[0]) + "Total" + str(2 * data) + ".uni" + self.__fileSuffix, 'w')
          
          pr = int((Split[0] * 1.0 / (Split[0] + Split[1])) * data)
          tr = data - pr
          print pr
          random.seed()
          
          pIndicesL1 = random.sample(range(len(self.__dataHandler.pureL1)), pr)
          pIndicesL2 = random.sample(range(len(self.__dataHandler.pureL2)), pr)
          
          for index in pIndicesL1:
            line = self.__dataHandler.pureL1[index]
            line = self.__dataHandler.addLangTags(line, self.__dataHandler.LID[0])
            line = self.__dataHandler.makeLD(line)
            pureData.append(tuple(line))
            csData.append(tuple(line))
          
          for index in pIndicesL2:
            line = self.__dataHandler.pureL2[index]
            line = self.__dataHandler.addLangTags(line, self.__dataHandler.LID[1])
            line = self.__dataHandler.makeLD(line)
            pureData.append(tuple(line))
            csData.append(tuple(line))

          if splitIndex != 0:
            random.seed()
            csSample = random.sample(initialSplitCSData, tr)
            for sample in csSample:
              csData.append(sample[0])
              csData.append(sample[1])
              pureData.append(sample[2])
              pureData.append(sample[3])
          else:
            self.__csHash = set()
            stopLength = tr
            index = -1
            while 1:
              index += 1
              
              if index == len(self.__dataHandler.parL1):
                ##break
                index = 0
                print "Still:", stopLength, " Looping.."
              
              csLines = []
              csSeqs = []
              
              hashKeys = ["", ""]
              for order in range(2):
              #order = stopLength%2
                self.__csInstance.updateHandler(self.__dataHandler.parL1[index], self.__dataHandler.parL2[index], self.__dataHandler.align[index], order)
                csReturn = self.__csInstance.csSentence(csType)
                csLine = csReturn[0]
                if csLine != -1:
                  hashKeys[order] = (index, order, tuple(csReturn[1]))
                  csLines.append(csLine)
                  csSeqs.append(csReturn[1])
              
              if len(csLines) == 2:
                csWords = set([x[0] for x in csLines[0]]) | set([x[0] for x in csLines[1]])
                self.__Tree.updateTree(self.__dataHandler.parL1[index])
                pureLine1 = self.__Tree.wordTags()
                pureLine1 = self.__dataHandler.addLangTags(pureLine1, self.__dataHandler.LID[0])
                pureLine1 = self.__dataHandler.makeLD(pureLine1)
                self.__Tree.updateTree(self.__dataHandler.parL2[index])
                pureLine2 = self.__Tree.wordTags()
                pureLine2 = self.__dataHandler.addLangTags(pureLine2, self.__dataHandler.LID[1])
                pureLine2 = self.__dataHandler.makeLD(pureLine2)
                pureWords = set([x[0] for x in pureLine1]) | set([x[0] for x in pureLine2])
                if pureWords == csWords and hashKeys[0] not in self.__csHash and hashKeys[1] not in self.__csHash:
                  pureData.append(tuple(pureLine1))
                  pureData.append(tuple(pureLine2))
                  csData.append(tuple(csLines[0]))
                  csData.append(tuple(csLines[1]))
                  if splitIndex == 0:
                    initialSplitCSData.append((tuple(csLines[0]), tuple(csLines[1]), tuple(pureLine1), tuple(pureLine2)))
                  stopLength -= 1
                  for hashKey in hashKeys:
                    self.__csHash.add(hashKey)
              else:
                continue
              
              if stopLength <= 0:
                break
              
            if stopLength > 0:
              print tr, stopLength, "Training Break!!"
              dummy = raw_input()
            
          for csLine in csData:
            dataUniFile.write(self.makeString(self.__dataHandler.mapLD2Uni(csLine)))
            dataFile.write(self.makeString(csLine))
          for pureLine in pureData:
            pureFile.write(self.makeString(pureLine))
            pureUniFile.write(self.makeString(self.__dataHandler.mapLD2Uni(pureLine)))
          pureFile.close()
          dataFile.close()
          pureUniFile.close()
          dataUniFile.close()

          statusCount += 1
          if statusCount % 50 == 0:
            print statusCount,
            sys.stdout.flush()
    print statusCount

  def makeString(self, wordsTagsLangs):
    return ' '.join(map(lambda x:"_#".join(x), wordsTagsLangs)) + '\n'
    
  def loadData(self, l1Data, l2Data, l1Aligns, l2Aligns, pureL1Data, pureL2Data):
    self.__dataHandler.loadData(l1Data, l2Data, l1Aligns, l2Aligns, pureL1Data, pureL2Data)