예제 #1
0
def Load(loc, wordEmb, emb_size=50):
    res = []
    table = str.maketrans(dict.fromkeys(string.punctuation))
    f = open(loc, 'r')
    exceptions = 0
    while True:
        line = f.readline()
        if not line:
            break
        if len(line) < 1:
            continue
        # Remove Punctuations
        # Should I handle with 'd -> ed???
        line = line.replace("'d", "ed").translate(table).lower().rstrip()
        line = line.split(" ")
        for elem in line:
            if elem == "":
                continue
            try:
                res.append(wordEmb[elem])
            except:
                # logger.PrintDebug("Word not found! %s"%elem,col="r")
                res.append(np.zeros(emb_size))
                exceptions += 1
    f.close()
    r = np.array(res)
    logger.PrintDebug("Excepted word count : %d" % exceptions, col="r")
    logger.PrintDebug("Loaded word size : %s" % (str(r.shape)), col="g")
    return r
예제 #2
0
def _glove6B(sz=50, logDiv=100000):
    d = dict()
    if sz not in [50, 100, 200, 300]:
        logger.PrintDebug(
            "Not supported word embedding size, returning empty dict")
        return d
    path = "data/glove6B/glove.6B." + str(sz) + "d.txt"
    logger.PrintDebug("Reading embedded word file from" + path)
    f = open(path, 'r', encoding="utf-8")
    while True:
        line = f.readline()
        if not line: break
        line = line.replace("\n", "").split(" ")
        if len(line) != sz + 1:
            break
        l = np.zeros(sz)
        for i in range(sz):
            l[i] = float(line[i + 1])
        d[line[0]] = l
        # Print log
        if logDiv > 0 and len(d) % logDiv == 0:
            logger.PrintDebug("  Data Count = " + str(len(d)))
    f.close()
    logger.PrintDebug("Loaded " + str(len(d)) + " embedded word data")
    return d
예제 #3
0
def LoadRawData(loc):
    logger.PrintDebug("Loading data from " + loc)
    with gzip.open(loc, 'rb') as f:
        raw = f.read()
        if int.from_bytes(raw[:4], byteorder='big') == 2051:
            # This is image file
            nImages = int.from_bytes(raw[4:8], byteorder='big')
            rows = int.from_bytes(raw[8:12], byteorder='big')
            cols = int.from_bytes(raw[12:16], byteorder='big')
            res = np.zeros((nImages, 1, rows, cols))

            for imageIdx in range(nImages):
                for rowIdx in range(rows):
                    for colIdx in range(cols):
                        res[imageIdx, 0, rowIdx,
                            colIdx] = int(raw[16 + imageIdx * rows * cols +
                                              rowIdx * cols + colIdx]) / 255.0
                if imageIdx % 10000 == 9999 and imageIdx + 1 != nImages:
                    logger.PrintDebug("  " + str(imageIdx + 1) +
                                      " Images loaded",
                                      end='\r')
            logger.PrintDebug("Total " + str(nImages) + " Images loaded")
        elif int.from_bytes(raw[:4], byteorder='big') == 2049:
            # This is label file
            nImages = int.from_bytes(raw[4:8], byteorder='big')
            res = np.zeros((nImages), dtype=np.int32)
            for imageIdx in range(nImages):
                res[imageIdx] = int(raw[8 + imageIdx])
            logger.PrintDebug("Total " + str(nImages) + " Labels loaded")
        else:
            logger.PrintDebug("File is not a MNIST image or label file.",
                              col="r")
            return -1
    return res
예제 #4
0
def DownloadGlove6B():
    # Create directory if not exists
    if not os.path.exists("./data/"):
        os.mkdir("./data/")

    logger.PrintDebug("Downloading glove6B...")
    p = "./data/glove6B.zip"
    request.urlretrieve(" http://nlp.stanford.edu/data/glove.6B.zip", p)
    logger.PrintDebug("Unpacking glove6B...")
    with zipfile.ZipFile(p, "r") as zr:
        zr.extractall("data/glove6B/")
    os.remove(p)
예제 #5
0
def LoadData(loc, rawloc):
    if not os.path.exists(loc + ".npy"):
        res = LoadRawData(rawloc)
        np.save(loc, res)
    else:
        res = np.load(loc + ".npy")
        logger.PrintDebug("Loaded " + loc + ".npy")
    return res
예제 #6
0
def glove6B(sz=50, logDiv=100000):
    loc = "data/glove6B/glove6B%s" % str(sz) + ".dump"
    if not os.path.exists(loc):
        res = _glove6B(sz, logDiv)
        pickle.dump(res, open(loc, 'wb'))
    else:
        res = pickle.load(open(loc, 'rb'))
        logger.PrintDebug("Loaded " + loc)
    return res
예제 #7
0
 def Forward(self, input_image, input_word):
     for i in range(len(self.cnn)):
         input_image = self.cnn[i].Forward(input_image)
         if i == 10:
             input_image_c4 = input_image
         logger.PrintDebug("CNN " + str(i) + " " + str(input_image.shape))
     input_image = (input_image_c4, input_image)
     for i in range(len(self.concat)):
         input_image = self.concat[i].Forward(input_image)
         logger.PrintDebug("Concat " + str(i) + " " +
                           str(input_image.shape))
     for i in range(len(self.emb)):
         input_word = self.emb[i].Forward(input_word)
         logger.PrintDebug("Embedding " + str(i) + " " +
                           str(input_word.shape))
     input_linear = (input_image, input_word)
     for i in range(len(self.linear)):
         input_linear = self.linear[i].Forward(input_linear)
         logger.PrintDebug("Linear " + str(i) + " " +
                           str(input_linear.shape))
     return input_linear
예제 #8
0
def Download():

    # Delete existing file
    logger.PrintDebug("Clearing Folder...")
    top = ("./data/")
    for root, dirs, files in os.walk(top, topdown=False):
        for name in files:
            os.remove(os.path.join(root, name))
        for name in dirs:
            os.rmdir(os.path.join(root, name))

    # Create directory if not exists
    if not os.path.exists("./data/"):
        os.mkdir("./data/")

    # Download Flickr8k
    logger.PrintDebug("Downloading Flickr8k Text...")
    p = "./data/Flickr8k_text.zip"
    request.urlretrieve(
        "https://github.com/jbrownlee/Datasets/releases/download/Flickr8k/Flickr8k_text.zip",
        p)
    logger.PrintDebug("Unpacking Flickr8k Text...")
    with zipfile.ZipFile(p, "r") as zr:
        zr.extractall("data/Flickr8k/")
    os.remove(p)

    logger.PrintDebug("Downloading Flickr8k Image...")
    p = "./data/Flickr8k_Dataset.zip"
    request.urlretrieve(
        "https://github.com/jbrownlee/Datasets/releases/download/Flickr8k/Flickr8k_Dataset.zip",
        p)
    logger.PrintDebug("Unpacking Flickr8k Image...")
    with zipfile.ZipFile(p, "r") as zr:
        zr.extractall("data/Flickr8k/")
    os.remove(p)

    DownloadGlove6B()
예제 #9
0
def Flickr8k(mode="test", logDiv=2000, imageSize=[64, 64]):
    data = list()
    dataL = list()

    # Load image list
    imageList = list()
    path = "data/Flickr8k/Flickr_8k." + mode + "Images.txt"
    logger.PrintDebug("Loading " + mode + " image list from " + path)
    f = open(path, 'r', encoding="utf-8")
    while True:
        line = f.readline().replace("\n", "")
        if not line: break
        imageList.append(line)
    logger.PrintDebug("Total of " + str(len(imageList)) + " " + mode +
                      " image list loaded")

    # Load token
    tokens = dict()
    path = "data/Flickr8k/Flickr8k.token.txt"
    logger.PrintDebug("Loading tokens from " + path)
    f = open(path, 'r', encoding="utf-8")
    while True:
        line = f.readline().replace("\n", "")
        if not line: break
        line = line.split("\t")

        if line[0][:-2] not in tokens:  # Add token
            tokens[line[0][:-2]] = [line[1]]
        else:
            tokens[line[0][:-2]].append(line[1])
    logger.PrintDebug(str(len(tokens)) + " image tokens are loaded")

    # Load images and tokens
    logger.PrintDebug("Loading " + mode + " images from " + path)
    for ind, val in enumerate(imageList):
        # Load image
        data.append(
            np.asarray(
                Image.open('data/Flickr8k/Flicker8k_Dataset/' + val).resize(
                    (imageSize[0], imageSize[1]))))
        # Load tag
        dataL.append(tokens[val])
        # Print log
        if logDiv > 0 and (ind + 1) % logDiv == 0:
            logger.PrintDebug("  Data Count = " + str(ind + 1))
    logger.PrintDebug("Total of " + str(len(data)) + " " + mode +
                      " images loaded")

    return data, dataL
예제 #10
0
                continue
            try:
                res.append(wordEmb[elem])
            except:
                # logger.PrintDebug("Word not found! %s"%elem,col="r")
                res.append(np.zeros(emb_size))
                exceptions += 1
    f.close()
    r = np.array(res)
    logger.PrintDebug("Excepted word count : %d" % exceptions, col="r")
    logger.PrintDebug("Loaded word size : %s" % (str(r.shape)), col="g")
    return r


if __name__ == "__main__":
    logger.PrintDebug("Simple RNN Network Training", col='b')
    # Parse args
    args = args.parse()

    # Numpy options
    np.set_printoptions(precision=3)
    np.set_printoptions(suppress=True)
    # np.set_printoptions(threshold=sys.maxsize)

    # Download data if not exists
    locData = "./data_tinyshakespeare/"
    url = "http://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt"
    locFile = "input.txt"
    if not os.path.exists(locData):
        os.mkdir(locData)
    if not os.path.exists(locData + locFile):
예제 #11
0
            return -1
    return res


def LoadData(loc, rawloc):
    if not os.path.exists(loc + ".npy"):
        res = LoadRawData(rawloc)
        np.save(loc, res)
    else:
        res = np.load(loc + ".npy")
        logger.PrintDebug("Loaded " + loc + ".npy")
    return res


if __name__ == "__main__":
    logger.PrintDebug("Simple CNN Network Training", col='b')
    # Parse args
    args = args.parse()

    # Numpy options
    np.set_printoptions(precision=3)
    np.set_printoptions(suppress=True)
    # np.set_printoptions(threshold=sys.maxsize)

    # Download data if not exists
    locData = "./data_mnist/"
    url = "http://yann.lecun.com/exdb/mnist/"
    locTrainLabel = "train-labels-idx1-ubyte.gz"
    locTrainData = "train-images-idx3-ubyte.gz"
    locTestLabel = "t10k-labels-idx1-ubyte.gz"
    locTestData = "t10k-images-idx3-ubyte.gz"
예제 #12
0
import args
import model
import load
import image
from log import logger
import numpy as np
from model import Model

if __name__ == "__main__":
    logger.PrintDebug("Initiated")
    # parse args
    args = args.parse()
    # Numpy set precision
    np.set_printoptions(precision=3)

    # Download Data if argument is checked or falied to verify
    if args.download or not load.Verify():
        load.Download()
    
    dataTrain = []; dataTrainLabel = []
    # load data to memory
    # dataTrain, dataTrainLabel = load.Flickr8k(mode = "train")
    dataTest, dataTestLabel = load.Flickr8k(mode = "test")

    # Normalize images
    for i in range(len(dataTest)):
        dataTest[i] = image.Normalize(dataTest[i])
    for i in range(len(dataTrain)):
        dataTrain[i] = image.Normalize(dataTrain[i])
    logger.PrintDebug("Image Normalized")
예제 #13
0
def Verify():
    logger.PrintDebug("Verifing Data...")
    logger.PrintDebug("Data Verified!")
    return True