Exemplo n.º 1
0
    def test_opt_mismatch(self):
        # test optimizer mis-match
        with tempfile.TemporaryDirectory() as ckpt_dir:
            param_1 = nn.Parameter(torch.Tensor([1]))
            optimizer_1 = alf.optimizers.Adam(lr=0.2)
            alg_1_no_op = SimpleAlg(params=[param_1], name="alg_1_no_op")
            alg_1 = SimpleAlg(
                params=[param_1], optimizer=optimizer_1, name="alg_1")

            param_2 = nn.Parameter(torch.Tensor([2]))
            optimizer_2 = alf.optimizers.Adam(lr=0.2)
            alg_2 = SimpleAlg(
                params=[param_2], optimizer=optimizer_2, name="alg_2")

            optimizer_root = alf.optimizers.Adam(lr=0.1)
            param_root = nn.Parameter(torch.Tensor([0]))
            alg_root_1_no_op = ComposedAlg(
                params=[param_root],
                optimizer=optimizer_root,
                sub_alg1=alg_1_no_op,
                sub_alg2=alg_2,
                name="root")

            alg_root_1 = ComposedAlg(
                params=[param_root],
                optimizer=optimizer_root,
                sub_alg1=alg_1,
                sub_alg2=alg_2,
                name="root")

            # case 1: save using alg_root_1_no_op and load using alg_root_1
            step_num = 0
            ckpt_mngr = ckpt_utils.Checkpointer(ckpt_dir, alg=alg_root_1_no_op)
            ckpt_mngr.save(step_num)
            ckpt_mngr = ckpt_utils.Checkpointer(ckpt_dir, alg=alg_root_1)
            self.assertRaises(RuntimeError, ckpt_mngr.load, step_num)

            # case 2: save using alg_root_1 load using alg_root_1_no_op
            step_num = 0
            ckpt_mngr = ckpt_utils.Checkpointer(ckpt_dir, alg=alg_root_1)
            ckpt_mngr.save(step_num)
            ckpt_mngr = ckpt_utils.Checkpointer(ckpt_dir, alg=alg_root_1_no_op)
            self.assertRaises(RuntimeError, ckpt_mngr.load, step_num)
Exemplo n.º 2
0
    def test_multi_algo_single_opt(self):

        with tempfile.TemporaryDirectory() as ckpt_dir:
            # construct algorithms
            param_1 = nn.Parameter(torch.Tensor([1.0]))
            alg_1 = SimpleAlg(params=[param_1], name="alg_1")

            param_2_1 = nn.Parameter(torch.Tensor([2.1]))
            alg_2_1 = SimpleAlg(params=[param_2_1], name="alg_2_1")

            param_2 = nn.Parameter(torch.Tensor([2]))
            alg_2 = SimpleAlg(
                params=[param_2], sub_algs=[alg_2_1], name="alg_2")

            optimizer_root = alf.optimizers.Adam(lr=0.1)
            param_root = nn.Parameter(torch.Tensor([0]))
            alg_root = SimpleAlg(
                params=[param_root],
                optimizer=optimizer_root,
                sub_algs=[alg_1, alg_2],
                name="root")

            ckpt_mngr = ckpt_utils.Checkpointer(ckpt_dir, alg=alg_root)

            all_optimizers = alg_root.optimizers()
            # a number of training steps
            step_num = 0
            ckpt_mngr.save(step_num)
            step_num = 1
            set_learning_rate(all_optimizers, 0.01)
            alg_root.apply(weights_init_ones)
            ckpt_mngr.save(step_num)

            self.assertTrue(get_learning_rate(all_optimizers) == [0.01])

            # load checkpoints
            ckpt_mngr.load(0)

            # check the recovered optimizers
            self.assertTrue(get_learning_rate(all_optimizers) == [0.1])

            # check the recovered paramerter values for all modules
            sd = alg_root.state_dict()
            self.assertTrue((list(sd.values())[0:4] == [
                torch.tensor([1]),
                torch.tensor([2.1]),
                torch.tensor([2.0]),
                torch.tensor([0.0])
            ]))
Exemplo n.º 3
0
    def test_net_and_optimizer(self):
        net = Net()
        optimizer = torch.optim.Adam(net.parameters(), lr=0.1)

        with tempfile.TemporaryDirectory() as ckpt_dir:
            ckpt_mngr = ckpt_utils.Checkpointer(ckpt_dir,
                                                net=net,
                                                optimizer=optimizer)

            # test the case loading from 'latest' which does not exist
            self.assertWarns(UserWarning, ckpt_mngr.load, 'latest')

            # training-step-0, all parameters are zeros
            step_num = 0
            net.apply(weights_init_zeros)
            set_learning_rate(optimizer, 0.1)
            ckpt_mngr.save(step_num)

            # training-step-1, all parameters are ones
            step_num = 1
            net.apply(weights_init_ones)
            set_learning_rate(optimizer, 0.01)
            ckpt_mngr.save(step_num)

            # load ckpt-1
            ckpt_mngr.load(global_step=1)
            self.assertTrue(get_learning_rate(optimizer)[0] == 0.01)
            for para in list(net.parameters()):
                self.assertTrue((para == 1).all())

            # load ckpt-0
            ckpt_mngr.load(global_step=0)
            self.assertTrue(get_learning_rate(optimizer)[0] == 0.1)
            for para in list(net.parameters()):
                self.assertTrue((para == 0).all())

            # load 'latest'
            step_num_from_ckpt = ckpt_mngr.load(global_step='latest')
            self.assertTrue(step_num_from_ckpt == step_num)
            self.assertTrue(get_learning_rate(optimizer)[0] == 0.01)
            for para in list(net.parameters()):
                self.assertTrue((para == 1).all())

            # load a non-existing ckpt won't change current values
            # but will trigger a UserWarning
            self.assertWarns(UserWarning, ckpt_mngr.load, 2)
            self.assertTrue(get_learning_rate(optimizer)[0] == 0.01)
            for para in list(net.parameters()):
                self.assertTrue((para == 1).all())
Exemplo n.º 4
0
    def test_with_param_sharing(self):
        with tempfile.TemporaryDirectory() as ckpt_dir:
            # construct algorithms
            param_1 = nn.Parameter(torch.Tensor([1]))
            alg_1 = SimpleAlg(params=[param_1], name="alg_1")

            param_2 = nn.Parameter(torch.Tensor([2]))
            optimizer_2 = alf.optimizers.Adam(lr=0.2)
            alg_2 = SimpleAlg(
                params=[param_2], optimizer=optimizer_2, name="alg_2")
            alg_2.ignored_param = param_1

            optimizer_root = alf.optimizers.Adam(lr=0.1)
            param_root = nn.Parameter(torch.Tensor([0]))
            alg_root = ComposedAlg(
                params=[param_root],
                optimizer=optimizer_root,
                sub_alg1=alg_1,
                sub_alg2=alg_2,
                name="root")

            ckpt_mngr = ckpt_utils.Checkpointer(ckpt_dir, alg=alg_root)

            # only one copy of the shared param is returned from state_dict
            self.assertTrue(
                '_sub_alg2.ignored_param' not in alg_root.state_dict())

            # a number of training steps
            step_num = 0
            ckpt_mngr.save(step_num)

            # modify the shared param after saving
            with torch.no_grad():
                alg_root._sub_alg2.state_dict()['ignored_param'].copy_(
                    torch.Tensor([-1]))

            self.assertTrue((alg_root._sub_alg2.state_dict()['ignored_param']
                             == torch.Tensor([-1])))
            self.assertTrue((alg_root.state_dict()['_sub_alg1._param_list.0']
                             == torch.Tensor([-1])))

            # the value of the shared parameter is recovered back to saved value
            ckpt_mngr.load(0)
            self.assertTrue((alg_root._sub_alg2.state_dict()['ignored_param']
                             == torch.Tensor([1])))
            self.assertTrue((alg_root.state_dict()['_sub_alg1._param_list.0']
                             == torch.Tensor([1])))
Exemplo n.º 5
0
    def test_multi_alg_multi_opt(self):
        with tempfile.TemporaryDirectory() as ckpt_dir:
            # construct algorithms
            param_1 = nn.Parameter(torch.Tensor([1]))
            alg_1 = SimpleAlg(params=[param_1], name="alg_1")

            param_2 = nn.Parameter(torch.Tensor([2]))
            optimizer_2 = alf.optimizers.Adam(lr=0.2)
            alg_2 = SimpleAlg(
                params=[param_2], optimizer=optimizer_2, name="alg_2")

            optimizer_root = alf.optimizers.Adam(lr=0.1)
            param_root = nn.Parameter(torch.Tensor([0]))
            alg_root = ComposedAlg(
                params=[param_root],
                optimizer=optimizer_root,
                sub_alg1=alg_1,
                sub_alg2=alg_2,
                name="root")

            ckpt_mngr = ckpt_utils.Checkpointer(ckpt_dir, alg=alg_root)

            all_optimizers = alg_root.optimizers()
            # a number of training steps
            step_num = 0
            ckpt_mngr.save(step_num)
            step_num = 1
            set_learning_rate(all_optimizers, 0.01)

            alg_root.apply(weights_init_ones)
            ckpt_mngr.save(step_num)

            # load checkpoints
            ckpt_mngr.load(0)

            # check the recovered optimizers
            expected = [0.1, 0.2]
            np.testing.assert_array_almost_equal(
                get_learning_rate(all_optimizers), expected)
Exemplo n.º 6
0
    def test_with_cycle(self):
        # checkpointer should work regardless of cycles
        with tempfile.TemporaryDirectory() as ckpt_dir:
            # construct algorithms
            param_1 = nn.Parameter(torch.Tensor([1]))
            alg_1 = SimpleAlg(params=[param_1], name="alg_1")

            param_2 = nn.Parameter(torch.Tensor([2]))
            optimizer_2 = alf.optimizers.Adam(lr=0.2)
            alg_2 = SimpleAlg(
                params=[param_2], optimizer=optimizer_2, name="alg_2")

            optimizer_root = alf.optimizers.Adam(lr=0.1)
            param_root = nn.Parameter(torch.Tensor([0]))

            # case 1: cycle without ignore
            alg_root = ComposedAlg(
                params=[param_root],
                optimizer=optimizer_root,
                sub_alg1=alg_1,
                sub_alg2=alg_2,
                name="root")

            alg_2.root = alg_root

            expected_state_dict = OrderedDict(
                [('_sub_alg1._param_list.0', torch.tensor([1.])),
                 ('_sub_alg2._param_list.0', torch.tensor([2.])),
                 ('_sub_alg2._optimizers.0', {
                     'state': {},
                     'param_groups': [{
                         'lr': 0.2,
                         'betas': (0.9, 0.999),
                         'eps': 1e-08,
                         'weight_decay': 0,
                         'amsgrad': False,
                         'params': []
                     },
                                      {
                                          'lr': 0.2,
                                          'betas': (0.9, 0.999),
                                          'eps': 1e-08,
                                          'weight_decay': 0,
                                          'amsgrad': False,
                                          'params': [id(param_2)]
                                      }]
                 }), ('_param_list.0', torch.tensor([0.])),
                 ('_optimizers.0', {
                     'state': {},
                     'param_groups': [
                         {
                             'lr': 0.1,
                             'betas': (0.9, 0.999),
                             'eps': 1e-08,
                             'weight_decay': 0,
                             'amsgrad': False,
                             'params': []
                         },
                         {
                             'lr': 0.1,
                             'betas': (0.9, 0.999),
                             'eps': 1e-08,
                             'weight_decay': 0,
                             'amsgrad': False,
                             'params': [id(param_1),
                                        id(param_root)]
                         }
                     ]
                 })])

            # cycles are not allowed with explicit ignoring
            self.assertRaises(AssertionError, alg_root.state_dict)

            # case 2: cycle with ignore, which also resembles the case where a
            # self-training module (alg_2) is involved
            alg_root2 = ComposedAlgWithIgnore(
                params=[param_root],
                optimizer=optimizer_root,
                sub_alg1=alg_1,
                sub_alg2=alg_2,
                name="root")

            alg_2.root = alg_root2

            ckpt_mngr = ckpt_utils.Checkpointer(ckpt_dir, alg=alg_root2)
            ckpt_mngr.save(0)

            # modify some parameter values after saving
            with torch.no_grad():
                alg_root._sub_alg1._param_list[0].copy_(torch.Tensor([-1]))

            self.assertTrue((alg_root2.state_dict() != expected_state_dict))

            # recover the expected values after loading
            ckpt_mngr.load(0)
            self.assertTrue((alg_root2.state_dict() == expected_state_dict))