Example #1
0
def randomwalk_decomposition(goal_tn):
    loss_fun = cc.tensor_recovery_loss
    input_dims = [t.shape[i] for i, t in enumerate(goal_tn)]
    base_tn = cc.random_tn(input_dims, rank=1)

    # Initialize the first tensor network close to zero
    for i in range(len(base_tn)):
        base_tn[i] /= 10
    base_tn = cc.make_trainable(base_tn)

    trained_tn, best_loss, loss_record = randomwalk_optim(
        base_tn,
        goal_tn,
        loss_fun,
        other_args={
            'cprint': True,
            'epochs': 10000,
            'max_iter': 20,
            'lr': 0.01,
            'optim': 'RMSprop',
            'search_epochs': 80,
            'cvg_threshold': 1e-10,
            'stop_on_plateau': {
                'mode': 'min',
                'patience': 100,
                'threshold': 1e-7
            },
            'dyn_print': True,
            'initial_epochs': 10
        })
    return loss_record
Example #2
0
def randomwalk_regression(train_data, val_data=None):
    loss_fun = cc.regression_loss
    input_dims = [t.shape[1] for t in train_data[0]]
    base_tn = cc.random_tn(input_dims, rank=1)

    # Initialize the first tensor network close to zero
    for i in range(len(base_tn)):
        base_tn[i] /= 10
    base_tn = cc.make_trainable(base_tn)

    trained_tn, best_loss, loss_record = randomwalk_optim(
        base_tn,
        train_data,
        loss_fun,
        val_data=val_data,
        other_args={
            'cprint': True,
            'epochs': None,
            'max_iter': 20,
            'lr': 0.01,
            'optim': 'RMSprop',
            'search_epochs': 80,
            'cvg_threshold': 1e-10,
            'bsize': 100,
            'is_reg': True,
            'stop_on_plateau': {
                'mode': 'min',
                'patience': 50,
                'threshold': 1e-7
            },
            'dyn_print': True,
            'initial_epochs': 10
        })
    return loss_record
Example #3
0
def greedy_completion(dataset, input_dims, initial_network=None,filename=None):

    loss_fun = cc.completion_loss
    from generate_tensor_ring import generate_tensor_ring

    base_tn = cc.random_tn(input_dims, rank=1)

    # Initialize the first tensor network close to zero
    for i in range(len(base_tn)):
        base_tn[i] /= 1
    base_tn = cc.make_trainable(base_tn)


    if initial_network:
        base_tn = initial_network

    # create list of all edges allowed in a TR decomposition    
    #ndims = len(base_tn)
    #tr_edges = [(i,j) for i in range(ndims) for j in range(i+1,ndims) if i+1==j] + [(0,ndims-1)]


    lr_scheduler = lambda optimizer: ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=50, verbose=True,threshold=1e-7)
    trained_tn, best_loss, loss_record = greedy_optim(base_tn, 
                                                      dataset, loss_fun, 
                                                      find_best_edge=greedy_find_best_edge,
                                                      other_args={'cprint':True, 'epochs':20000, 'max_iter':100, 
                                                                  'lr':0.01, 'optim':'RMSprop', 'search_epochs':20, 
                                                                  'cvg_threshold':1e-10, 
                                                                  #'stop_on_plateau':{'mode':'min', 'patience':50, 'threshold':1e-7},
                                                                  'dyn_print':True,'initial_epochs':10,'bsize':-1,
                                                                  'rank_increment':2,
                                                                  #'allowed_edges':tr_edges
                                                                  'lr_scheduler':lr_scheduler,
                                                                  'filename':filename
                                                                  })


    return loss_record
Example #4
0
def randomwalk_optim(tensor_list,
                     train_data,
                     loss_fun,
                     val_data=None,
                     other_args=dict(),
                     max_iter=None):
    """
    Train a tensor network using discrete optimization over TN ranks
    Args:
        tensor_list: List of tensors encoding the network being trained
        train_data:  The data used to train the network(target Tensor)
        loss_fun:    Scalar-valued loss function of the type 
                        tens_list, data -> scalar_loss
                     (This depends on the task being learned)
        val_data:    The data used for validation, which can be used to
                     for early stopping within continuous optimization
                     calls within discrete optimization loop
        other_args:  Dictionary of other arguments for the optimization, 
                     with some options below (feel free to add more)
                        epochs: Number of epochs for 
                                continuous optimization     (default=10)
                        max_iter: Maximum number of iterations
                                  for random walk search    (default=10)
                        optim:  Choice of Pytorch optimizer (default='SGD')
                        lr:     Learning rate for optimizer (default=1e-3)
                        bsize:  Minibatch size for training (default=100)
                        reps:   Number of times to repeat 
                                training data per epoch     (default=1)
                        cprint: Whether to print info from
                                continuous optimization     (default=True)
                        dprint: Whether to print info from
                                discrete optimization       (default=True)
                        search_epochs: Number of epochs to use to identify the
                                best rank 1 update. If None, the epochs argument
                                is used.                    (default=None)
                        loss_threshold: if loss gets below this threshold, 
                            discrete optimization is stopped
                                                            (default=1e-5)
                        initial_epochs: Number of epochs after which the 
                            learning rate is reduced and optimization is restarted
                            if there is no improvement in the loss.
                                                            (default=None)
                        rank_increment: how much should a rank be increase at
                            each discrete iterations        (default=1)
                        stop_on_plateau: a dictionnary containing keys
                                mode  (min/max)
                                patience
                                threshold
                            used to stop continuous optimizaion when plateau
                            is detected                     (default=None)
                        gradient_hook: this is a hack, please ignore...

    
    Returns:
        better_list: List of tensors with same length as tensor_list, but
                     having been optimized using the discrete optimization
                     algorithm. The TN ranks of better_list will be larger
                     than those of tensor_list.
        best_loss:   The value of the validation/training loss for the
                     model output as better_list
        loss_record: If dhist=True in other_args, this records the history of 
                     all losses for discrete and continuous optimization. This
                     is a list of dictionnaries with keys
                           iter,num_params,network,loss,train_loss_hist,val_loss_hist
                     where
                     iter: iteration of the discrete optimization
                     num_params: number of parameters of best network for this iteration
                     network: list of tensors for the best network for this iteration
                     loss: loss achieved by the best network in this iteration
                     train_loss_hist: history of losses for the continuous optimization
                        for this iteration starting from previous best_network to 
                        the epoch where the new best network was found
                     TODO: add val_loss_hist to the list of keys
    """
    # Check input and initialize local record variables
    epochs = other_args['epochs'] if 'epochs' in other_args else 10
    max_iter = other_args['max_iter'] if 'max_iter' in other_args else 10
    dprint = other_args['dprint'] if 'dprint' in other_args else True
    cprint = other_args['cprint'] if 'cprint' in other_args else True
    search_epochs = other_args[
        'search_epochs'] if 'search_epochs' in other_args else epochs
    loss_threshold = other_args[
        'loss_threshold'] if 'loss_threshold' in other_args else 1e-5
    initial_epochs = other_args[
        'initial_epochs'] if 'initial_epochs' in other_args else None
    rank_increment = other_args[
        'rank_increment'] if 'rank_increment' in other_args else 1
    gradient_hook = other_args[
        'gradient_hook'] if 'gradient_hook' in other_args else None
    is_reg = other_args['is_reg'] if 'is_reg' in other_args else False

    stop_on_plateau = other_args[
        'stop_on_plateau'] if 'stop_on_plateau' in other_args else None
    if stop_on_plateau:
        detect_plateau = DetectPlateau(**stop_on_plateau)
        other_args[
            "stop_condition"] = lambda train_loss, val_loss: detect_plateau(
                train_loss)

    first_loss, best_loss, best_network = ([], []), None, None
    d_loss_hist = ([], [])

    other_args['hist'] = True  # we always track the history of losses

    # Function to maybe print, conditioned on `dprint`
    m_print = lambda s: print(s) if dprint else None

    # Copy tensor_list so the original is unchanged
    tensor_list = cc.copy_network(tensor_list)
    loss_record = [{
        'iter': -1,
        'network': tensor_list,
        'num_params': cc.num_params(tensor_list),
        'train_loss_hist': [0],
        'val_loss_hist': [0]
    }]

    # Define a function giving the stop condition for the discrete
    # optimization procedure. I'm using a simple example here which could
    # work for greedy or random walk searches, but for other optimization
    # methods this could be trivial
    stop_cond = lambda loss: loss < loss_threshold

    # Iteratively increment ranks of tensor_list and train via
    # continuous_optim, at each stage using a search procedure to
    # test out different ranks before choosing just one to increase
    stage = -1
    best_loss, best_network, best_network_optimizer_state = np.infty, None, None

    while not stop_cond(best_loss) and stage < max_iter:
        stage += 1
        if best_loss is np.infty:  # first continuous optimization
            tensor_list, first_loss, best_loss, best_epoch, hist = training(
                tensor_list,
                initial_epochs,
                train_data,
                loss_fun,
                epochs=epochs,
                val_data=val_data,
                other_args=other_args)
            m_print("Initial model has TN ranks")
            if dprint: cc.print_ranks(tensor_list)
            m_print(f"Initial loss is {best_loss:.7f}")

            initialNetwork = cc.copy_network(tensor_list)
            best_network = cc.copy_network(tensor_list)
            loss_record[0]["loss"] = first_loss
            loss_record.append({
                'iter': stage,
                'network': best_network,
                'num_params': cc.num_params(best_network),
                'loss': best_loss,
                'train_loss_hist': hist[0][:best_epoch],
                'val_loss_hist': hist[1][:best_epoch]
            })

        else:
            m_print(
                f"\n\n**** Discrete optimization - iteration {stage} ****\n\n\n"
            )
            best_search_loss = best_loss
            best_train_lost_hist = []
            best_val_loss_hist = []

            [i, j] = torch.randperm(len(tensor_list))[:2].tolist()
            currentNetwork = cc.copy_network(initialNetwork)
            #increase rank along a chosen dimension
            currentNetwork = cc.increase_rank(currentNetwork, i, j,
                                              rank_increment, 1e-6)
            currentNetwork = cc.make_trainable(currentNetwork)
            print('\ntesting rank increment for i =', i, 'j = ', j)

            ### Search optimization phase
            # we d only a few epochs to identify the most promising rank update

            # function to zero out the gradient of all entries except for the new slices
            def grad_masking_function(tensor_list):
                nonlocal i, j
                for k in range(len(tensor_list)):
                    if k == i:
                        tensor_list[i].grad.permute(
                            [j] + list(range(0, j)) +
                            list(range(j + 1, len(currentNetwork))))[:-1, :,
                                                                     ...] *= 0
                    elif k == j:
                        tensor_list[j].grad.permute(
                            [i] + list(range(0, i)) +
                            list(range(i + 1, len(currentNetwork))))[:-1, :,
                                                                     ...] *= 0
                    else:
                        tensor_list[k].grad *= 0

            # we first optimize only the new slices for a few epochs
            print("optimize new slices for a few epochs")
            search_args = dict(other_args)
            current_network_optimizer_state = {}
            search_args["save_optimizer_state"] = True
            search_args["optimizer_state"] = current_network_optimizer_state
            if not is_reg:
                search_args["grad_masking_function"] = grad_masking_function
            if stop_on_plateau:
                detect_plateau._reset()
            [currentNetwork, first_loss, current_loss, best_epoch,
             hist] = training(currentNetwork,
                              initial_epochs,
                              train_data,
                              loss_fun,
                              val_data=val_data,
                              epochs=search_epochs,
                              other_args=search_args)
            first_loss = hist[0][0]
            train_lost_hist = deepcopy(hist[0][:best_epoch])
            val_loss_hist = deepcopy(hist[1][:best_epoch])

            # # We then optimize all parameters for a few epochs
            # print("\noptimize all parameters for a few epochs")
            # search_args["grad_masking_function"] = None
            # search_args["load_optimizer_state"] = dict(current_network_optimizer_state)
            # if stop_on_plateau:
            #     detect_plateau._reset()
            # [currentNetwork, first_loss, current_loss, best_epoch, hist] = training(currentNetwork, initial_epochs, train_data,
            #     loss_fun, val_data=val_data, epochs=search_epochs,
            #     other_args=search_args)
            # search_args["load_optimizer_state"] = None
            # train_lost_hist += deepcopy(hist[0][:best_epoch])

            best_search_loss = current_loss
            best_network = currentNetwork
            best_network_optimizer_state = deepcopy(
                current_network_optimizer_state)
            best_train_lost_hist = train_lost_hist
            best_val_loss_hist = val_loss_hist
            print('-> best rank update so far:', i, j)

            best_loss = best_search_loss
            # train network to convergence

            print('\ntraining best network until max_epochs/convergence...')
            other_args["load_optimizer_state"] = best_network_optimizer_state
            current_network_optimizer_state = {}
            other_args["save_optimizer_state"] = True
            other_args["optimizer_state"] = current_network_optimizer_state

            if gradient_hook:  # A bit of a hack only used for tensor completion - ignore
                other_args["grad_masking_function"] = gradient_hook
            if stop_on_plateau:
                detect_plateau._reset()

            [best_network, first_loss, best_loss, best_epoch,
             hist] = training(best_network,
                              initial_epochs,
                              train_data,
                              loss_fun,
                              val_data=val_data,
                              epochs=epochs,
                              other_args=other_args)

            if gradient_hook:
                other_args["grad_masking_function"] = None

            other_args["load_optimizer_state"] = None
            best_train_lost_hist += deepcopy(hist[0][:best_epoch])
            best_val_loss_hist += deepcopy(hist[1][:best_epoch])

            initialNetwork = cc.copy_network(best_network)

            loss_record.append({
                'iter': stage,
                'network': best_network,
                'num_params': cc.num_params(best_network),
                'loss': best_loss,
                'train_loss_hist': best_train_lost_hist,
                'val_loss_hist': best_val_loss_hist
            })

            print('\nbest TN:')
            cc.print_ranks(best_network)
            print('number of params:', cc.num_params(best_network))
            print([(r['iter'], r['num_params'], float(r['loss']),
                    float(r['train_loss_hist'][0]),
                    float(r['train_loss_hist'][-1])) for r in loss_record])

    return best_network, best_loss, loss_record
def main(args):
    num_train = args.ntrain
    num_val = args.nval
    input_dims = [7, 7, 7, 7, 7]
    goal_tn = torch.load(args.path)
    base_tn = cc.random_tn(input_dims, rank=1)
    base_tn = cc.make_trainable(base_tn)
    loss_fun = cc.regression_loss
    train_data = cc.generate_regression_data(goal_tn, num_train, noise=1e-6)
    val_data = cc.generate_regression_data(goal_tn, num_val, noise=1e-6)

    best_network, first_loss, best_loss, loss_record, loss_hist, param_count, d_loss_hist = rw.randomwalk_optim(
        base_tn,
        train_data,
        loss_fun,
        val_data=val_data,
        other_args={
            'dhist': True,
            'optim': 'RMSprop',
            'max_iter': args.maxiter,
            'epochs': None,  # early stopping
            'lr': 0.001
        })

    plt.figure(figsize=(4, 3))
    plt.plot(loss_hist[0])
    plt.xlabel('Epoch')
    plt.ylabel('Training loss')
    plt.savefig('./figures/' + args.path + '_randomwalk' + '_trainloss' +
                '_.pdf',
                bbox_inches='tight')

    plt.figure(figsize=(4, 3))
    plt.plot(param_count, d_loss_hist[0])
    plt.xlabel('Number of parameters')
    plt.ylabel('Training loss')
    plt.savefig('./figures/' + args.path + '_randomwalk' +
                '_trainloss_numparam' + '_.pdf',
                bbox_inches='tight')

    plt.figure(figsize=(4, 3))
    plt.plot(param_count, loss_record)
    plt.xlabel('Number of parameters')
    plt.ylabel('Validation loss')
    plt.savefig('./figures/' + args.path + '_randomwalk' +
                '_valloss_numparam' + '_.pdf',
                bbox_inches='tight')

    ### TODO: Add greedy

    ### random search
    best_network, first_loss, best_loss, loss_record, param_count, d_loss_hist = rs.randomsearch_optim(
        base_tn,
        train_data,
        loss_fun,
        val_data=val_data,
        other_args={
            'dhist': True,
            'optim': 'RMSprop',
            'max_iter': args.maxiter,
            'epochs': None,  # early stopping
            'lr': 0.001
        })

    plt.figure(figsize=(4, 3))
    plt.plot(param_count, d_loss_hist[0])
    plt.xlabel('Number of parameters')
    plt.ylabel('Training loss')
    plt.savefig('./figures/' + args.path + '_randomsearch' +
                '_trainloss_numparam' + '_.pdf',
                bbox_inches='tight')

    plt.figure(figsize=(4, 3))
    plt.plot(param_count, loss_record)
    plt.xlabel('Number of parameters')
    plt.ylabel('Validation loss')
    plt.savefig('./figures/' + args.path + '_randomsearch' +
                '_valloss_numparam' + '_.pdf',
                bbox_inches='tight')
    torch.manual_seed(0)
    #Target tensor is a chain
    #Tensor decomposition
    d0 = 4
    d1 = 4
    d2 = 4
    d3 = 4
    d4 = 4
    d5 = 4
    r12 = 2
    r23 = 3
    r34 = 6
    r45 = 5
    r56 = 4
    input_dims = [d0, d1, d2, d3, d4, d5]
    rank_list = [[r12, 1, 1, 1, 1], [r23, 1, 1, 1], [r34, 1, 1], [r45, 1],
                 [r56]]

    # Parameters to control the experimental behavior
    exp_params = {'print': False, 'epochs': 200}

    loss_fun = cc.tensor_recovery_loss
    base_tn = cc.random_tn(input_dims, rank=1)
    goal_tn = cc.random_tn(input_dims, rank=rank_list)
    base_tn = cc.make_trainable(base_tn)
    _, _, better_loss = discrete_optim_template(base_tn,
                                                goal_tn,
                                                loss_fun,
                                                other_args=exp_params)
    print('better loss = ', better_loss)
Example #7
0
def randomsearch_optim(tensor_list,
                       train_data,
                       loss_fun,
                       val_data=None,
                       other_args=dict()):
    """
    Train a tensor network using discrete optimization over TN ranks

    Args:
        tensor_list: List of tensors encoding the network being trained
        train_data:  The data used to train the network
        loss_fun:    Scalar-valued loss function of the type 
                        tens_list, data -> scalar_loss
                     (This depends on the task being learned)
        val_data:    The data used for validation, which can be used to
                     for early stopping within continuous optimization
                     calls within discrete optimization loop
        other_args:  Dictionary of other arguments for the optimization, 
                     with some options below (feel free to add more)
                        epochs: Number of epochs for 
                                continuous optimization     (default=10)
                        max_iter: Maximum number of iterations
                                  for random search         (default=10)
                        optim:  Choice of Pytorch optimizer (default='SGD')
                        lr:     Learning rate for optimizer (default=1e-3)
                        bsize:  Minibatch size for training (default=100)
                        max_rank: Maximum rank to search for  (default=7)
                        reps:   Number of times to repeat 
                                training data per epoch     (default=1)
                        cprint: Whether to print info from
                                continuous optimization     (default=True)
                        dprint: Whether to print info from
                                discrete optimization       (default=True)
                        loss_threshold: if loss gets below this threshold, 
                            discrete optimization is stopped
                                                            (default=1e-5)
                        initial_epochs: Number of epochs after which the 
                            learning rate is reduced and optimization is restarted
                            if there is no improvement in the loss.
                                                            (default=None)
                        stop_on_plateau: a dictionnary containing keys
                                mode  (min/max)
                                patience
                                threshold
                            used to stop continuous optimizaion when plateau
                            is detected                     (default=None)
    
    Returns:
        better_list: List of tensors with same length as tensor_list, but
                     having been optimized using the discrete optimization
                     algorithm. The TN ranks of better_list will be larger
                     than those of tensor_list.
        best_loss:   The value of the validation/training loss for the
                     model output as better_list
        loss_record: If dhist=True in other_args, this records the history of 
                     all losses for discrete and continuous optimization. This
                     is a list of dictionnaries with keys
                           iter,num_params,network,loss,train_loss_hist
                     where
                     iter: iteration of the discrete optimization
                     num_params: number of parameters of best network for this iteration
                     network: list of tensors for the best network for this iteration
                     loss: loss achieved by the best network in this iteration
                     train_loss_hist: history of losses for the continuous optimization
                        for this iteration starting from previous best_network to 
                        the epoch where the new best network was found
                     TODO: add val_lost_hist to the list of keys
    """
    # Check input and initialize local record variables
    epochs = other_args['epochs'] if 'epochs' in other_args else 10
    max_iter = other_args['max_iter'] if 'max_iter' in other_args else 10
    max_rank = other_args['max_rank'] if 'max_rank' in other_args else 7
    max_params = other_args[
        'max_params'] if 'max_params' in other_args else 10000
    dprint = other_args['dprint'] if 'dprint' in other_args else True
    cprint = other_args['cprint'] if 'cprint' in other_args else True
    loss_threshold = other_args[
        'loss_threshold'] if 'loss_threshold' in other_args else 1e-5
    initial_epochs = other_args[
        'initial_epochs'] if 'initial_epochs' in other_args else None
    # Keep track of continous optimization history
    other_args['hist'] = True

    stop_on_plateau = other_args[
        'stop_on_plateau'] if 'stop_on_plateau' in other_args else None
    if stop_on_plateau:
        detect_plateau = DetectPlateau(**stop_on_plateau)
        other_args[
            "stop_condition"] = lambda train_loss, val_loss: detect_plateau(
                train_loss)

    first_loss, best_loss, best_network = ([], []), None, None
    d_loss_hist = ([], [])

    # Function to maybe print, conditioned on `dprint`
    m_print = lambda s: print(s) if dprint else None

    # Copy tensor_list so the original is unchanged
    tensor_list = cc.copy_network(tensor_list)
    loss_record = [{
        'iter': -1,
        'network': tensor_list,
        'num_params': cc.num_params(tensor_list),
        'train_loss_hist': [0]
    }]

    # Define a function giving the stop condition for the discrete
    # optimization procedure. I'm using a simple example here which could
    # work for greedy or random walk searches, but for other optimization
    # methods this could be trivial
    stop_cond = lambda loss: loss < loss_threshold

    input_dims = cc.get_indims(tensor_list)
    n_cores = len(input_dims)
    n_edges = n_cores * (n_cores - 1) // 2

    # Iteratively increment ranks of tensor_list and train via
    # continuous_optim, at each stage using a search procedure to
    # test out different ranks before choosing just one to increase
    stage = -1
    best_loss, best_network, best_network_optimizer_state = np.infty, None, None

    while not stop_cond(best_loss) and stage < max_iter:
        stage += 1
        if best_loss is np.infty:  # first continuous optimization
            tensor_list, first_loss, best_loss, best_epoch, hist = training(
                tensor_list,
                initial_epochs,
                train_data,
                loss_fun,
                epochs=epochs,
                val_data=val_data,
                other_args=other_args)
            m_print("Initial model has TN ranks")
            if dprint: cc.print_ranks(tensor_list)
            m_print(f"Initial loss is {best_loss:.7f}")

            initialNetwork = cc.copy_network(tensor_list)
            best_network = cc.copy_network(tensor_list)
            loss_record[0]["loss"] = first_loss
            loss_record.append({
                'iter': stage,
                'network': best_network,
                'num_params': cc.num_params(best_network),
                'loss': best_loss,
                'train_loss_hist': hist[0][:best_epoch]
            })

        else:
            m_print(
                f"\n\n**** Discrete optimization - iteration {stage} ****\n\n\n"
            )
            # Create new TN random ranks
            ranks = torch.randint(low=1, high=max_rank + 1,
                                  size=(n_edges, 1)).view(-1, ).tolist()
            rank_list = _make_ranks(ranks)
            currentNetwork = _limit_random_tn(input_dims,
                                              rank=rank_list,
                                              max_params=max_params)
            currentNetwork = cc.make_trainable(currentNetwork)
            if stop_on_plateau:
                detect_plateau._reset()
            currentNetwork, first_loss, current_loss, best_epoch, hist = training(
                currentNetwork,
                initial_epochs,
                train_data,
                loss_fun,
                val_data=val_data,
                epochs=epochs,
                other_args=other_args)
            train_lost_hist = hist[0][:best_epoch]
            loss_record.append({
                'iter': stage,
                'network': currentNetwork,
                'num_params': cc.num_params(currentNetwork),
                'loss': current_loss,
                'train_loss_hist': train_lost_hist
            })
            if best_loss > current_loss:
                best_network = currentNetwork
                best_loss = current_loss

            print('\nbest TN:')
            cc.print_ranks(best_network)
            print('number of params:', cc.num_params(best_network))
            print([(r['iter'], r['num_params'], float(r['loss']),
                    float(r['train_loss_hist'][0]),
                    float(r['train_loss_hist'][-1])) for r in loss_record])

    return best_network, best_loss, loss_record
Example #8
0
def greedy_find_best_edge(initialNetwork,train_data,val_data,loss_fun,rank_increment,training_args,prev_best_loss,allowed_edges=None,search_epochs=10,padding_noise=1e-6,is_reg=False):
    """
        initialNetwork: list of tensors repreesnting the initial network
        allowed_edges: a list of allowed edges of rank more than
                    one in the tensor network. If None, all edges are
                    considered.                     (default=None)
        rank_increment: how much should the rank of the new edge be increased
        padding_noise: standard deviation of the normal distribution use to
                    initialize the new slices of the core tensors receiving
                    the new edge
        training_args: dictionnary of arguments passed to the 
                    training function.              (default={})
    """

    if not allowed_edges:
        ndims = len(initialNetwork)
        allowed_edges = [(i,j) for i in range(ndims) for j in range(i+1,ndims)]

    stop_on_plateau  = training_args['stop_on_plateau'] if 'stop_on_plateau' in training_args else None
    if stop_on_plateau:
        detect_plateau = DetectPlateau(**stop_on_plateau)
        training_args["stop_condition"] = lambda train_loss,val_loss : detect_plateau(train_loss)


    initial_epochs  = training_args['initial_epochs'] if 'initial_epochs' in training_args else None


    best_loss, best_network, best_network_optimizer_state = np.infty, None, None

    best_search_loss = best_loss
    best_train_lost_hist = []
    best_val_loss_hist = []

    for (i,j) in allowed_edges: 
        currentNetwork = cc.copy_network(initialNetwork)
        #increase rank along a chosen dimension
        currentNetwork = cc.increase_rank(currentNetwork,i, j, rank_increment, padding_noise)
        currentNetwork = cc.make_trainable(currentNetwork)
        print('\ntesting rank increment for i =', i, 'j = ', j)

        ### Search optimization phase
        # we do only a few epochs to identify the most promising rank update
        # we first optimize only the new slices for a few epochs
        print("optimize new slices for a few epochs")
        search_args = dict(training_args)
        current_network_optimizer_state = {}
        search_args["save_optimizer_state"] = True
        search_args["optimizer_state"] = current_network_optimizer_state

        if not is_reg:
            search_args["only_new_slice"] = (i,j)
        if stop_on_plateau:
            detect_plateau._reset()

        [currentNetwork, first_loss, current_loss, best_epoch, hist] = training(currentNetwork, training_args["initial_epochs"], train_data, 
            loss_fun, val_data=val_data, epochs=search_epochs, other_args=search_args)
        first_loss = hist[0][0]
        train_lost_hist = deepcopy(hist[0][:best_epoch])
        val_loss_hist = deepcopy(hist[1][:best_epoch])

        # We then optimize all parameters for a few epochs
        print("\noptimize all parameters for a few epochs")
        search_args["load_optimizer_state"] = dict(current_network_optimizer_state)
        search_args["only_new_slice"] = False
        if stop_on_plateau:
            detect_plateau._reset()
        [currentNetwork, first_loss, current_loss, best_epoch, hist] = training(currentNetwork, initial_epochs, train_data, 
            loss_fun, val_data=val_data, epochs=search_epochs, 
            other_args=search_args)
        search_args["load_optimizer_state"] = None
        train_lost_hist += deepcopy(hist[0][:best_epoch])
        val_loss_hist += deepcopy(hist[1][:best_epoch])


        print(f"\nCurrent loss is {current_loss:.7f}    Best loss from previous discrete optim is {prev_best_loss}")
        if best_search_loss > current_loss:
            best_search_loss = current_loss
            best_network = currentNetwork
            best_network_optimizer_state = deepcopy(current_network_optimizer_state)
            best_train_lost_hist = train_lost_hist
            best_val_loss_hist = val_loss_hist
            print('-> best rank update so far:', i,j)

    return best_network,best_search_loss,best_network_optimizer_state,best_train_lost_hist,best_val_loss_hist