コード例 #1
0
def run():
    """ Run the script"""
    #############
    ### Setup ###
    #############

    config = train_args_seq_smnist.parse_cmd_arguments()
    device, writer, logger = sutils.setup_environment(config)
    dhandlers = ctu._generate_tasks(config, logger)

    # We will use the namespace below to share miscellaneous information between
    # functions.
    shared = Namespace()
    shared.feature_size = dhandlers[0].in_shape[0]

    # Plot images.
    if config.show_plots:
        figure_dir = os.path.join(config.out_dir, 'figures')
        if not os.path.exists(figure_dir):
            os.makedirs(figure_dir)

        for t, dh in enumerate(dhandlers):
            dh.plot_samples('Test Samples - Task %d' % t,
                dh.get_train_inputs()[:8], outputs=dh.get_train_outputs()[:8],
                show=True, filename=os.path.join(figure_dir,
                    'test_samples_task_%d.png' % t))

    target_net, hnet, dnet = stu.generate_networks(config, shared, dhandlers,
                                                   device)

    # generate masks if needed
    ctx_masks = None
    if config.use_masks:
        ctx_masks = stu.generate_binary_masks(config, device, target_net)

    # We store the target network weights (excluding potential context-mod
    # weights after every task). In this way, we can quantify changes and
    # observe the "stiffness" of EWC.
    shared.tnet_weights = []
    # We store the context-mod weights (or all weights) coming from the hypernet
    # after every task, in order to quantify "forgetting". Note, the hnet
    # regularizer should keep them fix.
    shared.hnet_out = []

    # Get the task-specific functions for loss and accuracy.
    task_loss_func = ctu.get_loss_func(config, device, logger, ewc_loss=False)
    accuracy_func = ctu.get_accuracy_func(config)
    ewc_loss_func = ctu.get_loss_func(config, device, logger, ewc_loss=True) \
        if config.use_ewc else None

    replay_fcts = None
    if config.use_replay:
        replay_fcts = dict()
        replay_fcts['rec_loss'] = ctu.get_vae_rec_loss_func()
        replay_fcts['distill_loss'] = ctu.get_distill_loss_func()
        replay_fcts['soft_trgt_acc'] = ctu.get_soft_trgt_acc_func()

    if config.multitask:
        summary_keywords=hpsearch_mt._SUMMARY_KEYWORDS
        summary_filename=hpsearch_mt._SUMMARY_FILENAME
    else:
        summary_keywords=hpsearch_cl._SUMMARY_KEYWORDS
        summary_filename=hpsearch_cl._SUMMARY_FILENAME

    ########################
    ### Train classifier ###
    ########################

    # Train the network task by task. Testing on all tasks is run after 
    # finishing training on each task.
    ret, train_loss, test_loss, test_acc = sts.train_tasks(dhandlers,
        target_net, hnet, dnet, device, config, shared, logger, writer,
        ctx_masks, summary_keywords, summary_filename,
        task_loss_func=task_loss_func, accuracy_func=accuracy_func,
        ewc_loss_func=ewc_loss_func, replay_fcts=replay_fcts)

    stu.log_results(test_acc, config, logger)

    writer.close()

    if ret == -1:
        logger.info('Program finished successfully.')

        if config.show_plots:
            plt.show()
    else:
        logger.error('Only %d tasks have completed training.' % (ret+1))
コード例 #2
0
def run(config, *args, dataset='copy'):
    """Run the script.

    Args:
        config: The default config for the current dataset.
        dataset (str, optional): The dataset being analysed.
    """

    fig_size = [2.8, 2.55]
    configure_matplotlib_params(fig_size=fig_size)

    ### Parse the command-line arguments.
    parser = argparse.ArgumentParser(description= \
        'Studying the dimensionality of the hidden space.')
    parser.add_argument('out_dir', type=str, default='./out/hyperparam_search',
                        help='The output directory of the runs. ' +
                             'Default: %(default)s.')
    parser.add_argument('--redo_analyses', action='store_true',
                        help='If enabled, all analyses will be done even ' +
                             'if some previous results had been stored and ' +
                             'could have been loaded.')
    parser.add_argument('--do_kernel_pca', action='store_true',
                        help='If enabled, kernel PCA will also be used to ' +
                             'compute the number of hidden dimensions.')
    parser.add_argument('--do_supervised_dimred', action='store_true',
                        help='If enabled, supervised dimensionality reduction '+
                             'will be performed to compute the number of '+
                             'task-relevant hidden dimensions.')
    parser.add_argument('--p_var', type=int, default=0.75,
                        help='The amount of variance that needs to be ' +
                             'explained to determine the number of ' +
                             'dimensions. Default: %(default)s.')
    parser.add_argument('--n_pcs', type=int, default=10,
                        help='The number of principal components to be taken '+
                             'into account for the plots and analyses. ' +
                             'Default: %(default)s.')
    parser.add_argument('--n_samples', type=int, default=1000,
                        help='The number of samples to be used for the PCA ' +
                             'analyses in the different settings. This ' +
                             'ensures that comparisons when pulling data ' +
                             'from different tasks are made with an equal ' +
                             'amount of data. Default: %(default)s.')
    parser.add_argument('--timesteps_for_analysis', type=str, default='output',
                        choices=['all', 'input', 'output', 'stop', 'last', \
                                 'stop_plus_one'],
                        help='The timesteps to be used for the PCA analyses. '+
                             'Options are: all timesteps, only during input ' +
                             'presentation, only during output presentation, '+
                             'only upon stop flag, only one timestep after ' +
                             'the flag or only the very last output timestep.'+
                             ' Default: %(default)s.')
    parser.add_argument('--num_tasks', type=int, default=-1,
                        help='Number of tasks to consider. Default: ' +
                             '%(default)s.')
    parser.add_argument('--sup_dimred_criterion', type=int, default=-1,
                        help='Accuracy to be obtained when projecting the ' +
                             'hidden activity into a lower-dimensional ' +
                             'subspace in a supervised fasion. For a value '+
                             'of -1, no criterion is used and so all ' +
                             'possible number of dimensions are explored. '+
                             '%(default)s.')
    parser.add_argument('--sdr_orth_strength', type=float, default=1e3,
                        help='The strength of the orthogonal regularizer ' +
                             'when doing supervised dimensionality reduction ' +
                             '(option "do_supervised_dimred"). ' +
                             'Default: %(default)s.')
    parser.add_argument('--sdr_lr', type=float, default=1e-2,
                        help='The learning rate when doing supervised ' +
                             'dimensionality reduction ' +
                             '(option "do_supervised_dimred"). ' +
                             'Default: %(default)s.')
    parser.add_argument('--sdr_batch_size', type=int, default=64,
                        help='The batch size when doing supervised ' +
                             'dimensionality reduction ' +
                             '(option "do_supervised_dimred"). ' +
                             'Default: %(default)s.')
    parser.add_argument('--sdr_n_iter', type=int, default=100,
                        help='The number of training iterations per ' +
                             'projection column when doing supervised ' +
                             'dimensionality reduction ' +
                             '(option "do_supervised_dimred"). ' +
                             'Default: %(default)s.')
    cmd_args = parser.parse_args()

    if (dataset == 'audioset' and cmd_args.timesteps_for_analysis not in \
            ['all', 'last']) or (dataset == 'seq_smnist' and \
            cmd_args.timesteps_for_analysis != 'all') \
            or (dataset == 'student_teacher' \
            and cmd_args.timesteps_for_analysis != 'all'):
        warnings.warn('A subset of timesteps for the analysis can only be ' +
            'selected for the Copy Task currently. Using all timesteps.')
        # Note that for Sequential SMNIST, the last timestep needs to be 
        # correctly selected, which is not currently implemented.
        setattr(cmd_args, 'timesteps_for_analysis', 'all')

    copy_task = False
    if dataset == 'copy':
        copy_task = True

    # Define directory where to store the results of all current analyses.
    results_dir = os.path.join(cmd_args.out_dir, 'pca_analyses')
    if not os.path.exists(results_dir):
        os.makedirs(results_dir)
    # Overwrite the output directory.
    config.out_dir = os.path.join(results_dir, 'sim_files')

    ### Set up environment using some general command line arguments.
    device, writer, logger = sutils.setup_environment(config)

    ### Check if the current directory corresponds to a single run or not.
    # The name of folders has recently been changed. Allow old and new namings.
    out_dirs = glob.glob(os.path.join(cmd_args.out_dir, '20*'))
    out_dirs.extend(glob.glob(os.path.join(cmd_args.out_dir, 'sim*')))
    if len(out_dirs)==0:
        out_dirs = [cmd_args.out_dir]

    ### Store the results of the different runs in a same dictionary.
    results = {}
    for i, out_dir in enumerate(out_dirs):
        sdr_args = {
            'lambda_ortho': cmd_args.sdr_orth_strength,
            'lr': cmd_args.sdr_lr,
            'n_iter': cmd_args.sdr_n_iter,
            'batch_size': cmd_args.sdr_batch_size
        }

        results[out_dir], settings = analyse_single_run(out_dir, device,
            writer, logger, *args, redo_analyses=cmd_args.redo_analyses,
            n_samples=cmd_args.n_samples, do_kernel_pca=cmd_args.do_kernel_pca,
            timesteps_for_analysis=cmd_args.timesteps_for_analysis,
            copy_task=copy_task, num_tasks=cmd_args.num_tasks,
            do_supervised_dimred=cmd_args.do_supervised_dimred,
            sup_dimred_criterion=cmd_args.sup_dimred_criterion,
            sup_dimred_args=sdr_args)

        # Ensure all runs have comparable properties
        if i == 0:
            common_settings = settings.copy()
        for key in settings.keys():
            assert settings[key] == common_settings[key]
    num_tasks = common_settings['num_tasks']
    if cmd_args.num_tasks != -1:
        num_tasks = cmd_args.num_tasks
        assert num_tasks <= common_settings['num_tasks']

    ### Check if there are identical runs with different random seeds.
    seed_groups, num_seeds, num_runs = group_runs(results)
    print('\nThe analysis was done with %i seeds '%num_seeds + 
        'for each of the %i different runs.'%num_runs)

    ### Pickle the results.
    with open(os.path.join(results_dir, 'results.pickle'), 'wb') as handle:
        pickle.dump(results, handle, protocol=pickle.HIGHEST_PROTOCOL)
    with open(os.path.join(results_dir, 'seed_groups.pickle'), 'wb') as handle:
        pickle.dump(seed_groups, handle, protocol=pickle.HIGHEST_PROTOCOL)

    ### Plot.
    for task_id in range(num_tasks):
        if num_runs > 1:
            if copy_task:
                ssp.plot_per_ts(results, seed_groups, task_id=task_id,
                    path=results_dir, key_name='accuracy')
            ssp.plot_per_ts(results, seed_groups, task_id=task_id,
                path=results_dir, key_name='dimension')
            ssp.plot_per_ts(results, seed_groups, task_id=task_id,
                path=results_dir, key_name='dimension', internal=False)
            if cmd_args.do_kernel_pca:
                ssp.plot_per_ts(results, seed_groups, key_name='dimension',
                    task_id=task_id, path=results_dir, kernel=True)
                ssp.plot_per_ts(results, seed_groups, key_name='dimension',
                    task_id=task_id, path=results_dir, kernel=True,
                    internal=False)

    if copy_task:
        if settings['permute_time'] and not \
                settings['permute_width'] and not settings['permute_xor']:
            ssp.plot_accuracy_vs_bptt_steps(results, seed_groups,
                task_id=task_id, path=results_dir)

    # Make multi-task plots.
    if num_tasks > 1:
        ssp.plot_importance_vs_task(results, seed_groups, path=results_dir)
        ssp.plot_dimension_vs_task(results, seed_groups, path=results_dir)
        ssp.plot_dimension_vs_task(results, seed_groups, path=results_dir,
            internal=False)
        ssp.plot_dimension_vs_task(results, seed_groups, path=results_dir,
            onto_task_1=True)
        ssp.print_dimension_vs_task(results, seed_groups, p_var=cmd_args.p_var)
        if num_runs == 1:
            ssp.plot_per_ts(results, seed_groups, path=results_dir,
                key_name='dimension')
            ssp.plot_per_ts(results, seed_groups, path=results_dir,
                key_name='dimension', internal=False)
            if cmd_args.do_kernel_pca:
                ssp.plot_per_ts(results, seed_groups, key_name='dimension',
                    path=results_dir, kernel=True)
            if copy_task:
                ssp.plot_per_ts(results, seed_groups, path=results_dir,
                    key_name='accuracy') 
    if cmd_args.do_supervised_dimred:
        ssp.plot_supervised_dimension_vs_task(results, seed_groups,
            path=results_dir)
        ssp.plot_supervised_dimension_vs_task(results, seed_groups,
            path=results_dir, key='accu')
        ssp.plot_supervised_dimension_vs_task(results, seed_groups,
            path=results_dir, key='accu', stop_bit=True)

    # Make multi-run plots.
    if num_runs > 1:
        ssp.plot_across_runs(results, seed_groups, path=results_dir,
            do_kernel_pca=cmd_args.do_kernel_pca, n_pcs=cmd_args.n_pcs,
            p_var=cmd_args.p_var)
コード例 #3
0
def run(config, experiment='resnet'):
    """Run the training.

    Args:
        config (argparse.Namespace): Command-line arguments.
        experiment (str): Which kind of experiment should be performed?

            - ``resnet``: CIFAR-10/100 with Resnet-32.
            - ``zenke``: CIFAR-10/100 with Zenkenet.
    """
    assert (experiment in ['resnet', 'zenke'])

    script_start = time()

    device, writer, logger = sutils.setup_environment(
        config, logger_name='det_cl_cifar_%s' % experiment)
    # TODO Adapt script to allow checkpointing of models using
    # `utils.torch_ckpts` (i.e., we should be able to continue training or just
    # test an existing checkpoint).
    #config.ckpt_dir = os.path.join(config.out_dir, 'checkpoints')

    # Container for variables shared across function.
    shared = Namespace()
    shared.experiment = experiment

    ### Load datasets (i.e., create tasks).
    dhandlers = tutils.load_datasets(config,
                                     shared,
                                     logger,
                                     data_dir='../datasets')
    print('handlers traninig: ', dhandlers[0].num_train_samples)

    # data,taskcla,inputsize=mixceleba.get(seed=config.seed,args=config)
    data, taskcla, inputsize = mixemnist.get(seed=config.seed, args=config)

    print('Input size =', inputsize, '\nTask info =', taskcla)

    ### Create main network.
    # TODO Allow main net only training.
    mnet = tutils.get_main_model(config,
                                 shared,
                                 logger,
                                 device,
                                 no_weights=not config.mnet_only)

    ### Create the hypernetwork.
    if config.mnet_only:
        hnet = None
    else:
        hnet = tutils.get_hnet_model(config, mnet, logger, device)

    ### Initialize the performance measures, that should be tracked during
    ### training.
    tutils.setup_summary_dict(config, shared, mnet, hnet=hnet)

    # Add hparams to tensorboard, such that the identification of runs is
    # easier.
    writer.add_hparams(hparam_dict={
        **vars(config),
        **{
            'num_weights_main': shared.summary['num_weights_main'],
            'num_weights_hyper': shared.summary['num_weights_hyper'],
            'num_weights_ratio': shared.summary['num_weights_ratio'],
        }
    },
                       metric_dict={})

    # FIXME: Method "calc_fix_target_reg" expects a None value.
    # But `writer.add_hparams` can't deal with `None` values.
    if config.cl_reg_batch_size == -1:
        config.cl_reg_batch_size = None

    # We keep the hnet output right after training to measure forgetting.
    weights_after_training = []

    ######################
    ### Start Training ###
    ######################
    for j in range(config.num_tasks):
        logger.info('Starting training of task %d ...' % (j + 1))

        data = dhandlers[j]

        # It might be that tasks are very similar and we can transfer knowledge
        # form the previous solution.
        if hnet is not None and config.init_with_prev_emb and j > 0:
            last_emb = hnet.get_task_emb(j - 1).detach().clone()
            hnet.get_task_emb(j).data = last_emb

        # Training from scratch -- create new network instance!
        # -> No transfer possible.
        if j > 0 and config.train_from_scratch:
            # FIXME Since we simply override the current network, future testing
            # on this new network for old tasks doesn't make sense. So we
            # shouldn't report `final` accuracies.
            if config.mnet_only:
                logger.info(
                    'From scratch training: Creating new main network.')
                mnet = tutils.get_main_model(config,
                                             shared,
                                             logger,
                                             device,
                                             no_weights=not config.mnet_only)
            else:
                logger.info(
                    'From scratch training: Creating new hypernetwork.')
                hnet = tutils.get_hnet_model(config, mnet, logger, device)

        ################################
        ### Train and test on task j ###
        ################################
        train(j, data, mnet, hnet, device, config, shared, writer, logger)

        ### Final test run.
        if hnet is not None:
            weights = hnet.forward(j)
            # Push to CPU to avoid growing GPU memory when solving very long
            # task sequences.
            weights = [w.detach().clone().cpu() for w in weights]
            weights_after_training.append(weights)

        test_acc, _ = test(j, data, mnet, hnet, device, shared, config, writer,
                           logger)

        logger.info('### Accuracy of task %d / %d:  %.3f' % \
                    (j+1, config.num_tasks, test_acc))
        logger.info('### Finished training task: %d' % (j + 1))
        shared.summary['acc_during'][j] = test_acc

        # Backup results so far.
        tutils.save_summary_dict(config, shared, experiment)

    shared.summary['acc_avg_during'] = np.mean(shared.summary['acc_during'])

    logger.info('### Accuracy of individual tasks after training %s' % \
                (str(shared.summary['acc_during'])))
    logger.info('### Average of these accuracies  %.2f' % \
                (shared.summary['acc_avg_during']))
    writer.add_scalar('final/during_acc_avg', shared.summary['acc_avg_during'])

    #########################################
    ### Test continual learning scenarios ###
    #########################################
    test_multiple(dhandlers, mnet, hnet, device, config, shared, writer,
                  logger)

    #########################
    ### Run some analysis ###
    #########################
    if not config.mnet_only:
        analysis(dhandlers, mnet, hnet, device, config, shared, writer, logger,
                 weights_after_training)

    ### Write final summary.
    shared.summary['finished'] = 1
    tutils.save_summary_dict(config, shared, experiment)

    writer.close()

    logger.info('Program finished successfully in %f sec.' %
                (time() - script_start))
コード例 #4
0
def run(config, experiment='regression_ewc'):
    """Run the EWC training for the given experiment.

    Args:
        config (argparse.Namespace): Command-line arguments.
        experiment (str): Which kind of experiment should be performed?

            - ``'regression_ewc'``: Regression tasks with EWC
            - ``'gmm_ewc'``: GMM Data with EWC
            - ``'split_mnist_ewc'``: SplitMNIST with EWC
            - ``'perm_mnist_ewc'``: PermutedMNIST with EWC
            - ``'cifar_resnet_ewc'``: CIFAR-10/100 with EWC
    """
    assert experiment in ['regression_ewc', 'gmm_ewc', 'split_mnist_ewc',
                          'perm_mnist_ewc', 'cifar_resnet_ewc']

    script_start = time()

    device, writer, logger = sutils.setup_environment(config,
        logger_name=experiment)

    rutils.backup_cli_command(config)

    is_classification = True
    if 'regression' in experiment:
        is_classification = False

    ### Create tasks.
    if is_classification:
        dhandlers = pmutils.load_datasets(config, logger, experiment, writer)
    else:
        dhandlers, num_tasks = rutils.generate_tasks(config, writer)
        config.num_tasks = num_tasks

    ### Simple struct, that is used to share data among functions.
    shared = Namespace()
    shared.experiment_type = experiment
    shared.all_dhandlers = dhandlers
    shared.num_trained = 0

    ### Generate network.
    mnet, _, _, _ = pcutils.generate_networks(config, shared, logger, dhandlers,
        device, create_mnet=True, create_hnet=False, create_hhnet=dhandlers,
        create_dis=False)

    if not is_classification:
        shared.during_mse = np.ones(config.num_tasks) * -1.
        # MSE achieved after most recent call of test method.
        shared.current_mse = np.ones(config.num_tasks) * -1.
    # The weights of the main network right after training on that task
    # (can be used to assess how close the final weights are to the original
    # ones).
    shared.during_weights = [-1] * config.num_tasks

    # Where to save network checkpoints?
    shared.ckpt_dir = os.path.join(config.out_dir, 'checkpoints')
    shared.ckpt_mnet_fn = os.path.join(shared.ckpt_dir, 'mnet_task_%d')

    # Initialize the softmax temperature per-task with one. Might be changed
    # later on to calibrate the temperature.
    if is_classification:
        shared.softmax_temp = [torch.ones(1).to(device) \
                               for _ in range(config.num_tasks)]

    ### Initialize summary.
    if is_classification:
        pcutils.setup_summary_dict(config, shared, experiment, mnet)
    else:
        rutils.setup_summary_dict(config, shared, 'ewc', config.num_tasks, mnet)

    # Add hparams to tensorboard, such that the identification of runs is
    # easier.
    writer.add_hparams(hparam_dict={**vars(config), **{
        'num_weights_main': shared.summary['num_weights_main'] \
            if is_classification else shared.summary['aa_num_weights_main']
    }}, metric_dict={})

    if is_classification:
        during_acc_criterion = pmutils.parse_performance_criterion(config,
            shared, logger)

    ### Train on tasks sequentially.
    for i in range(config.num_tasks):
        logger.info('### Training on task %d ###' % (i+1))
        data = dhandlers[i]

        # Train the network.
        shared.num_trained += 1
        train(i, data, mnet, device, config, shared, logger, writer)

        ### Test networks.
        test_ids = None
        if hasattr(config, 'full_test_interval') and \
                config.full_test_interval != -1:
            if i == config.num_tasks-1 or \
                    (i > 0 and i % config.full_test_interval == 0):
                test_ids = None # Test on all tasks.
            else:
                test_ids = [i] # Only test on current task.
        test(dhandlers[:(i+1)], mnet, device, config, shared, logger, writer,
             test_ids=test_ids)

        ### Check if last task got "acceptable" accuracy ###
        if is_classification:
            curr_dur_acc = shared.summary['acc_task_given_during'][i]
        if is_classification and i < config.num_tasks-1 \
                and during_acc_criterion[i] != -1 \
                and during_acc_criterion[i] > curr_dur_acc:
            logger.error('During accuracy of task %d too small (%f < %f).' % \
                         (i+1, curr_dur_acc, during_acc_criterion[i]))
            logger.error('Training of future tasks will be skipped')
            writer.close()
            exit(1)

        if config.train_from_scratch and i < config.num_tasks-1:
            # We have to checkpoint the networks, such that we can reload them
            # for task inference later during testing.
            pmutils.checkpoint_nets(config, shared, i, mnet, None)

            mnet, _, _, _ = pcutils.generate_networks(config, shared, logger,
                dhandlers, device, create_mnet=True, create_hnet=False,
                create_hhnet=dhandlers, create_dis=False)

    if config.store_final_model:
        logger.info('Checkpointing final model ...')
        pmutils.checkpoint_nets(config, shared, config.num_tasks-1, mnet, None)

    ### Plot final classification scores.
    if is_classification:
        logger.info('During accuracies (task identity given): ' + \
                    '%s (avg: %.2f%%).' % \
            (np.array2string(np.array(shared.summary['acc_task_given_during']),
                             precision=2, separator=','),
             shared.summary['acc_avg_task_given_during']))
        logger.info('Final accuracies (task identity given): ' + \
                    '%s (avg: %.2f%%).' % \
            (np.array2string(np.array(shared.summary['acc_task_given']),
                             precision=2, separator=','),
             shared.summary['acc_avg_task_given']))

    if is_classification and config.cl_scenario == 3 and config.split_head_cl3:
        logger.info('During accuracies (task identity inferred): ' +
                    '%s (avg: %.2f%%).' % \
            (np.array2string(np.array( \
                shared.summary['acc_task_inferred_ent_during']),
                             precision=2, separator=','),
             shared.summary['acc_avg_task_inferred_ent_during']))
        logger.info('Final accuracies (task identity inferred): ' +
                    '%s (avg: %.2f%%).' % \
            (np.array2string(np.array(shared.summary['acc_task_inferred_ent']),
                             precision=2, separator=','),
             shared.summary['acc_avg_task_inferred_ent']))

    ### Plot final regression scores.
    if not is_classification:
        logger.info('During MSE values after training each task: %s' % \
            np.array2string(np.array(shared.summary['aa_mse_during']),
                            precision=5, separator=','))
        logger.info('Final MSE values after training on all tasks: %s' % \
            np.array2string(np.array(shared.summary['aa_mse_final']),
                            precision=5, separator=','))
        logger.info('Final MSE mean %.4f.' % \
                    (shared.summary['aa_mse_during_mean']))

    ### Write final summary.
    shared.summary['finished'] = 1
    if is_classification:
        pmutils.save_summary_dict(config, shared, experiment)
    else:
        rutils.save_summary_dict(config, shared)

    writer.close()

    logger.info('Program finished successfully in %f sec.'
                % (time() - script_start))
コード例 #5
0
def run():
    """Run the script.

    Returns:
        (tuple): Tuple containing:

        - **final_mse**: Final MSE for each task.
        - **during_mse**: MSE achieved directly after training on each task.
    """
    script_start = time()

    mode = 'regression_bbb'
    config = train_args.parse_cmd_arguments(mode=mode)

    device, writer, logger = sutils.setup_environment(config, logger_name=mode)

    train_utils.backup_cli_command(config)

    ### Create tasks.
    dhandlers, num_tasks = train_utils.generate_tasks(config, writer)

    ### Generate networks.
    use_hnet = not config.mnet_only
    mnet, hnet = train_utils.generate_gauss_networks(
        config,
        logger,
        dhandlers,
        device,
        create_hnet=use_hnet,
        non_gaussian=config.mean_only)

    ### Simple struct, that is used to share data among functions.
    shared = Namespace()
    shared.experiment_type = mode
    shared.all_dhandlers = dhandlers
    # Mean and variance of prior that is used for variational inference.
    if config.mean_only:  # No prior-matching can be performed.
        shared.prior_mean = None
        shared.prior_logvar = None
        shared.prior_std = None
    else:
        plogvar = np.log(config.prior_variance)
        pstd = np.sqrt(config.prior_variance)
        shared.prior_mean = [torch.zeros(*s).to(device) \
                             for s in mnet.orig_param_shapes]
        shared.prior_logvar = [plogvar * torch.ones(*s).to(device) \
                               for s in mnet.orig_param_shapes]
        shared.prior_std = [pstd * torch.ones(*s).to(device) \
                            for s in mnet.orig_param_shapes]

    # Note, all MSE values are measured on a validation set if given, otherwise
    # on the training set. All samples in the validation set are expected to
    # lay inside the training range. Test samples may lay outside the training
    # range.
    # The MSE value achieved right after training on the corresponding task.
    shared.during_mse = np.ones(num_tasks) * -1.
    # The weights of the main network right after training on that task
    # (can be used to assess how close the final weights are to the original
    # ones). Note, weights refer to mean and variances (e.g., the output of the
    # hypernetwork).
    shared.during_weights = [-1] * num_tasks
    # MSE achieved after most recent call of test method.
    shared.current_mse = np.ones(num_tasks) * -1.

    # Where to save network checkpoints?
    shared.ckpt_dir = os.path.join(config.out_dir, 'checkpoints')
    # Note, some main networks have stuff to store such as batch statistics for
    # batch norm. So it is wise to always checkpoint mnets as well!
    shared.ckpt_mnet_fn = os.path.join(shared.ckpt_dir, 'mnet_task_%d')
    shared.ckpt_hnet_fn = os.path.join(shared.ckpt_dir, 'hnet_task_%d')

    ### Initialize the performance measures, that should be tracked during
    ### training.
    train_utils.setup_summary_dict(config,
                                   shared,
                                   'bbb',
                                   num_tasks,
                                   mnet,
                                   hnet=hnet)

    # Add hparams to tensorboard, such that the identification of runs is
    # easier.
    writer.add_hparams(hparam_dict={
        **vars(config),
        **{
            'num_weights_main': shared.summary['aa_num_weights_main'],
            'num_weights_hyper': shared.summary['aa_num_weights_hyper'],
            'num_weights_ratio': shared.summary['aa_num_weights_ratio'],
        }
    },
                       metric_dict={})

    ### Train on tasks sequentially.
    for i in range(num_tasks):
        logger.info('### Training on task %d ###' % (i + 1))
        data = dhandlers[i]
        # Train the network.
        train(i, data, mnet, hnet, device, config, shared, logger, writer)

        ### Test networks.
        test(dhandlers[:(i + 1)], mnet, hnet, device, config, shared, logger,
             writer)

        if config.train_from_scratch and i < num_tasks - 1:
            # We have to checkpoint the networks, such that we can reload them
            # for task inference later during testing.
            pmutils.checkpoint_nets(config, shared, i, mnet, hnet)

            mnet, hnet = train_utils.generate_gauss_networks(
                config,
                logger,
                dhandlers,
                device,
                create_hnet=use_hnet,
                non_gaussian=config.mean_only)

    if config.store_final_model:
        logger.info('Checkpointing final model ...')
        pmutils.checkpoint_nets(config, shared, num_tasks - 1, mnet, hnet)

    logger.info('During MSE values after training each task: %s' % \
          np.array2string(shared.during_mse, precision=5, separator=','))
    logger.info('Final MSE values after training on all tasks: %s' % \
          np.array2string(shared.current_mse, precision=5, separator=','))
    logger.info('Final MSE mean %.4f (std %.4f).' %
                (shared.current_mse.mean(), shared.current_mse.std()))

    ### Write final summary.
    shared.summary['finished'] = 1
    train_utils.save_summary_dict(config, shared)

    writer.close()

    logger.info('Program finished successfully in %f sec.' %
                (time() - script_start))

    return shared.current_mse, shared.during_mse
コード例 #6
0
    parser.add_argument('out_dir',
                        type=str,
                        help='The output directory of the simulation to be ' +
                        'analyzed.')
    args = parser.parse_args()
    out_dir = args.out_dir

    # Temporary simulation directory, required by method `setup_environment`.
    args.out_dir = os.path.join(
        tempfile.gettempdir(),
        'tmp_' + datetime.now().strftime('%Y-%m-%d_%H-%M-%S'))
    args.loglevel_info = False
    args.random_seed = 42  # Note, this script doesn't perform random computation
    args.deterministic_run = False
    args.no_cuda = True
    device, writer, logger = sutils.setup_environment(args)

    # FIXME Code below copied from script `state_space_analysis`.
    # Load the config
    if not os.path.exists(out_dir):
        raise ValueError('The directory "%s" does not exist.' % out_dir)
    with open(os.path.join(out_dir, "config.pickle"), "rb") as f:
        config = pickle.load(f)
    # Overwrite the directory it it's not the same as the original.
    if config.out_dir != out_dir:
        config.out_dir = out_dir
    # Check for old command line arguments and make compatible with new version.
    config = sta.update_cli_args(config)

    # FIXME only for copy task!
    generate_tasks_func = copytu.generate_copy_tasks
コード例 #7
0
import matplotlib.pyplot as plt
import numpy as np
import torch

from data.timeseries.rnd_rec_teacher import RndRecTeacher
from mnets.simple_rnn import SimpleRNN
import sequential.ht_analyses.state_space_analysis as ssa
import sequential.student_teacher.train_args_st as sta
import sequential.student_teacher.train_utils_st as stu
from sequential.ht_analyses import pca_utils
import utils.ewc_regularizer as ewc
import utils.sim_utils as sutils

if __name__ == '__main__':
    config = sta.parse_cmd_arguments(mode='student_teacher')
    device, writer, logger = sutils.setup_environment(config)
    writer.close()

    num_samples = 1000
    ### Construct datasets ###
    scenario = 2
    if scenario == 0:  # All tasks identical -> hidden dim should remain constant
        d1 = RndRecTeacher(num_test=num_samples, rseed=1)
        d2 = RndRecTeacher(num_test=num_samples, rseed=1)
        d3 = RndRecTeacher(num_test=num_samples, rseed=1)
        dhandlers = [d1, d2, d3]
    elif scenario == 1:  # All tasks different -> hidden dim should increase
        d1 = RndRecTeacher(num_test=num_samples, rseed=1)
        d2 = RndRecTeacher(num_test=num_samples, rseed=2)
        d3 = RndRecTeacher(num_test=num_samples, rseed=3)
        dhandlers = [d1, d2, d3]
コード例 #8
0
def run(method='avb'):
    """Run the script.

    Args:
        method (str, optional): The VI algorithm. Possible values are:

            - ``'avb'``
            - ``'ssge'``

    Returns:
        (tuple): Tuple containing:

        - **final_mse**: Final MSE for each task.
        - **during_mse**: MSE achieved directly after training on each task.
    """
    script_start = time()
    mode = 'regression_' + method
    use_dis = False  # whether a discriminator network is used
    if method == 'avb':
        use_dis = True
    config = train_args.parse_cmd_arguments(mode=mode)

    device, writer, logger = sutils.setup_environment(config, logger_name=mode)

    train_utils.backup_cli_command(config)

    if config.prior_focused:
        logger.info('Running a prior-focused CL experiment ...')
    else:
        logger.info('Learning task-specific posteriors sequentially ...')

    ### Create tasks.
    dhandlers, num_tasks = train_utils.generate_tasks(config, writer)

    ### Simple struct, that is used to share data among functions.
    shared = Namespace()
    shared.experiment_type = mode
    shared.all_dhandlers = dhandlers
    shared.prior_focused = config.prior_focused

    ### Generate networks and environment
    mnet, hnet, hhnet, dnet = pcu.generate_networks(config,
                                                    shared,
                                                    logger,
                                                    dhandlers,
                                                    device,
                                                    create_dis=use_dis)

    # Mean and variance of prior that is used for variational inference.
    # For a prior-focused training, this prior will only be used for the
    # first task.
    pstd = np.sqrt(config.prior_variance)
    shared.prior_mean = [torch.zeros(*s).to(device) \
                         for s in mnet.param_shapes]
    shared.prior_std = [pstd * torch.ones(*s).to(device) \
                        for s in mnet.param_shapes]

    # Note, all MSE values are measured on a validation set if given, otherwise
    # on the training set. All samples in the validation set are expected to
    # lay inside the training range. Test samples may lay outside the training
    # range.
    # The MSE value achieved right after training on the corresponding task.
    shared.during_mse = np.ones(num_tasks) * -1.
    # The weights of the main network right after training on that task
    # (can be used to assess how close the final weights are to the original
    # ones). Note, weights refer to mean and variances (e.g., the output of the
    # hypernetwork).
    shared.during_weights = [-1] * num_tasks
    # MSE achieved after most recent call of test method.
    shared.current_mse = np.ones(num_tasks) * -1.

    # Where to save network checkpoints?
    shared.ckpt_dir = os.path.join(config.out_dir, 'checkpoints')
    # Note, some networks have stuff to store such as batch statistics for
    # batch norm. So it is wise to always checkpoint all networks, even if they
    # where constructed without weights.
    shared.ckpt_mnet_fn = os.path.join(shared.ckpt_dir, 'mnet_task_%d')
    shared.ckpt_hnet_fn = os.path.join(shared.ckpt_dir, 'hnet_task_%d')
    shared.ckpt_hhnet_fn = os.path.join(shared.ckpt_dir, 'hhnet_task_%d')
    #shared.ckpt_dis_fn = os.path.join(shared.ckpt_dir, 'dis_task_%d')

    ### Initialize the performance measures, that should be tracked during
    ### training.
    train_utils.setup_summary_dict(config,
                                   shared,
                                   method,
                                   num_tasks,
                                   mnet,
                                   hnet=hnet,
                                   hhnet=hhnet,
                                   dis=dnet)
    logger.info('Ratio num hnet weights / num mnet weights: %f.' %
                shared.summary['aa_num_weights_hm_ratio'])
    if hhnet is not None:
        logger.info('Ratio num hyper-hnet weights / num mnet weights: %f.' %
                    shared.summary['aa_num_weights_hhm_ratio'])
    if mode == 'regression_avb' and dnet is not None:
        # A discriminator only exists for AVB.
        logger.info('Ratio num dis weights / num mnet weights: %f.' %
                    shared.summary['aa_num_weights_dm_ratio'])

    # Add hparams to tensorboard, such that the identification of runs is
    # easier.
    hparams_extra_dict = {
        'num_weights_hm_ratio': shared.summary['aa_num_weights_hm_ratio'],
        'num_weights_hhm_ratio': shared.summary['aa_num_weights_hhm_ratio']
    }
    if mode == 'regression_avb':
        hparams_extra_dict['num_weights_dm_ratio'] = \
            shared.summary['aa_num_weights_dm_ratio']
    writer.add_hparams(hparam_dict={
        **vars(config),
        **hparams_extra_dict
    },
                       metric_dict={})

    ### Train on tasks sequentially.
    for i in range(num_tasks):
        logger.info('### Training on task %d ###' % (i + 1))
        data = dhandlers[i]

        # Train the network.
        tvi.train(i,
                  data,
                  mnet,
                  hnet,
                  hhnet,
                  dnet,
                  device,
                  config,
                  shared,
                  logger,
                  writer,
                  method=method)

        # Test networks.
        train_bbb.test(dhandlers[:(i + 1)],
                       mnet,
                       hnet,
                       device,
                       config,
                       shared,
                       logger,
                       writer,
                       hhnet=hhnet)

        if config.train_from_scratch and i < num_tasks - 1:
            # We have to checkpoint the networks, such that we can reload them
            # for task inference later during testing.
            # Note, we only need the discriminator as helper for training,
            # so we don't checkpoint it.
            pmu.checkpoint_nets(config,
                                shared,
                                i,
                                mnet,
                                hnet,
                                hhnet=hhnet,
                                dis=None)

            mnet, hnet, hhnet, dnet = pcu.generate_networks(
                config, shared, logger, dhandlers, device)

        elif config.store_during_models:
            logger.info('Checkpointing current model ...')
            pmu.checkpoint_nets(config,
                                shared,
                                i,
                                mnet,
                                hnet,
                                hhnet=hhnet,
                                dis=None)

    if config.store_final_model:
        logger.info('Checkpointing final model ...')
        pmu.checkpoint_nets(config,
                            shared,
                            num_tasks - 1,
                            mnet,
                            hnet,
                            hhnet=hhnet,
                            dis=None)

    logger.info('During MSE values after training each task: %s' % \
                np.array2string(shared.during_mse, precision=5, separator=','))
    logger.info('Final MSE values after training on all tasks: %s' % \
                np.array2string(shared.current_mse, precision=5, separator=','))
    logger.info('Final MSE mean %.4f (std %.4f).' %
                (shared.current_mse.mean(), shared.current_mse.std()))

    ### Write final summary.
    shared.summary['finished'] = 1
    train_utils.save_summary_dict(config, shared)

    writer.close()

    logger.info('Program finished successfully in %f sec.' %
                (time() - script_start))

    return shared.current_mse, shared.during_mse
コード例 #9
0
def run(config, experiment='split_mnist_avb'):
    """Run the training.

    Args:
        config (argparse.Namespace): Command-line arguments.
        experiment (str): Which kind of experiment should be performed?

            - "gmm_avb": Synthetic GMM data with Posterior Replay via AVB
            - "gmm_avb_pf": Synthetic GMM data with Prior-Focused CL via AVB
            - "split_mnist_avb": Split MNIST with Posterior Replay via AVB
            - "perm_mnist_avb": Permuted MNIST with Posterior Replay via AVB
            - "split_mnist_avb_pf": Split MNIST with Prior-Focused CL via AVB
            - "perm_mnist_avb_pf": Permuted MNIST with Prior-Focused CL via AVB
            - "cifar_zenke_avb": CIFAR-10/100 with Posterior Replay using a
              ZenkeNet and AVB
            - "cifar_resnet_avb": CIFAR-10/100 with Posterior Replay using a
              Resnet and AVB
            - "cifar_zenke_avb_pf": CIFAR-10/100 with Prior-Focused CL using a
              ZenkeNet and AVB
            - "cifar_resnet_avb_pf": CIFAR-10/100 with Prior-Focused CL using a
              Resnet and AVB
            - "gmm_ssge": Synthetic GMM data with Posterior Replay via SSGE
            - "gmm_ssge_pf": Synthetic GMM data with Prior-Focused CL via SSGE
            - "split_mnist_ssge": Split MNIST with Posterior Replay via SSGE
            - "perm_mnist_ssge": Permuted MNIST with Posterior Replay via SSGE
            - "split_mnist_ssge_pf": Split MNIST with Prior-Focused CL via SSGE
            - "perm_mnist_ssge_pf": Permuted MNIST with Prior-Focused CL via
              SSGE
            - "cifar_resnet_ssge": CIFAR-10/100 with Posterior Replay using a
              Resnet and SSGE
            - "cifar_resnet_ssge_pf": CIFAR-10/100 with Prior-Focused CL using a
              Resnet and SSGE
    """
    assert experiment in [
        'gmm_avb', 'gmm_avb_pf', 'split_mnist_avb', 'split_mnist_avb_pf',
        'perm_mnist_avb', 'perm_mnist_avb_pf', 'cifar_zenke_avb',
        'cifar_zenke_avb_pf', 'cifar_resnet_avb', 'cifar_resnet_avb_pf',
        'gmm_ssge', 'gmm_ssge_pf', 'split_mnist_ssge', 'split_mnist_ssge_pf',
        'perm_mnist_ssge', 'perm_mnist_ssge_pf', 'cifar_resnet_ssge',
        'cifar_resnet_ssge_pf'
    ]

    script_start = time()

    if 'avb' in experiment:
        method = 'avb'
        use_dis = True  # whether a discriminator network is used
    elif 'ssge' in experiment:
        method = 'ssge'
        use_dis = False

    device, writer, logger = sutils.setup_environment(config,
                                                      logger_name=experiment +
                                                      'logger')

    rutils.backup_cli_command(config)

    if experiment.endswith('pf'):
        prior_focused_cl = True
        logger.info('Running a prior-focused CL experiment ...')
    else:
        prior_focused_cl = False
        logger.info('Learning task-specific posteriors sequentially ...')

    ### Create tasks.
    dhandlers = pmutils.load_datasets(config, logger, experiment, writer)

    ### Simple struct, that is used to share data among functions.
    shared = Namespace()
    shared.experiment_type = experiment
    shared.all_dhandlers = dhandlers
    shared.prior_focused = prior_focused_cl

    ### Generate networks.
    mnet, hnet, hhnet, dis = pcutils.generate_networks(config,
                                                       shared,
                                                       logger,
                                                       dhandlers,
                                                       device,
                                                       create_dis=use_dis)
    if method == 'ssge':
        assert dis is None

    ### Add more information to shared.
    # Mean and variance of prior that is used for variational inference.
    # For a prior-focused training, this prior will only be used for the
    # first task.
    #plogvar = np.log(config.prior_variance)
    pstd = np.sqrt(config.prior_variance)
    shared.prior_mean = [torch.zeros(*s).to(device) \
                         for s in mnet.param_shapes]
    #shared.prior_logvar = [plogvar * torch.ones(*s).to(device) \
    #                       for s in mnet.param_shapes]
    shared.prior_std = [pstd * torch.ones(*s).to(device) \
                        for s in mnet.param_shapes]

    # The output weights of the hyper-hyper network right after training on
    # a task (can be used to assess how close the final weights are to the
    # original ones).
    shared.during_weights = [-1] * config.num_tasks if hhnet is not None \
        else None

    # Where to save network checkpoints?
    shared.ckpt_dir = os.path.join(config.out_dir, 'checkpoints')
    # Note, some networks have stuff to store such as batch statistics for
    # batch norm. So it is wise to always checkpoint all networks, even if they
    # where constructed without weights.
    shared.ckpt_mnet_fn = os.path.join(shared.ckpt_dir, 'mnet_task_%d')
    shared.ckpt_hnet_fn = os.path.join(shared.ckpt_dir, 'hnet_task_%d')
    shared.ckpt_hhnet_fn = os.path.join(shared.ckpt_dir, 'hhnet_task_%d')
    #shared.ckpt_dis_fn = os.path.join(shared.ckpt_dir, 'dis_task_%d')

    # Initialize the softmax temperature per-task with one. Might be changed
    # later on to calibrate the temperature.
    shared.softmax_temp = [torch.ones(1).to(device) \
                           for _ in range(config.num_tasks)]
    shared.num_trained = 0

    # Setup coresets iff regularization on all tasks is allowed.
    if config.coreset_size != -1 and config.past_and_future_coresets:
        for i in range(config.num_tasks):
            pmutils.update_coreset(config,
                                   shared,
                                   i,
                                   dhandlers[i],
                                   None,
                                   None,
                                   device,
                                   logger,
                                   None,
                                   hhnet=None,
                                   method='avb')

    ### Initialize summary.
    pcutils.setup_summary_dict(config,
                               shared,
                               experiment,
                               mnet,
                               hnet=hnet,
                               hhnet=hhnet,
                               dis=dis)
    logger.info('Ratio num hnet weights / num mnet weights: %f.' %
                shared.summary['num_weights_hm_ratio'])
    if 'num_weights_hhm_ratio' in shared.summary.keys():
        logger.info('Ratio num hyper-hnet weights / num mnet weights: %f.' %
                    shared.summary['num_weights_hhm_ratio'])
    if method == 'avb':
        logger.info('Ratio num dis weights / num mnet weights: %f.' %
                    shared.summary['num_weights_dm_ratio'])

    # Add hparams to tensorboard, such that the identification of runs is
    # easier.
    hparams_extra_dict = {
        'num_weights_hm_ratio': shared.summary['num_weights_hm_ratio'],
    }
    if 'num_weights_dm_ratio' in shared.summary.keys():
        hparams_extra_dict = {**hparams_extra_dict,
            **{'num_weights_dm_ratio': \
               shared.summary['num_weights_dm_ratio']}}
    if 'num_weights_hhm_ratio' in shared.summary.keys():
        hparams_extra_dict = {**hparams_extra_dict,
            **{'num_weights_hhm_ratio': \
               shared.summary['num_weights_hhm_ratio']}}
    writer.add_hparams(hparam_dict={
        **vars(config),
        **hparams_extra_dict
    },
                       metric_dict={})

    during_acc_criterion = pmutils.parse_performance_criterion(
        config, shared, logger)

    ### Train on tasks sequentially.
    for i in range(config.num_tasks):
        logger.info('### Training on task %d ###' % (i + 1))
        data = dhandlers[i]

        # Train the network.
        shared.num_trained += 1
        if config.distill_iter == -1:
            tvi.train(i,
                      data,
                      mnet,
                      hnet,
                      hhnet,
                      dis,
                      device,
                      config,
                      shared,
                      logger,
                      writer,
                      method=method)
        else:
            assert hhnet is not None
            # Train main network only.
            tvi.train(i,
                      data,
                      mnet,
                      hnet,
                      None,
                      dis,
                      device,
                      config,
                      shared,
                      logger,
                      writer,
                      method=method)

            # Distill `hnet` into `hhnet`.
            train_bbb.distill_net(i, data, mnet, hnet, hhnet, device, config,
                                  shared, logger, writer)

            # Create a new main network before training the next task.
            mnet, hnet, _, _ = pcutils.generate_networks(config,
                                                         shared,
                                                         logger,
                                                         dhandlers,
                                                         device,
                                                         create_dis=False,
                                                         create_hhnet=False)

        ### Temperature Calibration.
        if config.calibrate_temp:
            pcutils.calibrate_temperature(i, data, mnet, hnet, hhnet, device,
                                          config, shared, logger, writer)

        ### Test networks.
        test_ids = None
        if config.full_test_interval != -1:
            if i == config.num_tasks-1 or \
                    (i > 0 and i % config.full_test_interval == 0):
                test_ids = None  # Test on all tasks.
            else:
                test_ids = [i]  # Only test on current task.
        tvi.test(dhandlers[:(i + 1)],
                 mnet,
                 hnet,
                 hhnet,
                 device,
                 config,
                 shared,
                 logger,
                 writer,
                 test_ids=test_ids,
                 method=method)

        ### Check if last task got "acceptable" accuracy ###
        curr_dur_acc = shared.summary['acc_task_given_during'][i]
        if i < config.num_tasks-1 and during_acc_criterion[i] != -1 \
                and during_acc_criterion[i] > curr_dur_acc:
            logger.error('During accuracy of task %d too small (%f < %f).' % \
                         (i+1, curr_dur_acc, during_acc_criterion[i]))
            logger.error('Training of future tasks will be skipped')
            writer.close()
            exit(1)

        if config.train_from_scratch and i < config.num_tasks - 1:
            # We have to checkpoint the networks, such that we can reload them
            # for task inference later during testing.
            # Note, we only need the discriminator as helper for training,
            # so we don't checkpoint it.
            pmutils.checkpoint_nets(config,
                                    shared,
                                    i,
                                    mnet,
                                    hnet,
                                    hhnet=hhnet,
                                    dis=None)

            mnet, hnet, hhnet, dis = pcutils.generate_networks(
                config, shared, logger, dhandlers, device, create_dis=use_dis)

        elif dis is not None and not config.no_dis_reinit and \
                i < config.num_tasks-1:
            logger.debug('Reinitializing discriminator network ...')
            # FIXME Build a new network as this init doesn't effect batchnorm
            # weights atm.
            dis.custom_init(normal_init=config.normal_init,
                            normal_std=config.std_normal_init,
                            zero_bias=True)

    if config.store_final_model:
        logger.info('Checkpointing final model ...')
        pmutils.checkpoint_nets(config,
                                shared,
                                config.num_tasks - 1,
                                mnet,
                                hnet,
                                hhnet=hhnet,
                                dis=None)

    logger.info('During accuracies (task identity given): %s (avg: %.2f%%).' % \
        (np.array2string(np.array(shared.summary['acc_task_given_during']),
                         precision=2, separator=','),
         shared.summary['acc_avg_task_given_during']))
    logger.info('Final accuracies (task identity given): %s (avg: %.2f%%).' % \
        (np.array2string(np.array(shared.summary['acc_task_given']),
                         precision=2, separator=','),
         shared.summary['acc_avg_task_given']))

    logger.info('During accuracies (task identity inferred): ' +
                '%s (avg: %.2f%%).' % \
        (np.array2string(np.array( \
            shared.summary['acc_task_inferred_ent_during']),
                         precision=2, separator=','),
         shared.summary['acc_avg_task_inferred_ent_during']))
    logger.info('Final accuracies (task identity inferred): ' +
                '%s (avg: %.2f%%).' % \
        (np.array2string(np.array(shared.summary['acc_task_inferred_ent']),
                         precision=2, separator=','),
         shared.summary['acc_avg_task_inferred_ent']))

    logger.info('### Avg. during accuracy (CL scenario %d): %.4f.' %
                (config.cl_scenario, shared.summary['acc_avg_during']))
    logger.info('### Avg. final accuracy (CL scenario %d): %.4f.' %
                (config.cl_scenario, shared.summary['acc_avg_final']))

    ### Write final summary.
    shared.summary['finished'] = 1
    pmutils.save_summary_dict(config, shared, experiment)

    writer.close()

    logger.info('Program finished successfully in %f sec.' %
                (time() - script_start))
コード例 #10
0
def run():
    """ Run the script"""
    #############
    ### Setup ###
    #############

    config = train_args_copy.parse_cmd_arguments()
    device, writer, logger = sutils.setup_environment(config)
    dhandlers = ctu.generate_copy_tasks(config, logger, writer=writer)
    plc.visualise_data(dhandlers, config, device)

    # We will use the namespace below to share miscellaneous information between
    # functions.
    shared = Namespace()
    shared.feature_size = dhandlers[0].in_shape[0]

    if (config.permute_time or config.permute_width) and not \
            config.scatter_pattern and not config.permute_xor_separate and \
            not config.permute_xor_iter > 1:
        chance = ctu.compute_chance_level(dhandlers, config)
        logger.info('Chance level for perfect during accuracies: %.2f' %
                    chance)

    # A bit ugly, find a nicer way (problem is, if you overwrite this before
    # generating the tasks, always the task with shortest sequences is chosen).
    if config.last_task_only:
        config.num_tasks = 1

    target_net, hnet, dnet = stu.generate_networks(config, shared, dhandlers,
                                                   device)

    # generate masks if needed
    ctx_masks = None
    if config.use_masks:
        ctx_masks = stu.generate_binary_masks(config, device, target_net)

    # We store the target network weights (excluding potential context-mod
    # weights after every task). In this way, we can quantify changes and
    # observe the "stiffness" of EWC.
    shared.tnet_weights = []
    # We store the context-mod weights (or all weights) coming from the hypernet
    # after every task, in order to quantify "forgetting". Note, the hnet
    # regularizer should keep them fix.
    shared.hnet_out = []

    # Get the task-specific functions for loss and accuracy.
    task_loss_func = ctu.get_copy_loss_func(config,
                                            device,
                                            logger,
                                            ewc_loss=False)
    accuracy_func = ctu.get_accuracy
    ewc_loss_func = ctu.get_copy_loss_func(config, device, logger, \
        ewc_loss=True) if config.use_ewc else None

    replay_fcts = None
    if config.use_replay:
        replay_fcts = dict()
        replay_fcts['rec_loss'] = ctu.get_vae_rec_loss_func()
        replay_fcts['distill_loss'] = ctu.get_distill_loss_func()
        replay_fcts['soft_trgt_acc'] = ctu.get_soft_trgt_acc_func()

    if config.multitask:
        summary_keywords = hpsearch_mt._SUMMARY_KEYWORDS
        summary_filename = hpsearch_mt._SUMMARY_FILENAME
    else:
        summary_keywords = hpsearch_cl._SUMMARY_KEYWORDS
        summary_filename = hpsearch_cl._SUMMARY_FILENAME

    ########################
    ### Train classifier ###
    ########################

    # Train the network task by task. Testing on all tasks is run after
    # finishing training on each task.
    ret, train_loss, test_loss, test_acc = sts.train_tasks(
        dhandlers,
        target_net,
        hnet,
        dnet,
        device,
        config,
        shared,
        logger,
        writer,
        ctx_masks,
        summary_keywords,
        summary_filename,
        task_loss_func=task_loss_func,
        accuracy_func=accuracy_func,
        ewc_loss_func=ewc_loss_func,
        replay_fcts=replay_fcts)

    stu.log_results(test_acc, config, logger)

    writer.close()

    if ret == -1:
        logger.info('Program finished successfully.')

        if config.show_plots:
            plt.show()
    else:
        logger.error('Only %d tasks have completed training.' % (ret + 1))
コード例 #11
0
def run():
    """Run the script"""
    #############
    ### Setup ###
    #############

    config = train_args_pos.parse_cmd_arguments()
    device, writer, logger = sutils.setup_environment(config)
    dhandlers = ctu.generate_tasks(config, logger, writer=writer)

    # Load preprocessed word embeddings, see
    # :mod:`data.timeseries.preprocess_mud` for details.
    wembs_path = '../../datasets/sequential/mud/embeddings.pickle'
    wemb_lookups = eu.generate_emb_lookups(config,
                                           filename=wembs_path,
                                           device=device)
    assert len(wemb_lookups) == config.num_tasks

    # We will use the namespace below to share miscellaneous information between
    # functions.
    shared = Namespace()
    # The embedding size is fixed due to the use of pretrained polyglot
    # embeddings.
    # FIXME Could be made configurable in the future in case we don't initialize
    # embeddings via polyglot.
    shared.feature_size = 64
    shared.word_emb_lookups = wemb_lookups

    target_net, hnet, dnet = stu.generate_networks(config, shared, dhandlers,
                                                   device)

    # generate masks if needed
    ctx_masks = None
    if config.use_masks:
        ctx_masks = stu.generate_binary_masks(config, device, target_net)

    # We store the target network weights (excluding potential context-mod
    # weights after every task). In this way, we can quantify changes and
    # observe the "stiffness" of EWC.
    shared.tnet_weights = []
    # We store the context-mod weights (or all weights) coming from the hypernet
    # after every task, in order to quantify "forgetting". Note, the hnet
    # regularizer should keep them fix.
    shared.hnet_out = []

    # Get the task-specific functions for loss and accuracy.
    task_loss_func = ctu.get_loss_func(config, device, logger, ewc_loss=False)
    accuracy_func = ctu.get_accuracy_func(config)
    ewc_loss_func = ctu.get_loss_func(config, device, logger, \
        ewc_loss=True) if config.use_ewc else None

    replay_fcts = None
    if config.use_replay:
        replay_fcts = dict()
        replay_fcts['rec_loss'] = ctu.get_vae_rec_loss_func()
        replay_fcts['distill_loss'] = ctu.get_distill_loss_func()
        replay_fcts['soft_trgt_acc'] = ctu.get_soft_trgt_acc_func()

    if config.multitask:
        summary_keywords = hpsearch_mt._SUMMARY_KEYWORDS
        summary_filename = hpsearch_mt._SUMMARY_FILENAME
    else:
        summary_keywords = hpsearch_cl._SUMMARY_KEYWORDS
        summary_filename = hpsearch_cl._SUMMARY_FILENAME

    ########################
    ### Train classifier ###
    ########################

    shared.f_scores = None

    # Train the network task by task. Testing on all tasks is run after
    # finishing training on each task.
    ret, train_loss, test_loss, test_acc = sts.train_tasks(
        dhandlers,
        target_net,
        hnet,
        dnet,
        device,
        config,
        shared,
        logger,
        writer,
        ctx_masks,
        summary_keywords,
        summary_filename,
        task_loss_func=task_loss_func,
        accuracy_func=accuracy_func,
        ewc_loss_func=ewc_loss_func,
        replay_fcts=replay_fcts)

    stu.log_results(test_acc, config, logger)

    writer.close()

    if ret == -1:
        logger.info('Program finished successfully.')

        if config.show_plots:
            plt.show()
    else:
        logger.error('Only %d tasks have completed training.' % (ret + 1))
コード例 #12
0
def run(config, experiment='regression_mt'):
    """Run the Multitask training for the given experiment.

    Args:
        config (argparse.Namespace): Command-line arguments.
        experiment (str): Which kind of experiment should be performed?

            - ``'regression_mt'``: Regression tasks with multitask training
            - ``'gmm_mt'``: GMM Data with multitask training
            - ``'split_mnist_mt'``: SplitMNIST with multitask training
            - ``'perm_mnist_mt'``: PermutedMNIST with multitask training
            - ``'cifar_resnet_mt'``: CIFAR-10/100 with multitask training
    """
    assert experiment in [
        'regression_mt', 'gmm_mt', 'split_mnist_mt', 'perm_mnist_mt',
        'cifar_resnet_mt'
    ]

    script_start = time()

    device, writer, logger = sutils.setup_environment(config,
                                                      logger_name=experiment)

    rutils.backup_cli_command(config)

    is_classification = True
    if 'regression' in experiment:
        is_classification = False

    ### Create tasks.
    if is_classification:
        dhandlers = pmutils.load_datasets(config, logger, experiment, writer)
    else:
        dhandlers, num_tasks = rutils.generate_tasks(config, writer)
        config.num_tasks = num_tasks

    ### Simple struct, that is used to share data among functions.
    shared = Namespace()
    shared.experiment_type = experiment
    shared.all_dhandlers = dhandlers
    shared.num_trained = 0

    ### Generate network.
    mnet, _, _, _ = pcutils.generate_networks(config,
                                              shared,
                                              logger,
                                              dhandlers,
                                              device,
                                              create_mnet=True,
                                              create_hnet=False,
                                              create_hhnet=dhandlers,
                                              create_dis=False)

    if not is_classification:
        shared.during_mse = np.ones(config.num_tasks) * -1.
        # MSE achieved after most recent call of test method.
        shared.current_mse = np.ones(config.num_tasks) * -1.

    # Where to save network checkpoints?
    shared.ckpt_dir = os.path.join(config.out_dir, 'checkpoints')
    shared.ckpt_mnet_fn = os.path.join(shared.ckpt_dir, 'mnet_task_%d')

    # Initialize the softmax temperature per-task with one. Might be changed
    # later on to calibrate the temperature.
    if is_classification:
        shared.softmax_temp = [torch.ones(1).to(device) \
                               for _ in range(config.num_tasks)]

    ### Initialize summary.
    if is_classification:
        pcutils.setup_summary_dict(config, shared, experiment, mnet)
    else:
        rutils.setup_summary_dict(config, shared, 'mt', config.num_tasks, mnet)

    # Add hparams to tensorboard, such that the identification of runs is
    # easier.
    writer.add_hparams(hparam_dict={**vars(config), **{
        'num_weights_main': shared.summary['num_weights_main'] \
            if is_classification else shared.summary['aa_num_weights_main']
    }}, metric_dict={})

    ### Train on all tasks.
    logger.info('### Training ###')
    # Note, since we are training on all tasks; all output heads can at all
    # times be considered as trained!
    shared.num_trained = config.num_tasks
    train(dhandlers, mnet, device, config, shared, logger, writer)

    logger.info('### Testing ###')
    test(dhandlers,
         mnet,
         device,
         config,
         shared,
         logger,
         writer,
         test_ids=None)

    if config.store_final_model:
        logger.info('Checkpointing final model ...')
        pmutils.checkpoint_nets(config, shared, config.num_tasks - 1, mnet,
                                None)

    ### Plot final classification scores.
    if is_classification:
        logger.info('Final accuracies (task identity given): ' + \
                    '%s (avg: %.2f%%).' % \
            (np.array2string(np.array(shared.summary['acc_task_given']),
                             precision=2, separator=','),
             shared.summary['acc_avg_task_given']))

    ### Plot final regression scores.
    if not is_classification:
        logger.info('Final MSE values after training on all tasks: %s' % \
            np.array2string(np.array(shared.summary['aa_mse_final']),
                            precision=5, separator=','))
        logger.info('Final MSE mean %.4f.' % \
                    (shared.summary['aa_mse_during_mean']))

    ### Write final summary.
    shared.summary['finished'] = 1
    if is_classification:
        pmutils.save_summary_dict(config, shared, experiment)
    else:
        rutils.save_summary_dict(config, shared)

    writer.close()

    logger.info('Program finished successfully in %f sec.' %
                (time() - script_start))