Ejemplo n.º 1
0
class DataGenerator:
  def __init__(self, outDir):
    sys.stderr.write("DataGenerator: Constructor\n")
    ## Languages and Order
    self.__LID = ["FR","EN"]
    self.__l2MapFile = "/usr0/home/pgadde/Work/CodeSwitching/FrenchEnglish/NewsCommentary/E0/UniversalMapping/en-ptb.map"
    self.__l1MapFile = "/usr0/home/pgadde/Work/CodeSwitching/FrenchEnglish/NewsCommentary/E0/UniversalMapping/fr-paris.map"
    ## Data containers
    self.__L1 = []
    self.__L2 = []
    self.__align = []
    self.__outputDir = outDir
    self.__posMap = {}
    self.__phraseMap = dd(list)
    self.__csInstance = CSHandler()
    self.__utils = Utils()
    ## Generation Variants
    self.__csVariants = [0,1,2]
    self.__tagsetVariants = ["",".uni"]
    self.__dataRange = range(52,1000,52)
    ##LID stuff
    self.__L1Tags = set()
    self.__L2Tags = set()
    self.__commonTags = set()
    ## Pre processing
    self.__genPosMap()
    self.__genPhraseMap()
    self.__csInstance.updatePhraseMap(self.__phraseMap)
 
  def loadData(self, l1Data, l2Data, aligns):
    self.__L1 = [l.strip() for l in open(l1Data)]
    self.__L2 = [l.strip() for l in open(l2Data)]
    self.__align = [l.strip() for l in open(aligns)]
  
  def __genTestData(self, testIndices):
    for csType in self.__csVariants:
      #for tag in self.__tagsetVariants:
      dataFile = open(self.__outputDir+"TestCSType"+str(csType),'w')
      dataFileUni = open(self.__outputDir+"TestCSType"+str(csType)+".uni",'w')
      stopLength = 5129
      for index in testIndices:
        order = stopLength%2
        self.__csInstance.updateHandler(self.__L1[index], self.__L2[index], self.__align[index], order)
        csReturn = self.__csInstance.csSentence(csType)
        csLine = csReturn[0]
        #csSequence = csReturn[1]
        if csLine != -1:
          stopLength -= 1
        else:
          continue
        self.__addLangTags(csLine)
        csLineUni = self.__map2Uni(csLine)
        dataFile.write(' '.join(map(lambda x:'_#'.join(x), csLine))+'\n')
        dataFileUni.write(' '.join(map(lambda x:'_#'.join(x), csLineUni))+'\n')
        if stopLength == 0:
          break
      dataFile.close()
      dataFileUni.close()
      if stopLength != 0:
        print "Test Break!!", 5129, stopLength
        dummy = raw_input()
  
  def __genTrainData(self, pureIndices, csIndices):
    statusCount = 0
    for data in self.__dataRange:
      pr = 0
      while 1:
      #for pr in range(3):
        if pr == 3:
          break
        pr= int(pr*1.0/2 * data)
        tr = data - pr
        pr = pr/2
        random.seed()
        pIndices = random.sample(pureIndices, pr)
        cIndices = random.sample(csIndices, tr*5)
        for csType in self.__csVariants:
          print csType
          # Debugging !!
          #switch = ""
          #############
          #for tag in self.__tagsetVariants:
            # Debugging !!
            #if switch == "yes":
            #    break
            ###################
            #sys.stderr.write(outputDir+"Train"+cs+str(len(trainVariants[tr]))+"Pure"+str(len(pureVariants[pr]))+tag+"\n")
          dataFile = open(self.__outputDir+"TrainCSType"+str(csType)+"CS"+str(tr)+"Pure"+str(pr*2),'w')
          dataFileUni = open(self.__outputDir+"TrainCSType"+str(csType)+"CS"+str(tr)+"Pure"+str(pr*2)+".uni",'w')
          for index in pIndices:
            l1Line = self.__utils.wordTags(self.__L1[index])
            l2Line = self.__utils.wordTags(self.__L2[index])
            self.__addLangTags(l1Line)
            self.__addLangTags(l2Line)
            l1LineUni = self.__map2Uni(l1Line)
            l2LineUni = self.__map2Uni(l2Line)
            dataFile.write(' '.join(map(lambda x:'_#'.join(x), l1Line))+'\n')
            dataFile.write(' '.join(map(lambda x:'_#'.join(x), l2Line))+'\n')
            dataFileUni.write(' '.join(map(lambda x:'_#'.join(x), l1LineUni))+'\n')
            dataFileUni.write(' '.join(map(lambda x:'_#'.join(x), l2LineUni))+'\n')
          stopLength = tr
          for index in cIndices:
            csLine = ""
            order = stopLength%2
            #print order
            self.__csInstance.updateHandler(self.__L1[index], self.__L2[index], self.__align[index], order)
            csReturn = self.__csInstance.csSentence(csType)
            # Debugging !!                         
            #sys.stderr.write("Switch to another CS variant?? ")
            #switch = raw_input()
            #if switch == "yes":
            #    break
            ###############
            csLine = csReturn[0]
            #csSequence = csReturn[1]
            if csLine != -1:
              stopLength -= 1
            else:
              continue
            self.__addLangTags(csLine)
            csLineUni = self.__map2Uni(csLine)
            dataFile.write(' '.join(map(lambda x:'_#'.join(x), csLine))+'\n')
            dataFileUni.write(' '.join(map(lambda x:'_#'.join(x), csLineUni))+'\n')
            if stopLength == 0:
              break
          dataFile.close()
          dataFileUni.close()
          if stopLength != 0:
            print tr, stopLength, "Training Break!!"
            pr -= 1
            #dummy = raw_input()
          statusCount += 1
          if statusCount%50 == 0:
            print statusCount,
            sys.stdout.flush()
        pr += 1
    print statusCount
    
  def __addLangTags(self, wordTags):
    #print self.__L1Tags
    #print self.__L2Tags
    #print wordTags
    for index in range(len(wordTags)):
      tag = wordTags[index][1]
      lang = ""
      if tag in self.__commonTags:
        lang = "C"
      elif tag in self.__L1Tags:
        lang = self.__LID[0]
      elif tag in self.__L2Tags:
        lang = self.__LID[1]
      if lang == "":
        print "Something wrong with the tagsets in the function add_lang"
        dummy = raw_input()
      wordTags[index].append(lang)
  
  def __genPosMap(self):
    for i in open(self.__l1MapFile):
      i = i.strip()
      srcTag = [i.split()[0]]
      uniTag = i.split()[1]
      if srcTag[0].find('|') >= 0:
        srcTag = srcTag[0].split('|')
      for tag in srcTag:
        self.__posMap[tag] = uniTag
    for i in open(self.__l2MapFile):
      i = i.strip()
      srcTag = [i.split()[0]]
      uniTag = i.split()[1]
      if srcTag[0].find('|') >= 0:
        srcTag = srcTag[0].split('|')
      for tag in srcTag:
        self.__posMap[tag] = uniTag
        
    
    self.__L1Tags = set()
    for line in open(self.__l1MapFile):
      tags = line.split()[0].split('|')
      for tag in tags:
        self.__L1Tags.add(tag)
    for line in open(self.__l2MapFile):
      tags = line.split()[0].split('|')
      for tag in tags:
        self.__L2Tags.add(tag)
    self.__commonTags = set([c for c in self.__L1Tags if c in self.__L2Tags])
  
  def __map2Uni(self, wordTagsLangs):
    newLine = []
    for index in range(len(wordTagsLangs)):
      newLine.append(wordTagsLangs[index])
      tag = wordTagsLangs[index][1]
      try:
        newLine[index][1] = self.__posMap[tag]
      except:
        dummy = raw_input("Something wrong.. Couldn't find Uni Map\n")
    return newLine
  
  def __genPhraseMap(self):
    phraseMapFile = open("/usr0/home/pgadde/Work/CodeSwitching/FrenchEnglish/NewsCommentary/E17/mapping")
    for i in phraseMapFile:
      i = i.strip()
      self.__phraseMap[i.split()[0]].extend(i.split()[1].split(","))
    
  def __randomSample(self):
    print "Random Sample Train"
    totalLines = 95129
    testLines = 15129
    testIndices = random.sample(range(totalLines),testLines)
    #print testIndices
    trainIndices = []
    for i in range(totalLines):
      if i not in testIndices:
        trainIndices.append(i)
    csIndices = random.sample(trainIndices, 6000)
    remaining = []
    for i in trainIndices:
      if i not in csIndices:
        remaining.append(i)
    pureIndices = random.sample(remaining,6000)
    return trainIndices, testIndices, pureIndices, csIndices

  def __getRanges(self, dataRanges):
    dataRanges = open(dataRanges)
    trainIndices = []
    testIndices = []
    pureIndices = []
    csIndices = []
    for i in dataRanges:
        if i.split(":")[0] == "trainIndices":
            indices = i.split(":")[1].strip("[]\n").replace(" ","").split(",")
            for j in  indices:
                j = int(j)
                trainIndices.append(j)
        elif i.split(":")[0] == "testIndices":
            indices = i.split(":")[1].strip("[]\n").replace(" ","").split(",")
            for j in  indices:
                j = int(j)
                testIndices.append(j)
        elif i.split(":")[0] == "pureIndices":
            indices = i.split(":")[1].strip("[]\n").replace(" ","").split(",")
            for j in  indices:
                j = int(j)
                pureIndices.append(j)
        elif i.split(":")[0] == "csIndices":
            indices = i.split(":")[1].strip("[]\n").replace(" ","").split(",")
            for j in  indices:
                j = int(j)
                csIndices.append(j)
    #print len(testIndices), len(trainIndices), len(pureIndices), len(csIndices)
    #************************************
    # Un-comment this when sampling again
    #************************************
    '''rangesFile = open(dataRanges,'w')
    rangesFile.write("trainIndices:"+str(trainIndices)+"\n")
    rangesFile.write("testIndices:"+str(testIndices)+"\n")
    rangesFile.write("pureIndices:"+str(pureIndices)+"\n")
    rangesFile.write("csIndices:"+str(csIndices)+"\n")
    rangesFile.close()
    return self.__randomSample()'''
    return trainIndices, testIndices, pureIndices, csIndices
    
  def generateData(self, dataRanges):    
    ranges = self.__getRanges(dataRanges)
    #unmappedEnglish = open("/usr0/home/pgadde/Work/CodeSwitching/FrenchEnglish/NewsCommentary/E17/unmapped")
    #unmappedEnPhrases = [i.strip() for i in unmappedEnglish]
    #trainIndices = ranges[0]
    testIndices = ranges[1]
    pureIndices = ranges[2]
    csIndices = ranges[3]
    self.__genTrainData(pureIndices, csIndices)
    self.__genTestData(testIndices)