def top_x(self): hps = HyperParams() dtype = torch.float device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') N = hps.batch_size d_in = hps.bottom_output + 1 x = torch.randn(N, d_in, device=device, dtype=dtype) return x
def flat_x(): hps = HyperParams() dtype = torch.float device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') N = hps.batch_size d_in = hps.n_input flat_x = torch.randn(N, d_in, device=device, dtype=dtype) return flat_x
def conv_x(self): hps = HyperParams() dtype = torch.float device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') N = hps.batch_size d_in = hps.n_input x = torch.randn(N, 4, d_in // 4, device=device, dtype=dtype) return x
def test_model_mlp_stack(self, sample_x): path_to_json = os.path.join(get_project_root(), 'res_experiments', 'hps_base_mlp.json') hps = HyperParams.from_file(path_to_json=path_to_json) bottom = getattr(models, hps.bottom_net) top = getattr(models, hps.top_net) net = FullModel(hps, bottom, top) out = net(sample_x) assert True
def sample_x(self): hps = HyperParams() dtype = torch.float device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') N = hps.batch_size d_in = hps.n_input x0 = torch.randn(N, d_in, device=device, dtype=dtype) x1 = torch.randn(N, 1, device=device, dtype=dtype) return x0, x1
def flat(self): path_to_json = os.path.join(get_project_root(), 'res_experiments', 'hps_partly_independent_mlp.json') hps = HyperParams.from_file(path_to_json=path_to_json) N = hps.batch_size d_in = hps.n_input device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') flat = torch.randn(N, d_in, device=device, dtype=torch.float) self.hps = hps return flat
def test_model_conv_stack(self, sample_conv_x): path_to_json = os.path.join(get_project_root(), 'res_experiments', 'hps_base_conv.json') hps = HyperParams.from_file(path_to_json=path_to_json) bottom = getattr(models, hps.bottom_net) top = getattr(models, hps.top_net) model = Model(hps) x = model.make_loader() x_ = next(iter(x)) net = FullModel(hps, bottom, top) out = net(x_['X']) assert True
def load_pretrained_bottom(self, path_to_model, path_to_json): hps = HyperParams.from_file(path_to_json=path_to_json) model_common = Model(hps) model_common.load_model(path_to_model) pretrained_dict = model_common.net.state_dict() pretrained_dict = { k: v for k, v in pretrained_dict.items() if 'bottom' in k } model_dict = self.net.state_dict() model_dict.update(pretrained_dict) self.net.load_state_dict(model_dict)
def test_full_mlp_stack(self): path_to_json = os.path.join(get_project_root(), 'res_experiments', 'hps_partly_independent_mlp.json') filename = get_project_root() / 'data' / 'small_parameters_base.fits' params = fits.open(filename)[0].data hps = HyperParams.from_file(path_to_json=path_to_json) bottom = getattr(models, hps.bottom_net) top = getattr(models, hps.top_net) model = Model(hps) x = model.make_loader(data_arr=params) x_ = next(iter(x)) net = FullModel(hps, bottom, top) out = net(x_['X'])
def test_top_net_forward(self): hps = HyperParams() dtype = torch.float device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') N = hps.batch_size D_in = hps.top_input D_out = hps.top_output x = torch.randn([N, 224], device=device, dtype=dtype) net = TopNet(hps) convnet = Conv1dModel(hps) # y = torch.rand(N, D_out, device=device, dtype=dtype) out = convnet(x) assert True
def test_bottom_simple_conv1_net(self, conv_x): hps = HyperParams() net = BottomSimpleConv1d(hps) out = net.forward(conv_x) assert True
def test_bottom_simple_mlp_net(self, flat_x): hps = HyperParams() net = BottomSimpleMLPNet(hps) out = net.forward(flat_x) assert True
def test_base_net(self): hps = HyperParams() net = BaseNet(hps) assert net.hps.activation == 'relu'
def test_hyper_params(self): path_to_json = os.path.join(get_project_root(), 'res_experiments', 'hps_base_mlp.json') hps = HyperParams.from_file(path_to_json=path_to_json) assert hps.activation == 'relu'
def test_resnet(self, conv_x): hps = HyperParams() net = BottomResNet(hps) out = net.forward(conv_x.unsqueeze(2)) assert True
def test_top_net_init(self): hps = HyperParams() net = TopNet(hps) convnet = Conv1dModel(hps) assert True