def run_training(log_file):

    # load the data
    print(150 * '*')
    with open("./mnist.pkl", "rb") as fid:
        dataset = pickle.load(fid, encoding='latin1')
    train_x, train_y = dataset[0]
    test_x, test_y = dataset[1]
    train_num = train_x.shape[0]
    test_num = test_x.shape[0]

    # label_index: list of list
    # ex) label_index[3] conatins every indicies of 3 data
    num_classes = FLAGS.num_classes
    label_index = [[] for i in range(num_classes)]
    for i in range(len(train_y)):
        label_index[train_y[i]].append(i)

    # construct the computation graph
    images = tf.placeholder(tf.float32, shape=[None, 1, 28, 28])
    labels = tf.placeholder(tf.int32, shape=[None])
    lr = tf.placeholder(tf.float32)

    features, _ = mnist_net(images)
    centers = func.construct_center(features, FLAGS.num_classes)
    loss = func.dce_loss(features, labels, centers, FLAGS.temp)
    eval_correct = func.evaluation(features, labels, centers)
    find_wrong = func.find_wrong(features, labels, centers)
    train_op = func.training(loss, lr)

    # counts = tf.get_variable('counts', [FLAGS.num_classes], dtype=tf.int32,
    # initializer=tf.constant_initializer(0), trainable=False)
    # add_op, count_op, average_op = net.init_centers(features, labels, centers, counts)

    init = tf.global_variables_initializer()

    # init the variables
    sess = tf.Session()
    sess.run(init)
    #compute_centers(sess, add_op, count_op, average_op, images, labels, train_x, train_y)

    # run the computation graph (training and test)
    epoch = 1
    loss_before = np.inf
    score_before = 0.0
    stopping = 0
    index = list(range(train_num))
    np.random.shuffle(index)
    batch_size = FLAGS.batch_size
    batch_num = train_num // batch_size if train_num % batch_size == 0 else train_num // batch_size + 1
    #saver = tf.train.Saver(max_to_keep=1)

    ratio = FLAGS.ratio
    iter = 0
    step = 0
    log_loss = 0
    log_acc = 0
    # pdb.set_trace()
    # train the framework with the training data
    while stopping < FLAGS.stop:
        time1 = time.time()
        loss_now = 0.0
        score_now = 0.0
        iter += 1
        if iter % 100 == 0:
            print("iter : {}".format(iter))
        wrong_list = np.zeros(num_classes)
        for i in range(batch_num):

            # selecting probability

            class_prob = ratio * wrong_list / batch_size + (
                1 - ratio * sum(wrong_list) / batch_size) / 10
            input_num_list = [
                math.ceil(batch_size * class_prob[j])
                for j in range(num_classes)
            ]
            batch_indicies = []
            for k in range(num_classes):
                rand_smpl = [
                    label_index[k][j] for j in sorted(
                        random.sample(range(len(label_index[k])),
                                      input_num_list[k]))
                ]
                batch_indicies += rand_smpl
            np.random.shuffle(batch_indicies)
            # pdb.set_trace()
            batch_x = train_x[batch_indicies[:batch_size]]
            batch_y = train_y[batch_indicies[:batch_size]]
            # JG
            batch_x_reshp = np.reshape(batch_x, (batch_size, 1, 28, 28))
            # batch_x.shape : (50, 784)
            # images : <tf.Tensor 'Placeholder:0' shape=(?, 1, 28, 28) dtype=float32>
            ##result = sess.run([train_op, loss, eval_correct], feed_dict={images:batch_x,
            ##    labels:batch_y, lr:FLAGS.learning_rate})
            result = sess.run([train_op, loss, eval_correct, find_wrong],
                              feed_dict={
                                  images: batch_x_reshp,
                                  labels: batch_y,
                                  lr: FLAGS.learning_rate
                              })
            loss_now += result[1]
            score_now += result[2]
            wrong_list = result[3]
            log_loss += result[1]
            log_acc += result[2]
            if (step % FLAGS.log_period == 0) and (step != 0):
                log_file.write('{}, loss, {:.3f}, acc, {:.3f}\n'.format(
                    step, log_loss / FLAGS.log_period,
                    log_acc / batch_size / FLAGS.log_period * 100))
                log_loss = 0
                log_acc = 0
            step += 1

        score_now /= train_num

        print('epoch {}: training: loss --> {:.3f}, acc --> {:.3f}%'.format(
            epoch, loss_now, score_now * 100))

        if loss_now > loss_before or score_now < score_before:
            stopping += 1
            FLAGS.learning_rate *= FLAGS.decay
            print("\033[1;31;40mdecay learning rate {}th time!\033[0m".format(
                stopping))

        loss_before = loss_now
        score_before = score_now

        #checkpoint_file = os.path.join(FLAGS.log_dir, 'model.ckpt')
        #saver.save(sess, checkpoint_file, global_step=epoch)

        epoch += 1
        np.random.shuffle(index)

        time2 = time.time()
        print('time for this epoch: {:.3f} minutes'.format(
            (time2 - time1) / 60.0))

    # pdb.set_trace()

    # test the framework with the test data
    test_score = do_eval(sess, eval_correct, images, labels, test_x, test_y)
    print('accuracy on the test dataset: {:.3f}%'.format(test_score * 100))

    # test the framework with the test data
    test_score = do_eval(sess, eval_correct, images, labels, test_x, test_y)

    sess.close()
Exemple #2
0
def run_training():

    # load the data
    print (150*'*')
    #with open("mnist.data", "rb") as fid:
    with open("/home/ubuntu/codes/Convolutional-Prototype-Learning/mnist.pkl", "rb") as fid:
        dataset = pickle.load(fid, encoding='latin1')
    train_x, train_y = dataset[0]
    test_x, test_y = dataset[1]
    train_num = train_x.shape[0]
    test_num = test_x.shape[0]

    # construct the computation graph
    images = tf.placeholder(tf.float32, shape=[None,1,28,28])
    labels = tf.placeholder(tf.int32, shape=[None])
    lr= tf.placeholder(tf.float32)

    features, _ = mnist_net(images)
    centers = func.construct_center(features, FLAGS.num_classes)
    loss1 = func.dce_loss(features, labels, centers, FLAGS.temp)
    loss2 = func.pl_loss(features, labels, centers)
    loss = loss1 + FLAGS.weight_pl * loss2
    eval_correct = func.evaluation(features, labels, centers)
    train_op = func.training(loss, lr)
    
    #counts = tf.get_variable('counts', [FLAGS.num_classes], dtype=tf.int32,
    #    initializer=tf.constant_initializer(0), trainable=False)
    #add_op, count_op, average_op = net.init_centers(features, labels, centers, counts)

    init = tf.global_variables_initializer()

    # initialize the variables
    sess = tf.Session()
    sess.run(init)
    #compute_centers(sess, add_op, count_op, average_op, images, labels, train_x, train_y)

    # run the computation graph (train and test process)
    epoch = 1
    loss_before = np.inf
    score_before = 0.0
    stopping = 0
    index = list(range(train_num))
    np.random.shuffle(index)
    batch_size = FLAGS.batch_size
    batch_num = train_num//batch_size if train_num % batch_size==0 else train_num//batch_size+1
    #saver = tf.train.Saver(max_to_keep=1)

    # train the framework with the training data
    while stopping<FLAGS.stop:
        time1 = time.time()
        loss_now = 0.0
        score_now = 0.0
    
        for i in range(batch_num):
            batch_x = train_x[index[i*batch_size:(i+1)*batch_size]]
            batch_y = train_y[index[i*batch_size:(i+1)*batch_size]]
            batch_x_reshp = np.reshape(batch_x, (batch_size, 1, 28, 28))
            result = sess.run([train_op, loss, eval_correct], feed_dict={images:batch_x_reshp,
                labels:batch_y, lr:FLAGS.learning_rate})
            loss_now += result[1]
            score_now += result[2]
        score_now /= train_num

        print ('epoch {}: training: loss --> {:.3f}, acc --> {:.3f}%'.format(epoch, loss_now, score_now*100))
        #print sess.run(centers)
    
        if loss_now > loss_before or score_now < score_before:
            stopping += 1
            FLAGS.learning_rate *= FLAGS.decay
            print ("\033[1;31;40mdecay learning rate {}th time!\033[0m".format(stopping))
            
        loss_before = loss_now
        score_before = score_now

        #checkpoint_file = os.path.join(FLAGS.log_dir, 'model.ckpt')
        #saver.save(sess, checkpoint_file, global_step=epoch)

        epoch += 1
        np.random.shuffle(index)

        time2 = time.time()
        print ('time for this epoch: {:.3f} minutes'.format((time2-time1)/60.0))
    
    pdb.set_trace() 
    # test the framework with the test data
    test_score = do_eval(sess, eval_correct, images, labels, test_x, test_y)
    print ('accuracy on the test dataset: {:.3f}%'.format(test_score*100))

    sess.close()
def run_training():

    # load the data
    print(150 * '*')
    train_x = mnist.train.images
    train_y = mnist.train.labels
    train_y = np.argmax(train_y, axis=1)
    test_x = mnist.test.images
    test_y = mnist.test.labels
    test_y = np.argmax(test_y, axis=1)
    # train_x, train_y = dataset[0]
    # test_x, test_y = dataset[1]
    # train_x = [np.reshape(x, (784, 1)) for x in train_x]
    # train_y = [vectorized_label(x) for x in train_y]
    train_num = train_x.shape[0]
    # test_num = test_x.shape[0]
    print("train:" + str(train_x.shape) + "label:" + str(train_y.shape))
    # construct the computation graph
    images = tf.placeholder(tf.float32, shape=[None, 784])
    labels = tf.placeholder(tf.int32, shape=[None])
    lr = tf.placeholder(tf.float32)

    features, _ = mnist_net(images)
    centers = func.construct_center(features, FLAGS.num_classes)
    loss1 = func.dce_loss(features, labels, centers, FLAGS.temp)
    loss2 = func.pl_loss(features, labels, centers)
    loss = loss1 + FLAGS.weight_pl * loss2
    eval_correct = func.evaluation(features, labels, centers)
    train_op = func.training(loss, lr)

    #counts = tf.get_variable('counts', [FLAGS.num_classes], dtype=tf.int32,
    #    initializer=tf.constant_initializer(0), trainable=False)
    #add_op, count_op, average_op = net.init_centers(features, labels, centers, counts)

    init = tf.global_variables_initializer()

    # initialize the variables
    sess = tf.Session()
    sess.run(init)
    #compute_centers(sess, add_op, count_op, average_op, images, labels, train_x, train_y)

    # run the computation graph (train and test process)
    epoch = 1
    loss_before = np.inf
    score_before = 0.0
    stopping = 0
    index = list(range(train_num))
    np.random.shuffle(index)
    batch_size = FLAGS.batch_size
    # print("batch size",batch_size)
    batch_num = train_num // batch_size if train_num % batch_size == 0 else train_num // batch_size + 1
    #saver = tf.train.Saver(max_to_keep=1)

    # train the framework with the training data
    while stopping < FLAGS.stop:
        time1 = time.time()
        loss_now = 0.0
        score_now = 0.0

        print("centers:", sess.run(centers))

        for i in range(batch_num):
            batch_x = train_x[index[i * batch_size:(i + 1) * batch_size]]
            batch_y = train_y[index[i * batch_size:(i + 1) * batch_size]]
            # print("train images:",batch_x.shape,"label:",batch_y.shape)
            result = sess.run([train_op, loss, eval_correct],
                              feed_dict={
                                  images: batch_x,
                                  labels: batch_y,
                                  lr: FLAGS.learning_rate
                              })
            loss_now += result[1]
            score_now += result[2]
        score_now /= train_num

        print('epoch {}: training: loss --> {:.3f}, acc --> {:.3f}%'.format(
            epoch, loss_now, score_now * 100))
        #print sess.run(centers)

        if loss_now > loss_before or score_now < score_before:
            stopping += 1
            FLAGS.learning_rate *= FLAGS.decay
            print("\033[1;31;40mdecay learning rate {}th time!\033[0m".format(
                stopping))

        loss_before = loss_now
        score_before = score_now

        #checkpoint_file = os.path.join(FLAGS.log_dir, 'model.ckpt')
        #saver.save(sess, checkpoint_file, global_step=epoch)

        epoch += 1
        np.random.shuffle(index)

        time2 = time.time()
        print('time for this epoch: {:.3f} minutes'.format(
            (time2 - time1) / 60.0))

        # test the framework with the test data
        test_score = do_eval(sess, eval_correct, images, labels, test_x,
                             test_y)
        print('测试集准确率 {:.10f}%'.format(test_score * 100))

    sess.close()
def run_training():

    # load the data
    print (150*'*')
    #with open("mnist.data", "rb") as fid:
    with open("/home/ubuntu/datasets/mnist.pkl", "rb") as fid:
        dataset = pickle.load(fid, encoding='latin1')
    train_x, train_y = dataset[0]
    test_x, test_y = dataset[1]
    train_num = train_x.shape[0]
    test_num = test_x.shape[0]

    # construct the computation graph
    images = tf.placeholder(tf.float32, shape=[None,1,28,28])
    labels = tf.placeholder(tf.int32, shape=[None])
    lr= tf.placeholder(tf.float32)

    features, _ = mnist_net(images)
    centers = func.construct_center(features, FLAGS.num_classes)
    loss1 = func.dce_loss(features, labels, centers, FLAGS.temp)
    loss2 = func.pl_loss(features, labels, centers)
    loss = loss1 + FLAGS.weight_pl * loss2
    eval_correct = func.evaluation(features, labels, centers)
    train_op = func.training(loss, lr)
    
    #counts = tf.get_variable('counts', [FLAGS.num_classes], dtype=tf.int32,
    #    initializer=tf.constant_initializer(0), trainable=False)
    #add_op, count_op, average_op = net.init_centers(features, labels, centers, counts)
  
    sess = tf.Session()
    load_saver = tf.train.Saver()
    os.makedirs(FLAGS.log_dir, exist_ok=True)
    file_list = os.listdir(FLAGS.log_dir)
    keep_last_int = 0
    last_load_file_name = ''
    for name in file_list:
        if len(name.split('.')) < 2:
            continue
        if keep_last_int < int(name.split('.')[1].split('-')[1]):
            keep_last_int = int(name.split('.')[1].split('-')[1])
            last_load_file_name = '.'.join(name.split('.')[:2])
    load_file = os.path.join(FLAGS.log_dir, last_load_file_name)
    if os.path.isfile(load_file+".meta"):
        load_saver.restore(sess, load_file)
    else:
        init = tf.global_variables_initializer()

        # initialize the variables
        sess = tf.Session()
        sess.run(init)
        #compute_centers(sess, add_op, count_op, average_op, images, labels, train_x, train_y)

        # run the computation graph (train and test process)
        epoch = 1
        loss_before = np.inf
        score_before = 0.0
        stopping = 0
        index = list(range(train_num))
        np.random.shuffle(index)
        batch_size = FLAGS.batch_size
        batch_num = train_num//batch_size if train_num % batch_size==0 else train_num//batch_size+1
        saver = tf.train.Saver(max_to_keep=1)

        # train the framework with the training data
        while stopping<FLAGS.stop:
            time1 = time.time()
            loss_now = 0.0
            score_now = 0.0
        
            for i in range(batch_num):
                batch_x = train_x[index[i*batch_size:(i+1)*batch_size]]
                batch_y = train_y[index[i*batch_size:(i+1)*batch_size]]
                batch_x_reshp = np.reshape(batch_x, (batch_size, 1, 28, 28))
                result = sess.run([train_op, loss, eval_correct], feed_dict={images:batch_x_reshp,
                    labels:batch_y, lr:FLAGS.learning_rate})
                # features_eval = sess.run([features], feed_dict={images:batch_x_reshp, 
                #    labels:batch_y, lr:FLAGS.learning_rate})
                # features_eval.shape (1, 50, 2)
                loss_now += result[1]
                score_now += result[2]
            score_now /= train_num

            print ('epoch {}: training: loss --> {:.3f}, acc --> {:.3f}%'.format(epoch, loss_now, score_now*100))
            print (sess.run(centers))
        
            if loss_now > loss_before or score_now < score_before:
                stopping += 1
                FLAGS.learning_rate *= FLAGS.decay
                print ("\033[1;31;40mdecay learning rate {}th time!\033[0m".format(stopping))
                
            loss_before = loss_now
            score_before = score_now

            checkpoint_file = os.path.join(FLAGS.log_dir, 'model.ckpt')
            saver.save(sess, checkpoint_file, global_step=epoch)

            epoch += 1
            np.random.shuffle(index)

            time2 = time.time()
            print ('time for this epoch: {:.3f} minutes'.format((time2-time1)/60.0))

            break # For testing only the first episode
        
    #pdb.set_trace() 
    # test the framework with the test data
    test_score, eval_dots_perclass = do_eval(sess, eval_correct, images, labels, test_x, test_y, features) 
    compute_overlap(eval_dots_perclass)
    # eval_dots_perclass
    # len(eval_dots_perclass) : 10 [num_categories=10]
    # eval_dots_perclass[0] : dots of category 0
    # eval_dots_perclass[0][0] : (x, y) of index 0 of category 0
    print ('accuracy on the test dataset: {:.3f}%'.format(test_score*100))
    # pdb.set_trace()

    sess.close()
Exemple #5
0
def run_training():
# load the data
    print (150*'*')
    HU2012 = sio.loadmat('./data/HU2012/2012_Houston.mat')
    data_IN = HU2012['spectraldata']
    gt_IN = HU2012['gt_2012']
    print (data_IN.shape)
    data = data_IN.reshape(np.prod(data_IN.shape[:2]),np.prod(data_IN.shape[2:]))
    gt = gt_IN.reshape(np.prod(gt_IN.shape[:2]),)

    trainingIndexf = './data/Houston2012trainingIndex.mat'
    train_indices = sio.loadmat(trainingIndexf)['trainingIndex']
    train_indices_rows = sio.loadmat(trainingIndexf)['trainingIndex_rows']
    train_indices_cols = sio.loadmat(trainingIndexf)['trainingIndex_cols']
    testingIndexf = './data/Houston2012testingIndex.mat'
    test_indices = sio.loadmat(testingIndexf)['testingIndex']  
    test_indices_rows = sio.loadmat(testingIndexf)['testingIndex_rows']  
    test_indices_cols = sio.loadmat(testingIndexf)['testingIndex_cols'] 

    train_indices = np.squeeze(train_indices-1)
    test_indices = np.squeeze(test_indices-1)
    height, width = gt_IN.shape

    Y=gt_IN.T
    Y = Y.reshape(height*width,)
    train_y = Y[train_indices]-1
    test_y = Y[test_indices] - 1

    classes_num = np.max(gt)
    
    data = preprocessing.scale(data)
    whole_data = data.reshape(data_IN.shape[0], data_IN.shape[1], data_IN.shape[2])

    whole_data, pca = applyPCA(whole_data, numComponents = FLAGS.numComponents)
    img_channels = whole_data.shape[2]
    PATCH_LENGTH = int((FLAGS.window_size-1)/2)
    padded_data = zeroPadding.zeroPadding_3D(whole_data, PATCH_LENGTH)
    train_data = np.zeros((train_indices.shape[0], FLAGS.window_size, FLAGS.window_size, img_channels))
    test_data = np.zeros((test_indices.shape[0], FLAGS.window_size, FLAGS.window_size, img_channels))
    
    train_assign = indexToAssignment(np.squeeze(train_indices_rows-1), np.squeeze(train_indices_cols-1), PATCH_LENGTH)
    for i in range(len(train_assign)):
        train_data[i] = selectNeighboringPatch(padded_data,train_assign[i][0],train_assign[i][1],PATCH_LENGTH)

    test_assign = indexToAssignment(np.squeeze(test_indices_rows-1), np.squeeze(test_indices_cols-1), PATCH_LENGTH)
    for i in range(len(test_assign)):
        test_data[i] = selectNeighboringPatch(padded_data,test_assign[i][0],test_assign[i][1],PATCH_LENGTH)
    
    Xtrain = train_data.reshape(train_data.shape[0], train_data.shape[1], train_data.shape[2],img_channels)
    Xtest = test_data.reshape(test_data.shape[0], test_data.shape[1], test_data.shape[2], img_channels)
    train_x = Xtrain.reshape(-1,train_data.shape[1], train_data.shape[2],img_channels,1)
    test_x = Xtest.reshape(-1, test_data.shape[1], test_data.shape[2],img_channels,1)
    train_num = train_x.shape[0]
    test_num = test_x.shape[0]

    # construct the computation graph
    images = tf.placeholder(tf.float32, shape=[None,FLAGS.window_size,FLAGS.window_size,img_channels,1])
    labels = tf.placeholder(tf.int32, shape=[None])
    lr= tf.placeholder(tf.float32)

    features = res4_model_ss(images,[1],[1])
    prototypes = func.construct_center(features, classes_num, 1)
    
    loss1 = func.dce_loss(features, labels, prototypes, FLAGS.temp)
    loss2 = func.dis_loss(features, labels, prototypes)
    loss = loss1 + FLAGS.weight_dis * loss2

    eval_correct = func.evaluation(features, labels, prototypes)
    train_op = func.training(loss, lr)

    # initialize the variables
    init = tf.global_variables_initializer()
    sess = tf.Session()
    sess.run(init)

    # run the computation graph (train and test process)
    epoch = 1
    index = list(range(train_num))
    np.random.shuffle(index)
    batch_size = FLAGS.batch_size
    batch_num = train_num//batch_size if train_num % batch_size ==0 else train_num//batch_size+1
    train_start= time.time()

    # train the framework with the training data
    while epoch<FLAGS.epoch_num:
        time1 = time.time()
        loss_now = 0.0
        score_now = 0.0
        for i in range(batch_num):
            batch_x = train_x[index[i*batch_size:(i+1)*batch_size]]
            batch_y = train_y[index[i*batch_size:(i+1)*batch_size]]
            result = sess.run([train_op, loss, eval_correct], feed_dict={images:batch_x,
                labels:batch_y, lr:FLAGS.learning_rate})
            loss_now += result[1]
            score_now += result[2][1]
        score_now /= train_num
        print ('epoch {}: training: loss --> {:.3f}, acc --> {:.3f}%'.format(epoch, loss_now, score_now*100))
        FLAGS.learning_rate-=FLAGS.decay
        epoch += 1
        np.random.shuffle(index)
        time2 = time.time()
        print ('time for this epoch: {:.3f} minutes'.format((time2-time1)/60.0))
    print()
    print('time for the whole training phase: '+str(time.time()-train_start)+' s')   
    # test the framework with the test data
    # init_prototypes_value = sess.run(prototypes) # get the variable of prototypes
    test_start= time.time()
    pred_labels, test_score = do_eval(sess, eval_correct, images, labels, test_x, test_y)
    print('time for the whole testing phase: '+str(time.time()-test_start)+' s')
    sess.close()    
    pred_labels = np.int8(pred_labels)  
    test_y = np.int8(test_y) 
#    confusion matrix
    matrix = np.zeros((classes_num, classes_num))
    with open('prediction.txt', 'w') as f:
        for i in range(test_num):
            pre_label = pred_labels[i]
            f.write(str(pre_label)+'\n')
            matrix[pre_label, test_y[i]] += 1
    f.closed  
    print()
    print('The confusion matrix is:')
    print(np.int_(matrix))

#   calculate the overall accuracy
    OA = np.sum(np.trace(matrix)) / float(test_num)
#    print('OA = '+str(OA)+'\n')
#   calculate the per-class accuracy
#    print('ua =')
    ua = np.diag(matrix)/np.sum(matrix, axis=0)
#   calculate the precision
#    print('precision =')
    precision = np.diag(matrix)/np.sum(matrix, axis=1)
#   calculate the Kappa coefficient
    matrix = np.mat(matrix)
    Po = OA
    xsum = np.sum(matrix, axis=1)
    ysum = np.sum(matrix, axis=0)
    Pe = float(ysum*xsum)/(np.sum(matrix)**2)
    Kappa = float((Po-Pe)/(1-Pe))
    ## print the classification result
    for i in ua:
         print(i)
    print(str(np.sum(ua)/matrix.shape[0]))
    print(str(OA))
    print(str(Kappa))
    print()
    for i in precision:
         print(i)  
    print(str(np.sum(precision)/matrix.shape[0]))
Exemple #6
0
def run_training():
    # load the data
    print (150*'*')
    if FLAGS.dataset == 'mnist':
        train_x, train_y = mnist.train.images.reshape(-1,1,28,28), mnist.train.labels
        test_x, test_y = mnist.test.images.reshape(-1,1,28,28), mnist.test.labels
        # data normalization.
        # test_x = test_x - np.mean(test_x, axis = 0)
        # test_x = test_x / (np.std(test_x, axis = 0) + 1e-8)
        # train_x = train_x - np.mean(train_x, axis = 0)
        # train_x= train_x / (np.std(train_x, axis = 0) + 1e-8)
        train_num = train_x.shape[0]
        test_num = test_x.shape[0]
        # construct the computation graph
        images_new = tf.placeholder(tf.float32, shape=[None,1,28,28])
        images = tf.placeholder(tf.float32, shape=[None,1,28,28])
        features, logits = mnist_net(images)

        if FLAGS.use_augmentation:
            augment = data_augmentor(images_new)

        labels = tf.placeholder(tf.int32, shape=[None])
        lr= tf.placeholder(tf.float32)
        # print('test!')

    elif FLAGS.dataset == 'cifar10':
        from keras.datasets import cifar10
        (xtrain, ytrain), (xtest, ytest) = cifar10.load_data()
        ytrain = np.squeeze(ytrain)
        ytest = np.squeeze(ytest)
        xtrain = xtrain.astype('float32')
        xtrain = xtrain / 255.0
        xtest = xtest.astype('float32')
        xtest = xtest / 255.0

        # xtrain, ytrain, xtest, ytest = load_cifar()
        train_x, train_y = xtrain.reshape(-1,32,32,3), ytrain
        test_x, test_y = xtest.reshape(-1,32,32,3), ytest
        # data normalization.
        # test_x = test_x - np.mean(test_x, axis = 0)
        # test_x = test_x / np.std(test_x, axis = 0)
        # train_x = train_x - np.mean(train_x, axis = 0)
        # train_x= train_x / np.std(train_x, axis = 0)
        #another normalization method.
        train_x = train_x - [0.491,0.482,0.447]
        train_x = train_x / [0.247,0.243,0.262]
        test_x = test_x - [0.491,0.482,0.447]
        test_x = test_x / [0.247,0.243,0.262]        

        train_num = train_x.shape[0]
        test_num = test_x.shape[0]
	    # construct the computation graph
        images = tf.placeholder(tf.float32, shape=[None,32,32,3])
        # images_normalized = tf.map_fn(lambda img: tf.image.per_image_standardization(img),
        #                        images)
        images_new = tf.placeholder(tf.float32, shape=[None,32,32,3])
        labels = tf.placeholder(tf.int32, shape=[None])
        lr= tf.placeholder(tf.float32)

        if FLAGS.model == 'resnet':
            if FLAGS.use_augmentation:
                augment = data_augmentor(images_new)
            features, logits = inference(images, FLAGS.num_residual_blocks, FLAGS.num_classes,\
             reuse=False, weight_decay = FLAGS.weight_decay)
        else:
            if FLAGS.use_augmentation:
                augment = data_augmentor(images_new)
            features, logits = cifar_net1(images, [0.1,0.5,0.5])	



    elif FLAGS.dataset == 'cifar100':
        from keras.datasets import cifar100
        (xtrain, ytrain), (xtest, ytest) = cifar100.load_data()
        ytrain = np.squeeze(ytrain)
        ytest = np.squeeze(ytest)
        xtrain = xtrain.astype('float32')
        xtrain = xtrain / 255.0
        xtest = xtest.astype('float32')
        xtest = xtest / 255.0
        # xtrain, ytrain, xtest, ytest = load_cifar100()
        train_x, train_y = xtrain.reshape(-1,32,32, 3), ytrain
        test_x, test_y = xtest.reshape(-1,32,32,3), ytest
        #another normalization method.
        train_x = train_x - [0.4914, 0.4822, 0.4465]
        train_x = train_x / [0.2023, 0.1994, 0.2010]
        test_x = test_x - [0.4914, 0.4822, 0.4465]
        test_x = test_x / [0.2023, 0.1994, 0.2010]    

        train_num = train_x.shape[0]
        test_num = test_x.shape[0]

	    # construct the computation graph
        images = tf.placeholder(tf.float32, shape=[None,32,32,3])
        images_new = tf.placeholder(tf.float32, shape=[None,32,32,3])
        labels = tf.placeholder(tf.int32, shape=[None])
        lr= tf.placeholder(tf.float32)

        if FLAGS.model == 'densenet':
            if FLAGS.use_augmentation:
                augment = data_augmentor(images_new)
            features, logits = densenet_bc(images, num_classes = FLAGS.num_classes,is_training = True, growth_rate = 12,drop_rate = 0,\
            	depth = 100, for_imagenet = False, reuse = False, scope='test')
            # inference(images, FLAGS.num_residual_blocks, reuse=False, weight_decay = FLAGS.weight_decay)
        elif FLAGS.model == 'resnet':
            if FLAGS.use_augmentation:
                augment = data_augmentor(images_new)
            features, logits = inference(images, FLAGS.num_residual_blocks, FLAGS.num_classes,\
            	reuse=False, weight_decay = FLAGS.weight_decay)
        else:
            if FLAGS.use_augmentation:
                augment = data_augmentor(images_new)
            features, logits = cifar_net1(images, [0.1,0.5,0.5])	

    # import ipdb
    # ipdb.set_trace()
    if FLAGS.loss == 'cpl': 
        centers = func.construct_center(features, FLAGS)

        if FLAGS.use_dot_product:
            loss1 = func.dot_dce_loss(features, labels, centers, FLAGS.temp, FLAGS)
            eval_correct = func.evaluation_dot_product(features, labels, centers, FLAGS)
        else:
            loss1 = func.dce_loss(features, labels, centers, FLAGS.temp, FLAGS)
            eval_correct = func.evaluation(features, labels, centers, FLAGS)

        loss2 = FLAGS.weight_pl * func.pl_loss(features, labels, centers, FLAGS)
        loss = loss1 + loss2
        
        if FLAGS.model == 'resnet':
            # reg_losses = tf.reduce_sum(tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES))
            reg_losses = tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)
            loss = tf.add_n([loss] + reg_losses)#loss + reg_losses
        if FLAGS.model == 'densenet':
            reg_losses = tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)
            loss = tf.add_n([loss] + reg_losses)#loss + reg_losses

        train_op = func.training(loss, FLAGS, lr)

    elif FLAGS.loss == "softmax":
        loss = func.softmax_loss(logits, labels)
        eval_correct = func.evaluation_softmax(logits, labels)
        if FLAGS.model == 'resnet':
            reg_losses = tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)
            loss = tf.add_n([loss] + reg_losses)#loss + reg_losses
        if FLAGS.model == 'densenet':
            reg_losses = tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)
            loss = tf.add_n([loss] + reg_losses)#loss + reg_losses

        train_op = func.training(loss, FLAGS, lr)

    init = tf.global_variables_initializer()
    # initialize the variables
    sess = tf.Session()
    sess.run(init)
    saver = tf.train.Saver(max_to_keep=1)
    if FLAGS.restore:
        saver.restore(sess,FLAGS.restore)
        print('restore the model from: ', FLAGS.restore)

    # run the computation graph (train and test process)
    stopping = 0
    index = range(train_num)
    np.random.shuffle(list(index))
    batch_size = FLAGS.batch_size
    batch_num = train_num//batch_size if train_num % batch_size==0 else train_num//batch_size+1
    

    # train the framework with the training data
    acc_save = []
    steps = 0
    for epoch in range(FLAGS.num_epoches):
        time1 = time.time()
        loss_now = 0.0
        score_now = 0.0
        loss_dce = 0.0
        loss_pl = 0.0
        reg_loss = 0.0

        for i in range(batch_num):

            # if loss_now > loss_before or score_now < score_before:
            # if (epoch + 1) % FLAGS.decay_step == 0:
            if epoch < 10:
                # learning_rate_temp = FLAGS.learning_rate * float(steps) / float(batch_num * 10)
                learning_rate_temp = FLAGS.learning_rate 

            batch_x = train_x[index[i*batch_size:(i+1)*batch_size]]
            batch_y = train_y[index[i*batch_size:(i+1)*batch_size]]
            # mask_temp = compute_mask(FLAGS, batch_y)

            if FLAGS.loss == 'cpl': 
                if FLAGS.model == 'resnet' or FLAGS.model == 'densenet':
                    if FLAGS.use_augmentation:
                        batch_x = augment.output(sess, batch_x)
                    result = sess.run([train_op, loss, loss1, loss2, eval_correct, reg_losses, centers, features],\
                    feed_dict={images:batch_x, labels:batch_y, lr:learning_rate_temp})
                    reg_loss += np.sum(result[5])
                    # print(result[-2])

                else:
                    if FLAGS.use_augmentation:
                        batch_x = augment.output(sess, batch_x)
                    result = sess.run([train_op, loss, loss1, loss2, eval_correct, centers, features],\
                    feed_dict={images:batch_x, labels:batch_y, lr:learning_rate_temp})
                    print(result[-2])

                loss_now += result[1]
                score_now += result[4]
                loss_dce += result[2]
                loss_pl += result[3]

                if i == 0:
                    features_container = np.asarray(result[-1])
                    label_container = batch_y
                else:
                    features_container = np.concatenate((features_container, result[-1]), axis=0)
                    label_container = np.concatenate((label_container, batch_y), axis=0)


            elif FLAGS.loss == 'softmax':
                if FLAGS.model == 'resnet' or FLAGS.model == 'densenet':
                    if FLAGS.use_augmentation:
                        batch_x = augment.output(sess, batch_x)

                    result = sess.run([train_op, loss, eval_correct, reg_losses],\
                    feed_dict={images:batch_x, labels:batch_y, lr:learning_rate_temp})

                    reg_loss += np.sum(result[3])
                else:
                    if FLAGS.use_augmentation:
                        batch_x = augment.output(sess, batch_x)

                    result = sess.run([train_op, loss, eval_correct],\
                    feed_dict={images:batch_x, labels:batch_y, lr:learning_rate_temp})

                loss_now += result[1]
                score_now += result[2]
            steps += 1

        # for visualization. 
        if FLAGS.loss == 'cpl' and FLAGS.dataset == 'mnist':
            if epoch % 10 == 0:
                centers_container = result[-2]
                func.visualize(features_container, label_container, epoch, centers_container, FLAGS)

        # if epoch +1 == 150 or epoch +1 == 225:
        if (epoch + 1) % FLAGS.decay_step == 0:
            stopping += 1
            learning_rate_temp *= FLAGS.decay_rate
            print ("\033[1;31;40mdecay learning rate {}th time!\033[0m".format(stopping))

        score_now /= train_num
        loss_now /= batch_num

        if FLAGS.loss == 'cpl':
            loss_dce = loss_dce / batch_num
            loss_pl = loss_pl / batch_num
            if FLAGS.model == 'resnet' or FLAGS.model == 'densenet':
                reg_loss /= batch_num
                print ('epoch {}: training: loss --> {:.3f}, dce_loss --> {:.3f}, pl_loss --> {:.3f}, reg_loss --> {:.3f},\
                 acc --> {:.3f}%'.format(epoch, loss_now, loss_dce, loss_pl, reg_loss, score_now*100))
            else:
                print ('epoch {}: training: loss --> {:.3f}, dce_loss --> {:.3f}, pl_loss --> {:.3f},\
                 acc --> {:.3f}%'.format(epoch, loss_now, loss_dce, loss_pl, score_now*100))
        elif FLAGS.loss == 'softmax':
            if FLAGS.model == 'resnet' or FLAGS.model == 'densenet':
                reg_loss /= batch_num
                print ('epoch {}: training: loss --> {:.3f}, reg_loss --> {:.3f},\
                   acc --> {:.3f}%'.format(epoch, loss_now, reg_loss, score_now*100)) 
            else:
                print ('epoch {}: training: loss --> {:.3f},\
                   acc --> {:.3f}%'.format(epoch, loss_now, score_now*100))        	


        # epoch += 1
        np.random.shuffle(list(index))
        time2 = time.time()
        print ('time for this epoch: {:.3f} minutes'.format((time2-time1)/60.0))


        # test the framework with the test data
        if (epoch + 1) % FLAGS.print_step == 0:
            # test_score, logits_test = do_eval(sess, eval_correct, images, labels, test_x, test_y, logits)
            # np.save('./cifar10_logits.npy', logits_test)
            test_score = do_eval(sess, eval_correct, images, labels, test_x, test_y)
            print ('epoch:{}, accuracy on the test dataset: {:.3f}%'.format(epoch, test_score*100))
            acc_save.append(test_score)
            temp = np.amax(np.asarray(acc_save))
            print('best test acc:',temp)

        # saving the model.
        if not os.path.isdir(os.path.join(FLAGS.log_dir, FLAGS.dataset)):  
            os.makedirs(os.path.join(FLAGS.log_dir, FLAGS.dataset))
        checkpoint_file = os.path.join(str(os.path.join(FLAGS.log_dir, FLAGS.dataset)),\
         'model_'+FLAGS.loss+'_'+str(FLAGS.use_dot_product)+'_'+str(FLAGS.learning_rate)+'_'+str(FLAGS.batch_size)+'.ckpt')
        saver.save(sess, checkpoint_file, global_step=epoch)



    acc_save = np.asarray(acc_save)
    np.save('./acc_test_cifar10_original_paper.npy', acc_save)
    sess.close()