Beispiel #1
0
def load_from_dir(root_dir, model_index=None, G_weights=None, verbose=False):
    args = json.load(open(os.path.join(root_dir, 'args.json')))

    models_dir = os.path.join(root_dir, 'models')
    if model_index is None:
        models = os.listdir(models_dir)
        model_index = max([
            int(name.split('.')[0].split('_')[-1]) for name in models
            if name.startswith('deformator')
        ])

        if verbose:
            print('using max index {}'.format(model_index))

    if G_weights is None:
        G_weights = args['gan_weights']
    if G_weights is None or not os.path.isfile(G_weights):
        if verbose:
            print('Using default local G weights')
        G_weights = WEIGHTS[args['gan_type']]

    if args['gan_type'] == 'BigGAN':
        G = make_big_gan(G_weights, args['target_class']).eval()
    elif args['gan_type'] in ['ProgGAN', 'PGGAN']:
        G = make_proggan(G_weights)
    else:
        G = make_external(G_weights)

    deformator = LatentDeformator(
        G.dim_z, type=DEFORMATOR_TYPE_DICT[args['deformator']])

    if 'shift_predictor' not in args.keys(
    ) or args['shift_predictor'] == 'ResNet':
        shift_predictor = ResNetShiftPredictor(G.dim_z)
    elif args['shift_predictor'] == 'LeNet':
        shift_predictor = LeNetShiftPredictor(
            G.dim_z, 1 if args['gan_type'] == 'SN_MNIST' else 3)

    deformator_model_path = os.path.join(
        models_dir, 'deformator_{}.pt'.format(model_index))
    shift_model_path = os.path.join(
        models_dir, 'shift_predictor_{}.pt'.format(model_index))
    if os.path.isfile(deformator_model_path):
        deformator.load_state_dict(torch.load(deformator_model_path))
    if os.path.isfile(shift_model_path):
        shift_predictor.load_state_dict(torch.load(shift_model_path))

    # try to load dims annotation
    directions_json = os.path.join(root_dir, 'directions.json')
    if os.path.isfile(directions_json):
        with open(directions_json, 'r') as f:
            directions_dict = json.load(f, object_pairs_hook=OrderedDict)
            setattr(deformator, 'directions_dict', directions_dict)

    return deformator.eval().cuda(), G.eval().cuda(), shift_predictor.eval(
    ).cuda()
Beispiel #2
0
def main():
    parser = argparse.ArgumentParser(
        description='GAN-based unsupervised segmentation train')
    parser.add_argument('--out', type=str, required=True)
    parser.add_argument('--gan_weights', type=str, default=WEIGHTS['BigGAN'])
    parser.add_argument('--deformator_weights', type=str, required=True)
    parser.add_argument('--deformator_type',
                        type=str,
                        choices=DEFORMATOR_TYPE_DICT.keys(),
                        required=True)
    parser.add_argument('--background_dim', type=int, required=True)
    parser.add_argument('--classes',
                        type=int,
                        nargs='*',
                        default=list(range(1000)))
    parser.add_argument('--device', type=int, default=0)
    parser.add_argument('--seed', type=int, default=2)

    parser.add_argument('--val_images_dir', type=str, default=None)
    parser.add_argument('--val_masks_dir', type=str, default=None)

    for key, val in SegmentationTrainParams().__dict__.items():
        val_type = type(val) if key is not 'synthezing' else str
        parser.add_argument('--{}'.format(key), type=val_type, default=None)

    args = parser.parse_args()
    torch.random.manual_seed(args.seed)

    torch.cuda.set_device(args.device)
    # save run p
    save_command_run_params(args)

    if len(args.classes) == 0:
        print('using all ImageNet')
        args.classes = list(range(1000))
    G = make_big_gan(args.gan_weights, args.classes).eval().cuda()
    deformator = LatentDeformator(
        G.dim_z, type=DEFORMATOR_TYPE_DICT[args.deformator_type])
    deformator.load_state_dict(
        torch.load(args.deformator_weights, map_location=torch.device('cpu')))
    deformator.eval().cuda()

    model = UNet().train().cuda()
    train_params = SegmentationTrainParams(**args.__dict__)
    print(f'run train with p: {train_params.__dict__}')

    train_segmentation(G,
                       deformator,
                       model,
                       train_params,
                       args.background_dim,
                       args.out,
                       val_dirs=[args.val_images_dir, args.val_masks_dir])
Beispiel #3
0
def load_generator(args, G_weights):
    gan_type = args['gan_type']
    if gan_type == 'BigGAN':
        G = make_big_gan(G_weights, args['target_class']).eval()
    elif gan_type in ['ProgGAN']:
        G = make_proggan(G_weights)
    elif 'StyleGAN2' in gan_type:
        G = make_style_gan2(args['gan_resolution'], G_weights, args['w_shift'])
    else:
        G = make_sngan(G_weights)

    return G
Beispiel #4
0
def load_generator(args, G_weights, shift_in_w):
    gan_type = args['gan_type']
    if gan_type == 'BigGAN':
        G = make_big_gan(G_weights, args['target_class']).eval()
    elif gan_type in ['ProgGAN']:
        G = make_proggan(G_weights)
    elif 'StyleGAN2' in gan_type:
        G = make_style_gan2(args['resolution'], G_weights, shift_in_w)
    elif 'GLOW_tensorflow' == gan_type:
        G = make_GLOW(G_weights, "tensorflow")
    elif gan_type in ['GLOW_pt_celeba', 'GLOW_pt_anime']:
        G = make_GLOW(gan_type, G_weights, "pytorch")
    else:
        G = make_sngan(G_weights)

    return G
def main():
    parser = argparse.ArgumentParser(
        description='GAN-based unsupervised segmentation train')
    parser.add_argument('--args',
                        type=str,
                        default=None,
                        help='json with all arguments')

    parser.add_argument('--out', type=str, required=True)
    parser.add_argument('--gan_weights', type=str, default=WEIGHTS['BigGAN'])
    parser.add_argument('--deformator_weights', type=str, required=True)
    parser.add_argument('--deformator_type',
                        type=str,
                        choices=DEFORMATOR_TYPE_DICT.keys(),
                        required=True)
    parser.add_argument('--background_dim', type=int, required=True)
    parser.add_argument('--classes',
                        type=int,
                        nargs='*',
                        default=list(range(1000)))
    parser.add_argument('--device', type=int, default=0)
    parser.add_argument('--seed', type=int, default=2)

    parser.add_argument('--val_images_dir', type=str)
    parser.add_argument('--val_masks_dir', type=str)

    for key, val in SegmentationTrainParams().__dict__.items():
        parser.add_argument('--{}'.format(key), type=type(val), default=None)

    args = parser.parse_args()
    torch.random.manual_seed(args.seed)

    torch.cuda.set_device(args.device)
    if args.args is not None:
        with open(args.args) as args_json:
            args_dict = json.load(args_json)
            args.__dict__.update(**args_dict)

    # save run params
    if not os.path.isdir(args.out):
        os.makedirs(args.out)
    with open(os.path.join(args.out, 'args.json'), 'w') as args_file:
        json.dump(args.__dict__, args_file)
    with open(os.path.join(args.out, 'command.sh'), 'w') as command_file:
        command_file.write(' '.join(sys.argv))
        command_file.write('\n')

    if len(args.classes) == 0:
        print('using all ImageNet')
        args.classes = list(range(1000))
    G = make_big_gan(args.gan_weights, args.classes).eval().cuda()
    deformator = LatentDeformator(
        G.dim_z, type=DEFORMATOR_TYPE_DICT[args.deformator_type])
    deformator.load_state_dict(
        torch.load(args.deformator_weights, map_location=torch.device('cpu')))
    deformator.eval().cuda()

    model = UNet().train().cuda()
    train_params = SegmentationTrainParams(**args.__dict__)
    print('run train with params: {}'.format(train_params.__dict__))

    train_segmentation(G, deformator, model, train_params, args.background_dim,
                       args.out)

    if args.val_images_dir is not None:
        evaluate(model, args.val_images_dir, args.val_masks_dir,
                 os.path.join(args.out, 'score.json'), 128)
def main():
    tOption = TrainOptions()

    for key, val in Params().__dict__.items():
        tOption.parser.add_argument('--{}'.format(key),
                                    type=type(val),
                                    default=val)

    tOption.parser.add_argument('--args',
                                type=str,
                                default=None,
                                help='json with all arguments')
    tOption.parser.add_argument('--out', type=str, default='./output')
    tOption.parser.add_argument('--gan_type',
                                type=str,
                                choices=WEIGHTS.keys(),
                                default='StyleGAN')
    tOption.parser.add_argument('--gan_weights', type=str, default=None)
    tOption.parser.add_argument('--target_class', type=int, default=239)
    tOption.parser.add_argument('--json', type=str)

    tOption.parser.add_argument('--deformator',
                                type=str,
                                default='proj',
                                choices=DEFORMATOR_TYPE_DICT.keys())
    tOption.parser.add_argument('--deformator_random_init',
                                type=bool,
                                default=False)

    tOption.parser.add_argument('--shift_predictor_size', type=int)
    tOption.parser.add_argument('--shift_predictor',
                                type=str,
                                choices=['ResNet', 'LeNet'],
                                default='ResNet')
    tOption.parser.add_argument('--shift_distribution_key',
                                type=str,
                                choices=SHIFT_DISTRIDUTION_DICT.keys())

    tOption.parser.add_argument('--seed', type=int, default=2)
    tOption.parser.add_argument('--device', type=int, default=0)

    tOption.parser.add_argument('--continue_train', type=bool, default=False)
    tOption.parser.add_argument('--deformator_path',
                                type=str,
                                default='output/models/deformator_90000.pt')
    tOption.parser.add_argument(
        '--shift_predictor_path',
        type=str,
        default='output/models/shift_predictor_190000.pt')

    args = tOption.parse()
    torch.cuda.set_device(args.device)
    random.seed(args.seed)
    torch.random.manual_seed(args.seed)

    if args.args is not None:
        with open(args.args) as args_json:
            args_dict = json.load(args_json)
            args.__dict__.update(**args_dict)

    # save run params
    #if not os.path.isdir(args.out):
    #    os.makedirs(args.out)
    #with open(os.path.join(args.out, 'args.json'), 'w') as args_file:
    #    json.dump(args.__dict__, args_file)
    #with open(os.path.join(args.out, 'command.sh'), 'w') as command_file:
    #    command_file.write(' '.join(sys.argv))
    #    command_file.write('\n')

    # init models
    if args.gan_weights is not None:
        weights_path = args.gan_weights
    else:
        weights_path = WEIGHTS[args.gan_type]

    if args.gan_type == 'BigGAN':
        G = make_big_gan(weights_path, args.target_class).eval()
    elif args.gan_type == 'StyleGAN':
        G = make_stylegan(
            weights_path,
            net_info[args.stylegan.dataset]['resolution']).eval()
    elif args.gan_type == 'ProgGAN':
        G = make_proggan(weights_path).eval()
    else:
        G = make_external(weights_path).eval()

    #判断是对z还是w做latent code
    if args.model == 'stylegan':
        assert (args.stylegan.latent in ['z', 'w']), 'unknown latent space'
        if args.stylegan.latent == 'z':
            target_dim = G.dim_z
        else:
            target_dim = G.dim_w

    if args.shift_predictor == 'ResNet':
        shift_predictor = ResNetShiftPredictor(
            args.direction_size, args.shift_predictor_size).cuda()
    elif args.shift_predictor == 'LeNet':
        shift_predictor = LeNetShiftPredictor(
            args.direction_size,
            1 if args.gan_type == 'SN_MNIST' else 3).cuda()
    if args.continue_train:
        deformator = LatentDeformator(
            direction_size=args.direction_size,
            out_dim=target_dim,
            type=DEFORMATOR_TYPE_DICT[args.deformator]).cuda()
        deformator.load_state_dict(
            torch.load(args.deformator_path, map_location=torch.device('cpu')))

        shift_predictor.load_state_dict(
            torch.load(args.shift_predictor_path,
                       map_location=torch.device('cpu')))
    else:
        deformator = LatentDeformator(
            direction_size=args.direction_size,
            out_dim=target_dim,
            type=DEFORMATOR_TYPE_DICT[args.deformator],
            random_init=args.deformator_random_init).cuda()

    # transform
    graph_kwargs = util.set_graph_kwargs(args)

    transform_type = ['zoom', 'shiftx', 'color', 'shifty']
    transform_model = EasyDict()
    for a_type in transform_type:
        model = graphs.find_model_using_name(args.model, a_type)
        g = model(**graph_kwargs)
        transform_model[a_type] = EasyDict(model=g)

    # training
    args.shift_distribution = SHIFT_DISTRIDUTION_DICT[
        args.shift_distribution_key]
    trainer = Trainer(params=Params(**args.__dict__),
                      out_dir=args.out,
                      out_json=args.json,
                      continue_train=args.continue_train)
    trainer.train(G, deformator, shift_predictor, transform_model)
Beispiel #7
0
def main():
    parser = argparse.ArgumentParser(description='Latent space rectification')
    for key, val in Params().__dict__.items():
        parser.add_argument('--{}'.format(key), type=type(val), default=None)

    parser.add_argument('--args',
                        type=str,
                        default=None,
                        help='json with all arguments')
    parser.add_argument('--out', type=str, required=True)
    parser.add_argument('--gan_type', type=str, choices=WEIGHTS.keys())
    parser.add_argument('--gan_weights', type=str, default=None)
    parser.add_argument('--target_class', type=int, default=239)
    parser.add_argument('--json', type=str)

    parser.add_argument('--deformator',
                        type=str,
                        default='ortho',
                        choices=DEFORMATOR_TYPE_DICT.keys())
    parser.add_argument('--deformator_random_init', type=bool, default=False)

    parser.add_argument('--shift_predictor_size', type=int)
    parser.add_argument('--shift_predictor',
                        type=str,
                        choices=['ResNet', 'LeNet'],
                        default='ResNet')
    parser.add_argument('--shift_distribution_key',
                        type=str,
                        choices=SHIFT_DISTRIDUTION_DICT.keys())

    parser.add_argument('--seed', type=int, default=2)
    parser.add_argument('--device', type=int, default=0)

    args = parser.parse_args()
    torch.cuda.set_device(args.device)
    random.seed(args.seed)
    torch.random.manual_seed(args.seed)

    if args.args is not None:
        with open(args.args) as args_json:
            args_dict = json.load(args_json)
            args.__dict__.update(**args_dict)

    # save run params
    if not os.path.isdir(args.out):
        os.makedirs(args.out)
    with open(os.path.join(args.out, 'args.json'), 'w') as args_file:
        json.dump(args.__dict__, args_file)
    with open(os.path.join(args.out, 'command.sh'), 'w') as command_file:
        command_file.write(' '.join(sys.argv))
        command_file.write('\n')

    # init models
    if args.gan_weights is not None:
        weights_path = args.gan_weights
    else:
        weights_path = WEIGHTS[args.gan_type]

    if args.gan_type == 'BigGAN':
        G = make_big_gan(weights_path, args.target_class).eval()
    elif args.gan_type == 'ProgGAN':
        G = make_proggan(weights_path).eval()
    else:
        G = make_external(weights_path).eval()

    deformator = LatentDeformator(
        G.dim_z,
        type=DEFORMATOR_TYPE_DICT[args.deformator],
        random_init=args.deformator_random_init).cuda()

    if args.shift_predictor == 'ResNet':
        shift_predictor = ResNetShiftPredictor(
            G.dim_z, args.shift_predictor_size).cuda()
    elif args.shift_predictor == 'LeNet':
        shift_predictor = LeNetShiftPredictor(
            G.dim_z, 1 if args.gan_type == 'SN_MNIST' else 3).cuda()

    # training
    args.shift_distribution = SHIFT_DISTRIDUTION_DICT[
        args.shift_distribution_key]
    args.deformation_loss = DEFORMATOR_LOSS_DICT[args.deformation_loss]
    trainer = Trainer(params=Params(**args.__dict__),
                      out_dir=args.out,
                      out_json=args.json)
    trainer.train(G, deformator, shift_predictor)