def set_batch_iterator_func(self): if (self.conf is not None and 'use_process_generator' in conf['training'] and conf['training']['use_process_generator']): self.batch_iterator_func = ProcessGenerator(self.batch_iterator()) else: self.batch_iterator_func = self.batch_iterator()
def set_batch_iterator_func(self): self.batch_iterator_func = ProcessGenerator(self.batch_iterator())
class MPIModel(): def __init__(self, model, optimizer, comm, batch_iterator, batch_size, num_replicas=None, warmup_steps=1000, lr=0.01, num_batches_minimum=100): random.seed(task_index) np.random.seed(task_index) self.start_time = time.time() self.epoch = 0 self.num_so_far = 0 self.num_so_far_accum = 0 self.num_so_far_indiv = 0 self.model = model self.optimizer = optimizer self.max_lr = 0.1 self.DUMMY_LR = 0.001 self.comm = comm self.batch_size = batch_size self.batch_iterator = batch_iterator self.set_batch_iterator_func() self.warmup_steps = warmup_steps self.num_batches_minimum = num_batches_minimum self.num_workers = comm.Get_size() self.task_index = comm.Get_rank() self.history = cbks.History() if num_replicas is None or num_replicas < 1 or num_replicas > self.num_workers: self.num_replicas = self.num_workers else: self.num_replicas = num_replicas self.lr = lr / (1.0 + self.num_replicas / 100.0) if ( lr < self.max_lr) else self.max_lr / (1.0 + self.num_replicas / 100.0) def set_batch_iterator_func(self): self.batch_iterator_func = ProcessGenerator(self.batch_iterator()) def close(self): self.batch_iterator_func.__exit__() def set_lr(self, lr): self.lr = lr def save_weights(self, path, overwrite=False): self.model.save_weights(path, overwrite=overwrite) def load_weights(self, path): self.model.load_weights(path) def compile(self, optimizer, clipnorm, loss='mse'): if optimizer == 'sgd': optimizer_class = SGD(lr=self.DUMMY_LR, clipnorm=clipnorm) elif optimizer == 'momentum_sgd': optimizer_class = SGD(lr=self.DUMMY_LR, clipnorm=clipnorm, decay=1e-6, momentum=0.9) elif optimizer == 'tf_momentum_sgd': optimizer_class = TFOptimizer( tf.train.MomentumOptimizer(learning_rate=self.DUMMY_LR, momentum=0.9)) elif optimizer == 'adam': optimizer_class = Adam(lr=self.DUMMY_LR, clipnorm=clipnorm) elif optimizer == 'tf_adam': optimizer_class = TFOptimizer( tf.train.AdamOptimizer(learning_rate=self.DUMMY_LR)) elif optimizer == 'rmsprop': optimizer_class = RMSprop(lr=self.DUMMY_LR, clipnorm=clipnorm) elif optimizer == 'nadam': optimizer_class = Nadam(lr=self.DUMMY_LR, clipnorm=clipnorm) else: print("Optimizer not implemented yet") exit(1) self.model.compile(optimizer=optimizer_class, loss=loss) def train_on_batch_and_get_deltas(self, X_batch, Y_batch, verbose=False): ''' The purpose of the method is to perform a single gradient update over one mini-batch for one model replica. Given a mini-batch, it first accesses the current model weights, performs single gradient update over one mini-batch, gets new model weights, calculates weight updates (deltas) by subtracting weight scalars, applies the learning rate. It performs calls to: subtract_params, multiply_params Argument list: - X_batch: input data for one mini-batch as a Numpy array - Y_batch: labels for one mini-batch as a Numpy array - verbose: set verbosity level (currently unused) Returns: - deltas: a list of model weight updates - loss: scalar training loss ''' weights_before_update = self.model.get_weights() loss = self.model.train_on_batch(X_batch, Y_batch) weights_after_update = self.model.get_weights() self.model.set_weights(weights_before_update) #unscale before subtracting weights_before_update = multiply_params(weights_before_update, 1.0 / self.DUMMY_LR) weights_after_update = multiply_params(weights_after_update, 1.0 / self.DUMMY_LR) deltas = subtract_params(weights_after_update, weights_before_update) #unscale loss if conf['model']['loss_scale_factor'] != 1.0: deltas = multiply_params(deltas, 1.0 / conf['model']['loss_scale_factor']) return deltas, loss def get_new_weights(self, deltas): return add_params(self.model.get_weights(), deltas) def mpi_average_gradients(self, arr, num_replicas=None): if num_replicas == None: num_replicas = self.num_workers if self.task_index >= num_replicas: arr *= 0.0 arr_global = np.empty_like(arr) if K.floatx() == 'float16': self.comm.Allreduce(arr, arr_global, op=mpi_sum_f16) else: self.comm.Allreduce(arr, arr_global, op=MPI.SUM) arr_global /= num_replicas return arr_global def mpi_average_scalars(self, val, num_replicas=None): ''' The purpose of the method is to calculate a simple scalar arithmetic mean over num_replicas. It performs calls to: MPIModel.mpi_sum_scalars Argument list: - val: value averaged, scalar - num_replicas: the size of the ensemble an average is perfromed over Returns: - val_global: scalar arithmetic mean over num_replicas ''' val_global = self.mpi_sum_scalars(val, num_replicas) val_global /= num_replicas return val_global def mpi_sum_scalars(self, val, num_replicas=None): ''' The purpose of the method is to calculate a simple scalar arithmetic mean over num_replicas using MPI allreduce action with fixed op=MPI.SIM Argument list: - val: value averaged, scalar - num_replicas: the size of the ensemble an average is perfromed over Returns: - val_global: scalar arithmetic mean over num_replicas ''' if num_replicas == None: num_replicas = self.num_workers if self.task_index >= num_replicas: val *= 0.0 val_global = 0.0 val_global = self.comm.allreduce(val, op=MPI.SUM) return val_global def sync_deltas(self, deltas, num_replicas=None): global_deltas = [] #default is to reduce the deltas from all workers for delta in deltas: global_deltas.append( self.mpi_average_gradients(delta, num_replicas)) return global_deltas def set_new_weights(self, deltas, num_replicas=None): global_deltas = self.sync_deltas(deltas, num_replicas) effective_lr = self.get_effective_lr(num_replicas) self.optimizer.set_lr(effective_lr) global_deltas = self.optimizer.get_deltas(global_deltas) if self.comm.rank == 0: new_weights = self.get_new_weights(global_deltas) else: new_weights = None new_weights = self.comm.bcast(new_weights, root=0) self.model.set_weights(new_weights) def build_callbacks(self, conf, callbacks_list): ''' The purpose of the method is to set up logging and history. It is based on Keras Callbacks https://github.com/fchollet/keras/blob/fbc9a18f0abc5784607cd4a2a3886558efa3f794/keras/callbacks.py Currently used callbacks include: BaseLogger, CSVLogger, EarlyStopping. Other possible callbacks to add in future: RemoteMonitor, LearningRateScheduler Argument list: - conf: There is a "callbacks" section in conf.yaml file. Relevant parameters are: list: Parameter specifying additional callbacks, read in the driver script and passed as an argument of type list (see next arg) metrics: List of quantities monitored during training and validation mode: one of {auto, min, max}. The decision to overwrite the current save file is made based on either the maximization or the minimization of the monitored quantity. For val_acc, this should be max, for val_loss this should be min, etc. In auto mode, the direction is automatically inferred from the name of the monitored quantity. monitor: Quantity used for early stopping, has to be from the list of metrics patience: Number of epochs used to decide on whether to apply early stopping or continue training - callbacks_list: uses callbacks.list configuration parameter, specifies the list of additional callbacks Returns: modified list of callbacks ''' mode = conf['callbacks']['mode'] monitor = conf['callbacks']['monitor'] patience = conf['callbacks']['patience'] csvlog_save_path = conf['paths']['csvlog_save_path'] #CSV callback is on by default if not os.path.exists(csvlog_save_path): os.makedirs(csvlog_save_path) callbacks_list = conf['callbacks']['list'] callbacks = [cbks.BaseLogger()] callbacks += [self.history] callbacks += [ cbks.CSVLogger("{}callbacks-{}.log".format( csvlog_save_path, datetime.datetime.now().strftime("%Y-%m-%d-%H-%M-%S"))) ] if "earlystop" in callbacks_list: callbacks += [ cbks.EarlyStopping(patience=patience, monitor=monitor, mode=mode) ] if "lr_scheduler" in callbacks_list: pass return cbks.CallbackList(callbacks) def train_epoch(self): ''' The purpose of the method is to perform distributed mini-batch SGD for one epoch. It takes the batch iterator function and a NN model from MPIModel object, fetches mini-batches in a while-loop until number of samples seen by the ensemble of workers (num_so_far) exceeds the training dataset size (num_total). During each iteration, the gradient updates (deltas) and the loss are calculated for each model replica in the ensemble, weights are averaged over ensemble, and the new weights are set. It performs calls to: MPIModel.get_deltas, MPIModel.set_new_weights methods Argument list: Empty Returns: - step: epoch number - ave_loss: training loss averaged over replicas - curr_loss: - num_so_far: the number of samples seen by ensemble of replicas to a current epoch (step) Intermediate outputs and logging: debug printout of task_index (MPI), epoch number, number of samples seen to a current epoch, average training loss ''' verbose = False first_run = True step = 0 loss_averager = Averager() t_start = time.time() batch_iterator_func = self.batch_iterator_func num_total = 1 ave_loss = -1 curr_loss = -1 t0 = 0 t1 = 0 t2 = 0 while (self.num_so_far - self.epoch * num_total) < num_total or step < self.num_batches_minimum: try: batch_xs, batch_ys, batches_to_reset, num_so_far_curr, num_total, is_warmup_period = next( batch_iterator_func) except StopIteration: print("Resetting batch iterator.") self.num_so_far_accum = self.num_so_far_indiv self.set_batch_iterator_func() batch_iterator_func = self.batch_iterator_func batch_xs, batch_ys, batches_to_reset, num_so_far_curr, num_total, is_warmup_period = next( batch_iterator_func) self.num_so_far_indiv = self.num_so_far_accum + num_so_far_curr # if batches_to_reset: # self.model.reset_states(batches_to_reset) warmup_phase = (step < self.warmup_steps and self.epoch == 0) num_replicas = 1 if warmup_phase else self.num_replicas self.num_so_far = self.mpi_sum_scalars(self.num_so_far_indiv, num_replicas) #run the model once to force compilation. Don't actually use these values. if first_run: first_run = False t0_comp = time.time() _, _ = self.train_on_batch_and_get_deltas( batch_xs, batch_ys, verbose) self.comm.Barrier() sys.stdout.flush() print_unique( 'Compilation finished in {:.2f}s'.format(time.time() - t0_comp)) t_start = time.time() sys.stdout.flush() if np.any(batches_to_reset): reset_states(self.model, batches_to_reset) t0 = time.time() deltas, loss = self.train_on_batch_and_get_deltas( batch_xs, batch_ys, verbose) t1 = time.time() if not is_warmup_period: self.set_new_weights(deltas, num_replicas) t2 = time.time() write_str_0 = self.calculate_speed(t0, t1, t2, num_replicas) curr_loss = self.mpi_average_scalars(1.0 * loss, num_replicas) #if self.task_index == 0: #print(self.model.get_weights()[0][0][:4]) loss_averager.add_val(curr_loss) ave_loss = loss_averager.get_val() eta = self.estimate_remaining_time( t0 - t_start, self.num_so_far - self.epoch * num_total, num_total) write_str = '\r[{}] step: {} [ETA: {:.2f}s] [{:.2f}/{}], loss: {:.5f} [{:.5f}] | walltime: {:.4f} | '.format( self.task_index, step, eta, 1.0 * self.num_so_far, num_total, ave_loss, curr_loss, time.time() - self.start_time) print_unique(write_str + write_str_0) step += 1 else: print_unique('\r[{}] warmup phase, num so far: {}'.format( self.task_index, self.num_so_far)) effective_epochs = 1.0 * self.num_so_far / num_total epoch_previous = self.epoch self.epoch = effective_epochs print_unique( '\nEpoch {:.2f} finished ({:.2f} epochs passed) in {:.2f} seconds.\n' .format(1.0 * self.epoch, self.epoch - epoch_previous, t2 - t_start)) return (step, ave_loss, curr_loss, self.num_so_far, effective_epochs) def estimate_remaining_time(self, time_so_far, work_so_far, work_total): eps = 1e-6 total_time = 1.0 * time_so_far * work_total / (work_so_far + eps) return total_time - time_so_far def get_effective_lr(self, num_replicas): effective_lr = self.lr * num_replicas if effective_lr > self.max_lr: print_unique( 'Warning: effective learning rate set to {}, larger than maximum {}. Clipping.' .format(effective_lr, self.max_lr)) effective_lr = self.max_lr return effective_lr def get_effective_batch_size(self, num_replicas): return self.batch_size * num_replicas def calculate_speed(self, t0, t_after_deltas, t_after_update, num_replicas, verbose=False): effective_batch_size = self.get_effective_batch_size(num_replicas) t_calculate = t_after_deltas - t0 t_sync = t_after_update - t_after_deltas t_tot = t_after_update - t0 examples_per_sec = effective_batch_size / t_tot frac_calculate = t_calculate / t_tot frac_sync = t_sync / t_tot print_str = '{:.2E} Examples/sec | {:.2E} sec/batch [{:.1%} calc., {:.1%} synch.]'.format( examples_per_sec, t_tot, frac_calculate, frac_sync) print_str += '[batch = {} = {}*{}] [lr = {:.2E} = {:.2E}*{}]'.format( effective_batch_size, self.batch_size, num_replicas, self.get_effective_lr(num_replicas), self.lr, num_replicas) if verbose: print_unique(print_str) return print_str
def train(conf, shot_list_train, shot_list_validate, loader): loader.set_inference_mode(False) np.random.seed(1) validation_losses = [] validation_roc = [] training_losses = [] print('validate: {} shots, {} disruptive'.format( len(shot_list_validate), shot_list_validate.num_disruptive())) print('training: {} shots, {} disruptive'.format( len(shot_list_train), shot_list_train.num_disruptive())) if backend == 'tf' or backend == 'tensorflow': first_time = "tensorflow" not in sys.modules if first_time: import tensorflow as tf os.environ['KERAS_BACKEND'] = 'tensorflow' from keras.backend.tensorflow_backend import set_session config = tf.ConfigProto(device_count={"GPU": 1}) set_session(tf.Session(config=config)) else: os.environ['KERAS_BACKEND'] = 'theano' os.environ['THEANO_FLAGS'] = 'device=gpu,floatX=float32' import theano from keras.utils.generic_utils import Progbar from keras import backend as K from plasma.models import builder print('Build model...', end='') specific_builder = builder.ModelBuilder(conf) train_model = specific_builder.build_model(False) print('Compile model', end='') train_model.compile(optimizer=optimizer_class(), loss=conf['data']['target'].loss) print('...done') #load the latest epoch we did. Returns -1 if none exist yet e = specific_builder.load_model_weights(train_model) e_start = e batch_generator = partial(loader.training_batch_generator_partial_reset, shot_list=shot_list_train) batch_iterator = ProcessGenerator(batch_generator()) num_epochs = conf['training']['num_epochs'] num_at_once = conf['training']['num_shots_at_once'] lr_decay = conf['model']['lr_decay'] print('{} epochs left to go'.format(num_epochs - 1 - e)) num_so_far_accum = 0 num_so_far = 0 num_total = np.inf if conf['callbacks']['mode'] == 'max': best_so_far = -np.inf cmp_fn = max else: best_so_far = np.inf cmp_fn = min while e < num_epochs - 1: e += 1 print('\nEpoch {}/{}'.format(e + 1, num_epochs)) pbar = Progbar(len(shot_list_train)) #decay learning rate each epoch: K.set_value(train_model.optimizer.lr, lr * lr_decay**(e)) #print('Learning rate: {}'.format(train_model.optimizer.lr.get_value())) num_batches_minimum = 100 num_batches_current = 0 training_losses_tmp = [] while num_so_far < ( e - e_start ) * num_total or num_batches_current < num_batches_minimum: num_so_far_old = num_so_far try: batch_xs, batch_ys, batches_to_reset, num_so_far_curr, num_total, is_warmup_period = next( batch_iterator) except StopIteration: print("Resetting batch iterator.") num_so_far_accum = num_so_far batch_iterator = ProcessGenerator(batch_generator()) batch_xs, batch_ys, batches_to_reset, num_so_far_curr, num_total, is_warmup_period = next( batch_iterator) if np.any(batches_to_reset): reset_states(train_model, batches_to_reset) if not is_warmup_period: num_so_far = num_so_far_accum + num_so_far_curr num_batches_current += 1 loss = train_model.train_on_batch(batch_xs, batch_ys) training_losses_tmp.append(loss) pbar.add(num_so_far - num_so_far_old, values=[("train loss", loss)]) loader.verbose = False #True during the first iteration else: _ = train_model.predict( batch_xs, batch_size=conf['training']['batch_size']) e = e_start + 1.0 * num_so_far / num_total sys.stdout.flush() ave_loss = np.mean(training_losses_tmp) training_losses.append(ave_loss) specific_builder.save_model_weights(train_model, int(round(e))) if conf['training']['validation_frac'] > 0.0: print("prediction on GPU...") _, _, _, roc_area, loss = make_predictions_and_evaluate_gpu( conf, shot_list_validate, loader) validation_losses.append(loss) validation_roc.append(roc_area) epoch_logs = {} epoch_logs['val_roc'] = roc_area epoch_logs['val_loss'] = loss epoch_logs['train_loss'] = ave_loss best_so_far = cmp_fn(epoch_logs[conf['callbacks']['monitor']], best_so_far) if best_so_far != epoch_logs[ conf['callbacks'] ['monitor']]: #only save model weights if quantity we are tracking is improving print("Not saving model weights") specific_builder.delete_model_weights(train_model, int(round(e))) if conf['training']['ranking_difficulty_fac'] != 1.0: _, _, _, roc_area_train, loss_train = make_predictions_and_evaluate_gpu( conf, shot_list_train, loader) batch_iterator.__exit__() batch_generator = partial( loader.training_batch_generator_partial_reset, shot_list=shot_list_train) batch_iterator = ProcessGenerator(batch_generator()) num_so_far_accum = num_so_far print('=========Summary========') print('Training Loss Numpy: {:.3e}'.format(training_losses[-1])) if conf['training']['validation_frac'] > 0.0: print('Validation Loss: {:.3e}'.format(validation_losses[-1])) print('Validation ROC: {:.4f}'.format(validation_roc[-1])) if conf['training']['ranking_difficulty_fac'] != 1.0: print('Train Loss: {:.3e}'.format(loss_train)) print('Train ROC: {:.4f}'.format(roc_area_train)) # plot_losses(conf,[training_losses],specific_builder,name='training') if conf['training']['validation_frac'] > 0.0: plot_losses(conf, [training_losses, validation_losses, validation_roc], specific_builder, name='training_validation_roc') batch_iterator.__exit__() print('...done')
class MPIModel(): def __init__(self, model, optimizer, comm, batch_iterator, batch_size, num_replicas=None, warmup_steps=1000, lr=0.01, num_batches_minimum=100, conf=None): random.seed(g.task_index) np.random.seed(g.task_index) self.conf = conf self.start_time = time.time() self.epoch = 0 self.num_so_far = 0 self.num_so_far_accum = 0 self.num_so_far_indiv = 0 self.model = model self.optimizer = optimizer self.max_lr = 0.1 self.DUMMY_LR = 0.001 self.batch_size = batch_size self.batch_iterator = batch_iterator self.set_batch_iterator_func() self.warmup_steps = warmup_steps self.num_batches_minimum = num_batches_minimum # TODO(KGF): duplicate/may be in conflict with global_vars.py self.comm = comm self.num_workers = comm.Get_size() self.task_index = comm.Get_rank() self.history = cbks.History() self.model.stop_training = False if (num_replicas is None or num_replicas < 1 or num_replicas > self.num_workers): self.num_replicas = self.num_workers else: self.num_replicas = num_replicas self.lr = (lr / (1.0 + self.num_replicas / 100.0) if (lr < self.max_lr) else self.max_lr / (1.0 + self.num_replicas / 100.0)) def set_batch_iterator_func(self): if (self.conf is not None and 'use_process_generator' in conf['training'] and conf['training']['use_process_generator']): self.batch_iterator_func = ProcessGenerator(self.batch_iterator()) else: self.batch_iterator_func = self.batch_iterator() def close(self): # TODO(KGF): extend __exit__() fn capability when this member # = self.batch_iterator() (i.e. is not a ProcessGenerator()) if (self.conf is not None and 'use_process_generator' in conf['training'] and conf['training']['use_process_generator']): self.batch_iterator_func.__exit__() def set_lr(self, lr): self.lr = lr # KGF: Unused. model.*() called directly in builder.py # def save_weights(self, path, overwrite=False): # self.model.save_weights(path, overwrite=overwrite) # def load_weights(self, path): # self.model.load_weights(path) def compile(self, optimizer, clipnorm, loss='mse'): # TODO(KGF): check the following import taken from runner.py # Was not in this file, originally. from tensorflow.keras.optimizers import (SGD, Adam, RMSprop, Nadam) if optimizer == 'sgd': optimizer_class = SGD(lr=self.DUMMY_LR, clipnorm=clipnorm) elif optimizer == 'momentum_sgd': optimizer_class = SGD(lr=self.DUMMY_LR, clipnorm=clipnorm, decay=1e-6, momentum=0.9) elif optimizer == 'tf_momentum_sgd': # TODO(KGF): removed TFOptimizer wrapper from here and below # may not work anymore? See # https://github.com/tensorflow/tensorflow/issues/22780 optimizer_class = tf.train.MomentumOptimizer( learning_rate=self.DUMMY_LR, momentum=0.9) elif optimizer == 'adam': optimizer_class = Adam(lr=self.DUMMY_LR, clipnorm=clipnorm) elif optimizer == 'tf_adam': optimizer_class = tf.train.AdamOptimizer( learning_rate=self.DUMMY_LR) elif optimizer == 'rmsprop': optimizer_class = RMSprop(lr=self.DUMMY_LR, clipnorm=clipnorm) elif optimizer == 'nadam': optimizer_class = Nadam(lr=self.DUMMY_LR, clipnorm=clipnorm) else: print("Optimizer not implemented yet") exit(1) # Timeline profiler if (self.conf is not None and conf['training']['timeline_prof']): self.run_options = tf.RunOptions( trace_level=tf.RunOptions.FULL_TRACE) self.run_metadata = tf.RunMetadata() self.model.compile(optimizer=optimizer_class, loss=loss, options=self.run_options, run_metadata=self.run_metadata) else: self.model.compile(optimizer=optimizer_class, loss=loss) self.ensure_equal_weights() def ensure_equal_weights(self): if g.task_index == 0: new_weights = self.model.get_weights() else: new_weights = None nw = g.comm.bcast(new_weights, root=0) self.model.set_weights(nw) def train_on_batch_and_get_deltas(self, X_batch, Y_batch, verbose=False): ''' The purpose of the method is to perform a single gradient update over one mini-batch for one model replica. Given a mini-batch, it first accesses the current model weights, performs single gradient update over one mini-batch, gets new model weights, calculates weight updates (deltas) by subtracting weight scalars, applies the learning rate. It performs calls to: subtract_params, multiply_params Argument list: - X_batch: input data for one mini-batch as a Numpy array - Y_batch: labels for one mini-batch as a Numpy array - verbose: set verbosity level (currently unused) Returns: - deltas: a list of model weight updates - loss: scalar training loss ''' weights_before_update = self.model.get_weights() return_sequences = self.conf['model']['return_sequences'] if not return_sequences: Y_batch = Y_batch[:, -1, :] loss = self.model.train_on_batch(X_batch, Y_batch) weights_after_update = self.model.get_weights() self.model.set_weights(weights_before_update) # unscale before subtracting weights_before_update = multiply_params(weights_before_update, 1.0 / self.DUMMY_LR) weights_after_update = multiply_params(weights_after_update, 1.0 / self.DUMMY_LR) deltas = subtract_params(weights_after_update, weights_before_update) # unscale loss if conf['model']['loss_scale_factor'] != 1.0: deltas = multiply_params(deltas, 1.0 / conf['model']['loss_scale_factor']) return deltas, loss def get_new_weights(self, deltas): return add_params(self.model.get_weights(), deltas) def mpi_average_gradients(self, arr, num_replicas=None): if num_replicas is None: num_replicas = self.num_workers if self.task_index >= num_replicas: arr *= 0.0 arr_global = np.empty_like(arr) if K.floatx() == 'float16': self.comm.Allreduce(arr, arr_global, op=mpi_sum_f16) else: self.comm.Allreduce(arr, arr_global, op=MPI.SUM) arr_global /= num_replicas return arr_global def mpi_average_scalars(self, val, num_replicas=None): ''' The purpose of the method is to calculate a simple scalar arithmetic mean over num_replicas. It performs calls to: MPIModel.mpi_sum_scalars Argument list: - val: value averaged, scalar - num_replicas: the size of the ensemble an average is perfromed over Returns: - val_global: scalar arithmetic mean over num_replicas ''' val_global = self.mpi_sum_scalars(val, num_replicas) val_global /= num_replicas return val_global def mpi_sum_scalars(self, val, num_replicas=None): ''' The purpose of the method is to calculate a simple scalar arithmetic mean over num_replicas using MPI allreduce action with fixed op=MPI.SIM Argument list: - val: value averaged, scalar - num_replicas: the size of the ensemble an average is perfromed over Returns: - val_global: scalar arithmetic mean over num_replicas ''' if num_replicas is None: num_replicas = self.num_workers if self.task_index >= num_replicas: val *= 0.0 val_global = 0.0 val_global = self.comm.allreduce(val, op=MPI.SUM) return val_global def sync_deltas(self, deltas, num_replicas=None): global_deltas = [] # default is to reduce the deltas from all workers for delta in deltas: global_deltas.append( self.mpi_average_gradients(delta, num_replicas)) return global_deltas def set_new_weights(self, deltas, num_replicas=None): global_deltas = self.sync_deltas(deltas, num_replicas) effective_lr = self.get_effective_lr(num_replicas) self.optimizer.set_lr(effective_lr) global_deltas = self.optimizer.get_deltas(global_deltas) new_weights = self.get_new_weights(global_deltas) self.model.set_weights(new_weights) def build_callbacks(self, conf, callbacks_list): ''' The purpose of the method is to set up logging and history. It is based on Keras Callbacks https://github.com/fchollet/keras/blob/fbc9a18f0abc5784607cd4a2a3886558efa3f794/keras/callbacks.py Currently used callbacks include: BaseLogger, CSVLogger, EarlyStopping. Other possible callbacks to add in future: RemoteMonitor, LearningRateScheduler Argument list: - conf: There is a "callbacks" section in conf.yaml file. Relevant parameters are: - list: Parameter specifying additional callbacks, read in the driver script and passed as an argument of type list (see next arg) - metrics: List of quantities monitored during training and validation - mode: one of {auto, min, max}. The decision to overwrite the current save file is made based on either the maximization or the minimization of the monitored quantity. For val_acc, this should be max, for val_loss this should be min, etc. In auto mode, the direction is automatically inferred from the name of the monitored quantity. - monitor: Quantity used for early stopping, has to be from the list of metrics - patience: Number of epochs used to decide on whether to apply early stopping or continue training - callbacks_list: uses callbacks.list configuration parameter, specifies the list of additional callbacks Returns: modified list of callbacks ''' mode = conf['callbacks']['mode'] monitor = conf['callbacks']['monitor'] patience = conf['callbacks']['patience'] csvlog_save_path = conf['paths']['csvlog_save_path'] # CSV callback is on by default if not os.path.exists(csvlog_save_path): os.makedirs(csvlog_save_path) callbacks_list = conf['callbacks']['list'] callbacks = [cbks.BaseLogger()] callbacks += [self.history] callbacks += [ cbks.CSVLogger("{}callbacks-{}.log".format( csvlog_save_path, datetime.datetime.now().strftime("%Y-%m-%d-%H-%M-%S"))) ] if "earlystop" in callbacks_list: callbacks += [ cbks.EarlyStopping(patience=patience, monitor=monitor, mode=mode) ] if "lr_scheduler" in callbacks_list: pass return cbks.CallbackList(callbacks) def add_noise(self, X): if self.conf['training']['noise'] is True: prob = 0.05 else: prob = self.conf['training']['noise'] for i in range(0, X.shape[0]): for j in range(0, X.shape[2]): a = random.randint(0, 100) if a < prob * 100: X[i, :, j] = 0.0 return X def train_epoch(self): ''' Perform distributed mini-batch SGD for one epoch. It takes the batch iterator function and a NN model from MPIModel object, fetches mini-batches in a while-loop until number of samples seen by the ensemble of workers (num_so_far) exceeds the training dataset size (num_total). NOTE: "sample" = "an entire shot" within this description During each iteration, the gradient updates (deltas) and the loss are calculated for each model replica in the ensemble, weights are averaged over ensemble, and the new weights are set. It performs calls to: MPIModel.get_deltas, MPIModel.set_new_weights methods Argument list: Empty Returns: - step: final iteration number - ave_loss: model loss averaged over iterations within this epoch - curr_loss: training loss averaged over replicas at final iteration - num_so_far: the cumulative number of samples seen by the ensemble of replicas up to the end of the final iteration (step) of this epoch Intermediate outputs and logging: debug printout of task_index (MPI), epoch number, number of samples seen to a current epoch, average training loss ''' verbose = False first_run = True step = 0 loss_averager = Averager() t_start = time.time() timeline_prof = False if (self.conf is not None and conf['training']['timeline_prof']): timeline_prof = True step_limit = 0 if (self.conf is not None and conf['training']['step_limit'] > 0): step_limit = conf['training']['step_limit'] batch_iterator_func = self.batch_iterator_func num_total = 1 ave_loss = -1 curr_loss = -1 t0 = 0 t1 = 0 t2 = 0 while ((self.num_so_far - self.epoch * num_total) < num_total or step < self.num_batches_minimum): if step_limit > 0 and step > step_limit: print('reached step limit') break try: (batch_xs, batch_ys, batches_to_reset, num_so_far_curr, num_total, is_warmup_period) = next(batch_iterator_func) except StopIteration: g.print_unique("Resetting batch iterator.") self.num_so_far_accum = self.num_so_far_indiv self.set_batch_iterator_func() batch_iterator_func = self.batch_iterator_func (batch_xs, batch_ys, batches_to_reset, num_so_far_curr, num_total, is_warmup_period) = next(batch_iterator_func) self.num_so_far_indiv = self.num_so_far_accum + num_so_far_curr # if batches_to_reset: # self.model.reset_states(batches_to_reset) warmup_phase = (step < self.warmup_steps and self.epoch == 0) num_replicas = 1 if warmup_phase else self.num_replicas self.num_so_far = self.mpi_sum_scalars(self.num_so_far_indiv, num_replicas) # run the model once to force compilation. Don't actually use these # values. if first_run: first_run = False t0_comp = time.time() # print('input_dimension:',batch_xs.shape) # print('output_dimension:',batch_ys.shape) _, _ = self.train_on_batch_and_get_deltas( batch_xs, batch_ys, verbose) self.comm.Barrier() sys.stdout.flush() # TODO(KGF): check line feed/carriage returns around this g.print_unique( '\nCompilation finished in {:.2f}s'.format(time.time() - t0_comp)) t_start = time.time() sys.stdout.flush() if np.any(batches_to_reset): reset_states(self.model, batches_to_reset) if ('noise' in self.conf['training'].keys() and self.conf['training']['noise'] is not False): batch_xs = self.add_noise(batch_xs) t0 = time.time() deltas, loss = self.train_on_batch_and_get_deltas( batch_xs, batch_ys, verbose) t1 = time.time() if not is_warmup_period: self.set_new_weights(deltas, num_replicas) t2 = time.time() write_str_0 = self.calculate_speed(t0, t1, t2, num_replicas) curr_loss = self.mpi_average_scalars(1.0 * loss, num_replicas) # g.print_unique(self.model.get_weights()[0][0][:4]) loss_averager.add_val(curr_loss) ave_loss = loss_averager.get_ave() eta = self.estimate_remaining_time( t0 - t_start, self.num_so_far - self.epoch * num_total, num_total) write_str = ( '\r[{}] step: {} [ETA: {:.2f}s] [{:.2f}/{}], '.format( self.task_index, step, eta, 1.0 * self.num_so_far, num_total) + 'loss: {:.5f} [{:.5f}] | '.format(ave_loss, curr_loss) + 'walltime: {:.4f} | '.format(time.time() - self.start_time)) g.write_unique(write_str + write_str_0) if timeline_prof: # dump profile tl = timeline.Timeline(self.run_metadata.step_stats) ctf = tl.generate_chrome_trace_format() # dump file per iteration with open('timeline_%s.json' % step, 'w') as f: f.write(ctf) step += 1 else: g.write_unique('\r[{}] warmup phase, num so far: {}'.format( self.task_index, self.num_so_far)) effective_epochs = 1.0 * self.num_so_far / num_total epoch_previous = self.epoch self.epoch = effective_epochs g.write_unique( # TODO(KGF): "a total of X epochs within this session" ? '\nFinished training epoch {:.2f} '.format(self.epoch) # TODO(KGF): "precisely/exactly X epochs just passed"? + 'during this session ({:.2f} epochs passed)'.format(self.epoch - epoch_previous) # '\nEpoch {:.2f} finished training ({:.2f} epochs passed)'.format( # 1.0 * self.epoch, self.epoch - epoch_previous) + ' in {:.2f} seconds\n'.format(t2 - t_start)) return (step, ave_loss, curr_loss, self.num_so_far, effective_epochs) def estimate_remaining_time(self, time_so_far, work_so_far, work_total): eps = 1e-6 total_time = 1.0 * time_so_far * work_total / (work_so_far + eps) return total_time - time_so_far def get_effective_lr(self, num_replicas): effective_lr = self.lr * num_replicas if effective_lr > self.max_lr: g.write_unique( 'Warning: effective learning rate set to {}, '.format( effective_lr) + 'larger than maximum {}. Clipping.'.format(self.max_lr)) effective_lr = self.max_lr return effective_lr def get_effective_batch_size(self, num_replicas): return self.batch_size * num_replicas def calculate_speed(self, t0, t_after_deltas, t_after_update, num_replicas, verbose=False): effective_batch_size = self.get_effective_batch_size(num_replicas) t_calculate = t_after_deltas - t0 t_sync = t_after_update - t_after_deltas t_tot = t_after_update - t0 examples_per_sec = effective_batch_size / t_tot frac_calculate = t_calculate / t_tot frac_sync = t_sync / t_tot print_str = ( '{:.2E} Examples/sec | {:.2E} sec/batch '.format( examples_per_sec, t_tot) + '[{:.1%} calc., {:.1%} sync.]'.format(frac_calculate, frac_sync)) print_str += '[batch = {} = {}*{}] [lr = {:.2E} = {:.2E}*{}]'.format( effective_batch_size, self.batch_size, num_replicas, self.get_effective_lr(num_replicas), self.lr, num_replicas) if verbose: g.write_unique(print_str) return print_str
def train(conf,shot_list_train,shot_list_validate,loader): loader.set_inference_mode(False) np.random.seed(1) validation_losses = [] validation_roc = [] training_losses = [] print('validate: {} shots, {} disruptive'.format(len(shot_list_validate),shot_list_validate.num_disruptive())) print('training: {} shots, {} disruptive'.format(len(shot_list_train),shot_list_train.num_disruptive())) if backend == 'tf' or backend == 'tensorflow': first_time = "tensorflow" not in sys.modules if first_time: import tensorflow as tf os.environ['KERAS_BACKEND'] = 'tensorflow' from keras.backend.tensorflow_backend import set_session config = tf.ConfigProto(device_count={"GPU":1}) set_session(tf.Session(config=config)) else: os.environ['KERAS_BACKEND'] = 'theano' os.environ['THEANO_FLAGS'] = 'device=gpu,floatX=float32' import theano from keras.utils.generic_utils import Progbar from keras import backend as K from plasma.models import builder print('Build model...',end='') specific_builder = builder.ModelBuilder(conf) train_model = specific_builder.build_model(False) print('Compile model',end='') train_model.compile(optimizer=optimizer_class(),loss=conf['data']['target'].loss) print('...done') #load the latest epoch we did. Returns -1 if none exist yet e = specific_builder.load_model_weights(train_model) e_start = e batch_generator = partial(loader.training_batch_generator_partial_reset,shot_list=shot_list_train) batch_iterator = ProcessGenerator(batch_generator()) num_epochs = conf['training']['num_epochs'] num_at_once = conf['training']['num_shots_at_once'] lr_decay = conf['model']['lr_decay'] print('{} epochs left to go'.format(num_epochs - 1 - e)) num_so_far_accum = 0 num_so_far = 0 num_total = np.inf if conf['callbacks']['mode'] == 'max': best_so_far = -np.inf cmp_fn = max else: best_so_far = np.inf cmp_fn = min while e < num_epochs-1: e += 1 print('\nEpoch {}/{}'.format(e+1,num_epochs)) pbar = Progbar(len(shot_list_train)) #decay learning rate each epoch: K.set_value(train_model.optimizer.lr, lr*lr_decay**(e)) #print('Learning rate: {}'.format(train_model.optimizer.lr.get_value())) num_batches_minimum = 100 num_batches_current = 0 training_losses_tmp = [] while num_so_far < (e - e_start)*num_total or num_batches_current < num_batches_minimum: num_so_far_old = num_so_far try: batch_xs,batch_ys,batches_to_reset,num_so_far_curr,num_total,is_warmup_period = next(batch_iterator) except StopIteration: print("Resetting batch iterator.") num_so_far_accum = num_so_far batch_iterator = ProcessGenerator(batch_generator()) batch_xs,batch_ys,batches_to_reset,num_so_far_curr,num_total,is_warmup_period = next(batch_iterator) if np.any(batches_to_reset): reset_states(train_model,batches_to_reset) if not is_warmup_period: num_so_far = num_so_far_accum+num_so_far_curr num_batches_current +=1 loss = train_model.train_on_batch(batch_xs,batch_ys) training_losses_tmp.append(loss) pbar.add(num_so_far - num_so_far_old, values=[("train loss", loss)]) loader.verbose=False#True during the first iteration else: _ = train_model.predict(batch_xs,batch_size=conf['training']['batch_size']) e = e_start+1.0*num_so_far/num_total sys.stdout.flush() ave_loss = np.mean(training_losses_tmp) training_losses.append(ave_loss) specific_builder.save_model_weights(train_model,int(round(e))) if conf['training']['validation_frac'] > 0.0: print("prediction on GPU...") _,_,_,roc_area,loss = make_predictions_and_evaluate_gpu(conf,shot_list_validate,loader) validation_losses.append(loss) validation_roc.append(roc_area) epoch_logs = {} epoch_logs['val_roc'] = roc_area epoch_logs['val_loss'] = loss epoch_logs['train_loss'] = ave_loss best_so_far = cmp_fn(epoch_logs[conf['callbacks']['monitor']],best_so_far) if best_so_far != epoch_logs[conf['callbacks']['monitor']]: #only save model weights if quantity we are tracking is improving print("Not saving model weights") specific_builder.delete_model_weights(train_model,int(round(e))) if conf['training']['ranking_difficulty_fac'] != 1.0: _,_,_,roc_area_train,loss_train = make_predictions_and_evaluate_gpu(conf,shot_list_train,loader) batch_iterator.__exit__() batch_generator = partial(loader.training_batch_generator_partial_reset,shot_list=shot_list_train) batch_iterator = ProcessGenerator(batch_generator()) num_so_far_accum = num_so_far print('=========Summary========') print('Training Loss Numpy: {:.3e}'.format(training_losses[-1])) if conf['training']['validation_frac'] > 0.0: print('Validation Loss: {:.3e}'.format(validation_losses[-1])) print('Validation ROC: {:.4f}'.format(validation_roc[-1])) if conf['training']['ranking_difficulty_fac'] != 1.0: print('Train Loss: {:.3e}'.format(loss_train)) print('Train ROC: {:.4f}'.format(roc_area_train)) # plot_losses(conf,[training_losses],specific_builder,name='training') if conf['training']['validation_frac'] > 0.0: plot_losses(conf,[training_losses,validation_losses,validation_roc],specific_builder,name='training_validation_roc') batch_iterator.__exit__() print('...done')
class MPIModel(): def __init__(self,model,optimizer,comm,batch_iterator,batch_size,num_replicas=None,warmup_steps=1000,lr=0.01,num_batches_minimum=100): random.seed(task_index) np.random.seed(task_index) self.start_time = time.time() self.epoch = 0 self.num_so_far = 0 self.num_so_far_accum = 0 self.num_so_far_indiv = 0 self.model = model self.optimizer = optimizer self.max_lr = 0.1 self.DUMMY_LR = 0.001 self.comm = comm self.batch_size = batch_size self.batch_iterator = batch_iterator self.set_batch_iterator_func() self.warmup_steps=warmup_steps self.num_batches_minimum=num_batches_minimum self.num_workers = comm.Get_size() self.task_index = comm.Get_rank() self.history = cbks.History() self.model.stop_training = False if num_replicas is None or num_replicas < 1 or num_replicas > self.num_workers: self.num_replicas = self.num_workers else: self.num_replicas = num_replicas self.lr = lr/(1.0+self.num_replicas/100.0) if (lr < self.max_lr) else self.max_lr/(1.0+self.num_replicas/100.0) def set_batch_iterator_func(self): self.batch_iterator_func = ProcessGenerator(self.batch_iterator()) def close(self): self.batch_iterator_func.__exit__() def set_lr(self,lr): self.lr = lr def save_weights(self,path,overwrite=False): self.model.save_weights(path,overwrite=overwrite) def load_weights(self,path): self.model.load_weights(path) def compile(self,optimizer,clipnorm,loss='mse'): if optimizer == 'sgd': optimizer_class = SGD(lr=self.DUMMY_LR,clipnorm=clipnorm) elif optimizer == 'momentum_sgd': optimizer_class = SGD(lr=self.DUMMY_LR, clipnorm=clipnorm, decay=1e-6, momentum=0.9) elif optimizer == 'tf_momentum_sgd': optimizer_class = TFOptimizer(tf.train.MomentumOptimizer(learning_rate=self.DUMMY_LR,momentum=0.9)) elif optimizer == 'adam': optimizer_class = Adam(lr=self.DUMMY_LR,clipnorm=clipnorm) elif optimizer == 'tf_adam': optimizer_class = TFOptimizer(tf.train.AdamOptimizer(learning_rate=self.DUMMY_LR)) elif optimizer == 'rmsprop': optimizer_class = RMSprop(lr=self.DUMMY_LR,clipnorm=clipnorm) elif optimizer == 'nadam': optimizer_class = Nadam(lr=self.DUMMY_LR,clipnorm=clipnorm) else: print("Optimizer not implemented yet") exit(1) self.model.compile(optimizer=optimizer_class,loss=loss) self.ensure_equal_weights() def ensure_equal_weights(self): if task_index == 0: new_weights = self.model.get_weights() else: new_weights = None nw = comm.bcast(new_weights,root=0) self.model.set_weights(nw) def train_on_batch_and_get_deltas(self,X_batch,Y_batch,verbose=False): ''' The purpose of the method is to perform a single gradient update over one mini-batch for one model replica. Given a mini-batch, it first accesses the current model weights, performs single gradient update over one mini-batch, gets new model weights, calculates weight updates (deltas) by subtracting weight scalars, applies the learning rate. It performs calls to: subtract_params, multiply_params Argument list: - X_batch: input data for one mini-batch as a Numpy array - Y_batch: labels for one mini-batch as a Numpy array - verbose: set verbosity level (currently unused) Returns: - deltas: a list of model weight updates - loss: scalar training loss ''' weights_before_update = self.model.get_weights() loss = self.model.train_on_batch(X_batch,Y_batch) weights_after_update = self.model.get_weights() self.model.set_weights(weights_before_update) #unscale before subtracting weights_before_update = multiply_params(weights_before_update,1.0/self.DUMMY_LR) weights_after_update = multiply_params(weights_after_update,1.0/self.DUMMY_LR) deltas = subtract_params(weights_after_update,weights_before_update) #unscale loss if conf['model']['loss_scale_factor'] != 1.0: deltas = multiply_params(deltas,1.0/conf['model']['loss_scale_factor']) return deltas,loss def get_new_weights(self,deltas): return add_params(self.model.get_weights(),deltas) def mpi_average_gradients(self,arr,num_replicas=None): if num_replicas == None: num_replicas = self.num_workers if self.task_index >= num_replicas: arr *= 0.0 arr_global = np.empty_like(arr) if K.floatx() == 'float16': self.comm.Allreduce(arr,arr_global,op=mpi_sum_f16) else: self.comm.Allreduce(arr,arr_global,op=MPI.SUM) arr_global /= num_replicas return arr_global def mpi_average_scalars(self,val,num_replicas=None): ''' The purpose of the method is to calculate a simple scalar arithmetic mean over num_replicas. It performs calls to: MPIModel.mpi_sum_scalars Argument list: - val: value averaged, scalar - num_replicas: the size of the ensemble an average is perfromed over Returns: - val_global: scalar arithmetic mean over num_replicas ''' val_global = self.mpi_sum_scalars(val,num_replicas) val_global /= num_replicas return val_global def mpi_sum_scalars(self,val,num_replicas=None): ''' The purpose of the method is to calculate a simple scalar arithmetic mean over num_replicas using MPI allreduce action with fixed op=MPI.SIM Argument list: - val: value averaged, scalar - num_replicas: the size of the ensemble an average is perfromed over Returns: - val_global: scalar arithmetic mean over num_replicas ''' if num_replicas == None: num_replicas = self.num_workers if self.task_index >= num_replicas: val *= 0.0 val_global = 0.0 val_global = self.comm.allreduce(val,op=MPI.SUM) return val_global def sync_deltas(self,deltas,num_replicas=None): global_deltas = [] #default is to reduce the deltas from all workers for delta in deltas: global_deltas.append(self.mpi_average_gradients(delta,num_replicas)) return global_deltas def set_new_weights(self,deltas,num_replicas=None): global_deltas = self.sync_deltas(deltas,num_replicas) effective_lr = self.get_effective_lr(num_replicas) self.optimizer.set_lr(effective_lr) global_deltas = self.optimizer.get_deltas(global_deltas) new_weights = self.get_new_weights(global_deltas) self.model.set_weights(new_weights) def build_callbacks(self,conf,callbacks_list): ''' The purpose of the method is to set up logging and history. It is based on Keras Callbacks https://github.com/fchollet/keras/blob/fbc9a18f0abc5784607cd4a2a3886558efa3f794/keras/callbacks.py Currently used callbacks include: BaseLogger, CSVLogger, EarlyStopping. Other possible callbacks to add in future: RemoteMonitor, LearningRateScheduler Argument list: - conf: There is a "callbacks" section in conf.yaml file. Relevant parameters are: list: Parameter specifying additional callbacks, read in the driver script and passed as an argument of type list (see next arg) metrics: List of quantities monitored during training and validation mode: one of {auto, min, max}. The decision to overwrite the current save file is made based on either the maximization or the minimization of the monitored quantity. For val_acc, this should be max, for val_loss this should be min, etc. In auto mode, the direction is automatically inferred from the name of the monitored quantity. monitor: Quantity used for early stopping, has to be from the list of metrics patience: Number of epochs used to decide on whether to apply early stopping or continue training - callbacks_list: uses callbacks.list configuration parameter, specifies the list of additional callbacks Returns: modified list of callbacks ''' mode = conf['callbacks']['mode'] monitor = conf['callbacks']['monitor'] patience = conf['callbacks']['patience'] csvlog_save_path = conf['paths']['csvlog_save_path'] #CSV callback is on by default if not os.path.exists(csvlog_save_path): os.makedirs(csvlog_save_path) callbacks_list = conf['callbacks']['list'] callbacks = [cbks.BaseLogger()] callbacks += [self.history] callbacks += [cbks.CSVLogger("{}callbacks-{}.log".format(csvlog_save_path,datetime.datetime.now().strftime("%Y-%m-%d-%H-%M-%S")))] if "earlystop" in callbacks_list: callbacks += [cbks.EarlyStopping(patience=patience, monitor=monitor, mode=mode)] if "lr_scheduler" in callbacks_list: pass return cbks.CallbackList(callbacks) def train_epoch(self): ''' The purpose of the method is to perform distributed mini-batch SGD for one epoch. It takes the batch iterator function and a NN model from MPIModel object, fetches mini-batches in a while-loop until number of samples seen by the ensemble of workers (num_so_far) exceeds the training dataset size (num_total). During each iteration, the gradient updates (deltas) and the loss are calculated for each model replica in the ensemble, weights are averaged over ensemble, and the new weights are set. It performs calls to: MPIModel.get_deltas, MPIModel.set_new_weights methods Argument list: Empty Returns: - step: epoch number - ave_loss: training loss averaged over replicas - curr_loss: - num_so_far: the number of samples seen by ensemble of replicas to a current epoch (step) Intermediate outputs and logging: debug printout of task_index (MPI), epoch number, number of samples seen to a current epoch, average training loss ''' verbose = False first_run = True step = 0 loss_averager = Averager() t_start = time.time() batch_iterator_func = self.batch_iterator_func num_total = 1 ave_loss = -1 curr_loss = -1 t0 = 0 t1 = 0 t2 = 0 while (self.num_so_far-self.epoch*num_total) < num_total or step < self.num_batches_minimum: try: batch_xs,batch_ys,batches_to_reset,num_so_far_curr,num_total,is_warmup_period = next(batch_iterator_func) except StopIteration: print("Resetting batch iterator.") self.num_so_far_accum = self.num_so_far_indiv self.set_batch_iterator_func() batch_iterator_func = self.batch_iterator_func batch_xs,batch_ys,batches_to_reset,num_so_far_curr,num_total,is_warmup_period = next(batch_iterator_func) self.num_so_far_indiv = self.num_so_far_accum+num_so_far_curr # if batches_to_reset: # self.model.reset_states(batches_to_reset) warmup_phase = (step < self.warmup_steps and self.epoch == 0) num_replicas = 1 if warmup_phase else self.num_replicas self.num_so_far = self.mpi_sum_scalars(self.num_so_far_indiv,num_replicas) #run the model once to force compilation. Don't actually use these values. if first_run: first_run = False t0_comp = time.time() _,_ = self.train_on_batch_and_get_deltas(batch_xs,batch_ys,verbose) self.comm.Barrier() sys.stdout.flush() print_unique('Compilation finished in {:.2f}s'.format(time.time()-t0_comp)) t_start = time.time() sys.stdout.flush() if np.any(batches_to_reset): reset_states(self.model,batches_to_reset) t0 = time.time() deltas,loss = self.train_on_batch_and_get_deltas(batch_xs,batch_ys,verbose) t1 = time.time() if not is_warmup_period: self.set_new_weights(deltas,num_replicas) t2 = time.time() write_str_0 = self.calculate_speed(t0,t1,t2,num_replicas) curr_loss = self.mpi_average_scalars(1.0*loss,num_replicas) #if self.task_index == 0: #print(self.model.get_weights()[0][0][:4]) loss_averager.add_val(curr_loss) ave_loss = loss_averager.get_val() eta = self.estimate_remaining_time(t0 - t_start,self.num_so_far-self.epoch*num_total,num_total) write_str = '\r[{}] step: {} [ETA: {:.2f}s] [{:.2f}/{}], loss: {:.5f} [{:.5f}] | walltime: {:.4f} | '.format(self.task_index,step,eta,1.0*self.num_so_far,num_total,ave_loss,curr_loss,time.time()-self.start_time) print_unique(write_str + write_str_0) step += 1 else: print_unique('\r[{}] warmup phase, num so far: {}'.format(self.task_index,self.num_so_far)) effective_epochs = 1.0*self.num_so_far/num_total epoch_previous = self.epoch self.epoch = effective_epochs print_unique('\nEpoch {:.2f} finished ({:.2f} epochs passed) in {:.2f} seconds.\n'.format(1.0*self.epoch,self.epoch-epoch_previous,t2 - t_start)) return (step,ave_loss,curr_loss,self.num_so_far,effective_epochs) def estimate_remaining_time(self,time_so_far,work_so_far,work_total): eps = 1e-6 total_time = 1.0*time_so_far*work_total/(work_so_far + eps) return total_time - time_so_far def get_effective_lr(self,num_replicas): effective_lr = self.lr * num_replicas if effective_lr > self.max_lr: print_unique('Warning: effective learning rate set to {}, larger than maximum {}. Clipping.'.format(effective_lr,self.max_lr)) effective_lr = self.max_lr return effective_lr def get_effective_batch_size(self,num_replicas): return self.batch_size*num_replicas def calculate_speed(self,t0,t_after_deltas,t_after_update,num_replicas,verbose=False): effective_batch_size = self.get_effective_batch_size(num_replicas) t_calculate = t_after_deltas - t0 t_sync = t_after_update - t_after_deltas t_tot = t_after_update - t0 examples_per_sec = effective_batch_size/t_tot frac_calculate = t_calculate/t_tot frac_sync = t_sync/t_tot print_str = '{:.2E} Examples/sec | {:.2E} sec/batch [{:.1%} calc., {:.1%} synch.]'.format(examples_per_sec,t_tot,frac_calculate,frac_sync) print_str += '[batch = {} = {}*{}] [lr = {:.2E} = {:.2E}*{}]'.format(effective_batch_size,self.batch_size,num_replicas,self.get_effective_lr(num_replicas),self.lr,num_replicas) if verbose: print_unique(print_str) return print_str