Пример #1
0
def main():

    # Parse options
    args = Options().parse()
    print('Parameters:\t' + str(args))

    if args.filter_sketch:
        assert args.dataset == 'Sketchy'
    if args.split_eccv_2018:
        assert args.dataset == 'Sketchy_extended' or args.dataset == 'Sketchy'
    if args.gzs_sbir:
        args.test = True

    # Read the config file and
    config = utils.read_config()
    path_dataset = config['path_dataset']
    path_aux = config['path_aux']

    # modify the log and check point paths
    ds_var = None
    if '_' in args.dataset:
        token = args.dataset.split('_')
        args.dataset = token[0]
        ds_var = token[1]

    str_aux = ''
    if args.split_eccv_2018:
        str_aux = 'split_eccv_2018'
    if args.gzs_sbir:
        str_aux = os.path.join(str_aux, 'generalized')
    args.semantic_models = sorted(args.semantic_models)
    model_name = '+'.join(args.semantic_models)
    root_path = os.path.join(path_dataset, args.dataset)
    path_sketch_model = os.path.join(path_aux, 'CheckPoints', args.dataset,
                                     'sketch')
    path_image_model = os.path.join(path_aux, 'CheckPoints', args.dataset,
                                    'image')
    path_cp = os.path.join(path_aux, 'CheckPoints', args.dataset, str_aux,
                           model_name, str(args.dim_out))
    path_log = os.path.join(path_aux, 'LogFiles', args.dataset, str_aux,
                            model_name, str(args.dim_out))
    path_results = os.path.join(path_aux, 'Results', args.dataset, str_aux,
                                model_name, str(args.dim_out))
    files_semantic_labels = []
    sem_dim = 0
    for f in args.semantic_models:
        fi = os.path.join(path_aux, 'Semantic', args.dataset, f + '.npy')
        files_semantic_labels.append(fi)
        sem_dim += list(np.load(fi,
                                allow_pickle=True).item().values())[0].shape[0]

    print('Checkpoint path: {}'.format(path_cp))
    print('Logger path: {}'.format(path_log))
    print('Result path: {}'.format(path_results))

    # Parameters for transforming the images
    transform_image = transforms.Compose(
        [transforms.Resize((args.im_sz, args.im_sz)),
         transforms.ToTensor()])
    transform_sketch = transforms.Compose(
        [transforms.Resize((args.sk_sz, args.sk_sz)),
         transforms.ToTensor()])

    # Load the dataset
    print('Loading data...', end='')

    if args.dataset == 'Sketchy':
        if ds_var == 'extended':
            photo_dir = 'extended_photo'  # photo or extended_photo
            photo_sd = ''
        else:
            photo_dir = 'photo'
            photo_sd = 'tx_000000000000'
        sketch_dir = 'sketch'
        sketch_sd = 'tx_000000000000'
        splits = utils.load_files_sketchy_zeroshot(
            root_path=root_path,
            split_eccv_2018=args.split_eccv_2018,
            photo_dir=photo_dir,
            sketch_dir=sketch_dir,
            photo_sd=photo_sd,
            sketch_sd=sketch_sd)
    elif args.dataset == 'TU-Berlin':
        photo_dir = 'images'
        sketch_dir = 'sketches'
        photo_sd = ''
        sketch_sd = ''
        splits = utils.load_files_tuberlin_zeroshot(root_path=root_path,
                                                    photo_dir=photo_dir,
                                                    sketch_dir=sketch_dir,
                                                    photo_sd=photo_sd,
                                                    sketch_sd=sketch_sd)
    else:
        raise Exception('Wrong dataset.')

    # Combine the valid and test set into test set

    if args.gzs_sbir:
        perc = 0.2
        _, idx_sk = np.unique(splits['tr_fls_sk'], return_index=True)
        tr_fls_sk_ = splits['tr_fls_sk'][idx_sk]
        tr_clss_sk_ = splits['tr_clss_sk'][idx_sk]
        _, idx_im = np.unique(splits['tr_fls_im'], return_index=True)
        tr_fls_im_ = splits['tr_fls_im'][idx_im]
        tr_clss_im_ = splits['tr_clss_im'][idx_im]
        if args.dataset == 'Sketchy' and args.filter_sketch:
            _, idx_sk = np.unique([f.split('-')[0] for f in tr_fls_sk_],
                                  return_index=True)
            tr_fls_sk_ = tr_fls_sk_[idx_sk]
            tr_clss_sk_ = tr_clss_sk_[idx_sk]
        idx_sk = np.sort(
            np.random.choice(tr_fls_sk_.shape[0],
                             int(perc * splits['te_fls_sk'].shape[0]),
                             replace=False))
        idx_im = np.sort(
            np.random.choice(tr_fls_im_.shape[0],
                             int(perc * splits['te_fls_im'].shape[0]),
                             replace=False))
        splits['te_fls_sk'] = np.concatenate(
            (tr_fls_sk_[idx_sk], splits['te_fls_sk']), axis=0)
        splits['te_clss_sk'] = np.concatenate(
            (tr_clss_sk_[idx_sk], splits['te_clss_sk']), axis=0)
        splits['te_fls_im'] = np.concatenate(
            (tr_fls_im_[idx_im], splits['te_fls_im']), axis=0)
        splits['te_clss_im'] = np.concatenate(
            (tr_clss_im_[idx_im], splits['te_clss_im']), axis=0)

    # class dictionary
    dict_clss = utils.create_dict_texts(splits['tr_clss_im'])

    data_train = DataGeneratorPaired(args.dataset,
                                     root_path,
                                     photo_dir,
                                     sketch_dir,
                                     photo_sd,
                                     sketch_sd,
                                     splits['tr_fls_sk'],
                                     splits['tr_fls_im'],
                                     splits['tr_clss_im'],
                                     transforms_sketch=transform_sketch,
                                     transforms_image=transform_image)

    data_test_sketch = DataGeneratorSketch(args.dataset,
                                           root_path,
                                           sketch_dir,
                                           sketch_sd,
                                           splits['te_fls_sk'],
                                           splits['te_clss_sk'],
                                           transforms=transform_sketch)
    data_test_image = DataGeneratorImage(args.dataset,
                                         root_path,
                                         photo_dir,
                                         photo_sd,
                                         splits['te_fls_im'],
                                         splits['te_clss_im'],
                                         transforms=transform_image)
    print('Done')

    train_sampler = WeightedRandomSampler(data_train.get_weights(),
                                          num_samples=args.epoch_size *
                                          args.batch_size,
                                          replacement=True)

    # PyTorch train loader
    train_loader = DataLoader(dataset=data_train,
                              batch_size=args.batch_size,
                              sampler=train_sampler,
                              num_workers=args.num_workers,
                              pin_memory=True)
    # PyTorch test loader for sketch
    test_loader_sketch = DataLoader(dataset=data_test_sketch,
                                    batch_size=args.batch_size,
                                    shuffle=False,
                                    num_workers=args.num_workers,
                                    pin_memory=True)
    # PyTorch test loader for image
    test_loader_image = DataLoader(dataset=data_test_image,
                                   batch_size=args.batch_size,
                                   shuffle=False,
                                   num_workers=args.num_workers,
                                   pin_memory=True)

    # Model parameters
    params_model = dict()
    # Paths to pre-trained sketch and image models
    params_model['path_sketch_model'] = path_sketch_model
    params_model['path_image_model'] = path_image_model
    # Dimensions
    params_model['dim_out'] = args.dim_out
    params_model['sem_dim'] = sem_dim
    # Number of classes
    params_model['num_clss'] = len(dict_clss)
    # Weight (on losses) parameters
    params_model['lambda_se'] = args.lambda_se
    params_model['lambda_im'] = args.lambda_im
    params_model['lambda_sk'] = args.lambda_sk
    params_model['lambda_gen_cyc'] = args.lambda_gen_cyc
    params_model['lambda_gen_adv'] = args.lambda_gen_adv
    params_model['lambda_gen_cls'] = args.lambda_gen_cls
    params_model['lambda_gen_reg'] = args.lambda_gen_reg
    params_model['lambda_disc_se'] = args.lambda_disc_se
    params_model['lambda_disc_sk'] = args.lambda_disc_sk
    params_model['lambda_disc_im'] = args.lambda_disc_im
    params_model['lambda_regular'] = args.lambda_regular
    # Optimizers' parameters
    params_model['lr'] = args.lr
    params_model['momentum'] = args.momentum
    params_model['milestones'] = args.milestones
    params_model['gamma'] = args.gamma
    # Files with semantic labels
    params_model['files_semantic_labels'] = files_semantic_labels
    # Class dictionary
    params_model['dict_clss'] = dict_clss

    # Model
    sem_pcyc_model = SEM_PCYC(params_model)

    cudnn.benchmark = True

    # Logger
    print('Setting logger...', end='')
    logger = Logger(path_log, force=True)
    print('Done')

    # Check cuda
    print('Checking cuda...', end='')
    # Check if CUDA is enabled
    if args.ngpu > 0 & torch.cuda.is_available():
        print('*Cuda exists*...', end='')
        sem_pcyc_model = sem_pcyc_model.cuda()
    print('Done')

    best_map = 0
    early_stop_counter = 0

    # Epoch for loop
    if not args.test:
        print('***Train***')
        for epoch in range(args.epochs):

            sem_pcyc_model.scheduler_gen.step()
            sem_pcyc_model.scheduler_disc.step()
            sem_pcyc_model.scheduler_ae.step()

            # train on training set
            losses = train(train_loader, sem_pcyc_model, epoch, args)

            # evaluate on validation set, map_ since map is already there
            print('***Validation***')
            valid_data = validate(test_loader_sketch, test_loader_image,
                                  sem_pcyc_model, epoch, args)
            map_ = np.mean(valid_data['aps@all'])

            print(
                'mAP@all on validation set after {0} epochs: {1:.4f} (real), {2:.4f} (binary)'
                .format(epoch + 1, map_, np.mean(valid_data['aps@all_bin'])))

            del valid_data

            if map_ > best_map:
                best_map = map_
                early_stop_counter = 0
                utils.save_checkpoint(
                    {
                        'epoch': epoch + 1,
                        'state_dict': sem_pcyc_model.state_dict(),
                        'best_map': best_map
                    },
                    directory=path_cp)
            else:
                if args.early_stop == early_stop_counter:
                    break
                early_stop_counter += 1

            # Logger step
            logger.add_scalar('semantic autoencoder loss',
                              losses['aut_enc'].avg)
            logger.add_scalar('generator adversarial loss',
                              losses['gen_adv'].avg)
            logger.add_scalar('generator cycle consistency loss',
                              losses['gen_cyc'].avg)
            logger.add_scalar('generator classification loss',
                              losses['gen_cls'].avg)
            logger.add_scalar('generator regression loss',
                              losses['gen_reg'].avg)
            logger.add_scalar('generator loss', losses['gen'].avg)
            logger.add_scalar('semantic discriminator loss',
                              losses['disc_se'].avg)
            logger.add_scalar('sketch discriminator loss',
                              losses['disc_sk'].avg)
            logger.add_scalar('image discriminator loss',
                              losses['disc_im'].avg)
            logger.add_scalar('discriminator loss', losses['disc'].avg)
            logger.add_scalar('mean average precision', map_)
            logger.step()

    # load the best model yet
    best_model_file = os.path.join(path_cp, 'model_best.pth')
    if os.path.isfile(best_model_file):
        print("Loading best model from '{}'".format(best_model_file))
        checkpoint = torch.load(best_model_file)
        epoch = checkpoint['epoch']
        best_map = checkpoint['best_map']
        sem_pcyc_model.load_state_dict(checkpoint['state_dict'])
        print("Loaded best model '{0}' (epoch {1}; mAP@all {2:.4f})".format(
            best_model_file, epoch, best_map))
        print('***Test***')
        valid_data = validate(test_loader_sketch, test_loader_image,
                              sem_pcyc_model, epoch, args)
        print(
            'Results on test set: mAP@all = {1:.4f}, Prec@100 = {0:.4f}, mAP@200 = {3:.4f}, Prec@200 = {2:.4f}, '
            'Time = {4:.6f} || mAP@all (binary) = {6:.4f}, Prec@100 (binary) = {5:.4f}, mAP@200 (binary) = {8:.4f}, '
            'Prec@200 (binary) = {7:.4f}, Time (binary) = {9:.6f} '.format(
                valid_data['prec@100'], np.mean(valid_data['aps@all']),
                valid_data['prec@200'], np.mean(valid_data['aps@200']),
                valid_data['time_euc'], valid_data['prec@100_bin'],
                np.mean(valid_data['aps@all_bin']), valid_data['prec@200_bin'],
                np.mean(valid_data['aps@200_bin']), valid_data['time_bin']))
        print('Saving qualitative results...', end='')
        path_qualitative_results = os.path.join(path_results,
                                                'qualitative_results')
        utils.save_qualitative_results(root_path,
                                       sketch_dir,
                                       sketch_sd,
                                       photo_dir,
                                       photo_sd,
                                       splits['te_fls_sk'],
                                       splits['te_fls_im'],
                                       path_qualitative_results,
                                       valid_data['aps@all'],
                                       valid_data['sim_euc'],
                                       valid_data['str_sim'],
                                       save_image=args.save_image_results,
                                       nq=args.number_qualit_results,
                                       best=args.save_best_results)
        print('Done')
    else:
        print(
            "No best model found at '{}'. Exiting...".format(best_model_file))
        exit()
Пример #2
0
import torch
import torch.optim as optim
import torch.nn as nn
from torch.nn.functional import interpolate

from models.build_model import build_netG
from data.customdataset import CustomDataset
from models.losses import gdloss
from options import Options
"""Pre-Training Generator"""

opt = Options().parse()
opt.phase = 'train'
opt.nEpochs = 10
opt.save_fre = 10
opt.dataset = 'iseg'
print(opt)

data_set = CustomDataset(opt)
print('Image numbers:', data_set.img_size)

dataloader = torch.utils.data.DataLoader(data_set,
                                         batch_size=opt.batch_size,
                                         shuffle=True,
                                         num_workers=int(opt.workers))

generator = build_netG(opt)

if opt.gpu_ids != '-1':
    num_gpus = len(opt.gpu_ids.split(','))
else:
Пример #3
0
normal_dataloader=load_case(normal=True)
abnormal_dataloader=load_case(normal=False)
opt = Options()
opt.nc=1
opt.nz=50
opt.isize=320
opt.ndf=32
opt.ngf=32
opt.batchsize=64
opt.ngpu=1
opt.istest=True
opt.lr=0.001
opt.beta1=0.5
opt.niter=None
opt.dataset=None
opt.model = None
opt.outf=None



model=BeatGAN(opt,None,device)
model.G.load_state_dict(torch.load('model/beatgan_folder_0_G.pkl',map_location='cpu'))
model.D.load_state_dict(torch.load('model/beatgan_folder_0_D.pkl',map_location='cpu'))


model.G.eval()
model.D.eval()
with torch.no_grad():

    abnormal_input=[]