예제 #1
0
 def advanced_one_step(self, *args, lr_list=None, **kwargs):
     if lr_list is None: lr_list = [0.000088] * self.branches_num
     for i in range(self.branches_num):
         self.set_branch_index(i)
         if i > 0:
             FLAGS.overwrite = False
             FLAGS.save_best = True
         self._optimizer_lr_modify(lr_list[i])
         Model.train(self, *args, **kwargs)
     lr_list = [0.000088, 0.00088, 0.000088]
     self.train(branch_index=self.branches_num, lr_list=lr_list, **kwargs)
예제 #2
0
    def train(self, *args, branch_index=0, lr_list=None, **kwargs):
        if lr_list is None: lr_list = [0.000088] * self.branches_num
        self.set_branch_index(branch_index)
        freeze = kwargs.get('freeze', True)
        if not freeze:
            train_step = []
            for i in range(branch_index + 1):
                self._optimizer_lr_modify(lr_list[i])
                train_step.append(
                    self._optimizer.minimize(loss=self._loss,
                                             var_list=self._var_list[i]))
            self._train_step = train_step

        Model.train(self, *args, **kwargs)