def CNNnet(tr_data, tr_label, te_data, te_label, sam_num, n_way, com_label_str, com_img_str, str1 = 'none', iter_num=5000):
    com_label_str = 'cat'
    # clean graph
    tf.reset_default_graph()
    # cnn framework
    x, y_, is_train, y_conv, loss, accuracy, train_step = cf.Net(sam_num, n_way, com_label_str, com_img_str)
    # initialize variables
    init = tf.global_variables_initializer()
    # step4: train
    # run and train
    with tf.Session() as sess:
        sess.run(init)
        tr_result = []
        te_result = []
        global_result = []
        break_num = False
        # get the training index
        tr_index = spb.com_train_data(tr_label, 1, com_label_str)
        for i in range(iter_num):
            bat_in, bat_out = spb.get_test_data_depth_concat(sam_num, n_way, com_label_str, com_img_str, tr_data, tr_label)
            optim, tr_loss, tr_acc, tr_y = sess.run([train_step, loss, accuracy, y_conv], feed_dict={x: bat_in, y_: bat_out, is_train: True})
            tr_result.append(np.asarray([i, tr_loss, tr_acc, bat_out.shape[0]]))
            
            # get the training global loss
            global_result = spb.get_global_loss(tr_result, tr_index)
            
            # whther stop
            if i > 100 and len(global_result)>2 and global_result[-1] >= global_result[-2]:
                break_num = True
            
            # test
            if i % 50 == 0 or break_num:
                bat_in, bat_out = spb.get_test_data_depth_concat(sam_num, n_way, com_label_str, com_img_str, te_data, te_label)
                te_loss, te_acc, te_y = sess.run([loss, accuracy, y_conv], feed_dict={x: bat_in, y_: bat_out, is_train: False})
                te_acc1 = te_acc
                te_result.append(np.asarray([i, te_loss, te_acc, te_acc1]))
                print(i, tr_loss, tr_acc, '||', te_acc, te_acc1)
                #print(datetime.datetime.now())
                
            # stop
            if break_num:
                break
        
        result = [np.stack(tr_result, 0), global_result, np.stack(te_result, 0)]
        # save the result
        spb.save_result(result, str1 )

    return    
    
def run_model(w_num, h_num, sam_size, tim_num, tr_bat, te_bat, iter_num,
              test_once, da_str, GPU_str, aug_str):
    # saving folder
    sa_str = spb.com_mul_str(
        [da_str, sam_size, tim_num, aug_str, w_num, h_num])
    # construct dataset
    da = spda.dataset(w_num, h_num, sam_size, tim_num, tr_bat, te_bat, da_str,
                      aug_str)

    # save test images and masks
    spb.save_result(da.te_i, sa_str, 'te_img.npy')
    spb.save_result(da.te_m, sa_str, 'te_mask.npy')

    # construct UNET
    Unet = model.UNET(da.i_w, da.i_h, da.i_ch, da.m_w, da.m_h, da.o_ch, 0.0001,
                      sa_str, 'yes', GPU_str)
    Unet.build_framework()

    # create lists to save the results
    tr_res = []
    te_res = []
    for i in range(iter_num):
        # train
        tr_img, tr_mask = da.get_tr_bat_img()
        tr_loss, tr_out = Unet.train(tr_img, tr_mask)
        tr_res.append(
            np.array([i, da.tr_ind,
                      np.mean(tr_loss), tr_img.shape[0]]))

        print('train', i, np.mean(tr_loss), time.ctime())

        # check
        flag = spb.ch_glob(tr_res)

        # test
        if (i + 1) % test_once == 0 or flag:
            te_img, te_mask = da.get_test_bat_img()
            te_loss, te_dice = Unet.pred(te_img, te_mask, tr_bat)
            te_res.append([i, np.mean(te_loss), te_dice])

            # print
            print('test', i, np.mean(te_loss), te_dice[-1], time.ctime())

            # save result
            spb.save_result(np.stack(tr_res, 0), sa_str, 'tr_res.npy')
            spb.save_list(te_res, sa_str, 'te_res.npy')
            Unet.save()

        if flag:
            break

    # save result
    spb.save_result(np.stack(tr_res, 0), sa_str, 'tr_res.npy')
    spb.save_list(te_res, sa_str, 'te_res.npy')
    Unet.save()
    def pred(self, input_data, output_data, b_num):
        los = []
        sav_ind = 0
        for ind in range(1000):
            st_num = (ind * b_num) % input_data.shape[0]
            bat_img = input_data[st_num:st_num + b_num, :, :, :]
            bat_mask = output_data[st_num:st_num + b_num, :, :, :]
            te_loss, te_out = self.sess.run(
                [self.loss, self.output],
                feed_dict={
                    self.input_data: bat_img,
                    self.output_mask: bat_mask,
                    self.is_train: False,
                    self.lr: self.learning_rate
                })
            los.append([te_loss])

            # save the result
            for j in range(bat_img.shape[0]):
                spb.save_result(bat_img[j], self.sa_str + '/te_pred/',
                                str(sav_ind) + '_img.npy')
                spb.save_result(bat_mask[j], self.sa_str + '/te_pred/',
                                str(sav_ind) + '_mask.npy')
                spb.save_result(te_out[j], self.sa_str + '/te_pred/',
                                str(sav_ind) + '_pred.npy')
                sav_ind = sav_ind + 1

            # check whether ended
            if bat_img.shape[0] < b_num or (ind +
                                            1) * b_num == input_data.shape[0]:
                break

        spb.save_result(np.array(los), self.sa_str, 'te_loss.npy')

        # calculate DICE
        img_arr, mask_arr, pred_arr = spb.read_pred(self.sa_str)
        pred_mask = spb.pred_to_mask(pred_arr)
        Dice = spb.mean_dice(mask_arr, pred_mask)

        return np.concatenate(los), Dice