def training(tensor_list, initial_epochs, train_data, 
            loss_fun, val_data, epochs, other_args):
    """
    This function run the continuous optimization routine with a small tweak: if there is no improvement of the loss in the 
    first [initial_epochs] epochs, the learning rate is reduced by a factor of 0.5 and optimization is restarted from the beginning,
    All arguments are the same as cc.continuous_optim except for [initial_epochs].
    """

    if initial_epochs is None:
        return cc.continuous_optim(currentNetwork, train_data, 
            loss_fun, val_data=val_data, epochs=epochs, other_args=args)

    args = deepcopy(other_args)
    args["hist"] = True
    current_network_optimizer_state = other_args["optimizer_state"] if "optimizer_state" in args else {}
    args["save_optimizer_state"] = True
    args["optimizer_state"] = current_network_optimizer_state
    currentNetwork = cc.copy_network(tensor_list)
    remaining_epochs = epochs - initial_epochs if epochs else None
    hist = None
    while hist is None or (hist[0][0] < hist[0][-1]):
        if hist:
            if "load_optimizer_state" in args and args["load_optimizer_state"]:
                args["load_optimizer_state"]["optimizer_state"]['param_groups'][0]['lr'] /= 2
                lr = args["load_optimizer_state"]["optimizer_state"]['param_groups'][0]['lr']
            else:
                args["lr"] /= 2
                lr = args["lr"]
            print(f"\n[!] No progress in first {initial_epochs} epochs, starting again with smaller learning rate ({lr})")
            currentNetwork = cc.copy_network(tensor_list)
        [currentNetwork, first_loss, current_loss, hist] = cc.continuous_optim(currentNetwork, train_data, 
            loss_fun, val_data=val_data, epochs=initial_epochs, other_args=args)

    args["load_optimizer_state"] = current_network_optimizer_state
    [currentNetwork, first_loss, current_loss, hist] = cc.continuous_optim(currentNetwork, train_data, 
            loss_fun, val_data=val_data, epochs=remaining_epochs, other_args=args)
    
    return [currentNetwork, first_loss, current_loss, hist] if "hist" in other_args else [currentNetwork, first_loss, current_loss]
Example #2
0
 def plot_image(tensor_list):
     plt.imshow(
         cc.wire_network(cc.copy_network(tensor_list),
                         give_dense=True).detach().numpy().reshape(
                             im_size, order=order) / 255)
Example #3
0
def randomwalk_optim(tensor_list,
                     train_data,
                     loss_fun,
                     val_data=None,
                     other_args=dict(),
                     max_iter=None):
    """
    Train a tensor network using discrete optimization over TN ranks
    Args:
        tensor_list: List of tensors encoding the network being trained
        train_data:  The data used to train the network(target Tensor)
        loss_fun:    Scalar-valued loss function of the type 
                        tens_list, data -> scalar_loss
                     (This depends on the task being learned)
        val_data:    The data used for validation, which can be used to
                     for early stopping within continuous optimization
                     calls within discrete optimization loop
        other_args:  Dictionary of other arguments for the optimization, 
                     with some options below (feel free to add more)
                        epochs: Number of epochs for 
                                continuous optimization     (default=10)
                        max_iter: Maximum number of iterations
                                  for random walk search    (default=10)
                        optim:  Choice of Pytorch optimizer (default='SGD')
                        lr:     Learning rate for optimizer (default=1e-3)
                        bsize:  Minibatch size for training (default=100)
                        reps:   Number of times to repeat 
                                training data per epoch     (default=1)
                        cprint: Whether to print info from
                                continuous optimization     (default=True)
                        dprint: Whether to print info from
                                discrete optimization       (default=True)
                        search_epochs: Number of epochs to use to identify the
                                best rank 1 update. If None, the epochs argument
                                is used.                    (default=None)
                        loss_threshold: if loss gets below this threshold, 
                            discrete optimization is stopped
                                                            (default=1e-5)
                        initial_epochs: Number of epochs after which the 
                            learning rate is reduced and optimization is restarted
                            if there is no improvement in the loss.
                                                            (default=None)
                        rank_increment: how much should a rank be increase at
                            each discrete iterations        (default=1)
                        stop_on_plateau: a dictionnary containing keys
                                mode  (min/max)
                                patience
                                threshold
                            used to stop continuous optimizaion when plateau
                            is detected                     (default=None)
                        gradient_hook: this is a hack, please ignore...

    
    Returns:
        better_list: List of tensors with same length as tensor_list, but
                     having been optimized using the discrete optimization
                     algorithm. The TN ranks of better_list will be larger
                     than those of tensor_list.
        best_loss:   The value of the validation/training loss for the
                     model output as better_list
        loss_record: If dhist=True in other_args, this records the history of 
                     all losses for discrete and continuous optimization. This
                     is a list of dictionnaries with keys
                           iter,num_params,network,loss,train_loss_hist,val_loss_hist
                     where
                     iter: iteration of the discrete optimization
                     num_params: number of parameters of best network for this iteration
                     network: list of tensors for the best network for this iteration
                     loss: loss achieved by the best network in this iteration
                     train_loss_hist: history of losses for the continuous optimization
                        for this iteration starting from previous best_network to 
                        the epoch where the new best network was found
                     TODO: add val_loss_hist to the list of keys
    """
    # Check input and initialize local record variables
    epochs = other_args['epochs'] if 'epochs' in other_args else 10
    max_iter = other_args['max_iter'] if 'max_iter' in other_args else 10
    dprint = other_args['dprint'] if 'dprint' in other_args else True
    cprint = other_args['cprint'] if 'cprint' in other_args else True
    search_epochs = other_args[
        'search_epochs'] if 'search_epochs' in other_args else epochs
    loss_threshold = other_args[
        'loss_threshold'] if 'loss_threshold' in other_args else 1e-5
    initial_epochs = other_args[
        'initial_epochs'] if 'initial_epochs' in other_args else None
    rank_increment = other_args[
        'rank_increment'] if 'rank_increment' in other_args else 1
    gradient_hook = other_args[
        'gradient_hook'] if 'gradient_hook' in other_args else None
    is_reg = other_args['is_reg'] if 'is_reg' in other_args else False

    stop_on_plateau = other_args[
        'stop_on_plateau'] if 'stop_on_plateau' in other_args else None
    if stop_on_plateau:
        detect_plateau = DetectPlateau(**stop_on_plateau)
        other_args[
            "stop_condition"] = lambda train_loss, val_loss: detect_plateau(
                train_loss)

    first_loss, best_loss, best_network = ([], []), None, None
    d_loss_hist = ([], [])

    other_args['hist'] = True  # we always track the history of losses

    # Function to maybe print, conditioned on `dprint`
    m_print = lambda s: print(s) if dprint else None

    # Copy tensor_list so the original is unchanged
    tensor_list = cc.copy_network(tensor_list)
    loss_record = [{
        'iter': -1,
        'network': tensor_list,
        'num_params': cc.num_params(tensor_list),
        'train_loss_hist': [0],
        'val_loss_hist': [0]
    }]

    # Define a function giving the stop condition for the discrete
    # optimization procedure. I'm using a simple example here which could
    # work for greedy or random walk searches, but for other optimization
    # methods this could be trivial
    stop_cond = lambda loss: loss < loss_threshold

    # Iteratively increment ranks of tensor_list and train via
    # continuous_optim, at each stage using a search procedure to
    # test out different ranks before choosing just one to increase
    stage = -1
    best_loss, best_network, best_network_optimizer_state = np.infty, None, None

    while not stop_cond(best_loss) and stage < max_iter:
        stage += 1
        if best_loss is np.infty:  # first continuous optimization
            tensor_list, first_loss, best_loss, best_epoch, hist = training(
                tensor_list,
                initial_epochs,
                train_data,
                loss_fun,
                epochs=epochs,
                val_data=val_data,
                other_args=other_args)
            m_print("Initial model has TN ranks")
            if dprint: cc.print_ranks(tensor_list)
            m_print(f"Initial loss is {best_loss:.7f}")

            initialNetwork = cc.copy_network(tensor_list)
            best_network = cc.copy_network(tensor_list)
            loss_record[0]["loss"] = first_loss
            loss_record.append({
                'iter': stage,
                'network': best_network,
                'num_params': cc.num_params(best_network),
                'loss': best_loss,
                'train_loss_hist': hist[0][:best_epoch],
                'val_loss_hist': hist[1][:best_epoch]
            })

        else:
            m_print(
                f"\n\n**** Discrete optimization - iteration {stage} ****\n\n\n"
            )
            best_search_loss = best_loss
            best_train_lost_hist = []
            best_val_loss_hist = []

            [i, j] = torch.randperm(len(tensor_list))[:2].tolist()
            currentNetwork = cc.copy_network(initialNetwork)
            #increase rank along a chosen dimension
            currentNetwork = cc.increase_rank(currentNetwork, i, j,
                                              rank_increment, 1e-6)
            currentNetwork = cc.make_trainable(currentNetwork)
            print('\ntesting rank increment for i =', i, 'j = ', j)

            ### Search optimization phase
            # we d only a few epochs to identify the most promising rank update

            # function to zero out the gradient of all entries except for the new slices
            def grad_masking_function(tensor_list):
                nonlocal i, j
                for k in range(len(tensor_list)):
                    if k == i:
                        tensor_list[i].grad.permute(
                            [j] + list(range(0, j)) +
                            list(range(j + 1, len(currentNetwork))))[:-1, :,
                                                                     ...] *= 0
                    elif k == j:
                        tensor_list[j].grad.permute(
                            [i] + list(range(0, i)) +
                            list(range(i + 1, len(currentNetwork))))[:-1, :,
                                                                     ...] *= 0
                    else:
                        tensor_list[k].grad *= 0

            # we first optimize only the new slices for a few epochs
            print("optimize new slices for a few epochs")
            search_args = dict(other_args)
            current_network_optimizer_state = {}
            search_args["save_optimizer_state"] = True
            search_args["optimizer_state"] = current_network_optimizer_state
            if not is_reg:
                search_args["grad_masking_function"] = grad_masking_function
            if stop_on_plateau:
                detect_plateau._reset()
            [currentNetwork, first_loss, current_loss, best_epoch,
             hist] = training(currentNetwork,
                              initial_epochs,
                              train_data,
                              loss_fun,
                              val_data=val_data,
                              epochs=search_epochs,
                              other_args=search_args)
            first_loss = hist[0][0]
            train_lost_hist = deepcopy(hist[0][:best_epoch])
            val_loss_hist = deepcopy(hist[1][:best_epoch])

            # # We then optimize all parameters for a few epochs
            # print("\noptimize all parameters for a few epochs")
            # search_args["grad_masking_function"] = None
            # search_args["load_optimizer_state"] = dict(current_network_optimizer_state)
            # if stop_on_plateau:
            #     detect_plateau._reset()
            # [currentNetwork, first_loss, current_loss, best_epoch, hist] = training(currentNetwork, initial_epochs, train_data,
            #     loss_fun, val_data=val_data, epochs=search_epochs,
            #     other_args=search_args)
            # search_args["load_optimizer_state"] = None
            # train_lost_hist += deepcopy(hist[0][:best_epoch])

            best_search_loss = current_loss
            best_network = currentNetwork
            best_network_optimizer_state = deepcopy(
                current_network_optimizer_state)
            best_train_lost_hist = train_lost_hist
            best_val_loss_hist = val_loss_hist
            print('-> best rank update so far:', i, j)

            best_loss = best_search_loss
            # train network to convergence

            print('\ntraining best network until max_epochs/convergence...')
            other_args["load_optimizer_state"] = best_network_optimizer_state
            current_network_optimizer_state = {}
            other_args["save_optimizer_state"] = True
            other_args["optimizer_state"] = current_network_optimizer_state

            if gradient_hook:  # A bit of a hack only used for tensor completion - ignore
                other_args["grad_masking_function"] = gradient_hook
            if stop_on_plateau:
                detect_plateau._reset()

            [best_network, first_loss, best_loss, best_epoch,
             hist] = training(best_network,
                              initial_epochs,
                              train_data,
                              loss_fun,
                              val_data=val_data,
                              epochs=epochs,
                              other_args=other_args)

            if gradient_hook:
                other_args["grad_masking_function"] = None

            other_args["load_optimizer_state"] = None
            best_train_lost_hist += deepcopy(hist[0][:best_epoch])
            best_val_loss_hist += deepcopy(hist[1][:best_epoch])

            initialNetwork = cc.copy_network(best_network)

            loss_record.append({
                'iter': stage,
                'network': best_network,
                'num_params': cc.num_params(best_network),
                'loss': best_loss,
                'train_loss_hist': best_train_lost_hist,
                'val_loss_hist': best_val_loss_hist
            })

            print('\nbest TN:')
            cc.print_ranks(best_network)
            print('number of params:', cc.num_params(best_network))
            print([(r['iter'], r['num_params'], float(r['loss']),
                    float(r['train_loss_hist'][0]),
                    float(r['train_loss_hist'][-1])) for r in loss_record])

    return best_network, best_loss, loss_record
def discrete_optim_template(tensor_list,
                            train_data,
                            loss_fun,
                            val_data=None,
                            other_args=dict()):
    """
    Train a tensor network using discrete optimization over TN ranks
    Args:
        tensor_list: List of tensors encoding the network being trained
        train_data:  The data used to train the network(target Tensor)
        loss_fun:    Scalar-valued loss function of the type 
                        tens_list, data -> scalar_loss
                     (This depends on the task being learned)
        val_data:    The data used for validation, which can be used to
                     for early stopping within continuous optimization
                     calls within discrete optimization loop
        other_args:  Dictionary of other arguments for the optimization, 
                     with some options below (feel free to add more)
                        epochs: Number of epochs for 
                                continuous optimization     (default=10)
                        optim:  Choice of Pytorch optimizer (default='SGD')
                        lr:     Learning rate for optimizer (default=1e-3)
                        bsize:  Minibatch size for training (default=100)
                        reps:   Number of times to repeat 
                                training data per epoch     (default=1)
                        cprint: Whether to print info from
                                continuous optimization     (default=True)
                        dprint: Whether to print info from
                                discrete optimization       (default=True)
                        dhist:  Whether to return losses
                                from intermediate TNs       (default=False)
    
    Returns:
        better_list: List of tensors with same length as tensor_list, but
                     having been optimized using the discrete optimization
                     algorithm. The TN ranks of better_list will be larger
                     than those of tensor_list.
        first_loss:  Initial loss of the model on the validation set, 
                     before any training. If no val set is provided, the
                     first training loss is instead returned
        best_loss:   The value of the validation/training loss for the
                     model output as better_list
        loss_record: If dhist=True in other_args, all values of best_loss
                     associated with intermediate optimized TNs will be
                     returned as a PyTorch vector, with loss_record[0]
                     giving the initial loss of the model, and
                     loss_record[-1] equal to best_loss value returned
                     by discrete_optim.
    """
    # Check input and initialize local record variables
    epochs = other_args['epochs'] if 'epochs' in other_args else 10
    dprint = other_args['dprint'] if 'dprint' in other_args else True
    dhist = other_args['dhist'] if 'dhist' in other_args else False
    loss_rec, first_loss, best_loss, best_network = [], None, None, None
    if dhist: loss_record = []  # (train_record, val_record)

    # Function to maybe print, conditioned on `dprint`
    m_print = lambda s: print(s) if dprint else None

    # Function to record loss information
    def record_loss(new_loss, new_network):
        # Load record variables from outer scope
        nonlocal loss_rec, first_loss, best_loss, best_network

        # Check for best loss
        if best_loss is None or new_loss < best_loss:
            best_loss, best_network = new_loss, new_network

        # Add new loss to our loss record
        if not dhist: return
        nonlocal loss_record
        loss_record.append(new_loss)

    # Copy tensor_list so the original is unchanged
    tensor_list = cc.copy_network(tensor_list)

    # Define a function giving the stop condition for the discrete
    # optimization procedure. I'm using a simple example here which could
    # work for greedy or random walk searches, but for other optimization
    # methods this could be trivial
    stop_cond = generate_stop_cond(cc.get_indims(tensor_list))

    ###########***************where are you increasing the rank from line 123 to 137?*****
    # Iteratively increment ranks of tensor_list and train via
    # continuous_optim, at each stage using a search procedure to
    # test out different ranks before choosing just one to increase
    stage = 0
    better_network, better_loss = tensor_list, 1e10
    initialNetwork = cc.copy_network(tensor_list)  #example_tn
    while not stop_cond(better_network):
        if first_loss is None:
            # Record initial loss of TN model
            first_args = other_args.copy()
            first_args["print"] = first_args["hist"] = False
            _, first_loss, _ = cc.continuous_optim(tensor_list,
                                                   train_data,
                                                   loss_fun,
                                                   epochs=1,
                                                   val_data=val_data,
                                                   other_args=first_args)
            m_print("Initial model has TN ranks")
            m_print(f"Initial loss is {first_loss:.3f}")
            m_print("Performing initial optimization...")
            first_args["print"] = True
            initialNetwork, _, _ = cc.continuous_optim(tensor_list,
                                                       train_data,
                                                       loss_fun,
                                                       epochs=2000,
                                                       val_data=val_data,
                                                       other_args=first_args)
            continue
        m_print(f"STAGE {stage}")
        stage += 1

        ##################################line 139 onward are new added acode#############
        # Try out training different network ranks and assign network
        # with best ranks to better_network
        #
        # TODO: This is your part to write! Use new variables for the loss
        #       and network being tried out, with the network being
        #       initialized from better_network. When you've found a better
        #       TN, make sure to update better_network and better_loss to
        #       equal the parameters and loss of that TN.
        # At some point this code will call the continuous optimization
        # loop, in which case you can use the following command:
        #
        # trained_tn, init_loss, final_loss = cc.continuous_optim(my_tn,
        #                                         train_data, loss_fun,
        #                                         epochs=epochs,
        #                                         val_data=val_data,
        #                                         other_args=other_args)

        #best_loss = float('inf')
        prev_loss = float('inf')
        num_cores = len(tensor_list)
        max_params = torch.prod(torch.tensor(
            (cc.get_indims(train_data))))  #train data is the target tensor
        loss_record = []
        for i in range(num_cores):
            for j in range(i + 1, num_cores):
                m_print(f"Testing i={i}, j={j}")

                #reset currentNetwork to contuniue with greedy at another point
                currentNetwork = cc.copy_network(initialNetwork)
                #increase rank along a chosen dimension
                currentNetwork = cc.increase_rank(currentNetwork, i, j, 1,
                                                  1e-6)

                stop_cond = generate_stop_cond(cc.get_indims(
                    currentNetwork))  #***********need to add param=-1
                if stop_cond(
                        currentNetwork
                ) == True:  #i.e. break if the number of parameters in trained tensor exceeds number of param. in target tensor
                    break

                #solve continuos optimization part, train_data = target_tensor
                [currentNetwork, first_loss,
                 current_loss] = cc.continuous_optim(currentNetwork,
                                                     train_data,
                                                     loss_fun,
                                                     epochs=epochs,
                                                     val_data=val_data,
                                                     other_args=other_args)
                if prev_loss > current_loss:
                    prev_loss = current_loss
                    better_network = currentNetwork
                    numParam = cc.num_params(better_network)
                    better_loss = current_loss
                    m_print("BEST CONFIG")
                else:
                    m_print("Not best config")
                m_print(f"{first_loss:.3f} -> {current_loss:.3f}\n")

        #update current point to the new point (i.e. better_network) that gave lower loss
        initialNetwork = cc.copy_network(better_network)
        loss_record.append(better_loss)

    return best_network, first_loss, better_loss  #, loss_record
def greedy_optim(tensor_list, train_data, loss_fun, 
                            val_data=None, other_args=dict(),max_iter=None):
    """
    Train a tensor network using discrete optimization over TN ranks
    Args:
        tensor_list: List of tensors encoding the network being trained
        train_data:  The data used to train the network(target Tensor)
        loss_fun:    Scalar-valued loss function of the type 
                        tens_list, data -> scalar_loss
                     (This depends on the task being learned)
        val_data:    The data used for validation, which can be used to
                     for early stopping within continuous optimization
                     calls within discrete optimization loop
        other_args:  Dictionary of other arguments for the optimization, 
                     with some options below (feel free to add more)
                        epochs: Number of epochs for 
                                continuous optimization     (default=10)
                        max_iter: Maximum number of iterations
                                  for greedy search           (default=10)
                        optim:  Choice of Pytorch optimizer (default='SGD')
                        lr:     Learning rate for optimizer (default=1e-3)
                        bsize:  Minibatch size for training (default=100)
                        reps:   Number of times to repeat 
                                training data per epoch     (default=1)
                        cprint: Whether to print info from
                                continuous optimization     (default=True)
                        dprint: Whether to print info from
                                discrete optimization       (default=True)
                        dhist:  Whether to return losses
                                from intermediate TNs       (default=False)
                        search_epochs: Number of epochs to use to identify the
                                best rank 1 update. If None, the epochs argument
                                is used.                    (default=None)
                        loss_threshold: if loss gets below this threshold, 
                        discrete optimization is stopped
                                                            (default=1e-5)
                        initial_epochs: Number of epochs after which the 
                        learning rate is reduced and optimization is restarted
                        if there is no improvement in the loss.
                                                            (default=None)
    
    Returns:
        better_list: List of tensors with same length as tensor_list, but
                     having been optimized using the discrete optimization
                     algorithm. The TN ranks of better_list will be larger
                     than those of tensor_list.
        first_loss:  Initial loss of the model on the validation set, 
                     before any training. If no val set is provided, the
                     first training loss is instead returned
        best_loss:   The value of the validation/training loss for the
                     model output as better_list
        loss_record: If dhist=True in other_args, all values of best_loss
                     associated with intermediate optimized TNs will be
                     returned as a PyTorch vector, with loss_record[0]
                     giving the initial loss of the model, and
                     loss_record[-1] equal to best_loss value returned
                     by discrete_optim.
    """
    # Check input and initialize local record variables
    epochs  = other_args['epochs'] if 'epochs' in other_args else 10
    max_iter  = other_args['max_iter'] if 'max_iter' in other_args else 10
    dprint  = other_args['dprint'] if 'dprint' in other_args else True
    cprint  = other_args['cprint'] if 'cprint' in other_args else True
    dhist  = other_args['dhist']  if 'dhist'  in other_args else False
    search_epochs  = other_args['search_epochs']  if 'search_epochs'  in other_args else None
    loss_threshold  = other_args['loss_threshold']  if 'loss_threshold'  in other_args else 1e-5
    initial_epochs  = other_args['initial_epochs'] if 'initial_epochs' in other_args else None
    stop_cond = lambda loss: loss < loss_threshold


    if dhist: loss_record = []    # (train_record, val_record)

    # Function to maybe print, conditioned on `dprint`
    m_print = lambda s: print(s) if dprint else None


    # Copy tensor_list so the original is unchanged
    tensor_list = cc.copy_network(tensor_list)


    # Iteratively increment ranks of tensor_list and train via
    # continuous_optim, at each stage using a search procedure to 
    # test out different ranks before choosing just one to increase
    stage = 0
    loss_record, best_loss, best_network = [], np.infty, None
    


    while not stop_cond(best_loss) and stage < max_iter:
        stage += 1
        if best_loss is np.infty: # first continuous optimization
            tensor_list, first_loss, best_loss = training(
                tensor_list, initial_epochs, train_data, loss_fun, epochs=epochs, 
                val_data=val_data,other_args=other_args)
            m_print("Initial model has TN ranks")
            if dprint: cc.print_ranks(tensor_list)
            m_print(f"Initial loss is {best_loss:.7f}")

            initialNetwork = cc.copy_network(tensor_list) 
            best_network = cc.copy_network(tensor_list) 
            loss_record += [first_loss,best_loss]

            


        m_print(f"\n\n**** Discrete optimization - iteration {stage} ****\n\n\n")  
        best_search_loss = best_loss

        for i in range(len(initialNetwork)):
            for j in range(i+1, len(initialNetwork)):
                currentNetwork = cc.copy_network(initialNetwork)
                #increase rank along a chosen dimension
                currentNetwork = cc.increase_rank(currentNetwork,i, j, 1, 1e-6)



                print('\ntesting rank increment for i =', i, 'j = ', j)
                if search_epochs: # we d only a few epochs to identify the most promising rank update

                    # function to zero out the gradient of all entries except for the new slices
                    def grad_masking_function(tensor_list):
                        nonlocal i,j
                        for k in range(len(tensor_list)):
                            if k == i:
                                tensor_list[i].grad.permute([j]+list(range(0,j))+list(range(j+1,len(currentNetwork))))[:-1,:,...] *= 0
                            elif k == j:
                                tensor_list[j].grad.permute([i]+list(range(0,i))+list(range(i+1,len(currentNetwork))))[:-1,:,...] *= 0
                            else:
                                tensor_list[k].grad *= 0

                    # we first optimize only the new slices for a few epochs
                    print("optimize new slices for a few epochs")
                    search_args = dict(other_args)
                    search_args["hist"] = True
                    current_network_optimizer_state = {}
                    search_args["save_optimizer_state"] = True
                    search_args["optimizer_state"] = current_network_optimizer_state
                    search_args["grad_masking_function"] = grad_masking_function
                    [currentNetwork, first_loss, current_loss, hist] = training(currentNetwork, initial_epochs, train_data, 
                        loss_fun, val_data=val_data, epochs=search_epochs, other_args=search_args)
                    first_loss = hist[0][0]

                    # We then optimize all parameters for a few epochs
                    print("\noptimize all parameters for a few epochs")
                    search_args["grad_masking_function"] = None
                    search_args["load_optimizer_state"] = dict(current_network_optimizer_state)
                    [currentNetwork, first_loss, current_loss, hist] = training(currentNetwork, initial_epochs, train_data, 
                        loss_fun, val_data=val_data, epochs=search_epochs , 
                        other_args=search_args)
                    search_args["load_optimizer_state"] = None

                else: # we fully optimize the network in the search phase
                    [currentNetwork, first_loss, current_loss] = cc.continuous_optim(currentNetwork, train_data, 
                        loss_fun, val_data=val_data, epochs=epochs, 
                        other_args=other_args)
                

                m_print(f"\nCurrent loss is {current_loss:.7f}    Best loss from previous discrete optim is {best_loss}")
                if best_search_loss > current_loss:
                    best_search_loss = current_loss
                    best_network = currentNetwork
                    best_network_optimizer_state = deepcopy(current_network_optimizer_state)
                    print('-> best rank update so far:', i,j)
        

        best_loss = best_search_loss
        # train network to convergence for the best rank increment (if search_epochs is set, 
        # otherwise the best network is already trained to convergence / max_epochs)
        if search_epochs:
            print('\ntraining best network until max_epochs/convergence...')

            other_args["load_optimizer_state"] = best_network_optimizer_state
            current_network_optimizer_state = {}
            other_args["save_optimizer_state"] = True
            other_args["optimizer_state"] = current_network_optimizer_state
            [best_network, first_loss, best_loss] = training(best_network, initial_epochs, train_data, 
                    loss_fun, val_data=val_data, epochs=epochs, 
                    other_args=other_args)
            other_args["load_optimizer_state"] = None

        initialNetwork  = cc.copy_network(best_network)
        loss_record.append((stage, cc.num_params(best_network), float(best_loss)))
        print('\nbest TN:')
        cc.print_ranks(best_network)
        print('number of params:',cc.num_params(best_network))
        print(loss_record)
    return best_network, first_loss, best_loss 
Example #6
0
def randomsearch_optim(tensor_list,
                       train_data,
                       loss_fun,
                       val_data=None,
                       other_args=dict()):
    """
    Train a tensor network using discrete optimization over TN ranks

    Args:
        tensor_list: List of tensors encoding the network being trained
        train_data:  The data used to train the network
        loss_fun:    Scalar-valued loss function of the type 
                        tens_list, data -> scalar_loss
                     (This depends on the task being learned)
        val_data:    The data used for validation, which can be used to
                     for early stopping within continuous optimization
                     calls within discrete optimization loop
        other_args:  Dictionary of other arguments for the optimization, 
                     with some options below (feel free to add more)
                        epochs: Number of epochs for 
                                continuous optimization     (default=10)
                        max_iter: Maximum number of iterations
                                  for random search         (default=10)
                        optim:  Choice of Pytorch optimizer (default='SGD')
                        lr:     Learning rate for optimizer (default=1e-3)
                        bsize:  Minibatch size for training (default=100)
                        max_rank: Maximum rank to search for  (default=7)
                        reps:   Number of times to repeat 
                                training data per epoch     (default=1)
                        cprint: Whether to print info from
                                continuous optimization     (default=True)
                        dprint: Whether to print info from
                                discrete optimization       (default=True)
                        loss_threshold: if loss gets below this threshold, 
                            discrete optimization is stopped
                                                            (default=1e-5)
                        initial_epochs: Number of epochs after which the 
                            learning rate is reduced and optimization is restarted
                            if there is no improvement in the loss.
                                                            (default=None)
                        stop_on_plateau: a dictionnary containing keys
                                mode  (min/max)
                                patience
                                threshold
                            used to stop continuous optimizaion when plateau
                            is detected                     (default=None)
    
    Returns:
        better_list: List of tensors with same length as tensor_list, but
                     having been optimized using the discrete optimization
                     algorithm. The TN ranks of better_list will be larger
                     than those of tensor_list.
        best_loss:   The value of the validation/training loss for the
                     model output as better_list
        loss_record: If dhist=True in other_args, this records the history of 
                     all losses for discrete and continuous optimization. This
                     is a list of dictionnaries with keys
                           iter,num_params,network,loss,train_loss_hist
                     where
                     iter: iteration of the discrete optimization
                     num_params: number of parameters of best network for this iteration
                     network: list of tensors for the best network for this iteration
                     loss: loss achieved by the best network in this iteration
                     train_loss_hist: history of losses for the continuous optimization
                        for this iteration starting from previous best_network to 
                        the epoch where the new best network was found
                     TODO: add val_lost_hist to the list of keys
    """
    # Check input and initialize local record variables
    epochs = other_args['epochs'] if 'epochs' in other_args else 10
    max_iter = other_args['max_iter'] if 'max_iter' in other_args else 10
    max_rank = other_args['max_rank'] if 'max_rank' in other_args else 7
    max_params = other_args[
        'max_params'] if 'max_params' in other_args else 10000
    dprint = other_args['dprint'] if 'dprint' in other_args else True
    cprint = other_args['cprint'] if 'cprint' in other_args else True
    loss_threshold = other_args[
        'loss_threshold'] if 'loss_threshold' in other_args else 1e-5
    initial_epochs = other_args[
        'initial_epochs'] if 'initial_epochs' in other_args else None
    # Keep track of continous optimization history
    other_args['hist'] = True

    stop_on_plateau = other_args[
        'stop_on_plateau'] if 'stop_on_plateau' in other_args else None
    if stop_on_plateau:
        detect_plateau = DetectPlateau(**stop_on_plateau)
        other_args[
            "stop_condition"] = lambda train_loss, val_loss: detect_plateau(
                train_loss)

    first_loss, best_loss, best_network = ([], []), None, None
    d_loss_hist = ([], [])

    # Function to maybe print, conditioned on `dprint`
    m_print = lambda s: print(s) if dprint else None

    # Copy tensor_list so the original is unchanged
    tensor_list = cc.copy_network(tensor_list)
    loss_record = [{
        'iter': -1,
        'network': tensor_list,
        'num_params': cc.num_params(tensor_list),
        'train_loss_hist': [0]
    }]

    # Define a function giving the stop condition for the discrete
    # optimization procedure. I'm using a simple example here which could
    # work for greedy or random walk searches, but for other optimization
    # methods this could be trivial
    stop_cond = lambda loss: loss < loss_threshold

    input_dims = cc.get_indims(tensor_list)
    n_cores = len(input_dims)
    n_edges = n_cores * (n_cores - 1) // 2

    # Iteratively increment ranks of tensor_list and train via
    # continuous_optim, at each stage using a search procedure to
    # test out different ranks before choosing just one to increase
    stage = -1
    best_loss, best_network, best_network_optimizer_state = np.infty, None, None

    while not stop_cond(best_loss) and stage < max_iter:
        stage += 1
        if best_loss is np.infty:  # first continuous optimization
            tensor_list, first_loss, best_loss, best_epoch, hist = training(
                tensor_list,
                initial_epochs,
                train_data,
                loss_fun,
                epochs=epochs,
                val_data=val_data,
                other_args=other_args)
            m_print("Initial model has TN ranks")
            if dprint: cc.print_ranks(tensor_list)
            m_print(f"Initial loss is {best_loss:.7f}")

            initialNetwork = cc.copy_network(tensor_list)
            best_network = cc.copy_network(tensor_list)
            loss_record[0]["loss"] = first_loss
            loss_record.append({
                'iter': stage,
                'network': best_network,
                'num_params': cc.num_params(best_network),
                'loss': best_loss,
                'train_loss_hist': hist[0][:best_epoch]
            })

        else:
            m_print(
                f"\n\n**** Discrete optimization - iteration {stage} ****\n\n\n"
            )
            # Create new TN random ranks
            ranks = torch.randint(low=1, high=max_rank + 1,
                                  size=(n_edges, 1)).view(-1, ).tolist()
            rank_list = _make_ranks(ranks)
            currentNetwork = _limit_random_tn(input_dims,
                                              rank=rank_list,
                                              max_params=max_params)
            currentNetwork = cc.make_trainable(currentNetwork)
            if stop_on_plateau:
                detect_plateau._reset()
            currentNetwork, first_loss, current_loss, best_epoch, hist = training(
                currentNetwork,
                initial_epochs,
                train_data,
                loss_fun,
                val_data=val_data,
                epochs=epochs,
                other_args=other_args)
            train_lost_hist = hist[0][:best_epoch]
            loss_record.append({
                'iter': stage,
                'network': currentNetwork,
                'num_params': cc.num_params(currentNetwork),
                'loss': current_loss,
                'train_loss_hist': train_lost_hist
            })
            if best_loss > current_loss:
                best_network = currentNetwork
                best_loss = current_loss

            print('\nbest TN:')
            cc.print_ranks(best_network)
            print('number of params:', cc.num_params(best_network))
            print([(r['iter'], r['num_params'], float(r['loss']),
                    float(r['train_loss_hist'][0]),
                    float(r['train_loss_hist'][-1])) for r in loss_record])

    return best_network, best_loss, loss_record
Example #7
0
def greedy_optim(tensor_list, train_data, loss_fun, find_best_edge,
                val_data=None, other_args=dict(),max_iter=None):
    """
    Train a tensor network using discrete optimization over TN ranks
    Args:
        tensor_list: List of tensors encoding the network being trained
        train_data:  The data used to train the network(target Tensor)
        loss_fun:    Scalar-valued loss function of the type 
                        tens_list, data -> scalar_loss
                     (This depends on the task being learned)
        val_data:    The data used for validation, which can be used to
                     for early stopping within continuous optimization
                     calls within discrete optimization loop
        find_best_edge: function identifying the most promising edge and 
            returning a new network with incremented edge. This function
            must accept the following arguments:
                initialNetwork
                train_data
                val_data
                loss_fun
                rank_increment
                training_args
                prev_best_loss
                search_epochs (specific to current searching method)
                padding_noise
                is_reg        (hack, ignore)

        other_args:  Dictionary of other arguments for the optimization, 
                     with some options below (feel free to add more)
                        epochs: Number of epochs for 
                                continuous optimization     (default=10)
                        max_iter: Maximum number of iterations
                                  for greedy search           (default=10)
                        optim:  Choice of Pytorch optimizer (default='SGD')
                        lr:     Learning rate for optimizer (default=1e-3)
                        bsize:  Minibatch size for training (default=100)
                        reps:   Number of times to repeat 
                                training data per epoch     (default=1)
                        cprint: Whether to print info from
                                continuous optimization     (default=True)
                        dprint: Whether to print info from
                                discrete optimization       (default=True)
                        search_epochs: Number of epochs to use to identify the
                                best rank 1 update. If None, the epochs argument
                                is used.                    (default=None)
                        loss_threshold: if loss gets below this threshold, 
                            discrete optimization is stopped
                                                            (default=1e-5)
                        initial_epochs: Number of epochs after which the 
                            learning rate is reduced and optimization is restarted
                            if there is no improvement in the loss.
                                                            (default=None)
                        rank_increment: how much should a rank be increase at
                            each discrete iterations        (default=1)
                        stop_on_plateau: a dictionnary containing keys
                                mode  (min/max)
                                patience
                                threshold
                            used to stop continuous optimizaion when plateau
                            is detected                     (default=None)
                        filename: pickle filename to save the loss_record after each
                        continuous optimization.            (default=None)

    
    Returns:
        better_list: List of tensors with same length as tensor_list, but
                     having been optimized using the discrete optimization
                     algorithm. The TN ranks of better_list will be larger
                     than those of tensor_list.
        best_loss:   The value of the validation/training loss for the
                     model output as better_list
        loss_record: If dhist=True in other_args, this records the history of 
                     all losses for discrete and continuous optimization. This
                     is a list of dictionnaries with keys
                           iter,num_params,network,loss,train_loss_hist,val_loss_hist
                     where
                     iter: iteration of the discrete optimization
                     num_params: number of parameters of best network for this iteration
                     network: list of tensors for the best network for this iteration
                     loss: loss achieved by the best network in this iteration
                     train_loss_hist: history of losses for the continuous optimization
                        for this iteration starting from previous best_network to 
                        the epoch where the new best network was found
                     TODO: add val_lost_hist to the list of keys
    """

    # Check input and initialize local record variables
    epochs  = other_args['epochs'] if 'epochs' in other_args else 10
    max_iter  = other_args['max_iter'] if 'max_iter' in other_args else 10
    dprint  = other_args['dprint'] if 'dprint' in other_args else True
    cprint  = other_args['cprint'] if 'cprint' in other_args else True
    search_epochs  = other_args['search_epochs']  if 'search_epochs'  in other_args else epochs
    loss_threshold  = other_args['loss_threshold']  if 'loss_threshold'  in other_args else 1e-5
    initial_epochs  = other_args['initial_epochs'] if 'initial_epochs' in other_args else None
    rank_increment  = other_args['rank_increment'] if 'rank_increment' in other_args else 1
    gradient_hook  = other_args['gradient_hook'] if 'gradient_hook' in other_args else None
    #allowed_edges  = other_args['allowed_edges'] if 'allowed_edges' in other_args else None
    filename = other_args['filename'] if 'filename' in other_args else None
    is_reg = other_args['is_reg'] if 'is_reg' in other_args else False

    

    assert initial_epochs < search_epochs and (not epochs or initial_epochs < epochs), "initial_epochs must be smaller than search_epochs and epochs"

    stop_on_plateau  = other_args['stop_on_plateau'] if 'stop_on_plateau' in other_args else None
    if stop_on_plateau:
        detect_plateau = DetectPlateau(**stop_on_plateau)
        other_args["stop_condition"] = lambda train_loss,val_loss : detect_plateau(train_loss)


    first_loss, best_loss, best_network = ([],[]), None, None
    d_loss_hist = ([], [])


    other_args['hist'] = True # we always track the history of losses

    # Function to maybe print, conditioned on `dprint`
    m_print = lambda s: print(s) if dprint else None



    # Copy tensor_list so the original is unchanged
    tensor_list = cc.copy_network(tensor_list)
    loss_record = [{'iter':-1,'network':tensor_list,
            'num_params':cc.num_params(tensor_list),
            'train_loss_hist':[0],
            'val_loss_hist':[0]}]  

    # Define a function giving the stop condition for the discrete 
    # optimization procedure. I'm using a simple example here which could
    # work for greedy or random walk searches, but for other optimization
    # methods this could be trivial
    stop_cond = lambda loss: loss < loss_threshold

    # Iteratively increment ranks of tensor_list and train via
    # continuous_optim, at each stage using a search procedure to 
    # test out different ranks before choosing just one to increase
    stage = -1
    best_loss, best_network, best_network_optimizer_state = np.infty, None, None

    while not stop_cond(best_loss) and stage < max_iter:
        stage += 1
        if best_loss is np.infty: # first continuous optimization
            tensor_list, first_loss, best_loss, best_epoch, hist = training(tensor_list, initial_epochs, 
                                    train_data, loss_fun, epochs=epochs, 
                                    val_data=val_data,other_args=other_args)
            m_print("Initial model has TN ranks")
            if dprint: cc.print_ranks(tensor_list)
            m_print(f"Initial loss is {best_loss:.7f}")

            initialNetwork = cc.copy_network(tensor_list) 
            best_network = cc.copy_network(tensor_list) 
            loss_record[0]["loss"] = first_loss
            loss_record.append({'iter':stage,
                'network':best_network,
                'num_params':cc.num_params(best_network),
                'loss':best_loss,
                'train_loss_hist':hist[0][:best_epoch],
                'val_loss_hist':hist[1][:best_epoch]})

            if filename:
                with open(filename, "wb") as f:
                  pickle.dump(loss_record,f)

        else:
            m_print(f"\n\n**** Discrete optimization - iteration {stage} ****\n\n\n")  
            best_search_loss = best_loss
            best_train_lost_hist = []
            best_val_loss_hist = []

            best_network,best_loss, best_network_optimizer_state,best_train_lost_hist,best_val_loss_hist = \
                find_best_edge(initialNetwork=initialNetwork,
                                    train_data=train_data,
                                    val_data=val_data,
                                    loss_fun=loss_fun,
                                    rank_increment=rank_increment,
                                    training_args=other_args,
                                    prev_best_loss=best_loss,
                                    search_epochs=search_epochs,
                                    padding_noise=1e-6,
                                    is_reg=False)

            # train network to convergence for the best rank increment


            print('\ntraining best network until max_epochs/convergence...')
            other_args["load_optimizer_state"] = best_network_optimizer_state
            current_network_optimizer_state = {}
            other_args["save_optimizer_state"] = True
            other_args["optimizer_state"] = current_network_optimizer_state


            if stop_on_plateau:
                detect_plateau._reset()
            [best_network, first_loss, best_loss, best_epoch, hist] = training(best_network, initial_epochs, train_data, 
                    loss_fun, val_data=val_data, epochs=epochs, 
                    other_args=other_args)


            other_args["load_optimizer_state"] = None
            best_train_lost_hist += deepcopy(hist[0][:best_epoch])
            best_val_loss_hist += deepcopy(hist[1][:best_epoch])


            initialNetwork  = cc.copy_network(best_network)

            loss_record.append({'iter':stage,
                'network':best_network,
                'num_params':cc.num_params(best_network),
                'loss':best_loss,
                'train_loss_hist':best_train_lost_hist,
                'val_loss_hist':best_val_loss_hist})

            print('\nbest TN:')
            cc.print_ranks(best_network)
            print('number of params:',cc.num_params(best_network))
            print([(r['iter'],r['num_params'],float(r['loss']),float(r['train_loss_hist'][0]),float(r['train_loss_hist'][-1])) for r in loss_record])

            if filename:
                with open(filename, "wb") as f:
                  pickle.dump(loss_record,f)

    return best_network, best_loss, loss_record
Example #8
0
def greedy_find_best_edge(initialNetwork,train_data,val_data,loss_fun,rank_increment,training_args,prev_best_loss,allowed_edges=None,search_epochs=10,padding_noise=1e-6,is_reg=False):
    """
        initialNetwork: list of tensors repreesnting the initial network
        allowed_edges: a list of allowed edges of rank more than
                    one in the tensor network. If None, all edges are
                    considered.                     (default=None)
        rank_increment: how much should the rank of the new edge be increased
        padding_noise: standard deviation of the normal distribution use to
                    initialize the new slices of the core tensors receiving
                    the new edge
        training_args: dictionnary of arguments passed to the 
                    training function.              (default={})
    """

    if not allowed_edges:
        ndims = len(initialNetwork)
        allowed_edges = [(i,j) for i in range(ndims) for j in range(i+1,ndims)]

    stop_on_plateau  = training_args['stop_on_plateau'] if 'stop_on_plateau' in training_args else None
    if stop_on_plateau:
        detect_plateau = DetectPlateau(**stop_on_plateau)
        training_args["stop_condition"] = lambda train_loss,val_loss : detect_plateau(train_loss)


    initial_epochs  = training_args['initial_epochs'] if 'initial_epochs' in training_args else None


    best_loss, best_network, best_network_optimizer_state = np.infty, None, None

    best_search_loss = best_loss
    best_train_lost_hist = []
    best_val_loss_hist = []

    for (i,j) in allowed_edges: 
        currentNetwork = cc.copy_network(initialNetwork)
        #increase rank along a chosen dimension
        currentNetwork = cc.increase_rank(currentNetwork,i, j, rank_increment, padding_noise)
        currentNetwork = cc.make_trainable(currentNetwork)
        print('\ntesting rank increment for i =', i, 'j = ', j)

        ### Search optimization phase
        # we do only a few epochs to identify the most promising rank update
        # we first optimize only the new slices for a few epochs
        print("optimize new slices for a few epochs")
        search_args = dict(training_args)
        current_network_optimizer_state = {}
        search_args["save_optimizer_state"] = True
        search_args["optimizer_state"] = current_network_optimizer_state

        if not is_reg:
            search_args["only_new_slice"] = (i,j)
        if stop_on_plateau:
            detect_plateau._reset()

        [currentNetwork, first_loss, current_loss, best_epoch, hist] = training(currentNetwork, training_args["initial_epochs"], train_data, 
            loss_fun, val_data=val_data, epochs=search_epochs, other_args=search_args)
        first_loss = hist[0][0]
        train_lost_hist = deepcopy(hist[0][:best_epoch])
        val_loss_hist = deepcopy(hist[1][:best_epoch])

        # We then optimize all parameters for a few epochs
        print("\noptimize all parameters for a few epochs")
        search_args["load_optimizer_state"] = dict(current_network_optimizer_state)
        search_args["only_new_slice"] = False
        if stop_on_plateau:
            detect_plateau._reset()
        [currentNetwork, first_loss, current_loss, best_epoch, hist] = training(currentNetwork, initial_epochs, train_data, 
            loss_fun, val_data=val_data, epochs=search_epochs, 
            other_args=search_args)
        search_args["load_optimizer_state"] = None
        train_lost_hist += deepcopy(hist[0][:best_epoch])
        val_loss_hist += deepcopy(hist[1][:best_epoch])


        print(f"\nCurrent loss is {current_loss:.7f}    Best loss from previous discrete optim is {prev_best_loss}")
        if best_search_loss > current_loss:
            best_search_loss = current_loss
            best_network = currentNetwork
            best_network_optimizer_state = deepcopy(current_network_optimizer_state)
            best_train_lost_hist = train_lost_hist
            best_val_loss_hist = val_loss_hist
            print('-> best rank update so far:', i,j)

    return best_network,best_search_loss,best_network_optimizer_state,best_train_lost_hist,best_val_loss_hist
Example #9
0
def training(tensor_list, initial_epochs, train_data, 
            loss_fun, val_data, epochs, other_args):
    """
    This function run the continuous optimization routine with a small tweak:
    if there is no improvement of the loss in the first [initial_epochs] epochs, 
    the learning rate is reduced by a factor of 0.5 and optimization is 
    restarted from the beginning.
    If initial epochs is None, the continuous_optim from core_code is called directly. 
    All arguments are the same as cc.continuous_optim 
    except for [initial_epochs] and the following additional optional arguments 
    for other_args:
        only_new_slice: either None, in which case optimization is performed on all 
                        parameters, or if set to a tuple of indices (corresponding 
                        to a newly incremented edge) it will only optimize the new 
                        slices of the corresponding core tensors. (default=False)                                                        
    """

    only_new_slice  = other_args['only_new_slice'] if 'only_new_slice' in other_args else False
    if only_new_slice:
        i,j = only_new_slice
        def grad_masking_function(tensor_list):
            nonlocal i,j
            for k in range(len(tensor_list)):
                if k == i:
                    tensor_list[i].grad.permute([j]+list(range(0,j))+list(range(j+1,len(currentNetwork))))[:-1,:,...] *= 0
                elif k == j:
                    tensor_list[j].grad.permute([i]+list(range(0,i))+list(range(i+1,len(currentNetwork))))[:-1,:,...] *= 0
                else:
                    tensor_list[k].grad *= 0
    else:
        grad_masking_function = None

    args = deepcopy(other_args)
    args["grad_masking_function"] = grad_masking_function

    if initial_epochs is None:
        return cc.continuous_optim(tensor_list, train_data, 
            loss_fun, val_data=val_data, epochs=epochs, other_args=args)

    args["hist"] = True
    current_network_optimizer_state = other_args["optimizer_state"] if "optimizer_state" in args else {}
    args["save_optimizer_state"] = True
    args["optimizer_state"] = current_network_optimizer_state
    currentNetwork = cc.copy_network(tensor_list)
    remaining_epochs = epochs - initial_epochs if epochs else None
    hist = None
    while hist is None or (hist[0][0] < hist[0][-1]):
        if hist:
            if "load_optimizer_state" in args and args["load_optimizer_state"]:
                args["load_optimizer_state"]["optimizer_state"]['param_groups'][0]['lr'] /= 2
                lr = args["load_optimizer_state"]["optimizer_state"]['param_groups'][0]['lr']
            else:
                args["lr"] /= 2
                lr = args["lr"]
            print(f"\n[!] No progress in first {initial_epochs} epochs, starting again with smaller learning rate ({lr})")
            currentNetwork = cc.copy_network(tensor_list)
        [currentNetwork, first_loss, current_loss, best_epoch, hist] = cc.continuous_optim(currentNetwork, train_data, 
            loss_fun, val_data=val_data, epochs=initial_epochs, other_args=args)

    hist_initial = [h[:best_epoch].tolist() for h in hist]
    best_epoch_initial = best_epoch

    args["load_optimizer_state"] = current_network_optimizer_state
    [currentNetwork, first_loss, current_loss, best_epoch, hist] = cc.continuous_optim(currentNetwork, train_data, 
            loss_fun, val_data=val_data, epochs=remaining_epochs, other_args=args)
    
    hist = [h_init + h[:best_epoch].tolist() for h_init,h in zip(hist_initial,hist)]
    best_epoch += best_epoch_initial

    return [currentNetwork, first_loss, current_loss, best_epoch, hist] if "hist" in other_args else [currentNetwork, first_loss, current_loss]
def discrete_optim_template(tensor_list,
                            train_data,
                            loss_fun,
                            val_data=None,
                            other_args=dict()):
    """
    Train a tensor network using discrete optimization over TN ranks

    Args:
        tensor_list: List of tensors encoding the network being trained
        train_data:  The data used to train the network
        loss_fun:    Scalar-valued loss function of the type 
                        tens_list, data -> scalar_loss
                     (This depends on the task being learned)
        val_data:    The data used for validation, which can be used to
                     for early stopping within continuous optimization
                     calls within discrete optimization loop
        other_args:  Dictionary of other arguments for the optimization, 
                     with some options below (feel free to add more)

                        epochs: Number of epochs for 
                                continuous optimization     (default=10)
                        optim:  Choice of Pytorch optimizer (default='SGD')
                        lr:     Learning rate for optimizer (default=1e-3)
                        bsize:  Minibatch size for training (default=100)
                        reps:   Number of times to repeat 
                                training data per epoch     (default=1)
                        cprint: Whether to print info from
                                continuous optimization     (default=True)
                        dprint: Whether to print info from
                                discrete optimization       (default=True)
                        dhist:  Whether to return losses
                                from intermediate TNs       (default=False)
    
    Returns:
        better_list: List of tensors with same length as tensor_list, but
                     having been optimized using the discrete optimization
                     algorithm. The TN ranks of better_list will be larger
                     than those of tensor_list.
        first_loss:  Initial loss of the model on the validation set, 
                     before any training. If no val set is provided, the
                     first training loss is instead returned
        best_loss:   The value of the validation/training loss for the
                     model output as better_list
        loss_record: If dhist=True in other_args, all values of best_loss
                     associated with intermediate optimized TNs will be
                     returned as a PyTorch vector, with loss_record[0]
                     giving the initial loss of the model, and
                     loss_record[-1] equal to best_loss value returned
                     by discrete_optim.
    """
    # Check input and initialize local record variables
    epochs = other_args['epochs'] if 'epochs' in other_args else 10
    dprint = other_args['dprint'] if 'dprint' in other_args else True
    dhist = other_args['dhist'] if 'dhist' in other_args else False
    loss_rec, first_loss, best_loss, best_network = [], None, None, None
    if dhist: loss_record = []  # (train_record, val_record)

    # Function to maybe print, conditioned on `dprint`
    m_print = lambda s: print(s) if dprint else None

    # Function to record loss information
    def record_loss(new_loss, new_network):
        # Load record variables from outer scope
        nonlocal loss_rec, first_loss, best_loss, best_network

        # Check for best loss
        if best_loss is None or new_loss < best_loss:
            best_loss, best_network = new_loss, new_network

        # Add new loss to our loss record
        if not dhist: return
        nonlocal loss_record
        loss_record.append(new_loss)

    # Copy tensor_list so the original is unchanged
    tensor_list = cc.copy_network(tensor_list)

    # Define a function giving the stop condition for the discrete
    # optimization procedure. I'm using a simple example here which could
    # work for greedy or random walk searches, but for other optimization
    # methods this could be trivial
    stop_cond = generate_stop_cond(cc.get_indims(tensor_list))

    # Iteratively increment ranks of tensor_list and train via
    # continuous_optim, at each stage using a search procedure to
    # test out different ranks before choosing just one to increase
    stage = 0
    better_network, better_loss = tensor_list, 1e10
    while not stop_cond(better_network):
        if first_loss is None:
            # Record initial loss of TN model
            first_args = other_args
            first_args["print"] = first_args["hist"] = False
            _, first_loss, _ = cc.continuous_optim(tensor_list,
                                                   train_data,
                                                   loss_fun,
                                                   epochs=1,
                                                   val_data=val_data,
                                                   other_args=first_args)
            m_print("Initial model has TN ranks")
            if dprint: cc.print_ranks(tensor_list)
            m_print(f"Initial loss is {first_loss:.3f}")
            continue
        m_print(f"STAGE {stage}")

        # Try out training different network ranks and assign network
        # with best ranks to better_network
        #
        # TODO: This is your part to write! Use new variables for the loss
        #       and network being tried out, with the network being
        #       initialized from better_network. When you've found a better
        #       TN, make sure to update better_network and better_loss to
        #       equal the parameters and loss of that TN.
        # At some point this code will call the continuous optimization
        # loop, in which case you can use the following command:
        #
        # trained_tn, init_loss, final_loss = cc.continuous_optim(my_tn,
        #                                         train_data, loss_fun,
        #                                         epochs=epochs,
        #                                         val_data=val_data,
        #                                         other_args=other_args)

        # Record the loss associated with the best network from this
        # discrete optimization loop
        record_loss(better_loss, better_network)
        stage += 1

    if dhist:
        loss_record = tuple(torch.tensor(fr) for fr in loss_record)
        return best_network, first_loss, best_loss, loss_record
    else:
        return best_network, first_loss, best_loss