示例#1
0
    def __init__(self, bn=False, mid_feat=False):
        super(VGG16, self).__init__()

        self.conv1 = nn.Sequential(Conv2d(3, 64, 3, same_padding=True, bn=bn),
                                   Conv2d(64, 64, 3, same_padding=True, bn=bn),
                                   nn.MaxPool2d(2, ceil_mode=True))
        self.conv2 = nn.Sequential(
            Conv2d(64, 128, 3, same_padding=True, bn=bn),
            Conv2d(128, 128, 3, same_padding=True, bn=bn),
            nn.MaxPool2d(2, ceil_mode=True))
        network.set_trainable(self.conv1, requires_grad=False)
        network.set_trainable(self.conv2, requires_grad=False)

        self.conv3 = nn.Sequential(
            Conv2d(128, 256, 3, same_padding=True, bn=bn),
            Conv2d(256, 256, 3, same_padding=True, bn=bn),
            Conv2d(256, 256, 3, same_padding=True, bn=bn),
            nn.MaxPool2d(2, ceil_mode=True))
        self.conv4 = nn.Sequential(
            Conv2d(256, 512, 3, same_padding=True, bn=bn),
            Conv2d(512, 512, 3, same_padding=True, bn=bn),
            Conv2d(512, 512, 3, same_padding=True, bn=bn),
            nn.MaxPool2d(2, ceil_mode=True))
        self.conv5 = nn.Sequential(
            Conv2d(512, 512, 3, same_padding=True, bn=bn),
            Conv2d(512, 512, 3, same_padding=True, bn=bn),
            Conv2d(512, 512, 3, same_padding=True, bn=bn))

        network.set_trainable(self.conv3, requires_grad=False)
        network.set_trainable(self.conv4, requires_grad=False)
        network.set_trainable(self.conv5, requires_grad=False)
        self.mid_feat = mid_feat
示例#2
0
    def __init__(self, bn=False, channel_in=3):
        super(VGG16, self).__init__()

        self.conv1 = nn.Sequential(
            Conv2d(channel_in, 64, 3, same_padding=True, bn=bn),
            Conv2d(64, 64, 3, same_padding=True, bn=bn), nn.MaxPool2d(2))
        self.conv2 = nn.Sequential(
            Conv2d(64, 128, 3, same_padding=True, bn=bn),
            Conv2d(128, 128, 3, same_padding=True, bn=bn), nn.MaxPool2d(2))
        network.set_trainable(self.conv1, requires_grad=False)
        network.set_trainable(self.conv2, requires_grad=False)

        self.conv3 = nn.Sequential(
            Conv2d(128, 256, 3, same_padding=True, bn=bn),
            Conv2d(256, 256, 3, same_padding=True, bn=bn),
            Conv2d(256, 256, 3, same_padding=True, bn=bn), nn.MaxPool2d(2))
        self.conv4 = nn.Sequential(
            Conv2d(256, 512, 3, same_padding=True, bn=bn),
            Conv2d(512, 512, 3, same_padding=True, bn=bn),
            Conv2d(512, 512, 3, same_padding=True, bn=bn), nn.MaxPool2d(2))
        self.conv5 = nn.Sequential(
            Conv2d(512, 512, 3, same_padding=True, bn=bn),
            Conv2d(512, 512, 3, same_padding=True, bn=bn),
            Conv2d(512, 512, 3, same_padding=True, bn=bn))
    train_loader = torch.utils.data.DataLoader(train_set, batch_size=1, shuffle=True, num_workers=8, pin_memory=True)
    test_loader = torch.utils.data.DataLoader(test_set, batch_size=1, shuffle=False, num_workers=8, pin_memory=True)
    net = RPN(not args.use_normal_anchors)
    #if args.resume_training:
    #    print 'Resume training from: {}'.format(args.res ume_model)
    #    if len(args.resume_model) == 0:
    #        raise Exception('[resume_model] not specified')
    #    network.load_net(args.resume_model, net)
    #    optimizer = torch.optim.SGD([
    #            {'params': list(net.parameters())[26:]}, 
    #            ], lr=args.lr, momentum=args.momentum, weight_decay=0.0005)
	print 'Training from scratch...Initializing network...'
	optimizer = torch.optim.SGD(list(net.parameters())[26:], lr=args.lr, momentum=args.momentum, weight_decay=0.0005)

    network.set_trainable(net.features, requires_grad=False)
    #net.cuda()

    if not os.path.exists(args.output_dir):
        os.mkdir(args.output_dir)

    best_recall = np.array([0.0, 0.0])

    for epoch in range(0, args.max_epoch):
        
        # Training
        train(train_loader, net, optimizer, epoch)

        # Testing
        recall = test(test_loader, net)
        print('Epoch[{epoch:d}]: '