Exemple #1
0
def train_dssm_with_minibatch(bin_file_train_1, bin_file_train_2, dssm_file_1_simple, dssm_file_2_simple, outputdir, ntrial, shift, max_iteration):
    # 1. Load in the input streams
    # Suppose the max seen feaid in the stream is 48930
    # then, inputstream1.nMaxFeatureId is 48931, which is one more
    # Ideally, the num_cols should be 48931
    # However, to make it conpatible with MS DSSM toolkit, we add it by one
    inputstream1 = nc.InputStream(bin_file_train_1) # this will load in the whole file as origin. No modification at all
    inputstream2 = nc.InputStream(bin_file_train_2)
    

    # 2. Load in the network structure and initial weights from DSSM
    init_model_1 = load_simpledssmmodel(dssm_file_1_simple)
    init_model_1.reset_params_by_random(0)
    activations_1 = [T.tanh] * init_model_1.mlink_num
    
    init_model_2 = load_simpledssmmodel(dssm_file_2_simple)
    init_model_2.reset_params_by_random(1)
    activations_2 = [T.tanh] * init_model_2.mlink_num

    # Before iteration, dump out the init model 
    outfilename_1 = os.path.join(outputdir, "yw_dssm_Q_0")
    outfilename_2 = os.path.join(outputdir, "yw_dssm_D_0")
    save_simpledssmmodel(outfilename_1, init_model_1)
    save_simpledssmmodel(outfilename_2, init_model_2)
    


    
    # 3. Generate useful index structures
    # We assue that each minibatch is of the same size, i.e. mbsize
    # if the last batch has fewer samples, just ignore it
    mbsize = inputstream1.BatchSize
    indexes = basic_utilities.generate_index(mbsize, ntrial, shift) # for a normal minibatch, we should use this indexes
#    indexes_lastone = generate_index(inputstream1.minibatches[-1].SegSize, ntrial, shift) # this is used for the last batch

    # 4. Generate an instance of DSSM    
    dssm = DSSM(init_model_1.params, activations_1, init_model_2.params, activations_2, mbsize, ntrial, shift )

    # Create Theano variables for the MLP input
    dssm_index_Q = T.ivector('dssm_index_Q')
    dssm_index_D = T.ivector('dssm_index_D')
    dssm_input_Q = T.matrix('dssm_input_Q')
    dssm_input_D = T.matrix('dssm_input_D')
    # ... and the desired output
#    mlp_target = T.col('mlp_target')
    # Learning rate and momentum hyperparameter values
    # Again, for non-toy problems these values can make a big difference
    # as to whether the network (quickly) converges on a good local minimum.
    learning_rate = 0.1
    momentum = 0.0
    # Create a function for computing the cost of the network given an input
    
    """
#    cost = mlp.squared_error(mlp_input, mlp_target)
    cost = dssm.output_train(dssm_index_Q, dssm_index_D, dssm_input_Q, dssm_input_D)
    cost_test = dssm.output_test(dssm_index_Q, dssm_index_D, dssm_input_Q, dssm_input_D)
        
    # Create a theano function for training the network
    train = theano.function([dssm_index_Q, dssm_index_D, dssm_input_Q, dssm_input_D], cost,
                            updates=gradient_updates_momentum(cost, dssm.params, learning_rate, momentum), mode=functionmode)
    # Create a theano function for computing the MLP's output given some input
    dssm_output = theano.function([dssm_index_Q, dssm_index_D, dssm_input_Q, dssm_input_D], cost_test, mode=functionmode)
    
    """
    """
    ywcost = dssm.output_train(dssm_index_Q, dssm_index_D, dssm_input_Q, dssm_input_D)
    ywtest = theano.function([dssm_index_Q, dssm_index_D, dssm_input_Q, dssm_input_D], ywcost,
                             updates=gradient_updates_momentum(ywcost, dssm.params, learning_rate, momentum), mode=functionmode)
    """
    ywcost = dssm.output_train(dssm_index_Q, dssm_index_D, dssm_input_Q, dssm_input_D)
#    ywcost_scalar = ywcost.sum()
    ywtest = theano.function([dssm_index_Q, dssm_index_D, dssm_input_Q, dssm_input_D], ywcost,
                             updates=basic_utilities.gradient_updates_momentum(ywcost, dssm.params, learning_rate, momentum, mbsize), mode=functionmode)
#    ywtest_noupdate = theano.function([dssm_index_Q, dssm_index_D, dssm_input_Q, dssm_input_D], ywcost,
#                             mode=functionmode)
    # Keep track of the number of training iterations performed
#    grad_ywcost = theano.grad(ywcost, dssm.params)
#    grad_ywtest = theano.function([dssm_index_Q, dssm_index_D, dssm_input_Q, dssm_input_D], grad_ywcost, mode=functionmode)
 
#    train_embedding = dssm.output_embedding(dssm_input_Q, dssm_input_D)
#    func_train_embedding = theano.function([dssm_input_Q, dssm_input_D], train_embedding, mode=functionmode)
    
#    train_output_complete = dssm.output_train_complete(dssm_index_Q, dssm_index_D, dssm_input_Q, dssm_input_D)
#    func_train_output_complete = theano.function([dssm_index_Q, dssm_index_D, dssm_input_Q, dssm_input_D], train_output_complete, mode=functionmode)
    
    
    iteration = 1
    while iteration <= max_iteration:
        print "Iteration %d--------------" % (iteration)
        print "Each iteration contains %d minibatches" % (inputstream1.nTotalBatches)
        
        trainLoss = 0.0

        if inputstream1.BatchSize == inputstream1.minibatches[-1].SegSize:
            usefulbatches = inputstream1.nTotalBatches
        else:
            usefulbatches = inputstream1.nTotalBatches -1
        print "After removing the last incomplete batch, we need to process %d batches" % (usefulbatches)

        curr_minibatch1 = np.zeros((inputstream1.BatchSize, init_model_1.in_num_list[0]), dtype = numpy.float32)
        curr_minibatch2 = np.zeros((inputstream2.BatchSize, init_model_2.in_num_list[0]), dtype = numpy.float32)

        # we scan all minibatches, except the last one  
        for i in range(usefulbatches):
#            i = 6
#            if i %100 == 0:
            
            inputstream1.setaminibatch(curr_minibatch1, i)
            inputstream2.setaminibatch(curr_minibatch2, i)

#            tmp_train_output_complete = func_train_output_complete(indexes[0], indexes[1], curr_minibatch1, curr_minibatch2)
#            tmp_train_embedding = func_train_embedding(curr_minibatch1, curr_minibatch2)
#            grad_current_output =  grad_ywtest(indexes[0], indexes[1], curr_minibatch1, curr_minibatch2)           
            current_output = ywtest(indexes[0], indexes[1], curr_minibatch1, curr_minibatch2)
#            current_output = ywtest_noupdate(indexes[0], indexes[1], curr_minibatch1, curr_minibatch2)
#            print "batch no %d, %f\n" % (i, current_output)
#           print current_output
            trainLoss += current_output
#            print "After processing batch no %d, curr_loss = %f, trainLoss = %f" % (i, current_output, trainLoss)
            if i%100 == 0:            
                print "%d\t%f\t%f" % (i, current_output, trainLoss)

        print "all batches in this iteraton is processed"
        print "trainLoss = %f" % (trainLoss)
                     
        # dump out current model separately
        tmpparams = []
        for W in dssm.params_Q:
            tmpparams.append(W.get_value())
        outfilename_1 = os.path.join(outputdir, "yw_dssm_Q_%d" % (iteration))
        save_simpledssmmodel(outfilename_1, SimpleDSSMModelFormat(init_model_1.mlayer_num, init_model_1.layer_info, init_model_1.mlink_num, init_model_1.in_num_list, init_model_1.out_num_list, tmpparams))
        

        tmpparams = []
        for W in dssm.params_D:
            tmpparams.append(W.get_value())
        outfilename_2 = os.path.join(outputdir, "yw_dssm_D_%d" % (iteration))
        save_simpledssmmodel(outfilename_2, SimpleDSSMModelFormat(init_model_2.mlayer_num, init_model_2.layer_info, init_model_2.mlink_num, init_model_2.in_num_list, init_model_2.out_num_list, tmpparams))
        
        print "Iteration %d-------------- is finished" % (iteration)
        
        iteration += 1

    print "-----The whole train process is finished-------\n"
def train_dssm_with_minibatch(ps):
    # 1. Load in the input streams of queries
    inputstream_src_1 = nc.InputStream(ps.bin_file_train_src_1)
    inputstream_src_2 = nc.InputStream(ps.bin_file_train_src_2)

    inputstream_tgt_1 = nc.InputStream(ps.bin_file_train_tgt_1)
    inputstream_tgt_2 = nc.InputStream(ps.bin_file_train_tgt_2)

    # 2. Read in network structure details and then do initialization for queries
    fields = ps.SimpleDSSM_1_NetworkStructure_src.split()  # a line of "49k:128:128 100K:128:128 256:128"
    assert len(fields) == 3
    weights_src = [[], [], []]
    for i in range(len(fields)):
        subfields = fields[i].split(":")
        for j in range(len(subfields) - 1):
            in_size = int(subfields[j])
            out_size = int(subfields[j + 1])
            W = np.zeros((in_size, out_size), dtype=np.float32)
            weights_src[i].append(W)

    fields = ps.SimpleDSSM_1_NetworkStructure_tgt.split()  # a line of "49k:128:128 100K:128:128 256:128"
    assert len(fields) == 3
    weights_tgt = [[], [], []]
    for i in range(len(fields)):
        subfields = fields[i].split(":")
        for j in range(len(subfields) - 1):
            in_size = int(subfields[j])
            out_size = int(subfields[j + 1])
            W = np.zeros((in_size, out_size), dtype=np.float32)
            weights_tgt[i].append(W)

    # 3. Init two instances and do init
    # for a model, it's params = [W_list_1, W_list_2, W_list_3]
    init_model_src = SimpleDSSM1Model_1_Format(weights_src)
    init_model_src.reset_params_by_random(0)

    init_model_tgt = SimpleDSSM1Model_1_Format(weights_tgt)
    init_model_tgt.reset_params_by_random(1)

    # 4. Before iteration, dump out the init model
    # the essential part of a model is [W_list_1, W_list_2, W_list_3]
    outfilename_src = os.path.join(ps.outputdir, "yw_dssm_Q_0")
    outfilename_tgt = os.path.join(ps.outputdir, "yw_dssm_D_0")
    save_simpledssmmodel(outfilename_src, init_model_src)
    save_simpledssmmodel(outfilename_tgt, init_model_tgt)

    # 5. Generate useful index structures
    mbsize = inputstream_src_1.BatchSize
    indexes = basic_utilities.generate_index(
        mbsize, ps.ntrial, ps.shift
    )  # for a normal minibatch, we should use this indexes

    # 6. Generate an instance of DSSM
    # init_model_src.params ==== a list of list of weight matrices
    dssm = SimpleDSSM_1(init_model_src.params, init_model_tgt.params, mbsize, ps.ntrial, ps.shift)

    # 7. Create Theano variables for the MLP input
    dssm_index_Q = T.ivector("dssm_index_Q")
    dssm_index_D = T.ivector("dssm_index_D")

    dssm_input_Q_1 = T.matrix("dssm_input_Q_1")
    dssm_input_Q_2 = T.matrix("dssm_input_Q_2")
    dssm_input_D_1 = T.matrix("dssm_input_D_1")
    dssm_input_D_2 = T.matrix("dssm_input_D_2")

    learning_rate = 0.1
    momentum = 0.0

    train_output = dssm.output_train_1(
        dssm_index_Q, dssm_index_D, dssm_input_Q_1, dssm_input_Q_2, dssm_input_D_1, dssm_input_D_2
    )
    func_train_output = theano.function(
        [dssm_index_Q, dssm_index_D, dssm_input_Q_1, dssm_input_Q_2, dssm_input_D_1, dssm_input_D_2],
        train_output,
        updates=basic_utilities.gradient_updates_momentum(train_output, dssm.params, learning_rate, momentum, mbsize),
        mode=functionmode,
    )

    iteration = 1
    while iteration <= ps.max_iteration:
        print "Iteration %d--------------" % (iteration)
        print "Each iteration contains %d minibatches" % (inputstream_src_1.nTotalBatches)

        trainLoss = 0.0

        if inputstream_src_1.BatchSize == inputstream_src_1.minibatches[-1].SegSize:
            usefulbatches = inputstream_src_1.nTotalBatches
        else:
            usefulbatches = inputstream_src_1.nTotalBatches - 1
        print "After removing the last incomplete batch (if there is one), we need to process %d batches" % (
            usefulbatches
        )

        #        usefulbatches = 10

        curr_minibatch_src_1 = np.zeros((mbsize, init_model_src.params[0][0].shape[0]), dtype=numpy.float32)
        curr_minibatch_src_2 = np.zeros((mbsize, init_model_src.params[1][0].shape[0]), dtype=numpy.float32)

        curr_minibatch_tgt_1 = np.zeros((mbsize, init_model_tgt.params[0][0].shape[0]), dtype=numpy.float32)
        curr_minibatch_tgt_2 = np.zeros((mbsize, init_model_tgt.params[1][0].shape[0]), dtype=numpy.float32)

        for i in range(usefulbatches):

            inputstream_src_1.setaminibatch(curr_minibatch_src_1, i)
            inputstream_src_2.setaminibatch(curr_minibatch_src_2, i)
            inputstream_tgt_1.setaminibatch(curr_minibatch_tgt_1, i)
            inputstream_tgt_2.setaminibatch(curr_minibatch_tgt_2, i)

            current_output = func_train_output(
                indexes[0],
                indexes[1],
                curr_minibatch_src_1,
                curr_minibatch_src_2,
                curr_minibatch_tgt_1,
                curr_minibatch_tgt_2,
            )
            trainLoss += current_output

            if i % 100 == 0:
                print "%d\t%f\t%f" % (i, current_output, trainLoss)

        print "all batches in this iteraton is processed"
        print "trainLoss = %f" % (trainLoss)

        # dump out current model separately
        tmpparams = [[], [], []]
        for i in range(len(dssm.params_Q)):
            list_len = len(dssm.params_Q[i])
            for j in range(list_len):
                tmpparams[i].append(dssm.params_Q[i][j].get_value())

        outfilename_src = os.path.join(ps.outputdir, "yw_dssm_Q_%d" % (iteration))
        save_simpledssmmodel(outfilename_src, SimpleDSSM1Model_1_Format(tmpparams))

        tmpparams = [[], [], []]
        for i in range(len(dssm.params_D)):
            list_len = len(dssm.params_D[i])
            for j in range(list_len):
                tmpparams[i].append(dssm.params_D[i][j].get_value())
        outfilename_tgt = os.path.join(ps.outputdir, "yw_dssm_D_%d" % (iteration))
        save_simpledssmmodel(outfilename_tgt, SimpleDSSM1Model_1_Format(tmpparams))

        print "Iteration %d-------------- is finished" % (iteration)

        iteration += 1

    print "-----The whole train process is finished-------\n"
Exemple #3
0
def train_dssm_with_minibatch(ps):
    # 1. Load in the input streams of queries
    inputstream_src_1 = nc.InputStream(ps.bin_file_train_src_1)
    inputstream_src_2 = nc.InputStream(ps.bin_file_train_src_2)

    inputstream_tgt_1 = nc.InputStream(ps.bin_file_train_tgt_1)
    inputstream_tgt_2 = nc.InputStream(ps.bin_file_train_tgt_2)

    # 2. Read in network structure details and then do initialization for queries
    fields = ps.SimpleDSSM_1_NetworkStructure_src.split() # a line of "49k:128:128 100K:128:128 256:128"
    assert(len(fields) == 3)
    weights_src = [[], [], []]
    for i in range(len(fields)):
        subfields = fields[i].split(':')
        for j in range(len(subfields)-1):
            in_size = int(subfields[j])
            out_size = int(subfields[j+1])
            W = np.zeros((in_size, out_size), dtype = np.float32)
            weights_src[i].append(W)
    
    fields = ps.SimpleDSSM_1_NetworkStructure_tgt.split() # a line of "49k:128:128 100K:128:128 256:128"
    assert(len(fields) == 3)
    weights_tgt = [[], [], []]
    for i in range(len(fields)):
        subfields = fields[i].split(':')
        for j in range(len(subfields)-1):
            in_size = int(subfields[j])
            out_size = int(subfields[j+1])
            W = np.zeros((in_size, out_size), dtype = np.float32)
            weights_tgt[i].append(W)
    
    
    # 3. Init two instances and do init
    # for a model, it's params = [W_list_1, W_list_2, W_list_3]
    init_model_src = SimpleDSSM1Model_1_Format(weights_src)
    init_model_src.reset_params_by_random(0)
    
    init_model_tgt = SimpleDSSM1Model_1_Format(weights_tgt)
    init_model_tgt.reset_params_by_random(1)

    # 4. Before iteration, dump out the init model 
    # the essential part of a model is [W_list_1, W_list_2, W_list_3]
    outfilename_src = os.path.join(ps.outputdir, "yw_dssm_Q_0")
    outfilename_tgt = os.path.join(ps.outputdir, "yw_dssm_D_0")
    save_simpledssmmodel(outfilename_src, init_model_src)
    save_simpledssmmodel(outfilename_tgt, init_model_tgt)
    

    # 5. Generate useful index structures
    mbsize = inputstream_src_1.BatchSize
    indexes = basic_utilities.generate_index(mbsize, ps.ntrial, ps.shift) # for a normal minibatch, we should use this indexes

    # 6. Generate an instance of DSSM    
    # init_model_src.params ==== a list of list of weight matrices
    dssm = SimpleDSSM_1(init_model_src.params, init_model_tgt.params, mbsize, ps.ntrial, ps.shift )

    # 7. Create Theano variables for the MLP input
    dssm_index_Q = T.ivector('dssm_index_Q')
    dssm_index_D = T.ivector('dssm_index_D')

    dssm_input_Q_1 = T.matrix('dssm_input_Q_1')
    dssm_input_Q_2 = T.matrix('dssm_input_Q_2')
    dssm_input_D_1 = T.matrix('dssm_input_D_1')
    dssm_input_D_2 = T.matrix('dssm_input_D_2')

    learning_rate = 0.1
    momentum = 0.0
    
    train_output = dssm.output_train_1(dssm_index_Q, dssm_index_D, dssm_input_Q_1, dssm_input_Q_2, dssm_input_D_1, dssm_input_D_2)
    func_train_output = theano.function([dssm_index_Q, dssm_index_D, dssm_input_Q_1, dssm_input_Q_2, dssm_input_D_1, dssm_input_D_2], train_output,
                             updates=basic_utilities.gradient_updates_momentum(train_output, dssm.params, learning_rate, momentum, mbsize), mode=functionmode)
    
    iteration = 1
    while iteration <= ps.max_iteration:
        print "Iteration %d--------------" % (iteration)
        print "Each iteration contains %d minibatches" % (inputstream_src_1.nTotalBatches)
        
        trainLoss = 0.0

        if inputstream_src_1.BatchSize == inputstream_src_1.minibatches[-1].SegSize:
            usefulbatches = inputstream_src_1.nTotalBatches
        else:
            usefulbatches = inputstream_src_1.nTotalBatches -1
        print "After removing the last incomplete batch (if there is one), we need to process %d batches" % (usefulbatches)

#        usefulbatches = 10
        
        curr_minibatch_src_1 = np.zeros((mbsize, init_model_src.params[0][0].shape[0]), dtype = numpy.float32)
        curr_minibatch_src_2 = np.zeros((mbsize, init_model_src.params[1][0].shape[0]), dtype = numpy.float32)
        
        curr_minibatch_tgt_1 = np.zeros((mbsize, init_model_tgt.params[0][0].shape[0]), dtype = numpy.float32)
        curr_minibatch_tgt_2 = np.zeros((mbsize, init_model_tgt.params[1][0].shape[0]), dtype = numpy.float32)
        
        for i in range(usefulbatches):
            
            inputstream_src_1.setaminibatch(curr_minibatch_src_1, i)
            inputstream_src_2.setaminibatch(curr_minibatch_src_2, i)
            inputstream_tgt_1.setaminibatch(curr_minibatch_tgt_1, i)
            inputstream_tgt_2.setaminibatch(curr_minibatch_tgt_2, i)
            
            current_output = func_train_output(indexes[0], indexes[1], curr_minibatch_src_1, curr_minibatch_src_2, curr_minibatch_tgt_1, curr_minibatch_tgt_2)
            trainLoss += current_output
            
            if i %100 == 0:
                print "%d\t%f\t%f" % (i, current_output, trainLoss)

        print "all batches in this iteraton is processed"
        print "trainLoss = %f" % (trainLoss)
                     
        # dump out current model separately
        tmpparams = [[], [], []]
        for i in range(len(dssm.params_Q)):
            list_len = len(dssm.params_Q[i])
            for j in range(list_len):
                tmpparams[i].append(dssm.params_Q[i][j].get_value())
            
        outfilename_src = os.path.join(ps.outputdir, "yw_dssm_Q_%d" % (iteration))
        save_simpledssmmodel(outfilename_src, SimpleDSSM1Model_1_Format(tmpparams))
        

        tmpparams = [[], [], []]
        for i in range(len(dssm.params_D)):
            list_len = len(dssm.params_D[i])
            for j in range(list_len):
                tmpparams[i].append(dssm.params_D[i][j].get_value())
        outfilename_tgt = os.path.join(ps.outputdir, "yw_dssm_D_%d" % (iteration))
        save_simpledssmmodel(outfilename_tgt, SimpleDSSM1Model_1_Format(tmpparams))
        
        print "Iteration %d-------------- is finished" % (iteration)
        
        iteration += 1

    print "-----The whole train process is finished-------\n"
Exemple #4
0
def train_dssm_with_minibatch(ps):
    # 1. Load in the input streams
    # Suppose the max seen feaid in the stream is 48930
    # then, inputstream1.nMaxFeatureId is 48931, which is one more
    # Ideally, the num_cols should be 48931
    # However, to make it conpatible with MS DSSM toolkit, we add it by one
    inputstream_src = nc.InputStream(
        ps.QFILE
    )  # this will load in the whole file as origin. No modification at all
    inputstream_tgt = nc.InputStream(ps.DFILE)

    # 2. Load in the network structure and initial weights from DSSM
    subfields = ps.SimpleDSSM_0_NetworkStructure_src.split(
        ':')  # a line of "49k:128:128"
    weights_src = []
    for j in range(len(subfields) - 1):
        in_size = int(subfields[j])
        out_size = int(subfields[j + 1])
        W = np.zeros((in_size, out_size), dtype=np.float32)
        weights_src.append(W)

    subfields = ps.SimpleDSSM_0_NetworkStructure_tgt.split(
        ':')  # a line of "49k:128:128"
    weights_tgt = []
    for j in range(len(subfields) - 1):
        in_size = int(subfields[j])
        out_size = int(subfields[j + 1])
        W = np.zeros((in_size, out_size), dtype=np.float32)
        weights_tgt.append(W)

    # 3. Init two instances and do init
    # for a model, it's params = [W_list_1, W_list_2, W_list_3]
    init_model_src = SimpleDSSMModel_0_Format(weights_src)
    init_model_src.reset_params_by_random(0)

    init_model_tgt = SimpleDSSMModel_0_Format(weights_tgt)
    init_model_tgt.reset_params_by_random(1)

    # Before iteration, dump out the init model
    outfilename_src = os.path.join(ps.MODELPATH, "yw_dssm_Q_0")
    outfilename_tgt = os.path.join(ps.MODELPATH, "yw_dssm_D_0")
    save_simpledssmmodel(outfilename_src, init_model_src)
    save_simpledssmmodel(outfilename_tgt, init_model_tgt)

    # 3. Generate useful index structures
    # We assue that each minibatch is of the same size, i.e. mbsize
    # if the last batch has fewer samples, just ignore it
    mbsize = inputstream_src.BatchSize
    indexes = basic_utilities.generate_index(
        mbsize, ps.NTRIAL,
        ps.SHIFT)  # for a normal minibatch, we should use this indexes
    #    indexes_lastone = generate_index(inputstream1.minibatches[-1].SegSize, ntrial, shift) # this is used for the last batch

    # 4. Generate an instance of DSSM
    dssm = SimpleDSSM_0(init_model_src.params, init_model_tgt.params, mbsize,
                        ps.NTRIAL, ps.SHIFT)

    # Create Theano variables for the MLP input
    dssm_index_Q = T.ivector('dssm_index_Q')
    dssm_index_D = T.ivector('dssm_index_D')
    dssm_input_Q = T.matrix('dssm_input_Q')
    dssm_input_D = T.matrix('dssm_input_D')
    # ... and the desired output
    #    mlp_target = T.col('mlp_target')
    # Learning rate and momentum hyperparameter values
    # Again, for non-toy problems these values can make a big difference
    # as to whether the network (quickly) converges on a good local minimum.
    learning_rate = 0.1
    momentum = 0.0
    # Create a function for computing the cost of the network given an input

    train_output = dssm.output_train(dssm_index_Q, dssm_index_D, dssm_input_Q,
                                     dssm_input_D)
    #    ywcost_scalar = ywcost.sum()
    func_train_output = theano.function(
        [dssm_index_Q, dssm_index_D, dssm_input_Q, dssm_input_D],
        train_output,
        updates=basic_utilities.gradient_updates_momentum(
            train_output, dssm.params, learning_rate, momentum, mbsize),
        mode=functionmode)

    iteration = 1
    while iteration <= ps.MAX_ITER:
        print "Iteration %d--------------" % (iteration)
        print "Each iteration contains %d minibatches" % (
            inputstream_src.nTotalBatches)

        trainLoss = 0.0

        if inputstream_src.BatchSize == inputstream_src.minibatches[
                -1].SegSize:
            usefulbatches = inputstream_src.nTotalBatches
        else:
            usefulbatches = inputstream_src.nTotalBatches - 1
        print "After removing the last incomplete batch, we need to process %d batches" % (
            usefulbatches)

        curr_minibatch_src = np.zeros(
            (inputstream_src.BatchSize, init_model_src.params[0].shape[0]),
            dtype=numpy.float32)
        curr_minibatch_tgt = np.zeros(
            (inputstream_tgt.BatchSize, init_model_tgt.params[0].shape[0]),
            dtype=numpy.float32)

        # we scan all minibatches, except the last one
        for i in range(usefulbatches):

            inputstream_src.setaminibatch(curr_minibatch_src, i)
            inputstream_tgt.setaminibatch(curr_minibatch_tgt, i)

            current_output = func_train_output(indexes[0], indexes[1],
                                               curr_minibatch_src,
                                               curr_minibatch_tgt)
            trainLoss += current_output

            if i % 100 == 0:
                print "%d\t%f\t%f" % (i, current_output, trainLoss)

        print "all batches in this iteraton is processed"
        print "trainLoss = %f" % (trainLoss)

        # dump out current model separately
        tmpparams = []
        for W in dssm.params_Q:
            tmpparams.append(W.get_value())
        outfilename_src = os.path.join(ps.MODELPATH,
                                       "yw_dssm_Q_%d" % (iteration))
        save_simpledssmmodel(outfilename_src,
                             SimpleDSSMModel_0_Format(tmpparams))

        tmpparams = []
        for W in dssm.params_D:
            tmpparams.append(W.get_value())
        outfilename_tgt = os.path.join(ps.MODELPATH,
                                       "yw_dssm_D_%d" % (iteration))
        save_simpledssmmodel(outfilename_tgt,
                             SimpleDSSMModel_0_Format(tmpparams))

        print "Iteration %d-------------- is finished" % (iteration)

        iteration += 1

    print "-----The whole train process is finished-------\n"
def train_dssm_with_minibatch(ps):
    # 1. Load in the input streams
    # Suppose the max seen feaid in the stream is 48930
    # then, inputstream1.nMaxFeatureId is 48931, which is one more
    # Ideally, the num_cols should be 48931
    # However, to make it conpatible with MS DSSM toolkit, we add it by one
    inputstream_src = nc.InputStream(ps.QFILE) # this will load in the whole file as origin. No modification at all
    inputstream_tgt = nc.InputStream(ps.DFILE)
    
    # 2. Load in the network structure and initial weights from DSSM
    subfields = ps.SimpleDSSM_0_NetworkStructure_src.split(':') # a line of "49k:128:128"
    weights_src = []
    for j in range(len(subfields)-1):
        in_size = int(subfields[j])
        out_size = int(subfields[j+1])
        W = np.zeros((in_size, out_size), dtype = np.float32)
        weights_src.append(W)
    
    subfields = ps.SimpleDSSM_0_NetworkStructure_tgt.split(':') # a line of "49k:128:128"
    weights_tgt = []
    for j in range(len(subfields)-1):
        in_size = int(subfields[j])
        out_size = int(subfields[j+1])
        W = np.zeros((in_size, out_size), dtype = np.float32)
        weights_tgt.append(W)
    
    
    # 3. Init two instances and do init
    # for a model, it's params = [W_list_1, W_list_2, W_list_3]
    init_model_src = SimpleDSSMModel_0_Format(weights_src)
    init_model_src.reset_params_by_random(0)
    
    init_model_tgt = SimpleDSSMModel_0_Format(weights_tgt)
    init_model_tgt.reset_params_by_random(1)

    # Before iteration, dump out the init model 
    outfilename_src = os.path.join(ps.MODELPATH, "yw_dssm_Q_0")
    outfilename_tgt = os.path.join(ps.MODELPATH, "yw_dssm_D_0")
    save_simpledssmmodel(outfilename_src, init_model_src)
    save_simpledssmmodel(outfilename_tgt, init_model_tgt)
    


    
    # 3. Generate useful index structures
    # We assue that each minibatch is of the same size, i.e. mbsize
    # if the last batch has fewer samples, just ignore it
    mbsize = inputstream_src.BatchSize
    indexes = basic_utilities.generate_index(mbsize, ps.NTRIAL, ps.SHIFT) # for a normal minibatch, we should use this indexes
#    indexes_lastone = generate_index(inputstream1.minibatches[-1].SegSize, ntrial, shift) # this is used for the last batch

    # 4. Generate an instance of DSSM    
    dssm = SimpleDSSM_0(init_model_src.params, init_model_tgt.params, mbsize, ps.NTRIAL, ps.SHIFT)

    # Create Theano variables for the MLP input
    dssm_index_Q = T.ivector('dssm_index_Q')
    dssm_index_D = T.ivector('dssm_index_D')
    dssm_input_Q = T.matrix('dssm_input_Q')
    dssm_input_D = T.matrix('dssm_input_D')
    # ... and the desired output
#    mlp_target = T.col('mlp_target')
    # Learning rate and momentum hyperparameter values
    # Again, for non-toy problems these values can make a big difference
    # as to whether the network (quickly) converges on a good local minimum.
    learning_rate = 0.1
    momentum = 0.0
    # Create a function for computing the cost of the network given an input
    
    train_output = dssm.output_train(dssm_index_Q, dssm_index_D, dssm_input_Q, dssm_input_D)
#    ywcost_scalar = ywcost.sum()
    func_train_output = theano.function([dssm_index_Q, dssm_index_D, dssm_input_Q, dssm_input_D], train_output,
                             updates=basic_utilities.gradient_updates_momentum(train_output, dssm.params, learning_rate, momentum, mbsize), mode=functionmode)
    
    iteration = 1
    while iteration <= ps.MAX_ITER:
        print "Iteration %d--------------" % (iteration)
        print "Each iteration contains %d minibatches" % (inputstream_src.nTotalBatches)
        
        trainLoss = 0.0

        if inputstream_src.BatchSize == inputstream_src.minibatches[-1].SegSize:
            usefulbatches = inputstream_src.nTotalBatches
        else:
            usefulbatches = inputstream_src.nTotalBatches -1
        print "After removing the last incomplete batch, we need to process %d batches" % (usefulbatches)

        curr_minibatch_src = np.zeros((inputstream_src.BatchSize, init_model_src.params[0].shape[0]), dtype = numpy.float32)
        curr_minibatch_tgt = np.zeros((inputstream_tgt.BatchSize, init_model_tgt.params[0].shape[0]), dtype = numpy.float32)

        # we scan all minibatches, except the last one  
        for i in range(usefulbatches):
            
            inputstream_src.setaminibatch(curr_minibatch_src, i)
            inputstream_tgt.setaminibatch(curr_minibatch_tgt, i)

            current_output = func_train_output(indexes[0], indexes[1], curr_minibatch_src, curr_minibatch_tgt)
            trainLoss += current_output

            if i%100 == 0:            
                print "%d\t%f\t%f" % (i, current_output, trainLoss)

        print "all batches in this iteraton is processed"
        print "trainLoss = %f" % (trainLoss)
                     
        # dump out current model separately
        tmpparams = []
        for W in dssm.params_Q:
            tmpparams.append(W.get_value())
        outfilename_src = os.path.join(ps.MODELPATH, "yw_dssm_Q_%d" % (iteration))
        save_simpledssmmodel(outfilename_src, SimpleDSSMModel_0_Format(tmpparams))
         

        tmpparams = []
        for W in dssm.params_D:
            tmpparams.append(W.get_value())
        outfilename_tgt = os.path.join(ps.MODELPATH, "yw_dssm_D_%d" % (iteration))
        save_simpledssmmodel(outfilename_tgt, SimpleDSSMModel_0_Format(tmpparams))
        
        print "Iteration %d-------------- is finished" % (iteration)
        
        iteration += 1

    print "-----The whole train process is finished-------\n"