def __init__(self, qEncoderParam=None, qDecoderParam=None, aEncoderParam=None, aDecoderParam=None, imgFeatureSize=0, verbose=1): super(AQMQuestioner, self).__init__() if qEncoderParam is not None and qDecoderParam is not None: self.questioner = Questioner(qEncoderParam, qDecoderParam, imgFeatureSize, verbose) self.startToken = qEncoderParam['startToken'] self.endToken = qEncoderParam['endToken'] if aEncoderParam is not None and aDecoderParam is not None: self.appAnswerer = Answerer(aEncoderParam, aDecoderParam, verbose) self.questions = [] # for Guesser self.quesLens = [] # for Guesser self.answers = [] # for Guesser self.ansLens = [] self.image = None self.caption = None self.captionLens = None self.dataset = None self.dataset_split = None self.logSoftmax = nn.LogSoftmax(dim=0)
def loadModel(params, agent='abot', overwrite=False): if overwrite is False: params = params.copy() loadedParams = {} # should be everything used in encoderParam, decoderParam below encoderOptions = [ 'encoder', 'vocabSize', 'embedSize', 'rnnHiddenSize', 'numLayers', 'useHistory', 'useIm', 'imgEmbedSize', 'imgFeatureSize', 'numRounds', 'dropout' ] decoderOptions = [ 'decoder', 'vocabSize', 'embedSize', 'rnnHiddenSize', 'numLayers', 'dropout' ] modelOptions = encoderOptions + decoderOptions mdict = None gpuFlag = params['useGPU'] continueFlag = params['continue'] numEpochs = params['numEpochs'] startArg = 'startFrom' if agent == 'abot' else 'qstartFrom' if continueFlag: assert params[startArg], "Can't continue training without a \ checkpoint" # load a model from disk if it is given if params[startArg]: print('Loading model (weights and config) from {}'.format( params[startArg])) if gpuFlag: mdict = torch.load(params[startArg]) else: mdict = torch.load(params[startArg], map_location=lambda storage, location: storage) # Model options is a union of standard model options defined # above and parameters loaded from checkpoint modelOptions = list(set(modelOptions).union(set(mdict['params']))) for opt in modelOptions: if opt not in params: # Loading options from a checkpoint which are # necessary for continuing training, but are # not present in original parameter list. if continueFlag: print("Loaded option '%s' from checkpoint" % opt) params[opt] = mdict['params'][opt] loadedParams[opt] = mdict['params'][opt] elif params[opt] != mdict['params'][opt]: # When continuing training from a checkpoint, overwriting # parameters loaded from checkpoint is okay. if continueFlag: print("Overwriting param '%s'" % str(opt)) params[opt] = mdict['params'][opt] params['continue'] = continueFlag params['numEpochs'] = numEpochs params['useGPU'] = gpuFlag if params['continue']: assert 'ckpt_lRate' in params, "Checkpoint does not have\ info for restoring learning rate and optimizer." # assert False, "STOP right there, criminal scum!" # Initialize model class encoderParam = {k: params[k] for k in encoderOptions} decoderParam = {k: params[k] for k in decoderOptions} encoderParam['startToken'] = encoderParam['vocabSize'] - 2 encoderParam['endToken'] = encoderParam['vocabSize'] - 1 decoderParam['startToken'] = decoderParam['vocabSize'] - 2 decoderParam['endToken'] = decoderParam['vocabSize'] - 1 if agent == 'abot': encoderParam['type'] = params['encoder'] decoderParam['type'] = params['decoder'] encoderParam['isAnswerer'] = True from visdial.models.answerer import Answerer model = Answerer(encoderParam, decoderParam) elif agent == 'qbot': encoderParam['type'] = params['qencoder'] decoderParam['type'] = params['qdecoder'] encoderParam['isAnswerer'] = False encoderParam['useIm'] = False from visdial.models.questioner import Questioner model = Questioner(encoderParam, decoderParam, imgFeatureSize=encoderParam['imgFeatureSize']) if params['useGPU']: model.cuda() for p in model.encoder.parameters(): p.register_hook(clampGrad) for p in model.decoder.parameters(): p.register_hook(clampGrad) # NOTE: model.parameters() should be used here, otherwise immediate # child modules in model will not have gradient clamping # copy parameters if specified if mdict: model.load_state_dict(mdict['model']) optim_state = mdict['optimizer'] else: optim_state = None return model, loadedParams, optim_state