def TrainByOneBatch(batch, train, modelSpecs, forRefState=False):

    ## batch is a list of protein locations, so we need to load the real data here
    minibatch = DataProcessor.LoadRealData(batch, modelSpecs)

    ## add code here to make sure that the data has the same input dimension as the model specification
    FeatureUtils.CheckModelNDataConsistency(modelSpecs, minibatch)

    onebatch, names4onebatch = DataProcessor.AssembleOneBatch(
        minibatch, modelSpecs, forRefState=forRefState)
    x1d, x2d, x1dmask, x2dmask = onebatch[0:4]

    ## crop a large protein to deal with limited GPU memory. For sequential and embedding features, the theano model itself will crop based upon bounding box
    bounds = SampleBoundingBox((x2d.shape[1], x2d.shape[2]),
                               modelSpecs['maxbatchSize'])

    #x1d_new = x1d[:, bounds[1]:bounds[3], :]
    x1d_new = x1d
    x2d_new = x2d[:, bounds[0]:bounds[2], bounds[1]:bounds[3], :]
    #x1dmask_new = x1dmask[:, bounds[1]:x1dmask.shape[1] ]
    x1dmask_new = x1dmask
    x2dmask_new = x2dmask[:, bounds[0]:x2dmask.shape[1], bounds[1]:bounds[3]]

    input = [x1d_new, x2d_new, x1dmask_new, x2dmask_new]

    ## if embedding is used
    ##if any( k in modelSpecs['seq2matrixMode'] for k in ('SeqOnly', 'Seq+SS') ):
    if config.EmbeddingUsed(modelSpecs):
        embed = onebatch[4]
        #embed_new = embed[:, bounds[1]:bounds[3], : ]
        embed_new = embed
        input.append(embed_new)

        remainings = onebatch[5:]
    else:
        remainings = onebatch[4:]

##crop the ground truth and weight matrices
    for x2d0 in remainings:
        if len(x2d0.shape) == 3:
            input.append(x2d0[:, bounds[0]:bounds[2], bounds[1]:bounds[3]])
        else:
            input.append(x2d0[:, bounds[0]:bounds[2], bounds[1]:bounds[3], :])

    ## add bounding box to the input list
    input.append(bounds)

    if config.TrainByRefLoss(modelSpecs):
        if forRefState:
            input.append(np.int32(-1))
        else:
            input.append(np.int32(1))

    train_loss, train_errors, param_L2 = train(*input)

    return train_loss, train_errors, param_L2
def PrepareInput4Train(data, modelSpecs, floatType=np.float32, forRefState=False, UseSharedMemory=False):
	if not bool(data):
		print 'ERROR: the input data for PrepareInput4Train2 is empty'
		exit(1)

	allowedLen = int(math.floor(math.sqrt(modelSpecs['maxbatchSize']) ) )
	bounds =[]
	for d in data:
		if d['seqLen'] < allowedLen:
			bounds.append( None )
			continue
		box = SampleBoundingBox( (d['seqLen'], d['seqLen']), modelSpecs['maxbatchSize'] )
		bounds.append(box)
	
	#print allowedLen
	#print bounds
	onebatch, _= DataProcessor.AssembleOneBatch(data, modelSpecs, forRefState=forRefState, bounds=bounds, floatType=floatType, bUseSharedMemory=UseSharedMemory)

	## determine the bounding box. 
	maxSeqLen = max([ d['seqLen'] for d in data ])
	#print maxSeqLen

	if maxSeqLen > allowedLen and len(data)>1:
		print 'ERROR: one minibatch has more than one large proteins: ', [ d['name'] for d in data ]
		exit(1)
		
	if maxSeqLen <= allowedLen:
		box = np.array([0, 0, maxSeqLen, maxSeqLen]).astype(np.int32)
	else:
		## in this case, len(data) == 1 and len(bounds) == 1
		assert bounds[0] is not None
		box = np.array(bounds[0]).astype(np.int32)

	onebatch.append(box)

	if config.TrainByRefLoss(modelSpecs):
                if forRefState:
                        onebatch.append(np.int32(-1) )
                else:
                        onebatch.append(np.int32(1) )
	return onebatch
def AddLabel2OneBatch(names,
                      batch,
                      modelSpecs,
                      sharedLabelPool,
                      sharedLabelWeightPool,
                      floatType=theano.config.floatX):

    numSeqs = len(names)
    for name in names:
        if (not sharedLabelPool.has_key(name)) or (
                not sharedLabelWeightPool.has_key(name)):
            print 'the label or label weight matrix does not exist for protein ', name
            exit(1)

    seqLens = [sharedLabelWeightPool[name].shape[0] for name in names]

    ## get the boundingbox for this batch
    if not config.TrainByRefLoss(modelSpecs):
        box = batch[-1]
    else:
        box = batch[-2]

    top, left, bottom, right = box
    assert bottom - top == right - left
    boxsize = bottom - top

    if boxsize < max(seqLens) and numSeqs > 1:
        ## make sure that there is only one protein in this batch
        print 'ERROR: when one batch has a large protein, it can only have one protein'
        exit(1)

    ## we crop pairwise labels at this step to save memory and computational time
    maxMatrixSize = min(boxsize, max(seqLens))

    ## Y shall be a list of 2D or 3D matrices, each for one response
    Y = []
    for response in modelSpecs['responses']:
        labelName, labelType, _ = ParseResponse(response)
        dataType = np.int16
        if not config.IsDiscreteLabel(labelType):
            dataType = floatType
        rValDims = GetResponseValueDims(response)
        if rValDims == 1:
            y = np.zeros(shape=(numSeqs, maxMatrixSize, maxMatrixSize),
                         dtype=dataType)
            Y.append(y)

        else:
            y = np.zeros(shape=(numSeqs, maxMatrixSize, maxMatrixSize,
                                rValDims),
                         dtype=dataType)
            Y.append(y)

    ## when Y is empty, weight is useless. So When Y is empty, weight shall also be empty
    weightMatrix = []
    if bool(Y) and config.UseSampleWeight(modelSpecs):
        weightMatrix = [
            np.zeros(shape=(numSeqs, maxMatrixSize, maxMatrixSize),
                     dtype=floatType)
        ] * len(modelSpecs['responses'])

    for j, name, seqLen in zip(range(len(names)), names, seqLens):

        ## we align all matrices in the bottom/right corner
        ## posInX and posInY are the starting position of one protein in the final output tensor
        ## here X and Y refer to x-axis and y-axis
        posInX = -min(boxsize, seqLen)
        posInY = -min(boxsize, seqLen)

        for y, response in zip(Y, modelSpecs['responses']):

            if boxsize < seqLen:
                tmp = sharedLabelPool[name][response][top:bottom, left:right]
            else:
                tmp = sharedLabelPool[name][response]
            if len(y.shape) == 3:
                y[j, posInX:, posInY:] = tmp
            else:
                y[j, posInX:, posInY:, ] = tmp

        labelWeightMatrix = sharedLabelWeightPool[name]
        for w, response in zip(weightMatrix, modelSpecs['responses']):
            if boxsize < seqLen:
                w[j, posInX:,
                  posInY:] = labelWeightMatrix[response][top:bottom,
                                                         left:right]
            else:
                w[j, posInX:, posInY:] = labelWeightMatrix[response]

    ## the input batch contains bounding box
    tail = 1

    ## check to see if the input batch contains one flag for RefState
    if config.TrainByRefLoss(modelSpecs):
        tail += 1

    newbatch = batch[:-tail]
    newbatch.extend(Y)
    newbatch.extend(weightMatrix)
    newbatch.extend(batch[-tail:])

    return newbatch
def main(argv):

    modelSpecs = InitializeModelSpecs()
    modelSpecs = ParseCommandLine.ParseArguments(argv, modelSpecs)

    startTime = datetime.datetime.now()

    trainMetaData = DataProcessor.LoadMetaData(modelSpecs['trainFile'])
    FeatureUtils.DetermineFeatureDimensionBySampling(trainMetaData, modelSpecs)
    ## calculate label distribution and weight at the very beginning
    print 'Calculating label distribution...'
    LabelUtils.CalcLabelDistributionNWeightBySampling(trainMetaData,
                                                      modelSpecs)

    if config.TrainByRefLoss(modelSpecs) or config.UseRefState(modelSpecs):
        print 'Calculating feature expection by sampling...'
        FeatureUtils.CalcFeatureExpectBySampling(trainMetaData, modelSpecs)

## trainMetaData is a list of groups. Each group contains a set of related proteins (seq-template alignments) and files for their features
    trainDataLocation = DataProcessor.SampleProteinInfo(trainMetaData)
    trainSeqData = DataProcessor.SplitData2Batches(
        trainDataLocation,
        numDataPoints=modelSpecs['minibatchSize'],
        modelSpecs=modelSpecs)
    print 'approximate #batches for train data: ', len(trainSeqData)

    #global trainSharedQ, stopTrainDataLoader, trainDataLoaders, trainSharedLabelPool, trainSharedLabelWeightPool
    global trainSharedQ, stopTrainDataLoader, trainDataLoaders
    trainSharedQ = multiprocessing.Queue(config.QSize(modelSpecs))
    stopTrainDataLoader = multiprocessing.Event()
    #trainSharedLabelPool = multiprocessing.Manager().dict()
    #trainSharedLabelWeightPool = multiprocessing.Manager().dict()
    #print stopTrainDataLoader

    numTrainDataLoaders = config.NumTrainDataLoaders(modelSpecs)
    metaDatas = DataProcessor.SplitMetaData(trainMetaData, numTrainDataLoaders)

    trainDataLoaders = []
    for i, metaData in zip(xrange(numTrainDataLoaders), metaDatas):
        #trainDataLoader = multiprocessing.Process(name='TrainDataLoader ' + str(i) + ' for ' + str(os.getpid()), target=TrainUtils.TrainDataLoader, args=(trainSharedQ, metaData, modelSpecs, True, True))
        trainDataLoader = multiprocessing.Process(
            name='TrainDataLoader ' + str(i) + ' for ' + str(os.getpid()),
            target=TrainUtils.TrainDataLoader2,
            args=(trainSharedQ, stopTrainDataLoader, metaData, modelSpecs,
                  True, True))
        #trainDataLoader = multiprocessing.Process(name='TrainDataLoader ' + str(i) + ' for ' + str(os.getpid()), target=TrainUtils.TrainDataLoader3, args=(trainSharedQ, trainSharedLabelPool, trainSharedLabelWeightPool, stopTrainDataLoader, metaData, modelSpecs, True, True))
        trainDataLoader.daemon = True
        trainDataLoaders.append(trainDataLoader)

    print 'start the train data loaders...'
    for trainDataLoader in trainDataLoaders:
        trainDataLoader.start()

    validMetaData = DataProcessor.LoadMetaData(modelSpecs['validFile'])
    validDataLocation = DataProcessor.SampleProteinInfo(validMetaData)

    ## split data into batches, but do not load the real data from disk
    #validSeqData = DataProcessor.SplitData2Batches(validDataLocation, numDataPoints=modelSpecs['minibatchSize'], modelSpecs=modelSpecs)
    validSeqData = DataProcessor.SplitData2Batches(validDataLocation,
                                                   numDataPoints=500 * 500,
                                                   modelSpecs=modelSpecs)
    print '#batches for validation data: ', len(validSeqData)

    global validSharedQ, validDataLoader, stopValidDataLoader
    validSharedQ = multiprocessing.Queue(len(validSeqData))
    stopValidDataLoader = multiprocessing.Event()
    #print stopValidDataLoader
    ## shared memory is a limited resource, so avoid using it as much as possible
    ## here we do not use shared array for validation data since we only need to load it once
    #validDataLoader = multiprocessing.Process(name='ValidDataLoader for '+str(os.getpid()), target=TrainUtils.ValidDataLoader, args=(validSharedQ, validSeqData, modelSpecs, True, False))
    validDataLoader = multiprocessing.Process(
        name='ValidDataLoader for ' + str(os.getpid()),
        target=TrainUtils.ValidDataLoader2,
        args=(validSharedQ, stopValidDataLoader, validSeqData, modelSpecs,
              True, False))
    print 'start the validation data loader...'
    validDataLoader.start()
    """
	if modelSpecs.has_key('ScaleLoss4Cost') and (modelSpecs['ScaleLoss4Cost'] is True):
		##calculate the average weight per minibatch
		maxDeviation = DataProcessor.CalcAvgWeightPerBatch(trainSeqDataset, modelSpecs)
		print 'maxWeightDeviation=', maxDeviation
	"""

    beforeTrainTime = datetime.datetime.now()
    print 'time spent before training :', beforeTrainTime - startTime

    result = TrainModel(modelSpecs=modelSpecs,
                        trainValidData=(trainSeqData, validSeqData))

    ##merge ModelSpecs and result
    resultModel = modelSpecs.copy()
    resultModel.update(result)

    modelFile = TrainUtils.GenerateModelFileName(resultModel)
    print 'Writing the resultant model to ', modelFile
    cPickle.dump(resultModel, file(modelFile, 'wb'), cPickle.HIGHEST_PROTOCOL)

    afterTrainTime = datetime.datetime.now()
    print 'time spent on training:', afterTrainTime - beforeTrainTime

    ## clean up again
    print 'Cleaning up again...'
    Cleanup()
def BuildModel(modelSpecs, forTrain=True):
    rng = np.random.RandomState()

    ## x is for sequential features and y for matrix (or pairwise) features
    x = T.tensor3('x')
    y = T.tensor4('y')

    ## mask for x and y, respectively
    xmask = T.bmatrix('xmask')
    ymask = T.btensor3('ymask')

    xem = None
    ##if any( k in modelSpecs['seq2matrixMode'] for k in ('SeqOnly', 'Seq+SS') ):
    if config.EmbeddingUsed(modelSpecs):
        xem = T.tensor3('xem')

## bounding box for crop of a big protein distance matrix. This box allows crop at any position.
    box = None
    if forTrain:
        box = T.ivector('boundingbox')

## trainByRefLoss can be either 1 or -1. When this variable exists, we train the model using both reference loss and the loss of real data
    trainByRefLoss = None
    if forTrain and config.TrainByRefLoss(modelSpecs):
        trainByRefLoss = T.iscalar('trainByRefLoss')

    distancePredictor = ResNet4DistMatrix(rng,
                                          seqInput=x,
                                          matrixInput=y,
                                          mask_seq=xmask,
                                          mask_matrix=ymask,
                                          embedInput=xem,
                                          boundingbox=box,
                                          modelSpecs=modelSpecs)

    ## labelList is a list of label tensors, each having shape (batchSize, seqLen, seqLen) or (batchSize, seqLen, seqLen, valueDims[response] )
    labelList = []
    if forTrain:
        ## when this model is used for training. We need to define the label variable
        for response in modelSpecs['responses']:
            labelType = Response2LabelType(response)
            rValDims = GetResponseValueDims(response)

            if labelType.startswith('Discrete'):
                if rValDims > 1:
                    ## if one response is a vector, then we use a 4-d tensor
                    ## wtensor is for 16bit integer
                    labelList.append(T.wtensor4('Tlabel4' + response))
                else:
                    labelList.append(T.wtensor3('Tlabel4' + response))
            else:
                if rValDims > 1:
                    labelList.append(T.tensor4('Tlabel4' + response))
                else:
                    labelList.append(T.tensor3('Tlabel4' + response))

    ## weightList is a list of label weight tensors, each having shape (batchSize, seqLen, seqLen)
    weightList = []
    if len(labelList) > 0 and config.UseSampleWeight(modelSpecs):
        weightList = [
            T.tensor3('Tweight4' + response)
            for response in modelSpecs['responses']
        ]

## for prediction, both labelList and weightList are empty
    if forTrain:
        return distancePredictor, x, y, xmask, ymask, xem, labelList, weightList, box, trainByRefLoss
    else:
        return distancePredictor, x, y, xmask, ymask, xem