예제 #1
0
def load_mnist(path=".", normalize=True):
    """
    Fetch the MNIST dataset and load it into memory.

    Args:
        path (str, optional): Local directory in which to cache the raw
                              dataset.  Defaults to current directory.
        normalize (bool, optional): whether to scale values between 0 and 1.
                                    Defaults to True.

    Returns:
        tuple: Both training and test sets are returned.
    """
    mnist = dataset_meta['mnist']

    filepath = os.path.join(path, mnist['file'])
    if not os.path.exists(filepath):
        fetch_dataset(mnist['url'], mnist['file'], filepath, mnist['size'])

    with gzip.open(filepath, 'rb') as mnist:
        (X_train, y_train), (X_test, y_test) = cPickle.load(mnist)
        # X_train = X_train.reshape(-1, 1, 28, 28)
        # X_test = X_test.reshape(-1, 1, 28, 28)
        X_train = X_train.reshape(-1, 784)
        X_test = X_test.reshape(-1, 784)

        # X_train = X_train[:, :100]
        # X_test = X_test[:, :100]

        if normalize:
            X_train = X_train / 255.
            X_test = X_test / 255.

        return (X_train, y_train), (X_test, y_test), 10
예제 #2
0
    def bleu_score(self, sents, targets):
        """
        Compute the BLEU score from a list of predicted sentences and reference sentences

        Args:
            sents (list): list of predicted sentences
            targets (list): list of reference sentences where each element is a list of
                            multiple references.
        """

        num_ref = len(targets[0])
        output_file = self.path + '/output'
        reference_files = [
            self.path + '/reference%d' % i for i in range(num_ref)
        ]
        bleu_script_url = 'https://raw.githubusercontent.com/karpathy/neuraltalk/master/eval/'
        bleu_script = 'multi-bleu.perl'

        print "Writing output and reference sents to dir %s" % self.path

        output_f = open(output_file, 'w+')
        for sent in sents:
            sent = sent.strip(self.end_token).split()
            output_f.write(" ".join(sent) + '\n')

        reference_f = [open(f, 'w') for f in reference_files]
        for i in range(num_ref):
            for target_sents in targets:
                reference_f[i].write(target_sents[i] + '\n')

        output_f.close()
        [x.close() for x in reference_f]

        owd = os.getcwd()
        os.chdir(self.path)
        if not os.path.exists(bleu_script):
            fetch_dataset(bleu_script_url, bleu_script, bleu_script, 6e6)
        bleu_command = 'perl multi-bleu.perl reference < output'
        print "Executing bleu eval script: ", bleu_command
        os.system(bleu_command)
        os.chdir(owd)
예제 #3
0
    def bleu_score(self, sents, targets):
        """
        Compute the BLEU score from a list of predicted sentences and reference sentences

        Args:
            sents (list): list of predicted sentences
            targets (list): list of reference sentences where each element is a list of
                            multiple references.
        """

        num_ref = len(targets[0])
        output_file = self.path + "/output"
        reference_files = [self.path + "/reference%d" % i for i in range(num_ref)]
        bleu_script_url = "https://raw.githubusercontent.com/karpathy/neuraltalk/master/eval/"
        bleu_script = "multi-bleu.perl"

        print "Writing output and reference sents to dir %s" % self.path

        output_f = open(output_file, "w+")
        for sent in sents:
            sent = sent.strip(self.end_token).split()
            output_f.write(" ".join(sent) + "\n")

        reference_f = [open(f, "w") for f in reference_files]
        for i in range(num_ref):
            for target_sents in targets:
                reference_f[i].write(target_sents[i] + "\n")

        output_f.close()
        [x.close() for x in reference_f]

        owd = os.getcwd()
        os.chdir(self.path)
        if not os.path.exists(bleu_script):
            fetch_dataset(bleu_script_url, bleu_script, bleu_script, 6e6)
        bleu_command = "perl multi-bleu.perl reference < output"
        print "Executing bleu eval script: ", bleu_command
        os.system(bleu_command)
        os.chdir(owd)