コード例 #1
0
def prune_network(args, network=None):
    device = torch.device("cuda" if args.gpu_no >= 0 else "cpu")

    if network is None:
        network = VGG(args.vgg, args.data_set)
        if args.load_path:
            check_point = torch.load(args.load_path)
            network.load_state_dict(check_point['state_dict'])

    # prune network
    network = prune_step(network, args.prune_layers, args.prune_channels,
                         args.independent_prune_flag)
    network = network.to(device)
    print("-*-" * 10 + "\n\tPrune network\n" + "-*-" * 10)
    print(network)

    if args.retrain_flag:
        # update arguemtns for retraing pruned network
        args.epoch = args.retrain_epoch
        args.lr = args.retrain_lr
        args.lr_milestone = None  # don't decay learning rate

        network = train_network(args, network)

    return network
コード例 #2
0
def main():
    parser = make_parser()
    args = parser.parse_args()
    use_gpu = torch.cuda.is_available()

    content_image, style_image, input_image = read_images(args, use_gpu)

    #MODEL
    vgg = VGG()
    loss = Loss()
    if use_gpu:
        vgg = VGG().cuda()
        loss = Loss().cuda()
    for param in vgg.parameters():
        param.requires_grad = False

    #OPTIMIZER
    learning_rate = args.lr
    optimizer = optim.LBFGS([input_image], lr=learning_rate)
    num_iterations = args.iter
    losses = []

    content_3_2 = vgg(content_image, ["3_2"])[0]
    style_features = vgg(style_image, ["1_1", "2_1", "3_1", "4_1", "5_1"])

    def closure():
        optimizer.zero_grad()

        input_features = vgg(input_image,
                             ["1_1", "2_1", "3_1", "4_1", "5_1", "3_2"])
        input_3_2 = input_features[-1]
        input_features = input_features[:-1]

        total_loss = loss(input_features, input_3_2, content_3_2,
                          style_features)
        losses.append(total_loss.data.cpu().numpy()[0])
        total_loss.backward()
        input_image.data.clamp_(0, 1)

        return total_loss

    for i in range(num_iterations):
        optimizer.step(closure)

        if i % 3 == 0:
            print(i / num_iterations * 100, "%")
    print("100.0 %")
    graph_losses(losses)

    output = Image.fromarray((input_image.data.squeeze() * 255).permute(
        1, 2, 0).cpu().numpy().astype(np.uint8))
    output.save(args.output)
    show_image(input_image)
コード例 #3
0
def train_vgg(batch_size, epoch):
    dataLoader = DataLoader()
    # 记录损失和准确率,用于画图
    train_loss_results = []
    train_accuracy_results = []
    # 构建VGG网络 VGG11,VGG16,VGG19
    model = VGG.build_vgg('vgg11')
    # 设置优化器
    optimizer = tf.keras.optimizers.SGD(learning_rate=0.01, momentum=0.8)
    # 设置损失函数
    loss_object = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False)
    train_loss = tf.keras.metrics.Mean(name='train_loss')
    train_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(name='train_accuracy')
    # 开始训练
    for e in range(epoch):
        num_iter = dataLoader.num_train//batch_size #计算1轮epoch需要迭代的批次数
        # Reset the metrics at the start of the next epoch
        train_loss.reset_states()
        train_accuracy.reset_states()
        for i in range(num_iter):
            images, labels = dataLoader.get_batch_train(batch_size)
            with tf.GradientTape() as tape:
                preds = model(images, training=True) # 获取预测值
                loss = loss_object(labels,preds) # 计算损失
                # loss += sum(model.losses) # 总损失
            gradients = tape.gradient(loss, model.trainable_variables) # 更新参数梯度
            optimizer.apply_gradients(zip(gradients, model.trainable_variables)) #更新优化器参数
            train_loss(loss) # 更新损失
            train_accuracy(labels, preds) # 更新准确率
        train_loss_results.append(train_loss.result())
        train_accuracy_results.append(train_accuracy.result())
        model.save_weights("./weight/"+str(e+1)+"_epoch_vgg11_weight.h5")
        print('Epoch {}, loss:{}, Accuracy:{}%'.format(e+1,train_loss.result(),train_accuracy.result()*100))
    show_loss_plot(train_loss_results, train_accuracy_results)
コード例 #4
0
def test_vgg(model_path, batch_size):
    dataLoader = DataLoader()
    # 构建VGG网络 vgg11;vgg16;vgg19
    model = VGG.build_vgg('vgg11')
    model.compile(loss='sparse_categorical_crossentropy', metrics=['accuracy'])
    model.build((1, 32, 32, 1))
    model.load_weights(model_path)
    # 显示模型网络结构
    print(model.summary())
    # 评估模型
    test_images, test_labels = dataLoader.get_batch_test(batch_size)
    model.evaluate(test_images, test_labels, verbose=2)
コード例 #5
0
def test_network(args, network=None, data_set=None):
    device = torch.device("cuda" if args.gpu_no >= 0 else "cpu")

    if args.net == 'resnet50' and network is None:
        network = resnet()
        if args.load_path:
            check_point = torch.load(args.load_path)
            network.load_state_dict(check_point['state_dict'])
    elif network is None:
        network = VGG(args.net, args.data_set)
        if args.load_path:
            check_point = torch.load(args.load_path)
            network.load_state_dict(check_point['state_dict'])
    network.to(device)
    #print(network)

    if data_set is None:
        data_set = get_data_set(args, train_flag=False)
    data_loader = torch.utils.data.DataLoader(data_set,
                                              batch_size=1,
                                              shuffle=False)

    top1, top5 = test_step(network, data_loader, device)

    return network, data_set, (top1, top5)
コード例 #6
0
def prune_network(args, network=None):
    resnet_prune_layer = 1
    device = torch.device("cuda" if args.gpu_no >= 0 else "cpu")

    if args.net == 'resnet50' and network is None:
        network = resnet()
        if args.load_path:
            check_point = torch.load(args.load_path)
            network.load_state_dict(check_point['state_dict'])
    elif network is None:
        network = VGG(args.net, args.data_set)
        if args.load_path:
            check_point = torch.load(args.load_path)
            network.load_state_dict(check_point['state_dict'])

    # prune network
    if args.net == 'resnet50':
        if resnet_prune_layer == 1:
            network = prune_resnet_1(network, args.prune_layers, args.independent_prune_flag)
        if resnet_prune_layer == 2:
            network = prune_resnet_2(network, args.prune_layers, args.independent_prune_flag)
        if resnet_prune_layer == 3:
            network = prune_resnet_3(network, args.prune_layers, args.independent_prune_flag)
        
    else:
        network = prune_step(network, args.prune_layers, args.prune_channels, args.independent_prune_flag)
    network = network.to(device)
    print("-*-"*10 + "\n\tPrune network\n" + "-*-"*10)
    print(network)

    if args.retrain_flag:
        # update arguments for retraining pruned network
        args.epoch = args.retrain_epoch
        args.lr = args.retrain_lr
        args.lr_milestone = None # don't decay learning rate

        network = train_network(args, network)
    
    return network
コード例 #7
0
def train_network(args, network=None, data_set=None):
    device = torch.device("cuda" if args.gpu_no >= 0 else "cpu")
    print("1. Finish check device: ", device)

    if network is None:
        network = VGG(args.vgg, args.data_set)
    network = network.to(device)
    print("2. Finish create network")

    if data_set is None:
        data_set = get_data_set(args, train_flag=True)
    print("3. Finish load dataset")

    loss_calculator = Loss_Calculator()

    optimizer, scheduler = get_optimizer(network, args)

    if args.resume_flag:
        check_point = torch.load(args.load_path)
        network.load_state_dict(check_point['state_dict'])
        loss_calculator.loss_seq = check_point['loss_seq']
        args.start_epoch = check_point['epoch']  # update start epoch

    print("-*-" * 10 + "\n\tTrain network\n" + "-*-" * 10)
    for epoch in range(args.start_epoch, args.epoch):
        # make shuffled data loader
        data_loader = torch.utils.data.DataLoader(data_set,
                                                  batch_size=args.batch_size,
                                                  shuffle=True)

        # train one epoch
        train_step(network, data_loader, loss_calculator, optimizer, device,
                   epoch, args.print_freq)

        # adjust learning rate
        if scheduler is not None:
            scheduler.step()

        torch.save(
            {
                'epoch': epoch + 1,
                'state_dict': network.state_dict(),
                'loss_seq': loss_calculator.loss_seq
            }, args.save_path + "check_point.pth")

    return network