コード例 #1
0
 def train(self, d_model_A, d_model_B, g_model_AtoB, g_model_BtoA,
           c_model_AtoB, c_model_BtoA, dataset):
     '''  tarining step for cyclegan models'''
     # define properties of the training run
     n_epochs, n_batch, = 50, 1
     # determine the output square shape of the discriminator
     n_patch = d_model_A.output_shape[1]
     # unpack dataset
     trainA, trainB = dataset
     # prepare image pool for fakes
     poolA, poolB = list(), list()
     # calculate the number of batches per training epoch
     bat_per_epo = int(len(trainA) / n_batch)
     # calculate the number of training iterations
     n_steps = bat_per_epo * n_epochs
     # manually enumerate epochs
     for i in range(n_steps):
         # select a batch of real samples
         X_realA, y_realA = generate_real_samples(trainA, n_batch, n_patch)
         X_realB, y_realB = generate_real_samples(trainB, n_batch, n_patch)
         # generate a batch of fake samples
         X_fakeA, y_fakeA = generate_fake_samples(g_model_BtoA, X_realB,
                                                  n_patch)
         X_fakeB, y_fakeB = generate_fake_samples(g_model_AtoB, X_realA,
                                                  n_patch)
         # update fakes from pool
         X_fakeA = update_image_pool(poolA, X_fakeA)
         X_fakeB = update_image_pool(poolB, X_fakeB)
         # update generator B->A via adversarial and cycle los
         g_loss2, _, _, _, _ = c_model_BtoA.train_on_batch(
             [X_realB, X_realA], [y_realA, X_realA, X_realB, X_realA])
         # update discriminator for A -> [real/fake]
         dA_loss1 = d_model_A.train_on_batch(X_realA, y_realA)
         dA_loss2 = d_model_A.train_on_batch(X_fakeA, y_fakeA)
         # update generator A->B via adversarial and cycle loss
         g_loss1, _, _, _, _ = c_model_AtoB.train_on_batch(
             [X_realA, X_realB], [y_realB, X_realB, X_realA, X_realB])
         # update discriminator for B -> [real/fake]
         dB_loss1 = d_model_B.train_on_batch(X_realB, y_realB)
         dB_loss2 = d_model_B.train_on_batch(X_fakeB, y_fakeB)
         # summarize performance
         print('>%d, dA[%.3f,%.3f] dB[%.3f,%.3f] g[%.3f,%.3f]' %
               (i + 1, dA_loss1, dA_loss2, dB_loss1, dB_loss2, g_loss1,
                g_loss2))
         # evaluate the model performance every so often
         if (i + 1) % (bat_per_epo * 1) == 0:
             # plot A->B translation
             summarize_performance(i, g_model_AtoB, trainA, 'AtoB')
             # plot B->A translation
             summarize_performance(i, g_model_BtoA, trainB, 'BtoA')
         if (i + 1) % (bat_per_epo * 5) == 0:
             # save the models
             save_models(i, g_model_AtoB, g_model_BtoA)
コード例 #2
0
def train(g_model,
          d_model,
          gan_model,
          dataset,
          latent_dim,
          cat,
          epochs=100,
          batch=64):
    batch_per_epoch = int(dataset.shape[0] / batch)

    steps = batch_per_epoch * epochs

    half_batch = int(batch / 2)

    for i in range(steps):
        X_real, y_real = generate_real_samples(dataset, half_batch)

        d_loss1 = d_model.train_on_batch(X_real, y_real)

        X_fake, y_fake = generate_fake_samples(g_model, latent_dim, cat,
                                               half_batch)

        d_loss2 = d_model.train_on_batch(X_fake, y_fake)

        z_input, cat_codes = generate_latent_points(latent_dim, cat, batch)
        y_gan = np.ones((batch, 1))

        _, g_1, g_2 = gan_model.train_on_batch(z_input, [y_gan, cat_codes])

        print("[INFO] {:d}, d[{:.3f}, {:.3f}], g[{:.3f}], q[{:.3f}]".format(
            i + 1, d_loss1, d_loss2, g_1, g_2))

        if (i + 1) % (batch_per_epoch * 10) == 0:
            summarize_performance(i, g_model, gan_model, latent_dim, cat)
コード例 #3
0
ファイル: train.py プロジェクト: rahhul/GANs
def train(g_model, d_model, gan_model, dataset, latent_dim, n_epochs=10,
          n_batch=64):
    # calculate number of batches per epoch
    batch_per_epoch = int(dataset.shape[0] / n_batch)
    # no of training iterations
    n_steps = batch_per_epoch * n_epochs
    # half-batch
    half_batch = n_batch // 2
    # empty lists to store loss
    d1_hist, d2_hist, g_hist = list(), list(), list()
    # training loop
    for i in range(n_steps):
        # generate real and fake samples
        X_real, y_real = generate_mnist_samples(dataset, half_batch)
        X_fake, y_fake = generate_fake_samples(g_model, latent_dim, half_batch)
        # update discriminator
        d_loss1 = d_model.train_on_batch(X_real, y_real)
        d_loss2 = d_model.train_on_batch(X_fake, y_fake)
        # update generator
        z_input = generate_latent_points(latent_dim, n_batch)
        y_real2 = np.ones((n_batch, 1))
        g_loss = gan_model.train_on_batch(z_input, y_real2)
        #
        print(">%d, d1=%.3f, d2=%.3f, g=%.3f" % (i+1, d_loss1, d_loss2, g_loss))
        # record
        d1_hist.append(d_loss1)
        d2_hist.append(d_loss2)
        g_hist.append(g_loss)
        # evaluate
        if (i+1) % (batch_per_epoch * 1) == 0:
            log_performance(i, g_model, latent_dim)
    # plot
    plot_history(d1_hist, d2_hist, g_hist)
コード例 #4
0
ファイル: train.py プロジェクト: karthiks1995/gan-opt
def train(dataset, generator, discriminator, gan_model, latent_dim=100, n_epochs=20, n_batch=25):
    bat_per_epo = int(dataset.shape[0] / n_batch)
    half_batch = int(n_batch / 2)
    for i in range(n_epochs):
        # enumerate batches over the training set
        for j in range(bat_per_epo):
            start_time = time.time()
            # get randomly selected 'real' samples
            X_real, y_real = utils.generate_real_samples(dataset, half_batch)
            # update discriminator model weights
            d_loss1, _ = discriminator.train_on_batch(X_real, y_real)
            # generate 'fake' examples
            X_fake, y_fake = utils.generate_fake_samples(generator, latent_dim, half_batch)
            # update discriminator model weights
            d_loss2, _ = discriminator.train_on_batch(X_fake, y_fake)
            # prepare points in latent space as input for the generator
            X_gan = utils.generate_latent_points(latent_dim, n_batch)
            # create inverted labels for the fake samples
            y_gan = tf.ones((n_batch, 1))
            # update the generator via the discriminator's error
            g_loss = gan_model.train_on_batch(X_gan, y_gan)
            # summarize loss on this batch
            time_taken = time.time() - start_time
            print('>%d, %d/%d, d1=%.3f, d2=%.3f g=%.3f Time Taken:%.2f seconds' %
                  (i + 1, j + 1, bat_per_epo, d_loss1, d_loss2, g_loss, time_taken))
        # evaluate the model performance, sometimes
        if (i + 1) % 10 == 0:
            summarize_performance(i, generator, discriminator, dataset, latent_dim)
コード例 #5
0
def train(generator,
          discriminator,
          gan,
          dataset,
          latent_dim,
          epochs=100,
          batch=128):
    batch_per_epoch = int(dataset.shape[0] / batch)

    half_batch = int(batch / 2)

    for i in range(epochs):
        for j in range(batch_per_epoch):
            X_real, y_real = generate_real_samples(dataset, half_batch)

            d_loss1, _ = discriminator.train_on_batch(X_real, y_real)

            X_fake, y_fake = generate_fake_samples(generator, latent_dim,
                                                   half_batch)

            d_loss2, _ = discriminator.train_on_batch(X_fake, y_fake)

            Xgan = generate_latent_points(latent_dim, batch)
            ygan = np.ones((batch, 1))
            g_loss = gan.train_on_batch(Xgan, ygan)

            print("{:d}, {:d}/{:d}, d1={:.3f}, d2={:.3f}, g={:.3f}".format(
                i + 1, j + 1, batch_per_epoch, d_loss1, d_loss2, g_loss))

    generator.save("models/generator.h5")
コード例 #6
0
ファイル: train.py プロジェクト: karthiks1995/gan-opt
def train_discriminator(model, dataset, n_iter=20, n_batch=25, latent_dim=100):
    half_batch = int(n_batch / 2)
    for i in range(n_iter):
        X_real, y_real = utils.generate_real_samples(dataset=dataset, num_samples=half_batch)
        _, real_acc = model.train_on_batch(X_real, y_real)
        X_fake, y_fake = utils.generate_fake_samples(model, latent_dim=latent_dim, num_samples=half_batch)
        _, fake_acc = model.train_on_batch(X_fake, y_fake)
        print('>%d real=%.0f%% fake=%.0f%%' % (i + 1, real_acc * 100, fake_acc * 100))
コード例 #7
0
ファイル: train.py プロジェクト: karthiks1995/gan-opt
def summarize_performance(epoch, g_model, d_model, dataset, latent_dim, n_samples=150):
    # prepare real samples
    X_real, y_real = utils.generate_real_samples(dataset, n_samples)
    # evaluate discriminator on real examples
    _, acc_real = d_model.evaluate(X_real, y_real, verbose=0)
    # prepare fake examples
    x_fake, y_fake = utils.generate_fake_samples(g_model, latent_dim, n_samples)
    # evaluate discriminator on fake examples
    _, acc_fake = d_model.evaluate(x_fake, y_fake, verbose=0)
    # summarize discriminator performance
    print('>Accuracy real: %.0f%%, fake: %.0f%%' % (acc_real * 100, acc_fake * 100))
    # save plot
    save_plot(x_fake, epoch)
    # save the generator model tile file
    filename = 'generator_model_%03d.h5' % (epoch + 1)
    g_model.save(filename)
コード例 #8
0
ファイル: train.py プロジェクト: rahhul/GANs
def log_performance(step, g_model, latent_dim, n_samples=100):
    # generate fake samples for evaluation
    X, _ = generate_fake_samples(g_model, latent_dim, n_samples)
    # preprocess pixel values ?
    X = (X + 1) / 2.0
    # plot images
    for i in range(10 * 10):
        # subplots
        plt.subplot(10, 10, i+1)
        plt.axis('off')
        plt.imshow(X[i, :, :, 0])
    # save plot to disk
    filename = "plot_at_%06d.png" % (step + 1)
    plt.savefig(filename)
    plt.close()
    # save model
    model_name = "model_%06d.h5" % (step + 1)
    g_model.save(model_name)
    print(f"Saved {filename} and {model_name}")
コード例 #9
0
def train(g_model,
          d_model,
          gan_model,
          dataset,
          latent_dim,
          epochs=20,
          batch=64):
    batch_per_epoch = int(dataset.shape[0] / batch)

    steps = batch_per_epoch * epochs

    half_batch = int(batch / 2)

    d1_hist, d2_hist, g_hist = list(), list(), list()

    for i in range(steps):
        X_real, y_real = generate_real_samples(dataset, half_batch)

        X_fake, y_fake = generate_fake_samples(g_model, latent_dim, half_batch)

        # update discriminator model
        d_loss1 = d_model.train_on_batch(X_real, y_real)
        d_loss2 = d_model.train_on_batch(X_fake, y_fake)

        # update generator via discriminator's loss
        z_input = generate_latent_points(latent_dim, batch)
        y_real2 = np.ones((batch, 1))

        g_loss = gan_model.train_on_batch(z_input, y_real2)

        print("{:d}, d1={:.3f}, d2={:.3f}, g={:.3f}".format(
            i + 1, d_loss1, d_loss2, g_loss))

        d1_hist.append(d_loss1)
        d2_hist.append(d_loss2)
        g_hist.append(g_loss)

        if (i + 1) % (batch_per_epoch * 1) == 0:
            summarize_performance(i, g_model, latent_dim)

    plot_history(d1_hist, d2_hist, g_hist)
コード例 #10
0
def train(args):

    # Device Configuration #
    device = torch.device(
        f'cuda:{args.gpu_num}' if torch.cuda.is_available() else 'cpu')

    # Fix Seed for Reproducibility #
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed(args.seed)

    # Samples, Plots, Weights and CSV Path #
    paths = [
        args.samples_path, args.plots_path, args.weights_path, args.csv_path
    ]
    for path in paths:
        make_dirs(path)

    # Prepare Data #
    data = pd.read_csv(args.data_path)[args.column]

    # Pre-processing #
    scaler_1 = StandardScaler()
    scaler_2 = StandardScaler()
    preprocessed_data = pre_processing(data, scaler_1, scaler_2, args.delta)

    X = moving_windows(preprocessed_data, args.ts_dim)
    label = moving_windows(data.to_numpy(), args.ts_dim)

    # Prepare Networks #
    D = Discriminator(args.ts_dim).to(device)
    G = Generator(args.latent_dim, args.ts_dim,
                  args.conditional_dim).to(device)

    # Loss Function #
    if args.criterion == 'l2':
        criterion = nn.MSELoss()
    elif args.criterion == 'wgangp':
        pass
    else:
        raise NotImplementedError

    # Optimizers #
    D_optim = torch.optim.Adam(D.parameters(), lr=args.lr, betas=(0.5, 0.9))
    G_optim = torch.optim.Adam(G.parameters(), lr=args.lr, betas=(0.5, 0.9))

    D_optim_scheduler = get_lr_scheduler(D_optim, args)
    G_optim_scheduler = get_lr_scheduler(G_optim, args)

    # Lists #
    D_losses, G_losses = list(), list()

    # Train #
    print("Training Time Series GAN started with total epoch of {}.".format(
        args.num_epochs))

    for epoch in range(args.num_epochs):

        # Initialize Optimizers #
        G_optim.zero_grad()
        D_optim.zero_grad()

        if args.criterion == 'l2':
            n_critics = 1
        elif args.criterion == 'wgangp':
            n_critics = 5

        #######################
        # Train Discriminator #
        #######################

        for j in range(n_critics):
            series, start_dates = get_samples(X, label, args.batch_size)

            # Data Preparation #
            series = series.to(device)
            noise = torch.randn(args.batch_size, 1, args.latent_dim).to(device)

            # Adversarial Loss using Real Image #
            prob_real = D(series.float())

            if args.criterion == 'l2':
                real_labels = torch.ones(prob_real.size()).to(device)
                D_real_loss = criterion(prob_real, real_labels)

            elif args.criterion == 'wgangp':
                D_real_loss = -torch.mean(prob_real)

            # Adversarial Loss using Fake Image #
            fake_series = G(noise)
            fake_series = torch.cat(
                (series[:, :, :args.conditional_dim].float(),
                 fake_series.float()),
                dim=2)

            prob_fake = D(fake_series.detach())

            if args.criterion == 'l2':
                fake_labels = torch.zeros(prob_fake.size()).to(device)
                D_fake_loss = criterion(prob_fake, fake_labels)

            elif args.criterion == 'wgangp':
                D_fake_loss = torch.mean(prob_fake)
                D_gp_loss = args.lambda_gp * get_gradient_penalty(
                    D, series.float(), fake_series.float(), device)

            # Calculate Total Discriminator Loss #
            D_loss = D_fake_loss + D_real_loss

            if args.criterion == 'wgangp':
                D_loss += args.lambda_gp * D_gp_loss

            # Back Propagation and Update #
            D_loss.backward()
            D_optim.step()

        ###################
        # Train Generator #
        ###################

        # Adversarial Loss #
        fake_series = G(noise)
        fake_series = torch.cat(
            (series[:, :, :args.conditional_dim].float(), fake_series.float()),
            dim=2)
        prob_fake = D(fake_series)

        # Calculate Total Generator Loss #
        if args.criterion == 'l2':
            real_labels = torch.ones(prob_fake.size()).to(device)
            G_loss = criterion(prob_fake, real_labels)

        elif args.criterion == 'wgangp':
            G_loss = -torch.mean(prob_fake)

        # Back Propagation and Update #
        G_loss.backward()
        G_optim.step()

        # Add items to Lists #
        D_losses.append(D_loss.item())
        G_losses.append(G_loss.item())

        ####################
        # Print Statistics #
        ####################

        print("Epochs [{}/{}] | D Loss {:.4f} | G Loss {:.4f}".format(
            epoch + 1, args.num_epochs, np.average(D_losses),
            np.average(G_losses)))

        # Adjust Learning Rate #
        D_optim_scheduler.step()
        G_optim_scheduler.step()

        # Save Model Weights and Series #
        if (epoch + 1) % args.save_every == 0:
            torch.save(
                G.state_dict(),
                os.path.join(
                    args.weights_path,
                    'TimeSeries_Generator_using{}_Epoch_{}.pkl'.format(
                        args.criterion.upper(), epoch + 1)))

            series, fake_series = generate_fake_samples(
                X, label, G, scaler_1, scaler_2, args, device)
            plot_sample(series, fake_series, epoch, args)
            make_csv(series, fake_series, epoch, args)

    print("Training finished.")
コード例 #11
0
def main(args):

    # Device Configuration #
    device = torch.device(
        f'cuda:{args.gpu_num}' if torch.cuda.is_available() else 'cpu')

    # Fix Seed for Reproducibility #
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed(args.seed)

    # Samples, Plots, Weights and CSV Path #
    paths = [
        args.samples_path, args.weights_path, args.csv_path,
        args.inference_path
    ]
    for path in paths:
        make_dirs(path)

    # Prepare Data #
    data = pd.read_csv(args.data_path)[args.column]

    # Prepare Data #
    scaler_1 = StandardScaler()
    scaler_2 = StandardScaler()
    preprocessed_data = pre_processing(data, scaler_1, scaler_2, args.constant,
                                       args.delta)

    train_X, train_Y, test_X, test_Y = prepare_data(data, preprocessed_data,
                                                    args)

    train_X = moving_windows(train_X, args.ts_dim)
    train_Y = moving_windows(train_Y, args.ts_dim)

    test_X = moving_windows(test_X, args.ts_dim)
    test_Y = moving_windows(test_Y, args.ts_dim)

    # Prepare Networks #
    if args.model == 'conv':
        D = ConvDiscriminator(args.ts_dim).to(device)
        G = ConvGenerator(args.latent_dim, args.ts_dim).to(device)

    elif args.model == 'lstm':
        D = LSTMDiscriminator(args.ts_dim).to(device)
        G = LSTMGenerator(args.latent_dim, args.ts_dim).to(device)

    else:
        raise NotImplementedError

    #########
    # Train #
    #########

    if args.mode == 'train':

        # Loss Function #
        if args.criterion == 'l2':
            criterion = nn.MSELoss()

        elif args.criterion == 'wgangp':
            pass

        else:
            raise NotImplementedError

        # Optimizers #
        if args.optim == 'sgd':
            D_optim = torch.optim.SGD(D.parameters(), lr=args.lr, momentum=0.9)
            G_optim = torch.optim.SGD(G.parameters(), lr=args.lr, momentum=0.9)

        elif args.optim == 'adam':
            D_optim = torch.optim.Adam(D.parameters(),
                                       lr=args.lr,
                                       betas=(0., 0.9))
            G_optim = torch.optim.Adam(G.parameters(),
                                       lr=args.lr,
                                       betas=(0., 0.9))

        else:
            raise NotImplementedError

        D_optim_scheduler = get_lr_scheduler(D_optim, args)
        G_optim_scheduler = get_lr_scheduler(G_optim, args)

        # Lists #
        D_losses, G_losses = list(), list()

        # Train #
        print(
            "Training Time Series GAN started with total epoch of {}.".format(
                args.num_epochs))

        for epoch in range(args.num_epochs):

            # Initialize Optimizers #
            G_optim.zero_grad()
            D_optim.zero_grad()

            #######################
            # Train Discriminator #
            #######################

            if args.criterion == 'l2':
                n_critics = 1
            elif args.criterion == 'wgangp':
                n_critics = 5

            for j in range(n_critics):
                series, start_dates = get_samples(train_X, train_Y,
                                                  args.batch_size)

                # Data Preparation #
                series = series.to(device)
                noise = torch.randn(args.batch_size, 1,
                                    args.latent_dim).to(device)

                # Adversarial Loss using Real Image #
                prob_real = D(series.float())

                if args.criterion == 'l2':
                    real_labels = torch.ones(prob_real.size()).to(device)
                    D_real_loss = criterion(prob_real, real_labels)

                elif args.criterion == 'wgangp':
                    D_real_loss = -torch.mean(prob_real)

                # Adversarial Loss using Fake Image #
                fake_series = G(noise)
                prob_fake = D(fake_series.detach())

                if args.criterion == 'l2':
                    fake_labels = torch.zeros(prob_fake.size()).to(device)
                    D_fake_loss = criterion(prob_fake, fake_labels)

                elif args.criterion == 'wgangp':
                    D_fake_loss = torch.mean(prob_fake)
                    D_gp_loss = args.lambda_gp * get_gradient_penalty(
                        D, series.float(), fake_series.float(), device)

                # Calculate Total Discriminator Loss #
                D_loss = D_fake_loss + D_real_loss

                if args.criterion == 'wgangp':
                    D_loss += args.lambda_gp * D_gp_loss

                # Back Propagation and Update #
                D_loss.backward()
                D_optim.step()

            ###################
            # Train Generator #
            ###################

            # Adversarial Loss #
            fake_series = G(noise)
            prob_fake = D(fake_series)

            # Calculate Total Generator Loss #
            if args.criterion == 'l2':
                real_labels = torch.ones(prob_fake.size()).to(device)
                G_loss = criterion(prob_fake, real_labels)

            elif args.criterion == 'wgangp':
                G_loss = -torch.mean(prob_fake)

            # Back Propagation and Update #
            G_loss.backward()
            G_optim.step()

            # Add items to Lists #
            D_losses.append(D_loss.item())
            G_losses.append(G_loss.item())

            # Adjust Learning Rate #
            D_optim_scheduler.step()
            G_optim_scheduler.step()

            # Print Statistics, Save Model Weights and Series #
            if (epoch + 1) % args.log_every == 0:

                # Print Statistics and Save Model #
                print("Epochs [{}/{}] | D Loss {:.4f} | G Loss {:.4f}".format(
                    epoch + 1, args.num_epochs, np.average(D_losses),
                    np.average(G_losses)))
                torch.save(
                    G.state_dict(),
                    os.path.join(
                        args.weights_path,
                        'TS_using{}_and_{}_Epoch_{}.pkl'.format(
                            G.__class__.__name__, args.criterion.upper(),
                            epoch + 1)))

                # Generate Samples and Save Plots and CSVs #
                series, fake_series = generate_fake_samples(
                    test_X, test_Y, G, scaler_1, scaler_2, args, device)
                plot_series(series, fake_series, G, epoch, args,
                            args.samples_path)
                make_csv(series, fake_series, G, epoch, args, args.csv_path)

    ########
    # Test #
    ########

    elif args.mode == 'test':

        # Load Model Weights #
        G.load_state_dict(
            torch.load(
                os.path.join(
                    args.weights_path, 'TS_using{}_and_{}_Epoch_{}.pkl'.format(
                        G.__class__.__name__, args.criterion.upper(),
                        args.num_epochs))))

        # Lists #
        real, fake = list(), list()

        # Inference #
        for idx in range(0, test_X.shape[0], args.ts_dim):

            # Do not plot if the remaining data is less than time dimension #
            end_ix = idx + args.ts_dim

            if end_ix > len(test_X) - 1:
                break

            # Prepare Data #
            test_data = test_X[idx, :]
            test_data = np.expand_dims(test_data, axis=0)
            test_data = np.expand_dims(test_data, axis=1)
            test_data = torch.from_numpy(test_data).to(device)

            start = test_Y[idx, 0]

            noise = torch.randn(args.val_batch_size, 1,
                                args.latent_dim).to(device)

            # Generate Fake Data #
            with torch.no_grad():
                fake_series = G(noise)

            # Convert to Numpy format for Saving #
            test_data = np.squeeze(test_data.cpu().data.numpy())
            fake_series = np.squeeze(fake_series.cpu().data.numpy())

            test_data = post_processing(test_data, start, scaler_1, scaler_2,
                                        args.delta)
            fake_series = post_processing(fake_series, start, scaler_1,
                                          scaler_2, args.delta)

            real += test_data.tolist()
            fake += fake_series.tolist()

        # Plot, Save to CSV file and Derive Metrics #
        plot_series(real, fake, G, args.num_epochs - 1, args,
                    args.inference_path)
        make_csv(real, fake, G, args.num_epochs - 1, args, args.inference_path)
        derive_metrics(real, fake, args)

    else:
        raise NotImplementedError