def get_model(args):

    sd = None
    model_args = args
    if args.load is not None and args.load != '':
        # sd = torch.load(args.load, map_location=lambda storage, location: 'cpu')
        device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
        sd = torch.load(args.load, map_location=device)
        if 'args' in sd:
            model_args = sd['args']
        if 'sd' in sd:
            sd = sd['sd']

    ntokens = model_args.data_size
    concat_pools = model_args.concat_max, model_args.concat_min, model_args.concat_mean
    if args.model == 'transformer':
        model = SentimentClassifier(model_args.model, ntokens, None, None,
                                    None, model_args.classifier_hidden_layers,
                                    model_args.classifier_dropout, None,
                                    concat_pools, False, model_args)
    else:
        model = SentimentClassifier(
            model_args.model, ntokens, model_args.emsize, model_args.nhid,
            model_args.nlayers, model_args.classifier_hidden_layers,
            model_args.classifier_dropout, model_args.all_layers, concat_pools,
            False, model_args)
    args.heads_per_class = model_args.heads_per_class
    args.use_softmax = model_args.use_softmax
    try:
        args.classes = list(model_args.classes)
    except:
        args.classes = [args.label_key]

    try:
        args.dual_thresh = model_args.dual_thresh and not model_args.joint_binary_train
    except:
        args.dual_thresh = False

    if args.cuda:
        model.cuda()

    if args.fp16:
        model.half()

    if sd is not None:
        try:
            model.load_state_dict(sd)
        except:
            # if state dict has weight normalized parameters apply and remove weight norm to model while loading sd
            if hasattr(model.lm_encoder, 'rnn'):
                apply_weight_norm(model.lm_encoder.rnn)
            else:
                apply_weight_norm(model.lm_encoder)
            model.lm_encoder.load_state_dict(sd)
            remove_weight_norm(model)

    if args.neurons > 0:
        print('WARNING. Setting neurons %s' % str(args.neurons))
        model.set_neurons(args.neurons)
    return model
示例#2
0
data_parser.set_defaults(split='1.', data='data/binary_sst/train.csv')

args = parser.parse_args()

args.cuda = torch.cuda.is_available()

train_data, val_data, test_data = data_config.apply(args)
ntokens = args.data_size
model = SentimentClassifier(args.model, ntokens, args.emsize, args.nhid,
                            args.nlayers, 0.0, args.all_layers)
if args.cuda:
    model.cuda()

if args.fp16:
    model.half()

with open(args.load_model, 'rb') as f:
    sd = torch.load(f)

try:
    model.load_state_dict(sd)
except:
    apply_weight_norm(model.encoder.rnn)
    model.load_state_dict(sd)
    remove_weight_norm(model)

if args.neurons > 0:
    model.set_neurons(args.neurons)