Ejemplo n.º 1
0
    def gather_options(self):
        # initialize parser with basic options
        if not self.initialized:
            parser = argparse.ArgumentParser(
                formatter_class=argparse.ArgumentDefaultsHelpFormatter)
            parser = self.initialize(parser)

        # get the basic options
        opt, unknown = parser.parse_known_args()

        # modify model-related parser options
        model_name = opt.model
        model_option_setter = models.get_option_setter(model_name)
        parser = model_option_setter(parser, self.isTrain)

        # modify dataset-related parser options
        dataset_mode = opt.dataset_mode
        dataset_option_setter = data.get_option_setter(dataset_mode)
        parser = dataset_option_setter(parser, self.isTrain)

        # modify networks-related parser options
        parser = networks.modify_commandline_options(parser, self.isTrain)

        opt, unknown = parser.parse_known_args()

        opt = parser.parse_args()
        self.parser = parser
        return opt
Ejemplo n.º 2
0
 def modify_commandline_options(parser, is_train):
     assert isinstance(parser, argparse.ArgumentParser)
     parser.set_defaults(netG='sub_mobile_spade')
     parser.add_argument('--separable_conv_norm', type=str, default='instance',
                         choices=('none', 'instance', 'batch'),
                         help='whether to use instance norm for the separable convolutions')
     parser.add_argument('--norm_G', type=str, default='spadesyncbatch3x3',
                         help='instance normalization or batch normalization')
     parser.add_argument('--num_upsampling_layers',
                         choices=('normal', 'more', 'most'), default='more',
                         help="If 'more', adds upsampling layer between the two middle resnet blocks. "
                              "If 'most', also add one more upsampling + resnet layer at the end of the generator")
     if is_train:
         parser.add_argument('--restore_G_path', type=str, default=None,
                             help='the path to restore the generator')
         parser.add_argument('--restore_D_path', type=str, default=None,
                             help='the path to restore the discriminator')
         parser.add_argument('--real_stat_path', type=str, required=True,
                             help='the path to load the groud-truth images information to compute FID.')
         parser.add_argument('--lambda_gan', type=float, default=1, help='weight for gan loss')
         parser.add_argument('--lambda_feat', type=float, default=10, help='weight for gan feature loss')
         parser.add_argument('--lambda_vgg', type=float, default=10, help='weight for vgg loss')
         parser.add_argument('--beta2', type=float, default=0.999, help='momentum term of adam')
         parser.add_argument('--no_TTUR', action='store_true', help='Use TTUR training scheme')
         parser.add_argument('--no_fid', action='store_true', help='No FID evaluation during training')
         parser.add_argument('--no_mIoU', action='store_true', help='No mIoU evaluation during training '
                                                                    '(sometimes because there are CUDA memory)')
         parser.set_defaults(netD='multi_scale', ndf=64, dataset_mode='cityscapes', batch_size=16,
                             print_freq=50, save_latest_freq=10000000000, save_epoch_freq=10,
                             nepochs=100, nepochs_decay=100, init_type='xavier')
     parser = networks.modify_commandline_options(parser, is_train)
     return parser
Ejemplo n.º 3
0
 def modify_commandline_options(parser, is_train):
     assert isinstance(parser, argparse.ArgumentParser)
     parser.add_argument('--z_dim', type=int, default=256,
                         help="dimension of the latent z vector")
     parser.set_defaults(netG='sub_mobile_spade')
     parser = networks.modify_commandline_options(parser, is_train)
     return parser
Ejemplo n.º 4
0
 def modify_commandline_options(parser, is_train):
     parser = super(SPADEDistiller,
                    SPADEDistiller).modify_commandline_options(
                        parser, is_train)
     assert isinstance(parser, argparse.ArgumentParser)
     parser.add_argument('--restore_pretrained_G_path',
                         type=str,
                         default=None,
                         help='the path to restore pretrained G')
     parser.add_argument('--pretrained_student_G_path',
                         type=str,
                         default=None,
                         help='the path for pretrained student G')
     parser.add_argument('--pretrained_netG',
                         type=str,
                         default='mobile_spade',
                         help='specify pretrained generator architecture',
                         choices=['inception_spade'])
     parser.add_argument(
         '--pretrained_ngf',
         type=int,
         default=64,
         help='the base number of filters of the pretrained generator')
     parser.add_argument(
         '--pretrained_norm_G',
         type=str,
         default='spadesyncbatch3x3',
         help=
         'instance normalization or batch normalization of the student model'
     )
     parser.add_argument('--target_flops',
                         type=float,
                         default=0,
                         help='target flops')
     parser.add_argument('--prune_cin_lb',
                         type=int,
                         default=1,
                         help='lower bound for input channel number')
     parser.add_argument('--prune_only',
                         action='store_true',
                         help='prune without training')
     parser.add_argument('--prune_continue',
                         action='store_true',
                         help='continue training after pruning all layers')
     parser.add_argument('--prune_logging_verbose',
                         action='store_true',
                         help='logging verbose for pruning')
     parser.set_defaults(netD='multi_scale',
                         dataset_mode='cityscapes',
                         batch_size=16,
                         print_freq=50,
                         save_latest_freq=10000000000,
                         save_epoch_freq=10,
                         nepochs=100,
                         nepochs_decay=100,
                         init_type='xavier',
                         teacher_ngf=64,
                         student_ngf=48)
     parser = networks.modify_commandline_options(parser, is_train)
     return parser
Ejemplo n.º 5
0
    def modify_commandline_options(cls, parser: argparse.ArgumentParser,
                                   is_train):
        parser = argparse.ArgumentParser(parents=[parser], add_help=False)
        parser = super(SamsModel,
                       cls).modify_commandline_options(parser, is_train)
        parser.set_defaults(person_inputs=("agnostic", "densepose", "flow"))
        parser.add_argument(
            "--encoder_input",
            default="flow",
            help=
            "which of the --person_inputs to use as the encoder segmap input "
            "(only 1 allowed).",
        )
        # num previous frames fed as input = n_frames_total - 1
        parser.set_defaults(n_frames_total=5)
        # batch size effectively becomes n_frames_total * batch
        parser.set_defaults(batch_size=4)
        parser.add_argument(
            "--wt_l1",
            type=float,
            default=1.0,
            help="Weight applied to l1 loss in the generator",
        )
        parser.add_argument(
            "--wt_vgg",
            type=float,
            default=1.0,
            help="Weight applied to vgg loss in the generator",
        )
        parser.add_argument(
            "--wt_multiscale",
            type=float,
            default=1.0,
            help=
            "Weight applied to adversarial multiscale loss in the generator",
        )
        parser.add_argument(
            "--wt_temporal",
            type=float,
            default=1.0,
            help="Weight applied to adversarial temporal loss in the generator",
        )
        parser.add_argument(
            "--norm_D",
            type=str,
            default="spectralinstance",
            help="instance normalization or batch normalization",
        )
        from models import networks

        parser = networks.modify_commandline_options(parser, is_train)
        parser = gan_options.modify_commandline_options(parser, is_train)
        return parser
Ejemplo n.º 6
0
 def modify_commandline_options(parser, is_train):
     parser = super(SPADESupernet,
                    SPADESupernet).modify_commandline_options(
                        parser, is_train)
     assert isinstance(parser, argparse.ArgumentParser)
     parser.set_defaults(netD='multi_scale',
                         dataset_mode='cityscapes',
                         batch_size=16,
                         print_freq=50,
                         save_latest_freq=10000000000,
                         save_epoch_freq=10,
                         nepochs=100,
                         nepochs_decay=100,
                         init_type='xavier',
                         teacher_ngf=64,
                         student_ngf=48,
                         student_netG='super_mobile_spade',
                         ndf=64)
     parser = networks.modify_commandline_options(parser, is_train)
     return parser
Ejemplo n.º 7
0
 def modify_commandline_options(parser, is_train):
     parser = super(SPADEDistiller,
                    SPADEDistiller).modify_commandline_options(
                        parser, is_train)
     assert isinstance(parser, argparse.ArgumentParser)
     parser.add_argument('--restore_pretrained_G_path',
                         type=str,
                         default=None,
                         help='the path to restore pretrained G')
     parser.add_argument('--pretrained_netG',
                         type=str,
                         default='mobile_spade',
                         help='specify pretrained generator architecture',
                         choices=['mobile_spade'])
     parser.add_argument(
         '--pretrained_ngf',
         type=int,
         default=64,
         help='the base number of filters of the pretrained generator')
     parser.add_argument(
         '--pretrained_norm_G',
         type=str,
         default='spadesyncbatch3x3',
         help=
         'instance normalization or batch normalization of the student model'
     )
     parser.set_defaults(netD='multi_scale',
                         dataset_mode='cityscapes',
                         batch_size=16,
                         print_freq=50,
                         save_latest_freq=10000000000,
                         save_epoch_freq=10,
                         nepochs=100,
                         nepochs_decay=100,
                         init_type='xavier',
                         teacher_ngf=64,
                         student_ngf=48)
     parser = networks.modify_commandline_options(parser, is_train)
     return parser
Ejemplo n.º 8
0
 def modify_commandline_options(parser, is_train):
     networks.modify_commandline_options(parser, is_train)
     return parser
Ejemplo n.º 9
0
 def modify_commandline_options(parser, is_train):
     # parser.add_argument('--name2', default='', type=str, help='name of another checkpoint')
     networks.modify_commandline_options(parser, is_train)
     return parser