Ejemplo n.º 1
0
def MakeFeatureModel(modelName='vgg16'):
    """Returns a perceptual loss network that outputs feature maps"""
    # Construct the feature extraction model
    models = ['vgg16', 'vgg19']
    if modelName not in models:
        raise ValueError(
            'Invalid model name; available models: {}'.format(models))

    elif modelName == models[0]:  #vgg16 pretrained on imagenet
        FeatureModel = VGG16.vgg16(pretrained=True, feat_ex=True)
    elif modelName == models[1]:  #vgg19 pretrained on imagenet
        FeatureModel = VGG19.vgg19(pretrained=True, feat_ex=True)

    return FeatureModel
Ejemplo n.º 2
0
        print('Best score: ', round(best_score, 5), ' @ threshold =', best)
    return best


if __name__ == '__main__':

    img_rows, img_cols = 224, 224  # Resolution of inputs
    channels = 3
    num_classes = 17
    test_weights_path = 'models/vgg19_weights.full.h5'
    #last_weights_path = 'models/inceptionv4_weights.last.h5'
    # Load our model
    #model = inception_v4_model(img_rows=img_rows, img_cols=img_cols, channels=channels, num_classes=num_classes)
    #model.load_weights(test_weights_path)
    model = vgg19(img_rows=img_rows,
                  img_cols=img_cols,
                  channels=channels,
                  num_classes=num_classes)
    model.load_weights(test_weights_path)

    X_valid = []
    Y_valid = []
    for f, tags in tqdm(df_valid.values, miniters=1000):
        img = cv2.imread('data/valid/{}.jpg'.format(f))
        #img = process_image(img)
        img = cv2.resize(img, (224, 224), interpolation=cv2.INTER_AREA)
        targets = np.zeros(17)
        for t in tags.split(' '):
            targets[label_map[t]] = 1
        X_valid.append(img)
        Y_valid.append(targets)
Ejemplo n.º 3
0
    model.load_state_dict(torch.load(path, device))
    model.eval()
    return model


device = torch.device('cuda')  # device can be "cpu" or "gpu"

models = {
    'sha': load_return_model(sha_path,
                             vgg16dres(map_location=device).to(device)),
    'ucf': load_return_model(ucf_path,
                             vgg16dres1(map_location=device).to(device)),
    'shb': load_return_model(shb_path,
                             vgg16dres1(map_location=device).to(device)),
    'dm_shb': load_return_model(dm_shb_path,
                                vgg19().to(device)),
    'dm_sha': load_return_model(dm_sha_path,
                                vgg19().to(device)),
}


def predict(inp, model):
    inp = Image.fromarray(inp.astype('uint8'), 'RGB')
    inp = transforms.ToTensor()(inp).unsqueeze(0)
    inp = inp.to(device)
    with torch.set_grad_enabled(False):
        outputs, _ = models[model](inp)
    count = torch.sum(outputs).item()
    vis_img = outputs[0, 0].cpu().numpy()
    # normalize density map values from 0 to 1, then map it to 0-255.
    vis_img = (vis_img - vis_img.min()) / (vis_img.max() - vis_img.min() +
Ejemplo n.º 4
0
def train(args):
    torch.backends.cudnn.benchmark = True
    torch.backends.cudnn.deterministic = True
    np.random.seed(args.seed)
    random.seed(args.seed)
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed_all(args.seed)

    # -------------------- Load data ----------------------------------
    transform = transforms.Compose([
        Rescale((224, 224)),
        ColorJitter(0.5, 0.5, 0.5, 0.3, 0.5),
        ToTensor(),
    ])
    dataset = FaceDataset(args.train_data, True, transform=transform)
    data_loader = DataLoader(dataset,
                             shuffle=True,
                             batch_size=args.batch_size,
                             drop_last=True,
                             num_workers=4)

    # ----------------- Define networks ---------------------------------
    Gnet = SketchNet(in_channels=3, out_channels=1, norm_type=args.Gnorm)
    Dnet = DNet(norm_type=args.Dnorm)
    vgg19_model = vgg19(args.vgg19_weight)

    gpu_ids = [int(x) for x in args.gpus.split(',')]
    device = torch.device(
        'cuda') if torch.cuda.is_available() else torch.device('cpu')
    Gnet.to(device)
    Dnet.to(device)
    if len(gpu_ids) > 0:
        Gnet = nn.DataParallel(Gnet, device_ids=gpu_ids)
        Dnet = nn.DataParallel(Dnet, device_ids=gpu_ids)
        vgg19_model = nn.DataParallel(vgg19_model, device_ids=gpu_ids)

    Gnet.train()
    Dnet.train()

    if args.resume:
        weights = glob(os.path.join(args.save_weight_path, '*-*.pth'))
        weight_path = sorted(weights)[-1][:-5]
        Gnet.load_state_dict(torch.load(weight_path + 'G.pth'))
        Dnet.load_state_dict(torch.load(weight_path + 'D.pth'))

    # ---------------- set optimizer and learning rate ---------------------
    args.epochs = np.ceil(args.epochs * 1000 / len(dataset))
    args.epochs = max(int(args.epochs), 4)
    ms = [int(1. / 4 * args.epochs), int(2. / 4 * args.epochs)]

    # optim_G = torch.optim.SGD(Gnet.parameters(), args.lr, momentum=0.9, weight_decay=1e-4)
    optim_G = torch.optim.AdamW(Gnet.parameters(), args.glr)
    optim_D = torch.optim.AdamW(Dnet.parameters(), args.dlr)
    scheduler_G = MultiStepLR(optim_G, milestones=ms, gamma=0.1)
    scheduler_D = MultiStepLR(optim_D, milestones=ms, gamma=0.1)
    mse_crit = nn.MSELoss()

    # ---------------------- Define reference styles and feature loss layers ----------
    if args.train_style == 'cufs':
        ref_style_dataset = ['CUHK_student', 'AR', 'XM2VTS']
        ref_feature = './data/cufs_feature_dataset.pth'
        ref_img_list = './data/cufs_reference_img_list.txt'
    elif args.train_style == 'cufsf':
        ref_style_dataset = ['CUFSF']
        ref_feature = './data/cufsf_feature_dataset.pth'
        ref_img_list = './data/cufsf_reference_img_list.txt'
    else:
        assert 1 == 0, 'Train style {} not supported.'.format(args.train_style)

    vgg_feature_layers = ['r11', 'r21', 'r31', 'r41', 'r51']
    feature_loss_layers = list(
        itertools.compress(vgg_feature_layers, args.flayers))
    utils.print_network(Gnet)
    utils.print_network(Dnet)
    print("Initialized")
    log = logger.Logger(args.save_weight_path)

    for e in range(args.epochs):
        sample_count = 0
        for batch_idx, batch_data in enumerate(data_loader):
            # ---------------- Load data -------------------
            start = time()
            train_img, train_img_org = [
                utils.tensorToVar(x) for x in batch_data
            ]
            topk_sketch_img, topk_photo_img = search_dataset.find_photo_sketch_batch(
                train_img_org,
                ref_feature,
                ref_img_list,
                vgg19_model,
                dataset_filter=ref_style_dataset,
                topk=args.topk)
            random_real_sketch = search_dataset.get_real_sketch_batch(
                train_img.size(0),
                ref_img_list,
                dataset_filter=ref_style_dataset)
            end = time()
            data_time = end - start
            sample_count += train_img.size(0)

            # ---------------- Model forward -------------------
            start = time()
            fake_sketch = Gnet(train_img)
            fake_score = Dnet(fake_sketch)
            real_score = Dnet(random_real_sketch)

            real_label = torch.ones_like(fake_score)
            fake_label = torch.zeros_like(fake_score)

            # ----------------- Calculate loss and backward -------------------
            train_img_org_vgg = img_process.subtract_mean_batch(
                train_img_org, 'face')
            topk_sketch_img_vgg = img_process.subtract_mean_batch(
                topk_sketch_img, 'sketch')
            topk_photo_img_vgg = img_process.subtract_mean_batch(
                topk_photo_img, 'face')
            fake_sketch_vgg = img_process.subtract_mean_batch(
                fake_sketch.expand_as(train_img_org), 'sketch', args.meanshift)

            style_loss = loss.feature_mrf_loss_func(
                fake_sketch_vgg,
                topk_sketch_img_vgg,
                vgg19_model,
                feature_loss_layers, [train_img_org_vgg, topk_photo_img_vgg],
                topk=args.topk)

            tv_loss = loss.total_variation(fake_sketch)

            # GAN Loss
            adv_loss = mse_crit(fake_score, real_label) * args.weight[1]
            tv_loss = tv_loss * args.weight[2]
            loss_G = style_loss * args.weight[0] + adv_loss + tv_loss
            loss_D = 0.5 * mse_crit(fake_score, fake_label) + 0.5 * mse_crit(
                real_score, real_label)

            # Update parameters
            optim_D.zero_grad()
            loss_D.backward(retain_graph=True)
            optim_D.step()

            optim_G.zero_grad()
            loss_G.backward()
            optim_G.step()

            scheduler_G.step()
            scheduler_D.step()

            end = time()
            train_time = end - start

            # ----------------- Print result and log the output -------------------
            log.iterLogUpdate(loss_G.item())
            if batch_idx % 100 == 0:
                log.draw_loss_curve()

            msg = "{:%Y-%m-%d %H:%M:%S}\tEpoch [{:03d}/{:03d}]\tBatch [{:03d}/{:03d}]\tData: {:.2f}  Train: {" \
                  ":.2f}\tLoss: G-{:.4f}, Adv-{:.4f}, tv-{:.4f}, D-{:.4f}".format(
                datetime.now(),
                e, args.epochs, sample_count, len(dataset),
                data_time, train_time, *[x for x in [loss_G.item(), adv_loss, tv_loss, loss_D]])
            print(msg)
            log_file = open(os.path.join(args.save_weight_path, 'log.txt'),
                            'a+')
            log_file.write(msg + '\n')
            log_file.close()

        save_weight_name = "epochs-{:03d}-".format(e)
        G_cpu_model = copy.deepcopy(Gnet).cpu()
        D_cpu_model = copy.deepcopy(Dnet).cpu()
        torch.save(
            G_cpu_model.state_dict(),
            os.path.join(args.save_weight_path, save_weight_name + 'G.pth'))
        torch.save(
            D_cpu_model.state_dict(),
            os.path.join(args.save_weight_path, save_weight_name + 'D.pth'))
Ejemplo n.º 5
0
 vgg_last_weights_path = 'models/vgg19_weights.last.h5'
 vgg_test_weights_path = 'models/vgg19_weights.test.h5'
 vgg_full_weights_path = 'models/vgg19_weights.full.h5'
 # Define densenet weights paths
 densenet_best_weights_path = 'models/densenet169_weights.best.h5'
 densenet_last_weights_path = 'models/densenet169_weights.last.h5'
 densenet_test_weights_path = 'models/densenet169_weights.test.h5'
 densenet_full_weights_path = 'models/densenet169_weights.full.h5'
 # Define inception weights paths
 inception_best_weights_path = 'models/inceptionv4_weights.best.h5'
 inception_last_weights_path = 'models/inceptionv4_weights.last.h5'
 inception_test_weights_path = 'models/inceptionv4_weights.test.h5'
 inception_full_weights_path = 'models/inceptionv4_weights.full.h5'
 # Load vgg models
 vgg_best = vgg19(img_rows=img_rows,
                  img_cols=img_cols,
                  channels=channels,
                  num_classes=num_classes)
 vgg_best.load_weights(vgg_best_weights_path)
 vgg_last = vgg19(img_rows=img_rows,
                  img_cols=img_cols,
                  channels=channels,
                  num_classes=num_classes)
 vgg_last.load_weights(vgg_last_weights_path)
 vgg_test = vgg19(img_rows=img_rows,
                  img_cols=img_cols,
                  channels=channels,
                  num_classes=num_classes)
 vgg_test.load_weights(vgg_test_weights_path)
 vgg_full = vgg19(img_rows=img_rows,
                  img_cols=img_cols,
                  channels=channels,