示例#1
0
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import tensorflow as tf
tf.compat.v1.enable_eager_execution()

import losses
from networks import base as basenet
from util import logging as logutil

logger = logutil.Logger(loggee="models/base", debug_mode=False)


class Model(tf.keras.Model):
    """Uses only the parent's trackability and nothing else.
    """
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.net = {
            'main': basenet.Network()
        }  # NOTE: insert trainable networks
        # of your model into this dictionary, values of which will be registered
        # as trainable
        self.trainable_registered = False  # NOTE: before training, call
        # register_trainable() to register trainable parameters (which lie in
示例#2
0
from os.path import join, exists
from itertools import product
import re
import numpy as np
from PIL import Image

import tensorflow as tf
tf.compat.v1.enable_eager_execution()

import xiuminglib as xm

from util import logging as logutil, io as ioutil
from .base import Dataset as BaseDataset

logger = logutil.Logger(loggee="datasets/nlt")


class Dataset(BaseDataset):
    def __init__(self, config, mode, **kwargs):
        self.data_root = config.get('DEFAULT', 'data_root')
        data_status_path = self.data_root.rstrip('/') + '.json'
        if not exists(data_status_path):
            raise FileNotFoundError(
                ("Data status JSON not found at \n\t%s\nRun "
                 "$REPO/data_gen/postproc.py to generate it") %
                data_status_path)
        self.data_paths = ioutil.read_json(data_status_path)
        # Because paths in JSON are relative, prepend data root directory
        for _, paths in self.data_paths.items():
            for k, v in paths.items():
import tensorflow as tf
tf.compat.v1.enable_eager_execution()

import datasets
import models
from util import io as ioutil, logging as logutil

flags.DEFINE_string('ckpt', '/path/to/ckpt-100',
                    "path to checkpoint (prefix only)")
flags.DEFINE_integer(
    'n_obs_batches', 1,
    "number of observation batches used for the observation path")
flags.DEFINE_integer('fps', 24, "frames per second for the result video")
FLAGS = flags.FLAGS

logger = logutil.Logger(loggee="nlt_test")


def get_config_ini():
    return '/'.join(FLAGS.ckpt.split('/')[:-2]) + '.ini'


def make_datapipe(mode, config):
    dataset_name = config.get('DEFAULT', 'dataset')
    Dataset = datasets.get_dataset_class(dataset_name)
    dataset = Dataset(config, mode)

    no_batch = config.getboolean('DEFAULT', 'no_batch')
    datapipe = dataset.build_pipeline(no_batch=no_batch)
    return datapipe
示例#4
0
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# pylint: disable=relative-beyond-top-level

import tensorflow as tf
tf.compat.v1.enable_eager_execution()

from util import logging as logutil

logger = logutil.Logger(loggee="networks/base")


class Network:
    def __init__(self):
        self.layers = []

    def __call__(self, x):
        raise NotImplementedError

    @staticmethod
    def str2none(str_):
        """Mostly to overcome there being no `config.getnone()` method.
        """
        assert isinstance(str_, str), "Call this only on strings"
        if str_.lower() == 'none':
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# pylint: disable=relative-beyond-top-level

import tensorflow as tf
tf.compat.v1.enable_eager_execution()

from util import logging as logutil

logger = logutil.Logger(loggee="networks/elements")


def conv(kernel_size, n_ch_out, stride=1):
    return tf.keras.layers.Conv2D(n_ch_out,
                                  kernel_size,
                                  strides=stride,
                                  padding='same')


def deconv(kernel_size, n_ch_out, stride=1):
    return tf.keras.layers.Conv2DTranspose(n_ch_out,
                                           kernel_size,
                                           strides=stride,
                                           padding='same')
示例#6
0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# pylint: disable=relative-beyond-top-level

import tensorflow as tf

tf.compat.v1.enable_eager_execution()

from util import logging as logutil
from .base import Network as BaseNetwork

logger = logutil.Logger(loggee="networks/seq")


class Network(BaseNetwork):
    """Assuming simple sequential flow.
    """
    def build(self, input_shape):
        seq = tf.keras.Sequential(self.layers)
        seq.build(input_shape)
        for layer in self.layers:
            assert layer.built, "Some layers not built"

    def __call__(self, tensor):
        x = tensor
        for layer in self.layers:
            y = layer(x)
示例#7
0
import numpy as np
from tqdm import tqdm

import tensorflow as tf
tf.compat.v1.enable_eager_execution()
import tensorflow_addons as tfa

import xiuminglib as xm

import losses
from networks import convnet
from util import logging as logutil, io as ioutil, img as imgutil, \
    tensor as tutil
from .base import Model as BaseModel

logger = logutil.Logger(loggee="models/nlt", debug_mode=False)


class Model(BaseModel):
    def __init__(self, config):
        # Needed by Barron loss
        self.imh = config.getint('DEFAULT', 'imh')
        self.imw = config.getint('DEFAULT', 'imw')
        super().__init__(config)
        # Networks
        depth0 = config.getint('DEFAULT', 'depth0')
        depth = config.getint('DEFAULT', 'depth')
        kernel = config.getint('DEFAULT', 'kernel')
        stride = config.getint('DEFAULT', 'stride')
        norm = config.get('DEFAULT', 'norm')
        act = config.get('DEFAULT', 'act')
示例#8
0
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# pylint: disable=relative-beyond-top-level

import numpy as np

import tensorflow as tf
tf.compat.v1.enable_eager_execution()

from util import logging as logutil, net as netutil
from .seq import Network as BaseNetwork
from .elements import conv, norm, act, pool, iden, deconv, upconv

logger = logutil.Logger(loggee="networks/convnet")


class Network(BaseNetwork):
    def __init__(self,
                 depth0,
                 depth,
                 kernel,
                 stride,
                 norm_type=None,
                 act_type='relu',
                 pool_type=None):
        super().__init__()
        norm_type = self.str2none(norm_type)
        pool_type = self.str2none(pool_type)
        min_n_ch = depth0