Ejemplo n.º 1
0
def test1(in_channels, out_channels, net_name,
          weights_path, test_image_root, batch_size, resize=None, crop_offset=None, **kwargs):
    """
    测试网络
    :param in_channels: 输入通道
    :param out_channels: 输出通道
    :param net_name: 网络名称
    :param weights_path: 模型权重文件路径
    :param test_image_root: 测试图片目录
    :param batch_size: 批量大小
    :param resize: 网络输入的图片尺寸
    :param crop_offset: 剪切偏移量
    :param kwargs:
    :return:
    """

    device = torch.device("cpu" if torch.cuda.is_available() else "cpu")

    # 网络
    net = create_net(in_channels, out_channels, net_name, **kwargs)

    net.eval()  # 不启用 BatchNormalization 和 Dropout, see https://pytorch.org/docs/stable/nn.html?highlight=module%20eval#torch.nn.Module.eval
    net = net.to(device)

    # Load checkpoint weights
    net.load_state_dict(torch.load(weights_path))

    generator, data_size = test_data_generator(test_image_root, batch_size, resize, crop_offset)
    epoch_size = int(data_size / batch_size)  # 1个epoch包含的batch数目
    with torch.no_grad():
        for batch_index in range(1, epoch_size + 1):
            images, original_images, original_hw_size = next(generator)

            images = images.to(device)
            predicts = net(images)  # 推断 shape=(n,c,h,w)
            convert = torch.softmax(predicts, dim=1).argmax(dim=1)  # convert.shape=(n,h,w)

            # shape=(n,h,w,c)=(n,h,w,3)
            # decode_image = decode(convert).unsqueeze(dim=-1). \
            #     expand(convert.shape[0], convert.shape[1], convert.shape[2], 3).contiguous()

            # (n,h,w) => (n,h,w,3)
            decode_image = decode_rgb(convert)

            for index, original_image in enumerate(original_images):
                original_size = original_hw_size[index]
                result_image = recover_image(decode_image[index].numpy(),
                                             (original_size[0] - crop_offset[0], original_size[1] - crop_offset[1]),
                                             crop_offset)

                plt.subplot(1, 2, 1)
                plt.imshow(result_image)
                plt.subplot(1, 2, 2)
                plt.imshow(original_image)
                plt.show()
Ejemplo n.º 2
0
def style_transfer_video(video_file, checkpoint_path, out_path):
    cap = cv2.VideoCapture(video_file)
    fps = int(cap.get(cv2.CAP_PROP_FPS))
    width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
    height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
    vid_size = (width // 2, height // 2)
    fourcc = cv2.VideoWriter_fourcc(*'h264')
    video_writer = cv2.VideoWriter(out_path, fourcc, fps, vid_size)

    g = tf.Graph()
    batch_shape = (BATCH_SIZE, vid_size[1], vid_size[0], 3)

    with g.as_default(), tf.Session() as sess:
        img_placeholder = tf.placeholder(tf.float32,
                                         shape=batch_shape,
                                         name='img_placeholder')
        preds = net(img_placeholder)
        saver = tf.train.Saver()
        saver.restore(sess, checkpoint_path)

        X = np.zeros(batch_shape, dtype=np.float32)

        def style_and_write(count):
            for j in range(count, BATCH_SIZE):
                X[j] = X[count - 1]

            _preds = sess.run(preds, feed_dict={img_placeholder: X})

            for j in range(count):
                style_frame = np.clip(_preds[j], 0, 255).astype(np.uint8)
                style_frame = cv2.cvtColor(style_frame, cv2.COLOR_RGB2BGR)
                video_writer.write(style_frame)

        frame_count = 0
        i = 0
        while True:
            ret, frame = cap.read()
            if frame is None:
                break
            frame = cv2.resize(frame, (0, 0), fx=0.5, fy=0.5)
            frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
            X[frame_count] = frame
            frame_count += 1
            if frame_count == BATCH_SIZE:
                style_and_write(frame_count)
                frame_count = 0
                print("Wrote %d frames" % (i + 1))
            i += 1

        if frame_count != 0:
            style_and_write(frame_count)

    cap.release()
Ejemplo n.º 3
0
def style_transfer(img, checkpoint_path):
    batch_size = 1
    batch_shape = (batch_size,) + img.shape
    g = tf.Graph()
    with g.as_default(), tf.Session() as sess:
        img_placeholder = tf.placeholder(tf.float32, shape=batch_shape, name='img_placeholder')
        preds = net(img_placeholder)
        saver = tf.train.Saver()
        saver.restore(sess, checkpoint_path)

        X = np.zeros(batch_shape, dtype=np.float32)
        X[0] = img

        _preds = sess.run(preds, feed_dict={img_placeholder: X})
        return _preds[0]
Ejemplo n.º 4
0
def test(in_channels, out_channels, net_name,
         weights_path, image_path, resize=None, crop_offset=None, **kwargs):
    """
    测试网络
    :param in_channels: 输入通道
    :param out_channels: 输出通道
    :param net_name: 网络名称
    :param weights_path: 模型权重文件路径
    :param image_path: 测试图片地址
    :param resize: 网络输入的图片尺寸
    :param crop_offset: 剪切偏移量
    :param kwargs:
    :return:
    """

    device = torch.device("cpu" if torch.cuda.is_available() else "cpu")

    # 网络
    net = create_net(in_channels, out_channels, net_name, **kwargs)
    # net.train()  # 训练 BatchNormalization 和 Dropout
    net.eval()  # 固定 BatchNormalization 和 Dropout, see https://pytorch.org/docs/stable/nn.html?highlight=module%20eval#torch.nn.Module.eval
    net = net.to(device)
    # Load checkpoint weights
    net.load_state_dict(torch.load(weights_path))
    with torch.no_grad():
        image, ori_size, ori_image = read_image(image_path, resize, crop_offset)
        image = image.to(device)
        predicts = net(image)  # 推断
        convert = torch.softmax(predicts, dim=1).argmax(dim=1)  # convert.shape=(n,h,w)
        # decode_image = decode(convert).permute(1, 2, 0)
        # decode_image = decode_image.expand(decode_image.shape[0], decode_image.shape[1], 3).contiguous()
        # shape=(n,h,w,c)=(n,h,w,3)
        decode_image = decode(convert).unsqueeze(dim=-1). \
            expand(convert.shape[0], convert.shape[1], convert.shape[2], 3).contiguous()
        recover_img = recover_image(decode_image.numpy(), (ori_size[0] - crop_offset[0], ori_size[1] - crop_offset[1]),
                                    crop_offset)

        for index, img in enumerate(recover_img):
            plt.subplot(1, 2, 1)
            plt.imshow(img)
            plt.subplot(1, 2, 2)
            plt.imshow(ori_image)
            plt.show()
Ejemplo n.º 5
0
def hello_world(payload):
    global model
    doodle = payload['doodle'][22:]
    coords = payload['coords']

    coords[0] = int(coords[0])
    coords[1] = int(coords[1])
    coords[2] = int(coords[2])
    coords[3] = int(coords[3])

    id_ = payload['id']
    pickle.dump(doodle, open('doodle.p', 'wb'))
    fh = open("imageToSave.png", "wb")
    fh.write(base64.urlsafe_b64decode(doodle))
    fh.close()
    noo = misc.imread('imageToSave.png', mode='L').astype(int)
    img, x, y = preprocess('./imageToSave.png',
                           n=30,
                           brightness=120,
                           coords=coords)
    img = misc.imread('tmp.png').astype('float32')
    img = np.stack((img, ) * 3)
    img = torch.Tensor(normalizing(img))
    img = img.view(1, 3, 256, 256)
    img = Variable(img)
    # loading the trained model
    if model is None:
        # model = load_model('.././doodlemodel/ultimate4_model-data-augmentation.hdf5')
        model = YoloClassifier(utils.net('tiny_yolo', in_channels=3),
                               class_size=class_size,
                               batch_size=1)
        model = load_checkpoint(model)

    out = model(img)
    _, preds = torch.max(out.data, 1)

    preds = preds.numpy()
    preds = preds[0]
    print("\nI think it's a : " + str(class_keys[preds]))

    emit('message', {'message': str(class_keys[preds])})
    buildLayout(coords, x, y, preds, id_)
Ejemplo n.º 6
0
def single_level_net_steps(net, opt, Xcur, Ycur, num_steps, mb_size, loss_func,
                           clip_gradient, sampling_weights):
    # TODO: make sure there are no wasteful copying
    torch.set_num_threads(1)
    start = time.time()
    Xcur = to_variable(Xcur).float()
    Ycur = to_variable(Ycur).float()
    for mini_iter in range(num_steps):
        # TODO: reweight sampling weights here
        idxs = np.random.choice(list(range(len(Xcur))),
                                mb_size,
                                p=sampling_weights)
        xbatch = Xcur[idxs]
        ybatch = Ycur[idxs]
        pred = net(xbatch).squeeze(1)
        assert pred.shape == ybatch.shape
        loss = loss_func(pred, ybatch)
        opt.zero_grad()
        loss.backward()
        if clip_gradient is not None:
            clip_grad_norm_(net.parameters(), clip_gradient)
        opt.step()
Ejemplo n.º 7
0
def valid(net, csv_path, load_data, batch_size, resize, crop_offset, num_classes):
    """
    训练网络模型
    :param net: 网络模型
    :param csv_path: 数据data_list文件
    :param load_data: 加载数据function, 返回数据root_path
    :param batch_size: 批量尺寸
    :param resize: 网络输入的图片尺寸
    :param crop_offset: 剪切偏移量
    :param num_classes: 类别数量
    :return:
    """
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    net.eval()  # 固定 BatchNormalization 和 Dropout, see https://pytorch.org/docs/stable/nn.html?highlight=module%20eval#torch.nn.Module.eval

    # 准备数据
    df = pd.read_csv(csv_path)
    generator = data_generator(load_data,
                               np.array(df['image']),
                               np.array(df['label']),
                               batch_size, resize, crop_offset)
    # 训练
    epoch_size = int(len(df) / batch_size)  # 1个epoch包含的batch数目

    miou = 0.0
    with torch.no_grad():
        for iter in range(1, epoch_size + 1):
            images, labels = next(generator)
            images = images.to(device)
            labels = labels.to(device)

            predicts = net(images)  # 推断

            iou = get_miou(predicts, labels, num_classes)
            print("valid {}/{} iou".format(iter, epoch_size), iou)
            miou += iou
    return miou / epoch_size
Ejemplo n.º 8
0
    def train_step(self, num_steps=1):
        if self.nn_type == "num_tables":
            assert self.net is None
            nt_map = self.train_num_table_mapping
            if self.single_threaded_nt:
                for nt in nt_map:
                    start, end = nt_map[nt]
                    net_map = self._map_num_tables(nt)
                    net = self.nets[net_map]
                    opt = self.optimizers[net_map]
                    Xcur = self.Xtrain[start:end]
                    Ycur = self.Ytrain[start:end]
                    for mini_iter in range(num_steps):
                        # TODO: reweight sampling weights here
                        idxs = np.random.choice(list(range(len(Xcur))),
                                                self.mb_size)
                        xbatch = Xcur[idxs]
                        ybatch = Ycur[idxs]
                        pred = net(xbatch).squeeze(1)
                        assert pred.shape == ybatch.shape
                        loss = self.loss(pred, ybatch)
                        opt.zero_grad()
                        loss.backward()
                        if self.clip_gradient is not None:
                            clip_grad_norm_(net.parameters(),
                                            self.clip_gradient)
                        opt.step()
            else:
                par_args = []
                for i, nt in enumerate(nt_map):
                    start, end = nt_map[nt]
                    net_map = self._map_num_tables(nt)
                    net = self.nets[net_map]
                    opt = self.optimizers[net_map]
                    Xcur = self.Xtrain[start:end].cpu().detach().numpy()
                    Ycur = self.Ytrain[start:end].cpu().detach().numpy()
                    sampling_wts = self.subquery_sampling_weights[start:end]
                    sampling_wts = self._normalize_priorities(sampling_wts)

                    # TODO: make mb_size dependent on the level?
                    par_args.append(
                        (net, opt, Xcur, Ycur, num_steps, self.mb_size,
                         self.loss, self.clip_gradient, sampling_wts))

                # launch single-threaded processes for each
                # TODO: might be better to launch pool of 4 + 2T each, so we
                # don't waste resources on levels that finish fast?
                num_processes = 4
                with Pool2(processes=num_processes) as pool:
                    pool.starmap(single_level_net_steps, par_args)
        else:
            # TODO: replace this with dataloader (...)
            for mini_iter in range(num_steps):
                # usual case
                idxs = np.random.choice(list(range(len(self.Xtrain))),
                                        self.mb_size,
                                        p=self.subquery_sampling_weights)

                xbatch = self.Xtrain[idxs]
                ybatch = self.Ytrain[idxs]
                pred = self.net(xbatch).squeeze(1)
                loss = self.loss(pred, ybatch)
                self.optimizer.zero_grad()
                loss.backward()
                if self.clip_gradient is not None:
                    clip_grad_norm_(self.net.parameters(), self.clip_gradient)
                self.optimizer.step()
Ejemplo n.º 9
0
def fast_style_transfer(data_in,
                        paths_out,
                        checkpoint_dir,
                        device_t='/gpu:0',
                        batch_size=4):

    is_paths = type(data_in[0]) == str

    if is_paths:
        img_shape = get_img(data_in[0]).shape
    else:
        img_shape = X[0].shape

    g = tf.Graph()
    batch_size = min(len(paths_out), batch_size)
    curr_num = 0
    soft_config = tf.ConfigProto(allow_soft_placement=True)
    soft_config.gpu_options.allow_growth = True

    with g.as_default(), g.device(device_t), tf.Session(
            config=soft_config) as sess:

        batch_shape = (batch_size, ) + img_shape
        img_placeholder = tf.placeholder(tf.float32,
                                         shape=batch_shape,
                                         name='img_placeholder')

        preds = net(img_placeholder)
        saver = tf.train.Saver()

        saver.restore(sess, checkpoint_dir)
        num_iters = int(len(paths_out) / batch_size)

        for i in range(num_iters):

            pos = i * batch_size
            curr_batch_out = paths_out[pos:pos + batch_size]

            if is_paths:
                curr_batch_in = data_in[pos:pos + batch_size]
                X = np.zeros(batch_shape, dtype=np.float32)
                for j, path_in in enumerate(curr_batch_in):
                    img = get_img(path_in)
                    X[j] = img
            else:
                X = data_in[pos:pos + batch_size]

            _preds = sess.run(preds, feed_dict={img_placeholder: X})
            for j, path_out in enumerate(curr_batch_out):
                save_img(path_out, _preds[j])

        remaining_in = data_in[num_iters * batch_size:]
        remaining_out = paths_out[num_iters * batch_size:]

    if len(remaining_in) > 0:
        fast_style_transfer(remaining_in,
                            remaining_out,
                            checkpoint_dir,
                            device_t=device_t,
                            batch_size=1)

    print("style transfer completed.")
    send_to_ps(paths_out[0])

    # clear temp folder
    temppath = os.path.dirname(data_in[0])
    ([os.remove(os.path.join(temppath, f)) for f in os.listdir(temppath)])
def train(in_channels,
          out_channels,
          net_name,
          lr,
          csv_path,
          data_path,
          batch_size,
          resize,
          crop_offset,
          epoch_begin,
          epoch_num,
          num_classes,
          save_model,
          load_state_dict_path=None,
          loss_weights=None,
          **kwargs):
    """
    训练网络模型
    :param in_channels: 输入通道
    :param out_channels: 输出通道
    :param net_name: 网络名称
    :param lr: 学习率
    :param csv_path: 数据data_list文件
    :param load_data: 加载数据function, 返回数据root_path
    :param batch_size: 批量尺寸
    :param resize: 网络输入的图片尺寸
    :param crop_offset: 剪切偏移量
    :param epoch_begin: 开始批次(可以实现断点训练)
    :param epoch_num: epoch大小
    :param num_classes: 类别数量
    :param save_model: 训练网络参数保存function
    :param load_state_dict_path: 网络预训练权重
    :param loss_weights: 每个类别的权重, shape=(num_classes)
    :return:
    """
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # 网络
    net = create_net(in_channels, out_channels, net_name, **kwargs)
    net.train()  # 训练 BatchNormalization 和 Dropout
    # net.eval()  # 固定 BatchNormalization 和 Dropout, see https://pytorch.org/docs/stable/nn.html?highlight=module%20eval#torch.nn.Module.eval
    net = net.to(device)
    if load_state_dict_path is not None:
        net.load_state_dict(torch.load(load_state_dict_path))

    # 优化器
    optimizer = torch.optim.Adam(net.parameters(), lr=lr)
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer,
                                                step_size=1,
                                                gamma=0.1,
                                                last_epoch=-1)

    # 准备数据
    df = pd.read_csv(csv_path)
    generator = hpmp_data_generator(data_path, batch_size)
    # 训练
    epoch_size = int(len(df) / batch_size)  # 1个epoch包含的batch数目
    for epoch in range(epoch_begin, epoch_num):
        print("The epoch {} start.".format(epoch))
        start = datetime.datetime.now()
        epoch_loss = 0.0
        for iter in range(1, epoch_size + 1):
            images, labels = next(generator)
            images = images.to(device)
            labels = labels.to(device)

            predicts = net(images)  # 推断

            if loss_weights is not None:
                # dice_loss_weights = torch.Tensor([1, 2, 3, 4, 5, 6, 7, 8]).to(device)
                loss_weights = torch.Tensor(loss_weights).to(device)
            loss = create_multi_loss(loss_type=LossType.ce_loss,
                                     predicts=predicts,
                                     labels=labels,
                                     num_classes=num_classes,
                                     loss_weights=loss_weights)

            miou = get_miou(predicts, labels, num_classes)

            print("loss {}/{}".format(iter, epoch_size), loss)
            epoch_loss += loss.item()
            loss.backward()  # 反向传播
            optimizer.step()  # 更新网络参数
            optimizer.zero_grad()  # 梯度清零

        print("The epoch {} end, epoch loss:{}, miou:{}, execution time:{}".
              format(epoch, epoch_loss, miou.item(),
                     datetime.datetime.now() - start))
        print("The current epoch {} learning rate {}.".format(
            epoch,
            scheduler.get_lr()[0]))
        scheduler.step()  # 更新学习率
        # 保存模型
        model_name = f"ckpt_%d_%.2f.pth" % (epoch, epoch_loss)
        save_model(net, model_name)
Ejemplo n.º 11
0
def valid(net, csv_path, load_data, batch_size, resize, crop_offset,
          num_classes, load_classes, anchors, iou_thres, conf_thres, nms_thres,
          img_size):
    """
    训练网络模型
    :param iou_thres: 在批量数据统计时,当iou值大于该阈值时,才认为是正确的
    :param nms_thres: 非最大值抑制时,预测框相似程度的iou阈值,当大于该阈值,则认为预测框相似,过滤
    :param conf_thres: 非最大值抑制时,小于该置信度阈值,则过滤
    :param net: 网络模型
    :param csv_path: 数据data_list文件
    :param load_data: 加载数据function, 返回数据root_path
    :param batch_size: 批量尺寸
    :param resize: 网络输入的图片尺寸
    :param crop_offset: 剪切偏移量
    :param num_classes: 类别数量
    :return:
    """
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    float_tensor = torch.cuda.FloatTensor if torch.cuda.is_available(
    ) else torch.FloatTensor

    net.eval(
    )  # 不启用 BatchNormalization 和 Dropout, see https://pytorch.org/docs/stable/nn.html?highlight=module%20eval#torch.nn.Module.eval

    # 准备数据
    df = pd.read_csv(csv_path)
    # generator = data_generator(load_data,
    #                            np.array(df['image']),
    #                            np.array(df['label']),
    #                            batch_size, resize, crop_offset)
    generator = detection_data_generator(load_data,
                                         np.array(df['image']),
                                         np.array(df['label']),
                                         batch_size,
                                         load_classes=load_classes,
                                         resize=resize,
                                         crop_offset=crop_offset)
    # 训练
    epoch_size = int(len(df) / batch_size)  # 1个epoch包含的batch数目

    # miou = 0.0
    targets = []
    sample_metrics = []  # List of tuples (TP, confs, pred)
    with torch.no_grad():
        for iter in range(1, epoch_size + 1):
            images, labels = next(generator)
            images = images.to(device)
            # labels = labels.to(device)

            predicts = net(images)  # 推断

            yolo_outputs, _, _ = get_yolo_output(
                predicts=predicts,
                num_classes=num_classes,
                input_size=images.shape[-2:],
                anchors=anchors,
                cuda=torch.cuda.is_available(),
                labels=None,
                # ignore_threshold=0.5,
                # obj_scale=1,
                # noobj_scale=100
            )

            outputs = non_max_suppression(yolo_outputs.detach().cpu(),
                                          conf_thres=conf_thres,
                                          nms_thres=nms_thres)

            # Extract labels
            targets += labels[:, 1].tolist()
            # Rescale target
            labels[:, 2:] = xyhw2xyxy(labels[:, 2:])
            labels[:, 2:] = labels[:, 2:] * torch.FloatTensor(
                [img_size[1], img_size[0], img_size[1], img_size[0]])
            sample_metrics += get_batch_statistics(outputs,
                                                   labels,
                                                   iou_threshold=iou_thres)

            # iou = get_miou(predicts, labels, num_classes)
            # print("valid {}/{} iou".format(iter, epoch_size), iou)
            # miou += iou

    if len(sample_metrics) > 0:
        # Concatenate sample statistics
        true_positives, pred_scores, pred_labels = [
            np.concatenate(x, 0) for x in list(zip(*sample_metrics))
        ]
        precision, recall, AP, f1, ap_class = ap_per_class(
            true_positives, pred_scores, pred_labels, targets)

        return precision, recall, AP, f1, ap_class
    else:
        return None
Ejemplo n.º 12
0
def train_valid(in_channels,
                out_channels,
                net_name,
                lr,
                train_csv_path,
                load_train_data,
                valid_csv_path,
                load_valid_data,
                batch_size,
                resize,
                crop_offset,
                epoch_begin,
                epoch_num,
                num_classes,
                load_classes,
                anchors,
                lr_strategy,
                save_model,
                load_state_dict_path=None,
                loss_type: LossType = LossType.ce_loss,
                loss_weights=None,
                load_state_dict=None,
                **kwargs):
    """
    训练网络模型
    :param in_channels: 输入通道
    :param out_channels: 输出通道
    :param net_name: 网络名称
    :param lr: 学习率
    :param train_csv_path: 数据data_list文件
    :param load_train_data: 加载数据function, 返回数据root_path
    :param valid_csv_path: 数据data_list文件
    :param load_valid_data: 加载数据function, 返回数据root_path
    :param batch_size: 批量尺寸
    :param resize: 网络输入的图片尺寸
    :param crop_offset: 剪切偏移量
    :param epoch_begin: 开始批次(可以实现断点训练)
    :param epoch_num: epoch大小
    :param num_classes: 类别数量
    :param save_model: 训练网络参数保存function
    :param load_state_dict_path: 网络预训练权重
    :param loss_type: 损失函数
    :param loss_weights: 每个类别的权重, shape=(num_classes)
    :return:
    """
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # 网络
    net = create_net(in_channels, out_channels, net_name, **kwargs)
    net.train()  # 启用 BatchNormalization 和 Dropout
    # net.eval()  # 不启用 BatchNormalization 和 Dropout, see https://pytorch.org/docs/stable/nn.html?highlight=module%20eval#torch.nn.Module.eval
    net = net.to(device)
    if load_state_dict is not None:
        net.load_state_dict(load_state_dict())

    # 优化器
    optimizer = torch.optim.Adam(net.parameters(), lr=lr)
    # scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.1, last_epoch=-1)

    # 准备数据
    df = pd.read_csv(train_csv_path)
    generator = detection_data_generator(load_train_data,
                                         np.array(df['image']),
                                         np.array(df['label']),
                                         batch_size,
                                         load_classes=load_classes,
                                         resize=resize,
                                         crop_offset=crop_offset)
    # 训练、验证
    epoch_size = int(len(df) / batch_size)  # 1个epoch包含的batch数目
    best_net = {'mAP': 0, 'name': ''}
    for epoch in range(epoch_begin, epoch_num):
        # 训练
        print("The epoch {} start.".format(epoch))
        start = datetime.datetime.now()
        epoch_loss = 0.0
        for batch_index in range(1, epoch_size + 1):
            images, labels = next(generator)
            images = images.to(device)
            labels = labels.to(device)

            lr = ajust_learning_rate(optimizer, lr_strategy, epoch,
                                     batch_index - 1, epoch_size)

            predicts = net(images)  # 推断

            if loss_weights is not None:
                # dice_loss_weights = torch.Tensor([1, 2, 3, 4, 5, 6, 7, 8]).to(device)
                loss_weights = torch.Tensor(loss_weights).to(device)

            yolo_outputs, loss, metrics_table = get_yolo_output(
                predicts=predicts,
                num_classes=num_classes,
                input_size=images.shape[-2:],
                anchors=anchors,
                cuda=torch.cuda.is_available(),
                labels=labels,
                ignore_threshold=0.5,
                obj_scale=1,
                noobj_scale=5,
                coord_scale=5,
                cls_scale=1)

            print("batch_index/epoch_size/epoch/lr/loss {}/{}/{}/{}/{}".format(
                batch_index, epoch_size, epoch, lr, loss))
            if metrics_table is not None and len(metrics_table) > 0:
                for index, item in enumerate(metrics_table):
                    keys = item.keys()
                    keys_str = "/".join(
                        [key for key in keys if not key == 'grid_size'])
                    values_str = "/".join([
                        str(item[key]) for key in keys
                        if not key == 'grid_size'
                    ])
                    print(
                        str(item['grid_size'][0]) +
                        "/batch_index/epoch_size/epoch/output_loss/" +
                        keys_str + " {}/{}/{}/{}/".format(
                            batch_index, epoch_size, epoch, loss) + values_str)
            # print("batch_index/epoch_size/epoch {}/{}/{}".format(batch_index, epoch_size, epoch), metrics_table)
            epoch_loss += loss.item()
            loss.backward()  # 反向传播
            optimizer.step()  # 更新网络参数
            optimizer.zero_grad()  # 梯度清零

        print("The epoch {} end, epoch loss:{}, execution time:{}".format(
            epoch, epoch_loss,
            datetime.datetime.now() - start))

        # 验证
        valid_result = valid(net=net,
                             csv_path=valid_csv_path,
                             load_data=load_valid_data,
                             batch_size=batch_size,
                             resize=resize,
                             crop_offset=crop_offset,
                             num_classes=num_classes,
                             load_classes=load_classes,
                             anchors=anchors,
                             iou_thres=0.5,
                             conf_thres=0.5,
                             nms_thres=0.5,
                             img_size=resize)
        if valid_result is not None:
            precision, recall, ap, f1, ap_class = valid_result

            print("Average Precisions:")
            for i, c in enumerate(ap_class):
                print(f"+ Class '{c}' ({load_classes()[c]}) - AP: {ap[i]}")

            m_ap = ap.mean()
            model_name = f"ckpt_%d_%.2f_%.2f.pth" % (epoch, epoch_loss, m_ap)

            if m_ap > best_net['mAP']:
                best_net['mAP'] = m_ap
                best_net['name'] = model_name

            # 保存模型
            save_model(net, model_name)
            print("The current epoch {} mAP {}.".format(epoch, m_ap))
            # scheduler.step()  # 更新学习率
        else:
            model_name = f"ckpt_%d_%.2f_%.2f.pth" % (epoch, epoch_loss, 0)
            # 保存模型
            save_model(net, model_name)
            print("The current epoch {} mAP {}.".format(epoch, 0))

    print("This is the best model", best_net)