def update_argparser(parser):
    models.update_argparser(parser)
    args, _ = parser.parse_known_args()
    parser.add_argument('--num_blocks',
                        help='Number of residual blocks in networks.',
                        default=16,
                        type=int)
    parser.add_argument('--num_residual_units',
                        help='Number of residual units in networks.',
                        default=32,
                        type=int)
    parser.add_argument('--width_multiplier',
                        help='Width multiplier inside residual blocks.',
                        default=4,
                        type=float)
    parser.add_argument('--temporal_size',
                        help='Number of frames for burst input.',
                        default=None,
                        type=int)
    if args.dataset.startswith('div2k'):
        parser.set_defaults(
            train_epochs=30,
            learning_rate_milestones=(20, 25),
            learning_rate_decay=0.2,
            save_checkpoints_epochs=1,
            lr_patch_size=48,
            train_temporal_size=1,
            eval_temporal_size=1,
        )
    else:
        raise NotImplementedError(
            'Needs to tune hyper parameters for new dataset.')
示例#2
0
def update_argparser(parser):
    models.update_argparser(parser)
    args, _ = parser.parse_known_args()
    if args.dataset.startswith('video'):
        parser.add_argument('--num-blocks',
                            help='Number of residual blocks in networks',
                            default=16,
                            type=int)
        parser.add_argument('--num-residual-units',
                            help='Number of residual units in networks',
                            default=32,
                            type=int)
        parser.add_argument('--width_multiplier',
                            help='Width multiplier inside residual blocks',
                            default=4,
                            type=int)
        parser.set_defaults(
            train_epochs=20,
            learning_rate_milestones=(15, 18),
            save_checkpoints_epochs=1,
            lr_patch_size=64,
            train_temporal_size=1,
            eval_temporal_size=1,
        )
    else:
        raise NotImplementedError(
            'Needs to tune hyper parameters for new dataset.')
示例#3
0
def update_argparser(parser):
    models.update_argparser(parser)
    args, _ = parser.parse_known_args()
    parser.add_argument('--num-steps',
                        help='Number of steps in recurrent networks',
                        default=12,
                        type=int)
    parser.add_argument('--num-filters',
                        help='Number of filters in networks',
                        default=128,
                        type=int)
    parser.add_argument('--non-local-field-size',
                        help='Size of receptive field in non-local blocks',
                        default=35,
                        type=int)
    parser.add_argument(
        '--init-ckpt',
        help='Checkpoint path to initialize',
        default=None,
        type=str,
    )
    parser.set_defaults(
        train_steps=500000,
        learning_rate=((100000, 200000, 300000, 400000, 450000),
                       (1e-3, 5e-4, 2.5e-4, 1.25e-4, 6.25e-5, 3.125e-5)),
        save_checkpoints_steps=20000,
        save_summary_steps=1000,
    )
示例#4
0
def update_argparser(parser):
    models.update_argparser(parser)
    args, _ = parser.parse_known_args()
    if args.dataset == 'cifar10':
        parser.add_argument('--num-layers',
                            help='Number of layers in networks',
                            default=110,
                            type=int)
        parser.add_argument('--mixup',
                            help='Hyper parameter for mixup training',
                            default=0.0,
                            type=float)
        parser.set_defaults(
            train_steps=150000,
            learning_rate=((32000, 48000, 120000), (0.1, 0.01, 0.001, 0.0002)),
            save_checkpoints_steps=5000,
        )
    else:
        raise NotImplementedError(
            'Needs to tune hyper parameters for new dataset.')
示例#5
0
def update_argparser(parser):
    models.update_argparser(parser)
    args, _ = parser.parse_known_args()
    if args.dataset == 'div2k':
        parser.add_argument('--num-blocks',
                            help='Number of residual blocks in networks',
                            default=16,
                            type=int)
        parser.add_argument('--num-residual-units',
                            help='Number of residual units in networks',
                            default=64,
                            type=int)
        parser.set_defaults(
            train_steps=1500000,
            learning_rate=((1000000, ), (1e-4, 5e-5)),
            save_checkpoints_steps=50000,
            save_summary_steps=10000,
        )
    else:
        raise NotImplementedError(
            'Needs to tune hyper parameters for new dataset.')