Exemplo n.º 1
0
def get_optimizer(args,
                  network=None,
                  param=None,
                  resume=False,
                  model_path=None):

    #process optimizer params

    # optimizer
    # different params for different part
    cnn_params = list(map(id, network.module.image_model.parameters()))
    other_params = filter(lambda p: id(p) not in cnn_params,
                          network.parameters())
    other_params = list(other_params)
    if param is not None:
        other_params.extend(list(param))
    param_groups = [{
        'params': other_params
    }, {
        'params': network.module.image_model.parameters(),
        'weight_decay': args.wd
    }]

    optimizer = torch.optim.Adam(param_groups,
                                 lr=args.lr,
                                 betas=(args.adam_alpha, args.adam_beta),
                                 eps=args.epsilon)

    if resume:
        check_file(model_path, 'model_file')
        checkpoint = torch.load(model_path)
        optimizer.load_state_dict(checkpoint['optimizer'])

    print('Total params: %2.fM' %
          (sum(p.numel() for p in network.parameters()) / 1000000.0))
    # seed

    manualSeed = random.randint(1, 10000)
    random.seed(manualSeed)
    np.random.seed(manualSeed)
    torch.manual_seed(manualSeed)
    # torch.cuda.manual_seed_all(manualSeed)

    return optimizer
Exemplo n.º 2
0
def get_network(args, model_path=None):

    network = Model(args)
    network = nn.DataParallel(network)
    # cudnn.benchmark = True
    args.start_epoch = 0

    # process network params
    if model_path == None:
        raise ValueError(
            'Supply the model path with --model_path while testing')
    check_file(model_path, 'model_file')
    checkpoint = torch.load(model_path)
    args.start_epoch = checkpoint['epoch'] + 1
    network_dict = checkpoint['network']
    network.load_state_dict(network_dict)
    print('==> Loading checkpoint "{}"'.format(model_path))

    return network
def network_config(args, split='train', param=None, resume=False, model_path=None):
    network = Model(args)
    network = nn.DataParallel(network).cuda()
    cudnn.benchmark = True
    args.start_epoch = 0

    # process network params
    if resume:
        directory.check_file(model_path, 'model_file')
        checkpoint = torch.load(model_path)
        args.start_epoch = checkpoint['epoch'] + 1
        # best_prec1 = checkpoint['best_prec1']
        #network.load_state_dict(checkpoint['state_dict'])
        network.load_state_dict(checkpoint['network']) 
        print('==> Loading checkpoint "{}"'.format(model_path))
    else:
        # pretrained
        if model_path is not None:
            print('==> Loading from pretrained models')
            network_dict = network.state_dict()
            if args.image_model == 'mobilenet_v1':
                cnn_pretrained = torch.load(model_path)['state_dict']
                start = 7
            else:
                cnn_pretrained = torch.load(model_path)
                start = 0
            # process keyword of pretrained model
            prefix = 'module.image_model.'
            pretrained_dict = {prefix + k[start:] :v for k,v in cnn_pretrained.items()}
            pretrained_dict = {k:v for k,v in pretrained_dict.items() if k in network_dict}
            network_dict.update(pretrained_dict)
            network.load_state_dict(network_dict)

    # process optimizer params
    if split == 'test':
        optimizer = None
    else:
        # optimizer
        # different params for different part
        cnn_params = list(map(id, network.module.image_model.parameters()))
        other_params = filter(lambda p: id(p) not in cnn_params, network.parameters())
        other_params = list(other_params)
        if param is not None:
            other_params.extend(list(param))
        param_groups = [{'params':other_params},
            {'params':network.module.image_model.parameters(), 'weight_decay':args.wd}]
        optimizer = torch.optim.Adam(
            param_groups,
            lr = args.lr, betas=(args.adam_alpha, args.adam_beta), eps=args.epsilon)
        if resume:
            optimizer.load_state_dict(checkpoint['optimizer'])

    print('Total params: %2.fM' % (sum(p.numel() for p in network.parameters()) / 1000000.0))
    # seed
    manualSeed = random.randint(1, 10000)
    random.seed(manualSeed)
    np.random.seed(manualSeed)
    torch.manual_seed(manualSeed)
    torch.cuda.manual_seed_all(manualSeed)

    return network, optimizer
Exemplo n.º 4
0
def network_config(args,
                   split='train',
                   param=None,
                   resume=False,
                   model_path=None,
                   param2=None):
    network = Model(args)
    network = nn.DataParallel(network).cuda()
    cudnn.benchmark = True
    args.start_epoch = 0

    # process network params
    if resume:
        directory.check_file(model_path, 'model_file')
        checkpoint = torch.load(model_path)
        args.start_epoch = checkpoint['epoch'] + 1
        network.load_state_dict(checkpoint['network'])
        print('==> Loading checkpoint "{}"'.format(model_path))
    else:
        # pretrained
        if model_path is not None:
            print('==> Loading from pretrained models')
            network_dict = network.state_dict()
            # process keyword of pretrained model
            cnn_pretrained = torch.load(model_path)
            network_keys = network_dict.keys()
            prefix = 'module.image_model.'
            update_pretrained_dict = {}
            for k, v in cnn_pretrained.items():
                if prefix + k in network_keys:
                    update_pretrained_dict[prefix + k] = v
                if prefix + 'branch2_' + k in network_keys:
                    update_pretrained_dict[prefix + 'branch2_' + k] = v
                if prefix + 'branch3_' + k in network_keys:
                    update_pretrained_dict[prefix + 'branch3_' + k] = v
                if prefix + k not in network_keys and prefix + 'branch2_' + k not in network_keys and prefix + 'branch3_' + k not in network_keys:
                    print("warning: " + k + ' not load')
            network_dict.update(update_pretrained_dict)
            network.load_state_dict(network_dict)

    # process optimizer params
    if split == 'test':
        optimizer = None
    else:
        # optimizer
        # different params for different part
        cnn_params = list(map(id, network.module.image_model.parameters()))
        lang_params = list(map(id, network.module.language_model.parameters()))
        cnn_params = cnn_params + lang_params
        other_params = filter(lambda p: id(p) not in cnn_params,
                              network.parameters())
        other_params = list(other_params)

        if param is not None:
            other_params.extend(list(param))

        if param2 is not None:
            other_params.extend(list(param2))

        param_groups = [{
            'params': other_params
        }, {
            'params': network.module.image_model.parameters(),
            'weight_decay': args.wd,
            'lr': args.lr / 10
        }, {
            'params': network.module.language_model.parameters(),
            'lr': args.lr / 10
        }]
        optimizer = torch.optim.Adam(param_groups,
                                     lr=args.lr,
                                     betas=(args.adam_alpha, args.adam_beta),
                                     eps=args.epsilon)
        if resume:
            optimizer.load_state_dict(checkpoint['optimizer'])

    print('Total params: %2.fM' %
          (sum(p.numel() for p in network.parameters()) / 1000000.0))
    # seed
    manualSeed = random.randint(1, 10000)
    random.seed(manualSeed)
    np.random.seed(manualSeed)
    torch.manual_seed(manualSeed)
    torch.cuda.manual_seed_all(manualSeed)

    return network, optimizer