def training(tensor_list, initial_epochs, train_data, loss_fun, val_data, epochs, other_args): """ This function run the continuous optimization routine with a small tweak: if there is no improvement of the loss in the first [initial_epochs] epochs, the learning rate is reduced by a factor of 0.5 and optimization is restarted from the beginning, All arguments are the same as cc.continuous_optim except for [initial_epochs]. """ if initial_epochs is None: return cc.continuous_optim(currentNetwork, train_data, loss_fun, val_data=val_data, epochs=epochs, other_args=args) args = deepcopy(other_args) args["hist"] = True current_network_optimizer_state = other_args["optimizer_state"] if "optimizer_state" in args else {} args["save_optimizer_state"] = True args["optimizer_state"] = current_network_optimizer_state currentNetwork = cc.copy_network(tensor_list) remaining_epochs = epochs - initial_epochs if epochs else None hist = None while hist is None or (hist[0][0] < hist[0][-1]): if hist: if "load_optimizer_state" in args and args["load_optimizer_state"]: args["load_optimizer_state"]["optimizer_state"]['param_groups'][0]['lr'] /= 2 lr = args["load_optimizer_state"]["optimizer_state"]['param_groups'][0]['lr'] else: args["lr"] /= 2 lr = args["lr"] print(f"\n[!] No progress in first {initial_epochs} epochs, starting again with smaller learning rate ({lr})") currentNetwork = cc.copy_network(tensor_list) [currentNetwork, first_loss, current_loss, hist] = cc.continuous_optim(currentNetwork, train_data, loss_fun, val_data=val_data, epochs=initial_epochs, other_args=args) args["load_optimizer_state"] = current_network_optimizer_state [currentNetwork, first_loss, current_loss, hist] = cc.continuous_optim(currentNetwork, train_data, loss_fun, val_data=val_data, epochs=remaining_epochs, other_args=args) return [currentNetwork, first_loss, current_loss, hist] if "hist" in other_args else [currentNetwork, first_loss, current_loss]
def plot_image(tensor_list): plt.imshow( cc.wire_network(cc.copy_network(tensor_list), give_dense=True).detach().numpy().reshape( im_size, order=order) / 255)
def randomwalk_optim(tensor_list, train_data, loss_fun, val_data=None, other_args=dict(), max_iter=None): """ Train a tensor network using discrete optimization over TN ranks Args: tensor_list: List of tensors encoding the network being trained train_data: The data used to train the network(target Tensor) loss_fun: Scalar-valued loss function of the type tens_list, data -> scalar_loss (This depends on the task being learned) val_data: The data used for validation, which can be used to for early stopping within continuous optimization calls within discrete optimization loop other_args: Dictionary of other arguments for the optimization, with some options below (feel free to add more) epochs: Number of epochs for continuous optimization (default=10) max_iter: Maximum number of iterations for random walk search (default=10) optim: Choice of Pytorch optimizer (default='SGD') lr: Learning rate for optimizer (default=1e-3) bsize: Minibatch size for training (default=100) reps: Number of times to repeat training data per epoch (default=1) cprint: Whether to print info from continuous optimization (default=True) dprint: Whether to print info from discrete optimization (default=True) search_epochs: Number of epochs to use to identify the best rank 1 update. If None, the epochs argument is used. (default=None) loss_threshold: if loss gets below this threshold, discrete optimization is stopped (default=1e-5) initial_epochs: Number of epochs after which the learning rate is reduced and optimization is restarted if there is no improvement in the loss. (default=None) rank_increment: how much should a rank be increase at each discrete iterations (default=1) stop_on_plateau: a dictionnary containing keys mode (min/max) patience threshold used to stop continuous optimizaion when plateau is detected (default=None) gradient_hook: this is a hack, please ignore... Returns: better_list: List of tensors with same length as tensor_list, but having been optimized using the discrete optimization algorithm. The TN ranks of better_list will be larger than those of tensor_list. best_loss: The value of the validation/training loss for the model output as better_list loss_record: If dhist=True in other_args, this records the history of all losses for discrete and continuous optimization. This is a list of dictionnaries with keys iter,num_params,network,loss,train_loss_hist,val_loss_hist where iter: iteration of the discrete optimization num_params: number of parameters of best network for this iteration network: list of tensors for the best network for this iteration loss: loss achieved by the best network in this iteration train_loss_hist: history of losses for the continuous optimization for this iteration starting from previous best_network to the epoch where the new best network was found TODO: add val_loss_hist to the list of keys """ # Check input and initialize local record variables epochs = other_args['epochs'] if 'epochs' in other_args else 10 max_iter = other_args['max_iter'] if 'max_iter' in other_args else 10 dprint = other_args['dprint'] if 'dprint' in other_args else True cprint = other_args['cprint'] if 'cprint' in other_args else True search_epochs = other_args[ 'search_epochs'] if 'search_epochs' in other_args else epochs loss_threshold = other_args[ 'loss_threshold'] if 'loss_threshold' in other_args else 1e-5 initial_epochs = other_args[ 'initial_epochs'] if 'initial_epochs' in other_args else None rank_increment = other_args[ 'rank_increment'] if 'rank_increment' in other_args else 1 gradient_hook = other_args[ 'gradient_hook'] if 'gradient_hook' in other_args else None is_reg = other_args['is_reg'] if 'is_reg' in other_args else False stop_on_plateau = other_args[ 'stop_on_plateau'] if 'stop_on_plateau' in other_args else None if stop_on_plateau: detect_plateau = DetectPlateau(**stop_on_plateau) other_args[ "stop_condition"] = lambda train_loss, val_loss: detect_plateau( train_loss) first_loss, best_loss, best_network = ([], []), None, None d_loss_hist = ([], []) other_args['hist'] = True # we always track the history of losses # Function to maybe print, conditioned on `dprint` m_print = lambda s: print(s) if dprint else None # Copy tensor_list so the original is unchanged tensor_list = cc.copy_network(tensor_list) loss_record = [{ 'iter': -1, 'network': tensor_list, 'num_params': cc.num_params(tensor_list), 'train_loss_hist': [0], 'val_loss_hist': [0] }] # Define a function giving the stop condition for the discrete # optimization procedure. I'm using a simple example here which could # work for greedy or random walk searches, but for other optimization # methods this could be trivial stop_cond = lambda loss: loss < loss_threshold # Iteratively increment ranks of tensor_list and train via # continuous_optim, at each stage using a search procedure to # test out different ranks before choosing just one to increase stage = -1 best_loss, best_network, best_network_optimizer_state = np.infty, None, None while not stop_cond(best_loss) and stage < max_iter: stage += 1 if best_loss is np.infty: # first continuous optimization tensor_list, first_loss, best_loss, best_epoch, hist = training( tensor_list, initial_epochs, train_data, loss_fun, epochs=epochs, val_data=val_data, other_args=other_args) m_print("Initial model has TN ranks") if dprint: cc.print_ranks(tensor_list) m_print(f"Initial loss is {best_loss:.7f}") initialNetwork = cc.copy_network(tensor_list) best_network = cc.copy_network(tensor_list) loss_record[0]["loss"] = first_loss loss_record.append({ 'iter': stage, 'network': best_network, 'num_params': cc.num_params(best_network), 'loss': best_loss, 'train_loss_hist': hist[0][:best_epoch], 'val_loss_hist': hist[1][:best_epoch] }) else: m_print( f"\n\n**** Discrete optimization - iteration {stage} ****\n\n\n" ) best_search_loss = best_loss best_train_lost_hist = [] best_val_loss_hist = [] [i, j] = torch.randperm(len(tensor_list))[:2].tolist() currentNetwork = cc.copy_network(initialNetwork) #increase rank along a chosen dimension currentNetwork = cc.increase_rank(currentNetwork, i, j, rank_increment, 1e-6) currentNetwork = cc.make_trainable(currentNetwork) print('\ntesting rank increment for i =', i, 'j = ', j) ### Search optimization phase # we d only a few epochs to identify the most promising rank update # function to zero out the gradient of all entries except for the new slices def grad_masking_function(tensor_list): nonlocal i, j for k in range(len(tensor_list)): if k == i: tensor_list[i].grad.permute( [j] + list(range(0, j)) + list(range(j + 1, len(currentNetwork))))[:-1, :, ...] *= 0 elif k == j: tensor_list[j].grad.permute( [i] + list(range(0, i)) + list(range(i + 1, len(currentNetwork))))[:-1, :, ...] *= 0 else: tensor_list[k].grad *= 0 # we first optimize only the new slices for a few epochs print("optimize new slices for a few epochs") search_args = dict(other_args) current_network_optimizer_state = {} search_args["save_optimizer_state"] = True search_args["optimizer_state"] = current_network_optimizer_state if not is_reg: search_args["grad_masking_function"] = grad_masking_function if stop_on_plateau: detect_plateau._reset() [currentNetwork, first_loss, current_loss, best_epoch, hist] = training(currentNetwork, initial_epochs, train_data, loss_fun, val_data=val_data, epochs=search_epochs, other_args=search_args) first_loss = hist[0][0] train_lost_hist = deepcopy(hist[0][:best_epoch]) val_loss_hist = deepcopy(hist[1][:best_epoch]) # # We then optimize all parameters for a few epochs # print("\noptimize all parameters for a few epochs") # search_args["grad_masking_function"] = None # search_args["load_optimizer_state"] = dict(current_network_optimizer_state) # if stop_on_plateau: # detect_plateau._reset() # [currentNetwork, first_loss, current_loss, best_epoch, hist] = training(currentNetwork, initial_epochs, train_data, # loss_fun, val_data=val_data, epochs=search_epochs, # other_args=search_args) # search_args["load_optimizer_state"] = None # train_lost_hist += deepcopy(hist[0][:best_epoch]) best_search_loss = current_loss best_network = currentNetwork best_network_optimizer_state = deepcopy( current_network_optimizer_state) best_train_lost_hist = train_lost_hist best_val_loss_hist = val_loss_hist print('-> best rank update so far:', i, j) best_loss = best_search_loss # train network to convergence print('\ntraining best network until max_epochs/convergence...') other_args["load_optimizer_state"] = best_network_optimizer_state current_network_optimizer_state = {} other_args["save_optimizer_state"] = True other_args["optimizer_state"] = current_network_optimizer_state if gradient_hook: # A bit of a hack only used for tensor completion - ignore other_args["grad_masking_function"] = gradient_hook if stop_on_plateau: detect_plateau._reset() [best_network, first_loss, best_loss, best_epoch, hist] = training(best_network, initial_epochs, train_data, loss_fun, val_data=val_data, epochs=epochs, other_args=other_args) if gradient_hook: other_args["grad_masking_function"] = None other_args["load_optimizer_state"] = None best_train_lost_hist += deepcopy(hist[0][:best_epoch]) best_val_loss_hist += deepcopy(hist[1][:best_epoch]) initialNetwork = cc.copy_network(best_network) loss_record.append({ 'iter': stage, 'network': best_network, 'num_params': cc.num_params(best_network), 'loss': best_loss, 'train_loss_hist': best_train_lost_hist, 'val_loss_hist': best_val_loss_hist }) print('\nbest TN:') cc.print_ranks(best_network) print('number of params:', cc.num_params(best_network)) print([(r['iter'], r['num_params'], float(r['loss']), float(r['train_loss_hist'][0]), float(r['train_loss_hist'][-1])) for r in loss_record]) return best_network, best_loss, loss_record
def discrete_optim_template(tensor_list, train_data, loss_fun, val_data=None, other_args=dict()): """ Train a tensor network using discrete optimization over TN ranks Args: tensor_list: List of tensors encoding the network being trained train_data: The data used to train the network(target Tensor) loss_fun: Scalar-valued loss function of the type tens_list, data -> scalar_loss (This depends on the task being learned) val_data: The data used for validation, which can be used to for early stopping within continuous optimization calls within discrete optimization loop other_args: Dictionary of other arguments for the optimization, with some options below (feel free to add more) epochs: Number of epochs for continuous optimization (default=10) optim: Choice of Pytorch optimizer (default='SGD') lr: Learning rate for optimizer (default=1e-3) bsize: Minibatch size for training (default=100) reps: Number of times to repeat training data per epoch (default=1) cprint: Whether to print info from continuous optimization (default=True) dprint: Whether to print info from discrete optimization (default=True) dhist: Whether to return losses from intermediate TNs (default=False) Returns: better_list: List of tensors with same length as tensor_list, but having been optimized using the discrete optimization algorithm. The TN ranks of better_list will be larger than those of tensor_list. first_loss: Initial loss of the model on the validation set, before any training. If no val set is provided, the first training loss is instead returned best_loss: The value of the validation/training loss for the model output as better_list loss_record: If dhist=True in other_args, all values of best_loss associated with intermediate optimized TNs will be returned as a PyTorch vector, with loss_record[0] giving the initial loss of the model, and loss_record[-1] equal to best_loss value returned by discrete_optim. """ # Check input and initialize local record variables epochs = other_args['epochs'] if 'epochs' in other_args else 10 dprint = other_args['dprint'] if 'dprint' in other_args else True dhist = other_args['dhist'] if 'dhist' in other_args else False loss_rec, first_loss, best_loss, best_network = [], None, None, None if dhist: loss_record = [] # (train_record, val_record) # Function to maybe print, conditioned on `dprint` m_print = lambda s: print(s) if dprint else None # Function to record loss information def record_loss(new_loss, new_network): # Load record variables from outer scope nonlocal loss_rec, first_loss, best_loss, best_network # Check for best loss if best_loss is None or new_loss < best_loss: best_loss, best_network = new_loss, new_network # Add new loss to our loss record if not dhist: return nonlocal loss_record loss_record.append(new_loss) # Copy tensor_list so the original is unchanged tensor_list = cc.copy_network(tensor_list) # Define a function giving the stop condition for the discrete # optimization procedure. I'm using a simple example here which could # work for greedy or random walk searches, but for other optimization # methods this could be trivial stop_cond = generate_stop_cond(cc.get_indims(tensor_list)) ###########***************where are you increasing the rank from line 123 to 137?***** # Iteratively increment ranks of tensor_list and train via # continuous_optim, at each stage using a search procedure to # test out different ranks before choosing just one to increase stage = 0 better_network, better_loss = tensor_list, 1e10 initialNetwork = cc.copy_network(tensor_list) #example_tn while not stop_cond(better_network): if first_loss is None: # Record initial loss of TN model first_args = other_args.copy() first_args["print"] = first_args["hist"] = False _, first_loss, _ = cc.continuous_optim(tensor_list, train_data, loss_fun, epochs=1, val_data=val_data, other_args=first_args) m_print("Initial model has TN ranks") m_print(f"Initial loss is {first_loss:.3f}") m_print("Performing initial optimization...") first_args["print"] = True initialNetwork, _, _ = cc.continuous_optim(tensor_list, train_data, loss_fun, epochs=2000, val_data=val_data, other_args=first_args) continue m_print(f"STAGE {stage}") stage += 1 ##################################line 139 onward are new added acode############# # Try out training different network ranks and assign network # with best ranks to better_network # # TODO: This is your part to write! Use new variables for the loss # and network being tried out, with the network being # initialized from better_network. When you've found a better # TN, make sure to update better_network and better_loss to # equal the parameters and loss of that TN. # At some point this code will call the continuous optimization # loop, in which case you can use the following command: # # trained_tn, init_loss, final_loss = cc.continuous_optim(my_tn, # train_data, loss_fun, # epochs=epochs, # val_data=val_data, # other_args=other_args) #best_loss = float('inf') prev_loss = float('inf') num_cores = len(tensor_list) max_params = torch.prod(torch.tensor( (cc.get_indims(train_data)))) #train data is the target tensor loss_record = [] for i in range(num_cores): for j in range(i + 1, num_cores): m_print(f"Testing i={i}, j={j}") #reset currentNetwork to contuniue with greedy at another point currentNetwork = cc.copy_network(initialNetwork) #increase rank along a chosen dimension currentNetwork = cc.increase_rank(currentNetwork, i, j, 1, 1e-6) stop_cond = generate_stop_cond(cc.get_indims( currentNetwork)) #***********need to add param=-1 if stop_cond( currentNetwork ) == True: #i.e. break if the number of parameters in trained tensor exceeds number of param. in target tensor break #solve continuos optimization part, train_data = target_tensor [currentNetwork, first_loss, current_loss] = cc.continuous_optim(currentNetwork, train_data, loss_fun, epochs=epochs, val_data=val_data, other_args=other_args) if prev_loss > current_loss: prev_loss = current_loss better_network = currentNetwork numParam = cc.num_params(better_network) better_loss = current_loss m_print("BEST CONFIG") else: m_print("Not best config") m_print(f"{first_loss:.3f} -> {current_loss:.3f}\n") #update current point to the new point (i.e. better_network) that gave lower loss initialNetwork = cc.copy_network(better_network) loss_record.append(better_loss) return best_network, first_loss, better_loss #, loss_record
def greedy_optim(tensor_list, train_data, loss_fun, val_data=None, other_args=dict(),max_iter=None): """ Train a tensor network using discrete optimization over TN ranks Args: tensor_list: List of tensors encoding the network being trained train_data: The data used to train the network(target Tensor) loss_fun: Scalar-valued loss function of the type tens_list, data -> scalar_loss (This depends on the task being learned) val_data: The data used for validation, which can be used to for early stopping within continuous optimization calls within discrete optimization loop other_args: Dictionary of other arguments for the optimization, with some options below (feel free to add more) epochs: Number of epochs for continuous optimization (default=10) max_iter: Maximum number of iterations for greedy search (default=10) optim: Choice of Pytorch optimizer (default='SGD') lr: Learning rate for optimizer (default=1e-3) bsize: Minibatch size for training (default=100) reps: Number of times to repeat training data per epoch (default=1) cprint: Whether to print info from continuous optimization (default=True) dprint: Whether to print info from discrete optimization (default=True) dhist: Whether to return losses from intermediate TNs (default=False) search_epochs: Number of epochs to use to identify the best rank 1 update. If None, the epochs argument is used. (default=None) loss_threshold: if loss gets below this threshold, discrete optimization is stopped (default=1e-5) initial_epochs: Number of epochs after which the learning rate is reduced and optimization is restarted if there is no improvement in the loss. (default=None) Returns: better_list: List of tensors with same length as tensor_list, but having been optimized using the discrete optimization algorithm. The TN ranks of better_list will be larger than those of tensor_list. first_loss: Initial loss of the model on the validation set, before any training. If no val set is provided, the first training loss is instead returned best_loss: The value of the validation/training loss for the model output as better_list loss_record: If dhist=True in other_args, all values of best_loss associated with intermediate optimized TNs will be returned as a PyTorch vector, with loss_record[0] giving the initial loss of the model, and loss_record[-1] equal to best_loss value returned by discrete_optim. """ # Check input and initialize local record variables epochs = other_args['epochs'] if 'epochs' in other_args else 10 max_iter = other_args['max_iter'] if 'max_iter' in other_args else 10 dprint = other_args['dprint'] if 'dprint' in other_args else True cprint = other_args['cprint'] if 'cprint' in other_args else True dhist = other_args['dhist'] if 'dhist' in other_args else False search_epochs = other_args['search_epochs'] if 'search_epochs' in other_args else None loss_threshold = other_args['loss_threshold'] if 'loss_threshold' in other_args else 1e-5 initial_epochs = other_args['initial_epochs'] if 'initial_epochs' in other_args else None stop_cond = lambda loss: loss < loss_threshold if dhist: loss_record = [] # (train_record, val_record) # Function to maybe print, conditioned on `dprint` m_print = lambda s: print(s) if dprint else None # Copy tensor_list so the original is unchanged tensor_list = cc.copy_network(tensor_list) # Iteratively increment ranks of tensor_list and train via # continuous_optim, at each stage using a search procedure to # test out different ranks before choosing just one to increase stage = 0 loss_record, best_loss, best_network = [], np.infty, None while not stop_cond(best_loss) and stage < max_iter: stage += 1 if best_loss is np.infty: # first continuous optimization tensor_list, first_loss, best_loss = training( tensor_list, initial_epochs, train_data, loss_fun, epochs=epochs, val_data=val_data,other_args=other_args) m_print("Initial model has TN ranks") if dprint: cc.print_ranks(tensor_list) m_print(f"Initial loss is {best_loss:.7f}") initialNetwork = cc.copy_network(tensor_list) best_network = cc.copy_network(tensor_list) loss_record += [first_loss,best_loss] m_print(f"\n\n**** Discrete optimization - iteration {stage} ****\n\n\n") best_search_loss = best_loss for i in range(len(initialNetwork)): for j in range(i+1, len(initialNetwork)): currentNetwork = cc.copy_network(initialNetwork) #increase rank along a chosen dimension currentNetwork = cc.increase_rank(currentNetwork,i, j, 1, 1e-6) print('\ntesting rank increment for i =', i, 'j = ', j) if search_epochs: # we d only a few epochs to identify the most promising rank update # function to zero out the gradient of all entries except for the new slices def grad_masking_function(tensor_list): nonlocal i,j for k in range(len(tensor_list)): if k == i: tensor_list[i].grad.permute([j]+list(range(0,j))+list(range(j+1,len(currentNetwork))))[:-1,:,...] *= 0 elif k == j: tensor_list[j].grad.permute([i]+list(range(0,i))+list(range(i+1,len(currentNetwork))))[:-1,:,...] *= 0 else: tensor_list[k].grad *= 0 # we first optimize only the new slices for a few epochs print("optimize new slices for a few epochs") search_args = dict(other_args) search_args["hist"] = True current_network_optimizer_state = {} search_args["save_optimizer_state"] = True search_args["optimizer_state"] = current_network_optimizer_state search_args["grad_masking_function"] = grad_masking_function [currentNetwork, first_loss, current_loss, hist] = training(currentNetwork, initial_epochs, train_data, loss_fun, val_data=val_data, epochs=search_epochs, other_args=search_args) first_loss = hist[0][0] # We then optimize all parameters for a few epochs print("\noptimize all parameters for a few epochs") search_args["grad_masking_function"] = None search_args["load_optimizer_state"] = dict(current_network_optimizer_state) [currentNetwork, first_loss, current_loss, hist] = training(currentNetwork, initial_epochs, train_data, loss_fun, val_data=val_data, epochs=search_epochs , other_args=search_args) search_args["load_optimizer_state"] = None else: # we fully optimize the network in the search phase [currentNetwork, first_loss, current_loss] = cc.continuous_optim(currentNetwork, train_data, loss_fun, val_data=val_data, epochs=epochs, other_args=other_args) m_print(f"\nCurrent loss is {current_loss:.7f} Best loss from previous discrete optim is {best_loss}") if best_search_loss > current_loss: best_search_loss = current_loss best_network = currentNetwork best_network_optimizer_state = deepcopy(current_network_optimizer_state) print('-> best rank update so far:', i,j) best_loss = best_search_loss # train network to convergence for the best rank increment (if search_epochs is set, # otherwise the best network is already trained to convergence / max_epochs) if search_epochs: print('\ntraining best network until max_epochs/convergence...') other_args["load_optimizer_state"] = best_network_optimizer_state current_network_optimizer_state = {} other_args["save_optimizer_state"] = True other_args["optimizer_state"] = current_network_optimizer_state [best_network, first_loss, best_loss] = training(best_network, initial_epochs, train_data, loss_fun, val_data=val_data, epochs=epochs, other_args=other_args) other_args["load_optimizer_state"] = None initialNetwork = cc.copy_network(best_network) loss_record.append((stage, cc.num_params(best_network), float(best_loss))) print('\nbest TN:') cc.print_ranks(best_network) print('number of params:',cc.num_params(best_network)) print(loss_record) return best_network, first_loss, best_loss
def randomsearch_optim(tensor_list, train_data, loss_fun, val_data=None, other_args=dict()): """ Train a tensor network using discrete optimization over TN ranks Args: tensor_list: List of tensors encoding the network being trained train_data: The data used to train the network loss_fun: Scalar-valued loss function of the type tens_list, data -> scalar_loss (This depends on the task being learned) val_data: The data used for validation, which can be used to for early stopping within continuous optimization calls within discrete optimization loop other_args: Dictionary of other arguments for the optimization, with some options below (feel free to add more) epochs: Number of epochs for continuous optimization (default=10) max_iter: Maximum number of iterations for random search (default=10) optim: Choice of Pytorch optimizer (default='SGD') lr: Learning rate for optimizer (default=1e-3) bsize: Minibatch size for training (default=100) max_rank: Maximum rank to search for (default=7) reps: Number of times to repeat training data per epoch (default=1) cprint: Whether to print info from continuous optimization (default=True) dprint: Whether to print info from discrete optimization (default=True) loss_threshold: if loss gets below this threshold, discrete optimization is stopped (default=1e-5) initial_epochs: Number of epochs after which the learning rate is reduced and optimization is restarted if there is no improvement in the loss. (default=None) stop_on_plateau: a dictionnary containing keys mode (min/max) patience threshold used to stop continuous optimizaion when plateau is detected (default=None) Returns: better_list: List of tensors with same length as tensor_list, but having been optimized using the discrete optimization algorithm. The TN ranks of better_list will be larger than those of tensor_list. best_loss: The value of the validation/training loss for the model output as better_list loss_record: If dhist=True in other_args, this records the history of all losses for discrete and continuous optimization. This is a list of dictionnaries with keys iter,num_params,network,loss,train_loss_hist where iter: iteration of the discrete optimization num_params: number of parameters of best network for this iteration network: list of tensors for the best network for this iteration loss: loss achieved by the best network in this iteration train_loss_hist: history of losses for the continuous optimization for this iteration starting from previous best_network to the epoch where the new best network was found TODO: add val_lost_hist to the list of keys """ # Check input and initialize local record variables epochs = other_args['epochs'] if 'epochs' in other_args else 10 max_iter = other_args['max_iter'] if 'max_iter' in other_args else 10 max_rank = other_args['max_rank'] if 'max_rank' in other_args else 7 max_params = other_args[ 'max_params'] if 'max_params' in other_args else 10000 dprint = other_args['dprint'] if 'dprint' in other_args else True cprint = other_args['cprint'] if 'cprint' in other_args else True loss_threshold = other_args[ 'loss_threshold'] if 'loss_threshold' in other_args else 1e-5 initial_epochs = other_args[ 'initial_epochs'] if 'initial_epochs' in other_args else None # Keep track of continous optimization history other_args['hist'] = True stop_on_plateau = other_args[ 'stop_on_plateau'] if 'stop_on_plateau' in other_args else None if stop_on_plateau: detect_plateau = DetectPlateau(**stop_on_plateau) other_args[ "stop_condition"] = lambda train_loss, val_loss: detect_plateau( train_loss) first_loss, best_loss, best_network = ([], []), None, None d_loss_hist = ([], []) # Function to maybe print, conditioned on `dprint` m_print = lambda s: print(s) if dprint else None # Copy tensor_list so the original is unchanged tensor_list = cc.copy_network(tensor_list) loss_record = [{ 'iter': -1, 'network': tensor_list, 'num_params': cc.num_params(tensor_list), 'train_loss_hist': [0] }] # Define a function giving the stop condition for the discrete # optimization procedure. I'm using a simple example here which could # work for greedy or random walk searches, but for other optimization # methods this could be trivial stop_cond = lambda loss: loss < loss_threshold input_dims = cc.get_indims(tensor_list) n_cores = len(input_dims) n_edges = n_cores * (n_cores - 1) // 2 # Iteratively increment ranks of tensor_list and train via # continuous_optim, at each stage using a search procedure to # test out different ranks before choosing just one to increase stage = -1 best_loss, best_network, best_network_optimizer_state = np.infty, None, None while not stop_cond(best_loss) and stage < max_iter: stage += 1 if best_loss is np.infty: # first continuous optimization tensor_list, first_loss, best_loss, best_epoch, hist = training( tensor_list, initial_epochs, train_data, loss_fun, epochs=epochs, val_data=val_data, other_args=other_args) m_print("Initial model has TN ranks") if dprint: cc.print_ranks(tensor_list) m_print(f"Initial loss is {best_loss:.7f}") initialNetwork = cc.copy_network(tensor_list) best_network = cc.copy_network(tensor_list) loss_record[0]["loss"] = first_loss loss_record.append({ 'iter': stage, 'network': best_network, 'num_params': cc.num_params(best_network), 'loss': best_loss, 'train_loss_hist': hist[0][:best_epoch] }) else: m_print( f"\n\n**** Discrete optimization - iteration {stage} ****\n\n\n" ) # Create new TN random ranks ranks = torch.randint(low=1, high=max_rank + 1, size=(n_edges, 1)).view(-1, ).tolist() rank_list = _make_ranks(ranks) currentNetwork = _limit_random_tn(input_dims, rank=rank_list, max_params=max_params) currentNetwork = cc.make_trainable(currentNetwork) if stop_on_plateau: detect_plateau._reset() currentNetwork, first_loss, current_loss, best_epoch, hist = training( currentNetwork, initial_epochs, train_data, loss_fun, val_data=val_data, epochs=epochs, other_args=other_args) train_lost_hist = hist[0][:best_epoch] loss_record.append({ 'iter': stage, 'network': currentNetwork, 'num_params': cc.num_params(currentNetwork), 'loss': current_loss, 'train_loss_hist': train_lost_hist }) if best_loss > current_loss: best_network = currentNetwork best_loss = current_loss print('\nbest TN:') cc.print_ranks(best_network) print('number of params:', cc.num_params(best_network)) print([(r['iter'], r['num_params'], float(r['loss']), float(r['train_loss_hist'][0]), float(r['train_loss_hist'][-1])) for r in loss_record]) return best_network, best_loss, loss_record
def greedy_optim(tensor_list, train_data, loss_fun, find_best_edge, val_data=None, other_args=dict(),max_iter=None): """ Train a tensor network using discrete optimization over TN ranks Args: tensor_list: List of tensors encoding the network being trained train_data: The data used to train the network(target Tensor) loss_fun: Scalar-valued loss function of the type tens_list, data -> scalar_loss (This depends on the task being learned) val_data: The data used for validation, which can be used to for early stopping within continuous optimization calls within discrete optimization loop find_best_edge: function identifying the most promising edge and returning a new network with incremented edge. This function must accept the following arguments: initialNetwork train_data val_data loss_fun rank_increment training_args prev_best_loss search_epochs (specific to current searching method) padding_noise is_reg (hack, ignore) other_args: Dictionary of other arguments for the optimization, with some options below (feel free to add more) epochs: Number of epochs for continuous optimization (default=10) max_iter: Maximum number of iterations for greedy search (default=10) optim: Choice of Pytorch optimizer (default='SGD') lr: Learning rate for optimizer (default=1e-3) bsize: Minibatch size for training (default=100) reps: Number of times to repeat training data per epoch (default=1) cprint: Whether to print info from continuous optimization (default=True) dprint: Whether to print info from discrete optimization (default=True) search_epochs: Number of epochs to use to identify the best rank 1 update. If None, the epochs argument is used. (default=None) loss_threshold: if loss gets below this threshold, discrete optimization is stopped (default=1e-5) initial_epochs: Number of epochs after which the learning rate is reduced and optimization is restarted if there is no improvement in the loss. (default=None) rank_increment: how much should a rank be increase at each discrete iterations (default=1) stop_on_plateau: a dictionnary containing keys mode (min/max) patience threshold used to stop continuous optimizaion when plateau is detected (default=None) filename: pickle filename to save the loss_record after each continuous optimization. (default=None) Returns: better_list: List of tensors with same length as tensor_list, but having been optimized using the discrete optimization algorithm. The TN ranks of better_list will be larger than those of tensor_list. best_loss: The value of the validation/training loss for the model output as better_list loss_record: If dhist=True in other_args, this records the history of all losses for discrete and continuous optimization. This is a list of dictionnaries with keys iter,num_params,network,loss,train_loss_hist,val_loss_hist where iter: iteration of the discrete optimization num_params: number of parameters of best network for this iteration network: list of tensors for the best network for this iteration loss: loss achieved by the best network in this iteration train_loss_hist: history of losses for the continuous optimization for this iteration starting from previous best_network to the epoch where the new best network was found TODO: add val_lost_hist to the list of keys """ # Check input and initialize local record variables epochs = other_args['epochs'] if 'epochs' in other_args else 10 max_iter = other_args['max_iter'] if 'max_iter' in other_args else 10 dprint = other_args['dprint'] if 'dprint' in other_args else True cprint = other_args['cprint'] if 'cprint' in other_args else True search_epochs = other_args['search_epochs'] if 'search_epochs' in other_args else epochs loss_threshold = other_args['loss_threshold'] if 'loss_threshold' in other_args else 1e-5 initial_epochs = other_args['initial_epochs'] if 'initial_epochs' in other_args else None rank_increment = other_args['rank_increment'] if 'rank_increment' in other_args else 1 gradient_hook = other_args['gradient_hook'] if 'gradient_hook' in other_args else None #allowed_edges = other_args['allowed_edges'] if 'allowed_edges' in other_args else None filename = other_args['filename'] if 'filename' in other_args else None is_reg = other_args['is_reg'] if 'is_reg' in other_args else False assert initial_epochs < search_epochs and (not epochs or initial_epochs < epochs), "initial_epochs must be smaller than search_epochs and epochs" stop_on_plateau = other_args['stop_on_plateau'] if 'stop_on_plateau' in other_args else None if stop_on_plateau: detect_plateau = DetectPlateau(**stop_on_plateau) other_args["stop_condition"] = lambda train_loss,val_loss : detect_plateau(train_loss) first_loss, best_loss, best_network = ([],[]), None, None d_loss_hist = ([], []) other_args['hist'] = True # we always track the history of losses # Function to maybe print, conditioned on `dprint` m_print = lambda s: print(s) if dprint else None # Copy tensor_list so the original is unchanged tensor_list = cc.copy_network(tensor_list) loss_record = [{'iter':-1,'network':tensor_list, 'num_params':cc.num_params(tensor_list), 'train_loss_hist':[0], 'val_loss_hist':[0]}] # Define a function giving the stop condition for the discrete # optimization procedure. I'm using a simple example here which could # work for greedy or random walk searches, but for other optimization # methods this could be trivial stop_cond = lambda loss: loss < loss_threshold # Iteratively increment ranks of tensor_list and train via # continuous_optim, at each stage using a search procedure to # test out different ranks before choosing just one to increase stage = -1 best_loss, best_network, best_network_optimizer_state = np.infty, None, None while not stop_cond(best_loss) and stage < max_iter: stage += 1 if best_loss is np.infty: # first continuous optimization tensor_list, first_loss, best_loss, best_epoch, hist = training(tensor_list, initial_epochs, train_data, loss_fun, epochs=epochs, val_data=val_data,other_args=other_args) m_print("Initial model has TN ranks") if dprint: cc.print_ranks(tensor_list) m_print(f"Initial loss is {best_loss:.7f}") initialNetwork = cc.copy_network(tensor_list) best_network = cc.copy_network(tensor_list) loss_record[0]["loss"] = first_loss loss_record.append({'iter':stage, 'network':best_network, 'num_params':cc.num_params(best_network), 'loss':best_loss, 'train_loss_hist':hist[0][:best_epoch], 'val_loss_hist':hist[1][:best_epoch]}) if filename: with open(filename, "wb") as f: pickle.dump(loss_record,f) else: m_print(f"\n\n**** Discrete optimization - iteration {stage} ****\n\n\n") best_search_loss = best_loss best_train_lost_hist = [] best_val_loss_hist = [] best_network,best_loss, best_network_optimizer_state,best_train_lost_hist,best_val_loss_hist = \ find_best_edge(initialNetwork=initialNetwork, train_data=train_data, val_data=val_data, loss_fun=loss_fun, rank_increment=rank_increment, training_args=other_args, prev_best_loss=best_loss, search_epochs=search_epochs, padding_noise=1e-6, is_reg=False) # train network to convergence for the best rank increment print('\ntraining best network until max_epochs/convergence...') other_args["load_optimizer_state"] = best_network_optimizer_state current_network_optimizer_state = {} other_args["save_optimizer_state"] = True other_args["optimizer_state"] = current_network_optimizer_state if stop_on_plateau: detect_plateau._reset() [best_network, first_loss, best_loss, best_epoch, hist] = training(best_network, initial_epochs, train_data, loss_fun, val_data=val_data, epochs=epochs, other_args=other_args) other_args["load_optimizer_state"] = None best_train_lost_hist += deepcopy(hist[0][:best_epoch]) best_val_loss_hist += deepcopy(hist[1][:best_epoch]) initialNetwork = cc.copy_network(best_network) loss_record.append({'iter':stage, 'network':best_network, 'num_params':cc.num_params(best_network), 'loss':best_loss, 'train_loss_hist':best_train_lost_hist, 'val_loss_hist':best_val_loss_hist}) print('\nbest TN:') cc.print_ranks(best_network) print('number of params:',cc.num_params(best_network)) print([(r['iter'],r['num_params'],float(r['loss']),float(r['train_loss_hist'][0]),float(r['train_loss_hist'][-1])) for r in loss_record]) if filename: with open(filename, "wb") as f: pickle.dump(loss_record,f) return best_network, best_loss, loss_record
def greedy_find_best_edge(initialNetwork,train_data,val_data,loss_fun,rank_increment,training_args,prev_best_loss,allowed_edges=None,search_epochs=10,padding_noise=1e-6,is_reg=False): """ initialNetwork: list of tensors repreesnting the initial network allowed_edges: a list of allowed edges of rank more than one in the tensor network. If None, all edges are considered. (default=None) rank_increment: how much should the rank of the new edge be increased padding_noise: standard deviation of the normal distribution use to initialize the new slices of the core tensors receiving the new edge training_args: dictionnary of arguments passed to the training function. (default={}) """ if not allowed_edges: ndims = len(initialNetwork) allowed_edges = [(i,j) for i in range(ndims) for j in range(i+1,ndims)] stop_on_plateau = training_args['stop_on_plateau'] if 'stop_on_plateau' in training_args else None if stop_on_plateau: detect_plateau = DetectPlateau(**stop_on_plateau) training_args["stop_condition"] = lambda train_loss,val_loss : detect_plateau(train_loss) initial_epochs = training_args['initial_epochs'] if 'initial_epochs' in training_args else None best_loss, best_network, best_network_optimizer_state = np.infty, None, None best_search_loss = best_loss best_train_lost_hist = [] best_val_loss_hist = [] for (i,j) in allowed_edges: currentNetwork = cc.copy_network(initialNetwork) #increase rank along a chosen dimension currentNetwork = cc.increase_rank(currentNetwork,i, j, rank_increment, padding_noise) currentNetwork = cc.make_trainable(currentNetwork) print('\ntesting rank increment for i =', i, 'j = ', j) ### Search optimization phase # we do only a few epochs to identify the most promising rank update # we first optimize only the new slices for a few epochs print("optimize new slices for a few epochs") search_args = dict(training_args) current_network_optimizer_state = {} search_args["save_optimizer_state"] = True search_args["optimizer_state"] = current_network_optimizer_state if not is_reg: search_args["only_new_slice"] = (i,j) if stop_on_plateau: detect_plateau._reset() [currentNetwork, first_loss, current_loss, best_epoch, hist] = training(currentNetwork, training_args["initial_epochs"], train_data, loss_fun, val_data=val_data, epochs=search_epochs, other_args=search_args) first_loss = hist[0][0] train_lost_hist = deepcopy(hist[0][:best_epoch]) val_loss_hist = deepcopy(hist[1][:best_epoch]) # We then optimize all parameters for a few epochs print("\noptimize all parameters for a few epochs") search_args["load_optimizer_state"] = dict(current_network_optimizer_state) search_args["only_new_slice"] = False if stop_on_plateau: detect_plateau._reset() [currentNetwork, first_loss, current_loss, best_epoch, hist] = training(currentNetwork, initial_epochs, train_data, loss_fun, val_data=val_data, epochs=search_epochs, other_args=search_args) search_args["load_optimizer_state"] = None train_lost_hist += deepcopy(hist[0][:best_epoch]) val_loss_hist += deepcopy(hist[1][:best_epoch]) print(f"\nCurrent loss is {current_loss:.7f} Best loss from previous discrete optim is {prev_best_loss}") if best_search_loss > current_loss: best_search_loss = current_loss best_network = currentNetwork best_network_optimizer_state = deepcopy(current_network_optimizer_state) best_train_lost_hist = train_lost_hist best_val_loss_hist = val_loss_hist print('-> best rank update so far:', i,j) return best_network,best_search_loss,best_network_optimizer_state,best_train_lost_hist,best_val_loss_hist
def training(tensor_list, initial_epochs, train_data, loss_fun, val_data, epochs, other_args): """ This function run the continuous optimization routine with a small tweak: if there is no improvement of the loss in the first [initial_epochs] epochs, the learning rate is reduced by a factor of 0.5 and optimization is restarted from the beginning. If initial epochs is None, the continuous_optim from core_code is called directly. All arguments are the same as cc.continuous_optim except for [initial_epochs] and the following additional optional arguments for other_args: only_new_slice: either None, in which case optimization is performed on all parameters, or if set to a tuple of indices (corresponding to a newly incremented edge) it will only optimize the new slices of the corresponding core tensors. (default=False) """ only_new_slice = other_args['only_new_slice'] if 'only_new_slice' in other_args else False if only_new_slice: i,j = only_new_slice def grad_masking_function(tensor_list): nonlocal i,j for k in range(len(tensor_list)): if k == i: tensor_list[i].grad.permute([j]+list(range(0,j))+list(range(j+1,len(currentNetwork))))[:-1,:,...] *= 0 elif k == j: tensor_list[j].grad.permute([i]+list(range(0,i))+list(range(i+1,len(currentNetwork))))[:-1,:,...] *= 0 else: tensor_list[k].grad *= 0 else: grad_masking_function = None args = deepcopy(other_args) args["grad_masking_function"] = grad_masking_function if initial_epochs is None: return cc.continuous_optim(tensor_list, train_data, loss_fun, val_data=val_data, epochs=epochs, other_args=args) args["hist"] = True current_network_optimizer_state = other_args["optimizer_state"] if "optimizer_state" in args else {} args["save_optimizer_state"] = True args["optimizer_state"] = current_network_optimizer_state currentNetwork = cc.copy_network(tensor_list) remaining_epochs = epochs - initial_epochs if epochs else None hist = None while hist is None or (hist[0][0] < hist[0][-1]): if hist: if "load_optimizer_state" in args and args["load_optimizer_state"]: args["load_optimizer_state"]["optimizer_state"]['param_groups'][0]['lr'] /= 2 lr = args["load_optimizer_state"]["optimizer_state"]['param_groups'][0]['lr'] else: args["lr"] /= 2 lr = args["lr"] print(f"\n[!] No progress in first {initial_epochs} epochs, starting again with smaller learning rate ({lr})") currentNetwork = cc.copy_network(tensor_list) [currentNetwork, first_loss, current_loss, best_epoch, hist] = cc.continuous_optim(currentNetwork, train_data, loss_fun, val_data=val_data, epochs=initial_epochs, other_args=args) hist_initial = [h[:best_epoch].tolist() for h in hist] best_epoch_initial = best_epoch args["load_optimizer_state"] = current_network_optimizer_state [currentNetwork, first_loss, current_loss, best_epoch, hist] = cc.continuous_optim(currentNetwork, train_data, loss_fun, val_data=val_data, epochs=remaining_epochs, other_args=args) hist = [h_init + h[:best_epoch].tolist() for h_init,h in zip(hist_initial,hist)] best_epoch += best_epoch_initial return [currentNetwork, first_loss, current_loss, best_epoch, hist] if "hist" in other_args else [currentNetwork, first_loss, current_loss]
def discrete_optim_template(tensor_list, train_data, loss_fun, val_data=None, other_args=dict()): """ Train a tensor network using discrete optimization over TN ranks Args: tensor_list: List of tensors encoding the network being trained train_data: The data used to train the network loss_fun: Scalar-valued loss function of the type tens_list, data -> scalar_loss (This depends on the task being learned) val_data: The data used for validation, which can be used to for early stopping within continuous optimization calls within discrete optimization loop other_args: Dictionary of other arguments for the optimization, with some options below (feel free to add more) epochs: Number of epochs for continuous optimization (default=10) optim: Choice of Pytorch optimizer (default='SGD') lr: Learning rate for optimizer (default=1e-3) bsize: Minibatch size for training (default=100) reps: Number of times to repeat training data per epoch (default=1) cprint: Whether to print info from continuous optimization (default=True) dprint: Whether to print info from discrete optimization (default=True) dhist: Whether to return losses from intermediate TNs (default=False) Returns: better_list: List of tensors with same length as tensor_list, but having been optimized using the discrete optimization algorithm. The TN ranks of better_list will be larger than those of tensor_list. first_loss: Initial loss of the model on the validation set, before any training. If no val set is provided, the first training loss is instead returned best_loss: The value of the validation/training loss for the model output as better_list loss_record: If dhist=True in other_args, all values of best_loss associated with intermediate optimized TNs will be returned as a PyTorch vector, with loss_record[0] giving the initial loss of the model, and loss_record[-1] equal to best_loss value returned by discrete_optim. """ # Check input and initialize local record variables epochs = other_args['epochs'] if 'epochs' in other_args else 10 dprint = other_args['dprint'] if 'dprint' in other_args else True dhist = other_args['dhist'] if 'dhist' in other_args else False loss_rec, first_loss, best_loss, best_network = [], None, None, None if dhist: loss_record = [] # (train_record, val_record) # Function to maybe print, conditioned on `dprint` m_print = lambda s: print(s) if dprint else None # Function to record loss information def record_loss(new_loss, new_network): # Load record variables from outer scope nonlocal loss_rec, first_loss, best_loss, best_network # Check for best loss if best_loss is None or new_loss < best_loss: best_loss, best_network = new_loss, new_network # Add new loss to our loss record if not dhist: return nonlocal loss_record loss_record.append(new_loss) # Copy tensor_list so the original is unchanged tensor_list = cc.copy_network(tensor_list) # Define a function giving the stop condition for the discrete # optimization procedure. I'm using a simple example here which could # work for greedy or random walk searches, but for other optimization # methods this could be trivial stop_cond = generate_stop_cond(cc.get_indims(tensor_list)) # Iteratively increment ranks of tensor_list and train via # continuous_optim, at each stage using a search procedure to # test out different ranks before choosing just one to increase stage = 0 better_network, better_loss = tensor_list, 1e10 while not stop_cond(better_network): if first_loss is None: # Record initial loss of TN model first_args = other_args first_args["print"] = first_args["hist"] = False _, first_loss, _ = cc.continuous_optim(tensor_list, train_data, loss_fun, epochs=1, val_data=val_data, other_args=first_args) m_print("Initial model has TN ranks") if dprint: cc.print_ranks(tensor_list) m_print(f"Initial loss is {first_loss:.3f}") continue m_print(f"STAGE {stage}") # Try out training different network ranks and assign network # with best ranks to better_network # # TODO: This is your part to write! Use new variables for the loss # and network being tried out, with the network being # initialized from better_network. When you've found a better # TN, make sure to update better_network and better_loss to # equal the parameters and loss of that TN. # At some point this code will call the continuous optimization # loop, in which case you can use the following command: # # trained_tn, init_loss, final_loss = cc.continuous_optim(my_tn, # train_data, loss_fun, # epochs=epochs, # val_data=val_data, # other_args=other_args) # Record the loss associated with the best network from this # discrete optimization loop record_loss(better_loss, better_network) stage += 1 if dhist: loss_record = tuple(torch.tensor(fr) for fr in loss_record) return best_network, first_loss, best_loss, loss_record else: return best_network, first_loss, best_loss