def get_model(args):

    sd = None
    model_args = args
    if args.load is not None and args.load != '':
        # sd = torch.load(args.load, map_location=lambda storage, location: 'cpu')
        device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
        sd = torch.load(args.load, map_location=device)
        if 'args' in sd:
            model_args = sd['args']
        if 'sd' in sd:
            sd = sd['sd']

    ntokens = model_args.data_size
    concat_pools = model_args.concat_max, model_args.concat_min, model_args.concat_mean
    if args.model == 'transformer':
        model = SentimentClassifier(model_args.model, ntokens, None, None,
                                    None, model_args.classifier_hidden_layers,
                                    model_args.classifier_dropout, None,
                                    concat_pools, False, model_args)
    else:
        model = SentimentClassifier(
            model_args.model, ntokens, model_args.emsize, model_args.nhid,
            model_args.nlayers, model_args.classifier_hidden_layers,
            model_args.classifier_dropout, model_args.all_layers, concat_pools,
            False, model_args)
    args.heads_per_class = model_args.heads_per_class
    args.use_softmax = model_args.use_softmax
    try:
        args.classes = list(model_args.classes)
    except:
        args.classes = [args.label_key]

    try:
        args.dual_thresh = model_args.dual_thresh and not model_args.joint_binary_train
    except:
        args.dual_thresh = False

    if args.cuda:
        model.cuda()

    if args.fp16:
        model.half()

    if sd is not None:
        try:
            model.load_state_dict(sd)
        except:
            # if state dict has weight normalized parameters apply and remove weight norm to model while loading sd
            if hasattr(model.lm_encoder, 'rnn'):
                apply_weight_norm(model.lm_encoder.rnn)
            else:
                apply_weight_norm(model.lm_encoder)
            model.lm_encoder.load_state_dict(sd)
            remove_weight_norm(model)

    if args.neurons > 0:
        print('WARNING. Setting neurons %s' % str(args.neurons))
        model.set_neurons(args.neurons)
    return model
Exemplo n.º 2
0
if args.fp16:
    model.half()

with open(args.load_model, 'rb') as f:
    sd = torch.load(f)

try:
    model.load_state_dict(sd)
except:
    apply_weight_norm(model.encoder.rnn)
    model.load_state_dict(sd)
    remove_weight_norm(model)

if args.neurons > 0:
    model.set_neurons(args.neurons)


# uses similar function as transform from transfer.py
def classify(model, text):
    model.eval()
    labels = np.array([])
    first_label = True

    def get_batch(batch):
        (text, timesteps), labels = batch
        text = Variable(text).long()
        timesteps = Variable(timesteps).long()
        labels = Variable(labels).long()
        if args.cuda:
            text, timesteps, labels = text.cuda(), timesteps.cuda(