Beispiel #1
0
def build(config, batch_size, is_train=False):
    optimizer = train_utils.build_optimizer(config)
    ema_vars = []

    downsample = config.get('downsample', False)
    downsample_res = config.get('downsample_res', 64)
    h, w = config.resolution

    if config.model.name == 'coltran_core':
        if downsample:
            h, w = downsample_res, downsample_res
        zero = tf.zeros((batch_size, h, w, 3), dtype=tf.int32)
        model = colorizer.ColTranCore(config.model)
        model(zero, training=is_train)

    c = 1 if is_train else 3
    if config.model.name == 'color_upsampler':
        if downsample:
            h, w = downsample_res, downsample_res
        zero_slice = tf.zeros((batch_size, h, w, c), dtype=tf.int32)
        zero = tf.zeros((batch_size, h, w, 3), dtype=tf.int32)
        model = upsampler.ColorUpsampler(config.model)
        model(zero, inputs_slice=zero_slice, training=is_train)
    elif config.model.name == 'spatial_upsampler':
        zero_slice = tf.zeros((batch_size, h, w, c), dtype=tf.int32)
        zero = tf.zeros((batch_size, h, w, 3), dtype=tf.int32)
        model = upsampler.SpatialUpsampler(config.model)
        model(zero, inputs_slice=zero_slice, training=is_train)

    ema_vars = model.trainable_variables
    ema = train_utils.build_ema(config, ema_vars)
    return model, optimizer, ema
    def test_transformer_attention_encoder(self):
        config = self.get_config(encoder_net='attention')
        config.stage = 'encoder_decoder'
        transformer = colorizer.ColTranCore(config=config)
        images = tf.random.uniform(shape=(2, 2, 2, 3),
                                   minval=0,
                                   maxval=256,
                                   dtype=tf.int32)
        logits = transformer(inputs=images, training=True)[0]
        self.assertEqual(logits.shape, (2, 2, 2, 1, 512))

        # batch-size=2
        gray = tf.image.rgb_to_grayscale(images)
        output = transformer.sample(gray, mode='argmax')
        output_np = output['auto_argmax'].numpy()
        proba_np = output['proba'].numpy()
        self.assertEqual(output_np.shape, (2, 2, 2, 3))
        self.assertEqual(proba_np.shape, (2, 2, 2, 512))
        # logging.info(output_np[0, ..., 0])

        # batch-size=1
        output_np_bs_1, proba_np_bs_1 = [], []
        for batch_ind in [0, 1]:
            curr_gray = tf.expand_dims(gray[batch_ind], axis=0)
            curr_out = transformer.sample(curr_gray, mode='argmax')
            curr_out_np = curr_out['auto_argmax'].numpy()
            curr_proba_np = curr_out['proba'].numpy()
            output_np_bs_1.append(curr_out_np)
            proba_np_bs_1.append(curr_proba_np)
        output_np_bs_1 = np.concatenate(output_np_bs_1, axis=0)
        proba_np_bs_1 = np.concatenate(proba_np_bs_1, axis=0)
        self.assertTrue(np.allclose(output_np, output_np_bs_1))
        self.assertTrue(np.allclose(proba_np, proba_np_bs_1))
    def test_transformer_encoder_decoder(self):
        config = self.get_config()
        config.stage = 'encoder_decoder'

        transformer = colorizer.ColTranCore(config=config)
        images = tf.random.uniform(shape=(1, 64, 64, 3),
                                   minval=0,
                                   maxval=256,
                                   dtype=tf.int32)
        logits, enc_logits = transformer(inputs=images, training=True)
        enc_logits = enc_logits['encoder_logits']
        self.assertEqual(enc_logits.shape, (1, 64, 64, 1, 512))
        self.assertEqual(logits.shape, (1, 64, 64, 1, 512))
def build_model(config):
    """Builds model."""
    name = config.model.name
    optimizer = train_utils.build_optimizer(config)

    zero_64 = tf.zeros((1, 64, 64, 3), dtype=tf.int32)
    zero_64_slice = tf.zeros((1, 64, 64, 1), dtype=tf.int32)
    zero = tf.zeros((1, 256, 256, 3), dtype=tf.int32)
    zero_slice = tf.zeros((1, 256, 256, 1), dtype=tf.int32)

    if name == 'coltran_core':
        model = colorizer.ColTranCore(config.model)
        model(zero_64, training=False)
    elif name == 'color_upsampler':
        model = upsampler.ColorUpsampler(config.model)
        model(inputs=zero_64, inputs_slice=zero_64_slice, training=False)
    elif name == 'spatial_upsampler':
        model = upsampler.SpatialUpsampler(config.model)
        model(inputs=zero, inputs_slice=zero_slice, training=False)

    ema_vars = model.trainable_variables
    ema = train_utils.build_ema(config, ema_vars)
    return model, optimizer, ema