示例#1
0
 def set_batch_iterator_func(self):
     if (self.conf is not None
             and 'use_process_generator' in conf['training']
             and conf['training']['use_process_generator']):
         self.batch_iterator_func = ProcessGenerator(self.batch_iterator())
     else:
         self.batch_iterator_func = self.batch_iterator()
示例#2
0
 def set_batch_iterator_func(self):
     self.batch_iterator_func = ProcessGenerator(self.batch_iterator())
示例#3
0
class MPIModel():
    def __init__(self,
                 model,
                 optimizer,
                 comm,
                 batch_iterator,
                 batch_size,
                 num_replicas=None,
                 warmup_steps=1000,
                 lr=0.01,
                 num_batches_minimum=100):
        random.seed(task_index)
        np.random.seed(task_index)
        self.start_time = time.time()
        self.epoch = 0
        self.num_so_far = 0
        self.num_so_far_accum = 0
        self.num_so_far_indiv = 0
        self.model = model
        self.optimizer = optimizer
        self.max_lr = 0.1
        self.DUMMY_LR = 0.001
        self.comm = comm
        self.batch_size = batch_size
        self.batch_iterator = batch_iterator
        self.set_batch_iterator_func()
        self.warmup_steps = warmup_steps
        self.num_batches_minimum = num_batches_minimum
        self.num_workers = comm.Get_size()
        self.task_index = comm.Get_rank()
        self.history = cbks.History()
        if num_replicas is None or num_replicas < 1 or num_replicas > self.num_workers:
            self.num_replicas = self.num_workers
        else:
            self.num_replicas = num_replicas
        self.lr = lr / (1.0 + self.num_replicas / 100.0) if (
            lr < self.max_lr) else self.max_lr / (1.0 +
                                                  self.num_replicas / 100.0)

    def set_batch_iterator_func(self):
        self.batch_iterator_func = ProcessGenerator(self.batch_iterator())

    def close(self):
        self.batch_iterator_func.__exit__()

    def set_lr(self, lr):
        self.lr = lr

    def save_weights(self, path, overwrite=False):
        self.model.save_weights(path, overwrite=overwrite)

    def load_weights(self, path):
        self.model.load_weights(path)

    def compile(self, optimizer, clipnorm, loss='mse'):
        if optimizer == 'sgd':
            optimizer_class = SGD(lr=self.DUMMY_LR, clipnorm=clipnorm)
        elif optimizer == 'momentum_sgd':
            optimizer_class = SGD(lr=self.DUMMY_LR,
                                  clipnorm=clipnorm,
                                  decay=1e-6,
                                  momentum=0.9)
        elif optimizer == 'tf_momentum_sgd':
            optimizer_class = TFOptimizer(
                tf.train.MomentumOptimizer(learning_rate=self.DUMMY_LR,
                                           momentum=0.9))
        elif optimizer == 'adam':
            optimizer_class = Adam(lr=self.DUMMY_LR, clipnorm=clipnorm)
        elif optimizer == 'tf_adam':
            optimizer_class = TFOptimizer(
                tf.train.AdamOptimizer(learning_rate=self.DUMMY_LR))
        elif optimizer == 'rmsprop':
            optimizer_class = RMSprop(lr=self.DUMMY_LR, clipnorm=clipnorm)
        elif optimizer == 'nadam':
            optimizer_class = Nadam(lr=self.DUMMY_LR, clipnorm=clipnorm)
        else:
            print("Optimizer not implemented yet")
            exit(1)
        self.model.compile(optimizer=optimizer_class, loss=loss)

    def train_on_batch_and_get_deltas(self, X_batch, Y_batch, verbose=False):
        '''
    The purpose of the method is to perform a single gradient update over one mini-batch for one model replica.
    Given a mini-batch, it first accesses the current model weights, performs single gradient update over one mini-batch,
    gets new model weights, calculates weight updates (deltas) by subtracting weight scalars, applies the learning rate.

    It performs calls to: subtract_params, multiply_params 

    Argument list: 
      - X_batch: input data for one mini-batch as a Numpy array
      - Y_batch: labels for one mini-batch as a Numpy array
      - verbose: set verbosity level (currently unused)

    Returns:  
      - deltas: a list of model weight updates
      - loss: scalar training loss
    '''
        weights_before_update = self.model.get_weights()

        loss = self.model.train_on_batch(X_batch, Y_batch)

        weights_after_update = self.model.get_weights()
        self.model.set_weights(weights_before_update)

        #unscale before subtracting
        weights_before_update = multiply_params(weights_before_update,
                                                1.0 / self.DUMMY_LR)
        weights_after_update = multiply_params(weights_after_update,
                                               1.0 / self.DUMMY_LR)

        deltas = subtract_params(weights_after_update, weights_before_update)

        #unscale loss
        if conf['model']['loss_scale_factor'] != 1.0:
            deltas = multiply_params(deltas,
                                     1.0 / conf['model']['loss_scale_factor'])

        return deltas, loss

    def get_new_weights(self, deltas):
        return add_params(self.model.get_weights(), deltas)

    def mpi_average_gradients(self, arr, num_replicas=None):
        if num_replicas == None:
            num_replicas = self.num_workers
        if self.task_index >= num_replicas:
            arr *= 0.0
        arr_global = np.empty_like(arr)
        if K.floatx() == 'float16':
            self.comm.Allreduce(arr, arr_global, op=mpi_sum_f16)
        else:
            self.comm.Allreduce(arr, arr_global, op=MPI.SUM)
        arr_global /= num_replicas
        return arr_global

    def mpi_average_scalars(self, val, num_replicas=None):
        '''
    The purpose of the method is to calculate a simple scalar arithmetic mean over num_replicas.

    It performs calls to: MPIModel.mpi_sum_scalars

    Argument list: 
      - val: value averaged, scalar
      - num_replicas: the size of the ensemble an average is perfromed over

    Returns:  
      - val_global: scalar arithmetic mean over num_replicas
    '''
        val_global = self.mpi_sum_scalars(val, num_replicas)
        val_global /= num_replicas
        return val_global

    def mpi_sum_scalars(self, val, num_replicas=None):
        '''
    The purpose of the method is to calculate a simple scalar arithmetic mean over num_replicas using MPI allreduce action with fixed op=MPI.SIM

    Argument list: 
      - val: value averaged, scalar
      - num_replicas: the size of the ensemble an average is perfromed over

    Returns:  
      - val_global: scalar arithmetic mean over num_replicas
    '''
        if num_replicas == None:
            num_replicas = self.num_workers
        if self.task_index >= num_replicas:
            val *= 0.0
        val_global = 0.0
        val_global = self.comm.allreduce(val, op=MPI.SUM)
        return val_global

    def sync_deltas(self, deltas, num_replicas=None):
        global_deltas = []
        #default is to reduce the deltas from all workers
        for delta in deltas:
            global_deltas.append(
                self.mpi_average_gradients(delta, num_replicas))
        return global_deltas

    def set_new_weights(self, deltas, num_replicas=None):
        global_deltas = self.sync_deltas(deltas, num_replicas)
        effective_lr = self.get_effective_lr(num_replicas)

        self.optimizer.set_lr(effective_lr)
        global_deltas = self.optimizer.get_deltas(global_deltas)

        if self.comm.rank == 0:
            new_weights = self.get_new_weights(global_deltas)
        else:
            new_weights = None
        new_weights = self.comm.bcast(new_weights, root=0)
        self.model.set_weights(new_weights)

    def build_callbacks(self, conf, callbacks_list):
        '''
      The purpose of the method is to set up logging and history. It is based on Keras Callbacks
      https://github.com/fchollet/keras/blob/fbc9a18f0abc5784607cd4a2a3886558efa3f794/keras/callbacks.py

      Currently used callbacks include: BaseLogger, CSVLogger, EarlyStopping. 
      Other possible callbacks to add in future: RemoteMonitor, LearningRateScheduler

      Argument list: 
        - conf: There is a "callbacks" section in conf.yaml file. Relevant parameters are:
             list: Parameter specifying additional callbacks, read in the driver script and passed as an argument of type list (see next arg)
             metrics: List of quantities monitored during training and validation
             mode: one of {auto, min, max}. The decision to overwrite the current save file is made based on either the maximization or the minimization of the monitored quantity. For val_acc, this should be max, for val_loss this should be min, etc. In auto mode, the direction is automatically inferred from the name of the monitored quantity. 
             monitor: Quantity used for early stopping, has to be from the list of metrics
             patience: Number of epochs used to decide on whether to apply early stopping or continue training
        - callbacks_list: uses callbacks.list configuration parameter, specifies the list of additional callbacks
      Returns: modified list of callbacks
      '''

        mode = conf['callbacks']['mode']
        monitor = conf['callbacks']['monitor']
        patience = conf['callbacks']['patience']
        csvlog_save_path = conf['paths']['csvlog_save_path']
        #CSV callback is on by default
        if not os.path.exists(csvlog_save_path):
            os.makedirs(csvlog_save_path)

        callbacks_list = conf['callbacks']['list']

        callbacks = [cbks.BaseLogger()]
        callbacks += [self.history]
        callbacks += [
            cbks.CSVLogger("{}callbacks-{}.log".format(
                csvlog_save_path,
                datetime.datetime.now().strftime("%Y-%m-%d-%H-%M-%S")))
        ]

        if "earlystop" in callbacks_list:
            callbacks += [
                cbks.EarlyStopping(patience=patience,
                                   monitor=monitor,
                                   mode=mode)
            ]
        if "lr_scheduler" in callbacks_list:
            pass

        return cbks.CallbackList(callbacks)

    def train_epoch(self):
        '''
    The purpose of the method is to perform distributed mini-batch SGD for one epoch.
    It takes the batch iterator function and a NN model from MPIModel object, fetches mini-batches
    in a while-loop until number of samples seen by the ensemble of workers (num_so_far) exceeds the 
    training dataset size (num_total). 

    During each iteration, the gradient updates (deltas) and the loss are calculated for each model replica
    in the ensemble, weights are averaged over ensemble, and the new weights are set.

    It performs calls to: MPIModel.get_deltas, MPIModel.set_new_weights methods 

    Argument list: Empty

    Returns:  
      - step: epoch number
      - ave_loss: training loss averaged over replicas
      - curr_loss:
      - num_so_far: the number of samples seen by ensemble of replicas to a current epoch (step) 

    Intermediate outputs and logging: debug printout of task_index (MPI), epoch number, number of samples seen to 
    a current epoch, average training loss
    '''

        verbose = False
        first_run = True
        step = 0
        loss_averager = Averager()
        t_start = time.time()

        batch_iterator_func = self.batch_iterator_func
        num_total = 1
        ave_loss = -1
        curr_loss = -1
        t0 = 0
        t1 = 0
        t2 = 0

        while (self.num_so_far - self.epoch *
               num_total) < num_total or step < self.num_batches_minimum:

            try:
                batch_xs, batch_ys, batches_to_reset, num_so_far_curr, num_total, is_warmup_period = next(
                    batch_iterator_func)
            except StopIteration:
                print("Resetting batch iterator.")
                self.num_so_far_accum = self.num_so_far_indiv
                self.set_batch_iterator_func()
                batch_iterator_func = self.batch_iterator_func
                batch_xs, batch_ys, batches_to_reset, num_so_far_curr, num_total, is_warmup_period = next(
                    batch_iterator_func)
            self.num_so_far_indiv = self.num_so_far_accum + num_so_far_curr

            # if batches_to_reset:
            # self.model.reset_states(batches_to_reset)

            warmup_phase = (step < self.warmup_steps and self.epoch == 0)
            num_replicas = 1 if warmup_phase else self.num_replicas

            self.num_so_far = self.mpi_sum_scalars(self.num_so_far_indiv,
                                                   num_replicas)

            #run the model once to force compilation. Don't actually use these values.
            if first_run:
                first_run = False
                t0_comp = time.time()
                _, _ = self.train_on_batch_and_get_deltas(
                    batch_xs, batch_ys, verbose)
                self.comm.Barrier()
                sys.stdout.flush()
                print_unique(
                    'Compilation finished in {:.2f}s'.format(time.time() -
                                                             t0_comp))
                t_start = time.time()
                sys.stdout.flush()

            if np.any(batches_to_reset):
                reset_states(self.model, batches_to_reset)

            t0 = time.time()
            deltas, loss = self.train_on_batch_and_get_deltas(
                batch_xs, batch_ys, verbose)
            t1 = time.time()
            if not is_warmup_period:
                self.set_new_weights(deltas, num_replicas)
                t2 = time.time()
                write_str_0 = self.calculate_speed(t0, t1, t2, num_replicas)

                curr_loss = self.mpi_average_scalars(1.0 * loss, num_replicas)
                #if self.task_index == 0:
                #print(self.model.get_weights()[0][0][:4])
                loss_averager.add_val(curr_loss)
                ave_loss = loss_averager.get_val()
                eta = self.estimate_remaining_time(
                    t0 - t_start, self.num_so_far - self.epoch * num_total,
                    num_total)
                write_str = '\r[{}] step: {} [ETA: {:.2f}s] [{:.2f}/{}], loss: {:.5f} [{:.5f}] | walltime: {:.4f} | '.format(
                    self.task_index, step, eta, 1.0 * self.num_so_far,
                    num_total, ave_loss, curr_loss,
                    time.time() - self.start_time)
                print_unique(write_str + write_str_0)
                step += 1
            else:
                print_unique('\r[{}] warmup phase, num so far: {}'.format(
                    self.task_index, self.num_so_far))

        effective_epochs = 1.0 * self.num_so_far / num_total
        epoch_previous = self.epoch
        self.epoch = effective_epochs
        print_unique(
            '\nEpoch {:.2f} finished ({:.2f} epochs passed) in {:.2f} seconds.\n'
            .format(1.0 * self.epoch, self.epoch - epoch_previous,
                    t2 - t_start))
        return (step, ave_loss, curr_loss, self.num_so_far, effective_epochs)

    def estimate_remaining_time(self, time_so_far, work_so_far, work_total):
        eps = 1e-6
        total_time = 1.0 * time_so_far * work_total / (work_so_far + eps)
        return total_time - time_so_far

    def get_effective_lr(self, num_replicas):
        effective_lr = self.lr * num_replicas
        if effective_lr > self.max_lr:
            print_unique(
                'Warning: effective learning rate set to {}, larger than maximum {}. Clipping.'
                .format(effective_lr, self.max_lr))
            effective_lr = self.max_lr
        return effective_lr

    def get_effective_batch_size(self, num_replicas):
        return self.batch_size * num_replicas

    def calculate_speed(self,
                        t0,
                        t_after_deltas,
                        t_after_update,
                        num_replicas,
                        verbose=False):
        effective_batch_size = self.get_effective_batch_size(num_replicas)
        t_calculate = t_after_deltas - t0
        t_sync = t_after_update - t_after_deltas
        t_tot = t_after_update - t0

        examples_per_sec = effective_batch_size / t_tot
        frac_calculate = t_calculate / t_tot
        frac_sync = t_sync / t_tot

        print_str = '{:.2E} Examples/sec | {:.2E} sec/batch [{:.1%} calc., {:.1%} synch.]'.format(
            examples_per_sec, t_tot, frac_calculate, frac_sync)
        print_str += '[batch = {} = {}*{}] [lr = {:.2E} = {:.2E}*{}]'.format(
            effective_batch_size, self.batch_size, num_replicas,
            self.get_effective_lr(num_replicas), self.lr, num_replicas)
        if verbose:
            print_unique(print_str)
        return print_str
示例#4
0
def train(conf, shot_list_train, shot_list_validate, loader):
    loader.set_inference_mode(False)
    np.random.seed(1)

    validation_losses = []
    validation_roc = []
    training_losses = []
    print('validate: {} shots, {} disruptive'.format(
        len(shot_list_validate), shot_list_validate.num_disruptive()))
    print('training: {} shots, {} disruptive'.format(
        len(shot_list_train), shot_list_train.num_disruptive()))

    if backend == 'tf' or backend == 'tensorflow':
        first_time = "tensorflow" not in sys.modules
        if first_time:
            import tensorflow as tf
            os.environ['KERAS_BACKEND'] = 'tensorflow'
            from keras.backend.tensorflow_backend import set_session
            config = tf.ConfigProto(device_count={"GPU": 1})
            set_session(tf.Session(config=config))
    else:
        os.environ['KERAS_BACKEND'] = 'theano'
        os.environ['THEANO_FLAGS'] = 'device=gpu,floatX=float32'
        import theano

    from keras.utils.generic_utils import Progbar
    from keras import backend as K
    from plasma.models import builder

    print('Build model...', end='')
    specific_builder = builder.ModelBuilder(conf)
    train_model = specific_builder.build_model(False)
    print('Compile model', end='')
    train_model.compile(optimizer=optimizer_class(),
                        loss=conf['data']['target'].loss)
    print('...done')

    #load the latest epoch we did. Returns -1 if none exist yet
    e = specific_builder.load_model_weights(train_model)
    e_start = e
    batch_generator = partial(loader.training_batch_generator_partial_reset,
                              shot_list=shot_list_train)
    batch_iterator = ProcessGenerator(batch_generator())

    num_epochs = conf['training']['num_epochs']
    num_at_once = conf['training']['num_shots_at_once']
    lr_decay = conf['model']['lr_decay']
    print('{} epochs left to go'.format(num_epochs - 1 - e))
    num_so_far_accum = 0
    num_so_far = 0
    num_total = np.inf

    if conf['callbacks']['mode'] == 'max':
        best_so_far = -np.inf
        cmp_fn = max
    else:
        best_so_far = np.inf
        cmp_fn = min

    while e < num_epochs - 1:
        e += 1
        print('\nEpoch {}/{}'.format(e + 1, num_epochs))
        pbar = Progbar(len(shot_list_train))

        #decay learning rate each epoch:
        K.set_value(train_model.optimizer.lr, lr * lr_decay**(e))

        #print('Learning rate: {}'.format(train_model.optimizer.lr.get_value()))
        num_batches_minimum = 100
        num_batches_current = 0
        training_losses_tmp = []

        while num_so_far < (
                e - e_start
        ) * num_total or num_batches_current < num_batches_minimum:
            num_so_far_old = num_so_far
            try:
                batch_xs, batch_ys, batches_to_reset, num_so_far_curr, num_total, is_warmup_period = next(
                    batch_iterator)
            except StopIteration:
                print("Resetting batch iterator.")
                num_so_far_accum = num_so_far
                batch_iterator = ProcessGenerator(batch_generator())
                batch_xs, batch_ys, batches_to_reset, num_so_far_curr, num_total, is_warmup_period = next(
                    batch_iterator)
            if np.any(batches_to_reset):
                reset_states(train_model, batches_to_reset)
            if not is_warmup_period:
                num_so_far = num_so_far_accum + num_so_far_curr

                num_batches_current += 1

                loss = train_model.train_on_batch(batch_xs, batch_ys)
                training_losses_tmp.append(loss)
                pbar.add(num_so_far - num_so_far_old,
                         values=[("train loss", loss)])
                loader.verbose = False  #True during the first iteration
            else:
                _ = train_model.predict(
                    batch_xs, batch_size=conf['training']['batch_size'])

        e = e_start + 1.0 * num_so_far / num_total
        sys.stdout.flush()
        ave_loss = np.mean(training_losses_tmp)
        training_losses.append(ave_loss)
        specific_builder.save_model_weights(train_model, int(round(e)))

        if conf['training']['validation_frac'] > 0.0:
            print("prediction on GPU...")
            _, _, _, roc_area, loss = make_predictions_and_evaluate_gpu(
                conf, shot_list_validate, loader)
            validation_losses.append(loss)
            validation_roc.append(roc_area)

            epoch_logs = {}
            epoch_logs['val_roc'] = roc_area
            epoch_logs['val_loss'] = loss
            epoch_logs['train_loss'] = ave_loss
            best_so_far = cmp_fn(epoch_logs[conf['callbacks']['monitor']],
                                 best_so_far)
            if best_so_far != epoch_logs[
                    conf['callbacks']
                ['monitor']]:  #only save model weights if quantity we are tracking is improving
                print("Not saving model weights")
                specific_builder.delete_model_weights(train_model,
                                                      int(round(e)))

            if conf['training']['ranking_difficulty_fac'] != 1.0:
                _, _, _, roc_area_train, loss_train = make_predictions_and_evaluate_gpu(
                    conf, shot_list_train, loader)
                batch_iterator.__exit__()
                batch_generator = partial(
                    loader.training_batch_generator_partial_reset,
                    shot_list=shot_list_train)
                batch_iterator = ProcessGenerator(batch_generator())
                num_so_far_accum = num_so_far

        print('=========Summary========')
        print('Training Loss Numpy: {:.3e}'.format(training_losses[-1]))
        if conf['training']['validation_frac'] > 0.0:
            print('Validation Loss: {:.3e}'.format(validation_losses[-1]))
            print('Validation ROC: {:.4f}'.format(validation_roc[-1]))
            if conf['training']['ranking_difficulty_fac'] != 1.0:
                print('Train Loss: {:.3e}'.format(loss_train))
                print('Train ROC: {:.4f}'.format(roc_area_train))

    # plot_losses(conf,[training_losses],specific_builder,name='training')
    if conf['training']['validation_frac'] > 0.0:
        plot_losses(conf, [training_losses, validation_losses, validation_roc],
                    specific_builder,
                    name='training_validation_roc')
    batch_iterator.__exit__()
    print('...done')
示例#5
0
class MPIModel():
    def __init__(self,
                 model,
                 optimizer,
                 comm,
                 batch_iterator,
                 batch_size,
                 num_replicas=None,
                 warmup_steps=1000,
                 lr=0.01,
                 num_batches_minimum=100,
                 conf=None):
        random.seed(g.task_index)
        np.random.seed(g.task_index)
        self.conf = conf
        self.start_time = time.time()
        self.epoch = 0
        self.num_so_far = 0
        self.num_so_far_accum = 0
        self.num_so_far_indiv = 0
        self.model = model
        self.optimizer = optimizer
        self.max_lr = 0.1
        self.DUMMY_LR = 0.001
        self.batch_size = batch_size
        self.batch_iterator = batch_iterator
        self.set_batch_iterator_func()
        self.warmup_steps = warmup_steps
        self.num_batches_minimum = num_batches_minimum
        # TODO(KGF): duplicate/may be in conflict with global_vars.py
        self.comm = comm
        self.num_workers = comm.Get_size()
        self.task_index = comm.Get_rank()
        self.history = cbks.History()
        self.model.stop_training = False
        if (num_replicas is None or num_replicas < 1
                or num_replicas > self.num_workers):
            self.num_replicas = self.num_workers
        else:
            self.num_replicas = num_replicas
        self.lr = (lr / (1.0 + self.num_replicas / 100.0) if
                   (lr < self.max_lr) else self.max_lr /
                   (1.0 + self.num_replicas / 100.0))

    def set_batch_iterator_func(self):
        if (self.conf is not None
                and 'use_process_generator' in conf['training']
                and conf['training']['use_process_generator']):
            self.batch_iterator_func = ProcessGenerator(self.batch_iterator())
        else:
            self.batch_iterator_func = self.batch_iterator()

    def close(self):
        # TODO(KGF): extend __exit__() fn capability when this member
        # = self.batch_iterator() (i.e. is not a ProcessGenerator())
        if (self.conf is not None
                and 'use_process_generator' in conf['training']
                and conf['training']['use_process_generator']):
            self.batch_iterator_func.__exit__()

    def set_lr(self, lr):
        self.lr = lr

    # KGF: Unused. model.*() called directly in builder.py
    # def save_weights(self, path, overwrite=False):
    #     self.model.save_weights(path, overwrite=overwrite)

    # def load_weights(self, path):
    #     self.model.load_weights(path)

    def compile(self, optimizer, clipnorm, loss='mse'):
        # TODO(KGF): check the following import taken from runner.py
        # Was not in this file, originally.
        from tensorflow.keras.optimizers import (SGD, Adam, RMSprop, Nadam)
        if optimizer == 'sgd':
            optimizer_class = SGD(lr=self.DUMMY_LR, clipnorm=clipnorm)
        elif optimizer == 'momentum_sgd':
            optimizer_class = SGD(lr=self.DUMMY_LR,
                                  clipnorm=clipnorm,
                                  decay=1e-6,
                                  momentum=0.9)
        elif optimizer == 'tf_momentum_sgd':
            # TODO(KGF): removed TFOptimizer wrapper from here and below
            # may not work anymore? See
            # https://github.com/tensorflow/tensorflow/issues/22780
            optimizer_class = tf.train.MomentumOptimizer(
                learning_rate=self.DUMMY_LR, momentum=0.9)
        elif optimizer == 'adam':
            optimizer_class = Adam(lr=self.DUMMY_LR, clipnorm=clipnorm)
        elif optimizer == 'tf_adam':
            optimizer_class = tf.train.AdamOptimizer(
                learning_rate=self.DUMMY_LR)
        elif optimizer == 'rmsprop':
            optimizer_class = RMSprop(lr=self.DUMMY_LR, clipnorm=clipnorm)
        elif optimizer == 'nadam':
            optimizer_class = Nadam(lr=self.DUMMY_LR, clipnorm=clipnorm)
        else:
            print("Optimizer not implemented yet")
            exit(1)

        # Timeline profiler
        if (self.conf is not None and conf['training']['timeline_prof']):
            self.run_options = tf.RunOptions(
                trace_level=tf.RunOptions.FULL_TRACE)
            self.run_metadata = tf.RunMetadata()
            self.model.compile(optimizer=optimizer_class,
                               loss=loss,
                               options=self.run_options,
                               run_metadata=self.run_metadata)
        else:
            self.model.compile(optimizer=optimizer_class, loss=loss)

        self.ensure_equal_weights()

    def ensure_equal_weights(self):
        if g.task_index == 0:
            new_weights = self.model.get_weights()
        else:
            new_weights = None
        nw = g.comm.bcast(new_weights, root=0)
        self.model.set_weights(nw)

    def train_on_batch_and_get_deltas(self, X_batch, Y_batch, verbose=False):
        '''
        The purpose of the method is to perform a single gradient update over
        one mini-batch for one model replica.  Given a mini-batch, it first
        accesses the current model weights, performs single gradient update
        over one mini-batch, gets new model weights, calculates weight updates
        (deltas) by subtracting weight scalars, applies the learning rate.

        It performs calls to: subtract_params, multiply_params

        Argument list:
          - X_batch: input data for one mini-batch as a Numpy array
          - Y_batch: labels for one mini-batch as a Numpy array
          - verbose: set verbosity level (currently unused)

        Returns:
          - deltas: a list of model weight updates
          - loss: scalar training loss

        '''
        weights_before_update = self.model.get_weights()

        return_sequences = self.conf['model']['return_sequences']
        if not return_sequences:
            Y_batch = Y_batch[:, -1, :]
        loss = self.model.train_on_batch(X_batch, Y_batch)

        weights_after_update = self.model.get_weights()
        self.model.set_weights(weights_before_update)

        # unscale before subtracting
        weights_before_update = multiply_params(weights_before_update,
                                                1.0 / self.DUMMY_LR)
        weights_after_update = multiply_params(weights_after_update,
                                               1.0 / self.DUMMY_LR)

        deltas = subtract_params(weights_after_update, weights_before_update)

        # unscale loss
        if conf['model']['loss_scale_factor'] != 1.0:
            deltas = multiply_params(deltas,
                                     1.0 / conf['model']['loss_scale_factor'])

        return deltas, loss

    def get_new_weights(self, deltas):
        return add_params(self.model.get_weights(), deltas)

    def mpi_average_gradients(self, arr, num_replicas=None):
        if num_replicas is None:
            num_replicas = self.num_workers
        if self.task_index >= num_replicas:
            arr *= 0.0
        arr_global = np.empty_like(arr)
        if K.floatx() == 'float16':
            self.comm.Allreduce(arr, arr_global, op=mpi_sum_f16)
        else:
            self.comm.Allreduce(arr, arr_global, op=MPI.SUM)
        arr_global /= num_replicas
        return arr_global

    def mpi_average_scalars(self, val, num_replicas=None):
        '''
        The purpose of the method is to calculate a simple scalar arithmetic
        mean over num_replicas.

        It performs calls to: MPIModel.mpi_sum_scalars

        Argument list:
          - val: value averaged, scalar
          - num_replicas: the size of the ensemble an average is perfromed over

        Returns:
          - val_global: scalar arithmetic mean over num_replicas
        '''
        val_global = self.mpi_sum_scalars(val, num_replicas)
        val_global /= num_replicas
        return val_global

    def mpi_sum_scalars(self, val, num_replicas=None):
        '''
        The purpose of the method is to calculate a simple scalar arithmetic
        mean over num_replicas using MPI allreduce action with fixed op=MPI.SIM

        Argument list:
          - val: value averaged, scalar
          - num_replicas: the size of the ensemble an average is perfromed over

        Returns:
          - val_global: scalar arithmetic mean over num_replicas
        '''
        if num_replicas is None:
            num_replicas = self.num_workers
        if self.task_index >= num_replicas:
            val *= 0.0
        val_global = 0.0
        val_global = self.comm.allreduce(val, op=MPI.SUM)
        return val_global

    def sync_deltas(self, deltas, num_replicas=None):
        global_deltas = []
        # default is to reduce the deltas from all workers
        for delta in deltas:
            global_deltas.append(
                self.mpi_average_gradients(delta, num_replicas))
        return global_deltas

    def set_new_weights(self, deltas, num_replicas=None):
        global_deltas = self.sync_deltas(deltas, num_replicas)
        effective_lr = self.get_effective_lr(num_replicas)

        self.optimizer.set_lr(effective_lr)
        global_deltas = self.optimizer.get_deltas(global_deltas)

        new_weights = self.get_new_weights(global_deltas)
        self.model.set_weights(new_weights)

    def build_callbacks(self, conf, callbacks_list):
        '''
        The purpose of the method is to set up logging and history. It is based
        on Keras Callbacks
        https://github.com/fchollet/keras/blob/fbc9a18f0abc5784607cd4a2a3886558efa3f794/keras/callbacks.py

        Currently used callbacks include: BaseLogger, CSVLogger, EarlyStopping.
        Other possible callbacks to add in future: RemoteMonitor,
        LearningRateScheduler

        Argument list:
        - conf: There is a "callbacks" section in conf.yaml file.

        Relevant parameters are:
        - list: Parameter specifying additional callbacks, read
        in the driver script and passed as an argument of type  list (see next
        arg)

        - metrics: List of quantities monitored during training and validation

        - mode: one of {auto, min, max}. The decision to overwrite the current
        save file is made based on either the maximization or the minimization
        of the monitored quantity. For val_acc, this should be max, for
        val_loss this should be min, etc. In auto mode, the direction is
        automatically inferred from the name of the monitored quantity.

        - monitor: Quantity used for early stopping, has to
        be from the list of metrics

        - patience: Number of epochs used to decide on whether to apply early
          stopping or continue training

        - callbacks_list: uses callbacks.list configuration parameter,
          specifies the list of additional callbacks Returns: modified list of
          callbacks

        '''

        mode = conf['callbacks']['mode']
        monitor = conf['callbacks']['monitor']
        patience = conf['callbacks']['patience']
        csvlog_save_path = conf['paths']['csvlog_save_path']
        # CSV callback is on by default
        if not os.path.exists(csvlog_save_path):
            os.makedirs(csvlog_save_path)

        callbacks_list = conf['callbacks']['list']
        callbacks = [cbks.BaseLogger()]
        callbacks += [self.history]
        callbacks += [
            cbks.CSVLogger("{}callbacks-{}.log".format(
                csvlog_save_path,
                datetime.datetime.now().strftime("%Y-%m-%d-%H-%M-%S")))
        ]

        if "earlystop" in callbacks_list:
            callbacks += [
                cbks.EarlyStopping(patience=patience,
                                   monitor=monitor,
                                   mode=mode)
            ]
        if "lr_scheduler" in callbacks_list:
            pass

        return cbks.CallbackList(callbacks)

    def add_noise(self, X):
        if self.conf['training']['noise'] is True:
            prob = 0.05
        else:
            prob = self.conf['training']['noise']
        for i in range(0, X.shape[0]):
            for j in range(0, X.shape[2]):
                a = random.randint(0, 100)
                if a < prob * 100:
                    X[i, :, j] = 0.0
        return X

    def train_epoch(self):
        '''
        Perform distributed mini-batch SGD for
        one epoch.  It takes the batch iterator function and a NN model from
        MPIModel object, fetches mini-batches in a while-loop until number of
        samples seen by the ensemble of workers (num_so_far) exceeds the
        training dataset size (num_total).

        NOTE: "sample" = "an entire shot" within this description

        During each iteration, the gradient updates (deltas) and the loss are
        calculated for each model replica in the ensemble, weights are averaged
        over ensemble, and the new weights are set.

        It performs calls to: MPIModel.get_deltas, MPIModel.set_new_weights
        methods

        Argument list: Empty

        Returns:
          - step: final iteration number
          - ave_loss: model loss averaged over iterations within this epoch
          - curr_loss: training loss averaged over replicas at final iteration
          - num_so_far: the cumulative number of samples seen by the ensemble
        of replicas up to the end of the final iteration (step) of this epoch

        Intermediate outputs and logging: debug printout of task_index (MPI),
        epoch number, number of samples seen to a current epoch, average
        training loss

        '''

        verbose = False
        first_run = True
        step = 0
        loss_averager = Averager()
        t_start = time.time()

        timeline_prof = False
        if (self.conf is not None and conf['training']['timeline_prof']):
            timeline_prof = True
        step_limit = 0
        if (self.conf is not None and conf['training']['step_limit'] > 0):
            step_limit = conf['training']['step_limit']

        batch_iterator_func = self.batch_iterator_func
        num_total = 1
        ave_loss = -1
        curr_loss = -1
        t0 = 0
        t1 = 0
        t2 = 0

        while ((self.num_so_far - self.epoch * num_total) < num_total
               or step < self.num_batches_minimum):
            if step_limit > 0 and step > step_limit:
                print('reached step limit')
                break
            try:
                (batch_xs, batch_ys, batches_to_reset, num_so_far_curr,
                 num_total, is_warmup_period) = next(batch_iterator_func)
            except StopIteration:
                g.print_unique("Resetting batch iterator.")
                self.num_so_far_accum = self.num_so_far_indiv
                self.set_batch_iterator_func()
                batch_iterator_func = self.batch_iterator_func
                (batch_xs, batch_ys, batches_to_reset, num_so_far_curr,
                 num_total, is_warmup_period) = next(batch_iterator_func)
            self.num_so_far_indiv = self.num_so_far_accum + num_so_far_curr

            # if batches_to_reset:
            # self.model.reset_states(batches_to_reset)

            warmup_phase = (step < self.warmup_steps and self.epoch == 0)
            num_replicas = 1 if warmup_phase else self.num_replicas

            self.num_so_far = self.mpi_sum_scalars(self.num_so_far_indiv,
                                                   num_replicas)

            # run the model once to force compilation. Don't actually use these
            # values.
            if first_run:
                first_run = False
                t0_comp = time.time()
                #   print('input_dimension:',batch_xs.shape)
                #   print('output_dimension:',batch_ys.shape)
                _, _ = self.train_on_batch_and_get_deltas(
                    batch_xs, batch_ys, verbose)
                self.comm.Barrier()
                sys.stdout.flush()
                # TODO(KGF): check line feed/carriage returns around this
                g.print_unique(
                    '\nCompilation finished in {:.2f}s'.format(time.time() -
                                                               t0_comp))
                t_start = time.time()
                sys.stdout.flush()

            if np.any(batches_to_reset):
                reset_states(self.model, batches_to_reset)
            if ('noise' in self.conf['training'].keys()
                    and self.conf['training']['noise'] is not False):
                batch_xs = self.add_noise(batch_xs)
            t0 = time.time()
            deltas, loss = self.train_on_batch_and_get_deltas(
                batch_xs, batch_ys, verbose)
            t1 = time.time()
            if not is_warmup_period:
                self.set_new_weights(deltas, num_replicas)
                t2 = time.time()
                write_str_0 = self.calculate_speed(t0, t1, t2, num_replicas)
                curr_loss = self.mpi_average_scalars(1.0 * loss, num_replicas)
                # g.print_unique(self.model.get_weights()[0][0][:4])
                loss_averager.add_val(curr_loss)
                ave_loss = loss_averager.get_ave()
                eta = self.estimate_remaining_time(
                    t0 - t_start, self.num_so_far - self.epoch * num_total,
                    num_total)
                write_str = (
                    '\r[{}] step: {} [ETA: {:.2f}s] [{:.2f}/{}], '.format(
                        self.task_index, step, eta, 1.0 * self.num_so_far,
                        num_total) +
                    'loss: {:.5f} [{:.5f}] | '.format(ave_loss, curr_loss) +
                    'walltime: {:.4f} | '.format(time.time() -
                                                 self.start_time))
                g.write_unique(write_str + write_str_0)

                if timeline_prof:
                    # dump profile
                    tl = timeline.Timeline(self.run_metadata.step_stats)
                    ctf = tl.generate_chrome_trace_format()
                    # dump file per iteration
                    with open('timeline_%s.json' % step, 'w') as f:
                        f.write(ctf)

                step += 1
            else:
                g.write_unique('\r[{}] warmup phase, num so far: {}'.format(
                    self.task_index, self.num_so_far))

        effective_epochs = 1.0 * self.num_so_far / num_total
        epoch_previous = self.epoch
        self.epoch = effective_epochs
        g.write_unique(
            # TODO(KGF): "a total of X epochs within this session" ?
            '\nFinished training epoch {:.2f} '.format(self.epoch)
            # TODO(KGF): "precisely/exactly X epochs just passed"?
            +
            'during this session ({:.2f} epochs passed)'.format(self.epoch -
                                                                epoch_previous)
            # '\nEpoch {:.2f} finished training ({:.2f} epochs passed)'.format(
            #     1.0 * self.epoch, self.epoch - epoch_previous)
            + ' in {:.2f} seconds\n'.format(t2 - t_start))
        return (step, ave_loss, curr_loss, self.num_so_far, effective_epochs)

    def estimate_remaining_time(self, time_so_far, work_so_far, work_total):
        eps = 1e-6
        total_time = 1.0 * time_so_far * work_total / (work_so_far + eps)
        return total_time - time_so_far

    def get_effective_lr(self, num_replicas):
        effective_lr = self.lr * num_replicas
        if effective_lr > self.max_lr:
            g.write_unique(
                'Warning: effective learning rate set to {}, '.format(
                    effective_lr) +
                'larger than maximum {}. Clipping.'.format(self.max_lr))
            effective_lr = self.max_lr
        return effective_lr

    def get_effective_batch_size(self, num_replicas):
        return self.batch_size * num_replicas

    def calculate_speed(self,
                        t0,
                        t_after_deltas,
                        t_after_update,
                        num_replicas,
                        verbose=False):
        effective_batch_size = self.get_effective_batch_size(num_replicas)
        t_calculate = t_after_deltas - t0
        t_sync = t_after_update - t_after_deltas
        t_tot = t_after_update - t0

        examples_per_sec = effective_batch_size / t_tot
        frac_calculate = t_calculate / t_tot
        frac_sync = t_sync / t_tot

        print_str = (
            '{:.2E} Examples/sec | {:.2E} sec/batch '.format(
                examples_per_sec, t_tot) +
            '[{:.1%} calc., {:.1%} sync.]'.format(frac_calculate, frac_sync))
        print_str += '[batch = {} = {}*{}] [lr = {:.2E} = {:.2E}*{}]'.format(
            effective_batch_size, self.batch_size, num_replicas,
            self.get_effective_lr(num_replicas), self.lr, num_replicas)
        if verbose:
            g.write_unique(print_str)
        return print_str
示例#6
0
def train(conf,shot_list_train,shot_list_validate,loader):
    loader.set_inference_mode(False)
    np.random.seed(1)

    validation_losses = []
    validation_roc = []
    training_losses = []
    print('validate: {} shots, {} disruptive'.format(len(shot_list_validate),shot_list_validate.num_disruptive()))
    print('training: {} shots, {} disruptive'.format(len(shot_list_train),shot_list_train.num_disruptive()))

    if backend == 'tf' or backend == 'tensorflow':
        first_time = "tensorflow" not in sys.modules
        if first_time:
                import tensorflow as tf
                os.environ['KERAS_BACKEND'] = 'tensorflow'
                from keras.backend.tensorflow_backend import set_session
                config = tf.ConfigProto(device_count={"GPU":1})
                set_session(tf.Session(config=config))
    else:
        os.environ['KERAS_BACKEND'] = 'theano'
        os.environ['THEANO_FLAGS'] = 'device=gpu,floatX=float32'
        import theano

    from keras.utils.generic_utils import Progbar 
    from keras import backend as K
    from plasma.models import builder

    print('Build model...',end='')
    specific_builder = builder.ModelBuilder(conf)
    train_model = specific_builder.build_model(False) 
    print('Compile model',end='')
    train_model.compile(optimizer=optimizer_class(),loss=conf['data']['target'].loss)
    print('...done')

    #load the latest epoch we did. Returns -1 if none exist yet
    e = specific_builder.load_model_weights(train_model)
    e_start = e
    batch_generator = partial(loader.training_batch_generator_partial_reset,shot_list=shot_list_train)
    batch_iterator = ProcessGenerator(batch_generator())

    num_epochs = conf['training']['num_epochs']
    num_at_once = conf['training']['num_shots_at_once']
    lr_decay = conf['model']['lr_decay']
    print('{} epochs left to go'.format(num_epochs - 1 - e))
    num_so_far_accum = 0
    num_so_far = 0
    num_total = np.inf

    if conf['callbacks']['mode'] == 'max':
        best_so_far = -np.inf
        cmp_fn = max
    else:
        best_so_far = np.inf
        cmp_fn = min

    while e < num_epochs-1:
        e += 1
        print('\nEpoch {}/{}'.format(e+1,num_epochs))
        pbar =  Progbar(len(shot_list_train))

        #decay learning rate each epoch:
        K.set_value(train_model.optimizer.lr, lr*lr_decay**(e))
        
        #print('Learning rate: {}'.format(train_model.optimizer.lr.get_value()))
        num_batches_minimum = 100
        num_batches_current = 0
        training_losses_tmp = []

        while num_so_far < (e - e_start)*num_total or num_batches_current < num_batches_minimum:
            num_so_far_old = num_so_far
            try:
                batch_xs,batch_ys,batches_to_reset,num_so_far_curr,num_total,is_warmup_period = next(batch_iterator)
            except StopIteration:
                print("Resetting batch iterator.")
                num_so_far_accum = num_so_far
                batch_iterator = ProcessGenerator(batch_generator())
                batch_xs,batch_ys,batches_to_reset,num_so_far_curr,num_total,is_warmup_period = next(batch_iterator)
            if np.any(batches_to_reset):
                reset_states(train_model,batches_to_reset)
            if not is_warmup_period:
                num_so_far = num_so_far_accum+num_so_far_curr

                num_batches_current +=1 


                loss = train_model.train_on_batch(batch_xs,batch_ys)
                training_losses_tmp.append(loss)
                pbar.add(num_so_far - num_so_far_old, values=[("train loss", loss)])
                loader.verbose=False#True during the first iteration
            else:
                _ = train_model.predict(batch_xs,batch_size=conf['training']['batch_size'])


        e = e_start+1.0*num_so_far/num_total
        sys.stdout.flush()
        ave_loss = np.mean(training_losses_tmp)
        training_losses.append(ave_loss)
        specific_builder.save_model_weights(train_model,int(round(e)))

        if conf['training']['validation_frac'] > 0.0:
            print("prediction on GPU...")
            _,_,_,roc_area,loss = make_predictions_and_evaluate_gpu(conf,shot_list_validate,loader)
            validation_losses.append(loss)
            validation_roc.append(roc_area)

            epoch_logs = {}
            epoch_logs['val_roc'] = roc_area 
            epoch_logs['val_loss'] = loss
            epoch_logs['train_loss'] = ave_loss
            best_so_far = cmp_fn(epoch_logs[conf['callbacks']['monitor']],best_so_far)
            if best_so_far != epoch_logs[conf['callbacks']['monitor']]: #only save model weights if quantity we are tracking is improving
                print("Not saving model weights")
                specific_builder.delete_model_weights(train_model,int(round(e)))

            if conf['training']['ranking_difficulty_fac'] != 1.0:
                _,_,_,roc_area_train,loss_train = make_predictions_and_evaluate_gpu(conf,shot_list_train,loader)
                batch_iterator.__exit__()
                batch_generator = partial(loader.training_batch_generator_partial_reset,shot_list=shot_list_train)
                batch_iterator = ProcessGenerator(batch_generator())
                num_so_far_accum = num_so_far

        print('=========Summary========')
        print('Training Loss Numpy: {:.3e}'.format(training_losses[-1]))
        if conf['training']['validation_frac'] > 0.0:
            print('Validation Loss: {:.3e}'.format(validation_losses[-1]))
            print('Validation ROC: {:.4f}'.format(validation_roc[-1]))
            if conf['training']['ranking_difficulty_fac'] != 1.0:
                print('Train Loss: {:.3e}'.format(loss_train))
                print('Train ROC: {:.4f}'.format(roc_area_train))
        


    # plot_losses(conf,[training_losses],specific_builder,name='training')
    if conf['training']['validation_frac'] > 0.0:
        plot_losses(conf,[training_losses,validation_losses,validation_roc],specific_builder,name='training_validation_roc')
    batch_iterator.__exit__()
    print('...done')
示例#7
0
 def set_batch_iterator_func(self):
   self.batch_iterator_func = ProcessGenerator(self.batch_iterator())
示例#8
0
class MPIModel():
  def __init__(self,model,optimizer,comm,batch_iterator,batch_size,num_replicas=None,warmup_steps=1000,lr=0.01,num_batches_minimum=100):
    random.seed(task_index)
    np.random.seed(task_index)
    self.start_time = time.time()
    self.epoch = 0
    self.num_so_far = 0
    self.num_so_far_accum = 0
    self.num_so_far_indiv = 0
    self.model = model
    self.optimizer = optimizer
    self.max_lr = 0.1
    self.DUMMY_LR = 0.001
    self.comm = comm
    self.batch_size = batch_size
    self.batch_iterator = batch_iterator
    self.set_batch_iterator_func()
    self.warmup_steps=warmup_steps
    self.num_batches_minimum=num_batches_minimum
    self.num_workers = comm.Get_size()
    self.task_index = comm.Get_rank()
    self.history = cbks.History()
    self.model.stop_training = False
    if num_replicas is None or num_replicas < 1 or num_replicas > self.num_workers:
        self.num_replicas = self.num_workers
    else:
        self.num_replicas = num_replicas
    self.lr = lr/(1.0+self.num_replicas/100.0) if (lr < self.max_lr) else self.max_lr/(1.0+self.num_replicas/100.0)


  def set_batch_iterator_func(self):
    self.batch_iterator_func = ProcessGenerator(self.batch_iterator())

  def close(self):
    self.batch_iterator_func.__exit__()

  def set_lr(self,lr):
    self.lr = lr

  def save_weights(self,path,overwrite=False):
    self.model.save_weights(path,overwrite=overwrite)

  def load_weights(self,path):
    self.model.load_weights(path)

  def compile(self,optimizer,clipnorm,loss='mse'):
    if optimizer == 'sgd':
        optimizer_class = SGD(lr=self.DUMMY_LR,clipnorm=clipnorm)
    elif optimizer == 'momentum_sgd':
        optimizer_class = SGD(lr=self.DUMMY_LR, clipnorm=clipnorm, decay=1e-6, momentum=0.9)
    elif optimizer == 'tf_momentum_sgd':
        optimizer_class = TFOptimizer(tf.train.MomentumOptimizer(learning_rate=self.DUMMY_LR,momentum=0.9))
    elif optimizer == 'adam':
        optimizer_class = Adam(lr=self.DUMMY_LR,clipnorm=clipnorm)
    elif optimizer == 'tf_adam':
        optimizer_class = TFOptimizer(tf.train.AdamOptimizer(learning_rate=self.DUMMY_LR))
    elif optimizer == 'rmsprop':
        optimizer_class = RMSprop(lr=self.DUMMY_LR,clipnorm=clipnorm)
    elif optimizer == 'nadam':
        optimizer_class = Nadam(lr=self.DUMMY_LR,clipnorm=clipnorm)
    else:
        print("Optimizer not implemented yet")
        exit(1)
    self.model.compile(optimizer=optimizer_class,loss=loss)
    self.ensure_equal_weights()
    
  def ensure_equal_weights(self):
    if task_index == 0:
        new_weights = self.model.get_weights()
    else:
        new_weights = None
    nw = comm.bcast(new_weights,root=0)
    self.model.set_weights(nw)



  def train_on_batch_and_get_deltas(self,X_batch,Y_batch,verbose=False):
    '''
    The purpose of the method is to perform a single gradient update over one mini-batch for one model replica.
    Given a mini-batch, it first accesses the current model weights, performs single gradient update over one mini-batch,
    gets new model weights, calculates weight updates (deltas) by subtracting weight scalars, applies the learning rate.

    It performs calls to: subtract_params, multiply_params 

    Argument list: 
      - X_batch: input data for one mini-batch as a Numpy array
      - Y_batch: labels for one mini-batch as a Numpy array
      - verbose: set verbosity level (currently unused)

    Returns:  
      - deltas: a list of model weight updates
      - loss: scalar training loss
    '''
    weights_before_update = self.model.get_weights()

    loss = self.model.train_on_batch(X_batch,Y_batch)

    weights_after_update = self.model.get_weights()
    self.model.set_weights(weights_before_update)
 
    #unscale before subtracting
    weights_before_update = multiply_params(weights_before_update,1.0/self.DUMMY_LR) 
    weights_after_update = multiply_params(weights_after_update,1.0/self.DUMMY_LR) 

    deltas = subtract_params(weights_after_update,weights_before_update)
    
    #unscale loss
    if conf['model']['loss_scale_factor'] != 1.0:
        deltas = multiply_params(deltas,1.0/conf['model']['loss_scale_factor'])
 
    return deltas,loss


  def get_new_weights(self,deltas):
    return add_params(self.model.get_weights(),deltas)

  def mpi_average_gradients(self,arr,num_replicas=None):
    if num_replicas == None:
      num_replicas = self.num_workers 
    if self.task_index >= num_replicas:
      arr *= 0.0
    arr_global = np.empty_like(arr)
    if K.floatx() == 'float16':
        self.comm.Allreduce(arr,arr_global,op=mpi_sum_f16)
    else:
        self.comm.Allreduce(arr,arr_global,op=MPI.SUM)
    arr_global /= num_replicas
    return arr_global



  def mpi_average_scalars(self,val,num_replicas=None):
    '''
    The purpose of the method is to calculate a simple scalar arithmetic mean over num_replicas.

    It performs calls to: MPIModel.mpi_sum_scalars

    Argument list: 
      - val: value averaged, scalar
      - num_replicas: the size of the ensemble an average is perfromed over

    Returns:  
      - val_global: scalar arithmetic mean over num_replicas
    '''
    val_global = self.mpi_sum_scalars(val,num_replicas)
    val_global /= num_replicas
    return val_global


  def mpi_sum_scalars(self,val,num_replicas=None):
    '''
    The purpose of the method is to calculate a simple scalar arithmetic mean over num_replicas using MPI allreduce action with fixed op=MPI.SIM

    Argument list: 
      - val: value averaged, scalar
      - num_replicas: the size of the ensemble an average is perfromed over

    Returns:  
      - val_global: scalar arithmetic mean over num_replicas
    '''
    if num_replicas == None:
      num_replicas = self.num_workers 
    if self.task_index >= num_replicas:
      val *= 0.0
    val_global = 0.0 
    val_global = self.comm.allreduce(val,op=MPI.SUM)
    return val_global



  def sync_deltas(self,deltas,num_replicas=None):
    global_deltas = []
    #default is to reduce the deltas from all workers
    for delta in deltas:
      global_deltas.append(self.mpi_average_gradients(delta,num_replicas))
    return global_deltas 

  def set_new_weights(self,deltas,num_replicas=None):
    global_deltas = self.sync_deltas(deltas,num_replicas)
    effective_lr = self.get_effective_lr(num_replicas)

    self.optimizer.set_lr(effective_lr)
    global_deltas = self.optimizer.get_deltas(global_deltas)

    new_weights = self.get_new_weights(global_deltas)
    self.model.set_weights(new_weights)

  def build_callbacks(self,conf,callbacks_list):
      '''
      The purpose of the method is to set up logging and history. It is based on Keras Callbacks
      https://github.com/fchollet/keras/blob/fbc9a18f0abc5784607cd4a2a3886558efa3f794/keras/callbacks.py

      Currently used callbacks include: BaseLogger, CSVLogger, EarlyStopping. 
      Other possible callbacks to add in future: RemoteMonitor, LearningRateScheduler

      Argument list: 
        - conf: There is a "callbacks" section in conf.yaml file. Relevant parameters are:
             list: Parameter specifying additional callbacks, read in the driver script and passed as an argument of type list (see next arg)
             metrics: List of quantities monitored during training and validation
             mode: one of {auto, min, max}. The decision to overwrite the current save file is made based on either the maximization or the minimization of the monitored quantity. For val_acc, this should be max, for val_loss this should be min, etc. In auto mode, the direction is automatically inferred from the name of the monitored quantity. 
             monitor: Quantity used for early stopping, has to be from the list of metrics
             patience: Number of epochs used to decide on whether to apply early stopping or continue training
        - callbacks_list: uses callbacks.list configuration parameter, specifies the list of additional callbacks
      Returns: modified list of callbacks
      '''

      mode = conf['callbacks']['mode']
      monitor = conf['callbacks']['monitor']
      patience = conf['callbacks']['patience']
      csvlog_save_path = conf['paths']['csvlog_save_path']
      #CSV callback is on by default
      if not os.path.exists(csvlog_save_path):
          os.makedirs(csvlog_save_path)

      callbacks_list = conf['callbacks']['list']

      callbacks = [cbks.BaseLogger()]
      callbacks += [self.history]
      callbacks += [cbks.CSVLogger("{}callbacks-{}.log".format(csvlog_save_path,datetime.datetime.now().strftime("%Y-%m-%d-%H-%M-%S")))]

      if "earlystop" in callbacks_list: 
          callbacks += [cbks.EarlyStopping(patience=patience, monitor=monitor, mode=mode)]
      if "lr_scheduler" in callbacks_list: 
          pass
      
      return cbks.CallbackList(callbacks)

  def train_epoch(self):
    '''
    The purpose of the method is to perform distributed mini-batch SGD for one epoch.
    It takes the batch iterator function and a NN model from MPIModel object, fetches mini-batches
    in a while-loop until number of samples seen by the ensemble of workers (num_so_far) exceeds the 
    training dataset size (num_total). 

    During each iteration, the gradient updates (deltas) and the loss are calculated for each model replica
    in the ensemble, weights are averaged over ensemble, and the new weights are set.

    It performs calls to: MPIModel.get_deltas, MPIModel.set_new_weights methods 

    Argument list: Empty

    Returns:  
      - step: epoch number
      - ave_loss: training loss averaged over replicas
      - curr_loss:
      - num_so_far: the number of samples seen by ensemble of replicas to a current epoch (step) 

    Intermediate outputs and logging: debug printout of task_index (MPI), epoch number, number of samples seen to 
    a current epoch, average training loss
    '''

    verbose = False
    first_run = True
    step = 0
    loss_averager = Averager()
    t_start = time.time()

    batch_iterator_func = self.batch_iterator_func
    num_total = 1
    ave_loss = -1
    curr_loss = -1
    t0 = 0 
    t1 = 0 
    t2 = 0

    while (self.num_so_far-self.epoch*num_total) < num_total or step < self.num_batches_minimum:

      try:
          batch_xs,batch_ys,batches_to_reset,num_so_far_curr,num_total,is_warmup_period = next(batch_iterator_func)
      except StopIteration:
          print("Resetting batch iterator.")
          self.num_so_far_accum = self.num_so_far_indiv
          self.set_batch_iterator_func()
          batch_iterator_func = self.batch_iterator_func
          batch_xs,batch_ys,batches_to_reset,num_so_far_curr,num_total,is_warmup_period = next(batch_iterator_func)
      self.num_so_far_indiv = self.num_so_far_accum+num_so_far_curr

      # if batches_to_reset:
        # self.model.reset_states(batches_to_reset)

      warmup_phase = (step < self.warmup_steps and self.epoch == 0)
      num_replicas = 1 if warmup_phase else self.num_replicas

      self.num_so_far = self.mpi_sum_scalars(self.num_so_far_indiv,num_replicas)

      #run the model once to force compilation. Don't actually use these values.
      if first_run:
        first_run = False
        t0_comp = time.time()
        _,_ = self.train_on_batch_and_get_deltas(batch_xs,batch_ys,verbose)
        self.comm.Barrier()
        sys.stdout.flush()
        print_unique('Compilation finished in {:.2f}s'.format(time.time()-t0_comp))
        t_start = time.time()
        sys.stdout.flush()  
      
      if np.any(batches_to_reset):
        reset_states(self.model,batches_to_reset)

      t0 = time.time()
      deltas,loss = self.train_on_batch_and_get_deltas(batch_xs,batch_ys,verbose)
      t1 = time.time()
      if not is_warmup_period:
        self.set_new_weights(deltas,num_replicas)
        t2 = time.time()
        write_str_0 = self.calculate_speed(t0,t1,t2,num_replicas)

        curr_loss = self.mpi_average_scalars(1.0*loss,num_replicas)
        #if self.task_index == 0:
          #print(self.model.get_weights()[0][0][:4])
        loss_averager.add_val(curr_loss)
        ave_loss = loss_averager.get_val()
        eta = self.estimate_remaining_time(t0 - t_start,self.num_so_far-self.epoch*num_total,num_total)
        write_str = '\r[{}] step: {} [ETA: {:.2f}s] [{:.2f}/{}], loss: {:.5f} [{:.5f}] | walltime: {:.4f} | '.format(self.task_index,step,eta,1.0*self.num_so_far,num_total,ave_loss,curr_loss,time.time()-self.start_time)
        print_unique(write_str + write_str_0)
        step += 1
      else:
        print_unique('\r[{}] warmup phase, num so far: {}'.format(self.task_index,self.num_so_far))
        

      

    effective_epochs = 1.0*self.num_so_far/num_total
    epoch_previous = self.epoch
    self.epoch = effective_epochs
    print_unique('\nEpoch {:.2f} finished ({:.2f} epochs passed) in {:.2f} seconds.\n'.format(1.0*self.epoch,self.epoch-epoch_previous,t2 - t_start))
    return (step,ave_loss,curr_loss,self.num_so_far,effective_epochs)


  def estimate_remaining_time(self,time_so_far,work_so_far,work_total):
    eps = 1e-6
    total_time = 1.0*time_so_far*work_total/(work_so_far + eps)
    return total_time - time_so_far

  def get_effective_lr(self,num_replicas):
    effective_lr = self.lr * num_replicas
    if effective_lr > self.max_lr:
      print_unique('Warning: effective learning rate set to {}, larger than maximum {}. Clipping.'.format(effective_lr,self.max_lr))
      effective_lr = self.max_lr
    return effective_lr

  def get_effective_batch_size(self,num_replicas):
    return self.batch_size*num_replicas

  def calculate_speed(self,t0,t_after_deltas,t_after_update,num_replicas,verbose=False):
    effective_batch_size = self.get_effective_batch_size(num_replicas)
    t_calculate = t_after_deltas - t0
    t_sync = t_after_update - t_after_deltas
    t_tot = t_after_update - t0

    examples_per_sec = effective_batch_size/t_tot
    frac_calculate = t_calculate/t_tot
    frac_sync = t_sync/t_tot

    print_str = '{:.2E} Examples/sec | {:.2E} sec/batch [{:.1%} calc., {:.1%} synch.]'.format(examples_per_sec,t_tot,frac_calculate,frac_sync)
    print_str += '[batch = {} = {}*{}] [lr = {:.2E} = {:.2E}*{}]'.format(effective_batch_size,self.batch_size,num_replicas,self.get_effective_lr(num_replicas),self.lr,num_replicas)
    if verbose:
      print_unique(print_str)
    return print_str