Exemplo n.º 1
0
def train(epochs, iterations, batchsize, validsize, outdir, modeldir,
          extension, train_size, valid_size, data_path, sketch_path, digi_path,
          learning_rate, beta1, weight_decay):

    # Dataset definition
    dataset = DatasetLoader(data_path, sketch_path, digi_path, extension,
                            train_size, valid_size)
    print(dataset)
    x_val, t_val = dataset.valid(validsize)

    # Model & Optimizer definition
    unet = UNet()
    unet.to_gpu()
    unet_opt = set_optimizer(unet, learning_rate, beta1, weight_decay)

    discriminator = Discriminator()
    discriminator.to_gpu()
    dis_opt = set_optimizer(discriminator, learning_rate, beta1, weight_decay)

    # Loss function definition
    lossfunc = Pix2pixLossCalculator()

    # Visualization definition
    visualizer = Visualizer()

    for epoch in range(epochs):
        sum_dis_loss = 0
        sum_gen_loss = 0
        for batch in range(0, iterations, batchsize):
            x, t = dataset.train(batchsize)

            # Discriminator update
            y = unet(x)
            y.unchain_backward()

            dis_loss = lossfunc.dis_loss(discriminator, y, t)

            discriminator.cleargrads()
            dis_loss.backward()
            dis_opt.update()

            sum_dis_loss += dis_loss.data

            # Generator update
            y = unet(x)

            gen_loss = lossfunc.gen_loss(discriminator, y)
            gen_loss += lossfunc.content_loss(y, t)

            unet.cleargrads()
            gen_loss.backward()
            unet_opt.update()

            sum_gen_loss += gen_loss.data

            if batch == 0:
                serializers.save_npz(f"{modeldir}/unet_{epoch}.model", unet)

                with chainer.using_config("train", False):
                    y = unet(x_val)

                x = x_val.data.get()
                t = t_val.data.get()
                y = y.data.get()

                visualizer(x, t, y, outdir, epoch, validsize)

        print(f"epoch: {epoch}")
        print(
            f"dis loss: {sum_dis_loss/iterations} gen loss: {sum_gen_loss/iterations}"
        )
Exemplo n.º 2
0
        opt_enc=image_encoder(opt)
        opt_fake = making_optical_flow(F.concat([x, key_diff]), opt_enc)
        opt_fake_enc = image_encoder(opt_fake)
        y = generator(x, key_diff, opt_fake_enc)

        dis_opt_fake = discriminator_image(F.concat([x, opt_fake]))
        temp_fake = preapre_smoothing(opt_fake, y)
        dis_temp_fake = discriminator_temporal(temp_fake)

        gen_loss = F.mean(F.softplus(-dis_temp_fake)) + F.mean(F.softplus(-dis_opt_fake))
        gen_loss += F.mean_absolute_error(y,t) + F.mean_absolute_error(opt_fake, opt)

        key_point_detector.cleargrads()
        image_encoder.cleargrads()
        making_optical_flow.cleargrads()
        generator.cleargrads()
        gen_loss.backward()
        enc_opt.update()
        ref_opt.update()
        gen_opt.update()
        key_opt.update()

        gen_loss.unchain_backward()
        
        sum_gen_loss+=gen_loss.data.get()
        sum_dis_loss += dis_loss.data.get()

        if epoch%interval==0 and batch==0:
            serializers.save_npz("image_encoder.model",image_encoder)
            serializers.save_npz("making_opticalflow.model",making_optical_flow)
Exemplo n.º 3
0
                y_x_serial[0].reshape(1, 3, size, size),
                y_x_serial[1].reshape(1, 3, size, size)
            ])
            y_next = generator_xy(predictor_x(y_x_series))
            recycle_loss_y = F.mean_squared_error(
                y_next, y[index].reshape(1, 3, size, size))

            gen_loss_x = gen_loss_xy + weight * (
                cycle_loss_x + recurrent_loss_x + recycle_loss_x)
            gen_loss_y = gen_loss_yx + weight * (
                cycle_loss_y + recurrent_loss_y + recycle_loss_y)
            gen_loss = gen_loss_x + gen_loss_y

            generator_xy.cleargrads()
            generator_yx.cleargrads()
            predictor_x.cleargrads()
            predictor_y.cleargrads()

            gen_loss.backward()

            gen_opt_xy.update()
            pre_opt_x.update()
            gen_opt_yx.update()
            pre_opt_y.update()

            gen_loss.unchain_backward()

        sum_gen_loss += (gen_loss)
        sum_dis_loss += (dis_loss_y + dis_loss_x)

        if epoch % interval == 0 and batch == 0: