Exemplo n.º 1
0
    def test_save_sup_load_rl(self):
        pass

        model_to_save = MockModel(spinn.spinn_core_model.BaseModel,
                                  default_args())

        # Parse command line flags.
        get_flags()
        FLAGS(sys.argv)

        log_temp = tempfile.NamedTemporaryFile()
        ckpt_temp = tempfile.NamedTemporaryFile()

        logger = afs_safe_logger.ProtoLogger(log_temp.name)
        FLAGS.ckpt_path = '.'

        trainer_to_save = ModelTrainer(model_to_save, logger, FLAGS)

        model_to_load = MockModel(spinn.rl_spinn.BaseModel, default_args())
        trainer_to_load = ModelTrainer(model_to_load, logger, FLAGS)

        # Save to and load from temporary file.
        trainer_to_save.save(ckpt_temp.name)
        trainer_to_load.load(ckpt_temp.name, cpu=True)

        compare_models(model_to_save, model_to_load)

        # Cleanup temporary file.
        ckpt_temp.close()
        log_temp.close()
Exemplo n.º 2
0
        for index, eval_set in enumerate(eval_iterators):
            log_entry.Clear()
            acc = evaluate(FLAGS,
                           model,
                           eval_set,
                           log_entry,
                           logger,
                           trainer,
                           vocabulary,
                           show_sample=True,
                           eval_index=index)
            print(log_entry)
            logger.LogEntry(log_entry)
    else:
        train_loop(FLAGS, model, trainer, training_data_iter, eval_iterators,
                   logger)


if __name__ == '__main__':
    get_flags()

    # Parse command line flags.
    FLAGS(sys.argv)

    flag_defaults(FLAGS)

    if FLAGS.model_type != "RLSPINN":
        raise Exception("Reinforce is only implemented for RLSPINN.")

    run(only_forward=FLAGS.expanded_eval_only_mode)