def test():
    os.environ['CUDA_VISIBLE_DEVICES'] = FLAGS.gpu_list
    """ Load model """
    input_image = Input(shape=(None, None, 3), name='image', dtype=tf.float32)
    region, affinity = VGG16_UNet(input_tensor=input_image, weights=None)
    model = Model(inputs=[input_image], outputs=[region, affinity])
    model.load_weights(FLAGS.trained_model)
    """ For test images in a folder """
    gt_all_imgs = load_data(
        "/home/ldf/CRAFT_keras/data/CTW/gt.pkl")  # SynthText, CTW
    t = time.time()
    PPR_list = list()
    num_list = list()
    dif_list = list()
    """ Test images """
    for k, [image_path, word_boxes, words,
            char_boxes_list] in enumerate(gt_all_imgs):

        image = load_image(image_path)
        start_time = time.time()
        bboxes, score_text = predict(model, image, FLAGS.text_threshold,
                                     FLAGS.link_threshold, FLAGS.low_text)
        """ Compute single pic's PPR and num """
        PPR, single_num, diff = compute_PPR(bboxes, char_boxes_list)

        PPR_list.append(PPR)
        num_list.append(single_num)
        dif_list.append(np.mean(diff))

    print("elapsed time : {}s".format(time.time() - t))
    result, MPD = compute_final_result(PPR_list, num_list, dif_list)
    print("PPR", result)
    print("MPD", MPD)
Пример #2
0
def interactive(args):
    config = Config(args)
    config.model_output = args.checkpoint_dir

    config.batch_size = 1

    logger.info("Loading data...")
    start = time.time()
    train_vocab = load_data(args.pointer,
                            interactive=True,
                            vocab_size=args.vocab_size)
    logger.info("took %.2f seconds", time.time() - start)

    with tf.Graph().as_default():
        logger.info("Building model...", )
        start = time.time()
        model = Summarizer(config, train_vocab, False, args.attention,
                           args.beam_search, args.bidirectional, args.pointer)
        logger.info("took %.2f seconds", time.time() - start)

        init = tf.global_variables_initializer()
        saver = tf.train.Saver()

        with tf.Session() as sess:
            sess.run(init)
            restore_checkpoint(sess, saver, config.model_output, args.pointer)
            # Interactive mode
            model.interactive(sess)
Пример #3
0
def train_avg_pay(data_path, data_type, metrics_path):
    """

    :param data_path: 数据路径
    :param data_type: 数据类型
    :param metrics_path: loss记录路径
    :return:
    """

    pay_feature_cat = [
        'policy', 'coverage', 'industry', 'occupation', 'city', 'avg_age'
    ]
    pay_feature_con = ['log_amount', 'log_num', 'sex_ratio']
    pay_label = 'avg_pay'
    pay_train, pay_validate, pay_test = \
        load_data(data_path, data_type, pay_feature_cat, pay_feature_con, pay_label)

    pay_model_path = data_path + 'insurance_price/keras_avg_pay/'
    pay_metrics_file = metrics_path + 'avg_pay_metrics.csv'

    dnn = Dnn(pay_model_path)
    dnn.dnn_net(input_shape=177)
    dnn.train_model(pay_metrics_file, pay_train, pay_label, epoch=150)
    dnn.save_model()
    # 出险率训练集
    pay_train_y, pay_train_yp = dnn.predict(pay_train, pay_label)
    pay_train_error = dnn.cal_error(pay_train_y, pay_train_yp)
    # 出险率验证集
    pay_validate_y, pay_validate_yp = dnn.predict(pay_validate, pay_label)
    pay_validate_error = dnn.cal_error(pay_validate_y, pay_validate_yp)
    dnn.save_result(pay_model_path, pay_validate_y, pay_validate_yp)
    # 出险率测试集
    pay_test_y, pay_test_yp = dnn.predict(pay_test, pay_label)
    pay_test_error = dnn.cal_error(pay_test_y, pay_test_yp)

    with open(pay_metrics_file, 'a') as f:
        writer = csv.writer(f)
        writer.writerow(['训练集', ' '.join(str(i) for i in pay_train_error)])
        writer.writerow(['验证集', ' '.join(str(i) for i in pay_validate_error)])
        writer.writerow(['测试集', ' '.join(str(i) for i in pay_test_error)])

    fmt = '%6.3f' * 3
    print('件均赔付:', '训练集', fmt % pay_train_error, '验证集',
          fmt % pay_validate_error, '测试集', fmt % pay_test_error)
Пример #4
0
def train(args):
    config = Config(args)
    config.model_output = args.checkpoint_dir

    logger.info("Loading data...", )
    start = time.time()
    if not args.load:
        train_vocab, train, dev = load_and_preprocess_data(args)
    else:
        train_vocab, train, dev = load_data(args.pointer,
                                            vocab_size=args.vocab_size)

    logger.info("took %.2f seconds", time.time() - start)

    print("")

    plot_length_dist(train[0], "train_length_dist.png")
    plot_length_dist(train[1], "label_length_dist.png")
    #return

    with tf.Graph().as_default():
        logger.info("Building model...", )
        start = time.time()
        #model = test_model.Test(config, train_vocab, labels_vocab)
        model = Summarizer(config,
                           train_vocab,
                           True,
                           args.attention,
                           bidirectional=args.bidirectional,
                           pointer=args.pointer)
        logger.info("took %.2f seconds", time.time() - start)

        init = tf.global_variables_initializer()
        saver = tf.train.Saver()

        with tf.Session() as sess:
            sess.run(init)
            restore_checkpoint(sess, saver, config.model_output)
            model.fit(sess, saver, train, dev)
            #model.fit(sess, saver, args.data_train)
            return
Пример #5
0
def predict(args):
    config = Config(args)
    config.model_output = args.checkpoint_dir

    #config.batch_size = 1
    logger.info("Loading data...")
    start = time.time()
    if not args.load:
        train_vocab, test = load_and_preprocess_data(args, predict=True)
    else:
        train_vocab, test = load_data(args.pointer,
                                      predict=True,
                                      vocab_size=args.vocab_size)
    test_articles, test_headlines = read_data(args.data_test,
                                              args.data_test_labels)
    logger.info("took %.2f seconds", time.time() - start)
    print("")
    with tf.Graph().as_default():
        logger.info("Building model...", )
        start = time.time()
        model = Summarizer(config, train_vocab, False, args.attention,
                           args.beam_search, args.bidirectional, args.pointer)
        logger.info("took %.2f seconds", time.time() - start)

        init = tf.global_variables_initializer()
        saver = tf.train.Saver()

        with tf.Session() as sess:
            sess.run(init)
            restore_checkpoint(sess, saver, config.model_output, args.pointer)
            logger.info("Predicting...")
            start = time.time()

            predicted = model.predict(sess, test)
            logger.info("took %.2f seconds", time.time() - start)
            save_predictions(test_headlines, predicted, args)
Пример #6
0
		if args.which_data == 'cifar10':
			X.append(images.numpy()[0])
		else:
			if loaded_model.inputs[0].shape[1:] == images.numpy()[0].shape:
				X.append(images.numpy()[0])
			else:	
				if is_input_2d:
					X.append(images.numpy()[0].reshape(-1,)) # since (1,x,x,x)
				else:
					X.append(images.numpy()[0].reshape(1,-1))
		y.append(labels.item())

	X = np.asarray(X)
	y = np.asarray(y)
elif args.which_data in ['GTSRB', 'imdb', 'reuters', 'lfw', 'us_airline']: # gtsrb
	train_data, test_data = data_util.load_data(args.which_data, args.datadir, path_to_female_names = args.female_lst_file)
	if bool(args.is_train):
		X,y = train_data
	else:
		X,y = test_data

loaded_model = load_model(args.model)
loaded_model.summary()

ret_raw = False #True
if args.which_data in ['cifar10', 'GTSRB', 'lfw', 'us_airline']: # and also GTSRB
	predicteds = loaded_model.predict(X)
else:
	if loaded_model.inputs[0].shape[1:] == images.numpy()[0].shape:
		predicteds = loaded_model.predict(X)
	else:
Пример #7
0
parser.add_argument("-dest", action="store", default=".")
parser.add_argument("-patch_aggr", action='store', default=None, type=int)
parser.add_argument("-num_label", type=int, default=10)
parser.add_argument("-batch_size", type=int, default=None)
parser.add_argument("-target_layer_idx",
                    action="store",
                    default=-1,
                    type=int,
                    help="an index to the layer to localiser nws")

args = parser.parse_args()
os.makedirs(args.dest, exist_ok=True)
loc_dest = os.path.join(args.dest, "loc")
os.makedirs(loc_dest, exist_ok=True)

train_data, test_data = data_util.load_data(args.which_data, args.datadir)
predef_indices_to_wrong = data_util.get_misclf_for_rq2(
    args.target_indices_file, percent=0.1, seed=args.seed)

num_label = args.num_label
iter_num = args.iter_num
t1 = time.time()

patched_model_name, indices_to_target_inputs, indices_to_patched = auto_patch.patch(
    num_label,
    test_data,
    target_layer_idx=args.target_layer_idx,
    max_search_num=iter_num,
    search_method='DE',
    which=args.which,
    loc_method=args.loc_method,
def train():
    input_image = Input(shape=(None, None, 3), name='image', dtype=tf.float32)
    input_box = Input(shape=(None, 4, 2), name='word_box', dtype=tf.int32)
    input_word_length = Input(shape=(None,), name='word_length', dtype=tf.int32)
    input_region = Input(shape=(None, None), name='region', dtype=tf.float32)
    input_affinity = Input(shape=(None, None), name='affinity', dtype=tf.float32)
    input_confidence = Input(shape=(None, None), name='confidence', dtype=tf.float32)
    input_fg_mask = Input(shape=(None, None), name='fg_mask', dtype=tf.float32)
    input_bg_mask = Input(shape=(None, None), name='bg_mask', dtype=tf.float32)

    region, affinity = VGG16_UNet(input_tensor=input_image, weights='imagenet')

    region_gt = Lambda(lambda x: x)(input_region)
    affinity_gt = Lambda(lambda x: x)(input_affinity)
    confidence_gt = Lambda(lambda x: x)(input_confidence)

    loss_funs = [craft_mse_loss, craft_mae_loss, craft_huber_loss]

    loss_out = Lambda(loss_funs[2], output_shape=(1,), name='craft')(
        [region_gt, affinity_gt, region, affinity, confidence_gt, input_fg_mask, input_bg_mask])

    model = Model(inputs=[input_image, input_box, input_word_length, input_region,
                          input_affinity, input_confidence, input_fg_mask, input_bg_mask],
                  outputs=loss_out)

    callback_model = Model(inputs=[input_image, input_box, input_word_length, input_region,
                                   input_affinity, input_confidence],
                           outputs=[region, affinity, region_gt, affinity_gt])

    test_model = Model(inputs=[input_image], outputs=[region, affinity])
    test_model.summary()

    weight_path = 'weights/Syn_CTW_400k_linear_800.h5'
    if os.path.exists(weight_path):
        test_model.load_weights(weight_path)

    optimizer = Adam(lr=FLAGS.learning_rate)
    model.compile(loss={'craft': lambda y_true, y_pred: y_pred}, optimizer=optimizer)
    
    CTW_data_path = '/home/ldf/CRAFT_keras/data/CTW/'
    true_sample_list = load_data(os.path.join(FLAGS.truth_data_path, r'gt.pkl'))
    CTW_sample_list = load_data(os.path.join(CTW_data_path, 'gt.pkl'))
    num_Eng = len(CTW_sample_list) // 2
    np.random.shuffle(true_sample_list)
    train_sample_list = true_sample_list[:num_Eng] + CTW_sample_list

    np.random.shuffle(train_sample_list)

    if FLAGS.use_fake:
        pseudo_sample_list = load_data(os.path.join(FLAGS.pseudo_data_path, r'pse_gt.pkl'))
        np.random.shuffle(pseudo_sample_list)
        train_generator = SampleGenerator(test_model, [train_sample_list, pseudo_sample_list], [1, 5], [False, True],
                                          FLAGS.img_size, FLAGS.batch_size)

    else:
        train_generator = SampleGenerator(test_model, [train_sample_list], [1], [False],
                                          FLAGS.img_size, FLAGS.batch_size)


    steps_per_epoch = 1000

    tensor_board = CraftTensorBoard(log_dir=r'logs',
                                    write_graph=True,
                                    test_model=test_model,
                                    callback_model=callback_model,
                                    data_generator=train_generator,
                                    )

    model.fit_generator(generator=train_generator.next_train(),
                        steps_per_epoch=steps_per_epoch,
                        initial_epoch=0,
                        epochs=FLAGS.max_epochs,
                        callbacks=[train_generator, tensor_board]
                        )
Пример #9
0
def main():
    print 'load datas...'
    priors, train, orders, products, aisles, departments, sample_submission, order_streaks = data_util.load_data()

    groupby_features_train = pd.DataFrame()
    groupby_features_test = pd.DataFrame()

    if (not os.path.exists(Configure.groupby_features_train_path)) or \
            (not os.path.exists(Configure.groupby_features_test_path)):

        # # Product part

        # Products information ----------------------------------------------------------------
        # add order information to priors set
        priors_orders_detail = orders.merge(right=priors, how='inner', on='order_id')

        # create new variables
        ## _user_buy_product_times: 用户是第几次购买该商品
        priors_orders_detail.loc[:,'_user_buy_product_times'] = priors_orders_detail.groupby(['user_id', 'product_id']).cumcount() + 1
        # _prod_tot_cnts: 该商品被购买的总次数,表明被喜欢的程度
        # _reorder_tot_cnts_of_this_prod: 这件商品被再次购买的总次数
        ### 我觉得下面两个很不好理解,考虑改变++++++++++++++++++++++++++
        # _prod_order_once: 该商品被购买一次的总次数
        # _prod_order_more_than_once: 该商品被购买一次以上的总次数
        agg_dict = {'user_id':{'_prod_tot_cnts':'count'}, 
                    'reordered':{'_prod_reorder_tot_cnts':'sum'}, 
                    '_user_buy_product_times': {'_prod_buy_first_time_total_cnt':lambda x: sum(x==1),
                                                '_prod_buy_second_time_total_cnt':lambda x: sum(x==2)}}
        prd = ka_add_groupby_features_1_vs_n(priors_orders_detail, ['product_id'], agg_dict)

        # _prod_reorder_prob: 这个指标不好理解
        # _prod_reorder_ratio: 商品复购率
        prd['_prod_reorder_prob'] = prd._prod_buy_second_time_total_cnt / prd._prod_buy_first_time_total_cnt
        prd['_prod_reorder_ratio'] = prd._prod_reorder_tot_cnts / prd._prod_tot_cnts
        prd['_prod_reorder_times'] = 1 + prd._prod_reorder_tot_cnts / prd._prod_buy_first_time_total_cnt

        # # User Part

        # _user_total_orders: 用户的总订单数
        # 可以考虑加入其它统计指标++++++++++++++++++++++++++
        # _user_sum_days_since_prior_order: 距离上次购买时间(和),这个只能在orders表里面计算,priors_orders_detail不是在order level上面unique
        # _user_mean_days_since_prior_order: 距离上次购买时间(均值)
        agg_dict_2 = {'order_number':{'_user_total_orders':'max'},
                      'days_since_prior_order':{'_user_sum_days_since_prior_order':'sum', 
                                                '_user_mean_days_since_prior_order': 'mean'}}
        users = ka_add_groupby_features_1_vs_n(orders[orders.eval_set == 'prior'], ['user_id'], agg_dict_2)

        # _user_reorder_ratio: reorder的总次数 / 第一单后买后的总次数
        # _user_total_products: 用户购买的总商品数
        # _user_distinct_products: 用户购买的unique商品数
        # agg_dict_3 = {'reordered':
        #               {'_user_reorder_ratio': 
        #                lambda x: sum(priors_orders_detail.ix[x.index,'reordered']==1)/
        #                          sum(priors_orders_detail.ix[x.index,'order_number'] > 1)},
        #               'product_id':{'_user_total_products':'count', 
        #                             '_user_distinct_products': lambda x: x.nunique()}}
        # us = ka_add_groupby_features_1_vs_n(priors_orders_detail, ['user_id'], agg_dict_3)

        us = pd.concat([
        priors_orders_detail.groupby('user_id')['product_id'].count().rename('_user_total_products'),
        priors_orders_detail.groupby('user_id')['product_id'].nunique().rename('_user_distinct_products'),
        (priors_orders_detail.groupby('user_id')['reordered'].sum() /
        priors_orders_detail[priors_orders_detail['order_number'] > 1].groupby('user_id')['order_number'].count()).rename('_user_reorder_ratio')
        ], axis=1).reset_index()
        users = users.merge(us, how='inner')

        # 平均每单的商品数
        # 每单中最多的商品数,最少的商品数++++++++++++++
        users['_user_average_basket'] = users._user_total_products / users._user_total_orders

        us = orders[orders.eval_set != "prior"][['user_id', 'order_id', 'eval_set', 'days_since_prior_order']]
        us.rename(index=str, columns={'days_since_prior_order': 'time_since_last_order'}, inplace=True)

        users = users.merge(us, how='inner')


        # # Database Part

        # 这里应该还有很多变量可以被添加
        # _up_order_count: 用户购买该商品的次数
        # _up_first_order_number: 用户第一次购买该商品所处的订单数
        # _up_last_order_number: 用户最后一次购买该商品所处的订单数
        # _up_average_cart_position: 该商品被添加到购物篮中的平均位置
        agg_dict_4 = {'order_number':{'_up_order_count': 'count', 
                                      '_up_first_order_number': 'min', 
                                      '_up_last_order_number':'max'}, 
                      'add_to_cart_order':{'_up_average_cart_position': 'mean'}}

        data = ka_add_groupby_features_1_vs_n(df=priors_orders_detail, 
                                                              group_columns_list=['user_id', 'product_id'], 
                                                              agg_dict=agg_dict_4)

        data = data.merge(prd, how='inner', on='product_id').merge(users, how='inner', on='user_id')
        # 该商品购买次数 / 总的订单数
        # 最近一次购买商品 - 最后一次购买该商品
        # 该商品购买次数 / 第一次购买该商品到最后一次购买商品的的订单数
        data['_up_order_rate'] = data._up_order_count / data._user_total_orders
        data['_up_order_since_last_order'] = data._user_total_orders - data._up_last_order_number
        data['_up_order_rate_since_first_order'] = data._up_order_count / (data._user_total_orders - data._up_first_order_number + 1)

        # add user_id to train set
        train = train.merge(right=orders[['order_id', 'user_id']], how='left', on='order_id')
        data = data.merge(train[['user_id', 'product_id', 'reordered']], on=['user_id', 'product_id'], how='left')
        data = pd.merge(data, products[['product_id', 'aisle_id', 'department_id']], how='left', on='product_id')
        transform_categorical_data(data, ['aisle_id', 'department_id'])
        data = data.merge(order_streaks[['user_id', 'product_id', 'order_streak']], on=['user_id', 'product_id'], how='left')


        # release Memory
        # del train, prd, users
        # gc.collect()
        # release Memory
        #del priors_orders_detail
        del orders, order_streaks
        gc.collect()

        starting_size = sys.getsizeof(data)
        i = 0
        for c, dtype in zip(data.columns, data.dtypes):
            if 'int' in str(dtype):
                if min(data[c]) >=0:
                    max_int =  max(data[c])
                    if max_int <= 255:
                        data[c] = data[c].astype(np.uint8)
                    elif max_int <= 65535:
                        data[c] = data[c].astype(np.uint16)
                    elif max_int <= 4294967295:
                        data[c] = data[c].astype(np.uint32)
                    i += 1
        print("Number of colums adjusted: {}\n".format(i))
        ## Changing known reorderd col to smaller int size
        data['reordered'] = np.nan_to_num(data['reordered']).astype(np.uint8)
        data['reordered'][data['reordered']==0] = np.nan
        print("Reduced size {:.2%}".format(float(sys.getsizeof(data))/float(starting_size)))


        # # Create Train / Test
        train = data.loc[data.eval_set == "train",:]
        #train.drop(['eval_set', 'user_id', 'product_id', 'order_id'], axis=1, inplace=True)
        #train.loc[:, 'reordered'] = train.reordered.fillna(0)

        test = data.loc[data.eval_set == "test",:]
        #test.drop(['eval_set', 'user_id', 'product_id', 'order_id', 'reordered'], axis=1, inplace=True)
        #groupby_features_train = train
        #groupby_features_test = test

        # with open(Configure.groupby_features_train_path, "wb") as f:
        #     cPickle.dump(groupby_features_train, f, -1)
        # with open(Configure.groupby_features_test_path, "wb") as f:
        #     cPickle.dump(groupby_features_test, f, -1)

        print 'train:', train.shape, ', test:', test.shape
        print("Save data...")
        data_util.save_dataset(train, test)
        
    else:
        with open(Configure.groupby_features_train_path, "rb") as f:
            groupby_features_train = cPickle.load(f)
        with open(Configure.groupby_features_test_path, "rb") as f:
            groupby_features_test = cPickle.load(f)
Пример #10
0
def train():
    os.environ['CUDA_VISIBLE_DEVICES'] = FLAGS.gpu_list

    input_image = Input(shape=(None, None, 3), name='image', dtype=tf.float32)
    input_box = Input(shape=(None, 4, 2), name='word_box', dtype=tf.int32)
    input_word_length = Input(shape=(None,), name='word_length', dtype=tf.int32)
    input_region = Input(shape=(None, None), name='region', dtype=tf.float32)
    input_affinity = Input(shape=(None, None), name='affinity', dtype=tf.float32)
    input_confidence = Input(shape=(None, None), name='confidence', dtype=tf.float32)
    input_fg_mask = Input(shape=(None, None), name='fg_mask', dtype=tf.float32)
    input_bg_mask = Input(shape=(None, None), name='bg_mask', dtype=tf.float32)

    region, affinity = VGG16_UNet(input_tensor=input_image, weights='imagenet')

    # if FLAGS.use_fake:
    #     region_gt, affinity_gt, confidence_gt = \
    #         Fake(input_box, input_word_length, input_region, input_affinity, input_confidence, name='fake')(region)
    # else:
    #     region_gt = Lambda(lambda x: x)(input_region)
    #     affinity_gt = Lambda(lambda x: x)(input_affinity)
    #     confidence_gt = Lambda(lambda x: x)(input_confidence)
    region_gt = Lambda(lambda x: x)(input_region)
    affinity_gt = Lambda(lambda x: x)(input_affinity)
    confidence_gt = Lambda(lambda x: x)(input_confidence)

    loss_funs = [craft_mse_loss, craft_mae_loss, craft_huber_loss]

    loss_out = Lambda(loss_funs[2], output_shape=(1,), name='craft')(
        [region_gt, affinity_gt, region, affinity, confidence_gt, input_fg_mask, input_bg_mask])

    model = Model(inputs=[input_image, input_box, input_word_length, input_region,
                          input_affinity, input_confidence, input_fg_mask, input_bg_mask],
                  outputs=loss_out)

    callback_model = Model(inputs=[input_image, input_box, input_word_length, input_region,
                                   input_affinity, input_confidence],
                           outputs=[region, affinity, region_gt, affinity_gt])

    test_model = Model(inputs=[input_image], outputs=[region, affinity])
    test_model.summary()

    weight_path = r'weights/weight.h5'
    if os.path.exists(weight_path):
        test_model.load_weights(weight_path)

    # optimizer = SGD(lr=FLAGS.learning_rate, decay=1e-6, momentum=0.9, nesterov=True, clipnorm=5)
    optimizer = Adam(lr=FLAGS.learning_rate)
    model.compile(loss={'craft': lambda y_true, y_pred: y_pred}, optimizer=optimizer)

    true_sample_list = load_data(os.path.join(FLAGS.truth_data_path, r'gt.pkl'))

    train_sample_list = true_sample_list

    np.random.shuffle(train_sample_list)

    if FLAGS.use_fake:
        pseudo_sample_list = load_data(os.path.join(FLAGS.pseudo_data_path, r'gt.pkl'))
        np.random.shuffle(pseudo_sample_list)
        train_generator = SampleGenerator(test_model, [train_sample_list, pseudo_sample_list], [5, 1], [False, True],
                                          FLAGS.img_size, FLAGS.batch_size)
        # tensor_board_data_generator = SampleGenerator(test_model, [pseudo_sample_list], [1], [True],
        #                                               FLAGS.img_size, FLAGS.batch_size)
    else:
        train_generator = SampleGenerator(test_model, [train_sample_list], [1], [False],
                                          FLAGS.img_size, FLAGS.batch_size)
        # tensor_board_data_generator = SampleGenerator(test_model, [train_sample_list], [1], [False],
        #                                               FLAGS.img_size, FLAGS.batch_size)

    # train_generator.init_sample(True)

    # val_pkl_path = os.path.join(FLAGS.val_data_path, r'gt.pkl')
    # if os.path.exists(val_pkl_path):
    #     val_sample_list = load_data(val_pkl_path)

    steps_per_epoch = 1000

    tensor_board = CraftTensorBoard(log_dir=r'logs',
                                    write_graph=False,
                                    test_model=test_model,
                                    callback_model=callback_model,
                                    data_generator=train_generator,
                                    )

    model.fit_generator(generator=train_generator.next_train(),
                        steps_per_epoch=steps_per_epoch,
                        initial_epoch=0,
                        epochs=FLAGS.max_epochs,
                        callbacks=[train_generator, tensor_board]
                        )