Ejemplo n.º 1
0
def get_optimizer(learning_rate):
    """
  Return the tensor of SGD optimizer.

  Args:
    learning_rate: a `float32` tensor as the learning rate.

  Returns:
    optimizer: an optimizer.

  """

    if FLAGS.optimizer == 'adam':
        return tf.train.AdamOptimizer(learning_rate=learning_rate,
                                      beta1=FLAGS.beta1)
    elif FLAGS.optimizer == 'nadam':
        return NadamOptimizer(learning_rate=learning_rate, beta1=FLAGS.beta1)
    elif FLAGS.optimizer == 'adadelta':
        return tf.train.AdadeltaOptimizer(learning_rate=learning_rate,
                                          rho=FLAGS.rho)
    elif FLAGS.optimizer == 'rmsprop':
        return tf.train.RMSPropOptimizer(learning_rate=learning_rate,
                                         decay=FLAGS.rmsprop_decay,
                                         momentum=FLAGS.rmsprop_momentum)
    elif FLAGS.optimizer == 'amsgrad':
        return AmsGrad(learning_rate=learning_rate, beta1=FLAGS.beta1)
    else:
        raise ValueError(
            "Supported SGD optimizers: adam, nadam, adadelta, rmsprop, amsgrad"
        )
Ejemplo n.º 2
0
def _get_optimizer(learning_rate, optimizer_name):
    if optimizer_name == Optimizers.MOMENTUM.value:
        return tf.train.MomentumOptimizer(learning_rate,
                                          momentum=0.9,
                                          use_nesterov=True)
    if optimizer_name == Optimizers.ADAM.value:
        return tf.train.AdamOptimizer(learning_rate)
    if optimizer_name == Optimizers.ADADELTA.value:
        return tf.train.AdadeltaOptimizer(learning_rate)
    if optimizer_name == Optimizers.RMSPROP.value:
        return tf.train.RMSPropOptimizer(learning_rate)
    if optimizer_name == Optimizers.NADAM.value:
        return NadamOptimizer(learning_rate)
    raise NotImplementedError(optimizer_name + " optimizer not supported")
Ejemplo n.º 3
0
def Fashion_CNN(input_shape, num_classes, learning_rate, graph):

    with graph.as_default():

        #is_train = tf.placeholder(tf.bool)
        img = tf.placeholder(tf.float32, input_shape)

        labels = tf.placeholder(tf.float32, shape=(None, num_classes))
        lr = tf.placeholder(tf.float32)

        # first 3 convolutions approximate Conv(7,7):
        layer = conv_layer(img, 64)
        layer = conv_layer(layer, 64)
        layer = conv_layer(layer, 64)
        layer = MaxPooling2D()(layer)
        layer = dropout(layer, keep_prob=0.7)
        layer = conv_layer(layer, 128, shape=(-1, 14, 14, -1))
        layer = conv_layer(layer, 128, shape=(-1, 14, 14, -1))
        layer = conv_layer(layer, 64, (1, 1), shape=(-1, 14, 14, -1))
        layer = MaxPooling2D()(layer)
        layer = Flatten()(layer)
        layer = dropout(layer, keep_prob=0.7)
        layer = fc_layer(layer, 2048)
        layer = dropout(layer)
        layer = fc_layer(layer, 512)
        layer = dropout(layer)
        layer = fc_layer(layer, 256)
        layer = dropout(layer)
        layer = Dense(10, kernel_initializer='glorot_normal')(layer)
        layer = batch_norm(layer,
                           updates_collections=None,
                           center=True,
                           scale=True)
        preds = activations.softmax(layer)

        lossL2 = tf.add_n([
            tf.nn.l2_loss(v) for v in tf.trainable_variables()
            if 'kernel' in v.name
        ])

        beta = 1e-7
        loss = tf.reduce_mean(losses.categorical_crossentropy(labels, preds))
        train_step = NadamOptimizer(learning_rate=lr).minimize(loss)

        acc_value = tf.reduce_mean(metrics.categorical_accuracy(labels, preds))

        return img, labels, lr, train_step, loss, acc_value
Ejemplo n.º 4
0
    def _init_optimiser(self):
        r"""
                Computes the batch_loss function to be minimised
                """

        self._init_lr_decay()

        self._loss_weights = tf.sequence_mask(
            lengths=self._decoder._labels_len, dtype=self._hparams.dtype)

        if self._hparams.loss_fun is None:
            if self._hparams.label_smoothing <= 0.0:
                softmax_loss_fun = None
            else:
                print(
                    'Using the slower "softmax_cross_entropy" instead of "sparse_softmax_cross_entropy" '
                    'since label smoothing is nonzero')
                from .devel import smoothed_cross_entropy
                num_classes = tf.shape(self._decoder._logits)[2]
                softmax_loss_fun = smoothed_cross_entropy(
                    num_classes, self._hparams.label_smoothing)
        elif self._hparams.loss_fun == 'focal_loss':
            from .devel import focal_loss
            softmax_loss_fun = focal_loss
        elif self._hparams.loss_fun == 'mc_loss':
            from .devel import mc_loss
            softmax_loss_fun = mc_loss
        else:
            raise ValueError('Unknown loss function {}'.format(
                self._hparams.loss_fun))

        self.batch_loss = seq2seq.sequence_loss(
            logits=self._decoder._logits,
            targets=self._decoder._labels,
            weights=self._loss_weights,
            softmax_loss_function=softmax_loss_fun,
            average_across_batch=True,
            average_across_timesteps=True)

        reg_loss = 0

        if self._hparams.recurrent_l2_regularisation is not None:
            regularisable_vars = _get_trainable_vars(self._hparams.cell_type)
            reg = tf.contrib.layers.l2_regularizer(
                scale=self._hparams.recurrent_l2_regularisation)
            reg_loss = tf.contrib.layers.apply_regularization(
                reg, regularisable_vars)

        if self._hparams.video_processing is not None:
            if 'cnn' in self._hparams.video_processing:
                # we regularise the cnn vars by specifying a regulariser in conv2d
                reg_variables = tf.get_collection(
                    tf.GraphKeys.REGULARIZATION_LOSSES)
                reg_loss += tf.reduce_sum(reg_variables)

        self.batch_loss = self.batch_loss + reg_loss

        if self._hparams.regress_aus is True:
            loss_weight = self._hparams.kwargs.get('au_loss_weight', 10.0)
            self.batch_loss += loss_weight * self._video_encoder.au_loss

        if self._hparams.loss_scaling > 1:
            self.batch_loss *= self._hparams.loss_scaling

        if self._hparams.optimiser == 'Adam':
            optimiser = tf.train.AdamOptimizer(
                learning_rate=self.current_lr,
                epsilon=1e-8 if self._hparams.dtype == tf.float32 else 1e-4,
            )
        elif self._hparams.optimiser == 'Nadam':
            from tensorflow.contrib.opt import NadamOptimizer
            optimiser = NadamOptimizer(learning_rate=self.current_lr, )
        elif self._hparams.optimiser == 'AdamW':
            from tensorflow.contrib.opt import AdamWOptimizer
            optimiser = AdamWOptimizer(
                learning_rate=self.current_lr,
                weight_decay=self._hparams.weight_decay,
                epsilon=1e-8 if self._hparams.dtype == tf.float32 else 1e-4,
            )
        elif self._hparams.optimiser == 'Momentum':
            optimiser = tf.train.MomentumOptimizer(
                learning_rate=self.current_lr,
                momentum=0.9,
                use_nesterov=False)
        else:
            raise Exception('Unsupported optimiser, try Adam')

        variables = tf.trainable_variables()
        gradients = tf.gradients(
            self.batch_loss,
            variables)  # not compatible with Nvidia AMP (fp16)
        # gradients = optimiser.compute_gradients(self.batch_loss, variables)

        summaries = []
        for grad, variable in zip(gradients, variables):
            if isinstance(grad, tf.IndexedSlices):
                value = grad.values
            else:
                value = grad
            summary = tf.summary.histogram("%s-grad" % variable.name, value)
            summaries.append(summary)

        if self._hparams.dtype == tf.float16:
            #ripped from https://github.com/joeyearsley/efficient_densenet_tensorflow/blob/master/train.py
            # Choose a loss scale manager which decides how to pick the right loss scale
            # throughout the training process.
            loss_scale_manager = ExponentialUpdateLossScaleManager(128, 100)
            # Wraps the original optimizer in a LossScaleOptimizer.
            optimizer = LossScaleOptimizer(optimiser, loss_scale_manager)

        if self._hparams.loss_scaling > 1:
            gradients = [
                tf.div(grad, self._hparams.loss_scaling) for grad in gradients
            ]

        if self._hparams.clip_gradients is True:
            gradients, self.global_norm = tf.clip_by_global_norm(
                gradients, self._hparams.max_gradient_norm)

        if self._hparams.batch_normalisation is True:
            update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
            with tf.control_dependencies(update_ops):
                self.train_op = optimiser.apply_gradients(
                    grads_and_vars=zip(gradients, variables),
                    global_step=tf.train.get_global_step())
        else:
            self.train_op = optimiser.apply_gradients(
                grads_and_vars=zip(gradients, variables),
                global_step=tf.train.get_global_step())
Ejemplo n.º 5
0
 def create_optimizer(self):
     lr = self.get_current_step_learning_rate()
     epsilon = self.config.float("optimizer_epsilon", 1e-16)
     use_locking = self.use_locking
     momentum = self.config.float("momentum", 0.0)
     optim_config = self.config.typed_value("optimizer")
     if optim_config:
         if isinstance(optim_config, str):
             optim_config = {"class": optim_config}
         assert isinstance(optim_config, dict)
         optim_config = optim_config.copy()
         optim_class_name = optim_config.pop("class")
         optim_class = get_optimizer_class(optim_class_name)
         from Util import collect_class_init_kwargs
         optim_class_kwargs = collect_class_init_kwargs(optim_class)
         if "epsilon" in optim_class_kwargs:
             optim_config.setdefault("epsilon", epsilon)
         if "momentum" in optim_class_kwargs and momentum:
             optim_config.setdefault("momentum", momentum)
         if "use_locking" in optim_class_kwargs and use_locking:
             optim_config.setdefault("use_locking", use_locking)
         assert "learning_rate" not in optim_config, "learning_rate will be set implicitly"
         optim_config["learning_rate"] = lr
         print("Create optimizer %s with options %r." %
               (optim_class, optim_config),
               file=log.v2)
         optimizer = optim_class(**optim_config)
         assert isinstance(optimizer, tf.train.Optimizer)
     elif self.config.bool("adam", False):
         assert not momentum
         print("Create Adam optimizer.", file=log.v2)
         # Default TF values: learning_rate=0.001, beta1=0.9, beta2=0.999, epsilon=1e-8.
         # Default Keras values: lr=0.001, beta_1=0.9, beta_2=0.999, epsilon=1e-8.
         # Our Theano default values: beta1=0.9, beta2=0.999, epsilon=1e-16
         # https://github.com/openai/improved-gan/blob/master/imagenet/train_imagenet.py: beta1=0.5
         optimizer = tf.train.AdamOptimizer(learning_rate=lr,
                                            epsilon=epsilon,
                                            use_locking=use_locking)
     elif self.config.bool("nadam", False):
         assert_min_tf_version((1, 2, 0),
                               "NadamOptimizer introduced in TF 1.2.0")
         assert not momentum
         print("Create NAdam optimizer.", file=log.v2)
         # TF default values: like Adam: beta1=0.9, beta2=0.999, epsilon=1e-8
         # Our Theano default values: decay=0.004, beta1=0.9, beta2=0.999, epsilon=1e-8
         from tensorflow.contrib.opt import NadamOptimizer
         optimizer = NadamOptimizer(learning_rate=lr,
                                    epsilon=epsilon,
                                    use_locking=use_locking)
     elif self.config.bool("adadelta", False):
         assert not momentum
         print("Create Adadelta optimizer.", file=log.v2)
         optimizer = tf.train.AdadeltaOptimizer(learning_rate=lr,
                                                epsilon=epsilon,
                                                use_locking=use_locking)
     elif self.config.bool("adagrad", False):
         assert not momentum
         print("Create Adagrad optimizer.", file=log.v2)
         optimizer = tf.train.AdagradOptimizer(learning_rate=lr,
                                               use_locking=use_locking)
     elif self.config.is_of_type("rmsprop", float):
         print("Create RMSProp optimizer. With Decay %f" %
               (self.config.float("rmsprop", 0.9)),
               file=log.v2)
         optimizer = tf.train.RMSPropOptimizer(decay=self.config.float(
             "rmsprop", 0.9),
                                               learning_rate=lr,
                                               momentum=momentum,
                                               epsilon=epsilon,
                                               use_locking=use_locking)
     elif self.config.bool("rmsprop", False):
         print("Create RMSProp optimizer.", file=log.v2)
         optimizer = tf.train.RMSPropOptimizer(learning_rate=lr,
                                               momentum=momentum,
                                               epsilon=epsilon,
                                               use_locking=use_locking)
     elif momentum:
         print("Create Momentum optimizer.", file=log.v2)
         optimizer = tf.train.MomentumOptimizer(learning_rate=lr,
                                                momentum=momentum,
                                                use_locking=use_locking)
     else:
         print("Create SGD optimizer.", file=log.v2)
         optimizer = tf.train.GradientDescentOptimizer(
             learning_rate=lr, use_locking=use_locking)
     self.optimizer = optimizer
     self.reset_optim_op()
Ejemplo n.º 6
0
    def __init__(self,
                 sess,
                 towers,
                 precision,
                 layer_sizes=(128, 128, 64, 8, 1),
                 activation_fn=activations.celu,
                 fit_charges=False):
        """
        A queue-enabled multi-gpu trainer. Construction of this class will also
        finalize and initialize all the variables pertaining to the input session.

        Parameters
        ----------
        sess: tf.Session
            A tensorflow session under which we use

        layer_Sizes: sequence of ints
            Defines the shapes of the intermediate atomic nn layers

        fit_charges: bool
            Whether or not we fit partial charges

        precision: tf.dtype
            Should be either tf.float32 or tf.float64

        """
        self.towers = towers
        self.num_towers = len(towers)

        assert fit_charges is False
        assert (precision is tf.float32) or (precision is tf.float64)
        assert self.num_towers > 0
        self.precision = precision

        self.x_enq = tf.placeholder(dtype=precision)
        self.y_enq = tf.placeholder(dtype=precision)
        self.z_enq = tf.placeholder(dtype=precision)
        self.a_enq = tf.placeholder(dtype=tf.int32)
        self.m_enq = tf.placeholder(dtype=tf.int32)
        self.yt_enq = tf.placeholder(dtype=precision)
        self.bi_enq = tf.placeholder(dtype=tf.int32)

        dtypes = [
            precision,  # Xs
            precision,  # Ys
            precision,  # Zs
            tf.int32,  # As
            tf.int32,  # mol ids
            precision,  # Y TRUEss
            tf.int32,  # b_idxs
        ]

        qtypes = [
            self.x_enq,
            self.y_enq,
            self.z_enq,
            self.a_enq,
            self.m_enq,
            self.yt_enq,
            self.bi_enq,
        ]

        # force fitting
        self.force_enq_x = tf.placeholder(dtype=precision,
                                          name="dx")  # (batch_size, 1)
        self.force_enq_y = tf.placeholder(dtype=precision,
                                          name="dy")  # (batch_size, 1)
        self.force_enq_z = tf.placeholder(dtype=precision,
                                          name="dz")  # (batch_size, 1)
        dtypes.extend([precision, precision, precision])
        qtypes.extend([self.force_enq_x, self.force_enq_y, self.force_enq_z])

        queue = tf.FIFOQueue(capacity=20 * self.num_towers, dtypes=dtypes)

        self.put_op = queue.enqueue(qtypes)
        self.sess = sess
        self.non_trainable_variables = []

        with tf.device('/cpu:0'):
            self.learning_rate = tf.get_variable('learning_rate',
                                                 tuple(),
                                                 precision,
                                                 tf.constant_initializer(1e-4),
                                                 trainable=False)
            self.optimizer = NadamOptimizer(learning_rate=self.learning_rate,
                                            beta1=0.9,
                                            beta2=0.999,
                                            epsilon=1e-8)  # default is 1e-8

            self.global_step = tf.get_variable('global_step',
                                               tuple(),
                                               tf.int32,
                                               tf.constant_initializer(0),
                                               trainable=False)
            self.decr_learning_rate = tf.assign(
                self.learning_rate, tf.multiply(self.learning_rate, 0.8))
            self.global_epoch_count = tf.get_variable(
                'global_epoch_count',
                tuple(),
                tf.int32,
                tf.constant_initializer(0),
                trainable=False)
            self.local_epoch_count = tf.get_variable(
                'local_epoch_count',
                tuple(),
                tf.int32,
                tf.constant_initializer(0),
                trainable=False)
            self.incr_global_epoch_count = tf.assign(
                self.global_epoch_count, tf.add(self.global_epoch_count, 1))
            self.incr_local_epoch_count = tf.assign(
                self.local_epoch_count, tf.add(self.local_epoch_count, 1))
            self.reset_local_epoch_count = tf.assign(self.local_epoch_count, 0)

            # these data elements are unordered since the gpu grabs the batches in different orders
            self.tower_grads = []  # average is order invariant
            self.tower_force_grads = []  # yell at yutong for naming this
            self.tower_preds = []
            self.tower_bids = []
            self.tower_l2s = []
            self.tower_mos = []
            self.tower_coord_grads = []
            self.tower_features = []
            self.all_models = []
            self.tower_exp_loss = []
            self.tower_force_rmses = []
            self.parameters = []

            # parameters within a tower are shared.
            with tf.variable_scope(tf.get_variable_scope()):
                for tower_idx, tower_device in enumerate(towers):
                    with tf.device(tower_device):
                        with tf.name_scope("%s_%d" %
                                           ("tower", tower_idx)) as scope:

                            with tf.device('/cpu:0'):
                                get_op = queue.dequeue()
                                x_deq, y_deq, z_deq, a_deq, m_deq, labels, bi_deq = get_op[
                                    0], get_op[1], get_op[2], get_op[
                                        3], get_op[4], get_op[5], get_op[6]

                                dx_deq, dy_deq, dz_deq = get_op[7], get_op[
                                    8], get_op[9]

                                self.tower_bids.append(bi_deq)
                                mol_atom_counts = tf.segment_sum(
                                    tf.ones_like(m_deq), m_deq)
                                mol_offsets = tf.cumsum(mol_atom_counts,
                                                        exclusive=True)

                                scatter_idxs, gather_idxs, atom_counts = ani_mod.ani_sort(
                                    a_deq)

                                self.tower_mos.append(mol_offsets)

                            with tf.device(tower_device):
                                f0, f1, f2, f3 = ani_mod.featurize(
                                    x_deq,
                                    y_deq,
                                    z_deq,
                                    a_deq,
                                    mol_offsets,
                                    mol_atom_counts,
                                    scatter_idxs,
                                    atom_counts,
                                    name="ani_op_" + str(tower_idx))
                                feat_size = f0.op.get_attr("feature_size")

                                # TODO: optimize in C++ code directly to avoid reshape
                                f0 = tf.reshape(f0, (-1, feat_size))
                                f1 = tf.reshape(f1, (-1, feat_size))
                                f2 = tf.reshape(f2, (-1, feat_size))
                                f3 = tf.reshape(f3, (-1, feat_size))

                            # print(f0.shape, f1.shape, f2.shape, f3.shape)

                            self.tower_features.append(
                                tf.gather(tf.concat([f0, f1, f2, f3], axis=0),
                                          gather_idxs))

                            tower_model_near = MoleculeNN(
                                type_map=["H", "C", "N", "O"],
                                precision=precision,
                                atom_type_features=[f0, f1, f2, f3],
                                gather_idxs=gather_idxs,
                                layer_sizes=(feat_size, ) + layer_sizes,
                                activation_fn=activation_fn,
                                prefix="near_")

                            # avoid duplicate parameters from later towers since the variables are shared.
                            if tower_idx == 0:
                                self.parameters.extend(
                                    tower_model_near.get_parameters())

                            self.all_models.append(tower_model_near)
                            tower_near_energy = tf.segment_sum(
                                tower_model_near.atom_outputs, m_deq)

                            if fit_charges:
                                tower_model_charges = MoleculeNN(
                                    type_map=["H", "C", "N", "O"],
                                    atom_type_features=[f0, f1, f2, f3],
                                    gather_idxs=gather_idxs,
                                    layer_sizes=(feat_size, ) + layer_sizes,
                                    precision=precision,
                                    prefix="charge_")

                                if tower_idx == 0:
                                    self.parameters.extend(
                                        tower_model_charges.get_parameters())

                                self.all_models.append(tower_model_charges)
                                tower_charges = tower_model_charges.atom_outputs

                                # (ytz + stevenso): we want to normalize the compute the charge per molecule
                                # note that this only works for *neutral* molecules. For molecules that have a formal charge
                                # we want to specify correct differently, or turn off the normalization entirely.
                                # tower_charges_per_mol = tf.segment_sum(tower_charges, m_deq) # per molecule charge
                                # tower_charges_per_mol = tf.divide(tower_charges_per_mol, tf.cast(mol_atom_counts, dtype=precision)) # per molecule avg charge
                                # tower_charge_correction = tf.gather(tower_charges_per_mol, m_deq) # generate the per atom correction
                                # tower_charges = tf.subtract(tower_charges, tower_charge_correction) # zero out the charge

                                tower_far_energy = ani_mod.ani_charge(
                                    x_deq, y_deq, z_deq, tower_charges,
                                    mol_offsets, mol_atom_counts)
                                tower_pred = tf.add(tower_near_energy,
                                                    tower_far_energy)
                            else:
                                tower_pred = tower_near_energy

                            tf.get_variable_scope().reuse_variables()

                            self.tower_preds.append(tower_pred)
                            tower_l2 = tf.squared_difference(
                                tower_pred, labels)

                            self.tower_l2s.append(tower_l2)
                            tower_rmse = tf.sqrt(tf.reduce_mean(tower_l2))
                            tower_exp_loss = tf.exp(
                                tf.cast(tower_rmse, dtype=tf.float64))
                            self.tower_exp_loss.append(tower_exp_loss)
                            tower_grad = self.optimizer.compute_gradients(
                                tower_exp_loss)
                            self.tower_grads.append(tower_grad)

                            p_dx, p_dy, p_dz = tf.gradients(
                                tower_pred, [x_deq, y_deq, z_deq])

                            self.tower_coord_grads.append([p_dx, p_dy, p_dz])

                            # forces are the negative of the gradient
                            f_dx, f_dy, f_dz = -p_dx, -p_dy, -p_dz

                            # optionally fit to the forces
                            dx_l2 = tf.pow(f_dx - dx_deq, 2)
                            dy_l2 = tf.pow(f_dy - dy_deq, 2)
                            dz_l2 = tf.pow(f_dz - dz_deq, 2)
                            dx_l2 = tf.sqrt(tf.reduce_mean(dx_l2))
                            dy_l2 = tf.sqrt(tf.reduce_mean(dy_l2))
                            dz_l2 = tf.sqrt(tf.reduce_mean(dz_l2))
                            # (todo): triple check that F = -grad(V)
                            tower_force_rmse = dx_l2 + dy_l2 + dz_l2
                            self.tower_force_rmses.append(tower_force_rmse)
                            tower_force_exp_loss = tf.exp(
                                tf.cast(tower_force_rmse, dtype=tf.float64))

                            tower_force_grad = self.optimizer.compute_gradients(
                                tower_force_exp_loss)
                            self.tower_force_grads.append(tower_force_grad)

            def tower_grads(grads):
                apply_gradient_op = self.optimizer.apply_gradients(
                    average_gradients(grads), global_step=self.global_step)
                variable_averages = tf.train.ExponentialMovingAverage(
                    0.9999, self.global_step)
                variables_averages_op = variable_averages.apply(
                    tf.trainable_variables())
                return tf.group(apply_gradient_op, variables_averages_op)

            self.train_op = tower_grads(self.tower_grads)
            self.train_op_forces = tower_grads(self.tower_force_grads)

        ws = self._weight_matrices()
        max_norm_ops = []
        for w in ws:
            max_norm_ops.append(tf.assign(w, tf.clip_by_norm(w, 2.0, axes=1)))
        self.max_norm_ops = max_norm_ops

        self.unordered_l2s = tf.squeeze(tf.concat(self.tower_l2s, axis=0))
        #self.unordered_l2s += l2_norm_k * tf.norm(ws) # one way to enforce an l2 norm

        self.global_initializer_op = tf.global_variables_initializer()
        self.saver = tf.train.Saver()
Ejemplo n.º 7
0
class TrainerMultiTower():
    def __init__(self,
                 sess,
                 towers,
                 precision,
                 layer_sizes=(128, 128, 64, 8, 1),
                 activation_fn=activations.celu,
                 fit_charges=False):
        """
        A queue-enabled multi-gpu trainer. Construction of this class will also
        finalize and initialize all the variables pertaining to the input session.

        Parameters
        ----------
        sess: tf.Session
            A tensorflow session under which we use

        layer_Sizes: sequence of ints
            Defines the shapes of the intermediate atomic nn layers

        fit_charges: bool
            Whether or not we fit partial charges

        precision: tf.dtype
            Should be either tf.float32 or tf.float64

        """
        self.towers = towers
        self.num_towers = len(towers)

        assert fit_charges is False
        assert (precision is tf.float32) or (precision is tf.float64)
        assert self.num_towers > 0
        self.precision = precision

        self.x_enq = tf.placeholder(dtype=precision)
        self.y_enq = tf.placeholder(dtype=precision)
        self.z_enq = tf.placeholder(dtype=precision)
        self.a_enq = tf.placeholder(dtype=tf.int32)
        self.m_enq = tf.placeholder(dtype=tf.int32)
        self.yt_enq = tf.placeholder(dtype=precision)
        self.bi_enq = tf.placeholder(dtype=tf.int32)

        dtypes = [
            precision,  # Xs
            precision,  # Ys
            precision,  # Zs
            tf.int32,  # As
            tf.int32,  # mol ids
            precision,  # Y TRUEss
            tf.int32,  # b_idxs
        ]

        qtypes = [
            self.x_enq,
            self.y_enq,
            self.z_enq,
            self.a_enq,
            self.m_enq,
            self.yt_enq,
            self.bi_enq,
        ]

        # force fitting
        self.force_enq_x = tf.placeholder(dtype=precision,
                                          name="dx")  # (batch_size, 1)
        self.force_enq_y = tf.placeholder(dtype=precision,
                                          name="dy")  # (batch_size, 1)
        self.force_enq_z = tf.placeholder(dtype=precision,
                                          name="dz")  # (batch_size, 1)
        dtypes.extend([precision, precision, precision])
        qtypes.extend([self.force_enq_x, self.force_enq_y, self.force_enq_z])

        queue = tf.FIFOQueue(capacity=20 * self.num_towers, dtypes=dtypes)

        self.put_op = queue.enqueue(qtypes)
        self.sess = sess
        self.non_trainable_variables = []

        with tf.device('/cpu:0'):
            self.learning_rate = tf.get_variable('learning_rate',
                                                 tuple(),
                                                 precision,
                                                 tf.constant_initializer(1e-4),
                                                 trainable=False)
            self.optimizer = NadamOptimizer(learning_rate=self.learning_rate,
                                            beta1=0.9,
                                            beta2=0.999,
                                            epsilon=1e-8)  # default is 1e-8

            self.global_step = tf.get_variable('global_step',
                                               tuple(),
                                               tf.int32,
                                               tf.constant_initializer(0),
                                               trainable=False)
            self.decr_learning_rate = tf.assign(
                self.learning_rate, tf.multiply(self.learning_rate, 0.8))
            self.global_epoch_count = tf.get_variable(
                'global_epoch_count',
                tuple(),
                tf.int32,
                tf.constant_initializer(0),
                trainable=False)
            self.local_epoch_count = tf.get_variable(
                'local_epoch_count',
                tuple(),
                tf.int32,
                tf.constant_initializer(0),
                trainable=False)
            self.incr_global_epoch_count = tf.assign(
                self.global_epoch_count, tf.add(self.global_epoch_count, 1))
            self.incr_local_epoch_count = tf.assign(
                self.local_epoch_count, tf.add(self.local_epoch_count, 1))
            self.reset_local_epoch_count = tf.assign(self.local_epoch_count, 0)

            # these data elements are unordered since the gpu grabs the batches in different orders
            self.tower_grads = []  # average is order invariant
            self.tower_force_grads = []  # yell at yutong for naming this
            self.tower_preds = []
            self.tower_bids = []
            self.tower_l2s = []
            self.tower_mos = []
            self.tower_coord_grads = []
            self.tower_features = []
            self.all_models = []
            self.tower_exp_loss = []
            self.tower_force_rmses = []
            self.parameters = []

            # parameters within a tower are shared.
            with tf.variable_scope(tf.get_variable_scope()):
                for tower_idx, tower_device in enumerate(towers):
                    with tf.device(tower_device):
                        with tf.name_scope("%s_%d" %
                                           ("tower", tower_idx)) as scope:

                            with tf.device('/cpu:0'):
                                get_op = queue.dequeue()
                                x_deq, y_deq, z_deq, a_deq, m_deq, labels, bi_deq = get_op[
                                    0], get_op[1], get_op[2], get_op[
                                        3], get_op[4], get_op[5], get_op[6]

                                dx_deq, dy_deq, dz_deq = get_op[7], get_op[
                                    8], get_op[9]

                                self.tower_bids.append(bi_deq)
                                mol_atom_counts = tf.segment_sum(
                                    tf.ones_like(m_deq), m_deq)
                                mol_offsets = tf.cumsum(mol_atom_counts,
                                                        exclusive=True)

                                scatter_idxs, gather_idxs, atom_counts = ani_mod.ani_sort(
                                    a_deq)

                                self.tower_mos.append(mol_offsets)

                            with tf.device(tower_device):
                                f0, f1, f2, f3 = ani_mod.featurize(
                                    x_deq,
                                    y_deq,
                                    z_deq,
                                    a_deq,
                                    mol_offsets,
                                    mol_atom_counts,
                                    scatter_idxs,
                                    atom_counts,
                                    name="ani_op_" + str(tower_idx))
                                feat_size = f0.op.get_attr("feature_size")

                                # TODO: optimize in C++ code directly to avoid reshape
                                f0 = tf.reshape(f0, (-1, feat_size))
                                f1 = tf.reshape(f1, (-1, feat_size))
                                f2 = tf.reshape(f2, (-1, feat_size))
                                f3 = tf.reshape(f3, (-1, feat_size))

                            # print(f0.shape, f1.shape, f2.shape, f3.shape)

                            self.tower_features.append(
                                tf.gather(tf.concat([f0, f1, f2, f3], axis=0),
                                          gather_idxs))

                            tower_model_near = MoleculeNN(
                                type_map=["H", "C", "N", "O"],
                                precision=precision,
                                atom_type_features=[f0, f1, f2, f3],
                                gather_idxs=gather_idxs,
                                layer_sizes=(feat_size, ) + layer_sizes,
                                activation_fn=activation_fn,
                                prefix="near_")

                            # avoid duplicate parameters from later towers since the variables are shared.
                            if tower_idx == 0:
                                self.parameters.extend(
                                    tower_model_near.get_parameters())

                            self.all_models.append(tower_model_near)
                            tower_near_energy = tf.segment_sum(
                                tower_model_near.atom_outputs, m_deq)

                            if fit_charges:
                                tower_model_charges = MoleculeNN(
                                    type_map=["H", "C", "N", "O"],
                                    atom_type_features=[f0, f1, f2, f3],
                                    gather_idxs=gather_idxs,
                                    layer_sizes=(feat_size, ) + layer_sizes,
                                    precision=precision,
                                    prefix="charge_")

                                if tower_idx == 0:
                                    self.parameters.extend(
                                        tower_model_charges.get_parameters())

                                self.all_models.append(tower_model_charges)
                                tower_charges = tower_model_charges.atom_outputs

                                # (ytz + stevenso): we want to normalize the compute the charge per molecule
                                # note that this only works for *neutral* molecules. For molecules that have a formal charge
                                # we want to specify correct differently, or turn off the normalization entirely.
                                # tower_charges_per_mol = tf.segment_sum(tower_charges, m_deq) # per molecule charge
                                # tower_charges_per_mol = tf.divide(tower_charges_per_mol, tf.cast(mol_atom_counts, dtype=precision)) # per molecule avg charge
                                # tower_charge_correction = tf.gather(tower_charges_per_mol, m_deq) # generate the per atom correction
                                # tower_charges = tf.subtract(tower_charges, tower_charge_correction) # zero out the charge

                                tower_far_energy = ani_mod.ani_charge(
                                    x_deq, y_deq, z_deq, tower_charges,
                                    mol_offsets, mol_atom_counts)
                                tower_pred = tf.add(tower_near_energy,
                                                    tower_far_energy)
                            else:
                                tower_pred = tower_near_energy

                            tf.get_variable_scope().reuse_variables()

                            self.tower_preds.append(tower_pred)
                            tower_l2 = tf.squared_difference(
                                tower_pred, labels)

                            self.tower_l2s.append(tower_l2)
                            tower_rmse = tf.sqrt(tf.reduce_mean(tower_l2))
                            tower_exp_loss = tf.exp(
                                tf.cast(tower_rmse, dtype=tf.float64))
                            self.tower_exp_loss.append(tower_exp_loss)
                            tower_grad = self.optimizer.compute_gradients(
                                tower_exp_loss)
                            self.tower_grads.append(tower_grad)

                            p_dx, p_dy, p_dz = tf.gradients(
                                tower_pred, [x_deq, y_deq, z_deq])

                            self.tower_coord_grads.append([p_dx, p_dy, p_dz])

                            # forces are the negative of the gradient
                            f_dx, f_dy, f_dz = -p_dx, -p_dy, -p_dz

                            # optionally fit to the forces
                            dx_l2 = tf.pow(f_dx - dx_deq, 2)
                            dy_l2 = tf.pow(f_dy - dy_deq, 2)
                            dz_l2 = tf.pow(f_dz - dz_deq, 2)
                            dx_l2 = tf.sqrt(tf.reduce_mean(dx_l2))
                            dy_l2 = tf.sqrt(tf.reduce_mean(dy_l2))
                            dz_l2 = tf.sqrt(tf.reduce_mean(dz_l2))
                            # (todo): triple check that F = -grad(V)
                            tower_force_rmse = dx_l2 + dy_l2 + dz_l2
                            self.tower_force_rmses.append(tower_force_rmse)
                            tower_force_exp_loss = tf.exp(
                                tf.cast(tower_force_rmse, dtype=tf.float64))

                            tower_force_grad = self.optimizer.compute_gradients(
                                tower_force_exp_loss)
                            self.tower_force_grads.append(tower_force_grad)

            def tower_grads(grads):
                apply_gradient_op = self.optimizer.apply_gradients(
                    average_gradients(grads), global_step=self.global_step)
                variable_averages = tf.train.ExponentialMovingAverage(
                    0.9999, self.global_step)
                variables_averages_op = variable_averages.apply(
                    tf.trainable_variables())
                return tf.group(apply_gradient_op, variables_averages_op)

            self.train_op = tower_grads(self.tower_grads)
            self.train_op_forces = tower_grads(self.tower_force_grads)

        ws = self._weight_matrices()
        max_norm_ops = []
        for w in ws:
            max_norm_ops.append(tf.assign(w, tf.clip_by_norm(w, 2.0, axes=1)))
        self.max_norm_ops = max_norm_ops

        self.unordered_l2s = tf.squeeze(tf.concat(self.tower_l2s, axis=0))
        #self.unordered_l2s += l2_norm_k * tf.norm(ws) # one way to enforce an l2 norm

        self.global_initializer_op = tf.global_variables_initializer()
        self.saver = tf.train.Saver()

    def initialize(self):
        """
        Randomly initialize the parameters in the trainer's underlying model.
        """
        self.sess.run(self.global_initializer_op)

    def save_numpy(self, npz_file):
        """
        Save the parameters into a numpy npz. For the sake of consistency, we require that
        the npz_file ends in .npz

        .. note:: This saves the entire state of all variables (including non-trainable ones
            like the learning rate, etc.)

        Parameters
        ----------
        npz_file: str
            filename to save under. Must end in .npz

        """
        _, file_ext = os.path.splitext(npz_file)
        assert file_ext == ".npz"
        save_objs = {}
        all_vars = tf.global_variables()
        for var, val in zip(all_vars, self.sess.run(all_vars)):
            save_objs[var.name] = val
        np.savez(npz_file, **save_objs)

    def load_numpy(self, npz_file, strict=True):
        """
        Load a numpy checkpoint file.

        Parameters
        ----------
        npz_file: str
            filename to load

        strict: bool (optional)
            Whether or not we allow type conversions. By default
            this is set to True. If you're converting a 64 bit checkpoint file
            into lossy 32bit (and vice versa), you can set strict to False to enable the conversion
            automatically.

        """
        objs = np.load(npz_file, allow_pickle=False)
        assign_ops = []
        for k in objs.keys():
            tfo = self.sess.graph.get_tensor_by_name(k)
            npa = objs[k]
            if tfo.dtype.as_numpy_dtype != npa.dtype and strict is True:
                msg = "Cannot deserialize " + str(
                    tfo.dtype.as_numpy_dtype) + " into " + str(npa.dtype)
                msg += ". You may want to set strict=False."
                raise TypeError(msg)
            assign_ops.append(
                tf.assign(tfo, objs[k].astype(tfo.dtype.as_numpy_dtype)))
        self.sess.run(assign_ops)

    def save(self, save_dir):
        """
        (DEPRECATED) Save the entire model to a given directory. Use save_numpy instead.

        Parameters
        ----------
        save_dir: str
            Path of the save_dir. If the path does not exist then it will
            be created automatically.

        """
        if not os.path.exists(save_dir):
            os.makedirs(save_dir)
        save_path = os.path.join(save_dir, "model.ckpt")
        self.saver.save(self.sess, save_path)

    def load(self, save_dir):
        """
        (DEPRECATED) Load an existing model from an existing directory and initialize
        the trainer's Session variables. Use load_numpy instead.

        Parameters
        ----------
        save_dir: str
            Directory containing the checkpoint file. This should be the same
            as what was passed into save().

        .. note:: It is expected that this directory exists.

        """
        save_path = os.path.join(save_dir, "model.ckpt")
        self.saver.restore(self.sess, save_path)

    def _weight_matrices(self):
        weights = []
        # vars are always shared so we can just grab them by the first tower
        for ann in self.all_models[0].anns:
            for W in ann.Ws:
                weights.append(W)
        return weights

    def _biases(self):
        biases = []
        for ann in self.all_models[0].anns:
            for b in ann.bs:
                biases.append(b)
        return biases

    def get_train_op_rmse(self):
        return self.train_op_rmse

    def get_train_op_exp(self):
        return self.train_op_exp

    def get_loss_op(self):
        return self.exp_loss

    def eval_abs_rmse(self, dataset, batch_size=1024):
        """
        Evaluates the absolute RMSE in kcal/mols of the y-values given dataset.

        Parameters
        ----------
        dataset: khan.RawDataset
            Dataset for evaluation

        batch_size: int (optional)
            Size of each batch used during prediction.

        Returns
        -------
        float
            A scalar for the RMSE of the dataset

        """
        # Todo: add support for force errors.
        test_l2s = self.feed_dataset(dataset,
                                     shuffle=False,
                                     target_ops=[self.unordered_l2s],
                                     batch_size=batch_size)
        return np.sqrt(np.mean(
            flatten_results(test_l2s))) * HARTREE_TO_KCAL_PER_MOL

    def coordinate_gradients(self, dataset, batch_size=1024):
        """
        Compute gradients with respect to the (x,y,z) coordinates of a dataset.

        Parameters
        ----------
        dataset: khan.RawDataset
            Dataset from which we predict from.

        batch_size: int (optional)
            Size of each batch used during prediction.

        Returns
        -------
        list of gradients
            Returns a list of num_atoms x 3 gradients for each molecule

        """
        results = self.feed_dataset(dataset,
                                    shuffle=False,
                                    target_ops=[
                                        self.tower_coord_grads, self.tower_mos,
                                        self.tower_bids
                                    ],
                                    batch_size=batch_size)

        for (grad, mo, tids) in results:
            bidxs = np.argsort(tids)
            sorted_grads = np.take(grad, bidxs, axis=0)
            sorted_mos = np.take(mo, bidxs, axis=0)

            for (mo, grad_all) in zip(sorted_mos, sorted_grads):
                # mo is an exclusive prefix sum so the first element is zero
                mo = mo[1:]
                grad_x, grad_y, grad_z = grad_all
                grad_xs = np.split(grad_x, mo)
                grad_ys = np.split(grad_y, mo)
                grad_zs = np.split(grad_z, mo)
                for x, y, z in zip(grad_xs, grad_ys, grad_zs):
                    grad_xyz = np.vstack([x, y, z]).transpose()
                    yield grad_xyz

    def featurize(self, dataset, batch_size=1024):
        """
        Featurize a given dataset.

        Parameters
        ----------
        dataset: khan.RawDataset
            Dataset used for featurization.

        batch_size: int (optional)
            Size of each batch.

        Returns
        -------
        list of np.ndarray
            Returns a list of numpy array corresponding to the features
            of each molecule in the dataset.

        .. note:: This should be used for investigative/debug purposes only. This returns tensors that are
            extremely large in size (hint: 600GB if iterating over gdb8 dataset)

        """

        results = self.feed_dataset(
            dataset,
            shuffle=False,
            target_ops=[self.tower_features, self.tower_mos, self.tower_bids],
            batch_size=batch_size)

        for (feats, mos, tids) in results:
            bidxs = np.argsort(tids)
            sorted_mos = np.take(mos, bidxs, axis=0)
            sorted_feats = np.take(feats, bidxs, axis=0)

            for (mo, feat) in zip(sorted_mos, sorted_feats):
                # mo is an exclusive prefix sum so the first element is zero
                mo = mo[1:]
                feats = np.split(feat, mo)
                for f in feats:
                    yield f

    def predict(self, dataset, batch_size=2048):
        """
        Infer y-values given a dataset.

        Parameters
        ----------
        dataset: khan.RawDataset
            Dataset from which we predict from.

        batch_size: int (optional)
            Size of each batch used during prediction.

        Returns
        -------
        list of floats
            Returns a list of predicted [y0, y1, y2...] in the same order as the dataset Xs [x0, x1, x2...]

        """
        results = self.feed_dataset(
            dataset,
            shuffle=False,
            target_ops=[self.tower_preds, self.tower_bids],
            batch_size=batch_size)
        ordered_ys = []
        for (ys, ids) in results:
            sorted_ys = np.take(ys, np.argsort(ids), axis=0)
            ordered_ys.extend(np.concatenate(sorted_ys, axis=0))
        return ordered_ys

    def eval_rel_rmse(self, dataset, group_ys, batch_size=1024):
        """
        Evaluates the relative RMSE in kcal/mols of the y-values given dataset.

        Parameters
        ----------
        dataset: khan.RawDataset
            Dataset for evaluation. The y-values are ignored.

        group_ys: list of list of floats
            group_ys will be used in-place of the dataset's true y values.

        batch_size: int (optional)
            Size of each batch used during prediction.

        Returns
        -------
        float
            A scalar for the RMSE of the dataset

        """
        ordered_ys = self.predict(dataset, batch_size)
        return ed_harder_rmse(group_ys, ordered_ys) * HARTREE_TO_KCAL_PER_MOL

    def eval_eh_rmse(self, dataset, group_ys, batch_size=1024):
        """
        (DEPRECATED) renamed to eval_rel_rmse
        """
        return self.eval_rel_rmse(dataset, group_ys, batch_size)

    def feed_dataset(self,
                     dataset,
                     shuffle,
                     target_ops,
                     batch_size,
                     fuzz=None,
                     before_hooks=None):
        """
        Feed a dataset into the trainer under arbitrary ops.

        Params
        ------
        dataset: khan.RawDataset
            Input dataset that may or may not have y-values depending on the target_ops

        shuffle: bool
            Whether or not we randomly shuffle the data before feeding.

        target_ops: list of tf.Tensors
            tf.Tensors for which we wish to obtain values for.

        batch_size: int
            Size of the batch for which we iterate the dataset over.

        hooks: list of tf.Ops
            List of tensorflow ops which we run before every batch. Note that currently
            these ops must have no feed_dict dependency.

        Returns
        -------
        A generator that yields results of the specified op in increments of batch_size.

        .. note:: You must ensure that resulting generator is fully iterated over to ensure
            proper terminating of the submission threads. Furthermore, the resulting iterable
            should be as non-blocking as possible, since flushing of the queue assumes that the
            results are consumed asap.

        """
        def submitter():

            accum = 0
            g_b_idx = 0

            # suppose we have 4 gpus and 5 batches
            # the distribution schedule is as follows:
            # gpu   0 1 2 3
            # bid0  1 1 1 1
            # bid1  1 0 0 0

            # suppose we have 3 gpus and 5 batches
            # the distribution schedule is as follows:
            # gpu   0 1 2
            # bid0  1 1 1
            # bid1  1 1 0
            try:
                n_batches = dataset.num_batches(batch_size)
                for b_idx, (mol_xs, mol_idxs, mol_yts, mol_grads) in enumerate(
                        dataset.iterate(batch_size=batch_size,
                                        shuffle=shuffle,
                                        fuzz=fuzz)):
                    atom_types = (mol_xs[:, 0]).astype(np.int32)
                    if before_hooks:
                        self.sess.run(before_hooks)

                    feed_dict = {
                        self.x_enq: mol_xs[:, 1],
                        self.y_enq: mol_xs[:, 2],
                        self.z_enq: mol_xs[:, 3],
                        self.a_enq: atom_types,
                        self.m_enq: mol_idxs,
                        self.yt_enq: mol_yts,
                        self.bi_enq: b_idx
                    }

                    if mol_grads is not None:
                        feed_dict[self.force_enq_x] = mol_grads[:, 0]
                        feed_dict[self.force_enq_y] = mol_grads[:, 1]
                        feed_dict[self.force_enq_z] = mol_grads[:, 2]
                    else:
                        num_mols = mol_xs.shape[0]
                        feed_dict[self.force_enq_x] = np.zeros(
                            (num_mols, 0), dtype=self.precision.as_numpy_dtype)
                        feed_dict[self.force_enq_y] = np.zeros(
                            (num_mols, 0), dtype=self.precision.as_numpy_dtype)
                        feed_dict[self.force_enq_z] = np.zeros(
                            (num_mols, 0), dtype=self.precision.as_numpy_dtype)

                    self.sess.run(self.put_op, feed_dict=feed_dict)
                    g_b_idx += 1

                # division across multiple towers
                remainder = n_batches % self.num_towers
                if remainder:
                    for _ in range(self.num_towers - remainder):
                        if before_hooks:
                            self.sess.run(before_hooks)

                        feed_dict = {
                            self.x_enq:
                            np.zeros((0, 1),
                                     dtype=self.precision.as_numpy_dtype),
                            self.y_enq:
                            np.zeros((0, 1),
                                     dtype=self.precision.as_numpy_dtype),
                            self.z_enq:
                            np.zeros((0, 1),
                                     dtype=self.precision.as_numpy_dtype),
                            self.a_enq:
                            np.zeros((0, ), dtype=np.int32),
                            self.m_enq:
                            np.zeros((0, ), dtype=np.int32),
                            self.yt_enq:
                            np.zeros((0, )),
                            self.bi_enq:
                            b_idx,
                        }

                        feed_dict[self.force_enq_x] = np.zeros(
                            (0, 1), dtype=self.precision.as_numpy_dtype)
                        feed_dict[self.force_enq_y] = np.zeros(
                            (0, 1), dtype=self.precision.as_numpy_dtype)
                        feed_dict[self.force_enq_z] = np.zeros(
                            (0, 1), dtype=self.precision.as_numpy_dtype)

                        self.sess.run(self.put_op, feed_dict=feed_dict)
                        b_idx += 1

            except Exception as e:
                print("QueueError:", e)

        executor = ThreadPoolExecutor(4)
        executor.submit(submitter)

        n_tower_batches = -(-dataset.num_batches(batch_size=batch_size) //
                            self.num_towers)

        for i in range(n_tower_batches):
            yield self.sess.run(target_ops)

    # run the actual training
    # (ytz) - this is maintained by jminuse for the sake of convenience for now.
    # This is HOTMERGED - I'd avoid calling this code if possible, it seriously needs refactoring
    def train(self,
              save_dir,
              rd_train,
              rd_test,
              rd_gdb11,
              eval_names,
              eval_datasets,
              eval_groups,
              batch_size,
              max_local_epoch_count=25,
              max_batch_size=1e4,
              max_global_epoch_count=1000):

        train_ops = [
            self.global_epoch_count, self.learning_rate,
            self.local_epoch_count, self.unordered_l2s, self.train_op
        ]
        start_time = time.time()
        best_test_score = self.eval_abs_rmse(rd_test)
        global_epoch = 0
        while batch_size < max_batch_size and global_epoch <= max_global_epoch_count:  # bigger batches as fitting goes on, makes updates less exploratory
            while self.sess.run(
                    self.local_epoch_count
            ) < max_local_epoch_count and global_epoch <= max_global_epoch_count:
                for step in range(
                        2
                ):  # how many rounds to perform before checking test rmse. Evaluation takes about as long as training for the same number of points, so it can be a waste to evaluate every time.
                    train_step_time = time.time()
                    train_results = list(
                        self.feed_dataset(rd_train,
                                          shuffle=True,
                                          target_ops=train_ops,
                                          batch_size=batch_size,
                                          before_hooks=self.max_norm_ops))
                    train_abs_rmse = np.sqrt(
                        np.mean(flatten_results(
                            train_results, pos=3))) * HARTREE_TO_KCAL_PER_MOL
                    print(
                        '%s Training step %d: train RMSE %.2f kcal/mol in %.1fs'
                        % (save_dir, step, train_abs_rmse,
                           time.time() - train_step_time))
                global_epoch = train_results[0][0]
                learning_rate = train_results[0][1]
                local_epoch_count = train_results[0][2]
                test_abs_rmse_time = time.time()
                test_abs_rmse = self.eval_abs_rmse(rd_test)
                #print('test_abs_rmse_time', time.time()-test_abs_rmse_time )
                time_per_epoch = time.time() - start_time
                start_time = time.time()
                print(save_dir, end=' ')
                print(time.strftime("%Y-%m-%d %H:%M:%S"),
                      'tpe:',
                      "{0:.2f}s,".format(time_per_epoch),
                      'g-epoch',
                      global_epoch,
                      'l-epoch',
                      local_epoch_count,
                      'lr',
                      "{0:.0e}".format(learning_rate),
                      'batch_size',
                      batch_size,
                      '| train/test abs rmse:',
                      "{0:.2f} kcal/mol,".format(train_abs_rmse),
                      "{0:.2f} kcal/mol".format(test_abs_rmse),
                      end='')

                if test_abs_rmse < best_test_score:
                    self.save_best_params()
                    best_test_score = test_abs_rmse
                    self.sess.run([
                        self.incr_global_epoch_count,
                        self.reset_local_epoch_count
                    ])
                else:
                    self.sess.run([
                        self.incr_global_epoch_count,
                        self.incr_local_epoch_count
                    ])

                gdb11_abs_rmse = self.eval_abs_rmse(rd_gdb11)
                print(' | gdb11 abs rmse',
                      "{0:.2f} kcal/mol | ".format(gdb11_abs_rmse),
                      end='')
                for name, ff_data, ff_groups in zip(eval_names, eval_datasets,
                                                    eval_groups):
                    print(name, "abs/rel rmses", "{0:.2f} kcal/mol,".format(self.eval_abs_rmse(ff_data)), \
                        "{0:.2f} kcal/mol | ".format(self.eval_eh_rmse(ff_data, ff_groups)), end='')

                print('')
                self.save(save_dir)

            print(
                "========== Decreasing learning rate, increasing batch size =========="
            )
            self.load_best_params()
            self.sess.run(self.decr_learning_rate)
            self.sess.run(self.reset_local_epoch_count)
            batch_size = int(batch_size * 1.2)
            max_local_epoch_count = int(
                max_local_epoch_count * 1.2
            )  # increase this since higher batch size means fewer actual gradient steps per epoch
Ejemplo n.º 8
0
    def __init__(self, **optimizer_kwargs):
        self._model = optimizer_kwargs["model"]

        self._individual_learning_rate = optimizer_kwargs[
            "individual_learning_rate"]

        self._learning_rate = optimizer_kwargs["learning_rate"]
        self._rescale_learning_rate = optimizer_kwargs["rescale_learning_rate"]
        self._d_p = None
        self._n_reg = None

        post_optimizer = optimizer_kwargs[
            "post_optimizer"] if "post_optimizer" in optimizer_kwargs else None
        if post_optimizer is None:
            self._post_optimizer = super()

        elif post_optimizer == "Momentum":
            self._post_optimizer = MomentumOptimizer(
                learning_rate=optimizer_kwargs["learning_rate"],
                momentum=0.95,
                use_locking=False,
                name="MomentumOptimizer")

        elif post_optimizer == "RMSProp":
            self._post_optimizer = RMSPropOptimizer(
                learning_rate=optimizer_kwargs["learning_rate"],
                decay=0.9,
                epsilon=1e-5,
                use_locking=False,
                name="RMSPropOptimizer")

        elif post_optimizer == "Adam":
            self._post_optimizer = AdamOptimizer(
                learning_rate=optimizer_kwargs["learning_rate"],
                beta1=0.9,
                beta2=0.999,
                epsilon=1e-8,
                use_locking=False,
                name="AdamOptimizer")
        elif post_optimizer == "Nadam":
            self._post_optimizer = NadamOptimizer(
                learning_rate=optimizer_kwargs["learning_rate"],
                beta1=0.9,
                beta2=0.999,
                epsilon=1e-8,
                use_locking=False,
                name="NadamOptimizer")

        elif post_optimizer == "Nesterov":
            self._post_optimizer = MomentumOptimizer(
                learning_rate=optimizer_kwargs["learning_rate"],
                momentum=0.95,
                use_locking=False,
                use_nesterov=True,
                name="NesterovMomentumOptimizer")
        elif post_optimizer == "NesterovConst":
            self._post_optimizer = NesterovConst(
                model=self._model,
                learning_rate=optimizer_kwargs["learning_rate"],
                use_locking=False,
                name="NesterovConstOptimizer")

        else:
            raise Exception(
                "There is no such post optimizer defined. Must be: None, Adam, Momentum, RMSProp"
            )

        super().__init__(self._learning_rate)
Ejemplo n.º 9
0
layer = fc_layer(layer, 256)
layer = dropout(layer, is_training=is_train)
layer = Dense(10, kernel_initializer='glorot_normal')(layer)
layer = batch_norm(layer,
                   updates_collections=None,
                   center=True,
                   scale=True,
                   is_training=is_train)
preds = activations.softmax(layer)

lossL2 = tf.add_n(
    [tf.nn.l2_loss(v) for v in tf.trainable_variables() if 'kernel' in v.name])

beta = 1e-7
loss = tf.reduce_mean(losses.categorical_crossentropy(labels, preds))
train_step = NadamOptimizer(learning_rate=lr).minimize(loss)

# Initialize all variables
init_op = tf.global_variables_initializer()
sess.run(init_op)

acc_value = tf.reduce_mean(metrics.categorical_accuracy(labels, preds))


def accuracy(data, n):
    l = []
    for i in range(n):
        batch = data.next_batch(100)
        acc = acc_value.eval(feed_dict={
            img: batch[0],
            labels: batch[1],
Ejemplo n.º 10
0
    def _build_graph(self):
        growth_rate = self.growth_rate
        layers_per_block = self.layers_per_block
        # first - initial 3 x 3 conv to first_output_features
        with tf.variable_scope("Initial_convolution"):
            output = self.conv2d(self.images,
                                 out_features=self.first_output_features,
                                 kernel_size=3)

        # add N required blocks
        for block in range(self.total_blocks):
            with tf.variable_scope("Block_%d" % block):
                output = self.add_block(output, growth_rate, layers_per_block)
            # last block exist without transition layer
            if block != self.total_blocks - 1:
                with tf.variable_scope("Transition_after_block_%d" % block):
                    output = self.transition_layer(output)

        with tf.variable_scope("Transition_to_classes"):
            logits = self.transition_layer_to_classes(output)
        prediction = tf.nn.softmax(logits)

        # Losses
        cross_entropy = tf.reduce_mean(
            tf.nn.softmax_cross_entropy_with_logits(logits=logits,
                                                    labels=self.labels))
        self.cross_entropy = cross_entropy
        l2_loss = tf.add_n(
            [tf.nn.l2_loss(var) for var in tf.trainable_variables()])

        # optimizer and train step
        if self.optimizer == "adam":
            print("adam optimizer: %.4f, %.4f" % (self.beta1, self.beta2))
            optimizer = optimizer_all.Adam(learning_rate=self.learning_rate,
                                           beta1=self.beta1,
                                           beta2=self.beta2,
                                           epsilon=self.epsilon)
        elif self.optimizer == "adaShift":
            print("adaShift optimizer: %.4f, %.4f, %d" %
                  (self.beta1, self.beta2, self.keep_num))
            optimizer = optimizer_all.AdaShift(
                learning_rate=self.learning_rate,
                beta1=self.beta1,
                beta2=self.beta2,
                epsilon=self.epsilon,
                keep_num=self.keep_num,
                pred_g_op=self.pred_g_op,
                use_mov=self.use_mov,
                mov_num=self.mov_num)
        elif self.optimizer == "amsgrad":
            print("amsgrad optimizer: %.4f, %.4f" % (self.beta1, self.beta2))
            optimizer = optimizer_all.AMSGrad(learning_rate=self.learning_rate,
                                              beta1=self.beta1,
                                              beta2=self.beta2,
                                              epsilon=self.epsilon)
        elif self.optimizer == "adamspace":
            print("adamspace optimizer: %.4f, %.4f" % (self.beta1, self.beta2))
            optimizer = optimizer_all.AdamSpace(
                learning_rate=self.learning_rate,
                beta1=self.beta1,
                beta2=self.beta2,
                epsilon=self.epsilon)
        elif self.optimizer == "nadam":
            print("nadam optimizer: %.4f, %.4f" % (self.beta1, self.beta2))
            optimizer = NadamOptimizer(learning_rate=self.learning_rate,
                                       beta1=self.beta1,
                                       beta2=self.beta2,
                                       epsilon=self.epsilon)
        else:
            print("momentum optimizer")
            optimizer = tf.train.MomentumOptimizer(self.learning_rate,
                                                   self.nesterov_momentum,
                                                   use_nesterov=True)
        self.train_step = optimizer.minimize(cross_entropy +
                                             l2_loss * self.weight_decay)

        correct_prediction = tf.equal(tf.argmax(prediction, 1),
                                      tf.argmax(self.labels, 1))
        self.accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
Ejemplo n.º 11
0
    def _set_train_or_infer(self, res, reverse_target_vocab_table, hparams):
        """Set up training and inference."""
        if self.mode == tf.contrib.learn.ModeKeys.TRAIN:
            self.train_loss = res[1]
            self.word_count = tf.reduce_sum(
                self.iterator.source_sequence_length) + tf.reduce_sum(
                    self.iterator.target_sequence_length)
        elif self.mode == tf.contrib.learn.ModeKeys.EVAL:
            self.eval_loss = res[1]
        elif self.mode == tf.contrib.learn.ModeKeys.INFER:
            self.infer_logits, _, self.final_context_state, self.sample_id = res
            self.sample_words = reverse_target_vocab_table.lookup(
                tf.to_int64(self.sample_id))

        if self.mode != tf.contrib.learn.ModeKeys.INFER:
            ## Count the number of predicted words for compute ppl.
            self.predict_count = tf.reduce_sum(
                self.iterator.target_sequence_length)

        params = tf.trainable_variables()

        # Gradients and SGD update operation for training the model.
        # Arrange for the embedding vars to appear at the beginning.
        if self.mode == tf.contrib.learn.ModeKeys.TRAIN:
            self.learning_rate = tf.constant(hparams.learning_rate)
            # warm-up
            self.learning_rate = self._get_learning_rate_warmup(hparams)
            # decay
            self.learning_rate = self._get_learning_rate_decay(hparams)

            if hparams.optimizer == "sgd":
                opt = tf.train.GradientDescentOptimizer(self.learning_rate)
            elif hparams.optimizer == 'adam':  # lr 1e-3 0.9 0.999
                opt = optimizer_all.Adam(learning_rate=self.learning_rate,
                                         beta1=hparams.beta1,
                                         beta2=hparams.beta2,
                                         epsilon=hparams.epsilon)
            elif hparams.optimizer == 'adaShift':  # 0.01  10  0.9  0.999  1e-8   'max'|'none'|
                print("adashift optimizer %s" % hparams.pred_g_op)
                opt = optimizer_all.AdaShift(learning_rate=self.learning_rate,
                                             keep_num=hparams.keep_num,
                                             beta1=hparams.beta1,
                                             beta2=hparams.beta2,
                                             epsilon=hparams.epsilon,
                                             pred_g_op=hparams.pred_g_op,
                                             use_mov=(hparams.use_mov == 1),
                                             mov_num=hparams.mov_num)
            elif hparams.optimizer == "amsgrad":
                opt = optimizer_all.AMSGrad(learning_rate=self.learning_rate,
                                            beta1=hparams.beta1,
                                            beta2=hparams.beta2,
                                            epsilon=hparams.epsilon)
            elif hparams.optimizer == "nadam":
                opt = NadamOptimizer(learning_rate=self.learning_rate,
                                     beta1=hparams.beta1,
                                     beta2=hparams.beta2,
                                     epsilon=hparams.epsilon)
            elif hparams.optimizer == "adamspace":
                opt = optimizer_all.AdamSpace(learning_rate=self.learning_rate,
                                              beta1=hparams.beta1,
                                              beta2=hparams.beta2,
                                              epsilon=hparams.epsilon)
            else:
                # assert 'No optimizer has been chosed, name may be wrong'
                opt = tf.train.MomentumOptimizer(
                    learning_rate=self.learning_rate,
                    momentum=0.9,
                    use_nesterov=False)

            # Gradients
            gradients = tf.gradients(self.train_loss,
                                     params,
                                     colocate_gradients_with_ops=hparams.
                                     colocate_gradients_with_ops)

            clipped_grads, grad_norm_summary, grad_norm = model_helper.gradient_clip(
                gradients, max_gradient_norm=hparams.max_gradient_norm)
            self.grad_norm_summary = grad_norm_summary
            self.grad_norm = grad_norm

            self.update = opt.apply_gradients(zip(clipped_grads, params),
                                              global_step=self.global_step)

            # Summary
            self.train_summary = self._get_train_summary()
        elif self.mode == tf.contrib.learn.ModeKeys.INFER:
            self.infer_summary = self._get_infer_summary(hparams)

        # Print trainable variables
        utils.print_out("# Trainable variables")
        utils.print_out("Format: <name>, <shape>, <(soft) device placement>")
        for param in params:
            utils.print_out(
                "  %s, %s, %s" %
                (param.name, str(param.get_shape()), param.op.device))