def main():
    #数据集加载
    dataset = Market1501()

    #训练数据处理器
    transform_train = T.Compose([
        T.Random2DTransform(height, width),  #尺度统一,随机裁剪
        T.RandomHorizontalFlip(),  #水平翻转
        T.ToTensor(),  #图片转张量
        T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224,
                                                     0.225]),  #归一化,参数固定
    ])

    #测试数据处理器
    transform_test = T.Compose([
        T.Resize((height, width)),  #尺度统一
        T.ToTensor(),  #图片转张量
        T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224,
                                                     0.225]),  #归一化,参数固定
    ])

    #train数据集吞吐器
    train_data_loader = DataLoader(
        ImageDataset(dataset.train,
                     transform=transform_train),  #自定义的数据集,使用训练数据处理器
        batch_size=train_batch_size,  #一个批次的大小(一个批次有多少个图片张量)
        drop_last=True,  #丢弃最后无法称为一整个批次的数据
    )
    print("train_data_loader inited")

    #query数据集吞吐器
    query_data_loader = DataLoader(
        ImageDataset(dataset.query,
                     transform=transform_test),  #自定义的数据集,使用测试数据处理器
        batch_size=test_batch_size,  #一个批次的大小(一个批次有多少个图片张量)
        shuffle=False,  #不重排
        drop_last=True,  #丢弃最后无法称为一整个批次的数据
    )
    print("query_data_loader inited")

    #gallery数据集吞吐器
    gallery_data_loader = DataLoader(
        ImageDataset(dataset.gallery,
                     transform=transform_test),  #自定义的数据集,使用测试数据处理器
        batch_size=test_batch_size,  #一个批次的大小(一个批次有多少个图片张量)
        shuffle=False,  #不重排
        drop_last=True,  #丢弃最后无法称为一整个批次的数据
    )
    print("gallery_data_loader inited\n")

    #加载模型
    model = ReIDNet(num_classes=751,
                    loss={'softmax'})  #指定分类的数量,与使用的损失函数以便决定模型输出何种计算结果
    print("=>ReIDNet loaded")
    print("Model size: {:.5f}M\n".format(
        sum(p.numel() for p in model.parameters()) / 1000000.0))

    #损失函数
    criterion_class = nn.CrossEntropyLoss()
    """
    优化器
    参数1,待优化的参数
    参数2,学习率
    参数3,权重衰减
    """
    optimizer = torch.optim.SGD(model.parameters(),
                                lr=train_lr,
                                weight_decay=5e-04)
    """
    动态学习率
    参数1,指定使用的优化器
    参数2,mode,可选择‘min’(min表示当监控量停止下降的时候,学习率将减小)或者‘max’(max表示当监控量停止上升的时候,学习率将减小)
    参数3,factor,代表学习率每次降低多少
    参数4,patience,容忍网路的性能不提升的次数,高于这个次数就降低学习率
    参数5,min_lr,学习率的下限
    """
    scheduler = lr_scheduler.ReduceLROnPlateau(optimizer,
                                               mode='min',
                                               factor=dy_step_gamma,
                                               patience=10,
                                               min_lr=0.0001)

    #如果是测试
    if evaluate:
        test(model, query_data_loader, gallery_data_loader)
        return 0
    #如果是训练
    print('————model start training————\n')
    bt = time.time()  #训练的开始时间
    for epoch in range(start_epoch, end_epoch):
        model.train(True)
        train(epoch, model, criterion_class, optimizer, scheduler,
              train_data_loader)
    et = time.time()  #训练的结束时间
    print('**模型训练结束, 保存最终参数到{}**\n'.format(final_model_path))
    torch.save(model.state_dict(), final_model_path)
    print('————训练总用时{:.2f}小时————'.format((et - bt) / 3600.0))
Пример #2
0
            x_maxrange = new_width - self.width
            y_maxrange = new_height - self.height
            # 计算随机裁剪XY轴起点
            x_start = int(round(random.uniform(0, x_maxrange)))
            y_start = int(round(random.uniform(0, y_maxrange)))
            # 进行裁剪
            img = resize_img.crop((x_start, y_start, x_start + self.width,
                                   y_start + self.height))
        return img


if __name__ == '__main__':
    from dataset_manager import Market1501
    from dataset_loader import ImageDataset

    dataset = Market1501()
    train_loader = ImageDataset(dataset.train)
    plt.figure()
    j = 1
    # 从训练集中获取前两张图片进行处理,并使用matplot显示图片
    for batch_id, (img, pid, cid) in enumerate(train_loader):
        if (batch_id < 2):
            transform = Random2DTransform(64, 64, 0.5)
            img_t = transform(img)
            img_t = np.array(img_t)
            plt.subplot(1, 2, j)
            plt.imshow(img)  # 显示图片
            plt.savefig()
            j = j + 1
            plt.subplot(1, 2, j)
            plt.imshow(img_t)  # 显示图片