Beispiel #1
0
def run():
    """ Run the script"""
    #############
    ### Setup ###
    #############

    config = train_args_copy.parse_cmd_arguments()
    device, writer, logger = sutils.setup_environment(config)
    dhandlers = ctu.generate_copy_tasks(config, logger, writer=writer)
    plc.visualise_data(dhandlers, config, device)

    # We will use the namespace below to share miscellaneous information between
    # functions.
    shared = Namespace()
    shared.feature_size = dhandlers[0].in_shape[0]

    if (config.permute_time or config.permute_width) and not \
            config.scatter_pattern and not config.permute_xor_separate and \
            not config.permute_xor_iter > 1:
        chance = ctu.compute_chance_level(dhandlers, config)
        logger.info('Chance level for perfect during accuracies: %.2f' %
                    chance)

    # A bit ugly, find a nicer way (problem is, if you overwrite this before
    # generating the tasks, always the task with shortest sequences is chosen).
    if config.last_task_only:
        config.num_tasks = 1

    target_net, hnet, dnet = stu.generate_networks(config, shared, dhandlers,
                                                   device)

    # generate masks if needed
    ctx_masks = None
    if config.use_masks:
        ctx_masks = stu.generate_binary_masks(config, device, target_net)

    # We store the target network weights (excluding potential context-mod
    # weights after every task). In this way, we can quantify changes and
    # observe the "stiffness" of EWC.
    shared.tnet_weights = []
    # We store the context-mod weights (or all weights) coming from the hypernet
    # after every task, in order to quantify "forgetting". Note, the hnet
    # regularizer should keep them fix.
    shared.hnet_out = []

    # Get the task-specific functions for loss and accuracy.
    task_loss_func = ctu.get_copy_loss_func(config,
                                            device,
                                            logger,
                                            ewc_loss=False)
    accuracy_func = ctu.get_accuracy
    ewc_loss_func = ctu.get_copy_loss_func(config, device, logger, \
        ewc_loss=True) if config.use_ewc else None

    replay_fcts = None
    if config.use_replay:
        replay_fcts = dict()
        replay_fcts['rec_loss'] = ctu.get_vae_rec_loss_func()
        replay_fcts['distill_loss'] = ctu.get_distill_loss_func()
        replay_fcts['soft_trgt_acc'] = ctu.get_soft_trgt_acc_func()

    if config.multitask:
        summary_keywords = hpsearch_mt._SUMMARY_KEYWORDS
        summary_filename = hpsearch_mt._SUMMARY_FILENAME
    else:
        summary_keywords = hpsearch_cl._SUMMARY_KEYWORDS
        summary_filename = hpsearch_cl._SUMMARY_FILENAME

    ########################
    ### Train classifier ###
    ########################

    # Train the network task by task. Testing on all tasks is run after
    # finishing training on each task.
    ret, train_loss, test_loss, test_acc = sts.train_tasks(
        dhandlers,
        target_net,
        hnet,
        dnet,
        device,
        config,
        shared,
        logger,
        writer,
        ctx_masks,
        summary_keywords,
        summary_filename,
        task_loss_func=task_loss_func,
        accuracy_func=accuracy_func,
        ewc_loss_func=ewc_loss_func,
        replay_fcts=replay_fcts)

    stu.log_results(test_acc, config, logger)

    writer.close()

    if ret == -1:
        logger.info('Program finished successfully.')

        if config.show_plots:
            plt.show()
    else:
        logger.error('Only %d tasks have completed training.' % (ret + 1))
Beispiel #2
0
def run():
    """ Run the script"""
    #############
    ### Setup ###
    #############

    config = train_args_seq_smnist.parse_cmd_arguments()
    device, writer, logger = sutils.setup_environment(config)
    dhandlers = ctu._generate_tasks(config, logger)

    # We will use the namespace below to share miscellaneous information between
    # functions.
    shared = Namespace()
    shared.feature_size = dhandlers[0].in_shape[0]

    # Plot images.
    if config.show_plots:
        figure_dir = os.path.join(config.out_dir, 'figures')
        if not os.path.exists(figure_dir):
            os.makedirs(figure_dir)

        for t, dh in enumerate(dhandlers):
            dh.plot_samples('Test Samples - Task %d' % t,
                dh.get_train_inputs()[:8], outputs=dh.get_train_outputs()[:8],
                show=True, filename=os.path.join(figure_dir,
                    'test_samples_task_%d.png' % t))

    target_net, hnet, dnet = stu.generate_networks(config, shared, dhandlers,
                                                   device)

    # generate masks if needed
    ctx_masks = None
    if config.use_masks:
        ctx_masks = stu.generate_binary_masks(config, device, target_net)

    # We store the target network weights (excluding potential context-mod
    # weights after every task). In this way, we can quantify changes and
    # observe the "stiffness" of EWC.
    shared.tnet_weights = []
    # We store the context-mod weights (or all weights) coming from the hypernet
    # after every task, in order to quantify "forgetting". Note, the hnet
    # regularizer should keep them fix.
    shared.hnet_out = []

    # Get the task-specific functions for loss and accuracy.
    task_loss_func = ctu.get_loss_func(config, device, logger, ewc_loss=False)
    accuracy_func = ctu.get_accuracy_func(config)
    ewc_loss_func = ctu.get_loss_func(config, device, logger, ewc_loss=True) \
        if config.use_ewc else None

    replay_fcts = None
    if config.use_replay:
        replay_fcts = dict()
        replay_fcts['rec_loss'] = ctu.get_vae_rec_loss_func()
        replay_fcts['distill_loss'] = ctu.get_distill_loss_func()
        replay_fcts['soft_trgt_acc'] = ctu.get_soft_trgt_acc_func()

    if config.multitask:
        summary_keywords=hpsearch_mt._SUMMARY_KEYWORDS
        summary_filename=hpsearch_mt._SUMMARY_FILENAME
    else:
        summary_keywords=hpsearch_cl._SUMMARY_KEYWORDS
        summary_filename=hpsearch_cl._SUMMARY_FILENAME

    ########################
    ### Train classifier ###
    ########################

    # Train the network task by task. Testing on all tasks is run after 
    # finishing training on each task.
    ret, train_loss, test_loss, test_acc = sts.train_tasks(dhandlers,
        target_net, hnet, dnet, device, config, shared, logger, writer,
        ctx_masks, summary_keywords, summary_filename,
        task_loss_func=task_loss_func, accuracy_func=accuracy_func,
        ewc_loss_func=ewc_loss_func, replay_fcts=replay_fcts)

    stu.log_results(test_acc, config, logger)

    writer.close()

    if ret == -1:
        logger.info('Program finished successfully.')

        if config.show_plots:
            plt.show()
    else:
        logger.error('Only %d tasks have completed training.' % (ret+1))
Beispiel #3
0
def run():
    """Run the script"""
    #############
    ### Setup ###
    #############

    config = train_args_pos.parse_cmd_arguments()
    device, writer, logger = sutils.setup_environment(config)
    dhandlers = ctu.generate_tasks(config, logger, writer=writer)

    # Load preprocessed word embeddings, see
    # :mod:`data.timeseries.preprocess_mud` for details.
    wembs_path = '../../datasets/sequential/mud/embeddings.pickle'
    wemb_lookups = eu.generate_emb_lookups(config,
                                           filename=wembs_path,
                                           device=device)
    assert len(wemb_lookups) == config.num_tasks

    # We will use the namespace below to share miscellaneous information between
    # functions.
    shared = Namespace()
    # The embedding size is fixed due to the use of pretrained polyglot
    # embeddings.
    # FIXME Could be made configurable in the future in case we don't initialize
    # embeddings via polyglot.
    shared.feature_size = 64
    shared.word_emb_lookups = wemb_lookups

    target_net, hnet, dnet = stu.generate_networks(config, shared, dhandlers,
                                                   device)

    # generate masks if needed
    ctx_masks = None
    if config.use_masks:
        ctx_masks = stu.generate_binary_masks(config, device, target_net)

    # We store the target network weights (excluding potential context-mod
    # weights after every task). In this way, we can quantify changes and
    # observe the "stiffness" of EWC.
    shared.tnet_weights = []
    # We store the context-mod weights (or all weights) coming from the hypernet
    # after every task, in order to quantify "forgetting". Note, the hnet
    # regularizer should keep them fix.
    shared.hnet_out = []

    # Get the task-specific functions for loss and accuracy.
    task_loss_func = ctu.get_loss_func(config, device, logger, ewc_loss=False)
    accuracy_func = ctu.get_accuracy_func(config)
    ewc_loss_func = ctu.get_loss_func(config, device, logger, \
        ewc_loss=True) if config.use_ewc else None

    replay_fcts = None
    if config.use_replay:
        replay_fcts = dict()
        replay_fcts['rec_loss'] = ctu.get_vae_rec_loss_func()
        replay_fcts['distill_loss'] = ctu.get_distill_loss_func()
        replay_fcts['soft_trgt_acc'] = ctu.get_soft_trgt_acc_func()

    if config.multitask:
        summary_keywords = hpsearch_mt._SUMMARY_KEYWORDS
        summary_filename = hpsearch_mt._SUMMARY_FILENAME
    else:
        summary_keywords = hpsearch_cl._SUMMARY_KEYWORDS
        summary_filename = hpsearch_cl._SUMMARY_FILENAME

    ########################
    ### Train classifier ###
    ########################

    shared.f_scores = None

    # Train the network task by task. Testing on all tasks is run after
    # finishing training on each task.
    ret, train_loss, test_loss, test_acc = sts.train_tasks(
        dhandlers,
        target_net,
        hnet,
        dnet,
        device,
        config,
        shared,
        logger,
        writer,
        ctx_masks,
        summary_keywords,
        summary_filename,
        task_loss_func=task_loss_func,
        accuracy_func=accuracy_func,
        ewc_loss_func=ewc_loss_func,
        replay_fcts=replay_fcts)

    stu.log_results(test_acc, config, logger)

    writer.close()

    if ret == -1:
        logger.info('Program finished successfully.')

        if config.show_plots:
            plt.show()
    else:
        logger.error('Only %d tasks have completed training.' % (ret + 1))