Exemplo n.º 1
0
def main():
    # https://storage.googleapis.com/mobilenet_v3/checkpoints/v3-large_224_1.0_float.tgz
    ckpt_path = './v3-large_224_1.0_float/pristine/model.ckpt-540000'
    save_path = './pre_mobilev3.h5'
    m = mobilenet_v3_large(input_shape=(224, 224, 3), num_classes=1001, include_top=True)
    m_info = [(i.name.replace(":0", ""), list(i.shape))
              for i in m.weights]
    weights = rename_var(ckpt_path, m_info)
    m.set_weights(weights)
    m.save_weights(save_path)
def main():
    data_root = "D:/My_code/pythonProject/deep-learning-for-image-processing-master/data_set/flower_data/flower_photos"  # get data root path

    if not os.path.exists("./save_weights"):
        os.makedirs("./save_weights")

    im_height = 224
    im_width = 224
    batch_size = 16
    epochs = 20
    num_classes = 5

    # data generator with data augmentation
    train_ds, val_ds = generate_ds(data_root, im_height, im_width, batch_size)

    # create model
    model = mobilenet_v3_large(input_shape=(im_height, im_width, 3),
                               num_classes=num_classes,
                               include_top=True)

    # load weights
    pre_weights_path = './weights_mobilenet_v3_large_224_1.0_float.h5'
    assert os.path.exists(pre_weights_path), "cannot find {}".format(
        pre_weights_path)
    model.load_weights(pre_weights_path, by_name=True, skip_mismatch=True)

    # freeze layer, only training 2 last layers
    for layer in model.layers:
        if layer.name not in ["Conv_2", "Logits/Conv2d_1c_1x1"]:
            layer.trainable = False
        else:
            print("training: " + layer.name)

    model.summary()

    # using keras low level api for training
    loss_object = tf.keras.losses.SparseCategoricalCrossentropy(
        from_logits=False)
    optimizer = tf.keras.optimizers.Adam(learning_rate=0.0005)

    train_loss = tf.keras.metrics.Mean(name='train_loss')
    train_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(
        name='train_accuracy')

    val_loss = tf.keras.metrics.Mean(name='val_loss')
    val_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(
        name='val_accuracy')

    @tf.function
    def train_step(train_images, train_labels):
        with tf.GradientTape() as tape:
            output = model(train_images, training=True)
            loss = loss_object(train_labels, output)
        gradients = tape.gradient(loss, model.trainable_variables)
        optimizer.apply_gradients(zip(gradients, model.trainable_variables))

        train_loss(loss)
        train_accuracy(train_labels, output)

    @tf.function
    def val_step(val_images, val_labels):
        output = model(val_images, training=False)
        loss = loss_object(val_labels, output)

        val_loss(loss)
        val_accuracy(val_labels, output)

    best_test_loss = float('inf')
    for epoch in range(epochs):
        train_loss.reset_states()  # clear history info
        train_accuracy.reset_states()  # clear history info
        val_loss.reset_states()  # clear history info
        val_accuracy.reset_states()  # clear history info

        # train
        train_bar = tqdm(train_ds)
        for images, labels in train_bar:
            train_step(images, labels)

            # print train process
            train_bar.desc = "train epoch[{}/{}] loss:{:.3f}, acc:{:.3f}".format(
                epoch + 1, epochs, train_loss.result(),
                train_accuracy.result())

        # validate
        val_bar = tqdm(val_ds, colour='green')
        for images, labels in val_bar:
            val_step(images, labels)

            # print val process
            val_bar.desc = "valid epoch[{}/{}] loss:{:.3f}, acc:{:.3f}".format(
                epoch + 1, epochs, val_loss.result(), val_accuracy.result())

        if val_loss.result() < best_test_loss:
            best_test_loss = val_loss.result()
            model.save_weights("./save_weights/resMobileNetV3.ckpt",
                               save_format="tf")
Exemplo n.º 3
0
def main():
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    print("using {} device.".format(device))

    batch_size = 16
    epochs = 5

    data_transform = {
        "train":
        transforms.Compose([
            transforms.RandomResizedCrop(224),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ]),
        "val":
        transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ])
    }

    # data_root = os.path.abspath(os.path.join(os.getcwd(), "../.."))  # get data root path
    data_root = os.path.abspath(os.getcwd())  # get data root path
    image_path = os.path.join(data_root, "data_set",
                              "flower_data")  # flower data set path
    assert os.path.exists(image_path), "{} path does not exist.".format(
        image_path)
    train_dataset = datasets.ImageFolder(root=os.path.join(
        image_path, "train"),
                                         transform=data_transform["train"])
    train_num = len(train_dataset)

    # {'daisy':0, 'dandelion':1, 'roses':2, 'sunflower':3, 'tulips':4}
    flower_list = train_dataset.class_to_idx
    cla_dict = dict((val, key) for key, val in flower_list.items())
    # write dict into json file
    json_str = json.dumps(cla_dict, indent=4)
    with open('class_indices.json', 'w') as json_file:
        json_file.write(json_str)

    nw = min([os.cpu_count(), batch_size if batch_size > 1 else 0,
              8])  # number of workers
    print('Using {} dataloader workers every process'.format(nw))

    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=batch_size,
                                               shuffle=True,
                                               num_workers=nw)

    validate_dataset = datasets.ImageFolder(root=os.path.join(
        image_path, "val"),
                                            transform=data_transform["val"])
    val_num = len(validate_dataset)
    validate_loader = torch.utils.data.DataLoader(validate_dataset,
                                                  batch_size=batch_size,
                                                  shuffle=False,
                                                  num_workers=nw)

    print("using {} images for training, {} images for validation.".format(
        train_num, val_num))

    # * create model  实例化模型,类别为5
    # net = MobileNetV2(num_classes=5)
    net = mobilenet_v3_large(num_classes=5)

    # load pretrain weights
    # download url: https://download.pytorch.org/models/mobilenet_v2-b0353104.pth
    # model_weight_path = "./mobilenet_v2.pth"
    # model_weight_path = ".\\pytorch_classification\\Test6_mobilenet\\mobilenet_v2.pth"
    model_weight_path = ".\\pytorch_classification\\Test6_mobilenet\\mobilenet_v3_large.pth"

    assert os.path.exists(model_weight_path), "file {} dose not exist.".format(
        model_weight_path)
    pre_weights = torch.load(model_weight_path, map_location=device)
    # torch.load载入预训练模型参数,是一个字典类型

    # * 由于预训练模型是在ImageNet数据集上训练的,最后是1000分类;而这里是5分类,不可以使用
    # delete classifier weights
    # pre_dict = {k: v for k, v in pre_weights.items() if net.state_dict()[k].numel() == v.numel()}  # 这是后来的写法
    pre_dict = {k: v
                for k, v in pre_weights.items()
                if "classifier" not in k}  # 这是原来的写法
    # * 遍历权重字典,如果有 "classifier"  说明是最后的全连接层参数

    missing_keys, unexpected_keys = net.load_state_dict(pre_dict, strict=False)
    # 用net.load_state_dict载入清洗过的权重字典

    # ! freeze features weights 冻结特征提取部分的所有权重
    # ! 如果想训练整个网络的权重,注释下面的话
    for param in net.features.parameters():  # 遍历net.features下的所有参数
        # 如果写的是net.parameters()就会冻结所有的参数,net.features.parameters():只是特征提取部分的参数
        param.requires_grad = False  # * 这些参数的requires_grad 设置为False,就不会对其求导和更新

    net.to(device)

    # define loss function
    loss_function = nn.CrossEntropyLoss()

    # construct an optimizer
    params = [p for p in net.parameters() if p.requires_grad]
    optimizer = optim.Adam(params, lr=0.0001)

    best_acc = 0.0
    save_path = './MobileNetV2.pth'  # 改成自己的
    train_steps = len(train_loader)
    for epoch in range(epochs):
        # train
        net.train()
        running_loss = 0.0
        train_bar = tqdm(train_loader)
        for step, data in enumerate(train_bar):
            images, labels = data
            optimizer.zero_grad()
            logits = net(images.to(device))
            loss = loss_function(logits, labels.to(device))
            loss.backward()
            optimizer.step()

            # print statistics
            running_loss += loss.item()

            train_bar.desc = "train epoch[{}/{}] loss:{:.3f}".format(
                epoch + 1, epochs, loss)

        # validate
        net.eval()
        acc = 0.0  # accumulate accurate number / epoch
        with torch.no_grad():
            val_bar = tqdm(validate_loader)
            for val_data in val_bar:
                val_images, val_labels = val_data
                outputs = net(val_images.to(device))
                # loss = loss_function(outputs, test_labels)
                predict_y = torch.max(outputs, dim=1)[1]
                acc += torch.eq(predict_y, val_labels.to(device)).sum().item()

                val_bar.desc = "valid epoch[{}/{}]".format(epoch + 1, epochs)
        val_accurate = acc / val_num
        print('[epoch %d] train_loss: %.3f  val_accuracy: %.3f' %
              (epoch + 1, running_loss / train_steps, val_accurate))

        if val_accurate > best_acc:
            best_acc = val_accurate
            torch.save(net.state_dict(), save_path)

    print('Finished Training')