示例#1
0
def load_marcua_dataset():
    for expid, (prop, answers, truth) in enumerate(data_iterator()):
        cur_wid = 0
        wids = {}
        rowid = 0
        for ans in answers:
            # normalize worker ids
            if ans[1] not in wids:
                wids[ans[1]] = cur_wid
                cur_wid += 1
            row = list(ans)
            row[1] = wids[ans[1]]
            row.append(rowid)
            rowid += 1
            row.insert(0, prop)
            row.insert(0, expid)
            row.append(truth)

            yield row
示例#2
0
def load_marcua_dataset():
    for expid, (prop, answers, truth) in enumerate(data_iterator()):
        cur_wid = 0
        wids = {}
        rowid = 0
        for ans in answers:
            # normalize worker ids
            if ans[1] not in wids:
                wids[ans[1]] = cur_wid
                cur_wid += 1
            row = list(ans)
            row[1] = wids[ans[1]]
            row.append(rowid)
            rowid += 1
            row.insert(0, prop)
            row.insert(0, expid)
            row.append(truth)

            yield row
示例#3
0
def main():
    # Args
    args = get_args()
    save_args(args)

    # Context
    ctx = extension_context(args.context,
                            device_id=args.device_id,
                            type_config=args.type_config)
    nn.set_default_context(ctx)
    nn.set_auto_forward(True)

    # Data Itrator
    di = data_iterator(args.img_path,
                       args.batch_size,
                       imsize=(args.imsize, args.imsize),
                       num_samples=args.train_samples,
                       dataset_name=args.dataset_name)
    # Model
    generator = Generator(use_bn=args.use_bn,
                          last_act=args.last_act,
                          use_wscale=args.not_use_wscale,
                          use_he_backward=args.use_he_backward)
    discriminator = Discriminator(use_ln=args.use_ln,
                                  alpha=args.leaky_alpha,
                                  use_wscale=args.not_use_wscale,
                                  use_he_backward=args.use_he_backward)

    # Solver
    solver_gen = S.Adam(alpha=args.learning_rate,
                        beta1=args.beta1,
                        beta2=args.beta2)
    solver_dis = S.Adam(alpha=args.learning_rate,
                        beta1=args.beta1,
                        beta2=args.beta2)

    # Monitor
    monitor = Monitor(args.monitor_path)
    monitor_loss_gen = MonitorSeries("Generator Loss", monitor, interval=10)
    monitor_loss_dis = MonitorSeries("Discriminator Loss",
                                     monitor,
                                     interval=10)
    monitor_p_fake = MonitorSeries("Fake Probability", monitor, interval=10)
    monitor_p_real = MonitorSeries("Real Probability", monitor, interval=10)
    monitor_time = MonitorTimeElapsed("Training Time per Resolution",
                                      monitor,
                                      interval=1)
    monitor_image_tile = MonitorImageTileWithName("Image Tile",
                                                  monitor,
                                                  num_images=4,
                                                  normalize_method=lambda x:
                                                  (x + 1.) / 2.)

    # TODO: use argment
    resolution_list = [4, 8, 16, 32, 64, 128]
    channel_list = [512, 512, 256, 128, 64, 32]

    trainer = Trainer(di,
                      generator,
                      discriminator,
                      solver_gen,
                      solver_dis,
                      args.monitor_path,
                      monitor_loss_gen,
                      monitor_loss_dis,
                      monitor_p_fake,
                      monitor_p_real,
                      monitor_time,
                      monitor_image_tile,
                      resolution_list,
                      channel_list,
                      n_latent=args.latent,
                      n_critic=args.critic,
                      save_image_interval=args.save_image_interval,
                      hyper_sphere=args.hyper_sphere,
                      l2_fake_weight=args.l2_fake_weight)

    # TODO: use images per resolution?
    trainer.train(args.epoch_per_resolution)
示例#4
0
def main():
    # Args
    args = get_args()

    # Context
    ctx = get_extension_context(args.context,
                                device_id=args.device_id,
                                type_config=args.type_config)
    logger.info(ctx)
    nn.set_default_context(ctx)
    nn.set_auto_forward(True)

    # Monitor
    monitor = Monitor(args.monitor_path)

    # Validation
    logger.info("Start validation")

    num_images = args.valid_samples
    num_batches = num_images // args.batch_size

    # DataIterator
    di = data_iterator(args.img_path,
                       args.batch_size,
                       imsize=(args.imsize, args.imsize),
                       num_samples=args.valid_samples,
                       dataset_name=args.dataset_name)
    # generator
    gen = load_gen(args.model_load_path,
                   use_bn=args.use_bn,
                   last_act=args.last_act,
                   use_wscale=args.not_use_wscale,
                   use_he_backward=args.use_he_backward)

    # compute metric
    if args.validation_metric == "ms-ssim":
        logger.info("Multi Scale SSIM")
        monitor_time = MonitorTimeElapsed("MS-SSIM-ValidationTime",
                                          monitor,
                                          interval=1)
        monitor_metric = MonitorSeries("MS-SSIM", monitor, interval=1)
        from ms_ssim import compute_metric
        score = compute_metric(gen, args.batch_size, num_images, args.latent,
                               args.hyper_sphere)
        monitor_time.add(0)
        monitor_metric.add(0, score)
    elif args.validation_metric == "swd":
        logger.info("Sliced Wasserstein Distance")
        monitor_time = MonitorTimeElapsed("SWD-ValidationTime",
                                          monitor,
                                          interval=1)
        monitor_metric = MonitorSeries("SWD", monitor, interval=1)
        nhoods_per_image = 128
        nhood_size = 7
        level_list = [128, 64, 32, 16]  # TODO: use argument
        dir_repeats = 4
        dirs_per_repeat = 128
        from sliced_wasserstein import compute_metric
        score = compute_metric(di, gen, args.latent, num_batches,
                               nhoods_per_image, nhood_size, level_list,
                               dir_repeats, dirs_per_repeat, args.hyper_sphere)
        monitor_time.add(0)
        monitor_metric.add(0, score)  # averaged in the log
    else:
        logger.info("Set `validation-metric` as either `ms-ssim` or `swd`.")
    logger.info(score)
    logger.info("End validation")
示例#5
0
def train(args):
    # Context
    ctx = get_extension_context(
        args.context, device_id=args.device_id, type_config=args.type_config)
    nn.set_default_context(ctx)

    aug_list = args.aug_list

    # Model
    scope_gen = "Generator"
    scope_dis = "Discriminator"
    # generator loss
    z = nn.Variable([args.batch_size, args.latent, 1, 1])
    x_fake = Generator(z, scope_name=scope_gen, img_size=args.image_size)
    p_fake = Discriminator([augment(xf, aug_list)
                            for xf in x_fake], label="fake", scope_name=scope_dis)
    lossG = loss_gen(p_fake)
    # discriminator loss
    x_real = nn.Variable(
        [args.batch_size, 3, args.image_size, args.image_size])
    x_real_aug = augment(x_real, aug_list)
    p_real, rec_imgs, part = Discriminator(
        x_real_aug, label="real", scope_name=scope_dis)
    lossD_fake = loss_dis_fake(p_fake)
    lossD_real = loss_dis_real(p_real, rec_imgs, part, x_real_aug)
    lossD = lossD_fake + lossD_real
    # generator with fixed latent values for test
    # Use train=True even in an inference phase
    z_test = nn.Variable.from_numpy_array(
        np.random.randn(args.batch_size, args.latent, 1, 1))
    x_test = Generator(z_test, scope_name=scope_gen,
                       train=True, img_size=args.image_size)[0]

    # Exponential Moving Average (EMA) model
    # Use train=True even in an inference phase
    scope_gen_ema = "Generator_EMA"
    x_test_ema = Generator(z_test, scope_name=scope_gen_ema,
                           train=True, img_size=args.image_size)[0]
    copy_params(scope_gen, scope_gen_ema)
    update_ema_var = make_ema_updater(scope_gen_ema, scope_gen, 0.999)

    # Solver
    solver_gen = S.Adam(args.lr, beta1=0.5)
    solver_dis = S.Adam(args.lr, beta1=0.5)
    with nn.parameter_scope(scope_gen):
        params_gen = nn.get_parameters()
        solver_gen.set_parameters(params_gen)
    with nn.parameter_scope(scope_dis):
        params_dis = nn.get_parameters()
        solver_dis.set_parameters(params_dis)

    # Monitor
    monitor = Monitor(args.monitor_path)
    monitor_loss_gen = MonitorSeries(
        "Generator Loss", monitor, interval=10)
    monitor_loss_dis_real = MonitorSeries(
        "Discriminator Loss Real", monitor, interval=10)
    monitor_loss_dis_fake = MonitorSeries(
        "Discriminator Loss Fake", monitor, interval=10)
    monitor_time = MonitorTimeElapsed(
        "Training Time", monitor, interval=10)
    monitor_image_tile_train = MonitorImageTile("Image Tile Train", monitor,
                                                num_images=args.batch_size,
                                                interval=1,
                                                normalize_method=lambda x: (x + 1.) / 2.)
    monitor_image_tile_test = MonitorImageTile("Image Tile Test", monitor,
                                               num_images=args.batch_size,
                                               interval=1,
                                               normalize_method=lambda x: (x + 1.) / 2.)
    monitor_image_tile_test_ema = MonitorImageTile("Image Tile Test EMA", monitor,
                                                   num_images=args.batch_size,
                                                   interval=1,
                                                   normalize_method=lambda x: (x + 1.) / 2.)

    # Data Iterator
    rng = np.random.RandomState(141)
    di = data_iterator(args.img_path, args.batch_size,
                       imsize=(args.image_size, args.image_size),
                       num_samples=args.train_samples, rng=rng)

    # Train loop
    for i in range(args.max_iter):
        # Train discriminator
        x_fake[0].need_grad = False  # no need backward to generator
        x_fake[1].need_grad = False  # no need backward to generator
        solver_dis.zero_grad()
        x_real.d = di.next()[0]
        z.d = np.random.randn(args.batch_size, args.latent, 1, 1)
        lossD.forward()
        lossD.backward()
        solver_dis.update()

        # Train generator
        x_fake[0].need_grad = True  # need backward to generator
        x_fake[1].need_grad = True  # need backward to generator
        solver_gen.zero_grad()
        lossG.forward()
        lossG.backward()
        solver_gen.update()

        # Update EMA model
        update_ema_var.forward()

        # Monitor
        monitor_loss_gen.add(i, lossG.d)
        monitor_loss_dis_real.add(i, lossD_real.d)
        monitor_loss_dis_fake.add(i, lossD_fake.d)
        monitor_time.add(i)

        # Save
        if (i+1) % args.save_interval == 0:
            with nn.parameter_scope(scope_gen):
                nn.save_parameters(os.path.join(
                    args.monitor_path, "Gen_iter{}.h5".format(i+1)))
            with nn.parameter_scope(scope_gen_ema):
                nn.save_parameters(os.path.join(
                    args.monitor_path, "GenEMA_iter{}.h5".format(i+1)))
            with nn.parameter_scope(scope_dis):
                nn.save_parameters(os.path.join(
                    args.monitor_path, "Dis_iter{}.h5".format(i+1)))
        if (i+1) % args.test_interval == 0:
            x_test.forward(clear_buffer=True)
            x_test_ema.forward(clear_buffer=True)
            monitor_image_tile_train.add(i+1, x_fake[0])
            monitor_image_tile_test.add(i+1, x_test)
            monitor_image_tile_test_ema.add(i+1, x_test_ema)

    # Last
    x_test.forward(clear_buffer=True)
    x_test_ema.forward(clear_buffer=True)
    monitor_image_tile_train.add(args.max_iter, x_fake[0])
    monitor_image_tile_test.add(args.max_iter, x_test)
    monitor_image_tile_test_ema.add(args.max_iter, x_test_ema)
    with nn.parameter_scope(scope_gen):
        nn.save_parameters(os.path.join(args.monitor_path,
                                        "Gen_iter{}.h5".format(args.max_iter)))
    with nn.parameter_scope(scope_gen_ema):
        nn.save_parameters(os.path.join(args.monitor_path,
                                        "GenEMA_iter{}.h5".format(args.max_iter)))
    with nn.parameter_scope(scope_dis):
        nn.save_parameters(os.path.join(args.monitor_path,
                                        "Dis_iter{}.h5".format(args.max_iter)))