] print('Loading data') data = BasicDataSet(source_loader, data_transformers) print('Setting up trainer') trainer = setup_trainer(MONetTrainer, monet, training_config, data) check_path = os.path.join(run_path, 'checkpoints_{}'.format(game)) if not os.path.exists(check_path): os.mkdir(check_path) checkpointing = file_handler.EpochCheckpointHandler(check_path) trainer.register_handler(checkpointing) log_path = os.path.join(run_path, 'logging_{}'.format(game)) if not os.path.exists(log_path): os.mkdir(log_path) tb_logger = tb_handler.NStepTbHandler( config.EXPERIMENT.log_every, run_path, 'logging_{}'.format(game), log_name_list=['loss', 'kl_loss', 'mask_loss', 'p_x_loss']) print('Running training') trainer.register_handler(tb_logger) # MONet init block # for w in trainer.model.parameters(): # std_init = 0.01 # torch.nn.init.normal_(w, mean=0., std=std_init) trainer.model.module.init_background_weights( trainer.train_dataloader.dataset.get_all()) trainer.train(config.TRAINING.epochs, train_only=True, visdom=False) env.close()
transformers.TypeTransformer(config.EXPERIMENT.device) ] data = SequenceDictDataSet(source_loader, transformers, 8) trainer = setup_trainer(MONetTrainer, monet, training_config, data) checkpointing = file_handler.EpochCheckpointHandler(os.path.join(run_path, 'checkpoints')) trainer.register_handler(checkpointing) regular_logging = file_handler.EpochFileHandler(os.path.join(run_path, 'data'), log_name_list=['imgs']) trainer.register_handler(regular_logging) tb_logging_list = ['average_elbo', 'trans_lik', 'log_z_f', 'img_lik_forward', 'elbo', 'z_s', 'img_lik_mean', 'p_x_loss', 'p_x_loss_mean'] tb_logger = tb_handler.NStepTbHandler(config.EXPERIMENT.log_every, run_path, 'logging', log_name_list=tb_logging_list) trainer.register_handler(tb_logger) if config.EXPERIMENT.model == 'm-stove': trainer.model.img_model.init_background_weights(trainer.train_dataloader.dataset.get_all()) trainer.check_ready() trainer.train(config.TRAINING.epochs, train_only=True, pretrain=config.TRAINING.pretrain, visdom=False) if config.TRAINING.pretrain: monet.img_model.beta = config.MODULE.MONET.beta trainer = setup_trainer(MONetTrainer, monet, training_config, data) checkpointing = file_handler.EpochCheckpointHandler(os.path.join(run_path, 'checkpoints')) trainer.register_handler(checkpointing)
data_transformers = [ transformers.TorchVisionTransformerComposition(config.DATA.transform, config.DATA.shape), transformers.TypeTransformer(config.EXPERIMENT.device) ] print('Loading data') data = BasicDataSet(source_loader, data_transformers) trainer = setup_trainer(MONetTrainer, monet, training_config, data) checkpointing = file_handler.EpochCheckpointHandler( os.path.join(run_path, 'checkpoints')) trainer.register_handler(checkpointing) tb_logger = tb_handler.NStepTbHandler(config.EXPERIMENT.log_every, run_path, 'logging', log_name_list=[ 'loss', 'kl_loss', 'mask_loss', 'p_x_loss', 'mse', 'p_x_loss_mean' ]) trainer.register_handler(tb_logger) regular_logging = file_handler.NStepFileHandler( 150, os.path.join(run_path, 'data'), log_name_list=['reconstruction', 'masks']) trainer.register_handler(regular_logging) # MONet init block # for w in trainer.model.parameters(): # std_init = 0.01 # torch.nn.init.normal_(w, mean=0., std=std_init) trainer.model.init_background_weights(