Пример #1
0

class OverfitGenerator(object):
    @staticmethod
    def generate_batch(batch_size):
        return overfit_inputs, overfit_targets


overfit_inputs, overfit_targets = DataGenerator(
    pattern='data/train/*',
    image_size=64,
    max_flow=5,
    max_scale=5,
    noise_level=5,
    interp='bicubic').generate_batch(batch_size=100)

validation_data = DataGenerator(
    pattern='data/test/*',
    image_size=64,
    max_flow=5,
    max_scale=5,
    noise_level=5,
    interp='bicubic').generate_batch(batch_size=100)

cnn = CNN(split=False,
          normalize=True,
          fully_connected=None,
          learning_rate=1e-4)

cnn.train(OverfitGenerator, validation_data, 1000)
Пример #2
0
import datetime

from models.generator import DataGenerator
from models.cnn import CNN

train_generator = DataGenerator(pattern='data/train/*',
                                image_size=64,
                                max_flow=5,
                                max_scale=5,
                                noise_level=5,
                                interp='bicubic')

validation_data = DataGenerator(
    pattern='data/test/*',
    image_size=64,
    max_flow=5,
    max_scale=5,
    noise_level=5,
    interp='bicubic').generate_batch(batch_size=100)

log_path = datetime.datetime.now().strftime('.logs/%Y%m%d-%H%M%S/')

cnn = CNN(split=True, fully_connected=500, normalize=True, learning_rate=2e-4)

cnn.train(train_generator, validation_data, 5000, log_path=log_path)
Пример #3
0
            except IndexError:
                print(f"Supply a pretrained model with -p. Exiting.")
                exit(0)

        # check the supplied model
        if selected_model not in models.keys():
            print(f"Model unknown: {selected_model}. Exiting.")
            exit(0)

        else:
            print(f"Running model: {selected_model}")

            model = None
            if models.get(selected_model) == "CNN":
                model = CNN(selected_model, save_name=save_name, pretrained_model=pretrained_model)

            elif models.get(selected_model) == "RNN":
                model = RNN(selected_model, save_name=save_name, pretrained_model=pretrained_model)

            # proceed if we've got a valid model
            if model is not None:

                # train and evaluate
                if pretrained_model is None:
                    model.train()
                model.evaluate()

                # export the model for importing later (save time training from scratch)
                if save_name is not None:
                    model.export_model()
Пример #4
0
from models.generator import DataGenerator
from models.cnn import CNN

train_generator = DataGenerator(pattern='data/train/garden.png',
                                image_size=64,
                                max_flow=5,
                                max_scale=1,
                                noise_level=0)

validation_data = DataGenerator(pattern='data/train/carpet.png',
                                image_size=64,
                                max_flow=5,
                                max_scale=1,
                                noise_level=0).generate_batch(batch_size=1000)

convolution_layers = [
    [32, 3, 2],  # 32 filters, kernel size 3, stride 2
    [64, 3, 2],
    [128, 3, 1],
    [256, 3, 2]
]

cnn = CNN(config=convolution_layers,
          split=False,
          normalize=False,
          fully_connected=None,
          learning_rate=1e-4)

cnn.train(train_generator, validation_data, 25000)
Пример #5
0
import os
from processing import get_train_val_split, get_jpeg_files

from models.cnn import CNN

labels_file = os.path.join(os.getcwd(), "data", "train.csv")
train_X, val_X, train_Y, val_Y = get_train_val_split(labels_file,
                                                     split_ratio=0.1)

model = CNN()
model.init()
model.train(train_X, train_Y, val_X, val_Y, epochs=50)
print "model training complete!!"

# Predict top k labels for test data images, and store in submission format
k = 5
test_images_dir = os.path.join(os.getcwd(), "data", "test")
test_images = get_jpeg_files(test_images_dir)

model.write_test_output(test_images, k)