コード例 #1
0
ファイル: rnn.py プロジェクト: sleep-yearning/magenta
 def get_output(self, state):
     unused_c, h = tf.split(state, 2, 1)
     return h
コード例 #2
0
    def run(
        self,
        *in_arrays,
        return_as_list=False,  # True = return a list of NumPy arrays, False = return a single NumPy array, or a tuple if there are multiple outputs.
        print_progress=False,  # Print progress to the console? Useful for very large input arrays.
        minibatch_size=None,  # Maximum minibatch size to use, None = disable batching.
        num_gpus=1,  # Number of GPUs to use.
        out_mul=1.0,  # Multiplicative constant to apply to the output(s).
        out_add=0.0,  # Additive constant to apply to the output(s).
        out_shrink=1,  # Shrink the spatial dimensions of the output(s) by the given factor.
        out_dtype=None,  # Convert the output to the specified data type.
        **dynamic_kwargs
    ):  # Additional keyword arguments to pass into the network construction function.

        # assert len(in_arrays) == self.num_inputs
        num_items = in_arrays[0].shape[0]
        print(num_items)
        if minibatch_size is None:
            minibatch_size = num_items
        key = str([
            list(sorted(dynamic_kwargs.items())), num_gpus, out_mul, out_add,
            out_shrink, out_dtype
        ])

        # Build graph.
        if key not in self._run_cache:
            with absolute_name_scope(self.scope +
                                     '/Run'), tf.control_dependencies(None):
                in_split = list(
                    zip(*[tf.split(x, num_gpus)
                          for x in self.input_templates]))
                out_split = []
                for gpu in range(num_gpus):
                    with tf.device('/gpu:%d' % gpu):
                        out_expr = self.get_output_for(*in_split[gpu],
                                                       return_as_list=True,
                                                       **dynamic_kwargs)
                        if out_mul != 1.0:
                            out_expr = [x * out_mul for x in out_expr]
                        if out_add != 0.0:
                            out_expr = [x + out_add for x in out_expr]
                        if out_shrink > 1:
                            ksize = [1, 1, out_shrink, out_shrink]
                            out_expr = [
                                tf.nn.avg_pool(x,
                                               ksize=ksize,
                                               strides=ksize,
                                               padding='VALID',
                                               data_format='NCHW')
                                for x in out_expr
                            ]
                        if out_dtype is not None:
                            if tf.as_dtype(out_dtype).is_integer:
                                out_expr = [tf.round(x) for x in out_expr]
                            out_expr = [
                                tf.saturate_cast(x, out_dtype)
                                for x in out_expr
                            ]
                        out_split.append(out_expr)
                self._run_cache[key] = [
                    tf.concat(outputs, axis=0) for outputs in zip(*out_split)
                ]

        # Run minibatches.
        out_expr = self._run_cache[key]
        out_arrays = [
            np.empty([num_items] + shape_to_list(expr.shape)[1:],
                     expr.dtype.name) for expr in out_expr
        ]
        for mb_begin in range(0, num_items, minibatch_size):
            if print_progress:
                print('\r%d / %d' % (mb_begin, num_items), end='')
            mb_end = min(mb_begin + minibatch_size, num_items)
            mb_in = [src[mb_begin:mb_end] for src in in_arrays]
            # config = tf.compat.v1.ConfigProto(log_device_placement=True, allow_soft_placement=True)
            print("Num GPUs Available: ", tf.config.list_physical_devices())
            mb_out = tf.Session().run(out_expr,
                                      dict(zip(self.input_templates, mb_in)))
            # mb_out = tf.get_default_session().run(out_expr, dict(zip(self.input_templates, mb_in)))
            for dst, src in zip(out_arrays, mb_out):
                dst[mb_begin:mb_end] = src

        # Done.
        if print_progress:
            print('\r%d / %d' % (num_items, num_items))
        if not return_as_list:
            out_arrays = out_arrays[0] if len(out_arrays) == 1 else tuple(
                out_arrays)
        return out_arrays
コード例 #3
0
    def build(self, rgb):
        """
        load variable from npy to build the vgg

        :param rgb: rgb image [batch, height, width, 3] values scaled [0, 1]
        """

        start_time = time.time()
        print("build model started")
        rgb_scaled = rgb * 255.0

        # Convert RGB to BGR
        red, green, blue = tf.split(axis=3,
                                    num_or_size_splits=3,
                                    value=rgb_scaled)
        assert red.get_shape().as_list()[1:] == [224, 224, 1]
        assert green.get_shape().as_list()[1:] == [224, 224, 1]
        assert blue.get_shape().as_list()[1:] == [224, 224, 1]
        bgr = tf.concat(axis=3,
                        values=[
                            blue - VGG_MEAN[0],
                            green - VGG_MEAN[1],
                            red - VGG_MEAN[2],
                        ])
        assert bgr.get_shape().as_list()[1:] == [224, 224, 3]

        self.conv1_1 = self.conv_layer(bgr, "conv1_1")
        self.conv1_2 = self.conv_layer(self.conv1_1, "conv1_2")
        self.pool1 = self.max_pool(self.conv1_2, 'pool1')

        self.conv2_1 = self.conv_layer(self.pool1, "conv2_1")
        self.conv2_2 = self.conv_layer(self.conv2_1, "conv2_2")
        self.pool2 = self.max_pool(self.conv2_2, 'pool2')

        self.conv3_1 = self.conv_layer(self.pool2, "conv3_1")
        self.conv3_2 = self.conv_layer(self.conv3_1, "conv3_2")
        self.conv3_3 = self.conv_layer(self.conv3_2, "conv3_3")
        self.pool3 = self.max_pool(self.conv3_3, 'pool3')

        self.conv4_1 = self.conv_layer(self.pool3, "conv4_1")
        self.conv4_2 = self.conv_layer(self.conv4_1, "conv4_2")
        self.conv4_3 = self.conv_layer(self.conv4_2, "conv4_3")
        self.pool4 = self.max_pool(self.conv4_3, 'pool4')

        self.conv5_1 = self.conv_layer(self.pool4, "conv5_1")
        self.conv5_2 = self.conv_layer(self.conv5_1, "conv5_2")
        self.conv5_3 = self.conv_layer(self.conv5_2, "conv5_3")
        self.pool5 = self.max_pool(self.conv5_3, 'pool5')

        self.fc6 = self.fc_layer(self.pool5, "fc6")
        assert self.fc6.get_shape().as_list()[1:] == [4096]
        self.relu6 = tf.nn.relu(self.fc6)

        self.fc7 = self.fc_layer(self.relu6, "fc7")
        self.relu7 = tf.nn.relu(self.fc7)

        self.fc8 = self.fc_layer(self.relu7, "fc8")

        self.prob = tf.nn.softmax(self.fc8, name="prob")

        self.data_dict = None
        print(("build model finished: %ds" % (time.time() - start_time)))
コード例 #4
0
 def objective(self, params, data=None, labels=None):
     x, y = tf.split(params[0], 2, axis=0)
     obj = (x + 2 * y - 7)**2 + (2 * x + y - 5)**2
     return tf.squeeze(obj)
コード例 #5
0
 def objective(self, params, data=None, labels=None):
     x, y = tf.split(params[0], 2, axis=0)
     obj = 0.26 * (x**2 + y**2) - 0.48 * x * y
     return tf.squeeze(obj)
コード例 #6
0
 def objective(self, params, data=None, labels=None):
     x, y = tf.split(params[0], 2, axis=0)
     obj = x**2 - y**2
     return tf.squeeze(obj)
コード例 #7
0
 def objective(self, params, data=None, labels=None):
     x, y = tf.split(params[0], 2, axis=0)
     obj = (-20 * tf.exp(-0.2 * tf.sqrt(0.5 * (x**2 + y**2))) -
            tf.exp(0.5 * (tf.cos(2 * np.pi * x) + tf.cos(2 * np.pi * y))) +
            tf.exp(1.0) + 20.)
     return tf.squeeze(obj)
コード例 #8
0
def axial_mixture_unidir(x, config, is_training=True, causal=True):
    """Full attention matrix with axial pattern as local and mixture for global summary."""
    del is_training
    assert causal
    bsize = x.shape[0]
    query, key, value = attention.get_qkv(x,
                                          x,
                                          x,
                                          hidden_size=config.model_size,
                                          num_heads=config.num_heads,
                                          bias=config.dense_use_bias)
    head_dim = config.model_size // config.num_heads
    assert config.max_seq_len % config.max_seg_len == 0
    num_seg = config.max_seq_len // config.max_seg_len
    cur_query = tf.reshape(
        query,
        [bsize, num_seg, config.max_seg_len, config.num_heads, head_dim])
    cur_key = tf.reshape(key, cur_query.shape)
    cur_val = tf.reshape(value, cur_query.shape)

    col_logit_expr = 'BSUNK,BTUNK->BUNST'
    col_attn_expr = 'BUNST,BTUNK->BSUNK'
    col_strict_mask = get_causal_mask(cur_query, axis=1,
                                      is_strict=True)[tf.newaxis, tf.newaxis,
                                                      tf.newaxis, :, :]
    row_logit_expr = 'BUSNK,BUTNK->BUNST'
    row_attn_expr = 'BUNST,BUTNK->BUSNK'
    row_mask = get_causal_mask(cur_query, axis=2,
                               is_strict=False)[tf.newaxis, tf.newaxis,
                                                tf.newaxis, :, :]
    col_logits = tf.einsum(col_logit_expr, cur_query,
                           cur_key) + col_strict_mask
    row_logits = tf.einsum(row_logit_expr, cur_query, cur_key) + row_mask

    ###################
    col_up2down_query = approx_cummax(cur_query, axis=1)
    col_up2down_key = shift_right(approx_cummax(cur_key, axis=1), axis=1)
    col_mask = get_causal_mask(cur_query, axis=1,
                               is_strict=False)[tf.newaxis, tf.newaxis,
                                                tf.newaxis, :, :]
    col_up2down_logits = tf.einsum(col_logit_expr, col_up2down_query,
                                   cur_key) + col_mask
    col_up2down_attn_weights = attention.float32_softmax(col_up2down_logits,
                                                         axis=-1)
    col_up2down_summary = tf.einsum(col_attn_expr, col_up2down_attn_weights,
                                    cur_val)
    col_up2down_summary = shift_right(col_up2down_summary, axis=1)

    row_only_myself_mask = tf.eye(tf.shape(cur_query)[2],
                                  dtype=cur_query.dtype)[tf.newaxis,
                                                         tf.newaxis,
                                                         tf.newaxis, :, :]
    row_without_myself_mask = -1e9 * row_only_myself_mask
    all_maskout = tf.cast(tf.fill(row_without_myself_mask.shape, -1e9),
                          cur_query.dtype)
    row_without_myself_mask = tf.concat(
        [all_maskout] + [row_without_myself_mask] * (cur_query.shape[1] - 1),
        axis=1)
    previous_row_logits = tf.einsum(row_logit_expr, cur_query,
                                    col_up2down_key) + row_without_myself_mask
    ###################

    row_left2right_query = approx_cummax(cur_query, axis=2)
    row_left2right_key = shift_right(approx_cummax(cur_key, axis=2), axis=2)
    row_left2right_logits = tf.einsum(row_logit_expr, row_left2right_query,
                                      cur_key) + row_mask
    row_left2right_attn_weights = attention.float32_softmax(
        row_left2right_logits, axis=-1)
    row_left2right_summary = tf.einsum(row_attn_expr,
                                       row_left2right_attn_weights, cur_val)
    row_left2right_summary = shift_right(row_left2right_summary, axis=2)

    all_maskout = tf.cast(tf.fill(col_strict_mask.shape, -1e9),
                          cur_query.dtype)
    col_strict_without_first_mask = tf.concat(
        [all_maskout] + [col_strict_mask] * (cur_query.shape[2] - 1), axis=1)
    top_left_col_logits = tf.einsum(
        col_logit_expr, cur_query,
        row_left2right_key) + col_strict_without_first_mask
    ###################

    row_right2left_query = approx_cummax(cur_query, axis=2, reverse=True)
    row_right2left_key = shift_left(approx_cummax(cur_key,
                                                  axis=2,
                                                  reverse=True),
                                    axis=2)
    row_upper_mask = get_causal_mask(cur_query,
                                     axis=2,
                                     is_strict=False,
                                     upper=True)[tf.newaxis, tf.newaxis,
                                                 tf.newaxis, :, :]
    row_right2left_logits = tf.einsum(row_logit_expr, row_right2left_query,
                                      cur_key) + row_upper_mask
    row_right2left_attn_weights = attention.float32_softmax(
        row_right2left_logits, axis=-1)
    row_right2left_summary = tf.einsum(row_attn_expr,
                                       row_right2left_attn_weights, cur_val)
    row_right2left_summary = shift_left(row_right2left_summary, axis=2)
    col_strict_without_last_mask = tf.concat(
        [col_strict_mask] * (cur_query.shape[2] - 1) + [all_maskout], axis=1)
    top_right_col_logits = tf.einsum(
        col_logit_expr, cur_query,
        row_right2left_key) + col_strict_without_last_mask
    ###################

    joint_logits = tf.concat([
        tf.transpose(col_logits, perm=[0, 3, 2, 1, 4]), row_logits,
        previous_row_logits,
        tf.transpose(top_left_col_logits, perm=[0, 3, 2, 1, 4]),
        tf.transpose(top_right_col_logits, perm=[0, 3, 2, 1, 4])
    ],
                             axis=-1)
    attn_weights = attention.float32_softmax(joint_logits, axis=-1)
    col_att, row_att, previous_row_att, top_left_col_att, top_right_col_att = tf.split(
        attn_weights,
        [num_seg, config.max_seg_len, config.max_seg_len, num_seg, num_seg],
        axis=-1)
    col_att = tf.transpose(col_att, [0, 3, 2, 1, 4])
    top_left_col_att = tf.transpose(top_left_col_att, [0, 3, 2, 1, 4])
    top_right_col_att = tf.transpose(top_right_col_att, [0, 3, 2, 1, 4])
    col_merged = tf.einsum(col_attn_expr, col_att, cur_val)
    row_merged = tf.einsum(row_attn_expr, row_att, cur_val)
    previous_row_merged = tf.einsum(row_attn_expr, previous_row_att,
                                    col_up2down_summary)
    top_left_merged = tf.einsum(col_attn_expr, top_left_col_att,
                                row_left2right_summary)
    top_right_merged = tf.einsum(col_attn_expr, top_right_col_att,
                                 row_right2left_summary)

    joint_merged = tf.reshape(
        col_merged + row_merged + previous_row_merged + top_left_merged +
        top_right_merged,
        [bsize, num_seg * config.max_seg_len, config.num_heads, head_dim])
    output = ops.trail_dense(joint_merged, config.model_size, begin_axis=-2)
    return output
コード例 #9
0
def sqrt_fixed_full(x, config, is_training=True, causal=True):
    """Full attention matrix with sqrt decomposition."""
    bsize = x.shape[0]
    query, key, value = attention.get_qkv(x,
                                          x,
                                          x,
                                          hidden_size=config.model_size,
                                          num_heads=config.num_heads,
                                          bias=config.dense_use_bias)
    head_dim = config.model_size // config.num_heads
    assert config.max_seq_len % config.max_seg_len == 0
    num_seg = config.max_seq_len // config.max_seg_len
    cur_query = tf.reshape(
        query, [-1, num_seg, config.max_seg_len, config.num_heads, head_dim])
    with tf.variable_scope('pooling_query'):
        merged_query = pooling_summary(cur_query,
                                       axis=2,
                                       local_summary=config.local_summary,
                                       keepdims=True)
    cur_key = tf.reshape(key, cur_query.shape)
    cur_val = tf.reshape(value, cur_query.shape)
    span_val = attention.dot_product_attention(merged_query,
                                               cur_key,
                                               cur_val,
                                               is_training=is_training,
                                               attn_axis=1,
                                               dropatt=config.dropatt)
    span_val = tf.squeeze(span_val, axis=2)
    with tf.variable_scope('pooling_key'):
        span_key = pooling_summary(cur_key,
                                   axis=2,
                                   local_summary=config.local_summary,
                                   keepdims=False)
    local_logits = tf.einsum('bsqhd,bskhd->bsqhk', cur_query, cur_key)
    if causal:
        local_mask = get_causal_mask(cur_query, axis=2, is_strict=False)
        local_mask = tf.expand_dims(local_mask, axis=-2)
        local_logits += local_mask
    prev_logits = tf.einsum('bqhd,bkhd->bqhk', query, span_key)
    if causal:
        prev_mask = get_causal_mask(cur_query, axis=1, is_strict=True)
        prev_mask = tf.repeat(prev_mask, [config.max_seg_len] * num_seg,
                              axis=0)
        prev_logits += tf.expand_dims(prev_mask, axis=1)
    joint_logits = tf.concat([
        tf.reshape(local_logits,
                   [bsize, config.max_seq_len, config.num_heads, -1]),
        prev_logits
    ],
                             axis=-1)
    attn_weights = attention.float32_softmax(joint_logits, axis=-1)
    local_att, prev_att = tf.split(attn_weights, [config.max_seg_len, num_seg],
                                   axis=-1)
    if is_training:
        local_att = tf.nn.dropout(local_att, rate=config.dropatt)
    local_att = tf.reshape(local_att, [
        bsize, num_seg, config.max_seg_len, config.num_heads,
        config.max_seg_len
    ])
    local_merged = tf.einsum('bsqhk,bskhd->bsqhd', local_att, cur_val)
    prev_merged = tf.einsum('bqhk,bkhd->bqhd', prev_att, span_val)
    joint_merged = prev_merged + tf.reshape(local_merged, prev_merged.shape)
    output = ops.trail_dense(joint_merged, config.model_size, begin_axis=-2)
    return output
コード例 #10
0
def sampling(output):
  mu, logstd = tf.split(output, num_or_size_splits=2, axis=-1)
  sigma = tf.nn.softplus(logstd)
  ws = mu + tf.random_normal(tf.shape(mu)) * sigma
  return ws, mu, sigma
コード例 #11
0
def main():
    print("Local rank: ", hvd.local_rank(), hvd.size())

    logdir = osp.join(FLAGS.logdir, FLAGS.exp)
    if hvd.rank() == 0:
        if not osp.exists(logdir):
            os.makedirs(logdir)
        logger = TensorBoardOutputFormat(logdir)
    else:
        logger = None

    LABEL = None
    print("Loading data...")
    if FLAGS.dataset == 'cifar10':
        dataset = Cifar10(augment=FLAGS.augment, rescale=FLAGS.rescale)
        test_dataset = Cifar10(train=False, rescale=FLAGS.rescale)
        channel_num = 3

        X_NOISE = tf.placeholder(shape=(None, 32, 32, 3), dtype=tf.float32)
        X = tf.placeholder(shape=(None, 32, 32, 3), dtype=tf.float32)
        LABEL = tf.placeholder(shape=(None, 10), dtype=tf.float32)
        LABEL_POS = tf.placeholder(shape=(None, 10), dtype=tf.float32)

        if FLAGS.large_model:
            model = ResNet32Large(num_channels=channel_num,
                                  num_filters=128,
                                  train=True)
        elif FLAGS.larger_model:
            model = ResNet32Larger(num_channels=channel_num, num_filters=128)
        elif FLAGS.wider_model:
            model = ResNet32Wider(num_channels=channel_num, num_filters=192)
        else:
            model = ResNet32(num_channels=channel_num, num_filters=128)

    elif FLAGS.dataset == 'imagenet':
        dataset = Imagenet(train=True)
        test_dataset = Imagenet(train=False)
        channel_num = 3
        X_NOISE = tf.placeholder(shape=(None, 32, 32, 3), dtype=tf.float32)
        X = tf.placeholder(shape=(None, 32, 32, 3), dtype=tf.float32)
        LABEL = tf.placeholder(shape=(None, 1000), dtype=tf.float32)
        LABEL_POS = tf.placeholder(shape=(None, 1000), dtype=tf.float32)

        model = ResNet32Wider(num_channels=channel_num, num_filters=256)

    elif FLAGS.dataset == 'imagenetfull':
        channel_num = 3
        X_NOISE = tf.placeholder(shape=(None, 128, 128, 3), dtype=tf.float32)
        X = tf.placeholder(shape=(None, 128, 128, 3), dtype=tf.float32)
        LABEL = tf.placeholder(shape=(None, 1000), dtype=tf.float32)
        LABEL_POS = tf.placeholder(shape=(None, 1000), dtype=tf.float32)

        model = ResNet128(num_channels=channel_num, num_filters=64)

    elif FLAGS.dataset == 'mnist':
        dataset = Mnist(rescale=FLAGS.rescale)
        test_dataset = dataset
        channel_num = 1
        X_NOISE = tf.placeholder(shape=(None, 28, 28), dtype=tf.float32)
        X = tf.placeholder(shape=(None, 28, 28), dtype=tf.float32)
        LABEL = tf.placeholder(shape=(None, 10), dtype=tf.float32)
        LABEL_POS = tf.placeholder(shape=(None, 10), dtype=tf.float32)

        model = MnistNet(num_channels=channel_num,
                         num_filters=FLAGS.num_filters)

    elif FLAGS.dataset == 'dsprites':
        dataset = DSprites(cond_shape=FLAGS.cond_shape,
                           cond_size=FLAGS.cond_size,
                           cond_pos=FLAGS.cond_pos,
                           cond_rot=FLAGS.cond_rot)
        test_dataset = dataset
        channel_num = 1

        X_NOISE = tf.placeholder(shape=(None, 64, 64), dtype=tf.float32)
        X = tf.placeholder(shape=(None, 64, 64), dtype=tf.float32)

        if FLAGS.dpos_only:
            LABEL = tf.placeholder(shape=(None, 2), dtype=tf.float32)
            LABEL_POS = tf.placeholder(shape=(None, 2), dtype=tf.float32)
        elif FLAGS.dsize_only:
            LABEL = tf.placeholder(shape=(None, 1), dtype=tf.float32)
            LABEL_POS = tf.placeholder(shape=(None, 1), dtype=tf.float32)
        elif FLAGS.drot_only:
            LABEL = tf.placeholder(shape=(None, 2), dtype=tf.float32)
            LABEL_POS = tf.placeholder(shape=(None, 2), dtype=tf.float32)
        elif FLAGS.cond_size:
            LABEL = tf.placeholder(shape=(None, 1), dtype=tf.float32)
            LABEL_POS = tf.placeholder(shape=(None, 1), dtype=tf.float32)
        elif FLAGS.cond_shape:
            LABEL = tf.placeholder(shape=(None, 3), dtype=tf.float32)
            LABEL_POS = tf.placeholder(shape=(None, 3), dtype=tf.float32)
        elif FLAGS.cond_pos:
            LABEL = tf.placeholder(shape=(None, 2), dtype=tf.float32)
            LABEL_POS = tf.placeholder(shape=(None, 2), dtype=tf.float32)
        elif FLAGS.cond_rot:
            LABEL = tf.placeholder(shape=(None, 2), dtype=tf.float32)
            LABEL_POS = tf.placeholder(shape=(None, 2), dtype=tf.float32)
        else:
            LABEL = tf.placeholder(shape=(None, 3), dtype=tf.float32)
            LABEL_POS = tf.placeholder(shape=(None, 3), dtype=tf.float32)

        model = DspritesNet(num_channels=channel_num,
                            num_filters=FLAGS.num_filters,
                            cond_size=FLAGS.cond_size,
                            cond_shape=FLAGS.cond_shape,
                            cond_pos=FLAGS.cond_pos,
                            cond_rot=FLAGS.cond_rot)

    print("Done loading...")

    if FLAGS.dataset == "imagenetfull":
        # In the case of full imagenet, use custom_tensorflow dataloader
        data_loader = TFImagenetLoader('train',
                                       FLAGS.batch_size,
                                       hvd.rank(),
                                       hvd.size(),
                                       rescale=FLAGS.rescale)
    else:
        data_loader = DataLoader(dataset,
                                 batch_size=FLAGS.batch_size,
                                 num_workers=FLAGS.data_workers,
                                 drop_last=True,
                                 shuffle=True)

    batch_size = FLAGS.batch_size

    weights = [model.construct_weights('context_0')]

    Y = tf.placeholder(shape=(None), dtype=tf.int32)

    # Varibles to run in training
    X_SPLIT = tf.split(X, FLAGS.num_gpus)
    X_NOISE_SPLIT = tf.split(X_NOISE, FLAGS.num_gpus)
    LABEL_SPLIT = tf.split(LABEL, FLAGS.num_gpus)
    LABEL_POS_SPLIT = tf.split(LABEL_POS, FLAGS.num_gpus)
    LABEL_SPLIT_INIT = list(LABEL_SPLIT)
    tower_grads = []
    tower_gen_grads = []
    x_mod_list = []

    optimizer = AdamOptimizer(FLAGS.lr, beta1=0.0, beta2=0.999)
    optimizer = hvd.DistributedOptimizer(optimizer)

    for j in range(FLAGS.num_gpus):

        if FLAGS.model_cclass:
            ind_batch_size = FLAGS.batch_size // FLAGS.num_gpus
            label_tensor = tf.Variable(tf.convert_to_tensor(np.reshape(
                np.tile(np.eye(10), (FLAGS.batch_size, 1, 1)),
                (FLAGS.batch_size * 10, 10)),
                                                            dtype=tf.float32),
                                       trainable=False,
                                       dtype=tf.float32)
            x_split = tf.tile(
                tf.reshape(X_SPLIT[j], (ind_batch_size, 1, 32, 32, 3)),
                (1, 10, 1, 1, 1))
            x_split = tf.reshape(x_split, (ind_batch_size * 10, 32, 32, 3))
            energy_pos = model.forward(x_split,
                                       weights[0],
                                       label=label_tensor,
                                       stop_at_grad=False)

            energy_pos_full = tf.reshape(energy_pos, (ind_batch_size, 10))
            energy_partition_est = tf.reduce_logsumexp(energy_pos_full,
                                                       axis=1,
                                                       keepdims=True)
            uniform = tf.random_uniform(tf.shape(energy_pos_full))
            label_tensor = tf.argmax(-energy_pos_full -
                                     tf.log(-tf.log(uniform)) -
                                     energy_partition_est,
                                     axis=1)
            label = tf.one_hot(label_tensor, 10, dtype=tf.float32)
            label = tf.Print(label, [label_tensor, energy_pos_full])
            LABEL_SPLIT[j] = label
            energy_pos = tf.concat(energy_pos, axis=0)
        else:
            energy_pos = [
                model.forward(X_SPLIT[j],
                              weights[0],
                              label=LABEL_POS_SPLIT[j],
                              stop_at_grad=False)
            ]
            energy_pos = tf.concat(energy_pos, axis=0)

        print("Building graph...")
        x_mod = x_orig = X_NOISE_SPLIT[j]

        x_grads = []

        energy_negs = []
        loss_energys = []

        energy_negs.extend([
            model.forward(tf.stop_gradient(x_mod),
                          weights[0],
                          label=LABEL_SPLIT[j],
                          stop_at_grad=False,
                          reuse=True)
        ])
        eps_begin = tf.zeros(1)

        steps = tf.constant(0)
        c = lambda i, x: tf.less(i, FLAGS.num_steps)

        def langevin_step(counter, x_mod):
            x_mod = x_mod + tf.random_normal(
                tf.shape(x_mod),
                mean=0.0,
                stddev=0.005 * FLAGS.rescale * FLAGS.noise_scale)

            energy_noise = energy_start = tf.concat([
                model.forward(x_mod,
                              weights[0],
                              label=LABEL_SPLIT[j],
                              reuse=True,
                              stop_at_grad=False,
                              stop_batch=True)
            ],
                                                    axis=0)

            x_grad, label_grad = tf.gradients(FLAGS.temperature * energy_noise,
                                              [x_mod, LABEL_SPLIT[j]])
            energy_noise_old = energy_noise

            lr = FLAGS.step_lr

            if FLAGS.proj_norm != 0.0:
                if FLAGS.proj_norm_type == 'l2':
                    x_grad = tf.clip_by_norm(x_grad, FLAGS.proj_norm)
                elif FLAGS.proj_norm_type == 'li':
                    x_grad = tf.clip_by_value(x_grad, -FLAGS.proj_norm,
                                              FLAGS.proj_norm)
                else:
                    print("Other types of projection are not supported!!!")
                    assert False

            # Clip gradient norm for now
            if FLAGS.hmc:
                # Step size should be tuned to get around 65% acceptance
                def energy(x):
                    return FLAGS.temperature * \
                        model.forward(x, weights[0], label=LABEL_SPLIT[j], reuse=True)

                x_last = hmc(x_mod, 15., 10, energy)
            else:
                x_last = x_mod - (lr) * x_grad

            x_mod = x_last
            x_mod = tf.clip_by_value(x_mod, 0, FLAGS.rescale)

            counter = counter + 1

            return counter, x_mod

        steps, x_mod = tf.while_loop(c, langevin_step, (steps, x_mod))

        energy_eval = model.forward(x_mod,
                                    weights[0],
                                    label=LABEL_SPLIT[j],
                                    stop_at_grad=False,
                                    reuse=True)
        x_grad = tf.gradients(FLAGS.temperature * energy_eval, [x_mod])[0]
        x_grads.append(x_grad)

        energy_negs.append(
            model.forward(tf.stop_gradient(x_mod),
                          weights[0],
                          label=LABEL_SPLIT[j],
                          stop_at_grad=False,
                          reuse=True))

        test_x_mod = x_mod

        temp = FLAGS.temperature

        energy_neg = energy_negs[-1]
        x_off = tf.reduce_mean(
            tf.abs(x_mod[:tf.shape(X_SPLIT[j])[0]] - X_SPLIT[j]))

        loss_energy = model.forward(x_mod,
                                    weights[0],
                                    reuse=True,
                                    label=LABEL,
                                    stop_grad=True)

        print("Finished processing loop construction ...")

        target_vars = {}

        if FLAGS.cclass or FLAGS.model_cclass:
            label_sum = tf.reduce_sum(LABEL_SPLIT[0], axis=0)
            label_prob = label_sum / tf.reduce_sum(label_sum)
            label_ent = -tf.reduce_sum(
                label_prob * tf.math.log(label_prob + 1e-7))
        else:
            label_ent = tf.zeros(1)

        target_vars['label_ent'] = label_ent

        if FLAGS.train:

            if FLAGS.objective == 'logsumexp':
                pos_term = temp * energy_pos
                energy_neg_reduced = (energy_neg - tf.reduce_min(energy_neg))
                coeff = tf.stop_gradient(tf.exp(-temp * energy_neg_reduced))
                norm_constant = tf.stop_gradient(tf.reduce_sum(coeff)) + 1e-4
                pos_loss = tf.reduce_mean(temp * energy_pos)
                neg_loss = coeff * (-1 * temp * energy_neg) / norm_constant
                loss_ml = FLAGS.ml_coeff * (pos_loss + tf.reduce_sum(neg_loss))
            elif FLAGS.objective == 'cd':
                pos_loss = tf.reduce_mean(temp * energy_pos)
                neg_loss = -tf.reduce_mean(temp * energy_neg)
                loss_ml = FLAGS.ml_coeff * (pos_loss + tf.reduce_sum(neg_loss))
            elif FLAGS.objective == 'softplus':
                loss_ml = FLAGS.ml_coeff * \
                    tf.nn.softplus(temp * (energy_pos - energy_neg))

            loss_total = tf.reduce_mean(loss_ml)

            if not FLAGS.zero_kl:
                loss_total = loss_total + tf.reduce_mean(loss_energy)

            loss_total = loss_total + \
                FLAGS.l2_coeff * (tf.reduce_mean(tf.square(energy_pos)) + tf.reduce_mean(tf.square((energy_neg))))

            print("Started gradient computation...")
            gvs = optimizer.compute_gradients(loss_total)
            gvs = [(k, v) for (k, v) in gvs if k is not None]

            print("Applying gradients...")

            tower_grads.append(gvs)

            print("Finished applying gradients.")

            target_vars['loss_ml'] = loss_ml
            target_vars['total_loss'] = loss_total
            target_vars['loss_energy'] = loss_energy
            target_vars['weights'] = weights
            target_vars['gvs'] = gvs

        target_vars['X'] = X
        target_vars['Y'] = Y
        target_vars['LABEL'] = LABEL
        target_vars['LABEL_POS'] = LABEL_POS
        target_vars['X_NOISE'] = X_NOISE
        target_vars['energy_pos'] = energy_pos
        target_vars['energy_start'] = energy_negs[0]

        if len(x_grads) >= 1:
            target_vars['x_grad'] = x_grads[-1]
            target_vars['x_grad_first'] = x_grads[0]
        else:
            target_vars['x_grad'] = tf.zeros(1)
            target_vars['x_grad_first'] = tf.zeros(1)

        target_vars['x_mod'] = x_mod
        target_vars['x_off'] = x_off
        target_vars['temp'] = temp
        target_vars['energy_neg'] = energy_neg
        target_vars['test_x_mod'] = test_x_mod
        target_vars['eps_begin'] = eps_begin

    if FLAGS.train:
        grads = average_gradients(tower_grads)
        train_op = optimizer.apply_gradients(grads)
        target_vars['train_op'] = train_op

    config = tf.ConfigProto()

    if hvd.size() > 1:
        config.gpu_options.visible_device_list = str(hvd.local_rank())

    sess = tf.Session(config=config)

    saver = loader = tf.train.Saver(max_to_keep=30,
                                    keep_checkpoint_every_n_hours=6)

    total_parameters = 0
    for variable in tf.trainable_variables():
        # shape is an array of tf.Dimension
        shape = variable.get_shape()
        variable_parameters = 1
        for dim in shape:
            variable_parameters *= dim.value
        total_parameters += variable_parameters
    print("Model has a total of {} parameters".format(total_parameters))

    sess.run(tf.global_variables_initializer())

    resume_itr = 0

    if (FLAGS.resume_iter != -1 or not FLAGS.train) and hvd.rank() == 0:
        model_file = osp.join(logdir, 'model_{}'.format(FLAGS.resume_iter))
        resume_itr = FLAGS.resume_iter
        # saver.restore(sess, model_file)
        optimistic_restore(sess, model_file)

    sess.run(hvd.broadcast_global_variables(0))
    print("Initializing variables...")

    print("Start broadcast")
    print("End broadcast")

    if FLAGS.train:
        print("Training phase")
        train(target_vars, saver, sess, logger, data_loader, resume_itr,
              logdir)
    print("Testing phase")
    test(target_vars, saver, sess, logger, data_loader)
コード例 #12
0
    def _build_graph(self):
        """Builds the computational graph.

    Input placeholders created:
      observations: shape = [batch_size, hparams.fingerprint_length].
        The input of the Q function.
      head: shape = [1].
        The index of the head chosen for decision.
      objective_weight: shape = [num_objectives, 1].
        objective_weight is the weight to scalarize the objective vector:
        reward = sum (objective_weight_i * objective_i)
      state_t: shape = [batch_size, hparams.fingerprint_length].
        The state at time step t.
      state_tp1: a list of tensors,
        each has shape = [num_actions, hparams.fingerprint_length].
        Note that the num_actions can be different for each tensor.
        The state at time step t+1.
      done_mask: shape = [batch_size, 1]
        Whether state_tp1 is the terminal state.
      reward_t: shape = [batch_size, num_objectives]
        the reward at time step t.
      error weight: shape = [batch_size, 1]
        weight for the loss.

    Instance attributes created:
      q_values: List of Tensors of [batch_size, 1]. The q values for the
        observations.
      td_error: List of Tensor of [batch_size, 1]. The TD error.
        weighted_error: List of Tensor of [batch_size, 1]. The TD error weighted
        by importance sampling weight.
      q_fn_vars: List of tf.Variables. The variables of q_fn when computing
        the q_values of state_t
      q_fn_vars: List of tf.Variables. The variables of q_fn when computing
        the q_values of state_tp1

    """
        batch_size, _ = self.input_shape
        with tf.variable_scope(self.scope, reuse=self.reuse):
            self._build_input_placeholder()
            self.reward_t = tf.placeholder(tf.float32,
                                           (batch_size, self.num_objectives),
                                           name='reward_t')
            # objective_weight is the weight to scalarize the objective vector:
            # reward = sum (objective_weight_i * objective_i)
            self.objective_weight_input = tf.placeholder(
                tf.float32, [self.num_objectives, 1], name='objective_weight')

            # split reward for each q network
            rewards_list = tf.split(self.reward_t, self.num_objectives, axis=1)
            q_values_list = []
            self.td_error = []
            self.weighted_error = 0
            self.q_fn_vars = []
            self.q_tp1_vars = []

            # build a Q network for each objective
            for obj_idx in range(self.num_objectives):
                with tf.variable_scope('objective_%i' % obj_idx):
                    (q_values, td_error, weighted_error, q_fn_vars,
                     q_tp1_vars) = self._build_single_q_network(
                         self.observations, self.head, self.state_t,
                         self.state_tp1, self.done_mask, rewards_list[obj_idx],
                         self.error_weight)
                    q_values_list.append(tf.expand_dims(q_values, 1))
                    # td error is for summary only.
                    # weighted error is the optimization goal.
                    self.td_error.append(td_error)
                    self.weighted_error += weighted_error / self.num_objectives
                    self.q_fn_vars += q_fn_vars
                    self.q_tp1_vars += q_tp1_vars
            q_values = tf.concat(q_values_list, axis=1)
            # action is the one that leads to the maximum weighted reward.
            self.action = tf.argmax(tf.matmul(q_values,
                                              self.objective_weight_input),
                                    axis=0)
コード例 #13
0
def affine_coupling(name,
                    x,
                    x_mask,
                    inverse,
                    split_dim,
                    identity_first,
                    init,
                    decoder_self_attention_bias=None,
                    **kwargs):
    """Affine coupling transform layer.

  Args:
    name: variable scope.
    x: 3-D Tensor, shape=[B, L, C].
    x_mask : 2-D Tensor, shape=[B, L].
    inverse: Forward or inverse pass.
    split_dim: which dimension to split
      (time, channel_continuous, channel_alternate).
    identity_first: True means the first half remains constant. False for 2nd.
    init: init.
    decoder_self_attention_bias: bias.
    **kwargs: additional arguments. Contains hparams, encoder_output and
      encoder_decoder_attention_bias.

  Returns:
    z: data transformed by the affine coupling layer. shape=[B, L, C]
    logabsdets: Log absolute determinant Jacobian. shape=[B]
  """
    hparams = kwargs["hparams"]
    batch_size, length, n_channels = common_layers.shape_list(x)
    assert hparams.scale_width > 0.0 and hparams.scale_width < 1.0
    with tf.variable_scope(name, reuse=tf.AUTO_REUSE):
        x_id, x_tr, _, n_transform, bias, mask = gops.split_coupling(
            x, x_mask, split_dim, identity_first, decoder_self_attention_bias)
        z_id = x_id

        transform_params = gops.transformer_decoder_block(
            "theta_tr",
            n_layers=hparams.n_layers_transform_params,
            x=x_id,
            x_mask=mask,
            output_size=n_transform * 2,
            init=init,
            decoder_self_attention_bias=bias,
            **kwargs)
        loc, unconstrained_scale = tf.split(transform_params, 2, axis=-1)
        scale = tf.sigmoid(unconstrained_scale + 2.0)
        if not inverse:
            z_tr = (x_tr + loc) * scale
        else:
            z_tr = x_tr / scale - loc

        logabsdet = gops.reduce_sum_over_lc(tf.log(scale), mask)  # [B]
        if inverse:
            logabsdet *= -1

        tf.summary.histogram("_loc", tf.boolean_mask(loc, mask))
        tf.summary.histogram("_scale", tf.boolean_mask(scale, mask))
        result = gops.join_coupling(z_id, z_tr, split_dim, identity_first)
        result = tf.reshape(result, [batch_size, length, n_channels])
        return result, logabsdet
コード例 #14
0
def compress(args):
    """Compresses an image, or a batch of images of the same shape in npy format."""
    from configs import get_eval_batch_size

    if args.input_file.endswith('.npy'):
        # .npy file should contain N images of the same shapes, in the form of an array of shape [N, H, W, 3]
        X = np.load(args.input_file)
    else:
        # Load input image and add batch dimension.
        from PIL import Image
        x = np.asarray(Image.open(args.input_file).convert('RGB'))
        X = x[None, ...]

    num_images = int(X.shape[0])
    img_num_pixels = int(np.prod(X.shape[1:-1]))
    X = X.astype('float32')
    X /= 255.

    eval_batch_size = get_eval_batch_size(img_num_pixels)
    dataset = tf.data.Dataset.from_tensor_slices(X)
    dataset = dataset.batch(batch_size=eval_batch_size)
    # https://www.tensorflow.org/api_docs/python/tf/compat/v1/data/Iterator
    # Importantly, each sess.run(op) call will consume a new batch, where op is any operation that depends on
    # x. Therefore if multiple ops need to be evaluated on the same batch of data, they have to be grouped like
    # sess.run([op1, op2, ...]).
    # x = dataset.make_one_shot_iterator().get_next()
    x_next = dataset.make_one_shot_iterator().get_next()

    x_ph = x = tf.placeholder(
        'float32',
        (None, *X.shape[1:]))  # keep a reference around for feed_dict

    #### BEGIN build compression graph ####
    # Instantiate model.
    analysis_transform = AnalysisTransform(args.num_filters)
    synthesis_transform = SynthesisTransform(args.num_filters)
    hyper_analysis_transform = HyperAnalysisTransform(args.num_filters)
    hyper_synthesis_transform = HyperSynthesisTransform(args.num_filters,
                                                        num_output_filters=2 *
                                                        args.num_filters)
    entropy_bottleneck = tfc.EntropyBottleneck()

    # Initial values for optimization
    y_init = analysis_transform(x)
    z_init = hyper_analysis_transform(y_init)

    y = tf.placeholder('float32', y_init.shape)
    y_tilde = y + tf.random.uniform(tf.shape(y), -0.5, 0.5)

    z = tf.placeholder('float32', z_init.shape)
    # sample z_tilde from q(z_tilde|x) = q(z_tilde|h_a(g_a(x))), and compute the pdf of z_tilde under the flexible prior
    # p(z_tilde) ("z_likelihoods")
    z_tilde, z_likelihoods = entropy_bottleneck(z, training=True)
    z_hat = entropy_bottleneck._quantize(
        z, 'dequantize')  # rounded (with median centering)
    mu, sigma = tf.split(hyper_synthesis_transform(z_tilde),
                         num_or_size_splits=2,
                         axis=-1)
    sigma = tf.exp(sigma)  # make positive
    # need to handle images with non-standard sizes during compression; mu/sigma must have the same shape as y
    y_shape = tf.shape(y_tilde)
    mu = mu[:, :y_shape[1], :y_shape[2], :]
    sigma = sigma[:, :y_shape[1], :y_shape[2], :]
    scale_table = np.exp(
        np.linspace(np.log(SCALES_MIN), np.log(SCALES_MAX), SCALES_LEVELS))
    conditional_bottleneck = tfc.GaussianConditional(sigma,
                                                     scale_table,
                                                     mean=mu)
    # compute the pdf of y_tilde under the conditional prior/entropy model p(y_tilde|z_tilde)
    # = N(y_tilde|mu, sigma^2) conv U(-0.5, 0.5)
    y_likelihoods = conditional_bottleneck._likelihood(
        y_tilde)  # p(\tilde y | \tilde z)
    if conditional_bottleneck.likelihood_bound > 0:
        likelihood_bound = conditional_bottleneck.likelihood_bound
        y_likelihoods = math_ops.lower_bound(y_likelihoods, likelihood_bound)
    y_hat = conditional_bottleneck._quantize(
        y, 'dequantize')  # rounded (with mean centering)

    x_tilde = synthesis_transform(y_tilde)
    x_shape = tf.shape(x)
    x_tilde = x_tilde[:, :x_shape[1], :x_shape[
        2], :]  # crop reconstruction to have the same shape as input

    # Total number of bits divided by number of pixels.
    # - log p(\tilde y | \tilde z) - log p(\tilde z) - - log q(\tilde z | \tilde y)
    axes_except_batch = list(range(1, len(x.shape)))  # should be [1,2,3]
    y_bpp = tf.reduce_sum(-tf.log(y_likelihoods), axis=axes_except_batch) / (
        np.log(2) * img_num_pixels)
    z_bpp = tf.reduce_sum(-tf.log(z_likelihoods), axis=axes_except_batch) / (
        np.log(2) * img_num_pixels)
    eval_bpp = y_bpp + z_bpp  # shape (N,)
    train_bpp = tf.reduce_mean(eval_bpp)

    # Mean squared error across pixels.
    train_mse = tf.reduce_mean(tf.squared_difference(x, x_tilde))
    # Multiply by 255^2 to correct for rescaling.
    # float_train_mse = train_mse
    # psnr = - 10 * (tf.log(float_train_mse) / np.log(10))  # float MSE computed on float images
    train_mse *= 255**2

    # The rate-distortion cost.
    if args.lmbda < 0:
        args.lmbda = float(args.runname.split('lmbda=')[1].split('-')
                           [0])  # re-use the lmbda as used for training
        print(
            'Defaulting lmbda (mse coefficient) to %g as used in model training.'
            % args.lmbda)
    if args.lmbda > 0:
        rd_loss = args.lmbda * train_mse + train_bpp
    else:
        rd_loss = train_bpp
    rd_gradients = tf.gradients(rd_loss, [y, z])

    # Bring both images back to 0..255 range, for evaluation only.
    x *= 255
    x_tilde = tf.clip_by_value(x_tilde, 0, 1)
    x_tilde = tf.round(x_tilde * 255)

    mse = tf.reduce_mean(tf.squared_difference(x, x_tilde),
                         axis=axes_except_batch)  # shape (N,)
    psnr = tf.image.psnr(x_tilde, x, 255)  # shape (N,)
    msssim = tf.image.ssim_multiscale(x_tilde, x, 255)  # shape (N,)
    msssim_db = -10 * tf.log(1 - msssim) / np.log(10)  # shape (N,)

    with tf.Session() as sess:
        # Load the latest model checkpoint, get compression stats
        save_dir = os.path.join(args.checkpoint_dir, args.runname)
        latest = tf.train.latest_checkpoint(checkpoint_dir=save_dir)
        tf.train.Saver().restore(sess, save_path=latest)
        eval_fields = [
            'mse', 'psnr', 'msssim', 'msssim_db', 'est_bpp', 'est_y_bpp',
            'est_z_bpp'
        ]
        eval_tensors = [mse, psnr, msssim, msssim_db, eval_bpp, y_bpp, z_bpp]
        all_results_arrs = {key: []
                            for key in eval_fields
                            }  # append across all batches

        log_itv = 100
        if save_opt_record:
            log_itv = 10
        rd_lr = 0.005
        rd_opt_its = 2000
        from adam import Adam

        batch_idx = 0
        while True:
            try:
                x_val = sess.run(x_next)
                x_feed_dict = {x_ph: x_val}
                # 1. Perform R-D optimization conditioned on ground truth x
                print('----RD Optimization----')
                y_cur, z_cur = sess.run([y_init, z_init],
                                        feed_dict=x_feed_dict)  # np arrays
                adam_optimizer = Adam(lr=rd_lr)
                opt_record = {
                    'its': [],
                    'rd_loss': [],
                    'rd_loss_after_rounding': []
                }
                for it in range(rd_opt_its):
                    grads, obj, mse_, train_bpp_, psnr_ = sess.run(
                        [rd_gradients, rd_loss, train_mse, train_bpp, psnr],
                        feed_dict={
                            y: y_cur,
                            z: z_cur,
                            **x_feed_dict
                        })
                    y_cur, z_cur = adam_optimizer.update([y_cur, z_cur], grads)
                    if it % log_itv == 0 or it + 1 == rd_opt_its:
                        psnr_ = psnr_.mean()
                        if args.verbose:
                            y_hat_, z_hat_ = sess.run([y_hat, z_hat],
                                                      feed_dict={
                                                          y: y_cur,
                                                          z: z_cur
                                                      })
                            bpp_after_rounding, psnr_after_rounding, rd_loss_after_rounding = sess.run(
                                [train_bpp, psnr, rd_loss],
                                feed_dict={
                                    y_tilde: y_hat_,
                                    z_tilde: z_hat_,
                                    **x_feed_dict
                                })
                            psnr_after_rounding = psnr_after_rounding.mean()
                            print(
                                'it=%d, rd_loss=%.4f mse=%.3f bpp=%.4f psnr=%.4f\t after rounding: rd_loss=%.4f, bpp=%.4f psnr=%.4f'
                                % (it, obj, mse_, train_bpp_, psnr_,
                                   rd_loss_after_rounding, bpp_after_rounding,
                                   psnr_after_rounding))
                            opt_record['rd_loss_after_rounding'].append(
                                rd_loss_after_rounding)

                        else:
                            print(
                                'it=%d, rd_loss=%.4f mse=%.3f bpp=%.4f psnr=%.4f'
                                % (it, obj, mse_, train_bpp_, psnr_))
                        opt_record['its'].append(it)
                        opt_record['rd_loss'].append(obj)

                print()

                # this is the latents we end up transmitting
                y_hat_, z_hat_ = sess.run([y_hat, z_hat],
                                          feed_dict={
                                              y: y_cur,
                                              z: z_cur
                                          })

                # If requested, transform the quantized image back and measure performance.
                eval_arrs = sess.run(eval_tensors,
                                     feed_dict={
                                         y_tilde: y_hat_,
                                         z_tilde: z_hat_,
                                         **x_feed_dict
                                     })
                for field, arr in zip(eval_fields, eval_arrs):
                    all_results_arrs[field] += arr.tolist()

                batch_idx += 1

            except tf.errors.OutOfRangeError:
                break

        for field in eval_fields:
            all_results_arrs[field] = np.asarray(all_results_arrs[field])

        input_file = os.path.basename(args.input_file)
        results_dict = all_results_arrs
        trained_script_name = args.runname.split('-')[0]
        script_name = os.path.splitext(os.path.basename(__file__))[
            0]  # current script name, without extension

        # save RD evaluation results
        prefix = 'rd'
        save_file = '%s-%s-input=%s.npz' % (prefix, args.runname, input_file)
        if script_name != trained_script_name:
            save_file = '%s-%s-lmbda=%g+%s-input=%s.npz' % (
                prefix, script_name, args.lmbda, args.runname, input_file)
        np.savez(os.path.join(args.results_dir, save_file), **results_dict)

        if save_opt_record:
            # save optimization record
            prefix = 'opt'
            save_file = '%s-%s-input=%s.npz' % (prefix, args.runname,
                                                input_file)
            if script_name != trained_script_name:
                save_file = '%s-%s-lmbda=%g+%s-input=%s.npz' % (
                    prefix, script_name, args.lmbda, args.runname, input_file)
            np.savez(os.path.join(args.results_dir, save_file), **opt_record)

        for field in eval_fields:
            arr = all_results_arrs[field]
            print('Avg {}: {:0.4f}'.format(field, arr.mean()))
コード例 #15
0
ファイル: lib_graph.py プロジェクト: mishaluczkiw/music_embed
    def apply_convolution(self, x, layer, layer_idx):
        """Adds convolution and batch norm layers if hparam.batch_norm is True."""
        if 'filters' not in layer:
            return x

        filter_shape = layer['filters']
        # Instantiate or retrieve filter weights.
        fanin = tf.to_float(tf.reduce_prod(filter_shape[:-1]))
        stddev = tf.sqrt(tf.div(2.0, fanin))
        initializer = tf.random_normal_initializer(0.0, stddev)
        regular_convs = (
            not self.hparams.use_sep_conv
            or layer_idx < self.hparams.num_initial_regular_conv_layers)
        if regular_convs:
            dilation_rates = layer.get('dilation_rate', 1)
            if isinstance(dilation_rates, int):
                dilation_rates = [dilation_rates] * 2
            weights = tf.get_variable(
                'weights',
                filter_shape,
                initializer=initializer if self.is_training else None)
            stride = layer.get('conv_stride', 1)
            conv = tf.nn.conv2d(x,
                                weights,
                                strides=[1, stride, stride, 1],
                                padding=layer.get('conv_pad', 'SAME'),
                                dilations=[1] + dilation_rates + [1])
        else:
            num_outputs = filter_shape[-1]
            num_splits = layer.get('num_pointwise_splits', 1)
            tf.logging.info('num_splits %d', num_splits)
            if num_splits > 1:
                num_outputs = None
            # conv = tf.layers.separable_conv2d(
            conv = tf_contrib.layers.separable_conv2d(
                x,
                num_outputs,
                filter_shape[:2],
                depth_multiplier=self.hparams.sep_conv_depth_multiplier,
                stride=layer.get('conv_stride', 1),
                # strides=layer.get('conv_stride', 1),
                padding=layer.get('conv_pad', 'SAME'),
                rate=layer.get('dilation_rate', 1),
                # dilation_rate=layer.get('dilation_rate', 1),
                activation_fn=None,
                # activation=None,
                weights_initializer=initializer if self.is_training else None)
            # depthwise_initializer = initializer if self.is_training else None)
            # depthwise_initializer = 'glorot_uniform' if self.is_training else None)
            if num_splits > 1:
                splits = tf.split(conv, num_splits, -1)
                print(len(splits), splits[0].shape)
                # TODO(annahuang): support non equal splits.
                pointwise_splits = [
                    tf.layers.dense(splits[i],
                                    filter_shape[3] / num_splits,
                                    name='split_%d_%d' % (layer_idx, i))
                    for i in range(num_splits)
                ]
                conv = tf.concat((pointwise_splits), axis=-1)

        # Compute batch normalization or add biases.
        if self.hparams.batch_norm:
            y = self.apply_batchnorm(conv)
        else:
            biases = tf.get_variable('bias', [conv.get_shape()[-1]],
                                     initializer=tf.constant_initializer(0.0))
            y = tf.nn.bias_add(conv, biases)
        return y
コード例 #16
0
def axial_rowmajor(x, config, is_training=True, causal=True):
    """Full attention matrix with sqrt decomposition."""
    bsize = x.shape[0]
    seq_len = x.shape.as_list()[1]
    head_dim = config.model_size // config.num_heads
    assert seq_len % config.max_seg_len == 0
    num_seg = seq_len // config.max_seg_len
    x_sqr = tf.reshape(x,
                       [bsize, num_seg, config.max_seg_len, config.model_size])
    q_row_local, key_row_local, value_row_local = attention.get_qkv(
        x_sqr,
        x_sqr,
        x_sqr,
        hidden_size=config.model_size,
        num_heads=config.num_heads,
        bias=config.dense_use_bias)
    local_logits = tf.einsum('bsqhd,bskhd->bsqhk', q_row_local, key_row_local)
    row_probs = attention.float32_softmax(local_logits, axis=-1)
    if is_training:
        row_probs = tf.nn.dropout(row_probs, rate=config.dropatt)

    row_attn_out = tf.einsum('bsqhk,bskhd->bsqhd', row_probs, value_row_local)
    if config.row_summary == 'none':
        key_row = key_row_local
    elif config.row_summary in ['wsum', 'proj', 'wsum_proj']:
        if 'wsum' in config.row_summary:
            pre_summary = tf.einsum('bsqhk,bskhd->bsqhd', row_probs,
                                    key_row_local)
        else:
            pre_summary = row_attn_out
        if 'proj' in config.row_summary:
            with tf.variable_scope('rowmajor_param_post'):
                key_row = ops.trail_dense(pre_summary,
                                          config.model_size,
                                          begin_axis=-2,
                                          bias=config.dense_use_bias)
                key_row = ops.postprocess(x_sqr, key_row, config, is_training)
                _, key_row = ops.preprocess(key_row, config)
                key_row = ops.trail_dense(key_row,
                                          [config.num_heads, head_dim],
                                          bias=config.dense_use_bias)
        else:
            key_row = pre_summary
    else:
        raise ValueError('Unknown row summary %s' % config.row_summary)
    if causal:
        local_mask = get_causal_mask(q_row_local, axis=2, is_strict=False)
        local_logits += local_mask[:, tf.newaxis, :]

    global_logits = tf.einsum('bqlhd,bklhd->bqlhk', q_row_local, key_row)
    if causal:
        global_mask = get_causal_mask(q_row_local, axis=1, is_strict=True)
        global_logits += global_mask[:, tf.newaxis, tf.newaxis, :]
    # (bsize, num_seg, seg_len, n_head, seg_len + num_seg)
    joint_logits = tf.concat([local_logits, global_logits], axis=-1)
    attn_probs = attention.float32_softmax(joint_logits, axis=-1)
    local_att, global_att = tf.split(attn_probs, [config.max_seg_len, num_seg],
                                     axis=-1)
    if is_training:
        local_att = tf.nn.dropout(local_att, rate=config.dropatt)
    local_merged = tf.einsum('bsqhk,bskhd->bsqhd', local_att, value_row_local)
    global_merged = tf.einsum('bqlhv,bvlhd->bqlhd', global_att, row_attn_out)
    joint_merged = tf.reshape(local_merged + global_merged,
                              [bsize, seq_len, config.num_heads, head_dim])
    output = ops.trail_dense(joint_merged,
                             config.model_size,
                             begin_axis=-2,
                             bias=config.dense_use_bias)
    return output
コード例 #17
0
 def objective(self, params, data=None, labels=None):
     x, y = tf.split(params[0], 2, axis=0)
     obj = (1 - x)**2 + 100 * (y - x**2)**2
     return tf.squeeze(obj)
コード例 #18
0
def axial_mixture_bidir(x, config, is_training=True, causal=False):
    """Full attention matrix with axial mixture decomposition."""
    assert not causal
    bsize = x.shape[0]
    seq_len = x.shape.as_list()[1]
    head_dim = config.model_size // config.num_heads
    assert seq_len % config.max_seg_len == 0
    num_seg = seq_len // config.max_seg_len
    x_sqr = tf.reshape(x,
                       [bsize, num_seg, config.max_seg_len, config.model_size])
    query, key, value = attention.get_qkv(x_sqr,
                                          x_sqr,
                                          x_sqr,
                                          hidden_size=config.model_size,
                                          num_heads=config.num_heads,
                                          bias=config.dense_use_bias)
    local_row_logits = tf.einsum('bushd,buthd->bhust', query, key)
    local_col_logits = tf.einsum('bsuhd,btuhd->bhsut', query, key)
    # TODO: add self-mask for local_col_logits

    span_attn_fn = functools.partial(attention.dot_product_attention,
                                     key_heads=key,
                                     value_heads=value,
                                     is_training=is_training,
                                     dropatt=config.dropatt)

    # === top-down summary ===
    col_query_topdown = approx_cummax(query, 1, exclusive=True)
    col_key_topdown = approx_cummax(key, 1, exclusive=True)
    col_t2d_mask = get_causal_mask(x_sqr, axis=1, is_strict=True)
    col_t2d_val = span_attn_fn(query_heads=col_query_topdown,
                               attn_axis=0,
                               attn_bias=col_t2d_mask)

    # === bottom-up summary ===
    col_query_bottomup = approx_cummax(query, 1, exclusive=True, reverse=True)
    col_key_bottomup = approx_cummax(key, 1, exclusive=True, reverse=True)
    col_b2t_mask = get_causal_mask(x_sqr, axis=1, is_strict=True, upper=True)
    col_b2t_val = span_attn_fn(query_heads=col_query_bottomup,
                               attn_axis=0,
                               attn_bias=col_b2t_mask)

    # === left2right summary ===
    row_query_left2right = approx_cummax(query, 2, exclusive=True)
    row_key_left2right = approx_cummax(key, 2, exclusive=True)
    row_l2r_mask = get_causal_mask(x_sqr, axis=2, is_strict=True)
    row_l2r_val = span_attn_fn(query_heads=row_query_left2right,
                               attn_axis=1,
                               attn_bias=row_l2r_mask)

    # === right2left summary ===
    row_query_right2left = approx_cummax(query,
                                         2,
                                         exclusive=True,
                                         reverse=True)
    row_key_right2left = approx_cummax(key, 2, exclusive=True, reverse=True)
    row_r2l_mask = get_causal_mask(x_sqr, axis=2, is_strict=True, upper=True)
    row_r2l_val = span_attn_fn(query_heads=row_query_right2left,
                               attn_axis=1,
                               attn_bias=row_r2l_mask)

    global_t2d_logits = tf.einsum('bushd,buthd->bhust', query, col_key_topdown)
    global_b2t_logits = tf.einsum('bushd,buthd->bhust', query,
                                  col_key_bottomup)
    global_l2r_logits = tf.einsum('bsuhd,btuhd->bhsut', query,
                                  row_key_left2right)
    global_r2l_logits = tf.einsum('bsuhd,btuhd->bhsut', query,
                                  row_key_right2left)
    joint_logits = tf.concat([
        local_row_logits, local_col_logits, global_t2d_logits,
        global_b2t_logits, global_l2r_logits, global_r2l_logits
    ],
                             axis=-1)
    attn_probs = attention.float32_softmax(joint_logits, axis=-1)
    prow, pcol, pt2d, pb2t, pl2r, pr2l = tf.split(attn_probs, [
        config.max_seg_len, num_seg, config.max_seg_len, config.max_seg_len,
        num_seg, num_seg
    ],
                                                  axis=-1)
    mrow = tf.einsum('bhust,buthd->bushd', prow, value)
    mcol = tf.einsum('bhsut,btuhd->bsuhd', pcol, value)
    mt2d = tf.einsum('bhust,buthd->bushd', pt2d, col_t2d_val)
    mb2t = tf.einsum('bhust,buthd->bushd', pb2t, col_b2t_val)
    ml2r = tf.einsum('bhsut,btuhd->bsuhd', pl2r, row_l2r_val)
    mr2l = tf.einsum('bhsut,btuhd->bsuhd', pr2l, row_r2l_val)
    joint_merged = mrow + mcol + mt2d + mb2t + ml2r + mr2l
    joint_merged = tf.reshape(joint_merged,
                              [bsize, seq_len, config.num_heads, head_dim])
    output = ops.trail_dense(joint_merged,
                             config.model_size,
                             begin_axis=-2,
                             bias=config.dense_use_bias)
    return output
コード例 #19
0
 def objective(self, params, data=None, labels=None):
     x, y = tf.split(params[0], 2, axis=0)
     obj = tf.log(
         tf.exp(x + 3. * y - 0.1) + tf.exp(x - 3. * y - 0.1) +
         tf.exp(-x - 0.1) + 1.0)
     return tf.squeeze(obj)
コード例 #20
0
    def build(self):
        self.inputs = tf.placeholder(tf.int32, [None, self.max_time_steps])
        self.targets = tf.placeholder(tf.int32, [None, self.max_time_steps])

        self.inputs_emb = tf.nn.embedding_lookup(self.embedding, self.inputs)
        self.inputs_emb = tf.transpose(self.inputs_emb, [1, 0, 2])
        self.inputs_emb = tf.reshape(self.inputs_emb, [-1, self.emb_dim])
        self.inputs_emb = tf.split(self.inputs_emb, self.max_time_steps, 0)

        # lstm cell
        if self.biderectional:
            lstm_cell_fw = self.cell
            lstm_cell_bw = self.cell

            # dropout
            if self.is_training:
                lstm_cell_fw = tf.nn.rnn_cell.DropoutWrapper(
                    lstm_cell_fw, output_keep_prob=(1 - self.dropout_rate))
                lstm_cell_bw = tf.nn.rnn_cell.DropoutWrapper(
                    lstm_cell_bw, output_keep_prob=(1 - self.dropout_rate))

            lstm_cell_fw = tf.nn.rnn_cell.MultiRNNCell([lstm_cell_fw] *
                                                       self.num_layers)
            lstm_cell_bw = tf.nn.rnn_cell.MultiRNNCell([lstm_cell_bw] *
                                                       self.num_layers)

            # get the length of each sample
            self.length = tf.reduce_sum(tf.sign(self.inputs),
                                        reduction_indices=1)
            self.length = tf.cast(self.length, tf.int32)

            # forward and backward
            outputs, _, _ = tf2.compat.v1.nn.static_bidirectional_rnn(
                lstm_cell_fw,
                lstm_cell_bw,
                self.inputs_emb,
                dtype=tf.float32,
                sequence_length=self.length)

        else:
            lstm_cell = self.cell
            if self.is_training:
                lstm_cell = tf.nn.rnn_cell.DropoutWrapper(
                    lstm_cell, output_keep_prob=(1 - self.dropout_rate))
            lstm_cell = tf.nn.rnn_cell.MultiRNNCell([lstm_cell] *
                                                    self.num_layers)
            self.length = tf.reduce_sum(tf.sign(self.inputs),
                                        reduction_indices=1)
            self.length = tf.cast(self.length, tf.int32)

            outputs, _ = tf.contrib.rnn.static_rnn(lstm_cell,
                                                   self.inputs_emb,
                                                   dtype=tf.float32,
                                                   sequence_length=self.length)
        # outputs: list_steps[batch, 2*dim]
        outputs = tf.concat(outputs, 1)
        outputs = tf.reshape(
            outputs,
            [self.batch_size, self.max_time_steps, self.hidden_dim * 2])

        # self attention module
        if self.is_attention:
            H1 = tf.reshape(outputs, [-1, self.hidden_dim * 2])
            W_a1 = tf.get_variable(
                "W_a1",
                shape=[self.hidden_dim * 2, self.attention_dim],
                initializer=self.initializer,
                trainable=True)
            u1 = tf.matmul(H1, W_a1)

            H2 = tf.reshape(tf.identity(outputs), [-1, self.hidden_dim * 2])
            W_a2 = tf.get_variable(
                "W_a2",
                shape=[self.hidden_dim * 2, self.attention_dim],
                initializer=self.initializer,
                trainable=True)
            u2 = tf.matmul(H2, W_a2)

            u1 = tf.reshape(
                u1,
                [self.batch_size, self.max_time_steps, self.hidden_dim * 2])
            u2 = tf.reshape(
                u2,
                [self.batch_size, self.max_time_steps, self.hidden_dim * 2])
            u = tf.matmul(u1, u2, transpose_b=True)

            # Array of weights for each time step
            A = tf.nn.softmax(u, name="attention")
            outputs = tf.matmul(
                A,
                tf.reshape(tf.identity(outputs), [
                    self.batch_size, self.max_time_steps, self.hidden_dim * 2
                ]))

        # linear
        self.outputs = tf.reshape(outputs, [-1, self.hidden_dim * 2])
        self.softmax_w = tf.get_variable(
            "softmax_w", [self.hidden_dim * 2, self.num_classes],
            initializer=self.initializer)
        self.softmax_b = tf.get_variable("softmax_b", [self.num_classes],
                                         initializer=self.initializer)
        self.logits = tf.matmul(self.outputs, self.softmax_w) + self.softmax_b

        self.logits = tf.reshape(
            self.logits,
            [self.batch_size, self.max_time_steps, self.num_classes])
        print(self.logits.get_shape().as_list())
        if not self.is_crf:
            # softmax
            softmax_out = tf.nn.softmax(self.logits, axis=-1)

            self.batch_pred_sequence = tf.cast(tf.argmax(softmax_out, -1),
                                               tf.int32)
            losses = tf.nn.sparse_softmax_cross_entropy_with_logits(
                logits=self.logits, labels=self.targets)
            mask = tf.sequence_mask(self.length)

            self.losses = tf.boolean_mask(losses, mask)

            self.loss = tf.reduce_mean(losses)
        else:
            # crf
            #print(self.logits.shape[2])
            #print(self.targets)
            #print(self.length)

            self.log_likelihood, self.transition_params = tfa.text.crf_log_likelihood(
                self.logits, self.targets, self.length)
            #tf.contrib.crf.crf_log_likelihood(
            #self.logits, self.targets, self.length)
            #self.batch_pred_sequence, self.batch_viterbi_score = tf.contrib.crf.crf_decode(self.logits,
            #                                                                               self.transition_params,
            #                                                                               self.length)

            self.batch_pred_sequence, self.batch_viterbi_score = tfa.text.crf_log_likelihood(
                self.logits, self.transition_params, self.length)

            self.loss = tf.reduce_mean(-self.log_likelihood)

        self.train_summary = tf.summary.scalar("loss", self.loss)
        self.dev_summary = tf.summary.scalar("loss", self.loss)

        self.opt_op = self.optimizer.minimize(self.loss,
                                              global_step=self.global_step)
コード例 #21
0
 def objective(self, params, data=None, labels=None):
     x, y = tf.split(params[0], 2, axis=0)
     obj = ((1.5 - x + x * y)**2 + (2.25 - x + x * y**2)**2 +
            (2.625 - x + x * y**3)**2)
     return tf.squeeze(obj)
コード例 #22
0
ファイル: model_utils.py プロジェクト: yyht/language
 def split_input(inp, out):
   out_dim = out.get_shape().as_list()[-1]
   inp_dim = inp.get_shape().as_list()[-1]
   return tf.split(inp, [out_dim, inp_dim - out_dim], axis=-1)
コード例 #23
0
 def objective(self, params, data=None, labels=None):
     params = tf.split(params[0], 2, axis=0)
     obj = 0.5 * tf.reduce_sum([x**4 - 16 * x**2 + 5 * x
                                for x in params], 0) + 80.
     return tf.squeeze(obj)
コード例 #24
0
def position_sensitive_crop_regions(image,
                                    boxes,
                                    crop_size,
                                    num_spatial_bins,
                                    global_pool):
  """Position-sensitive crop and pool rectangular regions from a feature grid.

  The output crops are split into `spatial_bins_y` vertical bins
  and `spatial_bins_x` horizontal bins. For each intersection of a vertical
  and a horizontal bin the output values are gathered by performing
  `tf.image.crop_and_resize` (bilinear resampling) on a a separate subset of
  channels of the image. This reduces `depth` by a factor of
  `(spatial_bins_y * spatial_bins_x)`.

  When global_pool is True, this function implements a differentiable version
  of position-sensitive RoI pooling used in
  [R-FCN detection system](https://arxiv.org/abs/1605.06409).

  When global_pool is False, this function implements a differentiable version
  of position-sensitive assembling operation used in
  [instance FCN](https://arxiv.org/abs/1603.08678).

  Args:
    image: A `Tensor`. Must be one of the following types: `uint8`, `int8`,
      `int16`, `int32`, `int64`, `half`, `float32`, `float64`.
      A 3-D tensor of shape `[image_height, image_width, depth]`.
      Both `image_height` and `image_width` need to be positive.
    boxes: A `Tensor` of type `float32`.
      A 2-D tensor of shape `[num_boxes, 4]`. Each box is specified in
      normalized coordinates `[y1, x1, y2, x2]`. A normalized coordinate value
      of `y` is mapped to the image coordinate at `y * (image_height - 1)`, so
      as the `[0, 1]` interval of normalized image height is mapped to
      `[0, image_height - 1] in image height coordinates. We do allow y1 > y2,
      in which case the sampled crop is an up-down flipped version of the
      original image. The width dimension is treated similarly.
    crop_size: A list of two integers `[crop_height, crop_width]`. All
      cropped image patches are resized to this size. The aspect ratio of the
      image content is not preserved. Both `crop_height` and `crop_width` need
      to be positive.
    num_spatial_bins: A list of two integers `[spatial_bins_y, spatial_bins_x]`.
      Represents the number of position-sensitive bins in y and x directions.
      Both values should be >= 1. `crop_height` should be divisible by
      `spatial_bins_y`, and similarly for width.
      The number of image channels should be divisible by
      (spatial_bins_y * spatial_bins_x).
      Suggested value from R-FCN paper: [3, 3].
    global_pool: A boolean variable.
      If True, we perform average global pooling on the features assembled from
        the position-sensitive score maps.
      If False, we keep the position-pooled features without global pooling
        over the spatial coordinates.
      Note that using global_pool=True is equivalent to but more efficient than
        running the function with global_pool=False and then performing global
        average pooling.

  Returns:
    position_sensitive_features: A 4-D tensor of shape
      `[num_boxes, K, K, crop_channels]`,
      where `crop_channels = depth / (spatial_bins_y * spatial_bins_x)`,
      where K = 1 when global_pool is True (Average-pooled cropped regions),
      and K = crop_size when global_pool is False.
  Raises:
    ValueError: Raised in four situations:
      `num_spatial_bins` is not >= 1;
      `num_spatial_bins` does not divide `crop_size`;
      `(spatial_bins_y*spatial_bins_x)` does not divide `depth`;
      `bin_crop_size` is not square when global_pool=False due to the
        constraint in function space_to_depth.
  """
  total_bins = 1
  bin_crop_size = []

  for (num_bins, crop_dim) in zip(num_spatial_bins, crop_size):
    if num_bins < 1:
      raise ValueError('num_spatial_bins should be >= 1')

    if crop_dim % num_bins != 0:
      raise ValueError('crop_size should be divisible by num_spatial_bins')

    total_bins *= num_bins
    bin_crop_size.append(crop_dim // num_bins)

  if not global_pool and bin_crop_size[0] != bin_crop_size[1]:
    raise ValueError('Only support square bin crop size for now.')

  ymin, xmin, ymax, xmax = tf.unstack(boxes, axis=1)
  spatial_bins_y, spatial_bins_x = num_spatial_bins

  # Split each box into spatial_bins_y * spatial_bins_x bins.
  position_sensitive_boxes = []
  for bin_y in range(spatial_bins_y):
    step_y = (ymax - ymin) / spatial_bins_y
    for bin_x in range(spatial_bins_x):
      step_x = (xmax - xmin) / spatial_bins_x
      box_coordinates = [ymin + bin_y * step_y,
                         xmin + bin_x * step_x,
                         ymin + (bin_y + 1) * step_y,
                         xmin + (bin_x + 1) * step_x,
                        ]
      position_sensitive_boxes.append(tf.stack(box_coordinates, axis=1))

  image_splits = tf.split(value=image, num_or_size_splits=total_bins, axis=2)

  image_crops = []
  for (split, box) in zip(image_splits, position_sensitive_boxes):
    if split.shape.is_fully_defined() and box.shape.is_fully_defined():
      crop = tf.squeeze(
          matmul_crop_and_resize(
              tf.expand_dims(split, axis=0), tf.expand_dims(box, axis=0),
              bin_crop_size),
          axis=0)
    else:
      crop = tf.image.crop_and_resize(
          tf.expand_dims(split, 0), box,
          tf.zeros(tf.shape(boxes)[0], dtype=tf.int32), bin_crop_size)
    image_crops.append(crop)

  if global_pool:
    # Average over all bins.
    position_sensitive_features = tf.add_n(image_crops) / len(image_crops)
    # Then average over spatial positions within the bins.
    position_sensitive_features = tf.reduce_mean(
        position_sensitive_features, [1, 2], keepdims=True)
  else:
    # Reorder height/width to depth channel.
    block_size = bin_crop_size[0]
    if block_size >= 2:
      image_crops = [tf.space_to_depth(
          crop, block_size=block_size) for crop in image_crops]

    # Pack image_crops so that first dimension is for position-senstive boxes.
    position_sensitive_features = tf.stack(image_crops, axis=0)

    # Unroll the position-sensitive boxes to spatial positions.
    position_sensitive_features = tf.squeeze(
        tf.batch_to_space_nd(position_sensitive_features,
                             block_shape=[1] + num_spatial_bins,
                             crops=tf.zeros((3, 2), dtype=tf.int32)),
        axis=[0])

    # Reorder back the depth channel.
    if block_size >= 2:
      position_sensitive_features = tf.depth_to_space(
          position_sensitive_features, block_size=block_size)

  return position_sensitive_features
コード例 #25
0
 def objective(self, params, data=None, labels=None):
     x, y = tf.split(params[0], 2, axis=0)
     m = 5  # Defines how steep the ridges are (larger m => steeper ridges).
     obj = 2. - (tf.sin(x) * tf.sin(x**2 / np.pi)**(2 * m) +
                 tf.sin(y) * tf.sin(2 * y**2 / np.pi)**(2 * m))
     return tf.squeeze(obj)
コード例 #26
0
    def build_model(self, hps):
        """Define model architecture."""
        if hps.is_training:
            self.global_step = tf.Variable(0,
                                           name='global_step',
                                           trainable=False)

        if hps.dec_model == 'lstm':
            cell_fn = rnn.LSTMCell
        elif hps.dec_model == 'layer_norm':
            cell_fn = rnn.LayerNormLSTMCell
        elif hps.dec_model == 'hyper':
            cell_fn = rnn.HyperLSTMCell
        else:
            assert False, 'please choose a respectable cell'

        if hps.enc_model == 'lstm':
            enc_cell_fn = rnn.LSTMCell
        elif hps.enc_model == 'layer_norm':
            enc_cell_fn = rnn.LayerNormLSTMCell
        elif hps.enc_model == 'hyper':
            enc_cell_fn = rnn.HyperLSTMCell
        else:
            assert False, 'please choose a respectable cell'

        use_recurrent_dropout = self.hps.use_recurrent_dropout
        use_input_dropout = self.hps.use_input_dropout
        use_output_dropout = self.hps.use_output_dropout

        cell = cell_fn(hps.dec_rnn_size,
                       use_recurrent_dropout=use_recurrent_dropout,
                       dropout_keep_prob=self.hps.recurrent_dropout_prob)

        if hps.conditional:  # vae mode:
            if hps.enc_model == 'hyper':
                self.enc_cell_fw = enc_cell_fn(
                    hps.enc_rnn_size,
                    use_recurrent_dropout=use_recurrent_dropout,
                    dropout_keep_prob=self.hps.recurrent_dropout_prob)
                self.enc_cell_bw = enc_cell_fn(
                    hps.enc_rnn_size,
                    use_recurrent_dropout=use_recurrent_dropout,
                    dropout_keep_prob=self.hps.recurrent_dropout_prob)
            else:
                self.enc_cell_fw = enc_cell_fn(
                    hps.enc_rnn_size,
                    use_recurrent_dropout=use_recurrent_dropout,
                    dropout_keep_prob=self.hps.recurrent_dropout_prob)
                self.enc_cell_bw = enc_cell_fn(
                    hps.enc_rnn_size,
                    use_recurrent_dropout=use_recurrent_dropout,
                    dropout_keep_prob=self.hps.recurrent_dropout_prob)

        # dropout:
        tf.logging.info('Input dropout mode = %s.', use_input_dropout)
        tf.logging.info('Output dropout mode = %s.', use_output_dropout)
        tf.logging.info('Recurrent dropout mode = %s.', use_recurrent_dropout)
        if use_input_dropout:
            tf.logging.info('Dropout to input w/ keep_prob = %4.4f.',
                            self.hps.input_dropout_prob)
            cell = contrib_rnn.DropoutWrapper(
                cell, input_keep_prob=self.hps.input_dropout_prob)
        if use_output_dropout:
            tf.logging.info('Dropout to output w/ keep_prob = %4.4f.',
                            self.hps.output_dropout_prob)
            cell = contrib_rnn.DropoutWrapper(
                cell, output_keep_prob=self.hps.output_dropout_prob)
        self.cell = cell

        self.sequence_lengths = tf.placeholder(dtype=tf.int32,
                                               shape=[self.hps.batch_size])
        self.input_data = tf.placeholder(
            dtype=tf.float32,
            shape=[self.hps.batch_size, self.hps.max_seq_len + 1, 5])

        # The target/expected vectors of strokes
        self.output_x = self.input_data[:, 1:self.hps.max_seq_len + 1, :]
        # vectors of strokes to be fed to decoder (same as above, but lagged behind
        # one step to include initial dummy value of (0, 0, 1, 0, 0))
        self.input_x = self.input_data[:, :self.hps.max_seq_len, :]

        # either do vae-bit and get z, or do unconditional, decoder-only
        if hps.conditional:  # vae mode:
            self.mean, self.presig = self.encoder(self.output_x,
                                                  self.sequence_lengths)
            self.sigma = tf.exp(self.presig /
                                2.0)  # sigma > 0. div 2.0 -> sqrt.
            eps = tf.random_normal((self.hps.batch_size, self.hps.z_size),
                                   0.0,
                                   1.0,
                                   dtype=tf.float32)
            self.batch_z = self.mean + tf.multiply(self.sigma, eps)
            # KL cost
            self.kl_cost = -0.5 * tf.reduce_mean(
                (1 + self.presig - tf.square(self.mean) - tf.exp(self.presig)))
            self.kl_cost = tf.maximum(self.kl_cost, self.hps.kl_tolerance)
            pre_tile_y = tf.reshape(self.batch_z,
                                    [self.hps.batch_size, 1, self.hps.z_size])
            overlay_x = tf.tile(pre_tile_y, [1, self.hps.max_seq_len, 1])
            actual_input_x = tf.concat([self.input_x, overlay_x], 2)
            self.initial_state = tf.nn.tanh(
                rnn.super_linear(self.batch_z,
                                 cell.state_size,
                                 init_w='gaussian',
                                 weight_start=0.001,
                                 input_size=self.hps.z_size))
        else:  # unconditional, decoder-only generation
            self.batch_z = tf.zeros((self.hps.batch_size, self.hps.z_size),
                                    dtype=tf.float32)
            self.kl_cost = tf.zeros([], dtype=tf.float32)
            actual_input_x = self.input_x
            self.initial_state = cell.zero_state(batch_size=hps.batch_size,
                                                 dtype=tf.float32)

        self.num_mixture = hps.num_mixture

        # TODO(deck): Better understand this comment.
        # Number of outputs is 3 (one logit per pen state) plus 6 per mixture
        # component: mean_x, stdev_x, mean_y, stdev_y, correlation_xy, and the
        # mixture weight/probability (Pi_k)
        n_out = (3 + self.num_mixture * 6)

        with tf.variable_scope('RNN'):
            output_w = tf.get_variable('output_w',
                                       [self.hps.dec_rnn_size, n_out])
            output_b = tf.get_variable('output_b', [n_out])

        # decoder module of sketch-rnn is below
        output, last_state = tf.nn.dynamic_rnn(
            cell,
            actual_input_x,
            initial_state=self.initial_state,
            time_major=False,
            swap_memory=True,
            dtype=tf.float32,
            scope='RNN')

        output = tf.reshape(output, [-1, hps.dec_rnn_size])
        output = tf.nn.xw_plus_b(output, output_w, output_b)
        self.final_state = last_state

        # NB: the below are inner functions, not methods of Model
        def tf_2d_normal(x1, x2, mu1, mu2, s1, s2, rho):
            """Returns result of eq # 24 of http://arxiv.org/abs/1308.0850."""
            norm1 = tf.subtract(x1, mu1)
            norm2 = tf.subtract(x2, mu2)
            s1s2 = tf.multiply(s1, s2)
            # eq 25
            z = (tf.square(tf.div(norm1, s1)) + tf.square(tf.div(norm2, s2)) -
                 2 * tf.div(tf.multiply(rho, tf.multiply(norm1, norm2)), s1s2))
            neg_rho = 1 - tf.square(rho)
            result = tf.exp(tf.div(-z, 2 * neg_rho))
            denom = 2 * np.pi * tf.multiply(s1s2, tf.sqrt(neg_rho))
            result = tf.div(result, denom)
            return result

        def get_lossfunc(z_pi, z_mu1, z_mu2, z_sigma1, z_sigma2, z_corr,
                         z_pen_logits, x1_data, x2_data, pen_data):
            """Returns a loss fn based on eq #26 of http://arxiv.org/abs/1308.0850."""
            # This represents the L_R only (i.e. does not include the KL loss term).

            result0 = tf_2d_normal(x1_data, x2_data, z_mu1, z_mu2, z_sigma1,
                                   z_sigma2, z_corr)
            epsilon = 1e-6
            # result1 is the loss wrt pen offset (L_s in equation 9 of
            # https://arxiv.org/pdf/1704.03477.pdf)
            result1 = tf.multiply(result0, z_pi)
            result1 = tf.reduce_sum(result1, 1, keep_dims=True)
            result1 = -tf.log(result1 + epsilon)  # avoid log(0)

            fs = 1.0 - pen_data[:, 2]  # use training data for this
            fs = tf.reshape(fs, [-1, 1])
            # Zero out loss terms beyond N_s, the last actual stroke
            result1 = tf.multiply(result1, fs)

            # result2: loss wrt pen state, (L_p in equation 9)
            result2 = tf.nn.softmax_cross_entropy_with_logits(
                labels=pen_data, logits=z_pen_logits)
            result2 = tf.reshape(result2, [-1, 1])
            if not self.hps.is_training:  # eval mode, mask eos columns
                result2 = tf.multiply(result2, fs)

            result = result1 + result2
            return result

        # below is where we need to do MDN (Mixture Density Network) splitting of
        # distribution params
        def get_mixture_coef(output):
            """Returns the tf slices containing mdn dist params."""
            # This uses eqns 18 -> 23 of http://arxiv.org/abs/1308.0850.
            z = output
            z_pen_logits = z[:, 0:3]  # pen states
            z_pi, z_mu1, z_mu2, z_sigma1, z_sigma2, z_corr = tf.split(
                z[:, 3:], 6, 1)

            # process output z's into MDN parameters

            # softmax all the pi's and pen states:
            z_pi = tf.nn.softmax(z_pi)
            z_pen = tf.nn.softmax(z_pen_logits)

            # exponentiate the sigmas and also make corr between -1 and 1.
            z_sigma1 = tf.exp(z_sigma1)
            z_sigma2 = tf.exp(z_sigma2)
            z_corr = tf.tanh(z_corr)

            r = [
                z_pi, z_mu1, z_mu2, z_sigma1, z_sigma2, z_corr, z_pen,
                z_pen_logits
            ]
            return r

        out = get_mixture_coef(output)
        [o_pi, o_mu1, o_mu2, o_sigma1, o_sigma2, o_corr, o_pen,
         o_pen_logits] = out

        self.pi = o_pi
        self.mu1 = o_mu1
        self.mu2 = o_mu2
        self.sigma1 = o_sigma1
        self.sigma2 = o_sigma2
        self.corr = o_corr
        self.pen_logits = o_pen_logits
        # pen state probabilities (result of applying softmax to self.pen_logits)
        self.pen = o_pen

        # reshape target data so that it is compatible with prediction shape
        target = tf.reshape(self.output_x, [-1, 5])
        [x1_data, x2_data, eos_data, eoc_data,
         cont_data] = tf.split(target, 5, 1)
        pen_data = tf.concat([eos_data, eoc_data, cont_data], 1)

        lossfunc = get_lossfunc(o_pi, o_mu1, o_mu2, o_sigma1, o_sigma2, o_corr,
                                o_pen_logits, x1_data, x2_data, pen_data)

        self.r_cost = tf.reduce_mean(lossfunc)

        if self.hps.is_training:
            self.lr = tf.Variable(self.hps.learning_rate, trainable=False)
            optimizer = tf.train.AdamOptimizer(self.lr)

            self.kl_weight = tf.Variable(self.hps.kl_weight_start,
                                         trainable=False)
            self.cost = self.r_cost + self.kl_cost * self.kl_weight

            gvs = optimizer.compute_gradients(self.cost)
            g = self.hps.grad_clip
            capped_gvs = [(tf.clip_by_value(grad, -g, g), var)
                          for grad, var in gvs]
            self.train_op = optimizer.apply_gradients(
                capped_gvs, global_step=self.global_step, name='train_step')
コード例 #27
0
ファイル: tfprocess.py プロジェクト: syys96/AlphaGomoku
    def init_net(self, planes, probs, winner, gpus_num):
        self.y_ = probs  # (tf.float32, [None, 362])
        self.sx = tf.split(planes, gpus_num)
        self.sy_ = tf.split(probs, gpus_num)
        self.sz_ = tf.split(winner, gpus_num)
        self.batch_norm_count = 0
        self.reuse_var = None

        # You need to change the learning rate here if you are training
        # from a self-play training set, for example start with 0.005 instead.
        opt = tf.train.MomentumOptimizer(learning_rate=0.05,
                                         momentum=0.9,
                                         use_nesterov=True)

        opt = LossScalingOptimizer(opt, scale=self.loss_scale)

        # Construct net here.
        tower_grads = []
        tower_loss = []
        tower_policy_loss = []
        tower_mse_loss = []
        tower_reg_term = []
        tower_y_conv = []
        with tf.variable_scope(
                "fp32_storage",
                # this forces trainable variables to be stored as fp32
                custom_getter=float32_variable_storage_getter):
            for i in range(gpus_num):
                with tf.device("/gpu:%d" % i):
                    with tf.name_scope("tower_%d" % i):
                        loss, policy_loss, mse_loss, reg_term, y_conv = self.tower_loss(
                            self.sx[i], self.sy_[i], self.sz_[i])

                        # Reset batchnorm key to 0.
                        self.reset_batchnorm_key()

                        tf.get_variable_scope().reuse_variables()
                        with tf.control_dependencies(
                                tf.get_collection(tf.GraphKeys.UPDATE_OPS)):
                            grads = opt.compute_gradients(loss)

                        tower_grads.append(grads)
                        tower_loss.append(loss)
                        tower_policy_loss.append(policy_loss)
                        tower_mse_loss.append(mse_loss)
                        tower_reg_term.append(reg_term)
                        tower_y_conv.append(y_conv)

        # Average gradients from different GPUs
        self.loss = tf.reduce_mean(tower_loss)
        self.policy_loss = tf.reduce_mean(tower_policy_loss)
        self.mse_loss = tf.reduce_mean(tower_mse_loss)
        self.reg_term = tf.reduce_mean(tower_reg_term)
        self.y_conv = tf.concat(tower_y_conv, axis=0)
        self.mean_grads = self.average_gradients(tower_grads)

        # Do swa after we contruct the net
        if self.swa_enabled is True:
            # Count of networks accumulated into SWA
            self.swa_count = tf.Variable(0., name='swa_count', trainable=False)
            # Count of networks to skip
            self.swa_skip = tf.Variable(self.swa_c,
                                        name='swa_skip',
                                        trainable=False)
            # Build the SWA variables and accumulators
            accum = []
            load = []
            n = self.swa_count
            for w in self.weights:
                name = w.name.split(':')[0]
                var = tf.Variable(tf.zeros(shape=w.shape),
                                  name='swa/' + name,
                                  trainable=False)
                accum.append(
                    tf.assign(var,
                              var * (n / (n + 1.)) + w * (1. / (n + 1.))))
                load.append(tf.assign(w, var))
            with tf.control_dependencies(accum):
                self.swa_accum_op = tf.assign_add(n, 1.)
            self.swa_load_op = tf.group(*load)

        # Accumulate gradients
        self.update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
        total_grad = []
        grad_ops = []
        clear_var = []
        self.grad_op_real = self.mean_grads
        for (g, v) in self.grad_op_real:
            if g is None:
                total_grad.append((g, v))
            name = v.name.split(':')[0]
            gsum = tf.get_variable(name='gsum/' + name,
                                   shape=g.shape,
                                   trainable=False,
                                   initializer=tf.zeros_initializer)
            total_grad.append((gsum, v))
            grad_ops.append(tf.assign_add(gsum, g))
            clear_var.append(gsum)
        # Op to compute gradients and add to running total in 'gsum/'
        self.grad_op = tf.group(*grad_ops)

        # Op to apply accmulated gradients
        self.train_op = opt.apply_gradients(total_grad)

        zero_ops = []
        for g in clear_var:
            zero_ops.append(
                tf.assign(g, tf.zeros(shape=g.shape, dtype=g.dtype)))
        # Op to clear accumulated gradients
        self.clear_op = tf.group(*zero_ops)

        # Op to increment global step counter
        self.step_op = tf.assign_add(self.global_step, 1)

        correct_prediction = \
            tf.equal(tf.argmax(self.y_conv, 1), tf.argmax(self.y_, 1))
        correct_prediction = tf.cast(correct_prediction, tf.float32)
        self.accuracy = tf.reduce_mean(correct_prediction)

        # Summary part
        self.test_writer = tf.summary.FileWriter(
            os.path.join(os.getcwd(), self.logbase + "/test"),
            self.session.graph)
        self.train_writer = tf.summary.FileWriter(
            os.path.join(os.getcwd(), self.logbase + "/train"),
            self.session.graph)

        # Build checkpoint saver
        self.saver = tf.train.Saver()

        # Initialize all variables
        self.session.run(tf.global_variables_initializer())
コード例 #28
0
def _get_mdn_coef(output):
    logmix, mean, logstd = tf.split(output, 3, -1)
    logmix = logmix - tf.reduce_logsumexp(logmix, -1, keepdims=True)
    return logmix, mean, logstd
コード例 #29
0
    def _build(self, features, parent_transform=None, parent_presence=None):
        """Builds the module.

    Args:
      features: Tensor of encodings of shape [B, n_enc_dims].
      parent_transform: Tuple of (matrix, vector).
      parent_presence: pass

    Returns:
      A bunch of stuff.
    """
        batch_size = features.shape.as_list()[0]
        batch_shape = [batch_size, self._n_caps]

        # Predict capsule and additional params from the input encoding.
        # [B, n_caps, n_caps_dims]
        if self._n_caps_params is not None:

            # Use separate parameters to do predictions for different capsules.
            mlp = BatchMLP(self._n_hiddens + [self._n_caps_params])
            raw_caps_params = mlp(features)

            caps_params = tf.reshape(raw_caps_params,
                                     batch_shape + [self._n_caps_params])

        else:
            assert features.shape[:2].as_list() == batch_shape
            caps_params = features

        if self._caps_dropout_rate == 0.0:
            caps_exist = tf.ones(batch_shape + [1], dtype=tf.float32)
        else:
            pmf = tfd.Bernoulli(1. - self._caps_dropout_rate, dtype=tf.float32)
            caps_exist = pmf.sample(batch_shape + [1])

        caps_params = tf.concat([caps_params, caps_exist], -1)

        output_shapes = (
            [self._n_votes, self._n_transform_params],  # CPR_dynamic
            [1, self._n_transform_params],  # CCR
            [1],  # per-capsule presence
            [self._n_votes],  # per-vote-presence
            [self._n_votes],  # per-vote scale
        )

        splits = [np.prod(i).astype(np.int32) for i in output_shapes]
        n_outputs = sum(splits)

        # we don't use bias in the output layer in order to separate the static
        # and dynamic parts of the CPR
        caps_mlp = BatchMLP([self._n_hiddens, n_outputs], use_bias=False)
        all_params = caps_mlp(caps_params)
        all_params = tf.split(all_params, splits, -1)
        res = [
            tf.reshape(i, batch_shape + s)
            for (i, s) in zip(all_params, output_shapes)
        ]

        cpr_dynamic = res[0]

        # add bias to all remaining outputs
        res = [snt.AddBias()(i) for i in res[1:]]
        ccr, pres_logit_per_caps, pres_logit_per_vote, scale_per_vote = res

        if self._caps_dropout_rate != 0.0:
            pres_logit_per_caps += math_ops.safe_log(caps_exist)

        cpr_static = tf.get_variable(
            'cpr_static',
            shape=[1, self._n_caps, self._n_votes, self._n_transform_params])

        def add_noise(tensor):
            """Adds noise to tensors."""
            if self._noise_type == 'uniform':
                noise = tf.random.uniform(tensor.shape, minval=-.5,
                                          maxval=.5) * self._noise_scale

            elif self._noise_type == 'logistic':
                pdf = tfd.Logistic(0., self._noise_scale)
                noise = pdf.sample(tensor.shape)

            elif not self._noise_type:
                noise = 0.

            else:
                raise ValueError('Invalid noise type: "{}".'.format(
                    self._noise_type))

            return tensor + noise

        pres_logit_per_caps = add_noise(pres_logit_per_caps)
        pres_logit_per_vote = add_noise(pres_logit_per_vote)

        # this is for hierarchical
        if parent_transform is None:
            ccr = self._make_transform(ccr)
        else:
            ccr = parent_transform

        if not self._deformations:
            cpr_dynamic = tf.zeros_like(cpr_dynamic)

        cpr = self._make_transform(cpr_dynamic + cpr_static)

        ccr_per_vote = snt.TileByDim([2], [self._n_votes])(ccr)
        votes = tf.matmul(ccr_per_vote, cpr)

        if parent_presence is not None:
            pres_per_caps = parent_presence
        else:
            pres_per_caps = tf.nn.sigmoid(pres_logit_per_caps)

        pres_per_vote = pres_per_caps * tf.nn.sigmoid(pres_logit_per_vote)

        if self._learn_vote_scale:
            # for numerical stability
            scale_per_vote = tf.nn.softplus(scale_per_vote + .5) + 1e-2
        else:
            scale_per_vote = tf.zeros_like(scale_per_vote) + 1.

        return AttrDict(
            vote=votes,
            scale=scale_per_vote,
            vote_presence=pres_per_vote,
            pres_logit_per_caps=pres_logit_per_caps,
            pres_logit_per_vote=pres_logit_per_vote,
            dynamic_weights_l2=tf.nn.l2_loss(cpr_dynamic) / batch_size,
            raw_caps_params=raw_caps_params,
            raw_caps_features=features,
        )
コード例 #30
0
ファイル: rnn.py プロジェクト: sleep-yearning/magenta
    def __call__(self, x, state, timestep=0, scope=None):
        with tf.variable_scope(scope or type(self).__name__):
            total_h, total_c = tf.split(state, 2, 1)
            h = total_h[:, 0:self.num_units]
            c = total_c[:, 0:self.num_units]
            self.hyper_state = tf.concat(
                [total_h[:, self.num_units:], total_c[:, self.num_units:]], 1)

            batch_size = x.get_shape().as_list()[0]
            x_size = x.get_shape().as_list()[1]
            self._input_size = x_size

            w_init = None  # uniform

            h_init = lstm_ortho_initializer(1.0)

            w_xh = tf.get_variable(
                'W_xh', [x_size, 4 * self.num_units], initializer=w_init)
            w_hh = tf.get_variable(
                'W_hh', [self.num_units, 4 * self.num_units], initializer=h_init)
            bias = tf.get_variable(
                'bias', [4 * self.num_units],
                initializer=tf.constant_initializer(0.0))

            # concatenate the input and hidden states for hyperlstm input
            hyper_input = tf.concat([x, h], 1)
            hyper_output, hyper_new_state = self.hyper_cell(hyper_input,
                                                            self.hyper_state)
            self.hyper_output = hyper_output
            self.hyper_state = hyper_new_state

            xh = tf.matmul(x, w_xh)
            hh = tf.matmul(h, w_hh)

            # split Wxh contributions
            ix, jx, fx, ox = tf.split(xh, 4, 1)
            ix = self.hyper_norm(ix, 'hyper_ix', use_bias=False)
            jx = self.hyper_norm(jx, 'hyper_jx', use_bias=False)
            fx = self.hyper_norm(fx, 'hyper_fx', use_bias=False)
            ox = self.hyper_norm(ox, 'hyper_ox', use_bias=False)

            # split Whh contributions
            ih, jh, fh, oh = tf.split(hh, 4, 1)
            ih = self.hyper_norm(ih, 'hyper_ih', use_bias=True)
            jh = self.hyper_norm(jh, 'hyper_jh', use_bias=True)
            fh = self.hyper_norm(fh, 'hyper_fh', use_bias=True)
            oh = self.hyper_norm(oh, 'hyper_oh', use_bias=True)

            # split bias
            ib, jb, fb, ob = tf.split(bias, 4, 0)  # bias is to be broadcasted.

            # i = input_gate, j = new_input, f = forget_gate, o = output_gate
            i = ix + ih + ib
            j = jx + jh + jb
            f = fx + fh + fb
            o = ox + oh + ob

            if self.use_layer_norm:
                concat = tf.concat([i, j, f, o], 1)
                concat = layer_norm_all(concat, batch_size, 4, self.num_units, 'ln_all')
                i, j, f, o = tf.split(concat, 4, 1)

            if self.use_recurrent_dropout:
                g = tf.nn.dropout(tf.tanh(j), self.dropout_keep_prob)
            else:
                g = tf.tanh(j)

            new_c = c * tf.sigmoid(f + self.forget_bias) + tf.sigmoid(i) * g
            new_h = tf.tanh(layer_norm(new_c, self.num_units, 'ln_c')) * tf.sigmoid(o)

            hyper_h, hyper_c = tf.split(hyper_new_state, 2, 1)
            new_total_h = tf.concat([new_h, hyper_h], 1)
            new_total_c = tf.concat([new_c, hyper_c], 1)
            new_total_state = tf.concat([new_total_h, new_total_c], 1)
        return new_h, new_total_state