Ejemplo n.º 1
0
def generate(args):
    # Context
    ctx = get_extension_context(args.context,
                                device_id=args.device_id,
                                type_config=args.type_config)
    nn.set_default_context(ctx)

    # Args
    latent = args.latent
    maps = args.maps
    batch_size = args.batch_size

    # Generator
    nn.load_parameters(args.model_load_path)
    z_test = nn.Variable([batch_size, latent])
    x_test = generator(z_test, maps=maps, test=True, up=args.up)

    # Monitor
    monitor = Monitor(args.monitor_path)
    monitor_image_tile_test = MonitorImageTile("Image Tile Generated",
                                               monitor,
                                               num_images=batch_size,
                                               interval=1,
                                               normalize_method=denormalize)

    # Generation iteration
    for i in range(args.num_generation):
        z_test.d = np.random.randn(batch_size, latent)
        x_test.forward(clear_buffer=True)
        monitor_image_tile_test.add(i, x_test)
Ejemplo n.º 2
0
def generate(args):
    ctx = get_extension_context(args.context,
                                device_id=args.device_id,
                                type_config=args.type_config)
    nn.set_default_context(ctx)

    scope_gen = "Generator"
    scope_gen_ema = "Generator_EMA"
    gen_param_path = args.model_load_path + '/Gen_iter100000.h5'
    gen_ema_param_path = args.model_load_path + '/GenEMA_iter100000.h5'
    with nn.parameter_scope(scope_gen):
        nn.load_parameters(gen_param_path)
    with nn.parameter_scope(scope_gen_ema):
        nn.load_parameters(gen_ema_param_path)

    monitor = Monitor(args.monitor_path)
    monitor_image_tile_test = MonitorImageTile("Image Tile",
                                               monitor,
                                               num_images=args.batch_size,
                                               interval=1,
                                               normalize_method=lambda x:
                                               (x + 1.) / 2.)
    monitor_image_tile_test_ema = MonitorImageTile("Image Tile with EMA",
                                                   monitor,
                                                   num_images=args.batch_size,
                                                   interval=1,
                                                   normalize_method=lambda x:
                                                   (x + 1.) / 2.)

    z_test = nn.Variable([args.batch_size, args.latent, 1, 1])
    x_test = Generator(z_test,
                       scope_name=scope_gen,
                       train=True,
                       img_size=args.image_size)[0]
    x_test_ema = Generator(z_test,
                           scope_name=scope_gen_ema,
                           train=True,
                           img_size=args.image_size)[0]
    z_test.d = np.random.randn(args.batch_size, args.latent, 1, 1)

    x_test.forward(clear_buffer=True)
    x_test_ema.forward(clear_buffer=True)
    monitor_image_tile_test.add(0, x_test)
    monitor_image_tile_test_ema.add(0, x_test_ema)
Ejemplo n.º 3
0
def train(data_iterator, monitor, config, comm, args):
    monitor_train_loss, monitor_train_recon = None, None
    monitor_val_loss, monitor_val_recon = None, None
    if comm.rank == 0:
        monitor_train_loss = MonitorSeries(
            config['monitor']['train_loss'], monitor, interval=config['train']['logger_step_interval'])
        monitor_train_recon = MonitorImageTile(config['monitor']['train_recon'], monitor, interval=config['train']['logger_step_interval'],
                                               num_images=config['train']['batch_size'])

        monitor_val_loss = MonitorSeries(
            config['monitor']['val_loss'], monitor, interval=config['train']['logger_step_interval'])
        monitor_val_recon = MonitorImageTile(config['monitor']['val_recon'], monitor, interval=config['train']['logger_step_interval'],
                                             num_images=config['train']['batch_size'])

    model = VQVAE(config)

    if not args.sample_from_pixelcnn:
        if config['train']['solver'] == 'adam':
            solver = S.Adam()
        else:
            solver = S.momentum()
        solver.set_learning_rate(config['train']['learning_rate'])

        train_loader = data_iterator(config, comm, train=True)
        if config['dataset']['name'] != 'imagenet':
            val_loader = data_iterator(config, comm, train=False)
        else:
            val_loader = None
    else:
        solver, train_loader, val_loader = None, None, None

    if not args.pixelcnn_prior:
        trainer = VQVAEtrainer(model, solver, train_loader, val_loader, monitor_train_loss,
                               monitor_train_recon, monitor_val_loss, monitor_val_recon, config, comm)
        num_epochs = config['train']['num_epochs']
    else:
        pixelcnn_model = GatedPixelCNN(config['prior'])
        trainer = TrainerPrior(model, pixelcnn_model, solver, train_loader, val_loader, monitor_train_loss,
                               monitor_train_recon, monitor_val_loss, monitor_val_recon, config, comm, eval=args.sample_from_pixelcnn)
        num_epochs = config['prior']['train']['num_epochs']

    if os.path.exists(config['model']['checkpoint']) and (args.load_checkpoint or args.sample_from_pixelcnn):
        checkpoint_path = config['model']['checkpoint'] if not args.pixelcnn_prior else config['prior']['checkpoint']
        trainer.load_checkpoint(checkpoint_path, msg='Parameters loaded from {}'.format(
            checkpoint_path), pixelcnn=args.pixelcnn_prior, load_solver=not args.sample_from_pixelcnn)

    if args.sample_from_pixelcnn:
        trainer.random_generate(
            args.sample_from_pixelcnn, args.sample_save_path)
        return

    for epoch in range(num_epochs):

        trainer.train(epoch)

        if epoch % config['val']['interval'] == 0 and val_loader != None:
            trainer.validate(epoch)

        if comm.rank == 0:
            if epoch % config['train']['save_param_step_interval'] == 0 or epoch == config['train']['num_epochs']-1:
                trainer.save_checkpoint(
                    config['model']['saved_models_dir'], epoch, pixelcnn=args.pixelcnn_prior)
Ejemplo n.º 4
0
def train(args):
    # Context
    ctx = get_extension_context(args.context,
                                device_id=args.device_id,
                                type_config=args.type_config)
    nn.set_default_context(ctx)

    # Args
    latent = args.latent
    maps = args.maps
    batch_size = args.batch_size
    image_size = args.image_size
    lambda_ = args.lambda_

    # Model
    # generator loss
    z = nn.Variable([batch_size, latent])
    x_fake = generator(z, maps=maps, up=args.up).apply(persistent=True)
    p_fake = discriminator(x_fake, maps=maps)
    loss_gen = gan_loss(p_fake).apply(persistent=True)
    # discriminator loss
    p_fake = discriminator(x_fake, maps=maps)
    x_real = nn.Variable([batch_size, 3, image_size, image_size])
    p_real = discriminator(x_real, maps=maps)
    loss_dis = gan_loss(p_fake, p_real).apply(persistent=True)
    # gradient penalty
    eps = F.rand(shape=[batch_size, 1, 1, 1])
    x_rmix = eps * x_real + (1.0 - eps) * x_fake
    p_rmix = discriminator(x_rmix, maps=maps)
    x_rmix.need_grad = True  # Enabling gradient computation for double backward
    grads = nn.grad([p_rmix], [x_rmix])
    l2norms = [F.sum(g**2.0, [1, 2, 3])**0.5 for g in grads]
    gp = sum([F.mean((l - 1.0)**2.0) for l in l2norms])
    loss_dis += lambda_ * gp
    # generator with fixed value for test
    z_test = nn.Variable.from_numpy_array(np.random.randn(batch_size, latent))
    x_test = generator(z_test, maps=maps, test=True,
                       up=args.up).apply(persistent=True)

    # Solver
    solver_gen = S.Adam(args.lrg, args.beta1, args.beta2)
    solver_dis = S.Adam(args.lrd, args.beta1, args.beta2)

    with nn.parameter_scope("generator"):
        params_gen = nn.get_parameters()
        solver_gen.set_parameters(params_gen)
    with nn.parameter_scope("discriminator"):
        params_dis = nn.get_parameters()
        solver_dis.set_parameters(params_dis)

    # Monitor
    monitor = Monitor(args.monitor_path)
    monitor_loss_gen = MonitorSeries("Generator Loss", monitor, interval=10)
    monitor_loss_cri = MonitorSeries("Negative Critic Loss",
                                     monitor,
                                     interval=10)
    monitor_time = MonitorTimeElapsed("Training Time", monitor, interval=10)
    monitor_image_tile_train = MonitorImageTile("Image Tile Train",
                                                monitor,
                                                num_images=batch_size,
                                                interval=1,
                                                normalize_method=denormalize)
    monitor_image_tile_test = MonitorImageTile("Image Tile Test",
                                               monitor,
                                               num_images=batch_size,
                                               interval=1,
                                               normalize_method=denormalize)

    # Data Iterator
    di = data_iterator_cifar10(batch_size, True)

    # Train loop
    for i in range(args.max_iter):
        # Train discriminator
        x_fake.need_grad = False  # no need backward to generator
        for _ in range(args.n_critic):
            solver_dis.zero_grad()
            x_real.d = di.next()[0] / 127.5 - 1.0
            z.d = np.random.randn(batch_size, latent)
            loss_dis.forward(clear_no_need_grad=True)
            loss_dis.backward(clear_buffer=True)
            solver_dis.update()

        # Train generator
        x_fake.need_grad = True  # need backward to generator
        solver_gen.zero_grad()
        z.d = np.random.randn(batch_size, latent)
        loss_gen.forward(clear_no_need_grad=True)
        loss_gen.backward(clear_buffer=True)
        solver_gen.update()
        # Monitor
        monitor_loss_gen.add(i, loss_gen.d)
        monitor_loss_cri.add(i, -loss_dis.d)
        monitor_time.add(i)

        # Save
        if i % args.save_interval == 0:
            monitor_image_tile_train.add(i, x_fake)
            monitor_image_tile_test.add(i, x_test)
            nn.save_parameters(
                os.path.join(args.monitor_path, "params_{}.h5".format(i)))

    # Last
    x_test.forward(clear_buffer=True)
    nn.save_parameters(
        os.path.join(args.monitor_path, "params_{}.h5".format(i)))
    monitor_image_tile_train.add(i, x_fake)
    monitor_image_tile_test.add(i, x_test)
Ejemplo n.º 5
0
def generate(args):
    # Communicator and Context
    extension_module = "cudnn"
    ctx = get_extension_context(extension_module, type_config=args.type_config)
    nn.set_default_context(ctx)

    # Args
    latent = args.latent
    maps = args.maps
    batch_size = args.batch_size
    image_size = args.image_size
    n_classes = args.n_classes
    not_sn = args.not_sn
    threshold = args.truncation_threshold

    # Model
    nn.load_parameters(args.model_load_path)
    z = nn.Variable([batch_size, latent])
    y_fake = nn.Variable([batch_size])
    x_fake = generator(z, y_fake, maps=maps, n_classes=n_classes, test=True, sn=not_sn)\
        .apply(persistent=True)

    # Generate All
    if args.generate_all:
        # Monitor
        monitor = Monitor(args.monitor_path)
        name = "Generated Image Tile All"
        monitor_image = MonitorImageTile(name,
                                         monitor,
                                         interval=1,
                                         num_images=args.batch_size,
                                         normalize_method=normalize_method)

        # Generate images for all classes
        for class_id in range(args.n_classes):
            # Generate
            z_data = resample(batch_size, latent, threshold)
            y_data = generate_one_class(class_id, batch_size)

            z.d = z_data
            y_fake.d = y_data
            x_fake.forward(clear_buffer=True)
            monitor_image.add(class_id, x_fake.d)
        return

    # Generate Indivisually
    monitor = Monitor(args.monitor_path)
    name = "Generated Image Tile {}".format(
        args.class_id) if args.class_id != -1 else "Generated Image Tile"
    monitor_image_tile = MonitorImageTile(name,
                                          monitor,
                                          interval=1,
                                          num_images=args.batch_size,
                                          normalize_method=normalize_method)
    name = "Generated Image {}".format(
        args.class_id) if args.class_id != -1 else "Generated Image"
    monitor_image = MonitorImage(name,
                                 monitor,
                                 interval=1,
                                 num_images=args.batch_size,
                                 normalize_method=normalize_method)
    z_data = resample(batch_size, latent, threshold)
    y_data = generate_random_class(n_classes, batch_size) if args.class_id == -1 else \
        generate_one_class(args.class_id, batch_size)
    z.d = z_data
    y_fake.d = y_data
    x_fake.forward(clear_buffer=True)
    monitor_image.add(0, x_fake.d)
    monitor_image_tile.add(0, x_fake.d)
Ejemplo n.º 6
0
    def __init__(self, monitor, config, args, comm, few_shot_config):
        super(Train, self).__init__(monitor, config, args, comm,
                                    few_shot_config)

        # Initialize Monitor
        self.monitor_train_loss, self.monitor_train_gen = None, None
        self.monitor_val_loss, self.monitor_val_gen = None, None
        if comm is not None:
            if comm.rank == 0:
                self.monitor_train_gen_loss = MonitorSeries(
                    config['monitor']['train_loss'],
                    monitor,
                    interval=self.config['logger_step_interval'])
                self.monitor_train_gen = MonitorImageTile(
                    config['monitor']['train_gen'],
                    monitor,
                    interval=self.config['logger_step_interval'],
                    num_images=self.config['batch_size'])
                self.monitor_train_disc_loss = MonitorSeries(
                    config['monitor']['train_loss'],
                    monitor,
                    interval=self.config['logger_step_interval'])

        os.makedirs(self.config['saved_weights_dir'], exist_ok=True)
        self.results_dir = args.results_dir
        self.save_weights_dir = args.weights_path

        self.few_shot_config = few_shot_config

        # Initialize Discriminator
        self.discriminator = Discriminator(config['discriminator'],
                                           self.img_size)
        self.gen_exp_weight = 0.5**(32 / (10 * 1000))
        self.generator_ema = Generator(config['generator'],
                                       self.img_size,
                                       config['train']['mix_after'],
                                       global_scope='GeneratorEMA')

        # Initialize Solver
        if 'gen_solver' not in dir(self):
            if self.config['solver'] == 'Adam':
                self.gen_solver = S.Adam(beta1=0, beta2=0.99)
                self.disc_solver = S.Adam(beta1=0, beta2=0.99)
            else:
                self.gen_solver = eval('S.' + self.config['solver'])()
                self.disc_solver = eval('S.' + self.config['solver'])()

        self.gen_solver.set_learning_rate(self.config['learning_rate'])
        self.disc_solver.set_learning_rate(self.config['learning_rate'])

        self.gen_mean_path_length = 0.0

        self.dali = args.dali
        self.args = args
        # Initialize Dataloader
        if args.data == 'ffhq':
            if args.dali:
                self.train_loader = get_dali_iterator_ffhq(
                    args.dataset_path, config['data'], self.img_size,
                    self.batch_size, self.comm)
            else:
                self.train_loader = get_data_iterator_ffhq(
                    args.dataset_path, config['data'], self.batch_size,
                    self.img_size, self.comm)
        else:
            print('Dataset not recognized')
            exit(1)

        # Start training
        self.train()
Ejemplo n.º 7
0
def match(args):
    # Context
    extension_module = "cudnn"
    ctx = get_extension_context(extension_module, device_id=args.device_id,
                                type_config=args.type_config)
    nn.set_default_context(ctx)

    # Args
    latent = args.latent
    maps = args.maps
    batch_size = 1
    image_size = args.image_size
    n_classes = args.n_classes
    not_sn = args.not_sn
    threshold = args.truncation_threshold

    # Model (SAGAN)
    nn.load_parameters(args.model_load_path)
    z = nn.Variable([batch_size, latent])
    y_fake = nn.Variable([batch_size])
    x_fake = generator(z, y_fake, maps=maps, n_classes=n_classes, test=True, sn=not_sn)\
        .apply(persistent=True)

    # Model (Inception model) from nnp file
    nnp = NnpLoader(args.nnp_inception_model_load_path)
    x, h = get_input_and_output(nnp, batch_size, args.variable_name)

    # DataIterator for a given class_id
    di = data_iterator_imagenet(args.train_dir, args.dirname_to_label_path,
                                batch_size=batch_size, n_classes=args.n_classes,
                                noise=False,
                                class_id=args.class_id)

    # Monitor
    monitor = Monitor(args.monitor_path)
    name = "Matched Image {}".format(args.class_id)
    monitor_image = MonitorImage(name, monitor, interval=1,
                                 num_images=batch_size,
                                 normalize_method=lambda x: (x + 1.) / 2. * 255.)
    name = "Matched Image Tile {}".format(args.class_id)
    monitor_image_tile = MonitorImageTile(name, monitor, interval=1,
                                          num_images=batch_size + args.top_n,
                                          normalize_method=lambda x: (x + 1.) / 2. * 255.)

    # Generate and p(h|x).forward
    # generate
    z_data = resample(batch_size, latent, threshold)
    y_data = generate_one_class(args.class_id, batch_size)
    z.d = z_data
    y_fake.d = y_data
    x_fake.forward(clear_buffer=True)
    # p(h|x).forward
    x_fake_d = x_fake.d.copy()
    x_fake_d = preprocess(
        x_fake_d, (args.image_size, args.image_size), args.nnp_preprocess)
    x.d = x_fake_d
    h.forward(clear_buffer=True)
    h_fake_d = h.d.copy()

    # Feature matching
    norm2_list = []
    x_data_list = []
    x_data_list.append(x_fake.d)
    for i in range(di.size):
        # forward for real data
        x_d, _ = di.next()
        x_data_list.append(x_d)
        x_d = preprocess(
            x_d, (args.image_size, args.image_size), args.nnp_preprocess)
        x.d = x_d
        h.forward(clear_buffer=True)
        h_real_d = h.d.copy()
        # norm computation
        axis = tuple(np.arange(1, len(h.shape)).tolist())
        norm2 = np.sum((h_real_d - h_fake_d) ** 2.0, axis=axis)
        norm2_list.append(norm2)

    # Save top-n images
    argmins = np.argsort(norm2_list)
    for i in range(args.top_n):
        monitor_image.add(i, x_data_list[i])
    matched_images = np.concatenate(x_data_list)
    monitor_image_tile.add(0, matched_images)
Ejemplo n.º 8
0
def train(args):
    # Communicator and Context
    extension_module = "cudnn"
    ctx = get_extension_context(extension_module, type_config=args.type_config)
    comm = C.MultiProcessDataParalellCommunicator(ctx)
    comm.init()
    n_devices = comm.size
    mpi_rank = comm.rank
    device_id = comm.local_rank
    ctx.device_id = str(device_id)
    nn.set_default_context(ctx)

    # Args
    latent = args.latent
    maps = args.maps
    batch_size = args.batch_size
    image_size = args.image_size
    n_classes = args.n_classes
    not_sn = args.not_sn

    # Model
    # workaround to start with the same weights in the distributed system.
    np.random.seed(412)
    # generator loss
    z = nn.Variable([batch_size, latent])
    y_fake = nn.Variable([batch_size])
    x_fake = generator(z, y_fake, maps=maps, n_classes=n_classes,
                       sn=not_sn).apply(persistent=True)
    p_fake = discriminator(x_fake, y_fake, maps=maps //
                           16, n_classes=n_classes, sn=not_sn)
    loss_gen = gan_loss(p_fake)
    # discriminator loss
    y_real = nn.Variable([batch_size])
    x_real = nn.Variable([batch_size, 3, image_size, image_size])
    p_real = discriminator(x_real, y_real, maps=maps //
                           16, n_classes=n_classes, sn=not_sn)
    loss_dis = gan_loss(p_fake, p_real)
    # generator with fixed value for test
    z_test = nn.Variable.from_numpy_array(np.random.randn(batch_size, latent))
    y_test = nn.Variable.from_numpy_array(
        generate_random_class(n_classes, batch_size))
    x_test = generator(z_test, y_test, maps=maps,
                       n_classes=n_classes, test=True, sn=not_sn)

    # Solver
    solver_gen = S.Adam(args.lrg, args.beta1, args.beta2)
    solver_dis = S.Adam(args.lrd, args.beta1, args.beta2)
    with nn.parameter_scope("generator"):
        params_gen = nn.get_parameters()
        solver_gen.set_parameters(params_gen)
    with nn.parameter_scope("discriminator"):
        params_dis = nn.get_parameters()
        solver_dis.set_parameters(params_dis)

    # Monitor
    if comm.rank == 0:
        monitor = Monitor(args.monitor_path)
        monitor_loss_gen = MonitorSeries(
            "Generator Loss", monitor, interval=10)
        monitor_loss_dis = MonitorSeries(
            "Discriminator Loss", monitor, interval=10)
        monitor_time = MonitorTimeElapsed(
            "Training Time", monitor, interval=10)
        monitor_image_tile_train = MonitorImageTile("Image Tile Train", monitor,
                                                    num_images=args.batch_size,
                                                    interval=1,
                                                    normalize_method=normalize_method)
        monitor_image_tile_test = MonitorImageTile("Image Tile Test", monitor,
                                                   num_images=args.batch_size,
                                                   interval=1,
                                                   normalize_method=normalize_method)
    # DataIterator
    rng = np.random.RandomState(device_id)
    di = data_iterator_imagenet(args.train_dir, args.dirname_to_label_path,
                                args.batch_size, n_classes=args.n_classes,
                                rng=rng)

    # Train loop
    for i in range(args.max_iter):
        # Train discriminator
        x_fake.need_grad = False  # no need for discriminator backward
        solver_dis.zero_grad()
        for _ in range(args.accum_grad):
            # feed x_real and y_real
            x_data, y_data = di.next()
            x_real.d, y_real.d = x_data, y_data.flatten()
            # feed z and y_fake
            z_data = np.random.randn(args.batch_size, args.latent)
            y_data = generate_random_class(args.n_classes, args.batch_size)
            z.d, y_fake.d = z_data, y_data
            loss_dis.forward(clear_no_need_grad=True)
            loss_dis.backward(
                1.0 / (args.accum_grad * n_devices), clear_buffer=True)
        comm.all_reduce([v.grad for v in params_dis.values()])
        solver_dis.update()

        # Train genrator
        x_fake.need_grad = True  # need for generator backward
        solver_gen.zero_grad()
        for _ in range(args.accum_grad):
            z_data = np.random.randn(args.batch_size, args.latent)
            y_data = generate_random_class(args.n_classes, args.batch_size)
            z.d, y_fake.d = z_data, y_data
            loss_gen.forward(clear_no_need_grad=True)
            loss_gen.backward(
                1.0 / (args.accum_grad * n_devices), clear_buffer=True)
        comm.all_reduce([v.grad for v in params_gen.values()])
        solver_gen.update()

        # Synchronize by averaging the weights over devices using allreduce
        if i % args.sync_weight_every_itr == 0:
            weights = [v.data for v in nn.get_parameters().values()]
            comm.all_reduce(weights, division=True, inplace=True)

        # Save model and image
        if i % args.save_interval == 0 and comm.rank == 0:
            x_test.forward(clear_buffer=True)
            nn.save_parameters(os.path.join(
                args.monitor_path, "params_{}.h5".format(i)))
            monitor_image_tile_train.add(i, x_fake.d)
            monitor_image_tile_test.add(i, x_test.d)

        # Monitor
        if comm.rank == 0:
            monitor_loss_gen.add(i, loss_gen.d.copy())
            monitor_loss_dis.add(i, loss_dis.d.copy())
            monitor_time.add(i)

    if comm.rank == 0:
        x_test.forward(clear_buffer=True)
        nn.save_parameters(os.path.join(
            args.monitor_path, "params_{}.h5".format(i)))
        monitor_image_tile_train.add(i, x_fake.d)
        monitor_image_tile_test.add(i, x_test.d)
Ejemplo n.º 9
0
def interpolate(args):
    # Load model
    nn.load_parameters(args.model_load_path)

    # Context
    extension_module = "cudnn"
    ctx = get_extension_context(extension_module, type_config=args.type_config)
    nn.set_default_context(ctx)

    # Input
    b, c, h, w = 1, 3, args.image_size, args.image_size
    x_real_a = nn.Variable([b, c, h, w])
    x_real_b = nn.Variable([b, c, h, w])
    one = nn.Variable.from_numpy_array(np.ones((1, 1, 1, 1)) * 0.5)

    # Model
    maps = args.maps
    # content/style (domain A)
    x_content_a = content_encoder(x_real_a, maps, name="content-encoder-a")
    x_style_a = style_encoder(x_real_a, maps, name="style-encoder-a")
    # content/style (domain B)
    x_content_b = content_encoder(x_real_b, maps, name="content-encoder-b")
    x_style_b = style_encoder(x_real_b, maps, name="style-encoder-b")
    # generate over domains and reconstruction of content and style (domain A)
    z_style_a = nn.Variable(
        x_style_a.shape) if not args.example_guided else x_style_a
    z_style_a = z_style_a.apply(persistent=True)
    x_fake_a = decoder(x_content_b, z_style_a, name="decoder-a")
    # generate over domains and reconstruction of content and style (domain B)
    z_style_b = nn.Variable(
        x_style_b.shape) if not args.example_guided else x_style_b
    z_style_b = z_style_b.apply(persistent=True)
    x_fake_b = decoder(x_content_a, z_style_b, name="decoder-b")

    # Monitor
    def file_names(path):
        return path.split("/")[-1].rstrip("_AB.jpg")

    suffix = "Stochastic" if not args.example_guided else "Example-guided"
    monitor = Monitor(args.monitor_path)
    monitor_image_tile_a = MonitorImageTile(
        "Fake Image Tile {} B to A {} Interpolation".format(
            "-".join([file_names(path) for path in args.img_files_b]), suffix),
        monitor,
        interval=1,
        num_images=len(args.img_files_b))
    monitor_image_tile_b = MonitorImageTile(
        "Fake Image Tile {} A to B {} Interpolation".format(
            "-".join([file_names(path) for path in args.img_files_a]), suffix),
        monitor,
        interval=1,
        num_images=len(args.img_files_a))

    # DataIterator
    di_a = munit_data_iterator(args.img_files_a, b, shuffle=False)
    di_b = munit_data_iterator(args.img_files_b, b, shuffle=False)
    rng = np.random.RandomState(args.seed)

    # Interpolate (A -> B)
    z_data_0 = [rng.randn(*z_style_a.shape) for j in range(di_a.size)]
    z_data_1 = [rng.randn(*z_style_a.shape) for j in range(di_a.size)]
    for i in range(args.num_repeats):
        r = 1.0 * i / args.num_repeats
        images = []
        for j in range(di_a.size):
            x_data_a = di_a.next()[0]
            x_real_a.d = x_data_a
            z_style_b.d = z_data_0[j] * (1.0 - r) + z_data_1[j] * r
            x_fake_b.forward(clear_buffer=True)
            cmp_image = np.concatenate([x_data_a, x_fake_b.d.copy()], axis=3)
            images.append(cmp_image)
        images = np.concatenate(images)
        monitor_image_tile_b.add(i, images)

    # Interpolate (B -> A)
    z_data_0 = [rng.randn(*z_style_b.shape) for j in range(di_b.size)]
    z_data_1 = [rng.randn(*z_style_b.shape) for j in range(di_b.size)]
    for i in range(args.num_repeats):
        r = 1.0 * i / args.num_repeats
        images = []
        for j in range(di_b.size):
            x_data_b = di_b.next()[0]
            x_real_b.d = x_data_b
            z_style_a.d = z_data_0[j] * (1.0 - r) + z_data_1[j] * r
            x_fake_a.forward(clear_buffer=True)
            cmp_image = np.concatenate([x_data_b, x_fake_a.d.copy()], axis=3)
            images.append(cmp_image)
        images = np.concatenate(images)
        monitor_image_tile_a.add(i, images)
Ejemplo n.º 10
0
def train(args):
    # Context
    ctx = get_extension_context(
        args.context, device_id=args.device_id, type_config=args.type_config)
    nn.set_default_context(ctx)

    aug_list = args.aug_list

    # Model
    scope_gen = "Generator"
    scope_dis = "Discriminator"
    # generator loss
    z = nn.Variable([args.batch_size, args.latent, 1, 1])
    x_fake = Generator(z, scope_name=scope_gen, img_size=args.image_size)
    p_fake = Discriminator([augment(xf, aug_list)
                            for xf in x_fake], label="fake", scope_name=scope_dis)
    lossG = loss_gen(p_fake)
    # discriminator loss
    x_real = nn.Variable(
        [args.batch_size, 3, args.image_size, args.image_size])
    x_real_aug = augment(x_real, aug_list)
    p_real, rec_imgs, part = Discriminator(
        x_real_aug, label="real", scope_name=scope_dis)
    lossD_fake = loss_dis_fake(p_fake)
    lossD_real = loss_dis_real(p_real, rec_imgs, part, x_real_aug)
    lossD = lossD_fake + lossD_real
    # generator with fixed latent values for test
    # Use train=True even in an inference phase
    z_test = nn.Variable.from_numpy_array(
        np.random.randn(args.batch_size, args.latent, 1, 1))
    x_test = Generator(z_test, scope_name=scope_gen,
                       train=True, img_size=args.image_size)[0]

    # Exponential Moving Average (EMA) model
    # Use train=True even in an inference phase
    scope_gen_ema = "Generator_EMA"
    x_test_ema = Generator(z_test, scope_name=scope_gen_ema,
                           train=True, img_size=args.image_size)[0]
    copy_params(scope_gen, scope_gen_ema)
    update_ema_var = make_ema_updater(scope_gen_ema, scope_gen, 0.999)

    # Solver
    solver_gen = S.Adam(args.lr, beta1=0.5)
    solver_dis = S.Adam(args.lr, beta1=0.5)
    with nn.parameter_scope(scope_gen):
        params_gen = nn.get_parameters()
        solver_gen.set_parameters(params_gen)
    with nn.parameter_scope(scope_dis):
        params_dis = nn.get_parameters()
        solver_dis.set_parameters(params_dis)

    # Monitor
    monitor = Monitor(args.monitor_path)
    monitor_loss_gen = MonitorSeries(
        "Generator Loss", monitor, interval=10)
    monitor_loss_dis_real = MonitorSeries(
        "Discriminator Loss Real", monitor, interval=10)
    monitor_loss_dis_fake = MonitorSeries(
        "Discriminator Loss Fake", monitor, interval=10)
    monitor_time = MonitorTimeElapsed(
        "Training Time", monitor, interval=10)
    monitor_image_tile_train = MonitorImageTile("Image Tile Train", monitor,
                                                num_images=args.batch_size,
                                                interval=1,
                                                normalize_method=lambda x: (x + 1.) / 2.)
    monitor_image_tile_test = MonitorImageTile("Image Tile Test", monitor,
                                               num_images=args.batch_size,
                                               interval=1,
                                               normalize_method=lambda x: (x + 1.) / 2.)
    monitor_image_tile_test_ema = MonitorImageTile("Image Tile Test EMA", monitor,
                                                   num_images=args.batch_size,
                                                   interval=1,
                                                   normalize_method=lambda x: (x + 1.) / 2.)

    # Data Iterator
    rng = np.random.RandomState(141)
    di = data_iterator(args.img_path, args.batch_size,
                       imsize=(args.image_size, args.image_size),
                       num_samples=args.train_samples, rng=rng)

    # Train loop
    for i in range(args.max_iter):
        # Train discriminator
        x_fake[0].need_grad = False  # no need backward to generator
        x_fake[1].need_grad = False  # no need backward to generator
        solver_dis.zero_grad()
        x_real.d = di.next()[0]
        z.d = np.random.randn(args.batch_size, args.latent, 1, 1)
        lossD.forward()
        lossD.backward()
        solver_dis.update()

        # Train generator
        x_fake[0].need_grad = True  # need backward to generator
        x_fake[1].need_grad = True  # need backward to generator
        solver_gen.zero_grad()
        lossG.forward()
        lossG.backward()
        solver_gen.update()

        # Update EMA model
        update_ema_var.forward()

        # Monitor
        monitor_loss_gen.add(i, lossG.d)
        monitor_loss_dis_real.add(i, lossD_real.d)
        monitor_loss_dis_fake.add(i, lossD_fake.d)
        monitor_time.add(i)

        # Save
        if (i+1) % args.save_interval == 0:
            with nn.parameter_scope(scope_gen):
                nn.save_parameters(os.path.join(
                    args.monitor_path, "Gen_iter{}.h5".format(i+1)))
            with nn.parameter_scope(scope_gen_ema):
                nn.save_parameters(os.path.join(
                    args.monitor_path, "GenEMA_iter{}.h5".format(i+1)))
            with nn.parameter_scope(scope_dis):
                nn.save_parameters(os.path.join(
                    args.monitor_path, "Dis_iter{}.h5".format(i+1)))
        if (i+1) % args.test_interval == 0:
            x_test.forward(clear_buffer=True)
            x_test_ema.forward(clear_buffer=True)
            monitor_image_tile_train.add(i+1, x_fake[0])
            monitor_image_tile_test.add(i+1, x_test)
            monitor_image_tile_test_ema.add(i+1, x_test_ema)

    # Last
    x_test.forward(clear_buffer=True)
    x_test_ema.forward(clear_buffer=True)
    monitor_image_tile_train.add(args.max_iter, x_fake[0])
    monitor_image_tile_test.add(args.max_iter, x_test)
    monitor_image_tile_test_ema.add(args.max_iter, x_test_ema)
    with nn.parameter_scope(scope_gen):
        nn.save_parameters(os.path.join(args.monitor_path,
                                        "Gen_iter{}.h5".format(args.max_iter)))
    with nn.parameter_scope(scope_gen_ema):
        nn.save_parameters(os.path.join(args.monitor_path,
                                        "GenEMA_iter{}.h5".format(args.max_iter)))
    with nn.parameter_scope(scope_dis):
        nn.save_parameters(os.path.join(args.monitor_path,
                                        "Dis_iter{}.h5".format(args.max_iter)))
Ejemplo n.º 11
0
def main(args):
    from numpy.random import seed
    seed(46)

    # Get context.
    from nnabla.ext_utils import get_extension_context
    ctx = get_extension_context('cudnn', device_id='0', type_config='float')
    nn.set_default_context(ctx)

    # Create CNN network
    # === TRAIN ===
    # Create input variables.
    image = nn.Variable([args.batch_size, 3, args.img_height, args.img_width])
    label = nn.Variable([args.batch_size, 1, args.img_height, args.img_width])
    # Create prediction graph.
    pred = depth_cnn_model(image, test=False)
    pred.persistent = True
    # Create loss function.
    loss = l1_loss(pred, label)
    # === VAL ===
    #vimage = nn.Variable([args.batch_size, 3, args.img_height, args.img_width])
    #vlabel = nn.Variable([args.batch_size, 1, args.img_height, args.img_width])
    #vpred = depth_cnn_model(vimage, test=True)
    #vloss = l1_loss(vpred, vlabel)

    # Prepare monitors.
    monitor = Monitor(os.path.join(args.log_dir, 'nnmonitor'))
    monitors = {
        'train_epoch_loss':
        MonitorSeries('Train epoch loss', monitor, interval=1),
        'train_itr_loss':
        MonitorSeries('Train itr loss', monitor, interval=100),
        # 'val_epoch_loss': MonitorSeries('Val epoch loss', monitor, interval=1),
        'train_viz':
        MonitorImageTile('Train images', monitor, interval=1000, num_images=4)
    }

    # Create Solver. If training from checkpoint, load the info.
    if args.optimizer == "adam":
        solver = S.Adam(alpha=args.learning_rate, beta1=0.9, beta2=0.999)
    elif args.optimizer == "sgd":
        solver = S.Momentum(lr=args.learning_rate, momentum=0.9)
    solver.set_parameters(nn.get_parameters())

    # Initialize DataIterator
    data_dic = prepare_dataloader(args.dataset_path,
                                  datatype_list=['train', 'val'],
                                  batch_size=args.batch_size,
                                  img_size=(args.img_height, args.img_width))

    # Training loop.
    logger.info("Start training!!!")
    total_itr_index = 0
    for epoch in range(1, args.epochs + 1):
        ## === training === ##
        total_train_loss = 0
        index = 0
        while index < data_dic['train']['size']:
            # Preprocess
            image.d, label.d = data_dic['train']['itr'].next()
            loss.forward(clear_no_need_grad=True)
            # Initialize gradients
            solver.zero_grad()
            # Backward execution
            loss.backward(clear_buffer=True)
            # Update parameters by computed gradients
            if args.optimizer == 'sgd':
                solver.weight_decay(1e-4)
            solver.update()

            # Update log
            index += 1
            total_itr_index += 1
            total_train_loss += loss.d

            # Pass to monitor
            monitors['train_itr_loss'].add(total_itr_index, loss.d)

            # Visualization
            pred.forward(clear_buffer=True)
            train_viz = np.concatenate([
                image.d,
                convert_depth2colormap(label.d),
                convert_depth2colormap(pred.d)
            ],
                                       axis=3)
            monitors['train_viz'].add(total_itr_index, train_viz)

            # Logger
            logger.info("[{}] {}/{} Train Loss {} ({})".format(
                epoch, index, data_dic['train']['size'],
                total_train_loss / index, loss.d))

        # Pass training loss to a monitor.
        train_error = total_train_loss / data_dic['train']['size']
        monitors['train_epoch_loss'].add(epoch, train_error)

        # Save Parameter
        out_param_file = os.path.join(args.log_dir,
                                      'checkpoint' + str(epoch) + '.h5')
        nn.save_parameters(out_param_file)