def train_net(args): torch.cuda.set_device(0) cfg = get_cfg_defaults() cfg.merge_from_file(args.config_file) cfg.merge_from_list(args.opts) cfg.freeze() logger = logging.getLogger("logger") logger.setLevel(logging.DEBUG) output_dir = cfg.OUTPUT_DIR os.makedirs(output_dir, exist_ok=True) ch = logging.StreamHandler(stream=sys.stdout) ch.setLevel(logging.DEBUG) formatter = logging.Formatter( "%(asctime)s %(name)s %(levelname)s: %(message)s") ch.setFormatter(formatter) logger.addHandler(ch) logger.info(args) logger.info("Loaded configuration file {}".format(args.config_file)) with open(args.config_file, "r") as cf: config_str = "\n" + cf.read() logger.info(config_str) logger.info("Running with config:\n{}".format(cfg)) torch.set_default_tensor_type('torch.cuda.FloatTensor') model = load_from('karras2019stylegan-ffhq-1024x1024.pkl', cfg) #model, Gs = load_from('karras2019stylegan-ffhq-1024x1024.pkl', cfg) # Generate image. #fmt = dict(func=tflib.convert_images_to_uint8, nchw_to_nhwc=True) #images = Gs.run(sample.cpu().detach().numpy(), None, truncation_psi=0.7, randomize_noise=True, output_transform=None) rnd = np.random.RandomState(5) latents = rnd.randn(1, cfg.MODEL.LATENT_SPACE_SIZE) sample = torch.tensor(latents).float().cuda() save_sample( model, sample, ) #png_filename = os.path.join('example.png') #PIL.Image.fromarray(images[0], 'RGB').save(png_filename) model_dict = { 'generator_s': model.generator, 'mapping_s': model.mapping, 'dlatent_avg': model.dlatent_avg, } checkpointer = Checkpointer(cfg, model_dict, logger=logger, save=True) checkpointer.save('karras2019stylegan-ffhq')
def convert(args): torch.cuda.set_device(0) cfg = get_cfg_defaults() cfg.merge_from_file(args.config_file) cfg.merge_from_list(args.opts) cfg.freeze() logger = logging.getLogger("logger") logger.setLevel(logging.DEBUG) output_dir = cfg.OUTPUT_DIR os.makedirs(output_dir, exist_ok=True) ch = logging.StreamHandler(stream=sys.stdout) ch.setLevel(logging.DEBUG) formatter = logging.Formatter( "%(asctime)s %(name)s %(levelname)s: %(message)s") ch.setFormatter(formatter) logger.addHandler(ch) logger.info(args) logger.info("Loaded configuration file {}".format(args.config_file)) with open(args.config_file, "r") as cf: config_str = "\n" + cf.read() logger.info(config_str) logger.info("Running with config:\n{}".format(cfg)) torch.set_default_tensor_type('torch.cuda.FloatTensor') #model, Gs = load_from('karras2019stylegan-ffhq-1024x1024.pkl', cfg) model, Gs = load_from('karras2019stylegan-cats-256x256.pkl', cfg) model_dict = { 'generator_s': model.generator, 'mapping_fl_s': model.mapping, 'dlatent_avg': model.dlatent_avg, } checkpointer = Checkpointer(cfg, model_dict, logger=logger, save=True) #checkpointer.save('karras2019stylegan-ffhq') checkpointer.save('karras2019stylegan-cats')
def train(cfg, logger, local_rank, world_size, distributed): torch.cuda.set_device(local_rank) model = Model(startf=cfg.MODEL.START_CHANNEL_COUNT, layer_count=cfg.MODEL.LAYER_COUNT, maxf=cfg.MODEL.MAX_CHANNEL_COUNT, latent_size=cfg.MODEL.LATENT_SPACE_SIZE, dlatent_avg_beta=cfg.MODEL.DLATENT_AVG_BETA, style_mixing_prob=cfg.MODEL.STYLE_MIXING_PROB, mapping_layers=cfg.MODEL.MAPPING_LAYERS, channels=cfg.MODEL.CHANNELS, generator=cfg.MODEL.GENERATOR, encoder=cfg.MODEL.ENCODER) model.cuda(local_rank) model.train() if local_rank == 0: model_s = Model(startf=cfg.MODEL.START_CHANNEL_COUNT, layer_count=cfg.MODEL.LAYER_COUNT, maxf=cfg.MODEL.MAX_CHANNEL_COUNT, latent_size=cfg.MODEL.LATENT_SPACE_SIZE, truncation_psi=cfg.MODEL.TRUNCATIOM_PSI, truncation_cutoff=cfg.MODEL.TRUNCATIOM_CUTOFF, mapping_layers=cfg.MODEL.MAPPING_LAYERS, channels=cfg.MODEL.CHANNELS, generator=cfg.MODEL.GENERATOR, encoder=cfg.MODEL.ENCODER) model_s.cuda(local_rank) model_s.eval() model_s.requires_grad_(False) if distributed: model = nn.parallel.DistributedDataParallel( model, device_ids=[local_rank], broadcast_buffers=False, bucket_cap_mb=25, find_unused_parameters=True) model.device_ids = None decoder = model.module.decoder encoder = model.module.encoder discriminator = model.module.discriminator mapping_fl = model.module.mapping_fl dlatent_avg = model.module.dlatent_avg else: decoder = model.decoder encoder = model.encoder discriminator = model.discriminator mapping_fl = model.mapping_fl dlatent_avg = model.dlatent_avg count_param_override.print = lambda a: logger.info(a) logger.info("Trainable parameters generator:") count_parameters(decoder) logger.info("Trainable parameters discriminator:") count_parameters(encoder) arguments = dict() arguments["iteration"] = 0 decoder_optimizer = LREQAdam([{ 'params': decoder.parameters() }, { 'params': mapping_fl.parameters() }], lr=cfg.TRAIN.BASE_LEARNING_RATE, betas=(cfg.TRAIN.ADAM_BETA_0, cfg.TRAIN.ADAM_BETA_1), weight_decay=0) encoder_optimizer = LREQAdam([ { 'params': encoder.parameters() }, ], lr=cfg.TRAIN.BASE_LEARNING_RATE, betas=(cfg.TRAIN.ADAM_BETA_0, cfg.TRAIN.ADAM_BETA_1), weight_decay=0) discriminator_optimizer = LREQAdam([ { 'params': discriminator.parameters() }, ], lr=cfg.TRAIN.BASE_LEARNING_RATE, betas=(cfg.TRAIN.ADAM_BETA_0, cfg.TRAIN.ADAM_BETA_1), weight_decay=0) scheduler = ComboMultiStepLR(optimizers={ 'encoder_optimizer': encoder_optimizer, 'discriminator_optimizer': discriminator_optimizer, 'decoder_optimizer': decoder_optimizer }, milestones=cfg.TRAIN.LEARNING_DECAY_STEPS, gamma=cfg.TRAIN.LEARNING_DECAY_RATE, reference_batch_size=32, base_lr=cfg.TRAIN.LEARNING_RATES) model_dict = { 'discriminator': discriminator, 'encoder': encoder, 'generator': decoder, 'mapping_fl': mapping_fl, 'dlatent_avg': dlatent_avg } if local_rank == 0: model_dict['discriminator_s'] = model_s.discriminator model_dict['encoder_s'] = model_s.encoder model_dict['generator_s'] = model_s.decoder model_dict['mapping_fl_s'] = model_s.mapping_fl tracker = LossTracker(cfg.OUTPUT_DIR) checkpointer = Checkpointer( cfg, model_dict, { 'encoder_optimizer': encoder_optimizer, 'discriminator_optimizer': discriminator_optimizer, 'decoder_optimizer': decoder_optimizer, 'scheduler': scheduler, 'tracker': tracker }, logger=logger, save=local_rank == 0) extra_checkpoint_data = checkpointer.load() logger.info("Starting from epoch: %d" % (scheduler.start_epoch())) arguments.update(extra_checkpoint_data) layer_to_resolution = decoder.layer_to_resolution dataset = TFRecordsDataset(cfg, logger, rank=local_rank, world_size=world_size, buffer_size_mb=1024, channels=cfg.MODEL.CHANNELS) rnd = np.random.RandomState(3456) latents = rnd.randn(32, cfg.MODEL.LATENT_SPACE_SIZE) samplez = torch.tensor(latents).float().cuda() lod2batch = lod_driver.LODDriver(cfg, logger, world_size, dataset_size=len(dataset) * world_size) if cfg.DATASET.SAMPLES_PATH: path = cfg.DATASET.SAMPLES_PATH src = [] with torch.no_grad(): for filename in list(os.listdir(path))[:32]: img = np.asarray(Image.open(path + '/' + filename)) if img.shape[2] == 4: img = img[:, :, :3] im = img.transpose((2, 0, 1)) x = torch.tensor(np.asarray(im, dtype=np.float32), requires_grad=True).cuda() / 127.5 - 1. if x.shape[0] == 4: x = x[:3] src.append(x) sample = torch.stack(src) else: dataset.reset(cfg.DATASET.MAX_RESOLUTION_LEVEL, 32) sample = next(make_dataloader(cfg, logger, dataset, 32, local_rank)) sample = (sample / 127.5 - 1.) lod2batch.set_epoch(scheduler.start_epoch(), [encoder_optimizer, decoder_optimizer]) for epoch in range(scheduler.start_epoch(), cfg.TRAIN.TRAIN_EPOCHS): model.train() lod2batch.set_epoch(epoch, [encoder_optimizer, decoder_optimizer]) logger.info( "Batch size: %d, Batch size per GPU: %d, LOD: %d - %dx%d, blend: %.3f, dataset size: %d" % (lod2batch.get_batch_size(), lod2batch.get_per_GPU_batch_size(), lod2batch.lod, 2**lod2batch.get_lod_power2(), 2** lod2batch.get_lod_power2(), lod2batch.get_blend_factor(), len(dataset) * world_size)) dataset.reset(lod2batch.get_lod_power2(), lod2batch.get_per_GPU_batch_size()) batches = make_dataloader(cfg, logger, dataset, lod2batch.get_per_GPU_batch_size(), local_rank) scheduler.set_batch_size(lod2batch.get_batch_size(), lod2batch.lod) model.train() need_permute = False epoch_start_time = time.time() i = 0 with torch.autograd.profiler.profile(use_cuda=True, enabled=False) as prof: for x_orig in tqdm(batches): i += 1 with torch.no_grad(): if x_orig.shape[0] != lod2batch.get_per_GPU_batch_size(): continue if need_permute: x_orig = x_orig.permute(0, 3, 1, 2) x_orig = (x_orig / 127.5 - 1.) blend_factor = lod2batch.get_blend_factor() needed_resolution = layer_to_resolution[lod2batch.lod] x = x_orig if lod2batch.in_transition: needed_resolution_prev = layer_to_resolution[ lod2batch.lod - 1] x_prev = F.avg_pool2d(x_orig, 2, 2) x_prev_2x = F.interpolate(x_prev, needed_resolution) x = x * blend_factor + x_prev_2x * (1.0 - blend_factor) x.requires_grad = True loss_d = model(x, lod2batch.lod, blend_factor, d_train=True, ae=False) tracker.update(dict(loss_d=loss_d)) loss_d.backward() discriminator_optimizer.step() decoder_optimizer.zero_grad() discriminator_optimizer.zero_grad() loss_g = model(x, lod2batch.lod, blend_factor, d_train=False, ae=False) tracker.update(dict(loss_g=loss_g)) loss_g.backward() decoder_optimizer.step() decoder_optimizer.zero_grad() discriminator_optimizer.zero_grad() lae = model(x, lod2batch.lod, blend_factor, d_train=True, ae=True) tracker.update(dict(lae=lae)) (lae).backward() encoder_optimizer.step() decoder_optimizer.step() encoder_optimizer.zero_grad() decoder_optimizer.zero_grad() if local_rank == 0: betta = 0.5**(lod2batch.get_batch_size() / (10 * 1000.0)) model_s.lerp(model, betta) epoch_end_time = time.time() per_epoch_ptime = epoch_end_time - epoch_start_time lod_for_saving_model = lod2batch.lod lod2batch.step() if local_rank == 0: if lod2batch.is_time_to_save(): checkpointer.save("model_tmp_intermediate_lod%d" % lod_for_saving_model) if lod2batch.is_time_to_report(): save_sample(lod2batch, tracker, sample, samplez, x, logger, model_s, cfg, encoder_optimizer, decoder_optimizer) scheduler.step() if local_rank == 0: checkpointer.save("model_tmp_lod%d" % lod_for_saving_model) save_sample(lod2batch, tracker, sample, samplez, x, logger, model_s, cfg, encoder_optimizer, decoder_optimizer) logger.info("Training finish!... save training results") if local_rank == 0: checkpointer.save("model_final").wait()
def train(cfg, logger, local_rank, world_size, distributed): torch.cuda.set_device(local_rank) model = SoftIntroVAEModelTL( startf=cfg.MODEL.START_CHANNEL_COUNT, layer_count=cfg.MODEL.LAYER_COUNT, maxf=cfg.MODEL.MAX_CHANNEL_COUNT, latent_size=cfg.MODEL.LATENT_SPACE_SIZE, dlatent_avg_beta=cfg.MODEL.DLATENT_AVG_BETA, style_mixing_prob=cfg.MODEL.STYLE_MIXING_PROB, mapping_layers=cfg.MODEL.MAPPING_LAYERS, channels=cfg.MODEL.CHANNELS, generator=cfg.MODEL.GENERATOR, encoder=cfg.MODEL.ENCODER, beta_kl=cfg.MODEL.BETA_KL, beta_rec=cfg.MODEL.BETA_REC, beta_neg=cfg.MODEL.BETA_NEG[cfg.MODEL.LAYER_COUNT - 1], scale=cfg.MODEL.SCALE) model.cuda(local_rank) model.train() if local_rank == 0: model_s = SoftIntroVAEModelTL( startf=cfg.MODEL.START_CHANNEL_COUNT, layer_count=cfg.MODEL.LAYER_COUNT, maxf=cfg.MODEL.MAX_CHANNEL_COUNT, latent_size=cfg.MODEL.LATENT_SPACE_SIZE, truncation_psi=cfg.MODEL.TRUNCATIOM_PSI, truncation_cutoff=cfg.MODEL.TRUNCATIOM_CUTOFF, mapping_layers=cfg.MODEL.MAPPING_LAYERS, channels=cfg.MODEL.CHANNELS, generator=cfg.MODEL.GENERATOR, encoder=cfg.MODEL.ENCODER, beta_kl=cfg.MODEL.BETA_KL, beta_rec=cfg.MODEL.BETA_REC, beta_neg=cfg.MODEL.BETA_NEG[cfg.MODEL.LAYER_COUNT - 1], scale=cfg.MODEL.SCALE) model_s.cuda(local_rank) model_s.eval() model_s.requires_grad_(False) if distributed: model = nn.parallel.DistributedDataParallel( model, device_ids=[local_rank], broadcast_buffers=False, bucket_cap_mb=25, find_unused_parameters=True) model.device_ids = None decoder = model.module.decoder encoder = model.module.encoder mapping_tl = model.module.mapping_tl mapping_fl = model.module.mapping_fl dlatent_avg = model.module.dlatent_avg else: decoder = model.decoder encoder = model.encoder mapping_tl = model.mapping_tl mapping_fl = model.mapping_fl dlatent_avg = model.dlatent_avg count_parameters.print = lambda a: logger.info(a) num_vae_epochs = cfg.TRAIN.NUM_VAE logger.info("Trainable parameters decoder:") print(count_parameters(decoder)) logger.info("Trainable parameters encoder:") print(count_parameters(encoder)) arguments = dict() arguments["iteration"] = 0 decoder_optimizer = LREQAdam([{ 'params': decoder.parameters() }, { 'params': mapping_fl.parameters() }], lr=cfg.TRAIN.BASE_LEARNING_RATE, betas=(cfg.TRAIN.ADAM_BETA_0, cfg.TRAIN.ADAM_BETA_1), weight_decay=0) encoder_optimizer = LREQAdam([{ 'params': encoder.parameters() }, { 'params': mapping_tl.parameters() }], lr=cfg.TRAIN.BASE_LEARNING_RATE, betas=(cfg.TRAIN.ADAM_BETA_0, cfg.TRAIN.ADAM_BETA_1), weight_decay=0) scheduler = ComboMultiStepLR(optimizers={ 'encoder_optimizer': encoder_optimizer, 'decoder_optimizer': decoder_optimizer }, milestones=cfg.TRAIN.LEARNING_DECAY_STEPS, gamma=cfg.TRAIN.LEARNING_DECAY_RATE, reference_batch_size=32, base_lr=cfg.TRAIN.LEARNING_RATES) model_dict = { 'discriminator': encoder, 'generator': decoder, 'mapping_fl': mapping_fl, 'mapping_tl': mapping_tl, 'dlatent_avg': dlatent_avg } if local_rank == 0: model_dict['discriminator_s'] = model_s.encoder model_dict['generator_s'] = model_s.decoder model_dict['mapping_fl_s'] = model_s.mapping_fl model_dict['mapping_tl_s'] = model_s.mapping_tl tracker = LossTracker(cfg.OUTPUT_DIR) checkpointer = Checkpointer(cfg, model_dict, { 'encoder_optimizer': encoder_optimizer, 'decoder_optimizer': decoder_optimizer, 'scheduler': scheduler, 'tracker': tracker }, logger=logger, save=local_rank == 0) extra_checkpoint_data = checkpointer.load() logger.info("Starting from epoch: %d" % (scheduler.start_epoch())) arguments.update(extra_checkpoint_data) layer_to_resolution = decoder.layer_to_resolution dataset = TFRecordsDataset(cfg, logger, rank=local_rank, world_size=world_size, buffer_size_mb=1024, channels=cfg.MODEL.CHANNELS) rnd = np.random.RandomState(3456) latents = rnd.randn(32, cfg.MODEL.LATENT_SPACE_SIZE) samplez = torch.tensor(latents).float().cuda() lod2batch = lod_driver.LODDriver(cfg, logger, world_size, dataset_size=len(dataset) * world_size) if cfg.DATASET.SAMPLES_PATH: path = cfg.DATASET.SAMPLES_PATH src = [] with torch.no_grad(): for filename in list(os.listdir(path))[:32]: img = np.asarray(Image.open(os.path.join(path, filename))) if img.shape[2] == 4: img = img[:, :, :3] im = img.transpose((2, 0, 1)) x = torch.tensor(np.asarray(im, dtype=np.float32), requires_grad=True).cuda() / 127.5 - 1. if x.shape[0] == 4: x = x[:3] src.append(x) sample = torch.stack(src) else: dataset.reset(cfg.DATASET.MAX_RESOLUTION_LEVEL, 32) sample = next(make_dataloader(cfg, logger, dataset, 32, local_rank)) sample = (sample / 127.5 - 1.) lod2batch.set_epoch(scheduler.start_epoch(), [encoder_optimizer, decoder_optimizer]) kls_real = [] kls_fake = [] rec_errs = [] best_fid = None # best_fid = 20.0 for epoch in range(scheduler.start_epoch(), cfg.TRAIN.TRAIN_EPOCHS): new_beta_neg = cfg.MODEL.BETA_NEG[lod2batch.lod] if distributed: if model.module.beta_neg != new_beta_neg: model.module.beta_neg = new_beta_neg print("beta negative changed to:", new_beta_neg) else: if model.beta_neg != new_beta_neg: model.beta_neg = new_beta_neg print("beta negative changed to:", new_beta_neg) if (epoch > cfg.TRAIN.EPOCHS_PER_LOD * (cfg.MODEL.LAYER_COUNT - 1)) and (epoch % 10 == 0) and (local_rank == 0): print("calculating fid...") fid = calc_fid_from_dataset_generate( cfg, dataset, model_s, batch_size=50, cuda=1, dims=2048, device=torch.device( "cuda:0" if torch.cuda.is_available() else "cpu"), num_images=50000) print("epoch: {}, fid: {}".format(epoch, fid)) if best_fid is None: best_fid = fid elif fid < best_fid: print("best fid updated: {} -> {}".format(best_fid, fid)) best_fid = fid checkpointer.save("model_tmp_lod{}_fid_{}".format( lod2batch.lod, fid)) lod2batch.set_epoch(epoch, [encoder_optimizer, decoder_optimizer]) logger.info( "Batch size: %d, Batch size per GPU: %d, LOD: %d - %dx%d, blend: %.3f, dataset size: %d" % (lod2batch.get_batch_size(), lod2batch.get_per_GPU_batch_size(), lod2batch.lod, 2**lod2batch.get_lod_power2(), 2** lod2batch.get_lod_power2(), lod2batch.get_blend_factor(), len(dataset) * world_size)) dataset.reset(lod2batch.get_lod_power2(), lod2batch.get_per_GPU_batch_size()) batches = make_dataloader(cfg, logger, dataset, lod2batch.get_per_GPU_batch_size(), local_rank) scheduler.set_batch_size(lod2batch.get_batch_size(), lod2batch.lod) model.train() need_permute = False epoch_start_time = time.time() i = 0 diff_kls = [] batch_kls_real = [] batch_kls_fake = [] batch_rec_errs = [] for x_orig in tqdm(batches): i += 1 with torch.no_grad(): if x_orig.shape[0] != lod2batch.get_per_GPU_batch_size(): continue if need_permute: x_orig = x_orig.permute(0, 3, 1, 2) x_orig = (x_orig / 127.5 - 1.) blend_factor = lod2batch.get_blend_factor() needed_resolution = layer_to_resolution[lod2batch.lod] x = x_orig if lod2batch.in_transition: needed_resolution_prev = layer_to_resolution[lod2batch.lod - 1] x_prev = F.avg_pool2d(x_orig, 2, 2) x_prev_2x = F.interpolate(x_prev, needed_resolution) x = x * blend_factor + x_prev_2x * (1.0 - blend_factor) x.requires_grad = True if epoch < num_vae_epochs: encoder_optimizer.zero_grad() decoder_optimizer.zero_grad() loss = model(x, lod2batch.lod, blend_factor, d_train=False, e_train=False) tracker.update(dict(loss_e=loss)) tracker.update(dict(loss_d=loss)) loss.backward() encoder_optimizer.step() decoder_optimizer.step() else: # ------------- Update Encoder ------------- # encoder_optimizer.zero_grad() loss_e = model(x, lod2batch.lod, blend_factor, d_train=False, e_train=True) tracker.update(dict(loss_e=loss_e)) loss_e.backward() encoder_optimizer.step() # ------------- Update Decoder ------------- # decoder_optimizer.zero_grad() loss_d = model(x, lod2batch.lod, blend_factor, d_train=True, e_train=False) loss_d.backward() tracker.update(dict(loss_d=loss_d)) decoder_optimizer.step() # ------------- Update Statistics ------------- # if distributed: tracker.update(dict(rec_loss=model.module.last_rec_loss)) tracker.update(dict(real_kl=model.module.last_real_kl)) tracker.update(dict(fake_kl=model.module.last_fake_kl)) tracker.update(dict(kl_diff=model.module.last_kl_diff)) tracker.update(dict(expelbo_f=model.module.last_expelbo_fake)) tracker.update(dict(expelbo_r=model.module.last_expelbo_rec)) diff_kls.append(model.module.last_kl_diff.data.cpu()) batch_kls_real.append(model.module.last_real_kl) batch_kls_fake.append(model.module.last_fake_kl) batch_rec_errs.append(model.module.last_rec_loss) else: tracker.update(dict(rec_loss=model.last_rec_loss)) tracker.update(dict(real_kl=model.last_real_kl)) tracker.update(dict(fake_kl=model.last_fake_kl)) tracker.update(dict(kl_diff=model.last_kl_diff)) tracker.update(dict(expelbo_f=model.last_expelbo_fake)) tracker.update(dict(expelbo_r=model.last_expelbo_rec)) diff_kls.append(model.last_kl_diff.data.cpu()) batch_kls_real.append(model.last_real_kl) batch_kls_fake.append(model.last_fake_kl) batch_rec_errs.append(model.last_rec_loss) if local_rank == 0: betta = 0.5**(lod2batch.get_batch_size() / (10 * 1000.0)) model_s.lerp(model, betta) epoch_end_time = time.time() per_epoch_ptime = epoch_end_time - epoch_start_time lod_for_saving_model = lod2batch.lod lod2batch.step() if local_rank == 0: if lod2batch.is_time_to_save(): checkpointer.save("model_tmp_intermediate_lod%d" % lod_for_saving_model) if lod2batch.is_time_to_report(): save_sample(lod2batch, tracker, sample, samplez, x, logger, model_s, cfg, encoder_optimizer, decoder_optimizer) scheduler.step() mean_diff_kl = np.mean(diff_kls) print("mean diff kl: ", mean_diff_kl) if epoch > num_vae_epochs - 1: kls_real.append(np.mean(batch_kls_real)) kls_fake.append(np.mean(batch_kls_fake)) rec_errs.append(np.mean(batch_rec_errs)) if local_rank == 0: checkpointer.save("model_tmp_lod%d" % lod_for_saving_model) save_sample(lod2batch, tracker, sample, samplez, x, logger, model_s, cfg, encoder_optimizer, decoder_optimizer) logger.info("Training finish!... save training results") if local_rank == 0: checkpointer.save("model_final").wait()
def train(cfg, logger, local_rank, world_size, distributed): torch.cuda.set_device(local_rank) model = Model(startf=cfg.MODEL.START_CHANNEL_COUNT, layer_count=cfg.MODEL.LAYER_COUNT, maxf=cfg.MODEL.MAX_CHANNEL_COUNT, latent_size=cfg.MODEL.LATENT_SPACE_SIZE, dlatent_avg_beta=cfg.MODEL.DLATENT_AVG_BETA, style_mixing_prob=cfg.MODEL.STYLE_MIXING_PROB, mapping_layers=cfg.MODEL.MAPPING_LAYERS, channels=1) model.cuda(local_rank) model.train() if local_rank == 0: model_s = Model(startf=cfg.MODEL.START_CHANNEL_COUNT, layer_count=cfg.MODEL.LAYER_COUNT, maxf=cfg.MODEL.MAX_CHANNEL_COUNT, latent_size=cfg.MODEL.LATENT_SPACE_SIZE, truncation_psi=cfg.MODEL.TRUNCATIOM_PSI, truncation_cutoff=cfg.MODEL.TRUNCATIOM_CUTOFF, mapping_layers=cfg.MODEL.MAPPING_LAYERS, channels=1) del model_s.discriminator model_s.cuda(local_rank) model_s.eval() model_s.requires_grad_(False) if distributed: model = nn.parallel.DistributedDataParallel( model, device_ids=[local_rank], broadcast_buffers=False, bucket_cap_mb=25, find_unused_parameters=True) model.device_ids = None generator = model.module.generator discriminator = model.module.discriminator mapping = model.module.mapping dlatent_avg = model.module.dlatent_avg else: generator = model.generator discriminator = model.discriminator mapping = model.mapping dlatent_avg = model.dlatent_avg count_param_override.print = lambda a: logger.info(a) logger.info("Trainable parameters generator:") count_parameters(generator) logger.info("Trainable parameters discriminator:") count_parameters(discriminator) arguments = dict() arguments["iteration"] = 0 generator_optimizer = LREQAdam([{ 'params': generator.parameters() }, { 'params': mapping.parameters() }], lr=cfg.TRAIN.BASE_LEARNING_RATE, betas=(cfg.TRAIN.ADAM_BETA_0, cfg.TRAIN.ADAM_BETA_1), weight_decay=0) discriminator_optimizer = LREQAdam([ { 'params': discriminator.parameters() }, ], lr=cfg.TRAIN.BASE_LEARNING_RATE, betas=(cfg.TRAIN.ADAM_BETA_0, cfg.TRAIN.ADAM_BETA_1), weight_decay=0) scheduler = ComboMultiStepLR(optimizers={ 'generator': generator_optimizer, 'discriminator': discriminator_optimizer }, milestones=cfg.TRAIN.LEARNING_DECAY_STEPS, gamma=cfg.TRAIN.LEARNING_DECAY_RATE, reference_batch_size=32) model_dict = { 'discriminator': discriminator, 'generator': generator, 'mapping': mapping, 'dlatent_avg': dlatent_avg } if local_rank == 0: model_dict['generator_s'] = model_s.generator model_dict['mapping_s'] = model_s.mapping tracker = LossTracker(cfg.OUTPUT_DIR) checkpointer = Checkpointer( cfg, model_dict, { 'generator_optimizer': generator_optimizer, 'discriminator_optimizer': discriminator_optimizer, 'tracker': tracker, 'scheduler': scheduler, }, logger=logger, save=local_rank == 0) extra_checkpoint_data = checkpointer.load() arguments.update(extra_checkpoint_data) layer_to_resolution = generator.layer_to_resolution dlutils.download.mnist() mnist = dlutils.reader.Mnist('mnist').items mnist = np.asarray([x[1] for x in mnist], np.float32) rnd = np.random.RandomState(3456) latents = rnd.randn(32, cfg.MODEL.LATENT_SPACE_SIZE) sample = torch.tensor(latents).float().cuda() lod2batch = lod_driver.LODDriver(cfg, logger, world_size, dataset_size=len(mnist)) for epoch in range(scheduler.start_epoch(), cfg.TRAIN.TRAIN_EPOCHS): model.train() lod2batch.set_epoch(epoch, [generator_optimizer, discriminator_optimizer]) logger.info( "Batch size: %d, Batch size per GPU: %d, LOD: %d - %dx%d, dataset size: %d" % (lod2batch.get_batch_size(), lod2batch.get_per_GPU_batch_size(), lod2batch.lod, 2**lod2batch.get_lod_power2(), 2** lod2batch.get_lod_power2(), len(mnist))) dlutils.shuffle.shuffle_ndarray(mnist) r = 2**lod2batch.get_lod_power2() mnist_ = F.interpolate( torch.tensor(mnist).view(mnist.shape[0], 1, 28, 28), r).detach().cpu().numpy() scheduler.set_batch_size(32) model.train() need_permute = False class BatchCollator(object): def __init__(self, device=torch.device("cpu")): self.device = device def __call__(self, batch): with torch.no_grad(): x = batch x = torch.tensor(x, requires_grad=True, device=torch.device(self.device), dtype=torch.float32) return x batches = dlutils.batch_provider(mnist_, lod2batch.get_per_GPU_batch_size(), BatchCollator(local_rank), report_progress=False) with torch.autograd.profiler.profile(use_cuda=True, enabled=False) as prof: for x_orig in tqdm(batches): if x_orig.shape[0] != lod2batch.get_per_GPU_batch_size(): continue if need_permute: x_orig = x_orig.permute(0, 3, 1, 2) x_orig = (x_orig / 127.5 - 1.) blend_factor = lod2batch.get_blend_factor() needed_resolution = layer_to_resolution[lod2batch.lod] x = x_orig if lod2batch.in_transition: needed_resolution_prev = layer_to_resolution[lod2batch.lod - 1] x_prev = F.avg_pool2d(x_orig, 2, 2) x_prev_2x = F.interpolate(x_prev, needed_resolution) x = x * blend_factor + x_prev_2x * (1.0 - blend_factor) discriminator_optimizer.zero_grad() loss_d = model(x, lod2batch.lod, blend_factor, d_train=True) tracker.update(dict(loss_d=loss_d)) loss_d.backward() discriminator_optimizer.step() if local_rank == 0: betta = 0.5**(lod2batch.get_batch_size() / (10 * 1000.0)) model_s.lerp(model, betta) generator_optimizer.zero_grad() loss_g = model(x, lod2batch.lod, blend_factor, d_train=False) tracker.update(dict(loss_g=loss_g)) loss_g.backward() generator_optimizer.step() lod2batch.step() if local_rank == 0 and lod2batch.is_time_to_report(): save_sample(lod2batch, tracker, sample, x, logger, model_s, cfg, discriminator_optimizer, generator_optimizer) #print(prof.key_averages().table(sort_by="self_cpu_time_total")) if local_rank == 0: save_sample(lod2batch, tracker, sample, x, logger, model_s, cfg, discriminator_optimizer, generator_optimizer) if epoch > 2: checkpointer.save("model_tmp") scheduler.step() logger.info("Training finish!... save training results") if local_rank == 0: checkpointer.save("model_final")
class Trainer(): def __init__(self, config, model, optimizer, scheduler, train_iter, valid_iter, fields): self._config = config self._epochs = config.epochs self._serialization_dir = Path(config.serialization_dir) self._keep_all_serialized_models = config.keep_all_serialized_models self._log_file = config.log_file self._model = model self._optimizer = optimizer self._max_grad_norm = config.grad_norm self._scheduler = scheduler self._start_decay_epoch = config.start_decay_epoch self._checkpointer = Checkpointer(self._serialization_dir, self._keep_all_serialized_models) self._train_iter = train_iter self._valid_iter = valid_iter self._fields = fields if config.gpu < 0: self._device = torch.device('cpu') else: self._device = torch.device('cuda:0') self._model.to(self._device) print(self._model) def run(self, ): best_loss = np.inf is_best = False for epoch in range(1, self._epochs + 1): if epoch >= self._start_decay_epoch: self._scheduler.step() print("epoch:", epoch) train_loss = self._train() valid_loss = self._valid() if valid_loss < best_loss: best_loss = valid_loss is_best = True else: is_best = False self._save(epoch, is_best) scores = { 'train/loss': train_loss, 'valid/loss': valid_loss, } self._report(epoch, scores) return def _train(self, ): print('train model') model = self._model optimizer = self._optimizer total_loss = 0 total_norm = 0 model.train() for batch in tqdm(self._train_iter): optimizer.zero_grad() output_dict = model(batch.source_tokens, batch.target_tokens) loss = output_dict["loss"] loss.backward() torch.nn.utils.clip_grad_norm_(model.parameters(), self._max_grad_norm) optimizer.step() loss = output_dict["loss"] total_loss += loss.item() * len(batch) total_norm += len(batch) return total_loss / total_norm def _valid(self, ): print('valid model') model = self._model total_loss = 0 total_norm = 0 model.eval() for batch in tqdm(self._valid_iter): with torch.no_grad(): output_dict = model(batch.source_tokens, batch.target_tokens) loss = output_dict["loss"] total_loss += loss.item() * len(batch) total_norm += len(batch) return total_loss / total_norm def _evaluate(): pass def _save(self, epoch, is_best): print('save model') model_state = { "epoch": epoch, "model": self._model.state_dict(), "optim": self._optimizer.state_dict(), "sched": self._scheduler.state_dict(), "fields": self._fields, "config": self._config } self._checkpointer.save(epoch, model_state, is_best) return def get_best_model_path(self): return self._checkpointer.get_best_model_path() def _report(self, epoch, scores): # lr = self._optimizer.lr print("epoch: {}".format(epoch)) # print("lr: {}".format(lr)) print("scores:") for k, v in scores.items(): print("\t{} {}".format(k, v)) with open(self._serialization_dir / self._log_file, "a") as f: print("\t".join(map(str, [epoch, *scores.values()])), file=f) return
def train(cfg, logger, local_rank, world_size, distributed): torch.cuda.set_device(local_rank) model = Model(startf=cfg.MODEL.START_CHANNEL_COUNT, layer_count=cfg.MODEL.LAYER_COUNT, maxf=cfg.MODEL.MAX_CHANNEL_COUNT, latent_size=cfg.MODEL.LATENT_SPACE_SIZE, dlatent_avg_beta=cfg.MODEL.DLATENT_AVG_BETA, style_mixing_prob=cfg.MODEL.STYLE_MIXING_PROB, mapping_layers=cfg.MODEL.MAPPING_LAYERS, channels=3) model.cuda(local_rank) model.train() if local_rank == 0: model_s = Model(startf=cfg.MODEL.START_CHANNEL_COUNT, layer_count=cfg.MODEL.LAYER_COUNT, maxf=cfg.MODEL.MAX_CHANNEL_COUNT, latent_size=cfg.MODEL.LATENT_SPACE_SIZE, truncation_psi=cfg.MODEL.TRUNCATIOM_PSI, truncation_cutoff=cfg.MODEL.TRUNCATIOM_CUTOFF, mapping_layers=cfg.MODEL.MAPPING_LAYERS, channels=3) del model_s.discriminator model_s.cuda(local_rank) model_s.eval() model_s.requires_grad_(False) if distributed: model = nn.parallel.DistributedDataParallel( model, device_ids=[local_rank], broadcast_buffers=False, bucket_cap_mb=25, find_unused_parameters=True) model.device_ids = None generator = model.module.generator discriminator = model.module.discriminator mapping = model.module.mapping dlatent_avg = model.module.dlatent_avg else: generator = model.generator discriminator = model.discriminator mapping = model.mapping dlatent_avg = model.dlatent_avg count_param_override.print = lambda a: logger.info(a) logger.info("Trainable parameters generator:") count_parameters(generator) logger.info("Trainable parameters discriminator:") count_parameters(discriminator) generator_optimizer = LREQAdam([{ 'params': generator.parameters() }, { 'params': mapping.parameters() }], lr=cfg.TRAIN.BASE_LEARNING_RATE, betas=(cfg.TRAIN.ADAM_BETA_0, cfg.TRAIN.ADAM_BETA_1), weight_decay=0) discriminator_optimizer = LREQAdam([ { 'params': discriminator.parameters() }, ], lr=cfg.TRAIN.BASE_LEARNING_RATE, betas=(cfg.TRAIN.ADAM_BETA_0, cfg.TRAIN.ADAM_BETA_1), weight_decay=0) scheduler = ComboMultiStepLR(optimizers={ 'generator': generator_optimizer, 'discriminator': discriminator_optimizer }, milestones=cfg.TRAIN.LEARNING_DECAY_STEPS, gamma=cfg.TRAIN.LEARNING_DECAY_RATE, reference_batch_size=32, base_lr=cfg.TRAIN.LEARNING_RATES) model_dict = { 'discriminator': discriminator, 'generator': generator, 'mapping': mapping, 'dlatent_avg': dlatent_avg } if local_rank == 0: model_dict['generator_s'] = model_s.generator model_dict['mapping_s'] = model_s.mapping tracker = LossTracker(cfg.OUTPUT_DIR) checkpointer = Checkpointer( cfg, model_dict, { 'generator_optimizer': generator_optimizer, 'discriminator_optimizer': discriminator_optimizer, 'scheduler': scheduler, 'tracker': tracker }, logger=logger, save=local_rank == 0) checkpointer.load() logger.info("Starting from epoch: %d" % (scheduler.start_epoch())) layer_to_resolution = generator.layer_to_resolution dataset = TFRecordsDataset(cfg, logger, rank=local_rank, world_size=world_size, buffer_size_mb=1024) rnd = np.random.RandomState(3456) latents = rnd.randn(32, cfg.MODEL.LATENT_SPACE_SIZE) sample = torch.tensor(latents).float().cuda() lod2batch = lod_driver.LODDriver(cfg, logger, world_size, dataset_size=len(dataset) * world_size) for epoch in range(scheduler.start_epoch(), cfg.TRAIN.TRAIN_EPOCHS): model.train() lod2batch.set_epoch(epoch, [generator_optimizer, discriminator_optimizer]) logger.info( "Batch size: %d, Batch size per GPU: %d, LOD: %d - %dx%d, blend: %.3f, dataset size: %d" % (lod2batch.get_batch_size(), lod2batch.get_per_GPU_batch_size(), lod2batch.lod, 2**lod2batch.get_lod_power2(), 2** lod2batch.get_lod_power2(), lod2batch.get_blend_factor(), len(dataset) * world_size)) dataset.reset(lod2batch.get_lod_power2(), lod2batch.get_per_GPU_batch_size()) batches = make_dataloader(cfg, logger, dataset, lod2batch.get_per_GPU_batch_size(), local_rank) scheduler.set_batch_size(lod2batch.get_batch_size(), lod2batch.lod) model.train() need_permute = False with torch.autograd.profiler.profile(use_cuda=True, enabled=False) as prof: for x_orig in tqdm(batches): torch.distributed.barrier() with torch.no_grad(): if x_orig.shape[0] != lod2batch.get_per_GPU_batch_size(): continue if need_permute: x_orig = x_orig.permute(0, 3, 1, 2) x_orig = (x_orig / 127.5 - 1.) blend_factor = lod2batch.get_blend_factor() needed_resolution = layer_to_resolution[lod2batch.lod] x = x_orig if lod2batch.in_transition: needed_resolution_prev = layer_to_resolution[ lod2batch.lod - 1] x_prev = F.avg_pool2d(x_orig, 2, 2) x_prev_2x = F.interpolate(x_prev, needed_resolution) x = x * blend_factor + x_prev_2x * (1.0 - blend_factor) x.requires_grad = True discriminator_optimizer.zero_grad() loss_d = model(x, lod2batch.lod, blend_factor, d_train=True) tracker.update(dict(loss_d=loss_d)) loss_d.backward() discriminator_optimizer.step() if local_rank == 0: betta = 0.5**(lod2batch.get_batch_size() / (10 * 1000.0)) model_s.lerp(model, betta) generator_optimizer.zero_grad() loss_g = model(x, lod2batch.lod, blend_factor, d_train=False) tracker.update(dict(loss_g=loss_g)) loss_g.backward() generator_optimizer.step() lod2batch.step() if local_rank == 0: if lod2batch.is_time_to_save(): checkpointer.save("model_tmp_intermediate") if lod2batch.is_time_to_report(): save_sample(lod2batch, tracker, sample, x, logger, model_s, cfg, discriminator_optimizer, generator_optimizer) #print(prof.key_averages().table(sort_by="self_cpu_time_total")) scheduler.step() if local_rank == 0: checkpointer.save("model_tmp") save_sample(lod2batch, tracker, sample, x, logger, model_s, cfg, discriminator_optimizer, generator_optimizer) logger.info("Training finish!... save training results") if local_rank == 0: checkpointer.save("model_final").wait()