Пример #1
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--parameters', type=str, default='/home/mdo2/Documents/artgan/parameters.json', help='model parameters file')
    parser.add_argument('--mode', type=bool, default=True, help=' True for training or False for inference')
    args = parser.parse_args()

    with open(args.parameters, 'r') as f:
        parameters = json.load(f)
    parameters['mode'] = args.mode

    gen_feed = tf.placeholder(dtype=tf.float32, shape=(None, parameters['noise_length']), name='gen_feed')
    dis_feed = tf.placeholder(dtype=tf.float32, shape=(None, 64, 64, 3), name='dis_feed')
    dis_labels_real = tf.placeholder(dtype=tf.float32, shape=(None, 1), name='dis_labels_real')
    dis_labels_fake = tf.placeholder(dtype=tf.float32, shape=(None, 1), name='dis_labels_fake')
    dis_feed_cond = tf.placeholder(dtype=tf.bool, name='dis_feed_cond')

    model = architecture.Model(parameters)
    gen_output = model.generator(gen_feed)
    dis_out_real = model.discriminator(dis_feed)
    dis_out_fake = model.discriminator(gen_output, reuse=True)     

    if args.mode == True: #Training
        train(gen_feed, dis_feed, gen_output, dis_out_real, dis_out_fake, dis_labels_real, dis_labels_fake, parameters)
    elif args.mode == False: #Inference 
        infer(model, gen_feed)
Пример #2
0
def evaluate(config):
    device = torch.device('cuda' if config['use_cuda'] else 'cpu')

    model = architecture.Model().to(device)

    train_state = dict(model=model)

    print('Loading model checkpoint')
    workflow.ignite.handlers.ModelCheckpoint.load(
        train_state, 'model/checkpoints', device
    )


    @workflow.ignite.decorators.evaluate(model)
    def evaluate_batch(engine, examples):
        predictions = model.predictions(
            architecture.FeatureBatch.from_examples(examples)
        )
        loss = predictions.loss(examples)
        return dict(
            examples=examples,
            predictions=predictions.cpu().detach(),
            loss=loss,
        )

    evaluate_data_loaders = {
        f'evaluate_{name}': datastream.data_loader(
            batch_size=config['eval_batch_size'],
            num_workers=config['n_workers'],
            collate_fn=tuple,
        )
        for name, datastream in datastream.evaluate_datastreams().items()
    }

    tensorboard_logger = TensorboardLogger(log_dir='tb')

    for desciption, data_loader in evaluate_data_loaders.items():
        engine = evaluator(
            evaluate_batch, desciption,
            metrics.evaluate_metrics(),
            tensorboard_logger,
        )
        engine.run(data=data_loader)
Пример #3
0
def evaluate(config):
    torch.set_grad_enabled(False)
    device = torch.device("cuda" if config["use_cuda"] else "cpu")

    model = architecture.Model().to(device)

    if Path("model").exists():
        print("Loading model checkpoint")
        model.load_state_dict(torch.load("model/model.pt"))

    evaluate_data_loaders = {
        f"evaluate_{name}": (
            datastream.map(architecture.StandardizedImage.from_example).data_loader(
                batch_size=config["eval_batch_size"],
                collate_fn=tools.unzip,
                num_workers=config["n_workers"],
            )
        )
        for name, datastream in datastream.evaluate_datastreams().items()
        if "mini" not in name
    }

    tensorboard_logger = torch.utils.tensorboard.SummaryWriter()
    evaluate_metrics = {
        name: metrics.evaluate_metrics() for name in evaluate_data_loaders
    }

    for name, data_loader in evaluate_data_loaders.items():
        for examples, standardized_images in tqdm(data_loader, desc=name, leave=False):
            with lantern.module_eval(model):
                predictions = model.predictions(standardized_images)
                loss = predictions.loss(examples)

            evaluate_metrics[name]["loss"].update_(loss)
            evaluate_metrics[name]["accuracy"].update_(examples, predictions)

        for metric_name, metric in evaluate_metrics[name].items():
            metric.log(tensorboard_logger, name, metric_name)

        print(lantern.MetricTable(name, evaluate_metrics[name]))

    tensorboard_logger.close()
Пример #4
0
def main(unused_arg):
    parser = argparse.ArgumentParser()
    parser.add_argument('-wp',
                        '--weights_path',
                        type=str,
                        default="vgg16.npy",
                        help="path to the vgg16.npy file")
    parser.add_argument('-ci',
                        '--content_image',
                        type=str,
                        default=None,
                        help="path to content image")
    parser.add_argument('-si',
                        '--style_image',
                        type=str,
                        default=None,
                        help="path to style image")
    parser.add_argument('-op',
                        '--output_path',
                        type=str,
                        default="art_image.png",
                        help="path to save output")
    parser.add_argument('-lr',
                        '--learning_rate',
                        type=float,
                        default=2.0,
                        help="learning rate")
    parser.add_argument('-i',
                        '--iterations',
                        type=int,
                        default=2000,
                        help="content loss weight")
    parser.add_argument('-a',
                        '--alpha',
                        type=float,
                        default=100,
                        help="content loss weight")
    parser.add_argument('-b',
                        '--beta',
                        type=float,
                        default=8,
                        help="style loss weight")
    parser.add_argument('-lw',
                        '--layer_loss_weights',
                        nargs='+',
                        type=float,
                        default=[0.5, 1, 0.5, 0.5, 0.5],
                        help="layer wise style loss weights")
    arg = parser.parse_args()

    if not os.path.exists(arg.weights_path):
        raise ValueError("vgg16.npy not found at {}".format(arg.weights_path))

    if arg.content_image is None:
        raise ValueError("Path to content image not specified")
    else:
        if not os.path.exists(arg.content_image):
            raise FileNotFoundError(
                "No image exists at the location: {}".format(
                    arg.content_image))
        else:
            con_img = cv2.imread(arg.content_image)

    if arg.style_image is None:
        raise ValueError("Path to content image not specified")
    else:
        if not os.path.exists(arg.style_image):
            raise FileNotFoundError(
                "No image exists at the location: {}".format(arg.style_image))
        else:
            sty_img = cv2.imread(arg.style_image)

    vgg_mean = [103.939, 116.779, 123.68]

    # resize content and style image to same size
    con_img = np.asarray(con_img, dtype=np.float32)
    shape = con_img.shape
    sty_img = cv2.resize(sty_img, (shape[1], shape[0]))
    sty_img = np.asarray(sty_img, dtype=np.float32)
    assert con_img.shape == sty_img.shape, "content and style images have different shape"

    # subtract mean values from each channel and reshape (required by vgg network)
    for i in range(3):
        con_img[:, :, i] = con_img[:, :, i] - vgg_mean[i]
        sty_img[:, :, i] = sty_img[:, :, i] - vgg_mean[i]
    con_img = con_img.reshape(1, shape[0], shape[1], shape[2])
    sty_img = sty_img.reshape(1, shape[0], shape[1], shape[2])

    content_image = tf.placeholder(dtype=tf.float32,
                                   shape=(1, shape[0], shape[1], shape[2]),
                                   name="content_image")
    style_image = tf.placeholder(dtype=tf.float32,
                                 shape=(1, shape[0], shape[1], shape[2]),
                                 name="style_image")
    initialize = tf.constant_initializer(value=con_img, dtype=tf.float32)
    '''random_image = tf.get_variable(initializer=tf.zeros(shape=(1, shape[0], shape[1], shape[2])), dtype=tf.float32,
                                   trainable=True, name="rnd_img")'''
    random_image = tf.get_variable(initializer=initialize,
                                   dtype=tf.float32,
                                   shape=(1, shape[0], shape[1], shape[2]),
                                   trainable=True,
                                   name="rnd_img")
    input_image = tf.concat([random_image, content_image, style_image], axis=0)

    model = architecture.Model(arg.weights_path, False)
    model.build(input_image)
    layer_1 = model.conv1_2
    layer_2 = model.conv2_2
    layer_3 = model.conv3_3
    layer_4 = model.conv4_3
    layer_5 = model.conv5_3

    con_loss = content_loss(layer_3[1, :, :, :], layer_3[0, :, :, :])
    sty_loss1 = style_loss(layer_1[2, :, :, :], layer_1[0, :, :, :])
    sty_loss2 = style_loss(layer_2[2, :, :, :], layer_2[0, :, :, :])
    sty_loss3 = style_loss(layer_3[2, :, :, :], layer_3[0, :, :, :])
    sty_loss4 = style_loss(layer_4[2, :, :, :], layer_4[0, :, :, :])
    sty_loss5 = style_loss(layer_5[2, :, :, :], layer_5[0, :, :, :])

    w = arg.layer_loss_weights
    sty_loss = sty_loss1 * w[0] + sty_loss2 * w[1] + sty_loss3 * w[
        2] + sty_loss4 * w[3] + sty_loss5 * w[4]
    loss = (arg.alpha * con_loss) + (arg.beta * sty_loss)

    optimizer = tf.train.AdamOptimizer(learning_rate=arg.learning_rate)
    train_op = optimizer.minimize(loss)

    saver = tf.train.Saver({"rnd_img": random_image})

    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())

        for i in range(arg.iterations):
            _, loss_out = sess.run([train_op, loss],
                                   feed_dict={
                                       content_image: con_img,
                                       style_image: sty_img
                                   })
            print("iteration:{}/{} loss:{}".format(i + 1, arg.iterations,
                                                   loss_out))

        artistic_image, = sess.run([random_image])
        artistic_image = np.squeeze(artistic_image)
        for i in range(3):
            artistic_image[:, :, i] = artistic_image[:, :, i] + vgg_mean[i]
        cv2.imwrite(arg.output_path, artistic_image)
Пример #5
0
def train(config):

    set_seeds(config['seed'])

    device = torch.device('cuda' if config['use_cuda'] else 'cpu')

    model = architecture.Model().to(device)
    optimizer = torch.optim.Adam(
        model.parameters(), lr=config['learning_rate']
    )

    train_state = dict(model=model, optimizer=optimizer)

    if Path('model').exists():
        print('Loading model checkpoint')
        workflow.ignite.handlers.ModelCheckpoint.load(
            train_state, 'model/checkpoints', device
        )

        workflow.torch.set_learning_rate(optimizer, config['learning_rate'])

    n_parameters = sum([
        p.shape.numel() for p in model.parameters() if p.requires_grad
    ])
    print(f'n_parameters: {n_parameters:,}')

    def process_batch(examples):
        predictions = model.predictions(
            architecture.FeatureBatch.from_examples(examples)
        )
        loss = predictions.loss(examples)
        return predictions, loss

    @workflow.ignite.decorators.train(model, optimizer)
    def train_batch(engine, examples):
        predictions, loss = process_batch(examples)
        loss.backward()
        return dict(
            examples=examples,
            predictions=predictions.cpu().detach(),
            loss=loss,
        )

    @workflow.ignite.decorators.evaluate(model)
    def evaluate_batch(engine, examples):
        predictions, loss = process_batch(examples)
        return dict(
            examples=examples,
            predictions=predictions.cpu().detach(),
            loss=loss,
        )

    evaluate_data_loaders = {
        f'evaluate_{name}': datastream.data_loader(
            batch_size=config['eval_batch_size'],
            num_workers=config['n_workers'],
            collate_fn=tuple,
        )
        for name, datastream in datastream.evaluate_datastreams().items()
    }

    trainer, evaluators, tensorboard_logger = workflow.ignite.trainer(
        train_batch,
        evaluate_batch,
        evaluate_data_loaders,
        metrics=dict(
            progress=metrics.progress_metrics(),
            train=metrics.train_metrics(),
            **{
                name: metrics.evaluate_metrics()
                for name in evaluate_data_loaders.keys()
            }
        ),
        optimizers=optimizer,
    )

    workflow.ignite.handlers.ModelScore(
        lambda: -evaluators['evaluate_early_stopping'].state.metrics['loss'],
        train_state,
        {
            name: metrics.evaluate_metrics()
            for name in evaluate_data_loaders.keys()
        },
        tensorboard_logger,
        config,
    ).attach(trainer, evaluators)

    tensorboard_logger.attach(
        trainer,
        log_examples('train', trainer),
        ignite.engine.Events.EPOCH_COMPLETED,
    )
    tensorboard_logger.attach(
        evaluators['evaluate_compare'],
        log_examples('evaluate_compare', trainer),
        ignite.engine.Events.EPOCH_COMPLETED,
    )

    if config.get('search_learning_rate', False):

        def search(config):
            def search_(step, multiplier):
                return (
                    step,
                    (1 / config['minimum_learning_rate'])
                    ** (step / config['n_batches'])
                )
            return search_

        LearningRateScheduler(
            optimizer,
            search(config),
        ).attach(trainer)

    else:
        LearningRateScheduler(
            optimizer,
            starcompose(
                warmup(150),
                cyclical(length=500),
            ),
        ).attach(trainer)

    trainer.run(
        data=(
            datastream.GradientDatastream()
            .data_loader(
                batch_size=config['batch_size'],
                num_workers=config['n_workers'],
                n_batches_per_epoch=config['n_batches_per_epoch'],
                worker_init_fn=partial(worker_init, config['seed'], trainer),
                collate_fn=tuple,
            )
        ),
        max_epochs=config['max_epochs'],
    )
Пример #6
0
def train(config):
    torch.set_grad_enabled(False)
    device = torch.device("cuda" if config["use_cuda"] else "cpu")
    set_seeds(config["seed"])

    model = architecture.Model().to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=config["learning_rate"])

    if Path("model").exists():
        print("Loading model checkpoint")
        model.load_state_dict(torch.load("model/model.pt"))
        optimizer.load_state_dict(torch.load("model/optimizer.pt"))
        lantern.set_learning_rate(optimizer, config["learning_rate"])

    train_data_loader = (
        datastream.TrainDatastream()
        .map(architecture.StandardizedImage.from_example)
        .data_loader(
            batch_size=config["batch_size"],
            n_batches_per_epoch=config["n_batches_per_epoch"],
            collate_fn=tools.unzip,
            num_workers=config["n_workers"],
            worker_init_fn=worker_init_fn(config["seed"]),
            persistent_workers=(config["n_workers"] >= 1),
        )
    )

    evaluate_data_loaders = {
        f"evaluate_{name}": (
            datastream.map(architecture.StandardizedImage.from_example).data_loader(
                batch_size=config["eval_batch_size"],
                collate_fn=tools.unzip,
                num_workers=config["n_workers"],
            )
        )
        for name, datastream in datastream.evaluate_datastreams().items()
        if "mini" in name
    }

    tensorboard_logger = torch.utils.tensorboard.SummaryWriter(log_dir="tb")
    early_stopping = lantern.EarlyStopping(tensorboard_logger=tensorboard_logger)
    train_metrics = metrics.train_metrics()

    for epoch in lantern.Epochs(config["max_epochs"]):

        for examples, standardized_images in lantern.ProgressBar(
            train_data_loader, "train", train_metrics
        ):
            with lantern.module_train(model), torch.enable_grad():
                predictions = model.predictions(standardized_images)
                loss = predictions.loss(examples)
                loss.backward()
            optimizer.step()
            optimizer.zero_grad()

            train_metrics["loss"].update_(loss)
            train_metrics["accuracy"].update_(examples, predictions)

            for name, metric in train_metrics.items():
                metric.log(tensorboard_logger, "train", name, epoch)

        print(lantern.MetricTable("train", train_metrics))
        log_examples(tensorboard_logger, "train", epoch, examples, predictions)

        evaluate_metrics = {
            name: metrics.evaluate_metrics() for name in evaluate_data_loaders
        }

        for name, data_loader in evaluate_data_loaders.items():
            for examples, standardized_images in lantern.ProgressBar(data_loader, name):
                with lantern.module_eval(model):
                    predictions = model.predictions(standardized_images)
                    loss = predictions.loss(examples)

                evaluate_metrics[name]["loss"].update_(loss)
                evaluate_metrics[name]["accuracy"].update_(examples, predictions)

            for metric_name, metric in evaluate_metrics[name].items():
                metric.log(tensorboard_logger, name, metric_name, epoch)

            print(lantern.MetricTable(name, evaluate_metrics[name]))
            log_examples(tensorboard_logger, name, epoch, examples, predictions)

        early_stopping = early_stopping.score(
            evaluate_metrics["evaluate_mini_early_stopping"]["accuracy"].compute()
        )
        if early_stopping.scores_since_improvement == 0:
            torch.save(model.state_dict(), "model.pt")
            torch.save(optimizer.state_dict(), "optimizer.pt")
        elif early_stopping.scores_since_improvement > config["patience"]:
            break
        early_stopping.log(epoch).print()

        tensorboard_logger.close()