예제 #1
0
source_loader = generators.FunctionLoader(
        generate_envs_data,
        {'env': env, 'num_runs': config.DATA.num_samples//500, 'run_len': 500})

transformers = [
        transformers.TorchVisionTransformerComposition(config.DATA.transform, config.DATA.shape),
        transformers.TypeTransformer(config.EXPERIMENT.device)
        ]

data = SequenceDictDataSet(source_loader, transformers, 8)


trainer = setup_trainer(MONetTrainer, monet, training_config, data)

checkpointing = file_handler.EpochCheckpointHandler(os.path.join(run_path, 'checkpoints'))
trainer.register_handler(checkpointing)

regular_logging = file_handler.EpochFileHandler(os.path.join(run_path, 'data'), log_name_list=['imgs'])
trainer.register_handler(regular_logging)

tb_logging_list = ['average_elbo', 'trans_lik', 'log_z_f', 'img_lik_forward', 'elbo', 'z_s', 'img_lik_mean', 'p_x_loss', 'p_x_loss_mean']
tb_logger = tb_handler.NStepTbHandler(config.EXPERIMENT.log_every, run_path, 'logging', log_name_list=tb_logging_list)
trainer.register_handler(tb_logger)

if config.EXPERIMENT.model == 'm-stove':
    trainer.model.img_model.init_background_weights(trainer.train_dataloader.dataset.get_all())

trainer.check_ready()
trainer.train(config.TRAINING.epochs, train_only=True, pretrain=config.TRAINING.pretrain, visdom=False)
            'run_len': 25000
        })

    data_transformers = [
        transformers.TorchVisionTransformerComposition(config.DATA.transform,
                                                       config.DATA.shape),
        transformers.TypeTransformer(config.EXPERIMENT.device)
    ]
    print('Loading data')
    data = BasicDataSet(source_loader, data_transformers)
    print('Setting up trainer')
    trainer = setup_trainer(MONetTrainer, monet, training_config, data)
    check_path = os.path.join(run_path, 'checkpoints_{}'.format(game))
    if not os.path.exists(check_path):
        os.mkdir(check_path)
    checkpointing = file_handler.EpochCheckpointHandler(check_path)
    trainer.register_handler(checkpointing)
    log_path = os.path.join(run_path, 'logging_{}'.format(game))
    if not os.path.exists(log_path):
        os.mkdir(log_path)
    tb_logger = tb_handler.NStepTbHandler(
        config.EXPERIMENT.log_every,
        run_path,
        'logging_{}'.format(game),
        log_name_list=['loss', 'kl_loss', 'mask_loss', 'p_x_loss'])
    print('Running training')
    trainer.register_handler(tb_logger)

    # MONet init block
    # for w in trainer.model.parameters():
    #     std_init = 0.01