Exemplo n.º 1
0
    def train_epoch(self):
        '''
    The purpose of the method is to perform distributed mini-batch SGD for one epoch.
    It takes the batch iterator function and a NN model from MPIModel object, fetches mini-batches
    in a while-loop until number of samples seen by the ensemble of workers (num_so_far) exceeds the 
    training dataset size (num_total). 

    During each iteration, the gradient updates (deltas) and the loss are calculated for each model replica
    in the ensemble, weights are averaged over ensemble, and the new weights are set.

    It performs calls to: MPIModel.get_deltas, MPIModel.set_new_weights methods 

    Argument list: Empty

    Returns:  
      - step: epoch number
      - ave_loss: training loss averaged over replicas
      - curr_loss:
      - num_so_far: the number of samples seen by ensemble of replicas to a current epoch (step) 

    Intermediate outputs and logging: debug printout of task_index (MPI), epoch number, number of samples seen to 
    a current epoch, average training loss
    '''

        verbose = False
        first_run = True
        step = 0
        loss_averager = Averager()
        t_start = time.time()

        batch_iterator_func = self.batch_iterator_func
        num_total = 1
        ave_loss = -1
        curr_loss = -1
        t0 = 0
        t1 = 0
        t2 = 0

        while (self.num_so_far - self.epoch *
               num_total) < num_total or step < self.num_batches_minimum:

            try:
                batch_xs, batch_ys, batches_to_reset, num_so_far_curr, num_total, is_warmup_period = next(
                    batch_iterator_func)
            except StopIteration:
                print("Resetting batch iterator.")
                self.num_so_far_accum = self.num_so_far_indiv
                self.set_batch_iterator_func()
                batch_iterator_func = self.batch_iterator_func
                batch_xs, batch_ys, batches_to_reset, num_so_far_curr, num_total, is_warmup_period = next(
                    batch_iterator_func)
            self.num_so_far_indiv = self.num_so_far_accum + num_so_far_curr

            # if batches_to_reset:
            # self.model.reset_states(batches_to_reset)

            warmup_phase = (step < self.warmup_steps and self.epoch == 0)
            num_replicas = 1 if warmup_phase else self.num_replicas

            self.num_so_far = self.mpi_sum_scalars(self.num_so_far_indiv,
                                                   num_replicas)

            #run the model once to force compilation. Don't actually use these values.
            if first_run:
                first_run = False
                t0_comp = time.time()
                _, _ = self.train_on_batch_and_get_deltas(
                    batch_xs, batch_ys, verbose)
                self.comm.Barrier()
                sys.stdout.flush()
                print_unique(
                    'Compilation finished in {:.2f}s'.format(time.time() -
                                                             t0_comp))
                t_start = time.time()
                sys.stdout.flush()

            if np.any(batches_to_reset):
                reset_states(self.model, batches_to_reset)

            t0 = time.time()
            deltas, loss = self.train_on_batch_and_get_deltas(
                batch_xs, batch_ys, verbose)
            t1 = time.time()
            if not is_warmup_period:
                self.set_new_weights(deltas, num_replicas)
                t2 = time.time()
                write_str_0 = self.calculate_speed(t0, t1, t2, num_replicas)

                curr_loss = self.mpi_average_scalars(1.0 * loss, num_replicas)
                #if self.task_index == 0:
                #print(self.model.get_weights()[0][0][:4])
                loss_averager.add_val(curr_loss)
                ave_loss = loss_averager.get_val()
                eta = self.estimate_remaining_time(
                    t0 - t_start, self.num_so_far - self.epoch * num_total,
                    num_total)
                write_str = '\r[{}] step: {} [ETA: {:.2f}s] [{:.2f}/{}], loss: {:.5f} [{:.5f}] | walltime: {:.4f} | '.format(
                    self.task_index, step, eta, 1.0 * self.num_so_far,
                    num_total, ave_loss, curr_loss,
                    time.time() - self.start_time)
                print_unique(write_str + write_str_0)
                step += 1
            else:
                print_unique('\r[{}] warmup phase, num so far: {}'.format(
                    self.task_index, self.num_so_far))

        effective_epochs = 1.0 * self.num_so_far / num_total
        epoch_previous = self.epoch
        self.epoch = effective_epochs
        print_unique(
            '\nEpoch {:.2f} finished ({:.2f} epochs passed) in {:.2f} seconds.\n'
            .format(1.0 * self.epoch, self.epoch - epoch_previous,
                    t2 - t_start))
        return (step, ave_loss, curr_loss, self.num_so_far, effective_epochs)
Exemplo n.º 2
0
def train(conf, shot_list_train, shot_list_validate, loader):
    loader.set_inference_mode(False)
    np.random.seed(1)

    validation_losses = []
    validation_roc = []
    training_losses = []
    print('validate: {} shots, {} disruptive'.format(
        len(shot_list_validate), shot_list_validate.num_disruptive()))
    print('training: {} shots, {} disruptive'.format(
        len(shot_list_train), shot_list_train.num_disruptive()))

    if backend == 'tf' or backend == 'tensorflow':
        first_time = "tensorflow" not in sys.modules
        if first_time:
            import tensorflow as tf
            os.environ['KERAS_BACKEND'] = 'tensorflow'
            from keras.backend.tensorflow_backend import set_session
            config = tf.ConfigProto(device_count={"GPU": 1})
            set_session(tf.Session(config=config))
    else:
        os.environ['KERAS_BACKEND'] = 'theano'
        os.environ['THEANO_FLAGS'] = 'device=gpu,floatX=float32'
        import theano

    from keras.utils.generic_utils import Progbar
    from keras import backend as K
    from plasma.models import builder

    print('Build model...', end='')
    specific_builder = builder.ModelBuilder(conf)
    train_model = specific_builder.build_model(False)
    print('Compile model', end='')
    train_model.compile(optimizer=optimizer_class(),
                        loss=conf['data']['target'].loss)
    print('...done')

    #load the latest epoch we did. Returns -1 if none exist yet
    e = specific_builder.load_model_weights(train_model)
    e_start = e
    batch_generator = partial(loader.training_batch_generator_partial_reset,
                              shot_list=shot_list_train)
    batch_iterator = ProcessGenerator(batch_generator())

    num_epochs = conf['training']['num_epochs']
    num_at_once = conf['training']['num_shots_at_once']
    lr_decay = conf['model']['lr_decay']
    print('{} epochs left to go'.format(num_epochs - 1 - e))
    num_so_far_accum = 0
    num_so_far = 0
    num_total = np.inf

    if conf['callbacks']['mode'] == 'max':
        best_so_far = -np.inf
        cmp_fn = max
    else:
        best_so_far = np.inf
        cmp_fn = min

    while e < num_epochs - 1:
        e += 1
        print('\nEpoch {}/{}'.format(e + 1, num_epochs))
        pbar = Progbar(len(shot_list_train))

        #decay learning rate each epoch:
        K.set_value(train_model.optimizer.lr, lr * lr_decay**(e))

        #print('Learning rate: {}'.format(train_model.optimizer.lr.get_value()))
        num_batches_minimum = 100
        num_batches_current = 0
        training_losses_tmp = []

        while num_so_far < (
                e - e_start
        ) * num_total or num_batches_current < num_batches_minimum:
            num_so_far_old = num_so_far
            try:
                batch_xs, batch_ys, batches_to_reset, num_so_far_curr, num_total, is_warmup_period = next(
                    batch_iterator)
            except StopIteration:
                print("Resetting batch iterator.")
                num_so_far_accum = num_so_far
                batch_iterator = ProcessGenerator(batch_generator())
                batch_xs, batch_ys, batches_to_reset, num_so_far_curr, num_total, is_warmup_period = next(
                    batch_iterator)
            if np.any(batches_to_reset):
                reset_states(train_model, batches_to_reset)
            if not is_warmup_period:
                num_so_far = num_so_far_accum + num_so_far_curr

                num_batches_current += 1

                loss = train_model.train_on_batch(batch_xs, batch_ys)
                training_losses_tmp.append(loss)
                pbar.add(num_so_far - num_so_far_old,
                         values=[("train loss", loss)])
                loader.verbose = False  #True during the first iteration
            else:
                _ = train_model.predict(
                    batch_xs, batch_size=conf['training']['batch_size'])

        e = e_start + 1.0 * num_so_far / num_total
        sys.stdout.flush()
        ave_loss = np.mean(training_losses_tmp)
        training_losses.append(ave_loss)
        specific_builder.save_model_weights(train_model, int(round(e)))

        if conf['training']['validation_frac'] > 0.0:
            print("prediction on GPU...")
            _, _, _, roc_area, loss = make_predictions_and_evaluate_gpu(
                conf, shot_list_validate, loader)
            validation_losses.append(loss)
            validation_roc.append(roc_area)

            epoch_logs = {}
            epoch_logs['val_roc'] = roc_area
            epoch_logs['val_loss'] = loss
            epoch_logs['train_loss'] = ave_loss
            best_so_far = cmp_fn(epoch_logs[conf['callbacks']['monitor']],
                                 best_so_far)
            if best_so_far != epoch_logs[
                    conf['callbacks']
                ['monitor']]:  #only save model weights if quantity we are tracking is improving
                print("Not saving model weights")
                specific_builder.delete_model_weights(train_model,
                                                      int(round(e)))

            if conf['training']['ranking_difficulty_fac'] != 1.0:
                _, _, _, roc_area_train, loss_train = make_predictions_and_evaluate_gpu(
                    conf, shot_list_train, loader)
                batch_iterator.__exit__()
                batch_generator = partial(
                    loader.training_batch_generator_partial_reset,
                    shot_list=shot_list_train)
                batch_iterator = ProcessGenerator(batch_generator())
                num_so_far_accum = num_so_far

        print('=========Summary========')
        print('Training Loss Numpy: {:.3e}'.format(training_losses[-1]))
        if conf['training']['validation_frac'] > 0.0:
            print('Validation Loss: {:.3e}'.format(validation_losses[-1]))
            print('Validation ROC: {:.4f}'.format(validation_roc[-1]))
            if conf['training']['ranking_difficulty_fac'] != 1.0:
                print('Train Loss: {:.3e}'.format(loss_train))
                print('Train ROC: {:.4f}'.format(roc_area_train))

    # plot_losses(conf,[training_losses],specific_builder,name='training')
    if conf['training']['validation_frac'] > 0.0:
        plot_losses(conf, [training_losses, validation_losses, validation_roc],
                    specific_builder,
                    name='training_validation_roc')
    batch_iterator.__exit__()
    print('...done')
Exemplo n.º 3
0
def train(conf,shot_list_train,shot_list_validate,loader):
    loader.set_inference_mode(False)
    np.random.seed(1)

    validation_losses = []
    validation_roc = []
    training_losses = []
    print('validate: {} shots, {} disruptive'.format(len(shot_list_validate),shot_list_validate.num_disruptive()))
    print('training: {} shots, {} disruptive'.format(len(shot_list_train),shot_list_train.num_disruptive()))

    if backend == 'tf' or backend == 'tensorflow':
        first_time = "tensorflow" not in sys.modules
        if first_time:
                import tensorflow as tf
                os.environ['KERAS_BACKEND'] = 'tensorflow'
                from keras.backend.tensorflow_backend import set_session
                config = tf.ConfigProto(device_count={"GPU":1})
                set_session(tf.Session(config=config))
    else:
        os.environ['KERAS_BACKEND'] = 'theano'
        os.environ['THEANO_FLAGS'] = 'device=gpu,floatX=float32'
        import theano

    from keras.utils.generic_utils import Progbar 
    from keras import backend as K
    from plasma.models import builder

    print('Build model...',end='')
    specific_builder = builder.ModelBuilder(conf)
    train_model = specific_builder.build_model(False) 
    print('Compile model',end='')
    train_model.compile(optimizer=optimizer_class(),loss=conf['data']['target'].loss)
    print('...done')

    #load the latest epoch we did. Returns -1 if none exist yet
    e = specific_builder.load_model_weights(train_model)
    e_start = e
    batch_generator = partial(loader.training_batch_generator_partial_reset,shot_list=shot_list_train)
    batch_iterator = ProcessGenerator(batch_generator())

    num_epochs = conf['training']['num_epochs']
    num_at_once = conf['training']['num_shots_at_once']
    lr_decay = conf['model']['lr_decay']
    print('{} epochs left to go'.format(num_epochs - 1 - e))
    num_so_far_accum = 0
    num_so_far = 0
    num_total = np.inf

    if conf['callbacks']['mode'] == 'max':
        best_so_far = -np.inf
        cmp_fn = max
    else:
        best_so_far = np.inf
        cmp_fn = min

    while e < num_epochs-1:
        e += 1
        print('\nEpoch {}/{}'.format(e+1,num_epochs))
        pbar =  Progbar(len(shot_list_train))

        #decay learning rate each epoch:
        K.set_value(train_model.optimizer.lr, lr*lr_decay**(e))
        
        #print('Learning rate: {}'.format(train_model.optimizer.lr.get_value()))
        num_batches_minimum = 100
        num_batches_current = 0
        training_losses_tmp = []

        while num_so_far < (e - e_start)*num_total or num_batches_current < num_batches_minimum:
            num_so_far_old = num_so_far
            try:
                batch_xs,batch_ys,batches_to_reset,num_so_far_curr,num_total,is_warmup_period = next(batch_iterator)
            except StopIteration:
                print("Resetting batch iterator.")
                num_so_far_accum = num_so_far
                batch_iterator = ProcessGenerator(batch_generator())
                batch_xs,batch_ys,batches_to_reset,num_so_far_curr,num_total,is_warmup_period = next(batch_iterator)
            if np.any(batches_to_reset):
                reset_states(train_model,batches_to_reset)
            if not is_warmup_period:
                num_so_far = num_so_far_accum+num_so_far_curr

                num_batches_current +=1 


                loss = train_model.train_on_batch(batch_xs,batch_ys)
                training_losses_tmp.append(loss)
                pbar.add(num_so_far - num_so_far_old, values=[("train loss", loss)])
                loader.verbose=False#True during the first iteration
            else:
                _ = train_model.predict(batch_xs,batch_size=conf['training']['batch_size'])


        e = e_start+1.0*num_so_far/num_total
        sys.stdout.flush()
        ave_loss = np.mean(training_losses_tmp)
        training_losses.append(ave_loss)
        specific_builder.save_model_weights(train_model,int(round(e)))

        if conf['training']['validation_frac'] > 0.0:
            print("prediction on GPU...")
            _,_,_,roc_area,loss = make_predictions_and_evaluate_gpu(conf,shot_list_validate,loader)
            validation_losses.append(loss)
            validation_roc.append(roc_area)

            epoch_logs = {}
            epoch_logs['val_roc'] = roc_area 
            epoch_logs['val_loss'] = loss
            epoch_logs['train_loss'] = ave_loss
            best_so_far = cmp_fn(epoch_logs[conf['callbacks']['monitor']],best_so_far)
            if best_so_far != epoch_logs[conf['callbacks']['monitor']]: #only save model weights if quantity we are tracking is improving
                print("Not saving model weights")
                specific_builder.delete_model_weights(train_model,int(round(e)))

            if conf['training']['ranking_difficulty_fac'] != 1.0:
                _,_,_,roc_area_train,loss_train = make_predictions_and_evaluate_gpu(conf,shot_list_train,loader)
                batch_iterator.__exit__()
                batch_generator = partial(loader.training_batch_generator_partial_reset,shot_list=shot_list_train)
                batch_iterator = ProcessGenerator(batch_generator())
                num_so_far_accum = num_so_far

        print('=========Summary========')
        print('Training Loss Numpy: {:.3e}'.format(training_losses[-1]))
        if conf['training']['validation_frac'] > 0.0:
            print('Validation Loss: {:.3e}'.format(validation_losses[-1]))
            print('Validation ROC: {:.4f}'.format(validation_roc[-1]))
            if conf['training']['ranking_difficulty_fac'] != 1.0:
                print('Train Loss: {:.3e}'.format(loss_train))
                print('Train ROC: {:.4f}'.format(roc_area_train))
        


    # plot_losses(conf,[training_losses],specific_builder,name='training')
    if conf['training']['validation_frac'] > 0.0:
        plot_losses(conf,[training_losses,validation_losses,validation_roc],specific_builder,name='training_validation_roc')
    batch_iterator.__exit__()
    print('...done')
Exemplo n.º 4
0
    def train_epoch(self):
        '''
        Perform distributed mini-batch SGD for
        one epoch.  It takes the batch iterator function and a NN model from
        MPIModel object, fetches mini-batches in a while-loop until number of
        samples seen by the ensemble of workers (num_so_far) exceeds the
        training dataset size (num_total).

        NOTE: "sample" = "an entire shot" within this description

        During each iteration, the gradient updates (deltas) and the loss are
        calculated for each model replica in the ensemble, weights are averaged
        over ensemble, and the new weights are set.

        It performs calls to: MPIModel.get_deltas, MPIModel.set_new_weights
        methods

        Argument list: Empty

        Returns:
          - step: final iteration number
          - ave_loss: model loss averaged over iterations within this epoch
          - curr_loss: training loss averaged over replicas at final iteration
          - num_so_far: the cumulative number of samples seen by the ensemble
        of replicas up to the end of the final iteration (step) of this epoch

        Intermediate outputs and logging: debug printout of task_index (MPI),
        epoch number, number of samples seen to a current epoch, average
        training loss

        '''

        verbose = False
        first_run = True
        step = 0
        loss_averager = Averager()
        t_start = time.time()

        timeline_prof = False
        if (self.conf is not None and conf['training']['timeline_prof']):
            timeline_prof = True
        step_limit = 0
        if (self.conf is not None and conf['training']['step_limit'] > 0):
            step_limit = conf['training']['step_limit']

        batch_iterator_func = self.batch_iterator_func
        num_total = 1
        ave_loss = -1
        curr_loss = -1
        t0 = 0
        t1 = 0
        t2 = 0

        while ((self.num_so_far - self.epoch * num_total) < num_total
               or step < self.num_batches_minimum):
            if step_limit > 0 and step > step_limit:
                print('reached step limit')
                break
            try:
                (batch_xs, batch_ys, batches_to_reset, num_so_far_curr,
                 num_total, is_warmup_period) = next(batch_iterator_func)
            except StopIteration:
                g.print_unique("Resetting batch iterator.")
                self.num_so_far_accum = self.num_so_far_indiv
                self.set_batch_iterator_func()
                batch_iterator_func = self.batch_iterator_func
                (batch_xs, batch_ys, batches_to_reset, num_so_far_curr,
                 num_total, is_warmup_period) = next(batch_iterator_func)
            self.num_so_far_indiv = self.num_so_far_accum + num_so_far_curr

            # if batches_to_reset:
            # self.model.reset_states(batches_to_reset)

            warmup_phase = (step < self.warmup_steps and self.epoch == 0)
            num_replicas = 1 if warmup_phase else self.num_replicas

            self.num_so_far = self.mpi_sum_scalars(self.num_so_far_indiv,
                                                   num_replicas)

            # run the model once to force compilation. Don't actually use these
            # values.
            if first_run:
                first_run = False
                t0_comp = time.time()
                #   print('input_dimension:',batch_xs.shape)
                #   print('output_dimension:',batch_ys.shape)
                _, _ = self.train_on_batch_and_get_deltas(
                    batch_xs, batch_ys, verbose)
                self.comm.Barrier()
                sys.stdout.flush()
                # TODO(KGF): check line feed/carriage returns around this
                g.print_unique(
                    '\nCompilation finished in {:.2f}s'.format(time.time() -
                                                               t0_comp))
                t_start = time.time()
                sys.stdout.flush()

            if np.any(batches_to_reset):
                reset_states(self.model, batches_to_reset)
            if ('noise' in self.conf['training'].keys()
                    and self.conf['training']['noise'] is not False):
                batch_xs = self.add_noise(batch_xs)
            t0 = time.time()
            deltas, loss = self.train_on_batch_and_get_deltas(
                batch_xs, batch_ys, verbose)
            t1 = time.time()
            if not is_warmup_period:
                self.set_new_weights(deltas, num_replicas)
                t2 = time.time()
                write_str_0 = self.calculate_speed(t0, t1, t2, num_replicas)
                curr_loss = self.mpi_average_scalars(1.0 * loss, num_replicas)
                # g.print_unique(self.model.get_weights()[0][0][:4])
                loss_averager.add_val(curr_loss)
                ave_loss = loss_averager.get_ave()
                eta = self.estimate_remaining_time(
                    t0 - t_start, self.num_so_far - self.epoch * num_total,
                    num_total)
                write_str = (
                    '\r[{}] step: {} [ETA: {:.2f}s] [{:.2f}/{}], '.format(
                        self.task_index, step, eta, 1.0 * self.num_so_far,
                        num_total) +
                    'loss: {:.5f} [{:.5f}] | '.format(ave_loss, curr_loss) +
                    'walltime: {:.4f} | '.format(time.time() -
                                                 self.start_time))
                g.write_unique(write_str + write_str_0)

                if timeline_prof:
                    # dump profile
                    tl = timeline.Timeline(self.run_metadata.step_stats)
                    ctf = tl.generate_chrome_trace_format()
                    # dump file per iteration
                    with open('timeline_%s.json' % step, 'w') as f:
                        f.write(ctf)

                step += 1
            else:
                g.write_unique('\r[{}] warmup phase, num so far: {}'.format(
                    self.task_index, self.num_so_far))

        effective_epochs = 1.0 * self.num_so_far / num_total
        epoch_previous = self.epoch
        self.epoch = effective_epochs
        g.write_unique(
            # TODO(KGF): "a total of X epochs within this session" ?
            '\nFinished training epoch {:.2f} '.format(self.epoch)
            # TODO(KGF): "precisely/exactly X epochs just passed"?
            +
            'during this session ({:.2f} epochs passed)'.format(self.epoch -
                                                                epoch_previous)
            # '\nEpoch {:.2f} finished training ({:.2f} epochs passed)'.format(
            #     1.0 * self.epoch, self.epoch - epoch_previous)
            + ' in {:.2f} seconds\n'.format(t2 - t_start))
        return (step, ave_loss, curr_loss, self.num_so_far, effective_epochs)
Exemplo n.º 5
0
  def train_epoch(self):
    '''
    The purpose of the method is to perform distributed mini-batch SGD for one epoch.
    It takes the batch iterator function and a NN model from MPIModel object, fetches mini-batches
    in a while-loop until number of samples seen by the ensemble of workers (num_so_far) exceeds the 
    training dataset size (num_total). 

    During each iteration, the gradient updates (deltas) and the loss are calculated for each model replica
    in the ensemble, weights are averaged over ensemble, and the new weights are set.

    It performs calls to: MPIModel.get_deltas, MPIModel.set_new_weights methods 

    Argument list: Empty

    Returns:  
      - step: epoch number
      - ave_loss: training loss averaged over replicas
      - curr_loss:
      - num_so_far: the number of samples seen by ensemble of replicas to a current epoch (step) 

    Intermediate outputs and logging: debug printout of task_index (MPI), epoch number, number of samples seen to 
    a current epoch, average training loss
    '''

    verbose = False
    first_run = True
    step = 0
    loss_averager = Averager()
    t_start = time.time()

    batch_iterator_func = self.batch_iterator_func
    num_total = 1
    ave_loss = -1
    curr_loss = -1
    t0 = 0 
    t1 = 0 
    t2 = 0

    while (self.num_so_far-self.epoch*num_total) < num_total or step < self.num_batches_minimum:

      try:
          batch_xs,batch_ys,batches_to_reset,num_so_far_curr,num_total,is_warmup_period = next(batch_iterator_func)
      except StopIteration:
          print("Resetting batch iterator.")
          self.num_so_far_accum = self.num_so_far_indiv
          self.set_batch_iterator_func()
          batch_iterator_func = self.batch_iterator_func
          batch_xs,batch_ys,batches_to_reset,num_so_far_curr,num_total,is_warmup_period = next(batch_iterator_func)
      self.num_so_far_indiv = self.num_so_far_accum+num_so_far_curr

      # if batches_to_reset:
        # self.model.reset_states(batches_to_reset)

      warmup_phase = (step < self.warmup_steps and self.epoch == 0)
      num_replicas = 1 if warmup_phase else self.num_replicas

      self.num_so_far = self.mpi_sum_scalars(self.num_so_far_indiv,num_replicas)

      #run the model once to force compilation. Don't actually use these values.
      if first_run:
        first_run = False
        t0_comp = time.time()
        _,_ = self.train_on_batch_and_get_deltas(batch_xs,batch_ys,verbose)
        self.comm.Barrier()
        sys.stdout.flush()
        print_unique('Compilation finished in {:.2f}s'.format(time.time()-t0_comp))
        t_start = time.time()
        sys.stdout.flush()  
      
      if np.any(batches_to_reset):
        reset_states(self.model,batches_to_reset)

      t0 = time.time()
      deltas,loss = self.train_on_batch_and_get_deltas(batch_xs,batch_ys,verbose)
      t1 = time.time()
      if not is_warmup_period:
        self.set_new_weights(deltas,num_replicas)
        t2 = time.time()
        write_str_0 = self.calculate_speed(t0,t1,t2,num_replicas)

        curr_loss = self.mpi_average_scalars(1.0*loss,num_replicas)
        #if self.task_index == 0:
          #print(self.model.get_weights()[0][0][:4])
        loss_averager.add_val(curr_loss)
        ave_loss = loss_averager.get_val()
        eta = self.estimate_remaining_time(t0 - t_start,self.num_so_far-self.epoch*num_total,num_total)
        write_str = '\r[{}] step: {} [ETA: {:.2f}s] [{:.2f}/{}], loss: {:.5f} [{:.5f}] | walltime: {:.4f} | '.format(self.task_index,step,eta,1.0*self.num_so_far,num_total,ave_loss,curr_loss,time.time()-self.start_time)
        print_unique(write_str + write_str_0)
        step += 1
      else:
        print_unique('\r[{}] warmup phase, num so far: {}'.format(self.task_index,self.num_so_far))
        

      

    effective_epochs = 1.0*self.num_so_far/num_total
    epoch_previous = self.epoch
    self.epoch = effective_epochs
    print_unique('\nEpoch {:.2f} finished ({:.2f} epochs passed) in {:.2f} seconds.\n'.format(1.0*self.epoch,self.epoch-epoch_previous,t2 - t_start))
    return (step,ave_loss,curr_loss,self.num_so_far,effective_epochs)