Exemple #1
0
def train(clusterfile, job_name, task_index, ssh_command, expdir):
    """ does everything for ss training

	Args:
		clusterfile: the file where all the machines in the cluster are
			specified if None, local training will be done
		job_name: one of ps or worker in the case of distributed training
		task_index: the task index in this job
		ssh_command: the command to use for ssh, if 'None' no tunnel will be
			created
		expdir: the experiments directory
	"""

    # read the database config file
    parsed_database_cfg = configparser.ConfigParser()
    parsed_database_cfg.read(os.path.join(expdir, 'database.cfg'))

    # read the ss config file
    model_cfg = configparser.ConfigParser()
    model_cfg.read(os.path.join(expdir, 'model.cfg'))

    # read the trainer config file
    parsed_trainer_cfg = configparser.ConfigParser()
    parsed_trainer_cfg.read(os.path.join(expdir, 'trainer.cfg'))
    trainer_cfg = dict(parsed_trainer_cfg.items('trainer'))

    # read the decoder config file
    evaluator_cfg = configparser.ConfigParser()
    evaluator_cfg.read(os.path.join(expdir, 'evaluator.cfg'))

    # read the loss config file
    losses_cfg_file = os.path.join(expdir, 'loss.cfg')
    if not os.path.isfile(losses_cfg_file):
        warnings.warn(
            'In following versions it will be required to provide a loss config file',
            Warning)
        losses_cfg_available = False
        losses_cfg = None
    else:
        losses_cfg_available = True
        losses_cfg = configparser.ConfigParser()
        losses_cfg.read(losses_cfg_file)

    # Get the config files for each training stage. Each training stage has a different
    # segment length and its network is initliazed with the network of the previous
    # training stage
    segment_lengths = trainer_cfg['segment_lengths'].split(' ')
    # segment_lengths = [segment_lengths[-1]]

    val_sum = dict()
    for seg_len_ind, segment_length in enumerate(segment_lengths):

        segment_expdir = os.path.join(expdir, segment_length)

        segment_parsed_database_cfg = configparser.ConfigParser()
        segment_parsed_database_cfg.read(
            os.path.join(segment_expdir, 'database.cfg'))

        segment_parsed_trainer_cfg = configparser.ConfigParser()
        segment_parsed_trainer_cfg.read(
            os.path.join(segment_expdir, 'trainer.cfg'))
        segment_trainer_cfg = dict(segment_parsed_trainer_cfg.items('trainer'))

        if 'multi_task' in segment_trainer_cfg['trainer']:
            segment_tasks_cfg = dict()
            for task in segment_trainer_cfg['tasks'].split(' '):
                segment_tasks_cfg[task] = dict(
                    segment_parsed_trainer_cfg.items(task))
        else:
            segment_tasks_cfg = None

        # If this is first segment length, and there is no previously validated training session for this segment length,
        # we can allow to use a different trained model to be used for bootstrapping the current model
        if seg_len_ind == 0 and \
         not os.path.exists(os.path.join(segment_expdir, 'logdir', 'validated.ckpt.index')) and \
         'init_file' in segment_trainer_cfg:
            if not os.path.exists(segment_trainer_cfg['init_file'] + '.index'):
                raise BaseException(
                    'The requested bootstrapping model does not exist: %s' %
                    segment_trainer_cfg['init_file'])
            init_filename = segment_trainer_cfg['init_file']
            print('Using the following model for bootstrapping: %s' %
                  init_filename)

        # If the above bootstrapping does not apply and there was no previously validated training sessions, use the
        # model of the previous segment length as initialization for the current one
        elif seg_len_ind > 0 and not os.path.exists(
                os.path.join(segment_expdir, 'logdir',
                             'validated.ckpt.index')):
            init_filename = os.path.join(expdir,
                                         segment_lengths[seg_len_ind - 1],
                                         'model', 'network.ckpt')
            if not os.path.exists(init_filename + '.index'):
                init_filename = None

        else:
            init_filename = None

        # if this training stage has already successfully finished, skip it
        if segment_lengths[seg_len_ind] != 'full' \
         and os.path.exists(os.path.join(expdir, segment_lengths[seg_len_ind], 'model', 'network.ckpt.index')):
            print('Already found a fully trained model for segment length %s' %
                  segment_length)
        else:
            tr = trainer_factory.factory(segment_trainer_cfg['trainer'])(
                conf=segment_trainer_cfg,
                tasksconf=segment_tasks_cfg,
                dataconf=segment_parsed_database_cfg,
                modelconf=model_cfg,
                evaluatorconf=evaluator_cfg,
                lossesconf=losses_cfg,
                expdir=segment_expdir,
                init_filename=init_filename,
                task_index=task_index)

            print('starting training for segment length: %s' % segment_length)

            # train the model
            best_val_loss = tr.train()
            if best_val_loss is not None:
                if tr.acc_steps:
                    val_sum[segment_length] = {
                        task: round(loss * 1e5) / 1e5
                        for loss, task in zip(best_val_loss, tr.tasks)
                    }
                else:
                    val_sum[segment_length] = round(best_val_loss * 1e5) / 1e5

            # best_val_losses, all_tasks = tr.train()
            # if best_val_losses is not None:
            # 	val_sum[segment_length] = {task: float(loss) for (loss, task) in zip(best_val_losses, all_tasks)}

    if val_sum and 'full' in val_sum:
        out_file = os.path.join(expdir, 'val_sum.json')
        with open(out_file, 'w') as fid:
            print('the validation loss ...')
            print(val_sum)
            print('... will be saved to memory')
            json.dump(val_sum, fid)
    else:
        print('Did not find a validation loss to save')
Exemple #2
0
def train_lm(clusterfile,
             job_name,
             task_index,
             ssh_command,
             expdir):

    ''' does everything for language model training

    Args:
        clusterfile: the file where all the machines in the cluster are
            specified if None, local training will be done
        job_name: one of ps or worker in the case of distributed training
        task_index: the task index in this job
        ssh_command: the command to use for ssh, if 'None' no tunnel will be
            created
        expdir: the experiments directory
    '''
    #read the database config file
    parsed_database_cfg = configparser.ConfigParser()
    parsed_database_cfg.read(expdir + '/database.cfg')
    database_cfg = dict(parsed_database_cfg.items('database'))

    #read the asr config file
    parsed_nnet_cfg = configparser.ConfigParser()
    parsed_nnet_cfg.read(expdir + '/model/lm.cfg')
    nnet_cfg = dict(parsed_nnet_cfg.items('lm'))

    #read the trainer config file
    parsed_trainer_cfg = configparser.ConfigParser()
    parsed_trainer_cfg.read(expdir + '/trainer.cfg')
    trainer_cfg = dict(parsed_trainer_cfg.items('trainer'))

    #read the decoder config file
    parsed_decoder_cfg = configparser.ConfigParser()
    parsed_decoder_cfg.read(expdir + '/model/decoder.cfg')
    decoder_cfg = dict(parsed_decoder_cfg.items('decoder'))

    #create the cluster and server
    server = create_server.create_server(
        clusterfile=clusterfile,
        job_name=job_name,
        task_index=task_index,
        expdir=expdir,
        ssh_command=ssh_command)

    #copy the alphabet to the model
    if (job_name == 'ps' and task_index == 0) or job_name == 'local':
        shutil.copyfile(os.path.join(database_cfg['train_dir'], 'alphabet'),
                        os.path.join(FLAGS.expdir, 'model', 'alphabet'))

    #the ps should just wait
    if job_name == 'ps':
        server.join()

    #create the coder
    with open(os.path.join(database_cfg['train_dir'], 'alphabet')) as fid:
        alphabet = fid.read().split(' ')
    coder = target_coder.TargetCoder(alphabet)

    #read the number of utterances
    with open(os.path.join(database_cfg['train_dir'], 'numlines')) as fid:
        num_utt = int(fid.read())

    #read the maximum length
    with open(os.path.join(database_cfg['train_dir'], 'max_num_chars')) as fid:
        max_length = int(fid.read())

    #create a batch dispenser for the training data
    dispenser = batchdispenser.LmBatchDispenser(
        target_coder=coder,
        size=int(trainer_cfg['batch_size']),
        textfile=os.path.join(database_cfg['train_dir'], 'text'),
        max_length=max_length,
        num_utt=num_utt)

    #create a reader for the validation data
    if 'dev_dir' in database_cfg:

        #read the maximum length
        with open(os.path.join(database_cfg['dev_dir'],
                               'max_num_chars')) as fid:
            max_length = int(fid.read())

        #create a batch dispenser for the training data
        val_reader = text_reader.TextReader(
            textfile=os.path.join(database_cfg['dev_dir'], 'text'),
            max_length=max_length,
            coder=coder)

        val_targets = val_reader.as_dict()

    else:
        if int(trainer_cfg['valid_utt']) > 0:
            val_dispenser = dispenser.split(int(trainer_cfg['valid_utt']))
            val_reader = val_dispenser.textreader
            val_targets = val_reader.asdict()
        else:
            val_reader = None
            val_targets = None

    #encode the validation targets
    if val_targets is not None:
        for utt in val_targets:
            val_targets[utt] = dispenser.textreader.coder.encode(
                val_targets[utt])

    #create the classifier
    classifier = lm_factory.factory(
        conf=nnet_cfg,
        output_dim=coder.num_labels)

    #create the callable for the decoder
    decoder = partial(
        decoder_factory.factory,
        conf=decoder_cfg,
        classifier=classifier,
        input_dim=1,
        max_input_length=val_reader.max_length,
        coder=coder,
        expdir=expdir)

    #create the trainer
    tr = trainer_factory.factory(
        conf=trainer_cfg,
        decoder=decoder,
        classifier=classifier,
        input_dim=1,
        reconstruction_dim=1,
        dispenser=dispenser,
        val_reader=val_reader,
        val_targets=val_targets,
        expdir=expdir,
        server=server,
        task_index=task_index)

    #train the classifier
    tr.train()
Exemple #3
0
def train(clusterfile,
          job_name,
          task_index,
          ssh_command,
          expdir,
          testing=False):
    ''' does everything for asr training

    Args:
        clusterfile: the file where all the machines in the cluster are
            specified if None, local training will be done
        job_name: one of ps or worker in the case of distributed training
        task_index: the task index in this job
        ssh_command: the command to use for ssh, if 'None' no tunnel will be
            created
        expdir: the experiments directory
        testing: if true only the graph will be created for debugging purposes
    '''

    #read the database config file
    database_cfg = configparser.ConfigParser()
    database_cfg.read(os.path.join(expdir, 'database.cfg'))

    #read the asr config file
    model_cfg = configparser.ConfigParser()
    model_cfg.read(os.path.join(expdir, 'model.cfg'))

    #read the trainer config file
    trainer_cfg = configparser.ConfigParser()
    trainer_cfg.read(os.path.join(expdir, 'trainer.cfg'))

    #read the decoder config file
    evaluator_cfg = configparser.ConfigParser()
    evaluator_cfg.read(os.path.join(expdir, 'validation_evaluator.cfg'))

    #create the cluster and server
    server = create_server.create_server(clusterfile=clusterfile,
                                         job_name=job_name,
                                         task_index=task_index,
                                         expdir=expdir,
                                         ssh_command=ssh_command)

    #parameter server
    if job_name == 'ps':

        print 'starting parameter server'

        #create the parameter server
        ps = trainer.ParameterServer(conf=trainer_cfg,
                                     modelconf=model_cfg,
                                     dataconf=database_cfg,
                                     server=server,
                                     task_index=task_index)

        #let the ps wait untill all workers are finished
        ps.join()

        print 'parameter server stopped'

        return

    #create the trainer
    tr = trainer_factory.factory(trainer_cfg.get('trainer', 'trainer'))(
        conf=trainer_cfg,
        dataconf=database_cfg,
        modelconf=model_cfg,
        evaluatorconf=evaluator_cfg,
        expdir=expdir,
        server=server,
        task_index=task_index)

    print 'starting training'

    #train the model
    tr.train(testing)
Exemple #4
0
def train(clusterfile, job_name, task_index, ssh_command, expdir):

	""" does everything for ss training

	Args:
		clusterfile: the file where all the machines in the cluster are
			specified if None, local training will be done
		job_name: one of ps or worker in the case of distributed training
		task_index: the task index in this job
		ssh_command: the command to use for ssh, if 'None' no tunnel will be
			created
		expdir: the experiments directory
	"""

	# read the database config file
	parsed_database_cfg = configparser.ConfigParser()
	parsed_database_cfg.read(os.path.join(expdir, 'database.cfg'))

	# read the ss config file
	model_cfg = configparser.ConfigParser()
	model_cfg.read(os.path.join(expdir, 'model.cfg'))

	# read the trainer config file
	parsed_trainer_cfg = configparser.ConfigParser()
	parsed_trainer_cfg.read(os.path.join(expdir, 'trainer.cfg'))
	trainer_cfg = dict(parsed_trainer_cfg.items('trainer'))

	# read the decoder config file
	evaluator_cfg = configparser.ConfigParser()
	evaluator_cfg.read(os.path.join(expdir, 'evaluator.cfg'))
	
	# Get the config files for each training stage. Each training stage has a different
	# segment length and its network is initliazed with the network of the previous
	# training stage
	segment_lengths = trainer_cfg['segment_lengths'].split(' ')
	# segment_lengths = [segment_lengths[-1]]

	val_sum = dict()
	for seg_len_ind, segment_length in enumerate(segment_lengths):
	
		segment_expdir = os.path.join(expdir, segment_length)

		segment_parsed_database_cfg = configparser.ConfigParser()
		segment_parsed_database_cfg.read(
			os.path.join(segment_expdir, 'database.cfg'))

		segment_parsed_trainer_cfg = configparser.ConfigParser()
		segment_parsed_trainer_cfg.read(
			os.path.join(segment_expdir, 'trainer.cfg'))
		segment_trainer_cfg = dict(segment_parsed_trainer_cfg.items('trainer'))

		if 'multi_task' in segment_trainer_cfg['trainer']:
			segment_tasks_cfg = dict()
			for task in segment_trainer_cfg['tasks'].split(' '):
				segment_tasks_cfg[task] = dict(segment_parsed_trainer_cfg.items(task))
		else:
			segment_tasks_cfg = None

		# If there was no previously validated training sessions, use the model of the
		# previous segment length as initialization for the current one
		if seg_len_ind > 0 and not os.path.exists(os.path.join(segment_expdir, 'logdir', 'validated.ckpt.index')):
			init_filename = os.path.join(expdir, segment_lengths[seg_len_ind-1], 'model', 'network.ckpt')
			if not os.path.exists(init_filename + '.index'):
				init_filename = None

		else:
			init_filename = None

		# if this training stage has already succesfully finished, skipt it
		if os.path.exists(os.path.join(expdir, segment_lengths[seg_len_ind], 'model', 'network.ckpt.index')):
			print 'Already found a fully trained model for segment length %s' % segment_length
		else:
			tr = trainer_factory.factory(segment_trainer_cfg['trainer'])(
				conf=segment_trainer_cfg,
				tasksconf=segment_tasks_cfg,
				dataconf=segment_parsed_database_cfg,
				modelconf=model_cfg,
				evaluatorconf=evaluator_cfg,
				expdir=segment_expdir,
				init_filename=init_filename,
				task_index=task_index)

			print 'starting training for segment length: %s' % segment_length

			# train the model
			best_val_losses, all_tasks = tr.train()

			if best_val_losses is not None:
				val_sum[segment_length] = {task: float(loss) for (loss, task) in zip(best_val_losses, all_tasks)}

	if val_sum and 'full' in val_sum:
		out_file = os.path.join(expdir, 'val_sum.json')
		with open(out_file, 'w') as fid:
			json.dump(val_sum, fid)
	else:
		print 'Did not find a validation loss to save'
Exemple #5
0
def train(clusterfile, job_name, task_index, ssh_command, expdir):
    ''' does everything for ss training

    Args:
        clusterfile: the file where all the machines in the cluster are
            specified if None, local training will be done
        job_name: one of ps or worker in the case of distributed training
        task_index: the task index in this job
        ssh_command: the command to use for ssh, if 'None' no tunnel will be
            created
        expdir: the experiments directory
    '''

    #read the database config file
    parsed_database_cfg = configparser.ConfigParser()
    parsed_database_cfg.read(os.path.join(expdir, 'database.cfg'))

    #read the ss config file
    model_cfg = configparser.ConfigParser()
    model_cfg.read(os.path.join(expdir, 'model.cfg'))

    #read the trainer config file
    parsed_trainer_cfg = configparser.ConfigParser()
    parsed_trainer_cfg.read(os.path.join(expdir, 'trainer.cfg'))
    trainer_cfg = dict(parsed_trainer_cfg.items('trainer'))

    #read the decoder config file
    evaluator_cfg = configparser.ConfigParser()
    evaluator_cfg.read(os.path.join(expdir, 'evaluator.cfg'))

    #Get the config files for each training stage. Each training stage has a different
    #segment length and its network is initliazed with the network of the previous
    #training stage
    segment_lengths = trainer_cfg['segment_lengths'].split(' ')
    #segment_lengths = [segment_lengths[-1]]
    #os.environ['CUDA_VISIBLE_DEVICES'] = '1'
    for i, segment_length in enumerate(segment_lengths):

        segment_expdir = os.path.join(expdir, segment_length)

        segment_parsed_database_cfg = configparser.ConfigParser()
        segment_parsed_database_cfg.read(
            os.path.join(segment_expdir, 'database.cfg'))

        segment_parsed_trainer_cfg = configparser.ConfigParser()
        segment_parsed_trainer_cfg.read(
            os.path.join(segment_expdir, 'trainer.cfg'))
        segment_trainer_cfg = dict(segment_parsed_trainer_cfg.items('trainer'))

        if segment_trainer_cfg['trainer'] == 'multi_task':
            segment_tasks_cfg = dict()
            for task in segment_trainer_cfg['tasks'].split(' '):
                segment_tasks_cfg[task] = dict(
                    segment_parsed_trainer_cfg.items(task))
        else:
            segment_tasks_cfg = None

        #If there was no previously validated training sessions, use the model of the
        #previous segment length as initialization for the current one
        if i > 0 and not os.path.exists(
                os.path.join(segment_expdir, 'logdir',
                             'validated.ckpt.index')):
            init_filename = os.path.join(expdir, segment_lengths[i - 1],
                                         'model', 'network.ckpt')
            if not os.path.exists(init_filename + '.index'):
                init_filename = None

        else:
            init_filename = None

        #if this training stage has already succesfully finished, skipt it
        if os.path.exists(
                os.path.join(expdir, segment_lengths[i], 'model',
                             'network.ckpt.index')):
            print 'Already found a fully trained model for segment length %s' % segment_length
        else:

            #create the cluster and server
            server = create_server.create_server(clusterfile=clusterfile,
                                                 job_name=job_name,
                                                 task_index=task_index,
                                                 expdir=expdir,
                                                 ssh_command=ssh_command)

            #parameter server
            if job_name == 'ps':
                raise 'Parameter server is currently not implemented correctly'
                ##create the parameter server
                #ps = multi_task_trainer.ParameterServer(
                #conf=segment_trainer_cfg,
                #tasksconf=segment_tasks_cfg,
                #modelconf=model_cfg,
                #dataconf=segment_parsed_database_cfg,
                #server=server,
                #task_index=task_index)

                #if task_index ==0:
                ##let the ps wait untill all workers are finished
                #ps.join()
                #return

            tr = trainer_factory.factory(segment_trainer_cfg['trainer'])(
                conf=segment_trainer_cfg,
                tasksconf=segment_tasks_cfg,
                dataconf=segment_parsed_database_cfg,
                modelconf=model_cfg,
                evaluatorconf=evaluator_cfg,
                expdir=segment_expdir,
                init_filename=init_filename,
                server=server,
                task_index=task_index)

            print 'starting training for segment length: %s' % segment_length

            #train the model
            tr.train()
Exemple #6
0
def train_asr(clusterfile, job_name, task_index, ssh_command, expdir):
    ''' does everything for asr training

    Args:
        clusterfile: the file where all the machines in the cluster are
            specified if None, local training will be done
        job_name: one of ps or worker in the case of distributed training
        task_index: the task index in this job
        ssh_command: the command to use for ssh, if 'None' no tunnel will be
            created
        expdir: the experiments directory
    '''

    #read the database config file
    parsed_database_cfg = configparser.ConfigParser()
    parsed_database_cfg.read(os.path.join(expdir, 'database.cfg'))
    database_cfg = dict(parsed_database_cfg.items('database'))

    #read the features config file
    parsed_feat_cfg = configparser.ConfigParser()
    parsed_feat_cfg.read(os.path.join(expdir, 'model', 'features.cfg'))
    feat_cfg = dict(parsed_feat_cfg.items('features'))

    #read the asr config file
    parsed_nnet_cfg = configparser.ConfigParser()
    parsed_nnet_cfg.read(os.path.join(expdir, 'model', 'asr.cfg'))
    nnet_cfg = dict(parsed_nnet_cfg.items('asr'))

    #read the trainer config file
    parsed_trainer_cfg = configparser.ConfigParser()
    parsed_trainer_cfg.read(os.path.join(expdir, 'trainer.cfg'))
    trainer_cfg = dict(parsed_trainer_cfg.items('trainer'))

    #read the decoder config file
    parsed_decoder_cfg = configparser.ConfigParser()
    parsed_decoder_cfg.read(os.path.join(expdir, 'model', 'decoder.cfg'))
    decoder_cfg = dict(parsed_decoder_cfg.items('decoder'))

    #create the cluster and server
    server = create_server.create_server(clusterfile=clusterfile,
                                         job_name=job_name,
                                         task_index=task_index,
                                         expdir=expdir,
                                         ssh_command=ssh_command)

    #the ps should just wait
    if job_name == 'ps':
        server.join()

    featdir = os.path.join(database_cfg['train_dir'], feat_cfg['name'])

    #create the coder
    with open(os.path.join(database_cfg['train_dir'], 'alphabet')) as fid:
        alphabet = fid.read().split(' ')
    coder = target_coder.TargetCoder(alphabet)

    #create a feature reader for the training data
    with open(featdir + '/maxlength', 'r') as fid:
        max_length = int(fid.read())

    featreader = feature_reader.FeatureReader(scpfile=featdir +
                                              '/feats_shuffled.scp',
                                              cmvnfile=featdir + '/cmvn.scp',
                                              utt2spkfile=featdir + '/utt2spk',
                                              max_length=max_length)

    #read the feature dimension
    with open(featdir + '/dim', 'r') as fid:
        input_dim = int(fid.read())

    #the path to the text file
    textfile = os.path.join(database_cfg['train_dir'], 'targets')

    #create a batch dispenser for the training data
    dispenser = batchdispenser.AsrBatchDispenser(
        feature_reader=featreader,
        target_coder=coder,
        size=int(trainer_cfg['batch_size']),
        target_path=textfile)

    #create a reader for the validation data
    if 'dev_data' in database_cfg:
        featdir = database_cfg['dev_dir'] + '/' + feat_cfg['name']

        with open(featdir + '/maxlength', 'r') as fid:
            max_length = int(fid.read())

        val_reader = feature_reader.FeatureReader(
            scpfile=featdir + '/feats.scp',
            cmvnfile=featdir + '/cmvn.scp',
            utt2spkfile=featdir + '/utt2spk',
            max_length=max_length)

        textfile = os.path.join(database_cfg['dev_dir'], 'targets')

        #read the validation targets
        with open(textfile) as fid:
            lines = fid.readlines()

        val_targets = dict()
        for line in lines:
            splitline = line.strip().split(' ')
            val_targets[splitline[0]] = ' '.join(splitline[1:])

    else:
        if int(trainer_cfg['valid_utt']) > 0:
            val_dispenser = dispenser.split(int(trainer_cfg['valid_utt']))
            val_reader = val_dispenser.feature_reader
            val_targets = val_dispenser.target_dict
        else:
            val_reader = None
            val_targets = None

    #encode the validation targets
    if val_targets is not None:
        for utt in val_targets:
            val_targets[utt] = dispenser.target_coder.encode(val_targets[utt])

    #create the classifier
    classifier = asr_factory.factory(conf=nnet_cfg,
                                     output_dim=coder.num_labels)

    #create the callable for the decoder
    decoder = partial(decoder_factory.factory,
                      conf=decoder_cfg,
                      classifier=classifier,
                      input_dim=input_dim,
                      max_input_length=val_reader.max_length,
                      coder=coder,
                      expdir=expdir)

    #create the trainer
    tr = trainer_factory.factory(conf=trainer_cfg,
                                 decoder=decoder,
                                 classifier=classifier,
                                 input_dim=input_dim,
                                 dispenser=dispenser,
                                 val_reader=val_reader,
                                 val_targets=val_targets,
                                 expdir=expdir,
                                 server=server,
                                 task_index=task_index)

    print 'starting training'

    #train the classifier
    tr.train()
Exemple #7
0
def train_asr(clusterfile,
              job_name,
              task_index,
              ssh_command,
              expdir):

    ''' does everything for asr training
    Args:
        clusterfile: the file where all the machines in the cluster are
            specified if None, local training will be done
        job_name: one of ps or worker in the case of distributed training
        task_index: the task index in this job
        ssh_command: the command to use for ssh, if 'None' no tunnel will be
            created
        expdir: the experiments directory
    '''

    #read the database config file
    parsed_database_cfg = configparser.ConfigParser()
    parsed_database_cfg.read(os.path.join(expdir, 'database.cfg'))
    database_cfg = dict(parsed_database_cfg.items('database'))

    #read the features config file
    parsed_feat_cfg = configparser.ConfigParser()
    parsed_feat_cfg.read(os.path.join(expdir, 'model', 'features.cfg'))
    feat_cfg = dict(parsed_feat_cfg.items('features'))

    #read the asr config file
    parsed_nnet_cfg = configparser.ConfigParser()
    parsed_nnet_cfg.read(os.path.join(expdir, 'model', 'asr.cfg'))
    nnet_cfg = dict(parsed_nnet_cfg.items('asr'))

    #read the trainer config file
    parsed_trainer_cfg = configparser.ConfigParser()
    parsed_trainer_cfg.read(os.path.join(expdir, 'trainer.cfg'))
    trainer_cfg = dict(parsed_trainer_cfg.items('trainer'))

    #read the decoder config file
    parsed_decoder_cfg = configparser.ConfigParser()
    parsed_decoder_cfg.read(os.path.join(expdir, 'model', 'decoder.cfg'))
    decoder_cfg = dict(parsed_decoder_cfg.items('decoder'))

    #make distinction between three implemented different kind of training forms
    if database_cfg['train_mode'] == 'supervised':
        nonsupervised = False
    elif database_cfg['train_mode'] == 'nonsupervised' or\
            database_cfg['train_mode'] == 'semisupervised':
        nonsupervised = True
    else:
        raise Exception('Wrong kind of training mode')

    #when (partly) nonsupervised, what features are used for the reconstruction
    #currently two possible options implemented
    if nonsupervised:
        if trainer_cfg['reconstruction_features'] == 'audio_samples':
            audio_used = True
        elif trainer_cfg['reconstruction_features'] == 'input_features':
            audio_used = False
        else:
            raise Exception(
                'Unknown specification for the reconstruction features')

    #read the quant config file if nonsupervised training and samples used
    if nonsupervised:
        if audio_used:
            parsed_quant_cfg = configparser.ConfigParser()
            parsed_quant_cfg.read(os.path.join(expdir,
                                               'model', 'quantization.cfg'))
            quant_cfg = dict(parsed_quant_cfg.items('features'))

    #based on the other settings, compute and overwrite samples_per_hlfeature
    #and unpredictable_samples in the classifier config dictionary
    if nonsupervised:
        if audio_used:
            rate_after_quant = int(quant_cfg['quant_rate'])
            win_lenght = float(feat_cfg['winlen'])
            win_shift = float(feat_cfg['winstep'])
            samples_one_window = int(win_lenght*rate_after_quant)
            samples_one_shift = int(win_shift*rate_after_quant)
            #### THIS IS ONLY RELEVANT WHEN USING A LISTENER WITH PYRAM STRUCT
            # and this line should be adapted otherwise
            time_compression = 2**int(nnet_cfg['listener_numlayers'])
            #store values in config dictionary
            nnet_cfg['samples_per_hlfeature'] = samples_one_shift\
                                        *time_compression
            nnet_cfg['unpredictable_samples'] = (samples_one_window+\
                                    (time_compression-1)\
                        *samples_one_shift)-nnet_cfg['samples_per_hlfeature']


    #create the cluster and server
    server = create_server.create_server(
        clusterfile=clusterfile,
        job_name=job_name,
        task_index=task_index,
        expdir=expdir,
        ssh_command=ssh_command)

    #the ps should just wait
    if job_name == 'ps':
        server.join()

    # path to where the training samples are stored
    featdir = os.path.join(database_cfg['train_dir'], feat_cfg['name'])

    #create the coder
    with open(os.path.join(database_cfg['train_dir'], 'alphabet')) as fid:
        alphabet = fid.read().split(' ')
    coder = target_coder.TargetCoder(alphabet)

    #create a feature reader for the training data
    with open(featdir + '/maxlength', 'r') as fid:
        max_length = int(fid.read())

    featreader = feature_reader.FeatureReader(
        scpfile=featdir + '/feats_shuffled.scp',
        cmvnfile=featdir + '/cmvn.scp',
        utt2spkfile=featdir + '/utt2spk',
        max_length=max_length)

    #read the feature dimension
    with open(featdir + '/dim', 'r') as fid:
        input_dim = int(fid.read())

    #the path to the text file
    textfile = os.path.join(database_cfg['train_dir'], 'targets')


    # If nonsupervised and audio used, we also need to read samples
    # these can be done with a second feature reader
    if nonsupervised:
        if audio_used:
            featdir2 = os.path.join(database_cfg['train_dir'],
                                    quant_cfg['name'])

            with open(featdir2 + '/maxlength', 'r') as fid:
                max_length_audio = int(fid.read())

            audioreader = feature_reader.FeatureReader(
                scpfile=featdir2 + '/feats_shuffled.scp',
                cmvnfile=None,
                utt2spkfile=None,
                max_length=max_length_audio)

    ## create a batch dispenser, depending on which situation we're in
    if not nonsupervised:
    # in the normal supervised training mode, regular dispenser is needed
        if 'las_ignoring_mode' in trainer_cfg:
            if trainer_cfg['las_ignoring_mode'] == 'True':
            # if we ignore unlabeled examples
                dispenser = batchdispenser.AsrTextBatchDispenser(
                    feature_reader=featreader,
                    target_coder=coder,
                    size=int(trainer_cfg['batch_size']),
                    target_path=textfile)
            elif trainer_cfg['las_ignoring_mode'] == 'False':
            # if we choose to process the unlabeled examples
                if 'fixed_ratio' in trainer_cfg:
                    if trainer_cfg['fixed_ratio'] == 'True':
                    # if we choose to process with batches with fixed
                    # labeled/unlabeled ratio
                        dispenser = \
                            batchdispenser.AsrTextBatchDispenserAltFixRatio(
                                feature_reader=featreader,
                                target_coder=coder,
                                size=int(trainer_cfg['batch_size']),
                                target_path=textfile,
                                percentage_unlabeled=1-float(
                                    database_cfg['part_labeled']))
                    elif trainer_cfg['fixed_ratio'] == 'False':
                        # if the fixed ratio is not used
                        dispenser = batchdispenser.AsrTextBatchDispenserAlt(
                            feature_reader=featreader,
                            target_coder=coder,
                            size=int(trainer_cfg['batch_size']),
                            target_path=textfile)
                    else:
                        raise Exception('wrong information in fixed_ratio var')
                else:
                # if fixed ratio is not specified, we choose to do without it
                    dispenser = batchdispenser.AsrTextBatchDispenserAlt(
                        feature_reader=featreader,
                        target_coder=coder,
                        size=int(trainer_cfg['batch_size']),
                        target_path=textfile)
            else:
                raise Exception('wrong information in LAS_ignoring_mode var')
        else:
        # if no specification is made about the ignoring, ignore the unlabeled
            dispenser = batchdispenser.AsrTextBatchDispenser(
                feature_reader=featreader,
                target_coder=coder,
                size=int(trainer_cfg['batch_size']),
                target_path=textfile)
    else:
        # when doing (partly) nonsupervised extra reconstruction features needed
        if audio_used:
        # when the audio is the reconstruction feature
            if 'fixed_ratio' in trainer_cfg:
                if trainer_cfg['fixed_ratio'] == 'True':
                    # if specified to work with fixed lab/unlab ratio batches
                    dispenser = \
                        batchdispenser.AsrTextAndAudioBatchDispenserFixRatio(
                            feature_reader=featreader,
                            audio_reader=audioreader,
                            target_coder=coder,
                            size=int(trainer_cfg['batch_size']),
                            target_path=textfile,
                            percentage_unlabeled=1-float(
                                database_cfg['part_labeled']))
                elif trainer_cfg['fixed_ratio'] == 'False':
                # if specified to not use the fixed ratio
                    dispenser = batchdispenser.AsrTextAndAudioBatchDispenser(
                        feature_reader=featreader,
                        audio_reader=audioreader,
                        target_coder=coder,
                        size=int(trainer_cfg['batch_size']),
                        target_path=textfile)
                else:
                    raise Exception('wrong information in fixed_ratio var')
            else:
            # without specification, suppose no fixed ratio batches
                dispenser = batchdispenser.AsrTextAndAudioBatchDispenser(
                    feature_reader=featreader,
                    audio_reader=audioreader,
                    target_coder=coder,
                    size=int(trainer_cfg['batch_size']),
                    target_path=textfile)
        else:
        # if no audio is used, the input features are used
            if 'fixed_ratio' in trainer_cfg:
                if trainer_cfg['fixed_ratio'] == 'True':
                # if specified to work with fixed labeled/unlabled ratio batches
                    dispenser = \
                        batchdispenser.AsrTextAndFeatBatchDispenserFixRatio(
                            feature_reader=featreader,
                            target_coder=coder,
                            size=int(trainer_cfg['batch_size']),
                            target_path=textfile,
                            percentage_unlabeled=1-float(
                                database_cfg['part_labeled']))
                elif trainer_cfg['fixed_ratio'] == 'False':
                # if specified to not use the fixed ratio
                    dispenser = batchdispenser.AsrTextAndFeatBatchDispenser(
                        feature_reader=featreader,
                        target_coder=coder,
                        size=int(trainer_cfg['batch_size']),
                        target_path=textfile)
                else:
                    raise Exception('wrong information in fixed_ratio var')
            else:
            # without specification, suppose no fixed ratio batches
                dispenser = batchdispenser.AsrTextAndFeatBatchDispenser(
                    feature_reader=featreader,
                    target_coder=coder,
                    size=int(trainer_cfg['batch_size']),
                    target_path=textfile)


    # read validation data. If there are text targets, they are only important
    # for the validation data. If only nonsupervised, we must validate on the
    # reconstructed features
    if 'dev_data' in database_cfg:
        # create a reader for the validation inputs
        featdir = database_cfg['dev_dir'] + '/' +  feat_cfg['name']

        with open(featdir + '/maxlength', 'r') as fid:
            max_length = int(fid.read())

        val_reader = feature_reader.FeatureReader(
            scpfile=featdir + '/feats.scp',
            cmvnfile=featdir + '/cmvn.scp',
            utt2spkfile=featdir + '/utt2spk',
            max_length=max_length)

        textfile = os.path.join(database_cfg['dev_dir'], 'targets')

        #read the validation text targets
        with open(textfile) as fid:
            lines = fid.readlines()

            val_text_targets = dict()
            for line in lines:
                splitline = line.strip().split(' ')
                val_text_targets[splitline[0]] = ' '.join(splitline[1:])

        if nonsupervised:
        #also store the reconstruction targets
            val_rec_targets = dict()
            if audio_used:
                audiodir = database_cfg['dev_dir'] + '/' +  quant_cfg['name']
                with open(audiodir + '/maxlength', 'r') as fid:
                    max_length_audio = int(fid.read())
                val_audio_reader = feature_reader.FeatureReader(
                    scpfile=audiodir + '/feats.scp',
                    cmvnfile=None,
                    utt2spkfile=audiodir + '/utt2spk',
                    max_length=max_length_audio)
                for _ in range(val_audio_reader.num_utt):
                    utt_id, audio, _ = val_audio_reader.get_utt()
                    val_rec_targets[utt_id] = audio
            else: #input features are used
                for _ in range(val_reader.num_utt):
                    utt_id, feat, _ = val_reader.get_utt()
                    val_rec_targets[utt_id] = feat
        else:
            with open(textfile) as fid:
                lines = fid.readlines()

                val_rec_targets = dict()
                for line in lines:
                    splitline = line.strip().split(' ')
                    val_rec_targets[splitline[0]] = None

        val_targets = dict()
        for utt_id in val_text_targets:
            val_targets[utt_id] = (val_text_targets[utt_id],
                                   val_rec_targets[utt_id])

    else:
        if int(trainer_cfg['valid_utt']) > 0:
            val_dispenser = dispenser.split(int(trainer_cfg['valid_utt']))
            val_reader = val_dispenser.feature_reader
            val_targets = val_dispenser.target_dict
        else:
            val_reader = None
            val_targets = None

    #encode the validation targets
    if val_targets is not None:
        for utt in val_targets:
            val_targets[utt] = (dispenser.target_coder.encode(
                val_targets[utt][0]), val_targets[utt][1])


    #create the classifier
    if nonsupervised:
        if audio_used:
            output_dim_second_el = int(quant_cfg['quant_levels'])
        else: # input features used
            output_dim_second_el = input_dim
    else: # only supervised training
        output_dim_second_el = None

    classifier = asr_factory.factory(
        conf=nnet_cfg,
        output_dim=(coder.num_labels, output_dim_second_el))

    #create the callable for the decoder
    decoder = partial(
        decoder_factory.factory,
        conf=decoder_cfg,
        classifier=classifier,
        input_dim=input_dim,
        max_input_length=val_reader.max_length,
        coder=coder,
        expdir=expdir)

    #create the trainer
    if nonsupervised:
        if audio_used:
            reconstruction_dim = 1
        else:
            reconstruction_dim = input_dim
    else:
        reconstruction_dim = 1

    tr = trainer_factory.factory(
        conf=trainer_cfg,
        decoder=decoder,
        classifier=classifier,
        input_dim=input_dim,
        reconstruction_dim=reconstruction_dim,
        dispenser=dispenser,
        val_reader=val_reader,
        val_targets=val_targets,
        expdir=expdir,
        server=server,
        task_index=task_index)

    print 'starting training'

    #train the classifier
    tr.train()