def test_inference(self): """Test that inference runs and produces output of the right size.""" image1 = np.random.randn(256, 256, 3).astype('float32') image2 = np.random.randn(256, 256, 3).astype('float32') smurf_model = smurf_net.SMURFNet() flow = smurf_model.infer(image1, image2) correct_shape = np.equal(flow.shape, [256, 256, 2]).all() self.assertTrue(correct_shape)
def test_train_step(self): """Test a single training step.""" ds = tf.data.Dataset.from_tensor_slices({ 'images': tf.zeros([1, 2, 256, 256, 3], dtype=tf.float32) }).repeat().batch(1) it = iter(ds) smurf_model = smurf_net.SMURFNet() weights = {'smooth2': 2.0, 'edge_constant': 100.0, 'census': 1.} losses = train_step(smurf_model, it.next(), weights=weights) self.assertNotEmpty(losses)
def build_network(batch_size): """Builds the model architecture.""" return smurf_net.SMURFNet( checkpoint_dir=FLAGS.checkpoint_dir, optimizer=FLAGS.optimizer, dropout_rate=0., feature_architecture=FLAGS.feature_architecture, flow_architecture=FLAGS.flow_architecture, size=(batch_size, FLAGS.height, FLAGS.width), occlusion_estimation=FLAGS.occlusion_estimation, smoothness_at_level=FLAGS.smoothness_at_level, use_float16=True, )
def test_sequence_unsupervised_train_step(self): """Test a single supervised training step.""" ds = tf.data.Dataset.from_tensor_slices({ 'images': tf.zeros([1, 2, 256, 256, 3], dtype=tf.float32), }).repeat().batch(1) it = iter(ds) smurf_model = smurf_net.SMURFNet(train_mode='sequence-unsupervised', flow_architecture='raft', feature_architecture='raft') weights = {'smooth2': 2.0, 'edge_constant': 100.0, 'census': 1.} losses = train_step(smurf_model, it.next(), weights=weights) self.assertNotEmpty(losses) self.assertGreater(losses['census-loss'], 0)
def test_supervised_train_step(self): """Test a single supervised training step.""" ds = tf.data.Dataset.from_tensor_slices({ 'images': tf.zeros([1, 2, 256, 256, 3], dtype=tf.float32), 'flow': tf.zeros([1, 256, 256, 2], dtype=tf.float32), 'flow_valid': tf.ones([1, 256, 256, 1], dtype=tf.float32) }).repeat().batch(1) it = iter(ds) smurf_model = smurf_net.SMURFNet(train_mode='supervised') weights = {'supervision': 1.0} losses = train_step(smurf_model, it.next(), weights=weights) self.assertNotEmpty(losses) self.assertGreater(losses['supervision-loss'], 0)