def load_from_dir(root_dir, model_index=None, G_weights=None, verbose=False): args = json.load(open(os.path.join(root_dir, 'args.json'))) models_dir = os.path.join(root_dir, 'models') if model_index is None: models = os.listdir(models_dir) model_index = max([ int(name.split('.')[0].split('_')[-1]) for name in models if name.startswith('deformator') ]) if verbose: print('using max index {}'.format(model_index)) if G_weights is None: G_weights = args['gan_weights'] if G_weights is None or not os.path.isfile(G_weights): if verbose: print('Using default local G weights') G_weights = WEIGHTS[args['gan_type']] if args['gan_type'] == 'BigGAN': G = make_big_gan(G_weights, args['target_class']).eval() elif args['gan_type'] in ['ProgGAN', 'PGGAN']: G = make_proggan(G_weights) else: G = make_external(G_weights) deformator = LatentDeformator( G.dim_z, type=DEFORMATOR_TYPE_DICT[args['deformator']]) if 'shift_predictor' not in args.keys( ) or args['shift_predictor'] == 'ResNet': shift_predictor = ResNetShiftPredictor(G.dim_z) elif args['shift_predictor'] == 'LeNet': shift_predictor = LeNetShiftPredictor( G.dim_z, 1 if args['gan_type'] == 'SN_MNIST' else 3) deformator_model_path = os.path.join( models_dir, 'deformator_{}.pt'.format(model_index)) shift_model_path = os.path.join( models_dir, 'shift_predictor_{}.pt'.format(model_index)) if os.path.isfile(deformator_model_path): deformator.load_state_dict(torch.load(deformator_model_path)) if os.path.isfile(shift_model_path): shift_predictor.load_state_dict(torch.load(shift_model_path)) # try to load dims annotation directions_json = os.path.join(root_dir, 'directions.json') if os.path.isfile(directions_json): with open(directions_json, 'r') as f: directions_dict = json.load(f, object_pairs_hook=OrderedDict) setattr(deformator, 'directions_dict', directions_dict) return deformator.eval().cuda(), G.eval().cuda(), shift_predictor.eval( ).cuda()
def main(): tOption = TrainOptions() for key, val in Params().__dict__.items(): tOption.parser.add_argument('--{}'.format(key), type=type(val), default=val) tOption.parser.add_argument('--args', type=str, default=None, help='json with all arguments') tOption.parser.add_argument('--out', type=str, default='./output') tOption.parser.add_argument('--gan_type', type=str, choices=WEIGHTS.keys(), default='StyleGAN') tOption.parser.add_argument('--gan_weights', type=str, default=None) tOption.parser.add_argument('--target_class', type=int, default=239) tOption.parser.add_argument('--json', type=str) tOption.parser.add_argument('--deformator', type=str, default='proj', choices=DEFORMATOR_TYPE_DICT.keys()) tOption.parser.add_argument('--deformator_random_init', type=bool, default=False) tOption.parser.add_argument('--shift_predictor_size', type=int) tOption.parser.add_argument('--shift_predictor', type=str, choices=['ResNet', 'LeNet'], default='ResNet') tOption.parser.add_argument('--shift_distribution_key', type=str, choices=SHIFT_DISTRIDUTION_DICT.keys()) tOption.parser.add_argument('--seed', type=int, default=2) tOption.parser.add_argument('--device', type=int, default=0) tOption.parser.add_argument('--continue_train', type=bool, default=False) tOption.parser.add_argument('--deformator_path', type=str, default='output/models/deformator_90000.pt') tOption.parser.add_argument( '--shift_predictor_path', type=str, default='output/models/shift_predictor_190000.pt') args = tOption.parse() torch.cuda.set_device(args.device) random.seed(args.seed) torch.random.manual_seed(args.seed) if args.args is not None: with open(args.args) as args_json: args_dict = json.load(args_json) args.__dict__.update(**args_dict) # save run params #if not os.path.isdir(args.out): # os.makedirs(args.out) #with open(os.path.join(args.out, 'args.json'), 'w') as args_file: # json.dump(args.__dict__, args_file) #with open(os.path.join(args.out, 'command.sh'), 'w') as command_file: # command_file.write(' '.join(sys.argv)) # command_file.write('\n') # init models if args.gan_weights is not None: weights_path = args.gan_weights else: weights_path = WEIGHTS[args.gan_type] if args.gan_type == 'BigGAN': G = make_big_gan(weights_path, args.target_class).eval() elif args.gan_type == 'StyleGAN': G = make_stylegan( weights_path, net_info[args.stylegan.dataset]['resolution']).eval() elif args.gan_type == 'ProgGAN': G = make_proggan(weights_path).eval() else: G = make_external(weights_path).eval() #判断是对z还是w做latent code if args.model == 'stylegan': assert (args.stylegan.latent in ['z', 'w']), 'unknown latent space' if args.stylegan.latent == 'z': target_dim = G.dim_z else: target_dim = G.dim_w if args.shift_predictor == 'ResNet': shift_predictor = ResNetShiftPredictor( args.direction_size, args.shift_predictor_size).cuda() elif args.shift_predictor == 'LeNet': shift_predictor = LeNetShiftPredictor( args.direction_size, 1 if args.gan_type == 'SN_MNIST' else 3).cuda() if args.continue_train: deformator = LatentDeformator( direction_size=args.direction_size, out_dim=target_dim, type=DEFORMATOR_TYPE_DICT[args.deformator]).cuda() deformator.load_state_dict( torch.load(args.deformator_path, map_location=torch.device('cpu'))) shift_predictor.load_state_dict( torch.load(args.shift_predictor_path, map_location=torch.device('cpu'))) else: deformator = LatentDeformator( direction_size=args.direction_size, out_dim=target_dim, type=DEFORMATOR_TYPE_DICT[args.deformator], random_init=args.deformator_random_init).cuda() # transform graph_kwargs = util.set_graph_kwargs(args) transform_type = ['zoom', 'shiftx', 'color', 'shifty'] transform_model = EasyDict() for a_type in transform_type: model = graphs.find_model_using_name(args.model, a_type) g = model(**graph_kwargs) transform_model[a_type] = EasyDict(model=g) # training args.shift_distribution = SHIFT_DISTRIDUTION_DICT[ args.shift_distribution_key] trainer = Trainer(params=Params(**args.__dict__), out_dir=args.out, out_json=args.json, continue_train=args.continue_train) trainer.train(G, deformator, shift_predictor, transform_model)
def main(): parser = argparse.ArgumentParser(description='Latent space rectification') for key, val in Params().__dict__.items(): parser.add_argument('--{}'.format(key), type=type(val), default=None) parser.add_argument('--args', type=str, default=None, help='json with all arguments') parser.add_argument('--out', type=str, required=True) parser.add_argument('--gan_type', type=str, choices=WEIGHTS.keys()) parser.add_argument('--gan_weights', type=str, default=None) parser.add_argument('--target_class', type=int, default=239) parser.add_argument('--json', type=str) parser.add_argument('--deformator', type=str, default='ortho', choices=DEFORMATOR_TYPE_DICT.keys()) parser.add_argument('--deformator_random_init', type=bool, default=False) parser.add_argument('--shift_predictor_size', type=int) parser.add_argument('--shift_predictor', type=str, choices=['ResNet', 'LeNet'], default='ResNet') parser.add_argument('--shift_distribution_key', type=str, choices=SHIFT_DISTRIDUTION_DICT.keys()) parser.add_argument('--seed', type=int, default=2) parser.add_argument('--device', type=int, default=0) args = parser.parse_args() torch.cuda.set_device(args.device) random.seed(args.seed) torch.random.manual_seed(args.seed) if args.args is not None: with open(args.args) as args_json: args_dict = json.load(args_json) args.__dict__.update(**args_dict) # save run params if not os.path.isdir(args.out): os.makedirs(args.out) with open(os.path.join(args.out, 'args.json'), 'w') as args_file: json.dump(args.__dict__, args_file) with open(os.path.join(args.out, 'command.sh'), 'w') as command_file: command_file.write(' '.join(sys.argv)) command_file.write('\n') # init models if args.gan_weights is not None: weights_path = args.gan_weights else: weights_path = WEIGHTS[args.gan_type] if args.gan_type == 'BigGAN': G = make_big_gan(weights_path, args.target_class).eval() elif args.gan_type == 'ProgGAN': G = make_proggan(weights_path).eval() else: G = make_external(weights_path).eval() deformator = LatentDeformator( G.dim_z, type=DEFORMATOR_TYPE_DICT[args.deformator], random_init=args.deformator_random_init).cuda() if args.shift_predictor == 'ResNet': shift_predictor = ResNetShiftPredictor( G.dim_z, args.shift_predictor_size).cuda() elif args.shift_predictor == 'LeNet': shift_predictor = LeNetShiftPredictor( G.dim_z, 1 if args.gan_type == 'SN_MNIST' else 3).cuda() # training args.shift_distribution = SHIFT_DISTRIDUTION_DICT[ args.shift_distribution_key] args.deformation_loss = DEFORMATOR_LOSS_DICT[args.deformation_loss] trainer = Trainer(params=Params(**args.__dict__), out_dir=args.out, out_json=args.json) trainer.train(G, deformator, shift_predictor)