def stylize_image(image_path, style, style_weight):
    image = load_img(image_path)
    style_weights = {1: '10',
                     2: '100',
                     3: '1000'}
    print('stylise_image')
    transformer = TransformerNet()
    optimizer = tf.optimizers.Adam(learning_rate = 0.001)
    ckpt = tf.train.Checkpoint(transformer=transformer, optimizer=optimizer, step=tf.Variable(1))
    # ckpt = tf.train.Checkpoint(transformer=transformer)
    # ckpt.restore(tf.train.latest_checkpoint(args.log_dir)).expect_partial()
    model_path = f'models/style/{style}_sw{style_weights[style_weight]}'
    ckpt.restore(tf.train.latest_checkpoint(model_path)).expect_partial()
    # ckpt.restore(tf.train.latest_checkpoint('models/style/la_muse_contentlayer33_sw100')).expect_partial()
    print('ckpt')

    transformed_image = transformer(image)
    print('transformed_image')
    transformed_image = tf.cast(
        tf.squeeze(transformed_image), tf.uint8
    ).numpy()
    print('transformed_image cast')

    img = Image.fromarray(transformed_image, mode="RGB")
    output_path = f'images/output/output_{style}_sw{style_weight}.png'
    img.save(output_path)
    return output_path
예제 #2
0
import tensorflow as tf

from PIL import Image
from argparse import ArgumentParser

from networks import TransformerNet
from utils import load_img

if __name__ == "__main__":
    parser = ArgumentParser()
    parser.add_argument("--log-dir")
    parser.add_argument("--image-path")
    parser.add_argument("--output-path")
    args = parser.parse_args()

    image = load_img(args.image_path)

    transformer = TransformerNet()
    ckpt = tf.train.Checkpoint(transformer=transformer)
    ckpt.restore(tf.train.latest_checkpoint(args.log_dir)).expect_partial()

    transformed_image = transformer(image)
    transformed_image = tf.cast(tf.squeeze(transformed_image),
                                tf.uint8).numpy()

    img = Image.fromarray(transformed_image, mode="RGB")
    img.save(args.output_path)
예제 #3
0
def train(training_input, training_labels, num_unique_tokens, sequence_length):
    # Loader Parameters
    params = {'batch_size': 32, 'shuffle': True, 'num_workers': 12}
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    # Create Batch Generator
    training_dataset = Dataset(training_input, training_labels)
    training_dataloader = torch.utils.data.DataLoader(training_dataset,
                                                      **params)

    # Define network, loss function and optimizer
    net = TransformerNet(embedding_dim=500,
                         num_mappings=num_unique_tokens,
                         num_heads=4,
                         num_blocks=3,
                         seq_length=sequence_length)
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(net.parameters(), lr=1e-4, weight_decay=1e-3)

    # Define number of epochs
    num_epochs = 10
    print_batch_step = 100
    best_epoch_loss = 1000

    # Start training
    net = net.to(device)
    net.train(mode=True)
    for epoch in range(num_epochs):

        running_loss = 0
        epoch_loss = 0

        for i, data_ in enumerate(training_dataloader, 1):

            inputs, labels = data_[0].to(device), data_[1].to(device)

            optimizer.zero_grad()
            outputs = net(inputs)

            # flatten sequence dim to batch dim before Softmax + CE
            batch, seq, num_classes = outputs.size()
            outputs = outputs.view(batch * seq, num_classes)
            labels = labels.view(batch * seq)
            loss = criterion(outputs, labels)

            loss.backward()
            optimizer.step()

            # print statistics
            running_loss += loss.item()
            epoch_loss += loss.item() * params['batch_size']
            if i % print_batch_step == 0:
                print('[%d, %5d] loss: %.3f' %
                      (epoch + 1, i, running_loss / print_batch_step))
                running_loss = 0

        epoch_num = epoch + 1
        average_epoch_loss = epoch_loss / len(training_dataset)
        print(f"Epoch {epoch_num}: average loss is {average_epoch_loss}")

        if average_epoch_loss < best_epoch_loss:
            net.save_checkpoint(epoch_num, average_epoch_loss,
                                optimizer.state_dict())
            best_epoch_loss = average_epoch_loss

    print('Finished Training')