示例#1
0
文件: main.py 项目: ryoma-jp/samples
def main():
	# --- 引数処理 ---
	args = ArgParser()
	print(args.data_type)
	print(args.dataset_dir)
	print(args.model_type)
	
	if (args.data_type == "MNIST"):
		dataset = DataLoaderMNIST(args.dataset_dir)
		print(dataset.train_images.shape)
		print(dataset.train_labels.shape)
		print(dataset.test_images.shape)
		print(dataset.test_labels.shape)
		
		x_train = dataset.train_images / 255
		y_train = dataset.train_labels
		x_test = dataset.test_images / 255
		y_test = dataset.test_labels
		output_dims = dataset.output_dims
	elif (args.data_type == "CIFAR-10"):
		dataset = DataLoaderCIFAR10(args.dataset_dir)
		print(dataset.train_images.shape)
		print(dataset.train_labels.shape)
		print(dataset.test_images.shape)
		print(dataset.test_labels.shape)
		
		x_train = dataset.train_images / 255
		y_train = dataset.train_labels
		x_test = dataset.test_images / 255
		y_test = dataset.test_labels
		output_dims = dataset.output_dims
	else:
		print('[ERROR] Unknown data_type: {}'.format(args.data_type))
		quit()
	
	if (args.model_type == 'MLP'):
		trainer = TrainerMLP(dataset.train_images.shape[1:], output_dir='./output')
		trainer.fit(x_train, y_train, x_test=x_test, y_test=y_test)
		
		predictions = trainer.predict(x_test)
		print('\nPredictions(shape): {}'.format(predictions.shape))
	elif (args.model_type == 'CNN'):
		trainer = TrainerCNN(dataset.train_images.shape[1:], output_dir='./output')
		trainer.fit(x_train, y_train, x_test=x_test, y_test=y_test)
		
		predictions = trainer.predict(x_test)
		print('\nPredictions(shape): {}'.format(predictions.shape))
	elif (args.model_type == 'ResNet'):
		trainer = TrainerResNet(dataset.train_images.shape[1:], output_dims, output_dir='./output')
		trainer.fit(x_train, y_train, x_test=x_test, y_test=y_test)
		
		predictions = trainer.predict(x_test)
		print('\nPredictions(shape): {}'.format(predictions.shape))
	else:
		print('[ERROR] Unknown model_type: {}'.format(args.model_type))
		quit()

	return
示例#2
0
文件: main.py 项目: ryoma-jp/samples
def main():
    # --- NumPy配列形状表示 ---
    def print_ndarray_shape(ndarr):
        if (ndarr is not None):
            print(ndarr.shape)
        else:
            pass
        return

    # --- 引数処理 ---
    args = ArgParser()
    print('[INFO] Arguments')
    print('  * args.data_type = {}'.format(args.data_type))
    print('  * args.dataset_dir = {}'.format(args.dataset_dir))
    print('  * args.model_type = {}'.format(args.model_type))
    print('  * args.data_augmentation = {}'.format(args.data_augmentation))
    print('  * args.optimizer = {}'.format(args.optimizer))
    print('  * args.batch_size = {}'.format(args.batch_size))
    print('  * args.initializer = {}'.format(args.initializer))
    print('  * args.result_dir = {}'.format(args.result_dir))

    # --- Data Augmentationパラメータを辞書型に変換 ---
    if (args.data_augmentation is not None):
        dict_keys = [
            'rotation_range', 'width_shift_range', 'height_shift_range',
            'horizontal_flip'
        ]
        df_da_params = pd.read_csv(io.StringIO(args.data_augmentation),
                                   header=None,
                                   skipinitialspace=True).values[0]

        data_augmentation = {}
        for (key, da_param) in zip(dict_keys, df_da_params):
            data_augmentation[key] = da_param
    else:
        data_augmentaion = None

    if (args.data_type == "MNIST"):
        dataset = DataLoaderMNIST(args.dataset_dir, validation_split=0.2)
        print_ndarray_shape(dataset.train_images)
        print_ndarray_shape(dataset.train_labels)
        print_ndarray_shape(dataset.validation_images)
        print_ndarray_shape(dataset.validation_labels)
        print_ndarray_shape(dataset.test_images)
        print_ndarray_shape(dataset.test_labels)

        if (dataset.train_images is not None):
            x_train = dataset.train_images / 255
        else:
            x_train = None
        y_train = dataset.train_labels
        if (dataset.validation_images is not None):
            x_val = dataset.validation_images / 255
        else:
            x_val = None
        y_val = dataset.validation_labels
        if (dataset.test_images is not None):
            x_test = dataset.test_images / 255
        else:
            x_test = None
        y_test = dataset.test_labels
        output_dims = dataset.output_dims
    elif (args.data_type == "CIFAR-10"):
        dataset = DataLoaderCIFAR10(args.dataset_dir, validation_split=0.2)
        print_ndarray_shape(dataset.train_images)
        print_ndarray_shape(dataset.train_labels)
        print_ndarray_shape(dataset.validation_images)
        print_ndarray_shape(dataset.validation_labels)
        print_ndarray_shape(dataset.test_images)
        print_ndarray_shape(dataset.test_labels)

        if (dataset.train_images is not None):
            x_train = dataset.train_images / 255
        else:
            x_train = None
        y_train = dataset.train_labels
        if (dataset.validation_images is not None):
            x_val = dataset.validation_images / 255
        else:
            x_val = None
        y_val = dataset.validation_labels
        if (dataset.test_images is not None):
            x_test = dataset.test_images / 255
        else:
            x_test = None
        y_test = dataset.test_labels
        output_dims = dataset.output_dims
    else:
        print('[ERROR] Unknown data_type: {}'.format(args.data_type))
        quit()

    if (args.model_type == 'MLP'):
        trainer = TrainerMLP(dataset.train_images.shape[1:],
                             output_dir=args.result_dir,
                             optimizer=args.optimizer,
                             initializer=args.initializer)
        trainer.fit(x_train,
                    y_train,
                    x_val=x_val,
                    y_val=y_val,
                    x_test=x_test,
                    y_test=y_test,
                    batch_size=args.batch_size,
                    da_params=data_augmentation)
        trainer.save_model()

        predictions = trainer.predict(x_test)
        print('\nPredictions(shape): {}'.format(predictions.shape))
    elif (args.model_type == 'SimpleCNN'):
        trainer = TrainerCNN(dataset.train_images.shape[1:],
                             output_dir=args.result_dir,
                             optimizer=args.optimizer,
                             initializer=args.initializer)
        trainer.fit(x_train,
                    y_train,
                    x_val=x_val,
                    y_val=y_val,
                    x_test=x_test,
                    y_test=y_test,
                    batch_size=args.batch_size,
                    da_params=data_augmentation)
        trainer.save_model()

        predictions = trainer.predict(x_test)
        print('\nPredictions(shape): {}'.format(predictions.shape))
    elif (args.model_type == 'SimpleResNet'):
        trainer = TrainerResNet(dataset.train_images.shape[1:],
                                output_dims,
                                output_dir=args.result_dir,
                                optimizer=args.optimizer,
                                initializer=args.initializer)
        trainer.fit(x_train,
                    y_train,
                    x_val=x_val,
                    y_val=y_val,
                    x_test=x_test,
                    y_test=y_test,
                    batch_size=args.batch_size,
                    da_params=data_augmentation)
        trainer.save_model()

        predictions = trainer.predict(x_test)
        print('\nPredictions(shape): {}'.format(predictions.shape))
    else:
        print('[ERROR] Unknown model_type: {}'.format(args.model_type))
        quit()

    return
示例#3
0
文件: main.py 项目: ryoma-jp/samples
def main():
    # --- NumPy配列形状表示 ---
    def print_ndarray_shape(ndarr):
        if (ndarr is not None):
            print(ndarr.shape)
        else:
            pass
        return

    # --- 引数処理 ---
    args = ArgParser()
    print('[INFO] Arguments')
    print('  * args.data_type = {}'.format(args.data_type))
    print('  * args.dataset_dir = {}'.format(args.dataset_dir))
    print('  * args.model_type = {}'.format(args.model_type))
    print('  * args.data_augmentation = {}'.format(args.data_augmentation))
    print('  * args.optimizer = {}'.format(args.optimizer))
    print('  * args.result_dir = {}'.format(args.result_dir))

    if (args.data_type == "MNIST"):
        dataset = DataLoaderMNIST(args.dataset_dir, validation_split=0.2)
        print_ndarray_shape(dataset.train_images)
        print_ndarray_shape(dataset.train_labels)
        print_ndarray_shape(dataset.validation_images)
        print_ndarray_shape(dataset.validation_labels)
        print_ndarray_shape(dataset.test_images)
        print_ndarray_shape(dataset.test_labels)

        if (dataset.train_images is not None):
            x_train = dataset.train_images / 255
        else:
            x_train = None
        y_train = dataset.train_labels
        if (dataset.validation_images is not None):
            x_val = dataset.validation_images / 255
        else:
            x_val = None
        y_val = dataset.validation_labels
        if (dataset.test_images is not None):
            x_test = dataset.test_images / 255
        else:
            x_test = None
        y_test = dataset.test_labels
        output_dims = dataset.output_dims
    elif (args.data_type == "CIFAR-10"):
        dataset = DataLoaderCIFAR10(args.dataset_dir, validation_split=0.2)
        print_ndarray_shape(dataset.train_images)
        print_ndarray_shape(dataset.train_labels)
        print_ndarray_shape(dataset.validation_images)
        print_ndarray_shape(dataset.validation_labels)
        print_ndarray_shape(dataset.test_images)
        print_ndarray_shape(dataset.test_labels)

        if (dataset.train_images is not None):
            x_train = dataset.train_images / 255
        else:
            x_train = None
        y_train = dataset.train_labels
        if (dataset.validation_images is not None):
            x_val = dataset.validation_images / 255
        else:
            x_val = None
        y_val = dataset.validation_labels
        if (dataset.test_images is not None):
            x_test = dataset.test_images / 255
        else:
            x_test = None
        y_test = dataset.test_labels
        output_dims = dataset.output_dims
    else:
        print('[ERROR] Unknown data_type: {}'.format(args.data_type))
        quit()

    if (args.model_type == 'MLP'):
        trainer = TrainerMLP(dataset.train_images.shape[1:],
                             output_dir=args.result_dir,
                             optimizer=args.optimizer)
        trainer.fit(x_train,
                    y_train,
                    x_val=x_val,
                    y_val=y_val,
                    x_test=x_test,
                    y_test=y_test,
                    da_enable=args.data_augmentation)
        trainer.save_model()

        predictions = trainer.predict(x_test)
        print('\nPredictions(shape): {}'.format(predictions.shape))
    elif (args.model_type == 'SimpleCNN'):
        trainer = TrainerCNN(dataset.train_images.shape[1:],
                             output_dir=args.result_dir,
                             optimizer=args.optimizer)
        trainer.fit(x_train,
                    y_train,
                    x_val=x_val,
                    y_val=y_val,
                    x_test=x_test,
                    y_test=y_test,
                    da_enable=args.data_augmentation)
        trainer.save_model()

        predictions = trainer.predict(x_test)
        print('\nPredictions(shape): {}'.format(predictions.shape))
    elif (args.model_type == 'SimpleResNet'):
        trainer = TrainerResNet(dataset.train_images.shape[1:],
                                output_dims,
                                output_dir=args.result_dir,
                                optimizer=args.optimizer)
        trainer.fit(x_train,
                    y_train,
                    x_val=x_val,
                    y_val=y_val,
                    x_test=x_test,
                    y_test=y_test,
                    da_enable=args.data_augmentation)
        trainer.save_model()

        predictions = trainer.predict(x_test)
        print('\nPredictions(shape): {}'.format(predictions.shape))
    else:
        print('[ERROR] Unknown model_type: {}'.format(args.model_type))
        quit()

    return
示例#4
0
def main():
	# --- NumPy配列形状表示 ---
	def print_ndarray_shape(ndarr):
		if (ndarr is not None):
			print(ndarr.shape)
		else:
			pass
		return
		
	# --- 引数処理 ---
	args = ArgParser()
	print('[INFO] Arguments')
	print('  * args.fifo = {}'.format(args.fifo))
	print('  * args.data_type = {}'.format(args.data_type))
	print('  * args.dataset_dir = {}'.format(args.dataset_dir))
	print('  * args.model_type = {}'.format(args.model_type))
	print('  * args.data_augmentation = {}'.format(args.data_augmentation))
	print('  * args.optimizer = {}'.format(args.optimizer))
	print('  * args.batch_size = {}'.format(args.batch_size))
	print('  * args.initializer = {}'.format(args.initializer))
	print('  * args.data_norm = {}'.format(args.data_norm))
	print('  * args.dropout_rate = {}'.format(args.dropout_rate))
	print('  * args.loss_func = {}'.format(args.loss_func))
	print('  * args.epochs = {}'.format(args.epochs))
	print('  * args.result_dir = {}'.format(args.result_dir))
	
	# --- Data Augmentationパラメータを辞書型に変換 ---
	if (args.data_augmentation is not None):
		dict_keys = ['rotation_range', 'width_shift_range', 'height_shift_range', 'zoom_range', 'channel_shift_range', 'horizontal_flip']
		df_da_params = pd.read_csv(io.StringIO(args.data_augmentation), header=None, skipinitialspace=True).values[0]
		
		data_augmentation = {}
		for (key, da_param) in zip(dict_keys, df_da_params):
			data_augmentation[key] = da_param
	else:
		data_augmentaion = None
	
	if (args.loss_func == "sparse_categorical_crossentropy"):
		one_hot = False
	else:
		one_hot = True
	if (args.data_type == "MNIST"):
		dataset = DataLoaderMNIST(args.dataset_dir, validation_split=0.2, one_hot=one_hot)
	elif (args.data_type == "CIFAR-10"):
		dataset = DataLoaderCIFAR10(args.dataset_dir, validation_split=0.2, one_hot=one_hot)
	else:
		print('[ERROR] Unknown data_type: {}'.format(args.data_type))
		quit()
		
	print_ndarray_shape(dataset.train_images)
	print_ndarray_shape(dataset.train_labels)
	print_ndarray_shape(dataset.validation_images)
	print_ndarray_shape(dataset.validation_labels)
	print_ndarray_shape(dataset.test_images)
	print_ndarray_shape(dataset.test_labels)
	
	x_train, x_val, x_test = dataset.normalization(args.data_norm)
	y_train = dataset.train_labels
	y_val = dataset.validation_labels
	y_test = dataset.test_labels
	output_dims = dataset.output_dims
	
	if (args.model_type == 'MLP'):
		trainer = TrainerMLP(dataset.train_images.shape[1:], output_dir=args.result_dir,
			optimizer=args.optimizer, initializer=args.initializer)
	elif (args.model_type == 'SimpleCNN'):
		trainer = TrainerCNN(dataset.train_images.shape[1:], output_dir=args.result_dir,
			optimizer=args.optimizer, loss=args.loss_func, initializer=args.initializer)
	elif (args.model_type == 'DeepCNN'):
		trainer = TrainerCNN(dataset.train_images.shape[1:], output_dir=args.result_dir,
			optimizer=args.optimizer, loss=args.loss_func, initializer=args.initializer, model_type='deep_model')
	elif (args.model_type == 'SimpleResNet'):
		trainer = TrainerResNet(dataset.train_images.shape[1:], output_dims, output_dir=args.result_dir,
			model_type='custom', 
			optimizer=args.optimizer, loss=args.loss_func, initializer=args.initializer, dropout_rate=args.dropout_rate)
	elif (args.model_type == 'DeepResNet'):
		trainer = TrainerResNet(dataset.train_images.shape[1:], output_dims, output_dir=args.result_dir,
			model_type='custom_deep', 
			optimizer=args.optimizer, loss=args.loss_func, initializer=args.initializer, dropout_rate=args.dropout_rate)
	else:
		print('[ERROR] Unknown model_type: {}'.format(args.model_type))
		quit()
	trainer.fit(args.fifo, x_train, y_train, x_val=x_val, y_val=y_val, x_test=x_test, y_test=y_test,
		batch_size=args.batch_size, da_params=data_augmentation, epochs=args.epochs)
	trainer.save_model()
	
	predictions = trainer.predict(x_test)
	print('\nPredictions(shape): {}'.format(predictions.shape))

	return