コード例 #1
0
def main():
    G = u.create_generator()
    D = u.create_discriminator()
    E = u.create_encoder()
    R = u.create_regressor()

    if settings.EQUALIZE_WEIGHTS:
        ws.scale_network(D, 0.2)
        ws.scale_network(G, 0.2)

    if settings.SPECTRAL_NORM:
        sn.normalize_network(D, 0.2)

    #G.apply(u.near_identity_weight_init)
    #D.apply(u.near_identity_weight_init)

    #opt_G = torch.optim.Adamax(G.parameters(), settings.LEARNING_RATE, betas=settings.BETAS)
    #opt_D = torch.optim.Adamax(D.parameters(), settings.LEARNING_RATE, betas=settings.BETAS)

    visualizer = Visualizer()
    visualizer.initiate_windows()

    torch.save(G.state_dict(), "working_model/G.params")
    torch.save(D.state_dict(), "working_model/D.params")
    torch.save(E.state_dict(), "working_model/E.params")
    torch.save(R.state_dict(), "working_model/R.params")

    #torch.save(opt_G.state_dict(), "working_model/optG.state")
    #torch.save(opt_D.state_dict(), "working_model/optD.state")

    for i in settings.PROGRESSION:
        c, d = settings.PROGRESSION[i]
        to_rgb = nn.Conv2d(c, 2, 1)
        from_rgb = nn.Conv2d(2, c, 1)

        torch.save(to_rgb.state_dict(),
                   "working_model/toRGB{}.params".format(i))
        torch.save(from_rgb.state_dict(),
                   "working_model/fromRGB{}.params".format(i))

    # Initialize state
    state = {
        "point": 0,
        "pred_real": 0,
        "pred_fake": 0,
        "history_real": [],
        "history_fake": []
    }
    json.dump(state, open("working_model/state.json", "w"))
    # -----------------
    print("Saved networks and RGB layers in ./working_model")

    # Set row to zero for progressive training
    json.dump("0", open("/tmp/DeepGenerationConfigRow", "w"))
コード例 #2
0
ファイル: train_stage.py プロジェクト: netrome/DeepGeneration
def main():
    print("\nInitiating training with the following setting ----")
    print(json.dumps(vars(settings.args), sort_keys=True, indent=4))
    print("---------------------------------------------------")
    # Get utilities ---------------------------------------------------
    dataset = u.get_data_set()
    data_loader = torch.utils.data.DataLoader(dataset,
                                              batch_size=settings.BATCH_SIZE,
                                              shuffle=True,
                                              pin_memory=True,
                                              drop_last=True)
    visualizer = Visualizer()
    state = json.load(open("working_model/state.json", "r"))
    visualizer.point = state["point"]

    # Define networks -------------------------------------------------
    G = u.create_generator()
    D = u.create_discriminator()

    if settings.EQUALIZE_WEIGHTS:
        ws.scale_network(D, 0.2)
        ws.scale_network(G, 0.2)

    if settings.SPECTRAL_NORM:
        sn.normalize_network(D, 0.2)

    if settings.WORKING_MODEL:
        print("Using model parameters in ./working_model")
        G.load_state_dict(torch.load("working_model/G.params"))
        D.load_state_dict(torch.load("working_model/D.params"))

    # Train with StageTrainer or FadeInTrainer
    s, (c, d) = [settings.STAGE, settings.PROGRESSION[settings.STAGE]]
    if settings.FADE_IN:
        print("Freezing in next layer")
        c = settings.PROGRESSION[settings.STAGE + 1][0]
        d = int(d / 2)
        G.freeze_until(s)
        #D.freeze_until(s)
        s += 1

    # Freeze idle layers - did not stop vlad
    #G.freeze_idle(s)
    #D.freeze_idle(s)

    stage = trainer.StageTrainer(G,
                                 D,
                                 data_loader,
                                 stage=s,
                                 conversion_depth=c,
                                 downscale_factor=d)
    stage.pred_real += state["pred_real"]
    stage.pred_fake += state["pred_fake"]

    if settings.WORKING_MODEL:
        stage.toRGB.load_state_dict(
            torch.load("working_model/toRGB{}.params".format(s)))
        stage.fromRGB.load_state_dict(
            torch.load("working_model/fromRGB{}.params".format(s)))
        print("Loaded RGB layers too")

    stage.visualize(visualizer)
    for i in range(settings.CHUNKS):
        print("Chunk {}, stage {}, fade in: {}, GPU memory {}               ".
              format(i, settings.STAGE, settings.FADE_IN, 1337))
        stage.steps(settings.STEPS)
        gc.collect()  # Prevent memory leaks (?)
        #torch.cuda.empty_cache()  - Made no difference
        state["history_real"].append(float(stage.pred_real))
        state["history_fake"].append(float(stage.pred_fake))
        if settings.WORKING_MODEL:
            print("Saved timelapse visualization")
            stage.save_fake_reference_batch(visualizer.point)
        stage.visualize(visualizer)

    # Save networks
    """
    if settings.FADE_IN:
        to_rgb, from_rgb, next_to_rgb, next_from_rgb = stage.get_rgb_layers()
        print("Saving extra rgb layers, {}".format(time.ctime()))
        torch.save(next_to_rgb.state_dict(), "working_model/toRGB{}.params".format(s + 1))
        torch.save(next_from_rgb.state_dict(), "working_model/fromRGB{}.params".format(s + 1))
    else:
        to_rgb, from_rgb = stage.get_rgb_layers()
    """
    to_rgb, from_rgb = stage.get_rgb_layers()
    print("Saving rgb layers, {}".format(time.ctime()))

    torch.save(to_rgb.state_dict(), "working_model/toRGB{}.params".format(s))
    torch.save(from_rgb.state_dict(),
               "working_model/fromRGB{}.params".format(s))
    print("Saving networks, {}".format(time.ctime()))
    G.unfreeze_all()
    D.unfreeze_all()
    torch.save(G.state_dict(), "working_model/G.params")
    torch.save(D.state_dict(), "working_model/D.params")

    # Save state
    state["point"] = visualizer.point
    state["pred_real"] = float(stage.pred_real)
    state["pred_fake"] = float(stage.pred_fake)
    print("Saving state, {}".format(time.ctime()))
    json.dump(state, open("working_model/state.json", "w"))

    # Save optimizer state
    #opt_G = stage.opt_G
    #opt_D = stage.opt_D

    #print("Saving optimizer state, {}".format(time.ctime()))
    #torch.save(opt_G.state_dict(), "working_model/optG.state")
    #torch.save(opt_D.state_dict(), "working_model/optD.state")
    print("Finished with main")
コード例 #3
0
import torchvision
from torch.autograd import Variable
import os

import settings

import utils.utils as u


out_dir = "./output/"
num_batches = 150

if settings.TEST_DATA:
    num_batches=15

G = u.create_generator()
toRGB = nn.Conv2d(16, 2, 1)
G.load_state_dict(torch.load(os.path.join(settings.MODEL_PATH, "G.params")))
toRGB.load_state_dict(torch.load(os.path.join(settings.MODEL_PATH, "toRGB6.params")))
latent = Variable(torch.FloatTensor(settings.BATCH_SIZE, 128, 1, 1))

if settings.CUDA:
    toRGB.cuda()
    latent = latent.cuda()

num = 0
for i in range(num_batches):
    print("Generating batch {}/{}   ".format(i + 1, num_batches), end="\r")
    latent.data.normal_()
    batch = toRGB(G(latent))
コード例 #4
0
import json
import time

import settings
import torch
import torch.nn as nn
from torch.autograd import Variable

import utils.datasets as datasets
import utils.visualizer as vis
import utils.utils as u

from utils.utils import cyclic_data_iterator

encoder = u.create_encoder()
decoder = u.create_generator()
toRGB = nn.Conv2d(16, 2, 1)
fromRGB = nn.Conv2d(2, 16, 1)

if settings.CUDA:
    toRGB.cuda()
    fromRGB.cuda()

optimizer = torch.optim.Adamax([
    {
        "params": encoder.parameters()
    },
    {
        "params": decoder.parameters()
    },
    {