def train(self):
        '''train the model'''

        #look for the master if distributed training is done
        master = self.server.target

        #start the session and standart servises
        config = tf.ConfigProto()
        config.gpu_options.allow_growth = True
        config.allow_soft_placement = True
        #config.log_device_placement = True

        chief_only_hooks = []

        if self.init_filename != None:
            init_hook = hooks.LoadAtBegin(self.init_filename, self.model)
            chief_only_hooks.append(init_hook)

        #create a hook for saving the final model
        save_hook = hooks.SaveAtEnd(
            os.path.join(self.expdir, 'model', 'network.ckpt'), self.model)
        chief_only_hooks.append(save_hook)

        #create a hook for saving and restoring the validated model
        validation_hook = hooks.ValidationSaveHook(
            os.path.join(self.expdir, 'logdir', 'validated.ckpt'), self.model)
        chief_only_hooks.append(validation_hook)

        #number of times validation performance was worse
        num_tries = 0

        with self.graph.as_default():
            with tf.train.MonitoredTrainingSession(
                    master=master,
                    is_chief=self.is_chief,
                    checkpoint_dir=os.path.join(self.expdir, 'logdir'),
                    scaffold=self.scaffold,
                    hooks=[hooks.StopHook(self.done)],
                    chief_only_hooks=chief_only_hooks,
                    config=config) as sess:

                #set the number of steps
                self.set_num_steps.run(session=sess)

                #start the training loop
                #pylint: disable=E1101
                while not (sess.should_stop()
                           or self.should_stop.eval(session=sess)):

                    #check if validation is due
                    if (self.update_loss is not None
                            and self.should_validate.eval(session=sess)):
                        if self.is_chief:
                            print('WORKER %d: validating model' %
                                  self.task_index)

                            #get the previous validation loss
                            prev_val_loss = self.best_validation.eval(
                                session=sess)

                            #reset the validation loss
                            self.validation_loss.initializer.run(session=sess)

                            #start time
                            start = time.time()

                            #compute the validation loss
                            for _ in range(self.valbatches):
                                self.update_loss.run(session=sess)

                            #get the current validation loss
                            validation_loss = self.validation_loss.eval(
                                session=sess)

                            print('WORKER %d: validation loss:%.6g,'
                                  'time: %f sec' %
                                  (self.task_index, validation_loss,
                                   time.time() - start))

                            #check if the validation loss is better
                            if validation_loss >= prev_val_loss:

                                print('WORKER %d: validation loss is worse!' %
                                      self.task_index)

                                #check how many times validation performance was
                                #worse
                                if self.conf['num_tries'] != 'None':
                                    if num_tries == int(
                                            self.conf['num_tries']):
                                        validation_hook.restore()
                                        print(
                                            'WORKER %d: terminating training' %
                                            self.task_index)
                                        self.terminate.run(session=sess)
                                        break

                                num_tries += 1

                                if self.conf['go_back'] == 'True':

                                    #wait untill all workers are at validation
                                    #point
                                    while not self.all_waiting.eval(
                                            session=sess):
                                        time.sleep(1)
                                    self.reset_waiting.run(session=sess)

                                    print('WORKER %d: loading previous model' %
                                          self.task_index)

                                    #load the previous model
                                    validation_hook.restore()
                                else:
                                    self.update_validated_step.run(
                                        session=sess)

                                if self.conf['valid_adapt'] == 'True':
                                    print('WORKER %d: halving learning rate' %
                                          self.task_index)
                                    self.half_lr.run(session=sess)
                                    validation_hook.save()

                            else:
                                if self.conf['reset_tries'] == 'True':
                                    num_tries = 0

                                #set the validated step
                                self.update_validated_step.run(session=sess)
                                self.update_best.run(session=sess)
                                self.reset_waiting.run(session=sess)

                                #store the validated model
                                validation_hook.save()

                        else:
                            if (self.conf['go_back'] == 'True'
                                    and self.update_loss is not None):
                                self.waiting.run(session=sess)
                                while (self.should_validate.eval(session=sess)
                                       and not self.should_stop.eval(
                                           session=sess)):
                                    time.sleep(1)

                                if self.should_stop.eval(session=sess):
                                    break

                    #start time
                    start = time.time()

                    #First, accumulate the gradients
                    for _ in range(int(self.conf['numbatches_to_aggregate'])):
                        sess.run(
                            fetches=[self.update_gradients, self.acc_loss])

#Finally, apply the gradients
                    _, loss, lr, global_step, num_steps = sess.run(fetches=[
                        self.update_op, self.total_loss, self.learning_rate,
                        self.global_step, self.num_steps
                    ])

                    #reset the gradients for the next step
                    sess.run(fetches=[self.reset_grad, self.reset_loss])

                    print((
                        'WORKER %d: step %d/%d loss: %.6g, learning rate: %f, '
                        'time: %f sec') %
                          (self.task_index, global_step, num_steps, loss, lr,
                           time.time() - start))
Beispiel #2
0
    def train(self, testing=False):
        '''train the model

        args:
            testing: if true only the graph will be created for debugging
                purposes
        '''

        #look for the master if distributed training is done
        master = self.server.target

        #start the session and standart servises
        config = tf.ConfigProto()
        config.gpu_options.allow_growth = True
        config.allow_soft_placement = True

        #number of times validation performance was worse
        num_tries = 0

        #check if this is the chief worker
        is_chief = self.task_index == 0

        #create the graph
        graph = tf.Graph()
        with graph.as_default():
            outputs = self._create_graph()
            scaffold = tf.train.Scaffold()

        if testing:
            return

        with graph.as_default():

            #create a hook for saving the final model
            save_hook = hooks.SaveAtEnd(
                os.path.join(self.expdir, 'model', 'network.ckpt'),
                self.model.variables)

            #create a hook for saving and restoring the validated model
            validation_hook = hooks.ValidationSaveHook(
                os.path.join(self.expdir, 'logdir', 'validated.ckpt'),
                self.model)

            with tf.train.MonitoredTrainingSession(
                master=master,
                is_chief=is_chief,
                checkpoint_dir=os.path.join(self.expdir, 'logdir'),
                scaffold=scaffold,
                hooks=[hooks.StopHook(outputs['done'])] + self.hooks(outputs),
                chief_only_hooks=[save_hook, validation_hook] \
                    + self.chief_only_hooks(outputs),
                config=config) as sess:

                #create the summary writer
                summary_writer = tf.summary.FileWriter(
                    os.path.join(self.expdir, 'logdir'))

                #start the training loop
                #pylint: disable=E1101
                while not (sess.should_stop()
                           or outputs['should_stop'].eval(session=sess)):

                    #check if validation is due
                    if (outputs['update_loss'] is not None
                            and outputs['should_validate'].eval(session=sess)):
                        if is_chief:
                            print('WORKER %d: validating model' %
                                  self.task_index)

                            #get the previous validation loss
                            prev_val_loss = outputs['best_validation'].eval(
                                session=sess)

                            #initialize validation
                            outputs['init_validation'].run(session=sess)

                            #compute the validation loss
                            for i in range(outputs['valbatches']):
                                _, summary = sess.run(fetches=[
                                    outputs['update_loss'],
                                    outputs['eval_summaries']
                                ])
                                if summary is not None:
                                    summary_writer.add_summary(summary, i)
                            summary, global_step = sess.run(fetches=[
                                outputs['val_loss_summary'],
                                outputs['global_step']
                            ])
                            summary_writer.add_summary(summary, global_step)

                            #get the current validation loss
                            validation_loss = outputs['validation_loss'].eval(
                                session=sess)

                            print('WORKER %d: validation loss: %f' %
                                  (self.task_index, validation_loss))

                            #check if the validation loss is better
                            if validation_loss >= prev_val_loss:

                                print('WORKER %d: validation loss is worse' %
                                      self.task_index)

                                #check how many times validation performance was
                                #worse
                                if self.conf['num_tries'] != 'None':
                                    if num_tries == int(
                                            self.conf['num_tries']):
                                        validation_hook.restore()
                                        print(
                                            'WORKER %d: terminating training' %
                                            self.task_index)
                                        outputs['terminate'].run(session=sess)
                                        break

                                num_tries += 1

                                if self.conf['go_back'] == 'True':

                                    #wait untill all workers are at validation
                                    #point
                                    while not outputs['all_waiting'].eval(
                                            session=sess):
                                        time.sleep(1)
                                    outputs['reset_waiting'].run(session=sess)

                                    print('WORKER %d: loading previous model' %
                                          self.task_index)

                                    #load the previous model
                                    validation_hook.restore()
                                else:
                                    outputs['update_validated_step'].run(
                                        session=sess)

                                if self.conf['valid_adapt'] == 'True':
                                    print('WORKER %d: halving learning rate' %
                                          self.task_index)
                                    outputs['half_lr'].run(session=sess)
                                    validation_hook.save()

                            else:
                                if self.conf['reset_tries'] == 'True':
                                    num_tries = 0

                                #set the validated step
                                outputs['update_validated_step'].run(
                                    session=sess)
                                outputs['update_best'].run(session=sess)
                                outputs['reset_waiting'].run(session=sess)

                                #store the validated model
                                validation_hook.save()

                        else:
                            if (self.conf['go_back'] == 'True'
                                    and self.update_loss is not None):
                                outputs['waiting'].run(session=sess)
                                while (outputs['should_validate'].eval(
                                        session=sess)
                                       and not outputs['should_stop'].eval(
                                           session=sess)):
                                    time.sleep(1)

                                if outputs['should_stop'].eval(session=sess):
                                    break

                    #start time
                    start = time.time()

                    #read in the next batch of data
                    local_steps, _ = sess.run(
                        [outputs['local_steps'], outputs['read_data']])

                    for _ in range(local_steps):
                        #update the model
                        _, loss, lr, global_step, memory, limit, summary = \
                            sess.run(
                                fetches=[outputs['update_op'],
                                         outputs['loss'],
                                         outputs['learning_rate'],
                                         outputs['global_step'],
                                         outputs['memory_usage'],
                                         outputs['memory_limit'],
                                         outputs['training_summaries']])

                        summary_writer.add_summary(summary, global_step)

                        if memory is not None:
                            memory_line = '\n\t peak memory usage: %d/%d MB' % (
                                memory / 1e6, limit / 1e6)
                        else:
                            memory_line = ''

                        print(('WORKER %d: step %d/%d loss: %f, learning rate:'
                               ' %f \n\t time elapsed: %f sec%s') %
                              (self.task_index,
                               global_step, outputs['num_steps'], loss, lr,
                               time.time() - start, memory_line))

                    outputs['increment_step'].run(session=sess)

        #store the model file
        modelfile = os.path.join(self.expdir, 'model', 'model.pkl')
        with open(modelfile, 'wb') as fid:
            pickle.dump(self.model, fid)
Beispiel #3
0
    def train(self):
        '''train the model'''

        #look for the master if distributed training is done
        master = self.server.target

        #start the session and standard services
        config = tf.ConfigProto(device_count={'CPU': 1})
        config.gpu_options.allow_growth = True
        config.allow_soft_placement = True
        #config.log_device_placement = True

        chief_only_hooks = []

        if self.init_filename != None:
            init_hook = hooks.LoadAtBegin(self.init_filename, self.models)
            chief_only_hooks.append(init_hook)

        #create a hook for saving the final model
        save_hook = hooks.SaveAtEnd(
            os.path.join(self.expdir, 'model', 'network.ckpt'), self.models)
        chief_only_hooks.append(save_hook)

        #create a hook for saving and restoring the validated model
        validation_hook = hooks.ValidationSaveHook(
            os.path.join(self.expdir, 'logdir', 'validated.ckpt'), self.models)
        chief_only_hooks.append(validation_hook)

        #number of times validation performance was worse
        num_tries = np.zeros(len(self.val_task_trainers))

        #determine al parameters
        all_params = []
        for ind, _ in enumerate(self.task_trainers):
            all_params += self.task_trainers[ind].params
        all_params = list(set(all_params))
        all_params = sorted(all_params, key=lambda par: par.name)

        with self.graph.as_default():
            with tf.train.MonitoredTrainingSession(
                    master=master,
                    is_chief=self.is_chief,
                    checkpoint_dir=os.path.join(self.expdir, 'logdir'),
                    scaffold=self.scaffold,
                    hooks=[hooks.StopHook(self.done)],
                    chief_only_hooks=chief_only_hooks,
                    config=config) as sess:

                #set the number of steps
                self.set_num_steps.run(session=sess)

                #print the params that will be updated
                print 'parameters that will be trained:'
                for ind, param in enumerate(all_params):
                    print 'param ind %i: %s' % (ind, param.name)

                #start the training loop
                #pylint: disable=E1101
                while not (sess.should_stop()
                           or self.should_stop.eval(session=sess)):

                    ##Validation part
                    #check if validation is due
                    if (self.process_val_batch is not None
                            and self.should_validate.eval(session=sess)):
                        if self.is_chief:
                            print('WORKER %d: validating model' %
                                  self.task_index)

                            #get the previous validation loss for each validation task
                            prev_val_loss_all_tasks = sess.run(
                                self.best_validation_all_tasks)

                            #reset the validation loss
                            self.reset_val_loss_norm.run(session=sess)

                            #start time
                            start = time.time()

                            #compute the validation loss
                            for _ in range(self.valbatches):
                                self.process_val_batch.run(session=sess)

                            #get the current validation loss
                            [validation_loss, val_loss_all_tasks] = sess.run([
                                self.validation_loss, self.val_loss_all_tasks
                            ])

                            print_str = ('WORKER %d: validation loss:%.6g,'
                                         'time: %f sec' %
                                         (self.task_index, validation_loss,
                                          time.time() - start))
                            #if multiple tasks, also print individual task losses
                            if len(val_loss_all_tasks) > 1:
                                for ind, loss_task in enumerate(
                                        val_loss_all_tasks):
                                    print_str += (
                                        ', task_loss %s: %.6g' %
                                        (self.task_trainers[ind].task_name,
                                         loss_task))
                            print print_str

                            #check if the validation loss is better, for every task
                            terminate_train = False
                            restore_validation = False
                            continue_validation = True
                            do_halve_lr = False
                            for task_ind, val_task in enumerate(
                                    self.val_task_trainers):
                                if val_loss_all_tasks[
                                        task_ind] >= prev_val_loss_all_tasks[
                                            task_ind]:
                                    print(
                                        'WORKER %d: validation loss is worse for %s!'
                                        %
                                        (self.task_index, val_task.task_name))

                                    #check how many times validation performance was
                                    #worse
                                    num_tries[task_ind] += 1
                                    if self.conf['num_tries'] != 'None':
                                        if num_tries[task_ind] == int(
                                                self.conf['num_tries']):
                                            terminate_train = True

                                    if self.conf['go_back'] == 'True':
                                        continue_validation = False
                                        restore_validation = True
                                    else:
                                        continue_validation = True

                                    if self.conf['valid_adapt'] == 'True':
                                        do_halve_lr = True

                                else:
                                    sess.run(
                                        self.update_best_all_tasks[task_ind])
                                    if self.conf['reset_tries'] == 'True':
                                        num_tries[task_ind] = 0

                            #decide what to do for training based on the above task validations
                            if terminate_train:
                                validation_hook.restore()
                                print('WORKER %d: terminating training' %
                                      self.task_index)
                                self.terminate.run(session=sess)
                                break

                            if restore_validation:
                                #wait untill all workers are at validation
                                #point
                                while not self.all_waiting.eval(session=sess):
                                    time.sleep(1)
                                self.reset_waiting.run(session=sess)

                                print('WORKER %d: loading previous model' %
                                      self.task_index)

                                #load the previous model
                                validation_hook.restore()

                            if continue_validation:
                                self.update_validated_step.run(session=sess)

                            if do_halve_lr:
                                print('WORKER %d: halving learning rate' %
                                      self.task_index)
                                self.half_lr.run(session=sess)
                                validation_hook.save()

                            #
                            if np.sum(num_tries) == 0:
                                self.reset_waiting.run(session=sess)

                                #store the validated model
                                validation_hook.save()

                        else:
                            if (self.conf['go_back'] == 'True'
                                    and self.process_val_batch is not None):
                                self.waiting.run(session=sess)
                                while (self.should_validate.eval(session=sess)
                                       and not self.should_stop.eval(
                                           session=sess)):
                                    time.sleep(1)

                                if self.should_stop.eval(session=sess):
                                    break

                    ##Training part
                    #start time
                    start = time.time()

                    #reset the gradients for the next step
                    sess.run(fetches=[self.reset_grad_loss_norm])

                    old_param_values = sess.run(all_params)

                    #First, accumulate the gradients
                    for _ in range(int(self.conf['numbatches_to_aggregate'])):
                        _ = sess.run([self.process_minibatch])

                    #_, batch_loss, batch_loss_norm = sess.run(fetches=[self.process_minibatch,
                    #self.task_trainers[0].batch_loss,
                    #self.task_trainers[0].batch_loss_norm])
                    #print (('batchloss: %.6g, batch_loss_norm: %.6g, batch_normalized_loss: %.6g')
                    #%(batch_loss,batch_loss_norm,batch_loss/(batch_loss_norm+1e-20)))

                    #Then, normalize the gradients
                    _ = sess.run([self.normalize_gradients])

                    #Finally, apply the gradients for each task optimizer. Get the variable values before
                    #and after the update, so stepsizes for each task can be displayed.
                    old_task_param_values = []
                    new_task_param_values = []
                    task_params_diff = []
                    loss_all_tasks = []

                    for ind, task_trainer in enumerate(self.task_trainers):
                        #get the variable values before update
                        if ind == 0:
                            old_task_param_values.append(old_param_values)
                        else:
                            old_task_param_values.append(
                                new_task_param_values[ind - 1])

                        #Apply the gradients in the task optimizer and get the task loss. If it is the last
                        #task, also get some other stuff
                        if ind + 1 < len(self.task_trainers):
                            [_, task_loss] = sess.run([
                                task_trainer.apply_gradients,
                                task_trainer.normalized_loss
                            ])
                        else:
                            _, _, task_loss, lr, global_step, num_steps = \
                             sess.run(
                              fetches=[task_trainer.apply_gradients,
                                 self.other_update_op,
                                 task_trainer.normalized_loss,
                                 self.learning_rate,
                                 self.global_step,
                                 self.num_steps])
                        loss_all_tasks.append(task_loss)
                        #get the variable values after update
                        new_task_param_values.append(sess.run(all_params))

                        #Calculate the stepsize for each variable by calculating the difference between old
                        #and new variable values. Average this per variable type (eg weights layer 1) and average.
                        #Also multiply with 10000 (this is just for printing format purposes)
                        task_params_diff.append([
                            10000.0 * np.mean(
                                np.abs(new_task_param_values[ind][par_ind] -
                                       old_task_param_values[ind][par_ind]))
                            for par_ind in range(
                                len(new_task_param_values[ind]))
                        ])

                    #Calculate loss and step size over all task optimizations
                    loss = np.mean(loss_all_tasks)
                    new_param_values = new_task_param_values[-1]
                    params_diff = [
                        10000.0 * np.mean(
                            np.abs(new_param_values[ind] -
                                   old_param_values[ind]))
                        for ind in range(len(new_param_values))
                    ]

                    #_, loss,loss_all_tasks, lr, global_step, num_steps,new_param_values = sess.run(
                    #fetches=[self.update_op,
                    #self.total_loss,
                    #self.loss_all_tasks,
                    #self.learning_rate,
                    #self.global_step,
                    #self.num_steps,
                    #all_params])

                    ##Output prompt
                    #Start the printing string with most important information
                    print_str = ((
                        'WORKER %d: step %d/%d loss: %.6g, learning rate: %f, '
                        'time: %.2f sec') %
                                 (self.task_index, global_step, num_steps,
                                  loss, lr, time.time() - start))

                    #if multiple tasks, also print individual task losses
                    if len(loss_all_tasks) > 1:
                        print_str += ' ('
                        for ind, loss_task in enumerate(loss_all_tasks):
                            print_str += (
                                '%s: %.6g. ' %
                                (self.task_trainers[ind].task_name, loss_task))
                        print_str += ')'

                    if 'print_var_updates' in self.conf and self.conf[
                            'print_var_updates'] == 'True':
                        #print the average variable step size
                        print_str += '\n Av param upd (*10000): %.3f' % np.mean(
                            np.array(params_diff))
                        #if multiple tasks, also print individual task average variable step size
                        if len(task_params_diff) > 1:
                            print_str += ' ('
                            for ind, task_param_diff in enumerate(
                                    task_params_diff):
                                print_str += '%s: %.3f; ' % (
                                    self.task_trainers[ind].task_name,
                                    np.mean(np.array(task_param_diff)))
                            print_str += ')'

                        #For each variable type (eg weights layer 1) print the average step size
                        print_str += ' ('
                        for par_ind, param in enumerate(all_params):
                            if par_ind > 0:
                                print_str += ';'
                            print_str += ('%i: %.3f ' %
                                          (par_ind, params_diff[par_ind]))
                            #if multiple tasks, also print for each variable type the individual task average step size
                            if len(task_params_diff) > 1:
                                print_str += '{'
                                for ind, task_param_diff in enumerate(
                                        task_params_diff):
                                    if ind > 0:
                                        print_str += '+'
                                    print_str += ('%.3f' %
                                                  (task_param_diff[par_ind]))
                                print_str += '} '
                        print_str += ')'

                    #print the complete string
                    print(print_str)