Example #1
0
def reload_checkpoint(trainer: Trainer,
                      checkpoints_dir: str,
                      save_path=None) -> tf2lib.Checkpoint:
    checkpoint_dict = {
        'encoder_a': trainer.model.encoder_a,
        'encoder_b': trainer.model.encoder_b,
        'encoder_shared': trainer.model.encoder_shared,
        'decoder_shared': trainer.model.decoder_shared,
        'decoder_a': trainer.model.decoder_a,
        'decoder_b': trainer.model.decoder_b,
        'downstreamer': trainer.model.downstreamer,
        'dis_a': trainer.model.dis_a,
        'dis_b': trainer.model.dis_b,
        'controller': trainer.controller.model,
        'gen_opt': trainer.gen_opt,
        'dis_opt': trainer.dis_opt,
        'control_opt': trainer.control_opt,
    }
    checkpoint = tf2lib.Checkpoint(checkpoint_dict,
                                   checkpoints_dir,
                                   max_to_keep=1)
    try:  # Restore checkpoint
        checkpoint.restore(save_path).assert_existing_objects_matched()
    except Exception as e:
        logging.error(e)
    return checkpoint
Example #2
0
 def set_checkpoints(self):
     self.ep_cnt = tf.Variable(initial_value=0,
                               trainable=False,
                               dtype=tf.int64)
     # checkpoint
     self.checkpoint = tl.Checkpoint(dict(G_A2B=self.G_A2B,
                                          G_B2A=self.G_B2A,
                                          D_A=self.D_A,
                                          D_B=self.D_B,
                                          G_optimizer=self.G_optimizer,
                                          D_optimizer=self.D_optimizer,
                                          ep_cnt=self.ep_cnt),
                                     py.join(self.output_dataset_dir,
                                             'checkpoints'),
                                     max_to_keep=5)
     try:  # restore checkpoint including the epoch counter
         self.checkpoint.restore().assert_existing_objects_matched()
     except Exception as e:
         self.logger.warn(e)
def sample(z):
    return G(z, training=False)


# ==============================================================================
# =                                    run                                     =
# ==============================================================================

# epoch counter
ep_cnt = tf.Variable(initial_value=0, trainable=False, dtype=tf.int64)

# checkpoint
checkpoint = tl.Checkpoint(dict(G=G,
                                D=D,
                                G_optimizer=G_optimizer,
                                D_optimizer=D_optimizer,
                                ep_cnt=ep_cnt),
                           py.join(output_dir, 'checkpoints'),
                           max_to_keep=5)
try:  # restore checkpoint including the epoch counter
    checkpoint.restore().assert_existing_objects_matched()
except Exception as e:
    print(e)

# summary
train_summary_writer = tf.summary.create_file_writer(py.join(output_dir, 'summaries', 'train'))

# sample
sample_dir = py.join(output_dir, 'samples_training')
py.mkdir(sample_dir)
Example #4
0
    return A2B, B2A, A2B2A, B2A2B


# ==============================================================================
# =                                    run                                     =
# ==============================================================================

# epoch counter
ep_cnt = tf.Variable(initial_value=0, trainable=False, dtype=tf.int64)

# checkpoint
checkpoint = tl.Checkpoint(dict(G_A2B=G_A2B,
                                G_B2A=G_B2A,
                                D_A=D_A,
                                D_B=D_B,
                                G_optimizer=G_optimizer,
                                D_optimizer=D_optimizer,
                                ep_cnt=ep_cnt),
                           py.join(output_dir, 'checkpoints'),
                           max_to_keep=5)
try:  # restore checkpoint including the epoch counter
    checkpoint.restore().assert_existing_objects_matched()
except Exception as e:
    print(e)

# summary
train_summary_writer = tf.summary.create_file_writer(
    py.join(output_dir, 'summaries', 'train'))

# sample
test_iter = iter(A_B_dataset_test)
Example #5
0
                                   drop_remainder=False,
                                   shuffle=False,
                                   repeat=1)

# model
G_A2B = module.ResnetGenerator(input_shape=(args.crop_size, args.crop_size, 3),
                               dim=args.dim,
                               n_downsamplings=args.n_downsamplings,
                               n_blocks=args.n_blocks)
G_B2A = module.ResnetGenerator(input_shape=(args.crop_size, args.crop_size, 3),
                               dim=args.dim,
                               n_downsamplings=args.n_downsamplings,
                               n_blocks=args.n_blocks)

# resotre
tl.Checkpoint(dict(G_A2B=G_A2B, G_B2A=G_B2A),
              py.join(args.experiment_dir, 'checkpoints')).restore()


@tf.function
def sample_A2B(A):
    A2B = G_A2B(A, training=False)
    A2B2A = G_B2A(A2B, training=False)
    return A2B, A2B2A


@tf.function
def sample_B2A(B):
    B2A = G_B2A(B, training=False)
    B2A2B = G_A2B(B2A, training=False)
    return B2A, B2A2B
        'G_GAN': loss_G_GAN,
        'G_GAN_Feat': loss_G_GAN_Feat,
        'loss_G_VGG': loss_G_VGG
    }
    return loss_D_dict, loss_G_dict, fake_img


# graph = train_G_D.get_concrete_function().graph
# graph_def = graph.as_graph_def()
# tf.io.write_graph(graph_def, './', 'graph.pbtxt', True)

# checkpoint
checkpoint_dir = os.path.join(opt.checkpoints_dir,
                              opt.name, 'train_ckpt')
checkpoint = tl.Checkpoint({'generator': model.netG,
                            'discriminator': model.netD,
                            'optimizer_G': optimizer_G,
                            'optimizer_D': optimizer_D}, checkpoint_dir)
try:  # restore checkpoint including the epoch counter
    checkpoint.restore().assert_existing_objects_matched()
    print('restore')
except Exception as e:
    print(e)

# main loop
with train_summary_writer.as_default():
    for ep in range(start_epoch, opt.niter + opt.niter_decay + 1):
        print('Epoch: ', ep)
        for step, (label, real_img) in enumerate(dataset):
            input_label, inst_map, real_image, feat_map = model.encode_input(label)
            # inputs = (input_label, inst_map, real_image, feat_map)
            loss_D_dict, loss_G_dict, fake_img = train_step(input_label, real_img)
Example #7
0
def train():
    # ===================================== Args =====================================
    args = parse_args()
    output_dir = os.path.join('output', args.dataset)
    os.makedirs(output_dir, exist_ok=True)
    settings_path = os.path.join(output_dir, 'settings.json')
    pylib.args_to_json(settings_path, args)

    # ===================================== Data =====================================
    A_img_paths = pylib.glob(
        os.path.join(args.datasets_dir, args.dataset, 'trainA'), '*.png')
    B_img_paths = pylib.glob(
        os.path.join(args.datasets_dir, args.dataset, 'trainB'), '*.png')
    print(f'len(A_img_paths) = {len(A_img_paths)}')
    print(f'len(B_img_paths) = {len(B_img_paths)}')
    load_size = [args.load_size_height, args.load_size_width]
    crop_size = [args.crop_size_height, args.crop_size_width]
    A_B_dataset, len_dataset = data.make_zip_dataset(A_img_paths,
                                                     B_img_paths,
                                                     args.batch_size,
                                                     load_size,
                                                     crop_size,
                                                     training=True,
                                                     repeat=False)

    A2B_pool = data.ItemPool(args.pool_size)
    B2A_pool = data.ItemPool(args.pool_size)

    A_img_paths_test = pylib.glob(
        os.path.join(args.datasets_dir, args.dataset, 'testA'), '*.png')
    B_img_paths_test = pylib.glob(
        os.path.join(args.datasets_dir, args.dataset, 'testB'), '*.png')
    A_B_dataset_test, _ = data.make_zip_dataset(A_img_paths_test,
                                                B_img_paths_test,
                                                args.batch_size,
                                                load_size,
                                                crop_size,
                                                training=False,
                                                repeat=True)

    # ===================================== Models =====================================
    model_input_shape = crop_size + [
        3
    ]  # [args.crop_size_height, args.crop_size_width, 3]

    G_A2B = module.ResnetGenerator(input_shape=model_input_shape, n_blocks=6)
    G_B2A = module.ResnetGenerator(input_shape=model_input_shape, n_blocks=6)

    D_A = module.ConvDiscriminator(input_shape=model_input_shape)
    D_B = module.ConvDiscriminator(input_shape=model_input_shape)

    d_loss_fn, g_loss_fn = tf2gan.get_adversarial_losses_fn(
        args.adversarial_loss_mode)
    cycle_loss_fn = tf.losses.MeanAbsoluteError()
    identity_loss_fn = tf.losses.MeanAbsoluteError()

    G_lr_scheduler = module.LinearDecay(args.lr, args.epochs * len_dataset,
                                        args.epoch_decay * len_dataset)
    D_lr_scheduler = module.LinearDecay(args.lr, args.epochs * len_dataset,
                                        args.epoch_decay * len_dataset)
    G_optimizer = tf.keras.optimizers.Adam(learning_rate=G_lr_scheduler,
                                           beta_1=args.beta_1)
    D_optimizer = tf.keras.optimizers.Adam(learning_rate=D_lr_scheduler,
                                           beta_1=args.beta_1)

    # ===================================== Training steps =====================================
    @tf.function
    def train_generators(A, B):
        with tf.GradientTape() as t:
            A2B = G_A2B(A, training=True)
            B2A = G_B2A(B, training=True)
            A2B2A = G_B2A(A2B, training=True)
            B2A2B = G_A2B(B2A, training=True)
            A2A = G_B2A(A, training=True)
            B2B = G_A2B(B, training=True)

            A2B_d_logits = D_B(A2B, training=True)
            B2A_d_logits = D_A(B2A, training=True)

            A2B_g_loss = g_loss_fn(A2B_d_logits)
            B2A_g_loss = g_loss_fn(B2A_d_logits)
            A2B2A_cycle_loss = cycle_loss_fn(A, A2B2A)
            B2A2B_cycle_loss = cycle_loss_fn(B, B2A2B)
            A2A_id_loss = identity_loss_fn(A, A2A)
            B2B_id_loss = identity_loss_fn(B, B2B)

            G_loss = (A2B_g_loss + B2A_g_loss) + (
                A2B2A_cycle_loss +
                B2A2B_cycle_loss) * args.cycle_loss_weight + (
                    A2A_id_loss + B2B_id_loss) * args.identity_loss_weight

        G_grad = t.gradient(
            G_loss, G_A2B.trainable_variables + G_B2A.trainable_variables)
        G_optimizer.apply_gradients(
            zip(G_grad, G_A2B.trainable_variables + G_B2A.trainable_variables))

        return A2B, B2A, {
            'A2B_g_loss': A2B_g_loss,
            'B2A_g_loss': B2A_g_loss,
            'A2B2A_cycle_loss': A2B2A_cycle_loss,
            'B2A2B_cycle_loss': B2A2B_cycle_loss,
            'A2A_id_loss': A2A_id_loss,
            'B2B_id_loss': B2B_id_loss
        }

    @tf.function
    def train_discriminators(A, B, A2B, B2A):
        with tf.GradientTape() as t:
            A_d_logits = D_A(A, training=True)
            B2A_d_logits = D_A(B2A, training=True)
            B_d_logits = D_B(B, training=True)
            A2B_d_logits = D_B(A2B, training=True)

            A_d_loss, B2A_d_loss = d_loss_fn(A_d_logits, B2A_d_logits)
            B_d_loss, A2B_d_loss = d_loss_fn(B_d_logits, A2B_d_logits)
            D_A_gp = tf2gan.gradient_penalty(functools.partial(D_A,
                                                               training=True),
                                             A,
                                             B2A,
                                             mode=args.gradient_penalty_mode)
            D_B_gp = tf2gan.gradient_penalty(functools.partial(D_B,
                                                               training=True),
                                             B,
                                             A2B,
                                             mode=args.gradient_penalty_mode)

            D_loss = (A_d_loss + B2A_d_loss) + (B_d_loss + A2B_d_loss) + (
                D_A_gp + D_B_gp) * args.gradient_penalty_weight

        D_grad = t.gradient(D_loss,
                            D_A.trainable_variables + D_B.trainable_variables)
        D_optimizer.apply_gradients(
            zip(D_grad, D_A.trainable_variables + D_B.trainable_variables))

        return {
            'A_d_loss': A_d_loss + B2A_d_loss,
            'B_d_loss': B_d_loss + A2B_d_loss,
            'D_A_gp': D_A_gp,
            'D_B_gp': D_B_gp
        }

    def train_step(A, B):
        A2B, B2A, G_loss_dict = train_generators(A, B)

        # cannot autograph `A2B_pool`
        A2B = A2B_pool(
            A2B)  # or A2B = A2B_pool(A2B.numpy()), but it is much slower
        B2A = B2A_pool(B2A)  # because of the communication between CPU and GPU

        D_loss_dict = train_discriminators(A, B, A2B, B2A)

        return G_loss_dict, D_loss_dict

    @tf.function
    def sample(A, B):
        A2B = G_A2B(A, training=False)
        B2A = G_B2A(B, training=False)
        A2B2A = G_B2A(A2B, training=False)
        B2A2B = G_A2B(B2A, training=False)
        return A2B, B2A, A2B2A, B2A2B

    # ===================================== Runner code =====================================
    # epoch counter
    ep_cnt = tf.Variable(initial_value=0, trainable=False, dtype=tf.int64)

    # checkpoint
    checkpoint = tf2lib.Checkpoint(dict(G_A2B=G_A2B,
                                        G_B2A=G_B2A,
                                        D_A=D_A,
                                        D_B=D_B,
                                        G_optimizer=G_optimizer,
                                        D_optimizer=D_optimizer,
                                        ep_cnt=ep_cnt),
                                   os.path.join(output_dir, 'checkpoints'),
                                   max_to_keep=5)
    try:  # restore checkpoint including the epoch counter
        checkpoint.restore().assert_existing_objects_matched()
    except Exception as e:
        print(e)

    # summary
    train_summary_writer = tf.summary.create_file_writer(
        os.path.join(output_dir, 'summaries', 'train'))

    # sample
    test_iter = iter(A_B_dataset_test)
    sample_dir = os.path.join(output_dir, 'samples_training')
    os.makedirs(sample_dir, exist_ok=True)

    # main loop
    with train_summary_writer.as_default():
        for ep in tqdm.trange(args.epochs, desc='Epoch Loop'):
            if ep < ep_cnt:
                continue

            # update epoch counter
            ep_cnt.assign_add(1)

            # train for an epoch
            for A, B in tqdm.tqdm(A_B_dataset,
                                  desc='Inner Epoch Loop',
                                  total=len_dataset):
                G_loss_dict, D_loss_dict = train_step(A, B)

                # # summary
                tf2lib.summary(G_loss_dict,
                               step=G_optimizer.iterations,
                               name='G_losses')
                tf2lib.summary(D_loss_dict,
                               step=G_optimizer.iterations,
                               name='D_losses')
                tf2lib.summary(
                    {'learning rate': G_lr_scheduler.current_learning_rate},
                    step=G_optimizer.iterations,
                    name='learning rate')

                # sample
                if G_optimizer.iterations.numpy() % 100 == 0:
                    A, B = next(test_iter)
                    A2B, B2A, A2B2A, B2A2B = sample(A, B)
                    img = imlib.immerge(np.concatenate(
                        [A, A2B, A2B2A, B, B2A, B2A2B], axis=0),
                                        n_rows=6)
                    imlib.imwrite(
                        img,
                        os.path.join(
                            sample_dir,
                            'iter-%09d.jpg' % G_optimizer.iterations.numpy()))

            # save checkpoint
            checkpoint.save(ep)

# ==============================================================================
# =                                    run                                     =
# ==============================================================================

# epoch counter
ep_cnt = tf.Variable(initial_value=0, trainable=False, dtype=tf.int64)

# checkpoint
checkpoint_dir = os.path.join(output_dir, 'checkpoints')
checkpoint = tl.Checkpoint(dict(G_A2B=G_A2B,
                                G_B2A=G_B2A,
                                D_A=D_A,
                                D_B=D_B,
                                G_optimizer=G_optimizer,
                                D_optimizer=D_optimizer,
                                ep_cnt=ep_cnt),
                           checkpoint_dir,
                           max_to_keep=5)
try:  # restore checkpoint including the epoch counter
    latest_checkpt = tf.train.latest_checkpoint(checkpoint_dir)
    if latest_checkpt is not None:
        print('Latest checkpoint found: {}. Attempting to restore...'.format(
            latest_checkpt))
    else:
        print('No checkpoint found. Beginning from scratch.')
    checkpoint.restore(latest_checkpt).assert_existing_objects_matched()
except Exception as e:
    print(e)