示例#1
0
文件: main.py 项目: augustdemi/demi
            def summary(maml_result, set):
                print(set)
                print_str = 'Iteration ' + str(itr)
                print(print_str)
                y_hata = np.vstack(np.array(maml_result[-2][0]))  # length = num_of_task * N * K
                y_laba = np.vstack(np.array(maml_result[-2][1]))

                save_path = "./logs/result/train/" + trained_model_dir + "/" + str(FLAGS.au_idx) + "/"
                if FLAGS.keep_train_dir:
                    retrained_model_dir = 'sbjt' + str(FLAGS.sbjt_start_idx) + '.ubs_' + str(
                        FLAGS.train_update_batch_size) + '.numstep' + str(
                        FLAGS.num_updates) + '.updatelr' + str(
                        FLAGS.train_update_lr) + '.metalr' + str(FLAGS.meta_lr)
                    save_path += retrained_model_dir
                if not os.path.exists(save_path):
                    os.makedirs(save_path)

                print_summary(y_hata, y_laba, log_dir=save_path + "/outa_" + set + "_" + str(itr) + ".txt")
                print("------------------------------------------------------------------------------------")
                recent_y_hatb = np.array(maml_result[-1][0][
                                             FLAGS.num_updates - 1])
                y_hatb = np.vstack(recent_y_hatb)
                recent_y_labb = np.array(maml_result[-1][1][FLAGS.num_updates - 1])
                y_labb = np.vstack(recent_y_labb)
                print_summary(y_hatb, y_labb, log_dir=save_path + "/outb_" + set + "_" + str(itr) + ".txt")
                print("================================================================================")
示例#2
0
文件: test.py 项目: augustdemi/demi
def main():
    os.environ["CUDA_VISIBLE_DEVICES"] = FLAGS.gpu
    TOTAL_NUM_AU = 8
    all_au = ['au1', 'au2', 'au4', 'au6', 'au9', 'au12', 'au25', 'au26']

    if not FLAGS.train:
        orig_meta_batch_size = FLAGS.meta_batch_size
        # always use meta batch size of 1 when testing.
        FLAGS.meta_batch_size = 1
        temp_kshot = FLAGS.update_batch_size
        FLAGS.update_batch_size = 1
    if FLAGS.model.startswith('m2'):
        temp_num_updates = FLAGS.num_updates
        FLAGS.num_updates = 1



    data_generator = DataGenerator()

    dim_output = data_generator.num_classes
    dim_input = data_generator.dim_input

    inputa, inputb, labela, labelb = data_generator.make_data_tensor()
    metatrain_input_tensors = {'inputa': inputa, 'inputb': inputb, 'labela': labela, 'labelb': labelb}

    model = MAML(dim_input, dim_output)
    model.construct_model(input_tensors=metatrain_input_tensors, prefix='metatrain_')
    model.summ_op = tf.summary.merge_all()

    saver = loader = tf.train.Saver(tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES), max_to_keep=20)

    sess = tf.InteractiveSession()


    if not FLAGS.train:
        # change to original meta batch size when loading model.
        FLAGS.update_batch_size = temp_kshot
        FLAGS.meta_batch_size = orig_meta_batch_size

    if FLAGS.model.startswith('m2'):
        FLAGS.num_updates = temp_num_updates

    if FLAGS.train_update_batch_size == -1:
        FLAGS.train_update_batch_size = FLAGS.update_batch_size
    if FLAGS.train_update_lr == -1:
        FLAGS.train_update_lr = FLAGS.update_lr

    tf.global_variables_initializer().run()
    tf.train.start_queue_runners()

    print('initial weights: ', sess.run('model/b1:0'))
    print("========================================================================================")

    ################## Test ##################
    def _load_weight_m(trained_model_dir):
        all_au = ['au1', 'au2', 'au4', 'au6', 'au9', 'au12', 'au25', 'au26']
        if FLAGS.au_idx < TOTAL_NUM_AU: all_au = [all_au[FLAGS.au_idx]]
        w_arr = None
        b_arr = None
        for au in all_au:
            model_file = None
            print('model file dir: ', FLAGS.logdir + '/' + au + '/' + trained_model_dir)
            model_file = tf.train.latest_checkpoint(FLAGS.logdir + '/' + au + '/' + trained_model_dir)
            print("model_file from ", au, ": ", model_file)
            if (model_file == None):
                print(
                    "############################################################################################")
                print("####################################################################### None for ", au)
                print(
                    "############################################################################################")
            else:
                if FLAGS.test_iter > 0:
                    files = os.listdir(model_file[:model_file.index('model')])
                    if 'model' + str(FLAGS.test_iter) + '.index' in files:
                        model_file = model_file[:model_file.index('model')] + 'model' + str(FLAGS.test_iter)
                        print("model_file by test_iter > 0: ", model_file)
                    else:
                        print(" >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>", files)
                print("Restoring model weights from " + model_file)

                saver.restore(sess, model_file)
                w = sess.run('model/w1:0')
                b = sess.run('model/b1:0')
                print("updated weights from ckpt: ", b)
                print('----------------------------------------------------------')
                if w_arr is None:
                    w_arr = w
                    b_arr = b
                else:
                    w_arr = np.hstack((w_arr, w))
                    b_arr = np.vstack((b_arr, b))

        return w_arr, b_arr

    def _load_weight_s(sbjt_start_idx):
        batch_size = 10
        # 모든 au 를 이용하여 한 모델을 만든경우 그 한 모델만 로드하면됨.
        if FLAGS.model.startswith('s1'):
            three_layers = feature_layer(batch_size, TOTAL_NUM_AU)
            three_layers.loadWeight(FLAGS.vae_model_to_test, FLAGS.au_idx, num_au_for_rm=TOTAL_NUM_AU)
        # 각 au별로 다른 모델인 경우 au별 weight을 쌓아줘야함
        else:
            three_layers = feature_layer(batch_size, 1)
            all_au = ['au1', 'au2', 'au4', 'au6', 'au9', 'au12', 'au25', 'au26']
            if FLAGS.au_idx < TOTAL_NUM_AU: all_au = [all_au[FLAGS.au_idx]]
            w_arr = None
            b_arr = None
            for au in all_au:
                if FLAGS.model.startswith('s3'):
                    load_model_path = FLAGS.vae_model_to_test + '/' + FLAGS.model + '_' + au + '_kshot' + str(
                        FLAGS.update_batch_size) + '_iter100'
                elif FLAGS.model.startswith('s4'):
                    load_model_path = FLAGS.vae_model_to_test + '/' + FLAGS.model + '_' + au + '_subject' + str(
                        sbjt_start_idx + 1) + '_kshot' + str(
                        FLAGS.update_batch_size) + '_iter10_maml_adad' + str(FLAGS.test_iter)
                else:
                    load_model_path = FLAGS.vae_model_to_test + '/' + FLAGS.model + '_' + au + '_kshot' + str(
                        FLAGS.update_batch_size) + '_iter200_kshot10_iter10_nobatch_adam_noinit'
                three_layers.loadWeight(load_model_path, au)
                print('=============== Model S loaded from ', load_model_path)
                w = three_layers.model_intensity.layers[-1].get_weights()[0]
                b = three_layers.model_intensity.layers[-1].get_weights()[1]
                print('----------------------------------------------------------')
                if w_arr is None:
                    w_arr = w
                    b_arr = b
                else:
                    w_arr = np.hstack((w_arr, w))
                    b_arr = np.vstack((b_arr, b))

        return w_arr, b_arr



    def _load_weight_m0(trained_model_dir):
        model_file = None
        print('--------- model file dir: ', FLAGS.logdir + trained_model_dir)
        model_file = tf.train.latest_checkpoint(FLAGS.logdir + trained_model_dir)
        print(">>>> model_file from all_aus: ", model_file)
        if (model_file == None):
            print("####################################################################### None for all_aus")
        else:
            if FLAGS.test_iter > 0:
                files = os.listdir(model_file[:model_file.index('model')])
                if 'model' + str(FLAGS.test_iter) + '.index' in files:
                    model_file = model_file[:model_file.index('model')] + 'model' + str(FLAGS.test_iter)
                    print(">>>> model_file2: ", model_file)
                else:
                    print(" >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>", files)
            print("Restoring model weights from " + model_file)
            saver.restore(sess, model_file)
            w = sess.run('model/w1:0')
            b = sess.run('model/b1:0')
            print("updated weights from ckpt: ", b)
            print('----------------------------------------------------------')
        return w, b

    print("<<<<<<<<<<<< CONCATENATE >>>>>>>>>>>>>>")
    save_path = "./logs/result/"
    y_hat = []
    y_lab = []
    if FLAGS.all_sub_model:  # 모델이 모든 subjects를 이용해 train된 경우
        print('---------------- all sub model ----------------')
        # weight load를 한번만 실행해도됨. subject별로 모델이 다르지 않기 때문
        if FLAGS.model.startswith('m'):
            trained_model_dir = '/cls_' + str(FLAGS.num_classes) + '.mbs_' + str(
                FLAGS.meta_batch_size) + '.ubs_' + str(
                FLAGS.train_update_batch_size) + '.numstep' + str(FLAGS.num_updates) + '.updatelr' + str(
                FLAGS.train_update_lr) + '.metalr' + str(FLAGS.meta_lr)
            if FLAGS.model.startswith('m0'):
                w_arr, b_arr = _load_weight_m0(trained_model_dir)
            else:
                w_arr, b_arr = _load_weight_m(trained_model_dir)  # au별로 모델이 다르게됨

        ### test per each subject and concatenate
        for i in range(FLAGS.sbjt_start_idx, FLAGS.sbjt_start_idx + FLAGS.num_test_tasks):
            if FLAGS.model.startswith('s'):
                w_arr, b_arr = _load_weight_s(i)

            result = test_each_subject(w_arr, b_arr, i)
            y_hat.append(result[0])
            y_lab.append(result[1])
            print("y_hat shape:", result[0].shape)
            print("y_lab shape:", result[1].shape)
            print(">> y_hat_all shape:", np.vstack(y_hat).shape)
            print(">> y_lab_all shape:", np.vstack(y_lab).shape)
        print_summary(np.vstack(y_hat), np.vstack(y_lab), log_dir=save_path + "/" + "test.txt")
    else:  # 모델이 각 subject 별로 train된 경우: vae와 MAML의 train_test두 경우에만 존재 가능 + local weight test의 경우
        for subj_idx in range(FLAGS.sbjt_start_idx, FLAGS.sbjt_start_idx + FLAGS.num_test_tasks):
            if FLAGS.model.startswith('s'):
                w_arr, b_arr = _load_weight_s(subj_idx)
            else:
                trained_model_dir = '/sbjt' + str(subj_idx) + '.ubs_' + str(
                    FLAGS.train_update_batch_size) + '.numstep' + str(FLAGS.num_updates) + '.updatelr' + str(
                    FLAGS.train_update_lr) + '.metalr' + str(FLAGS.meta_lr)
                w_arr, b_arr = _load_weight_m(trained_model_dir)
            result = test_each_subject(w_arr, b_arr, subj_idx)
            y_hat.append(result[0])
            y_lab.append(result[1])
            print("y_hat shape:", result[0].shape)
            print("y_lab shape:", result[1].shape)
            print(">> y_hat_all shape:", np.vstack(y_hat).shape)
            print(">> y_lab_all shape:", np.vstack(y_lab).shape)
        print_summary(np.vstack(y_hat), np.vstack(y_lab),
                      log_dir=save_path + "/test.txt")

    end_time = datetime.now()
    elapse = end_time - start_time
    print("=======================================================")
    print(">>>>>> elapse time: " + str(elapse))
    print("=======================================================")
示例#3
0
    #path = r_path + iter + '/' + kshot + 'kshot/seed' + str(seed)

    y_lab_all = []
    y_hat_all = []
    f1_scores_per_seed = []
    for subject_idx in range(13):
        print("=============================subject_idx: ", subject_idx)
        file = pickle.load(open(
            path + '/predicted_subject' + str(subject_idx) + '.pkl', 'rb'),
                           encoding='latin1')
        y_lab = file['y_lab']
        y_hat = file['y_hat']
        y_lab_all.append(y_lab)
        y_hat_all.append(y_hat)
        out = print_summary(y_hat,
                            y_lab,
                            log_dir="./logs/result/" + "/test.txt")
        f1_score = list(out['data'][5])
        # add avg throughout all AUs as the last elt
        f1_score.append(np.average(f1_score))
        # stack each subject's f1-score
        f1_scores_per_seed.append(f1_score)
        print("-- num of samples:", len(file['all_used_frame_set']))

    print(">> y_lab_all shape:", np.vstack(y_lab_all).shape)
    print(">> y_hat_all shape:", np.vstack(y_hat_all).shape)

    print('-------------------- avg --------------------')
    averaged_f1 = np.average(f1_scores_per_seed, axis=0)
    print(averaged_f1)
    print('---------------- concatenated ---------------')
示例#4
0
def train(model, data_generator, saver, sess, trained_model_dir, resume_itr=0):
    print("===============> Final in weight: ",
          sess.run('model/w1:0').shape,
          sess.run('model/b1:0').shape)
    SUMMARY_INTERVAL = 10
    SAVE_INTERVAL = 5000

    if FLAGS.log:
        train_writer = tf.summary.FileWriter(
            FLAGS.logdir + '/' + trained_model_dir, sess.graph)

    feed_dict = {}

    print('Done initializing, starting training.')
    aus = ['au1', 'au2', 'au4', 'au6', 'au9', 'au12', 'au25', 'au26']

    total_val_stat = []
    total_mse = []
    two_layer = feature_layer(1, FLAGS.num_au)
    data_generator.get_validation_data()
    all_val_feat_vec = data_generator.val_feat_vec
    all_val_frame = data_generator.val_frame
    val_subjects = os.listdir(FLAGS.val_data_folder)
    val_subjects.sort()
    print('total validation subjects: ', val_subjects)

    from sklearn.metrics import f1_score
    for itr in range(resume_itr + 1, FLAGS.metatrain_iterations + 1):
        if FLAGS.shuffle_batch > 0 and itr % FLAGS.shuffle_batch == 0:
            print(
                '=============================================================shuffle data, iteration:',
                itr)
            inputa, inputb, labela, labelb, _ = data_generator.shuffle_data(
                itr, FLAGS.update_batch_size, aus)
            feed_dict = {
                model.inputa: inputa,
                model.inputb: inputb,
                model.labela: labela,
                model.labelb: labelb
            }

        if itr <= 1000:
            SAVE_INTERVAL = 100
        else:
            SAVE_INTERVAL = 100

        input_tensors = [model.train_op]

        if (itr % SUMMARY_INTERVAL == 0):
            input_tensors.extend([model.summ_op])

        input_tensors.extend([model.fast_weight_w])
        input_tensors.extend([model.fast_weight_b])
        result = sess.run(input_tensors, feed_dict)

        if (itr % SUMMARY_INTERVAL == 0):
            train_writer.add_summary(result[1], itr)

        if (itr % SAVE_INTERVAL == 0) or (itr == FLAGS.metatrain_iterations):

            print("================================================ iter:",
                  itr)
            print()
            saver.save(
                sess,
                FLAGS.logdir + '/' + trained_model_dir + '/model' + str(itr))

            w = []
            b = []
            w.append(sess.run('model/w1:0'))
            w.append(sess.run('model/w2:0'))
            w.append(sess.run('model/w3:0'))
            b.append(sess.run('model/b1:0'))
            b.append(sess.run('model/b2:0'))
            b.append(sess.run('model/b3:0'))

            ### save global weight ###
            with open(
                    FLAGS.logdir + '/' + trained_model_dir + "/two_layers" +
                    str(itr) + ".pkl", 'wb') as out:
                pickle.dump({'w': w, 'b': b}, out, protocol=2)

            ### save local weight ###
            with open(
                    FLAGS.logdir + '/' + trained_model_dir +
                    "/per_sub_weight" + str(itr) + ".pkl", 'wb') as out:
                pickle.dump({
                    'w': result[-2],
                    'b': result[-1]
                },
                            out,
                            protocol=2)

            ### validation ###
            two_layer.loadWeight(FLAGS.vae_model, w=w, b=b)
            print('--------------------------------------------------------')
            print("[Main] loaded soft bias to be evaluated: ", b[0])
            print("[Main] loaded z_mean bias to be evaluated : ", b[1][:4])
            print('--------------------------------------------------------')

            total_val_cnt = 0
            f1_scores = []
            mse = []
            for i in range(len(val_subjects)):
                eval_vec = all_val_feat_vec[i]
                eval_frame = all_val_frame[i]
                y_lab = data_generator.labels[i][eval_frame]
                print(
                    '---------------- len of eval_frame ---------------------')
                print(len(eval_frame))
                y_reconst = two_layer.model_reconst.predict(eval_vec)
                one_sub_mse = np.average(
                    np.power((np.array(eval_vec) - y_reconst), 2))
                print('==================== mse of {} ===================='.
                      format(val_subjects[i]))
                print(one_sub_mse)
                print(
                    '========================================================='
                )
                y_true = np.array([np.eye(2)[label] for label in y_lab])
                y_pred = two_layer.model_intensity.predict(eval_vec)
                print('y_true shape: ', y_true.shape)
                print('y_pred shape: ', y_pred.shape)
                total_val_cnt += int(y_true.shape[0])
                out = print_summary(y_pred,
                                    y_true,
                                    log_dir="./logs/result/" + "/test.txt")
                f1_score = np.average(list(out['data'][5]))
                f1_scores.append(f1_score)
                mse.append(one_sub_mse)
            means = np.mean(f1_scores, 0)
            stds = np.std(f1_scores, 0)
            ci95 = 1.96 * stds / np.sqrt(total_val_cnt)
            mean_mse = np.average(mse)
            total_val_stat.append((means, stds, ci95))
            total_mse.append(mean_mse)
            print(
                '================================================================'
            )
            print('total_val_cnt: ', total_val_cnt)
            print(
                '(Mean validation f1-score, stddev, and confidence intervals), Mean reconst. loss'
            )
            for i in range(len(total_val_stat)):
                print('iter:', (i + 1) * SAVE_INTERVAL, total_val_stat[i],
                      total_mse[i])