def randomwalk_decomposition(goal_tn): loss_fun = cc.tensor_recovery_loss input_dims = [t.shape[i] for i, t in enumerate(goal_tn)] base_tn = cc.random_tn(input_dims, rank=1) # Initialize the first tensor network close to zero for i in range(len(base_tn)): base_tn[i] /= 10 base_tn = cc.make_trainable(base_tn) trained_tn, best_loss, loss_record = randomwalk_optim( base_tn, goal_tn, loss_fun, other_args={ 'cprint': True, 'epochs': 10000, 'max_iter': 20, 'lr': 0.01, 'optim': 'RMSprop', 'search_epochs': 80, 'cvg_threshold': 1e-10, 'stop_on_plateau': { 'mode': 'min', 'patience': 100, 'threshold': 1e-7 }, 'dyn_print': True, 'initial_epochs': 10 }) return loss_record
def randomwalk_regression(train_data, val_data=None): loss_fun = cc.regression_loss input_dims = [t.shape[1] for t in train_data[0]] base_tn = cc.random_tn(input_dims, rank=1) # Initialize the first tensor network close to zero for i in range(len(base_tn)): base_tn[i] /= 10 base_tn = cc.make_trainable(base_tn) trained_tn, best_loss, loss_record = randomwalk_optim( base_tn, train_data, loss_fun, val_data=val_data, other_args={ 'cprint': True, 'epochs': None, 'max_iter': 20, 'lr': 0.01, 'optim': 'RMSprop', 'search_epochs': 80, 'cvg_threshold': 1e-10, 'bsize': 100, 'is_reg': True, 'stop_on_plateau': { 'mode': 'min', 'patience': 50, 'threshold': 1e-7 }, 'dyn_print': True, 'initial_epochs': 10 }) return loss_record
def greedy_completion(dataset, input_dims, initial_network=None,filename=None): loss_fun = cc.completion_loss from generate_tensor_ring import generate_tensor_ring base_tn = cc.random_tn(input_dims, rank=1) # Initialize the first tensor network close to zero for i in range(len(base_tn)): base_tn[i] /= 1 base_tn = cc.make_trainable(base_tn) if initial_network: base_tn = initial_network # create list of all edges allowed in a TR decomposition #ndims = len(base_tn) #tr_edges = [(i,j) for i in range(ndims) for j in range(i+1,ndims) if i+1==j] + [(0,ndims-1)] lr_scheduler = lambda optimizer: ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=50, verbose=True,threshold=1e-7) trained_tn, best_loss, loss_record = greedy_optim(base_tn, dataset, loss_fun, find_best_edge=greedy_find_best_edge, other_args={'cprint':True, 'epochs':20000, 'max_iter':100, 'lr':0.01, 'optim':'RMSprop', 'search_epochs':20, 'cvg_threshold':1e-10, #'stop_on_plateau':{'mode':'min', 'patience':50, 'threshold':1e-7}, 'dyn_print':True,'initial_epochs':10,'bsize':-1, 'rank_increment':2, #'allowed_edges':tr_edges 'lr_scheduler':lr_scheduler, 'filename':filename }) return loss_record
def randomwalk_optim(tensor_list, train_data, loss_fun, val_data=None, other_args=dict(), max_iter=None): """ Train a tensor network using discrete optimization over TN ranks Args: tensor_list: List of tensors encoding the network being trained train_data: The data used to train the network(target Tensor) loss_fun: Scalar-valued loss function of the type tens_list, data -> scalar_loss (This depends on the task being learned) val_data: The data used for validation, which can be used to for early stopping within continuous optimization calls within discrete optimization loop other_args: Dictionary of other arguments for the optimization, with some options below (feel free to add more) epochs: Number of epochs for continuous optimization (default=10) max_iter: Maximum number of iterations for random walk search (default=10) optim: Choice of Pytorch optimizer (default='SGD') lr: Learning rate for optimizer (default=1e-3) bsize: Minibatch size for training (default=100) reps: Number of times to repeat training data per epoch (default=1) cprint: Whether to print info from continuous optimization (default=True) dprint: Whether to print info from discrete optimization (default=True) search_epochs: Number of epochs to use to identify the best rank 1 update. If None, the epochs argument is used. (default=None) loss_threshold: if loss gets below this threshold, discrete optimization is stopped (default=1e-5) initial_epochs: Number of epochs after which the learning rate is reduced and optimization is restarted if there is no improvement in the loss. (default=None) rank_increment: how much should a rank be increase at each discrete iterations (default=1) stop_on_plateau: a dictionnary containing keys mode (min/max) patience threshold used to stop continuous optimizaion when plateau is detected (default=None) gradient_hook: this is a hack, please ignore... Returns: better_list: List of tensors with same length as tensor_list, but having been optimized using the discrete optimization algorithm. The TN ranks of better_list will be larger than those of tensor_list. best_loss: The value of the validation/training loss for the model output as better_list loss_record: If dhist=True in other_args, this records the history of all losses for discrete and continuous optimization. This is a list of dictionnaries with keys iter,num_params,network,loss,train_loss_hist,val_loss_hist where iter: iteration of the discrete optimization num_params: number of parameters of best network for this iteration network: list of tensors for the best network for this iteration loss: loss achieved by the best network in this iteration train_loss_hist: history of losses for the continuous optimization for this iteration starting from previous best_network to the epoch where the new best network was found TODO: add val_loss_hist to the list of keys """ # Check input and initialize local record variables epochs = other_args['epochs'] if 'epochs' in other_args else 10 max_iter = other_args['max_iter'] if 'max_iter' in other_args else 10 dprint = other_args['dprint'] if 'dprint' in other_args else True cprint = other_args['cprint'] if 'cprint' in other_args else True search_epochs = other_args[ 'search_epochs'] if 'search_epochs' in other_args else epochs loss_threshold = other_args[ 'loss_threshold'] if 'loss_threshold' in other_args else 1e-5 initial_epochs = other_args[ 'initial_epochs'] if 'initial_epochs' in other_args else None rank_increment = other_args[ 'rank_increment'] if 'rank_increment' in other_args else 1 gradient_hook = other_args[ 'gradient_hook'] if 'gradient_hook' in other_args else None is_reg = other_args['is_reg'] if 'is_reg' in other_args else False stop_on_plateau = other_args[ 'stop_on_plateau'] if 'stop_on_plateau' in other_args else None if stop_on_plateau: detect_plateau = DetectPlateau(**stop_on_plateau) other_args[ "stop_condition"] = lambda train_loss, val_loss: detect_plateau( train_loss) first_loss, best_loss, best_network = ([], []), None, None d_loss_hist = ([], []) other_args['hist'] = True # we always track the history of losses # Function to maybe print, conditioned on `dprint` m_print = lambda s: print(s) if dprint else None # Copy tensor_list so the original is unchanged tensor_list = cc.copy_network(tensor_list) loss_record = [{ 'iter': -1, 'network': tensor_list, 'num_params': cc.num_params(tensor_list), 'train_loss_hist': [0], 'val_loss_hist': [0] }] # Define a function giving the stop condition for the discrete # optimization procedure. I'm using a simple example here which could # work for greedy or random walk searches, but for other optimization # methods this could be trivial stop_cond = lambda loss: loss < loss_threshold # Iteratively increment ranks of tensor_list and train via # continuous_optim, at each stage using a search procedure to # test out different ranks before choosing just one to increase stage = -1 best_loss, best_network, best_network_optimizer_state = np.infty, None, None while not stop_cond(best_loss) and stage < max_iter: stage += 1 if best_loss is np.infty: # first continuous optimization tensor_list, first_loss, best_loss, best_epoch, hist = training( tensor_list, initial_epochs, train_data, loss_fun, epochs=epochs, val_data=val_data, other_args=other_args) m_print("Initial model has TN ranks") if dprint: cc.print_ranks(tensor_list) m_print(f"Initial loss is {best_loss:.7f}") initialNetwork = cc.copy_network(tensor_list) best_network = cc.copy_network(tensor_list) loss_record[0]["loss"] = first_loss loss_record.append({ 'iter': stage, 'network': best_network, 'num_params': cc.num_params(best_network), 'loss': best_loss, 'train_loss_hist': hist[0][:best_epoch], 'val_loss_hist': hist[1][:best_epoch] }) else: m_print( f"\n\n**** Discrete optimization - iteration {stage} ****\n\n\n" ) best_search_loss = best_loss best_train_lost_hist = [] best_val_loss_hist = [] [i, j] = torch.randperm(len(tensor_list))[:2].tolist() currentNetwork = cc.copy_network(initialNetwork) #increase rank along a chosen dimension currentNetwork = cc.increase_rank(currentNetwork, i, j, rank_increment, 1e-6) currentNetwork = cc.make_trainable(currentNetwork) print('\ntesting rank increment for i =', i, 'j = ', j) ### Search optimization phase # we d only a few epochs to identify the most promising rank update # function to zero out the gradient of all entries except for the new slices def grad_masking_function(tensor_list): nonlocal i, j for k in range(len(tensor_list)): if k == i: tensor_list[i].grad.permute( [j] + list(range(0, j)) + list(range(j + 1, len(currentNetwork))))[:-1, :, ...] *= 0 elif k == j: tensor_list[j].grad.permute( [i] + list(range(0, i)) + list(range(i + 1, len(currentNetwork))))[:-1, :, ...] *= 0 else: tensor_list[k].grad *= 0 # we first optimize only the new slices for a few epochs print("optimize new slices for a few epochs") search_args = dict(other_args) current_network_optimizer_state = {} search_args["save_optimizer_state"] = True search_args["optimizer_state"] = current_network_optimizer_state if not is_reg: search_args["grad_masking_function"] = grad_masking_function if stop_on_plateau: detect_plateau._reset() [currentNetwork, first_loss, current_loss, best_epoch, hist] = training(currentNetwork, initial_epochs, train_data, loss_fun, val_data=val_data, epochs=search_epochs, other_args=search_args) first_loss = hist[0][0] train_lost_hist = deepcopy(hist[0][:best_epoch]) val_loss_hist = deepcopy(hist[1][:best_epoch]) # # We then optimize all parameters for a few epochs # print("\noptimize all parameters for a few epochs") # search_args["grad_masking_function"] = None # search_args["load_optimizer_state"] = dict(current_network_optimizer_state) # if stop_on_plateau: # detect_plateau._reset() # [currentNetwork, first_loss, current_loss, best_epoch, hist] = training(currentNetwork, initial_epochs, train_data, # loss_fun, val_data=val_data, epochs=search_epochs, # other_args=search_args) # search_args["load_optimizer_state"] = None # train_lost_hist += deepcopy(hist[0][:best_epoch]) best_search_loss = current_loss best_network = currentNetwork best_network_optimizer_state = deepcopy( current_network_optimizer_state) best_train_lost_hist = train_lost_hist best_val_loss_hist = val_loss_hist print('-> best rank update so far:', i, j) best_loss = best_search_loss # train network to convergence print('\ntraining best network until max_epochs/convergence...') other_args["load_optimizer_state"] = best_network_optimizer_state current_network_optimizer_state = {} other_args["save_optimizer_state"] = True other_args["optimizer_state"] = current_network_optimizer_state if gradient_hook: # A bit of a hack only used for tensor completion - ignore other_args["grad_masking_function"] = gradient_hook if stop_on_plateau: detect_plateau._reset() [best_network, first_loss, best_loss, best_epoch, hist] = training(best_network, initial_epochs, train_data, loss_fun, val_data=val_data, epochs=epochs, other_args=other_args) if gradient_hook: other_args["grad_masking_function"] = None other_args["load_optimizer_state"] = None best_train_lost_hist += deepcopy(hist[0][:best_epoch]) best_val_loss_hist += deepcopy(hist[1][:best_epoch]) initialNetwork = cc.copy_network(best_network) loss_record.append({ 'iter': stage, 'network': best_network, 'num_params': cc.num_params(best_network), 'loss': best_loss, 'train_loss_hist': best_train_lost_hist, 'val_loss_hist': best_val_loss_hist }) print('\nbest TN:') cc.print_ranks(best_network) print('number of params:', cc.num_params(best_network)) print([(r['iter'], r['num_params'], float(r['loss']), float(r['train_loss_hist'][0]), float(r['train_loss_hist'][-1])) for r in loss_record]) return best_network, best_loss, loss_record
def main(args): num_train = args.ntrain num_val = args.nval input_dims = [7, 7, 7, 7, 7] goal_tn = torch.load(args.path) base_tn = cc.random_tn(input_dims, rank=1) base_tn = cc.make_trainable(base_tn) loss_fun = cc.regression_loss train_data = cc.generate_regression_data(goal_tn, num_train, noise=1e-6) val_data = cc.generate_regression_data(goal_tn, num_val, noise=1e-6) best_network, first_loss, best_loss, loss_record, loss_hist, param_count, d_loss_hist = rw.randomwalk_optim( base_tn, train_data, loss_fun, val_data=val_data, other_args={ 'dhist': True, 'optim': 'RMSprop', 'max_iter': args.maxiter, 'epochs': None, # early stopping 'lr': 0.001 }) plt.figure(figsize=(4, 3)) plt.plot(loss_hist[0]) plt.xlabel('Epoch') plt.ylabel('Training loss') plt.savefig('./figures/' + args.path + '_randomwalk' + '_trainloss' + '_.pdf', bbox_inches='tight') plt.figure(figsize=(4, 3)) plt.plot(param_count, d_loss_hist[0]) plt.xlabel('Number of parameters') plt.ylabel('Training loss') plt.savefig('./figures/' + args.path + '_randomwalk' + '_trainloss_numparam' + '_.pdf', bbox_inches='tight') plt.figure(figsize=(4, 3)) plt.plot(param_count, loss_record) plt.xlabel('Number of parameters') plt.ylabel('Validation loss') plt.savefig('./figures/' + args.path + '_randomwalk' + '_valloss_numparam' + '_.pdf', bbox_inches='tight') ### TODO: Add greedy ### random search best_network, first_loss, best_loss, loss_record, param_count, d_loss_hist = rs.randomsearch_optim( base_tn, train_data, loss_fun, val_data=val_data, other_args={ 'dhist': True, 'optim': 'RMSprop', 'max_iter': args.maxiter, 'epochs': None, # early stopping 'lr': 0.001 }) plt.figure(figsize=(4, 3)) plt.plot(param_count, d_loss_hist[0]) plt.xlabel('Number of parameters') plt.ylabel('Training loss') plt.savefig('./figures/' + args.path + '_randomsearch' + '_trainloss_numparam' + '_.pdf', bbox_inches='tight') plt.figure(figsize=(4, 3)) plt.plot(param_count, loss_record) plt.xlabel('Number of parameters') plt.ylabel('Validation loss') plt.savefig('./figures/' + args.path + '_randomsearch' + '_valloss_numparam' + '_.pdf', bbox_inches='tight')
torch.manual_seed(0) #Target tensor is a chain #Tensor decomposition d0 = 4 d1 = 4 d2 = 4 d3 = 4 d4 = 4 d5 = 4 r12 = 2 r23 = 3 r34 = 6 r45 = 5 r56 = 4 input_dims = [d0, d1, d2, d3, d4, d5] rank_list = [[r12, 1, 1, 1, 1], [r23, 1, 1, 1], [r34, 1, 1], [r45, 1], [r56]] # Parameters to control the experimental behavior exp_params = {'print': False, 'epochs': 200} loss_fun = cc.tensor_recovery_loss base_tn = cc.random_tn(input_dims, rank=1) goal_tn = cc.random_tn(input_dims, rank=rank_list) base_tn = cc.make_trainable(base_tn) _, _, better_loss = discrete_optim_template(base_tn, goal_tn, loss_fun, other_args=exp_params) print('better loss = ', better_loss)
def randomsearch_optim(tensor_list, train_data, loss_fun, val_data=None, other_args=dict()): """ Train a tensor network using discrete optimization over TN ranks Args: tensor_list: List of tensors encoding the network being trained train_data: The data used to train the network loss_fun: Scalar-valued loss function of the type tens_list, data -> scalar_loss (This depends on the task being learned) val_data: The data used for validation, which can be used to for early stopping within continuous optimization calls within discrete optimization loop other_args: Dictionary of other arguments for the optimization, with some options below (feel free to add more) epochs: Number of epochs for continuous optimization (default=10) max_iter: Maximum number of iterations for random search (default=10) optim: Choice of Pytorch optimizer (default='SGD') lr: Learning rate for optimizer (default=1e-3) bsize: Minibatch size for training (default=100) max_rank: Maximum rank to search for (default=7) reps: Number of times to repeat training data per epoch (default=1) cprint: Whether to print info from continuous optimization (default=True) dprint: Whether to print info from discrete optimization (default=True) loss_threshold: if loss gets below this threshold, discrete optimization is stopped (default=1e-5) initial_epochs: Number of epochs after which the learning rate is reduced and optimization is restarted if there is no improvement in the loss. (default=None) stop_on_plateau: a dictionnary containing keys mode (min/max) patience threshold used to stop continuous optimizaion when plateau is detected (default=None) Returns: better_list: List of tensors with same length as tensor_list, but having been optimized using the discrete optimization algorithm. The TN ranks of better_list will be larger than those of tensor_list. best_loss: The value of the validation/training loss for the model output as better_list loss_record: If dhist=True in other_args, this records the history of all losses for discrete and continuous optimization. This is a list of dictionnaries with keys iter,num_params,network,loss,train_loss_hist where iter: iteration of the discrete optimization num_params: number of parameters of best network for this iteration network: list of tensors for the best network for this iteration loss: loss achieved by the best network in this iteration train_loss_hist: history of losses for the continuous optimization for this iteration starting from previous best_network to the epoch where the new best network was found TODO: add val_lost_hist to the list of keys """ # Check input and initialize local record variables epochs = other_args['epochs'] if 'epochs' in other_args else 10 max_iter = other_args['max_iter'] if 'max_iter' in other_args else 10 max_rank = other_args['max_rank'] if 'max_rank' in other_args else 7 max_params = other_args[ 'max_params'] if 'max_params' in other_args else 10000 dprint = other_args['dprint'] if 'dprint' in other_args else True cprint = other_args['cprint'] if 'cprint' in other_args else True loss_threshold = other_args[ 'loss_threshold'] if 'loss_threshold' in other_args else 1e-5 initial_epochs = other_args[ 'initial_epochs'] if 'initial_epochs' in other_args else None # Keep track of continous optimization history other_args['hist'] = True stop_on_plateau = other_args[ 'stop_on_plateau'] if 'stop_on_plateau' in other_args else None if stop_on_plateau: detect_plateau = DetectPlateau(**stop_on_plateau) other_args[ "stop_condition"] = lambda train_loss, val_loss: detect_plateau( train_loss) first_loss, best_loss, best_network = ([], []), None, None d_loss_hist = ([], []) # Function to maybe print, conditioned on `dprint` m_print = lambda s: print(s) if dprint else None # Copy tensor_list so the original is unchanged tensor_list = cc.copy_network(tensor_list) loss_record = [{ 'iter': -1, 'network': tensor_list, 'num_params': cc.num_params(tensor_list), 'train_loss_hist': [0] }] # Define a function giving the stop condition for the discrete # optimization procedure. I'm using a simple example here which could # work for greedy or random walk searches, but for other optimization # methods this could be trivial stop_cond = lambda loss: loss < loss_threshold input_dims = cc.get_indims(tensor_list) n_cores = len(input_dims) n_edges = n_cores * (n_cores - 1) // 2 # Iteratively increment ranks of tensor_list and train via # continuous_optim, at each stage using a search procedure to # test out different ranks before choosing just one to increase stage = -1 best_loss, best_network, best_network_optimizer_state = np.infty, None, None while not stop_cond(best_loss) and stage < max_iter: stage += 1 if best_loss is np.infty: # first continuous optimization tensor_list, first_loss, best_loss, best_epoch, hist = training( tensor_list, initial_epochs, train_data, loss_fun, epochs=epochs, val_data=val_data, other_args=other_args) m_print("Initial model has TN ranks") if dprint: cc.print_ranks(tensor_list) m_print(f"Initial loss is {best_loss:.7f}") initialNetwork = cc.copy_network(tensor_list) best_network = cc.copy_network(tensor_list) loss_record[0]["loss"] = first_loss loss_record.append({ 'iter': stage, 'network': best_network, 'num_params': cc.num_params(best_network), 'loss': best_loss, 'train_loss_hist': hist[0][:best_epoch] }) else: m_print( f"\n\n**** Discrete optimization - iteration {stage} ****\n\n\n" ) # Create new TN random ranks ranks = torch.randint(low=1, high=max_rank + 1, size=(n_edges, 1)).view(-1, ).tolist() rank_list = _make_ranks(ranks) currentNetwork = _limit_random_tn(input_dims, rank=rank_list, max_params=max_params) currentNetwork = cc.make_trainable(currentNetwork) if stop_on_plateau: detect_plateau._reset() currentNetwork, first_loss, current_loss, best_epoch, hist = training( currentNetwork, initial_epochs, train_data, loss_fun, val_data=val_data, epochs=epochs, other_args=other_args) train_lost_hist = hist[0][:best_epoch] loss_record.append({ 'iter': stage, 'network': currentNetwork, 'num_params': cc.num_params(currentNetwork), 'loss': current_loss, 'train_loss_hist': train_lost_hist }) if best_loss > current_loss: best_network = currentNetwork best_loss = current_loss print('\nbest TN:') cc.print_ranks(best_network) print('number of params:', cc.num_params(best_network)) print([(r['iter'], r['num_params'], float(r['loss']), float(r['train_loss_hist'][0]), float(r['train_loss_hist'][-1])) for r in loss_record]) return best_network, best_loss, loss_record
def greedy_find_best_edge(initialNetwork,train_data,val_data,loss_fun,rank_increment,training_args,prev_best_loss,allowed_edges=None,search_epochs=10,padding_noise=1e-6,is_reg=False): """ initialNetwork: list of tensors repreesnting the initial network allowed_edges: a list of allowed edges of rank more than one in the tensor network. If None, all edges are considered. (default=None) rank_increment: how much should the rank of the new edge be increased padding_noise: standard deviation of the normal distribution use to initialize the new slices of the core tensors receiving the new edge training_args: dictionnary of arguments passed to the training function. (default={}) """ if not allowed_edges: ndims = len(initialNetwork) allowed_edges = [(i,j) for i in range(ndims) for j in range(i+1,ndims)] stop_on_plateau = training_args['stop_on_plateau'] if 'stop_on_plateau' in training_args else None if stop_on_plateau: detect_plateau = DetectPlateau(**stop_on_plateau) training_args["stop_condition"] = lambda train_loss,val_loss : detect_plateau(train_loss) initial_epochs = training_args['initial_epochs'] if 'initial_epochs' in training_args else None best_loss, best_network, best_network_optimizer_state = np.infty, None, None best_search_loss = best_loss best_train_lost_hist = [] best_val_loss_hist = [] for (i,j) in allowed_edges: currentNetwork = cc.copy_network(initialNetwork) #increase rank along a chosen dimension currentNetwork = cc.increase_rank(currentNetwork,i, j, rank_increment, padding_noise) currentNetwork = cc.make_trainable(currentNetwork) print('\ntesting rank increment for i =', i, 'j = ', j) ### Search optimization phase # we do only a few epochs to identify the most promising rank update # we first optimize only the new slices for a few epochs print("optimize new slices for a few epochs") search_args = dict(training_args) current_network_optimizer_state = {} search_args["save_optimizer_state"] = True search_args["optimizer_state"] = current_network_optimizer_state if not is_reg: search_args["only_new_slice"] = (i,j) if stop_on_plateau: detect_plateau._reset() [currentNetwork, first_loss, current_loss, best_epoch, hist] = training(currentNetwork, training_args["initial_epochs"], train_data, loss_fun, val_data=val_data, epochs=search_epochs, other_args=search_args) first_loss = hist[0][0] train_lost_hist = deepcopy(hist[0][:best_epoch]) val_loss_hist = deepcopy(hist[1][:best_epoch]) # We then optimize all parameters for a few epochs print("\noptimize all parameters for a few epochs") search_args["load_optimizer_state"] = dict(current_network_optimizer_state) search_args["only_new_slice"] = False if stop_on_plateau: detect_plateau._reset() [currentNetwork, first_loss, current_loss, best_epoch, hist] = training(currentNetwork, initial_epochs, train_data, loss_fun, val_data=val_data, epochs=search_epochs, other_args=search_args) search_args["load_optimizer_state"] = None train_lost_hist += deepcopy(hist[0][:best_epoch]) val_loss_hist += deepcopy(hist[1][:best_epoch]) print(f"\nCurrent loss is {current_loss:.7f} Best loss from previous discrete optim is {prev_best_loss}") if best_search_loss > current_loss: best_search_loss = current_loss best_network = currentNetwork best_network_optimizer_state = deepcopy(current_network_optimizer_state) best_train_lost_hist = train_lost_hist best_val_loss_hist = val_loss_hist print('-> best rank update so far:', i,j) return best_network,best_search_loss,best_network_optimizer_state,best_train_lost_hist,best_val_loss_hist