Exemplo n.º 1
0
def training(tensor_list, initial_epochs, train_data, 
            loss_fun, val_data, epochs, other_args):
    """
    This function run the continuous optimization routine with a small tweak: if there is no improvement of the loss in the 
    first [initial_epochs] epochs, the learning rate is reduced by a factor of 0.5 and optimization is restarted from the beginning,
    All arguments are the same as cc.continuous_optim except for [initial_epochs].
    """

    if initial_epochs is None:
        return cc.continuous_optim(currentNetwork, train_data, 
            loss_fun, val_data=val_data, epochs=epochs, other_args=args)

    args = deepcopy(other_args)
    args["hist"] = True
    current_network_optimizer_state = other_args["optimizer_state"] if "optimizer_state" in args else {}
    args["save_optimizer_state"] = True
    args["optimizer_state"] = current_network_optimizer_state
    currentNetwork = cc.copy_network(tensor_list)
    remaining_epochs = epochs - initial_epochs if epochs else None
    hist = None
    while hist is None or (hist[0][0] < hist[0][-1]):
        if hist:
            if "load_optimizer_state" in args and args["load_optimizer_state"]:
                args["load_optimizer_state"]["optimizer_state"]['param_groups'][0]['lr'] /= 2
                lr = args["load_optimizer_state"]["optimizer_state"]['param_groups'][0]['lr']
            else:
                args["lr"] /= 2
                lr = args["lr"]
            print(f"\n[!] No progress in first {initial_epochs} epochs, starting again with smaller learning rate ({lr})")
            currentNetwork = cc.copy_network(tensor_list)
        [currentNetwork, first_loss, current_loss, hist] = cc.continuous_optim(currentNetwork, train_data, 
            loss_fun, val_data=val_data, epochs=initial_epochs, other_args=args)

    args["load_optimizer_state"] = current_network_optimizer_state
    [currentNetwork, first_loss, current_loss, hist] = cc.continuous_optim(currentNetwork, train_data, 
            loss_fun, val_data=val_data, epochs=remaining_epochs, other_args=args)
    
    return [currentNetwork, first_loss, current_loss, hist] if "hist" in other_args else [currentNetwork, first_loss, current_loss]
def discrete_optim_template(tensor_list,
                            train_data,
                            loss_fun,
                            val_data=None,
                            other_args=dict()):
    """
    Train a tensor network using discrete optimization over TN ranks
    Args:
        tensor_list: List of tensors encoding the network being trained
        train_data:  The data used to train the network(target Tensor)
        loss_fun:    Scalar-valued loss function of the type 
                        tens_list, data -> scalar_loss
                     (This depends on the task being learned)
        val_data:    The data used for validation, which can be used to
                     for early stopping within continuous optimization
                     calls within discrete optimization loop
        other_args:  Dictionary of other arguments for the optimization, 
                     with some options below (feel free to add more)
                        epochs: Number of epochs for 
                                continuous optimization     (default=10)
                        optim:  Choice of Pytorch optimizer (default='SGD')
                        lr:     Learning rate for optimizer (default=1e-3)
                        bsize:  Minibatch size for training (default=100)
                        reps:   Number of times to repeat 
                                training data per epoch     (default=1)
                        cprint: Whether to print info from
                                continuous optimization     (default=True)
                        dprint: Whether to print info from
                                discrete optimization       (default=True)
                        dhist:  Whether to return losses
                                from intermediate TNs       (default=False)
    
    Returns:
        better_list: List of tensors with same length as tensor_list, but
                     having been optimized using the discrete optimization
                     algorithm. The TN ranks of better_list will be larger
                     than those of tensor_list.
        first_loss:  Initial loss of the model on the validation set, 
                     before any training. If no val set is provided, the
                     first training loss is instead returned
        best_loss:   The value of the validation/training loss for the
                     model output as better_list
        loss_record: If dhist=True in other_args, all values of best_loss
                     associated with intermediate optimized TNs will be
                     returned as a PyTorch vector, with loss_record[0]
                     giving the initial loss of the model, and
                     loss_record[-1] equal to best_loss value returned
                     by discrete_optim.
    """
    # Check input and initialize local record variables
    epochs = other_args['epochs'] if 'epochs' in other_args else 10
    dprint = other_args['dprint'] if 'dprint' in other_args else True
    dhist = other_args['dhist'] if 'dhist' in other_args else False
    loss_rec, first_loss, best_loss, best_network = [], None, None, None
    if dhist: loss_record = []  # (train_record, val_record)

    # Function to maybe print, conditioned on `dprint`
    m_print = lambda s: print(s) if dprint else None

    # Function to record loss information
    def record_loss(new_loss, new_network):
        # Load record variables from outer scope
        nonlocal loss_rec, first_loss, best_loss, best_network

        # Check for best loss
        if best_loss is None or new_loss < best_loss:
            best_loss, best_network = new_loss, new_network

        # Add new loss to our loss record
        if not dhist: return
        nonlocal loss_record
        loss_record.append(new_loss)

    # Copy tensor_list so the original is unchanged
    tensor_list = cc.copy_network(tensor_list)

    # Define a function giving the stop condition for the discrete
    # optimization procedure. I'm using a simple example here which could
    # work for greedy or random walk searches, but for other optimization
    # methods this could be trivial
    stop_cond = generate_stop_cond(cc.get_indims(tensor_list))

    ###########***************where are you increasing the rank from line 123 to 137?*****
    # Iteratively increment ranks of tensor_list and train via
    # continuous_optim, at each stage using a search procedure to
    # test out different ranks before choosing just one to increase
    stage = 0
    better_network, better_loss = tensor_list, 1e10
    initialNetwork = cc.copy_network(tensor_list)  #example_tn
    while not stop_cond(better_network):
        if first_loss is None:
            # Record initial loss of TN model
            first_args = other_args.copy()
            first_args["print"] = first_args["hist"] = False
            _, first_loss, _ = cc.continuous_optim(tensor_list,
                                                   train_data,
                                                   loss_fun,
                                                   epochs=1,
                                                   val_data=val_data,
                                                   other_args=first_args)
            m_print("Initial model has TN ranks")
            m_print(f"Initial loss is {first_loss:.3f}")
            m_print("Performing initial optimization...")
            first_args["print"] = True
            initialNetwork, _, _ = cc.continuous_optim(tensor_list,
                                                       train_data,
                                                       loss_fun,
                                                       epochs=2000,
                                                       val_data=val_data,
                                                       other_args=first_args)
            continue
        m_print(f"STAGE {stage}")
        stage += 1

        ##################################line 139 onward are new added acode#############
        # Try out training different network ranks and assign network
        # with best ranks to better_network
        #
        # TODO: This is your part to write! Use new variables for the loss
        #       and network being tried out, with the network being
        #       initialized from better_network. When you've found a better
        #       TN, make sure to update better_network and better_loss to
        #       equal the parameters and loss of that TN.
        # At some point this code will call the continuous optimization
        # loop, in which case you can use the following command:
        #
        # trained_tn, init_loss, final_loss = cc.continuous_optim(my_tn,
        #                                         train_data, loss_fun,
        #                                         epochs=epochs,
        #                                         val_data=val_data,
        #                                         other_args=other_args)

        #best_loss = float('inf')
        prev_loss = float('inf')
        num_cores = len(tensor_list)
        max_params = torch.prod(torch.tensor(
            (cc.get_indims(train_data))))  #train data is the target tensor
        loss_record = []
        for i in range(num_cores):
            for j in range(i + 1, num_cores):
                m_print(f"Testing i={i}, j={j}")

                #reset currentNetwork to contuniue with greedy at another point
                currentNetwork = cc.copy_network(initialNetwork)
                #increase rank along a chosen dimension
                currentNetwork = cc.increase_rank(currentNetwork, i, j, 1,
                                                  1e-6)

                stop_cond = generate_stop_cond(cc.get_indims(
                    currentNetwork))  #***********need to add param=-1
                if stop_cond(
                        currentNetwork
                ) == True:  #i.e. break if the number of parameters in trained tensor exceeds number of param. in target tensor
                    break

                #solve continuos optimization part, train_data = target_tensor
                [currentNetwork, first_loss,
                 current_loss] = cc.continuous_optim(currentNetwork,
                                                     train_data,
                                                     loss_fun,
                                                     epochs=epochs,
                                                     val_data=val_data,
                                                     other_args=other_args)
                if prev_loss > current_loss:
                    prev_loss = current_loss
                    better_network = currentNetwork
                    numParam = cc.num_params(better_network)
                    better_loss = current_loss
                    m_print("BEST CONFIG")
                else:
                    m_print("Not best config")
                m_print(f"{first_loss:.3f} -> {current_loss:.3f}\n")

        #update current point to the new point (i.e. better_network) that gave lower loss
        initialNetwork = cc.copy_network(better_network)
        loss_record.append(better_loss)

    return best_network, first_loss, better_loss  #, loss_record
Exemplo n.º 3
0
def greedy_optim(tensor_list, train_data, loss_fun, 
                            val_data=None, other_args=dict(),max_iter=None):
    """
    Train a tensor network using discrete optimization over TN ranks
    Args:
        tensor_list: List of tensors encoding the network being trained
        train_data:  The data used to train the network(target Tensor)
        loss_fun:    Scalar-valued loss function of the type 
                        tens_list, data -> scalar_loss
                     (This depends on the task being learned)
        val_data:    The data used for validation, which can be used to
                     for early stopping within continuous optimization
                     calls within discrete optimization loop
        other_args:  Dictionary of other arguments for the optimization, 
                     with some options below (feel free to add more)
                        epochs: Number of epochs for 
                                continuous optimization     (default=10)
                        max_iter: Maximum number of iterations
                                  for greedy search           (default=10)
                        optim:  Choice of Pytorch optimizer (default='SGD')
                        lr:     Learning rate for optimizer (default=1e-3)
                        bsize:  Minibatch size for training (default=100)
                        reps:   Number of times to repeat 
                                training data per epoch     (default=1)
                        cprint: Whether to print info from
                                continuous optimization     (default=True)
                        dprint: Whether to print info from
                                discrete optimization       (default=True)
                        dhist:  Whether to return losses
                                from intermediate TNs       (default=False)
                        search_epochs: Number of epochs to use to identify the
                                best rank 1 update. If None, the epochs argument
                                is used.                    (default=None)
                        loss_threshold: if loss gets below this threshold, 
                        discrete optimization is stopped
                                                            (default=1e-5)
                        initial_epochs: Number of epochs after which the 
                        learning rate is reduced and optimization is restarted
                        if there is no improvement in the loss.
                                                            (default=None)
    
    Returns:
        better_list: List of tensors with same length as tensor_list, but
                     having been optimized using the discrete optimization
                     algorithm. The TN ranks of better_list will be larger
                     than those of tensor_list.
        first_loss:  Initial loss of the model on the validation set, 
                     before any training. If no val set is provided, the
                     first training loss is instead returned
        best_loss:   The value of the validation/training loss for the
                     model output as better_list
        loss_record: If dhist=True in other_args, all values of best_loss
                     associated with intermediate optimized TNs will be
                     returned as a PyTorch vector, with loss_record[0]
                     giving the initial loss of the model, and
                     loss_record[-1] equal to best_loss value returned
                     by discrete_optim.
    """
    # Check input and initialize local record variables
    epochs  = other_args['epochs'] if 'epochs' in other_args else 10
    max_iter  = other_args['max_iter'] if 'max_iter' in other_args else 10
    dprint  = other_args['dprint'] if 'dprint' in other_args else True
    cprint  = other_args['cprint'] if 'cprint' in other_args else True
    dhist  = other_args['dhist']  if 'dhist'  in other_args else False
    search_epochs  = other_args['search_epochs']  if 'search_epochs'  in other_args else None
    loss_threshold  = other_args['loss_threshold']  if 'loss_threshold'  in other_args else 1e-5
    initial_epochs  = other_args['initial_epochs'] if 'initial_epochs' in other_args else None
    stop_cond = lambda loss: loss < loss_threshold


    if dhist: loss_record = []    # (train_record, val_record)

    # Function to maybe print, conditioned on `dprint`
    m_print = lambda s: print(s) if dprint else None


    # Copy tensor_list so the original is unchanged
    tensor_list = cc.copy_network(tensor_list)


    # Iteratively increment ranks of tensor_list and train via
    # continuous_optim, at each stage using a search procedure to 
    # test out different ranks before choosing just one to increase
    stage = 0
    loss_record, best_loss, best_network = [], np.infty, None
    


    while not stop_cond(best_loss) and stage < max_iter:
        stage += 1
        if best_loss is np.infty: # first continuous optimization
            tensor_list, first_loss, best_loss = training(
                tensor_list, initial_epochs, train_data, loss_fun, epochs=epochs, 
                val_data=val_data,other_args=other_args)
            m_print("Initial model has TN ranks")
            if dprint: cc.print_ranks(tensor_list)
            m_print(f"Initial loss is {best_loss:.7f}")

            initialNetwork = cc.copy_network(tensor_list) 
            best_network = cc.copy_network(tensor_list) 
            loss_record += [first_loss,best_loss]

            


        m_print(f"\n\n**** Discrete optimization - iteration {stage} ****\n\n\n")  
        best_search_loss = best_loss

        for i in range(len(initialNetwork)):
            for j in range(i+1, len(initialNetwork)):
                currentNetwork = cc.copy_network(initialNetwork)
                #increase rank along a chosen dimension
                currentNetwork = cc.increase_rank(currentNetwork,i, j, 1, 1e-6)



                print('\ntesting rank increment for i =', i, 'j = ', j)
                if search_epochs: # we d only a few epochs to identify the most promising rank update

                    # function to zero out the gradient of all entries except for the new slices
                    def grad_masking_function(tensor_list):
                        nonlocal i,j
                        for k in range(len(tensor_list)):
                            if k == i:
                                tensor_list[i].grad.permute([j]+list(range(0,j))+list(range(j+1,len(currentNetwork))))[:-1,:,...] *= 0
                            elif k == j:
                                tensor_list[j].grad.permute([i]+list(range(0,i))+list(range(i+1,len(currentNetwork))))[:-1,:,...] *= 0
                            else:
                                tensor_list[k].grad *= 0

                    # we first optimize only the new slices for a few epochs
                    print("optimize new slices for a few epochs")
                    search_args = dict(other_args)
                    search_args["hist"] = True
                    current_network_optimizer_state = {}
                    search_args["save_optimizer_state"] = True
                    search_args["optimizer_state"] = current_network_optimizer_state
                    search_args["grad_masking_function"] = grad_masking_function
                    [currentNetwork, first_loss, current_loss, hist] = training(currentNetwork, initial_epochs, train_data, 
                        loss_fun, val_data=val_data, epochs=search_epochs, other_args=search_args)
                    first_loss = hist[0][0]

                    # We then optimize all parameters for a few epochs
                    print("\noptimize all parameters for a few epochs")
                    search_args["grad_masking_function"] = None
                    search_args["load_optimizer_state"] = dict(current_network_optimizer_state)
                    [currentNetwork, first_loss, current_loss, hist] = training(currentNetwork, initial_epochs, train_data, 
                        loss_fun, val_data=val_data, epochs=search_epochs , 
                        other_args=search_args)
                    search_args["load_optimizer_state"] = None

                else: # we fully optimize the network in the search phase
                    [currentNetwork, first_loss, current_loss] = cc.continuous_optim(currentNetwork, train_data, 
                        loss_fun, val_data=val_data, epochs=epochs, 
                        other_args=other_args)
                

                m_print(f"\nCurrent loss is {current_loss:.7f}    Best loss from previous discrete optim is {best_loss}")
                if best_search_loss > current_loss:
                    best_search_loss = current_loss
                    best_network = currentNetwork
                    best_network_optimizer_state = deepcopy(current_network_optimizer_state)
                    print('-> best rank update so far:', i,j)
        

        best_loss = best_search_loss
        # train network to convergence for the best rank increment (if search_epochs is set, 
        # otherwise the best network is already trained to convergence / max_epochs)
        if search_epochs:
            print('\ntraining best network until max_epochs/convergence...')

            other_args["load_optimizer_state"] = best_network_optimizer_state
            current_network_optimizer_state = {}
            other_args["save_optimizer_state"] = True
            other_args["optimizer_state"] = current_network_optimizer_state
            [best_network, first_loss, best_loss] = training(best_network, initial_epochs, train_data, 
                    loss_fun, val_data=val_data, epochs=epochs, 
                    other_args=other_args)
            other_args["load_optimizer_state"] = None

        initialNetwork  = cc.copy_network(best_network)
        loss_record.append((stage, cc.num_params(best_network), float(best_loss)))
        print('\nbest TN:')
        cc.print_ranks(best_network)
        print('number of params:',cc.num_params(best_network))
        print(loss_record)
    return best_network, first_loss, best_loss 
Exemplo n.º 4
0
def training(tensor_list, initial_epochs, train_data, 
            loss_fun, val_data, epochs, other_args):
    """
    This function run the continuous optimization routine with a small tweak:
    if there is no improvement of the loss in the first [initial_epochs] epochs, 
    the learning rate is reduced by a factor of 0.5 and optimization is 
    restarted from the beginning.
    If initial epochs is None, the continuous_optim from core_code is called directly. 
    All arguments are the same as cc.continuous_optim 
    except for [initial_epochs] and the following additional optional arguments 
    for other_args:
        only_new_slice: either None, in which case optimization is performed on all 
                        parameters, or if set to a tuple of indices (corresponding 
                        to a newly incremented edge) it will only optimize the new 
                        slices of the corresponding core tensors. (default=False)                                                        
    """

    only_new_slice  = other_args['only_new_slice'] if 'only_new_slice' in other_args else False
    if only_new_slice:
        i,j = only_new_slice
        def grad_masking_function(tensor_list):
            nonlocal i,j
            for k in range(len(tensor_list)):
                if k == i:
                    tensor_list[i].grad.permute([j]+list(range(0,j))+list(range(j+1,len(currentNetwork))))[:-1,:,...] *= 0
                elif k == j:
                    tensor_list[j].grad.permute([i]+list(range(0,i))+list(range(i+1,len(currentNetwork))))[:-1,:,...] *= 0
                else:
                    tensor_list[k].grad *= 0
    else:
        grad_masking_function = None

    args = deepcopy(other_args)
    args["grad_masking_function"] = grad_masking_function

    if initial_epochs is None:
        return cc.continuous_optim(tensor_list, train_data, 
            loss_fun, val_data=val_data, epochs=epochs, other_args=args)

    args["hist"] = True
    current_network_optimizer_state = other_args["optimizer_state"] if "optimizer_state" in args else {}
    args["save_optimizer_state"] = True
    args["optimizer_state"] = current_network_optimizer_state
    currentNetwork = cc.copy_network(tensor_list)
    remaining_epochs = epochs - initial_epochs if epochs else None
    hist = None
    while hist is None or (hist[0][0] < hist[0][-1]):
        if hist:
            if "load_optimizer_state" in args and args["load_optimizer_state"]:
                args["load_optimizer_state"]["optimizer_state"]['param_groups'][0]['lr'] /= 2
                lr = args["load_optimizer_state"]["optimizer_state"]['param_groups'][0]['lr']
            else:
                args["lr"] /= 2
                lr = args["lr"]
            print(f"\n[!] No progress in first {initial_epochs} epochs, starting again with smaller learning rate ({lr})")
            currentNetwork = cc.copy_network(tensor_list)
        [currentNetwork, first_loss, current_loss, best_epoch, hist] = cc.continuous_optim(currentNetwork, train_data, 
            loss_fun, val_data=val_data, epochs=initial_epochs, other_args=args)

    hist_initial = [h[:best_epoch].tolist() for h in hist]
    best_epoch_initial = best_epoch

    args["load_optimizer_state"] = current_network_optimizer_state
    [currentNetwork, first_loss, current_loss, best_epoch, hist] = cc.continuous_optim(currentNetwork, train_data, 
            loss_fun, val_data=val_data, epochs=remaining_epochs, other_args=args)
    
    hist = [h_init + h[:best_epoch].tolist() for h_init,h in zip(hist_initial,hist)]
    best_epoch += best_epoch_initial

    return [currentNetwork, first_loss, current_loss, best_epoch, hist] if "hist" in other_args else [currentNetwork, first_loss, current_loss]
Exemplo n.º 5
0
def discrete_optim_template(tensor_list,
                            train_data,
                            loss_fun,
                            val_data=None,
                            other_args=dict()):
    """
    Train a tensor network using discrete optimization over TN ranks

    Args:
        tensor_list: List of tensors encoding the network being trained
        train_data:  The data used to train the network
        loss_fun:    Scalar-valued loss function of the type 
                        tens_list, data -> scalar_loss
                     (This depends on the task being learned)
        val_data:    The data used for validation, which can be used to
                     for early stopping within continuous optimization
                     calls within discrete optimization loop
        other_args:  Dictionary of other arguments for the optimization, 
                     with some options below (feel free to add more)

                        epochs: Number of epochs for 
                                continuous optimization     (default=10)
                        optim:  Choice of Pytorch optimizer (default='SGD')
                        lr:     Learning rate for optimizer (default=1e-3)
                        bsize:  Minibatch size for training (default=100)
                        reps:   Number of times to repeat 
                                training data per epoch     (default=1)
                        cprint: Whether to print info from
                                continuous optimization     (default=True)
                        dprint: Whether to print info from
                                discrete optimization       (default=True)
                        dhist:  Whether to return losses
                                from intermediate TNs       (default=False)
    
    Returns:
        better_list: List of tensors with same length as tensor_list, but
                     having been optimized using the discrete optimization
                     algorithm. The TN ranks of better_list will be larger
                     than those of tensor_list.
        first_loss:  Initial loss of the model on the validation set, 
                     before any training. If no val set is provided, the
                     first training loss is instead returned
        best_loss:   The value of the validation/training loss for the
                     model output as better_list
        loss_record: If dhist=True in other_args, all values of best_loss
                     associated with intermediate optimized TNs will be
                     returned as a PyTorch vector, with loss_record[0]
                     giving the initial loss of the model, and
                     loss_record[-1] equal to best_loss value returned
                     by discrete_optim.
    """
    # Check input and initialize local record variables
    epochs = other_args['epochs'] if 'epochs' in other_args else 10
    dprint = other_args['dprint'] if 'dprint' in other_args else True
    dhist = other_args['dhist'] if 'dhist' in other_args else False
    loss_rec, first_loss, best_loss, best_network = [], None, None, None
    if dhist: loss_record = []  # (train_record, val_record)

    # Function to maybe print, conditioned on `dprint`
    m_print = lambda s: print(s) if dprint else None

    # Function to record loss information
    def record_loss(new_loss, new_network):
        # Load record variables from outer scope
        nonlocal loss_rec, first_loss, best_loss, best_network

        # Check for best loss
        if best_loss is None or new_loss < best_loss:
            best_loss, best_network = new_loss, new_network

        # Add new loss to our loss record
        if not dhist: return
        nonlocal loss_record
        loss_record.append(new_loss)

    # Copy tensor_list so the original is unchanged
    tensor_list = cc.copy_network(tensor_list)

    # Define a function giving the stop condition for the discrete
    # optimization procedure. I'm using a simple example here which could
    # work for greedy or random walk searches, but for other optimization
    # methods this could be trivial
    stop_cond = generate_stop_cond(cc.get_indims(tensor_list))

    # Iteratively increment ranks of tensor_list and train via
    # continuous_optim, at each stage using a search procedure to
    # test out different ranks before choosing just one to increase
    stage = 0
    better_network, better_loss = tensor_list, 1e10
    while not stop_cond(better_network):
        if first_loss is None:
            # Record initial loss of TN model
            first_args = other_args
            first_args["print"] = first_args["hist"] = False
            _, first_loss, _ = cc.continuous_optim(tensor_list,
                                                   train_data,
                                                   loss_fun,
                                                   epochs=1,
                                                   val_data=val_data,
                                                   other_args=first_args)
            m_print("Initial model has TN ranks")
            if dprint: cc.print_ranks(tensor_list)
            m_print(f"Initial loss is {first_loss:.3f}")
            continue
        m_print(f"STAGE {stage}")

        # Try out training different network ranks and assign network
        # with best ranks to better_network
        #
        # TODO: This is your part to write! Use new variables for the loss
        #       and network being tried out, with the network being
        #       initialized from better_network. When you've found a better
        #       TN, make sure to update better_network and better_loss to
        #       equal the parameters and loss of that TN.
        # At some point this code will call the continuous optimization
        # loop, in which case you can use the following command:
        #
        # trained_tn, init_loss, final_loss = cc.continuous_optim(my_tn,
        #                                         train_data, loss_fun,
        #                                         epochs=epochs,
        #                                         val_data=val_data,
        #                                         other_args=other_args)

        # Record the loss associated with the best network from this
        # discrete optimization loop
        record_loss(better_loss, better_network)
        stage += 1

    if dhist:
        loss_record = tuple(torch.tensor(fr) for fr in loss_record)
        return best_network, first_loss, best_loss, loss_record
    else:
        return best_network, first_loss, best_loss