예제 #1
0
                        default='model',
                        help='the name of the model')

    args = parser.parse_args()
    args.batch_size_test = args.batch_size if args.batch_size_test == None else args.batch_size_test
    config_visible_gpu(args.gpu)

    train_loader, test_loader = mnist(batch_size=args.batch_size,
                                      batch_size_test=args.batch_size_test)

    model = ConvNet1(input_size=[28, 28], input_channels=1, output_class=10)

    device_ids, model = parse_device_alloc(device_config=None, model=model)

    lr_func = parse_lr(policy=args.lr_policy, epoch_num=args.epoch_num)
    optimizer = parse_optim(policy=args.optim_policy,
                            params=model.parameters())

    setup_config = {kwarg: value for kwarg, value in args._get_kwargs()}
    setup_config['lr_list'] = [lr_func(idx) for idx in range(args.epoch_num)]
    if not os.path.exists(args.output_folder):
        os.makedirs(args.output_folder)

    tricks = {}
    if args.snapshots != None:
        tricks['snapshots'] = args.snapshots

    results = train_test(setup_config=setup_config,
                         model=model,
                         train_loader=train_loader,
                         test_loader=test_loader,
                         epoch_num=args.epoch_num,
예제 #2
0
                                   mat_file=args.model2load,
                                   device=device,
                                   param2load=args.load_mode)
        elif args.model2load.endswith('.pth'):
            model = load_pth_model(model=model,
                                   name='lenet',
                                   pth_file=args.model2load,
                                   device=device,
                                   param2load=args.load_mode)
        else:
            raise ValueError('The format of %s is not supported' %
                             args.model2load)

    # Parse the optimizer
    if args.frozen_mode == None:
        optimizer = parse_optim(policy=args.optim, params=model.parameters())
    elif args.frozen_mode.lower() in [
            'fc',
    ]:
        for param in model.fc_params():
            param.requires_grad = False
        optimizer = parse_optim(policy=args.optim, params=model.conv_params())
    elif args.frozen_mode.lower() in [
            'conv',
    ]:
        for param in model.conv_params():
            param.requires_grad = False
        optimizer = parse_optim(policy=args.optim, params=model.fc_params())
    else:
        raise ValueError('Unrecognized frozen mode: %s' % args.frozen_mode)