示例#1
0
    def load_checkpoint(self):
        # Restart training (useful with oar idempotant)
        params = self.params
        modelname = params['modelname']
        iterators_state = {}
        history = {}
        if osp.exists(osp.join(modelname, 'model.pth')):
            self.warn('Picking up where we left')
            # load model's weights
            saved_state = torch.load(osp.join(modelname, 'model.pth'))
            saved = list(saved_state)
            required_state = self.model.state_dict()
            required = list(required_state)
            del required_state
            if "module" in required[0] and "module" not in saved[0]:
                for k in saved:
                    kbis = "module.%s" % k
                    saved_state[kbis] = saved_state[k]
                    del saved_state[k]

            for k in saved:
                if "increment" in k:
                    del saved_state[k]
            self.model.load_state_dict(saved_state)
            # load the optimizer's last state:
            self.optimizer.load(
                torch.load(osp.join(modelname, 'optimizer.pth')
                           ))
            history = pl(osp.join(modelname, 'trackers.pkl'))
            iterators_state = {'batch_offset': history['batch_offset'],
                               'epoch': history['epoch']}

        elif params['start_from']:
            start_from = params['start_from']
            # Start from a pre-trained model:
            self.warn('Starting from %s' % start_from)
            if params['start_from_best']:
                flag = '-best'
                self.warn('Starting from the best saved model')
            else:
                flag = ''
            # load model's weights
            saved_weights = torch.load(osp.join(start_from, 'model%s.pth' % flag))
            if params['reverse']:
                # adapat the loaded weight for the reverse task
                saved_weights = self.reverse_weights(saved_weights)
            self.model.load_state_dict(saved_weights)
            del saved_weights
            # load the optimizer's last state:
            if not params['optim']['reset']:
                self.optimizer.load(
                    torch.load(osp.join(start_from, 'optimizer%s.pth' % flag)
                               ))
            history = pl(osp.join(start_from, 'trackers%s.pkl' % flag))
        self.trackers.update(history)
        self.epoch = self.trackers['epoch']
        self.iteration = self.trackers['iteration']
        # start with eval:
        # self.iteration -= 1
        return iterators_state
示例#2
0
 def __init__(self, job_name, params):
     super().__init__()
     self.version = "tok"
     self.logger = logging.getLogger(job_name)
     self.margin_sim = params['margin_sim']
     self.normalize_batch = params['normalize_batch']
     self.penalize_confidence = params['penalize_confidence']
     if self.margin_sim:
         self.logger.warn('Clipping similarities below %.2f' %
                          self.margin_sim)
     self.limited = params['limited_vocab_sim']
     self.alpha = params['alpha_word']
     self.tau_word = params['tau_word']
     # Load the similarity matrix:
     M = pl(params['similarity_matrix'])
     self.dense = isinstance(M, np.ndarray)
     self.rare = params['promote_rarity']
     if self.dense:
         M = M - 1  # = -D_ij
         if self.rare:
             IDF = pl(params['rarity_matrix'])
             M -= self.tau_word * self.rare * IDF
             del IDF
         M = M.astype(np.float32)
         M = Variable(torch.from_numpy(M)).cuda()
         self.Sim_Matrix = M
         n, d = self.Sim_Matrix.size()
     else:
         if self.rare:
             IDF = pl(params['rarity_matrix'])
             self.IDF = sparse_torch(IDF).cuda()
             del IDF
         self.Sim_Matrix = sparse_torch(M).cuda()
         n, d = self.Sim_Matrix.size()
     del M
示例#3
0
    def load_checkpoint(self):
        # Restart training (useful with oar idempotant)
        params = self.params
        modelname = params['modelname']
        iterators_state = {}
        history = {}
        print('Checking up', osp.join(modelname, "model.pth"))
        if osp.exists(osp.join(modelname, 'model.pth')):
            self.warn('Picking up where we left')
            # load model's weights
            saved_state = torch.load(osp.join(modelname, 'model.pth'))
            check = list(saved_state)
            for k in check:
                if "increment" in k:
                    del saved_state[k]
            self.model.load_state_dict(saved_state)
            # load the optimizer's last state:
            self.optimizer.load(
                torch.load(osp.join(modelname, 'optimizer.pth')
                           ))
            history = pl(osp.join(modelname, 'trackers.pkl'))
            print('Loaded history:', list(history))
            iterators_state = {'src_iterators': history['src_iterators'],
                               'trg_iterators': history['trg_iterators']}

        elif params['start_from']:
            start_from = params['start_from']
            # Start from a pre-trained model:
            self.warn('Starting from %s' % start_from)
            if params['start_from_best']:
                flag = '-best'
                self.warn('Starting from the best saved model')
            else:
                flag = ''
            # load model's weights
            saved_weights = torch.load(osp.join(start_from, 'model%s.pth' % flag))
            if params['reverse']:
                # adapat the loaded weight for the reverse task
                saved_weights = self.reverse_weights(saved_weights)
            self.model.load_state_dict(saved_weights)
            # load the optimizer's last state:
            if not params['optim']['reset']:
                self.optimizer.load(
                    torch.load(osp.join(start_from, 'optimizer%s.pth' % flag)
                               ))
            history = pl(osp.join(start_from, 'trackers%s.pkl' % flag))

        self.trackers.update(history)
        self.epoch = self.trackers['epoch']
        self.iteration = self.trackers['iteration']
        # start with eval:
        # self.iteration -= 1
        return iterators_state