def Load(loc, wordEmb, emb_size=50): res = [] table = str.maketrans(dict.fromkeys(string.punctuation)) f = open(loc, 'r') exceptions = 0 while True: line = f.readline() if not line: break if len(line) < 1: continue # Remove Punctuations # Should I handle with 'd -> ed??? line = line.replace("'d", "ed").translate(table).lower().rstrip() line = line.split(" ") for elem in line: if elem == "": continue try: res.append(wordEmb[elem]) except: # logger.PrintDebug("Word not found! %s"%elem,col="r") res.append(np.zeros(emb_size)) exceptions += 1 f.close() r = np.array(res) logger.PrintDebug("Excepted word count : %d" % exceptions, col="r") logger.PrintDebug("Loaded word size : %s" % (str(r.shape)), col="g") return r
def _glove6B(sz=50, logDiv=100000): d = dict() if sz not in [50, 100, 200, 300]: logger.PrintDebug( "Not supported word embedding size, returning empty dict") return d path = "data/glove6B/glove.6B." + str(sz) + "d.txt" logger.PrintDebug("Reading embedded word file from" + path) f = open(path, 'r', encoding="utf-8") while True: line = f.readline() if not line: break line = line.replace("\n", "").split(" ") if len(line) != sz + 1: break l = np.zeros(sz) for i in range(sz): l[i] = float(line[i + 1]) d[line[0]] = l # Print log if logDiv > 0 and len(d) % logDiv == 0: logger.PrintDebug(" Data Count = " + str(len(d))) f.close() logger.PrintDebug("Loaded " + str(len(d)) + " embedded word data") return d
def LoadRawData(loc): logger.PrintDebug("Loading data from " + loc) with gzip.open(loc, 'rb') as f: raw = f.read() if int.from_bytes(raw[:4], byteorder='big') == 2051: # This is image file nImages = int.from_bytes(raw[4:8], byteorder='big') rows = int.from_bytes(raw[8:12], byteorder='big') cols = int.from_bytes(raw[12:16], byteorder='big') res = np.zeros((nImages, 1, rows, cols)) for imageIdx in range(nImages): for rowIdx in range(rows): for colIdx in range(cols): res[imageIdx, 0, rowIdx, colIdx] = int(raw[16 + imageIdx * rows * cols + rowIdx * cols + colIdx]) / 255.0 if imageIdx % 10000 == 9999 and imageIdx + 1 != nImages: logger.PrintDebug(" " + str(imageIdx + 1) + " Images loaded", end='\r') logger.PrintDebug("Total " + str(nImages) + " Images loaded") elif int.from_bytes(raw[:4], byteorder='big') == 2049: # This is label file nImages = int.from_bytes(raw[4:8], byteorder='big') res = np.zeros((nImages), dtype=np.int32) for imageIdx in range(nImages): res[imageIdx] = int(raw[8 + imageIdx]) logger.PrintDebug("Total " + str(nImages) + " Labels loaded") else: logger.PrintDebug("File is not a MNIST image or label file.", col="r") return -1 return res
def DownloadGlove6B(): # Create directory if not exists if not os.path.exists("./data/"): os.mkdir("./data/") logger.PrintDebug("Downloading glove6B...") p = "./data/glove6B.zip" request.urlretrieve(" http://nlp.stanford.edu/data/glove.6B.zip", p) logger.PrintDebug("Unpacking glove6B...") with zipfile.ZipFile(p, "r") as zr: zr.extractall("data/glove6B/") os.remove(p)
def LoadData(loc, rawloc): if not os.path.exists(loc + ".npy"): res = LoadRawData(rawloc) np.save(loc, res) else: res = np.load(loc + ".npy") logger.PrintDebug("Loaded " + loc + ".npy") return res
def glove6B(sz=50, logDiv=100000): loc = "data/glove6B/glove6B%s" % str(sz) + ".dump" if not os.path.exists(loc): res = _glove6B(sz, logDiv) pickle.dump(res, open(loc, 'wb')) else: res = pickle.load(open(loc, 'rb')) logger.PrintDebug("Loaded " + loc) return res
def Forward(self, input_image, input_word): for i in range(len(self.cnn)): input_image = self.cnn[i].Forward(input_image) if i == 10: input_image_c4 = input_image logger.PrintDebug("CNN " + str(i) + " " + str(input_image.shape)) input_image = (input_image_c4, input_image) for i in range(len(self.concat)): input_image = self.concat[i].Forward(input_image) logger.PrintDebug("Concat " + str(i) + " " + str(input_image.shape)) for i in range(len(self.emb)): input_word = self.emb[i].Forward(input_word) logger.PrintDebug("Embedding " + str(i) + " " + str(input_word.shape)) input_linear = (input_image, input_word) for i in range(len(self.linear)): input_linear = self.linear[i].Forward(input_linear) logger.PrintDebug("Linear " + str(i) + " " + str(input_linear.shape)) return input_linear
def Download(): # Delete existing file logger.PrintDebug("Clearing Folder...") top = ("./data/") for root, dirs, files in os.walk(top, topdown=False): for name in files: os.remove(os.path.join(root, name)) for name in dirs: os.rmdir(os.path.join(root, name)) # Create directory if not exists if not os.path.exists("./data/"): os.mkdir("./data/") # Download Flickr8k logger.PrintDebug("Downloading Flickr8k Text...") p = "./data/Flickr8k_text.zip" request.urlretrieve( "https://github.com/jbrownlee/Datasets/releases/download/Flickr8k/Flickr8k_text.zip", p) logger.PrintDebug("Unpacking Flickr8k Text...") with zipfile.ZipFile(p, "r") as zr: zr.extractall("data/Flickr8k/") os.remove(p) logger.PrintDebug("Downloading Flickr8k Image...") p = "./data/Flickr8k_Dataset.zip" request.urlretrieve( "https://github.com/jbrownlee/Datasets/releases/download/Flickr8k/Flickr8k_Dataset.zip", p) logger.PrintDebug("Unpacking Flickr8k Image...") with zipfile.ZipFile(p, "r") as zr: zr.extractall("data/Flickr8k/") os.remove(p) DownloadGlove6B()
def Flickr8k(mode="test", logDiv=2000, imageSize=[64, 64]): data = list() dataL = list() # Load image list imageList = list() path = "data/Flickr8k/Flickr_8k." + mode + "Images.txt" logger.PrintDebug("Loading " + mode + " image list from " + path) f = open(path, 'r', encoding="utf-8") while True: line = f.readline().replace("\n", "") if not line: break imageList.append(line) logger.PrintDebug("Total of " + str(len(imageList)) + " " + mode + " image list loaded") # Load token tokens = dict() path = "data/Flickr8k/Flickr8k.token.txt" logger.PrintDebug("Loading tokens from " + path) f = open(path, 'r', encoding="utf-8") while True: line = f.readline().replace("\n", "") if not line: break line = line.split("\t") if line[0][:-2] not in tokens: # Add token tokens[line[0][:-2]] = [line[1]] else: tokens[line[0][:-2]].append(line[1]) logger.PrintDebug(str(len(tokens)) + " image tokens are loaded") # Load images and tokens logger.PrintDebug("Loading " + mode + " images from " + path) for ind, val in enumerate(imageList): # Load image data.append( np.asarray( Image.open('data/Flickr8k/Flicker8k_Dataset/' + val).resize( (imageSize[0], imageSize[1])))) # Load tag dataL.append(tokens[val]) # Print log if logDiv > 0 and (ind + 1) % logDiv == 0: logger.PrintDebug(" Data Count = " + str(ind + 1)) logger.PrintDebug("Total of " + str(len(data)) + " " + mode + " images loaded") return data, dataL
continue try: res.append(wordEmb[elem]) except: # logger.PrintDebug("Word not found! %s"%elem,col="r") res.append(np.zeros(emb_size)) exceptions += 1 f.close() r = np.array(res) logger.PrintDebug("Excepted word count : %d" % exceptions, col="r") logger.PrintDebug("Loaded word size : %s" % (str(r.shape)), col="g") return r if __name__ == "__main__": logger.PrintDebug("Simple RNN Network Training", col='b') # Parse args args = args.parse() # Numpy options np.set_printoptions(precision=3) np.set_printoptions(suppress=True) # np.set_printoptions(threshold=sys.maxsize) # Download data if not exists locData = "./data_tinyshakespeare/" url = "http://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt" locFile = "input.txt" if not os.path.exists(locData): os.mkdir(locData) if not os.path.exists(locData + locFile):
return -1 return res def LoadData(loc, rawloc): if not os.path.exists(loc + ".npy"): res = LoadRawData(rawloc) np.save(loc, res) else: res = np.load(loc + ".npy") logger.PrintDebug("Loaded " + loc + ".npy") return res if __name__ == "__main__": logger.PrintDebug("Simple CNN Network Training", col='b') # Parse args args = args.parse() # Numpy options np.set_printoptions(precision=3) np.set_printoptions(suppress=True) # np.set_printoptions(threshold=sys.maxsize) # Download data if not exists locData = "./data_mnist/" url = "http://yann.lecun.com/exdb/mnist/" locTrainLabel = "train-labels-idx1-ubyte.gz" locTrainData = "train-images-idx3-ubyte.gz" locTestLabel = "t10k-labels-idx1-ubyte.gz" locTestData = "t10k-images-idx3-ubyte.gz"
import args import model import load import image from log import logger import numpy as np from model import Model if __name__ == "__main__": logger.PrintDebug("Initiated") # parse args args = args.parse() # Numpy set precision np.set_printoptions(precision=3) # Download Data if argument is checked or falied to verify if args.download or not load.Verify(): load.Download() dataTrain = []; dataTrainLabel = [] # load data to memory # dataTrain, dataTrainLabel = load.Flickr8k(mode = "train") dataTest, dataTestLabel = load.Flickr8k(mode = "test") # Normalize images for i in range(len(dataTest)): dataTest[i] = image.Normalize(dataTest[i]) for i in range(len(dataTrain)): dataTrain[i] = image.Normalize(dataTrain[i]) logger.PrintDebug("Image Normalized")
def Verify(): logger.PrintDebug("Verifing Data...") logger.PrintDebug("Data Verified!") return True