Exemplo n.º 1
0
  def _build(batch):
    """Builds the sonnet module."""
    image = utils.maybe_center(cfg["center_data"], batch["image"])

    net = snt.nets.ConvNet2D(
        hidden_units,
        kernel_shapes=[(3, 3)],
        strides=cfg["strides"],
        paddings=cfg["padding"],
        activation=act_fn,
        use_bias=cfg["use_bias"],
        initializers=init,
        activate_final=True)(
            image)
    if cfg["pool_type"] == "mean":
      net = tf.reduce_mean(net, axis=[1, 2])
    elif cfg["pool_type"] == "max":
      net = tf.reduce_max(net, axis=[1, 2])
    elif cfg["pool_type"] == "squared_mean":
      net = tf.reduce_mean(net**2, axis=[1, 2])

    logits = snt.Linear(batch["label_onehot"].shape[1], initializers=init)(net)

    loss_vec = tf.nn.softmax_cross_entropy_with_logits_v2(
        labels=batch["label_onehot"], logits=logits)

    return tf.reduce_mean(loss_vec)
Exemplo n.º 2
0
 def _fn(batch):
   image = utils.maybe_center(cfg["center_data"], batch["image"])
   hidden_units = cfg["layer_sizes"] + [batch["label_onehot"].shape[1]]
   net = snt.BatchFlatten()(image)
   mod = snt.nets.MLP(hidden_units, activation=act_fn, initializers=init)
   logits = mod(net)
   loss_vec = tf.nn.softmax_cross_entropy_with_logits_v2(
       labels=batch["label_onehot"], logits=logits)
   return tf.reduce_mean(loss_vec)
Exemplo n.º 3
0
    def _build(batch):
        """Builds the sonnet module."""
        image = utils.maybe_center(cfg["center_data"], batch["image"])

        net = snt.nets.ConvNet2D(hidden_units,
                                 kernel_shapes=[(3, 3)],
                                 strides=cfg["strides"],
                                 paddings=cfg["padding"],
                                 activation=act_fn,
                                 use_bias=cfg["use_bias"],
                                 initializers=init,
                                 activate_final=True)(image)

        num_classes = batch["label_onehot"].shape[1]
        fc_hidden = cfg["fc_hidden_units"] + [num_classes]
        net = snt.BatchFlatten()(net)
        logits = snt.nets.MLP(fc_hidden, initializers=init,
                              activation=act_fn)(net)

        loss_vec = tf.nn.softmax_cross_entropy_with_logits_v2(
            labels=batch["label_onehot"], logits=logits)

        return tf.reduce_mean(loss_vec)