def advanced_one_step(self, *args, lr_list=None, **kwargs): if lr_list is None: lr_list = [0.000088] * self.branches_num for i in range(self.branches_num): self.set_branch_index(i) if i > 0: FLAGS.overwrite = False FLAGS.save_best = True self._optimizer_lr_modify(lr_list[i]) Model.train(self, *args, **kwargs) lr_list = [0.000088, 0.00088, 0.000088] self.train(branch_index=self.branches_num, lr_list=lr_list, **kwargs)
def train(self, *args, branch_index=0, lr_list=None, **kwargs): if lr_list is None: lr_list = [0.000088] * self.branches_num self.set_branch_index(branch_index) freeze = kwargs.get('freeze', True) if not freeze: train_step = [] for i in range(branch_index + 1): self._optimizer_lr_modify(lr_list[i]) train_step.append( self._optimizer.minimize(loss=self._loss, var_list=self._var_list[i])) self._train_step = train_step Model.train(self, *args, **kwargs)