예제 #1
0
def main():
    init_database(args.logdir)

    py_filename = osp.join('jacmldash.py')
    if osp.isfile(py_filename):
        logger.critical('Loading JacMLDash config: {}.'.format(
            osp.abspath(py_filename)))
        config = load_source(py_filename)
        if hasattr(config, 'ui_methods'):
            register_ui_methods(config.ui_methods)
        if hasattr(config, 'run_methods'):
            register_run_methods(config.run_methods)
        if hasattr(config, 'custom_pages'):
            register_custom_pages(config.custom_pages)

    if args.cli:
        from IPython import embed
        embed()
        return

    app = make_app(
        [
            'mldash.web.app.index', 'mldash.web.app.experiment',
            'mldash.plugins.tensorboard.handler',
            'mldash.plugins.trashbin.handler', 'mldash.plugins.star.handler',
            'mldash.plugins.viewer.handler'
        ], {
            'gzip': True,
            'debug': args.debug,
            'xsrf_cookies': True,
            'static_path': get_static_path(),
            'template_path': get_template_path(),
            'ui_methods': get_ui_methods(),
            "cookie_secret": "20f42d0ae6548e88cf9788e725b298bd",
            "session_secret":
            "3cdcb1f00803b6e78ab50b466a40b9977db396840c28307f428b25e2277f1bcc",
            "frontend_secret":
            "asdjikfh98234uf9pidwja09f9adsjp9fd6840c28307f428b25e2277f1bcc",
            "cookie_prefix": 'jac_',
            'session_engine': 'off',
        })

    app.listen(args.port, xheaders=True)

    logger.critical('Mainloop started. Port: {}.'.format(args.port))
    loop = tornado.ioloop.IOLoop.current()
    loop.start()
예제 #2
0
# filenames
args.series_name = args.dataset
args.desc_name = escape_desc_name(args.desc)
args.run_name = 'run-{}'.format(time.strftime('%Y-%m-%d-%H-%M-%S'))

# directories

if args.use_gpu:
    nr_devs = cuda.device_count()
    if args.force_gpu and nr_devs == 0:
        nr_devs = 1
    assert nr_devs > 0, 'No GPU device available'
    args.gpus = [i for i in range(nr_devs)]
    args.gpu_parallel = (nr_devs > 1)

desc = load_source(args.desc)
configs = desc.configs
args.configs.apply(configs)


def main():
    args.dump_dir = ensure_path(osp.join('dumps', args.dataset_name, args.desc_name, args.expr))

    if not args.debug:
        args.ckpt_dir = ensure_path(osp.join(args.dump_dir, 'checkpoints'))
        args.meta_dir = ensure_path(osp.join(args.dump_dir, 'meta'))
        args.meta_file = osp.join(args.meta_dir, args.run_name + '.json')
        args.log_file = osp.join(args.meta_dir, args.run_name + '.log')
        args.meter_file = osp.join(args.meta_dir, args.run_name + '.meter.json')

        logger.critical('Writing logs to file: "{}".'.format(args.log_file))
예제 #3
0
def main_train(train_dataset, validation_dataset, extra_dataset=None):
    logger.critical('Building the model.')
    model = desc.make_model(args)
    if args.version=='v3':
        desc_pred = load_source(args.pred_model_path)
        model.build_temporal_prediction_model(args, desc_pred)
    elif args.version=='v4':
        desc_pred = load_source(args.pred_model_path)
        desc_spatial_pred = load_source(args.pred_spatial_model_path)
        model.build_temporal_prediction_model(args, desc_pred, desc_spatial_pred)

    elif args.version=='v2_1':
        model.make_relation_embedding_for_unseen_events(args) 

    if args.use_gpu:
        model.cuda()
        # Disable the cudnn benchmark.
        cudnn.benchmark = False

    if hasattr(desc, 'make_optimizer'):
        logger.critical('Building customized optimizer.')
        optimizer = desc.make_optimizer(model, args.lr)
    else:
        from jactorch.optim import AdamW
        if args.freeze_learner_flag==1:
            if args.reconstruct_flag:
                parameters = list(model._model_pred.parameters())+list(model._decoder.parameters())
                trainable_parameters = filter(lambda x: x.requires_grad, parameters)
            elif args.version=='v4':
                trainable_parameters = filter(lambda x: x.requires_grad, model._model_pred.parameters())
        else:
            trainable_parameters = filter(lambda x: x.requires_grad, model.parameters())
        optimizer = AdamW(trainable_parameters, args.lr, weight_decay=configs.train.weight_decay)

    if args.acc_grad > 1:
        from jactorch.optim import AccumGrad
        optimizer = AccumGrad(optimizer, args.acc_grad)
        logger.warning('Use accumulated grad={:d}, effective iterations per epoch={:d}.'.format(args.acc_grad, int(args.iters_per_epoch / args.acc_grad)))

    trainer = TrainerEnv(model, optimizer)

    if args.resume:
        extra = trainer.load_checkpoint(args.resume)
        if extra:
            args.start_epoch = extra['epoch']
            logger.critical('Resume from epoch {}.'.format(args.start_epoch))
    elif args.load:
        if trainer.load_weights(args.load):
            logger.critical('Loaded weights from pretrained model: "{}".'.format(args.load))
        if args.version=='v3':
            if args.pretrain_pred_model_path:
                model._model_pred.load_state_dict(torch.load(args.pretrain_pred_model_path))
                logger.critical('Loaded weights from pretrained temporal model: "{}".'.format(args.pretrain_pred_model_path))
        elif args.version=='v4':
            if args.pretrain_pred_spatial_model_path:
                model._model_spatial_pred.load_state_dict(torch.load(args.pretrain_pred_spatial_model_path))
                logger.critical('Loaded spatial models from pretrained temporal model: "{}".'.format(args.pretrain_pred_spatial_model_path))
            if args.pretrain_pred_feature_model_path:
                model._model_pred.load_state_dict(torch.load(args.pretrain_pred_feature_model_path))
                logger.critical('Loaded feature models from pretrained temporal model: "{}".'.format(args.pretrain_pred_feature_model_path))
                #pdb.set_trace()
            if args.pretrain_pred_model_path:
                model._model_pred.load_state_dict(torch.load(args.pretrain_pred_model_path))
                logger.critical('Loaded weights from pretrained temporal model: "{}".'.format(args.pretrain_pred_model_path))
        elif args.version =='v2_1':
            model.reasoning.embedding_relation_future.load_state_dict(model.reasoning.embedding_relation.state_dict())
            model.reasoning.embedding_relation_counterfact.load_state_dict(model.reasoning.embedding_relation.state_dict())
            logger.critical('Copy original relation weights into counterfact and future relation.')
    if args.use_tb and not args.debug:
        from jactorch.train.tb import TBLogger, TBGroupMeters
        tb_logger = TBLogger(args.tb_dir)
        meters = TBGroupMeters(tb_logger)
        logger.critical('Writing tensorboard logs to: "{}".'.format(args.tb_dir))
    else:
        from jacinle.utils.meter import GroupMeters
        meters = GroupMeters()

    if not args.debug:
        logger.critical('Writing meter logs to file: "{}".'.format(args.meter_file))

    if args.clip_grad:
        logger.info('Registering the clip_grad hook: {}.'.format(args.clip_grad))
        def clip_grad(self, loss):
            from torch.nn.utils import clip_grad_norm_
            clip_grad_norm_(self.model.parameters(), max_norm=args.clip_grad)
        trainer.register_event('backward:after', clip_grad)

    if hasattr(desc, 'customize_trainer'):
        desc.customize_trainer(trainer)

    if args.embed:
        from IPython import embed; embed()

    if args.debug:
        shuffle_flag=False
    else:
        shuffle_flag=True

    logger.critical('Building the data loader.')
    validation_dataloader = validation_dataset.make_dataloader(args.batch_size, shuffle=False, drop_last=False, nr_workers=args.data_workers)
    if extra_dataset is not None:
        extra_dataloader = extra_dataset.make_dataloader(args.batch_size, shuffle=False, drop_last=False, nr_workers=args.data_workers)

    if args.evaluate:
        meters.reset()
        model.eval()
        validate_epoch(0, trainer, validation_dataloader, meters)
        if extra_dataset is not None:
            validate_epoch(0, trainer, extra_dataloader, meters, meter_prefix='validation_extra')
        logger.critical(meters.format_simple('Validation', {k: v for k, v in meters.avg.items() if v != 0}, compressed=False))
        return meters


    for epoch in range(args.start_epoch + 1, args.epochs + 1):
        meters.reset()

        model.train()

        this_train_dataset = train_dataset
        train_dataloader = this_train_dataset.make_dataloader(args.batch_size, shuffle=shuffle_flag, drop_last=True, nr_workers=args.data_workers)

        for enum_id in range(args.enums_per_epoch):
            train_epoch(epoch, trainer, train_dataloader, meters)

        if epoch % args.validation_interval == 0:
            model.eval()
            validate_epoch(epoch, trainer, validation_dataloader, meters)

        if not args.debug:
            meters.dump(args.meter_file)

        logger.critical(meters.format_simple(
            'Epoch = {}'.format(epoch),
            {k: v for k, v in meters.avg.items() if epoch % args.validation_interval == 0 or not k.startswith('validation')},
            compressed=False
        ))

        if epoch % args.save_interval == 0 and not args.debug:
            fname = osp.join(args.ckpt_dir, 'epoch_{}.pth'.format(epoch))
            trainer.save_checkpoint(fname, dict(epoch=epoch, meta_file=args.meta_file))

        if epoch > int(args.epochs * 0.6):
            trainer.set_learning_rate(args.lr * 0.1)