def policy_and_value_net(n_actions, n_controls, vocab_size, bottom_layers_fn, two_towers): """A policy and value net function.""" # Layers. # Now, with the current logits, one head computes action probabilities and the # other computes the value function. # NOTE: The LogSoftmax instead of the Softmax because of numerical stability. @tl.layer() def FlattenControlsIntoTime(x, **unused_kwargs): # pylint: disable=invalid-name """Splits logits for actions in different controls and flattens controls.""" return np.reshape(x, (x.shape[0], -1, n_actions)) if vocab_size is None: # In continuous policies every element of the output sequence corresponds to # an observation. n_preds_per_input = n_controls kwargs = {} else: # In discrete policies every element of the output sequence corresponds to # a symbol in the discrete representation, and each control takes 1 symbol. n_preds_per_input = 1 kwargs = {"vocab_size": vocab_size} if two_towers: layers = [ tl.Dup(), tl.Parallel( [ bottom_layers_fn(**kwargs), tl.Dense(n_preds_per_input * n_actions), FlattenControlsIntoTime(), # pylint: disable=no-value-for-parameter tl.LogSoftmax() ], [ bottom_layers_fn(**kwargs), tl.Dense(n_preds_per_input), tl.Flatten() ], ) ] else: layers = [ bottom_layers_fn(**kwargs), tl.Dup(), tl.Parallel( [ tl.Dense(n_preds_per_input * n_actions), FlattenControlsIntoTime(), # pylint: disable=no-value-for-parameter tl.LogSoftmax() ], [tl.Dense(n_preds_per_input), tl.Flatten()], ) ] return tl.Model(layers)
def WideResnet(n_blocks=3, widen_factor=1, n_output_classes=10, bn_momentum=0.9, mode='train'): """WideResnet from https://arxiv.org/pdf/1605.07146.pdf. Args: n_blocks: int, number of blocks in a group. total layers = 6n + 4. widen_factor: int, widening factor of each group. k=1 is vanilla resnet. n_output_classes: int, number of distinct output classes. bn_momentum: float, momentum in BatchNorm. mode: Whether we are training or evaluating or doing inference. Returns: The list of layers comprising a WideResnet model with the given parameters. """ return tl.Serial( tl.ToFloat(), tl.Conv(16, (3, 3), padding='SAME'), WideResnetGroup(n_blocks, 16 * widen_factor, bn_momentum=bn_momentum, mode=mode), WideResnetGroup(n_blocks, 32 * widen_factor, (2, 2), bn_momentum=bn_momentum, mode=mode), WideResnetGroup(n_blocks, 64 * widen_factor, (2, 2), bn_momentum=bn_momentum, mode=mode), tl.BatchNorm(momentum=bn_momentum, mode=mode), tl.Relu(), tl.AvgPool(pool_size=(8, 8)), tl.Flatten(), tl.Dense(n_output_classes), tl.LogSoftmax(), )
def AtariCnnBody(n_frames=4, hidden_sizes=(32, 64, 64), output_size=512, mode='train', kernel_initializer=None, padding='VALID'): """An Atari CNN.""" del mode # TODO(jonni): Include link to paper? # Input shape: (B, T, H, W, C) # Output shape: (B, T, output_size) return tl.Serial( _BytesToFloats(), _FrameStack(n_frames=n_frames), # (B, T, H, W, 4C) tl.Conv(hidden_sizes[0], (8, 8), (4, 4), padding=padding, kernel_initializer=kernel_initializer), tl.Relu(), tl.Conv(hidden_sizes[1], (4, 4), (2, 2), padding=padding, kernel_initializer=kernel_initializer), tl.Relu(), tl.Conv(hidden_sizes[2], (3, 3), (1, 1), padding=padding, kernel_initializer=kernel_initializer), tl.Relu(), tl.Flatten(n_axes_to_keep=2), # B, T and rest. tl.Dense(output_size), tl.Relu(), )
def test_train_mnist(self): """Train MNIST model (almost) fully, to compare to other implementations. Evals for cross-entropy loss and accuracy are run every 50 steps; their values are visible in the test log. """ gin.parse_config([ 'batch_fn.batch_size_per_device = 256', 'batch_fn.eval_batch_size = 256', ]) mnist_model = tl.Serial( tl.Flatten(), tl.Dense(512), tl.Relu(), tl.Dense(512), tl.Relu(), tl.Dense(10), tl.LogSoftmax(), ) task = training.TrainTask( itertools.cycle(_mnist_dataset().train_stream(1)), tl.CrossEntropyLoss(), adafactor.Adafactor(.02)) eval_task = training.EvalTask( itertools.cycle(_mnist_dataset().eval_stream(1)), [tl.CrossEntropyLoss(), tl.AccuracyScalar()], names=['CrossEntropyLoss', 'AccuracyScalar'], eval_at=lambda step_n: step_n % 50 == 0, eval_N=10) training_session = training.Loop(mnist_model, task, eval_task=eval_task) training_session.run(n_steps=1000) self.assertEqual(training_session.current_step(), 1000)
def test_policy_and_value_net(self): observation_shape = (3, 4, 5) n_actions = 2 n_controls = 3 batch = 2 time_steps = 10 observations = np.random.uniform(size=(batch, time_steps) + observation_shape) actions = np.random.randint(n_actions, size=(batch, time_steps - 1, n_controls)) (pnv_model, _) = policy_based_utils.policy_and_value_net( bottom_layers_fn=lambda: [layers.Flatten(n_axes_to_keep=2)], observation_space=gym.spaces.Box(shape=observation_shape, low=0, high=1), action_space=gym.spaces.MultiDiscrete((n_actions, ) * n_controls), vocab_size=None, two_towers=True, ) input_signature = shapes.signature((observations, actions)) _, _ = pnv_model.init(input_signature) (action_logits, values) = pnv_model((observations, actions)) # Output is a list, first is probab of actions and the next is value output. self.assertEqual((batch, time_steps, n_controls, n_actions), action_logits.shape) self.assertEqual((batch, time_steps), values.shape)
def test_policy_and_value_net(self): observation_shape = (3, 4, 5) batch_observation_shape = (1, 1) + observation_shape n_actions = 2 n_controls = 3 pnv_model = ppo.policy_and_value_net( n_controls=n_controls, n_actions=n_actions, vocab_size=None, bottom_layers_fn=lambda: [layers.Flatten(n_axes_to_keep=2)], two_towers=True, ) input_signature = ShapeDtype(batch_observation_shape) _, _ = pnv_model.init(input_signature) batch = 2 time_steps = 10 batch_of_observations = np.random.uniform( size=(batch, time_steps) + observation_shape) pnv_output = pnv_model(batch_of_observations) # Output is a list, first is probab of actions and the next is value output. self.assertEqual(2, len(pnv_output)) self.assertEqual( (batch, time_steps * n_controls, n_actions), pnv_output[0].shape) self.assertEqual((batch, time_steps * n_controls), pnv_output[1].shape)
def test_train_mnist(self): """Train MNIST model (almost) fully, to compare to other implementations. Evals for cross-entropy loss and accuracy are run every 50 steps; their values are visible in the test log. """ mnist_model = tl.Serial( tl.Flatten(), tl.Dense(512), tl.Relu(), tl.Dense(512), tl.Relu(), tl.Dense(10), tl.LogSoftmax(), ) task = training.TrainTask( itertools.cycle(_mnist_dataset().train_stream(1)), tl.CrossEntropyLoss(), adafactor.Adafactor(.02)) eval_task = training.EvalTask( itertools.cycle(_mnist_dataset().eval_stream(1)), [tl.CrossEntropyLoss(), tl.Accuracy()], n_eval_batches=10) training_session = training.Loop( mnist_model, [task], eval_tasks=[eval_task], eval_at=lambda step_n: step_n % 50 == 0) training_session.run(n_steps=1000) self.assertEqual(training_session.step, 1000)
def AtariCnnBody(n_frames=4, hidden_sizes=(32, 64, 64), output_size=512, mode='train', kernel_initializer=None): """An Atari CNN.""" del mode # TODO(jonni): Include link to paper? # Input shape: (B, T, H, W, C) # Output shape: (B, T, output_size) return tl.Serial( tl.Fn(lambda x: x / 255.0), # Convert unsigned bytes to float. _FrameStack(n_frames=n_frames), # (B, T, H, W, 4C) tl.Conv(hidden_sizes[0], (8, 8), (4, 4), padding='SAME', kernel_initializer=kernel_initializer), tl.Relu(), tl.Conv(hidden_sizes[1], (4, 4), (2, 2), 'SAME', kernel_initializer=kernel_initializer), tl.Relu(), tl.Conv(hidden_sizes[2], (3, 3), (1, 1), 'SAME', kernel_initializer=kernel_initializer), tl.Relu(), tl.Flatten(n_axes_to_keep=2), # B, T and rest. tl.Dense(output_size), tl.Relu(), )
def test_two_outputs_pass(self): layer = tl.AssertFunction( '...cd->...x,...cd', tl.Branch( tl.Flatten(n_axes_to_keep=2), tl.Dropout(rate=0.1), )) x = np.ones((1, 2, 3, 4)) layer(x)
def Resnet50(d_hidden=64, n_output_classes=1001, mode='train', norm=tl.BatchNorm, non_linearity=tl.Relu): """ResNet. Args: d_hidden: Dimensionality of the first hidden layer (multiplied later). n_output_classes: Number of distinct output classes. mode: Whether we are training or evaluating or doing inference. norm: `Layer` used for normalization, Ex: BatchNorm or FilterResponseNorm. non_linearity: `Layer` used as a non-linearity, Ex: If norm is BatchNorm then this is a Relu, otherwise for FilterResponseNorm this should be ThresholdedLinearUnit. Returns: The list of layers comprising a ResNet model with the given parameters. """ # A ConvBlock configured with the given norm, non-linearity and mode. def Resnet50ConvBlock(filter_multiplier=1, strides=(2, 2)): filters = ([ filter_multiplier * dim for dim in [d_hidden, d_hidden, 4 * d_hidden] ]) return ConvBlock(3, filters, strides, norm, non_linearity, mode) # Same as above for IdentityBlock. def Resnet50IdentityBlock(filter_multiplier=1): filters = ([ filter_multiplier * dim for dim in [d_hidden, d_hidden, 4 * d_hidden] ]) return IdentityBlock(3, filters, norm, non_linearity, mode) return tl.Serial( tl.ToFloat(), tl.Conv(d_hidden, (7, 7), (2, 2), 'SAME'), norm(mode=mode), non_linearity(), tl.MaxPool(pool_size=(3, 3), strides=(2, 2)), Resnet50ConvBlock(strides=(1, 1)), [Resnet50IdentityBlock() for _ in range(2)], Resnet50ConvBlock(2), [Resnet50IdentityBlock(2) for _ in range(3)], Resnet50ConvBlock(4), [Resnet50IdentityBlock(4) for _ in range(5)], Resnet50ConvBlock(8), [Resnet50IdentityBlock(8) for _ in range(2)], tl.AvgPool(pool_size=(7, 7)), tl.Flatten(), tl.Dense(n_output_classes), tl.LogSoftmax(), )
def get_model(num_classes): return tl.Serial( tl.Flatten(), tl.Dense(512), tl.Relu(), tl.Dense(512), tl.Relu(), tl.Dense(num_classes), tl.LogSoftmax(), )
def test_multi_output_rank_fail(self): layer = tl.AssertFunction( '...34->...x,...y', tl.Branch( tl.Flatten(n_axes_to_keep=3), tl.Serial(), )) x = np.ones((1, 2, 3, 4)) with self.assertRaises(tl.LayerError): layer(x)
def test_too_many_outputs_fail(self): layer = tl.AssertFunction( '...cd->...x,...cd,...cd,...cd', tl.Branch( tl.Flatten(n_axes_to_keep=2), tl.Dropout(rate=0.1), tl.Serial(), )) x = np.ones((1, 2, 3, 4)) with self.assertRaises(tl.LayerError): layer(x)
def Resnet50(d_hidden=64, n_output_classes=1001, mode='train'): """ResNet. Args: d_hidden: Dimensionality of the first hidden layer (multiplied later). n_output_classes: Number of distinct output classes. mode: Whether we are training or evaluating or doing inference. Returns: The list of layers comprising a ResNet model with the given parameters. """ return tl.Model( tl.ToFloat(), tl.Conv(d_hidden, (7, 7), (2, 2), 'SAME'), tl.BatchNorm(mode=mode), tl.Relu(), tl.MaxPool(pool_size=(3, 3), strides=(2, 2)), ConvBlock(3, [d_hidden, d_hidden, 4 * d_hidden], (1, 1), mode=mode), IdentityBlock(3, [d_hidden, d_hidden, 4 * d_hidden], mode=mode), IdentityBlock(3, [d_hidden, d_hidden, 4 * d_hidden], mode=mode), ConvBlock(3, [2 * d_hidden, 2 * d_hidden, 8 * d_hidden], (2, 2), mode=mode), IdentityBlock(3, [2 * d_hidden, 2 * d_hidden, 8 * d_hidden], mode=mode), IdentityBlock(3, [2 * d_hidden, 2 * d_hidden, 8 * d_hidden], mode=mode), IdentityBlock(3, [2 * d_hidden, 2 * d_hidden, 8 * d_hidden], mode=mode), ConvBlock(3, [4 * d_hidden, 4 * d_hidden, 16 * d_hidden], (2, 2), mode=mode), IdentityBlock(3, [4 * d_hidden, 4 * d_hidden, 16 * d_hidden], mode=mode), IdentityBlock(3, [4 * d_hidden, 4 * d_hidden, 16 * d_hidden], mode=mode), IdentityBlock(3, [4 * d_hidden, 4 * d_hidden, 16 * d_hidden], mode=mode), IdentityBlock(3, [4 * d_hidden, 4 * d_hidden, 16 * d_hidden], mode=mode), IdentityBlock(3, [4 * d_hidden, 4 * d_hidden, 16 * d_hidden], mode=mode), ConvBlock(3, [8 * d_hidden, 8 * d_hidden, 32 * d_hidden], (2, 2), mode=mode), IdentityBlock(3, [8 * d_hidden, 8 * d_hidden, 32 * d_hidden], mode=mode), IdentityBlock(3, [8 * d_hidden, 8 * d_hidden, 32 * d_hidden], mode=mode), tl.AvgPool(pool_size=(7, 7)), tl.Flatten(), tl.Dense(n_output_classes), tl.LogSoftmax(), )
def RecommenderTransformer(n_classes_in, embedding_size, n_out_classes, dropout_rate): transfomer = tl.Serial( tl.Embedding(n_classes_in, d_feature=embedding_size), tl.Dropout(dropout_rate), tl.SelfAttention(2), tl.Flatten(), tl.Dropout(dropout_rate), #tl.DotProductCausalAttention(4), tl.Dense(n_out_classes), tl.LogSoftmax()) print(str(transfomer)) return transfomer
def MLP(n_hidden_layers=2, d_hidden=512, activation_fn=tl.Relu, n_output_classes=10, mode="train"): """A multi-layer feedforward (perceptron) network.""" del mode return tl.Model( tl.Flatten(), [[tl.Dense(d_hidden), activation_fn()] for _ in range(n_hidden_layers)], tl.Dense(n_output_classes), tl.LogSoftmax(), )
def model(mode): del mode return layers.Serial( layers.Parallel( layers.Flatten(), # Observation stack. layers.Embedding(d_feature=1, vocab_size=n_actions), # Action. ), layers.Concatenate(), layers.Dense(n_units=1), layers.Dup(), layers.Parallel( layers.Dense(n_units=obs_shape[1]), # New observation. None, # Reward. ))
def PureMLP( layer_widths=(128, 64), activation_fn=tl.Relu, out_activation=False, flatten=True, mode='train'): """A "multilayer perceptron" (MLP) network. This is a classic fully connected feedforward network, with one or more layers and a (nonlinear) activation function between each layer. For historical reasons, such networks are often called multilayer perceptrons; but they are more accurately described as multilayer networks, where each layer + activation function is a perceptron-like unit (see, e.g., [https://en.wikipedia.org/wiki/Multilayer_perceptron#Terminology]). Args: layer_widths: Tuple of ints telling the number of layers and the width of each layer. For example, setting `layer_widths=(128, 64, 32)` would yield 3 layers with successive widths of 128, 64, and 32. activation_fn: Layer that computes a nonlinear activation between pairs of fully connnected layers. An activation function typically acts elementwise, and its output has the same shape and dtype as its input. out_activation: If True, include a copy of the activation function as the last layer in the network. flatten: If True, insert a layer at the head of the network to flatten the input tensor into a matrix of shape (batch_size. -1). mode: Ignored. Returns: An assembled MLP network with the specified layers. This network can either be initialized and trained as a full model, or can be used as a building block in a larger network. """ del mode layers = [] for width in layer_widths: layers.append(tl.Dense(width)) layers.append(activation_fn()) if not out_activation: # Don't need the last activation. layers.pop() return tl.Serial( [tl.Flatten()] if flatten else [], layers, )
def _build_model(two_heads): cls_head = tl.Serial(tl.Dense(10), tl.LogSoftmax()) if two_heads: reg_head = tl.Dense(1) heads = tl.Branch(cls_head, reg_head) else: heads = cls_head return tl.Serial( tl.Fn('ScaleInput', lambda x: x / 255), tl.Flatten(), tl.Dense(512), tl.Relu(), tl.Dense(512), tl.Relu(), heads, )
def PureMLP( hidden_dims=(128, 64), activation_fn=tl.Relu, flatten=True, mode='train'): """A multi-layer feedforward (perceptron) network.""" del mode layers = [] for hidden_dim in hidden_dims: layers.append(tl.Dense(hidden_dim)) layers.append(activation_fn()) # Don't need the last activation. layers.pop() return tl.Serial( [tl.Flatten()] if flatten else [], layers, )
def AtariCnn(n_frames=4, hidden_sizes=(32, 32), output_size=128, mode='train'): """An Atari CNN.""" del mode # TODO(jonni): Include link to paper? # Input shape: (B, T, H, W, C) # Output shape: (B, T, output_size) return tl.Serial( tl.Fn(lambda x: x / 255.0), # Convert unsigned bytes to float. _FrameStack(n_frames=n_frames), # (B, T, H, W, 4C) tl.Conv(hidden_sizes[0], (5, 5), (2, 2), 'SAME'), tl.Relu(), tl.Conv(hidden_sizes[1], (5, 5), (2, 2), 'SAME'), tl.Relu(), tl.Flatten(n_axes_to_keep=2), # B, T and rest. tl.Dense(output_size), tl.Relu(), )
def MLP(d_hidden=512, n_hidden_layers=2, activation_fn=tl.Relu, n_output_classes=10, mode='train'): """A multi-layer feedforward (perceptron) network.""" del mode # Define a function rather than a variable, so that multiple copies will # each be their own object with their own weights. def DensePlusActivation(): return [tl.Dense(d_hidden), activation_fn()] return tl.Serial( tl.Flatten(), [DensePlusActivation() for _ in range(n_hidden_layers)], tl.Dense(n_output_classes), tl.LogSoftmax(), )
def RawPolicy(seq_model, n_controls, n_actions): """Wraps a sequence model in a policy interface. The resulting model takes as input observation anc action sequences, but only uses the observations. Adds output heads for action logits and value predictions. Args: seq_model: Trax sequence model taking as input and outputting a sequence of continuous vectors. n_controls: Number of controls. n_actions: Number of action categories in each control. Returns: A model of signature (obs, act) -> (act_logits, values), with shapes: obs: (batch_size, length + 1, obs_depth) act: (batch_size, length, n_controls) act_logits: (batch_size, length, n_controls, n_actions) values: (batch_size, length) """ @tl.layer() def SplitControls(x, **unused_kwargs): # pylint: disable=invalid-name """Splits logits for actions in different controls.""" return np.reshape(x, x.shape[:2] + (n_controls, n_actions)) action_head = [ # Predict all action logits at the same time. tl.Dense(n_controls * n_actions), # Then group them into separate controls, adding a new dimension. SplitControls(), # pylint: disable=no-value-for-parameter # Needed because there is 1 less actions than observations. DropLastTimestep(), # pylint: disable=no-value-for-parameter tl.LogSoftmax(), ] return tl.Serial([ # (obs, act) tl.Select([0], n_in=2), # (obs,) seq_model, # (obs_hidden,) tl.Dup(), # (obs_hidden, obs_hidden) tl.Parallel( action_head, [tl.Dense(1), tl.Flatten()], ) # (act_logits, values) ])
def get_model(n_output_classes=10): """ Simple CNN to classify Fashion MNIST """ model = tl.Serial( tl.ToFloat(), tl.Conv(32, (3, 3), (1, 1), "SAME"), tl.LayerNorm(), tl.Relu(), tl.MaxPool(), tl.Conv(64, (3, 3), (1, 1), "SAME"), tl.LayerNorm(), tl.Relu(), tl.MaxPool(), tl.Flatten(), tl.Dense(n_output_classes), ) return model
def AtariCnn(n_frames=4, hidden_sizes=(32, 32), output_size=128, mode='train'): """An Atari CNN.""" del mode # TODO(jonni): Include link to paper? # Input shape: (B, T, H, W, C) # Output shape: (B, T, output_size) return tl.Model( tl.ToFloat(), tl.Div(divisor=255.0), # Set up n_frames successive game frames, concatenated on the last axis. FrameStack(n_frames=n_frames), # (B, T, H, W, 4C) tl.Conv(hidden_sizes[0], (5, 5), (2, 2), 'SAME'), tl.Relu(), tl.Conv(hidden_sizes[1], (5, 5), (2, 2), 'SAME'), tl.Relu(), tl.Flatten(n_axes_to_keep=2), # B, T and rest. tl.Dense(output_size), tl.Relu(), )
def test_reduce_rank_explicit_fail2(self): layer = tl.AssertFunction('abcde->abcd', tl.Flatten(n_axes_to_keep=3)) x = np.ones((1, 2, 3, 4, 5)) with self.assertRaises(tl.LayerError): layer(x)
def test_reduce_rank_ellipsis_pass(self): layer = tl.AssertFunction('...ab->...c', tl.Flatten(n_axes_to_keep=3)) x = np.ones((1, 2, 3, 4, 5)) layer(x)
def test_reduce_rank_explicit_pass(self): layer = tl.AssertFunction('xyzab->xyzc', tl.Flatten(n_axes_to_keep=3)) x = np.ones((1, 2, 3, 4, 5)) layer(x)
def test_reduce_rank_to_one_pass(self): layer = tl.AssertFunction('abcde->x', tl.Flatten(n_axes_to_keep=0)) x = np.ones((1, 2, 3, 4, 5)) layer(x)
def test_combined_loss(self): B, T, A, OBS = 2, 10, 2, (28, 28, 3) # pylint: disable=invalid-name batch_observation_shape = (1, 1) + OBS net = ppo.policy_and_value_net( n_controls=1, n_actions=A, vocab_size=None, bottom_layers_fn=lambda: [layers.Flatten(n_axes_to_keep=2)], two_towers=True, ) input_signature = ShapeDtype(batch_observation_shape) old_params, _ = net.init(input_signature) new_params, state = net.init(input_signature) # Generate a batch of observations. observations = np.random.uniform(size=(B, T + 1) + OBS) actions = np.random.randint(0, A, size=(B, T + 1)) rewards = np.random.uniform(0, 1, size=(B, T)) mask = np.ones_like(rewards) # Just test that this computes at all. (new_log_probabs, value_predictions_new) = ( net(observations, weights=new_params, state=state)) (old_log_probabs, value_predictions_old) = ( net(observations, weights=old_params, state=state)) gamma = 0.99 lambda_ = 0.95 epsilon = 0.2 value_weight = 1.0 entropy_weight = 0.01 nontrainable_params = { 'gamma': gamma, 'lambda': lambda_, 'epsilon': epsilon, 'value_weight': value_weight, 'entropy_weight': entropy_weight, } rewards_to_actions = np.eye(value_predictions_old.shape[1]) (value_loss_1, _) = ppo.value_loss_given_predictions( value_predictions_new, rewards, mask, gamma=gamma, value_prediction_old=value_predictions_old, epsilon=epsilon) (ppo_loss_1, _) = ppo.ppo_loss_given_predictions( new_log_probabs, old_log_probabs, value_predictions_old, actions, rewards_to_actions, rewards, mask, gamma=gamma, lambda_=lambda_, epsilon=epsilon) (combined_loss, (ppo_loss_2, value_loss_2, entropy_bonus), _, state) = ( ppo.combined_loss(new_params, old_log_probabs, value_predictions_old, net, observations, actions, rewards_to_actions, rewards, mask, nontrainable_params=nontrainable_params, state=state) ) # Test that these compute at all and are self consistent. self.assertGreater(entropy_bonus, 0.0) self.assertNear(value_loss_1, value_loss_2, 1e-6) self.assertNear(ppo_loss_1, ppo_loss_2, 1e-6) self.assertNear( combined_loss, ppo_loss_2 + (value_weight * value_loss_2) - (entropy_weight * entropy_bonus), 1e-6 )