Exemplo n.º 1
0
def compute_frechet_inception_distance(z, y_fake, x_fake, x, y, args, di=None):
    h_fakes = []
    h_reals = []
    for i in range(args.max_iter):
        logger.info("Compute at {}-th batch".format(i))
        # Generate
        z.d = np.random.randn(args.batch_size, args.latent)
        y_fake.d = generate_random_class(args.n_classes, args.batch_size)
        x_fake.forward(clear_buffer=True)
        # Predict for fake
        x_fake_d = x_fake.d.copy()
        x_fake_d = preprocess(
            x_fake_d, (args.image_size, args.image_size), args.nnp_preprocess)
        x.d = x_fake_d
        y.forward(clear_buffer=True)
        h_fakes.append(y.d.copy().squeeze())
        # Predict for real
        x_d, _ = di.next()
        x_d = preprocess(
            x_d, (args.image_size, args.image_size), args.nnp_preprocess)
        x.d = x_d
        y.forward(clear_buffer=True)
        h_reals.append(y.d.copy().squeeze())
    h_fakes = np.concatenate(h_fakes)
    h_reals = np.concatenate(h_reals)

    # FID score
    ave_h_real = np.mean(h_reals, axis=0)
    ave_h_fake = np.mean(h_fakes, axis=0)
    cov_h_real = np.cov(h_reals, rowvar=False)
    cov_h_fake = np.cov(h_fakes, rowvar=False)
    score = np.sum((ave_h_real - ave_h_fake) ** 2) \
        + np.trace(cov_h_real + cov_h_fake - 2.0 *
                   sqrtm(np.dot(cov_h_real, cov_h_fake)))
    return score
Exemplo n.º 2
0
def compute_inception_score(z, y_fake, x_fake, x, y, args):
    preds = []
    for i in range(args.max_iter):
        logger.info("Compute at {}-th batch".format(i))
        # Generate
        z.d = np.random.randn(args.batch_size, args.latent)
        y_fake.d = generate_random_class(args.n_classes, args.batch_size)
        x_fake.forward(clear_buffer=True)
        # Predict
        x_fake_d = x_fake.d.copy()
        x_fake_d = preprocess(
            x_fake_d, (args.image_size, args.image_size), args.nnp_preprocess)
        x.d = x_fake_d
        y.forward(clear_buffer=True)
        preds.append(y.d.copy())
    p_yx = np.concatenate(preds)
    # Score
    p_y = np.mean(p_yx, axis=0)
    kld = np.sum(p_yx * (np.log(p_yx) - np.log(p_y)), axis=1)
    score = np.exp(np.mean(kld))
    return score
Exemplo n.º 3
0
def generate(args):
    # Communicator and Context
    extension_module = "cudnn"
    ctx = get_extension_context(extension_module, type_config=args.type_config)
    nn.set_default_context(ctx)

    # Args
    latent = args.latent
    maps = args.maps
    batch_size = args.batch_size
    image_size = args.image_size
    n_classes = args.n_classes
    not_sn = args.not_sn
    threshold = args.truncation_threshold

    # Model
    nn.load_parameters(args.model_load_path)
    z = nn.Variable([batch_size, latent])
    y_fake = nn.Variable([batch_size])
    x_fake = generator(z, y_fake, maps=maps, n_classes=n_classes, test=True, sn=not_sn)\
        .apply(persistent=True)

    # Generate All
    if args.generate_all:
        # Monitor
        monitor = Monitor(args.monitor_path)
        name = "Generated Image Tile All"
        monitor_image = MonitorImageTile(name,
                                         monitor,
                                         interval=1,
                                         num_images=args.batch_size,
                                         normalize_method=normalize_method)

        # Generate images for all classes
        for class_id in range(args.n_classes):
            # Generate
            z_data = resample(batch_size, latent, threshold)
            y_data = generate_one_class(class_id, batch_size)

            z.d = z_data
            y_fake.d = y_data
            x_fake.forward(clear_buffer=True)
            monitor_image.add(class_id, x_fake.d)
        return

    # Generate Indivisually
    monitor = Monitor(args.monitor_path)
    name = "Generated Image Tile {}".format(
        args.class_id) if args.class_id != -1 else "Generated Image Tile"
    monitor_image_tile = MonitorImageTile(name,
                                          monitor,
                                          interval=1,
                                          num_images=args.batch_size,
                                          normalize_method=normalize_method)
    name = "Generated Image {}".format(
        args.class_id) if args.class_id != -1 else "Generated Image"
    monitor_image = MonitorImage(name,
                                 monitor,
                                 interval=1,
                                 num_images=args.batch_size,
                                 normalize_method=normalize_method)
    z_data = resample(batch_size, latent, threshold)
    y_data = generate_random_class(n_classes, batch_size) if args.class_id == -1 else \
        generate_one_class(args.class_id, batch_size)
    z.d = z_data
    y_fake.d = y_data
    x_fake.forward(clear_buffer=True)
    monitor_image.add(0, x_fake.d)
    monitor_image_tile.add(0, x_fake.d)
Exemplo n.º 4
0
def train(args):
    # Communicator and Context
    extension_module = "cudnn"
    ctx = get_extension_context(extension_module, type_config=args.type_config)
    comm = C.MultiProcessDataParalellCommunicator(ctx)
    comm.init()
    n_devices = comm.size
    mpi_rank = comm.rank
    device_id = comm.local_rank
    ctx.device_id = str(device_id)
    nn.set_default_context(ctx)

    # Args
    latent = args.latent
    maps = args.maps
    batch_size = args.batch_size
    image_size = args.image_size
    n_classes = args.n_classes
    not_sn = args.not_sn

    # Model
    # workaround to start with the same weights in the distributed system.
    np.random.seed(412)
    # generator loss
    z = nn.Variable([batch_size, latent])
    y_fake = nn.Variable([batch_size])
    x_fake = generator(z, y_fake, maps=maps, n_classes=n_classes,
                       sn=not_sn).apply(persistent=True)
    p_fake = discriminator(x_fake, y_fake, maps=maps //
                           16, n_classes=n_classes, sn=not_sn)
    loss_gen = gan_loss(p_fake)
    # discriminator loss
    y_real = nn.Variable([batch_size])
    x_real = nn.Variable([batch_size, 3, image_size, image_size])
    p_real = discriminator(x_real, y_real, maps=maps //
                           16, n_classes=n_classes, sn=not_sn)
    loss_dis = gan_loss(p_fake, p_real)
    # generator with fixed value for test
    z_test = nn.Variable.from_numpy_array(np.random.randn(batch_size, latent))
    y_test = nn.Variable.from_numpy_array(
        generate_random_class(n_classes, batch_size))
    x_test = generator(z_test, y_test, maps=maps,
                       n_classes=n_classes, test=True, sn=not_sn)

    # Solver
    solver_gen = S.Adam(args.lrg, args.beta1, args.beta2)
    solver_dis = S.Adam(args.lrd, args.beta1, args.beta2)
    with nn.parameter_scope("generator"):
        params_gen = nn.get_parameters()
        solver_gen.set_parameters(params_gen)
    with nn.parameter_scope("discriminator"):
        params_dis = nn.get_parameters()
        solver_dis.set_parameters(params_dis)

    # Monitor
    if comm.rank == 0:
        monitor = Monitor(args.monitor_path)
        monitor_loss_gen = MonitorSeries(
            "Generator Loss", monitor, interval=10)
        monitor_loss_dis = MonitorSeries(
            "Discriminator Loss", monitor, interval=10)
        monitor_time = MonitorTimeElapsed(
            "Training Time", monitor, interval=10)
        monitor_image_tile_train = MonitorImageTile("Image Tile Train", monitor,
                                                    num_images=args.batch_size,
                                                    interval=1,
                                                    normalize_method=normalize_method)
        monitor_image_tile_test = MonitorImageTile("Image Tile Test", monitor,
                                                   num_images=args.batch_size,
                                                   interval=1,
                                                   normalize_method=normalize_method)
    # DataIterator
    rng = np.random.RandomState(device_id)
    di = data_iterator_imagenet(args.train_dir, args.dirname_to_label_path,
                                args.batch_size, n_classes=args.n_classes,
                                rng=rng)

    # Train loop
    for i in range(args.max_iter):
        # Train discriminator
        x_fake.need_grad = False  # no need for discriminator backward
        solver_dis.zero_grad()
        for _ in range(args.accum_grad):
            # feed x_real and y_real
            x_data, y_data = di.next()
            x_real.d, y_real.d = x_data, y_data.flatten()
            # feed z and y_fake
            z_data = np.random.randn(args.batch_size, args.latent)
            y_data = generate_random_class(args.n_classes, args.batch_size)
            z.d, y_fake.d = z_data, y_data
            loss_dis.forward(clear_no_need_grad=True)
            loss_dis.backward(
                1.0 / (args.accum_grad * n_devices), clear_buffer=True)
        comm.all_reduce([v.grad for v in params_dis.values()])
        solver_dis.update()

        # Train genrator
        x_fake.need_grad = True  # need for generator backward
        solver_gen.zero_grad()
        for _ in range(args.accum_grad):
            z_data = np.random.randn(args.batch_size, args.latent)
            y_data = generate_random_class(args.n_classes, args.batch_size)
            z.d, y_fake.d = z_data, y_data
            loss_gen.forward(clear_no_need_grad=True)
            loss_gen.backward(
                1.0 / (args.accum_grad * n_devices), clear_buffer=True)
        comm.all_reduce([v.grad for v in params_gen.values()])
        solver_gen.update()

        # Synchronize by averaging the weights over devices using allreduce
        if i % args.sync_weight_every_itr == 0:
            weights = [v.data for v in nn.get_parameters().values()]
            comm.all_reduce(weights, division=True, inplace=True)

        # Save model and image
        if i % args.save_interval == 0 and comm.rank == 0:
            x_test.forward(clear_buffer=True)
            nn.save_parameters(os.path.join(
                args.monitor_path, "params_{}.h5".format(i)))
            monitor_image_tile_train.add(i, x_fake.d)
            monitor_image_tile_test.add(i, x_test.d)

        # Monitor
        if comm.rank == 0:
            monitor_loss_gen.add(i, loss_gen.d.copy())
            monitor_loss_dis.add(i, loss_dis.d.copy())
            monitor_time.add(i)

    if comm.rank == 0:
        x_test.forward(clear_buffer=True)
        nn.save_parameters(os.path.join(
            args.monitor_path, "params_{}.h5".format(i)))
        monitor_image_tile_train.add(i, x_fake.d)
        monitor_image_tile_test.add(i, x_test.d)