Пример #1
0
class Solver(object):

    def __init__(self, configuration):
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

        # retrieve configuration variables
        self.data_path = configuration.data_path
        self.crop_size = configuration.crop_size
        self.final_size = configuration.final_size
        self.batch_size = configuration.batch_size
        self.alternating_step = configuration.alternating_step
        self.ncritic = configuration.ncritic
        self.lambda_gp = configuration.lambda_gp
        self.debug_step = configuration.debug_step
        self.save_step = configuration.save_step
        self.max_checkpoints = configuration.max_checkpoints
        self.log_step = configuration.log_step
        # self.tflogger = Logger(configuration.log_dir)
        ## directoriess
        self.train_dir = configuration.train_dir
        self.img_dir = configuration.img_dir
        self.models_dir = configuration.models_dir
        ## variables
        self.eps_drift = 0.001

        self.resume_training = configuration.resume_training

        self._initialise_networks()

    def _initialise_networks(self):
        self.generator = Generator(final_size=self.final_size)
        self.generator.generate_network()
        self.g_optimizer = Adam(self.generator.parameters(), lr=0.001, betas=(0, 0.99))

        self.discriminator = Discriminator(final_size=self.final_size)
        self.discriminator.generate_network()
        self.d_optimizer = Adam(self.discriminator.parameters(), lr=0.001, betas=(0, 0.99))

        self.num_channels = min(self.generator.num_channels,
                                self.generator.max_channels)
        self.upsample = [Upsample(scale_factor=2**i)
                for i in reversed(range(self.generator.num_blocks))]

    def print_debugging_images(self, generator, latent_vectors, shape, index,
                               alpha, iteration):
        with torch.no_grad():
            columns = []
            for i in range(shape[0]):
                row = []
                for j in range(shape[1]):
                    img_ij = generator(latent_vectors[i * shape[1] +
                                                      j].unsqueeze_(0),
                                       index, alpha)
                    img_ij = self.upsample[index](img_ij)
                    row.append(img_ij)
                columns.append(torch.cat(row, dim=3))
            debugging_image = torch.cat(columns, dim=2)
        # denorm
        debugging_image = (debugging_image + 1) / 2
        debugging_image.clamp_(0, 1)
        save_image(debugging_image.data,
                   os.path.join(self.img_dir, "debug_{}_{}.png".format(index,
                                                                      iteration)))

    def save_trained_networks(self, block_index, phase, step):
        models_file = os.path.join(self.models_dir, "models.json")
        if os.path.isfile(models_file):
            with open(models_file, 'r') as file:
                models_config = json.load(file)
        else:
            models_config = json.loads('{ "checkpoints": [] }')

        generator_save_name = "generator_{}_{}_{}.pth".format(
                                    block_index, phase, step
                                )
        torch.save(self.generator.state_dict(),
                   os.path.join(self.models_dir, generator_save_name))

        discriminator_save_name = "discriminator_{}_{}_{}.pth".format(
                                    block_index, phase, step
                                )
        torch.save(self.discriminator.state_dict(),
                   os.path.join(self.models_dir, discriminator_save_name))

        models_config["checkpoints"].append(OrderedDict({
            "block_index": block_index,
            "phase": phase,
            "step": step,
            "generator": generator_save_name,
            "discriminator": discriminator_save_name
        }))
        if len(models_config["checkpoints"]) > self.max_checkpoints:
            old_save = models_config["checkpoints"][0]
            os.remove(os.path.join(self.models_dir, old_save["generator"]))
            os.remove(os.path.join(self.models_dir, old_save["discriminator"]))
            models_config["checkpoints"] = models_config["checkpoints"][1:]
        with open(os.path.join(self.models_dir, "models.json"), 'w') as file:
            json.dump(models_config, file, indent=4)

    def load_trained_networks(self):
        models_file = os.path.join(self.models_dir, "models.json")
        if os.path.isfile(models_file):
            with open(models_file, 'r') as file:
                models_config = json.load(file)
        else:
            raise FileNotFoundError("File 'models.json' not found in {"
                                    "}".format(self.models_dir))

        last_checkpoint = models_config["checkpoints"][-1]
        block_index = last_checkpoint["block_index"]
        phase = last_checkpoint["phase"]
        step = last_checkpoint["step"]
        generator_save_name = os.path.join(
            self.models_dir, last_checkpoint["generator"])
        discriminator_save_name = os.path.join(
            self.models_dir, last_checkpoint["discriminator"])

        self.generator.load_state_dict(torch.load(generator_save_name))
        self.discriminator.load_state_dict(torch.load(discriminator_save_name))

        return  block_index, phase, step

    def train(self):
        # get debugging vectors
        N = (5, 10)
        debug_vectors = torch.randn(N[0] * N[1], self.num_channels, 1,
                                    1).to(self.device)

        # get loader
        loader = get_loader(self.data_path, self.crop_size, self.batch_size)

        losses = {
            "d_loss_real": None,
            "d_loss_fake": None,
            "g_loss": None
        }

        # resume training if needed
        if self.resume_training:
            start_index, start_phase, start_step = self.load_trained_networks()
        else:
            start_index, start_phase, start_step = (0, "fade", 0)

        # training loop
        start_time = time.time()
        absolute_step = -1
        for index in range(start_index, self.generator.num_blocks):
            loader.dataset.set_transform_by_index(index)
            data_iterator = iter(loader)
            for phase in ('fade', 'stabilize'):
                if index == 0 and phase == 'fade': continue
                if self.resume_training and \
                        index == start_index and \
                        phase is not start_phase:
                    continue #
                if phase == 'phade': self.alternating_step = 10000 #FIXME del
                print("index: {}, size: {}x{}, phase: {}".format(
                    index, 2 ** (index + 2), 2 ** (index + 2), phase))
                if self.resume_training and \
                        phase == start_phase     and \
                        index == start_index:
                    step_range = range(start_step, self.alternating_step)
                else:
                    step_range = range(self.alternating_step)
                for i in step_range:
                    absolute_step += 1
                    try:
                        batch = next(data_iterator)
                    except:
                        data_iterator = iter(loader)
                        batch = next(data_iterator)

                    alpha = i / self.alternating_step if phase == "fade" else 1.0

                    batch = batch.to(self.device)

                    d_loss_real = - torch.mean(
                        self.discriminator(batch, index, alpha))
                    losses["d_loss_real"] = torch.mean(d_loss_real).data[0]

                    latent = torch.randn(
                        batch.size(0), self.num_channels, 1, 1).to(self.device)
                    fake_batch = self.generator(latent, index, alpha).detach()
                    d_loss_fake = torch.mean(
                        self.discriminator(fake_batch, index, alpha))
                    losses["d_loss_fake"] = torch.mean(d_loss_fake).data[0]

                    # drift factor
                    drift = d_loss_real.pow(2) + d_loss_fake.pow(2)

                    d_loss = d_loss_real + d_loss_fake + self.eps_drift * drift
                    self.d_optimizer.zero_grad()
                    d_loss.backward()  # if retain_graph=True
                    # then gp works but I'm not sure it's right
                    self.d_optimizer.step()

                    # Compute gradient penalty
                    alpha_gp = torch.rand(batch.size(0), 1, 1, 1).to(self.device)
                    # mind that x_hat must be both detached from the previous
                    # gradient graph (from fake_barch) and with
                    # requires_graph=True so that the gradient can be computed
                    x_hat = (alpha_gp * batch + (1 - alpha_gp) *
                             fake_batch).requires_grad_(True)
                    # x_hat = torch.cuda.FloatTensor(x_hat).requires_grad_(True)
                    out = self.discriminator(x_hat, index, alpha)
                    grad = torch.autograd.grad(
                        outputs=out,
                        inputs=x_hat,
                        grad_outputs=torch.ones_like(out).to(self.device),
                        retain_graph=True,
                        create_graph=True,
                        only_inputs=True
                    )[0]
                    grad = grad.view(grad.size(0), -1)  # is this the same as
                    # detach?
                    l2norm = torch.sqrt(torch.sum(grad ** 2, dim=1))
                    d_loss_gp = torch.mean((l2norm - 1) ** 2)

                    d_loss_gp *= self.lambda_gp
                    self.d_optimizer.zero_grad()
                    d_loss_gp.backward()
                    self.d_optimizer.step()

                    # train generator
                    if (i + 1) % self.ncritic == 0:
                        latent = torch.randn(
                            self.batch_size, self.num_channels, 1, 1).to(self.device)
                        fake_batch = self.generator(latent, index, alpha)
                        g_loss = - torch.mean(self.discriminator(
                                                    fake_batch, index, alpha))
                        losses["g_loss"] = torch.mean(g_loss).data[0]
                        self.g_optimizer.zero_grad()
                        g_loss.backward()
                        self.g_optimizer.step()

                    # tensorboard logging
                    if (i + 1) % self.log_step == 0:
                        elapsed = time.time() - start_time
                        elapsed = str(datetime.timedelta(seconds=elapsed))
                        print("{}:{}:{}/{} time {}, d_loss_real {}, "
                              "d_loss_fake {}, "
                              "g_loss {}, alpha {}".format(index, phase, i,
                                                           self.alternating_step,
                                                           elapsed,
                                                           d_loss_real,
                                              d_loss_fake,
                                              g_loss, alpha))
                        for name, value in losses.items():
                            self.tflogger.scalar_summary(name, value, absolute_step)


                    # print debugging images
                    if (i + 1) % self.debug_step == 0:
                        self.print_debugging_images(
                            self.generator, debug_vectors, N, index, alpha, i)

                    # save trained networks
                    if (i + 1) % self.save_step == 0:
                        self.save_trained_networks(index, phase, i)
Пример #2
0
def train(data_path,
          crop_size=128,
          final_size=64,
          batch_size=16,
          alternating_step=10000,
          ncritic=1,
          lambda_gp=0.1,
          debug_step=100):
    # define networks
    generator = Generator(final_size=final_size)
    generator.generate_network()
    g_optimizer = Adam(generator.parameters())

    discriminator = Discriminator(final_size=final_size)
    discriminator.generate_network()
    d_optimizer = Adam(discriminator.parameters())

    num_channels = min(generator.num_channels, generator.max_channels)

    # get debugging vectors
    N = (5, 10)
    debug_vectors = torch.randn(N[0] * N[1], num_channels, 1, 1).to(device)
    global upsample
    upsample = [
        Upsample(scale_factor=2**i)
        for i in reversed(range(generator.num_blocks))
    ]

    # get loader
    loader = get_loader(data_path, crop_size, batch_size)

    # training loop
    start_time = time.time()
    for index in range(generator.num_blocks):
        loader.dataset.set_transform_by_index(index)
        data_iterator = iter(loader)
        for phase in ('fade', 'stabilize'):
            if index == 0 and phase == 'fade': continue
            print("index: {}, size: {}x{}, phase: {}".format(
                index, 2**(index + 2), 2**(index + 2), phase))
            for i in range(alternating_step):
                print(i)
                try:
                    batch = next(data_iterator)
                except:
                    data_iterator = iter(loader)
                    batch = next(data_iterator)

                alpha = i / alternating_step if phase == "fade" else 1.0

                batch = batch.to(device)

                d_loss_real = -torch.mean(discriminator(batch, index, alpha))

                latent = torch.randn(batch_size, num_channels, 1, 1).to(device)
                fake_batch = generator(latent, index, alpha).detach()
                d_loss_fake = torch.mean(
                    discriminator(fake_batch, index, alpha))

                d_loss = d_loss_real + d_loss_fake
                d_optimizer.zero_grad()
                d_loss.backward()  # if retain_graph=True
                # then gp works but I'm not sure it's right
                d_optimizer.step()

                # Compute gradient penalty
                alpha_gp = torch.rand(batch.size(0), 1, 1, 1).to(device)
                # mind that x_hat must be both detached from the previous
                # gradient graph (from fake_barch) and with
                # requires_graph=True so that the gradient can be computed
                x_hat = (alpha_gp * batch +
                         (1 - alpha_gp) * fake_batch).requires_grad_(True)
                # x_hat = torch.cuda.FloatTensor(x_hat).requires_grad_(True)
                out = discriminator(x_hat, index, alpha)
                grad = torch.autograd.grad(
                    outputs=out,
                    inputs=x_hat,
                    grad_outputs=torch.ones_like(out).to(device),
                    retain_graph=True,
                    create_graph=True,
                    only_inputs=True)[0]
                grad = grad.view(grad.size(0), -1)  #is this the same as
                # detach?
                l2norm = torch.sqrt(torch.sum(grad**2, dim=1))
                d_loss_gp = torch.mean((l2norm - 1)**2)

                d_loss_gp *= lambda_gp
                d_optimizer.zero_grad()
                d_loss_gp.backward()
                d_optimizer.step()

                if (i + 1) % ncritic == 0:
                    latent = torch.randn(batch_size, num_channels, 1,
                                         1).to(device)
                    fake_batch = generator(latent, index, alpha)
                    g_loss = -torch.mean(
                        discriminator(fake_batch, index, alpha))
                    g_optimizer.zero_grad()
                    g_loss.backward()
                    g_optimizer.step()

                # print debugging images
                if (i + 1) % debug_step == 0:
                    print_debugging_images(generator, debug_vectors, N, index,
                                           alpha, i)