Exemplo n.º 1
0
 def maximalPlasticitySearch(args, manager):
     """ Phase 1. Coarse finetuning gridsearch."""
     start_time = time.time()
     finetune_lr, finetune_acc = finetune_single_task.lr_grid_single_task(
         args, manager, save_models_mode=args.save_models_mode)
     args.phase1_elapsed_time = time.time() - start_time
     utils.print_timing(args.phase1_elapsed_time, "PHASE 1 FT GRID")
     return finetune_lr, finetune_acc
Exemplo n.º 2
0
def eval_all_models_all_tasks(args, manager, ds_paths, model_paths):
    """
    Each model is evaluated on all its seen tasks.
    Outputs accuracy/forgetting dictionaries with for each task the sequence of test results.
    """
    acc_all = {}
    forgetting_all = {}

    start_time = time.time()
    # All task results
    for dataset_index in range(args.test_starting_task_count - 1, args.test_max_task_count):
        method_performances = {manager.method.eval_name: {}}
        out_filepath = os.path.join(args.out_path,
                                    utils.get_perf_output_filename(manager.method.eval_name, dataset_index))
        args.eval_dset_idx = dataset_index
        if not args.test_overwrite_mode and not args.debug:  # safety check
            if os.path.exists(out_filepath):
                print("EVAL already done, can only rerun in overwrite mode")
                break

        # Test
        try:
            seq_acc, seq_forgetting, seq_head_acc = eval_task_steps_accuracy(args, manager, ds_paths, model_paths)

            if len(seq_acc[dataset_index]) == 0:
                msg = "SKIPPING SAVING: acc empty: ", seq_acc[dataset_index]
                print(msg)
                raise Exception(msg)

            # Collect results
            acc_all[dataset_index] = seq_acc[dataset_index]
            forgetting_all[dataset_index] = seq_forgetting[dataset_index]
            method_performances[manager.method.eval_name]['seq_res'] = seq_acc
            method_performances[manager.method.eval_name]['seq_forgetting'] = seq_forgetting
            method_performances[manager.method.eval_name]['seq_head_acc'] = seq_head_acc

            # Save results
            if not args.debug:
                torch.save(method_performances, out_filepath)
                print("Saved results to: ", out_filepath)

        except Exception as e:
            print("TESTING ERROR: ", e)
            print("No results saved...")
            traceback.print_exc()
            break

    elapsed_time = time.time() - start_time
    utils.print_timing(elapsed_time, title="TOTAL EVAL")
Exemplo n.º 3
0
def lr_grid_single_task(args, manager, save_models_mode='keep_none'):
    """
    Finetunes from starting model, acquire best lr and acc. LR gridsearch, with #finetune_iterations per LR.
    Makes symbolic link to the overall best iteration, corresponding with the obtained best lr.
    """

    # Init
    manager.store_policy = StoragePolicy(save_models_mode)
    args.task_name = manager.dataset.get_taskname(args.task_counter)
    manager.ft_parent_exp_dir = os.path.join(manager.parent_exp_dir,
                                             'task_' + str(args.task_counter),
                                             'FT_LR_GRIDSEARCH')
    utils.create_dir(manager.ft_parent_exp_dir)
    print("FINETUNE LR GRIDSEARCH: Task ", args.task_name)

    # Logfile
    logfile_parent_dir = os.path.join(manager.ft_parent_exp_dir, 'log')
    utils.create_dir(logfile_parent_dir)
    logfile = os.path.join(logfile_parent_dir,
                           utils.get_now() + '_finetune_grid.log')
    utils.append_to_file(logfile, "FINETUNE GRIDSEARCH LOG: Processed LRs")

    # Load Checkpoint
    processed_lrs = {}
    grid_checkpoint_file = os.path.join(manager.ft_parent_exp_dir,
                                        'grid_checkpoint.pth')
    if os.path.exists(grid_checkpoint_file):
        checkpoint = torch.load(grid_checkpoint_file)
        processed_lrs = checkpoint['processed_lrs']

        print("STARTING FROM CHECKPOINT: ", checkpoint)
        utils.append_to_file(logfile, "STARTING FROM CHECKPOINT")

    ########################################################
    # PRESTEPS
    args.presteps_elapsed_time = 0
    if hasattr(manager.method, 'grid_prestep'):
        manager.method.grid_prestep(args, manager)

    ########################################################
    # LR GRIDSEARCH
    best_acc = 0
    best_lr = None
    manager.best_exp_grid_node_dirname = None
    best_iteration_batch_dirs = []
    for lr in args.lrs:
        print("\n", "<" * 20, "LR ", lr, ">" * 20)
        accum_acc = 0
        best_iteration_dir = None
        best_iteration_acc = 0
        iteration_batch_dirs = []
        if lr not in processed_lrs:
            processed_lrs[lr] = {'acc': []}

        for finetune_iteration in range(0, args.finetune_iterations):
            print("\n", "-" * 20, "FT ITERATION ", finetune_iteration,
                  "-" * 20)
            start_time = time.time()

            # Paths
            exp_grid_node_dirname = "lr=" + str(
                utils.float_to_scientific_str(lr))
            if args.finetune_iterations > 1:
                exp_grid_node_dirname += "_it" + str(finetune_iteration)
            manager.gridsearch_exp_dir = os.path.join(
                manager.ft_parent_exp_dir, exp_grid_node_dirname)
            iteration_batch_dirs.append(manager.gridsearch_exp_dir)

            if finetune_iteration < len(processed_lrs[lr]['acc']):
                acc = processed_lrs[lr]['acc'][finetune_iteration]
                utils.set_random(finetune_iteration)
                print("RESTORING FROM CHECKPOINT: ITERATION = ",
                      finetune_iteration, "ACC = ", acc)
            else:
                # Set new seed for reproducability
                utils.set_random(finetune_iteration)

                # Only actually saved when in save_model mode
                utils.create_dir(manager.gridsearch_exp_dir)

                # TRAIN
                model, acc = manager.method.grid_train(args, manager, lr)

                # Append results
                processed_lrs[lr]['acc'].append(acc)
                msg = "LR = {}, FT Iteration {}/{}, Acc = {}".format(
                    lr, finetune_iteration + 1, args.finetune_iterations, acc)
                print(msg)
                utils.append_to_file(logfile, msg)

            # New best
            if acc > best_iteration_acc:
                if args.finetune_iterations > 1:
                    msg = "=> NEW BEST FT ITERATION {}/{}: (Attempt '{}': Acc '{}' > best attempt Acc '{}')" \
                        .format(finetune_iteration + 1,
                                args.finetune_iterations,
                                finetune_iteration,
                                acc,
                                best_iteration_acc)
                    print(msg)
                    utils.append_to_file(logfile, msg)

                best_iteration_acc = acc
                best_iteration_dir = manager.gridsearch_exp_dir

            accum_acc = accum_acc + acc

            # update logfile/checkpoint
            torch.save({'processed_lrs': processed_lrs}, grid_checkpoint_file)

            # Save iteration hyperparams
            if hasattr(manager.method,
                       "grid_chkpt") and manager.method.grid_chkpt:
                it_elapsed_time = time.time() - start_time
                hyperparams = {
                    'val_acc': acc,
                    'lr': lr,
                    'iteration_elapsed_time': it_elapsed_time,
                    'args': vars(args),
                    'manager': vars(manager)
                }
                utils.print_timing(it_elapsed_time, 'TRAIN')
                manager.save_hyperparams(manager.gridsearch_exp_dir,
                                         hyperparams)
        avg_acc = accum_acc / args.finetune_iterations
        print("Done FT iterations\n")
        print("LR AVG ACC = ", avg_acc, ", BEST OF LRs ACC = ", best_acc)

        # New it-avg best
        if avg_acc > best_acc:
            best_lr = lr
            best_acc = avg_acc
            manager.best_exp_grid_node_dirname = best_iteration_dir  # Keep ref to best in all attempts
            print("UPDATE best lr = {}".format(best_lr))
            print("UPDATE best lr acc= {}".format(best_acc))

            utils.append_to_file(logfile,
                                 "UPDATE best lr = {}".format(best_lr))
            utils.append_to_file(logfile,
                                 "UPDATE best lr acc= {}\n".format(best_acc))

            # Clean all from previous best
            if manager.store_policy.only_keep_best:
                for out_dir in best_iteration_batch_dirs:
                    if os.path.exists(out_dir):
                        shutil.rmtree(out_dir, ignore_errors=True)
                        print("[CLEANUP] removing {}".format(out_dir))
            best_iteration_batch_dirs = iteration_batch_dirs
        else:
            if manager.store_policy.only_keep_best:
                for out_dir in iteration_batch_dirs:
                    if os.path.exists(out_dir):
                        shutil.rmtree(out_dir, ignore_errors=True)
                        print("[CLEANUP] removing {}".format(out_dir))
        if manager.store_policy.keep_none:
            for out_dir in iteration_batch_dirs:
                if os.path.exists(out_dir):
                    shutil.rmtree(out_dir, ignore_errors=True)
                    print("[CLEANUP] removing {}".format(out_dir))
    print("FINETUNE DONE: best_lr={}, best_acc={}".format(best_lr, best_acc))

    ########################################################
    # POSTPROCESS
    if hasattr(manager.method, 'grid_poststep'):
        manager.method.grid_poststep(args, manager)

    return best_lr, best_acc
Exemplo n.º 4
0
    def stabilityDecay(self, args, manager, finetune_lr, finetune_acc):
        """ Phase 2. """
        args.lr = finetune_lr  # Set current lr based on previous phase
        manager.heuristic_exp_dir = os.path.join(
            manager.parent_exp_dir, 'task_' + str(args.task_counter),
            'TASK_TRAINING')
        if hasattr(manager.method, 'train_init'):  # Setting some paths etc.
            manager.method.train_init(args, manager)

        chkpt_loaded = self.load_chkpt(manager)  # Always Load checkpoints
        if not chkpt_loaded:  # Init state
            self.attempts = 0
            self.hyperparams_backup = copy.deepcopy(self.hyperparams)
        if self.check_succes(manager):  # Skip this phase
            manager.best_model_path = os.path.join(
                manager.heuristic_exp_dir, 'best_model.pth.tar')  # Set paths
            return

        # PRESTEPS FOR METHODS
        args.presteps_elapsed_time = 0
        if hasattr(manager.method, 'prestep'):
            manager.method.prestep(args, manager)

        # CONTINUE
        max_attempts = args.max_attempts_per_task  # * len(manager.method.hyperparams)
        converged = False
        while not converged and self.attempts < max_attempts:
            print(" => ATTEMPT {}/{}: Hyperparams {}".format(
                self.attempts, max_attempts - 1, self.hyperparams))
            start_time = time.time()
            try:
                manager.method.hyperparams = self.hyperparams
                model, task_lr_acc = manager.method.train(
                    args, manager, self.hyperparams)
            except:
                traceback.print_exc()
                sys.exit(1)

            # Accuracy on val set should be at least finetune_acc_threshold% of finetuning accuracy
            threshold = finetune_acc * args.inv_drop_margin  # A_ft * (1 - p) defined in paper

            ########################################
            # CONVERGE POLICY
            ########################################
            if task_lr_acc >= threshold:
                print('CONVERGED, (acc = ', task_lr_acc, ") >= (threshold = ",
                      threshold, ")")
                converged = True
                args.convergence_iteration_elapsed_time = time.time(
                ) - start_time
                utils.print_timing(args.convergence_iteration_elapsed_time,
                                   "PHASE 2 CONVERGED FINAL IT")

            ########################################
            # DECAY POLICY
            ########################################
            else:
                print('DECAY HYPERPARAMS, (acc = ', task_lr_acc,
                      ") < (threshold = ", threshold, ")")
                self.hyperparamDecay(args, manager)
                self.attempts += 1

                # Cleanup unless last attempt
                if self.attempts < max_attempts:
                    print('CLEANUP of previous model')
                    utils.rm_dir(manager.heuristic_exp_dir)
                else:
                    print("RETAINING LAST ATTEMPT MODEL")
                    converged = True

            # CHECKPOINT
            self._save_chkpt(args, manager, threshold, task_lr_acc)
            self._print_status()  # Framework Status

        # POST PREP
        manager.best_model_path = os.path.join(manager.heuristic_exp_dir,
                                               'best_model.pth.tar')
        manager.create_success_token(manager.heuristic_exp_dir)
Exemplo n.º 5
0
def preprocess_merge_IMM(method, model_paths, datasets_path, batch_size, overwrite=False, debug=True):
    """
    Create and save all merged models.
    :param model_paths: list of chronologically all paths, of the models per trained task.
    """
    merged_model_paths = []
    IMM_mode = method.mode
    merge_model_name = 'best_model_' + IMM_mode + '_merge.pth.tar'

    last_task_idx = len(model_paths) - 1

    # Avoiding memory overload when merged models already exist
    if not overwrite:
        for task_list_index in range(len(model_paths) - 1, 0, -1):
            merged_model_path = os.path.join(os.path.dirname(model_paths[task_list_index]), merge_model_name)

            if os.path.exists(merged_model_path):
                print("SKIPPING, MERGE ALREADY EXISTS for task ", task_list_index)
                last_task_idx = task_list_index
            else:
                break

    # Load models in memory
    models = [torch.load(model_path) for model_path in model_paths]
    print("MODELS TO PROCESS:")
    print('\n'.join(model_paths[:last_task_idx + 1]))
    print("LOADED ", len(models), " MODELS in MEMORY")

    # Keep first model (no merge needed)
    merged_model_paths.append(model_paths[0])

    # Head param names
    last_layer_index = str(len(models[0].classifier._modules) - 1)
    head_param_names = ['classifier.{}.{}'.format(last_layer_index, name) for name, p in
                        models[0].classifier._modules[last_layer_index].named_parameters()]

    if debug:
        print("HEAD PARAM NAMES")
        [print(name) for name in head_param_names]

    # equal_alpha = 1 / len(model_paths)
    # alphas = [equal_alpha for model_path in range(0, len(model_paths))]

    # Calculate precisions and sum of all precisions
    if IMM_mode == method.modes[1]:
        start_time = time.time()
        print("MODE IMM PREPROCESSING")
        precision_matrices = []
        sum_precision_matrices = []  # All summed of previous tasks (first task not included)
        precision_name = 'precision_' + IMM_mode + '.pth.tar'
        sum_precision_matrix = None
        for task_list_index in range(0, last_task_idx + 1):
            print("TASK ", task_list_index)
            precision_out_file_path = os.path.join(os.path.dirname(model_paths[task_list_index]), precision_name)
            sum_precision_out_file_path = os.path.join(os.path.dirname(model_paths[task_list_index]),
                                                       "sum_" + precision_name)

            if os.path.exists(precision_out_file_path) and not overwrite:
                precision_matrix = torch.load(precision_out_file_path)
                print('LOADED PRECISION MATRIX FOR TASK {} : {}'.format(task_list_index, precision_out_file_path))
            else:
                # get model and data
                model = models[task_list_index]
                dsets = torch.load(datasets_path[task_list_index])

                dset_loaders = {
                    x: torch.utils.data.DataLoader(dsets[x], batch_size=batch_size, shuffle=True, num_workers=8,
                                                   pin_memory=True)
                    for x in ['train', 'val']}

                # get parameters precision estimation
                if debug:
                    print("PARAM NAMES")
                    [print(n) for n, p in model.named_parameters() if p.requires_grad]
                model.params = {n: p for n, p in model.named_parameters() if p.requires_grad}
                precision_matrix = diag_fisher(model, dset_loaders, exclude_params=head_param_names)

                assert [precision_matrix.keys()] == [
                    {name for name, p in model.named_parameters() if name not in head_param_names}]
                del model, dset_loaders, dsets

                print("Saving precision matrix: ", precision_out_file_path)
                torch.save(precision_matrix, precision_out_file_path)
            precision_matrices.append(precision_matrix)

            # Update sum
            # Make sum matrix for each of the tasks! (incremental sum)
            if sum_precision_matrix is None:
                sum_precision_matrix = precision_matrix
            else:
                if os.path.exists(sum_precision_out_file_path) and not overwrite:
                    sum_precision_matrix = torch.load(sum_precision_out_file_path)
                    print('LOADED SUM-PRECISION MATRIX FOR TASK {} : {}'.format(task_list_index,
                                                                                sum_precision_out_file_path))
                else:
                    if debug:
                        for name, p in sum_precision_matrix.items():
                            print("{}: {} -> {}".format(name, p.shape, precision_matrix[name].shape))
                    sum_precision_matrix = {name: p + precision_matrix[name]
                                            for name, p in sum_precision_matrix.items()}
                    assert len([precision_matrix[name] != p for name, p in sum_precision_matrix.items()]) > 0

                    # Save
                    torch.save(sum_precision_matrix, sum_precision_out_file_path)
                    print("Saving SUM precision matrix: ", sum_precision_out_file_path)

                sum_precision_matrices.append(sum_precision_matrix)
        elapsed_time = time.time() - start_time
        utils.print_timing(elapsed_time, title="MODE IMM IWS")

    # Create merged model for each task (except first)
    start_time = time.time()
    for task_list_index in range(1, last_task_idx + 1):
        out_file_path = os.path.join(os.path.dirname(model_paths[task_list_index]), merge_model_name)

        # Mean IMM
        if IMM_mode == method.modes[0]:
            merged_model = IMM_merge_models(models, task_list_index, head_param_names, mean_mode=True)
        # Mode IMM
        elif IMM_mode == method.modes[1]:
            merged_model = IMM_merge_models(models, task_list_index, head_param_names, precision=precision_matrices,
                                            sum_precision=sum_precision_matrices[task_list_index - 1], mean_mode=False)
        else:
            raise ValueError("IMM mode is not supported: ", str(IMM_mode))

        # Save merged model on same spot as best_model
        torch.save(merged_model, out_file_path)
        merged_model_paths.append(out_file_path)
        print(" => SAVED MERGED MODEL: ", out_file_path)

        del merged_model
    del models
    elapsed_time = time.time() - start_time
    utils.print_timing(elapsed_time, title="IMM MERGING")

    print("MERGED MODELS:")
    print('\n'.join(merged_model_paths))

    return merged_model_paths
Exemplo n.º 6
0
def termination_protocol(since, best_acc):
    time_elapsed = time.time() - since
    print('Training complete in {:.0f}m {:.0f}s'.format(
        time_elapsed // 60, time_elapsed % 60))
    utils.print_timing(time_elapsed, "TRAINING ONLY")
    print('Best val Acc: {:4f}'.format(best_acc))