Exemple #1
0
def train_whole_dataset_begin(para_dic, model_name):

    #choose model
    if model_name == 'resnet_link':
        import otto_resnet_model as rl
        model = rl
    elif model_name == 'link_cnn':
        import Link_CNN_model as lc
        model = lc
    elif model_name == 'otto_resnet':
        import otto_resnet_model as orm
        model = orm
    else:
        print "Error model name"
        return 'error'


    dpm.model = model

    #get all para first
    [SPAN, dir, epochs, data_size, file_size, loop_eval_num, batch_size, train_file_size, valid_file_size, test_file_size, reg, lr_rate, lr_decay, keep_prob_v, log_dir, module_dir, eval_last_num, epoch, loop, best_model_number, best_model_acc_dic, best_model_dir_dic] = get_para_from_dic(para_dic)

    #change the dir and log dir nameodule_dir = module_dir + model_name + '/'
    log_dir = log_dir + model_name + '/'

    #create the acc dic and dir dic
    best_model_acc_dic = np.arange(0.0,-best_model_number,-1.0).tolist()
    best_model_dir_dic = []
    for i in range(best_model_number):
        best_model_dir_dic.append('%s'%best_model_acc_dic[i])

    max_step = train_file_size // batch_size
    loops = data_size // file_size
    log = Log()
    create_dir(log_dir)
    create_dir(module_dir)

    with tf.Graph().as_default():
        with tf.Session() as sess:
            # inputs
            input_x = tf.placeholder(tf.float32, [None, 96, 96, 1], name='input_x')
            para_pl = tf.placeholder(tf.float32, [None, 41], name='para_pl')
            input_y = tf.placeholder(tf.float32, [None, 9], name='input_y')
            train_phase = tf.placeholder(tf.bool, name='train_phase')
            keep_prob = tf.placeholder(tf.float32, name='keep_prob')

            # logits
            y_pred, parameters = model.inference(input_x, para_pl, train_phase, keep_prob)

            # loss
            loss_value = loss(input_y, y_pred, reg, parameters)

            # train
            train_step = tf.train.AdamOptimizer(lr_rate).minimize(loss_value)

            # predict
            correct_num, accuracy = corr_num_acc(input_y, y_pred)

            # placeholders
            placeholders = (input_x, para_pl, input_y, train_phase, keep_prob)
            train_pl = input_x, para_pl, input_y, train_phase, keep_prob, train_step, loss_value, accuracy

            sess.run(tf.global_variables_initializer())

            while epoch < epochs:

                # show the epoch num
                words_log_print_epoch(epoch, epochs, log)

                loop_indexs = dpm.get_file_random_seq_indexs(loops)

                # caution loop is not in sequence

                while loop < loops:
                    before_time = time.time()

                    train_file = "train_set_%d.csv" % loop_indexs[loop]

                    loop_loss_v, loop_acc = do_train_file(sess, train_pl, dir, train_file, SPAN, max_step, batch_size,keep_prob_v)

                    words_log_print_loop(loop, loops, loop_loss_v, loop_acc, log)

                    loop += 1

                    # each loop_eval_num, do evaluation
                    if loop % loop_eval_num == 0 or loop == loops:
                        # show the time
                        time_show(before_time, loop_eval_num, loop, loops, epoch, epochs, log)
                        # store the parameter first
                        eval_parameters = (loop, loop_indexs, SPAN, sess, batch_size, correct_num, placeholders, log)
                        # here only evaluate last eval_last_num files
                        evaluate_last_x_files(eval_last_num, eval_parameters, dir)

                        #ask for if want to interrupt
                        #press i to interrupt
                        temp_para = [SPAN, dir, epochs, data_size, file_size, loop_eval_num, batch_size, train_file_size, valid_file_size, test_file_size, reg, lr_rate, lr_decay, keep_prob_v, log_dir, module_dir, eval_last_num, epoch, loop, best_model_number, best_model_acc_dic, best_model_dir_dic]

                        answer = fm.interrupt_flow(temp_para, sess, log, loop_indexs)
                        if answer == 'Done':
                            return 'Done'

                        [SPAN, dir, epochs, data_size, file_size, loop_eval_num, batch_size, train_file_size,valid_file_size, test_file_size, reg, lr_rate, lr_decay, keep_prob_v, log_dir, module_dir,eval_last_num, epoch, loop, best_model_number, best_model_acc_dic, best_model_dir_dic] = temp_para




                #reset loop
                loop = 0
                # each epoch decay the lr_rate
                lr_rate *= lr_decay

                # store the parameter first
                test_parameter = loops, epoch, SPAN, sess, batch_size, correct_num, placeholders, log, dir
                # do the test evaluate
                test_acc = evaluate_test(test_parameter)

                temp_best_acc = np.array(best_model_acc_dic)
                #only store x best model
                if test_acc > temp_best_acc.min():
                    small_index = temp_best_acc.argmin()
                    temp_best_acc[small_index] = test_acc
                    module_path = module_dir + "%.4f_epoch%d/" % (test_acc, epoch)
                    #delete the latest module
                    del_dir(best_model_dir_dic[small_index])
                    best_model_dir_dic[small_index] = module_path
                    best_model_acc_dic = temp_best_acc.tolist()
                    # store module every epoch
                    store_module(module_dir, test_acc, epoch, sess, log, loop_indexs)
                # store log file every epoch
                store_log(log_dir, test_acc, epoch, log)

                epoch += 1
    return 'Done'
def cross_valid(para_dic):

    import Resnet_link_model as rl
    model = rl

    dpm.model = model
    # get all para first
    [
        SPAN, dir, epochs, data_size, file_size, loop_eval_num, batch_size,
        train_file_size, valid_file_size, test_file_size, reg, lr_rate,
        lr_decay, keep_prob_v, log_dir, module_dir, eval_last_num, epoch, loop,
        best_model_number, best_model_acc_dic, best_model_dir_dic
    ] = get_para_from_dic(para_dic)

    # change the dir and log dir nameodule_dir = module_dir + model_name + '/'

    log_dir = 'logs/Restnet_link_cross_valid/'

    # create the acc dic and dir dic
    best_model_acc_dic = np.arange(0.0, -best_model_number, -1.0).tolist()
    best_model_dir_dic = []
    for i in range(best_model_number):
        best_model_dir_dic.append('%s' % best_model_acc_dic[i])

    max_step = train_file_size // batch_size
    loops = data_size // file_size

    # hypers
    regs = random_uniform_array(7, -5, -1)
    lr_rates = random_uniform_array(7, -7, -2)

    count_total = len(regs) * len(lr_rates)
    count = 0

    print 'Begin to cross valid'
    print time.strftime('%Y-%m-%d %H:%M:%S')

    for reg in regs:
        for lr_rate in lr_rates:
            log = Log()

            log_dir = 'logs/Restnet_link_cross_valid/r%.4f_l%.4f_count%d/' % (
                reg, lr_rate, count)
            module_dir = 'modules/Restnet_link_cross_valid/r%.4f_l%.4f_count%d/' % (
                reg, lr_rate, count)
            create_dir(log_dir)
            create_dir(module_dir)

            # show hyper info
            words = '\nhyper\n'
            words += 'reg is %f\n' % reg
            words += 'lr_rate is %f\n' % lr_rate
            words += 'keep_prob_v is %f\n' % keep_prob_v
            words_log_print(words, log)

            filename = log_dir + 'hypers'
            hyper_info = '\nhyper\n'
            hyper_info += 'reg is %f\n' % reg
            hyper_info += 'lr_rate is %f\n' % lr_rate
            hyper_info += 'keep_prob_v is %f\n' % keep_prob_v
            f = file(filename, 'w+')
            f.write(hyper_info)
            f.close()

            with tf.Graph().as_default():
                with tf.Session() as sess:
                    # inputs
                    input_x = tf.placeholder(tf.float32, [None, 304, 48, 2],
                                             name='input_x')
                    para_pl = tf.placeholder(tf.float32, [None, 21],
                                             name='para_pl')
                    input_y = tf.placeholder(tf.float32, [None, 6],
                                             name='input_y')
                    train_phase = tf.placeholder(tf.bool, name='train_phase')
                    keep_prob = tf.placeholder(tf.float32, name='keep_prob')

                    # logits
                    y_pred, parameters = model.inference(
                        input_x, para_pl, train_phase, keep_prob)

                    # loss
                    loss_value = loss(input_y, y_pred, reg, parameters)

                    # train
                    train_step = tf.train.AdamOptimizer(lr_rate).minimize(
                        loss_value)

                    # predict
                    correct_num, accuracy = corr_num_acc(input_y, y_pred)

                    # placeholders
                    placeholders = (input_x, para_pl, input_y, train_phase,
                                    keep_prob)
                    train_pl = input_x, para_pl, input_y, train_phase, keep_prob, train_step, loss_value, accuracy

                    sess.run(tf.global_variables_initializer())

                    while epoch < epochs:

                        words = time.strftime('%Y-%m-%d %H:%M:%S')
                        words_log_print(words, log)

                        # show the epoch num
                        words_log_print_epoch(epoch, epochs, log)

                        loop_indexs = dpm.get_file_random_seq_indexs(loops)

                        # caution loop is not in sequence

                        while loop < loops:
                            before_time = time.time()

                            train_file = "Raw_data_%d_train.csv" % loop_indexs[
                                loop]

                            loop_loss_v, loop_acc = do_train_file(
                                sess, train_pl, dir, train_file, SPAN,
                                max_step, batch_size, keep_prob_v)

                            words_log_print_loop(loop, loops, loop_loss_v,
                                                 loop_acc, log)

                            loop += 1

                            # each loop_eval_num, do evaluation
                            if loop % loop_eval_num == 0 or loop == loops:

                                words = time.strftime('%Y-%m-%d %H:%M:%S')
                                words_log_print(words, log)

                                # show the time
                                time_show(before_time, loop_eval_num, loop,
                                          loops, epoch, epochs, log, count,
                                          count_total)
                                # store the parameter first
                                eval_parameters = (loop, loop_indexs, SPAN,
                                                   sess, batch_size,
                                                   correct_num, placeholders,
                                                   log)
                                # here only evaluate last eval_last_num files
                                evaluate_last_x_files(eval_last_num,
                                                      eval_parameters, dir)

                                # ask for if want to interrupt
                                # press i to interrupt
                                temp_para = [
                                    SPAN, dir, epochs, data_size, file_size,
                                    loop_eval_num, batch_size, train_file_size,
                                    valid_file_size, test_file_size, reg,
                                    lr_rate, lr_decay, keep_prob_v, log_dir,
                                    module_dir, eval_last_num, epoch, loop,
                                    best_model_number, best_model_acc_dic,
                                    best_model_dir_dic
                                ]

                                answer = fm.interrupt_flow(
                                    temp_para, sess, log, loop_indexs)
                                if answer == 'Done':
                                    return 'Done'

                                [
                                    SPAN, dir, epochs, data_size, file_size,
                                    loop_eval_num, batch_size, train_file_size,
                                    valid_file_size, test_file_size, reg,
                                    lr_rate, lr_decay, keep_prob_v, log_dir,
                                    module_dir, eval_last_num, epoch, loop,
                                    best_model_number, best_model_acc_dic,
                                    best_model_dir_dic
                                ] = temp_para

                        # reset loop
                        loop = 0
                        # each epoch decay the lr_rate
                        lr_rate *= lr_decay

                        # store the parameter first
                        test_parameter = loops, epoch, SPAN, sess, batch_size, correct_num, placeholders, log, dir
                        # do the test evaluate
                        test_acc = evaluate_test(test_parameter)

                        # store log file every epoch

                        store_log(log_dir, test_acc, epoch, log)

                        epoch += 1

                    epoch = 0
                    loop = 0

            count += 1