コード例 #1
0
 def __call__(self, cfg, i=None):
     try:
         cfg.pop('local_rank', None)  #TODO: properly handle distributed
         resume = cfg.pop('resume', False)
         save = cfg.pop('save', False)
         if i is not None:
             orig_suffix = cfg.setdefault('trainer_config',
                                          {}).get('log_suffix', '')
             cfg['trainer_config']['log_suffix'] = os.path.join(
                 orig_suffix, f'trial{i}/')
         trainer = self.make_trainer(**cfg)
         trainer.logger.add_scalars('config', flatten_dict(cfg))
         epochs = cfg['num_epochs'] if isinstance(
             cfg['num_epochs'], Iterable) else [cfg['num_epochs']]
         if resume:
             trainer.load_checkpoint(None if resume == True else resume)
         epochs = [e for e in epochs if e > trainer.epoch]
         for epoch in epochs:
             trainer.train_to(epoch)
             if save: cfg['saved_at'] = trainer.save_checkpoint()
         outcome = trainer.ckpt['outcome']
         trajectories = []
         for mb in trainer.dataloaders['test']:
             trajectories.append(
                 pred_and_gt_ode(trainer.dataloaders['test'].dataset,
                                 trainer.model, mb))
         torch.save(
             np.concatenate(trajectories),
             f"./{cfg['network']}_{cfg['net_config']['group']}_{i}.t")
     except Exception as e:
         if self.strict: raise
         outcome = e
     del trainer
     return cfg, outcome
コード例 #2
0
    def __call__(self, cfg, i=None):
        cfg.pop('local_rank', None)  # TODO: properly handle distributed
        if i is not None:
            orig_suffix = cfg.setdefault('trainer_config',{}).get('log_suffix','')
            cfg['trainer_config']['log_suffix'] = os.path.join(orig_suffix,f'trial{i}/')
        trainer = self.make_trainer(**cfg)
        trainer.logger.add_scalars('config', flatten_dict(cfg))
        trainer.train(cfg['num_epochs'])
        outcome = trainer.logger.scalar_frame.iloc[-1:]
        trainer.logger.save_object(trainer.model.state_dict(),suffix=f'checkpoints/final.state')
        trainer.logger.save_object(trainer.logger.scalar_frame,suffix=f'scalars.df')

        return cfg, outcome
コード例 #3
0
    def __call__(self, cfg, i=None):
        cfg.pop('local_rank', None)  # TODO: properly handle distributed
        if i is not None:
            orig_suffix = cfg.setdefault('trainer_cfg',
                                         {}).get('log_suffix', '')
            cfg['trainer_cfg']['log_suffix'] = os.path.join(
                orig_suffix, f'trial{i}/')
        trainer = self.make_trainer(**cfg)
        trainer.logger.add_scalars('config', flatten_dict(cfg))

        # trainer.train(round(cfg['n_epochs'] / 2))
        # trainer.model.load_state_dict(trainer.ckpt[1])
        # # trainer.model.train()
        # # trainer.model.eval()

        # trainer.train(round(cfg['n_epochs'] / 2))
        trainer.train(cfg['num_epochs'])
        save = cfg.pop('save', True)
        # trainer.model.load_state_dict(trainer.ckpt[1])
        # trainer.model.train()
        # trainer.model.eval()

        # # cast to double
        # trainer.model.double()
        # optimizer_sd = trainer.optimizer.state_dict()
        # for val in optimizer_sd['state'].values():
        #     if torch.is_tensor(val):
        #         val.double()
        # trainer.optimizer.load_state_dict(optimizer_sd)
        # trainer.dataloaders = {k:LoaderTo(v, dtype=torch.float64) for k, v in trainer.dataloaders.items()}
        # trainer.train(int(round(0.1*cfg['n_epochs'])))
        # if trainer.traj_data is not None:
        #     trainer.logger.add_scalars('metrics', {'rollout_mse': trainer._get_rollout_mse()})
        outcome = trainer.logger.scalar_frame.iloc[trainer.
                                                   ckpt[0]:trainer.ckpt[0] + 1]
        if save:
            trainer.logger.save_object(
                trainer.ckpt[1], suffix=f'checkpoints/{trainer.ckpt[0]}.state')
        self.trainer = trainer
        return cfg, outcome