Esempio n. 1
0
def train(cfg, logger, local_rank, world_size, distributed):
    torch.cuda.set_device(local_rank)
    model = Model(startf=cfg.MODEL.START_CHANNEL_COUNT,
                  layer_count=cfg.MODEL.LAYER_COUNT,
                  maxf=cfg.MODEL.MAX_CHANNEL_COUNT,
                  latent_size=cfg.MODEL.LATENT_SPACE_SIZE,
                  dlatent_avg_beta=cfg.MODEL.DLATENT_AVG_BETA,
                  style_mixing_prob=cfg.MODEL.STYLE_MIXING_PROB,
                  mapping_layers=cfg.MODEL.MAPPING_LAYERS,
                  channels=cfg.MODEL.CHANNELS,
                  generator=cfg.MODEL.GENERATOR,
                  encoder=cfg.MODEL.ENCODER)
    model.cuda(local_rank)
    model.train()

    if local_rank == 0:
        model_s = Model(startf=cfg.MODEL.START_CHANNEL_COUNT,
                        layer_count=cfg.MODEL.LAYER_COUNT,
                        maxf=cfg.MODEL.MAX_CHANNEL_COUNT,
                        latent_size=cfg.MODEL.LATENT_SPACE_SIZE,
                        truncation_psi=cfg.MODEL.TRUNCATIOM_PSI,
                        truncation_cutoff=cfg.MODEL.TRUNCATIOM_CUTOFF,
                        mapping_layers=cfg.MODEL.MAPPING_LAYERS,
                        channels=cfg.MODEL.CHANNELS,
                        generator=cfg.MODEL.GENERATOR,
                        encoder=cfg.MODEL.ENCODER)
        model_s.cuda(local_rank)
        model_s.eval()
        model_s.requires_grad_(False)

    if distributed:
        model = nn.parallel.DistributedDataParallel(
            model,
            device_ids=[local_rank],
            broadcast_buffers=False,
            bucket_cap_mb=25,
            find_unused_parameters=True)
        model.device_ids = None

        decoder = model.module.decoder
        encoder = model.module.encoder
        discriminator = model.module.discriminator
        mapping_fl = model.module.mapping_fl
        dlatent_avg = model.module.dlatent_avg
    else:
        decoder = model.decoder
        encoder = model.encoder
        discriminator = model.discriminator
        mapping_fl = model.mapping_fl
        dlatent_avg = model.dlatent_avg

    count_param_override.print = lambda a: logger.info(a)

    logger.info("Trainable parameters generator:")
    count_parameters(decoder)

    logger.info("Trainable parameters discriminator:")
    count_parameters(encoder)

    arguments = dict()
    arguments["iteration"] = 0

    decoder_optimizer = LREQAdam([{
        'params': decoder.parameters()
    }, {
        'params': mapping_fl.parameters()
    }],
                                 lr=cfg.TRAIN.BASE_LEARNING_RATE,
                                 betas=(cfg.TRAIN.ADAM_BETA_0,
                                        cfg.TRAIN.ADAM_BETA_1),
                                 weight_decay=0)

    encoder_optimizer = LREQAdam([
        {
            'params': encoder.parameters()
        },
    ],
                                 lr=cfg.TRAIN.BASE_LEARNING_RATE,
                                 betas=(cfg.TRAIN.ADAM_BETA_0,
                                        cfg.TRAIN.ADAM_BETA_1),
                                 weight_decay=0)

    discriminator_optimizer = LREQAdam([
        {
            'params': discriminator.parameters()
        },
    ],
                                       lr=cfg.TRAIN.BASE_LEARNING_RATE,
                                       betas=(cfg.TRAIN.ADAM_BETA_0,
                                              cfg.TRAIN.ADAM_BETA_1),
                                       weight_decay=0)

    scheduler = ComboMultiStepLR(optimizers={
        'encoder_optimizer': encoder_optimizer,
        'discriminator_optimizer': discriminator_optimizer,
        'decoder_optimizer': decoder_optimizer
    },
                                 milestones=cfg.TRAIN.LEARNING_DECAY_STEPS,
                                 gamma=cfg.TRAIN.LEARNING_DECAY_RATE,
                                 reference_batch_size=32,
                                 base_lr=cfg.TRAIN.LEARNING_RATES)

    model_dict = {
        'discriminator': discriminator,
        'encoder': encoder,
        'generator': decoder,
        'mapping_fl': mapping_fl,
        'dlatent_avg': dlatent_avg
    }
    if local_rank == 0:
        model_dict['discriminator_s'] = model_s.discriminator
        model_dict['encoder_s'] = model_s.encoder
        model_dict['generator_s'] = model_s.decoder
        model_dict['mapping_fl_s'] = model_s.mapping_fl

    tracker = LossTracker(cfg.OUTPUT_DIR)

    checkpointer = Checkpointer(
        cfg,
        model_dict, {
            'encoder_optimizer': encoder_optimizer,
            'discriminator_optimizer': discriminator_optimizer,
            'decoder_optimizer': decoder_optimizer,
            'scheduler': scheduler,
            'tracker': tracker
        },
        logger=logger,
        save=local_rank == 0)

    extra_checkpoint_data = checkpointer.load()
    logger.info("Starting from epoch: %d" % (scheduler.start_epoch()))

    arguments.update(extra_checkpoint_data)

    layer_to_resolution = decoder.layer_to_resolution

    dataset = TFRecordsDataset(cfg,
                               logger,
                               rank=local_rank,
                               world_size=world_size,
                               buffer_size_mb=1024,
                               channels=cfg.MODEL.CHANNELS)

    rnd = np.random.RandomState(3456)
    latents = rnd.randn(32, cfg.MODEL.LATENT_SPACE_SIZE)
    samplez = torch.tensor(latents).float().cuda()

    lod2batch = lod_driver.LODDriver(cfg,
                                     logger,
                                     world_size,
                                     dataset_size=len(dataset) * world_size)

    if cfg.DATASET.SAMPLES_PATH:
        path = cfg.DATASET.SAMPLES_PATH
        src = []
        with torch.no_grad():
            for filename in list(os.listdir(path))[:32]:
                img = np.asarray(Image.open(os.path.join(path, filename)))
                if img.shape[2] == 4:
                    img = img[:, :, :3]
                im = img.transpose((2, 0, 1))
                x = torch.tensor(np.asarray(im, dtype=np.float32),
                                 requires_grad=True).cuda() / 127.5 - 1.
                if x.shape[0] == 4:
                    x = x[:3]
                src.append(x)
            sample = torch.stack(src)
    else:
        dataset.reset(cfg.DATASET.MAX_RESOLUTION_LEVEL, 32)
        sample = next(make_dataloader(cfg, logger, dataset, 32, local_rank))
        sample = (sample / 127.5 - 1.)

    lod2batch.set_epoch(scheduler.start_epoch(),
                        [encoder_optimizer, decoder_optimizer])

    for epoch in range(scheduler.start_epoch(), cfg.TRAIN.TRAIN_EPOCHS):
        model.train()
        lod2batch.set_epoch(epoch, [encoder_optimizer, decoder_optimizer])

        logger.info(
            "Batch size: %d, Batch size per GPU: %d, LOD: %d - %dx%d, blend: %.3f, dataset size: %d"
            % (lod2batch.get_batch_size(), lod2batch.get_per_GPU_batch_size(),
               lod2batch.lod, 2**lod2batch.get_lod_power2(), 2**
               lod2batch.get_lod_power2(), lod2batch.get_blend_factor(),
               len(dataset) * world_size))

        dataset.reset(lod2batch.get_lod_power2(),
                      lod2batch.get_per_GPU_batch_size())
        batches = make_dataloader(cfg, logger, dataset,
                                  lod2batch.get_per_GPU_batch_size(),
                                  local_rank)

        scheduler.set_batch_size(lod2batch.get_batch_size(), lod2batch.lod)

        model.train()

        need_permute = False
        epoch_start_time = time.time()

        i = 0
        with torch.autograd.profiler.profile(use_cuda=True,
                                             enabled=False) as prof:
            for x_orig in tqdm(batches):
                i += 1
                with torch.no_grad():
                    if x_orig.shape[0] != lod2batch.get_per_GPU_batch_size():
                        continue
                    if need_permute:
                        x_orig = x_orig.permute(0, 3, 1, 2)
                    x_orig = (x_orig / 127.5 - 1.)

                    blend_factor = lod2batch.get_blend_factor()

                    needed_resolution = layer_to_resolution[lod2batch.lod]
                    x = x_orig

                    if lod2batch.in_transition:
                        needed_resolution_prev = layer_to_resolution[
                            lod2batch.lod - 1]
                        x_prev = F.avg_pool2d(x_orig, 2, 2)
                        x_prev_2x = F.interpolate(x_prev, needed_resolution)
                        x = x * blend_factor + x_prev_2x * (1.0 - blend_factor)

                x.requires_grad = True

                loss_d = model(x,
                               lod2batch.lod,
                               blend_factor,
                               d_train=True,
                               ae=False)
                tracker.update(dict(loss_d=loss_d))
                loss_d.backward()
                discriminator_optimizer.step()
                decoder_optimizer.zero_grad()
                discriminator_optimizer.zero_grad()

                loss_g = model(x,
                               lod2batch.lod,
                               blend_factor,
                               d_train=False,
                               ae=False)
                tracker.update(dict(loss_g=loss_g))
                loss_g.backward()
                decoder_optimizer.step()
                decoder_optimizer.zero_grad()
                discriminator_optimizer.zero_grad()

                lae = model(x,
                            lod2batch.lod,
                            blend_factor,
                            d_train=True,
                            ae=True)
                tracker.update(dict(lae=lae))
                (lae).backward()
                encoder_optimizer.step()
                decoder_optimizer.step()
                encoder_optimizer.zero_grad()
                decoder_optimizer.zero_grad()

                if local_rank == 0:
                    betta = 0.5**(lod2batch.get_batch_size() / (10 * 1000.0))
                    model_s.lerp(model, betta)

                epoch_end_time = time.time()
                per_epoch_ptime = epoch_end_time - epoch_start_time

                lod_for_saving_model = lod2batch.lod
                lod2batch.step()
                if local_rank == 0:
                    if lod2batch.is_time_to_save():
                        checkpointer.save("model_tmp_intermediate_lod%d" %
                                          lod_for_saving_model)
                    if lod2batch.is_time_to_report():
                        save_sample(lod2batch, tracker, sample, samplez, x,
                                    logger, model_s, cfg, encoder_optimizer,
                                    decoder_optimizer)

        scheduler.step()

        if local_rank == 0:
            checkpointer.save("model_tmp_lod%d" % lod_for_saving_model)
            save_sample(lod2batch, tracker, sample, samplez, x, logger,
                        model_s, cfg, encoder_optimizer, decoder_optimizer)

    logger.info("Training finish!... save training results")
    if local_rank == 0:
        checkpointer.save("model_final").wait()
def train(cfg, logger, local_rank, world_size, distributed):
    torch.cuda.set_device(local_rank)
    model = SoftIntroVAEModelTL(
        startf=cfg.MODEL.START_CHANNEL_COUNT,
        layer_count=cfg.MODEL.LAYER_COUNT,
        maxf=cfg.MODEL.MAX_CHANNEL_COUNT,
        latent_size=cfg.MODEL.LATENT_SPACE_SIZE,
        dlatent_avg_beta=cfg.MODEL.DLATENT_AVG_BETA,
        style_mixing_prob=cfg.MODEL.STYLE_MIXING_PROB,
        mapping_layers=cfg.MODEL.MAPPING_LAYERS,
        channels=cfg.MODEL.CHANNELS,
        generator=cfg.MODEL.GENERATOR,
        encoder=cfg.MODEL.ENCODER,
        beta_kl=cfg.MODEL.BETA_KL,
        beta_rec=cfg.MODEL.BETA_REC,
        beta_neg=cfg.MODEL.BETA_NEG[cfg.MODEL.LAYER_COUNT - 1],
        scale=cfg.MODEL.SCALE)
    model.cuda(local_rank)
    model.train()

    if local_rank == 0:
        model_s = SoftIntroVAEModelTL(
            startf=cfg.MODEL.START_CHANNEL_COUNT,
            layer_count=cfg.MODEL.LAYER_COUNT,
            maxf=cfg.MODEL.MAX_CHANNEL_COUNT,
            latent_size=cfg.MODEL.LATENT_SPACE_SIZE,
            truncation_psi=cfg.MODEL.TRUNCATIOM_PSI,
            truncation_cutoff=cfg.MODEL.TRUNCATIOM_CUTOFF,
            mapping_layers=cfg.MODEL.MAPPING_LAYERS,
            channels=cfg.MODEL.CHANNELS,
            generator=cfg.MODEL.GENERATOR,
            encoder=cfg.MODEL.ENCODER,
            beta_kl=cfg.MODEL.BETA_KL,
            beta_rec=cfg.MODEL.BETA_REC,
            beta_neg=cfg.MODEL.BETA_NEG[cfg.MODEL.LAYER_COUNT - 1],
            scale=cfg.MODEL.SCALE)
        model_s.cuda(local_rank)
        model_s.eval()
        model_s.requires_grad_(False)

    if distributed:
        model = nn.parallel.DistributedDataParallel(
            model,
            device_ids=[local_rank],
            broadcast_buffers=False,
            bucket_cap_mb=25,
            find_unused_parameters=True)
        model.device_ids = None

        decoder = model.module.decoder
        encoder = model.module.encoder
        mapping_tl = model.module.mapping_tl
        mapping_fl = model.module.mapping_fl
        dlatent_avg = model.module.dlatent_avg
    else:
        decoder = model.decoder
        encoder = model.encoder
        mapping_tl = model.mapping_tl
        mapping_fl = model.mapping_fl
        dlatent_avg = model.dlatent_avg

    count_parameters.print = lambda a: logger.info(a)

    num_vae_epochs = cfg.TRAIN.NUM_VAE

    logger.info("Trainable parameters decoder:")
    print(count_parameters(decoder))

    logger.info("Trainable parameters encoder:")
    print(count_parameters(encoder))

    arguments = dict()
    arguments["iteration"] = 0

    decoder_optimizer = LREQAdam([{
        'params': decoder.parameters()
    }, {
        'params': mapping_fl.parameters()
    }],
                                 lr=cfg.TRAIN.BASE_LEARNING_RATE,
                                 betas=(cfg.TRAIN.ADAM_BETA_0,
                                        cfg.TRAIN.ADAM_BETA_1),
                                 weight_decay=0)

    encoder_optimizer = LREQAdam([{
        'params': encoder.parameters()
    }, {
        'params': mapping_tl.parameters()
    }],
                                 lr=cfg.TRAIN.BASE_LEARNING_RATE,
                                 betas=(cfg.TRAIN.ADAM_BETA_0,
                                        cfg.TRAIN.ADAM_BETA_1),
                                 weight_decay=0)

    scheduler = ComboMultiStepLR(optimizers={
        'encoder_optimizer': encoder_optimizer,
        'decoder_optimizer': decoder_optimizer
    },
                                 milestones=cfg.TRAIN.LEARNING_DECAY_STEPS,
                                 gamma=cfg.TRAIN.LEARNING_DECAY_RATE,
                                 reference_batch_size=32,
                                 base_lr=cfg.TRAIN.LEARNING_RATES)

    model_dict = {
        'discriminator': encoder,
        'generator': decoder,
        'mapping_fl': mapping_fl,
        'mapping_tl': mapping_tl,
        'dlatent_avg': dlatent_avg
    }
    if local_rank == 0:
        model_dict['discriminator_s'] = model_s.encoder
        model_dict['generator_s'] = model_s.decoder
        model_dict['mapping_fl_s'] = model_s.mapping_fl
        model_dict['mapping_tl_s'] = model_s.mapping_tl

    tracker = LossTracker(cfg.OUTPUT_DIR)

    checkpointer = Checkpointer(cfg,
                                model_dict, {
                                    'encoder_optimizer': encoder_optimizer,
                                    'decoder_optimizer': decoder_optimizer,
                                    'scheduler': scheduler,
                                    'tracker': tracker
                                },
                                logger=logger,
                                save=local_rank == 0)

    extra_checkpoint_data = checkpointer.load()
    logger.info("Starting from epoch: %d" % (scheduler.start_epoch()))

    arguments.update(extra_checkpoint_data)

    layer_to_resolution = decoder.layer_to_resolution

    dataset = TFRecordsDataset(cfg,
                               logger,
                               rank=local_rank,
                               world_size=world_size,
                               buffer_size_mb=1024,
                               channels=cfg.MODEL.CHANNELS)

    rnd = np.random.RandomState(3456)
    latents = rnd.randn(32, cfg.MODEL.LATENT_SPACE_SIZE)
    samplez = torch.tensor(latents).float().cuda()

    lod2batch = lod_driver.LODDriver(cfg,
                                     logger,
                                     world_size,
                                     dataset_size=len(dataset) * world_size)

    if cfg.DATASET.SAMPLES_PATH:
        path = cfg.DATASET.SAMPLES_PATH
        src = []
        with torch.no_grad():
            for filename in list(os.listdir(path))[:32]:
                img = np.asarray(Image.open(os.path.join(path, filename)))
                if img.shape[2] == 4:
                    img = img[:, :, :3]
                im = img.transpose((2, 0, 1))
                x = torch.tensor(np.asarray(im, dtype=np.float32),
                                 requires_grad=True).cuda() / 127.5 - 1.
                if x.shape[0] == 4:
                    x = x[:3]
                src.append(x)
            sample = torch.stack(src)
    else:
        dataset.reset(cfg.DATASET.MAX_RESOLUTION_LEVEL, 32)
        sample = next(make_dataloader(cfg, logger, dataset, 32, local_rank))
        sample = (sample / 127.5 - 1.)

    lod2batch.set_epoch(scheduler.start_epoch(),
                        [encoder_optimizer, decoder_optimizer])

    kls_real = []
    kls_fake = []
    rec_errs = []

    best_fid = None
    # best_fid = 20.0

    for epoch in range(scheduler.start_epoch(), cfg.TRAIN.TRAIN_EPOCHS):
        new_beta_neg = cfg.MODEL.BETA_NEG[lod2batch.lod]
        if distributed:
            if model.module.beta_neg != new_beta_neg:
                model.module.beta_neg = new_beta_neg
                print("beta negative changed to:", new_beta_neg)
        else:
            if model.beta_neg != new_beta_neg:
                model.beta_neg = new_beta_neg
                print("beta negative changed to:", new_beta_neg)
        if (epoch > cfg.TRAIN.EPOCHS_PER_LOD *
            (cfg.MODEL.LAYER_COUNT - 1)) and (epoch % 10 == 0) and (local_rank
                                                                    == 0):
            print("calculating fid...")
            fid = calc_fid_from_dataset_generate(
                cfg,
                dataset,
                model_s,
                batch_size=50,
                cuda=1,
                dims=2048,
                device=torch.device(
                    "cuda:0" if torch.cuda.is_available() else "cpu"),
                num_images=50000)
            print("epoch: {}, fid: {}".format(epoch, fid))
            if best_fid is None:
                best_fid = fid
            elif fid < best_fid:
                print("best fid updated: {} -> {}".format(best_fid, fid))
                best_fid = fid
                checkpointer.save("model_tmp_lod{}_fid_{}".format(
                    lod2batch.lod, fid))
        lod2batch.set_epoch(epoch, [encoder_optimizer, decoder_optimizer])

        logger.info(
            "Batch size: %d, Batch size per GPU: %d, LOD: %d - %dx%d, blend: %.3f, dataset size: %d"
            % (lod2batch.get_batch_size(), lod2batch.get_per_GPU_batch_size(),
               lod2batch.lod, 2**lod2batch.get_lod_power2(), 2**
               lod2batch.get_lod_power2(), lod2batch.get_blend_factor(),
               len(dataset) * world_size))

        dataset.reset(lod2batch.get_lod_power2(),
                      lod2batch.get_per_GPU_batch_size())
        batches = make_dataloader(cfg, logger, dataset,
                                  lod2batch.get_per_GPU_batch_size(),
                                  local_rank)

        scheduler.set_batch_size(lod2batch.get_batch_size(), lod2batch.lod)

        model.train()

        need_permute = False
        epoch_start_time = time.time()

        i = 0

        diff_kls = []
        batch_kls_real = []
        batch_kls_fake = []
        batch_rec_errs = []

        for x_orig in tqdm(batches):
            i += 1
            with torch.no_grad():
                if x_orig.shape[0] != lod2batch.get_per_GPU_batch_size():
                    continue
                if need_permute:
                    x_orig = x_orig.permute(0, 3, 1, 2)
                x_orig = (x_orig / 127.5 - 1.)

                blend_factor = lod2batch.get_blend_factor()

                needed_resolution = layer_to_resolution[lod2batch.lod]
                x = x_orig

                if lod2batch.in_transition:
                    needed_resolution_prev = layer_to_resolution[lod2batch.lod
                                                                 - 1]
                    x_prev = F.avg_pool2d(x_orig, 2, 2)
                    x_prev_2x = F.interpolate(x_prev, needed_resolution)
                    x = x * blend_factor + x_prev_2x * (1.0 - blend_factor)

            x.requires_grad = True
            if epoch < num_vae_epochs:
                encoder_optimizer.zero_grad()
                decoder_optimizer.zero_grad()
                loss = model(x,
                             lod2batch.lod,
                             blend_factor,
                             d_train=False,
                             e_train=False)
                tracker.update(dict(loss_e=loss))
                tracker.update(dict(loss_d=loss))
                loss.backward()
                encoder_optimizer.step()
                decoder_optimizer.step()
            else:
                # ------------- Update Encoder ------------- #
                encoder_optimizer.zero_grad()
                loss_e = model(x,
                               lod2batch.lod,
                               blend_factor,
                               d_train=False,
                               e_train=True)
                tracker.update(dict(loss_e=loss_e))
                loss_e.backward()
                encoder_optimizer.step()

                # ------------- Update Decoder ------------- #
                decoder_optimizer.zero_grad()
                loss_d = model(x,
                               lod2batch.lod,
                               blend_factor,
                               d_train=True,
                               e_train=False)
                loss_d.backward()
                tracker.update(dict(loss_d=loss_d))
                decoder_optimizer.step()

            # ------------- Update Statistics ------------- #
            if distributed:
                tracker.update(dict(rec_loss=model.module.last_rec_loss))
                tracker.update(dict(real_kl=model.module.last_real_kl))
                tracker.update(dict(fake_kl=model.module.last_fake_kl))
                tracker.update(dict(kl_diff=model.module.last_kl_diff))
                tracker.update(dict(expelbo_f=model.module.last_expelbo_fake))
                tracker.update(dict(expelbo_r=model.module.last_expelbo_rec))

                diff_kls.append(model.module.last_kl_diff.data.cpu())
                batch_kls_real.append(model.module.last_real_kl)
                batch_kls_fake.append(model.module.last_fake_kl)
                batch_rec_errs.append(model.module.last_rec_loss)
            else:
                tracker.update(dict(rec_loss=model.last_rec_loss))
                tracker.update(dict(real_kl=model.last_real_kl))
                tracker.update(dict(fake_kl=model.last_fake_kl))
                tracker.update(dict(kl_diff=model.last_kl_diff))
                tracker.update(dict(expelbo_f=model.last_expelbo_fake))
                tracker.update(dict(expelbo_r=model.last_expelbo_rec))

                diff_kls.append(model.last_kl_diff.data.cpu())
                batch_kls_real.append(model.last_real_kl)
                batch_kls_fake.append(model.last_fake_kl)
                batch_rec_errs.append(model.last_rec_loss)

            if local_rank == 0:
                betta = 0.5**(lod2batch.get_batch_size() / (10 * 1000.0))
                model_s.lerp(model, betta)

            epoch_end_time = time.time()
            per_epoch_ptime = epoch_end_time - epoch_start_time

            lod_for_saving_model = lod2batch.lod
            lod2batch.step()
            if local_rank == 0:
                if lod2batch.is_time_to_save():
                    checkpointer.save("model_tmp_intermediate_lod%d" %
                                      lod_for_saving_model)
                if lod2batch.is_time_to_report():
                    save_sample(lod2batch, tracker, sample, samplez, x, logger,
                                model_s, cfg, encoder_optimizer,
                                decoder_optimizer)

        scheduler.step()
        mean_diff_kl = np.mean(diff_kls)
        print("mean diff kl: ", mean_diff_kl)

        if epoch > num_vae_epochs - 1:
            kls_real.append(np.mean(batch_kls_real))
            kls_fake.append(np.mean(batch_kls_fake))
            rec_errs.append(np.mean(batch_rec_errs))

        if local_rank == 0:
            checkpointer.save("model_tmp_lod%d" % lod_for_saving_model)
            save_sample(lod2batch, tracker, sample, samplez, x, logger,
                        model_s, cfg, encoder_optimizer, decoder_optimizer)

    logger.info("Training finish!... save training results")
    if local_rank == 0:
        checkpointer.save("model_final").wait()
Esempio n. 3
0
def train(cfg, logger, local_rank, world_size, distributed):
    torch.cuda.set_device(local_rank)
    model = Model(startf=cfg.MODEL.START_CHANNEL_COUNT,
                  layer_count=cfg.MODEL.LAYER_COUNT,
                  maxf=cfg.MODEL.MAX_CHANNEL_COUNT,
                  latent_size=cfg.MODEL.LATENT_SPACE_SIZE,
                  dlatent_avg_beta=cfg.MODEL.DLATENT_AVG_BETA,
                  style_mixing_prob=cfg.MODEL.STYLE_MIXING_PROB,
                  mapping_layers=cfg.MODEL.MAPPING_LAYERS,
                  channels=1)
    model.cuda(local_rank)
    model.train()

    if local_rank == 0:
        model_s = Model(startf=cfg.MODEL.START_CHANNEL_COUNT,
                        layer_count=cfg.MODEL.LAYER_COUNT,
                        maxf=cfg.MODEL.MAX_CHANNEL_COUNT,
                        latent_size=cfg.MODEL.LATENT_SPACE_SIZE,
                        truncation_psi=cfg.MODEL.TRUNCATIOM_PSI,
                        truncation_cutoff=cfg.MODEL.TRUNCATIOM_CUTOFF,
                        mapping_layers=cfg.MODEL.MAPPING_LAYERS,
                        channels=1)
        del model_s.discriminator
        model_s.cuda(local_rank)
        model_s.eval()
        model_s.requires_grad_(False)

    if distributed:
        model = nn.parallel.DistributedDataParallel(
            model,
            device_ids=[local_rank],
            broadcast_buffers=False,
            bucket_cap_mb=25,
            find_unused_parameters=True)
        model.device_ids = None

        generator = model.module.generator
        discriminator = model.module.discriminator
        mapping = model.module.mapping
        dlatent_avg = model.module.dlatent_avg
    else:
        generator = model.generator
        discriminator = model.discriminator
        mapping = model.mapping
        dlatent_avg = model.dlatent_avg

    count_param_override.print = lambda a: logger.info(a)

    logger.info("Trainable parameters generator:")
    count_parameters(generator)

    logger.info("Trainable parameters discriminator:")
    count_parameters(discriminator)

    arguments = dict()
    arguments["iteration"] = 0

    generator_optimizer = LREQAdam([{
        'params': generator.parameters()
    }, {
        'params': mapping.parameters()
    }],
                                   lr=cfg.TRAIN.BASE_LEARNING_RATE,
                                   betas=(cfg.TRAIN.ADAM_BETA_0,
                                          cfg.TRAIN.ADAM_BETA_1),
                                   weight_decay=0)

    discriminator_optimizer = LREQAdam([
        {
            'params': discriminator.parameters()
        },
    ],
                                       lr=cfg.TRAIN.BASE_LEARNING_RATE,
                                       betas=(cfg.TRAIN.ADAM_BETA_0,
                                              cfg.TRAIN.ADAM_BETA_1),
                                       weight_decay=0)

    scheduler = ComboMultiStepLR(optimizers={
        'generator': generator_optimizer,
        'discriminator': discriminator_optimizer
    },
                                 milestones=cfg.TRAIN.LEARNING_DECAY_STEPS,
                                 gamma=cfg.TRAIN.LEARNING_DECAY_RATE,
                                 reference_batch_size=32)

    model_dict = {
        'discriminator': discriminator,
        'generator': generator,
        'mapping': mapping,
        'dlatent_avg': dlatent_avg
    }

    if local_rank == 0:
        model_dict['generator_s'] = model_s.generator
        model_dict['mapping_s'] = model_s.mapping

    tracker = LossTracker(cfg.OUTPUT_DIR)

    checkpointer = Checkpointer(
        cfg,
        model_dict, {
            'generator_optimizer': generator_optimizer,
            'discriminator_optimizer': discriminator_optimizer,
            'tracker': tracker,
            'scheduler': scheduler,
        },
        logger=logger,
        save=local_rank == 0)

    extra_checkpoint_data = checkpointer.load()

    arguments.update(extra_checkpoint_data)

    layer_to_resolution = generator.layer_to_resolution

    dlutils.download.mnist()
    mnist = dlutils.reader.Mnist('mnist').items
    mnist = np.asarray([x[1] for x in mnist], np.float32)

    rnd = np.random.RandomState(3456)
    latents = rnd.randn(32, cfg.MODEL.LATENT_SPACE_SIZE)
    sample = torch.tensor(latents).float().cuda()

    lod2batch = lod_driver.LODDriver(cfg,
                                     logger,
                                     world_size,
                                     dataset_size=len(mnist))

    for epoch in range(scheduler.start_epoch(), cfg.TRAIN.TRAIN_EPOCHS):
        model.train()
        lod2batch.set_epoch(epoch,
                            [generator_optimizer, discriminator_optimizer])

        logger.info(
            "Batch size: %d, Batch size per GPU: %d, LOD: %d - %dx%d, dataset size: %d"
            % (lod2batch.get_batch_size(), lod2batch.get_per_GPU_batch_size(),
               lod2batch.lod, 2**lod2batch.get_lod_power2(), 2**
               lod2batch.get_lod_power2(), len(mnist)))

        dlutils.shuffle.shuffle_ndarray(mnist)

        r = 2**lod2batch.get_lod_power2()
        mnist_ = F.interpolate(
            torch.tensor(mnist).view(mnist.shape[0], 1, 28, 28),
            r).detach().cpu().numpy()

        scheduler.set_batch_size(32)

        model.train()

        need_permute = False

        class BatchCollator(object):
            def __init__(self, device=torch.device("cpu")):
                self.device = device

            def __call__(self, batch):
                with torch.no_grad():
                    x = batch
                    x = torch.tensor(x,
                                     requires_grad=True,
                                     device=torch.device(self.device),
                                     dtype=torch.float32)
                    return x

        batches = dlutils.batch_provider(mnist_,
                                         lod2batch.get_per_GPU_batch_size(),
                                         BatchCollator(local_rank),
                                         report_progress=False)

        with torch.autograd.profiler.profile(use_cuda=True,
                                             enabled=False) as prof:
            for x_orig in tqdm(batches):
                if x_orig.shape[0] != lod2batch.get_per_GPU_batch_size():
                    continue
                if need_permute:
                    x_orig = x_orig.permute(0, 3, 1, 2)
                x_orig = (x_orig / 127.5 - 1.)

                blend_factor = lod2batch.get_blend_factor()

                needed_resolution = layer_to_resolution[lod2batch.lod]
                x = x_orig

                if lod2batch.in_transition:
                    needed_resolution_prev = layer_to_resolution[lod2batch.lod
                                                                 - 1]
                    x_prev = F.avg_pool2d(x_orig, 2, 2)
                    x_prev_2x = F.interpolate(x_prev, needed_resolution)
                    x = x * blend_factor + x_prev_2x * (1.0 - blend_factor)

                discriminator_optimizer.zero_grad()
                loss_d = model(x, lod2batch.lod, blend_factor, d_train=True)
                tracker.update(dict(loss_d=loss_d))
                loss_d.backward()
                discriminator_optimizer.step()

                if local_rank == 0:
                    betta = 0.5**(lod2batch.get_batch_size() / (10 * 1000.0))
                    model_s.lerp(model, betta)

                generator_optimizer.zero_grad()
                loss_g = model(x, lod2batch.lod, blend_factor, d_train=False)
                tracker.update(dict(loss_g=loss_g))
                loss_g.backward()
                generator_optimizer.step()

                lod2batch.step()
                if local_rank == 0 and lod2batch.is_time_to_report():
                    save_sample(lod2batch, tracker, sample, x, logger, model_s,
                                cfg, discriminator_optimizer,
                                generator_optimizer)

        #print(prof.key_averages().table(sort_by="self_cpu_time_total"))

        if local_rank == 0:
            save_sample(lod2batch, tracker, sample, x, logger, model_s, cfg,
                        discriminator_optimizer, generator_optimizer)
            if epoch > 2:
                checkpointer.save("model_tmp")

        scheduler.step()

    logger.info("Training finish!... save training results")
    if local_rank == 0:
        checkpointer.save("model_final")
Esempio n. 4
0
def train(cfg, logger, local_rank, world_size, folding_id=0, inliner_classes=None):
    torch.cuda.set_device(local_rank)
    model = Model(
        startf=cfg.MODEL.START_CHANNEL_COUNT,
        layer_count=cfg.MODEL.LAYER_COUNT,
        maxf=cfg.MODEL.MAX_CHANNEL_COUNT,
        latent_size=cfg.MODEL.LATENT_SPACE_SIZE,
        channels=cfg.MODEL.INPUT_IMAGE_CHANNELS,
        generator=cfg.MODEL.GENERATOR,
        encoder=cfg.MODEL.ENCODER,
    )
    model.cuda(local_rank)
    model.train()

    model_s = Model(
        startf=cfg.MODEL.START_CHANNEL_COUNT,
        layer_count=cfg.MODEL.LAYER_COUNT,
        maxf=cfg.MODEL.MAX_CHANNEL_COUNT,
        latent_size=cfg.MODEL.LATENT_SPACE_SIZE,
        channels=cfg.MODEL.INPUT_IMAGE_CHANNELS,
        generator=cfg.MODEL.GENERATOR,
        encoder=cfg.MODEL.ENCODER,
    )
    model_s.cuda(local_rank)
    model_s.eval()
    model_s.requires_grad_(False)

    generator = model.generator
    encoder = model.encoder
    discriminator = model.discriminator
    z_discriminator = model.z_discriminator

    count_param_override.print = lambda a: logger.info(a)

    logger.info("Trainable parameters generator:")
    count_parameters(generator)

    logger.info("Trainable parameters discriminator:")
    count_parameters(encoder)

    arguments = dict()
    arguments["iteration"] = 0

    generator_optimizer = LREQAdam([
        {'params': generator.parameters()},
    ], lr=cfg.TRAIN.BASE_LEARNING_RATE, betas=(cfg.TRAIN.ADAM_BETA_0, cfg.TRAIN.ADAM_BETA_1), weight_decay=0)

    z_discriminator_optimizer = LREQAdam([
        {'params': z_discriminator.parameters()},
    ], lr=cfg.TRAIN.BASE_LEARNING_RATE, betas=(cfg.TRAIN.ADAM_BETA_0, cfg.TRAIN.ADAM_BETA_1), weight_decay=0)

    encoder_optimizer = LREQAdam([
        {'params': encoder.parameters()},
        {'params': discriminator.parameters()},
    ], lr=cfg.TRAIN.BASE_LEARNING_RATE, betas=(cfg.TRAIN.ADAM_BETA_0, cfg.TRAIN.ADAM_BETA_1), weight_decay=0)

    scheduler = ComboMultiStepLR(optimizers=
                                 {
                                    'encoder_optimizer': encoder_optimizer,
                                    'generator_optimizer': generator_optimizer,
                                    'z_discriminator_optimizer': z_discriminator_optimizer
                                 },
                                 milestones=cfg.TRAIN.LEARNING_DECAY_STEPS,
                                 gamma=cfg.TRAIN.LEARNING_DECAY_RATE,
                                 reference_batch_size=32, base_lr=cfg.TRAIN.LEARNING_RATES)

    model_dict = {
        'encoder': encoder,
        'generator': generator,
        'discriminator': discriminator,
        'z_discriminator': z_discriminator,
        'encoder_s': model_s.encoder,
        'generator_s': model_s.generator,
        'discriminator_s': model_s.discriminator,
        'z_discriminator_s': model_s.z_discriminator,
    }

    output_folder = os.path.join('results_' + str(folding_id) + "_" + "_".join([str(x) for x in inliner_classes]))
    output_folder = os.path.join(cfg.OUTPUT_DIR, output_folder)
    os.makedirs(output_folder, exist_ok=True)

    tracker = LossTracker(output_folder)

    checkpointer = Checkpointer(output_folder,
                                model_dict,
                                {
                                    'encoder_optimizer': encoder_optimizer,
                                    'decoder_optimizer': generator_optimizer,
                                    'scheduler': scheduler,
                                    'tracker': tracker
                                },
                                logger=logger,
                                save=True)

    extra_checkpoint_data = checkpointer.load()
    save_file = os.path.join(checkpointer.folder, "last_checkpoint")
    try:
        with open(save_file, "r") as last_checkpoint:
            f = last_checkpoint.read().strip()
            f = os.path.basename(f)
            checkpointer.tag_last_checkpoint(f)
        extra_checkpoint_data = checkpointer.load()
    except:
        pass

    logger.info("Starting from epoch: %d" % (scheduler.start_epoch()))

    arguments.update(extra_checkpoint_data)

    layer_to_resolution = generator.layer_to_resolution

    train_set, _, _ = make_datasets(cfg, logger, folding_id, inliner_classes)

    rnd = np.random.RandomState(3456)
    latents = rnd.randn(32, cfg.MODEL.LATENT_SPACE_SIZE)
    samplez = torch.tensor(latents).float().cuda()

    lod2batch = driver.Driver(cfg, logger, world_size, dataset_size=len(train_set))

    sample = next(make_dataloader(train_set, cfg.TRAIN.BATCH_1GPU, torch.cuda.current_device()))
    sample = sample[1]
    sample = sample.view(-1, cfg.MODEL.INPUT_IMAGE_CHANNELS, cfg.MODEL.INPUT_IMAGE_SIZE, cfg.MODEL.INPUT_IMAGE_SIZE)
    # sample = (sample / 127.5 - 1.)

    lod2batch.set_epoch(scheduler.start_epoch(), [encoder_optimizer, generator_optimizer])

    scores_list = []

    try:
        with open(os.path.join(output_folder, "scores.txt"), "r") as f:
            lines = f.readlines()
            lines = [l[:-1].strip() for l in lines]
            lines = [l.split(' ') for l in lines]
            lines = [l for l in lines if len(l) == 2]
            scores_list = [(x[0], float(x[1]))for x in lines]
            # for l in scores_list:
            #     print("%s: %f" % l)
    except FileNotFoundError:
        pass

    def save(epoch):
        score = eval_model_on_valid(cfg, logger, model_s, folding_id, inliner_classes)
        filename = "model_%d" % epoch
        checkpointer.save(filename).wait()
        scores_list.append((filename, score))
        with open(os.path.join(output_folder, "scores.txt"), "w") as f:
            f.writelines([x[0] + " " + str(x[1]) + "\n" for x in scores_list])

    def last_score():
        return 0 if len(scores_list) == 0 else scores_list[-1][1]

    epoch = None
    for epoch in range(scheduler.start_epoch(), cfg.TRAIN.TRAIN_EPOCHS):
        model.train()
        lod2batch.set_epoch(epoch, [encoder_optimizer, generator_optimizer])

        logger.info("Batch size: %d, Batch size per GPU: %d, dataset size: %d" % (
                                                                lod2batch.get_batch_size(),
                                                                lod2batch.get_per_GPU_batch_size(),
                                                                len(train_set) * world_size))

        data_loader = make_dataloader(train_set, lod2batch.get_per_GPU_batch_size(), torch.cuda.current_device())
        train_set.shuffle()

        scheduler.set_batch_size(lod2batch.get_batch_size())

        model.train()

        epoch_start_time = time.time()

        i = 0
        for y, x in data_loader:
            x = x.view(x.shape[0], cfg.MODEL.INPUT_IMAGE_CHANNELS, cfg.MODEL.INPUT_IMAGE_SIZE, cfg.MODEL.INPUT_IMAGE_SIZE)
            i += 1
            with torch.no_grad():
                if x.shape[0] != lod2batch.get_per_GPU_batch_size():
                    continue
            encoder_optimizer.zero_grad()
            loss_d, loss_zg = model(x, d_train=True, ae=False)
            tracker.update(dict(loss_d=loss_d, loss_zg=loss_zg))
            (loss_zg + loss_d).backward()
            encoder_optimizer.step()

            generator_optimizer.zero_grad()
            z_discriminator_optimizer.zero_grad()
            loss_g, loss_zd = model(x, d_train=False, ae=False)
            tracker.update(dict(loss_g=loss_g, loss_zd=loss_zd))
            (loss_g + loss_zd).backward()
            generator_optimizer.step()
            z_discriminator_optimizer.step()

            encoder_optimizer.zero_grad()
            generator_optimizer.zero_grad()
            lae = model(x, d_train=True, ae=True)
            tracker.update(dict(lae=lae))
            (lae).backward()
            encoder_optimizer.step()
            generator_optimizer.step()

            betta = 0.5 ** (lod2batch.get_batch_size() / (1000.0))
            model_s.lerp(model, betta)

            epoch_end_time = time.time()
            per_epoch_ptime = epoch_end_time - epoch_start_time

            # tracker.update(dict(score_a=score_a, score_b=score_b, score_c=score_c))
            tracker.update(dict(score=last_score()))

            lod2batch.step()
            # if lod2batch.is_time_to_save():
            #     checkpointer.save("model_tmp_intermediate_lod%d" % lod_for_saving_model)
            if lod2batch.is_time_to_report():
                save_sample(lod2batch, tracker, sample, samplez, x, logger, model_s, cfg, encoder_optimizer,
                            generator_optimizer, output_folder)

        scheduler.step()

        if epoch % 20 == 0:
            save(epoch)

        save_sample(lod2batch, tracker, sample, samplez, x, logger, model_s, cfg, encoder_optimizer, generator_optimizer, output_folder)

    logger.info("Training finish!... save training results")
    if epoch is not None:
        save(epoch)

    best_model_name, best_model_score = scores_list[0]
    for model_name, model_score in scores_list:
        if model_score >= best_model_score:
            best_model_name, best_model_score = model_name, model_score

    checkpointer.tag_best_checkpoint(best_model_name)
Esempio n. 5
0
def train(cfg, logger, local_rank, world_size, distributed):
    torch.cuda.set_device(local_rank)
    model = Model(startf=cfg.MODEL.START_CHANNEL_COUNT,
                  layer_count=cfg.MODEL.LAYER_COUNT,
                  maxf=cfg.MODEL.MAX_CHANNEL_COUNT,
                  latent_size=cfg.MODEL.LATENT_SPACE_SIZE,
                  dlatent_avg_beta=cfg.MODEL.DLATENT_AVG_BETA,
                  style_mixing_prob=cfg.MODEL.STYLE_MIXING_PROB,
                  mapping_layers=cfg.MODEL.MAPPING_LAYERS,
                  channels=3)
    model.cuda(local_rank)
    model.train()

    if local_rank == 0:
        model_s = Model(startf=cfg.MODEL.START_CHANNEL_COUNT,
                        layer_count=cfg.MODEL.LAYER_COUNT,
                        maxf=cfg.MODEL.MAX_CHANNEL_COUNT,
                        latent_size=cfg.MODEL.LATENT_SPACE_SIZE,
                        truncation_psi=cfg.MODEL.TRUNCATIOM_PSI,
                        truncation_cutoff=cfg.MODEL.TRUNCATIOM_CUTOFF,
                        mapping_layers=cfg.MODEL.MAPPING_LAYERS,
                        channels=3)
        del model_s.discriminator
        model_s.cuda(local_rank)
        model_s.eval()
        model_s.requires_grad_(False)

    if distributed:
        model = nn.parallel.DistributedDataParallel(
            model,
            device_ids=[local_rank],
            broadcast_buffers=False,
            bucket_cap_mb=25,
            find_unused_parameters=True)
        model.device_ids = None

        generator = model.module.generator
        discriminator = model.module.discriminator
        mapping = model.module.mapping
        dlatent_avg = model.module.dlatent_avg
    else:
        generator = model.generator
        discriminator = model.discriminator
        mapping = model.mapping
        dlatent_avg = model.dlatent_avg

    count_param_override.print = lambda a: logger.info(a)

    logger.info("Trainable parameters generator:")
    count_parameters(generator)

    logger.info("Trainable parameters discriminator:")
    count_parameters(discriminator)

    generator_optimizer = LREQAdam([{
        'params': generator.parameters()
    }, {
        'params': mapping.parameters()
    }],
                                   lr=cfg.TRAIN.BASE_LEARNING_RATE,
                                   betas=(cfg.TRAIN.ADAM_BETA_0,
                                          cfg.TRAIN.ADAM_BETA_1),
                                   weight_decay=0)

    discriminator_optimizer = LREQAdam([
        {
            'params': discriminator.parameters()
        },
    ],
                                       lr=cfg.TRAIN.BASE_LEARNING_RATE,
                                       betas=(cfg.TRAIN.ADAM_BETA_0,
                                              cfg.TRAIN.ADAM_BETA_1),
                                       weight_decay=0)

    scheduler = ComboMultiStepLR(optimizers={
        'generator': generator_optimizer,
        'discriminator': discriminator_optimizer
    },
                                 milestones=cfg.TRAIN.LEARNING_DECAY_STEPS,
                                 gamma=cfg.TRAIN.LEARNING_DECAY_RATE,
                                 reference_batch_size=32,
                                 base_lr=cfg.TRAIN.LEARNING_RATES)

    model_dict = {
        'discriminator': discriminator,
        'generator': generator,
        'mapping': mapping,
        'dlatent_avg': dlatent_avg
    }

    if local_rank == 0:
        model_dict['generator_s'] = model_s.generator
        model_dict['mapping_s'] = model_s.mapping

    tracker = LossTracker(cfg.OUTPUT_DIR)

    checkpointer = Checkpointer(
        cfg,
        model_dict, {
            'generator_optimizer': generator_optimizer,
            'discriminator_optimizer': discriminator_optimizer,
            'scheduler': scheduler,
            'tracker': tracker
        },
        logger=logger,
        save=local_rank == 0)

    checkpointer.load()
    logger.info("Starting from epoch: %d" % (scheduler.start_epoch()))

    layer_to_resolution = generator.layer_to_resolution

    dataset = TFRecordsDataset(cfg,
                               logger,
                               rank=local_rank,
                               world_size=world_size,
                               buffer_size_mb=1024)

    rnd = np.random.RandomState(3456)
    latents = rnd.randn(32, cfg.MODEL.LATENT_SPACE_SIZE)
    sample = torch.tensor(latents).float().cuda()

    lod2batch = lod_driver.LODDriver(cfg,
                                     logger,
                                     world_size,
                                     dataset_size=len(dataset) * world_size)

    for epoch in range(scheduler.start_epoch(), cfg.TRAIN.TRAIN_EPOCHS):
        model.train()
        lod2batch.set_epoch(epoch,
                            [generator_optimizer, discriminator_optimizer])

        logger.info(
            "Batch size: %d, Batch size per GPU: %d, LOD: %d - %dx%d, blend: %.3f, dataset size: %d"
            % (lod2batch.get_batch_size(), lod2batch.get_per_GPU_batch_size(),
               lod2batch.lod, 2**lod2batch.get_lod_power2(), 2**
               lod2batch.get_lod_power2(), lod2batch.get_blend_factor(),
               len(dataset) * world_size))

        dataset.reset(lod2batch.get_lod_power2(),
                      lod2batch.get_per_GPU_batch_size())
        batches = make_dataloader(cfg, logger, dataset,
                                  lod2batch.get_per_GPU_batch_size(),
                                  local_rank)

        scheduler.set_batch_size(lod2batch.get_batch_size(), lod2batch.lod)

        model.train()

        need_permute = False

        with torch.autograd.profiler.profile(use_cuda=True,
                                             enabled=False) as prof:
            for x_orig in tqdm(batches):
                torch.distributed.barrier()

                with torch.no_grad():
                    if x_orig.shape[0] != lod2batch.get_per_GPU_batch_size():
                        continue
                    if need_permute:
                        x_orig = x_orig.permute(0, 3, 1, 2)
                    x_orig = (x_orig / 127.5 - 1.)

                    blend_factor = lod2batch.get_blend_factor()

                    needed_resolution = layer_to_resolution[lod2batch.lod]
                    x = x_orig

                    if lod2batch.in_transition:
                        needed_resolution_prev = layer_to_resolution[
                            lod2batch.lod - 1]
                        x_prev = F.avg_pool2d(x_orig, 2, 2)
                        x_prev_2x = F.interpolate(x_prev, needed_resolution)
                        x = x * blend_factor + x_prev_2x * (1.0 - blend_factor)
                x.requires_grad = True

                discriminator_optimizer.zero_grad()
                loss_d = model(x, lod2batch.lod, blend_factor, d_train=True)
                tracker.update(dict(loss_d=loss_d))
                loss_d.backward()
                discriminator_optimizer.step()

                if local_rank == 0:
                    betta = 0.5**(lod2batch.get_batch_size() / (10 * 1000.0))
                    model_s.lerp(model, betta)

                generator_optimizer.zero_grad()
                loss_g = model(x, lod2batch.lod, blend_factor, d_train=False)
                tracker.update(dict(loss_g=loss_g))
                loss_g.backward()
                generator_optimizer.step()

                lod2batch.step()
                if local_rank == 0:
                    if lod2batch.is_time_to_save():
                        checkpointer.save("model_tmp_intermediate")
                    if lod2batch.is_time_to_report():
                        save_sample(lod2batch, tracker, sample, x, logger,
                                    model_s, cfg, discriminator_optimizer,
                                    generator_optimizer)

        #print(prof.key_averages().table(sort_by="self_cpu_time_total"))

        scheduler.step()

        if local_rank == 0:
            checkpointer.save("model_tmp")
            save_sample(lod2batch, tracker, sample, x, logger, model_s, cfg,
                        discriminator_optimizer, generator_optimizer)

    logger.info("Training finish!... save training results")
    if local_rank == 0:
        checkpointer.save("model_final").wait()