コード例 #1
0
    def train(self, loaders):
        if self.args.loss == 'arcface':
            BACKBONE_RESUME_ROOT = 'D:/face-recognition/stargan-v2-master/ms1m_ir50/backbone_ir50_ms1m_epoch120.pth'

            INPUT_SIZE = [112, 112]
            arcface = IR_50(INPUT_SIZE)

            if os.path.isfile(BACKBONE_RESUME_ROOT):
                arcface.load_state_dict(torch.load(BACKBONE_RESUME_ROOT))
                print("Loading Backbone Checkpoint '{}'".format(
                    BACKBONE_RESUME_ROOT))

            DEVICE = torch.device(
                "cuda:0" if torch.cuda.is_available() else "cpu")
            criterion_id = arcface.to(DEVICE)
        elif self.args.loss == 'perceptual':
            criterion_id = network.LossEG(False, 0)

        args = self.args
        nets = self.nets
        nets_ema = self.nets_ema

        optims = self.optims
        for net in nets.keys():
            if net == 'linear_classfier':

                optims[net] = torch.optim.Adam(params=nets[net].parameters(),
                                               lr=args.lr2,
                                               betas=[args.beta1, args.beta2],
                                               weight_decay=args.weight_decay)

        print(optims)
        print(self.nets.keys())
        # assert False

        # fetch random validation images for debugging
        if args.dataset == 'mpie':
            fetcher = InputFetcher_mpie(loaders.src, args.latent_dim, 'train')
            fetcher_val = InputFetcher_mpie(loaders.val, args.latent_dim,
                                            'val')

        elif args.dataset == '300vw':
            fetcher = InputFetcher_300vw(loaders.src, args.latent_dim, 'train')
            fetcher_val = InputFetcher_300vw(loaders.val, args.latent_dim,
                                             'val')
        inputs_val = next(fetcher_val)

        # resume training if necessary
        if args.resume_iter > 0:
            self._load_checkpoint(args.resume_iter)

        # remember the initial value of ds weight
        initial_lambda_ds = args.lambda_ds

        print('Start training...')
        start_time = time.time()
        for i in range(args.resume_iter, args.total_iters):
            if (i + 1) % args.decay_every == 0:

                # print('54555')
                times = (i + 1) / args.decay_every
                # print(args.lr*0.1**int(times))
                optims = Munch()
                for net in nets.keys():
                    if net == 'fan':
                        continue
                    optims[net] = torch.optim.Adam(
                        params=nets[net].parameters(),
                        lr=args.lr * 0.1**int(times),
                        betas=[args.beta1, args.beta2],
                        weight_decay=args.weight_decay)

                # optims = torch.optim.Adam(
                #             params=self.nets[net].parameters(),
                #             lr=args.lr*0.1**int(times),
                #             betas=[args.beta1, args.beta2],
                #             weight_decay=args.weight_decay)

            # fetch images and labels
            inputs = next(fetcher)
            # x_label, x2_label, x3_label, x4_label, x1_one_hot, x3_one_hot, x1_id, x3_id

            x1_label = inputs.x_label
            x2_label = inputs.x2_label
            x3_label = inputs.x3_label
            if args.dataset == 'mpie':
                x4_label = inputs.x4_label

                param_x4 = x4_label[:, 0, :].unsqueeze(0)
                param_x4 = param_x4.view(-1, 136).float()

            x1_one_hot, x3_one_hot = inputs.x1_one_hot, inputs.x3_one_hot
            x1_id, x3_id = inputs.x1_id, inputs.x3_id

            param_x1 = x1_label[:, 0, :].unsqueeze(0)
            param_x1 = param_x1.view(-1, 136).float()

            param_x2 = x2_label[:, 0, :].unsqueeze(0)
            param_x2 = param_x2.view(-1, 136).float()

            param_x3 = x3_label[:, 0, :].unsqueeze(0)
            param_x3 = param_x3.view(-1, 136).float()

            one_hot_x1 = x1_one_hot[:, 0, :].unsqueeze(0)
            one_hot_x1 = one_hot_x1.view(-1, 150).float()

            one_hot_x3 = x3_one_hot[:, 0, :].unsqueeze(0)
            one_hot_x3 = one_hot_x3.view(-1, 150).float()

            # print(param_x1.shape)
            # print(one_hot_x1.shape)
            # assert False

            # masks = nets.fan.get_heatmap(x_real) if args.w_hpf > 0 else None

            # linear_decoder = Linear_decoder()
            # id_linear_encoder = Id_linear_encoder()
            # lm_linear_encoder = Lm_linear_encoder()
            # linear_discriminator = Linear_discriminator()

            if args.dataset == '300vw':
                print('300vw')

            elif args.dataset == 'mpie':
                # train the discriminator
                d_tran_loss, d_tran_losses = compute_d_tran_loss(
                    nets,
                    args,
                    param_x1,
                    param_x2,
                    param_x3,
                    param_x4,
                    one_hot_x1,
                    one_hot_x3,
                    x1_id,
                    x3_id,
                    masks=None,
                    loss_select=args.loss)
                self._reset_grad()
                d_tran_loss.backward()
                optims.linear_discriminator.step()
                moving_average(nets.linear_discriminator,
                               nets_ema.linear_discriminator,
                               beta=0.999)

                # train the classfier
                c_loss, c_losses = compute_c_loss(nets,
                                                  args,
                                                  param_x1,
                                                  param_x2,
                                                  param_x3,
                                                  param_x4,
                                                  one_hot_x1,
                                                  one_hot_x3,
                                                  x1_id,
                                                  x3_id,
                                                  masks=None,
                                                  loss_select=args.loss)
                self._reset_grad()
                c_loss.backward()
                optims.linear_classfier.step()
                moving_average(nets.linear_classfier,
                               nets_ema.linear_classfier,
                               beta=0.999)

                # train the transformer
                t_loss, t_losses = compute_t_loss(nets,
                                                  args,
                                                  param_x1,
                                                  param_x2,
                                                  param_x3,
                                                  param_x4,
                                                  one_hot_x1,
                                                  one_hot_x3,
                                                  x1_id,
                                                  x3_id,
                                                  masks=None,
                                                  loss_select=args.loss)
                self._reset_grad()
                t_loss.backward()
                optims.linear_decoder.step()
                optims.lm_linear_encoder.step()
                optims.id_linear_encoder.step()

                moving_average(nets.linear_decoder,
                               nets_ema.linear_decoder,
                               beta=0.999)
                moving_average(nets.lm_linear_encoder,
                               nets_ema.lm_linear_encoder,
                               beta=0.999)
                moving_average(nets.id_linear_encoder,
                               nets_ema.id_linear_encoder,
                               beta=0.999)

            # # train the discriminator
            # d_loss, d_losses = compute_d_loss(
            #     nets, args, x1_source,x1_source_lm, x2_target, x2_target_lm, masks=None, loss_select = args.loss)
            # self._reset_grad()
            # d_loss.backward()
            # optims.discriminator.step()
            #
            #
            #
            # # train the generator
            # g_loss, g_losses = compute_g_loss(
            #     nets, args, x1_source, x1_source_lm, x2_target, x2_target_lm, criterion_id, masks=None, loss_select = args.loss)
            # self._reset_grad()
            # g_loss.backward()
            #
            # if args.transformer:
            #     optims.lm_encoder.step()
            #     optims.lm_transformer.step()
            #     optims.lm_decoder.step()
            #     optims.style_encoder.step()
            # else:
            #     optims.generator.step()
            #     optims.style_encoder.step()
            #
            # if args.transformer:
            #     moving_average(nets.lm_encoder, nets_ema.lm_encoder, beta=0.999)
            #     moving_average(nets.lm_transformer, nets_ema.lm_transformer, beta=0.999)
            #     moving_average(nets.lm_decoder, nets_ema.lm_decoder, beta=0.999)
            #     moving_average(nets.style_encoder, nets_ema.style_encoder, beta=0.999)
            #
            # else:
            #     # compute moving average of network parameters
            #     moving_average(nets.generator, nets_ema.generator, beta=0.999)
            #     # moving_average(nets.mapping_network, nets_ema.mapping_network, beta=0.999)
            #     moving_average(nets.style_encoder, nets_ema.style_encoder, beta=0.999)

            # decay weight for diversity sensitive loss
            if args.lambda_ds > 0:
                args.lambda_ds -= (initial_lambda_ds / args.ds_iter)

            # print out log info
            if (i + 1) % args.print_every == 0:
                elapsed = time.time() - start_time
                elapsed = str(datetime.timedelta(seconds=elapsed))[:-7]
                log = "Elapsed time [%s], Iteration [%i/%i], " % (
                    elapsed, i + 1, args.total_iters)
                all_losses = dict()
                # for loss, prefix in zip([d_losses, g_losses],['D/', 'G/']):
                for loss, prefix in zip([d_tran_losses, t_losses, c_losses],
                                        ['D/', 'G/', 'C/']):

                    for key, value in loss.items():
                        all_losses[prefix + key] = value
                all_losses['G/lambda_ds'] = args.lambda_ds
                log += ' '.join([
                    '%s: [%.4f]' % (key, value)
                    for key, value in all_losses.items()
                ])
                print(log)
                for key, value in all_losses.items():
                    self.writer.add_scalar(key, value, i + 1)

            # generate images for debugging
            if (i + 1) % args.sample_every == 0:
                os.makedirs(args.sample_dir, exist_ok=True)
                utils.debug_image(nets_ema,
                                  args,
                                  inputs=inputs_val,
                                  step=i + 1)

            # save model checkpoints
            if (i + 1) % args.save_every == 0:
                self._save_checkpoint(step=i + 1)

            # compute FID and LPIPS if necessary
            if (i + 1) % args.eval_every == 0:
                calculate_metrics(nets_ema, args, i + 1, mode='latent')
                calculate_metrics(nets_ema, args, i + 1, mode='reference')
        self.writer.close()
コード例 #2
0
def meta_train(gpu, dataset_path, continue_id):
    run_start = datetime.now()
    logging.info('===== META-TRAINING =====')
    logging.info(f'Running on {"GPU" if gpu else "CPU"}.')

    # region DATASET----------------------------------------------------------------------------------------------------
    logging.info(f'Training using dataset located in {dataset_path}')
    raw_dataset = VoxCelebDataset(
        root=dataset_path,
        extension='.vid',
        shuffle_frames=True,
        subset_size=config.SUBSET_SIZE,
        transform=transforms.Compose([
            transforms.Resize(config.IMAGE_SIZE),
            transforms.CenterCrop(config.IMAGE_SIZE),
            transforms.ToTensor(),
        ])
    )
    dataset = DataLoader(raw_dataset, batch_size=config.BATCH_SIZE, shuffle=True)

    # endregion

    # region NETWORK ---------------------------------------------------------------------------------------------------

    E = network.Embedder(GPU['Embedder'])
    G = network.Generator(GPU['Generator'])
    D = network.Discriminator(len(raw_dataset), GPU['Discriminator'])
    criterion_E_G = network.LossEG(config.FEED_FORWARD, GPU['LossEG'])
    criterion_D = network.LossD(GPU['LossD'])

    optimizer_E_G = Adam(
        params=list(E.parameters()) + list(G.parameters()),
        lr=config.LEARNING_RATE_E_G
    )
    optimizer_D = Adam(
        params=D.parameters(),
        lr=config.LEARNING_RATE_D
    )

    if continue_id is not None:
        E = load_model(E, continue_id)
        G = load_model(G, continue_id)
        D = load_model(D, continue_id)

    # endregion

    # region TRAINING LOOP ---------------------------------------------------------------------------------------------
    logging.info(f'Epochs: {config.EPOCHS} Batches: {len(dataset)} Batch Size: {config.BATCH_SIZE}')

    for epoch in range(config.EPOCHS):
        epoch_start = datetime.now()

        E.train()
        G.train()
        D.train()

        for batch_num, (i, video) in enumerate(dataset):

            # region PROCESS BATCH -------------------------------------------------------------------------------------
            batch_start = datetime.now()

            # video [B, K+1, 2, C, W, H]

            # Put one frame aside (frame t)
            t = video[:, -1, ...]  # [B, 2, C, W, H]
            video = video[:, :-1, ...]  # [B, K, 2, C, W, H]
            dims = video.shape

            # Calculate average encoding vector for video
            e_in = .reshape(dims[0] * dims[1], dims[2], dims[3], dims[4], dims[5])  # [BxK, 2, C, W, H]
            x, y = e_in[:, 0, ...], e_in[:, 1, ...]
            e_vectors = E(x, y).reshape(dims[0], dims[1], -1)  # B, K, len(e)
            e_hat = e_vectors.mean(dim=1)
 
            # Generate frame using landmarks from frame t
            x_t, y_t = t[:, 0, ...], t[:, 1, ...]
            x_hat = G(y_t, e_hat)

            # Optimize E_G and D
            r_x_hat, _ = D(x_hat, y_t, i)
            r_x, _ = D(x_t, y_t, i)

            optimizer_E_G.zero_grad()
            optimizer_D.zero_grad()

            loss_E_G = criterion_E_G(x_t, x_hat, r_x_hat, e_hat, D.W[:, i].transpose(1, 0))
            loss_D = criterion_D(r_x, r_x_hat)
            loss = loss_E_G + loss_D
            loss.backward()

            optimizer_E_G.step()
            optimizer_D.step()

            # Optimize D again
            x_hat = G(y_t, e_hat).detach()
            r_x_hat, D_act_hat = D(x_hat, y_t, i)
            r_x, D_act = D(x_t, y_t, i)

            optimizer_D.zero_grad()
            loss_D = criterion_D(r_x, r_x_hat)
            loss_D.backward()
            optimizer_D.step()

            batch_end = datetime.now()

            # endregion

            # region SHOW PROGRESS -------------------------------------------------------------------------------------
            if (batch_num + 1) % 1 == 0 or batch_num == 0:
                logging.info(f'Epoch {epoch + 1}: [{batch_num + 1}/{len(dataset)}] | '
                             f'Time: {batch_end - batch_start} | '
                             f'Loss_E_G = {loss_E_G.item():.4f} Loss_D = {loss_D.item():.4f}')
                logging.debug(f'D(x) = {r_x.mean().item():.4f} D(x_hat) = {r_x_hat.mean().item():.4f}')
            # endregion

            # region SAVE ----------------------------------------------------------------------------------------------
            save_image(os.path.join(config.GENERATED_DIR, f'last_result_x.png'), x_t[0])
            save_image(os.path.join(config.GENERATED_DIR, f'last_result_x_hat.png'), x_hat[0])

            if (batch_num + 1) % 100 == 0:
                save_image(os.path.join(config.GENERATED_DIR, f'{datetime.now():%Y%m%d_%H%M%S%f}_x.png'), x_t[0])
                save_image(os.path.join(config.GENERATED_DIR, f'{datetime.now():%Y%m%d_%H%M%S%f}_x_hat.png'), x_hat[0])

            if (batch_num + 1) % 100 == 0:
                save_model(E, gpu, run_start)
                save_model(G, gpu, run_start)
                save_model(D, gpu, run_start)

            # endregion

        # SAVE MODELS --------------------------------------------------------------------------------------------------

        save_model(E, gpu, run_start)
        save_model(G, gpu, run_start)
        save_model(D, gpu, run_start)
        epoch_end = datetime.now()
        logging.info(f'Epoch {epoch + 1} finished in {epoch_end - epoch_start}. ')
コード例 #3
0
    def train(self, loaders):
        if self.args.loss == 'arcface':
            BACKBONE_RESUME_ROOT = 'D:/face-recognition/stargan-v2-master/ms1m_ir50/backbone_ir50_ms1m_epoch120.pth'

            INPUT_SIZE = [112, 112]
            arcface = IR_50(INPUT_SIZE)

            if os.path.isfile(BACKBONE_RESUME_ROOT):
                arcface.load_state_dict(torch.load(BACKBONE_RESUME_ROOT))
                print("Loading Backbone Checkpoint '{}'".format(
                    BACKBONE_RESUME_ROOT))

            DEVICE = torch.device(
                "cuda:0" if torch.cuda.is_available() else "cpu")
            criterion_id = arcface.to(DEVICE)
        elif self.args.loss == 'perceptual':
            criterion_id = network.LossEG(False, 0)
        elif self.args.loss == 'lightcnn':
            BACKBONE_RESUME_ROOT = 'D:/face-reenactment/stargan-v2-master/FR_Pretrained_Test/Pretrained/LightCNN/LightCNN_29Layers_V2_checkpoint.pth.tar'

            # INPUT_SIZE = [128, 128]
            Model = LightCNN_29Layers_v2()

            if os.path.isfile(BACKBONE_RESUME_ROOT):

                Model = WrappedModel(Model)
                checkpoint = torch.load(BACKBONE_RESUME_ROOT)
                Model.load_state_dict(checkpoint['state_dict'])
                print("Loading Backbone Checkpoint '{}'".format(
                    BACKBONE_RESUME_ROOT))

            DEVICE = torch.device(
                "cuda:0" if torch.cuda.is_available() else "cpu")
            criterion_id = Model.to(DEVICE)
            # criterion_id = FR_Pretrained_Test.LossEG(False, 0)
        if self.args.id_embed:
            id_embedder = network.vgg_feature(False, 0)
        else:
            id_embedder = None

        args = self.args
        nets = self.nets
        nets_ema = self.nets_ema
        optims = self.optims

        # fetch random validation images for debugging
        fetcher = InputFetcher(loaders.src, args.latent_dim, 'train',
                               args.multi)
        fetcher_val = InputFetcher(loaders.val, args.latent_dim, 'val',
                                   args.multi)
        inputs_val = next(fetcher_val)

        # resume training if necessary
        if args.resume_iter > 0:
            self._load_checkpoint(args.resume_iter)

        # remember the initial value of ds weight
        initial_lambda_ds = args.lambda_ds

        print('Start training...')
        start_time = time.time()
        for i in range(args.resume_iter, args.total_iters):
            # fetch images and labels
            inputs = next(fetcher)

            if args.multi:
                x1_1_source = inputs.x1
                x1_2_source = inputs.x2
                x1_3_source = inputs.x3
                x1_4_source = inputs.x4

                x2_target, x2_target_lm = inputs.x5, inputs.x5_lm

                d_loss, d_losses = compute_d_loss_multi(nets,
                                                        args,
                                                        x1_1_source,
                                                        x1_2_source,
                                                        x1_3_source,
                                                        x1_4_source,
                                                        x2_target,
                                                        x2_target_lm,
                                                        masks=None,
                                                        loss_select=args.loss)
                self._reset_grad()
                d_loss.backward()
                optims.discriminator.step()

                # train the generator
                g_loss, g_losses = compute_g_loss_multi(nets,
                                                        args,
                                                        x1_1_source,
                                                        x1_2_source,
                                                        x1_3_source,
                                                        x1_4_source,
                                                        x2_target,
                                                        x2_target_lm,
                                                        criterion_id,
                                                        masks=None,
                                                        loss_select=args.loss)
                self._reset_grad()
                g_loss.backward()
                if args.id_embed:
                    optims.generator.step()
                    optims.mlp.step()
                else:
                    optims.generator.step()

            else:
                x1_source_lm = None
                x1_source = inputs.x1
                x2_target, x2_target_lm = inputs.x2, inputs.x2_lm
                # label = inputs.label

                # masks = nets.fan.get_heatmap(x_real) if args.w_hpf > 0 else None

                # train the discriminator
                d_loss, d_losses = compute_d_loss(nets,
                                                  args,
                                                  x1_source,
                                                  x1_source_lm,
                                                  x2_target,
                                                  x2_target_lm,
                                                  masks=None,
                                                  loss_select=args.loss,
                                                  embedder=id_embedder)
                self._reset_grad()
                d_loss.backward()
                optims.discriminator.step()

                # train the generator
                g_loss, g_losses = compute_g_loss(nets,
                                                  args,
                                                  x1_source,
                                                  x1_source_lm,
                                                  x2_target,
                                                  x2_target_lm,
                                                  criterion_id,
                                                  masks=None,
                                                  loss_select=args.loss,
                                                  embedder=id_embedder)
                self._reset_grad()
                g_loss.backward()
                if args.id_embed:
                    optims.generator.step()
                    optims.mlp.step()
                else:
                    optims.generator.step()

            # compute moving average of network parameters
            moving_average(nets.generator, nets_ema.generator, beta=0.999)
            # moving_average(nets.mapping_network, nets_ema.mapping_network, beta=0.999)
            if args.id_embed:
                moving_average(nets.mlp, nets_ema.mlp, beta=0.999)
            else:
                moving_average(nets.style_encoder,
                               nets_ema.style_encoder,
                               beta=0.999)

            # decay weight for diversity sensitive loss
            if args.lambda_ds > 0:
                args.lambda_ds -= (initial_lambda_ds / args.ds_iter)

            # print out log info
            if (i + 1) % args.print_every == 0:
                elapsed = time.time() - start_time
                elapsed = str(datetime.timedelta(seconds=elapsed))[:-7]
                log = "Elapsed time [%s], Iteration [%i/%i], " % (
                    elapsed, i + 1, args.total_iters)
                all_losses = dict()
                for loss, prefix in zip([d_losses, g_losses], ['D/', 'G/']):
                    for key, value in loss.items():
                        all_losses[prefix + key] = value
                all_losses['G/lambda_ds'] = args.lambda_ds
                log += ' '.join([
                    '%s: [%.4f]' % (key, value)
                    for key, value in all_losses.items()
                ])
                print(log)
                for key, value in all_losses.items():
                    self.writer.add_scalar(key, value, i + 1)

            # generate images for debugging
            if (i + 1) % args.sample_every == 0:
                os.makedirs(args.sample_dir, exist_ok=True)

                utils.debug_image(nets_ema,
                                  args,
                                  inputs=inputs_val,
                                  step=i + 1,
                                  embedder=id_embedder)

            # save model checkpoints
            if (i + 1) % args.save_every == 0:
                self._save_checkpoint(step=i + 1)

            # compute FID and LPIPS if necessary
            if (i + 1) % args.eval_every == 0:
                calculate_metrics(nets_ema, args, i + 1, mode='latent')
                # calculate_metrics(nets_ema, args, i+1, mode='reference')
        self.writer.close()
コード例 #4
0
ファイル: run.py プロジェクト: yvonwin/talking-heads
def meta_train(device, dataset_path, continue_id):
    run_start = datetime.now()
    logging.info('===== META-TRAINING =====')
    # GPU / CPU --------------------------------------------------------------------------------------------------------
    if device is not None and device != 'cpu':
        dtype = torch.cuda.FloatTensor
        torch.cuda.set_device(device)
        logging.info(f'Running on GPU: {torch.cuda.current_device()}.')
    else:
        dtype = torch.FloatTensor
        logging.info(f'Running on CPU.')

    # DATASET-----------------------------------------------------------------------------------------------------------
    logging.info(f'Training using dataset located in {dataset_path}')
    dataset = VoxCelebDataset(root=dataset_path,
                              extension='.vid',
                              shuffle=False,
                              shuffle_frames=True,
                              transform=transforms.Compose([
                                  transforms.Resize(config.IMAGE_SIZE),
                                  transforms.CenterCrop(config.IMAGE_SIZE),
                                  transforms.ToTensor(),
                                  transforms.Normalize([0.485, 0.456, 0.406],
                                                       [0.229, 0.224, 0.225]),
                              ]))

    # NETWORK ----------------------------------------------------------------------------------------------------------

    E = network.Embedder().type(dtype)
    G = network.Generator().type(dtype)
    D = network.Discriminator(143000).type(dtype)

    if continue_id is not None:
        E = load_model(E, continue_id)
        G = load_model(G, continue_id)
        D = load_model(D, continue_id)

    optimizer_E_G = Adam(params=list(E.parameters()) + list(G.parameters()),
                         lr=config.LEARNING_RATE_E_G)
    optimizer_D = Adam(params=D.parameters(), lr=config.LEARNING_RATE_D)

    criterion_E_G = network.LossEG(device, feed_forward=True)
    criterion_D = network.LossD(device)

    # TRAINING LOOP ----------------------------------------------------------------------------------------------------
    logging.info(
        f'Starting training loop. Epochs: {config.EPOCHS} Dataset Size: {len(dataset)}'
    )

    for epoch in range(config.EPOCHS):
        epoch_start = datetime.now()
        batch_durations = []

        E.train()
        G.train()
        D.train()

        for batch_num, (i, video) in enumerate(dataset):
            batch_start = datetime.now()

            # Put one frame aside (frame t)
            t = video.pop()

            # Calculate average encoding vector for video
            e_vectors = []
            for s in video:
                x_s = s['frame'].type(dtype)
                y_s = s['landmarks'].type(dtype)
                e_vectors.append(E(x_s, y_s))
            e_hat = torch.stack(e_vectors).mean(dim=0)

            # Generate frame using landmarks from frame t
            x_t = t['frame'].type(dtype)
            y_t = t['landmarks'].type(dtype)
            x_hat = G(y_t, e_hat)

            # Optimize E_G and D
            r_x_hat, D_act_hat = D(x_hat, y_t, i)
            r_x, D_act = D(x_t, y_t, i)

            optimizer_E_G.zero_grad()
            optimizer_D.zero_grad()

            loss_E_G = criterion_E_G(x_t, x_hat, r_x_hat, e_hat, D.W[:, i],
                                     D_act, D_act_hat)
            loss_D = criterion_D(r_x, r_x_hat)
            loss = loss_E_G + loss_D
            loss.backward(retain_graph=True)

            optimizer_E_G.step()
            optimizer_D.step()

            # Optimize D again
            r_x_hat, D_act_hat = D(G(y_t, e_hat), y_t, i)
            r_x, D_act = D(x_t, y_t, i)

            optimizer_D.zero_grad()
            loss_D = criterion_D(r_x, r_x_hat)
            loss_D.backward()
            optimizer_D.step()

            batch_end = datetime.now()
            batch_durations.append(batch_end - batch_start)
            # SHOW PROGRESS --------------------------------------------------------------------------------------------
            if (batch_num + 1) % 100 == 0 or batch_num == 0:
                avg_time = sum(batch_durations,
                               timedelta(0)) / len(batch_durations)
                logging.info(
                    f'Epoch {epoch+1}: [{batch_num + 1}/{len(dataset)}] | '
                    f'Avg Time: {avg_time} | '
                    f'Loss_E_G = {loss_E_G.item():.4} Loss_D {loss_D.item():.4}'
                )
                logging.debug(
                    f'D(x) = {r_x.item():.4} D(x_hat) = {r_x_hat.item():.4}')

            # SAVE IMAGES ----------------------------------------------------------------------------------------------
            if (batch_num + 1) % 100 == 0:
                if not os.path.isdir(config.GENERATED_DIR):
                    os.makedirs(config.GENERATED_DIR)

                save_image(
                    os.path.join(config.GENERATED_DIR,
                                 f'{datetime.now():%Y%m%d_%H%M}_x.png'), x_t)
                save_image(
                    os.path.join(config.GENERATED_DIR,
                                 f'{datetime.now():%Y%m%d_%H%M}_x_hat.png'),
                    x_hat)

            if (batch_num + 1) % 2000 == 0:
                save_model(E, device)
                save_model(G, device)
                save_model(D, device)

        # SAVE MODELS --------------------------------------------------------------------------------------------------

        save_model(E, device, run_start)
        save_model(G, device, run_start)
        save_model(D, device, run_start)
        epoch_end = datetime.now()
        logging.info(
            f'Epoch {epoch+1} finished in {epoch_end - epoch_start}. '
            f'Average batch time: {sum(batch_durations, timedelta(0)) / len(batch_durations)}'
        )
コード例 #5
0
def meta_train(gpu, dataset_path, continue_id):
    run_start = datetime.now()
    logging.info('===== META-TRAINING =====')
    # GPU / CPU --------------------------------------------------------------------------------------------------------
    if gpu:
        dtype = torch.cuda.FloatTensor
        torch.set_default_tensor_type(dtype)
        logging.info(f'Running on GPU: {torch.cuda.current_device()}.')
    else:
        dtype = torch.FloatTensor
        torch.set_default_tensor_type(dtype)
        logging.info(f'Running on CPU.')

    # DATASET-----------------------------------------------------------------------------------------------------------
    logging.info(f'Training using dataset located in {dataset_path}')
    raw_dataset = VoxCelebDataset(
        root=dataset_path,
        extension='.vid',
        shuffle_frames=True,
        # subset_size=1,
        transform=transforms.Compose([
            transforms.Resize(config.IMAGE_SIZE),
            transforms.CenterCrop(config.IMAGE_SIZE),
            transforms.ToTensor(),
        ]))
    dataset = DataLoader(raw_dataset,
                         batch_size=config.BATCH_SIZE,
                         shuffle=True)

    # NETWORK ----------------------------------------------------------------------------------------------------------

    E = network.Embedder().type(dtype)
    G = network.Generator().type(dtype)
    D = network.Discriminator(len(raw_dataset)).type(dtype)

    optimizer_E_G = Adam(params=list(E.parameters()) + list(G.parameters()),
                         lr=config.LEARNING_RATE_E_G)
    optimizer_D = Adam(params=D.parameters(), lr=config.LEARNING_RATE_D)

    criterion_E_G = network.LossEG(feed_forward=True)
    criterion_D = network.LossD()

    if gpu:
        E = DataParallel(E)
        G = DataParallel(G)
        D = ParallelDiscriminator(D)
        criterion_E_G = DataParallel(criterion_E_G)
        criterion_D = DataParallel(criterion_D)

    if continue_id is not None:
        E = load_model(E, 'Embedder', continue_id)
        G = load_model(G, 'Generator', continue_id)
        D = load_model(D, 'Discriminator', continue_id)

    # TRAINING LOOP ----------------------------------------------------------------------------------------------------
    logging.info(f'Starting training loop. '
                 f'Epochs: {config.EPOCHS} '
                 f'Batches: {len(dataset)} '
                 f'Batch Size: {config.BATCH_SIZE}')

    for epoch in range(config.EPOCHS):
        epoch_start = datetime.now()
        batch_durations = []

        E.train()
        G.train()
        D.train()

        for batch_num, (i, video) in enumerate(dataset):
            batch_start = datetime.now()
            video = video.type(dtype)  # [B, K+1, 2, C, W, H]

            # Put one frame aside (frame t)
            t = video[:, -1, ...]  # [B, 2, C, W, H]
            video = video[:, :-1, ...]  # [B, K, C, W, H]
            dims = video.shape

            # Calculate average encoding vector for video
            e_in = video.reshape(dims[0] * dims[1], dims[2], dims[3], dims[4],
                                 dims[5])  # [BxK, 2, C, W, H]
            x, y = e_in[:, 0, ...], e_in[:, 1, ...]
            e_vectors = E(x, y).reshape(dims[0], dims[1], -1)  # B, K, len(e)
            e_hat = e_vectors.mean(dim=1)

            # Generate frame using landmarks from frame t
            x_t, y_t = t[:, 0, ...], t[:, 1, ...]
            x_hat = G(y_t, e_hat)

            # Optimize E_G and D
            r_x_hat, D_act_hat = D(x_hat, y_t, i)
            r_x, D_act = D(x_t, y_t, i)

            optimizer_E_G.zero_grad()
            optimizer_D.zero_grad()

            loss_E_G = criterion_E_G(x_t, x_hat, r_x_hat, e_hat,
                                     D.W[:, i].transpose(1, 0), D_act,
                                     D_act_hat).mean()
            loss_D = criterion_D(r_x, r_x_hat).mean()
            loss = loss_E_G + loss_D
            loss.backward()

            optimizer_E_G.step()
            optimizer_D.step()

            # Optimize D again
            x_hat = G(y_t, e_hat).detach()
            r_x_hat, D_act_hat = D(x_hat, y_t, i)
            r_x, D_act = D(x_t, y_t, i)

            optimizer_D.zero_grad()
            loss_D = criterion_D(r_x, r_x_hat).mean()
            loss_D.backward()
            optimizer_D.step()

            batch_end = datetime.now()
            batch_duration = batch_end - batch_start
            batch_durations.append(batch_duration)
            # SHOW PROGRESS --------------------------------------------------------------------------------------------
            if (batch_num + 1) % 1 == 0 or batch_num == 0:
                logging.info(
                    f'Epoch {epoch + 1}: [{batch_num + 1}/{len(dataset)}] | '
                    f'Time: {batch_duration} | '
                    f'Loss_E_G = {loss_E_G.item():.4} Loss_D {loss_D.item():.4}'
                )
                logging.debug(
                    f'D(x) = {r_x.mean().item():.4} D(x_hat) = {r_x_hat.mean().item():.4}'
                )

            # SAVE IMAGES ----------------------------------------------------------------------------------------------
            save_image(
                os.path.join(config.GENERATED_DIR, f'last_result_x.png'),
                x_t[0])
            save_image(
                os.path.join(config.GENERATED_DIR, f'last_result_x_hat.png'),
                x_hat[0])

            if (batch_num + 1) % 1000 == 0:
                save_image(
                    os.path.join(config.GENERATED_DIR,
                                 f'{datetime.now():%Y%m%d_%H%M%S%f}_x.png'),
                    x_t[0])
                save_image(
                    os.path.join(
                        config.GENERATED_DIR,
                        f'{datetime.now():%Y%m%d_%H%M%S%f}_x_hat.png'),
                    x_hat[0])

            # SAVE MODELS ----------------------------------------------------------------------------------------------
            if (batch_num + 1) % 100 == 0:
                save_model(E, 'Embedder', gpu, run_start)
                save_model(G, 'Generator', gpu, run_start)
                save_model(D, 'Discriminator', gpu, run_start)

        # SAVE MODELS --------------------------------------------------------------------------------------------------

        save_model(E, 'Embedder', gpu, run_start)
        save_model(G, 'Generator', gpu, run_start)
        save_model(D, 'Discriminator', gpu, run_start)
        epoch_end = datetime.now()
        logging.info(
            f'Epoch {epoch + 1} finished in {epoch_end - epoch_start}. '
            f'Average batch time: {sum(batch_durations, timedelta(0)) / len(batch_durations)}'
        )