Example #1
0
    def meta_optimize(self,
                      meta_optimizer,
                      data,
                      model_cls,
                      optim_it,
                      unroll,
                      out_mul,
                      tf_writer=None,
                      mode='train'):
        assert mode in ['train', 'valid', 'test']
        if mode == 'train':
            self.train()
        else:
            self.eval()

        result_dict = ResultDict()
        unroll_losses = 0
        walltime = 0

        params = C(model_cls()).params

        # lr_sgd = 0.2
        # lr_adam = 0.001
        # sgd = torch.optim.SGD(model.parameters(), lr=lr_sgd)
        # adam = torch.optim.Adam(model.parameters())
        # adagrad = torch.optim.Adagrad(model.parameters())

        iter_pbar = tqdm(range(1, optim_it + 1), 'inner_train')
        iter_watch = utils.StopWatch('inner_train')

        lr = 0.1
        topk = False
        n_sample = 10
        topk_decay = True
        recompute_grad = False

        self.topk_ratio = 0.5
        # self.topk_ratio = 1.0
        set_size = {'layer_0': 500, 'layer_1': 10}  # NOTE: do it smarter

        for iteration in iter_pbar:
            # # conventional optimizers with API
            # model.zero_grad()
            # loss = model(*inner_data.load())
            # loss.backward()
            # adam.step()
            # loss_detached = loss

            iter_watch.touch()
            model_train = C(model_cls(params=params.detach()))
            train_nll = model_train(*data['in_train'].load())
            train_nll.backward()

            params = model_train.params.detach()
            grad = model_train.params.grad.detach()
            act = model_train.activations.detach()

            # conventional optimizers with manual SGD
            params_good = params + grad * (-lr)
            grad_good = grad

            best_loss_sparse = 999999
            for i in range(n_sample):
                mask_gen = MaskGenerator.topk if topk else MaskGenerator.randk
                mask = mask_gen(grad=grad,
                                set_size=set_size,
                                topk=self.topk_ratio)

                if recompute_grad:
                    # recompute gradients again using prunned network
                    params_sparse = params.prune(mask)
                    model_sparse = C(model_cls(params=params_sparse.detach()))
                    loss_sparse = model_sparse(*data['in_train'].load())
                    loss_sparse.backward()
                    grad_sparse = model_sparse.params.grad
                else:
                    # use full gradients but sparsified
                    params_sparse = params.prune(mask)
                    grad_sparse = grad.prune(mask)

                params_sparse_ = params_sparse + grad_sparse * (-lr)  # SGD
                model_sparse = C(model_cls(params=params_sparse_.detach()))
                loss_sparse = model_sparse(*data['in_train'].load())
                if loss_sparse < best_loss_sparse:
                    best_loss_sparse = loss_sparse
                    best_params_sparse = params_sparse_
                    best_grads = grad_sparse
                    best_mask = mask

            # below lines are just for evaluation purpose (excluded from time cost)
            walltime += iter_watch.touch('interval')
            # decayu topk_ratio
            if topk_decay:
                self.topk_ratio *= 0.999

            # update!
            params = params.sparse_update(best_mask, best_grads * (-lr))

            # generalization performance
            model_test = C(model_cls(params=params.detach()))
            test_nll = model_test(*data['in_test'].load())

            # result dict
            result = dict(
                train_nll=train_nll.tolist(),
                test_nll=test_nll.tolist(),
                sparse_0=self.topk_ratio,
                sparse_1=self.topk_ratio,
            )
            result_dict.append(result)
            log_pbar(result, iter_pbar)
            # desc = [f'{k}: {v:5.5}' for k, v in result.items()]
            # iter_pbar.set_description(f"inner_train [ {' / '.join(desc)} ]")

        return result_dict
Example #2
0
    def meta_optimize(self,
                      cfg,
                      meta_optim,
                      data,
                      model_cls,
                      writer=None,
                      mode='train'):
        assert mode in ['train', 'valid', 'test']
        self.set_mode(mode)

        # mask_mode = ['no_mask', 'structured', 'unstructured'][2]
        # data_mode = ['in_train', 'in_test'][0]

        ############################################################################
        analyze_model = False
        analyze_surface = False
        ############################################################################

        result_dict = ResultDict()
        unroll_losses = 0
        walltime = Walltime()
        test_kld = torch.tensor(0.)

        params = C(model_cls()).params
        self.feature_gen.new()
        self.step_gen.new()
        sparse_r = {}  # sparsity
        iter_pbar = tqdm(range(1, cfg['iter_' + mode] + 1), 'Inner_loop')

        for iter in iter_pbar:
            debug_1 = sigint.is_active(iter == 1 or iter % 10 == 0)
            debug_2 = sigstp.is_active()

            with WalltimeChecker(walltime):
                model_train = C(model_cls(params.detach()))
                data_train = data['in_train'].load()
                train_nll, train_acc = model_train(*data_train)
                train_nll.backward()

                grad = model_train.params.grad.detach()
                # g = model_train.params.grad.flat.detach()
                # w = model_train.params.flat.detach()

                # step & mask genration
                feature, v_sqrt = self.feature_gen(grad.flat.detach())
                # step = self.step_gen(feature, v_sqrt, debug=debug_1)
                # step = params.new_from_flat(step[0])
                size = params.size().unflat()

                if cfg.mask_mode == 'structured':
                    mask = self.mask_gen(feature, size, debug=debug_1)
                    mask = ParamsFlattener(mask)
                    mask_layout = mask.expand_as(params)
                    params = params + grad.detach() * mask_layout * (
                        -cfg.inner_lr)
                elif cfg.mask_mode == 'unstructured':
                    mask_flat = self.mask_gen.unstructured(feature, size)
                    mask = params.new_from_flat(mask_flat)
                    params = params + grad.detach() * mask * (-cfg.inner_lr)
                    # update = params.new_from_flat(params.flat + grad.flat.detach() * mask * (-cfg.lr))
                    # params = params + update
                elif cfg.mask_mode == 'no_mask':
                    params = params + grad.detach() * (-cfg.inner_lr)
                else:
                    raise Exception('Unknown setting!')

                # import pdb; pdb.set_trace()
                # step_masked = step * mask_layout
                # params = params + step_masked

            with WalltimeChecker(walltime if mode == 'train' else None):
                model_test = C(model_cls(params))
                if cfg.data_mode == 'in_train':
                    data_test = data_train
                elif cfg.data_mode == 'in_test':
                    data_test = data['in_test'].load()
                test_nll, test_acc = utils.isnan(*model_test(*data_test))

                if debug_2: pdb.set_trace()

                if mode == 'train':
                    unroll_losses += test_nll  # + test_kld
                    if iter % cfg.unroll == 0:
                        meta_optim.zero_grad()
                        unroll_losses.backward()
                        nn.utils.clip_grad_value_(self.parameters(), 0.01)
                        meta_optim.step()
                        unroll_losses = 0

            with WalltimeChecker(walltime):
                if not mode == 'train' or iter % cfg.unroll == 0:
                    params = params.detach_()

            ##########################################################################
            if analyze_model:
                analyzers.model_analyzer(self,
                                         mode,
                                         model_train,
                                         params,
                                         model_cls,
                                         mask.tsize(0),
                                         data,
                                         iter,
                                         optim_it,
                                         analyze_mask=True,
                                         sample_mask=True,
                                         draw_loss=False)
            if analyze_surface:
                analyzers.surface_analyzer(params, best_mask, step, writer,
                                           iter)
            ##########################################################################

            result = dict(
                train_nll=train_nll.tolist(),
                test_nll=test_nll.tolist(),
                train_acc=train_acc.tolist(),
                test_acc=test_acc.tolist(),
                test_kld=test_kld.tolist(),
                walltime=walltime.time,
            )
            if not cfg.mask_mode == 'no_mask':
                result.update(
                    **mask.sparsity(overall=True),
                    **mask.sparsity(overall=False),
                )
            result_dict.append(result)
            log_pbar(result, iter_pbar)

        return result_dict, params
Example #3
0
    def meta_optimize(self,
                      meta_optim,
                      data,
                      model_cls,
                      optim_it,
                      unroll,
                      out_mul,
                      k_obsrv=1,
                      no_mask=False,
                      writer=None,
                      mode='train'):
        assert mode in ['train', 'valid', 'test']
        if no_mask is True:
            raise Exception(
                "this module currently does NOT suport no_mask option")
        self.set_mode(mode)

        ############################################################################
        n_samples = k_obsrv
        """MSG: better postfix?"""
        analyze_model = False
        analyze_surface = False
        ############################################################################

        if analyze_surface:
            writer.new_subdirs('best', 'any', 'rand', 'inv', 'dense')

        result_dict = ResultDict()
        unroll_losses = 0
        walltime = Walltime()
        test_kld = torch.tensor(0.)

        params = C(model_cls()).params
        self.feature_gen.new()
        self.step_gen.new()
        iter_pbar = tqdm(range(1, optim_it + 1), 'Inner_loop')
        set_size = {'layer_0': 500, 'layer_1': 10}  # NOTE: make it smarter

        for iter in iter_pbar:
            debug_1 = sigint.is_active(iter == 1 or iter % 10 == 0)
            debug_2 = sigstp.is_active()

            best_loss = 9999999
            best_params = None

            with WalltimeChecker(walltime):
                model_train = C(model_cls(params.detach()))
                train_nll, train_acc = model_train(*data['in_train'].load())
                train_nll.backward()

                g = model_train.params.grad.flat.detach()
                w = model_train.params.flat.detach()

                feature, v_sqrt = self.feature_gen(g)

                size = params.size().unflat()
                kld = self.mask_gen(feature, size)

                losses = []
                lips = []
                valid_mask_patience = 100
                assert n_samples > 0
                """FIX LATER:
        when n_samples == 0 it can behave like no_mask flag is on."""
                for i in range(n_samples):
                    # step & mask genration
                    for j in range(valid_mask_patience):
                        mask = self.mask_gen.sample_mask()
                        mask = ParamsFlattener(mask)
                        if mask.is_valid_sparsity():
                            if j > 0:
                                print(
                                    f'\n\n[!]Resampled {j + 1} times to get valid mask!'
                                )
                            break
                        if j == valid_mask_patience - 1:
                            raise Exception(
                                "[!]Could not sample valid mask for "
                                f"{j+1} trials.")

                    step_out = self.step_gen(feature, v_sqrt, debug=debug_1)
                    step = params.new_from_flat(step_out[0])

                    mask_layout = mask.expand_as(params)
                    step_sparse = step * mask_layout
                    params_sparse = params + step_sparse
                    params_pruned = params_sparse.prune(mask > 0.5)

                    if params_pruned.size().unflat()['mat_0'][1] == 0:
                        continue

                    # cand_loss = model(*outer_data_s)
                    sparse_model = C(model_cls(params_pruned.detach()))
                    loss, _ = sparse_model(*data['in_train'].load())

                    if (loss < best_loss) or i == 0:
                        best_loss = loss
                        best_params = params_sparse
                        best_pruned = params_pruned
                        best_mask = mask

                if best_params is not None:
                    params = best_params

            with WalltimeChecker(walltime if mode == 'train' else None):
                model_test = C(model_cls(params))
                test_nll, test_acc = utils.isnan(*model_test(
                    *data['in_test'].load()))
                test_kld = kld / data['in_test'].full_size / unroll
                ## kl annealing function 'linear' / 'logistic' / None
                test_kld2 = test_kld * kl_anneal_function(
                    anneal_function=None, step=iter, k=0.0025, x0=optim_it)
                total_test = test_nll + test_kld2

                if mode == 'train':
                    unroll_losses += total_test
                    if iter % unroll == 0:
                        meta_optim.zero_grad()
                        unroll_losses.backward()
                        nn.utils.clip_grad_value_(self.parameters(), 0.01)
                        meta_optim.step()
                        unroll_losses = 0

            with WalltimeChecker(walltime):
                if not mode == 'train' or iter % unroll == 0:
                    params = params.detach_()

            ##########################################################################
            """Analyzers"""
            if analyze_model:
                analyzers.model_analyzer(self,
                                         mode,
                                         model_train,
                                         params,
                                         model_cls,
                                         set_size,
                                         data,
                                         iter,
                                         optim_it,
                                         analyze_mask=True,
                                         sample_mask=True,
                                         draw_loss=False)
            if analyze_surface:
                analyzers.surface_analyzer(params, best_mask, step, writer,
                                           iter)
            ##########################################################################

            result = dict(
                train_nll=train_nll.tolist(),
                test_nll=test_nll.tolist(),
                train_acc=train_acc.tolist(),
                test_acc=test_acc.tolist(),
                test_kld=test_kld.tolist(),
                walltime=walltime.time,
                **best_mask.sparsity(overall=True),
                **best_mask.sparsity(overall=False),
            )
            result_dict.append(result)
            log_pbar(result, iter_pbar)

        return result_dict, params
Example #4
0
    def meta_optimize(self,
                      meta_optim,
                      data,
                      model_cls,
                      optim_it,
                      unroll,
                      out_mul,
                      k_obsrv=0,
                      no_mask=False,
                      writer=None,
                      mode='train'):
        assert mode in ['train', 'valid', 'test']
        self.set_mode(mode)

        ############################################################################
        analyze_model = False
        analyze_surface = False
        ############################################################################

        result_dict = ResultDict()
        unroll_losses = 0
        walltime = Walltime()
        test_kld = torch.tensor(0.)

        params = C(model_cls()).params
        self.feature_gen.new()
        self.step_gen.new()
        sparse_r = {}  # sparsity
        iter_pbar = tqdm(range(1, optim_it + 1), 'Inner_loop')

        for iter in iter_pbar:
            debug_1 = sigint.is_active(iter == 1 or iter % 10 == 0)
            debug_2 = sigstp.is_active()

            with WalltimeChecker(walltime):
                model_train = C(model_cls(params.detach()))
                data_ = data['in_train'].load()
                train_nll, train_acc = model_train(*data_)
                train_nll.backward()

                g = model_train.params.grad.flat.detach()
                w = model_train.params.flat.detach()

                # step & mask genration
                feature, v_sqrt = self.feature_gen(g)
                step = self.step_gen(feature, v_sqrt, debug=debug_1)
                step = params.new_from_flat(step[0])
                size = params.size().unflat()

                if no_mask:
                    params = params + step
                else:
                    kld = self.mask_gen(feature, size, debug=debug_1)
                    test_kld = kld / data['in_test'].full_size / unroll
                    ## kl annealing function 'linear' / 'logistic' / None
                    test_kld2 = test_kld * kl_anneal_function(
                        anneal_function=None, step=iter, k=0.0025, x0=optim_it)
                    mask = self.mask_gen.sample_mask()
                    mask = ParamsFlattener(mask)
                    mask_layout = mask.expand_as(params)
                    import pdb
                    pdb.set_trace()
                    step_masked = step * mask_layout
                    params = params + step_masked

            with WalltimeChecker(walltime if mode == 'train' else None):
                model_test = C(model_cls(params))
                test_nll, test_acc = utils.isnan(*model_test(
                    *data['in_test'].load()))

                if debug_2: pdb.set_trace()

                if mode == 'train':
                    unroll_losses += test_nll  # + test_kld
                    if iter % unroll == 0:
                        meta_optim.zero_grad()
                        unroll_losses.backward()
                        nn.utils.clip_grad_value_(self.parameters(), 0.01)
                        meta_optim.step()
                        unroll_losses = 0

            with WalltimeChecker(walltime):
                if not mode == 'train' or iter % unroll == 0:
                    params = params.detach_()

            ##########################################################################
            if analyze_model:
                analyzers.model_analyzer(self,
                                         mode,
                                         model_train,
                                         params,
                                         model_cls,
                                         mask.tsize(0),
                                         data,
                                         iter,
                                         optim_it,
                                         analyze_mask=True,
                                         sample_mask=True,
                                         draw_loss=False)
            if analyze_surface:
                analyzers.surface_analyzer(params, best_mask, step, writer,
                                           iter)
            ##########################################################################

            result = dict(
                train_nll=train_nll.tolist(),
                test_nll=test_nll.tolist(),
                train_acc=train_acc.tolist(),
                test_acc=test_acc.tolist(),
                test_kld=test_kld.tolist(),
                walltime=walltime.time,
            )
            if no_mask is False:
                result.update(
                    **mask.sparsity(overall=True),
                    **mask.sparsity(overall=False),
                )
            result_dict.append(result)
            log_pbar(result, iter_pbar)

        return result_dict, params
Example #5
0
    def meta_optimize(self,
                      meta_optim,
                      data,
                      model_cls,
                      optim_it,
                      unroll,
                      out_mul,
                      k_obsrv=None,
                      no_mask=None,
                      writer=None,
                      mode='train'):
        """FIX LATER: do something about the dummy arguments."""
        assert mode in ['train', 'valid', 'test']
        self.set_mode(mode)
        use_indexer = False

        params = C(model_cls()).params
        states = C(
            OptimizerStates.initial_zeros(size=(self.n_layers, len(params),
                                                self.hidden_sz),
                                          rnn_cell=self.rnn_cell))

        result_dict = ResultDict()
        unroll_losses = 0
        walltime = Walltime()
        update = C(torch.zeros(params.size().flat()))

        if use_indexer:
            batch_arg = BatchManagerArgument(
                params=params,
                states=StatesSlicingWrapper(states),
                updates=update,
            )

        iter_pbar = tqdm(range(1, optim_it + 1), 'inner_train')
        for iter in iter_pbar:

            with WalltimeChecker(walltime):
                model_train = C(model_cls(params=params.detach()))
                train_nll, train_acc = model_train(*data['in_train'].load())
                train_nll.backward()
                if model_train.params.flat.grad is not None:
                    g = model_train.params.flat.grad
                else:
                    g = model_train._grad2params(model_train.params)
                assert g is not None

                if use_indexer:
                    # This indexer was originally for outer-level batch split
                    #  in case that the whold parameters do not fit in the VRAM.
                    # Which is not good, unless unavoidable, since you cannot avoid
                    #  additional time cost induced by the serial computation.
                    batch_arg.update(
                        BatchManagerArgument(grad=g).immutable().volatile())
                    indexer = DefaultIndexer(input=g, shuffle=False)
                    batch_manager = OptimizerBatchManager(batch_arg=batch_arg,
                                                          indexer=indexer)
                    ######################################################################
                    for get, set in batch_manager.iterator():
                        updates, new_states = self(get.grad().detach(),
                                                   get.states())
                        set.states(new_states)
                        set.params(get.params() + updates * out_mul)
                        if not mode == 'train':
                            set.updates(updates)
                    ######################################################################
                    batch_arg = batch_manager.batch_arg_to_set
                    params = batch_arg.params
                else:
                    updates, states = self(g.detach(), states)
                    updates = params.new_from_flat(updates)
                    params = params + updates * out_mul
                    #params = params - params.new_from_flat(g) * 0.1

            with WalltimeChecker(walltime if mode == 'train' else None):
                model_test = C(model_cls(params=params))
                test_nll, test_acc = model_test(*data['in_test'].load())
                if mode == 'train':
                    unroll_losses += test_nll
                    if iter % unroll == 0:
                        meta_optim.zero_grad()
                        unroll_losses.backward()
                        nn.utils.clip_grad_value_(self.parameters(), 0.01)
                        meta_optim.step()
                        unroll_losses = 0

            with WalltimeChecker(walltime):
                if not mode == 'train' or iter % unroll == 0:
                    if use_indexer:
                        batch_arg.detach_()
                    else:
                        params.detach_()
                        states.detach_()
                # model.params = model.params.detach()
                # states = states.detach()
            result = dict(
                train_nll=train_nll.tolist(),
                test_nll=test_nll.tolist(),
                train_acc=train_acc.tolist(),
                test_acc=test_acc.tolist(),
                walltime=walltime.time,
            )
            result_dict.append(result)
            log_pbar(result, iter_pbar)

        return result_dict, params