def logvars(variables, label, print_variables=False): if print_variables: tf.logging.info("%s (%s parameters): %s", label, paramcount(variables), pps(variables)) else: tf.logging.info("%s (%s parameters)", label, paramcount(variables)) return variables
def create_train_op(loss, params): tf.logging.info("create_train_op(loss=%s, params=%s)", loss, pps(params)) lr = params["lr"] global_step = tf.train.get_global_step() assert global_step is not None if "warmup_steps" in params.keys(): tf.logging.info( 'create_train_op: lr = cosine_decay_with_warmup(%s, %s, %s, warmup_steps=%s)', global_step, lr, params["max_steps"], params["warmup_steps"]) lr = cosine_decay_with_warmup(global_step, lr, params["max_steps"], warmup_steps=params["warmup_steps"]) if params["opt_name"] == "adam": if not "weight_decay" in params.keys(): optimizer = tf.train.AdamOptimizer(learning_rate=lr, beta1=params["beta1"], beta2=params["beta2"], epsilon=params["epsilon"]) tf.logging.info( 'create_train_op: optimizer = tf.train.AdamOptimizer(learning_rate=%s, beta1=%s, beta2=%s, epsilon=%s)', lr, params["beta1"], params["beta2"], params["epsilon"]) else: optimizer = tf.contrib.opt.AdamWOptimizer( learning_rate=lr, weight_decay=lr * params["weight_decay"], beta1=params["beta1"], beta2=params["beta2"], epsilon=params["epsilon"]) tf.logging.info( 'create_train_op: optimizer = tf.train.AdamWOptimizer(learning_rate=%s, weight_decay=lr*%s, beta1=%s, beta2=%s, epsilon=%s)', lr, params["weight_decay"], params["beta1"], params["beta2"], params["epsilon"]) elif params["opt_name"] == "adafactor": if params["decay_type"] == "adam": decay_rate = adafactor_decay_rate_adam(params["beta2"]) elif params["decay_type"] == "pow": decay_rate = adafactor_decay_rate_pow(params["decay_exponent"]) elif params["decay_type"] == "none": decay_rate = None else: raise ValueError("unknown optimizer_adafactor_decay_type") if not "weight_decay" in params.keys(): optimizer = AdafactorOptimizer(learning_rate=lr, decay_rate=decay_rate, beta1=params["beta1"], name="Adafactor") tf.logging.info( 'create_train_op: optimizer = AdafactorOptimizer(learning_rate=%s, decay_rate=%s, beta1=%s)', lr, decay_rate, params["beta1"]) else: AdafactorWOptimizer = tf.contrib.opt.extend_with_decoupled_weight_decay( AdafactorOptimizer) optimizer = AdafactorWOptimizer( weight_decay=params["weight_decay"] * lr, learning_rate=lr, decay_rate=decay_rate, beta1=params["beta1"], name="AdafactorW") tf.logging.info( 'create_train_op: optimizer = AdafactorWOptimizer(weight_decay=lr*%s, learning_rate=%s, decay_rate=%s, beta1=%s)', params["weight_decay"], lr, decay_rate, params["beta1"]) else: raise ValueError("Unknown optimizer type!") if params["use_tpu"]: optimizer = tf.contrib.tpu.CrossShardOptimizer(optimizer) update_ops = tf.get_collection( tf.GraphKeys.UPDATE_OPS) # To update batchnorm, if present only_train_transformer_layers = False if 'only_train_transformer_layers' not in params else params[ 'only_train_transformer_layers'] def should_train_variable(v): if only_train_transformer_layers: if '/h' not in v.name and '/ln_f' not in v.name: tf.logging.info("NOT training variable: %s", v) return False #for i in range(1): # if ('/h%01d/' % i) in v.name: # return False # if ('/h%02d/' % i) in v.name: # return False tf.logging.info(" training variable: %s", v) return True train_vars = [ v for v in tf.trainable_variables() if should_train_variable(v) ] non_train_vars = [ v for v in tf.trainable_variables() if not should_train_variable(v) ] other_vars = [ v for v in tf.global_variables() if v not in train_vars and v not in non_train_vars ] local_vars = [v for v in tf.local_variables()] paramcount = lambda vs: sum([np.prod(v.shape.as_list()) for v in vs]) def logvars(variables, label, print_variables=False): if print_variables: tf.logging.info("%s (%s parameters): %s", label, paramcount(variables), pps(variables)) else: tf.logging.info("%s (%s parameters)", label, paramcount(variables)) return variables tf.logging.info( "Training %d parameters (%.2fM) out of %d parameters (%.2fM)" % ( paramcount(train_vars), paramcount(train_vars) / (1024.0 * 1024.0), paramcount(tf.trainable_variables()), paramcount(tf.trainable_variables()) / (1024.0 * 1024.0), )) tf.logging.info("---------") tf.logging.info("Variable details:") logvars(train_vars, "trainable variables", print_variables=True) logvars(non_train_vars, "non-trainable variables", print_variables=True) logvars(other_vars, "other global variables", print_variables=True) logvars(local_vars, "other local variables", print_variables=True) tf.logging.info("---------") tf.logging.info("Variable summary:") logvars(train_vars, "trainable variables") logvars(non_train_vars, "non-trainable variables") logvars(other_vars, "other global variables") logvars(local_vars, "other local variables") tf.logging.info("---------") tf.logging.info("Gradient options:") #use_memory_saving_gradients=True use_memory_saving_gradients = False if 'memory_saving_gradients' not in params else params[ 'memory_saving_gradients'] colocate_gradients_with_ops = True if 'colocate_gradients' not in params else params[ 'colocate_gradients'] gate_gradients = None tf.logging.info("use_memory_saving_gradients=%s", use_memory_saving_gradients) tf.logging.info("colocate_gradients_with_ops=%s", colocate_gradients_with_ops) tf.logging.info("gate_gradients=%s", gate_gradients) if use_memory_saving_gradients: #grads = memory_saving_gradients.gradients(loss, train_vars, colocate_gradients_with_ops=colocate_gradients_with_ops, checkpoints='memory') #grads = memory_saving_gradients.gradients_memory if i == 0 else memory_saving_gradients.gradients_speed #grads = memory_saving_gradients.gradients_speed if i == 0 else memory_saving_gradients.gradients_speed grads = memory_saving_gradients.gradients grads = grads(loss, train_vars, colocate_grients_with_ops=colocate_gradients_with_ops, gate_gradients=gate_gradients) else: grads = gradients.gradients( loss, train_vars, colocate_gradients_with_ops=colocate_gradients_with_ops, gate_gradients=gate_gradients) grads = list(zip(grads, train_vars)) disconnected_grads = [v for g, v in grads if g is None] grads = [(g, v) if g is not None else (tf.zeros_like(v), v) for g, v in grads] # replace disconnected gradients with zeros tf.logging.info("---------") tf.logging.info("Gradient details:") tf.logging.info("%s", pps(grads)) tf.logging.info("Disconnected gradients:") tf.logging.info("%s", pps(disconnected_grads)) tf.logging.info("---------") #train_op = optimizer.minimize(loss, global_step=global_step) train_op = optimizer.apply_gradients(grads, global_step=global_step) train_op = tf.group([train_op, update_ops], name="train_op") return train_op