def run(): """ Run the script""" ############# ### Setup ### ############# config = train_args_copy.parse_cmd_arguments() device, writer, logger = sutils.setup_environment(config) dhandlers = ctu.generate_copy_tasks(config, logger, writer=writer) plc.visualise_data(dhandlers, config, device) # We will use the namespace below to share miscellaneous information between # functions. shared = Namespace() shared.feature_size = dhandlers[0].in_shape[0] if (config.permute_time or config.permute_width) and not \ config.scatter_pattern and not config.permute_xor_separate and \ not config.permute_xor_iter > 1: chance = ctu.compute_chance_level(dhandlers, config) logger.info('Chance level for perfect during accuracies: %.2f' % chance) # A bit ugly, find a nicer way (problem is, if you overwrite this before # generating the tasks, always the task with shortest sequences is chosen). if config.last_task_only: config.num_tasks = 1 target_net, hnet, dnet = stu.generate_networks(config, shared, dhandlers, device) # generate masks if needed ctx_masks = None if config.use_masks: ctx_masks = stu.generate_binary_masks(config, device, target_net) # We store the target network weights (excluding potential context-mod # weights after every task). In this way, we can quantify changes and # observe the "stiffness" of EWC. shared.tnet_weights = [] # We store the context-mod weights (or all weights) coming from the hypernet # after every task, in order to quantify "forgetting". Note, the hnet # regularizer should keep them fix. shared.hnet_out = [] # Get the task-specific functions for loss and accuracy. task_loss_func = ctu.get_copy_loss_func(config, device, logger, ewc_loss=False) accuracy_func = ctu.get_accuracy ewc_loss_func = ctu.get_copy_loss_func(config, device, logger, \ ewc_loss=True) if config.use_ewc else None replay_fcts = None if config.use_replay: replay_fcts = dict() replay_fcts['rec_loss'] = ctu.get_vae_rec_loss_func() replay_fcts['distill_loss'] = ctu.get_distill_loss_func() replay_fcts['soft_trgt_acc'] = ctu.get_soft_trgt_acc_func() if config.multitask: summary_keywords = hpsearch_mt._SUMMARY_KEYWORDS summary_filename = hpsearch_mt._SUMMARY_FILENAME else: summary_keywords = hpsearch_cl._SUMMARY_KEYWORDS summary_filename = hpsearch_cl._SUMMARY_FILENAME ######################## ### Train classifier ### ######################## # Train the network task by task. Testing on all tasks is run after # finishing training on each task. ret, train_loss, test_loss, test_acc = sts.train_tasks( dhandlers, target_net, hnet, dnet, device, config, shared, logger, writer, ctx_masks, summary_keywords, summary_filename, task_loss_func=task_loss_func, accuracy_func=accuracy_func, ewc_loss_func=ewc_loss_func, replay_fcts=replay_fcts) stu.log_results(test_acc, config, logger) writer.close() if ret == -1: logger.info('Program finished successfully.') if config.show_plots: plt.show() else: logger.error('Only %d tasks have completed training.' % (ret + 1))
def run(): """ Run the script""" ############# ### Setup ### ############# config = train_args_seq_smnist.parse_cmd_arguments() device, writer, logger = sutils.setup_environment(config) dhandlers = ctu._generate_tasks(config, logger) # We will use the namespace below to share miscellaneous information between # functions. shared = Namespace() shared.feature_size = dhandlers[0].in_shape[0] # Plot images. if config.show_plots: figure_dir = os.path.join(config.out_dir, 'figures') if not os.path.exists(figure_dir): os.makedirs(figure_dir) for t, dh in enumerate(dhandlers): dh.plot_samples('Test Samples - Task %d' % t, dh.get_train_inputs()[:8], outputs=dh.get_train_outputs()[:8], show=True, filename=os.path.join(figure_dir, 'test_samples_task_%d.png' % t)) target_net, hnet, dnet = stu.generate_networks(config, shared, dhandlers, device) # generate masks if needed ctx_masks = None if config.use_masks: ctx_masks = stu.generate_binary_masks(config, device, target_net) # We store the target network weights (excluding potential context-mod # weights after every task). In this way, we can quantify changes and # observe the "stiffness" of EWC. shared.tnet_weights = [] # We store the context-mod weights (or all weights) coming from the hypernet # after every task, in order to quantify "forgetting". Note, the hnet # regularizer should keep them fix. shared.hnet_out = [] # Get the task-specific functions for loss and accuracy. task_loss_func = ctu.get_loss_func(config, device, logger, ewc_loss=False) accuracy_func = ctu.get_accuracy_func(config) ewc_loss_func = ctu.get_loss_func(config, device, logger, ewc_loss=True) \ if config.use_ewc else None replay_fcts = None if config.use_replay: replay_fcts = dict() replay_fcts['rec_loss'] = ctu.get_vae_rec_loss_func() replay_fcts['distill_loss'] = ctu.get_distill_loss_func() replay_fcts['soft_trgt_acc'] = ctu.get_soft_trgt_acc_func() if config.multitask: summary_keywords=hpsearch_mt._SUMMARY_KEYWORDS summary_filename=hpsearch_mt._SUMMARY_FILENAME else: summary_keywords=hpsearch_cl._SUMMARY_KEYWORDS summary_filename=hpsearch_cl._SUMMARY_FILENAME ######################## ### Train classifier ### ######################## # Train the network task by task. Testing on all tasks is run after # finishing training on each task. ret, train_loss, test_loss, test_acc = sts.train_tasks(dhandlers, target_net, hnet, dnet, device, config, shared, logger, writer, ctx_masks, summary_keywords, summary_filename, task_loss_func=task_loss_func, accuracy_func=accuracy_func, ewc_loss_func=ewc_loss_func, replay_fcts=replay_fcts) stu.log_results(test_acc, config, logger) writer.close() if ret == -1: logger.info('Program finished successfully.') if config.show_plots: plt.show() else: logger.error('Only %d tasks have completed training.' % (ret+1))
def run(): """Run the script""" ############# ### Setup ### ############# config = train_args_pos.parse_cmd_arguments() device, writer, logger = sutils.setup_environment(config) dhandlers = ctu.generate_tasks(config, logger, writer=writer) # Load preprocessed word embeddings, see # :mod:`data.timeseries.preprocess_mud` for details. wembs_path = '../../datasets/sequential/mud/embeddings.pickle' wemb_lookups = eu.generate_emb_lookups(config, filename=wembs_path, device=device) assert len(wemb_lookups) == config.num_tasks # We will use the namespace below to share miscellaneous information between # functions. shared = Namespace() # The embedding size is fixed due to the use of pretrained polyglot # embeddings. # FIXME Could be made configurable in the future in case we don't initialize # embeddings via polyglot. shared.feature_size = 64 shared.word_emb_lookups = wemb_lookups target_net, hnet, dnet = stu.generate_networks(config, shared, dhandlers, device) # generate masks if needed ctx_masks = None if config.use_masks: ctx_masks = stu.generate_binary_masks(config, device, target_net) # We store the target network weights (excluding potential context-mod # weights after every task). In this way, we can quantify changes and # observe the "stiffness" of EWC. shared.tnet_weights = [] # We store the context-mod weights (or all weights) coming from the hypernet # after every task, in order to quantify "forgetting". Note, the hnet # regularizer should keep them fix. shared.hnet_out = [] # Get the task-specific functions for loss and accuracy. task_loss_func = ctu.get_loss_func(config, device, logger, ewc_loss=False) accuracy_func = ctu.get_accuracy_func(config) ewc_loss_func = ctu.get_loss_func(config, device, logger, \ ewc_loss=True) if config.use_ewc else None replay_fcts = None if config.use_replay: replay_fcts = dict() replay_fcts['rec_loss'] = ctu.get_vae_rec_loss_func() replay_fcts['distill_loss'] = ctu.get_distill_loss_func() replay_fcts['soft_trgt_acc'] = ctu.get_soft_trgt_acc_func() if config.multitask: summary_keywords = hpsearch_mt._SUMMARY_KEYWORDS summary_filename = hpsearch_mt._SUMMARY_FILENAME else: summary_keywords = hpsearch_cl._SUMMARY_KEYWORDS summary_filename = hpsearch_cl._SUMMARY_FILENAME ######################## ### Train classifier ### ######################## shared.f_scores = None # Train the network task by task. Testing on all tasks is run after # finishing training on each task. ret, train_loss, test_loss, test_acc = sts.train_tasks( dhandlers, target_net, hnet, dnet, device, config, shared, logger, writer, ctx_masks, summary_keywords, summary_filename, task_loss_func=task_loss_func, accuracy_func=accuracy_func, ewc_loss_func=ewc_loss_func, replay_fcts=replay_fcts) stu.log_results(test_acc, config, logger) writer.close() if ret == -1: logger.info('Program finished successfully.') if config.show_plots: plt.show() else: logger.error('Only %d tasks have completed training.' % (ret + 1))