source_loader = generators.FunctionLoader( generate_envs_data, {'env': env, 'num_runs': config.DATA.num_samples//500, 'run_len': 500}) transformers = [ transformers.TorchVisionTransformerComposition(config.DATA.transform, config.DATA.shape), transformers.TypeTransformer(config.EXPERIMENT.device) ] data = SequenceDictDataSet(source_loader, transformers, 8) trainer = setup_trainer(MONetTrainer, monet, training_config, data) checkpointing = file_handler.EpochCheckpointHandler(os.path.join(run_path, 'checkpoints')) trainer.register_handler(checkpointing) regular_logging = file_handler.EpochFileHandler(os.path.join(run_path, 'data'), log_name_list=['imgs']) trainer.register_handler(regular_logging) tb_logging_list = ['average_elbo', 'trans_lik', 'log_z_f', 'img_lik_forward', 'elbo', 'z_s', 'img_lik_mean', 'p_x_loss', 'p_x_loss_mean'] tb_logger = tb_handler.NStepTbHandler(config.EXPERIMENT.log_every, run_path, 'logging', log_name_list=tb_logging_list) trainer.register_handler(tb_logger) if config.EXPERIMENT.model == 'm-stove': trainer.model.img_model.init_background_weights(trainer.train_dataloader.dataset.get_all()) trainer.check_ready() trainer.train(config.TRAINING.epochs, train_only=True, pretrain=config.TRAINING.pretrain, visdom=False)
'run_len': 25000 }) data_transformers = [ transformers.TorchVisionTransformerComposition(config.DATA.transform, config.DATA.shape), transformers.TypeTransformer(config.EXPERIMENT.device) ] print('Loading data') data = BasicDataSet(source_loader, data_transformers) print('Setting up trainer') trainer = setup_trainer(MONetTrainer, monet, training_config, data) check_path = os.path.join(run_path, 'checkpoints_{}'.format(game)) if not os.path.exists(check_path): os.mkdir(check_path) checkpointing = file_handler.EpochCheckpointHandler(check_path) trainer.register_handler(checkpointing) log_path = os.path.join(run_path, 'logging_{}'.format(game)) if not os.path.exists(log_path): os.mkdir(log_path) tb_logger = tb_handler.NStepTbHandler( config.EXPERIMENT.log_every, run_path, 'logging_{}'.format(game), log_name_list=['loss', 'kl_loss', 'mask_loss', 'p_x_loss']) print('Running training') trainer.register_handler(tb_logger) # MONet init block # for w in trainer.model.parameters(): # std_init = 0.01