예제 #1
0
def test(**kwargs):
    conf.parse(kwargs)

    model = Network().eval()

    if conf.LOAD_MODEL_PATH:
        print(conf.LOAD_MODEL_PATH)
        model.load_state_dict(torch.load(conf.CHECKPOINTS_ROOT + conf.LOAD_MODEL_PATH))

    device = torch.device('cuda:0' if conf.USE_GPU else 'cpu')
    model.to(device)

    test_set = ImageFolder(conf.TEST_DATA_ROOT, transform)
    test_loader = DataLoader(test_set, conf.BATCH_SIZE,
                             shuffle=False,
                             num_workers=conf.NUM_WORKERS)

    results = list()

    with torch.no_grad():
        for step, (inputs, targets) in enumerate(test_loader):
            inputs, targets = inputs.to(device), targets.to(device)
            outs = model(inputs)
            pred = torch.max(outs, 1)[1]
            # print((targets == pred).float())
            # (prob_top_k, idxs_top_k) = probability.topk(3, dim=1)

            acc = (pred == targets).float().sum() / len(targets)
            results += ((pred == targets).float().to('cpu').numpy().tolist())

            print('[%5d] acc: %.3f' % (step + 1, acc))

        results = np.array(results)
        print('Top 1 acc: %.3f' % (np.sum(results) / len(results)))
예제 #2
0
파일: test.py 프로젝트: 530824679/YOLOv2
def predict_image():
    image_path = "/home/chenwei/HDD/Project/datasets/object_detection/FDDB2016/convert/images/2002_07_19_big_img_130.jpg"

    image = cv2.imread(image_path)
    image_size = image.shape[:2]
    input_shape = [model_params['input_height'], model_params['input_width']]
    image_data = pre_process(image, input_shape)
    image_data = image_data[np.newaxis, ...]

    input = tf.placeholder(shape=[1, None, None, 3], dtype=tf.float32)

    network = Network(is_train=False)
    logits = network.build_network(input)
    output = network.reorg_layer(logits, model_params['anchors'])

    checkpoints = "./checkpoints/model.ckpt-128"
    saver = tf.train.Saver()
    with tf.Session() as sess:
        saver.restore(sess, checkpoints)
        bboxes, obj_probs, class_probs = sess.run(
            output, feed_dict={input: image_data})

    bboxes, scores, class_id = postprocess(bboxes,
                                           obj_probs,
                                           class_probs,
                                           image_shape=image_size,
                                           input_shape=input_shape)

    img_detection = visualization(image, bboxes, scores, class_id,
                                  model_params["classes"])
    cv2.imshow("result", img_detection)
    cv2.waitKey(0)
예제 #3
0
def test_classify_digits():
    retina = Retina(32)
    layer_level1 = Layer(8, 'layer_1')
    layer_level2 = Layer(4, 'layer_2')
    layer_level3 = Layer(1, 'layer_3')
    layers = [layer_level1, layer_level2, layer_level3]
    ConnectTypes.rectangle_connect(retina.vision_cells, layer_level1, 0, 0)
    ConnectTypes.rectangle_connect(layer_level1.nodes, layer_level2, 0, 0)
    ConnectTypes.rectangle_connect(layer_level2.nodes, layer_level3, 0, 0)

    network = Network(layers, retina)
    cca_v1 = CommonCorticalAlgorithmV1(network)

    number_training_timesteps = 1
    t = 0
    print_to_console = True
    # train network on digit dataset to form memory and temporal groups
    with ZipFile('model/datasets/digit_0.zip') as archive:
        for entry in archive.infolist():
            with archive.open(entry) as file:
                binary_image = Image.open(file)
                if print_to_console:
                    print('timestep = ' + str(t))
                input_layer = retina.see_binary_image(binary_image, print_to_console)

                # run 1 time step for all levels in hierarchy?
                cca_v1.learn_one_time_step(input_layer)
                _save_model_at_current_timestep(t, network)
                t += 1
                # now we have trained the network using cca_v1 on data set
                if t >= number_training_timesteps:
                    break
예제 #4
0
    def __init__(self,
                 num_samples,
                 burn_in,
                 population_size,
                 topology,
                 train_data,
                 test_data,
                 directory,
                 temperature,
                 swap_sample,
                 parameter_queue,
                 problem_type,
                 main_process,
                 event,
                 active_chains,
                 num_accepted,
                 swap_interval,
                 max_limit=(5),
                 min_limit=-5):
        # Multiprocessing attributes
        multiprocessing.Process.__init__(self)
        self.process_id = temperature
        self.parameter_queue = parameter_queue
        self.signal_main = main_process
        self.event = event
        self.active_chains = active_chains
        self.num_accepted = num_accepted
        self.event.clear()
        self.signal_main.clear()
        # Parallel Tempering attributes
        self.temperature = temperature
        self.swap_sample = swap_sample
        self.swap_interval = swap_interval
        self.burn_in = burn_in
        # MCMC attributes
        self.num_samples = num_samples

        self.topology = topology
        self.pop_size = population_size
        self.train_data = train_data
        self.test_data = test_data
        self.problem_type = problem_type
        self.directory = directory
        self.w_size = (topology[0] * topology[1]) + (
            topology[1] * topology[2]) + topology[1] + topology[2]
        self.neural_network = Network(topology, train_data, test_data)
        self.min_limits = np.repeat(min_limit, self.w_size)
        self.max_limits = np.repeat(max_limit, self.w_size)
        self.initialize_sampling_parameters()
        max_limit_vel = (self.weights_stepsize) * (self.weights_stepsize) * 10
        min_limit_vel = self.weights_stepsize * self.weights_stepsize * -10
        self.min_limits_vel = np.repeat(min_limit_vel, self.w_size)
        self.max_limits_vel = np.repeat(max_limit_vel, self.w_size)
        self.create_directory(directory)
        PSO.__init__(self, self.pop_size, self.w_size, self.max_limits,
                     self.min_limits, self.neural_network.evaluate_fitness,
                     opt.problem_type, self.max_limits_vel,
                     self.min_limits_vel)
예제 #5
0
def predict():
    fasterRCNN = Network()
    fasterRCNN.build(is_training=False)
    saver = tf.train.Saver()
    with tf.Session() as sess:
        saver.restore(sess, os.path.join(CHECKPOINTS_PATH, "model_final.ckpt"))
        print("Model restored.")
        base_extractor = VGG16(include_top=False)
        extractor = Model(inputs=base_extractor.input, outputs=base_extractor.get_layer('block5_conv3').output)
        predict_img_names = os.listdir(PREDICT_IMG_DATA_PATH)

        for predict_img_name in predict_img_names:
            img_data, img_info = get_predict_data(predict_img_name)
            features = extractor.predict(img_data, steps=1)
            rois, scores, regression_parameter = sess.run(
                [fasterRCNN._predictions["rois"], fasterRCNN._predictions["cls_prob"],
                 fasterRCNN._predictions["bbox_pred"]],
                feed_dict={fasterRCNN.feature_map: features,
                           fasterRCNN.image_info: img_info})

            boxes = rois[:, 1:5] / img_info[2]
            scores = np.reshape(scores, [scores.shape[0], -1])
            regression_parameter = np.reshape(regression_parameter, [regression_parameter.shape[0], -1])
            pred_boxes = bbox_transform_inv(boxes, regression_parameter)
            pred_boxes = clip_boxes(pred_boxes, [img_info[0] / img_info[2], img_info[1] / img_info[2]])

            result_list = []
            for class_index, class_name in enumerate(CLASSES[1:]):
                class_index += 1  # 因为跳过了背景类别
                cls_boxes = pred_boxes[:, 4 * class_index:4 * (class_index + 1)]  # TODO:
                cls_scores = scores[:, class_index]
                detections = np.hstack((cls_boxes, cls_scores[:, np.newaxis])).astype(np.float32)
                keep = nms(detections, NMS_THRESH)
                detections = detections[keep, :]

                inds = np.where(detections[:, -1] >= CONF_THRESH)[0]  # 筛选结果
                for i in inds:
                    result_for_a_class = []
                    bbox = detections[i, :4]
                    score = detections[i, -1]
                    result_for_a_class.append(predict_img_name)
                    result_for_a_class.append(class_name)
                    result_for_a_class.append(score)
                    for coordinate in bbox:
                        result_for_a_class.append(coordinate)
                    result_list.append(result_for_a_class)
                    # result_for_a_class = [fileName,class_name,score,x1,y1,x2,y2]
            if len(result_list) == 0:
                continue

            if TXT_RESULT_WANTED:
                write_txt_result(result_list)

            if IS_VISIBLE:
                visualization(result_list)
예제 #6
0
파일: main.py 프로젝트: Carl-Rabbit/IMP
def run(network_lines, k, model, network_input=None):
    global rrset_func, heap, top_idx

    if model == 'IC':
        rrset_func = get_ic_rrset
    elif model == 'LT':
        rrset_func = get_lt_rrset
    else:
        raise Exception('model type error')

    network: Network

    if network_input:
        network = network_input
    else:
        network = Network(network_lines)

    # step 1
    # t0 = time.time()
    # rrsets: List[Set] = sampling(network, k, e, l)
    rrsets: List[List] = sampling(network, k, model)

    del network

    # step 2
    # print(f'len(rrsets) = {len(rrsets)}')

    # t1 = time.time()
    # init vtx2sid_lst
    vtx2sid_list = get_vtx2sid_lst(rrsets)
    del rrsets

    # init heap
    init_heap(vtx2sid_list)
    del vtx2sid_list

    # t2 = time.time()

    # node selection
    seeds = node_selection(k)
    # t3 = time.time()

    # print(f'sampling: {t1 - t0}\ninit heap: {t2 - t1}\nselection: {t3 - t2}')

    # remember to + 1 because I store the index begin at 0
    output_lst = [str(i + 1) for i in seeds]
    print('\n'.join(output_lst))

    print('write to test folder')
    write_lines(
        '../DatasetOnTestPlatform/my_' + str(model).lower() + '_seeds.txt',
        output_lst)
예제 #7
0
def train(**kwargs):
    conf.parse(kwargs)

    # train_set = DataSet(cfg, train=True, test=False)
    train_set = ImageFolder(conf.TRAIN_DATA_ROOT, transform)
    train_loader = DataLoader(train_set, conf.BATCH_SIZE,
                              shuffle=True,
                              num_workers=conf.NUM_WORKERS)

    model = Network()

    if conf.LOAD_MODEL_PATH:
        print(conf.LOAD_MODEL_PATH)
        model.load_state_dict(torch.load(conf.CHECKPOINTS_ROOT + conf.LOAD_MODEL_PATH))

    device = torch.device('cuda:0' if conf.USE_GPU else 'cpu')
    criterion = nn.CrossEntropyLoss().to(device)
    lr = conf.LEARNING_RATE
    optim = torch.optim.Adam(params=model.parameters(),
                             lr=lr,
                             weight_decay=conf.WEIGHT_DECAY)
    model.to(device)

    for epoch in range(conf.MAX_EPOCH):

        model.train()
        running_loss = 0
        for step, (inputs, targets) in tqdm(enumerate(train_loader)):

            inputs, targets = inputs.to(device), targets.to(device)
            optim.zero_grad()
            outs = model(inputs)
            loss = criterion(outs, targets)
            loss.backward()
            optim.step()

            running_loss += loss.item()
            if step % conf.PRINT_FREQ == conf.PRINT_FREQ - 1:
                running_loss = running_loss / conf.PRINT_FREQ
                print('[%d, %5d] loss: %.3f' % (epoch + 1, step + 1, running_loss))
                # vis.plot('loss', running_loss)
                running_loss = 0



        torch.save(model.state_dict(), conf.CHECKPOINTS_ROOT + time.strftime('%Y-%m-%d-%H-%M-%S.pth'))

        for param_group in optim.param_groups:
            lr *= conf.LEARNING_RATE_DECAY
            param_group['lr'] = lr
예제 #8
0
def train():
    fasterRCNN = Network()
    fasterRCNN.build(is_training=True)
    train_op = tf.train.MomentumOptimizer(learning_rate=0.001,
                                          momentum=0.9).minimize(
                                              fasterRCNN._losses['total_loss'])
    init_op = tf.global_variables_initializer()
    saver = tf.train.Saver()
    with tf.Session() as sess:
        sess.run(init_op)

        base_extractor = VGG16(include_top=False)
        extractor = Model(
            inputs=base_extractor.input,
            outputs=base_extractor.get_layer('block5_conv3').output)
        train_img_names = os.listdir(TRAIN_IMG_DATA_PATH)
        trained_times = 0

        for epoch in range(1, MAX_EPOCH + 1):
            random.shuffle(train_img_names)
            for train_img_name in train_img_names:
                img_data, boxes, img_info = get_train_data(train_img_name)
                features = extractor.predict(img_data, steps=1)
                sess.run(train_op,
                         feed_dict={
                             fasterRCNN.feature_map: features,
                             fasterRCNN.gt_boxes: boxes,
                             fasterRCNN.image_info: img_info
                         })

                trained_times += 1
                if trained_times % 10 == 0:
                    total_loss = sess.run(fasterRCNN._losses['total_loss'],
                                          feed_dict={
                                              fasterRCNN.feature_map: features,
                                              fasterRCNN.gt_boxes: boxes,
                                              fasterRCNN.image_info: img_info
                                          })
                    print('epoch:{}, trained_times:{}, loss:{}'.format(
                        epoch, trained_times, total_loss))

            if epoch % 10 == 0:
                save_path = saver.save(
                    sess,
                    os.path.join(CHECKPOINTS_PATH,
                                 "model_" + str(epoch) + ".ckpt"))
                print("Model saved in path: %s" % save_path)
        save_path = saver.save(
            sess, os.path.join(CHECKPOINTS_PATH, "model_final.ckpt"))
        print("Model saved in path: %s" % save_path)
예제 #9
0
def get_drugs_related_info(disease_pairs):
    networks = []
    for i in disease_pairs:
        networks.append(Network())
    networks = get_common_drugs(disease_pairs, networks, True)
    drugs = []
    for network in networks:
        pair_drugs = network.get_nodes_by_label('Drug')
        pair_drugs_ids = []
        for d in pair_drugs:
            pair_drugs_ids.append(d.id)
        drugs.append(pair_drugs_ids)
    networks = get_given_drugs_related_info(disease_pairs, drugs)
    return networks
def predict():
    def map_char(char):
        return labels[char]

    if (request.method == "POST"):
        img = request.get_json()
        img = preprocess(img)
        # plt.imshow(img[0])
        # plt.savefig('image_debug.jpg')
        # plt.close()
        net = Network(network_config)
        character, confidence = net.predict_with_pretrained_weights(img)
        print('received', character, confidence)
        data = {
            "character": map_char(character),
            "confidence": float(int(confidence * 100)) / 100.
        }
        return jsonify(data)
예제 #11
0
파일: test.py 프로젝트: 530824679/YOLOv2
def predict_video():
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True

    capture = cv2.VideoCapture(0)

    input = tf.placeholder(shape=[1, None, None, 3], dtype=tf.float32)

    network = Network(is_train=False)
    logits = network.build_network(input)
    output = network.reorg_layer(logits, model_params['anchors'])

    checkpoints = "./checkpoints/model.ckpt-128"
    saver = tf.train.Saver()

    with tf.Session(config=config) as sess:
        saver.restore(sess, checkpoints)

        while (True):
            ref, image = capture.read()

            image_size = image.shape[:2]
            input_shape = [
                model_params['input_height'], model_params['input_width']
            ]
            image_data = pre_process(image, input_shape)
            image_data = image_data[np.newaxis, ...]

            bboxes, obj_probs, class_probs = sess.run(
                output, feed_dict={input: image_data})

            bboxes, scores, class_id = postprocess(bboxes,
                                                   obj_probs,
                                                   class_probs,
                                                   image_shape=image_size,
                                                   input_shape=input_shape)

            img_detection = visualization(image, bboxes, scores, class_id,
                                          model_params["classes"])
            cv2.imshow("result", img_detection)
            cv2.waitKey(1)

    cv2.destroyAllWindows()
예제 #12
0
파일: test.py 프로젝트: 530824679/YOLOv3
def predict_image():
    image_path = "/home/chenwei/HDD/Project/datasets/object_detection/VOCdevkit/VOC2007/JPEGImages/000066.jpg"
    image = cv2.imread(image_path)
    image_size = image.shape[:2]
    input_shape = [model_params['input_height'], model_params['input_width']]
    image_data = preporcess(image, input_shape)
    image_data = image_data[np.newaxis, ...]

    input = tf.placeholder(shape=[1, input_shape[0], input_shape[1], 3],
                           dtype=tf.float32)

    model = Network(len(model_params['classes']),
                    model_params['anchors'],
                    is_train=False)
    with tf.variable_scope('yolov3'):
        logits = model.build_network(input)
        output = model.inference(logits)

    checkpoints = "/home/chenwei/HDD/Project/YOLOv3/weights/yolov3.ckpt"
    saver = tf.train.Saver()
    with tf.Session() as sess:
        saver.restore(sess, checkpoints)
        bboxes, obj_probs, class_probs = sess.run(
            output, feed_dict={input: image_data})

    bboxes, scores, class_max_index = postprocess(bboxes,
                                                  obj_probs,
                                                  class_probs,
                                                  image_shape=image_size,
                                                  input_shape=input_shape)

    resize_ratio = min(input_shape[1] / image_size[1],
                       input_shape[0] / image_size[0])
    dw = (input_shape[1] - resize_ratio * image_size[1]) / 2
    dh = (input_shape[0] - resize_ratio * image_size[0]) / 2
    bboxes[:, [0, 2]] = (bboxes[:, [0, 2]] - dw) / resize_ratio
    bboxes[:, [1, 3]] = (bboxes[:, [1, 3]] - dh) / resize_ratio

    img_detection = visualization(image, bboxes, scores, class_max_index,
                                  model_params["classes"])
    cv2.imshow("result", img_detection)
    cv2.waitKey(0)
예제 #13
0
def get_disease_pairs_info(disease_pairs, writing_files):
    networks = []
    for disease_pair in disease_pairs:
        networks.append(Network())
    networks = get_common_genes(disease_pairs, networks, writing_files)
    networks = get_common_drugs(disease_pairs, networks, writing_files)
    networks = get_common_rnas(disease_pairs, networks, writing_files)
    networks = get_common_variants(disease_pairs, networks, writing_files)
    if writing_files:
        for index, disease_pair in enumerate(disease_pairs):
            temp_id1 = disease_pair[0].replace(':', '-')
            temp_id2 = disease_pair[1].replace(':', '-')
            path = '../analysis/disease_pairs/' + temp_id1 + '_' + temp_id2
            try:
                os.mkdir(path)
            except FileExistsError:
                pass
            network = networks[index]
            network.save(path + '/' + temp_id1 + '_' + temp_id2 + '_full_graph.json')
    return networks
예제 #14
0
def remove_optimizers_params():
    ckpt_path = ''
    class_num = 2
    save_dir = 'shrinked_ckpt'
    if not os.path.exists(save_dir):
        os.makedirs(save_dir)

    anchors = [[676, 197], [763, 250], [684, 283], [868, 231], [745, 273],
               [544, 391], [829, 258], [678, 316, 713, 355]]

    image = tf.placeholder(tf.float32, [1, 416, 416, 3])
    model = Network(class_num, anchors, False)
    with tf.variable_scope('yolov3'):
        feature_maps = model.build_network(image)

    saver_to_restore = tf.train.Saver()
    saver_to_save = tf.train.Saver()

    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())
        saver_to_restore.restore(sess, ckpt_path)
        saver_to_save.save(sess, save_dir + '/shrinked')
예제 #15
0
파일: main.py 프로젝트: Carl-Rabbit/IMP
def imm_main():
    global start_time, time_limit

    file_name, seed_cnt, model, time_limit = param_parse()

    # time_limit *= 10      # for test

    start_time = time.time()

    # print_tmp('Params:', file_name, seed_cnt, model, time_limit)

    # network_lines = read_lines(file_name)
    # run(network_lines, seed_cnt, model)

    network = Network()
    with open(file_name, 'r') as fp:
        network.parse_first_line(fp.readline())
        for _ in range(network.m):
            network.parse_data_line(fp.readline())
    run(None, seed_cnt, model, network)

    sys.stdout.flush()
    print('time cost: ', time.time() - start_time)
예제 #16
0
def weights_to_ckpt():
    num_class = 80
    image_size = 416
    anchors = [[676, 197], [763, 250], [684, 283], [868, 231], [745, 273],
               [544, 391], [829, 258], [678, 316, 713, 355]]
    weight_path = '../weights/yolov3.weights'
    save_path = '../weights/yolov3.ckpt'

    model = Network(num_class, anchors, False)
    with tf.Session() as sess:
        inputs = tf.placeholder(tf.float32, [1, image_size, image_size, 3])

        with tf.variable_scope('yolov3'):
            feature_maps = model.build_network(inputs)

        saver = tf.train.Saver(var_list=tf.global_variables(scope='yolov3'))

        load_ops = load_weights(tf.global_variables(scope='yolov3'),
                                weight_path)
        sess.run(load_ops)
        saver.save(sess, save_path=save_path)
        print('TensorFlow model checkpoint has been saved to {}'.format(
            save_path))
예제 #17
0
device = 'cuda:0' if torch.cuda.is_available() else 'cpu'

def parse_args():
	parser = argparse.ArgumentParser()
	parser.add_argument('image', help='Load image to classify')
	parser.add_argument('model', help='Load trained model')

	return parser.parse_args()

def to_img(tensor: torch.Tensor):
	return ToPILImage()(tensor.cpu().view(1, 28, 28))

if __name__ == '__main__':
	args = parse_args()

	net = Network(1, 64, 5, 10).to(device)
	net.load_state_dict(torch.load(args.model, map_location=device))

	image = Image.open(args.image).convert('L').resize((28, 28))
	image = ToTensor()(image).to(device)

	image = image.view(1, *image.shape)
	
	pred, feat = net.predict_with_feature(image)
	pred = F.softmax(pred, dim=1)

	grid = plt.GridSpec(3, 2, wspace=0.2, hspace=0.2)

	orig_img_plot = plt.subplot(grid[0, 0])
	orig_img_plot.set_title('input image')
	orig_img_plot.xaxis.set_major_locator(plt.NullLocator())
예제 #18
0
    print(f"{gct()} : model init")
    det = RFDet(
        cfg.TRAIN.score_com_strength,
        cfg.TRAIN.scale_com_strength,
        cfg.TRAIN.NMS_THRESH,
        cfg.TRAIN.NMS_KSIZE,
        cfg.TRAIN.TOPK,
        cfg.MODEL.GAUSSIAN_KSIZE,
        cfg.MODEL.GAUSSIAN_SIGMA,
        cfg.MODEL.KSIZE,
        cfg.MODEL.padding,
        cfg.MODEL.dilation,
        cfg.MODEL.scale_list,
    )
    des = HardNetNeiMask(cfg.HARDNET.MARGIN, cfg.MODEL.COO_THRSH)
    model = Network(det, des, cfg.LOSS.SCORE, cfg.LOSS.PAIR, cfg.PATCH.SIZE,
                    cfg.TRAIN.TOPK)

    print(f"{gct()} : to device")
    device = torch.device("cuda")
    model = model.to(device)
    resume = args.resume
    print(f"{gct()} : in {resume}")
    checkpoint = torch.load(resume)
    model.load_state_dict(checkpoint["state_dict"])

    ###############################################################################
    # detect and compute
    ###############################################################################
    img1_path, img2_path = args.imgpath.split("@")
    kp1, des1, img1, _, _ = model.detectAndCompute(img1_path, device,
                                                   (600, 460))
예제 #19
0
        f.write(os.path.join(config['Neo4j']['bin-path'], 'neo4j-admin'))
        f.write(' import ' +
                '--database %s ' % config['Neo4j']['database-name'] +
                ' '.join(['--nodes %s' % x
                          for x in node_import_files]) + ' ' + ' '.join([
                              '--relationships rel_%s.csv' % x
                              for x in network.edge_labels()
                          ]) + ' > import.log\n')


if __name__ == '__main__':
    with io.open('../data/config.json', 'r', encoding='utf-8',
                 newline='') as f:
        config = json.load(f)

    network = Network()
    # Import
    graphs = [
        '../data/EBI-GOA-miRNA/graph.json',
        '../data/miRTarBase/graph.json',
        '../data/RNAInter/graph.json',
        '../data/DisGeNet/graph.json',
        '../data/DrugBank/graph.json',
        '../data/DrugCentral/graph.json',
        '../data/GWAS-Catalog/graph.json',
        '../data/HGNC/graph.json',
        '../data/HPO/graph.json',
        '../data/MED-RT/graph.json',
        '../data/NDF-RT/graph.json',
        '../data/OMIM/graph.json',
        '../data/HuGE-Navigator/graph.json',
예제 #20
0
    # Creating CNN model
    det = RFDet(
        cfg.TRAIN.score_com_strength,
        cfg.TRAIN.scale_com_strength,
        cfg.TRAIN.NMS_THRESH,
        cfg.TRAIN.NMS_KSIZE,
        args.k,
        cfg.MODEL.GAUSSIAN_KSIZE,
        cfg.MODEL.GAUSSIAN_SIGMA,
        cfg.MODEL.KSIZE,
        cfg.MODEL.padding,
        cfg.MODEL.dilation,
        cfg.MODEL.scale_list,
    )
    des = HardNetNeiMask(cfg.HARDNET.MARGIN, cfg.MODEL.COO_THRSH)
    model = Network(det, des, cfg.LOSS.SCORE, cfg.LOSS.PAIR, cfg.PATCH.SIZE,
                    args.k)
    model = model.to(device=device)
    checkpoint = torch.load(model_file)
    model.load_state_dict(checkpoint["state_dict"])

    random.seed(cfg.PROJ.SEED)
    torch.manual_seed(cfg.PROJ.SEED)
    np.random.seed(cfg.PROJ.SEED)

    root_dir = '/home/wang/workspace/RFSLAM_offline/RFNET/data/'
    csv_file = None
    seq = None
    a = None
    if args.data == 'v':
        csv_file = 'hpatch_view.csv'
        root_dir += 'hpatch_v_sequence'
예제 #21
0
calculationStart = time.clock()

startTs = datetime.datetime(2019, 1, 1, 7, 0, 0)
totalSteps = 20000  #2500
timeStep = 1
jamDensity = 124
medianValueTime = 50
random.seed(10)

vehicleId = 0
GEN_VEH_DIST = 'normal_whole'  # ["uniform", "random", "random_whole", "normal_whole"]
STRATEGY = 'vol_sim'  # ['vol_sim', 'vol_dist', 'random', 'fix']
MULTIVEH = 1  #[default=1, 2, 3,...]
NO_CHARGE = False

network = Network(startTs)

#fNode = open("C:/Users/lyy90/OneDrive/Documents/GitHub/meso_v2.0/Sioux Falls network/nodes-SiouxFalls_gong.csv")
fNode = open("F:/meso_v2.0/Sioux Falls network/nodes-SiouxFalls_gong.csv")
fNode.readline()
#fLane = open("C:/Users/lyy90/OneDrive/Documents/GitHub/meso_v2.0/Sioux Falls network/lanes-SiouxFalls_gong.csv")
fLane = open("F:/meso_v2.0/Sioux Falls network/lanes-SiouxFalls_gong.csv")
fLane.readline()
#pOd = 'C:/Users/lyy90/OneDrive/Documents/GitHub/meso_v2.0/OD_data'
pOd = "F:/meso_v2.0/OD_data"

readNodes(fNode, network)
readLanes(fLane, network)
tsPairNodePairTypeMap = readOd(pOd)
#print(tsPairNodePairTypeMap)
genVehicle(tsPairNodePairTypeMap, GEN_VEH_DIST, vehicleId, medianValueTime,
# -*- coding: utf-8 -*-

import os

from config import Config

from model.data import Data
from model.network import Network

if __name__ == '__main__':
    dir_path = os.path.join(os.path.dirname(os.path.realpath(__file__)),
                            Config.CURRENT_MODEL_BASE_PATH)
    data = Data()
    data.load_data_from_file(os.path.join(dir_path,
                                          'data.nosync/all_data.npy'))
    model = Network(data, os.path.join(dir_path, 'log.nosync/network/run1'))
    model.train()
예제 #23
0
 def setUp(self):
     self.net = Network()
     path = '../../ratdata/bowden/rat_items'
     self.items = np.loadtxt(path, dtype=np.character)
     self.nr_words = 10
     self.net.max_visited = self.nr_words
예제 #24
0
#     5000,
#     10000,
#     50000,
#     100000,
#     500000,
#     1000000,
#     5000000,
# ]

cnt_list = [5_0000 * i for i in range(1, 101)]

seed_cnt = 500

if __name__ == '__main__':
    network_lines = read_lines('../../DatasetOnTestPlatform/NetHEPT.txt')
    network: Network = Network(network_lines)

    create_workers(network.vs)

    for cnt in cnt_list:
        rrsets = []  # clear it
        t = -time.time()
        fill_rrsets(rrsets, cnt)
        t += time.time()
        print(f'fill_rrsets, {cnt}, {t}')

        t = -time.time()
        _ = node_selection(rrsets, seed_cnt)
        t += time.time()
        print(f'selection, {cnt}, {t}')
예제 #25
0
            'F = F | E = T, G = T': 0.3,
            'F = T | E = T, G = F': 0.8,
            'F = F | E = T, G = F': 0.2,
            'F = T | E = F, G = T': 0.15,
            'F = F | E = F, G = T': 0.85,
            'F = T | E = F, G = F': 0.35,
            'F = F | E = F, G = F': 0.65,
        }
    }, {
        'G': {
            'G = T': 0.85,
            'G = F': 0.15
        }
    }]

    net = Network(possibleValues, parents, CPTs)
    order = net.topoSort()

    query = {'F': 'T'}

    condition = {
        'D': 'F',
    }

    num = {**query, **condition}

    num_factors = net.getFactors(num)
    den_factors = net.getFactors(condition)

    num = sum_product_ve(num_factors, order, possibleValues)
    den = sum_product_ve(den_factors, order, possibleValues)
예제 #26
0
def get_given_drugs_related_info(disease_pairs, drugs):   # first disease pair with first drug array
    all_networks = []   # contains an array for each disease pair
    for index, disease_pair in enumerate(disease_pairs):
        networks_per_drug = []  # contains a network for each drug
        pair_drugs_ids = drugs[index]
        temp_id1 = disease_pair[0].replace(':', '-')
        temp_id2 = disease_pair[1].replace(':', '-')
        path = '../analysis/disease_pairs/' + temp_id1 + '_' + temp_id2
        for drug_id in pair_drugs_ids:
            try:
                os.mkdir(path)
            except FileExistsError:
                pass

            network = Network()
            d1 = Disease([disease_pair[0]], [])
            network.add_node(d1)
            d2 = Disease([disease_pair[1]], [])
            network.add_node(d2)
            drug = Drug([drug_id], [])
            network.add_node(drug)
            temp_drug_id = drug_id.replace(':', '-')
            with io.open(path + '/' + temp_id1 + '_' + temp_id2 + '_' + temp_drug_id + '_results.txt', 'w', encoding='utf-8', newline='') as results_file:
                results_file.write('In this file all information about the connection between ' + disease_pair[0] +
                                   ' and ' + disease_pair[1] + ' and the drug ' + drug_id + ' is summarized:\n')

                # the drug INDICATES, CONTRAINDICATES or INDUCES the disease
                query = """ MATCH (d:Disease)-[a]-(n:Drug) WHERE {d1_id} IN d.ids AND {n_id} in n.ids RETURN distinct(type(a)) """
                d1_results = session.run(query, parameters={'d1_id': disease_pair[0], 'n_id': drug_id})
                for result in d1_results:
                    results_file.write(drug_id + ' ' + result['(type(a))'] + ' ' + disease_pair[0] + '\n')
                    network.add_edge(Edge(drug, d1, result['(type(a))'], {}))
                query = """ MATCH (d:Disease)-[a]-(n:Drug) WHERE {d2_id} IN d.ids AND {n_id} in n.ids RETURN distinct(type(a)) """
                d2_results = session.run(query, parameters={'d2_id': disease_pair[1], 'n_id': drug_id})
                for result in d2_results:
                    results_file.write(drug_id + ' ' + result['(type(a))'] + ' ' + disease_pair[1] + '\n')
                    network.add_edge(Edge(drug, d2, result['(type(a))'], {}))

                # the drug targets a gene which is associated to the disease
                d1_genes = set()
                query = """ MATCH (n:Drug)-[:TARGETS]-(g:Gene)-[:ASSOCIATES_WITH]-(d:Disease) WHERE {d1_id} IN d.ids AND {n_id} in n.ids RETURN g.`_id` """
                d1_results = session.run(query, parameters={'d1_id': disease_pair[0], 'n_id': drug_id})
                for gene in d1_results:
                    d1_genes.add(gene['g.`_id`'])
                    g = Gene([gene['g.`_id`']], [])
                    network.add_node(g)
                    network.add_edge(Edge(drug, g, 'TARGETS', {'actions': []})) #TODO
                    network.add_edge(Edge(g, d1, 'ASSOCIATES_WITH', {}))
                d2_genes = set()
                query = """ MATCH (n:Drug)-[:TARGETS]-(g:Gene)-[:ASSOCIATES_WITH]-(d:Disease) WHERE {d2_id} IN d.ids AND {n_id} in n.ids RETURN g.`_id` """
                d2_results = session.run(query, parameters={'d2_id': disease_pair[1], 'n_id': drug_id})
                for gene in d2_results:
                    d2_genes.add(gene['g.`_id`'])
                    g = Gene([gene['g.`_id`']], [])
                    network.add_node(g)
                    network.add_edge(Edge(drug, g, 'TARGETS', {'actions': []})) #TODO
                    network.add_edge(Edge(g, d2, 'ASSOCIATES_WITH', {}))

                common_drug_genes = d1_genes.intersection(d2_genes) # genes associated to the drug and both diseases
                # relevant_genes are all genes associated to at least one disease and the drug, below the common genes
                # with the most disease associated references are added
                relevant_genes = d1_genes.union(d2_genes)
                if len(d1_genes) > 0:
                    nbr = str(len(d1_genes))
                    d1_genes = str(d1_genes)
                    d1_genes = d1_genes.replace('{', '')
                    d1_genes = d1_genes.replace('}', '')
                    d1_genes = d1_genes.replace('\'', '')
                    results_file.write(drug_id + ' targets following ' + nbr + ' genes which are associated to ' + disease_pair[0] + ': ' + d1_genes + '\n')
                if len(d2_genes) > 0:
                    nbr = str(len(d2_genes))
                    d2_genes = str(d2_genes)
                    d2_genes = d2_genes.replace('{', '')
                    d2_genes = d2_genes.replace('}', '')
                    d2_genes = d2_genes.replace('\'', '')
                    results_file.write(drug_id + ' targets following ' + nbr + ' genes which are associated to ' + disease_pair[1] + ': ' + d2_genes + '\n')
                if len(common_drug_genes) > 0:
                    nbr = str(len(common_drug_genes))
                    cdgs = str(common_drug_genes)
                    cdgs = cdgs.replace('{', '')
                    cdgs = cdgs.replace('}', '')
                    cdgs = cdgs.replace('\'', '')
                    results_file.write('The disease pair has ' + nbr + ' common genes which are targeted by the drug: ' + cdgs + '\n')

                # add the common genes with the most disease associated references
                # no given num_pmids is similar to num_pmids = 0
                all_d1_genes, all_d2_genes = get_genes(disease_pair)
                all_common_genes = all_d1_genes.intersection(all_d2_genes)
                relevant_common_genes = []  # the genes with the most cited gene-disease association, threshold 10
                if len(all_common_genes) > 0:
                    results_file.write('The disease pair has ' + str(len(all_common_genes)) + ' common genes, not considering the connection to the drug.'
                                        ' Following genes have the most references regarding their connection to both diseases:\n')
                    for gene in all_common_genes:
                        query = """ MATCH (d1:Disease)-[a]-(g:Gene) WHERE {g_id} IN g.ids AND {d1_id} IN d1.ids RETURN a.num_pmids """
                        results = session.run(query, parameters={'g_id': gene, 'd1_id': disease_pair[0]})
                        num_pmids = 0
                        for result in results:  # multiple edges to the same gene
                            temp = result['a.num_pmids']
                            if temp is not None:
                                num_pmids = num_pmids + temp
                        query = """ MATCH (d2:Disease)-[a]-(g:Gene) WHERE {g_id} IN g.ids AND {d2_id} IN d2.ids RETURN a.num_pmids """
                        results = session.run(query, parameters={'g_id': gene, 'd2_id': disease_pair[1]})
                        for result in results:  # multiple edges to the same gene
                            temp = result['a.num_pmids']
                            if temp is not None:
                                num_pmids = num_pmids + temp
                        relevant_common_genes.append([gene, num_pmids])
                    # sort by number of pmids
                    relevant_common_genes = sorted(relevant_common_genes, key=lambda item: item[1], reverse=True)
                    relevant_common_genes = relevant_common_genes[:10]  # threshold
                    rcgs = str(relevant_common_genes)
                    rcgs = rcgs[1:-1]
                    rcgs = rcgs.replace('\'', '')
                    results_file.write(rcgs + '\n')
                    for g in relevant_common_genes:
                        gene = Gene([g[0]], [])
                        network.add_node(gene)
                        network.add_edge(Edge(gene, d1, 'ASSOCIATES_WITH', {}))
                        network.add_edge(Edge(gene, d2, 'ASSOCIATES_WITH', {}))
                        relevant_genes.add(g[0])

                # add the common disease associated variants with most references
                # no given num_pmids is similar to num_pmids = 0
                disease_variants = {}
                query = """ MATCH (d1:Disease)-[a]-(v:Variant)--(d2:Disease) WHERE {d1_id} in d1.ids AND {d2_id} in d2.ids RETURN distinct(a.num_pmids), v.`_id` """
                results = session.run(query, parameters={'d1_id': disease_pair[0], 'd2_id': disease_pair[1]})
                for variant in results:
                    num_pmids = variant['(a.num_pmids)']
                    if num_pmids is None:
                        num_pmids = 0
                    var_id = variant['v.`_id`']
                    if var_id in disease_variants:
                        temp = disease_variants[var_id]
                        disease_variants[var_id] = temp + num_pmids
                    else:
                        disease_variants[var_id] = num_pmids
                query = """ MATCH (d2:Disease)-[a]-(v:Variant)--(d1:Disease) WHERE {d1_id} in d1.ids AND {d2_id} in d2.ids RETURN distinct(a.num_pmids), v.`_id` """
                results = session.run(query, parameters={'d1_id': disease_pair[0], 'd2_id': disease_pair[1]})
                for variant in results:
                    num_pmids = variant['(a.num_pmids)']
                    if num_pmids is None:
                        num_pmids = 0
                    var_id = variant['v.`_id`']
                    if var_id in disease_variants:
                        temp = disease_variants[var_id]
                        disease_variants[var_id] = temp + num_pmids
                    else:
                        disease_variants[var_id] = num_pmids
                dvs = ''
                i = 0
                for key, value in sorted(disease_variants.items(), key=lambda item: item[1], reverse=True):
                    if i < 9:   # threshold
                        num_pmids = disease_variants[key]
                        variant = Variant([key], [])
                        network.add_node(variant)
                        network.add_edge(Edge(variant, d1, 'ASSOCIATES_WITH', {}))
                        network.add_edge(Edge(variant, d2, 'ASSOCIATES_WITH', {}))
                        dvs = dvs + key + ':' + str(num_pmids) + ' PMIDs, '
                        i += 1
                dvs = dvs[:-2]

                # add the gene associated variants with smallest pvalues
                # if no pvalue is given, pvalue is set to 1
                gene_variants = []
                for gene in relevant_genes:
                    query = """ MATCH (g:Gene)-[a]-(v:Variant) WHERE {g_id} in g.ids RETURN v.`_id`, a.pvalue, type(a) """
                    results = session.run(query, parameters={'g_id': gene})
                    for variant in results:
                        pvalue = variant['a.pvalue']
                        if pvalue is None:
                            pvalue = 1
                        else:
                            pvalue = float(pvalue)
                        gene_variants.append([variant['v.`_id`'] + '-' + gene, pvalue, variant['type(a)']])
                gene_variants = sorted(gene_variants, key=lambda item: item[1])
                gene_variants = gene_variants[:10]  # threshold
                for v in gene_variants:
                    temp = v[0].split('-')
                    v_id = temp[0]
                    g_id = temp[1]
                    variant = Variant([v_id], [])
                    network.add_node(variant)
                    gene = Gene([g_id], [])
                    network.add_node(gene)
                    network.add_edge(Edge(gene, variant, v[2], {'pvalue': v[1]}))
                if len(gene_variants) > 0:
                    gvs = str(gene_variants)
                    gvs = gvs[1:-1]
                    gvs = gvs.replace('\'', '')
                else:
                    gvs = ''

                if len(disease_variants) > 0 or len(gene_variants) > 0:
                    results_file.write('The disease pair has at least ' + str(i) + ' variants associated to both diseases: ' +
                                           dvs + ' and at least ' + str(len(gene_variants)) + ' gene associated variants: ' + gvs + '\n')

                # dict with RNA name as key and an array as value
                # first array position is the number of regulated genes, second position is an array with the gene names
                relevant_rnas = {}
                for gene in relevant_genes:
                    query = """ MATCH (r:RNA)--(g:Gene) WHERE {g_id} in g.ids AND NOT r.label_id CONTAINS "MRNA" return r.`_id` """
                    results = session.run(query, parameters={'g_id': gene})
                    for result in results:
                        key = result['r.`_id`']
                        if key in relevant_rnas:
                            value = relevant_rnas[key]
                            genes = value[1]
                            if gene not in genes:
                                genes.add(gene)
                                relevant_rnas[key] = [value[0] + 1, genes]
                        else:
                            genes = set()
                            genes.add(gene)
                            relevant_rnas[key] = [1, genes]

                if len(relevant_rnas) > 0:
                    i = 0
                    for key, value in sorted(relevant_rnas.items(), key=lambda item: item[1], reverse=True):
                    # sort by the number of regulated genes
                        if i > 9:   # threshold
                            break
                        elif value[0] > 1:  # only add and print RNAs which regulate more than one gene
                            if i == 0:
                                results_file.write('RNAs with the number and names of the genes they regulate: \n')
                            rna_id = key
                            for gene_id in value[1]:
                                rna = RNA([rna_id], [])
                                network.add_node(rna)
                                gene = Gene([gene_id], [])
                                network.add_node(gene)
                                network.add_edge(Edge(rna, gene, 'REGULATES', {}))
                            regulated_genes = str(value[1])
                            regulated_genes = regulated_genes[1:-1]
                            regulated_genes = regulated_genes.replace('\'', '')
                            results_file.write(rna_id + '\t' + str(value[0]) + '\t' + regulated_genes + '\n')
                            i += 1

                    # append regulating RNAs to one RNA which regulates the most genes, MRNAs are not added
                    for key, value in sorted(relevant_rnas.items(), key=lambda item: item[1], reverse=True):
                        if value[0] > 1:
                            most_relevant_rna = RNA([key], [])
                            network.add_node(most_relevant_rna)
                            query = """ MATCH (r:RNA)--(n:RNA) WHERE {r_id} in r.ids AND NOT n.label_id CONTAINS "MRNA" RETURN n.`_id`, labels(n) """
                            results = session.run(query, parameters={'r_id': key})
                            reg_rnas = ''
                            for result in results:
                                rna_id = result['n.`_id`']
                                types = result['labels(n)']
                                for type in types:
                                    if type != 'RNA':
                                        if type == 'CircRNA':
                                            rna = CircRNA([rna_id], [])
                                        if type == 'ERNA':
                                            rna = ERNA([rna_id], [])
                                        if type == 'LncRNA':
                                            rna = LncRNA([rna_id], [])
                                        if type == 'MiRNA':
                                            rna = MiRNA([rna_id], [])
                                        if type == 'NcRNA':
                                            rna = NcRNA([rna_id], [])
                                        if type == 'PiRNA':
                                            rna = PiRNA([rna_id], [])
                                        if type == 'Pseudogene':
                                            rna = Pseudogene([rna_id], [])
                                        if type == 'Ribozyme':
                                            rna = Ribozyme([rna_id], [])
                                        if type == 'RRNA':
                                            rna = RRNA([rna_id], [])
                                        if type == 'ScaRNA':
                                            rna = ScaRNA([rna_id], [])
                                        if type == 'ScRNA':
                                            rna = ScRNA([rna_id], [])
                                        if type == 'SnoRNA':
                                            rna = SnoRNA([rna_id], [])
                                        if type == 'SnRNA':
                                            rna = SnRNA([rna_id], [])
                                        network.add_node(rna)
                                        network.add_edge(Edge(rna, most_relevant_rna, 'REGULATES', {}))
                                        reg_rnas = reg_rnas + rna_id + ', '
                            reg_rnas = reg_rnas[:-2]
                            results_file.write(key + ' is the RNA which regulates the most genes in this subgraph. It is regulated by ' + reg_rnas + '.\n')
                        break
            json_file = path + '/' + temp_id1 + '_' + temp_id2 + '_' + temp_drug_id + '_graph.json'
            network.save(json_file)
            draw_drug_subgraph(json_file)
            networks_per_drug.append(network)
        all_networks.append(networks_per_drug)
    return all_networks
예제 #27
0
def main():
    train_dataset = MNIST(root='./data',
                          train=True,
                          download=True,
                          transform=transforms.ToTensor())
    test_dataset = MNIST(root='./data',
                         train=False,
                         download=True,
                         transform=transforms.ToTensor())

    train_loader = DataLoader(train_dataset,
                              batch_size=BATCH_SIZE,
                              shuffle=True,
                              num_workers=2)
    test_loader = DataLoader(test_dataset,
                             batch_size=BATCH_SIZE,
                             shuffle=False,
                             num_workers=2)

    net = Network(1, 64, 5, 10)

    if USE_CUDA:
        net = net.cuda()

    opt = optim.SGD(net.parameters(),
                    lr=LEARNING_RATE,
                    weight_decay=WEIGHT_DECAY,
                    momentum=.9,
                    nesterov=True)

    if not os.path.exists('checkpoint'):
        os.mkdir('checkpoint')

    for epoch in range(1, EPOCHS + 1):
        print('[Epoch %d]' % epoch)

        train_loss = 0
        train_correct, train_total = 0, 0

        start_point = time.time()

        for inputs, labels in train_loader:
            inputs, labels = Variable(inputs), Variable(labels)
            if USE_CUDA:
                inputs, labels = inputs.cuda(), labels.cuda()

            opt.zero_grad()

            preds = F.log_softmax(net(inputs), dim=1)

            loss = F.cross_entropy(preds, labels)
            loss.backward()

            opt.step()

            train_loss += loss.item()

            train_correct += (preds.argmax(dim=1) == labels).sum().item()
            train_total += len(preds)

        print('train-acc : %.4f%% train-loss : %.5f' %
              (100 * train_correct / train_total,
               train_loss / len(train_loader)))
        print('elapsed time: %ds' % (time.time() - start_point))

        test_loss = 0
        test_correct, test_total = 0, 0

        for inputs, labels in test_loader:
            with torch.no_grad():
                inputs, labels = Variable(inputs), Variable(labels)

                if USE_CUDA:
                    inputs, labels = inputs.cuda(), labels.cuda()

                preds = F.softmax(net(inputs), dim=1)

                test_loss += F.cross_entropy(preds, labels).item()

                test_correct += (preds.argmax(dim=1) == labels).sum().item()
                test_total += len(preds)

        print('test-acc : %.4f%% test-loss : %.5f' %
              (100 * test_correct / test_total, test_loss / len(test_loader)))

        torch.save(net.state_dict(),
                   './checkpoint/checkpoint-%04d.bin' % epoch)
 def setUp(self):
     self.net = Network()
예제 #29
0
from model.embedding import Embedding
from model.network import Network

from extractor.utterance.all import UtteranceAll

import numpy as np
import jieba
import os
import sys
# import thulac

# thulac_seg = thulac.thulac(seg_only=True)

base_path = os.path.dirname(os.path.realpath(__file__))
model = Network(None, os.path.join(base_path, 'log.nosync/network_demo/run1'))
model.load_model(
    os.path.join(base_path,
                 'model_checkpoint/lstm_early_stopping_without_conc'))

path = os.path.join(base_path, 'word2vec_model/wiki.zh.nosync/wiki.zh.vec')
embedding = Embedding()
embedding.load_w2v_model(path, False)

max_steps = model.n_input_steps
n_embedding = model.n_embedding

categories = [
    '开关语音播报', '打电话', '发短信', '发邮件', '导航', '离职倾向', 'KPI', '访问网站', '会议室预定',
    '设置提醒', '查日程安排', '查会议安排', '查会议室安排情况', '查月度工作任务', '查工作任务完成情况', '查月度预算执行情况',
    '查当月费用报销情况', '查借款情况', '查应收款', '查应付款', '查考勤', '查出差情况', '查天气', '查股票', '讲笑话',
예제 #30
0
파일: train.py 프로젝트: 530824679/YOLOv2
def train():
    start_step = 0
    log_step = solver_params['log_step']
    restore = solver_params['restore']
    checkpoint_dir = path_params['checkpoints_dir']
    checkpoints_name = path_params['checkpoints_name']
    tfrecord_dir = path_params['tfrecord_dir']
    tfrecord_name = path_params['train_tfrecord_name']
    log_dir = path_params['logs_dir']
    batch_size = solver_params['batch_size']

    # 配置GPU
    gpu_options = tf.GPUOptions(allow_growth=True)
    config = tf.ConfigProto(gpu_options=gpu_options)

    # 解析得到训练样本以及标注
    data = tfrecord.TFRecord()
    train_tfrecord = os.path.join(tfrecord_dir, tfrecord_name)
    data_num = total_sample(train_tfrecord)
    batch_num = int(math.ceil(float(data_num) / batch_size))
    dataset = data.create_dataset(train_tfrecord,
                                  batch_num,
                                  batch_size=batch_size,
                                  is_shuffle=True)
    iterator = dataset.make_one_shot_iterator()
    images, y_true = iterator.get_next()

    images.set_shape([None, 416, 416, 3])
    y_true.set_shape([None, 13, 13, 5, 6])

    # 构建网络
    network = Network(is_train=True)
    logits = network.build_network(images)

    # 计算损失函数
    total_loss, diou_loss, confs_loss, class_loss = network.calc_loss(
        logits, y_true)

    global_step = tf.Variable(0, trainable=False)
    learning_rate = tf.train.exponential_decay(solver_params['lr'],
                                               global_step,
                                               solver_params['decay_steps'],
                                               solver_params['decay_rate'],
                                               staircase=True)
    optimizer = tf.train.AdamOptimizer(learning_rate)
    update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
    with tf.control_dependencies(update_ops):
        train_op = optimizer.minimize(total_loss, global_step=global_step)

    # 配置tensorboard
    tf.summary.scalar("learning_rate", learning_rate)
    tf.summary.scalar('total_loss', total_loss)
    tf.summary.scalar("diou_loss", diou_loss)
    tf.summary.scalar("confs_loss", confs_loss)
    tf.summary.scalar("class_loss", class_loss)

    # 配置tensorboard
    summary_op = tf.summary.merge_all()
    summary_writer = tf.summary.FileWriter(log_dir,
                                           graph=tf.get_default_graph(),
                                           flush_secs=60)

    # 模型保存
    save_variable = tf.global_variables()
    saver = tf.train.Saver(save_variable, max_to_keep=50)

    with tf.Session(config=config) as sess:
        sess.run([
            tf.global_variables_initializer(),
            tf.local_variables_initializer()
        ])

        if restore == True:
            ckpt = tf.train.get_checkpoint_state(checkpoint_dir)
            if ckpt and ckpt.model_checkpoint_path:
                stem = os.path.basename(ckpt.model_checkpoint_path)
                restore_step = int(stem.split('.')[0].split('-')[-1])
                start_step = restore_step
                sess.run(global_step.assign(restore_step))
                saver.restore(sess, ckpt.model_checkpoint_path)
                print('Restoreing from {}'.format(ckpt.model_checkpoint_path))
            else:
                print("Failed to find a checkpoint")

        if solver_params['pre_train']:
            pretrained = np.load(path_params['pretrain_weights'],
                                 allow_pickle=True).item()
            for variable in tf.trainable_variables():
                for key in pretrained.keys():
                    key2 = variable.name.rstrip(':0')
                    if (key == key2):
                        sess.run(tf.assign(variable, pretrained[key]))

        summary_writer.add_graph(sess.graph)

        print('\n----------- start to train -----------\n')
        for epoch in range(start_step + 1, solver_params['epoches']):
            train_epoch_loss, train_epoch_diou_loss, train_epoch_confs_loss, train_epoch_class_loss = [], [], [], []
            for index in tqdm(range(batch_num)):
                _, summary_, loss_, diou_loss_, confs_loss_, class_loss_, global_step_, lr = sess.run(
                    [
                        train_op, summary_op, total_loss, diou_loss,
                        confs_loss, class_loss, global_step, learning_rate
                    ])

                train_epoch_loss.append(loss_)
                train_epoch_diou_loss.append(diou_loss_)
                train_epoch_confs_loss.append(confs_loss_)
                train_epoch_class_loss.append(class_loss_)

                summary_writer.add_summary(summary_, global_step_)

            train_epoch_loss, train_epoch_diou_loss, train_epoch_confs_loss, train_epoch_class_loss = np.mean(
                train_epoch_loss), np.mean(train_epoch_diou_loss), np.mean(
                    train_epoch_confs_loss), np.mean(train_epoch_class_loss)
            print(
                "Epoch: {}, global_step: {}, lr: {:.8f}, total_loss: {:.3f}, diou_loss: {:.3f}, confs_loss: {:.3f}, class_loss: {:.3f}"
                .format(epoch, global_step_, lr, train_epoch_loss,
                        train_epoch_diou_loss, train_epoch_confs_loss,
                        train_epoch_class_loss))
            saver.save(sess,
                       os.path.join(checkpoint_dir, checkpoints_name),
                       global_step=epoch)

        sess.close()