def main(): init_database(args.logdir) py_filename = osp.join('jacmldash.py') if osp.isfile(py_filename): logger.critical('Loading JacMLDash config: {}.'.format( osp.abspath(py_filename))) config = load_source(py_filename) if hasattr(config, 'ui_methods'): register_ui_methods(config.ui_methods) if hasattr(config, 'run_methods'): register_run_methods(config.run_methods) if hasattr(config, 'custom_pages'): register_custom_pages(config.custom_pages) if args.cli: from IPython import embed embed() return app = make_app( [ 'mldash.web.app.index', 'mldash.web.app.experiment', 'mldash.plugins.tensorboard.handler', 'mldash.plugins.trashbin.handler', 'mldash.plugins.star.handler', 'mldash.plugins.viewer.handler' ], { 'gzip': True, 'debug': args.debug, 'xsrf_cookies': True, 'static_path': get_static_path(), 'template_path': get_template_path(), 'ui_methods': get_ui_methods(), "cookie_secret": "20f42d0ae6548e88cf9788e725b298bd", "session_secret": "3cdcb1f00803b6e78ab50b466a40b9977db396840c28307f428b25e2277f1bcc", "frontend_secret": "asdjikfh98234uf9pidwja09f9adsjp9fd6840c28307f428b25e2277f1bcc", "cookie_prefix": 'jac_', 'session_engine': 'off', }) app.listen(args.port, xheaders=True) logger.critical('Mainloop started. Port: {}.'.format(args.port)) loop = tornado.ioloop.IOLoop.current() loop.start()
# filenames args.series_name = args.dataset args.desc_name = escape_desc_name(args.desc) args.run_name = 'run-{}'.format(time.strftime('%Y-%m-%d-%H-%M-%S')) # directories if args.use_gpu: nr_devs = cuda.device_count() if args.force_gpu and nr_devs == 0: nr_devs = 1 assert nr_devs > 0, 'No GPU device available' args.gpus = [i for i in range(nr_devs)] args.gpu_parallel = (nr_devs > 1) desc = load_source(args.desc) configs = desc.configs args.configs.apply(configs) def main(): args.dump_dir = ensure_path(osp.join('dumps', args.dataset_name, args.desc_name, args.expr)) if not args.debug: args.ckpt_dir = ensure_path(osp.join(args.dump_dir, 'checkpoints')) args.meta_dir = ensure_path(osp.join(args.dump_dir, 'meta')) args.meta_file = osp.join(args.meta_dir, args.run_name + '.json') args.log_file = osp.join(args.meta_dir, args.run_name + '.log') args.meter_file = osp.join(args.meta_dir, args.run_name + '.meter.json') logger.critical('Writing logs to file: "{}".'.format(args.log_file))
def main_train(train_dataset, validation_dataset, extra_dataset=None): logger.critical('Building the model.') model = desc.make_model(args) if args.version=='v3': desc_pred = load_source(args.pred_model_path) model.build_temporal_prediction_model(args, desc_pred) elif args.version=='v4': desc_pred = load_source(args.pred_model_path) desc_spatial_pred = load_source(args.pred_spatial_model_path) model.build_temporal_prediction_model(args, desc_pred, desc_spatial_pred) elif args.version=='v2_1': model.make_relation_embedding_for_unseen_events(args) if args.use_gpu: model.cuda() # Disable the cudnn benchmark. cudnn.benchmark = False if hasattr(desc, 'make_optimizer'): logger.critical('Building customized optimizer.') optimizer = desc.make_optimizer(model, args.lr) else: from jactorch.optim import AdamW if args.freeze_learner_flag==1: if args.reconstruct_flag: parameters = list(model._model_pred.parameters())+list(model._decoder.parameters()) trainable_parameters = filter(lambda x: x.requires_grad, parameters) elif args.version=='v4': trainable_parameters = filter(lambda x: x.requires_grad, model._model_pred.parameters()) else: trainable_parameters = filter(lambda x: x.requires_grad, model.parameters()) optimizer = AdamW(trainable_parameters, args.lr, weight_decay=configs.train.weight_decay) if args.acc_grad > 1: from jactorch.optim import AccumGrad optimizer = AccumGrad(optimizer, args.acc_grad) logger.warning('Use accumulated grad={:d}, effective iterations per epoch={:d}.'.format(args.acc_grad, int(args.iters_per_epoch / args.acc_grad))) trainer = TrainerEnv(model, optimizer) if args.resume: extra = trainer.load_checkpoint(args.resume) if extra: args.start_epoch = extra['epoch'] logger.critical('Resume from epoch {}.'.format(args.start_epoch)) elif args.load: if trainer.load_weights(args.load): logger.critical('Loaded weights from pretrained model: "{}".'.format(args.load)) if args.version=='v3': if args.pretrain_pred_model_path: model._model_pred.load_state_dict(torch.load(args.pretrain_pred_model_path)) logger.critical('Loaded weights from pretrained temporal model: "{}".'.format(args.pretrain_pred_model_path)) elif args.version=='v4': if args.pretrain_pred_spatial_model_path: model._model_spatial_pred.load_state_dict(torch.load(args.pretrain_pred_spatial_model_path)) logger.critical('Loaded spatial models from pretrained temporal model: "{}".'.format(args.pretrain_pred_spatial_model_path)) if args.pretrain_pred_feature_model_path: model._model_pred.load_state_dict(torch.load(args.pretrain_pred_feature_model_path)) logger.critical('Loaded feature models from pretrained temporal model: "{}".'.format(args.pretrain_pred_feature_model_path)) #pdb.set_trace() if args.pretrain_pred_model_path: model._model_pred.load_state_dict(torch.load(args.pretrain_pred_model_path)) logger.critical('Loaded weights from pretrained temporal model: "{}".'.format(args.pretrain_pred_model_path)) elif args.version =='v2_1': model.reasoning.embedding_relation_future.load_state_dict(model.reasoning.embedding_relation.state_dict()) model.reasoning.embedding_relation_counterfact.load_state_dict(model.reasoning.embedding_relation.state_dict()) logger.critical('Copy original relation weights into counterfact and future relation.') if args.use_tb and not args.debug: from jactorch.train.tb import TBLogger, TBGroupMeters tb_logger = TBLogger(args.tb_dir) meters = TBGroupMeters(tb_logger) logger.critical('Writing tensorboard logs to: "{}".'.format(args.tb_dir)) else: from jacinle.utils.meter import GroupMeters meters = GroupMeters() if not args.debug: logger.critical('Writing meter logs to file: "{}".'.format(args.meter_file)) if args.clip_grad: logger.info('Registering the clip_grad hook: {}.'.format(args.clip_grad)) def clip_grad(self, loss): from torch.nn.utils import clip_grad_norm_ clip_grad_norm_(self.model.parameters(), max_norm=args.clip_grad) trainer.register_event('backward:after', clip_grad) if hasattr(desc, 'customize_trainer'): desc.customize_trainer(trainer) if args.embed: from IPython import embed; embed() if args.debug: shuffle_flag=False else: shuffle_flag=True logger.critical('Building the data loader.') validation_dataloader = validation_dataset.make_dataloader(args.batch_size, shuffle=False, drop_last=False, nr_workers=args.data_workers) if extra_dataset is not None: extra_dataloader = extra_dataset.make_dataloader(args.batch_size, shuffle=False, drop_last=False, nr_workers=args.data_workers) if args.evaluate: meters.reset() model.eval() validate_epoch(0, trainer, validation_dataloader, meters) if extra_dataset is not None: validate_epoch(0, trainer, extra_dataloader, meters, meter_prefix='validation_extra') logger.critical(meters.format_simple('Validation', {k: v for k, v in meters.avg.items() if v != 0}, compressed=False)) return meters for epoch in range(args.start_epoch + 1, args.epochs + 1): meters.reset() model.train() this_train_dataset = train_dataset train_dataloader = this_train_dataset.make_dataloader(args.batch_size, shuffle=shuffle_flag, drop_last=True, nr_workers=args.data_workers) for enum_id in range(args.enums_per_epoch): train_epoch(epoch, trainer, train_dataloader, meters) if epoch % args.validation_interval == 0: model.eval() validate_epoch(epoch, trainer, validation_dataloader, meters) if not args.debug: meters.dump(args.meter_file) logger.critical(meters.format_simple( 'Epoch = {}'.format(epoch), {k: v for k, v in meters.avg.items() if epoch % args.validation_interval == 0 or not k.startswith('validation')}, compressed=False )) if epoch % args.save_interval == 0 and not args.debug: fname = osp.join(args.ckpt_dir, 'epoch_{}.pth'.format(epoch)) trainer.save_checkpoint(fname, dict(epoch=epoch, meta_file=args.meta_file)) if epoch > int(args.epochs * 0.6): trainer.set_learning_rate(args.lr * 0.1)