示例#1
0
def meta_train(device, dataset_path, continue_id):
    run_start = datetime.now()
    logging.info('===== META-TRAINING =====')
    # GPU / CPU --------------------------------------------------------------------------------------------------------
    if device is not None and device != 'cpu':
        dtype = torch.cuda.FloatTensor
        torch.cuda.set_device(device)
        logging.info(f'Running on GPU: {torch.cuda.current_device()}.')
    else:
        dtype = torch.FloatTensor
        logging.info(f'Running on CPU.')

    # DATASET-----------------------------------------------------------------------------------------------------------
    logging.info(f'Training using dataset located in {dataset_path}')
    dataset = VoxCelebDataset(root=dataset_path,
                              extension='.vid',
                              shuffle=False,
                              shuffle_frames=True,
                              transform=transforms.Compose([
                                  transforms.Resize(config.IMAGE_SIZE),
                                  transforms.CenterCrop(config.IMAGE_SIZE),
                                  transforms.ToTensor(),
                                  transforms.Normalize([0.485, 0.456, 0.406],
                                                       [0.229, 0.224, 0.225]),
                              ]))

    # NETWORK ----------------------------------------------------------------------------------------------------------

    E = network.Embedder().type(dtype)
    G = network.Generator().type(dtype)
    D = network.Discriminator(143000).type(dtype)

    if continue_id is not None:
        E = load_model(E, continue_id)
        G = load_model(G, continue_id)
        D = load_model(D, continue_id)

    optimizer_E_G = Adam(params=list(E.parameters()) + list(G.parameters()),
                         lr=config.LEARNING_RATE_E_G)
    optimizer_D = Adam(params=D.parameters(), lr=config.LEARNING_RATE_D)

    criterion_E_G = network.LossEG(device, feed_forward=True)
    criterion_D = network.LossD(device)

    # TRAINING LOOP ----------------------------------------------------------------------------------------------------
    logging.info(
        f'Starting training loop. Epochs: {config.EPOCHS} Dataset Size: {len(dataset)}'
    )

    for epoch in range(config.EPOCHS):
        epoch_start = datetime.now()
        batch_durations = []

        E.train()
        G.train()
        D.train()

        for batch_num, (i, video) in enumerate(dataset):
            batch_start = datetime.now()

            # Put one frame aside (frame t)
            t = video.pop()

            # Calculate average encoding vector for video
            e_vectors = []
            for s in video:
                x_s = s['frame'].type(dtype)
                y_s = s['landmarks'].type(dtype)
                e_vectors.append(E(x_s, y_s))
            e_hat = torch.stack(e_vectors).mean(dim=0)

            # Generate frame using landmarks from frame t
            x_t = t['frame'].type(dtype)
            y_t = t['landmarks'].type(dtype)
            x_hat = G(y_t, e_hat)

            # Optimize E_G and D
            r_x_hat, D_act_hat = D(x_hat, y_t, i)
            r_x, D_act = D(x_t, y_t, i)

            optimizer_E_G.zero_grad()
            optimizer_D.zero_grad()

            loss_E_G = criterion_E_G(x_t, x_hat, r_x_hat, e_hat, D.W[:, i],
                                     D_act, D_act_hat)
            loss_D = criterion_D(r_x, r_x_hat)
            loss = loss_E_G + loss_D
            loss.backward(retain_graph=True)

            optimizer_E_G.step()
            optimizer_D.step()

            # Optimize D again
            r_x_hat, D_act_hat = D(G(y_t, e_hat), y_t, i)
            r_x, D_act = D(x_t, y_t, i)

            optimizer_D.zero_grad()
            loss_D = criterion_D(r_x, r_x_hat)
            loss_D.backward()
            optimizer_D.step()

            batch_end = datetime.now()
            batch_durations.append(batch_end - batch_start)
            # SHOW PROGRESS --------------------------------------------------------------------------------------------
            if (batch_num + 1) % 100 == 0 or batch_num == 0:
                avg_time = sum(batch_durations,
                               timedelta(0)) / len(batch_durations)
                logging.info(
                    f'Epoch {epoch+1}: [{batch_num + 1}/{len(dataset)}] | '
                    f'Avg Time: {avg_time} | '
                    f'Loss_E_G = {loss_E_G.item():.4} Loss_D {loss_D.item():.4}'
                )
                logging.debug(
                    f'D(x) = {r_x.item():.4} D(x_hat) = {r_x_hat.item():.4}')

            # SAVE IMAGES ----------------------------------------------------------------------------------------------
            if (batch_num + 1) % 100 == 0:
                if not os.path.isdir(config.GENERATED_DIR):
                    os.makedirs(config.GENERATED_DIR)

                save_image(
                    os.path.join(config.GENERATED_DIR,
                                 f'{datetime.now():%Y%m%d_%H%M}_x.png'), x_t)
                save_image(
                    os.path.join(config.GENERATED_DIR,
                                 f'{datetime.now():%Y%m%d_%H%M}_x_hat.png'),
                    x_hat)

            if (batch_num + 1) % 2000 == 0:
                save_model(E, device)
                save_model(G, device)
                save_model(D, device)

        # SAVE MODELS --------------------------------------------------------------------------------------------------

        save_model(E, device, run_start)
        save_model(G, device, run_start)
        save_model(D, device, run_start)
        epoch_end = datetime.now()
        logging.info(
            f'Epoch {epoch+1} finished in {epoch_end - epoch_start}. '
            f'Average batch time: {sum(batch_durations, timedelta(0)) / len(batch_durations)}'
        )
示例#2
0
def meta_train(gpu, dataset_path, continue_id):
    run_start = datetime.now()
    logging.info('===== META-TRAINING =====')
    logging.info(f'Running on {"GPU" if gpu else "CPU"}.')

    # region DATASET----------------------------------------------------------------------------------------------------
    logging.info(f'Training using dataset located in {dataset_path}')
    raw_dataset = VoxCelebDataset(
        root=dataset_path,
        extension='.vid',
        shuffle_frames=True,
        subset_size=config.SUBSET_SIZE,
        transform=transforms.Compose([
            transforms.Resize(config.IMAGE_SIZE),
            transforms.CenterCrop(config.IMAGE_SIZE),
            transforms.ToTensor(),
        ])
    )
    dataset = DataLoader(raw_dataset, batch_size=config.BATCH_SIZE, shuffle=True)

    # endregion

    # region NETWORK ---------------------------------------------------------------------------------------------------

    E = network.Embedder(GPU['Embedder'])
    G = network.Generator(GPU['Generator'])
    D = network.Discriminator(len(raw_dataset), GPU['Discriminator'])
    criterion_E_G = network.LossEG(config.FEED_FORWARD, GPU['LossEG'])
    criterion_D = network.LossD(GPU['LossD'])

    optimizer_E_G = Adam(
        params=list(E.parameters()) + list(G.parameters()),
        lr=config.LEARNING_RATE_E_G
    )
    optimizer_D = Adam(
        params=D.parameters(),
        lr=config.LEARNING_RATE_D
    )

    if continue_id is not None:
        E = load_model(E, continue_id)
        G = load_model(G, continue_id)
        D = load_model(D, continue_id)

    # endregion

    # region TRAINING LOOP ---------------------------------------------------------------------------------------------
    logging.info(f'Epochs: {config.EPOCHS} Batches: {len(dataset)} Batch Size: {config.BATCH_SIZE}')

    for epoch in range(config.EPOCHS):
        epoch_start = datetime.now()

        E.train()
        G.train()
        D.train()

        for batch_num, (i, video) in enumerate(dataset):

            # region PROCESS BATCH -------------------------------------------------------------------------------------
            batch_start = datetime.now()

            # video [B, K+1, 2, C, W, H]

            # Put one frame aside (frame t)
            t = video[:, -1, ...]  # [B, 2, C, W, H]
            video = video[:, :-1, ...]  # [B, K, 2, C, W, H]
            dims = video.shape

            # Calculate average encoding vector for video
            e_in = .reshape(dims[0] * dims[1], dims[2], dims[3], dims[4], dims[5])  # [BxK, 2, C, W, H]
            x, y = e_in[:, 0, ...], e_in[:, 1, ...]
            e_vectors = E(x, y).reshape(dims[0], dims[1], -1)  # B, K, len(e)
            e_hat = e_vectors.mean(dim=1)
 
            # Generate frame using landmarks from frame t
            x_t, y_t = t[:, 0, ...], t[:, 1, ...]
            x_hat = G(y_t, e_hat)

            # Optimize E_G and D
            r_x_hat, _ = D(x_hat, y_t, i)
            r_x, _ = D(x_t, y_t, i)

            optimizer_E_G.zero_grad()
            optimizer_D.zero_grad()

            loss_E_G = criterion_E_G(x_t, x_hat, r_x_hat, e_hat, D.W[:, i].transpose(1, 0))
            loss_D = criterion_D(r_x, r_x_hat)
            loss = loss_E_G + loss_D
            loss.backward()

            optimizer_E_G.step()
            optimizer_D.step()

            # Optimize D again
            x_hat = G(y_t, e_hat).detach()
            r_x_hat, D_act_hat = D(x_hat, y_t, i)
            r_x, D_act = D(x_t, y_t, i)

            optimizer_D.zero_grad()
            loss_D = criterion_D(r_x, r_x_hat)
            loss_D.backward()
            optimizer_D.step()

            batch_end = datetime.now()

            # endregion

            # region SHOW PROGRESS -------------------------------------------------------------------------------------
            if (batch_num + 1) % 1 == 0 or batch_num == 0:
                logging.info(f'Epoch {epoch + 1}: [{batch_num + 1}/{len(dataset)}] | '
                             f'Time: {batch_end - batch_start} | '
                             f'Loss_E_G = {loss_E_G.item():.4f} Loss_D = {loss_D.item():.4f}')
                logging.debug(f'D(x) = {r_x.mean().item():.4f} D(x_hat) = {r_x_hat.mean().item():.4f}')
            # endregion

            # region SAVE ----------------------------------------------------------------------------------------------
            save_image(os.path.join(config.GENERATED_DIR, f'last_result_x.png'), x_t[0])
            save_image(os.path.join(config.GENERATED_DIR, f'last_result_x_hat.png'), x_hat[0])

            if (batch_num + 1) % 100 == 0:
                save_image(os.path.join(config.GENERATED_DIR, f'{datetime.now():%Y%m%d_%H%M%S%f}_x.png'), x_t[0])
                save_image(os.path.join(config.GENERATED_DIR, f'{datetime.now():%Y%m%d_%H%M%S%f}_x_hat.png'), x_hat[0])

            if (batch_num + 1) % 100 == 0:
                save_model(E, gpu, run_start)
                save_model(G, gpu, run_start)
                save_model(D, gpu, run_start)

            # endregion

        # SAVE MODELS --------------------------------------------------------------------------------------------------

        save_model(E, gpu, run_start)
        save_model(G, gpu, run_start)
        save_model(D, gpu, run_start)
        epoch_end = datetime.now()
        logging.info(f'Epoch {epoch + 1} finished in {epoch_end - epoch_start}. ')
示例#3
0
def meta_train(gpu, dataset_path, continue_id):
    run_start = datetime.now()
    logging.info('===== META-TRAINING =====')
    # GPU / CPU --------------------------------------------------------------------------------------------------------
    if gpu:
        dtype = torch.cuda.FloatTensor
        torch.set_default_tensor_type(dtype)
        logging.info(f'Running on GPU: {torch.cuda.current_device()}.')
    else:
        dtype = torch.FloatTensor
        torch.set_default_tensor_type(dtype)
        logging.info(f'Running on CPU.')

    # DATASET-----------------------------------------------------------------------------------------------------------
    logging.info(f'Training using dataset located in {dataset_path}')
    raw_dataset = VoxCelebDataset(
        root=dataset_path,
        extension='.vid',
        shuffle_frames=True,
        # subset_size=1,
        transform=transforms.Compose([
            transforms.Resize(config.IMAGE_SIZE),
            transforms.CenterCrop(config.IMAGE_SIZE),
            transforms.ToTensor(),
        ]))
    dataset = DataLoader(raw_dataset,
                         batch_size=config.BATCH_SIZE,
                         shuffle=True)

    # NETWORK ----------------------------------------------------------------------------------------------------------

    E = network.Embedder().type(dtype)
    G = network.Generator().type(dtype)
    D = network.Discriminator(len(raw_dataset)).type(dtype)

    optimizer_E_G = Adam(params=list(E.parameters()) + list(G.parameters()),
                         lr=config.LEARNING_RATE_E_G)
    optimizer_D = Adam(params=D.parameters(), lr=config.LEARNING_RATE_D)

    criterion_E_G = network.LossEG(feed_forward=True)
    criterion_D = network.LossD()

    if gpu:
        E = DataParallel(E)
        G = DataParallel(G)
        D = ParallelDiscriminator(D)
        criterion_E_G = DataParallel(criterion_E_G)
        criterion_D = DataParallel(criterion_D)

    if continue_id is not None:
        E = load_model(E, 'Embedder', continue_id)
        G = load_model(G, 'Generator', continue_id)
        D = load_model(D, 'Discriminator', continue_id)

    # TRAINING LOOP ----------------------------------------------------------------------------------------------------
    logging.info(f'Starting training loop. '
                 f'Epochs: {config.EPOCHS} '
                 f'Batches: {len(dataset)} '
                 f'Batch Size: {config.BATCH_SIZE}')

    for epoch in range(config.EPOCHS):
        epoch_start = datetime.now()
        batch_durations = []

        E.train()
        G.train()
        D.train()

        for batch_num, (i, video) in enumerate(dataset):
            batch_start = datetime.now()
            video = video.type(dtype)  # [B, K+1, 2, C, W, H]

            # Put one frame aside (frame t)
            t = video[:, -1, ...]  # [B, 2, C, W, H]
            video = video[:, :-1, ...]  # [B, K, C, W, H]
            dims = video.shape

            # Calculate average encoding vector for video
            e_in = video.reshape(dims[0] * dims[1], dims[2], dims[3], dims[4],
                                 dims[5])  # [BxK, 2, C, W, H]
            x, y = e_in[:, 0, ...], e_in[:, 1, ...]
            e_vectors = E(x, y).reshape(dims[0], dims[1], -1)  # B, K, len(e)
            e_hat = e_vectors.mean(dim=1)

            # Generate frame using landmarks from frame t
            x_t, y_t = t[:, 0, ...], t[:, 1, ...]
            x_hat = G(y_t, e_hat)

            # Optimize E_G and D
            r_x_hat, D_act_hat = D(x_hat, y_t, i)
            r_x, D_act = D(x_t, y_t, i)

            optimizer_E_G.zero_grad()
            optimizer_D.zero_grad()

            loss_E_G = criterion_E_G(x_t, x_hat, r_x_hat, e_hat,
                                     D.W[:, i].transpose(1, 0), D_act,
                                     D_act_hat).mean()
            loss_D = criterion_D(r_x, r_x_hat).mean()
            loss = loss_E_G + loss_D
            loss.backward()

            optimizer_E_G.step()
            optimizer_D.step()

            # Optimize D again
            x_hat = G(y_t, e_hat).detach()
            r_x_hat, D_act_hat = D(x_hat, y_t, i)
            r_x, D_act = D(x_t, y_t, i)

            optimizer_D.zero_grad()
            loss_D = criterion_D(r_x, r_x_hat).mean()
            loss_D.backward()
            optimizer_D.step()

            batch_end = datetime.now()
            batch_duration = batch_end - batch_start
            batch_durations.append(batch_duration)
            # SHOW PROGRESS --------------------------------------------------------------------------------------------
            if (batch_num + 1) % 1 == 0 or batch_num == 0:
                logging.info(
                    f'Epoch {epoch + 1}: [{batch_num + 1}/{len(dataset)}] | '
                    f'Time: {batch_duration} | '
                    f'Loss_E_G = {loss_E_G.item():.4} Loss_D {loss_D.item():.4}'
                )
                logging.debug(
                    f'D(x) = {r_x.mean().item():.4} D(x_hat) = {r_x_hat.mean().item():.4}'
                )

            # SAVE IMAGES ----------------------------------------------------------------------------------------------
            save_image(
                os.path.join(config.GENERATED_DIR, f'last_result_x.png'),
                x_t[0])
            save_image(
                os.path.join(config.GENERATED_DIR, f'last_result_x_hat.png'),
                x_hat[0])

            if (batch_num + 1) % 1000 == 0:
                save_image(
                    os.path.join(config.GENERATED_DIR,
                                 f'{datetime.now():%Y%m%d_%H%M%S%f}_x.png'),
                    x_t[0])
                save_image(
                    os.path.join(
                        config.GENERATED_DIR,
                        f'{datetime.now():%Y%m%d_%H%M%S%f}_x_hat.png'),
                    x_hat[0])

            # SAVE MODELS ----------------------------------------------------------------------------------------------
            if (batch_num + 1) % 100 == 0:
                save_model(E, 'Embedder', gpu, run_start)
                save_model(G, 'Generator', gpu, run_start)
                save_model(D, 'Discriminator', gpu, run_start)

        # SAVE MODELS --------------------------------------------------------------------------------------------------

        save_model(E, 'Embedder', gpu, run_start)
        save_model(G, 'Generator', gpu, run_start)
        save_model(D, 'Discriminator', gpu, run_start)
        epoch_end = datetime.now()
        logging.info(
            f'Epoch {epoch + 1} finished in {epoch_end - epoch_start}. '
            f'Average batch time: {sum(batch_durations, timedelta(0)) / len(batch_durations)}'
        )