Esempio n. 1
0
 def logvars(variables, label, print_variables=False):
     if print_variables:
         tf.logging.info("%s (%s parameters): %s", label,
                         paramcount(variables), pps(variables))
     else:
         tf.logging.info("%s (%s parameters)", label, paramcount(variables))
     return variables
Esempio n. 2
0
def create_train_op(loss, params):
    tf.logging.info("create_train_op(loss=%s, params=%s)", loss, pps(params))
    lr = params["lr"]
    global_step = tf.train.get_global_step()
    assert global_step is not None
    if "warmup_steps" in params.keys():
        tf.logging.info(
            'create_train_op: lr = cosine_decay_with_warmup(%s, %s, %s, warmup_steps=%s)',
            global_step, lr, params["max_steps"], params["warmup_steps"])
        lr = cosine_decay_with_warmup(global_step,
                                      lr,
                                      params["max_steps"],
                                      warmup_steps=params["warmup_steps"])

    if params["opt_name"] == "adam":
        if not "weight_decay" in params.keys():
            optimizer = tf.train.AdamOptimizer(learning_rate=lr,
                                               beta1=params["beta1"],
                                               beta2=params["beta2"],
                                               epsilon=params["epsilon"])
            tf.logging.info(
                'create_train_op: optimizer = tf.train.AdamOptimizer(learning_rate=%s, beta1=%s, beta2=%s, epsilon=%s)',
                lr, params["beta1"], params["beta2"], params["epsilon"])

        else:
            optimizer = tf.contrib.opt.AdamWOptimizer(
                learning_rate=lr,
                weight_decay=lr * params["weight_decay"],
                beta1=params["beta1"],
                beta2=params["beta2"],
                epsilon=params["epsilon"])
            tf.logging.info(
                'create_train_op: optimizer = tf.train.AdamWOptimizer(learning_rate=%s, weight_decay=lr*%s, beta1=%s, beta2=%s, epsilon=%s)',
                lr, params["weight_decay"], params["beta1"], params["beta2"],
                params["epsilon"])

    elif params["opt_name"] == "adafactor":
        if params["decay_type"] == "adam":
            decay_rate = adafactor_decay_rate_adam(params["beta2"])
        elif params["decay_type"] == "pow":
            decay_rate = adafactor_decay_rate_pow(params["decay_exponent"])
        elif params["decay_type"] == "none":
            decay_rate = None
        else:
            raise ValueError("unknown optimizer_adafactor_decay_type")

        if not "weight_decay" in params.keys():
            optimizer = AdafactorOptimizer(learning_rate=lr,
                                           decay_rate=decay_rate,
                                           beta1=params["beta1"],
                                           name="Adafactor")
            tf.logging.info(
                'create_train_op: optimizer = AdafactorOptimizer(learning_rate=%s, decay_rate=%s, beta1=%s)',
                lr, decay_rate, params["beta1"])

        else:
            AdafactorWOptimizer = tf.contrib.opt.extend_with_decoupled_weight_decay(
                AdafactorOptimizer)

            optimizer = AdafactorWOptimizer(
                weight_decay=params["weight_decay"] * lr,
                learning_rate=lr,
                decay_rate=decay_rate,
                beta1=params["beta1"],
                name="AdafactorW")
            tf.logging.info(
                'create_train_op: optimizer = AdafactorWOptimizer(weight_decay=lr*%s, learning_rate=%s, decay_rate=%s, beta1=%s)',
                params["weight_decay"], lr, decay_rate, params["beta1"])

    else:
        raise ValueError("Unknown optimizer type!")

    if params["use_tpu"]:
        optimizer = tf.contrib.tpu.CrossShardOptimizer(optimizer)

    update_ops = tf.get_collection(
        tf.GraphKeys.UPDATE_OPS)  # To update batchnorm, if present
    only_train_transformer_layers = False if 'only_train_transformer_layers' not in params else params[
        'only_train_transformer_layers']

    def should_train_variable(v):
        if only_train_transformer_layers:
            if '/h' not in v.name and '/ln_f' not in v.name:
                tf.logging.info("NOT training variable: %s", v)
                return False
            #for i in range(1):
            #  if ('/h%01d/' % i) in v.name:
            #    return False
            #  if ('/h%02d/' % i) in v.name:
            #    return False
        tf.logging.info("    training variable: %s", v)
        return True

    train_vars = [
        v for v in tf.trainable_variables() if should_train_variable(v)
    ]
    non_train_vars = [
        v for v in tf.trainable_variables() if not should_train_variable(v)
    ]
    other_vars = [
        v for v in tf.global_variables()
        if v not in train_vars and v not in non_train_vars
    ]
    local_vars = [v for v in tf.local_variables()]

    paramcount = lambda vs: sum([np.prod(v.shape.as_list()) for v in vs])

    def logvars(variables, label, print_variables=False):
        if print_variables:
            tf.logging.info("%s (%s parameters): %s", label,
                            paramcount(variables), pps(variables))
        else:
            tf.logging.info("%s (%s parameters)", label, paramcount(variables))
        return variables

    tf.logging.info(
        "Training %d parameters (%.2fM) out of %d parameters (%.2fM)" % (
            paramcount(train_vars),
            paramcount(train_vars) / (1024.0 * 1024.0),
            paramcount(tf.trainable_variables()),
            paramcount(tf.trainable_variables()) / (1024.0 * 1024.0),
        ))

    tf.logging.info("---------")
    tf.logging.info("Variable details:")
    logvars(train_vars, "trainable variables", print_variables=True)
    logvars(non_train_vars, "non-trainable variables", print_variables=True)
    logvars(other_vars, "other global variables", print_variables=True)
    logvars(local_vars, "other local variables", print_variables=True)

    tf.logging.info("---------")
    tf.logging.info("Variable summary:")
    logvars(train_vars, "trainable variables")
    logvars(non_train_vars, "non-trainable variables")
    logvars(other_vars, "other global variables")
    logvars(local_vars, "other local variables")

    tf.logging.info("---------")
    tf.logging.info("Gradient options:")
    #use_memory_saving_gradients=True
    use_memory_saving_gradients = False if 'memory_saving_gradients' not in params else params[
        'memory_saving_gradients']
    colocate_gradients_with_ops = True if 'colocate_gradients' not in params else params[
        'colocate_gradients']
    gate_gradients = None
    tf.logging.info("use_memory_saving_gradients=%s",
                    use_memory_saving_gradients)
    tf.logging.info("colocate_gradients_with_ops=%s",
                    colocate_gradients_with_ops)
    tf.logging.info("gate_gradients=%s", gate_gradients)
    if use_memory_saving_gradients:
        #grads = memory_saving_gradients.gradients(loss, train_vars, colocate_gradients_with_ops=colocate_gradients_with_ops, checkpoints='memory')
        #grads = memory_saving_gradients.gradients_memory if i == 0 else memory_saving_gradients.gradients_speed
        #grads = memory_saving_gradients.gradients_speed if i == 0 else memory_saving_gradients.gradients_speed
        grads = memory_saving_gradients.gradients
        grads = grads(loss,
                      train_vars,
                      colocate_grients_with_ops=colocate_gradients_with_ops,
                      gate_gradients=gate_gradients)
    else:
        grads = gradients.gradients(
            loss,
            train_vars,
            colocate_gradients_with_ops=colocate_gradients_with_ops,
            gate_gradients=gate_gradients)

    grads = list(zip(grads, train_vars))
    disconnected_grads = [v for g, v in grads if g is None]
    grads = [(g, v) if g is not None else (tf.zeros_like(v), v)
             for g, v in grads]  # replace disconnected gradients with zeros

    tf.logging.info("---------")
    tf.logging.info("Gradient details:")
    tf.logging.info("%s", pps(grads))
    tf.logging.info("Disconnected gradients:")
    tf.logging.info("%s", pps(disconnected_grads))
    tf.logging.info("---------")

    #train_op = optimizer.minimize(loss, global_step=global_step)
    train_op = optimizer.apply_gradients(grads, global_step=global_step)
    train_op = tf.group([train_op, update_ops], name="train_op")

    return train_op