def test_memcached_cache_tracing_with_a_wrong_connection(self): # initialize the dummy writer writer = DummyWriter() tracer = Tracer() tracer.writer = writer # create the TracedCache instance for a Flask app Cache = get_traced_cache(tracer, service=self.SERVICE) app = Flask(__name__) config = { 'CACHE_TYPE': 'memcached', 'CACHE_MEMCACHED_SERVERS': ['localhost:2230'], } cache = Cache(app, config=config) # use a wrong memcached connection try: cache.get(u'á_complex_operation') except Exception: pass # ensure that the error is not caused by our tracer spans = writer.pop() assert len(spans) == 1 span = spans[0] assert span.service == self.SERVICE assert span.resource == 'get' assert span.name == 'flask_cache.cmd' assert span.span_type == 'cache' assert span.meta[CACHE_BACKEND] == 'memcached' assert span.meta[net.TARGET_HOST] == 'localhost' assert span.metrics[net.TARGET_PORT] == 2230
def test_redis_cache_tracing_with_a_wrong_connection(self): # initialize the dummy writer writer = DummyWriter() tracer = Tracer() tracer.writer = writer # create the TracedCache instance for a Flask app Cache = get_traced_cache(tracer, service=self.SERVICE) app = Flask(__name__) config = { 'CACHE_TYPE': 'redis', 'CACHE_REDIS_PORT': 2230, 'CACHE_REDIS_HOST': '127.0.0.1' } cache = Cache(app, config=config) # use a wrong redis connection with pytest.raises(ConnectionError) as ex: cache.get(u'á_complex_operation') # ensure that the error is not caused by our tracer assert '127.0.0.1:2230. Connection refused.' in ex.value.args[0] spans = writer.pop() # an error trace must be sent assert len(spans) == 1 span = spans[0] assert span.service == self.SERVICE assert span.resource == 'get' assert span.name == 'flask_cache.cmd' assert span.span_type == 'cache' assert span.meta[CACHE_BACKEND] == 'redis' assert span.meta[net.TARGET_HOST] == '127.0.0.1' assert span.metrics[net.TARGET_PORT] == 2230 assert span.error == 1
def test_cache_add_without_arguments(self): # initialize the dummy writer writer = DummyWriter() tracer = Tracer() tracer.writer = writer # create the TracedCache instance for a Flask app Cache = get_traced_cache(tracer, service=self.SERVICE) app = Flask(__name__) cache = Cache(app, config={'CACHE_TYPE': 'simple'}) # make a wrong call with pytest.raises(TypeError) as ex: cache.add() # ensure that the error is not caused by our tracer assert 'add()' in ex.value.args[0] assert 'argument' in ex.value.args[0] spans = writer.pop() # an error trace must be sent assert len(spans) == 1 span = spans[0] assert span.service == self.SERVICE assert span.resource == 'add' assert span.name == 'flask_cache.cmd' assert span.span_type == 'cache' assert span.error == 1
def test_filters(): t = ddtrace.Tracer() class FilterAll(object): def process_trace(self, trace): return None t.configure( settings={"FILTERS": [FilterAll()],} ) t.writer = DummyWriter() with t.trace("root"): with t.trace("child"): pass spans = t.writer.pop() assert len(spans) == 0 class FilterMutate(object): def __init__(self, key, value): self.key = key self.value = value def process_trace(self, trace): for s in trace: s.set_tag(self.key, self.value) return trace t.configure( settings={"FILTERS": [FilterMutate("boop", "beep")],} ) t.writer = DummyWriter() with t.trace("root"): with t.trace("child"): pass spans = t.writer.pop() assert len(spans) == 2 s1, s2 = spans assert s1.get_tag("boop") == "beep" assert s2.get_tag("boop") == "beep" # Test multiple filters t.configure( settings={"FILTERS": [FilterMutate("boop", "beep"), FilterMutate("mats", "sundin")],} ) t.writer = DummyWriter() with t.trace("root"): with t.trace("child"): pass spans = t.writer.pop() assert len(spans) == 2 for s in spans: assert s.get_tag("boop") == "beep" assert s.get_tag("mats") == "sundin"
def test_VIB_memorize_minibatch(self): for covariance_type in ['diag', 'full']: writer = DummyWriter() kwargs = dict(seed=632, dataset_name='CIFAR10', model_class_name='VIB', model_kwargs=dict( net_name='SmallMLP', net_kwargs=dict( out_dim=10, nonlinearity='elu', batch_norm=True, dropout=False, ), covariance_type=covariance_type, beta=0.001, n_mixture_components=2, train_var_dist_samples=9, test_var_dist_samples=12, ), normalize_inputs=True, batch_size=25, train_size=25, val_size=0, epochs=100, total_batches=None, optimizer_class_name='Adam', optimizer_kwargs=dict(lr=0.001, ), lr_scheduler_class_name='ExponentialLR', lr_scheduler_kwargs=dict(gamma=1.0), device_id='cuda:0') train_model(writer, **kwargs) self.assertLess(writer.train_loss, 0.1) self.assertGreater(writer.train_loss, 0.0)
def test_MASS_var_dist_SGD_update_memorize_minibatch(self): writer = DummyWriter() kwargs = dict(seed=632, dataset_name='CIFAR10', model_class_name='MASSCE', model_kwargs=dict( net_name='SmallMLP', net_kwargs=dict( out_dim=10, nonlinearity='elu', batch_norm=True, dropout=False, ), var_dist_init_strategy='standard_basis', beta=0.001, n_mixture_components=2, ), normalize_inputs=True, batch_size=25, train_size=25, val_size=0, epochs=50, total_batches=None, optimizer_class_name='Adam', optimizer_kwargs=dict( lr=0.001, var_dist_optimizer_kwargs=dict(lr=0.001)), lr_scheduler_class_name='ExponentialLR', lr_scheduler_kwargs=dict(gamma=1.0), device_id='cuda:0') train_model(writer, **kwargs) self.assertLess(writer.train_loss, 0.1)
def test_MASS_var_dist_standard_basis_init(self): writer = DummyWriter() model_kwargs = dict( net_name='SmallMLP', net_kwargs=dict( out_dim=8, nonlinearity='elu', batch_norm=True, dropout=False, in_shape=(1, 28, 28), ), var_dist_init_strategy='standard_basis', beta=0.9, n_mixture_components=2, n_classes=10, ) model = MASSCE(writer, **model_kwargs) # Make sure everything started as zeros for p in model.var_dist.parameters(): self.assertEqual((p != 0).sum(), 0) self.assertTrue(p.requires_grad) # Make sure initialization worked model.initialize(None) for i, mog in enumerate(model.var_dist.q): self.assertEqual((mog.mixture_logits != 0).sum(), 0) if i < 8: self.assertEqual((mog.loc > 5).sum(), 1) self.assertEqual((mog.loc < -5).sum(), 0) self.assertTrue(mog.loc[0, i] == 10) else: self.assertEqual((mog.loc < -10).sum(), 0) self.assertEqual((mog.loc > 10).sum(), 0) self.assertEqual((mog.scale_tril != 0).sum(), 8 * 2) self.assertTrue(mog.scale_tril[0, 0, 0] == 1)
def test_ctx_distributed(): tracer = ddtrace.Tracer() tracer.writer = DummyWriter() # Test activating an invalid context. ctx = Context(span_id=None, trace_id=None) tracer.context_provider.activate(ctx) assert tracer.current_span() is None with tracer.trace("test") as s1: assert tracer.current_span() == s1 assert tracer.current_root_span() == s1 assert tracer.get_call_context().trace_id == s1.trace_id assert tracer.get_call_context().span_id == s1.span_id assert s1.parent_id is None trace = tracer.writer.pop_traces() assert len(trace) == 1 # Test activating a valid context. ctx = Context(span_id=1234, trace_id=4321, sampling_priority=2, dd_origin="somewhere") tracer.context_provider.activate(ctx) assert tracer.current_span() is None with tracer.trace("test2") as s2: assert tracer.current_span() == s2 assert tracer.current_root_span() == s2 assert tracer.get_call_context().trace_id == s2.trace_id == 4321 assert tracer.get_call_context().span_id == s2.span_id assert s2.parent_id == 1234 trace = tracer.writer.pop_traces() assert len(trace) == 1 assert s2.metrics[SAMPLING_PRIORITY_KEY] == 2 assert s2.meta[ORIGIN_KEY] == "somewhere"
def test_tracer_trace_across_fork(): """ When a trace is started in a parent process and a child process is spawned The trace should be continued in the child process """ tracer = ddtrace.Tracer() tracer.writer = DummyWriter() def task(tracer, q): tracer.writer = DummyWriter() with tracer.trace("child"): pass spans = tracer.writer.pop() q.put([dict(trace_id=s.trace_id, parent_id=s.parent_id) for s in spans]) # Assert tracer in a new process correctly recreates the writer q = multiprocessing.Queue() with tracer.trace("parent") as parent: p = multiprocessing.Process(target=task, args=(tracer, q)) p.start() p.join() children = q.get() assert len(children) == 1 (child,) = children assert parent.trace_id == child["trace_id"] assert child["parent_id"] == parent.span_id
def main(kwargs): # Workaround for pytorch bug where multiple gpu processes all like to use gpu0 if 'cuda' in kwargs['device_id'] and torch.cuda.is_available(): torch.cuda.set_device(int(kwargs['device_id'][-1])) writer = DummyWriter() return start_evaluating(writer, **kwargs)
def task(tracer, q): tracer.writer = DummyWriter() def task2(tracer, q): tracer.writer = DummyWriter() with tracer.trace("child2"): pass spans = tracer.writer.pop() q.put([ dict(trace_id=s.trace_id, parent_id=s.parent_id) for s in spans ]) with tracer.trace("child1"): q2 = multiprocessing.Queue() p = multiprocessing.Process(target=task2, args=(tracer, q2)) p.start() p.join() task2_spans = q2.get() spans = tracer.writer.pop() q.put([ dict(trace_id=s.trace_id, parent_id=s.parent_id, span_id=s.span_id) for s in spans ] + task2_spans)
def test_MASS_var_dist_random_init(self): writer = DummyWriter() model_kwargs = dict( net_name='SmallMLP', net_kwargs=dict( out_dim=10, nonlinearity='elu', batch_norm=True, dropout=False, in_shape=(1, 28, 28), ), var_dist_init_strategy='random', beta=0.9, n_mixture_components=10, n_classes=10, ) model = MASSCE(writer, **model_kwargs) # Make sure everything started as zeros for p in model.var_dist.parameters(): self.assertEqual((p != 0).sum(), 0) self.assertTrue(p.requires_grad) # Make sure initialization worked model.initialize(None) for mog in model.var_dist.q: self.assertEqual((mog.mixture_logits > 0.1).sum(), 0) self.assertEqual((mog.mixture_logits < -0.1).sum(), 0) self.assertEqual((mog.loc < -10).sum(), 0) self.assertEqual((mog.loc > 10).sum(), 0) self.assertEqual((mog.scale_tril < -1).sum(), 0) self.assertEqual((mog.scale_tril > 1).sum(), 0)
def task(tracer, q): tracer.writer = DummyWriter() with tracer.trace("child"): pass spans = tracer.writer.pop() q.put( [dict(trace_id=s.trace_id, parent_id=s.parent_id) for s in spans])
def test_tracer_is_properly_configured(self): # the tracer must be properly configured assert self.tracer.tags == {"env": "production", "debug": "false"} assert self.tracer.enabled is False assert self.tracer.writer.agent_url == "http://dd-agent.service.consul:8126" writer = DummyWriter() self.tracer.configure(enabled=True, writer=writer) with self.tracer.trace("keep"): pass spans = writer.pop() assert len(spans) == 1 with self.tracer.trace("drop"): pass spans = writer.pop() assert len(spans) == 0
def test_ctx(): tracer = ddtrace.Tracer() tracer.writer = DummyWriter() with tracer.trace("test") as s1: assert tracer.current_span() == s1 assert tracer.current_root_span() == s1 assert tracer.get_call_context().trace_id == s1.trace_id assert tracer.get_call_context().span_id == s1.span_id with tracer.trace("test2") as s2: assert tracer.current_span() == s2 assert tracer.current_root_span() == s1 assert tracer.get_call_context().trace_id == s1.trace_id assert tracer.get_call_context().span_id == s2.span_id with tracer.trace("test3") as s3: assert tracer.current_span() == s3 assert tracer.current_root_span() == s1 assert tracer.get_call_context().trace_id == s1.trace_id assert tracer.get_call_context().span_id == s3.span_id assert tracer.get_call_context().trace_id == s1.trace_id assert tracer.get_call_context().span_id == s2.span_id with tracer.trace("test4") as s4: assert tracer.current_span() == s4 assert tracer.current_root_span() == s1 assert tracer.get_call_context().trace_id == s1.trace_id assert tracer.get_call_context().span_id == s4.span_id assert tracer.current_span() == s1 assert tracer.current_root_span() == s1 assert tracer.current_span() is None assert tracer.current_root_span() is None assert s1.parent_id is None assert s2.parent_id == s1.span_id assert s3.parent_id == s2.span_id assert s4.parent_id == s1.span_id assert s1.trace_id == s2.trace_id == s3.trace_id == s4.trace_id assert s1.metrics[SAMPLING_PRIORITY_KEY] == 1 assert SAMPLING_PRIORITY_KEY not in s2.metrics assert ORIGIN_KEY not in s1.meta t = tracer.writer.pop_traces() assert len(t) == 1 assert len(t[0]) == 4 _s1, _s2, _s3, _s4 = t[0] assert s1 == _s1 assert s2 == _s2 assert s3 == _s3 assert s4 == _s4 with tracer.trace("s") as s: assert s.parent_id is None assert s.trace_id != s1.trace_id
def test_MASS_var_dist_MLE_from_training_data_init(self): """ Make a dummy training dataset where half the inputs are all zeros and are class 1, and the other half are all ones and are class 2. Then overwrite the model weights so it spits out normally-distributed outputs. The test makes sure the variational distribution of outputs the model learns is correct. """ for out_dim in [8, 10, 12]: data = torch.zeros(256000, 1, 28, 28) data[:128000, :, :, :] = 1 target = torch.zeros(256000, dtype=torch.int64) target[128000:] = 1 writer = DummyWriter() model_kwargs = dict( net_name='SmallMLP', net_kwargs=dict( out_dim=out_dim, nonlinearity='tanh', batch_norm=False, dropout=False, in_shape=(1, 28, 28), ), var_dist_init_strategy='MLE_from_training_data', beta=0.9, n_mixture_components=1, n_classes=2, ) model = MASSCE(writer, **model_kwargs) # overwrite fcout so the outputs are normally distributed model.net.fcout.forward = lambda x: torch.normal( torch.matmul(x, torch.ones(200, out_dim))) for name, p in model.net.named_parameters(): if 'weight' in name: p.data[:] = 1 elif 'bias' in name: p.data[:] = 0 # Make sure everything started as zeros for p in model.var_dist.parameters(): self.assertEqual((p != 0).sum(), 0) self.assertTrue(p.requires_grad) # Make sure initialization worked model.initialize([(data, target)]) for i, mog in enumerate(model.var_dist.q): if i == 0: mean = 200 else: mean = 0 self.assertEqual(mog.mixture_logits, 0) np.testing.assert_array_almost_equal(mean * np.ones( (1, out_dim)), mog.loc.detach().numpy(), decimal=2) np.testing.assert_array_almost_equal( np.eye(out_dim), mog.scale_tril.squeeze(0).detach().numpy(), decimal=2)
def test_manual_keep_then_drop(): tracer = Tracer() tracer.writer = DummyWriter() # Test changing the value before finish. with tracer.trace("asdf") as root: with tracer.trace("child") as child: child.set_tag(MANUAL_KEEP_KEY) root.set_tag(MANUAL_DROP_KEY) spans = tracer.writer.pop() assert spans[0].metrics[SAMPLING_PRIORITY_KEY] is priority.USER_REJECT
def test_tracer_trace_across_multiple_forks(): """ When a trace is started and crosses multiple process boundaries The trace should be continued in the child processes """ tracer = ddtrace.Tracer() tracer.writer = DummyWriter() # Start a span in this process then start a child process which itself # starts a span and spawns another child process which starts a span. def task(tracer, q): tracer.writer = DummyWriter() def task2(tracer, q): tracer.writer = DummyWriter() with tracer.trace("child2"): pass spans = tracer.writer.pop() q.put([ dict(trace_id=s.trace_id, parent_id=s.parent_id) for s in spans ]) with tracer.trace("child1"): q2 = multiprocessing.Queue() p = multiprocessing.Process(target=task2, args=(tracer, q2)) p.start() p.join() task2_spans = q2.get() spans = tracer.writer.pop() q.put([ dict(trace_id=s.trace_id, parent_id=s.parent_id, span_id=s.span_id) for s in spans ] + task2_spans) # Assert tracer in a new process correctly recreates the writer q = multiprocessing.Queue() with tracer.trace("parent") as parent: p = multiprocessing.Process(target=task, args=(tracer, q)) p.start() p.join() children = q.get() assert len(children) == 2 child1, child2 = children assert parent.trace_id == child1["trace_id"] == child2["trace_id"] assert child1["parent_id"] == parent.span_id assert child2["parent_id"] == child1["span_id"]
def test_get_report_hostname_default(get_hostname): get_hostname.return_value = "test-hostname" tracer = Tracer() tracer.writer = DummyWriter() with override_global_config(dict(report_hostname=False)): with tracer.trace("span"): with tracer.trace("child"): pass spans = tracer.writer.pop() root = spans[0] child = spans[1] assert root.get_tag(HOSTNAME_KEY) is None assert child.get_tag(HOSTNAME_KEY) is None
def test_manual_drop(): tracer = Tracer() tracer.writer = DummyWriter() # On a root span with tracer.trace("asdf") as s: s.set_tag(MANUAL_DROP_KEY) spans = tracer.writer.pop() assert spans[0].metrics[SAMPLING_PRIORITY_KEY] is priority.USER_REJECT # On a child span with tracer.trace("asdf"): with tracer.trace("child") as s: s.set_tag(MANUAL_DROP_KEY) spans = tracer.writer.pop() assert spans[0].metrics[SAMPLING_PRIORITY_KEY] is priority.USER_REJECT
def test_diffable_option(self, input): """ Make sure the jacobian can only be used in a loss function if diffable=True """ input = torch.tensor(input).requires_grad_() net = SmallMLP(DummyWriter(), True, (784, ), 10, 'elu', False, False) net.eval() output = net(input) with self.assertRaises(RuntimeError): J = jacobian(input, output) loss = J.sum() loss.backward() J = jacobian(input, output, diffable=True) loss = J.sum() loss.backward()
def test_gradients_computed_at_right_time(self, input): """ Make sure the jacobian is having gradients computed only after calling backward on it """ input = torch.tensor(input).requires_grad_() net = SmallMLP(DummyWriter(), True, (784, ), 10, 'elu', False, False) net.eval() output = net(input) J = jacobian(input, output, diffable=True) loss = J.sum() self.assertIsNone(input.grad) self.assertTrue(all(p.grad == None for p in net.parameters()), 'parameter had grad before backward') loss.backward() self.assertIsNotNone(input.grad) self.assertTrue( all(list(p.grad is not None for p in net.parameters())[:-1]), 'parameters missing grads after backward')
def test_early_exit(): t = ddtrace.Tracer() t.writer = DummyWriter() s1 = t.trace("1") s2 = t.trace("2") s1.finish() s2.finish() assert s1.parent_id is None assert s2.parent_id is s1.span_id traces = t.writer.pop_traces() assert len(traces) == 1 assert len(traces[0]) == 2 s1 = t.trace("1-1") s1.finish() assert s1.parent_id is None s1 = t.trace("1-2") s1.finish() assert s1.parent_id is None
def test_VIB_init(self): for covariance_type in ['diag', 'full']: writer = DummyWriter() model_kwargs = dict( net_name='SmallMLP', net_kwargs=dict( out_dim=10, nonlinearity='elu', batch_norm=True, dropout=False, in_shape=(1, 28, 28), ), covariance_type=covariance_type, beta=0.001, n_mixture_components=2, train_var_dist_samples=8, test_var_dist_samples=12, n_classes=10, ) model = VIB(writer, **model_kwargs) for m in model.marginal.scale_tril.detach(): np.testing.assert_array_equal(0, torch.triu(m, diagonal=1))
def test_jacobian_util_gradcheck_SmallMLP(self, input): """ Finite differences gradient check on jacobian """ input = torch.DoubleTensor(input).requires_grad_() net = SmallMLP(DummyWriter(), True, (784, ), 10, 'elu', False, False).to(torch.float64) net.eval() output = net(input) J = jacobian(input, output) class TestFunction(torch.autograd.Function): @staticmethod def forward(ctx, i): return net(i) @staticmethod def backward(ctx, grad_output): go = grad_output.unsqueeze(1) return torch.bmm(go, J).squeeze(1) testfunction = TestFunction.apply self.assertTrue(gradcheck(testfunction, [input]))
def test_VIB_encode(self): n_samples, n_batch, rep_dim = (9, 256, 10) in_shape = (3, 32, 32) for covariance_type in ['diag', 'full']: writer = DummyWriter() model_kwargs = dict( net_name='SmallMLP', net_kwargs=dict( out_dim=rep_dim, nonlinearity='elu', batch_norm=True, dropout=False, in_shape=in_shape, ), covariance_type=covariance_type, beta=0.001, n_mixture_components=2, train_var_dist_samples=n_samples, test_var_dist_samples=12, n_classes=10, ) model = VIB(writer, **model_kwargs) data = torch.ones(n_batch, *in_shape) output = model.net.forward(data) mean, std = model.encode(output) self.assertEqual(tuple(mean.shape), (n_batch, rep_dim)) if covariance_type == 'diag': self.assertEqual(tuple(output.shape), (n_batch, 2 * rep_dim)) self.assertEqual(tuple(std.shape), (n_batch, rep_dim)) elif covariance_type == 'full': self.assertEqual(tuple(output.shape), (n_batch, rep_dim + rep_dim * (rep_dim + 1) / 2)) self.assertEqual(tuple(std.shape), (n_batch, rep_dim, rep_dim))
def test_multithreaded(): tracer = ddtrace.Tracer() tracer.writer = DummyWriter() def target(): with tracer.trace("s1"): with tracer.trace("s2"): pass with tracer.trace("s3"): pass for i in range(1000): ts = [threading.Thread(target=target) for _ in range(10)] for t in ts: t.start() for t in ts: t.join() traces = tracer.writer.pop_traces() assert len(traces) == 10 for trace in traces: assert len(trace) == 3
An integration test that uses a real Redis client that we expect to be implicitly traced via `ddtrace-run` """ import redis from ddtrace import Pin from tests import DummyWriter from tests.contrib.config import REDIS_CONFIG if __name__ == "__main__": r = redis.Redis(port=REDIS_CONFIG["port"]) pin = Pin.get_from(r) assert pin pin.tracer.writer = DummyWriter() r.flushall() spans = pin.tracer.writer.pop() assert len(spans) == 1 assert spans[0].service == "redis" assert spans[0].resource == "FLUSHALL" long_cmd = "mget %s" % " ".join(map(str, range(1000))) us = r.execute_command(long_cmd) spans = pin.tracer.writer.pop() assert len(spans) == 1 span = spans[0] assert span.service == "redis" assert span.name == "redis.command"
def test_var_dist_SGD_update(self): """ Make sure SGD is actually updating the parameters of the variational distribution according to the var_dist_lr """ for var_dist_lr in [0, 1e-3]: data = torch.zeros(256, 1, 28, 28).uniform_(0, 1) target = torch.zeros(256, dtype=torch.int64).random_(0, 10) writer = DummyWriter() optimizer_class_name = 'SGD' optimizer_kwargs = dict( lr=1e-3, momentum=0.9, var_dist_optimizer_kwargs=dict(lr=var_dist_lr)) model_kwargs = dict( net_name='SmallMLP', net_kwargs=dict( out_dim=10, nonlinearity='elu', batch_norm=True, dropout=False, ), var_dist_init_strategy='standard_basis', beta=0.1, n_mixture_components=3, ) model_kwargs['n_classes'] = 10 model_kwargs['net_kwargs']['in_shape'] = (1, 28, 28) model = MASSCE(writer, **model_kwargs) model.initialize(None) orig_var_dist_params = copy.deepcopy( list(model.var_dist.named_parameters())) orig_var_dist_q_ids = [id(q_i) for q_i in model.var_dist.q] optimizer = model.get_optimizer(optimizer_class_name, optimizer_kwargs) model.train() optimizer.zero_grad() output = model.net.forward(data) loss = model.loss(data, output, target) # Make sure gradients start not initialized for nparam in model.var_dist.named_parameters(): self.assertIsNone(nparam[1].grad) loss.backward() for nparam in model.var_dist.named_parameters(): self.assertIsNotNone(nparam[1].grad) # Make sure none of the parameters changed values for orig_nparam, nparam in zip(orig_var_dist_params, model.var_dist.named_parameters()): self.assertEqual(orig_nparam[0], nparam[0]) np.testing.assert_array_equal(orig_nparam[1].detach(), nparam[1].detach()) self.assertEqual(orig_var_dist_q_ids, [id(q_i) for q_i in model.var_dist.q]) optimizer.step() # Make sure var_dist parameters changed values for orig_nparam, nparam in zip(orig_var_dist_params, model.var_dist.named_parameters()): self.assertEqual(orig_nparam[0], nparam[0]) if var_dist_lr == 0: np.testing.assert_array_equal(orig_nparam[1].detach(), nparam[1].detach()) elif var_dist_lr == 1e-3: with self.assertRaises(AssertionError): np.testing.assert_array_equal(orig_nparam[1].detach(), nparam[1].detach()) self.assertEqual(orig_var_dist_q_ids, [id(q_i) for q_i in model.var_dist.q])
def test_VIB_sample_representation(self): torch.manual_seed(42) np.random.seed(42) n_samples, n_batch, rep_dim = (9, 256, 10) for covariance_type in ['diag', 'full']: writer = DummyWriter() model_kwargs = dict( net_name='SmallMLP', net_kwargs=dict( out_dim=rep_dim, nonlinearity='elu', batch_norm=True, dropout=False, in_shape=(1, 28, 28), ), covariance_type=covariance_type, beta=0.001, n_mixture_components=2, train_var_dist_samples=n_samples, test_var_dist_samples=12, n_classes=10, ) model = VIB(writer, **model_kwargs) # Here every row should be close to an ascending integer mean = torch.Tensor(np.arange(0, n_batch)) mean = mean.unsqueeze(1).repeat(1, rep_dim) if covariance_type == 'diag': std = torch.ones(rep_dim).repeat(n_batch, 1) elif covariance_type == 'full': std = torch.eye(rep_dim, rep_dim).repeat(n_batch, 1, 1) std *= 1e-4 samples = model.sample_representation(mean, std, model.train_var_dist_samples) self.assertEqual(tuple(samples.shape), (n_samples, n_batch, rep_dim)) np.testing.assert_allclose(np.arange(0, n_batch), torch.mean(samples, dim=[0, 2]), atol=1e-3) # And here every column should be close to an ascending integer mean = torch.Tensor(np.arange(0, rep_dim)) mean = mean.repeat(n_batch, 1) if covariance_type == 'diag': std = torch.ones(rep_dim).repeat(n_batch, 1) elif covariance_type == 'full': std = torch.eye(rep_dim, rep_dim).repeat(n_batch, 1, 1) std *= 1e-4 samples = model.sample_representation(mean, std, model.train_var_dist_samples) self.assertEqual(tuple(samples.shape), (n_samples, n_batch, rep_dim)) np.testing.assert_allclose(np.arange(0, rep_dim), torch.mean(samples, dim=[0, 1]), atol=1e-3) # And here all the elements in the representation should be tightly correlated, but differ between samples if covariance_type == 'full': mean = torch.zeros(n_batch, rep_dim) std = torch.eye(rep_dim, rep_dim).repeat(n_batch, 1, 1) * 1e-4 std[:, :, 0] = 1 samples = model.sample_representation( mean, std, model.train_var_dist_samples) intra_rep_differences = torch.max( samples, dim=2)[0] - torch.min(samples, dim=2)[0] self.assertEqual(tuple(intra_rep_differences.shape), (n_samples, n_batch)) np.testing.assert_array_less(intra_rep_differences, 1e-3) inter_sample_differences = torch.max( samples, dim=0)[0] - torch.min(samples, dim=0)[0] self.assertEqual(tuple(inter_sample_differences.shape), (n_batch, rep_dim)) np.testing.assert_array_less(1, inter_sample_differences) inter_batch_differences = torch.max( samples, dim=1)[0] - torch.min(samples, dim=1)[0] self.assertEqual(tuple(inter_batch_differences.shape), (n_samples, rep_dim)) np.testing.assert_array_less(1, inter_batch_differences)