def get_model(args): sd = None model_args = args if args.load is not None and args.load != '': # sd = torch.load(args.load, map_location=lambda storage, location: 'cpu') device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') sd = torch.load(args.load, map_location=device) if 'args' in sd: model_args = sd['args'] if 'sd' in sd: sd = sd['sd'] ntokens = model_args.data_size concat_pools = model_args.concat_max, model_args.concat_min, model_args.concat_mean if args.model == 'transformer': model = SentimentClassifier(model_args.model, ntokens, None, None, None, model_args.classifier_hidden_layers, model_args.classifier_dropout, None, concat_pools, False, model_args) else: model = SentimentClassifier( model_args.model, ntokens, model_args.emsize, model_args.nhid, model_args.nlayers, model_args.classifier_hidden_layers, model_args.classifier_dropout, model_args.all_layers, concat_pools, False, model_args) args.heads_per_class = model_args.heads_per_class args.use_softmax = model_args.use_softmax try: args.classes = list(model_args.classes) except: args.classes = [args.label_key] try: args.dual_thresh = model_args.dual_thresh and not model_args.joint_binary_train except: args.dual_thresh = False if args.cuda: model.cuda() if args.fp16: model.half() if sd is not None: try: model.load_state_dict(sd) except: # if state dict has weight normalized parameters apply and remove weight norm to model while loading sd if hasattr(model.lm_encoder, 'rnn'): apply_weight_norm(model.lm_encoder.rnn) else: apply_weight_norm(model.lm_encoder) model.lm_encoder.load_state_dict(sd) remove_weight_norm(model) if args.neurons > 0: print('WARNING. Setting neurons %s' % str(args.neurons)) model.set_neurons(args.neurons) return model
if args.fp16: model.half() with open(args.load_model, 'rb') as f: sd = torch.load(f) try: model.load_state_dict(sd) except: apply_weight_norm(model.encoder.rnn) model.load_state_dict(sd) remove_weight_norm(model) if args.neurons > 0: model.set_neurons(args.neurons) # uses similar function as transform from transfer.py def classify(model, text): model.eval() labels = np.array([]) first_label = True def get_batch(batch): (text, timesteps), labels = batch text = Variable(text).long() timesteps = Variable(timesteps).long() labels = Variable(labels).long() if args.cuda: text, timesteps, labels = text.cuda(), timesteps.cuda(