Example #1
0
    def meta_optimize(self,
                      meta_optimizer,
                      data_cls,
                      model_cls,
                      optim_it,
                      unroll,
                      out_mul,
                      should_train=True):
        data = data_cls(training=should_train)
        model = C(model_cls())
        optimizer_states = C(
            OptimizerStates.initial_zeros(
                (self.n_layers, len(model.params), self.hidden_sz)))

        if should_train:
            self.train()
        else:
            self.eval()

        result_dict = ResultDict()
        unroll_losses = 0
        update = C(torch.zeros(model.params.size().flat))
        iter_pbar = tqdm(range(1, optim_it + 1), 'optim_iteration')
        iter_watch = utils.StopWatch('optim_iteration')

        batch_arg = BatchManagerArgument(
            params=model.params,
            states=StatesSlicingWrapper(optimizer_states),
            updates=update,
        )

        for iteration in iter_pbar:
            data_ = data.sample()
            loss = model(*data_)
            unroll_losses += loss
            model_dt = C(model_cls(params=batch_arg.params.detach()))
            model_dt_loss = model_dt(*data_)
            assert loss == model_dt_loss
            model_dt_loss.backward()
            assert model_dt.params.flat.grad is not None
            grad = model_dt.params.flat.grad
            batch_arg.update(
                BatchManagerArgument(grad=grad).immutable().volatile())
            iter_pbar.set_description(f'optim_iteration[loss:{loss}]')

            ##########################################################################
            batch_manager = OptimizerBatchManager(batch_arg=batch_arg)

            for get, set in batch_manager.iterator():

                updates, new_states = self(get.grad.detach(), get.states)
                set.states(new_states)
                set.params(get.params + updates * out_mul)
                if not should_train:
                    set.updates(updates)
            ##########################################################################
            batch_arg = batch_manager.batch_arg_to_set
            if should_train and iteration % unroll == 0:
                meta_optimizer.zero_grad()
                unroll_losses.backward()
                meta_optimizer.step()

            if not should_train or iteration % unroll == 0:
                unroll_losses = 0
                batch_arg.detach_()
                # model.params = model.params.detach()
                # optimizer_states = optimizer_states.detach()
            model = C(model_cls(params=batch_arg.params))
            result_dict.append(loss=loss)
            if not should_train:
                result_dict.append(walltime=iter_watch.touch(),
                                   **self.params_tracker(
                                       grad=grad,
                                       update=batch_arg.updates,
                                   ))
        return result_dict
    def meta_optimize(self,
                      meta_optimizer,
                      data_cls,
                      model_cls,
                      optim_it,
                      unroll,
                      out_mul,
                      should_train=True):
        data = data_cls(training=should_train)
        model = C(model_cls())
        g_states = C(
            OptimizerStates.initial_zeros(
                (self.n_layers, len(model.params), self.hidden_sz)))

        if should_train:
            self.train()
        else:
            self.eval()

        result_dict = ResultDict()
        model_dt = None
        grad_losses = 0
        train_with_lazy_tf = True
        eval_test_grad_loss = True
        grad_prev = C(torch.zeros(model.params.size().flat()))
        update_prev = C(torch.zeros(model.params.size().flat()))
        #grad_real_prev = None
        grad_real_cur = None
        grad_real_prev = C(torch.zeros(model.params.size().flat()))
        iter_pbar = tqdm(range(1, optim_it + 1), 'optim_iteration')
        iter_watch = utils.StopWatch('optim_iteration')
        """grad_real_prev:
        teacher-forcing current input of RNN.
        required during all of training steps + every 'k' test steps.
       grad_real_cur:
        MSE loss target for current output of RNN.
        required during all the training steps only.
    """
        batch_arg = BatchManagerArgument(
            params=model.params,
            grad_prev=grad_prev,
            update_prev=update_prev,
            g_states=StatesSlicingWrapper(g_states),
            mse_loss=C(torch.zeros(1)),
            sigma=C(torch.zeros(model.params.size().flat())),
            #updates=OptimizerBatchManager.placeholder(model.params.flat.size()),
        )

        for iteration in iter_pbar:
            data_ = data.sample()
            # if grad_real_cur is not None: # if iteration % self.k == 0
            #   grad_cur_prev = grad_real_cur # back up previous grad
            if (iteration +
                    1) % self.k == 0 or should_train or eval_test_grad_loss:
                # get grad every step while training
                # get grad one step earlier than teacher-forcing steps while testing
                model_dt = C(model_cls(params=model.params.detach()))
                model_loss_dt = model_dt(*data_)
                # model.params.register_grad_cut_hook_()
                model_loss_dt.backward()
                # model.params.remove_grad_cut_hook_()
                grad_real_cur = model_dt.params.flat.grad
                assert grad_real_cur is not None

            if should_train or eval_test_grad_loss:
                # mse loss target
                # grad_real_cur is not used as a mse target while testing
                assert grad_real_cur is not None
                batch_arg.update(
                    BatchManagerArgument(
                        grad_real_cur=grad_real_cur).immutable())

            if iteration % self.k == 0 or (should_train
                                           and not train_with_lazy_tf):
                # teacher-forcing
                assert grad_real_prev is not None
                batch_arg.update(
                    BatchManagerArgument(
                        grad_real_prev=grad_real_prev).immutable().volatile())
                grad_real_prev = None

            ########################################################################
            batch_manager = OptimizerBatchManager(batch_arg=batch_arg)

            for get, set in batch_manager.iterator():
                if iteration % self.k == 0 or (should_train
                                               and not train_with_lazy_tf):
                    grad_prev = get.grad_real_prev  # teacher-forcing
                else:
                    grad_prev = get.grad_prev  # free-running
                grad_cur, g_states = self(grad_prev.detach(),
                                          get.update_prev.detach(),
                                          get.params.detach(), get.g_states)
                set.grad_prev(grad_cur)
                set.g_states(g_states)
                set.params(get.params - 1.0 * grad_cur.detach())
                set.sigma(self.sigma)
                if should_train or eval_test_grad_loss:
                    set.mse_loss(
                        self.mse_loss(grad_cur, get.grad_real_cur.detach()))
            ##########################################################################

            batch_arg = batch_manager.batch_arg_to_set
            grad_real = batch_manager.batch_arg_to_get.grad_real_cur.detach()
            cosim_loss = nn.functional.cosine_similarity(batch_arg.grad_prev,
                                                         grad_real,
                                                         dim=0)
            grad_loss = batch_arg.mse_loss * 100
            cosim_loss = -torch.log((cosim_loss + 1) * 0.5) * 100
            grad_losses += (grad_loss + cosim_loss)
            #to_float = lambda x: float(x.data.cpu().numpy())
            sigma = batch_arg.sigma.detach().mean()
            iter_pbar.set_description(
                'optim_iteration[loss(model):{:10.6f}/(mse):{:10.6f}/(cosim):{:10.6f}/(sigma):{:10.6f}]'
                ''.format(model_loss_dt.tolist(),
                          grad_loss.tolist()[0],
                          cosim_loss.tolist()[0], sigma.tolist()))

            if should_train and iteration % unroll == 0:
                w_decay_loss = 0
                for p in self.parameters():
                    w_decay_loss += torch.sum(p * p)
                w_decay_loss = (w_decay_loss * 0.5) * 1e-4
                total_loss = grad_losses + w_decay_loss
                meta_optimizer.zero_grad()
                total_loss.backward()
                meta_optimizer.step()

            if not should_train or iteration % unroll == 0:
                model_losses = 0
                grad_losses = 0
                batch_arg.detach_()
                # batch_arg.params.detach_()
                # batch_arg.g_states.detach_()
                # batch_arg.u_states.detach_()

            result_dict.append(
                step_num=iteration,
                loss=model_loss_dt,
                grad_loss=grad_loss,
            )

            if not should_train:
                result_dict.append(walltime=iter_watch.touch(),
                                   **self.params_tracker(
                                       grad_real=grad_real_cur,
                                       grad_pred=batch_arg.grad_prev,
                                       update=batch_arg.update_prev,
                                       sigma=batch_arg.sigma,
                                   ))

            if (iteration +
                    1) % self.k == 0 or should_train or eval_test_grad_loss:
                grad_real_prev = grad_real_cur
                grad_real_cur = None

        return result_dict
    def meta_optimize(self, meta_optimizer, data, model_cls, optim_it, unroll,
                      out_mul, mode):
        assert mode in ['train', 'valid', 'test']
        if mode == 'train':
            self.train()
            # data.new_train_data()
            inner_data = data.loaders['inner_train']
            outer_data = data.loaders['inner_valid']
            # inner_data = data.loaders['train']
            # outer_data = inner_data
            drop_mode = 'soft_drop'
        elif mode == 'valid':
            self.eval()
            inner_data = data.loaders['valid']
            outer_data = None
            drop_mode = 'hard_drop'
        elif mode == 'eval':
            self.eval()
            inner_data = data.loaders['test']
            outer_data = None
            drop_mode = 'hard_drop'

        # data = dataset(mode=mode)
        model = C(model_cls(sb_mode=self.sb_mode))
        model(*data.pseudo_sample())
        # layer_sz = sum([v for v in model.params.size().unflat(1, 'mat').values()])
        mask_states = C(
            OptimizerStates.initial_zeros(
                (self.n_layers, len(model.activations.mean(0)),
                 self.hidden_sz)))
        update_states = C(
            OptimizerStates.initial_zeros(
                (self.n_layers, len(model.params), self.hidden_sz)))

        result_dict = ResultDict()
        unroll_losses = 0
        walltime = 0
        update = C(torch.zeros(model.params.size().flat()))

        batch_arg = BatchManagerArgument(
            params=model.params,
            states=StatesSlicingWrapper(update_states),
            updates=model.params.new_zeros(model.params.size()),
        )
        params = model.params
        lr_sgd = 0.2
        lr_adam = 0.001
        sgd = torch.optim.SGD(model.parameters(), lr=lr_sgd)
        adam = torch.optim.Adam(model.parameters())
        adagrad = torch.optim.Adagrad(model.parameters())
        # print('###', lr_adam, '####')

        iter_pbar = tqdm(range(1, optim_it + 1), 'optim_iteration')
        iter_watch = utils.StopWatch('optim_iteration')
        bp_watch = utils.StopWatch('bp')
        loss_decay = 1.0
        dist_list = []
        dense_rank = 0
        sparse_rank = 0
        dense_rank_list = []
        sparse_rank_list = []

        for iteration in iter_pbar:
            iter_watch.touch()

            # # conventional optimizers with API
            # model.zero_grad()
            # loss = model(*inner_data.load())
            # loss.backward()
            # adam.step()
            # loss_detached = loss

            model_detached = C(model_cls(params=params.detach()))
            loss_detached = model_detached(*inner_data.load())
            loss_detached.backward()

            # # conventional optimizers with manual SGD
            # params = model_detached.params + model_detached.params.grad * (-0.2)

            # iter_pbar.set_description(f'optim_iteration[loss:{loss_detached}]')

            masks = []

            # sparse_old_loss = []
            # sparse_new_loss = []
            sparse_loss = []
            dense_loss = []
            candidates = []

            mode = 2
            n_sample = 10

            if mode == 1:
                lr = 1.0
            elif mode == 2:
                lr = 0.2
            else:
                raise Exception()

            for i in range(n_sample):

                if i == 9:
                    mask = MaskGenerator.topk(
                        model_detached.activations.grad.detach())
                else:
                    mask = MaskGenerator.randk(
                        model_detached.activations.grad.detach())

                if model == 1:
                    # recompute gradients again using prunned network
                    sparse_params = model_detached.params.prune(mask)
                    model_prunned = C(model_cls(params=sparse_params.detach()))
                    loss_prunned = model_prunned(*inner_data.load())
                    loss_prunned.backward()
                    grads = model_prunned.params.grad
                elif mode == 2:
                    # use full gradients but sparsified
                    sparse_params = model_detached.params.prune(mask)
                    grads = model_detached.params.grad.prune(mask)
                else:
                    raise Exception('unknown model.')

                sparse_params = sparse_params + grads * (-lr)
                model_prunned = C(model_cls(params=sparse_params.detach()))
                loss_prunned = model_prunned(*inner_data.load())

                # params = model_detached.params
                params = model_detached.params.sparse_update(
                    mask, grads * (-lr))
                candidates.append(params)
                model_dense = C(model_cls(params=params))
                loss_dense = model_dense(*inner_data.load())

                sparse_loss.append(loss_prunned)
                dense_loss.append(loss_dense)

            sparse_loss = torch.stack(sparse_loss, 0)
            dense_loss = torch.stack(dense_loss, 0)
            min_id = (
                sparse_loss.min() == sparse_loss).nonzero().squeeze().tolist()
            params = candidates[min_id]
            sparse_order = sparse_loss.sort()[1].float()
            dense_order = dense_loss.sort()[1].float()
            dist = (sparse_order - dense_order).abs().sum()
            dist_list.append(dist)

            dense_rank += (dense_order == 9).nonzero().squeeze().tolist()
            sparse_rank += (sparse_order == 9).nonzero().squeeze().tolist()

            dense_rank_list.append(
                (dense_order == 9).nonzero().squeeze().tolist())
            sparse_rank_list.append(
                (sparse_order == 9).nonzero().squeeze().tolist())

            iter_pbar.set_description(
                f'optim_iteration[dense_loss:{loss_dense.tolist():5.5}/sparse_loss:{loss_prunned.tolist():5.5}]'
            )
            # iter_pbar.set_description(f'optim_iteration[loss:{loss_dense.tolist()}/dist:{dist.tolist()}]')
            # torch.optim.SGD(sparse_params, lr=0.1).step()

            result_dict.append(loss=loss_detached)
            if not mode == 'train':
                result_dict.append(
                    walltime=walltime,
                    # **self.params_tracker(
                    #   grad=grad,
                    #   update=batch_arg.updates,
                    # )
                )
        dist_mean = torch.stack(dist_list, 0).mean()
        import pdb
        pdb.set_trace()

        return result_dict
Example #4
0
    def meta_optimize(self,
                      meta_optimizer,
                      data_cls,
                      model_cls,
                      optim_it,
                      unroll,
                      out_mul,
                      should_train=True):
        data = data_cls(training=should_train)
        model = C(model_cls())
        g_states = C(
            OptimizerStates.initial_zeros(
                (self.n_layers, len(model.params), self.hidden_sz)))
        u_states = C(
            OptimizerStates.initial_zeros(
                (self.n_layers, len(model.params), self.hidden_sz)))

        if should_train:
            self.train()
        else:
            self.eval()
            grad_pred_tracker = ParamsIndexTracker(name='grad_pred',
                                                   n_tracks=10,
                                                   n_params=len(model.params))
            grad_real_tracker = grad_pred_tracker.clone_new_name('grad_real')
            update_tracker = grad_pred_tracker.clone_new_name('update')

        result_dict = ResultDict()
        model_dt = None
        model_losses = 0
        grad_losses = 0
        train_with_lazy_tf = False
        eval_test_grad_loss = True
        grad_prev = C(torch.zeros(model.params.size().flat()))
        update_prev = C(torch.zeros(model.params.size().flat()))
        #grad_real_prev = None
        grad_real_cur = None
        grad_real_prev = C(torch.zeros(model.params.size().flat()))
        iter_pbar = tqdm(range(1, optim_it + 1), 'optim_iteration')
        iter_watch = utils.StopWatch('optim_iteration')
        """grad_real_prev:
        teacher-forcing current input of RNN.
        required during all of training steps + every 'k' test steps.
       grad_real_cur:
        MSE loss target for current output of RNN.
        required during all the training steps only.
    """
        batch_arg = BatchManagerArgument(
            params=model.params,
            grad_prev=grad_prev,
            update_prev=update_prev,
            g_states=StatesSlicingWrapper(g_states),
            u_states=StatesSlicingWrapper(u_states),
            mse_loss=C(torch.zeros(1)),
            #updates=OptimizerBatchManager.placeholder(model.params.flat.size()),
        )

        for iteration in iter_pbar:
            inp, out = data.sample()
            model_loss = model(inp, out)
            model_losses += model_loss
            # if grad_real_cur is not None: # if iteration % self.k == 0
            #   grad_cur_prev = grad_real_cur # back up previous grad
            if (iteration +
                    1) % self.k == 0 or should_train or eval_test_grad_loss:
                # get grad every step while training
                # get grad one step earlier than teacher-forcing steps while testing
                model_dt = C(model_cls(params=model.params.detach()))
                model_loss_dt = model_dt(inp, out)
                # model.params.register_grad_cut_hook_()
                model_loss_dt.backward()
                # model.params.remove_grad_cut_hook_()
                grad_real_cur = model_dt.params.flat.grad
                assert grad_real_cur is not None
                del model_dt, model_loss_dt

            if should_train or eval_test_grad_loss:
                # mse loss target
                # grad_real_cur is not used as a mse target while testing
                assert grad_real_cur is not None
                batch_arg.update(
                    BatchManagerArgument(
                        grad_real_cur=grad_real_cur).immutable().volatile())

            if iteration % self.k == 0 or (should_train
                                           and not train_with_lazy_tf):
                # teacher-forcing
                assert grad_real_prev is not None
                batch_arg.update(
                    BatchManagerArgument(
                        grad_real_prev=grad_real_prev).immutable().volatile())
                grad_real_prev = None

            ########################################################################
            batch_manager = OptimizerBatchManager(batch_arg=batch_arg)

            for get, set in batch_manager.iterator():
                if iteration % self.k == 0 or (should_train
                                               and not train_with_lazy_tf):
                    grad_prev = get.grad_real_prev.detach()  # teacher-forcing
                else:
                    grad_prev = get.grad_prev.detach()  # free-running
                update_prev = get.update_prev.detach()
                grad_cur, update_cur, g_states, u_states = self(
                    grad_prev, update_prev, get.g_states, get.u_states)
                set.grad_prev(grad_cur)
                set.update_prev(update_cur)
                set.g_states(g_states)
                set.u_states(u_states)
                set.params(get.params + update_cur * out_mul)
                if should_train or eval_test_grad_loss:
                    set.mse_loss(
                        self.mse_loss(grad_cur, get.grad_real_cur.detach()))
            ##########################################################################

            batch_arg = batch_manager.batch_arg_to_set
            grad_loss = batch_arg.mse_loss * 100
            grad_losses += grad_loss
            #to_float = lambda x: float(x.data.cpu().numpy())
            iter_pbar.set_description(
                'optim_iteration[loss(model):{:10.6f}/(grad):{:10.6f}]'.format(
                    model_loss.tolist(),
                    grad_loss.tolist()[0]))

            if should_train and iteration % unroll == 0:
                meta_optimizer.zero_grad()
                losses = model_losses + grad_losses
                losses.backward()
                meta_optimizer.step()

            if not should_train or iteration % unroll == 0:
                model_losses = 0
                grad_losses = 0
                batch_arg.detach_()
                # batch_arg.params.detach_()
                # batch_arg.g_states.detach_()
                # batch_arg.u_states.detach_()

            model = C(model_cls(params=batch_arg.params))
            result_dict.append(
                step_num=iteration,
                loss=model_loss,
                grad_loss=grad_loss,
            )

            if not should_train:
                result_dict.append(
                    walltime=iter_watch.touch(),
                    **grad_pred_tracker.track_num,
                    **grad_pred_tracker(batch_arg.grad_prev),
                    **grad_real_tracker(grad_real_cur),
                    **update_tracker(batch_arg.update_prev),
                )

            if (iteration +
                    1) % self.k == 0 or should_train or eval_test_grad_loss:
                grad_real_prev = grad_real_cur
                grad_real_cur = None

        return result_dict
Example #5
0
    def meta_optimize(self,
                      meta_optim,
                      data,
                      model_cls,
                      optim_it,
                      unroll,
                      out_mul,
                      k_obsrv=None,
                      no_mask=None,
                      writer=None,
                      mode='train'):
        """FIX LATER: do something about the dummy arguments."""
        assert mode in ['train', 'valid', 'test']
        self.set_mode(mode)
        use_indexer = False

        params = C(model_cls()).params
        states = C(
            OptimizerStates.initial_zeros(size=(self.n_layers, len(params),
                                                self.hidden_sz),
                                          rnn_cell=self.rnn_cell))

        result_dict = ResultDict()
        unroll_losses = 0
        walltime = Walltime()
        update = C(torch.zeros(params.size().flat()))

        if use_indexer:
            batch_arg = BatchManagerArgument(
                params=params,
                states=StatesSlicingWrapper(states),
                updates=update,
            )

        iter_pbar = tqdm(range(1, optim_it + 1), 'inner_train')
        for iter in iter_pbar:

            with WalltimeChecker(walltime):
                model_train = C(model_cls(params=params.detach()))
                train_nll, train_acc = model_train(*data['in_train'].load())
                train_nll.backward()
                if model_train.params.flat.grad is not None:
                    g = model_train.params.flat.grad
                else:
                    g = model_train._grad2params(model_train.params)
                assert g is not None

                if use_indexer:
                    # This indexer was originally for outer-level batch split
                    #  in case that the whold parameters do not fit in the VRAM.
                    # Which is not good, unless unavoidable, since you cannot avoid
                    #  additional time cost induced by the serial computation.
                    batch_arg.update(
                        BatchManagerArgument(grad=g).immutable().volatile())
                    indexer = DefaultIndexer(input=g, shuffle=False)
                    batch_manager = OptimizerBatchManager(batch_arg=batch_arg,
                                                          indexer=indexer)
                    ######################################################################
                    for get, set in batch_manager.iterator():
                        updates, new_states = self(get.grad().detach(),
                                                   get.states())
                        set.states(new_states)
                        set.params(get.params() + updates * out_mul)
                        if not mode == 'train':
                            set.updates(updates)
                    ######################################################################
                    batch_arg = batch_manager.batch_arg_to_set
                    params = batch_arg.params
                else:
                    updates, states = self(g.detach(), states)
                    updates = params.new_from_flat(updates)
                    params = params + updates * out_mul
                    #params = params - params.new_from_flat(g) * 0.1

            with WalltimeChecker(walltime if mode == 'train' else None):
                model_test = C(model_cls(params=params))
                test_nll, test_acc = model_test(*data['in_test'].load())
                if mode == 'train':
                    unroll_losses += test_nll
                    if iter % unroll == 0:
                        meta_optim.zero_grad()
                        unroll_losses.backward()
                        nn.utils.clip_grad_value_(self.parameters(), 0.01)
                        meta_optim.step()
                        unroll_losses = 0

            with WalltimeChecker(walltime):
                if not mode == 'train' or iter % unroll == 0:
                    if use_indexer:
                        batch_arg.detach_()
                    else:
                        params.detach_()
                        states.detach_()
                # model.params = model.params.detach()
                # states = states.detach()
            result = dict(
                train_nll=train_nll.tolist(),
                test_nll=test_nll.tolist(),
                train_acc=train_acc.tolist(),
                test_acc=test_acc.tolist(),
                walltime=walltime.time,
            )
            result_dict.append(result)
            log_pbar(result, iter_pbar)

        return result_dict, params
Example #6
0
  def meta_optimize(self, meta_optimizer, data_cls, model_cls, optim_it, unroll,
                    out_mul, should_train=True):
    data = data_cls(training=should_train)
    model = C(model_cls())
    optimizer_states = C(OptimizerStates.initial_zeros(
        size=(self.n_layers, len(model.params), self.hidden_sz),
        rnn_cell=self.rnn_cell))

    if should_train:
      self.train()
    else:
      self.eval()

    result_dict = ResultDict()
    walltime = 0
    unroll_losses = 0
    update = C(torch.zeros(model.params.size().flat()))

    batch_arg = BatchManagerArgument(
      params=model.params,
      states=StatesSlicingWrapper(optimizer_states),
      updates=update,
    )

    iter_pbar = tqdm(range(1, optim_it + 1), 'optim_iteration')
    iter_watch = utils.StopWatch('optim_iteration')
    for iteration in iter_pbar:
      iter_watch.touch()
      # entire_loop.touch()
      # interval.touch()
      data_ = data.sample()
      if should_train:
        model = C(model_cls(params=batch_arg.params))
        loss = model(*data_)
        unroll_losses += loss
      # print("\n[1] : "+str(interval.touch('interval', False)))
      # model_ff.touch()
      model_detached = C(model_cls(params=batch_arg.params.detach()))
      loss_detached = model_detached(*data_)
      # model_ff.touch('interval', True)
      if should_train:
        assert loss == loss_detached
      # model_bp.touch()
      loss_detached.backward()
      # model_bp.touch('interval', True)
      # interval.touch()
      assert model_detached.params.flat.grad is not None
      grad = model_detached.params.flat.grad
      batch_arg.update(BatchManagerArgument(grad=grad).immutable().volatile())
      iter_pbar.set_description(f'optim_iteration[loss:{loss_detached}]')
      # print("\n[2-1] : "+str(interval.touch('interval', False)))
      ##########################################################################
      # indexer = DefaultIndexer(input=grad, shuffle=False)
      indexer = TopKIndexer(input=grad, k_ratio=0.1, random_k_mode=False)# bound=model_dt.params.size().bound_layerwise())
      #indexer = DefaultIndexer(input=grad)
      batch_manager = OptimizerBatchManager(
        batch_arg=batch_arg, indexer=indexer)
      # print("\n[2-2] : "+str(interval.touch('interval', False)))
      # optim_ff.touch()
      # optim_ff_1.touch()
      for get, set in batch_manager.iterator():

        # optim_ff_1.touch('interval', True)
        # optim_ff_2.touch()

        # optim_ff_2.touch('interval', True)
        # optim_ff_3.touch()
        updates, new_states = self(get.grad().detach(),  get.states())
        # optim_ff_3.touch('interval', True)
        # optim_ff_4.touch()
        updates = updates * out_mul
        set.states(new_states)
        set.params(get.params() + updates)
        if not should_train:
          set.updates(updates)
        # optim_ff_4.touch('interval', True)
      # optim_ff.touch('interval', True)
      ##########################################################################
      # interval.touch()
      batch_arg = batch_manager.batch_arg_to_set
      # print("\n[3] : "+str(interval.touch('interval', False)))
      if should_train and iteration % unroll == 0:

        # optim_bp.touch()
        meta_optimizer.zero_grad()
        unroll_losses.backward()
        nn.utils.clip_grad_norm_(self.parameters(), 10)
        meta_optimizer.step()
        unroll_losses = 0
        # optim_bp.touch('interval', True)
      # interval.touch()
      if not should_train or iteration % unroll == 0:
        batch_arg.detach_()
      # print("\n[4] : "+str(interval.touch('interval', False)))

      # interval.touch()
        # model.params = model.params.detach()
        # optimizer_states = optimizer_states.detach()
      walltime += iter_watch.touch('interval')
      result_dict.append(loss=loss_detached)
      # print("\n[5] : "+str(interval.touch('interval', False)))
      if not should_train:
        result_dict.append(
          walltime=iter_watch.touch(),
          **self.params_tracker(
            grad=grad,
            update=batch_arg.updates,
          )
        )
      # entire_loop.touch('interval', True)
    return result_dict
    def meta_optimize(self, meta_optimizer, data, model_cls, optim_it, unroll,
                      out_mul, mode):
        assert mode in ['train', 'valid', 'test']
        if mode == 'train':
            self.train()
            # data.new_train_data()
            inner_data = data.loaders['inner_train']
            outer_data = data.loaders['inner_valid']
            # inner_data = data.loaders['train']
            # outer_data = inner_data
            drop_mode = 'soft_drop'
        elif mode == 'valid':
            self.eval()
            inner_data = data.loaders['valid']
            outer_data = None
            drop_mode = 'hard_drop'
        elif mode == 'eval':
            self.eval()
            inner_data = data.loaders['test']
            outer_data = None
            drop_mode = 'hard_drop'

        # data = dataset(mode=mode)
        model = C(model_cls(sb_mode=self.sb_mode))
        model(*data.pseudo_sample())
        # layer_sz = sum([v for v in model.params.size().unflat(1, 'mat').values()])
        mask_states = C(
            OptimizerStates.initial_zeros(
                (self.n_layers, len(model.activations.mean(0)),
                 self.hidden_sz)))
        update_states = C(
            OptimizerStates.initial_zeros(
                (self.n_layers, len(model.params), self.hidden_sz)))

        result_dict = ResultDict()
        unroll_losses = 0
        walltime = 0
        update = C(torch.zeros(model.params.size().flat()))

        batch_arg = BatchManagerArgument(
            params=model.params,
            states=StatesSlicingWrapper(update_states),
            updates=model.params.new_zeros(model.params.size()),
        )

        iter_pbar = tqdm(range(1, optim_it + 1), 'optim_iteration')
        iter_watch = utils.StopWatch('optim_iteration')
        bp_watch = utils.StopWatch('bp')
        loss_decay = 1.0

        for iteration in iter_pbar:
            iter_watch.touch()
            model_detached = C(
                model_cls(params=batch_arg.params.detach(),
                          sb_mode=self.sb_mode))
            loss_detached = model_detached(*inner_data.load())
            loss_detached.backward()
            # import pdb; pdb.set_trace()
            iter_pbar.set_description(f'optim_iteration[loss:{loss_detached}]')
            # np.set_printoptions(suppress=True, precision=3, linewidth=200, edgeitems=20)
            if debug_sigint.signal_on:
                debug = iteration == 1 or iteration % 10 == 0
                # debug = True
            else:
                debug = False

            backprop_mask, mask_states, sparsity_loss = self.mask_generator(
                activations=model_detached.activations.grad.detach(),
                debug=debug,
                states=mask_states)

            # if debug_sigstp.signal_on and (iteration == 1 or iteration % 10 == 0):
            #   mask_list.append(backprop_mask)
            #   import pdb; pdb.set_trace()
            # min = 0.001
            # max = 0.01
            unroll_losses += sparsity_loss * 0.01 * \
                loss_decay  # * ((max - min) * lambda_ + min)
            # import pdb; pdb.set_trace()

            # model_detached.apply_backprop_mask(backprop_mask.unflat, drop_mode)
            # loss_detached.backward()
            assert model_detached.params.grad is not None

            if mode == 'train':
                model = C(
                    model_cls(params=batch_arg.params, sb_mode=self.sb_mode))

                loss = model(*outer_data.load())
                # assert loss == loss_detached
                if iteration == 1:
                    initial_loss = loss
                    loss_decay = 1.0
                else:
                    loss_decay = (loss / initial_loss).detach()
                unroll_losses += loss

                # model_detached.apply_backprop_mask(backprop_mask.unflat, drop_mode)

            #model_detached.params.apply_grad_mask(mask, 'soft_drop')
            coord_mask = backprop_mask.expand_as_params(model_detached.params)

            if drop_mode == 'soft_drop':
                # grad = model_detached.params.grad.mul(coord_mask).flat
                # model_detached.params = model_detached.params.mul(params_mask)
                index = None
                # updates, states = self.updater(grad, states, mask=coord_mask)
                # updates = updates * out_mul * coord_mask
                # import pdb; pdb.set_trace()
            elif drop_mode == 'hard_drop':
                index = (coord_mask.flat.squeeze() > 0.5).nonzero().squeeze()
            grad = model_detached.params.grad.flat
            batch_arg.update(BatchManagerArgument(grad=grad, mask=coord_mask))
            ##########################################################################
            indexer = DefaultIndexer(input=grad, index=index, shuffle=False)
            batch_manager = OptimizerBatchManager(batch_arg=batch_arg,
                                                  indexer=indexer)

            for get, set in batch_manager.iterator():
                inp = get.grad().detach()
                #inp = [get.grad().detach(), get.tag().detach()]
                if drop_mode == 'soft_drop':
                    updates, new_states = self.updater(
                        inp, get.states())  # , mask=get.mask())
                    # if iteration == 100:
                    #   import pdb; pdb.set_trace()
                    #   aaa = get.states() *2
                    updates = updates * out_mul * get.mask()
                    set.states(get.states() * (1 - get.mask()) +
                               new_states * get.mask())
                    # set.states(new_states)
                elif drop_mode == 'hard_drop':
                    updates, new_states = self.updater(inp, get.states())
                    updates = updates * out_mul
                    set.states(new_states)
                else:
                    raise RuntimeError('Unknown drop mode!')
                set.params(get.params() + updates)
                # if not mode != 'train':
                # set.updates(updates)
            ##########################################################################
            batch_arg = batch_manager.batch_arg_to_set
            # updates = model.params.new_from_flat(updates)

            if mode == 'train' and iteration % unroll == 0:
                meta_optimizer.zero_grad()
                unroll_losses.backward()
                nn.utils.clip_grad_norm_(self.parameters(), 5)
                meta_optimizer.step()
                unroll_losses = 0

            # if iteration == 1 or iteration % 10 == 0:
            #   import pdb; pdb.set_trace()

            if not mode == 'train' or iteration % unroll == 0:
                mask_states.detach_()
                batch_arg.detach_()
                # model.params = model.params.detach()
                # update_states = update_states.detach()
            walltime += iter_watch.touch('interval')
            result_dict.append(loss=loss_detached)
            if not mode == 'train':
                result_dict.append(
                    walltime=walltime,
                    # **self.params_tracker(
                    #   grad=grad,
                    #   update=batch_arg.updates,
                    # )
                )
        return result_dict
Example #8
0
    def meta_optimize(self,
                      meta_optimizer,
                      data_cls,
                      model_cls,
                      optim_it,
                      unroll,
                      out_mul,
                      should_train=True):
        data = data_cls(training=should_train)
        model = C(model_cls())
        states = C(
            OptimizerStates.initial_zeros(
                (self.n_layers, len(model.params), self.hidden_sz)))

        if should_train:
            self.train()
        else:
            self.eval()

        result_dict = ResultDict()
        model_dt = None
        model_losses = 0
        #grad_real_prev = None
        iter_pbar = tqdm(range(1, optim_it + 1), 'optim_iteration')
        iter_watch = utils.StopWatch('optim_iteration')

        batch_arg = BatchManagerArgument(
            params=model.params,
            states=StatesSlicingWrapper(states),
            update_prev=C(torch.zeros(model.params.size().flat())),
            mu=C(torch.zeros(model.params.size().flat())),
            sigma=C(torch.zeros(model.params.size().flat())),
        )

        for iteration in iter_pbar:
            data_ = data.sample()
            model_loss = model(*data_)
            model_losses += model_loss

            # if iteration == 21:
            #   import pdb; pdb.set_trace()
            try:
                make_dot(model_losses)
            except:
                import pdb
                pdb.set_trace()
            # if iteration == 1 or iteration == 21:
            #   import pdb; pdb.set_trace()

            model_dt = C(model_cls(params=model.params.detach()))
            model_loss_dt = model_dt(*data_)
            # model.params.register_grad_cut_hook_()
            model_loss_dt.backward()
            assert model_dt.params.flat.grad is not None
            grad_cur = model_dt.params.flat.grad
            # model.params.remove_grad_cut_hook_()
            batch_arg.update(BatchManagerArgument(grad_cur=grad_cur))
            # NOTE: read_only, write_only

            ########################################################################
            batch_manager = OptimizerBatchManager(batch_arg=batch_arg)

            for get, set in batch_manager.iterator():
                grad_cur = get.grad_cur.detach()  # free-running
                update_prev = get.update_prev.detach()
                points_rel, states = self.point_sampler(
                    grad_cur, update_prev, get.states, self.n_points)
                points_rel = points_rel * out_mul
                points_abs = get.params.detach(
                ) + points_rel  # NOTE: check detach
                losses = self.loss_observer(points_abs, model_cls, data_)
                update = self.point_selector(points_rel, losses)
                set.params(get.params + update)
                set.states(states)
                set.update_prev(update)
                set.mu(self.point_sampler.mu)
                set.sigma(self.point_sampler.sigma)
            ##########################################################################

            batch_arg = batch_manager.batch_arg_to_set
            iter_pbar.set_description('optim_iteration[loss:{:10.6f}]'.format(
                model_loss.tolist()))

            if should_train and iteration % unroll == 0:
                meta_optimizer.zero_grad()
                try:
                    model_losses.backward()
                except:
                    import pdb
                    pdb.set_trace()
                meta_optimizer.step()

            if not should_train or iteration % unroll == 0:
                model_losses = 0
                batch_arg.detach_()
                # batch_arg.params.detach_()
                # batch_arg.states.detach_()
                # NOTE: just do batch_arg.detach_()

            model = C(model_cls(params=batch_arg.params))
            result_dict.append(
                step_num=iteration,
                loss=model_loss,
            )

            if not should_train:
                result_dict.append(walltime=iter_watch.touch(),
                                   **self.params_tracker(
                                       grad=grad_cur,
                                       update=batch_arg.update_prev,
                                       mu=batch_arg.mu,
                                       sigma=batch_arg.sigma,
                                   ))
        return result_dict