def randomwalk_optim(tensor_list, train_data, loss_fun, val_data=None, other_args=dict(), max_iter=None): """ Train a tensor network using discrete optimization over TN ranks Args: tensor_list: List of tensors encoding the network being trained train_data: The data used to train the network(target Tensor) loss_fun: Scalar-valued loss function of the type tens_list, data -> scalar_loss (This depends on the task being learned) val_data: The data used for validation, which can be used to for early stopping within continuous optimization calls within discrete optimization loop other_args: Dictionary of other arguments for the optimization, with some options below (feel free to add more) epochs: Number of epochs for continuous optimization (default=10) max_iter: Maximum number of iterations for random walk search (default=10) optim: Choice of Pytorch optimizer (default='SGD') lr: Learning rate for optimizer (default=1e-3) bsize: Minibatch size for training (default=100) reps: Number of times to repeat training data per epoch (default=1) cprint: Whether to print info from continuous optimization (default=True) dprint: Whether to print info from discrete optimization (default=True) search_epochs: Number of epochs to use to identify the best rank 1 update. If None, the epochs argument is used. (default=None) loss_threshold: if loss gets below this threshold, discrete optimization is stopped (default=1e-5) initial_epochs: Number of epochs after which the learning rate is reduced and optimization is restarted if there is no improvement in the loss. (default=None) rank_increment: how much should a rank be increase at each discrete iterations (default=1) stop_on_plateau: a dictionnary containing keys mode (min/max) patience threshold used to stop continuous optimizaion when plateau is detected (default=None) gradient_hook: this is a hack, please ignore... Returns: better_list: List of tensors with same length as tensor_list, but having been optimized using the discrete optimization algorithm. The TN ranks of better_list will be larger than those of tensor_list. best_loss: The value of the validation/training loss for the model output as better_list loss_record: If dhist=True in other_args, this records the history of all losses for discrete and continuous optimization. This is a list of dictionnaries with keys iter,num_params,network,loss,train_loss_hist,val_loss_hist where iter: iteration of the discrete optimization num_params: number of parameters of best network for this iteration network: list of tensors for the best network for this iteration loss: loss achieved by the best network in this iteration train_loss_hist: history of losses for the continuous optimization for this iteration starting from previous best_network to the epoch where the new best network was found TODO: add val_loss_hist to the list of keys """ # Check input and initialize local record variables epochs = other_args['epochs'] if 'epochs' in other_args else 10 max_iter = other_args['max_iter'] if 'max_iter' in other_args else 10 dprint = other_args['dprint'] if 'dprint' in other_args else True cprint = other_args['cprint'] if 'cprint' in other_args else True search_epochs = other_args[ 'search_epochs'] if 'search_epochs' in other_args else epochs loss_threshold = other_args[ 'loss_threshold'] if 'loss_threshold' in other_args else 1e-5 initial_epochs = other_args[ 'initial_epochs'] if 'initial_epochs' in other_args else None rank_increment = other_args[ 'rank_increment'] if 'rank_increment' in other_args else 1 gradient_hook = other_args[ 'gradient_hook'] if 'gradient_hook' in other_args else None is_reg = other_args['is_reg'] if 'is_reg' in other_args else False stop_on_plateau = other_args[ 'stop_on_plateau'] if 'stop_on_plateau' in other_args else None if stop_on_plateau: detect_plateau = DetectPlateau(**stop_on_plateau) other_args[ "stop_condition"] = lambda train_loss, val_loss: detect_plateau( train_loss) first_loss, best_loss, best_network = ([], []), None, None d_loss_hist = ([], []) other_args['hist'] = True # we always track the history of losses # Function to maybe print, conditioned on `dprint` m_print = lambda s: print(s) if dprint else None # Copy tensor_list so the original is unchanged tensor_list = cc.copy_network(tensor_list) loss_record = [{ 'iter': -1, 'network': tensor_list, 'num_params': cc.num_params(tensor_list), 'train_loss_hist': [0], 'val_loss_hist': [0] }] # Define a function giving the stop condition for the discrete # optimization procedure. I'm using a simple example here which could # work for greedy or random walk searches, but for other optimization # methods this could be trivial stop_cond = lambda loss: loss < loss_threshold # Iteratively increment ranks of tensor_list and train via # continuous_optim, at each stage using a search procedure to # test out different ranks before choosing just one to increase stage = -1 best_loss, best_network, best_network_optimizer_state = np.infty, None, None while not stop_cond(best_loss) and stage < max_iter: stage += 1 if best_loss is np.infty: # first continuous optimization tensor_list, first_loss, best_loss, best_epoch, hist = training( tensor_list, initial_epochs, train_data, loss_fun, epochs=epochs, val_data=val_data, other_args=other_args) m_print("Initial model has TN ranks") if dprint: cc.print_ranks(tensor_list) m_print(f"Initial loss is {best_loss:.7f}") initialNetwork = cc.copy_network(tensor_list) best_network = cc.copy_network(tensor_list) loss_record[0]["loss"] = first_loss loss_record.append({ 'iter': stage, 'network': best_network, 'num_params': cc.num_params(best_network), 'loss': best_loss, 'train_loss_hist': hist[0][:best_epoch], 'val_loss_hist': hist[1][:best_epoch] }) else: m_print( f"\n\n**** Discrete optimization - iteration {stage} ****\n\n\n" ) best_search_loss = best_loss best_train_lost_hist = [] best_val_loss_hist = [] [i, j] = torch.randperm(len(tensor_list))[:2].tolist() currentNetwork = cc.copy_network(initialNetwork) #increase rank along a chosen dimension currentNetwork = cc.increase_rank(currentNetwork, i, j, rank_increment, 1e-6) currentNetwork = cc.make_trainable(currentNetwork) print('\ntesting rank increment for i =', i, 'j = ', j) ### Search optimization phase # we d only a few epochs to identify the most promising rank update # function to zero out the gradient of all entries except for the new slices def grad_masking_function(tensor_list): nonlocal i, j for k in range(len(tensor_list)): if k == i: tensor_list[i].grad.permute( [j] + list(range(0, j)) + list(range(j + 1, len(currentNetwork))))[:-1, :, ...] *= 0 elif k == j: tensor_list[j].grad.permute( [i] + list(range(0, i)) + list(range(i + 1, len(currentNetwork))))[:-1, :, ...] *= 0 else: tensor_list[k].grad *= 0 # we first optimize only the new slices for a few epochs print("optimize new slices for a few epochs") search_args = dict(other_args) current_network_optimizer_state = {} search_args["save_optimizer_state"] = True search_args["optimizer_state"] = current_network_optimizer_state if not is_reg: search_args["grad_masking_function"] = grad_masking_function if stop_on_plateau: detect_plateau._reset() [currentNetwork, first_loss, current_loss, best_epoch, hist] = training(currentNetwork, initial_epochs, train_data, loss_fun, val_data=val_data, epochs=search_epochs, other_args=search_args) first_loss = hist[0][0] train_lost_hist = deepcopy(hist[0][:best_epoch]) val_loss_hist = deepcopy(hist[1][:best_epoch]) # # We then optimize all parameters for a few epochs # print("\noptimize all parameters for a few epochs") # search_args["grad_masking_function"] = None # search_args["load_optimizer_state"] = dict(current_network_optimizer_state) # if stop_on_plateau: # detect_plateau._reset() # [currentNetwork, first_loss, current_loss, best_epoch, hist] = training(currentNetwork, initial_epochs, train_data, # loss_fun, val_data=val_data, epochs=search_epochs, # other_args=search_args) # search_args["load_optimizer_state"] = None # train_lost_hist += deepcopy(hist[0][:best_epoch]) best_search_loss = current_loss best_network = currentNetwork best_network_optimizer_state = deepcopy( current_network_optimizer_state) best_train_lost_hist = train_lost_hist best_val_loss_hist = val_loss_hist print('-> best rank update so far:', i, j) best_loss = best_search_loss # train network to convergence print('\ntraining best network until max_epochs/convergence...') other_args["load_optimizer_state"] = best_network_optimizer_state current_network_optimizer_state = {} other_args["save_optimizer_state"] = True other_args["optimizer_state"] = current_network_optimizer_state if gradient_hook: # A bit of a hack only used for tensor completion - ignore other_args["grad_masking_function"] = gradient_hook if stop_on_plateau: detect_plateau._reset() [best_network, first_loss, best_loss, best_epoch, hist] = training(best_network, initial_epochs, train_data, loss_fun, val_data=val_data, epochs=epochs, other_args=other_args) if gradient_hook: other_args["grad_masking_function"] = None other_args["load_optimizer_state"] = None best_train_lost_hist += deepcopy(hist[0][:best_epoch]) best_val_loss_hist += deepcopy(hist[1][:best_epoch]) initialNetwork = cc.copy_network(best_network) loss_record.append({ 'iter': stage, 'network': best_network, 'num_params': cc.num_params(best_network), 'loss': best_loss, 'train_loss_hist': best_train_lost_hist, 'val_loss_hist': best_val_loss_hist }) print('\nbest TN:') cc.print_ranks(best_network) print('number of params:', cc.num_params(best_network)) print([(r['iter'], r['num_params'], float(r['loss']), float(r['train_loss_hist'][0]), float(r['train_loss_hist'][-1])) for r in loss_record]) return best_network, best_loss, loss_record
def greedy_optim(tensor_list, train_data, loss_fun, val_data=None, other_args=dict(),max_iter=None): """ Train a tensor network using discrete optimization over TN ranks Args: tensor_list: List of tensors encoding the network being trained train_data: The data used to train the network(target Tensor) loss_fun: Scalar-valued loss function of the type tens_list, data -> scalar_loss (This depends on the task being learned) val_data: The data used for validation, which can be used to for early stopping within continuous optimization calls within discrete optimization loop other_args: Dictionary of other arguments for the optimization, with some options below (feel free to add more) epochs: Number of epochs for continuous optimization (default=10) max_iter: Maximum number of iterations for greedy search (default=10) optim: Choice of Pytorch optimizer (default='SGD') lr: Learning rate for optimizer (default=1e-3) bsize: Minibatch size for training (default=100) reps: Number of times to repeat training data per epoch (default=1) cprint: Whether to print info from continuous optimization (default=True) dprint: Whether to print info from discrete optimization (default=True) dhist: Whether to return losses from intermediate TNs (default=False) search_epochs: Number of epochs to use to identify the best rank 1 update. If None, the epochs argument is used. (default=None) loss_threshold: if loss gets below this threshold, discrete optimization is stopped (default=1e-5) initial_epochs: Number of epochs after which the learning rate is reduced and optimization is restarted if there is no improvement in the loss. (default=None) Returns: better_list: List of tensors with same length as tensor_list, but having been optimized using the discrete optimization algorithm. The TN ranks of better_list will be larger than those of tensor_list. first_loss: Initial loss of the model on the validation set, before any training. If no val set is provided, the first training loss is instead returned best_loss: The value of the validation/training loss for the model output as better_list loss_record: If dhist=True in other_args, all values of best_loss associated with intermediate optimized TNs will be returned as a PyTorch vector, with loss_record[0] giving the initial loss of the model, and loss_record[-1] equal to best_loss value returned by discrete_optim. """ # Check input and initialize local record variables epochs = other_args['epochs'] if 'epochs' in other_args else 10 max_iter = other_args['max_iter'] if 'max_iter' in other_args else 10 dprint = other_args['dprint'] if 'dprint' in other_args else True cprint = other_args['cprint'] if 'cprint' in other_args else True dhist = other_args['dhist'] if 'dhist' in other_args else False search_epochs = other_args['search_epochs'] if 'search_epochs' in other_args else None loss_threshold = other_args['loss_threshold'] if 'loss_threshold' in other_args else 1e-5 initial_epochs = other_args['initial_epochs'] if 'initial_epochs' in other_args else None stop_cond = lambda loss: loss < loss_threshold if dhist: loss_record = [] # (train_record, val_record) # Function to maybe print, conditioned on `dprint` m_print = lambda s: print(s) if dprint else None # Copy tensor_list so the original is unchanged tensor_list = cc.copy_network(tensor_list) # Iteratively increment ranks of tensor_list and train via # continuous_optim, at each stage using a search procedure to # test out different ranks before choosing just one to increase stage = 0 loss_record, best_loss, best_network = [], np.infty, None while not stop_cond(best_loss) and stage < max_iter: stage += 1 if best_loss is np.infty: # first continuous optimization tensor_list, first_loss, best_loss = training( tensor_list, initial_epochs, train_data, loss_fun, epochs=epochs, val_data=val_data,other_args=other_args) m_print("Initial model has TN ranks") if dprint: cc.print_ranks(tensor_list) m_print(f"Initial loss is {best_loss:.7f}") initialNetwork = cc.copy_network(tensor_list) best_network = cc.copy_network(tensor_list) loss_record += [first_loss,best_loss] m_print(f"\n\n**** Discrete optimization - iteration {stage} ****\n\n\n") best_search_loss = best_loss for i in range(len(initialNetwork)): for j in range(i+1, len(initialNetwork)): currentNetwork = cc.copy_network(initialNetwork) #increase rank along a chosen dimension currentNetwork = cc.increase_rank(currentNetwork,i, j, 1, 1e-6) print('\ntesting rank increment for i =', i, 'j = ', j) if search_epochs: # we d only a few epochs to identify the most promising rank update # function to zero out the gradient of all entries except for the new slices def grad_masking_function(tensor_list): nonlocal i,j for k in range(len(tensor_list)): if k == i: tensor_list[i].grad.permute([j]+list(range(0,j))+list(range(j+1,len(currentNetwork))))[:-1,:,...] *= 0 elif k == j: tensor_list[j].grad.permute([i]+list(range(0,i))+list(range(i+1,len(currentNetwork))))[:-1,:,...] *= 0 else: tensor_list[k].grad *= 0 # we first optimize only the new slices for a few epochs print("optimize new slices for a few epochs") search_args = dict(other_args) search_args["hist"] = True current_network_optimizer_state = {} search_args["save_optimizer_state"] = True search_args["optimizer_state"] = current_network_optimizer_state search_args["grad_masking_function"] = grad_masking_function [currentNetwork, first_loss, current_loss, hist] = training(currentNetwork, initial_epochs, train_data, loss_fun, val_data=val_data, epochs=search_epochs, other_args=search_args) first_loss = hist[0][0] # We then optimize all parameters for a few epochs print("\noptimize all parameters for a few epochs") search_args["grad_masking_function"] = None search_args["load_optimizer_state"] = dict(current_network_optimizer_state) [currentNetwork, first_loss, current_loss, hist] = training(currentNetwork, initial_epochs, train_data, loss_fun, val_data=val_data, epochs=search_epochs , other_args=search_args) search_args["load_optimizer_state"] = None else: # we fully optimize the network in the search phase [currentNetwork, first_loss, current_loss] = cc.continuous_optim(currentNetwork, train_data, loss_fun, val_data=val_data, epochs=epochs, other_args=other_args) m_print(f"\nCurrent loss is {current_loss:.7f} Best loss from previous discrete optim is {best_loss}") if best_search_loss > current_loss: best_search_loss = current_loss best_network = currentNetwork best_network_optimizer_state = deepcopy(current_network_optimizer_state) print('-> best rank update so far:', i,j) best_loss = best_search_loss # train network to convergence for the best rank increment (if search_epochs is set, # otherwise the best network is already trained to convergence / max_epochs) if search_epochs: print('\ntraining best network until max_epochs/convergence...') other_args["load_optimizer_state"] = best_network_optimizer_state current_network_optimizer_state = {} other_args["save_optimizer_state"] = True other_args["optimizer_state"] = current_network_optimizer_state [best_network, first_loss, best_loss] = training(best_network, initial_epochs, train_data, loss_fun, val_data=val_data, epochs=epochs, other_args=other_args) other_args["load_optimizer_state"] = None initialNetwork = cc.copy_network(best_network) loss_record.append((stage, cc.num_params(best_network), float(best_loss))) print('\nbest TN:') cc.print_ranks(best_network) print('number of params:',cc.num_params(best_network)) print(loss_record) return best_network, first_loss, best_loss
for fn in sys.argv[1:]: with open(fn, 'rb') as f: results = pickle.load(f) results["Greedy"] = results["greedy"] del results["greedy"] results["Random walk"] = results["randomwalk"] del results["randomwalk"] methods = ["Greedy", "CP", "Tucker", "TT", "Random walk"] plt.figure(1) plt.subplot(1, 3, subplot_counter) for method in methods: dloss = [r['loss'] for r in results[method]] num_params = [r['num_params'] for r in results[method]] plt.plot(num_params, dloss, '.-') print(print_ranks(results["Greedy"][-1]["network"])) plt.axvline(x=results["_xp-infos_"]["targt_TN_params"], ls='--', c='black', lw=0.8) plt.legend(methods) # + ["opt. params"]) plt.xlabel('parameters') plt.ylabel('reconstruction error') plt.tight_layout() plt.figure(2) plt.subplot(1, 3, subplot_counter) losses = moving_average([ loss for r in results['Greedy'][1:] for loss in r['train_loss_hist']
path.basename(target_file))[0] + ".pickle" if path.exists(result_file): print(f"output file already exists ({result_file})") sys.exit(-1) goal_tn = torch.load(target_file) target_full = cc.wire_network(goal_tn, give_dense=True).numpy() input_dims = target_full.shape target_TN_params = cc.num_params(goal_tn) target_full_parms = np.prod(input_dims) print('target tensor network number of params: ', target_TN_params) print('number of params for full target tensor:', target_full_parms) print('target tensor norm:', cc.l2_norm(goal_tn)) print('target tensor ranks:') cc.print_ranks(goal_tn) results = { "_xp-infos_": { 'targt_TN_params': target_TN_params, 'target_full_parms': target_full_parms, 'target_network': goal_tn } } from tensor_decomposition_models import incremental_tensor_decomposition from randomsearch import randomsearch_decomposition from randomwalk import randomwalk_decomposition from greedy import greedy_decomposition for decomp in ["greedy" ]: #"randomsearch randomwalk greedy CP TT Tucker".split():
def randomsearch_optim(tensor_list, train_data, loss_fun, val_data=None, other_args=dict()): """ Train a tensor network using discrete optimization over TN ranks Args: tensor_list: List of tensors encoding the network being trained train_data: The data used to train the network loss_fun: Scalar-valued loss function of the type tens_list, data -> scalar_loss (This depends on the task being learned) val_data: The data used for validation, which can be used to for early stopping within continuous optimization calls within discrete optimization loop other_args: Dictionary of other arguments for the optimization, with some options below (feel free to add more) epochs: Number of epochs for continuous optimization (default=10) max_iter: Maximum number of iterations for random search (default=10) optim: Choice of Pytorch optimizer (default='SGD') lr: Learning rate for optimizer (default=1e-3) bsize: Minibatch size for training (default=100) max_rank: Maximum rank to search for (default=7) reps: Number of times to repeat training data per epoch (default=1) cprint: Whether to print info from continuous optimization (default=True) dprint: Whether to print info from discrete optimization (default=True) loss_threshold: if loss gets below this threshold, discrete optimization is stopped (default=1e-5) initial_epochs: Number of epochs after which the learning rate is reduced and optimization is restarted if there is no improvement in the loss. (default=None) stop_on_plateau: a dictionnary containing keys mode (min/max) patience threshold used to stop continuous optimizaion when plateau is detected (default=None) Returns: better_list: List of tensors with same length as tensor_list, but having been optimized using the discrete optimization algorithm. The TN ranks of better_list will be larger than those of tensor_list. best_loss: The value of the validation/training loss for the model output as better_list loss_record: If dhist=True in other_args, this records the history of all losses for discrete and continuous optimization. This is a list of dictionnaries with keys iter,num_params,network,loss,train_loss_hist where iter: iteration of the discrete optimization num_params: number of parameters of best network for this iteration network: list of tensors for the best network for this iteration loss: loss achieved by the best network in this iteration train_loss_hist: history of losses for the continuous optimization for this iteration starting from previous best_network to the epoch where the new best network was found TODO: add val_lost_hist to the list of keys """ # Check input and initialize local record variables epochs = other_args['epochs'] if 'epochs' in other_args else 10 max_iter = other_args['max_iter'] if 'max_iter' in other_args else 10 max_rank = other_args['max_rank'] if 'max_rank' in other_args else 7 max_params = other_args[ 'max_params'] if 'max_params' in other_args else 10000 dprint = other_args['dprint'] if 'dprint' in other_args else True cprint = other_args['cprint'] if 'cprint' in other_args else True loss_threshold = other_args[ 'loss_threshold'] if 'loss_threshold' in other_args else 1e-5 initial_epochs = other_args[ 'initial_epochs'] if 'initial_epochs' in other_args else None # Keep track of continous optimization history other_args['hist'] = True stop_on_plateau = other_args[ 'stop_on_plateau'] if 'stop_on_plateau' in other_args else None if stop_on_plateau: detect_plateau = DetectPlateau(**stop_on_plateau) other_args[ "stop_condition"] = lambda train_loss, val_loss: detect_plateau( train_loss) first_loss, best_loss, best_network = ([], []), None, None d_loss_hist = ([], []) # Function to maybe print, conditioned on `dprint` m_print = lambda s: print(s) if dprint else None # Copy tensor_list so the original is unchanged tensor_list = cc.copy_network(tensor_list) loss_record = [{ 'iter': -1, 'network': tensor_list, 'num_params': cc.num_params(tensor_list), 'train_loss_hist': [0] }] # Define a function giving the stop condition for the discrete # optimization procedure. I'm using a simple example here which could # work for greedy or random walk searches, but for other optimization # methods this could be trivial stop_cond = lambda loss: loss < loss_threshold input_dims = cc.get_indims(tensor_list) n_cores = len(input_dims) n_edges = n_cores * (n_cores - 1) // 2 # Iteratively increment ranks of tensor_list and train via # continuous_optim, at each stage using a search procedure to # test out different ranks before choosing just one to increase stage = -1 best_loss, best_network, best_network_optimizer_state = np.infty, None, None while not stop_cond(best_loss) and stage < max_iter: stage += 1 if best_loss is np.infty: # first continuous optimization tensor_list, first_loss, best_loss, best_epoch, hist = training( tensor_list, initial_epochs, train_data, loss_fun, epochs=epochs, val_data=val_data, other_args=other_args) m_print("Initial model has TN ranks") if dprint: cc.print_ranks(tensor_list) m_print(f"Initial loss is {best_loss:.7f}") initialNetwork = cc.copy_network(tensor_list) best_network = cc.copy_network(tensor_list) loss_record[0]["loss"] = first_loss loss_record.append({ 'iter': stage, 'network': best_network, 'num_params': cc.num_params(best_network), 'loss': best_loss, 'train_loss_hist': hist[0][:best_epoch] }) else: m_print( f"\n\n**** Discrete optimization - iteration {stage} ****\n\n\n" ) # Create new TN random ranks ranks = torch.randint(low=1, high=max_rank + 1, size=(n_edges, 1)).view(-1, ).tolist() rank_list = _make_ranks(ranks) currentNetwork = _limit_random_tn(input_dims, rank=rank_list, max_params=max_params) currentNetwork = cc.make_trainable(currentNetwork) if stop_on_plateau: detect_plateau._reset() currentNetwork, first_loss, current_loss, best_epoch, hist = training( currentNetwork, initial_epochs, train_data, loss_fun, val_data=val_data, epochs=epochs, other_args=other_args) train_lost_hist = hist[0][:best_epoch] loss_record.append({ 'iter': stage, 'network': currentNetwork, 'num_params': cc.num_params(currentNetwork), 'loss': current_loss, 'train_loss_hist': train_lost_hist }) if best_loss > current_loss: best_network = currentNetwork best_loss = current_loss print('\nbest TN:') cc.print_ranks(best_network) print('number of params:', cc.num_params(best_network)) print([(r['iter'], r['num_params'], float(r['loss']), float(r['train_loss_hist'][0]), float(r['train_loss_hist'][-1])) for r in loss_record]) return best_network, best_loss, loss_record
def greedy_optim(tensor_list, train_data, loss_fun, find_best_edge, val_data=None, other_args=dict(),max_iter=None): """ Train a tensor network using discrete optimization over TN ranks Args: tensor_list: List of tensors encoding the network being trained train_data: The data used to train the network(target Tensor) loss_fun: Scalar-valued loss function of the type tens_list, data -> scalar_loss (This depends on the task being learned) val_data: The data used for validation, which can be used to for early stopping within continuous optimization calls within discrete optimization loop find_best_edge: function identifying the most promising edge and returning a new network with incremented edge. This function must accept the following arguments: initialNetwork train_data val_data loss_fun rank_increment training_args prev_best_loss search_epochs (specific to current searching method) padding_noise is_reg (hack, ignore) other_args: Dictionary of other arguments for the optimization, with some options below (feel free to add more) epochs: Number of epochs for continuous optimization (default=10) max_iter: Maximum number of iterations for greedy search (default=10) optim: Choice of Pytorch optimizer (default='SGD') lr: Learning rate for optimizer (default=1e-3) bsize: Minibatch size for training (default=100) reps: Number of times to repeat training data per epoch (default=1) cprint: Whether to print info from continuous optimization (default=True) dprint: Whether to print info from discrete optimization (default=True) search_epochs: Number of epochs to use to identify the best rank 1 update. If None, the epochs argument is used. (default=None) loss_threshold: if loss gets below this threshold, discrete optimization is stopped (default=1e-5) initial_epochs: Number of epochs after which the learning rate is reduced and optimization is restarted if there is no improvement in the loss. (default=None) rank_increment: how much should a rank be increase at each discrete iterations (default=1) stop_on_plateau: a dictionnary containing keys mode (min/max) patience threshold used to stop continuous optimizaion when plateau is detected (default=None) filename: pickle filename to save the loss_record after each continuous optimization. (default=None) Returns: better_list: List of tensors with same length as tensor_list, but having been optimized using the discrete optimization algorithm. The TN ranks of better_list will be larger than those of tensor_list. best_loss: The value of the validation/training loss for the model output as better_list loss_record: If dhist=True in other_args, this records the history of all losses for discrete and continuous optimization. This is a list of dictionnaries with keys iter,num_params,network,loss,train_loss_hist,val_loss_hist where iter: iteration of the discrete optimization num_params: number of parameters of best network for this iteration network: list of tensors for the best network for this iteration loss: loss achieved by the best network in this iteration train_loss_hist: history of losses for the continuous optimization for this iteration starting from previous best_network to the epoch where the new best network was found TODO: add val_lost_hist to the list of keys """ # Check input and initialize local record variables epochs = other_args['epochs'] if 'epochs' in other_args else 10 max_iter = other_args['max_iter'] if 'max_iter' in other_args else 10 dprint = other_args['dprint'] if 'dprint' in other_args else True cprint = other_args['cprint'] if 'cprint' in other_args else True search_epochs = other_args['search_epochs'] if 'search_epochs' in other_args else epochs loss_threshold = other_args['loss_threshold'] if 'loss_threshold' in other_args else 1e-5 initial_epochs = other_args['initial_epochs'] if 'initial_epochs' in other_args else None rank_increment = other_args['rank_increment'] if 'rank_increment' in other_args else 1 gradient_hook = other_args['gradient_hook'] if 'gradient_hook' in other_args else None #allowed_edges = other_args['allowed_edges'] if 'allowed_edges' in other_args else None filename = other_args['filename'] if 'filename' in other_args else None is_reg = other_args['is_reg'] if 'is_reg' in other_args else False assert initial_epochs < search_epochs and (not epochs or initial_epochs < epochs), "initial_epochs must be smaller than search_epochs and epochs" stop_on_plateau = other_args['stop_on_plateau'] if 'stop_on_plateau' in other_args else None if stop_on_plateau: detect_plateau = DetectPlateau(**stop_on_plateau) other_args["stop_condition"] = lambda train_loss,val_loss : detect_plateau(train_loss) first_loss, best_loss, best_network = ([],[]), None, None d_loss_hist = ([], []) other_args['hist'] = True # we always track the history of losses # Function to maybe print, conditioned on `dprint` m_print = lambda s: print(s) if dprint else None # Copy tensor_list so the original is unchanged tensor_list = cc.copy_network(tensor_list) loss_record = [{'iter':-1,'network':tensor_list, 'num_params':cc.num_params(tensor_list), 'train_loss_hist':[0], 'val_loss_hist':[0]}] # Define a function giving the stop condition for the discrete # optimization procedure. I'm using a simple example here which could # work for greedy or random walk searches, but for other optimization # methods this could be trivial stop_cond = lambda loss: loss < loss_threshold # Iteratively increment ranks of tensor_list and train via # continuous_optim, at each stage using a search procedure to # test out different ranks before choosing just one to increase stage = -1 best_loss, best_network, best_network_optimizer_state = np.infty, None, None while not stop_cond(best_loss) and stage < max_iter: stage += 1 if best_loss is np.infty: # first continuous optimization tensor_list, first_loss, best_loss, best_epoch, hist = training(tensor_list, initial_epochs, train_data, loss_fun, epochs=epochs, val_data=val_data,other_args=other_args) m_print("Initial model has TN ranks") if dprint: cc.print_ranks(tensor_list) m_print(f"Initial loss is {best_loss:.7f}") initialNetwork = cc.copy_network(tensor_list) best_network = cc.copy_network(tensor_list) loss_record[0]["loss"] = first_loss loss_record.append({'iter':stage, 'network':best_network, 'num_params':cc.num_params(best_network), 'loss':best_loss, 'train_loss_hist':hist[0][:best_epoch], 'val_loss_hist':hist[1][:best_epoch]}) if filename: with open(filename, "wb") as f: pickle.dump(loss_record,f) else: m_print(f"\n\n**** Discrete optimization - iteration {stage} ****\n\n\n") best_search_loss = best_loss best_train_lost_hist = [] best_val_loss_hist = [] best_network,best_loss, best_network_optimizer_state,best_train_lost_hist,best_val_loss_hist = \ find_best_edge(initialNetwork=initialNetwork, train_data=train_data, val_data=val_data, loss_fun=loss_fun, rank_increment=rank_increment, training_args=other_args, prev_best_loss=best_loss, search_epochs=search_epochs, padding_noise=1e-6, is_reg=False) # train network to convergence for the best rank increment print('\ntraining best network until max_epochs/convergence...') other_args["load_optimizer_state"] = best_network_optimizer_state current_network_optimizer_state = {} other_args["save_optimizer_state"] = True other_args["optimizer_state"] = current_network_optimizer_state if stop_on_plateau: detect_plateau._reset() [best_network, first_loss, best_loss, best_epoch, hist] = training(best_network, initial_epochs, train_data, loss_fun, val_data=val_data, epochs=epochs, other_args=other_args) other_args["load_optimizer_state"] = None best_train_lost_hist += deepcopy(hist[0][:best_epoch]) best_val_loss_hist += deepcopy(hist[1][:best_epoch]) initialNetwork = cc.copy_network(best_network) loss_record.append({'iter':stage, 'network':best_network, 'num_params':cc.num_params(best_network), 'loss':best_loss, 'train_loss_hist':best_train_lost_hist, 'val_loss_hist':best_val_loss_hist}) print('\nbest TN:') cc.print_ranks(best_network) print('number of params:',cc.num_params(best_network)) print([(r['iter'],r['num_params'],float(r['loss']),float(r['train_loss_hist'][0]),float(r['train_loss_hist'][-1])) for r in loss_record]) if filename: with open(filename, "wb") as f: pickle.dump(loss_record,f) return best_network, best_loss, loss_record
def discrete_optim_template(tensor_list, train_data, loss_fun, val_data=None, other_args=dict()): """ Train a tensor network using discrete optimization over TN ranks Args: tensor_list: List of tensors encoding the network being trained train_data: The data used to train the network loss_fun: Scalar-valued loss function of the type tens_list, data -> scalar_loss (This depends on the task being learned) val_data: The data used for validation, which can be used to for early stopping within continuous optimization calls within discrete optimization loop other_args: Dictionary of other arguments for the optimization, with some options below (feel free to add more) epochs: Number of epochs for continuous optimization (default=10) optim: Choice of Pytorch optimizer (default='SGD') lr: Learning rate for optimizer (default=1e-3) bsize: Minibatch size for training (default=100) reps: Number of times to repeat training data per epoch (default=1) cprint: Whether to print info from continuous optimization (default=True) dprint: Whether to print info from discrete optimization (default=True) dhist: Whether to return losses from intermediate TNs (default=False) Returns: better_list: List of tensors with same length as tensor_list, but having been optimized using the discrete optimization algorithm. The TN ranks of better_list will be larger than those of tensor_list. first_loss: Initial loss of the model on the validation set, before any training. If no val set is provided, the first training loss is instead returned best_loss: The value of the validation/training loss for the model output as better_list loss_record: If dhist=True in other_args, all values of best_loss associated with intermediate optimized TNs will be returned as a PyTorch vector, with loss_record[0] giving the initial loss of the model, and loss_record[-1] equal to best_loss value returned by discrete_optim. """ # Check input and initialize local record variables epochs = other_args['epochs'] if 'epochs' in other_args else 10 dprint = other_args['dprint'] if 'dprint' in other_args else True dhist = other_args['dhist'] if 'dhist' in other_args else False loss_rec, first_loss, best_loss, best_network = [], None, None, None if dhist: loss_record = [] # (train_record, val_record) # Function to maybe print, conditioned on `dprint` m_print = lambda s: print(s) if dprint else None # Function to record loss information def record_loss(new_loss, new_network): # Load record variables from outer scope nonlocal loss_rec, first_loss, best_loss, best_network # Check for best loss if best_loss is None or new_loss < best_loss: best_loss, best_network = new_loss, new_network # Add new loss to our loss record if not dhist: return nonlocal loss_record loss_record.append(new_loss) # Copy tensor_list so the original is unchanged tensor_list = cc.copy_network(tensor_list) # Define a function giving the stop condition for the discrete # optimization procedure. I'm using a simple example here which could # work for greedy or random walk searches, but for other optimization # methods this could be trivial stop_cond = generate_stop_cond(cc.get_indims(tensor_list)) # Iteratively increment ranks of tensor_list and train via # continuous_optim, at each stage using a search procedure to # test out different ranks before choosing just one to increase stage = 0 better_network, better_loss = tensor_list, 1e10 while not stop_cond(better_network): if first_loss is None: # Record initial loss of TN model first_args = other_args first_args["print"] = first_args["hist"] = False _, first_loss, _ = cc.continuous_optim(tensor_list, train_data, loss_fun, epochs=1, val_data=val_data, other_args=first_args) m_print("Initial model has TN ranks") if dprint: cc.print_ranks(tensor_list) m_print(f"Initial loss is {first_loss:.3f}") continue m_print(f"STAGE {stage}") # Try out training different network ranks and assign network # with best ranks to better_network # # TODO: This is your part to write! Use new variables for the loss # and network being tried out, with the network being # initialized from better_network. When you've found a better # TN, make sure to update better_network and better_loss to # equal the parameters and loss of that TN. # At some point this code will call the continuous optimization # loop, in which case you can use the following command: # # trained_tn, init_loss, final_loss = cc.continuous_optim(my_tn, # train_data, loss_fun, # epochs=epochs, # val_data=val_data, # other_args=other_args) # Record the loss associated with the best network from this # discrete optimization loop record_loss(better_loss, better_network) stage += 1 if dhist: loss_record = tuple(torch.tensor(fr) for fr in loss_record) return best_network, first_loss, best_loss, loss_record else: return best_network, first_loss, best_loss
def discrete_optim_template(tensor_list, train_data, loss_fun, val_data=None, other_args=dict()): """ Train a tensor network using discrete optimization over TN ranks Args: tensor_list: List of tensors encoding the network being trained train_data: The data used to train the network(target Tensor) loss_fun: Scalar-valued loss function of the type tens_list, data -> scalar_loss (This depends on the task being learned) val_data: The data used for validation, which can be used to for early stopping within continuous optimization calls within discrete optimization loop other_args: Dictionary of other arguments for the optimization, with some options below (feel free to add more) epochs: Number of epochs for continuous optimization (default=10) optim: Choice of Pytorch optimizer (default='SGD') lr: Learning rate for optimizer (default=1e-3) bsize: Minibatch size for training (default=100) reps: Number of times to repeat training data per epoch (default=1) cprint: Whether to print info from continuous optimization (default=True) dprint: Whether to print info from discrete optimization (default=True) dhist: Whether to return losses from intermediate TNs (default=False) search_epochs: Number of epochs to use to identify the best rank 1 update. If None, the epochs argument is used. (default=None) Returns: better_list: List of tensors with same length as tensor_list, but having been optimized using the discrete optimization algorithm. The TN ranks of better_list will be larger than those of tensor_list. first_loss: Initial loss of the model on the validation set, before any training. If no val set is provided, the first training loss is instead returned best_loss: The value of the validation/training loss for the model output as better_list loss_record: If dhist=True in other_args, all values of best_loss associated with intermediate optimized TNs will be returned as a PyTorch vector, with loss_record[0] giving the initial loss of the model, and loss_record[-1] equal to best_loss value returned by discrete_optim. """ # Check input and initialize local record variables epochs = other_args['epochs'] if 'epochs' in other_args else 10 dprint = other_args['dprint'] if 'dprint' in other_args else True cprint = other_args['cprint'] if 'cprint' in other_args else True dhist = other_args['dhist'] if 'dhist' in other_args else False search_epochs = other_args['search_epochs'] if 'search_epochs' in other_args else None loss_rec, prev_loss, best_loss, best_network = [], None, None, None if dhist: loss_record = [] # (train_record, val_record) # Function to maybe print, conditioned on `dprint` m_print = lambda s: print(s) if dprint else None # Function to record loss information def record_loss(new_loss, new_network): # Load record variables from outer scope nonlocal loss_rec, first_loss, best_loss, best_network # Check for best loss if best_loss is None or new_loss < best_loss: best_loss, best_network = new_loss, new_network # Add new loss to our loss record if not dhist: return nonlocal loss_record loss_record.append(new_loss) # Copy tensor_list so the original is unchanged tensor_list = cc.copy_network(tensor_list) # Define a function giving the stop condition for the discrete # optimization procedure. I'm using a simple example here which could # work for greedy or random walk searches, but for other optimization # methods this could be trivial stop_cond = generate_stop_cond(cc.get_indims(tensor_list)) ###########***************where are you increasing the rank from line 123 to 137?***** # Iteratively increment ranks of tensor_list and train via # continuous_optim, at each stage using a search procedure to # test out different ranks before choosing just one to increase stage = 0 best_network, better_loss = tensor_list, 1e10 while not stop_cond(best_network): if prev_loss is None: # Record initial loss of TN model first_args = other_args first_args["print"] = True first_args["hist"] = False tensor_list, _, prev_loss = cc.continuous_optim(tensor_list, train_data, loss_fun, epochs=epochs, val_data=val_data, other_args=first_args) m_print("Initial model has TN ranks") if dprint: cc.print_ranks(tensor_list) m_print(f"Initial loss is {prev_loss:.7f}") ##################################line 139 onward are new added acode############# # Try out training different network ranks and assign network # with best ranks to best_network # # TODO: This is your part to write! Use new variables for the loss # and network being tried out, with the network being # initialized from best_network. When you've found a better # TN, make sure to update best_network and better_loss to # equal the parameters and loss of that TN. # At some point this code will call the continuous optimization # loop, in which case you can use the following command: # # trained_tn, init_loss, final_loss = cc.continuous_optim(my_tn, # train_data, loss_fun, # epochs=epochs, # val_data=val_data, # other_args=other_args) #best_loss = float('inf') prev_loss = float('inf') max_params = torch.prod(torch.tensor((cc.get_indims(train_data)))) #train data is the target tensor currentNetwork = cc.copy_network(tensor_list) #temp initialNetwork = cc.copy_network(tensor_list) #example_tn loss_record = [] for k in range(50): m_print(f"STAGE {k}") for i in range(len(currentNetwork)): for j in range(i+1, len(currentNetwork)): currentNetwork = cc.copy_network(initialNetwork) #increase rank along a chosen dimension currentNetwork = cc.increase_rank(currentNetwork,i, j, 1, 1e-6) print('k = ', k, 'i =', i, 'j = ', j) #print(cc.print_ranks(currentNetwork)) #solve continuos optimization part, train_data = target_tensor if search_epochs: # we use SGD to identify the most promising rank update search_args = dict(other_args) #search_args["optim"] = 'SGD' #search_args["lr"] = 0.01 [currentNetwork, first_loss, current_loss] = cc.continuous_optim(currentNetwork, train_data, loss_fun, val_data=val_data, epochs=search_epochs , other_args=search_args) else: [currentNetwork, first_loss, current_loss] = cc.continuous_optim(currentNetwork, train_data, loss_fun, val_data=val_data, epochs=epochs, other_args=other_args) m_print(f"\nCurrent loss is {current_loss:.7f}") if prev_loss > current_loss: prev_loss = current_loss best_network = currentNetwork numParam = cc.num_params(best_network) better_loss = current_loss print('best rank update so far:', i,j) #reset currentNetwork to contuniue with greedy at another point #upate parameters if search_epochs: print('training best network until max_epochs/convergence..') [best_network, first_loss, better_loss] = cc.continuous_optim(best_network, train_data, loss_fun, val_data=val_data, epochs=epochs, other_args=other_args) currentNetwork = cc.copy_network(best_network) #update current point to the new point (i.e. best_network) that gave lower loss initialNetwork = cc.copy_network(best_network) loss_record.append(better_loss) print('best TN:') cc.print_ranks(best_network) print('number of params:',cc.num_params(best_network)) stop_cond = generate_stop_cond(cc.get_indims(best_network)) #***********need to add param=-1 if stop_cond == True: #i.e. break if the number of parameters in trained tensor exceeds number of param. in target tensor break return best_network, first_loss, better_loss #, loss_record