Esempio n. 1
0
def train(epochs, iterations, outdir, path, batchsize, validsize, model_type):
    # Dataset Definition
    dataloader = DatasetLoader(path)
    print(dataloader)
    t_valid, x_valid = dataloader(validsize, mode="valid")

    # Model & Optimizer Definition
    if model_type == 'ram':
        model = Model()
    elif model_type == 'gan':
        model = Generator()
    model.to_gpu()
    optimizer = set_optimizer(model)

    vgg = VGG()
    vgg.to_gpu()
    vgg_opt = set_optimizer(vgg)
    vgg.base.disable_update()

    # Loss Function Definition
    lossfunc = RAMLossFunction()
    print(lossfunc)

    # Evaluation Definition
    evaluator = Evaluation()

    for epoch in range(epochs):
        sum_loss = 0
        for batch in range(0, iterations, batchsize):
            t_train, x_train = dataloader(batchsize, mode="train")

            y_train = model(x_train)
            y_feat = vgg(y_train)
            t_feat = vgg(t_train)
            loss = lossfunc.content_loss(y_train, t_train)
            loss += lossfunc.perceptual_loss(y_feat, t_feat)

            model.cleargrads()
            vgg.cleargrads()
            loss.backward()
            optimizer.update()
            vgg_opt.update()
            loss.unchain_backward()

            sum_loss += loss.data

            if batch == 0:
                serializers.save_npz(f"{outdir}/model_{epoch}.model", model)

                with chainer.using_config('train', False):
                    y_valid = model(x_valid)
                x = x_valid.data.get()
                y = y_valid.data.get()
                t = t_valid.data.get()

                evaluator(x, y, t, epoch, outdir)

        print(f"epoch: {epoch}")
        print(f"loss: {sum_loss / iterations}")
Esempio n. 2
0
        t = adain(content_feat, style_feat4)
        t = alpha * t + (1 - alpha) * content_feat

        g_t = decoder(t)
        g_t_feats1, g_t_feats2, g_t_feats3, g_t_feats4 = vgg(g_t)

        loss_content = F.mean_squared_error(g_t_feats4, t)
        loss_style = style_loss(style_feat1, g_t_feats1)
        loss_style += style_loss(style_feat2, g_t_feats2)
        loss_style += style_loss(style_feat3, g_t_feats3)
        loss_style += style_loss(style_feat4, g_t_feats4)

        loss = content_weight * loss_content + style_weight * loss_style

        decoder.cleargrads()
        vgg.cleargrads()

        loss.backward()
        loss.unchain_backward()

        dec_opt.update()
        vgg_opt.update()

        sum_loss += loss.data.get()

        if epoch % interval == 0 and batch == 0:
            serializers.save_npz("decoder.model", decoder)
            with chainer.using_config("train", False):
                style_feat1, style_feat2, style_feat3, style_feat4 = vgg(
                    style_test)
                content_feat = vgg(content_test, last_only=True)