def test_metrics_multilabel(self): n_classes = 2 classes_labels = ["t1", "t2"] y_pred = torch.tensor([[[1, 0], [1, 1]], [[0, 1], [0, 0]]], dtype=torch.float).permute(0, 2, 1) y_true = torch.tensor([[[1, 1], [1, 0]], [[1, 1], [0, 0]]], dtype=torch.long).permute(0, 2, 1) params = {"type": "precision", "average": "micro", "multilabel": True} precision = Metric.from_params(Params(params), num_classes=n_classes, classes_labels=classes_labels) precision(y_true, y_pred) val = precision.get_metric_value() assert val == 3 / 4 params = { "type": "precision", "average": "micro", "batch_average": True, "multilabel": True, } precision = Metric.from_params(Params(params), num_classes=n_classes, classes_labels=classes_labels) precision(y_true, y_pred) val = precision.get_metric_value() assert val == (2 / 3 + 1 / 1) / 2
def test_patches(self): img = np.random.random((64, 64, 3)) mask = np.random.randint(0, 20, (63, 63)) patch_shape = 32 param = {"type": "sample_to_patches", "patch_shape": patch_shape} sample_to_patches = Transform.from_params(Params(param)) res = sample_to_patches(image=img, mask=mask) assert all( [p.shape[:2] == (patch_shape, patch_shape) for p in res["image"]]) assert all( [p.shape[:2] == (patch_shape, patch_shape) for p in res["mask"]]) patch_shape = (32, 2) param = {"type": "sample_to_patches", "patch_shape": patch_shape} sample_to_patches = Transform.from_params(Params(param)) res = sample_to_patches(image=img, mask=mask) assert all([p.shape[:2] == patch_shape for p in res["image"]]) assert all([p.shape[:2] == patch_shape for p in res["mask"]]) with pytest.raises(ValueError): sample_to_patches(image=np.ones((2, 32, 32, 3))) with pytest.raises(IndexError): sample_to_patches(image=np.ones((2, ))) with pytest.raises(ValueError): patch_shape = (1, 1) param = {"type": "sample_to_patches", "patch_shape": patch_shape} sample_to_patches = Transform.from_params(Params(param)) res = sample_to_patches(image=img, mask=mask)
def test_warmup_lr(self): named_parameters = [("x", torch.nn.Parameter())] lr = 1 opt = Optimizer.from_params(Params({ "type": "sgd", "lr": lr }), model_params=named_parameters) warmup_start = 0.1 warmup_end = lr warmup_duration = 10 gamma_exp = 0.95 sched_params = { "type": "warmup", "scheduler": { "type": "exponential", "gamma": gamma_exp }, "warmup_start_value": warmup_start, "warmup_end_value": warmup_end, "warmup_duration": warmup_duration, } sched = Scheduler.from_params(Params(sched_params), optimizer=opt) assert np.allclose(sched.get_last_lr()[0], warmup_start) for _ in range(warmup_duration): opt.step() sched.step() assert np.allclose(sched.get_last_lr()[0], warmup_end) for i in range(50): opt.step() sched.step() assert np.allclose(sched.get_last_lr()[0], warmup_end * gamma_exp**(i + 1))
def test_metrics_multiclass(self): n_classes = 3 classes_labels = ["b", "t1", "t2"] logits = (torch.tensor([ [[[10, 1, 1], [10, 1, 1], [1, 1, 10], [1, 1, 10]]], [[[1, 10, 1], [1, 1, 10], [1, 1, 10], [1, 10, 1]]], ]).to(torch.float).permute(0, 3, 1, 2)) y_true = torch.tensor([[[0, 1, 2, 2]], [[1, 2, 2, 0]]]) params = {"type": "precision", "average": "micro"} precision = Metric.from_params(Params(params), num_classes=n_classes, classes_labels=classes_labels) precision.get_metric_value() precision(y_true, logits) val = precision.get_metric_value() assert val == 5 / 6 params = { "type": "precision", "average": "micro", "batch_average": True, } precision = Metric.from_params(Params(params), num_classes=n_classes, classes_labels=classes_labels) precision(y_true, logits) val = precision.get_metric_value() assert val == (2 / 2 + 3 / 4) / 2
def test_lr_concat(self): named_parameters = [("x", torch.nn.Parameter())] lr = 1 opt = Optimizer.from_params(Params({ "type": "sgd", "lr": lr }), model_params=named_parameters) gamma_exp = 0.95 gamma_plateau = 0.01 patience = 2 sched_params = { "type": "concat", "schedulers": [ { "type": "exponential", "gamma": gamma_exp }, { "type": "reduce_on_plateau", "mode": "max", "factor": gamma_plateau, "patience": patience, "min_lr": 1e-3, }, ], "durations": [ 10, ], } sched = Scheduler.from_params(Params(sched_params), optimizer=opt) assert np.allclose(sched.get_last_lr()[0], lr) for i in range(10): opt.step() sched.step(0.1) assert np.allclose(sched.get_last_lr()[0], lr * gamma_exp**(i + 1)) start_lr = lr * gamma_exp**10 sched.step(0.5) assert np.allclose(sched.get_last_lr()[0], start_lr) sched.step(0.4) assert np.allclose(sched.get_last_lr()[0], start_lr) sched.step(0.3) assert np.allclose(sched.get_last_lr()[0], start_lr) sched.step(0.3) assert np.allclose(sched.get_last_lr()[0], start_lr * gamma_plateau) # test min_lr sched.step(0.3) assert np.allclose(sched.get_last_lr()[0], start_lr * gamma_plateau) sched.step(0.4) assert np.allclose(sched.get_last_lr()[0], start_lr * gamma_plateau) sched.step(0.3) assert np.allclose(sched.get_last_lr()[0], max(1e-3, start_lr * gamma_plateau**2))
def test_blur(self): blur = Transform.from_params(Params({"type": "blur", "p": 1})) img = np.random.random((128, 128, 3)) res = blur(image=img) assert img.mean() != res["image"].mean() blur = Transform.from_params(Params({"type": "blur", "p": 0})) img = np.random.random((128, 128, 3)) res = blur(image=img) assert img.mean() == res["image"].mean()
def test_read_jsonnet(self): with pytest.raises(RuntimeError): Params.from_file(self.FIXTURES_ROOT / "configs" / "resnet50_unet_bad.jsonnet") params = Params.from_file(self.FIXTURES_ROOT / "configs" / "resnet50_unet.jsonnet") assert len(params) == 3 assert params["encoder"]["type"] == "resnet50" assert params["decoder"]["type"] == "unet" assert params["decoder"]["decoder_channels"] == [512, 256, 128, 64, 32]
def test_normalizations_with_resnet(self): params = Params({ "type": "resnet50", "normalization": { "type": "identity" } }) encoder = Encoder.from_params(params) assert isinstance(encoder.layer4[0].bn1, Identity) x = torch.zeros((2, 3, 4, 4)).normal_() encoder.forward(x) params = Params({ "type": "resnet50", "normalization": { "type": "batch_norm_2d" } }) encoder = Encoder.from_params(params) assert isinstance(encoder.layer4[0].bn1, BatchNorm2d) encoder.forward(x) params = Params({ "type": "resnet50", "normalization": { "type": "batch_renorm_2d" } }) encoder = Encoder.from_params(params) assert isinstance(encoder.layer4[0].bn1, BatchRenorm2d) encoder.forward(x) params = Params({ "type": "resnet50", "normalization": { "type": "group_norm" } }) with pytest.raises(ConfigurationError): Encoder.from_params(params) params = Params({ "type": "resnet50", "normalization": { "type": "group_norm", "num_groups": 8 }, }) encoder = Encoder.from_params(params) assert isinstance(encoder.layer4[0].bn1, GroupNorm) encoder.forward(x)
def test_compose(self): param = { "type": "compose", "transforms": [ "blur", { "type": "horizontal_flip", "always_apply": True }, { "type": "random_crop", "height": 32, "width": 32 }, ], "additional_targets": { "image2": "image", "mask1": "mask", "label": "mask" }, } compose = Transform.from_params(Params(param)) img = np.random.random((128, 128, 3)) mask = np.random.randint(0, 20, (128, 128)) res = compose(image2=img, image=img, mask1=mask, label=mask) assert np.allclose(res["image"], res["image2"]) assert np.allclose(res["mask1"], res["label"]) assert res["image"].shape[:2] == (32, 32) assert res["mask1"].shape[:2] == (32, 32)
def test_union_str_list_str(self): class A(FromParams): def __init__(self, x: List[Tuple[Union[str, List[str]], Dict[str, Any]]]): self.x = x a = A.from_params(Params({"x": [["test", {}]]}))
def test_make_param_groups(self): model = Model.from_params( Params({ "encoder": { "type": "resnet50" }, "decoder": { "type": "unet", "decoder_channels": [512, 256, 128, 64, 32], }, "num_classes": 4, })) named_parameters = [x for x in model.named_parameters()] alpha_decoder = 0.01 alpha_logits = 0.1 alpha_encoder = 0.0001 params = { "param_groups": [ { "regexes": "encoder", "params": { "alpha": alpha_encoder } }, { "regexes": ["decoder.logits.*.weight", "decoder.logits.*.bias"], "params": { "alpha": alpha_logits }, }, ], "alpha": alpha_decoder, } reg = Regularizer.from_params(Params(params), model_params=named_parameters) assert reg.param_groups[0]["alpha"] == alpha_encoder assert len(reg.param_groups[0]["params"]) == 161 assert reg.param_groups[2]["alpha"] == alpha_decoder assert len(reg.param_groups[2]["params"]) == 20 assert reg.param_groups[1]["alpha"] == alpha_logits assert len(reg.param_groups[1]["params"]) == 2
def test_assert_empty(self): config_dict = {"test": "hello"} params = Params(config_dict) with pytest.raises(ConfigurationError): params.assert_empty("dummy") assert params.pop("test") == "hello" params.assert_empty("dummy")
def test_trainer(self): blocks = 2 num_channels = 32 params = Params({ "color_labels": { "type": "txt", "label_text_file": self.FIXTURES_ROOT / "dataset" / "multiclass" / "classes.txt", }, "train_dataset": { "type": "image_csv", "csv_filename": self.FIXTURES_ROOT / "dataset" / "multiclass" / "train.csv", "base_dir": self.FIXTURES_ROOT / "dataset" / "multiclass", }, "model": { "encoder": "resnet50", "decoder": { "decoder_channels": [512, 256, 128, 64, 32] } # "loss": {"type": "dice"}, }, "metrics": [ "iou", ("iou_class", { "type": "iou", "average": None }), "precision", ], "val_dataset": { "type": "image_csv", "csv_filename": self.FIXTURES_ROOT / "dataset" / "multiclass" / "test.csv", "base_dir": self.FIXTURES_ROOT / "dataset" / "multiclass", }, "lr_scheduler": { "type": "exponential", "gamma": 0.95 }, "early_stopping": { "patience": 20 }, "model_out_dir": str(self.TEMPORARY_DIR / "model"), "num_epochs": 2, "evaluate_every_epoch": 1, }) trainer = Trainer.from_params(params) trainer.train()
def test_build_optimizers(self): model = Model.from_params( Params( { "encoder": {}, "decoder": {"decoder_channels": [512, 256, 128, 64, 32],}, "num_classes": 4, } ) ) named_parameters = [x for x in model.named_parameters()] optimizers = Optimizer.get_available() for optimizer in optimizers: lr = 10 params = {"type": optimizer, "lr": lr} opt = Optimizer.from_params(Params(params), model_params=named_parameters) assert opt.state_dict()["param_groups"][0]["lr"] == lr
def test_build_regularizers(self): model = Model.from_params( Params({ "encoder": {}, "decoder": { "decoder_channels": [512, 256, 128, 64, 32], }, "num_classes": 4, })) named_parameters = [x for x in model.named_parameters()] regularizers = Regularizer.get_available() for regularizer in regularizers: alpha = 10 params = {"type": regularizer, "alpha": alpha} reg = Regularizer.from_params(Params(params), model_params=named_parameters) assert reg.param_groups[0]["alpha"] == alpha
def test_basic_from_params(self): config_dict = {} transform = Transform.from_params(Params(config_dict)) assert transform.apply(1) == 7 config_dict["x"] = 4 with pytest.raises(ConfigurationError): Transform.from_params(Params(config_dict)) config_dict["type"] = "mult_by_x" transform = Transform.from_params(Params(config_dict)) assert transform.apply(4) == 16 config_dict["type"] = "mult_by_x_add_y" config_dict["x"] = 4 with pytest.raises(ConfigurationError): Transform.from_params(Params(config_dict)) config_dict["type"] = "mult_by_x_add_y" config_dict["x"] = 4 config_dict["y"] = 2 transform = Transform.from_params(Params(config_dict)) assert transform.apply(4) == 18 config_dict["type"] = "mult_by_x_add_y" config_dict["x"] = 4 config_dict["y"] = "test" with pytest.raises(TypeError): Transform.from_params(Params(config_dict))
def test_encoders_from_params(self): resnet_encoders = ["resnet50", "resnet34", "resnet18"] available_encoders = Encoder.get_available() for encoder in resnet_encoders: assert encoder in available_encoders params = Params({"type": "resnet50", "blocks": 3}) encoder = Encoder.from_params(params) assert encoder.blocks == 3 assert len(encoder.output_dims) == 5 assert encoder.layer4[-1].conv1.in_channels == 2048 x = torch.zeros((2, 3, 128, 128)).normal_() encoder.forward(x) params = Params({"type": "resnet34", "blocks": 2}) encoder = Encoder.from_params(params) assert encoder.blocks == 2 assert len(encoder.output_dims) == 4 assert len(encoder.layer4) == 3 assert encoder.layer4[-1].conv1.in_channels == 512 encoder.forward(x) params = Params({"type": "resnet18"}) encoder = Encoder.from_params(params) assert len(encoder.layer4) == 2 assert encoder.layer4[-1].conv1.in_channels == 512 encoder.forward(x) params = Params({ "type": "resnet50", "blocks": 3, "pretrained": False, "replace_stride_with_dilation": [False, True, True], "normalization": { "type": "identity" }, }) encoder = Encoder.from_params(params) assert isinstance(encoder.layer4[0].bn1, torch.nn.Identity) assert encoder.layer4[1].conv2.dilation == (4, 4) encoder.forward(x)
def test_make_param_groups(self): model = Model.from_params( Params( { "encoder": {"type": "resnet50"}, "decoder": { "type": "unet", "decoder_channels": [512, 256, 128, 64, 32], }, "num_classes": 4, } ) ) named_parameters = [x for x in model.named_parameters()] lr_decoder = 0.01 lr_logits = 0.1 lr_encoder = 0.0001 params = { "param_groups": { "encoder": {"params": {"lr": lr_encoder}}, "decoder_logits": { "regexes": ["decoder.logits.*.weight", "decoder.logits.*.bias"], "params": {"lr": lr_logits}, }, }, "lr": lr_decoder, } opt = Optimizer.from_params(Params(params), model_params=named_parameters) assert opt.state_dict()["param_groups"][0]["lr"] == lr_encoder assert len(opt.state_dict()["param_groups"][0]["params"]) == 161 assert opt.state_dict()["param_groups"][2]["lr"] == lr_decoder assert len(opt.state_dict()["param_groups"][2]["params"]) == 20 assert opt.state_dict()["param_groups"][1]["lr"] == lr_logits assert len(opt.state_dict()["param_groups"][1]["params"]) == 2
def test_encoders_from_params(self): mobilenet_encoders = ["mobilenetv2"] available_encoders = Encoder.get_available() for encoder in mobilenet_encoders: assert encoder in available_encoders params = Params({"type": "mobilenetv2", "blocks": 3}) encoder = Encoder.from_params(params) assert encoder.blocks == 3 assert len(encoder.output_dims) == 5 assert encoder.features[-1][0].out_channels == 1280 x = torch.zeros((2, 3, 256, 256)).normal_() encoder.forward(x) params = Params({ "type": "mobilenetv2", "blocks": 2, "pretrained": False }) encoder = Encoder.from_params(params) assert encoder.blocks == 2 assert len(encoder.output_dims) == 4 assert encoder.features[-1][0].out_channels == 1280 encoder.forward(x)
def test_simple_nesting(self): config_dict = { "type": "from_single_transform", "data_path": 10, "transform": { "type": "mult_by_x", "x": 4 }, } dataset = Dataset.from_params(Params(config_dict)) assert dataset.data_path == 10 assert len(dataset.transforms) == 1 assert dataset.transforms[0].x == 4
def test_reduce_lr_on_plateau(self): named_parameters = [("x", torch.nn.Parameter())] lr = 1 opt = Optimizer.from_params(Params({ "type": "sgd", "lr": lr }), model_params=named_parameters) gamma = 0.01 patience = 2 sched_params = { "type": "reduce_on_plateau", "mode": "max", "factor": gamma, "patience": patience, "min_lr": 1e-3, } sched = Scheduler.from_params(Params(sched_params), optimizer=opt) assert np.allclose(sched.get_last_lr()[0], lr) sched.step(0.5) assert np.allclose(sched.get_last_lr()[0], lr) sched.step(0.4) assert np.allclose(sched.get_last_lr()[0], lr) sched.step(0.3) assert np.allclose(sched.get_last_lr()[0], lr) sched.step(0.3) assert np.allclose(sched.get_last_lr()[0], lr * gamma) # test min_lr sched.step(0.3) assert np.allclose(sched.get_last_lr()[0], lr * gamma) sched.step(0.4) assert np.allclose(sched.get_last_lr()[0], lr * gamma) sched.step(0.3) assert np.allclose(sched.get_last_lr()[0], max(1e-3, lr * gamma**2))
def test_encoders_from_params(self): blocks = 2 num_channels = 32 n_classes = 2 params = Params({ "encoder": { "type": "resnet50", "blocks": blocks, "pretrained": False, "replace_stride_with_dilation": [False, True, True], "normalization": { "type": "identity" }, }, "decoder": { "type": "pan", "decoder_channels_size": num_channels, "normalization": { "type": "batch_renorm_2d" }, "activation": { "type": "leaky_relu", "inplace": True }, "gau_activation": { "type": "swish" }, "upscale_mode": "nearest", }, "loss": { "type": "dice" }, "num_classes": n_classes, }) x = torch.zeros((2, 3, 128, 128)).normal_() y = (torch.zeros( (2, n_classes, 128, 128)).normal_() > 0.5).to(torch.float) model = Model.from_params(params) res = model.forward(x, y) assert res["loss"] > 0 assert res["logits"].shape[1] == n_classes assert isinstance(model.encoder, ResNetEncoder) assert isinstance(model.decoder, PanDecoder)
def test_lazy(self): test_string = "this is a test" extra_string = "extra string" class ConstructedObject(FromParams): def __init__(self, string: str, extra: str): self.string = string self.extra = extra class Testing(FromParams): def __init__(self, lazy_object: Lazy[ConstructedObject]): first_time = lazy_object.construct(extra=extra_string) second_time = lazy_object.construct(extra=extra_string) assert first_time.string == test_string assert first_time.extra == extra_string assert second_time.string == test_string assert second_time.extra == extra_string Testing.from_params(Params({"lazy_object": {"string": test_string}}))
def test_list_with_strings(self): params = { "type": "from_transforms_list", "data_path": "./", "transforms": [ "mult_by_2_add_5", { "type": "mult_by_x", "x": 5 }, "mult_by_2", ], } dataset = Dataset.from_params(Params(params)) assert len(dataset.transforms) == 3 assert isinstance(dataset.transforms[0], MultiplyByXAddY) assert dataset.transforms[0].x == 2 and dataset.transforms[0].y == 5 assert isinstance(dataset.transforms[1], MultiplyByX) assert dataset.transforms[1].x == 5 assert isinstance(dataset.transforms[2], MultiplyByX) assert dataset.transforms[2].x == 2
def test_none_default_args(self): params = {"type": "hardtanh", "min_val": "a"} with pytest.raises(TypeError): Activation.from_params(Params(params))
def test_complex_nesting(self): base_config_dict = { "data_path": "./", "transforms": { "t1": { "type": "mult_by_x", "x": 4 }, "t2": { "type": "mult_by_2_add_5" }, }, } config_dict = deepcopy(base_config_dict) dataset = Dataset.from_params(Params(config_dict)) assert dataset.data_path == "./" assert len(dataset.transforms) == 2 assert dataset.transforms["t2"].y == 5 # We now test with a list of transforms # Default dataset expects to be mapping with pytest.raises(TypeError): config_dict = deepcopy(base_config_dict) config_dict["transforms"] = list( config_dict["transforms"].values()) Dataset.from_params(Params(config_dict)) config_dict = deepcopy(base_config_dict) config_dict["transforms"] = list(config_dict["transforms"].values()) config_dict["type"] = "from_transforms_list" dataset = Dataset.from_params(Params(config_dict)) assert dataset.data_path == "./" assert len(dataset.transforms) == 2 assert dataset.transforms[1].y == 5 with pytest.raises(TypeError): config_dict = deepcopy(base_config_dict) config_dict["type"] = "from_transforms_list" Dataset.from_params(Params(config_dict)) # With tuples # Default dataset expects to be mapping with pytest.raises(TypeError): config_dict = deepcopy(base_config_dict) config_dict["transforms"] = tuple( config_dict["transforms"].values()) Dataset.from_params(Params(config_dict)) config_dict = deepcopy(base_config_dict) config_dict["transforms"] = tuple(config_dict["transforms"].values()) config_dict["type"] = "from_transforms_tuple" dataset = Dataset.from_params(Params(config_dict)) assert dataset.data_path == "./" assert len(dataset.transforms) == 2 assert dataset.transforms[1].y == 5 with pytest.raises(TypeError): config_dict = deepcopy(base_config_dict) config_dict["type"] = "from_transforms_tuple" Dataset.from_params(Params(config_dict)) # With set # Default dataset expects to be mapping with pytest.raises(TypeError): config_dict = deepcopy(base_config_dict) config_dict["transforms"] = set(config_dict["transforms"].values()) Dataset.from_params(Params(config_dict)) config_dict = deepcopy(base_config_dict) config_dict["type"] = "from_set_param" config_dict["transform"] = list(config_dict["transforms"].values())[0] del config_dict["transforms"] config_dict["set_arg"] = [{"x": 1}, {"x": 2}] dataset = Dataset.from_params(Params(config_dict)) assert dataset.data_path == "./" assert len(dataset.transforms) == 1 assert dataset.transforms[0].x == 4 assert len(dataset.set_arg) == 2 with pytest.raises(TypeError): config_dict = deepcopy(base_config_dict) config_dict["type"] = "from_set_param" config_dict["transform"] = list( config_dict["transforms"].values())[0] del config_dict["transforms"] config_dict["set_arg"] = {"k1": {"x": 1}, "k2": {"x": 2}} Dataset.from_params(Params(config_dict))
def test_write_read_from_file(self): config_dict = {"test": "hello"} params = Params(config_dict) assert params.as_dict() == config_dict write_path = self.TEMPORARY_DIR / "dummy_config.json" params.to_file(str(write_path)) params2 = Params.from_file(str(write_path)) assert params.as_dict() == params2.as_dict() assert params.pop("test") == "hello" assert params.pop("test2", "none") == "none" with pytest.raises(ConfigurationError): params.pop("test")
def test_encoders_from_params(self): available_decoders = Decoder.get_available() assert "pan" in available_decoders channels = [3, 6, 12, 24, 48] num_channels = 16 xs = [ torch.zeros(4, c, 100 // (i + 1), 100 // (i + 1)).normal_() for i, c in enumerate(channels) ] n_classes = 8 params = Params( { "type": "pan", "encoder_channels": channels, "decoder_channels_size": num_channels, "num_classes": n_classes, } ) decoder = Decoder.from_params(params) assert decoder.fpa.pooling_branch[1][0].in_channels == channels[-1] assert decoder.fpa.pooling_branch[1][0].out_channels == num_channels assert decoder.gau3.process_high[1][0].in_channels == num_channels assert decoder.gau3.process_high[1][0].out_channels == num_channels assert decoder.gau3.process_low[0].in_channels == channels[-4] assert decoder.gau3.process_low[0].out_channels == num_channels assert decoder.logits[0].in_channels == num_channels assert decoder.logits[0].out_channels == n_classes decoder.forward(*xs) params = Params( { "type": "pan", "encoder_channels": channels, "decoder_channels_size": num_channels, } ) with pytest.raises(ConfigurationError): Decoder.from_params(params) params = Params( { "type": "pan", "encoder_channels": channels, "decoder_channels_size": num_channels, "num_classes": n_classes, "normalization": {"type": "batch_renorm_2d"}, "activation": {"type": "leaky_relu", "inplace": True}, "gau_activation": {"type": "swish"}, "upscale_mode": "nearest", } ) decoder = Decoder.from_params(params) assert isinstance(decoder.gau3.process_low[2], BatchRenorm2d) assert isinstance(decoder.gau3.process_low[1], torch.nn.LeakyReLU) assert decoder.gau3.process_high[1][1]._get_name() == "Swish" decoder.forward(*xs)
def test_steps_scheduler(self): named_parameters = [("x", torch.nn.Parameter())] lr = 10 opt = Optimizer.from_params(Params({ "type": "sgd", "lr": lr }), model_params=named_parameters) step_size = 10 gamma = 0.1 sched_params = {"type": "step", "step_size": step_size, "gamma": gamma} sched = Scheduler.from_params(Params(sched_params), optimizer=opt) assert sched.get_last_lr()[0] == lr for _ in range(step_size - 1): opt.step() sched.step() assert sched.get_last_lr()[0] == lr opt.step() sched.step() assert np.allclose(sched.get_last_lr()[0], lr * gamma) lr = 10 opt = Optimizer.from_params(Params({ "type": "sgd", "lr": lr }), model_params=named_parameters) gamma = 0.1 sched_params = { "type": "multi_step", "milestones": [5, 20], "gamma": gamma } sched = Scheduler.from_params(Params(sched_params), optimizer=opt) assert sched.get_last_lr()[0] == lr for _ in range(4): opt.step() sched.step() assert sched.get_last_lr()[0] == lr opt.step() sched.step() assert np.allclose(sched.get_last_lr()[0], lr * gamma) for _ in range(14): sched.step() opt.step() assert np.allclose(sched.get_last_lr()[0], lr * gamma) opt.step() sched.step() assert np.allclose(sched.get_last_lr()[0], lr * gamma**2) lr = 10 opt = Optimizer.from_params(Params({ "type": "sgd", "lr": lr }), model_params=named_parameters) gamma = 0.1 sched_params = {"type": "exponential", "gamma": gamma} sched = Scheduler.from_params(Params(sched_params), optimizer=opt) assert sched.get_last_lr()[0] == lr for i in range(50): opt.step() sched.step() assert np.allclose(sched.get_last_lr()[0], lr * gamma**(i + 1))
def test_pop_nested_param(self): config_dict = {"model": {"type": "test", "other_param": 1}} params = Params(config_dict) assert isinstance(params.pop("model"), Params)