def train(self): '''train the model''' #look for the master if distributed training is done master = self.server.target #start the session and standart servises config = tf.ConfigProto() config.gpu_options.allow_growth = True config.allow_soft_placement = True #config.log_device_placement = True chief_only_hooks = [] if self.init_filename != None: init_hook = hooks.LoadAtBegin(self.init_filename, self.model) chief_only_hooks.append(init_hook) #create a hook for saving the final model save_hook = hooks.SaveAtEnd( os.path.join(self.expdir, 'model', 'network.ckpt'), self.model) chief_only_hooks.append(save_hook) #create a hook for saving and restoring the validated model validation_hook = hooks.ValidationSaveHook( os.path.join(self.expdir, 'logdir', 'validated.ckpt'), self.model) chief_only_hooks.append(validation_hook) #number of times validation performance was worse num_tries = 0 with self.graph.as_default(): with tf.train.MonitoredTrainingSession( master=master, is_chief=self.is_chief, checkpoint_dir=os.path.join(self.expdir, 'logdir'), scaffold=self.scaffold, hooks=[hooks.StopHook(self.done)], chief_only_hooks=chief_only_hooks, config=config) as sess: #set the number of steps self.set_num_steps.run(session=sess) #start the training loop #pylint: disable=E1101 while not (sess.should_stop() or self.should_stop.eval(session=sess)): #check if validation is due if (self.update_loss is not None and self.should_validate.eval(session=sess)): if self.is_chief: print('WORKER %d: validating model' % self.task_index) #get the previous validation loss prev_val_loss = self.best_validation.eval( session=sess) #reset the validation loss self.validation_loss.initializer.run(session=sess) #start time start = time.time() #compute the validation loss for _ in range(self.valbatches): self.update_loss.run(session=sess) #get the current validation loss validation_loss = self.validation_loss.eval( session=sess) print('WORKER %d: validation loss:%.6g,' 'time: %f sec' % (self.task_index, validation_loss, time.time() - start)) #check if the validation loss is better if validation_loss >= prev_val_loss: print('WORKER %d: validation loss is worse!' % self.task_index) #check how many times validation performance was #worse if self.conf['num_tries'] != 'None': if num_tries == int( self.conf['num_tries']): validation_hook.restore() print( 'WORKER %d: terminating training' % self.task_index) self.terminate.run(session=sess) break num_tries += 1 if self.conf['go_back'] == 'True': #wait untill all workers are at validation #point while not self.all_waiting.eval( session=sess): time.sleep(1) self.reset_waiting.run(session=sess) print('WORKER %d: loading previous model' % self.task_index) #load the previous model validation_hook.restore() else: self.update_validated_step.run( session=sess) if self.conf['valid_adapt'] == 'True': print('WORKER %d: halving learning rate' % self.task_index) self.half_lr.run(session=sess) validation_hook.save() else: if self.conf['reset_tries'] == 'True': num_tries = 0 #set the validated step self.update_validated_step.run(session=sess) self.update_best.run(session=sess) self.reset_waiting.run(session=sess) #store the validated model validation_hook.save() else: if (self.conf['go_back'] == 'True' and self.update_loss is not None): self.waiting.run(session=sess) while (self.should_validate.eval(session=sess) and not self.should_stop.eval( session=sess)): time.sleep(1) if self.should_stop.eval(session=sess): break #start time start = time.time() #First, accumulate the gradients for _ in range(int(self.conf['numbatches_to_aggregate'])): sess.run( fetches=[self.update_gradients, self.acc_loss]) #Finally, apply the gradients _, loss, lr, global_step, num_steps = sess.run(fetches=[ self.update_op, self.total_loss, self.learning_rate, self.global_step, self.num_steps ]) #reset the gradients for the next step sess.run(fetches=[self.reset_grad, self.reset_loss]) print(( 'WORKER %d: step %d/%d loss: %.6g, learning rate: %f, ' 'time: %f sec') % (self.task_index, global_step, num_steps, loss, lr, time.time() - start))
def train(self, testing=False): '''train the model args: testing: if true only the graph will be created for debugging purposes ''' #look for the master if distributed training is done master = self.server.target #start the session and standart servises config = tf.ConfigProto() config.gpu_options.allow_growth = True config.allow_soft_placement = True #number of times validation performance was worse num_tries = 0 #check if this is the chief worker is_chief = self.task_index == 0 #create the graph graph = tf.Graph() with graph.as_default(): outputs = self._create_graph() scaffold = tf.train.Scaffold() if testing: return with graph.as_default(): #create a hook for saving the final model save_hook = hooks.SaveAtEnd( os.path.join(self.expdir, 'model', 'network.ckpt'), self.model.variables) #create a hook for saving and restoring the validated model validation_hook = hooks.ValidationSaveHook( os.path.join(self.expdir, 'logdir', 'validated.ckpt'), self.model) with tf.train.MonitoredTrainingSession( master=master, is_chief=is_chief, checkpoint_dir=os.path.join(self.expdir, 'logdir'), scaffold=scaffold, hooks=[hooks.StopHook(outputs['done'])] + self.hooks(outputs), chief_only_hooks=[save_hook, validation_hook] \ + self.chief_only_hooks(outputs), config=config) as sess: #create the summary writer summary_writer = tf.summary.FileWriter( os.path.join(self.expdir, 'logdir')) #start the training loop #pylint: disable=E1101 while not (sess.should_stop() or outputs['should_stop'].eval(session=sess)): #check if validation is due if (outputs['update_loss'] is not None and outputs['should_validate'].eval(session=sess)): if is_chief: print('WORKER %d: validating model' % self.task_index) #get the previous validation loss prev_val_loss = outputs['best_validation'].eval( session=sess) #initialize validation outputs['init_validation'].run(session=sess) #compute the validation loss for i in range(outputs['valbatches']): _, summary = sess.run(fetches=[ outputs['update_loss'], outputs['eval_summaries'] ]) if summary is not None: summary_writer.add_summary(summary, i) summary, global_step = sess.run(fetches=[ outputs['val_loss_summary'], outputs['global_step'] ]) summary_writer.add_summary(summary, global_step) #get the current validation loss validation_loss = outputs['validation_loss'].eval( session=sess) print('WORKER %d: validation loss: %f' % (self.task_index, validation_loss)) #check if the validation loss is better if validation_loss >= prev_val_loss: print('WORKER %d: validation loss is worse' % self.task_index) #check how many times validation performance was #worse if self.conf['num_tries'] != 'None': if num_tries == int( self.conf['num_tries']): validation_hook.restore() print( 'WORKER %d: terminating training' % self.task_index) outputs['terminate'].run(session=sess) break num_tries += 1 if self.conf['go_back'] == 'True': #wait untill all workers are at validation #point while not outputs['all_waiting'].eval( session=sess): time.sleep(1) outputs['reset_waiting'].run(session=sess) print('WORKER %d: loading previous model' % self.task_index) #load the previous model validation_hook.restore() else: outputs['update_validated_step'].run( session=sess) if self.conf['valid_adapt'] == 'True': print('WORKER %d: halving learning rate' % self.task_index) outputs['half_lr'].run(session=sess) validation_hook.save() else: if self.conf['reset_tries'] == 'True': num_tries = 0 #set the validated step outputs['update_validated_step'].run( session=sess) outputs['update_best'].run(session=sess) outputs['reset_waiting'].run(session=sess) #store the validated model validation_hook.save() else: if (self.conf['go_back'] == 'True' and self.update_loss is not None): outputs['waiting'].run(session=sess) while (outputs['should_validate'].eval( session=sess) and not outputs['should_stop'].eval( session=sess)): time.sleep(1) if outputs['should_stop'].eval(session=sess): break #start time start = time.time() #read in the next batch of data local_steps, _ = sess.run( [outputs['local_steps'], outputs['read_data']]) for _ in range(local_steps): #update the model _, loss, lr, global_step, memory, limit, summary = \ sess.run( fetches=[outputs['update_op'], outputs['loss'], outputs['learning_rate'], outputs['global_step'], outputs['memory_usage'], outputs['memory_limit'], outputs['training_summaries']]) summary_writer.add_summary(summary, global_step) if memory is not None: memory_line = '\n\t peak memory usage: %d/%d MB' % ( memory / 1e6, limit / 1e6) else: memory_line = '' print(('WORKER %d: step %d/%d loss: %f, learning rate:' ' %f \n\t time elapsed: %f sec%s') % (self.task_index, global_step, outputs['num_steps'], loss, lr, time.time() - start, memory_line)) outputs['increment_step'].run(session=sess) #store the model file modelfile = os.path.join(self.expdir, 'model', 'model.pkl') with open(modelfile, 'wb') as fid: pickle.dump(self.model, fid)
def train(self): '''train the model''' #look for the master if distributed training is done master = self.server.target #start the session and standard services config = tf.ConfigProto(device_count={'CPU': 1}) config.gpu_options.allow_growth = True config.allow_soft_placement = True #config.log_device_placement = True chief_only_hooks = [] if self.init_filename != None: init_hook = hooks.LoadAtBegin(self.init_filename, self.models) chief_only_hooks.append(init_hook) #create a hook for saving the final model save_hook = hooks.SaveAtEnd( os.path.join(self.expdir, 'model', 'network.ckpt'), self.models) chief_only_hooks.append(save_hook) #create a hook for saving and restoring the validated model validation_hook = hooks.ValidationSaveHook( os.path.join(self.expdir, 'logdir', 'validated.ckpt'), self.models) chief_only_hooks.append(validation_hook) #number of times validation performance was worse num_tries = np.zeros(len(self.val_task_trainers)) #determine al parameters all_params = [] for ind, _ in enumerate(self.task_trainers): all_params += self.task_trainers[ind].params all_params = list(set(all_params)) all_params = sorted(all_params, key=lambda par: par.name) with self.graph.as_default(): with tf.train.MonitoredTrainingSession( master=master, is_chief=self.is_chief, checkpoint_dir=os.path.join(self.expdir, 'logdir'), scaffold=self.scaffold, hooks=[hooks.StopHook(self.done)], chief_only_hooks=chief_only_hooks, config=config) as sess: #set the number of steps self.set_num_steps.run(session=sess) #print the params that will be updated print 'parameters that will be trained:' for ind, param in enumerate(all_params): print 'param ind %i: %s' % (ind, param.name) #start the training loop #pylint: disable=E1101 while not (sess.should_stop() or self.should_stop.eval(session=sess)): ##Validation part #check if validation is due if (self.process_val_batch is not None and self.should_validate.eval(session=sess)): if self.is_chief: print('WORKER %d: validating model' % self.task_index) #get the previous validation loss for each validation task prev_val_loss_all_tasks = sess.run( self.best_validation_all_tasks) #reset the validation loss self.reset_val_loss_norm.run(session=sess) #start time start = time.time() #compute the validation loss for _ in range(self.valbatches): self.process_val_batch.run(session=sess) #get the current validation loss [validation_loss, val_loss_all_tasks] = sess.run([ self.validation_loss, self.val_loss_all_tasks ]) print_str = ('WORKER %d: validation loss:%.6g,' 'time: %f sec' % (self.task_index, validation_loss, time.time() - start)) #if multiple tasks, also print individual task losses if len(val_loss_all_tasks) > 1: for ind, loss_task in enumerate( val_loss_all_tasks): print_str += ( ', task_loss %s: %.6g' % (self.task_trainers[ind].task_name, loss_task)) print print_str #check if the validation loss is better, for every task terminate_train = False restore_validation = False continue_validation = True do_halve_lr = False for task_ind, val_task in enumerate( self.val_task_trainers): if val_loss_all_tasks[ task_ind] >= prev_val_loss_all_tasks[ task_ind]: print( 'WORKER %d: validation loss is worse for %s!' % (self.task_index, val_task.task_name)) #check how many times validation performance was #worse num_tries[task_ind] += 1 if self.conf['num_tries'] != 'None': if num_tries[task_ind] == int( self.conf['num_tries']): terminate_train = True if self.conf['go_back'] == 'True': continue_validation = False restore_validation = True else: continue_validation = True if self.conf['valid_adapt'] == 'True': do_halve_lr = True else: sess.run( self.update_best_all_tasks[task_ind]) if self.conf['reset_tries'] == 'True': num_tries[task_ind] = 0 #decide what to do for training based on the above task validations if terminate_train: validation_hook.restore() print('WORKER %d: terminating training' % self.task_index) self.terminate.run(session=sess) break if restore_validation: #wait untill all workers are at validation #point while not self.all_waiting.eval(session=sess): time.sleep(1) self.reset_waiting.run(session=sess) print('WORKER %d: loading previous model' % self.task_index) #load the previous model validation_hook.restore() if continue_validation: self.update_validated_step.run(session=sess) if do_halve_lr: print('WORKER %d: halving learning rate' % self.task_index) self.half_lr.run(session=sess) validation_hook.save() # if np.sum(num_tries) == 0: self.reset_waiting.run(session=sess) #store the validated model validation_hook.save() else: if (self.conf['go_back'] == 'True' and self.process_val_batch is not None): self.waiting.run(session=sess) while (self.should_validate.eval(session=sess) and not self.should_stop.eval( session=sess)): time.sleep(1) if self.should_stop.eval(session=sess): break ##Training part #start time start = time.time() #reset the gradients for the next step sess.run(fetches=[self.reset_grad_loss_norm]) old_param_values = sess.run(all_params) #First, accumulate the gradients for _ in range(int(self.conf['numbatches_to_aggregate'])): _ = sess.run([self.process_minibatch]) #_, batch_loss, batch_loss_norm = sess.run(fetches=[self.process_minibatch, #self.task_trainers[0].batch_loss, #self.task_trainers[0].batch_loss_norm]) #print (('batchloss: %.6g, batch_loss_norm: %.6g, batch_normalized_loss: %.6g') #%(batch_loss,batch_loss_norm,batch_loss/(batch_loss_norm+1e-20))) #Then, normalize the gradients _ = sess.run([self.normalize_gradients]) #Finally, apply the gradients for each task optimizer. Get the variable values before #and after the update, so stepsizes for each task can be displayed. old_task_param_values = [] new_task_param_values = [] task_params_diff = [] loss_all_tasks = [] for ind, task_trainer in enumerate(self.task_trainers): #get the variable values before update if ind == 0: old_task_param_values.append(old_param_values) else: old_task_param_values.append( new_task_param_values[ind - 1]) #Apply the gradients in the task optimizer and get the task loss. If it is the last #task, also get some other stuff if ind + 1 < len(self.task_trainers): [_, task_loss] = sess.run([ task_trainer.apply_gradients, task_trainer.normalized_loss ]) else: _, _, task_loss, lr, global_step, num_steps = \ sess.run( fetches=[task_trainer.apply_gradients, self.other_update_op, task_trainer.normalized_loss, self.learning_rate, self.global_step, self.num_steps]) loss_all_tasks.append(task_loss) #get the variable values after update new_task_param_values.append(sess.run(all_params)) #Calculate the stepsize for each variable by calculating the difference between old #and new variable values. Average this per variable type (eg weights layer 1) and average. #Also multiply with 10000 (this is just for printing format purposes) task_params_diff.append([ 10000.0 * np.mean( np.abs(new_task_param_values[ind][par_ind] - old_task_param_values[ind][par_ind])) for par_ind in range( len(new_task_param_values[ind])) ]) #Calculate loss and step size over all task optimizations loss = np.mean(loss_all_tasks) new_param_values = new_task_param_values[-1] params_diff = [ 10000.0 * np.mean( np.abs(new_param_values[ind] - old_param_values[ind])) for ind in range(len(new_param_values)) ] #_, loss,loss_all_tasks, lr, global_step, num_steps,new_param_values = sess.run( #fetches=[self.update_op, #self.total_loss, #self.loss_all_tasks, #self.learning_rate, #self.global_step, #self.num_steps, #all_params]) ##Output prompt #Start the printing string with most important information print_str = (( 'WORKER %d: step %d/%d loss: %.6g, learning rate: %f, ' 'time: %.2f sec') % (self.task_index, global_step, num_steps, loss, lr, time.time() - start)) #if multiple tasks, also print individual task losses if len(loss_all_tasks) > 1: print_str += ' (' for ind, loss_task in enumerate(loss_all_tasks): print_str += ( '%s: %.6g. ' % (self.task_trainers[ind].task_name, loss_task)) print_str += ')' if 'print_var_updates' in self.conf and self.conf[ 'print_var_updates'] == 'True': #print the average variable step size print_str += '\n Av param upd (*10000): %.3f' % np.mean( np.array(params_diff)) #if multiple tasks, also print individual task average variable step size if len(task_params_diff) > 1: print_str += ' (' for ind, task_param_diff in enumerate( task_params_diff): print_str += '%s: %.3f; ' % ( self.task_trainers[ind].task_name, np.mean(np.array(task_param_diff))) print_str += ')' #For each variable type (eg weights layer 1) print the average step size print_str += ' (' for par_ind, param in enumerate(all_params): if par_ind > 0: print_str += ';' print_str += ('%i: %.3f ' % (par_ind, params_diff[par_ind])) #if multiple tasks, also print for each variable type the individual task average step size if len(task_params_diff) > 1: print_str += '{' for ind, task_param_diff in enumerate( task_params_diff): if ind > 0: print_str += '+' print_str += ('%.3f' % (task_param_diff[par_ind])) print_str += '} ' print_str += ')' #print the complete string print(print_str)