Ejemplo n.º 1
0
class ConditionalOptimizer(tf.train.Optimizer):
    """Conditional optimizer."""
    def __init__(self, optimizer_name, lr, hparams, use_tpu=False):  # pylint: disable=super-init-not-called
        tf.logging.info("Using optimizer %s", optimizer_name)

        mlperf_log.transformer_print(key=mlperf_log.OPT_NAME,
                                     value=optimizer_name,
                                     hparams=hparams)
        mlperf_log.transformer_print(key=mlperf_log.OPT_HP_ADAM_BETA1,
                                     value=hparams.optimizer_adam_beta1,
                                     hparams=hparams)
        mlperf_log.transformer_print(key=mlperf_log.OPT_HP_ADAM_BETA2,
                                     value=hparams.optimizer_adam_beta2,
                                     hparams=hparams)
        mlperf_log.transformer_print(key=mlperf_log.OPT_HP_ADAM_EPSILON,
                                     value=hparams.optimizer_adam_epsilon,
                                     hparams=hparams)

        self._opt = registry.optimizer(optimizer_name)(lr, hparams)
        if _mixed_precision_is_enabled(hparams):
            if not hparams.mixed_precision_optimizer_loss_scaler:
                tf.logging.warning(
                    "Using mixed precision without a loss scaler will "
                    "likely cause numerical errors.")
            elif hparams.mixed_precision_optimizer_loss_scaler != "exponential":
                raise ValueError("Mixed precision training only supports the "
                                 "exponential loss scaler")
            else:
                tf.logging.info("Using Exponential Update Loss Scaler")
                manager = tf.contrib.mixed_precision.ExponentialUpdateLossScaleManager(
                    init_loss_scale=2**15,
                    incr_every_n_steps=2000,
                    decr_every_n_nan_or_inf=2,
                    incr_ratio=2,
                    decr_ratio=0.5)
                self._opt = LossScaleOptimizer(self._opt, manager)

        self._zero_grads = hparams.optimizer_zero_grads

    def compute_gradients(self, loss, var_list=None, **kwargs):  # pylint: disable=arguments-differ
        gradients = self._opt.compute_gradients(loss, var_list, **kwargs)

        def cast_grad(g, v):
            if v is not None and g is not None:
                g = common_layers.cast_like(g, v)
            if self._zero_grads and g is None:
                g = tf.zeros_like(v)
            return (g, v)

        gradients = [cast_grad(g, v) for g, v in gradients]
        return gradients

    def apply_gradients(self, grads_and_vars, global_step=None, name=None):
        return self._opt.apply_gradients(grads_and_vars,
                                         global_step=global_step,
                                         name=name)
Ejemplo n.º 2
0
    def __init__(self, rows, cols, channels, classes, latent, tpu=False):

        if tpu:
            set_floatx('float16')
            set_epsilon(1e-4)

        # Input shape
        self.img_rows = rows
        self.img_cols = cols
        self.channels = channels
        self.img_shape = (self.img_rows, self.img_cols, self.channels)

        self.num_of_classes = classes

        # size of the vector to fid the generator (z)
        self.latent_dim = latent

        #optimizer = Adam(0.0002, 0.5)

        optimizer = tf.train.AdamOptimizer(0.0002, 0.5)

        loss_scale_manager = FixedLossScaleManager(5000)

        loss_scale_optimizer = LossScaleOptimizer(optimizer,
                                                  loss_scale_manager)

        losses = ['binary_crossentropy', 'sparse_categorical_crossentropy']

        # Build and compile the discriminator
        self.discriminator = self.build_discriminator()

        self.discriminator.compile(loss=losses,
                                   optimizer=loss_scale_optimizer,
                                   metrics=['accuracy'])

        # Build the generator
        self.generator = self.build_generator()

        # The generator takes noise and the target label as input
        # and generates the corresponding digit of that label
        noise = Input(shape=(self.latent_dim, ))
        label = Input(shape=(1, ))
        img = self.generator([noise, label])

        # For the combined model we will only train the generator
        self.discriminator.trainable = False

        # The discriminator takes generated image as input and determines validity
        # and the label of that image
        valid, target_label = self.discriminator(img)

        # The combined model  (stacked generator and discriminator)
        # Trains the generator to fool the discriminator
        self.combined = Model([noise, label], [valid, target_label])
        self.combined.compile(loss=losses, optimizer=loss_scale_optimizer)
Ejemplo n.º 3
0
    def __init__(self, optimizer_name, lr, hparams, use_tpu=False):  # pylint: disable=super-init-not-called
        tf.logging.info("Using optimizer %s", optimizer_name)

        mlperf_log.transformer_print(key=mlperf_log.OPT_NAME,
                                     value=optimizer_name,
                                     hparams=hparams)
        mlperf_log.transformer_print(key=mlperf_log.OPT_HP_ADAM_BETA1,
                                     value=hparams.optimizer_adam_beta1,
                                     hparams=hparams)
        mlperf_log.transformer_print(key=mlperf_log.OPT_HP_ADAM_BETA2,
                                     value=hparams.optimizer_adam_beta2,
                                     hparams=hparams)
        mlperf_log.transformer_print(key=mlperf_log.OPT_HP_ADAM_EPSILON,
                                     value=hparams.optimizer_adam_epsilon,
                                     hparams=hparams)

        self._opt = registry.optimizer(optimizer_name)(lr, hparams)
        if _mixed_precision_is_enabled(hparams):
            if not hparams.mixed_precision_optimizer_loss_scaler:
                tf.logging.warning(
                    "Using mixed precision without a loss scaler will "
                    "likely cause numerical errors.")
            elif hparams.mixed_precision_optimizer_loss_scaler != "exponential":
                raise ValueError("Mixed precision training only supports the "
                                 "exponential loss scaler")
            else:
                tf.logging.info("Using Exponential Update Loss Scaler")
                manager = tf.contrib.mixed_precision.ExponentialUpdateLossScaleManager(
                    init_loss_scale=2**15,
                    incr_every_n_steps=2000,
                    decr_every_n_nan_or_inf=2,
                    incr_ratio=2,
                    decr_ratio=0.5)
                self._opt = LossScaleOptimizer(self._opt, manager)

        self._zero_grads = hparams.optimizer_zero_grads
Ejemplo n.º 4
0
def cifar10_model_fn(features, labels, params):
    print('PARAMS', params['fp16'])
    """Model function for CIFAR-10."""
    tf.summary.image('images', features, max_outputs=6)

    inputs = tf.reshape(features, [-1, HEIGHT, WIDTH, DEPTH])
    if params['fp16']:
        inputs = tf.cast(inputs, tf.float16)

    logits = densenet_121.get_model(inputs, NUM_CLASSES, params['data_format'], params['efficient'])
    logits = tf.cast(logits, tf.float32)

    predictions = {
        'classes': tf.argmax(logits, axis=1),
        'probabilities': tf.nn.softmax(logits, name='softmax_tensor')
    }

    # Calculate loss, which includes softmax cross entropy and L2 regularization.
    cross_entropy = tf.losses.softmax_cross_entropy(logits=logits, onehot_labels=labels)

    # Create a tensor named cross_entropy for logging purposes.
    tf.identity(cross_entropy, name='cross_entropy')
    tf.summary.scalar('cross_entropy', cross_entropy)

    # Add weight decay to the loss.
    loss = cross_entropy + WEIGHT_DECAY * tf.add_n(
        [tf.nn.l2_loss(v) for v in tf.trainable_variables()])

    # Scale the learning rate linearly with the batch size. When the batch size
    # is 128, the learning rate should be 0.1.
    initial_learning_rate = 0.1 * params['batch_size'] / 128
    batches_per_epoch = NUM_IMAGES['train'] / params['batch_size']
    global_step = tf.train.get_or_create_global_step()

    # Multiply the learning rate by 0.1 at 100, 150, and 200 epochs.
    boundaries = [int(batches_per_epoch * epoch) for epoch in [100, 150, 200]]
    values = [initial_learning_rate * decay for decay in [1, 0.1, 0.01, 0.001]]
    learning_rate = tf.train.piecewise_constant(
        tf.cast(global_step, tf.int32), boundaries, values)

    # Create a tensor named learning_rate for logging purposes
    tf.identity(learning_rate, name='learning_rate')
    tf.summary.scalar('learning_rate', learning_rate)

    optimizer = tf.train.MomentumOptimizer(
        learning_rate=learning_rate,
        momentum=MOMENTUM)

    if params['fp16']:
        # Choose a loss scale manager which decides how to pick the right loss scale
        # throughout the training process.
        loss_scale_manager = ExponentialUpdateLossScaleManager(128, 100)
        # Wraps the original optimizer in a LossScaleOptimizer.
        optimizer = LossScaleOptimizer(optimizer, loss_scale_manager)

    compression = hvd.Compression.fp16 if params['fp16'] else hvd.Compression.none

    optimizer = hvd.DistributedOptimizer(optimizer, compression=compression)

    # Batch norm requires update ops to be added as a dependency to the train_op
    update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
    with tf.control_dependencies(update_ops):
        train_op = optimizer.minimize(loss, global_step)

    accuracy = tf.metrics.accuracy(
        tf.argmax(labels, axis=1), predictions['classes'])
    metrics = {'accuracy': accuracy}

    # Create a tensor named train_accuracy for logging purposes
    tf.identity(accuracy[1], name='train_accuracy')
    tf.summary.scalar('train_accuracy', accuracy[1])

    return train_op, loss, global_step
Ejemplo n.º 5
0
    def _init_optimiser(self):
        r"""
                Computes the batch_loss function to be minimised
                """

        self._init_lr_decay()

        self._loss_weights = tf.sequence_mask(
            lengths=self._decoder._labels_len, dtype=self._hparams.dtype)

        if self._hparams.loss_fun is None:
            if self._hparams.label_smoothing <= 0.0:
                softmax_loss_fun = None
            else:
                print(
                    'Using the slower "softmax_cross_entropy" instead of "sparse_softmax_cross_entropy" '
                    'since label smoothing is nonzero')
                from .devel import smoothed_cross_entropy
                num_classes = tf.shape(self._decoder._logits)[2]
                softmax_loss_fun = smoothed_cross_entropy(
                    num_classes, self._hparams.label_smoothing)
        elif self._hparams.loss_fun == 'focal_loss':
            from .devel import focal_loss
            softmax_loss_fun = focal_loss
        elif self._hparams.loss_fun == 'mc_loss':
            from .devel import mc_loss
            softmax_loss_fun = mc_loss
        else:
            raise ValueError('Unknown loss function {}'.format(
                self._hparams.loss_fun))

        self.batch_loss = seq2seq.sequence_loss(
            logits=self._decoder._logits,
            targets=self._decoder._labels,
            weights=self._loss_weights,
            softmax_loss_function=softmax_loss_fun,
            average_across_batch=True,
            average_across_timesteps=True)

        reg_loss = 0

        if self._hparams.recurrent_l2_regularisation is not None:
            regularisable_vars = _get_trainable_vars(self._hparams.cell_type)
            reg = tf.contrib.layers.l2_regularizer(
                scale=self._hparams.recurrent_l2_regularisation)
            reg_loss = tf.contrib.layers.apply_regularization(
                reg, regularisable_vars)

        if self._hparams.video_processing is not None:
            if 'cnn' in self._hparams.video_processing:
                # we regularise the cnn vars by specifying a regulariser in conv2d
                reg_variables = tf.get_collection(
                    tf.GraphKeys.REGULARIZATION_LOSSES)
                reg_loss += tf.reduce_sum(reg_variables)

        self.batch_loss = self.batch_loss + reg_loss

        if self._hparams.regress_aus is True:
            loss_weight = self._hparams.kwargs.get('au_loss_weight', 10.0)
            self.batch_loss += loss_weight * self._video_encoder.au_loss

        if self._hparams.loss_scaling > 1:
            self.batch_loss *= self._hparams.loss_scaling

        if self._hparams.optimiser == 'Adam':
            optimiser = tf.train.AdamOptimizer(
                learning_rate=self.current_lr,
                epsilon=1e-8 if self._hparams.dtype == tf.float32 else 1e-4,
            )
        elif self._hparams.optimiser == 'Nadam':
            from tensorflow.contrib.opt import NadamOptimizer
            optimiser = NadamOptimizer(learning_rate=self.current_lr, )
        elif self._hparams.optimiser == 'AdamW':
            from tensorflow.contrib.opt import AdamWOptimizer
            optimiser = AdamWOptimizer(
                learning_rate=self.current_lr,
                weight_decay=self._hparams.weight_decay,
                epsilon=1e-8 if self._hparams.dtype == tf.float32 else 1e-4,
            )
        elif self._hparams.optimiser == 'Momentum':
            optimiser = tf.train.MomentumOptimizer(
                learning_rate=self.current_lr,
                momentum=0.9,
                use_nesterov=False)
        else:
            raise Exception('Unsupported optimiser, try Adam')

        variables = tf.trainable_variables()
        gradients = tf.gradients(
            self.batch_loss,
            variables)  # not compatible with Nvidia AMP (fp16)
        # gradients = optimiser.compute_gradients(self.batch_loss, variables)

        summaries = []
        for grad, variable in zip(gradients, variables):
            if isinstance(grad, tf.IndexedSlices):
                value = grad.values
            else:
                value = grad
            summary = tf.summary.histogram("%s-grad" % variable.name, value)
            summaries.append(summary)

        if self._hparams.dtype == tf.float16:
            #ripped from https://github.com/joeyearsley/efficient_densenet_tensorflow/blob/master/train.py
            # Choose a loss scale manager which decides how to pick the right loss scale
            # throughout the training process.
            loss_scale_manager = ExponentialUpdateLossScaleManager(128, 100)
            # Wraps the original optimizer in a LossScaleOptimizer.
            optimizer = LossScaleOptimizer(optimiser, loss_scale_manager)

        if self._hparams.loss_scaling > 1:
            gradients = [
                tf.div(grad, self._hparams.loss_scaling) for grad in gradients
            ]

        if self._hparams.clip_gradients is True:
            gradients, self.global_norm = tf.clip_by_global_norm(
                gradients, self._hparams.max_gradient_norm)

        if self._hparams.batch_normalisation is True:
            update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
            with tf.control_dependencies(update_ops):
                self.train_op = optimiser.apply_gradients(
                    grads_and_vars=zip(gradients, variables),
                    global_step=tf.train.get_global_step())
        else:
            self.train_op = optimiser.apply_gradients(
                grads_and_vars=zip(gradients, variables),
                global_step=tf.train.get_global_step())
Ejemplo n.º 6
0
def create_optimizer(loss,
                     init_lr,
                     num_train_steps,
                     num_warmup_steps,
                     use_tpu,
                     hvd=None,
                     use_fp16=False,
                     amp=False):
    """Creates an optimizer training op."""
    global_step = tf.train.get_or_create_global_step()

    # avoid step change in learning rate at end of warmup phase
    decayed_lr_at_crossover = init_lr * (
        1.0 - float(num_warmup_steps) / float(num_train_steps))
    adjusted_init_lr = init_lr * (init_lr / decayed_lr_at_crossover)
    print(
        'decayed_learning_rate_at_crossover_point = %e, adjusted_init_lr = %e'
        % (decayed_lr_at_crossover, adjusted_init_lr))

    learning_rate = tf.constant(value=adjusted_init_lr,
                                shape=[],
                                dtype=tf.float32)

    # Implements linear decay of the learning rate.
    learning_rate = tf.train.polynomial_decay(learning_rate,
                                              global_step,
                                              num_train_steps,
                                              end_learning_rate=0.0,
                                              power=1.0,
                                              cycle=False)

    # Implements linear warmup. I.e., if global_step < num_warmup_steps, the
    # learning rate will be `global_step/num_warmup_steps * init_lr`.
    if num_warmup_steps:
        global_steps_int = tf.cast(global_step, tf.int32)
        warmup_steps_int = tf.constant(num_warmup_steps, dtype=tf.int32)

        global_steps_float = tf.cast(global_steps_int, tf.float32)
        warmup_steps_float = tf.cast(warmup_steps_int, tf.float32)

        warmup_percent_done = global_steps_float / warmup_steps_float
        warmup_learning_rate = init_lr * warmup_percent_done

        is_warmup = tf.cast(global_steps_int < warmup_steps_int, tf.float32)
        learning_rate = ((1.0 - is_warmup) * learning_rate +
                         is_warmup * warmup_learning_rate)

    # It is recommended that you use this optimizer for fine tuning, since this
    # is how the model was trained (note that the Adam m/v variables are NOT
    # loaded from init_checkpoint.)
    optimizer = AdamWeightDecayOptimizer(
        learning_rate=learning_rate,
        weight_decay_rate=0.01,
        beta_1=0.9,
        beta_2=0.999,
        epsilon=1e-6,
        exclude_from_weight_decay=["LayerNorm", "layer_norm", "bias"])

    if use_tpu:
        optimizer = tf.contrib.tpu.CrossShardOptimizer(optimizer)
    else:
        if hvd is not None:
            from horovod.tensorflow.compression import Compression
            optimizer = hvd.DistributedOptimizer(optimizer,
                                                 sparse_as_dense=True,
                                                 compression=Compression.none)
        if use_fp16 or amp:
            loss_scale_manager = ExponentialUpdateLossScaleManager(
                init_loss_scale=2**32,
                incr_every_n_steps=1000,
                decr_every_n_nan_or_inf=2,
                decr_ratio=0.5)
            optimizer = LossScaleOptimizer(optimizer, loss_scale_manager)

    tvars = tf.trainable_variables()
    # grads_and_vars = optimizer.compute_gradients(loss, tvars)
    grads = tf.gradients(loss, tf.trainable_variables())
    grads_and_vars = list(zip(grads, tf.trainable_variables()))
    grads_and_vars = [(g, v) for g, v in grads_and_vars if g is not None]
    grads, tvars = list(zip(*grads_and_vars))
    all_are_finite = tf.reduce_all([tf.reduce_all(tf.is_finite(g)) for g in grads]) \
        if use_fp16 or amp else tf.constant(True, dtype=tf.bool)

    # This is how the model was pre-trained.
    # ensure global norm is a finite number
    # to prevent clip_by_global_norm from having a hizzy fit.
    (clipped_grads, _) = tf.clip_by_global_norm(
        grads,
        clip_norm=1.0,
        use_norm=tf.cond(all_are_finite, lambda: tf.global_norm(grads),
                         lambda: tf.constant(1.0)))

    train_op = optimizer.apply_gradients(list(zip(clipped_grads, tvars)),
                                         global_step=global_step)

    # Normally the global step update is done inside of `apply_gradients`.
    # However, `AdamWeightDecayOptimizer` doesn't do this. But if you use
    # a different optimizer, you should probably take this line out.
    new_global_step = tf.cond(all_are_finite, lambda: global_step + 1,
                              lambda: global_step)
    new_global_step = tf.identity(new_global_step, name='step_update')
    train_op = tf.group(train_op, [global_step.assign(new_global_step)])
    return train_op
Ejemplo n.º 7
0
def create_gan_model():

    gen = Generator(letant_size=100, tpu=True)
    dis = Discriminator(128, True)

    labels = tf.placeholder(tf.int32, shape=[None])
    real_images = dis.get_input_tensor()
    z = gen.get_input_tensor()

    G = gen.get_generator(z, labels)
    print("gen :", G)

    D_output_real, D_logits_real = dis.get_discriminator(G, labels)
    D_output_fake, D_logits_fake = dis.get_discriminator(G, labels, reuse=True)

    #############################
    #                           #
    #      Loss functions       #
    #                           #
    #############################

    D_loss, G_loss = loss(labels, D_output_real, D_logits_real, D_output_fake,
                          D_logits_fake, G)

    # Get all the trainable variables
    tvars = tf.trainable_variables()

    d_vars = [var for var in tvars if 'dis' in var.name]
    g_vars = [var for var in tvars if 'gen' in var.name]

    # Standard Optimizers
    D_trainer = tf.train.AdamOptimizer(0.002, 0.5)
    G_trainer = tf.train.AdamOptimizer(0.002, 0.5)

    loss_scale_manager_D = FixedLossScaleManager(5000)
    loss_scale_manager_G = FixedLossScaleManager(5000)

    loss_scale_optimizer_D = LossScaleOptimizer(D_trainer,
                                                loss_scale_manager_D)
    loss_scale_optimizer_G = LossScaleOptimizer(G_trainer,
                                                loss_scale_manager_G)

    grads_variables_D = loss_scale_optimizer_D.compute_gradients(
        D_loss, d_vars)
    grads_variables_G = loss_scale_optimizer_G.compute_gradients(
        G_loss, g_vars)

    training_step_op_D = loss_scale_optimizer_D.apply_gradients(
        grads_variables_D)
    training_step_op_G = loss_scale_optimizer_D.apply_gradients(
        grads_variables_G)

    init = tf.global_variables_initializer()

    samples = []

    batch_size = 128
    epochs = 100

    saver = tf.train.Saver(var_list=g_vars)

    with tf.Session() as sess:

        sess.run(init)

        start = time.time()

        # Recall an epoch is an entire run through the training data
        for e in range(epochs):
            # // indicates classic division
            num_batches = mnist.train.num_examples // batch_size

            for i in range(num_batches):
                # Grab batch of images
                batch = mnist.train.next_batch(batch_size)

                # Get images, reshape and rescale to pass to D
                batch_images = batch[0].astype(np.float16).reshape(
                    (batch_size, 784))
                batch_images = batch_images * 2 - 1

                # Z (random latent noise data for Models)
                # -1 to 1 because of tanh activation
                batch_z = np.random.uniform(-1, 1, size=(batch_size, 100))

                # Run optimizers, no need to save outputs, we won't use them
                _ = sess.run(training_step_op_D,
                             feed_dict={
                                 real_images: batch_images,
                                 z: batch_z
                             })
                _ = sess.run(training_step_op_G, feed_dict={z: batch_z})

            print("Currently on Epoch {} of {} total...".format(e + 1, epochs))

            # Sample from generator as we're training for viewing afterwards
            #sample_z = np.random.uniform(-1, 1, size=(1, 100))
            # gen_sample = sess.run(generator(z, reuse=True), feed_dict={z: sample_z})
            #
            # samples.append(gen_sample)
            saver.save(sess, '/models/model.ckpt')

        end = time.time()
        print(end - start)
Ejemplo n.º 8
0
class ACGAN(object):
    model_name = "ACGAN"  # name for checkpoint

    def __init__(self,
                 sess,
                 epoch,
                 batch_size,
                 z_dim,
                 dataset_name,
                 checkpoint_dir,
                 result_dir,
                 log_dir,
                 tpu=False):
        self.sess = sess
        self.dataset_name = dataset_name
        self.checkpoint_dir = checkpoint_dir
        self.result_dir = result_dir
        self.log_dir = log_dir
        self.epoch = epoch
        self.batch_size = batch_size

        if tpu:
            self.dtype = tf.float16
            self.nptype = np.float16
        else:
            self.dtype = tf.float32
            self.nptype = np.float32

        if dataset_name == 'mnist' or dataset_name == 'fashion-mnist' or dataset_name == 'quick_draw' or dataset_name == "cifar10":
            # parameters
            self.input_height = 32
            self.input_width = 32
            self.output_height = 32
            self.output_width = 32

            self.z_dim = z_dim  # dimension of noise-vector
            self.y_dim = 10  # dimension of code-vector (label)
            self.c_dim = 3

            # train
            self.learning_rate = 0.0002
            self.beta1 = 0.5

            # test
            self.sample_num = 64  # number of generated images to be saved

            # code
            self.len_discrete_code = 10  # categorical distribution (i.e. label)
            self.len_continuous_code = 2  # gaussian distribution (e.g. rotation, thickness)

            # load mnist
            #self.data_X, self.data_y = load_mnist(self.dataset_name)

            # load quick draw
            #self.data_X, self.data_y = load_quick_draw(self.dataset_name, tpu)
            self.data_X, self.data_y = load_cifar10(tpu)

            # get number of batches for a single epoch
            self.num_batches = len(self.data_X) // self.batch_size
        else:
            raise NotImplementedError

    def classifier(self, x, is_training=True, reuse=False):
        with tf.variable_scope("classifier", reuse=reuse):

            net = fc(x, 128, scope='c_fc1', activation_fn=None)

            # Batch normalization should be calculated as type of float32
            net = tf.cast(net, tf.float32)
            net = bn(net, is_training=is_training, scope='c_bn1')

            # Leveraging the tensors core for fully connected weight.
            net = tf.cast(net, tf.float16)
            net = tf.nn.leaky_relu(net, alpha=0.2)
            out_logit = fc(net, self.y_dim, scope='c_fc2', activation_fn=None)

            # Softmax should should be calculate as float32
            out_logit = tf.cast(out_logit, tf.float32)
            out = tf.nn.softmax(out_logit)

            return out, out_logit

    # def classifier(self, x, is_training=True, reuse=False):
    #     # Network Architecture is exactly same as in infoGAN (https://arxiv.org/abs/1606.03657)
    #     # Architecture : (64)5c2s-(128)5c2s_BL-FC1024_BL-FC128_BL-FC12S’
    #     # All layers except the last two layers are shared by discriminator
    #     with tf.variable_scope("classifier", reuse=reuse):
    #
    #         net = lrelu(bn(fc(x, 128, scope='c_fc1', activation_fn=None), is_training=is_training, scope='c_bn1'))
    #
    #         out_logit = fc(net, self.y_dim, scope='c_fc2', activation_fn=None)
    #         out = tf.nn.softmax(out_logit)
    #
    #         print("classsifier: ",out, out_logit )
    #
    #
    #         return out, out_logit

    def discriminator(self, x, is_training=True, reuse=False):
        # Network Architecture is exactly same as in infoGAN (https://arxiv.org/abs/1606.03657)
        # Architecture : (64)4c2s-(128)4c2s_BL-FC1024_BL-FC1_S
        with tf.variable_scope("discriminator", reuse=reuse):

            net = lrelu(
                conv2d(x, 64, 4, 4, 2, 2, name='d_conv1',
                       data_type=self.dtype))
            net = lrelu(
                bn(conv2d(net,
                          128,
                          4,
                          4,
                          2,
                          2,
                          name='d_conv2',
                          data_type=self.dtype),
                   is_training=is_training,
                   scope='d_bn2'))
            net = tf.reshape(net, [self.batch_size, -1])
            net = lrelu(
                bn(fc(net, 1024, scope='d_fc3', activation_fn=None),
                   is_training=is_training,
                   scope='d_bn3'))
            #out_logit = linear(net, 1, scope='d_fc4', data_type=self.dtype)
            #net = tf.cast(net, tf.float32)
            out_logit = fc(net, 1, scope='d_fc4', activation_fn=None)
            out = tf.nn.sigmoid(out_logit)
            print("discriminator: ", out, out_logit, net)

            return out, out_logit, net

    def generator(self, z, y, is_training=True, reuse=False):
        # Network Architecture is exactly same as in infoGAN (https://arxiv.org/abs/1606.03657)
        # Architecture : FC1024_BR-FC7x7x128_BR-(64)4dc2s_BR-(1)4dc2s_S
        with tf.variable_scope("generator", reuse=reuse):

            # merge noise and code
            z = concat([z, y], 1)

            net = tf.nn.relu(
                bn(fc(z, 1024, scope='g_fc1', activation_fn=None),
                   is_training=is_training,
                   scope='g_bn1'))
            net = tf.nn.relu(
                bn(fc(net, 128 * 8 * 8, scope='g_fc2', activation_fn=None),
                   is_training=is_training,
                   scope='g_bn2'))
            net = tf.reshape(net, [self.batch_size, 8, 8, 128])
            net = tf.nn.relu(
                bn(deconv2d(net, [self.batch_size, 16, 16, 64],
                            4,
                            4,
                            2,
                            2,
                            name='g_dc3',
                            data_type=self.dtype),
                   is_training=is_training,
                   scope='g_bn3'))

            out = tf.nn.sigmoid(
                deconv2d(net, [self.batch_size, 32, 32, 3],
                         4,
                         4,
                         2,
                         2,
                         name='g_dc4',
                         data_type=self.dtype))

            print("generator: ", out)
            return out

    def build_model(self):
        # some parameters
        image_dims = [self.input_height, self.input_width, self.c_dim]
        bs = self.batch_size
        """ Graph Input """
        # images
        self.inputs = tf.placeholder(self.dtype, [bs] + image_dims,
                                     name='real_images')

        # labels
        self.y = tf.placeholder(self.dtype, [bs, self.y_dim], name='y')

        # noises
        self.z = tf.placeholder(self.dtype, [bs, self.z_dim], name='z')
        """ Loss Function """
        ## 1. GAN Loss
        # output of D for real images
        D_real, D_real_logits, input4classifier_real = self.discriminator(
            self.inputs, is_training=True, reuse=False)

        # output of D for fake images
        G = self.generator(self.z, self.y, is_training=True, reuse=False)
        D_fake, D_fake_logits, input4classifier_fake = self.discriminator(
            G, is_training=True, reuse=True)

        # get loss for discriminator
        d_loss_real = tf.reduce_mean(
            tf.nn.sigmoid_cross_entropy_with_logits(
                logits=D_real_logits, labels=tf.ones_like(D_real)))
        d_loss_fake = tf.reduce_mean(
            tf.nn.sigmoid_cross_entropy_with_logits(
                logits=D_fake_logits, labels=tf.zeros_like(D_fake)))

        self.d_loss = tf.add(d_loss_real, d_loss_fake)

        # get loss for generator
        self.g_loss = tf.reduce_mean(
            tf.nn.sigmoid_cross_entropy_with_logits(
                logits=D_fake_logits, labels=tf.ones_like(D_fake)))

        ## 2. Information Loss
        code_fake, code_logit_fake = self.classifier(input4classifier_fake,
                                                     is_training=True,
                                                     reuse=False)
        code_real, code_logit_real = self.classifier(input4classifier_real,
                                                     is_training=True,
                                                     reuse=True)

        # For real samples
        q_real_loss = tf.reduce_mean(
            tf.nn.softmax_cross_entropy_with_logits_v2(logits=code_logit_real,
                                                       labels=self.y))

        # For fake samples
        q_fake_loss = tf.reduce_mean(
            tf.nn.softmax_cross_entropy_with_logits_v2(logits=code_logit_fake,
                                                       labels=self.y))

        # get information loss
        self.q_loss = tf.add(q_fake_loss, q_real_loss)
        """ Training """
        # divide trainable variables into a group for D and a group for G
        print(1)
        t_vars = tf.trainable_variables()

        d_vars = [var for var in t_vars if 'd_' in var.name]
        g_vars = [var for var in t_vars if 'g_' in var.name]
        q_vars = [
            var for var in t_vars
            if ('d_' in var.name) or ('c_' in var.name) or ('g_' in var.name)
        ]

        print(d_vars)
        print(g_vars)
        print(q_vars)

        print(2)
        # optimizers
        # with tf.control_dependencies(tf.get_collection(tf.GraphKeys.UPDATE_OPS)):
        #     self.d_optim = tf.train.AdamOptimizer(self.learning_rate, beta1=self.beta1) \
        #         .minimize(self.d_loss, var_list=d_vars)
        #     self.g_optim = tf.train.AdamOptimizer(self.learning_rate * 5, beta1=self.beta1) \
        #         .minimize(self.g_loss, var_list=g_vars)
        #     self.q_optim = tf.train.AdamOptimizer(self.learning_rate * 5, beta1=self.beta1) \
        #         .minimize(self.q_loss, var_list=q_vars)

        with tf.control_dependencies(tf.get_collection(
                tf.GraphKeys.UPDATE_OPS)):
            self.d_optim = tf.train.AdamOptimizer(self.learning_rate,
                                                  beta1=self.beta1)

            self.g_optim = tf.train.AdamOptimizer(self.learning_rate,
                                                  beta1=self.beta1)

            self.q_optim = tf.train.AdamOptimizer(self.learning_rate,
                                                  beta1=self.beta1)

            scale = 2

            self.loss_scale_manager_D = FixedLossScaleManager(scale)
            self.loss_scale_manager_G = FixedLossScaleManager(scale)
            self.loss_scale_manager_Q = FixedLossScaleManager(scale)

            print(3)

            self.loss_scale_optimizer_D = LossScaleOptimizer(
                self.d_optim, self.loss_scale_manager_D)
            self.loss_scale_optimizer_G = LossScaleOptimizer(
                self.g_optim, self.loss_scale_manager_G)
            self.loss_scale_optimizer_Q = LossScaleOptimizer(
                self.q_optim, self.loss_scale_manager_Q)

            print(4)

            self.grads_variables_D = self.loss_scale_optimizer_D.compute_gradients(
                self.d_loss)
            self.grads_variables_G = self.loss_scale_optimizer_G.compute_gradients(
                self.g_loss)
            self.grads_variables_Q = self.loss_scale_optimizer_Q.compute_gradients(
                self.q_loss)

            print(self.grads_variables_D)
            print(self.grads_variables_G)
            print(self.grads_variables_Q)

            self.q_grads = [(g, v) for (g, v) in self.grads_variables_Q
                            if g is not None]
            print('New Q_grad:', self.q_grads)

            self.training_step_op_D = self.loss_scale_optimizer_D.apply_gradients(
                self.grads_variables_D)
            self.training_step_op_G = self.loss_scale_optimizer_G.apply_gradients(
                self.grads_variables_G)
            self.training_step_op_Q = self.loss_scale_optimizer_Q.apply_gradients(
                self.grads_variables_Q)
        """" Testing """
        # for test
        self.fake_images = self.generator(self.z,
                                          self.y,
                                          is_training=False,
                                          reuse=True)
        """ Summary """
        d_loss_real_sum = tf.summary.scalar("d_loss_real", d_loss_real)
        d_loss_fake_sum = tf.summary.scalar("d_loss_fake", d_loss_fake)
        d_loss_sum = tf.summary.scalar("d_loss", self.d_loss)
        g_loss_sum = tf.summary.scalar("g_loss", self.g_loss)

        q_loss_sum = tf.summary.scalar("g_loss", self.q_loss)
        q_real_sum = tf.summary.scalar("q_real_loss", q_real_loss)
        q_fake_sum = tf.summary.scalar("q_fake_loss", q_fake_loss)

        # final summary operations
        self.g_sum = tf.summary.merge([d_loss_fake_sum, g_loss_sum])
        self.d_sum = tf.summary.merge([d_loss_real_sum, d_loss_sum])
        self.q_sum = tf.summary.merge([q_loss_sum, q_real_sum, q_fake_sum])

    def train(self):

        # initialize all variables
        tf.global_variables_initializer().run()

        # graph inputs for visualize training results
        self.sample_z = np.random.uniform(-1,
                                          1,
                                          size=(self.batch_size, self.z_dim))
        self.test_codes = self.data_y[0:self.batch_size]

        # saver to save model
        self.saver = tf.train.Saver()

        # summary writer
        self.writer = tf.summary.FileWriter(
            self.log_dir + '/' + self.model_name, self.sess.graph)

        # restore check-point if it exits
        could_load, checkpoint_counter = self.load(self.checkpoint_dir)
        if could_load:
            start_epoch = (int)(checkpoint_counter / self.num_batches)
            start_batch_id = checkpoint_counter - start_epoch * self.num_batches
            counter = checkpoint_counter
            print(" [*] Load SUCCESS")
        else:
            start_epoch = 0
            start_batch_id = 0
            counter = 1
            print(" [!] Load failed...")

        # loop for epoch
        start_time = time.time()
        for epoch in range(start_epoch, self.epoch):

            # get batch data
            for idx in range(start_batch_id, self.num_batches):
                batch_images = self.data_X[idx * self.batch_size:(idx + 1) *
                                           self.batch_size]
                batch_codes = self.data_y[idx * self.batch_size:(idx + 1) *
                                          self.batch_size]

                batch_z = np.random.uniform(
                    -1, 1, [self.batch_size, self.z_dim]).astype(self.nptype)

                # update D network
                _, summary_str, d_loss = self.sess.run(
                    [self.training_step_op_D, self.d_sum, self.d_loss],
                    feed_dict={
                        self.inputs: batch_images,
                        self.y: batch_codes,
                        self.z: batch_z
                    })
                self.writer.add_summary(summary_str, counter)

                # update G & Q network
                _, summary_str_g, g_loss, _, summary_str_q, q_loss = self.sess.run(
                    [
                        self.training_step_op_G, self.g_sum, self.g_loss,
                        self.training_step_op_Q, self.q_sum, self.q_loss
                    ],
                    feed_dict={
                        self.z: batch_z,
                        self.y: batch_codes,
                        self.inputs: batch_images
                    })
                self.writer.add_summary(summary_str_g, counter)
                self.writer.add_summary(summary_str_q, counter)

                # display training status
                counter += 1
                print("Epoch: [%2d] [%4d/%4d] time: %4.4f, d_loss: %.8f, g_loss: %.8f" \
                      % (epoch, idx, self.num_batches, time.time() - start_time, d_loss, g_loss))

                # save training results for every 300 steps
                if np.mod(counter, 300) == 0:
                    samples = self.sess.run(self.fake_images,
                                            feed_dict={
                                                self.z: self.sample_z,
                                                self.y: self.test_codes
                                            })
                    tot_num_samples = min(self.sample_num, self.batch_size)
                    manifold_h = int(np.floor(np.sqrt(tot_num_samples)))
                    manifold_w = int(np.floor(np.sqrt(tot_num_samples)))
                    save_images(
                        samples[:manifold_h * manifold_w, :, :, :],
                        [manifold_h, manifold_w], './' +
                        check_folder(self.result_dir + '/' + self.model_dir) +
                        '/' + self.model_name +
                        '_train_{:02d}_{:04d}.png'.format(epoch, idx))

            # After an epoch, start_batch_id is set to zero
            # non-zero value is only for the first epoch after loading pre-trained model
            start_batch_id = 0

            # save model
            self.save(self.checkpoint_dir, counter)

            # show temporal results
            self.visualize_results(epoch)

        # save model for final step
        self.save(self.checkpoint_dir, counter)

    def visualize_results(self, epoch):
        tot_num_samples = min(self.sample_num, self.batch_size)
        image_frame_dim = int(np.floor(np.sqrt(tot_num_samples)))
        z_sample = np.random.uniform(-1, 1, size=(self.batch_size, self.z_dim))
        """ random noise, random discrete code, fixed continuous code """
        y = np.random.choice(self.len_discrete_code, self.batch_size)
        y_one_hot = np.zeros((self.batch_size, self.y_dim))
        y_one_hot[np.arange(self.batch_size), y] = 1

        samples = self.sess.run(self.fake_images,
                                feed_dict={
                                    self.z: z_sample,
                                    self.y: y_one_hot
                                })

        save_images(
            samples[:image_frame_dim * image_frame_dim, :, :, :],
            [image_frame_dim, image_frame_dim],
            check_folder(self.result_dir + '/' + self.model_dir) + '/' +
            self.model_name + '_epoch%03d' % epoch + '_test_all_classes.png')
        """ specified condition, random noise """
        n_styles = 10  # must be less than or equal to self.batch_size

        np.random.seed()
        si = np.random.choice(self.batch_size, n_styles)

        for l in range(self.len_discrete_code):
            y = np.zeros(self.batch_size, dtype=np.int64) + l
            y_one_hot = np.zeros((self.batch_size, self.y_dim))
            y_one_hot[np.arange(self.batch_size), y] = 1

            samples = self.sess.run(self.fake_images,
                                    feed_dict={
                                        self.z: z_sample,
                                        self.y: y_one_hot
                                    })
            save_images(
                samples[:image_frame_dim * image_frame_dim, :, :, :],
                [image_frame_dim, image_frame_dim],
                check_folder(self.result_dir + '/' + self.model_dir) + '/' +
                self.model_name + '_epoch%03d' % epoch +
                '_test_class_%d.png' % l)

            samples = samples[si, :, :, :]

            if l == 0:
                all_samples = samples
            else:
                all_samples = np.concatenate((all_samples, samples), axis=0)
        """ save merged images to check style-consistency """
        canvas = np.zeros_like(all_samples)
        for s in range(n_styles):
            for c in range(self.len_discrete_code):
                canvas[s * self.len_discrete_code +
                       c, :, :, :] = all_samples[c * n_styles + s, :, :, :]

        save_images(
            canvas, [n_styles, self.len_discrete_code],
            check_folder(self.result_dir + '/' + self.model_dir) + '/' +
            self.model_name + '_epoch%03d' % epoch +
            '_test_all_classes_style_by_style.png')

    @property
    def model_dir(self):
        return "{}_{}_{}_{}".format(self.model_name, self.dataset_name,
                                    self.batch_size, self.z_dim)

    def save(self, checkpoint_dir, step):
        checkpoint_dir = os.path.join(checkpoint_dir, self.model_dir,
                                      self.model_name)

        if not os.path.exists(checkpoint_dir):
            os.makedirs(checkpoint_dir)

        self.saver.save(self.sess,
                        os.path.join(checkpoint_dir,
                                     self.model_name + '.model'),
                        global_step=step)

    def load(self, checkpoint_dir):
        import re
        print(" [*] Reading checkpoints...")
        checkpoint_dir = os.path.join(checkpoint_dir, self.model_dir,
                                      self.model_name)

        ckpt = tf.train.get_checkpoint_state(checkpoint_dir)
        if ckpt and ckpt.model_checkpoint_path:
            ckpt_name = os.path.basename(ckpt.model_checkpoint_path)
            self.saver.restore(self.sess,
                               os.path.join(checkpoint_dir, ckpt_name))
            counter = int(
                next(re.finditer("(\d+)(?!.*\d)", ckpt_name)).group(0))
            print(" [*] Success to read {}".format(ckpt_name))
            return True, counter
        else:
            print(" [*] Failed to find a checkpoint")
            return False, 0
Ejemplo n.º 9
0
    def build_model(self):
        # some parameters
        image_dims = [self.input_height, self.input_width, self.c_dim]
        bs = self.batch_size
        """ Graph Input """
        # images
        self.inputs = tf.placeholder(self.dtype, [bs] + image_dims,
                                     name='real_images')

        # labels
        self.y = tf.placeholder(self.dtype, [bs, self.y_dim], name='y')

        # noises
        self.z = tf.placeholder(self.dtype, [bs, self.z_dim], name='z')
        """ Loss Function """
        ## 1. GAN Loss
        # output of D for real images
        D_real, D_real_logits, input4classifier_real = self.discriminator(
            self.inputs, is_training=True, reuse=False)

        # output of D for fake images
        G = self.generator(self.z, self.y, is_training=True, reuse=False)
        D_fake, D_fake_logits, input4classifier_fake = self.discriminator(
            G, is_training=True, reuse=True)

        # get loss for discriminator
        d_loss_real = tf.reduce_mean(
            tf.nn.sigmoid_cross_entropy_with_logits(
                logits=D_real_logits, labels=tf.ones_like(D_real)))
        d_loss_fake = tf.reduce_mean(
            tf.nn.sigmoid_cross_entropy_with_logits(
                logits=D_fake_logits, labels=tf.zeros_like(D_fake)))

        self.d_loss = tf.add(d_loss_real, d_loss_fake)

        # get loss for generator
        self.g_loss = tf.reduce_mean(
            tf.nn.sigmoid_cross_entropy_with_logits(
                logits=D_fake_logits, labels=tf.ones_like(D_fake)))

        ## 2. Information Loss
        code_fake, code_logit_fake = self.classifier(input4classifier_fake,
                                                     is_training=True,
                                                     reuse=False)
        code_real, code_logit_real = self.classifier(input4classifier_real,
                                                     is_training=True,
                                                     reuse=True)

        # For real samples
        q_real_loss = tf.reduce_mean(
            tf.nn.softmax_cross_entropy_with_logits_v2(logits=code_logit_real,
                                                       labels=self.y))

        # For fake samples
        q_fake_loss = tf.reduce_mean(
            tf.nn.softmax_cross_entropy_with_logits_v2(logits=code_logit_fake,
                                                       labels=self.y))

        # get information loss
        self.q_loss = tf.add(q_fake_loss, q_real_loss)
        """ Training """
        # divide trainable variables into a group for D and a group for G
        print(1)
        t_vars = tf.trainable_variables()

        d_vars = [var for var in t_vars if 'd_' in var.name]
        g_vars = [var for var in t_vars if 'g_' in var.name]
        q_vars = [
            var for var in t_vars
            if ('d_' in var.name) or ('c_' in var.name) or ('g_' in var.name)
        ]

        print(d_vars)
        print(g_vars)
        print(q_vars)

        print(2)
        # optimizers
        # with tf.control_dependencies(tf.get_collection(tf.GraphKeys.UPDATE_OPS)):
        #     self.d_optim = tf.train.AdamOptimizer(self.learning_rate, beta1=self.beta1) \
        #         .minimize(self.d_loss, var_list=d_vars)
        #     self.g_optim = tf.train.AdamOptimizer(self.learning_rate * 5, beta1=self.beta1) \
        #         .minimize(self.g_loss, var_list=g_vars)
        #     self.q_optim = tf.train.AdamOptimizer(self.learning_rate * 5, beta1=self.beta1) \
        #         .minimize(self.q_loss, var_list=q_vars)

        with tf.control_dependencies(tf.get_collection(
                tf.GraphKeys.UPDATE_OPS)):
            self.d_optim = tf.train.AdamOptimizer(self.learning_rate,
                                                  beta1=self.beta1)

            self.g_optim = tf.train.AdamOptimizer(self.learning_rate,
                                                  beta1=self.beta1)

            self.q_optim = tf.train.AdamOptimizer(self.learning_rate,
                                                  beta1=self.beta1)

            scale = 2

            self.loss_scale_manager_D = FixedLossScaleManager(scale)
            self.loss_scale_manager_G = FixedLossScaleManager(scale)
            self.loss_scale_manager_Q = FixedLossScaleManager(scale)

            print(3)

            self.loss_scale_optimizer_D = LossScaleOptimizer(
                self.d_optim, self.loss_scale_manager_D)
            self.loss_scale_optimizer_G = LossScaleOptimizer(
                self.g_optim, self.loss_scale_manager_G)
            self.loss_scale_optimizer_Q = LossScaleOptimizer(
                self.q_optim, self.loss_scale_manager_Q)

            print(4)

            self.grads_variables_D = self.loss_scale_optimizer_D.compute_gradients(
                self.d_loss)
            self.grads_variables_G = self.loss_scale_optimizer_G.compute_gradients(
                self.g_loss)
            self.grads_variables_Q = self.loss_scale_optimizer_Q.compute_gradients(
                self.q_loss)

            print(self.grads_variables_D)
            print(self.grads_variables_G)
            print(self.grads_variables_Q)

            self.q_grads = [(g, v) for (g, v) in self.grads_variables_Q
                            if g is not None]
            print('New Q_grad:', self.q_grads)

            self.training_step_op_D = self.loss_scale_optimizer_D.apply_gradients(
                self.grads_variables_D)
            self.training_step_op_G = self.loss_scale_optimizer_G.apply_gradients(
                self.grads_variables_G)
            self.training_step_op_Q = self.loss_scale_optimizer_Q.apply_gradients(
                self.grads_variables_Q)
        """" Testing """
        # for test
        self.fake_images = self.generator(self.z,
                                          self.y,
                                          is_training=False,
                                          reuse=True)
        """ Summary """
        d_loss_real_sum = tf.summary.scalar("d_loss_real", d_loss_real)
        d_loss_fake_sum = tf.summary.scalar("d_loss_fake", d_loss_fake)
        d_loss_sum = tf.summary.scalar("d_loss", self.d_loss)
        g_loss_sum = tf.summary.scalar("g_loss", self.g_loss)

        q_loss_sum = tf.summary.scalar("g_loss", self.q_loss)
        q_real_sum = tf.summary.scalar("q_real_loss", q_real_loss)
        q_fake_sum = tf.summary.scalar("q_fake_loss", q_fake_loss)

        # final summary operations
        self.g_sum = tf.summary.merge([d_loss_fake_sum, g_loss_sum])
        self.d_sum = tf.summary.merge([d_loss_real_sum, d_loss_sum])
        self.q_sum = tf.summary.merge([q_loss_sum, q_real_sum, q_fake_sum])
Ejemplo n.º 10
0
    # Create training graph
    with tf.device('/gpu:0'), \
         tf.variable_scope(
             # Note: This forces trainable variables to be stored as float32
             'fp32_storage',custom_getter=float32_variable_storage_getter):
        data, target, loss = create_simple_model(nbatch, nin, nout, dtype)
        variables = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)

        model_opt = tf.train.AdamOptimizer(learning_rate,
                                           momentum)  # Adam optimizer
        # Note: Loss scaling can improve numerical stability for fp16 training
        scale_size = 128  # There is no one scale size

        loss_scale_manager = FixedLossScaleManager(scale_size)

        loss_scale_optimizer = LossScaleOptimizer(model_opt,
                                                  loss_scale_manager)

        grads_variables = loss_scale_optimizer.compute_gradients(
            loss, variables)
        """
        Doing some gradient manipulation (if needed) 

        only example!

        grads_variables = [(g,v) for (g,v) in grads_variables if g is not None]

        """

        training_opt = loss_scale_optimizer.apply_gradients(grads_variables)
        init_op = tf.global_variables_initializer()
Ejemplo n.º 11
0
  def __init__(self, optimizer_name, lr, hparams, use_tpu=False):  # pylint: disable=super-init-not-called
    tf.logging.info("Using optimizer %s", optimizer_name)

    mlperf_log.transformer_print(key=mlperf_log.OPT_NAME,
                                 value=optimizer_name,
                                 hparams=hparams)
    mlperf_log.transformer_print(
        key=mlperf_log.OPT_HP_ADAM_BETA1, value=hparams.optimizer_adam_beta1,
        hparams=hparams)
    mlperf_log.transformer_print(
        key=mlperf_log.OPT_HP_ADAM_BETA2, value=hparams.optimizer_adam_beta2,
        hparams=hparams)
    mlperf_log.transformer_print(
        key=mlperf_log.OPT_HP_ADAM_EPSILON,
        value=hparams.optimizer_adam_epsilon,
        hparams=hparams)

    if optimizer_name == "Adam":
      # We change the default epsilon for Adam.
      # Using LazyAdam as it's much faster for large vocabulary embeddings.
      self._opt = tf.contrib.opt.LazyAdamOptimizer(
          lr,
          beta1=hparams.optimizer_adam_beta1,
          beta2=hparams.optimizer_adam_beta2,
          epsilon=hparams.optimizer_adam_epsilon)
    elif optimizer_name == "MultistepAdam":
      self._opt = multistep_optimizer.MultistepAdamOptimizer(
          lr,
          beta1=hparams.optimizer_adam_beta1,
          beta2=hparams.optimizer_adam_beta2,
          epsilon=hparams.optimizer_adam_epsilon,
          n=hparams.optimizer_multistep_accumulate_steps)
    elif optimizer_name == "Momentum":
      self._opt = tf.train.MomentumOptimizer(
          lr,
          momentum=hparams.optimizer_momentum_momentum,
          use_nesterov=hparams.optimizer_momentum_nesterov)
    elif optimizer_name == "YellowFin":
      self._opt = yellowfin.YellowFinOptimizer(
          learning_rate=lr, momentum=hparams.optimizer_momentum_momentum)
    elif optimizer_name == "TrueAdam":
      self._opt = tf.train.AdamOptimizer(
          lr,
          beta1=hparams.optimizer_adam_beta1,
          beta2=hparams.optimizer_adam_beta2,
          epsilon=hparams.optimizer_adam_epsilon)
    elif optimizer_name == "AdamW":
      # Openai gpt used weight decay.
      # Given the internals of AdamW, weight decay dependent on the
      # learning rate is chosen to match the openai implementation.
      # The weight decay update to each parameter is applied before the adam
      # gradients computation, which is different from that described
      # in the paper and in the openai implementation:
      # https://arxiv.org/pdf/1711.05101.pdf
      self._opt = tf.contrib.opt.AdamWOptimizer(
          0.01*lr,
          lr,
          beta1=hparams.optimizer_adam_beta1,
          beta2=hparams.optimizer_adam_beta2,
          epsilon=hparams.optimizer_adam_epsilon)
    elif optimizer_name == "Adafactor":
      self._opt = adafactor.adafactor_optimizer_from_hparams(hparams, lr)
    else:
      self._opt = tf.contrib.layers.OPTIMIZER_CLS_NAMES[optimizer_name](lr)
    if _mixed_precision_is_enabled(hparams):
      if not hparams.mixed_precision_optimizer_loss_scaler:
        tf.logging.warning("Using mixed precision without a loss scaler will "
                           "likely cause numerical errors.")
      elif hparams.mixed_precision_optimizer_loss_scaler != "exponential":
        raise ValueError("Mixed precision training only supports the "
                         "exponential loss scaler")
      else:
        tf.logging.info("Using Exponential Update Loss Scaler")
        manager = tf.contrib.mixed_precision.ExponentialUpdateLossScaleManager(
            init_loss_scale=2**15,
            incr_every_n_steps=2000,
            decr_every_n_nan_or_inf=2,
            incr_ratio=2,
            decr_ratio=0.5)
        self._opt = LossScaleOptimizer(self._opt, manager)

    self._zero_grads = hparams.optimizer_zero_grads