Ejemplo n.º 1
0
def train(cfg, logger, local_rank, world_size, distributed):
    torch.cuda.set_device(local_rank)
    model = Model(startf=cfg.MODEL.START_CHANNEL_COUNT,
                  layer_count=cfg.MODEL.LAYER_COUNT,
                  maxf=cfg.MODEL.MAX_CHANNEL_COUNT,
                  latent_size=cfg.MODEL.LATENT_SPACE_SIZE,
                  dlatent_avg_beta=cfg.MODEL.DLATENT_AVG_BETA,
                  style_mixing_prob=cfg.MODEL.STYLE_MIXING_PROB,
                  mapping_layers=cfg.MODEL.MAPPING_LAYERS,
                  channels=cfg.MODEL.CHANNELS,
                  generator=cfg.MODEL.GENERATOR,
                  encoder=cfg.MODEL.ENCODER)
    model.cuda(local_rank)
    model.train()

    if local_rank == 0:
        model_s = Model(startf=cfg.MODEL.START_CHANNEL_COUNT,
                        layer_count=cfg.MODEL.LAYER_COUNT,
                        maxf=cfg.MODEL.MAX_CHANNEL_COUNT,
                        latent_size=cfg.MODEL.LATENT_SPACE_SIZE,
                        truncation_psi=cfg.MODEL.TRUNCATIOM_PSI,
                        truncation_cutoff=cfg.MODEL.TRUNCATIOM_CUTOFF,
                        mapping_layers=cfg.MODEL.MAPPING_LAYERS,
                        channels=cfg.MODEL.CHANNELS,
                        generator=cfg.MODEL.GENERATOR,
                        encoder=cfg.MODEL.ENCODER)
        model_s.cuda(local_rank)
        model_s.eval()
        model_s.requires_grad_(False)

    if distributed:
        model = nn.parallel.DistributedDataParallel(
            model,
            device_ids=[local_rank],
            broadcast_buffers=False,
            bucket_cap_mb=25,
            find_unused_parameters=True)
        model.device_ids = None

        decoder = model.module.decoder
        encoder = model.module.encoder
        discriminator = model.module.discriminator
        mapping_fl = model.module.mapping_fl
        dlatent_avg = model.module.dlatent_avg
    else:
        decoder = model.decoder
        encoder = model.encoder
        discriminator = model.discriminator
        mapping_fl = model.mapping_fl
        dlatent_avg = model.dlatent_avg

    count_param_override.print = lambda a: logger.info(a)

    logger.info("Trainable parameters generator:")
    count_parameters(decoder)

    logger.info("Trainable parameters discriminator:")
    count_parameters(encoder)

    arguments = dict()
    arguments["iteration"] = 0

    decoder_optimizer = LREQAdam([{
        'params': decoder.parameters()
    }, {
        'params': mapping_fl.parameters()
    }],
                                 lr=cfg.TRAIN.BASE_LEARNING_RATE,
                                 betas=(cfg.TRAIN.ADAM_BETA_0,
                                        cfg.TRAIN.ADAM_BETA_1),
                                 weight_decay=0)

    encoder_optimizer = LREQAdam([
        {
            'params': encoder.parameters()
        },
    ],
                                 lr=cfg.TRAIN.BASE_LEARNING_RATE,
                                 betas=(cfg.TRAIN.ADAM_BETA_0,
                                        cfg.TRAIN.ADAM_BETA_1),
                                 weight_decay=0)

    discriminator_optimizer = LREQAdam([
        {
            'params': discriminator.parameters()
        },
    ],
                                       lr=cfg.TRAIN.BASE_LEARNING_RATE,
                                       betas=(cfg.TRAIN.ADAM_BETA_0,
                                              cfg.TRAIN.ADAM_BETA_1),
                                       weight_decay=0)

    scheduler = ComboMultiStepLR(optimizers={
        'encoder_optimizer': encoder_optimizer,
        'discriminator_optimizer': discriminator_optimizer,
        'decoder_optimizer': decoder_optimizer
    },
                                 milestones=cfg.TRAIN.LEARNING_DECAY_STEPS,
                                 gamma=cfg.TRAIN.LEARNING_DECAY_RATE,
                                 reference_batch_size=32,
                                 base_lr=cfg.TRAIN.LEARNING_RATES)

    model_dict = {
        'discriminator': discriminator,
        'encoder': encoder,
        'generator': decoder,
        'mapping_fl': mapping_fl,
        'dlatent_avg': dlatent_avg
    }
    if local_rank == 0:
        model_dict['discriminator_s'] = model_s.discriminator
        model_dict['encoder_s'] = model_s.encoder
        model_dict['generator_s'] = model_s.decoder
        model_dict['mapping_fl_s'] = model_s.mapping_fl

    tracker = LossTracker(cfg.OUTPUT_DIR)

    checkpointer = Checkpointer(
        cfg,
        model_dict, {
            'encoder_optimizer': encoder_optimizer,
            'discriminator_optimizer': discriminator_optimizer,
            'decoder_optimizer': decoder_optimizer,
            'scheduler': scheduler,
            'tracker': tracker
        },
        logger=logger,
        save=local_rank == 0)

    extra_checkpoint_data = checkpointer.load()
    logger.info("Starting from epoch: %d" % (scheduler.start_epoch()))

    arguments.update(extra_checkpoint_data)

    layer_to_resolution = decoder.layer_to_resolution

    dataset = TFRecordsDataset(cfg,
                               logger,
                               rank=local_rank,
                               world_size=world_size,
                               buffer_size_mb=1024,
                               channels=cfg.MODEL.CHANNELS)

    rnd = np.random.RandomState(3456)
    latents = rnd.randn(32, cfg.MODEL.LATENT_SPACE_SIZE)
    samplez = torch.tensor(latents).float().cuda()

    lod2batch = lod_driver.LODDriver(cfg,
                                     logger,
                                     world_size,
                                     dataset_size=len(dataset) * world_size)

    if cfg.DATASET.SAMPLES_PATH:
        path = cfg.DATASET.SAMPLES_PATH
        src = []
        with torch.no_grad():
            for filename in list(os.listdir(path))[:32]:
                img = np.asarray(Image.open(os.path.join(path, filename)))
                if img.shape[2] == 4:
                    img = img[:, :, :3]
                im = img.transpose((2, 0, 1))
                x = torch.tensor(np.asarray(im, dtype=np.float32),
                                 requires_grad=True).cuda() / 127.5 - 1.
                if x.shape[0] == 4:
                    x = x[:3]
                src.append(x)
            sample = torch.stack(src)
    else:
        dataset.reset(cfg.DATASET.MAX_RESOLUTION_LEVEL, 32)
        sample = next(make_dataloader(cfg, logger, dataset, 32, local_rank))
        sample = (sample / 127.5 - 1.)

    lod2batch.set_epoch(scheduler.start_epoch(),
                        [encoder_optimizer, decoder_optimizer])

    for epoch in range(scheduler.start_epoch(), cfg.TRAIN.TRAIN_EPOCHS):
        model.train()
        lod2batch.set_epoch(epoch, [encoder_optimizer, decoder_optimizer])

        logger.info(
            "Batch size: %d, Batch size per GPU: %d, LOD: %d - %dx%d, blend: %.3f, dataset size: %d"
            % (lod2batch.get_batch_size(), lod2batch.get_per_GPU_batch_size(),
               lod2batch.lod, 2**lod2batch.get_lod_power2(), 2**
               lod2batch.get_lod_power2(), lod2batch.get_blend_factor(),
               len(dataset) * world_size))

        dataset.reset(lod2batch.get_lod_power2(),
                      lod2batch.get_per_GPU_batch_size())
        batches = make_dataloader(cfg, logger, dataset,
                                  lod2batch.get_per_GPU_batch_size(),
                                  local_rank)

        scheduler.set_batch_size(lod2batch.get_batch_size(), lod2batch.lod)

        model.train()

        need_permute = False
        epoch_start_time = time.time()

        i = 0
        with torch.autograd.profiler.profile(use_cuda=True,
                                             enabled=False) as prof:
            for x_orig in tqdm(batches):
                i += 1
                with torch.no_grad():
                    if x_orig.shape[0] != lod2batch.get_per_GPU_batch_size():
                        continue
                    if need_permute:
                        x_orig = x_orig.permute(0, 3, 1, 2)
                    x_orig = (x_orig / 127.5 - 1.)

                    blend_factor = lod2batch.get_blend_factor()

                    needed_resolution = layer_to_resolution[lod2batch.lod]
                    x = x_orig

                    if lod2batch.in_transition:
                        needed_resolution_prev = layer_to_resolution[
                            lod2batch.lod - 1]
                        x_prev = F.avg_pool2d(x_orig, 2, 2)
                        x_prev_2x = F.interpolate(x_prev, needed_resolution)
                        x = x * blend_factor + x_prev_2x * (1.0 - blend_factor)

                x.requires_grad = True

                loss_d = model(x,
                               lod2batch.lod,
                               blend_factor,
                               d_train=True,
                               ae=False)
                tracker.update(dict(loss_d=loss_d))
                loss_d.backward()
                discriminator_optimizer.step()
                decoder_optimizer.zero_grad()
                discriminator_optimizer.zero_grad()

                loss_g = model(x,
                               lod2batch.lod,
                               blend_factor,
                               d_train=False,
                               ae=False)
                tracker.update(dict(loss_g=loss_g))
                loss_g.backward()
                decoder_optimizer.step()
                decoder_optimizer.zero_grad()
                discriminator_optimizer.zero_grad()

                lae = model(x,
                            lod2batch.lod,
                            blend_factor,
                            d_train=True,
                            ae=True)
                tracker.update(dict(lae=lae))
                (lae).backward()
                encoder_optimizer.step()
                decoder_optimizer.step()
                encoder_optimizer.zero_grad()
                decoder_optimizer.zero_grad()

                if local_rank == 0:
                    betta = 0.5**(lod2batch.get_batch_size() / (10 * 1000.0))
                    model_s.lerp(model, betta)

                epoch_end_time = time.time()
                per_epoch_ptime = epoch_end_time - epoch_start_time

                lod_for_saving_model = lod2batch.lod
                lod2batch.step()
                if local_rank == 0:
                    if lod2batch.is_time_to_save():
                        checkpointer.save("model_tmp_intermediate_lod%d" %
                                          lod_for_saving_model)
                    if lod2batch.is_time_to_report():
                        save_sample(lod2batch, tracker, sample, samplez, x,
                                    logger, model_s, cfg, encoder_optimizer,
                                    decoder_optimizer)

        scheduler.step()

        if local_rank == 0:
            checkpointer.save("model_tmp_lod%d" % lod_for_saving_model)
            save_sample(lod2batch, tracker, sample, samplez, x, logger,
                        model_s, cfg, encoder_optimizer, decoder_optimizer)

    logger.info("Training finish!... save training results")
    if local_rank == 0:
        checkpointer.save("model_final").wait()
Ejemplo n.º 2
0
def train(cfg, logger, local_rank, world_size, distributed):
    torch.cuda.set_device(local_rank)
    model = Model(startf=cfg.MODEL.START_CHANNEL_COUNT,
                  layer_count=cfg.MODEL.LAYER_COUNT,
                  maxf=cfg.MODEL.MAX_CHANNEL_COUNT,
                  latent_size=cfg.MODEL.LATENT_SPACE_SIZE,
                  dlatent_avg_beta=cfg.MODEL.DLATENT_AVG_BETA,
                  style_mixing_prob=cfg.MODEL.STYLE_MIXING_PROB,
                  mapping_layers=cfg.MODEL.MAPPING_LAYERS,
                  channels=1)
    model.cuda(local_rank)
    model.train()

    if local_rank == 0:
        model_s = Model(startf=cfg.MODEL.START_CHANNEL_COUNT,
                        layer_count=cfg.MODEL.LAYER_COUNT,
                        maxf=cfg.MODEL.MAX_CHANNEL_COUNT,
                        latent_size=cfg.MODEL.LATENT_SPACE_SIZE,
                        truncation_psi=cfg.MODEL.TRUNCATIOM_PSI,
                        truncation_cutoff=cfg.MODEL.TRUNCATIOM_CUTOFF,
                        mapping_layers=cfg.MODEL.MAPPING_LAYERS,
                        channels=1)
        del model_s.discriminator
        model_s.cuda(local_rank)
        model_s.eval()
        model_s.requires_grad_(False)

    if distributed:
        model = nn.parallel.DistributedDataParallel(
            model,
            device_ids=[local_rank],
            broadcast_buffers=False,
            bucket_cap_mb=25,
            find_unused_parameters=True)
        model.device_ids = None

        generator = model.module.generator
        discriminator = model.module.discriminator
        mapping = model.module.mapping
        dlatent_avg = model.module.dlatent_avg
    else:
        generator = model.generator
        discriminator = model.discriminator
        mapping = model.mapping
        dlatent_avg = model.dlatent_avg

    count_param_override.print = lambda a: logger.info(a)

    logger.info("Trainable parameters generator:")
    count_parameters(generator)

    logger.info("Trainable parameters discriminator:")
    count_parameters(discriminator)

    arguments = dict()
    arguments["iteration"] = 0

    generator_optimizer = LREQAdam([{
        'params': generator.parameters()
    }, {
        'params': mapping.parameters()
    }],
                                   lr=cfg.TRAIN.BASE_LEARNING_RATE,
                                   betas=(cfg.TRAIN.ADAM_BETA_0,
                                          cfg.TRAIN.ADAM_BETA_1),
                                   weight_decay=0)

    discriminator_optimizer = LREQAdam([
        {
            'params': discriminator.parameters()
        },
    ],
                                       lr=cfg.TRAIN.BASE_LEARNING_RATE,
                                       betas=(cfg.TRAIN.ADAM_BETA_0,
                                              cfg.TRAIN.ADAM_BETA_1),
                                       weight_decay=0)

    scheduler = ComboMultiStepLR(optimizers={
        'generator': generator_optimizer,
        'discriminator': discriminator_optimizer
    },
                                 milestones=cfg.TRAIN.LEARNING_DECAY_STEPS,
                                 gamma=cfg.TRAIN.LEARNING_DECAY_RATE,
                                 reference_batch_size=32)

    model_dict = {
        'discriminator': discriminator,
        'generator': generator,
        'mapping': mapping,
        'dlatent_avg': dlatent_avg
    }

    if local_rank == 0:
        model_dict['generator_s'] = model_s.generator
        model_dict['mapping_s'] = model_s.mapping

    tracker = LossTracker(cfg.OUTPUT_DIR)

    checkpointer = Checkpointer(
        cfg,
        model_dict, {
            'generator_optimizer': generator_optimizer,
            'discriminator_optimizer': discriminator_optimizer,
            'tracker': tracker,
            'scheduler': scheduler,
        },
        logger=logger,
        save=local_rank == 0)

    extra_checkpoint_data = checkpointer.load()

    arguments.update(extra_checkpoint_data)

    layer_to_resolution = generator.layer_to_resolution

    dlutils.download.mnist()
    mnist = dlutils.reader.Mnist('mnist').items
    mnist = np.asarray([x[1] for x in mnist], np.float32)

    rnd = np.random.RandomState(3456)
    latents = rnd.randn(32, cfg.MODEL.LATENT_SPACE_SIZE)
    sample = torch.tensor(latents).float().cuda()

    lod2batch = lod_driver.LODDriver(cfg,
                                     logger,
                                     world_size,
                                     dataset_size=len(mnist))

    for epoch in range(scheduler.start_epoch(), cfg.TRAIN.TRAIN_EPOCHS):
        model.train()
        lod2batch.set_epoch(epoch,
                            [generator_optimizer, discriminator_optimizer])

        logger.info(
            "Batch size: %d, Batch size per GPU: %d, LOD: %d - %dx%d, dataset size: %d"
            % (lod2batch.get_batch_size(), lod2batch.get_per_GPU_batch_size(),
               lod2batch.lod, 2**lod2batch.get_lod_power2(), 2**
               lod2batch.get_lod_power2(), len(mnist)))

        dlutils.shuffle.shuffle_ndarray(mnist)

        r = 2**lod2batch.get_lod_power2()
        mnist_ = F.interpolate(
            torch.tensor(mnist).view(mnist.shape[0], 1, 28, 28),
            r).detach().cpu().numpy()

        scheduler.set_batch_size(32)

        model.train()

        need_permute = False

        class BatchCollator(object):
            def __init__(self, device=torch.device("cpu")):
                self.device = device

            def __call__(self, batch):
                with torch.no_grad():
                    x = batch
                    x = torch.tensor(x,
                                     requires_grad=True,
                                     device=torch.device(self.device),
                                     dtype=torch.float32)
                    return x

        batches = dlutils.batch_provider(mnist_,
                                         lod2batch.get_per_GPU_batch_size(),
                                         BatchCollator(local_rank),
                                         report_progress=False)

        with torch.autograd.profiler.profile(use_cuda=True,
                                             enabled=False) as prof:
            for x_orig in tqdm(batches):
                if x_orig.shape[0] != lod2batch.get_per_GPU_batch_size():
                    continue
                if need_permute:
                    x_orig = x_orig.permute(0, 3, 1, 2)
                x_orig = (x_orig / 127.5 - 1.)

                blend_factor = lod2batch.get_blend_factor()

                needed_resolution = layer_to_resolution[lod2batch.lod]
                x = x_orig

                if lod2batch.in_transition:
                    needed_resolution_prev = layer_to_resolution[lod2batch.lod
                                                                 - 1]
                    x_prev = F.avg_pool2d(x_orig, 2, 2)
                    x_prev_2x = F.interpolate(x_prev, needed_resolution)
                    x = x * blend_factor + x_prev_2x * (1.0 - blend_factor)

                discriminator_optimizer.zero_grad()
                loss_d = model(x, lod2batch.lod, blend_factor, d_train=True)
                tracker.update(dict(loss_d=loss_d))
                loss_d.backward()
                discriminator_optimizer.step()

                if local_rank == 0:
                    betta = 0.5**(lod2batch.get_batch_size() / (10 * 1000.0))
                    model_s.lerp(model, betta)

                generator_optimizer.zero_grad()
                loss_g = model(x, lod2batch.lod, blend_factor, d_train=False)
                tracker.update(dict(loss_g=loss_g))
                loss_g.backward()
                generator_optimizer.step()

                lod2batch.step()
                if local_rank == 0 and lod2batch.is_time_to_report():
                    save_sample(lod2batch, tracker, sample, x, logger, model_s,
                                cfg, discriminator_optimizer,
                                generator_optimizer)

        #print(prof.key_averages().table(sort_by="self_cpu_time_total"))

        if local_rank == 0:
            save_sample(lod2batch, tracker, sample, x, logger, model_s, cfg,
                        discriminator_optimizer, generator_optimizer)
            if epoch > 2:
                checkpointer.save("model_tmp")

        scheduler.step()

    logger.info("Training finish!... save training results")
    if local_rank == 0:
        checkpointer.save("model_final")
def train(cfg, logger, local_rank, world_size, distributed):
    torch.cuda.set_device(local_rank)
    model = SoftIntroVAEModelTL(
        startf=cfg.MODEL.START_CHANNEL_COUNT,
        layer_count=cfg.MODEL.LAYER_COUNT,
        maxf=cfg.MODEL.MAX_CHANNEL_COUNT,
        latent_size=cfg.MODEL.LATENT_SPACE_SIZE,
        dlatent_avg_beta=cfg.MODEL.DLATENT_AVG_BETA,
        style_mixing_prob=cfg.MODEL.STYLE_MIXING_PROB,
        mapping_layers=cfg.MODEL.MAPPING_LAYERS,
        channels=cfg.MODEL.CHANNELS,
        generator=cfg.MODEL.GENERATOR,
        encoder=cfg.MODEL.ENCODER,
        beta_kl=cfg.MODEL.BETA_KL,
        beta_rec=cfg.MODEL.BETA_REC,
        beta_neg=cfg.MODEL.BETA_NEG[cfg.MODEL.LAYER_COUNT - 1],
        scale=cfg.MODEL.SCALE)
    model.cuda(local_rank)
    model.train()

    if local_rank == 0:
        model_s = SoftIntroVAEModelTL(
            startf=cfg.MODEL.START_CHANNEL_COUNT,
            layer_count=cfg.MODEL.LAYER_COUNT,
            maxf=cfg.MODEL.MAX_CHANNEL_COUNT,
            latent_size=cfg.MODEL.LATENT_SPACE_SIZE,
            truncation_psi=cfg.MODEL.TRUNCATIOM_PSI,
            truncation_cutoff=cfg.MODEL.TRUNCATIOM_CUTOFF,
            mapping_layers=cfg.MODEL.MAPPING_LAYERS,
            channels=cfg.MODEL.CHANNELS,
            generator=cfg.MODEL.GENERATOR,
            encoder=cfg.MODEL.ENCODER,
            beta_kl=cfg.MODEL.BETA_KL,
            beta_rec=cfg.MODEL.BETA_REC,
            beta_neg=cfg.MODEL.BETA_NEG[cfg.MODEL.LAYER_COUNT - 1],
            scale=cfg.MODEL.SCALE)
        model_s.cuda(local_rank)
        model_s.eval()
        model_s.requires_grad_(False)

    if distributed:
        model = nn.parallel.DistributedDataParallel(
            model,
            device_ids=[local_rank],
            broadcast_buffers=False,
            bucket_cap_mb=25,
            find_unused_parameters=True)
        model.device_ids = None

        decoder = model.module.decoder
        encoder = model.module.encoder
        mapping_tl = model.module.mapping_tl
        mapping_fl = model.module.mapping_fl
        dlatent_avg = model.module.dlatent_avg
    else:
        decoder = model.decoder
        encoder = model.encoder
        mapping_tl = model.mapping_tl
        mapping_fl = model.mapping_fl
        dlatent_avg = model.dlatent_avg

    count_parameters.print = lambda a: logger.info(a)

    num_vae_epochs = cfg.TRAIN.NUM_VAE

    logger.info("Trainable parameters decoder:")
    print(count_parameters(decoder))

    logger.info("Trainable parameters encoder:")
    print(count_parameters(encoder))

    arguments = dict()
    arguments["iteration"] = 0

    decoder_optimizer = LREQAdam([{
        'params': decoder.parameters()
    }, {
        'params': mapping_fl.parameters()
    }],
                                 lr=cfg.TRAIN.BASE_LEARNING_RATE,
                                 betas=(cfg.TRAIN.ADAM_BETA_0,
                                        cfg.TRAIN.ADAM_BETA_1),
                                 weight_decay=0)

    encoder_optimizer = LREQAdam([{
        'params': encoder.parameters()
    }, {
        'params': mapping_tl.parameters()
    }],
                                 lr=cfg.TRAIN.BASE_LEARNING_RATE,
                                 betas=(cfg.TRAIN.ADAM_BETA_0,
                                        cfg.TRAIN.ADAM_BETA_1),
                                 weight_decay=0)

    scheduler = ComboMultiStepLR(optimizers={
        'encoder_optimizer': encoder_optimizer,
        'decoder_optimizer': decoder_optimizer
    },
                                 milestones=cfg.TRAIN.LEARNING_DECAY_STEPS,
                                 gamma=cfg.TRAIN.LEARNING_DECAY_RATE,
                                 reference_batch_size=32,
                                 base_lr=cfg.TRAIN.LEARNING_RATES)

    model_dict = {
        'discriminator': encoder,
        'generator': decoder,
        'mapping_fl': mapping_fl,
        'mapping_tl': mapping_tl,
        'dlatent_avg': dlatent_avg
    }
    if local_rank == 0:
        model_dict['discriminator_s'] = model_s.encoder
        model_dict['generator_s'] = model_s.decoder
        model_dict['mapping_fl_s'] = model_s.mapping_fl
        model_dict['mapping_tl_s'] = model_s.mapping_tl

    tracker = LossTracker(cfg.OUTPUT_DIR)

    checkpointer = Checkpointer(cfg,
                                model_dict, {
                                    'encoder_optimizer': encoder_optimizer,
                                    'decoder_optimizer': decoder_optimizer,
                                    'scheduler': scheduler,
                                    'tracker': tracker
                                },
                                logger=logger,
                                save=local_rank == 0)

    extra_checkpoint_data = checkpointer.load()
    logger.info("Starting from epoch: %d" % (scheduler.start_epoch()))

    arguments.update(extra_checkpoint_data)

    layer_to_resolution = decoder.layer_to_resolution

    dataset = TFRecordsDataset(cfg,
                               logger,
                               rank=local_rank,
                               world_size=world_size,
                               buffer_size_mb=1024,
                               channels=cfg.MODEL.CHANNELS)

    rnd = np.random.RandomState(3456)
    latents = rnd.randn(32, cfg.MODEL.LATENT_SPACE_SIZE)
    samplez = torch.tensor(latents).float().cuda()

    lod2batch = lod_driver.LODDriver(cfg,
                                     logger,
                                     world_size,
                                     dataset_size=len(dataset) * world_size)

    if cfg.DATASET.SAMPLES_PATH:
        path = cfg.DATASET.SAMPLES_PATH
        src = []
        with torch.no_grad():
            for filename in list(os.listdir(path))[:32]:
                img = np.asarray(Image.open(os.path.join(path, filename)))
                if img.shape[2] == 4:
                    img = img[:, :, :3]
                im = img.transpose((2, 0, 1))
                x = torch.tensor(np.asarray(im, dtype=np.float32),
                                 requires_grad=True).cuda() / 127.5 - 1.
                if x.shape[0] == 4:
                    x = x[:3]
                src.append(x)
            sample = torch.stack(src)
    else:
        dataset.reset(cfg.DATASET.MAX_RESOLUTION_LEVEL, 32)
        sample = next(make_dataloader(cfg, logger, dataset, 32, local_rank))
        sample = (sample / 127.5 - 1.)

    lod2batch.set_epoch(scheduler.start_epoch(),
                        [encoder_optimizer, decoder_optimizer])

    kls_real = []
    kls_fake = []
    rec_errs = []

    best_fid = None
    # best_fid = 20.0

    for epoch in range(scheduler.start_epoch(), cfg.TRAIN.TRAIN_EPOCHS):
        new_beta_neg = cfg.MODEL.BETA_NEG[lod2batch.lod]
        if distributed:
            if model.module.beta_neg != new_beta_neg:
                model.module.beta_neg = new_beta_neg
                print("beta negative changed to:", new_beta_neg)
        else:
            if model.beta_neg != new_beta_neg:
                model.beta_neg = new_beta_neg
                print("beta negative changed to:", new_beta_neg)
        if (epoch > cfg.TRAIN.EPOCHS_PER_LOD *
            (cfg.MODEL.LAYER_COUNT - 1)) and (epoch % 10 == 0) and (local_rank
                                                                    == 0):
            print("calculating fid...")
            fid = calc_fid_from_dataset_generate(
                cfg,
                dataset,
                model_s,
                batch_size=50,
                cuda=1,
                dims=2048,
                device=torch.device(
                    "cuda:0" if torch.cuda.is_available() else "cpu"),
                num_images=50000)
            print("epoch: {}, fid: {}".format(epoch, fid))
            if best_fid is None:
                best_fid = fid
            elif fid < best_fid:
                print("best fid updated: {} -> {}".format(best_fid, fid))
                best_fid = fid
                checkpointer.save("model_tmp_lod{}_fid_{}".format(
                    lod2batch.lod, fid))
        lod2batch.set_epoch(epoch, [encoder_optimizer, decoder_optimizer])

        logger.info(
            "Batch size: %d, Batch size per GPU: %d, LOD: %d - %dx%d, blend: %.3f, dataset size: %d"
            % (lod2batch.get_batch_size(), lod2batch.get_per_GPU_batch_size(),
               lod2batch.lod, 2**lod2batch.get_lod_power2(), 2**
               lod2batch.get_lod_power2(), lod2batch.get_blend_factor(),
               len(dataset) * world_size))

        dataset.reset(lod2batch.get_lod_power2(),
                      lod2batch.get_per_GPU_batch_size())
        batches = make_dataloader(cfg, logger, dataset,
                                  lod2batch.get_per_GPU_batch_size(),
                                  local_rank)

        scheduler.set_batch_size(lod2batch.get_batch_size(), lod2batch.lod)

        model.train()

        need_permute = False
        epoch_start_time = time.time()

        i = 0

        diff_kls = []
        batch_kls_real = []
        batch_kls_fake = []
        batch_rec_errs = []

        for x_orig in tqdm(batches):
            i += 1
            with torch.no_grad():
                if x_orig.shape[0] != lod2batch.get_per_GPU_batch_size():
                    continue
                if need_permute:
                    x_orig = x_orig.permute(0, 3, 1, 2)
                x_orig = (x_orig / 127.5 - 1.)

                blend_factor = lod2batch.get_blend_factor()

                needed_resolution = layer_to_resolution[lod2batch.lod]
                x = x_orig

                if lod2batch.in_transition:
                    needed_resolution_prev = layer_to_resolution[lod2batch.lod
                                                                 - 1]
                    x_prev = F.avg_pool2d(x_orig, 2, 2)
                    x_prev_2x = F.interpolate(x_prev, needed_resolution)
                    x = x * blend_factor + x_prev_2x * (1.0 - blend_factor)

            x.requires_grad = True
            if epoch < num_vae_epochs:
                encoder_optimizer.zero_grad()
                decoder_optimizer.zero_grad()
                loss = model(x,
                             lod2batch.lod,
                             blend_factor,
                             d_train=False,
                             e_train=False)
                tracker.update(dict(loss_e=loss))
                tracker.update(dict(loss_d=loss))
                loss.backward()
                encoder_optimizer.step()
                decoder_optimizer.step()
            else:
                # ------------- Update Encoder ------------- #
                encoder_optimizer.zero_grad()
                loss_e = model(x,
                               lod2batch.lod,
                               blend_factor,
                               d_train=False,
                               e_train=True)
                tracker.update(dict(loss_e=loss_e))
                loss_e.backward()
                encoder_optimizer.step()

                # ------------- Update Decoder ------------- #
                decoder_optimizer.zero_grad()
                loss_d = model(x,
                               lod2batch.lod,
                               blend_factor,
                               d_train=True,
                               e_train=False)
                loss_d.backward()
                tracker.update(dict(loss_d=loss_d))
                decoder_optimizer.step()

            # ------------- Update Statistics ------------- #
            if distributed:
                tracker.update(dict(rec_loss=model.module.last_rec_loss))
                tracker.update(dict(real_kl=model.module.last_real_kl))
                tracker.update(dict(fake_kl=model.module.last_fake_kl))
                tracker.update(dict(kl_diff=model.module.last_kl_diff))
                tracker.update(dict(expelbo_f=model.module.last_expelbo_fake))
                tracker.update(dict(expelbo_r=model.module.last_expelbo_rec))

                diff_kls.append(model.module.last_kl_diff.data.cpu())
                batch_kls_real.append(model.module.last_real_kl)
                batch_kls_fake.append(model.module.last_fake_kl)
                batch_rec_errs.append(model.module.last_rec_loss)
            else:
                tracker.update(dict(rec_loss=model.last_rec_loss))
                tracker.update(dict(real_kl=model.last_real_kl))
                tracker.update(dict(fake_kl=model.last_fake_kl))
                tracker.update(dict(kl_diff=model.last_kl_diff))
                tracker.update(dict(expelbo_f=model.last_expelbo_fake))
                tracker.update(dict(expelbo_r=model.last_expelbo_rec))

                diff_kls.append(model.last_kl_diff.data.cpu())
                batch_kls_real.append(model.last_real_kl)
                batch_kls_fake.append(model.last_fake_kl)
                batch_rec_errs.append(model.last_rec_loss)

            if local_rank == 0:
                betta = 0.5**(lod2batch.get_batch_size() / (10 * 1000.0))
                model_s.lerp(model, betta)

            epoch_end_time = time.time()
            per_epoch_ptime = epoch_end_time - epoch_start_time

            lod_for_saving_model = lod2batch.lod
            lod2batch.step()
            if local_rank == 0:
                if lod2batch.is_time_to_save():
                    checkpointer.save("model_tmp_intermediate_lod%d" %
                                      lod_for_saving_model)
                if lod2batch.is_time_to_report():
                    save_sample(lod2batch, tracker, sample, samplez, x, logger,
                                model_s, cfg, encoder_optimizer,
                                decoder_optimizer)

        scheduler.step()
        mean_diff_kl = np.mean(diff_kls)
        print("mean diff kl: ", mean_diff_kl)

        if epoch > num_vae_epochs - 1:
            kls_real.append(np.mean(batch_kls_real))
            kls_fake.append(np.mean(batch_kls_fake))
            rec_errs.append(np.mean(batch_rec_errs))

        if local_rank == 0:
            checkpointer.save("model_tmp_lod%d" % lod_for_saving_model)
            save_sample(lod2batch, tracker, sample, samplez, x, logger,
                        model_s, cfg, encoder_optimizer, decoder_optimizer)

    logger.info("Training finish!... save training results")
    if local_rank == 0:
        checkpointer.save("model_final").wait()
Ejemplo n.º 4
0
def train(cfg, logger, local_rank, world_size, distributed):
    torch.cuda.set_device(local_rank)
    model = Model(startf=cfg.MODEL.START_CHANNEL_COUNT,
                  layer_count=cfg.MODEL.LAYER_COUNT,
                  maxf=cfg.MODEL.MAX_CHANNEL_COUNT,
                  latent_size=cfg.MODEL.LATENT_SPACE_SIZE,
                  dlatent_avg_beta=cfg.MODEL.DLATENT_AVG_BETA,
                  style_mixing_prob=cfg.MODEL.STYLE_MIXING_PROB,
                  mapping_layers=cfg.MODEL.MAPPING_LAYERS,
                  channels=3)
    model.cuda(local_rank)
    model.train()

    if local_rank == 0:
        model_s = Model(startf=cfg.MODEL.START_CHANNEL_COUNT,
                        layer_count=cfg.MODEL.LAYER_COUNT,
                        maxf=cfg.MODEL.MAX_CHANNEL_COUNT,
                        latent_size=cfg.MODEL.LATENT_SPACE_SIZE,
                        truncation_psi=cfg.MODEL.TRUNCATIOM_PSI,
                        truncation_cutoff=cfg.MODEL.TRUNCATIOM_CUTOFF,
                        mapping_layers=cfg.MODEL.MAPPING_LAYERS,
                        channels=3)
        del model_s.discriminator
        model_s.cuda(local_rank)
        model_s.eval()
        model_s.requires_grad_(False)

    if distributed:
        model = nn.parallel.DistributedDataParallel(
            model,
            device_ids=[local_rank],
            broadcast_buffers=False,
            bucket_cap_mb=25,
            find_unused_parameters=True)
        model.device_ids = None

        generator = model.module.generator
        discriminator = model.module.discriminator
        mapping = model.module.mapping
        dlatent_avg = model.module.dlatent_avg
    else:
        generator = model.generator
        discriminator = model.discriminator
        mapping = model.mapping
        dlatent_avg = model.dlatent_avg

    count_param_override.print = lambda a: logger.info(a)

    logger.info("Trainable parameters generator:")
    count_parameters(generator)

    logger.info("Trainable parameters discriminator:")
    count_parameters(discriminator)

    generator_optimizer = LREQAdam([{
        'params': generator.parameters()
    }, {
        'params': mapping.parameters()
    }],
                                   lr=cfg.TRAIN.BASE_LEARNING_RATE,
                                   betas=(cfg.TRAIN.ADAM_BETA_0,
                                          cfg.TRAIN.ADAM_BETA_1),
                                   weight_decay=0)

    discriminator_optimizer = LREQAdam([
        {
            'params': discriminator.parameters()
        },
    ],
                                       lr=cfg.TRAIN.BASE_LEARNING_RATE,
                                       betas=(cfg.TRAIN.ADAM_BETA_0,
                                              cfg.TRAIN.ADAM_BETA_1),
                                       weight_decay=0)

    scheduler = ComboMultiStepLR(optimizers={
        'generator': generator_optimizer,
        'discriminator': discriminator_optimizer
    },
                                 milestones=cfg.TRAIN.LEARNING_DECAY_STEPS,
                                 gamma=cfg.TRAIN.LEARNING_DECAY_RATE,
                                 reference_batch_size=32,
                                 base_lr=cfg.TRAIN.LEARNING_RATES)

    model_dict = {
        'discriminator': discriminator,
        'generator': generator,
        'mapping': mapping,
        'dlatent_avg': dlatent_avg
    }

    if local_rank == 0:
        model_dict['generator_s'] = model_s.generator
        model_dict['mapping_s'] = model_s.mapping

    tracker = LossTracker(cfg.OUTPUT_DIR)

    checkpointer = Checkpointer(
        cfg,
        model_dict, {
            'generator_optimizer': generator_optimizer,
            'discriminator_optimizer': discriminator_optimizer,
            'scheduler': scheduler,
            'tracker': tracker
        },
        logger=logger,
        save=local_rank == 0)

    checkpointer.load()
    logger.info("Starting from epoch: %d" % (scheduler.start_epoch()))

    layer_to_resolution = generator.layer_to_resolution

    dataset = TFRecordsDataset(cfg,
                               logger,
                               rank=local_rank,
                               world_size=world_size,
                               buffer_size_mb=1024)

    rnd = np.random.RandomState(3456)
    latents = rnd.randn(32, cfg.MODEL.LATENT_SPACE_SIZE)
    sample = torch.tensor(latents).float().cuda()

    lod2batch = lod_driver.LODDriver(cfg,
                                     logger,
                                     world_size,
                                     dataset_size=len(dataset) * world_size)

    for epoch in range(scheduler.start_epoch(), cfg.TRAIN.TRAIN_EPOCHS):
        model.train()
        lod2batch.set_epoch(epoch,
                            [generator_optimizer, discriminator_optimizer])

        logger.info(
            "Batch size: %d, Batch size per GPU: %d, LOD: %d - %dx%d, blend: %.3f, dataset size: %d"
            % (lod2batch.get_batch_size(), lod2batch.get_per_GPU_batch_size(),
               lod2batch.lod, 2**lod2batch.get_lod_power2(), 2**
               lod2batch.get_lod_power2(), lod2batch.get_blend_factor(),
               len(dataset) * world_size))

        dataset.reset(lod2batch.get_lod_power2(),
                      lod2batch.get_per_GPU_batch_size())
        batches = make_dataloader(cfg, logger, dataset,
                                  lod2batch.get_per_GPU_batch_size(),
                                  local_rank)

        scheduler.set_batch_size(lod2batch.get_batch_size(), lod2batch.lod)

        model.train()

        need_permute = False

        with torch.autograd.profiler.profile(use_cuda=True,
                                             enabled=False) as prof:
            for x_orig in tqdm(batches):
                torch.distributed.barrier()

                with torch.no_grad():
                    if x_orig.shape[0] != lod2batch.get_per_GPU_batch_size():
                        continue
                    if need_permute:
                        x_orig = x_orig.permute(0, 3, 1, 2)
                    x_orig = (x_orig / 127.5 - 1.)

                    blend_factor = lod2batch.get_blend_factor()

                    needed_resolution = layer_to_resolution[lod2batch.lod]
                    x = x_orig

                    if lod2batch.in_transition:
                        needed_resolution_prev = layer_to_resolution[
                            lod2batch.lod - 1]
                        x_prev = F.avg_pool2d(x_orig, 2, 2)
                        x_prev_2x = F.interpolate(x_prev, needed_resolution)
                        x = x * blend_factor + x_prev_2x * (1.0 - blend_factor)
                x.requires_grad = True

                discriminator_optimizer.zero_grad()
                loss_d = model(x, lod2batch.lod, blend_factor, d_train=True)
                tracker.update(dict(loss_d=loss_d))
                loss_d.backward()
                discriminator_optimizer.step()

                if local_rank == 0:
                    betta = 0.5**(lod2batch.get_batch_size() / (10 * 1000.0))
                    model_s.lerp(model, betta)

                generator_optimizer.zero_grad()
                loss_g = model(x, lod2batch.lod, blend_factor, d_train=False)
                tracker.update(dict(loss_g=loss_g))
                loss_g.backward()
                generator_optimizer.step()

                lod2batch.step()
                if local_rank == 0:
                    if lod2batch.is_time_to_save():
                        checkpointer.save("model_tmp_intermediate")
                    if lod2batch.is_time_to_report():
                        save_sample(lod2batch, tracker, sample, x, logger,
                                    model_s, cfg, discriminator_optimizer,
                                    generator_optimizer)

        #print(prof.key_averages().table(sort_by="self_cpu_time_total"))

        scheduler.step()

        if local_rank == 0:
            checkpointer.save("model_tmp")
            save_sample(lod2batch, tracker, sample, x, logger, model_s, cfg,
                        discriminator_optimizer, generator_optimizer)

    logger.info("Training finish!... save training results")
    if local_rank == 0:
        checkpointer.save("model_final").wait()