def train(config, generator, mask_generator, checkpoint, log_dir, dataset, device_ids): train_params = config['train_params'] optimizer_generator = torch.optim.Adam(generator.parameters(), lr=train_params['lr_generator'], betas=(0.5, 0.999)) optimizer_mask_generator = torch.optim.Adam( mask_generator.parameters(), lr=train_params['lr_mask_generator'], betas=(0.5, 0.999)) if checkpoint is not None: print('loading cpk') start_epoch = Logger.load_cpk( checkpoint, generator, mask_generator, optimizer_generator, None if train_params['lr_mask_generator'] == 0 else optimizer_mask_generator) else: start_epoch = 0 print(start_epoch) scheduler_generator = MultiStepLR(optimizer_generator, train_params['epoch_milestones'], gamma=0.1, last_epoch=start_epoch - 1) scheduler_mask_generator = MultiStepLR( optimizer_mask_generator, train_params['epoch_milestones'], gamma=0.1, last_epoch=-1 + start_epoch * (train_params['lr_mask_generator'] != 0)) if 'num_repeats' in train_params or train_params['num_repeats'] != 1: dataset = DatasetRepeater(dataset, train_params['num_repeats']) dataloader = DataLoader(dataset, batch_size=train_params['batch_size'], shuffle=True, num_workers=6, drop_last=True) generator_full = GeneratorFullModel(mask_generator, generator, train_params) if torch.cuda.is_available(): generator_full = DataParallelWithCallback(generator_full, device_ids=device_ids) with Logger(log_dir=log_dir, visualizer_params=config['visualizer_params'], checkpoint_freq=train_params['checkpoint_freq']) as logger: for epoch in trange(start_epoch, train_params['num_epochs']): for index, x in enumerate(dataloader): predict_mask = epoch >= 1 losses_generator, generated = generator_full(x, predict_mask) loss_values = [val.mean() for val in losses_generator.values()] loss = sum(loss_values) loss.backward() optimizer_generator.step() optimizer_generator.zero_grad() optimizer_mask_generator.step() optimizer_mask_generator.zero_grad() losses = { key: value.mean().detach().data.cpu().numpy() for key, value in losses_generator.items() } logger.log_iter(losses=losses) scheduler_generator.step() scheduler_mask_generator.step() logger.log_epoch( epoch, { 'generator': generator, 'mask_generator': mask_generator, 'optimizer_generator': optimizer_generator, 'optimizer_mask_generator': optimizer_mask_generator }, inp=x, out=generated, save_w=True)
def train(config, generator, discriminator, kp_detector, checkpoint, log_dir, dataset, device_ids): # Refer to *.yaml, "train_params" section. # This including epoch nums, etc ... train_params = config['train_params'] # Define the optimizers for three sub-networks # Refer to Adam() document for details optimizer_generator = torch.optim.Adam(generator.parameters(), lr=train_params['lr_generator'], betas=(0.5, 0.999)) optimizer_discriminator = torch.optim.Adam(discriminator.parameters(), lr=train_params['lr_discriminator'], betas=(0.5, 0.999)) optimizer_kp_detector = torch.optim.Adam(kp_detector.parameters(), lr=train_params['lr_kp_detector'], betas=(0.5, 0.999)) if checkpoint is not None: # Load in pretrained-models if set so # Models passed in are empty-initialized, which will be loaded in the following function start_epoch = Logger.load_cpk(checkpoint, generator, discriminator, kp_detector, optimizer_generator, optimizer_discriminator, None if train_params['lr_kp_detector'] == 0 else optimizer_kp_detector) else: start_epoch = 0 # TODO: not sure what's this, it seems to define schedulers contronlling training details scheduler_generator = MultiStepLR(optimizer_generator, train_params['epoch_milestones'], gamma=0.1, last_epoch=start_epoch - 1) scheduler_discriminator = MultiStepLR(optimizer_discriminator, train_params['epoch_milestones'], gamma=0.1, last_epoch=start_epoch - 1) scheduler_kp_detector = MultiStepLR(optimizer_kp_detector, train_params['epoch_milestones'], gamma=0.1, last_epoch=-1 + start_epoch * (train_params['lr_kp_detector'] != 0)) if 'num_repeats' in train_params or train_params['num_repeats'] != 1: # Augment the dataset according to "num_reapeat" dataset = DatasetRepeater(dataset, train_params['num_repeats']) # Load in data with form that network can determine # Refer to pytorch DataLoader for details # 这里dataloader是一个FramesDataset类,它是 Dataset 的一个子类,所以可以有如下操作 dataloader = DataLoader(dataset, batch_size=train_params['batch_size'], shuffle=True, num_workers=2, drop_last=True) # Initialize two models for training # TODO: 阅读 generator 和 discrimator 的构造,key point detector 的部分应包含在 generator 当中 generator_full = GeneratorFullModel(kp_detector, generator, discriminator, train_params) # TODO: 阅读 discriminator,需注意的是上述 Generator 中也有 discriminator 存在,高清两者区别 discriminator_full = DiscriminatorFullModel(kp_detector, generator, discriminator, train_params) # Transfer model to gpu type if torch.cuda.is_available(): generator_full = DataParallelWithCallback(generator_full, device_ids=device_ids) discriminator_full = DataParallelWithCallback(discriminator_full, device_ids=device_ids) with Logger(log_dir=log_dir, visualizer_params=config['visualizer_params'], checkpoint_freq=train_params['checkpoint_freq']) as logger: for epoch in trange(start_epoch, train_params['num_epochs']): for x in dataloader: # 此处为前向传播,第一个返回值为loss,第二个为生成器的输出图片 losses_generator, generated = generator_full(x) # 此处计算的loss有很多种类,此处取了每一种的平均并求和 loss_values = [val.mean() for val in losses_generator.values()] loss = sum(loss_values) # 此处分别使用不同部分的优化器进行 step 更新 loss.backward() optimizer_generator.step() optimizer_generator.zero_grad() optimizer_kp_detector.step() optimizer_kp_detector.zero_grad() # 此处判断是否使用 GAN 的训练思想 if train_params['loss_weights']['generator_gan'] != 0: # 增加判别器的使用 optimizer_discriminator.zero_grad() # 用判别器判定生成数据和源数据 losses_discriminator = discriminator_full(x, generated) loss_values = [val.mean() for val in losses_discriminator.values()] loss = sum(loss_values) # 更新判别器 loss.backward() optimizer_discriminator.step() optimizer_discriminator.zero_grad() else: losses_discriminator = {} # 注意此处的 update 是 python 中字典自带的更新方式 losses_generator.update(losses_discriminator) losses = {key: value.mean().detach().data.cpu().numpy() for key, value in losses_generator.items()} logger.log_iter(losses=losses) # 此处为一个 epoch 的工作完成 # TODO: 这是之前不确定是什么的数据结构,推断是对训练的schedule器的更新 scheduler_generator.step() scheduler_discriminator.step() scheduler_kp_detector.step() logger.log_epoch(epoch, {'generator': generator, 'discriminator': discriminator, 'kp_detector': kp_detector, 'optimizer_generator': optimizer_generator, 'optimizer_discriminator': optimizer_discriminator, 'optimizer_kp_detector': optimizer_kp_detector}, inp=x, out=generated)
def train(config, generator, discriminator, kp_detector, save_dir, dataset): train_params = config['train_params'] # learning_rate_scheduler gen_lr = MultiStepDecay(learning_rate=train_params['lr_generator'], milestones=train_params['epoch_milestones'], gamma=0.1) dis_lr = MultiStepDecay(learning_rate=train_params['lr_discriminator'], milestones=train_params['epoch_milestones'], gamma=0.1) kp_lr = MultiStepDecay(learning_rate=train_params['lr_kp_detector'], milestones=train_params['epoch_milestones'], gamma=0.1) # optimer if TEST_MODE: logging.warning('TEST MODE: Optimer is SGD, lr is 0.001. run.py: L50') optimizer_generator = paddle.optimizer.SGD( parameters=generator.parameters(), learning_rate=0.001) optimizer_discriminator = paddle.optimizer.SGD( parameters=discriminator.parameters(), learning_rate=0.001) optimizer_kp_detector = paddle.optimizer.SGD( parameters=kp_detector.parameters(), learning_rate=0.001) else: optimizer_generator = paddle.optimizer.Adam( parameters=generator.parameters(), learning_rate=gen_lr) optimizer_discriminator = paddle.optimizer.Adam( parameters=discriminator.parameters(), learning_rate=dis_lr) optimizer_kp_detector = paddle.optimizer.Adam( parameters=kp_detector.parameters(), learning_rate=kp_lr) # load start_epoch if isinstance(config['ckpt_model']['start_epoch'], int): start_epoch = config['ckpt_model']['start_epoch'] else: start_epoch = 0 logging.info('Start Epoch is :%i' % start_epoch) # dataset dataloader = paddle.io.DataLoader(dataset, batch_size=train_params['batch_size'], shuffle=True, drop_last=False, num_workers=4, use_buffer_reader=True, use_shared_memory=False) # load checkpoint ckpt_config = config['ckpt_model'] has_key = lambda key: key in ckpt_config.keys() and ckpt_config[ key] is not None load_ckpt(ckpt_config, generator, optimizer_generator, kp_detector, optimizer_kp_detector, discriminator, optimizer_discriminator) # create full model generator_full = GeneratorFullModel(kp_detector, generator, discriminator, train_params) discriminator_full = DiscriminatorFullModel(kp_detector, generator, discriminator, train_params) # load vgg19 if has_key('vgg19_model'): vggVarList = [i for i in generator_full.vgg.parameters()] paramset = np.load(ckpt_config['vgg19_model'], allow_pickle=True)['arr_0'] for var, v in zip(vggVarList, paramset): if list(var.shape) == list(v.shape): var.set_value(v) else: logging.warning('VGG19 cannot be loaded') logging.info('Pre-trained VGG19 is loaded from *.npz') # train generator_full.train() discriminator_full.train() for epoch in trange(start_epoch, train_params['num_epochs']): for _step, _x in enumerate(dataloader()): # prepare data x = dict() x['driving'], x['source'] = _x x['name'] = ['NULL'] * _x[0].shape[0] if TEST_MODE: logging.warning('TEST MODE: Input is Fixed run.py: L207') x['driving'] = paddle.to_tensor(fake_input) x['source'] = paddle.to_tensor(fake_input) x['name'] = ['test1', 'test2'] # train generator losses_generator, generated = generator_full(x.copy()) loss_values = [val.sum() for val in losses_generator.values()] loss = paddle.add_n(loss_values) if TEST_MODE: print('Check Generator Loss') print('\n'.join([ '%s:%1.5f' % (k, v.numpy()) for k, v in zip(losses_generator.keys(), loss_values) ])) import pdb pdb.set_trace() loss.backward() optimizer_generator.step() optimizer_generator.clear_grad() optimizer_kp_detector.step() optimizer_kp_detector.clear_grad() # train discriminator if train_params['loss_weights']['generator_gan'] != 0: optimizer_discriminator.clear_gradients() losses_discriminator = discriminator_full(x.copy(), generated) loss_values = [ val.mean() for val in losses_discriminator.values() ] loss = paddle.add_n(loss_values) if TEST_MODE: print('Check Discriminator Loss') print('\n'.join([ '%s:%1.5f' % (k, v.numpy()) for k, v in zip( losses_discriminator.keys(), loss_values) ])) import pdb pdb.set_trace() loss.backward() optimizer_discriminator.step() optimizer_discriminator.clear_grad() else: losses_discriminator = {} losses_generator.update(losses_discriminator) losses = { key: value.mean().detach().numpy() for key, value in losses_generator.items() } # print log if _step % 20 == 0: logging.info('Epoch:%i\tstep: %i\tLr:%1.7f' % (epoch, _step, optimizer_generator.get_lr())) logging.info('\t'.join( ['%s:%1.4f' % (k, v) for k, v in losses.items()])) # save if epoch % 3 == 0: paddle.fluid.save_dygraph( generator.state_dict(), os.path.join(save_dir, 'epoch%i/G' % epoch)) paddle.fluid.save_dygraph( discriminator.state_dict(), os.path.join(save_dir, 'epoch%i/D' % epoch)) paddle.fluid.save_dygraph( kp_detector.state_dict(), os.path.join(save_dir, 'epoch%i/KP' % epoch)) paddle.fluid.save_dygraph( optimizer_generator.state_dict(), os.path.join(save_dir, 'epoch%i/G' % epoch)) paddle.fluid.save_dygraph( optimizer_discriminator.state_dict(), os.path.join(save_dir, 'epoch%i/D' % epoch)) paddle.fluid.save_dygraph( optimizer_kp_detector.state_dict(), os.path.join(save_dir, 'epoch%i/KP' % epoch)) logging.info('Model is saved to:%s' % os.path.join(save_dir, 'epoch%i/' % epoch)) gen_lr.step() dis_lr.step() kp_lr.step()
def train(config, generator, discriminator, kp_detector, save_dir, dataset): train_params = config['train_params'] # learning_rate_scheduler if paddle.version.full_version in ['1.8.4'] or paddle.version.major == '2': gen_lr = MultiStepDecay(learning_rate=train_params['lr_generator'], milestones=train_params['epoch_milestones'], decay_rate=0.1) dis_lr = MultiStepDecay(learning_rate=train_params['lr_discriminator'], milestones=train_params['epoch_milestones'], decay_rate=0.1) kp_lr = MultiStepDecay(learning_rate=train_params['lr_kp_detector'], milestones=train_params['epoch_milestones'], decay_rate=0.1) else: gen_lr = train_params['lr_generator'] dis_lr = train_params['lr_discriminator'] kp_lr = train_params['lr_kp_detector'] # optimer if TEST_MODE: logging.warning( 'TEST MODE: Optimer is SGD, lr is 0.001. train.py: L50') optimizer_generator = fluid.optimizer.SGDOptimizer( parameter_list=generator.parameters(), learning_rate=0.001) optimizer_discriminator = fluid.optimizer.SGDOptimizer( parameter_list=discriminator.parameters(), learning_rate=0.001) optimizer_kp_detector = fluid.optimizer.SGDOptimizer( parameter_list=kp_detector.parameters(), learning_rate=0.001) else: optimizer_generator = fluid.optimizer.AdamOptimizer( parameter_list=generator.parameters(), learning_rate=gen_lr) optimizer_discriminator = fluid.optimizer.AdamOptimizer( parameter_list=discriminator.parameters(), learning_rate=dis_lr) optimizer_kp_detector = fluid.optimizer.AdamOptimizer( parameter_list=kp_detector.parameters(), learning_rate=kp_lr) # load start_epoch if isinstance(config['ckpt_model']['start_epoch'], int): start_epoch = config['ckpt_model']['start_epoch'] else: start_epoch = 0 logging.info('Start Epoch is :%i' % start_epoch) # dataset pipeline def indexGenertaor(): """随机生成索引序列 """ order = list(range(len(dataset))) order = order * train_params['num_repeats'] random.shuffle(order) for i in order: yield i _dataset = fluid.io.xmap_readers(dataset.getSample, indexGenertaor, process_num=4, buffer_size=128, order=False) _dataset = fluid.io.batch(_dataset, batch_size=train_params['batch_size'], drop_last=True) dataloader = fluid.io.buffered(_dataset, 1) ###### Restore Part ###### ckpt_config = config['ckpt_model'] has_key = lambda key: key in ckpt_config.keys() and ckpt_config[ key] is not None if has_key('generator'): if ckpt_config['generator'][-3:] == 'npz': G_param = np.load(ckpt_config['generator'], allow_pickle=True)['arr_0'].item() G_param_clean = [(i, G_param[i]) for i in G_param if 'num_batches_tracked' not in i] parameter_clean = generator.parameters() del ( parameter_clean[65] ) # The parameters in AntiAliasInterpolation2d is not in dict_set and should be ignore. for v, b in zip(parameter_clean, G_param_clean): v.set_value(b[1]) logging.info('Generator is loaded from *.npz') else: param, optim = fluid.load_dygraph(ckpt_config['generator']) generator.set_dict(param) if optim is not None: optimizer_generator.set_dict(optim) else: logging.info('Optimizer of G is not loaded') logging.info('Generator is loaded from *.pdparams') if has_key('kp'): if ckpt_config['kp'][-3:] == 'npz': KD_param = np.load(ckpt_config['kp'], allow_pickle=True)['arr_0'].item() KD_param_clean = [(i, KD_param[i]) for i in KD_param if 'num_batches_tracked' not in i] parameter_cleans = kp_detector.parameters() for v, b in zip(parameter_cleans, KD_param_clean): v.set_value(b[1]) logging.info('KP is loaded from *.npz') else: param, optim = fluid.load_dygraph(ckpt_config['kp']) kp_detector.set_dict(param) if optim is not None: optimizer_kp_detector.set_dict(optim) else: logging.info('Optimizer of KP is not loaded') logging.info('KP is loaded from *.pdparams') if has_key('discriminator'): if ckpt_config['discriminator'][-3:] == 'npz': D_param = np.load(ckpt_config['discriminator'], allow_pickle=True)['arr_0'].item() if 'NULL Place' in ckpt_config['discriminator']: # 针对未开启spectral_norm的Fashion数据集模型 ## fashion数据集的默认设置中未启用spectral_norm,但其官方ckpt文件中存在spectral_norm特有的参数 需要重排顺序 ## 已提相关issue,作者回应加了sn也没什么影响 https://github.com/AliaksandrSiarohin/first-order-model/issues/264 ## 若在配置文件中开启sn则可通过else语句中的常规方法读取,故现已在配置中开启sn。 D_param_clean = [ (i, D_param[i]) for i in D_param if 'num_batches_tracked' not in i and 'weight_v' not in i and 'weight_u' not in i ] for idx in range(len(D_param_clean) // 2): if 'conv.bias' in D_param_clean[idx * 2][0]: D_param_clean[idx * 2], D_param_clean[ idx * 2 + 1] = D_param_clean[idx * 2 + 1], D_param_clean[idx * 2] parameter_clean = discriminator.parameters() for v, b in zip(parameter_clean, D_param_clean): v.set_value(b[1]) else: D_param_clean = list(D_param.items()) parameter_clean = discriminator.parameters() assert len(D_param_clean) == len(parameter_clean) # 调换顺序 ## PP中: [conv.weight, conv.bias, weight_u, weight_v] ## pytorch中: [conv.bias, conv.weight_orig, weight_u, weight_v] for idx in range(len(parameter_clean)): if list(parameter_clean[idx].shape) == list( D_param_clean[idx][1].shape): parameter_clean[idx].set_value(D_param_clean[idx][1]) elif parameter_clean[idx].name.split( '.')[-1] == 'w_0' and D_param_clean[ idx + 1][0].split('.')[-1] == 'weight_orig': parameter_clean[idx].set_value(D_param_clean[idx + 1][1]) elif parameter_clean[idx].name.split( '.')[-1] == 'b_0' and D_param_clean[ idx - 1][0].split('.')[-1] == 'bias': parameter_clean[idx].set_value(D_param_clean[idx - 1][1]) else: print('Error', idx) logging.info('Discriminator is loaded from *.npz') else: param, optim = fluid.load_dygraph(ckpt_config['discriminator']) discriminator.set_dict(param) if optim is not None: optimizer_discriminator.set_dict(optim) else: logging.info('Optimizer of Discriminator is not loaded') logging.info('Discriminator is loaded from *.pdparams') ###### Restore Part END ###### # create model generator_full = GeneratorFullModel(kp_detector, generator, discriminator, train_params) discriminator_full = DiscriminatorFullModel(kp_detector, generator, discriminator, train_params) if has_key('vgg19_model'): vggVarList = [i for i in generator_full.vgg.parameters()][2:] paramset = np.load(ckpt_config['vgg19_model'], allow_pickle=True)['arr_0'] for var, v in zip(vggVarList, paramset): if list(var.shape) == list(v.shape): var.set_value(v) else: logging.warning('VGG19 cannot be loaded') logging.info('Pre-trained VGG19 is loaded from *.npz') generator_full.train() discriminator_full.train() for epoch in trange(start_epoch, train_params['num_epochs']): for _step, _x in enumerate(dataloader()): # prepear data x = dict() for _key in _x[0].keys(): if str(_key) != 'name': x[_key] = dygraph.to_variable( np.stack([_v[_key] for _v in _x], axis=0).astype(np.float32)) else: x[_key] = np.stack([_v[_key] for _v in _x], axis=0) # import pdb;pdb.set_trace(); if TEST_MODE: logging.warning('TEST MODE: Input is Fixed train.py: L207') x['driving'] = dygraph.to_variable(fake_input) x['source'] = dygraph.to_variable(fake_input) x['name'] = ['test1', 'test2'] # train generator losses_generator, generated = generator_full(x.copy()) loss_values = [ fluid.layers.reduce_sum(val) for val in losses_generator.values() ] loss = fluid.layers.sum(loss_values) if TEST_MODE: print('Check Generator Loss') print('\n'.join([ '%s:%1.5f' % (k, v.numpy()) for k, v in zip(losses_generator.keys(), loss_values) ])) import pdb pdb.set_trace() loss.backward() optimizer_generator.minimize(loss) optimizer_generator.clear_gradients() optimizer_kp_detector.minimize(loss) optimizer_kp_detector.clear_gradients() # train discriminator if train_params['loss_weights']['generator_gan'] != 0: optimizer_discriminator.clear_gradients() losses_discriminator = discriminator_full(x.copy(), generated) loss_values = [ fluid.layers.reduce_mean(val) for val in losses_discriminator.values() ] loss = fluid.layers.sum(loss_values) if TEST_MODE: print('Check Discriminator Loss') print('\n'.join([ '%s:%1.5f' % (k, v.numpy()) for k, v in zip( losses_discriminator.keys(), loss_values) ])) import pdb pdb.set_trace() loss.backward() optimizer_discriminator.minimize(loss) optimizer_discriminator.clear_gradients() else: losses_discriminator = {} losses_generator.update(losses_discriminator) losses = { key: fluid.layers.reduce_mean(value).detach().numpy() for key, value in losses_generator.items() } # print log if _step % 20 == 0: logging.info( 'Epoch:%i\tstep: %i\tLr:%1.7f' % (epoch, _step, optimizer_generator.current_step_lr())) logging.info('\t'.join( ['%s:%1.4f' % (k, v) for k, v in losses.items()])) # save if epoch % 3 == 0: paddle.fluid.save_dygraph( generator.state_dict(), os.path.join(save_dir, 'epoch%i/G' % epoch)) paddle.fluid.save_dygraph( discriminator.state_dict(), os.path.join(save_dir, 'epoch%i/D' % epoch)) paddle.fluid.save_dygraph( kp_detector.state_dict(), os.path.join(save_dir, 'epoch%i/KP' % epoch)) paddle.fluid.save_dygraph( optimizer_generator.state_dict(), os.path.join(save_dir, 'epoch%i/G' % epoch)) paddle.fluid.save_dygraph( optimizer_discriminator.state_dict(), os.path.join(save_dir, 'epoch%i/D' % epoch)) paddle.fluid.save_dygraph( optimizer_kp_detector.state_dict(), os.path.join(save_dir, 'epoch%i/KP' % epoch)) logging.info('Model is saved to:%s' % os.path.join(save_dir, 'epoch%i/' % epoch)) if paddle.version.full_version in ['1.8.4' ] or paddle.version.major == '2': gen_lr.epoch() dis_lr.epoch() kp_lr.epoch()
def train(config, generator, discriminator, kp_detector, checkpoint, log_dir, dataset, device_ids): worker_num = 16 train_params = config['train_params'] optimizer_generator = torch.optim.Adam(generator.parameters(), lr=train_params['lr_generator'], betas=(0.5, 0.999)) optimizer_discriminator = torch.optim.Adam( discriminator.parameters(), lr=train_params['lr_discriminator'], betas=(0.5, 0.999)) optimizer_kp_detector = torch.optim.Adam(kp_detector.parameters(), lr=train_params['lr_kp_detector'], betas=(0.5, 0.999)) if checkpoint is not None: start_epoch = Logger.load_cpk( checkpoint, generator, discriminator, kp_detector, optimizer_generator, optimizer_discriminator, None if train_params['lr_kp_detector'] == 0 else optimizer_kp_detector) else: start_epoch = 0 scheduler_generator = MultiStepLR(optimizer_generator, train_params['epoch_milestones'], gamma=0.1, last_epoch=start_epoch - 1) scheduler_discriminator = MultiStepLR(optimizer_discriminator, train_params['epoch_milestones'], gamma=0.1, last_epoch=start_epoch - 1) scheduler_kp_detector = MultiStepLR(optimizer_kp_detector, train_params['epoch_milestones'], gamma=0.1, last_epoch=-1 + start_epoch * (train_params['lr_kp_detector'] != 0)) if 'num_repeats' in train_params or train_params['num_repeats'] != 1: dataset = DatasetRepeater(dataset, train_params['num_repeats']) dataloader = DataLoader(dataset, batch_size=train_params['batch_size'], shuffle=True, num_workers=worker_num, drop_last=True) generator_full = GeneratorFullModel(kp_detector, generator, discriminator, train_params) discriminator_full = DiscriminatorFullModel(kp_detector, generator, discriminator, train_params) if torch.cuda.is_available(): generator_full = DataParallelWithCallback(generator_full, device_ids=device_ids) discriminator_full = DataParallelWithCallback(discriminator_full, device_ids=device_ids) with Logger(log_dir=log_dir, visualizer_params=config['visualizer_params'], checkpoint_freq=train_params['checkpoint_freq']) as logger: for epoch in range(start_epoch, train_params['num_epochs']): for x in tqdm(dataloader, total=len(dataloader)): losses_generator, generated = generator_full(x) loss_values = [val.mean() for val in losses_generator.values()] loss = sum(loss_values) loss.backward() optimizer_generator.step() optimizer_generator.zero_grad() optimizer_kp_detector.step() optimizer_kp_detector.zero_grad() if train_params['loss_weights']['generator_gan'] != 0: optimizer_discriminator.zero_grad() losses_discriminator = discriminator_full(x, generated) loss_values = [ val.mean() for val in losses_discriminator.values() ] loss = sum(loss_values) loss.backward() optimizer_discriminator.step() optimizer_discriminator.zero_grad() else: losses_discriminator = {} losses_generator.update(losses_discriminator) losses = { key: value.mean().detach().data.cpu().numpy() for key, value in losses_generator.items() } logger.log_iter(losses=losses) scheduler_generator.step() scheduler_discriminator.step() scheduler_kp_detector.step() logger.log_epoch(epoch, { 'generator': generator, 'discriminator': discriminator, 'kp_detector': kp_detector, 'optimizer_generator': optimizer_generator, 'optimizer_discriminator': optimizer_discriminator, 'optimizer_kp_detector': optimizer_kp_detector }, inp=x, out=generated)
def train(config, generator, discriminator, kp_detector, checkpoint, log_dir, dataset, device_ids): train_params = config['train_params'] optimizer_generator = torch.optim.Adam(generator.parameters(), lr=train_params['lr_generator'], betas=(0.5, 0.999)) optimizer_discriminator = torch.optim.Adam( discriminator.parameters(), lr=train_params['lr_discriminator'], betas=(0.5, 0.999)) optimizer_kp_detector = torch.optim.Adam(kp_detector.parameters(), lr=train_params['lr_kp_detector'], betas=(0.5, 0.999)) if checkpoint is not None: start_epoch = Logger.load_cpk( checkpoint, generator, discriminator, kp_detector, optimizer_generator, optimizer_discriminator, None if train_params['lr_kp_detector'] == 0 else optimizer_kp_detector) else: start_epoch = 0 scheduler_generator = MultiStepLR(optimizer_generator, train_params['epoch_milestones'], gamma=0.1, last_epoch=start_epoch - 1) scheduler_discriminator = MultiStepLR(optimizer_discriminator, train_params['epoch_milestones'], gamma=0.1, last_epoch=start_epoch - 1) scheduler_kp_detector = MultiStepLR(optimizer_kp_detector, train_params['epoch_milestones'], gamma=0.1, last_epoch=-1 + start_epoch * (train_params['lr_kp_detector'] != 0)) if 'num_repeats' in train_params or train_params['num_repeats'] != 1: dataset = DatasetRepeater(dataset, train_params['num_repeats']) dataloader = DataLoader(dataset, batch_size=train_params['batch_size'], shuffle=True, drop_last=True, num_workers=4) print_fun(f'Full dataset length (with repeats): {len(dataset)}') generator_full = GeneratorFullModel(kp_detector, generator, discriminator, train_params) discriminator_full = DiscriminatorFullModel(kp_detector, generator, discriminator, train_params) if torch.cuda.is_available(): generator_full = DataParallelWithCallback(generator_full, device_ids=device_ids) discriminator_full = DataParallelWithCallback(discriminator_full, device_ids=device_ids) writer = tensorboardX.SummaryWriter(log_dir, flush_secs=60) with Logger(log_dir=log_dir, visualizer_params=config['visualizer_params'], checkpoint_freq=train_params['checkpoint_freq']) as logger: for epoch in trange(start_epoch, train_params['num_epochs'], disable=None): for i, x in enumerate(dataloader): losses_generator, generated = generator_full(x) loss_values = [val.mean() for val in losses_generator.values()] loss = sum(loss_values) loss.backward() optimizer_generator.step() optimizer_generator.zero_grad() optimizer_kp_detector.step() optimizer_kp_detector.zero_grad() if train_params['loss_weights']['generator_gan'] != 0: optimizer_discriminator.zero_grad() losses_discriminator = discriminator_full(x, generated) loss_values = [ val.mean() for val in losses_discriminator.values() ] loss = sum(loss_values) loss.backward() optimizer_discriminator.step() optimizer_discriminator.zero_grad() else: losses_discriminator = {} losses_generator.update(losses_discriminator) losses = { key: value.mean().detach().data.cpu().numpy() for key, value in losses_generator.items() } logger.log_iter(losses=losses) step = i + int(epoch * len(dataset) / dataloader.batch_size) if step % 20 == 0: print_fun( f'Epoch {epoch + 1}, global step {step}: {", ".join([f"{k}={v}" for k, v in losses.items()])}' ) if step != 0 and step % 50 == 0: for k, loss in losses.items(): writer.add_scalar(k, float(loss), global_step=step) # add images source = x['source'][0].detach().cpu().numpy().transpose( [1, 2, 0]) driving = x['driving'][0].detach().cpu().numpy().transpose( [1, 2, 0]) kp_source = generated['kp_source']['value'][0].detach( ).cpu().numpy() kp_driving = generated['kp_driving']['value'][0].detach( ).cpu().numpy() pred = generated['prediction'][0].detach().cpu().numpy( ).transpose([1, 2, 0]) kp_source = kp_source * 127.5 + 127.5 kp_driving = kp_driving * 127.5 + 127.5 source = cv2.UMat( (source * 255.).clip(0, 255).astype(np.uint8)).get() driving = cv2.UMat( (driving * 255.).clip(0, 255).astype(np.uint8)).get() pred = (pred * 255.).clip(0, 255).astype(np.uint8) for x1, y1 in kp_source: cv2.circle(source, (int(x1), int(y1)), 2, (250, 250, 250), thickness=cv2.FILLED) for x1, y1 in kp_driving: cv2.circle(driving, (int(x1), int(y1)), 2, (250, 250, 250), thickness=cv2.FILLED) writer.add_image('SourceDrivingPred', np.hstack((source, driving, pred)), global_step=step, dataformats='HWC') writer.flush() scheduler_generator.step() scheduler_discriminator.step() scheduler_kp_detector.step() logger.log_epoch( epoch, { 'generator': generator, 'discriminator': discriminator, 'kp_detector': kp_detector, 'optimizer_generator': optimizer_generator, 'optimizer_discriminator': optimizer_discriminator, 'optimizer_kp_detector': optimizer_kp_detector })