class EIG_classifier_early_layers(nn.Module):
    def __init__(self, imageset='fiv'):
        super(EIG_classifier_early_layers, self).__init__()

        self.eig_classifier = EIG_classifier()
        self.eig_classifier.load_state_dict(
            torch.load(
                os.path.join(CONFIG['PATHS', 'checkpoints'], 'eig_classifier',
                             'checkpoint_' + imageset +
                             '.pth.tar'))['state_dict'])

        self.vision_module = self.eig_classifier.vision_module
        self.vision_module_conv1 = nn.Sequential(
            *list(self.vision_module.children())[:4])
        self.vision_module_conv2 = nn.Sequential(
            *list(self.vision_module.children())[4:8])
        self.vision_module_conv3 = nn.Sequential(
            *list(self.vision_module.children())[8:10])
        self.vision_module_conv4 = nn.Sequential(
            *list(self.vision_module.children())[10:12])

    def forward(self, x, segment=False, add_offset=False):
        dtype = torch.FloatTensor
        if segment == True:
            x = nn.Upsample(size=(192, 192),
                            mode='bilinear',
                            align_corners=True)(x)
            segment_vols = self.eig_classifier.segmentation(x / 255.)[0]
            segmented = process_segmentation(segment_vols.detach(), x / 255,
                                             add_offset)
            x = segmented * 255

        mean = torch.FloatTensor(
            [104.0510072177276, 112.51448910834733,
             116.67603893449996]).cuda()

        out_interim = []
        x = x - mean.view(1, -1, 1, 1)
        x = self.vision_module_conv1(x)
        out_interim.append(x.detach()[0].cpu().numpy().flatten())
        x = self.vision_module_conv2(x)
        out_interim.append(x.detach()[0].cpu().numpy().flatten())
        x = self.vision_module_conv3(x)
        out_interim.append(x.detach()[0].cpu().numpy().flatten())
        x = self.vision_module_conv4(x)
        out_interim.append(x.detach()[0].cpu().numpy().flatten())

        return out_interim
    def __init__(self, imageset='fiv'):
        super(EIG_classifier_early_layers, self).__init__()

        self.eig_classifier = EIG_classifier()
        self.eig_classifier.load_state_dict(
            torch.load(
                os.path.join(CONFIG['PATHS', 'checkpoints'], 'eig_classifier',
                             'checkpoint_' + imageset +
                             '.pth.tar'))['state_dict'])

        self.vision_module = self.eig_classifier.vision_module
        self.vision_module_conv1 = nn.Sequential(
            *list(self.vision_module.children())[:4])
        self.vision_module_conv2 = nn.Sequential(
            *list(self.vision_module.children())[4:8])
        self.vision_module_conv3 = nn.Sequential(
            *list(self.vision_module.children())[8:10])
        self.vision_module_conv4 = nn.Sequential(
            *list(self.vision_module.children())[10:12])
def main():
    parser = argparse.ArgumentParser(description='Fine the EIG networks f2 and train f3 to obtain EIG_CLASSIFIER')
    parser.add_argument('--out', type = str, default = '', metavar='PATH',
                        help='Directory to output the result if other than what is specified in the config')
    parser.add_argument('--imageset', default='bfm', type=str,
                        help='Train with BFM (bfm) images or FIV (fiv) images?')
    parser.add_argument("--image", "-is", type = int, nargs="+",
                        default = (3,227,227), help = "Image size. Def: (3, 227, 227)")
    parser.add_argument("--num-classes", "-nc", type = int, metavar="N",
                        default = 25, help = "Number of unique individual identities. Default: 25.")
    parser.add_argument('--start-epoch', default=0, type=int, metavar='N',
                        help='manual epoch number (useful on restarts)')
    parser.add_argument('-b', '--batch-size', default=20, type=int,
                        metavar='N', help='mini-batch size (default: 20)')
    parser.add_argument('--lr', '--learning-rate', default=0.0005, type=float,
                        metavar='LR', help='initial learning rate')
    parser.add_argument('--print-freq', '-p', default=10, type=int,
                        metavar='N', help='print frequency (default: 10)')
    parser.add_argument('--resume', default='', type=str, metavar='PATH',
                        help='path to latest checkpoint (default: none)')

    global args
    args = parser.parse_args()

    assert args.imageset in ['bfm', 'fiv'], 'set imageset to either bfm or fiv; e.g., --imageset fiv'

    # create model using the pretrained alexnet.
    print("=> Construct the model...")
    
    model = EIG_classifier()
    model.cuda()

    # define loss function (criterion) and optimizer
    criterion = nn.CrossEntropyLoss().cuda()

    # finetune SFCL and the new identity layer
    optimizer = torch.optim.SGD([
        {'params': model.fc_layers.parameters(), 'lr':  0.0005},
        {'params': model.classifier.parameters(), 'lr': 0.0005}
    ])

    if args.out == '':
        out_path = os.path.join(CONFIG['PATHS', 'checkpoints'], 'eig_classifier', datetime.datetime.now().strftime('%Y-%m-%d_%H-%M-%S'))
        os.makedirs(out_path)
    else:
        out_path = args.out
    print("=> loading checkpoint '{}'".format(args.resume))
    resume_path = args.resume
    checkpoint = torch.load(resume_path)
    args.start_epoch = checkpoint['epoch']
    model.load_state_dict(checkpoint['state_dict'], strict=False)
    print("=> loaded checkpoint '{}' (epoch {})"
          .format(args.resume, checkpoint['epoch']))

    print("Current run location: {}".format(out_path))
    
    # Initialize the datasets
    if args.imageset == 'fiv':
        train_loader = datasets.BFMId(os.path.join(CONFIG['PATHS', 'databases'], 'FIV_segment_bootstrap.hdf5'), raw_image = True, input_shape = args.image)
        epochs = args.start_epoch + 20
     else:
        d_full = datasets.BFM09(os.path.join(CONFIG['PATHS', 'databases'], 'bfm09_backface_culling_id_ft.hdf5'), raw_image = True, input_shape = args.image)
        d_segment = datasets.BFM09(os.path.join(CONFIG['PATHS', 'databses'], 'bfm09_backface_culling_id_ft_segment.hdf5'), raw_image = True, input_shape = args.image)
        train_loader = (d_full, d_segment)
        epochs = args.start_epoch + 2


    for epoch in range(args.start_epoch, epochs):

        # train for one epoch
        train(train_loader, model, criterion, optimizer, epoch)

        # remember best prec@1 and save checkpoint
        save_checkpoint({
            'epoch': epoch + 1,
            'state_dict': model.state_dict(),
            'optimizer' : optimizer.state_dict(),
        }, False, os.path.join(out_path, 'checkpoint_' + args.imageset + '.pth.tar'))
Beispiel #4
0
def load_image(image, size):
    image = image.resize(size)
    image = np.asarray(image)
    if len(image.shape) == 2:
        image = np.stack((image, ) * 3, -1)
    else:
        image = image[:, :, 0:3]
    image = np.moveaxis(image, 2, 0)
    image = image.astype(np.float32)
    return image


models_d = {
    'eig': EIG(),
    'eig_classifier': EIG_classifier(),
    'vgg': lambda d: VGG_Linear_Decoder(1, d),
}

image_sizes = {
    'eig': (227, 227),
    'eig_classifier': (227, 227),
    'vgg': (224, 224),
}

filenames_d = {
    'eig': 'eig.hdf5',
    'eig_classifier': 'eig_classifier.hdf5',
    'vgg': 'vgg.hdf5',
}