# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import tensorflow as tf tf.compat.v1.enable_eager_execution() import losses from networks import base as basenet from util import logging as logutil logger = logutil.Logger(loggee="models/base", debug_mode=False) class Model(tf.keras.Model): """Uses only the parent's trackability and nothing else. """ def __init__(self, config): super().__init__() self.config = config self.net = { 'main': basenet.Network() } # NOTE: insert trainable networks # of your model into this dictionary, values of which will be registered # as trainable self.trainable_registered = False # NOTE: before training, call # register_trainable() to register trainable parameters (which lie in
from os.path import join, exists from itertools import product import re import numpy as np from PIL import Image import tensorflow as tf tf.compat.v1.enable_eager_execution() import xiuminglib as xm from util import logging as logutil, io as ioutil from .base import Dataset as BaseDataset logger = logutil.Logger(loggee="datasets/nlt") class Dataset(BaseDataset): def __init__(self, config, mode, **kwargs): self.data_root = config.get('DEFAULT', 'data_root') data_status_path = self.data_root.rstrip('/') + '.json' if not exists(data_status_path): raise FileNotFoundError( ("Data status JSON not found at \n\t%s\nRun " "$REPO/data_gen/postproc.py to generate it") % data_status_path) self.data_paths = ioutil.read_json(data_status_path) # Because paths in JSON are relative, prepend data root directory for _, paths in self.data_paths.items(): for k, v in paths.items():
import tensorflow as tf tf.compat.v1.enable_eager_execution() import datasets import models from util import io as ioutil, logging as logutil flags.DEFINE_string('ckpt', '/path/to/ckpt-100', "path to checkpoint (prefix only)") flags.DEFINE_integer( 'n_obs_batches', 1, "number of observation batches used for the observation path") flags.DEFINE_integer('fps', 24, "frames per second for the result video") FLAGS = flags.FLAGS logger = logutil.Logger(loggee="nlt_test") def get_config_ini(): return '/'.join(FLAGS.ckpt.split('/')[:-2]) + '.ini' def make_datapipe(mode, config): dataset_name = config.get('DEFAULT', 'dataset') Dataset = datasets.get_dataset_class(dataset_name) dataset = Dataset(config, mode) no_batch = config.getboolean('DEFAULT', 'no_batch') datapipe = dataset.build_pipeline(no_batch=no_batch) return datapipe
# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # pylint: disable=relative-beyond-top-level import tensorflow as tf tf.compat.v1.enable_eager_execution() from util import logging as logutil logger = logutil.Logger(loggee="networks/base") class Network: def __init__(self): self.layers = [] def __call__(self, x): raise NotImplementedError @staticmethod def str2none(str_): """Mostly to overcome there being no `config.getnone()` method. """ assert isinstance(str_, str), "Call this only on strings" if str_.lower() == 'none':
# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # pylint: disable=relative-beyond-top-level import tensorflow as tf tf.compat.v1.enable_eager_execution() from util import logging as logutil logger = logutil.Logger(loggee="networks/elements") def conv(kernel_size, n_ch_out, stride=1): return tf.keras.layers.Conv2D(n_ch_out, kernel_size, strides=stride, padding='same') def deconv(kernel_size, n_ch_out, stride=1): return tf.keras.layers.Conv2DTranspose(n_ch_out, kernel_size, strides=stride, padding='same')
# Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # pylint: disable=relative-beyond-top-level import tensorflow as tf tf.compat.v1.enable_eager_execution() from util import logging as logutil from .base import Network as BaseNetwork logger = logutil.Logger(loggee="networks/seq") class Network(BaseNetwork): """Assuming simple sequential flow. """ def build(self, input_shape): seq = tf.keras.Sequential(self.layers) seq.build(input_shape) for layer in self.layers: assert layer.built, "Some layers not built" def __call__(self, tensor): x = tensor for layer in self.layers: y = layer(x)
import numpy as np from tqdm import tqdm import tensorflow as tf tf.compat.v1.enable_eager_execution() import tensorflow_addons as tfa import xiuminglib as xm import losses from networks import convnet from util import logging as logutil, io as ioutil, img as imgutil, \ tensor as tutil from .base import Model as BaseModel logger = logutil.Logger(loggee="models/nlt", debug_mode=False) class Model(BaseModel): def __init__(self, config): # Needed by Barron loss self.imh = config.getint('DEFAULT', 'imh') self.imw = config.getint('DEFAULT', 'imw') super().__init__(config) # Networks depth0 = config.getint('DEFAULT', 'depth0') depth = config.getint('DEFAULT', 'depth') kernel = config.getint('DEFAULT', 'kernel') stride = config.getint('DEFAULT', 'stride') norm = config.get('DEFAULT', 'norm') act = config.get('DEFAULT', 'act')
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # pylint: disable=relative-beyond-top-level import numpy as np import tensorflow as tf tf.compat.v1.enable_eager_execution() from util import logging as logutil, net as netutil from .seq import Network as BaseNetwork from .elements import conv, norm, act, pool, iden, deconv, upconv logger = logutil.Logger(loggee="networks/convnet") class Network(BaseNetwork): def __init__(self, depth0, depth, kernel, stride, norm_type=None, act_type='relu', pool_type=None): super().__init__() norm_type = self.str2none(norm_type) pool_type = self.str2none(pool_type) min_n_ch = depth0