def test_network(args, network=None, data_set=None):
    device = torch.device("cuda" if args.gpu_no >= 0 else "cpu")

    if args.net == 'resnet50' and network is None:
        network = resnet()
        if args.load_path:
            check_point = torch.load(args.load_path)
            network.load_state_dict(check_point['state_dict'])
    elif network is None:
        network = VGG(args.net, args.data_set)
        if args.load_path:
            check_point = torch.load(args.load_path)
            network.load_state_dict(check_point['state_dict'])
    network.to(device)
    #print(network)

    if data_set is None:
        data_set = get_data_set(args, train_flag=False)
    data_loader = torch.utils.data.DataLoader(data_set,
                                              batch_size=1,
                                              shuffle=False)

    top1, top5 = test_step(network, data_loader, device)

    return network, data_set, (top1, top5)
示例#2
0
def prune_network(args, network=None):
    device = torch.device("cuda" if args.gpu_no >= 0 else "cpu")

    if network is None:
        network = VGG(args.vgg, args.data_set)
        if args.load_path:
            check_point = torch.load(args.load_path)
            network.load_state_dict(check_point['state_dict'])

    # prune network
    network = prune_step(network, args.prune_layers, args.prune_channels,
                         args.independent_prune_flag)
    network = network.to(device)
    print("-*-" * 10 + "\n\tPrune network\n" + "-*-" * 10)
    print(network)

    if args.retrain_flag:
        # update arguemtns for retraing pruned network
        args.epoch = args.retrain_epoch
        args.lr = args.retrain_lr
        args.lr_milestone = None  # don't decay learning rate

        network = train_network(args, network)

    return network
def train_network(args, network=None, data_set=None):
    device = torch.device("cuda" if args.gpu_no >= 0 else "cpu")
    print("1. Finish check device: ", device)

    if network is None:
        network = VGG(args.vgg, args.data_set)
    network = network.to(device)
    print("2. Finish create network")

    if data_set is None:
        data_set = get_data_set(args, train_flag=True)
    print("3. Finish load dataset")

    loss_calculator = Loss_Calculator()

    optimizer, scheduler = get_optimizer(network, args)

    if args.resume_flag:
        check_point = torch.load(args.load_path)
        network.load_state_dict(check_point['state_dict'])
        loss_calculator.loss_seq = check_point['loss_seq']
        args.start_epoch = check_point['epoch']  # update start epoch

    print("-*-" * 10 + "\n\tTrain network\n" + "-*-" * 10)
    for epoch in range(args.start_epoch, args.epoch):
        # make shuffled data loader
        data_loader = torch.utils.data.DataLoader(data_set,
                                                  batch_size=args.batch_size,
                                                  shuffle=True)

        # train one epoch
        train_step(network, data_loader, loss_calculator, optimizer, device,
                   epoch, args.print_freq)

        # adjust learning rate
        if scheduler is not None:
            scheduler.step()

        torch.save(
            {
                'epoch': epoch + 1,
                'state_dict': network.state_dict(),
                'loss_seq': loss_calculator.loss_seq
            }, args.save_path + "check_point.pth")

    return network
def prune_network(args, network=None):
    resnet_prune_layer = 1
    device = torch.device("cuda" if args.gpu_no >= 0 else "cpu")

    if args.net == 'resnet50' and network is None:
        network = resnet()
        if args.load_path:
            check_point = torch.load(args.load_path)
            network.load_state_dict(check_point['state_dict'])
    elif network is None:
        network = VGG(args.net, args.data_set)
        if args.load_path:
            check_point = torch.load(args.load_path)
            network.load_state_dict(check_point['state_dict'])

    # prune network
    if args.net == 'resnet50':
        if resnet_prune_layer == 1:
            network = prune_resnet_1(network, args.prune_layers, args.independent_prune_flag)
        if resnet_prune_layer == 2:
            network = prune_resnet_2(network, args.prune_layers, args.independent_prune_flag)
        if resnet_prune_layer == 3:
            network = prune_resnet_3(network, args.prune_layers, args.independent_prune_flag)
        
    else:
        network = prune_step(network, args.prune_layers, args.prune_channels, args.independent_prune_flag)
    network = network.to(device)
    print("-*-"*10 + "\n\tPrune network\n" + "-*-"*10)
    print(network)

    if args.retrain_flag:
        # update arguments for retraining pruned network
        args.epoch = args.retrain_epoch
        args.lr = args.retrain_lr
        args.lr_milestone = None # don't decay learning rate

        network = train_network(args, network)
    
    return network