def meta_optimize(self, meta_optimizer, data_cls, model_cls, optim_it, unroll,
                    out_mul, should_train=True):
    if should_train:
      self.train()
    else:
      self.eval()
      unroll = 1

    target = data_cls(training=should_train)
    optimizee = C(model_cls())
    optimizer_states = OptimizerStates.zero_states(
        self.n_layers, len(optimizee.params), self.hidden_sz)
    result_dict = ResultDict()
    # n_params = 0
    # named_shape = {}
    # for name, p in optimizee.all_named_parameters():
    #   n_params += int(np.prod(p.size()))
    #   named_shape[name] = p.size()

    # batch_size = 30000
    # batch_sizes = []
    # if batch_size > n_params:
    #   batch_sizes = [n_params]
    # else:
    #   batch_sizes = [batch_size for _ in range(n_params // batch_size)]
    #   batch_sizes.append(n_params % batch_size)

    coordinate_tracker = ParamsIndexTracker(
        n_params=len(optimizee.params), n_tracks=30)
    # update_track_num = 10
    # n_params = len(optimizee.params)
    # if update_track_num > n_params:
    #   update_track_num = n_params
    # update_track_id = np.random.randint(0, n_params, update_track_num)
    # updates_tracks = []

    # hidden_states = [
    #     C(torch.zeros(n_params, self.hidden_sz)) for _ in range(2)]
    # cell_states = [
    #     C(torch.zeros(n_params, self.hidden_sz)) for _ in range(2)]
    all_losses_ever = []
    timestamps = []

    if should_train:
      meta_optimizer.zero_grad()
    unroll_losses = None

    ############################################################################
    iter_pbar = tqdm(range(1, optim_it + 1), 'optim_iteration')
    iter_watch = utils.StopWatch('optim_iteration')
    for iteration in iter_pbar:
      loss = optimizee(target)
      iter_pbar.set_description(f'optim_iteration[loss:{loss}]')

      if unroll_losses is None:
        unroll_losses = loss
      else:
        unroll_losses += loss
      result_dict.add('losses', loss)
      try:
        loss.backward(retain_graph=should_train)
      except Exception as ex:
        print(ex)
        import pdb
        pdb.set_trace()

      # all_losses_ever.append(loss.data.cpu().numpy())

      # hidden_states2 = [
      #     C(torch.zeros(n_params, self.hidden_sz)) for _ in range(2)]
      # cell_states2 = [
      #     C(torch.zeros(n_params, self.hidden_sz)) for _ in range(2)]
      ##########################################################################
      # params_flat = []
      # grads_flat = []
      #
      # for name, p in optimizee.all_named_parameters():
      #   named_shape[name] = p.size()
      #   params_flat.append(p.view(-1, 1))
      #   grads_flat.append(C.detach(p.grad.view(-1, 1)))
      # params_flat = torch.cat(params_flat, 0)
      # grads_flat = torch.cat(grads_flat, 0)
      #
      # offset = 0
      # params_batches = []
      # grads_batches = []
      # for i in range(len(batch_sizes)):
      #   params_batches.append(params_flat[offset:offset + batch_sizes[i], :])
      #   grads_batches.append(grads_flat[offset:offset + batch_sizes[i], :])
      #   offset += batch_sizes[i]
      # assert(offset == params_flat.size(0))
      # batches = zip(batch_sizes, params_batches, grads_batches)
      ##########################################################################
      offset = 0
      result_params_flat = []
      result_params = {}
      updates_track = []
      batch_manager = OptimizerBatchManager(
          params=optimizee.params, states=optimizer_states,
          getters=['params', 'states'],
          setters=['params', 'states', 'updates'],
          batch_size=30000)

      # for batch_sz, params, grads in batches:
      for get, set in batch_manager.iterator():
        # We do this so the gradients are disconnected from the graph
        #  but we still get gradients from the rest
        updates, new_states = self(
            get.params.grad.detach(), get.states.clone())
        set.states(new_states.clone())
        set.params(get.params.clone() + updates * out_mul)
        set.updates(updates)

        #     grads,
        #     [h[offset:offset + batch_sz] for h in hidden_states],
        #     [c[offset:offset + batch_sz] for c in cell_states]
        # )
        # for i in range(len(new_hidden)):
        #   hidden_states2[i][offset:offset + batch_sz] = new_hidden[i]
        #   cell_states2[i][offset:offset + batch_sz] = new_cell[i]

        # result_params_flat.append(params + updates * out_mul)
        # updates_track.append(updates)
        # offset += batch_sz
      ##########################################################################
      # result_params_flat = torch.cat(result_params_flat, 0)
      # updates_track = torch.cat(updates_track, 0)
      # updates_track = updates_track.data.cpu().numpy()
      # updates_track = np.take(updates_track, update_track_id)
      # updates_tracks.append(updates_track)

      # offset = 0
      # for name, shape in named_shape.items():
      #   n_unit_params = int(np.prod(shape))
      #   result_params[name] = result_params_flat[
      #     offset:offset + n_unit_params, :].view(shape)
      #   result_params[name].retain_grad()
      #   offset += n_unit_params
      ##########################################################################
      # NOTE: optimizee.states.states <-- looks a bit weird
      if iteration % unroll == 0:

        if should_train:
          meta_optimizer.zero_grad()
          unroll_losses.backward()
          meta_optimizer.step()
        unroll_losses = None

        optimizee.params.detach_()
        optimizer_states = optimizer_states.detach()

      #   detached_params = {k: C.detach(v) for k, v in result_params.items()}
      #   optimizee = C(model_cls(**detached_params))
      #   hidden_states = [C.detach(s) for s in hidden_states2]
      #   cell_states = [C.detach(s) for s in cell_states2]
      #
      # else:
      #   optimizee = C(model_cls(**result_params))
      #   assert len(list(optimizee.all_named_parameters()))
      #   hidden_states = hidden_states2
      #   cell_states = cell_states2
      optimizee = C(model_cls(optimizee.params))
      # timestamps.append(iter_watch.touch())
      result_dict.add('updates', coordinate_tracker(batch_manager.updates))
      result_dict.add('timestamps', iter_watch.touch())
    return all_losses_ever, updates_tracks, timestamps
Exemple #2
0
    def meta_optimize(self,
                      meta_optimizer,
                      data_cls,
                      model_cls,
                      optim_it,
                      unroll,
                      out_mul,
                      should_train=True):
        data = data_cls(training=should_train)
        model = C(model_cls())
        optimizer_states = C(
            OptimizerStates.initial_zeros(
                (self.n_layers, len(model.params), self.hidden_sz)))

        if should_train:
            self.train()
        else:
            self.eval()

        result_dict = ResultDict()
        unroll_losses = 0
        update = C(torch.zeros(model.params.size().flat))
        iter_pbar = tqdm(range(1, optim_it + 1), 'optim_iteration')
        iter_watch = utils.StopWatch('optim_iteration')

        batch_arg = BatchManagerArgument(
            params=model.params,
            states=StatesSlicingWrapper(optimizer_states),
            updates=update,
        )

        for iteration in iter_pbar:
            data_ = data.sample()
            loss = model(*data_)
            unroll_losses += loss
            model_dt = C(model_cls(params=batch_arg.params.detach()))
            model_dt_loss = model_dt(*data_)
            assert loss == model_dt_loss
            model_dt_loss.backward()
            assert model_dt.params.flat.grad is not None
            grad = model_dt.params.flat.grad
            batch_arg.update(
                BatchManagerArgument(grad=grad).immutable().volatile())
            iter_pbar.set_description(f'optim_iteration[loss:{loss}]')

            ##########################################################################
            batch_manager = OptimizerBatchManager(batch_arg=batch_arg)

            for get, set in batch_manager.iterator():

                updates, new_states = self(get.grad.detach(), get.states)
                set.states(new_states)
                set.params(get.params + updates * out_mul)
                if not should_train:
                    set.updates(updates)
            ##########################################################################
            batch_arg = batch_manager.batch_arg_to_set
            if should_train and iteration % unroll == 0:
                meta_optimizer.zero_grad()
                unroll_losses.backward()
                meta_optimizer.step()

            if not should_train or iteration % unroll == 0:
                unroll_losses = 0
                batch_arg.detach_()
                # model.params = model.params.detach()
                # optimizer_states = optimizer_states.detach()
            model = C(model_cls(params=batch_arg.params))
            result_dict.append(loss=loss)
            if not should_train:
                result_dict.append(walltime=iter_watch.touch(),
                                   **self.params_tracker(
                                       grad=grad,
                                       update=batch_arg.updates,
                                   ))
        return result_dict
Exemple #3
0
    def meta_optimize(self,
                      meta_optimizer,
                      data_cls,
                      model_cls,
                      optim_it,
                      unroll,
                      out_mul,
                      should_train=True):
        data = data_cls(training=should_train)
        model = C(model_cls())
        g_states = C(
            OptimizerStates.initial_zeros(
                (self.n_layers, len(model.params), self.hidden_sz)))
        u_states = C(
            OptimizerStates.initial_zeros(
                (self.n_layers, len(model.params), self.hidden_sz)))

        if should_train:
            self.train()
        else:
            self.eval()
            grad_pred_tracker = ParamsIndexTracker(name='grad_pred',
                                                   n_tracks=10,
                                                   n_params=len(model.params))
            grad_real_tracker = grad_pred_tracker.clone_new_name('grad_real')
            update_tracker = grad_pred_tracker.clone_new_name('update')

        result_dict = ResultDict()
        model_dt = None
        model_losses = 0
        grad_losses = 0
        train_with_lazy_tf = False
        eval_test_grad_loss = True
        grad_prev = C(torch.zeros(model.params.size().flat()))
        update_prev = C(torch.zeros(model.params.size().flat()))
        #grad_real_prev = None
        grad_real_cur = None
        grad_real_prev = C(torch.zeros(model.params.size().flat()))
        iter_pbar = tqdm(range(1, optim_it + 1), 'optim_iteration')
        iter_watch = utils.StopWatch('optim_iteration')
        """grad_real_prev:
        teacher-forcing current input of RNN.
        required during all of training steps + every 'k' test steps.
       grad_real_cur:
        MSE loss target for current output of RNN.
        required during all the training steps only.
    """
        batch_arg = BatchManagerArgument(
            params=model.params,
            grad_prev=grad_prev,
            update_prev=update_prev,
            g_states=StatesSlicingWrapper(g_states),
            u_states=StatesSlicingWrapper(u_states),
            mse_loss=C(torch.zeros(1)),
            #updates=OptimizerBatchManager.placeholder(model.params.flat.size()),
        )

        for iteration in iter_pbar:
            inp, out = data.sample()
            model_loss = model(inp, out)
            model_losses += model_loss
            # if grad_real_cur is not None: # if iteration % self.k == 0
            #   grad_cur_prev = grad_real_cur # back up previous grad
            if (iteration +
                    1) % self.k == 0 or should_train or eval_test_grad_loss:
                # get grad every step while training
                # get grad one step earlier than teacher-forcing steps while testing
                model_dt = C(model_cls(params=model.params.detach()))
                model_loss_dt = model_dt(inp, out)
                # model.params.register_grad_cut_hook_()
                model_loss_dt.backward()
                # model.params.remove_grad_cut_hook_()
                grad_real_cur = model_dt.params.flat.grad
                assert grad_real_cur is not None
                del model_dt, model_loss_dt

            if should_train or eval_test_grad_loss:
                # mse loss target
                # grad_real_cur is not used as a mse target while testing
                assert grad_real_cur is not None
                batch_arg.update(
                    BatchManagerArgument(
                        grad_real_cur=grad_real_cur).immutable().volatile())

            if iteration % self.k == 0 or (should_train
                                           and not train_with_lazy_tf):
                # teacher-forcing
                assert grad_real_prev is not None
                batch_arg.update(
                    BatchManagerArgument(
                        grad_real_prev=grad_real_prev).immutable().volatile())
                grad_real_prev = None

            ########################################################################
            batch_manager = OptimizerBatchManager(batch_arg=batch_arg)

            for get, set in batch_manager.iterator():
                if iteration % self.k == 0 or (should_train
                                               and not train_with_lazy_tf):
                    grad_prev = get.grad_real_prev.detach()  # teacher-forcing
                else:
                    grad_prev = get.grad_prev.detach()  # free-running
                update_prev = get.update_prev.detach()
                grad_cur, update_cur, g_states, u_states = self(
                    grad_prev, update_prev, get.g_states, get.u_states)
                set.grad_prev(grad_cur)
                set.update_prev(update_cur)
                set.g_states(g_states)
                set.u_states(u_states)
                set.params(get.params + update_cur * out_mul)
                if should_train or eval_test_grad_loss:
                    set.mse_loss(
                        self.mse_loss(grad_cur, get.grad_real_cur.detach()))
            ##########################################################################

            batch_arg = batch_manager.batch_arg_to_set
            grad_loss = batch_arg.mse_loss * 100
            grad_losses += grad_loss
            #to_float = lambda x: float(x.data.cpu().numpy())
            iter_pbar.set_description(
                'optim_iteration[loss(model):{:10.6f}/(grad):{:10.6f}]'.format(
                    model_loss.tolist(),
                    grad_loss.tolist()[0]))

            if should_train and iteration % unroll == 0:
                meta_optimizer.zero_grad()
                losses = model_losses + grad_losses
                losses.backward()
                meta_optimizer.step()

            if not should_train or iteration % unroll == 0:
                model_losses = 0
                grad_losses = 0
                batch_arg.detach_()
                # batch_arg.params.detach_()
                # batch_arg.g_states.detach_()
                # batch_arg.u_states.detach_()

            model = C(model_cls(params=batch_arg.params))
            result_dict.append(
                step_num=iteration,
                loss=model_loss,
                grad_loss=grad_loss,
            )

            if not should_train:
                result_dict.append(
                    walltime=iter_watch.touch(),
                    **grad_pred_tracker.track_num,
                    **grad_pred_tracker(batch_arg.grad_prev),
                    **grad_real_tracker(grad_real_cur),
                    **update_tracker(batch_arg.update_prev),
                )

            if (iteration +
                    1) % self.k == 0 or should_train or eval_test_grad_loss:
                grad_real_prev = grad_real_cur
                grad_real_cur = None

        return result_dict
    def meta_optimize(self,
                      meta_optimizer,
                      data_cls,
                      model_cls,
                      optim_it,
                      unroll,
                      out_mul,
                      should_train=True):
        data = data_cls(training=should_train)
        model = C(model_cls())
        g_states = C(
            OptimizerStates.initial_zeros(
                (self.n_layers, len(model.params), self.hidden_sz)))

        if should_train:
            self.train()
        else:
            self.eval()

        result_dict = ResultDict()
        model_dt = None
        grad_losses = 0
        train_with_lazy_tf = True
        eval_test_grad_loss = True
        grad_prev = C(torch.zeros(model.params.size().flat()))
        update_prev = C(torch.zeros(model.params.size().flat()))
        #grad_real_prev = None
        grad_real_cur = None
        grad_real_prev = C(torch.zeros(model.params.size().flat()))
        iter_pbar = tqdm(range(1, optim_it + 1), 'optim_iteration')
        iter_watch = utils.StopWatch('optim_iteration')
        """grad_real_prev:
        teacher-forcing current input of RNN.
        required during all of training steps + every 'k' test steps.
       grad_real_cur:
        MSE loss target for current output of RNN.
        required during all the training steps only.
    """
        batch_arg = BatchManagerArgument(
            params=model.params,
            grad_prev=grad_prev,
            update_prev=update_prev,
            g_states=StatesSlicingWrapper(g_states),
            mse_loss=C(torch.zeros(1)),
            sigma=C(torch.zeros(model.params.size().flat())),
            #updates=OptimizerBatchManager.placeholder(model.params.flat.size()),
        )

        for iteration in iter_pbar:
            data_ = data.sample()
            # if grad_real_cur is not None: # if iteration % self.k == 0
            #   grad_cur_prev = grad_real_cur # back up previous grad
            if (iteration +
                    1) % self.k == 0 or should_train or eval_test_grad_loss:
                # get grad every step while training
                # get grad one step earlier than teacher-forcing steps while testing
                model_dt = C(model_cls(params=model.params.detach()))
                model_loss_dt = model_dt(*data_)
                # model.params.register_grad_cut_hook_()
                model_loss_dt.backward()
                # model.params.remove_grad_cut_hook_()
                grad_real_cur = model_dt.params.flat.grad
                assert grad_real_cur is not None

            if should_train or eval_test_grad_loss:
                # mse loss target
                # grad_real_cur is not used as a mse target while testing
                assert grad_real_cur is not None
                batch_arg.update(
                    BatchManagerArgument(
                        grad_real_cur=grad_real_cur).immutable())

            if iteration % self.k == 0 or (should_train
                                           and not train_with_lazy_tf):
                # teacher-forcing
                assert grad_real_prev is not None
                batch_arg.update(
                    BatchManagerArgument(
                        grad_real_prev=grad_real_prev).immutable().volatile())
                grad_real_prev = None

            ########################################################################
            batch_manager = OptimizerBatchManager(batch_arg=batch_arg)

            for get, set in batch_manager.iterator():
                if iteration % self.k == 0 or (should_train
                                               and not train_with_lazy_tf):
                    grad_prev = get.grad_real_prev  # teacher-forcing
                else:
                    grad_prev = get.grad_prev  # free-running
                grad_cur, g_states = self(grad_prev.detach(),
                                          get.update_prev.detach(),
                                          get.params.detach(), get.g_states)
                set.grad_prev(grad_cur)
                set.g_states(g_states)
                set.params(get.params - 1.0 * grad_cur.detach())
                set.sigma(self.sigma)
                if should_train or eval_test_grad_loss:
                    set.mse_loss(
                        self.mse_loss(grad_cur, get.grad_real_cur.detach()))
            ##########################################################################

            batch_arg = batch_manager.batch_arg_to_set
            grad_real = batch_manager.batch_arg_to_get.grad_real_cur.detach()
            cosim_loss = nn.functional.cosine_similarity(batch_arg.grad_prev,
                                                         grad_real,
                                                         dim=0)
            grad_loss = batch_arg.mse_loss * 100
            cosim_loss = -torch.log((cosim_loss + 1) * 0.5) * 100
            grad_losses += (grad_loss + cosim_loss)
            #to_float = lambda x: float(x.data.cpu().numpy())
            sigma = batch_arg.sigma.detach().mean()
            iter_pbar.set_description(
                'optim_iteration[loss(model):{:10.6f}/(mse):{:10.6f}/(cosim):{:10.6f}/(sigma):{:10.6f}]'
                ''.format(model_loss_dt.tolist(),
                          grad_loss.tolist()[0],
                          cosim_loss.tolist()[0], sigma.tolist()))

            if should_train and iteration % unroll == 0:
                w_decay_loss = 0
                for p in self.parameters():
                    w_decay_loss += torch.sum(p * p)
                w_decay_loss = (w_decay_loss * 0.5) * 1e-4
                total_loss = grad_losses + w_decay_loss
                meta_optimizer.zero_grad()
                total_loss.backward()
                meta_optimizer.step()

            if not should_train or iteration % unroll == 0:
                model_losses = 0
                grad_losses = 0
                batch_arg.detach_()
                # batch_arg.params.detach_()
                # batch_arg.g_states.detach_()
                # batch_arg.u_states.detach_()

            result_dict.append(
                step_num=iteration,
                loss=model_loss_dt,
                grad_loss=grad_loss,
            )

            if not should_train:
                result_dict.append(walltime=iter_watch.touch(),
                                   **self.params_tracker(
                                       grad_real=grad_real_cur,
                                       grad_pred=batch_arg.grad_prev,
                                       update=batch_arg.update_prev,
                                       sigma=batch_arg.sigma,
                                   ))

            if (iteration +
                    1) % self.k == 0 or should_train or eval_test_grad_loss:
                grad_real_prev = grad_real_cur
                grad_real_cur = None

        return result_dict
Exemple #5
0
    def meta_optimize(self,
                      meta_optim,
                      data,
                      model_cls,
                      optim_it,
                      unroll,
                      out_mul,
                      k_obsrv=None,
                      no_mask=None,
                      writer=None,
                      mode='train'):
        """FIX LATER: do something about the dummy arguments."""
        assert mode in ['train', 'valid', 'test']
        self.set_mode(mode)
        use_indexer = False

        params = C(model_cls()).params
        states = C(
            OptimizerStates.initial_zeros(size=(self.n_layers, len(params),
                                                self.hidden_sz),
                                          rnn_cell=self.rnn_cell))

        result_dict = ResultDict()
        unroll_losses = 0
        walltime = Walltime()
        update = C(torch.zeros(params.size().flat()))

        if use_indexer:
            batch_arg = BatchManagerArgument(
                params=params,
                states=StatesSlicingWrapper(states),
                updates=update,
            )

        iter_pbar = tqdm(range(1, optim_it + 1), 'inner_train')
        for iter in iter_pbar:

            with WalltimeChecker(walltime):
                model_train = C(model_cls(params=params.detach()))
                train_nll, train_acc = model_train(*data['in_train'].load())
                train_nll.backward()
                if model_train.params.flat.grad is not None:
                    g = model_train.params.flat.grad
                else:
                    g = model_train._grad2params(model_train.params)
                assert g is not None

                if use_indexer:
                    # This indexer was originally for outer-level batch split
                    #  in case that the whold parameters do not fit in the VRAM.
                    # Which is not good, unless unavoidable, since you cannot avoid
                    #  additional time cost induced by the serial computation.
                    batch_arg.update(
                        BatchManagerArgument(grad=g).immutable().volatile())
                    indexer = DefaultIndexer(input=g, shuffle=False)
                    batch_manager = OptimizerBatchManager(batch_arg=batch_arg,
                                                          indexer=indexer)
                    ######################################################################
                    for get, set in batch_manager.iterator():
                        updates, new_states = self(get.grad().detach(),
                                                   get.states())
                        set.states(new_states)
                        set.params(get.params() + updates * out_mul)
                        if not mode == 'train':
                            set.updates(updates)
                    ######################################################################
                    batch_arg = batch_manager.batch_arg_to_set
                    params = batch_arg.params
                else:
                    updates, states = self(g.detach(), states)
                    updates = params.new_from_flat(updates)
                    params = params + updates * out_mul
                    #params = params - params.new_from_flat(g) * 0.1

            with WalltimeChecker(walltime if mode == 'train' else None):
                model_test = C(model_cls(params=params))
                test_nll, test_acc = model_test(*data['in_test'].load())
                if mode == 'train':
                    unroll_losses += test_nll
                    if iter % unroll == 0:
                        meta_optim.zero_grad()
                        unroll_losses.backward()
                        nn.utils.clip_grad_value_(self.parameters(), 0.01)
                        meta_optim.step()
                        unroll_losses = 0

            with WalltimeChecker(walltime):
                if not mode == 'train' or iter % unroll == 0:
                    if use_indexer:
                        batch_arg.detach_()
                    else:
                        params.detach_()
                        states.detach_()
                # model.params = model.params.detach()
                # states = states.detach()
            result = dict(
                train_nll=train_nll.tolist(),
                test_nll=test_nll.tolist(),
                train_acc=train_acc.tolist(),
                test_acc=test_acc.tolist(),
                walltime=walltime.time,
            )
            result_dict.append(result)
            log_pbar(result, iter_pbar)

        return result_dict, params
  def meta_optimize(self, meta_optimizer, data_cls, model_cls, optim_it, unroll,
                    out_mul, should_train=True):
    data = data_cls(training=should_train)
    model = C(model_cls())
    optimizer_states = C(OptimizerStates.initial_zeros(
        size=(self.n_layers, len(model.params), self.hidden_sz),
        rnn_cell=self.rnn_cell))

    if should_train:
      self.train()
    else:
      self.eval()

    result_dict = ResultDict()
    walltime = 0
    unroll_losses = 0
    update = C(torch.zeros(model.params.size().flat()))

    batch_arg = BatchManagerArgument(
      params=model.params,
      states=StatesSlicingWrapper(optimizer_states),
      updates=update,
    )

    iter_pbar = tqdm(range(1, optim_it + 1), 'optim_iteration')
    iter_watch = utils.StopWatch('optim_iteration')
    for iteration in iter_pbar:
      iter_watch.touch()
      # entire_loop.touch()
      # interval.touch()
      data_ = data.sample()
      if should_train:
        model = C(model_cls(params=batch_arg.params))
        loss = model(*data_)
        unroll_losses += loss
      # print("\n[1] : "+str(interval.touch('interval', False)))
      # model_ff.touch()
      model_detached = C(model_cls(params=batch_arg.params.detach()))
      loss_detached = model_detached(*data_)
      # model_ff.touch('interval', True)
      if should_train:
        assert loss == loss_detached
      # model_bp.touch()
      loss_detached.backward()
      # model_bp.touch('interval', True)
      # interval.touch()
      assert model_detached.params.flat.grad is not None
      grad = model_detached.params.flat.grad
      batch_arg.update(BatchManagerArgument(grad=grad).immutable().volatile())
      iter_pbar.set_description(f'optim_iteration[loss:{loss_detached}]')
      # print("\n[2-1] : "+str(interval.touch('interval', False)))
      ##########################################################################
      # indexer = DefaultIndexer(input=grad, shuffle=False)
      indexer = TopKIndexer(input=grad, k_ratio=0.1, random_k_mode=False)# bound=model_dt.params.size().bound_layerwise())
      #indexer = DefaultIndexer(input=grad)
      batch_manager = OptimizerBatchManager(
        batch_arg=batch_arg, indexer=indexer)
      # print("\n[2-2] : "+str(interval.touch('interval', False)))
      # optim_ff.touch()
      # optim_ff_1.touch()
      for get, set in batch_manager.iterator():

        # optim_ff_1.touch('interval', True)
        # optim_ff_2.touch()

        # optim_ff_2.touch('interval', True)
        # optim_ff_3.touch()
        updates, new_states = self(get.grad().detach(),  get.states())
        # optim_ff_3.touch('interval', True)
        # optim_ff_4.touch()
        updates = updates * out_mul
        set.states(new_states)
        set.params(get.params() + updates)
        if not should_train:
          set.updates(updates)
        # optim_ff_4.touch('interval', True)
      # optim_ff.touch('interval', True)
      ##########################################################################
      # interval.touch()
      batch_arg = batch_manager.batch_arg_to_set
      # print("\n[3] : "+str(interval.touch('interval', False)))
      if should_train and iteration % unroll == 0:

        # optim_bp.touch()
        meta_optimizer.zero_grad()
        unroll_losses.backward()
        nn.utils.clip_grad_norm_(self.parameters(), 10)
        meta_optimizer.step()
        unroll_losses = 0
        # optim_bp.touch('interval', True)
      # interval.touch()
      if not should_train or iteration % unroll == 0:
        batch_arg.detach_()
      # print("\n[4] : "+str(interval.touch('interval', False)))

      # interval.touch()
        # model.params = model.params.detach()
        # optimizer_states = optimizer_states.detach()
      walltime += iter_watch.touch('interval')
      result_dict.append(loss=loss_detached)
      # print("\n[5] : "+str(interval.touch('interval', False)))
      if not should_train:
        result_dict.append(
          walltime=iter_watch.touch(),
          **self.params_tracker(
            grad=grad,
            update=batch_arg.updates,
          )
        )
      # entire_loop.touch('interval', True)
    return result_dict
    def meta_optimize(self, meta_optimizer, data, model_cls, optim_it, unroll,
                      out_mul, mode):
        assert mode in ['train', 'valid', 'test']
        if mode == 'train':
            self.train()
            # data.new_train_data()
            inner_data = data.loaders['inner_train']
            outer_data = data.loaders['inner_valid']
            # inner_data = data.loaders['train']
            # outer_data = inner_data
            drop_mode = 'soft_drop'
        elif mode == 'valid':
            self.eval()
            inner_data = data.loaders['valid']
            outer_data = None
            drop_mode = 'hard_drop'
        elif mode == 'eval':
            self.eval()
            inner_data = data.loaders['test']
            outer_data = None
            drop_mode = 'hard_drop'

        # data = dataset(mode=mode)
        model = C(model_cls(sb_mode=self.sb_mode))
        model(*data.pseudo_sample())
        # layer_sz = sum([v for v in model.params.size().unflat(1, 'mat').values()])
        mask_states = C(
            OptimizerStates.initial_zeros(
                (self.n_layers, len(model.activations.mean(0)),
                 self.hidden_sz)))
        update_states = C(
            OptimizerStates.initial_zeros(
                (self.n_layers, len(model.params), self.hidden_sz)))

        result_dict = ResultDict()
        unroll_losses = 0
        walltime = 0
        update = C(torch.zeros(model.params.size().flat()))

        batch_arg = BatchManagerArgument(
            params=model.params,
            states=StatesSlicingWrapper(update_states),
            updates=model.params.new_zeros(model.params.size()),
        )

        iter_pbar = tqdm(range(1, optim_it + 1), 'optim_iteration')
        iter_watch = utils.StopWatch('optim_iteration')
        bp_watch = utils.StopWatch('bp')
        loss_decay = 1.0

        for iteration in iter_pbar:
            iter_watch.touch()
            model_detached = C(
                model_cls(params=batch_arg.params.detach(),
                          sb_mode=self.sb_mode))
            loss_detached = model_detached(*inner_data.load())
            loss_detached.backward()
            # import pdb; pdb.set_trace()
            iter_pbar.set_description(f'optim_iteration[loss:{loss_detached}]')
            # np.set_printoptions(suppress=True, precision=3, linewidth=200, edgeitems=20)
            if debug_sigint.signal_on:
                debug = iteration == 1 or iteration % 10 == 0
                # debug = True
            else:
                debug = False

            backprop_mask, mask_states, sparsity_loss = self.mask_generator(
                activations=model_detached.activations.grad.detach(),
                debug=debug,
                states=mask_states)

            # if debug_sigstp.signal_on and (iteration == 1 or iteration % 10 == 0):
            #   mask_list.append(backprop_mask)
            #   import pdb; pdb.set_trace()
            # min = 0.001
            # max = 0.01
            unroll_losses += sparsity_loss * 0.01 * \
                loss_decay  # * ((max - min) * lambda_ + min)
            # import pdb; pdb.set_trace()

            # model_detached.apply_backprop_mask(backprop_mask.unflat, drop_mode)
            # loss_detached.backward()
            assert model_detached.params.grad is not None

            if mode == 'train':
                model = C(
                    model_cls(params=batch_arg.params, sb_mode=self.sb_mode))

                loss = model(*outer_data.load())
                # assert loss == loss_detached
                if iteration == 1:
                    initial_loss = loss
                    loss_decay = 1.0
                else:
                    loss_decay = (loss / initial_loss).detach()
                unroll_losses += loss

                # model_detached.apply_backprop_mask(backprop_mask.unflat, drop_mode)

            #model_detached.params.apply_grad_mask(mask, 'soft_drop')
            coord_mask = backprop_mask.expand_as_params(model_detached.params)

            if drop_mode == 'soft_drop':
                # grad = model_detached.params.grad.mul(coord_mask).flat
                # model_detached.params = model_detached.params.mul(params_mask)
                index = None
                # updates, states = self.updater(grad, states, mask=coord_mask)
                # updates = updates * out_mul * coord_mask
                # import pdb; pdb.set_trace()
            elif drop_mode == 'hard_drop':
                index = (coord_mask.flat.squeeze() > 0.5).nonzero().squeeze()
            grad = model_detached.params.grad.flat
            batch_arg.update(BatchManagerArgument(grad=grad, mask=coord_mask))
            ##########################################################################
            indexer = DefaultIndexer(input=grad, index=index, shuffle=False)
            batch_manager = OptimizerBatchManager(batch_arg=batch_arg,
                                                  indexer=indexer)

            for get, set in batch_manager.iterator():
                inp = get.grad().detach()
                #inp = [get.grad().detach(), get.tag().detach()]
                if drop_mode == 'soft_drop':
                    updates, new_states = self.updater(
                        inp, get.states())  # , mask=get.mask())
                    # if iteration == 100:
                    #   import pdb; pdb.set_trace()
                    #   aaa = get.states() *2
                    updates = updates * out_mul * get.mask()
                    set.states(get.states() * (1 - get.mask()) +
                               new_states * get.mask())
                    # set.states(new_states)
                elif drop_mode == 'hard_drop':
                    updates, new_states = self.updater(inp, get.states())
                    updates = updates * out_mul
                    set.states(new_states)
                else:
                    raise RuntimeError('Unknown drop mode!')
                set.params(get.params() + updates)
                # if not mode != 'train':
                # set.updates(updates)
            ##########################################################################
            batch_arg = batch_manager.batch_arg_to_set
            # updates = model.params.new_from_flat(updates)

            if mode == 'train' and iteration % unroll == 0:
                meta_optimizer.zero_grad()
                unroll_losses.backward()
                nn.utils.clip_grad_norm_(self.parameters(), 5)
                meta_optimizer.step()
                unroll_losses = 0

            # if iteration == 1 or iteration % 10 == 0:
            #   import pdb; pdb.set_trace()

            if not mode == 'train' or iteration % unroll == 0:
                mask_states.detach_()
                batch_arg.detach_()
                # model.params = model.params.detach()
                # update_states = update_states.detach()
            walltime += iter_watch.touch('interval')
            result_dict.append(loss=loss_detached)
            if not mode == 'train':
                result_dict.append(
                    walltime=walltime,
                    # **self.params_tracker(
                    #   grad=grad,
                    #   update=batch_arg.updates,
                    # )
                )
        return result_dict
  def meta_optimize(self, meta_optimizer, data_cls, model_cls, optim_it, unroll,
                    out_mul, should_train=True):
    target = data_cls(training=should_train)
    optimizee = C(model_cls())
    optimizer_states = OptimizerStates.zero_states(
        self.n_layers, len(optimizee.params), self.hidden_sz)

    if should_train:
      self.train()
    else:
      self.eval()
      coordinate_tracker = ParamsIndexTracker(
        n_params=len(optimizee.params), n_tracks=30)

    result_dict = ResultDict()
    unroll_losses = 0
    iter_pbar = tqdm(range(1, optim_it + 1), 'optim_iteration')
    iter_watch = utils.StopWatch('optim_iteration')

    for iteration in iter_pbar:
      loss = optimizee(target)
      unroll_losses += loss
      loss.backward(retain_graph=should_train)
      iter_pbar.set_description(f'optim_iteration[loss:{loss}]')
      unit_size = torch.Size([len(optimizee.params), 1])
      grid_size = torch.Size([len(optimizee.params), self.n_coord])
      batch_manager = OptimizerBatchManager(
        params=optimizee.params,
        states=StatesSlicingWrapper(optimizer_states),
        units=OptimizerBatchManager.placeholder(unit_size),
        grid_abs=OptimizerBatchManager.placeholder(grid_size),
        grid_rel=OptimizerBatchManager.placeholder(grid_size),
        batch_size=30000,
      )
      ##########################################################################
      for get, set in batch_manager.iterator():
        units, new_states = self.unit_maker(
          C.detach(get.params.grad), get.states.clone())
        units = units * out_mul / self.n_coord / 2
        grid_abs, grid_rel = self.grid_maker(units, get.params)
        set.states(new_states)
        set.units(units)
        set.grid_abs(grid_abs)
        set.grid_rel(grid_rel)
      ##########################################################################

      landscape = self.loss_observer(batch_manager.grid_abs, model_cls, target)
      updates = self.step_maker(batch_manager.grid_rel, landscape)
      optimizee.params.add_flat_(updates)

      if should_train and iteration % unroll == 0:
        meta_optimizer.zero_grad()
        unroll_losses.backward()
        meta_optimizer.step()

      if not should_train or iteration % unroll == 0:
        unroll_losses = 0
        optimizee.params = optimizee.params.detach()
        optimizer_states = optimizer_states.detach()

      optimizee = C(model_cls(optimizee.params))
      result_dict.append(loss=loss)
      if not should_train:
        result_dict.append(
            grid=coordinate_tracker(batch_manager.units),
            update=coordinate_tracker(updates),
            walltime=iter_watch.touch(),
        )
    return result_dict
    def meta_optimize(self,
                      meta_optimizer,
                      data_cls,
                      model_cls,
                      optim_it,
                      unroll,
                      out_mul,
                      should_train=True):
        data = data_cls(training=should_train)
        model = C(model_cls())
        states = C(
            OptimizerStates.initial_zeros(
                (self.n_layers, len(model.params), self.hidden_sz)))

        if should_train:
            self.train()
        else:
            self.eval()

        result_dict = ResultDict()
        model_dt = None
        model_losses = 0
        #grad_real_prev = None
        iter_pbar = tqdm(range(1, optim_it + 1), 'optim_iteration')
        iter_watch = utils.StopWatch('optim_iteration')

        batch_arg = BatchManagerArgument(
            params=model.params,
            states=StatesSlicingWrapper(states),
            update_prev=C(torch.zeros(model.params.size().flat())),
            mu=C(torch.zeros(model.params.size().flat())),
            sigma=C(torch.zeros(model.params.size().flat())),
        )

        for iteration in iter_pbar:
            data_ = data.sample()
            model_loss = model(*data_)
            model_losses += model_loss

            # if iteration == 21:
            #   import pdb; pdb.set_trace()
            try:
                make_dot(model_losses)
            except:
                import pdb
                pdb.set_trace()
            # if iteration == 1 or iteration == 21:
            #   import pdb; pdb.set_trace()

            model_dt = C(model_cls(params=model.params.detach()))
            model_loss_dt = model_dt(*data_)
            # model.params.register_grad_cut_hook_()
            model_loss_dt.backward()
            assert model_dt.params.flat.grad is not None
            grad_cur = model_dt.params.flat.grad
            # model.params.remove_grad_cut_hook_()
            batch_arg.update(BatchManagerArgument(grad_cur=grad_cur))
            # NOTE: read_only, write_only

            ########################################################################
            batch_manager = OptimizerBatchManager(batch_arg=batch_arg)

            for get, set in batch_manager.iterator():
                grad_cur = get.grad_cur.detach()  # free-running
                update_prev = get.update_prev.detach()
                points_rel, states = self.point_sampler(
                    grad_cur, update_prev, get.states, self.n_points)
                points_rel = points_rel * out_mul
                points_abs = get.params.detach(
                ) + points_rel  # NOTE: check detach
                losses = self.loss_observer(points_abs, model_cls, data_)
                update = self.point_selector(points_rel, losses)
                set.params(get.params + update)
                set.states(states)
                set.update_prev(update)
                set.mu(self.point_sampler.mu)
                set.sigma(self.point_sampler.sigma)
            ##########################################################################

            batch_arg = batch_manager.batch_arg_to_set
            iter_pbar.set_description('optim_iteration[loss:{:10.6f}]'.format(
                model_loss.tolist()))

            if should_train and iteration % unroll == 0:
                meta_optimizer.zero_grad()
                try:
                    model_losses.backward()
                except:
                    import pdb
                    pdb.set_trace()
                meta_optimizer.step()

            if not should_train or iteration % unroll == 0:
                model_losses = 0
                batch_arg.detach_()
                # batch_arg.params.detach_()
                # batch_arg.states.detach_()
                # NOTE: just do batch_arg.detach_()

            model = C(model_cls(params=batch_arg.params))
            result_dict.append(
                step_num=iteration,
                loss=model_loss,
            )

            if not should_train:
                result_dict.append(walltime=iter_watch.touch(),
                                   **self.params_tracker(
                                       grad=grad_cur,
                                       update=batch_arg.update_prev,
                                       mu=batch_arg.mu,
                                       sigma=batch_arg.sigma,
                                   ))
        return result_dict