Exemplo n.º 1
0
def train(config, generator, mask_generator, checkpoint, log_dir, dataset,
          device_ids):
    train_params = config['train_params']

    optimizer_generator = torch.optim.Adam(generator.parameters(),
                                           lr=train_params['lr_generator'],
                                           betas=(0.5, 0.999))
    optimizer_mask_generator = torch.optim.Adam(
        mask_generator.parameters(),
        lr=train_params['lr_mask_generator'],
        betas=(0.5, 0.999))

    if checkpoint is not None:
        print('loading cpk')
        start_epoch = Logger.load_cpk(
            checkpoint, generator, mask_generator, optimizer_generator,
            None if train_params['lr_mask_generator'] == 0 else
            optimizer_mask_generator)
    else:
        start_epoch = 0

    print(start_epoch)
    scheduler_generator = MultiStepLR(optimizer_generator,
                                      train_params['epoch_milestones'],
                                      gamma=0.1,
                                      last_epoch=start_epoch - 1)
    scheduler_mask_generator = MultiStepLR(
        optimizer_mask_generator,
        train_params['epoch_milestones'],
        gamma=0.1,
        last_epoch=-1 + start_epoch * (train_params['lr_mask_generator'] != 0))

    if 'num_repeats' in train_params or train_params['num_repeats'] != 1:
        dataset = DatasetRepeater(dataset, train_params['num_repeats'])
    dataloader = DataLoader(dataset,
                            batch_size=train_params['batch_size'],
                            shuffle=True,
                            num_workers=6,
                            drop_last=True)

    generator_full = GeneratorFullModel(mask_generator, generator,
                                        train_params)

    if torch.cuda.is_available():
        generator_full = DataParallelWithCallback(generator_full,
                                                  device_ids=device_ids)

    with Logger(log_dir=log_dir,
                visualizer_params=config['visualizer_params'],
                checkpoint_freq=train_params['checkpoint_freq']) as logger:
        for epoch in trange(start_epoch, train_params['num_epochs']):
            for index, x in enumerate(dataloader):
                predict_mask = epoch >= 1
                losses_generator, generated = generator_full(x, predict_mask)

                loss_values = [val.mean() for val in losses_generator.values()]
                loss = sum(loss_values)

                loss.backward()
                optimizer_generator.step()
                optimizer_generator.zero_grad()
                optimizer_mask_generator.step()
                optimizer_mask_generator.zero_grad()

                losses = {
                    key: value.mean().detach().data.cpu().numpy()
                    for key, value in losses_generator.items()
                }
                logger.log_iter(losses=losses)

            scheduler_generator.step()
            scheduler_mask_generator.step()

            logger.log_epoch(
                epoch, {
                    'generator': generator,
                    'mask_generator': mask_generator,
                    'optimizer_generator': optimizer_generator,
                    'optimizer_mask_generator': optimizer_mask_generator
                },
                inp=x,
                out=generated,
                save_w=True)
Exemplo n.º 2
0
def train(config, generator, discriminator, kp_detector, checkpoint, log_dir, dataset, device_ids):
    # Refer to *.yaml, "train_params" section.
    # This including epoch nums, etc ...
    train_params = config['train_params']

    # Define the optimizers for three sub-networks
    # Refer to Adam() document for details
    optimizer_generator = torch.optim.Adam(generator.parameters(), lr=train_params['lr_generator'], betas=(0.5, 0.999))
    optimizer_discriminator = torch.optim.Adam(discriminator.parameters(), lr=train_params['lr_discriminator'], betas=(0.5, 0.999))
    optimizer_kp_detector = torch.optim.Adam(kp_detector.parameters(), lr=train_params['lr_kp_detector'], betas=(0.5, 0.999))

    if checkpoint is not None:
        # Load in pretrained-models if set so
        # Models passed in are empty-initialized, which will be loaded in the following function
        start_epoch = Logger.load_cpk(checkpoint, generator, discriminator, kp_detector,
                                      optimizer_generator, optimizer_discriminator,
                                      None if train_params['lr_kp_detector'] == 0 else optimizer_kp_detector)
    else:
        start_epoch = 0

    # TODO: not sure what's this, it seems to define schedulers contronlling training details
    scheduler_generator = MultiStepLR(optimizer_generator, train_params['epoch_milestones'], gamma=0.1,
                                      last_epoch=start_epoch - 1)
    scheduler_discriminator = MultiStepLR(optimizer_discriminator, train_params['epoch_milestones'], gamma=0.1,
                                          last_epoch=start_epoch - 1)
    scheduler_kp_detector = MultiStepLR(optimizer_kp_detector, train_params['epoch_milestones'], gamma=0.1,
                                        last_epoch=-1 + start_epoch * (train_params['lr_kp_detector'] != 0))

    if 'num_repeats' in train_params or train_params['num_repeats'] != 1:
        # Augment the dataset according to "num_reapeat"
        dataset = DatasetRepeater(dataset, train_params['num_repeats'])
    # Load in data with form that network can determine
    # Refer to pytorch DataLoader for details
    # 这里dataloader是一个FramesDataset类,它是 Dataset 的一个子类,所以可以有如下操作
    dataloader = DataLoader(dataset, batch_size=train_params['batch_size'], shuffle=True, num_workers=2, drop_last=True)

    # Initialize two models for training
    # TODO: 阅读 generator 和 discrimator 的构造,key point detector 的部分应包含在 generator 当中
    generator_full = GeneratorFullModel(kp_detector, generator, discriminator, train_params)
    # TODO: 阅读 discriminator,需注意的是上述 Generator 中也有 discriminator 存在,高清两者区别
    discriminator_full = DiscriminatorFullModel(kp_detector, generator, discriminator, train_params)

    # Transfer model to gpu type
    if torch.cuda.is_available():
        generator_full = DataParallelWithCallback(generator_full, device_ids=device_ids)
        discriminator_full = DataParallelWithCallback(discriminator_full, device_ids=device_ids)

    with Logger(log_dir=log_dir, visualizer_params=config['visualizer_params'], checkpoint_freq=train_params['checkpoint_freq']) as logger:
        for epoch in trange(start_epoch, train_params['num_epochs']):
            for x in dataloader:
                # 此处为前向传播,第一个返回值为loss,第二个为生成器的输出图片
                losses_generator, generated = generator_full(x)

                # 此处计算的loss有很多种类,此处取了每一种的平均并求和
                loss_values = [val.mean() for val in losses_generator.values()]
                loss = sum(loss_values)

                # 此处分别使用不同部分的优化器进行 step 更新
                loss.backward()
                optimizer_generator.step()
                optimizer_generator.zero_grad()
                optimizer_kp_detector.step()
                optimizer_kp_detector.zero_grad()

                # 此处判断是否使用 GAN 的训练思想
                if train_params['loss_weights']['generator_gan'] != 0:
                    # 增加判别器的使用
                    optimizer_discriminator.zero_grad()
                    # 用判别器判定生成数据和源数据
                    losses_discriminator = discriminator_full(x, generated)
                    loss_values = [val.mean() for val in losses_discriminator.values()]
                    loss = sum(loss_values)

                    # 更新判别器
                    loss.backward()
                    optimizer_discriminator.step()
                    optimizer_discriminator.zero_grad()
                else:
                    losses_discriminator = {}

                # 注意此处的 update 是 python 中字典自带的更新方式
                losses_generator.update(losses_discriminator)
                losses = {key: value.mean().detach().data.cpu().numpy() for key, value in losses_generator.items()}
                logger.log_iter(losses=losses)

            # 此处为一个 epoch 的工作完成
            # TODO: 这是之前不确定是什么的数据结构,推断是对训练的schedule器的更新
            scheduler_generator.step()
            scheduler_discriminator.step()
            scheduler_kp_detector.step()
            
            logger.log_epoch(epoch, {'generator': generator,
                                     'discriminator': discriminator,
                                     'kp_detector': kp_detector,
                                     'optimizer_generator': optimizer_generator,
                                     'optimizer_discriminator': optimizer_discriminator,
                                     'optimizer_kp_detector': optimizer_kp_detector}, inp=x, out=generated)
Exemplo n.º 3
0
def train(config, generator, discriminator, kp_detector, save_dir, dataset):
    train_params = config['train_params']

    # learning_rate_scheduler
    gen_lr = MultiStepDecay(learning_rate=train_params['lr_generator'],
                            milestones=train_params['epoch_milestones'],
                            gamma=0.1)
    dis_lr = MultiStepDecay(learning_rate=train_params['lr_discriminator'],
                            milestones=train_params['epoch_milestones'],
                            gamma=0.1)
    kp_lr = MultiStepDecay(learning_rate=train_params['lr_kp_detector'],
                           milestones=train_params['epoch_milestones'],
                           gamma=0.1)
    # optimer
    if TEST_MODE:
        logging.warning('TEST MODE: Optimer is SGD, lr is 0.001. run.py: L50')
        optimizer_generator = paddle.optimizer.SGD(
            parameters=generator.parameters(), learning_rate=0.001)
        optimizer_discriminator = paddle.optimizer.SGD(
            parameters=discriminator.parameters(), learning_rate=0.001)
        optimizer_kp_detector = paddle.optimizer.SGD(
            parameters=kp_detector.parameters(), learning_rate=0.001)
    else:
        optimizer_generator = paddle.optimizer.Adam(
            parameters=generator.parameters(), learning_rate=gen_lr)
        optimizer_discriminator = paddle.optimizer.Adam(
            parameters=discriminator.parameters(), learning_rate=dis_lr)
        optimizer_kp_detector = paddle.optimizer.Adam(
            parameters=kp_detector.parameters(), learning_rate=kp_lr)

    # load start_epoch
    if isinstance(config['ckpt_model']['start_epoch'], int):
        start_epoch = config['ckpt_model']['start_epoch']
    else:
        start_epoch = 0
    logging.info('Start Epoch is :%i' % start_epoch)

    # dataset
    dataloader = paddle.io.DataLoader(dataset,
                                      batch_size=train_params['batch_size'],
                                      shuffle=True,
                                      drop_last=False,
                                      num_workers=4,
                                      use_buffer_reader=True,
                                      use_shared_memory=False)

    # load checkpoint
    ckpt_config = config['ckpt_model']
    has_key = lambda key: key in ckpt_config.keys() and ckpt_config[
        key] is not None
    load_ckpt(ckpt_config, generator, optimizer_generator, kp_detector,
              optimizer_kp_detector, discriminator, optimizer_discriminator)

    # create full model
    generator_full = GeneratorFullModel(kp_detector, generator, discriminator,
                                        train_params)
    discriminator_full = DiscriminatorFullModel(kp_detector, generator,
                                                discriminator, train_params)

    # load vgg19
    if has_key('vgg19_model'):
        vggVarList = [i for i in generator_full.vgg.parameters()]
        paramset = np.load(ckpt_config['vgg19_model'],
                           allow_pickle=True)['arr_0']
        for var, v in zip(vggVarList, paramset):
            if list(var.shape) == list(v.shape):
                var.set_value(v)
            else:
                logging.warning('VGG19 cannot be loaded')
        logging.info('Pre-trained VGG19 is loaded from *.npz')

    # train
    generator_full.train()
    discriminator_full.train()
    for epoch in trange(start_epoch, train_params['num_epochs']):
        for _step, _x in enumerate(dataloader()):

            # prepare data
            x = dict()
            x['driving'], x['source'] = _x
            x['name'] = ['NULL'] * _x[0].shape[0]
            if TEST_MODE:
                logging.warning('TEST MODE: Input is Fixed run.py: L207')
                x['driving'] = paddle.to_tensor(fake_input)
                x['source'] = paddle.to_tensor(fake_input)
                x['name'] = ['test1', 'test2']

            # train generator
            losses_generator, generated = generator_full(x.copy())
            loss_values = [val.sum() for val in losses_generator.values()]
            loss = paddle.add_n(loss_values)
            if TEST_MODE:
                print('Check Generator Loss')
                print('\n'.join([
                    '%s:%1.5f' % (k, v.numpy())
                    for k, v in zip(losses_generator.keys(), loss_values)
                ]))
                import pdb
                pdb.set_trace()
            loss.backward()
            optimizer_generator.step()
            optimizer_generator.clear_grad()
            optimizer_kp_detector.step()
            optimizer_kp_detector.clear_grad()

            # train discriminator
            if train_params['loss_weights']['generator_gan'] != 0:
                optimizer_discriminator.clear_gradients()
                losses_discriminator = discriminator_full(x.copy(), generated)
                loss_values = [
                    val.mean() for val in losses_discriminator.values()
                ]
                loss = paddle.add_n(loss_values)
                if TEST_MODE:
                    print('Check Discriminator Loss')
                    print('\n'.join([
                        '%s:%1.5f' % (k, v.numpy()) for k, v in zip(
                            losses_discriminator.keys(), loss_values)
                    ]))
                    import pdb
                    pdb.set_trace()
                loss.backward()
                optimizer_discriminator.step()
                optimizer_discriminator.clear_grad()
            else:
                losses_discriminator = {}

            losses_generator.update(losses_discriminator)
            losses = {
                key: value.mean().detach().numpy()
                for key, value in losses_generator.items()
            }

            # print log
            if _step % 20 == 0:
                logging.info('Epoch:%i\tstep: %i\tLr:%1.7f' %
                             (epoch, _step, optimizer_generator.get_lr()))
                logging.info('\t'.join(
                    ['%s:%1.4f' % (k, v) for k, v in losses.items()]))

        # save
        if epoch % 3 == 0:
            paddle.fluid.save_dygraph(
                generator.state_dict(),
                os.path.join(save_dir, 'epoch%i/G' % epoch))
            paddle.fluid.save_dygraph(
                discriminator.state_dict(),
                os.path.join(save_dir, 'epoch%i/D' % epoch))
            paddle.fluid.save_dygraph(
                kp_detector.state_dict(),
                os.path.join(save_dir, 'epoch%i/KP' % epoch))
            paddle.fluid.save_dygraph(
                optimizer_generator.state_dict(),
                os.path.join(save_dir, 'epoch%i/G' % epoch))
            paddle.fluid.save_dygraph(
                optimizer_discriminator.state_dict(),
                os.path.join(save_dir, 'epoch%i/D' % epoch))
            paddle.fluid.save_dygraph(
                optimizer_kp_detector.state_dict(),
                os.path.join(save_dir, 'epoch%i/KP' % epoch))
            logging.info('Model is saved to:%s' %
                         os.path.join(save_dir, 'epoch%i/' % epoch))
        gen_lr.step()
        dis_lr.step()
        kp_lr.step()
Exemplo n.º 4
0
def train(config, generator, discriminator, kp_detector, save_dir, dataset):
    train_params = config['train_params']

    # learning_rate_scheduler
    if paddle.version.full_version in ['1.8.4'] or paddle.version.major == '2':
        gen_lr = MultiStepDecay(learning_rate=train_params['lr_generator'],
                                milestones=train_params['epoch_milestones'],
                                decay_rate=0.1)
        dis_lr = MultiStepDecay(learning_rate=train_params['lr_discriminator'],
                                milestones=train_params['epoch_milestones'],
                                decay_rate=0.1)
        kp_lr = MultiStepDecay(learning_rate=train_params['lr_kp_detector'],
                               milestones=train_params['epoch_milestones'],
                               decay_rate=0.1)
    else:
        gen_lr = train_params['lr_generator']
        dis_lr = train_params['lr_discriminator']
        kp_lr = train_params['lr_kp_detector']

    # optimer
    if TEST_MODE:
        logging.warning(
            'TEST MODE: Optimer is SGD, lr is 0.001. train.py: L50')
        optimizer_generator = fluid.optimizer.SGDOptimizer(
            parameter_list=generator.parameters(), learning_rate=0.001)
        optimizer_discriminator = fluid.optimizer.SGDOptimizer(
            parameter_list=discriminator.parameters(), learning_rate=0.001)
        optimizer_kp_detector = fluid.optimizer.SGDOptimizer(
            parameter_list=kp_detector.parameters(), learning_rate=0.001)
    else:
        optimizer_generator = fluid.optimizer.AdamOptimizer(
            parameter_list=generator.parameters(), learning_rate=gen_lr)
        optimizer_discriminator = fluid.optimizer.AdamOptimizer(
            parameter_list=discriminator.parameters(), learning_rate=dis_lr)
        optimizer_kp_detector = fluid.optimizer.AdamOptimizer(
            parameter_list=kp_detector.parameters(), learning_rate=kp_lr)

    # load start_epoch
    if isinstance(config['ckpt_model']['start_epoch'], int):
        start_epoch = config['ckpt_model']['start_epoch']
    else:
        start_epoch = 0
    logging.info('Start Epoch is :%i' % start_epoch)

    # dataset pipeline
    def indexGenertaor():
        """随机生成索引序列
        """
        order = list(range(len(dataset)))
        order = order * train_params['num_repeats']
        random.shuffle(order)
        for i in order:
            yield i

    _dataset = fluid.io.xmap_readers(dataset.getSample,
                                     indexGenertaor,
                                     process_num=4,
                                     buffer_size=128,
                                     order=False)
    _dataset = fluid.io.batch(_dataset,
                              batch_size=train_params['batch_size'],
                              drop_last=True)
    dataloader = fluid.io.buffered(_dataset, 1)

    ###### Restore Part ######
    ckpt_config = config['ckpt_model']
    has_key = lambda key: key in ckpt_config.keys() and ckpt_config[
        key] is not None
    if has_key('generator'):
        if ckpt_config['generator'][-3:] == 'npz':
            G_param = np.load(ckpt_config['generator'],
                              allow_pickle=True)['arr_0'].item()
            G_param_clean = [(i, G_param[i]) for i in G_param
                             if 'num_batches_tracked' not in i]
            parameter_clean = generator.parameters()
            del (
                parameter_clean[65]
            )  # The parameters in AntiAliasInterpolation2d is not in dict_set and should be ignore.
            for v, b in zip(parameter_clean, G_param_clean):
                v.set_value(b[1])
            logging.info('Generator is loaded from *.npz')
        else:
            param, optim = fluid.load_dygraph(ckpt_config['generator'])
            generator.set_dict(param)
            if optim is not None:
                optimizer_generator.set_dict(optim)
            else:
                logging.info('Optimizer of G is not loaded')
            logging.info('Generator is loaded from *.pdparams')
    if has_key('kp'):
        if ckpt_config['kp'][-3:] == 'npz':
            KD_param = np.load(ckpt_config['kp'],
                               allow_pickle=True)['arr_0'].item()
            KD_param_clean = [(i, KD_param[i]) for i in KD_param
                              if 'num_batches_tracked' not in i]
            parameter_cleans = kp_detector.parameters()
            for v, b in zip(parameter_cleans, KD_param_clean):
                v.set_value(b[1])
            logging.info('KP is loaded from *.npz')
        else:
            param, optim = fluid.load_dygraph(ckpt_config['kp'])
            kp_detector.set_dict(param)
            if optim is not None:
                optimizer_kp_detector.set_dict(optim)
            else:
                logging.info('Optimizer of KP is not loaded')
            logging.info('KP is loaded from *.pdparams')
    if has_key('discriminator'):
        if ckpt_config['discriminator'][-3:] == 'npz':
            D_param = np.load(ckpt_config['discriminator'],
                              allow_pickle=True)['arr_0'].item()
            if 'NULL Place' in ckpt_config['discriminator']:
                # 针对未开启spectral_norm的Fashion数据集模型
                ## fashion数据集的默认设置中未启用spectral_norm,但其官方ckpt文件中存在spectral_norm特有的参数 需要重排顺序
                ## 已提相关issue,作者回应加了sn也没什么影响 https://github.com/AliaksandrSiarohin/first-order-model/issues/264
                ## 若在配置文件中开启sn则可通过else语句中的常规方法读取,故现已在配置中开启sn。
                D_param_clean = [
                    (i, D_param[i]) for i in D_param
                    if 'num_batches_tracked' not in i and 'weight_v' not in i
                    and 'weight_u' not in i
                ]
                for idx in range(len(D_param_clean) // 2):
                    if 'conv.bias' in D_param_clean[idx * 2][0]:
                        D_param_clean[idx * 2], D_param_clean[
                            idx * 2 +
                            1] = D_param_clean[idx * 2 +
                                               1], D_param_clean[idx * 2]
                parameter_clean = discriminator.parameters()
                for v, b in zip(parameter_clean, D_param_clean):
                    v.set_value(b[1])
            else:
                D_param_clean = list(D_param.items())
                parameter_clean = discriminator.parameters()
                assert len(D_param_clean) == len(parameter_clean)
                # 调换顺序
                ## PP中:        [conv.weight,   conv.bias,          weight_u, weight_v]
                ## pytorch中:   [conv.bias,     conv.weight_orig,   weight_u, weight_v]
                for idx in range(len(parameter_clean)):
                    if list(parameter_clean[idx].shape) == list(
                            D_param_clean[idx][1].shape):
                        parameter_clean[idx].set_value(D_param_clean[idx][1])
                    elif parameter_clean[idx].name.split(
                            '.')[-1] == 'w_0' and D_param_clean[
                                idx + 1][0].split('.')[-1] == 'weight_orig':
                        parameter_clean[idx].set_value(D_param_clean[idx +
                                                                     1][1])
                    elif parameter_clean[idx].name.split(
                            '.')[-1] == 'b_0' and D_param_clean[
                                idx - 1][0].split('.')[-1] == 'bias':
                        parameter_clean[idx].set_value(D_param_clean[idx -
                                                                     1][1])
                    else:
                        print('Error', idx)
            logging.info('Discriminator is loaded from *.npz')
        else:
            param, optim = fluid.load_dygraph(ckpt_config['discriminator'])
            discriminator.set_dict(param)
            if optim is not None:
                optimizer_discriminator.set_dict(optim)
            else:
                logging.info('Optimizer of Discriminator is not loaded')
            logging.info('Discriminator is loaded from *.pdparams')
    ###### Restore Part END ######

    # create model
    generator_full = GeneratorFullModel(kp_detector, generator, discriminator,
                                        train_params)
    discriminator_full = DiscriminatorFullModel(kp_detector, generator,
                                                discriminator, train_params)
    if has_key('vgg19_model'):
        vggVarList = [i for i in generator_full.vgg.parameters()][2:]
        paramset = np.load(ckpt_config['vgg19_model'],
                           allow_pickle=True)['arr_0']
        for var, v in zip(vggVarList, paramset):
            if list(var.shape) == list(v.shape):
                var.set_value(v)
            else:
                logging.warning('VGG19 cannot be loaded')
        logging.info('Pre-trained VGG19 is loaded from *.npz')
    generator_full.train()
    discriminator_full.train()
    for epoch in trange(start_epoch, train_params['num_epochs']):
        for _step, _x in enumerate(dataloader()):
            # prepear data
            x = dict()
            for _key in _x[0].keys():
                if str(_key) != 'name':
                    x[_key] = dygraph.to_variable(
                        np.stack([_v[_key] for _v in _x],
                                 axis=0).astype(np.float32))
                else:
                    x[_key] = np.stack([_v[_key] for _v in _x], axis=0)
            # import pdb;pdb.set_trace();
            if TEST_MODE:
                logging.warning('TEST MODE: Input is Fixed train.py: L207')
                x['driving'] = dygraph.to_variable(fake_input)
                x['source'] = dygraph.to_variable(fake_input)
                x['name'] = ['test1', 'test2']
            # train generator
            losses_generator, generated = generator_full(x.copy())
            loss_values = [
                fluid.layers.reduce_sum(val)
                for val in losses_generator.values()
            ]
            loss = fluid.layers.sum(loss_values)
            if TEST_MODE:
                print('Check Generator Loss')
                print('\n'.join([
                    '%s:%1.5f' % (k, v.numpy())
                    for k, v in zip(losses_generator.keys(), loss_values)
                ]))
                import pdb
                pdb.set_trace()
            loss.backward()
            optimizer_generator.minimize(loss)
            optimizer_generator.clear_gradients()
            optimizer_kp_detector.minimize(loss)
            optimizer_kp_detector.clear_gradients()

            # train discriminator
            if train_params['loss_weights']['generator_gan'] != 0:
                optimizer_discriminator.clear_gradients()
                losses_discriminator = discriminator_full(x.copy(), generated)
                loss_values = [
                    fluid.layers.reduce_mean(val)
                    for val in losses_discriminator.values()
                ]
                loss = fluid.layers.sum(loss_values)
                if TEST_MODE:
                    print('Check Discriminator Loss')
                    print('\n'.join([
                        '%s:%1.5f' % (k, v.numpy()) for k, v in zip(
                            losses_discriminator.keys(), loss_values)
                    ]))
                    import pdb
                    pdb.set_trace()
                loss.backward()
                optimizer_discriminator.minimize(loss)
                optimizer_discriminator.clear_gradients()
            else:
                losses_discriminator = {}

            losses_generator.update(losses_discriminator)
            losses = {
                key: fluid.layers.reduce_mean(value).detach().numpy()
                for key, value in losses_generator.items()
            }

            # print log
            if _step % 20 == 0:
                logging.info(
                    'Epoch:%i\tstep: %i\tLr:%1.7f' %
                    (epoch, _step, optimizer_generator.current_step_lr()))
                logging.info('\t'.join(
                    ['%s:%1.4f' % (k, v) for k, v in losses.items()]))

        # save
        if epoch % 3 == 0:
            paddle.fluid.save_dygraph(
                generator.state_dict(),
                os.path.join(save_dir, 'epoch%i/G' % epoch))
            paddle.fluid.save_dygraph(
                discriminator.state_dict(),
                os.path.join(save_dir, 'epoch%i/D' % epoch))
            paddle.fluid.save_dygraph(
                kp_detector.state_dict(),
                os.path.join(save_dir, 'epoch%i/KP' % epoch))
            paddle.fluid.save_dygraph(
                optimizer_generator.state_dict(),
                os.path.join(save_dir, 'epoch%i/G' % epoch))
            paddle.fluid.save_dygraph(
                optimizer_discriminator.state_dict(),
                os.path.join(save_dir, 'epoch%i/D' % epoch))
            paddle.fluid.save_dygraph(
                optimizer_kp_detector.state_dict(),
                os.path.join(save_dir, 'epoch%i/KP' % epoch))
            logging.info('Model is saved to:%s' %
                         os.path.join(save_dir, 'epoch%i/' % epoch))
        if paddle.version.full_version in ['1.8.4'
                                           ] or paddle.version.major == '2':
            gen_lr.epoch()
            dis_lr.epoch()
            kp_lr.epoch()
Exemplo n.º 5
0
def train(config, generator, discriminator, kp_detector, checkpoint, log_dir,
          dataset, device_ids):

    worker_num = 16
    train_params = config['train_params']

    optimizer_generator = torch.optim.Adam(generator.parameters(),
                                           lr=train_params['lr_generator'],
                                           betas=(0.5, 0.999))
    optimizer_discriminator = torch.optim.Adam(
        discriminator.parameters(),
        lr=train_params['lr_discriminator'],
        betas=(0.5, 0.999))
    optimizer_kp_detector = torch.optim.Adam(kp_detector.parameters(),
                                             lr=train_params['lr_kp_detector'],
                                             betas=(0.5, 0.999))

    if checkpoint is not None:
        start_epoch = Logger.load_cpk(
            checkpoint, generator, discriminator, kp_detector,
            optimizer_generator, optimizer_discriminator, None
            if train_params['lr_kp_detector'] == 0 else optimizer_kp_detector)
    else:
        start_epoch = 0

    scheduler_generator = MultiStepLR(optimizer_generator,
                                      train_params['epoch_milestones'],
                                      gamma=0.1,
                                      last_epoch=start_epoch - 1)
    scheduler_discriminator = MultiStepLR(optimizer_discriminator,
                                          train_params['epoch_milestones'],
                                          gamma=0.1,
                                          last_epoch=start_epoch - 1)
    scheduler_kp_detector = MultiStepLR(optimizer_kp_detector,
                                        train_params['epoch_milestones'],
                                        gamma=0.1,
                                        last_epoch=-1 + start_epoch *
                                        (train_params['lr_kp_detector'] != 0))

    if 'num_repeats' in train_params or train_params['num_repeats'] != 1:
        dataset = DatasetRepeater(dataset, train_params['num_repeats'])
    dataloader = DataLoader(dataset,
                            batch_size=train_params['batch_size'],
                            shuffle=True,
                            num_workers=worker_num,
                            drop_last=True)

    generator_full = GeneratorFullModel(kp_detector, generator, discriminator,
                                        train_params)
    discriminator_full = DiscriminatorFullModel(kp_detector, generator,
                                                discriminator, train_params)

    if torch.cuda.is_available():
        generator_full = DataParallelWithCallback(generator_full,
                                                  device_ids=device_ids)
        discriminator_full = DataParallelWithCallback(discriminator_full,
                                                      device_ids=device_ids)

    with Logger(log_dir=log_dir,
                visualizer_params=config['visualizer_params'],
                checkpoint_freq=train_params['checkpoint_freq']) as logger:
        for epoch in range(start_epoch, train_params['num_epochs']):
            for x in tqdm(dataloader, total=len(dataloader)):
                losses_generator, generated = generator_full(x)

                loss_values = [val.mean() for val in losses_generator.values()]
                loss = sum(loss_values)

                loss.backward()
                optimizer_generator.step()
                optimizer_generator.zero_grad()
                optimizer_kp_detector.step()
                optimizer_kp_detector.zero_grad()

                if train_params['loss_weights']['generator_gan'] != 0:
                    optimizer_discriminator.zero_grad()
                    losses_discriminator = discriminator_full(x, generated)
                    loss_values = [
                        val.mean() for val in losses_discriminator.values()
                    ]
                    loss = sum(loss_values)

                    loss.backward()
                    optimizer_discriminator.step()
                    optimizer_discriminator.zero_grad()
                else:
                    losses_discriminator = {}

                losses_generator.update(losses_discriminator)
                losses = {
                    key: value.mean().detach().data.cpu().numpy()
                    for key, value in losses_generator.items()
                }
                logger.log_iter(losses=losses)

            scheduler_generator.step()
            scheduler_discriminator.step()
            scheduler_kp_detector.step()

            logger.log_epoch(epoch, {
                'generator': generator,
                'discriminator': discriminator,
                'kp_detector': kp_detector,
                'optimizer_generator': optimizer_generator,
                'optimizer_discriminator': optimizer_discriminator,
                'optimizer_kp_detector': optimizer_kp_detector
            },
                             inp=x,
                             out=generated)
Exemplo n.º 6
0
def train(config, generator, discriminator, kp_detector, checkpoint, log_dir,
          dataset, device_ids):
    train_params = config['train_params']

    optimizer_generator = torch.optim.Adam(generator.parameters(),
                                           lr=train_params['lr_generator'],
                                           betas=(0.5, 0.999))
    optimizer_discriminator = torch.optim.Adam(
        discriminator.parameters(),
        lr=train_params['lr_discriminator'],
        betas=(0.5, 0.999))
    optimizer_kp_detector = torch.optim.Adam(kp_detector.parameters(),
                                             lr=train_params['lr_kp_detector'],
                                             betas=(0.5, 0.999))

    if checkpoint is not None:
        start_epoch = Logger.load_cpk(
            checkpoint, generator, discriminator, kp_detector,
            optimizer_generator, optimizer_discriminator, None
            if train_params['lr_kp_detector'] == 0 else optimizer_kp_detector)
    else:
        start_epoch = 0

    scheduler_generator = MultiStepLR(optimizer_generator,
                                      train_params['epoch_milestones'],
                                      gamma=0.1,
                                      last_epoch=start_epoch - 1)
    scheduler_discriminator = MultiStepLR(optimizer_discriminator,
                                          train_params['epoch_milestones'],
                                          gamma=0.1,
                                          last_epoch=start_epoch - 1)
    scheduler_kp_detector = MultiStepLR(optimizer_kp_detector,
                                        train_params['epoch_milestones'],
                                        gamma=0.1,
                                        last_epoch=-1 + start_epoch *
                                        (train_params['lr_kp_detector'] != 0))

    if 'num_repeats' in train_params or train_params['num_repeats'] != 1:
        dataset = DatasetRepeater(dataset, train_params['num_repeats'])
    dataloader = DataLoader(dataset,
                            batch_size=train_params['batch_size'],
                            shuffle=True,
                            drop_last=True,
                            num_workers=4)
    print_fun(f'Full dataset length (with repeats): {len(dataset)}')

    generator_full = GeneratorFullModel(kp_detector, generator, discriminator,
                                        train_params)
    discriminator_full = DiscriminatorFullModel(kp_detector, generator,
                                                discriminator, train_params)

    if torch.cuda.is_available():
        generator_full = DataParallelWithCallback(generator_full,
                                                  device_ids=device_ids)
        discriminator_full = DataParallelWithCallback(discriminator_full,
                                                      device_ids=device_ids)

    writer = tensorboardX.SummaryWriter(log_dir, flush_secs=60)

    with Logger(log_dir=log_dir,
                visualizer_params=config['visualizer_params'],
                checkpoint_freq=train_params['checkpoint_freq']) as logger:
        for epoch in trange(start_epoch,
                            train_params['num_epochs'],
                            disable=None):
            for i, x in enumerate(dataloader):
                losses_generator, generated = generator_full(x)

                loss_values = [val.mean() for val in losses_generator.values()]
                loss = sum(loss_values)

                loss.backward()
                optimizer_generator.step()
                optimizer_generator.zero_grad()
                optimizer_kp_detector.step()
                optimizer_kp_detector.zero_grad()

                if train_params['loss_weights']['generator_gan'] != 0:
                    optimizer_discriminator.zero_grad()
                    losses_discriminator = discriminator_full(x, generated)
                    loss_values = [
                        val.mean() for val in losses_discriminator.values()
                    ]
                    loss = sum(loss_values)

                    loss.backward()
                    optimizer_discriminator.step()
                    optimizer_discriminator.zero_grad()
                else:
                    losses_discriminator = {}

                losses_generator.update(losses_discriminator)
                losses = {
                    key: value.mean().detach().data.cpu().numpy()
                    for key, value in losses_generator.items()
                }
                logger.log_iter(losses=losses)

                step = i + int(epoch * len(dataset) / dataloader.batch_size)
                if step % 20 == 0:
                    print_fun(
                        f'Epoch {epoch + 1}, global step {step}: {", ".join([f"{k}={v}" for k, v in losses.items()])}'
                    )

                if step != 0 and step % 50 == 0:
                    for k, loss in losses.items():
                        writer.add_scalar(k, float(loss), global_step=step)
                    # add images
                    source = x['source'][0].detach().cpu().numpy().transpose(
                        [1, 2, 0])
                    driving = x['driving'][0].detach().cpu().numpy().transpose(
                        [1, 2, 0])
                    kp_source = generated['kp_source']['value'][0].detach(
                    ).cpu().numpy()
                    kp_driving = generated['kp_driving']['value'][0].detach(
                    ).cpu().numpy()
                    pred = generated['prediction'][0].detach().cpu().numpy(
                    ).transpose([1, 2, 0])
                    kp_source = kp_source * 127.5 + 127.5
                    kp_driving = kp_driving * 127.5 + 127.5
                    source = cv2.UMat(
                        (source * 255.).clip(0, 255).astype(np.uint8)).get()
                    driving = cv2.UMat(
                        (driving * 255.).clip(0, 255).astype(np.uint8)).get()
                    pred = (pred * 255.).clip(0, 255).astype(np.uint8)
                    for x1, y1 in kp_source:
                        cv2.circle(source, (int(x1), int(y1)),
                                   2, (250, 250, 250),
                                   thickness=cv2.FILLED)
                    for x1, y1 in kp_driving:
                        cv2.circle(driving, (int(x1), int(y1)),
                                   2, (250, 250, 250),
                                   thickness=cv2.FILLED)

                    writer.add_image('SourceDrivingPred',
                                     np.hstack((source, driving, pred)),
                                     global_step=step,
                                     dataformats='HWC')
                    writer.flush()

            scheduler_generator.step()
            scheduler_discriminator.step()
            scheduler_kp_detector.step()

            logger.log_epoch(
                epoch, {
                    'generator': generator,
                    'discriminator': discriminator,
                    'kp_detector': kp_detector,
                    'optimizer_generator': optimizer_generator,
                    'optimizer_discriminator': optimizer_discriminator,
                    'optimizer_kp_detector': optimizer_kp_detector
                })