def game5(): situation_size, message_size, prediction_size, func_size, hidden_size = ( 10, 2, 10, 20, 64, ) game = Game(situation_size, message_size, prediction_size, func_size, hidden_size, 1.2) game.play() plot_messages_information(game, 40)
def game3b(): situation_size, message_size, prediction_size, func_size, hidden_size = ( 10, 2, 2, 4, 64, ) game = Game(situation_size, message_size, prediction_size, func_size, hidden_size) game.play() plot_messages_information(game) predict_information_from_messages(game) clusterize_messages(game)
def game3(): situation_size, information_size, message_size, prediction_size, hidden_sizes = ( 10, 4, 2, 2, (64, 64), ) game = Game( situation_size, information_size, message_size, prediction_size, hidden_sizes, use_context=True, ) game.play() plot_messages_information(game, 40) predict_information_from_messages(game) clusterize_messages(game)
def game7(): situation_size, message_size, prediction_size, func_size, hidden_size = ( 10, 2, 10, 4, 64, ) game = Game(situation_size, message_size, prediction_size, func_size, hidden_size) print_first = True for lr in [0.01, 0.001, 0.0001]: play_game(game, 1000, learning_rate=lr) if print_first: logging.info( f"Epoch {game.loss_per_epoch[0][0]}:\t{game.loss_per_epoch[0][1]:.2e}" ) print_first = False logging.info( f"Epoch {game.loss_per_epoch[-1][0]}:\t{game.loss_per_epoch[-1][1]:.2e}" ) plot_messages_information(game, 40) # Compute the average messages game.average_messages(100) replications_per_func = 10 situations = torch.randn(replications_per_func * game.func_size, game.situation_size) func_switches = torch.cat( [torch.arange(game.func_size) for _ in range(replications_per_func)]) targets = game._target(situations, func_switches) LOSS = game.criterion(game.discrete_forward(situations, func_switches), targets).item() logging.info(f"Loss: {LOSS:.2e}")