class ReadMnistCsv(object): """ Class to read MNIST data from CSV files see: http://pjreddie.com/projects/mnist-in-csv/ http://makeyourownneuralnetwork.blogspot.com/2015/03/the-mnist-dataset-of-handwitten-digits.html """ CLASSES = {0: [1, 0, 0, 0, 0, 0, 0, 0, 0, 0], 1: [0, 1, 0, 0, 0, 0, 0, 0, 0, 0], 2: [0, 0, 1, 0, 0, 0, 0, 0, 0, 0], 3: [0, 0, 0, 1, 0, 0, 0, 0, 0, 0], 4: [0, 0, 0, 0, 1, 0, 0, 0, 0, 0], 5: [0, 0, 0, 0, 0, 1, 0, 0, 0, 0], 6: [0, 0, 0, 0, 0, 0, 1, 0, 0, 0], 7: [0, 0, 0, 0, 0, 0, 0, 1, 0, 0], 8: [0, 0, 0, 0, 0, 0, 0, 0, 1, 0], 9: [0, 0, 0, 0, 0, 0, 0, 0, 0, 1]} def __init__(self, input_file): """ """ self.examples = LearningExamples() f = open(input_file, 'r') mnist = f.readlines() f.close() cnt = 0 for line in mnist: fields = line.split(',') # log.info('line %d: len fields = %d', cnt, len(fields)) classes = self.CLASSES[int(fields[0])] pixels = list() for j in range(len(fields[1:])): pixels.append(float(fields[j + 1]) / 255.0) self.examples.add_data(pixels, classes) cnt += 1 if cnt % 10000 == 0: log.info('line %d: len pixels = %d, len fields = %d', cnt, len(pixels), len(fields)) def write_file(self, file_path): """ :param file_path: path to desired output file :return: """ self.examples.write_to(file_path)
ann_input = ann_input.resize((scaled_width, scaled_height), Image.ANTIALIAS) ann_input_data = list(ann_input.getdata()) ann_output_data = list(ann_output.getdata()) flattened_ann_input_data = [] flattened_ann_output_data = [] for pixel in ann_output_data: flattened_ann_output_data.append(flatten_and_rescale_pixel(pixel)) for pixel in ann_input_data: flattened_ann_input_data.append(flatten_and_rescale_pixel(pixel)) if len(flattened_ann_input_data) != ann_num_inputs: log.error('len(flattened_ann_input_data) [%d] != ann_num_inputs [%d]', len(flattened_ann_input_data), ann_num_inputs) if len(flattened_ann_output_data) != ann_num_outputs: log.error('len(flattened_ann_output_data) [%d] != ann_num_outputs [%d]', len(flattened_ann_output_data), ann_num_outputs) learn_examples.add_data(flattened_ann_input_data, flattened_ann_output_data) # temp_string = '' # for pixel in flattened_ann_input_data: # temp_string += '%7.6f ' % pixel # for pixel in flattened_ann_output_data: # temp_string += '%7.6f ' % pixel # output_path.write('%s\n' % temp_string) except Exception, exception_obj: log.error("error getting subimage from %s, skipping: %s", options.image_dir + '/' + input_filename, str(exception_obj)) cnt -= 1 continue learn_examples.write_to(options.output_path) except Exception, except_obj: log.exception('image file %s: %s', options.image_dir + '/' + input_filename, str(except_obj))