def _get_set_state(self, grad_values):
        config = self._get_config()

        opt1 = build_optimizer(config)
        opt1.set_param_groups(self._parameters(), lr=1, momentum=0.9)
        self.assertIsInstance(opt1, self._instance_to_test())

        self._set_gradient(self._parameters(), grad_values)
        opt1.step(where=0)

        if config["name"] == "zero":
            opt1.consolidate_state_dict()

        state = opt1.get_classy_state()

        opt2 = build_optimizer(config)
        opt2.set_param_groups(self._parameters(), lr=2)

        self.assertNotEqual(opt1.options_view.lr, opt2.options_view.lr)
        opt2.set_classy_state(state)
        self.assertEqual(opt1.options_view.lr, opt2.options_view.lr)

        for i in range(len(opt1.optimizer.param_groups[0]["params"])):
            self.assertTrue(
                torch.allclose(
                    opt1.optimizer.param_groups[0]["params"][i],
                    opt2.optimizer.param_groups[0]["params"][i],
                ))

        if config["name"] == "zero":
            opt2.consolidate_state_dict()

        self._compare_momentum_values(opt1.get_classy_state()["optim"],
                                      opt2.get_classy_state()["optim"])

        # check if the optimizers behave the same on params update
        mock_classy_vision_model1 = self._parameters()
        mock_classy_vision_model2 = self._parameters()
        self._set_gradient(mock_classy_vision_model1, grad_values)
        self._set_gradient(mock_classy_vision_model2, grad_values)
        opt1 = build_optimizer(config)
        opt1.set_param_groups(mock_classy_vision_model1)
        opt2 = build_optimizer(config)
        opt2.set_param_groups(mock_classy_vision_model2)
        opt1.step(where=0)
        opt2.step(where=0)
        for i in range(len(opt1.optimizer.param_groups[0]["params"])):
            print(opt1.optimizer.param_groups[0]["params"][i])
            self.assertTrue(
                torch.allclose(
                    opt1.optimizer.param_groups[0]["params"][i],
                    opt2.optimizer.param_groups[0]["params"][i],
                ))

        if config["name"] == "zero":
            opt1.consolidate_state_dict()
            opt2.consolidate_state_dict()

        self._compare_momentum_values(opt1.get_classy_state()["optim"],
                                      opt2.get_classy_state()["optim"])
Exemplo n.º 2
0
    def test_get_lr(self):
        opt = build_optimizer(self._get_config())
        param = torch.tensor([1.0], requires_grad=True)
        opt.set_param_groups([{"params": [param], "lr": 1}])

        self.assertEqual(opt.options_view.lr, 1)

        # Case two: verify LR changes
        opt = build_optimizer(self._get_config())
        param = torch.tensor([1.0], requires_grad=True)
        opt.set_param_groups([{"params": [param], "lr": LinearParamScheduler(1, 2)}])

        self.assertAlmostEqual(opt.options_view.lr, 1)
        opt.step(where=0.5)
        self.assertAlmostEqual(opt.options_view.lr, 1.5)
Exemplo n.º 3
0
 def test_build_sgd(self):
     config = self._get_config()
     mock_classy_vision_model = self._get_mock_classy_vision_model(
         trainable_params=True)
     opt = build_optimizer(config)
     opt.init_pytorch_optimizer(mock_classy_vision_model)
     self.assertTrue(isinstance(opt, self._instance_to_test()))
Exemplo n.º 4
0
    def from_config(cls, config: Dict[str, Any]) -> "ClassificationTask":
        """Instantiates a ClassificationTask from a configuration.

        Args:
            config: A configuration for a ClassificationTask.
                See :func:`__init__` for parameters expected in the config.

        Returns:
            A ClassificationTask instance.
        """
        optimizer_config = config["optimizer"]
        optimizer_config["num_epochs"] = config["num_epochs"]

        datasets = {}
        phase_types = ["train", "test"]
        for phase_type in phase_types:
            datasets[phase_type] = build_dataset(config["dataset"][phase_type])
        loss = build_loss(config["loss"])
        test_only = config.get("test_only", False)
        meters = build_meters(config.get("meters", {}))
        model = build_model(config["model"])
        # put model in eval mode in case any hooks modify model states, it'll
        # be reset to train mode before training
        model.eval()
        optimizer = build_optimizer(optimizer_config)

        task = (cls().set_num_epochs(config["num_epochs"]).set_loss(
            loss).set_test_only(test_only).set_model(model).set_optimizer(
                optimizer).set_meters(meters).set_distributed_options(
                    BroadcastBuffersMode[config.get("broadcast_buffers",
                                                    "DISABLED")]))
        for phase_type in phase_types:
            task.set_dataset(datasets[phase_type], phase_type)

        return task
Exemplo n.º 5
0
 def test_raise_error_on_non_trainable_params(self):
     # Test Raise ValueError if there are no trainable params in the model.
     config = self._get_config()
     with self.assertRaises(ValueError):
         opt = build_optimizer(config)
         opt.init_pytorch_optimizer(
             self._get_mock_classy_vision_model(trainable_params=False))
Exemplo n.º 6
0
    def test_set_invalid_state(self):
        config = self._get_config()
        opt = build_optimizer(config)
        opt.set_param_groups(self._parameters())
        self.assertTrue(isinstance(opt, self._instance_to_test()))

        with self.assertRaises(KeyError):
            opt.set_classy_state({})
Exemplo n.º 7
0
    def _get_set_state(self, grad_values):
        config = self._get_config()

        mock_classy_vision_model = self._get_mock_classy_vision_model()
        opt1 = build_optimizer(config)
        opt1.init_pytorch_optimizer(mock_classy_vision_model)

        self._set_model_gradient(mock_classy_vision_model, grad_values)
        opt1.step()
        state = opt1.get_classy_state()

        config["lr"] += 0.1
        opt2 = build_optimizer(config)
        opt2.init_pytorch_optimizer(mock_classy_vision_model)
        self.assertTrue(isinstance(opt1, self._instance_to_test()))
        opt2.set_classy_state(state)
        self.assertEqual(opt1.parameters, opt2.parameters)
        for i in range(len(opt1.optimizer.param_groups[0]["params"])):
            self.assertTrue(
                torch.allclose(
                    opt1.optimizer.param_groups[0]["params"][i],
                    opt2.optimizer.param_groups[0]["params"][i],
                ))
        self._compare_momentum_values(opt1.get_classy_state()["optim"],
                                      opt2.get_classy_state()["optim"])

        # check if the optimizers behave the same on params update
        mock_classy_vision_model1 = self._get_mock_classy_vision_model()
        mock_classy_vision_model2 = self._get_mock_classy_vision_model()
        self._set_model_gradient(mock_classy_vision_model1, grad_values)
        self._set_model_gradient(mock_classy_vision_model2, grad_values)
        opt1 = build_optimizer(config)
        opt1.init_pytorch_optimizer(mock_classy_vision_model1)
        opt2 = build_optimizer(config)
        opt2.init_pytorch_optimizer(mock_classy_vision_model2)
        opt1.step()
        opt2.step()
        for i in range(len(opt1.optimizer.param_groups[0]["params"])):
            print(opt1.optimizer.param_groups[0]["params"][i])
            self.assertTrue(
                torch.allclose(
                    opt1.optimizer.param_groups[0]["params"][i],
                    opt2.optimizer.param_groups[0]["params"][i],
                ))
        self._compare_momentum_values(opt1.get_classy_state()["optim"],
                                      opt2.get_classy_state()["optim"])
Exemplo n.º 8
0
    def from_config(cls, config: Dict[str, Any]) -> "ClassificationTask":
        """Instantiates a ClassificationTask from a configuration.

        Args:
            config: A configuration for a ClassificationTask.
                See :func:`__init__` for parameters expected in the config.

        Returns:
            A ClassificationTask instance.
        """
        optimizer_config = config["optimizer"]

        # TODO Make distinction between epochs and phases in optimizer clear
        train_phases_per_epoch = config["dataset"]["train"].get(
            "phases_per_epoch", 1)
        optimizer_config[
            "num_epochs"] = config["num_epochs"] * train_phases_per_epoch

        datasets = {}
        phase_types = ["train", "test"]
        for phase_type in phase_types:
            datasets[phase_type] = build_dataset(config["dataset"][phase_type])
        loss = build_loss(config["loss"])
        test_only = config.get("test_only", False)
        amp_args = config.get("amp_args")
        meters = build_meters(config.get("meters", {}))
        model = build_model(config["model"])

        # hooks config is optional
        hooks_config = config.get("hooks")
        hooks = []
        if hooks_config is not None:
            hooks = build_hooks(hooks_config)

        optimizer = build_optimizer(optimizer_config)

        task = (cls().set_num_epochs(
            config["num_epochs"]).set_test_phase_period(
                config.get("test_phase_period",
                           1)).set_loss(loss).set_test_only(test_only).
                set_model(model).set_optimizer(optimizer).set_meters(
                    meters).set_amp_args(amp_args).set_distributed_options(
                        broadcast_buffers_mode=BroadcastBuffersMode[config.get(
                            "broadcast_buffers", "disabled").upper()],
                        batch_norm_sync_mode=BatchNormSyncMode[config.get(
                            "batch_norm_sync_mode", "disabled").upper()],
                        find_unused_parameters=config.get(
                            "find_unused_parameters", True),
                    ).set_hooks(hooks))

        use_gpu = config.get("use_gpu")
        if use_gpu is not None:
            task.set_use_gpu(use_gpu)

        for phase_type in phase_types:
            task.set_dataset(datasets[phase_type], phase_type)

        return task
Exemplo n.º 9
0
    def test_set_invalid_state(self):
        config = self._get_config()
        mock_classy_vision_model = self._get_mock_classy_vision_model()
        opt = build_optimizer(config)
        opt.init_pytorch_optimizer(mock_classy_vision_model)
        self.assertTrue(isinstance(opt, self._instance_to_test()))

        with self.assertRaises(KeyError):
            opt.set_classy_state({})
Exemplo n.º 10
0
    def test_step_args(self):
        opt = build_optimizer(self._get_config())
        opt.set_param_groups([torch.tensor([1.0], requires_grad=True)])

        # where argument must be named explicitly
        with self.assertRaises(RuntimeError):
            opt.step(0)

        # this shouldn't crash
        opt.step(where=0)
Exemplo n.º 11
0
 def _build_optimizer(self):
     """
     Build optimizers using the optimizer settings specified by user.
     For SGD, we support LARC as well. In order to use LARC, Apex must
     be installed.
     """
     optimizer_config = self.config["OPTIMIZER"]
     if optimizer_config.use_larc and optimizer_config.name != "sgd_fsdp":
         assert is_apex_available(), "Apex must be available to use LARC"
     optim = build_optimizer(optimizer_config)
     return optim
    def _build_task(self, num_epochs, skip_param_schedulers=False):
        config = self._get_config(skip_param_schedulers)
        config["optimizer"]["num_epochs"] = num_epochs
        task = (ClassificationTask().set_num_epochs(num_epochs).set_loss(
            build_loss(config["loss"])).set_model(build_model(
                config["model"])).set_optimizer(
                    build_optimizer(config["optimizer"])))
        for phase_type in ["train", "test"]:
            dataset = build_dataset(config["dataset"][phase_type])
            task.set_dataset(dataset, phase_type)

        self.assertTrue(task is not None)
        return task
Exemplo n.º 13
0
    def _pretraining_worker(
        gpu_id: int,
        with_fsdp: bool,
        with_activation_checkpointing: bool,
        with_larc: bool,
        sync_file: str,
        result_file: str,
    ):
        init_distributed_on_file(world_size=2,
                                 gpu_id=gpu_id,
                                 sync_file=sync_file)
        torch.manual_seed(0)
        torch.backends.cudnn.deterministic = True

        # Create the inputs
        batch = torch.randn(size=(8, 3, 224, 224)).cuda()
        target = torch.tensor(0.0).cuda()

        # Create a fake model based on SWAV blocks
        config = TestRegnetFSDP._create_pretraining_config(
            with_fsdp, with_activation_checkpointing, with_larc=with_larc)
        model = build_model(config["MODEL"], config["OPTIMIZER"])
        model = model.cuda()
        if with_fsdp:
            model = fsdp_wrapper(model, **config.MODEL.FSDP_CONFIG)
        else:
            model = DistributedDataParallel(model, device_ids=[gpu_id])
        criterion = SwAVLoss(loss_config=config["LOSS"]["swav_loss"])
        optimizer = build_optimizer(config["OPTIMIZER"])
        optimizer.set_param_groups(model.parameters())

        # Run a few iterations and collect the losses
        losses = []
        num_iterations = 5
        for iteration in range(num_iterations):
            out = model(batch)
            loss = criterion(out[0], target)
            if gpu_id == 0:
                losses.append(loss.item())
            optimizer.zero_grad()
            loss.backward()
            if iteration <= 2:
                for name, param in model.named_parameters():
                    if "prototypes" in name:
                        param.grad = None
            optimizer.step(where=float(iteration / num_iterations))

        # Store the losses in a file to compare several methods
        if gpu_id == 0:
            with open(result_file, "wb") as f:
                pickle.dump(losses, f)
    def test_batchnorm_weight_decay(self):
        class MyModel(nn.Module):
            def __init__(self):
                super().__init__()
                self.lin = nn.Linear(2, 3)
                self.relu = nn.ReLU()
                self.bn = nn.BatchNorm1d(3)

            def forward(self, x):
                return self.bn(self.relu(self.lin(x)))

        torch.manual_seed(1)
        model = MyModel()

        opt = build_optimizer(self._get_config())
        bn_params, lin_params = split_batchnorm_params(model)

        lin_param_before = model.lin.weight.detach().clone()
        bn_param_before = model.bn.weight.detach().clone()

        with torch.enable_grad():
            x = torch.tensor([[1.0, 1.0], [1.0, 2.0]])
            out = model(x).pow(2).sum()
            out.backward()

        opt.set_param_groups([
            {
                "params": lin_params,
                "lr": LinearParamScheduler(1, 2),
                "weight_decay": 0.5,
            },
            {
                "params": bn_params,
                "lr": 0,
                "weight_decay": 0
            },
        ])

        opt.step(where=0.5)

        # Make sure the linear parameters are trained but not the batch norm
        self.assertFalse(torch.allclose(model.lin.weight, lin_param_before))
        self.assertTrue(torch.allclose(model.bn.weight, bn_param_before))

        opt.step(where=0.5)

        # Same, but after another step and triggering the lr scheduler
        self.assertFalse(torch.allclose(model.lin.weight, lin_param_before))
        self.assertTrue(torch.allclose(model.bn.weight, bn_param_before))
    def test_get_state(self):
        config = get_test_task_config()
        loss = build_loss(config["loss"])
        task = (
            ClassificationTask().set_num_epochs(1).set_loss(loss).set_model(
                build_model(config["model"])).set_optimizer(
                    build_optimizer(config["optimizer"])))
        for phase_type in ["train", "test"]:
            dataset = build_dataset(config["dataset"][phase_type])
            task.set_dataset(dataset, phase_type)

        task.prepare()

        task = build_task(config)
        task.prepare()
    def test_training(self):
        """Checks we can train a small MLP model."""
        config = get_test_mlp_task_config()
        task = (ClassificationTask().set_num_epochs(10).set_loss(
            build_loss(config["loss"])).set_model(build_model(
                config["model"])).set_optimizer(
                    build_optimizer(config["optimizer"])).set_meters([
                        AccuracyMeter(topk=[1])
                    ]).set_hooks([LossLrMeterLoggingHook()]))
        for split in ["train", "test"]:
            dataset = build_dataset(config["dataset"][split])
            task.set_dataset(dataset, split)

        self.assertTrue(task is not None)

        trainer = LocalTrainer()
        trainer.train(task)
        accuracy = task.meters[0].value["top_1"]
        self.assertAlmostEqual(accuracy, 1.0)
Exemplo n.º 17
0
    def from_config(cls, config: Dict[str, Any]) -> "ClassificationTask":
        """Instantiates a ClassificationTask from a configuration.

        Args:
            config: A configuration for a ClassificationTask.
                See :func:`__init__` for parameters expected in the config.

        Returns:
            A ClassificationTask instance.
        """
        optimizer_config = config["optimizer"]
        optimizer_config["num_epochs"] = config["num_epochs"]

        datasets = {}
        phase_types = ["train", "test"]
        for phase_type in phase_types:
            datasets[phase_type] = build_dataset(config["dataset"][phase_type])
        loss = build_loss(config["loss"])
        test_only = config.get("test_only", False)
        amp_args = config.get("amp_args")
        meters = build_meters(config.get("meters", {}))
        model = build_model(config["model"])
        optimizer = build_optimizer(optimizer_config)

        task = (cls().set_num_epochs(
            config["num_epochs"]).set_test_phase_period(
                config.get("test_phase_period",
                           1)).set_loss(loss).set_test_only(test_only).
                set_model(model).set_optimizer(optimizer).set_meters(
                    meters).set_amp_args(amp_args).set_distributed_options(
                        broadcast_buffers_mode=BroadcastBuffersMode[config.get(
                            "broadcast_buffers", "disabled").upper()],
                        batch_norm_sync_mode=BatchNormSyncMode[config.get(
                            "batch_norm_sync_mode", "disabled").upper()],
                    ))
        for phase_type in phase_types:
            task.set_dataset(datasets[phase_type], phase_type)

        return task
Exemplo n.º 18
0
 def from_config(cls, config):
     return cls(base_optimizer=build_optimizer(config["base_optimizer"]))
Exemplo n.º 19
0
 def test_build_sgd(self):
     config = self._get_config()
     opt = build_optimizer(config)
     opt.set_param_groups(self._parameters())
     self.assertTrue(isinstance(opt, self._instance_to_test()))
Exemplo n.º 20
0
    def test_lr_schedule(self):
        config = self._get_config()

        mock_classy_vision_model = self._get_mock_classy_vision_model()
        opt = build_optimizer(config)
        opt.init_pytorch_optimizer(mock_classy_vision_model)

        # Test initial learning rate
        for group in opt.optimizer.param_groups:
            self.assertEqual(group["lr"], 0.1)

        def _test_lr_schedule(optimizer, num_epochs, epochs, targets):
            for i in range(len(epochs)):
                epoch = epochs[i]
                target = targets[i]
                param_groups = optimizer.optimizer.param_groups.copy()
                optimizer.update_schedule_on_epoch(epoch / num_epochs)
                for idx, group in enumerate(optimizer.optimizer.param_groups):
                    self.assertEqual(group["lr"], target)
                    # Make sure all but LR is same
                    param_groups[idx]["lr"] = target
                    self.assertEqual(param_groups[idx], group)

        # Test constant learning schedule
        num_epochs = 90
        epochs = [
            0, 0.025, 0.05, 0.1, 0.5, 1, 15, 29, 30, 31, 59, 60, 61, 88, 89
        ]
        targets = [0.1] * 15
        _test_lr_schedule(opt, num_epochs, epochs, targets)

        # Test step learning schedule
        config["lr"] = {"name": "step", "values": [0.1, 0.01, 0.001]}
        opt = build_optimizer(config)
        opt.init_pytorch_optimizer(mock_classy_vision_model)
        targets = [0.1] * 8 + [0.01] * 3 + [0.001] * 4
        _test_lr_schedule(opt, num_epochs, epochs, targets)

        # Test step learning schedule with warmup
        init_lr = 0.01
        warmup_epochs = 0.1
        config["lr"] = {
            "name":
            "composite",
            "schedulers": [
                {
                    "name": "linear",
                    "start_lr": init_lr,
                    "end_lr": 0.1
                },
                {
                    "name": "step",
                    "values": [0.1, 0.01, 0.001]
                },
            ],
            "update_interval":
            "epoch",
            "interval_scaling": ["rescaled", "fixed"],
            "lengths":
            [warmup_epochs / num_epochs, 1 - warmup_epochs / num_epochs],
        }

        opt = build_optimizer(config)
        opt.init_pytorch_optimizer(mock_classy_vision_model)
        targets = [0.01, 0.0325, 0.055] + [0.1] * 5 + [0.01] * 3 + [0.001] * 4
        _test_lr_schedule(opt, num_epochs, epochs, targets)
Exemplo n.º 21
0
 def test_set_param_groups(self):
     opt = build_optimizer(self._get_config())
     # This must crash since we're missing the .set_param_groups call
     with self.assertRaises(RuntimeError):
         opt.step(where=0)
Exemplo n.º 22
0
def main():
    args = parser.parse_args()
    print(args)
    args.cuda = not args.no_cuda and torch.cuda.is_available()

    if args.cuda and args.mkldnn:
        assert False, "We can not runing this work on GPU backend and MKLDNN backend \
                please set one backend.\n"

    if args.cuda:
        print("Using GPU backend to do this work.\n")
    elif args.mkldnn:
        print("Using MKLDNN backend to do this work.\n")
    else:
        print("Using native CPU backend to do this work.\n")

    # set it to the folder where video files are saved
    video_dir = args.video_dir + "/UCF-101"
    # set it to the folder where dataset splitting files are saved
    splits_dir = args.video_dir + "/ucfTrainTestlist"
    # set it to the file path for saving the metadata
    metadata_file = args.video_dir + "/metadata.pth"

    resnext3d_configs =model_config.ResNeXt3D_Config(video_dir, splits_dir, metadata_file, args.num_epochs)
    resnext3d_configs.setUp()

    datasets = {}
    dataset_train_configs = resnext3d_configs.dataset_configs["train"]
    dataset_test_configs = resnext3d_configs.dataset_configs["test"]
    dataset_train_configs["batchsize_per_replica"] = args.batch_size_train
    # For testing, batchsize per replica should be equal to clips_per_video
    dataset_test_configs["batchsize_per_replica"] = args.batch_size_eval
    dataset_test_configs["clips_per_video"] = args.batch_size_eval

    datasets["train"] = build_dataset(dataset_train_configs)
    datasets["test"] = build_dataset(dataset_test_configs)

    model = build_model(resnext3d_configs.model_configs)
    meters = build_meters(resnext3d_configs.meters_configs)
    loss = build_loss({"name": "CrossEntropyLoss"})
    optimizer = build_optimizer(resnext3d_configs.optimizer_configs)

    # there some ops are not supported by MKLDNN, so convert input to CPU tensor
    if args.mkldnn:
        heads_configs = resnext3d_configs.model_configs['heads'][0]
        in_plane = heads_configs['in_plane']
        num_classes = heads_configs['num_classes']
        act_func = heads_configs['activation_func']
        mkldnn_head_fcl = MkldnnFullyConvolutionalLinear(in_plane, num_classes, act_func)

        if args.evaluate:
            model = model.eval()
            model = mkldnn_utils.to_mkldnn(model)
            model._heads['pathway0-stage4-block2']['default_head'].head_fcl = mkldnn_head_fcl.eval()
        else:
            model._heads['pathway0-stage4-block2']['default_head'].head_fcl = mkldnn_head_fc

    # print(model)
    if args.evaluate:
        validata(datasets, model, loss, meters, args)
        return

    train(datasets, model, loss, optimizer, meters, args)
Exemplo n.º 23
0
    def from_config(cls, config: Dict[str, Any]) -> "ClassificationTask":
        """Instantiates a ClassificationTask from a configuration.

        Args:
            config: A configuration for a ClassificationTask.
                See :func:`__init__` for parameters expected in the config.

        Returns:
            A ClassificationTask instance.
        """
        test_only = config.get("test_only", False)
        if not test_only:
            # TODO Make distinction between epochs and phases in optimizer clear
            train_phases_per_epoch = config["dataset"]["train"].get(
                "phases_per_epoch", 1)

            optimizer_config = config["optimizer"]
            optimizer_config["num_epochs"] = (config["num_epochs"] *
                                              train_phases_per_epoch)
            optimizer = build_optimizer(optimizer_config)
            param_schedulers = build_optimizer_schedulers(optimizer_config)

        datasets = {}
        phase_types = ["train", "test"]
        for phase_type in phase_types:
            if phase_type in config["dataset"]:
                datasets[phase_type] = build_dataset(
                    config["dataset"][phase_type])
        loss = build_loss(config["loss"])
        amp_args = config.get("amp_args")
        meters = build_meters(config.get("meters", {}))
        model = build_model(config["model"])

        mixup_transform = None
        if config.get("mixup") is not None:
            assert "alpha" in config[
                "mixup"], "key alpha is missing in mixup dict"
            mixup_transform = MixupTransform(
                config["mixup"]["alpha"], config["mixup"].get("num_classes"))

        # hooks config is optional
        hooks_config = config.get("hooks")
        hooks = []
        if hooks_config is not None:
            hooks = build_hooks(hooks_config)

        distributed_config = config.get("distributed", {})
        distributed_options = {
            "broadcast_buffers_mode":
            BroadcastBuffersMode[distributed_config.get(
                "broadcast_buffers", "before_eval").upper()],
            "batch_norm_sync_mode":
            BatchNormSyncMode[distributed_config.get("batch_norm_sync_mode",
                                                     "disabled").upper()],
            "batch_norm_sync_group_size":
            distributed_config.get("batch_norm_sync_group_size", 0),
            "find_unused_parameters":
            distributed_config.get("find_unused_parameters", True),
        }

        task = (
            cls().set_num_epochs(config["num_epochs"]).set_test_phase_period(
                config.get(
                    "test_phase_period",
                    1)).set_loss(loss).set_test_only(test_only).set_model(
                        model).set_meters(meters).set_amp_args(amp_args).
            set_mixup_transform(mixup_transform).set_distributed_options(
                **distributed_options).set_hooks(hooks).set_bn_weight_decay(
                    config.get("bn_weight_decay", False)))

        if not test_only:
            task.set_optimizer(optimizer)
            task.set_optimizer_schedulers(param_schedulers)

        use_gpu = config.get("use_gpu")
        if use_gpu is not None:
            task.set_use_gpu(use_gpu)

        for phase_type in datasets:
            task.set_dataset(datasets[phase_type], phase_type)

        # NOTE: this is a private member and only meant to be used for
        # logging/debugging purposes. See __repr__ implementation
        task._config = config

        return task