def setup(): if not horovod_installed: return False global horovod_initialized if horovod_initialized: return hvd hvd.init() horovod_initialized = True horovod_num_worker = hvd.size() horovod_rank = hvd.rank() # verify that MPI multi-threading is supported. assert hvd.mpi_threads_supported() # make sure MPI is not re-initialized. import mpi4py.rc mpi4py.rc.initialize = False # import mpi4py from mpi4py import MPI comm = MPI.COMM_WORLD # check size and rank are syncronized assert horovod_num_worker == comm.Get_size() assert horovod_rank == comm.Get_rank() return hvd
def setup_horovod(): import horovod.tensorflow as hvd # Initialize Horovod hvd.init() # Verify that MPI multi-threading is supported. assert hvd.mpi_threads_supported() from mpi4py import MPI assert hvd.size() == MPI.COMM_WORLD.Get_size() is_root = hvd.rank() == 0 def mpi_average(local_list): # _local_list_orig = local_list local_list = list(map(float, local_list)) # print('RANK {} AVERAGING {} -> {}'.format(hvd.rank(), _local_list_orig, local_list)) sums = MPI.COMM_WORLD.gather(sum(local_list), root=0) counts = MPI.COMM_WORLD.gather(len(local_list), root=0) sum_counts = sum(counts) if is_root else None avg = (sum(sums) / sum_counts) if is_root else None return avg, sum_counts return hvd, MPI, is_root, mpi_average
def init_workers(distributed=False): if distributed and not no_horovod: hvd.init() assert hvd.mpi_threads_supported() from mpi4py import MPI assert hvd.size() == MPI.COMM_WORLD.Get_size() comm = MPI.COMM_WORLD print("Rank: {}, Size: {}".format(hvd.rank(), hvd.size())) return SimpleNamespace(rank=hvd.rank(), size=hvd.size(), local_rank=hvd.local_rank(), local_size=hvd.local_size(), comm=comm) else: print("not doing distributed") return SimpleNamespace(rank=0, size=1, local_rank=0, local_size=1, comm=None)
def start_training(config): if config.IS_DISTRIBUTION: import horovod.tensorflow as hvd # initialize Horovod. hvd.init() num_worker = hvd.size() rank = hvd.rank() # verify that MPI multi-threading is supported. assert hvd.mpi_threads_supported() # make sure MPI is not re-initialized. import mpi4py.rc mpi4py.rc.initialize = False # import mpi4py from mpi4py import MPI comm = MPI.COMM_WORLD # check size and rank are syncronized assert num_worker == comm.Get_size() assert rank == comm.Get_rank() else: num_worker = 1 rank = 0 ModelClass = config.NETWORK_CLASS network_kwargs = dict( (key.lower(), val) for key, val in config.NETWORK.items()) if "train_validation_saving_size".upper() in config.DATASET.keys(): use_train_validation_saving = config.DATASET.TRAIN_VALIDATION_SAVING_SIZE > 0 else: use_train_validation_saving = False if use_train_validation_saving: top_train_validation_saving_set_accuracy = 0 train_dataset = setup_dataset(config, "train", rank) print("train dataset num:", train_dataset.num_per_epoch) if use_train_validation_saving: train_validation_saving_dataset = setup_dataset( config, "train_validation_saving", rank) print("train_validation_saving dataset num:", train_validation_saving_dataset.num_per_epoch) validation_dataset = setup_dataset(config, "validation", rank) print("validation dataset num:", validation_dataset.num_per_epoch) graph = tf.Graph() with graph.as_default(): if ModelClass.__module__.startswith("lmnet.networks.object_detection"): model = ModelClass( classes=train_dataset.classes, num_max_boxes=train_dataset.num_max_boxes, is_debug=config.IS_DEBUG, **network_kwargs, ) elif ModelClass.__module__.startswith("lmnet.networks.segmentation"): model = ModelClass( classes=train_dataset.classes, label_colors=train_dataset.label_colors, is_debug=config.IS_DEBUG, **network_kwargs, ) else: model = ModelClass( classes=train_dataset.classes, is_debug=config.IS_DEBUG, **network_kwargs, ) global_step = tf.Variable(0, name="global_step", trainable=False) is_training_placeholder = tf.placeholder( tf.bool, name="is_training_placeholder") images_placeholder, labels_placeholder = model.placeholderes() output = model.inference(images_placeholder, is_training_placeholder) if ModelClass.__module__.startswith("lmnet.networks.object_detection"): loss = model.loss(output, labels_placeholder, is_training_placeholder) else: loss = model.loss(output, labels_placeholder) opt = model.optimizer(global_step) if config.IS_DISTRIBUTION: # add Horovod Distributed Optimizer opt = hvd.DistributedOptimizer(opt) train_op = model.train(loss, opt, global_step) metrics_ops_dict, metrics_update_op = model.metrics( output, labels_placeholder) # TODO(wakisaka): Deal with many networks. model.summary(output, labels_placeholder) summary_op = tf.summary.merge_all() metrics_summary_op, metrics_placeholders = executor.prepare_metrics( metrics_ops_dict) init_op = tf.global_variables_initializer() reset_metrics_op = tf.local_variables_initializer() if config.IS_DISTRIBUTION: # add Horovod broadcasting variables from rank 0 to all bcast_global_variables_op = hvd.broadcast_global_variables(0) if use_train_validation_saving: saver = tf.train.Saver(max_to_keep=1) else: saver = tf.train.Saver(max_to_keep=None) if config.IS_PRETRAIN: all_vars = tf.global_variables() pretrain_var_list = [ var for var in all_vars if var.name.startswith(tuple(config.PRETRAIN_VARS)) ] print("pretrain_vars", [var.name for var in pretrain_var_list]) pretrain_saver = tf.train.Saver(pretrain_var_list, name="pretrain_saver") if config.IS_DISTRIBUTION: # For distributed training session_config = tf.ConfigProto(gpu_options=tf.GPUOptions( allow_growth=True, visible_device_list=str(hvd.local_rank()))) else: # TODO(wakisaka): For debug. # session_config = tf.ConfigProto( # gpu_options=tf.GPUOptions( # allow_growth=True, # per_process_gpu_memory_fraction=0.1 # ) # ) session_config = tf.ConfigProto( ) # tf.ConfigProto(log_device_placement=True) # TODO(wakisaka): XLA JIT # session_config.graph_options.optimizer_options.global_jit_level = tf.OptimizerOptions.ON_1 sess = tf.Session(graph=graph, config=session_config) sess.run([init_op, reset_metrics_op]) if rank == 0: train_writer = tf.summary.FileWriter( environment.TENSORBOARD_DIR + "/train", sess.graph) if use_train_validation_saving: train_val_saving_writer = tf.summary.FileWriter( environment.TENSORBOARD_DIR + "/train_validation_saving") val_writer = tf.summary.FileWriter(environment.TENSORBOARD_DIR + "/validation") if config.IS_PRETRAIN: print("------- Load pretrain data ----------") pretrain_saver.restore( sess, os.path.join(config.PRETRAIN_DIR, config.PRETRAIN_FILE)) sess.run(tf.assign(global_step, 0)) last_step = 0 # for recovery ckpt = tf.train.get_checkpoint_state(environment.CHECKPOINTS_DIR) if ckpt and ckpt.model_checkpoint_path: print("--------- Restore last checkpoint -------------") saver.restore(sess, ckpt.model_checkpoint_path) # saver.recover_last_checkpoints(ckpt.model_checkpoint_path) last_step = sess.run(global_step) # TODO(wakisaka): tensorflow v1.3 remain previous event log in tensorboard. # https://github.com/tensorflow/tensorflow/blob/r1.3/tensorflow/python/training/supervisor.py#L1072 train_writer.add_session_log(SessionLog(status=SessionLog.START), global_step=last_step + 1) val_writer.add_session_log(SessionLog(status=SessionLog.START), global_step=last_step + 1) print("recovered. last step", last_step) if config.IS_DISTRIBUTION: # broadcast variables from rank 0 to all other processes sess.run(bcast_global_variables_op) # calculate step per epoch for each nodes train_num_per_epoch = train_dataset.num_per_epoch num_per_nodes = (train_num_per_epoch + num_worker - 1) // num_worker step_per_epoch = num_per_nodes // config.BATCH_SIZE begin_index = (train_num_per_epoch * rank) // num_worker end_index = begin_index + num_per_nodes last_step = sess.run(global_step) # Calculate max steps. The priority of config.MAX_EPOCHS is higher than config.MAX_STEPS. if "MAX_EPOCHS" in config: max_steps = int(train_dataset.num_per_epoch / config.BATCH_SIZE * config.MAX_EPOCHS) else: max_steps = config.MAX_STEPS print("max_steps: {}".format(max_steps)) for step in range(last_step, max_steps): print("step", step) if config.IS_DISTRIBUTION: # scatter dataset if step % step_per_epoch == 0: indices = train_dataset.get_shuffle_index( ) if rank == 0 else None # broadcast shuffled indices indices = comm.bcast(indices, 0) feed_indices = indices[begin_index:end_index] # update each dataset by splited indices train_dataset.update_dataset(feed_indices) images, labels = train_dataset.feed() feed_dict = { is_training_placeholder: True, images_placeholder: images, labels_placeholder: labels, } if step * ((step + 1) % config.SUMMARISE_STEPS) == 0 and rank == 0: # Runtime statistics for develop. # run_options = tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE) # run_metadata = tf.RunMetadata() sess.run(reset_metrics_op) _, summary, _ = sess.run( [train_op, summary_op, metrics_update_op], feed_dict=feed_dict, # options=run_options, # run_metadata=run_metadata, ) # train_writer.add_run_metadata(run_metadata, "step: {}".format(step + 1)) train_writer.add_summary(summary, step + 1) metrics_values = sess.run(list(metrics_ops_dict.values())) metrics_feed_dict = { placeholder: value for placeholder, value in zip(metrics_placeholders, metrics_values) } metrics_summary, = sess.run( [metrics_summary_op], feed_dict=metrics_feed_dict, ) train_writer.add_summary(metrics_summary, step + 1) else: sess.run([train_op], feed_dict=feed_dict) to_be_saved = step == 0 or ( step + 1) == max_steps or (step + 1) % config.SAVE_STEPS == 0 if to_be_saved and rank == 0: if use_train_validation_saving: sess.run(reset_metrics_op) train_validation_saving_step_size = int( math.ceil(train_validation_saving_dataset.num_per_epoch / config.BATCH_SIZE)) print("train_validation_saving_step_size", train_validation_saving_step_size) current_train_validation_saving_set_accuracy = 0 for train_validation_saving_step in range( train_validation_saving_step_size): print("train_validation_saving_step", train_validation_saving_step) images, labels = train_validation_saving_dataset.feed() feed_dict = { is_training_placeholder: False, images_placeholder: images, labels_placeholder: labels, } if train_validation_saving_step % config.SUMMARISE_STEPS == 0: summary, _ = sess.run([summary_op, metrics_update_op], feed_dict=feed_dict) train_val_saving_writer.add_summary(summary, step + 1) else: sess.run([metrics_update_op], feed_dict=feed_dict) metrics_values = sess.run(list(metrics_ops_dict.values())) metrics_feed_dict = { placeholder: value for placeholder, value in zip(metrics_placeholders, metrics_values) } metrics_summary, = sess.run( [metrics_summary_op], feed_dict=metrics_feed_dict, ) train_val_saving_writer.add_summary(metrics_summary, step + 1) current_train_validation_saving_set_accuracy = sess.run( metrics_ops_dict["accuracy"]) if current_train_validation_saving_set_accuracy > top_train_validation_saving_set_accuracy: top_train_validation_saving_set_accuracy = current_train_validation_saving_set_accuracy print("New top train_validation_saving accuracy is: ", top_train_validation_saving_set_accuracy) _save_checkpoint(saver, sess, global_step, step) else: _save_checkpoint(saver, sess, global_step, step) if step == 0: # check create pb on only first step. minimal_graph = tf.graph_util.convert_variables_to_constants( sess, sess.graph.as_graph_def(add_shapes=True), ["output"], ) pb_name = "minimal_graph_with_shape_{}.pb".format(step + 1) pbtxt_name = "minimal_graph_with_shape_{}.pbtxt".format(step + 1) tf.train.write_graph(minimal_graph, environment.CHECKPOINTS_DIR, pb_name, as_text=False) tf.train.write_graph(minimal_graph, environment.CHECKPOINTS_DIR, pbtxt_name, as_text=True) if step == 0 or (step + 1) % config.TEST_STEPS == 0: # init metrics values sess.run(reset_metrics_op) test_step_size = int( math.ceil(validation_dataset.num_per_epoch / config.BATCH_SIZE)) print("test_step_size", test_step_size) for test_step in range(test_step_size): print("test_step", test_step) images, labels = validation_dataset.feed() feed_dict = { is_training_placeholder: False, images_placeholder: images, labels_placeholder: labels, } if test_step % config.SUMMARISE_STEPS == 0: summary, _ = sess.run([summary_op, metrics_update_op], feed_dict=feed_dict) if rank == 0: val_writer.add_summary(summary, step + 1) else: sess.run([metrics_update_op], feed_dict=feed_dict) metrics_values = sess.run(list(metrics_ops_dict.values())) metrics_feed_dict = { placeholder: value for placeholder, value in zip(metrics_placeholders, metrics_values) } metrics_summary, = sess.run( [metrics_summary_op], feed_dict=metrics_feed_dict, ) if rank == 0: val_writer.add_summary(metrics_summary, step + 1) # training loop end. print("reach max step")
# ============================================================================== from __future__ import absolute_import from __future__ import division from __future__ import print_function import sys import os import horovod.tensorflow as hvd import mpi4py from mpi4py import MPI comm = MPI.COMM_WORLD mpi4py.rc.initialize = False hvd.init() assert hvd.mpi_threads_supported() assert hvd.size() == comm.Get_size() rank = comm.Get_rank() if rank == 0: s = 'abcdef' data = {'key1': [7, 2.72, 2 + 3j], 'key2': ('abc', 'xyz')} print('before broadcasting: process %d has %s' % (rank, data), s) else: s = None data = None print('before broadcasting: process %d has %s' % (rank, data), s) data = comm.bcast(data, root=0) s = comm.bcast(s, root=0)
import horovod.tensorflow as hvd # Initialize Horovod hvd.init() # Verify that MPI multi-threading is supported print(hvd.mpi_threads_supported()) assert hvd.mpi_threads_supported() from mpi4py import MPI print(hvd.size()) assert hvd.size() == MPI.COMM_WORLD.Get_size()
def evaluate( *, flow_constructor, seed, restore_checkpoint, total_bs, iw_samples=1024, # 4096 is too slow for ImageNet dtype=tf.float32, dataset='imagenet32', samples_filename='samples.png', extra_dims=3, ): import horovod.tensorflow as hvd # Initialize Horovod hvd.init() # Verify that MPI multi-threading is supported. assert hvd.mpi_threads_supported() from mpi4py import MPI assert hvd.size() == MPI.COMM_WORLD.Get_size() is_root = hvd.rank() == 0 def mpi_average(local_list): local_list = list(map(float, local_list)) sums = MPI.COMM_WORLD.gather(sum(local_list), root=0) counts = MPI.COMM_WORLD.gather(len(local_list), root=0) sum_counts = sum(counts) if is_root else None avg = (sum(sums) / sum_counts) if is_root else None return avg, sum_counts restore_checkpoint = os.path.expanduser(restore_checkpoint) # Seeding and logging setup seed_all(hvd.rank() + hvd.size() * seed) assert total_bs % hvd.size() == 0 local_bs = total_bs // hvd.size() assert iw_samples % total_bs == 0 if is_root: print('===== EVALUATING {} ({} IW samples) ====='.format( restore_checkpoint, iw_samples)) # Load data assert dataset in ['imagenet32', 'imagenet64', 'imagenet64_5bit'] if is_root: print('Loading data') MPI.COMM_WORLD.Barrier() if dataset == 'imagenet32': """The dataset as a npy file on RAM. There are as many copies as number of MPI threads. This isn't effficient and tf.Records would be better to read from disk. This is just done to ensure bits/dim reported are perfect and no data loading bugs creep in. However, the dataset is quite small resolution and even 8 MPI threads can work on 40GB RAM.""" # data_train = np.load('../train_32x32.npy') data_val = np.load('../valid_32x32.npy') # assert data_train.dtype == 'uint8' # assert np.max(data_train) <= 255 # assert np.min(data_train) >= 0 assert np.max(data_val) <= 255 assert np.min(data_val) >= 0 assert data_val.dtype == 'uint8' elif dataset == 'imagenet64': """The dataset as a npy file on RAM. There are as many copies as number of MPI threads. This isn't effficient and tf.Records would be better to read from disk. This is just done to ensure bits/dim reported are perfect and no data loading bugs creep in. If you don't have enough CPU RAM to run 8 threads, run it with fewer threads and adjust batch-size / model-size tradeoff accordingly.""" data_train = np.load('../train_64x64.npy') data_val = np.load('../valid_64x64.npy') assert data_train.dtype == 'uint8' assert np.max(data_train) <= 255 assert np.min(data_train) >= 0 assert np.max(data_val) <= 255 assert np.min(data_val) >= 0 elif dataset == 'imagenet64_5bit': """Similar loading as above. Quantized to 5-bit while loading.""" if is_root: data_train = np.load('../train_64x64.npy') data_train = np.floor(data_train / 8.) data_train = data_train.astype('uint8') assert np.max(data_train) <= 31 assert np.min(data_train) >= 0 np.save('../train_64x64_5bit.npy', data_train) del data_train data_val = np.load('../valid_64x64.npy') data_val = np.floor(data_val / 8.) data_val = data_val.astype('uint8') assert np.max(data_val) <= 31 assert np.min(data_val) >= 0 np.save('../valid_64x64_5bit.npy', data_val) del data_val MPI.COMM_WORLD.Barrier() data_train = np.load('../train_64x64_5bit.npy') data_val = np.load('../valid_64x64_5bit.npy') # data_train = data_train.astype(dtype.as_numpy_dtype) data_val = data_val.astype(dtype.as_numpy_dtype) img_shp = list(data_val.shape[1:]) if dataset == 'imagenet32': assert img_shp == [32, 32, 3] else: assert img_shp == [64, 64, 3] if is_root: # print('Training data: {}, Validation data: {}'.format(data_train.shape[0], data_val.shape[0])) print('Image shape:', img_shp) bpd_scale_factor = 1. / (np.log(2) * np.prod(img_shp)) # Build graph if is_root: print('Building graph') dequant_flow, flow, posterior_flow = flow_constructor() x_sym = tf.placeholder(dtype, [local_bs] + img_shp) # This is a fake training graph. Just used to mimic flow_training, so we can load from the saver build_forward(x=x_sym, dequant_flow=dequant_flow, flow=flow, posterior_flow=posterior_flow, flow_kwargs=dict(init=False, ema=None, dropout_p=0, verbose=is_root) # note dropout is 0: it doesn't matter ) # EMA params = tf.trainable_variables() if is_root: print('Parameters', sum(np.prod(p.get_shape().as_list()) for p in params)) ema = tf.train.ExponentialMovingAverage( decay=0.9999999999999) # ema turned off maintain_averages_op = tf.group(ema.apply(params)) # Validation and sampling (with EMA) if is_root: print('===== Validation graph =====') val_flow_kwargs = dict(init=False, dropout_p=0, ema=ema, verbose=is_root) val_loss_sym, val_logratio_sym, val_dequant_x_sym = build_forward( x=x_sym, dequant_flow=dequant_flow, flow=flow, posterior_flow=posterior_flow, flow_kwargs=val_flow_kwargs) allgathered_val_logratios_sym = hvd.allgather(val_logratio_sym) # for debugging invertibility # val_inverr_sym = tf.reduce_max(tf.abs( # val_dequant_x_sym - flow.inverse(val_y_sym, dropout_p=0, ema=ema, verbose=is_root)[0] # )) if is_root: print('===== Sampling graph =====') samples_sym, _ = flow.sample(local_bs, flow_kwargs=val_flow_kwargs) allgathered_samples_sym = hvd.allgather(tf.to_float(samples_sym)) assert len(tf.trainable_variables()) == len(params) def run_iw_eval(sess): if is_root: print('Running IW eval with {} samples...'.format(iw_samples)) # Go through one example at a time all_val_losses = [] for i_example in (trange if is_root else range)(len(data_val)): # take this single example and tile it batch_x = np.tile(data_val[i_example, None, ...], (local_bs, 1, 1, 1)) # repeatedly evaluate logd for the IWAE bound batch_logratios = np.concatenate([ sess.run(allgathered_val_logratios_sym, {x_sym: batch_x}) for _ in range(iw_samples // total_bs) ]).astype(np.float64) assert batch_logratios.shape == (iw_samples, ) # log [1/n \sum_i exp(r_i)] = log [exp(-b) 1/n \sum_i exp(r_i + b)] = -b + log [1/n \sum_i exp(r_i + b)] shift = batch_logratios.max() all_val_losses.append( -bpd_scale_factor * (shift + np.log(np.mean(np.exp(batch_logratios - shift))))) if i_example % 100 == 0 and is_root: print(i_example, np.mean(all_val_losses)) if is_root: print(f'Final ({len(data_val)}):', np.mean(all_val_losses)) def run_sampling_only(sess, *, prefix=dataset, dump_to_tensorboard=True, save_jpg=False): samples = sess.run(allgathered_samples_sym) if is_root: print('samples gathered from the session') if dataset == 'imagenet64_5bit': """Quantized values. So different kind of sampling needed here.""" samples = np.floor(np.clip(samples, 0, 31)) samples = samples * 8 samples = samples.astype('uint8') # np.save('samples_' + prefix + '.npy', samples) import cv2 samples = tile_imgs( np.floor(np.clip(samples, 0, 255)).astype('uint8')) cv2.imwrite(samples_filename, samples) def run_validation(sess): data_val_shard = np.array_split(data_val, hvd.size(), axis=0)[hvd.rank()] shard_losses, shard_corr = zip(*[ sess.run([val_loss_sym, val_corr_sym], {x_sym: val_batch}) for val_batch, in iterbatches([data_val_shard], batch_size=local_bs, include_final_partial_batch=False) ]) val_loss, total_count = mpi_average(shard_losses) val_corr, _ = mpi_average(shard_corr) if is_root: for k, v in [ ('val_bpd', bpd_scale_factor * val_loss), ('val_corr', val_corr), ('num_val_examples', total_count * local_bs), ]: print(k, v) # Run config = tf.ConfigProto() config.gpu_options.allow_growth = True config.gpu_options.visible_device_list = str( hvd.local_rank()) # Pin GPU to local rank (one GPU per process) with tf.Session(config=config) as sess: if is_root: print('Initializing') sess.run(tf.global_variables_initializer()) # Restore from checkpoint if is_root: print('Restoring checkpoint:', restore_checkpoint) saver = tf.train.Saver() saver.restore(sess, restore_checkpoint) print('Broadcasting initial parameters') sess.run(hvd.broadcast_global_variables(0)) sess.graph.finalize() # if samples_filename: # run_sampling_only(sess) # Make sure data is the same on all MPI processes tmp_inds = [0, 183, 3, 6, 20, 88] check_batch = np.ascontiguousarray(data_val[tmp_inds]) gathered_batches = np.zeros( (hvd.size(), *check_batch.shape), check_batch.dtype) if is_root else None MPI.COMM_WORLD.Gather(check_batch, gathered_batches, root=0) if is_root: assert all( np.allclose(check_batch, b) for b in gathered_batches), 'data must be in the same order!' print('data ordering ok') # Run validation run_validation(sess) run_iw_eval(sess)
def train( *, flow_constructor, logdir, lr_schedule, dropout_p, seed, init_bs, total_bs, ema_decay, steps_per_log, max_grad_norm, dtype=tf.float32, scale_loss=None, dataset='imagenet32', steps_per_samples=20000, steps_per_dump=5000, n_epochs=2, restore_checkpoint=None, dump_samples_to_tensorboard=True, save_jpg=True, ): import horovod.tensorflow as hvd # Initialize Horovod hvd.init() # Verify that MPI multi-threading is supported. assert hvd.mpi_threads_supported() from mpi4py import MPI assert hvd.size() == MPI.COMM_WORLD.Get_size() is_root = hvd.rank() == 0 def mpi_average(local_list): local_list = list(map(float, local_list)) sums = MPI.COMM_WORLD.gather(sum(local_list), root=0) counts = MPI.COMM_WORLD.gather(len(local_list), root=0) sum_counts = sum(counts) if is_root else None avg = (sum(sums) / sum_counts) if is_root else None return avg, sum_counts # Seeding and logging setup seed_all(hvd.rank() + hvd.size() * seed) assert total_bs % hvd.size() == 0 local_bs = total_bs // hvd.size() logger = None logdir = '{}_mpi{}_{}'.format(os.path.expanduser(logdir), hvd.size(), time.time()) checkpointdir = os.path.join(logdir, 'checkpoints') profiledir = os.path.join(logdir, 'profiling') if is_root: print('Floating point format:', dtype) pprint(locals()) os.makedirs(logdir) os.makedirs(checkpointdir) os.makedirs(profiledir) logger = TensorBoardOutput(logdir) # Load data assert dataset in ['imagenet32', 'imagenet64', 'imagenet64_5bit'] if is_root: print('Loading data') MPI.COMM_WORLD.Barrier() if dataset == 'imagenet32': """The dataset as a npy file on RAM. There are as many copies as number of MPI threads. This isn't effficient and tf.Records would be better to read from disk. This is just done to ensure bits/dim reported are perfect and no data loading bugs creep in. However, the dataset is quite small resolution and even 8 MPI threads can work on 40GB RAM.""" data_train = np.load('../train_32x32.npy') data_val = np.load('../valid_32x32.npy') assert data_train.dtype == 'uint8' assert np.max(data_train) <= 255 assert np.min(data_train) >= 0 assert np.max(data_val) <= 255 assert np.min(data_val) >= 0 assert data_val.dtype == 'uint8' elif dataset == 'imagenet64': """The dataset as a npy file on RAM. There are as many copies as number of MPI threads. This isn't effficient and tf.Records would be better to read from disk. This is just done to ensure bits/dim reported are perfect and no data loading bugs creep in. If you don't have enough CPU RAM to run 8 threads, run it with fewer threads and adjust batch-size / model-size tradeoff accordingly.""" data_train = np.load('../train_64x64.npy') data_val = np.load('../valid_64x64.npy') assert data_train.dtype == 'uint8' assert np.max(data_train) <= 255 assert np.min(data_train) >= 0 assert np.max(data_val) <= 255 assert np.min(data_val) >= 0 elif dataset == 'imagenet64_5bit': """Similar loading as above. Quantized to 5-bit while loading.""" if is_root: data_train = np.load('../train_64x64.npy') data_train = np.floor(data_train / 8.) data_train = data_train.astype('uint8') assert np.max(data_train) <= 31 assert np.min(data_train) >= 0 np.save('../train_64x64_5bit.npy', data_train) del data_train data_val = np.load('../valid_64x64.npy') data_val = np.floor(data_val / 8.) data_val = data_val.astype('uint8') assert np.max(data_val) <= 31 assert np.min(data_val) >= 0 np.save('../valid_64x64_5bit.npy', data_val) del data_val MPI.COMM_WORLD.Barrier() data_train = np.load('../train_64x64_5bit.npy') data_val = np.load('../valid_64x64_5bit.npy') data_train = data_train.astype(dtype.as_numpy_dtype) data_val = data_val.astype(dtype.as_numpy_dtype) img_shp = list(data_train.shape[1:]) if dataset == 'imagenet32': assert img_shp == [32, 32, 3] else: assert img_shp == [64, 64, 3] if is_root: print('Training data: {}, Validation data: {}'.format( data_train.shape[0], data_val.shape[0])) print('Image shape:', img_shp) bpd_scale_factor = 1. / (np.log(2) * np.prod(img_shp)) # Build graph if is_root: print('Building graph') dequant_flow, flow, posterior_flow = flow_constructor() # Data-dependent init if restore_checkpoint is None: if is_root: print('===== Init graph =====') x_init_sym = tf.placeholder(dtype, [init_bs] + img_shp) init_syms, _ = build_forward(x=x_init_sym, dequant_flow=dequant_flow, flow=flow, posterior_flow=posterior_flow, flow_kwargs=dict(init=True, dropout_p=dropout_p, verbose=is_root)) # Training if is_root: print('===== Training graph =====') x_sym = tf.placeholder(dtype, [local_bs] + img_shp) loss_sym, _ = build_forward(x=x_sym, dequant_flow=dequant_flow, flow=flow, posterior_flow=posterior_flow, flow_kwargs=dict(dropout_p=dropout_p, verbose=is_root)) # EMA params = tf.trainable_variables() if is_root: print('Parameters', sum(np.prod(p.get_shape().as_list()) for p in params)) ema = tf.train.ExponentialMovingAverage(decay=ema_decay) maintain_averages_op = tf.group(ema.apply(params)) # Op for setting the ema params to the current non-ema params (for use after data-dependent init) name2var = {v.name: v for v in tf.global_variables()} copy_params_to_ema = tf.group([ name2var[p.name.replace(':0', '') + '/ExponentialMovingAverage:0'].assign(p) for p in params ]) # Validation and sampling (with EMA) if is_root: print('===== Validation graph =====') val_loss_sym, _ = build_forward(x=x_sym, dequant_flow=dequant_flow, flow=flow, posterior_flow=posterior_flow, flow_kwargs=dict(dropout_p=0, ema=ema, verbose=is_root)) # for debugging invertibility # val_inverr_sym = tf.reduce_max(tf.abs( # val_dequant_x_sym - flow.inverse(val_y_sym, dropout_p=0, ema=ema, verbose=is_root)[0] # )) if is_root: print('===== Sampling graph =====') samples_sym, _ = flow.sample(local_bs, flow_kwargs=dict(dropout_p=0., ema=ema, verbose=is_root)) allgathered_samples_sym = hvd.allgather(tf.to_float(samples_sym)) assert len(tf.trainable_variables()) == len(params) def run_sampling(sess, i_step, *, prefix=dataset, dump_to_tensorboard=True, save_jpg=False): samples = sess.run(allgathered_samples_sym) if is_root: print('samples gathered from the session') if dataset == 'imagenet64_5bit': """Quantized values. So different kind of sampling needed here.""" samples = np.floor(np.clip(samples, 0, 31)) samples = samples * 8 samples = samples.astype('uint8') # np.save('samples_' + prefix + '.npy', samples) # if save_jpg: # samples = tile_imgs(np.floor(np.clip(samples, 0, 255)).astype('uint8')) # cv2.imwrite('samples_' + prefix + '_' + str(i_step) + '.jpg', samples) if dump_to_tensorboard: """You can turn this off if tensorboard crashes for sample dumps. You can view the samples from the npy file anyway""" logger.writekvs( [('samples', tile_imgs(np.clip(samples, 0, 255).astype(np.uint8)))], i_step) def run_validation(sess, i_step): data_val_shard = np.array_split(data_val, hvd.size(), axis=0)[hvd.rank()] shard_losses = np.concatenate([ sess.run([val_loss_sym], {x_sym: val_batch}) for val_batch, in iterbatches([data_val_shard], batch_size=local_bs, include_final_partial_batch=False) ]) val_loss, total_count = mpi_average(shard_losses) if is_root: logger.writekvs([('val_bpd', bpd_scale_factor * val_loss), ('num_val_examples', total_count * local_bs)], i_step) # Optimization lr_sym = tf.placeholder(dtype, [], 'lr') optimizer = hvd.DistributedOptimizer(tf.train.AdamOptimizer(lr_sym)) if scale_loss is None: grads_and_vars = optimizer.compute_gradients(loss_sym, var_list=params) else: grads_and_vars = [(g / scale_loss, v) for (g, v) in optimizer.compute_gradients( loss_sym * scale_loss, var_list=params)] if max_grad_norm is not None: clipped_grads, grad_norm_sym = tf.clip_by_global_norm( [g for (g, _) in grads_and_vars], max_grad_norm) grads_and_vars = [ (cg, v) for (cg, (_, v)) in zip(clipped_grads, grads_and_vars) ] else: grad_norm_sym = tf.constant(0.) opt_sym = tf.group(optimizer.apply_gradients(grads_and_vars), maintain_averages_op) def loop(sess: tf.Session): i_step = 0 i_step_lr = 0 if is_root: print('Initializing') sess.run(tf.global_variables_initializer()) # if is_root: # logger.write_graph(sess.graph) if restore_checkpoint is not None: """If restoring from an existing checkpoint whose path is specified in the launcher""" restore_step = int(restore_checkpoint.split('-')[-1]) if is_root: saver = tf.train.Saver() print('Restoring checkpoint:', restore_checkpoint) print('Restoring from step:', restore_step) saver.restore(sess, restore_checkpoint) print('Loaded checkpoint') else: saver = None i_step = restore_step """You could re-start with the warm-up or start from wherever the checkpoint stopped depending on what is needed. If the session had to be stopped due to NaN/Inf, warm-up from a most recent working checkpoint is recommended. If it was because of Horovod Crash / Machine Shut down, re-starting from the same LR can be done in which case you need to uncomment the blow line. By default, it warms up.""" i_step_lr = restore_step else: if is_root: print('Data dependent init') sess.run( init_syms, { x_init_sym: data_train[np.random.randint(0, data_train.shape[0], init_bs)] }) sess.run(copy_params_to_ema) saver = tf.train.Saver() if is_root else None if is_root: print('Broadcasting initial parameters') sess.run(hvd.broadcast_global_variables(0)) sess.graph.finalize() if is_root: print('Training') print( 'Parameters(M)', sum(np.prod(p.get_shape().as_list()) for p in params) / 1024. / 1024.) loss_hist = deque(maxlen=steps_per_log) """ 2 epochs are sufficient to see good results on Imagenet. After 2 epochs, gains are marginal, but important for good bits/dim.""" for i_epoch in range(n_epochs): epoch_start_t = time.time() for i_epoch_step, (batch, ) in enumerate( iterbatches( # non-sharded: each gpu goes through the whole dataset [data_train], batch_size=local_bs, include_final_partial_batch=False, )): lr = lr_schedule(i_step_lr) loss, _ = sess.run( [loss_sym, opt_sym], { x_sym: batch, lr_sym: lr }, ) loss_hist.append(loss) if i_epoch == i_epoch_step == 0: epoch_start_t = time.time() if i_step % steps_per_log == 0: loss_hist_means = MPI.COMM_WORLD.gather(float( np.mean(loss_hist)), root=0) steps_per_sec = (i_epoch_step + 1) / (time.time() - epoch_start_t) if is_root: kvs = [ ('iter', i_step), ('epoch', i_epoch + i_epoch_step * local_bs / data_train.shape[0]), # epoch for this gpu ('bpd', float( np.mean(loss_hist_means) * bpd_scale_factor)), ('lr', float(lr)), ('fps', steps_per_sec * total_bs ), # fps calculated over all gpus (this epoch) ('sps', steps_per_sec), ] logger.writekvs(kvs, i_step) """You could pass the validation for Imagenet because the val set is reasonably big. It is extremely hard to overfit on Imagenet (if you manage to, let us know). So, skipping the validation throughout the training and validating at the end with the most recent checkpoint would be okay and good for wall clock time. You could also have steps_per_val specified in the launcher pretty high to find a balance.""" if i_step > 0 and i_step % steps_per_samples == 0 and i_step_lr > 0: run_sampling( sess, i_step=i_step, dump_to_tensorboard=dump_samples_to_tensorboard, save_jpg=save_jpg) print('Run Validation...') run_validation(sess, i_step) if i_step % steps_per_dump == 0 and i_step > 0 and i_step_lr > 0: if saver is not None: saver.save(sess, os.path.join(checkpointdir, 'model'), global_step=i_step) i_step += 1 i_step_lr += 1 # End of epoch # Train config = tf.ConfigProto() config.gpu_options.allow_growth = True config.gpu_options.visible_device_list = str( hvd.local_rank()) # Pin GPU to local rank (one GPU per process) with tf.Session(config=config) as sess: loop(sess)
def train(model, loss_fn, Dataset=None, dataset=None, valid_dataset=None, valid_dataset2=None, test_dataset=None, evaluate_fn=None, inference_fn=None, eval_fn=None, write_valid=True, valid_names=None, infer_names=None, infer_debug_names=None, valid_write_fn=None, infer_write_fn=None, valid_suffix='.valid', infer_suffix='.infer', write_streaming=False, optimizer=None, param_groups=None, init_fn=None, sep=','): use_horovod = 'OMPI_COMM_WORLD_RANK' in os.environ if Dataset is None: assert dataset logging.info('Dataset', Dataset, 'dataset', dataset, 'valid_dataset', valid_dataset, 'test_dataset', test_dataset, loss_fn) if FLAGS.torch: torch.manual_seed(FLAGS.seed or 0) if torch.cuda.device_count(): torch.cuda.manual_seed(FLAGS.seed or 0) if use_horovod: import horovod.torch as hvd hvd.init() #print('-----------------', hvd, hvd.size()) assert hvd.mpi_threads_supported() assert hvd.size() == comm.Get_size() # hvd.init already done on apps.train.py init torch.cuda.set_device(hvd.local_rank()) # https://pytorch.org/tutorials/beginner/blitz/data_parallel_tutorial.html else: if torch.cuda.device_count() > 1: model = torch.nn.DataParallel(model) model.to(device) input_ = FLAGS.train_input inputs = gezi.list_files(input_) inputs.sort() all_inputs = inputs #batch_size = FLAGS.batch_size batch_size = melt.batch_size() num_gpus = melt.num_gpus() #batch_size = max(batch_size, 1) #batch_size_ = batch_size if not FLAGS.batch_sizes else int(FLAGS.batch_sizes.split(',')[-1]) batch_size_ = FLAGS.eval_batch_size or batch_size if dataset is None: if FLAGS.fold is not None: inputs = [ x for x in inputs if not x.endswith('%d.record' % FLAGS.fold) and not x.endswith('%d.tfrecord' % FLAGS.fold) ] # if FLAGS.valid_input: # inputs += [x for x in gezi.list_files(FLAGS.valid_input) if not x.endswith('%d.record' % FLAGS.fold)] logging.info('inputs', len(inputs), inputs[:100]) num_folds = FLAGS.num_folds or len(inputs) + 1 if dataset is None: dataset = Dataset('train') assert len(inputs) > 0 train_dataset = dataset.make_batch(batch_size, inputs, simple_parse=FLAGS.simple_parse) num_examples = dataset.num_examples_per_epoch('train') else: assert FLAGS.torch_only, 'only torch only currently support input dataset not Dataset class type, because we do not have len function there' train_dataset = dataset num_examples = len(train_dataset) num_all_examples = num_examples if valid_dataset is None: valid_inputs = None if FLAGS.valid_input: valid_inputs = gezi.list_files(FLAGS.valid_input) else: if FLAGS.fold is not None: #valid_inputs = [x for x in all_inputs if x not in inputs] if not FLAGS.test_aug: valid_inputs = [ x for x in all_inputs if not 'aug' in x and x not in inputs ] else: valid_inputs = [ x for x in all_inputs if 'aug' in x and x not in inputs ] logging.info('valid_inputs', valid_inputs) num_valid_examples = None if valid_dataset is not None: num_valid_examples = len(valid_dataset) else: if valid_inputs: valid_dataset = dataset.make_batch(batch_size_, valid_inputs, subset='valid', hvd_shard=FLAGS.horovod_eval) valid_dataset2 = dataset.make_batch(batch_size, valid_inputs, subset='valid', repeat=True, initializable=False, hvd_shard=False) valid_dataset2_iter = iter(valid_dataset2) else: valid_datsset = None valid_dataset2 = None if num_examples: if FLAGS.fold is not None: num_examples = int(num_examples * (num_folds - 1) / num_folds) num_steps_per_epoch = -(-num_examples // batch_size) else: num_steps_per_epoch = None logging.info('num_train_examples:', num_examples) if use_horovod and num_examples: num_steps_per_epoch = -(-num_examples // (batch_size * hvd.size())) if num_valid_examples is None: if FLAGS.valid_input: num_valid_examples = dataset.num_examples_per_epoch('valid') num_valid_steps_per_epoch = -(-num_valid_examples // batch_size_ ) if num_valid_examples else None else: if FLAGS.fold is not None: if num_examples: num_valid_examples = int(num_all_examples * (1 / num_folds)) num_valid_steps_per_epoch = -(-num_valid_examples // batch_size_) else: num_valid_steps_per_epoch = None if use_horovod and FLAGS.horovod_eval and num_valid_examples: num_valid_steps_per_epoch = -(-num_valid_examples // (batch_size_ * hvd.size())) logging.info('num_valid_examples:', num_valid_examples) if test_dataset is None: if FLAGS.test_input: test_inputs = gezi.list_files(FLAGS.test_input) #test_inputs = [x for x in test_inputs if not 'aug' in x] logging.info('test_inputs', test_inputs) else: test_inputs = None num_test_examples = None if test_dataset is not None: num_test_examples = len(test_dataset) else: if test_inputs: test_dataset = dataset.make_batch(batch_size_, test_inputs, subset='test') num_test_examples = dataset.num_examples_per_epoch('test') else: test_dataset = None num_test_steps_per_epoch = -(-num_test_examples // batch_size_) if num_test_examples else None if use_horovod and FLAGS.horovod_eval and num_test_examples: num_test_steps_per_epoch = -(-num_test_examples // (batch_size_ * hvd.size())) logging.info('num_test_examples:', num_test_examples) summary = tf.contrib.summary # writer = summary.create_file_writer(FLAGS.log_dir + '/epoch') # writer_train = summary.create_file_writer(FLAGS.log_dir + '/train') # writer_valid = summary.create_file_writer(FLAGS.log_dir + '/valid') writer = summary.create_file_writer(FLAGS.log_dir) writer_train = summary.create_file_writer(FLAGS.log_dir) writer_valid = summary.create_file_writer(FLAGS.log_dir) global_step = tf.train.get_or_create_global_step() ## RuntimeError: tf.summary.FileWriter is not compatible with eager execution. Use tf.contrib.summary instead. #logger = gezi.SummaryWriter(FLAGS.log_dir) learning_rate = tfe.Variable(FLAGS.learning_rate, name="learning_rate") tf.add_to_collection('learning_rate', learning_rate) learning_rate_weight = tf.get_collection('learning_rate_weight')[-1] try: learning_rate_weights = tf.get_collection('learning_rate_weights')[-1] except Exception: learning_rate_weights = None # ckpt dir save models one per epoch ckpt_dir = os.path.join(FLAGS.model_dir, 'ckpt') os.system('mkdir -p %s' % ckpt_dir) # HACK ckpt dir is actually save mini epoch like when you set save_interval_epochs=0.1, this is usefull when you training large dataset ckpt_dir2 = os.path.join(FLAGS.model_dir, 'ckpt2') os.system('mkdir -p %s' % ckpt_dir2) #TODO FIXME now I just changed tf code so to not by default save only latest 5 # refer to https://github.com/tensorflow/tensorflow/issues/22036 # manager = tf.contrib.checkpoint.CheckpointManager( # checkpoint, directory=ckpt_dir, max_to_keep=5) # latest_checkpoint = manager.latest_checkpoint latest_checkpoint = tf.train.latest_checkpoint(ckpt_dir) if latest_checkpoint: logging.info('Latest checkpoint:', latest_checkpoint) else: latest_checkpoint = tf.train.latest_checkpoint(ckpt_dir2) logging.info('Latest checkpoint:', latest_checkpoint) if os.path.exists(FLAGS.model_dir + '.index'): latest_checkpoint = FLAGS.model_dir if 'test' in FLAGS.work_mode or 'valid' in FLAGS.work_mode: #assert not os.path.isdir(FLAGS.model_dir), FLAGS.model_dir latest_checkpoint = FLAGS.model_dir #assert os.path.exists(latest_checkpoint) and os.path.isfile(latest_checkpoint) checkpoint_prefix = os.path.join(ckpt_dir, 'ckpt') checkpoint_prefix2 = os.path.join(ckpt_dir2, 'ckpt') if not FLAGS.torch: try: optimizer = optimizer or melt.get_optimizer( FLAGS.optimizer)(learning_rate) except Exception: logging.warning( f'Fail to using {FLAGS.optimizer} use adam instead') optimizer = melt.get_optimizer('adam')(learning_rate) # TODO... if learning_rate_weights is None: checkpoint = tf.train.Checkpoint( learning_rate=learning_rate, learning_rate_weight=learning_rate_weight, model=model, optimizer=optimizer, global_step=global_step) else: checkpoint = tf.train.Checkpoint( learning_rate=learning_rate, learning_rate_weight=learning_rate_weight, learning_rate_weights=learning_rate_weights, model=model, optimizer=optimizer, global_step=global_step) checkpoint.restore(latest_checkpoint) checkpoint2 = copy.deepcopy(checkpoint) start_epoch = int( latest_checkpoint.split('-') [-1]) if latest_checkpoint and 'ckpt' in latest_checkpoint else 0 start_step = 0 # TODO else: # TODO torch with learning rate adjust # https://github.com/horovod/horovod/blob/master/examples/pytorch_mnist.py # TODO full support for pytorch now not work if optimizer is None: import lele is_dynamic_opt = True if FLAGS.optimizer == 'noam': optimizer_ = torch.optim.Adamax(model.parameters(), lr=0) if use_horovod: optimizer_ = hvd.DistributedOptimizer(optimizer_) optimizer = lele.training.optimizers.NoamOpt( 128, 2, 4000, optimzier_) elif FLAGS.optimizer == 'bert': num_train_steps = int( num_steps_per_epoch * (FLAGS.num_decay_epochs or FLAGS.num_epochs)) if FLAGS.warmup_steps and use_horovod: FLAGS.warmup_steps = max( int(FLAGS.warmup_steps / hvd.size()), 1) num_warmup_steps = FLAGS.warmup_steps or int( num_steps_per_epoch * FLAGS.warmup_epochs) or int( num_train_steps * FLAGS.warmup_proportion) logging.info('num_train_steps', num_train_steps, 'num_warmup_steps', num_warmup_steps, 'warmup_proportion', FLAGS.warmup_proportion) optimizer_ = torch.optim.Adamax(model.parameters(), lr=0) if use_horovod: optimizer_ = hvd.DistributedOptimizer(optimizer_) optimizer = lele.training.optimizers.BertOpt( FLAGS.learning_rate, FLAGS.min_learning_rate, num_train_steps, num_warmup_steps, optimizer_) else: is_dynamic_opt = False optimizer = torch.optim.Adamax( param_groups if param_groups else model.parameters(), lr=FLAGS.learning_rate) if use_horovod: optimizer = hvd.DistributedOptimizer(optimizer) start_epoch = 0 latest_path = latest_checkpoint + '.pyt' if latest_checkpoint else os.path.join( FLAGS.model_dir, 'latest.pyt') if not os.path.exists(latest_path): latest_path = os.path.join(FLAGS.model_dir, 'latest.pyt') if os.path.exists(latest_path): logging.info('loading torch model from', latest_path) checkpoint = torch.load(latest_path) if not FLAGS.torch_finetune: start_epoch = checkpoint['epoch'] step = checkpoint['step'] global_step.assign(step + 1) load_torch_model(model, latest_path) if FLAGS.torch_load_optimizer: optimizer.load_state_dict(checkpoint['optimizer']) # TODO by this way restart can not change learning rate.. if learning_rate_weights is None: checkpoint = tf.train.Checkpoint( learning_rate=learning_rate, learning_rate_weight=learning_rate_weight, global_step=global_step) else: checkpoint = tf.train.Checkpoint( learning_rate=learning_rate, learning_rate_weight=learning_rate_weight, learning_rate_weights=learning_rate_weights, global_step=global_step) try: checkpoint.restore(latest_checkpoint) checkpoint2 = copy.deepcopy(checkpoint) except Exception: pass if FLAGS.torch and is_dynamic_opt: optimizer._step = global_step.numpy() #model.load_weights(os.path.join(ckpt_dir, 'ckpt-1')) #model.save('./weight3.hd5') logging.info('optimizer:', optimizer) if FLAGS.torch_lr: learning_rate.assign(optimizer.rate(1)) if FLAGS.torch: learning_rate.assign(optimizer.param_groups[0]['lr']) logging.info('learning rate got from pytorch latest.py as', learning_rate.numpy()) learning_rate.assign(learning_rate * FLAGS.learning_rate_start_factor) if learning_rate_weights is not None: learning_rate_weights.assign(learning_rate_weights * FLAGS.learning_rate_start_factor) # TODO currently not support 0.1 epoch.. like this num_epochs = FLAGS.num_epochs if FLAGS.num_epochs != 0 else 1024 will_valid = valid_dataset and not FLAGS.work_mode == 'test' and not 'SHOW' in os.environ and not 'QUICK' in os.environ if global_step.numpy() == 0: will_valid = False if gezi.get_env('EVFIRST') == '1': will_valid = True if gezi.get_env('EVFIRST') == '0': will_valid = False if will_valid: logging.info('----------valid') if hasattr(model, 'eval'): model.eval() names = None if evaluate_fn is not None: vals, names = evaluate_fn(model, valid_dataset, tf.train.latest_checkpoint(ckpt_dir), num_valid_steps_per_epoch) elif eval_fn: model_path = None if not write_valid else latest_checkpoint names = valid_names if valid_names is not None else [ infer_names[0] ] + [x + '_y' for x in infer_names[1:] ] + infer_names[1:] if infer_names else None logging.info('model_path:', model_path, 'model_dir:', FLAGS.model_dir) vals, names = evaluate(model, valid_dataset, eval_fn, model_path, names, valid_write_fn, write_streaming, num_valid_steps_per_epoch, num_valid_examples, suffix=valid_suffix, sep=sep) if names: logging.info2( 'epoch:%.2f/%d step:%d' % (global_step.numpy() / num_steps_per_epoch, num_epochs, global_step.numpy()), ['%s:%.4f' % (name, val) for name, val in zip(names, vals)]) if FLAGS.work_mode == 'valid' or gezi.get_env('METRIC') == '1': exit(0) if 'test' in FLAGS.work_mode or gezi.get_env( 'TEST') == '1' or gezi.get_env('INFER') == '1': logging.info('--------test/inference') if test_dataset: if hasattr(model, eval): model.eval() if inference_fn is None: # model_path = FLAGS.model_dir + '.pyt' if not latest_checkpoint else latest_checkpoint # logging.info('model_path', model_path) assert latest_checkpoint inference(model, test_dataset, latest_checkpoint, infer_names, infer_debug_names, infer_write_fn, write_streaming, num_test_steps_per_epoch, num_test_examples, suffix=infer_suffix) else: inference_fn(model, test_dataset, tf.train.latest_checkpoint(ckpt_dir), num_test_steps_per_epoch) exit(0) if 'SHOW' in os.environ: num_epochs = start_epoch + 1 class PytObj(object): def __init__(self, x): self.x = x def numpy(self): return self.x class PytMean(object): def __init__(self): self._val = 0. self.count = 0 self.is_call = True def clear(self): self._val = 0 self.count = 0 def __call__(self, val): if not self.is_call: self.clear() self.is_call = True self._val += val.item() self.count += 1 def result(self): if self.is_call: self.is_call = False if not self.count: val = 0 else: val = self._val / self.count # TODO just for compact with tf .. return PytObj(val) Mean = tfe.metrics.Mean if not FLAGS.torch else PytMean num_insts = 0 if FLAGS.learning_rate_decay_factor > 0: #assert FLAGS.learning_rate_values is None, 'use exponential_decay or piecewise_constant?' #NOTICE if you do finetune or other things which might change batch_size then you'd better direclty set num_steps_per_decay #since global step / decay_steps will not be correct epoch as num_steps per epoch changed #so if if you change batch set you have to reset global step as fixed step assert FLAGS.num_steps_per_decay or ( FLAGS.num_epochs_per_decay and num_steps_per_epoch ), 'must set num_steps_per_epoch or num_epochs_per_decay and num_steps_per_epoch' decay_steps = FLAGS.num_steps_per_decay or int( num_steps_per_epoch * FLAGS.num_epochs_per_decay) decay_start_step = FLAGS.decay_start_step or int( num_steps_per_epoch * FLAGS.decay_start_epoch) # decayed_learning_rate = learning_rate * decay_rate ^ (global_step / decay_steps) logging.info( 'learning_rate_decay_factor:{} decay_epochs:{} decay_steps:{} decay_start_epoch:{} decay_start_step:{}' .format(FLAGS.learning_rate_decay_factor, FLAGS.num_epochs_per_decay, decay_steps, FLAGS.decay_start_epoch, decay_start_step)) #-------------------------start training if hasattr(model, 'train'): model.train() timer = gezi.Timer() loss_avg = Mean() valid_loss_avg = Mean() num_epochs = num_epochs if num_epochs else 0 loops = min(num_epochs, 1) if FLAGS.torch_only else 1 for _ in range(loops): for i, (x, y) in enumerate(train_dataset): #print('-------------------', i) print(len(x['index']), len(x['value']), len(x['id'])) print(x['index'][0].size(), x['index'][1].size(), y.size()) print(x['value'][0].size(), x['value'][1].size(), y.size()) print(x['id'][0], x['id'][1], y.size()) if i == 3: exit(0) continue if FLAGS.torch: x, y = to_torch(x, y) if is_dynamic_opt: learning_rate.assign(optimizer.rate()) def loss_fn_(x, y): if not FLAGS.torch and 'training' in inspect.getargspec( model.call).args: y_ = model(x, training=True) else: y_ = model(x) if not FLAGS.torch: return loss_fn(y, y_) else: return loss_fn(y_, y) if not FLAGS.torch: loss, grads = melt.eager.grad(model, x, y, loss_fn) grads, _ = tf.clip_by_global_norm(grads, FLAGS.clip_gradients) #optimizer.apply_gradients(zip(grads, model.variables)) optimizer.apply_gradients(zip(grads, model.trainable_variables)) # https://github.com/horovod/horovod/blob/master/examples/tensorflow_mnist_eager.py # Horovod: broadcast initial variable states from rank 0 to all other processes. # This is necessary to ensure consistent initialization of all workers when # training is started with random weights or restored from a checkpoint. # Note: broadcast should be done after the first gradient step to ensure optimizer # initialization. # TODO check eager mode if use_horovod and epoch == start_epoch and i == 0: hvd.broadcast_variables(model.variables, root_rank=0) hvd.broadcast_variables(optimizier.variables(), root_rank=0) else: optimizer.zero_grad() loss = loss_fn_(x, y) loss.backward() torch.nn.utils.clip_grad_norm_(model.parameters(), FLAGS.clip_gradients) optimizer.step() global_step.assign_add(1) loss_avg(loss) ## https://discuss.pytorch.org/t/calling-loss-backward-reduce-memory-usage/2735 # if FLAGS.torch: # del loss batch_size_ = list( x.values())[0].shape[FLAGS.batch_size_dim] if type(x) == type( {}) else x.shape[FLAGS.batch_size_dim] num_insts += int(batch_size_) if global_step.numpy() % FLAGS.interval_steps == 0: #checkpoint.save(checkpoint_prefix) elapsed = timer.elapsed() steps_per_second = FLAGS.interval_steps / elapsed instances_per_second = num_insts / elapsed num_insts = 0 if num_steps_per_epoch is None: epoch_time_info = '' else: hours_per_epoch = num_steps_per_epoch / FLAGS.interval_steps * elapsed / 3600 epoch_time_info = '1epoch:[{:.2f}h]'.format( hours_per_epoch) if valid_dataset2: # try: # x, y = next(iter(valid_dataset2)) # except Exception: # # TODO FIXME how.. iterate stop restart.., here hack for my iterator see projects/lm/dataset # x, y = next(iter(valid_dataset2)) ## valid dataset2 is repeated ## NOTICE will always the first batch ... as below #x, y = next(iter(valid_dataset2)) x, y = next(valid_dataset2_iter) #print(x['id'][0]) if FLAGS.torch: x, y = to_torch(x, y) if hasattr(model, 'eval'): model.eval() valid_loss = loss_fn_(x, y) valid_loss = valid_loss.numpy( ) if not FLAGS.torch else valid_loss.item() if hasattr(model, 'train'): model.train() if not use_horovod or hvd.rank() == 0: # 'train_loss:[%.4f]' % loss_avg.result().numpy(), # 'valid_loss:[%.4f]' % valid_loss_avg.result().numpy() logging.info2( 'epoch:%.2f/%d' % ((global_step.numpy() / num_steps_per_epoch), num_epochs), 'step:%d' % global_step.numpy(), 'elapsed:[%.2f]' % elapsed, 'batch_size:[%d]' % batch_size_, 'gpus:[%d]' % num_gpus, 'batches/s:[%.2f]' % steps_per_second, 'insts/s:[%d]' % instances_per_second, '%s' % epoch_time_info, 'lr:[%.6f]' % learning_rate.numpy(), 'train_loss:[%.4f]' % loss_avg.result().numpy(), 'valid_loss:[%.4f]' % valid_loss) if global_step.numpy( ) % FLAGS.valid_interval_steps == 0: with writer_valid.as_default( ), summary.always_record_summaries(): summary.scalar('loss/valid', valid_loss) writer_valid.flush() else: if not use_horovod or hvd.rank() == 0: #'train_loss:[%.4f]' % loss_avg.result().numpy() logging.info2( 'epoch:%.2f/%d' % ((epoch + i / num_steps_per_epoch), num_epochs), 'step:%d' % global_step.numpy(), 'elapsed:[%.2f]' % elapsed, 'batch_size:[%d]' % batch_size_, 'gpus:[%d]' % num_gpus, 'batches/s:[%.2f]' % steps_per_second, 'insts/s:[%d]' % instances_per_second, '%s' % epoch_time_info, 'lr:[%.6f]' % learning_rate.numpy(), 'train_loss:[%.4f]' % loss_avg.result().numpy()) if not use_horovod or hvd.rank() == 0: if global_step.numpy() % FLAGS.valid_interval_steps == 0: with writer_train.as_default( ), summary.always_record_summaries(): summary.scalar('loss/train_avg', loss_avg.result().numpy()) summary.scalar('learning_rate', learning_rate.numpy()) summary.scalar('other/batch_size', batch_size_) summary.scalar('other/epoch', melt.epoch()) summary.scalar('perf/steps_per_second', steps_per_second) summary.scalar('perf/instances_per_second', instances_per_second) writer_train.flush() if valid_dataset and FLAGS.metric_eval_interval_steps and global_step.numpy( ) and global_step.numpy() % FLAGS.metric_eval_interval_steps == 0: if hasattr(model, eval): model.eval() vals, names = None, None if evaluate_fn is not None: vals, names = evaluate_fn(model, valid_dataset, None, num_valid_steps_per_epoch) elif eval_fn: names = valid_names if valid_names is not None else [ infer_names[0] ] + [x + '_y' for x in infer_names[1:] ] + infer_names[1:] if infer_names else None vals, names = evaluate(model, valid_dataset, eval_fn, None, names, valid_write_fn, write_streaming, num_valid_steps_per_epoch, num_valid_examples, sep=sep) if not use_horovod or hvd.rank() == 0: if vals and names: with writer_valid.as_default( ), summary.always_record_summaries(): for name, val in zip(names, vals): summary.scalar(f'step_eval/{name}', val) writer_valid.flush() if FLAGS.torch: if not FLAGS.torch_lr: # control learning rate by tensorflow learning rate for param_group in optimizer.param_groups: # important learning rate decay param_group['lr'] = learning_rate.numpy() if hasattr(model, 'train'): model.train() if not use_horovod or hvd.rank() == 0: if names and vals: logging.info2( 'epoch:%.2f/%d' % ((global_step.numpy() / num_steps_per_epoch), num_epochs), 'valid_step:%d' % global_step.numpy(), 'valid_metrics', [ '%s:%.5f' % (name, val) for name, val in zip(names, vals) ]) if not use_horovod or hvd.rank() == 0: # TODO save ok ? if global_step.numpy() % FLAGS.save_interval_steps == 0: if FLAGS.torch: state = { 'epoch': int(global_step.numpy() / num_steps_per_epoch), 'step': global_step.numpy(), 'state_dict': model.state_dict() if not hasattr(model, 'module') else model.module.state_dict(), 'optimizer': optimizer.state_dict(), } torch.save(state, os.path.join(FLAGS.model_dir, 'latest.pyt')) # TODO fixme why if both checpoint2 and chekpoint used... not ok.. if FLAGS.save_interval_epochs and global_step.numpy() % int( num_steps_per_epoch * FLAGS.save_interval_epochs) == 0: checkpoint2.save(checkpoint_prefix2) if FLAGS.torch: state = { 'epoch': int(global_step.numpy() / num_steps_per_epoch), 'step': global_step.numpy(), 'state_dict': model.state_dict() if not hasattr(model, 'module') else model.module.state_dict(), 'optimizer': optimizer.state_dict(), } torch.save( state, tf.train.latest_checkpoint(ckpt_dir2) + '.pyt') if FLAGS.learning_rate_decay_factor > 0: if global_step.numpy( ) >= decay_start_step and global_step.numpy( ) % decay_steps == 0: lr = max( learning_rate.numpy() * FLAGS.learning_rate_decay_factor, FLAGS.min_learning_rate) if lr < learning_rate.numpy(): learning_rate.assign(lr) if FLAGS.torch: for param_group in optimizer.param_groups: param_group['lr'] = learning_rate.numpy() if i == 0: try: if not FLAGS.torch: logging.info(model.summary()) # #tf.keras.utils.plot_model(model, to_file='/home/gezi/model.png', show_shapes=False, show_layer_names=True, rankdir='TB') # import keras # keras.utils.plot_model(model, to_file='/home/gezi/model.png', show_shapes=False, show_layer_names=True, rankdir='LR', expand_nested=True, dpi=96) else: logging.info(model) except Exception: traceback.print_exc() logging.info( 'Fail to do model.summary() may be you have layer define in init but not used in call' ) if 'SHOW' in os.environ: exit(0) if valid_dataset and global_step.numpy() % int( num_steps_per_epoch * FLAGS.valid_interval_epochs) == 0: if hasattr(model, 'eval'): model.eval() vals, names = None, None if evaluate_fn is not None: vals, names = evaluate_fn( model, valid_dataset, tf.train.latest_checkpoint(ckpt_dir), num_valid_steps_per_epoch) elif eval_fn: model_path = None if not write_valid else tf.train.latest_checkpoint( ckpt_dir) print('---------metric evaluate step', global_step.numpy(), 'model_path:', model_path) names = valid_names if valid_names is not None else [ infer_names[0] ] + [x + '_y' for x in infer_names[1:] ] + infer_names[1:] if infer_names else None vals, names = evaluate(model, valid_dataset, eval_fn, model_path, names, valid_write_fn, write_streaming, num_valid_steps_per_epoch, num_valid_examples, suffix=valid_suffix, sep=sep) if not use_horovod or hvd.rank() == 0: if vals and names: logging.info2( 'epoch:%.2f/%d' % (global_step.numpy() / num_steps_per_epoch, num_epochs), 'step:%d' % global_step.numpy(), 'valid_metrics', [ '%s:%.5f' % (name, val) for name, val in zip(names, vals) ]) if not use_horovod or hvd.rank() == 0: with writer.as_default(), summary.always_record_summaries( ): temp = global_step.value() global_step.assign( int(global_step.numpy() / int(num_steps_per_epoch * FLAGS.valid_interval_epochs))) if valid_dataset: if hasattr(model, 'eval'): model.eval() if vals and names: for name, val in zip(names, vals): summary.scalar(f'eval/{name}', val) writer.flush() global_step.assign(temp) if test_dataset and global_step.numpy() % int( num_steps_per_epoch * FLAGS.inference_interval_epochs) == 0: if hasattr(model, 'eval'): model.eval() if inference_fn is None: inference(model, test_dataset, tf.train.latest_checkpoint(ckpt_dir), infer_names, infer_debug_names, infer_write_fn, write_streaming, num_test_steps_per_epoch, num_test_examples, suffix=infer_suffix, sep=sep) else: inference_fn(model, test_dataset, tf.train.latest_checkpoint(ckpt_dir), num_test_steps_per_epoch) if num_epochs and (global_step.numpy() % num_steps_per_epoch) == 0 and int( global_step.numpy() / num_steps_per_epoch) == num_epochs: logging.info(f'Finshed training of {num_epochs} epochs') exit(0)