Beispiel #1
0
def main():
    seed = 42
    np.random.seed(seed)

    current_dir = os.path.dirname(__file__)
    sys.path.append(os.path.join(current_dir, '..'))
    current_dir = current_dir if current_dir is not '' else '.'

    img_dir_path = 'jpg'
    txt_dir_path = 'flowers/text_c10'
    model_dir_path = current_dir + '/models'

    img_width = 32
    img_height = 32

    from dcgan import DCGan
    from image_utils import img_from_normalized_img
    from img_cap_loader import load_normalized_img_and_its_text

    image_label_pairs = load_normalized_img_and_its_text(img_dir_path,
                                                         txt_dir_path,
                                                         img_width=img_width,
                                                         img_height=img_height)

    shuffle(image_label_pairs)

    gan = DCGan()
    gan.load_model(model_dir_path)

    for i in range(10):
        image_label_pair = image_label_pairs[i]
        normalized_image = image_label_pair[0]
        text = image_label_pair[1]
        print(text)

        image = img_from_normalized_img(normalized_image)
        image.save(current_dir + '/data/outputs/' + DCGan.model_name +
                   '-generated-' + str(i) + '-0.png')
        for j in range(3):
            generated_image = gan.generate_image_from_text(text)
            generated_image.save(current_dir + '/data/outputs/' +
                                 DCGan.model_name + '-generated-' + str(i) +
                                 '-' + str(j) + '.png')
Beispiel #2
0
def main():
    seed = 42
    np.random.seed(seed)

    current_dir = os.path.dirname(__file__)
    sys.path.append(os.path.join(current_dir, '..'))
    current_dir = current_dir if current_dir is not '' else '.'

    img_dir_path = 'jpg'
    txt_dir_path = 'flowers/text_c10'
    model_dir_path = current_dir + '/models'

    img_width = 32
    img_height = 32
    img_channels = 3

    from dcgan import DCGan
    from img_cap_loader import load_normalized_img_and_its_text

    image_label_pairs = load_normalized_img_and_its_text(img_dir_path,
                                                         txt_dir_path,
                                                         img_width=img_width,
                                                         img_height=img_height)

    shuffle(image_label_pairs)

    gan = DCGan()
    gan.img_width = img_width
    gan.img_height = img_height
    gan.img_channels = img_channels
    gan.random_input_dim = 200
    gan.glove_source_dir_path = './very_large_data'

    batch_size = 16
    epochs = 300
    gan.fit(model_dir_path=model_dir_path,
            image_label_pairs=image_label_pairs,
            snapshot_dir_path=current_dir + '/data/snapshots',
            snapshot_interval=100,
            batch_size=batch_size,
            epochs=epochs)
Beispiel #3
0
def main(_=None):
    parser = argparse.ArgumentParser(description='Train GAN')
    parser.add_argument('--config', required=True, help="config path")
    parser.add_argument('--name', required=True, help="model name")
    parser.add_argument('--model-dir', required=True, help="model directory")
    parser.add_argument('--data-dir', required=True, help="data directory")
    parser.add_argument('--epochs',
                        default=1000,
                        help="number of training epochs")

    args = parser.parse_args()

    CONFIG_PATH = args.config
    MODEL_NAME = args.name
    MODEL_DIR = Path(args.model_dir)
    DATA_DIR = Path(args.data_dir)
    NB_EPOCHS = args.epochs

    # load model config
    with open(CONFIG_PATH, 'r') as f:
        config = yaml.load(f)
    IMG_SHAPE = config['data']['input_shape']

    # load data
    train_ds = load_celeba_tfdataset(DATA_DIR, config, zipped=False)
    test_ds = load_celeba_tfdataset(DATA_DIR, config, zipped=False)

    # instantiate GAN
    gan = DCGan(IMG_SHAPE, config)

    # setup model directory for checkpoint and tensorboard logs
    model_dir = MODEL_DIR / MODEL_NAME
    model_dir.mkdir(exist_ok=True, parents=True)
    log_dir = model_dir / "logs" / datetime.now().strftime("%Y%m%d-%H%M%S")

    # run train
    gan._train(train_ds=gan.setup_dataset(train_ds),
               validation_ds=gan.setup_dataset(test_ds),
               nb_epochs=NB_EPOCHS,
               log_dir=log_dir,
               checkpoint_dir=None,
               is_tfdataset=True)
Beispiel #4
0
mode = 'test'

img_width = 64
img_height = 64
img_channels = 3

from dcgan import DCGan

image_label_pairs = load_normalized_img_and_text(img_dir_path,
                                                 txt_dir_path,
                                                 img_width=img_width,
                                                 img_height=img_height)

shuffle(image_label_pairs)

gan = DCGan()
gan.img_width = img_width
gan.img_height = img_height
gan.img_channels = img_channels
gan.random_input_dim = 200
gan.glove_source_dir_path = './very_large_data'

batch_size = 5
epochs = 2000

if mode == 'train':
    #training
    start_time = time.time()

    logs = gan.fit(model_dir_path=model_dir_path,
                   image_label_pairs=image_label_pairs,